UNPKG

4.89 MBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2024 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17(function (global, factory) {
18 typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) :
19 typeof define === 'function' && define.amd ? define(['exports'], factory) :
20 (global = typeof globalThis !== 'undefined' ? globalThis : global || self, factory(global.tf = global.tf || {}));
21})(this, (function (exports) { 'use strict';
22
23 function _mergeNamespaces(n, m) {
24 m.forEach(function (e) {
25 e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) {
26 if (k !== 'default' && !(k in n)) {
27 var d = Object.getOwnPropertyDescriptor(e, k);
28 Object.defineProperty(n, k, d.get ? d : {
29 enumerable: true,
30 get: function () { return e[k]; }
31 });
32 }
33 });
34 });
35 return n;
36 }
37
38 var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {};
39
40 function getDefaultExportFromCjs (x) {
41 return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x;
42 }
43
44 function getDefaultExportFromNamespaceIfPresent (n) {
45 return n && Object.prototype.hasOwnProperty.call(n, 'default') ? n['default'] : n;
46 }
47
48 function getDefaultExportFromNamespaceIfNotNamed (n) {
49 return n && Object.prototype.hasOwnProperty.call(n, 'default') && Object.keys(n).length === 1 ? n['default'] : n;
50 }
51
52 function getAugmentedNamespace(n) {
53 if (n.__esModule) return n;
54 var f = n.default;
55 if (typeof f == "function") {
56 var a = function a () {
57 if (this instanceof a) {
58 var args = [null];
59 args.push.apply(args, arguments);
60 var Ctor = Function.bind.apply(f, args);
61 return new Ctor();
62 }
63 return f.apply(this, arguments);
64 };
65 a.prototype = f.prototype;
66 } else a = {};
67 Object.defineProperty(a, '__esModule', {value: true});
68 Object.keys(n).forEach(function (k) {
69 var d = Object.getOwnPropertyDescriptor(n, k);
70 Object.defineProperty(a, k, d.get ? d : {
71 enumerable: true,
72 get: function () {
73 return n[k];
74 }
75 });
76 });
77 return a;
78 }
79
80 var es_symbol = {};
81
82 var es_symbol_constructor = {};
83
84 function _AsyncGenerator(gen) {
85 var front, back;
86 function resume(key, arg) {
87 try {
88 var result = gen[key](arg),
89 value = result.value,
90 overloaded = value instanceof _OverloadYield;
91 Promise.resolve(overloaded ? value.v : value).then(function (arg) {
92 if (overloaded) {
93 var nextKey = "return" === key ? "return" : "next";
94 if (!value.k || arg.done) return resume(nextKey, arg);
95 arg = gen[nextKey](arg).value;
96 }
97 settle(result.done ? "return" : "normal", arg);
98 }, function (err) {
99 resume("throw", err);
100 });
101 } catch (err) {
102 settle("throw", err);
103 }
104 }
105 function settle(type, value) {
106 switch (type) {
107 case "return":
108 front.resolve({
109 value: value,
110 done: !0
111 });
112 break;
113 case "throw":
114 front.reject(value);
115 break;
116 default:
117 front.resolve({
118 value: value,
119 done: !1
120 });
121 }
122 (front = front.next) ? resume(front.key, front.arg) : back = null;
123 }
124 this._invoke = function (key, arg) {
125 return new Promise(function (resolve, reject) {
126 var request = {
127 key: key,
128 arg: arg,
129 resolve: resolve,
130 reject: reject,
131 next: null
132 };
133 back ? back = back.next = request : (front = back = request, resume(key, arg));
134 });
135 }, "function" != typeof gen.return && (this.return = void 0);
136 }
137 _AsyncGenerator.prototype["function" == typeof Symbol && Symbol.asyncIterator || "@@asyncIterator"] = function () {
138 return this;
139 }, _AsyncGenerator.prototype.next = function (arg) {
140 return this._invoke("next", arg);
141 }, _AsyncGenerator.prototype.throw = function (arg) {
142 return this._invoke("throw", arg);
143 }, _AsyncGenerator.prototype.return = function (arg) {
144 return this._invoke("return", arg);
145 };
146 function _OverloadYield(value, kind) {
147 this.v = value, this.k = kind;
148 }
149 function old_createMetadataMethodsForProperty(metadataMap, kind, property, decoratorFinishedRef) {
150 return {
151 getMetadata: function (key) {
152 old_assertNotFinished(decoratorFinishedRef, "getMetadata"), old_assertMetadataKey(key);
153 var metadataForKey = metadataMap[key];
154 if (void 0 !== metadataForKey) if (1 === kind) {
155 var pub = metadataForKey.public;
156 if (void 0 !== pub) return pub[property];
157 } else if (2 === kind) {
158 var priv = metadataForKey.private;
159 if (void 0 !== priv) return priv.get(property);
160 } else if (Object.hasOwnProperty.call(metadataForKey, "constructor")) return metadataForKey.constructor;
161 },
162 setMetadata: function (key, value) {
163 old_assertNotFinished(decoratorFinishedRef, "setMetadata"), old_assertMetadataKey(key);
164 var metadataForKey = metadataMap[key];
165 if (void 0 === metadataForKey && (metadataForKey = metadataMap[key] = {}), 1 === kind) {
166 var pub = metadataForKey.public;
167 void 0 === pub && (pub = metadataForKey.public = {}), pub[property] = value;
168 } else if (2 === kind) {
169 var priv = metadataForKey.priv;
170 void 0 === priv && (priv = metadataForKey.private = new Map()), priv.set(property, value);
171 } else metadataForKey.constructor = value;
172 }
173 };
174 }
175 function old_convertMetadataMapToFinal(obj, metadataMap) {
176 var parentMetadataMap = obj[Symbol.metadata || Symbol.for("Symbol.metadata")],
177 metadataKeys = Object.getOwnPropertySymbols(metadataMap);
178 if (0 !== metadataKeys.length) {
179 for (var i = 0; i < metadataKeys.length; i++) {
180 var key = metadataKeys[i],
181 metaForKey = metadataMap[key],
182 parentMetaForKey = parentMetadataMap ? parentMetadataMap[key] : null,
183 pub = metaForKey.public,
184 parentPub = parentMetaForKey ? parentMetaForKey.public : null;
185 pub && parentPub && Object.setPrototypeOf(pub, parentPub);
186 var priv = metaForKey.private;
187 if (priv) {
188 var privArr = Array.from(priv.values()),
189 parentPriv = parentMetaForKey ? parentMetaForKey.private : null;
190 parentPriv && (privArr = privArr.concat(parentPriv)), metaForKey.private = privArr;
191 }
192 parentMetaForKey && Object.setPrototypeOf(metaForKey, parentMetaForKey);
193 }
194 parentMetadataMap && Object.setPrototypeOf(metadataMap, parentMetadataMap), obj[Symbol.metadata || Symbol.for("Symbol.metadata")] = metadataMap;
195 }
196 }
197 function old_createAddInitializerMethod(initializers, decoratorFinishedRef) {
198 return function (initializer) {
199 old_assertNotFinished(decoratorFinishedRef, "addInitializer"), old_assertCallable(initializer, "An initializer"), initializers.push(initializer);
200 };
201 }
202 function old_memberDec(dec, name, desc, metadataMap, initializers, kind, isStatic, isPrivate, value) {
203 var kindStr;
204 switch (kind) {
205 case 1:
206 kindStr = "accessor";
207 break;
208 case 2:
209 kindStr = "method";
210 break;
211 case 3:
212 kindStr = "getter";
213 break;
214 case 4:
215 kindStr = "setter";
216 break;
217 default:
218 kindStr = "field";
219 }
220 var metadataKind,
221 metadataName,
222 ctx = {
223 kind: kindStr,
224 name: isPrivate ? "#" + name : name,
225 isStatic: isStatic,
226 isPrivate: isPrivate
227 },
228 decoratorFinishedRef = {
229 v: !1
230 };
231 if (0 !== kind && (ctx.addInitializer = old_createAddInitializerMethod(initializers, decoratorFinishedRef)), isPrivate) {
232 metadataKind = 2, metadataName = Symbol(name);
233 var access = {};
234 0 === kind ? (access.get = desc.get, access.set = desc.set) : 2 === kind ? access.get = function () {
235 return desc.value;
236 } : (1 !== kind && 3 !== kind || (access.get = function () {
237 return desc.get.call(this);
238 }), 1 !== kind && 4 !== kind || (access.set = function (v) {
239 desc.set.call(this, v);
240 })), ctx.access = access;
241 } else metadataKind = 1, metadataName = name;
242 try {
243 return dec(value, Object.assign(ctx, old_createMetadataMethodsForProperty(metadataMap, metadataKind, metadataName, decoratorFinishedRef)));
244 } finally {
245 decoratorFinishedRef.v = !0;
246 }
247 }
248 function old_assertNotFinished(decoratorFinishedRef, fnName) {
249 if (decoratorFinishedRef.v) throw new Error("attempted to call " + fnName + " after decoration was finished");
250 }
251 function old_assertMetadataKey(key) {
252 if ("symbol" != typeof key) throw new TypeError("Metadata keys must be symbols, received: " + key);
253 }
254 function old_assertCallable(fn, hint) {
255 if ("function" != typeof fn) throw new TypeError(hint + " must be a function");
256 }
257 function old_assertValidReturnValue(kind, value) {
258 var type = typeof value;
259 if (1 === kind) {
260 if ("object" !== type || null === value) throw new TypeError("accessor decorators must return an object with get, set, or init properties or void 0");
261 void 0 !== value.get && old_assertCallable(value.get, "accessor.get"), void 0 !== value.set && old_assertCallable(value.set, "accessor.set"), void 0 !== value.init && old_assertCallable(value.init, "accessor.init"), void 0 !== value.initializer && old_assertCallable(value.initializer, "accessor.initializer");
262 } else if ("function" !== type) {
263 var hint;
264 throw hint = 0 === kind ? "field" : 10 === kind ? "class" : "method", new TypeError(hint + " decorators must return a function or void 0");
265 }
266 }
267 function old_getInit(desc) {
268 var initializer;
269 return null == (initializer = desc.init) && (initializer = desc.initializer) && "undefined" != typeof console && console.warn(".initializer has been renamed to .init as of March 2022"), initializer;
270 }
271 function old_applyMemberDec(ret, base, decInfo, name, kind, isStatic, isPrivate, metadataMap, initializers) {
272 var desc,
273 initializer,
274 value,
275 newValue,
276 get,
277 set,
278 decs = decInfo[0];
279 if (isPrivate ? desc = 0 === kind || 1 === kind ? {
280 get: decInfo[3],
281 set: decInfo[4]
282 } : 3 === kind ? {
283 get: decInfo[3]
284 } : 4 === kind ? {
285 set: decInfo[3]
286 } : {
287 value: decInfo[3]
288 } : 0 !== kind && (desc = Object.getOwnPropertyDescriptor(base, name)), 1 === kind ? value = {
289 get: desc.get,
290 set: desc.set
291 } : 2 === kind ? value = desc.value : 3 === kind ? value = desc.get : 4 === kind && (value = desc.set), "function" == typeof decs) void 0 !== (newValue = old_memberDec(decs, name, desc, metadataMap, initializers, kind, isStatic, isPrivate, value)) && (old_assertValidReturnValue(kind, newValue), 0 === kind ? initializer = newValue : 1 === kind ? (initializer = old_getInit(newValue), get = newValue.get || value.get, set = newValue.set || value.set, value = {
292 get: get,
293 set: set
294 }) : value = newValue);else for (var i = decs.length - 1; i >= 0; i--) {
295 var newInit;
296 if (void 0 !== (newValue = old_memberDec(decs[i], name, desc, metadataMap, initializers, kind, isStatic, isPrivate, value))) old_assertValidReturnValue(kind, newValue), 0 === kind ? newInit = newValue : 1 === kind ? (newInit = old_getInit(newValue), get = newValue.get || value.get, set = newValue.set || value.set, value = {
297 get: get,
298 set: set
299 }) : value = newValue, void 0 !== newInit && (void 0 === initializer ? initializer = newInit : "function" == typeof initializer ? initializer = [initializer, newInit] : initializer.push(newInit));
300 }
301 if (0 === kind || 1 === kind) {
302 if (void 0 === initializer) initializer = function (instance, init) {
303 return init;
304 };else if ("function" != typeof initializer) {
305 var ownInitializers = initializer;
306 initializer = function (instance, init) {
307 for (var value = init, i = 0; i < ownInitializers.length; i++) value = ownInitializers[i].call(instance, value);
308 return value;
309 };
310 } else {
311 var originalInitializer = initializer;
312 initializer = function (instance, init) {
313 return originalInitializer.call(instance, init);
314 };
315 }
316 ret.push(initializer);
317 }
318 0 !== kind && (1 === kind ? (desc.get = value.get, desc.set = value.set) : 2 === kind ? desc.value = value : 3 === kind ? desc.get = value : 4 === kind && (desc.set = value), isPrivate ? 1 === kind ? (ret.push(function (instance, args) {
319 return value.get.call(instance, args);
320 }), ret.push(function (instance, args) {
321 return value.set.call(instance, args);
322 })) : 2 === kind ? ret.push(value) : ret.push(function (instance, args) {
323 return value.call(instance, args);
324 }) : Object.defineProperty(base, name, desc));
325 }
326 function old_applyMemberDecs(ret, Class, protoMetadataMap, staticMetadataMap, decInfos) {
327 for (var protoInitializers, staticInitializers, existingProtoNonFields = new Map(), existingStaticNonFields = new Map(), i = 0; i < decInfos.length; i++) {
328 var decInfo = decInfos[i];
329 if (Array.isArray(decInfo)) {
330 var base,
331 metadataMap,
332 initializers,
333 kind = decInfo[1],
334 name = decInfo[2],
335 isPrivate = decInfo.length > 3,
336 isStatic = kind >= 5;
337 if (isStatic ? (base = Class, metadataMap = staticMetadataMap, 0 !== (kind -= 5) && (initializers = staticInitializers = staticInitializers || [])) : (base = Class.prototype, metadataMap = protoMetadataMap, 0 !== kind && (initializers = protoInitializers = protoInitializers || [])), 0 !== kind && !isPrivate) {
338 var existingNonFields = isStatic ? existingStaticNonFields : existingProtoNonFields,
339 existingKind = existingNonFields.get(name) || 0;
340 if (!0 === existingKind || 3 === existingKind && 4 !== kind || 4 === existingKind && 3 !== kind) throw new Error("Attempted to decorate a public method/accessor that has the same name as a previously decorated public method/accessor. This is not currently supported by the decorators plugin. Property name was: " + name);
341 !existingKind && kind > 2 ? existingNonFields.set(name, kind) : existingNonFields.set(name, !0);
342 }
343 old_applyMemberDec(ret, base, decInfo, name, kind, isStatic, isPrivate, metadataMap, initializers);
344 }
345 }
346 old_pushInitializers(ret, protoInitializers), old_pushInitializers(ret, staticInitializers);
347 }
348 function old_pushInitializers(ret, initializers) {
349 initializers && ret.push(function (instance) {
350 for (var i = 0; i < initializers.length; i++) initializers[i].call(instance);
351 return instance;
352 });
353 }
354 function old_applyClassDecs(ret, targetClass, metadataMap, classDecs) {
355 if (classDecs.length > 0) {
356 for (var initializers = [], newClass = targetClass, name = targetClass.name, i = classDecs.length - 1; i >= 0; i--) {
357 var decoratorFinishedRef = {
358 v: !1
359 };
360 try {
361 var ctx = Object.assign({
362 kind: "class",
363 name: name,
364 addInitializer: old_createAddInitializerMethod(initializers, decoratorFinishedRef)
365 }, old_createMetadataMethodsForProperty(metadataMap, 0, name, decoratorFinishedRef)),
366 nextNewClass = classDecs[i](newClass, ctx);
367 } finally {
368 decoratorFinishedRef.v = !0;
369 }
370 void 0 !== nextNewClass && (old_assertValidReturnValue(10, nextNewClass), newClass = nextNewClass);
371 }
372 ret.push(newClass, function () {
373 for (var i = 0; i < initializers.length; i++) initializers[i].call(newClass);
374 });
375 }
376 }
377 function _applyDecs(targetClass, memberDecs, classDecs) {
378 var ret = [],
379 staticMetadataMap = {},
380 protoMetadataMap = {};
381 return old_applyMemberDecs(ret, targetClass, protoMetadataMap, staticMetadataMap, memberDecs), old_convertMetadataMapToFinal(targetClass.prototype, protoMetadataMap), old_applyClassDecs(ret, targetClass, staticMetadataMap, classDecs), old_convertMetadataMapToFinal(targetClass, staticMetadataMap), ret;
382 }
383 function applyDecs2203Factory() {
384 function createAddInitializerMethod(initializers, decoratorFinishedRef) {
385 return function (initializer) {
386 !function (decoratorFinishedRef, fnName) {
387 if (decoratorFinishedRef.v) throw new Error("attempted to call " + fnName + " after decoration was finished");
388 }(decoratorFinishedRef, "addInitializer"), assertCallable(initializer, "An initializer"), initializers.push(initializer);
389 };
390 }
391 function memberDec(dec, name, desc, initializers, kind, isStatic, isPrivate, value) {
392 var kindStr;
393 switch (kind) {
394 case 1:
395 kindStr = "accessor";
396 break;
397 case 2:
398 kindStr = "method";
399 break;
400 case 3:
401 kindStr = "getter";
402 break;
403 case 4:
404 kindStr = "setter";
405 break;
406 default:
407 kindStr = "field";
408 }
409 var get,
410 set,
411 ctx = {
412 kind: kindStr,
413 name: isPrivate ? "#" + name : name,
414 static: isStatic,
415 private: isPrivate
416 },
417 decoratorFinishedRef = {
418 v: !1
419 };
420 0 !== kind && (ctx.addInitializer = createAddInitializerMethod(initializers, decoratorFinishedRef)), 0 === kind ? isPrivate ? (get = desc.get, set = desc.set) : (get = function () {
421 return this[name];
422 }, set = function (v) {
423 this[name] = v;
424 }) : 2 === kind ? get = function () {
425 return desc.value;
426 } : (1 !== kind && 3 !== kind || (get = function () {
427 return desc.get.call(this);
428 }), 1 !== kind && 4 !== kind || (set = function (v) {
429 desc.set.call(this, v);
430 })), ctx.access = get && set ? {
431 get: get,
432 set: set
433 } : get ? {
434 get: get
435 } : {
436 set: set
437 };
438 try {
439 return dec(value, ctx);
440 } finally {
441 decoratorFinishedRef.v = !0;
442 }
443 }
444 function assertCallable(fn, hint) {
445 if ("function" != typeof fn) throw new TypeError(hint + " must be a function");
446 }
447 function assertValidReturnValue(kind, value) {
448 var type = typeof value;
449 if (1 === kind) {
450 if ("object" !== type || null === value) throw new TypeError("accessor decorators must return an object with get, set, or init properties or void 0");
451 void 0 !== value.get && assertCallable(value.get, "accessor.get"), void 0 !== value.set && assertCallable(value.set, "accessor.set"), void 0 !== value.init && assertCallable(value.init, "accessor.init");
452 } else if ("function" !== type) {
453 var hint;
454 throw hint = 0 === kind ? "field" : 10 === kind ? "class" : "method", new TypeError(hint + " decorators must return a function or void 0");
455 }
456 }
457 function applyMemberDec(ret, base, decInfo, name, kind, isStatic, isPrivate, initializers) {
458 var desc,
459 init,
460 value,
461 newValue,
462 get,
463 set,
464 decs = decInfo[0];
465 if (isPrivate ? desc = 0 === kind || 1 === kind ? {
466 get: decInfo[3],
467 set: decInfo[4]
468 } : 3 === kind ? {
469 get: decInfo[3]
470 } : 4 === kind ? {
471 set: decInfo[3]
472 } : {
473 value: decInfo[3]
474 } : 0 !== kind && (desc = Object.getOwnPropertyDescriptor(base, name)), 1 === kind ? value = {
475 get: desc.get,
476 set: desc.set
477 } : 2 === kind ? value = desc.value : 3 === kind ? value = desc.get : 4 === kind && (value = desc.set), "function" == typeof decs) void 0 !== (newValue = memberDec(decs, name, desc, initializers, kind, isStatic, isPrivate, value)) && (assertValidReturnValue(kind, newValue), 0 === kind ? init = newValue : 1 === kind ? (init = newValue.init, get = newValue.get || value.get, set = newValue.set || value.set, value = {
478 get: get,
479 set: set
480 }) : value = newValue);else for (var i = decs.length - 1; i >= 0; i--) {
481 var newInit;
482 if (void 0 !== (newValue = memberDec(decs[i], name, desc, initializers, kind, isStatic, isPrivate, value))) assertValidReturnValue(kind, newValue), 0 === kind ? newInit = newValue : 1 === kind ? (newInit = newValue.init, get = newValue.get || value.get, set = newValue.set || value.set, value = {
483 get: get,
484 set: set
485 }) : value = newValue, void 0 !== newInit && (void 0 === init ? init = newInit : "function" == typeof init ? init = [init, newInit] : init.push(newInit));
486 }
487 if (0 === kind || 1 === kind) {
488 if (void 0 === init) init = function (instance, init) {
489 return init;
490 };else if ("function" != typeof init) {
491 var ownInitializers = init;
492 init = function (instance, init) {
493 for (var value = init, i = 0; i < ownInitializers.length; i++) value = ownInitializers[i].call(instance, value);
494 return value;
495 };
496 } else {
497 var originalInitializer = init;
498 init = function (instance, init) {
499 return originalInitializer.call(instance, init);
500 };
501 }
502 ret.push(init);
503 }
504 0 !== kind && (1 === kind ? (desc.get = value.get, desc.set = value.set) : 2 === kind ? desc.value = value : 3 === kind ? desc.get = value : 4 === kind && (desc.set = value), isPrivate ? 1 === kind ? (ret.push(function (instance, args) {
505 return value.get.call(instance, args);
506 }), ret.push(function (instance, args) {
507 return value.set.call(instance, args);
508 })) : 2 === kind ? ret.push(value) : ret.push(function (instance, args) {
509 return value.call(instance, args);
510 }) : Object.defineProperty(base, name, desc));
511 }
512 function pushInitializers(ret, initializers) {
513 initializers && ret.push(function (instance) {
514 for (var i = 0; i < initializers.length; i++) initializers[i].call(instance);
515 return instance;
516 });
517 }
518 return function (targetClass, memberDecs, classDecs) {
519 var ret = [];
520 return function (ret, Class, decInfos) {
521 for (var protoInitializers, staticInitializers, existingProtoNonFields = new Map(), existingStaticNonFields = new Map(), i = 0; i < decInfos.length; i++) {
522 var decInfo = decInfos[i];
523 if (Array.isArray(decInfo)) {
524 var base,
525 initializers,
526 kind = decInfo[1],
527 name = decInfo[2],
528 isPrivate = decInfo.length > 3,
529 isStatic = kind >= 5;
530 if (isStatic ? (base = Class, 0 != (kind -= 5) && (initializers = staticInitializers = staticInitializers || [])) : (base = Class.prototype, 0 !== kind && (initializers = protoInitializers = protoInitializers || [])), 0 !== kind && !isPrivate) {
531 var existingNonFields = isStatic ? existingStaticNonFields : existingProtoNonFields,
532 existingKind = existingNonFields.get(name) || 0;
533 if (!0 === existingKind || 3 === existingKind && 4 !== kind || 4 === existingKind && 3 !== kind) throw new Error("Attempted to decorate a public method/accessor that has the same name as a previously decorated public method/accessor. This is not currently supported by the decorators plugin. Property name was: " + name);
534 !existingKind && kind > 2 ? existingNonFields.set(name, kind) : existingNonFields.set(name, !0);
535 }
536 applyMemberDec(ret, base, decInfo, name, kind, isStatic, isPrivate, initializers);
537 }
538 }
539 pushInitializers(ret, protoInitializers), pushInitializers(ret, staticInitializers);
540 }(ret, targetClass, memberDecs), function (ret, targetClass, classDecs) {
541 if (classDecs.length > 0) {
542 for (var initializers = [], newClass = targetClass, name = targetClass.name, i = classDecs.length - 1; i >= 0; i--) {
543 var decoratorFinishedRef = {
544 v: !1
545 };
546 try {
547 var nextNewClass = classDecs[i](newClass, {
548 kind: "class",
549 name: name,
550 addInitializer: createAddInitializerMethod(initializers, decoratorFinishedRef)
551 });
552 } finally {
553 decoratorFinishedRef.v = !0;
554 }
555 void 0 !== nextNewClass && (assertValidReturnValue(10, nextNewClass), newClass = nextNewClass);
556 }
557 ret.push(newClass, function () {
558 for (var i = 0; i < initializers.length; i++) initializers[i].call(newClass);
559 });
560 }
561 }(ret, targetClass, classDecs), ret;
562 };
563 }
564 var applyDecs2203Impl;
565 function _applyDecs2203(targetClass, memberDecs, classDecs) {
566 return (applyDecs2203Impl = applyDecs2203Impl || applyDecs2203Factory())(targetClass, memberDecs, classDecs);
567 }
568 function applyDecs2203RFactory() {
569 function createAddInitializerMethod(initializers, decoratorFinishedRef) {
570 return function (initializer) {
571 !function (decoratorFinishedRef, fnName) {
572 if (decoratorFinishedRef.v) throw new Error("attempted to call " + fnName + " after decoration was finished");
573 }(decoratorFinishedRef, "addInitializer"), assertCallable(initializer, "An initializer"), initializers.push(initializer);
574 };
575 }
576 function memberDec(dec, name, desc, initializers, kind, isStatic, isPrivate, value) {
577 var kindStr;
578 switch (kind) {
579 case 1:
580 kindStr = "accessor";
581 break;
582 case 2:
583 kindStr = "method";
584 break;
585 case 3:
586 kindStr = "getter";
587 break;
588 case 4:
589 kindStr = "setter";
590 break;
591 default:
592 kindStr = "field";
593 }
594 var get,
595 set,
596 ctx = {
597 kind: kindStr,
598 name: isPrivate ? "#" + name : name,
599 static: isStatic,
600 private: isPrivate
601 },
602 decoratorFinishedRef = {
603 v: !1
604 };
605 0 !== kind && (ctx.addInitializer = createAddInitializerMethod(initializers, decoratorFinishedRef)), 0 === kind ? isPrivate ? (get = desc.get, set = desc.set) : (get = function () {
606 return this[name];
607 }, set = function (v) {
608 this[name] = v;
609 }) : 2 === kind ? get = function () {
610 return desc.value;
611 } : (1 !== kind && 3 !== kind || (get = function () {
612 return desc.get.call(this);
613 }), 1 !== kind && 4 !== kind || (set = function (v) {
614 desc.set.call(this, v);
615 })), ctx.access = get && set ? {
616 get: get,
617 set: set
618 } : get ? {
619 get: get
620 } : {
621 set: set
622 };
623 try {
624 return dec(value, ctx);
625 } finally {
626 decoratorFinishedRef.v = !0;
627 }
628 }
629 function assertCallable(fn, hint) {
630 if ("function" != typeof fn) throw new TypeError(hint + " must be a function");
631 }
632 function assertValidReturnValue(kind, value) {
633 var type = typeof value;
634 if (1 === kind) {
635 if ("object" !== type || null === value) throw new TypeError("accessor decorators must return an object with get, set, or init properties or void 0");
636 void 0 !== value.get && assertCallable(value.get, "accessor.get"), void 0 !== value.set && assertCallable(value.set, "accessor.set"), void 0 !== value.init && assertCallable(value.init, "accessor.init");
637 } else if ("function" !== type) {
638 var hint;
639 throw hint = 0 === kind ? "field" : 10 === kind ? "class" : "method", new TypeError(hint + " decorators must return a function or void 0");
640 }
641 }
642 function applyMemberDec(ret, base, decInfo, name, kind, isStatic, isPrivate, initializers) {
643 var desc,
644 init,
645 value,
646 newValue,
647 get,
648 set,
649 decs = decInfo[0];
650 if (isPrivate ? desc = 0 === kind || 1 === kind ? {
651 get: decInfo[3],
652 set: decInfo[4]
653 } : 3 === kind ? {
654 get: decInfo[3]
655 } : 4 === kind ? {
656 set: decInfo[3]
657 } : {
658 value: decInfo[3]
659 } : 0 !== kind && (desc = Object.getOwnPropertyDescriptor(base, name)), 1 === kind ? value = {
660 get: desc.get,
661 set: desc.set
662 } : 2 === kind ? value = desc.value : 3 === kind ? value = desc.get : 4 === kind && (value = desc.set), "function" == typeof decs) void 0 !== (newValue = memberDec(decs, name, desc, initializers, kind, isStatic, isPrivate, value)) && (assertValidReturnValue(kind, newValue), 0 === kind ? init = newValue : 1 === kind ? (init = newValue.init, get = newValue.get || value.get, set = newValue.set || value.set, value = {
663 get: get,
664 set: set
665 }) : value = newValue);else for (var i = decs.length - 1; i >= 0; i--) {
666 var newInit;
667 if (void 0 !== (newValue = memberDec(decs[i], name, desc, initializers, kind, isStatic, isPrivate, value))) assertValidReturnValue(kind, newValue), 0 === kind ? newInit = newValue : 1 === kind ? (newInit = newValue.init, get = newValue.get || value.get, set = newValue.set || value.set, value = {
668 get: get,
669 set: set
670 }) : value = newValue, void 0 !== newInit && (void 0 === init ? init = newInit : "function" == typeof init ? init = [init, newInit] : init.push(newInit));
671 }
672 if (0 === kind || 1 === kind) {
673 if (void 0 === init) init = function (instance, init) {
674 return init;
675 };else if ("function" != typeof init) {
676 var ownInitializers = init;
677 init = function (instance, init) {
678 for (var value = init, i = 0; i < ownInitializers.length; i++) value = ownInitializers[i].call(instance, value);
679 return value;
680 };
681 } else {
682 var originalInitializer = init;
683 init = function (instance, init) {
684 return originalInitializer.call(instance, init);
685 };
686 }
687 ret.push(init);
688 }
689 0 !== kind && (1 === kind ? (desc.get = value.get, desc.set = value.set) : 2 === kind ? desc.value = value : 3 === kind ? desc.get = value : 4 === kind && (desc.set = value), isPrivate ? 1 === kind ? (ret.push(function (instance, args) {
690 return value.get.call(instance, args);
691 }), ret.push(function (instance, args) {
692 return value.set.call(instance, args);
693 })) : 2 === kind ? ret.push(value) : ret.push(function (instance, args) {
694 return value.call(instance, args);
695 }) : Object.defineProperty(base, name, desc));
696 }
697 function applyMemberDecs(Class, decInfos) {
698 for (var protoInitializers, staticInitializers, ret = [], existingProtoNonFields = new Map(), existingStaticNonFields = new Map(), i = 0; i < decInfos.length; i++) {
699 var decInfo = decInfos[i];
700 if (Array.isArray(decInfo)) {
701 var base,
702 initializers,
703 kind = decInfo[1],
704 name = decInfo[2],
705 isPrivate = decInfo.length > 3,
706 isStatic = kind >= 5;
707 if (isStatic ? (base = Class, 0 !== (kind -= 5) && (initializers = staticInitializers = staticInitializers || [])) : (base = Class.prototype, 0 !== kind && (initializers = protoInitializers = protoInitializers || [])), 0 !== kind && !isPrivate) {
708 var existingNonFields = isStatic ? existingStaticNonFields : existingProtoNonFields,
709 existingKind = existingNonFields.get(name) || 0;
710 if (!0 === existingKind || 3 === existingKind && 4 !== kind || 4 === existingKind && 3 !== kind) throw new Error("Attempted to decorate a public method/accessor that has the same name as a previously decorated public method/accessor. This is not currently supported by the decorators plugin. Property name was: " + name);
711 !existingKind && kind > 2 ? existingNonFields.set(name, kind) : existingNonFields.set(name, !0);
712 }
713 applyMemberDec(ret, base, decInfo, name, kind, isStatic, isPrivate, initializers);
714 }
715 }
716 return pushInitializers(ret, protoInitializers), pushInitializers(ret, staticInitializers), ret;
717 }
718 function pushInitializers(ret, initializers) {
719 initializers && ret.push(function (instance) {
720 for (var i = 0; i < initializers.length; i++) initializers[i].call(instance);
721 return instance;
722 });
723 }
724 return function (targetClass, memberDecs, classDecs) {
725 return {
726 e: applyMemberDecs(targetClass, memberDecs),
727 get c() {
728 return function (targetClass, classDecs) {
729 if (classDecs.length > 0) {
730 for (var initializers = [], newClass = targetClass, name = targetClass.name, i = classDecs.length - 1; i >= 0; i--) {
731 var decoratorFinishedRef = {
732 v: !1
733 };
734 try {
735 var nextNewClass = classDecs[i](newClass, {
736 kind: "class",
737 name: name,
738 addInitializer: createAddInitializerMethod(initializers, decoratorFinishedRef)
739 });
740 } finally {
741 decoratorFinishedRef.v = !0;
742 }
743 void 0 !== nextNewClass && (assertValidReturnValue(10, nextNewClass), newClass = nextNewClass);
744 }
745 return [newClass, function () {
746 for (var i = 0; i < initializers.length; i++) initializers[i].call(newClass);
747 }];
748 }
749 }(targetClass, classDecs);
750 }
751 };
752 };
753 }
754 function _applyDecs2203R(targetClass, memberDecs, classDecs) {
755 return (_applyDecs2203R = applyDecs2203RFactory())(targetClass, memberDecs, classDecs);
756 }
757 function createAddInitializerMethod(initializers, decoratorFinishedRef) {
758 return function (initializer) {
759 assertNotFinished(decoratorFinishedRef, "addInitializer"), assertCallable(initializer, "An initializer"), initializers.push(initializer);
760 };
761 }
762 function assertInstanceIfPrivate(has, target) {
763 if (!has(target)) throw new TypeError("Attempted to access private element on non-instance");
764 }
765 function memberDec(dec, name, desc, initializers, kind, isStatic, isPrivate, value, hasPrivateBrand) {
766 var kindStr;
767 switch (kind) {
768 case 1:
769 kindStr = "accessor";
770 break;
771 case 2:
772 kindStr = "method";
773 break;
774 case 3:
775 kindStr = "getter";
776 break;
777 case 4:
778 kindStr = "setter";
779 break;
780 default:
781 kindStr = "field";
782 }
783 var get,
784 set,
785 ctx = {
786 kind: kindStr,
787 name: isPrivate ? "#" + name : name,
788 static: isStatic,
789 private: isPrivate
790 },
791 decoratorFinishedRef = {
792 v: !1
793 };
794 if (0 !== kind && (ctx.addInitializer = createAddInitializerMethod(initializers, decoratorFinishedRef)), isPrivate || 0 !== kind && 2 !== kind) {
795 if (2 === kind) get = function (target) {
796 return assertInstanceIfPrivate(hasPrivateBrand, target), desc.value;
797 };else {
798 var t = 0 === kind || 1 === kind;
799 (t || 3 === kind) && (get = isPrivate ? function (target) {
800 return assertInstanceIfPrivate(hasPrivateBrand, target), desc.get.call(target);
801 } : function (target) {
802 return desc.get.call(target);
803 }), (t || 4 === kind) && (set = isPrivate ? function (target, value) {
804 assertInstanceIfPrivate(hasPrivateBrand, target), desc.set.call(target, value);
805 } : function (target, value) {
806 desc.set.call(target, value);
807 });
808 }
809 } else get = function (target) {
810 return target[name];
811 }, 0 === kind && (set = function (target, v) {
812 target[name] = v;
813 });
814 var has = isPrivate ? hasPrivateBrand.bind() : function (target) {
815 return name in target;
816 };
817 ctx.access = get && set ? {
818 get: get,
819 set: set,
820 has: has
821 } : get ? {
822 get: get,
823 has: has
824 } : {
825 set: set,
826 has: has
827 };
828 try {
829 return dec(value, ctx);
830 } finally {
831 decoratorFinishedRef.v = !0;
832 }
833 }
834 function assertNotFinished(decoratorFinishedRef, fnName) {
835 if (decoratorFinishedRef.v) throw new Error("attempted to call " + fnName + " after decoration was finished");
836 }
837 function assertCallable(fn, hint) {
838 if ("function" != typeof fn) throw new TypeError(hint + " must be a function");
839 }
840 function assertValidReturnValue(kind, value) {
841 var type = typeof value;
842 if (1 === kind) {
843 if ("object" !== type || null === value) throw new TypeError("accessor decorators must return an object with get, set, or init properties or void 0");
844 void 0 !== value.get && assertCallable(value.get, "accessor.get"), void 0 !== value.set && assertCallable(value.set, "accessor.set"), void 0 !== value.init && assertCallable(value.init, "accessor.init");
845 } else if ("function" !== type) {
846 var hint;
847 throw hint = 0 === kind ? "field" : 10 === kind ? "class" : "method", new TypeError(hint + " decorators must return a function or void 0");
848 }
849 }
850 function curryThis1(fn) {
851 return function () {
852 return fn(this);
853 };
854 }
855 function curryThis2(fn) {
856 return function (value) {
857 fn(this, value);
858 };
859 }
860 function applyMemberDec(ret, base, decInfo, name, kind, isStatic, isPrivate, initializers, hasPrivateBrand) {
861 var desc,
862 init,
863 value,
864 newValue,
865 get,
866 set,
867 decs = decInfo[0];
868 if (isPrivate ? desc = 0 === kind || 1 === kind ? {
869 get: curryThis1(decInfo[3]),
870 set: curryThis2(decInfo[4])
871 } : 3 === kind ? {
872 get: decInfo[3]
873 } : 4 === kind ? {
874 set: decInfo[3]
875 } : {
876 value: decInfo[3]
877 } : 0 !== kind && (desc = Object.getOwnPropertyDescriptor(base, name)), 1 === kind ? value = {
878 get: desc.get,
879 set: desc.set
880 } : 2 === kind ? value = desc.value : 3 === kind ? value = desc.get : 4 === kind && (value = desc.set), "function" == typeof decs) void 0 !== (newValue = memberDec(decs, name, desc, initializers, kind, isStatic, isPrivate, value, hasPrivateBrand)) && (assertValidReturnValue(kind, newValue), 0 === kind ? init = newValue : 1 === kind ? (init = newValue.init, get = newValue.get || value.get, set = newValue.set || value.set, value = {
881 get: get,
882 set: set
883 }) : value = newValue);else for (var i = decs.length - 1; i >= 0; i--) {
884 var newInit;
885 if (void 0 !== (newValue = memberDec(decs[i], name, desc, initializers, kind, isStatic, isPrivate, value, hasPrivateBrand))) assertValidReturnValue(kind, newValue), 0 === kind ? newInit = newValue : 1 === kind ? (newInit = newValue.init, get = newValue.get || value.get, set = newValue.set || value.set, value = {
886 get: get,
887 set: set
888 }) : value = newValue, void 0 !== newInit && (void 0 === init ? init = newInit : "function" == typeof init ? init = [init, newInit] : init.push(newInit));
889 }
890 if (0 === kind || 1 === kind) {
891 if (void 0 === init) init = function (instance, init) {
892 return init;
893 };else if ("function" != typeof init) {
894 var ownInitializers = init;
895 init = function (instance, init) {
896 for (var value = init, i = 0; i < ownInitializers.length; i++) value = ownInitializers[i].call(instance, value);
897 return value;
898 };
899 } else {
900 var originalInitializer = init;
901 init = function (instance, init) {
902 return originalInitializer.call(instance, init);
903 };
904 }
905 ret.push(init);
906 }
907 0 !== kind && (1 === kind ? (desc.get = value.get, desc.set = value.set) : 2 === kind ? desc.value = value : 3 === kind ? desc.get = value : 4 === kind && (desc.set = value), isPrivate ? 1 === kind ? (ret.push(function (instance, args) {
908 return value.get.call(instance, args);
909 }), ret.push(function (instance, args) {
910 return value.set.call(instance, args);
911 })) : 2 === kind ? ret.push(value) : ret.push(function (instance, args) {
912 return value.call(instance, args);
913 }) : Object.defineProperty(base, name, desc));
914 }
915 function applyMemberDecs(Class, decInfos, instanceBrand) {
916 for (var protoInitializers, staticInitializers, staticBrand, ret = [], existingProtoNonFields = new Map(), existingStaticNonFields = new Map(), i = 0; i < decInfos.length; i++) {
917 var decInfo = decInfos[i];
918 if (Array.isArray(decInfo)) {
919 var base,
920 initializers,
921 kind = decInfo[1],
922 name = decInfo[2],
923 isPrivate = decInfo.length > 3,
924 isStatic = kind >= 5,
925 hasPrivateBrand = instanceBrand;
926 if (isStatic ? (base = Class, 0 !== (kind -= 5) && (initializers = staticInitializers = staticInitializers || []), isPrivate && !staticBrand && (staticBrand = function (_) {
927 return _checkInRHS(_) === Class;
928 }), hasPrivateBrand = staticBrand) : (base = Class.prototype, 0 !== kind && (initializers = protoInitializers = protoInitializers || [])), 0 !== kind && !isPrivate) {
929 var existingNonFields = isStatic ? existingStaticNonFields : existingProtoNonFields,
930 existingKind = existingNonFields.get(name) || 0;
931 if (!0 === existingKind || 3 === existingKind && 4 !== kind || 4 === existingKind && 3 !== kind) throw new Error("Attempted to decorate a public method/accessor that has the same name as a previously decorated public method/accessor. This is not currently supported by the decorators plugin. Property name was: " + name);
932 !existingKind && kind > 2 ? existingNonFields.set(name, kind) : existingNonFields.set(name, !0);
933 }
934 applyMemberDec(ret, base, decInfo, name, kind, isStatic, isPrivate, initializers, hasPrivateBrand);
935 }
936 }
937 return pushInitializers(ret, protoInitializers), pushInitializers(ret, staticInitializers), ret;
938 }
939 function pushInitializers(ret, initializers) {
940 initializers && ret.push(function (instance) {
941 for (var i = 0; i < initializers.length; i++) initializers[i].call(instance);
942 return instance;
943 });
944 }
945 function applyClassDecs(targetClass, classDecs) {
946 if (classDecs.length > 0) {
947 for (var initializers = [], newClass = targetClass, name = targetClass.name, i = classDecs.length - 1; i >= 0; i--) {
948 var decoratorFinishedRef = {
949 v: !1
950 };
951 try {
952 var nextNewClass = classDecs[i](newClass, {
953 kind: "class",
954 name: name,
955 addInitializer: createAddInitializerMethod(initializers, decoratorFinishedRef)
956 });
957 } finally {
958 decoratorFinishedRef.v = !0;
959 }
960 void 0 !== nextNewClass && (assertValidReturnValue(10, nextNewClass), newClass = nextNewClass);
961 }
962 return [newClass, function () {
963 for (var i = 0; i < initializers.length; i++) initializers[i].call(newClass);
964 }];
965 }
966 }
967 function _applyDecs2301(targetClass, memberDecs, classDecs, instanceBrand) {
968 return {
969 e: applyMemberDecs(targetClass, memberDecs, instanceBrand),
970 get c() {
971 return applyClassDecs(targetClass, classDecs);
972 }
973 };
974 }
975 function _asyncGeneratorDelegate(inner) {
976 var iter = {},
977 waiting = !1;
978 function pump(key, value) {
979 return waiting = !0, value = new Promise(function (resolve) {
980 resolve(inner[key](value));
981 }), {
982 done: !1,
983 value: new _OverloadYield(value, 1)
984 };
985 }
986 return iter["undefined" != typeof Symbol && Symbol.iterator || "@@iterator"] = function () {
987 return this;
988 }, iter.next = function (value) {
989 return waiting ? (waiting = !1, value) : pump("next", value);
990 }, "function" == typeof inner.throw && (iter.throw = function (value) {
991 if (waiting) throw waiting = !1, value;
992 return pump("throw", value);
993 }), "function" == typeof inner.return && (iter.return = function (value) {
994 return waiting ? (waiting = !1, value) : pump("return", value);
995 }), iter;
996 }
997 function _asyncIterator(iterable) {
998 var method,
999 async,
1000 sync,
1001 retry = 2;
1002 for ("undefined" != typeof Symbol && (async = Symbol.asyncIterator, sync = Symbol.iterator); retry--;) {
1003 if (async && null != (method = iterable[async])) return method.call(iterable);
1004 if (sync && null != (method = iterable[sync])) return new AsyncFromSyncIterator(method.call(iterable));
1005 async = "@@asyncIterator", sync = "@@iterator";
1006 }
1007 throw new TypeError("Object is not async iterable");
1008 }
1009 function AsyncFromSyncIterator(s) {
1010 function AsyncFromSyncIteratorContinuation(r) {
1011 if (Object(r) !== r) return Promise.reject(new TypeError(r + " is not an object."));
1012 var done = r.done;
1013 return Promise.resolve(r.value).then(function (value) {
1014 return {
1015 value: value,
1016 done: done
1017 };
1018 });
1019 }
1020 return AsyncFromSyncIterator = function (s) {
1021 this.s = s, this.n = s.next;
1022 }, AsyncFromSyncIterator.prototype = {
1023 s: null,
1024 n: null,
1025 next: function () {
1026 return AsyncFromSyncIteratorContinuation(this.n.apply(this.s, arguments));
1027 },
1028 return: function (value) {
1029 var ret = this.s.return;
1030 return void 0 === ret ? Promise.resolve({
1031 value: value,
1032 done: !0
1033 }) : AsyncFromSyncIteratorContinuation(ret.apply(this.s, arguments));
1034 },
1035 throw: function (value) {
1036 var thr = this.s.return;
1037 return void 0 === thr ? Promise.reject(value) : AsyncFromSyncIteratorContinuation(thr.apply(this.s, arguments));
1038 }
1039 }, new AsyncFromSyncIterator(s);
1040 }
1041 function _awaitAsyncGenerator(value) {
1042 return new _OverloadYield(value, 0);
1043 }
1044 function _checkInRHS(value) {
1045 if (Object(value) !== value) throw TypeError("right-hand side of 'in' should be an object, got " + (null !== value ? typeof value : "null"));
1046 return value;
1047 }
1048 function _defineAccessor(type, obj, key, fn) {
1049 var desc = {
1050 configurable: !0,
1051 enumerable: !0
1052 };
1053 return desc[type] = fn, Object.defineProperty(obj, key, desc);
1054 }
1055 function _iterableToArrayLimit(arr, i) {
1056 var _i = null == arr ? null : "undefined" != typeof Symbol && arr[Symbol.iterator] || arr["@@iterator"];
1057 if (null != _i) {
1058 var _s,
1059 _e,
1060 _x,
1061 _r,
1062 _arr = [],
1063 _n = !0,
1064 _d = !1;
1065 try {
1066 if (_x = (_i = _i.call(arr)).next, 0 === i) {
1067 if (Object(_i) !== _i) return;
1068 _n = !1;
1069 } else for (; !(_n = (_s = _x.call(_i)).done) && (_arr.push(_s.value), _arr.length !== i); _n = !0);
1070 } catch (err) {
1071 _d = !0, _e = err;
1072 } finally {
1073 try {
1074 if (!_n && null != _i.return && (_r = _i.return(), Object(_r) !== _r)) return;
1075 } finally {
1076 if (_d) throw _e;
1077 }
1078 }
1079 return _arr;
1080 }
1081 }
1082 function _iterableToArrayLimitLoose(arr, i) {
1083 var _i = arr && ("undefined" != typeof Symbol && arr[Symbol.iterator] || arr["@@iterator"]);
1084 if (null != _i) {
1085 var _s,
1086 _arr = [];
1087 for (_i = _i.call(arr); arr.length < i && !(_s = _i.next()).done;) _arr.push(_s.value);
1088 return _arr;
1089 }
1090 }
1091 var REACT_ELEMENT_TYPE;
1092 function _jsx(type, props, key, children) {
1093 REACT_ELEMENT_TYPE || (REACT_ELEMENT_TYPE = "function" == typeof Symbol && Symbol.for && Symbol.for("react.element") || 60103);
1094 var defaultProps = type && type.defaultProps,
1095 childrenLength = arguments.length - 3;
1096 if (props || 0 === childrenLength || (props = {
1097 children: void 0
1098 }), 1 === childrenLength) props.children = children;else if (childrenLength > 1) {
1099 for (var childArray = new Array(childrenLength), i = 0; i < childrenLength; i++) childArray[i] = arguments[i + 3];
1100 props.children = childArray;
1101 }
1102 if (props && defaultProps) for (var propName in defaultProps) void 0 === props[propName] && (props[propName] = defaultProps[propName]);else props || (props = defaultProps || {});
1103 return {
1104 $$typeof: REACT_ELEMENT_TYPE,
1105 type: type,
1106 key: void 0 === key ? null : "" + key,
1107 ref: null,
1108 props: props,
1109 _owner: null
1110 };
1111 }
1112 function ownKeys$5(object, enumerableOnly) {
1113 var keys = Object.keys(object);
1114 if (Object.getOwnPropertySymbols) {
1115 var symbols = Object.getOwnPropertySymbols(object);
1116 enumerableOnly && (symbols = symbols.filter(function (sym) {
1117 return Object.getOwnPropertyDescriptor(object, sym).enumerable;
1118 })), keys.push.apply(keys, symbols);
1119 }
1120 return keys;
1121 }
1122 function _objectSpread2(target) {
1123 for (var i = 1; i < arguments.length; i++) {
1124 var source = null != arguments[i] ? arguments[i] : {};
1125 i % 2 ? ownKeys$5(Object(source), !0).forEach(function (key) {
1126 _defineProperty(target, key, source[key]);
1127 }) : Object.getOwnPropertyDescriptors ? Object.defineProperties(target, Object.getOwnPropertyDescriptors(source)) : ownKeys$5(Object(source)).forEach(function (key) {
1128 Object.defineProperty(target, key, Object.getOwnPropertyDescriptor(source, key));
1129 });
1130 }
1131 return target;
1132 }
1133 function _regeneratorRuntime() {
1134 "use strict"; /*! regenerator-runtime -- Copyright (c) 2014-present, Facebook, Inc. -- license (MIT): https://github.com/facebook/regenerator/blob/main/LICENSE */
1135 _regeneratorRuntime = function () {
1136 return exports;
1137 };
1138 var exports = {},
1139 Op = Object.prototype,
1140 hasOwn = Op.hasOwnProperty,
1141 defineProperty = Object.defineProperty || function (obj, key, desc) {
1142 obj[key] = desc.value;
1143 },
1144 $Symbol = "function" == typeof Symbol ? Symbol : {},
1145 iteratorSymbol = $Symbol.iterator || "@@iterator",
1146 asyncIteratorSymbol = $Symbol.asyncIterator || "@@asyncIterator",
1147 toStringTagSymbol = $Symbol.toStringTag || "@@toStringTag";
1148 function define(obj, key, value) {
1149 return Object.defineProperty(obj, key, {
1150 value: value,
1151 enumerable: !0,
1152 configurable: !0,
1153 writable: !0
1154 }), obj[key];
1155 }
1156 try {
1157 define({}, "");
1158 } catch (err) {
1159 define = function (obj, key, value) {
1160 return obj[key] = value;
1161 };
1162 }
1163 function wrap(innerFn, outerFn, self, tryLocsList) {
1164 var protoGenerator = outerFn && outerFn.prototype instanceof Generator ? outerFn : Generator,
1165 generator = Object.create(protoGenerator.prototype),
1166 context = new Context(tryLocsList || []);
1167 return defineProperty(generator, "_invoke", {
1168 value: makeInvokeMethod(innerFn, self, context)
1169 }), generator;
1170 }
1171 function tryCatch(fn, obj, arg) {
1172 try {
1173 return {
1174 type: "normal",
1175 arg: fn.call(obj, arg)
1176 };
1177 } catch (err) {
1178 return {
1179 type: "throw",
1180 arg: err
1181 };
1182 }
1183 }
1184 exports.wrap = wrap;
1185 var ContinueSentinel = {};
1186 function Generator() {}
1187 function GeneratorFunction() {}
1188 function GeneratorFunctionPrototype() {}
1189 var IteratorPrototype = {};
1190 define(IteratorPrototype, iteratorSymbol, function () {
1191 return this;
1192 });
1193 var getProto = Object.getPrototypeOf,
1194 NativeIteratorPrototype = getProto && getProto(getProto(values([])));
1195 NativeIteratorPrototype && NativeIteratorPrototype !== Op && hasOwn.call(NativeIteratorPrototype, iteratorSymbol) && (IteratorPrototype = NativeIteratorPrototype);
1196 var Gp = GeneratorFunctionPrototype.prototype = Generator.prototype = Object.create(IteratorPrototype);
1197 function defineIteratorMethods(prototype) {
1198 ["next", "throw", "return"].forEach(function (method) {
1199 define(prototype, method, function (arg) {
1200 return this._invoke(method, arg);
1201 });
1202 });
1203 }
1204 function AsyncIterator(generator, PromiseImpl) {
1205 function invoke(method, arg, resolve, reject) {
1206 var record = tryCatch(generator[method], generator, arg);
1207 if ("throw" !== record.type) {
1208 var result = record.arg,
1209 value = result.value;
1210 return value && "object" == typeof value && hasOwn.call(value, "__await") ? PromiseImpl.resolve(value.__await).then(function (value) {
1211 invoke("next", value, resolve, reject);
1212 }, function (err) {
1213 invoke("throw", err, resolve, reject);
1214 }) : PromiseImpl.resolve(value).then(function (unwrapped) {
1215 result.value = unwrapped, resolve(result);
1216 }, function (error) {
1217 return invoke("throw", error, resolve, reject);
1218 });
1219 }
1220 reject(record.arg);
1221 }
1222 var previousPromise;
1223 defineProperty(this, "_invoke", {
1224 value: function (method, arg) {
1225 function callInvokeWithMethodAndArg() {
1226 return new PromiseImpl(function (resolve, reject) {
1227 invoke(method, arg, resolve, reject);
1228 });
1229 }
1230 return previousPromise = previousPromise ? previousPromise.then(callInvokeWithMethodAndArg, callInvokeWithMethodAndArg) : callInvokeWithMethodAndArg();
1231 }
1232 });
1233 }
1234 function makeInvokeMethod(innerFn, self, context) {
1235 var state = "suspendedStart";
1236 return function (method, arg) {
1237 if ("executing" === state) throw new Error("Generator is already running");
1238 if ("completed" === state) {
1239 if ("throw" === method) throw arg;
1240 return doneResult();
1241 }
1242 for (context.method = method, context.arg = arg;;) {
1243 var delegate = context.delegate;
1244 if (delegate) {
1245 var delegateResult = maybeInvokeDelegate(delegate, context);
1246 if (delegateResult) {
1247 if (delegateResult === ContinueSentinel) continue;
1248 return delegateResult;
1249 }
1250 }
1251 if ("next" === context.method) context.sent = context._sent = context.arg;else if ("throw" === context.method) {
1252 if ("suspendedStart" === state) throw state = "completed", context.arg;
1253 context.dispatchException(context.arg);
1254 } else "return" === context.method && context.abrupt("return", context.arg);
1255 state = "executing";
1256 var record = tryCatch(innerFn, self, context);
1257 if ("normal" === record.type) {
1258 if (state = context.done ? "completed" : "suspendedYield", record.arg === ContinueSentinel) continue;
1259 return {
1260 value: record.arg,
1261 done: context.done
1262 };
1263 }
1264 "throw" === record.type && (state = "completed", context.method = "throw", context.arg = record.arg);
1265 }
1266 };
1267 }
1268 function maybeInvokeDelegate(delegate, context) {
1269 var methodName = context.method,
1270 method = delegate.iterator[methodName];
1271 if (undefined === method) return context.delegate = null, "throw" === methodName && delegate.iterator.return && (context.method = "return", context.arg = undefined, maybeInvokeDelegate(delegate, context), "throw" === context.method) || "return" !== methodName && (context.method = "throw", context.arg = new TypeError("The iterator does not provide a '" + methodName + "' method")), ContinueSentinel;
1272 var record = tryCatch(method, delegate.iterator, context.arg);
1273 if ("throw" === record.type) return context.method = "throw", context.arg = record.arg, context.delegate = null, ContinueSentinel;
1274 var info = record.arg;
1275 return info ? info.done ? (context[delegate.resultName] = info.value, context.next = delegate.nextLoc, "return" !== context.method && (context.method = "next", context.arg = undefined), context.delegate = null, ContinueSentinel) : info : (context.method = "throw", context.arg = new TypeError("iterator result is not an object"), context.delegate = null, ContinueSentinel);
1276 }
1277 function pushTryEntry(locs) {
1278 var entry = {
1279 tryLoc: locs[0]
1280 };
1281 1 in locs && (entry.catchLoc = locs[1]), 2 in locs && (entry.finallyLoc = locs[2], entry.afterLoc = locs[3]), this.tryEntries.push(entry);
1282 }
1283 function resetTryEntry(entry) {
1284 var record = entry.completion || {};
1285 record.type = "normal", delete record.arg, entry.completion = record;
1286 }
1287 function Context(tryLocsList) {
1288 this.tryEntries = [{
1289 tryLoc: "root"
1290 }], tryLocsList.forEach(pushTryEntry, this), this.reset(!0);
1291 }
1292 function values(iterable) {
1293 if (iterable) {
1294 var iteratorMethod = iterable[iteratorSymbol];
1295 if (iteratorMethod) return iteratorMethod.call(iterable);
1296 if ("function" == typeof iterable.next) return iterable;
1297 if (!isNaN(iterable.length)) {
1298 var i = -1,
1299 next = function next() {
1300 for (; ++i < iterable.length;) if (hasOwn.call(iterable, i)) return next.value = iterable[i], next.done = !1, next;
1301 return next.value = undefined, next.done = !0, next;
1302 };
1303 return next.next = next;
1304 }
1305 }
1306 return {
1307 next: doneResult
1308 };
1309 }
1310 function doneResult() {
1311 return {
1312 value: undefined,
1313 done: !0
1314 };
1315 }
1316 return GeneratorFunction.prototype = GeneratorFunctionPrototype, defineProperty(Gp, "constructor", {
1317 value: GeneratorFunctionPrototype,
1318 configurable: !0
1319 }), defineProperty(GeneratorFunctionPrototype, "constructor", {
1320 value: GeneratorFunction,
1321 configurable: !0
1322 }), GeneratorFunction.displayName = define(GeneratorFunctionPrototype, toStringTagSymbol, "GeneratorFunction"), exports.isGeneratorFunction = function (genFun) {
1323 var ctor = "function" == typeof genFun && genFun.constructor;
1324 return !!ctor && (ctor === GeneratorFunction || "GeneratorFunction" === (ctor.displayName || ctor.name));
1325 }, exports.mark = function (genFun) {
1326 return Object.setPrototypeOf ? Object.setPrototypeOf(genFun, GeneratorFunctionPrototype) : (genFun.__proto__ = GeneratorFunctionPrototype, define(genFun, toStringTagSymbol, "GeneratorFunction")), genFun.prototype = Object.create(Gp), genFun;
1327 }, exports.awrap = function (arg) {
1328 return {
1329 __await: arg
1330 };
1331 }, defineIteratorMethods(AsyncIterator.prototype), define(AsyncIterator.prototype, asyncIteratorSymbol, function () {
1332 return this;
1333 }), exports.AsyncIterator = AsyncIterator, exports.async = function (innerFn, outerFn, self, tryLocsList, PromiseImpl) {
1334 void 0 === PromiseImpl && (PromiseImpl = Promise);
1335 var iter = new AsyncIterator(wrap(innerFn, outerFn, self, tryLocsList), PromiseImpl);
1336 return exports.isGeneratorFunction(outerFn) ? iter : iter.next().then(function (result) {
1337 return result.done ? result.value : iter.next();
1338 });
1339 }, defineIteratorMethods(Gp), define(Gp, toStringTagSymbol, "Generator"), define(Gp, iteratorSymbol, function () {
1340 return this;
1341 }), define(Gp, "toString", function () {
1342 return "[object Generator]";
1343 }), exports.keys = function (val) {
1344 var object = Object(val),
1345 keys = [];
1346 for (var key in object) keys.push(key);
1347 return keys.reverse(), function next() {
1348 for (; keys.length;) {
1349 var key = keys.pop();
1350 if (key in object) return next.value = key, next.done = !1, next;
1351 }
1352 return next.done = !0, next;
1353 };
1354 }, exports.values = values, Context.prototype = {
1355 constructor: Context,
1356 reset: function (skipTempReset) {
1357 if (this.prev = 0, this.next = 0, this.sent = this._sent = undefined, this.done = !1, this.delegate = null, this.method = "next", this.arg = undefined, this.tryEntries.forEach(resetTryEntry), !skipTempReset) for (var name in this) "t" === name.charAt(0) && hasOwn.call(this, name) && !isNaN(+name.slice(1)) && (this[name] = undefined);
1358 },
1359 stop: function () {
1360 this.done = !0;
1361 var rootRecord = this.tryEntries[0].completion;
1362 if ("throw" === rootRecord.type) throw rootRecord.arg;
1363 return this.rval;
1364 },
1365 dispatchException: function (exception) {
1366 if (this.done) throw exception;
1367 var context = this;
1368 function handle(loc, caught) {
1369 return record.type = "throw", record.arg = exception, context.next = loc, caught && (context.method = "next", context.arg = undefined), !!caught;
1370 }
1371 for (var i = this.tryEntries.length - 1; i >= 0; --i) {
1372 var entry = this.tryEntries[i],
1373 record = entry.completion;
1374 if ("root" === entry.tryLoc) return handle("end");
1375 if (entry.tryLoc <= this.prev) {
1376 var hasCatch = hasOwn.call(entry, "catchLoc"),
1377 hasFinally = hasOwn.call(entry, "finallyLoc");
1378 if (hasCatch && hasFinally) {
1379 if (this.prev < entry.catchLoc) return handle(entry.catchLoc, !0);
1380 if (this.prev < entry.finallyLoc) return handle(entry.finallyLoc);
1381 } else if (hasCatch) {
1382 if (this.prev < entry.catchLoc) return handle(entry.catchLoc, !0);
1383 } else {
1384 if (!hasFinally) throw new Error("try statement without catch or finally");
1385 if (this.prev < entry.finallyLoc) return handle(entry.finallyLoc);
1386 }
1387 }
1388 }
1389 },
1390 abrupt: function (type, arg) {
1391 for (var i = this.tryEntries.length - 1; i >= 0; --i) {
1392 var entry = this.tryEntries[i];
1393 if (entry.tryLoc <= this.prev && hasOwn.call(entry, "finallyLoc") && this.prev < entry.finallyLoc) {
1394 var finallyEntry = entry;
1395 break;
1396 }
1397 }
1398 finallyEntry && ("break" === type || "continue" === type) && finallyEntry.tryLoc <= arg && arg <= finallyEntry.finallyLoc && (finallyEntry = null);
1399 var record = finallyEntry ? finallyEntry.completion : {};
1400 return record.type = type, record.arg = arg, finallyEntry ? (this.method = "next", this.next = finallyEntry.finallyLoc, ContinueSentinel) : this.complete(record);
1401 },
1402 complete: function (record, afterLoc) {
1403 if ("throw" === record.type) throw record.arg;
1404 return "break" === record.type || "continue" === record.type ? this.next = record.arg : "return" === record.type ? (this.rval = this.arg = record.arg, this.method = "return", this.next = "end") : "normal" === record.type && afterLoc && (this.next = afterLoc), ContinueSentinel;
1405 },
1406 finish: function (finallyLoc) {
1407 for (var i = this.tryEntries.length - 1; i >= 0; --i) {
1408 var entry = this.tryEntries[i];
1409 if (entry.finallyLoc === finallyLoc) return this.complete(entry.completion, entry.afterLoc), resetTryEntry(entry), ContinueSentinel;
1410 }
1411 },
1412 catch: function (tryLoc) {
1413 for (var i = this.tryEntries.length - 1; i >= 0; --i) {
1414 var entry = this.tryEntries[i];
1415 if (entry.tryLoc === tryLoc) {
1416 var record = entry.completion;
1417 if ("throw" === record.type) {
1418 var thrown = record.arg;
1419 resetTryEntry(entry);
1420 }
1421 return thrown;
1422 }
1423 }
1424 throw new Error("illegal catch attempt");
1425 },
1426 delegateYield: function (iterable, resultName, nextLoc) {
1427 return this.delegate = {
1428 iterator: values(iterable),
1429 resultName: resultName,
1430 nextLoc: nextLoc
1431 }, "next" === this.method && (this.arg = undefined), ContinueSentinel;
1432 }
1433 }, exports;
1434 }
1435 function _typeof(obj) {
1436 "@babel/helpers - typeof";
1437
1438 return _typeof = "function" == typeof Symbol && "symbol" == typeof Symbol.iterator ? function (obj) {
1439 return typeof obj;
1440 } : function (obj) {
1441 return obj && "function" == typeof Symbol && obj.constructor === Symbol && obj !== Symbol.prototype ? "symbol" : typeof obj;
1442 }, _typeof(obj);
1443 }
1444 function _wrapRegExp() {
1445 _wrapRegExp = function (re, groups) {
1446 return new BabelRegExp(re, void 0, groups);
1447 };
1448 var _super = RegExp.prototype,
1449 _groups = new WeakMap();
1450 function BabelRegExp(re, flags, groups) {
1451 var _this = new RegExp(re, flags);
1452 return _groups.set(_this, groups || _groups.get(re)), _setPrototypeOf(_this, BabelRegExp.prototype);
1453 }
1454 function buildGroups(result, re) {
1455 var g = _groups.get(re);
1456 return Object.keys(g).reduce(function (groups, name) {
1457 var i = g[name];
1458 if ("number" == typeof i) groups[name] = result[i];else {
1459 for (var k = 0; void 0 === result[i[k]] && k + 1 < i.length;) k++;
1460 groups[name] = result[i[k]];
1461 }
1462 return groups;
1463 }, Object.create(null));
1464 }
1465 return _inherits(BabelRegExp, RegExp), BabelRegExp.prototype.exec = function (str) {
1466 var result = _super.exec.call(this, str);
1467 if (result) {
1468 result.groups = buildGroups(result, this);
1469 var indices = result.indices;
1470 indices && (indices.groups = buildGroups(indices, this));
1471 }
1472 return result;
1473 }, BabelRegExp.prototype[Symbol.replace] = function (str, substitution) {
1474 if ("string" == typeof substitution) {
1475 var groups = _groups.get(this);
1476 return _super[Symbol.replace].call(this, str, substitution.replace(/\$<([^>]+)>/g, function (_, name) {
1477 var group = groups[name];
1478 return "$" + (Array.isArray(group) ? group.join("$") : group);
1479 }));
1480 }
1481 if ("function" == typeof substitution) {
1482 var _this = this;
1483 return _super[Symbol.replace].call(this, str, function () {
1484 var args = arguments;
1485 return "object" != typeof args[args.length - 1] && (args = [].slice.call(args)).push(buildGroups(args, _this)), substitution.apply(this, args);
1486 });
1487 }
1488 return _super[Symbol.replace].call(this, str, substitution);
1489 }, _wrapRegExp.apply(this, arguments);
1490 }
1491 function _AwaitValue(value) {
1492 this.wrapped = value;
1493 }
1494 function _wrapAsyncGenerator(fn) {
1495 return function () {
1496 return new _AsyncGenerator(fn.apply(this, arguments));
1497 };
1498 }
1499 function asyncGeneratorStep(gen, resolve, reject, _next, _throw, key, arg) {
1500 try {
1501 var info = gen[key](arg);
1502 var value = info.value;
1503 } catch (error) {
1504 reject(error);
1505 return;
1506 }
1507 if (info.done) {
1508 resolve(value);
1509 } else {
1510 Promise.resolve(value).then(_next, _throw);
1511 }
1512 }
1513 function _asyncToGenerator(fn) {
1514 return function () {
1515 var self = this,
1516 args = arguments;
1517 return new Promise(function (resolve, reject) {
1518 var gen = fn.apply(self, args);
1519 function _next(value) {
1520 asyncGeneratorStep(gen, resolve, reject, _next, _throw, "next", value);
1521 }
1522 function _throw(err) {
1523 asyncGeneratorStep(gen, resolve, reject, _next, _throw, "throw", err);
1524 }
1525 _next(undefined);
1526 });
1527 };
1528 }
1529 function _classCallCheck(instance, Constructor) {
1530 if (!(instance instanceof Constructor)) {
1531 throw new TypeError("Cannot call a class as a function");
1532 }
1533 }
1534 function _defineProperties(target, props) {
1535 for (var i = 0; i < props.length; i++) {
1536 var descriptor = props[i];
1537 descriptor.enumerable = descriptor.enumerable || false;
1538 descriptor.configurable = true;
1539 if ("value" in descriptor) descriptor.writable = true;
1540 Object.defineProperty(target, _toPropertyKey(descriptor.key), descriptor);
1541 }
1542 }
1543 function _createClass(Constructor, protoProps, staticProps) {
1544 if (protoProps) _defineProperties(Constructor.prototype, protoProps);
1545 if (staticProps) _defineProperties(Constructor, staticProps);
1546 Object.defineProperty(Constructor, "prototype", {
1547 writable: false
1548 });
1549 return Constructor;
1550 }
1551 function _defineEnumerableProperties(obj, descs) {
1552 for (var key in descs) {
1553 var desc = descs[key];
1554 desc.configurable = desc.enumerable = true;
1555 if ("value" in desc) desc.writable = true;
1556 Object.defineProperty(obj, key, desc);
1557 }
1558 if (Object.getOwnPropertySymbols) {
1559 var objectSymbols = Object.getOwnPropertySymbols(descs);
1560 for (var i = 0; i < objectSymbols.length; i++) {
1561 var sym = objectSymbols[i];
1562 var desc = descs[sym];
1563 desc.configurable = desc.enumerable = true;
1564 if ("value" in desc) desc.writable = true;
1565 Object.defineProperty(obj, sym, desc);
1566 }
1567 }
1568 return obj;
1569 }
1570 function _defaults(obj, defaults) {
1571 var keys = Object.getOwnPropertyNames(defaults);
1572 for (var i = 0; i < keys.length; i++) {
1573 var key = keys[i];
1574 var value = Object.getOwnPropertyDescriptor(defaults, key);
1575 if (value && value.configurable && obj[key] === undefined) {
1576 Object.defineProperty(obj, key, value);
1577 }
1578 }
1579 return obj;
1580 }
1581 function _defineProperty(obj, key, value) {
1582 key = _toPropertyKey(key);
1583 if (key in obj) {
1584 Object.defineProperty(obj, key, {
1585 value: value,
1586 enumerable: true,
1587 configurable: true,
1588 writable: true
1589 });
1590 } else {
1591 obj[key] = value;
1592 }
1593 return obj;
1594 }
1595 function _extends() {
1596 _extends = Object.assign ? Object.assign.bind() : function (target) {
1597 for (var i = 1; i < arguments.length; i++) {
1598 var source = arguments[i];
1599 for (var key in source) {
1600 if (Object.prototype.hasOwnProperty.call(source, key)) {
1601 target[key] = source[key];
1602 }
1603 }
1604 }
1605 return target;
1606 };
1607 return _extends.apply(this, arguments);
1608 }
1609 function _objectSpread(target) {
1610 for (var i = 1; i < arguments.length; i++) {
1611 var source = arguments[i] != null ? Object(arguments[i]) : {};
1612 var ownKeys = Object.keys(source);
1613 if (typeof Object.getOwnPropertySymbols === 'function') {
1614 ownKeys.push.apply(ownKeys, Object.getOwnPropertySymbols(source).filter(function (sym) {
1615 return Object.getOwnPropertyDescriptor(source, sym).enumerable;
1616 }));
1617 }
1618 ownKeys.forEach(function (key) {
1619 _defineProperty(target, key, source[key]);
1620 });
1621 }
1622 return target;
1623 }
1624 function _inherits(subClass, superClass) {
1625 if (typeof superClass !== "function" && superClass !== null) {
1626 throw new TypeError("Super expression must either be null or a function");
1627 }
1628 subClass.prototype = Object.create(superClass && superClass.prototype, {
1629 constructor: {
1630 value: subClass,
1631 writable: true,
1632 configurable: true
1633 }
1634 });
1635 Object.defineProperty(subClass, "prototype", {
1636 writable: false
1637 });
1638 if (superClass) _setPrototypeOf(subClass, superClass);
1639 }
1640 function _inheritsLoose(subClass, superClass) {
1641 subClass.prototype = Object.create(superClass.prototype);
1642 subClass.prototype.constructor = subClass;
1643 _setPrototypeOf(subClass, superClass);
1644 }
1645 function _getPrototypeOf(o) {
1646 _getPrototypeOf = Object.setPrototypeOf ? Object.getPrototypeOf.bind() : function _getPrototypeOf(o) {
1647 return o.__proto__ || Object.getPrototypeOf(o);
1648 };
1649 return _getPrototypeOf(o);
1650 }
1651 function _setPrototypeOf(o, p) {
1652 _setPrototypeOf = Object.setPrototypeOf ? Object.setPrototypeOf.bind() : function _setPrototypeOf(o, p) {
1653 o.__proto__ = p;
1654 return o;
1655 };
1656 return _setPrototypeOf(o, p);
1657 }
1658 function _isNativeReflectConstruct() {
1659 if (typeof Reflect === "undefined" || !Reflect.construct) return false;
1660 if (Reflect.construct.sham) return false;
1661 if (typeof Proxy === "function") return true;
1662 try {
1663 Boolean.prototype.valueOf.call(Reflect.construct(Boolean, [], function () {}));
1664 return true;
1665 } catch (e) {
1666 return false;
1667 }
1668 }
1669 function _construct(Parent, args, Class) {
1670 if (_isNativeReflectConstruct()) {
1671 _construct = Reflect.construct.bind();
1672 } else {
1673 _construct = function _construct(Parent, args, Class) {
1674 var a = [null];
1675 a.push.apply(a, args);
1676 var Constructor = Function.bind.apply(Parent, a);
1677 var instance = new Constructor();
1678 if (Class) _setPrototypeOf(instance, Class.prototype);
1679 return instance;
1680 };
1681 }
1682 return _construct.apply(null, arguments);
1683 }
1684 function _isNativeFunction(fn) {
1685 return Function.toString.call(fn).indexOf("[native code]") !== -1;
1686 }
1687 function _wrapNativeSuper(Class) {
1688 var _cache = typeof Map === "function" ? new Map() : undefined;
1689 _wrapNativeSuper = function _wrapNativeSuper(Class) {
1690 if (Class === null || !_isNativeFunction(Class)) return Class;
1691 if (typeof Class !== "function") {
1692 throw new TypeError("Super expression must either be null or a function");
1693 }
1694 if (typeof _cache !== "undefined") {
1695 if (_cache.has(Class)) return _cache.get(Class);
1696 _cache.set(Class, Wrapper);
1697 }
1698 function Wrapper() {
1699 return _construct(Class, arguments, _getPrototypeOf(this).constructor);
1700 }
1701 Wrapper.prototype = Object.create(Class.prototype, {
1702 constructor: {
1703 value: Wrapper,
1704 enumerable: false,
1705 writable: true,
1706 configurable: true
1707 }
1708 });
1709 return _setPrototypeOf(Wrapper, Class);
1710 };
1711 return _wrapNativeSuper(Class);
1712 }
1713 function _instanceof(left, right) {
1714 if (right != null && typeof Symbol !== "undefined" && right[Symbol.hasInstance]) {
1715 return !!right[Symbol.hasInstance](left);
1716 } else {
1717 return left instanceof right;
1718 }
1719 }
1720 function _interopRequireDefault(obj) {
1721 return obj && obj.__esModule ? obj : {
1722 default: obj
1723 };
1724 }
1725 function _getRequireWildcardCache(nodeInterop) {
1726 if (typeof WeakMap !== "function") return null;
1727 var cacheBabelInterop = new WeakMap();
1728 var cacheNodeInterop = new WeakMap();
1729 return (_getRequireWildcardCache = function (nodeInterop) {
1730 return nodeInterop ? cacheNodeInterop : cacheBabelInterop;
1731 })(nodeInterop);
1732 }
1733 function _interopRequireWildcard(obj, nodeInterop) {
1734 if (!nodeInterop && obj && obj.__esModule) {
1735 return obj;
1736 }
1737 if (obj === null || typeof obj !== "object" && typeof obj !== "function") {
1738 return {
1739 default: obj
1740 };
1741 }
1742 var cache = _getRequireWildcardCache(nodeInterop);
1743 if (cache && cache.has(obj)) {
1744 return cache.get(obj);
1745 }
1746 var newObj = {};
1747 var hasPropertyDescriptor = Object.defineProperty && Object.getOwnPropertyDescriptor;
1748 for (var key in obj) {
1749 if (key !== "default" && Object.prototype.hasOwnProperty.call(obj, key)) {
1750 var desc = hasPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : null;
1751 if (desc && (desc.get || desc.set)) {
1752 Object.defineProperty(newObj, key, desc);
1753 } else {
1754 newObj[key] = obj[key];
1755 }
1756 }
1757 }
1758 newObj.default = obj;
1759 if (cache) {
1760 cache.set(obj, newObj);
1761 }
1762 return newObj;
1763 }
1764 function _newArrowCheck(innerThis, boundThis) {
1765 if (innerThis !== boundThis) {
1766 throw new TypeError("Cannot instantiate an arrow function");
1767 }
1768 }
1769 function _objectDestructuringEmpty(obj) {
1770 if (obj == null) throw new TypeError("Cannot destructure " + obj);
1771 }
1772 function _objectWithoutPropertiesLoose(source, excluded) {
1773 if (source == null) return {};
1774 var target = {};
1775 var sourceKeys = Object.keys(source);
1776 var key, i;
1777 for (i = 0; i < sourceKeys.length; i++) {
1778 key = sourceKeys[i];
1779 if (excluded.indexOf(key) >= 0) continue;
1780 target[key] = source[key];
1781 }
1782 return target;
1783 }
1784 function _objectWithoutProperties(source, excluded) {
1785 if (source == null) return {};
1786 var target = _objectWithoutPropertiesLoose(source, excluded);
1787 var key, i;
1788 if (Object.getOwnPropertySymbols) {
1789 var sourceSymbolKeys = Object.getOwnPropertySymbols(source);
1790 for (i = 0; i < sourceSymbolKeys.length; i++) {
1791 key = sourceSymbolKeys[i];
1792 if (excluded.indexOf(key) >= 0) continue;
1793 if (!Object.prototype.propertyIsEnumerable.call(source, key)) continue;
1794 target[key] = source[key];
1795 }
1796 }
1797 return target;
1798 }
1799 function _assertThisInitialized(self) {
1800 if (self === void 0) {
1801 throw new ReferenceError("this hasn't been initialised - super() hasn't been called");
1802 }
1803 return self;
1804 }
1805 function _possibleConstructorReturn(self, call) {
1806 if (call && (typeof call === "object" || typeof call === "function")) {
1807 return call;
1808 } else if (call !== void 0) {
1809 throw new TypeError("Derived constructors may only return object or undefined");
1810 }
1811 return _assertThisInitialized(self);
1812 }
1813 function _createSuper(Derived) {
1814 var hasNativeReflectConstruct = _isNativeReflectConstruct();
1815 return function _createSuperInternal() {
1816 var Super = _getPrototypeOf(Derived),
1817 result;
1818 if (hasNativeReflectConstruct) {
1819 var NewTarget = _getPrototypeOf(this).constructor;
1820 result = Reflect.construct(Super, arguments, NewTarget);
1821 } else {
1822 result = Super.apply(this, arguments);
1823 }
1824 return _possibleConstructorReturn(this, result);
1825 };
1826 }
1827 function _superPropBase(object, property) {
1828 while (!Object.prototype.hasOwnProperty.call(object, property)) {
1829 object = _getPrototypeOf(object);
1830 if (object === null) break;
1831 }
1832 return object;
1833 }
1834 function _get() {
1835 if (typeof Reflect !== "undefined" && Reflect.get) {
1836 _get = Reflect.get.bind();
1837 } else {
1838 _get = function _get(target, property, receiver) {
1839 var base = _superPropBase(target, property);
1840 if (!base) return;
1841 var desc = Object.getOwnPropertyDescriptor(base, property);
1842 if (desc.get) {
1843 return desc.get.call(arguments.length < 3 ? target : receiver);
1844 }
1845 return desc.value;
1846 };
1847 }
1848 return _get.apply(this, arguments);
1849 }
1850 function set$4(target, property, value, receiver) {
1851 if (typeof Reflect !== "undefined" && Reflect.set) {
1852 set$4 = Reflect.set;
1853 } else {
1854 set$4 = function set(target, property, value, receiver) {
1855 var base = _superPropBase(target, property);
1856 var desc;
1857 if (base) {
1858 desc = Object.getOwnPropertyDescriptor(base, property);
1859 if (desc.set) {
1860 desc.set.call(receiver, value);
1861 return true;
1862 } else if (!desc.writable) {
1863 return false;
1864 }
1865 }
1866 desc = Object.getOwnPropertyDescriptor(receiver, property);
1867 if (desc) {
1868 if (!desc.writable) {
1869 return false;
1870 }
1871 desc.value = value;
1872 Object.defineProperty(receiver, property, desc);
1873 } else {
1874 _defineProperty(receiver, property, value);
1875 }
1876 return true;
1877 };
1878 }
1879 return set$4(target, property, value, receiver);
1880 }
1881 function _set(target, property, value, receiver, isStrict) {
1882 var s = set$4(target, property, value, receiver || target);
1883 if (!s && isStrict) {
1884 throw new TypeError('failed to set property');
1885 }
1886 return value;
1887 }
1888 function _taggedTemplateLiteral(strings, raw) {
1889 if (!raw) {
1890 raw = strings.slice(0);
1891 }
1892 return Object.freeze(Object.defineProperties(strings, {
1893 raw: {
1894 value: Object.freeze(raw)
1895 }
1896 }));
1897 }
1898 function _taggedTemplateLiteralLoose(strings, raw) {
1899 if (!raw) {
1900 raw = strings.slice(0);
1901 }
1902 strings.raw = raw;
1903 return strings;
1904 }
1905 function _readOnlyError(name) {
1906 throw new TypeError("\"" + name + "\" is read-only");
1907 }
1908 function _writeOnlyError(name) {
1909 throw new TypeError("\"" + name + "\" is write-only");
1910 }
1911 function _classNameTDZError(name) {
1912 throw new ReferenceError("Class \"" + name + "\" cannot be referenced in computed property keys.");
1913 }
1914 function _temporalUndefined() {}
1915 function _tdz(name) {
1916 throw new ReferenceError(name + " is not defined - temporal dead zone");
1917 }
1918 function _temporalRef(val, name) {
1919 return val === _temporalUndefined ? _tdz(name) : val;
1920 }
1921 function _slicedToArray(arr, i) {
1922 return _arrayWithHoles(arr) || _iterableToArrayLimit(arr, i) || _unsupportedIterableToArray(arr, i) || _nonIterableRest();
1923 }
1924 function _slicedToArrayLoose(arr, i) {
1925 return _arrayWithHoles(arr) || _iterableToArrayLimitLoose(arr, i) || _unsupportedIterableToArray(arr, i) || _nonIterableRest();
1926 }
1927 function _toArray(arr) {
1928 return _arrayWithHoles(arr) || _iterableToArray(arr) || _unsupportedIterableToArray(arr) || _nonIterableRest();
1929 }
1930 function _toConsumableArray(arr) {
1931 return _arrayWithoutHoles(arr) || _iterableToArray(arr) || _unsupportedIterableToArray(arr) || _nonIterableSpread();
1932 }
1933 function _arrayWithoutHoles(arr) {
1934 if (Array.isArray(arr)) return _arrayLikeToArray(arr);
1935 }
1936 function _arrayWithHoles(arr) {
1937 if (Array.isArray(arr)) return arr;
1938 }
1939 function _maybeArrayLike(next, arr, i) {
1940 if (arr && !Array.isArray(arr) && typeof arr.length === "number") {
1941 var len = arr.length;
1942 return _arrayLikeToArray(arr, i !== void 0 && i < len ? i : len);
1943 }
1944 return next(arr, i);
1945 }
1946 function _iterableToArray(iter) {
1947 if (typeof Symbol !== "undefined" && iter[Symbol.iterator] != null || iter["@@iterator"] != null) return Array.from(iter);
1948 }
1949 function _unsupportedIterableToArray(o, minLen) {
1950 if (!o) return;
1951 if (typeof o === "string") return _arrayLikeToArray(o, minLen);
1952 var n = Object.prototype.toString.call(o).slice(8, -1);
1953 if (n === "Object" && o.constructor) n = o.constructor.name;
1954 if (n === "Map" || n === "Set") return Array.from(o);
1955 if (n === "Arguments" || /^(?:Ui|I)nt(?:8|16|32)(?:Clamped)?Array$/.test(n)) return _arrayLikeToArray(o, minLen);
1956 }
1957 function _arrayLikeToArray(arr, len) {
1958 if (len == null || len > arr.length) len = arr.length;
1959 for (var i = 0, arr2 = new Array(len); i < len; i++) arr2[i] = arr[i];
1960 return arr2;
1961 }
1962 function _nonIterableSpread() {
1963 throw new TypeError("Invalid attempt to spread non-iterable instance.\nIn order to be iterable, non-array objects must have a [Symbol.iterator]() method.");
1964 }
1965 function _nonIterableRest() {
1966 throw new TypeError("Invalid attempt to destructure non-iterable instance.\nIn order to be iterable, non-array objects must have a [Symbol.iterator]() method.");
1967 }
1968 function _createForOfIteratorHelper(o, allowArrayLike) {
1969 var it = typeof Symbol !== "undefined" && o[Symbol.iterator] || o["@@iterator"];
1970 if (!it) {
1971 if (Array.isArray(o) || (it = _unsupportedIterableToArray(o)) || allowArrayLike && o && typeof o.length === "number") {
1972 if (it) o = it;
1973 var i = 0;
1974 var F = function () {};
1975 return {
1976 s: F,
1977 n: function () {
1978 if (i >= o.length) return {
1979 done: true
1980 };
1981 return {
1982 done: false,
1983 value: o[i++]
1984 };
1985 },
1986 e: function (e) {
1987 throw e;
1988 },
1989 f: F
1990 };
1991 }
1992 throw new TypeError("Invalid attempt to iterate non-iterable instance.\nIn order to be iterable, non-array objects must have a [Symbol.iterator]() method.");
1993 }
1994 var normalCompletion = true,
1995 didErr = false,
1996 err;
1997 return {
1998 s: function () {
1999 it = it.call(o);
2000 },
2001 n: function () {
2002 var step = it.next();
2003 normalCompletion = step.done;
2004 return step;
2005 },
2006 e: function (e) {
2007 didErr = true;
2008 err = e;
2009 },
2010 f: function () {
2011 try {
2012 if (!normalCompletion && it.return != null) it.return();
2013 } finally {
2014 if (didErr) throw err;
2015 }
2016 }
2017 };
2018 }
2019 function _createForOfIteratorHelperLoose(o, allowArrayLike) {
2020 var it = typeof Symbol !== "undefined" && o[Symbol.iterator] || o["@@iterator"];
2021 if (it) return (it = it.call(o)).next.bind(it);
2022 if (Array.isArray(o) || (it = _unsupportedIterableToArray(o)) || allowArrayLike && o && typeof o.length === "number") {
2023 if (it) o = it;
2024 var i = 0;
2025 return function () {
2026 if (i >= o.length) return {
2027 done: true
2028 };
2029 return {
2030 done: false,
2031 value: o[i++]
2032 };
2033 };
2034 }
2035 throw new TypeError("Invalid attempt to iterate non-iterable instance.\nIn order to be iterable, non-array objects must have a [Symbol.iterator]() method.");
2036 }
2037 function _skipFirstGeneratorNext(fn) {
2038 return function () {
2039 var it = fn.apply(this, arguments);
2040 it.next();
2041 return it;
2042 };
2043 }
2044 function _toPrimitive(input, hint) {
2045 if (typeof input !== "object" || input === null) return input;
2046 var prim = input[Symbol.toPrimitive];
2047 if (prim !== undefined) {
2048 var res = prim.call(input, hint || "default");
2049 if (typeof res !== "object") return res;
2050 throw new TypeError("@@toPrimitive must return a primitive value.");
2051 }
2052 return (hint === "string" ? String : Number)(input);
2053 }
2054 function _toPropertyKey(arg) {
2055 var key = _toPrimitive(arg, "string");
2056 return typeof key === "symbol" ? key : String(key);
2057 }
2058 function _initializerWarningHelper(descriptor, context) {
2059 throw new Error('Decorating class property failed. Please ensure that ' + 'proposal-class-properties is enabled and runs after the decorators transform.');
2060 }
2061 function _initializerDefineProperty(target, property, descriptor, context) {
2062 if (!descriptor) return;
2063 Object.defineProperty(target, property, {
2064 enumerable: descriptor.enumerable,
2065 configurable: descriptor.configurable,
2066 writable: descriptor.writable,
2067 value: descriptor.initializer ? descriptor.initializer.call(context) : void 0
2068 });
2069 }
2070 function _applyDecoratedDescriptor(target, property, decorators, descriptor, context) {
2071 var desc = {};
2072 Object.keys(descriptor).forEach(function (key) {
2073 desc[key] = descriptor[key];
2074 });
2075 desc.enumerable = !!desc.enumerable;
2076 desc.configurable = !!desc.configurable;
2077 if ('value' in desc || desc.initializer) {
2078 desc.writable = true;
2079 }
2080 desc = decorators.slice().reverse().reduce(function (desc, decorator) {
2081 return decorator(target, property, desc) || desc;
2082 }, desc);
2083 if (context && desc.initializer !== void 0) {
2084 desc.value = desc.initializer ? desc.initializer.call(context) : void 0;
2085 desc.initializer = undefined;
2086 }
2087 if (desc.initializer === void 0) {
2088 Object.defineProperty(target, property, desc);
2089 desc = null;
2090 }
2091 return desc;
2092 }
2093 var id$3 = 0;
2094 function _classPrivateFieldLooseKey(name) {
2095 return "__private_" + id$3++ + "_" + name;
2096 }
2097 function _classPrivateFieldLooseBase(receiver, privateKey) {
2098 if (!Object.prototype.hasOwnProperty.call(receiver, privateKey)) {
2099 throw new TypeError("attempted to use private field on non-instance");
2100 }
2101 return receiver;
2102 }
2103 function _classPrivateFieldGet(receiver, privateMap) {
2104 var descriptor = _classExtractFieldDescriptor(receiver, privateMap, "get");
2105 return _classApplyDescriptorGet(receiver, descriptor);
2106 }
2107 function _classPrivateFieldSet(receiver, privateMap, value) {
2108 var descriptor = _classExtractFieldDescriptor(receiver, privateMap, "set");
2109 _classApplyDescriptorSet(receiver, descriptor, value);
2110 return value;
2111 }
2112 function _classPrivateFieldDestructureSet(receiver, privateMap) {
2113 var descriptor = _classExtractFieldDescriptor(receiver, privateMap, "set");
2114 return _classApplyDescriptorDestructureSet(receiver, descriptor);
2115 }
2116 function _classExtractFieldDescriptor(receiver, privateMap, action) {
2117 if (!privateMap.has(receiver)) {
2118 throw new TypeError("attempted to " + action + " private field on non-instance");
2119 }
2120 return privateMap.get(receiver);
2121 }
2122 function _classStaticPrivateFieldSpecGet(receiver, classConstructor, descriptor) {
2123 _classCheckPrivateStaticAccess(receiver, classConstructor);
2124 _classCheckPrivateStaticFieldDescriptor(descriptor, "get");
2125 return _classApplyDescriptorGet(receiver, descriptor);
2126 }
2127 function _classStaticPrivateFieldSpecSet(receiver, classConstructor, descriptor, value) {
2128 _classCheckPrivateStaticAccess(receiver, classConstructor);
2129 _classCheckPrivateStaticFieldDescriptor(descriptor, "set");
2130 _classApplyDescriptorSet(receiver, descriptor, value);
2131 return value;
2132 }
2133 function _classStaticPrivateMethodGet(receiver, classConstructor, method) {
2134 _classCheckPrivateStaticAccess(receiver, classConstructor);
2135 return method;
2136 }
2137 function _classStaticPrivateMethodSet() {
2138 throw new TypeError("attempted to set read only static private field");
2139 }
2140 function _classApplyDescriptorGet(receiver, descriptor) {
2141 if (descriptor.get) {
2142 return descriptor.get.call(receiver);
2143 }
2144 return descriptor.value;
2145 }
2146 function _classApplyDescriptorSet(receiver, descriptor, value) {
2147 if (descriptor.set) {
2148 descriptor.set.call(receiver, value);
2149 } else {
2150 if (!descriptor.writable) {
2151 throw new TypeError("attempted to set read only private field");
2152 }
2153 descriptor.value = value;
2154 }
2155 }
2156 function _classApplyDescriptorDestructureSet(receiver, descriptor) {
2157 if (descriptor.set) {
2158 if (!("__destrObj" in descriptor)) {
2159 descriptor.__destrObj = {
2160 set value(v) {
2161 descriptor.set.call(receiver, v);
2162 }
2163 };
2164 }
2165 return descriptor.__destrObj;
2166 } else {
2167 if (!descriptor.writable) {
2168 throw new TypeError("attempted to set read only private field");
2169 }
2170 return descriptor;
2171 }
2172 }
2173 function _classStaticPrivateFieldDestructureSet(receiver, classConstructor, descriptor) {
2174 _classCheckPrivateStaticAccess(receiver, classConstructor);
2175 _classCheckPrivateStaticFieldDescriptor(descriptor, "set");
2176 return _classApplyDescriptorDestructureSet(receiver, descriptor);
2177 }
2178 function _classCheckPrivateStaticAccess(receiver, classConstructor) {
2179 if (receiver !== classConstructor) {
2180 throw new TypeError("Private static access of wrong provenance");
2181 }
2182 }
2183 function _classCheckPrivateStaticFieldDescriptor(descriptor, action) {
2184 if (descriptor === undefined) {
2185 throw new TypeError("attempted to " + action + " private static field before its declaration");
2186 }
2187 }
2188 function _decorate(decorators, factory, superClass, mixins) {
2189 var api = _getDecoratorsApi();
2190 if (mixins) {
2191 for (var i = 0; i < mixins.length; i++) {
2192 api = mixins[i](api);
2193 }
2194 }
2195 var r = factory(function initialize(O) {
2196 api.initializeInstanceElements(O, decorated.elements);
2197 }, superClass);
2198 var decorated = api.decorateClass(_coalesceClassElements(r.d.map(_createElementDescriptor)), decorators);
2199 api.initializeClassElements(r.F, decorated.elements);
2200 return api.runClassFinishers(r.F, decorated.finishers);
2201 }
2202 function _getDecoratorsApi() {
2203 _getDecoratorsApi = function () {
2204 return api;
2205 };
2206 var api = {
2207 elementsDefinitionOrder: [["method"], ["field"]],
2208 initializeInstanceElements: function (O, elements) {
2209 ["method", "field"].forEach(function (kind) {
2210 elements.forEach(function (element) {
2211 if (element.kind === kind && element.placement === "own") {
2212 this.defineClassElement(O, element);
2213 }
2214 }, this);
2215 }, this);
2216 },
2217 initializeClassElements: function (F, elements) {
2218 var proto = F.prototype;
2219 ["method", "field"].forEach(function (kind) {
2220 elements.forEach(function (element) {
2221 var placement = element.placement;
2222 if (element.kind === kind && (placement === "static" || placement === "prototype")) {
2223 var receiver = placement === "static" ? F : proto;
2224 this.defineClassElement(receiver, element);
2225 }
2226 }, this);
2227 }, this);
2228 },
2229 defineClassElement: function (receiver, element) {
2230 var descriptor = element.descriptor;
2231 if (element.kind === "field") {
2232 var initializer = element.initializer;
2233 descriptor = {
2234 enumerable: descriptor.enumerable,
2235 writable: descriptor.writable,
2236 configurable: descriptor.configurable,
2237 value: initializer === void 0 ? void 0 : initializer.call(receiver)
2238 };
2239 }
2240 Object.defineProperty(receiver, element.key, descriptor);
2241 },
2242 decorateClass: function (elements, decorators) {
2243 var newElements = [];
2244 var finishers = [];
2245 var placements = {
2246 static: [],
2247 prototype: [],
2248 own: []
2249 };
2250 elements.forEach(function (element) {
2251 this.addElementPlacement(element, placements);
2252 }, this);
2253 elements.forEach(function (element) {
2254 if (!_hasDecorators(element)) return newElements.push(element);
2255 var elementFinishersExtras = this.decorateElement(element, placements);
2256 newElements.push(elementFinishersExtras.element);
2257 newElements.push.apply(newElements, elementFinishersExtras.extras);
2258 finishers.push.apply(finishers, elementFinishersExtras.finishers);
2259 }, this);
2260 if (!decorators) {
2261 return {
2262 elements: newElements,
2263 finishers: finishers
2264 };
2265 }
2266 var result = this.decorateConstructor(newElements, decorators);
2267 finishers.push.apply(finishers, result.finishers);
2268 result.finishers = finishers;
2269 return result;
2270 },
2271 addElementPlacement: function (element, placements, silent) {
2272 var keys = placements[element.placement];
2273 if (!silent && keys.indexOf(element.key) !== -1) {
2274 throw new TypeError("Duplicated element (" + element.key + ")");
2275 }
2276 keys.push(element.key);
2277 },
2278 decorateElement: function (element, placements) {
2279 var extras = [];
2280 var finishers = [];
2281 for (var decorators = element.decorators, i = decorators.length - 1; i >= 0; i--) {
2282 var keys = placements[element.placement];
2283 keys.splice(keys.indexOf(element.key), 1);
2284 var elementObject = this.fromElementDescriptor(element);
2285 var elementFinisherExtras = this.toElementFinisherExtras((0, decorators[i])(elementObject) || elementObject);
2286 element = elementFinisherExtras.element;
2287 this.addElementPlacement(element, placements);
2288 if (elementFinisherExtras.finisher) {
2289 finishers.push(elementFinisherExtras.finisher);
2290 }
2291 var newExtras = elementFinisherExtras.extras;
2292 if (newExtras) {
2293 for (var j = 0; j < newExtras.length; j++) {
2294 this.addElementPlacement(newExtras[j], placements);
2295 }
2296 extras.push.apply(extras, newExtras);
2297 }
2298 }
2299 return {
2300 element: element,
2301 finishers: finishers,
2302 extras: extras
2303 };
2304 },
2305 decorateConstructor: function (elements, decorators) {
2306 var finishers = [];
2307 for (var i = decorators.length - 1; i >= 0; i--) {
2308 var obj = this.fromClassDescriptor(elements);
2309 var elementsAndFinisher = this.toClassDescriptor((0, decorators[i])(obj) || obj);
2310 if (elementsAndFinisher.finisher !== undefined) {
2311 finishers.push(elementsAndFinisher.finisher);
2312 }
2313 if (elementsAndFinisher.elements !== undefined) {
2314 elements = elementsAndFinisher.elements;
2315 for (var j = 0; j < elements.length - 1; j++) {
2316 for (var k = j + 1; k < elements.length; k++) {
2317 if (elements[j].key === elements[k].key && elements[j].placement === elements[k].placement) {
2318 throw new TypeError("Duplicated element (" + elements[j].key + ")");
2319 }
2320 }
2321 }
2322 }
2323 }
2324 return {
2325 elements: elements,
2326 finishers: finishers
2327 };
2328 },
2329 fromElementDescriptor: function (element) {
2330 var obj = {
2331 kind: element.kind,
2332 key: element.key,
2333 placement: element.placement,
2334 descriptor: element.descriptor
2335 };
2336 var desc = {
2337 value: "Descriptor",
2338 configurable: true
2339 };
2340 Object.defineProperty(obj, Symbol.toStringTag, desc);
2341 if (element.kind === "field") obj.initializer = element.initializer;
2342 return obj;
2343 },
2344 toElementDescriptors: function (elementObjects) {
2345 if (elementObjects === undefined) return;
2346 return _toArray(elementObjects).map(function (elementObject) {
2347 var element = this.toElementDescriptor(elementObject);
2348 this.disallowProperty(elementObject, "finisher", "An element descriptor");
2349 this.disallowProperty(elementObject, "extras", "An element descriptor");
2350 return element;
2351 }, this);
2352 },
2353 toElementDescriptor: function (elementObject) {
2354 var kind = String(elementObject.kind);
2355 if (kind !== "method" && kind !== "field") {
2356 throw new TypeError('An element descriptor\'s .kind property must be either "method" or' + ' "field", but a decorator created an element descriptor with' + ' .kind "' + kind + '"');
2357 }
2358 var key = _toPropertyKey(elementObject.key);
2359 var placement = String(elementObject.placement);
2360 if (placement !== "static" && placement !== "prototype" && placement !== "own") {
2361 throw new TypeError('An element descriptor\'s .placement property must be one of "static",' + ' "prototype" or "own", but a decorator created an element descriptor' + ' with .placement "' + placement + '"');
2362 }
2363 var descriptor = elementObject.descriptor;
2364 this.disallowProperty(elementObject, "elements", "An element descriptor");
2365 var element = {
2366 kind: kind,
2367 key: key,
2368 placement: placement,
2369 descriptor: Object.assign({}, descriptor)
2370 };
2371 if (kind !== "field") {
2372 this.disallowProperty(elementObject, "initializer", "A method descriptor");
2373 } else {
2374 this.disallowProperty(descriptor, "get", "The property descriptor of a field descriptor");
2375 this.disallowProperty(descriptor, "set", "The property descriptor of a field descriptor");
2376 this.disallowProperty(descriptor, "value", "The property descriptor of a field descriptor");
2377 element.initializer = elementObject.initializer;
2378 }
2379 return element;
2380 },
2381 toElementFinisherExtras: function (elementObject) {
2382 var element = this.toElementDescriptor(elementObject);
2383 var finisher = _optionalCallableProperty(elementObject, "finisher");
2384 var extras = this.toElementDescriptors(elementObject.extras);
2385 return {
2386 element: element,
2387 finisher: finisher,
2388 extras: extras
2389 };
2390 },
2391 fromClassDescriptor: function (elements) {
2392 var obj = {
2393 kind: "class",
2394 elements: elements.map(this.fromElementDescriptor, this)
2395 };
2396 var desc = {
2397 value: "Descriptor",
2398 configurable: true
2399 };
2400 Object.defineProperty(obj, Symbol.toStringTag, desc);
2401 return obj;
2402 },
2403 toClassDescriptor: function (obj) {
2404 var kind = String(obj.kind);
2405 if (kind !== "class") {
2406 throw new TypeError('A class descriptor\'s .kind property must be "class", but a decorator' + ' created a class descriptor with .kind "' + kind + '"');
2407 }
2408 this.disallowProperty(obj, "key", "A class descriptor");
2409 this.disallowProperty(obj, "placement", "A class descriptor");
2410 this.disallowProperty(obj, "descriptor", "A class descriptor");
2411 this.disallowProperty(obj, "initializer", "A class descriptor");
2412 this.disallowProperty(obj, "extras", "A class descriptor");
2413 var finisher = _optionalCallableProperty(obj, "finisher");
2414 var elements = this.toElementDescriptors(obj.elements);
2415 return {
2416 elements: elements,
2417 finisher: finisher
2418 };
2419 },
2420 runClassFinishers: function (constructor, finishers) {
2421 for (var i = 0; i < finishers.length; i++) {
2422 var newConstructor = (0, finishers[i])(constructor);
2423 if (newConstructor !== undefined) {
2424 if (typeof newConstructor !== "function") {
2425 throw new TypeError("Finishers must return a constructor.");
2426 }
2427 constructor = newConstructor;
2428 }
2429 }
2430 return constructor;
2431 },
2432 disallowProperty: function (obj, name, objectType) {
2433 if (obj[name] !== undefined) {
2434 throw new TypeError(objectType + " can't have a ." + name + " property.");
2435 }
2436 }
2437 };
2438 return api;
2439 }
2440 function _createElementDescriptor(def) {
2441 var key = _toPropertyKey(def.key);
2442 var descriptor;
2443 if (def.kind === "method") {
2444 descriptor = {
2445 value: def.value,
2446 writable: true,
2447 configurable: true,
2448 enumerable: false
2449 };
2450 } else if (def.kind === "get") {
2451 descriptor = {
2452 get: def.value,
2453 configurable: true,
2454 enumerable: false
2455 };
2456 } else if (def.kind === "set") {
2457 descriptor = {
2458 set: def.value,
2459 configurable: true,
2460 enumerable: false
2461 };
2462 } else if (def.kind === "field") {
2463 descriptor = {
2464 configurable: true,
2465 writable: true,
2466 enumerable: true
2467 };
2468 }
2469 var element = {
2470 kind: def.kind === "field" ? "field" : "method",
2471 key: key,
2472 placement: def.static ? "static" : def.kind === "field" ? "own" : "prototype",
2473 descriptor: descriptor
2474 };
2475 if (def.decorators) element.decorators = def.decorators;
2476 if (def.kind === "field") element.initializer = def.value;
2477 return element;
2478 }
2479 function _coalesceGetterSetter(element, other) {
2480 if (element.descriptor.get !== undefined) {
2481 other.descriptor.get = element.descriptor.get;
2482 } else {
2483 other.descriptor.set = element.descriptor.set;
2484 }
2485 }
2486 function _coalesceClassElements(elements) {
2487 var newElements = [];
2488 var isSameElement = function (other) {
2489 return other.kind === "method" && other.key === element.key && other.placement === element.placement;
2490 };
2491 for (var i = 0; i < elements.length; i++) {
2492 var element = elements[i];
2493 var other;
2494 if (element.kind === "method" && (other = newElements.find(isSameElement))) {
2495 if (_isDataDescriptor(element.descriptor) || _isDataDescriptor(other.descriptor)) {
2496 if (_hasDecorators(element) || _hasDecorators(other)) {
2497 throw new ReferenceError("Duplicated methods (" + element.key + ") can't be decorated.");
2498 }
2499 other.descriptor = element.descriptor;
2500 } else {
2501 if (_hasDecorators(element)) {
2502 if (_hasDecorators(other)) {
2503 throw new ReferenceError("Decorators can't be placed on different accessors with for " + "the same property (" + element.key + ").");
2504 }
2505 other.decorators = element.decorators;
2506 }
2507 _coalesceGetterSetter(element, other);
2508 }
2509 } else {
2510 newElements.push(element);
2511 }
2512 }
2513 return newElements;
2514 }
2515 function _hasDecorators(element) {
2516 return element.decorators && element.decorators.length;
2517 }
2518 function _isDataDescriptor(desc) {
2519 return desc !== undefined && !(desc.value === undefined && desc.writable === undefined);
2520 }
2521 function _optionalCallableProperty(obj, name) {
2522 var value = obj[name];
2523 if (value !== undefined && typeof value !== "function") {
2524 throw new TypeError("Expected '" + name + "' to be a function");
2525 }
2526 return value;
2527 }
2528 function _classPrivateMethodGet(receiver, privateSet, fn) {
2529 if (!privateSet.has(receiver)) {
2530 throw new TypeError("attempted to get private field on non-instance");
2531 }
2532 return fn;
2533 }
2534 function _checkPrivateRedeclaration(obj, privateCollection) {
2535 if (privateCollection.has(obj)) {
2536 throw new TypeError("Cannot initialize the same private elements twice on an object");
2537 }
2538 }
2539 function _classPrivateFieldInitSpec(obj, privateMap, value) {
2540 _checkPrivateRedeclaration(obj, privateMap);
2541 privateMap.set(obj, value);
2542 }
2543 function _classPrivateMethodInitSpec(obj, privateSet) {
2544 _checkPrivateRedeclaration(obj, privateSet);
2545 privateSet.add(obj);
2546 }
2547 function _classPrivateMethodSet() {
2548 throw new TypeError("attempted to reassign private method");
2549 }
2550 function _identity(x) {
2551 return x;
2552 }
2553
2554 var check = function check(it) {
2555 return it && it.Math == Math && it;
2556 };
2557
2558 // https://github.com/zloirock/core-js/issues/86#issuecomment-115759028
2559 var global$Z =
2560 // eslint-disable-next-line es/no-global-this -- safe
2561 check((typeof globalThis === "undefined" ? "undefined" : _typeof(globalThis)) == 'object' && globalThis) || check((typeof window === "undefined" ? "undefined" : _typeof(window)) == 'object' && window) ||
2562 // eslint-disable-next-line no-restricted-globals -- safe
2563 check((typeof self === "undefined" ? "undefined" : _typeof(self)) == 'object' && self) || check(_typeof(commonjsGlobal) == 'object' && commonjsGlobal) ||
2564 // eslint-disable-next-line no-new-func -- fallback
2565 function () {
2566 return this;
2567 }() || Function('return this')();
2568 var global$_ = /*@__PURE__*/getDefaultExportFromCjs(global$Z);
2569
2570 var objectGetOwnPropertyDescriptor = {};
2571
2572 var fails$1m = function fails(exec) {
2573 try {
2574 return !!exec();
2575 } catch (error) {
2576 return true;
2577 }
2578 };
2579 var fails$1n = /*@__PURE__*/getDefaultExportFromCjs(fails$1m);
2580
2581 var fails$1l = fails$1m;
2582
2583 // Detect IE8's incomplete defineProperty implementation
2584 var descriptors = !fails$1l(function () {
2585 // eslint-disable-next-line es/no-object-defineproperty -- required for testing
2586 return Object.defineProperty({}, 1, {
2587 get: function get() {
2588 return 7;
2589 }
2590 })[1] != 7;
2591 });
2592 var descriptors$1 = /*@__PURE__*/getDefaultExportFromCjs(descriptors);
2593
2594 var fails$1k = fails$1m;
2595 var functionBindNative = !fails$1k(function () {
2596 // eslint-disable-next-line es/no-function-prototype-bind -- safe
2597 var test = function () {/* empty */}.bind();
2598 // eslint-disable-next-line no-prototype-builtins -- safe
2599 return typeof test != 'function' || test.hasOwnProperty('prototype');
2600 });
2601 var functionBindNative$1 = /*@__PURE__*/getDefaultExportFromCjs(functionBindNative);
2602
2603 var NATIVE_BIND$4 = functionBindNative;
2604 var call$E = Function.prototype.call;
2605 var functionCall = NATIVE_BIND$4 ? call$E.bind(call$E) : function () {
2606 return call$E.apply(call$E, arguments);
2607 };
2608 var functionCall$1 = /*@__PURE__*/getDefaultExportFromCjs(functionCall);
2609
2610 var objectPropertyIsEnumerable = {};
2611
2612 'use strict';
2613 var $propertyIsEnumerable$2 = {}.propertyIsEnumerable;
2614 // eslint-disable-next-line es/no-object-getownpropertydescriptor -- safe
2615 var getOwnPropertyDescriptor$a = Object.getOwnPropertyDescriptor;
2616
2617 // Nashorn ~ JDK8 bug
2618 var NASHORN_BUG = getOwnPropertyDescriptor$a && !$propertyIsEnumerable$2.call({
2619 1: 2
2620 }, 1);
2621
2622 // `Object.prototype.propertyIsEnumerable` method implementation
2623 // https://tc39.es/ecma262/#sec-object.prototype.propertyisenumerable
2624 var f$8 = objectPropertyIsEnumerable.f = NASHORN_BUG ? function propertyIsEnumerable(V) {
2625 var descriptor = getOwnPropertyDescriptor$a(this, V);
2626 return !!descriptor && descriptor.enumerable;
2627 } : $propertyIsEnumerable$2;
2628
2629 var createPropertyDescriptor$c = function createPropertyDescriptor(bitmap, value) {
2630 return {
2631 enumerable: !(bitmap & 1),
2632 configurable: !(bitmap & 2),
2633 writable: !(bitmap & 4),
2634 value: value
2635 };
2636 };
2637 var createPropertyDescriptor$d = /*@__PURE__*/getDefaultExportFromCjs(createPropertyDescriptor$c);
2638
2639 var NATIVE_BIND$3 = functionBindNative;
2640 var FunctionPrototype$4 = Function.prototype;
2641 var call$D = FunctionPrototype$4.call;
2642 var uncurryThisWithBind = NATIVE_BIND$3 && FunctionPrototype$4.bind.bind(call$D, call$D);
2643 var functionUncurryThis = NATIVE_BIND$3 ? uncurryThisWithBind : function (fn) {
2644 return function () {
2645 return call$D.apply(fn, arguments);
2646 };
2647 };
2648 var functionUncurryThis$1 = /*@__PURE__*/getDefaultExportFromCjs(functionUncurryThis);
2649
2650 var uncurryThis$1j = functionUncurryThis;
2651 var toString$A = uncurryThis$1j({}.toString);
2652 var stringSlice$i = uncurryThis$1j(''.slice);
2653 var classofRaw$2 = function classofRaw(it) {
2654 return stringSlice$i(toString$A(it), 8, -1);
2655 };
2656 var classofRaw$3 = /*@__PURE__*/getDefaultExportFromCjs(classofRaw$2);
2657
2658 var uncurryThis$1i = functionUncurryThis;
2659 var fails$1j = fails$1m;
2660 var classof$o = classofRaw$2;
2661 var $Object$5 = Object;
2662 var split$7 = uncurryThis$1i(''.split);
2663
2664 // fallback for non-array-like ES3 and non-enumerable old V8 strings
2665 var indexedObject = fails$1j(function () {
2666 // throws an error in rhino, see https://github.com/mozilla/rhino/issues/346
2667 // eslint-disable-next-line no-prototype-builtins -- safe
2668 return !$Object$5('z').propertyIsEnumerable(0);
2669 }) ? function (it) {
2670 return classof$o(it) == 'String' ? split$7(it, '') : $Object$5(it);
2671 } : $Object$5;
2672 var indexedObject$1 = /*@__PURE__*/getDefaultExportFromCjs(indexedObject);
2673
2674 // we can't use just `it == null` since of `document.all` special case
2675 // https://tc39.es/ecma262/#sec-IsHTMLDDA-internal-slot-aec
2676 var isNullOrUndefined$e = function isNullOrUndefined(it) {
2677 return it === null || it === undefined;
2678 };
2679 var isNullOrUndefined$f = /*@__PURE__*/getDefaultExportFromCjs(isNullOrUndefined$e);
2680
2681 var isNullOrUndefined$d = isNullOrUndefined$e;
2682 var $TypeError$p = TypeError;
2683
2684 // `RequireObjectCoercible` abstract operation
2685 // https://tc39.es/ecma262/#sec-requireobjectcoercible
2686 var requireObjectCoercible$j = function requireObjectCoercible(it) {
2687 if (isNullOrUndefined$d(it)) throw $TypeError$p("Can't call method on " + it);
2688 return it;
2689 };
2690 var requireObjectCoercible$k = /*@__PURE__*/getDefaultExportFromCjs(requireObjectCoercible$j);
2691
2692 // toObject with fallback for non-array-like ES3 strings
2693 var IndexedObject$5 = indexedObject;
2694 var requireObjectCoercible$i = requireObjectCoercible$j;
2695 var toIndexedObject$j = function toIndexedObject(it) {
2696 return IndexedObject$5(requireObjectCoercible$i(it));
2697 };
2698 var toIndexedObject$k = /*@__PURE__*/getDefaultExportFromCjs(toIndexedObject$j);
2699
2700 var documentAll$2 = (typeof document === "undefined" ? "undefined" : _typeof(document)) == 'object' && document.all;
2701
2702 // https://tc39.es/ecma262/#sec-IsHTMLDDA-internal-slot
2703 // eslint-disable-next-line unicorn/no-typeof-undefined -- required for testing
2704 var IS_HTMLDDA = typeof documentAll$2 == 'undefined' && documentAll$2 !== undefined;
2705 var documentAll_1 = {
2706 all: documentAll$2,
2707 IS_HTMLDDA: IS_HTMLDDA
2708 };
2709 var documentAll$3 = /*@__PURE__*/getDefaultExportFromCjs(documentAll_1);
2710
2711 var $documentAll$1 = documentAll_1;
2712 var documentAll$1 = $documentAll$1.all;
2713
2714 // `IsCallable` abstract operation
2715 // https://tc39.es/ecma262/#sec-iscallable
2716 var isCallable$z = $documentAll$1.IS_HTMLDDA ? function (argument) {
2717 return typeof argument == 'function' || argument === documentAll$1;
2718 } : function (argument) {
2719 return typeof argument == 'function';
2720 };
2721 var isCallable$A = /*@__PURE__*/getDefaultExportFromCjs(isCallable$z);
2722
2723 var isCallable$y = isCallable$z;
2724 var $documentAll = documentAll_1;
2725 var documentAll = $documentAll.all;
2726 var isObject$z = $documentAll.IS_HTMLDDA ? function (it) {
2727 return _typeof(it) == 'object' ? it !== null : isCallable$y(it) || it === documentAll;
2728 } : function (it) {
2729 return _typeof(it) == 'object' ? it !== null : isCallable$y(it);
2730 };
2731 var isObject$A = /*@__PURE__*/getDefaultExportFromCjs(isObject$z);
2732
2733 var global$Y = global$Z;
2734 var isCallable$x = isCallable$z;
2735 var aFunction = function aFunction(argument) {
2736 return isCallable$x(argument) ? argument : undefined;
2737 };
2738 var getBuiltIn$m = function getBuiltIn(namespace, method) {
2739 return arguments.length < 2 ? aFunction(global$Y[namespace]) : global$Y[namespace] && global$Y[namespace][method];
2740 };
2741 var getBuiltIn$n = /*@__PURE__*/getDefaultExportFromCjs(getBuiltIn$m);
2742
2743 var uncurryThis$1h = functionUncurryThis;
2744 var objectIsPrototypeOf = uncurryThis$1h({}.isPrototypeOf);
2745 var objectIsPrototypeOf$1 = /*@__PURE__*/getDefaultExportFromCjs(objectIsPrototypeOf);
2746
2747 var engineUserAgent = typeof navigator != 'undefined' && String(navigator.userAgent) || '';
2748 var engineUserAgent$1 = /*@__PURE__*/getDefaultExportFromCjs(engineUserAgent);
2749
2750 var global$X = global$Z;
2751 var userAgent$6 = engineUserAgent;
2752 var process$5 = global$X.process;
2753 var Deno$1 = global$X.Deno;
2754 var versions = process$5 && process$5.versions || Deno$1 && Deno$1.version;
2755 var v8 = versions && versions.v8;
2756 var match, version$8;
2757 if (v8) {
2758 match = v8.split('.');
2759 // in old Chrome, versions of V8 isn't V8 = Chrome / 10
2760 // but their correct versions are not interesting for us
2761 version$8 = match[0] > 0 && match[0] < 4 ? 1 : +(match[0] + match[1]);
2762 }
2763
2764 // BrowserFS NodeJS `process` polyfill incorrectly set `.v8` to `0.0`
2765 // so check `userAgent` even if `.v8` exists, but 0
2766 if (!version$8 && userAgent$6) {
2767 match = userAgent$6.match(/Edge\/(\d+)/);
2768 if (!match || match[1] >= 74) {
2769 match = userAgent$6.match(/Chrome\/(\d+)/);
2770 if (match) version$8 = +match[1];
2771 }
2772 }
2773 var engineV8Version = version$8;
2774 var engineV8Version$1 = /*@__PURE__*/getDefaultExportFromCjs(engineV8Version);
2775
2776 /* eslint-disable es/no-symbol -- required for testing */
2777 var V8_VERSION$3 = engineV8Version;
2778 var fails$1i = fails$1m;
2779
2780 // eslint-disable-next-line es/no-object-getownpropertysymbols -- required for testing
2781 var symbolConstructorDetection = !!Object.getOwnPropertySymbols && !fails$1i(function () {
2782 var symbol = Symbol();
2783 // Chrome 38 Symbol has incorrect toString conversion
2784 // `get-own-property-symbols` polyfill symbols converted to object are not Symbol instances
2785 return !String(symbol) || !(Object(symbol) instanceof Symbol) ||
2786 // Chrome 38-40 symbols are not inherited from DOM collections prototypes to instances
2787 !Symbol.sham && V8_VERSION$3 && V8_VERSION$3 < 41;
2788 });
2789 var symbolConstructorDetection$1 = /*@__PURE__*/getDefaultExportFromCjs(symbolConstructorDetection);
2790
2791 var NATIVE_SYMBOL$6 = symbolConstructorDetection;
2792 var useSymbolAsUid = NATIVE_SYMBOL$6 && !Symbol.sham && _typeof(Symbol.iterator) == 'symbol';
2793 var useSymbolAsUid$1 = /*@__PURE__*/getDefaultExportFromCjs(useSymbolAsUid);
2794
2795 var getBuiltIn$l = getBuiltIn$m;
2796 var isCallable$w = isCallable$z;
2797 var isPrototypeOf$b = objectIsPrototypeOf;
2798 var USE_SYMBOL_AS_UID$1 = useSymbolAsUid;
2799 var $Object$4 = Object;
2800 var isSymbol$7 = USE_SYMBOL_AS_UID$1 ? function (it) {
2801 return _typeof(it) == 'symbol';
2802 } : function (it) {
2803 var $Symbol = getBuiltIn$l('Symbol');
2804 return isCallable$w($Symbol) && isPrototypeOf$b($Symbol.prototype, $Object$4(it));
2805 };
2806 var isSymbol$8 = /*@__PURE__*/getDefaultExportFromCjs(isSymbol$7);
2807
2808 var $String$7 = String;
2809 var tryToString$7 = function tryToString(argument) {
2810 try {
2811 return $String$7(argument);
2812 } catch (error) {
2813 return 'Object';
2814 }
2815 };
2816 var tryToString$8 = /*@__PURE__*/getDefaultExportFromCjs(tryToString$7);
2817
2818 var isCallable$v = isCallable$z;
2819 var tryToString$6 = tryToString$7;
2820 var $TypeError$o = TypeError;
2821
2822 // `Assert: IsCallable(argument) is true`
2823 var aCallable$l = function aCallable(argument) {
2824 if (isCallable$v(argument)) return argument;
2825 throw $TypeError$o(tryToString$6(argument) + ' is not a function');
2826 };
2827 var aCallable$m = /*@__PURE__*/getDefaultExportFromCjs(aCallable$l);
2828
2829 var aCallable$k = aCallable$l;
2830 var isNullOrUndefined$c = isNullOrUndefined$e;
2831
2832 // `GetMethod` abstract operation
2833 // https://tc39.es/ecma262/#sec-getmethod
2834 var getMethod$9 = function getMethod(V, P) {
2835 var func = V[P];
2836 return isNullOrUndefined$c(func) ? undefined : aCallable$k(func);
2837 };
2838 var getMethod$a = /*@__PURE__*/getDefaultExportFromCjs(getMethod$9);
2839
2840 var call$C = functionCall;
2841 var isCallable$u = isCallable$z;
2842 var isObject$y = isObject$z;
2843 var $TypeError$n = TypeError;
2844
2845 // `OrdinaryToPrimitive` abstract operation
2846 // https://tc39.es/ecma262/#sec-ordinarytoprimitive
2847 var ordinaryToPrimitive$2 = function ordinaryToPrimitive(input, pref) {
2848 var fn, val;
2849 if (pref === 'string' && isCallable$u(fn = input.toString) && !isObject$y(val = call$C(fn, input))) return val;
2850 if (isCallable$u(fn = input.valueOf) && !isObject$y(val = call$C(fn, input))) return val;
2851 if (pref !== 'string' && isCallable$u(fn = input.toString) && !isObject$y(val = call$C(fn, input))) return val;
2852 throw $TypeError$n("Can't convert object to primitive value");
2853 };
2854 var ordinaryToPrimitive$3 = /*@__PURE__*/getDefaultExportFromCjs(ordinaryToPrimitive$2);
2855
2856 var shared$a = {exports: {}};
2857
2858 var isPure = false;
2859 var isPure$1 = /*@__PURE__*/getDefaultExportFromCjs(isPure);
2860
2861 var global$W = global$Z;
2862
2863 // eslint-disable-next-line es/no-object-defineproperty -- safe
2864 var defineProperty$e = Object.defineProperty;
2865 var defineGlobalProperty$3 = function defineGlobalProperty(key, value) {
2866 try {
2867 defineProperty$e(global$W, key, {
2868 value: value,
2869 configurable: true,
2870 writable: true
2871 });
2872 } catch (error) {
2873 global$W[key] = value;
2874 }
2875 return value;
2876 };
2877 var defineGlobalProperty$4 = /*@__PURE__*/getDefaultExportFromCjs(defineGlobalProperty$3);
2878
2879 var global$V = global$Z;
2880 var defineGlobalProperty$2 = defineGlobalProperty$3;
2881 var SHARED = '__core-js_shared__';
2882 var store$3 = global$V[SHARED] || defineGlobalProperty$2(SHARED, {});
2883 var sharedStore = store$3;
2884 var sharedStore$1 = /*@__PURE__*/getDefaultExportFromCjs(sharedStore);
2885
2886 var shared$8 = shared$a.exports;
2887 var IS_PURE$k = isPure;
2888 var store$2 = sharedStore;
2889 (shared$a.exports = function (key, value) {
2890 return store$2[key] || (store$2[key] = value !== undefined ? value : {});
2891 })('versions', []).push({
2892 version: '3.29.1',
2893 mode: IS_PURE$k ? 'pure' : 'global',
2894 copyright: '© 2014-2023 Denis Pushkarev (zloirock.ru)',
2895 license: 'https://github.com/zloirock/core-js/blob/v3.29.1/LICENSE',
2896 source: 'https://github.com/zloirock/core-js'
2897 });
2898 var sharedExports = shared$a.exports;
2899 var shared$9 = /*@__PURE__*/getDefaultExportFromCjs(sharedExports);
2900
2901 var requireObjectCoercible$h = requireObjectCoercible$j;
2902 var $Object$3 = Object;
2903
2904 // `ToObject` abstract operation
2905 // https://tc39.es/ecma262/#sec-toobject
2906 var toObject$t = function toObject(argument) {
2907 return $Object$3(requireObjectCoercible$h(argument));
2908 };
2909 var toObject$u = /*@__PURE__*/getDefaultExportFromCjs(toObject$t);
2910
2911 var uncurryThis$1g = functionUncurryThis;
2912 var toObject$s = toObject$t;
2913 var hasOwnProperty = uncurryThis$1g({}.hasOwnProperty);
2914
2915 // `HasOwnProperty` abstract operation
2916 // https://tc39.es/ecma262/#sec-hasownproperty
2917 // eslint-disable-next-line es/no-object-hasown -- safe
2918 var hasOwnProperty_1 = Object.hasOwn || function hasOwn(it, key) {
2919 return hasOwnProperty(toObject$s(it), key);
2920 };
2921 var hasOwnProperty$1 = /*@__PURE__*/getDefaultExportFromCjs(hasOwnProperty_1);
2922
2923 var uncurryThis$1f = functionUncurryThis;
2924 var id$2 = 0;
2925 var postfix = Math.random();
2926 var toString$z = uncurryThis$1f(1.0.toString);
2927 var uid$6 = function uid(key) {
2928 return 'Symbol(' + (key === undefined ? '' : key) + ')_' + toString$z(++id$2 + postfix, 36);
2929 };
2930 var uid$7 = /*@__PURE__*/getDefaultExportFromCjs(uid$6);
2931
2932 var global$U = global$Z;
2933 var shared$7 = sharedExports;
2934 var hasOwn$w = hasOwnProperty_1;
2935 var uid$5 = uid$6;
2936 var NATIVE_SYMBOL$5 = symbolConstructorDetection;
2937 var USE_SYMBOL_AS_UID = useSymbolAsUid;
2938 var _Symbol$2 = global$U.Symbol;
2939 var WellKnownSymbolsStore$1 = shared$7('wks');
2940 var createWellKnownSymbol = USE_SYMBOL_AS_UID ? _Symbol$2['for'] || _Symbol$2 : _Symbol$2 && _Symbol$2.withoutSetter || uid$5;
2941 var wellKnownSymbol$z = function wellKnownSymbol(name) {
2942 if (!hasOwn$w(WellKnownSymbolsStore$1, name)) {
2943 WellKnownSymbolsStore$1[name] = NATIVE_SYMBOL$5 && hasOwn$w(_Symbol$2, name) ? _Symbol$2[name] : createWellKnownSymbol('Symbol.' + name);
2944 }
2945 return WellKnownSymbolsStore$1[name];
2946 };
2947 var wellKnownSymbol$A = /*@__PURE__*/getDefaultExportFromCjs(wellKnownSymbol$z);
2948
2949 var call$B = functionCall;
2950 var isObject$x = isObject$z;
2951 var isSymbol$6 = isSymbol$7;
2952 var getMethod$8 = getMethod$9;
2953 var ordinaryToPrimitive$1 = ordinaryToPrimitive$2;
2954 var wellKnownSymbol$y = wellKnownSymbol$z;
2955 var $TypeError$m = TypeError;
2956 var TO_PRIMITIVE$1 = wellKnownSymbol$y('toPrimitive');
2957
2958 // `ToPrimitive` abstract operation
2959 // https://tc39.es/ecma262/#sec-toprimitive
2960 var toPrimitive$4 = function toPrimitive(input, pref) {
2961 if (!isObject$x(input) || isSymbol$6(input)) return input;
2962 var exoticToPrim = getMethod$8(input, TO_PRIMITIVE$1);
2963 var result;
2964 if (exoticToPrim) {
2965 if (pref === undefined) pref = 'default';
2966 result = call$B(exoticToPrim, input, pref);
2967 if (!isObject$x(result) || isSymbol$6(result)) return result;
2968 throw $TypeError$m("Can't convert object to primitive value");
2969 }
2970 if (pref === undefined) pref = 'number';
2971 return ordinaryToPrimitive$1(input, pref);
2972 };
2973 var toPrimitive$5 = /*@__PURE__*/getDefaultExportFromCjs(toPrimitive$4);
2974
2975 var toPrimitive$3 = toPrimitive$4;
2976 var isSymbol$5 = isSymbol$7;
2977
2978 // `ToPropertyKey` abstract operation
2979 // https://tc39.es/ecma262/#sec-topropertykey
2980 var toPropertyKey$8 = function toPropertyKey(argument) {
2981 var key = toPrimitive$3(argument, 'string');
2982 return isSymbol$5(key) ? key : key + '';
2983 };
2984 var toPropertyKey$9 = /*@__PURE__*/getDefaultExportFromCjs(toPropertyKey$8);
2985
2986 var global$T = global$Z;
2987 var isObject$w = isObject$z;
2988 var document$3 = global$T.document;
2989 // typeof document.createElement is 'object' in old IE
2990 var EXISTS$1 = isObject$w(document$3) && isObject$w(document$3.createElement);
2991 var documentCreateElement$2 = function documentCreateElement(it) {
2992 return EXISTS$1 ? document$3.createElement(it) : {};
2993 };
2994 var documentCreateElement$3 = /*@__PURE__*/getDefaultExportFromCjs(documentCreateElement$2);
2995
2996 var DESCRIPTORS$J = descriptors;
2997 var fails$1h = fails$1m;
2998 var createElement$1 = documentCreateElement$2;
2999
3000 // Thanks to IE8 for its funny defineProperty
3001 var ie8DomDefine = !DESCRIPTORS$J && !fails$1h(function () {
3002 // eslint-disable-next-line es/no-object-defineproperty -- required for testing
3003 return Object.defineProperty(createElement$1('div'), 'a', {
3004 get: function get() {
3005 return 7;
3006 }
3007 }).a != 7;
3008 });
3009 var ie8DomDefine$1 = /*@__PURE__*/getDefaultExportFromCjs(ie8DomDefine);
3010
3011 var DESCRIPTORS$I = descriptors;
3012 var call$A = functionCall;
3013 var propertyIsEnumerableModule$2 = objectPropertyIsEnumerable;
3014 var createPropertyDescriptor$b = createPropertyDescriptor$c;
3015 var toIndexedObject$i = toIndexedObject$j;
3016 var toPropertyKey$7 = toPropertyKey$8;
3017 var hasOwn$v = hasOwnProperty_1;
3018 var IE8_DOM_DEFINE$1 = ie8DomDefine;
3019
3020 // eslint-disable-next-line es/no-object-getownpropertydescriptor -- safe
3021 var $getOwnPropertyDescriptor$2 = Object.getOwnPropertyDescriptor;
3022
3023 // `Object.getOwnPropertyDescriptor` method
3024 // https://tc39.es/ecma262/#sec-object.getownpropertydescriptor
3025 var f$7 = objectGetOwnPropertyDescriptor.f = DESCRIPTORS$I ? $getOwnPropertyDescriptor$2 : function getOwnPropertyDescriptor(O, P) {
3026 O = toIndexedObject$i(O);
3027 P = toPropertyKey$7(P);
3028 if (IE8_DOM_DEFINE$1) try {
3029 return $getOwnPropertyDescriptor$2(O, P);
3030 } catch (error) {/* empty */}
3031 if (hasOwn$v(O, P)) return createPropertyDescriptor$b(!call$A(propertyIsEnumerableModule$2.f, O, P), O[P]);
3032 };
3033
3034 var objectDefineProperty = {};
3035
3036 var DESCRIPTORS$H = descriptors;
3037 var fails$1g = fails$1m;
3038
3039 // V8 ~ Chrome 36-
3040 // https://bugs.chromium.org/p/v8/issues/detail?id=3334
3041 var v8PrototypeDefineBug = DESCRIPTORS$H && fails$1g(function () {
3042 // eslint-disable-next-line es/no-object-defineproperty -- required for testing
3043 return Object.defineProperty(function () {/* empty */}, 'prototype', {
3044 value: 42,
3045 writable: false
3046 }).prototype != 42;
3047 });
3048 var v8PrototypeDefineBug$1 = /*@__PURE__*/getDefaultExportFromCjs(v8PrototypeDefineBug);
3049
3050 var isObject$v = isObject$z;
3051 var $String$6 = String;
3052 var $TypeError$l = TypeError;
3053
3054 // `Assert: Type(argument) is Object`
3055 var anObject$D = function anObject(argument) {
3056 if (isObject$v(argument)) return argument;
3057 throw $TypeError$l($String$6(argument) + ' is not an object');
3058 };
3059 var anObject$E = /*@__PURE__*/getDefaultExportFromCjs(anObject$D);
3060
3061 var DESCRIPTORS$G = descriptors;
3062 var IE8_DOM_DEFINE = ie8DomDefine;
3063 var V8_PROTOTYPE_DEFINE_BUG$1 = v8PrototypeDefineBug;
3064 var anObject$C = anObject$D;
3065 var toPropertyKey$6 = toPropertyKey$8;
3066 var $TypeError$k = TypeError;
3067 // eslint-disable-next-line es/no-object-defineproperty -- safe
3068 var $defineProperty$1 = Object.defineProperty;
3069 // eslint-disable-next-line es/no-object-getownpropertydescriptor -- safe
3070 var $getOwnPropertyDescriptor$1 = Object.getOwnPropertyDescriptor;
3071 var ENUMERABLE = 'enumerable';
3072 var CONFIGURABLE$1 = 'configurable';
3073 var WRITABLE = 'writable';
3074
3075 // `Object.defineProperty` method
3076 // https://tc39.es/ecma262/#sec-object.defineproperty
3077 var f$6 = objectDefineProperty.f = DESCRIPTORS$G ? V8_PROTOTYPE_DEFINE_BUG$1 ? function defineProperty(O, P, Attributes) {
3078 anObject$C(O);
3079 P = toPropertyKey$6(P);
3080 anObject$C(Attributes);
3081 if (typeof O === 'function' && P === 'prototype' && 'value' in Attributes && WRITABLE in Attributes && !Attributes[WRITABLE]) {
3082 var current = $getOwnPropertyDescriptor$1(O, P);
3083 if (current && current[WRITABLE]) {
3084 O[P] = Attributes.value;
3085 Attributes = {
3086 configurable: CONFIGURABLE$1 in Attributes ? Attributes[CONFIGURABLE$1] : current[CONFIGURABLE$1],
3087 enumerable: ENUMERABLE in Attributes ? Attributes[ENUMERABLE] : current[ENUMERABLE],
3088 writable: false
3089 };
3090 }
3091 }
3092 return $defineProperty$1(O, P, Attributes);
3093 } : $defineProperty$1 : function defineProperty(O, P, Attributes) {
3094 anObject$C(O);
3095 P = toPropertyKey$6(P);
3096 anObject$C(Attributes);
3097 if (IE8_DOM_DEFINE) try {
3098 return $defineProperty$1(O, P, Attributes);
3099 } catch (error) {/* empty */}
3100 if ('get' in Attributes || 'set' in Attributes) throw $TypeError$k('Accessors not supported');
3101 if ('value' in Attributes) O[P] = Attributes.value;
3102 return O;
3103 };
3104
3105 var DESCRIPTORS$F = descriptors;
3106 var definePropertyModule$b = objectDefineProperty;
3107 var createPropertyDescriptor$a = createPropertyDescriptor$c;
3108 var createNonEnumerableProperty$f = DESCRIPTORS$F ? function (object, key, value) {
3109 return definePropertyModule$b.f(object, key, createPropertyDescriptor$a(1, value));
3110 } : function (object, key, value) {
3111 object[key] = value;
3112 return object;
3113 };
3114 var createNonEnumerableProperty$g = /*@__PURE__*/getDefaultExportFromCjs(createNonEnumerableProperty$f);
3115
3116 var makeBuiltIn$5 = {exports: {}};
3117
3118 var DESCRIPTORS$E = descriptors;
3119 var hasOwn$u = hasOwnProperty_1;
3120 var FunctionPrototype$3 = Function.prototype;
3121 // eslint-disable-next-line es/no-object-getownpropertydescriptor -- safe
3122 var getDescriptor = DESCRIPTORS$E && Object.getOwnPropertyDescriptor;
3123 var EXISTS = hasOwn$u(FunctionPrototype$3, 'name');
3124 // additional protection from minified / mangled / dropped function names
3125 var PROPER = EXISTS && function something() {/* empty */}.name === 'something';
3126 var CONFIGURABLE = EXISTS && (!DESCRIPTORS$E || DESCRIPTORS$E && getDescriptor(FunctionPrototype$3, 'name').configurable);
3127 var functionName = {
3128 EXISTS: EXISTS,
3129 PROPER: PROPER,
3130 CONFIGURABLE: CONFIGURABLE
3131 };
3132 var functionName$1 = /*@__PURE__*/getDefaultExportFromCjs(functionName);
3133
3134 var uncurryThis$1e = functionUncurryThis;
3135 var isCallable$t = isCallable$z;
3136 var store$1 = sharedStore;
3137 var functionToString$1 = uncurryThis$1e(Function.toString);
3138
3139 // this helper broken in `core-js@3.4.1-3.4.4`, so we can't use `shared` helper
3140 if (!isCallable$t(store$1.inspectSource)) {
3141 store$1.inspectSource = function (it) {
3142 return functionToString$1(it);
3143 };
3144 }
3145 var inspectSource$3 = store$1.inspectSource;
3146 var inspectSource$4 = /*@__PURE__*/getDefaultExportFromCjs(inspectSource$3);
3147
3148 var global$S = global$Z;
3149 var isCallable$s = isCallable$z;
3150 var WeakMap$2 = global$S.WeakMap;
3151 var weakMapBasicDetection = isCallable$s(WeakMap$2) && /native code/.test(String(WeakMap$2));
3152 var weakMapBasicDetection$1 = /*@__PURE__*/getDefaultExportFromCjs(weakMapBasicDetection);
3153
3154 var shared$6 = sharedExports;
3155 var uid$4 = uid$6;
3156 var keys$2 = shared$6('keys');
3157 var sharedKey$4 = function sharedKey(key) {
3158 return keys$2[key] || (keys$2[key] = uid$4(key));
3159 };
3160 var sharedKey$5 = /*@__PURE__*/getDefaultExportFromCjs(sharedKey$4);
3161
3162 var hiddenKeys$6 = {};
3163 var hiddenKeys$7 = /*@__PURE__*/getDefaultExportFromCjs(hiddenKeys$6);
3164
3165 var NATIVE_WEAK_MAP$1 = weakMapBasicDetection;
3166 var global$R = global$Z;
3167 var isObject$u = isObject$z;
3168 var createNonEnumerableProperty$e = createNonEnumerableProperty$f;
3169 var hasOwn$t = hasOwnProperty_1;
3170 var shared$5 = sharedStore;
3171 var sharedKey$3 = sharedKey$4;
3172 var hiddenKeys$5 = hiddenKeys$6;
3173 var OBJECT_ALREADY_INITIALIZED = 'Object already initialized';
3174 var TypeError$8 = global$R.TypeError;
3175 var WeakMap$1 = global$R.WeakMap;
3176 var set$3, get$4, has;
3177 var enforce = function enforce(it) {
3178 return has(it) ? get$4(it) : set$3(it, {});
3179 };
3180 var getterFor$1 = function getterFor(TYPE) {
3181 return function (it) {
3182 var state;
3183 if (!isObject$u(it) || (state = get$4(it)).type !== TYPE) {
3184 throw TypeError$8('Incompatible receiver, ' + TYPE + ' required');
3185 }
3186 return state;
3187 };
3188 };
3189 if (NATIVE_WEAK_MAP$1 || shared$5.state) {
3190 var store = shared$5.state || (shared$5.state = new WeakMap$1());
3191 /* eslint-disable no-self-assign -- prototype methods protection */
3192 store.get = store.get;
3193 store.has = store.has;
3194 store.set = store.set;
3195 /* eslint-enable no-self-assign -- prototype methods protection */
3196 set$3 = function set(it, metadata) {
3197 if (store.has(it)) throw TypeError$8(OBJECT_ALREADY_INITIALIZED);
3198 metadata.facade = it;
3199 store.set(it, metadata);
3200 return metadata;
3201 };
3202 get$4 = function get(it) {
3203 return store.get(it) || {};
3204 };
3205 has = function has(it) {
3206 return store.has(it);
3207 };
3208 } else {
3209 var STATE = sharedKey$3('state');
3210 hiddenKeys$5[STATE] = true;
3211 set$3 = function set(it, metadata) {
3212 if (hasOwn$t(it, STATE)) throw TypeError$8(OBJECT_ALREADY_INITIALIZED);
3213 metadata.facade = it;
3214 createNonEnumerableProperty$e(it, STATE, metadata);
3215 return metadata;
3216 };
3217 get$4 = function get(it) {
3218 return hasOwn$t(it, STATE) ? it[STATE] : {};
3219 };
3220 has = function has(it) {
3221 return hasOwn$t(it, STATE);
3222 };
3223 }
3224 var internalState = {
3225 set: set$3,
3226 get: get$4,
3227 has: has,
3228 enforce: enforce,
3229 getterFor: getterFor$1
3230 };
3231 var internalState$1 = /*@__PURE__*/getDefaultExportFromCjs(internalState);
3232
3233 var makeBuiltIn_1 = makeBuiltIn$5.exports;
3234 var uncurryThis$1d = functionUncurryThis;
3235 var fails$1f = fails$1m;
3236 var isCallable$r = isCallable$z;
3237 var hasOwn$s = hasOwnProperty_1;
3238 var DESCRIPTORS$D = descriptors;
3239 var CONFIGURABLE_FUNCTION_NAME$2 = functionName.CONFIGURABLE;
3240 var inspectSource$2 = inspectSource$3;
3241 var InternalStateModule$d = internalState;
3242 var enforceInternalState$4 = InternalStateModule$d.enforce;
3243 var getInternalState$a = InternalStateModule$d.get;
3244 var $String$5 = String;
3245 // eslint-disable-next-line es/no-object-defineproperty -- safe
3246 var defineProperty$d = Object.defineProperty;
3247 var stringSlice$h = uncurryThis$1d(''.slice);
3248 var replace$d = uncurryThis$1d(''.replace);
3249 var join$7 = uncurryThis$1d([].join);
3250 var CONFIGURABLE_LENGTH = DESCRIPTORS$D && !fails$1f(function () {
3251 return defineProperty$d(function () {/* empty */}, 'length', {
3252 value: 8
3253 }).length !== 8;
3254 });
3255 var TEMPLATE = String(String).split('String');
3256 var makeBuiltIn$3 = makeBuiltIn$5.exports = function (value, name, options) {
3257 if (stringSlice$h($String$5(name), 0, 7) === 'Symbol(') {
3258 name = '[' + replace$d($String$5(name), /^Symbol\(([^)]*)\)/, '$1') + ']';
3259 }
3260 if (options && options.getter) name = 'get ' + name;
3261 if (options && options.setter) name = 'set ' + name;
3262 if (!hasOwn$s(value, 'name') || CONFIGURABLE_FUNCTION_NAME$2 && value.name !== name) {
3263 if (DESCRIPTORS$D) defineProperty$d(value, 'name', {
3264 value: name,
3265 configurable: true
3266 });else value.name = name;
3267 }
3268 if (CONFIGURABLE_LENGTH && options && hasOwn$s(options, 'arity') && value.length !== options.arity) {
3269 defineProperty$d(value, 'length', {
3270 value: options.arity
3271 });
3272 }
3273 try {
3274 if (options && hasOwn$s(options, 'constructor') && options.constructor) {
3275 if (DESCRIPTORS$D) defineProperty$d(value, 'prototype', {
3276 writable: false
3277 });
3278 // in V8 ~ Chrome 53, prototypes of some methods, like `Array.prototype.values`, are non-writable
3279 } else if (value.prototype) value.prototype = undefined;
3280 } catch (error) {/* empty */}
3281 var state = enforceInternalState$4(value);
3282 if (!hasOwn$s(state, 'source')) {
3283 state.source = join$7(TEMPLATE, typeof name == 'string' ? name : '');
3284 }
3285 return value;
3286 };
3287
3288 // add fake Function#toString for correct work wrapped methods / constructors with methods like LoDash isNative
3289 // eslint-disable-next-line no-extend-native -- required
3290 Function.prototype.toString = makeBuiltIn$3(function toString() {
3291 return isCallable$r(this) && getInternalState$a(this).source || inspectSource$2(this);
3292 }, 'toString');
3293 var makeBuiltInExports = makeBuiltIn$5.exports;
3294 var makeBuiltIn$4 = /*@__PURE__*/getDefaultExportFromCjs(makeBuiltInExports);
3295
3296 var isCallable$q = isCallable$z;
3297 var definePropertyModule$a = objectDefineProperty;
3298 var makeBuiltIn$2 = makeBuiltInExports;
3299 var defineGlobalProperty$1 = defineGlobalProperty$3;
3300 var defineBuiltIn$m = function defineBuiltIn(O, key, value, options) {
3301 if (!options) options = {};
3302 var simple = options.enumerable;
3303 var name = options.name !== undefined ? options.name : key;
3304 if (isCallable$q(value)) makeBuiltIn$2(value, name, options);
3305 if (options.global) {
3306 if (simple) O[key] = value;else defineGlobalProperty$1(key, value);
3307 } else {
3308 try {
3309 if (!options.unsafe) delete O[key];else if (O[key]) simple = true;
3310 } catch (error) {/* empty */}
3311 if (simple) O[key] = value;else definePropertyModule$a.f(O, key, {
3312 value: value,
3313 enumerable: false,
3314 configurable: !options.nonConfigurable,
3315 writable: !options.nonWritable
3316 });
3317 }
3318 return O;
3319 };
3320 var defineBuiltIn$n = /*@__PURE__*/getDefaultExportFromCjs(defineBuiltIn$m);
3321
3322 var objectGetOwnPropertyNames = {};
3323
3324 var ceil$4 = Math.ceil;
3325 var floor$d = Math.floor;
3326
3327 // `Math.trunc` method
3328 // https://tc39.es/ecma262/#sec-math.trunc
3329 // eslint-disable-next-line es/no-math-trunc -- safe
3330 var mathTrunc = Math.trunc || function trunc(x) {
3331 var n = +x;
3332 return (n > 0 ? floor$d : ceil$4)(n);
3333 };
3334 var mathTrunc$1 = /*@__PURE__*/getDefaultExportFromCjs(mathTrunc);
3335
3336 var trunc$1 = mathTrunc;
3337
3338 // `ToIntegerOrInfinity` abstract operation
3339 // https://tc39.es/ecma262/#sec-tointegerorinfinity
3340 var toIntegerOrInfinity$l = function toIntegerOrInfinity(argument) {
3341 var number = +argument;
3342 // eslint-disable-next-line no-self-compare -- NaN check
3343 return number !== number || number === 0 ? 0 : trunc$1(number);
3344 };
3345 var toIntegerOrInfinity$m = /*@__PURE__*/getDefaultExportFromCjs(toIntegerOrInfinity$l);
3346
3347 var toIntegerOrInfinity$k = toIntegerOrInfinity$l;
3348 var max$b = Math.max;
3349 var min$e = Math.min;
3350
3351 // Helper for a popular repeating case of the spec:
3352 // Let integer be ? ToInteger(index).
3353 // If integer < 0, let result be max((length + integer), 0); else let result be min(integer, length).
3354 var toAbsoluteIndex$a = function toAbsoluteIndex(index, length) {
3355 var integer = toIntegerOrInfinity$k(index);
3356 return integer < 0 ? max$b(integer + length, 0) : min$e(integer, length);
3357 };
3358 var toAbsoluteIndex$b = /*@__PURE__*/getDefaultExportFromCjs(toAbsoluteIndex$a);
3359
3360 var toIntegerOrInfinity$j = toIntegerOrInfinity$l;
3361 var min$d = Math.min;
3362
3363 // `ToLength` abstract operation
3364 // https://tc39.es/ecma262/#sec-tolength
3365 var toLength$d = function toLength(argument) {
3366 return argument > 0 ? min$d(toIntegerOrInfinity$j(argument), 0x1FFFFFFFFFFFFF) : 0; // 2 ** 53 - 1 == 9007199254740991
3367 };
3368
3369 var toLength$e = /*@__PURE__*/getDefaultExportFromCjs(toLength$d);
3370
3371 var toLength$c = toLength$d;
3372
3373 // `LengthOfArrayLike` abstract operation
3374 // https://tc39.es/ecma262/#sec-lengthofarraylike
3375 var lengthOfArrayLike$t = function lengthOfArrayLike(obj) {
3376 return toLength$c(obj.length);
3377 };
3378 var lengthOfArrayLike$u = /*@__PURE__*/getDefaultExportFromCjs(lengthOfArrayLike$t);
3379
3380 var toIndexedObject$h = toIndexedObject$j;
3381 var toAbsoluteIndex$9 = toAbsoluteIndex$a;
3382 var lengthOfArrayLike$s = lengthOfArrayLike$t;
3383
3384 // `Array.prototype.{ indexOf, includes }` methods implementation
3385 var createMethod$7 = function createMethod(IS_INCLUDES) {
3386 return function ($this, el, fromIndex) {
3387 var O = toIndexedObject$h($this);
3388 var length = lengthOfArrayLike$s(O);
3389 var index = toAbsoluteIndex$9(fromIndex, length);
3390 var value;
3391 // Array#includes uses SameValueZero equality algorithm
3392 // eslint-disable-next-line no-self-compare -- NaN check
3393 if (IS_INCLUDES && el != el) while (length > index) {
3394 value = O[index++];
3395 // eslint-disable-next-line no-self-compare -- NaN check
3396 if (value != value) return true;
3397 // Array#indexOf ignores holes, Array#includes - not
3398 } else for (; length > index; index++) {
3399 if ((IS_INCLUDES || index in O) && O[index] === el) return IS_INCLUDES || index || 0;
3400 }
3401 return !IS_INCLUDES && -1;
3402 };
3403 };
3404 var arrayIncludes = {
3405 // `Array.prototype.includes` method
3406 // https://tc39.es/ecma262/#sec-array.prototype.includes
3407 includes: createMethod$7(true),
3408 // `Array.prototype.indexOf` method
3409 // https://tc39.es/ecma262/#sec-array.prototype.indexof
3410 indexOf: createMethod$7(false)
3411 };
3412 var arrayIncludes$1 = /*@__PURE__*/getDefaultExportFromCjs(arrayIncludes);
3413
3414 var uncurryThis$1c = functionUncurryThis;
3415 var hasOwn$r = hasOwnProperty_1;
3416 var toIndexedObject$g = toIndexedObject$j;
3417 var indexOf$2 = arrayIncludes.indexOf;
3418 var hiddenKeys$4 = hiddenKeys$6;
3419 var push$e = uncurryThis$1c([].push);
3420 var objectKeysInternal = function objectKeysInternal(object, names) {
3421 var O = toIndexedObject$g(object);
3422 var i = 0;
3423 var result = [];
3424 var key;
3425 for (key in O) !hasOwn$r(hiddenKeys$4, key) && hasOwn$r(O, key) && push$e(result, key);
3426 // Don't enum bug & hidden keys
3427 while (names.length > i) if (hasOwn$r(O, key = names[i++])) {
3428 ~indexOf$2(result, key) || push$e(result, key);
3429 }
3430 return result;
3431 };
3432 var objectKeysInternal$1 = /*@__PURE__*/getDefaultExportFromCjs(objectKeysInternal);
3433
3434 // IE8- don't enum bug keys
3435 var enumBugKeys$3 = ['constructor', 'hasOwnProperty', 'isPrototypeOf', 'propertyIsEnumerable', 'toLocaleString', 'toString', 'valueOf'];
3436 var enumBugKeys$4 = /*@__PURE__*/getDefaultExportFromCjs(enumBugKeys$3);
3437
3438 var internalObjectKeys$1 = objectKeysInternal;
3439 var enumBugKeys$2 = enumBugKeys$3;
3440 var hiddenKeys$3 = enumBugKeys$2.concat('length', 'prototype');
3441
3442 // `Object.getOwnPropertyNames` method
3443 // https://tc39.es/ecma262/#sec-object.getownpropertynames
3444 // eslint-disable-next-line es/no-object-getownpropertynames -- safe
3445 var f$5 = objectGetOwnPropertyNames.f = Object.getOwnPropertyNames || function getOwnPropertyNames(O) {
3446 return internalObjectKeys$1(O, hiddenKeys$3);
3447 };
3448
3449 var objectGetOwnPropertySymbols = {};
3450
3451 // eslint-disable-next-line es/no-object-getownpropertysymbols -- safe
3452 var f$4 = objectGetOwnPropertySymbols.f = Object.getOwnPropertySymbols;
3453
3454 var getBuiltIn$k = getBuiltIn$m;
3455 var uncurryThis$1b = functionUncurryThis;
3456 var getOwnPropertyNamesModule$2 = objectGetOwnPropertyNames;
3457 var getOwnPropertySymbolsModule$3 = objectGetOwnPropertySymbols;
3458 var anObject$B = anObject$D;
3459 var concat$6 = uncurryThis$1b([].concat);
3460
3461 // all object keys, includes non-enumerable and symbols
3462 var ownKeys$3 = getBuiltIn$k('Reflect', 'ownKeys') || function ownKeys(it) {
3463 var keys = getOwnPropertyNamesModule$2.f(anObject$B(it));
3464 var getOwnPropertySymbols = getOwnPropertySymbolsModule$3.f;
3465 return getOwnPropertySymbols ? concat$6(keys, getOwnPropertySymbols(it)) : keys;
3466 };
3467 var ownKeys$4 = /*@__PURE__*/getDefaultExportFromCjs(ownKeys$3);
3468
3469 var hasOwn$q = hasOwnProperty_1;
3470 var ownKeys$2 = ownKeys$3;
3471 var getOwnPropertyDescriptorModule$6 = objectGetOwnPropertyDescriptor;
3472 var definePropertyModule$9 = objectDefineProperty;
3473 var copyConstructorProperties$5 = function copyConstructorProperties(target, source, exceptions) {
3474 var keys = ownKeys$2(source);
3475 var defineProperty = definePropertyModule$9.f;
3476 var getOwnPropertyDescriptor = getOwnPropertyDescriptorModule$6.f;
3477 for (var i = 0; i < keys.length; i++) {
3478 var key = keys[i];
3479 if (!hasOwn$q(target, key) && !(exceptions && hasOwn$q(exceptions, key))) {
3480 defineProperty(target, key, getOwnPropertyDescriptor(source, key));
3481 }
3482 }
3483 };
3484 var copyConstructorProperties$6 = /*@__PURE__*/getDefaultExportFromCjs(copyConstructorProperties$5);
3485
3486 var fails$1e = fails$1m;
3487 var isCallable$p = isCallable$z;
3488 var replacement = /#|\.prototype\./;
3489 var isForced$5 = function isForced(feature, detection) {
3490 var value = data[normalize(feature)];
3491 return value == POLYFILL ? true : value == NATIVE ? false : isCallable$p(detection) ? fails$1e(detection) : !!detection;
3492 };
3493 var normalize = isForced$5.normalize = function (string) {
3494 return String(string).replace(replacement, '.').toLowerCase();
3495 };
3496 var data = isForced$5.data = {};
3497 var NATIVE = isForced$5.NATIVE = 'N';
3498 var POLYFILL = isForced$5.POLYFILL = 'P';
3499 var isForced_1 = isForced$5;
3500 var isForced$6 = /*@__PURE__*/getDefaultExportFromCjs(isForced_1);
3501
3502 var global$Q = global$Z;
3503 var getOwnPropertyDescriptor$9 = objectGetOwnPropertyDescriptor.f;
3504 var createNonEnumerableProperty$d = createNonEnumerableProperty$f;
3505 var defineBuiltIn$l = defineBuiltIn$m;
3506 var defineGlobalProperty = defineGlobalProperty$3;
3507 var copyConstructorProperties$4 = copyConstructorProperties$5;
3508 var isForced$4 = isForced_1;
3509
3510 /*
3511 options.target - name of the target object
3512 options.global - target is the global object
3513 options.stat - export as static methods of target
3514 options.proto - export as prototype methods of target
3515 options.real - real prototype method for the `pure` version
3516 options.forced - export even if the native feature is available
3517 options.bind - bind methods to the target, required for the `pure` version
3518 options.wrap - wrap constructors to preventing global pollution, required for the `pure` version
3519 options.unsafe - use the simple assignment of property instead of delete + defineProperty
3520 options.sham - add a flag to not completely full polyfills
3521 options.enumerable - export as enumerable property
3522 options.dontCallGetSet - prevent calling a getter on target
3523 options.name - the .name of the function if it does not match the key
3524 */
3525 var _export = function _export(options, source) {
3526 var TARGET = options.target;
3527 var GLOBAL = options.global;
3528 var STATIC = options.stat;
3529 var FORCED, target, key, targetProperty, sourceProperty, descriptor;
3530 if (GLOBAL) {
3531 target = global$Q;
3532 } else if (STATIC) {
3533 target = global$Q[TARGET] || defineGlobalProperty(TARGET, {});
3534 } else {
3535 target = (global$Q[TARGET] || {}).prototype;
3536 }
3537 if (target) for (key in source) {
3538 sourceProperty = source[key];
3539 if (options.dontCallGetSet) {
3540 descriptor = getOwnPropertyDescriptor$9(target, key);
3541 targetProperty = descriptor && descriptor.value;
3542 } else targetProperty = target[key];
3543 FORCED = isForced$4(GLOBAL ? key : TARGET + (STATIC ? '.' : '#') + key, options.forced);
3544 // contained in target
3545 if (!FORCED && targetProperty !== undefined) {
3546 if (_typeof(sourceProperty) == _typeof(targetProperty)) continue;
3547 copyConstructorProperties$4(sourceProperty, targetProperty);
3548 }
3549 // add a flag to not completely full polyfills
3550 if (options.sham || targetProperty && targetProperty.sham) {
3551 createNonEnumerableProperty$d(sourceProperty, 'sham', true);
3552 }
3553 defineBuiltIn$l(target, key, sourceProperty, options);
3554 }
3555 };
3556 var _export$1 = /*@__PURE__*/getDefaultExportFromCjs(_export);
3557
3558 var wellKnownSymbol$x = wellKnownSymbol$z;
3559 var TO_STRING_TAG$5 = wellKnownSymbol$x('toStringTag');
3560 var test$2 = {};
3561 test$2[TO_STRING_TAG$5] = 'z';
3562 var toStringTagSupport = String(test$2) === '[object z]';
3563 var toStringTagSupport$1 = /*@__PURE__*/getDefaultExportFromCjs(toStringTagSupport);
3564
3565 var TO_STRING_TAG_SUPPORT$2 = toStringTagSupport;
3566 var isCallable$o = isCallable$z;
3567 var classofRaw$1 = classofRaw$2;
3568 var wellKnownSymbol$w = wellKnownSymbol$z;
3569 var TO_STRING_TAG$4 = wellKnownSymbol$w('toStringTag');
3570 var $Object$2 = Object;
3571
3572 // ES3 wrong here
3573 var CORRECT_ARGUMENTS = classofRaw$1(function () {
3574 return arguments;
3575 }()) == 'Arguments';
3576
3577 // fallback for IE11 Script Access Denied error
3578 var tryGet = function tryGet(it, key) {
3579 try {
3580 return it[key];
3581 } catch (error) {/* empty */}
3582 };
3583
3584 // getting tag from ES6+ `Object.prototype.toString`
3585 var classof$m = TO_STRING_TAG_SUPPORT$2 ? classofRaw$1 : function (it) {
3586 var O, tag, result;
3587 return it === undefined ? 'Undefined' : it === null ? 'Null'
3588 // @@toStringTag case
3589 : typeof (tag = tryGet(O = $Object$2(it), TO_STRING_TAG$4)) == 'string' ? tag
3590 // builtinTag case
3591 : CORRECT_ARGUMENTS ? classofRaw$1(O)
3592 // ES3 arguments fallback
3593 : (result = classofRaw$1(O)) == 'Object' && isCallable$o(O.callee) ? 'Arguments' : result;
3594 };
3595 var classof$n = /*@__PURE__*/getDefaultExportFromCjs(classof$m);
3596
3597 var classof$l = classof$m;
3598 var $String$4 = String;
3599 var toString$x = function toString(argument) {
3600 if (classof$l(argument) === 'Symbol') throw TypeError('Cannot convert a Symbol value to a string');
3601 return $String$4(argument);
3602 };
3603 var toString$y = /*@__PURE__*/getDefaultExportFromCjs(toString$x);
3604
3605 var objectDefineProperties = {};
3606
3607 var internalObjectKeys = objectKeysInternal;
3608 var enumBugKeys$1 = enumBugKeys$3;
3609
3610 // `Object.keys` method
3611 // https://tc39.es/ecma262/#sec-object.keys
3612 // eslint-disable-next-line es/no-object-keys -- safe
3613 var objectKeys$5 = Object.keys || function keys(O) {
3614 return internalObjectKeys(O, enumBugKeys$1);
3615 };
3616 var objectKeys$6 = /*@__PURE__*/getDefaultExportFromCjs(objectKeys$5);
3617
3618 var DESCRIPTORS$C = descriptors;
3619 var V8_PROTOTYPE_DEFINE_BUG = v8PrototypeDefineBug;
3620 var definePropertyModule$8 = objectDefineProperty;
3621 var anObject$A = anObject$D;
3622 var toIndexedObject$f = toIndexedObject$j;
3623 var objectKeys$4 = objectKeys$5;
3624
3625 // `Object.defineProperties` method
3626 // https://tc39.es/ecma262/#sec-object.defineproperties
3627 // eslint-disable-next-line es/no-object-defineproperties -- safe
3628 var f$3 = objectDefineProperties.f = DESCRIPTORS$C && !V8_PROTOTYPE_DEFINE_BUG ? Object.defineProperties : function defineProperties(O, Properties) {
3629 anObject$A(O);
3630 var props = toIndexedObject$f(Properties);
3631 var keys = objectKeys$4(Properties);
3632 var length = keys.length;
3633 var index = 0;
3634 var key;
3635 while (length > index) definePropertyModule$8.f(O, key = keys[index++], props[key]);
3636 return O;
3637 };
3638
3639 var getBuiltIn$j = getBuiltIn$m;
3640 var html$2 = getBuiltIn$j('document', 'documentElement');
3641 var html$3 = /*@__PURE__*/getDefaultExportFromCjs(html$2);
3642
3643 /* global ActiveXObject -- old IE, WSH */
3644 var anObject$z = anObject$D;
3645 var definePropertiesModule$1 = objectDefineProperties;
3646 var enumBugKeys = enumBugKeys$3;
3647 var hiddenKeys$2 = hiddenKeys$6;
3648 var html$1 = html$2;
3649 var documentCreateElement$1 = documentCreateElement$2;
3650 var sharedKey$2 = sharedKey$4;
3651 var GT = '>';
3652 var LT = '<';
3653 var PROTOTYPE$2 = 'prototype';
3654 var SCRIPT = 'script';
3655 var IE_PROTO$1 = sharedKey$2('IE_PROTO');
3656 var EmptyConstructor = function EmptyConstructor() {/* empty */};
3657 var scriptTag = function scriptTag(content) {
3658 return LT + SCRIPT + GT + content + LT + '/' + SCRIPT + GT;
3659 };
3660
3661 // Create object with fake `null` prototype: use ActiveX Object with cleared prototype
3662 var NullProtoObjectViaActiveX = function NullProtoObjectViaActiveX(activeXDocument) {
3663 activeXDocument.write(scriptTag(''));
3664 activeXDocument.close();
3665 var temp = activeXDocument.parentWindow.Object;
3666 activeXDocument = null; // avoid memory leak
3667 return temp;
3668 };
3669
3670 // Create object with fake `null` prototype: use iframe Object with cleared prototype
3671 var NullProtoObjectViaIFrame = function NullProtoObjectViaIFrame() {
3672 // Thrash, waste and sodomy: IE GC bug
3673 var iframe = documentCreateElement$1('iframe');
3674 var JS = 'java' + SCRIPT + ':';
3675 var iframeDocument;
3676 iframe.style.display = 'none';
3677 html$1.appendChild(iframe);
3678 // https://github.com/zloirock/core-js/issues/475
3679 iframe.src = String(JS);
3680 iframeDocument = iframe.contentWindow.document;
3681 iframeDocument.open();
3682 iframeDocument.write(scriptTag('document.F=Object'));
3683 iframeDocument.close();
3684 return iframeDocument.F;
3685 };
3686
3687 // Check for document.domain and active x support
3688 // No need to use active x approach when document.domain is not set
3689 // see https://github.com/es-shims/es5-shim/issues/150
3690 // variation of https://github.com/kitcambridge/es5-shim/commit/4f738ac066346
3691 // avoid IE GC bug
3692 var activeXDocument;
3693 var _NullProtoObject = function NullProtoObject() {
3694 try {
3695 activeXDocument = new ActiveXObject('htmlfile');
3696 } catch (error) {/* ignore */}
3697 _NullProtoObject = typeof document != 'undefined' ? document.domain && activeXDocument ? NullProtoObjectViaActiveX(activeXDocument) // old IE
3698 : NullProtoObjectViaIFrame() : NullProtoObjectViaActiveX(activeXDocument); // WSH
3699 var length = enumBugKeys.length;
3700 while (length--) delete _NullProtoObject[PROTOTYPE$2][enumBugKeys[length]];
3701 return _NullProtoObject();
3702 };
3703 hiddenKeys$2[IE_PROTO$1] = true;
3704
3705 // `Object.create` method
3706 // https://tc39.es/ecma262/#sec-object.create
3707 // eslint-disable-next-line es/no-object-create -- safe
3708 var objectCreate = Object.create || function create(O, Properties) {
3709 var result;
3710 if (O !== null) {
3711 EmptyConstructor[PROTOTYPE$2] = anObject$z(O);
3712 result = new EmptyConstructor();
3713 EmptyConstructor[PROTOTYPE$2] = null;
3714 // add "__proto__" for Object.getPrototypeOf polyfill
3715 result[IE_PROTO$1] = O;
3716 } else result = _NullProtoObject();
3717 return Properties === undefined ? result : definePropertiesModule$1.f(result, Properties);
3718 };
3719 var objectCreate$1 = /*@__PURE__*/getDefaultExportFromCjs(objectCreate);
3720
3721 var objectGetOwnPropertyNamesExternal = {};
3722
3723 'use strict';
3724 var toPropertyKey$5 = toPropertyKey$8;
3725 var definePropertyModule$7 = objectDefineProperty;
3726 var createPropertyDescriptor$9 = createPropertyDescriptor$c;
3727 var createProperty$9 = function createProperty(object, key, value) {
3728 var propertyKey = toPropertyKey$5(key);
3729 if (propertyKey in object) definePropertyModule$7.f(object, propertyKey, createPropertyDescriptor$9(0, value));else object[propertyKey] = value;
3730 };
3731 var createProperty$a = /*@__PURE__*/getDefaultExportFromCjs(createProperty$9);
3732
3733 var toAbsoluteIndex$8 = toAbsoluteIndex$a;
3734 var lengthOfArrayLike$r = lengthOfArrayLike$t;
3735 var createProperty$8 = createProperty$9;
3736 var $Array$9 = Array;
3737 var max$a = Math.max;
3738 var arraySliceSimple = function arraySliceSimple(O, start, end) {
3739 var length = lengthOfArrayLike$r(O);
3740 var k = toAbsoluteIndex$8(start, length);
3741 var fin = toAbsoluteIndex$8(end === undefined ? length : end, length);
3742 var result = $Array$9(max$a(fin - k, 0));
3743 for (var n = 0; k < fin; k++, n++) createProperty$8(result, n, O[k]);
3744 result.length = n;
3745 return result;
3746 };
3747 var arraySliceSimple$1 = /*@__PURE__*/getDefaultExportFromCjs(arraySliceSimple);
3748
3749 var classof$k = classofRaw$2;
3750 var toIndexedObject$e = toIndexedObject$j;
3751 var $getOwnPropertyNames$1 = objectGetOwnPropertyNames.f;
3752 var arraySlice$c = arraySliceSimple;
3753 var windowNames = (typeof window === "undefined" ? "undefined" : _typeof(window)) == 'object' && window && Object.getOwnPropertyNames ? Object.getOwnPropertyNames(window) : [];
3754 var getWindowNames = function getWindowNames(it) {
3755 try {
3756 return $getOwnPropertyNames$1(it);
3757 } catch (error) {
3758 return arraySlice$c(windowNames);
3759 }
3760 };
3761
3762 // fallback for IE11 buggy Object.getOwnPropertyNames with iframe and window
3763 var f$2 = objectGetOwnPropertyNamesExternal.f = function getOwnPropertyNames(it) {
3764 return windowNames && classof$k(it) == 'Window' ? getWindowNames(it) : $getOwnPropertyNames$1(toIndexedObject$e(it));
3765 };
3766
3767 var makeBuiltIn$1 = makeBuiltInExports;
3768 var defineProperty$c = objectDefineProperty;
3769 var defineBuiltInAccessor$h = function defineBuiltInAccessor(target, name, descriptor) {
3770 if (descriptor.get) makeBuiltIn$1(descriptor.get, name, {
3771 getter: true
3772 });
3773 if (descriptor.set) makeBuiltIn$1(descriptor.set, name, {
3774 setter: true
3775 });
3776 return defineProperty$c.f(target, name, descriptor);
3777 };
3778 var defineBuiltInAccessor$i = /*@__PURE__*/getDefaultExportFromCjs(defineBuiltInAccessor$h);
3779
3780 var wellKnownSymbolWrapped = {};
3781
3782 var wellKnownSymbol$v = wellKnownSymbol$z;
3783 var f$1 = wellKnownSymbolWrapped.f = wellKnownSymbol$v;
3784
3785 var global$P = global$Z;
3786 var path$2 = global$P;
3787 var path$3 = /*@__PURE__*/getDefaultExportFromCjs(path$2);
3788
3789 var path$1 = path$2;
3790 var hasOwn$p = hasOwnProperty_1;
3791 var wrappedWellKnownSymbolModule$1 = wellKnownSymbolWrapped;
3792 var defineProperty$b = objectDefineProperty.f;
3793 var wellKnownSymbolDefine = function wellKnownSymbolDefine(NAME) {
3794 var _Symbol = path$1.Symbol || (path$1.Symbol = {});
3795 if (!hasOwn$p(_Symbol, NAME)) defineProperty$b(_Symbol, NAME, {
3796 value: wrappedWellKnownSymbolModule$1.f(NAME)
3797 });
3798 };
3799 var wellKnownSymbolDefine$1 = /*@__PURE__*/getDefaultExportFromCjs(wellKnownSymbolDefine);
3800
3801 var call$z = functionCall;
3802 var getBuiltIn$i = getBuiltIn$m;
3803 var wellKnownSymbol$u = wellKnownSymbol$z;
3804 var defineBuiltIn$k = defineBuiltIn$m;
3805 var symbolDefineToPrimitive = function symbolDefineToPrimitive() {
3806 var _Symbol = getBuiltIn$i('Symbol');
3807 var SymbolPrototype = _Symbol && _Symbol.prototype;
3808 var valueOf = SymbolPrototype && SymbolPrototype.valueOf;
3809 var TO_PRIMITIVE = wellKnownSymbol$u('toPrimitive');
3810 if (SymbolPrototype && !SymbolPrototype[TO_PRIMITIVE]) {
3811 // `Symbol.prototype[@@toPrimitive]` method
3812 // https://tc39.es/ecma262/#sec-symbol.prototype-@@toprimitive
3813 // eslint-disable-next-line no-unused-vars -- required for .length
3814 defineBuiltIn$k(SymbolPrototype, TO_PRIMITIVE, function (hint) {
3815 return call$z(valueOf, this);
3816 }, {
3817 arity: 1
3818 });
3819 }
3820 };
3821 var symbolDefineToPrimitive$1 = /*@__PURE__*/getDefaultExportFromCjs(symbolDefineToPrimitive);
3822
3823 var defineProperty$a = objectDefineProperty.f;
3824 var hasOwn$o = hasOwnProperty_1;
3825 var wellKnownSymbol$t = wellKnownSymbol$z;
3826 var TO_STRING_TAG$3 = wellKnownSymbol$t('toStringTag');
3827 var setToStringTag$d = function setToStringTag(target, TAG, STATIC) {
3828 if (target && !STATIC) target = target.prototype;
3829 if (target && !hasOwn$o(target, TO_STRING_TAG$3)) {
3830 defineProperty$a(target, TO_STRING_TAG$3, {
3831 configurable: true,
3832 value: TAG
3833 });
3834 }
3835 };
3836 var setToStringTag$e = /*@__PURE__*/getDefaultExportFromCjs(setToStringTag$d);
3837
3838 var classofRaw = classofRaw$2;
3839 var uncurryThis$1a = functionUncurryThis;
3840 var functionUncurryThisClause = function functionUncurryThisClause(fn) {
3841 // Nashorn bug:
3842 // https://github.com/zloirock/core-js/issues/1128
3843 // https://github.com/zloirock/core-js/issues/1130
3844 if (classofRaw(fn) === 'Function') return uncurryThis$1a(fn);
3845 };
3846 var functionUncurryThisClause$1 = /*@__PURE__*/getDefaultExportFromCjs(functionUncurryThisClause);
3847
3848 var uncurryThis$19 = functionUncurryThisClause;
3849 var aCallable$j = aCallable$l;
3850 var NATIVE_BIND$2 = functionBindNative;
3851 var bind$e = uncurryThis$19(uncurryThis$19.bind);
3852
3853 // optional / simple context binding
3854 var functionBindContext = function functionBindContext(fn, that) {
3855 aCallable$j(fn);
3856 return that === undefined ? fn : NATIVE_BIND$2 ? bind$e(fn, that) : function /* ...args */
3857 () {
3858 return fn.apply(that, arguments);
3859 };
3860 };
3861 var functionBindContext$1 = /*@__PURE__*/getDefaultExportFromCjs(functionBindContext);
3862
3863 var classof$j = classofRaw$2;
3864
3865 // `IsArray` abstract operation
3866 // https://tc39.es/ecma262/#sec-isarray
3867 // eslint-disable-next-line es/no-array-isarray -- safe
3868 var isArray$9 = Array.isArray || function isArray(argument) {
3869 return classof$j(argument) == 'Array';
3870 };
3871 var isArray$a = /*@__PURE__*/getDefaultExportFromCjs(isArray$9);
3872
3873 var uncurryThis$18 = functionUncurryThis;
3874 var fails$1d = fails$1m;
3875 var isCallable$n = isCallable$z;
3876 var classof$i = classof$m;
3877 var getBuiltIn$h = getBuiltIn$m;
3878 var inspectSource$1 = inspectSource$3;
3879 var noop = function noop() {/* empty */};
3880 var empty = [];
3881 var construct$1 = getBuiltIn$h('Reflect', 'construct');
3882 var constructorRegExp = /^\s*(?:class|function)\b/;
3883 var exec$a = uncurryThis$18(constructorRegExp.exec);
3884 var INCORRECT_TO_STRING$2 = !constructorRegExp.exec(noop);
3885 var isConstructorModern = function isConstructor(argument) {
3886 if (!isCallable$n(argument)) return false;
3887 try {
3888 construct$1(noop, empty, argument);
3889 return true;
3890 } catch (error) {
3891 return false;
3892 }
3893 };
3894 var isConstructorLegacy = function isConstructor(argument) {
3895 if (!isCallable$n(argument)) return false;
3896 switch (classof$i(argument)) {
3897 case 'AsyncFunction':
3898 case 'GeneratorFunction':
3899 case 'AsyncGeneratorFunction':
3900 return false;
3901 }
3902 try {
3903 // we can't check .prototype since constructors produced by .bind haven't it
3904 // `Function#toString` throws on some built-it function in some legacy engines
3905 // (for example, `DOMQuad` and similar in FF41-)
3906 return INCORRECT_TO_STRING$2 || !!exec$a(constructorRegExp, inspectSource$1(argument));
3907 } catch (error) {
3908 return true;
3909 }
3910 };
3911 isConstructorLegacy.sham = true;
3912
3913 // `IsConstructor` abstract operation
3914 // https://tc39.es/ecma262/#sec-isconstructor
3915 var isConstructor$6 = !construct$1 || fails$1d(function () {
3916 var called;
3917 return isConstructorModern(isConstructorModern.call) || !isConstructorModern(Object) || !isConstructorModern(function () {
3918 called = true;
3919 }) || called;
3920 }) ? isConstructorLegacy : isConstructorModern;
3921 var isConstructor$7 = /*@__PURE__*/getDefaultExportFromCjs(isConstructor$6);
3922
3923 var isArray$8 = isArray$9;
3924 var isConstructor$5 = isConstructor$6;
3925 var isObject$t = isObject$z;
3926 var wellKnownSymbol$s = wellKnownSymbol$z;
3927 var SPECIES$6 = wellKnownSymbol$s('species');
3928 var $Array$8 = Array;
3929
3930 // a part of `ArraySpeciesCreate` abstract operation
3931 // https://tc39.es/ecma262/#sec-arrayspeciescreate
3932 var arraySpeciesConstructor$1 = function arraySpeciesConstructor(originalArray) {
3933 var C;
3934 if (isArray$8(originalArray)) {
3935 C = originalArray.constructor;
3936 // cross-realm fallback
3937 if (isConstructor$5(C) && (C === $Array$8 || isArray$8(C.prototype))) C = undefined;else if (isObject$t(C)) {
3938 C = C[SPECIES$6];
3939 if (C === null) C = undefined;
3940 }
3941 }
3942 return C === undefined ? $Array$8 : C;
3943 };
3944 var arraySpeciesConstructor$2 = /*@__PURE__*/getDefaultExportFromCjs(arraySpeciesConstructor$1);
3945
3946 var arraySpeciesConstructor = arraySpeciesConstructor$1;
3947
3948 // `ArraySpeciesCreate` abstract operation
3949 // https://tc39.es/ecma262/#sec-arrayspeciescreate
3950 var arraySpeciesCreate$5 = function arraySpeciesCreate(originalArray, length) {
3951 return new (arraySpeciesConstructor(originalArray))(length === 0 ? 0 : length);
3952 };
3953 var arraySpeciesCreate$6 = /*@__PURE__*/getDefaultExportFromCjs(arraySpeciesCreate$5);
3954
3955 var bind$d = functionBindContext;
3956 var uncurryThis$17 = functionUncurryThis;
3957 var IndexedObject$4 = indexedObject;
3958 var toObject$r = toObject$t;
3959 var lengthOfArrayLike$q = lengthOfArrayLike$t;
3960 var arraySpeciesCreate$4 = arraySpeciesCreate$5;
3961 var push$d = uncurryThis$17([].push);
3962
3963 // `Array.prototype.{ forEach, map, filter, some, every, find, findIndex, filterReject }` methods implementation
3964 var createMethod$6 = function createMethod(TYPE) {
3965 var IS_MAP = TYPE == 1;
3966 var IS_FILTER = TYPE == 2;
3967 var IS_SOME = TYPE == 3;
3968 var IS_EVERY = TYPE == 4;
3969 var IS_FIND_INDEX = TYPE == 6;
3970 var IS_FILTER_REJECT = TYPE == 7;
3971 var NO_HOLES = TYPE == 5 || IS_FIND_INDEX;
3972 return function ($this, callbackfn, that, specificCreate) {
3973 var O = toObject$r($this);
3974 var self = IndexedObject$4(O);
3975 var boundFunction = bind$d(callbackfn, that);
3976 var length = lengthOfArrayLike$q(self);
3977 var index = 0;
3978 var create = specificCreate || arraySpeciesCreate$4;
3979 var target = IS_MAP ? create($this, length) : IS_FILTER || IS_FILTER_REJECT ? create($this, 0) : undefined;
3980 var value, result;
3981 for (; length > index; index++) if (NO_HOLES || index in self) {
3982 value = self[index];
3983 result = boundFunction(value, index, O);
3984 if (TYPE) {
3985 if (IS_MAP) target[index] = result; // map
3986 else if (result) switch (TYPE) {
3987 case 3:
3988 return true;
3989 // some
3990 case 5:
3991 return value;
3992 // find
3993 case 6:
3994 return index;
3995 // findIndex
3996 case 2:
3997 push$d(target, value);
3998 // filter
3999 } else switch (TYPE) {
4000 case 4:
4001 return false;
4002 // every
4003 case 7:
4004 push$d(target, value);
4005 // filterReject
4006 }
4007 }
4008 }
4009
4010 return IS_FIND_INDEX ? -1 : IS_SOME || IS_EVERY ? IS_EVERY : target;
4011 };
4012 };
4013 var arrayIteration = {
4014 // `Array.prototype.forEach` method
4015 // https://tc39.es/ecma262/#sec-array.prototype.foreach
4016 forEach: createMethod$6(0),
4017 // `Array.prototype.map` method
4018 // https://tc39.es/ecma262/#sec-array.prototype.map
4019 map: createMethod$6(1),
4020 // `Array.prototype.filter` method
4021 // https://tc39.es/ecma262/#sec-array.prototype.filter
4022 filter: createMethod$6(2),
4023 // `Array.prototype.some` method
4024 // https://tc39.es/ecma262/#sec-array.prototype.some
4025 some: createMethod$6(3),
4026 // `Array.prototype.every` method
4027 // https://tc39.es/ecma262/#sec-array.prototype.every
4028 every: createMethod$6(4),
4029 // `Array.prototype.find` method
4030 // https://tc39.es/ecma262/#sec-array.prototype.find
4031 find: createMethod$6(5),
4032 // `Array.prototype.findIndex` method
4033 // https://tc39.es/ecma262/#sec-array.prototype.findIndex
4034 findIndex: createMethod$6(6),
4035 // `Array.prototype.filterReject` method
4036 // https://github.com/tc39/proposal-array-filtering
4037 filterReject: createMethod$6(7)
4038 };
4039 var arrayIteration$1 = /*@__PURE__*/getDefaultExportFromCjs(arrayIteration);
4040
4041 'use strict';
4042 var $$2X = _export;
4043 var global$O = global$Z;
4044 var call$y = functionCall;
4045 var uncurryThis$16 = functionUncurryThis;
4046 var IS_PURE$j = isPure;
4047 var DESCRIPTORS$B = descriptors;
4048 var NATIVE_SYMBOL$4 = symbolConstructorDetection;
4049 var fails$1c = fails$1m;
4050 var hasOwn$n = hasOwnProperty_1;
4051 var isPrototypeOf$a = objectIsPrototypeOf;
4052 var anObject$y = anObject$D;
4053 var toIndexedObject$d = toIndexedObject$j;
4054 var toPropertyKey$4 = toPropertyKey$8;
4055 var $toString$3 = toString$x;
4056 var createPropertyDescriptor$8 = createPropertyDescriptor$c;
4057 var nativeObjectCreate = objectCreate;
4058 var objectKeys$3 = objectKeys$5;
4059 var getOwnPropertyNamesModule$1 = objectGetOwnPropertyNames;
4060 var getOwnPropertyNamesExternal = objectGetOwnPropertyNamesExternal;
4061 var getOwnPropertySymbolsModule$2 = objectGetOwnPropertySymbols;
4062 var getOwnPropertyDescriptorModule$5 = objectGetOwnPropertyDescriptor;
4063 var definePropertyModule$6 = objectDefineProperty;
4064 var definePropertiesModule = objectDefineProperties;
4065 var propertyIsEnumerableModule$1 = objectPropertyIsEnumerable;
4066 var defineBuiltIn$j = defineBuiltIn$m;
4067 var defineBuiltInAccessor$g = defineBuiltInAccessor$h;
4068 var shared$4 = sharedExports;
4069 var sharedKey$1 = sharedKey$4;
4070 var hiddenKeys$1 = hiddenKeys$6;
4071 var uid$3 = uid$6;
4072 var wellKnownSymbol$r = wellKnownSymbol$z;
4073 var wrappedWellKnownSymbolModule = wellKnownSymbolWrapped;
4074 var defineWellKnownSymbol$d = wellKnownSymbolDefine;
4075 var defineSymbolToPrimitive$1 = symbolDefineToPrimitive;
4076 var setToStringTag$c = setToStringTag$d;
4077 var InternalStateModule$c = internalState;
4078 var $forEach$2 = arrayIteration.forEach;
4079 var HIDDEN = sharedKey$1('hidden');
4080 var SYMBOL = 'Symbol';
4081 var PROTOTYPE$1 = 'prototype';
4082 var setInternalState$b = InternalStateModule$c.set;
4083 var getInternalState$9 = InternalStateModule$c.getterFor(SYMBOL);
4084 var ObjectPrototype$5 = Object[PROTOTYPE$1];
4085 var $Symbol = global$O.Symbol;
4086 var SymbolPrototype$1 = $Symbol && $Symbol[PROTOTYPE$1];
4087 var TypeError$7 = global$O.TypeError;
4088 var QObject = global$O.QObject;
4089 var nativeGetOwnPropertyDescriptor$2 = getOwnPropertyDescriptorModule$5.f;
4090 var nativeDefineProperty$1 = definePropertyModule$6.f;
4091 var nativeGetOwnPropertyNames = getOwnPropertyNamesExternal.f;
4092 var nativePropertyIsEnumerable = propertyIsEnumerableModule$1.f;
4093 var push$c = uncurryThis$16([].push);
4094 var AllSymbols = shared$4('symbols');
4095 var ObjectPrototypeSymbols = shared$4('op-symbols');
4096 var WellKnownSymbolsStore = shared$4('wks');
4097
4098 // Don't use setters in Qt Script, https://github.com/zloirock/core-js/issues/173
4099 var USE_SETTER = !QObject || !QObject[PROTOTYPE$1] || !QObject[PROTOTYPE$1].findChild;
4100
4101 // fallback for old Android, https://code.google.com/p/v8/issues/detail?id=687
4102 var setSymbolDescriptor = DESCRIPTORS$B && fails$1c(function () {
4103 return nativeObjectCreate(nativeDefineProperty$1({}, 'a', {
4104 get: function get() {
4105 return nativeDefineProperty$1(this, 'a', {
4106 value: 7
4107 }).a;
4108 }
4109 })).a != 7;
4110 }) ? function (O, P, Attributes) {
4111 var ObjectPrototypeDescriptor = nativeGetOwnPropertyDescriptor$2(ObjectPrototype$5, P);
4112 if (ObjectPrototypeDescriptor) delete ObjectPrototype$5[P];
4113 nativeDefineProperty$1(O, P, Attributes);
4114 if (ObjectPrototypeDescriptor && O !== ObjectPrototype$5) {
4115 nativeDefineProperty$1(ObjectPrototype$5, P, ObjectPrototypeDescriptor);
4116 }
4117 } : nativeDefineProperty$1;
4118 var wrap = function wrap(tag, description) {
4119 var symbol = AllSymbols[tag] = nativeObjectCreate(SymbolPrototype$1);
4120 setInternalState$b(symbol, {
4121 type: SYMBOL,
4122 tag: tag,
4123 description: description
4124 });
4125 if (!DESCRIPTORS$B) symbol.description = description;
4126 return symbol;
4127 };
4128 var $defineProperty = function defineProperty(O, P, Attributes) {
4129 if (O === ObjectPrototype$5) $defineProperty(ObjectPrototypeSymbols, P, Attributes);
4130 anObject$y(O);
4131 var key = toPropertyKey$4(P);
4132 anObject$y(Attributes);
4133 if (hasOwn$n(AllSymbols, key)) {
4134 if (!Attributes.enumerable) {
4135 if (!hasOwn$n(O, HIDDEN)) nativeDefineProperty$1(O, HIDDEN, createPropertyDescriptor$8(1, {}));
4136 O[HIDDEN][key] = true;
4137 } else {
4138 if (hasOwn$n(O, HIDDEN) && O[HIDDEN][key]) O[HIDDEN][key] = false;
4139 Attributes = nativeObjectCreate(Attributes, {
4140 enumerable: createPropertyDescriptor$8(0, false)
4141 });
4142 }
4143 return setSymbolDescriptor(O, key, Attributes);
4144 }
4145 return nativeDefineProperty$1(O, key, Attributes);
4146 };
4147 var $defineProperties = function defineProperties(O, Properties) {
4148 anObject$y(O);
4149 var properties = toIndexedObject$d(Properties);
4150 var keys = objectKeys$3(properties).concat($getOwnPropertySymbols(properties));
4151 $forEach$2(keys, function (key) {
4152 if (!DESCRIPTORS$B || call$y($propertyIsEnumerable$1, properties, key)) $defineProperty(O, key, properties[key]);
4153 });
4154 return O;
4155 };
4156 var $create = function create(O, Properties) {
4157 return Properties === undefined ? nativeObjectCreate(O) : $defineProperties(nativeObjectCreate(O), Properties);
4158 };
4159 var $propertyIsEnumerable$1 = function propertyIsEnumerable(V) {
4160 var P = toPropertyKey$4(V);
4161 var enumerable = call$y(nativePropertyIsEnumerable, this, P);
4162 if (this === ObjectPrototype$5 && hasOwn$n(AllSymbols, P) && !hasOwn$n(ObjectPrototypeSymbols, P)) return false;
4163 return enumerable || !hasOwn$n(this, P) || !hasOwn$n(AllSymbols, P) || hasOwn$n(this, HIDDEN) && this[HIDDEN][P] ? enumerable : true;
4164 };
4165 var $getOwnPropertyDescriptor = function getOwnPropertyDescriptor(O, P) {
4166 var it = toIndexedObject$d(O);
4167 var key = toPropertyKey$4(P);
4168 if (it === ObjectPrototype$5 && hasOwn$n(AllSymbols, key) && !hasOwn$n(ObjectPrototypeSymbols, key)) return;
4169 var descriptor = nativeGetOwnPropertyDescriptor$2(it, key);
4170 if (descriptor && hasOwn$n(AllSymbols, key) && !(hasOwn$n(it, HIDDEN) && it[HIDDEN][key])) {
4171 descriptor.enumerable = true;
4172 }
4173 return descriptor;
4174 };
4175 var $getOwnPropertyNames = function getOwnPropertyNames(O) {
4176 var names = nativeGetOwnPropertyNames(toIndexedObject$d(O));
4177 var result = [];
4178 $forEach$2(names, function (key) {
4179 if (!hasOwn$n(AllSymbols, key) && !hasOwn$n(hiddenKeys$1, key)) push$c(result, key);
4180 });
4181 return result;
4182 };
4183 var $getOwnPropertySymbols = function $getOwnPropertySymbols(O) {
4184 var IS_OBJECT_PROTOTYPE = O === ObjectPrototype$5;
4185 var names = nativeGetOwnPropertyNames(IS_OBJECT_PROTOTYPE ? ObjectPrototypeSymbols : toIndexedObject$d(O));
4186 var result = [];
4187 $forEach$2(names, function (key) {
4188 if (hasOwn$n(AllSymbols, key) && (!IS_OBJECT_PROTOTYPE || hasOwn$n(ObjectPrototype$5, key))) {
4189 push$c(result, AllSymbols[key]);
4190 }
4191 });
4192 return result;
4193 };
4194
4195 // `Symbol` constructor
4196 // https://tc39.es/ecma262/#sec-symbol-constructor
4197 if (!NATIVE_SYMBOL$4) {
4198 $Symbol = function _Symbol() {
4199 if (isPrototypeOf$a(SymbolPrototype$1, this)) throw TypeError$7('Symbol is not a constructor');
4200 var description = !arguments.length || arguments[0] === undefined ? undefined : $toString$3(arguments[0]);
4201 var tag = uid$3(description);
4202 var setter = function setter(value) {
4203 if (this === ObjectPrototype$5) call$y(setter, ObjectPrototypeSymbols, value);
4204 if (hasOwn$n(this, HIDDEN) && hasOwn$n(this[HIDDEN], tag)) this[HIDDEN][tag] = false;
4205 setSymbolDescriptor(this, tag, createPropertyDescriptor$8(1, value));
4206 };
4207 if (DESCRIPTORS$B && USE_SETTER) setSymbolDescriptor(ObjectPrototype$5, tag, {
4208 configurable: true,
4209 set: setter
4210 });
4211 return wrap(tag, description);
4212 };
4213 SymbolPrototype$1 = $Symbol[PROTOTYPE$1];
4214 defineBuiltIn$j(SymbolPrototype$1, 'toString', function toString() {
4215 return getInternalState$9(this).tag;
4216 });
4217 defineBuiltIn$j($Symbol, 'withoutSetter', function (description) {
4218 return wrap(uid$3(description), description);
4219 });
4220 propertyIsEnumerableModule$1.f = $propertyIsEnumerable$1;
4221 definePropertyModule$6.f = $defineProperty;
4222 definePropertiesModule.f = $defineProperties;
4223 getOwnPropertyDescriptorModule$5.f = $getOwnPropertyDescriptor;
4224 getOwnPropertyNamesModule$1.f = getOwnPropertyNamesExternal.f = $getOwnPropertyNames;
4225 getOwnPropertySymbolsModule$2.f = $getOwnPropertySymbols;
4226 wrappedWellKnownSymbolModule.f = function (name) {
4227 return wrap(wellKnownSymbol$r(name), name);
4228 };
4229 if (DESCRIPTORS$B) {
4230 // https://github.com/tc39/proposal-Symbol-description
4231 defineBuiltInAccessor$g(SymbolPrototype$1, 'description', {
4232 configurable: true,
4233 get: function description() {
4234 return getInternalState$9(this).description;
4235 }
4236 });
4237 if (!IS_PURE$j) {
4238 defineBuiltIn$j(ObjectPrototype$5, 'propertyIsEnumerable', $propertyIsEnumerable$1, {
4239 unsafe: true
4240 });
4241 }
4242 }
4243 }
4244 $$2X({
4245 global: true,
4246 constructor: true,
4247 wrap: true,
4248 forced: !NATIVE_SYMBOL$4,
4249 sham: !NATIVE_SYMBOL$4
4250 }, {
4251 Symbol: $Symbol
4252 });
4253 $forEach$2(objectKeys$3(WellKnownSymbolsStore), function (name) {
4254 defineWellKnownSymbol$d(name);
4255 });
4256 $$2X({
4257 target: SYMBOL,
4258 stat: true,
4259 forced: !NATIVE_SYMBOL$4
4260 }, {
4261 useSetter: function useSetter() {
4262 USE_SETTER = true;
4263 },
4264 useSimple: function useSimple() {
4265 USE_SETTER = false;
4266 }
4267 });
4268 $$2X({
4269 target: 'Object',
4270 stat: true,
4271 forced: !NATIVE_SYMBOL$4,
4272 sham: !DESCRIPTORS$B
4273 }, {
4274 // `Object.create` method
4275 // https://tc39.es/ecma262/#sec-object.create
4276 create: $create,
4277 // `Object.defineProperty` method
4278 // https://tc39.es/ecma262/#sec-object.defineproperty
4279 defineProperty: $defineProperty,
4280 // `Object.defineProperties` method
4281 // https://tc39.es/ecma262/#sec-object.defineproperties
4282 defineProperties: $defineProperties,
4283 // `Object.getOwnPropertyDescriptor` method
4284 // https://tc39.es/ecma262/#sec-object.getownpropertydescriptors
4285 getOwnPropertyDescriptor: $getOwnPropertyDescriptor
4286 });
4287 $$2X({
4288 target: 'Object',
4289 stat: true,
4290 forced: !NATIVE_SYMBOL$4
4291 }, {
4292 // `Object.getOwnPropertyNames` method
4293 // https://tc39.es/ecma262/#sec-object.getownpropertynames
4294 getOwnPropertyNames: $getOwnPropertyNames
4295 });
4296
4297 // `Symbol.prototype[@@toPrimitive]` method
4298 // https://tc39.es/ecma262/#sec-symbol.prototype-@@toprimitive
4299 defineSymbolToPrimitive$1();
4300
4301 // `Symbol.prototype[@@toStringTag]` property
4302 // https://tc39.es/ecma262/#sec-symbol.prototype-@@tostringtag
4303 setToStringTag$c($Symbol, SYMBOL);
4304 hiddenKeys$1[HIDDEN] = true;
4305
4306 var es_symbol_for = {};
4307
4308 var NATIVE_SYMBOL$3 = symbolConstructorDetection;
4309
4310 /* eslint-disable es/no-symbol -- safe */
4311 var symbolRegistryDetection = NATIVE_SYMBOL$3 && !!Symbol['for'] && !!Symbol.keyFor;
4312 var symbolRegistryDetection$1 = /*@__PURE__*/getDefaultExportFromCjs(symbolRegistryDetection);
4313
4314 var $$2W = _export;
4315 var getBuiltIn$g = getBuiltIn$m;
4316 var hasOwn$m = hasOwnProperty_1;
4317 var toString$w = toString$x;
4318 var shared$3 = sharedExports;
4319 var NATIVE_SYMBOL_REGISTRY$1 = symbolRegistryDetection;
4320 var StringToSymbolRegistry = shared$3('string-to-symbol-registry');
4321 var SymbolToStringRegistry$1 = shared$3('symbol-to-string-registry');
4322
4323 // `Symbol.for` method
4324 // https://tc39.es/ecma262/#sec-symbol.for
4325 $$2W({
4326 target: 'Symbol',
4327 stat: true,
4328 forced: !NATIVE_SYMBOL_REGISTRY$1
4329 }, {
4330 'for': function _for(key) {
4331 var string = toString$w(key);
4332 if (hasOwn$m(StringToSymbolRegistry, string)) return StringToSymbolRegistry[string];
4333 var symbol = getBuiltIn$g('Symbol')(string);
4334 StringToSymbolRegistry[string] = symbol;
4335 SymbolToStringRegistry$1[symbol] = string;
4336 return symbol;
4337 }
4338 });
4339
4340 var es_symbol_keyFor = {};
4341
4342 var $$2V = _export;
4343 var hasOwn$l = hasOwnProperty_1;
4344 var isSymbol$4 = isSymbol$7;
4345 var tryToString$5 = tryToString$7;
4346 var shared$2 = sharedExports;
4347 var NATIVE_SYMBOL_REGISTRY = symbolRegistryDetection;
4348 var SymbolToStringRegistry = shared$2('symbol-to-string-registry');
4349
4350 // `Symbol.keyFor` method
4351 // https://tc39.es/ecma262/#sec-symbol.keyfor
4352 $$2V({
4353 target: 'Symbol',
4354 stat: true,
4355 forced: !NATIVE_SYMBOL_REGISTRY
4356 }, {
4357 keyFor: function keyFor(sym) {
4358 if (!isSymbol$4(sym)) throw TypeError(tryToString$5(sym) + ' is not a symbol');
4359 if (hasOwn$l(SymbolToStringRegistry, sym)) return SymbolToStringRegistry[sym];
4360 }
4361 });
4362
4363 var es_json_stringify = {};
4364
4365 var NATIVE_BIND$1 = functionBindNative;
4366 var FunctionPrototype$2 = Function.prototype;
4367 var apply$b = FunctionPrototype$2.apply;
4368 var call$x = FunctionPrototype$2.call;
4369
4370 // eslint-disable-next-line es/no-reflect -- safe
4371 var functionApply$1 = (typeof Reflect === "undefined" ? "undefined" : _typeof(Reflect)) == 'object' && Reflect.apply || (NATIVE_BIND$1 ? call$x.bind(apply$b) : function () {
4372 return call$x.apply(apply$b, arguments);
4373 });
4374 var functionApply$2 = /*@__PURE__*/getDefaultExportFromCjs(functionApply$1);
4375
4376 var uncurryThis$15 = functionUncurryThis;
4377 var arraySlice$a = uncurryThis$15([].slice);
4378 var arraySlice$b = /*@__PURE__*/getDefaultExportFromCjs(arraySlice$a);
4379
4380 var uncurryThis$14 = functionUncurryThis;
4381 var isArray$7 = isArray$9;
4382 var isCallable$m = isCallable$z;
4383 var classof$h = classofRaw$2;
4384 var toString$v = toString$x;
4385 var push$b = uncurryThis$14([].push);
4386 var getJsonReplacerFunction = function getJsonReplacerFunction(replacer) {
4387 if (isCallable$m(replacer)) return replacer;
4388 if (!isArray$7(replacer)) return;
4389 var rawLength = replacer.length;
4390 var keys = [];
4391 for (var i = 0; i < rawLength; i++) {
4392 var element = replacer[i];
4393 if (typeof element == 'string') push$b(keys, element);else if (typeof element == 'number' || classof$h(element) == 'Number' || classof$h(element) == 'String') push$b(keys, toString$v(element));
4394 }
4395 var keysLength = keys.length;
4396 var root = true;
4397 return function (key, value) {
4398 if (root) {
4399 root = false;
4400 return value;
4401 }
4402 if (isArray$7(this)) return value;
4403 for (var j = 0; j < keysLength; j++) if (keys[j] === key) return value;
4404 };
4405 };
4406 var getJsonReplacerFunction$1 = /*@__PURE__*/getDefaultExportFromCjs(getJsonReplacerFunction);
4407
4408 var $$2U = _export;
4409 var getBuiltIn$f = getBuiltIn$m;
4410 var apply$a = functionApply$1;
4411 var call$w = functionCall;
4412 var uncurryThis$13 = functionUncurryThis;
4413 var fails$1b = fails$1m;
4414 var isCallable$l = isCallable$z;
4415 var isSymbol$3 = isSymbol$7;
4416 var arraySlice$9 = arraySlice$a;
4417 var getReplacerFunction = getJsonReplacerFunction;
4418 var NATIVE_SYMBOL$2 = symbolConstructorDetection;
4419 var $String$3 = String;
4420 var $stringify = getBuiltIn$f('JSON', 'stringify');
4421 var exec$9 = uncurryThis$13(/./.exec);
4422 var charAt$e = uncurryThis$13(''.charAt);
4423 var charCodeAt$5 = uncurryThis$13(''.charCodeAt);
4424 var replace$c = uncurryThis$13(''.replace);
4425 var numberToString$2 = uncurryThis$13(1.0.toString);
4426 var tester = /[\uD800-\uDFFF]/g;
4427 var low = /^[\uD800-\uDBFF]$/;
4428 var hi = /^[\uDC00-\uDFFF]$/;
4429 var WRONG_SYMBOLS_CONVERSION = !NATIVE_SYMBOL$2 || fails$1b(function () {
4430 var symbol = getBuiltIn$f('Symbol')();
4431 // MS Edge converts symbol values to JSON as {}
4432 return $stringify([symbol]) != '[null]'
4433 // WebKit converts symbol values to JSON as null
4434 || $stringify({
4435 a: symbol
4436 }) != '{}'
4437 // V8 throws on boxed symbols
4438 || $stringify(Object(symbol)) != '{}';
4439 });
4440
4441 // https://github.com/tc39/proposal-well-formed-stringify
4442 var ILL_FORMED_UNICODE = fails$1b(function () {
4443 return $stringify("\uDF06\uD834") !== "\"\\udf06\\ud834\"" || $stringify("\uDEAD") !== "\"\\udead\"";
4444 });
4445 var stringifyWithSymbolsFix = function stringifyWithSymbolsFix(it, replacer) {
4446 var args = arraySlice$9(arguments);
4447 var $replacer = getReplacerFunction(replacer);
4448 if (!isCallable$l($replacer) && (it === undefined || isSymbol$3(it))) return; // IE8 returns string on undefined
4449 args[1] = function (key, value) {
4450 // some old implementations (like WebKit) could pass numbers as keys
4451 if (isCallable$l($replacer)) value = call$w($replacer, this, $String$3(key), value);
4452 if (!isSymbol$3(value)) return value;
4453 };
4454 return apply$a($stringify, null, args);
4455 };
4456 var fixIllFormed = function fixIllFormed(match, offset, string) {
4457 var prev = charAt$e(string, offset - 1);
4458 var next = charAt$e(string, offset + 1);
4459 if (exec$9(low, match) && !exec$9(hi, next) || exec$9(hi, match) && !exec$9(low, prev)) {
4460 return "\\u" + numberToString$2(charCodeAt$5(match, 0), 16);
4461 }
4462 return match;
4463 };
4464 if ($stringify) {
4465 // `JSON.stringify` method
4466 // https://tc39.es/ecma262/#sec-json.stringify
4467 $$2U({
4468 target: 'JSON',
4469 stat: true,
4470 arity: 3,
4471 forced: WRONG_SYMBOLS_CONVERSION || ILL_FORMED_UNICODE
4472 }, {
4473 // eslint-disable-next-line no-unused-vars -- required for `.length`
4474 stringify: function stringify(it, replacer, space) {
4475 var args = arraySlice$9(arguments);
4476 var result = apply$a(WRONG_SYMBOLS_CONVERSION ? stringifyWithSymbolsFix : $stringify, null, args);
4477 return ILL_FORMED_UNICODE && typeof result == 'string' ? replace$c(result, tester, fixIllFormed) : result;
4478 }
4479 });
4480 }
4481
4482 var es_object_getOwnPropertySymbols = {};
4483
4484 var $$2T = _export;
4485 var NATIVE_SYMBOL$1 = symbolConstructorDetection;
4486 var fails$1a = fails$1m;
4487 var getOwnPropertySymbolsModule$1 = objectGetOwnPropertySymbols;
4488 var toObject$q = toObject$t;
4489
4490 // V8 ~ Chrome 38 and 39 `Object.getOwnPropertySymbols` fails on primitives
4491 // https://bugs.chromium.org/p/v8/issues/detail?id=3443
4492 var FORCED$D = !NATIVE_SYMBOL$1 || fails$1a(function () {
4493 getOwnPropertySymbolsModule$1.f(1);
4494 });
4495
4496 // `Object.getOwnPropertySymbols` method
4497 // https://tc39.es/ecma262/#sec-object.getownpropertysymbols
4498 $$2T({
4499 target: 'Object',
4500 stat: true,
4501 forced: FORCED$D
4502 }, {
4503 getOwnPropertySymbols: function getOwnPropertySymbols(it) {
4504 var $getOwnPropertySymbols = getOwnPropertySymbolsModule$1.f;
4505 return $getOwnPropertySymbols ? $getOwnPropertySymbols(toObject$q(it)) : [];
4506 }
4507 });
4508
4509 var es_symbol_description = {};
4510
4511 // `Symbol.prototype.description` getter
4512 // https://tc39.es/ecma262/#sec-symbol.prototype.description
4513 'use strict';
4514 var $$2S = _export;
4515 var DESCRIPTORS$A = descriptors;
4516 var global$N = global$Z;
4517 var uncurryThis$12 = functionUncurryThis;
4518 var hasOwn$k = hasOwnProperty_1;
4519 var isCallable$k = isCallable$z;
4520 var isPrototypeOf$9 = objectIsPrototypeOf;
4521 var toString$u = toString$x;
4522 var defineBuiltInAccessor$f = defineBuiltInAccessor$h;
4523 var copyConstructorProperties$3 = copyConstructorProperties$5;
4524 var NativeSymbol = global$N.Symbol;
4525 var SymbolPrototype = NativeSymbol && NativeSymbol.prototype;
4526 if (DESCRIPTORS$A && isCallable$k(NativeSymbol) && (!('description' in SymbolPrototype) ||
4527 // Safari 12 bug
4528 NativeSymbol().description !== undefined)) {
4529 var EmptyStringDescriptionStore = {};
4530 // wrap Symbol constructor for correct work with undefined description
4531 var SymbolWrapper = function _Symbol() {
4532 var description = arguments.length < 1 || arguments[0] === undefined ? undefined : toString$u(arguments[0]);
4533 var result = isPrototypeOf$9(SymbolPrototype, this) ? new NativeSymbol(description)
4534 // in Edge 13, String(Symbol(undefined)) === 'Symbol(undefined)'
4535 : description === undefined ? NativeSymbol() : NativeSymbol(description);
4536 if (description === '') EmptyStringDescriptionStore[result] = true;
4537 return result;
4538 };
4539 copyConstructorProperties$3(SymbolWrapper, NativeSymbol);
4540 SymbolWrapper.prototype = SymbolPrototype;
4541 SymbolPrototype.constructor = SymbolWrapper;
4542 var NATIVE_SYMBOL = String(NativeSymbol('test')) == 'Symbol(test)';
4543 var thisSymbolValue = uncurryThis$12(SymbolPrototype.valueOf);
4544 var symbolDescriptiveString = uncurryThis$12(SymbolPrototype.toString);
4545 var regexp = /^Symbol\((.*)\)[^)]+$/;
4546 var replace$b = uncurryThis$12(''.replace);
4547 var stringSlice$g = uncurryThis$12(''.slice);
4548 defineBuiltInAccessor$f(SymbolPrototype, 'description', {
4549 configurable: true,
4550 get: function description() {
4551 var symbol = thisSymbolValue(this);
4552 if (hasOwn$k(EmptyStringDescriptionStore, symbol)) return '';
4553 var string = symbolDescriptiveString(symbol);
4554 var desc = NATIVE_SYMBOL ? stringSlice$g(string, 7, -1) : replace$b(string, regexp, '$1');
4555 return desc === '' ? undefined : desc;
4556 }
4557 });
4558 $$2S({
4559 global: true,
4560 constructor: true,
4561 forced: true
4562 }, {
4563 Symbol: SymbolWrapper
4564 });
4565 }
4566
4567 var es_symbol_asyncIterator = {};
4568
4569 var defineWellKnownSymbol$c = wellKnownSymbolDefine;
4570
4571 // `Symbol.asyncIterator` well-known symbol
4572 // https://tc39.es/ecma262/#sec-symbol.asynciterator
4573 defineWellKnownSymbol$c('asyncIterator');
4574
4575 var es_symbol_hasInstance = {};
4576
4577 var defineWellKnownSymbol$b = wellKnownSymbolDefine;
4578
4579 // `Symbol.hasInstance` well-known symbol
4580 // https://tc39.es/ecma262/#sec-symbol.hasinstance
4581 defineWellKnownSymbol$b('hasInstance');
4582
4583 var es_symbol_isConcatSpreadable = {};
4584
4585 var defineWellKnownSymbol$a = wellKnownSymbolDefine;
4586
4587 // `Symbol.isConcatSpreadable` well-known symbol
4588 // https://tc39.es/ecma262/#sec-symbol.isconcatspreadable
4589 defineWellKnownSymbol$a('isConcatSpreadable');
4590
4591 var es_symbol_iterator = {};
4592
4593 var defineWellKnownSymbol$9 = wellKnownSymbolDefine;
4594
4595 // `Symbol.iterator` well-known symbol
4596 // https://tc39.es/ecma262/#sec-symbol.iterator
4597 defineWellKnownSymbol$9('iterator');
4598
4599 var es_symbol_match = {};
4600
4601 var defineWellKnownSymbol$8 = wellKnownSymbolDefine;
4602
4603 // `Symbol.match` well-known symbol
4604 // https://tc39.es/ecma262/#sec-symbol.match
4605 defineWellKnownSymbol$8('match');
4606
4607 var es_symbol_matchAll = {};
4608
4609 var defineWellKnownSymbol$7 = wellKnownSymbolDefine;
4610
4611 // `Symbol.matchAll` well-known symbol
4612 // https://tc39.es/ecma262/#sec-symbol.matchall
4613 defineWellKnownSymbol$7('matchAll');
4614
4615 var es_symbol_replace = {};
4616
4617 var defineWellKnownSymbol$6 = wellKnownSymbolDefine;
4618
4619 // `Symbol.replace` well-known symbol
4620 // https://tc39.es/ecma262/#sec-symbol.replace
4621 defineWellKnownSymbol$6('replace');
4622
4623 var es_symbol_search = {};
4624
4625 var defineWellKnownSymbol$5 = wellKnownSymbolDefine;
4626
4627 // `Symbol.search` well-known symbol
4628 // https://tc39.es/ecma262/#sec-symbol.search
4629 defineWellKnownSymbol$5('search');
4630
4631 var es_symbol_species = {};
4632
4633 var defineWellKnownSymbol$4 = wellKnownSymbolDefine;
4634
4635 // `Symbol.species` well-known symbol
4636 // https://tc39.es/ecma262/#sec-symbol.species
4637 defineWellKnownSymbol$4('species');
4638
4639 var es_symbol_split = {};
4640
4641 var defineWellKnownSymbol$3 = wellKnownSymbolDefine;
4642
4643 // `Symbol.split` well-known symbol
4644 // https://tc39.es/ecma262/#sec-symbol.split
4645 defineWellKnownSymbol$3('split');
4646
4647 var es_symbol_toPrimitive = {};
4648
4649 var defineWellKnownSymbol$2 = wellKnownSymbolDefine;
4650 var defineSymbolToPrimitive = symbolDefineToPrimitive;
4651
4652 // `Symbol.toPrimitive` well-known symbol
4653 // https://tc39.es/ecma262/#sec-symbol.toprimitive
4654 defineWellKnownSymbol$2('toPrimitive');
4655
4656 // `Symbol.prototype[@@toPrimitive]` method
4657 // https://tc39.es/ecma262/#sec-symbol.prototype-@@toprimitive
4658 defineSymbolToPrimitive();
4659
4660 var es_symbol_toStringTag = {};
4661
4662 var getBuiltIn$e = getBuiltIn$m;
4663 var defineWellKnownSymbol$1 = wellKnownSymbolDefine;
4664 var setToStringTag$b = setToStringTag$d;
4665
4666 // `Symbol.toStringTag` well-known symbol
4667 // https://tc39.es/ecma262/#sec-symbol.tostringtag
4668 defineWellKnownSymbol$1('toStringTag');
4669
4670 // `Symbol.prototype[@@toStringTag]` property
4671 // https://tc39.es/ecma262/#sec-symbol.prototype-@@tostringtag
4672 setToStringTag$b(getBuiltIn$e('Symbol'), 'Symbol');
4673
4674 var es_symbol_unscopables = {};
4675
4676 var defineWellKnownSymbol = wellKnownSymbolDefine;
4677
4678 // `Symbol.unscopables` well-known symbol
4679 // https://tc39.es/ecma262/#sec-symbol.unscopables
4680 defineWellKnownSymbol('unscopables');
4681
4682 var es_error_cause = {};
4683
4684 var uncurryThis$11 = functionUncurryThis;
4685 var aCallable$i = aCallable$l;
4686 var functionUncurryThisAccessor = function functionUncurryThisAccessor(object, key, method) {
4687 try {
4688 // eslint-disable-next-line es/no-object-getownpropertydescriptor -- safe
4689 return uncurryThis$11(aCallable$i(Object.getOwnPropertyDescriptor(object, key)[method]));
4690 } catch (error) {/* empty */}
4691 };
4692 var functionUncurryThisAccessor$1 = /*@__PURE__*/getDefaultExportFromCjs(functionUncurryThisAccessor);
4693
4694 var isCallable$j = isCallable$z;
4695 var $String$2 = String;
4696 var $TypeError$j = TypeError;
4697 var aPossiblePrototype$2 = function aPossiblePrototype(argument) {
4698 if (_typeof(argument) == 'object' || isCallable$j(argument)) return argument;
4699 throw $TypeError$j("Can't set " + $String$2(argument) + ' as a prototype');
4700 };
4701 var aPossiblePrototype$3 = /*@__PURE__*/getDefaultExportFromCjs(aPossiblePrototype$2);
4702
4703 /* eslint-disable no-proto -- safe */
4704 var uncurryThisAccessor = functionUncurryThisAccessor;
4705 var anObject$x = anObject$D;
4706 var aPossiblePrototype$1 = aPossiblePrototype$2;
4707
4708 // `Object.setPrototypeOf` method
4709 // https://tc39.es/ecma262/#sec-object.setprototypeof
4710 // Works with __proto__ only. Old v8 can't work with null proto objects.
4711 // eslint-disable-next-line es/no-object-setprototypeof -- safe
4712 var objectSetPrototypeOf$1 = Object.setPrototypeOf || ('__proto__' in {} ? function () {
4713 var CORRECT_SETTER = false;
4714 var test = {};
4715 var setter;
4716 try {
4717 setter = uncurryThisAccessor(Object.prototype, '__proto__', 'set');
4718 setter(test, []);
4719 CORRECT_SETTER = test instanceof Array;
4720 } catch (error) {/* empty */}
4721 return function setPrototypeOf(O, proto) {
4722 anObject$x(O);
4723 aPossiblePrototype$1(proto);
4724 if (CORRECT_SETTER) setter(O, proto);else O.__proto__ = proto;
4725 return O;
4726 };
4727 }() : undefined);
4728 var objectSetPrototypeOf$2 = /*@__PURE__*/getDefaultExportFromCjs(objectSetPrototypeOf$1);
4729
4730 var defineProperty$9 = objectDefineProperty.f;
4731 var proxyAccessor$2 = function proxyAccessor(Target, Source, key) {
4732 key in Target || defineProperty$9(Target, key, {
4733 configurable: true,
4734 get: function get() {
4735 return Source[key];
4736 },
4737 set: function set(it) {
4738 Source[key] = it;
4739 }
4740 });
4741 };
4742 var proxyAccessor$3 = /*@__PURE__*/getDefaultExportFromCjs(proxyAccessor$2);
4743
4744 var isCallable$i = isCallable$z;
4745 var isObject$s = isObject$z;
4746 var setPrototypeOf$9 = objectSetPrototypeOf$1;
4747
4748 // makes subclassing work correct for wrapped built-ins
4749 var inheritIfRequired$6 = function inheritIfRequired($this, dummy, Wrapper) {
4750 var NewTarget, NewTargetPrototype;
4751 if (
4752 // it can work only with native `setPrototypeOf`
4753 setPrototypeOf$9 &&
4754 // we haven't completely correct pre-ES6 way for getting `new.target`, so use this
4755 isCallable$i(NewTarget = dummy.constructor) && NewTarget !== Wrapper && isObject$s(NewTargetPrototype = NewTarget.prototype) && NewTargetPrototype !== Wrapper.prototype) setPrototypeOf$9($this, NewTargetPrototype);
4756 return $this;
4757 };
4758 var inheritIfRequired$7 = /*@__PURE__*/getDefaultExportFromCjs(inheritIfRequired$6);
4759
4760 var toString$t = toString$x;
4761 var normalizeStringArgument$5 = function normalizeStringArgument(argument, $default) {
4762 return argument === undefined ? arguments.length < 2 ? '' : $default : toString$t(argument);
4763 };
4764 var normalizeStringArgument$6 = /*@__PURE__*/getDefaultExportFromCjs(normalizeStringArgument$5);
4765
4766 var isObject$r = isObject$z;
4767 var createNonEnumerableProperty$c = createNonEnumerableProperty$f;
4768
4769 // `InstallErrorCause` abstract operation
4770 // https://tc39.es/proposal-error-cause/#sec-errorobjects-install-error-cause
4771 var installErrorCause$2 = function installErrorCause(O, options) {
4772 if (isObject$r(options) && 'cause' in options) {
4773 createNonEnumerableProperty$c(O, 'cause', options.cause);
4774 }
4775 };
4776 var installErrorCause$3 = /*@__PURE__*/getDefaultExportFromCjs(installErrorCause$2);
4777
4778 var uncurryThis$10 = functionUncurryThis;
4779 var $Error$1 = Error;
4780 var replace$a = uncurryThis$10(''.replace);
4781 var TEST = function (arg) {
4782 return String($Error$1(arg).stack);
4783 }('zxcasd');
4784 // eslint-disable-next-line redos/no-vulnerable -- safe
4785 var V8_OR_CHAKRA_STACK_ENTRY = /\n\s*at [^:]*:[^\n]*/;
4786 var IS_V8_OR_CHAKRA_STACK = V8_OR_CHAKRA_STACK_ENTRY.test(TEST);
4787 var errorStackClear = function errorStackClear(stack, dropEntries) {
4788 if (IS_V8_OR_CHAKRA_STACK && typeof stack == 'string' && !$Error$1.prepareStackTrace) {
4789 while (dropEntries--) stack = replace$a(stack, V8_OR_CHAKRA_STACK_ENTRY, '');
4790 }
4791 return stack;
4792 };
4793 var errorStackClear$1 = /*@__PURE__*/getDefaultExportFromCjs(errorStackClear);
4794
4795 var fails$19 = fails$1m;
4796 var createPropertyDescriptor$7 = createPropertyDescriptor$c;
4797 var errorStackInstallable = !fails$19(function () {
4798 var error = Error('a');
4799 if (!('stack' in error)) return true;
4800 // eslint-disable-next-line es/no-object-defineproperty -- safe
4801 Object.defineProperty(error, 'stack', createPropertyDescriptor$7(1, 7));
4802 return error.stack !== 7;
4803 });
4804 var errorStackInstallable$1 = /*@__PURE__*/getDefaultExportFromCjs(errorStackInstallable);
4805
4806 var createNonEnumerableProperty$b = createNonEnumerableProperty$f;
4807 var clearErrorStack$2 = errorStackClear;
4808 var ERROR_STACK_INSTALLABLE$1 = errorStackInstallable;
4809
4810 // non-standard V8
4811 var captureStackTrace = Error.captureStackTrace;
4812 var errorStackInstall = function errorStackInstall(error, C, stack, dropEntries) {
4813 if (ERROR_STACK_INSTALLABLE$1) {
4814 if (captureStackTrace) captureStackTrace(error, C);else createNonEnumerableProperty$b(error, 'stack', clearErrorStack$2(stack, dropEntries));
4815 }
4816 };
4817 var errorStackInstall$1 = /*@__PURE__*/getDefaultExportFromCjs(errorStackInstall);
4818
4819 'use strict';
4820 var getBuiltIn$d = getBuiltIn$m;
4821 var hasOwn$j = hasOwnProperty_1;
4822 var createNonEnumerableProperty$a = createNonEnumerableProperty$f;
4823 var isPrototypeOf$8 = objectIsPrototypeOf;
4824 var setPrototypeOf$8 = objectSetPrototypeOf$1;
4825 var copyConstructorProperties$2 = copyConstructorProperties$5;
4826 var proxyAccessor$1 = proxyAccessor$2;
4827 var inheritIfRequired$5 = inheritIfRequired$6;
4828 var normalizeStringArgument$4 = normalizeStringArgument$5;
4829 var installErrorCause$1 = installErrorCause$2;
4830 var installErrorStack$1 = errorStackInstall;
4831 var DESCRIPTORS$z = descriptors;
4832 var IS_PURE$i = isPure;
4833 var wrapErrorConstructorWithCause$2 = function wrapErrorConstructorWithCause(FULL_NAME, wrapper, FORCED, IS_AGGREGATE_ERROR) {
4834 var STACK_TRACE_LIMIT = 'stackTraceLimit';
4835 var OPTIONS_POSITION = IS_AGGREGATE_ERROR ? 2 : 1;
4836 var path = FULL_NAME.split('.');
4837 var ERROR_NAME = path[path.length - 1];
4838 var OriginalError = getBuiltIn$d.apply(null, path);
4839 if (!OriginalError) return;
4840 var OriginalErrorPrototype = OriginalError.prototype;
4841
4842 // V8 9.3- bug https://bugs.chromium.org/p/v8/issues/detail?id=12006
4843 if (!IS_PURE$i && hasOwn$j(OriginalErrorPrototype, 'cause')) delete OriginalErrorPrototype.cause;
4844 if (!FORCED) return OriginalError;
4845 var BaseError = getBuiltIn$d('Error');
4846 var WrappedError = wrapper(function (a, b) {
4847 var message = normalizeStringArgument$4(IS_AGGREGATE_ERROR ? b : a, undefined);
4848 var result = IS_AGGREGATE_ERROR ? new OriginalError(a) : new OriginalError();
4849 if (message !== undefined) createNonEnumerableProperty$a(result, 'message', message);
4850 installErrorStack$1(result, WrappedError, result.stack, 2);
4851 if (this && isPrototypeOf$8(OriginalErrorPrototype, this)) inheritIfRequired$5(result, this, WrappedError);
4852 if (arguments.length > OPTIONS_POSITION) installErrorCause$1(result, arguments[OPTIONS_POSITION]);
4853 return result;
4854 });
4855 WrappedError.prototype = OriginalErrorPrototype;
4856 if (ERROR_NAME !== 'Error') {
4857 if (setPrototypeOf$8) setPrototypeOf$8(WrappedError, BaseError);else copyConstructorProperties$2(WrappedError, BaseError, {
4858 name: true
4859 });
4860 } else if (DESCRIPTORS$z && STACK_TRACE_LIMIT in OriginalError) {
4861 proxyAccessor$1(WrappedError, OriginalError, STACK_TRACE_LIMIT);
4862 proxyAccessor$1(WrappedError, OriginalError, 'prepareStackTrace');
4863 }
4864 copyConstructorProperties$2(WrappedError, OriginalError);
4865 if (!IS_PURE$i) try {
4866 // Safari 13- bug: WebAssembly errors does not have a proper `.name`
4867 if (OriginalErrorPrototype.name !== ERROR_NAME) {
4868 createNonEnumerableProperty$a(OriginalErrorPrototype, 'name', ERROR_NAME);
4869 }
4870 OriginalErrorPrototype.constructor = WrappedError;
4871 } catch (error) {/* empty */}
4872 return WrappedError;
4873 };
4874 var wrapErrorConstructorWithCause$3 = /*@__PURE__*/getDefaultExportFromCjs(wrapErrorConstructorWithCause$2);
4875
4876 /* eslint-disable no-unused-vars -- required for functions `.length` */
4877 var $$2R = _export;
4878 var global$M = global$Z;
4879 var apply$9 = functionApply$1;
4880 var wrapErrorConstructorWithCause$1 = wrapErrorConstructorWithCause$2;
4881 var WEB_ASSEMBLY = 'WebAssembly';
4882 var WebAssembly$2 = global$M[WEB_ASSEMBLY];
4883 var FORCED$C = Error('e', {
4884 cause: 7
4885 }).cause !== 7;
4886 var exportGlobalErrorCauseWrapper = function exportGlobalErrorCauseWrapper(ERROR_NAME, wrapper) {
4887 var O = {};
4888 O[ERROR_NAME] = wrapErrorConstructorWithCause$1(ERROR_NAME, wrapper, FORCED$C);
4889 $$2R({
4890 global: true,
4891 constructor: true,
4892 arity: 1,
4893 forced: FORCED$C
4894 }, O);
4895 };
4896 var exportWebAssemblyErrorCauseWrapper = function exportWebAssemblyErrorCauseWrapper(ERROR_NAME, wrapper) {
4897 if (WebAssembly$2 && WebAssembly$2[ERROR_NAME]) {
4898 var O = {};
4899 O[ERROR_NAME] = wrapErrorConstructorWithCause$1(WEB_ASSEMBLY + '.' + ERROR_NAME, wrapper, FORCED$C);
4900 $$2R({
4901 target: WEB_ASSEMBLY,
4902 stat: true,
4903 constructor: true,
4904 arity: 1,
4905 forced: FORCED$C
4906 }, O);
4907 }
4908 };
4909
4910 // https://tc39.es/ecma262/#sec-nativeerror
4911 // https://github.com/tc39/proposal-error-cause
4912 exportGlobalErrorCauseWrapper('Error', function (init) {
4913 return function Error(message) {
4914 return apply$9(init, this, arguments);
4915 };
4916 });
4917 exportGlobalErrorCauseWrapper('EvalError', function (init) {
4918 return function EvalError(message) {
4919 return apply$9(init, this, arguments);
4920 };
4921 });
4922 exportGlobalErrorCauseWrapper('RangeError', function (init) {
4923 return function RangeError(message) {
4924 return apply$9(init, this, arguments);
4925 };
4926 });
4927 exportGlobalErrorCauseWrapper('ReferenceError', function (init) {
4928 return function ReferenceError(message) {
4929 return apply$9(init, this, arguments);
4930 };
4931 });
4932 exportGlobalErrorCauseWrapper('SyntaxError', function (init) {
4933 return function SyntaxError(message) {
4934 return apply$9(init, this, arguments);
4935 };
4936 });
4937 exportGlobalErrorCauseWrapper('TypeError', function (init) {
4938 return function TypeError(message) {
4939 return apply$9(init, this, arguments);
4940 };
4941 });
4942 exportGlobalErrorCauseWrapper('URIError', function (init) {
4943 return function URIError(message) {
4944 return apply$9(init, this, arguments);
4945 };
4946 });
4947 exportWebAssemblyErrorCauseWrapper('CompileError', function (init) {
4948 return function CompileError(message) {
4949 return apply$9(init, this, arguments);
4950 };
4951 });
4952 exportWebAssemblyErrorCauseWrapper('LinkError', function (init) {
4953 return function LinkError(message) {
4954 return apply$9(init, this, arguments);
4955 };
4956 });
4957 exportWebAssemblyErrorCauseWrapper('RuntimeError', function (init) {
4958 return function RuntimeError(message) {
4959 return apply$9(init, this, arguments);
4960 };
4961 });
4962
4963 var es_error_toString = {};
4964
4965 'use strict';
4966 var DESCRIPTORS$y = descriptors;
4967 var fails$18 = fails$1m;
4968 var anObject$w = anObject$D;
4969 var create$b = objectCreate;
4970 var normalizeStringArgument$3 = normalizeStringArgument$5;
4971 var nativeErrorToString = Error.prototype.toString;
4972 var INCORRECT_TO_STRING$1 = fails$18(function () {
4973 if (DESCRIPTORS$y) {
4974 // Chrome 32- incorrectly call accessor
4975 // eslint-disable-next-line es/no-object-defineproperty -- safe
4976 var object = create$b(Object.defineProperty({}, 'name', {
4977 get: function get() {
4978 return this === object;
4979 }
4980 }));
4981 if (nativeErrorToString.call(object) !== 'true') return true;
4982 }
4983 // FF10- does not properly handle non-strings
4984 return nativeErrorToString.call({
4985 message: 1,
4986 name: 2
4987 }) !== '2: 1'
4988 // IE8 does not properly handle defaults
4989 || nativeErrorToString.call({}) !== 'Error';
4990 });
4991 var errorToString$2 = INCORRECT_TO_STRING$1 ? function toString() {
4992 var O = anObject$w(this);
4993 var name = normalizeStringArgument$3(O.name, 'Error');
4994 var message = normalizeStringArgument$3(O.message);
4995 return !name ? message : !message ? name : name + ': ' + message;
4996 } : nativeErrorToString;
4997 var errorToString$3 = /*@__PURE__*/getDefaultExportFromCjs(errorToString$2);
4998
4999 var defineBuiltIn$i = defineBuiltIn$m;
5000 var errorToString$1 = errorToString$2;
5001 var ErrorPrototype$1 = Error.prototype;
5002
5003 // `Error.prototype.toString` method fix
5004 // https://tc39.es/ecma262/#sec-error.prototype.tostring
5005 if (ErrorPrototype$1.toString !== errorToString$1) {
5006 defineBuiltIn$i(ErrorPrototype$1, 'toString', errorToString$1);
5007 }
5008
5009 var es_aggregateError = {};
5010
5011 var es_aggregateError_constructor = {};
5012
5013 var fails$17 = fails$1m;
5014 var correctPrototypeGetter = !fails$17(function () {
5015 function F() {/* empty */}
5016 F.prototype.constructor = null;
5017 // eslint-disable-next-line es/no-object-getprototypeof -- required for testing
5018 return Object.getPrototypeOf(new F()) !== F.prototype;
5019 });
5020 var correctPrototypeGetter$1 = /*@__PURE__*/getDefaultExportFromCjs(correctPrototypeGetter);
5021
5022 var hasOwn$i = hasOwnProperty_1;
5023 var isCallable$h = isCallable$z;
5024 var toObject$p = toObject$t;
5025 var sharedKey = sharedKey$4;
5026 var CORRECT_PROTOTYPE_GETTER$2 = correctPrototypeGetter;
5027 var IE_PROTO = sharedKey('IE_PROTO');
5028 var $Object$1 = Object;
5029 var ObjectPrototype$4 = $Object$1.prototype;
5030
5031 // `Object.getPrototypeOf` method
5032 // https://tc39.es/ecma262/#sec-object.getprototypeof
5033 // eslint-disable-next-line es/no-object-getprototypeof -- safe
5034 var objectGetPrototypeOf$1 = CORRECT_PROTOTYPE_GETTER$2 ? $Object$1.getPrototypeOf : function (O) {
5035 var object = toObject$p(O);
5036 if (hasOwn$i(object, IE_PROTO)) return object[IE_PROTO];
5037 var constructor = object.constructor;
5038 if (isCallable$h(constructor) && object instanceof constructor) {
5039 return constructor.prototype;
5040 }
5041 return object instanceof $Object$1 ? ObjectPrototype$4 : null;
5042 };
5043 var objectGetPrototypeOf$2 = /*@__PURE__*/getDefaultExportFromCjs(objectGetPrototypeOf$1);
5044
5045 var iterators = {};
5046 var iterators$1 = /*@__PURE__*/getDefaultExportFromCjs(iterators);
5047
5048 var wellKnownSymbol$q = wellKnownSymbol$z;
5049 var Iterators$4 = iterators;
5050 var ITERATOR$a = wellKnownSymbol$q('iterator');
5051 var ArrayPrototype$1 = Array.prototype;
5052
5053 // check on default Array iterator
5054 var isArrayIteratorMethod$3 = function isArrayIteratorMethod(it) {
5055 return it !== undefined && (Iterators$4.Array === it || ArrayPrototype$1[ITERATOR$a] === it);
5056 };
5057 var isArrayIteratorMethod$4 = /*@__PURE__*/getDefaultExportFromCjs(isArrayIteratorMethod$3);
5058
5059 var classof$g = classof$m;
5060 var getMethod$7 = getMethod$9;
5061 var isNullOrUndefined$b = isNullOrUndefined$e;
5062 var Iterators$3 = iterators;
5063 var wellKnownSymbol$p = wellKnownSymbol$z;
5064 var ITERATOR$9 = wellKnownSymbol$p('iterator');
5065 var getIteratorMethod$5 = function getIteratorMethod(it) {
5066 if (!isNullOrUndefined$b(it)) return getMethod$7(it, ITERATOR$9) || getMethod$7(it, '@@iterator') || Iterators$3[classof$g(it)];
5067 };
5068 var getIteratorMethod$6 = /*@__PURE__*/getDefaultExportFromCjs(getIteratorMethod$5);
5069
5070 var call$v = functionCall;
5071 var aCallable$h = aCallable$l;
5072 var anObject$v = anObject$D;
5073 var tryToString$4 = tryToString$7;
5074 var getIteratorMethod$4 = getIteratorMethod$5;
5075 var $TypeError$i = TypeError;
5076 var getIterator$4 = function getIterator(argument, usingIterator) {
5077 var iteratorMethod = arguments.length < 2 ? getIteratorMethod$4(argument) : usingIterator;
5078 if (aCallable$h(iteratorMethod)) return anObject$v(call$v(iteratorMethod, argument));
5079 throw $TypeError$i(tryToString$4(argument) + ' is not iterable');
5080 };
5081 var getIterator$5 = /*@__PURE__*/getDefaultExportFromCjs(getIterator$4);
5082
5083 var call$u = functionCall;
5084 var anObject$u = anObject$D;
5085 var getMethod$6 = getMethod$9;
5086 var iteratorClose$2 = function iteratorClose(iterator, kind, value) {
5087 var innerResult, innerError;
5088 anObject$u(iterator);
5089 try {
5090 innerResult = getMethod$6(iterator, 'return');
5091 if (!innerResult) {
5092 if (kind === 'throw') throw value;
5093 return value;
5094 }
5095 innerResult = call$u(innerResult, iterator);
5096 } catch (error) {
5097 innerError = true;
5098 innerResult = error;
5099 }
5100 if (kind === 'throw') throw value;
5101 if (innerError) throw innerResult;
5102 anObject$u(innerResult);
5103 return value;
5104 };
5105 var iteratorClose$3 = /*@__PURE__*/getDefaultExportFromCjs(iteratorClose$2);
5106
5107 var bind$c = functionBindContext;
5108 var call$t = functionCall;
5109 var anObject$t = anObject$D;
5110 var tryToString$3 = tryToString$7;
5111 var isArrayIteratorMethod$2 = isArrayIteratorMethod$3;
5112 var lengthOfArrayLike$p = lengthOfArrayLike$t;
5113 var isPrototypeOf$7 = objectIsPrototypeOf;
5114 var getIterator$3 = getIterator$4;
5115 var getIteratorMethod$3 = getIteratorMethod$5;
5116 var iteratorClose$1 = iteratorClose$2;
5117 var $TypeError$h = TypeError;
5118 var Result = function Result(stopped, result) {
5119 this.stopped = stopped;
5120 this.result = result;
5121 };
5122 var ResultPrototype = Result.prototype;
5123 var iterate$a = function iterate(iterable, unboundFunction, options) {
5124 var that = options && options.that;
5125 var AS_ENTRIES = !!(options && options.AS_ENTRIES);
5126 var IS_RECORD = !!(options && options.IS_RECORD);
5127 var IS_ITERATOR = !!(options && options.IS_ITERATOR);
5128 var INTERRUPTED = !!(options && options.INTERRUPTED);
5129 var fn = bind$c(unboundFunction, that);
5130 var iterator, iterFn, index, length, result, next, step;
5131 var stop = function stop(condition) {
5132 if (iterator) iteratorClose$1(iterator, 'normal', condition);
5133 return new Result(true, condition);
5134 };
5135 var callFn = function callFn(value) {
5136 if (AS_ENTRIES) {
5137 anObject$t(value);
5138 return INTERRUPTED ? fn(value[0], value[1], stop) : fn(value[0], value[1]);
5139 }
5140 return INTERRUPTED ? fn(value, stop) : fn(value);
5141 };
5142 if (IS_RECORD) {
5143 iterator = iterable.iterator;
5144 } else if (IS_ITERATOR) {
5145 iterator = iterable;
5146 } else {
5147 iterFn = getIteratorMethod$3(iterable);
5148 if (!iterFn) throw $TypeError$h(tryToString$3(iterable) + ' is not iterable');
5149 // optimisation for array iterators
5150 if (isArrayIteratorMethod$2(iterFn)) {
5151 for (index = 0, length = lengthOfArrayLike$p(iterable); length > index; index++) {
5152 result = callFn(iterable[index]);
5153 if (result && isPrototypeOf$7(ResultPrototype, result)) return result;
5154 }
5155 return new Result(false);
5156 }
5157 iterator = getIterator$3(iterable, iterFn);
5158 }
5159 next = IS_RECORD ? iterable.next : iterator.next;
5160 while (!(step = call$t(next, iterator)).done) {
5161 try {
5162 result = callFn(step.value);
5163 } catch (error) {
5164 iteratorClose$1(iterator, 'throw', error);
5165 }
5166 if (_typeof(result) == 'object' && result && isPrototypeOf$7(ResultPrototype, result)) return result;
5167 }
5168 return new Result(false);
5169 };
5170 var iterate$b = /*@__PURE__*/getDefaultExportFromCjs(iterate$a);
5171
5172 'use strict';
5173 var $$2Q = _export;
5174 var isPrototypeOf$6 = objectIsPrototypeOf;
5175 var getPrototypeOf$a = objectGetPrototypeOf$1;
5176 var setPrototypeOf$7 = objectSetPrototypeOf$1;
5177 var copyConstructorProperties$1 = copyConstructorProperties$5;
5178 var create$a = objectCreate;
5179 var createNonEnumerableProperty$9 = createNonEnumerableProperty$f;
5180 var createPropertyDescriptor$6 = createPropertyDescriptor$c;
5181 var installErrorCause = installErrorCause$2;
5182 var installErrorStack = errorStackInstall;
5183 var iterate$9 = iterate$a;
5184 var normalizeStringArgument$2 = normalizeStringArgument$5;
5185 var wellKnownSymbol$o = wellKnownSymbol$z;
5186 var TO_STRING_TAG$2 = wellKnownSymbol$o('toStringTag');
5187 var $Error = Error;
5188 var push$a = [].push;
5189 var $AggregateError$1 = function AggregateError(errors, message /* , options */) {
5190 var isInstance = isPrototypeOf$6(AggregateErrorPrototype, this);
5191 var that;
5192 if (setPrototypeOf$7) {
5193 that = setPrototypeOf$7($Error(), isInstance ? getPrototypeOf$a(this) : AggregateErrorPrototype);
5194 } else {
5195 that = isInstance ? this : create$a(AggregateErrorPrototype);
5196 createNonEnumerableProperty$9(that, TO_STRING_TAG$2, 'Error');
5197 }
5198 if (message !== undefined) createNonEnumerableProperty$9(that, 'message', normalizeStringArgument$2(message));
5199 installErrorStack(that, $AggregateError$1, that.stack, 1);
5200 if (arguments.length > 2) installErrorCause(that, arguments[2]);
5201 var errorsArray = [];
5202 iterate$9(errors, push$a, {
5203 that: errorsArray
5204 });
5205 createNonEnumerableProperty$9(that, 'errors', errorsArray);
5206 return that;
5207 };
5208 if (setPrototypeOf$7) setPrototypeOf$7($AggregateError$1, $Error);else copyConstructorProperties$1($AggregateError$1, $Error, {
5209 name: true
5210 });
5211 var AggregateErrorPrototype = $AggregateError$1.prototype = create$a($Error.prototype, {
5212 constructor: createPropertyDescriptor$6(1, $AggregateError$1),
5213 message: createPropertyDescriptor$6(1, ''),
5214 name: createPropertyDescriptor$6(1, 'AggregateError')
5215 });
5216
5217 // `AggregateError` constructor
5218 // https://tc39.es/ecma262/#sec-aggregate-error-constructor
5219 $$2Q({
5220 global: true,
5221 constructor: true,
5222 arity: 2
5223 }, {
5224 AggregateError: $AggregateError$1
5225 });
5226
5227 var es_aggregateError_cause = {};
5228
5229 var $$2P = _export;
5230 var getBuiltIn$c = getBuiltIn$m;
5231 var apply$8 = functionApply$1;
5232 var fails$16 = fails$1m;
5233 var wrapErrorConstructorWithCause = wrapErrorConstructorWithCause$2;
5234 var AGGREGATE_ERROR = 'AggregateError';
5235 var $AggregateError = getBuiltIn$c(AGGREGATE_ERROR);
5236 var FORCED$B = !fails$16(function () {
5237 return $AggregateError([1]).errors[0] !== 1;
5238 }) && fails$16(function () {
5239 return $AggregateError([1], AGGREGATE_ERROR, {
5240 cause: 7
5241 }).cause !== 7;
5242 });
5243
5244 // https://github.com/tc39/proposal-error-cause
5245 $$2P({
5246 global: true,
5247 constructor: true,
5248 arity: 2,
5249 forced: FORCED$B
5250 }, {
5251 AggregateError: wrapErrorConstructorWithCause(AGGREGATE_ERROR, function (init) {
5252 // eslint-disable-next-line no-unused-vars -- required for functions `.length`
5253 return function AggregateError(errors, message) {
5254 return apply$8(init, this, arguments);
5255 };
5256 }, FORCED$B, true)
5257 });
5258
5259 var es_array_at = {};
5260
5261 var wellKnownSymbol$n = wellKnownSymbol$z;
5262 var create$9 = objectCreate;
5263 var defineProperty$8 = objectDefineProperty.f;
5264 var UNSCOPABLES = wellKnownSymbol$n('unscopables');
5265 var ArrayPrototype = Array.prototype;
5266
5267 // Array.prototype[@@unscopables]
5268 // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
5269 if (ArrayPrototype[UNSCOPABLES] == undefined) {
5270 defineProperty$8(ArrayPrototype, UNSCOPABLES, {
5271 configurable: true,
5272 value: create$9(null)
5273 });
5274 }
5275
5276 // add a key to Array.prototype[@@unscopables]
5277 var addToUnscopables$e = function addToUnscopables(key) {
5278 ArrayPrototype[UNSCOPABLES][key] = true;
5279 };
5280 var addToUnscopables$f = /*@__PURE__*/getDefaultExportFromCjs(addToUnscopables$e);
5281
5282 'use strict';
5283 var $$2O = _export;
5284 var toObject$o = toObject$t;
5285 var lengthOfArrayLike$o = lengthOfArrayLike$t;
5286 var toIntegerOrInfinity$i = toIntegerOrInfinity$l;
5287 var addToUnscopables$d = addToUnscopables$e;
5288
5289 // `Array.prototype.at` method
5290 // https://github.com/tc39/proposal-relative-indexing-method
5291 $$2O({
5292 target: 'Array',
5293 proto: true
5294 }, {
5295 at: function at(index) {
5296 var O = toObject$o(this);
5297 var len = lengthOfArrayLike$o(O);
5298 var relativeIndex = toIntegerOrInfinity$i(index);
5299 var k = relativeIndex >= 0 ? relativeIndex : len + relativeIndex;
5300 return k < 0 || k >= len ? undefined : O[k];
5301 }
5302 });
5303 addToUnscopables$d('at');
5304
5305 var es_array_concat = {};
5306
5307 var $TypeError$g = TypeError;
5308 var MAX_SAFE_INTEGER = 0x1FFFFFFFFFFFFF; // 2 ** 53 - 1 == 9007199254740991
5309
5310 var doesNotExceedSafeInteger$6 = function doesNotExceedSafeInteger(it) {
5311 if (it > MAX_SAFE_INTEGER) throw $TypeError$g('Maximum allowed index exceeded');
5312 return it;
5313 };
5314 var doesNotExceedSafeInteger$7 = /*@__PURE__*/getDefaultExportFromCjs(doesNotExceedSafeInteger$6);
5315
5316 var fails$15 = fails$1m;
5317 var wellKnownSymbol$m = wellKnownSymbol$z;
5318 var V8_VERSION$2 = engineV8Version;
5319 var SPECIES$5 = wellKnownSymbol$m('species');
5320 var arrayMethodHasSpeciesSupport$5 = function arrayMethodHasSpeciesSupport(METHOD_NAME) {
5321 // We can't use this feature detection in V8 since it causes
5322 // deoptimization and serious performance degradation
5323 // https://github.com/zloirock/core-js/issues/677
5324 return V8_VERSION$2 >= 51 || !fails$15(function () {
5325 var array = [];
5326 var constructor = array.constructor = {};
5327 constructor[SPECIES$5] = function () {
5328 return {
5329 foo: 1
5330 };
5331 };
5332 return array[METHOD_NAME](Boolean).foo !== 1;
5333 });
5334 };
5335 var arrayMethodHasSpeciesSupport$6 = /*@__PURE__*/getDefaultExportFromCjs(arrayMethodHasSpeciesSupport$5);
5336
5337 'use strict';
5338 var $$2N = _export;
5339 var fails$14 = fails$1m;
5340 var isArray$6 = isArray$9;
5341 var isObject$q = isObject$z;
5342 var toObject$n = toObject$t;
5343 var lengthOfArrayLike$n = lengthOfArrayLike$t;
5344 var doesNotExceedSafeInteger$5 = doesNotExceedSafeInteger$6;
5345 var createProperty$7 = createProperty$9;
5346 var arraySpeciesCreate$3 = arraySpeciesCreate$5;
5347 var arrayMethodHasSpeciesSupport$4 = arrayMethodHasSpeciesSupport$5;
5348 var wellKnownSymbol$l = wellKnownSymbol$z;
5349 var V8_VERSION$1 = engineV8Version;
5350 var IS_CONCAT_SPREADABLE = wellKnownSymbol$l('isConcatSpreadable');
5351
5352 // We can't use this feature detection in V8 since it causes
5353 // deoptimization and serious performance degradation
5354 // https://github.com/zloirock/core-js/issues/679
5355 var IS_CONCAT_SPREADABLE_SUPPORT = V8_VERSION$1 >= 51 || !fails$14(function () {
5356 var array = [];
5357 array[IS_CONCAT_SPREADABLE] = false;
5358 return array.concat()[0] !== array;
5359 });
5360 var isConcatSpreadable = function isConcatSpreadable(O) {
5361 if (!isObject$q(O)) return false;
5362 var spreadable = O[IS_CONCAT_SPREADABLE];
5363 return spreadable !== undefined ? !!spreadable : isArray$6(O);
5364 };
5365 var FORCED$A = !IS_CONCAT_SPREADABLE_SUPPORT || !arrayMethodHasSpeciesSupport$4('concat');
5366
5367 // `Array.prototype.concat` method
5368 // https://tc39.es/ecma262/#sec-array.prototype.concat
5369 // with adding support of @@isConcatSpreadable and @@species
5370 $$2N({
5371 target: 'Array',
5372 proto: true,
5373 arity: 1,
5374 forced: FORCED$A
5375 }, {
5376 // eslint-disable-next-line no-unused-vars -- required for `.length`
5377 concat: function concat(arg) {
5378 var O = toObject$n(this);
5379 var A = arraySpeciesCreate$3(O, 0);
5380 var n = 0;
5381 var i, k, length, len, E;
5382 for (i = -1, length = arguments.length; i < length; i++) {
5383 E = i === -1 ? O : arguments[i];
5384 if (isConcatSpreadable(E)) {
5385 len = lengthOfArrayLike$n(E);
5386 doesNotExceedSafeInteger$5(n + len);
5387 for (k = 0; k < len; k++, n++) if (k in E) createProperty$7(A, n, E[k]);
5388 } else {
5389 doesNotExceedSafeInteger$5(n + 1);
5390 createProperty$7(A, n++, E);
5391 }
5392 }
5393 A.length = n;
5394 return A;
5395 }
5396 });
5397
5398 var es_array_copyWithin = {};
5399
5400 'use strict';
5401 var tryToString$2 = tryToString$7;
5402 var $TypeError$f = TypeError;
5403 var deletePropertyOrThrow$4 = function deletePropertyOrThrow(O, P) {
5404 if (!delete O[P]) throw $TypeError$f('Cannot delete property ' + tryToString$2(P) + ' of ' + tryToString$2(O));
5405 };
5406 var deletePropertyOrThrow$5 = /*@__PURE__*/getDefaultExportFromCjs(deletePropertyOrThrow$4);
5407
5408 'use strict';
5409 var toObject$m = toObject$t;
5410 var toAbsoluteIndex$7 = toAbsoluteIndex$a;
5411 var lengthOfArrayLike$m = lengthOfArrayLike$t;
5412 var deletePropertyOrThrow$3 = deletePropertyOrThrow$4;
5413 var min$c = Math.min;
5414
5415 // `Array.prototype.copyWithin` method implementation
5416 // https://tc39.es/ecma262/#sec-array.prototype.copywithin
5417 // eslint-disable-next-line es/no-array-prototype-copywithin -- safe
5418 var arrayCopyWithin = [].copyWithin || function copyWithin(target /* = 0 */, start /* = 0, end = @length */) {
5419 var O = toObject$m(this);
5420 var len = lengthOfArrayLike$m(O);
5421 var to = toAbsoluteIndex$7(target, len);
5422 var from = toAbsoluteIndex$7(start, len);
5423 var end = arguments.length > 2 ? arguments[2] : undefined;
5424 var count = min$c((end === undefined ? len : toAbsoluteIndex$7(end, len)) - from, len - to);
5425 var inc = 1;
5426 if (from < to && to < from + count) {
5427 inc = -1;
5428 from += count - 1;
5429 to += count - 1;
5430 }
5431 while (count-- > 0) {
5432 if (from in O) O[to] = O[from];else deletePropertyOrThrow$3(O, to);
5433 to += inc;
5434 from += inc;
5435 }
5436 return O;
5437 };
5438 var arrayCopyWithin$1 = /*@__PURE__*/getDefaultExportFromCjs(arrayCopyWithin);
5439
5440 var $$2M = _export;
5441 var copyWithin = arrayCopyWithin;
5442 var addToUnscopables$c = addToUnscopables$e;
5443
5444 // `Array.prototype.copyWithin` method
5445 // https://tc39.es/ecma262/#sec-array.prototype.copywithin
5446 $$2M({
5447 target: 'Array',
5448 proto: true
5449 }, {
5450 copyWithin: copyWithin
5451 });
5452
5453 // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
5454 addToUnscopables$c('copyWithin');
5455
5456 var es_array_every = {};
5457
5458 'use strict';
5459 var fails$13 = fails$1m;
5460 var arrayMethodIsStrict$9 = function arrayMethodIsStrict(METHOD_NAME, argument) {
5461 var method = [][METHOD_NAME];
5462 return !!method && fails$13(function () {
5463 // eslint-disable-next-line no-useless-call -- required for testing
5464 method.call(null, argument || function () {
5465 return 1;
5466 }, 1);
5467 });
5468 };
5469 var arrayMethodIsStrict$a = /*@__PURE__*/getDefaultExportFromCjs(arrayMethodIsStrict$9);
5470
5471 'use strict';
5472 var $$2L = _export;
5473 var $every$1 = arrayIteration.every;
5474 var arrayMethodIsStrict$8 = arrayMethodIsStrict$9;
5475 var STRICT_METHOD$4 = arrayMethodIsStrict$8('every');
5476
5477 // `Array.prototype.every` method
5478 // https://tc39.es/ecma262/#sec-array.prototype.every
5479 $$2L({
5480 target: 'Array',
5481 proto: true,
5482 forced: !STRICT_METHOD$4
5483 }, {
5484 every: function every(callbackfn /* , thisArg */) {
5485 return $every$1(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
5486 }
5487 });
5488
5489 var es_array_fill = {};
5490
5491 'use strict';
5492 var toObject$l = toObject$t;
5493 var toAbsoluteIndex$6 = toAbsoluteIndex$a;
5494 var lengthOfArrayLike$l = lengthOfArrayLike$t;
5495
5496 // `Array.prototype.fill` method implementation
5497 // https://tc39.es/ecma262/#sec-array.prototype.fill
5498 var arrayFill$1 = function fill(value /* , start = 0, end = @length */) {
5499 var O = toObject$l(this);
5500 var length = lengthOfArrayLike$l(O);
5501 var argumentsLength = arguments.length;
5502 var index = toAbsoluteIndex$6(argumentsLength > 1 ? arguments[1] : undefined, length);
5503 var end = argumentsLength > 2 ? arguments[2] : undefined;
5504 var endPos = end === undefined ? length : toAbsoluteIndex$6(end, length);
5505 while (endPos > index) O[index++] = value;
5506 return O;
5507 };
5508 var arrayFill$2 = /*@__PURE__*/getDefaultExportFromCjs(arrayFill$1);
5509
5510 var $$2K = _export;
5511 var fill$4 = arrayFill$1;
5512 var addToUnscopables$b = addToUnscopables$e;
5513
5514 // `Array.prototype.fill` method
5515 // https://tc39.es/ecma262/#sec-array.prototype.fill
5516 $$2K({
5517 target: 'Array',
5518 proto: true
5519 }, {
5520 fill: fill$4
5521 });
5522
5523 // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
5524 addToUnscopables$b('fill');
5525
5526 var es_array_filter = {};
5527
5528 'use strict';
5529 var $$2J = _export;
5530 var $filter$1 = arrayIteration.filter;
5531 var arrayMethodHasSpeciesSupport$3 = arrayMethodHasSpeciesSupport$5;
5532 var HAS_SPECIES_SUPPORT$3 = arrayMethodHasSpeciesSupport$3('filter');
5533
5534 // `Array.prototype.filter` method
5535 // https://tc39.es/ecma262/#sec-array.prototype.filter
5536 // with adding support of @@species
5537 $$2J({
5538 target: 'Array',
5539 proto: true,
5540 forced: !HAS_SPECIES_SUPPORT$3
5541 }, {
5542 filter: function filter(callbackfn /* , thisArg */) {
5543 return $filter$1(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
5544 }
5545 });
5546
5547 var es_array_find = {};
5548
5549 'use strict';
5550 var $$2I = _export;
5551 var $find$1 = arrayIteration.find;
5552 var addToUnscopables$a = addToUnscopables$e;
5553 var FIND = 'find';
5554 var SKIPS_HOLES$1 = true;
5555
5556 // Shouldn't skip holes
5557 if (FIND in []) Array(1)[FIND](function () {
5558 SKIPS_HOLES$1 = false;
5559 });
5560
5561 // `Array.prototype.find` method
5562 // https://tc39.es/ecma262/#sec-array.prototype.find
5563 $$2I({
5564 target: 'Array',
5565 proto: true,
5566 forced: SKIPS_HOLES$1
5567 }, {
5568 find: function find(callbackfn /* , that = undefined */) {
5569 return $find$1(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
5570 }
5571 });
5572
5573 // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
5574 addToUnscopables$a(FIND);
5575
5576 var es_array_findIndex = {};
5577
5578 'use strict';
5579 var $$2H = _export;
5580 var $findIndex$1 = arrayIteration.findIndex;
5581 var addToUnscopables$9 = addToUnscopables$e;
5582 var FIND_INDEX = 'findIndex';
5583 var SKIPS_HOLES = true;
5584
5585 // Shouldn't skip holes
5586 if (FIND_INDEX in []) Array(1)[FIND_INDEX](function () {
5587 SKIPS_HOLES = false;
5588 });
5589
5590 // `Array.prototype.findIndex` method
5591 // https://tc39.es/ecma262/#sec-array.prototype.findindex
5592 $$2H({
5593 target: 'Array',
5594 proto: true,
5595 forced: SKIPS_HOLES
5596 }, {
5597 findIndex: function findIndex(callbackfn /* , that = undefined */) {
5598 return $findIndex$1(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
5599 }
5600 });
5601
5602 // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
5603 addToUnscopables$9(FIND_INDEX);
5604
5605 var es_array_findLast = {};
5606
5607 var bind$b = functionBindContext;
5608 var IndexedObject$3 = indexedObject;
5609 var toObject$k = toObject$t;
5610 var lengthOfArrayLike$k = lengthOfArrayLike$t;
5611
5612 // `Array.prototype.{ findLast, findLastIndex }` methods implementation
5613 var createMethod$5 = function createMethod(TYPE) {
5614 var IS_FIND_LAST_INDEX = TYPE == 1;
5615 return function ($this, callbackfn, that) {
5616 var O = toObject$k($this);
5617 var self = IndexedObject$3(O);
5618 var boundFunction = bind$b(callbackfn, that);
5619 var index = lengthOfArrayLike$k(self);
5620 var value, result;
5621 while (index-- > 0) {
5622 value = self[index];
5623 result = boundFunction(value, index, O);
5624 if (result) switch (TYPE) {
5625 case 0:
5626 return value;
5627 // findLast
5628 case 1:
5629 return index;
5630 // findLastIndex
5631 }
5632 }
5633
5634 return IS_FIND_LAST_INDEX ? -1 : undefined;
5635 };
5636 };
5637 var arrayIterationFromLast = {
5638 // `Array.prototype.findLast` method
5639 // https://github.com/tc39/proposal-array-find-from-last
5640 findLast: createMethod$5(0),
5641 // `Array.prototype.findLastIndex` method
5642 // https://github.com/tc39/proposal-array-find-from-last
5643 findLastIndex: createMethod$5(1)
5644 };
5645 var arrayIterationFromLast$1 = /*@__PURE__*/getDefaultExportFromCjs(arrayIterationFromLast);
5646
5647 'use strict';
5648 var $$2G = _export;
5649 var $findLast$1 = arrayIterationFromLast.findLast;
5650 var addToUnscopables$8 = addToUnscopables$e;
5651
5652 // `Array.prototype.findLast` method
5653 // https://github.com/tc39/proposal-array-find-from-last
5654 $$2G({
5655 target: 'Array',
5656 proto: true
5657 }, {
5658 findLast: function findLast(callbackfn /* , that = undefined */) {
5659 return $findLast$1(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
5660 }
5661 });
5662 addToUnscopables$8('findLast');
5663
5664 var es_array_findLastIndex = {};
5665
5666 'use strict';
5667 var $$2F = _export;
5668 var $findLastIndex$1 = arrayIterationFromLast.findLastIndex;
5669 var addToUnscopables$7 = addToUnscopables$e;
5670
5671 // `Array.prototype.findLastIndex` method
5672 // https://github.com/tc39/proposal-array-find-from-last
5673 $$2F({
5674 target: 'Array',
5675 proto: true
5676 }, {
5677 findLastIndex: function findLastIndex(callbackfn /* , that = undefined */) {
5678 return $findLastIndex$1(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
5679 }
5680 });
5681 addToUnscopables$7('findLastIndex');
5682
5683 var es_array_flat = {};
5684
5685 'use strict';
5686 var isArray$5 = isArray$9;
5687 var lengthOfArrayLike$j = lengthOfArrayLike$t;
5688 var doesNotExceedSafeInteger$4 = doesNotExceedSafeInteger$6;
5689 var bind$a = functionBindContext;
5690
5691 // `FlattenIntoArray` abstract operation
5692 // https://tc39.github.io/proposal-flatMap/#sec-FlattenIntoArray
5693 var flattenIntoArray$2 = function flattenIntoArray(target, original, source, sourceLen, start, depth, mapper, thisArg) {
5694 var targetIndex = start;
5695 var sourceIndex = 0;
5696 var mapFn = mapper ? bind$a(mapper, thisArg) : false;
5697 var element, elementLen;
5698 while (sourceIndex < sourceLen) {
5699 if (sourceIndex in source) {
5700 element = mapFn ? mapFn(source[sourceIndex], sourceIndex, original) : source[sourceIndex];
5701 if (depth > 0 && isArray$5(element)) {
5702 elementLen = lengthOfArrayLike$j(element);
5703 targetIndex = flattenIntoArray(target, original, element, elementLen, targetIndex, depth - 1) - 1;
5704 } else {
5705 doesNotExceedSafeInteger$4(targetIndex + 1);
5706 target[targetIndex] = element;
5707 }
5708 targetIndex++;
5709 }
5710 sourceIndex++;
5711 }
5712 return targetIndex;
5713 };
5714 var flattenIntoArray_1 = flattenIntoArray$2;
5715 var flattenIntoArray$3 = /*@__PURE__*/getDefaultExportFromCjs(flattenIntoArray_1);
5716
5717 'use strict';
5718 var $$2E = _export;
5719 var flattenIntoArray$1 = flattenIntoArray_1;
5720 var toObject$j = toObject$t;
5721 var lengthOfArrayLike$i = lengthOfArrayLike$t;
5722 var toIntegerOrInfinity$h = toIntegerOrInfinity$l;
5723 var arraySpeciesCreate$2 = arraySpeciesCreate$5;
5724
5725 // `Array.prototype.flat` method
5726 // https://tc39.es/ecma262/#sec-array.prototype.flat
5727 $$2E({
5728 target: 'Array',
5729 proto: true
5730 }, {
5731 flat: function flat( /* depthArg = 1 */
5732 ) {
5733 var depthArg = arguments.length ? arguments[0] : undefined;
5734 var O = toObject$j(this);
5735 var sourceLen = lengthOfArrayLike$i(O);
5736 var A = arraySpeciesCreate$2(O, 0);
5737 A.length = flattenIntoArray$1(A, O, O, sourceLen, 0, depthArg === undefined ? 1 : toIntegerOrInfinity$h(depthArg));
5738 return A;
5739 }
5740 });
5741
5742 var es_array_flatMap = {};
5743
5744 'use strict';
5745 var $$2D = _export;
5746 var flattenIntoArray = flattenIntoArray_1;
5747 var aCallable$g = aCallable$l;
5748 var toObject$i = toObject$t;
5749 var lengthOfArrayLike$h = lengthOfArrayLike$t;
5750 var arraySpeciesCreate$1 = arraySpeciesCreate$5;
5751
5752 // `Array.prototype.flatMap` method
5753 // https://tc39.es/ecma262/#sec-array.prototype.flatmap
5754 $$2D({
5755 target: 'Array',
5756 proto: true
5757 }, {
5758 flatMap: function flatMap(callbackfn /* , thisArg */) {
5759 var O = toObject$i(this);
5760 var sourceLen = lengthOfArrayLike$h(O);
5761 var A;
5762 aCallable$g(callbackfn);
5763 A = arraySpeciesCreate$1(O, 0);
5764 A.length = flattenIntoArray(A, O, O, sourceLen, 0, 1, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
5765 return A;
5766 }
5767 });
5768
5769 var es_array_forEach = {};
5770
5771 'use strict';
5772 var $forEach$1 = arrayIteration.forEach;
5773 var arrayMethodIsStrict$7 = arrayMethodIsStrict$9;
5774 var STRICT_METHOD$3 = arrayMethodIsStrict$7('forEach');
5775
5776 // `Array.prototype.forEach` method implementation
5777 // https://tc39.es/ecma262/#sec-array.prototype.foreach
5778 var arrayForEach = !STRICT_METHOD$3 ? function forEach(callbackfn /* , thisArg */) {
5779 return $forEach$1(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
5780 // eslint-disable-next-line es/no-array-prototype-foreach -- safe
5781 } : [].forEach;
5782 var arrayForEach$1 = /*@__PURE__*/getDefaultExportFromCjs(arrayForEach);
5783
5784 'use strict';
5785 var $$2C = _export;
5786 var forEach$3 = arrayForEach;
5787
5788 // `Array.prototype.forEach` method
5789 // https://tc39.es/ecma262/#sec-array.prototype.foreach
5790 // eslint-disable-next-line es/no-array-prototype-foreach -- safe
5791 $$2C({
5792 target: 'Array',
5793 proto: true,
5794 forced: [].forEach != forEach$3
5795 }, {
5796 forEach: forEach$3
5797 });
5798
5799 var es_array_from = {};
5800
5801 var anObject$s = anObject$D;
5802 var iteratorClose = iteratorClose$2;
5803
5804 // call something on iterator step with safe closing on error
5805 var callWithSafeIterationClosing$1 = function callWithSafeIterationClosing(iterator, fn, value, ENTRIES) {
5806 try {
5807 return ENTRIES ? fn(anObject$s(value)[0], value[1]) : fn(value);
5808 } catch (error) {
5809 iteratorClose(iterator, 'throw', error);
5810 }
5811 };
5812 var callWithSafeIterationClosing$2 = /*@__PURE__*/getDefaultExportFromCjs(callWithSafeIterationClosing$1);
5813
5814 'use strict';
5815 var bind$9 = functionBindContext;
5816 var call$s = functionCall;
5817 var toObject$h = toObject$t;
5818 var callWithSafeIterationClosing = callWithSafeIterationClosing$1;
5819 var isArrayIteratorMethod$1 = isArrayIteratorMethod$3;
5820 var isConstructor$4 = isConstructor$6;
5821 var lengthOfArrayLike$g = lengthOfArrayLike$t;
5822 var createProperty$6 = createProperty$9;
5823 var getIterator$2 = getIterator$4;
5824 var getIteratorMethod$2 = getIteratorMethod$5;
5825 var $Array$7 = Array;
5826
5827 // `Array.from` method implementation
5828 // https://tc39.es/ecma262/#sec-array.from
5829 var arrayFrom$1 = function from(arrayLike /* , mapfn = undefined, thisArg = undefined */) {
5830 var O = toObject$h(arrayLike);
5831 var IS_CONSTRUCTOR = isConstructor$4(this);
5832 var argumentsLength = arguments.length;
5833 var mapfn = argumentsLength > 1 ? arguments[1] : undefined;
5834 var mapping = mapfn !== undefined;
5835 if (mapping) mapfn = bind$9(mapfn, argumentsLength > 2 ? arguments[2] : undefined);
5836 var iteratorMethod = getIteratorMethod$2(O);
5837 var index = 0;
5838 var length, result, step, iterator, next, value;
5839 // if the target is not iterable or it's an array with the default iterator - use a simple case
5840 if (iteratorMethod && !(this === $Array$7 && isArrayIteratorMethod$1(iteratorMethod))) {
5841 iterator = getIterator$2(O, iteratorMethod);
5842 next = iterator.next;
5843 result = IS_CONSTRUCTOR ? new this() : [];
5844 for (; !(step = call$s(next, iterator)).done; index++) {
5845 value = mapping ? callWithSafeIterationClosing(iterator, mapfn, [step.value, index], true) : step.value;
5846 createProperty$6(result, index, value);
5847 }
5848 } else {
5849 length = lengthOfArrayLike$g(O);
5850 result = IS_CONSTRUCTOR ? new this(length) : $Array$7(length);
5851 for (; length > index; index++) {
5852 value = mapping ? mapfn(O[index], index) : O[index];
5853 createProperty$6(result, index, value);
5854 }
5855 }
5856 result.length = index;
5857 return result;
5858 };
5859 var arrayFrom$2 = /*@__PURE__*/getDefaultExportFromCjs(arrayFrom$1);
5860
5861 var wellKnownSymbol$k = wellKnownSymbol$z;
5862 var ITERATOR$8 = wellKnownSymbol$k('iterator');
5863 var SAFE_CLOSING = false;
5864 try {
5865 var called = 0;
5866 var iteratorWithReturn = {
5867 next: function next() {
5868 return {
5869 done: !!called++
5870 };
5871 },
5872 'return': function _return() {
5873 SAFE_CLOSING = true;
5874 }
5875 };
5876 iteratorWithReturn[ITERATOR$8] = function () {
5877 return this;
5878 };
5879 // eslint-disable-next-line es/no-array-from, no-throw-literal -- required for testing
5880 Array.from(iteratorWithReturn, function () {
5881 throw 2;
5882 });
5883 } catch (error) {/* empty */}
5884 var checkCorrectnessOfIteration$4 = function checkCorrectnessOfIteration(exec, SKIP_CLOSING) {
5885 if (!SKIP_CLOSING && !SAFE_CLOSING) return false;
5886 var ITERATION_SUPPORT = false;
5887 try {
5888 var object = {};
5889 object[ITERATOR$8] = function () {
5890 return {
5891 next: function next() {
5892 return {
5893 done: ITERATION_SUPPORT = true
5894 };
5895 }
5896 };
5897 };
5898 exec(object);
5899 } catch (error) {/* empty */}
5900 return ITERATION_SUPPORT;
5901 };
5902 var checkCorrectnessOfIteration$5 = /*@__PURE__*/getDefaultExportFromCjs(checkCorrectnessOfIteration$4);
5903
5904 var $$2B = _export;
5905 var from = arrayFrom$1;
5906 var checkCorrectnessOfIteration$3 = checkCorrectnessOfIteration$4;
5907 var INCORRECT_ITERATION = !checkCorrectnessOfIteration$3(function (iterable) {
5908 // eslint-disable-next-line es/no-array-from -- required for testing
5909 Array.from(iterable);
5910 });
5911
5912 // `Array.from` method
5913 // https://tc39.es/ecma262/#sec-array.from
5914 $$2B({
5915 target: 'Array',
5916 stat: true,
5917 forced: INCORRECT_ITERATION
5918 }, {
5919 from: from
5920 });
5921
5922 var es_array_includes = {};
5923
5924 'use strict';
5925 var $$2A = _export;
5926 var $includes$1 = arrayIncludes.includes;
5927 var fails$12 = fails$1m;
5928 var addToUnscopables$6 = addToUnscopables$e;
5929
5930 // FF99+ bug
5931 var BROKEN_ON_SPARSE = fails$12(function () {
5932 // eslint-disable-next-line es/no-array-prototype-includes -- detection
5933 return !Array(1).includes();
5934 });
5935
5936 // `Array.prototype.includes` method
5937 // https://tc39.es/ecma262/#sec-array.prototype.includes
5938 $$2A({
5939 target: 'Array',
5940 proto: true,
5941 forced: BROKEN_ON_SPARSE
5942 }, {
5943 includes: function includes(el /* , fromIndex = 0 */) {
5944 return $includes$1(this, el, arguments.length > 1 ? arguments[1] : undefined);
5945 }
5946 });
5947
5948 // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
5949 addToUnscopables$6('includes');
5950
5951 var es_array_indexOf = {};
5952
5953 'use strict';
5954 /* eslint-disable es/no-array-prototype-indexof -- required for testing */
5955 var $$2z = _export;
5956 var uncurryThis$$ = functionUncurryThisClause;
5957 var $indexOf$1 = arrayIncludes.indexOf;
5958 var arrayMethodIsStrict$6 = arrayMethodIsStrict$9;
5959 var nativeIndexOf = uncurryThis$$([].indexOf);
5960 var NEGATIVE_ZERO$1 = !!nativeIndexOf && 1 / nativeIndexOf([1], 1, -0) < 0;
5961 var FORCED$z = NEGATIVE_ZERO$1 || !arrayMethodIsStrict$6('indexOf');
5962
5963 // `Array.prototype.indexOf` method
5964 // https://tc39.es/ecma262/#sec-array.prototype.indexof
5965 $$2z({
5966 target: 'Array',
5967 proto: true,
5968 forced: FORCED$z
5969 }, {
5970 indexOf: function indexOf(searchElement /* , fromIndex = 0 */) {
5971 var fromIndex = arguments.length > 1 ? arguments[1] : undefined;
5972 return NEGATIVE_ZERO$1
5973 // convert -0 to +0
5974 ? nativeIndexOf(this, searchElement, fromIndex) || 0 : $indexOf$1(this, searchElement, fromIndex);
5975 }
5976 });
5977
5978 var es_array_isArray = {};
5979
5980 var $$2y = _export;
5981 var isArray$4 = isArray$9;
5982
5983 // `Array.isArray` method
5984 // https://tc39.es/ecma262/#sec-array.isarray
5985 $$2y({
5986 target: 'Array',
5987 stat: true
5988 }, {
5989 isArray: isArray$4
5990 });
5991
5992 'use strict';
5993 var fails$11 = fails$1m;
5994 var isCallable$g = isCallable$z;
5995 var isObject$p = isObject$z;
5996 var create$8 = objectCreate;
5997 var getPrototypeOf$9 = objectGetPrototypeOf$1;
5998 var defineBuiltIn$h = defineBuiltIn$m;
5999 var wellKnownSymbol$j = wellKnownSymbol$z;
6000 var IS_PURE$h = isPure;
6001 var ITERATOR$7 = wellKnownSymbol$j('iterator');
6002 var BUGGY_SAFARI_ITERATORS$1 = false;
6003
6004 // `%IteratorPrototype%` object
6005 // https://tc39.es/ecma262/#sec-%iteratorprototype%-object
6006 var IteratorPrototype$2, PrototypeOfArrayIteratorPrototype, arrayIterator;
6007
6008 /* eslint-disable es/no-array-prototype-keys -- safe */
6009 if ([].keys) {
6010 arrayIterator = [].keys();
6011 // Safari 8 has buggy iterators w/o `next`
6012 if (!('next' in arrayIterator)) BUGGY_SAFARI_ITERATORS$1 = true;else {
6013 PrototypeOfArrayIteratorPrototype = getPrototypeOf$9(getPrototypeOf$9(arrayIterator));
6014 if (PrototypeOfArrayIteratorPrototype !== Object.prototype) IteratorPrototype$2 = PrototypeOfArrayIteratorPrototype;
6015 }
6016 }
6017 var NEW_ITERATOR_PROTOTYPE = !isObject$p(IteratorPrototype$2) || fails$11(function () {
6018 var test = {};
6019 // FF44- legacy iterators case
6020 return IteratorPrototype$2[ITERATOR$7].call(test) !== test;
6021 });
6022 if (NEW_ITERATOR_PROTOTYPE) IteratorPrototype$2 = {};else if (IS_PURE$h) IteratorPrototype$2 = create$8(IteratorPrototype$2);
6023
6024 // `%IteratorPrototype%[@@iterator]()` method
6025 // https://tc39.es/ecma262/#sec-%iteratorprototype%-@@iterator
6026 if (!isCallable$g(IteratorPrototype$2[ITERATOR$7])) {
6027 defineBuiltIn$h(IteratorPrototype$2, ITERATOR$7, function () {
6028 return this;
6029 });
6030 }
6031 var iteratorsCore = {
6032 IteratorPrototype: IteratorPrototype$2,
6033 BUGGY_SAFARI_ITERATORS: BUGGY_SAFARI_ITERATORS$1
6034 };
6035 var iteratorsCore$1 = /*@__PURE__*/getDefaultExportFromCjs(iteratorsCore);
6036
6037 'use strict';
6038 var IteratorPrototype$1 = iteratorsCore.IteratorPrototype;
6039 var create$7 = objectCreate;
6040 var createPropertyDescriptor$5 = createPropertyDescriptor$c;
6041 var setToStringTag$a = setToStringTag$d;
6042 var Iterators$2 = iterators;
6043 var returnThis$1 = function returnThis() {
6044 return this;
6045 };
6046 var iteratorCreateConstructor = function iteratorCreateConstructor(IteratorConstructor, NAME, next, ENUMERABLE_NEXT) {
6047 var TO_STRING_TAG = NAME + ' Iterator';
6048 IteratorConstructor.prototype = create$7(IteratorPrototype$1, {
6049 next: createPropertyDescriptor$5(+!ENUMERABLE_NEXT, next)
6050 });
6051 setToStringTag$a(IteratorConstructor, TO_STRING_TAG, false, true);
6052 Iterators$2[TO_STRING_TAG] = returnThis$1;
6053 return IteratorConstructor;
6054 };
6055 var iteratorCreateConstructor$1 = /*@__PURE__*/getDefaultExportFromCjs(iteratorCreateConstructor);
6056
6057 'use strict';
6058 var $$2x = _export;
6059 var call$r = functionCall;
6060 var IS_PURE$g = isPure;
6061 var FunctionName$1 = functionName;
6062 var isCallable$f = isCallable$z;
6063 var createIteratorConstructor$2 = iteratorCreateConstructor;
6064 var getPrototypeOf$8 = objectGetPrototypeOf$1;
6065 var setPrototypeOf$6 = objectSetPrototypeOf$1;
6066 var setToStringTag$9 = setToStringTag$d;
6067 var createNonEnumerableProperty$8 = createNonEnumerableProperty$f;
6068 var defineBuiltIn$g = defineBuiltIn$m;
6069 var wellKnownSymbol$i = wellKnownSymbol$z;
6070 var Iterators$1 = iterators;
6071 var IteratorsCore = iteratorsCore;
6072 var PROPER_FUNCTION_NAME$3 = FunctionName$1.PROPER;
6073 var CONFIGURABLE_FUNCTION_NAME$1 = FunctionName$1.CONFIGURABLE;
6074 var IteratorPrototype = IteratorsCore.IteratorPrototype;
6075 var BUGGY_SAFARI_ITERATORS = IteratorsCore.BUGGY_SAFARI_ITERATORS;
6076 var ITERATOR$6 = wellKnownSymbol$i('iterator');
6077 var KEYS = 'keys';
6078 var VALUES = 'values';
6079 var ENTRIES = 'entries';
6080 var returnThis = function returnThis() {
6081 return this;
6082 };
6083 var iteratorDefine = function iteratorDefine(Iterable, NAME, IteratorConstructor, next, DEFAULT, IS_SET, FORCED) {
6084 createIteratorConstructor$2(IteratorConstructor, NAME, next);
6085 var getIterationMethod = function getIterationMethod(KIND) {
6086 if (KIND === DEFAULT && defaultIterator) return defaultIterator;
6087 if (!BUGGY_SAFARI_ITERATORS && KIND in IterablePrototype) return IterablePrototype[KIND];
6088 switch (KIND) {
6089 case KEYS:
6090 return function keys() {
6091 return new IteratorConstructor(this, KIND);
6092 };
6093 case VALUES:
6094 return function values() {
6095 return new IteratorConstructor(this, KIND);
6096 };
6097 case ENTRIES:
6098 return function entries() {
6099 return new IteratorConstructor(this, KIND);
6100 };
6101 }
6102 return function () {
6103 return new IteratorConstructor(this);
6104 };
6105 };
6106 var TO_STRING_TAG = NAME + ' Iterator';
6107 var INCORRECT_VALUES_NAME = false;
6108 var IterablePrototype = Iterable.prototype;
6109 var nativeIterator = IterablePrototype[ITERATOR$6] || IterablePrototype['@@iterator'] || DEFAULT && IterablePrototype[DEFAULT];
6110 var defaultIterator = !BUGGY_SAFARI_ITERATORS && nativeIterator || getIterationMethod(DEFAULT);
6111 var anyNativeIterator = NAME == 'Array' ? IterablePrototype.entries || nativeIterator : nativeIterator;
6112 var CurrentIteratorPrototype, methods, KEY;
6113
6114 // fix native
6115 if (anyNativeIterator) {
6116 CurrentIteratorPrototype = getPrototypeOf$8(anyNativeIterator.call(new Iterable()));
6117 if (CurrentIteratorPrototype !== Object.prototype && CurrentIteratorPrototype.next) {
6118 if (!IS_PURE$g && getPrototypeOf$8(CurrentIteratorPrototype) !== IteratorPrototype) {
6119 if (setPrototypeOf$6) {
6120 setPrototypeOf$6(CurrentIteratorPrototype, IteratorPrototype);
6121 } else if (!isCallable$f(CurrentIteratorPrototype[ITERATOR$6])) {
6122 defineBuiltIn$g(CurrentIteratorPrototype, ITERATOR$6, returnThis);
6123 }
6124 }
6125 // Set @@toStringTag to native iterators
6126 setToStringTag$9(CurrentIteratorPrototype, TO_STRING_TAG, true, true);
6127 if (IS_PURE$g) Iterators$1[TO_STRING_TAG] = returnThis;
6128 }
6129 }
6130
6131 // fix Array.prototype.{ values, @@iterator }.name in V8 / FF
6132 if (PROPER_FUNCTION_NAME$3 && DEFAULT == VALUES && nativeIterator && nativeIterator.name !== VALUES) {
6133 if (!IS_PURE$g && CONFIGURABLE_FUNCTION_NAME$1) {
6134 createNonEnumerableProperty$8(IterablePrototype, 'name', VALUES);
6135 } else {
6136 INCORRECT_VALUES_NAME = true;
6137 defaultIterator = function values() {
6138 return call$r(nativeIterator, this);
6139 };
6140 }
6141 }
6142
6143 // export additional methods
6144 if (DEFAULT) {
6145 methods = {
6146 values: getIterationMethod(VALUES),
6147 keys: IS_SET ? defaultIterator : getIterationMethod(KEYS),
6148 entries: getIterationMethod(ENTRIES)
6149 };
6150 if (FORCED) for (KEY in methods) {
6151 if (BUGGY_SAFARI_ITERATORS || INCORRECT_VALUES_NAME || !(KEY in IterablePrototype)) {
6152 defineBuiltIn$g(IterablePrototype, KEY, methods[KEY]);
6153 }
6154 } else $$2x({
6155 target: NAME,
6156 proto: true,
6157 forced: BUGGY_SAFARI_ITERATORS || INCORRECT_VALUES_NAME
6158 }, methods);
6159 }
6160
6161 // define iterator
6162 if ((!IS_PURE$g || FORCED) && IterablePrototype[ITERATOR$6] !== defaultIterator) {
6163 defineBuiltIn$g(IterablePrototype, ITERATOR$6, defaultIterator, {
6164 name: DEFAULT
6165 });
6166 }
6167 Iterators$1[NAME] = defaultIterator;
6168 return methods;
6169 };
6170 var iteratorDefine$1 = /*@__PURE__*/getDefaultExportFromCjs(iteratorDefine);
6171
6172 // `CreateIterResultObject` abstract operation
6173 // https://tc39.es/ecma262/#sec-createiterresultobject
6174 var createIterResultObject$4 = function createIterResultObject(value, done) {
6175 return {
6176 value: value,
6177 done: done
6178 };
6179 };
6180 var createIterResultObject$5 = /*@__PURE__*/getDefaultExportFromCjs(createIterResultObject$4);
6181
6182 'use strict';
6183 var toIndexedObject$c = toIndexedObject$j;
6184 var addToUnscopables$5 = addToUnscopables$e;
6185 var Iterators = iterators;
6186 var InternalStateModule$b = internalState;
6187 var defineProperty$7 = objectDefineProperty.f;
6188 var defineIterator$2 = iteratorDefine;
6189 var createIterResultObject$3 = createIterResultObject$4;
6190 var IS_PURE$f = isPure;
6191 var DESCRIPTORS$x = descriptors;
6192 var ARRAY_ITERATOR = 'Array Iterator';
6193 var setInternalState$a = InternalStateModule$b.set;
6194 var getInternalState$8 = InternalStateModule$b.getterFor(ARRAY_ITERATOR);
6195
6196 // `Array.prototype.entries` method
6197 // https://tc39.es/ecma262/#sec-array.prototype.entries
6198 // `Array.prototype.keys` method
6199 // https://tc39.es/ecma262/#sec-array.prototype.keys
6200 // `Array.prototype.values` method
6201 // https://tc39.es/ecma262/#sec-array.prototype.values
6202 // `Array.prototype[@@iterator]` method
6203 // https://tc39.es/ecma262/#sec-array.prototype-@@iterator
6204 // `CreateArrayIterator` internal method
6205 // https://tc39.es/ecma262/#sec-createarrayiterator
6206 var es_array_iterator = defineIterator$2(Array, 'Array', function (iterated, kind) {
6207 setInternalState$a(this, {
6208 type: ARRAY_ITERATOR,
6209 target: toIndexedObject$c(iterated),
6210 // target
6211 index: 0,
6212 // next index
6213 kind: kind // kind
6214 });
6215 // `%ArrayIteratorPrototype%.next` method
6216 // https://tc39.es/ecma262/#sec-%arrayiteratorprototype%.next
6217 }, function () {
6218 var state = getInternalState$8(this);
6219 var target = state.target;
6220 var kind = state.kind;
6221 var index = state.index++;
6222 if (!target || index >= target.length) {
6223 state.target = undefined;
6224 return createIterResultObject$3(undefined, true);
6225 }
6226 if (kind == 'keys') return createIterResultObject$3(index, false);
6227 if (kind == 'values') return createIterResultObject$3(target[index], false);
6228 return createIterResultObject$3([index, target[index]], false);
6229 }, 'values');
6230
6231 // argumentsList[@@iterator] is %ArrayProto_values%
6232 // https://tc39.es/ecma262/#sec-createunmappedargumentsobject
6233 // https://tc39.es/ecma262/#sec-createmappedargumentsobject
6234 var values = Iterators.Arguments = Iterators.Array;
6235
6236 // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
6237 addToUnscopables$5('keys');
6238 addToUnscopables$5('values');
6239 addToUnscopables$5('entries');
6240
6241 // V8 ~ Chrome 45- bug
6242 if (!IS_PURE$f && DESCRIPTORS$x && values.name !== 'values') try {
6243 defineProperty$7(values, 'name', {
6244 value: 'values'
6245 });
6246 } catch (error) {/* empty */}
6247 var es_array_iterator$1 = /*@__PURE__*/getDefaultExportFromCjs(es_array_iterator);
6248
6249 var es_array_join = {};
6250
6251 'use strict';
6252 var $$2w = _export;
6253 var uncurryThis$_ = functionUncurryThis;
6254 var IndexedObject$2 = indexedObject;
6255 var toIndexedObject$b = toIndexedObject$j;
6256 var arrayMethodIsStrict$5 = arrayMethodIsStrict$9;
6257 var nativeJoin = uncurryThis$_([].join);
6258 var ES3_STRINGS = IndexedObject$2 != Object;
6259 var FORCED$y = ES3_STRINGS || !arrayMethodIsStrict$5('join', ',');
6260
6261 // `Array.prototype.join` method
6262 // https://tc39.es/ecma262/#sec-array.prototype.join
6263 $$2w({
6264 target: 'Array',
6265 proto: true,
6266 forced: FORCED$y
6267 }, {
6268 join: function join(separator) {
6269 return nativeJoin(toIndexedObject$b(this), separator === undefined ? ',' : separator);
6270 }
6271 });
6272
6273 var es_array_lastIndexOf = {};
6274
6275 'use strict';
6276 /* eslint-disable es/no-array-prototype-lastindexof -- safe */
6277 var apply$7 = functionApply$1;
6278 var toIndexedObject$a = toIndexedObject$j;
6279 var toIntegerOrInfinity$g = toIntegerOrInfinity$l;
6280 var lengthOfArrayLike$f = lengthOfArrayLike$t;
6281 var arrayMethodIsStrict$4 = arrayMethodIsStrict$9;
6282 var min$b = Math.min;
6283 var $lastIndexOf$1 = [].lastIndexOf;
6284 var NEGATIVE_ZERO = !!$lastIndexOf$1 && 1 / [1].lastIndexOf(1, -0) < 0;
6285 var STRICT_METHOD$2 = arrayMethodIsStrict$4('lastIndexOf');
6286 var FORCED$x = NEGATIVE_ZERO || !STRICT_METHOD$2;
6287
6288 // `Array.prototype.lastIndexOf` method implementation
6289 // https://tc39.es/ecma262/#sec-array.prototype.lastindexof
6290 var arrayLastIndexOf = FORCED$x ? function lastIndexOf(searchElement /* , fromIndex = @[*-1] */) {
6291 // convert -0 to +0
6292 if (NEGATIVE_ZERO) return apply$7($lastIndexOf$1, this, arguments) || 0;
6293 var O = toIndexedObject$a(this);
6294 var length = lengthOfArrayLike$f(O);
6295 var index = length - 1;
6296 if (arguments.length > 1) index = min$b(index, toIntegerOrInfinity$g(arguments[1]));
6297 if (index < 0) index = length + index;
6298 for (; index >= 0; index--) if (index in O && O[index] === searchElement) return index || 0;
6299 return -1;
6300 } : $lastIndexOf$1;
6301 var arrayLastIndexOf$1 = /*@__PURE__*/getDefaultExportFromCjs(arrayLastIndexOf);
6302
6303 var $$2v = _export;
6304 var lastIndexOf = arrayLastIndexOf;
6305
6306 // `Array.prototype.lastIndexOf` method
6307 // https://tc39.es/ecma262/#sec-array.prototype.lastindexof
6308 // eslint-disable-next-line es/no-array-prototype-lastindexof -- required for testing
6309 $$2v({
6310 target: 'Array',
6311 proto: true,
6312 forced: lastIndexOf !== [].lastIndexOf
6313 }, {
6314 lastIndexOf: lastIndexOf
6315 });
6316
6317 var es_array_map = {};
6318
6319 'use strict';
6320 var $$2u = _export;
6321 var $map$1 = arrayIteration.map;
6322 var arrayMethodHasSpeciesSupport$2 = arrayMethodHasSpeciesSupport$5;
6323 var HAS_SPECIES_SUPPORT$2 = arrayMethodHasSpeciesSupport$2('map');
6324
6325 // `Array.prototype.map` method
6326 // https://tc39.es/ecma262/#sec-array.prototype.map
6327 // with adding support of @@species
6328 $$2u({
6329 target: 'Array',
6330 proto: true,
6331 forced: !HAS_SPECIES_SUPPORT$2
6332 }, {
6333 map: function map(callbackfn /* , thisArg */) {
6334 return $map$1(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
6335 }
6336 });
6337
6338 var es_array_of = {};
6339
6340 'use strict';
6341 var $$2t = _export;
6342 var fails$10 = fails$1m;
6343 var isConstructor$3 = isConstructor$6;
6344 var createProperty$5 = createProperty$9;
6345 var $Array$6 = Array;
6346 var ISNT_GENERIC = fails$10(function () {
6347 function F() {/* empty */}
6348 // eslint-disable-next-line es/no-array-of -- safe
6349 return !($Array$6.of.call(F) instanceof F);
6350 });
6351
6352 // `Array.of` method
6353 // https://tc39.es/ecma262/#sec-array.of
6354 // WebKit Array.of isn't generic
6355 $$2t({
6356 target: 'Array',
6357 stat: true,
6358 forced: ISNT_GENERIC
6359 }, {
6360 of: function of( /* ...args */
6361 ) {
6362 var index = 0;
6363 var argumentsLength = arguments.length;
6364 var result = new (isConstructor$3(this) ? this : $Array$6)(argumentsLength);
6365 while (argumentsLength > index) createProperty$5(result, index, arguments[index++]);
6366 result.length = argumentsLength;
6367 return result;
6368 }
6369 });
6370
6371 var es_array_push = {};
6372
6373 'use strict';
6374 var DESCRIPTORS$w = descriptors;
6375 var isArray$3 = isArray$9;
6376 var $TypeError$e = TypeError;
6377 // eslint-disable-next-line es/no-object-getownpropertydescriptor -- safe
6378 var getOwnPropertyDescriptor$8 = Object.getOwnPropertyDescriptor;
6379
6380 // Safari < 13 does not throw an error in this case
6381 var SILENT_ON_NON_WRITABLE_LENGTH_SET = DESCRIPTORS$w && !function () {
6382 // makes no sense without proper strict mode support
6383 if (this !== undefined) return true;
6384 try {
6385 // eslint-disable-next-line es/no-object-defineproperty -- safe
6386 Object.defineProperty([], 'length', {
6387 writable: false
6388 }).length = 1;
6389 } catch (error) {
6390 return error instanceof TypeError;
6391 }
6392 }();
6393 var arraySetLength = SILENT_ON_NON_WRITABLE_LENGTH_SET ? function (O, length) {
6394 if (isArray$3(O) && !getOwnPropertyDescriptor$8(O, 'length').writable) {
6395 throw $TypeError$e('Cannot set read only .length');
6396 }
6397 return O.length = length;
6398 } : function (O, length) {
6399 return O.length = length;
6400 };
6401 var arraySetLength$1 = /*@__PURE__*/getDefaultExportFromCjs(arraySetLength);
6402
6403 'use strict';
6404 var $$2s = _export;
6405 var toObject$g = toObject$t;
6406 var lengthOfArrayLike$e = lengthOfArrayLike$t;
6407 var setArrayLength$2 = arraySetLength;
6408 var doesNotExceedSafeInteger$3 = doesNotExceedSafeInteger$6;
6409 var fails$$ = fails$1m;
6410 var INCORRECT_TO_LENGTH = fails$$(function () {
6411 return [].push.call({
6412 length: 0x100000000
6413 }, 1) !== 4294967297;
6414 });
6415
6416 // V8 and Safari <= 15.4, FF < 23 throws InternalError
6417 // https://bugs.chromium.org/p/v8/issues/detail?id=12681
6418 var properErrorOnNonWritableLength$1 = function properErrorOnNonWritableLength() {
6419 try {
6420 // eslint-disable-next-line es/no-object-defineproperty -- safe
6421 Object.defineProperty([], 'length', {
6422 writable: false
6423 }).push();
6424 } catch (error) {
6425 return error instanceof TypeError;
6426 }
6427 };
6428 var FORCED$w = INCORRECT_TO_LENGTH || !properErrorOnNonWritableLength$1();
6429
6430 // `Array.prototype.push` method
6431 // https://tc39.es/ecma262/#sec-array.prototype.push
6432 $$2s({
6433 target: 'Array',
6434 proto: true,
6435 arity: 1,
6436 forced: FORCED$w
6437 }, {
6438 // eslint-disable-next-line no-unused-vars -- required for `.length`
6439 push: function push(item) {
6440 var O = toObject$g(this);
6441 var len = lengthOfArrayLike$e(O);
6442 var argCount = arguments.length;
6443 doesNotExceedSafeInteger$3(len + argCount);
6444 for (var i = 0; i < argCount; i++) {
6445 O[len] = arguments[i];
6446 len++;
6447 }
6448 setArrayLength$2(O, len);
6449 return len;
6450 }
6451 });
6452
6453 var es_array_reduce = {};
6454
6455 var aCallable$f = aCallable$l;
6456 var toObject$f = toObject$t;
6457 var IndexedObject$1 = indexedObject;
6458 var lengthOfArrayLike$d = lengthOfArrayLike$t;
6459 var $TypeError$d = TypeError;
6460
6461 // `Array.prototype.{ reduce, reduceRight }` methods implementation
6462 var createMethod$4 = function createMethod(IS_RIGHT) {
6463 return function (that, callbackfn, argumentsLength, memo) {
6464 aCallable$f(callbackfn);
6465 var O = toObject$f(that);
6466 var self = IndexedObject$1(O);
6467 var length = lengthOfArrayLike$d(O);
6468 var index = IS_RIGHT ? length - 1 : 0;
6469 var i = IS_RIGHT ? -1 : 1;
6470 if (argumentsLength < 2) while (true) {
6471 if (index in self) {
6472 memo = self[index];
6473 index += i;
6474 break;
6475 }
6476 index += i;
6477 if (IS_RIGHT ? index < 0 : length <= index) {
6478 throw $TypeError$d('Reduce of empty array with no initial value');
6479 }
6480 }
6481 for (; IS_RIGHT ? index >= 0 : length > index; index += i) if (index in self) {
6482 memo = callbackfn(memo, self[index], index, O);
6483 }
6484 return memo;
6485 };
6486 };
6487 var arrayReduce = {
6488 // `Array.prototype.reduce` method
6489 // https://tc39.es/ecma262/#sec-array.prototype.reduce
6490 left: createMethod$4(false),
6491 // `Array.prototype.reduceRight` method
6492 // https://tc39.es/ecma262/#sec-array.prototype.reduceright
6493 right: createMethod$4(true)
6494 };
6495 var arrayReduce$1 = /*@__PURE__*/getDefaultExportFromCjs(arrayReduce);
6496
6497 var classof$f = classofRaw$2;
6498 var engineIsNode = typeof process != 'undefined' && classof$f(process) == 'process';
6499 var engineIsNode$1 = /*@__PURE__*/getDefaultExportFromCjs(engineIsNode);
6500
6501 'use strict';
6502 var $$2r = _export;
6503 var $reduce$1 = arrayReduce.left;
6504 var arrayMethodIsStrict$3 = arrayMethodIsStrict$9;
6505 var CHROME_VERSION$1 = engineV8Version;
6506 var IS_NODE$8 = engineIsNode;
6507
6508 // Chrome 80-82 has a critical bug
6509 // https://bugs.chromium.org/p/chromium/issues/detail?id=1049982
6510 var CHROME_BUG$1 = !IS_NODE$8 && CHROME_VERSION$1 > 79 && CHROME_VERSION$1 < 83;
6511 var FORCED$v = CHROME_BUG$1 || !arrayMethodIsStrict$3('reduce');
6512
6513 // `Array.prototype.reduce` method
6514 // https://tc39.es/ecma262/#sec-array.prototype.reduce
6515 $$2r({
6516 target: 'Array',
6517 proto: true,
6518 forced: FORCED$v
6519 }, {
6520 reduce: function reduce(callbackfn /* , initialValue */) {
6521 var length = arguments.length;
6522 return $reduce$1(this, callbackfn, length, length > 1 ? arguments[1] : undefined);
6523 }
6524 });
6525
6526 var es_array_reduceRight = {};
6527
6528 'use strict';
6529 var $$2q = _export;
6530 var $reduceRight$1 = arrayReduce.right;
6531 var arrayMethodIsStrict$2 = arrayMethodIsStrict$9;
6532 var CHROME_VERSION = engineV8Version;
6533 var IS_NODE$7 = engineIsNode;
6534
6535 // Chrome 80-82 has a critical bug
6536 // https://bugs.chromium.org/p/chromium/issues/detail?id=1049982
6537 var CHROME_BUG = !IS_NODE$7 && CHROME_VERSION > 79 && CHROME_VERSION < 83;
6538 var FORCED$u = CHROME_BUG || !arrayMethodIsStrict$2('reduceRight');
6539
6540 // `Array.prototype.reduceRight` method
6541 // https://tc39.es/ecma262/#sec-array.prototype.reduceright
6542 $$2q({
6543 target: 'Array',
6544 proto: true,
6545 forced: FORCED$u
6546 }, {
6547 reduceRight: function reduceRight(callbackfn /* , initialValue */) {
6548 return $reduceRight$1(this, callbackfn, arguments.length, arguments.length > 1 ? arguments[1] : undefined);
6549 }
6550 });
6551
6552 var es_array_reverse = {};
6553
6554 'use strict';
6555 var $$2p = _export;
6556 var uncurryThis$Z = functionUncurryThis;
6557 var isArray$2 = isArray$9;
6558 var nativeReverse = uncurryThis$Z([].reverse);
6559 var test$1 = [1, 2];
6560
6561 // `Array.prototype.reverse` method
6562 // https://tc39.es/ecma262/#sec-array.prototype.reverse
6563 // fix for Safari 12.0 bug
6564 // https://bugs.webkit.org/show_bug.cgi?id=188794
6565 $$2p({
6566 target: 'Array',
6567 proto: true,
6568 forced: String(test$1) === String(test$1.reverse())
6569 }, {
6570 reverse: function reverse() {
6571 // eslint-disable-next-line no-self-assign -- dirty hack
6572 if (isArray$2(this)) this.length = this.length;
6573 return nativeReverse(this);
6574 }
6575 });
6576
6577 var es_array_slice = {};
6578
6579 'use strict';
6580 var $$2o = _export;
6581 var isArray$1 = isArray$9;
6582 var isConstructor$2 = isConstructor$6;
6583 var isObject$o = isObject$z;
6584 var toAbsoluteIndex$5 = toAbsoluteIndex$a;
6585 var lengthOfArrayLike$c = lengthOfArrayLike$t;
6586 var toIndexedObject$9 = toIndexedObject$j;
6587 var createProperty$4 = createProperty$9;
6588 var wellKnownSymbol$h = wellKnownSymbol$z;
6589 var arrayMethodHasSpeciesSupport$1 = arrayMethodHasSpeciesSupport$5;
6590 var nativeSlice = arraySlice$a;
6591 var HAS_SPECIES_SUPPORT$1 = arrayMethodHasSpeciesSupport$1('slice');
6592 var SPECIES$4 = wellKnownSymbol$h('species');
6593 var $Array$5 = Array;
6594 var max$9 = Math.max;
6595
6596 // `Array.prototype.slice` method
6597 // https://tc39.es/ecma262/#sec-array.prototype.slice
6598 // fallback for not array-like ES3 strings and DOM objects
6599 $$2o({
6600 target: 'Array',
6601 proto: true,
6602 forced: !HAS_SPECIES_SUPPORT$1
6603 }, {
6604 slice: function slice(start, end) {
6605 var O = toIndexedObject$9(this);
6606 var length = lengthOfArrayLike$c(O);
6607 var k = toAbsoluteIndex$5(start, length);
6608 var fin = toAbsoluteIndex$5(end === undefined ? length : end, length);
6609 // inline `ArraySpeciesCreate` for usage native `Array#slice` where it's possible
6610 var Constructor, result, n;
6611 if (isArray$1(O)) {
6612 Constructor = O.constructor;
6613 // cross-realm fallback
6614 if (isConstructor$2(Constructor) && (Constructor === $Array$5 || isArray$1(Constructor.prototype))) {
6615 Constructor = undefined;
6616 } else if (isObject$o(Constructor)) {
6617 Constructor = Constructor[SPECIES$4];
6618 if (Constructor === null) Constructor = undefined;
6619 }
6620 if (Constructor === $Array$5 || Constructor === undefined) {
6621 return nativeSlice(O, k, fin);
6622 }
6623 }
6624 result = new (Constructor === undefined ? $Array$5 : Constructor)(max$9(fin - k, 0));
6625 for (n = 0; k < fin; k++, n++) if (k in O) createProperty$4(result, n, O[k]);
6626 result.length = n;
6627 return result;
6628 }
6629 });
6630
6631 var es_array_some = {};
6632
6633 'use strict';
6634 var $$2n = _export;
6635 var $some$1 = arrayIteration.some;
6636 var arrayMethodIsStrict$1 = arrayMethodIsStrict$9;
6637 var STRICT_METHOD$1 = arrayMethodIsStrict$1('some');
6638
6639 // `Array.prototype.some` method
6640 // https://tc39.es/ecma262/#sec-array.prototype.some
6641 $$2n({
6642 target: 'Array',
6643 proto: true,
6644 forced: !STRICT_METHOD$1
6645 }, {
6646 some: function some(callbackfn /* , thisArg */) {
6647 return $some$1(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
6648 }
6649 });
6650
6651 var es_array_sort = {};
6652
6653 var arraySlice$8 = arraySliceSimple;
6654 var floor$c = Math.floor;
6655 var mergeSort = function mergeSort(array, comparefn) {
6656 var length = array.length;
6657 var middle = floor$c(length / 2);
6658 return length < 8 ? insertionSort(array, comparefn) : merge(array, mergeSort(arraySlice$8(array, 0, middle), comparefn), mergeSort(arraySlice$8(array, middle), comparefn), comparefn);
6659 };
6660 var insertionSort = function insertionSort(array, comparefn) {
6661 var length = array.length;
6662 var i = 1;
6663 var element, j;
6664 while (i < length) {
6665 j = i;
6666 element = array[i];
6667 while (j && comparefn(array[j - 1], element) > 0) {
6668 array[j] = array[--j];
6669 }
6670 if (j !== i++) array[j] = element;
6671 }
6672 return array;
6673 };
6674 var merge = function merge(array, left, right, comparefn) {
6675 var llength = left.length;
6676 var rlength = right.length;
6677 var lindex = 0;
6678 var rindex = 0;
6679 while (lindex < llength || rindex < rlength) {
6680 array[lindex + rindex] = lindex < llength && rindex < rlength ? comparefn(left[lindex], right[rindex]) <= 0 ? left[lindex++] : right[rindex++] : lindex < llength ? left[lindex++] : right[rindex++];
6681 }
6682 return array;
6683 };
6684 var arraySort$1 = mergeSort;
6685 var arraySort$2 = /*@__PURE__*/getDefaultExportFromCjs(arraySort$1);
6686
6687 var userAgent$5 = engineUserAgent;
6688 var firefox = userAgent$5.match(/firefox\/(\d+)/i);
6689 var engineFfVersion = !!firefox && +firefox[1];
6690 var engineFfVersion$1 = /*@__PURE__*/getDefaultExportFromCjs(engineFfVersion);
6691
6692 var UA = engineUserAgent;
6693 var engineIsIeOrEdge = /MSIE|Trident/.test(UA);
6694 var engineIsIeOrEdge$1 = /*@__PURE__*/getDefaultExportFromCjs(engineIsIeOrEdge);
6695
6696 var userAgent$4 = engineUserAgent;
6697 var webkit = userAgent$4.match(/AppleWebKit\/(\d+)\./);
6698 var engineWebkitVersion = !!webkit && +webkit[1];
6699 var engineWebkitVersion$1 = /*@__PURE__*/getDefaultExportFromCjs(engineWebkitVersion);
6700
6701 'use strict';
6702 var $$2m = _export;
6703 var uncurryThis$Y = functionUncurryThis;
6704 var aCallable$e = aCallable$l;
6705 var toObject$e = toObject$t;
6706 var lengthOfArrayLike$b = lengthOfArrayLike$t;
6707 var deletePropertyOrThrow$2 = deletePropertyOrThrow$4;
6708 var toString$s = toString$x;
6709 var fails$_ = fails$1m;
6710 var internalSort$1 = arraySort$1;
6711 var arrayMethodIsStrict = arrayMethodIsStrict$9;
6712 var FF$1 = engineFfVersion;
6713 var IE_OR_EDGE$1 = engineIsIeOrEdge;
6714 var V8$2 = engineV8Version;
6715 var WEBKIT$2 = engineWebkitVersion;
6716 var test = [];
6717 var nativeSort$1 = uncurryThis$Y(test.sort);
6718 var push$9 = uncurryThis$Y(test.push);
6719
6720 // IE8-
6721 var FAILS_ON_UNDEFINED = fails$_(function () {
6722 test.sort(undefined);
6723 });
6724 // V8 bug
6725 var FAILS_ON_NULL = fails$_(function () {
6726 test.sort(null);
6727 });
6728 // Old WebKit
6729 var STRICT_METHOD = arrayMethodIsStrict('sort');
6730 var STABLE_SORT$1 = !fails$_(function () {
6731 // feature detection can be too slow, so check engines versions
6732 if (V8$2) return V8$2 < 70;
6733 if (FF$1 && FF$1 > 3) return;
6734 if (IE_OR_EDGE$1) return true;
6735 if (WEBKIT$2) return WEBKIT$2 < 603;
6736 var result = '';
6737 var code, chr, value, index;
6738
6739 // generate an array with more 512 elements (Chakra and old V8 fails only in this case)
6740 for (code = 65; code < 76; code++) {
6741 chr = String.fromCharCode(code);
6742 switch (code) {
6743 case 66:
6744 case 69:
6745 case 70:
6746 case 72:
6747 value = 3;
6748 break;
6749 case 68:
6750 case 71:
6751 value = 4;
6752 break;
6753 default:
6754 value = 2;
6755 }
6756 for (index = 0; index < 47; index++) {
6757 test.push({
6758 k: chr + index,
6759 v: value
6760 });
6761 }
6762 }
6763 test.sort(function (a, b) {
6764 return b.v - a.v;
6765 });
6766 for (index = 0; index < test.length; index++) {
6767 chr = test[index].k.charAt(0);
6768 if (result.charAt(result.length - 1) !== chr) result += chr;
6769 }
6770 return result !== 'DGBEFHACIJK';
6771 });
6772 var FORCED$t = FAILS_ON_UNDEFINED || !FAILS_ON_NULL || !STRICT_METHOD || !STABLE_SORT$1;
6773 var getSortCompare$1 = function getSortCompare(comparefn) {
6774 return function (x, y) {
6775 if (y === undefined) return -1;
6776 if (x === undefined) return 1;
6777 if (comparefn !== undefined) return +comparefn(x, y) || 0;
6778 return toString$s(x) > toString$s(y) ? 1 : -1;
6779 };
6780 };
6781
6782 // `Array.prototype.sort` method
6783 // https://tc39.es/ecma262/#sec-array.prototype.sort
6784 $$2m({
6785 target: 'Array',
6786 proto: true,
6787 forced: FORCED$t
6788 }, {
6789 sort: function sort(comparefn) {
6790 if (comparefn !== undefined) aCallable$e(comparefn);
6791 var array = toObject$e(this);
6792 if (STABLE_SORT$1) return comparefn === undefined ? nativeSort$1(array) : nativeSort$1(array, comparefn);
6793 var items = [];
6794 var arrayLength = lengthOfArrayLike$b(array);
6795 var itemsLength, index;
6796 for (index = 0; index < arrayLength; index++) {
6797 if (index in array) push$9(items, array[index]);
6798 }
6799 internalSort$1(items, getSortCompare$1(comparefn));
6800 itemsLength = lengthOfArrayLike$b(items);
6801 index = 0;
6802 while (index < itemsLength) array[index] = items[index++];
6803 while (index < arrayLength) deletePropertyOrThrow$2(array, index++);
6804 return array;
6805 }
6806 });
6807
6808 var es_array_species = {};
6809
6810 'use strict';
6811 var getBuiltIn$b = getBuiltIn$m;
6812 var defineBuiltInAccessor$e = defineBuiltInAccessor$h;
6813 var wellKnownSymbol$g = wellKnownSymbol$z;
6814 var DESCRIPTORS$v = descriptors;
6815 var SPECIES$3 = wellKnownSymbol$g('species');
6816 var setSpecies$6 = function setSpecies(CONSTRUCTOR_NAME) {
6817 var Constructor = getBuiltIn$b(CONSTRUCTOR_NAME);
6818 if (DESCRIPTORS$v && Constructor && !Constructor[SPECIES$3]) {
6819 defineBuiltInAccessor$e(Constructor, SPECIES$3, {
6820 configurable: true,
6821 get: function get() {
6822 return this;
6823 }
6824 });
6825 }
6826 };
6827 var setSpecies$7 = /*@__PURE__*/getDefaultExportFromCjs(setSpecies$6);
6828
6829 var setSpecies$5 = setSpecies$6;
6830
6831 // `Array[@@species]` getter
6832 // https://tc39.es/ecma262/#sec-get-array-@@species
6833 setSpecies$5('Array');
6834
6835 var es_array_splice = {};
6836
6837 'use strict';
6838 var $$2l = _export;
6839 var toObject$d = toObject$t;
6840 var toAbsoluteIndex$4 = toAbsoluteIndex$a;
6841 var toIntegerOrInfinity$f = toIntegerOrInfinity$l;
6842 var lengthOfArrayLike$a = lengthOfArrayLike$t;
6843 var setArrayLength$1 = arraySetLength;
6844 var doesNotExceedSafeInteger$2 = doesNotExceedSafeInteger$6;
6845 var arraySpeciesCreate = arraySpeciesCreate$5;
6846 var createProperty$3 = createProperty$9;
6847 var deletePropertyOrThrow$1 = deletePropertyOrThrow$4;
6848 var arrayMethodHasSpeciesSupport = arrayMethodHasSpeciesSupport$5;
6849 var HAS_SPECIES_SUPPORT = arrayMethodHasSpeciesSupport('splice');
6850 var max$8 = Math.max;
6851 var min$a = Math.min;
6852
6853 // `Array.prototype.splice` method
6854 // https://tc39.es/ecma262/#sec-array.prototype.splice
6855 // with adding support of @@species
6856 $$2l({
6857 target: 'Array',
6858 proto: true,
6859 forced: !HAS_SPECIES_SUPPORT
6860 }, {
6861 splice: function splice(start, deleteCount /* , ...items */) {
6862 var O = toObject$d(this);
6863 var len = lengthOfArrayLike$a(O);
6864 var actualStart = toAbsoluteIndex$4(start, len);
6865 var argumentsLength = arguments.length;
6866 var insertCount, actualDeleteCount, A, k, from, to;
6867 if (argumentsLength === 0) {
6868 insertCount = actualDeleteCount = 0;
6869 } else if (argumentsLength === 1) {
6870 insertCount = 0;
6871 actualDeleteCount = len - actualStart;
6872 } else {
6873 insertCount = argumentsLength - 2;
6874 actualDeleteCount = min$a(max$8(toIntegerOrInfinity$f(deleteCount), 0), len - actualStart);
6875 }
6876 doesNotExceedSafeInteger$2(len + insertCount - actualDeleteCount);
6877 A = arraySpeciesCreate(O, actualDeleteCount);
6878 for (k = 0; k < actualDeleteCount; k++) {
6879 from = actualStart + k;
6880 if (from in O) createProperty$3(A, k, O[from]);
6881 }
6882 A.length = actualDeleteCount;
6883 if (insertCount < actualDeleteCount) {
6884 for (k = actualStart; k < len - actualDeleteCount; k++) {
6885 from = k + actualDeleteCount;
6886 to = k + insertCount;
6887 if (from in O) O[to] = O[from];else deletePropertyOrThrow$1(O, to);
6888 }
6889 for (k = len; k > len - actualDeleteCount + insertCount; k--) deletePropertyOrThrow$1(O, k - 1);
6890 } else if (insertCount > actualDeleteCount) {
6891 for (k = len - actualDeleteCount; k > actualStart; k--) {
6892 from = k + actualDeleteCount - 1;
6893 to = k + insertCount - 1;
6894 if (from in O) O[to] = O[from];else deletePropertyOrThrow$1(O, to);
6895 }
6896 }
6897 for (k = 0; k < insertCount; k++) {
6898 O[k + actualStart] = arguments[k + 2];
6899 }
6900 setArrayLength$1(O, len - actualDeleteCount + insertCount);
6901 return A;
6902 }
6903 });
6904
6905 var es_array_toReversed = {};
6906
6907 var lengthOfArrayLike$9 = lengthOfArrayLike$t;
6908
6909 // https://tc39.es/proposal-change-array-by-copy/#sec-array.prototype.toReversed
6910 // https://tc39.es/proposal-change-array-by-copy/#sec-%typedarray%.prototype.toReversed
6911 var arrayToReversed$2 = function arrayToReversed(O, C) {
6912 var len = lengthOfArrayLike$9(O);
6913 var A = new C(len);
6914 var k = 0;
6915 for (; k < len; k++) A[k] = O[len - k - 1];
6916 return A;
6917 };
6918 var arrayToReversed$3 = /*@__PURE__*/getDefaultExportFromCjs(arrayToReversed$2);
6919
6920 'use strict';
6921 var $$2k = _export;
6922 var arrayToReversed$1 = arrayToReversed$2;
6923 var toIndexedObject$8 = toIndexedObject$j;
6924 var addToUnscopables$4 = addToUnscopables$e;
6925 var $Array$4 = Array;
6926
6927 // `Array.prototype.toReversed` method
6928 // https://tc39.es/proposal-change-array-by-copy/#sec-array.prototype.toReversed
6929 $$2k({
6930 target: 'Array',
6931 proto: true
6932 }, {
6933 toReversed: function toReversed() {
6934 return arrayToReversed$1(toIndexedObject$8(this), $Array$4);
6935 }
6936 });
6937 addToUnscopables$4('toReversed');
6938
6939 var es_array_toSorted = {};
6940
6941 var lengthOfArrayLike$8 = lengthOfArrayLike$t;
6942 var arrayFromConstructorAndList$3 = function arrayFromConstructorAndList(Constructor, list) {
6943 var index = 0;
6944 var length = lengthOfArrayLike$8(list);
6945 var result = new Constructor(length);
6946 while (length > index) result[index] = list[index++];
6947 return result;
6948 };
6949 var arrayFromConstructorAndList$4 = /*@__PURE__*/getDefaultExportFromCjs(arrayFromConstructorAndList$3);
6950
6951 var global$L = global$Z;
6952 var entryVirtual = function entryVirtual(CONSTRUCTOR) {
6953 return global$L[CONSTRUCTOR].prototype;
6954 };
6955 var entryVirtual$1 = /*@__PURE__*/getDefaultExportFromCjs(entryVirtual);
6956
6957 'use strict';
6958 var $$2j = _export;
6959 var uncurryThis$X = functionUncurryThis;
6960 var aCallable$d = aCallable$l;
6961 var toIndexedObject$7 = toIndexedObject$j;
6962 var arrayFromConstructorAndList$2 = arrayFromConstructorAndList$3;
6963 var getVirtual = entryVirtual;
6964 var addToUnscopables$3 = addToUnscopables$e;
6965 var $Array$3 = Array;
6966 var sort$1 = uncurryThis$X(getVirtual('Array').sort);
6967
6968 // `Array.prototype.toSorted` method
6969 // https://tc39.es/proposal-change-array-by-copy/#sec-array.prototype.toSorted
6970 $$2j({
6971 target: 'Array',
6972 proto: true
6973 }, {
6974 toSorted: function toSorted(compareFn) {
6975 if (compareFn !== undefined) aCallable$d(compareFn);
6976 var O = toIndexedObject$7(this);
6977 var A = arrayFromConstructorAndList$2($Array$3, O);
6978 return sort$1(A, compareFn);
6979 }
6980 });
6981 addToUnscopables$3('toSorted');
6982
6983 var es_array_toSpliced = {};
6984
6985 'use strict';
6986 var $$2i = _export;
6987 var addToUnscopables$2 = addToUnscopables$e;
6988 var doesNotExceedSafeInteger$1 = doesNotExceedSafeInteger$6;
6989 var lengthOfArrayLike$7 = lengthOfArrayLike$t;
6990 var toAbsoluteIndex$3 = toAbsoluteIndex$a;
6991 var toIndexedObject$6 = toIndexedObject$j;
6992 var toIntegerOrInfinity$e = toIntegerOrInfinity$l;
6993 var $Array$2 = Array;
6994 var max$7 = Math.max;
6995 var min$9 = Math.min;
6996
6997 // `Array.prototype.toSpliced` method
6998 // https://tc39.es/proposal-change-array-by-copy/#sec-array.prototype.toSpliced
6999 $$2i({
7000 target: 'Array',
7001 proto: true
7002 }, {
7003 toSpliced: function toSpliced(start, deleteCount /* , ...items */) {
7004 var O = toIndexedObject$6(this);
7005 var len = lengthOfArrayLike$7(O);
7006 var actualStart = toAbsoluteIndex$3(start, len);
7007 var argumentsLength = arguments.length;
7008 var k = 0;
7009 var insertCount, actualDeleteCount, newLen, A;
7010 if (argumentsLength === 0) {
7011 insertCount = actualDeleteCount = 0;
7012 } else if (argumentsLength === 1) {
7013 insertCount = 0;
7014 actualDeleteCount = len - actualStart;
7015 } else {
7016 insertCount = argumentsLength - 2;
7017 actualDeleteCount = min$9(max$7(toIntegerOrInfinity$e(deleteCount), 0), len - actualStart);
7018 }
7019 newLen = doesNotExceedSafeInteger$1(len + insertCount - actualDeleteCount);
7020 A = $Array$2(newLen);
7021 for (; k < actualStart; k++) A[k] = O[k];
7022 for (; k < actualStart + insertCount; k++) A[k] = arguments[k - actualStart + 2];
7023 for (; k < newLen; k++) A[k] = O[k + actualDeleteCount - insertCount];
7024 return A;
7025 }
7026 });
7027 addToUnscopables$2('toSpliced');
7028
7029 var es_array_unscopables_flat = {};
7030
7031 // this method was added to unscopables after implementation
7032 // in popular engines, so it's moved to a separate module
7033 var addToUnscopables$1 = addToUnscopables$e;
7034
7035 // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
7036 addToUnscopables$1('flat');
7037
7038 var es_array_unscopables_flatMap = {};
7039
7040 // this method was added to unscopables after implementation
7041 // in popular engines, so it's moved to a separate module
7042 var addToUnscopables = addToUnscopables$e;
7043
7044 // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
7045 addToUnscopables('flatMap');
7046
7047 var es_array_unshift = {};
7048
7049 'use strict';
7050 var $$2h = _export;
7051 var toObject$c = toObject$t;
7052 var lengthOfArrayLike$6 = lengthOfArrayLike$t;
7053 var setArrayLength = arraySetLength;
7054 var deletePropertyOrThrow = deletePropertyOrThrow$4;
7055 var doesNotExceedSafeInteger = doesNotExceedSafeInteger$6;
7056
7057 // IE8-
7058 var INCORRECT_RESULT = [].unshift(0) !== 1;
7059
7060 // V8 ~ Chrome < 71 and Safari <= 15.4, FF < 23 throws InternalError
7061 var properErrorOnNonWritableLength = function properErrorOnNonWritableLength() {
7062 try {
7063 // eslint-disable-next-line es/no-object-defineproperty -- safe
7064 Object.defineProperty([], 'length', {
7065 writable: false
7066 }).unshift();
7067 } catch (error) {
7068 return error instanceof TypeError;
7069 }
7070 };
7071 var FORCED$s = INCORRECT_RESULT || !properErrorOnNonWritableLength();
7072
7073 // `Array.prototype.unshift` method
7074 // https://tc39.es/ecma262/#sec-array.prototype.unshift
7075 $$2h({
7076 target: 'Array',
7077 proto: true,
7078 arity: 1,
7079 forced: FORCED$s
7080 }, {
7081 // eslint-disable-next-line no-unused-vars -- required for `.length`
7082 unshift: function unshift(item) {
7083 var O = toObject$c(this);
7084 var len = lengthOfArrayLike$6(O);
7085 var argCount = arguments.length;
7086 if (argCount) {
7087 doesNotExceedSafeInteger(len + argCount);
7088 var k = len;
7089 while (k--) {
7090 var to = k + argCount;
7091 if (k in O) O[to] = O[k];else deletePropertyOrThrow(O, to);
7092 }
7093 for (var j = 0; j < argCount; j++) {
7094 O[j] = arguments[j];
7095 }
7096 }
7097 return setArrayLength(O, len + argCount);
7098 }
7099 });
7100
7101 var es_array_with = {};
7102
7103 var lengthOfArrayLike$5 = lengthOfArrayLike$t;
7104 var toIntegerOrInfinity$d = toIntegerOrInfinity$l;
7105 var $RangeError$9 = RangeError;
7106
7107 // https://tc39.es/proposal-change-array-by-copy/#sec-array.prototype.with
7108 // https://tc39.es/proposal-change-array-by-copy/#sec-%typedarray%.prototype.with
7109 var arrayWith$2 = function arrayWith(O, C, index, value) {
7110 var len = lengthOfArrayLike$5(O);
7111 var relativeIndex = toIntegerOrInfinity$d(index);
7112 var actualIndex = relativeIndex < 0 ? len + relativeIndex : relativeIndex;
7113 if (actualIndex >= len || actualIndex < 0) throw $RangeError$9('Incorrect index');
7114 var A = new C(len);
7115 var k = 0;
7116 for (; k < len; k++) A[k] = k === actualIndex ? value : O[k];
7117 return A;
7118 };
7119 var arrayWith$3 = /*@__PURE__*/getDefaultExportFromCjs(arrayWith$2);
7120
7121 'use strict';
7122 var $$2g = _export;
7123 var arrayWith$1 = arrayWith$2;
7124 var toIndexedObject$5 = toIndexedObject$j;
7125 var $Array$1 = Array;
7126
7127 // `Array.prototype.with` method
7128 // https://tc39.es/proposal-change-array-by-copy/#sec-array.prototype.with
7129 $$2g({
7130 target: 'Array',
7131 proto: true
7132 }, {
7133 'with': function _with(index, value) {
7134 return arrayWith$1(toIndexedObject$5(this), $Array$1, index, value);
7135 }
7136 });
7137
7138 var es_arrayBuffer_constructor = {};
7139
7140 // eslint-disable-next-line es/no-typed-arrays -- safe
7141 var arrayBufferBasicDetection = typeof ArrayBuffer != 'undefined' && typeof DataView != 'undefined';
7142 var arrayBufferBasicDetection$1 = /*@__PURE__*/getDefaultExportFromCjs(arrayBufferBasicDetection);
7143
7144 var defineBuiltIn$f = defineBuiltIn$m;
7145 var defineBuiltIns$5 = function defineBuiltIns(target, src, options) {
7146 for (var key in src) defineBuiltIn$f(target, key, src[key], options);
7147 return target;
7148 };
7149 var defineBuiltIns$6 = /*@__PURE__*/getDefaultExportFromCjs(defineBuiltIns$5);
7150
7151 var isPrototypeOf$5 = objectIsPrototypeOf;
7152 var $TypeError$c = TypeError;
7153 var anInstance$a = function anInstance(it, Prototype) {
7154 if (isPrototypeOf$5(Prototype, it)) return it;
7155 throw $TypeError$c('Incorrect invocation');
7156 };
7157 var anInstance$b = /*@__PURE__*/getDefaultExportFromCjs(anInstance$a);
7158
7159 var toIntegerOrInfinity$c = toIntegerOrInfinity$l;
7160 var toLength$b = toLength$d;
7161 var $RangeError$8 = RangeError;
7162
7163 // `ToIndex` abstract operation
7164 // https://tc39.es/ecma262/#sec-toindex
7165 var toIndex$2 = function toIndex(it) {
7166 if (it === undefined) return 0;
7167 var number = toIntegerOrInfinity$c(it);
7168 var length = toLength$b(number);
7169 if (number !== length) throw $RangeError$8('Wrong length or index');
7170 return length;
7171 };
7172 var toIndex$3 = /*@__PURE__*/getDefaultExportFromCjs(toIndex$2);
7173
7174 // IEEE754 conversions based on https://github.com/feross/ieee754
7175 var $Array = Array;
7176 var abs$b = Math.abs;
7177 var pow$9 = Math.pow;
7178 var floor$b = Math.floor;
7179 var log$c = Math.log;
7180 var LN2$2 = Math.LN2;
7181 var pack$2 = function pack(number, mantissaLength, bytes) {
7182 var buffer = $Array(bytes);
7183 var exponentLength = bytes * 8 - mantissaLength - 1;
7184 var eMax = (1 << exponentLength) - 1;
7185 var eBias = eMax >> 1;
7186 var rt = mantissaLength === 23 ? pow$9(2, -24) - pow$9(2, -77) : 0;
7187 var sign = number < 0 || number === 0 && 1 / number < 0 ? 1 : 0;
7188 var index = 0;
7189 var exponent, mantissa, c;
7190 number = abs$b(number);
7191 // eslint-disable-next-line no-self-compare -- NaN check
7192 if (number != number || number === Infinity) {
7193 // eslint-disable-next-line no-self-compare -- NaN check
7194 mantissa = number != number ? 1 : 0;
7195 exponent = eMax;
7196 } else {
7197 exponent = floor$b(log$c(number) / LN2$2);
7198 c = pow$9(2, -exponent);
7199 if (number * c < 1) {
7200 exponent--;
7201 c *= 2;
7202 }
7203 if (exponent + eBias >= 1) {
7204 number += rt / c;
7205 } else {
7206 number += rt * pow$9(2, 1 - eBias);
7207 }
7208 if (number * c >= 2) {
7209 exponent++;
7210 c /= 2;
7211 }
7212 if (exponent + eBias >= eMax) {
7213 mantissa = 0;
7214 exponent = eMax;
7215 } else if (exponent + eBias >= 1) {
7216 mantissa = (number * c - 1) * pow$9(2, mantissaLength);
7217 exponent = exponent + eBias;
7218 } else {
7219 mantissa = number * pow$9(2, eBias - 1) * pow$9(2, mantissaLength);
7220 exponent = 0;
7221 }
7222 }
7223 while (mantissaLength >= 8) {
7224 buffer[index++] = mantissa & 255;
7225 mantissa /= 256;
7226 mantissaLength -= 8;
7227 }
7228 exponent = exponent << mantissaLength | mantissa;
7229 exponentLength += mantissaLength;
7230 while (exponentLength > 0) {
7231 buffer[index++] = exponent & 255;
7232 exponent /= 256;
7233 exponentLength -= 8;
7234 }
7235 buffer[--index] |= sign * 128;
7236 return buffer;
7237 };
7238 var unpack$2 = function unpack(buffer, mantissaLength) {
7239 var bytes = buffer.length;
7240 var exponentLength = bytes * 8 - mantissaLength - 1;
7241 var eMax = (1 << exponentLength) - 1;
7242 var eBias = eMax >> 1;
7243 var nBits = exponentLength - 7;
7244 var index = bytes - 1;
7245 var sign = buffer[index--];
7246 var exponent = sign & 127;
7247 var mantissa;
7248 sign >>= 7;
7249 while (nBits > 0) {
7250 exponent = exponent * 256 + buffer[index--];
7251 nBits -= 8;
7252 }
7253 mantissa = exponent & (1 << -nBits) - 1;
7254 exponent >>= -nBits;
7255 nBits += mantissaLength;
7256 while (nBits > 0) {
7257 mantissa = mantissa * 256 + buffer[index--];
7258 nBits -= 8;
7259 }
7260 if (exponent === 0) {
7261 exponent = 1 - eBias;
7262 } else if (exponent === eMax) {
7263 return mantissa ? NaN : sign ? -Infinity : Infinity;
7264 } else {
7265 mantissa = mantissa + pow$9(2, mantissaLength);
7266 exponent = exponent - eBias;
7267 }
7268 return (sign ? -1 : 1) * mantissa * pow$9(2, exponent - mantissaLength);
7269 };
7270 var ieee754 = {
7271 pack: pack$2,
7272 unpack: unpack$2
7273 };
7274 var ieee754$1 = /*@__PURE__*/getDefaultExportFromCjs(ieee754);
7275
7276 'use strict';
7277 var global$K = global$Z;
7278 var uncurryThis$W = functionUncurryThis;
7279 var DESCRIPTORS$u = descriptors;
7280 var NATIVE_ARRAY_BUFFER$2 = arrayBufferBasicDetection;
7281 var FunctionName = functionName;
7282 var createNonEnumerableProperty$7 = createNonEnumerableProperty$f;
7283 var defineBuiltInAccessor$d = defineBuiltInAccessor$h;
7284 var defineBuiltIns$4 = defineBuiltIns$5;
7285 var fails$Z = fails$1m;
7286 var anInstance$9 = anInstance$a;
7287 var toIntegerOrInfinity$b = toIntegerOrInfinity$l;
7288 var toLength$a = toLength$d;
7289 var toIndex$1 = toIndex$2;
7290 var IEEE754 = ieee754;
7291 var getPrototypeOf$7 = objectGetPrototypeOf$1;
7292 var setPrototypeOf$5 = objectSetPrototypeOf$1;
7293 var getOwnPropertyNames$4 = objectGetOwnPropertyNames.f;
7294 var arrayFill = arrayFill$1;
7295 var arraySlice$7 = arraySliceSimple;
7296 var setToStringTag$8 = setToStringTag$d;
7297 var InternalStateModule$a = internalState;
7298 var PROPER_FUNCTION_NAME$2 = FunctionName.PROPER;
7299 var CONFIGURABLE_FUNCTION_NAME = FunctionName.CONFIGURABLE;
7300 var ARRAY_BUFFER$1 = 'ArrayBuffer';
7301 var DATA_VIEW = 'DataView';
7302 var PROTOTYPE = 'prototype';
7303 var WRONG_LENGTH$1 = 'Wrong length';
7304 var WRONG_INDEX = 'Wrong index';
7305 var getInternalArrayBufferState = InternalStateModule$a.getterFor(ARRAY_BUFFER$1);
7306 var getInternalDataViewState = InternalStateModule$a.getterFor(DATA_VIEW);
7307 var setInternalState$9 = InternalStateModule$a.set;
7308 var NativeArrayBuffer$1 = global$K[ARRAY_BUFFER$1];
7309 var $ArrayBuffer = NativeArrayBuffer$1;
7310 var ArrayBufferPrototype$1 = $ArrayBuffer && $ArrayBuffer[PROTOTYPE];
7311 var $DataView = global$K[DATA_VIEW];
7312 var DataViewPrototype$1 = $DataView && $DataView[PROTOTYPE];
7313 var ObjectPrototype$3 = Object.prototype;
7314 var Array$2 = global$K.Array;
7315 var RangeError$4 = global$K.RangeError;
7316 var fill$3 = uncurryThis$W(arrayFill);
7317 var reverse$3 = uncurryThis$W([].reverse);
7318 var packIEEE754 = IEEE754.pack;
7319 var unpackIEEE754 = IEEE754.unpack;
7320 var packInt8 = function packInt8(number) {
7321 return [number & 0xFF];
7322 };
7323 var packInt16 = function packInt16(number) {
7324 return [number & 0xFF, number >> 8 & 0xFF];
7325 };
7326 var packInt32 = function packInt32(number) {
7327 return [number & 0xFF, number >> 8 & 0xFF, number >> 16 & 0xFF, number >> 24 & 0xFF];
7328 };
7329 var unpackInt32 = function unpackInt32(buffer) {
7330 return buffer[3] << 24 | buffer[2] << 16 | buffer[1] << 8 | buffer[0];
7331 };
7332 var packFloat32 = function packFloat32(number) {
7333 return packIEEE754(number, 23, 4);
7334 };
7335 var packFloat64 = function packFloat64(number) {
7336 return packIEEE754(number, 52, 8);
7337 };
7338 var addGetter$1 = function addGetter(Constructor, key, getInternalState) {
7339 defineBuiltInAccessor$d(Constructor[PROTOTYPE], key, {
7340 configurable: true,
7341 get: function get() {
7342 return getInternalState(this)[key];
7343 }
7344 });
7345 };
7346 var get$3 = function get(view, count, index, isLittleEndian) {
7347 var intIndex = toIndex$1(index);
7348 var store = getInternalDataViewState(view);
7349 if (intIndex + count > store.byteLength) throw RangeError$4(WRONG_INDEX);
7350 var bytes = store.bytes;
7351 var start = intIndex + store.byteOffset;
7352 var pack = arraySlice$7(bytes, start, start + count);
7353 return isLittleEndian ? pack : reverse$3(pack);
7354 };
7355 var set$2 = function set(view, count, index, conversion, value, isLittleEndian) {
7356 var intIndex = toIndex$1(index);
7357 var store = getInternalDataViewState(view);
7358 if (intIndex + count > store.byteLength) throw RangeError$4(WRONG_INDEX);
7359 var bytes = store.bytes;
7360 var start = intIndex + store.byteOffset;
7361 var pack = conversion(+value);
7362 for (var i = 0; i < count; i++) bytes[start + i] = pack[isLittleEndian ? i : count - i - 1];
7363 };
7364 if (!NATIVE_ARRAY_BUFFER$2) {
7365 $ArrayBuffer = function ArrayBuffer(length) {
7366 anInstance$9(this, ArrayBufferPrototype$1);
7367 var byteLength = toIndex$1(length);
7368 setInternalState$9(this, {
7369 type: ARRAY_BUFFER$1,
7370 bytes: fill$3(Array$2(byteLength), 0),
7371 byteLength: byteLength
7372 });
7373 if (!DESCRIPTORS$u) {
7374 this.byteLength = byteLength;
7375 this.detached = false;
7376 }
7377 };
7378 ArrayBufferPrototype$1 = $ArrayBuffer[PROTOTYPE];
7379 $DataView = function DataView(buffer, byteOffset, byteLength) {
7380 anInstance$9(this, DataViewPrototype$1);
7381 anInstance$9(buffer, ArrayBufferPrototype$1);
7382 var bufferState = getInternalArrayBufferState(buffer);
7383 var bufferLength = bufferState.byteLength;
7384 var offset = toIntegerOrInfinity$b(byteOffset);
7385 if (offset < 0 || offset > bufferLength) throw RangeError$4('Wrong offset');
7386 byteLength = byteLength === undefined ? bufferLength - offset : toLength$a(byteLength);
7387 if (offset + byteLength > bufferLength) throw RangeError$4(WRONG_LENGTH$1);
7388 setInternalState$9(this, {
7389 type: DATA_VIEW,
7390 buffer: buffer,
7391 byteLength: byteLength,
7392 byteOffset: offset,
7393 bytes: bufferState.bytes
7394 });
7395 if (!DESCRIPTORS$u) {
7396 this.buffer = buffer;
7397 this.byteLength = byteLength;
7398 this.byteOffset = offset;
7399 }
7400 };
7401 DataViewPrototype$1 = $DataView[PROTOTYPE];
7402 if (DESCRIPTORS$u) {
7403 addGetter$1($ArrayBuffer, 'byteLength', getInternalArrayBufferState);
7404 addGetter$1($DataView, 'buffer', getInternalDataViewState);
7405 addGetter$1($DataView, 'byteLength', getInternalDataViewState);
7406 addGetter$1($DataView, 'byteOffset', getInternalDataViewState);
7407 }
7408 defineBuiltIns$4(DataViewPrototype$1, {
7409 getInt8: function getInt8(byteOffset) {
7410 return get$3(this, 1, byteOffset)[0] << 24 >> 24;
7411 },
7412 getUint8: function getUint8(byteOffset) {
7413 return get$3(this, 1, byteOffset)[0];
7414 },
7415 getInt16: function getInt16(byteOffset /* , littleEndian */) {
7416 var bytes = get$3(this, 2, byteOffset, arguments.length > 1 ? arguments[1] : undefined);
7417 return (bytes[1] << 8 | bytes[0]) << 16 >> 16;
7418 },
7419 getUint16: function getUint16(byteOffset /* , littleEndian */) {
7420 var bytes = get$3(this, 2, byteOffset, arguments.length > 1 ? arguments[1] : undefined);
7421 return bytes[1] << 8 | bytes[0];
7422 },
7423 getInt32: function getInt32(byteOffset /* , littleEndian */) {
7424 return unpackInt32(get$3(this, 4, byteOffset, arguments.length > 1 ? arguments[1] : undefined));
7425 },
7426 getUint32: function getUint32(byteOffset /* , littleEndian */) {
7427 return unpackInt32(get$3(this, 4, byteOffset, arguments.length > 1 ? arguments[1] : undefined)) >>> 0;
7428 },
7429 getFloat32: function getFloat32(byteOffset /* , littleEndian */) {
7430 return unpackIEEE754(get$3(this, 4, byteOffset, arguments.length > 1 ? arguments[1] : undefined), 23);
7431 },
7432 getFloat64: function getFloat64(byteOffset /* , littleEndian */) {
7433 return unpackIEEE754(get$3(this, 8, byteOffset, arguments.length > 1 ? arguments[1] : undefined), 52);
7434 },
7435 setInt8: function setInt8(byteOffset, value) {
7436 set$2(this, 1, byteOffset, packInt8, value);
7437 },
7438 setUint8: function setUint8(byteOffset, value) {
7439 set$2(this, 1, byteOffset, packInt8, value);
7440 },
7441 setInt16: function setInt16(byteOffset, value /* , littleEndian */) {
7442 set$2(this, 2, byteOffset, packInt16, value, arguments.length > 2 ? arguments[2] : undefined);
7443 },
7444 setUint16: function setUint16(byteOffset, value /* , littleEndian */) {
7445 set$2(this, 2, byteOffset, packInt16, value, arguments.length > 2 ? arguments[2] : undefined);
7446 },
7447 setInt32: function setInt32(byteOffset, value /* , littleEndian */) {
7448 set$2(this, 4, byteOffset, packInt32, value, arguments.length > 2 ? arguments[2] : undefined);
7449 },
7450 setUint32: function setUint32(byteOffset, value /* , littleEndian */) {
7451 set$2(this, 4, byteOffset, packInt32, value, arguments.length > 2 ? arguments[2] : undefined);
7452 },
7453 setFloat32: function setFloat32(byteOffset, value /* , littleEndian */) {
7454 set$2(this, 4, byteOffset, packFloat32, value, arguments.length > 2 ? arguments[2] : undefined);
7455 },
7456 setFloat64: function setFloat64(byteOffset, value /* , littleEndian */) {
7457 set$2(this, 8, byteOffset, packFloat64, value, arguments.length > 2 ? arguments[2] : undefined);
7458 }
7459 });
7460 } else {
7461 var INCORRECT_ARRAY_BUFFER_NAME = PROPER_FUNCTION_NAME$2 && NativeArrayBuffer$1.name !== ARRAY_BUFFER$1;
7462 /* eslint-disable no-new -- required for testing */
7463 if (!fails$Z(function () {
7464 NativeArrayBuffer$1(1);
7465 }) || !fails$Z(function () {
7466 new NativeArrayBuffer$1(-1);
7467 }) || fails$Z(function () {
7468 new NativeArrayBuffer$1();
7469 new NativeArrayBuffer$1(1.5);
7470 new NativeArrayBuffer$1(NaN);
7471 return NativeArrayBuffer$1.length != 1 || INCORRECT_ARRAY_BUFFER_NAME && !CONFIGURABLE_FUNCTION_NAME;
7472 })) {
7473 /* eslint-enable no-new -- required for testing */
7474 $ArrayBuffer = function ArrayBuffer(length) {
7475 anInstance$9(this, ArrayBufferPrototype$1);
7476 return new NativeArrayBuffer$1(toIndex$1(length));
7477 };
7478 $ArrayBuffer[PROTOTYPE] = ArrayBufferPrototype$1;
7479 for (var keys$1 = getOwnPropertyNames$4(NativeArrayBuffer$1), j = 0, key$2; keys$1.length > j;) {
7480 if (!((key$2 = keys$1[j++]) in $ArrayBuffer)) {
7481 createNonEnumerableProperty$7($ArrayBuffer, key$2, NativeArrayBuffer$1[key$2]);
7482 }
7483 }
7484 ArrayBufferPrototype$1.constructor = $ArrayBuffer;
7485 } else if (INCORRECT_ARRAY_BUFFER_NAME && CONFIGURABLE_FUNCTION_NAME) {
7486 createNonEnumerableProperty$7(NativeArrayBuffer$1, 'name', ARRAY_BUFFER$1);
7487 }
7488
7489 // WebKit bug - the same parent prototype for typed arrays and data view
7490 if (setPrototypeOf$5 && getPrototypeOf$7(DataViewPrototype$1) !== ObjectPrototype$3) {
7491 setPrototypeOf$5(DataViewPrototype$1, ObjectPrototype$3);
7492 }
7493
7494 // iOS Safari 7.x bug
7495 var testView = new $DataView(new $ArrayBuffer(2));
7496 var $setInt8 = uncurryThis$W(DataViewPrototype$1.setInt8);
7497 testView.setInt8(0, 2147483648);
7498 testView.setInt8(1, 2147483649);
7499 if (testView.getInt8(0) || !testView.getInt8(1)) defineBuiltIns$4(DataViewPrototype$1, {
7500 setInt8: function setInt8(byteOffset, value) {
7501 $setInt8(this, byteOffset, value << 24 >> 24);
7502 },
7503 setUint8: function setUint8(byteOffset, value) {
7504 $setInt8(this, byteOffset, value << 24 >> 24);
7505 }
7506 }, {
7507 unsafe: true
7508 });
7509 }
7510 setToStringTag$8($ArrayBuffer, ARRAY_BUFFER$1);
7511 setToStringTag$8($DataView, DATA_VIEW);
7512 var arrayBuffer = {
7513 ArrayBuffer: $ArrayBuffer,
7514 DataView: $DataView
7515 };
7516 var arrayBuffer$1 = /*@__PURE__*/getDefaultExportFromCjs(arrayBuffer);
7517
7518 'use strict';
7519 var $$2f = _export;
7520 var global$J = global$Z;
7521 var arrayBufferModule = arrayBuffer;
7522 var setSpecies$4 = setSpecies$6;
7523 var ARRAY_BUFFER = 'ArrayBuffer';
7524 var ArrayBuffer$4 = arrayBufferModule[ARRAY_BUFFER];
7525 var NativeArrayBuffer = global$J[ARRAY_BUFFER];
7526
7527 // `ArrayBuffer` constructor
7528 // https://tc39.es/ecma262/#sec-arraybuffer-constructor
7529 $$2f({
7530 global: true,
7531 constructor: true,
7532 forced: NativeArrayBuffer !== ArrayBuffer$4
7533 }, {
7534 ArrayBuffer: ArrayBuffer$4
7535 });
7536 setSpecies$4(ARRAY_BUFFER);
7537
7538 var es_arrayBuffer_isView = {};
7539
7540 'use strict';
7541 var NATIVE_ARRAY_BUFFER$1 = arrayBufferBasicDetection;
7542 var DESCRIPTORS$t = descriptors;
7543 var global$I = global$Z;
7544 var isCallable$e = isCallable$z;
7545 var isObject$n = isObject$z;
7546 var hasOwn$h = hasOwnProperty_1;
7547 var classof$e = classof$m;
7548 var tryToString$1 = tryToString$7;
7549 var createNonEnumerableProperty$6 = createNonEnumerableProperty$f;
7550 var defineBuiltIn$e = defineBuiltIn$m;
7551 var defineBuiltInAccessor$c = defineBuiltInAccessor$h;
7552 var isPrototypeOf$4 = objectIsPrototypeOf;
7553 var getPrototypeOf$6 = objectGetPrototypeOf$1;
7554 var setPrototypeOf$4 = objectSetPrototypeOf$1;
7555 var wellKnownSymbol$f = wellKnownSymbol$z;
7556 var uid$2 = uid$6;
7557 var InternalStateModule$9 = internalState;
7558 var enforceInternalState$3 = InternalStateModule$9.enforce;
7559 var getInternalState$7 = InternalStateModule$9.get;
7560 var Int8Array$4 = global$I.Int8Array;
7561 var Int8ArrayPrototype$1 = Int8Array$4 && Int8Array$4.prototype;
7562 var Uint8ClampedArray$1 = global$I.Uint8ClampedArray;
7563 var Uint8ClampedArrayPrototype = Uint8ClampedArray$1 && Uint8ClampedArray$1.prototype;
7564 var TypedArray$1 = Int8Array$4 && getPrototypeOf$6(Int8Array$4);
7565 var TypedArrayPrototype$2 = Int8ArrayPrototype$1 && getPrototypeOf$6(Int8ArrayPrototype$1);
7566 var ObjectPrototype$2 = Object.prototype;
7567 var TypeError$6 = global$I.TypeError;
7568 var TO_STRING_TAG$1 = wellKnownSymbol$f('toStringTag');
7569 var TYPED_ARRAY_TAG$1 = uid$2('TYPED_ARRAY_TAG');
7570 var TYPED_ARRAY_CONSTRUCTOR = 'TypedArrayConstructor';
7571 // Fixing native typed arrays in Opera Presto crashes the browser, see #595
7572 var NATIVE_ARRAY_BUFFER_VIEWS$3 = NATIVE_ARRAY_BUFFER$1 && !!setPrototypeOf$4 && classof$e(global$I.opera) !== 'Opera';
7573 var TYPED_ARRAY_TAG_REQUIRED = false;
7574 var NAME$1, Constructor, Prototype;
7575 var TypedArrayConstructorsList = {
7576 Int8Array: 1,
7577 Uint8Array: 1,
7578 Uint8ClampedArray: 1,
7579 Int16Array: 2,
7580 Uint16Array: 2,
7581 Int32Array: 4,
7582 Uint32Array: 4,
7583 Float32Array: 4,
7584 Float64Array: 8
7585 };
7586 var BigIntArrayConstructorsList = {
7587 BigInt64Array: 8,
7588 BigUint64Array: 8
7589 };
7590 var isView = function isView(it) {
7591 if (!isObject$n(it)) return false;
7592 var klass = classof$e(it);
7593 return klass === 'DataView' || hasOwn$h(TypedArrayConstructorsList, klass) || hasOwn$h(BigIntArrayConstructorsList, klass);
7594 };
7595 var getTypedArrayConstructor$4 = function getTypedArrayConstructor(it) {
7596 var proto = getPrototypeOf$6(it);
7597 if (!isObject$n(proto)) return;
7598 var state = getInternalState$7(proto);
7599 return state && hasOwn$h(state, TYPED_ARRAY_CONSTRUCTOR) ? state[TYPED_ARRAY_CONSTRUCTOR] : getTypedArrayConstructor(proto);
7600 };
7601 var isTypedArray$2 = function isTypedArray(it) {
7602 if (!isObject$n(it)) return false;
7603 var klass = classof$e(it);
7604 return hasOwn$h(TypedArrayConstructorsList, klass) || hasOwn$h(BigIntArrayConstructorsList, klass);
7605 };
7606 var aTypedArray$s = function aTypedArray(it) {
7607 if (isTypedArray$2(it)) return it;
7608 throw TypeError$6('Target is not a typed array');
7609 };
7610 var aTypedArrayConstructor$4 = function aTypedArrayConstructor(C) {
7611 if (isCallable$e(C) && (!setPrototypeOf$4 || isPrototypeOf$4(TypedArray$1, C))) return C;
7612 throw TypeError$6(tryToString$1(C) + ' is not a typed array constructor');
7613 };
7614 var exportTypedArrayMethod$t = function exportTypedArrayMethod(KEY, property, forced, options) {
7615 if (!DESCRIPTORS$t) return;
7616 if (forced) for (var ARRAY in TypedArrayConstructorsList) {
7617 var TypedArrayConstructor = global$I[ARRAY];
7618 if (TypedArrayConstructor && hasOwn$h(TypedArrayConstructor.prototype, KEY)) try {
7619 delete TypedArrayConstructor.prototype[KEY];
7620 } catch (error) {
7621 // old WebKit bug - some methods are non-configurable
7622 try {
7623 TypedArrayConstructor.prototype[KEY] = property;
7624 } catch (error2) {/* empty */}
7625 }
7626 }
7627 if (!TypedArrayPrototype$2[KEY] || forced) {
7628 defineBuiltIn$e(TypedArrayPrototype$2, KEY, forced ? property : NATIVE_ARRAY_BUFFER_VIEWS$3 && Int8ArrayPrototype$1[KEY] || property, options);
7629 }
7630 };
7631 var exportTypedArrayStaticMethod$2 = function exportTypedArrayStaticMethod(KEY, property, forced) {
7632 var ARRAY, TypedArrayConstructor;
7633 if (!DESCRIPTORS$t) return;
7634 if (setPrototypeOf$4) {
7635 if (forced) for (ARRAY in TypedArrayConstructorsList) {
7636 TypedArrayConstructor = global$I[ARRAY];
7637 if (TypedArrayConstructor && hasOwn$h(TypedArrayConstructor, KEY)) try {
7638 delete TypedArrayConstructor[KEY];
7639 } catch (error) {/* empty */}
7640 }
7641 if (!TypedArray$1[KEY] || forced) {
7642 // V8 ~ Chrome 49-50 `%TypedArray%` methods are non-writable non-configurable
7643 try {
7644 return defineBuiltIn$e(TypedArray$1, KEY, forced ? property : NATIVE_ARRAY_BUFFER_VIEWS$3 && TypedArray$1[KEY] || property);
7645 } catch (error) {/* empty */}
7646 } else return;
7647 }
7648 for (ARRAY in TypedArrayConstructorsList) {
7649 TypedArrayConstructor = global$I[ARRAY];
7650 if (TypedArrayConstructor && (!TypedArrayConstructor[KEY] || forced)) {
7651 defineBuiltIn$e(TypedArrayConstructor, KEY, property);
7652 }
7653 }
7654 };
7655 for (NAME$1 in TypedArrayConstructorsList) {
7656 Constructor = global$I[NAME$1];
7657 Prototype = Constructor && Constructor.prototype;
7658 if (Prototype) enforceInternalState$3(Prototype)[TYPED_ARRAY_CONSTRUCTOR] = Constructor;else NATIVE_ARRAY_BUFFER_VIEWS$3 = false;
7659 }
7660 for (NAME$1 in BigIntArrayConstructorsList) {
7661 Constructor = global$I[NAME$1];
7662 Prototype = Constructor && Constructor.prototype;
7663 if (Prototype) enforceInternalState$3(Prototype)[TYPED_ARRAY_CONSTRUCTOR] = Constructor;
7664 }
7665
7666 // WebKit bug - typed arrays constructors prototype is Object.prototype
7667 if (!NATIVE_ARRAY_BUFFER_VIEWS$3 || !isCallable$e(TypedArray$1) || TypedArray$1 === Function.prototype) {
7668 // eslint-disable-next-line no-shadow -- safe
7669 TypedArray$1 = function TypedArray() {
7670 throw TypeError$6('Incorrect invocation');
7671 };
7672 if (NATIVE_ARRAY_BUFFER_VIEWS$3) for (NAME$1 in TypedArrayConstructorsList) {
7673 if (global$I[NAME$1]) setPrototypeOf$4(global$I[NAME$1], TypedArray$1);
7674 }
7675 }
7676 if (!NATIVE_ARRAY_BUFFER_VIEWS$3 || !TypedArrayPrototype$2 || TypedArrayPrototype$2 === ObjectPrototype$2) {
7677 TypedArrayPrototype$2 = TypedArray$1.prototype;
7678 if (NATIVE_ARRAY_BUFFER_VIEWS$3) for (NAME$1 in TypedArrayConstructorsList) {
7679 if (global$I[NAME$1]) setPrototypeOf$4(global$I[NAME$1].prototype, TypedArrayPrototype$2);
7680 }
7681 }
7682
7683 // WebKit bug - one more object in Uint8ClampedArray prototype chain
7684 if (NATIVE_ARRAY_BUFFER_VIEWS$3 && getPrototypeOf$6(Uint8ClampedArrayPrototype) !== TypedArrayPrototype$2) {
7685 setPrototypeOf$4(Uint8ClampedArrayPrototype, TypedArrayPrototype$2);
7686 }
7687 if (DESCRIPTORS$t && !hasOwn$h(TypedArrayPrototype$2, TO_STRING_TAG$1)) {
7688 TYPED_ARRAY_TAG_REQUIRED = true;
7689 defineBuiltInAccessor$c(TypedArrayPrototype$2, TO_STRING_TAG$1, {
7690 configurable: true,
7691 get: function get() {
7692 return isObject$n(this) ? this[TYPED_ARRAY_TAG$1] : undefined;
7693 }
7694 });
7695 for (NAME$1 in TypedArrayConstructorsList) if (global$I[NAME$1]) {
7696 createNonEnumerableProperty$6(global$I[NAME$1], TYPED_ARRAY_TAG$1, NAME$1);
7697 }
7698 }
7699 var arrayBufferViewCore = {
7700 NATIVE_ARRAY_BUFFER_VIEWS: NATIVE_ARRAY_BUFFER_VIEWS$3,
7701 TYPED_ARRAY_TAG: TYPED_ARRAY_TAG_REQUIRED && TYPED_ARRAY_TAG$1,
7702 aTypedArray: aTypedArray$s,
7703 aTypedArrayConstructor: aTypedArrayConstructor$4,
7704 exportTypedArrayMethod: exportTypedArrayMethod$t,
7705 exportTypedArrayStaticMethod: exportTypedArrayStaticMethod$2,
7706 getTypedArrayConstructor: getTypedArrayConstructor$4,
7707 isView: isView,
7708 isTypedArray: isTypedArray$2,
7709 TypedArray: TypedArray$1,
7710 TypedArrayPrototype: TypedArrayPrototype$2
7711 };
7712 var arrayBufferViewCore$1 = /*@__PURE__*/getDefaultExportFromCjs(arrayBufferViewCore);
7713
7714 var $$2e = _export;
7715 var ArrayBufferViewCore$v = arrayBufferViewCore;
7716 var NATIVE_ARRAY_BUFFER_VIEWS$2 = ArrayBufferViewCore$v.NATIVE_ARRAY_BUFFER_VIEWS;
7717
7718 // `ArrayBuffer.isView` method
7719 // https://tc39.es/ecma262/#sec-arraybuffer.isview
7720 $$2e({
7721 target: 'ArrayBuffer',
7722 stat: true,
7723 forced: !NATIVE_ARRAY_BUFFER_VIEWS$2
7724 }, {
7725 isView: ArrayBufferViewCore$v.isView
7726 });
7727
7728 var es_arrayBuffer_slice = {};
7729
7730 var isConstructor$1 = isConstructor$6;
7731 var tryToString = tryToString$7;
7732 var $TypeError$b = TypeError;
7733
7734 // `Assert: IsConstructor(argument) is true`
7735 var aConstructor$3 = function aConstructor(argument) {
7736 if (isConstructor$1(argument)) return argument;
7737 throw $TypeError$b(tryToString(argument) + ' is not a constructor');
7738 };
7739 var aConstructor$4 = /*@__PURE__*/getDefaultExportFromCjs(aConstructor$3);
7740
7741 var anObject$r = anObject$D;
7742 var aConstructor$2 = aConstructor$3;
7743 var isNullOrUndefined$a = isNullOrUndefined$e;
7744 var wellKnownSymbol$e = wellKnownSymbol$z;
7745 var SPECIES$2 = wellKnownSymbol$e('species');
7746
7747 // `SpeciesConstructor` abstract operation
7748 // https://tc39.es/ecma262/#sec-speciesconstructor
7749 var speciesConstructor$6 = function speciesConstructor(O, defaultConstructor) {
7750 var C = anObject$r(O).constructor;
7751 var S;
7752 return C === undefined || isNullOrUndefined$a(S = anObject$r(C)[SPECIES$2]) ? defaultConstructor : aConstructor$2(S);
7753 };
7754 var speciesConstructor$7 = /*@__PURE__*/getDefaultExportFromCjs(speciesConstructor$6);
7755
7756 'use strict';
7757 var $$2d = _export;
7758 var uncurryThis$V = functionUncurryThisClause;
7759 var fails$Y = fails$1m;
7760 var ArrayBufferModule$2 = arrayBuffer;
7761 var anObject$q = anObject$D;
7762 var toAbsoluteIndex$2 = toAbsoluteIndex$a;
7763 var toLength$9 = toLength$d;
7764 var speciesConstructor$5 = speciesConstructor$6;
7765 var ArrayBuffer$3 = ArrayBufferModule$2.ArrayBuffer;
7766 var DataView$2 = ArrayBufferModule$2.DataView;
7767 var DataViewPrototype = DataView$2.prototype;
7768 var nativeArrayBufferSlice = uncurryThis$V(ArrayBuffer$3.prototype.slice);
7769 var getUint8 = uncurryThis$V(DataViewPrototype.getUint8);
7770 var setUint8 = uncurryThis$V(DataViewPrototype.setUint8);
7771 var INCORRECT_SLICE = fails$Y(function () {
7772 return !new ArrayBuffer$3(2).slice(1, undefined).byteLength;
7773 });
7774
7775 // `ArrayBuffer.prototype.slice` method
7776 // https://tc39.es/ecma262/#sec-arraybuffer.prototype.slice
7777 $$2d({
7778 target: 'ArrayBuffer',
7779 proto: true,
7780 unsafe: true,
7781 forced: INCORRECT_SLICE
7782 }, {
7783 slice: function slice(start, end) {
7784 if (nativeArrayBufferSlice && end === undefined) {
7785 return nativeArrayBufferSlice(anObject$q(this), start); // FF fix
7786 }
7787
7788 var length = anObject$q(this).byteLength;
7789 var first = toAbsoluteIndex$2(start, length);
7790 var fin = toAbsoluteIndex$2(end === undefined ? length : end, length);
7791 var result = new (speciesConstructor$5(this, ArrayBuffer$3))(toLength$9(fin - first));
7792 var viewSource = new DataView$2(this);
7793 var viewTarget = new DataView$2(result);
7794 var index = 0;
7795 while (first < fin) {
7796 setUint8(viewTarget, index++, getUint8(viewSource, first++));
7797 }
7798 return result;
7799 }
7800 });
7801
7802 var es_dataView = {};
7803
7804 var es_dataView_constructor = {};
7805
7806 var $$2c = _export;
7807 var ArrayBufferModule$1 = arrayBuffer;
7808 var NATIVE_ARRAY_BUFFER = arrayBufferBasicDetection;
7809
7810 // `DataView` constructor
7811 // https://tc39.es/ecma262/#sec-dataview-constructor
7812 $$2c({
7813 global: true,
7814 constructor: true,
7815 forced: !NATIVE_ARRAY_BUFFER
7816 }, {
7817 DataView: ArrayBufferModule$1.DataView
7818 });
7819
7820 var es_date_getYear = {};
7821
7822 'use strict';
7823 var $$2b = _export;
7824 var uncurryThis$U = functionUncurryThis;
7825 var fails$X = fails$1m;
7826
7827 // IE8- non-standard case
7828 var FORCED$r = fails$X(function () {
7829 // eslint-disable-next-line es/no-date-prototype-getyear-setyear -- detection
7830 return new Date(16e11).getYear() !== 120;
7831 });
7832 var getFullYear = uncurryThis$U(Date.prototype.getFullYear);
7833
7834 // `Date.prototype.getYear` method
7835 // https://tc39.es/ecma262/#sec-date.prototype.getyear
7836 $$2b({
7837 target: 'Date',
7838 proto: true,
7839 forced: FORCED$r
7840 }, {
7841 getYear: function getYear() {
7842 return getFullYear(this) - 1900;
7843 }
7844 });
7845
7846 var es_date_now = {};
7847
7848 // TODO: Remove from `core-js@4`
7849 var $$2a = _export;
7850 var uncurryThis$T = functionUncurryThis;
7851 var $Date = Date;
7852 var thisTimeValue$4 = uncurryThis$T($Date.prototype.getTime);
7853
7854 // `Date.now` method
7855 // https://tc39.es/ecma262/#sec-date.now
7856 $$2a({
7857 target: 'Date',
7858 stat: true
7859 }, {
7860 now: function now() {
7861 return thisTimeValue$4(new $Date());
7862 }
7863 });
7864
7865 var es_date_setYear = {};
7866
7867 'use strict';
7868 var $$29 = _export;
7869 var uncurryThis$S = functionUncurryThis;
7870 var toIntegerOrInfinity$a = toIntegerOrInfinity$l;
7871 var DatePrototype$3 = Date.prototype;
7872 var thisTimeValue$3 = uncurryThis$S(DatePrototype$3.getTime);
7873 var setFullYear = uncurryThis$S(DatePrototype$3.setFullYear);
7874
7875 // `Date.prototype.setYear` method
7876 // https://tc39.es/ecma262/#sec-date.prototype.setyear
7877 $$29({
7878 target: 'Date',
7879 proto: true
7880 }, {
7881 setYear: function setYear(year) {
7882 // validate
7883 thisTimeValue$3(this);
7884 var yi = toIntegerOrInfinity$a(year);
7885 var yyyy = 0 <= yi && yi <= 99 ? yi + 1900 : yi;
7886 return setFullYear(this, yyyy);
7887 }
7888 });
7889
7890 var es_date_toGmtString = {};
7891
7892 var $$28 = _export;
7893
7894 // `Date.prototype.toGMTString` method
7895 // https://tc39.es/ecma262/#sec-date.prototype.togmtstring
7896 $$28({
7897 target: 'Date',
7898 proto: true
7899 }, {
7900 toGMTString: Date.prototype.toUTCString
7901 });
7902
7903 var es_date_toIsoString = {};
7904
7905 'use strict';
7906 var toIntegerOrInfinity$9 = toIntegerOrInfinity$l;
7907 var toString$r = toString$x;
7908 var requireObjectCoercible$g = requireObjectCoercible$j;
7909 var $RangeError$7 = RangeError;
7910
7911 // `String.prototype.repeat` method implementation
7912 // https://tc39.es/ecma262/#sec-string.prototype.repeat
7913 var stringRepeat = function repeat(count) {
7914 var str = toString$r(requireObjectCoercible$g(this));
7915 var result = '';
7916 var n = toIntegerOrInfinity$9(count);
7917 if (n < 0 || n == Infinity) throw $RangeError$7('Wrong number of repetitions');
7918 for (; n > 0; (n >>>= 1) && (str += str)) if (n & 1) result += str;
7919 return result;
7920 };
7921 var stringRepeat$1 = /*@__PURE__*/getDefaultExportFromCjs(stringRepeat);
7922
7923 // https://github.com/tc39/proposal-string-pad-start-end
7924 var uncurryThis$R = functionUncurryThis;
7925 var toLength$8 = toLength$d;
7926 var toString$q = toString$x;
7927 var $repeat$2 = stringRepeat;
7928 var requireObjectCoercible$f = requireObjectCoercible$j;
7929 var repeat$4 = uncurryThis$R($repeat$2);
7930 var stringSlice$f = uncurryThis$R(''.slice);
7931 var ceil$3 = Math.ceil;
7932
7933 // `String.prototype.{ padStart, padEnd }` methods implementation
7934 var createMethod$3 = function createMethod(IS_END) {
7935 return function ($this, maxLength, fillString) {
7936 var S = toString$q(requireObjectCoercible$f($this));
7937 var intMaxLength = toLength$8(maxLength);
7938 var stringLength = S.length;
7939 var fillStr = fillString === undefined ? ' ' : toString$q(fillString);
7940 var fillLen, stringFiller;
7941 if (intMaxLength <= stringLength || fillStr == '') return S;
7942 fillLen = intMaxLength - stringLength;
7943 stringFiller = repeat$4(fillStr, ceil$3(fillLen / fillStr.length));
7944 if (stringFiller.length > fillLen) stringFiller = stringSlice$f(stringFiller, 0, fillLen);
7945 return IS_END ? S + stringFiller : stringFiller + S;
7946 };
7947 };
7948 var stringPad = {
7949 // `String.prototype.padStart` method
7950 // https://tc39.es/ecma262/#sec-string.prototype.padstart
7951 start: createMethod$3(false),
7952 // `String.prototype.padEnd` method
7953 // https://tc39.es/ecma262/#sec-string.prototype.padend
7954 end: createMethod$3(true)
7955 };
7956 var stringPad$1 = /*@__PURE__*/getDefaultExportFromCjs(stringPad);
7957
7958 'use strict';
7959 var uncurryThis$Q = functionUncurryThis;
7960 var fails$W = fails$1m;
7961 var padStart = stringPad.start;
7962 var $RangeError$6 = RangeError;
7963 var $isFinite$1 = isFinite;
7964 var abs$a = Math.abs;
7965 var DatePrototype$2 = Date.prototype;
7966 var nativeDateToISOString = DatePrototype$2.toISOString;
7967 var thisTimeValue$2 = uncurryThis$Q(DatePrototype$2.getTime);
7968 var getUTCDate = uncurryThis$Q(DatePrototype$2.getUTCDate);
7969 var getUTCFullYear = uncurryThis$Q(DatePrototype$2.getUTCFullYear);
7970 var getUTCHours = uncurryThis$Q(DatePrototype$2.getUTCHours);
7971 var getUTCMilliseconds = uncurryThis$Q(DatePrototype$2.getUTCMilliseconds);
7972 var getUTCMinutes = uncurryThis$Q(DatePrototype$2.getUTCMinutes);
7973 var getUTCMonth = uncurryThis$Q(DatePrototype$2.getUTCMonth);
7974 var getUTCSeconds = uncurryThis$Q(DatePrototype$2.getUTCSeconds);
7975
7976 // `Date.prototype.toISOString` method implementation
7977 // https://tc39.es/ecma262/#sec-date.prototype.toisostring
7978 // PhantomJS / old WebKit fails here:
7979 var dateToIsoString = fails$W(function () {
7980 return nativeDateToISOString.call(new Date(-5e13 - 1)) != '0385-07-25T07:06:39.999Z';
7981 }) || !fails$W(function () {
7982 nativeDateToISOString.call(new Date(NaN));
7983 }) ? function toISOString() {
7984 if (!$isFinite$1(thisTimeValue$2(this))) throw $RangeError$6('Invalid time value');
7985 var date = this;
7986 var year = getUTCFullYear(date);
7987 var milliseconds = getUTCMilliseconds(date);
7988 var sign = year < 0 ? '-' : year > 9999 ? '+' : '';
7989 return sign + padStart(abs$a(year), sign ? 6 : 4, 0) + '-' + padStart(getUTCMonth(date) + 1, 2, 0) + '-' + padStart(getUTCDate(date), 2, 0) + 'T' + padStart(getUTCHours(date), 2, 0) + ':' + padStart(getUTCMinutes(date), 2, 0) + ':' + padStart(getUTCSeconds(date), 2, 0) + '.' + padStart(milliseconds, 3, 0) + 'Z';
7990 } : nativeDateToISOString;
7991 var dateToIsoString$1 = /*@__PURE__*/getDefaultExportFromCjs(dateToIsoString);
7992
7993 var $$27 = _export;
7994 var toISOString = dateToIsoString;
7995
7996 // `Date.prototype.toISOString` method
7997 // https://tc39.es/ecma262/#sec-date.prototype.toisostring
7998 // PhantomJS / old WebKit has a broken implementations
7999 $$27({
8000 target: 'Date',
8001 proto: true,
8002 forced: Date.prototype.toISOString !== toISOString
8003 }, {
8004 toISOString: toISOString
8005 });
8006
8007 var es_date_toJson = {};
8008
8009 'use strict';
8010 var $$26 = _export;
8011 var fails$V = fails$1m;
8012 var toObject$b = toObject$t;
8013 var toPrimitive$2 = toPrimitive$4;
8014 var FORCED$q = fails$V(function () {
8015 return new Date(NaN).toJSON() !== null || Date.prototype.toJSON.call({
8016 toISOString: function toISOString() {
8017 return 1;
8018 }
8019 }) !== 1;
8020 });
8021
8022 // `Date.prototype.toJSON` method
8023 // https://tc39.es/ecma262/#sec-date.prototype.tojson
8024 $$26({
8025 target: 'Date',
8026 proto: true,
8027 arity: 1,
8028 forced: FORCED$q
8029 }, {
8030 // eslint-disable-next-line no-unused-vars -- required for `.length`
8031 toJSON: function toJSON(key) {
8032 var O = toObject$b(this);
8033 var pv = toPrimitive$2(O, 'number');
8034 return typeof pv == 'number' && !isFinite(pv) ? null : O.toISOString();
8035 }
8036 });
8037
8038 var es_date_toPrimitive = {};
8039
8040 'use strict';
8041 var anObject$p = anObject$D;
8042 var ordinaryToPrimitive = ordinaryToPrimitive$2;
8043 var $TypeError$a = TypeError;
8044
8045 // `Date.prototype[@@toPrimitive](hint)` method implementation
8046 // https://tc39.es/ecma262/#sec-date.prototype-@@toprimitive
8047 var dateToPrimitive$1 = function dateToPrimitive(hint) {
8048 anObject$p(this);
8049 if (hint === 'string' || hint === 'default') hint = 'string';else if (hint !== 'number') throw $TypeError$a('Incorrect hint');
8050 return ordinaryToPrimitive(this, hint);
8051 };
8052 var dateToPrimitive$2 = /*@__PURE__*/getDefaultExportFromCjs(dateToPrimitive$1);
8053
8054 var hasOwn$g = hasOwnProperty_1;
8055 var defineBuiltIn$d = defineBuiltIn$m;
8056 var dateToPrimitive = dateToPrimitive$1;
8057 var wellKnownSymbol$d = wellKnownSymbol$z;
8058 var TO_PRIMITIVE = wellKnownSymbol$d('toPrimitive');
8059 var DatePrototype$1 = Date.prototype;
8060
8061 // `Date.prototype[@@toPrimitive]` method
8062 // https://tc39.es/ecma262/#sec-date.prototype-@@toprimitive
8063 if (!hasOwn$g(DatePrototype$1, TO_PRIMITIVE)) {
8064 defineBuiltIn$d(DatePrototype$1, TO_PRIMITIVE, dateToPrimitive);
8065 }
8066
8067 var es_date_toString = {};
8068
8069 // TODO: Remove from `core-js@4`
8070 var uncurryThis$P = functionUncurryThis;
8071 var defineBuiltIn$c = defineBuiltIn$m;
8072 var DatePrototype = Date.prototype;
8073 var INVALID_DATE = 'Invalid Date';
8074 var TO_STRING$1 = 'toString';
8075 var nativeDateToString = uncurryThis$P(DatePrototype[TO_STRING$1]);
8076 var thisTimeValue$1 = uncurryThis$P(DatePrototype.getTime);
8077
8078 // `Date.prototype.toString` method
8079 // https://tc39.es/ecma262/#sec-date.prototype.tostring
8080 if (String(new Date(NaN)) != INVALID_DATE) {
8081 defineBuiltIn$c(DatePrototype, TO_STRING$1, function toString() {
8082 var value = thisTimeValue$1(this);
8083 // eslint-disable-next-line no-self-compare -- NaN check
8084 return value === value ? nativeDateToString(this) : INVALID_DATE;
8085 });
8086 }
8087
8088 var es_escape = {};
8089
8090 'use strict';
8091 var $$25 = _export;
8092 var uncurryThis$O = functionUncurryThis;
8093 var toString$p = toString$x;
8094 var charAt$d = uncurryThis$O(''.charAt);
8095 var charCodeAt$4 = uncurryThis$O(''.charCodeAt);
8096 var exec$8 = uncurryThis$O(/./.exec);
8097 var numberToString$1 = uncurryThis$O(1.0.toString);
8098 var toUpperCase = uncurryThis$O(''.toUpperCase);
8099 var raw = /[\w*+\-./@]/;
8100 var hex$1 = function hex(code, length) {
8101 var result = numberToString$1(code, 16);
8102 while (result.length < length) result = '0' + result;
8103 return result;
8104 };
8105
8106 // `escape` method
8107 // https://tc39.es/ecma262/#sec-escape-string
8108 $$25({
8109 global: true
8110 }, {
8111 escape: function escape(string) {
8112 var str = toString$p(string);
8113 var result = '';
8114 var length = str.length;
8115 var index = 0;
8116 var chr, code;
8117 while (index < length) {
8118 chr = charAt$d(str, index++);
8119 if (exec$8(raw, chr)) {
8120 result += chr;
8121 } else {
8122 code = charCodeAt$4(chr, 0);
8123 if (code < 256) {
8124 result += '%' + hex$1(code, 2);
8125 } else {
8126 result += '%u' + toUpperCase(hex$1(code, 4));
8127 }
8128 }
8129 }
8130 return result;
8131 }
8132 });
8133
8134 var es_function_bind = {};
8135
8136 'use strict';
8137 var uncurryThis$N = functionUncurryThis;
8138 var aCallable$c = aCallable$l;
8139 var isObject$m = isObject$z;
8140 var hasOwn$f = hasOwnProperty_1;
8141 var arraySlice$6 = arraySlice$a;
8142 var NATIVE_BIND = functionBindNative;
8143 var $Function = Function;
8144 var concat$5 = uncurryThis$N([].concat);
8145 var join$6 = uncurryThis$N([].join);
8146 var factories = {};
8147 var construct = function construct(C, argsLength, args) {
8148 if (!hasOwn$f(factories, argsLength)) {
8149 for (var list = [], i = 0; i < argsLength; i++) list[i] = 'a[' + i + ']';
8150 factories[argsLength] = $Function('C,a', 'return new C(' + join$6(list, ',') + ')');
8151 }
8152 return factories[argsLength](C, args);
8153 };
8154
8155 // `Function.prototype.bind` method implementation
8156 // https://tc39.es/ecma262/#sec-function.prototype.bind
8157 // eslint-disable-next-line es/no-function-prototype-bind -- detection
8158 var functionBind = NATIVE_BIND ? $Function.bind : function bind(that /* , ...args */) {
8159 var F = aCallable$c(this);
8160 var Prototype = F.prototype;
8161 var partArgs = arraySlice$6(arguments, 1);
8162 var boundFunction = function bound( /* args... */
8163 ) {
8164 var args = concat$5(partArgs, arraySlice$6(arguments));
8165 return this instanceof boundFunction ? construct(F, args.length, args) : F.apply(that, args);
8166 };
8167 if (isObject$m(Prototype)) boundFunction.prototype = Prototype;
8168 return boundFunction;
8169 };
8170 var functionBind$1 = /*@__PURE__*/getDefaultExportFromCjs(functionBind);
8171
8172 // TODO: Remove from `core-js@4`
8173 var $$24 = _export;
8174 var bind$8 = functionBind;
8175
8176 // `Function.prototype.bind` method
8177 // https://tc39.es/ecma262/#sec-function.prototype.bind
8178 // eslint-disable-next-line es/no-function-prototype-bind -- detection
8179 $$24({
8180 target: 'Function',
8181 proto: true,
8182 forced: Function.bind !== bind$8
8183 }, {
8184 bind: bind$8
8185 });
8186
8187 var es_function_hasInstance = {};
8188
8189 'use strict';
8190 var isCallable$d = isCallable$z;
8191 var isObject$l = isObject$z;
8192 var definePropertyModule$5 = objectDefineProperty;
8193 var getPrototypeOf$5 = objectGetPrototypeOf$1;
8194 var wellKnownSymbol$c = wellKnownSymbol$z;
8195 var makeBuiltIn = makeBuiltInExports;
8196 var HAS_INSTANCE = wellKnownSymbol$c('hasInstance');
8197 var FunctionPrototype$1 = Function.prototype;
8198
8199 // `Function.prototype[@@hasInstance]` method
8200 // https://tc39.es/ecma262/#sec-function.prototype-@@hasinstance
8201 if (!(HAS_INSTANCE in FunctionPrototype$1)) {
8202 definePropertyModule$5.f(FunctionPrototype$1, HAS_INSTANCE, {
8203 value: makeBuiltIn(function (O) {
8204 if (!isCallable$d(this) || !isObject$l(O)) return false;
8205 var P = this.prototype;
8206 if (!isObject$l(P)) return O instanceof this;
8207 // for environment w/o native `@@hasInstance` logic enough `instanceof`, but add this:
8208 while (O = getPrototypeOf$5(O)) if (P === O) return true;
8209 return false;
8210 }, HAS_INSTANCE)
8211 });
8212 }
8213
8214 var es_function_name = {};
8215
8216 var DESCRIPTORS$s = descriptors;
8217 var FUNCTION_NAME_EXISTS = functionName.EXISTS;
8218 var uncurryThis$M = functionUncurryThis;
8219 var defineBuiltInAccessor$b = defineBuiltInAccessor$h;
8220 var FunctionPrototype = Function.prototype;
8221 var functionToString = uncurryThis$M(FunctionPrototype.toString);
8222 var nameRE = /function\b(?:\s|\/\*[\S\s]*?\*\/|\/\/[^\n\r]*[\n\r]+)*([^\s(/]*)/;
8223 var regExpExec$4 = uncurryThis$M(nameRE.exec);
8224 var NAME = 'name';
8225
8226 // Function instances `.name` property
8227 // https://tc39.es/ecma262/#sec-function-instances-name
8228 if (DESCRIPTORS$s && !FUNCTION_NAME_EXISTS) {
8229 defineBuiltInAccessor$b(FunctionPrototype, NAME, {
8230 configurable: true,
8231 get: function get() {
8232 try {
8233 return regExpExec$4(nameRE, functionToString(this))[1];
8234 } catch (error) {
8235 return '';
8236 }
8237 }
8238 });
8239 }
8240
8241 var es_globalThis = {};
8242
8243 var $$23 = _export;
8244 var global$H = global$Z;
8245
8246 // `globalThis` object
8247 // https://tc39.es/ecma262/#sec-globalthis
8248 $$23({
8249 global: true,
8250 forced: global$H.globalThis !== global$H
8251 }, {
8252 globalThis: global$H
8253 });
8254
8255 var es_json_toStringTag = {};
8256
8257 var global$G = global$Z;
8258 var setToStringTag$7 = setToStringTag$d;
8259
8260 // JSON[@@toStringTag] property
8261 // https://tc39.es/ecma262/#sec-json-@@tostringtag
8262 setToStringTag$7(global$G.JSON, 'JSON', true);
8263
8264 var es_map = {};
8265
8266 var es_map_constructor = {};
8267
8268 var internalMetadata$2 = {exports: {}};
8269
8270 // FF26- bug: ArrayBuffers are non-extensible, but Object.isExtensible does not report it
8271 var fails$U = fails$1m;
8272 var arrayBufferNonExtensible = fails$U(function () {
8273 if (typeof ArrayBuffer == 'function') {
8274 var buffer = new ArrayBuffer(8);
8275 // eslint-disable-next-line es/no-object-isextensible, es/no-object-defineproperty -- safe
8276 if (Object.isExtensible(buffer)) Object.defineProperty(buffer, 'a', {
8277 value: 8
8278 });
8279 }
8280 });
8281 var arrayBufferNonExtensible$1 = /*@__PURE__*/getDefaultExportFromCjs(arrayBufferNonExtensible);
8282
8283 var fails$T = fails$1m;
8284 var isObject$k = isObject$z;
8285 var classof$d = classofRaw$2;
8286 var ARRAY_BUFFER_NON_EXTENSIBLE$2 = arrayBufferNonExtensible;
8287
8288 // eslint-disable-next-line es/no-object-isextensible -- safe
8289 var $isExtensible$2 = Object.isExtensible;
8290 var FAILS_ON_PRIMITIVES$6 = fails$T(function () {
8291 $isExtensible$2(1);
8292 });
8293
8294 // `Object.isExtensible` method
8295 // https://tc39.es/ecma262/#sec-object.isextensible
8296 var objectIsExtensible = FAILS_ON_PRIMITIVES$6 || ARRAY_BUFFER_NON_EXTENSIBLE$2 ? function isExtensible(it) {
8297 if (!isObject$k(it)) return false;
8298 if (ARRAY_BUFFER_NON_EXTENSIBLE$2 && classof$d(it) == 'ArrayBuffer') return false;
8299 return $isExtensible$2 ? $isExtensible$2(it) : true;
8300 } : $isExtensible$2;
8301 var objectIsExtensible$1 = /*@__PURE__*/getDefaultExportFromCjs(objectIsExtensible);
8302
8303 var fails$S = fails$1m;
8304 var freezing = !fails$S(function () {
8305 // eslint-disable-next-line es/no-object-isextensible, es/no-object-preventextensions -- required for testing
8306 return Object.isExtensible(Object.preventExtensions({}));
8307 });
8308 var freezing$1 = /*@__PURE__*/getDefaultExportFromCjs(freezing);
8309
8310 var internalMetadata = internalMetadata$2.exports;
8311 var $$22 = _export;
8312 var uncurryThis$L = functionUncurryThis;
8313 var hiddenKeys = hiddenKeys$6;
8314 var isObject$j = isObject$z;
8315 var hasOwn$e = hasOwnProperty_1;
8316 var defineProperty$6 = objectDefineProperty.f;
8317 var getOwnPropertyNamesModule = objectGetOwnPropertyNames;
8318 var getOwnPropertyNamesExternalModule = objectGetOwnPropertyNamesExternal;
8319 var isExtensible$1 = objectIsExtensible;
8320 var uid$1 = uid$6;
8321 var FREEZING$5 = freezing;
8322 var REQUIRED = false;
8323 var METADATA = uid$1('meta');
8324 var id$1 = 0;
8325 var setMetadata = function setMetadata(it) {
8326 defineProperty$6(it, METADATA, {
8327 value: {
8328 objectID: 'O' + id$1++,
8329 // object ID
8330 weakData: {} // weak collections IDs
8331 }
8332 });
8333 };
8334
8335 var fastKey$1 = function fastKey(it, create) {
8336 // return a primitive with prefix
8337 if (!isObject$j(it)) return _typeof(it) == 'symbol' ? it : (typeof it == 'string' ? 'S' : 'P') + it;
8338 if (!hasOwn$e(it, METADATA)) {
8339 // can't set metadata to uncaught frozen object
8340 if (!isExtensible$1(it)) return 'F';
8341 // not necessary to add metadata
8342 if (!create) return 'E';
8343 // add missing metadata
8344 setMetadata(it);
8345 // return object ID
8346 }
8347 return it[METADATA].objectID;
8348 };
8349 var getWeakData$1 = function getWeakData(it, create) {
8350 if (!hasOwn$e(it, METADATA)) {
8351 // can't set metadata to uncaught frozen object
8352 if (!isExtensible$1(it)) return true;
8353 // not necessary to add metadata
8354 if (!create) return false;
8355 // add missing metadata
8356 setMetadata(it);
8357 // return the store of weak collections IDs
8358 }
8359 return it[METADATA].weakData;
8360 };
8361
8362 // add metadata on freeze-family methods calling
8363 var onFreeze$3 = function onFreeze(it) {
8364 if (FREEZING$5 && REQUIRED && isExtensible$1(it) && !hasOwn$e(it, METADATA)) setMetadata(it);
8365 return it;
8366 };
8367 var enable = function enable() {
8368 meta.enable = function () {/* empty */};
8369 REQUIRED = true;
8370 var getOwnPropertyNames = getOwnPropertyNamesModule.f;
8371 var splice = uncurryThis$L([].splice);
8372 var test = {};
8373 test[METADATA] = 1;
8374
8375 // prevent exposing of metadata key
8376 if (getOwnPropertyNames(test).length) {
8377 getOwnPropertyNamesModule.f = function (it) {
8378 var result = getOwnPropertyNames(it);
8379 for (var i = 0, length = result.length; i < length; i++) {
8380 if (result[i] === METADATA) {
8381 splice(result, i, 1);
8382 break;
8383 }
8384 }
8385 return result;
8386 };
8387 $$22({
8388 target: 'Object',
8389 stat: true,
8390 forced: true
8391 }, {
8392 getOwnPropertyNames: getOwnPropertyNamesExternalModule.f
8393 });
8394 }
8395 };
8396 var meta = internalMetadata$2.exports = {
8397 enable: enable,
8398 fastKey: fastKey$1,
8399 getWeakData: getWeakData$1,
8400 onFreeze: onFreeze$3
8401 };
8402 hiddenKeys[METADATA] = true;
8403 var internalMetadataExports = internalMetadata$2.exports;
8404 var internalMetadata$1 = /*@__PURE__*/getDefaultExportFromCjs(internalMetadataExports);
8405
8406 'use strict';
8407 var $$21 = _export;
8408 var global$F = global$Z;
8409 var uncurryThis$K = functionUncurryThis;
8410 var isForced$3 = isForced_1;
8411 var defineBuiltIn$b = defineBuiltIn$m;
8412 var InternalMetadataModule$1 = internalMetadataExports;
8413 var iterate$8 = iterate$a;
8414 var anInstance$8 = anInstance$a;
8415 var isCallable$c = isCallable$z;
8416 var isNullOrUndefined$9 = isNullOrUndefined$e;
8417 var isObject$i = isObject$z;
8418 var fails$R = fails$1m;
8419 var checkCorrectnessOfIteration$2 = checkCorrectnessOfIteration$4;
8420 var setToStringTag$6 = setToStringTag$d;
8421 var inheritIfRequired$4 = inheritIfRequired$6;
8422 var collection$4 = function collection(CONSTRUCTOR_NAME, wrapper, common) {
8423 var IS_MAP = CONSTRUCTOR_NAME.indexOf('Map') !== -1;
8424 var IS_WEAK = CONSTRUCTOR_NAME.indexOf('Weak') !== -1;
8425 var ADDER = IS_MAP ? 'set' : 'add';
8426 var NativeConstructor = global$F[CONSTRUCTOR_NAME];
8427 var NativePrototype = NativeConstructor && NativeConstructor.prototype;
8428 var Constructor = NativeConstructor;
8429 var exported = {};
8430 var fixMethod = function fixMethod(KEY) {
8431 var uncurriedNativeMethod = uncurryThis$K(NativePrototype[KEY]);
8432 defineBuiltIn$b(NativePrototype, KEY, KEY == 'add' ? function add(value) {
8433 uncurriedNativeMethod(this, value === 0 ? 0 : value);
8434 return this;
8435 } : KEY == 'delete' ? function (key) {
8436 return IS_WEAK && !isObject$i(key) ? false : uncurriedNativeMethod(this, key === 0 ? 0 : key);
8437 } : KEY == 'get' ? function get(key) {
8438 return IS_WEAK && !isObject$i(key) ? undefined : uncurriedNativeMethod(this, key === 0 ? 0 : key);
8439 } : KEY == 'has' ? function has(key) {
8440 return IS_WEAK && !isObject$i(key) ? false : uncurriedNativeMethod(this, key === 0 ? 0 : key);
8441 } : function set(key, value) {
8442 uncurriedNativeMethod(this, key === 0 ? 0 : key, value);
8443 return this;
8444 });
8445 };
8446 var REPLACE = isForced$3(CONSTRUCTOR_NAME, !isCallable$c(NativeConstructor) || !(IS_WEAK || NativePrototype.forEach && !fails$R(function () {
8447 new NativeConstructor().entries().next();
8448 })));
8449 if (REPLACE) {
8450 // create collection constructor
8451 Constructor = common.getConstructor(wrapper, CONSTRUCTOR_NAME, IS_MAP, ADDER);
8452 InternalMetadataModule$1.enable();
8453 } else if (isForced$3(CONSTRUCTOR_NAME, true)) {
8454 var instance = new Constructor();
8455 // early implementations not supports chaining
8456 var HASNT_CHAINING = instance[ADDER](IS_WEAK ? {} : -0, 1) != instance;
8457 // V8 ~ Chromium 40- weak-collections throws on primitives, but should return false
8458 var THROWS_ON_PRIMITIVES = fails$R(function () {
8459 instance.has(1);
8460 });
8461 // most early implementations doesn't supports iterables, most modern - not close it correctly
8462 // eslint-disable-next-line no-new -- required for testing
8463 var ACCEPT_ITERABLES = checkCorrectnessOfIteration$2(function (iterable) {
8464 new NativeConstructor(iterable);
8465 });
8466 // for early implementations -0 and +0 not the same
8467 var BUGGY_ZERO = !IS_WEAK && fails$R(function () {
8468 // V8 ~ Chromium 42- fails only with 5+ elements
8469 var $instance = new NativeConstructor();
8470 var index = 5;
8471 while (index--) $instance[ADDER](index, index);
8472 return !$instance.has(-0);
8473 });
8474 if (!ACCEPT_ITERABLES) {
8475 Constructor = wrapper(function (dummy, iterable) {
8476 anInstance$8(dummy, NativePrototype);
8477 var that = inheritIfRequired$4(new NativeConstructor(), dummy, Constructor);
8478 if (!isNullOrUndefined$9(iterable)) iterate$8(iterable, that[ADDER], {
8479 that: that,
8480 AS_ENTRIES: IS_MAP
8481 });
8482 return that;
8483 });
8484 Constructor.prototype = NativePrototype;
8485 NativePrototype.constructor = Constructor;
8486 }
8487 if (THROWS_ON_PRIMITIVES || BUGGY_ZERO) {
8488 fixMethod('delete');
8489 fixMethod('has');
8490 IS_MAP && fixMethod('get');
8491 }
8492 if (BUGGY_ZERO || HASNT_CHAINING) fixMethod(ADDER);
8493
8494 // weak collections should not contains .clear method
8495 if (IS_WEAK && NativePrototype.clear) delete NativePrototype.clear;
8496 }
8497 exported[CONSTRUCTOR_NAME] = Constructor;
8498 $$21({
8499 global: true,
8500 constructor: true,
8501 forced: Constructor != NativeConstructor
8502 }, exported);
8503 setToStringTag$6(Constructor, CONSTRUCTOR_NAME);
8504 if (!IS_WEAK) common.setStrong(Constructor, CONSTRUCTOR_NAME, IS_MAP);
8505 return Constructor;
8506 };
8507 var collection$5 = /*@__PURE__*/getDefaultExportFromCjs(collection$4);
8508
8509 'use strict';
8510 var create$6 = objectCreate;
8511 var defineBuiltInAccessor$a = defineBuiltInAccessor$h;
8512 var defineBuiltIns$3 = defineBuiltIns$5;
8513 var bind$7 = functionBindContext;
8514 var anInstance$7 = anInstance$a;
8515 var isNullOrUndefined$8 = isNullOrUndefined$e;
8516 var iterate$7 = iterate$a;
8517 var defineIterator$1 = iteratorDefine;
8518 var createIterResultObject$2 = createIterResultObject$4;
8519 var setSpecies$3 = setSpecies$6;
8520 var DESCRIPTORS$r = descriptors;
8521 var fastKey = internalMetadataExports.fastKey;
8522 var InternalStateModule$8 = internalState;
8523 var setInternalState$8 = InternalStateModule$8.set;
8524 var internalStateGetterFor$1 = InternalStateModule$8.getterFor;
8525 var collectionStrong$2 = {
8526 getConstructor: function getConstructor(wrapper, CONSTRUCTOR_NAME, IS_MAP, ADDER) {
8527 var Constructor = wrapper(function (that, iterable) {
8528 anInstance$7(that, Prototype);
8529 setInternalState$8(that, {
8530 type: CONSTRUCTOR_NAME,
8531 index: create$6(null),
8532 first: undefined,
8533 last: undefined,
8534 size: 0
8535 });
8536 if (!DESCRIPTORS$r) that.size = 0;
8537 if (!isNullOrUndefined$8(iterable)) iterate$7(iterable, that[ADDER], {
8538 that: that,
8539 AS_ENTRIES: IS_MAP
8540 });
8541 });
8542 var Prototype = Constructor.prototype;
8543 var getInternalState = internalStateGetterFor$1(CONSTRUCTOR_NAME);
8544 var define = function define(that, key, value) {
8545 var state = getInternalState(that);
8546 var entry = getEntry(that, key);
8547 var previous, index;
8548 // change existing entry
8549 if (entry) {
8550 entry.value = value;
8551 // create new entry
8552 } else {
8553 state.last = entry = {
8554 index: index = fastKey(key, true),
8555 key: key,
8556 value: value,
8557 previous: previous = state.last,
8558 next: undefined,
8559 removed: false
8560 };
8561 if (!state.first) state.first = entry;
8562 if (previous) previous.next = entry;
8563 if (DESCRIPTORS$r) state.size++;else that.size++;
8564 // add to index
8565 if (index !== 'F') state.index[index] = entry;
8566 }
8567 return that;
8568 };
8569 var getEntry = function getEntry(that, key) {
8570 var state = getInternalState(that);
8571 // fast case
8572 var index = fastKey(key);
8573 var entry;
8574 if (index !== 'F') return state.index[index];
8575 // frozen object case
8576 for (entry = state.first; entry; entry = entry.next) {
8577 if (entry.key == key) return entry;
8578 }
8579 };
8580 defineBuiltIns$3(Prototype, {
8581 // `{ Map, Set }.prototype.clear()` methods
8582 // https://tc39.es/ecma262/#sec-map.prototype.clear
8583 // https://tc39.es/ecma262/#sec-set.prototype.clear
8584 clear: function clear() {
8585 var that = this;
8586 var state = getInternalState(that);
8587 var data = state.index;
8588 var entry = state.first;
8589 while (entry) {
8590 entry.removed = true;
8591 if (entry.previous) entry.previous = entry.previous.next = undefined;
8592 delete data[entry.index];
8593 entry = entry.next;
8594 }
8595 state.first = state.last = undefined;
8596 if (DESCRIPTORS$r) state.size = 0;else that.size = 0;
8597 },
8598 // `{ Map, Set }.prototype.delete(key)` methods
8599 // https://tc39.es/ecma262/#sec-map.prototype.delete
8600 // https://tc39.es/ecma262/#sec-set.prototype.delete
8601 'delete': function _delete(key) {
8602 var that = this;
8603 var state = getInternalState(that);
8604 var entry = getEntry(that, key);
8605 if (entry) {
8606 var next = entry.next;
8607 var prev = entry.previous;
8608 delete state.index[entry.index];
8609 entry.removed = true;
8610 if (prev) prev.next = next;
8611 if (next) next.previous = prev;
8612 if (state.first == entry) state.first = next;
8613 if (state.last == entry) state.last = prev;
8614 if (DESCRIPTORS$r) state.size--;else that.size--;
8615 }
8616 return !!entry;
8617 },
8618 // `{ Map, Set }.prototype.forEach(callbackfn, thisArg = undefined)` methods
8619 // https://tc39.es/ecma262/#sec-map.prototype.foreach
8620 // https://tc39.es/ecma262/#sec-set.prototype.foreach
8621 forEach: function forEach(callbackfn /* , that = undefined */) {
8622 var state = getInternalState(this);
8623 var boundFunction = bind$7(callbackfn, arguments.length > 1 ? arguments[1] : undefined);
8624 var entry;
8625 while (entry = entry ? entry.next : state.first) {
8626 boundFunction(entry.value, entry.key, this);
8627 // revert to the last existing entry
8628 while (entry && entry.removed) entry = entry.previous;
8629 }
8630 },
8631 // `{ Map, Set}.prototype.has(key)` methods
8632 // https://tc39.es/ecma262/#sec-map.prototype.has
8633 // https://tc39.es/ecma262/#sec-set.prototype.has
8634 has: function has(key) {
8635 return !!getEntry(this, key);
8636 }
8637 });
8638 defineBuiltIns$3(Prototype, IS_MAP ? {
8639 // `Map.prototype.get(key)` method
8640 // https://tc39.es/ecma262/#sec-map.prototype.get
8641 get: function get(key) {
8642 var entry = getEntry(this, key);
8643 return entry && entry.value;
8644 },
8645 // `Map.prototype.set(key, value)` method
8646 // https://tc39.es/ecma262/#sec-map.prototype.set
8647 set: function set(key, value) {
8648 return define(this, key === 0 ? 0 : key, value);
8649 }
8650 } : {
8651 // `Set.prototype.add(value)` method
8652 // https://tc39.es/ecma262/#sec-set.prototype.add
8653 add: function add(value) {
8654 return define(this, value = value === 0 ? 0 : value, value);
8655 }
8656 });
8657 if (DESCRIPTORS$r) defineBuiltInAccessor$a(Prototype, 'size', {
8658 configurable: true,
8659 get: function get() {
8660 return getInternalState(this).size;
8661 }
8662 });
8663 return Constructor;
8664 },
8665 setStrong: function setStrong(Constructor, CONSTRUCTOR_NAME, IS_MAP) {
8666 var ITERATOR_NAME = CONSTRUCTOR_NAME + ' Iterator';
8667 var getInternalCollectionState = internalStateGetterFor$1(CONSTRUCTOR_NAME);
8668 var getInternalIteratorState = internalStateGetterFor$1(ITERATOR_NAME);
8669 // `{ Map, Set }.prototype.{ keys, values, entries, @@iterator }()` methods
8670 // https://tc39.es/ecma262/#sec-map.prototype.entries
8671 // https://tc39.es/ecma262/#sec-map.prototype.keys
8672 // https://tc39.es/ecma262/#sec-map.prototype.values
8673 // https://tc39.es/ecma262/#sec-map.prototype-@@iterator
8674 // https://tc39.es/ecma262/#sec-set.prototype.entries
8675 // https://tc39.es/ecma262/#sec-set.prototype.keys
8676 // https://tc39.es/ecma262/#sec-set.prototype.values
8677 // https://tc39.es/ecma262/#sec-set.prototype-@@iterator
8678 defineIterator$1(Constructor, CONSTRUCTOR_NAME, function (iterated, kind) {
8679 setInternalState$8(this, {
8680 type: ITERATOR_NAME,
8681 target: iterated,
8682 state: getInternalCollectionState(iterated),
8683 kind: kind,
8684 last: undefined
8685 });
8686 }, function () {
8687 var state = getInternalIteratorState(this);
8688 var kind = state.kind;
8689 var entry = state.last;
8690 // revert to the last existing entry
8691 while (entry && entry.removed) entry = entry.previous;
8692 // get next entry
8693 if (!state.target || !(state.last = entry = entry ? entry.next : state.state.first)) {
8694 // or finish the iteration
8695 state.target = undefined;
8696 return createIterResultObject$2(undefined, true);
8697 }
8698 // return step by kind
8699 if (kind == 'keys') return createIterResultObject$2(entry.key, false);
8700 if (kind == 'values') return createIterResultObject$2(entry.value, false);
8701 return createIterResultObject$2([entry.key, entry.value], false);
8702 }, IS_MAP ? 'entries' : 'values', !IS_MAP, true);
8703
8704 // `{ Map, Set }.prototype[@@species]` accessors
8705 // https://tc39.es/ecma262/#sec-get-map-@@species
8706 // https://tc39.es/ecma262/#sec-get-set-@@species
8707 setSpecies$3(CONSTRUCTOR_NAME);
8708 }
8709 };
8710 var collectionStrong$3 = /*@__PURE__*/getDefaultExportFromCjs(collectionStrong$2);
8711
8712 'use strict';
8713 var collection$3 = collection$4;
8714 var collectionStrong$1 = collectionStrong$2;
8715
8716 // `Map` constructor
8717 // https://tc39.es/ecma262/#sec-map-objects
8718 collection$3('Map', function (init) {
8719 return function Map() {
8720 return init(this, arguments.length ? arguments[0] : undefined);
8721 };
8722 }, collectionStrong$1);
8723
8724 var es_math_acosh = {};
8725
8726 var log$b = Math.log;
8727
8728 // `Math.log1p` method implementation
8729 // https://tc39.es/ecma262/#sec-math.log1p
8730 // eslint-disable-next-line es/no-math-log1p -- safe
8731 var mathLog1p = Math.log1p || function log1p(x) {
8732 var n = +x;
8733 return n > -1e-8 && n < 1e-8 ? n - n * n / 2 : log$b(1 + n);
8734 };
8735 var mathLog1p$1 = /*@__PURE__*/getDefaultExportFromCjs(mathLog1p);
8736
8737 var $$20 = _export;
8738 var log1p$4 = mathLog1p;
8739
8740 // eslint-disable-next-line es/no-math-acosh -- required for testing
8741 var $acosh = Math.acosh;
8742 var log$a = Math.log;
8743 var sqrt$5 = Math.sqrt;
8744 var LN2$1 = Math.LN2;
8745 var FORCED$p = !$acosh
8746 // V8 bug: https://code.google.com/p/v8/issues/detail?id=3509
8747 || Math.floor($acosh(Number.MAX_VALUE)) != 710
8748 // Tor Browser bug: Math.acosh(Infinity) -> NaN
8749 || $acosh(Infinity) != Infinity;
8750
8751 // `Math.acosh` method
8752 // https://tc39.es/ecma262/#sec-math.acosh
8753 $$20({
8754 target: 'Math',
8755 stat: true,
8756 forced: FORCED$p
8757 }, {
8758 acosh: function acosh(x) {
8759 var n = +x;
8760 return n < 1 ? NaN : n > 94906265.62425156 ? log$a(n) + LN2$1 : log1p$4(n - 1 + sqrt$5(n - 1) * sqrt$5(n + 1));
8761 }
8762 });
8763
8764 var es_math_asinh = {};
8765
8766 var $$1$ = _export;
8767
8768 // eslint-disable-next-line es/no-math-asinh -- required for testing
8769 var $asinh = Math.asinh;
8770 var log$9 = Math.log;
8771 var sqrt$4 = Math.sqrt;
8772 function asinh$3(x) {
8773 var n = +x;
8774 return !isFinite(n) || n == 0 ? n : n < 0 ? -asinh$3(-n) : log$9(n + sqrt$4(n * n + 1));
8775 }
8776 var FORCED$o = !($asinh && 1 / $asinh(0) > 0);
8777
8778 // `Math.asinh` method
8779 // https://tc39.es/ecma262/#sec-math.asinh
8780 // Tor Browser bug: Math.asinh(0) -> -0
8781 $$1$({
8782 target: 'Math',
8783 stat: true,
8784 forced: FORCED$o
8785 }, {
8786 asinh: asinh$3
8787 });
8788
8789 var es_math_atanh = {};
8790
8791 var $$1_ = _export;
8792
8793 // eslint-disable-next-line es/no-math-atanh -- required for testing
8794 var $atanh = Math.atanh;
8795 var log$8 = Math.log;
8796 var FORCED$n = !($atanh && 1 / $atanh(-0) < 0);
8797
8798 // `Math.atanh` method
8799 // https://tc39.es/ecma262/#sec-math.atanh
8800 // Tor Browser bug: Math.atanh(-0) -> 0
8801 $$1_({
8802 target: 'Math',
8803 stat: true,
8804 forced: FORCED$n
8805 }, {
8806 atanh: function atanh(x) {
8807 var n = +x;
8808 return n == 0 ? n : log$8((1 + n) / (1 - n)) / 2;
8809 }
8810 });
8811
8812 var es_math_cbrt = {};
8813
8814 // `Math.sign` method implementation
8815 // https://tc39.es/ecma262/#sec-math.sign
8816 // eslint-disable-next-line es/no-math-sign -- safe
8817 var mathSign = Math.sign || function sign(x) {
8818 var n = +x;
8819 // eslint-disable-next-line no-self-compare -- NaN check
8820 return n == 0 || n != n ? n : n < 0 ? -1 : 1;
8821 };
8822 var mathSign$1 = /*@__PURE__*/getDefaultExportFromCjs(mathSign);
8823
8824 var $$1Z = _export;
8825 var sign$6 = mathSign;
8826 var abs$9 = Math.abs;
8827 var pow$8 = Math.pow;
8828
8829 // `Math.cbrt` method
8830 // https://tc39.es/ecma262/#sec-math.cbrt
8831 $$1Z({
8832 target: 'Math',
8833 stat: true
8834 }, {
8835 cbrt: function cbrt(x) {
8836 var n = +x;
8837 return sign$6(n) * pow$8(abs$9(n), 1 / 3);
8838 }
8839 });
8840
8841 var es_math_clz32 = {};
8842
8843 var $$1Y = _export;
8844 var floor$a = Math.floor;
8845 var log$7 = Math.log;
8846 var LOG2E = Math.LOG2E;
8847
8848 // `Math.clz32` method
8849 // https://tc39.es/ecma262/#sec-math.clz32
8850 $$1Y({
8851 target: 'Math',
8852 stat: true
8853 }, {
8854 clz32: function clz32(x) {
8855 var n = x >>> 0;
8856 return n ? 31 - floor$a(log$7(n + 0.5) * LOG2E) : 32;
8857 }
8858 });
8859
8860 var es_math_cosh = {};
8861
8862 // eslint-disable-next-line es/no-math-expm1 -- safe
8863 var $expm1 = Math.expm1;
8864 var exp$5 = Math.exp;
8865
8866 // `Math.expm1` method implementation
8867 // https://tc39.es/ecma262/#sec-math.expm1
8868 var mathExpm1 = !$expm1
8869 // Old FF bug
8870 || $expm1(10) > 22025.465794806719 || $expm1(10) < 22025.4657948067165168
8871 // Tor Browser bug
8872 || $expm1(-2e-17) != -2e-17 ? function expm1(x) {
8873 var n = +x;
8874 return n == 0 ? n : n > -1e-6 && n < 1e-6 ? n + n * n / 2 : exp$5(n) - 1;
8875 } : $expm1;
8876 var mathExpm1$1 = /*@__PURE__*/getDefaultExportFromCjs(mathExpm1);
8877
8878 var $$1X = _export;
8879 var expm1$6 = mathExpm1;
8880
8881 // eslint-disable-next-line es/no-math-cosh -- required for testing
8882 var $cosh = Math.cosh;
8883 var abs$8 = Math.abs;
8884 var E$1 = Math.E;
8885 var FORCED$m = !$cosh || $cosh(710) === Infinity;
8886
8887 // `Math.cosh` method
8888 // https://tc39.es/ecma262/#sec-math.cosh
8889 $$1X({
8890 target: 'Math',
8891 stat: true,
8892 forced: FORCED$m
8893 }, {
8894 cosh: function cosh(x) {
8895 var t = expm1$6(abs$8(x) - 1) + 1;
8896 return (t + 1 / (t * E$1 * E$1)) * (E$1 / 2);
8897 }
8898 });
8899
8900 var es_math_expm1 = {};
8901
8902 var $$1W = _export;
8903 var expm1$5 = mathExpm1;
8904
8905 // `Math.expm1` method
8906 // https://tc39.es/ecma262/#sec-math.expm1
8907 // eslint-disable-next-line es/no-math-expm1 -- required for testing
8908 $$1W({
8909 target: 'Math',
8910 stat: true,
8911 forced: expm1$5 != Math.expm1
8912 }, {
8913 expm1: expm1$5
8914 });
8915
8916 var es_math_fround = {};
8917
8918 var sign$5 = mathSign;
8919 var abs$7 = Math.abs;
8920 var pow$7 = Math.pow;
8921 var EPSILON = pow$7(2, -52);
8922 var EPSILON32 = pow$7(2, -23);
8923 var MAX32 = pow$7(2, 127) * (2 - EPSILON32);
8924 var MIN32 = pow$7(2, -126);
8925 var roundTiesToEven = function roundTiesToEven(n) {
8926 return n + 1 / EPSILON - 1 / EPSILON;
8927 };
8928
8929 // `Math.fround` method implementation
8930 // https://tc39.es/ecma262/#sec-math.fround
8931 // eslint-disable-next-line es/no-math-fround -- safe
8932 var mathFround = Math.fround || function fround(x) {
8933 var n = +x;
8934 var $abs = abs$7(n);
8935 var $sign = sign$5(n);
8936 var a, result;
8937 if ($abs < MIN32) return $sign * roundTiesToEven($abs / MIN32 / EPSILON32) * MIN32 * EPSILON32;
8938 a = (1 + EPSILON32 / EPSILON) * $abs;
8939 result = a - (a - $abs);
8940 // eslint-disable-next-line no-self-compare -- NaN check
8941 if (result > MAX32 || result != result) return $sign * Infinity;
8942 return $sign * result;
8943 };
8944 var mathFround$1 = /*@__PURE__*/getDefaultExportFromCjs(mathFround);
8945
8946 var $$1V = _export;
8947 var fround = mathFround;
8948
8949 // `Math.fround` method
8950 // https://tc39.es/ecma262/#sec-math.fround
8951 $$1V({
8952 target: 'Math',
8953 stat: true
8954 }, {
8955 fround: fround
8956 });
8957
8958 var es_math_hypot = {};
8959
8960 var $$1U = _export;
8961
8962 // eslint-disable-next-line es/no-math-hypot -- required for testing
8963 var $hypot = Math.hypot;
8964 var abs$6 = Math.abs;
8965 var sqrt$3 = Math.sqrt;
8966
8967 // Chrome 77 bug
8968 // https://bugs.chromium.org/p/v8/issues/detail?id=9546
8969 var FORCED$l = !!$hypot && $hypot(Infinity, NaN) !== Infinity;
8970
8971 // `Math.hypot` method
8972 // https://tc39.es/ecma262/#sec-math.hypot
8973 $$1U({
8974 target: 'Math',
8975 stat: true,
8976 arity: 2,
8977 forced: FORCED$l
8978 }, {
8979 // eslint-disable-next-line no-unused-vars -- required for `.length`
8980 hypot: function hypot(value1, value2) {
8981 var sum = 0;
8982 var i = 0;
8983 var aLen = arguments.length;
8984 var larg = 0;
8985 var arg, div;
8986 while (i < aLen) {
8987 arg = abs$6(arguments[i++]);
8988 if (larg < arg) {
8989 div = larg / arg;
8990 sum = sum * div * div + 1;
8991 larg = arg;
8992 } else if (arg > 0) {
8993 div = arg / larg;
8994 sum += div * div;
8995 } else sum += arg;
8996 }
8997 return larg === Infinity ? Infinity : larg * sqrt$3(sum);
8998 }
8999 });
9000
9001 var es_math_imul = {};
9002
9003 var $$1T = _export;
9004 var fails$Q = fails$1m;
9005
9006 // eslint-disable-next-line es/no-math-imul -- required for testing
9007 var $imul = Math.imul;
9008 var FORCED$k = fails$Q(function () {
9009 return $imul(0xFFFFFFFF, 5) != -5 || $imul.length != 2;
9010 });
9011
9012 // `Math.imul` method
9013 // https://tc39.es/ecma262/#sec-math.imul
9014 // some WebKit versions fails with big numbers, some has wrong arity
9015 $$1T({
9016 target: 'Math',
9017 stat: true,
9018 forced: FORCED$k
9019 }, {
9020 imul: function imul(x, y) {
9021 var UINT16 = 0xFFFF;
9022 var xn = +x;
9023 var yn = +y;
9024 var xl = UINT16 & xn;
9025 var yl = UINT16 & yn;
9026 return 0 | xl * yl + ((UINT16 & xn >>> 16) * yl + xl * (UINT16 & yn >>> 16) << 16 >>> 0);
9027 }
9028 });
9029
9030 var es_math_log10 = {};
9031
9032 var log$6 = Math.log;
9033 var LOG10E = Math.LOG10E;
9034
9035 // eslint-disable-next-line es/no-math-log10 -- safe
9036 var mathLog10 = Math.log10 || function log10(x) {
9037 return log$6(x) * LOG10E;
9038 };
9039 var mathLog10$1 = /*@__PURE__*/getDefaultExportFromCjs(mathLog10);
9040
9041 var $$1S = _export;
9042 var log10$1 = mathLog10;
9043
9044 // `Math.log10` method
9045 // https://tc39.es/ecma262/#sec-math.log10
9046 $$1S({
9047 target: 'Math',
9048 stat: true
9049 }, {
9050 log10: log10$1
9051 });
9052
9053 var es_math_log1p = {};
9054
9055 var $$1R = _export;
9056 var log1p$3 = mathLog1p;
9057
9058 // `Math.log1p` method
9059 // https://tc39.es/ecma262/#sec-math.log1p
9060 $$1R({
9061 target: 'Math',
9062 stat: true
9063 }, {
9064 log1p: log1p$3
9065 });
9066
9067 var es_math_log2 = {};
9068
9069 var $$1Q = _export;
9070 var log$5 = Math.log;
9071 var LN2 = Math.LN2;
9072
9073 // `Math.log2` method
9074 // https://tc39.es/ecma262/#sec-math.log2
9075 $$1Q({
9076 target: 'Math',
9077 stat: true
9078 }, {
9079 log2: function log2(x) {
9080 return log$5(x) / LN2;
9081 }
9082 });
9083
9084 var es_math_sign = {};
9085
9086 var $$1P = _export;
9087 var sign$4 = mathSign;
9088
9089 // `Math.sign` method
9090 // https://tc39.es/ecma262/#sec-math.sign
9091 $$1P({
9092 target: 'Math',
9093 stat: true
9094 }, {
9095 sign: sign$4
9096 });
9097
9098 var es_math_sinh = {};
9099
9100 var $$1O = _export;
9101 var fails$P = fails$1m;
9102 var expm1$4 = mathExpm1;
9103 var abs$5 = Math.abs;
9104 var exp$4 = Math.exp;
9105 var E = Math.E;
9106 var FORCED$j = fails$P(function () {
9107 // eslint-disable-next-line es/no-math-sinh -- required for testing
9108 return Math.sinh(-2e-17) != -2e-17;
9109 });
9110
9111 // `Math.sinh` method
9112 // https://tc39.es/ecma262/#sec-math.sinh
9113 // V8 near Chromium 38 has a problem with very small numbers
9114 $$1O({
9115 target: 'Math',
9116 stat: true,
9117 forced: FORCED$j
9118 }, {
9119 sinh: function sinh(x) {
9120 var n = +x;
9121 return abs$5(n) < 1 ? (expm1$4(n) - expm1$4(-n)) / 2 : (exp$4(n - 1) - exp$4(-n - 1)) * (E / 2);
9122 }
9123 });
9124
9125 var es_math_tanh = {};
9126
9127 var $$1N = _export;
9128 var expm1$3 = mathExpm1;
9129 var exp$3 = Math.exp;
9130
9131 // `Math.tanh` method
9132 // https://tc39.es/ecma262/#sec-math.tanh
9133 $$1N({
9134 target: 'Math',
9135 stat: true
9136 }, {
9137 tanh: function tanh(x) {
9138 var n = +x;
9139 var a = expm1$3(n);
9140 var b = expm1$3(-n);
9141 return a == Infinity ? 1 : b == Infinity ? -1 : (a - b) / (exp$3(n) + exp$3(-n));
9142 }
9143 });
9144
9145 var es_math_toStringTag = {};
9146
9147 var setToStringTag$5 = setToStringTag$d;
9148
9149 // Math[@@toStringTag] property
9150 // https://tc39.es/ecma262/#sec-math-@@tostringtag
9151 setToStringTag$5(Math, 'Math', true);
9152
9153 var es_math_trunc = {};
9154
9155 var $$1M = _export;
9156 var trunc = mathTrunc;
9157
9158 // `Math.trunc` method
9159 // https://tc39.es/ecma262/#sec-math.trunc
9160 $$1M({
9161 target: 'Math',
9162 stat: true
9163 }, {
9164 trunc: trunc
9165 });
9166
9167 var es_number_constructor = {};
9168
9169 var uncurryThis$J = functionUncurryThis;
9170
9171 // `thisNumberValue` abstract operation
9172 // https://tc39.es/ecma262/#sec-thisnumbervalue
9173 var thisNumberValue$5 = uncurryThis$J(1.0.valueOf);
9174 var thisNumberValue$6 = /*@__PURE__*/getDefaultExportFromCjs(thisNumberValue$5);
9175
9176 // a string of all valid unicode whitespaces
9177 var whitespaces$5 = "\t\n\x0B\f\r \xA0\u1680\u2000\u2001\u2002" + "\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200A\u202F\u205F\u3000\u2028\u2029\uFEFF";
9178 var whitespaces$6 = /*@__PURE__*/getDefaultExportFromCjs(whitespaces$5);
9179
9180 var uncurryThis$I = functionUncurryThis;
9181 var requireObjectCoercible$e = requireObjectCoercible$j;
9182 var toString$o = toString$x;
9183 var whitespaces$4 = whitespaces$5;
9184 var replace$9 = uncurryThis$I(''.replace);
9185 var ltrim = RegExp('^[' + whitespaces$4 + ']+');
9186 var rtrim = RegExp('(^|[^' + whitespaces$4 + '])[' + whitespaces$4 + ']+$');
9187
9188 // `String.prototype.{ trim, trimStart, trimEnd, trimLeft, trimRight }` methods implementation
9189 var createMethod$2 = function createMethod(TYPE) {
9190 return function ($this) {
9191 var string = toString$o(requireObjectCoercible$e($this));
9192 if (TYPE & 1) string = replace$9(string, ltrim, '');
9193 if (TYPE & 2) string = replace$9(string, rtrim, '$1');
9194 return string;
9195 };
9196 };
9197 var stringTrim = {
9198 // `String.prototype.{ trimLeft, trimStart }` methods
9199 // https://tc39.es/ecma262/#sec-string.prototype.trimstart
9200 start: createMethod$2(1),
9201 // `String.prototype.{ trimRight, trimEnd }` methods
9202 // https://tc39.es/ecma262/#sec-string.prototype.trimend
9203 end: createMethod$2(2),
9204 // `String.prototype.trim` method
9205 // https://tc39.es/ecma262/#sec-string.prototype.trim
9206 trim: createMethod$2(3)
9207 };
9208 var stringTrim$1 = /*@__PURE__*/getDefaultExportFromCjs(stringTrim);
9209
9210 'use strict';
9211 var $$1L = _export;
9212 var IS_PURE$e = isPure;
9213 var DESCRIPTORS$q = descriptors;
9214 var global$E = global$Z;
9215 var path = path$2;
9216 var uncurryThis$H = functionUncurryThis;
9217 var isForced$2 = isForced_1;
9218 var hasOwn$d = hasOwnProperty_1;
9219 var inheritIfRequired$3 = inheritIfRequired$6;
9220 var isPrototypeOf$3 = objectIsPrototypeOf;
9221 var isSymbol$2 = isSymbol$7;
9222 var toPrimitive$1 = toPrimitive$4;
9223 var fails$O = fails$1m;
9224 var getOwnPropertyNames$3 = objectGetOwnPropertyNames.f;
9225 var getOwnPropertyDescriptor$7 = objectGetOwnPropertyDescriptor.f;
9226 var defineProperty$5 = objectDefineProperty.f;
9227 var thisNumberValue$4 = thisNumberValue$5;
9228 var trim$2 = stringTrim.trim;
9229 var NUMBER = 'Number';
9230 var NativeNumber = global$E[NUMBER];
9231 var PureNumberNamespace = path[NUMBER];
9232 var NumberPrototype = NativeNumber.prototype;
9233 var TypeError$5 = global$E.TypeError;
9234 var stringSlice$e = uncurryThis$H(''.slice);
9235 var charCodeAt$3 = uncurryThis$H(''.charCodeAt);
9236
9237 // `ToNumeric` abstract operation
9238 // https://tc39.es/ecma262/#sec-tonumeric
9239 var toNumeric = function toNumeric(value) {
9240 var primValue = toPrimitive$1(value, 'number');
9241 return typeof primValue == 'bigint' ? primValue : toNumber(primValue);
9242 };
9243
9244 // `ToNumber` abstract operation
9245 // https://tc39.es/ecma262/#sec-tonumber
9246 var toNumber = function toNumber(argument) {
9247 var it = toPrimitive$1(argument, 'number');
9248 var first, third, radix, maxCode, digits, length, index, code;
9249 if (isSymbol$2(it)) throw TypeError$5('Cannot convert a Symbol value to a number');
9250 if (typeof it == 'string' && it.length > 2) {
9251 it = trim$2(it);
9252 first = charCodeAt$3(it, 0);
9253 if (first === 43 || first === 45) {
9254 third = charCodeAt$3(it, 2);
9255 if (third === 88 || third === 120) return NaN; // Number('+0x1') should be NaN, old V8 fix
9256 } else if (first === 48) {
9257 switch (charCodeAt$3(it, 1)) {
9258 case 66:
9259 case 98:
9260 radix = 2;
9261 maxCode = 49;
9262 break;
9263 // fast equal of /^0b[01]+$/i
9264 case 79:
9265 case 111:
9266 radix = 8;
9267 maxCode = 55;
9268 break;
9269 // fast equal of /^0o[0-7]+$/i
9270 default:
9271 return +it;
9272 }
9273 digits = stringSlice$e(it, 2);
9274 length = digits.length;
9275 for (index = 0; index < length; index++) {
9276 code = charCodeAt$3(digits, index);
9277 // parseInt parses a string to a first unavailable symbol
9278 // but ToNumber should return NaN if a string contains unavailable symbols
9279 if (code < 48 || code > maxCode) return NaN;
9280 }
9281 return parseInt(digits, radix);
9282 }
9283 }
9284 return +it;
9285 };
9286 var FORCED$i = isForced$2(NUMBER, !NativeNumber(' 0o1') || !NativeNumber('0b1') || NativeNumber('+0x1'));
9287 var calledWithNew = function calledWithNew(dummy) {
9288 // includes check on 1..constructor(foo) case
9289 return isPrototypeOf$3(NumberPrototype, dummy) && fails$O(function () {
9290 thisNumberValue$4(dummy);
9291 });
9292 };
9293
9294 // `Number` constructor
9295 // https://tc39.es/ecma262/#sec-number-constructor
9296 var NumberWrapper = function Number(value) {
9297 var n = arguments.length < 1 ? 0 : NativeNumber(toNumeric(value));
9298 return calledWithNew(this) ? inheritIfRequired$3(Object(n), this, NumberWrapper) : n;
9299 };
9300 NumberWrapper.prototype = NumberPrototype;
9301 if (FORCED$i && !IS_PURE$e) NumberPrototype.constructor = NumberWrapper;
9302 $$1L({
9303 global: true,
9304 constructor: true,
9305 wrap: true,
9306 forced: FORCED$i
9307 }, {
9308 Number: NumberWrapper
9309 });
9310
9311 // Use `internal/copy-constructor-properties` helper in `core-js@4`
9312 var copyConstructorProperties = function copyConstructorProperties(target, source) {
9313 for (var keys = DESCRIPTORS$q ? getOwnPropertyNames$3(source) : (
9314 // ES3:
9315 'MAX_VALUE,MIN_VALUE,NaN,NEGATIVE_INFINITY,POSITIVE_INFINITY,' +
9316 // ES2015 (in case, if modules with ES2015 Number statics required before):
9317 'EPSILON,MAX_SAFE_INTEGER,MIN_SAFE_INTEGER,isFinite,isInteger,isNaN,isSafeInteger,parseFloat,parseInt,' +
9318 // ESNext
9319 'fromString,range').split(','), j = 0, key; keys.length > j; j++) {
9320 if (hasOwn$d(source, key = keys[j]) && !hasOwn$d(target, key)) {
9321 defineProperty$5(target, key, getOwnPropertyDescriptor$7(source, key));
9322 }
9323 }
9324 };
9325 if (IS_PURE$e && PureNumberNamespace) copyConstructorProperties(path[NUMBER], PureNumberNamespace);
9326 if (FORCED$i || IS_PURE$e) copyConstructorProperties(path[NUMBER], NativeNumber);
9327
9328 var es_number_epsilon = {};
9329
9330 var $$1K = _export;
9331
9332 // `Number.EPSILON` constant
9333 // https://tc39.es/ecma262/#sec-number.epsilon
9334 $$1K({
9335 target: 'Number',
9336 stat: true,
9337 nonConfigurable: true,
9338 nonWritable: true
9339 }, {
9340 EPSILON: Math.pow(2, -52)
9341 });
9342
9343 var es_number_isFinite = {};
9344
9345 var global$D = global$Z;
9346 var globalIsFinite = global$D.isFinite;
9347
9348 // `Number.isFinite` method
9349 // https://tc39.es/ecma262/#sec-number.isfinite
9350 // eslint-disable-next-line es/no-number-isfinite -- safe
9351 var numberIsFinite$1 = Number.isFinite || function isFinite(it) {
9352 return typeof it == 'number' && globalIsFinite(it);
9353 };
9354 var numberIsFinite$2 = /*@__PURE__*/getDefaultExportFromCjs(numberIsFinite$1);
9355
9356 var $$1J = _export;
9357 var numberIsFinite = numberIsFinite$1;
9358
9359 // `Number.isFinite` method
9360 // https://tc39.es/ecma262/#sec-number.isfinite
9361 $$1J({
9362 target: 'Number',
9363 stat: true
9364 }, {
9365 isFinite: numberIsFinite
9366 });
9367
9368 var es_number_isInteger = {};
9369
9370 var isObject$h = isObject$z;
9371 var floor$9 = Math.floor;
9372
9373 // `IsIntegralNumber` abstract operation
9374 // https://tc39.es/ecma262/#sec-isintegralnumber
9375 // eslint-disable-next-line es/no-number-isinteger -- safe
9376 var isIntegralNumber$3 = Number.isInteger || function isInteger(it) {
9377 return !isObject$h(it) && isFinite(it) && floor$9(it) === it;
9378 };
9379 var isIntegralNumber$4 = /*@__PURE__*/getDefaultExportFromCjs(isIntegralNumber$3);
9380
9381 var $$1I = _export;
9382 var isIntegralNumber$2 = isIntegralNumber$3;
9383
9384 // `Number.isInteger` method
9385 // https://tc39.es/ecma262/#sec-number.isinteger
9386 $$1I({
9387 target: 'Number',
9388 stat: true
9389 }, {
9390 isInteger: isIntegralNumber$2
9391 });
9392
9393 var es_number_isNan = {};
9394
9395 var $$1H = _export;
9396
9397 // `Number.isNaN` method
9398 // https://tc39.es/ecma262/#sec-number.isnan
9399 $$1H({
9400 target: 'Number',
9401 stat: true
9402 }, {
9403 isNaN: function isNaN(number) {
9404 // eslint-disable-next-line no-self-compare -- NaN check
9405 return number != number;
9406 }
9407 });
9408
9409 var es_number_isSafeInteger = {};
9410
9411 var $$1G = _export;
9412 var isIntegralNumber$1 = isIntegralNumber$3;
9413 var abs$4 = Math.abs;
9414
9415 // `Number.isSafeInteger` method
9416 // https://tc39.es/ecma262/#sec-number.issafeinteger
9417 $$1G({
9418 target: 'Number',
9419 stat: true
9420 }, {
9421 isSafeInteger: function isSafeInteger(number) {
9422 return isIntegralNumber$1(number) && abs$4(number) <= 0x1FFFFFFFFFFFFF;
9423 }
9424 });
9425
9426 var es_number_maxSafeInteger = {};
9427
9428 var $$1F = _export;
9429
9430 // `Number.MAX_SAFE_INTEGER` constant
9431 // https://tc39.es/ecma262/#sec-number.max_safe_integer
9432 $$1F({
9433 target: 'Number',
9434 stat: true,
9435 nonConfigurable: true,
9436 nonWritable: true
9437 }, {
9438 MAX_SAFE_INTEGER: 0x1FFFFFFFFFFFFF
9439 });
9440
9441 var es_number_minSafeInteger = {};
9442
9443 var $$1E = _export;
9444
9445 // `Number.MIN_SAFE_INTEGER` constant
9446 // https://tc39.es/ecma262/#sec-number.min_safe_integer
9447 $$1E({
9448 target: 'Number',
9449 stat: true,
9450 nonConfigurable: true,
9451 nonWritable: true
9452 }, {
9453 MIN_SAFE_INTEGER: -0x1FFFFFFFFFFFFF
9454 });
9455
9456 var es_number_parseFloat = {};
9457
9458 var global$C = global$Z;
9459 var fails$N = fails$1m;
9460 var uncurryThis$G = functionUncurryThis;
9461 var toString$n = toString$x;
9462 var trim$1 = stringTrim.trim;
9463 var whitespaces$3 = whitespaces$5;
9464 var charAt$c = uncurryThis$G(''.charAt);
9465 var $parseFloat$1 = global$C.parseFloat;
9466 var _Symbol$1 = global$C.Symbol;
9467 var ITERATOR$5 = _Symbol$1 && _Symbol$1.iterator;
9468 var FORCED$h = 1 / $parseFloat$1(whitespaces$3 + '-0') !== -Infinity
9469 // MS Edge 18- broken with boxed symbols
9470 || ITERATOR$5 && !fails$N(function () {
9471 $parseFloat$1(Object(ITERATOR$5));
9472 });
9473
9474 // `parseFloat` method
9475 // https://tc39.es/ecma262/#sec-parsefloat-string
9476 var numberParseFloat = FORCED$h ? function parseFloat(string) {
9477 var trimmedString = trim$1(toString$n(string));
9478 var result = $parseFloat$1(trimmedString);
9479 return result === 0 && charAt$c(trimmedString, 0) == '-' ? -0 : result;
9480 } : $parseFloat$1;
9481 var numberParseFloat$1 = /*@__PURE__*/getDefaultExportFromCjs(numberParseFloat);
9482
9483 var $$1D = _export;
9484 var parseFloat$1 = numberParseFloat;
9485
9486 // `Number.parseFloat` method
9487 // https://tc39.es/ecma262/#sec-number.parseFloat
9488 // eslint-disable-next-line es/no-number-parsefloat -- required for testing
9489 $$1D({
9490 target: 'Number',
9491 stat: true,
9492 forced: Number.parseFloat != parseFloat$1
9493 }, {
9494 parseFloat: parseFloat$1
9495 });
9496
9497 var es_number_parseInt = {};
9498
9499 var global$B = global$Z;
9500 var fails$M = fails$1m;
9501 var uncurryThis$F = functionUncurryThis;
9502 var toString$m = toString$x;
9503 var trim = stringTrim.trim;
9504 var whitespaces$2 = whitespaces$5;
9505 var $parseInt$1 = global$B.parseInt;
9506 var _Symbol = global$B.Symbol;
9507 var ITERATOR$4 = _Symbol && _Symbol.iterator;
9508 var hex = /^[+-]?0x/i;
9509 var exec$7 = uncurryThis$F(hex.exec);
9510 var FORCED$g = $parseInt$1(whitespaces$2 + '08') !== 8 || $parseInt$1(whitespaces$2 + '0x16') !== 22
9511 // MS Edge 18- broken with boxed symbols
9512 || ITERATOR$4 && !fails$M(function () {
9513 $parseInt$1(Object(ITERATOR$4));
9514 });
9515
9516 // `parseInt` method
9517 // https://tc39.es/ecma262/#sec-parseint-string-radix
9518 var numberParseInt = FORCED$g ? function parseInt(string, radix) {
9519 var S = trim(toString$m(string));
9520 return $parseInt$1(S, radix >>> 0 || (exec$7(hex, S) ? 16 : 10));
9521 } : $parseInt$1;
9522 var numberParseInt$1 = /*@__PURE__*/getDefaultExportFromCjs(numberParseInt);
9523
9524 var $$1C = _export;
9525 var parseInt$2 = numberParseInt;
9526
9527 // `Number.parseInt` method
9528 // https://tc39.es/ecma262/#sec-number.parseint
9529 // eslint-disable-next-line es/no-number-parseint -- required for testing
9530 $$1C({
9531 target: 'Number',
9532 stat: true,
9533 forced: Number.parseInt != parseInt$2
9534 }, {
9535 parseInt: parseInt$2
9536 });
9537
9538 var es_number_toExponential = {};
9539
9540 'use strict';
9541 var $$1B = _export;
9542 var uncurryThis$E = functionUncurryThis;
9543 var toIntegerOrInfinity$8 = toIntegerOrInfinity$l;
9544 var thisNumberValue$3 = thisNumberValue$5;
9545 var $repeat$1 = stringRepeat;
9546 var log10 = mathLog10;
9547 var fails$L = fails$1m;
9548 var $RangeError$5 = RangeError;
9549 var $String$1 = String;
9550 var $isFinite = isFinite;
9551 var abs$3 = Math.abs;
9552 var floor$8 = Math.floor;
9553 var pow$6 = Math.pow;
9554 var round$5 = Math.round;
9555 var nativeToExponential = uncurryThis$E(1.0.toExponential);
9556 var repeat$3 = uncurryThis$E($repeat$1);
9557 var stringSlice$d = uncurryThis$E(''.slice);
9558
9559 // Edge 17-
9560 var ROUNDS_PROPERLY = nativeToExponential(-6.9e-11, 4) === '-6.9000e-11'
9561 // IE11- && Edge 14-
9562 && nativeToExponential(1.255, 2) === '1.25e+0'
9563 // FF86-, V8 ~ Chrome 49-50
9564 && nativeToExponential(12345, 3) === '1.235e+4'
9565 // FF86-, V8 ~ Chrome 49-50
9566 && nativeToExponential(25, 0) === '3e+1';
9567
9568 // IE8-
9569 var throwsOnInfinityFraction = function throwsOnInfinityFraction() {
9570 return fails$L(function () {
9571 nativeToExponential(1, Infinity);
9572 }) && fails$L(function () {
9573 nativeToExponential(1, -Infinity);
9574 });
9575 };
9576
9577 // Safari <11 && FF <50
9578 var properNonFiniteThisCheck = function properNonFiniteThisCheck() {
9579 return !fails$L(function () {
9580 nativeToExponential(Infinity, Infinity);
9581 nativeToExponential(NaN, Infinity);
9582 });
9583 };
9584 var FORCED$f = !ROUNDS_PROPERLY || !throwsOnInfinityFraction() || !properNonFiniteThisCheck();
9585
9586 // `Number.prototype.toExponential` method
9587 // https://tc39.es/ecma262/#sec-number.prototype.toexponential
9588 $$1B({
9589 target: 'Number',
9590 proto: true,
9591 forced: FORCED$f
9592 }, {
9593 toExponential: function toExponential(fractionDigits) {
9594 var x = thisNumberValue$3(this);
9595 if (fractionDigits === undefined) return nativeToExponential(x);
9596 var f = toIntegerOrInfinity$8(fractionDigits);
9597 if (!$isFinite(x)) return String(x);
9598 // TODO: ES2018 increased the maximum number of fraction digits to 100, need to improve the implementation
9599 if (f < 0 || f > 20) throw $RangeError$5('Incorrect fraction digits');
9600 if (ROUNDS_PROPERLY) return nativeToExponential(x, f);
9601 var s = '';
9602 var m = '';
9603 var e = 0;
9604 var c = '';
9605 var d = '';
9606 if (x < 0) {
9607 s = '-';
9608 x = -x;
9609 }
9610 if (x === 0) {
9611 e = 0;
9612 m = repeat$3('0', f + 1);
9613 } else {
9614 // this block is based on https://gist.github.com/SheetJSDev/1100ad56b9f856c95299ed0e068eea08
9615 // TODO: improve accuracy with big fraction digits
9616 var l = log10(x);
9617 e = floor$8(l);
9618 var n = 0;
9619 var w = pow$6(10, e - f);
9620 n = round$5(x / w);
9621 if (2 * x >= (2 * n + 1) * w) {
9622 n += 1;
9623 }
9624 if (n >= pow$6(10, f + 1)) {
9625 n /= 10;
9626 e += 1;
9627 }
9628 m = $String$1(n);
9629 }
9630 if (f !== 0) {
9631 m = stringSlice$d(m, 0, 1) + '.' + stringSlice$d(m, 1);
9632 }
9633 if (e === 0) {
9634 c = '+';
9635 d = '0';
9636 } else {
9637 c = e > 0 ? '+' : '-';
9638 d = $String$1(abs$3(e));
9639 }
9640 m += 'e' + c + d;
9641 return s + m;
9642 }
9643 });
9644
9645 var es_number_toFixed = {};
9646
9647 'use strict';
9648 var $$1A = _export;
9649 var uncurryThis$D = functionUncurryThis;
9650 var toIntegerOrInfinity$7 = toIntegerOrInfinity$l;
9651 var thisNumberValue$2 = thisNumberValue$5;
9652 var $repeat = stringRepeat;
9653 var fails$K = fails$1m;
9654 var $RangeError$4 = RangeError;
9655 var $String = String;
9656 var floor$7 = Math.floor;
9657 var repeat$2 = uncurryThis$D($repeat);
9658 var stringSlice$c = uncurryThis$D(''.slice);
9659 var nativeToFixed = uncurryThis$D(1.0.toFixed);
9660 var pow$5 = function pow(x, n, acc) {
9661 return n === 0 ? acc : n % 2 === 1 ? pow(x, n - 1, acc * x) : pow(x * x, n / 2, acc);
9662 };
9663 var log$4 = function log(x) {
9664 var n = 0;
9665 var x2 = x;
9666 while (x2 >= 4096) {
9667 n += 12;
9668 x2 /= 4096;
9669 }
9670 while (x2 >= 2) {
9671 n += 1;
9672 x2 /= 2;
9673 }
9674 return n;
9675 };
9676 var multiply$4 = function multiply(data, n, c) {
9677 var index = -1;
9678 var c2 = c;
9679 while (++index < 6) {
9680 c2 += n * data[index];
9681 data[index] = c2 % 1e7;
9682 c2 = floor$7(c2 / 1e7);
9683 }
9684 };
9685 var divide = function divide(data, n) {
9686 var index = 6;
9687 var c = 0;
9688 while (--index >= 0) {
9689 c += data[index];
9690 data[index] = floor$7(c / n);
9691 c = c % n * 1e7;
9692 }
9693 };
9694 var dataToString = function dataToString(data) {
9695 var index = 6;
9696 var s = '';
9697 while (--index >= 0) {
9698 if (s !== '' || index === 0 || data[index] !== 0) {
9699 var t = $String(data[index]);
9700 s = s === '' ? t : s + repeat$2('0', 7 - t.length) + t;
9701 }
9702 }
9703 return s;
9704 };
9705 var FORCED$e = fails$K(function () {
9706 return nativeToFixed(0.00008, 3) !== '0.000' || nativeToFixed(0.9, 0) !== '1' || nativeToFixed(1.255, 2) !== '1.25' || nativeToFixed(1000000000000000128.0, 0) !== '1000000000000000128';
9707 }) || !fails$K(function () {
9708 // V8 ~ Android 4.3-
9709 nativeToFixed({});
9710 });
9711
9712 // `Number.prototype.toFixed` method
9713 // https://tc39.es/ecma262/#sec-number.prototype.tofixed
9714 $$1A({
9715 target: 'Number',
9716 proto: true,
9717 forced: FORCED$e
9718 }, {
9719 toFixed: function toFixed(fractionDigits) {
9720 var number = thisNumberValue$2(this);
9721 var fractDigits = toIntegerOrInfinity$7(fractionDigits);
9722 var data = [0, 0, 0, 0, 0, 0];
9723 var sign = '';
9724 var result = '0';
9725 var e, z, j, k;
9726
9727 // TODO: ES2018 increased the maximum number of fraction digits to 100, need to improve the implementation
9728 if (fractDigits < 0 || fractDigits > 20) throw $RangeError$4('Incorrect fraction digits');
9729 // eslint-disable-next-line no-self-compare -- NaN check
9730 if (number != number) return 'NaN';
9731 if (number <= -1e21 || number >= 1e21) return $String(number);
9732 if (number < 0) {
9733 sign = '-';
9734 number = -number;
9735 }
9736 if (number > 1e-21) {
9737 e = log$4(number * pow$5(2, 69, 1)) - 69;
9738 z = e < 0 ? number * pow$5(2, -e, 1) : number / pow$5(2, e, 1);
9739 z *= 0x10000000000000;
9740 e = 52 - e;
9741 if (e > 0) {
9742 multiply$4(data, 0, z);
9743 j = fractDigits;
9744 while (j >= 7) {
9745 multiply$4(data, 1e7, 0);
9746 j -= 7;
9747 }
9748 multiply$4(data, pow$5(10, j, 1), 0);
9749 j = e - 1;
9750 while (j >= 23) {
9751 divide(data, 1 << 23);
9752 j -= 23;
9753 }
9754 divide(data, 1 << j);
9755 multiply$4(data, 1, 1);
9756 divide(data, 2);
9757 result = dataToString(data);
9758 } else {
9759 multiply$4(data, 0, z);
9760 multiply$4(data, 1 << -e, 0);
9761 result = dataToString(data) + repeat$2('0', fractDigits);
9762 }
9763 }
9764 if (fractDigits > 0) {
9765 k = result.length;
9766 result = sign + (k <= fractDigits ? '0.' + repeat$2('0', fractDigits - k) + result : stringSlice$c(result, 0, k - fractDigits) + '.' + stringSlice$c(result, k - fractDigits));
9767 } else {
9768 result = sign + result;
9769 }
9770 return result;
9771 }
9772 });
9773
9774 var es_number_toPrecision = {};
9775
9776 'use strict';
9777 var $$1z = _export;
9778 var uncurryThis$C = functionUncurryThis;
9779 var fails$J = fails$1m;
9780 var thisNumberValue$1 = thisNumberValue$5;
9781 var nativeToPrecision = uncurryThis$C(1.0.toPrecision);
9782 var FORCED$d = fails$J(function () {
9783 // IE7-
9784 return nativeToPrecision(1, undefined) !== '1';
9785 }) || !fails$J(function () {
9786 // V8 ~ Android 4.3-
9787 nativeToPrecision({});
9788 });
9789
9790 // `Number.prototype.toPrecision` method
9791 // https://tc39.es/ecma262/#sec-number.prototype.toprecision
9792 $$1z({
9793 target: 'Number',
9794 proto: true,
9795 forced: FORCED$d
9796 }, {
9797 toPrecision: function toPrecision(precision) {
9798 return precision === undefined ? nativeToPrecision(thisNumberValue$1(this)) : nativeToPrecision(thisNumberValue$1(this), precision);
9799 }
9800 });
9801
9802 var es_object_assign = {};
9803
9804 'use strict';
9805 var DESCRIPTORS$p = descriptors;
9806 var uncurryThis$B = functionUncurryThis;
9807 var call$q = functionCall;
9808 var fails$I = fails$1m;
9809 var objectKeys$2 = objectKeys$5;
9810 var getOwnPropertySymbolsModule = objectGetOwnPropertySymbols;
9811 var propertyIsEnumerableModule = objectPropertyIsEnumerable;
9812 var toObject$a = toObject$t;
9813 var IndexedObject = indexedObject;
9814
9815 // eslint-disable-next-line es/no-object-assign -- safe
9816 var $assign = Object.assign;
9817 // eslint-disable-next-line es/no-object-defineproperty -- required for testing
9818 var defineProperty$4 = Object.defineProperty;
9819 var concat$4 = uncurryThis$B([].concat);
9820
9821 // `Object.assign` method
9822 // https://tc39.es/ecma262/#sec-object.assign
9823 var objectAssign = !$assign || fails$I(function () {
9824 // should have correct order of operations (Edge bug)
9825 if (DESCRIPTORS$p && $assign({
9826 b: 1
9827 }, $assign(defineProperty$4({}, 'a', {
9828 enumerable: true,
9829 get: function get() {
9830 defineProperty$4(this, 'b', {
9831 value: 3,
9832 enumerable: false
9833 });
9834 }
9835 }), {
9836 b: 2
9837 })).b !== 1) return true;
9838 // should work with symbols and should have deterministic property order (V8 bug)
9839 var A = {};
9840 var B = {};
9841 // eslint-disable-next-line es/no-symbol -- safe
9842 var symbol = Symbol();
9843 var alphabet = 'abcdefghijklmnopqrst';
9844 A[symbol] = 7;
9845 alphabet.split('').forEach(function (chr) {
9846 B[chr] = chr;
9847 });
9848 return $assign({}, A)[symbol] != 7 || objectKeys$2($assign({}, B)).join('') != alphabet;
9849 }) ? function assign(target, source) {
9850 // eslint-disable-line no-unused-vars -- required for `.length`
9851 var T = toObject$a(target);
9852 var argumentsLength = arguments.length;
9853 var index = 1;
9854 var getOwnPropertySymbols = getOwnPropertySymbolsModule.f;
9855 var propertyIsEnumerable = propertyIsEnumerableModule.f;
9856 while (argumentsLength > index) {
9857 var S = IndexedObject(arguments[index++]);
9858 var keys = getOwnPropertySymbols ? concat$4(objectKeys$2(S), getOwnPropertySymbols(S)) : objectKeys$2(S);
9859 var length = keys.length;
9860 var j = 0;
9861 var key;
9862 while (length > j) {
9863 key = keys[j++];
9864 if (!DESCRIPTORS$p || call$q(propertyIsEnumerable, S, key)) T[key] = S[key];
9865 }
9866 }
9867 return T;
9868 } : $assign;
9869 var objectAssign$1 = /*@__PURE__*/getDefaultExportFromCjs(objectAssign);
9870
9871 var $$1y = _export;
9872 var assign$1 = objectAssign;
9873
9874 // `Object.assign` method
9875 // https://tc39.es/ecma262/#sec-object.assign
9876 // eslint-disable-next-line es/no-object-assign -- required for testing
9877 $$1y({
9878 target: 'Object',
9879 stat: true,
9880 arity: 2,
9881 forced: Object.assign !== assign$1
9882 }, {
9883 assign: assign$1
9884 });
9885
9886 var es_object_create = {};
9887
9888 // TODO: Remove from `core-js@4`
9889 var $$1x = _export;
9890 var DESCRIPTORS$o = descriptors;
9891 var create$5 = objectCreate;
9892
9893 // `Object.create` method
9894 // https://tc39.es/ecma262/#sec-object.create
9895 $$1x({
9896 target: 'Object',
9897 stat: true,
9898 sham: !DESCRIPTORS$o
9899 }, {
9900 create: create$5
9901 });
9902
9903 var es_object_defineGetter = {};
9904
9905 'use strict';
9906 var IS_PURE$d = isPure;
9907 var global$A = global$Z;
9908 var fails$H = fails$1m;
9909 var WEBKIT$1 = engineWebkitVersion;
9910
9911 // Forced replacement object prototype accessors methods
9912 var objectPrototypeAccessorsForced = IS_PURE$d || !fails$H(function () {
9913 // This feature detection crashes old WebKit
9914 // https://github.com/zloirock/core-js/issues/232
9915 if (WEBKIT$1 && WEBKIT$1 < 535) return;
9916 var key = Math.random();
9917 // In FF throws only define methods
9918 // eslint-disable-next-line no-undef, no-useless-call, es/no-legacy-object-prototype-accessor-methods -- required for testing
9919 __defineSetter__.call(null, key, function () {/* empty */});
9920 delete global$A[key];
9921 });
9922 var objectPrototypeAccessorsForced$1 = /*@__PURE__*/getDefaultExportFromCjs(objectPrototypeAccessorsForced);
9923
9924 'use strict';
9925 var $$1w = _export;
9926 var DESCRIPTORS$n = descriptors;
9927 var FORCED$c = objectPrototypeAccessorsForced;
9928 var aCallable$b = aCallable$l;
9929 var toObject$9 = toObject$t;
9930 var definePropertyModule$4 = objectDefineProperty;
9931
9932 // `Object.prototype.__defineGetter__` method
9933 // https://tc39.es/ecma262/#sec-object.prototype.__defineGetter__
9934 if (DESCRIPTORS$n) {
9935 $$1w({
9936 target: 'Object',
9937 proto: true,
9938 forced: FORCED$c
9939 }, {
9940 __defineGetter__: function __defineGetter__(P, getter) {
9941 definePropertyModule$4.f(toObject$9(this), P, {
9942 get: aCallable$b(getter),
9943 enumerable: true,
9944 configurable: true
9945 });
9946 }
9947 });
9948 }
9949
9950 var es_object_defineProperties = {};
9951
9952 var $$1v = _export;
9953 var DESCRIPTORS$m = descriptors;
9954 var defineProperties = objectDefineProperties.f;
9955
9956 // `Object.defineProperties` method
9957 // https://tc39.es/ecma262/#sec-object.defineproperties
9958 // eslint-disable-next-line es/no-object-defineproperties -- safe
9959 $$1v({
9960 target: 'Object',
9961 stat: true,
9962 forced: Object.defineProperties !== defineProperties,
9963 sham: !DESCRIPTORS$m
9964 }, {
9965 defineProperties: defineProperties
9966 });
9967
9968 var es_object_defineProperty = {};
9969
9970 var $$1u = _export;
9971 var DESCRIPTORS$l = descriptors;
9972 var defineProperty$3 = objectDefineProperty.f;
9973
9974 // `Object.defineProperty` method
9975 // https://tc39.es/ecma262/#sec-object.defineproperty
9976 // eslint-disable-next-line es/no-object-defineproperty -- safe
9977 $$1u({
9978 target: 'Object',
9979 stat: true,
9980 forced: Object.defineProperty !== defineProperty$3,
9981 sham: !DESCRIPTORS$l
9982 }, {
9983 defineProperty: defineProperty$3
9984 });
9985
9986 var es_object_defineSetter = {};
9987
9988 'use strict';
9989 var $$1t = _export;
9990 var DESCRIPTORS$k = descriptors;
9991 var FORCED$b = objectPrototypeAccessorsForced;
9992 var aCallable$a = aCallable$l;
9993 var toObject$8 = toObject$t;
9994 var definePropertyModule$3 = objectDefineProperty;
9995
9996 // `Object.prototype.__defineSetter__` method
9997 // https://tc39.es/ecma262/#sec-object.prototype.__defineSetter__
9998 if (DESCRIPTORS$k) {
9999 $$1t({
10000 target: 'Object',
10001 proto: true,
10002 forced: FORCED$b
10003 }, {
10004 __defineSetter__: function __defineSetter__(P, setter) {
10005 definePropertyModule$3.f(toObject$8(this), P, {
10006 set: aCallable$a(setter),
10007 enumerable: true,
10008 configurable: true
10009 });
10010 }
10011 });
10012 }
10013
10014 var es_object_entries = {};
10015
10016 var DESCRIPTORS$j = descriptors;
10017 var uncurryThis$A = functionUncurryThis;
10018 var objectKeys$1 = objectKeys$5;
10019 var toIndexedObject$4 = toIndexedObject$j;
10020 var $propertyIsEnumerable = objectPropertyIsEnumerable.f;
10021 var propertyIsEnumerable = uncurryThis$A($propertyIsEnumerable);
10022 var push$8 = uncurryThis$A([].push);
10023
10024 // `Object.{ entries, values }` methods implementation
10025 var createMethod$1 = function createMethod(TO_ENTRIES) {
10026 return function (it) {
10027 var O = toIndexedObject$4(it);
10028 var keys = objectKeys$1(O);
10029 var length = keys.length;
10030 var i = 0;
10031 var result = [];
10032 var key;
10033 while (length > i) {
10034 key = keys[i++];
10035 if (!DESCRIPTORS$j || propertyIsEnumerable(O, key)) {
10036 push$8(result, TO_ENTRIES ? [key, O[key]] : O[key]);
10037 }
10038 }
10039 return result;
10040 };
10041 };
10042 var objectToArray = {
10043 // `Object.entries` method
10044 // https://tc39.es/ecma262/#sec-object.entries
10045 entries: createMethod$1(true),
10046 // `Object.values` method
10047 // https://tc39.es/ecma262/#sec-object.values
10048 values: createMethod$1(false)
10049 };
10050 var objectToArray$1 = /*@__PURE__*/getDefaultExportFromCjs(objectToArray);
10051
10052 var $$1s = _export;
10053 var $entries = objectToArray.entries;
10054
10055 // `Object.entries` method
10056 // https://tc39.es/ecma262/#sec-object.entries
10057 $$1s({
10058 target: 'Object',
10059 stat: true
10060 }, {
10061 entries: function entries(O) {
10062 return $entries(O);
10063 }
10064 });
10065
10066 var es_object_freeze = {};
10067
10068 var $$1r = _export;
10069 var FREEZING$4 = freezing;
10070 var fails$G = fails$1m;
10071 var isObject$g = isObject$z;
10072 var onFreeze$2 = internalMetadataExports.onFreeze;
10073
10074 // eslint-disable-next-line es/no-object-freeze -- safe
10075 var $freeze = Object.freeze;
10076 var FAILS_ON_PRIMITIVES$5 = fails$G(function () {
10077 $freeze(1);
10078 });
10079
10080 // `Object.freeze` method
10081 // https://tc39.es/ecma262/#sec-object.freeze
10082 $$1r({
10083 target: 'Object',
10084 stat: true,
10085 forced: FAILS_ON_PRIMITIVES$5,
10086 sham: !FREEZING$4
10087 }, {
10088 freeze: function freeze(it) {
10089 return $freeze && isObject$g(it) ? $freeze(onFreeze$2(it)) : it;
10090 }
10091 });
10092
10093 var es_object_fromEntries = {};
10094
10095 var $$1q = _export;
10096 var iterate$6 = iterate$a;
10097 var createProperty$2 = createProperty$9;
10098
10099 // `Object.fromEntries` method
10100 // https://github.com/tc39/proposal-object-from-entries
10101 $$1q({
10102 target: 'Object',
10103 stat: true
10104 }, {
10105 fromEntries: function fromEntries(iterable) {
10106 var obj = {};
10107 iterate$6(iterable, function (k, v) {
10108 createProperty$2(obj, k, v);
10109 }, {
10110 AS_ENTRIES: true
10111 });
10112 return obj;
10113 }
10114 });
10115
10116 var es_object_getOwnPropertyDescriptor = {};
10117
10118 var $$1p = _export;
10119 var fails$F = fails$1m;
10120 var toIndexedObject$3 = toIndexedObject$j;
10121 var nativeGetOwnPropertyDescriptor$1 = objectGetOwnPropertyDescriptor.f;
10122 var DESCRIPTORS$i = descriptors;
10123 var FORCED$a = !DESCRIPTORS$i || fails$F(function () {
10124 nativeGetOwnPropertyDescriptor$1(1);
10125 });
10126
10127 // `Object.getOwnPropertyDescriptor` method
10128 // https://tc39.es/ecma262/#sec-object.getownpropertydescriptor
10129 $$1p({
10130 target: 'Object',
10131 stat: true,
10132 forced: FORCED$a,
10133 sham: !DESCRIPTORS$i
10134 }, {
10135 getOwnPropertyDescriptor: function getOwnPropertyDescriptor(it, key) {
10136 return nativeGetOwnPropertyDescriptor$1(toIndexedObject$3(it), key);
10137 }
10138 });
10139
10140 var es_object_getOwnPropertyDescriptors = {};
10141
10142 var $$1o = _export;
10143 var DESCRIPTORS$h = descriptors;
10144 var ownKeys$1 = ownKeys$3;
10145 var toIndexedObject$2 = toIndexedObject$j;
10146 var getOwnPropertyDescriptorModule$4 = objectGetOwnPropertyDescriptor;
10147 var createProperty$1 = createProperty$9;
10148
10149 // `Object.getOwnPropertyDescriptors` method
10150 // https://tc39.es/ecma262/#sec-object.getownpropertydescriptors
10151 $$1o({
10152 target: 'Object',
10153 stat: true,
10154 sham: !DESCRIPTORS$h
10155 }, {
10156 getOwnPropertyDescriptors: function getOwnPropertyDescriptors(object) {
10157 var O = toIndexedObject$2(object);
10158 var getOwnPropertyDescriptor = getOwnPropertyDescriptorModule$4.f;
10159 var keys = ownKeys$1(O);
10160 var result = {};
10161 var index = 0;
10162 var key, descriptor;
10163 while (keys.length > index) {
10164 descriptor = getOwnPropertyDescriptor(O, key = keys[index++]);
10165 if (descriptor !== undefined) createProperty$1(result, key, descriptor);
10166 }
10167 return result;
10168 }
10169 });
10170
10171 var es_object_getOwnPropertyNames = {};
10172
10173 var $$1n = _export;
10174 var fails$E = fails$1m;
10175 var getOwnPropertyNames$2 = objectGetOwnPropertyNamesExternal.f;
10176
10177 // eslint-disable-next-line es/no-object-getownpropertynames -- required for testing
10178 var FAILS_ON_PRIMITIVES$4 = fails$E(function () {
10179 return !Object.getOwnPropertyNames(1);
10180 });
10181
10182 // `Object.getOwnPropertyNames` method
10183 // https://tc39.es/ecma262/#sec-object.getownpropertynames
10184 $$1n({
10185 target: 'Object',
10186 stat: true,
10187 forced: FAILS_ON_PRIMITIVES$4
10188 }, {
10189 getOwnPropertyNames: getOwnPropertyNames$2
10190 });
10191
10192 var es_object_getPrototypeOf = {};
10193
10194 var $$1m = _export;
10195 var fails$D = fails$1m;
10196 var toObject$7 = toObject$t;
10197 var nativeGetPrototypeOf = objectGetPrototypeOf$1;
10198 var CORRECT_PROTOTYPE_GETTER$1 = correctPrototypeGetter;
10199 var FAILS_ON_PRIMITIVES$3 = fails$D(function () {
10200 nativeGetPrototypeOf(1);
10201 });
10202
10203 // `Object.getPrototypeOf` method
10204 // https://tc39.es/ecma262/#sec-object.getprototypeof
10205 $$1m({
10206 target: 'Object',
10207 stat: true,
10208 forced: FAILS_ON_PRIMITIVES$3,
10209 sham: !CORRECT_PROTOTYPE_GETTER$1
10210 }, {
10211 getPrototypeOf: function getPrototypeOf(it) {
10212 return nativeGetPrototypeOf(toObject$7(it));
10213 }
10214 });
10215
10216 var es_object_hasOwn = {};
10217
10218 var $$1l = _export;
10219 var hasOwn$c = hasOwnProperty_1;
10220
10221 // `Object.hasOwn` method
10222 // https://github.com/tc39/proposal-accessible-object-hasownproperty
10223 $$1l({
10224 target: 'Object',
10225 stat: true
10226 }, {
10227 hasOwn: hasOwn$c
10228 });
10229
10230 var es_object_is = {};
10231
10232 // `SameValue` abstract operation
10233 // https://tc39.es/ecma262/#sec-samevalue
10234 // eslint-disable-next-line es/no-object-is -- safe
10235 var sameValue$1 = Object.is || function is(x, y) {
10236 // eslint-disable-next-line no-self-compare -- NaN check
10237 return x === y ? x !== 0 || 1 / x === 1 / y : x != x && y != y;
10238 };
10239 var sameValue$2 = /*@__PURE__*/getDefaultExportFromCjs(sameValue$1);
10240
10241 var $$1k = _export;
10242 var is = sameValue$1;
10243
10244 // `Object.is` method
10245 // https://tc39.es/ecma262/#sec-object.is
10246 $$1k({
10247 target: 'Object',
10248 stat: true
10249 }, {
10250 is: is
10251 });
10252
10253 var es_object_isExtensible = {};
10254
10255 var $$1j = _export;
10256 var $isExtensible$1 = objectIsExtensible;
10257
10258 // `Object.isExtensible` method
10259 // https://tc39.es/ecma262/#sec-object.isextensible
10260 // eslint-disable-next-line es/no-object-isextensible -- safe
10261 $$1j({
10262 target: 'Object',
10263 stat: true,
10264 forced: Object.isExtensible !== $isExtensible$1
10265 }, {
10266 isExtensible: $isExtensible$1
10267 });
10268
10269 var es_object_isFrozen = {};
10270
10271 var $$1i = _export;
10272 var fails$C = fails$1m;
10273 var isObject$f = isObject$z;
10274 var classof$c = classofRaw$2;
10275 var ARRAY_BUFFER_NON_EXTENSIBLE$1 = arrayBufferNonExtensible;
10276
10277 // eslint-disable-next-line es/no-object-isfrozen -- safe
10278 var $isFrozen = Object.isFrozen;
10279 var FORCED$9 = ARRAY_BUFFER_NON_EXTENSIBLE$1 || fails$C(function () {
10280 $isFrozen(1);
10281 });
10282
10283 // `Object.isFrozen` method
10284 // https://tc39.es/ecma262/#sec-object.isfrozen
10285 $$1i({
10286 target: 'Object',
10287 stat: true,
10288 forced: FORCED$9
10289 }, {
10290 isFrozen: function isFrozen(it) {
10291 if (!isObject$f(it)) return true;
10292 if (ARRAY_BUFFER_NON_EXTENSIBLE$1 && classof$c(it) == 'ArrayBuffer') return true;
10293 return $isFrozen ? $isFrozen(it) : false;
10294 }
10295 });
10296
10297 var es_object_isSealed = {};
10298
10299 var $$1h = _export;
10300 var fails$B = fails$1m;
10301 var isObject$e = isObject$z;
10302 var classof$b = classofRaw$2;
10303 var ARRAY_BUFFER_NON_EXTENSIBLE = arrayBufferNonExtensible;
10304
10305 // eslint-disable-next-line es/no-object-issealed -- safe
10306 var $isSealed = Object.isSealed;
10307 var FORCED$8 = ARRAY_BUFFER_NON_EXTENSIBLE || fails$B(function () {
10308 $isSealed(1);
10309 });
10310
10311 // `Object.isSealed` method
10312 // https://tc39.es/ecma262/#sec-object.issealed
10313 $$1h({
10314 target: 'Object',
10315 stat: true,
10316 forced: FORCED$8
10317 }, {
10318 isSealed: function isSealed(it) {
10319 if (!isObject$e(it)) return true;
10320 if (ARRAY_BUFFER_NON_EXTENSIBLE && classof$b(it) == 'ArrayBuffer') return true;
10321 return $isSealed ? $isSealed(it) : false;
10322 }
10323 });
10324
10325 var es_object_keys = {};
10326
10327 var $$1g = _export;
10328 var toObject$6 = toObject$t;
10329 var nativeKeys = objectKeys$5;
10330 var fails$A = fails$1m;
10331 var FAILS_ON_PRIMITIVES$2 = fails$A(function () {
10332 nativeKeys(1);
10333 });
10334
10335 // `Object.keys` method
10336 // https://tc39.es/ecma262/#sec-object.keys
10337 $$1g({
10338 target: 'Object',
10339 stat: true,
10340 forced: FAILS_ON_PRIMITIVES$2
10341 }, {
10342 keys: function keys(it) {
10343 return nativeKeys(toObject$6(it));
10344 }
10345 });
10346
10347 var es_object_lookupGetter = {};
10348
10349 'use strict';
10350 var $$1f = _export;
10351 var DESCRIPTORS$g = descriptors;
10352 var FORCED$7 = objectPrototypeAccessorsForced;
10353 var toObject$5 = toObject$t;
10354 var toPropertyKey$3 = toPropertyKey$8;
10355 var getPrototypeOf$4 = objectGetPrototypeOf$1;
10356 var getOwnPropertyDescriptor$6 = objectGetOwnPropertyDescriptor.f;
10357
10358 // `Object.prototype.__lookupGetter__` method
10359 // https://tc39.es/ecma262/#sec-object.prototype.__lookupGetter__
10360 if (DESCRIPTORS$g) {
10361 $$1f({
10362 target: 'Object',
10363 proto: true,
10364 forced: FORCED$7
10365 }, {
10366 __lookupGetter__: function __lookupGetter__(P) {
10367 var O = toObject$5(this);
10368 var key = toPropertyKey$3(P);
10369 var desc;
10370 do {
10371 if (desc = getOwnPropertyDescriptor$6(O, key)) return desc.get;
10372 } while (O = getPrototypeOf$4(O));
10373 }
10374 });
10375 }
10376
10377 var es_object_lookupSetter = {};
10378
10379 'use strict';
10380 var $$1e = _export;
10381 var DESCRIPTORS$f = descriptors;
10382 var FORCED$6 = objectPrototypeAccessorsForced;
10383 var toObject$4 = toObject$t;
10384 var toPropertyKey$2 = toPropertyKey$8;
10385 var getPrototypeOf$3 = objectGetPrototypeOf$1;
10386 var getOwnPropertyDescriptor$5 = objectGetOwnPropertyDescriptor.f;
10387
10388 // `Object.prototype.__lookupSetter__` method
10389 // https://tc39.es/ecma262/#sec-object.prototype.__lookupSetter__
10390 if (DESCRIPTORS$f) {
10391 $$1e({
10392 target: 'Object',
10393 proto: true,
10394 forced: FORCED$6
10395 }, {
10396 __lookupSetter__: function __lookupSetter__(P) {
10397 var O = toObject$4(this);
10398 var key = toPropertyKey$2(P);
10399 var desc;
10400 do {
10401 if (desc = getOwnPropertyDescriptor$5(O, key)) return desc.set;
10402 } while (O = getPrototypeOf$3(O));
10403 }
10404 });
10405 }
10406
10407 var es_object_preventExtensions = {};
10408
10409 var $$1d = _export;
10410 var isObject$d = isObject$z;
10411 var onFreeze$1 = internalMetadataExports.onFreeze;
10412 var FREEZING$3 = freezing;
10413 var fails$z = fails$1m;
10414
10415 // eslint-disable-next-line es/no-object-preventextensions -- safe
10416 var $preventExtensions = Object.preventExtensions;
10417 var FAILS_ON_PRIMITIVES$1 = fails$z(function () {
10418 $preventExtensions(1);
10419 });
10420
10421 // `Object.preventExtensions` method
10422 // https://tc39.es/ecma262/#sec-object.preventextensions
10423 $$1d({
10424 target: 'Object',
10425 stat: true,
10426 forced: FAILS_ON_PRIMITIVES$1,
10427 sham: !FREEZING$3
10428 }, {
10429 preventExtensions: function preventExtensions(it) {
10430 return $preventExtensions && isObject$d(it) ? $preventExtensions(onFreeze$1(it)) : it;
10431 }
10432 });
10433
10434 var es_object_proto = {};
10435
10436 'use strict';
10437 var DESCRIPTORS$e = descriptors;
10438 var defineBuiltInAccessor$9 = defineBuiltInAccessor$h;
10439 var isObject$c = isObject$z;
10440 var toObject$3 = toObject$t;
10441 var requireObjectCoercible$d = requireObjectCoercible$j;
10442
10443 // eslint-disable-next-line es/no-object-getprototypeof -- safe
10444 var getPrototypeOf$2 = Object.getPrototypeOf;
10445 // eslint-disable-next-line es/no-object-setprototypeof -- safe
10446 var setPrototypeOf$3 = Object.setPrototypeOf;
10447 var ObjectPrototype$1 = Object.prototype;
10448 var PROTO = '__proto__';
10449
10450 // `Object.prototype.__proto__` accessor
10451 // https://tc39.es/ecma262/#sec-object.prototype.__proto__
10452 if (DESCRIPTORS$e && getPrototypeOf$2 && setPrototypeOf$3 && !(PROTO in ObjectPrototype$1)) try {
10453 defineBuiltInAccessor$9(ObjectPrototype$1, PROTO, {
10454 configurable: true,
10455 get: function __proto__() {
10456 return getPrototypeOf$2(toObject$3(this));
10457 },
10458 set: function __proto__(proto) {
10459 var O = requireObjectCoercible$d(this);
10460 if (!isObject$c(proto) && proto !== null || !isObject$c(O)) return;
10461 setPrototypeOf$3(O, proto);
10462 }
10463 });
10464 } catch (error) {/* empty */}
10465
10466 var es_object_seal = {};
10467
10468 var $$1c = _export;
10469 var isObject$b = isObject$z;
10470 var onFreeze = internalMetadataExports.onFreeze;
10471 var FREEZING$2 = freezing;
10472 var fails$y = fails$1m;
10473
10474 // eslint-disable-next-line es/no-object-seal -- safe
10475 var $seal = Object.seal;
10476 var FAILS_ON_PRIMITIVES = fails$y(function () {
10477 $seal(1);
10478 });
10479
10480 // `Object.seal` method
10481 // https://tc39.es/ecma262/#sec-object.seal
10482 $$1c({
10483 target: 'Object',
10484 stat: true,
10485 forced: FAILS_ON_PRIMITIVES,
10486 sham: !FREEZING$2
10487 }, {
10488 seal: function seal(it) {
10489 return $seal && isObject$b(it) ? $seal(onFreeze(it)) : it;
10490 }
10491 });
10492
10493 var es_object_setPrototypeOf = {};
10494
10495 var $$1b = _export;
10496 var setPrototypeOf$2 = objectSetPrototypeOf$1;
10497
10498 // `Object.setPrototypeOf` method
10499 // https://tc39.es/ecma262/#sec-object.setprototypeof
10500 $$1b({
10501 target: 'Object',
10502 stat: true
10503 }, {
10504 setPrototypeOf: setPrototypeOf$2
10505 });
10506
10507 var es_object_toString = {};
10508
10509 'use strict';
10510 var TO_STRING_TAG_SUPPORT$1 = toStringTagSupport;
10511 var classof$a = classof$m;
10512
10513 // `Object.prototype.toString` method implementation
10514 // https://tc39.es/ecma262/#sec-object.prototype.tostring
10515 var objectToString = TO_STRING_TAG_SUPPORT$1 ? {}.toString : function toString() {
10516 return '[object ' + classof$a(this) + ']';
10517 };
10518 var objectToString$1 = /*@__PURE__*/getDefaultExportFromCjs(objectToString);
10519
10520 var TO_STRING_TAG_SUPPORT = toStringTagSupport;
10521 var defineBuiltIn$a = defineBuiltIn$m;
10522 var toString$l = objectToString;
10523
10524 // `Object.prototype.toString` method
10525 // https://tc39.es/ecma262/#sec-object.prototype.tostring
10526 if (!TO_STRING_TAG_SUPPORT) {
10527 defineBuiltIn$a(Object.prototype, 'toString', toString$l, {
10528 unsafe: true
10529 });
10530 }
10531
10532 var es_object_values = {};
10533
10534 var $$1a = _export;
10535 var $values = objectToArray.values;
10536
10537 // `Object.values` method
10538 // https://tc39.es/ecma262/#sec-object.values
10539 $$1a({
10540 target: 'Object',
10541 stat: true
10542 }, {
10543 values: function values(O) {
10544 return $values(O);
10545 }
10546 });
10547
10548 var es_parseFloat = {};
10549
10550 var $$19 = _export;
10551 var $parseFloat = numberParseFloat;
10552
10553 // `parseFloat` method
10554 // https://tc39.es/ecma262/#sec-parsefloat-string
10555 $$19({
10556 global: true,
10557 forced: parseFloat != $parseFloat
10558 }, {
10559 parseFloat: $parseFloat
10560 });
10561
10562 var es_parseInt = {};
10563
10564 var $$18 = _export;
10565 var $parseInt = numberParseInt;
10566
10567 // `parseInt` method
10568 // https://tc39.es/ecma262/#sec-parseint-string-radix
10569 $$18({
10570 global: true,
10571 forced: parseInt != $parseInt
10572 }, {
10573 parseInt: $parseInt
10574 });
10575
10576 var es_promise = {};
10577
10578 var es_promise_constructor = {};
10579
10580 var $TypeError$9 = TypeError;
10581 var validateArgumentsLength$8 = function validateArgumentsLength(passed, required) {
10582 if (passed < required) throw $TypeError$9('Not enough arguments');
10583 return passed;
10584 };
10585 var validateArgumentsLength$9 = /*@__PURE__*/getDefaultExportFromCjs(validateArgumentsLength$8);
10586
10587 var userAgent$3 = engineUserAgent;
10588
10589 // eslint-disable-next-line redos/no-vulnerable -- safe
10590 var engineIsIos = /(?:ipad|iphone|ipod).*applewebkit/i.test(userAgent$3);
10591 var engineIsIos$1 = /*@__PURE__*/getDefaultExportFromCjs(engineIsIos);
10592
10593 var global$z = global$Z;
10594 var apply$6 = functionApply$1;
10595 var bind$6 = functionBindContext;
10596 var isCallable$b = isCallable$z;
10597 var hasOwn$b = hasOwnProperty_1;
10598 var fails$x = fails$1m;
10599 var html = html$2;
10600 var arraySlice$5 = arraySlice$a;
10601 var createElement = documentCreateElement$2;
10602 var validateArgumentsLength$7 = validateArgumentsLength$8;
10603 var IS_IOS$1 = engineIsIos;
10604 var IS_NODE$6 = engineIsNode;
10605 var set$1 = global$z.setImmediate;
10606 var clear = global$z.clearImmediate;
10607 var process$4 = global$z.process;
10608 var Dispatch = global$z.Dispatch;
10609 var Function$2 = global$z.Function;
10610 var MessageChannel = global$z.MessageChannel;
10611 var String$1 = global$z.String;
10612 var counter = 0;
10613 var queue$3 = {};
10614 var ONREADYSTATECHANGE = 'onreadystatechange';
10615 var $location, defer$1, channel, port;
10616 fails$x(function () {
10617 // Deno throws a ReferenceError on `location` access without `--location` flag
10618 $location = global$z.location;
10619 });
10620 var run = function run(id) {
10621 if (hasOwn$b(queue$3, id)) {
10622 var fn = queue$3[id];
10623 delete queue$3[id];
10624 fn();
10625 }
10626 };
10627 var runner = function runner(id) {
10628 return function () {
10629 run(id);
10630 };
10631 };
10632 var eventListener = function eventListener(event) {
10633 run(event.data);
10634 };
10635 var globalPostMessageDefer = function globalPostMessageDefer(id) {
10636 // old engines have not location.origin
10637 global$z.postMessage(String$1(id), $location.protocol + '//' + $location.host);
10638 };
10639
10640 // Node.js 0.9+ & IE10+ has setImmediate, otherwise:
10641 if (!set$1 || !clear) {
10642 set$1 = function setImmediate(handler) {
10643 validateArgumentsLength$7(arguments.length, 1);
10644 var fn = isCallable$b(handler) ? handler : Function$2(handler);
10645 var args = arraySlice$5(arguments, 1);
10646 queue$3[++counter] = function () {
10647 apply$6(fn, undefined, args);
10648 };
10649 defer$1(counter);
10650 return counter;
10651 };
10652 clear = function clearImmediate(id) {
10653 delete queue$3[id];
10654 };
10655 // Node.js 0.8-
10656 if (IS_NODE$6) {
10657 defer$1 = function defer(id) {
10658 process$4.nextTick(runner(id));
10659 };
10660 // Sphere (JS game engine) Dispatch API
10661 } else if (Dispatch && Dispatch.now) {
10662 defer$1 = function defer(id) {
10663 Dispatch.now(runner(id));
10664 };
10665 // Browsers with MessageChannel, includes WebWorkers
10666 // except iOS - https://github.com/zloirock/core-js/issues/624
10667 } else if (MessageChannel && !IS_IOS$1) {
10668 channel = new MessageChannel();
10669 port = channel.port2;
10670 channel.port1.onmessage = eventListener;
10671 defer$1 = bind$6(port.postMessage, port);
10672 // Browsers with postMessage, skip WebWorkers
10673 // IE8 has postMessage, but it's sync & typeof its postMessage is 'object'
10674 } else if (global$z.addEventListener && isCallable$b(global$z.postMessage) && !global$z.importScripts && $location && $location.protocol !== 'file:' && !fails$x(globalPostMessageDefer)) {
10675 defer$1 = globalPostMessageDefer;
10676 global$z.addEventListener('message', eventListener, false);
10677 // IE8-
10678 } else if (ONREADYSTATECHANGE in createElement('script')) {
10679 defer$1 = function defer(id) {
10680 html.appendChild(createElement('script'))[ONREADYSTATECHANGE] = function () {
10681 html.removeChild(this);
10682 run(id);
10683 };
10684 };
10685 // Rest old browsers
10686 } else {
10687 defer$1 = function defer(id) {
10688 setTimeout(runner(id), 0);
10689 };
10690 }
10691 }
10692 var task$1 = {
10693 set: set$1,
10694 clear: clear
10695 };
10696 var task$2 = /*@__PURE__*/getDefaultExportFromCjs(task$1);
10697
10698 var Queue$2 = function Queue() {
10699 this.head = null;
10700 this.tail = null;
10701 };
10702 Queue$2.prototype = {
10703 add: function add(item) {
10704 var entry = {
10705 item: item,
10706 next: null
10707 };
10708 var tail = this.tail;
10709 if (tail) tail.next = entry;else this.head = entry;
10710 this.tail = entry;
10711 },
10712 get: function get() {
10713 var entry = this.head;
10714 if (entry) {
10715 var next = this.head = entry.next;
10716 if (next === null) this.tail = null;
10717 return entry.item;
10718 }
10719 }
10720 };
10721 var queue$1 = Queue$2;
10722 var queue$2 = /*@__PURE__*/getDefaultExportFromCjs(queue$1);
10723
10724 var userAgent$2 = engineUserAgent;
10725 var engineIsIosPebble = /ipad|iphone|ipod/i.test(userAgent$2) && typeof Pebble != 'undefined';
10726 var engineIsIosPebble$1 = /*@__PURE__*/getDefaultExportFromCjs(engineIsIosPebble);
10727
10728 var userAgent$1 = engineUserAgent;
10729 var engineIsWebosWebkit = /web0s(?!.*chrome)/i.test(userAgent$1);
10730 var engineIsWebosWebkit$1 = /*@__PURE__*/getDefaultExportFromCjs(engineIsWebosWebkit);
10731
10732 var global$y = global$Z;
10733 var bind$5 = functionBindContext;
10734 var getOwnPropertyDescriptor$4 = objectGetOwnPropertyDescriptor.f;
10735 var macrotask = task$1.set;
10736 var Queue$1 = queue$1;
10737 var IS_IOS = engineIsIos;
10738 var IS_IOS_PEBBLE = engineIsIosPebble;
10739 var IS_WEBOS_WEBKIT = engineIsWebosWebkit;
10740 var IS_NODE$5 = engineIsNode;
10741 var MutationObserver = global$y.MutationObserver || global$y.WebKitMutationObserver;
10742 var document$2 = global$y.document;
10743 var process$3 = global$y.process;
10744 var Promise$1 = global$y.Promise;
10745 // Node.js 11 shows ExperimentalWarning on getting `queueMicrotask`
10746 var queueMicrotaskDescriptor = getOwnPropertyDescriptor$4(global$y, 'queueMicrotask');
10747 var microtask$2 = queueMicrotaskDescriptor && queueMicrotaskDescriptor.value;
10748 var notify$1, toggle, node, promise, then;
10749
10750 // modern engines have queueMicrotask method
10751 if (!microtask$2) {
10752 var queue = new Queue$1();
10753 var flush = function flush() {
10754 var parent, fn;
10755 if (IS_NODE$5 && (parent = process$3.domain)) parent.exit();
10756 while (fn = queue.get()) try {
10757 fn();
10758 } catch (error) {
10759 if (queue.head) notify$1();
10760 throw error;
10761 }
10762 if (parent) parent.enter();
10763 };
10764
10765 // browsers with MutationObserver, except iOS - https://github.com/zloirock/core-js/issues/339
10766 // also except WebOS Webkit https://github.com/zloirock/core-js/issues/898
10767 if (!IS_IOS && !IS_NODE$5 && !IS_WEBOS_WEBKIT && MutationObserver && document$2) {
10768 toggle = true;
10769 node = document$2.createTextNode('');
10770 new MutationObserver(flush).observe(node, {
10771 characterData: true
10772 });
10773 notify$1 = function notify() {
10774 node.data = toggle = !toggle;
10775 };
10776 // environments with maybe non-completely correct, but existent Promise
10777 } else if (!IS_IOS_PEBBLE && Promise$1 && Promise$1.resolve) {
10778 // Promise.resolve without an argument throws an error in LG WebOS 2
10779 promise = Promise$1.resolve(undefined);
10780 // workaround of WebKit ~ iOS Safari 10.1 bug
10781 promise.constructor = Promise$1;
10782 then = bind$5(promise.then, promise);
10783 notify$1 = function notify() {
10784 then(flush);
10785 };
10786 // Node.js without promises
10787 } else if (IS_NODE$5) {
10788 notify$1 = function notify() {
10789 process$3.nextTick(flush);
10790 };
10791 // for other environments - macrotask based on:
10792 // - setImmediate
10793 // - MessageChannel
10794 // - window.postMessage
10795 // - onreadystatechange
10796 // - setTimeout
10797 } else {
10798 // `webpack` dev server bug on IE global methods - use bind(fn, global)
10799 macrotask = bind$5(macrotask, global$y);
10800 notify$1 = function notify() {
10801 macrotask(flush);
10802 };
10803 }
10804 microtask$2 = function microtask(fn) {
10805 if (!queue.head) notify$1();
10806 queue.add(fn);
10807 };
10808 }
10809 var microtask_1 = microtask$2;
10810 var microtask$3 = /*@__PURE__*/getDefaultExportFromCjs(microtask_1);
10811
10812 var hostReportErrors$1 = function hostReportErrors(a, b) {
10813 try {
10814 // eslint-disable-next-line no-console -- safe
10815 arguments.length == 1 ? console.error(a) : console.error(a, b);
10816 } catch (error) {/* empty */}
10817 };
10818 var hostReportErrors$2 = /*@__PURE__*/getDefaultExportFromCjs(hostReportErrors$1);
10819
10820 var perform$5 = function perform(exec) {
10821 try {
10822 return {
10823 error: false,
10824 value: exec()
10825 };
10826 } catch (error) {
10827 return {
10828 error: true,
10829 value: error
10830 };
10831 }
10832 };
10833 var perform$6 = /*@__PURE__*/getDefaultExportFromCjs(perform$5);
10834
10835 var global$x = global$Z;
10836 var promiseNativeConstructor = global$x.Promise;
10837 var promiseNativeConstructor$1 = /*@__PURE__*/getDefaultExportFromCjs(promiseNativeConstructor);
10838
10839 var engineIsDeno = (typeof Deno === "undefined" ? "undefined" : _typeof(Deno)) == 'object' && Deno && _typeof(Deno.version) == 'object';
10840 var engineIsDeno$1 = /*@__PURE__*/getDefaultExportFromCjs(engineIsDeno);
10841
10842 var IS_DENO$2 = engineIsDeno;
10843 var IS_NODE$4 = engineIsNode;
10844 var engineIsBrowser = !IS_DENO$2 && !IS_NODE$4 && (typeof window === "undefined" ? "undefined" : _typeof(window)) == 'object' && (typeof document === "undefined" ? "undefined" : _typeof(document)) == 'object';
10845 var engineIsBrowser$1 = /*@__PURE__*/getDefaultExportFromCjs(engineIsBrowser);
10846
10847 var global$w = global$Z;
10848 var NativePromiseConstructor$5 = promiseNativeConstructor;
10849 var isCallable$a = isCallable$z;
10850 var isForced$1 = isForced_1;
10851 var inspectSource = inspectSource$3;
10852 var wellKnownSymbol$b = wellKnownSymbol$z;
10853 var IS_BROWSER$1 = engineIsBrowser;
10854 var IS_DENO$1 = engineIsDeno;
10855 var IS_PURE$c = isPure;
10856 var V8_VERSION = engineV8Version;
10857 var NativePromisePrototype$3 = NativePromiseConstructor$5 && NativePromiseConstructor$5.prototype;
10858 var SPECIES$1 = wellKnownSymbol$b('species');
10859 var SUBCLASSING = false;
10860 var NATIVE_PROMISE_REJECTION_EVENT$1 = isCallable$a(global$w.PromiseRejectionEvent);
10861 var FORCED_PROMISE_CONSTRUCTOR$5 = isForced$1('Promise', function () {
10862 var PROMISE_CONSTRUCTOR_SOURCE = inspectSource(NativePromiseConstructor$5);
10863 var GLOBAL_CORE_JS_PROMISE = PROMISE_CONSTRUCTOR_SOURCE !== String(NativePromiseConstructor$5);
10864 // V8 6.6 (Node 10 and Chrome 66) have a bug with resolving custom thenables
10865 // https://bugs.chromium.org/p/chromium/issues/detail?id=830565
10866 // We can't detect it synchronously, so just check versions
10867 if (!GLOBAL_CORE_JS_PROMISE && V8_VERSION === 66) return true;
10868 // We need Promise#{ catch, finally } in the pure version for preventing prototype pollution
10869 if (IS_PURE$c && !(NativePromisePrototype$3['catch'] && NativePromisePrototype$3['finally'])) return true;
10870 // We can't use @@species feature detection in V8 since it causes
10871 // deoptimization and performance degradation
10872 // https://github.com/zloirock/core-js/issues/679
10873 if (!V8_VERSION || V8_VERSION < 51 || !/native code/.test(PROMISE_CONSTRUCTOR_SOURCE)) {
10874 // Detect correctness of subclassing with @@species support
10875 var promise = new NativePromiseConstructor$5(function (resolve) {
10876 resolve(1);
10877 });
10878 var FakePromise = function FakePromise(exec) {
10879 exec(function () {/* empty */}, function () {/* empty */});
10880 };
10881 var constructor = promise.constructor = {};
10882 constructor[SPECIES$1] = FakePromise;
10883 SUBCLASSING = promise.then(function () {/* empty */}) instanceof FakePromise;
10884 if (!SUBCLASSING) return true;
10885 // Unhandled rejections tracking support, NodeJS Promise without it fails @@species test
10886 }
10887 return !GLOBAL_CORE_JS_PROMISE && (IS_BROWSER$1 || IS_DENO$1) && !NATIVE_PROMISE_REJECTION_EVENT$1;
10888 });
10889 var promiseConstructorDetection = {
10890 CONSTRUCTOR: FORCED_PROMISE_CONSTRUCTOR$5,
10891 REJECTION_EVENT: NATIVE_PROMISE_REJECTION_EVENT$1,
10892 SUBCLASSING: SUBCLASSING
10893 };
10894 var promiseConstructorDetection$1 = /*@__PURE__*/getDefaultExportFromCjs(promiseConstructorDetection);
10895
10896 var newPromiseCapability$2 = {};
10897
10898 'use strict';
10899 var aCallable$9 = aCallable$l;
10900 var $TypeError$8 = TypeError;
10901 var PromiseCapability = function PromiseCapability(C) {
10902 var resolve, reject;
10903 this.promise = new C(function ($$resolve, $$reject) {
10904 if (resolve !== undefined || reject !== undefined) throw $TypeError$8('Bad Promise constructor');
10905 resolve = $$resolve;
10906 reject = $$reject;
10907 });
10908 this.resolve = aCallable$9(resolve);
10909 this.reject = aCallable$9(reject);
10910 };
10911
10912 // `NewPromiseCapability` abstract operation
10913 // https://tc39.es/ecma262/#sec-newpromisecapability
10914 var f = newPromiseCapability$2.f = function (C) {
10915 return new PromiseCapability(C);
10916 };
10917
10918 'use strict';
10919 var $$17 = _export;
10920 var IS_PURE$b = isPure;
10921 var IS_NODE$3 = engineIsNode;
10922 var global$v = global$Z;
10923 var call$p = functionCall;
10924 var defineBuiltIn$9 = defineBuiltIn$m;
10925 var setPrototypeOf$1 = objectSetPrototypeOf$1;
10926 var setToStringTag$4 = setToStringTag$d;
10927 var setSpecies$2 = setSpecies$6;
10928 var aCallable$8 = aCallable$l;
10929 var isCallable$9 = isCallable$z;
10930 var isObject$a = isObject$z;
10931 var anInstance$6 = anInstance$a;
10932 var speciesConstructor$4 = speciesConstructor$6;
10933 var task = task$1.set;
10934 var microtask$1 = microtask_1;
10935 var hostReportErrors = hostReportErrors$1;
10936 var perform$4 = perform$5;
10937 var Queue = queue$1;
10938 var InternalStateModule$7 = internalState;
10939 var NativePromiseConstructor$4 = promiseNativeConstructor;
10940 var PromiseConstructorDetection = promiseConstructorDetection;
10941 var newPromiseCapabilityModule$5 = newPromiseCapability$2;
10942 var PROMISE = 'Promise';
10943 var FORCED_PROMISE_CONSTRUCTOR$4 = PromiseConstructorDetection.CONSTRUCTOR;
10944 var NATIVE_PROMISE_REJECTION_EVENT = PromiseConstructorDetection.REJECTION_EVENT;
10945 var NATIVE_PROMISE_SUBCLASSING = PromiseConstructorDetection.SUBCLASSING;
10946 var getInternalPromiseState = InternalStateModule$7.getterFor(PROMISE);
10947 var setInternalState$7 = InternalStateModule$7.set;
10948 var NativePromisePrototype$2 = NativePromiseConstructor$4 && NativePromiseConstructor$4.prototype;
10949 var PromiseConstructor = NativePromiseConstructor$4;
10950 var PromisePrototype = NativePromisePrototype$2;
10951 var TypeError$4 = global$v.TypeError;
10952 var document$1 = global$v.document;
10953 var process$2 = global$v.process;
10954 var newPromiseCapability$1 = newPromiseCapabilityModule$5.f;
10955 var newGenericPromiseCapability = newPromiseCapability$1;
10956 var DISPATCH_EVENT = !!(document$1 && document$1.createEvent && global$v.dispatchEvent);
10957 var UNHANDLED_REJECTION = 'unhandledrejection';
10958 var REJECTION_HANDLED = 'rejectionhandled';
10959 var PENDING = 0;
10960 var FULFILLED = 1;
10961 var REJECTED = 2;
10962 var HANDLED = 1;
10963 var UNHANDLED = 2;
10964 var Internal, OwnPromiseCapability, PromiseWrapper, nativeThen;
10965
10966 // helpers
10967 var isThenable = function isThenable(it) {
10968 var then;
10969 return isObject$a(it) && isCallable$9(then = it.then) ? then : false;
10970 };
10971 var callReaction = function callReaction(reaction, state) {
10972 var value = state.value;
10973 var ok = state.state == FULFILLED;
10974 var handler = ok ? reaction.ok : reaction.fail;
10975 var resolve = reaction.resolve;
10976 var reject = reaction.reject;
10977 var domain = reaction.domain;
10978 var result, then, exited;
10979 try {
10980 if (handler) {
10981 if (!ok) {
10982 if (state.rejection === UNHANDLED) onHandleUnhandled(state);
10983 state.rejection = HANDLED;
10984 }
10985 if (handler === true) result = value;else {
10986 if (domain) domain.enter();
10987 result = handler(value); // can throw
10988 if (domain) {
10989 domain.exit();
10990 exited = true;
10991 }
10992 }
10993 if (result === reaction.promise) {
10994 reject(TypeError$4('Promise-chain cycle'));
10995 } else if (then = isThenable(result)) {
10996 call$p(then, result, resolve, reject);
10997 } else resolve(result);
10998 } else reject(value);
10999 } catch (error) {
11000 if (domain && !exited) domain.exit();
11001 reject(error);
11002 }
11003 };
11004 var notify = function notify(state, isReject) {
11005 if (state.notified) return;
11006 state.notified = true;
11007 microtask$1(function () {
11008 var reactions = state.reactions;
11009 var reaction;
11010 while (reaction = reactions.get()) {
11011 callReaction(reaction, state);
11012 }
11013 state.notified = false;
11014 if (isReject && !state.rejection) onUnhandled(state);
11015 });
11016 };
11017 var dispatchEvent = function dispatchEvent(name, promise, reason) {
11018 var event, handler;
11019 if (DISPATCH_EVENT) {
11020 event = document$1.createEvent('Event');
11021 event.promise = promise;
11022 event.reason = reason;
11023 event.initEvent(name, false, true);
11024 global$v.dispatchEvent(event);
11025 } else event = {
11026 promise: promise,
11027 reason: reason
11028 };
11029 if (!NATIVE_PROMISE_REJECTION_EVENT && (handler = global$v['on' + name])) handler(event);else if (name === UNHANDLED_REJECTION) hostReportErrors('Unhandled promise rejection', reason);
11030 };
11031 var onUnhandled = function onUnhandled(state) {
11032 call$p(task, global$v, function () {
11033 var promise = state.facade;
11034 var value = state.value;
11035 var IS_UNHANDLED = isUnhandled(state);
11036 var result;
11037 if (IS_UNHANDLED) {
11038 result = perform$4(function () {
11039 if (IS_NODE$3) {
11040 process$2.emit('unhandledRejection', value, promise);
11041 } else dispatchEvent(UNHANDLED_REJECTION, promise, value);
11042 });
11043 // Browsers should not trigger `rejectionHandled` event if it was handled here, NodeJS - should
11044 state.rejection = IS_NODE$3 || isUnhandled(state) ? UNHANDLED : HANDLED;
11045 if (result.error) throw result.value;
11046 }
11047 });
11048 };
11049 var isUnhandled = function isUnhandled(state) {
11050 return state.rejection !== HANDLED && !state.parent;
11051 };
11052 var onHandleUnhandled = function onHandleUnhandled(state) {
11053 call$p(task, global$v, function () {
11054 var promise = state.facade;
11055 if (IS_NODE$3) {
11056 process$2.emit('rejectionHandled', promise);
11057 } else dispatchEvent(REJECTION_HANDLED, promise, state.value);
11058 });
11059 };
11060 var bind$4 = function bind(fn, state, unwrap) {
11061 return function (value) {
11062 fn(state, value, unwrap);
11063 };
11064 };
11065 var internalReject = function internalReject(state, value, unwrap) {
11066 if (state.done) return;
11067 state.done = true;
11068 if (unwrap) state = unwrap;
11069 state.value = value;
11070 state.state = REJECTED;
11071 notify(state, true);
11072 };
11073 var internalResolve = function internalResolve(state, value, unwrap) {
11074 if (state.done) return;
11075 state.done = true;
11076 if (unwrap) state = unwrap;
11077 try {
11078 if (state.facade === value) throw TypeError$4("Promise can't be resolved itself");
11079 var then = isThenable(value);
11080 if (then) {
11081 microtask$1(function () {
11082 var wrapper = {
11083 done: false
11084 };
11085 try {
11086 call$p(then, value, bind$4(internalResolve, wrapper, state), bind$4(internalReject, wrapper, state));
11087 } catch (error) {
11088 internalReject(wrapper, error, state);
11089 }
11090 });
11091 } else {
11092 state.value = value;
11093 state.state = FULFILLED;
11094 notify(state, false);
11095 }
11096 } catch (error) {
11097 internalReject({
11098 done: false
11099 }, error, state);
11100 }
11101 };
11102
11103 // constructor polyfill
11104 if (FORCED_PROMISE_CONSTRUCTOR$4) {
11105 // 25.4.3.1 Promise(executor)
11106 PromiseConstructor = function Promise(executor) {
11107 anInstance$6(this, PromisePrototype);
11108 aCallable$8(executor);
11109 call$p(Internal, this);
11110 var state = getInternalPromiseState(this);
11111 try {
11112 executor(bind$4(internalResolve, state), bind$4(internalReject, state));
11113 } catch (error) {
11114 internalReject(state, error);
11115 }
11116 };
11117 PromisePrototype = PromiseConstructor.prototype;
11118
11119 // eslint-disable-next-line no-unused-vars -- required for `.length`
11120 Internal = function Promise(executor) {
11121 setInternalState$7(this, {
11122 type: PROMISE,
11123 done: false,
11124 notified: false,
11125 parent: false,
11126 reactions: new Queue(),
11127 rejection: false,
11128 state: PENDING,
11129 value: undefined
11130 });
11131 };
11132
11133 // `Promise.prototype.then` method
11134 // https://tc39.es/ecma262/#sec-promise.prototype.then
11135 Internal.prototype = defineBuiltIn$9(PromisePrototype, 'then', function then(onFulfilled, onRejected) {
11136 var state = getInternalPromiseState(this);
11137 var reaction = newPromiseCapability$1(speciesConstructor$4(this, PromiseConstructor));
11138 state.parent = true;
11139 reaction.ok = isCallable$9(onFulfilled) ? onFulfilled : true;
11140 reaction.fail = isCallable$9(onRejected) && onRejected;
11141 reaction.domain = IS_NODE$3 ? process$2.domain : undefined;
11142 if (state.state == PENDING) state.reactions.add(reaction);else microtask$1(function () {
11143 callReaction(reaction, state);
11144 });
11145 return reaction.promise;
11146 });
11147 OwnPromiseCapability = function OwnPromiseCapability() {
11148 var promise = new Internal();
11149 var state = getInternalPromiseState(promise);
11150 this.promise = promise;
11151 this.resolve = bind$4(internalResolve, state);
11152 this.reject = bind$4(internalReject, state);
11153 };
11154 newPromiseCapabilityModule$5.f = newPromiseCapability$1 = function newPromiseCapability(C) {
11155 return C === PromiseConstructor || C === PromiseWrapper ? new OwnPromiseCapability(C) : newGenericPromiseCapability(C);
11156 };
11157 if (!IS_PURE$b && isCallable$9(NativePromiseConstructor$4) && NativePromisePrototype$2 !== Object.prototype) {
11158 nativeThen = NativePromisePrototype$2.then;
11159 if (!NATIVE_PROMISE_SUBCLASSING) {
11160 // make `Promise#then` return a polyfilled `Promise` for native promise-based APIs
11161 defineBuiltIn$9(NativePromisePrototype$2, 'then', function then(onFulfilled, onRejected) {
11162 var that = this;
11163 return new PromiseConstructor(function (resolve, reject) {
11164 call$p(nativeThen, that, resolve, reject);
11165 }).then(onFulfilled, onRejected);
11166 // https://github.com/zloirock/core-js/issues/640
11167 }, {
11168 unsafe: true
11169 });
11170 }
11171
11172 // make `.constructor === Promise` work for native promise-based APIs
11173 try {
11174 delete NativePromisePrototype$2.constructor;
11175 } catch (error) {/* empty */}
11176
11177 // make `instanceof Promise` work for native promise-based APIs
11178 if (setPrototypeOf$1) {
11179 setPrototypeOf$1(NativePromisePrototype$2, PromisePrototype);
11180 }
11181 }
11182 }
11183 $$17({
11184 global: true,
11185 constructor: true,
11186 wrap: true,
11187 forced: FORCED_PROMISE_CONSTRUCTOR$4
11188 }, {
11189 Promise: PromiseConstructor
11190 });
11191 setToStringTag$4(PromiseConstructor, PROMISE, false, true);
11192 setSpecies$2(PROMISE);
11193
11194 var es_promise_all = {};
11195
11196 var NativePromiseConstructor$3 = promiseNativeConstructor;
11197 var checkCorrectnessOfIteration$1 = checkCorrectnessOfIteration$4;
11198 var FORCED_PROMISE_CONSTRUCTOR$3 = promiseConstructorDetection.CONSTRUCTOR;
11199 var promiseStaticsIncorrectIteration = FORCED_PROMISE_CONSTRUCTOR$3 || !checkCorrectnessOfIteration$1(function (iterable) {
11200 NativePromiseConstructor$3.all(iterable).then(undefined, function () {/* empty */});
11201 });
11202 var promiseStaticsIncorrectIteration$1 = /*@__PURE__*/getDefaultExportFromCjs(promiseStaticsIncorrectIteration);
11203
11204 'use strict';
11205 var $$16 = _export;
11206 var call$o = functionCall;
11207 var aCallable$7 = aCallable$l;
11208 var newPromiseCapabilityModule$4 = newPromiseCapability$2;
11209 var perform$3 = perform$5;
11210 var iterate$5 = iterate$a;
11211 var PROMISE_STATICS_INCORRECT_ITERATION$3 = promiseStaticsIncorrectIteration;
11212
11213 // `Promise.all` method
11214 // https://tc39.es/ecma262/#sec-promise.all
11215 $$16({
11216 target: 'Promise',
11217 stat: true,
11218 forced: PROMISE_STATICS_INCORRECT_ITERATION$3
11219 }, {
11220 all: function all(iterable) {
11221 var C = this;
11222 var capability = newPromiseCapabilityModule$4.f(C);
11223 var resolve = capability.resolve;
11224 var reject = capability.reject;
11225 var result = perform$3(function () {
11226 var $promiseResolve = aCallable$7(C.resolve);
11227 var values = [];
11228 var counter = 0;
11229 var remaining = 1;
11230 iterate$5(iterable, function (promise) {
11231 var index = counter++;
11232 var alreadyCalled = false;
11233 remaining++;
11234 call$o($promiseResolve, C, promise).then(function (value) {
11235 if (alreadyCalled) return;
11236 alreadyCalled = true;
11237 values[index] = value;
11238 --remaining || resolve(values);
11239 }, reject);
11240 });
11241 --remaining || resolve(values);
11242 });
11243 if (result.error) reject(result.value);
11244 return capability.promise;
11245 }
11246 });
11247
11248 var es_promise_catch = {};
11249
11250 'use strict';
11251 var $$15 = _export;
11252 var IS_PURE$a = isPure;
11253 var FORCED_PROMISE_CONSTRUCTOR$2 = promiseConstructorDetection.CONSTRUCTOR;
11254 var NativePromiseConstructor$2 = promiseNativeConstructor;
11255 var getBuiltIn$a = getBuiltIn$m;
11256 var isCallable$8 = isCallable$z;
11257 var defineBuiltIn$8 = defineBuiltIn$m;
11258 var NativePromisePrototype$1 = NativePromiseConstructor$2 && NativePromiseConstructor$2.prototype;
11259
11260 // `Promise.prototype.catch` method
11261 // https://tc39.es/ecma262/#sec-promise.prototype.catch
11262 $$15({
11263 target: 'Promise',
11264 proto: true,
11265 forced: FORCED_PROMISE_CONSTRUCTOR$2,
11266 real: true
11267 }, {
11268 'catch': function _catch(onRejected) {
11269 return this.then(undefined, onRejected);
11270 }
11271 });
11272
11273 // makes sure that native promise-based APIs `Promise#catch` properly works with patched `Promise#then`
11274 if (!IS_PURE$a && isCallable$8(NativePromiseConstructor$2)) {
11275 var method$1 = getBuiltIn$a('Promise').prototype['catch'];
11276 if (NativePromisePrototype$1['catch'] !== method$1) {
11277 defineBuiltIn$8(NativePromisePrototype$1, 'catch', method$1, {
11278 unsafe: true
11279 });
11280 }
11281 }
11282
11283 var es_promise_race = {};
11284
11285 'use strict';
11286 var $$14 = _export;
11287 var call$n = functionCall;
11288 var aCallable$6 = aCallable$l;
11289 var newPromiseCapabilityModule$3 = newPromiseCapability$2;
11290 var perform$2 = perform$5;
11291 var iterate$4 = iterate$a;
11292 var PROMISE_STATICS_INCORRECT_ITERATION$2 = promiseStaticsIncorrectIteration;
11293
11294 // `Promise.race` method
11295 // https://tc39.es/ecma262/#sec-promise.race
11296 $$14({
11297 target: 'Promise',
11298 stat: true,
11299 forced: PROMISE_STATICS_INCORRECT_ITERATION$2
11300 }, {
11301 race: function race(iterable) {
11302 var C = this;
11303 var capability = newPromiseCapabilityModule$3.f(C);
11304 var reject = capability.reject;
11305 var result = perform$2(function () {
11306 var $promiseResolve = aCallable$6(C.resolve);
11307 iterate$4(iterable, function (promise) {
11308 call$n($promiseResolve, C, promise).then(capability.resolve, reject);
11309 });
11310 });
11311 if (result.error) reject(result.value);
11312 return capability.promise;
11313 }
11314 });
11315
11316 var es_promise_reject = {};
11317
11318 'use strict';
11319 var $$13 = _export;
11320 var call$m = functionCall;
11321 var newPromiseCapabilityModule$2 = newPromiseCapability$2;
11322 var FORCED_PROMISE_CONSTRUCTOR$1 = promiseConstructorDetection.CONSTRUCTOR;
11323
11324 // `Promise.reject` method
11325 // https://tc39.es/ecma262/#sec-promise.reject
11326 $$13({
11327 target: 'Promise',
11328 stat: true,
11329 forced: FORCED_PROMISE_CONSTRUCTOR$1
11330 }, {
11331 reject: function reject(r) {
11332 var capability = newPromiseCapabilityModule$2.f(this);
11333 call$m(capability.reject, undefined, r);
11334 return capability.promise;
11335 }
11336 });
11337
11338 var es_promise_resolve = {};
11339
11340 var anObject$o = anObject$D;
11341 var isObject$9 = isObject$z;
11342 var newPromiseCapability = newPromiseCapability$2;
11343 var promiseResolve$2 = function promiseResolve(C, x) {
11344 anObject$o(C);
11345 if (isObject$9(x) && x.constructor === C) return x;
11346 var promiseCapability = newPromiseCapability.f(C);
11347 var resolve = promiseCapability.resolve;
11348 resolve(x);
11349 return promiseCapability.promise;
11350 };
11351 var promiseResolve$3 = /*@__PURE__*/getDefaultExportFromCjs(promiseResolve$2);
11352
11353 'use strict';
11354 var $$12 = _export;
11355 var getBuiltIn$9 = getBuiltIn$m;
11356 var IS_PURE$9 = isPure;
11357 var NativePromiseConstructor$1 = promiseNativeConstructor;
11358 var FORCED_PROMISE_CONSTRUCTOR = promiseConstructorDetection.CONSTRUCTOR;
11359 var promiseResolve$1 = promiseResolve$2;
11360 var PromiseConstructorWrapper = getBuiltIn$9('Promise');
11361 var CHECK_WRAPPER = IS_PURE$9 && !FORCED_PROMISE_CONSTRUCTOR;
11362
11363 // `Promise.resolve` method
11364 // https://tc39.es/ecma262/#sec-promise.resolve
11365 $$12({
11366 target: 'Promise',
11367 stat: true,
11368 forced: IS_PURE$9 || FORCED_PROMISE_CONSTRUCTOR
11369 }, {
11370 resolve: function resolve(x) {
11371 return promiseResolve$1(CHECK_WRAPPER && this === PromiseConstructorWrapper ? NativePromiseConstructor$1 : this, x);
11372 }
11373 });
11374
11375 var es_promise_allSettled = {};
11376
11377 'use strict';
11378 var $$11 = _export;
11379 var call$l = functionCall;
11380 var aCallable$5 = aCallable$l;
11381 var newPromiseCapabilityModule$1 = newPromiseCapability$2;
11382 var perform$1 = perform$5;
11383 var iterate$3 = iterate$a;
11384 var PROMISE_STATICS_INCORRECT_ITERATION$1 = promiseStaticsIncorrectIteration;
11385
11386 // `Promise.allSettled` method
11387 // https://tc39.es/ecma262/#sec-promise.allsettled
11388 $$11({
11389 target: 'Promise',
11390 stat: true,
11391 forced: PROMISE_STATICS_INCORRECT_ITERATION$1
11392 }, {
11393 allSettled: function allSettled(iterable) {
11394 var C = this;
11395 var capability = newPromiseCapabilityModule$1.f(C);
11396 var resolve = capability.resolve;
11397 var reject = capability.reject;
11398 var result = perform$1(function () {
11399 var promiseResolve = aCallable$5(C.resolve);
11400 var values = [];
11401 var counter = 0;
11402 var remaining = 1;
11403 iterate$3(iterable, function (promise) {
11404 var index = counter++;
11405 var alreadyCalled = false;
11406 remaining++;
11407 call$l(promiseResolve, C, promise).then(function (value) {
11408 if (alreadyCalled) return;
11409 alreadyCalled = true;
11410 values[index] = {
11411 status: 'fulfilled',
11412 value: value
11413 };
11414 --remaining || resolve(values);
11415 }, function (error) {
11416 if (alreadyCalled) return;
11417 alreadyCalled = true;
11418 values[index] = {
11419 status: 'rejected',
11420 reason: error
11421 };
11422 --remaining || resolve(values);
11423 });
11424 });
11425 --remaining || resolve(values);
11426 });
11427 if (result.error) reject(result.value);
11428 return capability.promise;
11429 }
11430 });
11431
11432 var es_promise_any = {};
11433
11434 'use strict';
11435 var $$10 = _export;
11436 var call$k = functionCall;
11437 var aCallable$4 = aCallable$l;
11438 var getBuiltIn$8 = getBuiltIn$m;
11439 var newPromiseCapabilityModule = newPromiseCapability$2;
11440 var perform = perform$5;
11441 var iterate$2 = iterate$a;
11442 var PROMISE_STATICS_INCORRECT_ITERATION = promiseStaticsIncorrectIteration;
11443 var PROMISE_ANY_ERROR = 'No one promise resolved';
11444
11445 // `Promise.any` method
11446 // https://tc39.es/ecma262/#sec-promise.any
11447 $$10({
11448 target: 'Promise',
11449 stat: true,
11450 forced: PROMISE_STATICS_INCORRECT_ITERATION
11451 }, {
11452 any: function any(iterable) {
11453 var C = this;
11454 var AggregateError = getBuiltIn$8('AggregateError');
11455 var capability = newPromiseCapabilityModule.f(C);
11456 var resolve = capability.resolve;
11457 var reject = capability.reject;
11458 var result = perform(function () {
11459 var promiseResolve = aCallable$4(C.resolve);
11460 var errors = [];
11461 var counter = 0;
11462 var remaining = 1;
11463 var alreadyResolved = false;
11464 iterate$2(iterable, function (promise) {
11465 var index = counter++;
11466 var alreadyRejected = false;
11467 remaining++;
11468 call$k(promiseResolve, C, promise).then(function (value) {
11469 if (alreadyRejected || alreadyResolved) return;
11470 alreadyResolved = true;
11471 resolve(value);
11472 }, function (error) {
11473 if (alreadyRejected || alreadyResolved) return;
11474 alreadyRejected = true;
11475 errors[index] = error;
11476 --remaining || reject(new AggregateError(errors, PROMISE_ANY_ERROR));
11477 });
11478 });
11479 --remaining || reject(new AggregateError(errors, PROMISE_ANY_ERROR));
11480 });
11481 if (result.error) reject(result.value);
11482 return capability.promise;
11483 }
11484 });
11485
11486 var es_promise_finally = {};
11487
11488 'use strict';
11489 var $$$ = _export;
11490 var IS_PURE$8 = isPure;
11491 var NativePromiseConstructor = promiseNativeConstructor;
11492 var fails$w = fails$1m;
11493 var getBuiltIn$7 = getBuiltIn$m;
11494 var isCallable$7 = isCallable$z;
11495 var speciesConstructor$3 = speciesConstructor$6;
11496 var promiseResolve = promiseResolve$2;
11497 var defineBuiltIn$7 = defineBuiltIn$m;
11498 var NativePromisePrototype = NativePromiseConstructor && NativePromiseConstructor.prototype;
11499
11500 // Safari bug https://bugs.webkit.org/show_bug.cgi?id=200829
11501 var NON_GENERIC = !!NativePromiseConstructor && fails$w(function () {
11502 // eslint-disable-next-line unicorn/no-thenable -- required for testing
11503 NativePromisePrototype['finally'].call({
11504 then: function then() {/* empty */}
11505 }, function () {/* empty */});
11506 });
11507
11508 // `Promise.prototype.finally` method
11509 // https://tc39.es/ecma262/#sec-promise.prototype.finally
11510 $$$({
11511 target: 'Promise',
11512 proto: true,
11513 real: true,
11514 forced: NON_GENERIC
11515 }, {
11516 'finally': function _finally(onFinally) {
11517 var C = speciesConstructor$3(this, getBuiltIn$7('Promise'));
11518 var isFunction = isCallable$7(onFinally);
11519 return this.then(isFunction ? function (x) {
11520 return promiseResolve(C, onFinally()).then(function () {
11521 return x;
11522 });
11523 } : onFinally, isFunction ? function (e) {
11524 return promiseResolve(C, onFinally()).then(function () {
11525 throw e;
11526 });
11527 } : onFinally);
11528 }
11529 });
11530
11531 // makes sure that native promise-based APIs `Promise#finally` properly works with patched `Promise#then`
11532 if (!IS_PURE$8 && isCallable$7(NativePromiseConstructor)) {
11533 var method = getBuiltIn$7('Promise').prototype['finally'];
11534 if (NativePromisePrototype['finally'] !== method) {
11535 defineBuiltIn$7(NativePromisePrototype, 'finally', method, {
11536 unsafe: true
11537 });
11538 }
11539 }
11540
11541 var es_reflect_apply = {};
11542
11543 var $$_ = _export;
11544 var functionApply = functionApply$1;
11545 var aCallable$3 = aCallable$l;
11546 var anObject$n = anObject$D;
11547 var fails$v = fails$1m;
11548
11549 // MS Edge argumentsList argument is optional
11550 var OPTIONAL_ARGUMENTS_LIST = !fails$v(function () {
11551 // eslint-disable-next-line es/no-reflect -- required for testing
11552 Reflect.apply(function () {/* empty */});
11553 });
11554
11555 // `Reflect.apply` method
11556 // https://tc39.es/ecma262/#sec-reflect.apply
11557 $$_({
11558 target: 'Reflect',
11559 stat: true,
11560 forced: OPTIONAL_ARGUMENTS_LIST
11561 }, {
11562 apply: function apply(target, thisArgument, argumentsList) {
11563 return functionApply(aCallable$3(target), thisArgument, anObject$n(argumentsList));
11564 }
11565 });
11566
11567 var es_reflect_construct = {};
11568
11569 var $$Z = _export;
11570 var getBuiltIn$6 = getBuiltIn$m;
11571 var apply$5 = functionApply$1;
11572 var bind$3 = functionBind;
11573 var aConstructor$1 = aConstructor$3;
11574 var anObject$m = anObject$D;
11575 var isObject$8 = isObject$z;
11576 var create$4 = objectCreate;
11577 var fails$u = fails$1m;
11578 var nativeConstruct = getBuiltIn$6('Reflect', 'construct');
11579 var ObjectPrototype = Object.prototype;
11580 var push$7 = [].push;
11581
11582 // `Reflect.construct` method
11583 // https://tc39.es/ecma262/#sec-reflect.construct
11584 // MS Edge supports only 2 arguments and argumentsList argument is optional
11585 // FF Nightly sets third argument as `new.target`, but does not create `this` from it
11586 var NEW_TARGET_BUG = fails$u(function () {
11587 function F() {/* empty */}
11588 return !(nativeConstruct(function () {/* empty */}, [], F) instanceof F);
11589 });
11590 var ARGS_BUG = !fails$u(function () {
11591 nativeConstruct(function () {/* empty */});
11592 });
11593 var FORCED$5 = NEW_TARGET_BUG || ARGS_BUG;
11594 $$Z({
11595 target: 'Reflect',
11596 stat: true,
11597 forced: FORCED$5,
11598 sham: FORCED$5
11599 }, {
11600 construct: function construct(Target, args /* , newTarget */) {
11601 aConstructor$1(Target);
11602 anObject$m(args);
11603 var newTarget = arguments.length < 3 ? Target : aConstructor$1(arguments[2]);
11604 if (ARGS_BUG && !NEW_TARGET_BUG) return nativeConstruct(Target, args, newTarget);
11605 if (Target == newTarget) {
11606 // w/o altered newTarget, optimization for 0-4 arguments
11607 switch (args.length) {
11608 case 0:
11609 return new Target();
11610 case 1:
11611 return new Target(args[0]);
11612 case 2:
11613 return new Target(args[0], args[1]);
11614 case 3:
11615 return new Target(args[0], args[1], args[2]);
11616 case 4:
11617 return new Target(args[0], args[1], args[2], args[3]);
11618 }
11619 // w/o altered newTarget, lot of arguments case
11620 var $args = [null];
11621 apply$5(push$7, $args, args);
11622 return new (apply$5(bind$3, Target, $args))();
11623 }
11624 // with altered newTarget, not support built-in constructors
11625 var proto = newTarget.prototype;
11626 var instance = create$4(isObject$8(proto) ? proto : ObjectPrototype);
11627 var result = apply$5(Target, instance, args);
11628 return isObject$8(result) ? result : instance;
11629 }
11630 });
11631
11632 var es_reflect_defineProperty = {};
11633
11634 var $$Y = _export;
11635 var DESCRIPTORS$d = descriptors;
11636 var anObject$l = anObject$D;
11637 var toPropertyKey$1 = toPropertyKey$8;
11638 var definePropertyModule$2 = objectDefineProperty;
11639 var fails$t = fails$1m;
11640
11641 // MS Edge has broken Reflect.defineProperty - throwing instead of returning false
11642 var ERROR_INSTEAD_OF_FALSE = fails$t(function () {
11643 // eslint-disable-next-line es/no-reflect -- required for testing
11644 Reflect.defineProperty(definePropertyModule$2.f({}, 1, {
11645 value: 1
11646 }), 1, {
11647 value: 2
11648 });
11649 });
11650
11651 // `Reflect.defineProperty` method
11652 // https://tc39.es/ecma262/#sec-reflect.defineproperty
11653 $$Y({
11654 target: 'Reflect',
11655 stat: true,
11656 forced: ERROR_INSTEAD_OF_FALSE,
11657 sham: !DESCRIPTORS$d
11658 }, {
11659 defineProperty: function defineProperty(target, propertyKey, attributes) {
11660 anObject$l(target);
11661 var key = toPropertyKey$1(propertyKey);
11662 anObject$l(attributes);
11663 try {
11664 definePropertyModule$2.f(target, key, attributes);
11665 return true;
11666 } catch (error) {
11667 return false;
11668 }
11669 }
11670 });
11671
11672 var es_reflect_deleteProperty = {};
11673
11674 var $$X = _export;
11675 var anObject$k = anObject$D;
11676 var getOwnPropertyDescriptor$3 = objectGetOwnPropertyDescriptor.f;
11677
11678 // `Reflect.deleteProperty` method
11679 // https://tc39.es/ecma262/#sec-reflect.deleteproperty
11680 $$X({
11681 target: 'Reflect',
11682 stat: true
11683 }, {
11684 deleteProperty: function deleteProperty(target, propertyKey) {
11685 var descriptor = getOwnPropertyDescriptor$3(anObject$k(target), propertyKey);
11686 return descriptor && !descriptor.configurable ? false : delete target[propertyKey];
11687 }
11688 });
11689
11690 var es_reflect_get = {};
11691
11692 var hasOwn$a = hasOwnProperty_1;
11693 var isDataDescriptor$2 = function isDataDescriptor(descriptor) {
11694 return descriptor !== undefined && (hasOwn$a(descriptor, 'value') || hasOwn$a(descriptor, 'writable'));
11695 };
11696 var isDataDescriptor$3 = /*@__PURE__*/getDefaultExportFromCjs(isDataDescriptor$2);
11697
11698 var $$W = _export;
11699 var call$j = functionCall;
11700 var isObject$7 = isObject$z;
11701 var anObject$j = anObject$D;
11702 var isDataDescriptor$1 = isDataDescriptor$2;
11703 var getOwnPropertyDescriptorModule$3 = objectGetOwnPropertyDescriptor;
11704 var getPrototypeOf$1 = objectGetPrototypeOf$1;
11705
11706 // `Reflect.get` method
11707 // https://tc39.es/ecma262/#sec-reflect.get
11708 function get$2(target, propertyKey /* , receiver */) {
11709 var receiver = arguments.length < 3 ? target : arguments[2];
11710 var descriptor, prototype;
11711 if (anObject$j(target) === receiver) return target[propertyKey];
11712 descriptor = getOwnPropertyDescriptorModule$3.f(target, propertyKey);
11713 if (descriptor) return isDataDescriptor$1(descriptor) ? descriptor.value : descriptor.get === undefined ? undefined : call$j(descriptor.get, receiver);
11714 if (isObject$7(prototype = getPrototypeOf$1(target))) return get$2(prototype, propertyKey, receiver);
11715 }
11716 $$W({
11717 target: 'Reflect',
11718 stat: true
11719 }, {
11720 get: get$2
11721 });
11722
11723 var es_reflect_getOwnPropertyDescriptor = {};
11724
11725 var $$V = _export;
11726 var DESCRIPTORS$c = descriptors;
11727 var anObject$i = anObject$D;
11728 var getOwnPropertyDescriptorModule$2 = objectGetOwnPropertyDescriptor;
11729
11730 // `Reflect.getOwnPropertyDescriptor` method
11731 // https://tc39.es/ecma262/#sec-reflect.getownpropertydescriptor
11732 $$V({
11733 target: 'Reflect',
11734 stat: true,
11735 sham: !DESCRIPTORS$c
11736 }, {
11737 getOwnPropertyDescriptor: function getOwnPropertyDescriptor(target, propertyKey) {
11738 return getOwnPropertyDescriptorModule$2.f(anObject$i(target), propertyKey);
11739 }
11740 });
11741
11742 var es_reflect_getPrototypeOf = {};
11743
11744 var $$U = _export;
11745 var anObject$h = anObject$D;
11746 var objectGetPrototypeOf = objectGetPrototypeOf$1;
11747 var CORRECT_PROTOTYPE_GETTER = correctPrototypeGetter;
11748
11749 // `Reflect.getPrototypeOf` method
11750 // https://tc39.es/ecma262/#sec-reflect.getprototypeof
11751 $$U({
11752 target: 'Reflect',
11753 stat: true,
11754 sham: !CORRECT_PROTOTYPE_GETTER
11755 }, {
11756 getPrototypeOf: function getPrototypeOf(target) {
11757 return objectGetPrototypeOf(anObject$h(target));
11758 }
11759 });
11760
11761 var es_reflect_has = {};
11762
11763 var $$T = _export;
11764
11765 // `Reflect.has` method
11766 // https://tc39.es/ecma262/#sec-reflect.has
11767 $$T({
11768 target: 'Reflect',
11769 stat: true
11770 }, {
11771 has: function has(target, propertyKey) {
11772 return propertyKey in target;
11773 }
11774 });
11775
11776 var es_reflect_isExtensible = {};
11777
11778 var $$S = _export;
11779 var anObject$g = anObject$D;
11780 var $isExtensible = objectIsExtensible;
11781
11782 // `Reflect.isExtensible` method
11783 // https://tc39.es/ecma262/#sec-reflect.isextensible
11784 $$S({
11785 target: 'Reflect',
11786 stat: true
11787 }, {
11788 isExtensible: function isExtensible(target) {
11789 anObject$g(target);
11790 return $isExtensible(target);
11791 }
11792 });
11793
11794 var es_reflect_ownKeys = {};
11795
11796 var $$R = _export;
11797 var ownKeys = ownKeys$3;
11798
11799 // `Reflect.ownKeys` method
11800 // https://tc39.es/ecma262/#sec-reflect.ownkeys
11801 $$R({
11802 target: 'Reflect',
11803 stat: true
11804 }, {
11805 ownKeys: ownKeys
11806 });
11807
11808 var es_reflect_preventExtensions = {};
11809
11810 var $$Q = _export;
11811 var getBuiltIn$5 = getBuiltIn$m;
11812 var anObject$f = anObject$D;
11813 var FREEZING$1 = freezing;
11814
11815 // `Reflect.preventExtensions` method
11816 // https://tc39.es/ecma262/#sec-reflect.preventextensions
11817 $$Q({
11818 target: 'Reflect',
11819 stat: true,
11820 sham: !FREEZING$1
11821 }, {
11822 preventExtensions: function preventExtensions(target) {
11823 anObject$f(target);
11824 try {
11825 var objectPreventExtensions = getBuiltIn$5('Object', 'preventExtensions');
11826 if (objectPreventExtensions) objectPreventExtensions(target);
11827 return true;
11828 } catch (error) {
11829 return false;
11830 }
11831 }
11832 });
11833
11834 var es_reflect_set = {};
11835
11836 var $$P = _export;
11837 var call$i = functionCall;
11838 var anObject$e = anObject$D;
11839 var isObject$6 = isObject$z;
11840 var isDataDescriptor = isDataDescriptor$2;
11841 var fails$s = fails$1m;
11842 var definePropertyModule$1 = objectDefineProperty;
11843 var getOwnPropertyDescriptorModule$1 = objectGetOwnPropertyDescriptor;
11844 var getPrototypeOf = objectGetPrototypeOf$1;
11845 var createPropertyDescriptor$4 = createPropertyDescriptor$c;
11846
11847 // `Reflect.set` method
11848 // https://tc39.es/ecma262/#sec-reflect.set
11849 function set(target, propertyKey, V /* , receiver */) {
11850 var receiver = arguments.length < 4 ? target : arguments[3];
11851 var ownDescriptor = getOwnPropertyDescriptorModule$1.f(anObject$e(target), propertyKey);
11852 var existingDescriptor, prototype, setter;
11853 if (!ownDescriptor) {
11854 if (isObject$6(prototype = getPrototypeOf(target))) {
11855 return set(prototype, propertyKey, V, receiver);
11856 }
11857 ownDescriptor = createPropertyDescriptor$4(0);
11858 }
11859 if (isDataDescriptor(ownDescriptor)) {
11860 if (ownDescriptor.writable === false || !isObject$6(receiver)) return false;
11861 if (existingDescriptor = getOwnPropertyDescriptorModule$1.f(receiver, propertyKey)) {
11862 if (existingDescriptor.get || existingDescriptor.set || existingDescriptor.writable === false) return false;
11863 existingDescriptor.value = V;
11864 definePropertyModule$1.f(receiver, propertyKey, existingDescriptor);
11865 } else definePropertyModule$1.f(receiver, propertyKey, createPropertyDescriptor$4(0, V));
11866 } else {
11867 setter = ownDescriptor.set;
11868 if (setter === undefined) return false;
11869 call$i(setter, receiver, V);
11870 }
11871 return true;
11872 }
11873
11874 // MS Edge 17-18 Reflect.set allows setting the property to object
11875 // with non-writable property on the prototype
11876 var MS_EDGE_BUG = fails$s(function () {
11877 var Constructor = function Constructor() {/* empty */};
11878 var object = definePropertyModule$1.f(new Constructor(), 'a', {
11879 configurable: true
11880 });
11881 // eslint-disable-next-line es/no-reflect -- required for testing
11882 return Reflect.set(Constructor.prototype, 'a', 1, object) !== false;
11883 });
11884 $$P({
11885 target: 'Reflect',
11886 stat: true,
11887 forced: MS_EDGE_BUG
11888 }, {
11889 set: set
11890 });
11891
11892 var es_reflect_setPrototypeOf = {};
11893
11894 var $$O = _export;
11895 var anObject$d = anObject$D;
11896 var aPossiblePrototype = aPossiblePrototype$2;
11897 var objectSetPrototypeOf = objectSetPrototypeOf$1;
11898
11899 // `Reflect.setPrototypeOf` method
11900 // https://tc39.es/ecma262/#sec-reflect.setprototypeof
11901 if (objectSetPrototypeOf) $$O({
11902 target: 'Reflect',
11903 stat: true
11904 }, {
11905 setPrototypeOf: function setPrototypeOf(target, proto) {
11906 anObject$d(target);
11907 aPossiblePrototype(proto);
11908 try {
11909 objectSetPrototypeOf(target, proto);
11910 return true;
11911 } catch (error) {
11912 return false;
11913 }
11914 }
11915 });
11916
11917 var es_reflect_toStringTag = {};
11918
11919 var $$N = _export;
11920 var global$u = global$Z;
11921 var setToStringTag$3 = setToStringTag$d;
11922 $$N({
11923 global: true
11924 }, {
11925 Reflect: {}
11926 });
11927
11928 // Reflect[@@toStringTag] property
11929 // https://tc39.es/ecma262/#sec-reflect-@@tostringtag
11930 setToStringTag$3(global$u.Reflect, 'Reflect', true);
11931
11932 var es_regexp_constructor = {};
11933
11934 var isObject$5 = isObject$z;
11935 var classof$9 = classofRaw$2;
11936 var wellKnownSymbol$a = wellKnownSymbol$z;
11937 var MATCH$2 = wellKnownSymbol$a('match');
11938
11939 // `IsRegExp` abstract operation
11940 // https://tc39.es/ecma262/#sec-isregexp
11941 var isRegexp = function isRegexp(it) {
11942 var isRegExp;
11943 return isObject$5(it) && ((isRegExp = it[MATCH$2]) !== undefined ? !!isRegExp : classof$9(it) == 'RegExp');
11944 };
11945 var isRegexp$1 = /*@__PURE__*/getDefaultExportFromCjs(isRegexp);
11946
11947 'use strict';
11948 var anObject$c = anObject$D;
11949
11950 // `RegExp.prototype.flags` getter implementation
11951 // https://tc39.es/ecma262/#sec-get-regexp.prototype.flags
11952 var regexpFlags$1 = function regexpFlags() {
11953 var that = anObject$c(this);
11954 var result = '';
11955 if (that.hasIndices) result += 'd';
11956 if (that.global) result += 'g';
11957 if (that.ignoreCase) result += 'i';
11958 if (that.multiline) result += 'm';
11959 if (that.dotAll) result += 's';
11960 if (that.unicode) result += 'u';
11961 if (that.unicodeSets) result += 'v';
11962 if (that.sticky) result += 'y';
11963 return result;
11964 };
11965 var regexpFlags$2 = /*@__PURE__*/getDefaultExportFromCjs(regexpFlags$1);
11966
11967 var call$h = functionCall;
11968 var hasOwn$9 = hasOwnProperty_1;
11969 var isPrototypeOf$2 = objectIsPrototypeOf;
11970 var regExpFlags$1 = regexpFlags$1;
11971 var RegExpPrototype$7 = RegExp.prototype;
11972 var regexpGetFlags = function regexpGetFlags(R) {
11973 var flags = R.flags;
11974 return flags === undefined && !('flags' in RegExpPrototype$7) && !hasOwn$9(R, 'flags') && isPrototypeOf$2(RegExpPrototype$7, R) ? call$h(regExpFlags$1, R) : flags;
11975 };
11976 var regexpGetFlags$1 = /*@__PURE__*/getDefaultExportFromCjs(regexpGetFlags);
11977
11978 var fails$r = fails$1m;
11979 var global$t = global$Z;
11980
11981 // babel-minify and Closure Compiler transpiles RegExp('a', 'y') -> /a/y and it causes SyntaxError
11982 var $RegExp$2 = global$t.RegExp;
11983 var UNSUPPORTED_Y$3 = fails$r(function () {
11984 var re = $RegExp$2('a', 'y');
11985 re.lastIndex = 2;
11986 return re.exec('abcd') != null;
11987 });
11988
11989 // UC Browser bug
11990 // https://github.com/zloirock/core-js/issues/1008
11991 var MISSED_STICKY$2 = UNSUPPORTED_Y$3 || fails$r(function () {
11992 return !$RegExp$2('a', 'y').sticky;
11993 });
11994 var BROKEN_CARET = UNSUPPORTED_Y$3 || fails$r(function () {
11995 // https://bugzilla.mozilla.org/show_bug.cgi?id=773687
11996 var re = $RegExp$2('^r', 'gy');
11997 re.lastIndex = 2;
11998 return re.exec('str') != null;
11999 });
12000 var regexpStickyHelpers = {
12001 BROKEN_CARET: BROKEN_CARET,
12002 MISSED_STICKY: MISSED_STICKY$2,
12003 UNSUPPORTED_Y: UNSUPPORTED_Y$3
12004 };
12005 var regexpStickyHelpers$1 = /*@__PURE__*/getDefaultExportFromCjs(regexpStickyHelpers);
12006
12007 var fails$q = fails$1m;
12008 var global$s = global$Z;
12009
12010 // babel-minify and Closure Compiler transpiles RegExp('.', 's') -> /./s and it causes SyntaxError
12011 var $RegExp$1 = global$s.RegExp;
12012 var regexpUnsupportedDotAll = fails$q(function () {
12013 var re = $RegExp$1('.', 's');
12014 return !(re.dotAll && re.exec('\n') && re.flags === 's');
12015 });
12016 var regexpUnsupportedDotAll$1 = /*@__PURE__*/getDefaultExportFromCjs(regexpUnsupportedDotAll);
12017
12018 var fails$p = fails$1m;
12019 var global$r = global$Z;
12020
12021 // babel-minify and Closure Compiler transpiles RegExp('(?<a>b)', 'g') -> /(?<a>b)/g and it causes SyntaxError
12022 var $RegExp = global$r.RegExp;
12023 var regexpUnsupportedNcg = fails$p(function () {
12024 var re = $RegExp('(?<a>b)', 'g');
12025 return re.exec('b').groups.a !== 'b' || 'b'.replace(re, '$<a>c') !== 'bc';
12026 });
12027 var regexpUnsupportedNcg$1 = /*@__PURE__*/getDefaultExportFromCjs(regexpUnsupportedNcg);
12028
12029 var DESCRIPTORS$b = descriptors;
12030 var global$q = global$Z;
12031 var uncurryThis$z = functionUncurryThis;
12032 var isForced = isForced_1;
12033 var inheritIfRequired$2 = inheritIfRequired$6;
12034 var createNonEnumerableProperty$5 = createNonEnumerableProperty$f;
12035 var getOwnPropertyNames$1 = objectGetOwnPropertyNames.f;
12036 var isPrototypeOf$1 = objectIsPrototypeOf;
12037 var isRegExp$4 = isRegexp;
12038 var toString$k = toString$x;
12039 var getRegExpFlags$4 = regexpGetFlags;
12040 var stickyHelpers$2 = regexpStickyHelpers;
12041 var proxyAccessor = proxyAccessor$2;
12042 var defineBuiltIn$6 = defineBuiltIn$m;
12043 var fails$o = fails$1m;
12044 var hasOwn$8 = hasOwnProperty_1;
12045 var enforceInternalState$2 = internalState.enforce;
12046 var setSpecies$1 = setSpecies$6;
12047 var wellKnownSymbol$9 = wellKnownSymbol$z;
12048 var UNSUPPORTED_DOT_ALL$2 = regexpUnsupportedDotAll;
12049 var UNSUPPORTED_NCG$1 = regexpUnsupportedNcg;
12050 var MATCH$1 = wellKnownSymbol$9('match');
12051 var NativeRegExp = global$q.RegExp;
12052 var RegExpPrototype$6 = NativeRegExp.prototype;
12053 var SyntaxError$1 = global$q.SyntaxError;
12054 var exec$6 = uncurryThis$z(RegExpPrototype$6.exec);
12055 var charAt$b = uncurryThis$z(''.charAt);
12056 var replace$8 = uncurryThis$z(''.replace);
12057 var stringIndexOf$4 = uncurryThis$z(''.indexOf);
12058 var stringSlice$b = uncurryThis$z(''.slice);
12059 // TODO: Use only proper RegExpIdentifierName
12060 var IS_NCG = /^\?<[^\s\d!#%&*+<=>@^][^\s!#%&*+<=>@^]*>/;
12061 var re1 = /a/g;
12062 var re2 = /a/g;
12063
12064 // "new" should create a new object, old webkit bug
12065 var CORRECT_NEW = new NativeRegExp(re1) !== re1;
12066 var MISSED_STICKY$1 = stickyHelpers$2.MISSED_STICKY;
12067 var UNSUPPORTED_Y$2 = stickyHelpers$2.UNSUPPORTED_Y;
12068 var BASE_FORCED = DESCRIPTORS$b && (!CORRECT_NEW || MISSED_STICKY$1 || UNSUPPORTED_DOT_ALL$2 || UNSUPPORTED_NCG$1 || fails$o(function () {
12069 re2[MATCH$1] = false;
12070 // RegExp constructor can alter flags and IsRegExp works correct with @@match
12071 return NativeRegExp(re1) != re1 || NativeRegExp(re2) == re2 || NativeRegExp(re1, 'i') != '/a/i';
12072 }));
12073 var handleDotAll = function handleDotAll(string) {
12074 var length = string.length;
12075 var index = 0;
12076 var result = '';
12077 var brackets = false;
12078 var chr;
12079 for (; index <= length; index++) {
12080 chr = charAt$b(string, index);
12081 if (chr === '\\') {
12082 result += chr + charAt$b(string, ++index);
12083 continue;
12084 }
12085 if (!brackets && chr === '.') {
12086 result += '[\\s\\S]';
12087 } else {
12088 if (chr === '[') {
12089 brackets = true;
12090 } else if (chr === ']') {
12091 brackets = false;
12092 }
12093 result += chr;
12094 }
12095 }
12096 return result;
12097 };
12098 var handleNCG = function handleNCG(string) {
12099 var length = string.length;
12100 var index = 0;
12101 var result = '';
12102 var named = [];
12103 var names = {};
12104 var brackets = false;
12105 var ncg = false;
12106 var groupid = 0;
12107 var groupname = '';
12108 var chr;
12109 for (; index <= length; index++) {
12110 chr = charAt$b(string, index);
12111 if (chr === '\\') {
12112 chr = chr + charAt$b(string, ++index);
12113 } else if (chr === ']') {
12114 brackets = false;
12115 } else if (!brackets) switch (true) {
12116 case chr === '[':
12117 brackets = true;
12118 break;
12119 case chr === '(':
12120 if (exec$6(IS_NCG, stringSlice$b(string, index + 1))) {
12121 index += 2;
12122 ncg = true;
12123 }
12124 result += chr;
12125 groupid++;
12126 continue;
12127 case chr === '>' && ncg:
12128 if (groupname === '' || hasOwn$8(names, groupname)) {
12129 throw new SyntaxError$1('Invalid capture group name');
12130 }
12131 names[groupname] = true;
12132 named[named.length] = [groupname, groupid];
12133 ncg = false;
12134 groupname = '';
12135 continue;
12136 }
12137 if (ncg) groupname += chr;else result += chr;
12138 }
12139 return [result, named];
12140 };
12141
12142 // `RegExp` constructor
12143 // https://tc39.es/ecma262/#sec-regexp-constructor
12144 if (isForced('RegExp', BASE_FORCED)) {
12145 var RegExpWrapper = function RegExp(pattern, flags) {
12146 var thisIsRegExp = isPrototypeOf$1(RegExpPrototype$6, this);
12147 var patternIsRegExp = isRegExp$4(pattern);
12148 var flagsAreUndefined = flags === undefined;
12149 var groups = [];
12150 var rawPattern = pattern;
12151 var rawFlags, dotAll, sticky, handled, result, state;
12152 if (!thisIsRegExp && patternIsRegExp && flagsAreUndefined && pattern.constructor === RegExpWrapper) {
12153 return pattern;
12154 }
12155 if (patternIsRegExp || isPrototypeOf$1(RegExpPrototype$6, pattern)) {
12156 pattern = pattern.source;
12157 if (flagsAreUndefined) flags = getRegExpFlags$4(rawPattern);
12158 }
12159 pattern = pattern === undefined ? '' : toString$k(pattern);
12160 flags = flags === undefined ? '' : toString$k(flags);
12161 rawPattern = pattern;
12162 if (UNSUPPORTED_DOT_ALL$2 && 'dotAll' in re1) {
12163 dotAll = !!flags && stringIndexOf$4(flags, 's') > -1;
12164 if (dotAll) flags = replace$8(flags, /s/g, '');
12165 }
12166 rawFlags = flags;
12167 if (MISSED_STICKY$1 && 'sticky' in re1) {
12168 sticky = !!flags && stringIndexOf$4(flags, 'y') > -1;
12169 if (sticky && UNSUPPORTED_Y$2) flags = replace$8(flags, /y/g, '');
12170 }
12171 if (UNSUPPORTED_NCG$1) {
12172 handled = handleNCG(pattern);
12173 pattern = handled[0];
12174 groups = handled[1];
12175 }
12176 result = inheritIfRequired$2(NativeRegExp(pattern, flags), thisIsRegExp ? this : RegExpPrototype$6, RegExpWrapper);
12177 if (dotAll || sticky || groups.length) {
12178 state = enforceInternalState$2(result);
12179 if (dotAll) {
12180 state.dotAll = true;
12181 state.raw = RegExpWrapper(handleDotAll(pattern), rawFlags);
12182 }
12183 if (sticky) state.sticky = true;
12184 if (groups.length) state.groups = groups;
12185 }
12186 if (pattern !== rawPattern) try {
12187 // fails in old engines, but we have no alternatives for unsupported regex syntax
12188 createNonEnumerableProperty$5(result, 'source', rawPattern === '' ? '(?:)' : rawPattern);
12189 } catch (error) {/* empty */}
12190 return result;
12191 };
12192 for (var keys = getOwnPropertyNames$1(NativeRegExp), index$4 = 0; keys.length > index$4;) {
12193 proxyAccessor(RegExpWrapper, NativeRegExp, keys[index$4++]);
12194 }
12195 RegExpPrototype$6.constructor = RegExpWrapper;
12196 RegExpWrapper.prototype = RegExpPrototype$6;
12197 defineBuiltIn$6(global$q, 'RegExp', RegExpWrapper, {
12198 constructor: true
12199 });
12200 }
12201
12202 // https://tc39.es/ecma262/#sec-get-regexp-@@species
12203 setSpecies$1('RegExp');
12204
12205 var es_regexp_dotAll = {};
12206
12207 var DESCRIPTORS$a = descriptors;
12208 var UNSUPPORTED_DOT_ALL$1 = regexpUnsupportedDotAll;
12209 var classof$8 = classofRaw$2;
12210 var defineBuiltInAccessor$8 = defineBuiltInAccessor$h;
12211 var getInternalState$6 = internalState.get;
12212 var RegExpPrototype$5 = RegExp.prototype;
12213 var $TypeError$7 = TypeError;
12214
12215 // `RegExp.prototype.dotAll` getter
12216 // https://tc39.es/ecma262/#sec-get-regexp.prototype.dotall
12217 if (DESCRIPTORS$a && UNSUPPORTED_DOT_ALL$1) {
12218 defineBuiltInAccessor$8(RegExpPrototype$5, 'dotAll', {
12219 configurable: true,
12220 get: function dotAll() {
12221 if (this === RegExpPrototype$5) return undefined;
12222 // We can't use InternalStateModule.getterFor because
12223 // we don't add metadata for regexps created by a literal.
12224 if (classof$8(this) === 'RegExp') {
12225 return !!getInternalState$6(this).dotAll;
12226 }
12227 throw $TypeError$7('Incompatible receiver, RegExp required');
12228 }
12229 });
12230 }
12231
12232 var es_regexp_exec = {};
12233
12234 'use strict';
12235 /* eslint-disable regexp/no-empty-capturing-group, regexp/no-empty-group, regexp/no-lazy-ends -- testing */
12236 /* eslint-disable regexp/no-useless-quantifier -- testing */
12237 var call$g = functionCall;
12238 var uncurryThis$y = functionUncurryThis;
12239 var toString$j = toString$x;
12240 var regexpFlags = regexpFlags$1;
12241 var stickyHelpers$1 = regexpStickyHelpers;
12242 var shared$1 = sharedExports;
12243 var create$3 = objectCreate;
12244 var getInternalState$5 = internalState.get;
12245 var UNSUPPORTED_DOT_ALL = regexpUnsupportedDotAll;
12246 var UNSUPPORTED_NCG = regexpUnsupportedNcg;
12247 var nativeReplace = shared$1('native-string-replace', String.prototype.replace);
12248 var nativeExec = RegExp.prototype.exec;
12249 var patchedExec = nativeExec;
12250 var charAt$a = uncurryThis$y(''.charAt);
12251 var indexOf$1 = uncurryThis$y(''.indexOf);
12252 var replace$7 = uncurryThis$y(''.replace);
12253 var stringSlice$a = uncurryThis$y(''.slice);
12254 var UPDATES_LAST_INDEX_WRONG = function () {
12255 var re1 = /a/;
12256 var re2 = /b*/g;
12257 call$g(nativeExec, re1, 'a');
12258 call$g(nativeExec, re2, 'a');
12259 return re1.lastIndex !== 0 || re2.lastIndex !== 0;
12260 }();
12261 var UNSUPPORTED_Y$1 = stickyHelpers$1.BROKEN_CARET;
12262
12263 // nonparticipating capturing group, copied from es5-shim's String#split patch.
12264 var NPCG_INCLUDED = /()??/.exec('')[1] !== undefined;
12265 var PATCH = UPDATES_LAST_INDEX_WRONG || NPCG_INCLUDED || UNSUPPORTED_Y$1 || UNSUPPORTED_DOT_ALL || UNSUPPORTED_NCG;
12266 if (PATCH) {
12267 patchedExec = function exec(string) {
12268 var re = this;
12269 var state = getInternalState$5(re);
12270 var str = toString$j(string);
12271 var raw = state.raw;
12272 var result, reCopy, lastIndex, match, i, object, group;
12273 if (raw) {
12274 raw.lastIndex = re.lastIndex;
12275 result = call$g(patchedExec, raw, str);
12276 re.lastIndex = raw.lastIndex;
12277 return result;
12278 }
12279 var groups = state.groups;
12280 var sticky = UNSUPPORTED_Y$1 && re.sticky;
12281 var flags = call$g(regexpFlags, re);
12282 var source = re.source;
12283 var charsAdded = 0;
12284 var strCopy = str;
12285 if (sticky) {
12286 flags = replace$7(flags, 'y', '');
12287 if (indexOf$1(flags, 'g') === -1) {
12288 flags += 'g';
12289 }
12290 strCopy = stringSlice$a(str, re.lastIndex);
12291 // Support anchored sticky behavior.
12292 if (re.lastIndex > 0 && (!re.multiline || re.multiline && charAt$a(str, re.lastIndex - 1) !== '\n')) {
12293 source = '(?: ' + source + ')';
12294 strCopy = ' ' + strCopy;
12295 charsAdded++;
12296 }
12297 // ^(? + rx + ) is needed, in combination with some str slicing, to
12298 // simulate the 'y' flag.
12299 reCopy = new RegExp('^(?:' + source + ')', flags);
12300 }
12301 if (NPCG_INCLUDED) {
12302 reCopy = new RegExp('^' + source + '$(?!\\s)', flags);
12303 }
12304 if (UPDATES_LAST_INDEX_WRONG) lastIndex = re.lastIndex;
12305 match = call$g(nativeExec, sticky ? reCopy : re, strCopy);
12306 if (sticky) {
12307 if (match) {
12308 match.input = stringSlice$a(match.input, charsAdded);
12309 match[0] = stringSlice$a(match[0], charsAdded);
12310 match.index = re.lastIndex;
12311 re.lastIndex += match[0].length;
12312 } else re.lastIndex = 0;
12313 } else if (UPDATES_LAST_INDEX_WRONG && match) {
12314 re.lastIndex = re.global ? match.index + match[0].length : lastIndex;
12315 }
12316 if (NPCG_INCLUDED && match && match.length > 1) {
12317 // Fix browsers whose `exec` methods don't consistently return `undefined`
12318 // for NPCG, like IE8. NOTE: This doesn't work for /(.?)?/
12319 call$g(nativeReplace, match[0], reCopy, function () {
12320 for (i = 1; i < arguments.length - 2; i++) {
12321 if (arguments[i] === undefined) match[i] = undefined;
12322 }
12323 });
12324 }
12325 if (match && groups) {
12326 match.groups = object = create$3(null);
12327 for (i = 0; i < groups.length; i++) {
12328 group = groups[i];
12329 object[group[0]] = match[group[1]];
12330 }
12331 }
12332 return match;
12333 };
12334 }
12335 var regexpExec$3 = patchedExec;
12336 var regexpExec$4 = /*@__PURE__*/getDefaultExportFromCjs(regexpExec$3);
12337
12338 'use strict';
12339 var $$M = _export;
12340 var exec$5 = regexpExec$3;
12341
12342 // `RegExp.prototype.exec` method
12343 // https://tc39.es/ecma262/#sec-regexp.prototype.exec
12344 $$M({
12345 target: 'RegExp',
12346 proto: true,
12347 forced: /./.exec !== exec$5
12348 }, {
12349 exec: exec$5
12350 });
12351
12352 var es_regexp_flags = {};
12353
12354 var global$p = global$Z;
12355 var DESCRIPTORS$9 = descriptors;
12356 var defineBuiltInAccessor$7 = defineBuiltInAccessor$h;
12357 var regExpFlags = regexpFlags$1;
12358 var fails$n = fails$1m;
12359
12360 // babel-minify and Closure Compiler transpiles RegExp('.', 'd') -> /./d and it causes SyntaxError
12361 var RegExp$2 = global$p.RegExp;
12362 var RegExpPrototype$4 = RegExp$2.prototype;
12363 var FORCED$4 = DESCRIPTORS$9 && fails$n(function () {
12364 var INDICES_SUPPORT = true;
12365 try {
12366 RegExp$2('.', 'd');
12367 } catch (error) {
12368 INDICES_SUPPORT = false;
12369 }
12370 var O = {};
12371 // modern V8 bug
12372 var calls = '';
12373 var expected = INDICES_SUPPORT ? 'dgimsy' : 'gimsy';
12374 var addGetter = function addGetter(key, chr) {
12375 // eslint-disable-next-line es/no-object-defineproperty -- safe
12376 Object.defineProperty(O, key, {
12377 get: function get() {
12378 calls += chr;
12379 return true;
12380 }
12381 });
12382 };
12383 var pairs = {
12384 dotAll: 's',
12385 global: 'g',
12386 ignoreCase: 'i',
12387 multiline: 'm',
12388 sticky: 'y'
12389 };
12390 if (INDICES_SUPPORT) pairs.hasIndices = 'd';
12391 for (var key in pairs) addGetter(key, pairs[key]);
12392
12393 // eslint-disable-next-line es/no-object-getownpropertydescriptor -- safe
12394 var result = Object.getOwnPropertyDescriptor(RegExpPrototype$4, 'flags').get.call(O);
12395 return result !== expected || calls !== expected;
12396 });
12397
12398 // `RegExp.prototype.flags` getter
12399 // https://tc39.es/ecma262/#sec-get-regexp.prototype.flags
12400 if (FORCED$4) defineBuiltInAccessor$7(RegExpPrototype$4, 'flags', {
12401 configurable: true,
12402 get: regExpFlags
12403 });
12404
12405 var es_regexp_sticky = {};
12406
12407 var DESCRIPTORS$8 = descriptors;
12408 var MISSED_STICKY = regexpStickyHelpers.MISSED_STICKY;
12409 var classof$7 = classofRaw$2;
12410 var defineBuiltInAccessor$6 = defineBuiltInAccessor$h;
12411 var getInternalState$4 = internalState.get;
12412 var RegExpPrototype$3 = RegExp.prototype;
12413 var $TypeError$6 = TypeError;
12414
12415 // `RegExp.prototype.sticky` getter
12416 // https://tc39.es/ecma262/#sec-get-regexp.prototype.sticky
12417 if (DESCRIPTORS$8 && MISSED_STICKY) {
12418 defineBuiltInAccessor$6(RegExpPrototype$3, 'sticky', {
12419 configurable: true,
12420 get: function sticky() {
12421 if (this === RegExpPrototype$3) return;
12422 // We can't use InternalStateModule.getterFor because
12423 // we don't add metadata for regexps created by a literal.
12424 if (classof$7(this) === 'RegExp') {
12425 return !!getInternalState$4(this).sticky;
12426 }
12427 throw $TypeError$6('Incompatible receiver, RegExp required');
12428 }
12429 });
12430 }
12431
12432 var es_regexp_test = {};
12433
12434 'use strict';
12435 // TODO: Remove from `core-js@4` since it's moved to entry points
12436
12437 var $$L = _export;
12438 var call$f = functionCall;
12439 var isCallable$6 = isCallable$z;
12440 var anObject$b = anObject$D;
12441 var toString$i = toString$x;
12442 var DELEGATES_TO_EXEC = function () {
12443 var execCalled = false;
12444 var re = /[ac]/;
12445 re.exec = function () {
12446 execCalled = true;
12447 return /./.exec.apply(this, arguments);
12448 };
12449 return re.test('abc') === true && execCalled;
12450 }();
12451 var nativeTest = /./.test;
12452
12453 // `RegExp.prototype.test` method
12454 // https://tc39.es/ecma262/#sec-regexp.prototype.test
12455 $$L({
12456 target: 'RegExp',
12457 proto: true,
12458 forced: !DELEGATES_TO_EXEC
12459 }, {
12460 test: function test(S) {
12461 var R = anObject$b(this);
12462 var string = toString$i(S);
12463 var exec = R.exec;
12464 if (!isCallable$6(exec)) return call$f(nativeTest, R, string);
12465 var result = call$f(exec, R, string);
12466 if (result === null) return false;
12467 anObject$b(result);
12468 return true;
12469 }
12470 });
12471
12472 var es_regexp_toString = {};
12473
12474 'use strict';
12475 var PROPER_FUNCTION_NAME$1 = functionName.PROPER;
12476 var defineBuiltIn$5 = defineBuiltIn$m;
12477 var anObject$a = anObject$D;
12478 var $toString$2 = toString$x;
12479 var fails$m = fails$1m;
12480 var getRegExpFlags$3 = regexpGetFlags;
12481 var TO_STRING = 'toString';
12482 var RegExpPrototype$2 = RegExp.prototype;
12483 var nativeToString = RegExpPrototype$2[TO_STRING];
12484 var NOT_GENERIC = fails$m(function () {
12485 return nativeToString.call({
12486 source: 'a',
12487 flags: 'b'
12488 }) != '/a/b';
12489 });
12490 // FF44- RegExp#toString has a wrong name
12491 var INCORRECT_NAME = PROPER_FUNCTION_NAME$1 && nativeToString.name != TO_STRING;
12492
12493 // `RegExp.prototype.toString` method
12494 // https://tc39.es/ecma262/#sec-regexp.prototype.tostring
12495 if (NOT_GENERIC || INCORRECT_NAME) {
12496 defineBuiltIn$5(RegExp.prototype, TO_STRING, function toString() {
12497 var R = anObject$a(this);
12498 var pattern = $toString$2(R.source);
12499 var flags = $toString$2(getRegExpFlags$3(R));
12500 return '/' + pattern + '/' + flags;
12501 }, {
12502 unsafe: true
12503 });
12504 }
12505
12506 var es_set = {};
12507
12508 var es_set_constructor = {};
12509
12510 'use strict';
12511 var collection$2 = collection$4;
12512 var collectionStrong = collectionStrong$2;
12513
12514 // `Set` constructor
12515 // https://tc39.es/ecma262/#sec-set-objects
12516 collection$2('Set', function (init) {
12517 return function Set() {
12518 return init(this, arguments.length ? arguments[0] : undefined);
12519 };
12520 }, collectionStrong);
12521
12522 var es_string_atAlternative = {};
12523
12524 'use strict';
12525 var $$K = _export;
12526 var uncurryThis$x = functionUncurryThis;
12527 var requireObjectCoercible$c = requireObjectCoercible$j;
12528 var toIntegerOrInfinity$6 = toIntegerOrInfinity$l;
12529 var toString$h = toString$x;
12530 var fails$l = fails$1m;
12531 var charAt$9 = uncurryThis$x(''.charAt);
12532 var FORCED$3 = fails$l(function () {
12533 // eslint-disable-next-line es/no-array-string-prototype-at -- safe
12534 return '𠮷'.at(-2) !== "\uD842";
12535 });
12536
12537 // `String.prototype.at` method
12538 // https://github.com/tc39/proposal-relative-indexing-method
12539 $$K({
12540 target: 'String',
12541 proto: true,
12542 forced: FORCED$3
12543 }, {
12544 at: function at(index) {
12545 var S = toString$h(requireObjectCoercible$c(this));
12546 var len = S.length;
12547 var relativeIndex = toIntegerOrInfinity$6(index);
12548 var k = relativeIndex >= 0 ? relativeIndex : len + relativeIndex;
12549 return k < 0 || k >= len ? undefined : charAt$9(S, k);
12550 }
12551 });
12552
12553 var es_string_codePointAt = {};
12554
12555 var uncurryThis$w = functionUncurryThis;
12556 var toIntegerOrInfinity$5 = toIntegerOrInfinity$l;
12557 var toString$g = toString$x;
12558 var requireObjectCoercible$b = requireObjectCoercible$j;
12559 var charAt$8 = uncurryThis$w(''.charAt);
12560 var charCodeAt$2 = uncurryThis$w(''.charCodeAt);
12561 var stringSlice$9 = uncurryThis$w(''.slice);
12562 var createMethod = function createMethod(CONVERT_TO_STRING) {
12563 return function ($this, pos) {
12564 var S = toString$g(requireObjectCoercible$b($this));
12565 var position = toIntegerOrInfinity$5(pos);
12566 var size = S.length;
12567 var first, second;
12568 if (position < 0 || position >= size) return CONVERT_TO_STRING ? '' : undefined;
12569 first = charCodeAt$2(S, position);
12570 return first < 0xD800 || first > 0xDBFF || position + 1 === size || (second = charCodeAt$2(S, position + 1)) < 0xDC00 || second > 0xDFFF ? CONVERT_TO_STRING ? charAt$8(S, position) : first : CONVERT_TO_STRING ? stringSlice$9(S, position, position + 2) : (first - 0xD800 << 10) + (second - 0xDC00) + 0x10000;
12571 };
12572 };
12573 var stringMultibyte = {
12574 // `String.prototype.codePointAt` method
12575 // https://tc39.es/ecma262/#sec-string.prototype.codepointat
12576 codeAt: createMethod(false),
12577 // `String.prototype.at` method
12578 // https://github.com/mathiasbynens/String.prototype.at
12579 charAt: createMethod(true)
12580 };
12581 var stringMultibyte$1 = /*@__PURE__*/getDefaultExportFromCjs(stringMultibyte);
12582
12583 'use strict';
12584 var $$J = _export;
12585 var codeAt$1 = stringMultibyte.codeAt;
12586
12587 // `String.prototype.codePointAt` method
12588 // https://tc39.es/ecma262/#sec-string.prototype.codepointat
12589 $$J({
12590 target: 'String',
12591 proto: true
12592 }, {
12593 codePointAt: function codePointAt(pos) {
12594 return codeAt$1(this, pos);
12595 }
12596 });
12597
12598 var es_string_endsWith = {};
12599
12600 var isRegExp$3 = isRegexp;
12601 var $TypeError$5 = TypeError;
12602 var notARegexp = function notARegexp(it) {
12603 if (isRegExp$3(it)) {
12604 throw $TypeError$5("The method doesn't accept regular expressions");
12605 }
12606 return it;
12607 };
12608 var notARegexp$1 = /*@__PURE__*/getDefaultExportFromCjs(notARegexp);
12609
12610 var wellKnownSymbol$8 = wellKnownSymbol$z;
12611 var MATCH = wellKnownSymbol$8('match');
12612 var correctIsRegexpLogic = function correctIsRegexpLogic(METHOD_NAME) {
12613 var regexp = /./;
12614 try {
12615 '/./'[METHOD_NAME](regexp);
12616 } catch (error1) {
12617 try {
12618 regexp[MATCH] = false;
12619 return '/./'[METHOD_NAME](regexp);
12620 } catch (error2) {/* empty */}
12621 }
12622 return false;
12623 };
12624 var correctIsRegexpLogic$1 = /*@__PURE__*/getDefaultExportFromCjs(correctIsRegexpLogic);
12625
12626 'use strict';
12627 var $$I = _export;
12628 var uncurryThis$v = functionUncurryThisClause;
12629 var getOwnPropertyDescriptor$2 = objectGetOwnPropertyDescriptor.f;
12630 var toLength$7 = toLength$d;
12631 var toString$f = toString$x;
12632 var notARegExp$2 = notARegexp;
12633 var requireObjectCoercible$a = requireObjectCoercible$j;
12634 var correctIsRegExpLogic$2 = correctIsRegexpLogic;
12635 var IS_PURE$7 = isPure;
12636
12637 // eslint-disable-next-line es/no-string-prototype-endswith -- safe
12638 var nativeEndsWith = uncurryThis$v(''.endsWith);
12639 var slice$4 = uncurryThis$v(''.slice);
12640 var min$8 = Math.min;
12641 var CORRECT_IS_REGEXP_LOGIC$1 = correctIsRegExpLogic$2('endsWith');
12642 // https://github.com/zloirock/core-js/pull/702
12643 var MDN_POLYFILL_BUG$1 = !IS_PURE$7 && !CORRECT_IS_REGEXP_LOGIC$1 && !!function () {
12644 var descriptor = getOwnPropertyDescriptor$2(String.prototype, 'endsWith');
12645 return descriptor && !descriptor.writable;
12646 }();
12647
12648 // `String.prototype.endsWith` method
12649 // https://tc39.es/ecma262/#sec-string.prototype.endswith
12650 $$I({
12651 target: 'String',
12652 proto: true,
12653 forced: !MDN_POLYFILL_BUG$1 && !CORRECT_IS_REGEXP_LOGIC$1
12654 }, {
12655 endsWith: function endsWith(searchString /* , endPosition = @length */) {
12656 var that = toString$f(requireObjectCoercible$a(this));
12657 notARegExp$2(searchString);
12658 var endPosition = arguments.length > 1 ? arguments[1] : undefined;
12659 var len = that.length;
12660 var end = endPosition === undefined ? len : min$8(toLength$7(endPosition), len);
12661 var search = toString$f(searchString);
12662 return nativeEndsWith ? nativeEndsWith(that, search, end) : slice$4(that, end - search.length, end) === search;
12663 }
12664 });
12665
12666 var es_string_fromCodePoint = {};
12667
12668 var $$H = _export;
12669 var uncurryThis$u = functionUncurryThis;
12670 var toAbsoluteIndex$1 = toAbsoluteIndex$a;
12671 var $RangeError$3 = RangeError;
12672 var fromCharCode$3 = String.fromCharCode;
12673 // eslint-disable-next-line es/no-string-fromcodepoint -- required for testing
12674 var $fromCodePoint = String.fromCodePoint;
12675 var join$5 = uncurryThis$u([].join);
12676
12677 // length should be 1, old FF problem
12678 var INCORRECT_LENGTH = !!$fromCodePoint && $fromCodePoint.length != 1;
12679
12680 // `String.fromCodePoint` method
12681 // https://tc39.es/ecma262/#sec-string.fromcodepoint
12682 $$H({
12683 target: 'String',
12684 stat: true,
12685 arity: 1,
12686 forced: INCORRECT_LENGTH
12687 }, {
12688 // eslint-disable-next-line no-unused-vars -- required for `.length`
12689 fromCodePoint: function fromCodePoint(x) {
12690 var elements = [];
12691 var length = arguments.length;
12692 var i = 0;
12693 var code;
12694 while (length > i) {
12695 code = +arguments[i++];
12696 if (toAbsoluteIndex$1(code, 0x10FFFF) !== code) throw $RangeError$3(code + ' is not a valid code point');
12697 elements[i] = code < 0x10000 ? fromCharCode$3(code) : fromCharCode$3(((code -= 0x10000) >> 10) + 0xD800, code % 0x400 + 0xDC00);
12698 }
12699 return join$5(elements, '');
12700 }
12701 });
12702
12703 var es_string_includes = {};
12704
12705 'use strict';
12706 var $$G = _export;
12707 var uncurryThis$t = functionUncurryThis;
12708 var notARegExp$1 = notARegexp;
12709 var requireObjectCoercible$9 = requireObjectCoercible$j;
12710 var toString$e = toString$x;
12711 var correctIsRegExpLogic$1 = correctIsRegexpLogic;
12712 var stringIndexOf$3 = uncurryThis$t(''.indexOf);
12713
12714 // `String.prototype.includes` method
12715 // https://tc39.es/ecma262/#sec-string.prototype.includes
12716 $$G({
12717 target: 'String',
12718 proto: true,
12719 forced: !correctIsRegExpLogic$1('includes')
12720 }, {
12721 includes: function includes(searchString /* , position = 0 */) {
12722 return !!~stringIndexOf$3(toString$e(requireObjectCoercible$9(this)), toString$e(notARegExp$1(searchString)), arguments.length > 1 ? arguments[1] : undefined);
12723 }
12724 });
12725
12726 var es_string_iterator = {};
12727
12728 'use strict';
12729 var charAt$7 = stringMultibyte.charAt;
12730 var toString$d = toString$x;
12731 var InternalStateModule$6 = internalState;
12732 var defineIterator = iteratorDefine;
12733 var createIterResultObject$1 = createIterResultObject$4;
12734 var STRING_ITERATOR = 'String Iterator';
12735 var setInternalState$6 = InternalStateModule$6.set;
12736 var getInternalState$3 = InternalStateModule$6.getterFor(STRING_ITERATOR);
12737
12738 // `String.prototype[@@iterator]` method
12739 // https://tc39.es/ecma262/#sec-string.prototype-@@iterator
12740 defineIterator(String, 'String', function (iterated) {
12741 setInternalState$6(this, {
12742 type: STRING_ITERATOR,
12743 string: toString$d(iterated),
12744 index: 0
12745 });
12746 // `%StringIteratorPrototype%.next` method
12747 // https://tc39.es/ecma262/#sec-%stringiteratorprototype%.next
12748 }, function next() {
12749 var state = getInternalState$3(this);
12750 var string = state.string;
12751 var index = state.index;
12752 var point;
12753 if (index >= string.length) return createIterResultObject$1(undefined, true);
12754 point = charAt$7(string, index);
12755 state.index += point.length;
12756 return createIterResultObject$1(point, false);
12757 });
12758
12759 var es_string_match = {};
12760
12761 'use strict';
12762 // TODO: Remove from `core-js@4` since it's moved to entry points
12763
12764 var uncurryThis$s = functionUncurryThisClause;
12765 var defineBuiltIn$4 = defineBuiltIn$m;
12766 var regexpExec$2 = regexpExec$3;
12767 var fails$k = fails$1m;
12768 var wellKnownSymbol$7 = wellKnownSymbol$z;
12769 var createNonEnumerableProperty$4 = createNonEnumerableProperty$f;
12770 var SPECIES = wellKnownSymbol$7('species');
12771 var RegExpPrototype$1 = RegExp.prototype;
12772 var fixRegexpWellKnownSymbolLogic = function fixRegexpWellKnownSymbolLogic(KEY, exec, FORCED, SHAM) {
12773 var SYMBOL = wellKnownSymbol$7(KEY);
12774 var DELEGATES_TO_SYMBOL = !fails$k(function () {
12775 // String methods call symbol-named RegEp methods
12776 var O = {};
12777 O[SYMBOL] = function () {
12778 return 7;
12779 };
12780 return ''[KEY](O) != 7;
12781 });
12782 var DELEGATES_TO_EXEC = DELEGATES_TO_SYMBOL && !fails$k(function () {
12783 // Symbol-named RegExp methods call .exec
12784 var execCalled = false;
12785 var re = /a/;
12786 if (KEY === 'split') {
12787 // We can't use real regex here since it causes deoptimization
12788 // and serious performance degradation in V8
12789 // https://github.com/zloirock/core-js/issues/306
12790 re = {};
12791 // RegExp[@@split] doesn't call the regex's exec method, but first creates
12792 // a new one. We need to return the patched regex when creating the new one.
12793 re.constructor = {};
12794 re.constructor[SPECIES] = function () {
12795 return re;
12796 };
12797 re.flags = '';
12798 re[SYMBOL] = /./[SYMBOL];
12799 }
12800 re.exec = function () {
12801 execCalled = true;
12802 return null;
12803 };
12804 re[SYMBOL]('');
12805 return !execCalled;
12806 });
12807 if (!DELEGATES_TO_SYMBOL || !DELEGATES_TO_EXEC || FORCED) {
12808 var uncurriedNativeRegExpMethod = uncurryThis$s(/./[SYMBOL]);
12809 var methods = exec(SYMBOL, ''[KEY], function (nativeMethod, regexp, str, arg2, forceStringMethod) {
12810 var uncurriedNativeMethod = uncurryThis$s(nativeMethod);
12811 var $exec = regexp.exec;
12812 if ($exec === regexpExec$2 || $exec === RegExpPrototype$1.exec) {
12813 if (DELEGATES_TO_SYMBOL && !forceStringMethod) {
12814 // The native String method already delegates to @@method (this
12815 // polyfilled function), leasing to infinite recursion.
12816 // We avoid it by directly calling the native @@method method.
12817 return {
12818 done: true,
12819 value: uncurriedNativeRegExpMethod(regexp, str, arg2)
12820 };
12821 }
12822 return {
12823 done: true,
12824 value: uncurriedNativeMethod(str, regexp, arg2)
12825 };
12826 }
12827 return {
12828 done: false
12829 };
12830 });
12831 defineBuiltIn$4(String.prototype, KEY, methods[0]);
12832 defineBuiltIn$4(RegExpPrototype$1, SYMBOL, methods[1]);
12833 }
12834 if (SHAM) createNonEnumerableProperty$4(RegExpPrototype$1[SYMBOL], 'sham', true);
12835 };
12836 var fixRegexpWellKnownSymbolLogic$1 = /*@__PURE__*/getDefaultExportFromCjs(fixRegexpWellKnownSymbolLogic);
12837
12838 'use strict';
12839 var charAt$6 = stringMultibyte.charAt;
12840
12841 // `AdvanceStringIndex` abstract operation
12842 // https://tc39.es/ecma262/#sec-advancestringindex
12843 var advanceStringIndex$4 = function advanceStringIndex(S, index, unicode) {
12844 return index + (unicode ? charAt$6(S, index).length : 1);
12845 };
12846 var advanceStringIndex$5 = /*@__PURE__*/getDefaultExportFromCjs(advanceStringIndex$4);
12847
12848 var call$e = functionCall;
12849 var anObject$9 = anObject$D;
12850 var isCallable$5 = isCallable$z;
12851 var classof$6 = classofRaw$2;
12852 var regexpExec$1 = regexpExec$3;
12853 var $TypeError$4 = TypeError;
12854
12855 // `RegExpExec` abstract operation
12856 // https://tc39.es/ecma262/#sec-regexpexec
12857 var regexpExecAbstract = function regexpExecAbstract(R, S) {
12858 var exec = R.exec;
12859 if (isCallable$5(exec)) {
12860 var result = call$e(exec, R, S);
12861 if (result !== null) anObject$9(result);
12862 return result;
12863 }
12864 if (classof$6(R) === 'RegExp') return call$e(regexpExec$1, R, S);
12865 throw $TypeError$4('RegExp#exec called on incompatible receiver');
12866 };
12867 var regexpExecAbstract$1 = /*@__PURE__*/getDefaultExportFromCjs(regexpExecAbstract);
12868
12869 'use strict';
12870 var call$d = functionCall;
12871 var fixRegExpWellKnownSymbolLogic$3 = fixRegexpWellKnownSymbolLogic;
12872 var anObject$8 = anObject$D;
12873 var isNullOrUndefined$7 = isNullOrUndefined$e;
12874 var toLength$6 = toLength$d;
12875 var toString$c = toString$x;
12876 var requireObjectCoercible$8 = requireObjectCoercible$j;
12877 var getMethod$5 = getMethod$9;
12878 var advanceStringIndex$3 = advanceStringIndex$4;
12879 var regExpExec$3 = regexpExecAbstract;
12880
12881 // @@match logic
12882 fixRegExpWellKnownSymbolLogic$3('match', function (MATCH, nativeMatch, maybeCallNative) {
12883 return [
12884 // `String.prototype.match` method
12885 // https://tc39.es/ecma262/#sec-string.prototype.match
12886 function match(regexp) {
12887 var O = requireObjectCoercible$8(this);
12888 var matcher = isNullOrUndefined$7(regexp) ? undefined : getMethod$5(regexp, MATCH);
12889 return matcher ? call$d(matcher, regexp, O) : new RegExp(regexp)[MATCH](toString$c(O));
12890 },
12891 // `RegExp.prototype[@@match]` method
12892 // https://tc39.es/ecma262/#sec-regexp.prototype-@@match
12893 function (string) {
12894 var rx = anObject$8(this);
12895 var S = toString$c(string);
12896 var res = maybeCallNative(nativeMatch, rx, S);
12897 if (res.done) return res.value;
12898 if (!rx.global) return regExpExec$3(rx, S);
12899 var fullUnicode = rx.unicode;
12900 rx.lastIndex = 0;
12901 var A = [];
12902 var n = 0;
12903 var result;
12904 while ((result = regExpExec$3(rx, S)) !== null) {
12905 var matchStr = toString$c(result[0]);
12906 A[n] = matchStr;
12907 if (matchStr === '') rx.lastIndex = advanceStringIndex$3(S, toLength$6(rx.lastIndex), fullUnicode);
12908 n++;
12909 }
12910 return n === 0 ? null : A;
12911 }];
12912 });
12913
12914 var es_string_matchAll = {};
12915
12916 'use strict';
12917 /* eslint-disable es/no-string-prototype-matchall -- safe */
12918 var $$F = _export;
12919 var call$c = functionCall;
12920 var uncurryThis$r = functionUncurryThisClause;
12921 var createIteratorConstructor$1 = iteratorCreateConstructor;
12922 var createIterResultObject = createIterResultObject$4;
12923 var requireObjectCoercible$7 = requireObjectCoercible$j;
12924 var toLength$5 = toLength$d;
12925 var toString$b = toString$x;
12926 var anObject$7 = anObject$D;
12927 var isNullOrUndefined$6 = isNullOrUndefined$e;
12928 var classof$5 = classofRaw$2;
12929 var isRegExp$2 = isRegexp;
12930 var getRegExpFlags$2 = regexpGetFlags;
12931 var getMethod$4 = getMethod$9;
12932 var defineBuiltIn$3 = defineBuiltIn$m;
12933 var fails$j = fails$1m;
12934 var wellKnownSymbol$6 = wellKnownSymbol$z;
12935 var speciesConstructor$2 = speciesConstructor$6;
12936 var advanceStringIndex$2 = advanceStringIndex$4;
12937 var regExpExec$2 = regexpExecAbstract;
12938 var InternalStateModule$5 = internalState;
12939 var IS_PURE$6 = isPure;
12940 var MATCH_ALL = wellKnownSymbol$6('matchAll');
12941 var REGEXP_STRING = 'RegExp String';
12942 var REGEXP_STRING_ITERATOR = REGEXP_STRING + ' Iterator';
12943 var setInternalState$5 = InternalStateModule$5.set;
12944 var getInternalState$2 = InternalStateModule$5.getterFor(REGEXP_STRING_ITERATOR);
12945 var RegExpPrototype = RegExp.prototype;
12946 var $TypeError$3 = TypeError;
12947 var stringIndexOf$2 = uncurryThis$r(''.indexOf);
12948 var nativeMatchAll = uncurryThis$r(''.matchAll);
12949 var WORKS_WITH_NON_GLOBAL_REGEX = !!nativeMatchAll && !fails$j(function () {
12950 nativeMatchAll('a', /./);
12951 });
12952 var $RegExpStringIterator = createIteratorConstructor$1(function RegExpStringIterator(regexp, string, $global, fullUnicode) {
12953 setInternalState$5(this, {
12954 type: REGEXP_STRING_ITERATOR,
12955 regexp: regexp,
12956 string: string,
12957 global: $global,
12958 unicode: fullUnicode,
12959 done: false
12960 });
12961 }, REGEXP_STRING, function next() {
12962 var state = getInternalState$2(this);
12963 if (state.done) return createIterResultObject(undefined, true);
12964 var R = state.regexp;
12965 var S = state.string;
12966 var match = regExpExec$2(R, S);
12967 if (match === null) {
12968 state.done = true;
12969 return createIterResultObject(undefined, true);
12970 }
12971 if (state.global) {
12972 if (toString$b(match[0]) === '') R.lastIndex = advanceStringIndex$2(S, toLength$5(R.lastIndex), state.unicode);
12973 return createIterResultObject(match, false);
12974 }
12975 state.done = true;
12976 return createIterResultObject(match, false);
12977 });
12978 var $matchAll = function $matchAll(string) {
12979 var R = anObject$7(this);
12980 var S = toString$b(string);
12981 var C = speciesConstructor$2(R, RegExp);
12982 var flags = toString$b(getRegExpFlags$2(R));
12983 var matcher, $global, fullUnicode;
12984 matcher = new C(C === RegExp ? R.source : R, flags);
12985 $global = !!~stringIndexOf$2(flags, 'g');
12986 fullUnicode = !!~stringIndexOf$2(flags, 'u');
12987 matcher.lastIndex = toLength$5(R.lastIndex);
12988 return new $RegExpStringIterator(matcher, S, $global, fullUnicode);
12989 };
12990
12991 // `String.prototype.matchAll` method
12992 // https://tc39.es/ecma262/#sec-string.prototype.matchall
12993 $$F({
12994 target: 'String',
12995 proto: true,
12996 forced: WORKS_WITH_NON_GLOBAL_REGEX
12997 }, {
12998 matchAll: function matchAll(regexp) {
12999 var O = requireObjectCoercible$7(this);
13000 var flags, S, matcher, rx;
13001 if (!isNullOrUndefined$6(regexp)) {
13002 if (isRegExp$2(regexp)) {
13003 flags = toString$b(requireObjectCoercible$7(getRegExpFlags$2(regexp)));
13004 if (!~stringIndexOf$2(flags, 'g')) throw $TypeError$3('`.matchAll` does not allow non-global regexes');
13005 }
13006 if (WORKS_WITH_NON_GLOBAL_REGEX) return nativeMatchAll(O, regexp);
13007 matcher = getMethod$4(regexp, MATCH_ALL);
13008 if (matcher === undefined && IS_PURE$6 && classof$5(regexp) == 'RegExp') matcher = $matchAll;
13009 if (matcher) return call$c(matcher, regexp, O);
13010 } else if (WORKS_WITH_NON_GLOBAL_REGEX) return nativeMatchAll(O, regexp);
13011 S = toString$b(O);
13012 rx = new RegExp(regexp, 'g');
13013 return IS_PURE$6 ? call$c($matchAll, rx, S) : rx[MATCH_ALL](S);
13014 }
13015 });
13016 IS_PURE$6 || MATCH_ALL in RegExpPrototype || defineBuiltIn$3(RegExpPrototype, MATCH_ALL, $matchAll);
13017
13018 var es_string_padEnd = {};
13019
13020 // https://github.com/zloirock/core-js/issues/280
13021 var userAgent = engineUserAgent;
13022 var stringPadWebkitBug = /Version\/10(?:\.\d+){1,2}(?: [\w./]+)?(?: Mobile\/\w+)? Safari\//.test(userAgent);
13023 var stringPadWebkitBug$1 = /*@__PURE__*/getDefaultExportFromCjs(stringPadWebkitBug);
13024
13025 'use strict';
13026 var $$E = _export;
13027 var $padEnd = stringPad.end;
13028 var WEBKIT_BUG$1 = stringPadWebkitBug;
13029
13030 // `String.prototype.padEnd` method
13031 // https://tc39.es/ecma262/#sec-string.prototype.padend
13032 $$E({
13033 target: 'String',
13034 proto: true,
13035 forced: WEBKIT_BUG$1
13036 }, {
13037 padEnd: function padEnd(maxLength /* , fillString = ' ' */) {
13038 return $padEnd(this, maxLength, arguments.length > 1 ? arguments[1] : undefined);
13039 }
13040 });
13041
13042 var es_string_padStart = {};
13043
13044 'use strict';
13045 var $$D = _export;
13046 var $padStart = stringPad.start;
13047 var WEBKIT_BUG = stringPadWebkitBug;
13048
13049 // `String.prototype.padStart` method
13050 // https://tc39.es/ecma262/#sec-string.prototype.padstart
13051 $$D({
13052 target: 'String',
13053 proto: true,
13054 forced: WEBKIT_BUG
13055 }, {
13056 padStart: function padStart(maxLength /* , fillString = ' ' */) {
13057 return $padStart(this, maxLength, arguments.length > 1 ? arguments[1] : undefined);
13058 }
13059 });
13060
13061 var es_string_raw = {};
13062
13063 var $$C = _export;
13064 var uncurryThis$q = functionUncurryThis;
13065 var toIndexedObject$1 = toIndexedObject$j;
13066 var toObject$2 = toObject$t;
13067 var toString$a = toString$x;
13068 var lengthOfArrayLike$4 = lengthOfArrayLike$t;
13069 var push$6 = uncurryThis$q([].push);
13070 var join$4 = uncurryThis$q([].join);
13071
13072 // `String.raw` method
13073 // https://tc39.es/ecma262/#sec-string.raw
13074 $$C({
13075 target: 'String',
13076 stat: true
13077 }, {
13078 raw: function raw(template) {
13079 var rawTemplate = toIndexedObject$1(toObject$2(template).raw);
13080 var literalSegments = lengthOfArrayLike$4(rawTemplate);
13081 if (!literalSegments) return '';
13082 var argumentsLength = arguments.length;
13083 var elements = [];
13084 var i = 0;
13085 while (true) {
13086 push$6(elements, toString$a(rawTemplate[i++]));
13087 if (i === literalSegments) return join$4(elements, '');
13088 if (i < argumentsLength) push$6(elements, toString$a(arguments[i]));
13089 }
13090 }
13091 });
13092
13093 var es_string_repeat = {};
13094
13095 var $$B = _export;
13096 var repeat$1 = stringRepeat;
13097
13098 // `String.prototype.repeat` method
13099 // https://tc39.es/ecma262/#sec-string.prototype.repeat
13100 $$B({
13101 target: 'String',
13102 proto: true
13103 }, {
13104 repeat: repeat$1
13105 });
13106
13107 var es_string_replace = {};
13108
13109 var uncurryThis$p = functionUncurryThis;
13110 var toObject$1 = toObject$t;
13111 var floor$6 = Math.floor;
13112 var charAt$5 = uncurryThis$p(''.charAt);
13113 var replace$6 = uncurryThis$p(''.replace);
13114 var stringSlice$8 = uncurryThis$p(''.slice);
13115 // eslint-disable-next-line redos/no-vulnerable -- safe
13116 var SUBSTITUTION_SYMBOLS = /\$([$&'`]|\d{1,2}|<[^>]*>)/g;
13117 var SUBSTITUTION_SYMBOLS_NO_NAMED = /\$([$&'`]|\d{1,2})/g;
13118
13119 // `GetSubstitution` abstract operation
13120 // https://tc39.es/ecma262/#sec-getsubstitution
13121 var getSubstitution$2 = function getSubstitution(matched, str, position, captures, namedCaptures, replacement) {
13122 var tailPos = position + matched.length;
13123 var m = captures.length;
13124 var symbols = SUBSTITUTION_SYMBOLS_NO_NAMED;
13125 if (namedCaptures !== undefined) {
13126 namedCaptures = toObject$1(namedCaptures);
13127 symbols = SUBSTITUTION_SYMBOLS;
13128 }
13129 return replace$6(replacement, symbols, function (match, ch) {
13130 var capture;
13131 switch (charAt$5(ch, 0)) {
13132 case '$':
13133 return '$';
13134 case '&':
13135 return matched;
13136 case '`':
13137 return stringSlice$8(str, 0, position);
13138 case "'":
13139 return stringSlice$8(str, tailPos);
13140 case '<':
13141 capture = namedCaptures[stringSlice$8(ch, 1, -1)];
13142 break;
13143 default:
13144 // \d\d?
13145 var n = +ch;
13146 if (n === 0) return match;
13147 if (n > m) {
13148 var f = floor$6(n / 10);
13149 if (f === 0) return match;
13150 if (f <= m) return captures[f - 1] === undefined ? charAt$5(ch, 1) : captures[f - 1] + charAt$5(ch, 1);
13151 return match;
13152 }
13153 capture = captures[n - 1];
13154 }
13155 return capture === undefined ? '' : capture;
13156 });
13157 };
13158 var getSubstitution$3 = /*@__PURE__*/getDefaultExportFromCjs(getSubstitution$2);
13159
13160 'use strict';
13161 var apply$4 = functionApply$1;
13162 var call$b = functionCall;
13163 var uncurryThis$o = functionUncurryThis;
13164 var fixRegExpWellKnownSymbolLogic$2 = fixRegexpWellKnownSymbolLogic;
13165 var fails$i = fails$1m;
13166 var anObject$6 = anObject$D;
13167 var isCallable$4 = isCallable$z;
13168 var isNullOrUndefined$5 = isNullOrUndefined$e;
13169 var toIntegerOrInfinity$4 = toIntegerOrInfinity$l;
13170 var toLength$4 = toLength$d;
13171 var toString$9 = toString$x;
13172 var requireObjectCoercible$6 = requireObjectCoercible$j;
13173 var advanceStringIndex$1 = advanceStringIndex$4;
13174 var getMethod$3 = getMethod$9;
13175 var getSubstitution$1 = getSubstitution$2;
13176 var regExpExec$1 = regexpExecAbstract;
13177 var wellKnownSymbol$5 = wellKnownSymbol$z;
13178 var REPLACE$1 = wellKnownSymbol$5('replace');
13179 var max$6 = Math.max;
13180 var min$7 = Math.min;
13181 var concat$3 = uncurryThis$o([].concat);
13182 var push$5 = uncurryThis$o([].push);
13183 var stringIndexOf$1 = uncurryThis$o(''.indexOf);
13184 var stringSlice$7 = uncurryThis$o(''.slice);
13185 var maybeToString = function maybeToString(it) {
13186 return it === undefined ? it : String(it);
13187 };
13188
13189 // IE <= 11 replaces $0 with the whole match, as if it was $&
13190 // https://stackoverflow.com/questions/6024666/getting-ie-to-replace-a-regex-with-the-literal-string-0
13191 var REPLACE_KEEPS_$0 = function () {
13192 // eslint-disable-next-line regexp/prefer-escape-replacement-dollar-char -- required for testing
13193 return 'a'.replace(/./, '$0') === '$0';
13194 }();
13195
13196 // Safari <= 13.0.3(?) substitutes nth capture where n>m with an empty string
13197 var REGEXP_REPLACE_SUBSTITUTES_UNDEFINED_CAPTURE = function () {
13198 if (/./[REPLACE$1]) {
13199 return /./[REPLACE$1]('a', '$0') === '';
13200 }
13201 return false;
13202 }();
13203 var REPLACE_SUPPORTS_NAMED_GROUPS = !fails$i(function () {
13204 var re = /./;
13205 re.exec = function () {
13206 var result = [];
13207 result.groups = {
13208 a: '7'
13209 };
13210 return result;
13211 };
13212 // eslint-disable-next-line regexp/no-useless-dollar-replacements -- false positive
13213 return ''.replace(re, '$<a>') !== '7';
13214 });
13215
13216 // @@replace logic
13217 fixRegExpWellKnownSymbolLogic$2('replace', function (_, nativeReplace, maybeCallNative) {
13218 var UNSAFE_SUBSTITUTE = REGEXP_REPLACE_SUBSTITUTES_UNDEFINED_CAPTURE ? '$' : '$0';
13219 return [
13220 // `String.prototype.replace` method
13221 // https://tc39.es/ecma262/#sec-string.prototype.replace
13222 function replace(searchValue, replaceValue) {
13223 var O = requireObjectCoercible$6(this);
13224 var replacer = isNullOrUndefined$5(searchValue) ? undefined : getMethod$3(searchValue, REPLACE$1);
13225 return replacer ? call$b(replacer, searchValue, O, replaceValue) : call$b(nativeReplace, toString$9(O), searchValue, replaceValue);
13226 },
13227 // `RegExp.prototype[@@replace]` method
13228 // https://tc39.es/ecma262/#sec-regexp.prototype-@@replace
13229 function (string, replaceValue) {
13230 var rx = anObject$6(this);
13231 var S = toString$9(string);
13232 if (typeof replaceValue == 'string' && stringIndexOf$1(replaceValue, UNSAFE_SUBSTITUTE) === -1 && stringIndexOf$1(replaceValue, '$<') === -1) {
13233 var res = maybeCallNative(nativeReplace, rx, S, replaceValue);
13234 if (res.done) return res.value;
13235 }
13236 var functionalReplace = isCallable$4(replaceValue);
13237 if (!functionalReplace) replaceValue = toString$9(replaceValue);
13238 var global = rx.global;
13239 if (global) {
13240 var fullUnicode = rx.unicode;
13241 rx.lastIndex = 0;
13242 }
13243 var results = [];
13244 while (true) {
13245 var result = regExpExec$1(rx, S);
13246 if (result === null) break;
13247 push$5(results, result);
13248 if (!global) break;
13249 var matchStr = toString$9(result[0]);
13250 if (matchStr === '') rx.lastIndex = advanceStringIndex$1(S, toLength$4(rx.lastIndex), fullUnicode);
13251 }
13252 var accumulatedResult = '';
13253 var nextSourcePosition = 0;
13254 for (var i = 0; i < results.length; i++) {
13255 result = results[i];
13256 var matched = toString$9(result[0]);
13257 var position = max$6(min$7(toIntegerOrInfinity$4(result.index), S.length), 0);
13258 var captures = [];
13259 // NOTE: This is equivalent to
13260 // captures = result.slice(1).map(maybeToString)
13261 // but for some reason `nativeSlice.call(result, 1, result.length)` (called in
13262 // the slice polyfill when slicing native arrays) "doesn't work" in safari 9 and
13263 // causes a crash (https://pastebin.com/N21QzeQA) when trying to debug it.
13264 for (var j = 1; j < result.length; j++) push$5(captures, maybeToString(result[j]));
13265 var namedCaptures = result.groups;
13266 if (functionalReplace) {
13267 var replacerArgs = concat$3([matched], captures, position, S);
13268 if (namedCaptures !== undefined) push$5(replacerArgs, namedCaptures);
13269 var replacement = toString$9(apply$4(replaceValue, undefined, replacerArgs));
13270 } else {
13271 replacement = getSubstitution$1(matched, S, position, captures, namedCaptures, replaceValue);
13272 }
13273 if (position >= nextSourcePosition) {
13274 accumulatedResult += stringSlice$7(S, nextSourcePosition, position) + replacement;
13275 nextSourcePosition = position + matched.length;
13276 }
13277 }
13278 return accumulatedResult + stringSlice$7(S, nextSourcePosition);
13279 }];
13280 }, !REPLACE_SUPPORTS_NAMED_GROUPS || !REPLACE_KEEPS_$0 || REGEXP_REPLACE_SUBSTITUTES_UNDEFINED_CAPTURE);
13281
13282 var es_string_replaceAll = {};
13283
13284 'use strict';
13285 var $$A = _export;
13286 var call$a = functionCall;
13287 var uncurryThis$n = functionUncurryThis;
13288 var requireObjectCoercible$5 = requireObjectCoercible$j;
13289 var isCallable$3 = isCallable$z;
13290 var isNullOrUndefined$4 = isNullOrUndefined$e;
13291 var isRegExp$1 = isRegexp;
13292 var toString$8 = toString$x;
13293 var getMethod$2 = getMethod$9;
13294 var getRegExpFlags$1 = regexpGetFlags;
13295 var getSubstitution = getSubstitution$2;
13296 var wellKnownSymbol$4 = wellKnownSymbol$z;
13297 var IS_PURE$5 = isPure;
13298 var REPLACE = wellKnownSymbol$4('replace');
13299 var $TypeError$2 = TypeError;
13300 var indexOf = uncurryThis$n(''.indexOf);
13301 var replace$5 = uncurryThis$n(''.replace);
13302 var stringSlice$6 = uncurryThis$n(''.slice);
13303 var max$5 = Math.max;
13304 var stringIndexOf = function stringIndexOf(string, searchValue, fromIndex) {
13305 if (fromIndex > string.length) return -1;
13306 if (searchValue === '') return fromIndex;
13307 return indexOf(string, searchValue, fromIndex);
13308 };
13309
13310 // `String.prototype.replaceAll` method
13311 // https://tc39.es/ecma262/#sec-string.prototype.replaceall
13312 $$A({
13313 target: 'String',
13314 proto: true
13315 }, {
13316 replaceAll: function replaceAll(searchValue, replaceValue) {
13317 var O = requireObjectCoercible$5(this);
13318 var IS_REG_EXP, flags, replacer, string, searchString, functionalReplace, searchLength, advanceBy, replacement;
13319 var position = 0;
13320 var endOfLastMatch = 0;
13321 var result = '';
13322 if (!isNullOrUndefined$4(searchValue)) {
13323 IS_REG_EXP = isRegExp$1(searchValue);
13324 if (IS_REG_EXP) {
13325 flags = toString$8(requireObjectCoercible$5(getRegExpFlags$1(searchValue)));
13326 if (!~indexOf(flags, 'g')) throw $TypeError$2('`.replaceAll` does not allow non-global regexes');
13327 }
13328 replacer = getMethod$2(searchValue, REPLACE);
13329 if (replacer) {
13330 return call$a(replacer, searchValue, O, replaceValue);
13331 } else if (IS_PURE$5 && IS_REG_EXP) {
13332 return replace$5(toString$8(O), searchValue, replaceValue);
13333 }
13334 }
13335 string = toString$8(O);
13336 searchString = toString$8(searchValue);
13337 functionalReplace = isCallable$3(replaceValue);
13338 if (!functionalReplace) replaceValue = toString$8(replaceValue);
13339 searchLength = searchString.length;
13340 advanceBy = max$5(1, searchLength);
13341 position = stringIndexOf(string, searchString, 0);
13342 while (position !== -1) {
13343 replacement = functionalReplace ? toString$8(replaceValue(searchString, position, string)) : getSubstitution(searchString, string, position, [], undefined, replaceValue);
13344 result += stringSlice$6(string, endOfLastMatch, position) + replacement;
13345 endOfLastMatch = position + searchLength;
13346 position = stringIndexOf(string, searchString, position + advanceBy);
13347 }
13348 if (endOfLastMatch < string.length) {
13349 result += stringSlice$6(string, endOfLastMatch);
13350 }
13351 return result;
13352 }
13353 });
13354
13355 var es_string_search = {};
13356
13357 'use strict';
13358 var call$9 = functionCall;
13359 var fixRegExpWellKnownSymbolLogic$1 = fixRegexpWellKnownSymbolLogic;
13360 var anObject$5 = anObject$D;
13361 var isNullOrUndefined$3 = isNullOrUndefined$e;
13362 var requireObjectCoercible$4 = requireObjectCoercible$j;
13363 var sameValue = sameValue$1;
13364 var toString$7 = toString$x;
13365 var getMethod$1 = getMethod$9;
13366 var regExpExec = regexpExecAbstract;
13367
13368 // @@search logic
13369 fixRegExpWellKnownSymbolLogic$1('search', function (SEARCH, nativeSearch, maybeCallNative) {
13370 return [
13371 // `String.prototype.search` method
13372 // https://tc39.es/ecma262/#sec-string.prototype.search
13373 function search(regexp) {
13374 var O = requireObjectCoercible$4(this);
13375 var searcher = isNullOrUndefined$3(regexp) ? undefined : getMethod$1(regexp, SEARCH);
13376 return searcher ? call$9(searcher, regexp, O) : new RegExp(regexp)[SEARCH](toString$7(O));
13377 },
13378 // `RegExp.prototype[@@search]` method
13379 // https://tc39.es/ecma262/#sec-regexp.prototype-@@search
13380 function (string) {
13381 var rx = anObject$5(this);
13382 var S = toString$7(string);
13383 var res = maybeCallNative(nativeSearch, rx, S);
13384 if (res.done) return res.value;
13385 var previousLastIndex = rx.lastIndex;
13386 if (!sameValue(previousLastIndex, 0)) rx.lastIndex = 0;
13387 var result = regExpExec(rx, S);
13388 if (!sameValue(rx.lastIndex, previousLastIndex)) rx.lastIndex = previousLastIndex;
13389 return result === null ? -1 : result.index;
13390 }];
13391 });
13392
13393 var es_string_split = {};
13394
13395 'use strict';
13396 var apply$3 = functionApply$1;
13397 var call$8 = functionCall;
13398 var uncurryThis$m = functionUncurryThis;
13399 var fixRegExpWellKnownSymbolLogic = fixRegexpWellKnownSymbolLogic;
13400 var anObject$4 = anObject$D;
13401 var isNullOrUndefined$2 = isNullOrUndefined$e;
13402 var isRegExp = isRegexp;
13403 var requireObjectCoercible$3 = requireObjectCoercible$j;
13404 var speciesConstructor$1 = speciesConstructor$6;
13405 var advanceStringIndex = advanceStringIndex$4;
13406 var toLength$3 = toLength$d;
13407 var toString$6 = toString$x;
13408 var getMethod = getMethod$9;
13409 var arraySlice$4 = arraySliceSimple;
13410 var callRegExpExec = regexpExecAbstract;
13411 var regexpExec = regexpExec$3;
13412 var stickyHelpers = regexpStickyHelpers;
13413 var fails$h = fails$1m;
13414 var UNSUPPORTED_Y = stickyHelpers.UNSUPPORTED_Y;
13415 var MAX_UINT32 = 0xFFFFFFFF;
13416 var min$6 = Math.min;
13417 var $push = [].push;
13418 var exec$4 = uncurryThis$m(/./.exec);
13419 var push$4 = uncurryThis$m($push);
13420 var stringSlice$5 = uncurryThis$m(''.slice);
13421
13422 // Chrome 51 has a buggy "split" implementation when RegExp#exec !== nativeExec
13423 // Weex JS has frozen built-in prototypes, so use try / catch wrapper
13424 var SPLIT_WORKS_WITH_OVERWRITTEN_EXEC = !fails$h(function () {
13425 // eslint-disable-next-line regexp/no-empty-group -- required for testing
13426 var re = /(?:)/;
13427 var originalExec = re.exec;
13428 re.exec = function () {
13429 return originalExec.apply(this, arguments);
13430 };
13431 var result = 'ab'.split(re);
13432 return result.length !== 2 || result[0] !== 'a' || result[1] !== 'b';
13433 });
13434
13435 // @@split logic
13436 fixRegExpWellKnownSymbolLogic('split', function (SPLIT, nativeSplit, maybeCallNative) {
13437 var internalSplit;
13438 if ('abbc'.split(/(b)*/)[1] == 'c' ||
13439 // eslint-disable-next-line regexp/no-empty-group -- required for testing
13440 'test'.split(/(?:)/, -1).length != 4 || 'ab'.split(/(?:ab)*/).length != 2 || '.'.split(/(.?)(.?)/).length != 4 ||
13441 // eslint-disable-next-line regexp/no-empty-capturing-group, regexp/no-empty-group -- required for testing
13442 '.'.split(/()()/).length > 1 || ''.split(/.?/).length) {
13443 // based on es5-shim implementation, need to rework it
13444 internalSplit = function internalSplit(separator, limit) {
13445 var string = toString$6(requireObjectCoercible$3(this));
13446 var lim = limit === undefined ? MAX_UINT32 : limit >>> 0;
13447 if (lim === 0) return [];
13448 if (separator === undefined) return [string];
13449 // If `separator` is not a regex, use native split
13450 if (!isRegExp(separator)) {
13451 return call$8(nativeSplit, string, separator, lim);
13452 }
13453 var output = [];
13454 var flags = (separator.ignoreCase ? 'i' : '') + (separator.multiline ? 'm' : '') + (separator.unicode ? 'u' : '') + (separator.sticky ? 'y' : '');
13455 var lastLastIndex = 0;
13456 // Make `global` and avoid `lastIndex` issues by working with a copy
13457 var separatorCopy = new RegExp(separator.source, flags + 'g');
13458 var match, lastIndex, lastLength;
13459 while (match = call$8(regexpExec, separatorCopy, string)) {
13460 lastIndex = separatorCopy.lastIndex;
13461 if (lastIndex > lastLastIndex) {
13462 push$4(output, stringSlice$5(string, lastLastIndex, match.index));
13463 if (match.length > 1 && match.index < string.length) apply$3($push, output, arraySlice$4(match, 1));
13464 lastLength = match[0].length;
13465 lastLastIndex = lastIndex;
13466 if (output.length >= lim) break;
13467 }
13468 if (separatorCopy.lastIndex === match.index) separatorCopy.lastIndex++; // Avoid an infinite loop
13469 }
13470
13471 if (lastLastIndex === string.length) {
13472 if (lastLength || !exec$4(separatorCopy, '')) push$4(output, '');
13473 } else push$4(output, stringSlice$5(string, lastLastIndex));
13474 return output.length > lim ? arraySlice$4(output, 0, lim) : output;
13475 };
13476 // Chakra, V8
13477 } else if ('0'.split(undefined, 0).length) {
13478 internalSplit = function internalSplit(separator, limit) {
13479 return separator === undefined && limit === 0 ? [] : call$8(nativeSplit, this, separator, limit);
13480 };
13481 } else internalSplit = nativeSplit;
13482 return [
13483 // `String.prototype.split` method
13484 // https://tc39.es/ecma262/#sec-string.prototype.split
13485 function split(separator, limit) {
13486 var O = requireObjectCoercible$3(this);
13487 var splitter = isNullOrUndefined$2(separator) ? undefined : getMethod(separator, SPLIT);
13488 return splitter ? call$8(splitter, separator, O, limit) : call$8(internalSplit, toString$6(O), separator, limit);
13489 },
13490 // `RegExp.prototype[@@split]` method
13491 // https://tc39.es/ecma262/#sec-regexp.prototype-@@split
13492 //
13493 // NOTE: This cannot be properly polyfilled in engines that don't support
13494 // the 'y' flag.
13495 function (string, limit) {
13496 var rx = anObject$4(this);
13497 var S = toString$6(string);
13498 var res = maybeCallNative(internalSplit, rx, S, limit, internalSplit !== nativeSplit);
13499 if (res.done) return res.value;
13500 var C = speciesConstructor$1(rx, RegExp);
13501 var unicodeMatching = rx.unicode;
13502 var flags = (rx.ignoreCase ? 'i' : '') + (rx.multiline ? 'm' : '') + (rx.unicode ? 'u' : '') + (UNSUPPORTED_Y ? 'g' : 'y');
13503
13504 // ^(? + rx + ) is needed, in combination with some S slicing, to
13505 // simulate the 'y' flag.
13506 var splitter = new C(UNSUPPORTED_Y ? '^(?:' + rx.source + ')' : rx, flags);
13507 var lim = limit === undefined ? MAX_UINT32 : limit >>> 0;
13508 if (lim === 0) return [];
13509 if (S.length === 0) return callRegExpExec(splitter, S) === null ? [S] : [];
13510 var p = 0;
13511 var q = 0;
13512 var A = [];
13513 while (q < S.length) {
13514 splitter.lastIndex = UNSUPPORTED_Y ? 0 : q;
13515 var z = callRegExpExec(splitter, UNSUPPORTED_Y ? stringSlice$5(S, q) : S);
13516 var e;
13517 if (z === null || (e = min$6(toLength$3(splitter.lastIndex + (UNSUPPORTED_Y ? q : 0)), S.length)) === p) {
13518 q = advanceStringIndex(S, q, unicodeMatching);
13519 } else {
13520 push$4(A, stringSlice$5(S, p, q));
13521 if (A.length === lim) return A;
13522 for (var i = 1; i <= z.length - 1; i++) {
13523 push$4(A, z[i]);
13524 if (A.length === lim) return A;
13525 }
13526 q = p = e;
13527 }
13528 }
13529 push$4(A, stringSlice$5(S, p));
13530 return A;
13531 }];
13532 }, !SPLIT_WORKS_WITH_OVERWRITTEN_EXEC, UNSUPPORTED_Y);
13533
13534 var es_string_startsWith = {};
13535
13536 'use strict';
13537 var $$z = _export;
13538 var uncurryThis$l = functionUncurryThisClause;
13539 var getOwnPropertyDescriptor$1 = objectGetOwnPropertyDescriptor.f;
13540 var toLength$2 = toLength$d;
13541 var toString$5 = toString$x;
13542 var notARegExp = notARegexp;
13543 var requireObjectCoercible$2 = requireObjectCoercible$j;
13544 var correctIsRegExpLogic = correctIsRegexpLogic;
13545 var IS_PURE$4 = isPure;
13546
13547 // eslint-disable-next-line es/no-string-prototype-startswith -- safe
13548 var nativeStartsWith = uncurryThis$l(''.startsWith);
13549 var stringSlice$4 = uncurryThis$l(''.slice);
13550 var min$5 = Math.min;
13551 var CORRECT_IS_REGEXP_LOGIC = correctIsRegExpLogic('startsWith');
13552 // https://github.com/zloirock/core-js/pull/702
13553 var MDN_POLYFILL_BUG = !IS_PURE$4 && !CORRECT_IS_REGEXP_LOGIC && !!function () {
13554 var descriptor = getOwnPropertyDescriptor$1(String.prototype, 'startsWith');
13555 return descriptor && !descriptor.writable;
13556 }();
13557
13558 // `String.prototype.startsWith` method
13559 // https://tc39.es/ecma262/#sec-string.prototype.startswith
13560 $$z({
13561 target: 'String',
13562 proto: true,
13563 forced: !MDN_POLYFILL_BUG && !CORRECT_IS_REGEXP_LOGIC
13564 }, {
13565 startsWith: function startsWith(searchString /* , position = 0 */) {
13566 var that = toString$5(requireObjectCoercible$2(this));
13567 notARegExp(searchString);
13568 var index = toLength$2(min$5(arguments.length > 1 ? arguments[1] : undefined, that.length));
13569 var search = toString$5(searchString);
13570 return nativeStartsWith ? nativeStartsWith(that, search, index) : stringSlice$4(that, index, index + search.length) === search;
13571 }
13572 });
13573
13574 var es_string_substr = {};
13575
13576 'use strict';
13577 var $$y = _export;
13578 var uncurryThis$k = functionUncurryThis;
13579 var requireObjectCoercible$1 = requireObjectCoercible$j;
13580 var toIntegerOrInfinity$3 = toIntegerOrInfinity$l;
13581 var toString$4 = toString$x;
13582 var stringSlice$3 = uncurryThis$k(''.slice);
13583 var max$4 = Math.max;
13584 var min$4 = Math.min;
13585
13586 // eslint-disable-next-line unicorn/prefer-string-slice -- required for testing
13587 var FORCED$2 = !''.substr || 'ab'.substr(-1) !== 'b';
13588
13589 // `String.prototype.substr` method
13590 // https://tc39.es/ecma262/#sec-string.prototype.substr
13591 $$y({
13592 target: 'String',
13593 proto: true,
13594 forced: FORCED$2
13595 }, {
13596 substr: function substr(start, length) {
13597 var that = toString$4(requireObjectCoercible$1(this));
13598 var size = that.length;
13599 var intStart = toIntegerOrInfinity$3(start);
13600 var intLength, intEnd;
13601 if (intStart === Infinity) intStart = 0;
13602 if (intStart < 0) intStart = max$4(size + intStart, 0);
13603 intLength = length === undefined ? size : toIntegerOrInfinity$3(length);
13604 if (intLength <= 0 || intLength === Infinity) return '';
13605 intEnd = min$4(intStart + intLength, size);
13606 return intStart >= intEnd ? '' : stringSlice$3(that, intStart, intEnd);
13607 }
13608 });
13609
13610 var es_string_trim = {};
13611
13612 var PROPER_FUNCTION_NAME = functionName.PROPER;
13613 var fails$g = fails$1m;
13614 var whitespaces$1 = whitespaces$5;
13615 var non = "\u200B\x85\u180E";
13616
13617 // check that a method works with the correct list
13618 // of whitespaces and has a correct name
13619 var stringTrimForced = function stringTrimForced(METHOD_NAME) {
13620 return fails$g(function () {
13621 return !!whitespaces$1[METHOD_NAME]() || non[METHOD_NAME]() !== non || PROPER_FUNCTION_NAME && whitespaces$1[METHOD_NAME].name !== METHOD_NAME;
13622 });
13623 };
13624 var stringTrimForced$1 = /*@__PURE__*/getDefaultExportFromCjs(stringTrimForced);
13625
13626 'use strict';
13627 var $$x = _export;
13628 var $trim = stringTrim.trim;
13629 var forcedStringTrimMethod$2 = stringTrimForced;
13630
13631 // `String.prototype.trim` method
13632 // https://tc39.es/ecma262/#sec-string.prototype.trim
13633 $$x({
13634 target: 'String',
13635 proto: true,
13636 forced: forcedStringTrimMethod$2('trim')
13637 }, {
13638 trim: function trim() {
13639 return $trim(this);
13640 }
13641 });
13642
13643 var es_string_trimEnd = {};
13644
13645 var es_string_trimRight = {};
13646
13647 'use strict';
13648 var $trimEnd = stringTrim.end;
13649 var forcedStringTrimMethod$1 = stringTrimForced;
13650
13651 // `String.prototype.{ trimEnd, trimRight }` method
13652 // https://tc39.es/ecma262/#sec-string.prototype.trimend
13653 // https://tc39.es/ecma262/#String.prototype.trimright
13654 var stringTrimEnd = forcedStringTrimMethod$1('trimEnd') ? function trimEnd() {
13655 return $trimEnd(this);
13656 // eslint-disable-next-line es/no-string-prototype-trimstart-trimend -- safe
13657 } : ''.trimEnd;
13658 var stringTrimEnd$1 = /*@__PURE__*/getDefaultExportFromCjs(stringTrimEnd);
13659
13660 var $$w = _export;
13661 var trimEnd$1 = stringTrimEnd;
13662
13663 // `String.prototype.trimRight` method
13664 // https://tc39.es/ecma262/#sec-string.prototype.trimend
13665 // eslint-disable-next-line es/no-string-prototype-trimleft-trimright -- safe
13666 $$w({
13667 target: 'String',
13668 proto: true,
13669 name: 'trimEnd',
13670 forced: ''.trimRight !== trimEnd$1
13671 }, {
13672 trimRight: trimEnd$1
13673 });
13674
13675 // TODO: Remove this line from `core-js@4`
13676
13677 var $$v = _export;
13678 var trimEnd = stringTrimEnd;
13679
13680 // `String.prototype.trimEnd` method
13681 // https://tc39.es/ecma262/#sec-string.prototype.trimend
13682 // eslint-disable-next-line es/no-string-prototype-trimstart-trimend -- safe
13683 $$v({
13684 target: 'String',
13685 proto: true,
13686 name: 'trimEnd',
13687 forced: ''.trimEnd !== trimEnd
13688 }, {
13689 trimEnd: trimEnd
13690 });
13691
13692 var es_string_trimStart = {};
13693
13694 var es_string_trimLeft = {};
13695
13696 'use strict';
13697 var $trimStart = stringTrim.start;
13698 var forcedStringTrimMethod = stringTrimForced;
13699
13700 // `String.prototype.{ trimStart, trimLeft }` method
13701 // https://tc39.es/ecma262/#sec-string.prototype.trimstart
13702 // https://tc39.es/ecma262/#String.prototype.trimleft
13703 var stringTrimStart = forcedStringTrimMethod('trimStart') ? function trimStart() {
13704 return $trimStart(this);
13705 // eslint-disable-next-line es/no-string-prototype-trimstart-trimend -- safe
13706 } : ''.trimStart;
13707 var stringTrimStart$1 = /*@__PURE__*/getDefaultExportFromCjs(stringTrimStart);
13708
13709 var $$u = _export;
13710 var trimStart$1 = stringTrimStart;
13711
13712 // `String.prototype.trimLeft` method
13713 // https://tc39.es/ecma262/#sec-string.prototype.trimleft
13714 // eslint-disable-next-line es/no-string-prototype-trimleft-trimright -- safe
13715 $$u({
13716 target: 'String',
13717 proto: true,
13718 name: 'trimStart',
13719 forced: ''.trimLeft !== trimStart$1
13720 }, {
13721 trimLeft: trimStart$1
13722 });
13723
13724 // TODO: Remove this line from `core-js@4`
13725
13726 var $$t = _export;
13727 var trimStart = stringTrimStart;
13728
13729 // `String.prototype.trimStart` method
13730 // https://tc39.es/ecma262/#sec-string.prototype.trimstart
13731 // eslint-disable-next-line es/no-string-prototype-trimstart-trimend -- safe
13732 $$t({
13733 target: 'String',
13734 proto: true,
13735 name: 'trimStart',
13736 forced: ''.trimStart !== trimStart
13737 }, {
13738 trimStart: trimStart
13739 });
13740
13741 var es_string_anchor = {};
13742
13743 var uncurryThis$j = functionUncurryThis;
13744 var requireObjectCoercible = requireObjectCoercible$j;
13745 var toString$3 = toString$x;
13746 var quot = /"/g;
13747 var replace$4 = uncurryThis$j(''.replace);
13748
13749 // `CreateHTML` abstract operation
13750 // https://tc39.es/ecma262/#sec-createhtml
13751 var createHtml = function createHtml(string, tag, attribute, value) {
13752 var S = toString$3(requireObjectCoercible(string));
13753 var p1 = '<' + tag;
13754 if (attribute !== '') p1 += ' ' + attribute + '="' + replace$4(toString$3(value), quot, '&quot;') + '"';
13755 return p1 + '>' + S + '</' + tag + '>';
13756 };
13757 var createHtml$1 = /*@__PURE__*/getDefaultExportFromCjs(createHtml);
13758
13759 var fails$f = fails$1m;
13760
13761 // check the existence of a method, lowercase
13762 // of a tag and escaping quotes in arguments
13763 var stringHtmlForced = function stringHtmlForced(METHOD_NAME) {
13764 return fails$f(function () {
13765 var test = ''[METHOD_NAME]('"');
13766 return test !== test.toLowerCase() || test.split('"').length > 3;
13767 });
13768 };
13769 var stringHtmlForced$1 = /*@__PURE__*/getDefaultExportFromCjs(stringHtmlForced);
13770
13771 'use strict';
13772 var $$s = _export;
13773 var createHTML$c = createHtml;
13774 var forcedStringHTMLMethod$c = stringHtmlForced;
13775
13776 // `String.prototype.anchor` method
13777 // https://tc39.es/ecma262/#sec-string.prototype.anchor
13778 $$s({
13779 target: 'String',
13780 proto: true,
13781 forced: forcedStringHTMLMethod$c('anchor')
13782 }, {
13783 anchor: function anchor(name) {
13784 return createHTML$c(this, 'a', 'name', name);
13785 }
13786 });
13787
13788 var es_string_big = {};
13789
13790 'use strict';
13791 var $$r = _export;
13792 var createHTML$b = createHtml;
13793 var forcedStringHTMLMethod$b = stringHtmlForced;
13794
13795 // `String.prototype.big` method
13796 // https://tc39.es/ecma262/#sec-string.prototype.big
13797 $$r({
13798 target: 'String',
13799 proto: true,
13800 forced: forcedStringHTMLMethod$b('big')
13801 }, {
13802 big: function big() {
13803 return createHTML$b(this, 'big', '', '');
13804 }
13805 });
13806
13807 var es_string_blink = {};
13808
13809 'use strict';
13810 var $$q = _export;
13811 var createHTML$a = createHtml;
13812 var forcedStringHTMLMethod$a = stringHtmlForced;
13813
13814 // `String.prototype.blink` method
13815 // https://tc39.es/ecma262/#sec-string.prototype.blink
13816 $$q({
13817 target: 'String',
13818 proto: true,
13819 forced: forcedStringHTMLMethod$a('blink')
13820 }, {
13821 blink: function blink() {
13822 return createHTML$a(this, 'blink', '', '');
13823 }
13824 });
13825
13826 var es_string_bold = {};
13827
13828 'use strict';
13829 var $$p = _export;
13830 var createHTML$9 = createHtml;
13831 var forcedStringHTMLMethod$9 = stringHtmlForced;
13832
13833 // `String.prototype.bold` method
13834 // https://tc39.es/ecma262/#sec-string.prototype.bold
13835 $$p({
13836 target: 'String',
13837 proto: true,
13838 forced: forcedStringHTMLMethod$9('bold')
13839 }, {
13840 bold: function bold() {
13841 return createHTML$9(this, 'b', '', '');
13842 }
13843 });
13844
13845 var es_string_fixed = {};
13846
13847 'use strict';
13848 var $$o = _export;
13849 var createHTML$8 = createHtml;
13850 var forcedStringHTMLMethod$8 = stringHtmlForced;
13851
13852 // `String.prototype.fixed` method
13853 // https://tc39.es/ecma262/#sec-string.prototype.fixed
13854 $$o({
13855 target: 'String',
13856 proto: true,
13857 forced: forcedStringHTMLMethod$8('fixed')
13858 }, {
13859 fixed: function fixed() {
13860 return createHTML$8(this, 'tt', '', '');
13861 }
13862 });
13863
13864 var es_string_fontcolor = {};
13865
13866 'use strict';
13867 var $$n = _export;
13868 var createHTML$7 = createHtml;
13869 var forcedStringHTMLMethod$7 = stringHtmlForced;
13870
13871 // `String.prototype.fontcolor` method
13872 // https://tc39.es/ecma262/#sec-string.prototype.fontcolor
13873 $$n({
13874 target: 'String',
13875 proto: true,
13876 forced: forcedStringHTMLMethod$7('fontcolor')
13877 }, {
13878 fontcolor: function fontcolor(color) {
13879 return createHTML$7(this, 'font', 'color', color);
13880 }
13881 });
13882
13883 var es_string_fontsize = {};
13884
13885 'use strict';
13886 var $$m = _export;
13887 var createHTML$6 = createHtml;
13888 var forcedStringHTMLMethod$6 = stringHtmlForced;
13889
13890 // `String.prototype.fontsize` method
13891 // https://tc39.es/ecma262/#sec-string.prototype.fontsize
13892 $$m({
13893 target: 'String',
13894 proto: true,
13895 forced: forcedStringHTMLMethod$6('fontsize')
13896 }, {
13897 fontsize: function fontsize(size) {
13898 return createHTML$6(this, 'font', 'size', size);
13899 }
13900 });
13901
13902 var es_string_italics = {};
13903
13904 'use strict';
13905 var $$l = _export;
13906 var createHTML$5 = createHtml;
13907 var forcedStringHTMLMethod$5 = stringHtmlForced;
13908
13909 // `String.prototype.italics` method
13910 // https://tc39.es/ecma262/#sec-string.prototype.italics
13911 $$l({
13912 target: 'String',
13913 proto: true,
13914 forced: forcedStringHTMLMethod$5('italics')
13915 }, {
13916 italics: function italics() {
13917 return createHTML$5(this, 'i', '', '');
13918 }
13919 });
13920
13921 var es_string_link = {};
13922
13923 'use strict';
13924 var $$k = _export;
13925 var createHTML$4 = createHtml;
13926 var forcedStringHTMLMethod$4 = stringHtmlForced;
13927
13928 // `String.prototype.link` method
13929 // https://tc39.es/ecma262/#sec-string.prototype.link
13930 $$k({
13931 target: 'String',
13932 proto: true,
13933 forced: forcedStringHTMLMethod$4('link')
13934 }, {
13935 link: function link(url) {
13936 return createHTML$4(this, 'a', 'href', url);
13937 }
13938 });
13939
13940 var es_string_small = {};
13941
13942 'use strict';
13943 var $$j = _export;
13944 var createHTML$3 = createHtml;
13945 var forcedStringHTMLMethod$3 = stringHtmlForced;
13946
13947 // `String.prototype.small` method
13948 // https://tc39.es/ecma262/#sec-string.prototype.small
13949 $$j({
13950 target: 'String',
13951 proto: true,
13952 forced: forcedStringHTMLMethod$3('small')
13953 }, {
13954 small: function small() {
13955 return createHTML$3(this, 'small', '', '');
13956 }
13957 });
13958
13959 var es_string_strike = {};
13960
13961 'use strict';
13962 var $$i = _export;
13963 var createHTML$2 = createHtml;
13964 var forcedStringHTMLMethod$2 = stringHtmlForced;
13965
13966 // `String.prototype.strike` method
13967 // https://tc39.es/ecma262/#sec-string.prototype.strike
13968 $$i({
13969 target: 'String',
13970 proto: true,
13971 forced: forcedStringHTMLMethod$2('strike')
13972 }, {
13973 strike: function strike() {
13974 return createHTML$2(this, 'strike', '', '');
13975 }
13976 });
13977
13978 var es_string_sub = {};
13979
13980 'use strict';
13981 var $$h = _export;
13982 var createHTML$1 = createHtml;
13983 var forcedStringHTMLMethod$1 = stringHtmlForced;
13984
13985 // `String.prototype.sub` method
13986 // https://tc39.es/ecma262/#sec-string.prototype.sub
13987 $$h({
13988 target: 'String',
13989 proto: true,
13990 forced: forcedStringHTMLMethod$1('sub')
13991 }, {
13992 sub: function sub() {
13993 return createHTML$1(this, 'sub', '', '');
13994 }
13995 });
13996
13997 var es_string_sup = {};
13998
13999 'use strict';
14000 var $$g = _export;
14001 var createHTML = createHtml;
14002 var forcedStringHTMLMethod = stringHtmlForced;
14003
14004 // `String.prototype.sup` method
14005 // https://tc39.es/ecma262/#sec-string.prototype.sup
14006 $$g({
14007 target: 'String',
14008 proto: true,
14009 forced: forcedStringHTMLMethod('sup')
14010 }, {
14011 sup: function sup() {
14012 return createHTML(this, 'sup', '', '');
14013 }
14014 });
14015
14016 var es_typedArray_float32Array = {};
14017
14018 var typedArrayConstructor$2 = {exports: {}};
14019
14020 /* eslint-disable no-new -- required for testing */
14021 var global$o = global$Z;
14022 var fails$e = fails$1m;
14023 var checkCorrectnessOfIteration = checkCorrectnessOfIteration$4;
14024 var NATIVE_ARRAY_BUFFER_VIEWS$1 = arrayBufferViewCore.NATIVE_ARRAY_BUFFER_VIEWS;
14025 var ArrayBuffer$2 = global$o.ArrayBuffer;
14026 var Int8Array$3 = global$o.Int8Array;
14027 var typedArrayConstructorsRequireWrappers = !NATIVE_ARRAY_BUFFER_VIEWS$1 || !fails$e(function () {
14028 Int8Array$3(1);
14029 }) || !fails$e(function () {
14030 new Int8Array$3(-1);
14031 }) || !checkCorrectnessOfIteration(function (iterable) {
14032 new Int8Array$3();
14033 new Int8Array$3(null);
14034 new Int8Array$3(1.5);
14035 new Int8Array$3(iterable);
14036 }, true) || fails$e(function () {
14037 // Safari (11+) bug - a reason why even Safari 13 should load a typed array polyfill
14038 return new Int8Array$3(new ArrayBuffer$2(2), 1, undefined).length !== 1;
14039 });
14040 var typedArrayConstructorsRequireWrappers$1 = /*@__PURE__*/getDefaultExportFromCjs(typedArrayConstructorsRequireWrappers);
14041
14042 var toIntegerOrInfinity$2 = toIntegerOrInfinity$l;
14043 var $RangeError$2 = RangeError;
14044 var toPositiveInteger$1 = function toPositiveInteger(it) {
14045 var result = toIntegerOrInfinity$2(it);
14046 if (result < 0) throw $RangeError$2("The argument can't be less than 0");
14047 return result;
14048 };
14049 var toPositiveInteger$2 = /*@__PURE__*/getDefaultExportFromCjs(toPositiveInteger$1);
14050
14051 var toPositiveInteger = toPositiveInteger$1;
14052 var $RangeError$1 = RangeError;
14053 var toOffset$2 = function toOffset(it, BYTES) {
14054 var offset = toPositiveInteger(it);
14055 if (offset % BYTES) throw $RangeError$1('Wrong offset');
14056 return offset;
14057 };
14058 var toOffset$3 = /*@__PURE__*/getDefaultExportFromCjs(toOffset$2);
14059
14060 var classof$4 = classof$m;
14061 var isBigIntArray$2 = function isBigIntArray(it) {
14062 var klass = classof$4(it);
14063 return klass == 'BigInt64Array' || klass == 'BigUint64Array';
14064 };
14065 var isBigIntArray$3 = /*@__PURE__*/getDefaultExportFromCjs(isBigIntArray$2);
14066
14067 var toPrimitive = toPrimitive$4;
14068 var $TypeError$1 = TypeError;
14069
14070 // `ToBigInt` abstract operation
14071 // https://tc39.es/ecma262/#sec-tobigint
14072 var toBigInt$3 = function toBigInt(argument) {
14073 var prim = toPrimitive(argument, 'number');
14074 if (typeof prim == 'number') throw $TypeError$1("Can't convert number to bigint");
14075 // eslint-disable-next-line es/no-bigint -- safe
14076 return BigInt(prim);
14077 };
14078 var toBigInt$4 = /*@__PURE__*/getDefaultExportFromCjs(toBigInt$3);
14079
14080 var bind$2 = functionBindContext;
14081 var call$7 = functionCall;
14082 var aConstructor = aConstructor$3;
14083 var toObject = toObject$t;
14084 var lengthOfArrayLike$3 = lengthOfArrayLike$t;
14085 var getIterator$1 = getIterator$4;
14086 var getIteratorMethod$1 = getIteratorMethod$5;
14087 var isArrayIteratorMethod = isArrayIteratorMethod$3;
14088 var isBigIntArray$1 = isBigIntArray$2;
14089 var aTypedArrayConstructor$3 = arrayBufferViewCore.aTypedArrayConstructor;
14090 var toBigInt$2 = toBigInt$3;
14091 var typedArrayFrom$2 = function from(source /* , mapfn, thisArg */) {
14092 var C = aConstructor(this);
14093 var O = toObject(source);
14094 var argumentsLength = arguments.length;
14095 var mapfn = argumentsLength > 1 ? arguments[1] : undefined;
14096 var mapping = mapfn !== undefined;
14097 var iteratorMethod = getIteratorMethod$1(O);
14098 var i, length, result, thisIsBigIntArray, value, step, iterator, next;
14099 if (iteratorMethod && !isArrayIteratorMethod(iteratorMethod)) {
14100 iterator = getIterator$1(O, iteratorMethod);
14101 next = iterator.next;
14102 O = [];
14103 while (!(step = call$7(next, iterator)).done) {
14104 O.push(step.value);
14105 }
14106 }
14107 if (mapping && argumentsLength > 2) {
14108 mapfn = bind$2(mapfn, arguments[2]);
14109 }
14110 length = lengthOfArrayLike$3(O);
14111 result = new (aTypedArrayConstructor$3(C))(length);
14112 thisIsBigIntArray = isBigIntArray$1(result);
14113 for (i = 0; length > i; i++) {
14114 value = mapping ? mapfn(O[i], i) : O[i];
14115 // FF30- typed arrays doesn't properly convert objects to typed array values
14116 result[i] = thisIsBigIntArray ? toBigInt$2(value) : +value;
14117 }
14118 return result;
14119 };
14120 var typedArrayFrom$3 = /*@__PURE__*/getDefaultExportFromCjs(typedArrayFrom$2);
14121
14122 var typedArrayConstructor = typedArrayConstructor$2.exports;
14123 'use strict';
14124 var $$f = _export;
14125 var global$n = global$Z;
14126 var call$6 = functionCall;
14127 var DESCRIPTORS$7 = descriptors;
14128 var TYPED_ARRAYS_CONSTRUCTORS_REQUIRES_WRAPPERS$2 = typedArrayConstructorsRequireWrappers;
14129 var ArrayBufferViewCore$u = arrayBufferViewCore;
14130 var ArrayBufferModule = arrayBuffer;
14131 var anInstance$5 = anInstance$a;
14132 var createPropertyDescriptor$3 = createPropertyDescriptor$c;
14133 var createNonEnumerableProperty$3 = createNonEnumerableProperty$f;
14134 var isIntegralNumber = isIntegralNumber$3;
14135 var toLength$1 = toLength$d;
14136 var toIndex = toIndex$2;
14137 var toOffset$1 = toOffset$2;
14138 var toPropertyKey = toPropertyKey$8;
14139 var hasOwn$7 = hasOwnProperty_1;
14140 var classof$3 = classof$m;
14141 var isObject$4 = isObject$z;
14142 var isSymbol$1 = isSymbol$7;
14143 var create$2 = objectCreate;
14144 var isPrototypeOf = objectIsPrototypeOf;
14145 var setPrototypeOf = objectSetPrototypeOf$1;
14146 var getOwnPropertyNames = objectGetOwnPropertyNames.f;
14147 var typedArrayFrom$1 = typedArrayFrom$2;
14148 var forEach$2 = arrayIteration.forEach;
14149 var setSpecies = setSpecies$6;
14150 var defineBuiltInAccessor$5 = defineBuiltInAccessor$h;
14151 var definePropertyModule = objectDefineProperty;
14152 var getOwnPropertyDescriptorModule = objectGetOwnPropertyDescriptor;
14153 var InternalStateModule$4 = internalState;
14154 var inheritIfRequired$1 = inheritIfRequired$6;
14155 var getInternalState$1 = InternalStateModule$4.get;
14156 var setInternalState$4 = InternalStateModule$4.set;
14157 var enforceInternalState$1 = InternalStateModule$4.enforce;
14158 var nativeDefineProperty = definePropertyModule.f;
14159 var nativeGetOwnPropertyDescriptor = getOwnPropertyDescriptorModule.f;
14160 var round$4 = Math.round;
14161 var RangeError$3 = global$n.RangeError;
14162 var ArrayBuffer$1 = ArrayBufferModule.ArrayBuffer;
14163 var ArrayBufferPrototype = ArrayBuffer$1.prototype;
14164 var DataView$1 = ArrayBufferModule.DataView;
14165 var NATIVE_ARRAY_BUFFER_VIEWS = ArrayBufferViewCore$u.NATIVE_ARRAY_BUFFER_VIEWS;
14166 var TYPED_ARRAY_TAG = ArrayBufferViewCore$u.TYPED_ARRAY_TAG;
14167 var TypedArray = ArrayBufferViewCore$u.TypedArray;
14168 var TypedArrayPrototype$1 = ArrayBufferViewCore$u.TypedArrayPrototype;
14169 var aTypedArrayConstructor$2 = ArrayBufferViewCore$u.aTypedArrayConstructor;
14170 var isTypedArray$1 = ArrayBufferViewCore$u.isTypedArray;
14171 var BYTES_PER_ELEMENT = 'BYTES_PER_ELEMENT';
14172 var WRONG_LENGTH = 'Wrong length';
14173 var fromList = function fromList(C, list) {
14174 aTypedArrayConstructor$2(C);
14175 var index = 0;
14176 var length = list.length;
14177 var result = new C(length);
14178 while (length > index) result[index] = list[index++];
14179 return result;
14180 };
14181 var addGetter = function addGetter(it, key) {
14182 defineBuiltInAccessor$5(it, key, {
14183 configurable: true,
14184 get: function get() {
14185 return getInternalState$1(this)[key];
14186 }
14187 });
14188 };
14189 var isArrayBuffer = function isArrayBuffer(it) {
14190 var klass;
14191 return isPrototypeOf(ArrayBufferPrototype, it) || (klass = classof$3(it)) == 'ArrayBuffer' || klass == 'SharedArrayBuffer';
14192 };
14193 var isTypedArrayIndex = function isTypedArrayIndex(target, key) {
14194 return isTypedArray$1(target) && !isSymbol$1(key) && key in target && isIntegralNumber(+key) && key >= 0;
14195 };
14196 var wrappedGetOwnPropertyDescriptor = function getOwnPropertyDescriptor(target, key) {
14197 key = toPropertyKey(key);
14198 return isTypedArrayIndex(target, key) ? createPropertyDescriptor$3(2, target[key]) : nativeGetOwnPropertyDescriptor(target, key);
14199 };
14200 var wrappedDefineProperty = function defineProperty(target, key, descriptor) {
14201 key = toPropertyKey(key);
14202 if (isTypedArrayIndex(target, key) && isObject$4(descriptor) && hasOwn$7(descriptor, 'value') && !hasOwn$7(descriptor, 'get') && !hasOwn$7(descriptor, 'set')
14203 // TODO: add validation descriptor w/o calling accessors
14204 && !descriptor.configurable && (!hasOwn$7(descriptor, 'writable') || descriptor.writable) && (!hasOwn$7(descriptor, 'enumerable') || descriptor.enumerable)) {
14205 target[key] = descriptor.value;
14206 return target;
14207 }
14208 return nativeDefineProperty(target, key, descriptor);
14209 };
14210 if (DESCRIPTORS$7) {
14211 if (!NATIVE_ARRAY_BUFFER_VIEWS) {
14212 getOwnPropertyDescriptorModule.f = wrappedGetOwnPropertyDescriptor;
14213 definePropertyModule.f = wrappedDefineProperty;
14214 addGetter(TypedArrayPrototype$1, 'buffer');
14215 addGetter(TypedArrayPrototype$1, 'byteOffset');
14216 addGetter(TypedArrayPrototype$1, 'byteLength');
14217 addGetter(TypedArrayPrototype$1, 'length');
14218 }
14219 $$f({
14220 target: 'Object',
14221 stat: true,
14222 forced: !NATIVE_ARRAY_BUFFER_VIEWS
14223 }, {
14224 getOwnPropertyDescriptor: wrappedGetOwnPropertyDescriptor,
14225 defineProperty: wrappedDefineProperty
14226 });
14227 typedArrayConstructor$2.exports = function (TYPE, wrapper, CLAMPED) {
14228 var BYTES = TYPE.match(/\d+/)[0] / 8;
14229 var CONSTRUCTOR_NAME = TYPE + (CLAMPED ? 'Clamped' : '') + 'Array';
14230 var GETTER = 'get' + TYPE;
14231 var SETTER = 'set' + TYPE;
14232 var NativeTypedArrayConstructor = global$n[CONSTRUCTOR_NAME];
14233 var TypedArrayConstructor = NativeTypedArrayConstructor;
14234 var TypedArrayConstructorPrototype = TypedArrayConstructor && TypedArrayConstructor.prototype;
14235 var exported = {};
14236 var getter = function getter(that, index) {
14237 var data = getInternalState$1(that);
14238 return data.view[GETTER](index * BYTES + data.byteOffset, true);
14239 };
14240 var setter = function setter(that, index, value) {
14241 var data = getInternalState$1(that);
14242 if (CLAMPED) value = (value = round$4(value)) < 0 ? 0 : value > 0xFF ? 0xFF : value & 0xFF;
14243 data.view[SETTER](index * BYTES + data.byteOffset, value, true);
14244 };
14245 var addElement = function addElement(that, index) {
14246 nativeDefineProperty(that, index, {
14247 get: function get() {
14248 return getter(this, index);
14249 },
14250 set: function set(value) {
14251 return setter(this, index, value);
14252 },
14253 enumerable: true
14254 });
14255 };
14256 if (!NATIVE_ARRAY_BUFFER_VIEWS) {
14257 TypedArrayConstructor = wrapper(function (that, data, offset, $length) {
14258 anInstance$5(that, TypedArrayConstructorPrototype);
14259 var index = 0;
14260 var byteOffset = 0;
14261 var buffer, byteLength, length;
14262 if (!isObject$4(data)) {
14263 length = toIndex(data);
14264 byteLength = length * BYTES;
14265 buffer = new ArrayBuffer$1(byteLength);
14266 } else if (isArrayBuffer(data)) {
14267 buffer = data;
14268 byteOffset = toOffset$1(offset, BYTES);
14269 var $len = data.byteLength;
14270 if ($length === undefined) {
14271 if ($len % BYTES) throw RangeError$3(WRONG_LENGTH);
14272 byteLength = $len - byteOffset;
14273 if (byteLength < 0) throw RangeError$3(WRONG_LENGTH);
14274 } else {
14275 byteLength = toLength$1($length) * BYTES;
14276 if (byteLength + byteOffset > $len) throw RangeError$3(WRONG_LENGTH);
14277 }
14278 length = byteLength / BYTES;
14279 } else if (isTypedArray$1(data)) {
14280 return fromList(TypedArrayConstructor, data);
14281 } else {
14282 return call$6(typedArrayFrom$1, TypedArrayConstructor, data);
14283 }
14284 setInternalState$4(that, {
14285 buffer: buffer,
14286 byteOffset: byteOffset,
14287 byteLength: byteLength,
14288 length: length,
14289 view: new DataView$1(buffer)
14290 });
14291 while (index < length) addElement(that, index++);
14292 });
14293 if (setPrototypeOf) setPrototypeOf(TypedArrayConstructor, TypedArray);
14294 TypedArrayConstructorPrototype = TypedArrayConstructor.prototype = create$2(TypedArrayPrototype$1);
14295 } else if (TYPED_ARRAYS_CONSTRUCTORS_REQUIRES_WRAPPERS$2) {
14296 TypedArrayConstructor = wrapper(function (dummy, data, typedArrayOffset, $length) {
14297 anInstance$5(dummy, TypedArrayConstructorPrototype);
14298 return inheritIfRequired$1(function () {
14299 if (!isObject$4(data)) return new NativeTypedArrayConstructor(toIndex(data));
14300 if (isArrayBuffer(data)) return $length !== undefined ? new NativeTypedArrayConstructor(data, toOffset$1(typedArrayOffset, BYTES), $length) : typedArrayOffset !== undefined ? new NativeTypedArrayConstructor(data, toOffset$1(typedArrayOffset, BYTES)) : new NativeTypedArrayConstructor(data);
14301 if (isTypedArray$1(data)) return fromList(TypedArrayConstructor, data);
14302 return call$6(typedArrayFrom$1, TypedArrayConstructor, data);
14303 }(), dummy, TypedArrayConstructor);
14304 });
14305 if (setPrototypeOf) setPrototypeOf(TypedArrayConstructor, TypedArray);
14306 forEach$2(getOwnPropertyNames(NativeTypedArrayConstructor), function (key) {
14307 if (!(key in TypedArrayConstructor)) {
14308 createNonEnumerableProperty$3(TypedArrayConstructor, key, NativeTypedArrayConstructor[key]);
14309 }
14310 });
14311 TypedArrayConstructor.prototype = TypedArrayConstructorPrototype;
14312 }
14313 if (TypedArrayConstructorPrototype.constructor !== TypedArrayConstructor) {
14314 createNonEnumerableProperty$3(TypedArrayConstructorPrototype, 'constructor', TypedArrayConstructor);
14315 }
14316 enforceInternalState$1(TypedArrayConstructorPrototype).TypedArrayConstructor = TypedArrayConstructor;
14317 if (TYPED_ARRAY_TAG) {
14318 createNonEnumerableProperty$3(TypedArrayConstructorPrototype, TYPED_ARRAY_TAG, CONSTRUCTOR_NAME);
14319 }
14320 var FORCED = TypedArrayConstructor != NativeTypedArrayConstructor;
14321 exported[CONSTRUCTOR_NAME] = TypedArrayConstructor;
14322 $$f({
14323 global: true,
14324 constructor: true,
14325 forced: FORCED,
14326 sham: !NATIVE_ARRAY_BUFFER_VIEWS
14327 }, exported);
14328 if (!(BYTES_PER_ELEMENT in TypedArrayConstructor)) {
14329 createNonEnumerableProperty$3(TypedArrayConstructor, BYTES_PER_ELEMENT, BYTES);
14330 }
14331 if (!(BYTES_PER_ELEMENT in TypedArrayConstructorPrototype)) {
14332 createNonEnumerableProperty$3(TypedArrayConstructorPrototype, BYTES_PER_ELEMENT, BYTES);
14333 }
14334 setSpecies(CONSTRUCTOR_NAME);
14335 };
14336 } else typedArrayConstructor$2.exports = function () {/* empty */};
14337 var typedArrayConstructorExports = typedArrayConstructor$2.exports;
14338 var typedArrayConstructor$1 = /*@__PURE__*/getDefaultExportFromCjs(typedArrayConstructorExports);
14339
14340 var createTypedArrayConstructor$8 = typedArrayConstructorExports;
14341
14342 // `Float32Array` constructor
14343 // https://tc39.es/ecma262/#sec-typedarray-objects
14344 createTypedArrayConstructor$8('Float32', function (init) {
14345 return function Float32Array(data, byteOffset, length) {
14346 return init(this, data, byteOffset, length);
14347 };
14348 });
14349
14350 var es_typedArray_float64Array = {};
14351
14352 var createTypedArrayConstructor$7 = typedArrayConstructorExports;
14353
14354 // `Float64Array` constructor
14355 // https://tc39.es/ecma262/#sec-typedarray-objects
14356 createTypedArrayConstructor$7('Float64', function (init) {
14357 return function Float64Array(data, byteOffset, length) {
14358 return init(this, data, byteOffset, length);
14359 };
14360 });
14361
14362 var es_typedArray_int8Array = {};
14363
14364 var createTypedArrayConstructor$6 = typedArrayConstructorExports;
14365
14366 // `Int8Array` constructor
14367 // https://tc39.es/ecma262/#sec-typedarray-objects
14368 createTypedArrayConstructor$6('Int8', function (init) {
14369 return function Int8Array(data, byteOffset, length) {
14370 return init(this, data, byteOffset, length);
14371 };
14372 });
14373
14374 var es_typedArray_int16Array = {};
14375
14376 var createTypedArrayConstructor$5 = typedArrayConstructorExports;
14377
14378 // `Int16Array` constructor
14379 // https://tc39.es/ecma262/#sec-typedarray-objects
14380 createTypedArrayConstructor$5('Int16', function (init) {
14381 return function Int16Array(data, byteOffset, length) {
14382 return init(this, data, byteOffset, length);
14383 };
14384 });
14385
14386 var es_typedArray_int32Array = {};
14387
14388 var createTypedArrayConstructor$4 = typedArrayConstructorExports;
14389
14390 // `Int32Array` constructor
14391 // https://tc39.es/ecma262/#sec-typedarray-objects
14392 createTypedArrayConstructor$4('Int32', function (init) {
14393 return function Int32Array(data, byteOffset, length) {
14394 return init(this, data, byteOffset, length);
14395 };
14396 });
14397
14398 var es_typedArray_uint8Array = {};
14399
14400 var createTypedArrayConstructor$3 = typedArrayConstructorExports;
14401
14402 // `Uint8Array` constructor
14403 // https://tc39.es/ecma262/#sec-typedarray-objects
14404 createTypedArrayConstructor$3('Uint8', function (init) {
14405 return function Uint8Array(data, byteOffset, length) {
14406 return init(this, data, byteOffset, length);
14407 };
14408 });
14409
14410 var es_typedArray_uint8ClampedArray = {};
14411
14412 var createTypedArrayConstructor$2 = typedArrayConstructorExports;
14413
14414 // `Uint8ClampedArray` constructor
14415 // https://tc39.es/ecma262/#sec-typedarray-objects
14416 createTypedArrayConstructor$2('Uint8', function (init) {
14417 return function Uint8ClampedArray(data, byteOffset, length) {
14418 return init(this, data, byteOffset, length);
14419 };
14420 }, true);
14421
14422 var es_typedArray_uint16Array = {};
14423
14424 var createTypedArrayConstructor$1 = typedArrayConstructorExports;
14425
14426 // `Uint16Array` constructor
14427 // https://tc39.es/ecma262/#sec-typedarray-objects
14428 createTypedArrayConstructor$1('Uint16', function (init) {
14429 return function Uint16Array(data, byteOffset, length) {
14430 return init(this, data, byteOffset, length);
14431 };
14432 });
14433
14434 var es_typedArray_uint32Array = {};
14435
14436 var createTypedArrayConstructor = typedArrayConstructorExports;
14437
14438 // `Uint32Array` constructor
14439 // https://tc39.es/ecma262/#sec-typedarray-objects
14440 createTypedArrayConstructor('Uint32', function (init) {
14441 return function Uint32Array(data, byteOffset, length) {
14442 return init(this, data, byteOffset, length);
14443 };
14444 });
14445
14446 var es_typedArray_at = {};
14447
14448 'use strict';
14449 var ArrayBufferViewCore$t = arrayBufferViewCore;
14450 var lengthOfArrayLike$2 = lengthOfArrayLike$t;
14451 var toIntegerOrInfinity$1 = toIntegerOrInfinity$l;
14452 var aTypedArray$r = ArrayBufferViewCore$t.aTypedArray;
14453 var exportTypedArrayMethod$s = ArrayBufferViewCore$t.exportTypedArrayMethod;
14454
14455 // `%TypedArray%.prototype.at` method
14456 // https://github.com/tc39/proposal-relative-indexing-method
14457 exportTypedArrayMethod$s('at', function at(index) {
14458 var O = aTypedArray$r(this);
14459 var len = lengthOfArrayLike$2(O);
14460 var relativeIndex = toIntegerOrInfinity$1(index);
14461 var k = relativeIndex >= 0 ? relativeIndex : len + relativeIndex;
14462 return k < 0 || k >= len ? undefined : O[k];
14463 });
14464
14465 var es_typedArray_copyWithin = {};
14466
14467 'use strict';
14468 var uncurryThis$i = functionUncurryThis;
14469 var ArrayBufferViewCore$s = arrayBufferViewCore;
14470 var $ArrayCopyWithin = arrayCopyWithin;
14471 var u$ArrayCopyWithin = uncurryThis$i($ArrayCopyWithin);
14472 var aTypedArray$q = ArrayBufferViewCore$s.aTypedArray;
14473 var exportTypedArrayMethod$r = ArrayBufferViewCore$s.exportTypedArrayMethod;
14474
14475 // `%TypedArray%.prototype.copyWithin` method
14476 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.copywithin
14477 exportTypedArrayMethod$r('copyWithin', function copyWithin(target, start /* , end */) {
14478 return u$ArrayCopyWithin(aTypedArray$q(this), target, start, arguments.length > 2 ? arguments[2] : undefined);
14479 });
14480
14481 var es_typedArray_every = {};
14482
14483 'use strict';
14484 var ArrayBufferViewCore$r = arrayBufferViewCore;
14485 var $every = arrayIteration.every;
14486 var aTypedArray$p = ArrayBufferViewCore$r.aTypedArray;
14487 var exportTypedArrayMethod$q = ArrayBufferViewCore$r.exportTypedArrayMethod;
14488
14489 // `%TypedArray%.prototype.every` method
14490 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.every
14491 exportTypedArrayMethod$q('every', function every(callbackfn /* , thisArg */) {
14492 return $every(aTypedArray$p(this), callbackfn, arguments.length > 1 ? arguments[1] : undefined);
14493 });
14494
14495 var es_typedArray_fill = {};
14496
14497 'use strict';
14498 var ArrayBufferViewCore$q = arrayBufferViewCore;
14499 var $fill = arrayFill$1;
14500 var toBigInt$1 = toBigInt$3;
14501 var classof$2 = classof$m;
14502 var call$5 = functionCall;
14503 var uncurryThis$h = functionUncurryThis;
14504 var fails$d = fails$1m;
14505 var aTypedArray$o = ArrayBufferViewCore$q.aTypedArray;
14506 var exportTypedArrayMethod$p = ArrayBufferViewCore$q.exportTypedArrayMethod;
14507 var slice$3 = uncurryThis$h(''.slice);
14508
14509 // V8 ~ Chrome < 59, Safari < 14.1, FF < 55, Edge <=18
14510 var CONVERSION_BUG = fails$d(function () {
14511 var count = 0;
14512 // eslint-disable-next-line es/no-typed-arrays -- safe
14513 new Int8Array(2).fill({
14514 valueOf: function valueOf() {
14515 return count++;
14516 }
14517 });
14518 return count !== 1;
14519 });
14520
14521 // `%TypedArray%.prototype.fill` method
14522 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.fill
14523 exportTypedArrayMethod$p('fill', function fill(value /* , start, end */) {
14524 var length = arguments.length;
14525 aTypedArray$o(this);
14526 var actualValue = slice$3(classof$2(this), 0, 3) === 'Big' ? toBigInt$1(value) : +value;
14527 return call$5($fill, this, actualValue, length > 1 ? arguments[1] : undefined, length > 2 ? arguments[2] : undefined);
14528 }, CONVERSION_BUG);
14529
14530 var es_typedArray_filter = {};
14531
14532 var ArrayBufferViewCore$p = arrayBufferViewCore;
14533 var speciesConstructor = speciesConstructor$6;
14534 var aTypedArrayConstructor$1 = ArrayBufferViewCore$p.aTypedArrayConstructor;
14535 var getTypedArrayConstructor$3 = ArrayBufferViewCore$p.getTypedArrayConstructor;
14536
14537 // a part of `TypedArraySpeciesCreate` abstract operation
14538 // https://tc39.es/ecma262/#typedarray-species-create
14539 var typedArraySpeciesConstructor$4 = function typedArraySpeciesConstructor(originalArray) {
14540 return aTypedArrayConstructor$1(speciesConstructor(originalArray, getTypedArrayConstructor$3(originalArray)));
14541 };
14542 var typedArraySpeciesConstructor$5 = /*@__PURE__*/getDefaultExportFromCjs(typedArraySpeciesConstructor$4);
14543
14544 var arrayFromConstructorAndList$1 = arrayFromConstructorAndList$3;
14545 var typedArraySpeciesConstructor$3 = typedArraySpeciesConstructor$4;
14546 var typedArrayFromSpeciesAndList = function typedArrayFromSpeciesAndList(instance, list) {
14547 return arrayFromConstructorAndList$1(typedArraySpeciesConstructor$3(instance), list);
14548 };
14549 var typedArrayFromSpeciesAndList$1 = /*@__PURE__*/getDefaultExportFromCjs(typedArrayFromSpeciesAndList);
14550
14551 'use strict';
14552 var ArrayBufferViewCore$o = arrayBufferViewCore;
14553 var $filter = arrayIteration.filter;
14554 var fromSpeciesAndList = typedArrayFromSpeciesAndList;
14555 var aTypedArray$n = ArrayBufferViewCore$o.aTypedArray;
14556 var exportTypedArrayMethod$o = ArrayBufferViewCore$o.exportTypedArrayMethod;
14557
14558 // `%TypedArray%.prototype.filter` method
14559 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.filter
14560 exportTypedArrayMethod$o('filter', function filter(callbackfn /* , thisArg */) {
14561 var list = $filter(aTypedArray$n(this), callbackfn, arguments.length > 1 ? arguments[1] : undefined);
14562 return fromSpeciesAndList(this, list);
14563 });
14564
14565 var es_typedArray_find = {};
14566
14567 'use strict';
14568 var ArrayBufferViewCore$n = arrayBufferViewCore;
14569 var $find = arrayIteration.find;
14570 var aTypedArray$m = ArrayBufferViewCore$n.aTypedArray;
14571 var exportTypedArrayMethod$n = ArrayBufferViewCore$n.exportTypedArrayMethod;
14572
14573 // `%TypedArray%.prototype.find` method
14574 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.find
14575 exportTypedArrayMethod$n('find', function find(predicate /* , thisArg */) {
14576 return $find(aTypedArray$m(this), predicate, arguments.length > 1 ? arguments[1] : undefined);
14577 });
14578
14579 var es_typedArray_findIndex = {};
14580
14581 'use strict';
14582 var ArrayBufferViewCore$m = arrayBufferViewCore;
14583 var $findIndex = arrayIteration.findIndex;
14584 var aTypedArray$l = ArrayBufferViewCore$m.aTypedArray;
14585 var exportTypedArrayMethod$m = ArrayBufferViewCore$m.exportTypedArrayMethod;
14586
14587 // `%TypedArray%.prototype.findIndex` method
14588 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.findindex
14589 exportTypedArrayMethod$m('findIndex', function findIndex(predicate /* , thisArg */) {
14590 return $findIndex(aTypedArray$l(this), predicate, arguments.length > 1 ? arguments[1] : undefined);
14591 });
14592
14593 var es_typedArray_findLast = {};
14594
14595 'use strict';
14596 var ArrayBufferViewCore$l = arrayBufferViewCore;
14597 var $findLast = arrayIterationFromLast.findLast;
14598 var aTypedArray$k = ArrayBufferViewCore$l.aTypedArray;
14599 var exportTypedArrayMethod$l = ArrayBufferViewCore$l.exportTypedArrayMethod;
14600
14601 // `%TypedArray%.prototype.findLast` method
14602 // https://github.com/tc39/proposal-array-find-from-last
14603 exportTypedArrayMethod$l('findLast', function findLast(predicate /* , thisArg */) {
14604 return $findLast(aTypedArray$k(this), predicate, arguments.length > 1 ? arguments[1] : undefined);
14605 });
14606
14607 var es_typedArray_findLastIndex = {};
14608
14609 'use strict';
14610 var ArrayBufferViewCore$k = arrayBufferViewCore;
14611 var $findLastIndex = arrayIterationFromLast.findLastIndex;
14612 var aTypedArray$j = ArrayBufferViewCore$k.aTypedArray;
14613 var exportTypedArrayMethod$k = ArrayBufferViewCore$k.exportTypedArrayMethod;
14614
14615 // `%TypedArray%.prototype.findLastIndex` method
14616 // https://github.com/tc39/proposal-array-find-from-last
14617 exportTypedArrayMethod$k('findLastIndex', function findLastIndex(predicate /* , thisArg */) {
14618 return $findLastIndex(aTypedArray$j(this), predicate, arguments.length > 1 ? arguments[1] : undefined);
14619 });
14620
14621 var es_typedArray_forEach = {};
14622
14623 'use strict';
14624 var ArrayBufferViewCore$j = arrayBufferViewCore;
14625 var $forEach = arrayIteration.forEach;
14626 var aTypedArray$i = ArrayBufferViewCore$j.aTypedArray;
14627 var exportTypedArrayMethod$j = ArrayBufferViewCore$j.exportTypedArrayMethod;
14628
14629 // `%TypedArray%.prototype.forEach` method
14630 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.foreach
14631 exportTypedArrayMethod$j('forEach', function forEach(callbackfn /* , thisArg */) {
14632 $forEach(aTypedArray$i(this), callbackfn, arguments.length > 1 ? arguments[1] : undefined);
14633 });
14634
14635 var es_typedArray_from = {};
14636
14637 'use strict';
14638 var TYPED_ARRAYS_CONSTRUCTORS_REQUIRES_WRAPPERS$1 = typedArrayConstructorsRequireWrappers;
14639 var exportTypedArrayStaticMethod$1 = arrayBufferViewCore.exportTypedArrayStaticMethod;
14640 var typedArrayFrom = typedArrayFrom$2;
14641
14642 // `%TypedArray%.from` method
14643 // https://tc39.es/ecma262/#sec-%typedarray%.from
14644 exportTypedArrayStaticMethod$1('from', typedArrayFrom, TYPED_ARRAYS_CONSTRUCTORS_REQUIRES_WRAPPERS$1);
14645
14646 var es_typedArray_includes = {};
14647
14648 'use strict';
14649 var ArrayBufferViewCore$i = arrayBufferViewCore;
14650 var $includes = arrayIncludes.includes;
14651 var aTypedArray$h = ArrayBufferViewCore$i.aTypedArray;
14652 var exportTypedArrayMethod$i = ArrayBufferViewCore$i.exportTypedArrayMethod;
14653
14654 // `%TypedArray%.prototype.includes` method
14655 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.includes
14656 exportTypedArrayMethod$i('includes', function includes(searchElement /* , fromIndex */) {
14657 return $includes(aTypedArray$h(this), searchElement, arguments.length > 1 ? arguments[1] : undefined);
14658 });
14659
14660 var es_typedArray_indexOf = {};
14661
14662 'use strict';
14663 var ArrayBufferViewCore$h = arrayBufferViewCore;
14664 var $indexOf = arrayIncludes.indexOf;
14665 var aTypedArray$g = ArrayBufferViewCore$h.aTypedArray;
14666 var exportTypedArrayMethod$h = ArrayBufferViewCore$h.exportTypedArrayMethod;
14667
14668 // `%TypedArray%.prototype.indexOf` method
14669 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.indexof
14670 exportTypedArrayMethod$h('indexOf', function indexOf(searchElement /* , fromIndex */) {
14671 return $indexOf(aTypedArray$g(this), searchElement, arguments.length > 1 ? arguments[1] : undefined);
14672 });
14673
14674 var es_typedArray_iterator = {};
14675
14676 'use strict';
14677 var global$m = global$Z;
14678 var fails$c = fails$1m;
14679 var uncurryThis$g = functionUncurryThis;
14680 var ArrayBufferViewCore$g = arrayBufferViewCore;
14681 var ArrayIterators = es_array_iterator;
14682 var wellKnownSymbol$3 = wellKnownSymbol$z;
14683 var ITERATOR$3 = wellKnownSymbol$3('iterator');
14684 var Uint8Array$2 = global$m.Uint8Array;
14685 var arrayValues = uncurryThis$g(ArrayIterators.values);
14686 var arrayKeys = uncurryThis$g(ArrayIterators.keys);
14687 var arrayEntries = uncurryThis$g(ArrayIterators.entries);
14688 var aTypedArray$f = ArrayBufferViewCore$g.aTypedArray;
14689 var exportTypedArrayMethod$g = ArrayBufferViewCore$g.exportTypedArrayMethod;
14690 var TypedArrayPrototype = Uint8Array$2 && Uint8Array$2.prototype;
14691 var GENERIC = !fails$c(function () {
14692 TypedArrayPrototype[ITERATOR$3].call([1]);
14693 });
14694 var ITERATOR_IS_VALUES = !!TypedArrayPrototype && TypedArrayPrototype.values && TypedArrayPrototype[ITERATOR$3] === TypedArrayPrototype.values && TypedArrayPrototype.values.name === 'values';
14695 var typedArrayValues = function values() {
14696 return arrayValues(aTypedArray$f(this));
14697 };
14698
14699 // `%TypedArray%.prototype.entries` method
14700 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.entries
14701 exportTypedArrayMethod$g('entries', function entries() {
14702 return arrayEntries(aTypedArray$f(this));
14703 }, GENERIC);
14704 // `%TypedArray%.prototype.keys` method
14705 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.keys
14706 exportTypedArrayMethod$g('keys', function keys() {
14707 return arrayKeys(aTypedArray$f(this));
14708 }, GENERIC);
14709 // `%TypedArray%.prototype.values` method
14710 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.values
14711 exportTypedArrayMethod$g('values', typedArrayValues, GENERIC || !ITERATOR_IS_VALUES, {
14712 name: 'values'
14713 });
14714 // `%TypedArray%.prototype[@@iterator]` method
14715 // https://tc39.es/ecma262/#sec-%typedarray%.prototype-@@iterator
14716 exportTypedArrayMethod$g(ITERATOR$3, typedArrayValues, GENERIC || !ITERATOR_IS_VALUES, {
14717 name: 'values'
14718 });
14719
14720 var es_typedArray_join = {};
14721
14722 'use strict';
14723 var ArrayBufferViewCore$f = arrayBufferViewCore;
14724 var uncurryThis$f = functionUncurryThis;
14725 var aTypedArray$e = ArrayBufferViewCore$f.aTypedArray;
14726 var exportTypedArrayMethod$f = ArrayBufferViewCore$f.exportTypedArrayMethod;
14727 var $join = uncurryThis$f([].join);
14728
14729 // `%TypedArray%.prototype.join` method
14730 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.join
14731 exportTypedArrayMethod$f('join', function join(separator) {
14732 return $join(aTypedArray$e(this), separator);
14733 });
14734
14735 var es_typedArray_lastIndexOf = {};
14736
14737 'use strict';
14738 var ArrayBufferViewCore$e = arrayBufferViewCore;
14739 var apply$2 = functionApply$1;
14740 var $lastIndexOf = arrayLastIndexOf;
14741 var aTypedArray$d = ArrayBufferViewCore$e.aTypedArray;
14742 var exportTypedArrayMethod$e = ArrayBufferViewCore$e.exportTypedArrayMethod;
14743
14744 // `%TypedArray%.prototype.lastIndexOf` method
14745 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.lastindexof
14746 exportTypedArrayMethod$e('lastIndexOf', function lastIndexOf(searchElement /* , fromIndex */) {
14747 var length = arguments.length;
14748 return apply$2($lastIndexOf, aTypedArray$d(this), length > 1 ? [searchElement, arguments[1]] : [searchElement]);
14749 });
14750
14751 var es_typedArray_map = {};
14752
14753 'use strict';
14754 var ArrayBufferViewCore$d = arrayBufferViewCore;
14755 var $map = arrayIteration.map;
14756 var typedArraySpeciesConstructor$2 = typedArraySpeciesConstructor$4;
14757 var aTypedArray$c = ArrayBufferViewCore$d.aTypedArray;
14758 var exportTypedArrayMethod$d = ArrayBufferViewCore$d.exportTypedArrayMethod;
14759
14760 // `%TypedArray%.prototype.map` method
14761 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.map
14762 exportTypedArrayMethod$d('map', function map(mapfn /* , thisArg */) {
14763 return $map(aTypedArray$c(this), mapfn, arguments.length > 1 ? arguments[1] : undefined, function (O, length) {
14764 return new (typedArraySpeciesConstructor$2(O))(length);
14765 });
14766 });
14767
14768 var es_typedArray_of = {};
14769
14770 'use strict';
14771 var ArrayBufferViewCore$c = arrayBufferViewCore;
14772 var TYPED_ARRAYS_CONSTRUCTORS_REQUIRES_WRAPPERS = typedArrayConstructorsRequireWrappers;
14773 var aTypedArrayConstructor = ArrayBufferViewCore$c.aTypedArrayConstructor;
14774 var exportTypedArrayStaticMethod = ArrayBufferViewCore$c.exportTypedArrayStaticMethod;
14775
14776 // `%TypedArray%.of` method
14777 // https://tc39.es/ecma262/#sec-%typedarray%.of
14778 exportTypedArrayStaticMethod('of', function of( /* ...items */
14779 ) {
14780 var index = 0;
14781 var length = arguments.length;
14782 var result = new (aTypedArrayConstructor(this))(length);
14783 while (length > index) result[index] = arguments[index++];
14784 return result;
14785 }, TYPED_ARRAYS_CONSTRUCTORS_REQUIRES_WRAPPERS);
14786
14787 var es_typedArray_reduce = {};
14788
14789 'use strict';
14790 var ArrayBufferViewCore$b = arrayBufferViewCore;
14791 var $reduce = arrayReduce.left;
14792 var aTypedArray$b = ArrayBufferViewCore$b.aTypedArray;
14793 var exportTypedArrayMethod$c = ArrayBufferViewCore$b.exportTypedArrayMethod;
14794
14795 // `%TypedArray%.prototype.reduce` method
14796 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.reduce
14797 exportTypedArrayMethod$c('reduce', function reduce(callbackfn /* , initialValue */) {
14798 var length = arguments.length;
14799 return $reduce(aTypedArray$b(this), callbackfn, length, length > 1 ? arguments[1] : undefined);
14800 });
14801
14802 var es_typedArray_reduceRight = {};
14803
14804 'use strict';
14805 var ArrayBufferViewCore$a = arrayBufferViewCore;
14806 var $reduceRight = arrayReduce.right;
14807 var aTypedArray$a = ArrayBufferViewCore$a.aTypedArray;
14808 var exportTypedArrayMethod$b = ArrayBufferViewCore$a.exportTypedArrayMethod;
14809
14810 // `%TypedArray%.prototype.reduceRight` method
14811 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.reduceright
14812 exportTypedArrayMethod$b('reduceRight', function reduceRight(callbackfn /* , initialValue */) {
14813 var length = arguments.length;
14814 return $reduceRight(aTypedArray$a(this), callbackfn, length, length > 1 ? arguments[1] : undefined);
14815 });
14816
14817 var es_typedArray_reverse = {};
14818
14819 'use strict';
14820 var ArrayBufferViewCore$9 = arrayBufferViewCore;
14821 var aTypedArray$9 = ArrayBufferViewCore$9.aTypedArray;
14822 var exportTypedArrayMethod$a = ArrayBufferViewCore$9.exportTypedArrayMethod;
14823 var floor$5 = Math.floor;
14824
14825 // `%TypedArray%.prototype.reverse` method
14826 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.reverse
14827 exportTypedArrayMethod$a('reverse', function reverse() {
14828 var that = this;
14829 var length = aTypedArray$9(that).length;
14830 var middle = floor$5(length / 2);
14831 var index = 0;
14832 var value;
14833 while (index < middle) {
14834 value = that[index];
14835 that[index++] = that[--length];
14836 that[length] = value;
14837 }
14838 return that;
14839 });
14840
14841 var es_typedArray_set = {};
14842
14843 'use strict';
14844 var global$l = global$Z;
14845 var call$4 = functionCall;
14846 var ArrayBufferViewCore$8 = arrayBufferViewCore;
14847 var lengthOfArrayLike$1 = lengthOfArrayLike$t;
14848 var toOffset = toOffset$2;
14849 var toIndexedObject = toObject$t;
14850 var fails$b = fails$1m;
14851 var RangeError$2 = global$l.RangeError;
14852 var Int8Array$2 = global$l.Int8Array;
14853 var Int8ArrayPrototype = Int8Array$2 && Int8Array$2.prototype;
14854 var $set = Int8ArrayPrototype && Int8ArrayPrototype.set;
14855 var aTypedArray$8 = ArrayBufferViewCore$8.aTypedArray;
14856 var exportTypedArrayMethod$9 = ArrayBufferViewCore$8.exportTypedArrayMethod;
14857 var WORKS_WITH_OBJECTS_AND_GENERIC_ON_TYPED_ARRAYS = !fails$b(function () {
14858 // eslint-disable-next-line es/no-typed-arrays -- required for testing
14859 var array = new Uint8ClampedArray(2);
14860 call$4($set, array, {
14861 length: 1,
14862 0: 3
14863 }, 1);
14864 return array[1] !== 3;
14865 });
14866
14867 // https://bugs.chromium.org/p/v8/issues/detail?id=11294 and other
14868 var TO_OBJECT_BUG = WORKS_WITH_OBJECTS_AND_GENERIC_ON_TYPED_ARRAYS && ArrayBufferViewCore$8.NATIVE_ARRAY_BUFFER_VIEWS && fails$b(function () {
14869 var array = new Int8Array$2(2);
14870 array.set(1);
14871 array.set('2', 1);
14872 return array[0] !== 0 || array[1] !== 2;
14873 });
14874
14875 // `%TypedArray%.prototype.set` method
14876 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.set
14877 exportTypedArrayMethod$9('set', function set(arrayLike /* , offset */) {
14878 aTypedArray$8(this);
14879 var offset = toOffset(arguments.length > 1 ? arguments[1] : undefined, 1);
14880 var src = toIndexedObject(arrayLike);
14881 if (WORKS_WITH_OBJECTS_AND_GENERIC_ON_TYPED_ARRAYS) return call$4($set, this, src, offset);
14882 var length = this.length;
14883 var len = lengthOfArrayLike$1(src);
14884 var index = 0;
14885 if (len + offset > length) throw RangeError$2('Wrong length');
14886 while (index < len) this[offset + index] = src[index++];
14887 }, !WORKS_WITH_OBJECTS_AND_GENERIC_ON_TYPED_ARRAYS || TO_OBJECT_BUG);
14888
14889 var es_typedArray_slice = {};
14890
14891 'use strict';
14892 var ArrayBufferViewCore$7 = arrayBufferViewCore;
14893 var typedArraySpeciesConstructor$1 = typedArraySpeciesConstructor$4;
14894 var fails$a = fails$1m;
14895 var arraySlice$3 = arraySlice$a;
14896 var aTypedArray$7 = ArrayBufferViewCore$7.aTypedArray;
14897 var exportTypedArrayMethod$8 = ArrayBufferViewCore$7.exportTypedArrayMethod;
14898 var FORCED$1 = fails$a(function () {
14899 // eslint-disable-next-line es/no-typed-arrays -- required for testing
14900 new Int8Array(1).slice();
14901 });
14902
14903 // `%TypedArray%.prototype.slice` method
14904 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.slice
14905 exportTypedArrayMethod$8('slice', function slice(start, end) {
14906 var list = arraySlice$3(aTypedArray$7(this), start, end);
14907 var C = typedArraySpeciesConstructor$1(this);
14908 var index = 0;
14909 var length = list.length;
14910 var result = new C(length);
14911 while (length > index) result[index] = list[index++];
14912 return result;
14913 }, FORCED$1);
14914
14915 var es_typedArray_some = {};
14916
14917 'use strict';
14918 var ArrayBufferViewCore$6 = arrayBufferViewCore;
14919 var $some = arrayIteration.some;
14920 var aTypedArray$6 = ArrayBufferViewCore$6.aTypedArray;
14921 var exportTypedArrayMethod$7 = ArrayBufferViewCore$6.exportTypedArrayMethod;
14922
14923 // `%TypedArray%.prototype.some` method
14924 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.some
14925 exportTypedArrayMethod$7('some', function some(callbackfn /* , thisArg */) {
14926 return $some(aTypedArray$6(this), callbackfn, arguments.length > 1 ? arguments[1] : undefined);
14927 });
14928
14929 var es_typedArray_sort = {};
14930
14931 'use strict';
14932 var global$k = global$Z;
14933 var uncurryThis$e = functionUncurryThisClause;
14934 var fails$9 = fails$1m;
14935 var aCallable$2 = aCallable$l;
14936 var internalSort = arraySort$1;
14937 var ArrayBufferViewCore$5 = arrayBufferViewCore;
14938 var FF = engineFfVersion;
14939 var IE_OR_EDGE = engineIsIeOrEdge;
14940 var V8$1 = engineV8Version;
14941 var WEBKIT = engineWebkitVersion;
14942 var aTypedArray$5 = ArrayBufferViewCore$5.aTypedArray;
14943 var exportTypedArrayMethod$6 = ArrayBufferViewCore$5.exportTypedArrayMethod;
14944 var Uint16Array$1 = global$k.Uint16Array;
14945 var nativeSort = Uint16Array$1 && uncurryThis$e(Uint16Array$1.prototype.sort);
14946
14947 // WebKit
14948 var ACCEPT_INCORRECT_ARGUMENTS = !!nativeSort && !(fails$9(function () {
14949 nativeSort(new Uint16Array$1(2), null);
14950 }) && fails$9(function () {
14951 nativeSort(new Uint16Array$1(2), {});
14952 }));
14953 var STABLE_SORT = !!nativeSort && !fails$9(function () {
14954 // feature detection can be too slow, so check engines versions
14955 if (V8$1) return V8$1 < 74;
14956 if (FF) return FF < 67;
14957 if (IE_OR_EDGE) return true;
14958 if (WEBKIT) return WEBKIT < 602;
14959 var array = new Uint16Array$1(516);
14960 var expected = Array(516);
14961 var index, mod;
14962 for (index = 0; index < 516; index++) {
14963 mod = index % 4;
14964 array[index] = 515 - index;
14965 expected[index] = index - 2 * mod + 3;
14966 }
14967 nativeSort(array, function (a, b) {
14968 return (a / 4 | 0) - (b / 4 | 0);
14969 });
14970 for (index = 0; index < 516; index++) {
14971 if (array[index] !== expected[index]) return true;
14972 }
14973 });
14974 var getSortCompare = function getSortCompare(comparefn) {
14975 return function (x, y) {
14976 if (comparefn !== undefined) return +comparefn(x, y) || 0;
14977 // eslint-disable-next-line no-self-compare -- NaN check
14978 if (y !== y) return -1;
14979 // eslint-disable-next-line no-self-compare -- NaN check
14980 if (x !== x) return 1;
14981 if (x === 0 && y === 0) return 1 / x > 0 && 1 / y < 0 ? 1 : -1;
14982 return x > y;
14983 };
14984 };
14985
14986 // `%TypedArray%.prototype.sort` method
14987 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.sort
14988 exportTypedArrayMethod$6('sort', function sort(comparefn) {
14989 if (comparefn !== undefined) aCallable$2(comparefn);
14990 if (STABLE_SORT) return nativeSort(this, comparefn);
14991 return internalSort(aTypedArray$5(this), getSortCompare(comparefn));
14992 }, !STABLE_SORT || ACCEPT_INCORRECT_ARGUMENTS);
14993
14994 var es_typedArray_subarray = {};
14995
14996 'use strict';
14997 var ArrayBufferViewCore$4 = arrayBufferViewCore;
14998 var toLength = toLength$d;
14999 var toAbsoluteIndex = toAbsoluteIndex$a;
15000 var typedArraySpeciesConstructor = typedArraySpeciesConstructor$4;
15001 var aTypedArray$4 = ArrayBufferViewCore$4.aTypedArray;
15002 var exportTypedArrayMethod$5 = ArrayBufferViewCore$4.exportTypedArrayMethod;
15003
15004 // `%TypedArray%.prototype.subarray` method
15005 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.subarray
15006 exportTypedArrayMethod$5('subarray', function subarray(begin, end) {
15007 var O = aTypedArray$4(this);
15008 var length = O.length;
15009 var beginIndex = toAbsoluteIndex(begin, length);
15010 var C = typedArraySpeciesConstructor(O);
15011 return new C(O.buffer, O.byteOffset + beginIndex * O.BYTES_PER_ELEMENT, toLength((end === undefined ? length : toAbsoluteIndex(end, length)) - beginIndex));
15012 });
15013
15014 var es_typedArray_toLocaleString = {};
15015
15016 'use strict';
15017 var global$j = global$Z;
15018 var apply$1 = functionApply$1;
15019 var ArrayBufferViewCore$3 = arrayBufferViewCore;
15020 var fails$8 = fails$1m;
15021 var arraySlice$2 = arraySlice$a;
15022 var Int8Array$1 = global$j.Int8Array;
15023 var aTypedArray$3 = ArrayBufferViewCore$3.aTypedArray;
15024 var exportTypedArrayMethod$4 = ArrayBufferViewCore$3.exportTypedArrayMethod;
15025 var $toLocaleString = [].toLocaleString;
15026
15027 // iOS Safari 6.x fails here
15028 var TO_LOCALE_STRING_BUG = !!Int8Array$1 && fails$8(function () {
15029 $toLocaleString.call(new Int8Array$1(1));
15030 });
15031 var FORCED = fails$8(function () {
15032 return [1, 2].toLocaleString() != new Int8Array$1([1, 2]).toLocaleString();
15033 }) || !fails$8(function () {
15034 Int8Array$1.prototype.toLocaleString.call([1, 2]);
15035 });
15036
15037 // `%TypedArray%.prototype.toLocaleString` method
15038 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.tolocalestring
15039 exportTypedArrayMethod$4('toLocaleString', function toLocaleString() {
15040 return apply$1($toLocaleString, TO_LOCALE_STRING_BUG ? arraySlice$2(aTypedArray$3(this)) : aTypedArray$3(this), arraySlice$2(arguments));
15041 }, FORCED);
15042
15043 var es_typedArray_toReversed = {};
15044
15045 'use strict';
15046 var arrayToReversed = arrayToReversed$2;
15047 var ArrayBufferViewCore$2 = arrayBufferViewCore;
15048 var aTypedArray$2 = ArrayBufferViewCore$2.aTypedArray;
15049 var exportTypedArrayMethod$3 = ArrayBufferViewCore$2.exportTypedArrayMethod;
15050 var getTypedArrayConstructor$2 = ArrayBufferViewCore$2.getTypedArrayConstructor;
15051
15052 // `%TypedArray%.prototype.toReversed` method
15053 // https://tc39.es/proposal-change-array-by-copy/#sec-%typedarray%.prototype.toReversed
15054 exportTypedArrayMethod$3('toReversed', function toReversed() {
15055 return arrayToReversed(aTypedArray$2(this), getTypedArrayConstructor$2(this));
15056 });
15057
15058 var es_typedArray_toSorted = {};
15059
15060 'use strict';
15061 var ArrayBufferViewCore$1 = arrayBufferViewCore;
15062 var uncurryThis$d = functionUncurryThis;
15063 var aCallable$1 = aCallable$l;
15064 var arrayFromConstructorAndList = arrayFromConstructorAndList$3;
15065 var aTypedArray$1 = ArrayBufferViewCore$1.aTypedArray;
15066 var getTypedArrayConstructor$1 = ArrayBufferViewCore$1.getTypedArrayConstructor;
15067 var exportTypedArrayMethod$2 = ArrayBufferViewCore$1.exportTypedArrayMethod;
15068 var sort = uncurryThis$d(ArrayBufferViewCore$1.TypedArrayPrototype.sort);
15069
15070 // `%TypedArray%.prototype.toSorted` method
15071 // https://tc39.es/proposal-change-array-by-copy/#sec-%typedarray%.prototype.toSorted
15072 exportTypedArrayMethod$2('toSorted', function toSorted(compareFn) {
15073 if (compareFn !== undefined) aCallable$1(compareFn);
15074 var O = aTypedArray$1(this);
15075 var A = arrayFromConstructorAndList(getTypedArrayConstructor$1(O), O);
15076 return sort(A, compareFn);
15077 });
15078
15079 var es_typedArray_toString = {};
15080
15081 'use strict';
15082 var exportTypedArrayMethod$1 = arrayBufferViewCore.exportTypedArrayMethod;
15083 var fails$7 = fails$1m;
15084 var global$i = global$Z;
15085 var uncurryThis$c = functionUncurryThis;
15086 var Uint8Array$1 = global$i.Uint8Array;
15087 var Uint8ArrayPrototype = Uint8Array$1 && Uint8Array$1.prototype || {};
15088 var arrayToString = [].toString;
15089 var join$3 = uncurryThis$c([].join);
15090 if (fails$7(function () {
15091 arrayToString.call({});
15092 })) {
15093 arrayToString = function toString() {
15094 return join$3(this);
15095 };
15096 }
15097 var IS_NOT_ARRAY_METHOD = Uint8ArrayPrototype.toString != arrayToString;
15098
15099 // `%TypedArray%.prototype.toString` method
15100 // https://tc39.es/ecma262/#sec-%typedarray%.prototype.tostring
15101 exportTypedArrayMethod$1('toString', arrayToString, IS_NOT_ARRAY_METHOD);
15102
15103 var es_typedArray_with = {};
15104
15105 'use strict';
15106 var arrayWith = arrayWith$2;
15107 var ArrayBufferViewCore = arrayBufferViewCore;
15108 var isBigIntArray = isBigIntArray$2;
15109 var toIntegerOrInfinity = toIntegerOrInfinity$l;
15110 var toBigInt = toBigInt$3;
15111 var aTypedArray = ArrayBufferViewCore.aTypedArray;
15112 var getTypedArrayConstructor = ArrayBufferViewCore.getTypedArrayConstructor;
15113 var exportTypedArrayMethod = ArrayBufferViewCore.exportTypedArrayMethod;
15114 var PROPER_ORDER = !!function () {
15115 try {
15116 // eslint-disable-next-line no-throw-literal, es/no-typed-arrays, es/no-array-prototype-with -- required for testing
15117 new Int8Array(1)['with'](2, {
15118 valueOf: function valueOf() {
15119 throw 8;
15120 }
15121 });
15122 } catch (error) {
15123 // some early implementations, like WebKit, does not follow the final semantic
15124 // https://github.com/tc39/proposal-change-array-by-copy/pull/86
15125 return error === 8;
15126 }
15127 }();
15128
15129 // `%TypedArray%.prototype.with` method
15130 // https://tc39.es/proposal-change-array-by-copy/#sec-%typedarray%.prototype.with
15131 exportTypedArrayMethod('with', {
15132 'with': function _with(index, value) {
15133 var O = aTypedArray(this);
15134 var relativeIndex = toIntegerOrInfinity(index);
15135 var actualValue = isBigIntArray(O) ? toBigInt(value) : +value;
15136 return arrayWith(O, getTypedArrayConstructor(O), relativeIndex, actualValue);
15137 }
15138 }['with'], !PROPER_ORDER);
15139
15140 var es_unescape = {};
15141
15142 'use strict';
15143 var $$e = _export;
15144 var uncurryThis$b = functionUncurryThis;
15145 var toString$2 = toString$x;
15146 var fromCharCode$2 = String.fromCharCode;
15147 var charAt$4 = uncurryThis$b(''.charAt);
15148 var exec$3 = uncurryThis$b(/./.exec);
15149 var stringSlice$2 = uncurryThis$b(''.slice);
15150 var hex2 = /^[\da-f]{2}$/i;
15151 var hex4 = /^[\da-f]{4}$/i;
15152
15153 // `unescape` method
15154 // https://tc39.es/ecma262/#sec-unescape-string
15155 $$e({
15156 global: true
15157 }, {
15158 unescape: function unescape(string) {
15159 var str = toString$2(string);
15160 var result = '';
15161 var length = str.length;
15162 var index = 0;
15163 var chr, part;
15164 while (index < length) {
15165 chr = charAt$4(str, index++);
15166 if (chr === '%') {
15167 if (charAt$4(str, index) === 'u') {
15168 part = stringSlice$2(str, index + 1, index + 5);
15169 if (exec$3(hex4, part)) {
15170 result += fromCharCode$2(parseInt(part, 16));
15171 index += 5;
15172 continue;
15173 }
15174 } else {
15175 part = stringSlice$2(str, index, index + 2);
15176 if (exec$3(hex2, part)) {
15177 result += fromCharCode$2(parseInt(part, 16));
15178 index += 2;
15179 continue;
15180 }
15181 }
15182 }
15183 result += chr;
15184 }
15185 return result;
15186 }
15187 });
15188
15189 var es_weakMap = {};
15190
15191 var es_weakMap_constructor = {};
15192
15193 'use strict';
15194 var uncurryThis$a = functionUncurryThis;
15195 var defineBuiltIns$2 = defineBuiltIns$5;
15196 var getWeakData = internalMetadataExports.getWeakData;
15197 var anInstance$4 = anInstance$a;
15198 var anObject$3 = anObject$D;
15199 var isNullOrUndefined$1 = isNullOrUndefined$e;
15200 var isObject$3 = isObject$z;
15201 var iterate$1 = iterate$a;
15202 var ArrayIterationModule = arrayIteration;
15203 var hasOwn$6 = hasOwnProperty_1;
15204 var InternalStateModule$3 = internalState;
15205 var setInternalState$3 = InternalStateModule$3.set;
15206 var internalStateGetterFor = InternalStateModule$3.getterFor;
15207 var find$1 = ArrayIterationModule.find;
15208 var findIndex = ArrayIterationModule.findIndex;
15209 var splice$1 = uncurryThis$a([].splice);
15210 var id = 0;
15211
15212 // fallback for uncaught frozen keys
15213 var uncaughtFrozenStore = function uncaughtFrozenStore(state) {
15214 return state.frozen || (state.frozen = new UncaughtFrozenStore());
15215 };
15216 var UncaughtFrozenStore = function UncaughtFrozenStore() {
15217 this.entries = [];
15218 };
15219 var findUncaughtFrozen = function findUncaughtFrozen(store, key) {
15220 return find$1(store.entries, function (it) {
15221 return it[0] === key;
15222 });
15223 };
15224 UncaughtFrozenStore.prototype = {
15225 get: function get(key) {
15226 var entry = findUncaughtFrozen(this, key);
15227 if (entry) return entry[1];
15228 },
15229 has: function has(key) {
15230 return !!findUncaughtFrozen(this, key);
15231 },
15232 set: function set(key, value) {
15233 var entry = findUncaughtFrozen(this, key);
15234 if (entry) entry[1] = value;else this.entries.push([key, value]);
15235 },
15236 'delete': function _delete(key) {
15237 var index = findIndex(this.entries, function (it) {
15238 return it[0] === key;
15239 });
15240 if (~index) splice$1(this.entries, index, 1);
15241 return !!~index;
15242 }
15243 };
15244 var collectionWeak$2 = {
15245 getConstructor: function getConstructor(wrapper, CONSTRUCTOR_NAME, IS_MAP, ADDER) {
15246 var Constructor = wrapper(function (that, iterable) {
15247 anInstance$4(that, Prototype);
15248 setInternalState$3(that, {
15249 type: CONSTRUCTOR_NAME,
15250 id: id++,
15251 frozen: undefined
15252 });
15253 if (!isNullOrUndefined$1(iterable)) iterate$1(iterable, that[ADDER], {
15254 that: that,
15255 AS_ENTRIES: IS_MAP
15256 });
15257 });
15258 var Prototype = Constructor.prototype;
15259 var getInternalState = internalStateGetterFor(CONSTRUCTOR_NAME);
15260 var define = function define(that, key, value) {
15261 var state = getInternalState(that);
15262 var data = getWeakData(anObject$3(key), true);
15263 if (data === true) uncaughtFrozenStore(state).set(key, value);else data[state.id] = value;
15264 return that;
15265 };
15266 defineBuiltIns$2(Prototype, {
15267 // `{ WeakMap, WeakSet }.prototype.delete(key)` methods
15268 // https://tc39.es/ecma262/#sec-weakmap.prototype.delete
15269 // https://tc39.es/ecma262/#sec-weakset.prototype.delete
15270 'delete': function _delete(key) {
15271 var state = getInternalState(this);
15272 if (!isObject$3(key)) return false;
15273 var data = getWeakData(key);
15274 if (data === true) return uncaughtFrozenStore(state)['delete'](key);
15275 return data && hasOwn$6(data, state.id) && delete data[state.id];
15276 },
15277 // `{ WeakMap, WeakSet }.prototype.has(key)` methods
15278 // https://tc39.es/ecma262/#sec-weakmap.prototype.has
15279 // https://tc39.es/ecma262/#sec-weakset.prototype.has
15280 has: function has(key) {
15281 var state = getInternalState(this);
15282 if (!isObject$3(key)) return false;
15283 var data = getWeakData(key);
15284 if (data === true) return uncaughtFrozenStore(state).has(key);
15285 return data && hasOwn$6(data, state.id);
15286 }
15287 });
15288 defineBuiltIns$2(Prototype, IS_MAP ? {
15289 // `WeakMap.prototype.get(key)` method
15290 // https://tc39.es/ecma262/#sec-weakmap.prototype.get
15291 get: function get(key) {
15292 var state = getInternalState(this);
15293 if (isObject$3(key)) {
15294 var data = getWeakData(key);
15295 if (data === true) return uncaughtFrozenStore(state).get(key);
15296 return data ? data[state.id] : undefined;
15297 }
15298 },
15299 // `WeakMap.prototype.set(key, value)` method
15300 // https://tc39.es/ecma262/#sec-weakmap.prototype.set
15301 set: function set(key, value) {
15302 return define(this, key, value);
15303 }
15304 } : {
15305 // `WeakSet.prototype.add(value)` method
15306 // https://tc39.es/ecma262/#sec-weakset.prototype.add
15307 add: function add(value) {
15308 return define(this, value, true);
15309 }
15310 });
15311 return Constructor;
15312 }
15313 };
15314 var collectionWeak$3 = /*@__PURE__*/getDefaultExportFromCjs(collectionWeak$2);
15315
15316 'use strict';
15317 var FREEZING = freezing;
15318 var global$h = global$Z;
15319 var uncurryThis$9 = functionUncurryThis;
15320 var defineBuiltIns$1 = defineBuiltIns$5;
15321 var InternalMetadataModule = internalMetadataExports;
15322 var collection$1 = collection$4;
15323 var collectionWeak$1 = collectionWeak$2;
15324 var isObject$2 = isObject$z;
15325 var enforceInternalState = internalState.enforce;
15326 var fails$6 = fails$1m;
15327 var NATIVE_WEAK_MAP = weakMapBasicDetection;
15328 var $Object = Object;
15329 // eslint-disable-next-line es/no-array-isarray -- safe
15330 var isArray = Array.isArray;
15331 // eslint-disable-next-line es/no-object-isextensible -- safe
15332 var isExtensible = $Object.isExtensible;
15333 // eslint-disable-next-line es/no-object-isfrozen -- safe
15334 var isFrozen = $Object.isFrozen;
15335 // eslint-disable-next-line es/no-object-issealed -- safe
15336 var isSealed = $Object.isSealed;
15337 // eslint-disable-next-line es/no-object-freeze -- safe
15338 var freeze = $Object.freeze;
15339 // eslint-disable-next-line es/no-object-seal -- safe
15340 var seal = $Object.seal;
15341 var FROZEN = {};
15342 var SEALED = {};
15343 var IS_IE11 = !global$h.ActiveXObject && 'ActiveXObject' in global$h;
15344 var InternalWeakMap;
15345 var wrapper = function wrapper(init) {
15346 return function WeakMap() {
15347 return init(this, arguments.length ? arguments[0] : undefined);
15348 };
15349 };
15350
15351 // `WeakMap` constructor
15352 // https://tc39.es/ecma262/#sec-weakmap-constructor
15353 var $WeakMap = collection$1('WeakMap', wrapper, collectionWeak$1);
15354 var WeakMapPrototype = $WeakMap.prototype;
15355 var nativeSet = uncurryThis$9(WeakMapPrototype.set);
15356
15357 // Chakra Edge bug: adding frozen arrays to WeakMap unfreeze them
15358 var hasMSEdgeFreezingBug = function hasMSEdgeFreezingBug() {
15359 return FREEZING && fails$6(function () {
15360 var frozenArray = freeze([]);
15361 nativeSet(new $WeakMap(), frozenArray, 1);
15362 return !isFrozen(frozenArray);
15363 });
15364 };
15365
15366 // IE11 WeakMap frozen keys fix
15367 // We can't use feature detection because it crash some old IE builds
15368 // https://github.com/zloirock/core-js/issues/485
15369 if (NATIVE_WEAK_MAP) if (IS_IE11) {
15370 InternalWeakMap = collectionWeak$1.getConstructor(wrapper, 'WeakMap', true);
15371 InternalMetadataModule.enable();
15372 var nativeDelete = uncurryThis$9(WeakMapPrototype['delete']);
15373 var nativeHas = uncurryThis$9(WeakMapPrototype.has);
15374 var nativeGet = uncurryThis$9(WeakMapPrototype.get);
15375 defineBuiltIns$1(WeakMapPrototype, {
15376 'delete': function _delete(key) {
15377 if (isObject$2(key) && !isExtensible(key)) {
15378 var state = enforceInternalState(this);
15379 if (!state.frozen) state.frozen = new InternalWeakMap();
15380 return nativeDelete(this, key) || state.frozen['delete'](key);
15381 }
15382 return nativeDelete(this, key);
15383 },
15384 has: function has(key) {
15385 if (isObject$2(key) && !isExtensible(key)) {
15386 var state = enforceInternalState(this);
15387 if (!state.frozen) state.frozen = new InternalWeakMap();
15388 return nativeHas(this, key) || state.frozen.has(key);
15389 }
15390 return nativeHas(this, key);
15391 },
15392 get: function get(key) {
15393 if (isObject$2(key) && !isExtensible(key)) {
15394 var state = enforceInternalState(this);
15395 if (!state.frozen) state.frozen = new InternalWeakMap();
15396 return nativeHas(this, key) ? nativeGet(this, key) : state.frozen.get(key);
15397 }
15398 return nativeGet(this, key);
15399 },
15400 set: function set(key, value) {
15401 if (isObject$2(key) && !isExtensible(key)) {
15402 var state = enforceInternalState(this);
15403 if (!state.frozen) state.frozen = new InternalWeakMap();
15404 nativeHas(this, key) ? nativeSet(this, key, value) : state.frozen.set(key, value);
15405 } else nativeSet(this, key, value);
15406 return this;
15407 }
15408 });
15409 // Chakra Edge frozen keys fix
15410 } else if (hasMSEdgeFreezingBug()) {
15411 defineBuiltIns$1(WeakMapPrototype, {
15412 set: function set(key, value) {
15413 var arrayIntegrityLevel;
15414 if (isArray(key)) {
15415 if (isFrozen(key)) arrayIntegrityLevel = FROZEN;else if (isSealed(key)) arrayIntegrityLevel = SEALED;
15416 }
15417 nativeSet(this, key, value);
15418 if (arrayIntegrityLevel == FROZEN) freeze(key);
15419 if (arrayIntegrityLevel == SEALED) seal(key);
15420 return this;
15421 }
15422 });
15423 }
15424
15425 var es_weakSet = {};
15426
15427 var es_weakSet_constructor = {};
15428
15429 'use strict';
15430 var collection = collection$4;
15431 var collectionWeak = collectionWeak$2;
15432
15433 // `WeakSet` constructor
15434 // https://tc39.es/ecma262/#sec-weakset-constructor
15435 collection('WeakSet', function (init) {
15436 return function WeakSet() {
15437 return init(this, arguments.length ? arguments[0] : undefined);
15438 };
15439 }, collectionWeak);
15440
15441 var web_atob = {};
15442
15443 var itoc$1 = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=';
15444 var ctoi$1 = {};
15445 for (var index$3 = 0; index$3 < 66; index$3++) ctoi$1[itoc$1.charAt(index$3)] = index$3;
15446 var base64Map = {
15447 itoc: itoc$1,
15448 ctoi: ctoi$1
15449 };
15450 var base64Map$1 = /*@__PURE__*/getDefaultExportFromCjs(base64Map);
15451
15452 var $$d = _export;
15453 var global$g = global$Z;
15454 var getBuiltIn$4 = getBuiltIn$m;
15455 var uncurryThis$8 = functionUncurryThis;
15456 var call$3 = functionCall;
15457 var fails$5 = fails$1m;
15458 var toString$1 = toString$x;
15459 var hasOwn$5 = hasOwnProperty_1;
15460 var validateArgumentsLength$6 = validateArgumentsLength$8;
15461 var ctoi = base64Map.ctoi;
15462 var disallowed = /[^\d+/a-z]/i;
15463 var whitespaces = /[\t\n\f\r ]+/g;
15464 var finalEq = /[=]{1,2}$/;
15465 var $atob = getBuiltIn$4('atob');
15466 var fromCharCode$1 = String.fromCharCode;
15467 var charAt$3 = uncurryThis$8(''.charAt);
15468 var replace$3 = uncurryThis$8(''.replace);
15469 var exec$2 = uncurryThis$8(disallowed.exec);
15470 var NO_SPACES_IGNORE = fails$5(function () {
15471 return $atob(' ') !== '';
15472 });
15473 var NO_ENCODING_CHECK = !fails$5(function () {
15474 $atob('a');
15475 });
15476 var NO_ARG_RECEIVING_CHECK$1 = !NO_SPACES_IGNORE && !NO_ENCODING_CHECK && !fails$5(function () {
15477 $atob();
15478 });
15479 var WRONG_ARITY$1 = !NO_SPACES_IGNORE && !NO_ENCODING_CHECK && $atob.length !== 1;
15480
15481 // `atob` method
15482 // https://html.spec.whatwg.org/multipage/webappapis.html#dom-atob
15483 $$d({
15484 global: true,
15485 bind: true,
15486 enumerable: true,
15487 forced: NO_SPACES_IGNORE || NO_ENCODING_CHECK || NO_ARG_RECEIVING_CHECK$1 || WRONG_ARITY$1
15488 }, {
15489 atob: function atob(data) {
15490 validateArgumentsLength$6(arguments.length, 1);
15491 // `webpack` dev server bug on IE global methods - use call(fn, global, ...)
15492 if (NO_ARG_RECEIVING_CHECK$1 || WRONG_ARITY$1) return call$3($atob, global$g, data);
15493 var string = replace$3(toString$1(data), whitespaces, '');
15494 var output = '';
15495 var position = 0;
15496 var bc = 0;
15497 var chr, bs;
15498 if (string.length % 4 == 0) {
15499 string = replace$3(string, finalEq, '');
15500 }
15501 if (string.length % 4 == 1 || exec$2(disallowed, string)) {
15502 throw new (getBuiltIn$4('DOMException'))('The string is not correctly encoded', 'InvalidCharacterError');
15503 }
15504 while (chr = charAt$3(string, position++)) {
15505 if (hasOwn$5(ctoi, chr)) {
15506 bs = bc % 4 ? bs * 64 + ctoi[chr] : ctoi[chr];
15507 if (bc++ % 4) output += fromCharCode$1(255 & bs >> (-2 * bc & 6));
15508 }
15509 }
15510 return output;
15511 }
15512 });
15513
15514 var web_btoa = {};
15515
15516 var $$c = _export;
15517 var global$f = global$Z;
15518 var getBuiltIn$3 = getBuiltIn$m;
15519 var uncurryThis$7 = functionUncurryThis;
15520 var call$2 = functionCall;
15521 var fails$4 = fails$1m;
15522 var toString = toString$x;
15523 var validateArgumentsLength$5 = validateArgumentsLength$8;
15524 var itoc = base64Map.itoc;
15525 var $btoa = getBuiltIn$3('btoa');
15526 var charAt$2 = uncurryThis$7(''.charAt);
15527 var charCodeAt$1 = uncurryThis$7(''.charCodeAt);
15528 var NO_ARG_RECEIVING_CHECK = !!$btoa && !fails$4(function () {
15529 $btoa();
15530 });
15531 var WRONG_ARG_CONVERSION = !!$btoa && fails$4(function () {
15532 return $btoa(null) !== 'bnVsbA==';
15533 });
15534 var WRONG_ARITY = !!$btoa && $btoa.length !== 1;
15535
15536 // `btoa` method
15537 // https://html.spec.whatwg.org/multipage/webappapis.html#dom-btoa
15538 $$c({
15539 global: true,
15540 bind: true,
15541 enumerable: true,
15542 forced: NO_ARG_RECEIVING_CHECK || WRONG_ARG_CONVERSION || WRONG_ARITY
15543 }, {
15544 btoa: function btoa(data) {
15545 validateArgumentsLength$5(arguments.length, 1);
15546 // `webpack` dev server bug on IE global methods - use call(fn, global, ...)
15547 if (NO_ARG_RECEIVING_CHECK || WRONG_ARG_CONVERSION || WRONG_ARITY) return call$2($btoa, global$f, toString(data));
15548 var string = toString(data);
15549 var output = '';
15550 var position = 0;
15551 var map = itoc;
15552 var block, charCode;
15553 while (charAt$2(string, position) || (map = '=', position % 1)) {
15554 charCode = charCodeAt$1(string, position += 3 / 4);
15555 if (charCode > 0xFF) {
15556 throw new (getBuiltIn$3('DOMException'))('The string contains characters outside of the Latin1 range', 'InvalidCharacterError');
15557 }
15558 block = block << 8 | charCode;
15559 output += charAt$2(map, 63 & block >> 8 - position % 1 * 8);
15560 }
15561 return output;
15562 }
15563 });
15564
15565 var web_domCollections_forEach = {};
15566
15567 // iterable DOM collections
15568 // flag - `iterable` interface - 'entries', 'keys', 'values', 'forEach' methods
15569 var domIterables = {
15570 CSSRuleList: 0,
15571 CSSStyleDeclaration: 0,
15572 CSSValueList: 0,
15573 ClientRectList: 0,
15574 DOMRectList: 0,
15575 DOMStringList: 0,
15576 DOMTokenList: 1,
15577 DataTransferItemList: 0,
15578 FileList: 0,
15579 HTMLAllCollection: 0,
15580 HTMLCollection: 0,
15581 HTMLFormElement: 0,
15582 HTMLSelectElement: 0,
15583 MediaList: 0,
15584 MimeTypeArray: 0,
15585 NamedNodeMap: 0,
15586 NodeList: 1,
15587 PaintRequestList: 0,
15588 Plugin: 0,
15589 PluginArray: 0,
15590 SVGLengthList: 0,
15591 SVGNumberList: 0,
15592 SVGPathSegList: 0,
15593 SVGPointList: 0,
15594 SVGStringList: 0,
15595 SVGTransformList: 0,
15596 SourceBufferList: 0,
15597 StyleSheetList: 0,
15598 TextTrackCueList: 0,
15599 TextTrackList: 0,
15600 TouchList: 0
15601 };
15602 var domIterables$1 = /*@__PURE__*/getDefaultExportFromCjs(domIterables);
15603
15604 // in old WebKit versions, `element.classList` is not an instance of global `DOMTokenList`
15605 var documentCreateElement = documentCreateElement$2;
15606 var classList = documentCreateElement('span').classList;
15607 var DOMTokenListPrototype$2 = classList && classList.constructor && classList.constructor.prototype;
15608 var domTokenListPrototype = DOMTokenListPrototype$2 === Object.prototype ? undefined : DOMTokenListPrototype$2;
15609 var domTokenListPrototype$1 = /*@__PURE__*/getDefaultExportFromCjs(domTokenListPrototype);
15610
15611 var global$e = global$Z;
15612 var DOMIterables$1 = domIterables;
15613 var DOMTokenListPrototype$1 = domTokenListPrototype;
15614 var forEach$1 = arrayForEach;
15615 var createNonEnumerableProperty$2 = createNonEnumerableProperty$f;
15616 var handlePrototype$1 = function handlePrototype(CollectionPrototype) {
15617 // some Chrome versions have non-configurable methods on DOMTokenList
15618 if (CollectionPrototype && CollectionPrototype.forEach !== forEach$1) try {
15619 createNonEnumerableProperty$2(CollectionPrototype, 'forEach', forEach$1);
15620 } catch (error) {
15621 CollectionPrototype.forEach = forEach$1;
15622 }
15623 };
15624 for (var COLLECTION_NAME$1 in DOMIterables$1) {
15625 if (DOMIterables$1[COLLECTION_NAME$1]) {
15626 handlePrototype$1(global$e[COLLECTION_NAME$1] && global$e[COLLECTION_NAME$1].prototype);
15627 }
15628 }
15629 handlePrototype$1(DOMTokenListPrototype$1);
15630
15631 var web_domCollections_iterator = {};
15632
15633 var global$d = global$Z;
15634 var DOMIterables = domIterables;
15635 var DOMTokenListPrototype = domTokenListPrototype;
15636 var ArrayIteratorMethods = es_array_iterator;
15637 var createNonEnumerableProperty$1 = createNonEnumerableProperty$f;
15638 var wellKnownSymbol$2 = wellKnownSymbol$z;
15639 var ITERATOR$2 = wellKnownSymbol$2('iterator');
15640 var TO_STRING_TAG = wellKnownSymbol$2('toStringTag');
15641 var ArrayValues = ArrayIteratorMethods.values;
15642 var handlePrototype = function handlePrototype(CollectionPrototype, COLLECTION_NAME) {
15643 if (CollectionPrototype) {
15644 // some Chrome versions have non-configurable methods on DOMTokenList
15645 if (CollectionPrototype[ITERATOR$2] !== ArrayValues) try {
15646 createNonEnumerableProperty$1(CollectionPrototype, ITERATOR$2, ArrayValues);
15647 } catch (error) {
15648 CollectionPrototype[ITERATOR$2] = ArrayValues;
15649 }
15650 if (!CollectionPrototype[TO_STRING_TAG]) {
15651 createNonEnumerableProperty$1(CollectionPrototype, TO_STRING_TAG, COLLECTION_NAME);
15652 }
15653 if (DOMIterables[COLLECTION_NAME]) for (var METHOD_NAME in ArrayIteratorMethods) {
15654 // some Chrome versions have non-configurable methods on DOMTokenList
15655 if (CollectionPrototype[METHOD_NAME] !== ArrayIteratorMethods[METHOD_NAME]) try {
15656 createNonEnumerableProperty$1(CollectionPrototype, METHOD_NAME, ArrayIteratorMethods[METHOD_NAME]);
15657 } catch (error) {
15658 CollectionPrototype[METHOD_NAME] = ArrayIteratorMethods[METHOD_NAME];
15659 }
15660 }
15661 }
15662 };
15663 for (var COLLECTION_NAME in DOMIterables) {
15664 handlePrototype(global$d[COLLECTION_NAME] && global$d[COLLECTION_NAME].prototype, COLLECTION_NAME);
15665 }
15666 handlePrototype(DOMTokenListPrototype, 'DOMTokenList');
15667
15668 var web_domException_constructor = {};
15669
15670 var IS_NODE$2 = engineIsNode;
15671 var tryNodeRequire$1 = function tryNodeRequire(name) {
15672 try {
15673 // eslint-disable-next-line no-new-func -- safe
15674 if (IS_NODE$2) return Function('return require("' + name + '")')();
15675 } catch (error) {/* empty */}
15676 };
15677 var tryNodeRequire$2 = /*@__PURE__*/getDefaultExportFromCjs(tryNodeRequire$1);
15678
15679 var domExceptionConstants = {
15680 IndexSizeError: {
15681 s: 'INDEX_SIZE_ERR',
15682 c: 1,
15683 m: 1
15684 },
15685 DOMStringSizeError: {
15686 s: 'DOMSTRING_SIZE_ERR',
15687 c: 2,
15688 m: 0
15689 },
15690 HierarchyRequestError: {
15691 s: 'HIERARCHY_REQUEST_ERR',
15692 c: 3,
15693 m: 1
15694 },
15695 WrongDocumentError: {
15696 s: 'WRONG_DOCUMENT_ERR',
15697 c: 4,
15698 m: 1
15699 },
15700 InvalidCharacterError: {
15701 s: 'INVALID_CHARACTER_ERR',
15702 c: 5,
15703 m: 1
15704 },
15705 NoDataAllowedError: {
15706 s: 'NO_DATA_ALLOWED_ERR',
15707 c: 6,
15708 m: 0
15709 },
15710 NoModificationAllowedError: {
15711 s: 'NO_MODIFICATION_ALLOWED_ERR',
15712 c: 7,
15713 m: 1
15714 },
15715 NotFoundError: {
15716 s: 'NOT_FOUND_ERR',
15717 c: 8,
15718 m: 1
15719 },
15720 NotSupportedError: {
15721 s: 'NOT_SUPPORTED_ERR',
15722 c: 9,
15723 m: 1
15724 },
15725 InUseAttributeError: {
15726 s: 'INUSE_ATTRIBUTE_ERR',
15727 c: 10,
15728 m: 1
15729 },
15730 InvalidStateError: {
15731 s: 'INVALID_STATE_ERR',
15732 c: 11,
15733 m: 1
15734 },
15735 SyntaxError: {
15736 s: 'SYNTAX_ERR',
15737 c: 12,
15738 m: 1
15739 },
15740 InvalidModificationError: {
15741 s: 'INVALID_MODIFICATION_ERR',
15742 c: 13,
15743 m: 1
15744 },
15745 NamespaceError: {
15746 s: 'NAMESPACE_ERR',
15747 c: 14,
15748 m: 1
15749 },
15750 InvalidAccessError: {
15751 s: 'INVALID_ACCESS_ERR',
15752 c: 15,
15753 m: 1
15754 },
15755 ValidationError: {
15756 s: 'VALIDATION_ERR',
15757 c: 16,
15758 m: 0
15759 },
15760 TypeMismatchError: {
15761 s: 'TYPE_MISMATCH_ERR',
15762 c: 17,
15763 m: 1
15764 },
15765 SecurityError: {
15766 s: 'SECURITY_ERR',
15767 c: 18,
15768 m: 1
15769 },
15770 NetworkError: {
15771 s: 'NETWORK_ERR',
15772 c: 19,
15773 m: 1
15774 },
15775 AbortError: {
15776 s: 'ABORT_ERR',
15777 c: 20,
15778 m: 1
15779 },
15780 URLMismatchError: {
15781 s: 'URL_MISMATCH_ERR',
15782 c: 21,
15783 m: 1
15784 },
15785 QuotaExceededError: {
15786 s: 'QUOTA_EXCEEDED_ERR',
15787 c: 22,
15788 m: 1
15789 },
15790 TimeoutError: {
15791 s: 'TIMEOUT_ERR',
15792 c: 23,
15793 m: 1
15794 },
15795 InvalidNodeTypeError: {
15796 s: 'INVALID_NODE_TYPE_ERR',
15797 c: 24,
15798 m: 1
15799 },
15800 DataCloneError: {
15801 s: 'DATA_CLONE_ERR',
15802 c: 25,
15803 m: 1
15804 }
15805 };
15806 var domExceptionConstants$1 = /*@__PURE__*/getDefaultExportFromCjs(domExceptionConstants);
15807
15808 'use strict';
15809 var $$b = _export;
15810 var tryNodeRequire = tryNodeRequire$1;
15811 var getBuiltIn$2 = getBuiltIn$m;
15812 var fails$3 = fails$1m;
15813 var create$1 = objectCreate;
15814 var createPropertyDescriptor$2 = createPropertyDescriptor$c;
15815 var defineProperty$2 = objectDefineProperty.f;
15816 var defineBuiltIn$2 = defineBuiltIn$m;
15817 var defineBuiltInAccessor$4 = defineBuiltInAccessor$h;
15818 var hasOwn$4 = hasOwnProperty_1;
15819 var anInstance$3 = anInstance$a;
15820 var anObject$2 = anObject$D;
15821 var errorToString = errorToString$2;
15822 var normalizeStringArgument$1 = normalizeStringArgument$5;
15823 var DOMExceptionConstants$1 = domExceptionConstants;
15824 var clearErrorStack$1 = errorStackClear;
15825 var InternalStateModule$2 = internalState;
15826 var DESCRIPTORS$6 = descriptors;
15827 var IS_PURE$3 = isPure;
15828 var DOM_EXCEPTION$2 = 'DOMException';
15829 var DATA_CLONE_ERR = 'DATA_CLONE_ERR';
15830 var Error$3 = getBuiltIn$2('Error');
15831 // NodeJS < 17.0 does not expose `DOMException` to global
15832 var NativeDOMException$1 = getBuiltIn$2(DOM_EXCEPTION$2) || function () {
15833 try {
15834 // NodeJS < 15.0 does not expose `MessageChannel` to global
15835 var MessageChannel = getBuiltIn$2('MessageChannel') || tryNodeRequire('worker_threads').MessageChannel;
15836 // eslint-disable-next-line es/no-weak-map, unicorn/require-post-message-target-origin -- safe
15837 new MessageChannel().port1.postMessage(new WeakMap());
15838 } catch (error) {
15839 if (error.name == DATA_CLONE_ERR && error.code == 25) return error.constructor;
15840 }
15841 }();
15842 var NativeDOMExceptionPrototype = NativeDOMException$1 && NativeDOMException$1.prototype;
15843 var ErrorPrototype = Error$3.prototype;
15844 var setInternalState$2 = InternalStateModule$2.set;
15845 var getInternalState = InternalStateModule$2.getterFor(DOM_EXCEPTION$2);
15846 var HAS_STACK = ('stack' in Error$3(DOM_EXCEPTION$2));
15847 var codeFor = function codeFor(name) {
15848 return hasOwn$4(DOMExceptionConstants$1, name) && DOMExceptionConstants$1[name].m ? DOMExceptionConstants$1[name].c : 0;
15849 };
15850 var $DOMException$1 = function DOMException() {
15851 anInstance$3(this, DOMExceptionPrototype$1);
15852 var argumentsLength = arguments.length;
15853 var message = normalizeStringArgument$1(argumentsLength < 1 ? undefined : arguments[0]);
15854 var name = normalizeStringArgument$1(argumentsLength < 2 ? undefined : arguments[1], 'Error');
15855 var code = codeFor(name);
15856 setInternalState$2(this, {
15857 type: DOM_EXCEPTION$2,
15858 name: name,
15859 message: message,
15860 code: code
15861 });
15862 if (!DESCRIPTORS$6) {
15863 this.name = name;
15864 this.message = message;
15865 this.code = code;
15866 }
15867 if (HAS_STACK) {
15868 var error = Error$3(message);
15869 error.name = DOM_EXCEPTION$2;
15870 defineProperty$2(this, 'stack', createPropertyDescriptor$2(1, clearErrorStack$1(error.stack, 1)));
15871 }
15872 };
15873 var DOMExceptionPrototype$1 = $DOMException$1.prototype = create$1(ErrorPrototype);
15874 var createGetterDescriptor = function createGetterDescriptor(get) {
15875 return {
15876 enumerable: true,
15877 configurable: true,
15878 get: get
15879 };
15880 };
15881 var getterFor = function getterFor(key) {
15882 return createGetterDescriptor(function () {
15883 return getInternalState(this)[key];
15884 });
15885 };
15886 if (DESCRIPTORS$6) {
15887 // `DOMException.prototype.code` getter
15888 defineBuiltInAccessor$4(DOMExceptionPrototype$1, 'code', getterFor('code'));
15889 // `DOMException.prototype.message` getter
15890 defineBuiltInAccessor$4(DOMExceptionPrototype$1, 'message', getterFor('message'));
15891 // `DOMException.prototype.name` getter
15892 defineBuiltInAccessor$4(DOMExceptionPrototype$1, 'name', getterFor('name'));
15893 }
15894 defineProperty$2(DOMExceptionPrototype$1, 'constructor', createPropertyDescriptor$2(1, $DOMException$1));
15895
15896 // FF36- DOMException is a function, but can't be constructed
15897 var INCORRECT_CONSTRUCTOR = fails$3(function () {
15898 return !(new NativeDOMException$1() instanceof Error$3);
15899 });
15900
15901 // Safari 10.1 / Chrome 32- / IE8- DOMException.prototype.toString bugs
15902 var INCORRECT_TO_STRING = INCORRECT_CONSTRUCTOR || fails$3(function () {
15903 return ErrorPrototype.toString !== errorToString || String(new NativeDOMException$1(1, 2)) !== '2: 1';
15904 });
15905
15906 // Deno 1.6.3- DOMException.prototype.code just missed
15907 var INCORRECT_CODE = INCORRECT_CONSTRUCTOR || fails$3(function () {
15908 return new NativeDOMException$1(1, 'DataCloneError').code !== 25;
15909 });
15910
15911 // Deno 1.6.3- DOMException constants just missed
15912 var MISSED_CONSTANTS = INCORRECT_CONSTRUCTOR || NativeDOMException$1[DATA_CLONE_ERR] !== 25 || NativeDOMExceptionPrototype[DATA_CLONE_ERR] !== 25;
15913 var FORCED_CONSTRUCTOR$1 = IS_PURE$3 ? INCORRECT_TO_STRING || INCORRECT_CODE || MISSED_CONSTANTS : INCORRECT_CONSTRUCTOR;
15914
15915 // `DOMException` constructor
15916 // https://webidl.spec.whatwg.org/#idl-DOMException
15917 $$b({
15918 global: true,
15919 constructor: true,
15920 forced: FORCED_CONSTRUCTOR$1
15921 }, {
15922 DOMException: FORCED_CONSTRUCTOR$1 ? $DOMException$1 : NativeDOMException$1
15923 });
15924 var PolyfilledDOMException$1 = getBuiltIn$2(DOM_EXCEPTION$2);
15925 var PolyfilledDOMExceptionPrototype$1 = PolyfilledDOMException$1.prototype;
15926 if (INCORRECT_TO_STRING && (IS_PURE$3 || NativeDOMException$1 === PolyfilledDOMException$1)) {
15927 defineBuiltIn$2(PolyfilledDOMExceptionPrototype$1, 'toString', errorToString);
15928 }
15929 if (INCORRECT_CODE && DESCRIPTORS$6 && NativeDOMException$1 === PolyfilledDOMException$1) {
15930 defineBuiltInAccessor$4(PolyfilledDOMExceptionPrototype$1, 'code', createGetterDescriptor(function () {
15931 return codeFor(anObject$2(this).name);
15932 }));
15933 }
15934
15935 // `DOMException` constants
15936 for (var key$1 in DOMExceptionConstants$1) if (hasOwn$4(DOMExceptionConstants$1, key$1)) {
15937 var constant$2 = DOMExceptionConstants$1[key$1];
15938 var constantName$1 = constant$2.s;
15939 var descriptor$2 = createPropertyDescriptor$2(6, constant$2.c);
15940 if (!hasOwn$4(PolyfilledDOMException$1, constantName$1)) {
15941 defineProperty$2(PolyfilledDOMException$1, constantName$1, descriptor$2);
15942 }
15943 if (!hasOwn$4(PolyfilledDOMExceptionPrototype$1, constantName$1)) {
15944 defineProperty$2(PolyfilledDOMExceptionPrototype$1, constantName$1, descriptor$2);
15945 }
15946 }
15947
15948 var web_domException_stack = {};
15949
15950 'use strict';
15951 var $$a = _export;
15952 var global$c = global$Z;
15953 var getBuiltIn$1 = getBuiltIn$m;
15954 var createPropertyDescriptor$1 = createPropertyDescriptor$c;
15955 var defineProperty$1 = objectDefineProperty.f;
15956 var hasOwn$3 = hasOwnProperty_1;
15957 var anInstance$2 = anInstance$a;
15958 var inheritIfRequired = inheritIfRequired$6;
15959 var normalizeStringArgument = normalizeStringArgument$5;
15960 var DOMExceptionConstants = domExceptionConstants;
15961 var clearErrorStack = errorStackClear;
15962 var DESCRIPTORS$5 = descriptors;
15963 var IS_PURE$2 = isPure;
15964 var DOM_EXCEPTION$1 = 'DOMException';
15965 var Error$2 = getBuiltIn$1('Error');
15966 var NativeDOMException = getBuiltIn$1(DOM_EXCEPTION$1);
15967 var $DOMException = function DOMException() {
15968 anInstance$2(this, DOMExceptionPrototype);
15969 var argumentsLength = arguments.length;
15970 var message = normalizeStringArgument(argumentsLength < 1 ? undefined : arguments[0]);
15971 var name = normalizeStringArgument(argumentsLength < 2 ? undefined : arguments[1], 'Error');
15972 var that = new NativeDOMException(message, name);
15973 var error = Error$2(message);
15974 error.name = DOM_EXCEPTION$1;
15975 defineProperty$1(that, 'stack', createPropertyDescriptor$1(1, clearErrorStack(error.stack, 1)));
15976 inheritIfRequired(that, this, $DOMException);
15977 return that;
15978 };
15979 var DOMExceptionPrototype = $DOMException.prototype = NativeDOMException.prototype;
15980 var ERROR_HAS_STACK = ('stack' in Error$2(DOM_EXCEPTION$1));
15981 var DOM_EXCEPTION_HAS_STACK = ('stack' in new NativeDOMException(1, 2));
15982
15983 // eslint-disable-next-line es/no-object-getownpropertydescriptor -- safe
15984 var descriptor$1 = NativeDOMException && DESCRIPTORS$5 && Object.getOwnPropertyDescriptor(global$c, DOM_EXCEPTION$1);
15985
15986 // Bun ~ 0.1.1 DOMException have incorrect descriptor and we can't redefine it
15987 // https://github.com/Jarred-Sumner/bun/issues/399
15988 var BUGGY_DESCRIPTOR = !!descriptor$1 && !(descriptor$1.writable && descriptor$1.configurable);
15989 var FORCED_CONSTRUCTOR = ERROR_HAS_STACK && !BUGGY_DESCRIPTOR && !DOM_EXCEPTION_HAS_STACK;
15990
15991 // `DOMException` constructor patch for `.stack` where it's required
15992 // https://webidl.spec.whatwg.org/#es-DOMException-specialness
15993 $$a({
15994 global: true,
15995 constructor: true,
15996 forced: IS_PURE$2 || FORCED_CONSTRUCTOR
15997 }, {
15998 // TODO: fix export logic
15999 DOMException: FORCED_CONSTRUCTOR ? $DOMException : NativeDOMException
16000 });
16001 var PolyfilledDOMException = getBuiltIn$1(DOM_EXCEPTION$1);
16002 var PolyfilledDOMExceptionPrototype = PolyfilledDOMException.prototype;
16003 if (PolyfilledDOMExceptionPrototype.constructor !== PolyfilledDOMException) {
16004 if (!IS_PURE$2) {
16005 defineProperty$1(PolyfilledDOMExceptionPrototype, 'constructor', createPropertyDescriptor$1(1, PolyfilledDOMException));
16006 }
16007 for (var key in DOMExceptionConstants) if (hasOwn$3(DOMExceptionConstants, key)) {
16008 var constant$1 = DOMExceptionConstants[key];
16009 var constantName = constant$1.s;
16010 if (!hasOwn$3(PolyfilledDOMException, constantName)) {
16011 defineProperty$1(PolyfilledDOMException, constantName, createPropertyDescriptor$1(6, constant$1.c));
16012 }
16013 }
16014 }
16015
16016 var web_domException_toStringTag = {};
16017
16018 var getBuiltIn = getBuiltIn$m;
16019 var setToStringTag$2 = setToStringTag$d;
16020 var DOM_EXCEPTION = 'DOMException';
16021
16022 // `DOMException.prototype[@@toStringTag]` property
16023 setToStringTag$2(getBuiltIn(DOM_EXCEPTION), DOM_EXCEPTION);
16024
16025 var web_immediate = {};
16026
16027 var web_clearImmediate = {};
16028
16029 var $$9 = _export;
16030 var global$b = global$Z;
16031 var clearImmediate = task$1.clear;
16032
16033 // `clearImmediate` method
16034 // http://w3c.github.io/setImmediate/#si-clearImmediate
16035 $$9({
16036 global: true,
16037 bind: true,
16038 enumerable: true,
16039 forced: global$b.clearImmediate !== clearImmediate
16040 }, {
16041 clearImmediate: clearImmediate
16042 });
16043
16044 var web_setImmediate = {};
16045
16046 /* global Bun -- Deno case */
16047 var engineIsBun = typeof Bun == 'function' && Bun && typeof Bun.version == 'string';
16048 var engineIsBun$1 = /*@__PURE__*/getDefaultExportFromCjs(engineIsBun);
16049
16050 'use strict';
16051 var global$a = global$Z;
16052 var apply = functionApply$1;
16053 var isCallable$2 = isCallable$z;
16054 var ENGINE_IS_BUN = engineIsBun;
16055 var USER_AGENT = engineUserAgent;
16056 var arraySlice$1 = arraySlice$a;
16057 var validateArgumentsLength$4 = validateArgumentsLength$8;
16058 var Function$1 = global$a.Function;
16059 // dirty IE9- and Bun 0.3.0- checks
16060 var WRAP = /MSIE .\./.test(USER_AGENT) || ENGINE_IS_BUN && function () {
16061 var version = global$a.Bun.version.split('.');
16062 return version.length < 3 || version[0] == 0 && (version[1] < 3 || version[1] == 3 && version[2] == 0);
16063 }();
16064
16065 // IE9- / Bun 0.3.0- setTimeout / setInterval / setImmediate additional parameters fix
16066 // https://html.spec.whatwg.org/multipage/timers-and-user-prompts.html#timers
16067 // https://github.com/oven-sh/bun/issues/1633
16068 var schedulersFix$3 = function schedulersFix(scheduler, hasTimeArg) {
16069 var firstParamIndex = hasTimeArg ? 2 : 1;
16070 return WRAP ? function (handler, timeout /* , ...arguments */) {
16071 var boundArgs = validateArgumentsLength$4(arguments.length, 1) > firstParamIndex;
16072 var fn = isCallable$2(handler) ? handler : Function$1(handler);
16073 var params = boundArgs ? arraySlice$1(arguments, firstParamIndex) : [];
16074 var callback = boundArgs ? function () {
16075 apply(fn, this, params);
16076 } : fn;
16077 return hasTimeArg ? scheduler(callback, timeout) : scheduler(callback);
16078 } : scheduler;
16079 };
16080 var schedulersFix$4 = /*@__PURE__*/getDefaultExportFromCjs(schedulersFix$3);
16081
16082 var $$8 = _export;
16083 var global$9 = global$Z;
16084 var setTask = task$1.set;
16085 var schedulersFix$2 = schedulersFix$3;
16086
16087 // https://github.com/oven-sh/bun/issues/1633
16088 var setImmediate$1 = global$9.setImmediate ? schedulersFix$2(setTask, false) : setTask;
16089
16090 // `setImmediate` method
16091 // http://w3c.github.io/setImmediate/#si-setImmediate
16092 $$8({
16093 global: true,
16094 bind: true,
16095 enumerable: true,
16096 forced: global$9.setImmediate !== setImmediate$1
16097 }, {
16098 setImmediate: setImmediate$1
16099 });
16100
16101 var web_queueMicrotask = {};
16102
16103 var $$7 = _export;
16104 var global$8 = global$Z;
16105 var microtask = microtask_1;
16106 var aCallable = aCallable$l;
16107 var validateArgumentsLength$3 = validateArgumentsLength$8;
16108 var IS_NODE$1 = engineIsNode;
16109 var process$1 = global$8.process;
16110
16111 // `queueMicrotask` method
16112 // https://html.spec.whatwg.org/multipage/timers-and-user-prompts.html#dom-queuemicrotask
16113 $$7({
16114 global: true,
16115 enumerable: true,
16116 dontCallGetSet: true
16117 }, {
16118 queueMicrotask: function queueMicrotask(fn) {
16119 validateArgumentsLength$3(arguments.length, 1);
16120 aCallable(fn);
16121 var domain = IS_NODE$1 && process$1.domain;
16122 microtask(domain ? domain.bind(fn) : fn);
16123 }
16124 });
16125
16126 var web_self = {};
16127
16128 'use strict';
16129 var $$6 = _export;
16130 var global$7 = global$Z;
16131 var defineBuiltInAccessor$3 = defineBuiltInAccessor$h;
16132 var DESCRIPTORS$4 = descriptors;
16133 var $TypeError = TypeError;
16134 // eslint-disable-next-line es/no-object-defineproperty -- safe
16135 var defineProperty = Object.defineProperty;
16136 var INCORRECT_VALUE = global$7.self !== global$7;
16137
16138 // `self` getter
16139 // https://html.spec.whatwg.org/multipage/window-object.html#dom-self
16140 try {
16141 if (DESCRIPTORS$4) {
16142 // eslint-disable-next-line es/no-object-getownpropertydescriptor -- safe
16143 var descriptor = Object.getOwnPropertyDescriptor(global$7, 'self');
16144 // some engines have `self`, but with incorrect descriptor
16145 // https://github.com/denoland/deno/issues/15765
16146 if (INCORRECT_VALUE || !descriptor || !descriptor.get || !descriptor.enumerable) {
16147 defineBuiltInAccessor$3(global$7, 'self', {
16148 get: function self() {
16149 return global$7;
16150 },
16151 set: function self(value) {
16152 if (this !== global$7) throw $TypeError('Illegal invocation');
16153 defineProperty(global$7, 'self', {
16154 value: value,
16155 writable: true,
16156 configurable: true,
16157 enumerable: true
16158 });
16159 },
16160 configurable: true,
16161 enumerable: true
16162 });
16163 }
16164 } else $$6({
16165 global: true,
16166 simple: true,
16167 forced: INCORRECT_VALUE
16168 }, {
16169 self: global$7
16170 });
16171 } catch (error) {/* empty */}
16172
16173 var web_structuredClone = {};
16174
16175 var uncurryThis$6 = functionUncurryThis;
16176
16177 // eslint-disable-next-line es/no-map -- safe
16178 var MapPrototype = Map.prototype;
16179 var mapHelpers = {
16180 // eslint-disable-next-line es/no-map -- safe
16181 Map: Map,
16182 set: uncurryThis$6(MapPrototype.set),
16183 get: uncurryThis$6(MapPrototype.get),
16184 has: uncurryThis$6(MapPrototype.has),
16185 remove: uncurryThis$6(MapPrototype['delete']),
16186 proto: MapPrototype
16187 };
16188 var mapHelpers$1 = /*@__PURE__*/getDefaultExportFromCjs(mapHelpers);
16189
16190 var uncurryThis$5 = functionUncurryThis;
16191
16192 // eslint-disable-next-line es/no-set -- safe
16193 var SetPrototype = Set.prototype;
16194 var setHelpers = {
16195 // eslint-disable-next-line es/no-set -- safe
16196 Set: Set,
16197 add: uncurryThis$5(SetPrototype.add),
16198 has: uncurryThis$5(SetPrototype.has),
16199 remove: uncurryThis$5(SetPrototype['delete']),
16200 proto: SetPrototype,
16201 $has: SetPrototype.has,
16202 $keys: SetPrototype.keys
16203 };
16204 var setHelpers$1 = /*@__PURE__*/getDefaultExportFromCjs(setHelpers);
16205
16206 var global$6 = global$Z;
16207 var fails$2 = fails$1m;
16208 var V8 = engineV8Version;
16209 var IS_BROWSER = engineIsBrowser;
16210 var IS_DENO = engineIsDeno;
16211 var IS_NODE = engineIsNode;
16212 var structuredClone = global$6.structuredClone;
16213 var structuredCloneProperTransfer = !!structuredClone && !fails$2(function () {
16214 // prevent V8 ArrayBufferDetaching protector cell invalidation and performance degradation
16215 // https://github.com/zloirock/core-js/issues/679
16216 if (IS_DENO && V8 > 92 || IS_NODE && V8 > 94 || IS_BROWSER && V8 > 97) return false;
16217 var buffer = new ArrayBuffer(8);
16218 var clone = structuredClone(buffer, {
16219 transfer: [buffer]
16220 });
16221 return buffer.byteLength != 0 || clone.byteLength != 8;
16222 });
16223 var structuredCloneProperTransfer$1 = /*@__PURE__*/getDefaultExportFromCjs(structuredCloneProperTransfer);
16224
16225 var IS_PURE$1 = isPure;
16226 var $$5 = _export;
16227 var global$5 = global$Z;
16228 var getBuiltin = getBuiltIn$m;
16229 var uncurryThis$4 = functionUncurryThis;
16230 var fails$1 = fails$1m;
16231 var uid = uid$6;
16232 var isCallable$1 = isCallable$z;
16233 var isConstructor = isConstructor$6;
16234 var isNullOrUndefined = isNullOrUndefined$e;
16235 var isObject$1 = isObject$z;
16236 var isSymbol = isSymbol$7;
16237 var iterate = iterate$a;
16238 var anObject$1 = anObject$D;
16239 var classof$1 = classof$m;
16240 var hasOwn$2 = hasOwnProperty_1;
16241 var createProperty = createProperty$9;
16242 var createNonEnumerableProperty = createNonEnumerableProperty$f;
16243 var lengthOfArrayLike = lengthOfArrayLike$t;
16244 var validateArgumentsLength$2 = validateArgumentsLength$8;
16245 var getRegExpFlags = regexpGetFlags;
16246 var MapHelpers = mapHelpers;
16247 var SetHelpers = setHelpers;
16248 var ERROR_STACK_INSTALLABLE = errorStackInstallable;
16249 var PROPER_TRANSFER = structuredCloneProperTransfer;
16250 var Object$1 = global$5.Object;
16251 var Array$1 = global$5.Array;
16252 var Date$1 = global$5.Date;
16253 var Error$1 = global$5.Error;
16254 var EvalError = global$5.EvalError;
16255 var RangeError$1 = global$5.RangeError;
16256 var ReferenceError$1 = global$5.ReferenceError;
16257 var SyntaxError = global$5.SyntaxError;
16258 var TypeError$3 = global$5.TypeError;
16259 var URIError = global$5.URIError;
16260 var PerformanceMark = global$5.PerformanceMark;
16261 var WebAssembly$1 = global$5.WebAssembly;
16262 var CompileError = WebAssembly$1 && WebAssembly$1.CompileError || Error$1;
16263 var LinkError = WebAssembly$1 && WebAssembly$1.LinkError || Error$1;
16264 var RuntimeError$1 = WebAssembly$1 && WebAssembly$1.RuntimeError || Error$1;
16265 var DOMException = getBuiltin('DOMException');
16266 var Map$1 = MapHelpers.Map;
16267 var mapHas = MapHelpers.has;
16268 var mapGet = MapHelpers.get;
16269 var mapSet = MapHelpers.set;
16270 var Set$1 = SetHelpers.Set;
16271 var setAdd = SetHelpers.add;
16272 var objectKeys = getBuiltin('Object', 'keys');
16273 var push$3 = uncurryThis$4([].push);
16274 var thisBooleanValue = uncurryThis$4(true.valueOf);
16275 var thisNumberValue = uncurryThis$4(1.0.valueOf);
16276 var thisStringValue = uncurryThis$4(''.valueOf);
16277 var thisTimeValue = uncurryThis$4(Date$1.prototype.getTime);
16278 var PERFORMANCE_MARK = uid('structuredClone');
16279 var DATA_CLONE_ERROR = 'DataCloneError';
16280 var TRANSFERRING = 'Transferring';
16281 var checkBasicSemantic = function checkBasicSemantic(structuredCloneImplementation) {
16282 return !fails$1(function () {
16283 var set1 = new global$5.Set([7]);
16284 var set2 = structuredCloneImplementation(set1);
16285 var number = structuredCloneImplementation(Object$1(7));
16286 return set2 == set1 || !set2.has(7) || _typeof(number) != 'object' || number != 7;
16287 }) && structuredCloneImplementation;
16288 };
16289 var checkErrorsCloning = function checkErrorsCloning(structuredCloneImplementation, $Error) {
16290 return !fails$1(function () {
16291 var error = new $Error();
16292 var test = structuredCloneImplementation({
16293 a: error,
16294 b: error
16295 });
16296 return !(test && test.a === test.b && test.a instanceof $Error && test.a.stack === error.stack);
16297 });
16298 };
16299
16300 // https://github.com/whatwg/html/pull/5749
16301 var checkNewErrorsCloningSemantic = function checkNewErrorsCloningSemantic(structuredCloneImplementation) {
16302 return !fails$1(function () {
16303 var test = structuredCloneImplementation(new global$5.AggregateError([1], PERFORMANCE_MARK, {
16304 cause: 3
16305 }));
16306 return test.name != 'AggregateError' || test.errors[0] != 1 || test.message != PERFORMANCE_MARK || test.cause != 3;
16307 });
16308 };
16309
16310 // FF94+, Safari 15.4+, Chrome 98+, NodeJS 17.0+, Deno 1.13+
16311 // FF<103 and Safari implementations can't clone errors
16312 // https://bugzilla.mozilla.org/show_bug.cgi?id=1556604
16313 // FF103 can clone errors, but `.stack` of clone is an empty string
16314 // https://bugzilla.mozilla.org/show_bug.cgi?id=1778762
16315 // FF104+ fixed it on usual errors, but not on DOMExceptions
16316 // https://bugzilla.mozilla.org/show_bug.cgi?id=1777321
16317 // Chrome <102 returns `null` if cloned object contains multiple references to one error
16318 // https://bugs.chromium.org/p/v8/issues/detail?id=12542
16319 // NodeJS implementation can't clone DOMExceptions
16320 // https://github.com/nodejs/node/issues/41038
16321 // only FF103+ supports new (html/5749) error cloning semantic
16322 var nativeStructuredClone = global$5.structuredClone;
16323 var FORCED_REPLACEMENT = IS_PURE$1 || !checkErrorsCloning(nativeStructuredClone, Error$1) || !checkErrorsCloning(nativeStructuredClone, DOMException) || !checkNewErrorsCloningSemantic(nativeStructuredClone);
16324
16325 // Chrome 82+, Safari 14.1+, Deno 1.11+
16326 // Chrome 78-81 implementation swaps `.name` and `.message` of cloned `DOMException`
16327 // Chrome returns `null` if cloned object contains multiple references to one error
16328 // Safari 14.1 implementation doesn't clone some `RegExp` flags, so requires a workaround
16329 // Safari implementation can't clone errors
16330 // Deno 1.2-1.10 implementations too naive
16331 // NodeJS 16.0+ does not have `PerformanceMark` constructor
16332 // NodeJS <17.2 structured cloning implementation from `performance.mark` is too naive
16333 // and can't clone, for example, `RegExp` or some boxed primitives
16334 // https://github.com/nodejs/node/issues/40840
16335 // no one of those implementations supports new (html/5749) error cloning semantic
16336 var structuredCloneFromMark = !nativeStructuredClone && checkBasicSemantic(function (value) {
16337 return new PerformanceMark(PERFORMANCE_MARK, {
16338 detail: value
16339 }).detail;
16340 });
16341 var nativeRestrictedStructuredClone = checkBasicSemantic(nativeStructuredClone) || structuredCloneFromMark;
16342 var throwUncloneable = function throwUncloneable(type) {
16343 throw new DOMException('Uncloneable type: ' + type, DATA_CLONE_ERROR);
16344 };
16345 var throwUnpolyfillable = function throwUnpolyfillable(type, action) {
16346 throw new DOMException((action || 'Cloning') + ' of ' + type + ' cannot be properly polyfilled in this engine', DATA_CLONE_ERROR);
16347 };
16348 var createDataTransfer = function createDataTransfer() {
16349 var dataTransfer;
16350 try {
16351 dataTransfer = new global$5.DataTransfer();
16352 } catch (error) {
16353 try {
16354 dataTransfer = new global$5.ClipboardEvent('').clipboardData;
16355 } catch (error2) {/* empty */}
16356 }
16357 return dataTransfer && dataTransfer.items && dataTransfer.files ? dataTransfer : null;
16358 };
16359 var structuredCloneInternal = function structuredCloneInternal(value, map) {
16360 if (isSymbol(value)) throwUncloneable('Symbol');
16361 if (!isObject$1(value)) return value;
16362 // effectively preserves circular references
16363 if (map) {
16364 if (mapHas(map, value)) return mapGet(map, value);
16365 } else map = new Map$1();
16366 var type = classof$1(value);
16367 var deep = false;
16368 var C, name, cloned, dataTransfer, i, length, keys, key, source, target, options;
16369 switch (type) {
16370 case 'Array':
16371 cloned = Array$1(lengthOfArrayLike(value));
16372 deep = true;
16373 break;
16374 case 'Object':
16375 cloned = {};
16376 deep = true;
16377 break;
16378 case 'Map':
16379 cloned = new Map$1();
16380 deep = true;
16381 break;
16382 case 'Set':
16383 cloned = new Set$1();
16384 deep = true;
16385 break;
16386 case 'RegExp':
16387 // in this block because of a Safari 14.1 bug
16388 // old FF does not clone regexes passed to the constructor, so get the source and flags directly
16389 cloned = new RegExp(value.source, getRegExpFlags(value));
16390 break;
16391 case 'Error':
16392 name = value.name;
16393 switch (name) {
16394 case 'AggregateError':
16395 cloned = getBuiltin('AggregateError')([]);
16396 break;
16397 case 'EvalError':
16398 cloned = EvalError();
16399 break;
16400 case 'RangeError':
16401 cloned = RangeError$1();
16402 break;
16403 case 'ReferenceError':
16404 cloned = ReferenceError$1();
16405 break;
16406 case 'SyntaxError':
16407 cloned = SyntaxError();
16408 break;
16409 case 'TypeError':
16410 cloned = TypeError$3();
16411 break;
16412 case 'URIError':
16413 cloned = URIError();
16414 break;
16415 case 'CompileError':
16416 cloned = CompileError();
16417 break;
16418 case 'LinkError':
16419 cloned = LinkError();
16420 break;
16421 case 'RuntimeError':
16422 cloned = RuntimeError$1();
16423 break;
16424 default:
16425 cloned = Error$1();
16426 }
16427 deep = true;
16428 break;
16429 case 'DOMException':
16430 cloned = new DOMException(value.message, value.name);
16431 deep = true;
16432 break;
16433 case 'DataView':
16434 case 'Int8Array':
16435 case 'Uint8Array':
16436 case 'Uint8ClampedArray':
16437 case 'Int16Array':
16438 case 'Uint16Array':
16439 case 'Int32Array':
16440 case 'Uint32Array':
16441 case 'Float32Array':
16442 case 'Float64Array':
16443 case 'BigInt64Array':
16444 case 'BigUint64Array':
16445 C = global$5[type];
16446 // in some old engines like Safari 9, typeof C is 'object'
16447 // on Uint8ClampedArray or some other constructors
16448 if (!isObject$1(C)) throwUnpolyfillable(type);
16449 cloned = new C(
16450 // this is safe, since arraybuffer cannot have circular references
16451 structuredCloneInternal(value.buffer, map), value.byteOffset, type === 'DataView' ? value.byteLength : value.length);
16452 break;
16453 case 'DOMQuad':
16454 try {
16455 cloned = new DOMQuad(structuredCloneInternal(value.p1, map), structuredCloneInternal(value.p2, map), structuredCloneInternal(value.p3, map), structuredCloneInternal(value.p4, map));
16456 } catch (error) {
16457 if (nativeRestrictedStructuredClone) {
16458 cloned = nativeRestrictedStructuredClone(value);
16459 } else throwUnpolyfillable(type);
16460 }
16461 break;
16462 case 'FileList':
16463 dataTransfer = createDataTransfer();
16464 if (dataTransfer) {
16465 for (i = 0, length = lengthOfArrayLike(value); i < length; i++) {
16466 dataTransfer.items.add(structuredCloneInternal(value[i], map));
16467 }
16468 cloned = dataTransfer.files;
16469 } else if (nativeRestrictedStructuredClone) {
16470 cloned = nativeRestrictedStructuredClone(value);
16471 } else throwUnpolyfillable(type);
16472 break;
16473 case 'ImageData':
16474 // Safari 9 ImageData is a constructor, but typeof ImageData is 'object'
16475 try {
16476 cloned = new ImageData(structuredCloneInternal(value.data, map), value.width, value.height, {
16477 colorSpace: value.colorSpace
16478 });
16479 } catch (error) {
16480 if (nativeRestrictedStructuredClone) {
16481 cloned = nativeRestrictedStructuredClone(value);
16482 } else throwUnpolyfillable(type);
16483 }
16484 break;
16485 default:
16486 if (nativeRestrictedStructuredClone) {
16487 cloned = nativeRestrictedStructuredClone(value);
16488 } else switch (type) {
16489 case 'BigInt':
16490 // can be a 3rd party polyfill
16491 cloned = Object$1(value.valueOf());
16492 break;
16493 case 'Boolean':
16494 cloned = Object$1(thisBooleanValue(value));
16495 break;
16496 case 'Number':
16497 cloned = Object$1(thisNumberValue(value));
16498 break;
16499 case 'String':
16500 cloned = Object$1(thisStringValue(value));
16501 break;
16502 case 'Date':
16503 cloned = new Date$1(thisTimeValue(value));
16504 break;
16505 case 'ArrayBuffer':
16506 C = global$5.DataView;
16507 // `ArrayBuffer#slice` is not available in IE10
16508 // `ArrayBuffer#slice` and `DataView` are not available in old FF
16509 if (!C && typeof value.slice != 'function') throwUnpolyfillable(type);
16510 // detached buffers throws in `DataView` and `.slice`
16511 try {
16512 if (typeof value.slice == 'function' && !value.resizable) {
16513 cloned = value.slice(0);
16514 } else {
16515 length = value.byteLength;
16516 options = 'maxByteLength' in value ? {
16517 maxByteLength: value.maxByteLength
16518 } : undefined;
16519 cloned = new ArrayBuffer(length, options);
16520 source = new C(value);
16521 target = new C(cloned);
16522 for (i = 0; i < length; i++) {
16523 target.setUint8(i, source.getUint8(i));
16524 }
16525 }
16526 } catch (error) {
16527 throw new DOMException('ArrayBuffer is detached', DATA_CLONE_ERROR);
16528 }
16529 break;
16530 case 'SharedArrayBuffer':
16531 // SharedArrayBuffer should use shared memory, we can't polyfill it, so return the original
16532 cloned = value;
16533 break;
16534 case 'Blob':
16535 try {
16536 cloned = value.slice(0, value.size, value.type);
16537 } catch (error) {
16538 throwUnpolyfillable(type);
16539 }
16540 break;
16541 case 'DOMPoint':
16542 case 'DOMPointReadOnly':
16543 C = global$5[type];
16544 try {
16545 cloned = C.fromPoint ? C.fromPoint(value) : new C(value.x, value.y, value.z, value.w);
16546 } catch (error) {
16547 throwUnpolyfillable(type);
16548 }
16549 break;
16550 case 'DOMRect':
16551 case 'DOMRectReadOnly':
16552 C = global$5[type];
16553 try {
16554 cloned = C.fromRect ? C.fromRect(value) : new C(value.x, value.y, value.width, value.height);
16555 } catch (error) {
16556 throwUnpolyfillable(type);
16557 }
16558 break;
16559 case 'DOMMatrix':
16560 case 'DOMMatrixReadOnly':
16561 C = global$5[type];
16562 try {
16563 cloned = C.fromMatrix ? C.fromMatrix(value) : new C(value);
16564 } catch (error) {
16565 throwUnpolyfillable(type);
16566 }
16567 break;
16568 case 'AudioData':
16569 case 'VideoFrame':
16570 if (!isCallable$1(value.clone)) throwUnpolyfillable(type);
16571 try {
16572 cloned = value.clone();
16573 } catch (error) {
16574 throwUncloneable(type);
16575 }
16576 break;
16577 case 'File':
16578 try {
16579 cloned = new File([value], value.name, value);
16580 } catch (error) {
16581 throwUnpolyfillable(type);
16582 }
16583 break;
16584 case 'CropTarget':
16585 case 'CryptoKey':
16586 case 'FileSystemDirectoryHandle':
16587 case 'FileSystemFileHandle':
16588 case 'FileSystemHandle':
16589 case 'GPUCompilationInfo':
16590 case 'GPUCompilationMessage':
16591 case 'ImageBitmap':
16592 case 'RTCCertificate':
16593 case 'WebAssembly.Module':
16594 throwUnpolyfillable(type);
16595 // break omitted
16596 default:
16597 throwUncloneable(type);
16598 }
16599 }
16600 mapSet(map, value, cloned);
16601 if (deep) switch (type) {
16602 case 'Array':
16603 case 'Object':
16604 keys = objectKeys(value);
16605 for (i = 0, length = lengthOfArrayLike(keys); i < length; i++) {
16606 key = keys[i];
16607 createProperty(cloned, key, structuredCloneInternal(value[key], map));
16608 }
16609 break;
16610 case 'Map':
16611 value.forEach(function (v, k) {
16612 mapSet(cloned, structuredCloneInternal(k, map), structuredCloneInternal(v, map));
16613 });
16614 break;
16615 case 'Set':
16616 value.forEach(function (v) {
16617 setAdd(cloned, structuredCloneInternal(v, map));
16618 });
16619 break;
16620 case 'Error':
16621 createNonEnumerableProperty(cloned, 'message', structuredCloneInternal(value.message, map));
16622 if (hasOwn$2(value, 'cause')) {
16623 createNonEnumerableProperty(cloned, 'cause', structuredCloneInternal(value.cause, map));
16624 }
16625 if (name == 'AggregateError') {
16626 cloned.errors = structuredCloneInternal(value.errors, map);
16627 }
16628 // break omitted
16629 case 'DOMException':
16630 if (ERROR_STACK_INSTALLABLE) {
16631 createNonEnumerableProperty(cloned, 'stack', structuredCloneInternal(value.stack, map));
16632 }
16633 }
16634 return cloned;
16635 };
16636 var tryToTransfer = function tryToTransfer(rawTransfer, map) {
16637 if (!isObject$1(rawTransfer)) throw TypeError$3('Transfer option cannot be converted to a sequence');
16638 var transfer = [];
16639 iterate(rawTransfer, function (value) {
16640 push$3(transfer, anObject$1(value));
16641 });
16642 var i = 0;
16643 var length = lengthOfArrayLike(transfer);
16644 var value, type, C, transferredArray, transferred, canvas, context;
16645 if (PROPER_TRANSFER) {
16646 transferredArray = nativeStructuredClone(transfer, {
16647 transfer: transfer
16648 });
16649 while (i < length) mapSet(map, transfer[i], transferredArray[i++]);
16650 } else while (i < length) {
16651 value = transfer[i++];
16652 if (mapHas(map, value)) throw new DOMException('Duplicate transferable', DATA_CLONE_ERROR);
16653 type = classof$1(value);
16654 switch (type) {
16655 case 'ImageBitmap':
16656 C = global$5.OffscreenCanvas;
16657 if (!isConstructor(C)) throwUnpolyfillable(type, TRANSFERRING);
16658 try {
16659 canvas = new C(value.width, value.height);
16660 context = canvas.getContext('bitmaprenderer');
16661 context.transferFromImageBitmap(value);
16662 transferred = canvas.transferToImageBitmap();
16663 } catch (error) {/* empty */}
16664 break;
16665 case 'AudioData':
16666 case 'VideoFrame':
16667 if (!isCallable$1(value.clone) || !isCallable$1(value.close)) throwUnpolyfillable(type, TRANSFERRING);
16668 try {
16669 transferred = value.clone();
16670 value.close();
16671 } catch (error) {/* empty */}
16672 break;
16673 case 'ArrayBuffer':
16674 if (!isCallable$1(value.transfer)) throwUnpolyfillable(type, TRANSFERRING);
16675 transferred = value.transfer();
16676 break;
16677 case 'MediaSourceHandle':
16678 case 'MessagePort':
16679 case 'OffscreenCanvas':
16680 case 'ReadableStream':
16681 case 'TransformStream':
16682 case 'WritableStream':
16683 throwUnpolyfillable(type, TRANSFERRING);
16684 }
16685 if (transferred === undefined) throw new DOMException('This object cannot be transferred: ' + type, DATA_CLONE_ERROR);
16686 mapSet(map, value, transferred);
16687 }
16688 };
16689
16690 // `structuredClone` method
16691 // https://html.spec.whatwg.org/multipage/structured-data.html#dom-structuredclone
16692 $$5({
16693 global: true,
16694 enumerable: true,
16695 sham: !PROPER_TRANSFER,
16696 forced: FORCED_REPLACEMENT
16697 }, {
16698 structuredClone: function structuredClone(value /* , { transfer } */) {
16699 var options = validateArgumentsLength$2(arguments.length, 1) > 1 && !isNullOrUndefined(arguments[1]) ? anObject$1(arguments[1]) : undefined;
16700 var transfer = options ? options.transfer : undefined;
16701 var map;
16702 if (transfer !== undefined) {
16703 map = new Map$1();
16704 tryToTransfer(transfer, map);
16705 }
16706 return structuredCloneInternal(value, map);
16707 }
16708 });
16709
16710 var web_timers = {};
16711
16712 var web_setInterval = {};
16713
16714 var $$4 = _export;
16715 var global$4 = global$Z;
16716 var schedulersFix$1 = schedulersFix$3;
16717 var setInterval$1 = schedulersFix$1(global$4.setInterval, true);
16718
16719 // Bun / IE9- setInterval additional parameters fix
16720 // https://html.spec.whatwg.org/multipage/timers-and-user-prompts.html#dom-setinterval
16721 $$4({
16722 global: true,
16723 bind: true,
16724 forced: global$4.setInterval !== setInterval$1
16725 }, {
16726 setInterval: setInterval$1
16727 });
16728
16729 var web_setTimeout = {};
16730
16731 var $$3 = _export;
16732 var global$3 = global$Z;
16733 var schedulersFix = schedulersFix$3;
16734 var setTimeout$1 = schedulersFix(global$3.setTimeout, true);
16735
16736 // Bun / IE9- setTimeout additional parameters fix
16737 // https://html.spec.whatwg.org/multipage/timers-and-user-prompts.html#dom-settimeout
16738 $$3({
16739 global: true,
16740 bind: true,
16741 forced: global$3.setTimeout !== setTimeout$1
16742 }, {
16743 setTimeout: setTimeout$1
16744 });
16745
16746 var web_url = {};
16747
16748 var web_url_constructor = {};
16749
16750 var fails = fails$1m;
16751 var wellKnownSymbol$1 = wellKnownSymbol$z;
16752 var DESCRIPTORS$3 = descriptors;
16753 var IS_PURE = isPure;
16754 var ITERATOR$1 = wellKnownSymbol$1('iterator');
16755 var urlConstructorDetection = !fails(function () {
16756 // eslint-disable-next-line unicorn/relative-url-style -- required for testing
16757 var url = new URL('b?a=1&b=2&c=3', 'http://a');
16758 var searchParams = url.searchParams;
16759 var result = '';
16760 url.pathname = 'c%20d';
16761 searchParams.forEach(function (value, key) {
16762 searchParams['delete']('b');
16763 result += key + value;
16764 });
16765 return IS_PURE && !url.toJSON || !searchParams.size && (IS_PURE || !DESCRIPTORS$3) || !searchParams.sort || url.href !== 'http://a/c%20d?a=1&c=3' || searchParams.get('c') !== '3' || String(new URLSearchParams('?a=1')) !== 'a=1' || !searchParams[ITERATOR$1]
16766 // throws in Edge
16767 || new URL('https://a@b').username !== 'a' || new URLSearchParams(new URLSearchParams('a=b')).get('a') !== 'b'
16768 // not punycoded in Edge
16769 || new URL('http://тест').host !== 'xn--e1aybc'
16770 // not escaped in Chrome 62-
16771 || new URL('http://a#б').hash !== '#%D0%B1'
16772 // fails in Chrome 66-
16773 || result !== 'a1c3'
16774 // throws in Safari
16775 || new URL('http://x', undefined).host !== 'x';
16776 });
16777 var urlConstructorDetection$1 = /*@__PURE__*/getDefaultExportFromCjs(urlConstructorDetection);
16778
16779 // based on https://github.com/bestiejs/punycode.js/blob/master/punycode.js
16780 var uncurryThis$3 = functionUncurryThis;
16781 var maxInt = 2147483647; // aka. 0x7FFFFFFF or 2^31-1
16782 var base = 36;
16783 var tMin = 1;
16784 var tMax = 26;
16785 var skew = 38;
16786 var damp = 700;
16787 var initialBias = 72;
16788 var initialN = 128; // 0x80
16789 var delimiter = '-'; // '\x2D'
16790 var regexNonASCII = /[^\0-\u007E]/; // non-ASCII chars
16791 var regexSeparators = /[.\u3002\uFF0E\uFF61]/g; // RFC 3490 separators
16792 var OVERFLOW_ERROR = 'Overflow: input needs wider integers to process';
16793 var baseMinusTMin = base - tMin;
16794 var $RangeError = RangeError;
16795 var exec$1 = uncurryThis$3(regexSeparators.exec);
16796 var floor$4 = Math.floor;
16797 var fromCharCode = String.fromCharCode;
16798 var charCodeAt = uncurryThis$3(''.charCodeAt);
16799 var join$2 = uncurryThis$3([].join);
16800 var push$2 = uncurryThis$3([].push);
16801 var replace$2 = uncurryThis$3(''.replace);
16802 var split$6 = uncurryThis$3(''.split);
16803 var toLowerCase$1 = uncurryThis$3(''.toLowerCase);
16804
16805 /**
16806 * Creates an array containing the numeric code points of each Unicode
16807 * character in the string. While JavaScript uses UCS-2 internally,
16808 * this function will convert a pair of surrogate halves (each of which
16809 * UCS-2 exposes as separate characters) into a single code point,
16810 * matching UTF-16.
16811 */
16812 var ucs2decode = function ucs2decode(string) {
16813 var output = [];
16814 var counter = 0;
16815 var length = string.length;
16816 while (counter < length) {
16817 var value = charCodeAt(string, counter++);
16818 if (value >= 0xD800 && value <= 0xDBFF && counter < length) {
16819 // It's a high surrogate, and there is a next character.
16820 var extra = charCodeAt(string, counter++);
16821 if ((extra & 0xFC00) == 0xDC00) {
16822 // Low surrogate.
16823 push$2(output, ((value & 0x3FF) << 10) + (extra & 0x3FF) + 0x10000);
16824 } else {
16825 // It's an unmatched surrogate; only append this code unit, in case the
16826 // next code unit is the high surrogate of a surrogate pair.
16827 push$2(output, value);
16828 counter--;
16829 }
16830 } else {
16831 push$2(output, value);
16832 }
16833 }
16834 return output;
16835 };
16836
16837 /**
16838 * Converts a digit/integer into a basic code point.
16839 */
16840 var digitToBasic = function digitToBasic(digit) {
16841 // 0..25 map to ASCII a..z or A..Z
16842 // 26..35 map to ASCII 0..9
16843 return digit + 22 + 75 * (digit < 26);
16844 };
16845
16846 /**
16847 * Bias adaptation function as per section 3.4 of RFC 3492.
16848 * https://tools.ietf.org/html/rfc3492#section-3.4
16849 */
16850 var adapt = function adapt(delta, numPoints, firstTime) {
16851 var k = 0;
16852 delta = firstTime ? floor$4(delta / damp) : delta >> 1;
16853 delta += floor$4(delta / numPoints);
16854 while (delta > baseMinusTMin * tMax >> 1) {
16855 delta = floor$4(delta / baseMinusTMin);
16856 k += base;
16857 }
16858 return floor$4(k + (baseMinusTMin + 1) * delta / (delta + skew));
16859 };
16860
16861 /**
16862 * Converts a string of Unicode symbols (e.g. a domain name label) to a
16863 * Punycode string of ASCII-only symbols.
16864 */
16865 var encode = function encode(input) {
16866 var output = [];
16867
16868 // Convert the input in UCS-2 to an array of Unicode code points.
16869 input = ucs2decode(input);
16870
16871 // Cache the length.
16872 var inputLength = input.length;
16873
16874 // Initialize the state.
16875 var n = initialN;
16876 var delta = 0;
16877 var bias = initialBias;
16878 var i, currentValue;
16879
16880 // Handle the basic code points.
16881 for (i = 0; i < input.length; i++) {
16882 currentValue = input[i];
16883 if (currentValue < 0x80) {
16884 push$2(output, fromCharCode(currentValue));
16885 }
16886 }
16887 var basicLength = output.length; // number of basic code points.
16888 var handledCPCount = basicLength; // number of code points that have been handled;
16889
16890 // Finish the basic string with a delimiter unless it's empty.
16891 if (basicLength) {
16892 push$2(output, delimiter);
16893 }
16894
16895 // Main encoding loop:
16896 while (handledCPCount < inputLength) {
16897 // All non-basic code points < n have been handled already. Find the next larger one:
16898 var m = maxInt;
16899 for (i = 0; i < input.length; i++) {
16900 currentValue = input[i];
16901 if (currentValue >= n && currentValue < m) {
16902 m = currentValue;
16903 }
16904 }
16905
16906 // Increase `delta` enough to advance the decoder's <n,i> state to <m,0>, but guard against overflow.
16907 var handledCPCountPlusOne = handledCPCount + 1;
16908 if (m - n > floor$4((maxInt - delta) / handledCPCountPlusOne)) {
16909 throw $RangeError(OVERFLOW_ERROR);
16910 }
16911 delta += (m - n) * handledCPCountPlusOne;
16912 n = m;
16913 for (i = 0; i < input.length; i++) {
16914 currentValue = input[i];
16915 if (currentValue < n && ++delta > maxInt) {
16916 throw $RangeError(OVERFLOW_ERROR);
16917 }
16918 if (currentValue == n) {
16919 // Represent delta as a generalized variable-length integer.
16920 var q = delta;
16921 var k = base;
16922 while (true) {
16923 var t = k <= bias ? tMin : k >= bias + tMax ? tMax : k - bias;
16924 if (q < t) break;
16925 var qMinusT = q - t;
16926 var baseMinusT = base - t;
16927 push$2(output, fromCharCode(digitToBasic(t + qMinusT % baseMinusT)));
16928 q = floor$4(qMinusT / baseMinusT);
16929 k += base;
16930 }
16931 push$2(output, fromCharCode(digitToBasic(q)));
16932 bias = adapt(delta, handledCPCountPlusOne, handledCPCount == basicLength);
16933 delta = 0;
16934 handledCPCount++;
16935 }
16936 }
16937 delta++;
16938 n++;
16939 }
16940 return join$2(output, '');
16941 };
16942 var stringPunycodeToAscii = function stringPunycodeToAscii(input) {
16943 var encoded = [];
16944 var labels = split$6(replace$2(toLowerCase$1(input), regexSeparators, "."), '.');
16945 var i, label;
16946 for (i = 0; i < labels.length; i++) {
16947 label = labels[i];
16948 push$2(encoded, exec$1(regexNonASCII, label) ? 'xn--' + encode(label) : label);
16949 }
16950 return join$2(encoded, '.');
16951 };
16952 var stringPunycodeToAscii$1 = /*@__PURE__*/getDefaultExportFromCjs(stringPunycodeToAscii);
16953
16954 'use strict';
16955 // TODO: in core-js@4, move /modules/ dependencies to public entries for better optimization by tools like `preset-env`
16956
16957 var $$2 = _export;
16958 var global$2 = global$Z;
16959 var call$1 = functionCall;
16960 var uncurryThis$2 = functionUncurryThis;
16961 var DESCRIPTORS$2 = descriptors;
16962 var USE_NATIVE_URL$1 = urlConstructorDetection;
16963 var defineBuiltIn$1 = defineBuiltIn$m;
16964 var defineBuiltInAccessor$2 = defineBuiltInAccessor$h;
16965 var defineBuiltIns = defineBuiltIns$5;
16966 var setToStringTag$1 = setToStringTag$d;
16967 var createIteratorConstructor = iteratorCreateConstructor;
16968 var InternalStateModule$1 = internalState;
16969 var anInstance$1 = anInstance$a;
16970 var isCallable = isCallable$z;
16971 var hasOwn$1 = hasOwnProperty_1;
16972 var bind$1 = functionBindContext;
16973 var classof = classof$m;
16974 var anObject = anObject$D;
16975 var isObject = isObject$z;
16976 var $toString$1 = toString$x;
16977 var create = objectCreate;
16978 var createPropertyDescriptor = createPropertyDescriptor$c;
16979 var getIterator = getIterator$4;
16980 var getIteratorMethod = getIteratorMethod$5;
16981 var validateArgumentsLength$1 = validateArgumentsLength$8;
16982 var wellKnownSymbol = wellKnownSymbol$z;
16983 var arraySort = arraySort$1;
16984 var ITERATOR = wellKnownSymbol('iterator');
16985 var URL_SEARCH_PARAMS = 'URLSearchParams';
16986 var URL_SEARCH_PARAMS_ITERATOR = URL_SEARCH_PARAMS + 'Iterator';
16987 var setInternalState$1 = InternalStateModule$1.set;
16988 var getInternalParamsState = InternalStateModule$1.getterFor(URL_SEARCH_PARAMS);
16989 var getInternalIteratorState = InternalStateModule$1.getterFor(URL_SEARCH_PARAMS_ITERATOR);
16990 // eslint-disable-next-line es/no-object-getownpropertydescriptor -- safe
16991 var getOwnPropertyDescriptor = Object.getOwnPropertyDescriptor;
16992
16993 // Avoid NodeJS experimental warning
16994 var safeGetBuiltIn = function safeGetBuiltIn(name) {
16995 if (!DESCRIPTORS$2) return global$2[name];
16996 var descriptor = getOwnPropertyDescriptor(global$2, name);
16997 return descriptor && descriptor.value;
16998 };
16999 var nativeFetch = safeGetBuiltIn('fetch');
17000 var NativeRequest = safeGetBuiltIn('Request');
17001 var Headers = safeGetBuiltIn('Headers');
17002 var RequestPrototype = NativeRequest && NativeRequest.prototype;
17003 var HeadersPrototype = Headers && Headers.prototype;
17004 var RegExp$1 = global$2.RegExp;
17005 var TypeError$2 = global$2.TypeError;
17006 var decodeURIComponent$1 = global$2.decodeURIComponent;
17007 var encodeURIComponent$1 = global$2.encodeURIComponent;
17008 var charAt$1 = uncurryThis$2(''.charAt);
17009 var join$1 = uncurryThis$2([].join);
17010 var push$1 = uncurryThis$2([].push);
17011 var replace$1 = uncurryThis$2(''.replace);
17012 var shift$1 = uncurryThis$2([].shift);
17013 var splice = uncurryThis$2([].splice);
17014 var split$5 = uncurryThis$2(''.split);
17015 var stringSlice$1 = uncurryThis$2(''.slice);
17016 var plus = /\+/g;
17017 var sequences = Array(4);
17018 var percentSequence = function percentSequence(bytes) {
17019 return sequences[bytes - 1] || (sequences[bytes - 1] = RegExp$1('((?:%[\\da-f]{2}){' + bytes + '})', 'gi'));
17020 };
17021 var percentDecode = function percentDecode(sequence) {
17022 try {
17023 return decodeURIComponent$1(sequence);
17024 } catch (error) {
17025 return sequence;
17026 }
17027 };
17028 var deserialize$1 = function deserialize(it) {
17029 var result = replace$1(it, plus, ' ');
17030 var bytes = 4;
17031 try {
17032 return decodeURIComponent$1(result);
17033 } catch (error) {
17034 while (bytes) {
17035 result = replace$1(result, percentSequence(bytes--), percentDecode);
17036 }
17037 return result;
17038 }
17039 };
17040 var find = /[!'()~]|%20/g;
17041 var replacements = {
17042 '!': '%21',
17043 "'": '%27',
17044 '(': '%28',
17045 ')': '%29',
17046 '~': '%7E',
17047 '%20': '+'
17048 };
17049 var replacer = function replacer(match) {
17050 return replacements[match];
17051 };
17052 var _serialize = function serialize(it) {
17053 return replace$1(encodeURIComponent$1(it), find, replacer);
17054 };
17055 var URLSearchParamsIterator = createIteratorConstructor(function Iterator(params, kind) {
17056 setInternalState$1(this, {
17057 type: URL_SEARCH_PARAMS_ITERATOR,
17058 iterator: getIterator(getInternalParamsState(params).entries),
17059 kind: kind
17060 });
17061 }, 'Iterator', function next() {
17062 var state = getInternalIteratorState(this);
17063 var kind = state.kind;
17064 var step = state.iterator.next();
17065 var entry = step.value;
17066 if (!step.done) {
17067 step.value = kind === 'keys' ? entry.key : kind === 'values' ? entry.value : [entry.key, entry.value];
17068 }
17069 return step;
17070 }, true);
17071 var URLSearchParamsState = function URLSearchParamsState(init) {
17072 this.entries = [];
17073 this.url = null;
17074 if (init !== undefined) {
17075 if (isObject(init)) this.parseObject(init);else this.parseQuery(typeof init == 'string' ? charAt$1(init, 0) === '?' ? stringSlice$1(init, 1) : init : $toString$1(init));
17076 }
17077 };
17078 URLSearchParamsState.prototype = {
17079 type: URL_SEARCH_PARAMS,
17080 bindURL: function bindURL(url) {
17081 this.url = url;
17082 this.update();
17083 },
17084 parseObject: function parseObject(object) {
17085 var iteratorMethod = getIteratorMethod(object);
17086 var iterator, next, step, entryIterator, entryNext, first, second;
17087 if (iteratorMethod) {
17088 iterator = getIterator(object, iteratorMethod);
17089 next = iterator.next;
17090 while (!(step = call$1(next, iterator)).done) {
17091 entryIterator = getIterator(anObject(step.value));
17092 entryNext = entryIterator.next;
17093 if ((first = call$1(entryNext, entryIterator)).done || (second = call$1(entryNext, entryIterator)).done || !call$1(entryNext, entryIterator).done) throw TypeError$2('Expected sequence with length 2');
17094 push$1(this.entries, {
17095 key: $toString$1(first.value),
17096 value: $toString$1(second.value)
17097 });
17098 }
17099 } else for (var key in object) if (hasOwn$1(object, key)) {
17100 push$1(this.entries, {
17101 key: key,
17102 value: $toString$1(object[key])
17103 });
17104 }
17105 },
17106 parseQuery: function parseQuery(query) {
17107 if (query) {
17108 var attributes = split$5(query, '&');
17109 var index = 0;
17110 var attribute, entry;
17111 while (index < attributes.length) {
17112 attribute = attributes[index++];
17113 if (attribute.length) {
17114 entry = split$5(attribute, '=');
17115 push$1(this.entries, {
17116 key: deserialize$1(shift$1(entry)),
17117 value: deserialize$1(join$1(entry, '='))
17118 });
17119 }
17120 }
17121 }
17122 },
17123 serialize: function serialize() {
17124 var entries = this.entries;
17125 var result = [];
17126 var index = 0;
17127 var entry;
17128 while (index < entries.length) {
17129 entry = entries[index++];
17130 push$1(result, _serialize(entry.key) + '=' + _serialize(entry.value));
17131 }
17132 return join$1(result, '&');
17133 },
17134 update: function update() {
17135 this.entries.length = 0;
17136 this.parseQuery(this.url.query);
17137 },
17138 updateURL: function updateURL() {
17139 if (this.url) this.url.update();
17140 }
17141 };
17142
17143 // `URLSearchParams` constructor
17144 // https://url.spec.whatwg.org/#interface-urlsearchparams
17145 var URLSearchParamsConstructor = function URLSearchParams( /* init */
17146 ) {
17147 anInstance$1(this, URLSearchParamsPrototype$1);
17148 var init = arguments.length > 0 ? arguments[0] : undefined;
17149 var state = setInternalState$1(this, new URLSearchParamsState(init));
17150 if (!DESCRIPTORS$2) this.length = state.entries.length;
17151 };
17152 var URLSearchParamsPrototype$1 = URLSearchParamsConstructor.prototype;
17153 defineBuiltIns(URLSearchParamsPrototype$1, {
17154 // `URLSearchParams.prototype.append` method
17155 // https://url.spec.whatwg.org/#dom-urlsearchparams-append
17156 append: function append(name, value) {
17157 validateArgumentsLength$1(arguments.length, 2);
17158 var state = getInternalParamsState(this);
17159 push$1(state.entries, {
17160 key: $toString$1(name),
17161 value: $toString$1(value)
17162 });
17163 if (!DESCRIPTORS$2) this.length++;
17164 state.updateURL();
17165 },
17166 // `URLSearchParams.prototype.delete` method
17167 // https://url.spec.whatwg.org/#dom-urlsearchparams-delete
17168 'delete': function _delete(name) {
17169 validateArgumentsLength$1(arguments.length, 1);
17170 var state = getInternalParamsState(this);
17171 var entries = state.entries;
17172 var key = $toString$1(name);
17173 var index = 0;
17174 while (index < entries.length) {
17175 if (entries[index].key === key) splice(entries, index, 1);else index++;
17176 }
17177 if (!DESCRIPTORS$2) this.length = entries.length;
17178 state.updateURL();
17179 },
17180 // `URLSearchParams.prototype.get` method
17181 // https://url.spec.whatwg.org/#dom-urlsearchparams-get
17182 get: function get(name) {
17183 validateArgumentsLength$1(arguments.length, 1);
17184 var entries = getInternalParamsState(this).entries;
17185 var key = $toString$1(name);
17186 var index = 0;
17187 for (; index < entries.length; index++) {
17188 if (entries[index].key === key) return entries[index].value;
17189 }
17190 return null;
17191 },
17192 // `URLSearchParams.prototype.getAll` method
17193 // https://url.spec.whatwg.org/#dom-urlsearchparams-getall
17194 getAll: function getAll(name) {
17195 validateArgumentsLength$1(arguments.length, 1);
17196 var entries = getInternalParamsState(this).entries;
17197 var key = $toString$1(name);
17198 var result = [];
17199 var index = 0;
17200 for (; index < entries.length; index++) {
17201 if (entries[index].key === key) push$1(result, entries[index].value);
17202 }
17203 return result;
17204 },
17205 // `URLSearchParams.prototype.has` method
17206 // https://url.spec.whatwg.org/#dom-urlsearchparams-has
17207 has: function has(name) {
17208 validateArgumentsLength$1(arguments.length, 1);
17209 var entries = getInternalParamsState(this).entries;
17210 var key = $toString$1(name);
17211 var index = 0;
17212 while (index < entries.length) {
17213 if (entries[index++].key === key) return true;
17214 }
17215 return false;
17216 },
17217 // `URLSearchParams.prototype.set` method
17218 // https://url.spec.whatwg.org/#dom-urlsearchparams-set
17219 set: function set(name, value) {
17220 validateArgumentsLength$1(arguments.length, 1);
17221 var state = getInternalParamsState(this);
17222 var entries = state.entries;
17223 var found = false;
17224 var key = $toString$1(name);
17225 var val = $toString$1(value);
17226 var index = 0;
17227 var entry;
17228 for (; index < entries.length; index++) {
17229 entry = entries[index];
17230 if (entry.key === key) {
17231 if (found) splice(entries, index--, 1);else {
17232 found = true;
17233 entry.value = val;
17234 }
17235 }
17236 }
17237 if (!found) push$1(entries, {
17238 key: key,
17239 value: val
17240 });
17241 if (!DESCRIPTORS$2) this.length = entries.length;
17242 state.updateURL();
17243 },
17244 // `URLSearchParams.prototype.sort` method
17245 // https://url.spec.whatwg.org/#dom-urlsearchparams-sort
17246 sort: function sort() {
17247 var state = getInternalParamsState(this);
17248 arraySort(state.entries, function (a, b) {
17249 return a.key > b.key ? 1 : -1;
17250 });
17251 state.updateURL();
17252 },
17253 // `URLSearchParams.prototype.forEach` method
17254 forEach: function forEach(callback /* , thisArg */) {
17255 var entries = getInternalParamsState(this).entries;
17256 var boundFunction = bind$1(callback, arguments.length > 1 ? arguments[1] : undefined);
17257 var index = 0;
17258 var entry;
17259 while (index < entries.length) {
17260 entry = entries[index++];
17261 boundFunction(entry.value, entry.key, this);
17262 }
17263 },
17264 // `URLSearchParams.prototype.keys` method
17265 keys: function keys() {
17266 return new URLSearchParamsIterator(this, 'keys');
17267 },
17268 // `URLSearchParams.prototype.values` method
17269 values: function values() {
17270 return new URLSearchParamsIterator(this, 'values');
17271 },
17272 // `URLSearchParams.prototype.entries` method
17273 entries: function entries() {
17274 return new URLSearchParamsIterator(this, 'entries');
17275 }
17276 }, {
17277 enumerable: true
17278 });
17279
17280 // `URLSearchParams.prototype[@@iterator]` method
17281 defineBuiltIn$1(URLSearchParamsPrototype$1, ITERATOR, URLSearchParamsPrototype$1.entries, {
17282 name: 'entries'
17283 });
17284
17285 // `URLSearchParams.prototype.toString` method
17286 // https://url.spec.whatwg.org/#urlsearchparams-stringification-behavior
17287 defineBuiltIn$1(URLSearchParamsPrototype$1, 'toString', function toString() {
17288 return getInternalParamsState(this).serialize();
17289 }, {
17290 enumerable: true
17291 });
17292
17293 // `URLSearchParams.prototype.size` getter
17294 // https://github.com/whatwg/url/pull/734
17295 if (DESCRIPTORS$2) defineBuiltInAccessor$2(URLSearchParamsPrototype$1, 'size', {
17296 get: function size() {
17297 return getInternalParamsState(this).entries.length;
17298 },
17299 configurable: true,
17300 enumerable: true
17301 });
17302 setToStringTag$1(URLSearchParamsConstructor, URL_SEARCH_PARAMS);
17303 $$2({
17304 global: true,
17305 constructor: true,
17306 forced: !USE_NATIVE_URL$1
17307 }, {
17308 URLSearchParams: URLSearchParamsConstructor
17309 });
17310
17311 // Wrap `fetch` and `Request` for correct work with polyfilled `URLSearchParams`
17312 if (!USE_NATIVE_URL$1 && isCallable(Headers)) {
17313 var headersHas = uncurryThis$2(HeadersPrototype.has);
17314 var headersSet = uncurryThis$2(HeadersPrototype.set);
17315 var wrapRequestOptions = function wrapRequestOptions(init) {
17316 if (isObject(init)) {
17317 var body = init.body;
17318 var headers;
17319 if (classof(body) === URL_SEARCH_PARAMS) {
17320 headers = init.headers ? new Headers(init.headers) : new Headers();
17321 if (!headersHas(headers, 'content-type')) {
17322 headersSet(headers, 'content-type', 'application/x-www-form-urlencoded;charset=UTF-8');
17323 }
17324 return create(init, {
17325 body: createPropertyDescriptor(0, $toString$1(body)),
17326 headers: createPropertyDescriptor(0, headers)
17327 });
17328 }
17329 }
17330 return init;
17331 };
17332 if (isCallable(nativeFetch)) {
17333 $$2({
17334 global: true,
17335 enumerable: true,
17336 dontCallGetSet: true,
17337 forced: true
17338 }, {
17339 fetch: function fetch(input /* , init */) {
17340 return nativeFetch(input, arguments.length > 1 ? wrapRequestOptions(arguments[1]) : {});
17341 }
17342 });
17343 }
17344 if (isCallable(NativeRequest)) {
17345 var RequestConstructor = function Request(input /* , init */) {
17346 anInstance$1(this, RequestPrototype);
17347 return new NativeRequest(input, arguments.length > 1 ? wrapRequestOptions(arguments[1]) : {});
17348 };
17349 RequestPrototype.constructor = RequestConstructor;
17350 RequestConstructor.prototype = RequestPrototype;
17351 $$2({
17352 global: true,
17353 constructor: true,
17354 dontCallGetSet: true,
17355 forced: true
17356 }, {
17357 Request: RequestConstructor
17358 });
17359 }
17360 }
17361 var web_urlSearchParams_constructor = {
17362 URLSearchParams: URLSearchParamsConstructor,
17363 getState: getInternalParamsState
17364 };
17365 var web_urlSearchParams_constructor$1 = /*@__PURE__*/getDefaultExportFromCjs(web_urlSearchParams_constructor);
17366
17367 'use strict';
17368 // TODO: in core-js@4, move /modules/ dependencies to public entries for better optimization by tools like `preset-env`
17369
17370 var $$1 = _export;
17371 var DESCRIPTORS$1 = descriptors;
17372 var USE_NATIVE_URL = urlConstructorDetection;
17373 var global$1 = global$Z;
17374 var bind = functionBindContext;
17375 var uncurryThis$1 = functionUncurryThis;
17376 var defineBuiltIn = defineBuiltIn$m;
17377 var defineBuiltInAccessor$1 = defineBuiltInAccessor$h;
17378 var anInstance = anInstance$a;
17379 var hasOwn = hasOwnProperty_1;
17380 var assign = objectAssign;
17381 var arrayFrom = arrayFrom$1;
17382 var arraySlice = arraySliceSimple;
17383 var codeAt = stringMultibyte.codeAt;
17384 var toASCII = stringPunycodeToAscii;
17385 var $toString = toString$x;
17386 var setToStringTag = setToStringTag$d;
17387 var validateArgumentsLength = validateArgumentsLength$8;
17388 var URLSearchParamsModule = web_urlSearchParams_constructor;
17389 var InternalStateModule = internalState;
17390 var setInternalState = InternalStateModule.set;
17391 var getInternalURLState = InternalStateModule.getterFor('URL');
17392 var URLSearchParams$1 = URLSearchParamsModule.URLSearchParams;
17393 var getInternalSearchParamsState = URLSearchParamsModule.getState;
17394 var NativeURL = global$1.URL;
17395 var TypeError$1 = global$1.TypeError;
17396 var parseInt$1 = global$1.parseInt;
17397 var floor$3 = Math.floor;
17398 var pow$4 = Math.pow;
17399 var charAt = uncurryThis$1(''.charAt);
17400 var exec = uncurryThis$1(/./.exec);
17401 var join = uncurryThis$1([].join);
17402 var numberToString = uncurryThis$1(1.0.toString);
17403 var pop = uncurryThis$1([].pop);
17404 var push = uncurryThis$1([].push);
17405 var replace = uncurryThis$1(''.replace);
17406 var shift = uncurryThis$1([].shift);
17407 var split$4 = uncurryThis$1(''.split);
17408 var stringSlice = uncurryThis$1(''.slice);
17409 var toLowerCase = uncurryThis$1(''.toLowerCase);
17410 var unshift = uncurryThis$1([].unshift);
17411 var INVALID_AUTHORITY = 'Invalid authority';
17412 var INVALID_SCHEME = 'Invalid scheme';
17413 var INVALID_HOST = 'Invalid host';
17414 var INVALID_PORT = 'Invalid port';
17415 var ALPHA = /[a-z]/i;
17416 // eslint-disable-next-line regexp/no-obscure-range -- safe
17417 var ALPHANUMERIC = /[\d+-.a-z]/i;
17418 var DIGIT = /\d/;
17419 var HEX_START = /^0x/i;
17420 var OCT = /^[0-7]+$/;
17421 var DEC = /^\d+$/;
17422 var HEX = /^[\da-f]+$/i;
17423 /* eslint-disable regexp/no-control-character -- safe */
17424 var FORBIDDEN_HOST_CODE_POINT = /[\0\t\n\r #%/:<>?@[\\\]^|]/;
17425 var FORBIDDEN_HOST_CODE_POINT_EXCLUDING_PERCENT = /[\0\t\n\r #/:<>?@[\\\]^|]/;
17426 var LEADING_C0_CONTROL_OR_SPACE = /^[\u0000-\u0020]+/;
17427 var TRAILING_C0_CONTROL_OR_SPACE = /(^|[^\u0000-\u0020])[\u0000-\u0020]+$/;
17428 var TAB_AND_NEW_LINE = /[\t\n\r]/g;
17429 /* eslint-enable regexp/no-control-character -- safe */
17430 var EOF;
17431
17432 // https://url.spec.whatwg.org/#ipv4-number-parser
17433 var parseIPv4 = function parseIPv4(input) {
17434 var parts = split$4(input, '.');
17435 var partsLength, numbers, index, part, radix, number, ipv4;
17436 if (parts.length && parts[parts.length - 1] == '') {
17437 parts.length--;
17438 }
17439 partsLength = parts.length;
17440 if (partsLength > 4) return input;
17441 numbers = [];
17442 for (index = 0; index < partsLength; index++) {
17443 part = parts[index];
17444 if (part == '') return input;
17445 radix = 10;
17446 if (part.length > 1 && charAt(part, 0) == '0') {
17447 radix = exec(HEX_START, part) ? 16 : 8;
17448 part = stringSlice(part, radix == 8 ? 1 : 2);
17449 }
17450 if (part === '') {
17451 number = 0;
17452 } else {
17453 if (!exec(radix == 10 ? DEC : radix == 8 ? OCT : HEX, part)) return input;
17454 number = parseInt$1(part, radix);
17455 }
17456 push(numbers, number);
17457 }
17458 for (index = 0; index < partsLength; index++) {
17459 number = numbers[index];
17460 if (index == partsLength - 1) {
17461 if (number >= pow$4(256, 5 - partsLength)) return null;
17462 } else if (number > 255) return null;
17463 }
17464 ipv4 = pop(numbers);
17465 for (index = 0; index < numbers.length; index++) {
17466 ipv4 += numbers[index] * pow$4(256, 3 - index);
17467 }
17468 return ipv4;
17469 };
17470
17471 // https://url.spec.whatwg.org/#concept-ipv6-parser
17472 // eslint-disable-next-line max-statements -- TODO
17473 var parseIPv6 = function parseIPv6(input) {
17474 var address = [0, 0, 0, 0, 0, 0, 0, 0];
17475 var pieceIndex = 0;
17476 var compress = null;
17477 var pointer = 0;
17478 var value, length, numbersSeen, ipv4Piece, number, swaps, swap;
17479 var chr = function chr() {
17480 return charAt(input, pointer);
17481 };
17482 if (chr() == ':') {
17483 if (charAt(input, 1) != ':') return;
17484 pointer += 2;
17485 pieceIndex++;
17486 compress = pieceIndex;
17487 }
17488 while (chr()) {
17489 if (pieceIndex == 8) return;
17490 if (chr() == ':') {
17491 if (compress !== null) return;
17492 pointer++;
17493 pieceIndex++;
17494 compress = pieceIndex;
17495 continue;
17496 }
17497 value = length = 0;
17498 while (length < 4 && exec(HEX, chr())) {
17499 value = value * 16 + parseInt$1(chr(), 16);
17500 pointer++;
17501 length++;
17502 }
17503 if (chr() == '.') {
17504 if (length == 0) return;
17505 pointer -= length;
17506 if (pieceIndex > 6) return;
17507 numbersSeen = 0;
17508 while (chr()) {
17509 ipv4Piece = null;
17510 if (numbersSeen > 0) {
17511 if (chr() == '.' && numbersSeen < 4) pointer++;else return;
17512 }
17513 if (!exec(DIGIT, chr())) return;
17514 while (exec(DIGIT, chr())) {
17515 number = parseInt$1(chr(), 10);
17516 if (ipv4Piece === null) ipv4Piece = number;else if (ipv4Piece == 0) return;else ipv4Piece = ipv4Piece * 10 + number;
17517 if (ipv4Piece > 255) return;
17518 pointer++;
17519 }
17520 address[pieceIndex] = address[pieceIndex] * 256 + ipv4Piece;
17521 numbersSeen++;
17522 if (numbersSeen == 2 || numbersSeen == 4) pieceIndex++;
17523 }
17524 if (numbersSeen != 4) return;
17525 break;
17526 } else if (chr() == ':') {
17527 pointer++;
17528 if (!chr()) return;
17529 } else if (chr()) return;
17530 address[pieceIndex++] = value;
17531 }
17532 if (compress !== null) {
17533 swaps = pieceIndex - compress;
17534 pieceIndex = 7;
17535 while (pieceIndex != 0 && swaps > 0) {
17536 swap = address[pieceIndex];
17537 address[pieceIndex--] = address[compress + swaps - 1];
17538 address[compress + --swaps] = swap;
17539 }
17540 } else if (pieceIndex != 8) return;
17541 return address;
17542 };
17543 var findLongestZeroSequence = function findLongestZeroSequence(ipv6) {
17544 var maxIndex = null;
17545 var maxLength = 1;
17546 var currStart = null;
17547 var currLength = 0;
17548 var index = 0;
17549 for (; index < 8; index++) {
17550 if (ipv6[index] !== 0) {
17551 if (currLength > maxLength) {
17552 maxIndex = currStart;
17553 maxLength = currLength;
17554 }
17555 currStart = null;
17556 currLength = 0;
17557 } else {
17558 if (currStart === null) currStart = index;
17559 ++currLength;
17560 }
17561 }
17562 if (currLength > maxLength) {
17563 maxIndex = currStart;
17564 maxLength = currLength;
17565 }
17566 return maxIndex;
17567 };
17568
17569 // https://url.spec.whatwg.org/#host-serializing
17570 var serializeHost = function serializeHost(host) {
17571 var result, index, compress, ignore0;
17572 // ipv4
17573 if (typeof host == 'number') {
17574 result = [];
17575 for (index = 0; index < 4; index++) {
17576 unshift(result, host % 256);
17577 host = floor$3(host / 256);
17578 }
17579 return join(result, '.');
17580 // ipv6
17581 } else if (_typeof(host) == 'object') {
17582 result = '';
17583 compress = findLongestZeroSequence(host);
17584 for (index = 0; index < 8; index++) {
17585 if (ignore0 && host[index] === 0) continue;
17586 if (ignore0) ignore0 = false;
17587 if (compress === index) {
17588 result += index ? ':' : '::';
17589 ignore0 = true;
17590 } else {
17591 result += numberToString(host[index], 16);
17592 if (index < 7) result += ':';
17593 }
17594 }
17595 return '[' + result + ']';
17596 }
17597 return host;
17598 };
17599 var C0ControlPercentEncodeSet = {};
17600 var fragmentPercentEncodeSet = assign({}, C0ControlPercentEncodeSet, {
17601 ' ': 1,
17602 '"': 1,
17603 '<': 1,
17604 '>': 1,
17605 '`': 1
17606 });
17607 var pathPercentEncodeSet = assign({}, fragmentPercentEncodeSet, {
17608 '#': 1,
17609 '?': 1,
17610 '{': 1,
17611 '}': 1
17612 });
17613 var userinfoPercentEncodeSet = assign({}, pathPercentEncodeSet, {
17614 '/': 1,
17615 ':': 1,
17616 ';': 1,
17617 '=': 1,
17618 '@': 1,
17619 '[': 1,
17620 '\\': 1,
17621 ']': 1,
17622 '^': 1,
17623 '|': 1
17624 });
17625 var percentEncode = function percentEncode(chr, set) {
17626 var code = codeAt(chr, 0);
17627 return code > 0x20 && code < 0x7F && !hasOwn(set, chr) ? chr : encodeURIComponent(chr);
17628 };
17629
17630 // https://url.spec.whatwg.org/#special-scheme
17631 var specialSchemes = {
17632 ftp: 21,
17633 file: null,
17634 http: 80,
17635 https: 443,
17636 ws: 80,
17637 wss: 443
17638 };
17639
17640 // https://url.spec.whatwg.org/#windows-drive-letter
17641 var isWindowsDriveLetter = function isWindowsDriveLetter(string, normalized) {
17642 var second;
17643 return string.length == 2 && exec(ALPHA, charAt(string, 0)) && ((second = charAt(string, 1)) == ':' || !normalized && second == '|');
17644 };
17645
17646 // https://url.spec.whatwg.org/#start-with-a-windows-drive-letter
17647 var startsWithWindowsDriveLetter = function startsWithWindowsDriveLetter(string) {
17648 var third;
17649 return string.length > 1 && isWindowsDriveLetter(stringSlice(string, 0, 2)) && (string.length == 2 || (third = charAt(string, 2)) === '/' || third === '\\' || third === '?' || third === '#');
17650 };
17651
17652 // https://url.spec.whatwg.org/#single-dot-path-segment
17653 var isSingleDot = function isSingleDot(segment) {
17654 return segment === '.' || toLowerCase(segment) === '%2e';
17655 };
17656
17657 // https://url.spec.whatwg.org/#double-dot-path-segment
17658 var isDoubleDot = function isDoubleDot(segment) {
17659 segment = toLowerCase(segment);
17660 return segment === '..' || segment === '%2e.' || segment === '.%2e' || segment === '%2e%2e';
17661 };
17662
17663 // States:
17664 var SCHEME_START = {};
17665 var SCHEME = {};
17666 var NO_SCHEME = {};
17667 var SPECIAL_RELATIVE_OR_AUTHORITY = {};
17668 var PATH_OR_AUTHORITY = {};
17669 var RELATIVE = {};
17670 var RELATIVE_SLASH = {};
17671 var SPECIAL_AUTHORITY_SLASHES = {};
17672 var SPECIAL_AUTHORITY_IGNORE_SLASHES = {};
17673 var AUTHORITY = {};
17674 var HOST = {};
17675 var HOSTNAME = {};
17676 var PORT = {};
17677 var FILE = {};
17678 var FILE_SLASH = {};
17679 var FILE_HOST = {};
17680 var PATH_START = {};
17681 var PATH = {};
17682 var CANNOT_BE_A_BASE_URL_PATH = {};
17683 var QUERY = {};
17684 var FRAGMENT = {};
17685 var URLState = function URLState(url, isBase, base) {
17686 var urlString = $toString(url);
17687 var baseState, failure, searchParams;
17688 if (isBase) {
17689 failure = this.parse(urlString);
17690 if (failure) throw TypeError$1(failure);
17691 this.searchParams = null;
17692 } else {
17693 if (base !== undefined) baseState = new URLState(base, true);
17694 failure = this.parse(urlString, null, baseState);
17695 if (failure) throw TypeError$1(failure);
17696 searchParams = getInternalSearchParamsState(new URLSearchParams$1());
17697 searchParams.bindURL(this);
17698 this.searchParams = searchParams;
17699 }
17700 };
17701 URLState.prototype = {
17702 type: 'URL',
17703 // https://url.spec.whatwg.org/#url-parsing
17704 // eslint-disable-next-line max-statements -- TODO
17705 parse: function parse(input, stateOverride, base) {
17706 var url = this;
17707 var state = stateOverride || SCHEME_START;
17708 var pointer = 0;
17709 var buffer = '';
17710 var seenAt = false;
17711 var seenBracket = false;
17712 var seenPasswordToken = false;
17713 var codePoints, chr, bufferCodePoints, failure;
17714 input = $toString(input);
17715 if (!stateOverride) {
17716 url.scheme = '';
17717 url.username = '';
17718 url.password = '';
17719 url.host = null;
17720 url.port = null;
17721 url.path = [];
17722 url.query = null;
17723 url.fragment = null;
17724 url.cannotBeABaseURL = false;
17725 input = replace(input, LEADING_C0_CONTROL_OR_SPACE, '');
17726 input = replace(input, TRAILING_C0_CONTROL_OR_SPACE, '$1');
17727 }
17728 input = replace(input, TAB_AND_NEW_LINE, '');
17729 codePoints = arrayFrom(input);
17730 while (pointer <= codePoints.length) {
17731 chr = codePoints[pointer];
17732 switch (state) {
17733 case SCHEME_START:
17734 if (chr && exec(ALPHA, chr)) {
17735 buffer += toLowerCase(chr);
17736 state = SCHEME;
17737 } else if (!stateOverride) {
17738 state = NO_SCHEME;
17739 continue;
17740 } else return INVALID_SCHEME;
17741 break;
17742 case SCHEME:
17743 if (chr && (exec(ALPHANUMERIC, chr) || chr == '+' || chr == '-' || chr == '.')) {
17744 buffer += toLowerCase(chr);
17745 } else if (chr == ':') {
17746 if (stateOverride && (url.isSpecial() != hasOwn(specialSchemes, buffer) || buffer == 'file' && (url.includesCredentials() || url.port !== null) || url.scheme == 'file' && !url.host)) return;
17747 url.scheme = buffer;
17748 if (stateOverride) {
17749 if (url.isSpecial() && specialSchemes[url.scheme] == url.port) url.port = null;
17750 return;
17751 }
17752 buffer = '';
17753 if (url.scheme == 'file') {
17754 state = FILE;
17755 } else if (url.isSpecial() && base && base.scheme == url.scheme) {
17756 state = SPECIAL_RELATIVE_OR_AUTHORITY;
17757 } else if (url.isSpecial()) {
17758 state = SPECIAL_AUTHORITY_SLASHES;
17759 } else if (codePoints[pointer + 1] == '/') {
17760 state = PATH_OR_AUTHORITY;
17761 pointer++;
17762 } else {
17763 url.cannotBeABaseURL = true;
17764 push(url.path, '');
17765 state = CANNOT_BE_A_BASE_URL_PATH;
17766 }
17767 } else if (!stateOverride) {
17768 buffer = '';
17769 state = NO_SCHEME;
17770 pointer = 0;
17771 continue;
17772 } else return INVALID_SCHEME;
17773 break;
17774 case NO_SCHEME:
17775 if (!base || base.cannotBeABaseURL && chr != '#') return INVALID_SCHEME;
17776 if (base.cannotBeABaseURL && chr == '#') {
17777 url.scheme = base.scheme;
17778 url.path = arraySlice(base.path);
17779 url.query = base.query;
17780 url.fragment = '';
17781 url.cannotBeABaseURL = true;
17782 state = FRAGMENT;
17783 break;
17784 }
17785 state = base.scheme == 'file' ? FILE : RELATIVE;
17786 continue;
17787 case SPECIAL_RELATIVE_OR_AUTHORITY:
17788 if (chr == '/' && codePoints[pointer + 1] == '/') {
17789 state = SPECIAL_AUTHORITY_IGNORE_SLASHES;
17790 pointer++;
17791 } else {
17792 state = RELATIVE;
17793 continue;
17794 }
17795 break;
17796 case PATH_OR_AUTHORITY:
17797 if (chr == '/') {
17798 state = AUTHORITY;
17799 break;
17800 } else {
17801 state = PATH;
17802 continue;
17803 }
17804 case RELATIVE:
17805 url.scheme = base.scheme;
17806 if (chr == EOF) {
17807 url.username = base.username;
17808 url.password = base.password;
17809 url.host = base.host;
17810 url.port = base.port;
17811 url.path = arraySlice(base.path);
17812 url.query = base.query;
17813 } else if (chr == '/' || chr == '\\' && url.isSpecial()) {
17814 state = RELATIVE_SLASH;
17815 } else if (chr == '?') {
17816 url.username = base.username;
17817 url.password = base.password;
17818 url.host = base.host;
17819 url.port = base.port;
17820 url.path = arraySlice(base.path);
17821 url.query = '';
17822 state = QUERY;
17823 } else if (chr == '#') {
17824 url.username = base.username;
17825 url.password = base.password;
17826 url.host = base.host;
17827 url.port = base.port;
17828 url.path = arraySlice(base.path);
17829 url.query = base.query;
17830 url.fragment = '';
17831 state = FRAGMENT;
17832 } else {
17833 url.username = base.username;
17834 url.password = base.password;
17835 url.host = base.host;
17836 url.port = base.port;
17837 url.path = arraySlice(base.path);
17838 url.path.length--;
17839 state = PATH;
17840 continue;
17841 }
17842 break;
17843 case RELATIVE_SLASH:
17844 if (url.isSpecial() && (chr == '/' || chr == '\\')) {
17845 state = SPECIAL_AUTHORITY_IGNORE_SLASHES;
17846 } else if (chr == '/') {
17847 state = AUTHORITY;
17848 } else {
17849 url.username = base.username;
17850 url.password = base.password;
17851 url.host = base.host;
17852 url.port = base.port;
17853 state = PATH;
17854 continue;
17855 }
17856 break;
17857 case SPECIAL_AUTHORITY_SLASHES:
17858 state = SPECIAL_AUTHORITY_IGNORE_SLASHES;
17859 if (chr != '/' || charAt(buffer, pointer + 1) != '/') continue;
17860 pointer++;
17861 break;
17862 case SPECIAL_AUTHORITY_IGNORE_SLASHES:
17863 if (chr != '/' && chr != '\\') {
17864 state = AUTHORITY;
17865 continue;
17866 }
17867 break;
17868 case AUTHORITY:
17869 if (chr == '@') {
17870 if (seenAt) buffer = '%40' + buffer;
17871 seenAt = true;
17872 bufferCodePoints = arrayFrom(buffer);
17873 for (var i = 0; i < bufferCodePoints.length; i++) {
17874 var codePoint = bufferCodePoints[i];
17875 if (codePoint == ':' && !seenPasswordToken) {
17876 seenPasswordToken = true;
17877 continue;
17878 }
17879 var encodedCodePoints = percentEncode(codePoint, userinfoPercentEncodeSet);
17880 if (seenPasswordToken) url.password += encodedCodePoints;else url.username += encodedCodePoints;
17881 }
17882 buffer = '';
17883 } else if (chr == EOF || chr == '/' || chr == '?' || chr == '#' || chr == '\\' && url.isSpecial()) {
17884 if (seenAt && buffer == '') return INVALID_AUTHORITY;
17885 pointer -= arrayFrom(buffer).length + 1;
17886 buffer = '';
17887 state = HOST;
17888 } else buffer += chr;
17889 break;
17890 case HOST:
17891 case HOSTNAME:
17892 if (stateOverride && url.scheme == 'file') {
17893 state = FILE_HOST;
17894 continue;
17895 } else if (chr == ':' && !seenBracket) {
17896 if (buffer == '') return INVALID_HOST;
17897 failure = url.parseHost(buffer);
17898 if (failure) return failure;
17899 buffer = '';
17900 state = PORT;
17901 if (stateOverride == HOSTNAME) return;
17902 } else if (chr == EOF || chr == '/' || chr == '?' || chr == '#' || chr == '\\' && url.isSpecial()) {
17903 if (url.isSpecial() && buffer == '') return INVALID_HOST;
17904 if (stateOverride && buffer == '' && (url.includesCredentials() || url.port !== null)) return;
17905 failure = url.parseHost(buffer);
17906 if (failure) return failure;
17907 buffer = '';
17908 state = PATH_START;
17909 if (stateOverride) return;
17910 continue;
17911 } else {
17912 if (chr == '[') seenBracket = true;else if (chr == ']') seenBracket = false;
17913 buffer += chr;
17914 }
17915 break;
17916 case PORT:
17917 if (exec(DIGIT, chr)) {
17918 buffer += chr;
17919 } else if (chr == EOF || chr == '/' || chr == '?' || chr == '#' || chr == '\\' && url.isSpecial() || stateOverride) {
17920 if (buffer != '') {
17921 var port = parseInt$1(buffer, 10);
17922 if (port > 0xFFFF) return INVALID_PORT;
17923 url.port = url.isSpecial() && port === specialSchemes[url.scheme] ? null : port;
17924 buffer = '';
17925 }
17926 if (stateOverride) return;
17927 state = PATH_START;
17928 continue;
17929 } else return INVALID_PORT;
17930 break;
17931 case FILE:
17932 url.scheme = 'file';
17933 if (chr == '/' || chr == '\\') state = FILE_SLASH;else if (base && base.scheme == 'file') {
17934 if (chr == EOF) {
17935 url.host = base.host;
17936 url.path = arraySlice(base.path);
17937 url.query = base.query;
17938 } else if (chr == '?') {
17939 url.host = base.host;
17940 url.path = arraySlice(base.path);
17941 url.query = '';
17942 state = QUERY;
17943 } else if (chr == '#') {
17944 url.host = base.host;
17945 url.path = arraySlice(base.path);
17946 url.query = base.query;
17947 url.fragment = '';
17948 state = FRAGMENT;
17949 } else {
17950 if (!startsWithWindowsDriveLetter(join(arraySlice(codePoints, pointer), ''))) {
17951 url.host = base.host;
17952 url.path = arraySlice(base.path);
17953 url.shortenPath();
17954 }
17955 state = PATH;
17956 continue;
17957 }
17958 } else {
17959 state = PATH;
17960 continue;
17961 }
17962 break;
17963 case FILE_SLASH:
17964 if (chr == '/' || chr == '\\') {
17965 state = FILE_HOST;
17966 break;
17967 }
17968 if (base && base.scheme == 'file' && !startsWithWindowsDriveLetter(join(arraySlice(codePoints, pointer), ''))) {
17969 if (isWindowsDriveLetter(base.path[0], true)) push(url.path, base.path[0]);else url.host = base.host;
17970 }
17971 state = PATH;
17972 continue;
17973 case FILE_HOST:
17974 if (chr == EOF || chr == '/' || chr == '\\' || chr == '?' || chr == '#') {
17975 if (!stateOverride && isWindowsDriveLetter(buffer)) {
17976 state = PATH;
17977 } else if (buffer == '') {
17978 url.host = '';
17979 if (stateOverride) return;
17980 state = PATH_START;
17981 } else {
17982 failure = url.parseHost(buffer);
17983 if (failure) return failure;
17984 if (url.host == 'localhost') url.host = '';
17985 if (stateOverride) return;
17986 buffer = '';
17987 state = PATH_START;
17988 }
17989 continue;
17990 } else buffer += chr;
17991 break;
17992 case PATH_START:
17993 if (url.isSpecial()) {
17994 state = PATH;
17995 if (chr != '/' && chr != '\\') continue;
17996 } else if (!stateOverride && chr == '?') {
17997 url.query = '';
17998 state = QUERY;
17999 } else if (!stateOverride && chr == '#') {
18000 url.fragment = '';
18001 state = FRAGMENT;
18002 } else if (chr != EOF) {
18003 state = PATH;
18004 if (chr != '/') continue;
18005 }
18006 break;
18007 case PATH:
18008 if (chr == EOF || chr == '/' || chr == '\\' && url.isSpecial() || !stateOverride && (chr == '?' || chr == '#')) {
18009 if (isDoubleDot(buffer)) {
18010 url.shortenPath();
18011 if (chr != '/' && !(chr == '\\' && url.isSpecial())) {
18012 push(url.path, '');
18013 }
18014 } else if (isSingleDot(buffer)) {
18015 if (chr != '/' && !(chr == '\\' && url.isSpecial())) {
18016 push(url.path, '');
18017 }
18018 } else {
18019 if (url.scheme == 'file' && !url.path.length && isWindowsDriveLetter(buffer)) {
18020 if (url.host) url.host = '';
18021 buffer = charAt(buffer, 0) + ':'; // normalize windows drive letter
18022 }
18023
18024 push(url.path, buffer);
18025 }
18026 buffer = '';
18027 if (url.scheme == 'file' && (chr == EOF || chr == '?' || chr == '#')) {
18028 while (url.path.length > 1 && url.path[0] === '') {
18029 shift(url.path);
18030 }
18031 }
18032 if (chr == '?') {
18033 url.query = '';
18034 state = QUERY;
18035 } else if (chr == '#') {
18036 url.fragment = '';
18037 state = FRAGMENT;
18038 }
18039 } else {
18040 buffer += percentEncode(chr, pathPercentEncodeSet);
18041 }
18042 break;
18043 case CANNOT_BE_A_BASE_URL_PATH:
18044 if (chr == '?') {
18045 url.query = '';
18046 state = QUERY;
18047 } else if (chr == '#') {
18048 url.fragment = '';
18049 state = FRAGMENT;
18050 } else if (chr != EOF) {
18051 url.path[0] += percentEncode(chr, C0ControlPercentEncodeSet);
18052 }
18053 break;
18054 case QUERY:
18055 if (!stateOverride && chr == '#') {
18056 url.fragment = '';
18057 state = FRAGMENT;
18058 } else if (chr != EOF) {
18059 if (chr == "'" && url.isSpecial()) url.query += '%27';else if (chr == '#') url.query += '%23';else url.query += percentEncode(chr, C0ControlPercentEncodeSet);
18060 }
18061 break;
18062 case FRAGMENT:
18063 if (chr != EOF) url.fragment += percentEncode(chr, fragmentPercentEncodeSet);
18064 break;
18065 }
18066 pointer++;
18067 }
18068 },
18069 // https://url.spec.whatwg.org/#host-parsing
18070 parseHost: function parseHost(input) {
18071 var result, codePoints, index;
18072 if (charAt(input, 0) == '[') {
18073 if (charAt(input, input.length - 1) != ']') return INVALID_HOST;
18074 result = parseIPv6(stringSlice(input, 1, -1));
18075 if (!result) return INVALID_HOST;
18076 this.host = result;
18077 // opaque host
18078 } else if (!this.isSpecial()) {
18079 if (exec(FORBIDDEN_HOST_CODE_POINT_EXCLUDING_PERCENT, input)) return INVALID_HOST;
18080 result = '';
18081 codePoints = arrayFrom(input);
18082 for (index = 0; index < codePoints.length; index++) {
18083 result += percentEncode(codePoints[index], C0ControlPercentEncodeSet);
18084 }
18085 this.host = result;
18086 } else {
18087 input = toASCII(input);
18088 if (exec(FORBIDDEN_HOST_CODE_POINT, input)) return INVALID_HOST;
18089 result = parseIPv4(input);
18090 if (result === null) return INVALID_HOST;
18091 this.host = result;
18092 }
18093 },
18094 // https://url.spec.whatwg.org/#cannot-have-a-username-password-port
18095 cannotHaveUsernamePasswordPort: function cannotHaveUsernamePasswordPort() {
18096 return !this.host || this.cannotBeABaseURL || this.scheme == 'file';
18097 },
18098 // https://url.spec.whatwg.org/#include-credentials
18099 includesCredentials: function includesCredentials() {
18100 return this.username != '' || this.password != '';
18101 },
18102 // https://url.spec.whatwg.org/#is-special
18103 isSpecial: function isSpecial() {
18104 return hasOwn(specialSchemes, this.scheme);
18105 },
18106 // https://url.spec.whatwg.org/#shorten-a-urls-path
18107 shortenPath: function shortenPath() {
18108 var path = this.path;
18109 var pathSize = path.length;
18110 if (pathSize && (this.scheme != 'file' || pathSize != 1 || !isWindowsDriveLetter(path[0], true))) {
18111 path.length--;
18112 }
18113 },
18114 // https://url.spec.whatwg.org/#concept-url-serializer
18115 serialize: function serialize() {
18116 var url = this;
18117 var scheme = url.scheme;
18118 var username = url.username;
18119 var password = url.password;
18120 var host = url.host;
18121 var port = url.port;
18122 var path = url.path;
18123 var query = url.query;
18124 var fragment = url.fragment;
18125 var output = scheme + ':';
18126 if (host !== null) {
18127 output += '//';
18128 if (url.includesCredentials()) {
18129 output += username + (password ? ':' + password : '') + '@';
18130 }
18131 output += serializeHost(host);
18132 if (port !== null) output += ':' + port;
18133 } else if (scheme == 'file') output += '//';
18134 output += url.cannotBeABaseURL ? path[0] : path.length ? '/' + join(path, '/') : '';
18135 if (query !== null) output += '?' + query;
18136 if (fragment !== null) output += '#' + fragment;
18137 return output;
18138 },
18139 // https://url.spec.whatwg.org/#dom-url-href
18140 setHref: function setHref(href) {
18141 var failure = this.parse(href);
18142 if (failure) throw TypeError$1(failure);
18143 this.searchParams.update();
18144 },
18145 // https://url.spec.whatwg.org/#dom-url-origin
18146 getOrigin: function getOrigin() {
18147 var scheme = this.scheme;
18148 var port = this.port;
18149 if (scheme == 'blob') try {
18150 return new URLConstructor(scheme.path[0]).origin;
18151 } catch (error) {
18152 return 'null';
18153 }
18154 if (scheme == 'file' || !this.isSpecial()) return 'null';
18155 return scheme + '://' + serializeHost(this.host) + (port !== null ? ':' + port : '');
18156 },
18157 // https://url.spec.whatwg.org/#dom-url-protocol
18158 getProtocol: function getProtocol() {
18159 return this.scheme + ':';
18160 },
18161 setProtocol: function setProtocol(protocol) {
18162 this.parse($toString(protocol) + ':', SCHEME_START);
18163 },
18164 // https://url.spec.whatwg.org/#dom-url-username
18165 getUsername: function getUsername() {
18166 return this.username;
18167 },
18168 setUsername: function setUsername(username) {
18169 var codePoints = arrayFrom($toString(username));
18170 if (this.cannotHaveUsernamePasswordPort()) return;
18171 this.username = '';
18172 for (var i = 0; i < codePoints.length; i++) {
18173 this.username += percentEncode(codePoints[i], userinfoPercentEncodeSet);
18174 }
18175 },
18176 // https://url.spec.whatwg.org/#dom-url-password
18177 getPassword: function getPassword() {
18178 return this.password;
18179 },
18180 setPassword: function setPassword(password) {
18181 var codePoints = arrayFrom($toString(password));
18182 if (this.cannotHaveUsernamePasswordPort()) return;
18183 this.password = '';
18184 for (var i = 0; i < codePoints.length; i++) {
18185 this.password += percentEncode(codePoints[i], userinfoPercentEncodeSet);
18186 }
18187 },
18188 // https://url.spec.whatwg.org/#dom-url-host
18189 getHost: function getHost() {
18190 var host = this.host;
18191 var port = this.port;
18192 return host === null ? '' : port === null ? serializeHost(host) : serializeHost(host) + ':' + port;
18193 },
18194 setHost: function setHost(host) {
18195 if (this.cannotBeABaseURL) return;
18196 this.parse(host, HOST);
18197 },
18198 // https://url.spec.whatwg.org/#dom-url-hostname
18199 getHostname: function getHostname() {
18200 var host = this.host;
18201 return host === null ? '' : serializeHost(host);
18202 },
18203 setHostname: function setHostname(hostname) {
18204 if (this.cannotBeABaseURL) return;
18205 this.parse(hostname, HOSTNAME);
18206 },
18207 // https://url.spec.whatwg.org/#dom-url-port
18208 getPort: function getPort() {
18209 var port = this.port;
18210 return port === null ? '' : $toString(port);
18211 },
18212 setPort: function setPort(port) {
18213 if (this.cannotHaveUsernamePasswordPort()) return;
18214 port = $toString(port);
18215 if (port == '') this.port = null;else this.parse(port, PORT);
18216 },
18217 // https://url.spec.whatwg.org/#dom-url-pathname
18218 getPathname: function getPathname() {
18219 var path = this.path;
18220 return this.cannotBeABaseURL ? path[0] : path.length ? '/' + join(path, '/') : '';
18221 },
18222 setPathname: function setPathname(pathname) {
18223 if (this.cannotBeABaseURL) return;
18224 this.path = [];
18225 this.parse(pathname, PATH_START);
18226 },
18227 // https://url.spec.whatwg.org/#dom-url-search
18228 getSearch: function getSearch() {
18229 var query = this.query;
18230 return query ? '?' + query : '';
18231 },
18232 setSearch: function setSearch(search) {
18233 search = $toString(search);
18234 if (search == '') {
18235 this.query = null;
18236 } else {
18237 if ('?' == charAt(search, 0)) search = stringSlice(search, 1);
18238 this.query = '';
18239 this.parse(search, QUERY);
18240 }
18241 this.searchParams.update();
18242 },
18243 // https://url.spec.whatwg.org/#dom-url-searchparams
18244 getSearchParams: function getSearchParams() {
18245 return this.searchParams.facade;
18246 },
18247 // https://url.spec.whatwg.org/#dom-url-hash
18248 getHash: function getHash() {
18249 var fragment = this.fragment;
18250 return fragment ? '#' + fragment : '';
18251 },
18252 setHash: function setHash(hash) {
18253 hash = $toString(hash);
18254 if (hash == '') {
18255 this.fragment = null;
18256 return;
18257 }
18258 if ('#' == charAt(hash, 0)) hash = stringSlice(hash, 1);
18259 this.fragment = '';
18260 this.parse(hash, FRAGMENT);
18261 },
18262 update: function update() {
18263 this.query = this.searchParams.serialize() || null;
18264 }
18265 };
18266
18267 // `URL` constructor
18268 // https://url.spec.whatwg.org/#url-class
18269 var URLConstructor = function URL(url /* , base */) {
18270 var that = anInstance(this, URLPrototype);
18271 var base = validateArgumentsLength(arguments.length, 1) > 1 ? arguments[1] : undefined;
18272 var state = setInternalState(that, new URLState(url, false, base));
18273 if (!DESCRIPTORS$1) {
18274 that.href = state.serialize();
18275 that.origin = state.getOrigin();
18276 that.protocol = state.getProtocol();
18277 that.username = state.getUsername();
18278 that.password = state.getPassword();
18279 that.host = state.getHost();
18280 that.hostname = state.getHostname();
18281 that.port = state.getPort();
18282 that.pathname = state.getPathname();
18283 that.search = state.getSearch();
18284 that.searchParams = state.getSearchParams();
18285 that.hash = state.getHash();
18286 }
18287 };
18288 var URLPrototype = URLConstructor.prototype;
18289 var accessorDescriptor = function accessorDescriptor(getter, setter) {
18290 return {
18291 get: function get() {
18292 return getInternalURLState(this)[getter]();
18293 },
18294 set: setter && function (value) {
18295 return getInternalURLState(this)[setter](value);
18296 },
18297 configurable: true,
18298 enumerable: true
18299 };
18300 };
18301 if (DESCRIPTORS$1) {
18302 // `URL.prototype.href` accessors pair
18303 // https://url.spec.whatwg.org/#dom-url-href
18304 defineBuiltInAccessor$1(URLPrototype, 'href', accessorDescriptor('serialize', 'setHref'));
18305 // `URL.prototype.origin` getter
18306 // https://url.spec.whatwg.org/#dom-url-origin
18307 defineBuiltInAccessor$1(URLPrototype, 'origin', accessorDescriptor('getOrigin'));
18308 // `URL.prototype.protocol` accessors pair
18309 // https://url.spec.whatwg.org/#dom-url-protocol
18310 defineBuiltInAccessor$1(URLPrototype, 'protocol', accessorDescriptor('getProtocol', 'setProtocol'));
18311 // `URL.prototype.username` accessors pair
18312 // https://url.spec.whatwg.org/#dom-url-username
18313 defineBuiltInAccessor$1(URLPrototype, 'username', accessorDescriptor('getUsername', 'setUsername'));
18314 // `URL.prototype.password` accessors pair
18315 // https://url.spec.whatwg.org/#dom-url-password
18316 defineBuiltInAccessor$1(URLPrototype, 'password', accessorDescriptor('getPassword', 'setPassword'));
18317 // `URL.prototype.host` accessors pair
18318 // https://url.spec.whatwg.org/#dom-url-host
18319 defineBuiltInAccessor$1(URLPrototype, 'host', accessorDescriptor('getHost', 'setHost'));
18320 // `URL.prototype.hostname` accessors pair
18321 // https://url.spec.whatwg.org/#dom-url-hostname
18322 defineBuiltInAccessor$1(URLPrototype, 'hostname', accessorDescriptor('getHostname', 'setHostname'));
18323 // `URL.prototype.port` accessors pair
18324 // https://url.spec.whatwg.org/#dom-url-port
18325 defineBuiltInAccessor$1(URLPrototype, 'port', accessorDescriptor('getPort', 'setPort'));
18326 // `URL.prototype.pathname` accessors pair
18327 // https://url.spec.whatwg.org/#dom-url-pathname
18328 defineBuiltInAccessor$1(URLPrototype, 'pathname', accessorDescriptor('getPathname', 'setPathname'));
18329 // `URL.prototype.search` accessors pair
18330 // https://url.spec.whatwg.org/#dom-url-search
18331 defineBuiltInAccessor$1(URLPrototype, 'search', accessorDescriptor('getSearch', 'setSearch'));
18332 // `URL.prototype.searchParams` getter
18333 // https://url.spec.whatwg.org/#dom-url-searchparams
18334 defineBuiltInAccessor$1(URLPrototype, 'searchParams', accessorDescriptor('getSearchParams'));
18335 // `URL.prototype.hash` accessors pair
18336 // https://url.spec.whatwg.org/#dom-url-hash
18337 defineBuiltInAccessor$1(URLPrototype, 'hash', accessorDescriptor('getHash', 'setHash'));
18338 }
18339
18340 // `URL.prototype.toJSON` method
18341 // https://url.spec.whatwg.org/#dom-url-tojson
18342 defineBuiltIn(URLPrototype, 'toJSON', function toJSON() {
18343 return getInternalURLState(this).serialize();
18344 }, {
18345 enumerable: true
18346 });
18347
18348 // `URL.prototype.toString` method
18349 // https://url.spec.whatwg.org/#URL-stringification-behavior
18350 defineBuiltIn(URLPrototype, 'toString', function toString() {
18351 return getInternalURLState(this).serialize();
18352 }, {
18353 enumerable: true
18354 });
18355 if (NativeURL) {
18356 var nativeCreateObjectURL = NativeURL.createObjectURL;
18357 var nativeRevokeObjectURL = NativeURL.revokeObjectURL;
18358 // `URL.createObjectURL` method
18359 // https://developer.mozilla.org/en-US/docs/Web/API/URL/createObjectURL
18360 if (nativeCreateObjectURL) defineBuiltIn(URLConstructor, 'createObjectURL', bind(nativeCreateObjectURL, NativeURL));
18361 // `URL.revokeObjectURL` method
18362 // https://developer.mozilla.org/en-US/docs/Web/API/URL/revokeObjectURL
18363 if (nativeRevokeObjectURL) defineBuiltIn(URLConstructor, 'revokeObjectURL', bind(nativeRevokeObjectURL, NativeURL));
18364 }
18365 setToStringTag(URLConstructor, 'URL');
18366 $$1({
18367 global: true,
18368 constructor: true,
18369 forced: !USE_NATIVE_URL,
18370 sham: !DESCRIPTORS$1
18371 }, {
18372 URL: URLConstructor
18373 });
18374
18375 var web_url_toJson = {};
18376
18377 'use strict';
18378 var $ = _export;
18379 var call = functionCall;
18380
18381 // `URL.prototype.toJSON` method
18382 // https://url.spec.whatwg.org/#dom-url-tojson
18383 $({
18384 target: 'URL',
18385 proto: true,
18386 enumerable: true
18387 }, {
18388 toJSON: function toJSON() {
18389 return call(URL.prototype.toString, this);
18390 }
18391 });
18392
18393 var web_urlSearchParams = {};
18394
18395 var web_urlSearchParams_size = {};
18396
18397 'use strict';
18398 var DESCRIPTORS = descriptors;
18399 var uncurryThis = functionUncurryThis;
18400 var defineBuiltInAccessor = defineBuiltInAccessor$h;
18401 var URLSearchParamsPrototype = URLSearchParams.prototype;
18402 var forEach = uncurryThis(URLSearchParamsPrototype.forEach);
18403
18404 // `URLSearchParams.prototype.size` getter
18405 // https://github.com/whatwg/url/pull/734
18406 if (DESCRIPTORS && !('size' in URLSearchParamsPrototype)) {
18407 defineBuiltInAccessor(URLSearchParamsPrototype, 'size', {
18408 get: function size() {
18409 var count = 0;
18410 forEach(this, function () {
18411 count++;
18412 });
18413 return count;
18414 },
18415 configurable: true,
18416 enumerable: true
18417 });
18418 }
18419
18420 var stable = path$2;
18421 var index$2 = /*@__PURE__*/getDefaultExportFromCjs(stable);
18422
18423 var runtime$1 = {exports: {}};
18424
18425 var runtime_1 = runtime$1.exports;
18426 (function (module) {
18427 var runtime = function (exports) {
18428 "use strict";
18429
18430 var Op = Object.prototype;
18431 var hasOwn = Op.hasOwnProperty;
18432 var undefined$1; // More compressible than void 0.
18433 var $Symbol = typeof Symbol === "function" ? Symbol : {};
18434 var iteratorSymbol = $Symbol.iterator || "@@iterator";
18435 var asyncIteratorSymbol = $Symbol.asyncIterator || "@@asyncIterator";
18436 var toStringTagSymbol = $Symbol.toStringTag || "@@toStringTag";
18437 function define(obj, key, value) {
18438 Object.defineProperty(obj, key, {
18439 value: value,
18440 enumerable: true,
18441 configurable: true,
18442 writable: true
18443 });
18444 return obj[key];
18445 }
18446 try {
18447 // IE 8 has a broken Object.defineProperty that only works on DOM objects.
18448 define({}, "");
18449 } catch (err) {
18450 define = function define(obj, key, value) {
18451 return obj[key] = value;
18452 };
18453 }
18454 function wrap(innerFn, outerFn, self, tryLocsList) {
18455 // If outerFn provided and outerFn.prototype is a Generator, then outerFn.prototype instanceof Generator.
18456 var protoGenerator = outerFn && outerFn.prototype instanceof Generator ? outerFn : Generator;
18457 var generator = Object.create(protoGenerator.prototype);
18458 var context = new Context(tryLocsList || []);
18459
18460 // The ._invoke method unifies the implementations of the .next,
18461 // .throw, and .return methods.
18462 generator._invoke = makeInvokeMethod(innerFn, self, context);
18463 return generator;
18464 }
18465 exports.wrap = wrap;
18466
18467 // Try/catch helper to minimize deoptimizations. Returns a completion
18468 // record like context.tryEntries[i].completion. This interface could
18469 // have been (and was previously) designed to take a closure to be
18470 // invoked without arguments, but in all the cases we care about we
18471 // already have an existing method we want to call, so there's no need
18472 // to create a new function object. We can even get away with assuming
18473 // the method takes exactly one argument, since that happens to be true
18474 // in every case, so we don't have to touch the arguments object. The
18475 // only additional allocation required is the completion record, which
18476 // has a stable shape and so hopefully should be cheap to allocate.
18477 function tryCatch(fn, obj, arg) {
18478 try {
18479 return {
18480 type: "normal",
18481 arg: fn.call(obj, arg)
18482 };
18483 } catch (err) {
18484 return {
18485 type: "throw",
18486 arg: err
18487 };
18488 }
18489 }
18490 var GenStateSuspendedStart = "suspendedStart";
18491 var GenStateSuspendedYield = "suspendedYield";
18492 var GenStateExecuting = "executing";
18493 var GenStateCompleted = "completed";
18494
18495 // Returning this object from the innerFn has the same effect as
18496 // breaking out of the dispatch switch statement.
18497 var ContinueSentinel = {};
18498
18499 // Dummy constructor functions that we use as the .constructor and
18500 // .constructor.prototype properties for functions that return Generator
18501 // objects. For full spec compliance, you may wish to configure your
18502 // minifier not to mangle the names of these two functions.
18503 function Generator() {}
18504 function GeneratorFunction() {}
18505 function GeneratorFunctionPrototype() {}
18506
18507 // This is a polyfill for %IteratorPrototype% for environments that
18508 // don't natively support it.
18509 var IteratorPrototype = {};
18510 IteratorPrototype[iteratorSymbol] = function () {
18511 return this;
18512 };
18513 var getProto = Object.getPrototypeOf;
18514 var NativeIteratorPrototype = getProto && getProto(getProto(values([])));
18515 if (NativeIteratorPrototype && NativeIteratorPrototype !== Op && hasOwn.call(NativeIteratorPrototype, iteratorSymbol)) {
18516 // This environment has a native %IteratorPrototype%; use it instead
18517 // of the polyfill.
18518 IteratorPrototype = NativeIteratorPrototype;
18519 }
18520 var Gp = GeneratorFunctionPrototype.prototype = Generator.prototype = Object.create(IteratorPrototype);
18521 GeneratorFunction.prototype = Gp.constructor = GeneratorFunctionPrototype;
18522 GeneratorFunctionPrototype.constructor = GeneratorFunction;
18523 GeneratorFunction.displayName = define(GeneratorFunctionPrototype, toStringTagSymbol, "GeneratorFunction");
18524
18525 // Helper for defining the .next, .throw, and .return methods of the
18526 // Iterator interface in terms of a single ._invoke method.
18527 function defineIteratorMethods(prototype) {
18528 ["next", "throw", "return"].forEach(function (method) {
18529 define(prototype, method, function (arg) {
18530 return this._invoke(method, arg);
18531 });
18532 });
18533 }
18534 exports.isGeneratorFunction = function (genFun) {
18535 var ctor = typeof genFun === "function" && genFun.constructor;
18536 return ctor ? ctor === GeneratorFunction ||
18537 // For the native GeneratorFunction constructor, the best we can
18538 // do is to check its .name property.
18539 (ctor.displayName || ctor.name) === "GeneratorFunction" : false;
18540 };
18541 exports.mark = function (genFun) {
18542 if (Object.setPrototypeOf) {
18543 Object.setPrototypeOf(genFun, GeneratorFunctionPrototype);
18544 } else {
18545 genFun.__proto__ = GeneratorFunctionPrototype;
18546 define(genFun, toStringTagSymbol, "GeneratorFunction");
18547 }
18548 genFun.prototype = Object.create(Gp);
18549 return genFun;
18550 };
18551
18552 // Within the body of any async function, `await x` is transformed to
18553 // `yield regeneratorRuntime.awrap(x)`, so that the runtime can test
18554 // `hasOwn.call(value, "__await")` to determine if the yielded value is
18555 // meant to be awaited.
18556 exports.awrap = function (arg) {
18557 return {
18558 __await: arg
18559 };
18560 };
18561 function AsyncIterator(generator, PromiseImpl) {
18562 function invoke(method, arg, resolve, reject) {
18563 var record = tryCatch(generator[method], generator, arg);
18564 if (record.type === "throw") {
18565 reject(record.arg);
18566 } else {
18567 var result = record.arg;
18568 var value = result.value;
18569 if (value && _typeof(value) === "object" && hasOwn.call(value, "__await")) {
18570 return PromiseImpl.resolve(value.__await).then(function (value) {
18571 invoke("next", value, resolve, reject);
18572 }, function (err) {
18573 invoke("throw", err, resolve, reject);
18574 });
18575 }
18576 return PromiseImpl.resolve(value).then(function (unwrapped) {
18577 // When a yielded Promise is resolved, its final value becomes
18578 // the .value of the Promise<{value,done}> result for the
18579 // current iteration.
18580 result.value = unwrapped;
18581 resolve(result);
18582 }, function (error) {
18583 // If a rejected Promise was yielded, throw the rejection back
18584 // into the async generator function so it can be handled there.
18585 return invoke("throw", error, resolve, reject);
18586 });
18587 }
18588 }
18589 var previousPromise;
18590 function enqueue(method, arg) {
18591 function callInvokeWithMethodAndArg() {
18592 return new PromiseImpl(function (resolve, reject) {
18593 invoke(method, arg, resolve, reject);
18594 });
18595 }
18596 return previousPromise =
18597 // If enqueue has been called before, then we want to wait until
18598 // all previous Promises have been resolved before calling invoke,
18599 // so that results are always delivered in the correct order. If
18600 // enqueue has not been called before, then it is important to
18601 // call invoke immediately, without waiting on a callback to fire,
18602 // so that the async generator function has the opportunity to do
18603 // any necessary setup in a predictable way. This predictability
18604 // is why the Promise constructor synchronously invokes its
18605 // executor callback, and why async functions synchronously
18606 // execute code before the first await. Since we implement simple
18607 // async functions in terms of async generators, it is especially
18608 // important to get this right, even though it requires care.
18609 previousPromise ? previousPromise.then(callInvokeWithMethodAndArg,
18610 // Avoid propagating failures to Promises returned by later
18611 // invocations of the iterator.
18612 callInvokeWithMethodAndArg) : callInvokeWithMethodAndArg();
18613 }
18614
18615 // Define the unified helper method that is used to implement .next,
18616 // .throw, and .return (see defineIteratorMethods).
18617 this._invoke = enqueue;
18618 }
18619 defineIteratorMethods(AsyncIterator.prototype);
18620 AsyncIterator.prototype[asyncIteratorSymbol] = function () {
18621 return this;
18622 };
18623 exports.AsyncIterator = AsyncIterator;
18624
18625 // Note that simple async functions are implemented on top of
18626 // AsyncIterator objects; they just return a Promise for the value of
18627 // the final result produced by the iterator.
18628 exports.async = function (innerFn, outerFn, self, tryLocsList, PromiseImpl) {
18629 if (PromiseImpl === void 0) PromiseImpl = Promise;
18630 var iter = new AsyncIterator(wrap(innerFn, outerFn, self, tryLocsList), PromiseImpl);
18631 return exports.isGeneratorFunction(outerFn) ? iter // If outerFn is a generator, return the full iterator.
18632 : iter.next().then(function (result) {
18633 return result.done ? result.value : iter.next();
18634 });
18635 };
18636 function makeInvokeMethod(innerFn, self, context) {
18637 var state = GenStateSuspendedStart;
18638 return function invoke(method, arg) {
18639 if (state === GenStateExecuting) {
18640 throw new Error("Generator is already running");
18641 }
18642 if (state === GenStateCompleted) {
18643 if (method === "throw") {
18644 throw arg;
18645 }
18646
18647 // Be forgiving, per 25.3.3.3.3 of the spec:
18648 // https://people.mozilla.org/~jorendorff/es6-draft.html#sec-generatorresume
18649 return doneResult();
18650 }
18651 context.method = method;
18652 context.arg = arg;
18653 while (true) {
18654 var delegate = context.delegate;
18655 if (delegate) {
18656 var delegateResult = maybeInvokeDelegate(delegate, context);
18657 if (delegateResult) {
18658 if (delegateResult === ContinueSentinel) continue;
18659 return delegateResult;
18660 }
18661 }
18662 if (context.method === "next") {
18663 // Setting context._sent for legacy support of Babel's
18664 // function.sent implementation.
18665 context.sent = context._sent = context.arg;
18666 } else if (context.method === "throw") {
18667 if (state === GenStateSuspendedStart) {
18668 state = GenStateCompleted;
18669 throw context.arg;
18670 }
18671 context.dispatchException(context.arg);
18672 } else if (context.method === "return") {
18673 context.abrupt("return", context.arg);
18674 }
18675 state = GenStateExecuting;
18676 var record = tryCatch(innerFn, self, context);
18677 if (record.type === "normal") {
18678 // If an exception is thrown from innerFn, we leave state ===
18679 // GenStateExecuting and loop back for another invocation.
18680 state = context.done ? GenStateCompleted : GenStateSuspendedYield;
18681 if (record.arg === ContinueSentinel) {
18682 continue;
18683 }
18684 return {
18685 value: record.arg,
18686 done: context.done
18687 };
18688 } else if (record.type === "throw") {
18689 state = GenStateCompleted;
18690 // Dispatch the exception by looping back around to the
18691 // context.dispatchException(context.arg) call above.
18692 context.method = "throw";
18693 context.arg = record.arg;
18694 }
18695 }
18696 };
18697 }
18698
18699 // Call delegate.iterator[context.method](context.arg) and handle the
18700 // result, either by returning a { value, done } result from the
18701 // delegate iterator, or by modifying context.method and context.arg,
18702 // setting context.delegate to null, and returning the ContinueSentinel.
18703 function maybeInvokeDelegate(delegate, context) {
18704 var method = delegate.iterator[context.method];
18705 if (method === undefined$1) {
18706 // A .throw or .return when the delegate iterator has no .throw
18707 // method always terminates the yield* loop.
18708 context.delegate = null;
18709 if (context.method === "throw") {
18710 // Note: ["return"] must be used for ES3 parsing compatibility.
18711 if (delegate.iterator["return"]) {
18712 // If the delegate iterator has a return method, give it a
18713 // chance to clean up.
18714 context.method = "return";
18715 context.arg = undefined$1;
18716 maybeInvokeDelegate(delegate, context);
18717 if (context.method === "throw") {
18718 // If maybeInvokeDelegate(context) changed context.method from
18719 // "return" to "throw", let that override the TypeError below.
18720 return ContinueSentinel;
18721 }
18722 }
18723 context.method = "throw";
18724 context.arg = new TypeError("The iterator does not provide a 'throw' method");
18725 }
18726 return ContinueSentinel;
18727 }
18728 var record = tryCatch(method, delegate.iterator, context.arg);
18729 if (record.type === "throw") {
18730 context.method = "throw";
18731 context.arg = record.arg;
18732 context.delegate = null;
18733 return ContinueSentinel;
18734 }
18735 var info = record.arg;
18736 if (!info) {
18737 context.method = "throw";
18738 context.arg = new TypeError("iterator result is not an object");
18739 context.delegate = null;
18740 return ContinueSentinel;
18741 }
18742 if (info.done) {
18743 // Assign the result of the finished delegate to the temporary
18744 // variable specified by delegate.resultName (see delegateYield).
18745 context[delegate.resultName] = info.value;
18746
18747 // Resume execution at the desired location (see delegateYield).
18748 context.next = delegate.nextLoc;
18749
18750 // If context.method was "throw" but the delegate handled the
18751 // exception, let the outer generator proceed normally. If
18752 // context.method was "next", forget context.arg since it has been
18753 // "consumed" by the delegate iterator. If context.method was
18754 // "return", allow the original .return call to continue in the
18755 // outer generator.
18756 if (context.method !== "return") {
18757 context.method = "next";
18758 context.arg = undefined$1;
18759 }
18760 } else {
18761 // Re-yield the result returned by the delegate method.
18762 return info;
18763 }
18764
18765 // The delegate iterator is finished, so forget it and continue with
18766 // the outer generator.
18767 context.delegate = null;
18768 return ContinueSentinel;
18769 }
18770
18771 // Define Generator.prototype.{next,throw,return} in terms of the
18772 // unified ._invoke helper method.
18773 defineIteratorMethods(Gp);
18774 define(Gp, toStringTagSymbol, "Generator");
18775
18776 // A Generator should always return itself as the iterator object when the
18777 // @@iterator function is called on it. Some browsers' implementations of the
18778 // iterator prototype chain incorrectly implement this, causing the Generator
18779 // object to not be returned from this call. This ensures that doesn't happen.
18780 // See https://github.com/facebook/regenerator/issues/274 for more details.
18781 Gp[iteratorSymbol] = function () {
18782 return this;
18783 };
18784 Gp.toString = function () {
18785 return "[object Generator]";
18786 };
18787 function pushTryEntry(locs) {
18788 var entry = {
18789 tryLoc: locs[0]
18790 };
18791 if (1 in locs) {
18792 entry.catchLoc = locs[1];
18793 }
18794 if (2 in locs) {
18795 entry.finallyLoc = locs[2];
18796 entry.afterLoc = locs[3];
18797 }
18798 this.tryEntries.push(entry);
18799 }
18800 function resetTryEntry(entry) {
18801 var record = entry.completion || {};
18802 record.type = "normal";
18803 delete record.arg;
18804 entry.completion = record;
18805 }
18806 function Context(tryLocsList) {
18807 // The root entry object (effectively a try statement without a catch
18808 // or a finally block) gives us a place to store values thrown from
18809 // locations where there is no enclosing try statement.
18810 this.tryEntries = [{
18811 tryLoc: "root"
18812 }];
18813 tryLocsList.forEach(pushTryEntry, this);
18814 this.reset(true);
18815 }
18816 exports.keys = function (object) {
18817 var keys = [];
18818 for (var key in object) {
18819 keys.push(key);
18820 }
18821 keys.reverse();
18822
18823 // Rather than returning an object with a next method, we keep
18824 // things simple and return the next function itself.
18825 return function next() {
18826 while (keys.length) {
18827 var key = keys.pop();
18828 if (key in object) {
18829 next.value = key;
18830 next.done = false;
18831 return next;
18832 }
18833 }
18834
18835 // To avoid creating an additional object, we just hang the .value
18836 // and .done properties off the next function object itself. This
18837 // also ensures that the minifier will not anonymize the function.
18838 next.done = true;
18839 return next;
18840 };
18841 };
18842 function values(iterable) {
18843 if (iterable) {
18844 var iteratorMethod = iterable[iteratorSymbol];
18845 if (iteratorMethod) {
18846 return iteratorMethod.call(iterable);
18847 }
18848 if (typeof iterable.next === "function") {
18849 return iterable;
18850 }
18851 if (!isNaN(iterable.length)) {
18852 var i = -1,
18853 next = function next() {
18854 while (++i < iterable.length) {
18855 if (hasOwn.call(iterable, i)) {
18856 next.value = iterable[i];
18857 next.done = false;
18858 return next;
18859 }
18860 }
18861 next.value = undefined$1;
18862 next.done = true;
18863 return next;
18864 };
18865 return next.next = next;
18866 }
18867 }
18868
18869 // Return an iterator with no values.
18870 return {
18871 next: doneResult
18872 };
18873 }
18874 exports.values = values;
18875 function doneResult() {
18876 return {
18877 value: undefined$1,
18878 done: true
18879 };
18880 }
18881 Context.prototype = {
18882 constructor: Context,
18883 reset: function reset(skipTempReset) {
18884 this.prev = 0;
18885 this.next = 0;
18886 // Resetting context._sent for legacy support of Babel's
18887 // function.sent implementation.
18888 this.sent = this._sent = undefined$1;
18889 this.done = false;
18890 this.delegate = null;
18891 this.method = "next";
18892 this.arg = undefined$1;
18893 this.tryEntries.forEach(resetTryEntry);
18894 if (!skipTempReset) {
18895 for (var name in this) {
18896 // Not sure about the optimal order of these conditions:
18897 if (name.charAt(0) === "t" && hasOwn.call(this, name) && !isNaN(+name.slice(1))) {
18898 this[name] = undefined$1;
18899 }
18900 }
18901 }
18902 },
18903 stop: function stop() {
18904 this.done = true;
18905 var rootEntry = this.tryEntries[0];
18906 var rootRecord = rootEntry.completion;
18907 if (rootRecord.type === "throw") {
18908 throw rootRecord.arg;
18909 }
18910 return this.rval;
18911 },
18912 dispatchException: function dispatchException(exception) {
18913 if (this.done) {
18914 throw exception;
18915 }
18916 var context = this;
18917 function handle(loc, caught) {
18918 record.type = "throw";
18919 record.arg = exception;
18920 context.next = loc;
18921 if (caught) {
18922 // If the dispatched exception was caught by a catch block,
18923 // then let that catch block handle the exception normally.
18924 context.method = "next";
18925 context.arg = undefined$1;
18926 }
18927 return !!caught;
18928 }
18929 for (var i = this.tryEntries.length - 1; i >= 0; --i) {
18930 var entry = this.tryEntries[i];
18931 var record = entry.completion;
18932 if (entry.tryLoc === "root") {
18933 // Exception thrown outside of any try block that could handle
18934 // it, so set the completion value of the entire function to
18935 // throw the exception.
18936 return handle("end");
18937 }
18938 if (entry.tryLoc <= this.prev) {
18939 var hasCatch = hasOwn.call(entry, "catchLoc");
18940 var hasFinally = hasOwn.call(entry, "finallyLoc");
18941 if (hasCatch && hasFinally) {
18942 if (this.prev < entry.catchLoc) {
18943 return handle(entry.catchLoc, true);
18944 } else if (this.prev < entry.finallyLoc) {
18945 return handle(entry.finallyLoc);
18946 }
18947 } else if (hasCatch) {
18948 if (this.prev < entry.catchLoc) {
18949 return handle(entry.catchLoc, true);
18950 }
18951 } else if (hasFinally) {
18952 if (this.prev < entry.finallyLoc) {
18953 return handle(entry.finallyLoc);
18954 }
18955 } else {
18956 throw new Error("try statement without catch or finally");
18957 }
18958 }
18959 }
18960 },
18961 abrupt: function abrupt(type, arg) {
18962 for (var i = this.tryEntries.length - 1; i >= 0; --i) {
18963 var entry = this.tryEntries[i];
18964 if (entry.tryLoc <= this.prev && hasOwn.call(entry, "finallyLoc") && this.prev < entry.finallyLoc) {
18965 var finallyEntry = entry;
18966 break;
18967 }
18968 }
18969 if (finallyEntry && (type === "break" || type === "continue") && finallyEntry.tryLoc <= arg && arg <= finallyEntry.finallyLoc) {
18970 // Ignore the finally entry if control is not jumping to a
18971 // location outside the try/catch block.
18972 finallyEntry = null;
18973 }
18974 var record = finallyEntry ? finallyEntry.completion : {};
18975 record.type = type;
18976 record.arg = arg;
18977 if (finallyEntry) {
18978 this.method = "next";
18979 this.next = finallyEntry.finallyLoc;
18980 return ContinueSentinel;
18981 }
18982 return this.complete(record);
18983 },
18984 complete: function complete(record, afterLoc) {
18985 if (record.type === "throw") {
18986 throw record.arg;
18987 }
18988 if (record.type === "break" || record.type === "continue") {
18989 this.next = record.arg;
18990 } else if (record.type === "return") {
18991 this.rval = this.arg = record.arg;
18992 this.method = "return";
18993 this.next = "end";
18994 } else if (record.type === "normal" && afterLoc) {
18995 this.next = afterLoc;
18996 }
18997 return ContinueSentinel;
18998 },
18999 finish: function finish(finallyLoc) {
19000 for (var i = this.tryEntries.length - 1; i >= 0; --i) {
19001 var entry = this.tryEntries[i];
19002 if (entry.finallyLoc === finallyLoc) {
19003 this.complete(entry.completion, entry.afterLoc);
19004 resetTryEntry(entry);
19005 return ContinueSentinel;
19006 }
19007 }
19008 },
19009 "catch": function _catch(tryLoc) {
19010 for (var i = this.tryEntries.length - 1; i >= 0; --i) {
19011 var entry = this.tryEntries[i];
19012 if (entry.tryLoc === tryLoc) {
19013 var record = entry.completion;
19014 if (record.type === "throw") {
19015 var thrown = record.arg;
19016 resetTryEntry(entry);
19017 }
19018 return thrown;
19019 }
19020 }
19021
19022 // The context.catch method must only be called with a location
19023 // argument that corresponds to a known catch block.
19024 throw new Error("illegal catch attempt");
19025 },
19026 delegateYield: function delegateYield(iterable, resultName, nextLoc) {
19027 this.delegate = {
19028 iterator: values(iterable),
19029 resultName: resultName,
19030 nextLoc: nextLoc
19031 };
19032 if (this.method === "next") {
19033 // Deliberately forget the last sent value so that we don't
19034 // accidentally pass it on to the delegate.
19035 this.arg = undefined$1;
19036 }
19037 return ContinueSentinel;
19038 }
19039 };
19040
19041 // Regardless of whether this script is executing as a CommonJS module
19042 // or not, return the runtime object so that we can declare the variable
19043 // regeneratorRuntime in the outer scope, which allows this module to be
19044 // injected easily by `bin/regenerator --include-runtime script.js`.
19045 return exports;
19046 }(
19047 // If this script is executing as a CommonJS module, use module.exports
19048 // as the regeneratorRuntime namespace. Otherwise create a new empty
19049 // object. Either way, the resulting object will be used to initialize
19050 // the regeneratorRuntime variable at the top of this file.
19051 'object' === "object" ? module.exports : {});
19052 try {
19053 regeneratorRuntime = runtime;
19054 } catch (accidentalStrictMode) {
19055 // This module should not be running in strict mode, so the above
19056 // assignment should always work unless something is misconfigured. Just
19057 // in case runtime.js accidentally runs in strict mode, we can escape
19058 // strict mode using a global Function call. This could conceivably fail
19059 // if a Content Security Policy forbids using Function, but in that case
19060 // the proper solution is to fix the accidental strict mode problem. If
19061 // you've misconfigured your bundler to force strict mode and applied a
19062 // CSP to forbid Function, and you're not willing to fix either of those
19063 // problems, please detail your unique predicament in a GitHub issue.
19064 Function("r", "regeneratorRuntime = r")(runtime);
19065 }
19066 })(runtime$1);
19067 var runtimeExports = runtime$1.exports;
19068 var runtime = /*@__PURE__*/getDefaultExportFromCjs(runtimeExports);
19069
19070 /**
19071 * @license
19072 * Copyright 2020 Google LLC. All Rights Reserved.
19073 * Licensed under the Apache License, Version 2.0 (the "License");
19074 * you may not use this file except in compliance with the License.
19075 * You may obtain a copy of the License at
19076 *
19077 * http://www.apache.org/licenses/LICENSE-2.0
19078 *
19079 * Unless required by applicable law or agreed to in writing, software
19080 * distributed under the License is distributed on an "AS IS" BASIS,
19081 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19082 * See the License for the specific language governing permissions and
19083 * limitations under the License.
19084 * =============================================================================
19085 */
19086 var EPSILON_FLOAT32$1 = 1e-7;
19087 var EPSILON_FLOAT16$1 = 1e-4;
19088 /** Convenient class for storing tensor-related data. */
19089 var DataStorage = /*#__PURE__*/function () {
19090 function DataStorage(backend, dataMover) {
19091 _classCallCheck(this, DataStorage);
19092 this.backend = backend;
19093 this.dataMover = dataMover;
19094 this.data = new WeakMap();
19095 this.dataIdsCount = 0;
19096 }
19097 _createClass(DataStorage, [{
19098 key: "get",
19099 value: function get(dataId) {
19100 if (!this.data.has(dataId)) {
19101 this.dataMover.moveData(this.backend, dataId);
19102 }
19103 return this.data.get(dataId);
19104 }
19105 }, {
19106 key: "set",
19107 value: function set(dataId, value) {
19108 this.dataIdsCount++;
19109 this.data.set(dataId, value);
19110 }
19111 }, {
19112 key: "has",
19113 value: function has(dataId) {
19114 return this.data.has(dataId);
19115 }
19116 }, {
19117 key: "delete",
19118 value: function _delete(dataId) {
19119 this.dataIdsCount--;
19120 return this.data.delete(dataId);
19121 }
19122 }, {
19123 key: "numDataIds",
19124 value: function numDataIds() {
19125 return this.dataIdsCount;
19126 }
19127 }]);
19128 return DataStorage;
19129 }();
19130 /**
19131 * The interface that defines the kernels that should be implemented when
19132 * adding a new backend. New backends don't need to implement every one of the
19133 * methods, this can be done gradually (throw an error for unimplemented
19134 * methods).
19135 */
19136 var KernelBackend = /*#__PURE__*/function () {
19137 function KernelBackend() {
19138 _classCallCheck(this, KernelBackend);
19139 }
19140 _createClass(KernelBackend, [{
19141 key: "refCount",
19142 value: function refCount(dataId) {
19143 return notYetImplemented('refCount');
19144 }
19145 }, {
19146 key: "incRef",
19147 value: function incRef(dataId) {
19148 return notYetImplemented('incRef');
19149 }
19150 }, {
19151 key: "timerAvailable",
19152 value: function timerAvailable() {
19153 return true;
19154 }
19155 }, {
19156 key: "time",
19157 value: function time(f) {
19158 return notYetImplemented('time');
19159 }
19160 }, {
19161 key: "read",
19162 value: function read(dataId) {
19163 return notYetImplemented('read');
19164 }
19165 }, {
19166 key: "readSync",
19167 value: function readSync(dataId) {
19168 return notYetImplemented('readSync');
19169 }
19170 }, {
19171 key: "readToGPU",
19172 value: function readToGPU(dataId, options) {
19173 return notYetImplemented('readToGPU');
19174 }
19175 }, {
19176 key: "numDataIds",
19177 value: function numDataIds() {
19178 return notYetImplemented('numDataIds');
19179 }
19180 }, {
19181 key: "disposeData",
19182 value: function disposeData(dataId, force) {
19183 return notYetImplemented('disposeData');
19184 }
19185 }, {
19186 key: "write",
19187 value: function write(values, shape, dtype) {
19188 return notYetImplemented('write');
19189 }
19190 }, {
19191 key: "move",
19192 value: function move(dataId, values, shape, dtype, refCount) {
19193 return notYetImplemented('move');
19194 }
19195 }, {
19196 key: "createTensorFromGPUData",
19197 value: function createTensorFromGPUData(values, shape, dtype) {
19198 return notYetImplemented('createTensorFromGPUData');
19199 }
19200 }, {
19201 key: "memory",
19202 value: function memory() {
19203 return notYetImplemented('memory');
19204 }
19205 /** Returns the highest precision for floats in bits (e.g. 16 or 32) */
19206 }, {
19207 key: "floatPrecision",
19208 value: function floatPrecision() {
19209 return notYetImplemented('floatPrecision');
19210 }
19211 /** Returns the smallest representable number. */
19212 }, {
19213 key: "epsilon",
19214 value: function epsilon() {
19215 return this.floatPrecision() === 32 ? EPSILON_FLOAT32$1 : EPSILON_FLOAT16$1;
19216 }
19217 }, {
19218 key: "dispose",
19219 value: function dispose() {
19220 return notYetImplemented('dispose');
19221 }
19222 }]);
19223 return KernelBackend;
19224 }();
19225 function notYetImplemented(kernelName) {
19226 throw new Error("'".concat(kernelName, "' not yet implemented or not found in the registry. ") + "This kernel may not be supported by the tfjs backend you have chosen");
19227 }
19228
19229 /**
19230 * @license
19231 * Copyright 2020 Google LLC. All Rights Reserved.
19232 * Licensed under the Apache License, Version 2.0 (the "License");
19233 * you may not use this file except in compliance with the License.
19234 * You may obtain a copy of the License at
19235 *
19236 * http://www.apache.org/licenses/LICENSE-2.0
19237 *
19238 * Unless required by applicable law or agreed to in writing, software
19239 * distributed under the License is distributed on an "AS IS" BASIS,
19240 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19241 * See the License for the specific language governing permissions and
19242 * limitations under the License.
19243 * =============================================================================
19244 */
19245 /**
19246 * Shuffles the array in-place using Fisher-Yates algorithm.
19247 *
19248 * ```js
19249 * const a = [1, 2, 3, 4, 5];
19250 * tf.util.shuffle(a);
19251 * console.log(a);
19252 * ```
19253 *
19254 * @param array The array to shuffle in-place.
19255 *
19256 * @doc {heading: 'Util', namespace: 'util'}
19257 */
19258 // tslint:disable-next-line:no-any
19259 function shuffle(array) {
19260 var counter = array.length;
19261 var index = 0;
19262 // While there are elements in the array
19263 while (counter > 0) {
19264 // Pick a random index
19265 index = Math.random() * counter | 0;
19266 // Decrease counter by 1
19267 counter--;
19268 // And swap the last element with it
19269 swap(array, counter, index);
19270 }
19271 }
19272 /**
19273 * Shuffles two arrays in-place the same way using Fisher-Yates algorithm.
19274 *
19275 * ```js
19276 * const a = [1,2,3,4,5];
19277 * const b = [11,22,33,44,55];
19278 * tf.util.shuffleCombo(a, b);
19279 * console.log(a, b);
19280 * ```
19281 *
19282 * @param array The first array to shuffle in-place.
19283 * @param array2 The second array to shuffle in-place with the same permutation
19284 * as the first array.
19285 *
19286 * @doc {heading: 'Util', namespace: 'util'}
19287 */
19288 function shuffleCombo(
19289 // tslint:disable-next-line:no-any
19290 array,
19291 // tslint:disable-next-line:no-any
19292 array2) {
19293 if (array.length !== array2.length) {
19294 throw new Error("Array sizes must match to be shuffled together " + "First array length was ".concat(array.length) + "Second array length was ".concat(array2.length));
19295 }
19296 var counter = array.length;
19297 var index = 0;
19298 // While there are elements in the array
19299 while (counter > 0) {
19300 // Pick a random index
19301 index = Math.random() * counter | 0;
19302 // Decrease counter by 1
19303 counter--;
19304 // And swap the last element of each array with it
19305 swap(array, counter, index);
19306 swap(array2, counter, index);
19307 }
19308 }
19309 /** Clamps a value to a specified range. */
19310 function clamp(min, x, max) {
19311 return Math.max(min, Math.min(x, max));
19312 }
19313 function nearestLargerEven(val) {
19314 return val % 2 === 0 ? val : val + 1;
19315 }
19316 function swap(object, left, right) {
19317 var temp = object[left];
19318 object[left] = object[right];
19319 object[right] = temp;
19320 }
19321 function sum$4(arr) {
19322 var sum = 0;
19323 for (var i = 0; i < arr.length; i++) {
19324 sum += arr[i];
19325 }
19326 return sum;
19327 }
19328 /**
19329 * Returns a sample from a uniform [a, b) distribution.
19330 *
19331 * @param a The minimum support (inclusive).
19332 * @param b The maximum support (exclusive).
19333 * @return A pseudorandom number on the half-open interval [a,b).
19334 */
19335 function randUniform(a, b) {
19336 var r = Math.random();
19337 return b * r + (1 - r) * a;
19338 }
19339 /** Returns the squared Euclidean distance between two vectors. */
19340 function distSquared(a, b) {
19341 var result = 0;
19342 for (var i = 0; i < a.length; i++) {
19343 var diff = Number(a[i]) - Number(b[i]);
19344 result += diff * diff;
19345 }
19346 return result;
19347 }
19348 /**
19349 * Asserts that the expression is true. Otherwise throws an error with the
19350 * provided message.
19351 *
19352 * ```js
19353 * const x = 2;
19354 * tf.util.assert(x === 2, 'x is not 2');
19355 * ```
19356 *
19357 * @param expr The expression to assert (as a boolean).
19358 * @param msg A function that returns the message to report when throwing an
19359 * error. We use a function for performance reasons.
19360 *
19361 * @doc {heading: 'Util', namespace: 'util'}
19362 */
19363 function assert$1(expr, msg) {
19364 if (!expr) {
19365 throw new Error(typeof msg === 'string' ? msg : msg());
19366 }
19367 }
19368 function assertShapesMatch(shapeA, shapeB) {
19369 var errorMessagePrefix = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : '';
19370 assert$1(arraysEqual(shapeA, shapeB), function () {
19371 return errorMessagePrefix + " Shapes ".concat(shapeA, " and ").concat(shapeB, " must match");
19372 });
19373 }
19374 function assertNonNull(a) {
19375 assert$1(a != null, function () {
19376 return "The input to the tensor constructor must be a non-null value.";
19377 });
19378 }
19379 /**
19380 * Returns the size (number of elements) of the tensor given its shape.
19381 *
19382 * ```js
19383 * const shape = [3, 4, 2];
19384 * const size = tf.util.sizeFromShape(shape);
19385 * console.log(size);
19386 * ```
19387 *
19388 * @doc {heading: 'Util', namespace: 'util'}
19389 */
19390 function sizeFromShape(shape) {
19391 if (shape.length === 0) {
19392 // Scalar.
19393 return 1;
19394 }
19395 var size = shape[0];
19396 for (var i = 1; i < shape.length; i++) {
19397 size *= shape[i];
19398 }
19399 return size;
19400 }
19401 function isScalarShape(shape) {
19402 return shape.length === 0;
19403 }
19404 function arraysEqualWithNull(n1, n2) {
19405 if (n1 === n2) {
19406 return true;
19407 }
19408 if (n1 == null || n2 == null) {
19409 return false;
19410 }
19411 if (n1.length !== n2.length) {
19412 return false;
19413 }
19414 for (var i = 0; i < n1.length; i++) {
19415 if (n1[i] !== null && n2[i] !== null && n1[i] !== n2[i]) {
19416 return false;
19417 }
19418 }
19419 return true;
19420 }
19421 function arraysEqual(n1, n2) {
19422 if (n1 === n2) {
19423 return true;
19424 }
19425 if (n1 == null || n2 == null) {
19426 return false;
19427 }
19428 if (n1.length !== n2.length) {
19429 return false;
19430 }
19431 for (var i = 0; i < n1.length; i++) {
19432 if (n1[i] !== n2[i]) {
19433 return false;
19434 }
19435 }
19436 return true;
19437 }
19438 function isInt(a) {
19439 return a % 1 === 0;
19440 }
19441 function tanh$3(x) {
19442 // tslint:disable-next-line:no-any
19443 if (Math.tanh != null) {
19444 // tslint:disable-next-line:no-any
19445 return Math.tanh(x);
19446 }
19447 if (x === Infinity) {
19448 return 1;
19449 } else if (x === -Infinity) {
19450 return -1;
19451 } else {
19452 var e2x = Math.exp(2 * x);
19453 return (e2x - 1) / (e2x + 1);
19454 }
19455 }
19456 function sizeToSquarishShape(size) {
19457 var width = Math.ceil(Math.sqrt(size));
19458 return [width, Math.ceil(size / width)];
19459 }
19460 /**
19461 * Creates a new array with randomized indices to a given quantity.
19462 *
19463 * ```js
19464 * const randomTen = tf.util.createShuffledIndices(10);
19465 * console.log(randomTen);
19466 * ```
19467 *
19468 * @param number Quantity of how many shuffled indices to create.
19469 *
19470 * @doc {heading: 'Util', namespace: 'util'}
19471 */
19472 function createShuffledIndices(n) {
19473 var shuffledIndices = new Uint32Array(n);
19474 for (var i = 0; i < n; ++i) {
19475 shuffledIndices[i] = i;
19476 }
19477 shuffle(shuffledIndices);
19478 return shuffledIndices;
19479 }
19480 function rightPad(a, size) {
19481 if (size <= a.length) {
19482 return a;
19483 }
19484 return a + ' '.repeat(size - a.length);
19485 }
19486 function repeatedTry(checkFn) {
19487 var delayFn = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : function (counter) {
19488 return 0;
19489 };
19490 var maxCounter = arguments.length > 2 ? arguments[2] : undefined;
19491 var scheduleFn = arguments.length > 3 ? arguments[3] : undefined;
19492 return new Promise(function (resolve, reject) {
19493 var tryCount = 0;
19494 var tryFn = function tryFn() {
19495 if (checkFn()) {
19496 resolve();
19497 return;
19498 }
19499 tryCount++;
19500 var nextBackoff = delayFn(tryCount);
19501 if (maxCounter != null && tryCount >= maxCounter) {
19502 reject();
19503 return;
19504 }
19505 if (scheduleFn != null) {
19506 scheduleFn(tryFn, nextBackoff);
19507 } else {
19508 // google3 does not allow assigning another variable to setTimeout.
19509 // Don't refactor this so scheduleFn has a default value of setTimeout.
19510 setTimeout(tryFn, nextBackoff);
19511 }
19512 };
19513 tryFn();
19514 });
19515 }
19516 /**
19517 * Given the full size of the array and a shape that may contain -1 as the
19518 * implicit dimension, returns the inferred shape where -1 is replaced.
19519 * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
19520 *
19521 * @param shape The shape, which may contain -1 in some dimension.
19522 * @param size The full size (number of elements) of the array.
19523 * @return The inferred shape where -1 is replaced with the inferred size.
19524 */
19525 function inferFromImplicitShape(shape, size) {
19526 var shapeProd = 1;
19527 var implicitIdx = -1;
19528 for (var i = 0; i < shape.length; ++i) {
19529 if (shape[i] >= 0) {
19530 shapeProd *= shape[i];
19531 } else if (shape[i] === -1) {
19532 if (implicitIdx !== -1) {
19533 throw Error("Shapes can only have 1 implicit size. " + "Found -1 at dim ".concat(implicitIdx, " and dim ").concat(i));
19534 }
19535 implicitIdx = i;
19536 } else if (shape[i] < 0) {
19537 throw Error("Shapes can not be < 0. Found ".concat(shape[i], " at dim ").concat(i));
19538 }
19539 }
19540 if (implicitIdx === -1) {
19541 if (size > 0 && size !== shapeProd) {
19542 throw Error("Size(".concat(size, ") must match the product of shape ").concat(shape));
19543 }
19544 return shape;
19545 }
19546 if (shapeProd === 0) {
19547 throw Error("Cannot infer the missing size in [".concat(shape, "] when ") + "there are 0 elements");
19548 }
19549 if (size % shapeProd !== 0) {
19550 throw Error("The implicit shape can't be a fractional number. " + "Got ".concat(size, " / ").concat(shapeProd));
19551 }
19552 var newShape = shape.slice();
19553 newShape[implicitIdx] = size / shapeProd;
19554 return newShape;
19555 }
19556 function parseAxisParam(axis, shape) {
19557 var rank = shape.length;
19558 // Normalize input
19559 axis = axis == null ? shape.map(function (s, i) {
19560 return i;
19561 }) : [].concat(axis);
19562 // Check for valid range
19563 assert$1(axis.every(function (ax) {
19564 return ax >= -rank && ax < rank;
19565 }), function () {
19566 return "All values in axis param must be in range [-".concat(rank, ", ").concat(rank, ") but ") + "got axis ".concat(axis);
19567 });
19568 // Check for only integers
19569 assert$1(axis.every(function (ax) {
19570 return isInt(ax);
19571 }), function () {
19572 return "All values in axis param must be integers but " + "got axis ".concat(axis);
19573 });
19574 // Handle negative axis.
19575 return axis.map(function (a) {
19576 return a < 0 ? rank + a : a;
19577 });
19578 }
19579 /** Reduces the shape by removing all dimensions of shape 1. */
19580 function squeezeShape(shape, axis) {
19581 var newShape = [];
19582 var keptDims = [];
19583 var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
19584 var axes = axis == null || isEmptyArray ? null : parseAxisParam(axis, shape).sort();
19585 var j = 0;
19586 for (var i = 0; i < shape.length; ++i) {
19587 if (axes != null) {
19588 if (axes[j] === i && shape[i] !== 1) {
19589 throw new Error("Can't squeeze axis ".concat(i, " since its dim '").concat(shape[i], "' is not 1"));
19590 }
19591 if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
19592 newShape.push(shape[i]);
19593 keptDims.push(i);
19594 }
19595 if (axes[j] <= i) {
19596 j++;
19597 }
19598 }
19599 if (shape[i] !== 1) {
19600 newShape.push(shape[i]);
19601 keptDims.push(i);
19602 }
19603 }
19604 return {
19605 newShape: newShape,
19606 keptDims: keptDims
19607 };
19608 }
19609 function getTypedArrayFromDType(dtype, size) {
19610 return getArrayFromDType(dtype, size);
19611 }
19612 function getArrayFromDType(dtype, size) {
19613 var values = null;
19614 if (dtype == null || dtype === 'float32') {
19615 values = new Float32Array(size);
19616 } else if (dtype === 'int32') {
19617 values = new Int32Array(size);
19618 } else if (dtype === 'bool') {
19619 values = new Uint8Array(size);
19620 } else if (dtype === 'string') {
19621 values = new Array(size);
19622 } else {
19623 throw new Error("Unknown data type ".concat(dtype));
19624 }
19625 return values;
19626 }
19627 function checkConversionForErrors(vals, dtype) {
19628 for (var i = 0; i < vals.length; i++) {
19629 var num = vals[i];
19630 if (isNaN(num) || !isFinite(num)) {
19631 throw Error("A tensor of type ".concat(dtype, " being uploaded contains ").concat(num, "."));
19632 }
19633 }
19634 }
19635 /** Returns true if the dtype is valid. */
19636 function isValidDtype(dtype) {
19637 return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || dtype === 'int32' || dtype === 'string';
19638 }
19639 /**
19640 * Returns true if the new type can't encode the old type without loss of
19641 * precision.
19642 */
19643 function hasEncodingLoss(oldType, newType) {
19644 if (newType === 'complex64') {
19645 return false;
19646 }
19647 if (newType === 'float32' && oldType !== 'complex64') {
19648 return false;
19649 }
19650 if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
19651 return false;
19652 }
19653 if (newType === 'bool' && oldType === 'bool') {
19654 return false;
19655 }
19656 return true;
19657 }
19658 function bytesPerElement(dtype) {
19659 if (dtype === 'float32' || dtype === 'int32') {
19660 return 4;
19661 } else if (dtype === 'complex64') {
19662 return 8;
19663 } else if (dtype === 'bool') {
19664 return 1;
19665 } else {
19666 throw new Error("Unknown dtype ".concat(dtype));
19667 }
19668 }
19669 /**
19670 * Returns the approximate number of bytes allocated in the string array - 2
19671 * bytes per character. Computing the exact bytes for a native string in JS
19672 * is not possible since it depends on the encoding of the html page that
19673 * serves the website.
19674 */
19675 function bytesFromStringArray(arr) {
19676 if (arr == null) {
19677 return 0;
19678 }
19679 var bytes = 0;
19680 arr.forEach(function (x) {
19681 return bytes += x.length;
19682 });
19683 return bytes;
19684 }
19685 /** Returns true if the value is a string. */
19686 function isString(value) {
19687 return typeof value === 'string' || value instanceof String;
19688 }
19689 function isBoolean(value) {
19690 return typeof value === 'boolean';
19691 }
19692 function isNumber(value) {
19693 return typeof value === 'number';
19694 }
19695 function inferDtype(values) {
19696 if (Array.isArray(values)) {
19697 return inferDtype(values[0]);
19698 }
19699 if (values instanceof Float32Array) {
19700 return 'float32';
19701 } else if (values instanceof Int32Array || values instanceof Uint8Array || values instanceof Uint8ClampedArray) {
19702 return 'int32';
19703 } else if (isNumber(values)) {
19704 return 'float32';
19705 } else if (isString(values)) {
19706 return 'string';
19707 } else if (isBoolean(values)) {
19708 return 'bool';
19709 }
19710 return 'float32';
19711 }
19712 function isFunction(f) {
19713 return !!(f && f.constructor && f.call && f.apply);
19714 }
19715 function nearestDivisor(size, start) {
19716 for (var i = start; i < size; ++i) {
19717 if (size % i === 0) {
19718 return i;
19719 }
19720 }
19721 return size;
19722 }
19723 function computeStrides(shape) {
19724 var rank = shape.length;
19725 if (rank < 2) {
19726 return [];
19727 }
19728 // Last dimension has implicit stride of 1, thus having D-1 (instead of D)
19729 // strides.
19730 var strides = new Array(rank - 1);
19731 strides[rank - 2] = shape[rank - 1];
19732 for (var i = rank - 3; i >= 0; --i) {
19733 strides[i] = strides[i + 1] * shape[i + 1];
19734 }
19735 return strides;
19736 }
19737 function createNestedArray(offset, shape, a) {
19738 var isComplex = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
19739 var ret = new Array();
19740 if (shape.length === 1) {
19741 var d = shape[0] * (isComplex ? 2 : 1);
19742 for (var i = 0; i < d; i++) {
19743 ret[i] = a[offset + i];
19744 }
19745 } else {
19746 var _d = shape[0];
19747 var rest = shape.slice(1);
19748 var len = rest.reduce(function (acc, c) {
19749 return acc * c;
19750 }) * (isComplex ? 2 : 1);
19751 for (var _i = 0; _i < _d; _i++) {
19752 ret[_i] = createNestedArray(offset + _i * len, rest, a, isComplex);
19753 }
19754 }
19755 return ret;
19756 }
19757 // Provide a nested array of TypedArray in given shape.
19758 function toNestedArray(shape, a) {
19759 var isComplex = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
19760 if (shape.length === 0) {
19761 // Scalar type should return a single number.
19762 return a[0];
19763 }
19764 var size = shape.reduce(function (acc, c) {
19765 return acc * c;
19766 }) * (isComplex ? 2 : 1);
19767 if (size === 0) {
19768 // A tensor with shape zero should be turned into empty list.
19769 return [];
19770 }
19771 if (size !== a.length) {
19772 throw new Error("[".concat(shape, "] does not match the input size ").concat(a.length).concat(isComplex ? ' for a complex tensor' : '', "."));
19773 }
19774 return createNestedArray(0, shape, a, isComplex);
19775 }
19776 function convertBackendValuesAndArrayBuffer(data, dtype) {
19777 // If is type Uint8Array[], return it directly.
19778 if (Array.isArray(data)) {
19779 return data;
19780 }
19781 if (dtype === 'float32') {
19782 return data instanceof Float32Array ? data : new Float32Array(data);
19783 } else if (dtype === 'int32') {
19784 return data instanceof Int32Array ? data : new Int32Array(data);
19785 } else if (dtype === 'bool' || dtype === 'string') {
19786 return Uint8Array.from(new Int32Array(data));
19787 } else {
19788 throw new Error("Unknown dtype ".concat(dtype));
19789 }
19790 }
19791 function makeOnesTypedArray(size, dtype) {
19792 var array = makeZerosTypedArray(size, dtype);
19793 for (var i = 0; i < array.length; i++) {
19794 array[i] = 1;
19795 }
19796 return array;
19797 }
19798 function makeZerosTypedArray(size, dtype) {
19799 if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
19800 return new Float32Array(size);
19801 } else if (dtype === 'int32') {
19802 return new Int32Array(size);
19803 } else if (dtype === 'bool') {
19804 return new Uint8Array(size);
19805 } else {
19806 throw new Error("Unknown data type ".concat(dtype));
19807 }
19808 }
19809 /**
19810 * Make nested `TypedArray` filled with zeros.
19811 * @param shape The shape information for the nested array.
19812 * @param dtype dtype of the array element.
19813 */
19814 function makeZerosNestedTypedArray(shape, dtype) {
19815 var size = shape.reduce(function (prev, curr) {
19816 return prev * curr;
19817 }, 1);
19818 if (dtype == null || dtype === 'float32') {
19819 return toNestedArray(shape, new Float32Array(size));
19820 } else if (dtype === 'int32') {
19821 return toNestedArray(shape, new Int32Array(size));
19822 } else if (dtype === 'bool') {
19823 return toNestedArray(shape, new Uint8Array(size));
19824 } else {
19825 throw new Error("Unknown data type ".concat(dtype));
19826 }
19827 }
19828 function assertNonNegativeIntegerDimensions(shape) {
19829 shape.forEach(function (dimSize) {
19830 assert$1(Number.isInteger(dimSize) && dimSize >= 0, function () {
19831 return "Tensor must have a shape comprised of positive integers but got " + "shape [".concat(shape, "].");
19832 });
19833 });
19834 }
19835 /**
19836 * Computes flat index for a given location (multidimentionsal index) in a
19837 * Tensor/multidimensional array.
19838 *
19839 * @param locs Location in the tensor.
19840 * @param rank Rank of the tensor.
19841 * @param strides Tensor strides.
19842 */
19843 function locToIndex(locs, rank, strides) {
19844 if (rank === 0) {
19845 return 0;
19846 } else if (rank === 1) {
19847 return locs[0];
19848 }
19849 var index = locs[locs.length - 1];
19850 for (var i = 0; i < locs.length - 1; ++i) {
19851 index += strides[i] * locs[i];
19852 }
19853 return index;
19854 }
19855 /**
19856 * Computes the location (multidimensional index) in a
19857 * tensor/multidimentional array for a given flat index.
19858 *
19859 * @param index Index in flat array.
19860 * @param rank Rank of tensor.
19861 * @param strides Strides of tensor.
19862 */
19863 function indexToLoc(index, rank, strides) {
19864 if (rank === 0) {
19865 return [];
19866 } else if (rank === 1) {
19867 return [index];
19868 }
19869 var locs = new Array(rank);
19870 for (var i = 0; i < locs.length - 1; ++i) {
19871 locs[i] = Math.floor(index / strides[i]);
19872 index -= locs[i] * strides[i];
19873 }
19874 locs[locs.length - 1] = index;
19875 return locs;
19876 }
19877 /**
19878 * This method asserts whether an object is a Promise instance.
19879 * @param object
19880 */
19881 // tslint:disable-next-line: no-any
19882 function isPromise(object) {
19883 // We chose to not use 'obj instanceOf Promise' for two reasons:
19884 // 1. It only reliably works for es6 Promise, not other Promise
19885 // implementations.
19886 // 2. It doesn't work with framework that uses zone.js. zone.js monkey
19887 // patch the async calls, so it is possible the obj (patched) is
19888 // comparing to a pre-patched Promise.
19889 return object && object.then && typeof object.then === 'function';
19890 }
19891
19892 // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.
19893 var TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
19894 /**
19895 * The environment contains evaluated flags as well as the registered platform.
19896 * This is always used as a global singleton and can be retrieved with
19897 * `tf.env()`.
19898 *
19899 * @doc {heading: 'Environment'}
19900 */
19901 var Environment = /*#__PURE__*/function () {
19902 // tslint:disable-next-line: no-any
19903 function Environment(global) {
19904 _classCallCheck(this, Environment);
19905 this.global = global;
19906 this.flags = {};
19907 this.flagRegistry = {};
19908 this.urlFlags = {};
19909 // Jasmine spies on this in 'environment_test.ts'
19910 this.getQueryParams = getQueryParams;
19911 this.populateURLFlags();
19912 }
19913 _createClass(Environment, [{
19914 key: "setPlatform",
19915 value: function setPlatform(platformName, platform) {
19916 if (this.platform != null) {
19917 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
19918 console.warn("Platform ".concat(this.platformName, " has already been set. ") + "Overwriting the platform with ".concat(platformName, "."));
19919 }
19920 }
19921 this.platformName = platformName;
19922 this.platform = platform;
19923 }
19924 }, {
19925 key: "registerFlag",
19926 value: function registerFlag(flagName, evaluationFn, setHook) {
19927 this.flagRegistry[flagName] = {
19928 evaluationFn: evaluationFn,
19929 setHook: setHook
19930 };
19931 // Override the flag value from the URL. This has to happen here because
19932 // the environment is initialized before flags get registered.
19933 if (this.urlFlags[flagName] != null) {
19934 var flagValue = this.urlFlags[flagName];
19935 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
19936 console.warn("Setting feature override from URL ".concat(flagName, ": ").concat(flagValue, "."));
19937 }
19938 this.set(flagName, flagValue);
19939 }
19940 }
19941 }, {
19942 key: "getAsync",
19943 value: function () {
19944 var _getAsync = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(flagName) {
19945 return _regeneratorRuntime().wrap(function _callee$(_context) {
19946 while (1) switch (_context.prev = _context.next) {
19947 case 0:
19948 if (!(flagName in this.flags)) {
19949 _context.next = 2;
19950 break;
19951 }
19952 return _context.abrupt("return", this.flags[flagName]);
19953 case 2:
19954 _context.next = 4;
19955 return this.evaluateFlag(flagName);
19956 case 4:
19957 this.flags[flagName] = _context.sent;
19958 return _context.abrupt("return", this.flags[flagName]);
19959 case 6:
19960 case "end":
19961 return _context.stop();
19962 }
19963 }, _callee, this);
19964 }));
19965 function getAsync(_x) {
19966 return _getAsync.apply(this, arguments);
19967 }
19968 return getAsync;
19969 }()
19970 }, {
19971 key: "get",
19972 value: function get(flagName) {
19973 if (flagName in this.flags) {
19974 return this.flags[flagName];
19975 }
19976 var flagValue = this.evaluateFlag(flagName);
19977 if (isPromise(flagValue)) {
19978 throw new Error("Flag ".concat(flagName, " cannot be synchronously evaluated. ") + "Please use getAsync() instead.");
19979 }
19980 this.flags[flagName] = flagValue;
19981 return this.flags[flagName];
19982 }
19983 }, {
19984 key: "getNumber",
19985 value: function getNumber(flagName) {
19986 return this.get(flagName);
19987 }
19988 }, {
19989 key: "getBool",
19990 value: function getBool(flagName) {
19991 return this.get(flagName);
19992 }
19993 }, {
19994 key: "getString",
19995 value: function getString(flagName) {
19996 return this.get(flagName);
19997 }
19998 }, {
19999 key: "getFlags",
20000 value: function getFlags() {
20001 return this.flags;
20002 }
20003 // For backwards compatibility.
20004 }, {
20005 key: "features",
20006 get: function get() {
20007 return this.flags;
20008 }
20009 }, {
20010 key: "set",
20011 value: function set(flagName, value) {
20012 if (this.flagRegistry[flagName] == null) {
20013 throw new Error("Cannot set flag ".concat(flagName, " as it has not been registered."));
20014 }
20015 this.flags[flagName] = value;
20016 if (this.flagRegistry[flagName].setHook != null) {
20017 this.flagRegistry[flagName].setHook(value);
20018 }
20019 }
20020 }, {
20021 key: "evaluateFlag",
20022 value: function evaluateFlag(flagName) {
20023 if (this.flagRegistry[flagName] == null) {
20024 throw new Error("Cannot evaluate flag '".concat(flagName, "': no evaluation function found."));
20025 }
20026 return this.flagRegistry[flagName].evaluationFn();
20027 }
20028 }, {
20029 key: "setFlags",
20030 value: function setFlags(flags) {
20031 this.flags = Object.assign({}, flags);
20032 }
20033 }, {
20034 key: "reset",
20035 value: function reset() {
20036 this.flags = {};
20037 this.urlFlags = {};
20038 this.populateURLFlags();
20039 }
20040 }, {
20041 key: "populateURLFlags",
20042 value: function populateURLFlags() {
20043 var _this = this;
20044 if (typeof this.global === 'undefined' || typeof this.global.location === 'undefined' || typeof this.global.location.search === 'undefined') {
20045 return;
20046 }
20047 var urlParams = this.getQueryParams(this.global.location.search);
20048 if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
20049 var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
20050 keyValues.forEach(function (keyValue) {
20051 var _keyValue$split = keyValue.split(':'),
20052 _keyValue$split2 = _slicedToArray(_keyValue$split, 2),
20053 key = _keyValue$split2[0],
20054 value = _keyValue$split2[1];
20055 _this.urlFlags[key] = parseValue(key, value);
20056 });
20057 }
20058 }
20059 }]);
20060 return Environment;
20061 }();
20062 function getQueryParams(queryString) {
20063 var params = {};
20064 queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function (s) {
20065 for (var _len = arguments.length, t = new Array(_len > 1 ? _len - 1 : 0), _key = 1; _key < _len; _key++) {
20066 t[_key - 1] = arguments[_key];
20067 }
20068 decodeParam(params, t[0], t[1]);
20069 return t.join('=');
20070 });
20071 return params;
20072 }
20073 function decodeParam(params, name, value) {
20074 params[decodeURIComponent(name)] = decodeURIComponent(value || '');
20075 }
20076 function parseValue(flagName, value) {
20077 var lowerCaseValue = value.toLowerCase();
20078 if (lowerCaseValue === 'true' || lowerCaseValue === 'false') {
20079 return lowerCaseValue === 'true';
20080 } else if ("".concat(+lowerCaseValue) === lowerCaseValue) {
20081 return +lowerCaseValue;
20082 } else {
20083 return value;
20084 }
20085 }
20086 /**
20087 * Returns the current environment (a global singleton).
20088 *
20089 * The environment object contains the evaluated feature values as well as the
20090 * active platform.
20091 *
20092 * @doc {heading: 'Environment'}
20093 */
20094 function env() {
20095 return exports.ENV;
20096 }
20097 exports.ENV = null;
20098 function setEnvironmentGlobal(environment) {
20099 exports.ENV = environment;
20100 }
20101
20102 /**
20103 * @license
20104 * Copyright 2020 Google LLC. All Rights Reserved.
20105 * Licensed under the Apache License, Version 2.0 (the "License");
20106 * you may not use this file except in compliance with the License.
20107 * You may obtain a copy of the License at
20108 *
20109 * http://www.apache.org/licenses/LICENSE-2.0
20110 *
20111 * Unless required by applicable law or agreed to in writing, software
20112 * distributed under the License is distributed on an "AS IS" BASIS,
20113 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20114 * See the License for the specific language governing permissions and
20115 * limitations under the License.
20116 * =============================================================================
20117 */
20118 // Note that the identifier globalNameSpace is scoped to this module, but will
20119 // always resolve to the same global object regardless of how the module is
20120 // resolved.
20121 // tslint:disable-next-line:no-any
20122 var globalNameSpace;
20123 // tslint:disable-next-line:no-any
20124 function getGlobalNamespace() {
20125 if (globalNameSpace == null) {
20126 // tslint:disable-next-line:no-any
20127 var ns;
20128 if (typeof window !== 'undefined') {
20129 ns = window;
20130 } else if (typeof global !== 'undefined') {
20131 ns = global;
20132 } else if (typeof process !== 'undefined') {
20133 ns = process;
20134 } else if (typeof self !== 'undefined') {
20135 ns = self;
20136 } else {
20137 throw new Error('Could not find a global object');
20138 }
20139 globalNameSpace = ns;
20140 }
20141 return globalNameSpace;
20142 }
20143 // tslint:disable-next-line:no-any
20144 function getGlobalMap() {
20145 var ns = getGlobalNamespace();
20146 if (ns._tfGlobals == null) {
20147 ns._tfGlobals = new Map();
20148 }
20149 return ns._tfGlobals;
20150 }
20151 /**
20152 * Returns a globally accessible 'singleton' object.
20153 *
20154 * @param key the name of the object
20155 * @param init a function to initialize to initialize this object
20156 * the first time it is fetched.
20157 */
20158 function getGlobal(key, init) {
20159 var globalMap = getGlobalMap();
20160 if (globalMap.has(key)) {
20161 return globalMap.get(key);
20162 } else {
20163 var singleton = init();
20164 globalMap.set(key, singleton);
20165 return globalMap.get(key);
20166 }
20167 }
20168
20169 var Abs = 'Abs';
20170 var Acos = 'Acos';
20171 var Acosh = 'Acosh';
20172 var Add$1 = 'Add';
20173 var AddN = 'AddN';
20174 var All = 'All';
20175 var Any = 'Any';
20176 var ArgMax = 'ArgMax';
20177 var ArgMin = 'ArgMin';
20178 var Asin = 'Asin';
20179 var Asinh = 'Asinh';
20180 var Atan = 'Atan';
20181 var Atanh = 'Atanh';
20182 var Atan2 = 'Atan2';
20183 var AvgPool = 'AvgPool';
20184 var AvgPoolGrad = 'AvgPoolGrad';
20185 var AvgPool3D = 'AvgPool3D';
20186 var AvgPool3DGrad = 'AvgPool3DGrad';
20187 var BatchMatMul = 'BatchMatMul';
20188 var BatchToSpaceND = 'BatchToSpaceND';
20189 var Bincount = 'Bincount';
20190 var BitwiseAnd = 'BitwiseAnd';
20191 var BroadcastTo = 'BroadcastTo';
20192 var BroadcastArgs = 'BroadcastArgs';
20193 var Cast = 'Cast';
20194 var Ceil = 'Ceil';
20195 var ClipByValue = 'ClipByValue';
20196 var Complex = 'Complex';
20197 var ComplexAbs = 'ComplexAbs';
20198 var Concat = 'Concat';
20199 var Conv2D$1 = 'Conv2D';
20200 var Conv2DBackpropFilter = 'Conv2DBackpropFilter';
20201 var Conv2DBackpropInput = 'Conv2DBackpropInput';
20202 var Conv3D$1 = 'Conv3D';
20203 var Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2';
20204 var Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2';
20205 var Cos = 'Cos';
20206 var Cosh = 'Cosh';
20207 var Cumprod = 'Cumprod';
20208 var Cumsum = 'Cumsum';
20209 var CropAndResize = 'CropAndResize';
20210 var DenseBincount = 'DenseBincount';
20211 var DepthToSpace = 'DepthToSpace';
20212 var DepthwiseConv2dNative = 'DepthwiseConv2dNative';
20213 var DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter';
20214 var DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput';
20215 var Diag = 'Diag';
20216 var Dilation2D = 'Dilation2D';
20217 var Dilation2DBackpropInput = 'Dilation2DBackpropInput';
20218 var Dilation2DBackpropFilter = 'Dilation2DBackpropFilter';
20219 var Draw = 'Draw';
20220 var RealDiv = 'RealDiv';
20221 var Einsum = 'Einsum';
20222 var Elu$1 = 'Elu';
20223 var EluGrad = 'EluGrad';
20224 var Erf = 'Erf';
20225 var Equal = 'Equal';
20226 var Exp = 'Exp';
20227 var ExpandDims = 'ExpandDims';
20228 var Expm1 = 'Expm1';
20229 var FFT = 'FFT';
20230 var Fill = 'Fill';
20231 var FlipLeftRight = 'FlipLeftRight';
20232 var Floor = 'Floor';
20233 var FloorDiv = 'FloorDiv';
20234 var FusedBatchNorm = 'FusedBatchNorm';
20235 var GatherV2 = 'GatherV2';
20236 var GatherNd = 'GatherNd';
20237 var Greater = 'Greater';
20238 var GreaterEqual = 'GreaterEqual';
20239 var Identity$1 = 'Identity';
20240 var IFFT = 'IFFT';
20241 var Imag = 'Imag';
20242 var IsFinite = 'IsFinite';
20243 var IsInf = 'IsInf';
20244 var IsNan = 'IsNan';
20245 var LeakyRelu = 'LeakyRelu';
20246 var Less = 'Less';
20247 var LessEqual = 'LessEqual';
20248 var LinSpace = 'LinSpace';
20249 var Log = 'Log';
20250 var Log1p = 'Log1p';
20251 var LogicalAnd = 'LogicalAnd';
20252 var LogicalNot = 'LogicalNot';
20253 var LogicalOr = 'LogicalOr';
20254 var LogicalXor = 'LogicalXor';
20255 var LogSoftmax$1 = 'LogSoftmax';
20256 var LowerBound = 'LowerBound';
20257 var LRN = 'LRN';
20258 var LRNGrad = 'LRNGrad';
20259 var MatrixBandPart = 'MatrixBandPart';
20260 var Max = 'Max';
20261 var Maximum$1 = 'Maximum';
20262 var MaxPool = 'MaxPool';
20263 var MaxPoolGrad = 'MaxPoolGrad';
20264 var MaxPool3D = 'MaxPool3D';
20265 var MaxPool3DGrad = 'MaxPool3DGrad';
20266 var MaxPoolWithArgmax = 'MaxPoolWithArgmax';
20267 var Mean = 'Mean';
20268 var Min = 'Min';
20269 var Minimum$1 = 'Minimum';
20270 var MirrorPad = 'MirrorPad';
20271 var Mod = 'Mod';
20272 var Multinomial = 'Multinomial';
20273 var Multiply$1 = 'Multiply';
20274 var Neg = 'Neg';
20275 var NotEqual = 'NotEqual';
20276 var NonMaxSuppressionV3 = 'NonMaxSuppressionV3';
20277 var NonMaxSuppressionV4 = 'NonMaxSuppressionV4';
20278 var NonMaxSuppressionV5 = 'NonMaxSuppressionV5';
20279 var OnesLike = 'OnesLike';
20280 var OneHot = 'OneHot';
20281 var Pack = 'Pack';
20282 var PadV2 = 'PadV2';
20283 var Pool = 'Pool';
20284 var Pow = 'Pow';
20285 var Prelu = 'Prelu';
20286 var Prod = 'Prod';
20287 var RaggedGather = 'RaggedGather';
20288 var RaggedRange = 'RaggedRange';
20289 var RaggedTensorToTensor = 'RaggedTensorToTensor';
20290 var Range = 'Range';
20291 var Real = 'Real';
20292 var Reciprocal = 'Reciprocal';
20293 var Relu$1 = 'Relu';
20294 var Reshape$1 = 'Reshape';
20295 var ResizeNearestNeighbor = 'ResizeNearestNeighbor';
20296 var ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad';
20297 var ResizeBilinear = 'ResizeBilinear';
20298 var ResizeBilinearGrad = 'ResizeBilinearGrad';
20299 var Relu6$1 = 'Relu6';
20300 var Reverse = 'Reverse';
20301 var Round = 'Round';
20302 var Rsqrt = 'Rsqrt';
20303 var ScatterNd = 'ScatterNd';
20304 var TensorScatterUpdate = 'TensorScatterUpdate';
20305 var SearchSorted = 'SearchSorted';
20306 var Select = 'Select';
20307 var Selu$1 = 'Selu';
20308 var Slice = 'Slice';
20309 var Sin = 'Sin';
20310 var Sinh = 'Sinh';
20311 var Sign = 'Sign';
20312 var Sigmoid$1 = 'Sigmoid';
20313 var Softplus$1 = 'Softplus';
20314 var Sqrt = 'Sqrt';
20315 var Sum = 'Sum';
20316 var SpaceToBatchND = 'SpaceToBatchND';
20317 var SplitV = 'SplitV';
20318 var Softmax$2 = 'Softmax';
20319 var SparseFillEmptyRows = 'SparseFillEmptyRows';
20320 var SparseReshape = 'SparseReshape';
20321 var SparseSegmentMean = 'SparseSegmentMean';
20322 var SparseSegmentSum = 'SparseSegmentSum';
20323 var SparseToDense = 'SparseToDense';
20324 var SquaredDifference = 'SquaredDifference';
20325 var Square = 'Square';
20326 var StaticRegexReplace = 'StaticRegexReplace';
20327 var StridedSlice = 'StridedSlice';
20328 var StringNGrams = 'StringNGrams';
20329 var StringSplit = 'StringSplit';
20330 var StringToHashBucketFast = 'StringToHashBucketFast';
20331 var Sub = 'Sub';
20332 var Tan = 'Tan';
20333 var Tanh$1 = 'Tanh';
20334 var Tile = 'Tile';
20335 var TopK = 'TopK';
20336 var Transform = 'Transform';
20337 var Transpose = 'Transpose';
20338 var Unique = 'Unique';
20339 var Unpack = 'Unpack';
20340 var UnsortedSegmentSum = 'UnsortedSegmentSum';
20341 var UpperBound = 'UpperBound';
20342 var ZerosLike = 'ZerosLike';
20343 /**
20344 * TensorFlow.js-only kernels
20345 */
20346 var Step = 'Step';
20347 var FromPixels = 'FromPixels';
20348 var RotateWithOffset = 'RotateWithOffset';
20349 var _FusedMatMul = '_FusedMatMul';
20350 var FusedConv2D = 'FusedConv2D';
20351 var FusedDepthwiseConv2D = 'FusedDepthwiseConv2D';
20352
20353 /**
20354 * @license
20355 * Copyright 2018 Google LLC. All Rights Reserved.
20356 * Licensed under the Apache License, Version 2.0 (the "License");
20357 * you may not use this file except in compliance with the License.
20358 * You may obtain a copy of the License at
20359 *
20360 * http://www.apache.org/licenses/LICENSE-2.0
20361 *
20362 * Unless required by applicable law or agreed to in writing, software
20363 * distributed under the License is distributed on an "AS IS" BASIS,
20364 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20365 * See the License for the specific language governing permissions and
20366 * limitations under the License.
20367 * =============================================================================
20368 */
20369 function warn() {
20370 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
20371 var _console;
20372 (_console = console).warn.apply(_console, arguments);
20373 }
20374 }
20375 function log$3() {
20376 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
20377 var _console2;
20378 (_console2 = console).log.apply(_console2, arguments);
20379 }
20380 }
20381
20382 var kernelRegistry = getGlobal('kernelRegistry', function () {
20383 return new Map();
20384 });
20385 var gradRegistry = getGlobal('gradRegistry', function () {
20386 return new Map();
20387 });
20388 /**
20389 * Returns the kernel function (code) associated with the provided names.
20390 *
20391 * @param kernelName The official name of the kernel.
20392 * @param backendName The official name of the backend.
20393 */
20394 function getKernel(kernelName, backendName) {
20395 var key = makeKey(kernelName, backendName);
20396 return kernelRegistry.get(key);
20397 }
20398 /**
20399 * Returns the registered gradient info associated with the provided kernel.
20400 * @param kernelName The official TF kernel name.
20401 */
20402 function getGradient(kernelName) {
20403 return gradRegistry.get(kernelName);
20404 }
20405 function getKernelsForBackend(backendName) {
20406 var it = kernelRegistry.entries();
20407 var result = [];
20408 while (true) {
20409 var _it$next = it.next(),
20410 done = _it$next.done,
20411 value = _it$next.value;
20412 if (done) {
20413 break;
20414 }
20415 var _value = _slicedToArray(value, 2),
20416 key = _value[0],
20417 config = _value[1];
20418 var _key$split = key.split('_'),
20419 _key$split2 = _slicedToArray(_key$split, 1),
20420 backend = _key$split2[0];
20421 if (backend === backendName) {
20422 result.push(config);
20423 }
20424 }
20425 return result;
20426 }
20427 /**
20428 * Registers the function (forward pass) for the kernel in a global registry.
20429 *
20430 * @param config A config object with the following properties:
20431 * - `kernelName` The official name of the kernel.
20432 * - `backendName` The official name of the backend.
20433 * - `kernelFunc` The function to run during the forward pass of the kernel.
20434 * - `setupFunc` Optional. Gets called once, after the backend initializes.
20435 * - `disposeFunc` Optional. Gets called once, right before the backend is
20436 * disposed.
20437 */
20438 function registerKernel(config) {
20439 var kernelName = config.kernelName,
20440 backendName = config.backendName;
20441 var key = makeKey(kernelName, backendName);
20442 if (kernelRegistry.has(key)) {
20443 warn("The kernel '".concat(kernelName, "' for backend ") + "'".concat(backendName, "' is already registered"));
20444 }
20445 kernelRegistry.set(key, config);
20446 }
20447 /**
20448 * Registers a gradient function for a given kernel in the global registry,
20449 * to be used during the back-propagation of that kernel.
20450 *
20451 * @param config An object with the following properties:
20452 * - `kernelName` The name of the kernel that the gradient function is for.
20453 * - `gradFunc` The function to run during back-propagation.
20454 */
20455 function registerGradient(config) {
20456 var kernelName = config.kernelName;
20457 if (gradRegistry.has(kernelName)) {
20458 // TODO (yassogba) after 3.0 assess whether we need to keep this gated
20459 // to debug mode.
20460 if (env().getBool('DEBUG')) {
20461 warn("Overriding the gradient for '".concat(kernelName, "'"));
20462 }
20463 }
20464 gradRegistry.set(kernelName, config);
20465 }
20466 /**
20467 * Removes the kernel function from the registry.
20468 *
20469 * @param kernelName The official name of the kernel.
20470 * @param backendName The official name of the backend.
20471 *
20472 */
20473 function unregisterKernel(kernelName, backendName) {
20474 var key = makeKey(kernelName, backendName);
20475 if (!kernelRegistry.has(key)) {
20476 throw new Error("The kernel '".concat(kernelName, "' for backend ") + "'".concat(backendName, "' is not registered"));
20477 }
20478 kernelRegistry.delete(key);
20479 }
20480 /** Removes the registered gradient from the global registry. */
20481 function unregisterGradient(kernelName) {
20482 if (!gradRegistry.has(kernelName)) {
20483 throw new Error("The gradient '".concat(kernelName, "' for backend is not registered"));
20484 }
20485 gradRegistry.delete(kernelName);
20486 }
20487 /**
20488 * Finds kernels that have already been registered to a backend and re-registers
20489 * them for a new backend. Useful for registering custom backends.
20490 * @param registeredBackendName Already registered backend.
20491 * @param newBackendName New backend.
20492 */
20493 function copyRegisteredKernels(registeredBackendName, newBackendName) {
20494 var kernels = getKernelsForBackend(registeredBackendName);
20495 kernels.forEach(function (kernelConfig) {
20496 var newKernelConfig = Object.assign({}, kernelConfig, {
20497 backendName: newBackendName
20498 });
20499 registerKernel(newKernelConfig);
20500 });
20501 }
20502 function makeKey(kernelName, backendName) {
20503 return "".concat(backendName, "_").concat(kernelName);
20504 }
20505
20506 /**
20507 * @license
20508 * Copyright 2023 Google LLC.
20509 * Licensed under the Apache License, Version 2.0 (the "License");
20510 * you may not use this file except in compliance with the License.
20511 * You may obtain a copy of the License at
20512 *
20513 * http://www.apache.org/licenses/LICENSE-2.0
20514 *
20515 * Unless required by applicable law or agreed to in writing, software
20516 * distributed under the License is distributed on an "AS IS" BASIS,
20517 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20518 * See the License for the specific language governing permissions and
20519 * limitations under the License.
20520 * =============================================================================
20521 */
20522 function isTypedArrayBrowser(a) {
20523 return a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array || a instanceof Uint8ClampedArray;
20524 }
20525
20526 var long = Long$1;
20527
20528 /**
20529 * wasm optimizations, to do native i64 multiplication and divide
20530 */
20531 var wasm = null;
20532 try {
20533 wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11])), {}).exports;
20534 } catch (e) {
20535 // no wasm support :(
20536 }
20537
20538 /**
20539 * Constructs a 64 bit two's-complement integer, given its low and high 32 bit values as *signed* integers.
20540 * See the from* functions below for more convenient ways of constructing Longs.
20541 * @exports Long
20542 * @class A Long class for representing a 64 bit two's-complement integer value.
20543 * @param {number} low The low (signed) 32 bits of the long
20544 * @param {number} high The high (signed) 32 bits of the long
20545 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
20546 * @constructor
20547 */
20548 function Long$1(low, high, unsigned) {
20549 /**
20550 * The low 32 bits as a signed value.
20551 * @type {number}
20552 */
20553 this.low = low | 0;
20554
20555 /**
20556 * The high 32 bits as a signed value.
20557 * @type {number}
20558 */
20559 this.high = high | 0;
20560
20561 /**
20562 * Whether unsigned or not.
20563 * @type {boolean}
20564 */
20565 this.unsigned = !!unsigned;
20566 }
20567
20568 // The internal representation of a long is the two given signed, 32-bit values.
20569 // We use 32-bit pieces because these are the size of integers on which
20570 // Javascript performs bit-operations. For operations like addition and
20571 // multiplication, we split each number into 16 bit pieces, which can easily be
20572 // multiplied within Javascript's floating-point representation without overflow
20573 // or change in sign.
20574 //
20575 // In the algorithms below, we frequently reduce the negative case to the
20576 // positive case by negating the input(s) and then post-processing the result.
20577 // Note that we must ALWAYS check specially whether those values are MIN_VALUE
20578 // (-2^63) because -MIN_VALUE == MIN_VALUE (since 2^63 cannot be represented as
20579 // a positive number, it overflows back into a negative). Not handling this
20580 // case would often result in infinite recursion.
20581 //
20582 // Common constant values ZERO, ONE, NEG_ONE, etc. are defined below the from*
20583 // methods on which they depend.
20584
20585 /**
20586 * An indicator used to reliably determine if an object is a Long or not.
20587 * @type {boolean}
20588 * @const
20589 * @private
20590 */
20591 Long$1.prototype.__isLong__;
20592 Object.defineProperty(Long$1.prototype, "__isLong__", {
20593 value: true
20594 });
20595
20596 /**
20597 * @function
20598 * @param {*} obj Object
20599 * @returns {boolean}
20600 * @inner
20601 */
20602 function isLong(obj) {
20603 return (obj && obj["__isLong__"]) === true;
20604 }
20605
20606 /**
20607 * Tests if the specified object is a Long.
20608 * @function
20609 * @param {*} obj Object
20610 * @returns {boolean}
20611 */
20612 Long$1.isLong = isLong;
20613
20614 /**
20615 * A cache of the Long representations of small integer values.
20616 * @type {!Object}
20617 * @inner
20618 */
20619 var INT_CACHE = {};
20620
20621 /**
20622 * A cache of the Long representations of small unsigned integer values.
20623 * @type {!Object}
20624 * @inner
20625 */
20626 var UINT_CACHE = {};
20627
20628 /**
20629 * @param {number} value
20630 * @param {boolean=} unsigned
20631 * @returns {!Long}
20632 * @inner
20633 */
20634 function fromInt(value, unsigned) {
20635 var obj, cachedObj, cache;
20636 if (unsigned) {
20637 value >>>= 0;
20638 if (cache = 0 <= value && value < 256) {
20639 cachedObj = UINT_CACHE[value];
20640 if (cachedObj) return cachedObj;
20641 }
20642 obj = fromBits(value, (value | 0) < 0 ? -1 : 0, true);
20643 if (cache) UINT_CACHE[value] = obj;
20644 return obj;
20645 } else {
20646 value |= 0;
20647 if (cache = -128 <= value && value < 128) {
20648 cachedObj = INT_CACHE[value];
20649 if (cachedObj) return cachedObj;
20650 }
20651 obj = fromBits(value, value < 0 ? -1 : 0, false);
20652 if (cache) INT_CACHE[value] = obj;
20653 return obj;
20654 }
20655 }
20656
20657 /**
20658 * Returns a Long representing the given 32 bit integer value.
20659 * @function
20660 * @param {number} value The 32 bit integer in question
20661 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
20662 * @returns {!Long} The corresponding Long value
20663 */
20664 Long$1.fromInt = fromInt;
20665
20666 /**
20667 * @param {number} value
20668 * @param {boolean=} unsigned
20669 * @returns {!Long}
20670 * @inner
20671 */
20672 function fromNumber(value, unsigned) {
20673 if (isNaN(value)) return unsigned ? UZERO : ZERO;
20674 if (unsigned) {
20675 if (value < 0) return UZERO;
20676 if (value >= TWO_PWR_64_DBL) return MAX_UNSIGNED_VALUE;
20677 } else {
20678 if (value <= -TWO_PWR_63_DBL) return MIN_VALUE;
20679 if (value + 1 >= TWO_PWR_63_DBL) return MAX_VALUE;
20680 }
20681 if (value < 0) return fromNumber(-value, unsigned).neg();
20682 return fromBits(value % TWO_PWR_32_DBL | 0, value / TWO_PWR_32_DBL | 0, unsigned);
20683 }
20684
20685 /**
20686 * Returns a Long representing the given value, provided that it is a finite number. Otherwise, zero is returned.
20687 * @function
20688 * @param {number} value The number in question
20689 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
20690 * @returns {!Long} The corresponding Long value
20691 */
20692 Long$1.fromNumber = fromNumber;
20693
20694 /**
20695 * @param {number} lowBits
20696 * @param {number} highBits
20697 * @param {boolean=} unsigned
20698 * @returns {!Long}
20699 * @inner
20700 */
20701 function fromBits(lowBits, highBits, unsigned) {
20702 return new Long$1(lowBits, highBits, unsigned);
20703 }
20704
20705 /**
20706 * Returns a Long representing the 64 bit integer that comes by concatenating the given low and high bits. Each is
20707 * assumed to use 32 bits.
20708 * @function
20709 * @param {number} lowBits The low 32 bits
20710 * @param {number} highBits The high 32 bits
20711 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
20712 * @returns {!Long} The corresponding Long value
20713 */
20714 Long$1.fromBits = fromBits;
20715
20716 /**
20717 * @function
20718 * @param {number} base
20719 * @param {number} exponent
20720 * @returns {number}
20721 * @inner
20722 */
20723 var pow_dbl = Math.pow; // Used 4 times (4*8 to 15+4)
20724
20725 /**
20726 * @param {string} str
20727 * @param {(boolean|number)=} unsigned
20728 * @param {number=} radix
20729 * @returns {!Long}
20730 * @inner
20731 */
20732 function fromString(str, unsigned, radix) {
20733 if (str.length === 0) throw Error('empty string');
20734 if (str === "NaN" || str === "Infinity" || str === "+Infinity" || str === "-Infinity") return ZERO;
20735 if (typeof unsigned === 'number') {
20736 // For goog.math.long compatibility
20737 radix = unsigned, unsigned = false;
20738 } else {
20739 unsigned = !!unsigned;
20740 }
20741 radix = radix || 10;
20742 if (radix < 2 || 36 < radix) throw RangeError('radix');
20743 var p;
20744 if ((p = str.indexOf('-')) > 0) throw Error('interior hyphen');else if (p === 0) {
20745 return fromString(str.substring(1), unsigned, radix).neg();
20746 }
20747
20748 // Do several (8) digits each time through the loop, so as to
20749 // minimize the calls to the very expensive emulated div.
20750 var radixToPower = fromNumber(pow_dbl(radix, 8));
20751 var result = ZERO;
20752 for (var i = 0; i < str.length; i += 8) {
20753 var size = Math.min(8, str.length - i),
20754 value = parseInt(str.substring(i, i + size), radix);
20755 if (size < 8) {
20756 var power = fromNumber(pow_dbl(radix, size));
20757 result = result.mul(power).add(fromNumber(value));
20758 } else {
20759 result = result.mul(radixToPower);
20760 result = result.add(fromNumber(value));
20761 }
20762 }
20763 result.unsigned = unsigned;
20764 return result;
20765 }
20766
20767 /**
20768 * Returns a Long representation of the given string, written using the specified radix.
20769 * @function
20770 * @param {string} str The textual representation of the Long
20771 * @param {(boolean|number)=} unsigned Whether unsigned or not, defaults to signed
20772 * @param {number=} radix The radix in which the text is written (2-36), defaults to 10
20773 * @returns {!Long} The corresponding Long value
20774 */
20775 Long$1.fromString = fromString;
20776
20777 /**
20778 * @function
20779 * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val
20780 * @param {boolean=} unsigned
20781 * @returns {!Long}
20782 * @inner
20783 */
20784 function fromValue(val, unsigned) {
20785 if (typeof val === 'number') return fromNumber(val, unsigned);
20786 if (typeof val === 'string') return fromString(val, unsigned);
20787 // Throws for non-objects, converts non-instanceof Long:
20788 return fromBits(val.low, val.high, typeof unsigned === 'boolean' ? unsigned : val.unsigned);
20789 }
20790
20791 /**
20792 * Converts the specified value to a Long using the appropriate from* function for its type.
20793 * @function
20794 * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val Value
20795 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
20796 * @returns {!Long}
20797 */
20798 Long$1.fromValue = fromValue;
20799
20800 // NOTE: the compiler should inline these constant values below and then remove these variables, so there should be
20801 // no runtime penalty for these.
20802
20803 /**
20804 * @type {number}
20805 * @const
20806 * @inner
20807 */
20808 var TWO_PWR_16_DBL = 1 << 16;
20809
20810 /**
20811 * @type {number}
20812 * @const
20813 * @inner
20814 */
20815 var TWO_PWR_24_DBL = 1 << 24;
20816
20817 /**
20818 * @type {number}
20819 * @const
20820 * @inner
20821 */
20822 var TWO_PWR_32_DBL = TWO_PWR_16_DBL * TWO_PWR_16_DBL;
20823
20824 /**
20825 * @type {number}
20826 * @const
20827 * @inner
20828 */
20829 var TWO_PWR_64_DBL = TWO_PWR_32_DBL * TWO_PWR_32_DBL;
20830
20831 /**
20832 * @type {number}
20833 * @const
20834 * @inner
20835 */
20836 var TWO_PWR_63_DBL = TWO_PWR_64_DBL / 2;
20837
20838 /**
20839 * @type {!Long}
20840 * @const
20841 * @inner
20842 */
20843 var TWO_PWR_24 = fromInt(TWO_PWR_24_DBL);
20844
20845 /**
20846 * @type {!Long}
20847 * @inner
20848 */
20849 var ZERO = fromInt(0);
20850
20851 /**
20852 * Signed zero.
20853 * @type {!Long}
20854 */
20855 Long$1.ZERO = ZERO;
20856
20857 /**
20858 * @type {!Long}
20859 * @inner
20860 */
20861 var UZERO = fromInt(0, true);
20862
20863 /**
20864 * Unsigned zero.
20865 * @type {!Long}
20866 */
20867 Long$1.UZERO = UZERO;
20868
20869 /**
20870 * @type {!Long}
20871 * @inner
20872 */
20873 var ONE = fromInt(1);
20874
20875 /**
20876 * Signed one.
20877 * @type {!Long}
20878 */
20879 Long$1.ONE = ONE;
20880
20881 /**
20882 * @type {!Long}
20883 * @inner
20884 */
20885 var UONE = fromInt(1, true);
20886
20887 /**
20888 * Unsigned one.
20889 * @type {!Long}
20890 */
20891 Long$1.UONE = UONE;
20892
20893 /**
20894 * @type {!Long}
20895 * @inner
20896 */
20897 var NEG_ONE = fromInt(-1);
20898
20899 /**
20900 * Signed negative one.
20901 * @type {!Long}
20902 */
20903 Long$1.NEG_ONE = NEG_ONE;
20904
20905 /**
20906 * @type {!Long}
20907 * @inner
20908 */
20909 var MAX_VALUE = fromBits(0xFFFFFFFF | 0, 0x7FFFFFFF | 0, false);
20910
20911 /**
20912 * Maximum signed value.
20913 * @type {!Long}
20914 */
20915 Long$1.MAX_VALUE = MAX_VALUE;
20916
20917 /**
20918 * @type {!Long}
20919 * @inner
20920 */
20921 var MAX_UNSIGNED_VALUE = fromBits(0xFFFFFFFF | 0, 0xFFFFFFFF | 0, true);
20922
20923 /**
20924 * Maximum unsigned value.
20925 * @type {!Long}
20926 */
20927 Long$1.MAX_UNSIGNED_VALUE = MAX_UNSIGNED_VALUE;
20928
20929 /**
20930 * @type {!Long}
20931 * @inner
20932 */
20933 var MIN_VALUE = fromBits(0, 0x80000000 | 0, false);
20934
20935 /**
20936 * Minimum signed value.
20937 * @type {!Long}
20938 */
20939 Long$1.MIN_VALUE = MIN_VALUE;
20940
20941 /**
20942 * @alias Long.prototype
20943 * @inner
20944 */
20945 var LongPrototype = Long$1.prototype;
20946
20947 /**
20948 * Converts the Long to a 32 bit integer, assuming it is a 32 bit integer.
20949 * @returns {number}
20950 */
20951 LongPrototype.toInt = function toInt() {
20952 return this.unsigned ? this.low >>> 0 : this.low;
20953 };
20954
20955 /**
20956 * Converts the Long to a the nearest floating-point representation of this value (double, 53 bit mantissa).
20957 * @returns {number}
20958 */
20959 LongPrototype.toNumber = function toNumber() {
20960 if (this.unsigned) return (this.high >>> 0) * TWO_PWR_32_DBL + (this.low >>> 0);
20961 return this.high * TWO_PWR_32_DBL + (this.low >>> 0);
20962 };
20963
20964 /**
20965 * Converts the Long to a string written in the specified radix.
20966 * @param {number=} radix Radix (2-36), defaults to 10
20967 * @returns {string}
20968 * @override
20969 * @throws {RangeError} If `radix` is out of range
20970 */
20971 LongPrototype.toString = function toString(radix) {
20972 radix = radix || 10;
20973 if (radix < 2 || 36 < radix) throw RangeError('radix');
20974 if (this.isZero()) return '0';
20975 if (this.isNegative()) {
20976 // Unsigned Longs are never negative
20977 if (this.eq(MIN_VALUE)) {
20978 // We need to change the Long value before it can be negated, so we remove
20979 // the bottom-most digit in this base and then recurse to do the rest.
20980 var radixLong = fromNumber(radix),
20981 div = this.div(radixLong),
20982 rem1 = div.mul(radixLong).sub(this);
20983 return div.toString(radix) + rem1.toInt().toString(radix);
20984 } else return '-' + this.neg().toString(radix);
20985 }
20986
20987 // Do several (6) digits each time through the loop, so as to
20988 // minimize the calls to the very expensive emulated div.
20989 var radixToPower = fromNumber(pow_dbl(radix, 6), this.unsigned),
20990 rem = this;
20991 var result = '';
20992 while (true) {
20993 var remDiv = rem.div(radixToPower),
20994 intval = rem.sub(remDiv.mul(radixToPower)).toInt() >>> 0,
20995 digits = intval.toString(radix);
20996 rem = remDiv;
20997 if (rem.isZero()) return digits + result;else {
20998 while (digits.length < 6) digits = '0' + digits;
20999 result = '' + digits + result;
21000 }
21001 }
21002 };
21003
21004 /**
21005 * Gets the high 32 bits as a signed integer.
21006 * @returns {number} Signed high bits
21007 */
21008 LongPrototype.getHighBits = function getHighBits() {
21009 return this.high;
21010 };
21011
21012 /**
21013 * Gets the high 32 bits as an unsigned integer.
21014 * @returns {number} Unsigned high bits
21015 */
21016 LongPrototype.getHighBitsUnsigned = function getHighBitsUnsigned() {
21017 return this.high >>> 0;
21018 };
21019
21020 /**
21021 * Gets the low 32 bits as a signed integer.
21022 * @returns {number} Signed low bits
21023 */
21024 LongPrototype.getLowBits = function getLowBits() {
21025 return this.low;
21026 };
21027
21028 /**
21029 * Gets the low 32 bits as an unsigned integer.
21030 * @returns {number} Unsigned low bits
21031 */
21032 LongPrototype.getLowBitsUnsigned = function getLowBitsUnsigned() {
21033 return this.low >>> 0;
21034 };
21035
21036 /**
21037 * Gets the number of bits needed to represent the absolute value of this Long.
21038 * @returns {number}
21039 */
21040 LongPrototype.getNumBitsAbs = function getNumBitsAbs() {
21041 if (this.isNegative())
21042 // Unsigned Longs are never negative
21043 return this.eq(MIN_VALUE) ? 64 : this.neg().getNumBitsAbs();
21044 var val = this.high != 0 ? this.high : this.low;
21045 for (var bit = 31; bit > 0; bit--) if ((val & 1 << bit) != 0) break;
21046 return this.high != 0 ? bit + 33 : bit + 1;
21047 };
21048
21049 /**
21050 * Tests if this Long's value equals zero.
21051 * @returns {boolean}
21052 */
21053 LongPrototype.isZero = function isZero() {
21054 return this.high === 0 && this.low === 0;
21055 };
21056
21057 /**
21058 * Tests if this Long's value equals zero. This is an alias of {@link Long#isZero}.
21059 * @returns {boolean}
21060 */
21061 LongPrototype.eqz = LongPrototype.isZero;
21062
21063 /**
21064 * Tests if this Long's value is negative.
21065 * @returns {boolean}
21066 */
21067 LongPrototype.isNegative = function isNegative() {
21068 return !this.unsigned && this.high < 0;
21069 };
21070
21071 /**
21072 * Tests if this Long's value is positive.
21073 * @returns {boolean}
21074 */
21075 LongPrototype.isPositive = function isPositive() {
21076 return this.unsigned || this.high >= 0;
21077 };
21078
21079 /**
21080 * Tests if this Long's value is odd.
21081 * @returns {boolean}
21082 */
21083 LongPrototype.isOdd = function isOdd() {
21084 return (this.low & 1) === 1;
21085 };
21086
21087 /**
21088 * Tests if this Long's value is even.
21089 * @returns {boolean}
21090 */
21091 LongPrototype.isEven = function isEven() {
21092 return (this.low & 1) === 0;
21093 };
21094
21095 /**
21096 * Tests if this Long's value equals the specified's.
21097 * @param {!Long|number|string} other Other value
21098 * @returns {boolean}
21099 */
21100 LongPrototype.equals = function equals(other) {
21101 if (!isLong(other)) other = fromValue(other);
21102 if (this.unsigned !== other.unsigned && this.high >>> 31 === 1 && other.high >>> 31 === 1) return false;
21103 return this.high === other.high && this.low === other.low;
21104 };
21105
21106 /**
21107 * Tests if this Long's value equals the specified's. This is an alias of {@link Long#equals}.
21108 * @function
21109 * @param {!Long|number|string} other Other value
21110 * @returns {boolean}
21111 */
21112 LongPrototype.eq = LongPrototype.equals;
21113
21114 /**
21115 * Tests if this Long's value differs from the specified's.
21116 * @param {!Long|number|string} other Other value
21117 * @returns {boolean}
21118 */
21119 LongPrototype.notEquals = function notEquals(other) {
21120 return !this.eq( /* validates */other);
21121 };
21122
21123 /**
21124 * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}.
21125 * @function
21126 * @param {!Long|number|string} other Other value
21127 * @returns {boolean}
21128 */
21129 LongPrototype.neq = LongPrototype.notEquals;
21130
21131 /**
21132 * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}.
21133 * @function
21134 * @param {!Long|number|string} other Other value
21135 * @returns {boolean}
21136 */
21137 LongPrototype.ne = LongPrototype.notEquals;
21138
21139 /**
21140 * Tests if this Long's value is less than the specified's.
21141 * @param {!Long|number|string} other Other value
21142 * @returns {boolean}
21143 */
21144 LongPrototype.lessThan = function lessThan(other) {
21145 return this.comp( /* validates */other) < 0;
21146 };
21147
21148 /**
21149 * Tests if this Long's value is less than the specified's. This is an alias of {@link Long#lessThan}.
21150 * @function
21151 * @param {!Long|number|string} other Other value
21152 * @returns {boolean}
21153 */
21154 LongPrototype.lt = LongPrototype.lessThan;
21155
21156 /**
21157 * Tests if this Long's value is less than or equal the specified's.
21158 * @param {!Long|number|string} other Other value
21159 * @returns {boolean}
21160 */
21161 LongPrototype.lessThanOrEqual = function lessThanOrEqual(other) {
21162 return this.comp( /* validates */other) <= 0;
21163 };
21164
21165 /**
21166 * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}.
21167 * @function
21168 * @param {!Long|number|string} other Other value
21169 * @returns {boolean}
21170 */
21171 LongPrototype.lte = LongPrototype.lessThanOrEqual;
21172
21173 /**
21174 * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}.
21175 * @function
21176 * @param {!Long|number|string} other Other value
21177 * @returns {boolean}
21178 */
21179 LongPrototype.le = LongPrototype.lessThanOrEqual;
21180
21181 /**
21182 * Tests if this Long's value is greater than the specified's.
21183 * @param {!Long|number|string} other Other value
21184 * @returns {boolean}
21185 */
21186 LongPrototype.greaterThan = function greaterThan(other) {
21187 return this.comp( /* validates */other) > 0;
21188 };
21189
21190 /**
21191 * Tests if this Long's value is greater than the specified's. This is an alias of {@link Long#greaterThan}.
21192 * @function
21193 * @param {!Long|number|string} other Other value
21194 * @returns {boolean}
21195 */
21196 LongPrototype.gt = LongPrototype.greaterThan;
21197
21198 /**
21199 * Tests if this Long's value is greater than or equal the specified's.
21200 * @param {!Long|number|string} other Other value
21201 * @returns {boolean}
21202 */
21203 LongPrototype.greaterThanOrEqual = function greaterThanOrEqual(other) {
21204 return this.comp( /* validates */other) >= 0;
21205 };
21206
21207 /**
21208 * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}.
21209 * @function
21210 * @param {!Long|number|string} other Other value
21211 * @returns {boolean}
21212 */
21213 LongPrototype.gte = LongPrototype.greaterThanOrEqual;
21214
21215 /**
21216 * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}.
21217 * @function
21218 * @param {!Long|number|string} other Other value
21219 * @returns {boolean}
21220 */
21221 LongPrototype.ge = LongPrototype.greaterThanOrEqual;
21222
21223 /**
21224 * Compares this Long's value with the specified's.
21225 * @param {!Long|number|string} other Other value
21226 * @returns {number} 0 if they are the same, 1 if the this is greater and -1
21227 * if the given one is greater
21228 */
21229 LongPrototype.compare = function compare(other) {
21230 if (!isLong(other)) other = fromValue(other);
21231 if (this.eq(other)) return 0;
21232 var thisNeg = this.isNegative(),
21233 otherNeg = other.isNegative();
21234 if (thisNeg && !otherNeg) return -1;
21235 if (!thisNeg && otherNeg) return 1;
21236 // At this point the sign bits are the same
21237 if (!this.unsigned) return this.sub(other).isNegative() ? -1 : 1;
21238 // Both are positive if at least one is unsigned
21239 return other.high >>> 0 > this.high >>> 0 || other.high === this.high && other.low >>> 0 > this.low >>> 0 ? -1 : 1;
21240 };
21241
21242 /**
21243 * Compares this Long's value with the specified's. This is an alias of {@link Long#compare}.
21244 * @function
21245 * @param {!Long|number|string} other Other value
21246 * @returns {number} 0 if they are the same, 1 if the this is greater and -1
21247 * if the given one is greater
21248 */
21249 LongPrototype.comp = LongPrototype.compare;
21250
21251 /**
21252 * Negates this Long's value.
21253 * @returns {!Long} Negated Long
21254 */
21255 LongPrototype.negate = function negate() {
21256 if (!this.unsigned && this.eq(MIN_VALUE)) return MIN_VALUE;
21257 return this.not().add(ONE);
21258 };
21259
21260 /**
21261 * Negates this Long's value. This is an alias of {@link Long#negate}.
21262 * @function
21263 * @returns {!Long} Negated Long
21264 */
21265 LongPrototype.neg = LongPrototype.negate;
21266
21267 /**
21268 * Returns the sum of this and the specified Long.
21269 * @param {!Long|number|string} addend Addend
21270 * @returns {!Long} Sum
21271 */
21272 LongPrototype.add = function add(addend) {
21273 if (!isLong(addend)) addend = fromValue(addend);
21274
21275 // Divide each number into 4 chunks of 16 bits, and then sum the chunks.
21276
21277 var a48 = this.high >>> 16;
21278 var a32 = this.high & 0xFFFF;
21279 var a16 = this.low >>> 16;
21280 var a00 = this.low & 0xFFFF;
21281 var b48 = addend.high >>> 16;
21282 var b32 = addend.high & 0xFFFF;
21283 var b16 = addend.low >>> 16;
21284 var b00 = addend.low & 0xFFFF;
21285 var c48 = 0,
21286 c32 = 0,
21287 c16 = 0,
21288 c00 = 0;
21289 c00 += a00 + b00;
21290 c16 += c00 >>> 16;
21291 c00 &= 0xFFFF;
21292 c16 += a16 + b16;
21293 c32 += c16 >>> 16;
21294 c16 &= 0xFFFF;
21295 c32 += a32 + b32;
21296 c48 += c32 >>> 16;
21297 c32 &= 0xFFFF;
21298 c48 += a48 + b48;
21299 c48 &= 0xFFFF;
21300 return fromBits(c16 << 16 | c00, c48 << 16 | c32, this.unsigned);
21301 };
21302
21303 /**
21304 * Returns the difference of this and the specified Long.
21305 * @param {!Long|number|string} subtrahend Subtrahend
21306 * @returns {!Long} Difference
21307 */
21308 LongPrototype.subtract = function subtract(subtrahend) {
21309 if (!isLong(subtrahend)) subtrahend = fromValue(subtrahend);
21310 return this.add(subtrahend.neg());
21311 };
21312
21313 /**
21314 * Returns the difference of this and the specified Long. This is an alias of {@link Long#subtract}.
21315 * @function
21316 * @param {!Long|number|string} subtrahend Subtrahend
21317 * @returns {!Long} Difference
21318 */
21319 LongPrototype.sub = LongPrototype.subtract;
21320
21321 /**
21322 * Returns the product of this and the specified Long.
21323 * @param {!Long|number|string} multiplier Multiplier
21324 * @returns {!Long} Product
21325 */
21326 LongPrototype.multiply = function multiply(multiplier) {
21327 if (this.isZero()) return ZERO;
21328 if (!isLong(multiplier)) multiplier = fromValue(multiplier);
21329
21330 // use wasm support if present
21331 if (wasm) {
21332 var low = wasm.mul(this.low, this.high, multiplier.low, multiplier.high);
21333 return fromBits(low, wasm.get_high(), this.unsigned);
21334 }
21335 if (multiplier.isZero()) return ZERO;
21336 if (this.eq(MIN_VALUE)) return multiplier.isOdd() ? MIN_VALUE : ZERO;
21337 if (multiplier.eq(MIN_VALUE)) return this.isOdd() ? MIN_VALUE : ZERO;
21338 if (this.isNegative()) {
21339 if (multiplier.isNegative()) return this.neg().mul(multiplier.neg());else return this.neg().mul(multiplier).neg();
21340 } else if (multiplier.isNegative()) return this.mul(multiplier.neg()).neg();
21341
21342 // If both longs are small, use float multiplication
21343 if (this.lt(TWO_PWR_24) && multiplier.lt(TWO_PWR_24)) return fromNumber(this.toNumber() * multiplier.toNumber(), this.unsigned);
21344
21345 // Divide each long into 4 chunks of 16 bits, and then add up 4x4 products.
21346 // We can skip products that would overflow.
21347
21348 var a48 = this.high >>> 16;
21349 var a32 = this.high & 0xFFFF;
21350 var a16 = this.low >>> 16;
21351 var a00 = this.low & 0xFFFF;
21352 var b48 = multiplier.high >>> 16;
21353 var b32 = multiplier.high & 0xFFFF;
21354 var b16 = multiplier.low >>> 16;
21355 var b00 = multiplier.low & 0xFFFF;
21356 var c48 = 0,
21357 c32 = 0,
21358 c16 = 0,
21359 c00 = 0;
21360 c00 += a00 * b00;
21361 c16 += c00 >>> 16;
21362 c00 &= 0xFFFF;
21363 c16 += a16 * b00;
21364 c32 += c16 >>> 16;
21365 c16 &= 0xFFFF;
21366 c16 += a00 * b16;
21367 c32 += c16 >>> 16;
21368 c16 &= 0xFFFF;
21369 c32 += a32 * b00;
21370 c48 += c32 >>> 16;
21371 c32 &= 0xFFFF;
21372 c32 += a16 * b16;
21373 c48 += c32 >>> 16;
21374 c32 &= 0xFFFF;
21375 c32 += a00 * b32;
21376 c48 += c32 >>> 16;
21377 c32 &= 0xFFFF;
21378 c48 += a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48;
21379 c48 &= 0xFFFF;
21380 return fromBits(c16 << 16 | c00, c48 << 16 | c32, this.unsigned);
21381 };
21382
21383 /**
21384 * Returns the product of this and the specified Long. This is an alias of {@link Long#multiply}.
21385 * @function
21386 * @param {!Long|number|string} multiplier Multiplier
21387 * @returns {!Long} Product
21388 */
21389 LongPrototype.mul = LongPrototype.multiply;
21390
21391 /**
21392 * Returns this Long divided by the specified. The result is signed if this Long is signed or
21393 * unsigned if this Long is unsigned.
21394 * @param {!Long|number|string} divisor Divisor
21395 * @returns {!Long} Quotient
21396 */
21397 LongPrototype.divide = function divide(divisor) {
21398 if (!isLong(divisor)) divisor = fromValue(divisor);
21399 if (divisor.isZero()) throw Error('division by zero');
21400
21401 // use wasm support if present
21402 if (wasm) {
21403 // guard against signed division overflow: the largest
21404 // negative number / -1 would be 1 larger than the largest
21405 // positive number, due to two's complement.
21406 if (!this.unsigned && this.high === -0x80000000 && divisor.low === -1 && divisor.high === -1) {
21407 // be consistent with non-wasm code path
21408 return this;
21409 }
21410 var low = (this.unsigned ? wasm.div_u : wasm.div_s)(this.low, this.high, divisor.low, divisor.high);
21411 return fromBits(low, wasm.get_high(), this.unsigned);
21412 }
21413 if (this.isZero()) return this.unsigned ? UZERO : ZERO;
21414 var approx, rem, res;
21415 if (!this.unsigned) {
21416 // This section is only relevant for signed longs and is derived from the
21417 // closure library as a whole.
21418 if (this.eq(MIN_VALUE)) {
21419 if (divisor.eq(ONE) || divisor.eq(NEG_ONE)) return MIN_VALUE; // recall that -MIN_VALUE == MIN_VALUE
21420 else if (divisor.eq(MIN_VALUE)) return ONE;else {
21421 // At this point, we have |other| >= 2, so |this/other| < |MIN_VALUE|.
21422 var halfThis = this.shr(1);
21423 approx = halfThis.div(divisor).shl(1);
21424 if (approx.eq(ZERO)) {
21425 return divisor.isNegative() ? ONE : NEG_ONE;
21426 } else {
21427 rem = this.sub(divisor.mul(approx));
21428 res = approx.add(rem.div(divisor));
21429 return res;
21430 }
21431 }
21432 } else if (divisor.eq(MIN_VALUE)) return this.unsigned ? UZERO : ZERO;
21433 if (this.isNegative()) {
21434 if (divisor.isNegative()) return this.neg().div(divisor.neg());
21435 return this.neg().div(divisor).neg();
21436 } else if (divisor.isNegative()) return this.div(divisor.neg()).neg();
21437 res = ZERO;
21438 } else {
21439 // The algorithm below has not been made for unsigned longs. It's therefore
21440 // required to take special care of the MSB prior to running it.
21441 if (!divisor.unsigned) divisor = divisor.toUnsigned();
21442 if (divisor.gt(this)) return UZERO;
21443 if (divisor.gt(this.shru(1)))
21444 // 15 >>> 1 = 7 ; with divisor = 8 ; true
21445 return UONE;
21446 res = UZERO;
21447 }
21448
21449 // Repeat the following until the remainder is less than other: find a
21450 // floating-point that approximates remainder / other *from below*, add this
21451 // into the result, and subtract it from the remainder. It is critical that
21452 // the approximate value is less than or equal to the real value so that the
21453 // remainder never becomes negative.
21454 rem = this;
21455 while (rem.gte(divisor)) {
21456 // Approximate the result of division. This may be a little greater or
21457 // smaller than the actual value.
21458 approx = Math.max(1, Math.floor(rem.toNumber() / divisor.toNumber()));
21459
21460 // We will tweak the approximate result by changing it in the 48-th digit or
21461 // the smallest non-fractional digit, whichever is larger.
21462 var log2 = Math.ceil(Math.log(approx) / Math.LN2),
21463 delta = log2 <= 48 ? 1 : pow_dbl(2, log2 - 48),
21464 // Decrease the approximation until it is smaller than the remainder. Note
21465 // that if it is too large, the product overflows and is negative.
21466 approxRes = fromNumber(approx),
21467 approxRem = approxRes.mul(divisor);
21468 while (approxRem.isNegative() || approxRem.gt(rem)) {
21469 approx -= delta;
21470 approxRes = fromNumber(approx, this.unsigned);
21471 approxRem = approxRes.mul(divisor);
21472 }
21473
21474 // We know the answer can't be zero... and actually, zero would cause
21475 // infinite recursion since we would make no progress.
21476 if (approxRes.isZero()) approxRes = ONE;
21477 res = res.add(approxRes);
21478 rem = rem.sub(approxRem);
21479 }
21480 return res;
21481 };
21482
21483 /**
21484 * Returns this Long divided by the specified. This is an alias of {@link Long#divide}.
21485 * @function
21486 * @param {!Long|number|string} divisor Divisor
21487 * @returns {!Long} Quotient
21488 */
21489 LongPrototype.div = LongPrototype.divide;
21490
21491 /**
21492 * Returns this Long modulo the specified.
21493 * @param {!Long|number|string} divisor Divisor
21494 * @returns {!Long} Remainder
21495 */
21496 LongPrototype.modulo = function modulo(divisor) {
21497 if (!isLong(divisor)) divisor = fromValue(divisor);
21498
21499 // use wasm support if present
21500 if (wasm) {
21501 var low = (this.unsigned ? wasm.rem_u : wasm.rem_s)(this.low, this.high, divisor.low, divisor.high);
21502 return fromBits(low, wasm.get_high(), this.unsigned);
21503 }
21504 return this.sub(this.div(divisor).mul(divisor));
21505 };
21506
21507 /**
21508 * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}.
21509 * @function
21510 * @param {!Long|number|string} divisor Divisor
21511 * @returns {!Long} Remainder
21512 */
21513 LongPrototype.mod = LongPrototype.modulo;
21514
21515 /**
21516 * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}.
21517 * @function
21518 * @param {!Long|number|string} divisor Divisor
21519 * @returns {!Long} Remainder
21520 */
21521 LongPrototype.rem = LongPrototype.modulo;
21522
21523 /**
21524 * Returns the bitwise NOT of this Long.
21525 * @returns {!Long}
21526 */
21527 LongPrototype.not = function not() {
21528 return fromBits(~this.low, ~this.high, this.unsigned);
21529 };
21530
21531 /**
21532 * Returns the bitwise AND of this Long and the specified.
21533 * @param {!Long|number|string} other Other Long
21534 * @returns {!Long}
21535 */
21536 LongPrototype.and = function and(other) {
21537 if (!isLong(other)) other = fromValue(other);
21538 return fromBits(this.low & other.low, this.high & other.high, this.unsigned);
21539 };
21540
21541 /**
21542 * Returns the bitwise OR of this Long and the specified.
21543 * @param {!Long|number|string} other Other Long
21544 * @returns {!Long}
21545 */
21546 LongPrototype.or = function or(other) {
21547 if (!isLong(other)) other = fromValue(other);
21548 return fromBits(this.low | other.low, this.high | other.high, this.unsigned);
21549 };
21550
21551 /**
21552 * Returns the bitwise XOR of this Long and the given one.
21553 * @param {!Long|number|string} other Other Long
21554 * @returns {!Long}
21555 */
21556 LongPrototype.xor = function xor(other) {
21557 if (!isLong(other)) other = fromValue(other);
21558 return fromBits(this.low ^ other.low, this.high ^ other.high, this.unsigned);
21559 };
21560
21561 /**
21562 * Returns this Long with bits shifted to the left by the given amount.
21563 * @param {number|!Long} numBits Number of bits
21564 * @returns {!Long} Shifted Long
21565 */
21566 LongPrototype.shiftLeft = function shiftLeft(numBits) {
21567 if (isLong(numBits)) numBits = numBits.toInt();
21568 if ((numBits &= 63) === 0) return this;else if (numBits < 32) return fromBits(this.low << numBits, this.high << numBits | this.low >>> 32 - numBits, this.unsigned);else return fromBits(0, this.low << numBits - 32, this.unsigned);
21569 };
21570
21571 /**
21572 * Returns this Long with bits shifted to the left by the given amount. This is an alias of {@link Long#shiftLeft}.
21573 * @function
21574 * @param {number|!Long} numBits Number of bits
21575 * @returns {!Long} Shifted Long
21576 */
21577 LongPrototype.shl = LongPrototype.shiftLeft;
21578
21579 /**
21580 * Returns this Long with bits arithmetically shifted to the right by the given amount.
21581 * @param {number|!Long} numBits Number of bits
21582 * @returns {!Long} Shifted Long
21583 */
21584 LongPrototype.shiftRight = function shiftRight(numBits) {
21585 if (isLong(numBits)) numBits = numBits.toInt();
21586 if ((numBits &= 63) === 0) return this;else if (numBits < 32) return fromBits(this.low >>> numBits | this.high << 32 - numBits, this.high >> numBits, this.unsigned);else return fromBits(this.high >> numBits - 32, this.high >= 0 ? 0 : -1, this.unsigned);
21587 };
21588
21589 /**
21590 * Returns this Long with bits arithmetically shifted to the right by the given amount. This is an alias of {@link Long#shiftRight}.
21591 * @function
21592 * @param {number|!Long} numBits Number of bits
21593 * @returns {!Long} Shifted Long
21594 */
21595 LongPrototype.shr = LongPrototype.shiftRight;
21596
21597 /**
21598 * Returns this Long with bits logically shifted to the right by the given amount.
21599 * @param {number|!Long} numBits Number of bits
21600 * @returns {!Long} Shifted Long
21601 */
21602 LongPrototype.shiftRightUnsigned = function shiftRightUnsigned(numBits) {
21603 if (isLong(numBits)) numBits = numBits.toInt();
21604 numBits &= 63;
21605 if (numBits === 0) return this;else {
21606 var high = this.high;
21607 if (numBits < 32) {
21608 var low = this.low;
21609 return fromBits(low >>> numBits | high << 32 - numBits, high >>> numBits, this.unsigned);
21610 } else if (numBits === 32) return fromBits(high, 0, this.unsigned);else return fromBits(high >>> numBits - 32, 0, this.unsigned);
21611 }
21612 };
21613
21614 /**
21615 * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}.
21616 * @function
21617 * @param {number|!Long} numBits Number of bits
21618 * @returns {!Long} Shifted Long
21619 */
21620 LongPrototype.shru = LongPrototype.shiftRightUnsigned;
21621
21622 /**
21623 * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}.
21624 * @function
21625 * @param {number|!Long} numBits Number of bits
21626 * @returns {!Long} Shifted Long
21627 */
21628 LongPrototype.shr_u = LongPrototype.shiftRightUnsigned;
21629
21630 /**
21631 * Converts this Long to signed.
21632 * @returns {!Long} Signed long
21633 */
21634 LongPrototype.toSigned = function toSigned() {
21635 if (!this.unsigned) return this;
21636 return fromBits(this.low, this.high, false);
21637 };
21638
21639 /**
21640 * Converts this Long to unsigned.
21641 * @returns {!Long} Unsigned long
21642 */
21643 LongPrototype.toUnsigned = function toUnsigned() {
21644 if (this.unsigned) return this;
21645 return fromBits(this.low, this.high, true);
21646 };
21647
21648 /**
21649 * Converts this Long to its byte representation.
21650 * @param {boolean=} le Whether little or big endian, defaults to big endian
21651 * @returns {!Array.<number>} Byte representation
21652 */
21653 LongPrototype.toBytes = function toBytes(le) {
21654 return le ? this.toBytesLE() : this.toBytesBE();
21655 };
21656
21657 /**
21658 * Converts this Long to its little endian byte representation.
21659 * @returns {!Array.<number>} Little endian byte representation
21660 */
21661 LongPrototype.toBytesLE = function toBytesLE() {
21662 var hi = this.high,
21663 lo = this.low;
21664 return [lo & 0xff, lo >>> 8 & 0xff, lo >>> 16 & 0xff, lo >>> 24, hi & 0xff, hi >>> 8 & 0xff, hi >>> 16 & 0xff, hi >>> 24];
21665 };
21666
21667 /**
21668 * Converts this Long to its big endian byte representation.
21669 * @returns {!Array.<number>} Big endian byte representation
21670 */
21671 LongPrototype.toBytesBE = function toBytesBE() {
21672 var hi = this.high,
21673 lo = this.low;
21674 return [hi >>> 24, hi >>> 16 & 0xff, hi >>> 8 & 0xff, hi & 0xff, lo >>> 24, lo >>> 16 & 0xff, lo >>> 8 & 0xff, lo & 0xff];
21675 };
21676
21677 /**
21678 * Creates a Long from its byte representation.
21679 * @param {!Array.<number>} bytes Byte representation
21680 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
21681 * @param {boolean=} le Whether little or big endian, defaults to big endian
21682 * @returns {Long} The corresponding Long value
21683 */
21684 Long$1.fromBytes = function fromBytes(bytes, unsigned, le) {
21685 return le ? Long$1.fromBytesLE(bytes, unsigned) : Long$1.fromBytesBE(bytes, unsigned);
21686 };
21687
21688 /**
21689 * Creates a Long from its little endian byte representation.
21690 * @param {!Array.<number>} bytes Little endian byte representation
21691 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
21692 * @returns {Long} The corresponding Long value
21693 */
21694 Long$1.fromBytesLE = function fromBytesLE(bytes, unsigned) {
21695 return new Long$1(bytes[0] | bytes[1] << 8 | bytes[2] << 16 | bytes[3] << 24, bytes[4] | bytes[5] << 8 | bytes[6] << 16 | bytes[7] << 24, unsigned);
21696 };
21697
21698 /**
21699 * Creates a Long from its big endian byte representation.
21700 * @param {!Array.<number>} bytes Big endian byte representation
21701 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
21702 * @returns {Long} The corresponding Long value
21703 */
21704 Long$1.fromBytesBE = function fromBytesBE(bytes, unsigned) {
21705 return new Long$1(bytes[4] << 24 | bytes[5] << 16 | bytes[6] << 8 | bytes[7], bytes[0] << 24 | bytes[1] << 16 | bytes[2] << 8 | bytes[3], unsigned);
21706 };
21707 var long$1 = /*@__PURE__*/getDefaultExportFromCjs(long);
21708
21709 var LongExports = /*#__PURE__*/_mergeNamespaces({
21710 __proto__: null,
21711 default: long$1
21712 }, [long]);
21713
21714 /**
21715 * @license
21716 * Copyright 2021 Google LLC. All Rights Reserved.
21717 * Licensed under the Apache License, Version 2.0 (the "License");
21718 * you may not use this file except in compliance with the License.
21719 * You may obtain a copy of the License at
21720 *
21721 * http://www.apache.org/licenses/LICENSE-2.0
21722 *
21723 * Unless required by applicable law or agreed to in writing, software
21724 * distributed under the License is distributed on an "AS IS" BASIS,
21725 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21726 * See the License for the specific language governing permissions and
21727 * limitations under the License.
21728 * =============================================================================
21729 */
21730 // tslint:disable-next-line
21731 var Long =
21732 // tslint:disable-next-line
21733 long$1 || LongExports;
21734 function hexToLong(hex) {
21735 return Long.fromString(hex, true, 16);
21736 }
21737 // Some primes between 2^63 and 2^64 for various uses.
21738 // Hex 0xc3a5c85c97cb3127
21739 var k0 = hexToLong('c3a5c85c97cb3127');
21740 // Hex 0xb492b66fbe98f273
21741 var k1 = hexToLong('b492b66fbe98f273');
21742 // Hex 0x9ae16a3b2f90404f
21743 var k2 = hexToLong('9ae16a3b2f90404f');
21744 function shiftMix(val) {
21745 return val.xor(val.shru(47));
21746 }
21747 function fetch$2(s, offset, numBytes) {
21748 var bytes = s.slice(offset, offset + numBytes);
21749 return Long.fromBytes(Array.from(bytes), true, true);
21750 }
21751 function fetch64(s, offset) {
21752 return fetch$2(s, offset, 8);
21753 }
21754 function fetch32(s, offset) {
21755 return fetch$2(s, offset, 4);
21756 }
21757 function rotate64(val, shift) {
21758 // Avoid shifting by 64: doing so yields an undefined result.
21759 return shift === 0 ? val : val.shru(shift).or(val.shl(64 - shift));
21760 }
21761 function hashLen16(u, v) {
21762 var mul = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : hexToLong('9ddfea08eb382d69');
21763 // Murmur-inspired hashing.
21764 var a = u.xor(v).mul(mul);
21765 a = a.xor(a.shru(47));
21766 var b = v.xor(a).mul(mul);
21767 b = b.xor(b.shru(47));
21768 b = b.mul(mul);
21769 return b;
21770 }
21771 // Return a 16-byte hash for 48 bytes. Quick and dirty.
21772 // Callers do best to use "random-looking" values for a and b.
21773 function weakHashLen32WithSeeds(w, x, y, z, a, b) {
21774 a = a.add(w);
21775 b = rotate64(b.add(a).add(z), 21);
21776 var c = a;
21777 a = a.add(x);
21778 a = a.add(y);
21779 b = b.add(rotate64(a, 44));
21780 return [a.add(z), b.add(c)];
21781 }
21782 function weakHashLen32WithSeedsStr(s, offset, a, b) {
21783 return weakHashLen32WithSeeds(fetch64(s, offset), fetch64(s, offset + 8), fetch64(s, offset + 16), fetch64(s, offset + 24), a, b);
21784 }
21785 function hashLen0to16(s) {
21786 var len = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : s.length;
21787 if (len >= 8) {
21788 var mul = k2.add(len * 2);
21789 var a = fetch64(s, 0).add(k2);
21790 var b = fetch64(s, len - 8);
21791 var c = rotate64(b, 37).mul(mul).add(a);
21792 var d = rotate64(a, 25).add(b).mul(mul);
21793 return hashLen16(c, d, mul);
21794 }
21795 if (len >= 4) {
21796 var _mul = k2.add(len * 2);
21797 var _a = fetch32(s, 0);
21798 return hashLen16(_a.shl(3).add(len), fetch32(s, len - 4), _mul);
21799 }
21800 if (len > 0) {
21801 var _a2 = s[0];
21802 var _b = s[len >> 1];
21803 var _c = s[len - 1];
21804 var y = _a2 + (_b << 8);
21805 var z = len + (_c << 2);
21806 return shiftMix(k2.mul(y).xor(k0.mul(z))).mul(k2);
21807 }
21808 return k2;
21809 }
21810 function hashLen17to32(s) {
21811 var len = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : s.length;
21812 var mul = k2.add(len * 2);
21813 var a = fetch64(s, 0).mul(k1);
21814 var b = fetch64(s, 8);
21815 var c = fetch64(s, len - 8).mul(mul);
21816 var d = fetch64(s, len - 16).mul(k2);
21817 return hashLen16(rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d), a.add(rotate64(b.add(k2), 18)).add(c), mul);
21818 }
21819 function hashLen33to64(s) {
21820 var len = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : s.length;
21821 var mul = k2.add(len * 2);
21822 var a = fetch64(s, 0).mul(k2);
21823 var b = fetch64(s, 8);
21824 var c = fetch64(s, len - 8).mul(mul);
21825 var d = fetch64(s, len - 16).mul(k2);
21826 var y = rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d);
21827 var z = hashLen16(y, a.add(rotate64(b.add(k2), 18)).add(c), mul);
21828 var e = fetch64(s, 16).mul(mul);
21829 var f = fetch64(s, 24);
21830 var g = y.add(fetch64(s, len - 32)).mul(mul);
21831 var h = z.add(fetch64(s, len - 24)).mul(mul);
21832 return hashLen16(rotate64(e.add(f), 43).add(rotate64(g, 30)).add(h), e.add(rotate64(f.add(a), 18)).add(g), mul);
21833 }
21834 function fingerPrint64(s) {
21835 var len = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : s.length;
21836 var seed = Long.fromNumber(81, true);
21837 if (len <= 32) {
21838 if (len <= 16) {
21839 return hashLen0to16(s, len);
21840 } else {
21841 return hashLen17to32(s, len);
21842 }
21843 } else if (len <= 64) {
21844 return hashLen33to64(s, len);
21845 }
21846 // For strings over 64 bytes we loop. Internal state consists of
21847 // 56 bytes: v, w, x, y, and z.
21848 var x = seed;
21849 var y = seed.mul(k1).add(113);
21850 var z = shiftMix(y.mul(k2).add(113)).mul(k2);
21851 var v = [Long.UZERO, Long.UZERO];
21852 var w = [Long.UZERO, Long.UZERO];
21853 x = x.mul(k2).add(fetch64(s, 0));
21854 var offset = 0;
21855 // Set end so that after the loop we have 1 to 64 bytes left to process.
21856 var end = (len - 1 >> 6) * 64;
21857 var last64 = end + (len - 1 & 63) - 63;
21858 do {
21859 x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(k1);
21860 y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(k1);
21861 x = x.xor(w[1]);
21862 y = y.add(v[0]).add(fetch64(s, offset + 40));
21863 z = rotate64(z.add(w[0]), 33).mul(k1);
21864 v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(k1), x.add(w[0]));
21865 w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
21866 var _ref = [x, z];
21867 z = _ref[0];
21868 x = _ref[1];
21869 offset += 64;
21870 } while (offset !== end);
21871 var mul = k1.add(z.and(0xff).shl(1));
21872 // Point to the last 64 bytes of input.
21873 offset = last64;
21874 w[0] = w[0].add(len - 1 & 63);
21875 v[0] = v[0].add(w[0]);
21876 w[0] = w[0].add(v[0]);
21877 x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(mul);
21878 y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(mul);
21879 x = x.xor(w[1].mul(9));
21880 y = y.add(v[0].mul(9).add(fetch64(s, offset + 40)));
21881 z = rotate64(z.add(w[0]), 33).mul(mul);
21882 v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(mul), x.add(w[0]));
21883 w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
21884 var _ref2 = [x, z];
21885 z = _ref2[0];
21886 x = _ref2[1];
21887 return hashLen16(hashLen16(v[0], w[0], mul).add(shiftMix(y).mul(k0)).add(z), hashLen16(v[1], w[1], mul).add(x), mul);
21888 }
21889
21890 /**
21891 * @license
21892 * Copyright 2017 Google LLC. All Rights Reserved.
21893 * Licensed under the Apache License, Version 2.0 (the "License");
21894 * you may not use this file except in compliance with the License.
21895 * You may obtain a copy of the License at
21896 *
21897 * http://www.apache.org/licenses/LICENSE-2.0
21898 *
21899 * Unless required by applicable law or agreed to in writing, software
21900 * distributed under the License is distributed on an "AS IS" BASIS,
21901 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21902 * See the License for the specific language governing permissions and
21903 * limitations under the License.
21904 * =============================================================================
21905 */
21906 /**
21907 * Create typed array for scalar value. Used for storing in `DataStorage`.
21908 */
21909 function createScalarValue(value, dtype) {
21910 if (dtype === 'string') {
21911 return encodeString(value);
21912 }
21913 return toTypedArray([value], dtype);
21914 }
21915 function noConversionNeeded(a, dtype) {
21916 return a instanceof Float32Array && dtype === 'float32' || a instanceof Int32Array && dtype === 'int32' || a instanceof Uint8Array && dtype === 'bool';
21917 }
21918 function toTypedArray(a, dtype) {
21919 if (dtype === 'string') {
21920 throw new Error('Cannot convert a string[] to a TypedArray');
21921 }
21922 if (Array.isArray(a)) {
21923 a = flatten$2(a);
21924 }
21925 if (env().getBool('DEBUG')) {
21926 checkConversionForErrors(a, dtype);
21927 }
21928 if (noConversionNeeded(a, dtype)) {
21929 return a;
21930 }
21931 if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
21932 return new Float32Array(a);
21933 } else if (dtype === 'int32') {
21934 return new Int32Array(a);
21935 } else if (dtype === 'bool') {
21936 var bool = new Uint8Array(a.length);
21937 for (var i = 0; i < bool.length; ++i) {
21938 if (Math.round(a[i]) !== 0) {
21939 bool[i] = 1;
21940 }
21941 }
21942 return bool;
21943 } else {
21944 throw new Error("Unknown data type ".concat(dtype));
21945 }
21946 }
21947 /**
21948 * Returns the current high-resolution time in milliseconds relative to an
21949 * arbitrary time in the past. It works across different platforms (node.js,
21950 * browsers).
21951 *
21952 * ```js
21953 * console.log(tf.util.now());
21954 * ```
21955 *
21956 * @doc {heading: 'Util', namespace: 'util'}
21957 */
21958 function now() {
21959 return env().platform.now();
21960 }
21961 /**
21962 * Returns a platform-specific implementation of
21963 * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
21964 *
21965 * If `fetch` is defined on the global object (`window`, `process`, etc.),
21966 * `tf.util.fetch` returns that function.
21967 *
21968 * If not, `tf.util.fetch` returns a platform-specific solution.
21969 *
21970 * ```js
21971 * const resource = await tf.util.fetch('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs');
21972 * // handle response
21973 * ```
21974 *
21975 * @doc {heading: 'Util'}
21976 */
21977 function fetch$1(path, requestInits) {
21978 return env().platform.fetch(path, requestInits);
21979 }
21980 /**
21981 * Encodes the provided string into bytes using the provided encoding scheme.
21982 *
21983 * @param s The string to encode.
21984 * @param encoding The encoding scheme. Defaults to utf-8.
21985 *
21986 * @doc {heading: 'Util'}
21987 */
21988 function encodeString(s) {
21989 var encoding = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 'utf-8';
21990 encoding = encoding || 'utf-8';
21991 return env().platform.encode(s, encoding);
21992 }
21993 /**
21994 * Decodes the provided bytes into a string using the provided encoding scheme.
21995 * @param bytes The bytes to decode.
21996 *
21997 * @param encoding The encoding scheme. Defaults to utf-8.
21998 *
21999 * @doc {heading: 'Util'}
22000 */
22001 function decodeString(bytes) {
22002 var encoding = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 'utf-8';
22003 encoding = encoding || 'utf-8';
22004 return env().platform.decode(bytes, encoding);
22005 }
22006 function isTypedArray(a) {
22007 // TODO(mattsoulanille): Remove this fallback in 5.0.0
22008 if (env().platform.isTypedArray != null) {
22009 return env().platform.isTypedArray(a);
22010 } else {
22011 return isTypedArrayBrowser(a);
22012 }
22013 }
22014 // NOTE: We explicitly type out what T extends instead of any so that
22015 // util.flatten on a nested array of number doesn't try to infer T as a
22016 // number[][], causing us to explicitly type util.flatten<number>().
22017 /**
22018 * Flattens an arbitrarily nested array.
22019 *
22020 * ```js
22021 * const a = [[1, 2], [3, 4], [5, [6, [7]]]];
22022 * const flat = tf.util.flatten(a);
22023 * console.log(flat);
22024 * ```
22025 *
22026 * @param arr The nested array to flatten.
22027 * @param result The destination array which holds the elements.
22028 * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
22029 * to false.
22030 *
22031 * @doc {heading: 'Util', namespace: 'util'}
22032 */
22033 function flatten$2(arr) {
22034 var result = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : [];
22035 var skipTypedArray = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
22036 if (result == null) {
22037 result = [];
22038 }
22039 if (typeof arr === 'boolean' || typeof arr === 'number' || typeof arr === 'string' || isPromise(arr) || arr == null || isTypedArray(arr) && skipTypedArray) {
22040 result.push(arr);
22041 } else if (Array.isArray(arr) || isTypedArray(arr)) {
22042 for (var i = 0; i < arr.length; ++i) {
22043 flatten$2(arr[i], result, skipTypedArray);
22044 }
22045 } else {
22046 var maxIndex = -1;
22047 for (var _i = 0, _Object$keys = Object.keys(arr); _i < _Object$keys.length; _i++) {
22048 var key = _Object$keys[_i];
22049 // 0 or positive integer.
22050 if (/^([1-9]+[0-9]*|0)$/.test(key)) {
22051 maxIndex = Math.max(maxIndex, Number(key));
22052 }
22053 }
22054 for (var _i2 = 0; _i2 <= maxIndex; _i2++) {
22055 // tslint:disable-next-line: no-unnecessary-type-assertion
22056 flatten$2(arr[_i2], result, skipTypedArray);
22057 }
22058 }
22059 return result;
22060 }
22061
22062 var util = {
22063 __proto__: null,
22064 arraysEqual: arraysEqual,
22065 arraysEqualWithNull: arraysEqualWithNull,
22066 assert: assert$1,
22067 assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions,
22068 assertNonNull: assertNonNull,
22069 assertShapesMatch: assertShapesMatch,
22070 bytesFromStringArray: bytesFromStringArray,
22071 bytesPerElement: bytesPerElement,
22072 checkConversionForErrors: checkConversionForErrors,
22073 clamp: clamp,
22074 computeStrides: computeStrides,
22075 convertBackendValuesAndArrayBuffer: convertBackendValuesAndArrayBuffer,
22076 createScalarValue: createScalarValue,
22077 createShuffledIndices: createShuffledIndices,
22078 decodeString: decodeString,
22079 distSquared: distSquared,
22080 encodeString: encodeString,
22081 fetch: fetch$1,
22082 fingerPrint64: fingerPrint64,
22083 flatten: flatten$2,
22084 getArrayFromDType: getArrayFromDType,
22085 getTypedArrayFromDType: getTypedArrayFromDType,
22086 hasEncodingLoss: hasEncodingLoss,
22087 hexToLong: hexToLong,
22088 indexToLoc: indexToLoc,
22089 inferDtype: inferDtype,
22090 inferFromImplicitShape: inferFromImplicitShape,
22091 isBoolean: isBoolean,
22092 isFunction: isFunction,
22093 isInt: isInt,
22094 isNumber: isNumber,
22095 isPromise: isPromise,
22096 isScalarShape: isScalarShape,
22097 isString: isString,
22098 isTypedArray: isTypedArray,
22099 isValidDtype: isValidDtype,
22100 locToIndex: locToIndex,
22101 makeOnesTypedArray: makeOnesTypedArray,
22102 makeZerosNestedTypedArray: makeZerosNestedTypedArray,
22103 makeZerosTypedArray: makeZerosTypedArray,
22104 nearestDivisor: nearestDivisor,
22105 nearestLargerEven: nearestLargerEven,
22106 now: now,
22107 parseAxisParam: parseAxisParam,
22108 randUniform: randUniform,
22109 repeatedTry: repeatedTry,
22110 rightPad: rightPad,
22111 shuffle: shuffle,
22112 shuffleCombo: shuffleCombo,
22113 sizeFromShape: sizeFromShape,
22114 sizeToSquarishShape: sizeToSquarishShape,
22115 squeezeShape: squeezeShape,
22116 sum: sum$4,
22117 swap: swap,
22118 tanh: tanh$3,
22119 toNestedArray: toNestedArray,
22120 toTypedArray: toTypedArray
22121 };
22122
22123 var Profiler = /*#__PURE__*/function () {
22124 function Profiler(backendTimer, logger) {
22125 _classCallCheck(this, Profiler);
22126 this.backendTimer = backendTimer;
22127 this.logger = logger;
22128 if (logger == null) {
22129 this.logger = new Logger();
22130 }
22131 }
22132 _createClass(Profiler, [{
22133 key: "profileKernel",
22134 value: function profileKernel(kernelName, inputs, f) {
22135 var outputs;
22136 var holdResultWrapperFn = function holdResultWrapperFn() {
22137 outputs = f();
22138 };
22139 var timer;
22140 var start = now();
22141 if (this.backendTimer.timerAvailable()) {
22142 timer = this.backendTimer.time(holdResultWrapperFn);
22143 } else {
22144 holdResultWrapperFn();
22145 var _iterator = _createForOfIteratorHelper(outputs),
22146 _step;
22147 try {
22148 for (_iterator.s(); !(_step = _iterator.n()).done;) {
22149 var output = _step.value;
22150 output.dataSync();
22151 }
22152 } catch (err) {
22153 _iterator.e(err);
22154 } finally {
22155 _iterator.f();
22156 }
22157 timer = Promise.resolve({
22158 kernelMs: now() - start
22159 });
22160 }
22161 if (env().getBool('CHECK_COMPUTATION_FOR_ERRORS')) {
22162 var _loop = function _loop() {
22163 var output = outputs[i];
22164 // Dangling promise here because we don't want to propagate up
22165 // asynchronicity.
22166 output.data().then(function (tensorVals) {
22167 checkComputationForErrors(tensorVals, output.dtype, kernelName);
22168 });
22169 };
22170 for (var i = 0; i < outputs.length; i++) {
22171 _loop();
22172 }
22173 }
22174 var kernelProfile = {
22175 kernelName: kernelName,
22176 outputs: outputs,
22177 inputs: inputs,
22178 timeMs: timer.then(function (timing) {
22179 return timing.kernelMs;
22180 }),
22181 extraInfo: timer.then(function (timing) {
22182 return timing.getExtraProfileInfo != null ? timing.getExtraProfileInfo() : '';
22183 })
22184 };
22185 return kernelProfile;
22186 }
22187 }, {
22188 key: "logKernelProfile",
22189 value: function logKernelProfile(kernelProfile) {
22190 var _this = this;
22191 var kernelName = kernelProfile.kernelName,
22192 outputs = kernelProfile.outputs,
22193 timeMs = kernelProfile.timeMs,
22194 inputs = kernelProfile.inputs,
22195 extraInfo = kernelProfile.extraInfo;
22196 outputs.forEach(function (result) {
22197 Promise.all([result.data(), timeMs, extraInfo]).then(function (valueContainer) {
22198 _this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]);
22199 });
22200 });
22201 }
22202 }]);
22203 return Profiler;
22204 }();
22205 function checkComputationForErrors(vals, dtype, kernelName) {
22206 if (dtype !== 'float32') {
22207 // Only floating point computations will generate NaN values
22208 return false;
22209 }
22210 for (var i = 0; i < vals.length; i++) {
22211 var num = vals[i];
22212 if (isNaN(num) || !isFinite(num)) {
22213 // Throwing custom exception so behavior is testable.
22214 console.warn("Found ".concat(num, " in the result of '").concat(kernelName, "'"));
22215 return true;
22216 }
22217 }
22218 return false;
22219 }
22220 var Logger = /*#__PURE__*/function () {
22221 function Logger() {
22222 _classCallCheck(this, Logger);
22223 }
22224 _createClass(Logger, [{
22225 key: "logKernelProfile",
22226 value: function logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) {
22227 var time = typeof timeMs === 'number' ? rightPad("".concat(timeMs, "ms"), 9) : timeMs['error'];
22228 var paddedName = rightPad(name, 25);
22229 var rank = result.rank;
22230 var size = result.size;
22231 var shape = rightPad(result.shape.toString(), 14);
22232 var inputShapesDescription = '';
22233 for (var _name in inputs) {
22234 var input = inputs[_name];
22235 if (input != null) {
22236 // The input might be a non-tensor (e.g HTMLImageElement), in which case
22237 // we claim the output shape as input shape.
22238 var inputShape = input.shape || result.shape;
22239 var inputRank = inputShape.length;
22240 inputShapesDescription += "".concat(_name, ": ").concat(inputRank, "D ").concat(inputRank > 0 ? inputShape : '', " ");
22241 }
22242 }
22243 console.log("%c".concat(paddedName, "\t%c").concat(time, "\t%c").concat(rank, "D ").concat(shape, "\t%c").concat(size, "\t%c").concat(inputShapesDescription, "\t%c").concat(extraInfo), 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue');
22244 }
22245 }]);
22246 return Logger;
22247 }();
22248
22249 /**
22250 * @license
22251 * Copyright 2017 Google LLC. All Rights Reserved.
22252 * Licensed under the Apache License, Version 2.0 (the "License");
22253 * you may not use this file except in compliance with the License.
22254 * You may obtain a copy of the License at
22255 *
22256 * http://www.apache.org/licenses/LICENSE-2.0
22257 *
22258 * Unless required by applicable law or agreed to in writing, software
22259 * distributed under the License is distributed on an "AS IS" BASIS,
22260 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22261 * See the License for the specific language governing permissions and
22262 * limitations under the License.
22263 * =============================================================================
22264 */
22265 /**
22266 * Computes a list of TapeNodes that connect x to y, filtering everything else
22267 * out and preserving the order of the original tape elements.
22268 *
22269 * @param tape The tape elements to filter.
22270 * @param xs The input Tensors.
22271 * @param y The output Tensor.
22272 */
22273 function getFilteredNodesXToY(tape, xs, y) {
22274 // Forward pass to compute all the nodes and Tensors that are transitively a
22275 // function of x.
22276 var tensorsFromX = {};
22277 var nodesFromX = {};
22278 for (var i = 0; i < xs.length; i++) {
22279 tensorsFromX[xs[i].id] = true;
22280 }
22281 for (var _i = 0; _i < tape.length; _i++) {
22282 var node = tape[_i];
22283 var nodeInputs = node.inputs;
22284 for (var inputName in nodeInputs) {
22285 var input = nodeInputs[inputName];
22286 var anyInputFromX = false;
22287 for (var j = 0; j < xs.length; j++) {
22288 if (tensorsFromX[input.id]) {
22289 node.outputs.forEach(function (output) {
22290 return tensorsFromX[output.id] = true;
22291 });
22292 anyInputFromX = true;
22293 nodesFromX[node.id] = true;
22294 break;
22295 }
22296 }
22297 if (anyInputFromX) {
22298 break;
22299 }
22300 }
22301 }
22302 // Backward pass to find all of the nodes and Tensors that lead to y.
22303 var tensorsLeadToY = {};
22304 tensorsLeadToY[y.id] = true;
22305 var nodesToY = {};
22306 for (var _i2 = tape.length - 1; _i2 >= 0; _i2--) {
22307 var _node = tape[_i2];
22308 var _nodeInputs = _node.inputs;
22309 // If any of the outputs lead to y, mark all of the inputs as leading to y.
22310 for (var _j = 0; _j < _node.outputs.length; _j++) {
22311 if (tensorsLeadToY[_node.outputs[_j].id]) {
22312 for (var _inputName in _nodeInputs) {
22313 tensorsLeadToY[_nodeInputs[_inputName].id] = true;
22314 nodesToY[_node.id] = true;
22315 }
22316 break;
22317 }
22318 }
22319 }
22320 // Return the paths that come from x and lead to y.
22321 var filteredTape = [];
22322 for (var _i3 = 0; _i3 < tape.length; _i3++) {
22323 var _node2 = tape[_i3];
22324 if (nodesFromX[_node2.id] && nodesToY[_node2.id]) {
22325 // Prune the inputs from the node that aren't a function of x.
22326 var prunedInputs = {};
22327 for (var _inputName2 in _node2.inputs) {
22328 var nodeInput = _node2.inputs[_inputName2];
22329 if (tensorsFromX[nodeInput.id]) {
22330 prunedInputs[_inputName2] = nodeInput;
22331 }
22332 }
22333 // Copy the node and overwrite inputsAndArgs to the pruned version.
22334 var prunedNode = Object.assign({}, _node2);
22335 prunedNode.inputs = prunedInputs;
22336 prunedNode.outputs = _node2.outputs;
22337 filteredTape.push(prunedNode);
22338 }
22339 }
22340 return filteredTape;
22341 }
22342 /**
22343 * Backpropagate gradients through the filtered TapeNodes.
22344 *
22345 * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
22346 * is mutated by this method.
22347 * @param filteredTape The filtered TapeNodes to backprop through.
22348 */
22349 function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) {
22350 var _loop = function _loop() {
22351 var node = filteredTape[i];
22352 var dys = [];
22353 node.outputs.forEach(function (o) {
22354 var gradTensor = tensorAccumulatedGradientMap[o.id];
22355 if (gradTensor != null) {
22356 dys.push(gradTensor);
22357 } else {
22358 // This particular output is not in the back-propagation subgraph, so it
22359 // does not affect the final output, thus we put null for its dy.
22360 dys.push(null);
22361 }
22362 });
22363 if (node.gradient == null) {
22364 throw new Error("Cannot compute gradient: gradient function not found " + "for ".concat(node.kernelName, "."));
22365 }
22366 // Backprop dy through this node and accumulate gradients over the inputs.
22367 var inputGradients = node.gradient(dys);
22368 var _loop2 = function _loop2(inputName) {
22369 if (!(inputName in inputGradients)) {
22370 throw new Error("Cannot backprop through input ".concat(inputName, ". ") + "Available gradients found: ".concat(Object.keys(inputGradients), "."));
22371 }
22372 // Call the gradient function.
22373 var dx = tidy(function () {
22374 return inputGradients[inputName]();
22375 });
22376 if (dx.dtype !== 'float32') {
22377 throw new Error("Error in gradient for op ".concat(node.kernelName, ". The gradient of input ") + "".concat(inputName, " must have 'float32' dtype, but has '").concat(dx.dtype, "'"));
22378 }
22379 var x = node.inputs[inputName];
22380 if (!arraysEqual(dx.shape, x.shape)) {
22381 throw new Error("Error in gradient for op ".concat(node.kernelName, ". The gradient of input ") + "'".concat(inputName, "' has shape '").concat(dx.shape, "', which does not match ") + "the shape of the input '".concat(x.shape, "'"));
22382 }
22383 if (tensorAccumulatedGradientMap[x.id] == null) {
22384 tensorAccumulatedGradientMap[x.id] = dx;
22385 } else {
22386 var curGradient = tensorAccumulatedGradientMap[x.id];
22387 tensorAccumulatedGradientMap[x.id] = add(curGradient, dx);
22388 curGradient.dispose();
22389 }
22390 };
22391 for (var inputName in node.inputs) {
22392 _loop2(inputName);
22393 }
22394 };
22395 // Walk the tape backward and keep a map of Tensor to its gradient.
22396 for (var i = filteredTape.length - 1; i >= 0; i--) {
22397 _loop();
22398 }
22399 }
22400
22401 // Maximum number of values before we decide to show ellipsis.
22402 var FORMAT_LIMIT_NUM_VALS = 20;
22403 // Number of first and last values to show when displaying a, b,...,y, z.
22404 var FORMAT_NUM_FIRST_LAST_VALS = 3;
22405 // Number of significant digits to show.
22406 var FORMAT_NUM_SIG_DIGITS = 7;
22407 function tensorToString(vals, shape, dtype, verbose) {
22408 var strides = computeStrides(shape);
22409 var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
22410 var rank = shape.length;
22411 var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
22412 var lines = ['Tensor'];
22413 if (verbose) {
22414 lines.push(" dtype: ".concat(dtype));
22415 lines.push(" rank: ".concat(rank));
22416 lines.push(" shape: [".concat(shape, "]"));
22417 lines.push(" values:");
22418 }
22419 lines.push(valsLines.map(function (l) {
22420 return ' ' + l;
22421 }).join('\n'));
22422 return lines.join('\n');
22423 }
22424 function computeMaxSizePerColumn(vals, shape, dtype, strides) {
22425 var n = sizeFromShape(shape);
22426 var numCols = strides[strides.length - 1];
22427 var padPerCol = new Array(numCols).fill(0);
22428 var rank = shape.length;
22429 var valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals;
22430 if (rank > 1) {
22431 for (var row = 0; row < n / numCols; row++) {
22432 var offset = row * numCols;
22433 for (var j = 0; j < numCols; j++) {
22434 padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
22435 }
22436 }
22437 }
22438 return padPerCol;
22439 }
22440 function valToString(val, pad, dtype) {
22441 var valStr;
22442 if (Array.isArray(val)) {
22443 valStr = "".concat(parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)), " + ") + "".concat(parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)), "j");
22444 } else if (isString(val)) {
22445 valStr = "'".concat(val, "'");
22446 } else if (dtype === 'bool') {
22447 valStr = boolNumToString(val);
22448 } else {
22449 valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
22450 }
22451 return rightPad(valStr, pad);
22452 }
22453 function boolNumToString(v) {
22454 return v === 0 ? 'false' : 'true';
22455 }
22456 function subTensorToString(vals, shape, dtype, strides, padPerCol) {
22457 var isLast = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : true;
22458 var storagePerElement = dtype === 'complex64' ? 2 : 1;
22459 var size = shape[0];
22460 var rank = shape.length;
22461 if (rank === 0) {
22462 if (dtype === 'complex64') {
22463 var complexTuple = createComplexTuples(vals);
22464 return [valToString(complexTuple[0], 0, dtype)];
22465 }
22466 if (dtype === 'bool') {
22467 return [boolNumToString(vals[0])];
22468 }
22469 return [vals[0].toString()];
22470 }
22471 if (rank === 1) {
22472 if (size > FORMAT_LIMIT_NUM_VALS) {
22473 var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
22474 var firstVals = Array.from(vals.slice(0, firstValsSize));
22475 var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
22476 if (dtype === 'complex64') {
22477 firstVals = createComplexTuples(firstVals);
22478 lastVals = createComplexTuples(lastVals);
22479 }
22480 return ['[' + firstVals.map(function (x, i) {
22481 return valToString(x, padPerCol[i], dtype);
22482 }).join(', ') + ', ..., ' + lastVals.map(function (x, i) {
22483 return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype);
22484 }).join(', ') + ']'];
22485 }
22486 var displayVals = dtype === 'complex64' ? createComplexTuples(vals) : Array.from(vals);
22487 return ['[' + displayVals.map(function (x, i) {
22488 return valToString(x, padPerCol[i], dtype);
22489 }).join(', ') + ']'];
22490 }
22491 // The array is rank 2 or more.
22492 var subshape = shape.slice(1);
22493 var substrides = strides.slice(1);
22494 var stride = strides[0] * storagePerElement;
22495 var lines = [];
22496 if (size > FORMAT_LIMIT_NUM_VALS) {
22497 for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
22498 var start = i * stride;
22499 var end = start + stride;
22500 lines.push.apply(lines, _toConsumableArray(subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */)));
22501 }
22502
22503 lines.push('...');
22504 for (var _i = size - FORMAT_NUM_FIRST_LAST_VALS; _i < size; _i++) {
22505 var _start = _i * stride;
22506 var _end = _start + stride;
22507 lines.push.apply(lines, _toConsumableArray(subTensorToString(vals.slice(_start, _end), subshape, dtype, substrides, padPerCol, _i === size - 1 /* isLast */)));
22508 }
22509 } else {
22510 for (var _i2 = 0; _i2 < size; _i2++) {
22511 var _start2 = _i2 * stride;
22512 var _end2 = _start2 + stride;
22513 lines.push.apply(lines, _toConsumableArray(subTensorToString(vals.slice(_start2, _end2), subshape, dtype, substrides, padPerCol, _i2 === size - 1 /* isLast */)));
22514 }
22515 }
22516
22517 var sep = rank === 2 ? ',' : '';
22518 lines[0] = '[' + (size > 0 ? lines[0] + sep : '');
22519 for (var _i3 = 1; _i3 < lines.length - 1; _i3++) {
22520 lines[_i3] = ' ' + lines[_i3] + sep;
22521 }
22522 var newLineSep = ',\n';
22523 for (var _i4 = 2; _i4 < rank; _i4++) {
22524 newLineSep += '\n';
22525 }
22526 lines[lines.length - 1] = ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep);
22527 return lines;
22528 }
22529 function createComplexTuples(vals) {
22530 var complexTuples = [];
22531 for (var i = 0; i < vals.length; i += 2) {
22532 complexTuples.push([vals[i], vals[i + 1]]);
22533 }
22534 return complexTuples;
22535 }
22536
22537 /**
22538 * A mutable object, similar to `tf.Tensor`, that allows users to set values
22539 * at locations before converting to an immutable `tf.Tensor`.
22540 *
22541 * See `tf.buffer` for creating a tensor buffer.
22542 *
22543 * @doc {heading: 'Tensors', subheading: 'Classes'}
22544 */
22545 var TensorBuffer = /*#__PURE__*/function () {
22546 function TensorBuffer(shape, dtype, values) {
22547 var _this = this;
22548 _classCallCheck(this, TensorBuffer);
22549 this.dtype = dtype;
22550 this.shape = shape.slice();
22551 this.size = sizeFromShape(shape);
22552 if (values != null) {
22553 var n = values.length;
22554 assert$1(n === this.size, function () {
22555 return "Length of values '".concat(n, "' does not match the size ") + "inferred by the shape '".concat(_this.size, "'.");
22556 });
22557 }
22558 if (dtype === 'complex64') {
22559 throw new Error("complex64 dtype TensorBuffers are not supported. Please create " + "a TensorBuffer for the real and imaginary parts separately and " + "call tf.complex(real, imag).");
22560 }
22561 this.values = values || getArrayFromDType(dtype, this.size);
22562 this.strides = computeStrides(shape);
22563 }
22564 /**
22565 * Sets a value in the buffer at a given location.
22566 *
22567 * @param value The value to set.
22568 * @param locs The location indices.
22569 *
22570 * @doc {heading: 'Tensors', subheading: 'Creation'}
22571 */
22572 _createClass(TensorBuffer, [{
22573 key: "set",
22574 value: function set(value) {
22575 var _this2 = this;
22576 for (var _len = arguments.length, locs = new Array(_len > 1 ? _len - 1 : 0), _key = 1; _key < _len; _key++) {
22577 locs[_key - 1] = arguments[_key];
22578 }
22579 if (locs.length === 0) {
22580 locs = [0];
22581 }
22582 assert$1(locs.length === this.rank, function () {
22583 return "The number of provided coordinates (".concat(locs.length, ") must ") + "match the rank (".concat(_this2.rank, ")");
22584 });
22585 var index = this.locToIndex(locs);
22586 this.values[index] = value;
22587 }
22588 /**
22589 * Returns the value in the buffer at the provided location.
22590 *
22591 * @param locs The location indices.
22592 *
22593 * @doc {heading: 'Tensors', subheading: 'Creation'}
22594 */
22595 }, {
22596 key: "get",
22597 value: function get() {
22598 for (var _len2 = arguments.length, locs = new Array(_len2), _key2 = 0; _key2 < _len2; _key2++) {
22599 locs[_key2] = arguments[_key2];
22600 }
22601 if (locs.length === 0) {
22602 locs = [0];
22603 }
22604 var i = 0;
22605 for (var _i = 0, _locs = locs; _i < _locs.length; _i++) {
22606 var loc = _locs[_i];
22607 if (loc < 0 || loc >= this.shape[i]) {
22608 var msg = "Requested out of range element at ".concat(locs, ". ") + " Buffer shape=".concat(this.shape);
22609 throw new Error(msg);
22610 }
22611 i++;
22612 }
22613 var index = locs[locs.length - 1];
22614 for (var _i2 = 0; _i2 < locs.length - 1; ++_i2) {
22615 index += this.strides[_i2] * locs[_i2];
22616 }
22617 return this.values[index];
22618 }
22619 }, {
22620 key: "locToIndex",
22621 value: function locToIndex(locs) {
22622 if (this.rank === 0) {
22623 return 0;
22624 } else if (this.rank === 1) {
22625 return locs[0];
22626 }
22627 var index = locs[locs.length - 1];
22628 for (var i = 0; i < locs.length - 1; ++i) {
22629 index += this.strides[i] * locs[i];
22630 }
22631 return index;
22632 }
22633 }, {
22634 key: "indexToLoc",
22635 value: function indexToLoc(index) {
22636 if (this.rank === 0) {
22637 return [];
22638 } else if (this.rank === 1) {
22639 return [index];
22640 }
22641 var locs = new Array(this.shape.length);
22642 for (var i = 0; i < locs.length - 1; ++i) {
22643 locs[i] = Math.floor(index / this.strides[i]);
22644 index -= locs[i] * this.strides[i];
22645 }
22646 locs[locs.length - 1] = index;
22647 return locs;
22648 }
22649 }, {
22650 key: "rank",
22651 get: function get() {
22652 return this.shape.length;
22653 }
22654 /**
22655 * Creates an immutable `tf.Tensor` object from the buffer.
22656 *
22657 * @doc {heading: 'Tensors', subheading: 'Creation'}
22658 */
22659 }, {
22660 key: "toTensor",
22661 value: function toTensor() {
22662 return trackerFn().makeTensor(this.values, this.shape, this.dtype);
22663 }
22664 }]);
22665 return TensorBuffer;
22666 }();
22667 // For tracking tensor creation and disposal.
22668 var trackerFn = null;
22669 // Used by chaining methods to call into ops.
22670 var opHandler$1 = null;
22671 // Used to warn about deprecated methods.
22672 var deprecationWarningFn = null;
22673 // This here so that we can use this method on dev branches and keep the
22674 // functionality at master.
22675 // tslint:disable-next-line:no-unused-expression
22676 [deprecationWarningFn];
22677 /**
22678 * An external consumer can register itself as the tensor tracker. This way
22679 * the Tensor class can notify the tracker for every tensor created and
22680 * disposed.
22681 */
22682 function setTensorTracker(fn) {
22683 trackerFn = fn;
22684 }
22685 /**
22686 * An external consumer can register itself as the op handler. This way the
22687 * Tensor class can have chaining methods that call into ops via the op
22688 * handler.
22689 */
22690 function setOpHandler(handler) {
22691 opHandler$1 = handler;
22692 }
22693 /**
22694 * Sets the deprecation warning function to be used by this file. This way the
22695 * Tensor class can be a leaf but still use the environment.
22696 */
22697 function setDeprecationWarningFn(fn) {
22698 deprecationWarningFn = fn;
22699 }
22700 /**
22701 * A `tf.Tensor` object represents an immutable, multidimensional array of
22702 * numbers that has a shape and a data type.
22703 *
22704 * For performance reasons, functions that create tensors do not necessarily
22705 * perform a copy of the data passed to them (e.g. if the data is passed as a
22706 * `Float32Array`), and changes to the data will change the tensor. This is not
22707 * a feature and is not supported. To avoid this behavior, use the tensor before
22708 * changing the input data or create a copy with `copy = tf.add(yourTensor, 0)`.
22709 *
22710 * See `tf.tensor` for details on how to create a `tf.Tensor`.
22711 *
22712 * @doc {heading: 'Tensors', subheading: 'Classes'}
22713 */
22714 var Tensor = /*#__PURE__*/function () {
22715 function Tensor(shape, dtype, dataId, id) {
22716 _classCallCheck(this, Tensor);
22717 /** Whether this tensor has been globally kept. */
22718 this.kept = false;
22719 this.isDisposedInternal = false;
22720 this.shape = shape.slice();
22721 this.dtype = dtype || 'float32';
22722 this.size = sizeFromShape(shape);
22723 this.strides = computeStrides(shape);
22724 this.dataId = dataId;
22725 this.id = id;
22726 this.rankType = this.rank < 5 ? this.rank.toString() : 'higher';
22727 }
22728 _createClass(Tensor, [{
22729 key: "rank",
22730 get: function get() {
22731 return this.shape.length;
22732 }
22733 /**
22734 * Returns a promise of `tf.TensorBuffer` that holds the underlying data.
22735 *
22736 * @doc {heading: 'Tensors', subheading: 'Classes'}
22737 */
22738 }, {
22739 key: "buffer",
22740 value: function () {
22741 var _buffer = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
22742 var vals;
22743 return _regeneratorRuntime().wrap(function _callee$(_context) {
22744 while (1) switch (_context.prev = _context.next) {
22745 case 0:
22746 _context.next = 2;
22747 return this.data();
22748 case 2:
22749 vals = _context.sent;
22750 return _context.abrupt("return", opHandler$1.buffer(this.shape, this.dtype, vals));
22751 case 4:
22752 case "end":
22753 return _context.stop();
22754 }
22755 }, _callee, this);
22756 }));
22757 function buffer() {
22758 return _buffer.apply(this, arguments);
22759 }
22760 return buffer;
22761 }()
22762 /**
22763 * Returns a `tf.TensorBuffer` that holds the underlying data.
22764 * @doc {heading: 'Tensors', subheading: 'Classes'}
22765 */
22766 }, {
22767 key: "bufferSync",
22768 value: function bufferSync() {
22769 return opHandler$1.buffer(this.shape, this.dtype, this.dataSync());
22770 }
22771 /**
22772 * Returns the tensor data as a nested array. The transfer of data is done
22773 * asynchronously.
22774 *
22775 * @doc {heading: 'Tensors', subheading: 'Classes'}
22776 */
22777 }, {
22778 key: "array",
22779 value: function () {
22780 var _array = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
22781 var vals;
22782 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
22783 while (1) switch (_context2.prev = _context2.next) {
22784 case 0:
22785 _context2.next = 2;
22786 return this.data();
22787 case 2:
22788 vals = _context2.sent;
22789 return _context2.abrupt("return", toNestedArray(this.shape, vals, this.dtype === 'complex64'));
22790 case 4:
22791 case "end":
22792 return _context2.stop();
22793 }
22794 }, _callee2, this);
22795 }));
22796 function array() {
22797 return _array.apply(this, arguments);
22798 }
22799 return array;
22800 }()
22801 /**
22802 * Returns the tensor data as a nested array. The transfer of data is done
22803 * synchronously.
22804 *
22805 * @doc {heading: 'Tensors', subheading: 'Classes'}
22806 */
22807 }, {
22808 key: "arraySync",
22809 value: function arraySync() {
22810 return toNestedArray(this.shape, this.dataSync(), this.dtype === 'complex64');
22811 }
22812 /**
22813 * Asynchronously downloads the values from the `tf.Tensor`. Returns a
22814 * promise of `TypedArray` that resolves when the computation has finished.
22815 *
22816 * @doc {heading: 'Tensors', subheading: 'Classes'}
22817 */
22818 }, {
22819 key: "data",
22820 value: function () {
22821 var _data = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3() {
22822 var data, bytes;
22823 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
22824 while (1) switch (_context3.prev = _context3.next) {
22825 case 0:
22826 this.throwIfDisposed();
22827 data = trackerFn().read(this.dataId);
22828 if (!(this.dtype === 'string')) {
22829 _context3.next = 13;
22830 break;
22831 }
22832 _context3.next = 5;
22833 return data;
22834 case 5:
22835 bytes = _context3.sent;
22836 _context3.prev = 6;
22837 return _context3.abrupt("return", bytes.map(function (b) {
22838 return decodeString(b);
22839 }));
22840 case 10:
22841 _context3.prev = 10;
22842 _context3.t0 = _context3["catch"](6);
22843 throw new Error('Failed to decode the string bytes into utf-8. ' + 'To get the original bytes, call tensor.bytes().');
22844 case 13:
22845 return _context3.abrupt("return", data);
22846 case 14:
22847 case "end":
22848 return _context3.stop();
22849 }
22850 }, _callee3, this, [[6, 10]]);
22851 }));
22852 function data() {
22853 return _data.apply(this, arguments);
22854 }
22855 return data;
22856 }()
22857 /**
22858 * Copy the tensor's data to a new GPU resource. Comparing to the `dataSync()`
22859 * and `data()`, this method prevents data from being downloaded to CPU.
22860 *
22861 * For WebGL backend, the data will be stored on a densely packed texture.
22862 * This means that the texture will use the RGBA channels to store value.
22863 *
22864 * For WebGPU backend, the data will be stored on a buffer. There is no
22865 * parameter, so can not use a user-defined size to create the buffer.
22866 *
22867 * @param options:
22868 * For WebGL,
22869 * - customTexShape: Optional. If set, will use the user defined
22870 * texture shape to create the texture.
22871 *
22872 * @returns For WebGL backend, a GPUData contains the new texture and
22873 * its information.
22874 * {
22875 * tensorRef: The tensor that is associated with this texture,
22876 * texture: WebGLTexture,
22877 * texShape: [number, number] // [height, width]
22878 * }
22879 *
22880 * For WebGPU backend, a GPUData contains the new buffer.
22881 * {
22882 * tensorRef: The tensor that is associated with this buffer,
22883 * buffer: GPUBuffer,
22884 * }
22885 *
22886 * Remember to dispose the GPUData after it is used by
22887 * `res.tensorRef.dispose()`.
22888 *
22889 * @doc {heading: 'Tensors', subheading: 'Classes'}
22890 */
22891 }, {
22892 key: "dataToGPU",
22893 value: function dataToGPU(options) {
22894 this.throwIfDisposed();
22895 return trackerFn().readToGPU(this.dataId, options);
22896 }
22897 /**
22898 * Synchronously downloads the values from the `tf.Tensor`. This blocks the
22899 * UI thread until the values are ready, which can cause performance issues.
22900 *
22901 * @doc {heading: 'Tensors', subheading: 'Classes'}
22902 */
22903 }, {
22904 key: "dataSync",
22905 value: function dataSync() {
22906 this.throwIfDisposed();
22907 var data = trackerFn().readSync(this.dataId);
22908 if (this.dtype === 'string') {
22909 try {
22910 return data.map(function (b) {
22911 return decodeString(b);
22912 });
22913 } catch (_a) {
22914 throw new Error('Failed to decode the string bytes into utf-8. ' + 'To get the original bytes, call tensor.bytes().');
22915 }
22916 }
22917 return data;
22918 }
22919 /** Returns the underlying bytes of the tensor's data. */
22920 }, {
22921 key: "bytes",
22922 value: function () {
22923 var _bytes = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4() {
22924 var data;
22925 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
22926 while (1) switch (_context4.prev = _context4.next) {
22927 case 0:
22928 this.throwIfDisposed();
22929 _context4.next = 3;
22930 return trackerFn().read(this.dataId);
22931 case 3:
22932 data = _context4.sent;
22933 if (!(this.dtype === 'string')) {
22934 _context4.next = 8;
22935 break;
22936 }
22937 return _context4.abrupt("return", data);
22938 case 8:
22939 return _context4.abrupt("return", new Uint8Array(data.buffer));
22940 case 9:
22941 case "end":
22942 return _context4.stop();
22943 }
22944 }, _callee4, this);
22945 }));
22946 function bytes() {
22947 return _bytes.apply(this, arguments);
22948 }
22949 return bytes;
22950 }()
22951 /**
22952 * Disposes `tf.Tensor` from memory.
22953 *
22954 * @doc {heading: 'Tensors', subheading: 'Classes'}
22955 */
22956 }, {
22957 key: "dispose",
22958 value: function dispose() {
22959 if (this.isDisposed) {
22960 return;
22961 }
22962 if (this.kerasMask) {
22963 this.kerasMask.dispose();
22964 }
22965 trackerFn().disposeTensor(this);
22966 this.isDisposedInternal = true;
22967 }
22968 }, {
22969 key: "isDisposed",
22970 get: function get() {
22971 return this.isDisposedInternal;
22972 }
22973 }, {
22974 key: "throwIfDisposed",
22975 value: function throwIfDisposed() {
22976 if (this.isDisposed) {
22977 throw new Error("Tensor is disposed.");
22978 }
22979 }
22980 /**
22981 * Prints the `tf.Tensor`. See `tf.print` for details.
22982 *
22983 * @param verbose Whether to print verbose information about the tensor,
22984 * including dtype and size.
22985 *
22986 * @doc {heading: 'Tensors', subheading: 'Classes'}
22987 */
22988 }, {
22989 key: "print",
22990 value: function print() {
22991 var verbose = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : false;
22992 return opHandler$1.print(this, verbose);
22993 }
22994 /**
22995 * Returns a copy of the tensor. See `tf.clone` for details.
22996 * @doc {heading: 'Tensors', subheading: 'Classes'}
22997 */
22998 }, {
22999 key: "clone",
23000 value: function clone() {
23001 this.throwIfDisposed();
23002 return opHandler$1.clone(this);
23003 }
23004 /**
23005 * Returns a human-readable description of the tensor. Useful for logging.
23006 *
23007 * @doc {heading: 'Tensors', subheading: 'Classes'}
23008 */
23009 }, {
23010 key: "toString",
23011 value: function toString() {
23012 var verbose = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : false;
23013 var vals = this.dataSync();
23014 return tensorToString(vals, this.shape, this.dtype, verbose);
23015 }
23016 }, {
23017 key: "cast",
23018 value: function cast(dtype) {
23019 this.throwIfDisposed();
23020 return opHandler$1.cast(this, dtype);
23021 }
23022 }, {
23023 key: "variable",
23024 value: function variable() {
23025 var trainable = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : true;
23026 var name = arguments.length > 1 ? arguments[1] : undefined;
23027 var dtype = arguments.length > 2 ? arguments[2] : undefined;
23028 this.throwIfDisposed();
23029 return trackerFn().makeVariable(this, trainable, name, dtype);
23030 }
23031 }]);
23032 return Tensor;
23033 }();
23034 Object.defineProperty(Tensor, Symbol.hasInstance, {
23035 value: function value(instance) {
23036 // Implementation note: we should use properties of the object that will be
23037 // defined before the constructor body has finished executing (methods).
23038 // This is because when this code is transpiled by babel, babel will call
23039 // classCallCheck before the constructor body is run.
23040 // See https://github.com/tensorflow/tfjs/issues/3384 for backstory.
23041 return !!instance && instance.data != null && instance.dataSync != null && instance.throwIfDisposed != null;
23042 }
23043 });
23044 function getGlobalTensorClass() {
23045 // Use getGlobal so that we can augment the Tensor class across package
23046 // boundaries because the node resolution alg may result in different modules
23047 // being returned for this file depending on the path they are loaded from.
23048 return getGlobal('Tensor', function () {
23049 return Tensor;
23050 });
23051 }
23052 // Global side effect. Cache global reference to Tensor class
23053 getGlobalTensorClass();
23054 /**
23055 * A mutable `tf.Tensor`, useful for persisting state, e.g. for training.
23056 *
23057 * @doc {heading: 'Tensors', subheading: 'Classes'}
23058 */
23059 var Variable = /*#__PURE__*/function (_Tensor) {
23060 _inherits(Variable, _Tensor);
23061 var _super = _createSuper(Variable);
23062 function Variable(initialValue, trainable, name, tensorId) {
23063 var _this3;
23064 _classCallCheck(this, Variable);
23065 _this3 = _super.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId);
23066 _this3.trainable = trainable;
23067 _this3.name = name;
23068 return _this3;
23069 }
23070 /**
23071 * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have
23072 * the same shape and dtype as the old `tf.Tensor`.
23073 *
23074 * @param newValue New tensor to be assigned to this variable.
23075 *
23076 * @doc {heading: 'Tensors', subheading: 'Classes'}
23077 */
23078 _createClass(Variable, [{
23079 key: "assign",
23080 value: function assign(newValue) {
23081 if (newValue.dtype !== this.dtype) {
23082 throw new Error("dtype of the new value (".concat(newValue.dtype, ") and ") + "previous value (".concat(this.dtype, ") must match"));
23083 }
23084 if (!arraysEqual(newValue.shape, this.shape)) {
23085 throw new Error("shape of the new value (".concat(newValue.shape, ") and ") + "previous value (".concat(this.shape, ") must match"));
23086 }
23087 trackerFn().disposeTensor(this);
23088 this.dataId = newValue.dataId;
23089 trackerFn().incRef(this, null /* backend */);
23090 }
23091 }, {
23092 key: "dispose",
23093 value: function dispose() {
23094 trackerFn().disposeVariable(this);
23095 this.isDisposedInternal = true;
23096 }
23097 }]);
23098 return Variable;
23099 }(Tensor);
23100 Object.defineProperty(Variable, Symbol.hasInstance, {
23101 value: function value(instance) {
23102 return instance instanceof Tensor && instance.assign != null && instance.assign instanceof Function;
23103 }
23104 });
23105
23106 /**
23107 * @license
23108 * Copyright 2017 Google LLC. All Rights Reserved.
23109 * Licensed under the Apache License, Version 2.0 (the "License");
23110 * you may not use this file except in compliance with the License.
23111 * You may obtain a copy of the License at
23112 *
23113 * http://www.apache.org/licenses/LICENSE-2.0
23114 *
23115 * Unless required by applicable law or agreed to in writing, software
23116 * distributed under the License is distributed on an "AS IS" BASIS,
23117 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23118 * See the License for the specific language governing permissions and
23119 * limitations under the License.
23120 * =============================================================================
23121 */
23122 exports.Rank = void 0;
23123 (function (Rank) {
23124 Rank["R0"] = "R0";
23125 Rank["R1"] = "R1";
23126 Rank["R2"] = "R2";
23127 Rank["R3"] = "R3";
23128 Rank["R4"] = "R4";
23129 Rank["R5"] = "R5";
23130 Rank["R6"] = "R6";
23131 })(exports.Rank || (exports.Rank = {}));
23132 // Looks for upcasting types. Used, for example, in operations with mixed dtype
23133 // inputs.
23134 var UpcastInt32AndMap;
23135 (function (UpcastInt32AndMap) {
23136 UpcastInt32AndMap["float32"] = "float32";
23137 UpcastInt32AndMap["int32"] = "int32";
23138 UpcastInt32AndMap["bool"] = "int32";
23139 UpcastInt32AndMap["complex64"] = "complex64";
23140 })(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
23141 var UpcastBoolAndMap;
23142 (function (UpcastBoolAndMap) {
23143 UpcastBoolAndMap["float32"] = "float32";
23144 UpcastBoolAndMap["int32"] = "int32";
23145 UpcastBoolAndMap["bool"] = "bool";
23146 UpcastBoolAndMap["complex64"] = "complex64";
23147 })(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
23148 var UpcastFloat32AndMap;
23149 (function (UpcastFloat32AndMap) {
23150 UpcastFloat32AndMap["float32"] = "float32";
23151 UpcastFloat32AndMap["int32"] = "float32";
23152 UpcastFloat32AndMap["bool"] = "float32";
23153 UpcastFloat32AndMap["complex64"] = "complex64";
23154 })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
23155 var UpcastComplex64AndMap;
23156 (function (UpcastComplex64AndMap) {
23157 UpcastComplex64AndMap["float32"] = "complex64";
23158 UpcastComplex64AndMap["int32"] = "complex64";
23159 UpcastComplex64AndMap["bool"] = "complex64";
23160 UpcastComplex64AndMap["complex64"] = "complex64";
23161 })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
23162 var upcastTypeMap = {
23163 'float32': UpcastFloat32AndMap,
23164 'int32': UpcastInt32AndMap,
23165 'bool': UpcastBoolAndMap,
23166 'complex64': UpcastComplex64AndMap
23167 };
23168 function upcastType(typeA, typeB) {
23169 if (typeA === 'string' || typeB === 'string') {
23170 if (typeA === 'string' && typeB === 'string') {
23171 return 'string';
23172 }
23173 throw new Error("Can not upcast ".concat(typeA, " with ").concat(typeB));
23174 }
23175 return upcastTypeMap[typeA][typeB];
23176 }
23177 /** Returns the output type after summation. */
23178 function sumOutType(type) {
23179 return upcastType(type, 'int32');
23180 }
23181 function isWebGLData(values) {
23182 return values != null && _typeof(values) === 'object' && 'texture' in values && values.texture instanceof WebGLTexture;
23183 }
23184 function isWebGPUData(values) {
23185 return typeof GPUBuffer !== 'undefined' && values != null && _typeof(values) === 'object' && 'buffer' in values && values.buffer instanceof GPUBuffer;
23186 }
23187
23188 function makeTypesMatch(a, b) {
23189 if (a.dtype === b.dtype) {
23190 return [a, b];
23191 }
23192 var dtype = upcastType(a.dtype, b.dtype);
23193 return [a.cast(dtype), b.cast(dtype)];
23194 }
23195 function assertTypesMatch(a, b) {
23196 assert$1(a.dtype === b.dtype, function () {
23197 return "The dtypes of the first(".concat(a.dtype, ") and") + " second(".concat(b.dtype, ") input must match");
23198 });
23199 }
23200 function isTensorInList(tensor, tensorList) {
23201 return tensorList.some(function (x) {
23202 return x.id === tensor.id;
23203 });
23204 }
23205 /**
23206 * Extracts any `Tensor`s found within the provided object.
23207 *
23208 * @param container an object that may be a `Tensor` or may directly contain
23209 * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it
23210 * is safe to pass any object here, except that `Promise`s are not
23211 * supported.
23212 * @returns An array of `Tensors` found within the passed object. If the
23213 * argument is simply a `Tensor', a list containing that `Tensor` is
23214 * returned. If the object is not a `Tensor` or does not
23215 * contain `Tensors`, an empty list is returned.
23216 */
23217 function getTensorsInContainer(result) {
23218 var list = [];
23219 var seen = new Set();
23220 walkTensorContainer(result, list, seen);
23221 return list;
23222 }
23223 function walkTensorContainer(container, list, seen) {
23224 if (container == null) {
23225 return;
23226 }
23227 if (container instanceof Tensor) {
23228 list.push(container);
23229 return;
23230 }
23231 if (!isIterable$1(container)) {
23232 return;
23233 }
23234 // Iteration over keys works also for arrays.
23235 var iterable = container;
23236 for (var k in iterable) {
23237 var val = iterable[k];
23238 if (!seen.has(val)) {
23239 seen.add(val);
23240 walkTensorContainer(val, list, seen);
23241 }
23242 }
23243 }
23244 // tslint:disable-next-line:no-any
23245 function isIterable$1(obj) {
23246 return Array.isArray(obj) || _typeof(obj) === 'object';
23247 }
23248
23249 var tensor_util = {
23250 __proto__: null,
23251 assertTypesMatch: assertTypesMatch,
23252 getTensorsInContainer: getTensorsInContainer,
23253 isTensorInList: isTensorInList,
23254 makeTypesMatch: makeTypesMatch
23255 };
23256
23257 function isRegisteredKernelInvocation(kernelInvocation) {
23258 return kernelInvocation.kernelName != null;
23259 }
23260 var EngineState = /*#__PURE__*/function () {
23261 function EngineState() {
23262 _classCallCheck(this, EngineState);
23263 // Public since optimizers will use it.
23264 this.registeredVariables = {};
23265 this.nextTapeNodeId = 0;
23266 this.numBytes = 0;
23267 this.numTensors = 0;
23268 this.numStringTensors = 0;
23269 this.numDataBuffers = 0;
23270 // Number of nested tf.grad() statements when computing higher-order
23271 // gradients. E.g. `1` for first-order gradients and `2` for second-order
23272 // gradients. Used to track if the tape should be removed after a backprop.
23273 this.gradientDepth = 0;
23274 // Number of nested kernel calls. When kernel depth is greater than 1, we turn
23275 // off the tape.
23276 this.kernelDepth = 0;
23277 this.scopeStack = [];
23278 /**
23279 * Keeps track of the number of data moves during a kernel execution. We
23280 * maintain a stack since kernels can call other kernels, recursively.
23281 */
23282 this.numDataMovesStack = [];
23283 this.nextScopeId = 0;
23284 this.tensorInfo = new WeakMap();
23285 this.profiling = false;
23286 this.activeProfile = {
23287 newBytes: 0,
23288 newTensors: 0,
23289 peakBytes: 0,
23290 kernels: [],
23291 result: null,
23292 get kernelNames() {
23293 return Array.from(new Set(this.kernels.map(function (k) {
23294 return k.name;
23295 })));
23296 }
23297 };
23298 }
23299 _createClass(EngineState, [{
23300 key: "dispose",
23301 value: function dispose() {
23302 for (var variableName in this.registeredVariables) {
23303 this.registeredVariables[variableName].dispose();
23304 }
23305 }
23306 }]);
23307 return EngineState;
23308 }();
23309 var Engine = /*#__PURE__*/function () {
23310 function Engine(ENV) {
23311 _classCallCheck(this, Engine);
23312 this.ENV = ENV;
23313 this.registry = {};
23314 this.registryFactory = {};
23315 this.pendingBackendInitId = 0;
23316 this.state = new EngineState();
23317 }
23318 _createClass(Engine, [{
23319 key: "ready",
23320 value: function () {
23321 var _ready = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
23322 var sortedBackends, i, backendName, success;
23323 return _regeneratorRuntime().wrap(function _callee$(_context) {
23324 while (1) switch (_context.prev = _context.next) {
23325 case 0:
23326 if (!(this.pendingBackendInit != null)) {
23327 _context.next = 2;
23328 break;
23329 }
23330 return _context.abrupt("return", this.pendingBackendInit.then(function () {}));
23331 case 2:
23332 if (!(this.backendInstance != null)) {
23333 _context.next = 4;
23334 break;
23335 }
23336 return _context.abrupt("return");
23337 case 4:
23338 sortedBackends = this.getSortedBackends();
23339 i = 0;
23340 case 6:
23341 if (!(i < sortedBackends.length)) {
23342 _context.next = 18;
23343 break;
23344 }
23345 backendName = sortedBackends[i];
23346 _context.next = 10;
23347 return this.initializeBackend(backendName).success;
23348 case 10:
23349 success = _context.sent;
23350 if (!success) {
23351 _context.next = 15;
23352 break;
23353 }
23354 _context.next = 14;
23355 return this.setBackend(backendName);
23356 case 14:
23357 return _context.abrupt("return");
23358 case 15:
23359 i++;
23360 _context.next = 6;
23361 break;
23362 case 18:
23363 throw new Error("Could not initialize any backends, all backend initializations " + "failed.");
23364 case 19:
23365 case "end":
23366 return _context.stop();
23367 }
23368 }, _callee, this);
23369 }));
23370 function ready() {
23371 return _ready.apply(this, arguments);
23372 }
23373 return ready;
23374 }()
23375 }, {
23376 key: "backend",
23377 get: function get() {
23378 if (this.pendingBackendInit != null) {
23379 throw new Error("Backend '".concat(this.backendName, "' has not yet been initialized. Make ") + "sure to await tf.ready() or await tf.setBackend() before calling " + "other methods");
23380 }
23381 if (this.backendInstance == null) {
23382 var _this$initializeBacke = this.initializeBackendsAndReturnBest(),
23383 name = _this$initializeBacke.name,
23384 asyncInit = _this$initializeBacke.asyncInit;
23385 if (asyncInit) {
23386 throw new Error("The highest priority backend '".concat(name, "' has not yet been ") + "initialized. Make sure to await tf.ready() or " + "await tf.setBackend() before calling other methods");
23387 }
23388 this.setBackend(name);
23389 }
23390 return this.backendInstance;
23391 }
23392 }, {
23393 key: "backendNames",
23394 value: function backendNames() {
23395 return Object.keys(this.registryFactory);
23396 }
23397 }, {
23398 key: "findBackend",
23399 value: function findBackend(backendName) {
23400 if (!(backendName in this.registry)) {
23401 // If the backend hasn't been initialized but we have a registry entry for
23402 // it, initialize it and return it.
23403 if (backendName in this.registryFactory) {
23404 var _this$initializeBacke2 = this.initializeBackend(backendName),
23405 asyncInit = _this$initializeBacke2.asyncInit;
23406 if (asyncInit) {
23407 // Backend is not ready yet.
23408 return null;
23409 }
23410 } else {
23411 return null;
23412 }
23413 }
23414 return this.registry[backendName];
23415 }
23416 }, {
23417 key: "findBackendFactory",
23418 value: function findBackendFactory(backendName) {
23419 if (!(backendName in this.registryFactory)) {
23420 return null;
23421 }
23422 return this.registryFactory[backendName].factory;
23423 }
23424 }, {
23425 key: "registerBackend",
23426 value: function registerBackend(backendName, factory) {
23427 var priority = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1;
23428 if (backendName in this.registryFactory) {
23429 warn("".concat(backendName, " backend was already registered. ") + "Reusing existing backend factory.");
23430 return false;
23431 }
23432 this.registryFactory[backendName] = {
23433 factory: factory,
23434 priority: priority
23435 };
23436 return true;
23437 }
23438 }, {
23439 key: "setBackend",
23440 value: function () {
23441 var _setBackend = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(backendName) {
23442 var _this$initializeBacke3, success, asyncInit, result;
23443 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
23444 while (1) switch (_context2.prev = _context2.next) {
23445 case 0:
23446 if (!(this.registryFactory[backendName] == null)) {
23447 _context2.next = 2;
23448 break;
23449 }
23450 throw new Error("Backend name '".concat(backendName, "' not found in registry"));
23451 case 2:
23452 this.backendName = backendName;
23453 if (!(this.registry[backendName] == null)) {
23454 _context2.next = 16;
23455 break;
23456 }
23457 this.backendInstance = null;
23458 _this$initializeBacke3 = this.initializeBackend(backendName), success = _this$initializeBacke3.success, asyncInit = _this$initializeBacke3.asyncInit;
23459 if (!asyncInit) {
23460 _context2.next = 12;
23461 break;
23462 }
23463 _context2.next = 9;
23464 return success;
23465 case 9:
23466 _context2.t0 = _context2.sent;
23467 _context2.next = 13;
23468 break;
23469 case 12:
23470 _context2.t0 = success;
23471 case 13:
23472 result = _context2.t0;
23473 if (result) {
23474 _context2.next = 16;
23475 break;
23476 }
23477 return _context2.abrupt("return", false);
23478 case 16:
23479 this.backendInstance = this.registry[backendName];
23480 this.setupRegisteredKernels();
23481 // Reset the profiler.
23482 this.profiler = new Profiler(this.backendInstance);
23483 return _context2.abrupt("return", true);
23484 case 20:
23485 case "end":
23486 return _context2.stop();
23487 }
23488 }, _callee2, this);
23489 }));
23490 function setBackend(_x) {
23491 return _setBackend.apply(this, arguments);
23492 }
23493 return setBackend;
23494 }()
23495 }, {
23496 key: "setupRegisteredKernels",
23497 value: function setupRegisteredKernels() {
23498 var _this = this;
23499 var kernels = getKernelsForBackend(this.backendName);
23500 kernels.forEach(function (kernel) {
23501 if (kernel.setupFunc != null) {
23502 kernel.setupFunc(_this.backendInstance);
23503 }
23504 });
23505 }
23506 }, {
23507 key: "disposeRegisteredKernels",
23508 value: function disposeRegisteredKernels(backendName) {
23509 var _this2 = this;
23510 var kernels = getKernelsForBackend(backendName);
23511 kernels.forEach(function (kernel) {
23512 if (kernel.disposeFunc != null) {
23513 kernel.disposeFunc(_this2.registry[backendName]);
23514 }
23515 });
23516 }
23517 /**
23518 * Initializes a backend by looking up the backend name in the factory
23519 * registry and calling the factory method. Returns a boolean representing
23520 * whether the initialization of the backend succeeded. Throws an error if
23521 * there is no backend in the factory registry.
23522 */
23523 }, {
23524 key: "initializeBackend",
23525 value: function initializeBackend(backendName) {
23526 var _this3 = this;
23527 var registryFactoryEntry = this.registryFactory[backendName];
23528 if (registryFactoryEntry == null) {
23529 throw new Error("Cannot initialize backend ".concat(backendName, ", no registration found."));
23530 }
23531 try {
23532 var backend = registryFactoryEntry.factory();
23533 /* Test if the factory returns a promise.
23534 Done in a more liberal way than
23535 previous 'Promise.resolve(backend)===backend'
23536 as we needed to account for custom Promise
23537 implementations (e.g. Angular) */
23538 if (backend && !(backend instanceof KernelBackend) && typeof backend.then === 'function') {
23539 var promiseId = ++this.pendingBackendInitId;
23540 var success = backend.then(function (backendInstance) {
23541 // Outdated promise. Another backend was set in the meantime.
23542 if (promiseId < _this3.pendingBackendInitId) {
23543 return false;
23544 }
23545 _this3.registry[backendName] = backendInstance;
23546 _this3.pendingBackendInit = null;
23547 return true;
23548 }).catch(function (err) {
23549 // Outdated promise. Another backend was set in the meantime.
23550 if (promiseId < _this3.pendingBackendInitId) {
23551 return false;
23552 }
23553 _this3.pendingBackendInit = null;
23554 warn("Initialization of backend ".concat(backendName, " failed"));
23555 warn(err.stack || err.message);
23556 return false;
23557 });
23558 this.pendingBackendInit = success;
23559 return {
23560 success: success,
23561 asyncInit: true
23562 };
23563 } else {
23564 this.registry[backendName] = backend;
23565 return {
23566 success: true,
23567 asyncInit: false
23568 };
23569 }
23570 } catch (err) {
23571 warn("Initialization of backend ".concat(backendName, " failed"));
23572 warn(err.stack || err.message);
23573 return {
23574 success: false,
23575 asyncInit: false
23576 };
23577 }
23578 }
23579 }, {
23580 key: "removeBackend",
23581 value: function removeBackend(backendName) {
23582 if (!(backendName in this.registryFactory)) {
23583 throw new Error("".concat(backendName, " backend not found in registry"));
23584 }
23585 if (this.backendName === backendName && this.pendingBackendInit != null) {
23586 // There is a pending promise of the backend we want to remove. Make it
23587 // obsolete.
23588 this.pendingBackendInitId++;
23589 }
23590 if (backendName in this.registry) {
23591 this.disposeRegisteredKernels(backendName);
23592 this.registry[backendName].dispose();
23593 delete this.registry[backendName];
23594 }
23595 delete this.registryFactory[backendName];
23596 // Unset the backend if it is active.
23597 if (this.backendName === backendName) {
23598 this.pendingBackendInit = null;
23599 this.backendName = null;
23600 this.backendInstance = null;
23601 }
23602 }
23603 }, {
23604 key: "getSortedBackends",
23605 value: function getSortedBackends() {
23606 var _this4 = this;
23607 if (Object.keys(this.registryFactory).length === 0) {
23608 throw new Error('No backend found in registry.');
23609 }
23610 return Object.keys(this.registryFactory).sort(function (a, b) {
23611 // Highest priority comes first.
23612 return _this4.registryFactory[b].priority - _this4.registryFactory[a].priority;
23613 });
23614 }
23615 }, {
23616 key: "initializeBackendsAndReturnBest",
23617 value: function initializeBackendsAndReturnBest() {
23618 var sortedBackends = this.getSortedBackends();
23619 for (var i = 0; i < sortedBackends.length; i++) {
23620 var backendName = sortedBackends[i];
23621 var _this$initializeBacke4 = this.initializeBackend(backendName),
23622 success = _this$initializeBacke4.success,
23623 asyncInit = _this$initializeBacke4.asyncInit;
23624 if (asyncInit || success) {
23625 return {
23626 name: backendName,
23627 asyncInit: asyncInit
23628 };
23629 }
23630 }
23631 throw new Error("Could not initialize any backends, all backend initializations " + "failed.");
23632 }
23633 }, {
23634 key: "moveData",
23635 value: function moveData(backend, dataId) {
23636 var info = this.state.tensorInfo.get(dataId);
23637 var srcBackend = info.backend;
23638 var values = this.readSync(dataId);
23639 var refCount = srcBackend.refCount(dataId);
23640 // Delete the tensor from the old backend and move it to the new
23641 // backend.
23642 srcBackend.disposeData(dataId, true);
23643 info.backend = backend;
23644 backend.move(dataId, values, info.shape, info.dtype, refCount);
23645 if (this.shouldCheckForMemLeaks()) {
23646 // Track the number of moves during a kernel execution to correctly
23647 // detect memory leaks.
23648 this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
23649 }
23650 }
23651 }, {
23652 key: "tidy",
23653 value: function tidy(nameOrFn, fn) {
23654 var _this5 = this;
23655 var name = null;
23656 if (fn == null) {
23657 // Called with only 1 argument.
23658 if (typeof nameOrFn !== 'function') {
23659 throw new Error('Please provide a function to tidy()');
23660 }
23661 fn = nameOrFn;
23662 } else {
23663 // Called with 2 arguments.
23664 if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
23665 throw new Error('When calling with two arguments, the first argument ' + 'to tidy() must be a string');
23666 }
23667 if (typeof fn !== 'function') {
23668 throw new Error('When calling with two arguments, the 2nd argument ' + 'to tidy() must be a function');
23669 }
23670 name = nameOrFn;
23671 // TODO(nsthorat,smilkov): Do operation logging and performance
23672 // profiling.
23673 }
23674
23675 var result;
23676 return this.scopedRun(function () {
23677 return _this5.startScope(name);
23678 }, function () {
23679 return _this5.endScope(result);
23680 }, function () {
23681 result = fn();
23682 if (result instanceof Promise) {
23683 console.error('Cannot return a Promise inside of tidy.');
23684 }
23685 return result;
23686 });
23687 }
23688 }, {
23689 key: "scopedRun",
23690 value: function scopedRun(start, end, f) {
23691 start();
23692 try {
23693 var res = f();
23694 end();
23695 return res;
23696 } catch (ex) {
23697 end();
23698 throw ex;
23699 }
23700 }
23701 }, {
23702 key: "nextTensorId",
23703 value: function nextTensorId() {
23704 return Engine.nextTensorId++;
23705 }
23706 }, {
23707 key: "nextVariableId",
23708 value: function nextVariableId() {
23709 return Engine.nextVariableId++;
23710 }
23711 /**
23712 * This method is called instead of the public-facing tensor.clone() when
23713 * saving a tensor for backwards pass. It makes sure to add the clone
23714 * operation to the tape regardless of being called inside a kernel
23715 * execution.
23716 */
23717 }, {
23718 key: "clone",
23719 value: function clone(x) {
23720 var y = ENGINE.runKernel(Identity$1, {
23721 x: x
23722 });
23723 var inputs = {
23724 x: x
23725 };
23726 var grad = function grad(dy) {
23727 return {
23728 x: function x() {
23729 var dtype = 'float32';
23730 var gradInputs = {
23731 x: dy
23732 };
23733 var attrs = {
23734 dtype: dtype
23735 };
23736 return ENGINE.runKernel(Cast, gradInputs,
23737 // tslint:disable-next-line: no-unnecessary-type-assertion
23738 attrs);
23739 }
23740 };
23741 };
23742 var saved = [];
23743 this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
23744 return y;
23745 }
23746 /**
23747 * Execute a kernel with the given name and return the output tensor.
23748 *
23749 * @param kernelName The name of the kernel to execute.
23750 * @param inputs A map of input names to tensors.
23751 * @param attrs A map of attribute names to their values. An attribute is a
23752 * primitive (non-tensor) input to the kernel.
23753 * @param inputsToSave A list of tensors, inputs to save for the backprop
23754 * computation.
23755 * @param outputsToSave A list of booleans, specifying which output to save
23756 * for the backprop computation. These are booleans since the output
23757 * tensors are not visible to the user.
23758 */
23759 }, {
23760 key: "runKernel",
23761 value: function runKernel(kernelName, inputs, attrs) {
23762 if (this.backendName == null) {
23763 // backend has not been initialized yet (backend initialization is lazy
23764 // can be deferred until an op/ kernel is run).
23765 // The below getter has side effects that will try to initialize the
23766 // backend and set properties like this.backendName
23767 // tslint:disable-next-line: no-unused-expression
23768 this.backend;
23769 }
23770 var hasKernel = getKernel(kernelName, this.backendName) != null;
23771 if (!hasKernel) {
23772 throw new Error("Kernel '".concat(kernelName, "' not registered for backend '").concat(this.backendName, "'"));
23773 }
23774 return this.runKernelFunc({
23775 kernelName: kernelName,
23776 inputs: inputs,
23777 attrs: attrs
23778 });
23779 }
23780 }, {
23781 key: "shouldCheckForMemLeaks",
23782 value: function shouldCheckForMemLeaks() {
23783 return this.ENV.getBool('IS_TEST');
23784 }
23785 }, {
23786 key: "checkKernelForMemLeak",
23787 value: function checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) {
23788 var numDataIdsAfter = this.backend.numDataIds();
23789 // Count the number of data ids associated with the result of the kernel.
23790 var numOutputDataIds = 0;
23791 outInfos.forEach(function (info) {
23792 // Complex numbers allocate 3 data ids, one for 'real', one for
23793 // 'imaginary', and one for the container that holds the former two.
23794 numOutputDataIds += info.dtype === 'complex64' ? 3 : 1;
23795 });
23796 // Account for the number of moves during kernel execution. A "data move"
23797 // can happen in the middle of a kernel execution, placing a new (key,value)
23798 // pair in the data storage. Since data moves have net zero effect (we
23799 // always remove the data from the old backend), we have to cancel them out
23800 // when detecting memory leaks.
23801 var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
23802 var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
23803 if (dataIdsLeaked > 0) {
23804 throw new Error("Backend '".concat(this.backendName, "' has an internal memory leak ") + "(".concat(dataIdsLeaked, " data ids) after running '").concat(kernelName, "'"));
23805 }
23806 }
23807 /**
23808 * Internal helper method to execute a kernel Func
23809 *
23810 * Use `runKernel` to execute kernels from outside of engine.
23811 */
23812 }, {
23813 key: "runKernelFunc",
23814 value: function runKernelFunc(kernelParams) {
23815 var _this6 = this;
23816 var outputs;
23817 var saved = [];
23818 var isTapeOn = this.isTapeOn();
23819 var startingBytecount = this.state.numBytes;
23820 var startingNumTensors = this.state.numTensors;
23821 if (this.shouldCheckForMemLeaks()) {
23822 this.state.numDataMovesStack.push(0);
23823 }
23824 var kernelFunc;
23825 if (this.backendName == null) {
23826 // backend has not been initialized yet (backend initialization is lazy
23827 // can be deferred until an op/ kernel is run).
23828 // The below getter has side effects that will try to initialize the
23829 // backend and set properties like this.backendName
23830 // tslint:disable-next-line: no-unused-expression
23831 this.backend;
23832 }
23833 var out;
23834 var kernelOrScopeName = isRegisteredKernelInvocation(kernelParams) ? kernelParams.kernelName : this.state.activeScope != null ? this.state.activeScope.name : '';
23835 // Create the kernelFunc from either a registered kernel OR passed in
23836 // forward/backward functions (used by custom grad). In this context a
23837 // kernelFunc wraps a kernel implementation with some bookkeeping.
23838 if (isRegisteredKernelInvocation(kernelParams)) {
23839 var kernelName = kernelParams.kernelName,
23840 _inputs = kernelParams.inputs,
23841 _attrs = kernelParams.attrs;
23842 if (this.backendName == null) {
23843 // backend has not been initialized yet (backend initialization is lazy
23844 // can be deferred until an op/ kernel is run).
23845 // The below getter has side effects that will try to initialize the
23846 // backend and set properties like this.backendName
23847 // tslint:disable-next-line: no-unused-expression
23848 this.backend;
23849 }
23850 var kernel = getKernel(kernelName, this.backendName);
23851 assert$1(kernel != null, function () {
23852 return "Cannot find registered kernel '".concat(kernelName, "' for backend '").concat(_this6.backendName, "'");
23853 });
23854 kernelFunc = function kernelFunc() {
23855 var numDataIdsBefore = _this6.backend.numDataIds();
23856 out = kernel.kernelFunc({
23857 inputs: _inputs,
23858 attrs: _attrs,
23859 backend: _this6.backend
23860 });
23861 var outInfos = Array.isArray(out) ? out : [out];
23862 if (_this6.shouldCheckForMemLeaks()) {
23863 _this6.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
23864 }
23865 var outTensors = outInfos.map(function (outInfo) {
23866 // todo (yassogba) remove this option (Tensor) when node backend
23867 // methods have been modularized and they all return tensorInfo.
23868 // TensorInfos do not have a rank attribute.
23869 if (outInfo.rank != null) {
23870 return outInfo;
23871 }
23872 return _this6.makeTensorFromTensorInfo(outInfo);
23873 });
23874 // Save any required inputs and outputs.
23875 // Do not save unless we are recording to the tape. Otherwise it would
23876 // cause a mem leak since there would be no backprop for these tensors
23877 // (which would otherwise dispose them).
23878 if (isTapeOn) {
23879 var tensorsToSave = _this6.getTensorsForGradient(kernelName, _inputs, outTensors);
23880 saved = _this6.saveTensorsForBackwardMode(tensorsToSave);
23881 }
23882 return outTensors;
23883 };
23884 } else {
23885 var forwardFunc = kernelParams.forwardFunc;
23886 // Running a customGrad op.
23887 var saveFunc = function saveFunc(tensors) {
23888 // Do not save unless we are recording to the tape. Otherwise it would
23889 // cause a mem leak since we would never run backprop, which disposes
23890 // the kept tensors.
23891 if (!isTapeOn) {
23892 return;
23893 }
23894 saved = tensors.map(function (tensor) {
23895 return _this6.keep(_this6.clone(tensor));
23896 });
23897 };
23898 kernelFunc = function kernelFunc() {
23899 var numDataIdsBefore = _this6.backend.numDataIds();
23900 out = _this6.tidy(function () {
23901 return forwardFunc(_this6.backend, saveFunc);
23902 });
23903 var outs = Array.isArray(out) ? out : [out];
23904 if (_this6.shouldCheckForMemLeaks()) {
23905 // Scope name is used to print a more helpful error message if needed.
23906 _this6.checkKernelForMemLeak(kernelOrScopeName, numDataIdsBefore, outs);
23907 }
23908 return outs;
23909 };
23910 }
23911 //
23912 // Run the kernelFunc. Optionally profiling it.
23913 //
23914 var inputs = kernelParams.inputs,
23915 attrs = kernelParams.attrs;
23916 var backwardsFunc = isRegisteredKernelInvocation(kernelParams) ? null : kernelParams.backwardsFunc;
23917 var kernelProfile;
23918 this.scopedRun(
23919 // Stop recording to a tape when running a kernel.
23920 function () {
23921 return _this6.state.kernelDepth++;
23922 }, function () {
23923 return _this6.state.kernelDepth--;
23924 }, function () {
23925 if (!_this6.ENV.getBool('DEBUG') && !_this6.state.profiling) {
23926 outputs = kernelFunc();
23927 } else {
23928 kernelProfile = _this6.profiler.profileKernel(kernelOrScopeName, inputs, function () {
23929 return kernelFunc();
23930 });
23931 if (_this6.ENV.getBool('DEBUG')) {
23932 _this6.profiler.logKernelProfile(kernelProfile);
23933 }
23934 outputs = kernelProfile.outputs;
23935 }
23936 });
23937 if (isTapeOn) {
23938 this.addTapeNode(kernelOrScopeName, inputs, outputs, backwardsFunc, saved, attrs);
23939 }
23940 if (this.state.profiling) {
23941 this.state.activeProfile.kernels.push({
23942 name: kernelOrScopeName,
23943 bytesAdded: this.state.numBytes - startingBytecount,
23944 totalBytesSnapshot: this.state.numBytes,
23945 tensorsAdded: this.state.numTensors - startingNumTensors,
23946 totalTensorsSnapshot: this.state.numTensors,
23947 inputShapes: Object.keys(inputs).map(function (key) {
23948 return inputs[key] != null ? inputs[key].shape : null;
23949 }),
23950 outputShapes: outputs.map(function (item) {
23951 return item.shape;
23952 }),
23953 kernelTimeMs: kernelProfile.timeMs,
23954 extraInfo: kernelProfile.extraInfo
23955 });
23956 }
23957 return Array.isArray(out) ? outputs : outputs[0];
23958 }
23959 /**
23960 * Saves tensors used in forward mode for use in backward mode.
23961 *
23962 * @param tensors the list of tensors to save.
23963 */
23964 }, {
23965 key: "saveTensorsForBackwardMode",
23966 value: function saveTensorsForBackwardMode(tensors) {
23967 var _this7 = this;
23968 var saved = tensors.map(function (tensor) {
23969 return _this7.keep(_this7.clone(tensor));
23970 });
23971 return saved;
23972 }
23973 /**
23974 * Returns a list of tensors to save for a given gradient calculation.
23975 *
23976 * @param kernelName name of kernel to look up gradient for.
23977 * @param inputs a map of input tensors.
23978 * @param outputs an array of output tensors from forward mode of kernel.
23979 */
23980 }, {
23981 key: "getTensorsForGradient",
23982 value: function getTensorsForGradient(kernelName, inputs, outputs) {
23983 var gradConfig = getGradient(kernelName);
23984 if (gradConfig != null) {
23985 var inputsToSave = gradConfig.inputsToSave || [];
23986 var outputsToSave = gradConfig.outputsToSave || [];
23987 // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs
23988 // specified in inputsToSave will be saved.
23989 var inputTensorsToSave;
23990 if (gradConfig.saveAllInputs) {
23991 assert$1(Array.isArray(inputs), function () {
23992 return 'saveAllInputs is true, expected inputs to be an array.';
23993 });
23994 inputTensorsToSave = Object.keys(inputs).map(function (key) {
23995 return inputs[key];
23996 });
23997 } else {
23998 inputTensorsToSave = inputsToSave.map(function (inputName) {
23999 return inputs[inputName];
24000 });
24001 }
24002 var outputTensorsToSave = outputs.filter(function (_, i) {
24003 return outputsToSave[i];
24004 });
24005 return inputTensorsToSave.concat(outputTensorsToSave);
24006 }
24007 // We return an empty list rather than throw an error because the kernel we
24008 // are looking up may not actually be relevant to backproping through the
24009 // overall function
24010 //
24011 // See 'does not error if irrelevant (pruned) ops are missing grads' test
24012 // in gradients_test.ts for an example.
24013 return [];
24014 }
24015 /**
24016 * Internal method used by public APIs for tensor creation. Makes a new
24017 * tensor with the provided shape, dtype and values. It always
24018 * creates a new data id and writes the values to the underlying backend.
24019 */
24020 }, {
24021 key: "makeTensor",
24022 value: function makeTensor(values, shape, dtype, backend) {
24023 if (values == null) {
24024 throw new Error('Values passed to engine.makeTensor() are null');
24025 }
24026 dtype = dtype || 'float32';
24027 backend = backend || this.backend;
24028 var backendVals = values;
24029 if (dtype === 'string' && isString(values[0])) {
24030 backendVals = values.map(function (d) {
24031 return encodeString(d);
24032 });
24033 }
24034 var dataId = backend.write(backendVals, shape, dtype);
24035 var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
24036 this.trackTensor(t, backend);
24037 // Count bytes for string tensors.
24038 if (dtype === 'string') {
24039 var info = this.state.tensorInfo.get(dataId);
24040 var newBytes = bytesFromStringArray(backendVals);
24041 this.state.numBytes += newBytes - info.bytes;
24042 info.bytes = newBytes;
24043 }
24044 return t;
24045 }
24046 /**
24047 * Internal method used by backends. Makes a new tensor
24048 * that is a wrapper around an existing data id. It doesn't create
24049 * a new data id, only increments the ref count used in memory tracking.
24050 * @deprecated
24051 */
24052 }, {
24053 key: "makeTensorFromDataId",
24054 value: function makeTensorFromDataId(dataId, shape, dtype, backend) {
24055 dtype = dtype || 'float32';
24056 var tensorInfo = {
24057 dataId: dataId,
24058 shape: shape,
24059 dtype: dtype
24060 };
24061 return this.makeTensorFromTensorInfo(tensorInfo, backend);
24062 }
24063 /**
24064 * Internal method used by backends. Makes a new tensor that is a wrapper
24065 * around an existing data id in TensorInfo. It doesn't create a new data id,
24066 * only increments the ref count used in memory tracking.
24067 */
24068 }, {
24069 key: "makeTensorFromTensorInfo",
24070 value: function makeTensorFromTensorInfo(tensorInfo, backend) {
24071 var dataId = tensorInfo.dataId,
24072 shape = tensorInfo.shape,
24073 dtype = tensorInfo.dtype;
24074 var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
24075 this.trackTensor(t, backend);
24076 return t;
24077 }
24078 }, {
24079 key: "makeVariable",
24080 value: function makeVariable(initialValue) {
24081 var trainable = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : true;
24082 var name = arguments.length > 2 ? arguments[2] : undefined;
24083 var dtype = arguments.length > 3 ? arguments[3] : undefined;
24084 name = name || this.nextVariableId().toString();
24085 if (dtype != null && dtype !== initialValue.dtype) {
24086 initialValue = initialValue.cast(dtype);
24087 }
24088 var v = new Variable(initialValue, trainable, name, this.nextTensorId());
24089 if (this.state.registeredVariables[v.name] != null) {
24090 throw new Error("Variable with name ".concat(v.name, " was already registered"));
24091 }
24092 this.state.registeredVariables[v.name] = v;
24093 this.incRef(v, this.backend);
24094 return v;
24095 }
24096 }, {
24097 key: "trackTensor",
24098 value: function trackTensor(a, backend) {
24099 this.state.numTensors++;
24100 if (a.dtype === 'string') {
24101 this.state.numStringTensors++;
24102 }
24103 // Bytes for complex numbers are counted by their components. Bytes for
24104 // string tensors are counted when writing values.
24105 var bytes = 0;
24106 if (a.dtype !== 'complex64' && a.dtype !== 'string') {
24107 bytes = a.size * bytesPerElement(a.dtype);
24108 }
24109 this.state.numBytes += bytes;
24110 if (!this.state.tensorInfo.has(a.dataId)) {
24111 this.state.numDataBuffers++;
24112 this.state.tensorInfo.set(a.dataId, {
24113 backend: backend || this.backend,
24114 dtype: a.dtype,
24115 shape: a.shape,
24116 bytes: bytes
24117 });
24118 }
24119 if (!(a instanceof Variable)) {
24120 this.track(a);
24121 }
24122 }
24123 // Track the tensor by dataId and increase the refCount for the dataId in the
24124 // backend.
24125 // TODO(pyu10055): This is currently used by makeVariable method, to increase
24126 // refCount on the backend for the dataId. It can potentially be replaced with
24127 // Identity op indead of calling backend directly.
24128 }, {
24129 key: "incRef",
24130 value: function incRef(a, backend) {
24131 this.trackTensor(a, backend);
24132 this.backend.incRef(a.dataId);
24133 }
24134 }, {
24135 key: "removeDataId",
24136 value: function removeDataId(dataId, backend) {
24137 if (this.state.tensorInfo.has(dataId) && this.state.tensorInfo.get(dataId).backend === backend) {
24138 this.state.tensorInfo.delete(dataId);
24139 this.state.numDataBuffers--;
24140 }
24141 }
24142 }, {
24143 key: "disposeTensor",
24144 value: function disposeTensor(a) {
24145 if (!this.state.tensorInfo.has(a.dataId)) {
24146 return;
24147 }
24148 var info = this.state.tensorInfo.get(a.dataId);
24149 this.state.numTensors--;
24150 if (a.dtype === 'string') {
24151 this.state.numStringTensors--;
24152 this.state.numBytes -= info.bytes;
24153 }
24154 // Don't count bytes for complex numbers as they are counted by their
24155 // components.
24156 if (a.dtype !== 'complex64' && a.dtype !== 'string') {
24157 var bytes = a.size * bytesPerElement(a.dtype);
24158 this.state.numBytes -= bytes;
24159 }
24160 // Remove the reference to dataId if backend dispose the data successfully
24161 if (info.backend.disposeData(a.dataId)) {
24162 this.removeDataId(a.dataId, info.backend);
24163 }
24164 // TODO(nsthorat): Construct an error and save the stack trace for
24165 // debugging when in debug mode. Creating a stack trace is too expensive
24166 // to do unconditionally.
24167 }
24168 }, {
24169 key: "disposeVariables",
24170 value: function disposeVariables() {
24171 for (var varName in this.state.registeredVariables) {
24172 var v = this.state.registeredVariables[varName];
24173 this.disposeVariable(v);
24174 }
24175 }
24176 }, {
24177 key: "disposeVariable",
24178 value: function disposeVariable(v) {
24179 this.disposeTensor(v);
24180 if (this.state.registeredVariables[v.name] != null) {
24181 delete this.state.registeredVariables[v.name];
24182 }
24183 }
24184 }, {
24185 key: "memory",
24186 value: function memory() {
24187 var info = this.backend.memory();
24188 info.numTensors = this.state.numTensors;
24189 info.numDataBuffers = this.state.numDataBuffers;
24190 info.numBytes = this.state.numBytes;
24191 if (this.state.numStringTensors > 0) {
24192 info.unreliable = true;
24193 if (info.reasons == null) {
24194 info.reasons = [];
24195 }
24196 info.reasons.push('Memory usage by string tensors is approximate ' + '(2 bytes per character)');
24197 }
24198 return info;
24199 }
24200 }, {
24201 key: "profile",
24202 value: function () {
24203 var _profile = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(query) {
24204 var startBytes, startNumTensors, _iterator, _step, kernel;
24205 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
24206 while (1) switch (_context3.prev = _context3.next) {
24207 case 0:
24208 this.state.profiling = true;
24209 startBytes = this.state.numBytes;
24210 startNumTensors = this.state.numTensors;
24211 this.state.activeProfile.kernels = [];
24212 _context3.next = 6;
24213 return query();
24214 case 6:
24215 this.state.activeProfile.result = _context3.sent;
24216 this.state.profiling = false;
24217 this.state.activeProfile.peakBytes = Math.max.apply(Math, _toConsumableArray(this.state.activeProfile.kernels.map(function (d) {
24218 return d.totalBytesSnapshot;
24219 })));
24220 this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
24221 this.state.activeProfile.newTensors = this.state.numTensors - startNumTensors;
24222 _iterator = _createForOfIteratorHelper(this.state.activeProfile.kernels);
24223 _context3.prev = 12;
24224 _iterator.s();
24225 case 14:
24226 if ((_step = _iterator.n()).done) {
24227 _context3.next = 24;
24228 break;
24229 }
24230 kernel = _step.value;
24231 _context3.next = 18;
24232 return kernel.kernelTimeMs;
24233 case 18:
24234 kernel.kernelTimeMs = _context3.sent;
24235 _context3.next = 21;
24236 return kernel.extraInfo;
24237 case 21:
24238 kernel.extraInfo = _context3.sent;
24239 case 22:
24240 _context3.next = 14;
24241 break;
24242 case 24:
24243 _context3.next = 29;
24244 break;
24245 case 26:
24246 _context3.prev = 26;
24247 _context3.t0 = _context3["catch"](12);
24248 _iterator.e(_context3.t0);
24249 case 29:
24250 _context3.prev = 29;
24251 _iterator.f();
24252 return _context3.finish(29);
24253 case 32:
24254 return _context3.abrupt("return", this.state.activeProfile);
24255 case 33:
24256 case "end":
24257 return _context3.stop();
24258 }
24259 }, _callee3, this, [[12, 26, 29, 32]]);
24260 }));
24261 function profile(_x2) {
24262 return _profile.apply(this, arguments);
24263 }
24264 return profile;
24265 }()
24266 }, {
24267 key: "isTapeOn",
24268 value: function isTapeOn() {
24269 return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
24270 }
24271 }, {
24272 key: "addTapeNode",
24273 value: function addTapeNode(kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
24274 var _this8 = this;
24275 var tapeNode = {
24276 id: this.state.nextTapeNodeId++,
24277 kernelName: kernelName,
24278 inputs: inputs,
24279 outputs: outputs,
24280 saved: saved
24281 };
24282 var gradConfig = getGradient(kernelName);
24283 if (gradConfig != null) {
24284 gradientsFunc = gradConfig.gradFunc;
24285 }
24286 if (gradientsFunc != null) {
24287 tapeNode.gradient = function (dys) {
24288 // TODO(smilkov): To optimize back-prop, pass dys that are not used in
24289 // the backprop graph to the user as null instead of zeros
24290 dys = dys.map(function (dy, i) {
24291 if (dy == null) {
24292 var output = outputs[i];
24293 var vals = makeZerosTypedArray(output.size, output.dtype);
24294 return _this8.makeTensor(vals, output.shape, output.dtype);
24295 }
24296 return dy;
24297 });
24298 // Grad functions of ops with single outputs expect a dy, while ops
24299 // with multiple outputs expect dys (array of dy).
24300 return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
24301 };
24302 }
24303 this.state.activeTape.push(tapeNode);
24304 }
24305 }, {
24306 key: "keep",
24307 value: function keep(result) {
24308 result.kept = true;
24309 return result;
24310 }
24311 }, {
24312 key: "startTape",
24313 value: function startTape() {
24314 if (this.state.gradientDepth === 0) {
24315 this.state.activeTape = [];
24316 }
24317 this.state.gradientDepth++;
24318 }
24319 }, {
24320 key: "endTape",
24321 value: function endTape() {
24322 this.state.gradientDepth--;
24323 }
24324 /**
24325 * Start a scope. Use this with endScope() to achieve the same functionality
24326 * as scope() without the need for a function closure.
24327 */
24328 }, {
24329 key: "startScope",
24330 value: function startScope(name) {
24331 var scopeInfo = {
24332 track: [],
24333 name: 'unnamed scope',
24334 id: this.state.nextScopeId++
24335 };
24336 if (name) {
24337 scopeInfo.name = name;
24338 }
24339 this.state.scopeStack.push(scopeInfo);
24340 this.state.activeScope = scopeInfo;
24341 }
24342 /**
24343 * End a scope. Use this with startScope() to achieve the same functionality
24344 * as scope() without the need for a function closure.
24345 */
24346 }, {
24347 key: "endScope",
24348 value: function endScope(result) {
24349 var _this9 = this;
24350 var tensorsToTrackInParent = getTensorsInContainer(result);
24351 var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function (t) {
24352 return t.id;
24353 }));
24354 // Dispose the arrays tracked in this scope.
24355 for (var i = 0; i < this.state.activeScope.track.length; i++) {
24356 var tensor = this.state.activeScope.track[i];
24357 if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
24358 tensor.dispose();
24359 }
24360 }
24361 var oldScope = this.state.scopeStack.pop();
24362 this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1];
24363 // Track the current result in the parent scope.
24364 tensorsToTrackInParent.forEach(function (tensor) {
24365 // Only track the tensor if was allocated in the inner scope and is not
24366 // globally kept.
24367 if (!tensor.kept && tensor.scopeId === oldScope.id) {
24368 _this9.track(tensor);
24369 }
24370 });
24371 }
24372 /**
24373 * Returns gradients of `f` with respect to each of the `xs`. The gradients
24374 * returned are of the same length as `xs`, but some might be null if `f`
24375 * was not a function of that `x`. It also takes optional dy to multiply the
24376 * gradient, which defaults to `1`.
24377 */
24378 }, {
24379 key: "gradients",
24380 value: function gradients(f, xs, dy) {
24381 var _this10 = this;
24382 var allowNoGradients = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
24383 assert$1(xs.length > 0, function () {
24384 return 'gradients() received an empty list of xs.';
24385 });
24386 if (dy != null && dy.dtype !== 'float32') {
24387 throw new Error("dy must have 'float32' dtype, but has '".concat(dy.dtype, "'"));
24388 }
24389 var y = this.scopedRun(function () {
24390 return _this10.startTape();
24391 }, function () {
24392 return _this10.endTape();
24393 }, function () {
24394 return _this10.tidy('forward', f);
24395 });
24396 assert$1(y instanceof Tensor, function () {
24397 return 'The result y returned by f() must be a tensor.';
24398 });
24399 // Filter out the nodes that don't connect x => y.
24400 var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
24401 if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
24402 throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' + 'that the f you passed encloses all operations that lead from x ' + 'to y.');
24403 }
24404 return this.tidy('backward', function () {
24405 var accumulatedGradientMap = {};
24406 accumulatedGradientMap[y.id] = dy == null ? ones$2(y.shape) : dy;
24407 // Backprop gradients through the filtered nodes.
24408 backpropagateGradients(accumulatedGradientMap, filteredTape,
24409 // Pass the tidy function to avoid circular dep with `tape.ts`.
24410 function (f) {
24411 return _this10.tidy(f);
24412 },
24413 // Pass an add function to avoide a circular dep with `tape.ts`.
24414 add$4);
24415 var grads = xs.map(function (x) {
24416 return accumulatedGradientMap[x.id];
24417 });
24418 if (_this10.state.gradientDepth === 0) {
24419 // This means that we are not computing higher-order gradients
24420 // and can clean up the tape.
24421 _this10.state.activeTape.forEach(function (node) {
24422 var _iterator2 = _createForOfIteratorHelper(node.saved),
24423 _step2;
24424 try {
24425 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
24426 var tensor = _step2.value;
24427 tensor.dispose();
24428 }
24429 } catch (err) {
24430 _iterator2.e(err);
24431 } finally {
24432 _iterator2.f();
24433 }
24434 });
24435 _this10.state.activeTape = null;
24436 }
24437 return {
24438 value: y,
24439 grads: grads
24440 };
24441 });
24442 }
24443 }, {
24444 key: "customGrad",
24445 value: function customGrad(f) {
24446 var _this11 = this;
24447 assert$1(isFunction(f), function () {
24448 return 'The f passed in customGrad(f) must be a function.';
24449 });
24450 return function () {
24451 for (var _len = arguments.length, inputs = new Array(_len), _key = 0; _key < _len; _key++) {
24452 inputs[_key] = arguments[_key];
24453 }
24454 assert$1(inputs.every(function (t) {
24455 return t instanceof Tensor;
24456 }), function () {
24457 return 'The args passed in customGrad(f)(x1, x2,...) must all be ' + 'tensors';
24458 });
24459 var res;
24460 var inputMap = {};
24461 inputs.forEach(function (input, i) {
24462 inputMap[i] = input;
24463 });
24464 var forwardFunc = function forwardFunc(_, save) {
24465 res = f.apply(void 0, [].concat(inputs, [save]));
24466 assert$1(res.value instanceof Tensor, function () {
24467 return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.value` is a tensor';
24468 });
24469 assert$1(isFunction(res.gradFunc), function () {
24470 return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function.';
24471 });
24472 return res.value;
24473 };
24474 var backwardsFunc = function backwardsFunc(dy, saved) {
24475 var gradRes = res.gradFunc(dy, saved);
24476 var grads = Array.isArray(gradRes) ? gradRes : [gradRes];
24477 assert$1(grads.length === inputs.length, function () {
24478 return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'the same number of tensors as inputs passed to f(...).';
24479 });
24480 assert$1(grads.every(function (t) {
24481 return t instanceof Tensor;
24482 }), function () {
24483 return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'a list of only tensors.';
24484 });
24485 var gradMap = {};
24486 grads.forEach(function (grad, i) {
24487 gradMap[i] = function () {
24488 return grad;
24489 };
24490 });
24491 return gradMap;
24492 };
24493 return _this11.runKernelFunc({
24494 forwardFunc: forwardFunc,
24495 backwardsFunc: backwardsFunc,
24496 inputs: inputMap
24497 });
24498 };
24499 }
24500 }, {
24501 key: "readSync",
24502 value: function readSync(dataId) {
24503 // Route the read to the correct backend.
24504 var info = this.state.tensorInfo.get(dataId);
24505 return info.backend.readSync(dataId);
24506 }
24507 }, {
24508 key: "read",
24509 value: function read(dataId) {
24510 // Route the read to the correct backend.
24511 var info = this.state.tensorInfo.get(dataId);
24512 return info.backend.read(dataId);
24513 }
24514 }, {
24515 key: "readToGPU",
24516 value: function readToGPU(dataId, options) {
24517 // Route the read to the correct backend.
24518 var info = this.state.tensorInfo.get(dataId);
24519 return info.backend.readToGPU(dataId, options);
24520 }
24521 }, {
24522 key: "time",
24523 value: function () {
24524 var _time = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(query) {
24525 var start, timingInfo;
24526 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
24527 while (1) switch (_context4.prev = _context4.next) {
24528 case 0:
24529 start = now();
24530 _context4.next = 3;
24531 return this.backend.time(query);
24532 case 3:
24533 timingInfo = _context4.sent;
24534 timingInfo.wallMs = now() - start;
24535 return _context4.abrupt("return", timingInfo);
24536 case 6:
24537 case "end":
24538 return _context4.stop();
24539 }
24540 }, _callee4, this);
24541 }));
24542 function time(_x3) {
24543 return _time.apply(this, arguments);
24544 }
24545 return time;
24546 }()
24547 /**
24548 * Tracks a Tensor in the current scope to be automatically cleaned up
24549 * when the current scope ends, and returns the value.
24550 *
24551 * @param result The Tensor to track in the current scope.
24552 */
24553 }, {
24554 key: "track",
24555 value: function track(result) {
24556 if (this.state.activeScope != null) {
24557 result.scopeId = this.state.activeScope.id;
24558 this.state.activeScope.track.push(result);
24559 }
24560 return result;
24561 }
24562 }, {
24563 key: "registeredVariables",
24564 get: function get() {
24565 return this.state.registeredVariables;
24566 }
24567 /**
24568 * Resets the engine state. Removes all backends but does not remove
24569 * registered backend factories.
24570 */
24571 }, {
24572 key: "reset",
24573 value: function reset() {
24574 // Make any pending promise obsolete.
24575 this.pendingBackendInitId++;
24576 this.state.dispose();
24577 this.ENV.reset();
24578 this.state = new EngineState();
24579 for (var backendName in this.registry) {
24580 this.disposeRegisteredKernels(backendName);
24581 this.registry[backendName].dispose();
24582 delete this.registry[backendName];
24583 }
24584 this.backendName = null;
24585 this.backendInstance = null;
24586 this.pendingBackendInit = null;
24587 }
24588 }]);
24589 return Engine;
24590 }();
24591 Engine.nextTensorId = 0;
24592 Engine.nextVariableId = 0;
24593 function ones$2(shape) {
24594 var values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
24595 return ENGINE.makeTensor(values, shape, 'float32');
24596 }
24597 function getOrMakeEngine() {
24598 var ns = getGlobalNamespace();
24599 if (ns._tfengine == null) {
24600 var environment = new Environment(ns);
24601 ns._tfengine = new Engine(environment);
24602 }
24603 setEnvironmentGlobal(ns._tfengine.ENV);
24604 // Tell the current tensor interface that the global engine is responsible
24605 // for tracking.
24606 setTensorTracker(function () {
24607 return ns._tfengine;
24608 });
24609 return ns._tfengine;
24610 }
24611 var ENGINE = getOrMakeEngine();
24612 /**
24613 * A implementation of the add op for use within engine and tape.
24614 *
24615 * This allows us to avoid a circular dependency between add.ts and engine.
24616 * It is exported to be available in tape tests.
24617 */
24618 function add$4(a, b) {
24619 // We duplicate Add here to avoid a circular dependency with add.ts.
24620 var inputs = {
24621 a: a,
24622 b: b
24623 };
24624 return ENGINE.runKernel(Add$1, inputs);
24625 }
24626
24627 /**
24628 * @license
24629 * Copyright 2017 Google LLC. All Rights Reserved.
24630 * Licensed under the Apache License, Version 2.0 (the "License");
24631 * you may not use this file except in compliance with the License.
24632 * You may obtain a copy of the License at
24633 *
24634 * http://www.apache.org/licenses/LICENSE-2.0
24635 *
24636 * Unless required by applicable law or agreed to in writing, software
24637 * distributed under the License is distributed on an "AS IS" BASIS,
24638 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24639 * See the License for the specific language governing permissions and
24640 * limitations under the License.
24641 * =============================================================================
24642 */
24643 // tslint:disable-next-line:no-any
24644 function _isNavigatorDefined() {
24645 return typeof navigator !== 'undefined' && navigator != null;
24646 }
24647 var isMobileMockValue;
24648 function mockIsMobile(value) {
24649 isMobileMockValue = value;
24650 }
24651 function isMobile(nav) {
24652 if (isMobileMockValue !== undefined) {
24653 return isMobileMockValue;
24654 }
24655 if (nav || _isNavigatorDefined()) {
24656 if (!nav) {
24657 nav = navigator;
24658 }
24659 if (nav.product === 'ReactNative') {
24660 return true;
24661 }
24662 var a = nav.userAgent || nav.vendor || (
24663 // tslint:disable-next-line:no-any
24664 typeof window !== 'undefined' ? window.opera : '');
24665 // Use `navigator.userAgentData.mobile` as fallback.
24666 if (!a) {
24667 // tslint:disable-next-line:no-any
24668 var navAny = nav;
24669 return navAny.userAgentData && navAny.userAgentData.mobile;
24670 }
24671 // tslint:disable-next-line:max-line-length
24672 return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i.test(a) ||
24673 // tslint:disable-next-line:max-line-length
24674 /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i.test(a.substr(0, 4));
24675 }
24676 return false;
24677 }
24678 function isBrowser() {
24679 return typeof window !== 'undefined' && window.document != null ||
24680 //@ts-ignore
24681 typeof WorkerGlobalScope !== 'undefined';
24682 }
24683
24684 var device_util = {
24685 __proto__: null,
24686 isBrowser: isBrowser,
24687 isMobile: isMobile,
24688 mockIsMobile: mockIsMobile
24689 };
24690
24691 /**
24692 * @license
24693 * Copyright 2019 Google LLC. All Rights Reserved.
24694 * Licensed under the Apache License, Version 2.0 (the "License");
24695 * you may not use this file except in compliance with the License.
24696 * You may obtain a copy of the License at
24697 *
24698 * http://www.apache.org/licenses/LICENSE-2.0
24699 *
24700 * Unless required by applicable law or agreed to in writing, software
24701 * distributed under the License is distributed on an "AS IS" BASIS,
24702 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24703 * See the License for the specific language governing permissions and
24704 * limitations under the License.
24705 * =============================================================================
24706 */
24707 var ENV$3 = env();
24708 /**
24709 * This file contains environment-related flag registrations.
24710 */
24711 /** Whether to enable debug mode. */
24712 ENV$3.registerFlag('DEBUG', function () {
24713 return false;
24714 }, function (debugValue) {
24715 if (debugValue) {
24716 console.warn('Debugging mode is ON. The output of every math call will ' + 'be downloaded to CPU and checked for NaNs. ' + 'This significantly impacts performance.');
24717 }
24718 });
24719 /** Whether we are in a browser (as versus, say, node.js) environment. */
24720 ENV$3.registerFlag('IS_BROWSER', function () {
24721 return isBrowser();
24722 });
24723 /** Whether we are in a browser (as versus, say, node.js) environment. */
24724 ENV$3.registerFlag('IS_NODE', function () {
24725 return typeof process !== 'undefined' && typeof process.versions !== 'undefined' && typeof process.versions.node !== 'undefined';
24726 });
24727 /** Whether this browser is Chrome. */
24728 ENV$3.registerFlag('IS_CHROME', function () {
24729 return typeof navigator !== 'undefined' && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor);
24730 });
24731 /** Whether this browser is Safari. */
24732 ENV$3.registerFlag('IS_SAFARI', function () {
24733 return typeof navigator !== 'undefined' && navigator != null && navigator.userAgent != null && /Safari/.test(navigator.userAgent) && /Apple/.test(navigator.vendor);
24734 });
24735 /**
24736 * True when the environment is "production" where we disable safety checks
24737 * to gain performance.
24738 */
24739 ENV$3.registerFlag('PROD', function () {
24740 return false;
24741 });
24742 /**
24743 * Whether to do sanity checks when inferring a shape from user-provided
24744 * values, used when creating a new tensor.
24745 */
24746 ENV$3.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', function () {
24747 return ENV$3.getBool('DEBUG');
24748 });
24749 /** Whether deprecation warnings are enabled. */
24750 ENV$3.registerFlag('DEPRECATION_WARNINGS_ENABLED', function () {
24751 return true;
24752 });
24753 /** True if running unit tests. */
24754 ENV$3.registerFlag('IS_TEST', function () {
24755 return false;
24756 });
24757 /** Whether to check computation result for errors. */
24758 ENV$3.registerFlag('CHECK_COMPUTATION_FOR_ERRORS', function () {
24759 return ENV$3.getBool('DEBUG');
24760 });
24761 /** Whether the backend needs to wrap input to imageBitmap. */
24762 ENV$3.registerFlag('WRAP_TO_IMAGEBITMAP', function () {
24763 return false;
24764 });
24765 /** Whether to enable canvas2d willReadFrequently for GPU backends */
24766 ENV$3.registerFlag('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU', function () {
24767 return false;
24768 });
24769 /** Whether to use setTimeoutCustom */
24770 ENV$3.registerFlag('USE_SETTIMEOUTCUSTOM', function () {
24771 return false;
24772 });
24773
24774 /**
24775 * @license
24776 * Copyright 2018 Google LLC. All Rights Reserved.
24777 * Licensed under the Apache License, Version 2.0 (the "License");
24778 * you may not use this file except in compliance with the License.
24779 * You may obtain a copy of the License at
24780 *
24781 * http://www.apache.org/licenses/LICENSE-2.0
24782 *
24783 * Unless required by applicable law or agreed to in writing, software
24784 * distributed under the License is distributed on an "AS IS" BASIS,
24785 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24786 * See the License for the specific language governing permissions and
24787 * limitations under the License.
24788 * =============================================================================
24789 */
24790 function inferShape(val, dtype) {
24791 var firstElem = val;
24792 if (isTypedArray(val)) {
24793 return dtype === 'string' ? [] : [val.length];
24794 }
24795 if (isWebGLData(val)) {
24796 var usedChannels = val.channels || 'RGBA';
24797 return [val.height, val.width * usedChannels.length];
24798 } else if (isWebGPUData(val)) {
24799 return [val.buffer.size / (dtype == null ? 4 : bytesPerElement(dtype))];
24800 }
24801 if (!Array.isArray(val)) {
24802 return []; // Scalar.
24803 }
24804
24805 var shape = [];
24806 while (Array.isArray(firstElem) || isTypedArray(firstElem) && dtype !== 'string') {
24807 shape.push(firstElem.length);
24808 firstElem = firstElem[0];
24809 }
24810 if (Array.isArray(val) && env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {
24811 deepAssertShapeConsistency(val, shape, []);
24812 }
24813 return shape;
24814 }
24815 function deepAssertShapeConsistency(val, shape, indices) {
24816 indices = indices || [];
24817 if (!Array.isArray(val) && !isTypedArray(val)) {
24818 assert$1(shape.length === 0, function () {
24819 return "Element arr[".concat(indices.join(']['), "] is a primitive, ") + "but should be an array/TypedArray of ".concat(shape[0], " elements");
24820 });
24821 return;
24822 }
24823 assert$1(shape.length > 0, function () {
24824 return "Element arr[".concat(indices.join(']['), "] should be a primitive, ") + "but is an array of ".concat(val.length, " elements");
24825 });
24826 assert$1(val.length === shape[0], function () {
24827 return "Element arr[".concat(indices.join(']['), "] should have ").concat(shape[0], " ") + "elements, but has ".concat(val.length, " elements");
24828 });
24829 var subShape = shape.slice(1);
24830 for (var i = 0; i < val.length; ++i) {
24831 deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
24832 }
24833 }
24834 function assertDtype(expectedDtype, actualDType, argName, functionName) {
24835 if (expectedDtype === 'string_or_numeric') {
24836 return;
24837 }
24838 if (expectedDtype == null) {
24839 throw new Error("Expected dtype cannot be null.");
24840 }
24841 if (expectedDtype !== 'numeric' && expectedDtype !== actualDType || expectedDtype === 'numeric' && actualDType === 'string') {
24842 throw new Error("Argument '".concat(argName, "' passed to '").concat(functionName, "' must ") + "be ".concat(expectedDtype, " tensor, but got ").concat(actualDType, " tensor"));
24843 }
24844 }
24845 function convertToTensor(x, argName, functionName) {
24846 var parseAsDtype = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'numeric';
24847 if (x instanceof getGlobalTensorClass()) {
24848 assertDtype(parseAsDtype, x.dtype, argName, functionName);
24849 return x;
24850 }
24851 var inferredDtype = inferDtype(x);
24852 // If the user expects a bool/int/float, use that info to update the
24853 // inferredDtype when it is not a string.
24854 if (inferredDtype !== 'string' && ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {
24855 inferredDtype = parseAsDtype;
24856 }
24857 assertDtype(parseAsDtype, inferredDtype, argName, functionName);
24858 if (x == null || !isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' && typeof x !== 'boolean' && typeof x !== 'string') {
24859 var type = x == null ? 'null' : x.constructor.name;
24860 throw new Error("Argument '".concat(argName, "' passed to '").concat(functionName, "' must be a ") + "Tensor or TensorLike, but got '".concat(type, "'"));
24861 }
24862 var inferredShape = inferShape(x, inferredDtype);
24863 if (!isTypedArray(x) && !Array.isArray(x)) {
24864 x = [x];
24865 }
24866 var skipTypedArray = true;
24867 var values = inferredDtype !== 'string' ? toTypedArray(x, inferredDtype) : flatten$2(x, [], skipTypedArray);
24868 return ENGINE.makeTensor(values, inferredShape, inferredDtype);
24869 }
24870 function convertToTensorArray(arg, argName, functionName) {
24871 var parseAsDtype = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'numeric';
24872 if (!Array.isArray(arg)) {
24873 throw new Error("Argument ".concat(argName, " passed to ").concat(functionName, " must be a ") + '`Tensor[]` or `TensorLike[]`');
24874 }
24875 var tensors = arg;
24876 return tensors.map(function (t, i) {
24877 return convertToTensor(t, "".concat(argName, "[").concat(i, "]"), functionName, parseAsDtype);
24878 });
24879 }
24880
24881 /**
24882 * @license
24883 * Copyright 2018 Google LLC. All Rights Reserved.
24884 * Licensed under the Apache License, Version 2.0 (the "License");
24885 * you may not use this file except in compliance with the License.
24886 * You may obtain a copy of the License at
24887 *
24888 * http://www.apache.org/licenses/LICENSE-2.0
24889 *
24890 * Unless required by applicable law or agreed to in writing, software
24891 * distributed under the License is distributed on an "AS IS" BASIS,
24892 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24893 * See the License for the specific language governing permissions and
24894 * limitations under the License.
24895 * =============================================================================
24896 */
24897 var OP_SCOPE_SUFFIX = '__op';
24898 /**
24899 * Used for wrapping functions that perform math operations on
24900 * Tensors. The function will be wrapped in a named scope that cleans all
24901 * memory usage after the function is done.
24902 */
24903 function op(f) {
24904 var keys = Object.keys(f);
24905 if (keys.length !== 1) {
24906 throw new Error("Please provide an object with a single key " + "(operation name) mapping to a function. Got an object with " + "".concat(keys.length, " keys."));
24907 }
24908 var opName = keys[0];
24909 var fn = f[opName];
24910 // Strip the underscore from the end of the function name.
24911 if (opName.endsWith('_')) {
24912 opName = opName.substring(0, opName.length - 1);
24913 }
24914 // add an __op suffix to distinguish ops from kernels in tf.profile
24915 opName = opName + OP_SCOPE_SUFFIX;
24916 // tslint:disable-next-line:no-any
24917 var f2 = function f2() {
24918 ENGINE.startScope(opName);
24919 try {
24920 var result = fn.apply(void 0, arguments);
24921 if (isPromise(result)) {
24922 console.error('Cannot return a Promise inside of tidy.');
24923 }
24924 ENGINE.endScope(result);
24925 return result;
24926 } catch (ex) {
24927 ENGINE.endScope(null);
24928 throw ex;
24929 }
24930 };
24931 Object.defineProperty(f2, 'name', {
24932 value: opName,
24933 configurable: true
24934 });
24935 // tslint:disable-next-line:no-any
24936 return f2;
24937 }
24938
24939 /**
24940 * @license
24941 * Copyright 2020 Google LLC. All Rights Reserved.
24942 * Licensed under the Apache License, Version 2.0 (the "License");
24943 * you may not use this file except in compliance with the License.
24944 * You may obtain a copy of the License at
24945 *
24946 * http://www.apache.org/licenses/LICENSE-2.0
24947 *
24948 * Unless required by applicable law or agreed to in writing, software
24949 * distributed under the License is distributed on an "AS IS" BASIS,
24950 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24951 * See the License for the specific language governing permissions and
24952 * limitations under the License.
24953 * =============================================================================
24954 */
24955 /**
24956 * Converts two real numbers to a complex number.
24957 *
24958 * Given a tensor `real` representing the real part of a complex number, and a
24959 * tensor `imag` representing the imaginary part of a complex number, this
24960 * operation returns complex numbers elementwise of the form [r0, i0, r1, i1],
24961 * where r represents the real part and i represents the imag part.
24962 *
24963 * The input tensors real and imag must have the same shape.
24964 *
24965 * ```js
24966 * const real = tf.tensor1d([2.25, 3.25]);
24967 * const imag = tf.tensor1d([4.75, 5.75]);
24968 * const complex = tf.complex(real, imag);
24969 *
24970 * complex.print();
24971 * ```
24972 *
24973 * @doc {heading: 'Tensors', subheading: 'Creation'}
24974 */
24975 function complex_(real, imag) {
24976 var $real = convertToTensor(real, 'real', 'complex');
24977 var $imag = convertToTensor(imag, 'imag', 'complex');
24978 assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, ".concat($real.shape, " and ").concat($imag.shape, ", ") + "must match in call to tf.complex().");
24979 var inputs = {
24980 real: $real,
24981 imag: $imag
24982 };
24983 return ENGINE.runKernel(Complex, inputs);
24984 }
24985 var complex$2 = /* @__PURE__ */op({
24986 complex_: complex_
24987 });
24988
24989 /**
24990 * @license
24991 * Copyright 2018 Google LLC. All Rights Reserved.
24992 * Licensed under the Apache License, Version 2.0 (the "License");
24993 * you may not use this file except in compliance with the License.
24994 * You may obtain a copy of the License at
24995 *
24996 * http://www.apache.org/licenses/LICENSE-2.0
24997 *
24998 * Unless required by applicable law or agreed to in writing, software
24999 * distributed under the License is distributed on an "AS IS" BASIS,
25000 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25001 * See the License for the specific language governing permissions and
25002 * limitations under the License.
25003 * =============================================================================
25004 */
25005 /** This is shared code across all tensor creation methods. */
25006 function makeTensor(values, shape, inferredShape, dtype) {
25007 if (dtype == null) {
25008 dtype = inferDtype(values);
25009 } else if (dtype === 'complex64') {
25010 throw new Error("Cannot construct a complex64 tensor directly. " + "Please use tf.complex(real, imag).");
25011 }
25012 if (isWebGPUData(values) || isWebGLData(values)) {
25013 if (dtype !== 'float32' && dtype !== 'int32') {
25014 throw new Error("Creating tensor from GPU data only supports " + "'float32'|'int32' dtype, while the dtype is ".concat(dtype, "."));
25015 }
25016 return ENGINE.backend.createTensorFromGPUData(values, shape || inferredShape, dtype);
25017 }
25018 if (!isTypedArray(values) && !Array.isArray(values) && typeof values !== 'number' && typeof values !== 'boolean' && typeof values !== 'string') {
25019 throw new Error('values passed to tensor(values) must be a number/boolean/string or ' + 'an array of numbers/booleans/strings, or a TypedArray');
25020 }
25021 // Verify that the shape matches the inferred shape.
25022 if (shape != null) {
25023 assertNonNegativeIntegerDimensions(shape);
25024 var providedSize = sizeFromShape(shape);
25025 var inferredSize = sizeFromShape(inferredShape);
25026 assert$1(providedSize === inferredSize, function () {
25027 return "Based on the provided shape, [".concat(shape, "], the tensor should have ") + "".concat(providedSize, " values but has ").concat(inferredSize);
25028 });
25029 for (var i = 0; i < inferredShape.length; ++i) {
25030 var inferred = inferredShape[i];
25031 var flatDimsDontMatch = i === inferredShape.length - 1 ? inferred !== sizeFromShape(shape.slice(i)) : true;
25032 assert$1(inferredShape[i] === shape[i] || !flatDimsDontMatch, function () {
25033 return "Error creating a new Tensor. Inferred shape " + "(".concat(inferredShape, ") does not match the provided ") + "shape (".concat(shape, "). ");
25034 });
25035 }
25036 }
25037 if (!isTypedArray(values) && !Array.isArray(values)) {
25038 values = [values];
25039 }
25040 shape = shape || inferredShape;
25041 values = dtype !== 'string' ? toTypedArray(values, dtype) : flatten$2(values, [], true);
25042 return ENGINE.makeTensor(values, shape, dtype);
25043 }
25044
25045 /**
25046 * @license
25047 * Copyright 2018 Google LLC. All Rights Reserved.
25048 * Licensed under the Apache License, Version 2.0 (the "License");
25049 * you may not use this file except in compliance with the License.
25050 * You may obtain a copy of the License at
25051 *
25052 * http://www.apache.org/licenses/LICENSE-2.0
25053 *
25054 * Unless required by applicable law or agreed to in writing, software
25055 * distributed under the License is distributed on an "AS IS" BASIS,
25056 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25057 * See the License for the specific language governing permissions and
25058 * limitations under the License.
25059 * =============================================================================
25060 */
25061 /**
25062 * Creates a `tf.Tensor` with the provided values, shape and dtype.
25063 *
25064 * ```js
25065 * // Pass an array of values to create a vector.
25066 * tf.tensor([1, 2, 3, 4]).print();
25067 * ```
25068 *
25069 * ```js
25070 * // Pass a nested array of values to make a matrix or a higher
25071 * // dimensional tensor.
25072 * tf.tensor([[1, 2], [3, 4]]).print();
25073 * ```
25074 *
25075 * ```js
25076 * // Pass a flat array and specify a shape yourself.
25077 * tf.tensor([1, 2, 3, 4], [2, 2]).print();
25078 * ```
25079 *
25080 * ```js
25081 * // Pass a `WebGLData` object and specify a shape yourself.
25082 *
25083 * // This makes it possible for TF.js applications to avoid GPU / CPU sync.
25084 * // For example, if your application includes a preprocessing step on the GPU,
25085 * // you could upload the GPU output directly to TF.js, rather than first
25086 * // downloading the values.
25087 *
25088 * // Example for WebGL2:
25089 * if (tf.findBackend('custom-webgl') == null) {
25090 * const customCanvas = document.createElement('canvas');
25091 * const customBackend = new tf.MathBackendWebGL(customCanvas);
25092 * tf.registerBackend('custom-webgl', () => customBackend);
25093 * }
25094 * const savedBackend = tf.getBackend();
25095 * await tf.setBackend('custom-webgl');
25096 * const gl = tf.backend().gpgpu.gl;
25097 * const texture = gl.createTexture();
25098 * const tex2d = gl.TEXTURE_2D;
25099 * const width = 2;
25100 * const height = 2;
25101 *
25102 * gl.bindTexture(tex2d, texture);
25103 * gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
25104 * gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
25105 * gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
25106 * gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
25107 * gl.texImage2D(
25108 * tex2d, 0, gl.RGBA32F, // internalFormat
25109 * width, height, 0,
25110 * gl.RGBA, // textureFormat
25111 * gl.FLOAT, // textureType
25112 * new Float32Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
25113 * );
25114 *
25115 * // Currently, the `texture` has 4 pixels:
25116 * // Pixel0 is {R:0, G:1, B:2, A:3}
25117 * // Pixel1 is {R:4, G:5, B:6, A:7}
25118 * // Pixel2 is {R:8, G:9, B:10, A:11}
25119 * // Pixel3 is {R:12, G:13, B:14, A:15}
25120 *
25121 * const logicalShape = [height * width * 2];
25122 * const a = tf.tensor({texture, height, width, channels: 'BR'}, logicalShape);
25123 * a.print();
25124 * // Tensor value will be [2, 0, 6, 4, 10, 8, 14, 12], since [2, 0] is the
25125 * // values of 'B' and 'R' channels of Pixel0, [6, 4] is the values of 'B' and
25126 * 'R'
25127 * // channels of Pixel1...
25128 *
25129 * // For postprocessing on the GPU, it's possible to retrieve the texture
25130 * // backing any tensor by calling the tensor's `dataToGPU` method like
25131 * // so:
25132 *
25133 * const tex = a.dataToGPU();
25134 * await tf.setBackend(savedBackend);
25135 * ```
25136 *
25137 * ```js
25138 * // Pass a `WebGPUData` object and specify a shape yourself.
25139 *
25140 * // This makes it possible for TF.js applications to avoid GPU / CPU sync.
25141 * // For example, if your application includes a preprocessing step on the GPU,
25142 * // you could upload the GPU output directly to TF.js, rather than first
25143 * // downloading the values. Unlike WebGL, this optionally supports zero copy
25144 * // by WebGPUData.zeroCopy. When zeroCopy is false or undefined(default), this
25145 * // passing GPUBuffer can be destroyed after tensor is created. When zeroCopy
25146 * // is true, this GPUBuffer is bound directly by the tensor, so do not destroy
25147 * // this GPUBuffer until all access is done.
25148 *
25149 * // Example for WebGPU:
25150 * function createGPUBufferFromData(device, data, dtype) {
25151 * const bytesPerElement = 4;
25152 * const sizeInBytes = data.length * bytesPerElement;
25153 *
25154 * const gpuWriteBuffer = device.createBuffer({
25155 * mappedAtCreation: true,
25156 * size: sizeInBytes,
25157 * usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC
25158 * });
25159 * const arrayBuffer = gpuWriteBuffer.getMappedRange();
25160 * if (dtype === 'float32') {
25161 * new Float32Array(arrayBuffer).set(data);
25162 * } else if (dtype === 'int32') {
25163 * new Int32Array(arrayBuffer).set(data);
25164 * } else {
25165 * throw new Error(
25166 * `Creating tensor from GPUBuffer only supports` +
25167 * `'float32'|'int32' dtype, while the dtype is ${dtype}.`);
25168 * }
25169 * gpuWriteBuffer.unmap();
25170 *
25171 * const gpuReadBuffer = device.createBuffer({
25172 * mappedAtCreation: false,
25173 * size: sizeInBytes,
25174 * usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE |
25175 * GPUBufferUsage.COPY_SRC
25176 * });
25177 *
25178 * const copyEncoder = device.createCommandEncoder();
25179 * copyEncoder.copyBufferToBuffer(
25180 * gpuWriteBuffer, 0, gpuReadBuffer, 0, sizeInBytes);
25181 * const copyCommands = copyEncoder.finish();
25182 * device.queue.submit([copyCommands]);
25183 * gpuWriteBuffer.destroy();
25184 * return gpuReadBuffer;
25185 * }
25186 *
25187 * const savedBackend = tf.getBackend();
25188 * await tf.setBackend('webgpu').catch(
25189 * () => {throw new Error(
25190 * 'Failed to use WebGPU backend. Please use Chrome Canary to run.')});
25191 * const dtype = 'float32';
25192 * const device = tf.backend().device;
25193 * const aData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
25194 * const bData = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4];
25195 * const expected = [2, 4, 6, 8, 6, 8, 10, 12, 10, 12, 14, 16, 14, 16, 18, 20];
25196 * const aBuffer = createGPUBufferFromData(device, aData, dtype);
25197 * const shape = [aData.length];
25198 * // To use zeroCopy, use {buffer: aBuffer, zeroCopy: true} instead and destroy
25199 * // aBuffer untill all access is done.
25200 * const a = tf.tensor({buffer: aBuffer}, shape, dtype);
25201 * const b = tf.tensor(bData, shape, dtype);
25202 * const result = tf.add(a, b);
25203 * result.print();
25204 * a.dispose();
25205 * b.dispose();
25206 * result.dispose();
25207 * aBuffer.destroy();
25208 * await tf.setBackend(savedBackend);
25209 * ```
25210 * @param values The values of the tensor. Can be nested array of numbers,
25211 * or a flat array, or a `TypedArray`(At the moment it supports Uint8Array,
25212 * Uint8ClampedArray, Int32Array, Float32Array) data types, or a `WebGLData`
25213 * object, or a `WebGPUData` object. If the values are strings, they will be
25214 * encoded as utf-8 and kept as `Uint8Array[]`. If the values is a `WebGLData`
25215 * object, the dtype could only be 'float32' or 'int32' and the object has to
25216 * have: 1. texture, a `WebGLTexture`, the texture must share the same
25217 * `WebGLRenderingContext` with TFJS's WebGL backend (you could create a custom
25218 * WebGL backend from your texture's canvas) and the internal texture format
25219 * for the input texture must be floating point or normalized integer; 2.
25220 * height, the height of the texture; 3. width, the width of the texture; 4.
25221 * channels, a non-empty subset of 'RGBA', indicating the values of which
25222 * channels will be passed to the tensor, such as 'R' or 'BR' (The order of the
25223 * channels affect the order of tensor values. ). (If the values passed from
25224 * texture is less than the tensor size, zeros will be padded at the rear.). If
25225 * the values is a `WebGPUData` object, the dtype could only be 'float32' or
25226 * 'int32 and the object has to have: buffer, a `GPUBuffer`. The buffer must:
25227 * 1. share the same `GPUDevice` with TFJS's WebGPU backend; 2. buffer.usage
25228 * should at least support GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC; 3.
25229 * buffer.size should not be smaller than the byte size of tensor shape.
25230 * WebGPUData optionally supports zero copy by flag zeroCopy. When zeroCopy is
25231 * false or undefined(default),this passing GPUBuffer can be destroyed after
25232 * tensor is created. When zeroCopy is true, this GPUBuffer is bound directly
25233 * by the tensor, so do not destroy this GPUBuffer until all access is done.
25234 * @param shape The shape of the tensor. Optional. If not provided,
25235 * it is inferred from `values`.
25236 * @param dtype The data type.
25237 *
25238 * @doc {heading: 'Tensors', subheading: 'Creation'}
25239 */
25240 function tensor(values, shape, dtype) {
25241 var inferredShape = inferShape(values, dtype);
25242 return makeTensor(values, shape, inferredShape, dtype);
25243 }
25244
25245 /**
25246 * @license
25247 * Copyright 2018 Google LLC. All Rights Reserved.
25248 * Licensed under the Apache License, Version 2.0 (the "License");
25249 * you may not use this file except in compliance with the License.
25250 * You may obtain a copy of the License at
25251 *
25252 * http://www.apache.org/licenses/LICENSE-2.0
25253 *
25254 * Unless required by applicable law or agreed to in writing, software
25255 * distributed under the License is distributed on an "AS IS" BASIS,
25256 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25257 * See the License for the specific language governing permissions and
25258 * limitations under the License.
25259 * =============================================================================
25260 */
25261 /* Type definitions for exporting and importing of models. */
25262 /**
25263 * A map from Tensor dtype to number of bytes per element of the Tensor.
25264 */
25265 var DTYPE_VALUE_SIZE_MAP = {
25266 'float32': 4,
25267 'float16': 2,
25268 'int32': 4,
25269 'uint16': 2,
25270 'uint8': 1,
25271 'bool': 1,
25272 'complex64': 8
25273 };
25274
25275 /**
25276 * Wraps a list of ArrayBuffers into a `slice()`-able object without allocating
25277 * a large ArrayBuffer.
25278 *
25279 * Allocating large ArrayBuffers (~2GB) can be unstable on Chrome. TFJS loads
25280 * its weights as a list of (usually) 4MB ArrayBuffers and then slices the
25281 * weight tensors out of them. For small models, it's safe to concatenate all
25282 * the weight buffers into a single ArrayBuffer and then slice the weight
25283 * tensors out of it, but for large models, a different approach is needed.
25284 */
25285 var CompositeArrayBuffer = /*#__PURE__*/function () {
25286 function CompositeArrayBuffer(buffers) {
25287 _classCallCheck(this, CompositeArrayBuffer);
25288 this.shards = [];
25289 this.previousShardIndex = 0;
25290 if (buffers == null) {
25291 return;
25292 }
25293 // Normalize the `buffers` input to be `ArrayBuffer[]`.
25294 if (!(buffers instanceof Array)) {
25295 buffers = [buffers];
25296 }
25297 buffers = buffers.map(function (bufferOrTypedArray) {
25298 if (isTypedArray(bufferOrTypedArray)) {
25299 return bufferOrTypedArray.buffer;
25300 }
25301 return bufferOrTypedArray;
25302 });
25303 // Skip setting up shards if there are no buffers.
25304 if (buffers.length === 0) {
25305 return;
25306 }
25307 this.bufferUniformSize = buffers[0].byteLength;
25308 var start = 0;
25309 for (var i = 0; i < buffers.length; i++) {
25310 var buffer = buffers[i];
25311 // Check that all buffers except the last one have the same length.
25312 if (i !== buffers.length - 1 && buffer.byteLength !== this.bufferUniformSize) {
25313 // Unset the buffer uniform size, since the buffer sizes are not
25314 // uniform.
25315 this.bufferUniformSize = undefined;
25316 }
25317 // Create the shards, including their start and end points.
25318 var end = start + buffer.byteLength;
25319 this.shards.push({
25320 buffer: buffer,
25321 start: start,
25322 end: end
25323 });
25324 start = end;
25325 }
25326 // Set the byteLength
25327 if (this.shards.length === 0) {
25328 this.byteLength = 0;
25329 }
25330 this.byteLength = this.shards[this.shards.length - 1].end;
25331 }
25332 _createClass(CompositeArrayBuffer, [{
25333 key: "slice",
25334 value: function slice() {
25335 var start = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 0;
25336 var end = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : this.byteLength;
25337 // If there are no shards, then the CompositeArrayBuffer was initialized
25338 // with no data.
25339 if (this.shards.length === 0) {
25340 return new ArrayBuffer(0);
25341 }
25342 // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior.
25343 start = isNaN(Number(start)) ? 0 : start;
25344 end = isNaN(Number(end)) ? 0 : end;
25345 // Fix the bounds to within the array.
25346 start = Math.max(0, start);
25347 end = Math.min(this.byteLength, end);
25348 if (end <= start) {
25349 return new ArrayBuffer(0);
25350 }
25351 var startShardIndex = this.findShardForByte(start);
25352 if (startShardIndex === -1) {
25353 // This should not happen since the start and end indices are always
25354 // within 0 and the composite array's length.
25355 throw new Error("Could not find start shard for byte ".concat(start));
25356 }
25357 var size = end - start;
25358 var outputBuffer = new ArrayBuffer(size);
25359 var outputArray = new Uint8Array(outputBuffer);
25360 var sliced = 0;
25361 for (var i = startShardIndex; i < this.shards.length; i++) {
25362 var shard = this.shards[i];
25363 var globalStart = start + sliced;
25364 var localStart = globalStart - shard.start;
25365 var outputStart = sliced;
25366 var globalEnd = Math.min(end, shard.end);
25367 var localEnd = globalEnd - shard.start;
25368 var outputSlice = new Uint8Array(shard.buffer, localStart, localEnd - localStart);
25369 outputArray.set(outputSlice, outputStart);
25370 sliced += outputSlice.length;
25371 if (end < shard.end) {
25372 break;
25373 }
25374 }
25375 return outputBuffer;
25376 }
25377 /**
25378 * Get the index of the shard that contains the byte at `byteIndex`.
25379 */
25380 }, {
25381 key: "findShardForByte",
25382 value: function findShardForByte(byteIndex) {
25383 if (this.shards.length === 0 || byteIndex < 0 || byteIndex >= this.byteLength) {
25384 return -1;
25385 }
25386 // If the buffers have a uniform size, compute the shard directly.
25387 if (this.bufferUniformSize != null) {
25388 this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize);
25389 return this.previousShardIndex;
25390 }
25391 // If the buffers don't have a uniform size, we need to search for the
25392 // shard. That means we need a function to check where the byteIndex lies
25393 // relative to a given shard.
25394 function check(shard) {
25395 if (byteIndex < shard.start) {
25396 return -1;
25397 }
25398 if (byteIndex >= shard.end) {
25399 return 1;
25400 }
25401 return 0;
25402 }
25403 // For efficiency, try the previous shard first.
25404 if (check(this.shards[this.previousShardIndex]) === 0) {
25405 return this.previousShardIndex;
25406 }
25407 // Otherwise, use a generic search function.
25408 // This should almost never end up being used in practice since the weight
25409 // entries should always be in order.
25410 var index = search(this.shards, check);
25411 if (index === -1) {
25412 return -1;
25413 }
25414 this.previousShardIndex = index;
25415 return this.previousShardIndex;
25416 }
25417 }], [{
25418 key: "join",
25419 value:
25420 /**
25421 * Concatenate a number of ArrayBuffers into one.
25422 *
25423 * @param buffers An array of ArrayBuffers to concatenate, or a single
25424 * ArrayBuffer.
25425 * @returns Result of concatenating `buffers` in order.
25426 */
25427 function join(buffers) {
25428 return new CompositeArrayBuffer(buffers).slice();
25429 }
25430 }]);
25431 return CompositeArrayBuffer;
25432 }();
25433 /**
25434 * Search for an element of a sorted array.
25435 *
25436 * @param sortedArray The sorted array to search
25437 * @param compare A function to compare the current value against the searched
25438 * value. Return 0 on a match, negative if the searched value is less than
25439 * the value passed to the function, and positive if the searched value is
25440 * greater than the value passed to the function.
25441 * @returns The index of the element, or -1 if it's not in the array.
25442 */
25443 function search(sortedArray, compare) {
25444 // Binary search
25445 var min = 0;
25446 var max = sortedArray.length;
25447 while (min <= max) {
25448 var middle = Math.floor((max - min) / 2) + min;
25449 var side = compare(sortedArray[middle]);
25450 if (side === 0) {
25451 return middle;
25452 } else if (side < 0) {
25453 max = middle;
25454 } else {
25455 min = middle + 1;
25456 }
25457 }
25458 return -1;
25459 }
25460
25461 /**
25462 * @license
25463 * Copyright 2018 Google LLC. All Rights Reserved.
25464 * Licensed under the Apache License, Version 2.0 (the "License");
25465 * you may not use this file except in compliance with the License.
25466 * You may obtain a copy of the License at
25467 *
25468 * http://www.apache.org/licenses/LICENSE-2.0
25469 *
25470 * Unless required by applicable law or agreed to in writing, software
25471 * distributed under the License is distributed on an "AS IS" BASIS,
25472 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25473 * See the License for the specific language governing permissions and
25474 * limitations under the License.
25475 * =============================================================================
25476 */
25477 /**
25478 * Enables production mode which disables correctness checks in favor of
25479 * performance.
25480 *
25481 * @doc {heading: 'Environment'}
25482 */
25483 function enableProdMode() {
25484 env().set('PROD', true);
25485 }
25486 /**
25487 * Enables debug mode which will log information about all executed kernels:
25488 * the elapsed time of the kernel execution, as well as the rank, shape, and
25489 * size of the output tensor.
25490 *
25491 * Debug mode will significantly slow down your application as it will
25492 * download the result of every operation to the CPU. This should not be used in
25493 * production. Debug mode does not affect the timing information of the kernel
25494 * execution as we do not measure download time in the kernel execution time.
25495 *
25496 * See also: `tf.profile`, `tf.memory`.
25497 *
25498 * @doc {heading: 'Environment'}
25499 */
25500 function enableDebugMode() {
25501 env().set('DEBUG', true);
25502 }
25503 /** Globally disables deprecation warnings */
25504 function disableDeprecationWarnings() {
25505 env().set('DEPRECATION_WARNINGS_ENABLED', false);
25506 console.warn("TensorFlow.js deprecation warnings have been disabled.");
25507 }
25508 /** Warn users about deprecated functionality. */
25509 function deprecationWarn(msg) {
25510 if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) {
25511 console.warn(msg + ' You can disable deprecation warnings with ' + 'tf.disableDeprecationWarnings().');
25512 }
25513 }
25514 setDeprecationWarningFn(deprecationWarn);
25515 /**
25516 * Dispose all variables kept in backend engine.
25517 *
25518 * @doc {heading: 'Environment'}
25519 */
25520 function disposeVariables() {
25521 ENGINE.disposeVariables();
25522 }
25523 /**
25524 * It returns the global engine that keeps track of all tensors and backends.
25525 *
25526 * @doc {heading: 'Environment'}
25527 */
25528 function engine() {
25529 return ENGINE;
25530 }
25531 /**
25532 * Returns memory info at the current time in the program. The result is an
25533 * object with the following properties:
25534 *
25535 * - `numBytes`: Number of bytes allocated (undisposed) at this time.
25536 * - `numTensors`: Number of unique tensors allocated.
25537 * - `numDataBuffers`: Number of unique data buffers allocated
25538 * (undisposed) at this time, which is ≤ the number of tensors
25539 * (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same
25540 * data buffer with `a`).
25541 * - `unreliable`: True if the memory usage is unreliable. See `reasons` when
25542 * `unreliable` is true.
25543 * - `reasons`: `string[]`, reasons why the memory is unreliable, present if
25544 * `unreliable` is true.
25545 *
25546 * WebGL Properties:
25547 * - `numBytesInGPU`: Number of bytes allocated (undisposed) in the GPU only at
25548 * this time.
25549 *
25550 * @doc {heading: 'Performance', subheading: 'Memory'}
25551 */
25552 function memory() {
25553 return ENGINE.memory();
25554 }
25555 /**
25556 * Executes the provided function `f()` and returns a promise that resolves
25557 * with information about the function's memory use:
25558 * - `newBytes`: the number of new bytes allocated
25559 * - `newTensors`: the number of new tensors created
25560 * - `peakBytes`: the peak number of bytes allocated
25561 * - `kernels`: an array of objects for each kernel involved that reports
25562 * their input and output shapes, number of bytes used, and number of new
25563 * tensors created.
25564 * - `kernelNames`: an array of unique strings with just the names of the
25565 * kernels in the `kernels` array.
25566 *
25567 * ```js
25568 * const profile = await tf.profile(() => {
25569 * const x = tf.tensor1d([1, 2, 3]);
25570 * let x2 = x.square();
25571 * x2.dispose();
25572 * x2 = x.square();
25573 * x2.dispose();
25574 * return x;
25575 * });
25576 *
25577 * console.log(`newBytes: ${profile.newBytes}`);
25578 * console.log(`newTensors: ${profile.newTensors}`);
25579 * console.log(`byte usage over all kernels: ${profile.kernels.map(k =>
25580 * k.totalBytesSnapshot)}`);
25581 * ```
25582 *
25583 *
25584 * @doc {heading: 'Performance', subheading: 'Profile'}
25585 */
25586 function profile(f) {
25587 return ENGINE.profile(f);
25588 }
25589 /**
25590 * Executes the provided function `fn` and after it is executed, cleans up all
25591 * intermediate tensors allocated by `fn` except those returned by `fn`.
25592 * `fn` must not return a Promise (async functions not allowed). The returned
25593 * result can be a complex object.
25594 *
25595 * Using this method helps avoid memory leaks. In general, wrap calls to
25596 * operations in `tf.tidy` for automatic memory cleanup.
25597 *
25598 * NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to
25599 * dispose variables, please use `tf.disposeVariables` or call dispose()
25600 * directly on variables.
25601 *
25602 * ```js
25603 * // y = 2 ^ 2 + 1
25604 * const y = tf.tidy(() => {
25605 * // a, b, and one will be cleaned up when the tidy ends.
25606 * const one = tf.scalar(1);
25607 * const a = tf.scalar(2);
25608 * const b = a.square();
25609 *
25610 * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
25611 *
25612 * // The value returned inside the tidy function will return
25613 * // through the tidy, in this case to the variable y.
25614 * return b.add(one);
25615 * });
25616 *
25617 * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
25618 * y.print();
25619 * ```
25620 *
25621 * @param nameOrFn The name of the closure, or the function to execute.
25622 * If a name is provided, the 2nd argument should be the function.
25623 * If debug mode is on, the timing and the memory usage of the function
25624 * will be tracked and displayed on the console using the provided name.
25625 * @param fn The function to execute.
25626 *
25627 * @doc {heading: 'Performance', subheading: 'Memory'}
25628 */
25629 function tidy(nameOrFn, fn) {
25630 return ENGINE.tidy(nameOrFn, fn);
25631 }
25632 /**
25633 * Disposes any `tf.Tensor`s found within the provided object.
25634 *
25635 * @param container an object that may be a `tf.Tensor` or may directly
25636 * contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If
25637 * the object is not a `tf.Tensor` or does not contain `Tensors`, nothing
25638 * happens. In general it is safe to pass any object here, except that
25639 * `Promise`s are not supported.
25640 *
25641 * @doc {heading: 'Performance', subheading: 'Memory'}
25642 */
25643 function dispose(container) {
25644 var tensors = getTensorsInContainer(container);
25645 tensors.forEach(function (tensor) {
25646 return tensor.dispose();
25647 });
25648 }
25649 /**
25650 * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed
25651 * automatically.
25652 *
25653 * ```js
25654 * let b;
25655 * const y = tf.tidy(() => {
25656 * const one = tf.scalar(1);
25657 * const a = tf.scalar(2);
25658 *
25659 * // b will not be cleaned up by the tidy. a and one will be cleaned up
25660 * // when the tidy ends.
25661 * b = tf.keep(a.square());
25662 *
25663 * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
25664 *
25665 * // The value returned inside the tidy function will return
25666 * // through the tidy, in this case to the variable y.
25667 * return b.add(one);
25668 * });
25669 *
25670 * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
25671 * console.log('y:');
25672 * y.print();
25673 * console.log('b:');
25674 * b.print();
25675 * ```
25676 *
25677 * @param result The tensor to keep from being disposed.
25678 *
25679 * @doc {heading: 'Performance', subheading: 'Memory'}
25680 */
25681 function keep(result) {
25682 return ENGINE.keep(result);
25683 }
25684 /**
25685 * Executes `f()` and returns a promise that resolves with timing
25686 * information.
25687 *
25688 * The result is an object with the following properties:
25689 *
25690 * - `wallMs`: Wall execution time.
25691 * - `kernelMs`: Kernel execution time, ignoring data transfer. If using the
25692 * WebGL backend and the query timer extension is not available, this will
25693 * return an error object.
25694 * - On `WebGL` The following additional properties exist:
25695 * - `uploadWaitMs`: CPU blocking time on texture uploads.
25696 * - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels).
25697 *
25698 * ```js
25699 * const x = tf.randomNormal([20, 20]);
25700 * const time = await tf.time(() => x.matMul(x));
25701 *
25702 * console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`);
25703 * ```
25704 *
25705 * @param f The function to execute and time.
25706 *
25707 * @doc {heading: 'Performance', subheading: 'Timing'}
25708 */
25709 function time(f) {
25710 return ENGINE.time(f);
25711 }
25712 /**
25713 * Sets the backend (cpu, webgl, wasm, etc) responsible for creating tensors and
25714 * executing operations on those tensors. Returns a promise that resolves
25715 * to a boolean if the backend initialization was successful.
25716 *
25717 * Note this disposes the current backend, if any, as well as any tensors
25718 * associated with it. A new backend is initialized, even if it is of the
25719 * same type as the previous one.
25720 *
25721 * @param backendName The name of the backend. Currently supports
25722 * `'webgl'|'cpu'` in the browser, `'tensorflow'` under node.js
25723 * (requires tfjs-node), and `'wasm'` (requires tfjs-backend-wasm).
25724 *
25725 * @doc {heading: 'Backends'}
25726 */
25727 function setBackend$1(backendName) {
25728 return ENGINE.setBackend(backendName);
25729 }
25730 /**
25731 * Returns a promise that resolves when the currently selected backend (or the
25732 * highest priority one) has initialized. Await this promise when you are using
25733 * a backend that has async initialization.
25734 *
25735 * @doc {heading: 'Backends'}
25736 */
25737 function ready() {
25738 return ENGINE.ready();
25739 }
25740 /**
25741 * Returns the current backend name (cpu, webgl, etc). The backend is
25742 * responsible for creating tensors and executing operations on those tensors.
25743 *
25744 * @doc {heading: 'Backends'}
25745 */
25746 function getBackend$1() {
25747 return ENGINE.backendName;
25748 }
25749 /**
25750 * Removes a backend and the registered factory.
25751 *
25752 * @doc {heading: 'Backends'}
25753 */
25754 function removeBackend(name) {
25755 ENGINE.removeBackend(name);
25756 }
25757 /**
25758 * Finds the backend registered under the provided name. Returns null if the
25759 * name is not in the registry, or the registration hasn't finished yet.
25760 */
25761 function findBackend(name) {
25762 return ENGINE.findBackend(name);
25763 }
25764 /**
25765 * Finds the backend factory registered under the provided name. Returns a
25766 * function that produces a new backend when called. Returns null if the name
25767 * is not in the registry.
25768 */
25769 function findBackendFactory(name) {
25770 return ENGINE.findBackendFactory(name);
25771 }
25772 /**
25773 * Registers a global backend. The registration should happen when importing
25774 * a module file (e.g. when importing `backend_webgl.ts`), and is used for
25775 * modular builds (e.g. custom tfjs bundle with only webgl support).
25776 *
25777 * @param factory The backend factory function. When called, it should
25778 * return a backend instance, or a promise of an instance.
25779 * @param priority The priority of the backend (higher = more important).
25780 * In case multiple backends are registered, the priority is used to find
25781 * the best backend. Defaults to 1.
25782 * @return False if there is already a registered backend under this name, true
25783 * if not.
25784 *
25785 * @doc {heading: 'Backends'}
25786 */
25787 function registerBackend(name, factory) {
25788 var priority = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1;
25789 return ENGINE.registerBackend(name, factory, priority);
25790 }
25791 /**
25792 * Gets the current backend. If no backends have been initialized, this will
25793 * attempt to initialize the best backend. Will throw an error if the highest
25794 * priority backend has async initialization, in which case you should call
25795 * 'await tf.ready()' before running other code.
25796 *
25797 * @doc {heading: 'Backends'}
25798 */
25799 function backend$1() {
25800 return ENGINE.backend;
25801 }
25802 /**
25803 * Sets the global platform.
25804 *
25805 * @param platformName The name of this platform.
25806 * @param platform A platform implementation.
25807 */
25808 function setPlatform(platformName, platform) {
25809 env().setPlatform(platformName, platform);
25810 }
25811
25812 /** Number of bytes reserved for the length of the string. (32bit integer). */
25813 var NUM_BYTES_STRING_LENGTH = 4;
25814 /**
25815 * Encode a map from names to weight values as an ArrayBuffer, along with an
25816 * `Array` of `WeightsManifestEntry` as specification of the encoded weights.
25817 *
25818 * This function does not perform sharding.
25819 *
25820 * This function is the reverse of `decodeWeights`.
25821 *
25822 * @param tensors A map ("dict") from names to tensors.
25823 * @param group Group to which the weights belong (optional).
25824 * @returns A `Promise` of
25825 * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s
25826 * concatenated.
25827 * - An `Array` of `WeightManifestEntry`s, carrying information including
25828 * tensor names, `dtype`s and shapes.
25829 * @throws Error: on unsupported tensor `dtype`.
25830 */
25831 function encodeWeights(_x, _x2) {
25832 return _encodeWeights.apply(this, arguments);
25833 }
25834 /**
25835 * Decode flat ArrayBuffer as weights.
25836 *
25837 * This function does not handle sharding.
25838 *
25839 * This function is the reverse of `encodeWeights`.
25840 *
25841 * @param weightData A flat ArrayBuffer or an array of ArrayBuffers carrying the
25842 * binary values of the tensors concatenated in the order specified in
25843 * `specs`.
25844 * @param specs Specifications of the names, dtypes and shapes of the tensors
25845 * whose value are encoded by `buffer`.
25846 * @return A map from tensor name to tensor value, with the names corresponding
25847 * to names in `specs`.
25848 * @throws Error, if any of the tensors has unsupported dtype.
25849 */
25850 function _encodeWeights() {
25851 _encodeWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(tensors, group) {
25852 var specs, dataPromises, names, _loop, i, tensorValues;
25853 return _regeneratorRuntime().wrap(function _callee2$(_context3) {
25854 while (1) switch (_context3.prev = _context3.next) {
25855 case 0:
25856 // TODO(adarob, cais): Support quantization.
25857 specs = [];
25858 dataPromises = [];
25859 names = Array.isArray(tensors) ? tensors.map(function (tensor) {
25860 return tensor.name;
25861 }) : Object.keys(tensors);
25862 _loop = /*#__PURE__*/_regeneratorRuntime().mark(function _loop() {
25863 var name, t, spec, utf8bytes;
25864 return _regeneratorRuntime().wrap(function _loop$(_context2) {
25865 while (1) switch (_context2.prev = _context2.next) {
25866 case 0:
25867 name = names[i];
25868 t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];
25869 if (!(t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' && t.dtype !== 'string' && t.dtype !== 'complex64')) {
25870 _context2.next = 4;
25871 break;
25872 }
25873 throw new Error("Unsupported dtype in weight '".concat(name, "': ").concat(t.dtype));
25874 case 4:
25875 spec = {
25876 name: name,
25877 shape: t.shape,
25878 dtype: t.dtype
25879 };
25880 if (t.dtype === 'string') {
25881 utf8bytes = new Promise( /*#__PURE__*/function () {
25882 var _ref = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(resolve) {
25883 var vals, totalNumBytes, bytes, offset, _i6, val, bytesOfLength;
25884 return _regeneratorRuntime().wrap(function _callee$(_context) {
25885 while (1) switch (_context.prev = _context.next) {
25886 case 0:
25887 _context.next = 2;
25888 return t.bytes();
25889 case 2:
25890 vals = _context.sent;
25891 totalNumBytes = vals.reduce(function (p, c) {
25892 return p + c.length;
25893 }, 0) + NUM_BYTES_STRING_LENGTH * vals.length;
25894 bytes = new Uint8Array(totalNumBytes);
25895 offset = 0;
25896 for (_i6 = 0; _i6 < vals.length; _i6++) {
25897 val = vals[_i6];
25898 bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer);
25899 bytes.set(bytesOfLength, offset);
25900 offset += NUM_BYTES_STRING_LENGTH;
25901 bytes.set(val, offset);
25902 offset += val.length;
25903 }
25904 resolve(bytes);
25905 case 8:
25906 case "end":
25907 return _context.stop();
25908 }
25909 }, _callee);
25910 }));
25911 return function (_x12) {
25912 return _ref.apply(this, arguments);
25913 };
25914 }());
25915 dataPromises.push(utf8bytes);
25916 } else {
25917 dataPromises.push(t.data());
25918 }
25919 if (group != null) {
25920 spec.group = group;
25921 }
25922 specs.push(spec);
25923 case 8:
25924 case "end":
25925 return _context2.stop();
25926 }
25927 }, _loop);
25928 });
25929 i = 0;
25930 case 5:
25931 if (!(i < names.length)) {
25932 _context3.next = 10;
25933 break;
25934 }
25935 return _context3.delegateYield(_loop(), "t0", 7);
25936 case 7:
25937 ++i;
25938 _context3.next = 5;
25939 break;
25940 case 10:
25941 _context3.next = 12;
25942 return Promise.all(dataPromises);
25943 case 12:
25944 tensorValues = _context3.sent;
25945 return _context3.abrupt("return", {
25946 data: concatenateTypedArrays(tensorValues),
25947 specs: specs
25948 });
25949 case 14:
25950 case "end":
25951 return _context3.stop();
25952 }
25953 }, _callee2);
25954 }));
25955 return _encodeWeights.apply(this, arguments);
25956 }
25957 function decodeWeights(weightData, specs) {
25958 // TODO(adarob, cais): Support quantization.
25959 var compositeBuffer = new CompositeArrayBuffer(weightData);
25960 var out = {};
25961 var offset = 0;
25962 var _iterator = _createForOfIteratorHelper(specs),
25963 _step;
25964 try {
25965 for (_iterator.s(); !(_step = _iterator.n()).done;) {
25966 var spec = _step.value;
25967 var byteLength = getWeightBytelength(spec, function (start, end) {
25968 return compositeBuffer.slice(offset + start, offset + end);
25969 });
25970 out[spec.name] = decodeWeight(spec, compositeBuffer.slice(offset, offset + byteLength));
25971 offset += byteLength;
25972 }
25973 } catch (err) {
25974 _iterator.e(err);
25975 } finally {
25976 _iterator.f();
25977 }
25978 return out;
25979 }
25980 function getWeightBytelength(spec, slice) {
25981 var size = sizeFromShape(spec.shape);
25982 var bytesPerValue;
25983 if ('quantization' in spec) {
25984 var quantization = spec.quantization;
25985 bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
25986 } else if (spec.dtype === 'string') {
25987 // Can not statically determine string length.
25988 var byteLength = 0;
25989 for (var i = 0; i < size; i++) {
25990 byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0];
25991 }
25992 return byteLength;
25993 } else {
25994 bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype];
25995 }
25996 return size * bytesPerValue;
25997 }
25998 function getWeightBytelengthAsync(_x3, _x4) {
25999 return _getWeightBytelengthAsync.apply(this, arguments);
26000 }
26001 function _getWeightBytelengthAsync() {
26002 _getWeightBytelengthAsync = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(spec, slice) {
26003 var size, bytesPerValue, quantization, byteLength, i;
26004 return _regeneratorRuntime().wrap(function _callee3$(_context4) {
26005 while (1) switch (_context4.prev = _context4.next) {
26006 case 0:
26007 size = sizeFromShape(spec.shape);
26008 if (!('quantization' in spec)) {
26009 _context4.next = 6;
26010 break;
26011 }
26012 quantization = spec.quantization;
26013 bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
26014 _context4.next = 25;
26015 break;
26016 case 6:
26017 if (!(spec.dtype === 'string')) {
26018 _context4.next = 24;
26019 break;
26020 }
26021 // Can not statically determine string length.
26022 byteLength = 0;
26023 i = 0;
26024 case 9:
26025 if (!(i < size)) {
26026 _context4.next = 21;
26027 break;
26028 }
26029 _context4.t0 = byteLength;
26030 _context4.t1 = NUM_BYTES_STRING_LENGTH;
26031 _context4.t2 = Uint32Array;
26032 _context4.next = 15;
26033 return slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH);
26034 case 15:
26035 _context4.t3 = _context4.sent;
26036 _context4.t4 = new _context4.t2(_context4.t3)[0];
26037 byteLength = _context4.t0 += _context4.t1 + _context4.t4;
26038 case 18:
26039 i++;
26040 _context4.next = 9;
26041 break;
26042 case 21:
26043 return _context4.abrupt("return", byteLength);
26044 case 24:
26045 bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype];
26046 case 25:
26047 return _context4.abrupt("return", size * bytesPerValue);
26048 case 26:
26049 case "end":
26050 return _context4.stop();
26051 }
26052 }, _callee3);
26053 }));
26054 return _getWeightBytelengthAsync.apply(this, arguments);
26055 }
26056 function decodeWeight(spec, byteBuffer) {
26057 var name = spec.name;
26058 var dtype = spec.dtype;
26059 var shape = spec.shape;
26060 var size = sizeFromShape(shape);
26061 var values;
26062 var offset = 0;
26063 if ('quantization' in spec) {
26064 var quantization = spec.quantization;
26065 if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
26066 if (!('min' in quantization && 'scale' in quantization)) {
26067 throw new Error("Weight ".concat(spec.name, " with quantization ").concat(quantization.dtype, " ") + "doesn't have corresponding metadata min and scale.");
26068 }
26069 } else if (quantization.dtype === 'float16') {
26070 if (dtype !== 'float32') {
26071 throw new Error("Weight ".concat(spec.name, " is quantized with ").concat(quantization.dtype, " ") + "which only supports weights of type float32 not ".concat(dtype, "."));
26072 }
26073 } else {
26074 throw new Error("Weight ".concat(spec.name, " has unknown ") + "quantization dtype ".concat(quantization.dtype, ". ") + "Supported quantization dtypes are: " + "'uint8', 'uint16', and 'float16'.");
26075 }
26076 var quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
26077 var quantizedArray = quantization.dtype === 'uint8' ? new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer);
26078 if (dtype === 'float32') {
26079 if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
26080 values = new Float32Array(quantizedArray.length);
26081 for (var i = 0; i < quantizedArray.length; i++) {
26082 var v = quantizedArray[i];
26083 values[i] = v * quantization.scale + quantization.min;
26084 }
26085 } else if (quantization.dtype === 'float16') {
26086 // TODO: This is inefficient. Make getFloat16Decoder efficient.
26087 var float16Decode = getFloat16Decoder();
26088 values = float16Decode(quantizedArray);
26089 } else {
26090 throw new Error("Unsupported quantization type ".concat(quantization.dtype, " ") + "for weight type float32.");
26091 }
26092 } else if (dtype === 'int32') {
26093 if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
26094 throw new Error("Unsupported quantization type ".concat(quantization.dtype, " ") + "for weight type int32.");
26095 }
26096 values = new Int32Array(quantizedArray.length);
26097 for (var _i = 0; _i < quantizedArray.length; _i++) {
26098 var _v = quantizedArray[_i];
26099 values[_i] = Math.round(_v * quantization.scale + quantization.min);
26100 }
26101 } else {
26102 throw new Error("Unsupported dtype in weight '".concat(name, "': ").concat(dtype));
26103 }
26104 offset += size * quantizationSizeFactor;
26105 } else if (dtype === 'string') {
26106 var _size = sizeFromShape(spec.shape);
26107 values = [];
26108 for (var _i2 = 0; _i2 < _size; _i2++) {
26109 var byteLength = new Uint32Array(byteBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
26110 offset += NUM_BYTES_STRING_LENGTH;
26111 var bytes = new Uint8Array(byteBuffer.slice(offset, offset + byteLength));
26112 values.push(bytes);
26113 offset += byteLength;
26114 }
26115 } else {
26116 var dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
26117 if (dtype === 'float32') {
26118 values = new Float32Array(byteBuffer);
26119 } else if (dtype === 'int32') {
26120 values = new Int32Array(byteBuffer);
26121 } else if (dtype === 'bool') {
26122 values = new Uint8Array(byteBuffer);
26123 } else if (dtype === 'complex64') {
26124 values = new Float32Array(byteBuffer);
26125 var real = new Float32Array(values.length / 2);
26126 var image = new Float32Array(values.length / 2);
26127 for (var _i3 = 0; _i3 < real.length; _i3++) {
26128 real[_i3] = values[_i3 * 2];
26129 image[_i3] = values[_i3 * 2 + 1];
26130 }
26131 var realTensor = tensor(real, shape, 'float32');
26132 var imageTensor = tensor(image, shape, 'float32');
26133 var complexTensor = complex$2(realTensor, imageTensor);
26134 realTensor.dispose();
26135 imageTensor.dispose();
26136 return complexTensor;
26137 } else {
26138 throw new Error("Unsupported dtype in weight '".concat(name, "': ").concat(dtype));
26139 }
26140 offset += size * dtypeFactor;
26141 }
26142 return tensor(values, shape, dtype);
26143 }
26144 function readToLength(_x5, _x6, _x7) {
26145 return _readToLength.apply(this, arguments);
26146 }
26147 function _readToLength() {
26148 _readToLength = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(reader, initialData, length) {
26149 var data, _yield$reader$read, done, value, missing, newData;
26150 return _regeneratorRuntime().wrap(function _callee4$(_context5) {
26151 while (1) switch (_context5.prev = _context5.next) {
26152 case 0:
26153 data = new Uint8Array(initialData);
26154 case 1:
26155 if (!(data.byteLength < length)) {
26156 _context5.next = 16;
26157 break;
26158 }
26159 _context5.next = 4;
26160 return reader.read();
26161 case 4:
26162 _yield$reader$read = _context5.sent;
26163 done = _yield$reader$read.done;
26164 value = _yield$reader$read.value;
26165 if (!(done && value == null)) {
26166 _context5.next = 10;
26167 break;
26168 }
26169 missing = length - data.byteLength;
26170 throw new Error("Reader is done but ".concat(missing, " bytes are still expected"));
26171 case 10:
26172 // TODO: Don't create a new array every loop.
26173 newData = new Uint8Array(data.length + value.byteLength);
26174 newData.set(data, 0);
26175 newData.set(new Uint8Array(value), data.length);
26176 data = newData;
26177 _context5.next = 1;
26178 break;
26179 case 16:
26180 return _context5.abrupt("return", data.buffer);
26181 case 17:
26182 case "end":
26183 return _context5.stop();
26184 }
26185 }, _callee4);
26186 }));
26187 return _readToLength.apply(this, arguments);
26188 }
26189 function decodeWeightsStream(_x8, _x9) {
26190 return _decodeWeightsStream.apply(this, arguments);
26191 }
26192 /**
26193 * Concatenate TypedArrays into an ArrayBuffer.
26194 */
26195 function _decodeWeightsStream() {
26196 _decodeWeightsStream = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee6(weightStream, specs) {
26197 var tensors, reader, data, _iterator3, _step3, spec, byteLength, tensorData, weightTensor, b;
26198 return _regeneratorRuntime().wrap(function _callee6$(_context7) {
26199 while (1) switch (_context7.prev = _context7.next) {
26200 case 0:
26201 tensors = {};
26202 reader = weightStream.getReader();
26203 data = new ArrayBuffer(0);
26204 _iterator3 = _createForOfIteratorHelper(specs);
26205 _context7.prev = 4;
26206 _iterator3.s();
26207 case 6:
26208 if ((_step3 = _iterator3.n()).done) {
26209 _context7.next = 21;
26210 break;
26211 }
26212 spec = _step3.value;
26213 _context7.next = 10;
26214 return getWeightBytelengthAsync(spec, /*#__PURE__*/function () {
26215 var _ref2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5(start, end) {
26216 return _regeneratorRuntime().wrap(function _callee5$(_context6) {
26217 while (1) switch (_context6.prev = _context6.next) {
26218 case 0:
26219 _context6.next = 2;
26220 return readToLength(reader, data, end);
26221 case 2:
26222 data = _context6.sent;
26223 return _context6.abrupt("return", data.slice(start, end));
26224 case 4:
26225 case "end":
26226 return _context6.stop();
26227 }
26228 }, _callee5);
26229 }));
26230 return function (_x13, _x14) {
26231 return _ref2.apply(this, arguments);
26232 };
26233 }());
26234 case 10:
26235 byteLength = _context7.sent;
26236 _context7.next = 13;
26237 return readToLength(reader, data, byteLength);
26238 case 13:
26239 data = _context7.sent;
26240 // Slice the tensor out
26241 tensorData = data.slice(0, byteLength);
26242 data = data.slice(byteLength);
26243 weightTensor = decodeWeight(spec, tensorData);
26244 tensors[spec.name] = weightTensor;
26245 // TODO(mattsoulanille): Better way to call uploadToGPU.
26246 // TODO(mattsoulanille): Make this work for webgl too.
26247 if (getBackend$1() === 'webgpu') {
26248 b = backend$1();
26249 if ('uploadToGPU' in b && sizeFromShape(weightTensor.shape) >= env().get('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD')) {
26250 b.uploadToGPU(weightTensor.dataId);
26251 }
26252 }
26253 case 19:
26254 _context7.next = 6;
26255 break;
26256 case 21:
26257 _context7.next = 26;
26258 break;
26259 case 23:
26260 _context7.prev = 23;
26261 _context7.t0 = _context7["catch"](4);
26262 _iterator3.e(_context7.t0);
26263 case 26:
26264 _context7.prev = 26;
26265 _iterator3.f();
26266 return _context7.finish(26);
26267 case 29:
26268 return _context7.abrupt("return", tensors);
26269 case 30:
26270 case "end":
26271 return _context7.stop();
26272 }
26273 }, _callee6, null, [[4, 23, 26, 29]]);
26274 }));
26275 return _decodeWeightsStream.apply(this, arguments);
26276 }
26277 function concatenateTypedArrays(xs) {
26278 // TODO(adarob, cais): Support quantization.
26279 if (xs === null) {
26280 throw new Error("Invalid input value: ".concat(JSON.stringify(xs)));
26281 }
26282 var totalByteLength = 0;
26283 // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer'
26284 // can have a different byte length from that of the `TypedArray` itself,
26285 // for example, when the `TypedArray` is created from an offset in an
26286 // `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match
26287 // the `TypedArray` in byte length. If an element of `xs` does not show
26288 // this property, a new `TypedArray` that satisfy this property will be
26289 // constructed and pushed into `normalizedXs`.
26290 var normalizedXs = [];
26291 xs.forEach(function (x) {
26292 totalByteLength += x.byteLength;
26293 // tslint:disable:no-any
26294 normalizedXs.push(x.byteLength === x.buffer.byteLength ? x : new x.constructor(x));
26295 if (!(x instanceof Float32Array || x instanceof Int32Array || x instanceof Uint8Array)) {
26296 throw new Error("Unsupported TypedArray subtype: ".concat(x.constructor.name));
26297 }
26298 // tslint:enable:no-any
26299 });
26300
26301 var y = new Uint8Array(totalByteLength);
26302 var offset = 0;
26303 normalizedXs.forEach(function (x) {
26304 y.set(new Uint8Array(x.buffer), offset);
26305 offset += x.byteLength;
26306 });
26307 return y.buffer;
26308 }
26309 // Use Buffer on Node.js instead of Blob/atob/btoa
26310 var useNodeBuffer = typeof Buffer !== 'undefined' && (typeof Blob === 'undefined' || typeof atob === 'undefined' || typeof btoa === 'undefined');
26311 /**
26312 * Calculate the byte length of a JavaScript string.
26313 *
26314 * Note that a JavaScript string can contain wide characters, therefore the
26315 * length of the string is not necessarily equal to the byte length.
26316 *
26317 * @param str Input string.
26318 * @returns Byte length.
26319 */
26320 function stringByteLength(str) {
26321 if (useNodeBuffer) {
26322 return Buffer.byteLength(str, 'utf8');
26323 }
26324 return new Blob([str]).size;
26325 }
26326 /**
26327 * Encode an ArrayBuffer as a base64 encoded string.
26328 *
26329 * @param buffer `ArrayBuffer` to be converted.
26330 * @returns A string that base64-encodes `buffer`.
26331 */
26332 function arrayBufferToBase64String(buffer) {
26333 if (useNodeBuffer) {
26334 return Buffer.from(buffer).toString('base64');
26335 }
26336 var buf = new Uint8Array(buffer);
26337 var s = '';
26338 for (var i = 0, l = buf.length; i < l; i++) {
26339 s += String.fromCharCode(buf[i]);
26340 }
26341 return btoa(s);
26342 }
26343 /**
26344 * Decode a base64 string as an ArrayBuffer.
26345 *
26346 * @param str Base64 string.
26347 * @returns Decoded `ArrayBuffer`.
26348 */
26349 function base64StringToArrayBuffer(str) {
26350 if (useNodeBuffer) {
26351 var buf = Buffer.from(str, 'base64');
26352 return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);
26353 }
26354 var s = atob(str);
26355 var buffer = new Uint8Array(s.length);
26356 for (var i = 0; i < s.length; ++i) {
26357 buffer.set([s.charCodeAt(i)], i);
26358 }
26359 return buffer.buffer;
26360 }
26361 /**
26362 * Concatenate a number of ArrayBuffers into one.
26363 *
26364 * @param buffers An array of ArrayBuffers to concatenate, or a single
26365 * ArrayBuffer.
26366 * @returns Result of concatenating `buffers` in order.
26367 *
26368 * @deprecated Use tf.io.CompositeArrayBuffer.join() instead.
26369 */
26370 function concatenateArrayBuffers(buffers) {
26371 return CompositeArrayBuffer.join(buffers);
26372 }
26373 /**
26374 * Get the basename of a path.
26375 *
26376 * Behaves in a way analogous to Linux's basename command.
26377 *
26378 * @param path
26379 */
26380 function basename(path) {
26381 var SEPARATOR = '/';
26382 path = path.trim();
26383 while (path.endsWith(SEPARATOR)) {
26384 path = path.slice(0, path.length - 1);
26385 }
26386 var items = path.split(SEPARATOR);
26387 return items[items.length - 1];
26388 }
26389 /**
26390 * Create `ModelJSON` from `ModelArtifacts`.
26391 *
26392 * @param artifacts Model artifacts, describing the model and its weights.
26393 * @param manifest Weight manifest, describing where the weights of the
26394 * `ModelArtifacts` are stored, and some metadata about them.
26395 * @returns Object representing the `model.json` file describing the model
26396 * artifacts and weights
26397 */
26398 function getModelJSONForModelArtifacts(artifacts, manifest) {
26399 var result = {
26400 modelTopology: artifacts.modelTopology,
26401 format: artifacts.format,
26402 generatedBy: artifacts.generatedBy,
26403 convertedBy: artifacts.convertedBy,
26404 weightsManifest: manifest
26405 };
26406 if (artifacts.signature != null) {
26407 result.signature = artifacts.signature;
26408 }
26409 if (artifacts.userDefinedMetadata != null) {
26410 result.userDefinedMetadata = artifacts.userDefinedMetadata;
26411 }
26412 if (artifacts.modelInitializer != null) {
26413 result.modelInitializer = artifacts.modelInitializer;
26414 }
26415 if (artifacts.initializerSignature != null) {
26416 result.initializerSignature = artifacts.initializerSignature;
26417 }
26418 if (artifacts.trainingConfig != null) {
26419 result.trainingConfig = artifacts.trainingConfig;
26420 }
26421 return result;
26422 }
26423 /**
26424 * Create `ModelArtifacts` from a JSON file and weights.
26425 *
26426 * @param modelJSON Object containing the parsed JSON of `model.json`
26427 * @param weightSpecs The list of WeightsManifestEntry for the model. Must be
26428 * passed if the modelJSON has a weightsManifest.
26429 * @param weightData An ArrayBuffer or array of ArrayBuffers of weight data for
26430 * the model corresponding to the weights in weightSpecs. Must be passed if
26431 * the modelJSON has a weightsManifest.
26432 * @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
26433 */
26434 function getModelArtifactsForJSONSync(modelJSON, weightSpecs, weightData) {
26435 var modelArtifacts = {
26436 modelTopology: modelJSON.modelTopology,
26437 format: modelJSON.format,
26438 generatedBy: modelJSON.generatedBy,
26439 convertedBy: modelJSON.convertedBy
26440 };
26441 if (modelJSON.trainingConfig != null) {
26442 modelArtifacts.trainingConfig = modelJSON.trainingConfig;
26443 }
26444 if (modelJSON.weightsManifest != null) {
26445 if (!weightSpecs) {
26446 throw new Error('modelJSON has weightsManifest but weightSpecs is null');
26447 }
26448 if (!weightData) {
26449 throw new Error('modelJSON has weightsManifest but weightData is null');
26450 }
26451 modelArtifacts.weightSpecs = weightSpecs;
26452 modelArtifacts.weightData = weightData;
26453 }
26454 if (modelJSON.signature != null) {
26455 modelArtifacts.signature = modelJSON.signature;
26456 }
26457 if (modelJSON.userDefinedMetadata != null) {
26458 modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata;
26459 }
26460 if (modelJSON.modelInitializer != null) {
26461 modelArtifacts.modelInitializer = modelJSON.modelInitializer;
26462 }
26463 if (modelJSON.initializerSignature != null) {
26464 modelArtifacts.initializerSignature = modelJSON.initializerSignature;
26465 }
26466 return modelArtifacts;
26467 }
26468 /**
26469 * Create `ModelArtifacts` from a JSON file.
26470 *
26471 * @param modelJSON Object containing the parsed JSON of `model.json`
26472 * @param loadWeights Function that takes the JSON file's weights manifest,
26473 * reads weights from the listed path(s), and returns a Promise of the
26474 * weight manifest entries along with the weights data.
26475 * @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
26476 */
26477 function getModelArtifactsForJSON(_x10, _x11) {
26478 return _getModelArtifactsForJSON.apply(this, arguments);
26479 }
26480 /**
26481 * Populate ModelArtifactsInfo fields for a model with JSON topology.
26482 * @param modelArtifacts
26483 * @returns A ModelArtifactsInfo object.
26484 */
26485 function _getModelArtifactsForJSON() {
26486 _getModelArtifactsForJSON = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee7(modelJSON, loadWeights) {
26487 var weightSpecs, weightData, _yield$loadWeights, _yield$loadWeights2;
26488 return _regeneratorRuntime().wrap(function _callee7$(_context8) {
26489 while (1) switch (_context8.prev = _context8.next) {
26490 case 0:
26491 if (!(modelJSON.weightsManifest != null)) {
26492 _context8.next = 7;
26493 break;
26494 }
26495 _context8.next = 3;
26496 return loadWeights(modelJSON.weightsManifest);
26497 case 3:
26498 _yield$loadWeights = _context8.sent;
26499 _yield$loadWeights2 = _slicedToArray(_yield$loadWeights, 2);
26500 weightSpecs = _yield$loadWeights2[0];
26501 weightData = _yield$loadWeights2[1];
26502 case 7:
26503 return _context8.abrupt("return", getModelArtifactsForJSONSync(modelJSON, weightSpecs, weightData));
26504 case 8:
26505 case "end":
26506 return _context8.stop();
26507 }
26508 }, _callee7);
26509 }));
26510 return _getModelArtifactsForJSON.apply(this, arguments);
26511 }
26512 function getModelArtifactsInfoForJSON(modelArtifacts) {
26513 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
26514 throw new Error('Expected JSON model topology, received ArrayBuffer.');
26515 }
26516 return {
26517 dateSaved: new Date(),
26518 modelTopologyType: 'JSON',
26519 modelTopologyBytes: modelArtifacts.modelTopology == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),
26520 weightSpecsBytes: modelArtifacts.weightSpecs == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),
26521 weightDataBytes: modelArtifacts.weightData == null ? 0 : new CompositeArrayBuffer(modelArtifacts.weightData).byteLength
26522 };
26523 }
26524 /**
26525 * Concatenate the weights stored in a WeightsManifestConfig into a list of
26526 * WeightsManifestEntry
26527 *
26528 * @param weightsManifest The WeightsManifestConfig to extract weights from.
26529 * @returns A list of WeightsManifestEntry of the weights in the weightsManifest
26530 */
26531 function getWeightSpecs(weightsManifest) {
26532 var weightSpecs = [];
26533 var _iterator2 = _createForOfIteratorHelper(weightsManifest),
26534 _step2;
26535 try {
26536 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
26537 var entry = _step2.value;
26538 weightSpecs.push.apply(weightSpecs, _toConsumableArray(entry.weights));
26539 }
26540 } catch (err) {
26541 _iterator2.e(err);
26542 } finally {
26543 _iterator2.f();
26544 }
26545 return weightSpecs;
26546 }
26547 /**
26548 * Computes mantisa table for casting Float16 to Float32
26549 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
26550 *
26551 * @returns Uint32Array, 2048 mantissa lookup values.
26552 */
26553 function computeFloat16MantisaTable() {
26554 var convertMantissa = function convertMantissa(i) {
26555 var m = i << 13;
26556 var e = 0;
26557 while ((m & 0x00800000) === 0) {
26558 e -= 0x00800000;
26559 m <<= 1;
26560 }
26561 m &= ~0x00800000;
26562 e += 0x38800000;
26563 return m | e;
26564 };
26565 var mantisaTable = new Uint32Array(2048);
26566 mantisaTable[0] = 0;
26567 for (var i = 1; i < 1024; i++) {
26568 mantisaTable[i] = convertMantissa(i);
26569 }
26570 for (var _i4 = 1024; _i4 < 2048; _i4++) {
26571 mantisaTable[_i4] = 0x38000000 + (_i4 - 1024 << 13);
26572 }
26573 return mantisaTable;
26574 }
26575 /**
26576 * Computes exponent table for casting Float16 to Float32
26577 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
26578 *
26579 * @returns Uint32Array, 64 exponent lookup values.
26580 */
26581 function computeFloat16ExponentTable() {
26582 var exponentTable = new Uint32Array(64);
26583 exponentTable[0] = 0;
26584 exponentTable[31] = 0x47800000;
26585 exponentTable[32] = 0x80000000;
26586 exponentTable[63] = 0xc7800000;
26587 for (var i = 1; i < 31; i++) {
26588 exponentTable[i] = i << 23;
26589 }
26590 for (var _i5 = 33; _i5 < 63; _i5++) {
26591 exponentTable[_i5] = 0x80000000 + (_i5 - 32 << 23);
26592 }
26593 return exponentTable;
26594 }
26595 /**
26596 * Computes offset table for casting Float16 to Float32
26597 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
26598 *
26599 * @returns Uint32Array, 6d offset values.
26600 */
26601 function computeFloat16OffsetTable() {
26602 var offsetTable = new Uint32Array(64);
26603 for (var i = 0; i < 64; i++) {
26604 offsetTable[i] = 1024;
26605 }
26606 offsetTable[0] = offsetTable[32] = 0;
26607 return offsetTable;
26608 }
26609 /**
26610 * Retrieve a Float16 decoder which will decode a ByteArray of Float16 values
26611 * to a Float32Array.
26612 *
26613 * @returns Function (buffer: Uint16Array) => Float32Array which decodes
26614 * the Uint16Array of Float16 bytes to a Float32Array.
26615 */
26616 function getFloat16Decoder() {
26617 // Algorithm is based off of
26618 // http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
26619 // Cache lookup tables
26620 var mantisaTable = computeFloat16MantisaTable();
26621 var exponentTable = computeFloat16ExponentTable();
26622 var offsetTable = computeFloat16OffsetTable();
26623 return function (quantizedArray) {
26624 var buffer = new ArrayBuffer(4 * quantizedArray.length);
26625 var bufferUint32View = new Uint32Array(buffer);
26626 for (var index = 0; index < quantizedArray.length; index++) {
26627 var float16Bits = quantizedArray[index];
26628 var float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] + exponentTable[float16Bits >> 10];
26629 bufferUint32View[index] = float32Bits;
26630 }
26631 return new Float32Array(buffer);
26632 };
26633 }
26634
26635 /**
26636 * @license
26637 * Copyright 2018 Google LLC. All Rights Reserved.
26638 * Licensed under the Apache License, Version 2.0 (the "License");
26639 * you may not use this file except in compliance with the License.
26640 * You may obtain a copy of the License at
26641 *
26642 * http://www.apache.org/licenses/LICENSE-2.0
26643 *
26644 * Unless required by applicable law or agreed to in writing, software
26645 * distributed under the License is distributed on an "AS IS" BASIS,
26646 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26647 * See the License for the specific language governing permissions and
26648 * limitations under the License.
26649 * =============================================================================
26650 */
26651 var IORouterRegistry = /*#__PURE__*/function () {
26652 function IORouterRegistry() {
26653 _classCallCheck(this, IORouterRegistry);
26654 this.saveRouters = [];
26655 this.loadRouters = [];
26656 }
26657 _createClass(IORouterRegistry, null, [{
26658 key: "getInstance",
26659 value: function getInstance() {
26660 if (IORouterRegistry.instance == null) {
26661 IORouterRegistry.instance = new IORouterRegistry();
26662 }
26663 return IORouterRegistry.instance;
26664 }
26665 /**
26666 * Register a save-handler router.
26667 *
26668 * @param saveRouter A function that maps a URL-like string onto an instance
26669 * of `IOHandler` with the `save` method defined or `null`.
26670 */
26671 }, {
26672 key: "registerSaveRouter",
26673 value: function registerSaveRouter(saveRouter) {
26674 IORouterRegistry.getInstance().saveRouters.push(saveRouter);
26675 }
26676 /**
26677 * Register a load-handler router.
26678 *
26679 * @param loadRouter A function that maps a URL-like string onto an instance
26680 * of `IOHandler` with the `load` method defined or `null`.
26681 */
26682 }, {
26683 key: "registerLoadRouter",
26684 value: function registerLoadRouter(loadRouter) {
26685 IORouterRegistry.getInstance().loadRouters.push(loadRouter);
26686 }
26687 /**
26688 * Look up IOHandler for saving, given a URL-like string.
26689 *
26690 * @param url
26691 * @returns If only one match is found, an instance of IOHandler with the
26692 * `save` method defined. If no match is found, `null`.
26693 * @throws Error, if more than one match is found.
26694 */
26695 }, {
26696 key: "getSaveHandlers",
26697 value: function getSaveHandlers(url) {
26698 return IORouterRegistry.getHandlers(url, 'save');
26699 }
26700 /**
26701 * Look up IOHandler for loading, given a URL-like string.
26702 *
26703 * @param url
26704 * @param loadOptions Optional, custom load options.
26705 * @returns All valid handlers for `url`, given the currently registered
26706 * handler routers.
26707 */
26708 }, {
26709 key: "getLoadHandlers",
26710 value: function getLoadHandlers(url, loadOptions) {
26711 return IORouterRegistry.getHandlers(url, 'load', loadOptions);
26712 }
26713 }, {
26714 key: "getHandlers",
26715 value: function getHandlers(url, handlerType, loadOptions) {
26716 var validHandlers = [];
26717 var routers = handlerType === 'load' ? IORouterRegistry.getInstance().loadRouters : IORouterRegistry.getInstance().saveRouters;
26718 routers.forEach(function (router) {
26719 var handler = router(url, loadOptions);
26720 if (handler !== null) {
26721 validHandlers.push(handler);
26722 }
26723 });
26724 return validHandlers;
26725 }
26726 }]);
26727 return IORouterRegistry;
26728 }();
26729 var registerSaveRouter = function registerSaveRouter(loudRouter) {
26730 return IORouterRegistry.registerSaveRouter(loudRouter);
26731 };
26732 var registerLoadRouter = function registerLoadRouter(loudRouter) {
26733 return IORouterRegistry.registerLoadRouter(loudRouter);
26734 };
26735 var getSaveHandlers = function getSaveHandlers(url) {
26736 return IORouterRegistry.getSaveHandlers(url);
26737 };
26738 var getLoadHandlers = function getLoadHandlers(url, loadOptions) {
26739 return IORouterRegistry.getLoadHandlers(url, loadOptions);
26740 };
26741
26742 var DATABASE_NAME = 'tensorflowjs';
26743 var DATABASE_VERSION = 1;
26744 // Model data and ModelArtifactsInfo (metadata) are stored in two separate
26745 // stores for efficient access of the list of stored models and their metadata.
26746 // 1. The object store for model data: topology, weights and weight manifests.
26747 var MODEL_STORE_NAME = 'models_store';
26748 // 2. The object store for ModelArtifactsInfo, including meta-information such
26749 // as the type of topology (JSON vs binary), byte size of the topology, byte
26750 // size of the weights, etc.
26751 var INFO_STORE_NAME = 'model_info_store';
26752 /**
26753 * Delete the entire database for tensorflow.js, including the models store.
26754 */
26755 function deleteDatabase() {
26756 return _deleteDatabase.apply(this, arguments);
26757 }
26758 function _deleteDatabase() {
26759 _deleteDatabase = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5() {
26760 var idbFactory;
26761 return _regeneratorRuntime().wrap(function _callee5$(_context5) {
26762 while (1) switch (_context5.prev = _context5.next) {
26763 case 0:
26764 idbFactory = getIndexedDBFactory();
26765 return _context5.abrupt("return", new Promise(function (resolve, reject) {
26766 var deleteRequest = idbFactory.deleteDatabase(DATABASE_NAME);
26767 deleteRequest.onsuccess = function () {
26768 return resolve();
26769 };
26770 deleteRequest.onerror = function (error) {
26771 return reject(error);
26772 };
26773 }));
26774 case 2:
26775 case "end":
26776 return _context5.stop();
26777 }
26778 }, _callee5);
26779 }));
26780 return _deleteDatabase.apply(this, arguments);
26781 }
26782 function getIndexedDBFactory() {
26783 if (!env().getBool('IS_BROWSER')) {
26784 // TODO(cais): Add more info about what IOHandler subtypes are available.
26785 // Maybe point to a doc page on the web and/or automatically determine
26786 // the available IOHandlers and print them in the error message.
26787 throw new Error('Failed to obtain IndexedDB factory because the current environment' + 'is not a web browser.');
26788 }
26789 // tslint:disable-next-line:no-any
26790 var theWindow = typeof window === 'undefined' ? self : window;
26791 var factory = theWindow.indexedDB || theWindow.mozIndexedDB || theWindow.webkitIndexedDB || theWindow.msIndexedDB || theWindow.shimIndexedDB;
26792 if (factory == null) {
26793 throw new Error('The current browser does not appear to support IndexedDB.');
26794 }
26795 return factory;
26796 }
26797 function setUpDatabase(openRequest) {
26798 var db = openRequest.result;
26799 db.createObjectStore(MODEL_STORE_NAME, {
26800 keyPath: 'modelPath'
26801 });
26802 db.createObjectStore(INFO_STORE_NAME, {
26803 keyPath: 'modelPath'
26804 });
26805 }
26806 /**
26807 * IOHandler subclass: Browser IndexedDB.
26808 *
26809 * See the doc string of `browserIndexedDB` for more details.
26810 */
26811 var BrowserIndexedDB = /*#__PURE__*/function () {
26812 function BrowserIndexedDB(modelPath) {
26813 _classCallCheck(this, BrowserIndexedDB);
26814 this.indexedDB = getIndexedDBFactory();
26815 if (modelPath == null || !modelPath) {
26816 throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.');
26817 }
26818 this.modelPath = modelPath;
26819 }
26820 _createClass(BrowserIndexedDB, [{
26821 key: "save",
26822 value: function () {
26823 var _save = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(modelArtifacts) {
26824 return _regeneratorRuntime().wrap(function _callee$(_context) {
26825 while (1) switch (_context.prev = _context.next) {
26826 case 0:
26827 if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) {
26828 _context.next = 2;
26829 break;
26830 }
26831 throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + 'in binary formats yet.');
26832 case 2:
26833 return _context.abrupt("return", this.databaseAction(this.modelPath, modelArtifacts));
26834 case 3:
26835 case "end":
26836 return _context.stop();
26837 }
26838 }, _callee, this);
26839 }));
26840 function save(_x) {
26841 return _save.apply(this, arguments);
26842 }
26843 return save;
26844 }()
26845 }, {
26846 key: "load",
26847 value: function () {
26848 var _load = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
26849 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
26850 while (1) switch (_context2.prev = _context2.next) {
26851 case 0:
26852 return _context2.abrupt("return", this.databaseAction(this.modelPath));
26853 case 1:
26854 case "end":
26855 return _context2.stop();
26856 }
26857 }, _callee2, this);
26858 }));
26859 function load() {
26860 return _load.apply(this, arguments);
26861 }
26862 return load;
26863 }()
26864 /**
26865 * Perform database action to put model artifacts into or read model artifacts
26866 * from IndexedDB object store.
26867 *
26868 * Whether the action is put or get depends on whether `modelArtifacts` is
26869 * specified. If it is specified, the action will be put; otherwise the action
26870 * will be get.
26871 *
26872 * @param modelPath A unique string path for the model.
26873 * @param modelArtifacts If specified, it will be the model artifacts to be
26874 * stored in IndexedDB.
26875 * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise`
26876 * of `ModelArtifacts`, if the action is get.
26877 */
26878 }, {
26879 key: "databaseAction",
26880 value: function databaseAction(modelPath, modelArtifacts) {
26881 var _this = this;
26882 return new Promise(function (resolve, reject) {
26883 var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
26884 openRequest.onupgradeneeded = function () {
26885 return setUpDatabase(openRequest);
26886 };
26887 openRequest.onsuccess = function () {
26888 var db = openRequest.result;
26889 if (modelArtifacts == null) {
26890 // Read model out from object store.
26891 var modelTx = db.transaction(MODEL_STORE_NAME, 'readonly');
26892 var modelStore = modelTx.objectStore(MODEL_STORE_NAME);
26893 var getRequest = modelStore.get(_this.modelPath);
26894 getRequest.onsuccess = function () {
26895 if (getRequest.result == null) {
26896 db.close();
26897 return reject(new Error("Cannot find model with path '".concat(_this.modelPath, "' ") + "in IndexedDB."));
26898 } else {
26899 resolve(getRequest.result.modelArtifacts);
26900 }
26901 };
26902 getRequest.onerror = function (error) {
26903 db.close();
26904 return reject(getRequest.error);
26905 };
26906 modelTx.oncomplete = function () {
26907 return db.close();
26908 };
26909 } else {
26910 // Put model into object store.
26911 // Concatenate all the model weights into a single ArrayBuffer. Large
26912 // models (~1GB) have problems saving if they are not concatenated.
26913 // TODO(mattSoulanille): Save large models to multiple indexeddb
26914 // records.
26915 modelArtifacts.weightData = CompositeArrayBuffer.join(modelArtifacts.weightData);
26916 var modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
26917 // First, put ModelArtifactsInfo into info store.
26918 var infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
26919 var infoStore = infoTx.objectStore(INFO_STORE_NAME);
26920 var putInfoRequest;
26921 try {
26922 putInfoRequest = infoStore.put({
26923 modelPath: _this.modelPath,
26924 modelArtifactsInfo: modelArtifactsInfo
26925 });
26926 } catch (error) {
26927 return reject(error);
26928 }
26929 var _modelTx;
26930 putInfoRequest.onsuccess = function () {
26931 // Second, put model data into model store.
26932 _modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
26933 var modelStore = _modelTx.objectStore(MODEL_STORE_NAME);
26934 var putModelRequest;
26935 try {
26936 putModelRequest = modelStore.put({
26937 modelPath: _this.modelPath,
26938 modelArtifacts: modelArtifacts,
26939 modelArtifactsInfo: modelArtifactsInfo
26940 });
26941 } catch (error) {
26942 // Sometimes, the serialized value is too large to store.
26943 return reject(error);
26944 }
26945 putModelRequest.onsuccess = function () {
26946 return resolve({
26947 modelArtifactsInfo: modelArtifactsInfo
26948 });
26949 };
26950 putModelRequest.onerror = function (error) {
26951 // If the put-model request fails, roll back the info entry as
26952 // well.
26953 infoStore = infoTx.objectStore(INFO_STORE_NAME);
26954 var deleteInfoRequest = infoStore.delete(_this.modelPath);
26955 deleteInfoRequest.onsuccess = function () {
26956 db.close();
26957 return reject(putModelRequest.error);
26958 };
26959 deleteInfoRequest.onerror = function (error) {
26960 db.close();
26961 return reject(putModelRequest.error);
26962 };
26963 };
26964 };
26965 putInfoRequest.onerror = function (error) {
26966 db.close();
26967 return reject(putInfoRequest.error);
26968 };
26969 infoTx.oncomplete = function () {
26970 if (_modelTx == null) {
26971 db.close();
26972 } else {
26973 _modelTx.oncomplete = function () {
26974 return db.close();
26975 };
26976 }
26977 };
26978 }
26979 };
26980 openRequest.onerror = function (error) {
26981 return reject(openRequest.error);
26982 };
26983 });
26984 }
26985 }]);
26986 return BrowserIndexedDB;
26987 }();
26988 BrowserIndexedDB.URL_SCHEME = 'indexeddb://';
26989 var indexedDBRouter = function indexedDBRouter(url) {
26990 if (!env().getBool('IS_BROWSER')) {
26991 return null;
26992 } else {
26993 if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) {
26994 return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length));
26995 } else {
26996 return null;
26997 }
26998 }
26999 };
27000 IORouterRegistry.registerSaveRouter(indexedDBRouter);
27001 IORouterRegistry.registerLoadRouter(indexedDBRouter);
27002 /**
27003 * Creates a browser IndexedDB IOHandler for saving and loading models.
27004 *
27005 * ```js
27006 * const model = tf.sequential();
27007 * model.add(
27008 * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
27009 *
27010 * const saveResult = await model.save('indexeddb://MyModel'));
27011 * console.log(saveResult);
27012 * ```
27013 *
27014 * @param modelPath A unique identifier for the model to be saved. Must be a
27015 * non-empty string.
27016 * @returns An instance of `BrowserIndexedDB` (subclass of `IOHandler`),
27017 * which can be used with, e.g., `tf.Model.save`.
27018 */
27019 function browserIndexedDB(modelPath) {
27020 return new BrowserIndexedDB(modelPath);
27021 }
27022 function maybeStripScheme$1(key) {
27023 return key.startsWith(BrowserIndexedDB.URL_SCHEME) ? key.slice(BrowserIndexedDB.URL_SCHEME.length) : key;
27024 }
27025 var BrowserIndexedDBManager = /*#__PURE__*/function () {
27026 function BrowserIndexedDBManager() {
27027 _classCallCheck(this, BrowserIndexedDBManager);
27028 this.indexedDB = getIndexedDBFactory();
27029 }
27030 _createClass(BrowserIndexedDBManager, [{
27031 key: "listModels",
27032 value: function () {
27033 var _listModels = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3() {
27034 var _this2 = this;
27035 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
27036 while (1) switch (_context3.prev = _context3.next) {
27037 case 0:
27038 return _context3.abrupt("return", new Promise(function (resolve, reject) {
27039 var openRequest = _this2.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
27040 openRequest.onupgradeneeded = function () {
27041 return setUpDatabase(openRequest);
27042 };
27043 openRequest.onsuccess = function () {
27044 var db = openRequest.result;
27045 var tx = db.transaction(INFO_STORE_NAME, 'readonly');
27046 var store = tx.objectStore(INFO_STORE_NAME);
27047 // tslint:disable:max-line-length
27048 // Need to cast `store` as `any` here because TypeScript's DOM
27049 // library does not have the `getAll()` method even though the
27050 // method is supported in the latest version of most mainstream
27051 // browsers:
27052 // https://developer.mozilla.org/en-US/docs/Web/API/IDBObjectStore/getAll
27053 // tslint:enable:max-line-length
27054 // tslint:disable-next-line:no-any
27055 var getAllInfoRequest = store.getAll();
27056 getAllInfoRequest.onsuccess = function () {
27057 var out = {};
27058 var _iterator = _createForOfIteratorHelper(getAllInfoRequest.result),
27059 _step;
27060 try {
27061 for (_iterator.s(); !(_step = _iterator.n()).done;) {
27062 var item = _step.value;
27063 out[item.modelPath] = item.modelArtifactsInfo;
27064 }
27065 } catch (err) {
27066 _iterator.e(err);
27067 } finally {
27068 _iterator.f();
27069 }
27070 resolve(out);
27071 };
27072 getAllInfoRequest.onerror = function (error) {
27073 db.close();
27074 return reject(getAllInfoRequest.error);
27075 };
27076 tx.oncomplete = function () {
27077 return db.close();
27078 };
27079 };
27080 openRequest.onerror = function (error) {
27081 return reject(openRequest.error);
27082 };
27083 }));
27084 case 1:
27085 case "end":
27086 return _context3.stop();
27087 }
27088 }, _callee3);
27089 }));
27090 function listModels() {
27091 return _listModels.apply(this, arguments);
27092 }
27093 return listModels;
27094 }()
27095 }, {
27096 key: "removeModel",
27097 value: function () {
27098 var _removeModel = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(path) {
27099 var _this3 = this;
27100 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
27101 while (1) switch (_context4.prev = _context4.next) {
27102 case 0:
27103 path = maybeStripScheme$1(path);
27104 return _context4.abrupt("return", new Promise(function (resolve, reject) {
27105 var openRequest = _this3.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
27106 openRequest.onupgradeneeded = function () {
27107 return setUpDatabase(openRequest);
27108 };
27109 openRequest.onsuccess = function () {
27110 var db = openRequest.result;
27111 var infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
27112 var infoStore = infoTx.objectStore(INFO_STORE_NAME);
27113 var getInfoRequest = infoStore.get(path);
27114 var modelTx;
27115 getInfoRequest.onsuccess = function () {
27116 if (getInfoRequest.result == null) {
27117 db.close();
27118 return reject(new Error("Cannot find model with path '".concat(path, "' ") + "in IndexedDB."));
27119 } else {
27120 // First, delete the entry in the info store.
27121 var deleteInfoRequest = infoStore.delete(path);
27122 var deleteModelData = function deleteModelData() {
27123 // Second, delete the entry in the model store.
27124 modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
27125 var modelStore = modelTx.objectStore(MODEL_STORE_NAME);
27126 var deleteModelRequest = modelStore.delete(path);
27127 deleteModelRequest.onsuccess = function () {
27128 return resolve(getInfoRequest.result.modelArtifactsInfo);
27129 };
27130 deleteModelRequest.onerror = function (error) {
27131 return reject(getInfoRequest.error);
27132 };
27133 };
27134 // Proceed with deleting model data regardless of whether deletion
27135 // of info data succeeds or not.
27136 deleteInfoRequest.onsuccess = deleteModelData;
27137 deleteInfoRequest.onerror = function (error) {
27138 deleteModelData();
27139 db.close();
27140 return reject(getInfoRequest.error);
27141 };
27142 }
27143 };
27144 getInfoRequest.onerror = function (error) {
27145 db.close();
27146 return reject(getInfoRequest.error);
27147 };
27148 infoTx.oncomplete = function () {
27149 if (modelTx == null) {
27150 db.close();
27151 } else {
27152 modelTx.oncomplete = function () {
27153 return db.close();
27154 };
27155 }
27156 };
27157 };
27158 openRequest.onerror = function (error) {
27159 return reject(openRequest.error);
27160 };
27161 }));
27162 case 2:
27163 case "end":
27164 return _context4.stop();
27165 }
27166 }, _callee4);
27167 }));
27168 function removeModel(_x2) {
27169 return _removeModel.apply(this, arguments);
27170 }
27171 return removeModel;
27172 }()
27173 }]);
27174 return BrowserIndexedDBManager;
27175 }();
27176
27177 var PATH_SEPARATOR = '/';
27178 var PATH_PREFIX = 'tensorflowjs_models';
27179 var INFO_SUFFIX = 'info';
27180 var MODEL_TOPOLOGY_SUFFIX = 'model_topology';
27181 var WEIGHT_SPECS_SUFFIX = 'weight_specs';
27182 var WEIGHT_DATA_SUFFIX = 'weight_data';
27183 var MODEL_METADATA_SUFFIX = 'model_metadata';
27184 /**
27185 * Purge all tensorflow.js-saved model artifacts from local storage.
27186 *
27187 * @returns Paths of the models purged.
27188 */
27189 function purgeLocalStorageArtifacts() {
27190 if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' || typeof window.localStorage === 'undefined') {
27191 throw new Error('purgeLocalStorageModels() cannot proceed because local storage is ' + 'unavailable in the current environment.');
27192 }
27193 var LS = window.localStorage;
27194 var purgedModelPaths = [];
27195 for (var i = 0; i < LS.length; ++i) {
27196 var key = LS.key(i);
27197 var prefix = PATH_PREFIX + PATH_SEPARATOR;
27198 if (key.startsWith(prefix) && key.length > prefix.length) {
27199 LS.removeItem(key);
27200 var modelName = getModelPathFromKey(key);
27201 if (purgedModelPaths.indexOf(modelName) === -1) {
27202 purgedModelPaths.push(modelName);
27203 }
27204 }
27205 }
27206 return purgedModelPaths;
27207 }
27208 function getModelKeys(path) {
27209 return {
27210 info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
27211 topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
27212 weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
27213 weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
27214 modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
27215 };
27216 }
27217 function removeItems(keys) {
27218 for (var _i = 0, _Object$values = Object.values(keys); _i < _Object$values.length; _i++) {
27219 var key = _Object$values[_i];
27220 window.localStorage.removeItem(key);
27221 }
27222 }
27223 /**
27224 * Get model path from a local-storage key.
27225 *
27226 * E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1'
27227 *
27228 * @param key
27229 */
27230 function getModelPathFromKey(key) {
27231 var items = key.split(PATH_SEPARATOR);
27232 if (items.length < 3) {
27233 throw new Error("Invalid key format: ".concat(key));
27234 }
27235 return items.slice(1, items.length - 1).join(PATH_SEPARATOR);
27236 }
27237 function maybeStripScheme(key) {
27238 return key.startsWith(BrowserLocalStorage.URL_SCHEME) ? key.slice(BrowserLocalStorage.URL_SCHEME.length) : key;
27239 }
27240 /**
27241 * IOHandler subclass: Browser Local Storage.
27242 *
27243 * See the doc string to `browserLocalStorage` for more details.
27244 */
27245 var BrowserLocalStorage = /*#__PURE__*/function () {
27246 function BrowserLocalStorage(modelPath) {
27247 _classCallCheck(this, BrowserLocalStorage);
27248 if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' || typeof window.localStorage === 'undefined') {
27249 // TODO(cais): Add more info about what IOHandler subtypes are
27250 // available.
27251 // Maybe point to a doc page on the web and/or automatically determine
27252 // the available IOHandlers and print them in the error message.
27253 throw new Error('The current environment does not support local storage.');
27254 }
27255 this.LS = window.localStorage;
27256 if (modelPath == null || !modelPath) {
27257 throw new Error('For local storage, modelPath must not be null, undefined or empty.');
27258 }
27259 this.modelPath = modelPath;
27260 this.keys = getModelKeys(this.modelPath);
27261 }
27262 /**
27263 * Save model artifacts to browser local storage.
27264 *
27265 * See the documentation to `browserLocalStorage` for details on the saved
27266 * artifacts.
27267 *
27268 * @param modelArtifacts The model artifacts to be stored.
27269 * @returns An instance of SaveResult.
27270 */
27271 _createClass(BrowserLocalStorage, [{
27272 key: "save",
27273 value: function () {
27274 var _save = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(modelArtifacts) {
27275 var topology, weightSpecs, modelArtifactsInfo, weightBuffer, metadata;
27276 return _regeneratorRuntime().wrap(function _callee$(_context) {
27277 while (1) switch (_context.prev = _context.next) {
27278 case 0:
27279 if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) {
27280 _context.next = 4;
27281 break;
27282 }
27283 throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + 'in binary formats yet.');
27284 case 4:
27285 topology = JSON.stringify(modelArtifacts.modelTopology);
27286 weightSpecs = JSON.stringify(modelArtifacts.weightSpecs);
27287 modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); // TODO(mattsoulanille): Support saving models over 2GB that exceed
27288 // Chrome's ArrayBuffer size limit.
27289 weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
27290 _context.prev = 8;
27291 this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo));
27292 this.LS.setItem(this.keys.topology, topology);
27293 this.LS.setItem(this.keys.weightSpecs, weightSpecs);
27294 this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(weightBuffer));
27295 // Note that JSON.stringify doesn't write out keys that have undefined
27296 // values, so for some keys, we set undefined instead of a null-ish
27297 // value.
27298 metadata = {
27299 format: modelArtifacts.format,
27300 generatedBy: modelArtifacts.generatedBy,
27301 convertedBy: modelArtifacts.convertedBy,
27302 signature: modelArtifacts.signature != null ? modelArtifacts.signature : undefined,
27303 userDefinedMetadata: modelArtifacts.userDefinedMetadata != null ? modelArtifacts.userDefinedMetadata : undefined,
27304 modelInitializer: modelArtifacts.modelInitializer != null ? modelArtifacts.modelInitializer : undefined,
27305 initializerSignature: modelArtifacts.initializerSignature != null ? modelArtifacts.initializerSignature : undefined,
27306 trainingConfig: modelArtifacts.trainingConfig != null ? modelArtifacts.trainingConfig : undefined
27307 };
27308 this.LS.setItem(this.keys.modelMetadata, JSON.stringify(metadata));
27309 return _context.abrupt("return", {
27310 modelArtifactsInfo: modelArtifactsInfo
27311 });
27312 case 18:
27313 _context.prev = 18;
27314 _context.t0 = _context["catch"](8);
27315 // If saving failed, clean up all items saved so far.
27316 removeItems(this.keys);
27317 throw new Error("Failed to save model '".concat(this.modelPath, "' to local storage: ") + "size quota being exceeded is a possible cause of this failure: " + "modelTopologyBytes=".concat(modelArtifactsInfo.modelTopologyBytes, ", ") + "weightSpecsBytes=".concat(modelArtifactsInfo.weightSpecsBytes, ", ") + "weightDataBytes=".concat(modelArtifactsInfo.weightDataBytes, "."));
27318 case 22:
27319 case "end":
27320 return _context.stop();
27321 }
27322 }, _callee, this, [[8, 18]]);
27323 }));
27324 function save(_x) {
27325 return _save.apply(this, arguments);
27326 }
27327 return save;
27328 }()
27329 /**
27330 * Load a model from local storage.
27331 *
27332 * See the documentation to `browserLocalStorage` for details on the saved
27333 * artifacts.
27334 *
27335 * @returns The loaded model (if loading succeeds).
27336 */
27337 }, {
27338 key: "load",
27339 value: function () {
27340 var _load = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
27341 var info, out, topology, weightSpecs, metadataString, metadata, weightDataBase64;
27342 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
27343 while (1) switch (_context2.prev = _context2.next) {
27344 case 0:
27345 info = JSON.parse(this.LS.getItem(this.keys.info));
27346 if (!(info == null)) {
27347 _context2.next = 3;
27348 break;
27349 }
27350 throw new Error("In local storage, there is no model with name '".concat(this.modelPath, "'"));
27351 case 3:
27352 if (!(info.modelTopologyType !== 'JSON')) {
27353 _context2.next = 5;
27354 break;
27355 }
27356 throw new Error('BrowserLocalStorage does not support loading non-JSON model ' + 'topology yet.');
27357 case 5:
27358 out = {}; // Load topology.
27359 topology = JSON.parse(this.LS.getItem(this.keys.topology));
27360 if (!(topology == null)) {
27361 _context2.next = 9;
27362 break;
27363 }
27364 throw new Error("In local storage, the topology of model '".concat(this.modelPath, "' ") + "is missing.");
27365 case 9:
27366 out.modelTopology = topology;
27367 // Load weight specs.
27368 weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
27369 if (!(weightSpecs == null)) {
27370 _context2.next = 13;
27371 break;
27372 }
27373 throw new Error("In local storage, the weight specs of model '".concat(this.modelPath, "' ") + "are missing.");
27374 case 13:
27375 out.weightSpecs = weightSpecs;
27376 // Load meta-data fields.
27377 metadataString = this.LS.getItem(this.keys.modelMetadata);
27378 if (metadataString != null) {
27379 metadata = JSON.parse(metadataString);
27380 out.format = metadata.format;
27381 out.generatedBy = metadata.generatedBy;
27382 out.convertedBy = metadata.convertedBy;
27383 if (metadata.signature != null) {
27384 out.signature = metadata.signature;
27385 }
27386 if (metadata.userDefinedMetadata != null) {
27387 out.userDefinedMetadata = metadata.userDefinedMetadata;
27388 }
27389 if (metadata.modelInitializer != null) {
27390 out.modelInitializer = metadata.modelInitializer;
27391 }
27392 if (metadata.initializerSignature != null) {
27393 out.initializerSignature = metadata.initializerSignature;
27394 }
27395 if (metadata.trainingConfig != null) {
27396 out.trainingConfig = metadata.trainingConfig;
27397 }
27398 }
27399 // Load weight data.
27400 weightDataBase64 = this.LS.getItem(this.keys.weightData);
27401 if (!(weightDataBase64 == null)) {
27402 _context2.next = 19;
27403 break;
27404 }
27405 throw new Error("In local storage, the binary weight values of model " + "'".concat(this.modelPath, "' are missing."));
27406 case 19:
27407 out.weightData = base64StringToArrayBuffer(weightDataBase64);
27408 return _context2.abrupt("return", out);
27409 case 21:
27410 case "end":
27411 return _context2.stop();
27412 }
27413 }, _callee2, this);
27414 }));
27415 function load() {
27416 return _load.apply(this, arguments);
27417 }
27418 return load;
27419 }()
27420 }]);
27421 return BrowserLocalStorage;
27422 }();
27423 BrowserLocalStorage.URL_SCHEME = 'localstorage://';
27424 var localStorageRouter = function localStorageRouter(url) {
27425 if (!env().getBool('IS_BROWSER')) {
27426 return null;
27427 } else {
27428 if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) {
27429 return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length));
27430 } else {
27431 return null;
27432 }
27433 }
27434 };
27435 IORouterRegistry.registerSaveRouter(localStorageRouter);
27436 IORouterRegistry.registerLoadRouter(localStorageRouter);
27437 /**
27438 * Factory function for local storage IOHandler.
27439 *
27440 * This `IOHandler` supports both `save` and `load`.
27441 *
27442 * For each model's saved artifacts, four items are saved to local storage.
27443 * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the
27444 * model, such as date saved, type of the topology, size in bytes, etc.
27445 * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras-
27446 * style models, this is a stringized JSON.
27447 * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the
27448 * model, can be used to decode the saved binary weight values (see
27449 * item below).
27450 * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary
27451 * weight values, stored as a base64-encoded string.
27452 *
27453 * Saving may throw an `Error` if the total size of the artifacts exceed the
27454 * browser-specific quota.
27455 *
27456 * @param modelPath A unique identifier for the model to be saved. Must be a
27457 * non-empty string.
27458 * @returns An instance of `IOHandler`, which can be used with, e.g.,
27459 * `tf.Model.save`.
27460 */
27461 function browserLocalStorage(modelPath) {
27462 return new BrowserLocalStorage(modelPath);
27463 }
27464 var BrowserLocalStorageManager = /*#__PURE__*/function () {
27465 function BrowserLocalStorageManager() {
27466 _classCallCheck(this, BrowserLocalStorageManager);
27467 assert$1(env().getBool('IS_BROWSER'), function () {
27468 return 'Current environment is not a web browser';
27469 });
27470 assert$1(typeof window === 'undefined' || typeof window.localStorage !== 'undefined', function () {
27471 return 'Current browser does not appear to support localStorage';
27472 });
27473 this.LS = window.localStorage;
27474 }
27475 _createClass(BrowserLocalStorageManager, [{
27476 key: "listModels",
27477 value: function () {
27478 var _listModels = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3() {
27479 var out, prefix, suffix, i, key, modelPath;
27480 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
27481 while (1) switch (_context3.prev = _context3.next) {
27482 case 0:
27483 out = {};
27484 prefix = PATH_PREFIX + PATH_SEPARATOR;
27485 suffix = PATH_SEPARATOR + INFO_SUFFIX;
27486 for (i = 0; i < this.LS.length; ++i) {
27487 key = this.LS.key(i);
27488 if (key.startsWith(prefix) && key.endsWith(suffix)) {
27489 modelPath = getModelPathFromKey(key);
27490 out[modelPath] = JSON.parse(this.LS.getItem(key));
27491 }
27492 }
27493 return _context3.abrupt("return", out);
27494 case 5:
27495 case "end":
27496 return _context3.stop();
27497 }
27498 }, _callee3, this);
27499 }));
27500 function listModels() {
27501 return _listModels.apply(this, arguments);
27502 }
27503 return listModels;
27504 }()
27505 }, {
27506 key: "removeModel",
27507 value: function () {
27508 var _removeModel = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(path) {
27509 var keys, info;
27510 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
27511 while (1) switch (_context4.prev = _context4.next) {
27512 case 0:
27513 path = maybeStripScheme(path);
27514 keys = getModelKeys(path);
27515 if (!(this.LS.getItem(keys.info) == null)) {
27516 _context4.next = 4;
27517 break;
27518 }
27519 throw new Error("Cannot find model at path '".concat(path, "'"));
27520 case 4:
27521 info = JSON.parse(this.LS.getItem(keys.info));
27522 removeItems(keys);
27523 return _context4.abrupt("return", info);
27524 case 7:
27525 case "end":
27526 return _context4.stop();
27527 }
27528 }, _callee4, this);
27529 }));
27530 function removeModel(_x2) {
27531 return _removeModel.apply(this, arguments);
27532 }
27533 return removeModel;
27534 }()
27535 }]);
27536 return BrowserLocalStorageManager;
27537 }();
27538
27539 var URL_SCHEME_SUFFIX = '://';
27540 var ModelStoreManagerRegistry = /*#__PURE__*/function () {
27541 function ModelStoreManagerRegistry() {
27542 _classCallCheck(this, ModelStoreManagerRegistry);
27543 this.managers = {};
27544 }
27545 _createClass(ModelStoreManagerRegistry, null, [{
27546 key: "getInstance",
27547 value: function getInstance() {
27548 if (ModelStoreManagerRegistry.instance == null) {
27549 ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry();
27550 }
27551 return ModelStoreManagerRegistry.instance;
27552 }
27553 /**
27554 * Register a save-handler router.
27555 *
27556 * @param saveRouter A function that maps a URL-like string onto an instance
27557 * of `IOHandler` with the `save` method defined or `null`.
27558 */
27559 }, {
27560 key: "registerManager",
27561 value: function registerManager(scheme, manager) {
27562 assert$1(scheme != null, function () {
27563 return 'scheme must not be undefined or null.';
27564 });
27565 if (scheme.endsWith(URL_SCHEME_SUFFIX)) {
27566 scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX));
27567 }
27568 assert$1(scheme.length > 0, function () {
27569 return 'scheme must not be an empty string.';
27570 });
27571 var registry = ModelStoreManagerRegistry.getInstance();
27572 assert$1(registry.managers[scheme] == null, function () {
27573 return "A model store manager is already registered for scheme '".concat(scheme, "'.");
27574 });
27575 registry.managers[scheme] = manager;
27576 }
27577 }, {
27578 key: "getManager",
27579 value: function getManager(scheme) {
27580 var manager = ModelStoreManagerRegistry.getInstance().managers[scheme];
27581 if (manager == null) {
27582 throw new Error("Cannot find model manager for scheme '".concat(scheme, "'"));
27583 }
27584 return manager;
27585 }
27586 }, {
27587 key: "getSchemes",
27588 value: function getSchemes() {
27589 return Object.keys(ModelStoreManagerRegistry.getInstance().managers);
27590 }
27591 }]);
27592 return ModelStoreManagerRegistry;
27593 }();
27594 /**
27595 * Helper method for parsing a URL string into a scheme and a path.
27596 *
27597 * @param url E.g., 'localstorage://my-model'
27598 * @returns A dictionary with two fields: scheme and path.
27599 * Scheme: e.g., 'localstorage' in the example above.
27600 * Path: e.g., 'my-model' in the example above.
27601 */
27602 function parseURL(url) {
27603 if (url.indexOf(URL_SCHEME_SUFFIX) === -1) {
27604 throw new Error("The url string provided does not contain a scheme. " + "Supported schemes are: " + "".concat(ModelStoreManagerRegistry.getSchemes().join(',')));
27605 }
27606 return {
27607 scheme: url.split(URL_SCHEME_SUFFIX)[0],
27608 path: url.split(URL_SCHEME_SUFFIX)[1]
27609 };
27610 }
27611 function cloneModelInternal(_x, _x2) {
27612 return _cloneModelInternal.apply(this, arguments);
27613 }
27614 /**
27615 * List all models stored in registered storage mediums.
27616 *
27617 * For a web browser environment, the registered mediums are Local Storage and
27618 * IndexedDB.
27619 *
27620 * ```js
27621 * // First create and save a model.
27622 * const model = tf.sequential();
27623 * model.add(tf.layers.dense(
27624 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
27625 * await model.save('localstorage://demo/management/model1');
27626 *
27627 * // Then list existing models.
27628 * console.log(JSON.stringify(await tf.io.listModels()));
27629 *
27630 * // Delete the model.
27631 * await tf.io.removeModel('localstorage://demo/management/model1');
27632 *
27633 * // List models again.
27634 * console.log(JSON.stringify(await tf.io.listModels()));
27635 * ```
27636 *
27637 * @returns A `Promise` of a dictionary mapping URLs of existing models to
27638 * their model artifacts info. URLs include medium-specific schemes, e.g.,
27639 * 'indexeddb://my/model/1'. Model artifacts info include type of the
27640 * model's topology, byte sizes of the topology, weights, etc.
27641 *
27642 * @doc {
27643 * heading: 'Models',
27644 * subheading: 'Management',
27645 * namespace: 'io',
27646 * ignoreCI: true
27647 * }
27648 */
27649 function _cloneModelInternal() {
27650 _cloneModelInternal = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(sourceURL, destURL) {
27651 var deleteSource,
27652 loadHandlers,
27653 loadHandler,
27654 saveHandlers,
27655 saveHandler,
27656 sourceScheme,
27657 sourcePath,
27658 sameMedium,
27659 modelArtifacts,
27660 saveResult,
27661 _args = arguments;
27662 return _regeneratorRuntime().wrap(function _callee$(_context) {
27663 while (1) switch (_context.prev = _context.next) {
27664 case 0:
27665 deleteSource = _args.length > 2 && _args[2] !== undefined ? _args[2] : false;
27666 assert$1(sourceURL !== destURL, function () {
27667 return "Old path and new path are the same: '".concat(sourceURL, "'");
27668 });
27669 loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL);
27670 assert$1(loadHandlers.length > 0, function () {
27671 return "Copying failed because no load handler is found for source URL ".concat(sourceURL, ".");
27672 });
27673 assert$1(loadHandlers.length < 2, function () {
27674 return "Copying failed because more than one (".concat(loadHandlers.length, ") ") + "load handlers for source URL ".concat(sourceURL, ".");
27675 });
27676 loadHandler = loadHandlers[0];
27677 saveHandlers = IORouterRegistry.getSaveHandlers(destURL);
27678 assert$1(saveHandlers.length > 0, function () {
27679 return "Copying failed because no save handler is found for destination " + "URL ".concat(destURL, ".");
27680 });
27681 assert$1(saveHandlers.length < 2, function () {
27682 return "Copying failed because more than one (".concat(loadHandlers.length, ") ") + "save handlers for destination URL ".concat(destURL, ".");
27683 });
27684 saveHandler = saveHandlers[0];
27685 sourceScheme = parseURL(sourceURL).scheme;
27686 sourcePath = parseURL(sourceURL).path;
27687 sameMedium = sourceScheme === parseURL(sourceURL).scheme;
27688 _context.next = 15;
27689 return loadHandler.load();
27690 case 15:
27691 modelArtifacts = _context.sent;
27692 if (!(deleteSource && sameMedium)) {
27693 _context.next = 19;
27694 break;
27695 }
27696 _context.next = 19;
27697 return ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath);
27698 case 19:
27699 _context.next = 21;
27700 return saveHandler.save(modelArtifacts);
27701 case 21:
27702 saveResult = _context.sent;
27703 if (!(deleteSource && !sameMedium)) {
27704 _context.next = 25;
27705 break;
27706 }
27707 _context.next = 25;
27708 return ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath);
27709 case 25:
27710 return _context.abrupt("return", saveResult.modelArtifactsInfo);
27711 case 26:
27712 case "end":
27713 return _context.stop();
27714 }
27715 }, _callee);
27716 }));
27717 return _cloneModelInternal.apply(this, arguments);
27718 }
27719 function listModels() {
27720 return _listModels.apply(this, arguments);
27721 }
27722 /**
27723 * Remove a model specified by URL from a registered storage medium.
27724 *
27725 * ```js
27726 * // First create and save a model.
27727 * const model = tf.sequential();
27728 * model.add(tf.layers.dense(
27729 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
27730 * await model.save('localstorage://demo/management/model1');
27731 *
27732 * // Then list existing models.
27733 * console.log(JSON.stringify(await tf.io.listModels()));
27734 *
27735 * // Delete the model.
27736 * await tf.io.removeModel('localstorage://demo/management/model1');
27737 *
27738 * // List models again.
27739 * console.log(JSON.stringify(await tf.io.listModels()));
27740 * ```
27741 *
27742 * @param url A URL to a stored model, with a scheme prefix, e.g.,
27743 * 'localstorage://my-model-1', 'indexeddb://my/model/2'.
27744 * @returns ModelArtifactsInfo of the deleted model (if and only if deletion
27745 * is successful).
27746 * @throws Error if deletion fails, e.g., if no model exists at `path`.
27747 *
27748 * @doc {
27749 * heading: 'Models',
27750 * subheading: 'Management',
27751 * namespace: 'io',
27752 * ignoreCI: true
27753 * }
27754 */
27755 function _listModels() {
27756 _listModels = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
27757 var schemes, out, _iterator, _step, scheme, schemeOut, path, url;
27758 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
27759 while (1) switch (_context2.prev = _context2.next) {
27760 case 0:
27761 schemes = ModelStoreManagerRegistry.getSchemes();
27762 out = {};
27763 _iterator = _createForOfIteratorHelper(schemes);
27764 _context2.prev = 3;
27765 _iterator.s();
27766 case 5:
27767 if ((_step = _iterator.n()).done) {
27768 _context2.next = 13;
27769 break;
27770 }
27771 scheme = _step.value;
27772 _context2.next = 9;
27773 return ModelStoreManagerRegistry.getManager(scheme).listModels();
27774 case 9:
27775 schemeOut = _context2.sent;
27776 for (path in schemeOut) {
27777 url = scheme + URL_SCHEME_SUFFIX + path;
27778 out[url] = schemeOut[path];
27779 }
27780 case 11:
27781 _context2.next = 5;
27782 break;
27783 case 13:
27784 _context2.next = 18;
27785 break;
27786 case 15:
27787 _context2.prev = 15;
27788 _context2.t0 = _context2["catch"](3);
27789 _iterator.e(_context2.t0);
27790 case 18:
27791 _context2.prev = 18;
27792 _iterator.f();
27793 return _context2.finish(18);
27794 case 21:
27795 return _context2.abrupt("return", out);
27796 case 22:
27797 case "end":
27798 return _context2.stop();
27799 }
27800 }, _callee2, null, [[3, 15, 18, 21]]);
27801 }));
27802 return _listModels.apply(this, arguments);
27803 }
27804 function removeModel(_x3) {
27805 return _removeModel.apply(this, arguments);
27806 }
27807 /**
27808 * Copy a model from one URL to another.
27809 *
27810 * This function supports:
27811 *
27812 * 1. Copying within a storage medium, e.g.,
27813 * `tf.io.copyModel('localstorage://model-1', 'localstorage://model-2')`
27814 * 2. Copying between two storage mediums, e.g.,
27815 * `tf.io.copyModel('localstorage://model-1', 'indexeddb://model-1')`
27816 *
27817 * ```js
27818 * // First create and save a model.
27819 * const model = tf.sequential();
27820 * model.add(tf.layers.dense(
27821 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
27822 * await model.save('localstorage://demo/management/model1');
27823 *
27824 * // Then list existing models.
27825 * console.log(JSON.stringify(await tf.io.listModels()));
27826 *
27827 * // Copy the model, from Local Storage to IndexedDB.
27828 * await tf.io.copyModel(
27829 * 'localstorage://demo/management/model1',
27830 * 'indexeddb://demo/management/model1');
27831 *
27832 * // List models again.
27833 * console.log(JSON.stringify(await tf.io.listModels()));
27834 *
27835 * // Remove both models.
27836 * await tf.io.removeModel('localstorage://demo/management/model1');
27837 * await tf.io.removeModel('indexeddb://demo/management/model1');
27838 * ```
27839 *
27840 * @param sourceURL Source URL of copying.
27841 * @param destURL Destination URL of copying.
27842 * @returns ModelArtifactsInfo of the copied model (if and only if copying
27843 * is successful).
27844 * @throws Error if copying fails, e.g., if no model exists at `sourceURL`, or
27845 * if `oldPath` and `newPath` are identical.
27846 *
27847 * @doc {
27848 * heading: 'Models',
27849 * subheading: 'Management',
27850 * namespace: 'io',
27851 * ignoreCI: true
27852 * }
27853 */
27854 function _removeModel() {
27855 _removeModel = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(url) {
27856 var schemeAndPath, manager;
27857 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
27858 while (1) switch (_context3.prev = _context3.next) {
27859 case 0:
27860 schemeAndPath = parseURL(url);
27861 manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme);
27862 return _context3.abrupt("return", manager.removeModel(schemeAndPath.path));
27863 case 3:
27864 case "end":
27865 return _context3.stop();
27866 }
27867 }, _callee3);
27868 }));
27869 return _removeModel.apply(this, arguments);
27870 }
27871 function copyModel(_x4, _x5) {
27872 return _copyModel.apply(this, arguments);
27873 }
27874 /**
27875 * Move a model from one URL to another.
27876 *
27877 * This function supports:
27878 *
27879 * 1. Moving within a storage medium, e.g.,
27880 * `tf.io.moveModel('localstorage://model-1', 'localstorage://model-2')`
27881 * 2. Moving between two storage mediums, e.g.,
27882 * `tf.io.moveModel('localstorage://model-1', 'indexeddb://model-1')`
27883 *
27884 * ```js
27885 * // First create and save a model.
27886 * const model = tf.sequential();
27887 * model.add(tf.layers.dense(
27888 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
27889 * await model.save('localstorage://demo/management/model1');
27890 *
27891 * // Then list existing models.
27892 * console.log(JSON.stringify(await tf.io.listModels()));
27893 *
27894 * // Move the model, from Local Storage to IndexedDB.
27895 * await tf.io.moveModel(
27896 * 'localstorage://demo/management/model1',
27897 * 'indexeddb://demo/management/model1');
27898 *
27899 * // List models again.
27900 * console.log(JSON.stringify(await tf.io.listModels()));
27901 *
27902 * // Remove the moved model.
27903 * await tf.io.removeModel('indexeddb://demo/management/model1');
27904 * ```
27905 *
27906 * @param sourceURL Source URL of moving.
27907 * @param destURL Destination URL of moving.
27908 * @returns ModelArtifactsInfo of the copied model (if and only if copying
27909 * is successful).
27910 * @throws Error if moving fails, e.g., if no model exists at `sourceURL`, or
27911 * if `oldPath` and `newPath` are identical.
27912 *
27913 * @doc {
27914 * heading: 'Models',
27915 * subheading: 'Management',
27916 * namespace: 'io',
27917 * ignoreCI: true
27918 * }
27919 */
27920 function _copyModel() {
27921 _copyModel = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(sourceURL, destURL) {
27922 var deleteSource;
27923 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
27924 while (1) switch (_context4.prev = _context4.next) {
27925 case 0:
27926 deleteSource = false;
27927 return _context4.abrupt("return", cloneModelInternal(sourceURL, destURL, deleteSource));
27928 case 2:
27929 case "end":
27930 return _context4.stop();
27931 }
27932 }, _callee4);
27933 }));
27934 return _copyModel.apply(this, arguments);
27935 }
27936 function moveModel(_x6, _x7) {
27937 return _moveModel.apply(this, arguments);
27938 }
27939 function _moveModel() {
27940 _moveModel = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5(sourceURL, destURL) {
27941 var deleteSource;
27942 return _regeneratorRuntime().wrap(function _callee5$(_context5) {
27943 while (1) switch (_context5.prev = _context5.next) {
27944 case 0:
27945 deleteSource = true;
27946 return _context5.abrupt("return", cloneModelInternal(sourceURL, destURL, deleteSource));
27947 case 2:
27948 case "end":
27949 return _context5.stop();
27950 }
27951 }, _callee5);
27952 }));
27953 return _moveModel.apply(this, arguments);
27954 }
27955
27956 var PlatformBrowser = /*#__PURE__*/function () {
27957 function PlatformBrowser() {
27958 _classCallCheck(this, PlatformBrowser);
27959 // For setTimeoutCustom
27960 this.messageName = 'setTimeoutCustom';
27961 this.functionRefs = [];
27962 this.handledMessageCount = 0;
27963 this.hasEventListener = false;
27964 }
27965 _createClass(PlatformBrowser, [{
27966 key: "fetch",
27967 value: function (_fetch) {
27968 function fetch(_x, _x2) {
27969 return _fetch.apply(this, arguments);
27970 }
27971 fetch.toString = function () {
27972 return _fetch.toString();
27973 };
27974 return fetch;
27975 }(function (path, init) {
27976 return fetch(path, init);
27977 })
27978 }, {
27979 key: "now",
27980 value: function now() {
27981 return performance.now();
27982 }
27983 }, {
27984 key: "encode",
27985 value: function encode(text, encoding) {
27986 if (encoding !== 'utf-8' && encoding !== 'utf8') {
27987 throw new Error("Browser's encoder only supports utf-8, but got ".concat(encoding));
27988 }
27989 if (this.textEncoder == null) {
27990 this.textEncoder = new TextEncoder();
27991 }
27992 return this.textEncoder.encode(text);
27993 }
27994 }, {
27995 key: "decode",
27996 value: function decode(bytes, encoding) {
27997 return new TextDecoder(encoding).decode(bytes);
27998 }
27999 // If the setTimeout nesting level is greater than 5 and timeout is less
28000 // than 4ms, timeout will be clamped to 4ms, which hurts the perf.
28001 // Interleaving window.postMessage and setTimeout will trick the browser and
28002 // avoid the clamp.
28003 }, {
28004 key: "setTimeoutCustom",
28005 value: function setTimeoutCustom(functionRef, delay) {
28006 var _this = this;
28007 if (typeof window === 'undefined' || !env().getBool('USE_SETTIMEOUTCUSTOM')) {
28008 setTimeout(functionRef, delay);
28009 return;
28010 }
28011 this.functionRefs.push(functionRef);
28012 setTimeout(function () {
28013 window.postMessage({
28014 name: _this.messageName,
28015 index: _this.functionRefs.length - 1
28016 }, '*');
28017 }, delay);
28018 if (!this.hasEventListener) {
28019 this.hasEventListener = true;
28020 window.addEventListener('message', function (event) {
28021 if (event.source === window && event.data.name === _this.messageName) {
28022 event.stopPropagation();
28023 var _functionRef = _this.functionRefs[event.data.index];
28024 _functionRef();
28025 _this.handledMessageCount++;
28026 if (_this.handledMessageCount === _this.functionRefs.length) {
28027 _this.functionRefs = [];
28028 _this.handledMessageCount = 0;
28029 }
28030 }
28031 }, true);
28032 }
28033 }
28034 }, {
28035 key: "isTypedArray",
28036 value: function isTypedArray(a) {
28037 return isTypedArrayBrowser(a);
28038 }
28039 }]);
28040 return PlatformBrowser;
28041 }();
28042 if (env().get('IS_BROWSER')) {
28043 env().setPlatform('browser', new PlatformBrowser());
28044 // Register LocalStorage IOHandler
28045 try {
28046 ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager());
28047 } catch (err) {}
28048 // Register IndexedDB IOHandler
28049 try {
28050 ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager());
28051 } catch (err) {}
28052 }
28053
28054 // We are wrapping this within an object so it can be stubbed by Jasmine.
28055 var getNodeFetch = {
28056 // tslint:disable-next-line:no-require-imports
28057 importFetch: function importFetch() {
28058 return require('node-fetch');
28059 }
28060 };
28061 var systemFetch;
28062 // These getters and setters are for testing so we don't export a mutable
28063 // variable.
28064 function resetSystemFetch() {
28065 systemFetch = null;
28066 }
28067 function setSystemFetch(fetchFn) {
28068 systemFetch = fetchFn;
28069 }
28070 function getSystemFetch() {
28071 return systemFetch;
28072 }
28073 var PlatformNode = /*#__PURE__*/function () {
28074 function PlatformNode() {
28075 _classCallCheck(this, PlatformNode);
28076 // tslint:disable-next-line:no-require-imports
28077 this.util = require('util');
28078 // According to the spec, the built-in encoder can do only UTF-8 encoding.
28079 // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder
28080 this.textEncoder = new this.util.TextEncoder();
28081 }
28082 _createClass(PlatformNode, [{
28083 key: "fetch",
28084 value: function fetch(path, requestInits) {
28085 if (env().global.fetch != null) {
28086 return env().global.fetch(path, requestInits);
28087 }
28088 if (systemFetch == null) {
28089 systemFetch = getNodeFetch.importFetch();
28090 }
28091 return systemFetch(path, requestInits);
28092 }
28093 }, {
28094 key: "now",
28095 value: function now() {
28096 var time = process.hrtime();
28097 return time[0] * 1000 + time[1] / 1000000;
28098 }
28099 }, {
28100 key: "encode",
28101 value: function encode(text, encoding) {
28102 if (encoding !== 'utf-8' && encoding !== 'utf8') {
28103 throw new Error("Node built-in encoder only supports utf-8, but got ".concat(encoding));
28104 }
28105 return this.textEncoder.encode(text);
28106 }
28107 }, {
28108 key: "decode",
28109 value: function decode(bytes, encoding) {
28110 if (bytes.length === 0) {
28111 return '';
28112 }
28113 return new this.util.TextDecoder(encoding).decode(bytes);
28114 }
28115 }, {
28116 key: "isTypedArray",
28117 value: function isTypedArray(a) {
28118 return this.util.types.isFloat32Array(a) || this.util.types.isInt32Array(a) || this.util.types.isUint8Array(a) || this.util.types.isUint8ClampedArray(a);
28119 }
28120 }]);
28121 return PlatformNode;
28122 }();
28123 if (env().get('IS_NODE') && !env().get('IS_BROWSER')) {
28124 env().setPlatform('node', new PlatformNode());
28125 }
28126
28127 /**
28128 * @license
28129 * Copyright 2020 Google Inc. All Rights Reserved.
28130 * Licensed under the Apache License, Version 2.0 (the "License");
28131 * you may not use this file except in compliance with the License.
28132 * You may obtain a copy of the License at
28133 *
28134 * http://www.apache.org/licenses/LICENSE-2.0
28135 *
28136 * Unless required by applicable law or agreed to in writing, software
28137 * distributed under the License is distributed on an "AS IS" BASIS,
28138 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28139 * See the License for the specific language governing permissions and
28140 * limitations under the License.
28141 * =============================================================================
28142 */
28143 /**
28144 * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`.
28145 *
28146 * The values are stored in CPU as `TypedArray`. Fill the buffer using
28147 * `buffer.set()`, or by modifying directly `buffer.values`.
28148 *
28149 * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with
28150 * those values.
28151 *
28152 * ```js
28153 * // Create a buffer and set values at particular indices.
28154 * const buffer = tf.buffer([2, 2]);
28155 * buffer.set(3, 0, 0);
28156 * buffer.set(5, 1, 0);
28157 *
28158 * // Convert the buffer back to a tensor.
28159 * buffer.toTensor().print();
28160 * ```
28161 *
28162 * @param shape An array of integers defining the output tensor shape.
28163 * @param dtype The dtype of the buffer. Defaults to 'float32'.
28164 * @param values The values of the buffer as `TypedArray`. Defaults to
28165 * zeros.
28166 *
28167 * @doc {heading: 'Tensors', subheading: 'Creation'}
28168 */
28169 function buffer(shape) {
28170 var dtype = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 'float32';
28171 var values = arguments.length > 2 ? arguments[2] : undefined;
28172 dtype = dtype || 'float32';
28173 assertNonNegativeIntegerDimensions(shape);
28174 return new TensorBuffer(shape, dtype, values);
28175 }
28176
28177 /**
28178 * @license
28179 * Copyright 2020 Google Inc. All Rights Reserved.
28180 * Licensed under the Apache License, Version 2.0 (the "License");
28181 * you may not use this file except in compliance with the License.
28182 * You may obtain a copy of the License at
28183 *
28184 * http://www.apache.org/licenses/LICENSE-2.0
28185 *
28186 * Unless required by applicable law or agreed to in writing, software
28187 * distributed under the License is distributed on an "AS IS" BASIS,
28188 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28189 * See the License for the specific language governing permissions and
28190 * limitations under the License.
28191 * =============================================================================
28192 */
28193 /**
28194 * Casts a `tf.Tensor` to a new dtype.
28195 *
28196 * ```js
28197 * const x = tf.tensor1d([1.5, 2.5, 3]);
28198 * tf.cast(x, 'int32').print();
28199 * ```
28200 * @param x The input tensor to be casted.
28201 * @param dtype The dtype to cast the input tensor to.
28202 *
28203 * @doc {heading: 'Tensors', subheading: 'Transformations'}
28204 */
28205 function cast_(x, dtype) {
28206 var $x = convertToTensor(x, 'x', 'cast');
28207 // Sanity checks.
28208 if (!isValidDtype(dtype)) {
28209 throw new Error("Failed to cast to unknown dtype ".concat(dtype));
28210 }
28211 if (dtype === 'string' && $x.dtype !== 'string' || dtype !== 'string' && $x.dtype === 'string') {
28212 throw new Error('Only strings can be casted to strings');
28213 }
28214 var inputs = {
28215 x: $x
28216 };
28217 var attrs = {
28218 dtype: dtype
28219 };
28220 return ENGINE.runKernel(Cast, inputs, attrs);
28221 }
28222 var cast$3 = /* @__PURE__ */op({
28223 cast_: cast_
28224 });
28225
28226 /**
28227 * @license
28228 * Copyright 2020 Google LLC. All Rights Reserved.
28229 * Licensed under the Apache License, Version 2.0 (the "License");
28230 * you may not use this file except in compliance with the License.
28231 * You may obtain a copy of the License at
28232 *
28233 * http://www.apache.org/licenses/LICENSE-2.0
28234 *
28235 * Unless required by applicable law or agreed to in writing, software
28236 * distributed under the License is distributed on an "AS IS" BASIS,
28237 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28238 * See the License for the specific language governing permissions and
28239 * limitations under the License.
28240 * =============================================================================
28241 */
28242 /**
28243 * Creates a new tensor with the same values and shape as the specified
28244 * tensor.
28245 *
28246 * ```js
28247 * const x = tf.tensor([1, 2]);
28248 *
28249 * x.clone().print();
28250 * ```
28251 *
28252 * @param x The tensor to clone.
28253 *
28254 * @doc {heading: 'Tensors', subheading: 'Creation'}
28255 */
28256 function clone_(x) {
28257 var $x = convertToTensor(x, 'x', 'clone', 'string_or_numeric');
28258 var inputs = {
28259 x: $x
28260 };
28261 // Note this op is called tf.identity in python. Hence the kernel name used
28262 // here.
28263 return ENGINE.runKernel(Identity$1, inputs);
28264 }
28265 var clone = /* @__PURE__ */op({
28266 clone_: clone_
28267 });
28268
28269 /**
28270 * @license
28271 * Copyright 2020 Google Inc. All Rights Reserved.
28272 * Licensed under the Apache License, Version 2.0 (the "License");
28273 * you may not use this file except in compliance with the License.
28274 * You may obtain a copy of the License at
28275 *
28276 * http://www.apache.org/licenses/LICENSE-2.0
28277 *
28278 * Unless required by applicable law or agreed to in writing, software
28279 * distributed under the License is distributed on an "AS IS" BASIS,
28280 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28281 * See the License for the specific language governing permissions and
28282 * limitations under the License.
28283 * =============================================================================
28284 */
28285 /**
28286 * Prints information about the `tf.Tensor` including its data.
28287 *
28288 * ```js
28289 * const verbose = true;
28290 * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose);
28291 * ```
28292 * @param x The tensor to be printed.
28293 * @param verbose Whether to print verbose information about the ` Tensor`,
28294 * including dtype and size.
28295 *
28296 * @doc {heading: 'Tensors', subheading: 'Creation'}
28297 */
28298 function print(x) {
28299 var verbose = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
28300 console.log(x.toString(verbose));
28301 }
28302
28303 /**
28304 * @license
28305 * Copyright 2020 Google Inc. All Rights Reserved.
28306 * Licensed under the Apache License, Version 2.0 (the "License");
28307 * you may not use this file except in compliance with the License.
28308 * You may obtain a copy of the License at
28309 *
28310 * http://www.apache.org/licenses/LICENSE-2.0
28311 *
28312 * Unless required by applicable law or agreed to in writing, software
28313 * distributed under the License is distributed on an "AS IS" BASIS,
28314 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28315 * See the License for the specific language governing permissions and
28316 * limitations under the License.
28317 * =============================================================================
28318 */
28319 getOrMakeEngine();
28320 var opHandler = {
28321 buffer: buffer,
28322 cast: cast$3,
28323 clone: clone,
28324 print: print
28325 };
28326 setOpHandler(opHandler);
28327
28328 /**
28329 * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting.
28330 *
28331 *
28332 * ```js
28333 * const a = tf.tensor1d([1, 2, 3, 4]);
28334 * const b = tf.tensor1d([10, 20, 30, 40]);
28335 *
28336 * a.add(b).print(); // or tf.add(a, b)
28337 * ```
28338 *
28339 * ```js
28340 * // Broadcast add a with b.
28341 * const a = tf.scalar(5);
28342 * const b = tf.tensor1d([10, 20, 30, 40]);
28343 *
28344 * a.add(b).print(); // or tf.add(a, b)
28345 * ```
28346 * @param a The first `tf.Tensor` to add.
28347 * @param b The second `tf.Tensor` to add. Must have the same type as `a`.
28348 *
28349 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
28350 */
28351 function add_(a, b) {
28352 var $a = convertToTensor(a, 'a', 'add');
28353 var $b = convertToTensor(b, 'b', 'add');
28354 var _makeTypesMatch = makeTypesMatch($a, $b);
28355 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
28356 $a = _makeTypesMatch2[0];
28357 $b = _makeTypesMatch2[1];
28358 var inputs = {
28359 a: $a,
28360 b: $b
28361 };
28362 return ENGINE.runKernel(Add$1, inputs);
28363 }
28364 var add$3 = /* @__PURE__ */op({
28365 add_: add_
28366 });
28367
28368 /**
28369 * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
28370 * The result is rounded with floor function.
28371 *
28372 *
28373 * ```js
28374 * const a = tf.tensor1d([1, 4, 9, 16]);
28375 * const b = tf.tensor1d([1, 2, 3, 4]);
28376 *
28377 * a.floorDiv(b).print(); // or tf.div(a, b)
28378 * ```
28379 *
28380 * ```js
28381 * // Broadcast div a with b.
28382 * const a = tf.tensor1d([2, 4, 6, 8]);
28383 * const b = tf.scalar(2);
28384 *
28385 * a.floorDiv(b).print(); // or tf.floorDiv(a, b)
28386 * ```
28387 *
28388 * @param a The first tensor as the numerator.
28389 * @param b The second tensor as the denominator. Must have the same dtype as
28390 * `a`.
28391 *
28392 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
28393 */
28394 function floorDiv_(a, b) {
28395 var $a = convertToTensor(a, 'a', 'floorDiv');
28396 var $b = convertToTensor(b, 'b', 'floorDiv');
28397 var _makeTypesMatch = makeTypesMatch($a, $b);
28398 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
28399 $a = _makeTypesMatch2[0];
28400 $b = _makeTypesMatch2[1];
28401 var inputs = {
28402 a: $a,
28403 b: $b
28404 };
28405 return ENGINE.runKernel(FloorDiv, inputs);
28406 }
28407 var floorDiv$2 = /* @__PURE__ */op({
28408 floorDiv_: floorDiv_
28409 });
28410
28411 /**
28412 * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
28413 *
28414 * ```js
28415 * const a = tf.tensor1d([1, 4, 9, 16]);
28416 * const b = tf.tensor1d([1, 2, 3, 4]);
28417 *
28418 * a.div(b).print(); // or tf.div(a, b)
28419 * ```
28420 *
28421 * ```js
28422 * // Broadcast div a with b.
28423 * const a = tf.tensor1d([2, 4, 6, 8]);
28424 * const b = tf.scalar(2);
28425 *
28426 * a.div(b).print(); // or tf.div(a, b)
28427 * ```
28428 *
28429 * @param a The first tensor as the numerator.
28430 * @param b The second tensor as the denominator. Must have the same dtype as
28431 * `a`.
28432 *
28433 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
28434 */
28435 function div_(a, b) {
28436 var $a = convertToTensor(a, 'a', 'div');
28437 var $b = convertToTensor(b, 'b', 'div');
28438 var _makeTypesMatch = makeTypesMatch($a, $b);
28439 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
28440 $a = _makeTypesMatch2[0];
28441 $b = _makeTypesMatch2[1];
28442 if ($a.dtype === 'int32' && $b.dtype === 'int32') {
28443 return floorDiv$2($a, $b);
28444 }
28445 var inputs = {
28446 a: $a,
28447 b: $b
28448 };
28449 var attrs = {};
28450 // tslint:disable-next-line: no-unnecessary-type-assertion
28451 return ENGINE.runKernel(RealDiv, inputs, attrs);
28452 }
28453 var div$1 = /* @__PURE__ */op({
28454 div_: div_
28455 });
28456
28457 /**
28458 * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting.
28459 *
28460 * We also expose `tf.mulStrict` which has the same signature as this op and
28461 * asserts that `a` and `b` are the same shape (does not broadcast).
28462 *
28463 * ```js
28464 * const a = tf.tensor1d([1, 2, 3, 4]);
28465 * const b = tf.tensor1d([2, 3, 4, 5]);
28466 *
28467 * a.mul(b).print(); // or tf.mul(a, b)
28468 * ```
28469 *
28470 * ```js
28471 * // Broadcast mul a with b.
28472 * const a = tf.tensor1d([1, 2, 3, 4]);
28473 * const b = tf.scalar(5);
28474 *
28475 * a.mul(b).print(); // or tf.mul(a, b)
28476 * ```
28477 * @param a The first tensor to multiply.
28478 * @param b The second tensor to multiply. Must have the same dtype as `a`.
28479 *
28480 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
28481 */
28482 function mul_(a, b) {
28483 var $a = convertToTensor(a, 'a', 'mul');
28484 var $b = convertToTensor(b, 'b', 'mul');
28485 var _makeTypesMatch = makeTypesMatch($a, $b);
28486 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
28487 $a = _makeTypesMatch2[0];
28488 $b = _makeTypesMatch2[1];
28489 var inputs = {
28490 a: $a,
28491 b: $b
28492 };
28493 return ENGINE.runKernel(Multiply$1, inputs);
28494 }
28495 var mul = /* @__PURE__ */op({
28496 mul_: mul_
28497 });
28498
28499 /**
28500 * @license
28501 * Copyright 2018 Google LLC. All Rights Reserved.
28502 * Licensed under the Apache License, Version 2.0 (the "License");
28503 * you may not use this file except in compliance with the License.
28504 * You may obtain a copy of the License at
28505 *
28506 * http://www.apache.org/licenses/LICENSE-2.0
28507 *
28508 * Unless required by applicable law or agreed to in writing, software
28509 * distributed under the License is distributed on an "AS IS" BASIS,
28510 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28511 * See the License for the specific language governing permissions and
28512 * limitations under the License.
28513 * =============================================================================
28514 */
28515 /**
28516 * Computes absolute value element-wise: `abs(x)`
28517 *
28518 * ```js
28519 * const x = tf.tensor1d([-1, 2, -3, 4]);
28520 *
28521 * x.abs().print(); // or tf.abs(x)
28522 * ```
28523 * @param x The input `tf.Tensor`.
28524 *
28525 * @doc {heading: 'Operations', subheading: 'Basic math'}
28526 */
28527 function abs_(x) {
28528 var $x = convertToTensor(x, 'x', 'abs');
28529 if ($x.dtype === 'complex64') {
28530 var inputs = {
28531 x: $x
28532 };
28533 return ENGINE.runKernel(ComplexAbs, inputs);
28534 } else {
28535 var _inputs = {
28536 x: $x
28537 };
28538 return ENGINE.runKernel(Abs, _inputs);
28539 }
28540 }
28541 var abs$2 = /* @__PURE__ */op({
28542 abs_: abs_
28543 });
28544
28545 /**
28546 * @license
28547 * Copyright 2018 Google LLC. All Rights Reserved.
28548 * Licensed under the Apache License, Version 2.0 (the "License");
28549 * you may not use this file except in compliance with the License.
28550 * You may obtain a copy of the License at
28551 *
28552 * http://www.apache.org/licenses/LICENSE-2.0
28553 *
28554 * Unless required by applicable law or agreed to in writing, software
28555 * distributed under the License is distributed on an "AS IS" BASIS,
28556 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28557 * See the License for the specific language governing permissions and
28558 * limitations under the License.
28559 * =============================================================================
28560 */
28561 /**
28562 * Computes acos of the input `tf.Tensor` element-wise: `acos(x)`
28563 *
28564 * ```js
28565 * const x = tf.tensor1d([0, 1, -1, .7]);
28566 *
28567 * x.acos().print(); // or tf.acos(x)
28568 * ```
28569 * @param x The input tensor.
28570 * @doc {heading: 'Operations', subheading: 'Basic math'}
28571 */
28572 function acos_(x) {
28573 var $x = convertToTensor(x, 'x', 'acos');
28574 var inputs = {
28575 x: $x
28576 };
28577 return ENGINE.runKernel(Acos, inputs);
28578 }
28579 var acos$2 = /* @__PURE__ */op({
28580 acos_: acos_
28581 });
28582
28583 /**
28584 * @license
28585 * Copyright 2018 Google LLC. All Rights Reserved.
28586 * Licensed under the Apache License, Version 2.0 (the "License");
28587 * you may not use this file except in compliance with the License.
28588 * You may obtain a copy of the License at
28589 *
28590 * http://www.apache.org/licenses/LICENSE-2.0
28591 *
28592 * Unless required by applicable law or agreed to in writing, software
28593 * distributed under the License is distributed on an "AS IS" BASIS,
28594 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28595 * See the License for the specific language governing permissions and
28596 * limitations under the License.
28597 * =============================================================================
28598 */
28599 /**
28600 * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise:
28601 * `acosh(x)`
28602 *
28603 * ```js
28604 * const x = tf.tensor1d([10, 1, 3, 5.7]);
28605 *
28606 * x.acosh().print(); // or tf.acosh(x)
28607 * ```
28608 * @param x The input tensor.
28609 *
28610 * @doc {heading: 'Operations', subheading: 'Basic math'}
28611 */
28612 function acosh_(x) {
28613 var $x = convertToTensor(x, 'x', 'acosh');
28614 var inputs = {
28615 x: $x
28616 };
28617 return ENGINE.runKernel(Acosh, inputs);
28618 }
28619 var acosh$2 = /* @__PURE__ */op({
28620 acosh_: acosh_
28621 });
28622
28623 /**
28624 * @license
28625 * Copyright 2020 Google LLC. All Rights Reserved.
28626 * Licensed under the Apache License, Version 2.0 (the "License");
28627 * you may not use this file except in compliance with the License.
28628 * You may obtain a copy of the License at
28629 *
28630 * http://www.apache.org/licenses/LICENSE-2.0
28631 *
28632 * Unless required by applicable law or agreed to in writing, software
28633 * distributed under the License is distributed on an "AS IS" BASIS,
28634 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28635 * See the License for the specific language governing permissions and
28636 * limitations under the License.
28637 * =============================================================================
28638 */
28639 /**
28640 * Adds a list of `tf.Tensor`s element-wise, each with the same shape and dtype.
28641 *
28642 * ```js
28643 * const a = tf.tensor1d([1, 2]);
28644 * const b = tf.tensor1d([3, 4]);
28645 * const c = tf.tensor1d([5, 6]);
28646 *
28647 * tf.addN([a, b, c]).print();
28648 * ```
28649 * @param tensors A list of tensors with the same shape and dtype.
28650 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
28651 */
28652 function addN_(tensors) {
28653 assert$1(Array.isArray(tensors), function () {
28654 return 'The argument passed to tf.addN() must be a list of tensors';
28655 });
28656 assert$1(tensors.length >= 1, function () {
28657 return "Must pass at least one tensor to tf.addN(), but got " + "".concat(tensors.length);
28658 });
28659 var $tensors = tensors.map(function (t, i) {
28660 return convertToTensor(t, "tensors".concat(i), 'addN');
28661 });
28662 var firstTensor = $tensors[0];
28663 $tensors.forEach(function (t) {
28664 if (t.dtype !== firstTensor.dtype) {
28665 throw new Error('All tensors passed to tf.addN() must have the same dtype');
28666 }
28667 });
28668 $tensors.forEach(function (t) {
28669 if (!arraysEqual(t.shape, firstTensor.shape)) {
28670 throw new Error('All tensors passed to tf.addN() must have the same shape');
28671 }
28672 });
28673 var inputs = $tensors;
28674 return ENGINE.runKernel(AddN, inputs);
28675 }
28676 var addN$2 = /* @__PURE__ */op({
28677 addN_: addN_
28678 });
28679
28680 /**
28681 * @license
28682 * Copyright 2020 Google LLC. All Rights Reserved.
28683 * Licensed under the Apache License, Version 2.0 (the "License");
28684 * you may not use this file except in compliance with the License.
28685 * You may obtain a copy of the License at
28686 *
28687 * http://www.apache.org/licenses/LICENSE-2.0
28688 *
28689 * Unless required by applicable law or agreed to in writing, software
28690 * distributed under the License is distributed on an "AS IS" BASIS,
28691 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28692 * See the License for the specific language governing permissions and
28693 * limitations under the License.
28694 * =============================================================================
28695 */
28696 /**
28697 * Computes the logical and of elements across dimensions of a `tf.Tensor`.
28698 *
28699 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
28700 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
28701 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
28702 * length 1. If `axes` has no entries, all dimensions are reduced, and a
28703 * `tf.Tensor` with a single element is returned.
28704 *
28705 * ```js
28706 * const x = tf.tensor1d([1, 1, 1], 'bool');
28707 *
28708 * x.all().print(); // or tf.all(x)
28709 * ```
28710 *
28711 * ```js
28712 * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
28713 *
28714 * const axis = 1;
28715 * x.all(axis).print(); // or tf.all(x, axis)
28716 * ```
28717 *
28718 * @param x The input tensor. Must be of dtype bool.
28719 * @param axis The dimension(s) to reduce. By default it reduces
28720 * all dimensions.
28721 * @param keepDims If true, retains reduced dimensions with size 1.
28722 *
28723 * @doc {heading: 'Operations', subheading: 'Reduction'}
28724 */
28725 function all_(x) {
28726 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null;
28727 var keepDims = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
28728 var $x = convertToTensor(x, 'x', 'all', 'bool');
28729 var inputs = {
28730 x: $x
28731 };
28732 var attrs = {
28733 axis: axis,
28734 keepDims: keepDims
28735 };
28736 return ENGINE.runKernel(All, inputs, attrs);
28737 }
28738 var all$2 = /* @__PURE__ */op({
28739 all_: all_
28740 });
28741
28742 /**
28743 * @license
28744 * Copyright 2020 Google LLC. All Rights Reserved.
28745 * Licensed under the Apache License, Version 2.0 (the "License");
28746 * you may not use this file except in compliance with the License.
28747 * You may obtain a copy of the License at
28748 *
28749 * http://www.apache.org/licenses/LICENSE-2.0
28750 *
28751 * Unless required by applicable law or agreed to in writing, software
28752 * distributed under the License is distributed on an "AS IS" BASIS,
28753 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28754 * See the License for the specific language governing permissions and
28755 * limitations under the License.
28756 * =============================================================================
28757 */
28758 /**
28759 * Computes the logical or of elements across dimensions of a `tf.Tensor`.
28760 *
28761 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
28762 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
28763 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
28764 * length 1. If `axes` has no entries, all dimensions are reduced, and a
28765 * `tf.Tensor` with a single element is returned.
28766 *
28767 * ```js
28768 * const x = tf.tensor1d([1, 1, 1], 'bool');
28769 *
28770 * x.any().print(); // or tf.any(x)
28771 * ```
28772 *
28773 * ```js
28774 * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
28775 *
28776 * const axis = 1;
28777 * x.any(axis).print(); // or tf.any(x, axis)
28778 * ```
28779 *
28780 * @param x The input tensor. Must be of dtype bool.
28781 * @param axis The dimension(s) to reduce. By default it reduces
28782 * all dimensions.
28783 * @param keepDims If true, retains reduced dimensions with size 1.
28784 *
28785 * @doc {heading: 'Operations', subheading: 'Reduction'}
28786 */
28787 function any_(x) {
28788 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null;
28789 var keepDims = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
28790 var $x = convertToTensor(x, 'x', 'any', 'bool');
28791 var inputs = {
28792 x: $x
28793 };
28794 var attrs = {
28795 axis: axis,
28796 keepDims: keepDims
28797 };
28798 return ENGINE.runKernel(Any, inputs, attrs);
28799 }
28800 // tslint:disable-next-line:variable-name
28801 var any$2 = /* @__PURE__ */op({
28802 any_: any_
28803 });
28804
28805 /**
28806 * @license
28807 * Copyright 2020 Google Inc. All Rights Reserved.
28808 * Licensed under the Apache License, Version 2.0 (the "License");
28809 * you may not use this file except in compliance with the License.
28810 * You may obtain a copy of the License at
28811 *
28812 * http://www.apache.org/licenses/LICENSE-2.0
28813 *
28814 * Unless required by applicable law or agreed to in writing, software
28815 * distributed under the License is distributed on an "AS IS" BASIS,
28816 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28817 * See the License for the specific language governing permissions and
28818 * limitations under the License.
28819 * =============================================================================
28820 */
28821 /**
28822 * Returns the indices of the maximum values along an `axis`.
28823 *
28824 * The result has the same shape as `input` with the dimension along `axis`
28825 * removed.
28826 *
28827 * ```js
28828 * const x = tf.tensor1d([1, 2, 3]);
28829 *
28830 * x.argMax().print(); // or tf.argMax(x)
28831 * ```
28832 *
28833 * ```js
28834 * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
28835 *
28836 * const axis = 1;
28837 * x.argMax(axis).print(); // or tf.argMax(x, axis)
28838 * ```
28839 *
28840 * @param x The input tensor.
28841 * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
28842 *
28843 * @doc {heading: 'Operations', subheading: 'Reduction'}
28844 */
28845 function argMax_(x) {
28846 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
28847 var $x = convertToTensor(x, 'x', 'argMax');
28848 var inputs = {
28849 x: $x
28850 };
28851 var attrs = {
28852 axis: axis
28853 };
28854 return ENGINE.runKernel(ArgMax, inputs, attrs);
28855 }
28856 var argMax$2 = /* @__PURE__ */op({
28857 argMax_: argMax_
28858 });
28859
28860 /**
28861 * @license
28862 * Copyright 2020 Google Inc. All Rights Reserved.
28863 * Licensed under the Apache License, Version 2.0 (the "License");
28864 * you may not use this file except in compliance with the License.
28865 * You may obtain a copy of the License at
28866 *
28867 * http://www.apache.org/licenses/LICENSE-2.0
28868 *
28869 * Unless required by applicable law or agreed to in writing, software
28870 * distributed under the License is distributed on an "AS IS" BASIS,
28871 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28872 * See the License for the specific language governing permissions and
28873 * limitations under the License.
28874 * =============================================================================
28875 */
28876 /**
28877 * Returns the indices of the minimum values along an `axis`.
28878 *
28879 * The result has the same shape as `input` with the dimension along `axis`
28880 * removed.
28881 *
28882 * ```js
28883 * const x = tf.tensor1d([1, 2, 3]);
28884 *
28885 * x.argMin().print(); // or tf.argMin(x)
28886 * ```
28887 *
28888 * ```js
28889 * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
28890 *
28891 * const axis = 1;
28892 * x.argMin(axis).print(); // or tf.argMin(x, axis)
28893 * ```
28894 *
28895 * @param x The input tensor.
28896 * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
28897 *
28898 * @doc {heading: 'Operations', subheading: 'Reduction'}
28899 */
28900 function argMin_(x) {
28901 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
28902 var $x = convertToTensor(x, 'x', 'argMin');
28903 var inputs = {
28904 x: $x
28905 };
28906 var attrs = {
28907 axis: axis
28908 };
28909 return ENGINE.runKernel(ArgMin, inputs, attrs);
28910 }
28911 var argMin$2 = /* @__PURE__ */op({
28912 argMin_: argMin_
28913 });
28914
28915 /**
28916 * @license
28917 * Copyright 2018 Google LLC. All Rights Reserved.
28918 * Licensed under the Apache License, Version 2.0 (the "License");
28919 * you may not use this file except in compliance with the License.
28920 * You may obtain a copy of the License at
28921 *
28922 * http://www.apache.org/licenses/LICENSE-2.0
28923 *
28924 * Unless required by applicable law or agreed to in writing, software
28925 * distributed under the License is distributed on an "AS IS" BASIS,
28926 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28927 * See the License for the specific language governing permissions and
28928 * limitations under the License.
28929 * =============================================================================
28930 */
28931 /**
28932 * Computes asin of the input `tf.Tensor` element-wise: `asin(x)`
28933 *
28934 * ```js
28935 * const x = tf.tensor1d([0, 1, -1, .7]);
28936 *
28937 * x.asin().print(); // or tf.asin(x)
28938 * ```
28939 * @param x The input tensor.
28940 * @doc {heading: 'Operations', subheading: 'Basic math'}
28941 */
28942 function asin_(x) {
28943 var $x = convertToTensor(x, 'x', 'asin');
28944 var inputs = {
28945 x: $x
28946 };
28947 return ENGINE.runKernel(Asin, inputs);
28948 }
28949 var asin$2 = /* @__PURE__ */op({
28950 asin_: asin_
28951 });
28952
28953 /**
28954 * @license
28955 * Copyright 2018 Google LLC. All Rights Reserved.
28956 * Licensed under the Apache License, Version 2.0 (the "License");
28957 * you may not use this file except in compliance with the License.
28958 * You may obtain a copy of the License at
28959 *
28960 * http://www.apache.org/licenses/LICENSE-2.0
28961 *
28962 * Unless required by applicable law or agreed to in writing, software
28963 * distributed under the License is distributed on an "AS IS" BASIS,
28964 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28965 * See the License for the specific language governing permissions and
28966 * limitations under the License.
28967 * =============================================================================
28968 */
28969 /**
28970 * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise:
28971 * `asinh(x)`
28972 *
28973 * ```js
28974 * const x = tf.tensor1d([0, 1, -1, .7]);
28975 *
28976 * x.asinh().print(); // or tf.asinh(x)
28977 * ```
28978 * @param x The input tensor.
28979 *
28980 * @doc {heading: 'Operations', subheading: 'Basic math'}
28981 */
28982 function asinh_(x) {
28983 var $x = convertToTensor(x, 'x', 'asinh');
28984 var inputs = {
28985 x: $x
28986 };
28987 return ENGINE.runKernel(Asinh, inputs);
28988 }
28989 var asinh$2 = /* @__PURE__ */op({
28990 asinh_: asinh_
28991 });
28992
28993 /**
28994 * @license
28995 * Copyright 2018 Google LLC. All Rights Reserved.
28996 * Licensed under the Apache License, Version 2.0 (the "License");
28997 * you may not use this file except in compliance with the License.
28998 * You may obtain a copy of the License at
28999 *
29000 * http://www.apache.org/licenses/LICENSE-2.0
29001 *
29002 * Unless required by applicable law or agreed to in writing, software
29003 * distributed under the License is distributed on an "AS IS" BASIS,
29004 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29005 * See the License for the specific language governing permissions and
29006 * limitations under the License.
29007 * =============================================================================
29008 */
29009 /**
29010 * Computes atan of the input `tf.Tensor` element-wise: `atan(x)`
29011 *
29012 * ```js
29013 * const x = tf.tensor1d([0, 1, -1, .7]);
29014 *
29015 * x.atan().print(); // or tf.atan(x)
29016 * ```
29017 * @param x The input tensor.
29018 *
29019 * @doc {heading: 'Operations', subheading: 'Basic math'}
29020 */
29021 function atan_(x) {
29022 var $x = convertToTensor(x, 'x', 'atan');
29023 var inputs = {
29024 x: $x
29025 };
29026 return ENGINE.runKernel(Atan, inputs);
29027 }
29028 var atan$2 = /* @__PURE__ */op({
29029 atan_: atan_
29030 });
29031
29032 /**
29033 * Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`.
29034 * Supports broadcasting.
29035 *
29036 * ```js
29037 * const a = tf.tensor1d([1.0, 1.0, -1.0, .7]);
29038 * const b = tf.tensor1d([2.0, 13.0, 3.5, .21]);
29039 *
29040 * tf.atan2(a, b).print()
29041 * ```
29042 *
29043 * @param a The first tensor.
29044 * @param b The second tensor. Must have the same dtype as `a`.
29045 *
29046 * @doc {heading: 'Operations', subheading: 'Basic math'}
29047 */
29048 function atan2_(a, b) {
29049 var $a = convertToTensor(a, 'a', 'atan2');
29050 var $b = convertToTensor(b, 'b', 'atan2');
29051 var _makeTypesMatch = makeTypesMatch($a, $b);
29052 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
29053 $a = _makeTypesMatch2[0];
29054 $b = _makeTypesMatch2[1];
29055 var inputs = {
29056 a: $a,
29057 b: $b
29058 };
29059 return ENGINE.runKernel(Atan2, inputs);
29060 }
29061 var atan2$2 = /* @__PURE__ */op({
29062 atan2_: atan2_
29063 });
29064
29065 /**
29066 * @license
29067 * Copyright 2018 Google LLC. All Rights Reserved.
29068 * Licensed under the Apache License, Version 2.0 (the "License");
29069 * you may not use this file except in compliance with the License.
29070 * You may obtain a copy of the License at
29071 *
29072 * http://www.apache.org/licenses/LICENSE-2.0
29073 *
29074 * Unless required by applicable law or agreed to in writing, software
29075 * distributed under the License is distributed on an "AS IS" BASIS,
29076 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29077 * See the License for the specific language governing permissions and
29078 * limitations under the License.
29079 * =============================================================================
29080 */
29081 /**
29082 * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise:
29083 * `atanh(x)`
29084 *
29085 * ```js
29086 * const x = tf.tensor1d([0, .1, -.1, .7]);
29087 *
29088 * x.atanh().print(); // or tf.atanh(x)
29089 * ```
29090 * @param x The input tensor.
29091 *
29092 * @doc {heading: 'Operations', subheading: 'Basic math'}
29093 */
29094 function atanh_(x) {
29095 var $x = convertToTensor(x, 'x', 'atanh');
29096 var inputs = {
29097 x: $x
29098 };
29099 return ENGINE.runKernel(Atanh, inputs);
29100 }
29101 var atanh$2 = /* @__PURE__ */op({
29102 atanh_: atanh_
29103 });
29104
29105 /**
29106 *
29107 * @param inputShape Input tensor shape is of the following dimensions:
29108 * `[batch, height, width, inChannels]`.
29109 * @param filterShape The filter shape is of the following dimensions:
29110 * `[filterHeight, filterWidth, depth]`.
29111 * @param strides The strides of the sliding window for each dimension of the
29112 * input tensor: `[strideHeight, strideWidth]`.
29113 * If `strides` is a single number,
29114 * then `strideHeight == strideWidth`.
29115 * @param pad The type of padding algorithm.
29116 * - `same` and stride 1: output will be of same size as input,
29117 * regardless of filter size.
29118 * - `valid`: output will be smaller than input if filter is larger
29119 * than 1*1x1.
29120 * - For more info, see this guide:
29121 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
29122 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
29123 * @param dataFormat The data format of the input and output data.
29124 * Defaults to 'NHWC'.
29125 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`.
29126 * Defaults to `[1, 1]`. If `dilations` is a single number, then
29127 * `dilationHeight == dilationWidth`.
29128 */
29129 function computeDilation2DInfo(inputShape, filterShape, strides, pad) {
29130 var dataFormat = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 'NHWC';
29131 var dilations = arguments.length > 5 ? arguments[5] : undefined;
29132 // `computerConv2DInfo` require filterShape to be in the dimension of:
29133 // `[filterHeight, filterWidth, depth, outDepth]`, dilation2d doesn't have
29134 // outDepth, it should have the same depth as the input.
29135 // Input shape: [batch, height, width, inChannels]
29136 var inputChannels = inputShape[3];
29137 var $filterShape = [].concat(_toConsumableArray(filterShape), [inputChannels]);
29138 var $dataFormat = convertConv2DDataFormat(dataFormat);
29139 return computeConv2DInfo(inputShape, $filterShape, strides, dilations, pad, null /* roundingMode */, null /* depthWise */, $dataFormat);
29140 }
29141 function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode) {
29142 var dataFormat = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : 'channelsLast';
29143 var _parseTupleParam = parseTupleParam(filterSize),
29144 _parseTupleParam2 = _slicedToArray(_parseTupleParam, 2),
29145 filterHeight = _parseTupleParam2[0],
29146 filterWidth = _parseTupleParam2[1];
29147 var filterShape;
29148 if (dataFormat === 'channelsLast') {
29149 filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
29150 } else if (dataFormat === 'channelsFirst') {
29151 filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
29152 } else {
29153 throw new Error("Unknown dataFormat ".concat(dataFormat));
29154 }
29155 return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat);
29156 }
29157 /**
29158 * Computes the information for a forward pass of a pooling3D operation.
29159 */
29160 function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode) {
29161 var dataFormat = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : 'NDHWC';
29162 var _parse3TupleParam = parse3TupleParam(filterSize),
29163 _parse3TupleParam2 = _slicedToArray(_parse3TupleParam, 3),
29164 filterDepth = _parse3TupleParam2[0],
29165 filterHeight = _parse3TupleParam2[1],
29166 filterWidth = _parse3TupleParam2[2];
29167 var filterShape;
29168 var $dataFormat;
29169 if (dataFormat === 'NDHWC') {
29170 $dataFormat = 'channelsLast';
29171 filterShape = [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];
29172 } else if (dataFormat === 'NCDHW') {
29173 $dataFormat = 'channelsFirst';
29174 filterShape = [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];
29175 } else {
29176 throw new Error("Unknown dataFormat ".concat(dataFormat));
29177 }
29178 return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode);
29179 }
29180 /**
29181 * Computes the information for a forward pass of a convolution/pooling
29182 * operation.
29183 */
29184 function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode) {
29185 var depthwise = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : false;
29186 var dataFormat = arguments.length > 7 && arguments[7] !== undefined ? arguments[7] : 'channelsLast';
29187 var batchSize = -1,
29188 inHeight = -1,
29189 inWidth = -1,
29190 inChannels = -1;
29191 if (dataFormat === 'channelsLast') {
29192 var _inShape = _slicedToArray(inShape, 4);
29193 batchSize = _inShape[0];
29194 inHeight = _inShape[1];
29195 inWidth = _inShape[2];
29196 inChannels = _inShape[3];
29197 } else if (dataFormat === 'channelsFirst') {
29198 var _inShape2 = _slicedToArray(inShape, 4);
29199 batchSize = _inShape2[0];
29200 inChannels = _inShape2[1];
29201 inHeight = _inShape2[2];
29202 inWidth = _inShape2[3];
29203 } else {
29204 throw new Error("Unknown dataFormat ".concat(dataFormat));
29205 }
29206 var _filterShape = _slicedToArray(filterShape, 4),
29207 filterHeight = _filterShape[0],
29208 filterWidth = _filterShape[1],
29209 filterChannels = _filterShape[3];
29210 var _parseTupleParam3 = parseTupleParam(strides),
29211 _parseTupleParam4 = _slicedToArray(_parseTupleParam3, 2),
29212 strideHeight = _parseTupleParam4[0],
29213 strideWidth = _parseTupleParam4[1];
29214 var _parseTupleParam5 = parseTupleParam(dilations),
29215 _parseTupleParam6 = _slicedToArray(_parseTupleParam5, 2),
29216 dilationHeight = _parseTupleParam6[0],
29217 dilationWidth = _parseTupleParam6[1];
29218 var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
29219 var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
29220 var _getPadAndOutInfo = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode, dataFormat),
29221 padInfo = _getPadAndOutInfo.padInfo,
29222 outHeight = _getPadAndOutInfo.outHeight,
29223 outWidth = _getPadAndOutInfo.outWidth;
29224 var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
29225 var outShape;
29226 if (dataFormat === 'channelsFirst') {
29227 outShape = [batchSize, outChannels, outHeight, outWidth];
29228 } else if (dataFormat === 'channelsLast') {
29229 outShape = [batchSize, outHeight, outWidth, outChannels];
29230 }
29231 return {
29232 batchSize: batchSize,
29233 dataFormat: dataFormat,
29234 inHeight: inHeight,
29235 inWidth: inWidth,
29236 inChannels: inChannels,
29237 outHeight: outHeight,
29238 outWidth: outWidth,
29239 outChannels: outChannels,
29240 padInfo: padInfo,
29241 strideHeight: strideHeight,
29242 strideWidth: strideWidth,
29243 filterHeight: filterHeight,
29244 filterWidth: filterWidth,
29245 effectiveFilterHeight: effectiveFilterHeight,
29246 effectiveFilterWidth: effectiveFilterWidth,
29247 dilationHeight: dilationHeight,
29248 dilationWidth: dilationWidth,
29249 inShape: inShape,
29250 outShape: outShape,
29251 filterShape: filterShape
29252 };
29253 }
29254 /**
29255 * Computes the information for a forward pass of a 3D convolution/pooling
29256 * operation.
29257 */
29258 function computeConv3DInfo(inShape, filterShape, strides, dilations, pad) {
29259 var depthwise = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : false;
29260 var dataFormat = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : 'channelsLast';
29261 var roundingMode = arguments.length > 7 ? arguments[7] : undefined;
29262 var batchSize = -1,
29263 inDepth = -1,
29264 inHeight = -1,
29265 inWidth = -1,
29266 inChannels = -1;
29267 if (dataFormat === 'channelsLast') {
29268 var _inShape3 = _slicedToArray(inShape, 5);
29269 batchSize = _inShape3[0];
29270 inDepth = _inShape3[1];
29271 inHeight = _inShape3[2];
29272 inWidth = _inShape3[3];
29273 inChannels = _inShape3[4];
29274 } else if (dataFormat === 'channelsFirst') {
29275 var _inShape4 = _slicedToArray(inShape, 5);
29276 batchSize = _inShape4[0];
29277 inChannels = _inShape4[1];
29278 inDepth = _inShape4[2];
29279 inHeight = _inShape4[3];
29280 inWidth = _inShape4[4];
29281 } else {
29282 throw new Error("Unknown dataFormat ".concat(dataFormat));
29283 }
29284 var _filterShape2 = _slicedToArray(filterShape, 5),
29285 filterDepth = _filterShape2[0],
29286 filterHeight = _filterShape2[1],
29287 filterWidth = _filterShape2[2],
29288 filterChannels = _filterShape2[4];
29289 var _parse3TupleParam3 = parse3TupleParam(strides),
29290 _parse3TupleParam4 = _slicedToArray(_parse3TupleParam3, 3),
29291 strideDepth = _parse3TupleParam4[0],
29292 strideHeight = _parse3TupleParam4[1],
29293 strideWidth = _parse3TupleParam4[2];
29294 var _parse3TupleParam5 = parse3TupleParam(dilations),
29295 _parse3TupleParam6 = _slicedToArray(_parse3TupleParam5, 3),
29296 dilationDepth = _parse3TupleParam6[0],
29297 dilationHeight = _parse3TupleParam6[1],
29298 dilationWidth = _parse3TupleParam6[2];
29299 var effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
29300 var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
29301 var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
29302 var _get3DPadAndOutInfo = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode),
29303 padInfo = _get3DPadAndOutInfo.padInfo,
29304 outDepth = _get3DPadAndOutInfo.outDepth,
29305 outHeight = _get3DPadAndOutInfo.outHeight,
29306 outWidth = _get3DPadAndOutInfo.outWidth;
29307 var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
29308 var outShape;
29309 if (dataFormat === 'channelsFirst') {
29310 outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
29311 } else if (dataFormat === 'channelsLast') {
29312 outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
29313 }
29314 return {
29315 batchSize: batchSize,
29316 dataFormat: dataFormat,
29317 inDepth: inDepth,
29318 inHeight: inHeight,
29319 inWidth: inWidth,
29320 inChannels: inChannels,
29321 outDepth: outDepth,
29322 outHeight: outHeight,
29323 outWidth: outWidth,
29324 outChannels: outChannels,
29325 padInfo: padInfo,
29326 strideDepth: strideDepth,
29327 strideHeight: strideHeight,
29328 strideWidth: strideWidth,
29329 filterDepth: filterDepth,
29330 filterHeight: filterHeight,
29331 filterWidth: filterWidth,
29332 effectiveFilterDepth: effectiveFilterDepth,
29333 effectiveFilterHeight: effectiveFilterHeight,
29334 effectiveFilterWidth: effectiveFilterWidth,
29335 dilationDepth: dilationDepth,
29336 dilationHeight: dilationHeight,
29337 dilationWidth: dilationWidth,
29338 inShape: inShape,
29339 outShape: outShape,
29340 filterShape: filterShape
29341 };
29342 }
29343 function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) {
29344 if (zeroPad == null) {
29345 zeroPad = computeDefaultPad(inShape, fieldSize, stride);
29346 }
29347 var inputRows = inShape[0];
29348 var inputCols = inShape[1];
29349 var outputRows = round$3((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
29350 var outputCols = round$3((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
29351 return [outputRows, outputCols];
29352 }
29353 function computeOutputShape4D(inShape, filterShape, outChannels, strides, zeroPad, roundingMode) {
29354 if (zeroPad == null) {
29355 zeroPad = computeDefaultPad(inShape, filterShape[0], strides[0]);
29356 }
29357 var outShape = [0, 0, 0, outChannels];
29358 for (var index = 0; index < 3; index++) {
29359 if (inShape[index] + 2 * zeroPad >= filterShape[index]) {
29360 outShape[index] = round$3((inShape[index] - filterShape[index] + 2 * zeroPad) / strides[index] + 1, roundingMode);
29361 }
29362 }
29363 return outShape;
29364 }
29365 function computeDefaultPad(inputShape, fieldSize, stride) {
29366 var dilation = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 1;
29367 var effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
29368 return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
29369 }
29370 function parseTupleParam(param) {
29371 if (typeof param === 'number') {
29372 return [param, param, param];
29373 }
29374 if (param.length === 2) {
29375 return [param[0], param[1], 1];
29376 }
29377 return param;
29378 }
29379 function parse3TupleParam(param) {
29380 return typeof param === 'number' ? [param, param, param] : param;
29381 }
29382 /* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d
29383 * Atrous convolution is equivalent to standard convolution with upsampled
29384 * filters with effective_filter_height =
29385 * filter_height + (filter_height - 1) * (dilation - 1)
29386 * and effective_filter_width =
29387 * filter_width + (filter_width - 1) * (dilation - 1),
29388 * produced by inserting dilation - 1 zeros along consecutive elements across
29389 * the filters' spatial dimensions.
29390 * When there is a dilation, this converts a filter dimension to the
29391 * effective filter dimension, so it can be used in a standard convolution.
29392 */
29393 function getEffectiveFilterSize(filterSize, dilation) {
29394 if (dilation <= 1) {
29395 return filterSize;
29396 }
29397 return filterSize + (filterSize - 1) * (dilation - 1);
29398 }
29399 function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode, dataFormat) {
29400 var padInfo;
29401 var outHeight;
29402 var outWidth;
29403 if (typeof pad === 'number') {
29404 var padType = pad === 0 ? 'VALID' : 'NUMBER';
29405 padInfo = {
29406 top: pad,
29407 bottom: pad,
29408 left: pad,
29409 right: pad,
29410 type: padType
29411 };
29412 var outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);
29413 outHeight = outShape[0];
29414 outWidth = outShape[1];
29415 } else if (pad === 'same') {
29416 outHeight = Math.ceil(inHeight / strideHeight);
29417 outWidth = Math.ceil(inWidth / strideWidth);
29418 var padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
29419 var padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
29420 var top = Math.floor(padAlongHeight / 2);
29421 var bottom = padAlongHeight - top;
29422 var left = Math.floor(padAlongWidth / 2);
29423 var right = padAlongWidth - left;
29424 padInfo = {
29425 top: top,
29426 bottom: bottom,
29427 left: left,
29428 right: right,
29429 type: 'SAME'
29430 };
29431 } else if (pad === 'valid') {
29432 padInfo = {
29433 top: 0,
29434 bottom: 0,
29435 left: 0,
29436 right: 0,
29437 type: 'VALID'
29438 };
29439 outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
29440 outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
29441 } else if (_typeof(pad) === 'object') {
29442 var _top = dataFormat === 'channelsLast' ? pad[1][0] : pad[2][0];
29443 var _bottom = dataFormat === 'channelsLast' ? pad[1][1] : pad[2][1];
29444 var _left = dataFormat === 'channelsLast' ? pad[2][0] : pad[3][0];
29445 var _right = dataFormat === 'channelsLast' ? pad[2][1] : pad[3][1];
29446 var _padType = _top === 0 && _bottom === 0 && _left === 0 && _right === 0 ? 'VALID' : 'EXPLICIT';
29447 padInfo = {
29448 top: _top,
29449 bottom: _bottom,
29450 left: _left,
29451 right: _right,
29452 type: _padType
29453 };
29454 outHeight = round$3((inHeight - filterHeight + _top + _bottom) / strideHeight + 1, roundingMode);
29455 outWidth = round$3((inWidth - filterWidth + _left + _right) / strideWidth + 1, roundingMode);
29456 } else {
29457 throw Error("Unknown padding parameter: ".concat(pad));
29458 }
29459 return {
29460 padInfo: padInfo,
29461 outHeight: outHeight,
29462 outWidth: outWidth
29463 };
29464 }
29465 function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) {
29466 var padInfo;
29467 var outDepth;
29468 var outHeight;
29469 var outWidth;
29470 if (pad === 'valid') {
29471 pad = 0;
29472 }
29473 if (typeof pad === 'number') {
29474 var padType = pad === 0 ? 'VALID' : 'NUMBER';
29475 padInfo = {
29476 top: pad,
29477 bottom: pad,
29478 left: pad,
29479 right: pad,
29480 front: pad,
29481 back: pad,
29482 type: padType
29483 };
29484 var outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], [filterDepth, filterHeight, filterWidth], 1, [strideDepth, strideHeight, strideWidth], pad, roundingMode);
29485 outDepth = outShape[0];
29486 outHeight = outShape[1];
29487 outWidth = outShape[2];
29488 } else if (pad === 'same') {
29489 outDepth = Math.ceil(inDepth / strideDepth);
29490 outHeight = Math.ceil(inHeight / strideHeight);
29491 outWidth = Math.ceil(inWidth / strideWidth);
29492 var padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
29493 var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
29494 var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
29495 var front = Math.floor(padAlongDepth / 2);
29496 var back = padAlongDepth - front;
29497 var top = Math.floor(padAlongHeight / 2);
29498 var bottom = padAlongHeight - top;
29499 var left = Math.floor(padAlongWidth / 2);
29500 var right = padAlongWidth - left;
29501 padInfo = {
29502 top: top,
29503 bottom: bottom,
29504 left: left,
29505 right: right,
29506 front: front,
29507 back: back,
29508 type: 'SAME'
29509 };
29510 } else {
29511 throw Error("Unknown padding parameter: ".concat(pad));
29512 }
29513 return {
29514 padInfo: padInfo,
29515 outDepth: outDepth,
29516 outHeight: outHeight,
29517 outWidth: outWidth
29518 };
29519 }
29520 /**
29521 * Rounds a value depending on the rounding mode
29522 * @param value
29523 * @param roundingMode A string from: 'ceil', 'round', 'floor'. If none is
29524 * provided, it will default to truncate.
29525 */
29526 function round$3(value, roundingMode) {
29527 if (!roundingMode) {
29528 return Math.trunc(value);
29529 }
29530 switch (roundingMode) {
29531 case 'round':
29532 // used for Caffe Conv
29533 return Math.round(value);
29534 case 'ceil':
29535 // used for Caffe Pool
29536 return Math.ceil(value);
29537 case 'floor':
29538 return Math.floor(value);
29539 default:
29540 throw new Error("Unknown roundingMode ".concat(roundingMode));
29541 }
29542 }
29543 function tupleValuesAreOne(param) {
29544 var _parseTupleParam7 = parseTupleParam(param),
29545 _parseTupleParam8 = _slicedToArray(_parseTupleParam7, 3),
29546 dimA = _parseTupleParam8[0],
29547 dimB = _parseTupleParam8[1],
29548 dimC = _parseTupleParam8[2];
29549 return dimA === 1 && dimB === 1 && dimC === 1;
29550 }
29551 function eitherStridesOrDilationsAreOne(strides, dilations) {
29552 return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
29553 }
29554 function stridesOrDilationsArePositive(values) {
29555 return parseTupleParam(values).every(function (value) {
29556 return value > 0;
29557 });
29558 }
29559 /**
29560 * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to
29561 * 'channelsLast'|'channelsFirst'
29562 * @param dataFormat in 'NHWC'|'NCHW' mode
29563 * @return dataFormat in 'channelsLast'|'channelsFirst' mode
29564 * @throws unknown dataFormat
29565 */
29566 function convertConv2DDataFormat(dataFormat) {
29567 if (dataFormat === 'NHWC') {
29568 return 'channelsLast';
29569 } else if (dataFormat === 'NCHW') {
29570 return 'channelsFirst';
29571 } else {
29572 throw new Error("Unknown dataFormat ".concat(dataFormat));
29573 }
29574 }
29575 /**
29576 * Check validity of pad when using dimRoundingMode.
29577 * @param opDesc A string of op description
29578 * @param pad The type of padding algorithm.
29579 * - `same` and stride 1: output will be of same size as input,
29580 * regardless of filter size.
29581 * - `valid` output will be smaller than input if filter is larger
29582 * than 1x1.
29583 * - For more info, see this guide:
29584 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
29585 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
29586 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
29587 * provided, it will default to truncate.
29588 * @throws unknown padding parameter
29589 */
29590 function checkPadOnDimRoundingMode(opDesc, pad, dimRoundingMode) {
29591 if (dimRoundingMode != null) {
29592 if (typeof pad === 'string') {
29593 throw Error("Error in ".concat(opDesc, ": pad must be an integer when using ") + "dimRoundingMode ".concat(dimRoundingMode, " but got pad ").concat(pad, "."));
29594 } else if (typeof pad === 'number') {
29595 assert$1(isInt(pad), function () {
29596 return "Error in ".concat(opDesc, ": pad must be an integer when using ") + "dimRoundingMode ".concat(dimRoundingMode, " but got pad ").concat(pad, ".");
29597 });
29598 } else if (_typeof(pad) === 'object') {
29599 pad.forEach(function (p) {
29600 p.forEach(function (v) {
29601 assert$1(isInt(v), function () {
29602 return "Error in ".concat(opDesc, ": pad must be an integer when using ") + "dimRoundingMode ".concat(dimRoundingMode, " but got pad ").concat(v, ".");
29603 });
29604 });
29605 });
29606 } else {
29607 throw Error("Error in ".concat(opDesc, ": Unknown padding parameter: ").concat(pad));
29608 }
29609 }
29610 }
29611
29612 /**
29613 * @license
29614 * Copyright 2020 Google LLC. All Rights Reserved.
29615 * Licensed under the Apache License, Version 2.0 (the "License");
29616 * you may not use this file except in compliance with the License.
29617 * You may obtain a copy of the License at
29618 *
29619 * http://www.apache.org/licenses/LICENSE-2.0
29620 *
29621 * Unless required by applicable law or agreed to in writing, software
29622 * distributed under the License is distributed on an "AS IS" BASIS,
29623 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29624 * See the License for the specific language governing permissions and
29625 * limitations under the License.
29626 * =============================================================================
29627 */
29628 /**
29629 * Reshapes a `tf.Tensor` to a given shape.
29630 *
29631 * Given an input tensor, returns a new tensor with the same values as the
29632 * input tensor with shape `shape`.
29633 *
29634 * If one component of shape is the special value -1, the size of that
29635 * dimension is computed so that the total size remains constant. In
29636 * particular, a shape of [-1] flattens into 1-D. At most one component of
29637 * shape can be -1.
29638 *
29639 * If shape is 1-D or higher, then the operation returns a tensor with shape
29640 * shape filled with the values of tensor. In this case, the number of
29641 * elements implied by shape must be the same as the number of elements in
29642 * tensor.
29643 *
29644 * ```js
29645 * const x = tf.tensor1d([1, 2, 3, 4]);
29646 * x.reshape([2, 2]).print();
29647 * ```
29648 *
29649 * @param x The input tensor to be reshaped.
29650 * @param shape An array of integers defining the output tensor shape.
29651 *
29652 * @doc {heading: 'Tensors', subheading: 'Transformations'}
29653 */
29654 function reshape_(x, shape) {
29655 var $x = convertToTensor(x, 'x', 'reshape', 'string_or_numeric');
29656 var inputs = {
29657 x: $x
29658 };
29659 var attrs = {
29660 shape: shape
29661 };
29662 return ENGINE.runKernel(Reshape$1, inputs, attrs);
29663 }
29664 var reshape$3 = /* @__PURE__ */op({
29665 reshape_: reshape_
29666 });
29667
29668 /**
29669 * @license
29670 * Copyright 2020 Google LLC. All Rights Reserved.
29671 * Licensed under the Apache License, Version 2.0 (the "License");
29672 * you may not use this file except in compliance with the License.
29673 * You may obtain a copy of the License at
29674 *
29675 * http://www.apache.org/licenses/LICENSE-2.0
29676 *
29677 * Unless required by applicable law or agreed to in writing, software
29678 * distributed under the License is distributed on an "AS IS" BASIS,
29679 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29680 * See the License for the specific language governing permissions and
29681 * limitations under the License.
29682 * =============================================================================
29683 */
29684 /**
29685 * Computes the 2D average pooling of an image.
29686 *
29687 * @param x The input tensor, of rank 4 or rank 3 of shape
29688 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
29689 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
29690 * `filterSize` is a single number, then `filterHeight == filterWidth`.
29691 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
29692 * `strides` is a single number, then `strideHeight == strideWidth`.
29693 * @param pad The type of padding algorithm:
29694 * - `same` and stride 1: output will be of same size as input,
29695 * regardless of filter size.
29696 * - `valid`: output will be smaller than input if filter is larger
29697 * than 1x1.
29698 * - For more info, see this guide:
29699 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
29700 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
29701 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
29702 * provided, it will default to truncate.
29703 *
29704 * @doc {heading: 'Operations', subheading: 'Convolution'}
29705 */
29706 function avgPool_(x, filterSize, strides, pad, dimRoundingMode) {
29707 var $x = convertToTensor(x, 'x', 'avgPool', 'float32');
29708 var dilations = 1;
29709 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
29710 return 'Error in avgPool: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
29711 });
29712 var x4D = $x;
29713 var reshapedTo4D = false;
29714 if ($x.rank === 3) {
29715 reshapedTo4D = true;
29716 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
29717 }
29718 assert$1(x4D.rank === 4, function () {
29719 return "Error in avgPool: x must be rank 4 but got rank ".concat(x4D.rank, ".");
29720 });
29721 checkPadOnDimRoundingMode('avgPool', pad, dimRoundingMode);
29722 var inputs = {
29723 x: x4D
29724 };
29725 var attrs = {
29726 filterSize: filterSize,
29727 strides: strides,
29728 pad: pad,
29729 dimRoundingMode: dimRoundingMode
29730 };
29731 // tslint:disable-next-line: no-unnecessary-type-assertion
29732 var res = ENGINE.runKernel(AvgPool, inputs, attrs);
29733 res = cast$3(res, $x.dtype);
29734 if (reshapedTo4D) {
29735 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
29736 }
29737 return res;
29738 }
29739 var avgPool$2 = /* @__PURE__ */op({
29740 avgPool_: avgPool_
29741 });
29742
29743 /**
29744 * @license
29745 * Copyright 2020 Google LLC. All Rights Reserved.
29746 * Licensed under the Apache License, Version 2.0 (the "License");
29747 * you may not use this file except in compliance with the License.
29748 * You may obtain a copy of the License at
29749 *
29750 * http://www.apache.org/licenses/LICENSE-2.0
29751 *
29752 * Unless required by applicable law or agreed to in writing, software
29753 * distributed under the License is distributed on an "AS IS" BASIS,
29754 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29755 * See the License for the specific language governing permissions and
29756 * limitations under the License.
29757 * =============================================================================
29758 */
29759 /**
29760 * Computes the 3D average pooling.
29761 *
29762 * ```js
29763 * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
29764 * const result = tf.avgPool3d(x, 2, 1, 'valid');
29765 * result.print();
29766 * ```
29767 *
29768 * @param x The input tensor, of rank 5 or rank 4 of shape
29769 * `[batch, depth, height, width, inChannels]`.
29770 * @param filterSize The filter size:
29771 * `[filterDepth, filterHeight, filterWidth]`.
29772 * If `filterSize` is a single number,
29773 * then `filterDepth == filterHeight == filterWidth`.
29774 * @param strides The strides of the pooling:
29775 * `[strideDepth, strideHeight, strideWidth]`.
29776 * If `strides` is a single number,
29777 * then `strideDepth == strideHeight == strideWidth`.
29778 * @param pad The type of padding algorithm.
29779 * - `same` and stride 1: output will be of same size as input,
29780 * regardless of filter size.
29781 * - `valid`: output will be smaller than input if filter is larger
29782 * than 1*1x1.
29783 * - For more info, see this guide:
29784 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
29785 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
29786 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
29787 * provided, it will default to truncate.
29788 * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
29789 * "NDHWC". Specify the data format of the input and output data. With the
29790 * default format "NDHWC", the data is stored in the order of: [batch,
29791 * depth, height, width, channels]. Only "NDHWC" is currently supported.
29792 *
29793 * @doc {heading: 'Operations', subheading: 'Convolution'}
29794 */
29795 function avgPool3d_(x, filterSize, strides, pad, dimRoundingMode) {
29796 var dataFormat = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 'NDHWC';
29797 var $x = convertToTensor(x, 'x', 'avgPool3d', 'float32');
29798 var x5D = $x;
29799 var reshapedTo5D = false;
29800 if ($x.rank === 4) {
29801 reshapedTo5D = true;
29802 x5D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
29803 }
29804 assert$1(x5D.rank === 5, function () {
29805 return "Error in avgPool3d: x must be rank 5 but got rank ".concat(x5D.rank, ".");
29806 });
29807 assert$1(dataFormat === 'NDHWC', function () {
29808 return "Error in avgPool3d: Only NDHWC is currently supported, " + "but got dataFormat of ".concat(dataFormat);
29809 });
29810 assert$1(typeof strides === 'number' && strides > 0 || Array.isArray(strides) && strides[0] > 0 && strides[1] > 0 && strides[2] > 0, function () {
29811 return "Error in avgPool3d: Stride must be > 0, but got '".concat(strides, "'");
29812 });
29813 checkPadOnDimRoundingMode('avgPool3d', pad, dimRoundingMode);
29814 var inputs = {
29815 x: x5D
29816 };
29817 var attrs = {
29818 filterSize: filterSize,
29819 strides: strides,
29820 pad: pad,
29821 dimRoundingMode: dimRoundingMode,
29822 dataFormat: dataFormat
29823 };
29824 // tslint:disable-next-line: no-unnecessary-type-assertion
29825 var res = ENGINE.runKernel(AvgPool3D, inputs, attrs);
29826 res = cast$3(res, x5D.dtype);
29827 if (reshapedTo5D) {
29828 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
29829 }
29830 return res;
29831 }
29832 var avgPool3d$1 = /* @__PURE__ */op({
29833 avgPool3d_: avgPool3d_
29834 });
29835
29836 /**
29837 * @license
29838 * Copyright 2020 Google LLC. All Rights Reserved.
29839 * Licensed under the Apache License, Version 2.0 (the "License");
29840 * you may not use this file except in compliance with the License.
29841 * You may obtain a copy of the License at
29842 *
29843 * http://www.apache.org/licenses/LICENSE-2.0
29844 *
29845 * Unless required by applicable law or agreed to in writing, software
29846 * distributed under the License is distributed on an "AS IS" BASIS,
29847 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29848 * See the License for the specific language governing permissions and
29849 * limitations under the License.
29850 * =============================================================================
29851 */
29852 /**
29853 * Concatenates a list of `tf.Tensor`s along a given axis.
29854 *
29855 * The tensors ranks and types must match, and their sizes must match in all
29856 * dimensions except `axis`.
29857 *
29858 * Also available are stricter rank-specific methods that assert that
29859 * `tensors` are of the given rank:
29860 * - `tf.concat1d`
29861 * - `tf.concat2d`
29862 * - `tf.concat3d`
29863 * - `tf.concat4d`
29864 *
29865 * Except `tf.concat1d` (which does not have axis param), all methods have
29866 * same signature as this method.
29867 *
29868 * ```js
29869 * const a = tf.tensor1d([1, 2]);
29870 * const b = tf.tensor1d([3, 4]);
29871 * a.concat(b).print(); // or a.concat(b)
29872 * ```
29873 *
29874 * ```js
29875 * const a = tf.tensor1d([1, 2]);
29876 * const b = tf.tensor1d([3, 4]);
29877 * const c = tf.tensor1d([5, 6]);
29878 * tf.concat([a, b, c]).print();
29879 * ```
29880 *
29881 * ```js
29882 * const a = tf.tensor2d([[1, 2], [10, 20]]);
29883 * const b = tf.tensor2d([[3, 4], [30, 40]]);
29884 * const axis = 1;
29885 * tf.concat([a, b], axis).print();
29886 * ```
29887 * @param tensors A list of tensors to concatenate.
29888 * @param axis The axis to concatenate along. Defaults to 0 (the first dim).
29889 *
29890 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
29891 */
29892 function concat_(tensors) {
29893 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
29894 assert$1(tensors.length >= 1, function () {
29895 return 'Pass at least one tensor to concat';
29896 });
29897 var $tensors = convertToTensorArray(tensors, 'tensors', 'concat', 'string_or_numeric');
29898 if ($tensors[0].dtype === 'complex64') {
29899 $tensors.forEach(function (tensor) {
29900 if (tensor.dtype !== 'complex64') {
29901 throw new Error("Cannot concatenate complex64 tensors with a tensor\n with dtype ".concat(tensor.dtype, ". "));
29902 }
29903 });
29904 }
29905 if ($tensors.length === 1) {
29906 return clone($tensors[0]);
29907 }
29908 var inputs = $tensors;
29909 var attr = {
29910 axis: axis
29911 };
29912 return ENGINE.runKernel(Concat, inputs, attr);
29913 }
29914 var concat$2 = /* @__PURE__ */op({
29915 concat_: concat_
29916 });
29917
29918 /**
29919 * Computes the dot product of two matrices, A * B. These must be matrices.
29920 *
29921 * ```js
29922 * const a = tf.tensor2d([1, 2], [1, 2]);
29923 * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
29924 *
29925 * a.matMul(b).print(); // or tf.matMul(a, b)
29926 * ```
29927 * @param a First matrix in dot product operation.
29928 * @param b Second matrix in dot product operation.
29929 * @param transposeA If true, `a` is transposed before multiplication.
29930 * @param transposeB If true, `b` is transposed before multiplication.
29931 *
29932 * @doc {heading: 'Operations', subheading: 'Matrices'}
29933 */
29934 function matMul_(a, b) {
29935 var transposeA = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
29936 var transposeB = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
29937 var $a = convertToTensor(a, 'a', 'matMul');
29938 var $b = convertToTensor(b, 'b', 'matMul');
29939 var _makeTypesMatch = makeTypesMatch($a, $b);
29940 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
29941 $a = _makeTypesMatch2[0];
29942 $b = _makeTypesMatch2[1];
29943 var inputs = {
29944 a: $a,
29945 b: $b
29946 };
29947 var attrs = {
29948 transposeA: transposeA,
29949 transposeB: transposeB
29950 };
29951 return ENGINE.runKernel(BatchMatMul, inputs, attrs);
29952 }
29953 var matMul$1 = /* @__PURE__ */op({
29954 matMul_: matMul_
29955 });
29956
29957 /**
29958 * @license
29959 * Copyright 2018 Google LLC. All Rights Reserved.
29960 * Licensed under the Apache License, Version 2.0 (the "License");
29961 * you may not use this file except in compliance with the License.
29962 * You may obtain a copy of the License at
29963 *
29964 * http://www.apache.org/licenses/LICENSE-2.0
29965 *
29966 * Unless required by applicable law or agreed to in writing, software
29967 * distributed under the License is distributed on an "AS IS" BASIS,
29968 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29969 * See the License for the specific language governing permissions and
29970 * limitations under the License.
29971 * =============================================================================
29972 */
29973 /**
29974 * Computes sigmoid element-wise, `1 / (1 + exp(-x))`
29975 *
29976 * ```js
29977 * const x = tf.tensor1d([0, -1, 2, -3]);
29978 *
29979 * x.sigmoid().print(); // or tf.sigmoid(x)
29980 * ```
29981 * @param x The input tensor.
29982 *
29983 * @doc {heading: 'Operations', subheading: 'Basic math'}
29984 */
29985 function sigmoid_(x) {
29986 var $x = convertToTensor(x, 'x', 'sigmoid', 'float32');
29987 var inputs = {
29988 x: $x
29989 };
29990 return ENGINE.runKernel(Sigmoid$1, inputs);
29991 }
29992 var sigmoid$2 = /* @__PURE__ */op({
29993 sigmoid_: sigmoid_
29994 });
29995
29996 /**
29997 * @license
29998 * Copyright 2018 Google LLC. All Rights Reserved.
29999 * Licensed under the Apache License, Version 2.0 (the "License");
30000 * you may not use this file except in compliance with the License.
30001 * You may obtain a copy of the License at
30002 *
30003 * http://www.apache.org/licenses/LICENSE-2.0
30004 *
30005 * Unless required by applicable law or agreed to in writing, software
30006 * distributed under the License is distributed on an "AS IS" BASIS,
30007 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30008 * See the License for the specific language governing permissions and
30009 * limitations under the License.
30010 * =============================================================================
30011 */
30012 /**
30013 * Extracts a slice from a `tf.Tensor` starting at coordinates `begin`
30014 * and is of size `size`.
30015 *
30016 * Also available are stricter rank-specific methods with the same signature
30017 * as this method that assert that `x` is of the given rank:
30018 * - `tf.slice1d`
30019 * - `tf.slice2d`
30020 * - `tf.slice3d`
30021 * - `tf.slice4d`
30022 *
30023 * ```js
30024 * const x = tf.tensor1d([1, 2, 3, 4]);
30025 *
30026 * x.slice([1], [2]).print();
30027 * ```
30028 *
30029 * ```js
30030 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
30031 *
30032 * x.slice([1, 0], [1, 2]).print();
30033 * ```
30034 * @param x The input `tf.Tensor` to slice from.
30035 * @param begin The coordinates to start the slice from. The length can be
30036 * less than the rank of x - the rest of the axes will have implicit 0 as
30037 * start. Can also be a single number, in which case it specifies the
30038 * first axis.
30039 * @param size The size of the slice. The length can be less than the rank of
30040 * x - the rest of the axes will have implicit -1. A value of -1 requests
30041 * the rest of the dimensions in the axis. Can also be a single number,
30042 * in which case it specifies the size of the first axis.
30043 *
30044 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
30045 */
30046 function slice_(x, begin, size) {
30047 var $x = convertToTensor(x, 'x', 'slice', 'string_or_numeric');
30048 if ($x.rank === 0) {
30049 throw new Error('Slicing scalar is not possible');
30050 }
30051 var inputs = {
30052 x: $x
30053 };
30054 var attrs = {
30055 begin: begin,
30056 size: size
30057 };
30058 return ENGINE.runKernel(Slice, inputs, attrs);
30059 }
30060 var slice$2 = /* @__PURE__ */op({
30061 slice_: slice_
30062 });
30063
30064 /**
30065 * @license
30066 * Copyright 2018 Google LLC. All Rights Reserved.
30067 * Licensed under the Apache License, Version 2.0 (the "License");
30068 * you may not use this file except in compliance with the License.
30069 * You may obtain a copy of the License at
30070 *
30071 * http://www.apache.org/licenses/LICENSE-2.0
30072 *
30073 * Unless required by applicable law or agreed to in writing, software
30074 * distributed under the License is distributed on an "AS IS" BASIS,
30075 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30076 * See the License for the specific language governing permissions and
30077 * limitations under the License.
30078 * =============================================================================
30079 */
30080 /**
30081 * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)`
30082 *
30083 * ```js
30084 * const x = tf.tensor1d([0, 1, -1, 70]);
30085 *
30086 * x.tanh().print(); // or tf.tanh(x)
30087 * ```
30088 * @param x The input tensor.
30089 *
30090 * @doc {heading: 'Operations', subheading: 'Basic math'}
30091 */
30092 function tanh_(x) {
30093 var $x = convertToTensor(x, 'x', 'tanh', 'float32');
30094 var inputs = {
30095 x: $x
30096 };
30097 return ENGINE.runKernel(Tanh$1, inputs);
30098 }
30099 var tanh$2 = /* @__PURE__ */op({
30100 tanh_: tanh_
30101 });
30102
30103 /**
30104 * @license
30105 * Copyright 2020 Google LLC. All Rights Reserved.
30106 * Licensed under the Apache License, Version 2.0 (the "License");
30107 * you may not use this file except in compliance with the License.
30108 * You may obtain a copy of the License at
30109 *
30110 * http://www.apache.org/licenses/LICENSE-2.0
30111 *
30112 * Unless required by applicable law or agreed to in writing, software
30113 * distributed under the License is distributed on an "AS IS" BASIS,
30114 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30115 * See the License for the specific language governing permissions and
30116 * limitations under the License.
30117 * =============================================================================
30118 */
30119 /**
30120 * Computes the next state and output of a BasicLSTMCell.
30121 *
30122 * Returns `[newC, newH]`.
30123 *
30124 * Derived from tf.contrib.rnn.BasicLSTMCell.
30125 *
30126 * @param forgetBias Forget bias for the cell.
30127 * @param lstmKernel The weights for the cell.
30128 * @param lstmBias The bias for the cell.
30129 * @param data The input to the cell.
30130 * @param c Previous cell state.
30131 * @param h Previous cell output.
30132 *
30133 * @doc {heading: 'Operations', subheading: 'RNN'}
30134 */
30135 function basicLSTMCell_(forgetBias, lstmKernel, lstmBias, data, c, h) {
30136 var $forgetBias = convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell');
30137 var $lstmKernel = convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell');
30138 var $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell');
30139 var $data = convertToTensor(data, 'data', 'basicLSTMCell');
30140 var $c = convertToTensor(c, 'c', 'basicLSTMCell');
30141 var $h = convertToTensor(h, 'h', 'basicLSTMCell');
30142 var combined = concat$2([$data, $h], 1);
30143 var weighted = matMul$1(combined, $lstmKernel);
30144 var res = add$3(weighted, $lstmBias);
30145 // i = input_gate, j = new_input, f = forget_gate, o = output_gate
30146 var batchSize = res.shape[0];
30147 var sliceCols = res.shape[1] / 4;
30148 var sliceSize = [batchSize, sliceCols];
30149 var i = slice$2(res, [0, 0], sliceSize);
30150 var j = slice$2(res, [0, sliceCols], sliceSize);
30151 var f = slice$2(res, [0, sliceCols * 2], sliceSize);
30152 var o = slice$2(res, [0, sliceCols * 3], sliceSize);
30153 var newC = add$3(mul(sigmoid$2(i), tanh$2(j)), mul($c, sigmoid$2(add$3($forgetBias, f))));
30154 var newH = mul(tanh$2(newC), sigmoid$2(o));
30155 return [newC, newH];
30156 }
30157 var basicLSTMCell = /* @__PURE__ */op({
30158 basicLSTMCell_: basicLSTMCell_
30159 });
30160
30161 /**
30162 * @license
30163 * Copyright 2020 Google LLC. All Rights Reserved.
30164 * Licensed under the Apache License, Version 2.0 (the "License");
30165 * you may not use this file except in compliance with the License.
30166 * You may obtain a copy of the License at
30167 *
30168 * http://www.apache.org/licenses/LICENSE-2.0
30169 *
30170 * Unless required by applicable law or agreed to in writing, software
30171 * distributed under the License is distributed on an "AS IS" BASIS,
30172 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30173 * See the License for the specific language governing permissions and
30174 * limitations under the License.
30175 * =============================================================================
30176 */
30177 /**
30178 * This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
30179 * shape `blockShape + [batch]`, interleaves these blocks back into the grid
30180 * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with
30181 * the same rank as the input. The spatial dimensions of this intermediate
30182 * result are then optionally cropped according to `crops` to produce the
30183 * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise
30184 * description.
30185 *
30186 * ```js
30187 * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
30188 * const blockShape = [2, 2];
30189 * const crops = [[0, 0], [0, 0]];
30190 *
30191 * x.batchToSpaceND(blockShape, crops).print();
30192 * ```
30193 *
30194 * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
30195 * remainingShape`, where spatialShape has `M` dimensions.
30196 * @param blockShape A 1-D array. Must have shape `[M]`, all values must
30197 * be >= 1.
30198 * @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0.
30199 * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input
30200 * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required
30201 * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]`
30202 *
30203 * This operation is equivalent to the following steps:
30204 *
30205 * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ...,
30206 * blockShape[M-1], batch / prod(blockShape), x.shape[1], ...,
30207 * x.shape[N-1]]`
30208 *
30209 * 2. Permute dimensions of `reshaped` to produce `permuted` of shape `[batch /
30210 * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M],
30211 * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
30212 *
30213 * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch /
30214 * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] *
30215 * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
30216 *
30217 * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted`
30218 * according to `crops` to produce the output of shape: `[batch /
30219 * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1],
30220 * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] -
30221 * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]`
30222 *
30223 * @doc {heading: 'Tensors', subheading: 'Transformations'}
30224 */
30225 function batchToSpaceND_(x, blockShape, crops) {
30226 var $x = convertToTensor(x, 'x', 'batchToSpaceND');
30227 var prod = blockShape.reduce(function (a, b) {
30228 return a * b;
30229 });
30230 assert$1($x.rank >= 1 + blockShape.length, function () {
30231 return "input rank is ".concat($x.rank, " but should be > than blockShape.length ").concat(blockShape.length);
30232 });
30233 assert$1(crops.length === blockShape.length, function () {
30234 return "crops.length is ".concat(crops.length, " but should be equal to blockShape.length ").concat(blockShape.length);
30235 });
30236 assert$1($x.shape[0] % prod === 0, function () {
30237 return "input tensor batch is ".concat($x.shape[0], " but is not divisible by the product of ") + "the elements of blockShape ".concat(blockShape.join(' * '), " === ").concat(prod);
30238 });
30239 var inputs = {
30240 x: $x
30241 };
30242 var attrs = {
30243 blockShape: blockShape,
30244 crops: crops
30245 };
30246 return ENGINE.runKernel(BatchToSpaceND, inputs, attrs);
30247 }
30248 var batchToSpaceND$2 = /* @__PURE__ */op({
30249 batchToSpaceND_: batchToSpaceND_
30250 });
30251
30252 function xAs4D(x) {
30253 var x4D;
30254 if (x.rank === 0 || x.rank === 1) {
30255 x4D = reshape$3(x, [1, 1, 1, x.size]);
30256 } else if (x.rank === 2) {
30257 x4D = reshape$3(x, [1, 1, x.shape[0], x.shape[1]]);
30258 } else if (x.rank === 3) {
30259 x4D = reshape$3(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
30260 } else {
30261 x4D = x;
30262 }
30263 return x4D;
30264 }
30265
30266 /**
30267 * @license
30268 * Copyright 2020 Google LLC. All Rights Reserved.
30269 * Licensed under the Apache License, Version 2.0 (the "License");
30270 * you may not use this file except in compliance with the License.
30271 * You may obtain a copy of the License at
30272 *
30273 * http://www.apache.org/licenses/LICENSE-2.0
30274 *
30275 * Unless required by applicable law or agreed to in writing, software
30276 * distributed under the License is distributed on an "AS IS" BASIS,
30277 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30278 * See the License for the specific language governing permissions and
30279 * limitations under the License.
30280 * =============================================================================
30281 */
30282 /**
30283 * Batch normalization.
30284 *
30285 * As described in
30286 * [http://arxiv.org/abs/1502.03167](http://arxiv.org/abs/1502.03167).
30287 *
30288 * Mean, variance, scale, and offset can be of two shapes:
30289 * - The same shape as the input.
30290 * - In the common case, the depth dimension is the last dimension of x, so
30291 * the values would be a `tf.Tensor1D` of shape [depth].
30292 *
30293 * Also available are stricter rank-specific methods with the same signature
30294 * as this method that assert that parameters passed are of given rank
30295 * - `tf.batchNorm2d`
30296 * - `tf.batchNorm3d`
30297 * - `tf.batchNorm4d`
30298 *
30299 * @param x The input Tensor.
30300 * @param mean A mean Tensor.
30301 * @param variance A variance Tensor.
30302 * @param offset An offset Tensor.
30303 * @param scale A scale Tensor.
30304 * @param varianceEpsilon A small float number to avoid dividing by 0.
30305 *
30306 * @doc {heading: 'Operations', subheading: 'Normalization'}
30307 */
30308 function batchNorm_(x, mean, variance, offset, scale, varianceEpsilon) {
30309 if (varianceEpsilon == null) {
30310 varianceEpsilon = 0.001;
30311 }
30312 var $x = convertToTensor(x, 'x', 'batchNorm');
30313 var $mean = convertToTensor(mean, 'mean', 'batchNorm');
30314 var $variance = convertToTensor(variance, 'variance', 'batchNorm');
30315 var $scale;
30316 if (scale != null) {
30317 $scale = convertToTensor(scale, 'scale', 'batchNorm');
30318 }
30319 var $offset;
30320 if (offset != null) {
30321 $offset = convertToTensor(offset, 'offset', 'batchNorm');
30322 }
30323 assert$1($mean.rank === $variance.rank, function () {
30324 return 'Batch normalization gradient requires mean and variance to have ' + 'equal ranks.';
30325 });
30326 assert$1($offset == null || $mean.rank === $offset.rank, function () {
30327 return 'Batch normalization gradient requires mean and offset to have ' + 'equal ranks.';
30328 });
30329 assert$1($scale == null || $mean.rank === $scale.rank, function () {
30330 return 'Batch normalization gradient requires mean and scale to have ' + 'equal ranks.';
30331 });
30332 var x4D = xAs4D($x);
30333 var inputs = {
30334 x: x4D,
30335 scale: $scale,
30336 offset: $offset,
30337 mean: $mean,
30338 variance: $variance
30339 };
30340 var attrs = {
30341 varianceEpsilon: varianceEpsilon
30342 };
30343 // tslint:disable-next-line: no-unnecessary-type-assertion
30344 var res = ENGINE.runKernel(FusedBatchNorm, inputs, attrs);
30345 return reshape$3(res, $x.shape);
30346 }
30347 var batchNorm$2 = /* @__PURE__ */op({
30348 batchNorm_: batchNorm_
30349 });
30350
30351 /**
30352 * Batch normalization, strictly for 2D. For the more relaxed version, see
30353 * `tf.batchNorm`.
30354 *
30355 * @param x The input Tensor.
30356 * @param mean A mean Tensor.
30357 * @param variance A variance Tensor.
30358 * @param offset An offset Tensor.
30359 * @param scale A scale Tensor.
30360 * @param varianceEpsilon A small float number to avoid dividing by 0.
30361 */
30362 function batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon) {
30363 var $x = convertToTensor(x, 'x', 'batchNorm');
30364 var $mean = convertToTensor(mean, 'mean', 'batchNorm');
30365 var $variance = convertToTensor(variance, 'variance', 'batchNorm');
30366 var $scale;
30367 if (scale != null) {
30368 $scale = convertToTensor(scale, 'scale', 'batchNorm');
30369 }
30370 var $offset;
30371 if (offset != null) {
30372 $offset = convertToTensor(offset, 'offset', 'batchNorm');
30373 }
30374 assert$1($x.rank === 2, function () {
30375 return "Error in batchNorm2D: x must be rank 2 but got rank " + "".concat($x.rank, ".");
30376 });
30377 assert$1($mean.rank === 2 || $mean.rank === 1, function () {
30378 return "Error in batchNorm2D: mean must be rank 2 or rank 1 but " + "got rank ".concat($mean.rank, ".");
30379 });
30380 assert$1($variance.rank === 2 || $variance.rank === 1, function () {
30381 return "Error in batchNorm2D: variance must be rank 2 or rank 1 " + "but got rank ".concat($variance.rank, ".");
30382 });
30383 if ($scale != null) {
30384 assert$1($scale.rank === 2 || $scale.rank === 1, function () {
30385 return "Error in batchNorm2D: scale must be rank 2 or rank 1 " + "but got rank ".concat($scale.rank, ".");
30386 });
30387 }
30388 if ($offset != null) {
30389 assert$1($offset.rank === 2 || $offset.rank === 1, function () {
30390 return "Error in batchNorm2D: offset must be rank 2 or rank 1 " + "but got rank ".concat($offset.rank, ".");
30391 });
30392 }
30393 return batchNorm$2($x, $mean, $variance, $offset, $scale, varianceEpsilon);
30394 }
30395 var batchNorm2d = /* @__PURE__ */op({
30396 batchNorm2d_: batchNorm2d_
30397 });
30398
30399 /**
30400 * Batch normalization, strictly for 3D. For the more relaxed version, see
30401 * `tf.batchNorm`.
30402 *
30403 * @param x The input Tensor.
30404 * @param mean A mean Tensor.
30405 * @param variance A variance Tensor.
30406 * @param offset An offset Tensor.
30407 * @param scale A scale Tensor.
30408 * @param varianceEpsilon A small float number to avoid dividing by 0.
30409 */
30410 function batchNorm3d_(x, mean, variance, offset, scale, varianceEpsilon) {
30411 var $x = convertToTensor(x, 'x', 'batchNorm');
30412 var $mean = convertToTensor(mean, 'mean', 'batchNorm');
30413 var $variance = convertToTensor(variance, 'variance', 'batchNorm');
30414 var $scale;
30415 if (scale != null) {
30416 $scale = convertToTensor(scale, 'scale', 'batchNorm');
30417 }
30418 var $offset;
30419 if (offset != null) {
30420 $offset = convertToTensor(offset, 'offset', 'batchNorm');
30421 }
30422 assert$1($x.rank === 3, function () {
30423 return "Error in batchNorm3D: x must be rank 3 but got rank " + "".concat($x.rank, ".");
30424 });
30425 assert$1($mean.rank === 3 || $mean.rank === 1, function () {
30426 return "Error in batchNorm3D: mean must be rank 3 or rank 1 but " + "got rank ".concat($mean.rank, ".");
30427 });
30428 assert$1($variance.rank === 3 || $variance.rank === 1, function () {
30429 return "Error in batchNorm3D: variance must be rank 3 or rank 1 " + "but got rank ".concat($variance.rank, ".");
30430 });
30431 if ($scale != null) {
30432 assert$1($scale.rank === 3 || $scale.rank === 1, function () {
30433 return "Error in batchNorm3D: scale must be rank 3 or rank 1 " + "but got rank ".concat($scale.rank, ".");
30434 });
30435 }
30436 if ($offset != null) {
30437 assert$1($offset.rank === 3 || $offset.rank === 1, function () {
30438 return "Error in batchNorm3D: offset must be rank 3 or rank 1 " + "but got rank ".concat($offset.rank, ".");
30439 });
30440 }
30441 return batchNorm$2($x, $mean, $variance, $offset, $scale, varianceEpsilon);
30442 }
30443 var batchNorm3d = /* @__PURE__ */op({
30444 batchNorm3d_: batchNorm3d_
30445 });
30446
30447 /**
30448 * Batch normalization, strictly for 4D. For the more relaxed version, see
30449 * `tf.batchNorm`.
30450 *
30451 * @param x The input Tensor.
30452 * @param mean A mean Tensor.
30453 * @param variance A variance Tensor.
30454 * @param offset An offset Tensor.
30455 * @param scale A scale Tensor.
30456 * @param varianceEpsilon A small float number to avoid dividing by 0.
30457 */
30458 function batchNorm4d_(x, mean, variance, offset, scale, varianceEpsilon) {
30459 var $x = convertToTensor(x, 'x', 'batchNorm');
30460 var $mean = convertToTensor(mean, 'mean', 'batchNorm');
30461 var $variance = convertToTensor(variance, 'variance', 'batchNorm');
30462 var $scale;
30463 if (scale != null) {
30464 $scale = convertToTensor(scale, 'scale', 'batchNorm');
30465 }
30466 var $offset;
30467 if (offset != null) {
30468 $offset = convertToTensor(offset, 'offset', 'batchNorm');
30469 }
30470 assert$1($x.rank === 4, function () {
30471 return "Error in batchNorm4D: x must be rank 4 but got rank " + "".concat($x.rank, ".");
30472 });
30473 assert$1($mean.rank === 4 || $mean.rank === 1, function () {
30474 return "Error in batchNorm4D: mean must be rank 4 or rank 1 but " + "got rank ".concat($mean.rank, ".");
30475 });
30476 assert$1($variance.rank === 4 || $variance.rank === 1, function () {
30477 return "Error in batchNorm4D: variance must be rank 4 or rank 1 " + "but got rank ".concat($variance.rank, ".");
30478 });
30479 if ($scale != null) {
30480 assert$1($scale.rank === 4 || $scale.rank === 1, function () {
30481 return "Error in batchNorm4D: scale must be rank 4 or rank 1 " + "but got rank ".concat($scale.rank, ".");
30482 });
30483 }
30484 if ($offset != null) {
30485 assert$1($offset.rank === 4 || $offset.rank === 1, function () {
30486 return "Error in batchNorm4D: offset must be rank 4 or rank 1 " + "but got rank ".concat($offset.rank, ".");
30487 });
30488 }
30489 return batchNorm$2($x, $mean, $variance, $offset, $scale, varianceEpsilon);
30490 }
30491 var batchNorm4d = /* @__PURE__ */op({
30492 batchNorm4d_: batchNorm4d_
30493 });
30494
30495 /**
30496 * @license
30497 * Copyright 2020 Google LLC. All Rights Reserved.
30498 * Licensed under the Apache License, Version 2.0 (the "License");
30499 * you may not use this file except in compliance with the License.
30500 * You may obtain a copy of the License at
30501 *
30502 * http://www.apache.org/licenses/LICENSE-2.0
30503 *
30504 * Unless required by applicable law or agreed to in writing, software
30505 * distributed under the License is distributed on an "AS IS" BASIS,
30506 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30507 * See the License for the specific language governing permissions and
30508 * limitations under the License.
30509 * =============================================================================
30510 */
30511 /**
30512 * Outputs a vector with length `size` and the same dtype as `weights`.
30513 *
30514 * If `weights` are empty, then index `i` stores the number of times the value
30515 * `i` is counted in `x`. If `weights` are non-empty, then index `i` stores the
30516 * sum of the value in `weights` at each index where the corresponding value in
30517 * `x` is `i`.
30518 *
30519 * Values in `x` outside of the range [0, size) are ignored.
30520 *
30521 * @param x The input int tensor, rank 1.
30522 * @param weights The weights tensor, must have the same shape as x, or a
30523 * length-0 Tensor, in which case it acts as all weights equal to 1.
30524 * @param size Non-negative integer.
30525 *
30526 * @doc {heading: 'Operations', subheading: 'Reduction'}
30527 */
30528 function bincount_(x, weights, size) {
30529 var $x = convertToTensor(x, 'x', 'bincount');
30530 var $weights = convertToTensor(weights, 'weights', 'bincount');
30531 assert$1($x.dtype === 'int32', function () {
30532 return "Error in bincount: input " + "dtype must be int32, but got ".concat($x.dtype);
30533 });
30534 assert$1(size >= 0, function () {
30535 return "size must be non-negative, but got ".concat(size, ".");
30536 });
30537 assert$1($weights.size === $x.size || $weights.size === 0, function () {
30538 return "Error in bincount: weights must have the same size as input or" + "0-length, but got input shape: ".concat($x.shape, ", weights shape: ") + "".concat($weights.shape, ".");
30539 });
30540 var inputs = {
30541 x: $x,
30542 weights: $weights
30543 };
30544 var attrs = {
30545 size: size
30546 };
30547 return ENGINE.runKernel(Bincount, inputs, attrs);
30548 }
30549 var bincount$2 = /* @__PURE__ */op({
30550 bincount_: bincount_
30551 });
30552
30553 /**
30554 * @license
30555 * Copyright 2023 Google LLC.
30556 * Licensed under the Apache License, Version 2.0 (the "License");
30557 * you may not use this file except in compliance with the License.
30558 * You may obtain a copy of the License at
30559 *
30560 * http://www.apache.org/licenses/LICENSE-2.0
30561 *
30562 * Unless required by applicable law or agreed to in writing, software
30563 * distributed under the License is distributed on an "AS IS" BASIS,
30564 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30565 * See the License for the specific language governing permissions and
30566 * limitations under the License.
30567 * =============================================================================
30568 */
30569 /**
30570 * Bitwise `AND` operation for input tensors.
30571 *
30572 * Given two input tensors, returns a new tensor
30573 * with the `AND` calculated values.
30574 *
30575 * The method supports int32 values
30576 *
30577 *
30578 * ```js
30579 * const x = tf.tensor1d([0, 5, 3, 14], 'int32');
30580 * const y = tf.tensor1d([5, 0, 7, 11], 'int32');
30581 * tf.bitwiseAnd(x, y).print();
30582 * ```
30583 *
30584 * @param x The input tensor to be calculated.
30585 * @param y The input tensor to be calculated.
30586 *
30587 * @doc {heading: 'Operations', subheading: 'Logical'}
30588 */
30589 function bitwiseAnd_(x, y) {
30590 var $x = convertToTensor(x, 'x', 'bitwiseAnd');
30591 var $y = convertToTensor(y, 'y', 'bitwiseAnd');
30592 if (!arraysEqual($x.shape, $y.shape)) {
30593 throw new Error("BitwiseAnd: Tensors must have the same shape. x: ".concat($x.shape, ", y: ").concat($y.shape));
30594 }
30595 if ($x.dtype !== 'int32' || $y.dtype !== 'int32') {
30596 throw new Error("BitwiseAnd: Only supports 'int32' values in tensor, found type of x: ".concat($x.dtype, " and type of y: ").concat($y.dtype));
30597 }
30598 var inputs = {
30599 a: $x,
30600 b: $y
30601 };
30602 return ENGINE.runKernel(BitwiseAnd, inputs);
30603 }
30604 var bitwiseAnd$2 = /* @__PURE__ */op({
30605 bitwiseAnd_: bitwiseAnd_
30606 });
30607
30608 /**
30609 * @license
30610 * Copyright 2021 Google LLC. All Rights Reserved.
30611 * Licensed under the Apache License, Version 2.0 (the "License");
30612 * you may not use this file except in compliance with the License.
30613 * You may obtain a copy of the License at
30614 *
30615 * http://www.apache.org/licenses/LICENSE-2.0
30616 *
30617 * Unless required by applicable law or agreed to in writing, software
30618 * distributed under the License is distributed on an "AS IS" BASIS,
30619 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30620 * See the License for the specific language governing permissions and
30621 * limitations under the License.
30622 * =============================================================================
30623 */
30624 /**
30625 * Return the shape of s0 op s1 with broadcast.
30626 *
30627 * compute r0, the broadcasted shape as a tensor.
30628 * s0, s1 and r0 are all integer vectors.
30629 *
30630 * This function returns the shape of the result of an operation between
30631 * two tensors of size s0 and s1 performed with broadcast.
30632 *
30633 * @param s0 A tensor representing a shape
30634 * @param s1 A tensor representing a shape
30635 *
30636 * @doc {heading: 'Tensors', subheading: 'Transformations'}
30637 */
30638 function broadcastArgs_(s0, s1) {
30639 var shape1Input = convertToTensor(s0, 's0', 'broadcastArgs', 'int32');
30640 var shape2Input = convertToTensor(s1, 's1', 'broadcastArgs', 'int32');
30641 if (shape1Input.rank !== 1) {
30642 throw new Error('broadcastArgs(): first input must be a vector (rank=1). ' + "Has rank ".concat(shape1Input.rank));
30643 }
30644 if (shape2Input.rank !== 1) {
30645 throw new Error('broadcastArgs(): second input must be a vector (rank=1). ' + "Has rank ".concat(shape2Input.rank));
30646 }
30647 var inputs = {
30648 s0: shape1Input,
30649 s1: shape2Input
30650 };
30651 return ENGINE.runKernel(BroadcastArgs, inputs);
30652 }
30653 var broadcastArgs$2 = /* @__PURE__ */op({
30654 broadcastArgs_: broadcastArgs_
30655 });
30656
30657 /**
30658 * @license
30659 * Copyright 2020 Google LLC. All Rights Reserved.
30660 * Licensed under the Apache License, Version 2.0 (the "License");
30661 * you may not use this file except in compliance with the License.
30662 * You may obtain a copy of the License at
30663 *
30664 * http://www.apache.org/licenses/LICENSE-2.0
30665 *
30666 * Unless required by applicable law or agreed to in writing, software
30667 * distributed under the License is distributed on an "AS IS" BASIS,
30668 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30669 * See the License for the specific language governing permissions and
30670 * limitations under the License.
30671 * =============================================================================
30672 */
30673 /**
30674 * Broadcast an array to a compatible shape NumPy-style.
30675 *
30676 * The tensor's shape is compared to the broadcast shape from end to beginning.
30677 * Ones are prepended to the tensor's shape until it has the same length as
30678 * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
30679 * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
30680 * the input tensor is tiled N times along that axis (using tf.tile).
30681 *
30682 * @param input The tensor that is to be broadcasted.
30683 * @param shape The input is to be broadcast to this shape.
30684 *
30685 * @doc {heading: 'Tensors', subheading: 'Transformations'}
30686 */
30687 function broadcastTo_(x, shape) {
30688 var input = convertToTensor(x, 'broadcastTo', 'x');
30689 var xShape = input.shape;
30690 assertNonNegativeIntegerDimensions(shape);
30691 if (shape.length < input.rank) {
30692 throw new Error("broadcastTo(): shape.length=".concat(shape.length, " < input.rank=").concat(input.rank, "."));
30693 }
30694 if (shape.length > input.rank) {
30695 var newShape = input.shape.slice();
30696 while (newShape.length < shape.length) {
30697 newShape.unshift(1);
30698 }
30699 input = reshape$3(input, newShape);
30700 }
30701 var inputShape = input.shape;
30702 var reps = Array.from(shape);
30703 for (var i = shape.length - 1; i >= 0; i--) {
30704 if (inputShape[i] === shape[i]) {
30705 reps[i] = 1;
30706 } else if (input.shape[i] !== 1) {
30707 throw new Error("broadcastTo(): [".concat(xShape, "] cannot be broadcast to [").concat(shape, "]."));
30708 }
30709 }
30710 var axes = reps.map(function (n, i) {
30711 return n > 1 ? i : -1;
30712 }).filter(function (i) {
30713 return i >= 0;
30714 });
30715 if (axes.length === 0) {
30716 return clone(input);
30717 }
30718 // TODO call broadcastTo kernel directly once backends implement broadcstTo
30719 var inputs = {
30720 x: input
30721 };
30722 var attrs = {
30723 reps: reps
30724 };
30725 return ENGINE.runKernel(Tile, inputs, attrs);
30726 }
30727 var broadcastTo = /* @__PURE__ */op({
30728 broadcastTo_: broadcastTo_
30729 });
30730
30731 /**
30732 * @license
30733 * Copyright 2018 Google LLC. All Rights Reserved.
30734 * Licensed under the Apache License, Version 2.0 (the "License");
30735 * you may not use this file except in compliance with the License.
30736 * You may obtain a copy of the License at
30737 *
30738 * http://www.apache.org/licenses/LICENSE-2.0
30739 *
30740 * Unless required by applicable law or agreed to in writing, software
30741 * distributed under the License is distributed on an "AS IS" BASIS,
30742 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30743 * See the License for the specific language governing permissions and
30744 * limitations under the License.
30745 * =============================================================================
30746 */
30747 /**
30748 * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)`
30749 *
30750 * ```js
30751 * const x = tf.tensor1d([.6, 1.1, -3.3]);
30752 *
30753 * x.ceil().print(); // or tf.ceil(x)
30754 * ```
30755 * @param x The input Tensor.
30756 *
30757 * @doc {heading: 'Operations', subheading: 'Basic math'}
30758 */
30759 function ceil_(x) {
30760 var $x = convertToTensor(x, 'x', 'ceil', 'float32');
30761 var inputs = {
30762 x: $x
30763 };
30764 return ENGINE.runKernel(Ceil, inputs);
30765 }
30766 var ceil$2 = /* @__PURE__ */op({
30767 ceil_: ceil_
30768 });
30769
30770 /**
30771 * @license
30772 * Copyright 2020 Google LLC. All Rights Reserved.
30773 * Licensed under the Apache License, Version 2.0 (the "License");
30774 * you may not use this file except in compliance with the License.
30775 * You may obtain a copy of the License at
30776 *
30777 * http://www.apache.org/licenses/LICENSE-2.0
30778 *
30779 * Unless required by applicable law or agreed to in writing, software
30780 * distributed under the License is distributed on an "AS IS" BASIS,
30781 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30782 * See the License for the specific language governing permissions and
30783 * limitations under the License.
30784 * =============================================================================
30785 */
30786 /**
30787 * Creates a `tf.Tensor` filled with a scalar value.
30788 *
30789 * ```js
30790 * tf.fill([2, 2], 4).print();
30791 * ```
30792 *
30793 * @param shape An array of integers defining the output tensor shape.
30794 * @param value The scalar value to fill the tensor with.
30795 * @param dtype The type of an element in the resulting tensor. Defaults to
30796 * 'float32' if the given param value is a number, otherwise 'string'.
30797 *
30798 * @doc {heading: 'Tensors', subheading: 'Creation'}
30799 */
30800 function fill$2(shape, value, dtype) {
30801 assertNonNegativeIntegerDimensions(shape);
30802 dtype = dtype || inferDtype(value);
30803 var attrs = {
30804 shape: shape,
30805 value: value,
30806 dtype: dtype
30807 };
30808 return ENGINE.runKernel(Fill, {}, attrs);
30809 }
30810
30811 /**
30812 * @license
30813 * Copyright 2018 Google LLC. All Rights Reserved.
30814 * Licensed under the Apache License, Version 2.0 (the "License");
30815 * you may not use this file except in compliance with the License.
30816 * You may obtain a copy of the License at
30817 *
30818 * http://www.apache.org/licenses/LICENSE-2.0
30819 *
30820 * Unless required by applicable law or agreed to in writing, software
30821 * distributed under the License is distributed on an "AS IS" BASIS,
30822 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30823 * See the License for the specific language governing permissions and
30824 * limitations under the License.
30825 * =============================================================================
30826 */
30827 /**
30828 * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)`
30829 *
30830 * ```js
30831 * const x = tf.tensor1d([-1, 2, -3, 4]);
30832 *
30833 * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3)
30834 * ```
30835 * @param x The input tensor.
30836 * @param clipValueMin Lower bound of range to be clipped to.
30837 * @param clipValueMax Upper bound of range to be clipped to.
30838 *
30839 * @doc {heading: 'Operations', subheading: 'Basic math'}
30840 */
30841 function clipByValue_(x, clipValueMin, clipValueMax) {
30842 var $x = convertToTensor(x, 'x', 'clipByValue');
30843 assert$1(clipValueMin <= clipValueMax, function () {
30844 return "Error in clip: min (".concat(clipValueMin, ") must be ") + "less than or equal to max (".concat(clipValueMax, ").");
30845 });
30846 if (clipValueMin === clipValueMax) {
30847 return fill$2($x.shape, clipValueMin, $x.dtype);
30848 }
30849 var inputs = {
30850 x: $x
30851 };
30852 var attrs = {
30853 clipValueMin: clipValueMin,
30854 clipValueMax: clipValueMax
30855 };
30856 return ENGINE.runKernel(ClipByValue, inputs, attrs);
30857 }
30858 var clipByValue$2 = /* @__PURE__ */op({
30859 clipByValue_: clipByValue_
30860 });
30861
30862 /**
30863 * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details.
30864 *
30865 * For example, if:
30866 * A: shape(3) = |r1, g1, b1|
30867 * B: shape(2) = |r2, g2|
30868 * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2|
30869 *
30870 * @param tensors A list of`tf.Tensor`s to concatenate.
30871 * @return The concatenated array.
30872 */
30873 function concat1d_(tensors) {
30874 return concat$2(tensors, 0 /* axis */);
30875 }
30876
30877 var concat1d = /* @__PURE__ */op({
30878 concat1d_: concat1d_
30879 });
30880
30881 /**
30882 * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details.
30883 *
30884 * For example, if:
30885 * A: shape(2, 3) = | r1, g1, b1 |
30886 * | r2, g2, b2 |
30887 *
30888 * B: shape(2, 3) = | r3, g3, b3 |
30889 * | r4, g4, b4 |
30890 *
30891 * C = tf.concat2d([A, B], axis)
30892 *
30893 * if axis = 0:
30894 * C: shape(4, 3) = | r1, g1, b1 |
30895 * | r2, g2, b2 |
30896 * | r3, g3, b3 |
30897 * | r4, g4, b4 |
30898 *
30899 * if axis = 1:
30900 * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 |
30901 * | r2, g2, b2, r4, g4, b4 |
30902 *
30903 *
30904 * @param tensors A list of `tf.Tensor`s to concatenate.
30905 * @param axis The axis to concatenate along.
30906 * @return The concatenated array.
30907 */
30908 function concat2d_(tensors, axis) {
30909 return concat$2(tensors, axis);
30910 }
30911 var concat2d = /* @__PURE__ */op({
30912 concat2d_: concat2d_
30913 });
30914
30915 /**
30916 * Concatenates a list of `tf.Tensor3D`s along an axis.
30917 * See `concat` for details.
30918 *
30919 * For example, if:
30920 * A: shape(2, 1, 3) = | r1, g1, b1 |
30921 * | r2, g2, b2 |
30922 *
30923 * B: shape(2, 1, 3) = | r3, g3, b3 |
30924 * | r4, g4, b4 |
30925 *
30926 * C = tf.concat3d([A, B], axis)
30927 *
30928 * if axis = 0:
30929 * C: shape(4, 1, 3) = | r1, g1, b1 |
30930 * | r2, g2, b2 |
30931 * | r3, g3, b3 |
30932 * | r4, g4, b4 |
30933 *
30934 * if axis = 1:
30935 * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 |
30936 * | r2, g2, b2, r4, g4, b4 |
30937 *
30938 * if axis = 2:
30939 * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 |
30940 * | r2, g2, b2, r4, g4, b4 |
30941 *
30942 * @param tensors A list of`tf.Tensor`s to concatenate.
30943 * @param axis The axis to concate along.
30944 * @return The concatenated array.
30945 */
30946 function concat3d_(tensors, axis) {
30947 return concat$2(tensors, axis);
30948 }
30949 var concat3d = /* @__PURE__ */op({
30950 concat3d_: concat3d_
30951 });
30952
30953 /**
30954 * Concatenates a list of `tf.Tensor4D`s along an axis.
30955 * See `concat` for details.
30956 *
30957 * @param tensors A list of `tf.Tensor`s to concatenate.
30958 * @param axis The axis to concate along.
30959 * @return The concatenated array.
30960 */
30961 function concat4d_(tensors, axis) {
30962 return concat$2(tensors, axis);
30963 }
30964 var concat4d = /* @__PURE__ */op({
30965 concat4d_: concat4d_
30966 });
30967
30968 /**
30969 * @license
30970 * Copyright 2020 Google LLC. All Rights Reserved.
30971 * Licensed under the Apache License, Version 2.0 (the "License");
30972 * you may not use this file except in compliance with the License.
30973 * You may obtain a copy of the License at
30974 *
30975 * http://www.apache.org/licenses/LICENSE-2.0
30976 *
30977 * Unless required by applicable law or agreed to in writing, software
30978 * distributed under the License is distributed on an "AS IS" BASIS,
30979 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30980 * See the License for the specific language governing permissions and
30981 * limitations under the License.
30982 * =============================================================================
30983 */
30984 /**
30985 * Computes a 2D convolution over the input x.
30986 *
30987 * @param x The input tensor, of rank 4 or rank 3, of shape
30988 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
30989 * assumed.
30990 * @param filter The filter, rank 4, of shape
30991 * `[filterHeight, filterWidth, inDepth, outDepth]`.
30992 * @param strides The strides of the convolution: `[strideHeight,
30993 * strideWidth]`.
30994 * @param pad The type of padding algorithm.
30995 * - `same` and stride 1: output will be of same size as input,
30996 * regardless of filter size.
30997 * - `valid`: output will be smaller than input if filter is larger
30998 * than 1x1.
30999 * - For more info, see this guide:
31000 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
31001 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
31002 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
31003 * "NHWC". Specify the data format of the input and output data. With the
31004 * default format "NHWC", the data is stored in the order of: [batch,
31005 * height, width, channels].
31006 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
31007 * in which we sample input values across the height and width dimensions
31008 * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
31009 * number, then `dilationHeight == dilationWidth`. If it is greater than
31010 * 1, then all values of `strides` must be 1.
31011 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
31012 * provided, it will default to truncate.
31013 *
31014 * @doc {heading: 'Operations', subheading: 'Convolution'}
31015 */
31016 function conv2d_(x, filter, strides, pad) {
31017 var dataFormat = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 'NHWC';
31018 var dilations = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : [1, 1];
31019 var dimRoundingMode = arguments.length > 6 ? arguments[6] : undefined;
31020 var $x = convertToTensor(x, 'x', 'conv2d', 'float32');
31021 var $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');
31022 var x4D = $x;
31023 var reshapedTo4D = false;
31024 if ($x.rank === 3) {
31025 reshapedTo4D = true;
31026 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
31027 }
31028 assert$1(x4D.rank === 4, function () {
31029 return "Error in conv2d: input must be rank 4, but got rank ".concat(x4D.rank, ".");
31030 });
31031 assert$1($filter.rank === 4, function () {
31032 return "Error in conv2d: filter must be rank 4, but got rank " + "".concat($filter.rank, ".");
31033 });
31034 checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode);
31035 var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
31036 assert$1(inDepth === $filter.shape[2], function () {
31037 return "Error in conv2d: depth of input (".concat(inDepth, ") must match ") + "input depth for filter ".concat($filter.shape[2], ".");
31038 });
31039 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
31040 return 'Error in conv2D: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
31041 });
31042 assert$1(stridesOrDilationsArePositive(dilations), function () {
31043 return 'Error in conv2D: Dilated rates should be larger than 0.';
31044 });
31045 assert$1(stridesOrDilationsArePositive(strides), function () {
31046 return 'Error in conv2D: Strides should be larger than 0.';
31047 });
31048 var inputs = {
31049 x: x4D,
31050 filter: $filter
31051 };
31052 var attrs = {
31053 strides: strides,
31054 pad: pad,
31055 dataFormat: dataFormat,
31056 dilations: dilations,
31057 dimRoundingMode: dimRoundingMode
31058 };
31059 // tslint:disable-next-line: no-unnecessary-type-assertion
31060 var res = ENGINE.runKernel(Conv2D$1, inputs, attrs);
31061 if (reshapedTo4D) {
31062 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
31063 }
31064 return res;
31065 }
31066 var conv2d$4 = /* @__PURE__ */op({
31067 conv2d_: conv2d_
31068 });
31069
31070 /**
31071 * Computes a 1D convolution over the input x.
31072 *
31073 * @param x The input tensor, of rank 3 or rank 2, of shape
31074 * `[batch, width, inChannels]`. If rank 2, batch of 1 is assumed.
31075 * @param filter The filter, rank 3, of shape
31076 * `[filterWidth, inDepth, outDepth]`.
31077 * @param stride The number of entries by which the filter is moved right at
31078 * each step.
31079 * @param pad The type of padding algorithm.
31080 * - `same` and stride 1: output will be of same size as input,
31081 * regardless of filter size.
31082 * - `valid`: output will be smaller than input if filter is larger
31083 * than 1x1.
31084 * - For more info, see this guide:
31085 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
31086 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
31087 * @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC",
31088 * the data is stored in the order of [batch, in_width, in_channels]. Only
31089 * "NWC" is currently supported.
31090 * @param dilation The dilation rate in which we sample input values in
31091 * atrous convolution. Defaults to `1`. If it is greater than 1, then
31092 * stride must be `1`.
31093 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
31094 * provided, it will default to truncate.
31095 *
31096 * @doc {heading: 'Operations', subheading: 'Convolution'}
31097 */
31098 function conv1d_(x, filter, stride, pad) {
31099 var dataFormat = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 'NWC';
31100 var dilation = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 1;
31101 var dimRoundingMode = arguments.length > 6 ? arguments[6] : undefined;
31102 var $x = convertToTensor(x, 'x', 'conv1d');
31103 var $filter = convertToTensor(filter, 'filter', 'conv1d');
31104 var x3D = $x;
31105 var reshapedTo3D = false;
31106 if ($x.rank === 2) {
31107 reshapedTo3D = true;
31108 x3D = reshape$3($x, [1, $x.shape[0], $x.shape[1]]);
31109 }
31110 assert$1(x3D.rank === 3, function () {
31111 return "Error in conv1d: input must be rank 3, but got rank ".concat(x3D.rank, ".");
31112 });
31113 assert$1($filter.rank === 3, function () {
31114 return "Error in conv1d: filter must be rank 3, but got rank " + "".concat($filter.rank, ".");
31115 });
31116 checkPadOnDimRoundingMode('conv1d', pad, dimRoundingMode);
31117 assert$1(x3D.shape[2] === $filter.shape[1], function () {
31118 return "Error in conv1d: depth of input (".concat(x3D.shape[2], ") must match ") + "input depth for filter ".concat($filter.shape[1], ".");
31119 });
31120 assert$1(eitherStridesOrDilationsAreOne(stride, dilation), function () {
31121 return 'Error in conv1D: Either stride or dilation must be 1. ' + "Got stride ".concat(stride, " and dilation '").concat(dilation, "'");
31122 });
31123 assert$1(stridesOrDilationsArePositive(dilation), function () {
31124 return 'Error in conv1D: Dilated rates should be larger than 0.';
31125 });
31126 assert$1(stridesOrDilationsArePositive(stride), function () {
31127 return 'Error in conv1D: Stride should be larger than 0.';
31128 });
31129 assert$1(dataFormat === 'NWC', function () {
31130 return "Error in conv1d: got dataFormat of ".concat(dataFormat, " but only NWC is currently supported.");
31131 });
31132 var filter4D = reshape$3($filter, [1, $filter.shape[0], $filter.shape[1], $filter.shape[2]]);
31133 var input4D = reshape$3(x3D, [x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]]);
31134 var strides = [1, stride];
31135 var dilations = [1, dilation];
31136 var conv2dDataFormat = 'NHWC';
31137 var res = conv2d$4(input4D, filter4D, strides, pad, conv2dDataFormat, dilations, dimRoundingMode);
31138 if (reshapedTo3D) {
31139 return reshape$3(res, [res.shape[2], res.shape[3]]);
31140 }
31141 return reshape$3(res, [res.shape[0], res.shape[2], res.shape[3]]);
31142 }
31143 var conv1d$2 = /* @__PURE__ */op({
31144 conv1d_: conv1d_
31145 });
31146
31147 /**
31148 * @license
31149 * Copyright 2020 Google LLC. All Rights Reserved.
31150 * Licensed under the Apache License, Version 2.0 (the "License");
31151 * you may not use this file except in compliance with the License.
31152 * You may obtain a copy of the License at
31153 *
31154 * http://www.apache.org/licenses/LICENSE-2.0
31155 *
31156 * Unless required by applicable law or agreed to in writing, software
31157 * distributed under the License is distributed on an "AS IS" BASIS,
31158 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31159 * See the License for the specific language governing permissions and
31160 * limitations under the License.
31161 * =============================================================================
31162 */
31163 /**
31164 * Computes the derivative of the input of a 2D convolution.
31165 *
31166 * @param xShape The shape of the input: [batch, height, width, inDepth].
31167 * If length of 3, batch of 1 is assumed.
31168 * @param dy The derivative of the output, of rank 4 or rank 3 of shape
31169 * `[batch, outHeight, outWidth, outDepth]`. If rank 3, batch of 1 is
31170 * assumed.
31171 * @param filter The filter, rank 4, of shape
31172 * `[filterHeight, filterWidth, inDepth, outDepth]`.
31173 * @param strides The strides of the convolution: `[strideHeight,
31174 * strideWidth]`.
31175 * @param pad The type of padding algorithm used:
31176 * - `same` and stride 1: output will be of same size as input,
31177 * regardless of filter size.
31178 * - `valid`: output will be smaller than input if filter is larger
31179 * than 1x1.
31180 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
31181 * "NHWC". Specify the data format of the input and output data. With the
31182 * default format "NHWC", the data is stored in the order of: [batch,
31183 * height, width, channels].
31184 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
31185 * provided, it will default to truncate.
31186 */
31187 function conv2DBackpropInput_(xShape, dy, filter, strides, pad) {
31188 var dataFormat = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 'NHWC';
31189 var dimRoundingMode = arguments.length > 6 ? arguments[6] : undefined;
31190 assert$1(xShape.length === dy.rank, function () {
31191 return "Length of inShape " + "(".concat(xShape.length, ") and rank of dy (").concat(dy.rank, ") must match");
31192 });
31193 var xShape4D = xShape;
31194 var dy4D = dy;
31195 var reshapedTo4D = false;
31196 if (dy.rank === 3) {
31197 reshapedTo4D = true;
31198 dy4D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
31199 xShape4D = [1, xShape[0], xShape[1], xShape[2]];
31200 }
31201 assert$1(xShape4D.length === 4, function () {
31202 return "Error in conv2dDerInput: inShape must be length 4, but got length " + "".concat(xShape4D.length, ".");
31203 });
31204 assert$1(dy4D.rank === 4, function () {
31205 return "Error in conv2dDerInput: dy must be rank 4, but got " + "rank ".concat(dy4D.rank);
31206 });
31207 assert$1(filter.rank === 4, function () {
31208 return "Error in conv2dDerInput: filter must be rank 4, but got " + "rank ".concat(filter.rank);
31209 });
31210 var inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1];
31211 var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
31212 assert$1(inDepth === filter.shape[2], function () {
31213 return "Error in conv2dDerInput: depth of input (".concat(inDepth, ") must ") + "match input depth for filter ".concat(filter.shape[2], ".");
31214 });
31215 assert$1(outDepth === filter.shape[3], function () {
31216 return "Error in conv2dDerInput: depth of output (".concat(outDepth, ") must ") + "match output depth for filter ".concat(filter.shape[3], ".");
31217 });
31218 checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode);
31219 var inputs = {
31220 dy: dy4D,
31221 filter: filter
31222 };
31223 var attrs = {
31224 strides: strides,
31225 pad: pad,
31226 dataFormat: dataFormat,
31227 dimRoundingMode: dimRoundingMode,
31228 inputShape: xShape4D
31229 };
31230 // tslint:disable-next-line: no-unnecessary-type-assertion
31231 var res = ENGINE.runKernel(Conv2DBackpropInput, inputs, attrs);
31232 if (reshapedTo4D) {
31233 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
31234 }
31235 return res;
31236 }
31237 var conv2DBackpropInput$2 = /* @__PURE__ */op({
31238 conv2DBackpropInput_: conv2DBackpropInput_
31239 });
31240
31241 /**
31242 * Computes the transposed 2D convolution of an image, also known as a
31243 * deconvolution.
31244 *
31245 * @param x The input image, of rank 4 or rank 3, of shape
31246 * `[batch, height, width, inDepth]`. If rank 3, batch of 1 is assumed.
31247 * @param filter The filter, rank 4, of shape
31248 * `[filterHeight, filterWidth, outDepth, inDepth]`.
31249 * `inDepth` must match `inDepth` in `x`.
31250 * @param outputShape Output shape, of rank 4 or rank 3:
31251 * `[batch, height, width, outDepth]`. If rank 3, batch of 1 is assumed.
31252 * @param strides The strides of the original convolution:
31253 * `[strideHeight, strideWidth]`.
31254 * @param pad The type of padding algorithm used in the non-transpose version
31255 * of the op.
31256 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
31257 * provided, it will default to truncate.
31258 *
31259 * @doc {heading: 'Operations', subheading: 'Convolution'}
31260 */
31261 function conv2dTranspose_(x, filter, outputShape, strides, pad, dimRoundingMode) {
31262 var $x = convertToTensor(x, 'x', 'conv2dTranspose');
31263 var $filter = convertToTensor(filter, 'filter', 'conv2dTranspose');
31264 return conv2DBackpropInput$2(outputShape, $x, $filter, strides, pad, 'NHWC', dimRoundingMode);
31265 }
31266 var conv2dTranspose$1 = /* @__PURE__ */op({
31267 conv2dTranspose_: conv2dTranspose_
31268 });
31269
31270 /**
31271 * @license
31272 * Copyright 2020 Google LLC. All Rights Reserved.
31273 * Licensed under the Apache License, Version 2.0 (the "License");
31274 * you may not use this file except in compliance with the License.
31275 * You may obtain a copy of the License at
31276 *
31277 * http://www.apache.org/licenses/LICENSE-2.0
31278 *
31279 * Unless required by applicable law or agreed to in writing, software
31280 * distributed under the License is distributed on an "AS IS" BASIS,
31281 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31282 * See the License for the specific language governing permissions and
31283 * limitations under the License.
31284 * =============================================================================
31285 */
31286 /**
31287 * Computes a 3D convolution over the input x.
31288 *
31289 * @param x The input tensor, of rank 5 or rank 4, of shape
31290 * `[batch, depth, height, width, channels]`. If rank 4,
31291 * batch of 1 is assumed.
31292 * @param filter The filter, rank 5, of shape
31293 * `[filterDepth, filterHeight, filterWidth, inChannels, outChannels]`.
31294 * inChannels must match between input and filter.
31295 * @param strides The strides of the convolution: `[strideDepth, strideHeight,
31296 * strideWidth]`.
31297 * @param pad The type of padding algorithm.
31298 * - `same` and stride 1: output will be of same size as input,
31299 * regardless of filter size.
31300 * - `valid`: output will be smaller than input if filter is larger
31301 * than 1x1.
31302 * - For more info, see this guide:
31303 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
31304 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
31305 * @param dataFormat: An optional string from: "NDHWC", "NCDHW". Defaults to
31306 * "NDHWC". Specify the data format of the input and output data. With the
31307 * default format "NDHWC", the data is stored in the order of: [batch,
31308 * depth, height, width, channels]. Only "NDHWC" is currently supported.
31309 * @param dilations The dilation rates: `[dilationDepth, dilationHeight,
31310 * dilationWidth]` in which we sample input values across the height
31311 * and width dimensions in atrous convolution. Defaults to `[1, 1, 1]`.
31312 * If `dilations` is a single number, then
31313 * `dilationDepth == dilationHeight == dilationWidth`. If it is greater
31314 * than 1, then all values of `strides` must be 1.
31315 *
31316 * @doc {heading: 'Operations', subheading: 'Convolution'}
31317 */
31318 function conv3d_(x, filter, strides, pad) {
31319 var dataFormat = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 'NDHWC';
31320 var dilations = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : [1, 1, 1];
31321 var $x = convertToTensor(x, 'x', 'conv3d');
31322 var $filter = convertToTensor(filter, 'filter', 'conv3d');
31323 var x5D = $x;
31324 var reshapedTo5D = false;
31325 if ($x.rank === 4) {
31326 reshapedTo5D = true;
31327 x5D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
31328 }
31329 assert$1(x5D.rank === 5, function () {
31330 return "Error in conv3d: input must be rank 5, but got rank ".concat(x5D.rank, ".");
31331 });
31332 assert$1($filter.rank === 5, function () {
31333 return "Error in conv3d: filter must be rank 5, but got rank " + "".concat($filter.rank, ".");
31334 });
31335 assert$1(x5D.shape[4] === $filter.shape[3], function () {
31336 return "Error in conv3d: depth of input (".concat(x5D.shape[4], ") must match ") + "input depth for filter ".concat($filter.shape[3], ".");
31337 });
31338 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
31339 return 'Error in conv3D: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
31340 });
31341 assert$1(dataFormat === 'NDHWC', function () {
31342 return "Error in conv3d: got dataFormat of ".concat(dataFormat, " but only NDHWC is currently supported.");
31343 });
31344 assert$1(stridesOrDilationsArePositive(dilations), function () {
31345 return 'Error in conv3D: Dilated rates should be larger than 0.';
31346 });
31347 assert$1(stridesOrDilationsArePositive(strides), function () {
31348 return 'Error in conv3D: Strides should be larger than 0.';
31349 });
31350 var inputs = {
31351 x: x5D,
31352 filter: $filter
31353 };
31354 var attrs = {
31355 strides: strides,
31356 pad: pad,
31357 dataFormat: dataFormat,
31358 dilations: dilations
31359 };
31360 // tslint:disable-next-line: no-unnecessary-type-assertion
31361 var res = ENGINE.runKernel(Conv3D$1, inputs, attrs);
31362 if (reshapedTo5D) {
31363 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
31364 }
31365 return res;
31366 }
31367 var conv3d$2 = /* @__PURE__ */op({
31368 conv3d_: conv3d_
31369 });
31370
31371 /**
31372 * @license
31373 * Copyright 2020 Google LLC. All Rights Reserved.
31374 * Licensed under the Apache License, Version 2.0 (the "License");
31375 * you may not use this file except in compliance with the License.
31376 * You may obtain a copy of the License at
31377 *
31378 * http://www.apache.org/licenses/LICENSE-2.0
31379 *
31380 * Unless required by applicable law or agreed to in writing, software
31381 * distributed under the License is distributed on an "AS IS" BASIS,
31382 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31383 * See the License for the specific language governing permissions and
31384 * limitations under the License.
31385 * =============================================================================
31386 */
31387 /**
31388 * Computes the derivative of the input of a 3D convolution.
31389 *
31390 * @param xShape The shape of the input: [batch, depth, height, width,
31391 * in_channels]. If length of 4, batch of 1 is assumed.
31392 * @param dy The derivative of the output, of rank 5 or rank 4 of shape
31393 * `[batch, outDepth, outHeight, outWidth, in_channels]`.
31394 * If rank 4, batch of 1 is assumed.
31395 * @param filter The filter, rank 5, of shape
31396 * `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`.
31397 * @param strides The strides of the convolution: `[strideDepth, strideHeight,
31398 * strideWidth]`.
31399 * @param pad The type of padding algorithm used:
31400 * - `same` and stride 1: output will be of same size as input,
31401 * regardless of filter size.
31402 * - `valid`: output will be smaller than input if filter is larger
31403 * than 1x1.
31404 */
31405 function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
31406 assert$1(xShape.length === dy.rank, function () {
31407 return "Length of inShape " + "(".concat(xShape.length, ") and rank of dy (").concat(dy.rank, ") must match");
31408 });
31409 var xShape5D = xShape;
31410 var dy5D = dy;
31411 var reshapedTo5D = false;
31412 if (dy.rank === 4) {
31413 reshapedTo5D = true;
31414 dy5D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
31415 xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
31416 }
31417 var inDepth = xShape5D[4];
31418 var outDepth = dy5D.shape[4];
31419 assert$1(xShape5D.length === 5, function () {
31420 return "Error in conv3dDerInput: inShape must be length 5, but got length " + "".concat(xShape5D.length, ".");
31421 });
31422 assert$1(dy5D.rank === 5, function () {
31423 return "Error in conv3dDerInput: dy must be rank 5, but got " + "rank ".concat(dy5D.rank);
31424 });
31425 assert$1(filter.rank === 5, function () {
31426 return "Error in conv3dDerInput: filter must be rank 5, but got " + "rank ".concat(filter.rank);
31427 });
31428 assert$1(inDepth === filter.shape[3], function () {
31429 return "Error in conv3dDerInput: depth of input (".concat(inDepth, ") must ") + "match input depth for filter ".concat(filter.shape[3], ".");
31430 });
31431 assert$1(outDepth === filter.shape[4], function () {
31432 return "Error in conv3dDerInput: depth of output (".concat(outDepth, ") must ") + "match output depth for filter ".concat(filter.shape[4], ".");
31433 });
31434 var inputs = {
31435 dy: dy5D,
31436 filter: filter
31437 };
31438 var attrs = {
31439 pad: pad,
31440 strides: strides,
31441 inputShape: xShape5D
31442 };
31443 // tslint:disable-next-line: no-unnecessary-type-assertion
31444 var res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs);
31445 if (reshapedTo5D) {
31446 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
31447 }
31448 return res;
31449 }
31450 var conv3DBackpropInput$1 = /* @__PURE__ */op({
31451 conv3DBackpropInput_: conv3DBackpropInput_
31452 });
31453
31454 /**
31455 * Computes the transposed 3D convolution of a volume, also known as a
31456 * deconvolution.
31457 *
31458 * @param x The input image, of rank 5 or rank 4, of shape
31459 * `[batch, depth, height, width, inDepth]`. If rank 4, batch of 1 is assumed.
31460 * @param filter The filter, rank 4, of shape
31461 * `[depth, filterHeight, filterWidth, outDepth, inDepth]`.
31462 * `inDepth` must match `inDepth` in `x`.
31463 * @param outputShape Output shape, of rank 5 or rank 4:
31464 * `[batch, depth, height, width, outDepth]`. If rank 3, batch of 1 is
31465 * assumed.
31466 * @param strides The strides of the original convolution:
31467 * `[strideDepth, strideHeight, strideWidth]`.
31468 * @param pad The type of padding algorithm used in the non-transpose version
31469 * of the op.
31470 *
31471 * @doc {heading: 'Operations', subheading: 'Convolution'}
31472 */
31473 function conv3dTranspose_(x, filter, outputShape, strides, pad) {
31474 var $x = convertToTensor(x, 'x', 'conv3dTranspose');
31475 var $filter = convertToTensor(filter, 'filter', 'conv3dTranspose');
31476 return conv3DBackpropInput$1(outputShape, $x, $filter, strides, pad);
31477 }
31478 var conv3dTranspose$1 = /* @__PURE__ */op({
31479 conv3dTranspose_: conv3dTranspose_
31480 });
31481
31482 /**
31483 * @license
31484 * Copyright 2018 Google LLC. All Rights Reserved.
31485 * Licensed under the Apache License, Version 2.0 (the "License");
31486 * you may not use this file except in compliance with the License.
31487 * You may obtain a copy of the License at
31488 *
31489 * http://www.apache.org/licenses/LICENSE-2.0
31490 *
31491 * Unless required by applicable law or agreed to in writing, software
31492 * distributed under the License is distributed on an "AS IS" BASIS,
31493 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31494 * See the License for the specific language governing permissions and
31495 * limitations under the License.
31496 * =============================================================================
31497 */
31498 /**
31499 * Computes cos of the input `tf.Tensor` element-wise: `cos(x)`
31500 *
31501 * ```js
31502 * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
31503 *
31504 * x.cos().print(); // or tf.cos(x)
31505 * ```
31506 * @param x The input tensor. Must be float32 type.
31507 *
31508 * @doc {heading: 'Operations', subheading: 'Basic math'}
31509 */
31510 function cos_(x) {
31511 var $x = convertToTensor(x, 'x', 'cos', 'float32');
31512 var inputs = {
31513 x: $x
31514 };
31515 return ENGINE.runKernel(Cos, inputs);
31516 }
31517 var cos$2 = /* @__PURE__ */op({
31518 cos_: cos_
31519 });
31520
31521 /**
31522 * @license
31523 * Copyright 2018 Google LLC. All Rights Reserved.
31524 * Licensed under the Apache License, Version 2.0 (the "License");
31525 * you may not use this file except in compliance with the License.
31526 * You may obtain a copy of the License at
31527 *
31528 * http://www.apache.org/licenses/LICENSE-2.0
31529 *
31530 * Unless required by applicable law or agreed to in writing, software
31531 * distributed under the License is distributed on an "AS IS" BASIS,
31532 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31533 * See the License for the specific language governing permissions and
31534 * limitations under the License.
31535 * =============================================================================
31536 */
31537 /**
31538 * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)`
31539 *
31540 * ```js
31541 * const x = tf.tensor1d([0, 1, -1, .7]);
31542 *
31543 * x.cosh().print(); // or tf.cosh(x)
31544 * ```
31545 * @param x The input tensor. Must be float32 type.
31546 *
31547 * @doc {heading: 'Operations', subheading: 'Basic math'}
31548 */
31549 function cosh_(x) {
31550 var $x = convertToTensor(x, 'x', 'cosh', 'float32');
31551 var inputs = {
31552 x: $x
31553 };
31554 return ENGINE.runKernel(Cosh, inputs);
31555 }
31556 var cosh$2 = /* @__PURE__ */op({
31557 cosh_: cosh_
31558 });
31559
31560 /**
31561 * @license
31562 * Copyright 2022 Google LLC. All Rights Reserved.
31563 * Licensed under the Apache License, Version 2.0 (the 'License');
31564 * you may not use this file except in compliance with the License.
31565 * You may obtain a copy of the License at
31566 *
31567 * http://www.apache.org/licenses/LICENSE-2.0
31568 *
31569 * Unless required by applicable law or agreed to in writing, software
31570 * distributed under the License is distributed on an 'AS IS' BASIS,
31571 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31572 * See the License for the specific language governing permissions and
31573 * limitations under the License.
31574 * =============================================================================
31575 */
31576 /**
31577 * Computes the cumulative product of a `tf.Tensor` along `axis`.
31578 *
31579 * ```js
31580 * const x = tf.tensor([1, 2, 3, 4]);
31581 * x.cumprod().print();
31582 * ```
31583 * ```js
31584 * const x = tf.tensor([[1, 2], [3, 4]]);
31585 * x.cumprod().print();
31586 * ```
31587 *
31588 * @param x The input tensor to cumulatively multiply.
31589 * @param axis The axis along which to multiply. Optional. Defaults to 0.
31590 * @param exclusive Whether to perform exclusive cumulative product. Optional.
31591 * Defaults to false. If set to true then the product of each tensor entry
31592 * does not include its own value, but only the values previous to it
31593 * along the specified axis.
31594 * @param reverse Whether to multiply in the opposite direction. Optional.
31595 * Defaults to false.
31596 *
31597 * @doc {heading: 'Operations', subheading: 'Scan'}
31598 */
31599 function cumprod_(x) {
31600 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
31601 var exclusive = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
31602 var reverse = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
31603 var $x = convertToTensor(x, 'x', 'cumprod');
31604 var inputs = {
31605 x: $x
31606 };
31607 var attrs = {
31608 axis: axis,
31609 exclusive: exclusive,
31610 reverse: reverse
31611 };
31612 return ENGINE.runKernel(Cumprod, inputs, attrs);
31613 }
31614 var cumprod$2 = /* @__PURE__ */op({
31615 cumprod_: cumprod_
31616 });
31617
31618 /**
31619 * @license
31620 * Copyright 2018 Google LLC. All Rights Reserved.
31621 * Licensed under the Apache License, Version 2.0 (the "License");
31622 * you may not use this file except in compliance with the License.
31623 * You may obtain a copy of the License at
31624 *
31625 * http://www.apache.org/licenses/LICENSE-2.0
31626 *
31627 * Unless required by applicable law or agreed to in writing, software
31628 * distributed under the License is distributed on an "AS IS" BASIS,
31629 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31630 * See the License for the specific language governing permissions and
31631 * limitations under the License.
31632 * =============================================================================
31633 */
31634 /**
31635 * Computes the cumulative sum of a `tf.Tensor` along `axis`.
31636 *
31637 * ```js
31638 * const x = tf.tensor([1, 2, 3, 4]);
31639 * x.cumsum().print();
31640 * ```
31641 * ```js
31642 * const x = tf.tensor([[1, 2], [3, 4]]);
31643 * x.cumsum().print();
31644 * ```
31645 *
31646 * @param x The input tensor to be summed.
31647 * @param axis The axis along which to sum. Optional. Defaults to 0.
31648 * @param exclusive Whether to perform exclusive cumulative sum. Optional.
31649 * Defaults to false. If set to true then the sum of each tensor entry
31650 * does not include its own value, but only the values previous to it
31651 * along the specified axis.
31652 * @param reverse Whether to sum in the opposite direction. Optional.
31653 * Defaults to false.
31654 *
31655 * @doc {heading: 'Operations', subheading: 'Scan'}
31656 */
31657 function cumsum_(x) {
31658 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
31659 var exclusive = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
31660 var reverse = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
31661 var $x = convertToTensor(x, 'x', 'cumsum');
31662 var inputs = {
31663 x: $x
31664 };
31665 var attrs = {
31666 axis: axis,
31667 exclusive: exclusive,
31668 reverse: reverse
31669 };
31670 return ENGINE.runKernel(Cumsum, inputs, attrs);
31671 }
31672 var cumsum$2 = /* @__PURE__ */op({
31673 cumsum_: cumsum_
31674 });
31675
31676 /**
31677 * @license
31678 * Copyright 2020 Google LLC. All Rights Reserved.
31679 * Licensed under the Apache License, Version 2.0 (the "License");
31680 * you may not use this file except in compliance with the License.
31681 * You may obtain a copy of the License at
31682 *
31683 * http://www.apache.org/licenses/LICENSE-2.0
31684 *
31685 * Unless required by applicable law or agreed to in writing, software
31686 * distributed under the License is distributed on an "AS IS" BASIS,
31687 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31688 * See the License for the specific language governing permissions and
31689 * limitations under the License.
31690 * =============================================================================
31691 */
31692 /**
31693 * Outputs a vector with length `size` and the same dtype as `weights`.
31694 *
31695 * If `weights` are empty, then index `i` stores the number of times the value
31696 * `i` is counted in `x`. If `weights` are non-empty, then index `i` stores the
31697 * sum of the value in `weights` at each index where the corresponding value in
31698 * `x` is `i`.
31699 *
31700 * Values in `x` outside of the range [0, size) are ignored.
31701 *
31702 * @param x The input int tensor, rank 1 or rank 2.
31703 * @param weights The weights tensor, must have the same shape as x, or a
31704 * length-0 Tensor, in which case it acts as all weights equal to 1.
31705 * @param size Non-negative integer.
31706 * @param binaryOutput Optional. Whether the kernel should count the appearance
31707 * or number of occurrences. Defaults to False.
31708 *
31709 * @doc {heading: 'Operations', subheading: 'Reduction'}
31710 */
31711 function denseBincount_(x, weights, size) {
31712 var binaryOutput = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
31713 var $x = convertToTensor(x, 'x', 'denseBincount');
31714 var $weights = convertToTensor(weights, 'weights', 'denseBincount');
31715 assert$1($x.dtype === 'int32', function () {
31716 return "Error in denseBincount: input " + "dtype must be int32, but got ".concat($x.dtype);
31717 });
31718 assert$1($x.rank <= 2, function () {
31719 return "Error in denseBincount: input must be at most rank 2, but got " + "rank ".concat($x.rank, ".");
31720 });
31721 assert$1(size >= 0, function () {
31722 return "size must be non-negative, but got ".concat(size, ".");
31723 });
31724 assert$1($weights.size === $x.size || $weights.size === 0, function () {
31725 return "Error in denseBincount: weights must have the same shape as x or " + "0-length, but got x shape: ".concat($x.shape, ", weights shape: ") + "".concat($weights.shape, ".");
31726 });
31727 var inputs = {
31728 x: $x,
31729 weights: $weights
31730 };
31731 var attrs = {
31732 size: size,
31733 binaryOutput: binaryOutput
31734 };
31735 return ENGINE.runKernel(DenseBincount, inputs, attrs);
31736 }
31737 var denseBincount$2 = /* @__PURE__ */op({
31738 denseBincount_: denseBincount_
31739 });
31740
31741 /**
31742 * @license
31743 * Copyright 2020 Google LLC. All Rights Reserved.
31744 * Licensed under the Apache License, Version 2.0 (the "License");
31745 * you may not use this file except in compliance with the License.
31746 * You may obtain a copy of the License at
31747 *
31748 * http://www.apache.org/licenses/LICENSE-2.0
31749 *
31750 * Unless required by applicable law or agreed to in writing, software
31751 * distributed under the License is distributed on an "AS IS" BASIS,
31752 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31753 * See the License for the specific language governing permissions and
31754 * limitations under the License.
31755 * =============================================================================
31756 */
31757 /**
31758 * Rearranges data from depth into blocks of spatial data. More specifically,
31759 * this op outputs a copy of the input tensor where values from the `depth`
31760 * dimension are moved in spatial blocks to the `height` and `width` dimensions.
31761 * The attr `blockSize` indicates the input block size and how the data is
31762 * moved.
31763 *
31764 * - Chunks of data of size `blockSize * blockSize` from depth are rearranged
31765 * into non-overlapping blocks of size `blockSize x blockSize`
31766 *
31767 * - The width the output tensor is `inputWidth * blockSize`, whereas the
31768 * height is `inputHeight * blockSize`
31769 *
31770 * - The Y, X coordinates within each block of the output image are determined
31771 * by the high order component of the input channel index
31772 *
31773 * - The depth of the input tensor must be divisible by `blockSize *
31774 * blockSize`
31775 *
31776 * The `dataFormat` attr specifies the layout of the input and output tensors
31777 * with the following options: "NHWC": [ `batch, height, width, channels` ]
31778 * "NCHW": [ `batch, channels, height, width` ]
31779 *
31780 * ```js
31781 * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]);
31782 * const blockSize = 2;
31783 * const dataFormat = "NHWC";
31784 *
31785 * tf.depthToSpace(x, blockSize, dataFormat).print();
31786 * ```
31787 *
31788 * @param x The input tensor of rank 4
31789 * @param blockSIze An `int` that is `>= 2`. The size of the spatial block
31790 * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC"
31791 *
31792 * @doc {heading: 'Tensors', subheading: 'Transformations'}
31793 */
31794 function depthToSpace_(x, blockSize) {
31795 var dataFormat = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 'NHWC';
31796 var $x = convertToTensor(x, 'x', 'depthToSpace', 'float32');
31797 var inputHeight = dataFormat === 'NHWC' ? $x.shape[1] : $x.shape[2];
31798 var inputWidth = dataFormat === 'NHWC' ? $x.shape[2] : $x.shape[3];
31799 var inputDepth = dataFormat === 'NHWC' ? $x.shape[3] : $x.shape[1];
31800 assert$1(blockSize > 1, function () {
31801 return "blockSize should be > 1 for depthToSpace, but was: ".concat(blockSize);
31802 });
31803 assert$1(inputHeight * blockSize >= 0, function () {
31804 return "Negative dimension size caused by overflow when multiplying\n ".concat(inputHeight, " and ").concat(blockSize, " for depthToSpace with input shape\n ").concat($x.shape);
31805 });
31806 assert$1(inputWidth * blockSize >= 0, function () {
31807 return "Negative dimension size caused by overflow when multiplying\n ".concat(inputWidth, " and ").concat(blockSize, " for depthToSpace with input shape\n ").concat($x.shape);
31808 });
31809 assert$1(inputDepth % (blockSize * blockSize) === 0, function () {
31810 return "Dimension size must be evenly divisible by ".concat(blockSize * blockSize, " but is ").concat(inputDepth, " for depthToSpace with input shape ").concat($x.shape);
31811 });
31812 var inputs = {
31813 x: $x
31814 };
31815 var attrs = {
31816 blockSize: blockSize,
31817 dataFormat: dataFormat
31818 };
31819 return ENGINE.runKernel(DepthToSpace, inputs, attrs);
31820 }
31821 var depthToSpace$2 = /* @__PURE__ */op({
31822 depthToSpace_: depthToSpace_
31823 });
31824
31825 /**
31826 * @license
31827 * Copyright 2020 Google LLC. All Rights Reserved.
31828 * Licensed under the Apache License, Version 2.0 (the "License");
31829 * you may not use this file except in compliance with the License.
31830 * You may obtain a copy of the License at
31831 *
31832 * http://www.apache.org/licenses/LICENSE-2.0
31833 *
31834 * Unless required by applicable law or agreed to in writing, software
31835 * distributed under the License is distributed on an "AS IS" BASIS,
31836 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31837 * See the License for the specific language governing permissions and
31838 * limitations under the License.
31839 * =============================================================================
31840 */
31841 /**
31842 * Depthwise 2D convolution.
31843 *
31844 * Given a 4D `input` array and a `filter` array of shape
31845 * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
31846 * `inChannels` convolutional filters of depth 1, this op applies a
31847 * different filter to each input channel (expanding from 1 channel to
31848 * `channelMultiplier` channels for each), then concatenates the results
31849 * together. The output has `inChannels * channelMultiplier` channels.
31850 *
31851 * See
31852 * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
31853 * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
31854 * for more details.
31855 *
31856 * @param x The input tensor, of rank 4 or rank 3, of shape
31857 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
31858 * assumed.
31859 * @param filter The filter tensor, rank 4, of shape
31860 * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
31861 * @param strides The strides of the convolution: `[strideHeight,
31862 * strideWidth]`. If strides is a single number, then `strideHeight ==
31863 * strideWidth`.
31864 * @param pad The type of padding algorithm.
31865 * - `same` and stride 1: output will be of same size as input,
31866 * regardless of filter size.
31867 * - `valid`: output will be smaller than input if filter is larger
31868 * than 1x1.
31869 * - For more info, see this guide:
31870 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
31871 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
31872 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
31873 * in which we sample input values across the height and width dimensions
31874 * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
31875 * number, then `dilationHeight == dilationWidth`. If it is greater than
31876 * 1, then all values of `strides` must be 1.
31877 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
31878 * "NHWC". Specify the data format of the input and output data. With the
31879 * default format "NHWC", the data is stored in the order of: [batch,
31880 * height, width, channels]. Only "NHWC" is currently supported.
31881 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
31882 * provided, it will default to truncate.
31883 *
31884 * @doc {heading: 'Operations', subheading: 'Convolution'}
31885 */
31886 function depthwiseConv2d_(x, filter, strides, pad) {
31887 var dataFormat = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 'NHWC';
31888 var dilations = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : [1, 1];
31889 var dimRoundingMode = arguments.length > 6 ? arguments[6] : undefined;
31890 var $x = convertToTensor(x, 'x', 'depthwiseConv2d', 'float32');
31891 var $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d', 'float32');
31892 var x4D = $x;
31893 var reshapedTo4D = false;
31894 if ($x.rank === 3) {
31895 reshapedTo4D = true;
31896 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
31897 }
31898 assert$1(x4D.rank === 4, function () {
31899 return "Error in depthwiseConv2d: input must be rank 4, but got " + "rank ".concat(x4D.rank, ".");
31900 });
31901 assert$1($filter.rank === 4, function () {
31902 return "Error in depthwiseConv2d: filter must be rank 4, but got rank " + "".concat($filter.rank, ".");
31903 });
31904 var inChannels = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
31905 assert$1(inChannels === $filter.shape[2], function () {
31906 return "Error in depthwiseConv2d: number of input channels " + "(".concat(inChannels, ") must match the inChannels dimension in ") + "filter ".concat($filter.shape[2], ".");
31907 });
31908 checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode);
31909 var inputs = {
31910 x: x4D,
31911 filter: $filter
31912 };
31913 var attrs = {
31914 strides: strides,
31915 pad: pad,
31916 dataFormat: dataFormat,
31917 dilations: dilations,
31918 dimRoundingMode: dimRoundingMode
31919 };
31920 // tslint:disable-next-line: no-unnecessary-type-assertion
31921 var res = ENGINE.runKernel(DepthwiseConv2dNative, inputs, attrs);
31922 if (reshapedTo4D) {
31923 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
31924 }
31925 return res;
31926 }
31927 var depthwiseConv2d$3 = /* @__PURE__ */op({
31928 depthwiseConv2d_: depthwiseConv2d_
31929 });
31930
31931 /**
31932 * @license
31933 * Copyright 2020 Google LLC. All Rights Reserved.
31934 * Licensed under the Apache License, Version 2.0 (the "License");
31935 * you may not use this file except in compliance with the License.
31936 * You may obtain a copy of the License at
31937 *
31938 * http://www.apache.org/licenses/LICENSE-2.0
31939 *
31940 * Unless required by applicable law or agreed to in writing, software
31941 * distributed under the License is distributed on an "AS IS" BASIS,
31942 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31943 * See the License for the specific language governing permissions and
31944 * limitations under the License.
31945 * =============================================================================
31946 */
31947 /**
31948 * Returns a diagonal tensor with given diagonal values.
31949 *
31950 * Given a diagonal, this operation returns a tensor with the diagonal and
31951 * everything else padded with zeros.
31952 *
31953 * Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor
31954 * of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]`
31955 *
31956 * ```js
31957 * const x = tf.tensor1d([1, 2, 3, 4]);
31958 *
31959 * tf.diag(x).print()
31960 * ```
31961 * ```js
31962 * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [4, 2])
31963 *
31964 * tf.diag(x).print()
31965 * ```
31966 * @param x The input tensor.
31967 *
31968 * @doc {heading: 'Tensors', subheading: 'Creation'}
31969 */
31970 function diag_(x) {
31971 var $x = convertToTensor(x, 'x', 'diag');
31972 var inputs = {
31973 x: $x
31974 };
31975 return ENGINE.runKernel(Diag, inputs);
31976 }
31977 var diag$2 = /* @__PURE__ */op({
31978 diag_: diag_
31979 });
31980
31981 /**
31982 * @license
31983 * Copyright 2020 Google LLC. All Rights Reserved.
31984 * Licensed under the Apache License, Version 2.0 (the "License");
31985 * you may not use this file except in compliance with the License.
31986 * You may obtain a copy of the License at
31987 *
31988 * http://www.apache.org/licenses/LICENSE-2.0
31989 *
31990 * Unless required by applicable law or agreed to in writing, software
31991 * distributed under the License is distributed on an "AS IS" BASIS,
31992 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31993 * See the License for the specific language governing permissions and
31994 * limitations under the License.
31995 * =============================================================================
31996 */
31997 /**
31998 * Computes the grayscale dilation over the input `x`.
31999 *
32000 * @param x The input tensor, rank 3 or rank 4 of shape
32001 * `[batch, height, width, depth]`. If rank 3, batch of 1 is assumed.
32002 * @param filter The filter tensor, rank 3, of shape
32003 * `[filterHeight, filterWidth, depth]`.
32004 * @param strides The strides of the sliding window for each dimension of the
32005 * input tensor: `[strideHeight, strideWidth]`.
32006 * If `strides` is a single number,
32007 * then `strideHeight == strideWidth`.
32008 * @param pad The type of padding algorithm.
32009 * - `same` and stride 1: output will be of same size as input,
32010 * regardless of filter size.
32011 * - `valid`: output will be smaller than input if filter is larger
32012 * than 1*1x1.
32013 * - For more info, see this guide:
32014 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
32015 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
32016 * @param dataFormat Specify the data format of the input and output data.
32017 * Defaults to 'NHWC'. Only 'NHWC' is currently supported. With the
32018 * default format "NHWC", the data is stored in the order of: [batch,
32019 * height, width, channels].
32020 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
32021 * in which we sample input values across the height and width dimensions
32022 * for atrous morphological dilation. Defaults to `[1, 1]`. If `dilations`
32023 * is a single number, then `dilationHeight == dilationWidth`. If it is
32024 * greater than 1, then all values of `strides` must be 1.
32025 *
32026 * @doc {heading: 'Operations', subheading: 'Convolution'}
32027 */
32028 function dilation2d_(x, filter, strides, pad) {
32029 var dilations = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : [1, 1];
32030 var dataFormat = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 'NHWC';
32031 var $x = convertToTensor(x, 'x', 'dilation2d');
32032 var $filter = convertToTensor(filter, 'filter', 'dilation2d');
32033 assert$1($x.rank === 3 || $x.rank === 4, function () {
32034 return "Error in dilation2d: input must be rank 3 or 4, but got rank " + "".concat($x.rank, ".");
32035 });
32036 assert$1($filter.rank === 3, function () {
32037 return "Error in dilation2d: filter must be rank 3, but got rank " + "".concat($filter.rank, ".");
32038 });
32039 assert$1(dataFormat === 'NHWC', function () {
32040 return "Error in dilation2d: Only NHWC is currently supported, " + "but got dataFormat of ".concat(dataFormat);
32041 });
32042 var x4D = $x;
32043 var reshapedTo4D = false;
32044 if ($x.rank === 3) {
32045 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
32046 reshapedTo4D = true;
32047 }
32048 assert$1(x4D.shape[3] === $filter.shape[2], function () {
32049 return "Error in dilation2d: input and filter must have the same depth: ".concat(x4D.shape[3], " vs ").concat($filter.shape[2]);
32050 });
32051 var inputs = {
32052 x: x4D,
32053 filter: $filter
32054 };
32055 var attrs = {
32056 strides: strides,
32057 pad: pad,
32058 dilations: dilations
32059 };
32060 // tslint:disable-next-line: no-unnecessary-type-assertion
32061 var res = ENGINE.runKernel(Dilation2D, inputs, attrs);
32062 if (reshapedTo4D) {
32063 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
32064 }
32065 return res;
32066 }
32067 var dilation2d = /* @__PURE__ */op({
32068 dilation2d_: dilation2d_
32069 });
32070
32071 /**
32072 * @license
32073 * Copyright 2017 Google LLC. All Rights Reserved.
32074 * Licensed under the Apache License, Version 2.0 (the "License");
32075 * you may not use this file except in compliance with the License.
32076 * You may obtain a copy of the License at
32077 *
32078 * http://www.apache.org/licenses/LICENSE-2.0
32079 *
32080 * Unless required by applicable law or agreed to in writing, software
32081 * distributed under the License is distributed on an "AS IS" BASIS,
32082 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32083 * See the License for the specific language governing permissions and
32084 * limitations under the License.
32085 * =============================================================================
32086 */
32087 /**
32088 * Returns the dimensions in the input shape that are broadcasted to
32089 * produce the provided output shape.
32090 *
32091 * The returned dimensions are 0-indexed and sorted. An example:
32092 * inShape = [4, 1, 3]
32093 * outShape = [5, 4, 3, 3]
32094 * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.
32095 */
32096 function getBroadcastDims$1(inShape, outShape) {
32097 var inRank = inShape.length;
32098 var dims = [];
32099 for (var i = 0; i < inRank; i++) {
32100 var dim = inRank - 1 - i;
32101 var a = inShape[dim] || 1;
32102 var b = outShape[outShape.length - 1 - i] || 1;
32103 if (b > 1 && a === 1) {
32104 dims.unshift(dim);
32105 }
32106 }
32107 return dims;
32108 }
32109 /**
32110 * Returns the axes in the output space that should be reduced to produce
32111 * the input space.
32112 */
32113 function getReductionAxes(inShape, outShape) {
32114 var result = [];
32115 for (var i = 0; i < outShape.length; i++) {
32116 var inDim = inShape[inShape.length - i - 1];
32117 var outAxis = outShape.length - i - 1;
32118 var outDim = outShape[outAxis];
32119 if (inDim == null || inDim === 1 && outDim > 1) {
32120 result.unshift(outAxis);
32121 }
32122 }
32123 return result;
32124 }
32125 function assertAndGetBroadcastShape(shapeA, shapeB) {
32126 var l = Math.max(shapeA.length, shapeB.length);
32127 var result = new Array(l);
32128 for (var i = 0; i < l; i++) {
32129 var a = shapeA[shapeA.length - i - 1];
32130 if (a == null) {
32131 a = 1;
32132 }
32133 var b = shapeB[shapeB.length - i - 1];
32134 if (b == null) {
32135 b = 1;
32136 }
32137 if (a === 1) {
32138 result[l - i - 1] = b;
32139 } else if (b === 1) {
32140 result[l - i - 1] = a;
32141 } else if (a !== b) {
32142 var errMsg = "Operands could not be broadcast together with shapes " + "".concat(shapeA, " and ").concat(shapeB, ".");
32143 throw Error(errMsg);
32144 } else {
32145 result[l - i - 1] = a;
32146 }
32147 }
32148 return result;
32149 }
32150
32151 var broadcast_util = {
32152 __proto__: null,
32153 assertAndGetBroadcastShape: assertAndGetBroadcastShape,
32154 getBroadcastDims: getBroadcastDims$1,
32155 getReductionAxes: getReductionAxes
32156 };
32157
32158 /**
32159 * Returns the truth value of (a == b) element-wise. Supports broadcasting.
32160 *
32161 * ```js
32162 * const a = tf.tensor1d([1, 2, 3]);
32163 * const b = tf.tensor1d([2, 2, 2]);
32164 *
32165 * a.equal(b).print();
32166 * ```
32167 *
32168 * @param a The first input tensor.
32169 * @param b The second input tensor. Must have the same dtype as `a`.
32170 *
32171 * @doc {heading: 'Operations', subheading: 'Logical'}
32172 */
32173 function equal_(a, b) {
32174 var $a = convertToTensor(a, 'a', 'equal', 'string_or_numeric');
32175 var $b = convertToTensor(b, 'b', 'equal', 'string_or_numeric');
32176 var _makeTypesMatch = makeTypesMatch($a, $b);
32177 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
32178 $a = _makeTypesMatch2[0];
32179 $b = _makeTypesMatch2[1];
32180 assertAndGetBroadcastShape($a.shape, $b.shape);
32181 var inputs = {
32182 a: $a,
32183 b: $b
32184 };
32185 return ENGINE.runKernel(Equal, inputs);
32186 }
32187 var equal$2 = /* @__PURE__ */op({
32188 equal_: equal_
32189 });
32190
32191 /**
32192 * @license
32193 * Copyright 2020 Google LLC. All Rights Reserved.
32194 * Licensed under the Apache License, Version 2.0 (the "License");
32195 * you may not use this file except in compliance with the License.
32196 * You may obtain a copy of the License at
32197 *
32198 * http://www.apache.org/licenses/LICENSE-2.0
32199 *
32200 * Unless required by applicable law or agreed to in writing, software
32201 * distributed under the License is distributed on an "AS IS" BASIS,
32202 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32203 * See the License for the specific language governing permissions and
32204 * limitations under the License.
32205 * =============================================================================
32206 */
32207 /**
32208 * Returns the elements, either `a` or `b` depending on the `condition`.
32209 *
32210 * If the condition is true, select from `a`, otherwise select from `b`.
32211 *
32212 * ```js
32213 * const cond = tf.tensor1d([false, false, true], 'bool');
32214 * const a = tf.tensor1d([1 , 2, 3]);
32215 * const b = tf.tensor1d([-1, -2, -3]);
32216 *
32217 * a.where(cond, b).print();
32218 * ```
32219 *
32220 * @param condition The input condition. Must be of dtype bool.
32221 * @param a If `condition` is rank 1, `a` may have a higher rank but
32222 * its first dimension must match the size of `condition`.
32223 * @param b A tensor with the same dtype as `a` and with shape that is
32224 * compatible with `a`.
32225 * @return A tensor with same dtype as `a` and `b`, and shape that is
32226 * broadcastable from `a` and `b`.
32227 *
32228 * @doc {heading: 'Operations', subheading: 'Logical'}
32229 */
32230 function where_(condition, a, b) {
32231 var $a = convertToTensor(a, 'a', 'where');
32232 var $b = convertToTensor(b, 'b', 'where');
32233 var $condition = convertToTensor(condition, 'condition', 'where', 'bool');
32234 // TODO: move this logic to forward function when the broadcastTo op is
32235 // implemented in WASM.
32236 // Find the broadcastable shape for $condition, $a, and $b.
32237 var broadcastShape = assertAndGetBroadcastShape(assertAndGetBroadcastShape($condition.shape, $a.shape), $b.shape);
32238 var $broadcastedCondition = broadcastTo($condition, broadcastShape);
32239 var $broadcastedA = broadcastTo($a, broadcastShape);
32240 var $broadcastedB = broadcastTo($b, broadcastShape);
32241 var inputs = {
32242 condition: $broadcastedCondition,
32243 t: $broadcastedA,
32244 e: $broadcastedB
32245 };
32246 return ENGINE.runKernel(Select, inputs);
32247 }
32248 var where = /* @__PURE__ */op({
32249 where_: where_
32250 });
32251
32252 /**
32253 * @license
32254 * Copyright 2018 Google LLC. All Rights Reserved.
32255 * Licensed under the Apache License, Version 2.0 (the "License");
32256 * you may not use this file except in compliance with the License.
32257 * You may obtain a copy of the License at
32258 *
32259 * http://www.apache.org/licenses/LICENSE-2.0
32260 *
32261 * Unless required by applicable law or agreed to in writing, software
32262 * distributed under the License is distributed on an "AS IS" BASIS,
32263 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32264 * See the License for the specific language governing permissions and
32265 * limitations under the License.
32266 * =============================================================================
32267 */
32268 /**
32269 * Creates a `tf.Tensor` with all elements set to 0 with the same shape as the
32270 * given tensor.
32271 *
32272 * ```js
32273 * const x = tf.tensor([1, 2]);
32274 * tf.zerosLike(x).print();
32275 * ```
32276 *
32277 * @param x The tensor of required shape.
32278 *
32279 * @doc {heading: 'Tensors', subheading: 'Creation'}
32280 */
32281 function zerosLike_(x) {
32282 var $x = convertToTensor(x, 'x', 'zerosLike');
32283 var inputs = {
32284 x: $x
32285 };
32286 return ENGINE.runKernel(ZerosLike, inputs);
32287 }
32288 var zerosLike$3 = /* @__PURE__ */op({
32289 zerosLike_: zerosLike_
32290 });
32291
32292 /**
32293 * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0
32294 * if denominator is 0.
32295 *
32296 *
32297 * ```js
32298 * const a = tf.tensor1d([1, 4, 9, 16]);
32299 * const b = tf.tensor1d([1, 2, 3, 4]);
32300 * const c = tf.tensor1d([0, 0, 0, 0]);
32301 *
32302 * a.divNoNan(b).print(); // or tf.divNoNan(a, b)
32303 * a.divNoNan(c).print(); // or tf.divNoNan(a, c)
32304 * ```
32305 *
32306 * ```js
32307 * // Broadcast div a with b.
32308 * const a = tf.tensor1d([2, 4, 6, 8]);
32309 * const b = tf.scalar(2);
32310 * const c = tf.scalar(0);
32311 *
32312 * a.divNoNan(b).print(); // or tf.divNoNan(a, b)
32313 * a.divNoNan(c).print(); // or tf.divNoNan(a, c)
32314 * ```
32315 *
32316 * @param a The first tensor as the numerator.
32317 * @param b The second tensor as the denominator. Must have the same dtype as
32318 * `a`.
32319 *
32320 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
32321 */
32322 function divNoNan_(a, b) {
32323 // TODO: Make this into its own kernel.
32324 var $a = convertToTensor(a, 'a', 'div');
32325 var $b = convertToTensor(b, 'b', 'div');
32326 var _makeTypesMatch = makeTypesMatch($a, $b);
32327 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
32328 $a = _makeTypesMatch2[0];
32329 $b = _makeTypesMatch2[1];
32330 var divResult = div$1($a, $b);
32331 var zeros = zerosLike$3(divResult);
32332 var bEqualsZero = equal$2($b, zeros);
32333 return where(bEqualsZero, zeros, divResult);
32334 }
32335 var divNoNan = /* @__PURE__ */op({
32336 divNoNan_: divNoNan_
32337 });
32338
32339 /**
32340 * @license
32341 * Copyright 2020 Google LLC. All Rights Reserved.
32342 * Licensed under the Apache License, Version 2.0 (the "License");
32343 * you may not use this file except in compliance with the License.
32344 * You may obtain a copy of the License at
32345 *
32346 * http://www.apache.org/licenses/LICENSE-2.0
32347 *
32348 * Unless required by applicable law or agreed to in writing, software
32349 * distributed under the License is distributed on an "AS IS" BASIS,
32350 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32351 * See the License for the specific language governing permissions and
32352 * limitations under the License.
32353 * =============================================================================
32354 */
32355 /**
32356 * Computes the dot product of two matrices and/or vectors, `t1` and `t2`.
32357 *
32358 * ```js
32359 * const a = tf.tensor1d([1, 2]);
32360 * const b = tf.tensor2d([[1, 2], [3, 4]]);
32361 * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
32362 *
32363 * a.dot(b).print(); // or tf.dot(a, b)
32364 * b.dot(a).print();
32365 * b.dot(c).print();
32366 * ```
32367 * @param t1 The first tensor in the dot operation.
32368 * @param t2 The second tensor in the dot operation.
32369 *
32370 * @doc {heading: 'Operations', subheading: 'Matrices'}
32371 */
32372 function dot_(t1, t2) {
32373 var $t1 = convertToTensor(t1, 't1', 'dot');
32374 var $t2 = convertToTensor(t2, 't2', 'dot');
32375 assert$1(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), function () {
32376 return "Error in dot: inputs must all be rank 1 or 2, but got ranks " + "".concat($t1.rank, " and ").concat($t2.rank, ".");
32377 });
32378 var t1Inner = $t1.rank === 1 ? $t1.size : $t1.shape[1];
32379 var t2Inner = $t2.rank === 1 ? $t2.size : $t2.shape[0];
32380 assert$1(t1Inner === t2Inner, function () {
32381 return "Error in dot: inner dimensions of inputs must match, but got " + "".concat(t1Inner, " and ").concat(t2Inner, ".");
32382 });
32383 if ($t1.rank === 1 && $t2.rank === 1) {
32384 var t12D = reshape$3($t1, [1, -1]);
32385 var t22D = reshape$3($t2, [-1, 1]);
32386 var t1t2 = matMul$1(t12D, t22D);
32387 return reshape$3(t1t2, []);
32388 } else if ($t1.rank === 1 && $t2.rank === 2) {
32389 var _t12D = reshape$3($t1, [1, -1]);
32390 var _t22D = reshape$3($t2, [$t2.shape[0], $t2.shape[1]]);
32391 var _t1t = matMul$1(_t12D, _t22D);
32392 return reshape$3(_t1t, [_t1t.size]);
32393 } else if ($t1.rank === 2 && $t2.rank === 1) {
32394 var _t22D2 = reshape$3($t2, [-1, 1]);
32395 var _t1t2 = matMul$1($t1, _t22D2);
32396 return reshape$3(_t1t2, [_t1t2.size]);
32397 } else {
32398 var _t22D3 = reshape$3($t2, [$t2.shape[0], $t2.shape[1]]);
32399 var _t1t3 = matMul$1($t1, _t22D3);
32400 return _t1t3;
32401 }
32402 }
32403 var dot$2 = /* @__PURE__ */op({
32404 dot_: dot_
32405 });
32406
32407 /**
32408 * @license
32409 * Copyright 2021 Google LLC. All Rights Reserved.
32410 * Licensed under the Apache License, Version 2.0 (the "License");
32411 * you may not use this file except in compliance with the License.
32412 * You may obtain a copy of the License at
32413 *
32414 * http://www.apache.org/licenses/LICENSE-2.0
32415 *
32416 * Unless required by applicable law or agreed to in writing, software
32417 * distributed under the License is distributed on an "AS IS" BASIS,
32418 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32419 * See the License for the specific language governing permissions and
32420 * limitations under the License.
32421 * =============================================================================
32422 */
32423 /**
32424 * Tensor contraction over specified indices and outer product.
32425 *
32426 * `einsum` allows defining Tensors by defining their element-wise computation.
32427 * This computation is based on
32428 * [Einstein summation](https://en.wikipedia.org/wiki/Einstein_notation).
32429 *
32430 * Some special cases include:
32431 *
32432 * Matrix multiplication:
32433 * ```js
32434 * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
32435 * const y = tf.tensor2d([[0, 1], [2, 3], [4, 5]]);
32436 * x.print();
32437 * y.print();
32438 * tf.einsum('ij,jk->ik', x, y).print();
32439 * ```
32440 *
32441 * Dot product:
32442 * ```js
32443 * const x = tf.tensor1d([1, 2, 3]);
32444 * const y = tf.tensor1d([0, 1, 2]);
32445 * x.print();
32446 * y.print();
32447 * tf.einsum('i,i->', x, y).print();
32448 * ```
32449 *
32450 * Batch dot product:
32451 * ```js
32452 * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
32453 * const y = tf.tensor2d([[0, 1, 2], [3, 4, 5]]);
32454 * x.print();
32455 * y.print();
32456 * tf.einsum('bi,bi->b', x, y).print();
32457 * ```
32458 *
32459 * Outer prouduct:
32460 * ```js
32461 * const x = tf.tensor1d([1, 3, 5]);
32462 * const y = tf.tensor1d([2, 4, 6]);
32463 * x.print();
32464 * y.print();
32465 * tf.einsum('i,j->ij', x, y).print();
32466 * ```
32467 *
32468 * Matrix transpose:
32469 * ```js
32470 * const x = tf.tensor2d([[1, 2], [3, 4]]);
32471 * x.print();
32472 * tf.einsum('ij->ji', x).print();
32473 * ```
32474 *
32475 * Batch matrix transpose:
32476 * ```js
32477 * const x = tf.tensor3d([[[1, 2], [3, 4]], [[-1, -2], [-3, -4]]]);
32478 * x.print();
32479 * tf.einsum('bij->bji', x).print();
32480 * ```
32481 *
32482 * Limitations:
32483 *
32484 * This implementation of einsum has the following limitations:
32485 *
32486 * - Does not support >2 input tensors.
32487 * - Does not support duplicate axes for any given input tensor. E.g., equation
32488 * 'ii->' is not supported.
32489 * - The `...` notation is not supported.
32490 *
32491 * @param equation a string describing the contraction, in the same format as
32492 * [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
32493 * @param tensors the input(s) to contract (each one a Tensor), whose shapes
32494 * should be consistent with equation.
32495 * @returns The output tensor.
32496 *
32497 * @doc {heading: 'Tensors', subheading: 'Matrices'}
32498 */
32499 function einsum_(equation) {
32500 for (var _len = arguments.length, tensors = new Array(_len > 1 ? _len - 1 : 0), _key = 1; _key < _len; _key++) {
32501 tensors[_key - 1] = arguments[_key];
32502 }
32503 var $tensors = tensors.map(function (t, i) {
32504 return convertToTensor(t, "tensors".concat(i), 'einsum');
32505 });
32506 var attrs = {
32507 equation: equation
32508 };
32509 return ENGINE.runKernel(Einsum, $tensors, attrs);
32510 }
32511 var einsum$2 = /* @__PURE__ */op({
32512 einsum_: einsum_
32513 });
32514
32515 /**
32516 * @license
32517 * Copyright 2020 Google LLC. All Rights Reserved.
32518 * Licensed under the Apache License, Version 2.0 (the "License");
32519 * you may not use this file except in compliance with the License.
32520 * You may obtain a copy of the License at
32521 *
32522 * http://www.apache.org/licenses/LICENSE-2.0
32523 *
32524 * Unless required by applicable law or agreed to in writing, software
32525 * distributed under the License is distributed on an "AS IS" BASIS,
32526 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32527 * See the License for the specific language governing permissions and
32528 * limitations under the License.
32529 * =============================================================================
32530 */
32531 /**
32532 * Computes exponential linear element-wise: `x > 0 ? x : (e ^ x) - 1`.
32533 *
32534 * ```js
32535 * const x = tf.tensor1d([-1, 1, -3, 2]);
32536 *
32537 * x.elu().print(); // or tf.elu(x)
32538 * ```
32539 * @param x The input tensor.
32540 *
32541 * @doc {heading: 'Operations', subheading: 'Basic math'}
32542 */
32543 function elu_(x) {
32544 var $x = convertToTensor(x, 'x', 'elu', 'float32');
32545 var inputs = {
32546 x: $x
32547 };
32548 return ENGINE.runKernel(Elu$1, inputs);
32549 }
32550 var elu$4 = /* @__PURE__ */op({
32551 elu_: elu_
32552 });
32553
32554 /**
32555 * @license
32556 * Copyright 2023 Google LLC.
32557 * Licensed under the Apache License, Version 2.0 (the "License");
32558 * you may not use this file except in compliance with the License.
32559 * You may obtain a copy of the License at
32560 *
32561 * http://www.apache.org/licenses/LICENSE-2.0
32562 *
32563 * Unless required by applicable law or agreed to in writing, software
32564 * distributed under the License is distributed on an "AS IS" BASIS,
32565 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32566 * See the License for the specific language governing permissions and
32567 * limitations under the License.
32568 * =============================================================================
32569 */
32570 /**
32571 * Checks the input tensor mathes the given shape.
32572 *
32573 * Given an input tensor, returns a new tensor with the same values as the
32574 * input tensor with shape `shape`.
32575 *
32576 * The method supports the null value in tensor. It will still check the shapes,
32577 * and null is a placeholder.
32578 *
32579 *
32580 * ```js
32581 * const x = tf.tensor1d([1, 2, 3, 4]);
32582 * const y = tf.tensor1d([1, null, 3, 4]);
32583 * const z = tf.tensor2d([1, 2, 3, 4], [2,2]);
32584 * tf.ensureShape(x, [4]).print();
32585 * tf.ensureShape(y, [4]).print();
32586 * tf.ensureShape(z, [null, 2]).print();
32587 * ```
32588 *
32589 * @param x The input tensor to be ensured.
32590 * @param shape A TensorShape representing the shape of this tensor, an array
32591 * or null.
32592 *
32593 * @doc {heading: 'Tensors', subheading: 'Transformations'}
32594 */
32595 function ensureShape_(x, shape) {
32596 var $x = convertToTensor(x, 'x', 'ensureShape', 'string_or_numeric');
32597 if (!arraysEqualWithNull($x.shape, shape)) {
32598 throw new Error("EnsureShape: Shape of tensor ".concat($x.shape, " is not compatible with expected shape ").concat(shape));
32599 }
32600 return x;
32601 }
32602 var ensureShape = /* @__PURE__ */op({
32603 ensureShape_: ensureShape_
32604 });
32605
32606 /**
32607 * @license
32608 * Copyright 2018 Google LLC. All Rights Reserved.
32609 * Licensed under the Apache License, Version 2.0 (the "License");
32610 * you may not use this file except in compliance with the License.
32611 * You may obtain a copy of the License at
32612 *
32613 * http://www.apache.org/licenses/LICENSE-2.0
32614 *
32615 * Unless required by applicable law or agreed to in writing, software
32616 * distributed under the License is distributed on an "AS IS" BASIS,
32617 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32618 * See the License for the specific language governing permissions and
32619 * limitations under the License.
32620 * =============================================================================
32621 */
32622 /**
32623 * Computes Gauss error function of the input `tf.Tensor` element-wise:
32624 * `erf(x)`
32625 *
32626 * ```js
32627 * const x = tf.tensor1d([0, .1, -.1, .7]);
32628 *
32629 * x.erf().print(); // or tf.erf(x);
32630 * ```
32631 * @param x The input tensor.
32632 *
32633 * @doc {heading: 'Operations', subheading: 'Basic math'}
32634 */
32635 function erf_(x) {
32636 var $x = convertToTensor(x, 'x', 'erf');
32637 assert$1($x.dtype === 'int32' || $x.dtype === 'float32', function () {
32638 return 'Input dtype must be `int32` or `float32`.';
32639 });
32640 if ($x.dtype === 'int32') {
32641 $x = cast$3($x, 'float32');
32642 }
32643 var inputs = {
32644 x: $x
32645 };
32646 return ENGINE.runKernel(Erf, inputs);
32647 }
32648 var erf$2 = /* @__PURE__ */op({
32649 erf_: erf_
32650 });
32651
32652 /**
32653 * @license
32654 * Copyright 2017 Google LLC. All Rights Reserved.
32655 * Licensed under the Apache License, Version 2.0 (the "License");
32656 * you may not use this file except in compliance with the License.
32657 * You may obtain a copy of the License at
32658 *
32659 * http://www.apache.org/licenses/LICENSE-2.0
32660 *
32661 * Unless required by applicable law or agreed to in writing, software
32662 * distributed under the License is distributed on an "AS IS" BASIS,
32663 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32664 * See the License for the specific language governing permissions and
32665 * limitations under the License.
32666 * =============================================================================
32667 */
32668 /**
32669 * Returns true if the axis specifies the inner most dimensions of the
32670 * array.
32671 */
32672 function axesAreInnerMostDims(axes, rank) {
32673 for (var i = 0; i < axes.length; ++i) {
32674 if (axes[axes.length - i - 1] !== rank - 1 - i) {
32675 return false;
32676 }
32677 }
32678 return true;
32679 }
32680 function combineLocations(outputLoc, reduceLoc, axes) {
32681 var rank = outputLoc.length + reduceLoc.length;
32682 var loc = [];
32683 var outIdx = 0;
32684 var reduceIdx = 0;
32685 for (var dim = 0; dim < rank; dim++) {
32686 if (axes.indexOf(dim) === -1) {
32687 loc.push(outputLoc[outIdx++]);
32688 } else {
32689 loc.push(reduceLoc[reduceIdx++]);
32690 }
32691 }
32692 return loc;
32693 }
32694 function computeOutAndReduceShapes(aShape, axes) {
32695 var outShape = [];
32696 var rank = aShape.length;
32697 for (var dim = 0; dim < rank; dim++) {
32698 if (axes.indexOf(dim) === -1) {
32699 outShape.push(aShape[dim]);
32700 }
32701 }
32702 var reduceShape = axes.map(function (dim) {
32703 return aShape[dim];
32704 });
32705 return [outShape, reduceShape];
32706 }
32707 function expandShapeToKeepDim(shape, axes) {
32708 var reduceSubShape = axes.map(function (x) {
32709 return 1;
32710 });
32711 return combineLocations(shape, reduceSubShape, axes);
32712 }
32713 function assertAxesAreInnerMostDims(msg, axes, rank) {
32714 assert$1(axesAreInnerMostDims(axes, rank), function () {
32715 return "".concat(msg, " supports only inner-most axes for now. ") + "Got axes ".concat(axes, " and rank-").concat(rank, " input.");
32716 });
32717 }
32718 /**
32719 * Returns the axes permutation to be used with `tf.transpose`, if such
32720 * permutation is necessary. Otherwise it returns null. This method is used by
32721 * operations that operate only on inner-most axes.
32722 */
32723 function getAxesPermutation(axes, rank) {
32724 if (axesAreInnerMostDims(axes, rank)) {
32725 return null;
32726 }
32727 var result = [];
32728 for (var i = 0; i < rank; ++i) {
32729 if (axes.indexOf(i) === -1) {
32730 result.push(i);
32731 }
32732 }
32733 axes.forEach(function (axis) {
32734 return result.push(axis);
32735 });
32736 return result;
32737 }
32738 /** Returns the axes permutation that undoes the original permutation. */
32739 function getUndoAxesPermutation(axes) {
32740 return axes.map(function (axis, i) {
32741 return [i, axis];
32742 }).sort(function (a, b) {
32743 return a[1] - b[1];
32744 }).map(function (x) {
32745 return x[0];
32746 });
32747 }
32748 function getInnerMostAxes(numAxes, rank) {
32749 var res = [];
32750 for (var i = rank - numAxes; i < rank; ++i) {
32751 res.push(i);
32752 }
32753 return res;
32754 }
32755
32756 /**
32757 * @license
32758 * Copyright 2020 Google LLC. All Rights Reserved.
32759 * Licensed under the Apache License, Version 2.0 (the "License");
32760 * you may not use this file except in compliance with the License.
32761 * You may obtain a copy of the License at
32762 *
32763 * http://www.apache.org/licenses/LICENSE-2.0
32764 *
32765 * Unless required by applicable law or agreed to in writing, software
32766 * distributed under the License is distributed on an "AS IS" BASIS,
32767 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32768 * See the License for the specific language governing permissions and
32769 * limitations under the License.
32770 * =============================================================================
32771 */
32772 /**
32773 * Computes the maximum of elements across dimensions of a `tf.Tensor`.
32774 *
32775 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
32776 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
32777 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
32778 * length 1. If `axes` has no entries, all dimensions are reduced, and a
32779 * `tf.Tensor` with a single element is returned.
32780 *
32781 * ```js
32782 * const x = tf.tensor1d([1, 2, 3]);
32783 *
32784 * x.max().print(); // or tf.max(x)
32785 * ```
32786 *
32787 * ```js
32788 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
32789 *
32790 * const axis = 1;
32791 * x.max(axis).print(); // or tf.max(x, axis)
32792 * ```
32793 *
32794 * @param x The input tensor.
32795 * @param axis The dimension(s) to reduce. By default it reduces
32796 * all dimensions.
32797 * @param keepDims If true, retains reduced dimensions with size 1.
32798 *
32799 * @doc {heading: 'Operations', subheading: 'Reduction'}
32800 */
32801 function max_(x) {
32802 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null;
32803 var keepDims = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
32804 var $x = convertToTensor(x, 'x', 'max');
32805 var inputs = {
32806 x: $x
32807 };
32808 var attrs = {
32809 reductionIndices: axis,
32810 keepDims: keepDims
32811 };
32812 return ENGINE.runKernel(Max, inputs, attrs);
32813 }
32814 var max$3 = /* @__PURE__ */op({
32815 max_: max_
32816 });
32817
32818 /**
32819 * @license
32820 * Copyright 2020 Google Inc. All Rights Reserved.
32821 * Licensed under the Apache License, Version 2.0 (the "License");
32822 * you may not use this file except in compliance with the License.
32823 * You may obtain a copy of the License at
32824 *
32825 * http://www.apache.org/licenses/LICENSE-2.0
32826 *
32827 * Unless required by applicable law or agreed to in writing, software
32828 * distributed under the License is distributed on an "AS IS" BASIS,
32829 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32830 * See the License for the specific language governing permissions and
32831 * limitations under the License.
32832 * =============================================================================
32833 */
32834 /**
32835 * Computes the minimum value from the input.
32836 *
32837 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
32838 * is true, the rank of the array is reduced by 1 for each entry in `axes`.
32839 * If `keepDims` is true, the reduced dimensions are retained with length 1.
32840 * If `axes` has no entries, all dimensions are reduced, and an array with a
32841 * single element is returned.
32842 *
32843 * ```js
32844 * const x = tf.tensor1d([1, 2, 3]);
32845 *
32846 * x.min().print(); // or tf.min(x)
32847 * ```
32848 *
32849 * ```js
32850 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
32851 *
32852 * const axis = 1;
32853 * x.min(axis).print(); // or tf.min(x, axis)
32854 * ```
32855 *
32856 * @param x The input Tensor.
32857 * @param axis The dimension(s) to reduce. By default it reduces
32858 * all dimensions.
32859 * @param keepDims If true, retains reduced dimensions with size 1.
32860 *
32861 * @doc {heading: 'Operations', subheading: 'Reduction'}
32862 */
32863 function min_(x) {
32864 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null;
32865 var keepDims = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
32866 var $x = convertToTensor(x, 'x', 'min');
32867 var inputs = {
32868 x: $x
32869 };
32870 var attrs = {
32871 axis: axis,
32872 keepDims: keepDims
32873 };
32874 // tslint:disable-next-line: no-unnecessary-type-assertion
32875 return ENGINE.runKernel(Min, inputs, attrs);
32876 }
32877 var min$3 = /* @__PURE__ */op({
32878 min_: min_
32879 });
32880
32881 /**
32882 * Computes the power of one `tf.Tensor` to another. Supports broadcasting.
32883 *
32884 * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for
32885 * corresponding elements in x and y. The result's dtype will be the upcasted
32886 * type of the `base` and `exp` dtypes.
32887 *
32888 * ```js
32889 * const a = tf.tensor([[2, 3], [4, 5]])
32890 * const b = tf.tensor([[1, 2], [3, 0]]).toInt();
32891 *
32892 * a.pow(b).print(); // or tf.pow(a, b)
32893 * ```
32894 *
32895 * ```js
32896 * const a = tf.tensor([[1, 2], [3, 4]])
32897 * const b = tf.tensor(2).toInt();
32898 *
32899 * a.pow(b).print(); // or tf.pow(a, b)
32900 * ```
32901 * We also expose `powStrict` which has the same signature as this op and
32902 * asserts that `base` and `exp` are the same shape (does not broadcast).
32903 *
32904 * @param base The base `tf.Tensor` to pow element-wise.
32905 * @param exp The exponent `tf.Tensor` to pow element-wise.
32906 *
32907 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
32908 */
32909 function pow_(base, exp) {
32910 var $base = convertToTensor(base, 'base', 'pow');
32911 var $exp = convertToTensor(exp, 'exp', 'pow');
32912 var _makeTypesMatch = makeTypesMatch($base, $exp);
32913 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
32914 $base = _makeTypesMatch2[0];
32915 $exp = _makeTypesMatch2[1];
32916 var inputs = {
32917 a: $base,
32918 b: $exp
32919 };
32920 return ENGINE.runKernel(Pow, inputs);
32921 }
32922 var pow$3 = /* @__PURE__ */op({
32923 pow_: pow_
32924 });
32925
32926 /**
32927 * @license
32928 * Copyright 2018 Google LLC. All Rights Reserved.
32929 * Licensed under the Apache License, Version 2.0 (the "License");
32930 * you may not use this file except in compliance with the License.
32931 * You may obtain a copy of the License at
32932 *
32933 * http://www.apache.org/licenses/LICENSE-2.0
32934 *
32935 * Unless required by applicable law or agreed to in writing, software
32936 * distributed under the License is distributed on an "AS IS" BASIS,
32937 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32938 * See the License for the specific language governing permissions and
32939 * limitations under the License.
32940 * =============================================================================
32941 */
32942 /**
32943 * Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype.
32944 *
32945 * The same functionality can be achieved with `tf.tensor`, but in general
32946 * we recommend using `tf.scalar` as it makes the code more readable.
32947 *
32948 * ```js
32949 * tf.scalar(3.14).print();
32950 * ```
32951 *
32952 * @param value The value of the scalar.
32953 * @param dtype The data type.
32954 *
32955 * @doc {heading: 'Tensors', subheading: 'Creation'}
32956 */
32957 function scalar(value, dtype) {
32958 if ((isTypedArray(value) && dtype !== 'string' || Array.isArray(value)) && dtype !== 'complex64') {
32959 throw new Error('Error creating a new Scalar: value must be a primitive ' + '(number|boolean|string)');
32960 }
32961 if (dtype === 'string' && isTypedArray(value) && !(value instanceof Uint8Array)) {
32962 throw new Error('When making a scalar from encoded string, ' + 'the value must be `Uint8Array`.');
32963 }
32964 var shape = [];
32965 var inferredShape = [];
32966 return makeTensor(value, shape, inferredShape, dtype);
32967 }
32968
32969 /**
32970 * @license
32971 * Copyright 2018 Google LLC. All Rights Reserved.
32972 * Licensed under the Apache License, Version 2.0 (the "License");
32973 * you may not use this file except in compliance with the License.
32974 * You may obtain a copy of the License at
32975 *
32976 * http://www.apache.org/licenses/LICENSE-2.0
32977 *
32978 * Unless required by applicable law or agreed to in writing, software
32979 * distributed under the License is distributed on an "AS IS" BASIS,
32980 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32981 * See the License for the specific language governing permissions and
32982 * limitations under the License.
32983 * =============================================================================
32984 */
32985 /**
32986 * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)`
32987 *
32988 * ```js
32989 * const x = tf.tensor1d([1, 2, 4, -1]);
32990 *
32991 * x.sqrt().print(); // or tf.sqrt(x)
32992 * ```
32993 * @param x The input tensor.
32994 *
32995 * @doc {heading: 'Operations', subheading: 'Basic math'}
32996 */
32997 function sqrt_(x) {
32998 var $x = convertToTensor(x, 'x', 'sqrt', 'float32');
32999 var inputs = {
33000 x: $x
33001 };
33002 return ENGINE.runKernel(Sqrt, inputs);
33003 }
33004 var sqrt$2 = /* @__PURE__ */op({
33005 sqrt_: sqrt_
33006 });
33007
33008 /**
33009 * @license
33010 * Copyright 2019 Google LLC. All Rights Reserved.
33011 * Licensed under the Apache License, Version 2.0 (the "License");
33012 * you may not use this file except in compliance with the License.
33013 * You may obtain a copy of the License at
33014 *
33015 * http://www.apache.org/licenses/LICENSE-2.0
33016 *
33017 * Unless required by applicable law or agreed to in writing, software
33018 * distributed under the License is distributed on an "AS IS" BASIS,
33019 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33020 * See the License for the specific language governing permissions and
33021 * limitations under the License.
33022 * =============================================================================
33023 */
33024 /**
33025 * Computes square of `x` element-wise: `x ^ 2`
33026 *
33027 * ```js
33028 * const x = tf.tensor1d([1, 2, Math.sqrt(2), -1]);
33029 *
33030 * x.square().print(); // or tf.square(x)
33031 * ```
33032 * @param x The input Tensor.
33033 *
33034 * @doc {heading: 'Operations', subheading: 'Basic math'}
33035 */
33036 function square_(x) {
33037 var $x = convertToTensor(x, 'x', 'square');
33038 var attrs = {};
33039 return ENGINE.runKernel('Square', {
33040 x: $x
33041 }, attrs);
33042 }
33043 var square$2 = /* @__PURE__ */op({
33044 square_: square_
33045 });
33046
33047 /**
33048 * @license
33049 * Copyright 2018 Google LLC. All Rights Reserved.
33050 * Licensed under the Apache License, Version 2.0 (the "License");
33051 * you may not use this file except in compliance with the License.
33052 * You may obtain a copy of the License at
33053 *
33054 * http://www.apache.org/licenses/LICENSE-2.0
33055 *
33056 * Unless required by applicable law or agreed to in writing, software
33057 * distributed under the License is distributed on an "AS IS" BASIS,
33058 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33059 * See the License for the specific language governing permissions and
33060 * limitations under the License.
33061 * =============================================================================
33062 */
33063 /**
33064 * Computes the sum of elements across dimensions of a `tf.Tensor`.
33065 *
33066 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
33067 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
33068 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
33069 * length 1. If axes has no entries, all dimensions are reduced, and a
33070 * `tf.Tensor` with a single element is returned.
33071 *
33072 * ```js
33073 * const x = tf.tensor1d([1, 2, 3]);
33074 *
33075 * x.sum().print(); // or tf.sum(x)
33076 * ```
33077 *
33078 * ```js
33079 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
33080 *
33081 * const axis = 1;
33082 * x.sum(axis).print(); // or tf.sum(x, axis)
33083 * ```
33084 *
33085 * @param x The input tensor to compute the sum over. If the dtype is `bool`
33086 * it will be converted to `int32` and the output dtype will be `int32`.
33087 * @param axis The dimension(s) to reduce. By default it reduces
33088 * all dimensions.
33089 * @param keepDims If true, retains reduced dimensions with size 1.
33090 *
33091 * @doc {heading: 'Operations', subheading: 'Reduction'}
33092 */
33093 function sum_(x) {
33094 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null;
33095 var keepDims = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
33096 var $x = convertToTensor(x, 'x', 'sum');
33097 if ($x.dtype === 'bool') {
33098 $x = cast$3($x, 'int32');
33099 }
33100 var inputs = {
33101 x: $x
33102 };
33103 var attrs = {
33104 axis: axis,
33105 keepDims: keepDims
33106 };
33107 return ENGINE.runKernel(Sum, inputs, attrs);
33108 }
33109 var sum$3 = /* @__PURE__ */op({
33110 sum_: sum_
33111 });
33112
33113 /**
33114 * @license
33115 * Copyright 2018 Google LLC. All Rights Reserved.
33116 * Licensed under the Apache License, Version 2.0 (the "License");
33117 * you may not use this file except in compliance with the License.
33118 * You may obtain a copy of the License at
33119 *
33120 * http://www.apache.org/licenses/LICENSE-2.0
33121 *
33122 * Unless required by applicable law or agreed to in writing, software
33123 * distributed under the License is distributed on an "AS IS" BASIS,
33124 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33125 * See the License for the specific language governing permissions and
33126 * limitations under the License.
33127 * =============================================================================
33128 */
33129 /**
33130 * Computes the norm of scalar, vectors, and matrices.
33131 * This function can compute several different vector norms (the 1-norm, the
33132 * Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0)
33133 * and matrix norms (Frobenius, 1-norm, and inf-norm).
33134 *
33135 * ```js
33136 * const x = tf.tensor1d([1, 2, 3, 4]);
33137 *
33138 * x.norm().print(); // or tf.norm(x)
33139 * ```
33140 *
33141 * @param x The input array.
33142 * @param ord Optional. Order of the norm. Supported norm types are
33143 * following:
33144 *
33145 * | ord | norm for matrices | norm for vectors
33146 * |------------|---------------------------|---------------------
33147 * |'euclidean' |Frobenius norm |2-norm
33148 * |'fro' |Frobenius norm |
33149 * |Infinity |max(sum(abs(x), axis=1)) |max(abs(x))
33150 * |-Infinity |min(sum(abs(x), axis=1)) |min(abs(x))
33151 * |1 |max(sum(abs(x), axis=0)) |sum(abs(x))
33152 * |2 | |sum(abs(x)^2)^(1/2)
33153 *
33154 * @param axis Optional. If axis is null (the default), the input is
33155 * considered a vector and a single vector norm is computed over the entire
33156 * set of values in the Tensor, i.e. norm(x, ord) is equivalent
33157 * to norm(x.reshape([-1]), ord). If axis is an integer, the input
33158 * is considered a batch of vectors, and axis determines the axis in x
33159 * over which to compute vector norms. If axis is a 2-tuple of integer it is
33160 * considered a batch of matrices and axis determines the axes in NDArray
33161 * over which to compute a matrix norm.
33162 * @param keepDims Optional. If true, the norm has the same dimensionality
33163 * as the input.
33164 *
33165 * @doc {heading: 'Operations', subheading: 'Matrices'}
33166 */
33167 function norm_(x) {
33168 var ord = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 'euclidean';
33169 var axis = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : null;
33170 var keepDims = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
33171 x = convertToTensor(x, 'x', 'norm');
33172 var norm = normImpl(x, ord, axis);
33173 var keepDimsShape = norm.shape;
33174 if (keepDims) {
33175 var axes = parseAxisParam(axis, x.shape);
33176 keepDimsShape = expandShapeToKeepDim(norm.shape, axes);
33177 }
33178 return reshape$3(norm, keepDimsShape);
33179 }
33180 function normImpl(x, p) {
33181 var axis = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : null;
33182 if (x.rank === 0) {
33183 return abs$2(x);
33184 }
33185 // consider vector when no axis is specified
33186 if (x.rank !== 1 && axis === null) {
33187 return normImpl(reshape$3(x, [-1]), p, axis);
33188 }
33189 // vector
33190 if (x.rank === 1 || typeof axis === 'number' || Array.isArray(axis) && axis.length === 1) {
33191 if (p === 1) {
33192 return sum$3(abs$2(x), axis);
33193 }
33194 if (p === Infinity) {
33195 return max$3(abs$2(x), axis);
33196 }
33197 if (p === -Infinity) {
33198 return min$3(abs$2(x), axis);
33199 }
33200 if (p === 'euclidean' || p === 2) {
33201 // norm(x, 2) = sum(abs(xi) ^ 2) ^ 1/2
33202 return sqrt$2(sum$3(pow$3(abs$2(x), scalar(2, 'int32')), axis));
33203 }
33204 throw new Error("Error in norm: invalid ord value: ".concat(p));
33205 }
33206 // matrix (assumption axis[0] < axis[1])
33207 if (Array.isArray(axis) && axis.length === 2) {
33208 if (p === 1) {
33209 return max$3(sum$3(abs$2(x), axis[0]), axis[1] - 1);
33210 }
33211 if (p === Infinity) {
33212 return max$3(sum$3(abs$2(x), axis[1]), axis[0]);
33213 }
33214 if (p === -Infinity) {
33215 return min$3(sum$3(abs$2(x), axis[1]), axis[0]);
33216 }
33217 if (p === 'fro' || p === 'euclidean') {
33218 // norm(x) = sqrt(sum(pow(x, 2)))
33219 return sqrt$2(sum$3(square$2(x), axis));
33220 }
33221 throw new Error("Error in norm: invalid ord value: ".concat(p));
33222 }
33223 throw new Error("Error in norm: invalid axis: ".concat(axis));
33224 }
33225 var norm = /* @__PURE__ */op({
33226 norm_: norm_
33227 });
33228
33229 /**
33230 * @license
33231 * Copyright 2022 Google LLC. All Rights Reserved.
33232 * Licensed under the Apache License, Version 2.0 (the "License");
33233 * you may not use this file except in compliance with the License.
33234 * You may obtain a copy of the License at
33235 *
33236 * http://www.apache.org/licenses/LICENSE-2.0
33237 *
33238 * Unless required by applicable law or agreed to in writing, software
33239 * distributed under the License is distributed on an "AS IS" BASIS,
33240 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33241 * See the License for the specific language governing permissions and
33242 * limitations under the License.
33243 * =============================================================================
33244 */
33245 /**
33246 * Computes the Euclidean norm of scalar, vectors, and matrices.
33247 *
33248 * ```js
33249 * const x = tf.tensor1d([1, 2, 3, 4]);
33250 *
33251 * x.euclideanNorm().print(); // or tf.euclideanNorm(x)
33252 * ```
33253 *
33254 * @param x The input array.
33255 * @param axis Optional. If axis is null (the default), the input is
33256 * considered a vector and a single vector norm is computed over the entire
33257 * set of values in the Tensor, i.e. euclideanNorm(x) is equivalent
33258 * to euclideanNorm(x.reshape([-1])). If axis is an integer, the input
33259 * is considered a batch of vectors, and axis determines the axis in x
33260 * over which to compute vector norms. If axis is a 2-tuple of integer it is
33261 * considered a batch of matrices and axis determines the axes in NDArray
33262 * over which to compute a matrix norm.
33263 * @param keepDims Optional. If true, the norm has the same dimensionality
33264 * as the input.
33265 *
33266 * @doc {heading: 'Operations', subheading: 'Matrices'}
33267 */
33268 function euclideanNorm_(x) {
33269 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null;
33270 var keepDims = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
33271 return norm(x, 'euclidean', axis, keepDims);
33272 }
33273 var euclideanNorm = /* @__PURE__ */op({
33274 euclideanNorm_: euclideanNorm_
33275 });
33276
33277 /**
33278 * @license
33279 * Copyright 2018 Google LLC. All Rights Reserved.
33280 * Licensed under the Apache License, Version 2.0 (the "License");
33281 * you may not use this file except in compliance with the License.
33282 * You may obtain a copy of the License at
33283 *
33284 * http://www.apache.org/licenses/LICENSE-2.0
33285 *
33286 * Unless required by applicable law or agreed to in writing, software
33287 * distributed under the License is distributed on an "AS IS" BASIS,
33288 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33289 * See the License for the specific language governing permissions and
33290 * limitations under the License.
33291 * =============================================================================
33292 */
33293 /**
33294 * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x`
33295 *
33296 * ```js
33297 * const x = tf.tensor1d([1, 2, -3]);
33298 *
33299 * x.exp().print(); // or tf.exp(x)
33300 * ```
33301 * @param x The input tensor.
33302 *
33303 * @doc {heading: 'Operations', subheading: 'Basic math'}
33304 */
33305 function exp_(x) {
33306 var $x = convertToTensor(x, 'x', 'exp');
33307 var inputs = {
33308 x: $x
33309 };
33310 return ENGINE.runKernel(Exp, inputs);
33311 }
33312 var exp$2 = /* @__PURE__ */op({
33313 exp_: exp_
33314 });
33315
33316 /**
33317 * @license
33318 * Copyright 2020 Google LLC. All Rights Reserved.
33319 * Licensed under the Apache License, Version 2.0 (the "License");
33320 * you may not use this file except in compliance with the License.
33321 * You may obtain a copy of the License at
33322 *
33323 * http://www.apache.org/licenses/LICENSE-2.0
33324 *
33325 * Unless required by applicable law or agreed to in writing, software
33326 * distributed under the License is distributed on an "AS IS" BASIS,
33327 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33328 * See the License for the specific language governing permissions and
33329 * limitations under the License.
33330 * =============================================================================
33331 */
33332 /**
33333 * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension
33334 * into the tensor's shape.
33335 *
33336 * ```js
33337 * const x = tf.tensor1d([1, 2, 3, 4]);
33338 * const axis = 1;
33339 * x.expandDims(axis).print();
33340 * ```
33341 *
33342 * @param x The input tensor whose dimensions are to be expanded.
33343 * @param axis The dimension index at which to insert shape of `1`. Defaults
33344 * to 0 (the first dimension).
33345 *
33346 * @doc {heading: 'Tensors', subheading: 'Transformations'}
33347 */
33348 function expandDims_(x) {
33349 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
33350 var $x = convertToTensor(x, 'x', 'expandDims', 'string_or_numeric');
33351 assert$1(axis <= $x.rank, function () {
33352 return 'Axis must be <= rank of the tensor';
33353 });
33354 var inputs = {
33355 input: $x
33356 };
33357 var attrs = {
33358 dim: axis
33359 };
33360 return ENGINE.runKernel(ExpandDims, inputs, attrs);
33361 }
33362 var expandDims$3 = /* @__PURE__ */op({
33363 expandDims_: expandDims_
33364 });
33365
33366 /**
33367 * @license
33368 * Copyright 2018 Google LLC. All Rights Reserved.
33369 * Licensed under the Apache License, Version 2.0 (the "License");
33370 * you may not use this file except in compliance with the License.
33371 * You may obtain a copy of the License at
33372 *
33373 * http://www.apache.org/licenses/LICENSE-2.0
33374 *
33375 * Unless required by applicable law or agreed to in writing, software
33376 * distributed under the License is distributed on an "AS IS" BASIS,
33377 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33378 * See the License for the specific language governing permissions and
33379 * limitations under the License.
33380 * =============================================================================
33381 */
33382 /**
33383 * Computes exponential of the input `tf.Tensor` minus one element-wise.
33384 * `e ^ x - 1`
33385 *
33386 * ```js
33387 * const x = tf.tensor1d([1, 2, -3]);
33388 *
33389 * x.expm1().print(); // or tf.expm1(x)
33390 * ```
33391 * @param x The input tensor.
33392 *
33393 * @doc {heading: 'Operations', subheading: 'Basic math'}
33394 */
33395 function expm1_(x) {
33396 var $x = convertToTensor(x, 'x', 'expm1');
33397 var inputs = {
33398 x: $x
33399 };
33400 return ENGINE.runKernel(Expm1, inputs);
33401 }
33402 var expm1$2 = /* @__PURE__ */op({
33403 expm1_: expm1_
33404 });
33405
33406 /**
33407 * @license
33408 * Copyright 2020 Google LLC. All Rights Reserved.
33409 * Licensed under the Apache License, Version 2.0 (the "License");
33410 * you may not use this file except in compliance with the License.
33411 * You may obtain a copy of the License at
33412 *
33413 * http://www.apache.org/licenses/LICENSE-2.0
33414 *
33415 * Unless required by applicable law or agreed to in writing, software
33416 * distributed under the License is distributed on an "AS IS" BASIS,
33417 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33418 * See the License for the specific language governing permissions and
33419 * limitations under the License.
33420 * =============================================================================
33421 */
33422 /**
33423 * Construct a tensor by repeating it the number of times given by reps.
33424 *
33425 * This operation creates a new tensor by replicating `input` `reps`
33426 * times. The output tensor's `i`th dimension has `input.shape[i] *
33427 * reps[i]` elements, and the values of `input` are replicated
33428 * `reps[i]` times along the `i`th dimension. For example, tiling
33429 * `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`.
33430 *
33431 * ```js
33432 * const a = tf.tensor1d([1, 2]);
33433 *
33434 * a.tile([2]).print(); // or tf.tile(a, [2])
33435 * ```
33436 *
33437 * ```js
33438 * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
33439 *
33440 * a.tile([1, 2]).print(); // or tf.tile(a, [1,2])
33441 * ```
33442 * @param x The tensor to tile.
33443 * @param reps Determines the number of replications per dimension.
33444 *
33445 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
33446 */
33447 function tile_(x, reps) {
33448 var $x = convertToTensor(x, 'x', 'tile', 'string_or_numeric');
33449 assert$1($x.rank === reps.length, function () {
33450 return "Error in transpose: rank of input ".concat($x.rank, " ") + "must match length of reps ".concat(reps, ".");
33451 });
33452 var inputs = {
33453 x: $x
33454 };
33455 var attrs = {
33456 reps: reps
33457 };
33458 return ENGINE.runKernel(Tile, inputs, attrs);
33459 }
33460 var tile$3 = /* @__PURE__ */op({
33461 tile_: tile_
33462 });
33463
33464 /**
33465 * @license
33466 * Copyright 2020 Google LLC. All Rights Reserved.
33467 * Licensed under the Apache License, Version 2.0 (the "License");
33468 * you may not use this file except in compliance with the License.
33469 * You may obtain a copy of the License at
33470 *
33471 * http://www.apache.org/licenses/LICENSE-2.0
33472 *
33473 * Unless required by applicable law or agreed to in writing, software
33474 * distributed under the License is distributed on an "AS IS" BASIS,
33475 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33476 * See the License for the specific language governing permissions and
33477 * limitations under the License.
33478 * =============================================================================
33479 */
33480 /**
33481 * Create an identity matrix.
33482 *
33483 * @param numRows Number of rows.
33484 * @param numColumns Number of columns. Defaults to `numRows`.
33485 * @param batchShape If provided, will add the batch shape to the beginning
33486 * of the shape of the returned `tf.Tensor` by repeating the identity
33487 * matrix.
33488 * @param dtype Data type.
33489 * @returns Identity matrix of the specified size and data type, possibly
33490 * with batch repetition if `batchShape` is specified.
33491 *
33492 * @doc {heading: 'Tensors', subheading: 'Creation'}
33493 */
33494 function eye_(numRows, numColumns, batchShape) {
33495 var dtype = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'float32';
33496 if (numColumns == null) {
33497 numColumns = numRows;
33498 }
33499 var buff = buffer([numRows, numColumns], dtype);
33500 var n = numRows <= numColumns ? numRows : numColumns;
33501 for (var i = 0; i < n; ++i) {
33502 buff.set(1, i, i);
33503 }
33504 var out = reshape$3(buff.toTensor(), [numRows, numColumns]);
33505 if (batchShape == null) {
33506 return out;
33507 } else {
33508 if (batchShape.length === 1) {
33509 return tile$3(expandDims$3(out, 0), [batchShape[0], 1, 1]);
33510 } else if (batchShape.length === 2) {
33511 // tslint:disable-next-line:no-unnecessary-type-assertion
33512 return tile$3(expandDims$3(expandDims$3(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]);
33513 } else if (batchShape.length === 3) {
33514 // tslint:disable-next-line:no-unnecessary-type-assertion
33515 return tile$3(expandDims$3(expandDims$3(expandDims$3(out, 0), 0), 0), [batchShape[0], batchShape[1], batchShape[2], 1, 1]);
33516 } else {
33517 throw new Error("eye() currently supports only 1D and 2D " + // tslint:disable-next-line:no-any
33518 "batchShapes, but received ".concat(batchShape.length, "D."));
33519 }
33520 }
33521 }
33522 var eye = /* @__PURE__ */op({
33523 eye_: eye_
33524 });
33525
33526 /**
33527 * @license
33528 * Copyright 2018 Google LLC. All Rights Reserved.
33529 * Licensed under the Apache License, Version 2.0 (the "License");
33530 * you may not use this file except in compliance with the License.
33531 * You may obtain a copy of the License at
33532 *
33533 * http://www.apache.org/licenses/LICENSE-2.0
33534 *
33535 * Unless required by applicable law or agreed to in writing, software
33536 * distributed under the License is distributed on an "AS IS" BASIS,
33537 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33538 * See the License for the specific language governing permissions and
33539 * limitations under the License.
33540 * =============================================================================
33541 */
33542 /**
33543 * Computes floor of input `tf.Tensor` element-wise: `floor(x)`.
33544 *
33545 * ```js
33546 * const x = tf.tensor1d([.6, 1.1, -3.3]);
33547 *
33548 * x.floor().print(); // or tf.floor(x)
33549 * ```
33550 * @param x The input tensor.
33551 *
33552 * @doc {heading: 'Operations', subheading: 'Basic math'}
33553 */
33554 function floor_(x) {
33555 var $x = convertToTensor(x, 'x', 'floor', 'float32');
33556 var inputs = {
33557 x: $x
33558 };
33559 return ENGINE.runKernel(Floor, inputs);
33560 }
33561 var floor$2 = /* @__PURE__ */op({
33562 floor_: floor_
33563 });
33564
33565 /**
33566 * @license
33567 * Copyright 2018 Google LLC. All Rights Reserved.
33568 * Licensed under the Apache License, Version 2.0 (the "License");
33569 * you may not use this file except in compliance with the License.
33570 * You may obtain a copy of the License at
33571 *
33572 * http://www.apache.org/licenses/LICENSE-2.0
33573 *
33574 * Unless required by applicable law or agreed to in writing, software
33575 * distributed under the License is distributed on an "AS IS" BASIS,
33576 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33577 * See the License for the specific language governing permissions and
33578 * limitations under the License.
33579 * =============================================================================
33580 */
33581 /**
33582 * Gather slices from tensor `x`'s axis `axis` according to `indices`.
33583 *
33584 * ```js
33585 * const x = tf.tensor1d([1, 2, 3, 4]);
33586 * const indices = tf.tensor1d([1, 3, 3], 'int32');
33587 *
33588 * x.gather(indices).print();
33589 * ```
33590 *
33591 * ```js
33592 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
33593 * const indices = tf.tensor1d([1, 1, 0], 'int32');
33594 *
33595 * x.gather(indices).print();
33596 * ```
33597 * @param x The input tensor whose slices are to be gathered.
33598 * @param indices The indices of the values to extract.
33599 * @param axis The axis over which to select values. Defaults to 0.
33600 * @param batchDims Optional. The number of batch dimensions. It must be less
33601 * than or equal to rank(indices). Defaults to 0.
33602 * The output tensor will have shape of
33603 * `x.shape[:axis] + indices.shape[batchDims:] + x.shape[axis + 1:]`
33604 *
33605 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
33606 */
33607 function gather_(x, indices) {
33608 var axis = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
33609 var batchDims = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 0;
33610 var $x = convertToTensor(x, 'x', 'gather');
33611 var $indices = convertToTensor(indices, 'indices', 'gather', 'int32');
33612 var inputs = {
33613 x: $x,
33614 indices: $indices
33615 };
33616 var attrs = {
33617 axis: axis,
33618 batchDims: batchDims
33619 };
33620 return ENGINE.runKernel(GatherV2, inputs, attrs);
33621 }
33622 var gather$1 = /* @__PURE__ */op({
33623 gather_: gather_
33624 });
33625
33626 /**
33627 * Returns the truth value of (a > b) element-wise. Supports broadcasting.
33628 *
33629 * ```js
33630 * const a = tf.tensor1d([1, 2, 3]);
33631 * const b = tf.tensor1d([2, 2, 2]);
33632 *
33633 * a.greater(b).print();
33634 * ```
33635 *
33636 * @param a The first input tensor.
33637 * @param b The second input tensor. Must have the same dtype as `a`.
33638 *
33639 * @doc {heading: 'Operations', subheading: 'Logical'}
33640 */
33641 function greater_(a, b) {
33642 var $a = convertToTensor(a, 'a', 'greater', 'string_or_numeric');
33643 var $b = convertToTensor(b, 'b', 'greater', 'string_or_numeric');
33644 var _makeTypesMatch = makeTypesMatch($a, $b);
33645 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
33646 $a = _makeTypesMatch2[0];
33647 $b = _makeTypesMatch2[1];
33648 assertAndGetBroadcastShape($a.shape, $b.shape);
33649 var inputs = {
33650 a: $a,
33651 b: $b
33652 };
33653 return ENGINE.runKernel(Greater, inputs);
33654 }
33655 var greater$3 = /* @__PURE__ */op({
33656 greater_: greater_
33657 });
33658
33659 /**
33660 * Returns the truth value of (a >= b) element-wise. Supports broadcasting.
33661 *
33662 * ```js
33663 * const a = tf.tensor1d([1, 2, 3]);
33664 * const b = tf.tensor1d([2, 2, 2]);
33665 *
33666 * a.greaterEqual(b).print();
33667 * ```
33668 *
33669 * @param a The first input tensor.
33670 * @param b The second input tensor. Must have the same dtype as `a`.
33671 *
33672 * @doc {heading: 'Operations', subheading: 'Logical'}
33673 */
33674 function greaterEqual_(a, b) {
33675 var $a = convertToTensor(a, 'a', 'greaterEqual', 'string_or_numeric');
33676 var $b = convertToTensor(b, 'b', 'greaterEqual', 'string_or_numeric');
33677 var _makeTypesMatch = makeTypesMatch($a, $b);
33678 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
33679 $a = _makeTypesMatch2[0];
33680 $b = _makeTypesMatch2[1];
33681 assertAndGetBroadcastShape($a.shape, $b.shape);
33682 var inputs = {
33683 a: $a,
33684 b: $b
33685 };
33686 return ENGINE.runKernel(GreaterEqual, inputs);
33687 }
33688 var greaterEqual$2 = /* @__PURE__ */op({
33689 greaterEqual_: greaterEqual_
33690 });
33691
33692 /**
33693 * @license
33694 * Copyright 2020 Google LLC. All Rights Reserved.
33695 * Licensed under the Apache License, Version 2.0 (the "License");
33696 * you may not use this file except in compliance with the License.
33697 * You may obtain a copy of the License at
33698 *
33699 * http://www.apache.org/licenses/LICENSE-2.0
33700 *
33701 * Unless required by applicable law or agreed to in writing, software
33702 * distributed under the License is distributed on an "AS IS" BASIS,
33703 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33704 * See the License for the specific language governing permissions and
33705 * limitations under the License.
33706 * =============================================================================
33707 */
33708 /**
33709 * Returns the imaginary part of a complex (or real) tensor.
33710 *
33711 * Given a tensor input, this operation returns a tensor of type float that is
33712 * the imaginary part of each element in input considered as a complex number.
33713 * If input is real, a tensor of all zeros is returned.
33714 *
33715 * ```js
33716 * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
33717 * tf.imag(x).print();
33718 * ```
33719 *
33720 * @doc {heading: 'Tensors', subheading: 'Creation'}
33721 */
33722 function imag_(input) {
33723 var $input = convertToTensor(input, 'input', 'imag');
33724 var inputs = {
33725 input: $input
33726 };
33727 return ENGINE.runKernel(Imag, inputs);
33728 }
33729 var imag$2 = /* @__PURE__ */op({
33730 imag_: imag_
33731 });
33732
33733 /**
33734 * @license
33735 * Copyright 2018 Google LLC. All Rights Reserved.
33736 * Licensed under the Apache License, Version 2.0 (the "License");
33737 * you may not use this file except in compliance with the License.
33738 * You may obtain a copy of the License at
33739 *
33740 * http://www.apache.org/licenses/LICENSE-2.0
33741 *
33742 * Unless required by applicable law or agreed to in writing, software
33743 * distributed under the License is distributed on an "AS IS" BASIS,
33744 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33745 * See the License for the specific language governing permissions and
33746 * limitations under the License.
33747 * =============================================================================
33748 */
33749 /**
33750 * Returns which elements of x are finite.
33751 *
33752 * ```js
33753 * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
33754 *
33755 * x.isFinite().print(); // or tf.isNaN(x)
33756 * ```
33757 * @param x The input Tensor.
33758 *
33759 * @doc {heading: 'Operations', subheading: 'Basic math'}
33760 */
33761 function isFinite_(x) {
33762 var $x = convertToTensor(x, 'x', 'isFinite');
33763 var inputs = {
33764 x: $x
33765 };
33766 return ENGINE.runKernel(IsFinite, inputs);
33767 }
33768 var isFinite$3 = /* @__PURE__ */op({
33769 isFinite_: isFinite_
33770 });
33771
33772 /**
33773 * @license
33774 * Copyright 2018 Google LLC. All Rights Reserved.
33775 * Licensed under the Apache License, Version 2.0 (the "License");
33776 * you may not use this file except in compliance with the License.
33777 * You may obtain a copy of the License at
33778 *
33779 * http://www.apache.org/licenses/LICENSE-2.0
33780 *
33781 * Unless required by applicable law or agreed to in writing, software
33782 * distributed under the License is distributed on an "AS IS" BASIS,
33783 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33784 * See the License for the specific language governing permissions and
33785 * limitations under the License.
33786 * =============================================================================
33787 */
33788 /**
33789 * Returns which elements of x are Infinity or -Infinity.
33790 *
33791 * ```js
33792 * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
33793 *
33794 * x.isInf().print(); // or tf.isNaN(x)
33795 * ```
33796 * @param x The input Tensor.
33797 *
33798 * @doc {heading: 'Operations', subheading: 'Basic math'}
33799 */
33800 function isInf_(x) {
33801 var $x = convertToTensor(x, 'x', 'isInf');
33802 var inputs = {
33803 x: $x
33804 };
33805 return ENGINE.runKernel(IsInf, inputs);
33806 }
33807 var isInf$2 = /* @__PURE__ */op({
33808 isInf_: isInf_
33809 });
33810
33811 /**
33812 * @license
33813 * Copyright 2018 Google LLC. All Rights Reserved.
33814 * Licensed under the Apache License, Version 2.0 (the "License");
33815 * you may not use this file except in compliance with the License.
33816 * You may obtain a copy of the License at
33817 *
33818 * http://www.apache.org/licenses/LICENSE-2.0
33819 *
33820 * Unless required by applicable law or agreed to in writing, software
33821 * distributed under the License is distributed on an "AS IS" BASIS,
33822 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33823 * See the License for the specific language governing permissions and
33824 * limitations under the License.
33825 * =============================================================================
33826 */
33827 /**
33828 * Returns which elements of x are NaN.
33829 *
33830 * ```js
33831 * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
33832 *
33833 * x.isNaN().print(); // or tf.isNaN(x)
33834 * ```
33835 * @param x The input Tensor.
33836 *
33837 * @doc {heading: 'Operations', subheading: 'Basic math'}
33838 */
33839 function isNaN_(x) {
33840 var $x = convertToTensor(x, 'x', 'isNaN');
33841 var inputs = {
33842 x: $x
33843 };
33844 return ENGINE.runKernel(IsNan, inputs);
33845 }
33846 var isNaN$3 = /* @__PURE__ */op({
33847 isNaN_: isNaN_
33848 });
33849
33850 /**
33851 * @license
33852 * Copyright 2020 Google LLC. All Rights Reserved.
33853 * Licensed under the Apache License, Version 2.0 (the "License");
33854 * you may not use this file except in compliance with the License.
33855 * You may obtain a copy of the License at
33856 *
33857 * http://www.apache.org/licenses/LICENSE-2.0
33858 *
33859 * Unless required by applicable law or agreed to in writing, software
33860 * distributed under the License is distributed on an "AS IS" BASIS,
33861 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33862 * See the License for the specific language governing permissions and
33863 * limitations under the License.
33864 * =============================================================================
33865 */
33866 /**
33867 * Computes leaky rectified linear element-wise.
33868 *
33869 * See
33870 * [http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf](
33871 * http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf)
33872 *
33873 * ```js
33874 * const x = tf.tensor1d([-1, 2, -3, 4]);
33875 *
33876 * x.leakyRelu(0.1).print(); // or tf.leakyRelu(x, 0.1)
33877 * ```
33878 * @param x The input tensor.
33879 * @param alpha The scaling factor for negative values, defaults to 0.2.
33880 *
33881 * @doc {heading: 'Operations', subheading: 'Basic math'}
33882 */
33883 function leakyRelu_(x) {
33884 var alpha = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0.2;
33885 var $x = convertToTensor(x, 'x', 'leakyRelu');
33886 var inputs = {
33887 x: $x
33888 };
33889 var attrs = {
33890 alpha: alpha
33891 };
33892 return ENGINE.runKernel(LeakyRelu, inputs, attrs);
33893 }
33894 var leakyRelu$2 = /* @__PURE__ */op({
33895 leakyRelu_: leakyRelu_
33896 });
33897
33898 /**
33899 * Returns the truth value of (a < b) element-wise. Supports broadcasting.
33900 *
33901 * ```js
33902 * const a = tf.tensor1d([1, 2, 3]);
33903 * const b = tf.tensor1d([2, 2, 2]);
33904 *
33905 * a.less(b).print();
33906 * ```
33907 * @param a The first input tensor.
33908 * @param b The second input tensor. Must have the same dtype as `a`.
33909 *
33910 * @doc {heading: 'Operations', subheading: 'Logical'}
33911 */
33912 function less_(a, b) {
33913 var $a = convertToTensor(a, 'a', 'less', 'string_or_numeric');
33914 var $b = convertToTensor(b, 'b', 'less', 'string_or_numeric');
33915 var _makeTypesMatch = makeTypesMatch($a, $b);
33916 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
33917 $a = _makeTypesMatch2[0];
33918 $b = _makeTypesMatch2[1];
33919 assertAndGetBroadcastShape($a.shape, $b.shape);
33920 var inputs = {
33921 a: $a,
33922 b: $b
33923 };
33924 return ENGINE.runKernel(Less, inputs);
33925 }
33926 var less$3 = /* @__PURE__ */op({
33927 less_: less_
33928 });
33929
33930 /**
33931 * Returns the truth value of (a <= b) element-wise. Supports broadcasting.
33932 *
33933 * ```js
33934 * const a = tf.tensor1d([1, 2, 3]);
33935 * const b = tf.tensor1d([2, 2, 2]);
33936 *
33937 * a.lessEqual(b).print();
33938 * ```
33939 *
33940 * @param a The first input tensor.
33941 * @param b The second input tensor. Must have the same dtype as `a`.
33942 *
33943 * @doc {heading: 'Operations', subheading: 'Logical'}
33944 */
33945 function lessEqual_(a, b) {
33946 var $a = convertToTensor(a, 'a', 'lessEqual', 'string_or_numeric');
33947 var $b = convertToTensor(b, 'b', 'lessEqual', 'string_or_numeric');
33948 var _makeTypesMatch = makeTypesMatch($a, $b);
33949 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
33950 $a = _makeTypesMatch2[0];
33951 $b = _makeTypesMatch2[1];
33952 assertAndGetBroadcastShape($a.shape, $b.shape);
33953 var inputs = {
33954 a: $a,
33955 b: $b
33956 };
33957 return ENGINE.runKernel(LessEqual, inputs);
33958 }
33959 var lessEqual$2 = /* @__PURE__ */op({
33960 lessEqual_: lessEqual_
33961 });
33962
33963 /**
33964 * @license
33965 * Copyright 2018 Google LLC. All Rights Reserved.
33966 * Licensed under the Apache License, Version 2.0 (the "License");
33967 * you may not use this file except in compliance with the License.
33968 * You may obtain a copy of the License at
33969 *
33970 * http://www.apache.org/licenses/LICENSE-2.0
33971 *
33972 * Unless required by applicable law or agreed to in writing, software
33973 * distributed under the License is distributed on an "AS IS" BASIS,
33974 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33975 * See the License for the specific language governing permissions and
33976 * limitations under the License.
33977 * =============================================================================
33978 */
33979 /**
33980 * Return an evenly spaced sequence of numbers over the given interval.
33981 *
33982 * ```js
33983 * tf.linspace(0, 9, 10).print();
33984 * ```
33985 * @param start The start value of the sequence.
33986 * @param stop The end value of the sequence.
33987 * @param num The number of values to generate.
33988 *
33989 * @doc {heading: 'Tensors', subheading: 'Creation'}
33990 */
33991 function linspace(start, stop, num) {
33992 if (num <= 0) {
33993 throw new Error('The number of values should be positive.');
33994 }
33995 var attrs = {
33996 start: start,
33997 stop: stop,
33998 num: num
33999 };
34000 return ENGINE.runKernel(LinSpace, {}, attrs);
34001 }
34002
34003 /**
34004 * @license
34005 * Copyright 2020 Google LLC. All Rights Reserved.
34006 * Licensed under the Apache License, Version 2.0 (the "License");
34007 * you may not use this file except in compliance with the License.
34008 * You may obtain a copy of the License at
34009 *
34010 * http://www.apache.org/licenses/LICENSE-2.0
34011 *
34012 * Unless required by applicable law or agreed to in writing, software
34013 * distributed under the License is distributed on an "AS IS" BASIS,
34014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34015 * See the License for the specific language governing permissions and
34016 * limitations under the License.
34017 * =============================================================================
34018 */
34019 /**
34020 * Normalizes the activation of a local neighborhood across or within
34021 * channels.
34022 *
34023 * @param x The input tensor. The 4-D input tensor is treated as a 3-D array
34024 * of 1D vectors (along the last dimension), and each vector is
34025 * normalized independently.
34026 * @param depthRadius The number of adjacent channels in the 1D normalization
34027 * window.
34028 * @param bias A constant bias term for the basis.
34029 * @param alpha A scale factor, usually positive.
34030 * @param beta An exponent.
34031 *
34032 * @doc {heading: 'Operations', subheading: 'Normalization'}
34033 */
34034 function localResponseNormalization_(x) {
34035 var depthRadius = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 5;
34036 var bias = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1;
34037 var alpha = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 1;
34038 var beta = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 0.5;
34039 var $x = convertToTensor(x, 'x', 'localResponseNormalization');
34040 assert$1($x.rank === 4 || $x.rank === 3, function () {
34041 return "Error in localResponseNormalization: x must be rank 3 or 4 but got\n rank ".concat($x.rank, ".");
34042 });
34043 assert$1(isInt(depthRadius), function () {
34044 return "Error in localResponseNormalization: depthRadius must be an " + "integer but got depthRadius ".concat(depthRadius, ".");
34045 });
34046 var x4D = $x;
34047 var reshapedTo4D = false;
34048 if ($x.rank === 3) {
34049 reshapedTo4D = true;
34050 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
34051 }
34052 var inputs = {
34053 x: x4D
34054 };
34055 var attrs = {
34056 depthRadius: depthRadius,
34057 bias: bias,
34058 alpha: alpha,
34059 beta: beta
34060 };
34061 // tslint:disable-next-line: no-unnecessary-type-assertion
34062 var res = ENGINE.runKernel(LRN, inputs, attrs);
34063 if (reshapedTo4D) {
34064 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
34065 } else {
34066 return res;
34067 }
34068 }
34069 var localResponseNormalization = /* @__PURE__ */op({
34070 localResponseNormalization_: localResponseNormalization_
34071 });
34072
34073 /**
34074 * @license
34075 * Copyright 2018 Google LLC. All Rights Reserved.
34076 * Licensed under the Apache License, Version 2.0 (the "License");
34077 * you may not use this file except in compliance with the License.
34078 * You may obtain a copy of the License at
34079 *
34080 * http://www.apache.org/licenses/LICENSE-2.0
34081 *
34082 * Unless required by applicable law or agreed to in writing, software
34083 * distributed under the License is distributed on an "AS IS" BASIS,
34084 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34085 * See the License for the specific language governing permissions and
34086 * limitations under the License.
34087 * =============================================================================
34088 */
34089 /**
34090 * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)`
34091 *
34092 * ```js
34093 * const x = tf.tensor1d([1, 2, Math.E]);
34094 *
34095 * x.log().print(); // or tf.log(x)
34096 * ```
34097 * @param x The input tensor.
34098 *
34099 * @doc {heading: 'Operations', subheading: 'Basic math'}
34100 */
34101 function log_(x) {
34102 var $x = convertToTensor(x, 'x', 'log', 'float32');
34103 var inputs = {
34104 x: $x
34105 };
34106 return ENGINE.runKernel(Log, inputs);
34107 }
34108 var log$2 = /* @__PURE__ */op({
34109 log_: log_
34110 });
34111
34112 /**
34113 * @license
34114 * Copyright 2018 Google LLC. All Rights Reserved.
34115 * Licensed under the Apache License, Version 2.0 (the "License");
34116 * you may not use this file except in compliance with the License.
34117 * You may obtain a copy of the License at
34118 *
34119 * http://www.apache.org/licenses/LICENSE-2.0
34120 *
34121 * Unless required by applicable law or agreed to in writing, software
34122 * distributed under the License is distributed on an "AS IS" BASIS,
34123 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34124 * See the License for the specific language governing permissions and
34125 * limitations under the License.
34126 * =============================================================================
34127 */
34128 /**
34129 * Computes natural logarithm of the input `tf.Tensor` plus one
34130 * element-wise: `ln(1 + x)`
34131 *
34132 * ```js
34133 * const x = tf.tensor1d([1, 2, Math.E - 1]);
34134 *
34135 * x.log1p().print(); // or tf.log1p(x)
34136 * ```
34137 * @param x The input tensor.
34138 *
34139 * @doc {heading: 'Operations', subheading: 'Basic math'}
34140 */
34141 function log1p_(x) {
34142 var $x = convertToTensor(x, 'x', 'log1p');
34143 var inputs = {
34144 x: $x
34145 };
34146 return ENGINE.runKernel(Log1p, inputs);
34147 }
34148 var log1p$2 = /* @__PURE__ */op({
34149 log1p_: log1p_
34150 });
34151
34152 /**
34153 * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the
34154 * gradient of `f(x)` with respect to `x`.
34155 *
34156 * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to
34157 * `x` is computed instead. `f(x)` must take a single tensor `x` and return a
34158 * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead.
34159 *
34160 * ```js
34161 * // f(x) = x ^ 2
34162 * const f = x => x.square();
34163 * // f'(x) = 2x
34164 * const g = tf.grad(f);
34165 *
34166 * const x = tf.tensor1d([2, 3]);
34167 * g(x).print();
34168 * ```
34169 *
34170 * ```js
34171 * // f(x) = x ^ 3
34172 * const f = x => x.pow(tf.scalar(3, 'int32'));
34173 * // f'(x) = 3x ^ 2
34174 * const g = tf.grad(f);
34175 * // f''(x) = 6x
34176 * const gg = tf.grad(g);
34177 *
34178 * const x = tf.tensor1d([2, 3]);
34179 * gg(x).print();
34180 * ```
34181 *
34182 * @param f The function f(x), to compute gradient for.
34183 *
34184 * @doc {heading: 'Training', subheading: 'Gradients'}
34185 */
34186 function grad(f) {
34187 assert$1(isFunction(f), function () {
34188 return 'The f passed in grad(f) must be a function';
34189 });
34190 return function (x, dy) {
34191 // x can be of any dtype, thus null as the last argument.
34192 var $x = convertToTensor(x, 'x', 'tf.grad', 'string_or_numeric');
34193 var $dy = dy != null ? convertToTensor(dy, 'dy', 'tf.grad') : null;
34194 return ENGINE.tidy(function () {
34195 var _ENGINE$gradients = ENGINE.gradients(function () {
34196 return f($x);
34197 }, [$x], $dy),
34198 value = _ENGINE$gradients.value,
34199 grads = _ENGINE$gradients.grads;
34200 if ($dy != null) {
34201 assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' + 'returned by f(x)');
34202 }
34203 checkGrads(grads);
34204 return grads[0];
34205 });
34206 };
34207 }
34208 /**
34209 * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`,
34210 * which gives an array of gradients of `f()` with respect to each input
34211 * [`x1`,`x2`,...].
34212 *
34213 * If `dy` is passed when calling `g()`, the gradient of
34214 * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead.
34215 * The provided `f` must take one or more tensors and return a single tensor
34216 * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead.
34217 *
34218 * ```js
34219 * // f(a, b) = a * b
34220 * const f = (a, b) => a.mul(b);
34221 * // df / da = b, df / db = a
34222 * const g = tf.grads(f);
34223 *
34224 * const a = tf.tensor1d([2, 3]);
34225 * const b = tf.tensor1d([-2, -3]);
34226 * const [da, db] = g([a, b]);
34227 * console.log('da');
34228 * da.print();
34229 * console.log('db');
34230 * db.print();
34231 * ```
34232 *
34233 * @param f The function `f(x1, x2,...)` to compute gradients for.
34234 *
34235 * @doc {heading: 'Training', subheading: 'Gradients'}
34236 */
34237 function grads(f) {
34238 assert$1(isFunction(f), function () {
34239 return 'The f passed in grads(f) must be a function';
34240 });
34241 return function (args, dy) {
34242 assert$1(Array.isArray(args), function () {
34243 return 'The args passed in grads(f)(args) must be an array ' + 'of `Tensor`s or `TensorLike`s';
34244 });
34245 // args can be of any dtype, thus null as the last argument.
34246 var $args = convertToTensorArray(args, 'args', 'tf.grads', 'string_or_numeric');
34247 var $dy = dy != null ? convertToTensor(dy, 'dy', 'tf.grads') : null;
34248 return ENGINE.tidy(function () {
34249 var _ENGINE$gradients2 = ENGINE.gradients(function () {
34250 return f.apply(void 0, _toConsumableArray($args));
34251 }, $args, $dy),
34252 value = _ENGINE$gradients2.value,
34253 grads = _ENGINE$gradients2.grads;
34254 if ($dy != null) {
34255 assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' + 'match the shape returned by f([x1,...])');
34256 }
34257 checkGrads(grads);
34258 return grads;
34259 });
34260 };
34261 }
34262 /**
34263 * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()`
34264 * returns a metric you want to show.
34265 *
34266 * The result is a rich object with the following properties:
34267 * - grad: The gradient of `f(x)` w.r.t. `x` (result of `tf.grad`).
34268 * - value: The value returned by `f(x)`.
34269 *
34270 * ```js
34271 * // f(x) = x ^ 2
34272 * const f = x => x.square();
34273 * // f'(x) = 2x
34274 * const g = tf.valueAndGrad(f);
34275 *
34276 * const x = tf.tensor1d([2, 3]);
34277 * const {value, grad} = g(x);
34278 *
34279 * console.log('value');
34280 * value.print();
34281 * console.log('grad');
34282 * grad.print();
34283 * ```
34284 *
34285 * @doc {heading: 'Training', subheading: 'Gradients'}
34286 */
34287 function valueAndGrad(f) {
34288 assert$1(isFunction(f), function () {
34289 return 'The f passed in valueAndGrad(f) must be a function';
34290 });
34291 return function (x, dy) {
34292 assert$1(x instanceof Tensor, function () {
34293 return 'The x passed in valueAndGrad(f)(x) must be a tensor';
34294 });
34295 assert$1(dy == null || dy instanceof Tensor, function () {
34296 return 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor';
34297 });
34298 var _ENGINE$gradients3 = ENGINE.gradients(function () {
34299 return f(x);
34300 }, [x], dy),
34301 grads = _ENGINE$gradients3.grads,
34302 value = _ENGINE$gradients3.value;
34303 checkGrads(grads);
34304 return {
34305 grad: grads[0],
34306 value: value
34307 };
34308 };
34309 }
34310 /**
34311 * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()`
34312 * returns a metric you want to show.
34313 *
34314 * The result is a rich object with the following properties:
34315 * - grads: The gradients of `f()` w.r.t. each input (result of `tf.grads`).
34316 * - value: The value returned by `f(x)`.
34317 *
34318 * ```js
34319 * // f(a, b) = a * b
34320 * const f = (a, b) => a.mul(b);
34321 * // df/da = b, df/db = a
34322 * const g = tf.valueAndGrads(f);
34323 *
34324 * const a = tf.tensor1d([2, 3]);
34325 * const b = tf.tensor1d([-2, -3]);
34326 * const {value, grads} = g([a, b]);
34327 *
34328 * const [da, db] = grads;
34329 *
34330 * console.log('value');
34331 * value.print();
34332 *
34333 * console.log('da');
34334 * da.print();
34335 * console.log('db');
34336 * db.print();
34337 * ```
34338 *
34339 * @doc {heading: 'Training', subheading: 'Gradients'}
34340 */
34341 function valueAndGrads(f) {
34342 assert$1(isFunction(f), function () {
34343 return 'The f passed in valueAndGrads(f) must be a function';
34344 });
34345 return function (args, dy) {
34346 assert$1(Array.isArray(args) && args.every(function (arg) {
34347 return arg instanceof Tensor;
34348 }), function () {
34349 return 'The args passed in valueAndGrads(f)(args) must be array of ' + 'tensors';
34350 });
34351 assert$1(dy == null || dy instanceof Tensor, function () {
34352 return 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor';
34353 });
34354 var res = ENGINE.gradients(function () {
34355 return f.apply(void 0, _toConsumableArray(args));
34356 }, args, dy);
34357 if (dy != null) {
34358 assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' + 'match the shape returned by f([x1,...])');
34359 }
34360 checkGrads(res.grads);
34361 return res;
34362 };
34363 }
34364 /**
34365 * Computes and returns the gradient of f(x) with respect to the list of
34366 * trainable variables provided by `varList`. If no list is provided, it
34367 * defaults to all trainable variables.
34368 *
34369 * ```js
34370 * const a = tf.variable(tf.tensor1d([3, 4]));
34371 * const b = tf.variable(tf.tensor1d([5, 6]));
34372 * const x = tf.tensor1d([1, 2]);
34373 *
34374 * // f(a, b) = a * x ^ 2 + b * x
34375 * const f = () => a.mul(x.square()).add(b.mul(x)).sum();
34376 * // df/da = x ^ 2, df/db = x
34377 * const {value, grads} = tf.variableGrads(f);
34378 *
34379 * Object.keys(grads).forEach(varName => grads[varName].print());
34380 * ```
34381 *
34382 * @param f The function to execute. f() should return a scalar.
34383 * @param varList The list of variables to compute the gradients with respect
34384 * to. Defaults to all trainable variables.
34385 * @returns An object with the following keys and values:
34386 * - `value`: The value of the function `f`.
34387 * - `grads`: A map from the names of the variables to the gradients.
34388 * If the `varList` argument is provided explicitly and contains a subset of
34389 * non-trainable variables, this map in the return value will contain keys
34390 * that map the names of the non-trainable variables to `null`.
34391 *
34392 * @doc {heading: 'Training', subheading: 'Gradients'}
34393 */
34394 function variableGrads(f, varList) {
34395 assert$1(isFunction(f), function () {
34396 return 'The f passed in variableGrads(f) must be a function';
34397 });
34398 assert$1(varList == null || Array.isArray(varList) && varList.every(function (v) {
34399 return v instanceof Variable;
34400 }), function () {
34401 return 'The varList passed in variableGrads(f, varList) must be an array ' + 'of variables';
34402 });
34403 var specifiedVarList = varList != null;
34404 if (!specifiedVarList) {
34405 // Get all of the trainable variables.
34406 varList = [];
34407 for (var varName in ENGINE.registeredVariables) {
34408 varList.push(ENGINE.registeredVariables[varName]);
34409 }
34410 }
34411 var specifiedNonTrainable = specifiedVarList ? varList.filter(function (variable) {
34412 return !variable.trainable;
34413 }) : null;
34414 // Prune non-trainable variables.
34415 var originalVarCount = varList.length;
34416 varList = varList.filter(function (variable) {
34417 return variable.trainable;
34418 });
34419 assert$1(varList.length > 0, function () {
34420 return "variableGrads() expects at least one of the input variables to " + "be trainable, but none of the ".concat(originalVarCount, " variables is ") + "trainable.";
34421 });
34422 var allowNoGradients = true;
34423 var _ENGINE$gradients4 = ENGINE.gradients(f, varList, null, allowNoGradients),
34424 value = _ENGINE$gradients4.value,
34425 grads = _ENGINE$gradients4.grads;
34426 assert$1(grads.some(function (g) {
34427 return g != null;
34428 }), function () {
34429 return 'Cannot find a connection between any variable and the result of ' + 'the loss function y=f(x). Please make sure the operations that ' + 'use variables are inside the function f passed to minimize().';
34430 });
34431 assert$1(value.rank === 0, function () {
34432 return "The f passed in variableGrads(f) must return a scalar, but it " + "returned a rank-".concat(value.rank, " tensor");
34433 });
34434 var namedGrads = {};
34435 varList.forEach(function (v, i) {
34436 if (grads[i] != null) {
34437 namedGrads[v.name] = grads[i];
34438 }
34439 });
34440 if (specifiedNonTrainable != null) {
34441 // If varList is explicitly provided and contains non-trainable values,
34442 // add them to the returned gradients with `null` values.
34443 specifiedNonTrainable.forEach(function (v) {
34444 return namedGrads[v.name] = null;
34445 });
34446 }
34447 return {
34448 value: value,
34449 grads: namedGrads
34450 };
34451 }
34452 /**
34453 * Overrides the gradient computation of a function `f`.
34454 *
34455 * Takes a function
34456 * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}`
34457 * and returns another function `g(...inputs)` which takes the same inputs as
34458 * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients
34459 * with respect to each input of `f` are computed using `f().gradFunc`.
34460 *
34461 * The `save` function passed to `f` should be used for saving tensors needed
34462 * in the gradient. And the `saved` passed to the `gradFunc` is a
34463 * `NamedTensorMap`, which contains those saved tensors.
34464 *
34465 * ```js
34466 * const customOp = tf.customGrad((x, save) => {
34467 * // Save x to make sure it's available later for the gradient.
34468 * save([x]);
34469 * // Override gradient of our custom x ^ 2 op to be dy * abs(x);
34470 * return {
34471 * value: x.square(),
34472 * // Note `saved.x` which points to the `x` we saved earlier.
34473 * gradFunc: (dy, saved) => [dy.mul(saved[0].abs())]
34474 * };
34475 * });
34476 *
34477 * const x = tf.tensor1d([-1, -2, 3]);
34478 * const dx = tf.grad(x => customOp(x));
34479 *
34480 * console.log(`f(x):`);
34481 * customOp(x).print();
34482 * console.log(`f'(x):`);
34483 * dx(x).print();
34484 * ```
34485 *
34486 * @param f The function to evaluate in forward mode, which should return
34487 * `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc`
34488 * returns the custom gradients of `f` with respect to its inputs.
34489 *
34490 * @doc {heading: 'Training', subheading: 'Gradients'}
34491 */
34492 function customGrad(f) {
34493 return ENGINE.customGrad(f);
34494 }
34495 function checkGrads(grads) {
34496 var numNullGradients = grads.filter(function (g) {
34497 return g == null;
34498 }).length;
34499 if (numNullGradients > 0) {
34500 throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that\n the f you passed encloses all operations that lead from x to y.");
34501 }
34502 }
34503
34504 /**
34505 * @license
34506 * Copyright 2018 Google LLC. All Rights Reserved.
34507 * Licensed under the Apache License, Version 2.0 (the "License");
34508 * you may not use this file except in compliance with the License.
34509 * You may obtain a copy of the License at
34510 *
34511 * http://www.apache.org/licenses/LICENSE-2.0
34512 *
34513 * Unless required by applicable law or agreed to in writing, software
34514 * distributed under the License is distributed on an "AS IS" BASIS,
34515 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34516 * See the License for the specific language governing permissions and
34517 * limitations under the License.
34518 * =============================================================================
34519 */
34520 /**
34521 * Computes `-1 * x` element-wise.
34522 *
34523 * ```js
34524 * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]);
34525 *
34526 * x.neg().print(); // or tf.neg(x)
34527 * ```
34528 *
34529 * @param x The input tensor.
34530 *
34531 * @doc {heading: 'Operations', subheading: 'Basic math'}
34532 */
34533 function neg_(x) {
34534 var $x = convertToTensor(x, 'x', 'neg');
34535 var inputs = {
34536 x: $x
34537 };
34538 return ENGINE.runKernel(Neg, inputs);
34539 }
34540 var neg$2 = /* @__PURE__ */op({
34541 neg_: neg_
34542 });
34543
34544 /**
34545 * @license
34546 * Copyright 2018 Google LLC. All Rights Reserved.
34547 * Licensed under the Apache License, Version 2.0 (the "License");
34548 * you may not use this file except in compliance with the License.
34549 * You may obtain a copy of the License at
34550 *
34551 * http://www.apache.org/licenses/LICENSE-2.0
34552 *
34553 * Unless required by applicable law or agreed to in writing, software
34554 * distributed under the License is distributed on an "AS IS" BASIS,
34555 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34556 * See the License for the specific language governing permissions and
34557 * limitations under the License.
34558 * =============================================================================
34559 */
34560 /**
34561 * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)`
34562 *
34563 * ```js
34564 * const x = tf.tensor1d([0, 1, -1, .7]);
34565 *
34566 * x.softplus().print(); // or tf.softplus(x)
34567 * ```
34568 * @param x The input tensor.
34569 *
34570 * @doc {heading: 'Operations', subheading: 'Basic math'}
34571 */
34572 function softplus_(x) {
34573 var $x = convertToTensor(x, 'x', 'softplus');
34574 var inputs = {
34575 x: $x
34576 };
34577 return ENGINE.runKernel(Softplus$1, inputs);
34578 }
34579 var softplus$2 = /* @__PURE__ */op({
34580 softplus_: softplus_
34581 });
34582
34583 /**
34584 * @license
34585 * Copyright 2018 Google LLC. All Rights Reserved.
34586 * Licensed under the Apache License, Version 2.0 (the "License");
34587 * you may not use this file except in compliance with the License.
34588 * You may obtain a copy of the License at
34589 *
34590 * http://www.apache.org/licenses/LICENSE-2.0
34591 *
34592 * Unless required by applicable law or agreed to in writing, software
34593 * distributed under the License is distributed on an "AS IS" BASIS,
34594 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34595 * See the License for the specific language governing permissions and
34596 * limitations under the License.
34597 * =============================================================================
34598 */
34599 /**
34600 * Computes log sigmoid of the input `tf.Tensor` element-wise:
34601 * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`.
34602 *
34603 * ```js
34604 * const x = tf.tensor1d([0, 1, -1, .7]);
34605 *
34606 * x.logSigmoid().print(); // or tf.logSigmoid(x)
34607 * ```
34608 * @param x The input tensor.
34609 *
34610 * @doc {heading: 'Operations', subheading: 'Basic math'}
34611 */
34612 function logSigmoid_(x) {
34613 var $x = convertToTensor(x, 'x', 'logSigmoid');
34614 // Use a custom gradient to maintain previous implementation.
34615 // There is no LogSigmoid kernel in TF so we can't use engine.runKernel
34616 // directly
34617 var customOp = customGrad(function (x) {
34618 // TODO(yassogba) we can remove the chained softplus call here only
34619 // after backends have modualrized softplus at which point we can call
34620 // engine runKernel(..., Sotfplus, ...) directly.
34621 var value = neg$2(softplus$2(neg$2(x)));
34622 var gradFunc = function gradFunc(dy) {
34623 var derX = mul(dy, sigmoid$2(neg$2(x)));
34624 return derX;
34625 };
34626 return {
34627 value: value,
34628 gradFunc: gradFunc
34629 };
34630 });
34631 return customOp($x);
34632 }
34633 var logSigmoid = /* @__PURE__ */op({
34634 logSigmoid_: logSigmoid_
34635 });
34636
34637 /**
34638 * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting.
34639 *
34640 * ```js
34641 * const a = tf.tensor1d([10, 20, 30, 40]);
34642 * const b = tf.tensor1d([1, 2, 3, 4]);
34643 *
34644 * a.sub(b).print(); // or tf.sub(a, b)
34645 * ```
34646 *
34647 * ```js
34648 * // Broadcast subtract a with b.
34649 * const a = tf.tensor1d([10, 20, 30, 40]);
34650 * const b = tf.scalar(5);
34651 *
34652 * a.sub(b).print(); // or tf.sub(a, b)
34653 * ```
34654 * @param a The first `tf.Tensor` to subtract from.
34655 * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as
34656 * `a`.
34657 *
34658 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
34659 */
34660 function sub_(a, b) {
34661 var $a = convertToTensor(a, 'a', 'sub');
34662 var $b = convertToTensor(b, 'b', 'sub');
34663 var _makeTypesMatch = makeTypesMatch($a, $b);
34664 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
34665 $a = _makeTypesMatch2[0];
34666 $b = _makeTypesMatch2[1];
34667 var inputs = {
34668 a: $a,
34669 b: $b
34670 };
34671 return ENGINE.runKernel(Sub, inputs);
34672 }
34673 var sub$2 = /* @__PURE__ */op({
34674 sub_: sub_
34675 });
34676
34677 /**
34678 * Computes the log softmax.
34679 *
34680 * ```js
34681 * const a = tf.tensor1d([1, 2, 3]);
34682 *
34683 * a.logSoftmax().print(); // or tf.logSoftmax(a)
34684 * ```
34685 *
34686 * ```js
34687 * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
34688 *
34689 * a.logSoftmax().print(); // or tf.logSoftmax(a)
34690 * ```
34691 *
34692 * @param logits The logits array.
34693 * @param axis The dimension softmax would be performed on. Defaults to `-1`
34694 * which indicates the last dimension.
34695 *
34696 * @doc {heading: 'Operations', subheading: 'Normalization'}
34697 */
34698 function logSoftmax_(logits) {
34699 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : -1;
34700 var $logits = convertToTensor(logits, 'logits', 'logSoftmax');
34701 if (axis === -1) {
34702 axis = $logits.rank - 1;
34703 }
34704 if (axis !== $logits.rank - 1) {
34705 throw Error('Log Softmax along a non-last dimension is not yet supported. ' + "Logits was rank ".concat($logits.rank, " and axis was ").concat(axis));
34706 }
34707 // const forward: ForwardFunc<Tensor> = (backend, save) => {
34708 // const keepDims = true;
34709 // const xMax = max(logits, axis, true);
34710 // const shifted = sub(logits, xMax);
34711 // const value =
34712 // sub(cast(shifted, 'float32'), log(sum(exp(shifted), axis,
34713 // keepDims)));
34714 // save([value]);
34715 // return value;
34716 // };
34717 // Use a custom gradient for numerical stability.
34718 var customOp = customGrad(function (logits, save) {
34719 var keepDims = true;
34720 var xMax = max$3(logits, axis, true);
34721 var shifted = sub$2(logits, xMax);
34722 var value = sub$2(cast$3(shifted, 'float32'), log$2(sum$3(exp$2(shifted), axis, keepDims)));
34723 save([value]);
34724 var gradFunc = function gradFunc(dy, saved) {
34725 var _saved = _slicedToArray(saved, 1),
34726 value = _saved[0];
34727 var keepDims = true;
34728 var softmax = exp$2(value);
34729 return sub$2(dy, mul(sum$3(dy, axis, keepDims), softmax));
34730 };
34731 return {
34732 value: value,
34733 gradFunc: gradFunc
34734 };
34735 });
34736 return customOp($logits);
34737 // TODO Use Engine.runKernel when CPU/WebGL/WASM backends implement this.
34738 // const inputs: LogSoftmaxInputs = {logits: $logits};
34739 // const attrs: LogSoftmaxAttrs = {axis};
34740 // return ENGINE.runKernel(
34741 // LogSoftmax, inputs as unknown as NamedTensorMap,
34742 // attrs as unknown as NamedAttrMap);
34743 }
34744
34745 var logSoftmax = /* @__PURE__ */op({
34746 logSoftmax_: logSoftmax_
34747 });
34748
34749 /**
34750 * @license
34751 * Copyright 2020 Google LLC. All Rights Reserved.
34752 * Licensed under the Apache License, Version 2.0 (the "License");
34753 * you may not use this file except in compliance with the License.
34754 * You may obtain a copy of the License at
34755 *
34756 * http://www.apache.org/licenses/LICENSE-2.0
34757 *
34758 * Unless required by applicable law or agreed to in writing, software
34759 * distributed under the License is distributed on an "AS IS" BASIS,
34760 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34761 * See the License for the specific language governing permissions and
34762 * limitations under the License.
34763 * =============================================================================
34764 */
34765 /**
34766 * Computes the log(sum(exp(elements across the reduction dimensions))).
34767 *
34768 * Reduces the input along the dimensions given in `axis`. Unless `keepDims`
34769 * is true, the rank of the array is reduced by 1 for each entry in `axis`.
34770 * If `keepDims` is true, the reduced dimensions are retained with length 1.
34771 * If `axis` has no entries, all dimensions are reduced, and an array with a
34772 * single element is returned.
34773 *
34774 * ```js
34775 * const x = tf.tensor1d([1, 2, 3]);
34776 *
34777 * x.logSumExp().print(); // or tf.logSumExp(x)
34778 * ```
34779 *
34780 * ```js
34781 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
34782 *
34783 * const axis = 1;
34784 * x.logSumExp(axis).print(); // or tf.logSumExp(a, axis)
34785 * ```
34786 * @param x The input tensor.
34787 * @param axis The dimension(s) to reduce. If null (the default),
34788 * reduces all dimensions.
34789 * @param keepDims If true, retains reduced dimensions with length
34790 * of 1. Defaults to false.
34791 *
34792 * @doc {heading: 'Operations', subheading: 'Reduction'}
34793 */
34794 function logSumExp_(x) {
34795 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null;
34796 var keepDims = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
34797 var $x = convertToTensor(x, 'x', 'logSumExp');
34798 var axes = parseAxisParam(axis, $x.shape);
34799 var xMax = max$3($x, axes, true /* keepDims */);
34800 var a = sub$2($x, xMax);
34801 var b = exp$2(a);
34802 var c = sum$3(b, axes);
34803 var d = log$2(c);
34804 var res = add$3(reshape$3(xMax, d.shape), d);
34805 if (keepDims) {
34806 var newShape = expandShapeToKeepDim(res.shape, axes);
34807 return reshape$3(res, newShape);
34808 }
34809 return res;
34810 }
34811 var logSumExp = /* @__PURE__ */op({
34812 logSumExp_: logSumExp_
34813 });
34814
34815 /**
34816 * @license
34817 * Copyright 2020 Google LLC. All Rights Reserved.
34818 * Licensed under the Apache License, Version 2.0 (the "License");
34819 * you may not use this file except in compliance with the License.
34820 * You may obtain a copy of the License at
34821 *
34822 * http://www.apache.org/licenses/LICENSE-2.0
34823 *
34824 * Unless required by applicable law or agreed to in writing, software
34825 * distributed under the License is distributed on an "AS IS" BASIS,
34826 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34827 * See the License for the specific language governing permissions and
34828 * limitations under the License.
34829 * =============================================================================
34830 */
34831 /**
34832 * Returns the truth value of `a AND b` element-wise. Supports broadcasting.
34833 *
34834 * ```js
34835 * const a = tf.tensor1d([false, false, true, true], 'bool');
34836 * const b = tf.tensor1d([false, true, false, true], 'bool');
34837 *
34838 * a.logicalAnd(b).print();
34839 * ```
34840 *
34841 * @param a The first input tensor. Must be of dtype bool.
34842 * @param b The second input tensor. Must be of dtype bool.
34843 *
34844 * @doc {heading: 'Operations', subheading: 'Logical'}
34845 */
34846 function logicalAnd_(a, b) {
34847 var $a = convertToTensor(a, 'a', 'logicalAnd', 'bool');
34848 var $b = convertToTensor(b, 'b', 'logicalAnd', 'bool');
34849 assertAndGetBroadcastShape($a.shape, $b.shape);
34850 var inputs = {
34851 a: $a,
34852 b: $b
34853 };
34854 return ENGINE.runKernel(LogicalAnd, inputs);
34855 }
34856 var logicalAnd$2 = /* @__PURE__ */op({
34857 logicalAnd_: logicalAnd_
34858 });
34859
34860 /**
34861 * @license
34862 * Copyright 2020 Google LLC. All Rights Reserved.
34863 * Licensed under the Apache License, Version 2.0 (the "License");
34864 * you may not use this file except in compliance with the License.
34865 * You may obtain a copy of the License at
34866 *
34867 * http://www.apache.org/licenses/LICENSE-2.0
34868 *
34869 * Unless required by applicable law or agreed to in writing, software
34870 * distributed under the License is distributed on an "AS IS" BASIS,
34871 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34872 * See the License for the specific language governing permissions and
34873 * limitations under the License.
34874 * =============================================================================
34875 */
34876 /**
34877 * Returns the truth value of `NOT x` element-wise.
34878 *
34879 * ```js
34880 * const a = tf.tensor1d([false, true], 'bool');
34881 *
34882 * a.logicalNot().print();
34883 * ```
34884 *
34885 * @param x The input tensor. Must be of dtype 'bool'.
34886 *
34887 * @doc {heading: 'Operations', subheading: 'Logical'}
34888 */
34889 function logicalNot_(x) {
34890 var $x = convertToTensor(x, 'x', 'logicalNot', 'bool');
34891 var inputs = {
34892 x: $x
34893 };
34894 return ENGINE.runKernel(LogicalNot, inputs);
34895 }
34896 var logicalNot$2 = /* @__PURE__ */op({
34897 logicalNot_: logicalNot_
34898 });
34899
34900 /**
34901 * @license
34902 * Copyright 2020 Google LLC. All Rights Reserved.
34903 * Licensed under the Apache License, Version 2.0 (the "License");
34904 * you may not use this file except in compliance with the License.
34905 * You may obtain a copy of the License at
34906 *
34907 * http://www.apache.org/licenses/LICENSE-2.0
34908 *
34909 * Unless required by applicable law or agreed to in writing, software
34910 * distributed under the License is distributed on an "AS IS" BASIS,
34911 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34912 * See the License for the specific language governing permissions and
34913 * limitations under the License.
34914 * =============================================================================
34915 */
34916 /**
34917 * Returns the truth value of `a OR b` element-wise. Supports broadcasting.
34918 *
34919 * ```js
34920 * const a = tf.tensor1d([false, false, true, true], 'bool');
34921 * const b = tf.tensor1d([false, true, false, true], 'bool');
34922 *
34923 * a.logicalOr(b).print();
34924 * ```
34925 * @param a The first input tensor. Must be of dtype bool.
34926 * @param b The second input tensor. Must be of dtype bool.
34927 *
34928 * @doc {heading: 'Operations', subheading: 'Logical'}
34929 */
34930 function logicalOr_(a, b) {
34931 var $a = convertToTensor(a, 'a', 'logicalOr', 'bool');
34932 var $b = convertToTensor(b, 'b', 'logicalOr', 'bool');
34933 assertAndGetBroadcastShape($a.shape, $b.shape);
34934 var inputs = {
34935 a: $a,
34936 b: $b
34937 };
34938 return ENGINE.runKernel(LogicalOr, inputs);
34939 }
34940 var logicalOr$2 = /* @__PURE__ */op({
34941 logicalOr_: logicalOr_
34942 });
34943
34944 /**
34945 * @license
34946 * Copyright 2020 Google LLC. All Rights Reserved.
34947 * Licensed under the Apache License, Version 2.0 (the "License");
34948 * you may not use this file except in compliance with the License.
34949 * You may obtain a copy of the License at
34950 *
34951 * http://www.apache.org/licenses/LICENSE-2.0
34952 *
34953 * Unless required by applicable law or agreed to in writing, software
34954 * distributed under the License is distributed on an "AS IS" BASIS,
34955 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34956 * See the License for the specific language governing permissions and
34957 * limitations under the License.
34958 * =============================================================================
34959 */
34960 /**
34961 * Returns the truth value of `a XOR b` element-wise. Supports broadcasting.
34962 *
34963 * ```js
34964 * const a = tf.tensor1d([false, false, true, true], 'bool');
34965 * const b = tf.tensor1d([false, true, false, true], 'bool');
34966 *
34967 * a.logicalXor(b).print();
34968 * ```
34969 *
34970 * @param a The first input tensor. Must be of dtype bool.
34971 * @param b The second input tensor. Must be of dtype bool.
34972 *
34973 * @doc {heading: 'Operations', subheading: 'Logical'}
34974 */
34975 function logicalXor_(a, b) {
34976 var $a = convertToTensor(a, 'a', 'logicalXor', 'bool');
34977 var $b = convertToTensor(b, 'b', 'logicalXor', 'bool');
34978 assertAndGetBroadcastShape($a.shape, $b.shape);
34979 // x ^ y = (x | y) & ~(x & y)
34980 return logicalAnd$2(logicalOr$2(a, b), logicalNot$2(logicalAnd$2(a, b)));
34981 }
34982 var logicalXor = /* @__PURE__ */op({
34983 logicalXor_: logicalXor_
34984 });
34985
34986 /**
34987 * @license
34988 * Copyright 2022 Google LLC. All Rights Reserved.
34989 * Licensed under the Apache License, Version 2.0 (the "License");
34990 * you may not use this file except in compliance with the License.
34991 * You may obtain a copy of the License at
34992 *
34993 * http://www.apache.org/licenses/LICENSE-2.0
34994 *
34995 * Unless required by applicable law or agreed to in writing, software
34996 * distributed under the License is distributed on an "AS IS" BASIS,
34997 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34998 * See the License for the specific language governing permissions and
34999 * limitations under the License.
35000 * =============================================================================
35001 */
35002 var INT32_MAX$1 = 2147483648;
35003 /**
35004 * Searches for where a value would go in a sorted sequence.
35005 *
35006 * This is not a method for checking containment (like javascript in).
35007 *
35008 * The typical use case for this operation is "binning", "bucketing", or
35009 * "discretizing". The values are assigned to bucket-indices based on the edges
35010 * listed in 'sortedSequence'. This operation returns the bucket-index for each
35011 * value.
35012 *
35013 * The side argument controls which index is returned if a value lands exactly
35014 * on an edge.
35015 *
35016 * The axis is not settable for this operation. It always operates on the
35017 * innermost dimension (axis=-1). The operation will accept any number of outer
35018 * dimensions.
35019 *
35020 * Note: This operation assumes that 'sortedSequence' is sorted along the
35021 * innermost axis, maybe using 'sort(..., axis=-1)'. If the sequence is not
35022 * sorted no error is raised and the content of the returned tensor is not well
35023 * defined.
35024 *
35025 * ```js
35026 * const edges = tf.tensor1d([-1, 3.3, 9.1, 10.0]);
35027 * let values = tf.tensor1d([0.0, 4.1, 12.0]);
35028 * const result1 = tf.searchSorted(edges, values, 'left');
35029 * result1.print(); // [1, 2, 4]
35030 *
35031 * const seq = tf.tensor1d([0, 3, 9, 10, 10]);
35032 * values = tf.tensor1d([0, 4, 10]);
35033 * const result2 = tf.searchSorted(seq, values, 'left');
35034 * result2.print(); // [0, 2, 3]
35035 * const result3 = tf.searchSorted(seq, values, 'right');
35036 * result3.print(); // [1, 2, 5]
35037 *
35038 * const sortedSequence = tf.tensor2d([[0., 3., 8., 9., 10.],
35039 * [1., 2., 3., 4., 5.]]);
35040 * values = tf.tensor2d([[9.8, 2.1, 4.3],
35041 * [0.1, 6.6, 4.5, ]]);
35042 * const result4 = tf.searchSorted(sortedSequence, values, 'left');
35043 * result4.print(); // [[4, 1, 2], [0, 5, 4]]
35044 * ```
35045 * @param sortedSequence: N-D. Sorted sequence.
35046 * @param values: N-D. Search values.
35047 * @param side: 'left'|'right'. Defaults to 'left'. 'left' corresponds to lower
35048 * bound and 'right' to upper bound.
35049 * @return An N-D int32 tensor the size of values containing the result of
35050 * applying either lower bound or upper bound (depending on side) to each
35051 * value. The result is not a global index to the entire Tensor, but the
35052 * index in the last dimension.
35053 * @doc {heading: 'Operations', subheading: 'Evaluation'}
35054 */
35055 function searchSorted_(sortedSequence, values) {
35056 var side = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 'left';
35057 var $sortedSequence = convertToTensor(sortedSequence, 'sortedSequence', 'searchSorted');
35058 var $values = convertToTensor(values, 'values', 'searchSorted');
35059 var sequenceSize = $sortedSequence.shape[$sortedSequence.shape.length - 1];
35060 var valuesSize = $values.shape[$values.shape.length - 1];
35061 var $sortedSequence2D = reshape$3($sortedSequence, [-1, sequenceSize]);
35062 var $values2D = reshape$3($values, [-1, valuesSize]);
35063 if ($sortedSequence2D.rank < 2) {
35064 throw new Error("Sorted input argument must be at least 2-dimensional");
35065 }
35066 if ($sortedSequence2D.shape[0] !== $values2D.shape[0]) {
35067 throw new Error("Leading dimension of 'sortedSequence' and 'values' must match.");
35068 }
35069 if (sizeFromShape($values2D.shape) >= INT32_MAX$1) {
35070 throw new Error("values tensor size must less than ".concat(INT32_MAX$1));
35071 }
35072 if ($sortedSequence2D.shape[1] >= INT32_MAX$1) {
35073 throw new Error("trailing dim_size must less than ".concat(INT32_MAX$1, " for int32 output type, was ").concat($sortedSequence2D.shape[1]));
35074 }
35075 var inputs = {
35076 sortedSequence: $sortedSequence2D,
35077 values: $values2D
35078 };
35079 var attrs = {
35080 side: side
35081 };
35082 return ENGINE.runKernel(SearchSorted, inputs, attrs);
35083 }
35084 var searchSorted$2 = /* @__PURE__ */op({
35085 searchSorted_: searchSorted_
35086 });
35087
35088 /**
35089 * @license
35090 * Copyright 2022 Google LLC. All Rights Reserved.
35091 * Licensed under the Apache License, Version 2.0 (the "License");
35092 * you may not use this file except in compliance with the License.
35093 * You may obtain a copy of the License at
35094 *
35095 * http://www.apache.org/licenses/LICENSE-2.0
35096 *
35097 * Unless required by applicable law or agreed to in writing, software
35098 * distributed under the License is distributed on an "AS IS" BASIS,
35099 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35100 * See the License for the specific language governing permissions and
35101 * limitations under the License.
35102 * =============================================================================
35103 */
35104 /**
35105 * Searches for where a value would go in a sorted sequence.
35106 *
35107 * This is not a method for checking containment (like javascript in).
35108 *
35109 * The typical use case for this operation is "binning", "bucketing", or
35110 * "discretizing". The values are assigned to bucket-indices based on the edges
35111 * listed in 'sortedSequence'. This operation returns the bucket-index for each
35112 * value.
35113 *
35114 * The index returned corresponds to the first edge greater than or equal to the
35115 * value.
35116 *
35117 * The axis is not settable for this operation. It always operates on the
35118 * innermost dimension (axis=-1). The operation will accept any number of outer
35119 * dimensions.
35120 *
35121 * Note: This operation assumes that 'lowerBound' is sorted along the
35122 * innermost axis, maybe using 'sort(..., axis=-1)'. If the sequence is not
35123 * sorted no error is raised and the content of the returned tensor is not well
35124 * defined.
35125 *
35126 * ```js
35127 * const edges = tf.tensor1d([-1, 3.3, 9.1, 10.0]);
35128 * let values = tf.tensor1d([0.0, 4.1, 12.0]);
35129 * const result1 = tf.lowerBound(edges, values);
35130 * result1.print(); // [1, 2, 4]
35131 *
35132 * const seq = tf.tensor1d([0, 3, 9, 10, 10]);
35133 * values = tf.tensor1d([0, 4, 10]);
35134 * const result2 = tf.lowerBound(seq, values);
35135 * result2.print(); // [0, 2, 3]
35136 *
35137 * const sortedSequence = tf.tensor2d([[0., 3., 8., 9., 10.],
35138 * [1., 2., 3., 4., 5.]]);
35139 * values = tf.tensor2d([[9.8, 2.1, 4.3],
35140 * [0.1, 6.6, 4.5, ]]);
35141 * const result3 = tf.lowerBound(sortedSequence, values);
35142 * result3.print(); // [[4, 1, 2], [0, 5, 4]]
35143 * ```
35144 * @param sortedSequence: N-D. Sorted sequence.
35145 * @param values: N-D. Search values.
35146 * @return An N-D int32 tensor the size of values containing the result of
35147 * applying lower bound to each value. The result is not a global index to
35148 * the entire Tensor, but the index in the last dimension.
35149 * @doc {heading: 'Operations', subheading: 'Evaluation'}
35150 */
35151 function lowerBound$1(sortedSequence, values) {
35152 return searchSorted$2(sortedSequence, values, 'left');
35153 }
35154
35155 /**
35156 * @license
35157 * Copyright 2020 Google LLC. All Rights Reserved.
35158 * Licensed under the Apache License, Version 2.0 (the "License");
35159 * you may not use this file except in compliance with the License.
35160 * You may obtain a copy of the License at
35161 *
35162 * http://www.apache.org/licenses/LICENSE-2.0
35163 *
35164 * Unless required by applicable law or agreed to in writing, software
35165 * distributed under the License is distributed on an "AS IS" BASIS,
35166 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35167 * See the License for the specific language governing permissions and
35168 * limitations under the License.
35169 * =============================================================================
35170 */
35171 /**
35172 * Computes the 2D max pooling of an image.
35173 *
35174 * @param x The input tensor, of rank 4 or rank 3 of shape
35175 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
35176 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
35177 * `filterSize` is a single number, then `filterHeight == filterWidth`.
35178 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
35179 * `strides` is a single number, then `strideHeight == strideWidth`.
35180 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
35181 * in which we sample input values across the height and width dimensions
35182 * in dilated pooling. Defaults to `[1, 1]`. If `dilations` is a single
35183 * number, then `dilationHeight == dilationWidth`. If it is greater than
35184 * 1, then all values of `strides` must be 1.
35185 * @param pad The type of padding algorithm.
35186 * - `same` and stride 1: output will be of same size as input,
35187 * regardless of filter size.
35188 * - `valid`: output will be smaller than input if filter is larger
35189 * than 1x1.
35190 * - For more info, see this guide:
35191 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
35192 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
35193 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
35194 * provided, it will default to truncate.
35195 */
35196 function maxPool_(x, filterSize, strides, pad, dimRoundingMode) {
35197 var $x = convertToTensor(x, 'x', 'maxPool');
35198 var dilations = 1;
35199 var x4D = $x;
35200 var reshapedTo4D = false;
35201 if ($x.rank === 3) {
35202 reshapedTo4D = true;
35203 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
35204 }
35205 assert$1(x4D.rank === 4, function () {
35206 return "Error in maxPool: input must be rank 4 but got rank ".concat(x4D.rank, ".");
35207 });
35208 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
35209 return 'Error in maxPool: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
35210 });
35211 checkPadOnDimRoundingMode('maxPool', pad, dimRoundingMode);
35212 var inputs = {
35213 x: x4D
35214 };
35215 var attrs = {
35216 filterSize: filterSize,
35217 strides: strides,
35218 pad: pad,
35219 dimRoundingMode: dimRoundingMode
35220 };
35221 // tslint:disable-next-line: no-unnecessary-type-assertion
35222 var res = ENGINE.runKernel(MaxPool, inputs, attrs);
35223 if (reshapedTo4D) {
35224 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
35225 }
35226 return res;
35227 }
35228 var maxPool$2 = /* @__PURE__ */op({
35229 maxPool_: maxPool_
35230 });
35231
35232 /**
35233 * @license
35234 * Copyright 2020 Google LLC. All Rights Reserved.
35235 * Licensed under the Apache License, Version 2.0 (the "License");
35236 * you may not use this file except in compliance with the License.
35237 * You may obtain a copy of the License at
35238 *
35239 * http://www.apache.org/licenses/LICENSE-2.0
35240 *
35241 * Unless required by applicable law or agreed to in writing, software
35242 * distributed under the License is distributed on an "AS IS" BASIS,
35243 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35244 * See the License for the specific language governing permissions and
35245 * limitations under the License.
35246 * =============================================================================
35247 */
35248 /**
35249 * Computes the 3D max pooling.
35250 *
35251 * ```js
35252 * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
35253 * const result = tf.maxPool3d(x, 2, 1, 'valid');
35254 * result.print();
35255 * ```
35256 *
35257 * @param x The input tensor, of rank 5 or rank 4 of shape
35258 * `[batch, depth, height, width, inChannels]`.
35259 * @param filterSize The filter size:
35260 * `[filterDepth, filterHeight, filterWidth]`.
35261 * If `filterSize` is a single number,
35262 * then `filterDepth == filterHeight == filterWidth`.
35263 * @param strides The strides of the pooling:
35264 * `[strideDepth, strideHeight, strideWidth]`.
35265 * If `strides` is a single number,
35266 * then `strideDepth == strideHeight == strideWidth`.
35267 * @param pad The type of padding algorithm.
35268 * - `same` and stride 1: output will be of same size as input,
35269 * regardless of filter size.
35270 * - `valid`: output will be smaller than input if filter is larger
35271 * than 1*1x1.
35272 * - For more info, see this guide:
35273 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
35274 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
35275 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
35276 * provided, it will default to truncate.
35277 * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
35278 * "NDHWC". Specify the data format of the input and output data. With the
35279 * default format "NDHWC", the data is stored in the order of: [batch,
35280 * depth, height, width, channels]. Only "NDHWC" is currently supported.
35281 * @doc {heading: 'Operations', subheading: 'Convolution'}
35282 */
35283 function maxPool3d_(x) {
35284 var filterSize = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : [1, 1, 1];
35285 var strides = arguments.length > 2 ? arguments[2] : undefined;
35286 var pad = arguments.length > 3 ? arguments[3] : undefined;
35287 var dimRoundingMode = arguments.length > 4 ? arguments[4] : undefined;
35288 var dataFormat = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 'NDHWC';
35289 var $x = convertToTensor(x, 'x', 'maxPool3d');
35290 var x5D = $x;
35291 var reshapedTo5D = false;
35292 if ($x.rank === 4) {
35293 reshapedTo5D = true;
35294 x5D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
35295 }
35296 assert$1(x5D.rank === 5, function () {
35297 return "Error in maxPool3d: x must be rank 5 but got rank ".concat(x5D.rank, ".");
35298 });
35299 assert$1(dataFormat === 'NDHWC', function () {
35300 return "Error in maxPool3d: Only NDHWC is currently supported, " + "but got dataFormat of ".concat(dataFormat);
35301 });
35302 checkPadOnDimRoundingMode('maxPool3d', pad, dimRoundingMode);
35303 var inputs = {
35304 x: x5D
35305 };
35306 var attrs = {
35307 filterSize: filterSize,
35308 strides: strides,
35309 pad: pad,
35310 dimRoundingMode: dimRoundingMode,
35311 dataFormat: dataFormat
35312 };
35313 // tslint:disable-next-line: no-unnecessary-type-assertion
35314 var res = ENGINE.runKernel(MaxPool3D, inputs, attrs);
35315 if (reshapedTo5D) {
35316 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
35317 }
35318 return res;
35319 }
35320 var maxPool3d$1 = /* @__PURE__ */op({
35321 maxPool3d_: maxPool3d_
35322 });
35323
35324 /**
35325 * @license
35326 * Copyright 2018 Google LLC. All Rights Reserved.
35327 * Licensed under the Apache License, Version 2.0 (the "License");
35328 * you may not use this file except in compliance with the License.
35329 * You may obtain a copy of the License at
35330 *
35331 * http://www.apache.org/licenses/LICENSE-2.0
35332 *
35333 * Unless required by applicable law or agreed to in writing, software
35334 * distributed under the License is distributed on an "AS IS" BASIS,
35335 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35336 * See the License for the specific language governing permissions and
35337 * limitations under the License.
35338 * =============================================================================
35339 */
35340 /**
35341 * Computes the 2D max pooling of an image with Argmax index.
35342 * The indices in argmax are flattened, so that a maximum value at position `[b,
35343 * y, x, c]` becomes flattened index: `(y * width + x) * channels + c` if
35344 * include_batch_in_index is False; `((b * height + y) * width + x) * channels
35345 * +c` if include_batch_in_index is True.
35346 *
35347 * The indices returned are always in `[0, height) x [0, width)` before
35348 * flattening.
35349 *
35350 * @param x The input tensor, of rank 4 or rank 3 of shape
35351 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
35352 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
35353 * `filterSize` is a single number, then `filterHeight == filterWidth`.
35354 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
35355 * `strides` is a single number, then `strideHeight == strideWidth`.
35356 * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
35357 * "NDHWC". Specify the data format of the input and output data. With the
35358 * default format "NDHWC", the data is stored in the order of: [batch,
35359 * depth, height, width, channels]. Only "NDHWC" is currently supported.
35360 * @param pad The type of padding algorithm.
35361 * - `same` and stride 1: output will be of same size as input,
35362 * regardless of filter size.
35363 * - `valid`: output will be smaller than input if filter is larger
35364 * than 1x1.
35365 * - For more info, see this guide:
35366 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
35367 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
35368 * @param includeBatchIndex Defaults to False. Whether to include batch
35369 * dimension in flattened index of argmax.
35370 *
35371 * @doc {heading: 'Operations', subheading: 'Convolution'}
35372 */
35373 function maxPoolWithArgmax_(x, filterSize, strides, pad) {
35374 var includeBatchInIndex = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
35375 var $x = convertToTensor(x, 'x', 'maxPoolWithArgmax');
35376 var inputs = {
35377 x: $x
35378 };
35379 var attrs = {
35380 filterSize: filterSize,
35381 strides: strides,
35382 pad: pad,
35383 includeBatchInIndex: includeBatchInIndex
35384 };
35385 // tslint:disable-next-line: no-unnecessary-type-assertion
35386 var result = ENGINE.runKernel(MaxPoolWithArgmax, inputs, attrs);
35387 return {
35388 result: result[0],
35389 indexes: result[1]
35390 };
35391 }
35392 var maxPoolWithArgmax = /* @__PURE__ */op({
35393 maxPoolWithArgmax_: maxPoolWithArgmax_
35394 });
35395
35396 /**
35397 * Returns the max of a and b (`a > b ? a : b`) element-wise.
35398 * Supports broadcasting.
35399 *
35400 * We also expose `tf.maximumStrict` which has the same signature as this op and
35401 * asserts that `a` and `b` are the same shape (does not broadcast).
35402 *
35403 * ```js
35404 * const a = tf.tensor1d([1, 4, 3, 16]);
35405 * const b = tf.tensor1d([1, 2, 9, 4]);
35406 *
35407 * a.maximum(b).print(); // or tf.maximum(a, b)
35408 * ```
35409 *
35410 * ```js
35411 * // Broadcast maximum a with b.
35412 * const a = tf.tensor1d([2, 4, 6, 8]);
35413 * const b = tf.scalar(5);
35414 *
35415 * a.maximum(b).print(); // or tf.maximum(a, b)
35416 * ```
35417 *
35418 * @param a The first tensor.
35419 * @param b The second tensor. Must have the same type as `a`.
35420 *
35421 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
35422 */
35423 function maximum_(a, b) {
35424 var $a = convertToTensor(a, 'a', 'maximum');
35425 var $b = convertToTensor(b, 'b', 'maximum');
35426 var _makeTypesMatch = makeTypesMatch($a, $b);
35427 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
35428 $a = _makeTypesMatch2[0];
35429 $b = _makeTypesMatch2[1];
35430 if ($a.dtype === 'bool') {
35431 $a = cast$3($a, 'int32');
35432 $b = cast$3($b, 'int32');
35433 }
35434 assertAndGetBroadcastShape($a.shape, $b.shape);
35435 var inputs = {
35436 a: $a,
35437 b: $b
35438 };
35439 return ENGINE.runKernel(Maximum$1, inputs);
35440 }
35441 var maximum$4 = /* @__PURE__ */op({
35442 maximum_: maximum_
35443 });
35444
35445 /**
35446 * @license
35447 * Copyright 2020 Google Inc. All Rights Reserved.
35448 * Licensed under the Apache License, Version 2.0 (the "License");
35449 * you may not use this file except in compliance with the License.
35450 * You may obtain a copy of the License at
35451 *
35452 * http://www.apache.org/licenses/LICENSE-2.0
35453 *
35454 * Unless required by applicable law or agreed to in writing, software
35455 * distributed under the License is distributed on an "AS IS" BASIS,
35456 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35457 * See the License for the specific language governing permissions and
35458 * limitations under the License.
35459 * =============================================================================
35460 */
35461 /**
35462 * Computes the mean of elements across dimensions of a `tf.Tensor`.
35463 *
35464 * Reduces `x` along the dimensions given in `axis`. Unless `keepDims` is
35465 * true, the rank of the `tf.Tensor` is reduced by 1 for each entry in `axis`.
35466 * If `keepDims` is true, the reduced dimensions are retained with length 1.
35467 * If `axis` has no entries, all dimensions are reduced, and a `tf.Tensor` with
35468 * a single element is returned.
35469 *
35470 * ```js
35471 * const x = tf.tensor1d([1, 2, 3]);
35472 *
35473 * x.mean().print(); // or tf.mean(a)
35474 * ```
35475 *
35476 * ```js
35477 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
35478 *
35479 * const axis = 1;
35480 * x.mean(axis).print(); // or tf.mean(x, axis)
35481 * ```
35482 *
35483 * @param x The input tensor.
35484 * @param axis The dimension(s) to reduce. By default it reduces
35485 * all dimensions.
35486 * @param keepDims If true, retains reduced dimensions with size 1.
35487 *
35488 * @doc {heading: 'Operations', subheading: 'Reduction'}
35489 */
35490 function mean_(x) {
35491 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null;
35492 var keepDims = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
35493 var $x = convertToTensor(x, 'x', 'mean');
35494 var inputs = {
35495 x: $x
35496 };
35497 var attrs = {
35498 axis: axis,
35499 keepDims: keepDims
35500 };
35501 return ENGINE.runKernel(Mean, inputs, attrs);
35502 }
35503 var mean$3 = /* @__PURE__ */op({
35504 mean_: mean_
35505 });
35506
35507 /**
35508 * @license
35509 * Copyright 2018 Google LLC. All Rights Reserved.
35510 * Licensed under the Apache License, Version 2.0 (the "License");
35511 * you may not use this file except in compliance with the License.
35512 * You may obtain a copy of the License at
35513 *
35514 * http://www.apache.org/licenses/LICENSE-2.0
35515 *
35516 * Unless required by applicable law or agreed to in writing, software
35517 * distributed under the License is distributed on an "AS IS" BASIS,
35518 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35519 * See the License for the specific language governing permissions and
35520 * limitations under the License.
35521 * =============================================================================
35522 */
35523 /**
35524 * Creates a `tf.Tensor` with all elements set to 0.
35525 *
35526 * ```js
35527 * tf.zeros([2, 2]).print();
35528 * ```
35529 *
35530 * @param shape An array of integers defining the output tensor shape.
35531 * @param dtype The type of an element in the resulting tensor. Can
35532 * be 'float32', 'int32' or 'bool'. Defaults to 'float'.
35533 *
35534 * @doc {heading: 'Tensors', subheading: 'Creation'}
35535 */
35536 function zeros$2(shape) {
35537 var dtype = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 'float32';
35538 assertNonNegativeIntegerDimensions(shape);
35539 if (dtype === 'complex64') {
35540 var real = zeros$2(shape, 'float32');
35541 var imag = zeros$2(shape, 'float32');
35542 return complex$2(real, imag);
35543 }
35544 var values = makeZerosTypedArray(sizeFromShape(shape), dtype);
35545 return ENGINE.makeTensor(values, shape, dtype);
35546 }
35547
35548 /**
35549 * @license
35550 * Copyright 2018 Google LLC. All Rights Reserved.
35551 * Licensed under the Apache License, Version 2.0 (the "License");
35552 * you may not use this file except in compliance with the License.
35553 * You may obtain a copy of the License at
35554 *
35555 * http://www.apache.org/licenses/LICENSE-2.0
35556 *
35557 * Unless required by applicable law or agreed to in writing, software
35558 * distributed under the License is distributed on an "AS IS" BASIS,
35559 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35560 * See the License for the specific language governing permissions and
35561 * limitations under the License.
35562 * =============================================================================
35563 */
35564 /**
35565 * Creates a `tf.Tensor` with all elements set to 1.
35566 *
35567 * ```js
35568 * tf.ones([2, 2]).print();
35569 * ```
35570 *
35571 * @param shape An array of integers defining the output tensor shape.
35572 * @param dtype The type of an element in the resulting tensor. Defaults to
35573 * 'float'.
35574 *
35575 * @doc {heading: 'Tensors', subheading: 'Creation'}
35576 */
35577 function ones$1(shape) {
35578 var dtype = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 'float32';
35579 assertNonNegativeIntegerDimensions(shape);
35580 if (dtype === 'complex64') {
35581 var real = ones$1(shape, 'float32');
35582 var imag = zeros$2(shape, 'float32');
35583 return complex$2(real, imag);
35584 }
35585 var values = makeOnesTypedArray(sizeFromShape(shape), dtype);
35586 return ENGINE.makeTensor(values, shape, dtype);
35587 }
35588
35589 /**
35590 * @license
35591 * Copyright 2021 Google LLC. All Rights Reserved.
35592 * Licensed under the Apache License, Version 2.0 (the "License");
35593 * you may not use this file except in compliance with the License.
35594 * You may obtain a copy of the License at
35595 *
35596 * http://www.apache.org/licenses/LICENSE-2.0
35597 *
35598 * Unless required by applicable law or agreed to in writing, software
35599 * distributed under the License is distributed on an "AS IS" BASIS,
35600 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35601 * See the License for the specific language governing permissions and
35602 * limitations under the License.
35603 * =============================================================================
35604 */
35605 /**
35606 * Broadcasts parameters for evaluation on an N-D grid.
35607 *
35608 * Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
35609 * of N-D coordinate arrays for evaluating expressions on an N-D grid.
35610 *
35611 * Notes:
35612 * `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
35613 * When the `indexing` argument is set to 'xy' (the default), the broadcasting
35614 * instructions for the first two dimensions are swapped.
35615 * Examples:
35616 * Calling `const [X, Y] = meshgrid(x, y)` with the tensors
35617 *
35618 * ```javascript
35619 * const x = [1, 2, 3];
35620 * const y = [4, 5, 6];
35621 * const [X, Y] = tf.meshgrid(x, y);
35622 * // X = [[1, 2, 3],
35623 * // [1, 2, 3],
35624 * // [1, 2, 3]]
35625 * // Y = [[4, 4, 4],
35626 * // [5, 5, 5],
35627 * // [6, 6, 6]]
35628 * ```
35629 *
35630 * @param x Tensor with rank geq 1.
35631 * @param y Tensor with rank geq 1.
35632 * @param indexing
35633 *
35634 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
35635 */
35636 function meshgrid(x, y) {
35637 var _ref = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : {},
35638 _ref$indexing = _ref.indexing,
35639 indexing = _ref$indexing === void 0 ? 'xy' : _ref$indexing;
35640 if (indexing !== 'xy' && indexing !== 'ij') {
35641 throw new TypeError("".concat(indexing, " is not a valid third argument to meshgrid"));
35642 }
35643 if (x === undefined) {
35644 return [];
35645 }
35646 var $x = convertToTensor(x, 'x', 'meshgrid', x instanceof Tensor ? x.dtype : 'float32');
35647 if (y === undefined) {
35648 return [$x];
35649 }
35650 var $y = convertToTensor(y, 'y', 'meshgrid', y instanceof Tensor ? y.dtype : 'float32');
35651 var w = sizeFromShape($x.shape);
35652 var h = sizeFromShape($y.shape);
35653 if (indexing === 'xy') {
35654 $x = reshape$3($x, [1, -1]);
35655 $y = reshape$3($y, [-1, 1]);
35656 return [matMul$1(ones$1([h, 1], $x.dtype), $x), matMul$1($y, ones$1([1, w], $y.dtype))];
35657 }
35658 $x = reshape$3($x, [-1, 1]);
35659 $y = reshape$3($y, [1, -1]);
35660 return [matMul$1($x, ones$1([1, h], $x.dtype)), matMul$1(ones$1([w, 1], $y.dtype), $y)];
35661 }
35662
35663 /**
35664 * Returns the min of a and b (`a < b ? a : b`) element-wise.
35665 * Supports broadcasting.
35666 *
35667 * We also expose `minimumStrict` which has the same signature as this op and
35668 * asserts that `a` and `b` are the same shape (does not broadcast).
35669 *
35670 * ```js
35671 * const a = tf.tensor1d([1, 4, 3, 16]);
35672 * const b = tf.tensor1d([1, 2, 9, 4]);
35673 *
35674 * a.minimum(b).print(); // or tf.minimum(a, b)
35675 * ```
35676 *
35677 * ```js
35678 * // Broadcast minimum a with b.
35679 * const a = tf.tensor1d([2, 4, 6, 8]);
35680 * const b = tf.scalar(5);
35681 *
35682 * a.minimum(b).print(); // or tf.minimum(a, b)
35683 * ```
35684 *
35685 * @param a The first tensor.
35686 * @param b The second tensor. Must have the same type as `a`.
35687 *
35688 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
35689 */
35690 function minimum_(a, b) {
35691 var $a = convertToTensor(a, 'a', 'minimum');
35692 var $b = convertToTensor(b, 'b', 'minimum');
35693 var _makeTypesMatch = makeTypesMatch($a, $b);
35694 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
35695 $a = _makeTypesMatch2[0];
35696 $b = _makeTypesMatch2[1];
35697 if ($a.dtype === 'bool') {
35698 $a = cast$3($a, 'int32');
35699 $b = cast$3($b, 'int32');
35700 }
35701 assertAndGetBroadcastShape($a.shape, $b.shape);
35702 var inputs = {
35703 a: $a,
35704 b: $b
35705 };
35706 return ENGINE.runKernel(Minimum$1, inputs);
35707 }
35708 var minimum$4 = /* @__PURE__ */op({
35709 minimum_: minimum_
35710 });
35711
35712 /**
35713 * @license
35714 * Copyright 2020 Google LLC. All Rights Reserved.
35715 * Licensed under the Apache License, Version 2.0 (the "License");
35716 * you may not use this file except in compliance with the License.
35717 * You may obtain a copy of the License at
35718 *
35719 * http://www.apache.org/licenses/LICENSE-2.0
35720 *
35721 * Unless required by applicable law or agreed to in writing, software
35722 * distributed under the License is distributed on an "AS IS" BASIS,
35723 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35724 * See the License for the specific language governing permissions and
35725 * limitations under the License.
35726 * =============================================================================
35727 */
35728 /**
35729 * Pads a `tf.Tensor` using mirror padding.
35730 *
35731 * This operation implements the `REFLECT` and `SYMMETRIC` modes of pad.
35732 *
35733 * ```js
35734 * const x = tf.range(0, 9).reshape([1, 1, 3, 3]);
35735 * x.mirrorPad([[0, 0], [0, 0], [2, 2], [2, 2]], 'reflect').print();
35736 * ```
35737 * @param x The tensor to pad.
35738 * @param paddings An array of length `R` (the rank of the tensor), where
35739 * each element is a length-2 tuple of ints `[padBefore, padAfter]`,
35740 * specifying how much to pad along each dimension of the tensor.
35741 * In "reflect" mode, the padded regions do not include the borders,
35742 * while in "symmetric" mode the padded regions do include the borders.
35743 * For example, if the input is `[1, 2, 3]` and paddings is `[0, 2]`,
35744 * then the output is `[1, 2, 3, 2, 1]` in "reflect" mode, and
35745 * `[1, 2, 3, 3, 2]` in "symmetric" mode.
35746 * If `mode` is "reflect" then both `paddings[D, 0]` and `paddings[D, 1]`
35747 * must be no greater than `x.shape[D] - 1`. If mode is "symmetric"
35748 * then both `paddings[D, 0]` and `paddings[D, 1]` must be no greater than
35749 * `x.shape[D]`
35750 * @param mode String to specify padding mode. Can be `'reflect' | 'symmetric'`
35751 */
35752 /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
35753 function mirrorPad_(x, paddings, mode) {
35754 assert$1(mode === 'reflect' || mode === 'symmetric', function () {
35755 return "Invalid mode. Mode must be either reflect or symmetric. " + "Got ".concat(mode, ".");
35756 });
35757 var $x = convertToTensor(x, 'x', 'mirrorPad');
35758 if ($x.rank === 0) {
35759 throw new Error('mirrorPad(scalar) is not defined. ' + 'Pass non-scalar to mirrorPad');
35760 }
35761 assert$1(paddings.length === $x.rank, function () {
35762 return "Padding doesn't match input. Must be ".concat($x.rank, ". ") + "Got ".concat(paddings.length, ".");
35763 });
35764 var shapeOffset = mode === 'reflect' ? 1 : 0;
35765 var _loop = function _loop(i) {
35766 assert$1(paddings[i].length === 2, function () {
35767 return "Invalid number of paddings. Must be length of 2 each.";
35768 });
35769 assert$1(paddings[i][0] >= 0 && paddings[i][0] <= $x.shape[i] - shapeOffset && paddings[i][1] >= 0 && paddings[i][1] <= $x.shape[i] - shapeOffset, function () {
35770 return "Padding in dimension ".concat(i, " cannot be greater than or equal ") + "to ".concat($x.shape[i] - shapeOffset, " or less than 0 for input of ") + "shape ".concat($x.shape);
35771 });
35772 };
35773 for (var i = 0; i < $x.rank; i++) {
35774 _loop(i);
35775 }
35776 var attrs = {
35777 paddings: paddings,
35778 mode: mode
35779 };
35780 var inputs = {
35781 x: $x
35782 };
35783 return ENGINE.runKernel(MirrorPad, inputs, attrs);
35784 }
35785 var mirrorPad$1 = /* @__PURE__ */op({
35786 mirrorPad_: mirrorPad_
35787 });
35788
35789 /**
35790 * Returns the mod of a and b element-wise.
35791 * `floor(x / y) * y + mod(x, y) = x`
35792 * Supports broadcasting.
35793 *
35794 * We also expose `tf.modStrict` which has the same signature as this op and
35795 * asserts that `a` and `b` are the same shape (does not broadcast).
35796 *
35797 * ```js
35798 * const a = tf.tensor1d([1, 4, 3, 16]);
35799 * const b = tf.tensor1d([1, 2, 9, 4]);
35800 *
35801 * a.mod(b).print(); // or tf.mod(a, b)
35802 * ```
35803 *
35804 * ```js
35805 * // Broadcast a mod b.
35806 * const a = tf.tensor1d([2, 4, 6, 8]);
35807 * const b = tf.scalar(5);
35808 *
35809 * a.mod(b).print(); // or tf.mod(a, b)
35810 * ```
35811 *
35812 * @param a The first tensor.
35813 * @param b The second tensor. Must have the same type as `a`.
35814 *
35815 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
35816 */
35817 function mod_(a, b) {
35818 var $a = convertToTensor(a, 'a', 'mod');
35819 var $b = convertToTensor(b, 'b', 'mod');
35820 var _makeTypesMatch = makeTypesMatch($a, $b);
35821 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
35822 $a = _makeTypesMatch2[0];
35823 $b = _makeTypesMatch2[1];
35824 var inputs = {
35825 a: $a,
35826 b: $b
35827 };
35828 return ENGINE.runKernel(Mod, inputs);
35829 }
35830 var mod$2 = /* @__PURE__ */op({
35831 mod_: mod_
35832 });
35833
35834 /**
35835 * @license
35836 * Copyright 2020 Google LLC. All Rights Reserved.
35837 * Licensed under the Apache License, Version 2.0 (the "License");
35838 * you may not use this file except in compliance with the License.
35839 * You may obtain a copy of the License at
35840 *
35841 * http://www.apache.org/licenses/LICENSE-2.0
35842 *
35843 * Unless required by applicable law or agreed to in writing, software
35844 * distributed under the License is distributed on an "AS IS" BASIS,
35845 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35846 * See the License for the specific language governing permissions and
35847 * limitations under the License.
35848 * =============================================================================
35849 */
35850 /**
35851 * Calculates the mean and variance of `x`. The mean and variance are
35852 * calculated by aggregating the contents of `x` across `axes`. If `x` is
35853 * 1-D and `axes = [0]` this is just the mean and variance of a vector.
35854 *
35855 * @param x The input tensor.
35856 * @param axis The dimension(s) along with to compute mean and
35857 * variance. By default it reduces all dimensions.
35858 * @param keepDims If true, the moments have the same dimensionality as the
35859 * input.
35860 * @return An object with two keys: `mean` and `variance`.
35861 *
35862 * @doc {heading: 'Operations', subheading: 'Normalization'}
35863 */
35864 function moments_(x) {
35865 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null;
35866 var keepDims = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
35867 x = convertToTensor(x, 'x', 'moments');
35868 var axes = parseAxisParam(axis, x.shape);
35869 var xMean = mean$3(x, axes, keepDims);
35870 var keepDimsShape = xMean.shape;
35871 if (!keepDims) {
35872 keepDimsShape = expandShapeToKeepDim(xMean.shape, axes);
35873 }
35874 var devSquared = square$2(sub$2(cast$3(x, 'float32'), reshape$3(xMean, keepDimsShape)));
35875 var variance = mean$3(devSquared, axes, keepDims);
35876 return {
35877 mean: xMean,
35878 variance: variance
35879 };
35880 }
35881 var moments = /* @__PURE__ */op({
35882 moments_: moments_
35883 });
35884
35885 /**
35886 * Computes the next states and outputs of a stack of LSTMCells.
35887 *
35888 * Each cell output is used as input to the next cell.
35889 *
35890 * Returns `[cellState, cellOutput]`.
35891 *
35892 * Derived from tf.contrib.rn.MultiRNNCell.
35893 *
35894 * @param lstmCells Array of LSTMCell functions.
35895 * @param data The input to the cell.
35896 * @param c Array of previous cell states.
35897 * @param h Array of previous cell outputs.
35898 *
35899 * @doc {heading: 'Operations', subheading: 'RNN'}
35900 */
35901 function multiRNNCell_(lstmCells, data, c, h) {
35902 var $data = convertToTensor(data, 'data', 'multiRNNCell');
35903 var $c = convertToTensorArray(c, 'c', 'multiRNNCell');
35904 var $h = convertToTensorArray(h, 'h', 'multiRNNCell');
35905 var input = $data;
35906 var newStates = [];
35907 for (var i = 0; i < lstmCells.length; i++) {
35908 var output = lstmCells[i](input, $c[i], $h[i]);
35909 newStates.push(output[0]);
35910 newStates.push(output[1]);
35911 input = output[1];
35912 }
35913 var newC = [];
35914 var newH = [];
35915 for (var _i = 0; _i < newStates.length; _i += 2) {
35916 newC.push(newStates[_i]);
35917 newH.push(newStates[_i + 1]);
35918 }
35919 return [newC, newH];
35920 }
35921 var multiRNNCell = /* @__PURE__ */op({
35922 multiRNNCell_: multiRNNCell_
35923 });
35924
35925 /**
35926 * @license
35927 * Copyright 2020 Google LLC. All Rights Reserved.
35928 * Licensed under the Apache License, Version 2.0 (the "License");
35929 * you may not use this file except in compliance with the License.
35930 * You may obtain a copy of the License at
35931 *
35932 * http://www.apache.org/licenses/LICENSE-2.0
35933 *
35934 * Unless required by applicable law or agreed to in writing, software
35935 * distributed under the License is distributed on an "AS IS" BASIS,
35936 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35937 * See the License for the specific language governing permissions and
35938 * limitations under the License.
35939 * =============================================================================
35940 */
35941 /**
35942 * Creates a `tf.Tensor` with values drawn from a multinomial distribution.
35943 *
35944 * ```js
35945 * const probs = tf.tensor([.75, .25]);
35946 * tf.multinomial(probs, 3).print();
35947 * ```
35948 *
35949 * @param logits 1D array with unnormalized log-probabilities, or
35950 * 2D array of shape `[batchSize, numOutcomes]`. See the `normalized`
35951 * parameter.
35952 * @param numSamples Number of samples to draw for each row slice.
35953 * @param seed The seed number.
35954 * @param normalized Whether the provided `logits` are normalized true
35955 * probabilities (sum to 1). Defaults to false.
35956 * @return 1D array of shape `[numSamples]`, or 2D array of shape
35957 * `[batchSize, numSamples]`, depending on the rank of the input.
35958 *
35959 * @doc {heading: 'Tensors', subheading: 'Random'}
35960 */
35961 function multinomial_(logits, numSamples, seed) {
35962 var normalized = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
35963 var $logits = convertToTensor(logits, 'logits', 'multinomial');
35964 var numOutcomes = $logits.size;
35965 var origRank = $logits.rank;
35966 if (numOutcomes < 2) {
35967 throw new Error("Error in multinomial: you need at least 2 outcomes, but got " + "".concat(numOutcomes, "."));
35968 }
35969 if (origRank > 2) {
35970 throw new Error("Rank of probabilities must be 1 or 2, but is ".concat(origRank));
35971 }
35972 // TODO(lina128): Investigate correct seed behavior. The code seems not allow
35973 // setting see to 0.
35974 seed = seed || Math.random();
35975 // The kernel only accepts (and returns) rank 2 tensors.
35976 var logits2D = origRank === 1 ? reshape$3($logits, [1, -1]) : $logits;
35977 var inputs = {
35978 logits: logits2D
35979 };
35980 var attrs = {
35981 numSamples: numSamples,
35982 seed: seed,
35983 normalized: normalized
35984 };
35985 // tslint:disable-next-line: no-unnecessary-type-assertion
35986 var res = ENGINE.runKernel(Multinomial, inputs, attrs);
35987 // tslint:disable-next-line:no-unnecessary-type-assertion
35988 return origRank === 1 ? reshape$3(res, [res.size]) : res;
35989 }
35990 var multinomial$2 = /* @__PURE__ */op({
35991 multinomial_: multinomial_
35992 });
35993
35994 /**
35995 * Returns the truth value of (a != b) element-wise. Supports broadcasting.
35996 *
35997 * ```js
35998 * const a = tf.tensor1d([1, 2, 3]);
35999 * const b = tf.tensor1d([0, 2, 3]);
36000 *
36001 * a.notEqual(b).print();
36002 * ```
36003 * @param a The first input tensor.
36004 * @param b The second input tensor. Must have the same dtype as `a`.
36005 *
36006 * @doc {heading: 'Operations', subheading: 'Logical'}
36007 */
36008 function notEqual_(a, b) {
36009 var $a = convertToTensor(a, 'a', 'notEqual', 'string_or_numeric');
36010 var $b = convertToTensor(b, 'b', 'notEqual', 'string_or_numeric');
36011 var _makeTypesMatch = makeTypesMatch($a, $b);
36012 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
36013 $a = _makeTypesMatch2[0];
36014 $b = _makeTypesMatch2[1];
36015 assertAndGetBroadcastShape($a.shape, $b.shape);
36016 var inputs = {
36017 a: $a,
36018 b: $b
36019 };
36020 return ENGINE.runKernel(NotEqual, inputs);
36021 }
36022 var notEqual$2 = /* @__PURE__ */op({
36023 notEqual_: notEqual_
36024 });
36025
36026 /**
36027 * @license
36028 * Copyright 2020 Google LLC. All Rights Reserved.
36029 * Licensed under the Apache License, Version 2.0 (the "License");
36030 * you may not use this file except in compliance with the License.
36031 * You may obtain a copy of the License at
36032 *
36033 * http://www.apache.org/licenses/LICENSE-2.0
36034 *
36035 * Unless required by applicable law or agreed to in writing, software
36036 * distributed under the License is distributed on an "AS IS" BASIS,
36037 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36038 * See the License for the specific language governing permissions and
36039 * limitations under the License.
36040 * =============================================================================
36041 */
36042 /**
36043 * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take
36044 * value `onValue` (defaults to 1), while all other locations take value
36045 * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank
36046 * `R+1` with the last axis of size `depth`.
36047 * `indices` used to encode prediction class must start from 0. For example,
36048 * if you have 3 classes of data, class 1 should be encoded as 0, class 2
36049 * should be 1, and class 3 should be 2.
36050 *
36051 * ```js
36052 * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print();
36053 * ```
36054 *
36055 * @param indices `tf.Tensor` of indices with dtype `int32`. Indices must
36056 * start from 0.
36057 * @param depth The depth of the one hot dimension.
36058 * @param onValue A number used to fill in the output when the index matches
36059 * the location.
36060 * @param offValue A number used to fill in the output when the index does
36061 * not match the location.
36062 * @param dtype The dtype of the output tensor, default to 'int32'.
36063 *
36064 * @doc {heading: 'Tensors', subheading: 'Creation'}
36065 */
36066 function oneHot_(indices, depth) {
36067 var onValue = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1;
36068 var offValue = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 0;
36069 var dtype = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 'int32';
36070 if (depth < 2) {
36071 throw new Error("Error in oneHot: depth must be >=2, but it is ".concat(depth));
36072 }
36073 var $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
36074 var inputs = {
36075 indices: $indices
36076 };
36077 var attrs = {
36078 dtype: dtype,
36079 depth: depth,
36080 onValue: onValue,
36081 offValue: offValue
36082 };
36083 return ENGINE.runKernel(OneHot, inputs, attrs);
36084 }
36085 var oneHot$3 = /* @__PURE__ */op({
36086 oneHot_: oneHot_
36087 });
36088
36089 /**
36090 * @license
36091 * Copyright 2018 Google LLC. All Rights Reserved.
36092 * Licensed under the Apache License, Version 2.0 (the "License");
36093 * you may not use this file except in compliance with the License.
36094 * You may obtain a copy of the License at
36095 *
36096 * http://www.apache.org/licenses/LICENSE-2.0
36097 *
36098 * Unless required by applicable law or agreed to in writing, software
36099 * distributed under the License is distributed on an "AS IS" BASIS,
36100 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36101 * See the License for the specific language governing permissions and
36102 * limitations under the License.
36103 * =============================================================================
36104 */
36105 /**
36106 * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the
36107 * given tensor.
36108 *
36109 * ```js
36110 * const x = tf.tensor([1, 2]);
36111 * tf.onesLike(x).print();
36112 * ```
36113 * @param x A tensor.
36114 *
36115 * @doc {heading: 'Tensors', subheading: 'Creation'}
36116 */
36117 function onesLike_(x) {
36118 var $x = convertToTensor(x, 'x', 'onesLike');
36119 var inputs = {
36120 x: $x
36121 };
36122 return ENGINE.runKernel(OnesLike, inputs);
36123 }
36124 var onesLike$3 = /* @__PURE__ */op({
36125 onesLike_: onesLike_
36126 });
36127
36128 /**
36129 * Computes the outer product of two vectors, `v1` and `v2`.
36130 *
36131 * ```js
36132 * const a = tf.tensor1d([1, 2, 3]);
36133 * const b = tf.tensor1d([3, 4, 5]);
36134 *
36135 * tf.outerProduct(a, b).print();
36136 * ```
36137 * @param v1 The first vector in the outer product operation.
36138 * @param v2 The second vector in the outer product operation.
36139 *
36140 * @doc {heading: 'Operations', subheading: 'Matrices'}
36141 */
36142 function outerProduct_(v1, v2) {
36143 var $v1 = convertToTensor(v1, 'v1', 'outerProduct');
36144 var $v2 = convertToTensor(v2, 'v2', 'outerProduct');
36145 assert$1($v1.rank === 1 && $v2.rank === 1, function () {
36146 return "Error in outerProduct: inputs must be rank 1, but got ranks " + "".concat($v1.rank, " and ").concat($v2.rank, ".");
36147 });
36148 var v12D = reshape$3($v1, [-1, 1]);
36149 var v22D = reshape$3($v2, [1, -1]);
36150 return matMul$1(v12D, v22D);
36151 }
36152 var outerProduct = /* @__PURE__ */op({
36153 outerProduct_: outerProduct_
36154 });
36155
36156 /**
36157 * @license
36158 * Copyright 2020 Google LLC. All Rights Reserved.
36159 * Licensed under the Apache License, Version 2.0 (the "License");
36160 * you may not use this file except in compliance with the License.
36161 * You may obtain a copy of the License at
36162 *
36163 * http://www.apache.org/licenses/LICENSE-2.0
36164 *
36165 * Unless required by applicable law or agreed to in writing, software
36166 * distributed under the License is distributed on an "AS IS" BASIS,
36167 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36168 * See the License for the specific language governing permissions and
36169 * limitations under the License.
36170 * =============================================================================
36171 */
36172 /**
36173 * Pads a `tf.Tensor` with a given value and paddings.
36174 *
36175 * This operation implements `CONSTANT` mode. For `REFLECT` and `SYMMETRIC`,
36176 * refer to `tf.mirrorPad`.
36177 *
36178 * Also available are stricter rank-specific methods with the same signature
36179 * as this method that assert that `paddings` is of given length.
36180 * - `tf.pad1d`
36181 * - `tf.pad2d`
36182 * - `tf.pad3d`
36183 * - `tf.pad4d`
36184 *
36185 * ```js
36186 * const x = tf.tensor1d([1, 2, 3, 4]);
36187 * x.pad([[1, 2]]).print();
36188 * ```
36189 * @param x The tensor to pad.
36190 * @param paddings An array of length `R` (the rank of the tensor), where
36191 * each element is a length-2 tuple of ints `[padBefore, padAfter]`,
36192 * specifying how much to pad along each dimension of the tensor.
36193 * @param constantValue The pad value to use. Defaults to 0.
36194 *
36195 * @doc {heading: 'Tensors', subheading: 'Transformations'}
36196 */
36197 function pad_(x, paddings) {
36198 var constantValue = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
36199 var $x = convertToTensor(x, 'x', 'pad');
36200 if ($x.rank === 0) {
36201 throw new Error('pad(scalar) is not defined. Pass non-scalar to pad');
36202 }
36203 var attrs = {
36204 paddings: paddings,
36205 constantValue: constantValue
36206 };
36207 var inputs = {
36208 x: $x
36209 };
36210 return ENGINE.runKernel(PadV2, inputs, attrs);
36211 }
36212 var pad = /* @__PURE__ */op({
36213 pad_: pad_
36214 });
36215
36216 /**
36217 * Pads a `tf.Tensor1D` with a given value and paddings. See `pad` for details.
36218 */
36219 function pad1d_(x, paddings) {
36220 var constantValue = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
36221 assert$1(paddings.length === 2, function () {
36222 return 'Invalid number of paddings. Must be length of 2.';
36223 });
36224 return pad(x, [paddings], constantValue);
36225 }
36226 var pad1d = /* @__PURE__ */op({
36227 pad1d_: pad1d_
36228 });
36229
36230 /**
36231 * Pads a `tf.Tensor2D` with a given value and paddings. See `pad` for details.
36232 */
36233 function pad2d_(x, paddings) {
36234 var constantValue = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
36235 assert$1(paddings.length === 2 && paddings[0].length === 2 && paddings[1].length === 2, function () {
36236 return 'Invalid number of paddings. Must be length of 2 each.';
36237 });
36238 return pad(x, paddings, constantValue);
36239 }
36240 var pad2d = /* @__PURE__ */op({
36241 pad2d_: pad2d_
36242 });
36243
36244 /**
36245 * Pads a `tf.Tensor3D` with a given value and paddings. See `pad` for details.
36246 */
36247 function pad3d_(x, paddings) {
36248 var constantValue = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
36249 assert$1(paddings.length === 3 && paddings[0].length === 2 && paddings[1].length === 2 && paddings[2].length === 2, function () {
36250 return 'Invalid number of paddings. Must be length of 2 each.';
36251 });
36252 return pad(x, paddings, constantValue);
36253 }
36254 var pad3d = /* @__PURE__ */op({
36255 pad3d_: pad3d_
36256 });
36257
36258 /**
36259 * Pads a `tf.Tensor4D` with a given value and paddings. See `pad` for details.
36260 */
36261 function pad4d_(x, paddings) {
36262 var constantValue = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
36263 assert$1(paddings.length === 4 && paddings[0].length === 2 && paddings[1].length === 2 && paddings[2].length === 2 && paddings[3].length === 2, function () {
36264 return 'Invalid number of paddings. Must be length of 2 each.';
36265 });
36266 return pad(x, paddings, constantValue);
36267 }
36268 var pad4d = /* @__PURE__ */op({
36269 pad4d_: pad4d_
36270 });
36271
36272 /**
36273 * @license
36274 * Copyright 2020 Google LLC. All Rights Reserved.
36275 * Licensed under the Apache License, Version 2.0 (the "License");
36276 * you may not use this file except in compliance with the License.
36277 * You may obtain a copy of the License at
36278 *
36279 * http://www.apache.org/licenses/LICENSE-2.0
36280 *
36281 * Unless required by applicable law or agreed to in writing, software
36282 * distributed under the License is distributed on an "AS IS" BASIS,
36283 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36284 * See the License for the specific language governing permissions and
36285 * limitations under the License.
36286 * =============================================================================
36287 */
36288 /**
36289 * This operation divides "spatial" dimensions `[1, ..., M]` of the input into
36290 * a grid of blocks of shape `blockShape`, and interleaves these blocks with
36291 * the "batch" dimension (0) such that in the output, the spatial
36292 * dimensions `[1, ..., M]` correspond to the position within the grid,
36293 * and the batch dimension combines both the position within a spatial block
36294 * and the original batch position. Prior to division into blocks,
36295 * the spatial dimensions of the input are optionally zero padded
36296 * according to `paddings`. See below for a precise description.
36297 *
36298 * ```js
36299 * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
36300 * const blockShape = [2, 2];
36301 * const paddings = [[0, 0], [0, 0]];
36302 *
36303 * x.spaceToBatchND(blockShape, paddings).print();
36304 * ```
36305 *
36306 * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
36307 * remainingShape`, where spatialShape has `M` dimensions.
36308 * @param blockShape A 1-D array. Must have shape `[M]`, all values must
36309 * be >= 1.
36310 * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >=
36311 * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad
36312 * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It
36313 * is required that
36314 * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0`
36315 *
36316 * This operation is equivalent to the following steps:
36317 *
36318 * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input
36319 * according to `paddings` to produce `padded` of shape paddedShape.
36320 *
36321 * 2. Reshape `padded` to `reshapedPadded` of shape:
36322 * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ...,
36323 * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape`
36324 *
36325 * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded`
36326 * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ...,
36327 * paddedShape[M] / blockShape[M-1]] + remainingShape`
36328 *
36329 * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the
36330 * batch dimension, producing an output tensor of shape:
36331 * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ...,
36332 * paddedShape[M] / blockShape[M-1]] + remainingShape`
36333 *
36334 * @doc {heading: 'Tensors', subheading: 'Transformations'}
36335 */
36336 function spaceToBatchND_(x, blockShape, paddings) {
36337 var $x = convertToTensor(x, 'x', 'spaceToBatchND');
36338 assert$1($x.rank >= 1 + blockShape.length, function () {
36339 return "input rank ".concat($x.rank, " should be > than [blockShape] ").concat(blockShape.length);
36340 });
36341 assert$1(paddings.length === blockShape.length, function () {
36342 return "paddings.shape[0] ".concat(paddings.length, " must be equal to [blockShape] ").concat(blockShape.length);
36343 });
36344 assert$1($x.shape.reduce(function (a, b, i) {
36345 if (i > 0 && i <= blockShape.length) {
36346 return a && (b + paddings[i - 1][0] + paddings[i - 1][1]) % blockShape[i - 1] === 0;
36347 }
36348 return a;
36349 }, true), function () {
36350 return "input spatial dimensions ".concat($x.shape.slice(1), " with paddings ").concat(paddings.toString(), " must be divisible by blockShapes ").concat(blockShape.toString());
36351 });
36352 var inputs = {
36353 x: $x
36354 };
36355 var attrs = {
36356 blockShape: blockShape,
36357 paddings: paddings
36358 };
36359 return ENGINE.runKernel(SpaceToBatchND, inputs, attrs);
36360 }
36361 var spaceToBatchND$2 = /* @__PURE__ */op({
36362 spaceToBatchND_: spaceToBatchND_
36363 });
36364
36365 /**
36366 * Performs an N-D pooling operation
36367 *
36368 * @param input The input tensor, of rank 4 or rank 3 of shape
36369 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
36370 * @param windowShape The filter size: `[filterHeight, filterWidth]`. If
36371 * `filterSize` is a single number, then `filterHeight == filterWidth`.
36372 * @param poolingType The type of pooling, either 'max' or 'avg'.
36373 * @param pad The type of padding algorithm:
36374 * - `same` and stride 1: output will be of same size as input,
36375 * regardless of filter size.
36376 * - `valid`: output will be smaller than input if filter is larger
36377 * than 1x1.
36378 * - For more info, see this guide:
36379 * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
36380 * https://www.tensorflow.org/api_guides/python/nn#Convolution)
36381 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
36382 * in which we sample input values across the height and width dimensions
36383 * in dilated pooling. Defaults to `[1, 1]`. If `dilationRate` is a single
36384 * number, then `dilationHeight == dilationWidth`. If it is greater than
36385 * 1, then all values of `strides` must be 1.
36386 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
36387 * `strides` is a single number, then `strideHeight == strideWidth`.
36388 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
36389 * provided, it will default to truncate.
36390 *
36391 * @doc {heading: 'Operations', subheading: 'Convolution'}
36392 */
36393 function pool_(input, windowShape, poolingType, pad, dilations, strides, dimRoundingMode) {
36394 if (dilations == null) {
36395 dilations = [1, 1];
36396 }
36397 if (strides == null) {
36398 strides = 1;
36399 }
36400 if (pad === 0) {
36401 pad = 'valid';
36402 }
36403 var $x = convertToTensor(input, 'x', 'maxPool');
36404 var x4D = $x;
36405 var reshapedTo4D = false;
36406 if ($x.rank === 3) {
36407 reshapedTo4D = true;
36408 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
36409 }
36410 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
36411 return 'Error in pool: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
36412 });
36413 var convInfo = computePool2DInfo(x4D.shape, windowShape, strides, dilations, pad);
36414 var dilation = [convInfo.dilationHeight, convInfo.dilationWidth];
36415 // The following implementation does batchToSpace(pool(spaceToBatch(x)))
36416 // whenever dilation > 1 since the TF kernels do not support dilation > 1.
36417 // tslint:disable-next-line:max-line-length
36418 // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L1037
36419 var basePadding;
36420 if (pad === 'same') {
36421 basePadding = withSpaceToBatchBasePaddings([convInfo.filterHeight, convInfo.filterWidth], dilation);
36422 } else {
36423 basePadding = [[0, 0], [0, 0]];
36424 }
36425 var isDilationOne = dilation[0] === 1 && dilation[1] === 1;
36426 var _requiredSpaceToBatch = requiredSpaceToBatchPaddings([convInfo.inHeight, convInfo.inWidth], dilation, basePadding),
36427 _requiredSpaceToBatch2 = _slicedToArray(_requiredSpaceToBatch, 2),
36428 adjustedPadding = _requiredSpaceToBatch2[0],
36429 adjustedCrops = _requiredSpaceToBatch2[1];
36430 var convertedPad = isDilationOne ? pad : 'valid';
36431 var convertedX = isDilationOne ? x4D : spaceToBatchND$2(x4D, dilation, adjustedPadding);
36432 var forwardOp = poolingType === 'avg' ? function () {
36433 return avgPool$2(convertedX, windowShape, strides, convertedPad, dimRoundingMode);
36434 } : function () {
36435 return maxPool$2(convertedX, windowShape, strides, convertedPad, dimRoundingMode);
36436 };
36437 var y = forwardOp();
36438 var res = isDilationOne ? y : batchToSpaceND$2(y, dilation, adjustedCrops);
36439 if (reshapedTo4D) {
36440 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
36441 }
36442 return res;
36443 }
36444 // Helper function to compute crops and paddings for pool with dilation > 1.
36445 // tslint:disable-next-line:max-line-length
36446 // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/array_ops.py#L2184
36447 function requiredSpaceToBatchPaddings(inputShape, blockShape, basePadding) {
36448 var padStart = basePadding.map(function (b) {
36449 return b[0];
36450 });
36451 var origPadEnd = basePadding.map(function (b) {
36452 return b[1];
36453 });
36454 var fullInputShape = inputShape.concat(padStart, origPadEnd);
36455 var padEndExtra = blockShape.map(function (b, i) {
36456 return (b - fullInputShape[i] % b) % b;
36457 });
36458 var padEnd = origPadEnd.map(function (s, i) {
36459 return s + padEndExtra[i];
36460 });
36461 var paddings = blockShape.map(function (_, i) {
36462 return [padStart[i], padEnd[i]];
36463 });
36464 var crops = blockShape.map(function (_, i) {
36465 return [0, padEndExtra[i]];
36466 });
36467 return [paddings, crops];
36468 }
36469 // Helper function to compute base paddings for pool with dilation > 1.
36470 // tslint:disable-next-line:max-line-length
36471 // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L524
36472 function withSpaceToBatchBasePaddings(filterShape, dilation) {
36473 // Spatial dimensions of the filters and the upsampled filters in which we
36474 // introduce (rate - 1) zeros between consecutive filter values.
36475 var dilatedFilterShape = filterShape.map(function (s, i) {
36476 return s + (s - 1) * (dilation[i] - 1);
36477 });
36478 var padExtraShape = dilatedFilterShape.map(function (s) {
36479 return s - 1;
36480 });
36481 // When padding is odd, we pad more at end, following the same
36482 // convention as conv2d.
36483 var padExtraStart = padExtraShape.map(function (s) {
36484 return Math.floor(s / 2);
36485 });
36486 var padExtraEnd = padExtraShape.map(function (s, i) {
36487 return s - padExtraStart[i];
36488 });
36489 return padExtraShape.map(function (_, i) {
36490 return [padExtraStart[i], padExtraEnd[i]];
36491 });
36492 }
36493 var pool$1 = /* @__PURE__ */op({
36494 pool_: pool_
36495 });
36496
36497 /**
36498 * @license
36499 * Copyright 2020 Google LLC. All Rights Reserved.
36500 * Licensed under the Apache License, Version 2.0 (the "License");
36501 * you may not use this file except in compliance with the License.
36502 * You may obtain a copy of the License at
36503 *
36504 * http://www.apache.org/licenses/LICENSE-2.0
36505 *
36506 * Unless required by applicable law or agreed to in writing, software
36507 * distributed under the License is distributed on an "AS IS" BASIS,
36508 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36509 * See the License for the specific language governing permissions and
36510 * limitations under the License.
36511 * =============================================================================
36512 */
36513 /**
36514 * Computes leaky rectified linear element-wise with parametric alphas.
36515 *
36516 * `x < 0 ? alpha * x : f(x) = x`
36517 *
36518 * ```js
36519 * const x = tf.tensor1d([-1, 2, -3, 4]);
36520 * const alpha = tf.scalar(0.1);
36521 *
36522 * x.prelu(alpha).print(); // or tf.prelu(x, alpha)
36523 * ```
36524 * @param x The input tensor.
36525 * @param alpha Scaling factor for negative values.
36526 *
36527 * @doc {heading: 'Operations', subheading: 'Basic math'}
36528 */
36529 function prelu_(x, alpha) {
36530 var $x = convertToTensor(x, 'x', 'prelu');
36531 var $alpha = convertToTensor(alpha, 'alpha', 'prelu');
36532 var inputs = {
36533 x: $x,
36534 alpha: $alpha
36535 };
36536 return ENGINE.runKernel(Prelu, inputs);
36537 }
36538 var prelu$3 = /* @__PURE__ */op({
36539 prelu_: prelu_
36540 });
36541
36542 /**
36543 * @license
36544 * Copyright 2020 Google LLC. All Rights Reserved.
36545 * Licensed under the Apache License, Version 2.0 (the "License");
36546 * you may not use this file except in compliance with the License.
36547 * You may obtain a copy of the License at
36548 *
36549 * http://www.apache.org/licenses/LICENSE-2.0
36550 *
36551 * Unless required by applicable law or agreed to in writing, software
36552 * distributed under the License is distributed on an "AS IS" BASIS,
36553 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36554 * See the License for the specific language governing permissions and
36555 * limitations under the License.
36556 * =============================================================================
36557 */
36558 /**
36559 * Computes the product of elements across dimensions of a `tf.Tensor`.
36560 *
36561 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
36562 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
36563 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
36564 * length 1. If `axes` has no entries, all dimensions are reduced, and a
36565 * `tf.Tensor` with a single element is returned.
36566 *
36567 * ```js
36568 * const x = tf.tensor1d([1, 2, 3]);
36569 *
36570 * x.prod().print(); // or tf.prod(x)
36571 * ```
36572 *
36573 * ```js
36574 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
36575 *
36576 * const axis = 1;
36577 * x.prod(axis).print(); // or tf.prod(x, axis)
36578 * ```
36579 *
36580 * @param x The input tensor to compute the product over. If the dtype is `bool`
36581 * it will be converted to `int32` and the output dtype will be `int32`.
36582 * @param axis The dimension(s) to reduce. By default it reduces
36583 * all dimensions.
36584 * @param keepDims If true, retains reduced dimensions with size 1.
36585 *
36586 * @doc {heading: 'Operations', subheading: 'Reduction'}
36587 */
36588 function prod_(x) {
36589 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null;
36590 var keepDims = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
36591 var $x = convertToTensor(x, 'x', 'prod');
36592 if ($x.dtype === 'bool') {
36593 // bool is not an allowed type for the underlying kernel.
36594 $x = cast$3($x, 'int32');
36595 }
36596 var inputs = {
36597 x: $x
36598 };
36599 var attrs = {
36600 axis: axis,
36601 keepDims: keepDims
36602 };
36603 return ENGINE.runKernel(Prod, inputs, attrs);
36604 }
36605 var prod$2 = /* @__PURE__ */op({
36606 prod_: prod_
36607 });
36608
36609 /**
36610 * @license
36611 * Copyright 2022 Google LLC. All Rights Reserved.
36612 * Licensed under the Apache License, Version 2.0 (the "License");
36613 * you may not use this file except in compliance with the License.
36614 * You may obtain a copy of the License at
36615 *
36616 * http://www.apache.org/licenses/LICENSE-2.0
36617 *
36618 * Unless required by applicable law or agreed to in writing, software
36619 * distributed under the License is distributed on an "AS IS" BASIS,
36620 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36621 * See the License for the specific language governing permissions and
36622 * limitations under the License.
36623 * =============================================================================
36624 */
36625 function raggedGather_(paramsNestedSplits, paramsDenseValues, indices, outputRaggedRank) {
36626 var $paramsNestedSplits = paramsNestedSplits.map(function (t, i) {
36627 return convertToTensor(t, "tensors".concat(i), 'raggedGather', 'int32');
36628 });
36629 var $paramsDenseValues = convertToTensor(paramsDenseValues, 'paramsDenseValues', 'raggedGather');
36630 var $indices = convertToTensor(indices, 'indices', 'raggedGather', 'int32');
36631 var inputs = {
36632 paramsNestedSplits: $paramsNestedSplits,
36633 paramsDenseValues: $paramsDenseValues,
36634 indices: $indices
36635 };
36636 var attrs = {
36637 outputRaggedRank: outputRaggedRank
36638 };
36639 var result = ENGINE.runKernel(RaggedGather, inputs, attrs);
36640 return {
36641 outputNestedSplits: result.slice(0, result.length - 1),
36642 outputDenseValues: result[result.length - 1]
36643 };
36644 }
36645 var raggedGather$2 = /* @__PURE__ */op({
36646 raggedGather_: raggedGather_
36647 });
36648
36649 /**
36650 * @license
36651 * Copyright 2022 Google LLC.
36652 * Licensed under the Apache License, Version 2.0 (the "License");
36653 * you may not use this file except in compliance with the License.
36654 * You may obtain a copy of the License at
36655 *
36656 * http://www.apache.org/licenses/LICENSE-2.0
36657 *
36658 * Unless required by applicable law or agreed to in writing, software
36659 * distributed under the License is distributed on an "AS IS" BASIS,
36660 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36661 * See the License for the specific language governing permissions and
36662 * limitations under the License.
36663 * =============================================================================
36664 */
36665 /**
36666 * Returns a RaggedTensor result composed from rtDenseValues and rtNestedSplits,
36667 * such that result[i] = [starts[i], starts[i] + deltas[i], ..., limits[i]]).
36668 *
36669 * @param starts: A Tensor. Must be one of the following types:
36670 * 'float32', 'int32'. The starts of each range.
36671 * @param limits: A Tensor. Must have the same type as starts. The limits of
36672 * each range.
36673 * @param deltas: A Tensor. Must have the same type as starts. The deltas of
36674 * each range.
36675 * @return A map with the following properties:
36676 * - rtNestedSplits: A Tensor of type 'int32'.
36677 * - rtDenseValues: A Tensor. Has the same type as starts.
36678 */
36679 function raggedRange_(starts, limits, deltas) {
36680 var $starts = convertToTensor(starts, 'starts', 'raggedRange');
36681 var $limits = convertToTensor(limits, 'limits', 'raggedRange', $starts.dtype);
36682 var $deltas = convertToTensor(deltas, 'deltas', 'raggedRange', $starts.dtype);
36683 var inputs = {
36684 starts: $starts,
36685 limits: $limits,
36686 deltas: $deltas
36687 };
36688 var result = ENGINE.runKernel(RaggedRange, inputs);
36689 return {
36690 rtNestedSplits: result[0],
36691 rtDenseValues: result[1]
36692 };
36693 }
36694 var raggedRange$2 = /* @__PURE__ */op({
36695 raggedRange_: raggedRange_
36696 });
36697
36698 /**
36699 * @license
36700 * Copyright 2022 Google LLC. All Rights Reserved.
36701 * Licensed under the Apache License, Version 2.0 (the "License");
36702 * you may not use this file except in compliance with the License.
36703 * You may obtain a copy of the License at
36704 *
36705 * http://www.apache.org/licenses/LICENSE-2.0
36706 *
36707 * Unless required by applicable law or agreed to in writing, software
36708 * distributed under the License is distributed on an "AS IS" BASIS,
36709 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36710 * See the License for the specific language governing permissions and
36711 * limitations under the License.
36712 * =============================================================================
36713 */
36714 /**
36715 * Create a dense tensor from a ragged tensor, possibly altering its shape.
36716 *
36717 * The raggedTensorToTensor op creates a dense tensor from am array of row
36718 * partition tensors, a value vector, and default values. If the shape is
36719 * unspecified, the minimal shape required to contain all the elements in the
36720 * ragged tensor (the natural shape) will be used. If some dimensions are left
36721 * unspecified, then the size of the natural shape is used in that dimension.
36722 *
36723 * The defaultValue will be broadcast to the output shape. After that, the
36724 * values from the ragged tensor overwrite the default values. Note that the
36725 * defaultValue must have less dimensions than the value.
36726 *
36727 * The row partition tensors are in the order of the dimensions. At present, the
36728 * types can be: "ROW_SPLITS": the row_splits tensor from the ragged tensor.
36729 * "VALUE_ROWIDS": the value_rowids tensor from the ragged tensor.
36730 * "FIRST_DIM_SIZE": if value_rowids is used for the first dimension, then it
36731 * is preceded by "FIRST_DIM_SIZE".
36732 * ```
36733 * @param shape: A Tensor. Must be one of the following types: 'int32'. The
36734 * desired shape of the output tensor. If left unspecified (empty), the
36735 * minimal shape required to contain all the elements in the ragged tensor
36736 * (the natural shape) will be used. If some dimensions are left
36737 * unspecified, then the size of the natural shape is used in that
36738 * dimension.
36739 *
36740 * Note that dense dimensions cannot be modified by the shape argument.
36741 * Trying to change the size of a dense dimension will cause the op to fail.
36742 * Examples: natural shape: [4, 5, 6] shape: -1 output shape: [4, 5, 6]
36743 *
36744 * natural shape: [4, 5, 6] shape: [3, -1, 2] output shape: [3, 5, 2]
36745 *
36746 * natural shape: [4, 5, 6] shape: [3, 7, 2] output shape: [3, 7, 2]
36747 * @param values: A Tensor. A 1D tensor representing the values of the ragged
36748 * tensor.
36749 * @param defaultValue: A Tensor. Must have the same type as values. The
36750 * defaultValue when the shape is larger than the ragged tensor. The
36751 * defaultValue is broadcast until it is the shape of the output tensor,
36752 * and then overwritten by values in the ragged tensor. The default value
36753 * must be compatible with this broadcast operation, and must have fewer
36754 * dimensions than the value tensor.
36755 * @param rowPartitionTensors: A list of at least 1 Tensor objects with the same
36756 * type in: 'int32'.
36757 * @param rowPartitionTypes: A list of strings. The types of the row partition
36758 * tensors. At present, these can be:
36759 * "ROW_SPLITS": the row_splits tensor from the ragged tensor.
36760 * "VALUE_ROWIDS": the value_rowids tensor from the ragged tensor.
36761 * "FIRST_DIM_SIZE": if value_rowids is used for the first dimension, then
36762 * it is preceded by "FIRST_DIM_SIZE". The tensors are in the order of
36763 * the dimensions.
36764 * @return A Tensor. Has the same type as values.
36765 * @doc {heading: 'Operations', subheading: 'Ragged'}
36766 */
36767 function raggedTensorToTensor_(shape, values, defaultValue, rowPartitionTensors, rowPartitionTypes) {
36768 var $shape = convertToTensor(shape, 'shape', 'raggedTensorToTensor', 'int32');
36769 var $values = convertToTensor(values, 'values', 'raggedTensorToTensor');
36770 var $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'raggedTensorToTensor', $values.dtype);
36771 var $rowPartitionTensors = rowPartitionTensors.map(function (t, i) {
36772 return convertToTensor(t, "tensors".concat(i), 'raggedTensorToTensor', 'int32');
36773 });
36774 var inputs = {
36775 shape: $shape,
36776 values: $values,
36777 defaultValue: $defaultValue,
36778 rowPartitionTensors: $rowPartitionTensors
36779 };
36780 var attrs = {
36781 rowPartitionTypes: rowPartitionTypes
36782 };
36783 return ENGINE.runKernel(RaggedTensorToTensor, inputs, attrs);
36784 }
36785 var raggedTensorToTensor$2 = /* @__PURE__ */op({
36786 raggedTensorToTensor_: raggedTensorToTensor_
36787 });
36788
36789 /**
36790 * @license
36791 * Copyright 2020 Google LLC. All Rights Reserved.
36792 * Licensed under the Apache License, Version 2.0 (the "License");
36793 * you may not use this file except in compliance with the License.
36794 * You may obtain a copy of the License at
36795 *
36796 * http://www.apache.org/licenses/LICENSE-2.0
36797 *
36798 * Unless required by applicable law or agreed to in writing, software
36799 * distributed under the License is distributed on an "AS IS" BASIS,
36800 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36801 * See the License for the specific language governing permissions and
36802 * limitations under the License.
36803 * =============================================================================
36804 */
36805 /**
36806 * Creates a `tf.Tensor` with values sampled from a random number generator
36807 * function defined by the user.
36808 *
36809 * @param shape An array of integers defining the output tensor shape.
36810 * @param randFunction A random number generator function which is called
36811 * for each element in the output tensor.
36812 * @param dtype The data type of the output tensor. Defaults to 'float32'.
36813 *
36814 * @doc {heading: 'Tensors', subheading: 'Random'}
36815 */
36816 function rand_(shape, randFunction, dtype) {
36817 assertNonNegativeIntegerDimensions(shape);
36818 var size = sizeFromShape(shape);
36819 var values = null;
36820 if (dtype == null || dtype === 'float32') {
36821 values = new Float32Array(size);
36822 } else if (dtype === 'int32') {
36823 values = new Int32Array(size);
36824 } else if (dtype === 'bool') {
36825 values = new Uint8Array(size);
36826 } else {
36827 throw new Error("Unknown data type ".concat(dtype));
36828 }
36829 for (var i = 0; i < size; i++) {
36830 values[i] = randFunction();
36831 }
36832 return ENGINE.makeTensor(values, shape, dtype);
36833 }
36834 var rand = /* @__PURE__ */op({
36835 rand_: rand_
36836 });
36837
36838 var alea$3 = {exports: {}};
36839
36840 var alea$1 = alea$3.exports;
36841 (function (module) {
36842 // A port of an algorithm by Johannes Baagøe <baagoe@baagoe.com>, 2010
36843 // http://baagoe.com/en/RandomMusings/javascript/
36844 // https://github.com/nquinlan/better-random-numbers-for-javascript-mirror
36845 // Original work is under MIT license -
36846
36847 // Copyright (C) 2010 by Johannes Baagøe <baagoe@baagoe.org>
36848 //
36849 // Permission is hereby granted, free of charge, to any person obtaining a copy
36850 // of this software and associated documentation files (the "Software"), to deal
36851 // in the Software without restriction, including without limitation the rights
36852 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
36853 // copies of the Software, and to permit persons to whom the Software is
36854 // furnished to do so, subject to the following conditions:
36855 //
36856 // The above copyright notice and this permission notice shall be included in
36857 // all copies or substantial portions of the Software.
36858 //
36859 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
36860 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
36861 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
36862 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36863 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36864 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
36865 // THE SOFTWARE.
36866
36867 (function (global, module, define) {
36868 function Alea(seed) {
36869 var me = this,
36870 mash = Mash();
36871 me.next = function () {
36872 var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10; // 2^-32
36873 me.s0 = me.s1;
36874 me.s1 = me.s2;
36875 return me.s2 = t - (me.c = t | 0);
36876 };
36877
36878 // Apply the seeding algorithm from Baagoe.
36879 me.c = 1;
36880 me.s0 = mash(' ');
36881 me.s1 = mash(' ');
36882 me.s2 = mash(' ');
36883 me.s0 -= mash(seed);
36884 if (me.s0 < 0) {
36885 me.s0 += 1;
36886 }
36887 me.s1 -= mash(seed);
36888 if (me.s1 < 0) {
36889 me.s1 += 1;
36890 }
36891 me.s2 -= mash(seed);
36892 if (me.s2 < 0) {
36893 me.s2 += 1;
36894 }
36895 mash = null;
36896 }
36897 function copy(f, t) {
36898 t.c = f.c;
36899 t.s0 = f.s0;
36900 t.s1 = f.s1;
36901 t.s2 = f.s2;
36902 return t;
36903 }
36904 function impl(seed, opts) {
36905 var xg = new Alea(seed),
36906 state = opts && opts.state,
36907 prng = xg.next;
36908 prng.int32 = function () {
36909 return xg.next() * 0x100000000 | 0;
36910 };
36911 prng.double = function () {
36912 return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16; // 2^-53
36913 };
36914
36915 prng.quick = prng;
36916 if (state) {
36917 if (_typeof(state) == 'object') copy(state, xg);
36918 prng.state = function () {
36919 return copy(xg, {});
36920 };
36921 }
36922 return prng;
36923 }
36924 function Mash() {
36925 var n = 0xefc8249d;
36926 var mash = function mash(data) {
36927 data = String(data);
36928 for (var i = 0; i < data.length; i++) {
36929 n += data.charCodeAt(i);
36930 var h = 0.02519603282416938 * n;
36931 n = h >>> 0;
36932 h -= n;
36933 h *= n;
36934 n = h >>> 0;
36935 h -= n;
36936 n += h * 0x100000000; // 2^32
36937 }
36938
36939 return (n >>> 0) * 2.3283064365386963e-10; // 2^-32
36940 };
36941
36942 return mash;
36943 }
36944 if (module && module.exports) {
36945 module.exports = impl;
36946 } else if (define && define.amd) {
36947 define(function () {
36948 return impl;
36949 });
36950 } else {
36951 this.alea = impl;
36952 }
36953 })(commonjsGlobal, 'object' == 'object' && module,
36954 // present in node.js
36955 typeof undefined == 'function' && undefined // present with an AMD loader
36956 );
36957 })(alea$3);
36958 var aleaExports = alea$3.exports;
36959 var alea$2 = /*@__PURE__*/getDefaultExportFromCjs(aleaExports);
36960
36961 var xor128$3 = {exports: {}};
36962
36963 var xor128$1 = xor128$3.exports;
36964 (function (module) {
36965 // A Javascript implementaion of the "xor128" prng algorithm by
36966 // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
36967
36968 (function (global, module, define) {
36969 function XorGen(seed) {
36970 var me = this,
36971 strseed = '';
36972 me.x = 0;
36973 me.y = 0;
36974 me.z = 0;
36975 me.w = 0;
36976
36977 // Set up generator function.
36978 me.next = function () {
36979 var t = me.x ^ me.x << 11;
36980 me.x = me.y;
36981 me.y = me.z;
36982 me.z = me.w;
36983 return me.w ^= me.w >>> 19 ^ t ^ t >>> 8;
36984 };
36985 if (seed === (seed | 0)) {
36986 // Integer seed.
36987 me.x = seed;
36988 } else {
36989 // String seed.
36990 strseed += seed;
36991 }
36992
36993 // Mix in string seed, then discard an initial batch of 64 values.
36994 for (var k = 0; k < strseed.length + 64; k++) {
36995 me.x ^= strseed.charCodeAt(k) | 0;
36996 me.next();
36997 }
36998 }
36999 function copy(f, t) {
37000 t.x = f.x;
37001 t.y = f.y;
37002 t.z = f.z;
37003 t.w = f.w;
37004 return t;
37005 }
37006 function impl(seed, opts) {
37007 var xg = new XorGen(seed),
37008 state = opts && opts.state,
37009 prng = function prng() {
37010 return (xg.next() >>> 0) / 0x100000000;
37011 };
37012 prng.double = function () {
37013 do {
37014 var top = xg.next() >>> 11,
37015 bot = (xg.next() >>> 0) / 0x100000000,
37016 result = (top + bot) / (1 << 21);
37017 } while (result === 0);
37018 return result;
37019 };
37020 prng.int32 = xg.next;
37021 prng.quick = prng;
37022 if (state) {
37023 if (_typeof(state) == 'object') copy(state, xg);
37024 prng.state = function () {
37025 return copy(xg, {});
37026 };
37027 }
37028 return prng;
37029 }
37030 if (module && module.exports) {
37031 module.exports = impl;
37032 } else if (define && define.amd) {
37033 define(function () {
37034 return impl;
37035 });
37036 } else {
37037 this.xor128 = impl;
37038 }
37039 })(commonjsGlobal, 'object' == 'object' && module,
37040 // present in node.js
37041 typeof undefined == 'function' && undefined // present with an AMD loader
37042 );
37043 })(xor128$3);
37044 var xor128Exports = xor128$3.exports;
37045 var xor128$2 = /*@__PURE__*/getDefaultExportFromCjs(xor128Exports);
37046
37047 var xorwow$3 = {exports: {}};
37048
37049 var xorwow$1 = xorwow$3.exports;
37050 (function (module) {
37051 // A Javascript implementaion of the "xorwow" prng algorithm by
37052 // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
37053
37054 (function (global, module, define) {
37055 function XorGen(seed) {
37056 var me = this,
37057 strseed = '';
37058
37059 // Set up generator function.
37060 me.next = function () {
37061 var t = me.x ^ me.x >>> 2;
37062 me.x = me.y;
37063 me.y = me.z;
37064 me.z = me.w;
37065 me.w = me.v;
37066 return (me.d = me.d + 362437 | 0) + (me.v = me.v ^ me.v << 4 ^ (t ^ t << 1)) | 0;
37067 };
37068 me.x = 0;
37069 me.y = 0;
37070 me.z = 0;
37071 me.w = 0;
37072 me.v = 0;
37073 if (seed === (seed | 0)) {
37074 // Integer seed.
37075 me.x = seed;
37076 } else {
37077 // String seed.
37078 strseed += seed;
37079 }
37080
37081 // Mix in string seed, then discard an initial batch of 64 values.
37082 for (var k = 0; k < strseed.length + 64; k++) {
37083 me.x ^= strseed.charCodeAt(k) | 0;
37084 if (k == strseed.length) {
37085 me.d = me.x << 10 ^ me.x >>> 4;
37086 }
37087 me.next();
37088 }
37089 }
37090 function copy(f, t) {
37091 t.x = f.x;
37092 t.y = f.y;
37093 t.z = f.z;
37094 t.w = f.w;
37095 t.v = f.v;
37096 t.d = f.d;
37097 return t;
37098 }
37099 function impl(seed, opts) {
37100 var xg = new XorGen(seed),
37101 state = opts && opts.state,
37102 prng = function prng() {
37103 return (xg.next() >>> 0) / 0x100000000;
37104 };
37105 prng.double = function () {
37106 do {
37107 var top = xg.next() >>> 11,
37108 bot = (xg.next() >>> 0) / 0x100000000,
37109 result = (top + bot) / (1 << 21);
37110 } while (result === 0);
37111 return result;
37112 };
37113 prng.int32 = xg.next;
37114 prng.quick = prng;
37115 if (state) {
37116 if (_typeof(state) == 'object') copy(state, xg);
37117 prng.state = function () {
37118 return copy(xg, {});
37119 };
37120 }
37121 return prng;
37122 }
37123 if (module && module.exports) {
37124 module.exports = impl;
37125 } else if (define && define.amd) {
37126 define(function () {
37127 return impl;
37128 });
37129 } else {
37130 this.xorwow = impl;
37131 }
37132 })(commonjsGlobal, 'object' == 'object' && module,
37133 // present in node.js
37134 typeof undefined == 'function' && undefined // present with an AMD loader
37135 );
37136 })(xorwow$3);
37137 var xorwowExports = xorwow$3.exports;
37138 var xorwow$2 = /*@__PURE__*/getDefaultExportFromCjs(xorwowExports);
37139
37140 var xorshift7$3 = {exports: {}};
37141
37142 var xorshift7$1 = xorshift7$3.exports;
37143 (function (module) {
37144 // A Javascript implementaion of the "xorshift7" algorithm by
37145 // François Panneton and Pierre L'ecuyer:
37146 // "On the Xorgshift Random Number Generators"
37147 // http://saluc.engr.uconn.edu/refs/crypto/rng/panneton05onthexorshift.pdf
37148
37149 (function (global, module, define) {
37150 function XorGen(seed) {
37151 var me = this;
37152
37153 // Set up generator function.
37154 me.next = function () {
37155 // Update xor generator.
37156 var X = me.x,
37157 i = me.i,
37158 t,
37159 v,
37160 w;
37161 t = X[i];
37162 t ^= t >>> 7;
37163 v = t ^ t << 24;
37164 t = X[i + 1 & 7];
37165 v ^= t ^ t >>> 10;
37166 t = X[i + 3 & 7];
37167 v ^= t ^ t >>> 3;
37168 t = X[i + 4 & 7];
37169 v ^= t ^ t << 7;
37170 t = X[i + 7 & 7];
37171 t = t ^ t << 13;
37172 v ^= t ^ t << 9;
37173 X[i] = v;
37174 me.i = i + 1 & 7;
37175 return v;
37176 };
37177 function init(me, seed) {
37178 var j,
37179 w,
37180 X = [];
37181 if (seed === (seed | 0)) {
37182 // Seed state array using a 32-bit integer.
37183 w = X[0] = seed;
37184 } else {
37185 // Seed state using a string.
37186 seed = '' + seed;
37187 for (j = 0; j < seed.length; ++j) {
37188 X[j & 7] = X[j & 7] << 15 ^ seed.charCodeAt(j) + X[j + 1 & 7] << 13;
37189 }
37190 }
37191 // Enforce an array length of 8, not all zeroes.
37192 while (X.length < 8) X.push(0);
37193 for (j = 0; j < 8 && X[j] === 0; ++j);
37194 if (j == 8) w = X[7] = -1;else w = X[j];
37195 me.x = X;
37196 me.i = 0;
37197
37198 // Discard an initial 256 values.
37199 for (j = 256; j > 0; --j) {
37200 me.next();
37201 }
37202 }
37203 init(me, seed);
37204 }
37205 function copy(f, t) {
37206 t.x = f.x.slice();
37207 t.i = f.i;
37208 return t;
37209 }
37210 function impl(seed, opts) {
37211 if (seed == null) seed = +new Date();
37212 var xg = new XorGen(seed),
37213 state = opts && opts.state,
37214 prng = function prng() {
37215 return (xg.next() >>> 0) / 0x100000000;
37216 };
37217 prng.double = function () {
37218 do {
37219 var top = xg.next() >>> 11,
37220 bot = (xg.next() >>> 0) / 0x100000000,
37221 result = (top + bot) / (1 << 21);
37222 } while (result === 0);
37223 return result;
37224 };
37225 prng.int32 = xg.next;
37226 prng.quick = prng;
37227 if (state) {
37228 if (state.x) copy(state, xg);
37229 prng.state = function () {
37230 return copy(xg, {});
37231 };
37232 }
37233 return prng;
37234 }
37235 if (module && module.exports) {
37236 module.exports = impl;
37237 } else if (define && define.amd) {
37238 define(function () {
37239 return impl;
37240 });
37241 } else {
37242 this.xorshift7 = impl;
37243 }
37244 })(commonjsGlobal, 'object' == 'object' && module,
37245 // present in node.js
37246 typeof undefined == 'function' && undefined // present with an AMD loader
37247 );
37248 })(xorshift7$3);
37249 var xorshift7Exports = xorshift7$3.exports;
37250 var xorshift7$2 = /*@__PURE__*/getDefaultExportFromCjs(xorshift7Exports);
37251
37252 var xor4096$3 = {exports: {}};
37253
37254 var xor4096$1 = xor4096$3.exports;
37255 (function (module) {
37256 // A Javascript implementaion of Richard Brent's Xorgens xor4096 algorithm.
37257 //
37258 // This fast non-cryptographic random number generator is designed for
37259 // use in Monte-Carlo algorithms. It combines a long-period xorshift
37260 // generator with a Weyl generator, and it passes all common batteries
37261 // of stasticial tests for randomness while consuming only a few nanoseconds
37262 // for each prng generated. For background on the generator, see Brent's
37263 // paper: "Some long-period random number generators using shifts and xors."
37264 // http://arxiv.org/pdf/1004.3115v1.pdf
37265 //
37266 // Usage:
37267 //
37268 // var xor4096 = require('xor4096');
37269 // random = xor4096(1); // Seed with int32 or string.
37270 // assert.equal(random(), 0.1520436450538547); // (0, 1) range, 53 bits.
37271 // assert.equal(random.int32(), 1806534897); // signed int32, 32 bits.
37272 //
37273 // For nonzero numeric keys, this impelementation provides a sequence
37274 // identical to that by Brent's xorgens 3 implementaion in C. This
37275 // implementation also provides for initalizing the generator with
37276 // string seeds, or for saving and restoring the state of the generator.
37277 //
37278 // On Chrome, this prng benchmarks about 2.1 times slower than
37279 // Javascript's built-in Math.random().
37280
37281 (function (global, module, define) {
37282 function XorGen(seed) {
37283 var me = this;
37284
37285 // Set up generator function.
37286 me.next = function () {
37287 var w = me.w,
37288 X = me.X,
37289 i = me.i,
37290 t,
37291 v;
37292 // Update Weyl generator.
37293 me.w = w = w + 0x61c88647 | 0;
37294 // Update xor generator.
37295 v = X[i + 34 & 127];
37296 t = X[i = i + 1 & 127];
37297 v ^= v << 13;
37298 t ^= t << 17;
37299 v ^= v >>> 15;
37300 t ^= t >>> 12;
37301 // Update Xor generator array state.
37302 v = X[i] = v ^ t;
37303 me.i = i;
37304 // Result is the combination.
37305 return v + (w ^ w >>> 16) | 0;
37306 };
37307 function init(me, seed) {
37308 var t,
37309 v,
37310 i,
37311 j,
37312 w,
37313 X = [],
37314 limit = 128;
37315 if (seed === (seed | 0)) {
37316 // Numeric seeds initialize v, which is used to generates X.
37317 v = seed;
37318 seed = null;
37319 } else {
37320 // String seeds are mixed into v and X one character at a time.
37321 seed = seed + '\0';
37322 v = 0;
37323 limit = Math.max(limit, seed.length);
37324 }
37325 // Initialize circular array and weyl value.
37326 for (i = 0, j = -32; j < limit; ++j) {
37327 // Put the unicode characters into the array, and shuffle them.
37328 if (seed) v ^= seed.charCodeAt((j + 32) % seed.length);
37329 // After 32 shuffles, take v as the starting w value.
37330 if (j === 0) w = v;
37331 v ^= v << 10;
37332 v ^= v >>> 15;
37333 v ^= v << 4;
37334 v ^= v >>> 13;
37335 if (j >= 0) {
37336 w = w + 0x61c88647 | 0; // Weyl.
37337 t = X[j & 127] ^= v + w; // Combine xor and weyl to init array.
37338 i = 0 == t ? i + 1 : 0; // Count zeroes.
37339 }
37340 }
37341 // We have detected all zeroes; make the key nonzero.
37342 if (i >= 128) {
37343 X[(seed && seed.length || 0) & 127] = -1;
37344 }
37345 // Run the generator 512 times to further mix the state before using it.
37346 // Factoring this as a function slows the main generator, so it is just
37347 // unrolled here. The weyl generator is not advanced while warming up.
37348 i = 127;
37349 for (j = 4 * 128; j > 0; --j) {
37350 v = X[i + 34 & 127];
37351 t = X[i = i + 1 & 127];
37352 v ^= v << 13;
37353 t ^= t << 17;
37354 v ^= v >>> 15;
37355 t ^= t >>> 12;
37356 X[i] = v ^ t;
37357 }
37358 // Storing state as object members is faster than using closure variables.
37359 me.w = w;
37360 me.X = X;
37361 me.i = i;
37362 }
37363 init(me, seed);
37364 }
37365 function copy(f, t) {
37366 t.i = f.i;
37367 t.w = f.w;
37368 t.X = f.X.slice();
37369 return t;
37370 }
37371 ;
37372 function impl(seed, opts) {
37373 if (seed == null) seed = +new Date();
37374 var xg = new XorGen(seed),
37375 state = opts && opts.state,
37376 prng = function prng() {
37377 return (xg.next() >>> 0) / 0x100000000;
37378 };
37379 prng.double = function () {
37380 do {
37381 var top = xg.next() >>> 11,
37382 bot = (xg.next() >>> 0) / 0x100000000,
37383 result = (top + bot) / (1 << 21);
37384 } while (result === 0);
37385 return result;
37386 };
37387 prng.int32 = xg.next;
37388 prng.quick = prng;
37389 if (state) {
37390 if (state.X) copy(state, xg);
37391 prng.state = function () {
37392 return copy(xg, {});
37393 };
37394 }
37395 return prng;
37396 }
37397 if (module && module.exports) {
37398 module.exports = impl;
37399 } else if (define && define.amd) {
37400 define(function () {
37401 return impl;
37402 });
37403 } else {
37404 this.xor4096 = impl;
37405 }
37406 })(commonjsGlobal,
37407 // window object or global
37408 'object' == 'object' && module,
37409 // present in node.js
37410 typeof undefined == 'function' && undefined // present with an AMD loader
37411 );
37412 })(xor4096$3);
37413 var xor4096Exports = xor4096$3.exports;
37414 var xor4096$2 = /*@__PURE__*/getDefaultExportFromCjs(xor4096Exports);
37415
37416 var tychei$3 = {exports: {}};
37417
37418 var tychei$1 = tychei$3.exports;
37419 (function (module) {
37420 // A Javascript implementaion of the "Tyche-i" prng algorithm by
37421 // Samuel Neves and Filipe Araujo.
37422 // See https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
37423
37424 (function (global, module, define) {
37425 function XorGen(seed) {
37426 var me = this,
37427 strseed = '';
37428
37429 // Set up generator function.
37430 me.next = function () {
37431 var b = me.b,
37432 c = me.c,
37433 d = me.d,
37434 a = me.a;
37435 b = b << 25 ^ b >>> 7 ^ c;
37436 c = c - d | 0;
37437 d = d << 24 ^ d >>> 8 ^ a;
37438 a = a - b | 0;
37439 me.b = b = b << 20 ^ b >>> 12 ^ c;
37440 me.c = c = c - d | 0;
37441 me.d = d << 16 ^ c >>> 16 ^ a;
37442 return me.a = a - b | 0;
37443 };
37444
37445 /* The following is non-inverted tyche, which has better internal
37446 * bit diffusion, but which is about 25% slower than tyche-i in JS.
37447 me.next = function() {
37448 var a = me.a, b = me.b, c = me.c, d = me.d;
37449 a = (me.a + me.b | 0) >>> 0;
37450 d = me.d ^ a; d = d << 16 ^ d >>> 16;
37451 c = me.c + d | 0;
37452 b = me.b ^ c; b = b << 12 ^ d >>> 20;
37453 me.a = a = a + b | 0;
37454 d = d ^ a; me.d = d = d << 8 ^ d >>> 24;
37455 me.c = c = c + d | 0;
37456 b = b ^ c;
37457 return me.b = (b << 7 ^ b >>> 25);
37458 }
37459 */
37460
37461 me.a = 0;
37462 me.b = 0;
37463 me.c = 2654435769 | 0;
37464 me.d = 1367130551;
37465 if (seed === Math.floor(seed)) {
37466 // Integer seed.
37467 me.a = seed / 0x100000000 | 0;
37468 me.b = seed | 0;
37469 } else {
37470 // String seed.
37471 strseed += seed;
37472 }
37473
37474 // Mix in string seed, then discard an initial batch of 64 values.
37475 for (var k = 0; k < strseed.length + 20; k++) {
37476 me.b ^= strseed.charCodeAt(k) | 0;
37477 me.next();
37478 }
37479 }
37480 function copy(f, t) {
37481 t.a = f.a;
37482 t.b = f.b;
37483 t.c = f.c;
37484 t.d = f.d;
37485 return t;
37486 }
37487 ;
37488 function impl(seed, opts) {
37489 var xg = new XorGen(seed),
37490 state = opts && opts.state,
37491 prng = function prng() {
37492 return (xg.next() >>> 0) / 0x100000000;
37493 };
37494 prng.double = function () {
37495 do {
37496 var top = xg.next() >>> 11,
37497 bot = (xg.next() >>> 0) / 0x100000000,
37498 result = (top + bot) / (1 << 21);
37499 } while (result === 0);
37500 return result;
37501 };
37502 prng.int32 = xg.next;
37503 prng.quick = prng;
37504 if (state) {
37505 if (_typeof(state) == 'object') copy(state, xg);
37506 prng.state = function () {
37507 return copy(xg, {});
37508 };
37509 }
37510 return prng;
37511 }
37512 if (module && module.exports) {
37513 module.exports = impl;
37514 } else if (define && define.amd) {
37515 define(function () {
37516 return impl;
37517 });
37518 } else {
37519 this.tychei = impl;
37520 }
37521 })(commonjsGlobal, 'object' == 'object' && module,
37522 // present in node.js
37523 typeof undefined == 'function' && undefined // present with an AMD loader
37524 );
37525 })(tychei$3);
37526 var tycheiExports = tychei$3.exports;
37527 var tychei$2 = /*@__PURE__*/getDefaultExportFromCjs(tycheiExports);
37528
37529 var seedrandom$3 = {exports: {}};
37530
37531 var seedrandom$1 = seedrandom$3.exports;
37532 (function (module) {
37533 (function (global, pool, math) {
37534 //
37535 // The following constants are related to IEEE 754 limits.
37536 //
37537
37538 var width = 256,
37539 // each RC4 output is 0 <= x < 256
37540 chunks = 6,
37541 // at least six RC4 outputs for each double
37542 digits = 52,
37543 // there are 52 significant digits in a double
37544 rngname = 'random',
37545 // rngname: name for Math.random and Math.seedrandom
37546 startdenom = math.pow(width, chunks),
37547 significance = math.pow(2, digits),
37548 overflow = significance * 2,
37549 mask = width - 1,
37550 nodecrypto; // node.js crypto module, initialized at the bottom.
37551
37552 //
37553 // seedrandom()
37554 // This is the seedrandom function described above.
37555 //
37556 function seedrandom(seed, options, callback) {
37557 var key = [];
37558 options = options == true ? {
37559 entropy: true
37560 } : options || {};
37561
37562 // Flatten the seed string or build one from local entropy if needed.
37563 var shortseed = mixkey(flatten(options.entropy ? [seed, tostring(pool)] : seed == null ? autoseed() : seed, 3), key);
37564
37565 // Use the seed to initialize an ARC4 generator.
37566 var arc4 = new ARC4(key);
37567
37568 // This function returns a random double in [0, 1) that contains
37569 // randomness in every bit of the mantissa of the IEEE 754 value.
37570 var prng = function prng() {
37571 var n = arc4.g(chunks),
37572 // Start with a numerator n < 2 ^ 48
37573 d = startdenom,
37574 // and denominator d = 2 ^ 48.
37575 x = 0; // and no 'extra last byte'.
37576 while (n < significance) {
37577 // Fill up all significant digits by
37578 n = (n + x) * width; // shifting numerator and
37579 d *= width; // denominator and generating a
37580 x = arc4.g(1); // new least-significant-byte.
37581 }
37582
37583 while (n >= overflow) {
37584 // To avoid rounding up, before adding
37585 n /= 2; // last byte, shift everything
37586 d /= 2; // right using integer math until
37587 x >>>= 1; // we have exactly the desired bits.
37588 }
37589
37590 return (n + x) / d; // Form the number within [0, 1).
37591 };
37592
37593 prng.int32 = function () {
37594 return arc4.g(4) | 0;
37595 };
37596 prng.quick = function () {
37597 return arc4.g(4) / 0x100000000;
37598 };
37599 prng.double = prng;
37600
37601 // Mix the randomness into accumulated entropy.
37602 mixkey(tostring(arc4.S), pool);
37603
37604 // Calling convention: what to return as a function of prng, seed, is_math.
37605 return (options.pass || callback || function (prng, seed, is_math_call, state) {
37606 if (state) {
37607 // Load the arc4 state from the given state if it has an S array.
37608 if (state.S) {
37609 copy(state, arc4);
37610 }
37611 // Only provide the .state method if requested via options.state.
37612 prng.state = function () {
37613 return copy(arc4, {});
37614 };
37615 }
37616
37617 // If called as a method of Math (Math.seedrandom()), mutate
37618 // Math.random because that is how seedrandom.js has worked since v1.0.
37619 if (is_math_call) {
37620 math[rngname] = prng;
37621 return seed;
37622 }
37623
37624 // Otherwise, it is a newer calling convention, so return the
37625 // prng directly.
37626 else return prng;
37627 })(prng, shortseed, 'global' in options ? options.global : this == math, options.state);
37628 }
37629
37630 //
37631 // ARC4
37632 //
37633 // An ARC4 implementation. The constructor takes a key in the form of
37634 // an array of at most (width) integers that should be 0 <= x < (width).
37635 //
37636 // The g(count) method returns a pseudorandom integer that concatenates
37637 // the next (count) outputs from ARC4. Its return value is a number x
37638 // that is in the range 0 <= x < (width ^ count).
37639 //
37640 function ARC4(key) {
37641 var t,
37642 keylen = key.length,
37643 me = this,
37644 i = 0,
37645 j = me.i = me.j = 0,
37646 s = me.S = [];
37647
37648 // The empty key [] is treated as [0].
37649 if (!keylen) {
37650 key = [keylen++];
37651 }
37652
37653 // Set up S using the standard key scheduling algorithm.
37654 while (i < width) {
37655 s[i] = i++;
37656 }
37657 for (i = 0; i < width; i++) {
37658 s[i] = s[j = mask & j + key[i % keylen] + (t = s[i])];
37659 s[j] = t;
37660 }
37661
37662 // The "g" method returns the next (count) outputs as one number.
37663 (me.g = function (count) {
37664 // Using instance members instead of closure state nearly doubles speed.
37665 var t,
37666 r = 0,
37667 i = me.i,
37668 j = me.j,
37669 s = me.S;
37670 while (count--) {
37671 t = s[i = mask & i + 1];
37672 r = r * width + s[mask & (s[i] = s[j = mask & j + t]) + (s[j] = t)];
37673 }
37674 me.i = i;
37675 me.j = j;
37676 return r;
37677 // For robust unpredictability, the function call below automatically
37678 // discards an initial batch of values. This is called RC4-drop[256].
37679 // See http://google.com/search?q=rsa+fluhrer+response&btnI
37680 })(width);
37681 }
37682
37683 //
37684 // copy()
37685 // Copies internal state of ARC4 to or from a plain object.
37686 //
37687 function copy(f, t) {
37688 t.i = f.i;
37689 t.j = f.j;
37690 t.S = f.S.slice();
37691 return t;
37692 }
37693 ;
37694
37695 //
37696 // flatten()
37697 // Converts an object tree to nested arrays of strings.
37698 //
37699 function flatten(obj, depth) {
37700 var result = [],
37701 typ = _typeof(obj),
37702 prop;
37703 if (depth && typ == 'object') {
37704 for (prop in obj) {
37705 try {
37706 result.push(flatten(obj[prop], depth - 1));
37707 } catch (e) {}
37708 }
37709 }
37710 return result.length ? result : typ == 'string' ? obj : obj + '\0';
37711 }
37712
37713 //
37714 // mixkey()
37715 // Mixes a string seed into a key that is an array of integers, and
37716 // returns a shortened string seed that is equivalent to the result key.
37717 //
37718 function mixkey(seed, key) {
37719 var stringseed = seed + '',
37720 smear,
37721 j = 0;
37722 while (j < stringseed.length) {
37723 key[mask & j] = mask & (smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++);
37724 }
37725 return tostring(key);
37726 }
37727
37728 //
37729 // autoseed()
37730 // Returns an object for autoseeding, using window.crypto and Node crypto
37731 // module if available.
37732 //
37733 function autoseed() {
37734 try {
37735 var out;
37736 if (nodecrypto && (out = nodecrypto.randomBytes)) {
37737 // The use of 'out' to remember randomBytes makes tight minified code.
37738 out = out(width);
37739 } else {
37740 out = new Uint8Array(width);
37741 (global.crypto || global.msCrypto).getRandomValues(out);
37742 }
37743 return tostring(out);
37744 } catch (e) {
37745 var browser = global.navigator,
37746 plugins = browser && browser.plugins;
37747 return [+new Date(), global, plugins, global.screen, tostring(pool)];
37748 }
37749 }
37750
37751 //
37752 // tostring()
37753 // Converts an array of charcodes to a string
37754 //
37755 function tostring(a) {
37756 return String.fromCharCode.apply(0, a);
37757 }
37758
37759 //
37760 // When seedrandom.js is loaded, we immediately mix a few bits
37761 // from the built-in RNG into the entropy pool. Because we do
37762 // not want to interfere with deterministic PRNG state later,
37763 // seedrandom will not call math.random on its own again after
37764 // initialization.
37765 //
37766 mixkey(math.random(), pool);
37767
37768 //
37769 // Nodejs and AMD support: export the implementation as a module using
37770 // either convention.
37771 //
37772 if ('object' == 'object' && module.exports) {
37773 module.exports = seedrandom;
37774 // When in node.js, try using crypto package for autoseeding.
37775 try {
37776 nodecrypto = require('crypto');
37777 } catch (ex) {}
37778 } else if (typeof undefined == 'function' && undefined.amd) {
37779 undefined(function () {
37780 return seedrandom;
37781 });
37782 } else {
37783 // When included as a plain script, set up Math.seedrandom global.
37784 math['seed' + rngname] = seedrandom;
37785 }
37786
37787 // End anonymous scope, and pass initial values.
37788 })(
37789 // global: `self` in browsers (including strict mode and web workers),
37790 // otherwise `this` in Node and other environments
37791 typeof self !== 'undefined' ? self : commonjsGlobal, [],
37792 // pool: entropy pool starts empty
37793 Math // math: package containing random, pow, and seedrandom
37794 );
37795 })(seedrandom$3);
37796 var seedrandomExports = seedrandom$3.exports;
37797 var seedrandom$2 = /*@__PURE__*/getDefaultExportFromCjs(seedrandomExports);
37798
37799 // A library of seedable RNGs implemented in Javascript.
37800 //
37801 // Usage:
37802 //
37803 // var seedrandom = require('seedrandom');
37804 // var random = seedrandom(1); // or any seed.
37805 // var x = random(); // 0 <= x < 1. Every bit is random.
37806 // var x = random.quick(); // 0 <= x < 1. 32 bits of randomness.
37807
37808 // alea, a 53-bit multiply-with-carry generator by Johannes Baagøe.
37809 // Period: ~2^116
37810 // Reported to pass all BigCrush tests.
37811 var alea = aleaExports;
37812
37813 // xor128, a pure xor-shift generator by George Marsaglia.
37814 // Period: 2^128-1.
37815 // Reported to fail: MatrixRank and LinearComp.
37816 var xor128 = xor128Exports;
37817
37818 // xorwow, George Marsaglia's 160-bit xor-shift combined plus weyl.
37819 // Period: 2^192-2^32
37820 // Reported to fail: CollisionOver, SimpPoker, and LinearComp.
37821 var xorwow = xorwowExports;
37822
37823 // xorshift7, by François Panneton and Pierre L'ecuyer, takes
37824 // a different approach: it adds robustness by allowing more shifts
37825 // than Marsaglia's original three. It is a 7-shift generator
37826 // with 256 bits, that passes BigCrush with no systmatic failures.
37827 // Period 2^256-1.
37828 // No systematic BigCrush failures reported.
37829 var xorshift7 = xorshift7Exports;
37830
37831 // xor4096, by Richard Brent, is a 4096-bit xor-shift with a
37832 // very long period that also adds a Weyl generator. It also passes
37833 // BigCrush with no systematic failures. Its long period may
37834 // be useful if you have many generators and need to avoid
37835 // collisions.
37836 // Period: 2^4128-2^32.
37837 // No systematic BigCrush failures reported.
37838 var xor4096 = xor4096Exports;
37839
37840 // Tyche-i, by Samuel Neves and Filipe Araujo, is a bit-shifting random
37841 // number generator derived from ChaCha, a modern stream cipher.
37842 // https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
37843 // Period: ~2^127
37844 // No systematic BigCrush failures reported.
37845 var tychei = tycheiExports;
37846
37847 // The original ARC4-based prng included in this library.
37848 // Period: ~2^1600
37849 var sr = seedrandomExports;
37850 sr.alea = alea;
37851 sr.xor128 = xor128;
37852 sr.xorwow = xorwow;
37853 sr.xorshift7 = xorshift7;
37854 sr.xor4096 = xor4096;
37855 sr.tychei = tychei;
37856 var seedrandom = sr;
37857 var index$1 = /*@__PURE__*/getDefaultExportFromCjs(seedrandom);
37858
37859 var TEST_EPSILON_FLOAT32 = 1e-3;
37860 var TEST_EPSILON_FLOAT16 = 1e-1;
37861 function expectArraysClose(actual, expected, epsilon) {
37862 if (epsilon == null) {
37863 epsilon = testEpsilon();
37864 }
37865 return expectArraysPredicate(actual, expected, function (a, b) {
37866 return areClose(a, b, epsilon);
37867 });
37868 }
37869 function testEpsilon() {
37870 return ENGINE.backend.floatPrecision() === 32 ? TEST_EPSILON_FLOAT32 : TEST_EPSILON_FLOAT16;
37871 }
37872 function expectArraysPredicate(actual, expected, predicate) {
37873 var checkClassType = true;
37874 if (isTypedArray(actual) || isTypedArray(expected)) {
37875 checkClassType = false;
37876 }
37877 if (isTypedArray(actual) && isTypedArray(expected)) {
37878 checkClassType = true;
37879 }
37880 if (checkClassType) {
37881 var aType = actual.constructor.name;
37882 var bType = expected.constructor.name;
37883 if (aType !== bType) {
37884 throw new Error("Arrays are of different type. Actual: ".concat(aType, ". ") + "Expected: ".concat(bType));
37885 }
37886 }
37887 if (Array.isArray(actual) && Array.isArray(expected)) {
37888 var actualShape = inferShape(actual);
37889 var expectedShape = inferShape(expected);
37890 if (!arraysEqual(actualShape, expectedShape)) {
37891 throw new Error("Arrays have different shapes. " + "Actual: [".concat(actualShape, "]. Expected: [").concat(expectedShape, "]"));
37892 }
37893 }
37894 var actualFlat = isTypedArray(actual) ? actual : flatten$2(actual);
37895 var expectedFlat = isTypedArray(expected) ? expected : flatten$2(expected);
37896 if (actualFlat.length !== expectedFlat.length) {
37897 throw new Error("Arrays have different lengths actual: ".concat(actualFlat.length, " vs ") + "expected: ".concat(expectedFlat.length, ".\n") + "Actual: ".concat(actualFlat, ".\n") + "Expected: ".concat(expectedFlat, "."));
37898 }
37899 for (var i = 0; i < expectedFlat.length; ++i) {
37900 var a = actualFlat[i];
37901 var e = expectedFlat[i];
37902 if (!predicate(a, e)) {
37903 throw new Error("Arrays differ: actual[".concat(i, "] = ").concat(a, ", expected[").concat(i, "] = ").concat(e, ".\n") + "Actual: ".concat(actualFlat, ".\n") + "Expected: ".concat(expectedFlat, "."));
37904 }
37905 }
37906 if (typeof expect !== 'undefined') {
37907 expect().nothing();
37908 }
37909 }
37910 function expectPromiseToFail(fn, done) {
37911 fn().then(function () {
37912 return done.fail();
37913 }, function () {
37914 return done();
37915 });
37916 if (typeof expect !== 'undefined') {
37917 expect().nothing();
37918 }
37919 }
37920 function expectArraysEqual(actual, expected) {
37921 var exp = typeof expected === 'string' || typeof expected === 'number' || typeof expected === 'boolean' ? [expected] : expected;
37922 if (isString(actual) || isString(actual[0]) || isString(expected) || isString(expected[0])) {
37923 // tslint:disable-next-line: triple-equals
37924 return expectArraysPredicate(actual, exp, function (a, b) {
37925 return a == b;
37926 });
37927 }
37928 return expectArraysPredicate(actual, expected, function (a, b) {
37929 return areClose(a, b, 0);
37930 });
37931 }
37932 function expectNumbersClose(a, e, epsilon) {
37933 if (epsilon == null) {
37934 epsilon = testEpsilon();
37935 }
37936 if (!areClose(a, e, epsilon)) {
37937 throw new Error("Numbers differ: actual === ".concat(a, ", expected === ").concat(e));
37938 }
37939 if (typeof expect !== 'undefined') {
37940 expect().nothing();
37941 }
37942 }
37943 function areClose(a, e, epsilon) {
37944 if (!isFinite(a) && !isFinite(e)) {
37945 return true;
37946 }
37947 if (isNaN(a) || isNaN(e) || Math.abs(a - e) > epsilon) {
37948 return false;
37949 }
37950 return true;
37951 }
37952 function expectValuesInRange(actual, low, high) {
37953 for (var i = 0; i < actual.length; i++) {
37954 if (actual[i] < low || actual[i] > high) {
37955 throw new Error("Value out of range:".concat(actual[i], " low: ").concat(low, ", high: ").concat(high));
37956 }
37957 }
37958 }
37959 function expectArrayBuffersEqual(actual, expected) {
37960 // Safari does not like comparing ArrayBuffers directly. Wrapping in
37961 // a Float32Array solves this issue.
37962 var actualArray = new Float32Array(actual);
37963 var expectedArray = new Float32Array(expected);
37964 if (actualArray.length !== expectedArray.length) {
37965 throw new Error('Expected ArrayBuffer to be of length ' + "".concat(expectedArray.length, ", but it was ").concat(actualArray.length));
37966 }
37967 for (var i = 0; i < expectedArray.length; i++) {
37968 if (actualArray[i] !== expectedArray[i]) {
37969 throw new Error("Expected ArrayBuffer value at ".concat(i, " to be ") + "".concat(expectedArray[i], " but got ").concat(actualArray[i], " instead"));
37970 }
37971 }
37972 }
37973 /** Encodes strings into utf-8 bytes. */
37974 function encodeStrings(a) {
37975 for (var i = 0; i < a.length; i++) {
37976 var val = a[i];
37977 if (Array.isArray(val)) {
37978 encodeStrings(val);
37979 } else {
37980 a[i] = encodeString(val);
37981 }
37982 }
37983 return a;
37984 }
37985 /** Creates an HTMLVideoElement with autoplay-friendly default settings. */
37986 function createVideoElement(source) {
37987 var video = document.createElement('video');
37988 if ('playsInline' in video) {
37989 // tslint:disable-next-line:no-any
37990 video.playsInline = true;
37991 }
37992 video.muted = true;
37993 video.loop = true;
37994 video.style.position = 'fixed';
37995 video.style.left = '0px';
37996 video.style.top = '0px';
37997 video.preload = 'auto';
37998 video.appendChild(source);
37999 return new Promise(function (resolve) {
38000 video.addEventListener('loadeddata', function (_) {
38001 return resolve(video);
38002 });
38003 video.load();
38004 });
38005 }
38006 function play(_x) {
38007 return _play.apply(this, arguments);
38008 }
38009 function _play() {
38010 _play = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(video) {
38011 return _regeneratorRuntime().wrap(function _callee$(_context) {
38012 while (1) switch (_context.prev = _context.next) {
38013 case 0:
38014 _context.next = 2;
38015 return video.play();
38016 case 2:
38017 if (!('requestVideoFrameCallback' in video)) {
38018 _context.next = 5;
38019 break;
38020 }
38021 _context.next = 5;
38022 return new Promise(function (resolve) {
38023 // tslint:disable-next-line:no-any
38024 video.requestVideoFrameCallback(resolve);
38025 });
38026 case 5:
38027 case "end":
38028 return _context.stop();
38029 }
38030 }, _callee);
38031 }));
38032 return _play.apply(this, arguments);
38033 }
38034
38035 var test_util = {
38036 __proto__: null,
38037 TEST_EPSILON_FLOAT16: TEST_EPSILON_FLOAT16,
38038 createVideoElement: createVideoElement,
38039 encodeStrings: encodeStrings,
38040 expectArrayBuffersEqual: expectArrayBuffersEqual,
38041 expectArraysClose: expectArraysClose,
38042 expectArraysEqual: expectArraysEqual,
38043 expectNumbersClose: expectNumbersClose,
38044 expectPromiseToFail: expectPromiseToFail,
38045 expectValuesInRange: expectValuesInRange,
38046 play: play,
38047 testEpsilon: testEpsilon
38048 };
38049
38050 // https://en.wikipedia.org/wiki/Marsaglia_polar_method
38051 var MPRandGauss = /*#__PURE__*/function () {
38052 function MPRandGauss(mean, stdDeviation, dtype, truncated, seed) {
38053 _classCallCheck(this, MPRandGauss);
38054 this.mean = mean;
38055 this.stdDev = stdDeviation;
38056 this.dtype = dtype;
38057 this.nextVal = NaN;
38058 this.truncated = truncated;
38059 if (this.truncated) {
38060 this.upper = this.mean + this.stdDev * 2;
38061 this.lower = this.mean - this.stdDev * 2;
38062 }
38063 var seedValue = seed ? seed : Math.random();
38064 this.random = seedrandom.alea(seedValue.toString());
38065 }
38066 /** Returns next sample from a Gaussian distribution. */
38067 _createClass(MPRandGauss, [{
38068 key: "nextValue",
38069 value: function nextValue() {
38070 if (!isNaN(this.nextVal)) {
38071 var value = this.nextVal;
38072 this.nextVal = NaN;
38073 return value;
38074 }
38075 var resultX, resultY;
38076 var isValid = false;
38077 while (!isValid) {
38078 var v1 = void 0,
38079 v2 = void 0,
38080 s = void 0;
38081 do {
38082 v1 = 2 * this.random() - 1;
38083 v2 = 2 * this.random() - 1;
38084 s = v1 * v1 + v2 * v2;
38085 } while (s >= 1 || s === 0);
38086 var mul = Math.sqrt(-2.0 * Math.log(s) / s);
38087 resultX = this.mean + this.stdDev * v1 * mul;
38088 resultY = this.mean + this.stdDev * v2 * mul;
38089 if (!this.truncated || this.isValidTruncated(resultX)) {
38090 isValid = true;
38091 }
38092 }
38093 if (!this.truncated || this.isValidTruncated(resultY)) {
38094 this.nextVal = this.convertValue(resultY);
38095 }
38096 return this.convertValue(resultX);
38097 }
38098 /** Handles proper rounding for non-floating-point numbers. */
38099 }, {
38100 key: "convertValue",
38101 value: function convertValue(value) {
38102 if (this.dtype == null || this.dtype === 'float32') {
38103 return value;
38104 }
38105 return Math.round(value);
38106 }
38107 /** Returns true if less than 2-standard-deviations from the mean. */
38108 }, {
38109 key: "isValidTruncated",
38110 value: function isValidTruncated(value) {
38111 return value <= this.upper && value >= this.lower;
38112 }
38113 }]);
38114 return MPRandGauss;
38115 }();
38116 // Marsaglia, George, and Wai Wan Tsang. 2000. "A Simple Method for Generating
38117 // Gamma Variables."
38118 var RandGamma = /*#__PURE__*/function () {
38119 function RandGamma(alpha, beta, dtype, seed) {
38120 _classCallCheck(this, RandGamma);
38121 this.alpha = alpha;
38122 this.beta = 1 / beta; // convert rate to scale parameter
38123 this.dtype = dtype;
38124 var seedValue = seed ? seed : Math.random();
38125 this.randu = seedrandom.alea(seedValue.toString());
38126 this.randn = new MPRandGauss(0, 1, dtype, false, this.randu());
38127 if (alpha < 1) {
38128 this.d = alpha + 2 / 3;
38129 } else {
38130 this.d = alpha - 1 / 3;
38131 }
38132 this.c = 1 / Math.sqrt(9 * this.d);
38133 }
38134 /** Returns next sample from a gamma distribution. */
38135 _createClass(RandGamma, [{
38136 key: "nextValue",
38137 value: function nextValue() {
38138 var x2, v0, v1, x, u, v;
38139 while (true) {
38140 do {
38141 x = this.randn.nextValue();
38142 v = 1 + this.c * x;
38143 } while (v <= 0);
38144 v *= v * v;
38145 x2 = x * x;
38146 v0 = 1 - 0.331 * x2 * x2;
38147 v1 = 0.5 * x2 + this.d * (1 - v + Math.log(v));
38148 u = this.randu();
38149 if (u < v0 || Math.log(u) < v1) {
38150 break;
38151 }
38152 }
38153 v = 1 / this.beta * this.d * v;
38154 if (this.alpha < 1) {
38155 v *= Math.pow(this.randu(), 1 / this.alpha);
38156 }
38157 return this.convertValue(v);
38158 }
38159 /** Handles proper rounding for non-floating-point numbers. */
38160 }, {
38161 key: "convertValue",
38162 value: function convertValue(value) {
38163 if (this.dtype === 'float32') {
38164 return value;
38165 }
38166 return Math.round(value);
38167 }
38168 }]);
38169 return RandGamma;
38170 }();
38171 var UniformRandom = /*#__PURE__*/function () {
38172 function UniformRandom() {
38173 var _this = this;
38174 var min = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 0;
38175 var max = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 1;
38176 var dtype = arguments.length > 2 ? arguments[2] : undefined;
38177 var seed = arguments.length > 3 ? arguments[3] : undefined;
38178 _classCallCheck(this, UniformRandom);
38179 /** Handles proper rounding for non floating point numbers. */
38180 this.canReturnFloat = function () {
38181 return _this.dtype == null || _this.dtype === 'float32';
38182 };
38183 this.min = min;
38184 this.range = max - min;
38185 this.dtype = dtype;
38186 if (seed == null) {
38187 seed = Math.random();
38188 }
38189 if (typeof seed === 'number') {
38190 seed = seed.toString();
38191 }
38192 if (!this.canReturnFloat() && this.range <= 1) {
38193 throw new Error("The difference between ".concat(min, " - ").concat(max, " <= 1 and dtype is not float"));
38194 }
38195 this.random = seedrandom.alea(seed);
38196 }
38197 _createClass(UniformRandom, [{
38198 key: "convertValue",
38199 value: function convertValue(value) {
38200 if (this.canReturnFloat()) {
38201 return value;
38202 }
38203 return Math.round(value);
38204 }
38205 }, {
38206 key: "nextValue",
38207 value: function nextValue() {
38208 return this.convertValue(this.min + this.range * this.random());
38209 }
38210 }]);
38211 return UniformRandom;
38212 }();
38213 function jarqueBeraNormalityTest(values) {
38214 // https://en.wikipedia.org/wiki/Jarque%E2%80%93Bera_test
38215 var n = values.length;
38216 var s = skewness(values);
38217 var k = kurtosis(values);
38218 var jb = n / 6 * (Math.pow(s, 2) + 0.25 * Math.pow(k - 3, 2));
38219 // JB test requires 2-degress of freedom from Chi-Square @ 0.95:
38220 // http://www.itl.nist.gov/div898/handbook/eda/section3/eda3674.htm
38221 var CHI_SQUARE_2DEG = 5.991;
38222 if (jb > CHI_SQUARE_2DEG) {
38223 throw new Error("Invalid p-value for JB: ".concat(jb));
38224 }
38225 }
38226 function expectArrayInMeanStdRange(actual, expectedMean, expectedStdDev, epsilon) {
38227 if (epsilon == null) {
38228 epsilon = testEpsilon();
38229 }
38230 var actualMean = mean$2(actual);
38231 expectNumbersClose(actualMean, expectedMean, epsilon);
38232 expectNumbersClose(standardDeviation(actual, actualMean), expectedStdDev, epsilon);
38233 }
38234 function mean$2(values) {
38235 var sum = 0;
38236 for (var i = 0; i < values.length; i++) {
38237 sum += values[i];
38238 }
38239 return sum / values.length;
38240 }
38241 function standardDeviation(values, mean) {
38242 var squareDiffSum = 0;
38243 for (var i = 0; i < values.length; i++) {
38244 var diff = values[i] - mean;
38245 squareDiffSum += diff * diff;
38246 }
38247 return Math.sqrt(squareDiffSum / values.length);
38248 }
38249 function kurtosis(values) {
38250 // https://en.wikipedia.org/wiki/Kurtosis
38251 var valuesMean = mean$2(values);
38252 var n = values.length;
38253 var sum2 = 0;
38254 var sum4 = 0;
38255 for (var i = 0; i < n; i++) {
38256 var v = values[i] - valuesMean;
38257 sum2 += Math.pow(v, 2);
38258 sum4 += Math.pow(v, 4);
38259 }
38260 return 1 / n * sum4 / Math.pow(1 / n * sum2, 2);
38261 }
38262 function skewness(values) {
38263 // https://en.wikipedia.org/wiki/Skewness
38264 var valuesMean = mean$2(values);
38265 var n = values.length;
38266 var sum2 = 0;
38267 var sum3 = 0;
38268 for (var i = 0; i < n; i++) {
38269 var v = values[i] - valuesMean;
38270 sum2 += Math.pow(v, 2);
38271 sum3 += Math.pow(v, 3);
38272 }
38273 return 1 / n * sum3 / Math.pow(1 / (n - 1) * sum2, 3 / 2);
38274 }
38275
38276 /**
38277 * @license
38278 * Copyright 2020 Google LLC. All Rights Reserved.
38279 * Licensed under the Apache License, Version 2.0 (the "License");
38280 * you may not use this file except in compliance with the License.
38281 * You may obtain a copy of the License at
38282 *
38283 * http://www.apache.org/licenses/LICENSE-2.0
38284 *
38285 * Unless required by applicable law or agreed to in writing, software
38286 * distributed under the License is distributed on an "AS IS" BASIS,
38287 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38288 * See the License for the specific language governing permissions and
38289 * limitations under the License.
38290 * =============================================================================
38291 */
38292 /**
38293 * Creates a `tf.Tensor` with values sampled from a gamma distribution.
38294 *
38295 * ```js
38296 * tf.randomGamma([2, 2], 1).print();
38297 * ```
38298 *
38299 * @param shape An array of integers defining the output tensor shape.
38300 * @param alpha The shape parameter of the gamma distribution.
38301 * @param beta The inverse scale parameter of the gamma distribution. Defaults
38302 * to 1.
38303 * @param dtype The data type of the output. Defaults to float32.
38304 * @param seed The seed for the random number generator.
38305 *
38306 * @doc {heading: 'Tensors', subheading: 'Random'}
38307 */
38308 function randomGamma_(shape, alpha) {
38309 var beta = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1;
38310 var dtype = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'float32';
38311 var seed = arguments.length > 4 ? arguments[4] : undefined;
38312 assertNonNegativeIntegerDimensions(shape);
38313 if (beta == null) {
38314 beta = 1;
38315 }
38316 if (dtype == null) {
38317 dtype = 'float32';
38318 }
38319 if (dtype !== 'float32' && dtype !== 'int32') {
38320 throw new Error("Unsupported data type ".concat(dtype));
38321 }
38322 var rgamma = new RandGamma(alpha, beta, dtype, seed);
38323 var res = buffer(shape, dtype);
38324 for (var i = 0; i < res.values.length; i++) {
38325 res.values[i] = rgamma.nextValue();
38326 }
38327 return res.toTensor();
38328 }
38329 var randomGamma = /* @__PURE__ */op({
38330 randomGamma_: randomGamma_
38331 });
38332
38333 /**
38334 * @license
38335 * Copyright 2020 Google LLC. All Rights Reserved.
38336 * Licensed under the Apache License, Version 2.0 (the "License");
38337 * you may not use this file except in compliance with the License.
38338 * You may obtain a copy of the License at
38339 *
38340 * http://www.apache.org/licenses/LICENSE-2.0
38341 *
38342 * Unless required by applicable law or agreed to in writing, software
38343 * distributed under the License is distributed on an "AS IS" BASIS,
38344 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38345 * See the License for the specific language governing permissions and
38346 * limitations under the License.
38347 * =============================================================================
38348 */
38349 /**
38350 * Creates a `tf.Tensor` with values sampled from a normal distribution.
38351 *
38352 * ```js
38353 * tf.randomNormal([2, 2]).print();
38354 * ```
38355 *
38356 * @param shape An array of integers defining the output tensor shape.
38357 * @param mean The mean of the normal distribution.
38358 * @param stdDev The standard deviation of the normal distribution.
38359 * @param dtype The data type of the output.
38360 * @param seed The seed for the random number generator.
38361 *
38362 * @doc {heading: 'Tensors', subheading: 'Random'}
38363 */
38364 function randomNormal_(shape) {
38365 var mean = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
38366 var stdDev = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1;
38367 var dtype = arguments.length > 3 ? arguments[3] : undefined;
38368 var seed = arguments.length > 4 ? arguments[4] : undefined;
38369 assertNonNegativeIntegerDimensions(shape);
38370 if (dtype != null && dtype === 'bool') {
38371 throw new Error("Unsupported data type ".concat(dtype));
38372 }
38373 var randGauss = new MPRandGauss(mean, stdDev, dtype, false /* truncated */, seed);
38374 var res = buffer(shape, dtype);
38375 for (var i = 0; i < res.values.length; i++) {
38376 res.values[i] = randGauss.nextValue();
38377 }
38378 return res.toTensor();
38379 }
38380 var randomNormal$2 = /* @__PURE__ */op({
38381 randomNormal_: randomNormal_
38382 });
38383
38384 /**
38385 * @license
38386 * Copyright 2022 Google LLC. All Rights Reserved.
38387 * Licensed under the Apache License, Version 2.0 (the "License");
38388 * you may not use this file except in compliance with the License.
38389 * You may obtain a copy of the License at
38390 *
38391 * http://www.apache.org/licenses/LICENSE-2.0
38392 *
38393 * Unless required by applicable law or agreed to in writing, software
38394 * distributed under the License is distributed on an "AS IS" BASIS,
38395 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38396 * See the License for the specific language governing permissions and
38397 * limitations under the License.
38398 * =============================================================================
38399 */
38400 /**
38401 * Creates a `tf.Tensor` with values sampled from a normal distribution.
38402 *
38403 * The generated values will have mean 0 and standard deviation 1.
38404 *
38405 * ```js
38406 * tf.randomStandardNormal([2, 2]).print();
38407 * ```
38408 *
38409 * @param shape An array of integers defining the output tensor shape.
38410 * @param dtype The data type of the output.
38411 * @param seed The seed for the random number generator.
38412 *
38413 * @doc {heading: 'Tensors', subheading: 'Random'}
38414 */
38415 function randomStandardNormal_(shape, dtype, seed) {
38416 if (dtype != null && dtype === 'bool') {
38417 throw new Error("Unsupported data type ".concat(dtype));
38418 }
38419 return randomNormal$2(shape, 0, 1, dtype, seed);
38420 }
38421 var randomStandardNormal = /* @__PURE__ */op({
38422 randomStandardNormal_: randomStandardNormal_
38423 });
38424
38425 /**
38426 * @license
38427 * Copyright 2020 Google LLC. All Rights Reserved.
38428 * Licensed under the Apache License, Version 2.0 (the "License");
38429 * you may not use this file except in compliance with the License.
38430 * You may obtain a copy of the License at
38431 *
38432 * http://www.apache.org/licenses/LICENSE-2.0
38433 *
38434 * Unless required by applicable law or agreed to in writing, software
38435 * distributed under the License is distributed on an "AS IS" BASIS,
38436 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38437 * See the License for the specific language governing permissions and
38438 * limitations under the License.
38439 * =============================================================================
38440 */
38441 /**
38442 * Creates a `tf.Tensor` with values sampled from a uniform distribution.
38443 *
38444 * The generated values follow a uniform distribution in the range [minval,
38445 * maxval). The lower bound minval is included in the range, while the upper
38446 * bound maxval is excluded.
38447 *
38448 * ```js
38449 * tf.randomUniform([2, 2]).print();
38450 * ```
38451 *
38452 * @param shape An array of integers defining the output tensor shape.
38453 * @param minval The lower bound on the range of random values to generate.
38454 * Defaults to 0.
38455 * @param maxval The upper bound on the range of random values to generate.
38456 * Defaults to 1.
38457 * @param dtype The data type of the output tensor. Defaults to 'float32'.
38458 * @param seed An optional int. Defaults to 0. If seed is set to be non-zero,
38459 * the random number generator is seeded by the given seed. Otherwise, it is
38460 * seeded by a random seed.
38461 *
38462 * @doc {heading: 'Tensors', subheading: 'Random'}
38463 */
38464 function randomUniform_(shape) {
38465 var minval = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
38466 var maxval = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1;
38467 var dtype = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'float32';
38468 var seed = arguments.length > 4 ? arguments[4] : undefined;
38469 assertNonNegativeIntegerDimensions(shape);
38470 var res = buffer(shape, dtype);
38471 var random = new UniformRandom(minval, maxval, null, seed);
38472 for (var i = 0; i < res.values.length; i++) {
38473 res.values[i] = random.nextValue();
38474 }
38475 return res.toTensor();
38476 }
38477 var randomUniform$1 = /* @__PURE__ */op({
38478 randomUniform_: randomUniform_
38479 });
38480
38481 /**
38482 * @license
38483 * Copyright 2023 Google LLC.
38484 * Licensed under the Apache License, Version 2.0 (the "License");
38485 * you may not use this file except in compliance with the License.
38486 * You may obtain a copy of the License at
38487 *
38488 * http://www.apache.org/licenses/LICENSE-2.0
38489 *
38490 * Unless required by applicable law or agreed to in writing, software
38491 * distributed under the License is distributed on an "AS IS" BASIS,
38492 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38493 * See the License for the specific language governing permissions and
38494 * limitations under the License.
38495 * =============================================================================
38496 */
38497 /**
38498 * Creates a `tf.Tensor` with integers sampled from a uniform distribution.
38499 *
38500 * The generated values are uniform integers in the range [minval, maxval). The
38501 * lower bound minval is included in the range, while the upper bound maxval is
38502 * excluded.
38503 *
38504 * ```js
38505 * tf.randomUniformInt([2, 2], 0, 10).print();
38506 * ```
38507 *
38508 * @param shape An array of integers defining the output tensor shape.
38509 * @param minval Inclusive lower bound on the generated integers.
38510 * @param maxval Exclusive upper bound on the generated integers.
38511 * @param seed An optional int. Defaults to 0. If seed is set to be non-zero,
38512 * the random number generator is seeded by the given seed. Otherwise, it is
38513 * seeded by a random seed.
38514 *
38515 * @doc {heading: 'Tensors', subheading: 'Random'}
38516 */
38517 function randomUniformInt_(shape, minval, maxval, seed) {
38518 // TODO(mattsoulanille): Handle optional seed2 input.
38519 return randomUniform$1(shape, minval, maxval, 'int32', seed);
38520 }
38521 var randomUniformInt = /* @__PURE__ */op({
38522 randomUniformInt_: randomUniformInt_
38523 });
38524
38525 /**
38526 * @license
38527 * Copyright 2018 Google LLC. All Rights Reserved.
38528 * Licensed under the Apache License, Version 2.0 (the "License");
38529 * you may not use this file except in compliance with the License.
38530 * You may obtain a copy of the License at
38531 *
38532 * http://www.apache.org/licenses/LICENSE-2.0
38533 *
38534 * Unless required by applicable law or agreed to in writing, software
38535 * distributed under the License is distributed on an "AS IS" BASIS,
38536 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38537 * See the License for the specific language governing permissions and
38538 * limitations under the License.
38539 * =============================================================================
38540 */
38541 /**
38542 * Creates a new `tf.Tensor1D` filled with the numbers in the range provided.
38543 *
38544 * The tensor is a half-open interval meaning it includes start, but
38545 * excludes stop. Decrementing ranges and negative step values are also
38546 * supported.
38547 *
38548 *
38549 * ```js
38550 * tf.range(0, 9, 2).print();
38551 * ```
38552 *
38553 * @param start An integer start value
38554 * @param stop An integer stop value
38555 * @param step An integer increment (will default to 1 or -1)
38556 * @param dtype The data type of the output tensor. Defaults to 'float32'.
38557 *
38558 * @doc {heading: 'Tensors', subheading: 'Creation'}
38559 */
38560 function range$3(start, stop) {
38561 var step = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1;
38562 var dtype = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'float32';
38563 if (step === 0) {
38564 throw new Error('Cannot have a step of zero');
38565 }
38566 var attrs = {
38567 start: start,
38568 stop: stop,
38569 step: step,
38570 dtype: dtype
38571 };
38572 return ENGINE.runKernel(Range, {} /* inputs */, attrs);
38573 }
38574
38575 /**
38576 * @license
38577 * Copyright 2020 Google LLC. All Rights Reserved.
38578 * Licensed under the Apache License, Version 2.0 (the "License");
38579 * you may not use this file except in compliance with the License.
38580 * You may obtain a copy of the License at
38581 *
38582 * http://www.apache.org/licenses/LICENSE-2.0
38583 *
38584 * Unless required by applicable law or agreed to in writing, software
38585 * distributed under the License is distributed on an "AS IS" BASIS,
38586 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38587 * See the License for the specific language governing permissions and
38588 * limitations under the License.
38589 * =============================================================================
38590 */
38591 /**
38592 * Returns the real part of a complex (or real) tensor.
38593 *
38594 * Given a tensor input, this operation returns a tensor of type float that is
38595 * the real part of each element in input considered as a complex number.
38596 *
38597 * If the input is real, it simply makes a clone.
38598 *
38599 * ```js
38600 * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
38601 * tf.real(x).print();
38602 * ```
38603 *
38604 * @doc {heading: 'Tensors', subheading: 'Creation'}
38605 */
38606 function real_(input) {
38607 var $input = convertToTensor(input, 'input', 'real');
38608 var inputs = {
38609 input: $input
38610 };
38611 return ENGINE.runKernel(Real, inputs);
38612 }
38613 var real$2 = /* @__PURE__ */op({
38614 real_: real_
38615 });
38616
38617 /**
38618 * @license
38619 * Copyright 2018 Google LLC. All Rights Reserved.
38620 * Licensed under the Apache License, Version 2.0 (the "License");
38621 * you may not use this file except in compliance with the License.
38622 * You may obtain a copy of the License at
38623 *
38624 * http://www.apache.org/licenses/LICENSE-2.0
38625 *
38626 * Unless required by applicable law or agreed to in writing, software
38627 * distributed under the License is distributed on an "AS IS" BASIS,
38628 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38629 * See the License for the specific language governing permissions and
38630 * limitations under the License.
38631 * =============================================================================
38632 */
38633 /**
38634 * Computes reciprocal of x element-wise: `1 / x`
38635 *
38636 * ```js
38637 * const x = tf.tensor1d([0, 1, 2]);
38638 *
38639 * x.reciprocal().print(); // or tf.reciprocal(x)
38640 * ```
38641 * @param x The input tensor.
38642 *
38643 * @doc {heading: 'Operations', subheading: 'Basic math'}
38644 */
38645 function reciprocal_(x) {
38646 var $x = convertToTensor(x, 'x', 'reciprocal');
38647 var inputs = {
38648 x: $x
38649 };
38650 return ENGINE.runKernel(Reciprocal, inputs);
38651 }
38652 var reciprocal$2 = /* @__PURE__ */op({
38653 reciprocal_: reciprocal_
38654 });
38655
38656 /**
38657 * @license
38658 * Copyright 2020 Google LLC. All Rights Reserved.
38659 * Licensed under the Apache License, Version 2.0 (the "License");
38660 * you may not use this file except in compliance with the License.
38661 * You may obtain a copy of the License at
38662 *
38663 * http://www.apache.org/licenses/LICENSE-2.0
38664 *
38665 * Unless required by applicable law or agreed to in writing, software
38666 * distributed under the License is distributed on an "AS IS" BASIS,
38667 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38668 * See the License for the specific language governing permissions and
38669 * limitations under the License.
38670 * =============================================================================
38671 */
38672 /**
38673 * Computes rectified linear element-wise: `max(x, 0)`.
38674 *
38675 * ```js
38676 * const x = tf.tensor1d([-1, 2, -3, 4]);
38677 *
38678 * x.relu().print(); // or tf.relu(x)
38679 * ```
38680 * @param x The input tensor. If the dtype is `bool`, the output dtype will be
38681 * `int32`.
38682 *
38683 * @doc {heading: 'Operations', subheading: 'Basic math'}
38684 */
38685 function relu_(x) {
38686 var $x = convertToTensor(x, 'x', 'relu');
38687 var inputs = {
38688 x: $x
38689 };
38690 return ENGINE.runKernel(Relu$1, inputs);
38691 }
38692 var relu$2 = /* @__PURE__ */op({
38693 relu_: relu_
38694 });
38695
38696 /**
38697 * @license
38698 * Copyright 2020 Google LLC. All Rights Reserved.
38699 * Licensed under the Apache License, Version 2.0 (the "License");
38700 * you may not use this file except in compliance with the License.
38701 * You may obtain a copy of the License at
38702 *
38703 * http://www.apache.org/licenses/LICENSE-2.0
38704 *
38705 * Unless required by applicable law or agreed to in writing, software
38706 * distributed under the License is distributed on an "AS IS" BASIS,
38707 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38708 * See the License for the specific language governing permissions and
38709 * limitations under the License.
38710 * =============================================================================
38711 */
38712 /**
38713 * Computes rectified linear 6 element-wise: `min(max(x, 0), 6)`.
38714 *
38715 * ```js
38716 * const x = tf.tensor1d([-1, 2, -3, 8]);
38717 *
38718 * x.relu6().print(); // or tf.relu6(x)
38719 * ```
38720 * @param x The input tensor. If the dtype is `bool`, the output dtype will be
38721 * `int32`.
38722 *
38723 * @doc {heading: 'Operations', subheading: 'Basic math'}
38724 */
38725 function relu6_(x) {
38726 var $x = convertToTensor(x, 'x', 'relu6');
38727 var inputs = {
38728 x: $x
38729 };
38730 return ENGINE.runKernel(Relu6$1, inputs);
38731 }
38732 var relu6$2 = /* @__PURE__ */op({
38733 relu6_: relu6_
38734 });
38735
38736 /**
38737 * @license
38738 * Copyright 2018 Google LLC. All Rights Reserved.
38739 * Licensed under the Apache License, Version 2.0 (the "License");
38740 * you may not use this file except in compliance with the License.
38741 * You may obtain a copy of the License at
38742 *
38743 * http://www.apache.org/licenses/LICENSE-2.0
38744 *
38745 * Unless required by applicable law or agreed to in writing, software
38746 * distributed under the License is distributed on an "AS IS" BASIS,
38747 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38748 * See the License for the specific language governing permissions and
38749 * limitations under the License.
38750 * =============================================================================
38751 */
38752 /**
38753 * Reverses a `tf.Tensor` along a specified axis.
38754 *
38755 * Also available are stricter rank-specific methods that assert that `x` is
38756 * of the given rank:
38757 * - `tf.reverse1d`
38758 * - `tf.reverse2d`
38759 * - `tf.reverse3d`
38760 * - `tf.reverse4d`
38761 *
38762 * Except `tf.reverse1d` (which does not have axis param), all methods have
38763 * same signature as this method.
38764 *
38765 * ```js
38766 * const x = tf.tensor1d([1, 2, 3, 4]);
38767 *
38768 * x.reverse().print();
38769 * ```
38770 *
38771 * ```js
38772 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
38773 *
38774 * const axis = 1;
38775 * x.reverse(axis).print();
38776 * ```
38777 * @param x The input tensor to be reversed.
38778 * @param axis The set of dimensions to reverse. Must be in the
38779 * range [-rank(x), rank(x)). Defaults to all axes.
38780 *
38781 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
38782 */
38783 function reverse_(x, axis) {
38784 var $x = convertToTensor(x, 'x', 'reverse');
38785 var inputs = {
38786 x: $x
38787 };
38788 var attrs = {
38789 dims: axis
38790 };
38791 return ENGINE.runKernel(Reverse, inputs, attrs);
38792 }
38793 var reverse$2 = /* @__PURE__ */op({
38794 reverse_: reverse_
38795 });
38796
38797 /**
38798 * @license
38799 * Copyright 2020 Google LLC. All Rights Reserved.
38800 * Licensed under the Apache License, Version 2.0 (the "License");
38801 * you may not use this file except in compliance with the License.
38802 * You may obtain a copy of the License at
38803 *
38804 * http://www.apache.org/licenses/LICENSE-2.0
38805 *
38806 * Unless required by applicable law or agreed to in writing, software
38807 * distributed under the License is distributed on an "AS IS" BASIS,
38808 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38809 * See the License for the specific language governing permissions and
38810 * limitations under the License.
38811 * =============================================================================
38812 */
38813 /**
38814 * Reverses a `tf.Tensor1D`.
38815 *
38816 * @param x The input tensor.
38817 */
38818 function reverse1d_(x) {
38819 var $x = convertToTensor(x, 'x', 'reverse');
38820 assert$1($x.rank === 1, function () {
38821 return "Error in reverse1D: x must be rank 1 but got rank ".concat($x.rank, ".");
38822 });
38823 return reverse$2($x, 0);
38824 }
38825 var reverse1d = /* @__PURE__ */op({
38826 reverse1d_: reverse1d_
38827 });
38828
38829 /**
38830 * @license
38831 * Copyright 2020 Google LLC. All Rights Reserved.
38832 * Licensed under the Apache License, Version 2.0 (the "License");
38833 * you may not use this file except in compliance with the License.
38834 * You may obtain a copy of the License at
38835 *
38836 * http://www.apache.org/licenses/LICENSE-2.0
38837 *
38838 * Unless required by applicable law or agreed to in writing, software
38839 * distributed under the License is distributed on an "AS IS" BASIS,
38840 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38841 * See the License for the specific language governing permissions and
38842 * limitations under the License.
38843 * =============================================================================
38844 */
38845 /**
38846 * Reverses a `tf.Tensor2D` along a specified axis.
38847 *
38848 * @param x The input tensor.
38849 * @param axis The set of dimensions to reverse. Must be in the
38850 * range [-rank(x), rank(x)). Defaults to all axes.
38851 */
38852 function reverse2d_(x, axis) {
38853 var $x = convertToTensor(x, 'x', 'reverse');
38854 assert$1($x.rank === 2, function () {
38855 return "Error in reverse2D: x must be rank 2 but got rank ".concat($x.rank, ".");
38856 });
38857 return reverse$2($x, axis);
38858 }
38859 var reverse2d = /* @__PURE__ */op({
38860 reverse2d_: reverse2d_
38861 });
38862
38863 /**
38864 * @license
38865 * Copyright 2020 Google LLC. All Rights Reserved.
38866 * Licensed under the Apache License, Version 2.0 (the "License");
38867 * you may not use this file except in compliance with the License.
38868 * You may obtain a copy of the License at
38869 *
38870 * http://www.apache.org/licenses/LICENSE-2.0
38871 *
38872 * Unless required by applicable law or agreed to in writing, software
38873 * distributed under the License is distributed on an "AS IS" BASIS,
38874 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38875 * See the License for the specific language governing permissions and
38876 * limitations under the License.
38877 * =============================================================================
38878 */
38879 /**
38880 * Reverses a `tf.Tensor3D` along a specified axis.
38881 *
38882 * @param x The input tensor.
38883 * @param axis The set of dimensions to reverse. Must be in the
38884 * range [-rank(x), rank(x)). Defaults to all axes.
38885 */
38886 function reverse3d_(x, axis) {
38887 var $x = convertToTensor(x, 'x', 'reverse');
38888 assert$1($x.rank === 3, function () {
38889 return "Error in reverse3D: x must be rank 3 but got rank ".concat($x.rank, ".");
38890 });
38891 return reverse$2($x, axis);
38892 }
38893 var reverse3d = /* @__PURE__ */op({
38894 reverse3d_: reverse3d_
38895 });
38896
38897 /**
38898 * @license
38899 * Copyright 2020 Google LLC. All Rights Reserved.
38900 * Licensed under the Apache License, Version 2.0 (the "License");
38901 * you may not use this file except in compliance with the License.
38902 * You may obtain a copy of the License at
38903 *
38904 * http://www.apache.org/licenses/LICENSE-2.0
38905 *
38906 * Unless required by applicable law or agreed to in writing, software
38907 * distributed under the License is distributed on an "AS IS" BASIS,
38908 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38909 * See the License for the specific language governing permissions and
38910 * limitations under the License.
38911 * =============================================================================
38912 */
38913 /**
38914 * Reverses a `tf.Tensor4D` along a specified axis.
38915 *
38916 * @param x The input tensor.
38917 * @param axis The set of dimensions to reverse. Must be in the
38918 * range [-rank(x), rank(x)). Defaults to all axes.
38919 */
38920 function reverse4d_(x, axis) {
38921 var $x = convertToTensor(x, 'x', 'reverse');
38922 assert$1($x.rank === 4, function () {
38923 return "Error in reverse4D: x must be rank 4 but got rank ".concat($x.rank, ".");
38924 });
38925 return reverse$2($x, axis);
38926 }
38927 var reverse4d = /* @__PURE__ */op({
38928 reverse4d_: reverse4d_
38929 });
38930
38931 /**
38932 * @license
38933 * Copyright 2018 Google LLC. All Rights Reserved.
38934 * Licensed under the Apache License, Version 2.0 (the "License");
38935 * you may not use this file except in compliance with the License.
38936 * You may obtain a copy of the License at
38937 *
38938 * http://www.apache.org/licenses/LICENSE-2.0
38939 *
38940 * Unless required by applicable law or agreed to in writing, software
38941 * distributed under the License is distributed on an "AS IS" BASIS,
38942 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38943 * See the License for the specific language governing permissions and
38944 * limitations under the License.
38945 * =============================================================================
38946 */
38947 /**
38948 * Computes round of input `tf.Tensor` element-wise: `round(x)`.
38949 * It implements banker's rounding.
38950 *
38951 * ```js
38952 * const x = tf.tensor1d([.6, 1.1, -3.3]);
38953 *
38954 * x.round().print(); // or tf.round(x)
38955 * ```
38956 * @param x The input tensor.
38957 *
38958 * @doc {heading: 'Operations', subheading: 'Basic math'}
38959 */
38960 function round_(x) {
38961 var $x = convertToTensor(x, 'x', 'round');
38962 var inputs = {
38963 x: $x
38964 };
38965 return ENGINE.runKernel(Round, inputs);
38966 }
38967 var round$2 = /* @__PURE__ */op({
38968 round_: round_
38969 });
38970
38971 /**
38972 * @license
38973 * Copyright 2018 Google LLC. All Rights Reserved.
38974 * Licensed under the Apache License, Version 2.0 (the "License");
38975 * you may not use this file except in compliance with the License.
38976 * You may obtain a copy of the License at
38977 *
38978 * http://www.apache.org/licenses/LICENSE-2.0
38979 *
38980 * Unless required by applicable law or agreed to in writing, software
38981 * distributed under the License is distributed on an "AS IS" BASIS,
38982 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38983 * See the License for the specific language governing permissions and
38984 * limitations under the License.
38985 * =============================================================================
38986 */
38987 /**
38988 * Computes reciprocal of square root of the input `tf.Tensor` element-wise:
38989 * `y = 1 / sqrt(x)`
38990 *
38991 * ```js
38992 * const x = tf.tensor1d([1, 2, 4, -1]);
38993 *
38994 * x.rsqrt().print(); // or tf.rsqrt(x)
38995 * ```
38996 * @param x The input tensor.
38997 *
38998 * @doc {heading: 'Operations', subheading: 'Basic math'}
38999 */
39000 function rsqrt_(x) {
39001 var $x = convertToTensor(x, 'x', 'rsqrt', 'float32');
39002 var inputs = {
39003 x: $x
39004 };
39005 return ENGINE.runKernel(Rsqrt, inputs);
39006 }
39007 var rsqrt$2 = /* @__PURE__ */op({
39008 rsqrt_: rsqrt_
39009 });
39010
39011 /**
39012 * @license
39013 * Copyright 2020 Google LLC. All Rights Reserved.
39014 * Licensed under the Apache License, Version 2.0 (the "License");
39015 * you may not use this file except in compliance with the License.
39016 * You may obtain a copy of the License at
39017 *
39018 * http://www.apache.org/licenses/LICENSE-2.0
39019 *
39020 * Unless required by applicable law or agreed to in writing, software
39021 * distributed under the License is distributed on an "AS IS" BASIS,
39022 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39023 * See the License for the specific language governing permissions and
39024 * limitations under the License.
39025 * =============================================================================
39026 */
39027 /**
39028 * Computes scaled exponential linear element-wise.
39029 *
39030 * `x < 0 ? scale * alpha * (exp(x) - 1) : scale * x`
39031 *
39032 * ```js
39033 * const x = tf.tensor1d([-1, 2, -3, 4]);
39034 *
39035 * x.selu().print(); // or tf.selu(x)
39036 * ```
39037 * @param x The input tensor.
39038 *
39039 * @doc {heading: 'Operations', subheading: 'Basic math'}
39040 */
39041 function selu_(x) {
39042 var $x = convertToTensor(x, 'x', 'selu');
39043 var inputs = {
39044 x: $x
39045 };
39046 return ENGINE.runKernel(Selu$1, inputs);
39047 }
39048 var selu$2 = /* @__PURE__ */op({
39049 selu_: selu_
39050 });
39051
39052 /**
39053 * 2-D convolution with separable filters.
39054 *
39055 * Performs a depthwise convolution that acts separately on channels followed
39056 * by a pointwise convolution that mixes channels. Note that this is
39057 * separability between dimensions [1, 2] and 3, not spatial separability
39058 * between dimensions 1 and 2.
39059 *
39060 * See
39061 * [https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d](
39062 * https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d)
39063 * for more details.
39064 *
39065 * @param x The input tensor, of rank 4 or rank 3, of shape
39066 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
39067 * assumed.
39068 * @param depthwiseFilter The depthwise filter tensor, rank 4, of shape
39069 * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. This is
39070 * the filter used in the first step.
39071 * @param pointwiseFilter The pointwise filter tensor, rank 4, of shape
39072 * `[1, 1, inChannels * channelMultiplier, outChannels]`. This is
39073 * the filter used in the second step.
39074 * @param strides The strides of the convolution: `[strideHeight,
39075 * strideWidth]`. If strides is a single number, then `strideHeight ==
39076 * strideWidth`.
39077 * @param pad The type of padding algorithm.
39078 * - `same` and stride 1: output will be of same size as input,
39079 * regardless of filter size.
39080 * - `valid`: output will be smaller than input if filter is larger
39081 * than 1x1.
39082 * - For more info, see this guide:
39083 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
39084 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
39085 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
39086 * in which we sample input values across the height and width dimensions
39087 * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
39088 * number, then `dilationHeight == dilationWidth`. If it is greater than
39089 * 1, then all values of `strides` must be 1.
39090 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
39091 * "NHWC". Specify the data format of the input and output data. With the
39092 * default format "NHWC", the data is stored in the order of: [batch,
39093 * height, width, channels]. Only "NHWC" is currently supported.
39094 *
39095 * @doc {heading: 'Operations', subheading: 'Convolution'}
39096 */
39097 function separableConv2d_(x, depthwiseFilter, pointwiseFilter, strides, pad) {
39098 var dilation = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : [1, 1];
39099 var dataFormat = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : 'NHWC';
39100 var $x = convertToTensor(x, 'x', 'separableConv2d');
39101 var $depthwiseFilter = convertToTensor(depthwiseFilter, 'depthwiseFilter', 'separableConv2d');
39102 var $pointwiseFilter = convertToTensor(pointwiseFilter, 'pointwiseFilter', 'separableConv2d');
39103 var x4D = $x;
39104 var reshapedTo4D = false;
39105 if ($x.rank === 3) {
39106 reshapedTo4D = true;
39107 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
39108 }
39109 if (dataFormat === 'NCHW') {
39110 throw new Error('separableConv2d currently does not support dataFormat NCHW; only ' + 'NHWC is supported');
39111 }
39112 assert$1(x4D.rank === 4, function () {
39113 return "Error in separableConv2d: input must be rank 4, but got " + "rank ".concat(x4D.rank, ".");
39114 });
39115 assert$1($depthwiseFilter.rank === 4, function () {
39116 return "Error in separableConv2d: depthwise filter must be rank 4, but " + "got rank ".concat($depthwiseFilter.rank, ".");
39117 });
39118 assert$1($pointwiseFilter.rank === 4, function () {
39119 return "Error in separableConv2d: pointwise filter must be rank 4, but " + "got rank ".concat($depthwiseFilter.rank, ".");
39120 });
39121 assert$1($pointwiseFilter.shape[0] === 1, function () {
39122 return "Error in separableConv2d: the first dimension of pointwise filter " + " must be 1, but got ".concat($pointwiseFilter.shape[0], ".");
39123 });
39124 assert$1($pointwiseFilter.shape[1] === 1, function () {
39125 return "Error in separableConv2d: the second dimension of pointwise " + "filter must be 1, but got ".concat($pointwiseFilter.shape[1], ".");
39126 });
39127 var inChannels = $depthwiseFilter.shape[2];
39128 var channelMultiplier = $depthwiseFilter.shape[3];
39129 assert$1($pointwiseFilter.shape[2] === inChannels * channelMultiplier, function () {
39130 return "Error in separableConv2d: the third dimension of pointwise filter " + "must be ".concat(inChannels * channelMultiplier, ", ") + "but got ".concat($pointwiseFilter.shape[2], ".");
39131 });
39132 var depthwise = depthwiseConv2d$3(x4D, $depthwiseFilter, strides, pad, dataFormat, dilation);
39133 var pointwiseStride = 1;
39134 var res = conv2d$4(depthwise, $pointwiseFilter, pointwiseStride, 'valid', dataFormat);
39135 if (reshapedTo4D) {
39136 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
39137 }
39138 return res;
39139 }
39140 var separableConv2d$1 = /* @__PURE__ */op({
39141 separableConv2d_: separableConv2d_
39142 });
39143
39144 /**
39145 * Computes the difference between two lists of numbers.
39146 *
39147 * Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out`
39148 * that represents all values that are in `x` but not in `y`. The returned
39149 * Tensor `out` is sorted in the same order that the numbers appear in `x`
39150 * (duplicates are preserved). This operation also returns a Tensor indices that
39151 * represents the position of each out element in `x`. In other words:
39152 *
39153 * `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]`
39154 *
39155 * ```js
39156 * const x = [1, 2, 3, 4, 5, 6];
39157 * const y = [1, 3, 5];
39158 *
39159 * const [out, indices] = await tf.setdiff1dAsync(x, y);
39160 * out.print(); // [2, 4, 6]
39161 * indices.print(); // [1, 3, 5]
39162 * ```
39163 *
39164 * @param x 1-D Tensor. Values to keep.
39165 * @param y 1-D Tensor. Must have the same type as x. Values to exclude in the
39166 * output.
39167 * @returns Promise of Tensor tuple [out, indices].
39168 * out: Tensor with the same type as x.
39169 * indices: A Tensor of type int32.
39170 *
39171 * @doc {heading: 'Tensors', subheading: 'Transformations'}
39172 */
39173 function setdiff1dAsync_(_x, _x2) {
39174 return _setdiff1dAsync_.apply(this, arguments);
39175 }
39176 function _setdiff1dAsync_() {
39177 _setdiff1dAsync_ = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(x, y) {
39178 var $x, $y, xVals, yVals, ySet, outputSize, i, buffer, indices, _i, p;
39179 return _regeneratorRuntime().wrap(function _callee$(_context) {
39180 while (1) switch (_context.prev = _context.next) {
39181 case 0:
39182 $x = convertToTensor(x, 'x', 'setdiff1d');
39183 $y = convertToTensor(y, 'y', 'setdiff1d');
39184 assert$1($x.dtype === $y.dtype, function () {
39185 return "x and y should have the same dtype, but got x (".concat($x.dtype, ") and y (").concat($y.dtype, ").");
39186 });
39187 assert$1($x.rank === 1, function () {
39188 return "x should be 1D tensor, but got x (".concat($x.shape, ").");
39189 });
39190 assert$1($y.rank === 1, function () {
39191 return "y should be 1D tensor, but got y (".concat($y.shape, ").");
39192 });
39193 _context.next = 7;
39194 return $x.data();
39195 case 7:
39196 xVals = _context.sent;
39197 _context.next = 10;
39198 return $y.data();
39199 case 10:
39200 yVals = _context.sent;
39201 ySet = new Set(yVals);
39202 outputSize = 0;
39203 for (i = 0; i < xVals.length; i++) {
39204 if (!ySet.has(xVals[i])) {
39205 outputSize++;
39206 }
39207 }
39208 buffer = new TensorBuffer([outputSize], $x.dtype);
39209 indices = new TensorBuffer([outputSize], 'int32');
39210 for (_i = 0, p = 0; _i < xVals.length; _i++) {
39211 if (!ySet.has(xVals[_i])) {
39212 buffer.values[p] = xVals[_i];
39213 indices.values[p] = _i;
39214 p++;
39215 }
39216 }
39217 return _context.abrupt("return", [buffer.toTensor(), indices.toTensor()]);
39218 case 18:
39219 case "end":
39220 return _context.stop();
39221 }
39222 }, _callee);
39223 }));
39224 return _setdiff1dAsync_.apply(this, arguments);
39225 }
39226 var setdiff1dAsync = setdiff1dAsync_;
39227
39228 /**
39229 * @license
39230 * Copyright 2018 Google LLC. All Rights Reserved.
39231 * Licensed under the Apache License, Version 2.0 (the "License");
39232 * you may not use this file except in compliance with the License.
39233 * You may obtain a copy of the License at
39234 *
39235 * http://www.apache.org/licenses/LICENSE-2.0
39236 *
39237 * Unless required by applicable law or agreed to in writing, software
39238 * distributed under the License is distributed on an "AS IS" BASIS,
39239 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39240 * See the License for the specific language governing permissions and
39241 * limitations under the License.
39242 * =============================================================================
39243 */
39244 /**
39245 * Returns an element-wise indication of the sign of a number.
39246 *
39247 * ```js
39248 * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]);
39249 *
39250 * x.sign().print(); // or tf.sign(x)
39251 * ```
39252 * @param x The input Tensor.
39253 *
39254 * @doc {heading: 'Operations', subheading: 'Basic math'}
39255 */
39256 function sign_(x) {
39257 var $x = convertToTensor(x, 'x', 'sign');
39258 var inputs = {
39259 x: $x
39260 };
39261 return ENGINE.runKernel(Sign, inputs);
39262 }
39263 var sign$3 = /* @__PURE__ */op({
39264 sign_: sign_
39265 });
39266
39267 /**
39268 * @license
39269 * Copyright 2018 Google LLC. All Rights Reserved.
39270 * Licensed under the Apache License, Version 2.0 (the "License");
39271 * you may not use this file except in compliance with the License.
39272 * You may obtain a copy of the License at
39273 *
39274 * http://www.apache.org/licenses/LICENSE-2.0
39275 *
39276 * Unless required by applicable law or agreed to in writing, software
39277 * distributed under the License is distributed on an "AS IS" BASIS,
39278 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39279 * See the License for the specific language governing permissions and
39280 * limitations under the License.
39281 * =============================================================================
39282 */
39283 /**
39284 * Computes sin of the input Tensor element-wise: `sin(x)`
39285 *
39286 * ```js
39287 * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
39288 *
39289 * x.sin().print(); // or tf.sin(x)
39290 * ```
39291 * @param x The input tensor.
39292 *
39293 * @doc {heading: 'Operations', subheading: 'Basic math'}
39294 */
39295 function sin_(x) {
39296 var $x = convertToTensor(x, 'x', 'sin', 'float32');
39297 var inputs = {
39298 x: $x
39299 };
39300 return ENGINE.runKernel(Sin, inputs);
39301 }
39302 var sin$2 = /* @__PURE__ */op({
39303 sin_: sin_
39304 });
39305
39306 /**
39307 * @license
39308 * Copyright 2018 Google LLC. All Rights Reserved.
39309 * Licensed under the Apache License, Version 2.0 (the "License");
39310 * you may not use this file except in compliance with the License.
39311 * You may obtain a copy of the License at
39312 *
39313 * http://www.apache.org/licenses/LICENSE-2.0
39314 *
39315 * Unless required by applicable law or agreed to in writing, software
39316 * distributed under the License is distributed on an "AS IS" BASIS,
39317 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39318 * See the License for the specific language governing permissions and
39319 * limitations under the License.
39320 * =============================================================================
39321 */
39322 /**
39323 * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)`
39324 *
39325 * ```js
39326 * const x = tf.tensor1d([0, 1, -1, .7]);
39327 *
39328 * x.sinh().print(); // or tf.sinh(x)
39329 * ```
39330 * @param x The input tensor.
39331 *
39332 * @doc {heading: 'Operations', subheading: 'Basic math'}
39333 */
39334 function sinh_(x) {
39335 var $x = convertToTensor(x, 'x', 'sinh');
39336 var inputs = {
39337 x: $x
39338 };
39339 return ENGINE.runKernel(Sinh, inputs);
39340 }
39341 var sinh$2 = /* @__PURE__ */op({
39342 sinh_: sinh_
39343 });
39344
39345 /**
39346 * @license
39347 * Copyright 2018 Google LLC. All Rights Reserved.
39348 * Licensed under the Apache License, Version 2.0 (the "License");
39349 * you may not use this file except in compliance with the License.
39350 * You may obtain a copy of the License at
39351 *
39352 * http://www.apache.org/licenses/LICENSE-2.0
39353 *
39354 * Unless required by applicable law or agreed to in writing, software
39355 * distributed under the License is distributed on an "AS IS" BASIS,
39356 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39357 * See the License for the specific language governing permissions and
39358 * limitations under the License.
39359 * =============================================================================
39360 */
39361 /**
39362 * Extracts a 1D slice from 1D array starting at coordinates `begin` and is
39363 * of length `size`. See `slice` for details.
39364 */
39365 function slice1d_(x, begin, size) {
39366 var $x = convertToTensor(x, 'x', 'slice1d');
39367 assert$1($x.rank === 1, function () {
39368 return "slice1d expects a rank-1 tensor, but got a rank-".concat($x.rank, " tensor");
39369 });
39370 return slice$2($x, [begin], [size]);
39371 }
39372 var slice1d = /* @__PURE__ */op({
39373 slice1d_: slice1d_
39374 });
39375
39376 /**
39377 * @license
39378 * Copyright 2018 Google LLC. All Rights Reserved.
39379 * Licensed under the Apache License, Version 2.0 (the "License");
39380 * you may not use this file except in compliance with the License.
39381 * You may obtain a copy of the License at
39382 *
39383 * http://www.apache.org/licenses/LICENSE-2.0
39384 *
39385 * Unless required by applicable law or agreed to in writing, software
39386 * distributed under the License is distributed on an "AS IS" BASIS,
39387 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39388 * See the License for the specific language governing permissions and
39389 * limitations under the License.
39390 * =============================================================================
39391 */
39392 /**
39393 * Extracts a 2D slice from a 2D array starting at coordinates `begin` and
39394 * is of size `size`. See `slice` for details.
39395 */
39396 function slice2d_(x, begin, size) {
39397 var $x = convertToTensor(x, 'x', 'slice2d');
39398 assert$1($x.rank === 2, function () {
39399 return "slice2d expects a rank-2 tensor, but got a rank-".concat($x.rank, " tensor");
39400 });
39401 return slice$2($x, begin, size);
39402 }
39403 var slice2d = /* @__PURE__ */op({
39404 slice2d_: slice2d_
39405 });
39406
39407 /**
39408 * @license
39409 * Copyright 2018 Google LLC. All Rights Reserved.
39410 * Licensed under the Apache License, Version 2.0 (the "License");
39411 * you may not use this file except in compliance with the License.
39412 * You may obtain a copy of the License at
39413 *
39414 * http://www.apache.org/licenses/LICENSE-2.0
39415 *
39416 * Unless required by applicable law or agreed to in writing, software
39417 * distributed under the License is distributed on an "AS IS" BASIS,
39418 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39419 * See the License for the specific language governing permissions and
39420 * limitations under the License.
39421 * =============================================================================
39422 */
39423 /**
39424 * Extracts a 3D slice from a 3D array starting at coordinates `begin` and
39425 * is of size `size`. See `slice` for details.
39426 */
39427 function slice3d_(x, begin, size) {
39428 var $x = convertToTensor(x, 'x', 'slice3d');
39429 assert$1($x.rank === 3, function () {
39430 return "slice3d expects a rank-3 tensor, but got a rank-".concat($x.rank, " tensor");
39431 });
39432 return slice$2($x, begin, size);
39433 }
39434 var slice3d = /* @__PURE__ */op({
39435 slice3d_: slice3d_
39436 });
39437
39438 /**
39439 * @license
39440 * Copyright 2018 Google LLC. All Rights Reserved.
39441 * Licensed under the Apache License, Version 2.0 (the "License");
39442 * you may not use this file except in compliance with the License.
39443 * You may obtain a copy of the License at
39444 *
39445 * http://www.apache.org/licenses/LICENSE-2.0
39446 *
39447 * Unless required by applicable law or agreed to in writing, software
39448 * distributed under the License is distributed on an "AS IS" BASIS,
39449 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39450 * See the License for the specific language governing permissions and
39451 * limitations under the License.
39452 * =============================================================================
39453 */
39454 /**
39455 * Extracts a 4D slice from a 4D array starting at coordinates `begin` and
39456 * is of size `size`. See `slice` for details.
39457 */
39458 function slice4d_(x, begin, size) {
39459 var $x = convertToTensor(x, 'x', 'slice4d');
39460 assert$1($x.rank === 4, function () {
39461 return "slice4d expects a rank-4 tensor, but got a rank-".concat($x.rank, " tensor");
39462 });
39463 return slice$2($x, begin, size);
39464 }
39465 var slice4d = /* @__PURE__ */op({
39466 slice4d_: slice4d_
39467 });
39468
39469 /**
39470 * @license
39471 * Copyright 2018 Google LLC. All Rights Reserved.
39472 * Licensed under the Apache License, Version 2.0 (the "License");
39473 * you may not use this file except in compliance with the License.
39474 * You may obtain a copy of the License at
39475 *
39476 * http://www.apache.org/licenses/LICENSE-2.0
39477 *
39478 * Unless required by applicable law or agreed to in writing, software
39479 * distributed under the License is distributed on an "AS IS" BASIS,
39480 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39481 * See the License for the specific language governing permissions and
39482 * limitations under the License.
39483 * =============================================================================
39484 */
39485 /**
39486 * Computes the softmax normalized vector given the logits.
39487 *
39488 * ```js
39489 * const a = tf.tensor1d([1, 2, 3]);
39490 *
39491 * a.softmax().print(); // or tf.softmax(a)
39492 * ```
39493 *
39494 * ```js
39495 * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
39496 *
39497 * a.softmax().print(); // or tf.softmax(a)
39498 * ```
39499 *
39500 * @param logits The logits array.
39501 * @param dim The dimension softmax would be performed on. Defaults to `-1`
39502 * which indicates the last dimension.
39503 *
39504 * @doc {heading: 'Operations', subheading: 'Normalization'}
39505 */
39506 function softmax_(logits) {
39507 var dim = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : -1;
39508 var $logits = convertToTensor(logits, 'logits', 'softmax', 'float32');
39509 if (dim === -1) {
39510 dim = $logits.rank - 1;
39511 }
39512 if (dim !== $logits.rank - 1) {
39513 throw Error('Softmax along a non-last dimension is not yet supported. ' + "Logits was rank ".concat($logits.rank, " and dim was ").concat(dim));
39514 }
39515 var inputs = {
39516 logits: $logits
39517 };
39518 var attrs = {
39519 dim: dim
39520 };
39521 return ENGINE.runKernel(Softmax$2, inputs, attrs);
39522 }
39523 var softmax$3 = /* @__PURE__ */op({
39524 softmax_: softmax_
39525 });
39526
39527 /**
39528 * @license
39529 * Copyright 2020 Google LLC. All Rights Reserved.
39530 * Licensed under the Apache License, Version 2.0 (the "License");
39531 * you may not use this file except in compliance with the License.
39532 * You may obtain a copy of the License at
39533 *
39534 * http://www.apache.org/licenses/LICENSE-2.0
39535 *
39536 * Unless required by applicable law or agreed to in writing, software
39537 * distributed under the License is distributed on an "AS IS" BASIS,
39538 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39539 * See the License for the specific language governing permissions and
39540 * limitations under the License.
39541 * =============================================================================
39542 */
39543 /**
39544 * Fast Fourier transform.
39545 *
39546 * Computes the 1-dimensional discrete Fourier transform over the inner-most
39547 * dimension of input.
39548 *
39549 * ```js
39550 * const real = tf.tensor1d([1, 2, 3]);
39551 * const imag = tf.tensor1d([1, 2, 3]);
39552 * const x = tf.complex(real, imag);
39553 *
39554 * x.fft().print(); // tf.spectral.fft(x).print();
39555 * ```
39556 * @param input The complex input to compute an fft over.
39557 *
39558 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
39559 */
39560 function fft_(input) {
39561 assert$1(input.dtype === 'complex64', function () {
39562 return "The dtype for tf.spectral.fft() must be complex64 " + "but got ".concat(input.dtype, ".");
39563 });
39564 var inputs = {
39565 input: input
39566 };
39567 return ENGINE.runKernel(FFT, inputs);
39568 }
39569 var fft$2 = /* @__PURE__ */op({
39570 fft_: fft_
39571 });
39572
39573 /**
39574 * @license
39575 * Copyright 2020 Google LLC. All Rights Reserved.
39576 * Licensed under the Apache License, Version 2.0 (the "License");
39577 * you may not use this file except in compliance with the License.
39578 * You may obtain a copy of the License at
39579 *
39580 * http://www.apache.org/licenses/LICENSE-2.0
39581 *
39582 * Unless required by applicable law or agreed to in writing, software
39583 * distributed under the License is distributed on an "AS IS" BASIS,
39584 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39585 * See the License for the specific language governing permissions and
39586 * limitations under the License.
39587 * =============================================================================
39588 */
39589 /**
39590 * Inverse fast Fourier transform.
39591 *
39592 * Computes the inverse 1-dimensional discrete Fourier transform over the
39593 * inner-most dimension of input.
39594 *
39595 * ```js
39596 * const real = tf.tensor1d([1, 2, 3]);
39597 * const imag = tf.tensor1d([1, 2, 3]);
39598 * const x = tf.complex(real, imag);
39599 *
39600 * x.ifft().print(); // tf.spectral.ifft(x).print();
39601 * ```
39602 * @param input The complex input to compute an ifft over.
39603 *
39604 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
39605 */
39606 function ifft_(input) {
39607 assert$1(input.dtype === 'complex64', function () {
39608 return "The dtype for tf.spectral.ifft() must be complex64 " + "but got ".concat(input.dtype, ".");
39609 });
39610 var inputs = {
39611 input: input
39612 };
39613 return ENGINE.runKernel(IFFT, inputs);
39614 }
39615 var ifft$2 = /* @__PURE__ */op({
39616 ifft_: ifft_
39617 });
39618
39619 /**
39620 * @license
39621 * Copyright 2018 Google LLC. All Rights Reserved.
39622 * Licensed under the Apache License, Version 2.0 (the "License");
39623 * you may not use this file except in compliance with the License.
39624 * You may obtain a copy of the License at
39625 *
39626 * http://www.apache.org/licenses/LICENSE-2.0
39627 *
39628 * Unless required by applicable law or agreed to in writing, software
39629 * distributed under the License is distributed on an "AS IS" BASIS,
39630 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39631 * See the License for the specific language governing permissions and
39632 * limitations under the License.
39633 * =============================================================================
39634 */
39635 /**
39636 * Inversed real value input fast Fourier transform.
39637 *
39638 * Computes the 1-dimensional inversed discrete Fourier transform over the
39639 * inner-most dimension of the real input.
39640 *
39641 * ```js
39642 * const real = tf.tensor1d([1, 2, 3]);
39643 * const imag = tf.tensor1d([0, 0, 0]);
39644 * const x = tf.complex(real, imag);
39645 *
39646 * x.irfft().print();
39647 * ```
39648 * @param input The real value input to compute an irfft over.
39649 *
39650 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
39651 */
39652 function irfft_(input) {
39653 var innerDimensionSize = input.shape[input.shape.length - 1];
39654 var batch = input.size / innerDimensionSize;
39655 var ret;
39656 if (innerDimensionSize <= 2) {
39657 var complexInput = reshape$3(input, [batch, innerDimensionSize]);
39658 ret = ifft$2(complexInput);
39659 } else {
39660 // The length of unique components of the DFT of a real-valued signal
39661 // is 2 * (input_len - 1)
39662 var outputShape = [batch, 2 * (innerDimensionSize - 1)];
39663 var realInput = reshape$3(real$2(input), [batch, innerDimensionSize]);
39664 var imagInput = reshape$3(imag$2(input), [batch, innerDimensionSize]);
39665 var realConjugate = reverse$2(slice$2(realInput, [0, 1], [batch, innerDimensionSize - 2]), 1);
39666 var imagConjugate = mul(reverse$2(slice$2(imagInput, [0, 1], [batch, innerDimensionSize - 2]), 1), scalar(-1));
39667 var r = concat$2([realInput, realConjugate], 1);
39668 var i = concat$2([imagInput, imagConjugate], 1);
39669 var _complexInput = reshape$3(complex$2(r, i), [outputShape[0], outputShape[1]]);
39670 ret = ifft$2(_complexInput);
39671 }
39672 ret = real$2(ret);
39673 // reshape the result if the input is 3D tensor.
39674 if (input.rank === 3 && input.shape[0] !== 0) {
39675 var temp = ret;
39676 var _batch = input.shape[0];
39677 ret = reshape$3(ret, [_batch, ret.shape[0] / _batch, ret.shape[1]]);
39678 temp.dispose();
39679 }
39680 return ret;
39681 }
39682 var irfft = /* @__PURE__ */op({
39683 irfft_: irfft_
39684 });
39685
39686 /**
39687 * @license
39688 * Copyright 2020 Google LLC. All Rights Reserved.
39689 * Licensed under the Apache License, Version 2.0 (the "License");
39690 * you may not use this file except in compliance with the License.
39691 * You may obtain a copy of the License at
39692 *
39693 * http://www.apache.org/licenses/LICENSE-2.0
39694 *
39695 * Unless required by applicable law or agreed to in writing, software
39696 * distributed under the License is distributed on an "AS IS" BASIS,
39697 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39698 * See the License for the specific language governing permissions and
39699 * limitations under the License.
39700 * =============================================================================
39701 */
39702 /**
39703 * Splits a `tf.Tensor` into sub tensors.
39704 *
39705 * If `numOrSizeSplits` is a number, splits `x` along dimension `axis`
39706 * into `numOrSizeSplits` smaller tensors.
39707 * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`.
39708 *
39709 * If `numOrSizeSplits` is a number array, splits `x` into
39710 * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the
39711 * same size as `x` except along dimension `axis` where the size is
39712 * `numOrSizeSplits[i]`.
39713 *
39714 * ```js
39715 * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
39716 * const [a, b] = tf.split(x, 2, 1);
39717 * a.print();
39718 * b.print();
39719 *
39720 * const [c, d, e] = tf.split(x, [1, 2, 1], 1);
39721 * c.print();
39722 * d.print();
39723 * e.print();
39724 * ```
39725 *
39726 * @param x The input tensor to split.
39727 * @param numOrSizeSplits Either an integer indicating the number of
39728 * splits along the axis or an array of integers containing the sizes of
39729 * each output tensor along the axis. If a number then it must evenly divide
39730 * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`.
39731 * Can contain one -1 indicating that dimension is to be inferred.
39732 * @param axis The dimension along which to split. Defaults to 0 (the first
39733 * dim).
39734 *
39735 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
39736 */
39737 function split_(x, numOrSizeSplits) {
39738 var axis = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
39739 var $x = convertToTensor(x, 'x', 'split');
39740 var inputs = {
39741 x: $x
39742 };
39743 var attr = {
39744 numOrSizeSplits: numOrSizeSplits,
39745 axis: axis
39746 };
39747 return ENGINE.runKernel(SplitV, inputs, attr);
39748 }
39749 var split$3 = /* @__PURE__ */op({
39750 split_: split_
39751 });
39752
39753 /**
39754 * @license
39755 * Copyright 2018 Google LLC. All Rights Reserved.
39756 * Licensed under the Apache License, Version 2.0 (the "License");
39757 * you may not use this file except in compliance with the License.
39758 * You may obtain a copy of the License at
39759 *
39760 * http://www.apache.org/licenses/LICENSE-2.0
39761 *
39762 * Unless required by applicable law or agreed to in writing, software
39763 * distributed under the License is distributed on an "AS IS" BASIS,
39764 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39765 * See the License for the specific language governing permissions and
39766 * limitations under the License.
39767 * =============================================================================
39768 */
39769 /**
39770 * Real value input fast Fourier transform.
39771 *
39772 * Computes the 1-dimensional discrete Fourier transform over the
39773 * inner-most dimension of the real input.
39774 *
39775 * ```js
39776 * const real = tf.tensor1d([1, 2, 3]);
39777 *
39778 * real.rfft().print();
39779 * ```
39780 * @param input The real value input to compute an rfft over.
39781 *
39782 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
39783 */
39784 function rfft_(input, fftLength) {
39785 assert$1(input.dtype === 'float32', function () {
39786 return "The dtype for rfft() must be real value but got ".concat(input.dtype);
39787 });
39788 var innerDimensionSize = input.shape[input.shape.length - 1];
39789 var batch = input.size / innerDimensionSize;
39790 var adjustedInput;
39791 if (fftLength != null && fftLength < innerDimensionSize) {
39792 // Need to crop
39793 var begin = input.shape.map(function (v) {
39794 return 0;
39795 });
39796 var size = input.shape.map(function (v) {
39797 return v;
39798 });
39799 size[input.shape.length - 1] = fftLength;
39800 adjustedInput = slice$2(input, begin, size);
39801 innerDimensionSize = fftLength;
39802 } else if (fftLength != null && fftLength > innerDimensionSize) {
39803 // Need to pad with zeros
39804 var zerosShape = input.shape.map(function (v) {
39805 return v;
39806 });
39807 zerosShape[input.shape.length - 1] = fftLength - innerDimensionSize;
39808 adjustedInput = concat$2([input, zeros$2(zerosShape)], input.shape.length - 1);
39809 innerDimensionSize = fftLength;
39810 } else {
39811 adjustedInput = input;
39812 }
39813 // Complement the input with zero imaginary numbers.
39814 var zerosInput = zerosLike$3(adjustedInput);
39815 var complexInput = reshape$3(complex$2(adjustedInput, zerosInput), [batch, innerDimensionSize]);
39816 var ret = fft$2(complexInput);
39817 // Exclude complex conjugations. These conjugations are put symmetrically.
39818 var half = Math.floor(innerDimensionSize / 2) + 1;
39819 var realValues = real$2(ret);
39820 var imagValues = imag$2(ret);
39821 var realComplexConjugate = split$3(realValues, [half, innerDimensionSize - half], realValues.shape.length - 1);
39822 var imagComplexConjugate = split$3(imagValues, [half, innerDimensionSize - half], imagValues.shape.length - 1);
39823 var outputShape = adjustedInput.shape.slice();
39824 outputShape[adjustedInput.shape.length - 1] = half;
39825 return reshape$3(complex$2(realComplexConjugate[0], imagComplexConjugate[0]), outputShape);
39826 }
39827 var rfft = /* @__PURE__ */op({
39828 rfft_: rfft_
39829 });
39830
39831 /**
39832 * Returns (a - b) * (a - b) element-wise.
39833 * Supports broadcasting.
39834 *
39835 * ```js
39836 * const a = tf.tensor1d([1, 4, 3, 16]);
39837 * const b = tf.tensor1d([1, 2, 9, 4]);
39838 *
39839 * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
39840 * ```
39841 *
39842 * ```js
39843 * // Broadcast squared difference a with b.
39844 * const a = tf.tensor1d([2, 4, 6, 8]);
39845 * const b = tf.scalar(5);
39846 *
39847 * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
39848 * ```
39849 *
39850 * @param a The first tensor.
39851 * @param b The second tensor. Must have the same type as `a`.
39852 *
39853 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
39854 */
39855 function squaredDifference_(a, b) {
39856 var $a = convertToTensor(a, 'a', 'squaredDifference');
39857 var $b = convertToTensor(b, 'b', 'squaredDifference');
39858 var _makeTypesMatch = makeTypesMatch($a, $b);
39859 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
39860 $a = _makeTypesMatch2[0];
39861 $b = _makeTypesMatch2[1];
39862 assertAndGetBroadcastShape($a.shape, $b.shape);
39863 var inputs = {
39864 a: $a,
39865 b: $b
39866 };
39867 var attrs = {};
39868 return ENGINE.runKernel(SquaredDifference, inputs, attrs);
39869 }
39870 var squaredDifference$2 = /* @__PURE__ */op({
39871 squaredDifference_: squaredDifference_
39872 });
39873
39874 /**
39875 * @license
39876 * Copyright 2020 Google LLC. All Rights Reserved.
39877 * Licensed under the Apache License, Version 2.0 (the "License");
39878 * you may not use this file except in compliance with the License.
39879 * You may obtain a copy of the License at
39880 *
39881 * http://www.apache.org/licenses/LICENSE-2.0
39882 *
39883 * Unless required by applicable law or agreed to in writing, software
39884 * distributed under the License is distributed on an "AS IS" BASIS,
39885 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39886 * See the License for the specific language governing permissions and
39887 * limitations under the License.
39888 * =============================================================================
39889 */
39890 /**
39891 * Removes dimensions of size 1 from the shape of a `tf.Tensor`.
39892 *
39893 * ```js
39894 * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]);
39895 * x.squeeze().print();
39896 * ```
39897 *
39898 * @param x The input tensor to be squeezed.
39899 * @param axis An optional list of numbers. If specified, only
39900 * squeezes the dimensions listed. The dimension index starts at 0. It
39901 * is an error to squeeze a dimension that is not 1.
39902 *
39903 * @doc {heading: 'Tensors', subheading: 'Transformations'}
39904 */
39905 function squeeze_(x, axis) {
39906 var $x = convertToTensor(x, 'x', 'squeeze', 'string_or_numeric');
39907 return reshape$3($x, squeezeShape($x.shape, axis).newShape);
39908 }
39909 var squeeze = /* @__PURE__ */op({
39910 squeeze_: squeeze_
39911 });
39912
39913 /**
39914 * @license
39915 * Copyright 2020 Google LLC. All Rights Reserved.
39916 * Licensed under the Apache License, Version 2.0 (the "License");
39917 * you may not use this file except in compliance with the License.
39918 * You may obtain a copy of the License at
39919 *
39920 * http://www.apache.org/licenses/LICENSE-2.0
39921 *
39922 * Unless required by applicable law or agreed to in writing, software
39923 * distributed under the License is distributed on an "AS IS" BASIS,
39924 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39925 * See the License for the specific language governing permissions and
39926 * limitations under the License.
39927 * =============================================================================
39928 */
39929 /**
39930 * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`.
39931 *
39932 * ```js
39933 * const a = tf.tensor1d([1, 2]);
39934 * const b = tf.tensor1d([3, 4]);
39935 * const c = tf.tensor1d([5, 6]);
39936 * tf.stack([a, b, c]).print();
39937 * ```
39938 *
39939 * @param tensors A list of tensor objects with the same shape and dtype.
39940 * @param axis The axis to stack along. Defaults to 0 (the first dim).
39941 *
39942 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
39943 */
39944 function stack_(tensors) {
39945 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
39946 var $tensors = convertToTensorArray(tensors, 'tensors', 'stack', 'string_or_numeric');
39947 assert$1($tensors.length >= 1, function () {
39948 return 'Pass at least one tensor to tf.stack';
39949 });
39950 if ($tensors.length > 0) {
39951 assert$1(axis <= $tensors[0].rank, function () {
39952 return 'Axis must be <= rank of the tensor';
39953 });
39954 }
39955 var inputs = $tensors;
39956 var attrs = {
39957 axis: axis
39958 };
39959 return ENGINE.runKernel(Pack, inputs, attrs);
39960 }
39961 var stack = /* @__PURE__ */op({
39962 stack_: stack_
39963 });
39964
39965 /**
39966 * @license
39967 * Copyright 2018 Google LLC. All Rights Reserved.
39968 * Licensed under the Apache License, Version 2.0 (the "License");
39969 * you may not use this file except in compliance with the License.
39970 * You may obtain a copy of the License at
39971 *
39972 * http://www.apache.org/licenses/LICENSE-2.0
39973 *
39974 * Unless required by applicable law or agreed to in writing, software
39975 * distributed under the License is distributed on an "AS IS" BASIS,
39976 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39977 * See the License for the specific language governing permissions and
39978 * limitations under the License.
39979 * =============================================================================
39980 */
39981 /**
39982 * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha`
39983 *
39984 * ```js
39985 * const x = tf.tensor1d([0, 2, -1, -3]);
39986 *
39987 * x.step(.5).print(); // or tf.step(x, .5)
39988 * ```
39989 * @param x The input tensor.
39990 * @param alpha The gradient when input is negative. Defaults to 0.
39991 *
39992 * @doc {heading: 'Operations', subheading: 'Basic math'}
39993 */
39994 function step_(x) {
39995 var alpha = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0.0;
39996 var $x = convertToTensor(x, 'x', 'step');
39997 var inputs = {
39998 x: $x
39999 };
40000 var attrs = {
40001 alpha: alpha
40002 };
40003 return ENGINE.runKernel(Step, inputs, attrs);
40004 }
40005 var step$2 = /* @__PURE__ */op({
40006 step_: step_
40007 });
40008
40009 /**
40010 * @license
40011 * Copyright 2018 Google LLC. All Rights Reserved.
40012 * Licensed under the Apache License, Version 2.0 (the "License");
40013 * you may not use this file except in compliance with the License.
40014 * You may obtain a copy of the License at
40015 *
40016 * http://www.apache.org/licenses/LICENSE-2.0
40017 *
40018 * Unless required by applicable law or agreed to in writing, software
40019 * distributed under the License is distributed on an "AS IS" BASIS,
40020 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40021 * See the License for the specific language governing permissions and
40022 * limitations under the License.
40023 * =============================================================================
40024 */
40025 /**
40026 * Extracts a strided slice of a tensor.
40027 *
40028 * Roughly speaking, this op extracts a slice of size (end-begin)/stride from
40029 * the given input tensor (x). Starting at the location specified by begin the
40030 * slice continues by adding stride to the index until all dimensions are not
40031 * less than end. Note that a stride can be negative, which causes a reverse
40032 * slice.
40033 *
40034 * ```js
40035 * const t = tf.tensor3d([1, 1, 1 ,2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
40036 * [3, 2, 3]);
40037 * t.stridedSlice([1, 0, 0], [2, 1, 3], [1, 1, 1]).print() // [[[3, 3, 3]]]
40038 * t.stridedSlice([1, 0, 0], [2, 2, 3], [1, 1, 1]).print() // [[[3, 3, 3],
40039 * // [4, 4, 4]]]
40040 * t.stridedSlice([1, -1, 0], [2, -3, 3], [1, -1, 1]).print() // [[[4, 4, 4],
40041 * // [3, 3, 3]]]
40042 * ```
40043 *
40044 * @param x The tensor to stride slice.
40045 * @param begin The coordinates to start the slice from.
40046 * @param end: The coordinates to end the slice at.
40047 * @param strides: The size of the slice.
40048 * @param beginMask: If the ith bit of beginMask is set, begin[i] is ignored
40049 * and the fullest possible range in that dimension is used instead.
40050 * @param endMask: If the ith bit of endMask is set, end[i] is ignored
40051 * and the fullest possible range in that dimension is used instead.
40052 * @param shrinkAxisMask: a bitmask where bit i implies that
40053 * the ith specification should shrink the dimensionality. begin and end must
40054 * imply a slice of size 1 in the dimension.
40055 *
40056 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
40057 */
40058 function stridedSlice_(x, begin, end, strides) {
40059 var beginMask = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 0;
40060 var endMask = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 0;
40061 var ellipsisMask = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : 0;
40062 var newAxisMask = arguments.length > 7 && arguments[7] !== undefined ? arguments[7] : 0;
40063 var shrinkAxisMask = arguments.length > 8 && arguments[8] !== undefined ? arguments[8] : 0;
40064 var $x = convertToTensor(x, 'x', 'stridedSlice', 'string_or_numeric');
40065 var inputs = {
40066 x: $x
40067 };
40068 var attrs = {
40069 begin: begin,
40070 end: end,
40071 strides: strides,
40072 beginMask: beginMask,
40073 endMask: endMask,
40074 ellipsisMask: ellipsisMask,
40075 newAxisMask: newAxisMask,
40076 shrinkAxisMask: shrinkAxisMask
40077 };
40078 return ENGINE.runKernel(StridedSlice, inputs, attrs);
40079 }
40080 var stridedSlice$2 = /* @__PURE__ */op({
40081 stridedSlice_: stridedSlice_
40082 });
40083
40084 /**
40085 * @license
40086 * Copyright 2018 Google LLC. All Rights Reserved.
40087 * Licensed under the Apache License, Version 2.0 (the "License");
40088 * you may not use this file except in compliance with the License.
40089 * You may obtain a copy of the License at
40090 *
40091 * http://www.apache.org/licenses/LICENSE-2.0
40092 *
40093 * Unless required by applicable law or agreed to in writing, software
40094 * distributed under the License is distributed on an "AS IS" BASIS,
40095 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40096 * See the License for the specific language governing permissions and
40097 * limitations under the License.
40098 * =============================================================================
40099 */
40100 /**
40101 * Computes tan of the input `tf.Tensor` element-wise, `tan(x)`
40102 *
40103 * ```js
40104 * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
40105 *
40106 * x.tan().print(); // or tf.tan(x)
40107 * ```
40108 * @param x The input tensor.
40109 *
40110 * @doc {heading: 'Operations', subheading: 'Basic math'}
40111 */
40112 function tan_(x) {
40113 var $x = convertToTensor(x, 'x', 'tan', 'float32');
40114 var inputs = {
40115 x: $x
40116 };
40117 return ENGINE.runKernel(Tan, inputs);
40118 }
40119 var tan$2 = /* @__PURE__ */op({
40120 tan_: tan_
40121 });
40122
40123 /**
40124 * @license
40125 * Copyright 2018 Google LLC. All Rights Reserved.
40126 * Licensed under the Apache License, Version 2.0 (the "License");
40127 * you may not use this file except in compliance with the License.
40128 * You may obtain a copy of the License at
40129 *
40130 * http://www.apache.org/licenses/LICENSE-2.0
40131 *
40132 * Unless required by applicable law or agreed to in writing, software
40133 * distributed under the License is distributed on an "AS IS" BASIS,
40134 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40135 * See the License for the specific language governing permissions and
40136 * limitations under the License.
40137 * =============================================================================
40138 */
40139 /**
40140 * Creates rank-1 `tf.Tensor` with the provided values, shape and dtype.
40141 *
40142 * The same functionality can be achieved with `tf.tensor`, but in general
40143 * we recommend using `tf.tensor1d` as it makes the code more readable.
40144 *
40145 * ```js
40146 * tf.tensor1d([1, 2, 3]).print();
40147 * ```
40148 *
40149 * @param values The values of the tensor. Can be array of numbers,
40150 * or a `TypedArray`.
40151 * @param dtype The data type.
40152 *
40153 * @doc {heading: 'Tensors', subheading: 'Creation'}
40154 */
40155 function tensor1d(values, dtype) {
40156 assertNonNull(values);
40157 var inferredShape = inferShape(values, dtype);
40158 if (inferredShape.length !== 1) {
40159 throw new Error('tensor1d() requires values to be a flat/TypedArray');
40160 }
40161 var shape = null;
40162 return makeTensor(values, shape, inferredShape, dtype);
40163 }
40164
40165 /**
40166 * @license
40167 * Copyright 2018 Google LLC. All Rights Reserved.
40168 * Licensed under the Apache License, Version 2.0 (the "License");
40169 * you may not use this file except in compliance with the License.
40170 * You may obtain a copy of the License at
40171 *
40172 * http://www.apache.org/licenses/LICENSE-2.0
40173 *
40174 * Unless required by applicable law or agreed to in writing, software
40175 * distributed under the License is distributed on an "AS IS" BASIS,
40176 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40177 * See the License for the specific language governing permissions and
40178 * limitations under the License.
40179 * =============================================================================
40180 */
40181 /**
40182 * Creates rank-2 `tf.Tensor` with the provided values, shape and dtype.
40183 *
40184 * The same functionality can be achieved with `tf.tensor`, but in general
40185 * we recommend using `tf.tensor2d` as it makes the code more readable.
40186 *
40187 * ```js
40188 * // Pass a nested array.
40189 * tf.tensor2d([[1, 2], [3, 4]]).print();
40190 * ```
40191 * ```js
40192 * // Pass a flat array and specify a shape.
40193 * tf.tensor2d([1, 2, 3, 4], [2, 2]).print();
40194 * ```
40195 *
40196 * @param values The values of the tensor. Can be nested array of numbers,
40197 * or a flat array, or a `TypedArray`.
40198 * @param shape The shape of the tensor. If not provided, it is inferred from
40199 * `values`.
40200 * @param dtype The data type.
40201 *
40202 * @doc {heading: 'Tensors', subheading: 'Creation'}
40203 */
40204 function tensor2d(values, shape, dtype) {
40205 assertNonNull(values);
40206 if (shape != null && shape.length !== 2) {
40207 throw new Error('tensor2d() requires shape to have two numbers');
40208 }
40209 var inferredShape = inferShape(values, dtype);
40210 if (inferredShape.length !== 2 && inferredShape.length !== 1) {
40211 throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray');
40212 }
40213 if (inferredShape.length === 1 && shape == null) {
40214 throw new Error('tensor2d() requires shape to be provided when `values` ' + 'are a flat/TypedArray');
40215 }
40216 return makeTensor(values, shape, inferredShape, dtype);
40217 }
40218
40219 /**
40220 * @license
40221 * Copyright 2018 Google LLC. All Rights Reserved.
40222 * Licensed under the Apache License, Version 2.0 (the "License");
40223 * you may not use this file except in compliance with the License.
40224 * You may obtain a copy of the License at
40225 *
40226 * http://www.apache.org/licenses/LICENSE-2.0
40227 *
40228 * Unless required by applicable law or agreed to in writing, software
40229 * distributed under the License is distributed on an "AS IS" BASIS,
40230 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40231 * See the License for the specific language governing permissions and
40232 * limitations under the License.
40233 * =============================================================================
40234 */
40235 /**
40236 * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype.
40237 *
40238 * The same functionality can be achieved with `tf.tensor`, but in general
40239 * we recommend using `tf.tensor3d` as it makes the code more readable.
40240 *
40241 * ```js
40242 * // Pass a nested array.
40243 * tf.tensor3d([[[1], [2]], [[3], [4]]]).print();
40244 * ```
40245 * ```js
40246 * // Pass a flat array and specify a shape.
40247 * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print();
40248 * ```
40249 *
40250 * @param values The values of the tensor. Can be nested array of numbers,
40251 * or a flat array, or a `TypedArray`.
40252 * @param shape The shape of the tensor. If not provided, it is inferred from
40253 * `values`.
40254 * @param dtype The data type.
40255 *
40256 * @doc {heading: 'Tensors', subheading: 'Creation'}
40257 */
40258 function tensor3d(values, shape, dtype) {
40259 assertNonNull(values);
40260 if (shape != null && shape.length !== 3) {
40261 throw new Error('tensor3d() requires shape to have three numbers');
40262 }
40263 var inferredShape = inferShape(values, dtype);
40264 if (inferredShape.length !== 3 && inferredShape.length !== 1) {
40265 throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray');
40266 }
40267 if (inferredShape.length === 1 && shape == null) {
40268 throw new Error('tensor3d() requires shape to be provided when `values` ' + 'are a flat array');
40269 }
40270 return makeTensor(values, shape, inferredShape, dtype);
40271 }
40272
40273 /**
40274 * @license
40275 * Copyright 2018 Google LLC. All Rights Reserved.
40276 * Licensed under the Apache License, Version 2.0 (the "License");
40277 * you may not use this file except in compliance with the License.
40278 * You may obtain a copy of the License at
40279 *
40280 * http://www.apache.org/licenses/LICENSE-2.0
40281 *
40282 * Unless required by applicable law or agreed to in writing, software
40283 * distributed under the License is distributed on an "AS IS" BASIS,
40284 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40285 * See the License for the specific language governing permissions and
40286 * limitations under the License.
40287 * =============================================================================
40288 */
40289 /**
40290 * Creates rank-4 `tf.Tensor` with the provided values, shape and dtype.
40291 *
40292 * The same functionality can be achieved with `tf.tensor`, but in general
40293 * we recommend using `tf.tensor4d` as it makes the code more readable.
40294 *
40295 * ```js
40296 * // Pass a nested array.
40297 * tf.tensor4d([[[[1], [2]], [[3], [4]]]]).print();
40298 * ```
40299 * ```js
40300 * // Pass a flat array and specify a shape.
40301 * tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]).print();
40302 * ```
40303 *
40304 * @param values The values of the tensor. Can be nested array of numbers,
40305 * or a flat array, or a `TypedArray`.
40306 * @param shape The shape of the tensor. Optional. If not provided,
40307 * it is inferred from `values`.
40308 * @param dtype The data type.
40309 *
40310 * @doc {heading: 'Tensors', subheading: 'Creation'}
40311 */
40312 function tensor4d(values, shape, dtype) {
40313 assertNonNull(values);
40314 if (shape != null && shape.length !== 4) {
40315 throw new Error('tensor4d() requires shape to have four numbers');
40316 }
40317 var inferredShape = inferShape(values, dtype);
40318 if (inferredShape.length !== 4 && inferredShape.length !== 1) {
40319 throw new Error('tensor4d() requires values to be number[][][][] or flat/TypedArray');
40320 }
40321 if (inferredShape.length === 1 && shape == null) {
40322 throw new Error('tensor4d() requires shape to be provided when `values` ' + 'are a flat array');
40323 }
40324 return makeTensor(values, shape, inferredShape, dtype);
40325 }
40326
40327 /**
40328 * @license
40329 * Copyright 2018 Google LLC. All Rights Reserved.
40330 * Licensed under the Apache License, Version 2.0 (the "License");
40331 * you may not use this file except in compliance with the License.
40332 * You may obtain a copy of the License at
40333 *
40334 * http://www.apache.org/licenses/LICENSE-2.0
40335 *
40336 * Unless required by applicable law or agreed to in writing, software
40337 * distributed under the License is distributed on an "AS IS" BASIS,
40338 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40339 * See the License for the specific language governing permissions and
40340 * limitations under the License.
40341 * =============================================================================
40342 */
40343 /**
40344 * Creates rank-5 `tf.Tensor` with the provided values, shape and dtype.
40345 *
40346 * The same functionality can be achieved with `tf.tensor`, but in general
40347 * we recommend using `tf.tensor5d` as it makes the code more readable.
40348 *
40349 * ```js
40350 * // Pass a nested array.
40351 * tf.tensor5d([[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]).print();
40352 * ```
40353 * ```js
40354 * // Pass a flat array and specify a shape.
40355 * tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]).print();
40356 * ```
40357 *
40358 * @param values The values of the tensor. Can be nested array of numbers,
40359 * or a flat array, or a `TypedArray`.
40360 * @param shape The shape of the tensor. Optional. If not provided,
40361 * it is inferred from `values`.
40362 * @param dtype The data type.
40363 *
40364 * @doc {heading: 'Tensors', subheading: 'Creation'}
40365 */
40366 function tensor5d(values, shape, dtype) {
40367 assertNonNull(values);
40368 if (shape != null && shape.length !== 5) {
40369 throw new Error('tensor5d() requires shape to have five numbers');
40370 }
40371 var inferredShape = inferShape(values, dtype);
40372 if (inferredShape.length !== 5 && inferredShape.length !== 1) {
40373 throw new Error('tensor5d() requires values to be ' + 'number[][][][][] or flat/TypedArray');
40374 }
40375 if (inferredShape.length === 1 && shape == null) {
40376 throw new Error('tensor5d() requires shape to be provided when `values` ' + 'are a flat array');
40377 }
40378 return makeTensor(values, shape, inferredShape, dtype);
40379 }
40380
40381 /**
40382 * @license
40383 * Copyright 2018 Google LLC. All Rights Reserved.
40384 * Licensed under the Apache License, Version 2.0 (the "License");
40385 * you may not use this file except in compliance with the License.
40386 * You may obtain a copy of the License at
40387 *
40388 * http://www.apache.org/licenses/LICENSE-2.0
40389 *
40390 * Unless required by applicable law or agreed to in writing, software
40391 * distributed under the License is distributed on an "AS IS" BASIS,
40392 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40393 * See the License for the specific language governing permissions and
40394 * limitations under the License.
40395 * =============================================================================
40396 */
40397 /**
40398 * Creates rank-6 `tf.Tensor` with the provided values, shape and dtype.
40399 *
40400 * The same functionality can be achieved with `tf.tensor`, but in general
40401 * we recommend using `tf.tensor6d` as it makes the code more readable.
40402 *
40403 * ```js
40404 * // Pass a nested array.
40405 * tf.tensor6d([[[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]]).print();
40406 * ```
40407 * ```js
40408 * // Pass a flat array and specify a shape.
40409 * tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 2, 2, 2, 1]).print();
40410 * ```
40411 *
40412 * @param values The values of the tensor. Can be nested array of numbers,
40413 * or a flat array, or a `TypedArray`.
40414 * @param shape The shape of the tensor. Optional. If not provided,
40415 * it is inferred from `values`.
40416 * @param dtype The data type.
40417 *
40418 * @doc {heading: 'Tensors', subheading: 'Creation'}
40419 */
40420 function tensor6d(values, shape, dtype) {
40421 assertNonNull(values);
40422 if (shape != null && shape.length !== 6) {
40423 throw new Error('tensor6d() requires shape to have six numbers');
40424 }
40425 var inferredShape = inferShape(values, dtype);
40426 if (inferredShape.length !== 6 && inferredShape.length !== 1) {
40427 throw new Error('tensor6d() requires values to be number[][][][][][] or ' + 'flat/TypedArray');
40428 }
40429 if (inferredShape.length === 1 && shape == null) {
40430 throw new Error('tensor6d() requires shape to be provided when `values` ' + 'are a flat array');
40431 }
40432 shape = shape || inferredShape;
40433 return makeTensor(values, shape, inferredShape, dtype);
40434 }
40435
40436 /**
40437 * Check whether updates.shape = indices.shape[:batchDim] +
40438 * shape[sliceDim:]
40439 *
40440 * @param x The input tensor.
40441 */
40442 function validateUpdateShape(shape, indices, updates) {
40443 var sliceDim = indices.rank > 1 ? indices.shape[indices.rank - 1] : 1;
40444 var batchDim = indices.rank > 1 ? indices.rank - 1 : 1;
40445 var shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' + "shape[sliceDim:], got updates.shape: ".concat(updates.shape) + ", indices.shape: ".concat(indices.shape, ", shape: ").concat(shape) + ", sliceDim: ".concat(sliceDim, ", and batchDim: ").concat(batchDim, ".");
40446 if (updates.rank < batchDim) {
40447 throw new Error(shapeError + " update.rank < ".concat(batchDim, ". "));
40448 }
40449 if (shape.length < sliceDim + (updates.rank - batchDim)) {
40450 throw new Error(shapeError + " Output shape length < ".concat(sliceDim + (updates.rank - batchDim)));
40451 }
40452 if (updates.rank !== batchDim + shape.length - sliceDim) {
40453 throw new Error(shapeError + " update.rank != ".concat(batchDim + shape.length - sliceDim));
40454 }
40455 for (var d = 0; d < batchDim; ++d) {
40456 if (updates.shape[d] !== indices.shape[d]) {
40457 throw new Error(shapeError + " updates.shape[".concat(d, "] (").concat(updates.shape[d], ") != indices.shape[").concat(d, "] (").concat(indices.shape[d], ")."));
40458 }
40459 }
40460 for (var _d = 0; _d < updates.rank - batchDim; ++_d) {
40461 if (updates.shape[_d + batchDim] !== shape[_d + sliceDim]) {
40462 throw new Error(shapeError + " updates.shape[".concat(_d + batchDim, "] (").concat(updates.shape[_d + batchDim], ") != shape[").concat(_d + batchDim, "] (").concat(shape[_d + batchDim], ")"));
40463 }
40464 }
40465 }
40466 /**
40467 * Validate scatter nd inputs.
40468 *
40469 * @param update The tensor contains the update values.
40470 * @param indices The tensor contains the indices for the update values.
40471 * @param shape The shape of the output tensor.
40472 */
40473 function validateInput$1(updates, indices, shape) {
40474 if (indices.rank < 1) {
40475 throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' + " but the rank was ".concat(indices.rank, "."));
40476 }
40477 if (updates.rank < 1) {
40478 throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' + " but the rank was ".concat(updates.rank, "."));
40479 }
40480 if (indices.dtype !== 'int32') {
40481 throw new Error("The dtype of 'indices' should be int32, but got dtype: ".concat(indices.dtype));
40482 }
40483 if (shape.length < 1) {
40484 throw new Error("Output rank must be greater or equal to 1, but got shape: ".concat(shape));
40485 }
40486 if (shape.length === 0) {
40487 if (indices.size === 0) {
40488 throw new Error("Indices specified for empty output. indices shape: ".concat(indices.shape));
40489 }
40490 if (updates.size === 0) {
40491 throw new Error("Updates specified for empty output. updates shape: ".concat(updates.shape));
40492 }
40493 }
40494 validateUpdateShape(shape, indices, updates);
40495 }
40496 /**
40497 * Calculate the shape information for the output.
40498 *
40499 * @param update The tensor contains the update values.
40500 * @param indices The tensor contains the indices for the update values.
40501 * @param shape The shape of the output tensor.
40502 *
40503 * @returns ScatterShapeInfo
40504 */
40505 function calculateShapes(updates, indices, shape) {
40506 // Calculate the number of dimensions in indices
40507 var indicesRank = indices.shape.length;
40508 var sliceRank = indicesRank > 1 ? indices.shape[indicesRank - 1] : 1;
40509 // Calculate the number of elements that make up each slice of our updated
40510 // tensor. This allows us to work with flattened tensors and copy over whole
40511 // slices at a time.
40512 var totalNd = shape.length;
40513 var sliceSize = 1;
40514 for (var i = sliceRank; i < totalNd; ++i) {
40515 sliceSize *= shape[i];
40516 }
40517 var safeSliceDim = sliceRank < 1 ? 1 : sliceRank;
40518 var numUpdates = sizeFromShape(indices.shape) / safeSliceDim;
40519 var strides = [].concat(_toConsumableArray(computeStrides(shape.slice(0, sliceRank))), [1]);
40520 var outputSize = sizeFromShape(shape);
40521 return {
40522 sliceRank: sliceRank,
40523 numUpdates: numUpdates,
40524 sliceSize: sliceSize,
40525 strides: strides,
40526 outputSize: outputSize
40527 };
40528 }
40529
40530 var scatter_nd_util = {
40531 __proto__: null,
40532 calculateShapes: calculateShapes,
40533 validateInput: validateInput$1,
40534 validateUpdateShape: validateUpdateShape
40535 };
40536
40537 /**
40538 * @license
40539 * Copyright 2022 Google LLC. All Rights Reserved.
40540 * Licensed under the Apache License, Version 2.0 (the "License");
40541 * you may not use this file except in compliance with the License.
40542 * You may obtain a copy of the License at
40543 *
40544 * http://www.apache.org/licenses/LICENSE-2.0
40545 *
40546 * Unless required by applicable law or agreed to in writing, software
40547 * distributed under the License is distributed on an "AS IS" BASIS,
40548 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40549 * See the License for the specific language governing permissions and
40550 * limitations under the License.
40551 * =============================================================================
40552 */
40553 /**
40554 * Creates a new tensor by applying sparse updates to individual
40555 * values or slices to the passed in tensor according to
40556 * indices. This operator is the similar to scatterNd op, except that the
40557 * udpates are scattered on an existing tensor (as opposed to a zero-tensor).
40558 *
40559 * If indices contains duplicates, then we pick the last update for the index.
40560 *
40561 * If an out of bound index is found on CPU, an error is returned.
40562 *
40563 * Warning: There are some GPU specific semantics for this operation.
40564 * - If an out of bound index is found, the index is ignored.
40565 * - The order in which updates are applied is nondeterministic, so the output
40566 * will be nondeterministic if indices contains duplicates.
40567 * ```js
40568 * const shape = [8];
40569 * const tensor = tf.ones(shape);
40570 * const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32');
40571 * const updates = tf.tensor1d([9, 10, 11, 12]);
40572 *
40573 * tf.tensorScatterUpdate(tensor, indices, updates).print();
40574 * //[1, 11, 1, 10, 9, 1, 1, 12]
40575 * ```
40576 *
40577 * @param tensor A Tensor. Tensor to copy/update.
40578 * @param indices The tensor contains the indices into the output tensor, must
40579 * have at least 2 axes: (num_updates, index_depth).
40580 * @param updates The tensor contains the value for the indices.
40581 *
40582 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
40583 */
40584 function tensorScatterUpdate_(tensor, indices, updates) {
40585 var $tensor = convertToTensor(tensor, 'tensor', 'tensorScatterupdate');
40586 var $indices = convertToTensor(indices, 'indices', 'tensorScatterupdate', 'int32');
40587 var $updates = convertToTensor(updates, 'updates', 'tensorScatterupdate');
40588 validateInput$1($updates, $indices, $tensor.shape);
40589 if ($tensor.dtype !== $updates.dtype) {
40590 throw new Error("tensor and updates must have the same dtype, instead they are ".concat($tensor.dtype, " and ").concat($updates.dtype, "."));
40591 }
40592 var inputs = {
40593 tensor: $tensor,
40594 indices: $indices,
40595 updates: $updates
40596 };
40597 var attrs = {};
40598 // tslint:disable-next-line: no-unnecessary-type-assertion
40599 return ENGINE.runKernel(TensorScatterUpdate, inputs, attrs);
40600 }
40601 var tensorScatterUpdate$2 = op({
40602 tensorScatterUpdate_: tensorScatterUpdate_
40603 });
40604
40605 /**
40606 * Finds the values and indices of the `k` largest entries along the last
40607 * dimension.
40608 *
40609 * If the input is a vector (rank=1), finds the k largest entries in the vector
40610 * and outputs their values and indices as vectors. Thus values[j] is the j-th
40611 * largest entry in input, and its index is indices[j].
40612 * For higher rank inputs, computes the top k entries along the last dimension.
40613 *
40614 * If two elements are equal, the lower-index element appears first.
40615 *
40616 * ```js
40617 * const a = tf.tensor2d([[1, 5], [4, 3]]);
40618 * const {values, indices} = tf.topk(a);
40619 * values.print();
40620 * indices.print();
40621 * ```
40622 * @param x 1-D or higher `tf.Tensor` with last dimension being at least `k`.
40623 * @param k Number of top elements to look for along the last dimension.
40624 * @param sorted If true, the resulting `k` elements will be sorted by the
40625 * values in descending order.
40626 *
40627 * @doc {heading: 'Operations', subheading: 'Evaluation'}
40628 */
40629 function topk_(x) {
40630 var k = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 1;
40631 var sorted = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true;
40632 var $x = convertToTensor(x, 'x', 'topk');
40633 if ($x.rank === 0) {
40634 throw new Error('topk() expects the input to be of rank 1 or higher');
40635 }
40636 var lastDim = $x.shape[$x.shape.length - 1];
40637 if (k < 0) {
40638 throw new Error("'k' passed to topk() must be >= 0 but got ".concat(k));
40639 }
40640 if (k > lastDim) {
40641 throw new Error("'k' passed to topk() must be <= the last dimension (".concat(lastDim, ") ") + "but got ".concat(k));
40642 }
40643 var inputs = {
40644 x: $x
40645 };
40646 var attrs = {
40647 k: k,
40648 sorted: sorted
40649 };
40650 var _ENGINE$runKernel = ENGINE.runKernel(TopK, inputs, attrs),
40651 _ENGINE$runKernel2 = _slicedToArray(_ENGINE$runKernel, 2),
40652 values = _ENGINE$runKernel2[0],
40653 indices = _ENGINE$runKernel2[1];
40654 return {
40655 values: values,
40656 indices: indices
40657 };
40658 }
40659 var topk = /* @__PURE__ */op({
40660 topk_: topk_
40661 });
40662
40663 /**
40664 * @license
40665 * Copyright 2020 Google LLC. All Rights Reserved.
40666 * Licensed under the Apache License, Version 2.0 (the "License");
40667 * you may not use this file except in compliance with the License.
40668 * You may obtain a copy of the License at
40669 *
40670 * http://www.apache.org/licenses/LICENSE-2.0
40671 *
40672 * Unless required by applicable law or agreed to in writing, software
40673 * distributed under the License is distributed on an "AS IS" BASIS,
40674 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40675 * See the License for the specific language governing permissions and
40676 * limitations under the License.
40677 * =============================================================================
40678 */
40679 /**
40680 * Creates a `tf.Tensor` with values sampled from a truncated normal
40681 * distribution.
40682 *
40683 * ```js
40684 * tf.truncatedNormal([2, 2]).print();
40685 * ```
40686 *
40687 * The generated values follow a normal distribution with specified mean and
40688 * standard deviation, except that values whose magnitude is more than 2
40689 * standard deviations from the mean are dropped and re-picked.
40690 *
40691 * @param shape An array of integers defining the output tensor shape.
40692 * @param mean The mean of the normal distribution.
40693 * @param stdDev The standard deviation of the normal distribution.
40694 * @param dtype The data type of the output tensor.
40695 * @param seed The seed for the random number generator.
40696 *
40697 * @doc {heading: 'Tensors', subheading: 'Creation'}
40698 */
40699 function truncatedNormal_(shape) {
40700 var mean = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
40701 var stdDev = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1;
40702 var dtype = arguments.length > 3 ? arguments[3] : undefined;
40703 var seed = arguments.length > 4 ? arguments[4] : undefined;
40704 assertNonNegativeIntegerDimensions(shape);
40705 if (dtype != null && dtype === 'bool') {
40706 throw new Error("Unsupported data type $ { dtype }");
40707 }
40708 var randGauss = new MPRandGauss(mean, stdDev, dtype, true /* truncated */, seed);
40709 var res = buffer(shape, dtype);
40710 for (var i = 0; i < res.values.length; i++) {
40711 res.values[i] = randGauss.nextValue();
40712 }
40713 return res.toTensor();
40714 }
40715 var truncatedNormal$1 = /* @__PURE__ */op({
40716 truncatedNormal_: truncatedNormal_
40717 });
40718
40719 /**
40720 * Finds unique elements along an axis of a tensor.
40721 *
40722 * It returns a tensor `values` containing all of the unique elements along the
40723 * `axis` of the given tensor `x` in the same order that they occur along the
40724 * `axis` in `x`; `x` does not need to be sorted. It also returns a tensor
40725 * `indices` the same size as the number of the elements in `x` along the `axis`
40726 * dimension. It contains the index in the unique output `values`.
40727 *
40728 * ```js
40729 * // A 1-D tensor
40730 * const a = tf.tensor1d([1, 1, 2, 4, 4, 4, 7, 8, 8]);
40731 * const {values, indices} = tf.unique(a);
40732 * values.print(); // [1, 2, 4, 7, 8,]
40733 * indices.print(); // [0, 0, 1, 2, 2, 2, 3, 4, 4]
40734 * ```
40735 *
40736 * ```js
40737 * // A 2-D tensor with axis=0
40738 * //
40739 * // 'a' is: [[1, 0, 0],
40740 * // [1, 0, 0],
40741 * // [2, 0, 0]]
40742 * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
40743 * const {values, indices} = tf.unique(a, 0)
40744 * values.print(); // [[1, 0, 0],
40745 * // [2, 0, 0]]
40746 * indices.print(); // [0, 0, 1]
40747 * ```
40748 *
40749 * ```js
40750 * // A 2-D tensor with axis=1
40751 * //
40752 * // 'a' is: [[1, 0, 0],
40753 * // [1, 0, 0],
40754 * // [2, 0, 0]]
40755 * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
40756 * const {values, indices} = tf.unique(a, 1)
40757 * values.print(); // [[1, 0],
40758 * // [1, 0],
40759 * // [2, 0]]
40760 * indices.print(); // [0, 1, 1]
40761 * ```
40762 * @param x A tensor (int32, string, bool).
40763 * @param axis The axis of the tensor to find the unique elements.
40764 * @returns [uniqueElements, indices] (see above for details)
40765 *
40766 * @doc {heading: 'Operations', subheading: 'Evaluation'}
40767 */
40768 function unique_(x) {
40769 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
40770 var $x = convertToTensor(x, 'x', 'unique', 'string_or_numeric');
40771 assert$1($x.rank > 0, function () {
40772 return 'The input tensor must be at least 1D';
40773 });
40774 var inputs = {
40775 x: $x
40776 };
40777 var attrs = {
40778 axis: axis
40779 };
40780 var _ENGINE$runKernel = ENGINE.runKernel(Unique, inputs, attrs),
40781 _ENGINE$runKernel2 = _slicedToArray(_ENGINE$runKernel, 2),
40782 values = _ENGINE$runKernel2[0],
40783 indices = _ENGINE$runKernel2[1];
40784 return {
40785 values: values,
40786 indices: indices
40787 };
40788 }
40789 var unique$3 = /* @__PURE__ */op({
40790 unique_: unique_
40791 });
40792
40793 /**
40794 * @license
40795 * Copyright 2020 Google LLC. All Rights Reserved.
40796 * Licensed under the Apache License, Version 2.0 (the "License");
40797 * you may not use this file except in compliance with the License.
40798 * You may obtain a copy of the License at
40799 *
40800 * http://www.apache.org/licenses/LICENSE-2.0
40801 *
40802 * Unless required by applicable law or agreed to in writing, software
40803 * distributed under the License is distributed on an "AS IS" BASIS,
40804 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40805 * See the License for the specific language governing permissions and
40806 * limitations under the License.
40807 * =============================================================================
40808 */
40809 /**
40810 * Computes the sum along segments of a `tf.Tensor`.
40811 *
40812 * ```js
40813 * const x = tf.tensor1d([1, 2, 3, 4]);
40814 * const segmentIds = tf.tensor1d([1, 2, 0, 1], 'int32');
40815 * const numSegments = 3;
40816 *
40817 * x.unsortedSegmentSum(segmentIds, numSegments).print()
40818 * //or tf.unsortedSegmentSum(x, segmentIds, numSegments)
40819 * ```
40820 * @param x The `tf.Tensor` that will be summed along its segments.
40821 * @param segmentIds A `tf.Tensor1D` whose rank is equal to the rank of `x`'s
40822 * dimension along the `axis`. Maps each element of `x` to a segment.
40823 * @param numSegments The number of distinct `segmentIds`.
40824 *
40825 * @doc {heading: 'Operations', subheading: 'Segment'}
40826 */
40827 function unsortedSegmentSum_(x, segmentIds, numSegments) {
40828 var $x = convertToTensor(x, 'x', 'unsortedSegmentSum');
40829 var $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32');
40830 assert$1(isInt(numSegments), function () {
40831 return 'numSegments must be of dtype int';
40832 });
40833 var inputs = {
40834 x: $x,
40835 segmentIds: $segmentIds
40836 };
40837 var attrs = {
40838 numSegments: numSegments
40839 };
40840 return ENGINE.runKernel(UnsortedSegmentSum, inputs, attrs);
40841 }
40842 var unsortedSegmentSum$2 = /* @__PURE__ */op({
40843 unsortedSegmentSum_: unsortedSegmentSum_
40844 });
40845
40846 /**
40847 * @license
40848 * Copyright 2020 Google LLC. All Rights Reserved.
40849 * Licensed under the Apache License, Version 2.0 (the "License");
40850 * you may not use this file except in compliance with the License.
40851 * You may obtain a copy of the License at
40852 *
40853 * http://www.apache.org/licenses/LICENSE-2.0
40854 *
40855 * Unless required by applicable law or agreed to in writing, software
40856 * distributed under the License is distributed on an "AS IS" BASIS,
40857 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40858 * See the License for the specific language governing permissions and
40859 * limitations under the License.
40860 * =============================================================================
40861 */
40862 /**
40863 * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s.
40864 *
40865 * ```js
40866 * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
40867 *
40868 * tf.unstack(a).forEach(tensor => tensor.print());
40869 * ```
40870 *
40871 * @param x A tensor object.
40872 * @param axis The axis to unstack along. Defaults to 0 (the first dim).
40873 *
40874 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
40875 */
40876 function unstack_(x) {
40877 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
40878 var $x = convertToTensor(x, 'x', 'unstack', 'string_or_numeric');
40879 assert$1(axis >= -$x.shape.length && axis < $x.shape.length, function () {
40880 return "Axis = ".concat(axis, " is not in [-").concat($x.shape.length, ", ").concat($x.shape.length, ")");
40881 });
40882 var inputs = {
40883 value: $x
40884 };
40885 var attrs = {
40886 axis: axis
40887 };
40888 return ENGINE.runKernel(Unpack, inputs, attrs);
40889 }
40890 var unstack = /* @__PURE__ */op({
40891 unstack_: unstack_
40892 });
40893
40894 /**
40895 * @license
40896 * Copyright 2022 Google LLC. All Rights Reserved.
40897 * Licensed under the Apache License, Version 2.0 (the "License");
40898 * you may not use this file except in compliance with the License.
40899 * You may obtain a copy of the License at
40900 *
40901 * http://www.apache.org/licenses/LICENSE-2.0
40902 *
40903 * Unless required by applicable law or agreed to in writing, software
40904 * distributed under the License is distributed on an "AS IS" BASIS,
40905 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40906 * See the License for the specific language governing permissions and
40907 * limitations under the License.
40908 * =============================================================================
40909 */
40910 /**
40911 * Searches for where a value would go in a sorted sequence.
40912 *
40913 * This is not a method for checking containment (like javascript in).
40914 *
40915 * The typical use case for this operation is "binning", "bucketing", or
40916 * "discretizing". The values are assigned to bucket-indices based on the edges
40917 * listed in 'sortedSequence'. This operation returns the bucket-index for each
40918 * value.
40919 *
40920 * The index returned corresponds to the first edge greater than the value.
40921 *
40922 * The axis is not settable for this operation. It always operates on the
40923 * innermost dimension (axis=-1). The operation will accept any number of outer
40924 * dimensions.
40925 *
40926 * Note: This operation assumes that 'upperBound' is sorted along the
40927 * innermost axis, maybe using 'sort(..., axis=-1)'. If the sequence is not
40928 * sorted no error is raised and the content of the returned tensor is not well
40929 * defined.
40930 *
40931 * ```js
40932 * const seq = tf.tensor1d([0, 3, 9, 10, 10]);
40933 * const values = tf.tensor1d([0, 4, 10]);
40934 * const result = tf.upperBound(seq, values);
40935 * result.print(); // [1, 2, 5]
40936 * ```
40937 * @param sortedSequence: N-D. Sorted sequence.
40938 * @param values: N-D. Search values.
40939 * @return An N-D int32 tensor the size of values containing the result of
40940 * applying upper bound to each value. The result is not a global index to
40941 * the entire Tensor, but the index in the last dimension.
40942 * @doc {heading: 'Operations', subheading: 'Evaluation'}
40943 */
40944 function upperBound$1(sortedSequence, values) {
40945 return searchSorted$2(sortedSequence, values, 'right');
40946 }
40947
40948 /**
40949 * @license
40950 * Copyright 2018 Google LLC. All Rights Reserved.
40951 * Licensed under the Apache License, Version 2.0 (the "License");
40952 * you may not use this file except in compliance with the License.
40953 * You may obtain a copy of the License at
40954 *
40955 * http://www.apache.org/licenses/LICENSE-2.0
40956 *
40957 * Unless required by applicable law or agreed to in writing, software
40958 * distributed under the License is distributed on an "AS IS" BASIS,
40959 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40960 * See the License for the specific language governing permissions and
40961 * limitations under the License.
40962 * =============================================================================
40963 */
40964 /**
40965 * Creates a new variable with the provided initial value.
40966 * ```js
40967 * const x = tf.variable(tf.tensor([1, 2, 3]));
40968 * x.assign(tf.tensor([4, 5, 6]));
40969 *
40970 * x.print();
40971 * ```
40972 *
40973 * @param initialValue Initial value for the tensor.
40974 * @param trainable If true, optimizers are allowed to update it.
40975 * @param name Name of the variable. Defaults to a unique id.
40976 * @param dtype If set, initialValue will be converted to the given type.
40977 *
40978 * @doc {heading: 'Tensors', subheading: 'Creation'}
40979 */
40980 function variable$1(initialValue) {
40981 var trainable = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : true;
40982 var name = arguments.length > 2 ? arguments[2] : undefined;
40983 var dtype = arguments.length > 3 ? arguments[3] : undefined;
40984 return ENGINE.makeVariable(initialValue, trainable, name, dtype);
40985 }
40986
40987 /**
40988 * @license
40989 * Copyright 2018 Google LLC. All Rights Reserved.
40990 * Licensed under the Apache License, Version 2.0 (the "License");
40991 * you may not use this file except in compliance with the License.
40992 * You may obtain a copy of the License at
40993 *
40994 * http://www.apache.org/licenses/LICENSE-2.0
40995 *
40996 * Unless required by applicable law or agreed to in writing, software
40997 * distributed under the License is distributed on an "AS IS" BASIS,
40998 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40999 * See the License for the specific language governing permissions and
41000 * limitations under the License.
41001 * =============================================================================
41002 */
41003 function whereImpl$2(condShape, condVals) {
41004 var indices = [];
41005 for (var i = 0; i < condVals.length; i++) {
41006 if (condVals[i]) {
41007 indices.push(i);
41008 }
41009 }
41010 var inBuffer = buffer(condShape, 'int32');
41011 var out = buffer([indices.length, condShape.length], 'int32');
41012 for (var _i = 0; _i < indices.length; _i++) {
41013 var loc = inBuffer.indexToLoc(indices[_i]);
41014 var offset = _i * condShape.length;
41015 out.values.set(loc, offset);
41016 }
41017 return out.toTensor();
41018 }
41019
41020 /**
41021 * Returns the coordinates of true elements of condition.
41022 *
41023 * The coordinates are returned in a 2-D tensor where the first dimension (rows)
41024 * represents the number of true elements, and the second dimension (columns)
41025 * represents the coordinates of the true elements. Keep in mind, the shape of
41026 * the output tensor can vary depending on how many true values there are in
41027 * input. Indices are output in row-major order. The resulting tensor has the
41028 * shape `[numTrueElems, condition.rank]`.
41029 *
41030 * This is analogous to calling the python `tf.where(cond)` without an x or y.
41031 *
41032 * ```js
41033 * const cond = tf.tensor1d([false, false, true], 'bool');
41034 * const result = await tf.whereAsync(cond);
41035 * result.print();
41036 * ```
41037 *
41038 * @doc {heading: 'Operations', subheading: 'Logical'}
41039 */
41040 function whereAsync_(_x) {
41041 return _whereAsync_.apply(this, arguments);
41042 }
41043 function _whereAsync_() {
41044 _whereAsync_ = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(condition) {
41045 var $condition, vals, res;
41046 return _regeneratorRuntime().wrap(function _callee$(_context) {
41047 while (1) switch (_context.prev = _context.next) {
41048 case 0:
41049 $condition = convertToTensor(condition, 'condition', 'whereAsync', 'bool');
41050 _context.next = 3;
41051 return $condition.data();
41052 case 3:
41053 vals = _context.sent;
41054 res = whereImpl$2($condition.shape, vals);
41055 if (condition !== $condition) {
41056 $condition.dispose();
41057 }
41058 return _context.abrupt("return", res);
41059 case 7:
41060 case "end":
41061 return _context.stop();
41062 }
41063 }, _callee);
41064 }));
41065 return _whereAsync_.apply(this, arguments);
41066 }
41067 var whereAsync = whereAsync_;
41068
41069 /**
41070 * Apply boolean mask to tensor.
41071 *
41072 * ```js
41073 * const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
41074 * const mask = tf.tensor1d([1, 0, 1], 'bool');
41075 * const result = await tf.booleanMaskAsync(tensor, mask);
41076 * result.print();
41077 * ```
41078 *
41079 * @param tensor N-D tensor.
41080 * @param mask K-D boolean tensor, K <= N and K must be known statically.
41081 * @param axis A 0-D int Tensor representing the axis in tensor to mask from.
41082 * By default, axis is 0 which will mask from the first dimension.
41083 * Otherwise K + axis <= N.
41084 *
41085 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
41086 */
41087 function booleanMaskAsync_(_x, _x2, _x3) {
41088 return _booleanMaskAsync_.apply(this, arguments);
41089 }
41090 function _booleanMaskAsync_() {
41091 _booleanMaskAsync_ = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(tensor, mask, axis) {
41092 var $tensor, $mask, axisFrom, maskDim, tensorShape, leadingSize, i, targetTensorShape, reshapedTensor, reshapedMask, positivePositions, indices, res;
41093 return _regeneratorRuntime().wrap(function _callee$(_context) {
41094 while (1) switch (_context.prev = _context.next) {
41095 case 0:
41096 $tensor = convertToTensor(tensor, 'tensor', 'boolMask');
41097 $mask = convertToTensor(mask, 'mask', 'boolMask', 'bool');
41098 axisFrom = axis == null ? 0 : axis;
41099 maskDim = $mask.rank;
41100 tensorShape = $tensor.shape;
41101 assert$1(maskDim > 0, function () {
41102 return 'mask cannot be scalar';
41103 });
41104 assertShapesMatch(tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape, "mask's shape must match the first K dimensions of tensor's shape,");
41105 leadingSize = 1;
41106 for (i = axisFrom; i < axisFrom + maskDim; i++) {
41107 leadingSize *= tensorShape[i];
41108 }
41109 targetTensorShape = tensorShape.slice(0, axisFrom).concat([leadingSize], tensorShape.slice(axisFrom + maskDim));
41110 reshapedTensor = reshape$3($tensor, targetTensorShape);
41111 reshapedMask = reshape$3($mask, [-1]);
41112 _context.next = 14;
41113 return whereAsync(reshapedMask);
41114 case 14:
41115 positivePositions = _context.sent;
41116 indices = squeeze(positivePositions, [1]);
41117 res = gather$1(reshapedTensor, indices, axisFrom); // Ensure no memory leak.
41118 if (tensor !== $tensor) {
41119 $tensor.dispose();
41120 }
41121 if (mask !== $mask) {
41122 $mask.dispose();
41123 }
41124 indices.dispose();
41125 reshapedTensor.dispose();
41126 reshapedMask.dispose();
41127 positivePositions.dispose();
41128 return _context.abrupt("return", res);
41129 case 24:
41130 case "end":
41131 return _context.stop();
41132 }
41133 }, _callee);
41134 }));
41135 return _booleanMaskAsync_.apply(this, arguments);
41136 }
41137 var booleanMaskAsync = booleanMaskAsync_;
41138
41139 /**
41140 * @license
41141 * Copyright 2018 Google LLC. All Rights Reserved.
41142 * Licensed under the Apache License, Version 2.0 (the "License");
41143 * you may not use this file except in compliance with the License.
41144 * You may obtain a copy of the License at
41145 *
41146 * http://www.apache.org/licenses/LICENSE-2.0
41147 *
41148 * Unless required by applicable law or agreed to in writing, software
41149 * distributed under the License is distributed on an "AS IS" BASIS,
41150 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
41151 * See the License for the specific language governing permissions and
41152 * limitations under the License.
41153 * =============================================================================
41154 */
41155 /**
41156 * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`.
41157 *
41158 * The returned `tf.Tensor`'s dimension `i` will correspond to the input
41159 * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`,
41160 * where `n` is the rank of the input `tf.Tensor`. Hence by default, this
41161 * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s.
41162 *
41163 * ```js
41164 * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
41165 *
41166 * a.transpose().print(); // or tf.transpose(a)
41167 * ```
41168 *
41169 * @param x The tensor to transpose.
41170 * @param perm The permutation of the dimensions of a.
41171 * @param conjugate Will conjugate complex input if true.
41172 *
41173 * @doc {heading: 'Operations', subheading: 'Matrices'}
41174 */
41175 function transpose_(x, perm, conjugate) {
41176 var $x = convertToTensor(x, 'x', 'transpose');
41177 if (perm == null) {
41178 perm = $x.shape.map(function (s, i) {
41179 return i;
41180 }).reverse();
41181 }
41182 assert$1($x.rank === perm.length, function () {
41183 return "Error in transpose: rank of input ".concat($x.rank, " ") + "must match length of perm ".concat(perm, ".");
41184 });
41185 perm.forEach(function (axis) {
41186 assert$1(axis >= 0 && axis < $x.rank, function () {
41187 return "All entries in 'perm' must be between 0 and ".concat($x.rank - 1) + " but got ".concat(perm);
41188 });
41189 });
41190 if ($x.rank <= 1) {
41191 return $x.clone();
41192 }
41193 var inputs = {
41194 x: $x
41195 };
41196 var attrs = {
41197 perm: perm
41198 };
41199 if ($x.dtype === 'complex64') {
41200 return tidy(function () {
41201 var $real = real$2($x);
41202 var $imag = imag$2($x);
41203 $real = ENGINE.runKernel(Transpose, {
41204 x: $real
41205 }, attrs);
41206 $imag = ENGINE.runKernel(Transpose, {
41207 x: $imag
41208 }, attrs);
41209 if (conjugate) {
41210 $imag = neg$2($imag);
41211 }
41212 return complex$2($real, $imag);
41213 });
41214 }
41215 return ENGINE.runKernel(Transpose, inputs, attrs);
41216 }
41217 var transpose$2 = /* @__PURE__ */op({
41218 transpose_: transpose_
41219 });
41220
41221 /**
41222 * @license
41223 * Copyright 2018 Google LLC. All Rights Reserved.
41224 * Licensed under the Apache License, Version 2.0 (the "License");
41225 * you may not use this file except in compliance with the License.
41226 * You may obtain a copy of the License at
41227 *
41228 * http://www.apache.org/licenses/LICENSE-2.0
41229 *
41230 * Unless required by applicable law or agreed to in writing, software
41231 * distributed under the License is distributed on an "AS IS" BASIS,
41232 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
41233 * See the License for the specific language governing permissions and
41234 * limitations under the License.
41235 * =============================================================================
41236 */
41237 /**
41238 * Compute the moving average of a variable.
41239 *
41240 * Without zeroDebias, the moving average operation is defined by:
41241 * `v += delta`
41242 * where
41243 * `delta = (1 - decay) * (x - v)`
41244 *
41245 * With zeroDebias (default), the `delta` term is scaled to debias the
41246 * effect of the (assumed) zero-initialization of `v`.
41247 * `delta /= (1 - decay ^ step)`
41248 *
41249 * For more details on the zero-debiasing algorithm, see:
41250 * https://arxiv.org/abs/1412.6980
41251 *
41252 * Note that this function is completely stateless and does not keep track of
41253 * step count. The step count needs to be maintained by the caller and passed
41254 * in as `step`.
41255 *
41256 * @param v The current moving average value.
41257 * @param x New input value, must have the same shape and dtype as `v`.
41258 * @param decay The decay factor. Typical values are 0.95 and 0.99.
41259 * @param step Step count.
41260 * @param zeroDebias: Whether zeroDebias is to be performed (default: `true`).
41261 * @returns The new moving average value.
41262 *
41263 * @doc {heading: 'Operations', subheading: 'Moving Average'}
41264 */
41265 function movingAverage_(v, x, decay, step) {
41266 var zeroDebias = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : true;
41267 var $v = convertToTensor(v, 'v', 'movingAverage');
41268 var $x = convertToTensor(x, 'x', 'movingAverage');
41269 var $decay = convertToTensor(decay, 'decay', 'movingAverage');
41270 assertTypesMatch($v, $x);
41271 assert$1(arraysEqual($v.shape, $x.shape), function () {
41272 return 'Shape mismatch in v and x';
41273 });
41274 var one = scalar(1);
41275 var oneMinusDecay = sub$2(one, $decay);
41276 var update = mul(sub$2($x, $v), oneMinusDecay);
41277 if (zeroDebias) {
41278 assert$1(step != null, function () {
41279 return 'When using zeroDebias: true, step is required.';
41280 });
41281 var $step = convertToTensor(step, 'step', 'movingAverage');
41282 update = div$1(update, sub$2(one, pow$3($decay, $step)));
41283 }
41284 return add$3($v, update);
41285 }
41286 var movingAverage = /* @__PURE__ */op({
41287 movingAverage_: movingAverage_
41288 });
41289
41290 /**
41291 * @license
41292 * Copyright 2018 Google LLC. All Rights Reserved.
41293 * Licensed under the Apache License, Version 2.0 (the "License");
41294 * you may not use this file except in compliance with the License.
41295 * You may obtain a copy of the License at
41296 *
41297 * http://www.apache.org/licenses/LICENSE-2.0
41298 *
41299 * Unless required by applicable law or agreed to in writing, software
41300 * distributed under the License is distributed on an "AS IS" BASIS,
41301 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
41302 * See the License for the specific language governing permissions and
41303 * limitations under the License.
41304 * =============================================================================
41305 */
41306 /**
41307 * Creates a new tensor by applying sparse updates to individual
41308 * values or slices within a zero tensor of the given shape tensor according to
41309 * indices. This operator is the inverse of the `tf.gatherND` operator which
41310 * extracts values or slices from a given tensor.
41311 *
41312 * ```js
41313 * const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32');
41314 * const updates = tf.tensor1d([9, 10, 11, 12]);
41315 * const shape = [8];
41316 * tf.scatterND(indices, updates, shape).print() //[0, 11, 0, 10, 9, 0, 0, 12]
41317 * ```
41318 *
41319 * @param indices The tensor contains the indices into the output tensor.
41320 * @param updates The tensor contains the value for the indices.
41321 * @param shape: The shape of the output tensor.
41322 *
41323 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
41324 */
41325 function scatterND_(indices, updates, shape) {
41326 assertNonNegativeIntegerDimensions(shape);
41327 var $indices = convertToTensor(indices, 'indices', 'scatterND', 'int32');
41328 var $updates = convertToTensor(updates, 'updates', 'scatterND');
41329 validateInput$1($updates, $indices, shape);
41330 var inputs = {
41331 indices: $indices,
41332 updates: $updates
41333 };
41334 var attrs = {
41335 shape: shape
41336 };
41337 // tslint:disable-next-line: no-unnecessary-type-assertion
41338 return ENGINE.runKernel(ScatterNd, inputs, attrs);
41339 }
41340 var scatterND = /* @__PURE__ */op({
41341 scatterND_: scatterND_
41342 });
41343
41344 /**
41345 * Validate sparseToDense inputs.
41346 *
41347 * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
41348 * sparseIndices[i] contains the complete index where sparseValues[i] will be
41349 * placed.
41350 * @param sparseValues A 0-D or 1-D Tensor. Values
41351 * corresponding to each row of sparseIndices, or a scalar value to be used for
41352 * all sparse indices.
41353 * @param outputShape number[]. Shape of the dense output tensor.
41354 * @param validateIndices boolean. indice validation is not supported, error
41355 * will be thrown if it is set.
41356 */
41357 function validateInput(sparseIndices, sparseValues, outputShape, defaultValues) {
41358 if (sparseIndices.dtype !== 'int32') {
41359 throw new Error('tf.sparseToDense() expects the indices to be int32 type,' + " but the dtype was ".concat(sparseIndices.dtype, "."));
41360 }
41361 if (sparseIndices.rank > 2) {
41362 throw new Error('sparseIndices should be a scalar, vector, or matrix,' + " but got shape ".concat(sparseIndices.shape, "."));
41363 }
41364 var numElems = sparseIndices.rank > 0 ? sparseIndices.shape[0] : 1;
41365 var numDims = sparseIndices.rank > 1 ? sparseIndices.shape[1] : 1;
41366 if (outputShape.length !== numDims) {
41367 throw new Error('outputShape has incorrect number of elements:,' + " ".concat(outputShape.length, ", should be: ").concat(numDims, "."));
41368 }
41369 var numValues = sparseValues.size;
41370 if (!(sparseValues.rank === 0 || sparseValues.rank === 1 && numValues === numElems)) {
41371 throw new Error('sparseValues has incorrect shape ' + "".concat(sparseValues.shape, ", should be [] or [").concat(numElems, "]"));
41372 }
41373 if (sparseValues.dtype !== defaultValues.dtype) {
41374 throw new Error('sparseValues.dtype must match defaultValues.dtype');
41375 }
41376 }
41377
41378 /**
41379 * @license
41380 * Copyright 2018 Google LLC. All Rights Reserved.
41381 * Licensed under the Apache License, Version 2.0 (the "License");
41382 * you may not use this file except in compliance with the License.
41383 * You may obtain a copy of the License at
41384 *
41385 * http://www.apache.org/licenses/LICENSE-2.0
41386 *
41387 * Unless required by applicable law or agreed to in writing, software
41388 * distributed under the License is distributed on an "AS IS" BASIS,
41389 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
41390 * See the License for the specific language governing permissions and
41391 * limitations under the License.
41392 * =============================================================================
41393 */
41394 /**
41395 * Converts a sparse representation into a dense tensor.
41396 *
41397 * Builds an array dense with shape outputShape such that:
41398 *
41399 * // If sparseIndices is scalar
41400 * dense[i] = (i == sparseIndices ? sparseValues : defaultValue)
41401 *
41402 * // If sparseIndices is a vector, then for each i
41403 * dense[sparseIndices[i]] = sparseValues[i]
41404 *
41405 * // If sparseIndices is an n by d matrix, then for each i in [0, n)
41406 * dense[sparseIndices[i][0], ..., sparseIndices[i][d-1]] = sparseValues[i]
41407 * All other values in dense are set to defaultValue. If sparseValues is a
41408 * scalar, all sparse indices are set to this single value.
41409 *
41410 * If indices are repeated the final value is summed over all values for those
41411 * indices.
41412 *
41413 * ```js
41414 * const indices = tf.tensor1d([4, 5, 6, 1, 2, 3], 'int32');
41415 * const values = tf.tensor1d([10, 11, 12, 13, 14, 15], 'float32');
41416 * const shape = [8];
41417 * tf.sparseToDense(indices, values, shape).print();
41418 * ```
41419 *
41420 * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
41421 * sparseIndices[i] contains the complete index where sparseValues[i] will be
41422 * placed.
41423 * @param sparseValues A 0-D or 1-D Tensor. Values
41424 * corresponding to each row of sparseIndices, or a scalar value to be used for
41425 * all sparse indices.
41426 * @param outputShape Shape of the dense output tensor. The type is inferred.
41427 * @param defaultValue Scalar. Value to set for indices not specified in
41428 * sparseIndices. Defaults to zero.
41429 *
41430 * @doc {heading: 'Operations', subheading: 'Normalization'}
41431 */
41432 function sparseToDense_(sparseIndices, sparseValues, outputShape) {
41433 var defaultValue = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 0;
41434 assertNonNegativeIntegerDimensions(outputShape);
41435 var $sparseIndices = convertToTensor(sparseIndices, 'sparseIndices', 'sparseToDense', 'int32');
41436 var $sparseValues = convertToTensor(sparseValues, 'sparseValues', 'sparseToDense', 'string_or_numeric');
41437 var $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseToDense', $sparseValues.dtype);
41438 validateInput($sparseIndices, $sparseValues, outputShape, $defaultValue);
41439 var inputs = {
41440 sparseIndices: $sparseIndices,
41441 sparseValues: $sparseValues,
41442 defaultValue: $defaultValue
41443 };
41444 var attrs = {
41445 outputShape: outputShape
41446 };
41447 return ENGINE.runKernel(SparseToDense, inputs, attrs);
41448 }
41449 var sparseToDense$2 = /* @__PURE__ */op({
41450 sparseToDense_: sparseToDense_
41451 });
41452
41453 /**
41454 * @license
41455 * Copyright 2018 Google LLC. All Rights Reserved.
41456 * Licensed under the Apache License, Version 2.0 (the "License");
41457 * you may not use this file except in compliance with the License.
41458 * You may obtain a copy of the License at
41459 *
41460 * http://www.apache.org/licenses/LICENSE-2.0
41461 *
41462 * Unless required by applicable law or agreed to in writing, software
41463 * distributed under the License is distributed on an "AS IS" BASIS,
41464 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
41465 * See the License for the specific language governing permissions and
41466 * limitations under the License.
41467 * =============================================================================
41468 */
41469 /**
41470 * Gather slices from input tensor into a Tensor with shape specified by
41471 * `indices`.
41472 *
41473 * `indices` is a K-dimensional integer tensor, best thought of as a
41474 * (K-1)-dimensional tensor of indices into input, where each element defines a
41475 * slice of input:
41476 * output[\\(i_0, ..., i_{K-2}\\)] = input[indices[\\(i_0, ..., i_{K-2}\\)]]
41477 *
41478 * Whereas in `tf.gather`, `indices` defines slices into the first dimension of
41479 * input, in `tf.gatherND`, `indices` defines slices into the first N dimensions
41480 * of input, where N = indices.shape[-1].
41481 *
41482 * The last dimension of indices can be at most the rank of input:
41483 * indices.shape[-1] <= input.rank
41484 *
41485 * The last dimension of `indices` corresponds to elements
41486 * (if indices.shape[-1] == input.rank) or slices
41487 * (if indices.shape[-1] < input.rank) along dimension indices.shape[-1] of
41488 * input.
41489 * The output tensor has shape
41490 * indices.shape[:-1] + input.shape[indices.shape[-1]:]
41491 *
41492 * Note that on CPU, if an out of bound index is found, an error is returned. On
41493 * GPU, if an out of bound index is found, a 0 is stored in the corresponding
41494 * output value.
41495 *
41496 * ```js
41497 * const indices = tf.tensor2d([0, 1, 1, 0], [2,2], 'int32');
41498 * const input = tf.tensor2d([9, 10, 11, 12], [2, 2]);
41499 * tf.gatherND(input, indices).print() // [10, 11]
41500 * ```
41501 *
41502 * @param x The tensor from which to gather values.
41503 * @param indices Index tensor, must be of type int32.
41504 *
41505 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
41506 */
41507 function gatherND_(x, indices) {
41508 var $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32');
41509 var $x = convertToTensor(x, 'x', 'gatherND', 'string_or_numeric');
41510 var inputs = {
41511 params: $x,
41512 indices: $indices
41513 };
41514 return ENGINE.runKernel(GatherNd, inputs);
41515 }
41516 var gatherND = /* @__PURE__ */op({
41517 gatherND_: gatherND_
41518 });
41519
41520 /**
41521 * @license
41522 * Copyright 2019 Google LLC. All Rights Reserved.
41523 * Licensed under the Apache License, Version 2.0 (the "License");
41524 * you may not use this file except in compliance with the License.
41525 * You may obtain a copy of the License at
41526 *
41527 * http://www.apache.org/licenses/LICENSE-2.0
41528 *
41529 * Unless required by applicable law or agreed to in writing, software
41530 * distributed under the License is distributed on an "AS IS" BASIS,
41531 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
41532 * See the License for the specific language governing permissions and
41533 * limitations under the License.
41534 * =============================================================================
41535 */
41536 /**
41537 * Normalize noise shape based on provided tensor and noise shape.
41538 *
41539 * @param x Tensor.
41540 * @param noiseShape The shape for the randomly generated keep/drop flags, as
41541 * an array of numbers. Optional.
41542 * @returns Normalized noise shape.
41543 */
41544 function getNoiseShape(x, noiseShape) {
41545 if (noiseShape == null) {
41546 return x.shape.slice();
41547 }
41548 if (arraysEqual(x.shape, noiseShape)) {
41549 return noiseShape;
41550 }
41551 if (x.shape.length === noiseShape.length) {
41552 var newDimension = [];
41553 for (var i = 0; i < x.shape.length; i++) {
41554 if (noiseShape[i] == null && x.shape[i] != null) {
41555 newDimension.push(x.shape[i]);
41556 } else {
41557 newDimension.push(noiseShape[i]);
41558 }
41559 }
41560 return newDimension;
41561 }
41562 return noiseShape;
41563 }
41564
41565 /**
41566 * @license
41567 * Copyright 2018 Google LLC. All Rights Reserved.
41568 * Licensed under the Apache License, Version 2.0 (the "License");
41569 * you may not use this file except in compliance with the License.
41570 * You may obtain a copy of the License at
41571 *
41572 * http://www.apache.org/licenses/LICENSE-2.0
41573 *
41574 * Unless required by applicable law or agreed to in writing, software
41575 * distributed under the License is distributed on an "AS IS" BASIS,
41576 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
41577 * See the License for the specific language governing permissions and
41578 * limitations under the License.
41579 * =============================================================================
41580 */
41581 /**
41582 * Computes dropout.
41583 *
41584 * ```js
41585 * const x = tf.tensor1d([1, 2, 2, 1]);
41586 * const rate = 0.75;
41587 * const output = tf.dropout(x, rate);
41588 * output.print();
41589 * ```
41590 *
41591 * @param x A floating point Tensor or TensorLike.
41592 * @param rate A float in the range [0, 1). The probability that each element
41593 * of x is discarded.
41594 * @param noiseShape An array of numbers of type int32, representing the
41595 * shape for randomly generated keep/drop flags. If the noiseShape has null
41596 * value, it will be automatically replaced with the x's relative dimension
41597 * size. Optional.
41598 * @param seed Used to create random seeds. Optional.
41599 * @returns A Tensor of the same shape of x.
41600 *
41601 * @doc {heading: 'Operations', subheading: 'Dropout'}
41602 */
41603 function dropout_(x, rate, noiseShape, seed) {
41604 var $x = convertToTensor(x, 'x', 'dropout');
41605 assert$1($x.dtype === 'float32', function () {
41606 return "x has to be a floating point tensor since it's going to be " + "scaled, but got a ".concat($x.dtype, " tensor instead.");
41607 });
41608 assert$1(rate >= 0 && rate < 1, function () {
41609 return "rate must be a float in the range [0, 1), but got ".concat(rate, ".");
41610 });
41611 if (rate === 0) {
41612 return x instanceof Tensor ? $x.clone() : $x;
41613 }
41614 var $noiseShape = getNoiseShape($x, noiseShape);
41615 var keepProb = 1 - rate;
41616 var multiplier = div$1(floor$2(add$3(randomUniform$1($noiseShape, 0, 1, 'float32', seed), keepProb)), keepProb);
41617 return mul($x, multiplier);
41618 }
41619 var dropout$2 = /* @__PURE__ */op({
41620 dropout_: dropout_
41621 });
41622
41623 /**
41624 * @license
41625 * Copyright 2019 Google LLC. All Rights Reserved.
41626 * Licensed under the Apache License, Version 2.0 (the "License");
41627 * you may not use this file except in compliance with the License.
41628 * You may obtain a copy of the License at
41629 *
41630 * http://www.apache.org/licenses/LICENSE-2.0
41631 *
41632 * Unless required by applicable law or agreed to in writing, software
41633 * distributed under the License is distributed on an "AS IS" BASIS,
41634 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
41635 * See the License for the specific language governing permissions and
41636 * limitations under the License.
41637 * =============================================================================
41638 */
41639 function enclosingPowerOfTwo(value) {
41640 // Return 2**N for integer N such that 2**N >= value.
41641 return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2.0))));
41642 }
41643 function cosineWindow(windowLength, a, b) {
41644 var even = 1 - windowLength % 2;
41645 var newValues = new Float32Array(windowLength);
41646 for (var i = 0; i < windowLength; ++i) {
41647 var cosArg = 2.0 * Math.PI * i / (windowLength + even - 1);
41648 newValues[i] = a - b * Math.cos(cosArg);
41649 }
41650 return tensor1d(newValues, 'float32');
41651 }
41652
41653 /**
41654 * Returns whether the targets are in the top K predictions.
41655 *
41656 * ```js
41657 * const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
41658 * const targets = tf.tensor1d([2, 0]);
41659 * const precision = await tf.inTopKAsync(predictions, targets);
41660 * precision.print();
41661 * ```
41662 * @param predictions 2-D or higher `tf.Tensor` with last dimension being
41663 * at least `k`.
41664 * @param targets 1-D or higher `tf.Tensor`.
41665 * @param k Optional Number of top elements to look at for computing precision,
41666 * default to 1.
41667 *
41668 * @doc {heading: 'Operations', subheading: 'Evaluation'}
41669 */
41670 function inTopKAsync_(_x, _x2) {
41671 return _inTopKAsync_.apply(this, arguments);
41672 }
41673 function _inTopKAsync_() {
41674 _inTopKAsync_ = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(predictions, targets) {
41675 var k,
41676 $predictions,
41677 $targets,
41678 lastDim,
41679 predictionsVals,
41680 targetsVals,
41681 batch,
41682 size,
41683 precision,
41684 b,
41685 offset,
41686 vals,
41687 valAndInd,
41688 i,
41689 _i,
41690 _args = arguments;
41691 return _regeneratorRuntime().wrap(function _callee$(_context) {
41692 while (1) switch (_context.prev = _context.next) {
41693 case 0:
41694 k = _args.length > 2 && _args[2] !== undefined ? _args[2] : 1;
41695 $predictions = convertToTensor(predictions, 'predictions', 'inTopK');
41696 $targets = convertToTensor(targets, 'targets', 'inTopK');
41697 assert$1($predictions.rank > 1, function () {
41698 return 'inTopK() expects the predictions to be of rank 2 or higher, ' + "but got ".concat($predictions.rank);
41699 });
41700 assert$1($predictions.rank - 1 === $targets.rank, function () {
41701 return "predictions rank should be 1 larger than " + "targets rank, but got predictions rank " + "".concat($predictions.rank, " and targets rank ").concat($targets.rank);
41702 });
41703 assertShapesMatch($predictions.shape.slice(0, $predictions.shape.length - 1), $targets.shape, "predictions's shape should be align with the targets' shape, " + 'except the last dimension.');
41704 lastDim = $predictions.shape[$predictions.shape.length - 1];
41705 assert$1(k > 0 && k <= lastDim, function () {
41706 return "'k' passed to inTopK() must be > 0 && <= the predictions last " + "dimension (".concat(lastDim, "), but got ").concat(k);
41707 });
41708 _context.next = 10;
41709 return $predictions.data();
41710 case 10:
41711 predictionsVals = _context.sent;
41712 _context.next = 13;
41713 return $targets.data();
41714 case 13:
41715 targetsVals = _context.sent;
41716 // Reshape predictionsVals into a 2d tensor [batch, lastDim]
41717 // and look up topK along lastDim.
41718 batch = predictionsVals.length / lastDim, size = lastDim;
41719 precision = getTypedArrayFromDType('bool', batch);
41720 b = 0;
41721 case 17:
41722 if (!(b < batch)) {
41723 _context.next = 35;
41724 break;
41725 }
41726 offset = b * size;
41727 vals = predictionsVals.subarray(offset, offset + size);
41728 valAndInd = [];
41729 for (i = 0; i < vals.length; i++) {
41730 valAndInd.push({
41731 value: vals[i],
41732 index: i
41733 });
41734 }
41735 valAndInd.sort(function (a, b) {
41736 return b.value - a.value;
41737 });
41738 precision[b] = 0;
41739 _i = 0;
41740 case 25:
41741 if (!(_i < k)) {
41742 _context.next = 32;
41743 break;
41744 }
41745 if (!(valAndInd[_i].index === targetsVals[b])) {
41746 _context.next = 29;
41747 break;
41748 }
41749 precision[b] = 1;
41750 return _context.abrupt("break", 32);
41751 case 29:
41752 _i++;
41753 _context.next = 25;
41754 break;
41755 case 32:
41756 b++;
41757 _context.next = 17;
41758 break;
41759 case 35:
41760 if (predictions !== $predictions) {
41761 $predictions.dispose();
41762 }
41763 if (targets !== $targets) {
41764 $targets.dispose();
41765 }
41766 // Output precision has the same shape as targets.
41767 return _context.abrupt("return", tensor(precision, $targets.shape, 'bool'));
41768 case 38:
41769 case "end":
41770 return _context.stop();
41771 }
41772 }, _callee);
41773 }));
41774 return _inTopKAsync_.apply(this, arguments);
41775 }
41776 var inTopKAsync = inTopKAsync_;
41777
41778 /**
41779 * @license
41780 * Copyright 2020 Google LLC. All Rights Reserved.
41781 * Licensed under the Apache License, Version 2.0 (the "License");
41782 * you may not use this file except in compliance with the License.
41783 * You may obtain a copy of the License at
41784 *
41785 * http://www.apache.org/licenses/LICENSE-2.0
41786 *
41787 * Unless required by applicable law or agreed to in writing, software
41788 * distributed under the License is distributed on an "AS IS" BASIS,
41789 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
41790 * See the License for the specific language governing permissions and
41791 * limitations under the License.
41792 * =============================================================================
41793 */
41794 /**
41795 * Computes the derivative of the filter of a 2D convolution.
41796 *
41797 * @param x The input tensor, of rank 4 or rank 3 of shape
41798 * [batch, height, width, inChannels]. If rank 3, batch of 1 is assumed.
41799 * @param dy The dy image, of rank 4 or rank 3, of shape
41800 * [batch, height, width, outDepth]. If rank 3, batch of 1 is assumed.
41801 * @param filterShape The shape of the filter, length 4,
41802 * [filterHeight, filterWidth, inDepth, outDepth].
41803 * @param strides The strides of the convolution: [strideHeight,
41804 * strideWidth].
41805 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
41806 * used in the forward prop of the op.
41807 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
41808 * "NHWC". Specify the data format of the input and output data. With the
41809 * default format "NHWC", the data is stored in the order of: [batch,
41810 * height, width, channels].
41811 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
41812 * provided, it will default to truncate.
41813 */
41814 function conv2DBackpropFilter_(x, dy, filterShape, strides, pad) {
41815 var dataFormat = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 'NHWC';
41816 var dimRoundingMode = arguments.length > 6 ? arguments[6] : undefined;
41817 var x4D = x;
41818 if (x.rank === 3) {
41819 x4D = reshape$3(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
41820 }
41821 var dy4D = dy;
41822 if (dy4D.rank === 3) {
41823 dy4D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
41824 }
41825 assert$1(x4D.rank === 4, function () {
41826 return "Error in conv2dDerFilter: input must be rank 4, but got shape " + "".concat(x4D.shape, ".");
41827 });
41828 assert$1(dy4D.rank === 4, function () {
41829 return "Error in conv2dDerFilter: dy must be rank 4, but got shape " + "".concat(dy4D.shape, ".");
41830 });
41831 assert$1(filterShape.length === 4, function () {
41832 return "Error in conv2dDerFilter: filterShape must be length 4, but got " + "".concat(filterShape, ".");
41833 });
41834 var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
41835 var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
41836 assert$1(inDepth === filterShape[2], function () {
41837 return "Error in conv2dDerFilter: depth of input ".concat(inDepth, ") must ") + "match input depth in filter (".concat(filterShape[2], ".");
41838 });
41839 assert$1(outDepth === filterShape[3], function () {
41840 return "Error in conv2dDerFilter: depth of dy (".concat(outDepth, ") must ") + "match output depth for filter (".concat(filterShape[3], ").");
41841 });
41842 checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode);
41843 var inputs = {
41844 x: x4D,
41845 dy: dy4D
41846 };
41847 var attrs = {
41848 strides: strides,
41849 pad: pad,
41850 dataFormat: dataFormat,
41851 dimRoundingMode: dimRoundingMode,
41852 filterShape: filterShape
41853 };
41854 // tslint:disable-next-line: no-unnecessary-type-assertion
41855 return ENGINE.runKernel(Conv2DBackpropFilter, inputs, attrs);
41856 }
41857 var conv2DBackpropFilter$2 = /* @__PURE__ */op({
41858 conv2DBackpropFilter_: conv2DBackpropFilter_
41859 });
41860
41861 /**
41862 * @license
41863 * Copyright 2019 Google LLC. All Rights Reserved.
41864 * Licensed under the Apache License, Version 2.0 (the "License");
41865 * you may not use this file except in compliance with the License.
41866 * You may obtain a copy of the License at
41867 *
41868 * http://www.apache.org/licenses/LICENSE-2.0
41869 *
41870 * Unless required by applicable law or agreed to in writing, software
41871 * distributed under the License is distributed on an "AS IS" BASIS,
41872 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
41873 * See the License for the specific language governing permissions and
41874 * limitations under the License.
41875 * =============================================================================
41876 */
41877 // Returns gradient for fused activation.
41878 function getFusedDyActivation(dy, y, activation) {
41879 if (activation == null || activation === 'linear') {
41880 return dy;
41881 }
41882 if (activation === 'relu') {
41883 return mul(dy, step$2(y));
41884 }
41885 throw new Error("Cannot compute gradient for fused activation ".concat(activation, "."));
41886 }
41887 // Returns gradient for fused bias.
41888 function getFusedBiasGradient(bias, dyActivation) {
41889 var res = dyActivation;
41890 var reduceAxes = getReductionAxes(bias.shape, dyActivation.shape);
41891 if (reduceAxes.length > 0) {
41892 res = sum$3(res, reduceAxes);
41893 }
41894 return reshape$3(res, bias.shape);
41895 }
41896 function applyActivation$1(x, activation, preluActivationWeights, leakyreluAlpha) {
41897 if (activation === 'linear') {
41898 return x;
41899 } else if (activation === 'relu') {
41900 return relu$2(x);
41901 } else if (activation === 'elu') {
41902 return elu$4(x);
41903 } else if (activation === 'relu6') {
41904 return relu6$2(x);
41905 } else if (activation === 'prelu') {
41906 return prelu$3(x, preluActivationWeights);
41907 } else if (activation === 'leakyrelu') {
41908 return leakyRelu$2(x, leakyreluAlpha);
41909 } else if (activation === 'sigmoid') {
41910 return sigmoid$2(x);
41911 }
41912 throw new Error("Unknown fused activation ".concat(activation, "."));
41913 }
41914 // Whether we should call fused ops.
41915 var shouldFuse = function shouldFuse(gradientDepth, activation) {
41916 var gradientMode = gradientDepth > 0;
41917 return !gradientMode || activation === 'linear';
41918 };
41919
41920 /**
41921 * Computes a 2D convolution over the input x, optionally fused with adding a
41922 * bias and applying an activation.
41923 *
41924 * ```js
41925 * const inputDepth = 2;
41926 * const inShape = [2, 2, 2, inputDepth];
41927 * const outputDepth = 2;
41928 * const fSize = 1;
41929 * const pad = 0;
41930 * const strides = 1;
41931 *
41932 * const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
41933 * 16], inShape);
41934 * const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth,
41935 * outputDepth]);
41936 *
41937 * tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC',
41938 * dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print();
41939 * ```
41940 *
41941 * @param obj An object with the following properties:
41942 * @param x The input tensor, of rank 4 or rank 3, of shape
41943 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
41944 * assumed.
41945 * @param filter The filter, rank 4, of shape
41946 * `[filterHeight, filterWidth, inDepth, outDepth]`.
41947 * @param strides The strides of the convolution: `[strideHeight,
41948 * strideWidth]`.
41949 * @param pad The type of padding algorithm.
41950 * - `same` and stride 1: output will be of same size as input,
41951 * regardless of filter size.
41952 * - `valid` output will be smaller than input if filter is larger
41953 * than 1x1.
41954 * - For more info, see this guide:
41955 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
41956 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
41957 * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to
41958 * "NHWC". Specify the data format of the input and output data. With the
41959 * default format "NHWC", the data is stored in the order of: [batch,
41960 * height, width, channels]. Only "NHWC" is currently supported.
41961 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
41962 * in which we sample input values across the height and width dimensions
41963 * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
41964 * number, then `dilationHeight == dilationWidth`. If it is greater than
41965 * 1, then all values of `strides` must be 1.
41966 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
41967 * provided, it will default to truncate.
41968 * @param bias Tensor to be added to the result.
41969 * @param activation Name of activation kernel (defaults to `linear`) to be
41970 * applied
41971 * after biasAdd.
41972 * @param preluActivationWeights Tensor of prelu weights to be applied as part
41973 * of a `prelu` activation, typically the same shape as `x`.
41974 * @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`
41975 * activation.
41976 */
41977 function fusedConv2d_(_ref) {
41978 var x = _ref.x,
41979 filter = _ref.filter,
41980 strides = _ref.strides,
41981 pad = _ref.pad,
41982 _ref$dataFormat = _ref.dataFormat,
41983 dataFormat = _ref$dataFormat === void 0 ? 'NHWC' : _ref$dataFormat,
41984 _ref$dilations = _ref.dilations,
41985 dilations = _ref$dilations === void 0 ? [1, 1] : _ref$dilations,
41986 dimRoundingMode = _ref.dimRoundingMode,
41987 bias = _ref.bias,
41988 _ref$activation = _ref.activation,
41989 activation = _ref$activation === void 0 ? 'linear' : _ref$activation,
41990 preluActivationWeights = _ref.preluActivationWeights,
41991 leakyreluAlpha = _ref.leakyreluAlpha;
41992 activation = activation || 'linear';
41993 if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
41994 // TODO: Transpose bias and preluActivationWeights properly for NCHW
41995 // format before computation.
41996 assert$1(dataFormat === 'NHWC', function () {
41997 return "Error in fused conv2d: got dataFormat of ".concat(dataFormat, " but ") + "only NHWC is currently supported for the case of gradient depth " + "is 0 and the activation is not linear.";
41998 });
41999 var result = conv2d$4(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
42000 if (bias != null) {
42001 result = add$3(result, bias);
42002 }
42003 return applyActivation$1(result, activation, preluActivationWeights, leakyreluAlpha);
42004 }
42005 var $x = convertToTensor(x, 'x', 'conv2d', 'float32');
42006 var $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');
42007 var x4D = $x;
42008 var reshapedTo4D = false;
42009 if ($x.rank === 3) {
42010 reshapedTo4D = true;
42011 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
42012 }
42013 assert$1(x4D.rank === 4, function () {
42014 return "Error in fused conv2d: input must be rank 4, but got rank " + "".concat(x4D.rank, ".");
42015 });
42016 assert$1($filter.rank === 4, function () {
42017 return "Error in fused conv2d: filter must be rank 4, but got rank " + "".concat($filter.rank, ".");
42018 });
42019 checkPadOnDimRoundingMode('fused conv2d', pad, dimRoundingMode);
42020 var inputChannels = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
42021 assert$1($filter.shape[2] === inputChannels, function () {
42022 return "Error in conv2d: depth of input (".concat(inputChannels, ") must match ") + "input depth for filter ".concat($filter.shape[2], ".");
42023 });
42024 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
42025 return 'Error in conv2D: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
42026 });
42027 var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode);
42028 var $bias;
42029 if (bias != null) {
42030 $bias = convertToTensor(bias, 'bias', 'fused conv2d');
42031 var _makeTypesMatch = makeTypesMatch($bias, $x);
42032 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 1);
42033 $bias = _makeTypesMatch2[0];
42034 // According to TensorFlow, the bias is supposed be a 1-D tensor or a
42035 // scalar.
42036 //
42037 // 3-D or 4-D bias is not disabled for NHWC format, because they are
42038 // currently being used in some cases. For examplem in our code base,
42039 // https://github.com/tensorflow/tfjs/blob/b53bd47e880367ae57493f0ea628abaf08db2d5d/tfjs-core/src/ops/fused/fused_conv2d_test.ts#L1972.
42040 if (dataFormat === 'NHWC') {
42041 assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
42042 } else {
42043 assert$1($bias.shape.length <= 1, function () {
42044 return "Error in fused conv2d: only supports scalar or 1-D Tensor " + "bias for NCHW format but got the bias of " + "rank-".concat($bias.shape.length, ".");
42045 });
42046 assert$1($bias.shape.length === 0 || $bias.shape[0] === convInfo.outChannels || $bias.shape[0] === 1, function () {
42047 return "Error in fused conv2d: bias shape (".concat($bias.shape, ") is not ") + "compatible with the number of output channels " + "(".concat(convInfo.outChannels, ")");
42048 });
42049 }
42050 }
42051 var $preluActivationWeights;
42052 if (preluActivationWeights != null) {
42053 // PReLU's activation weights could be a scalar, a 1-D tensor or a 3-D
42054 // tensor.
42055 var alphaShape = preluActivationWeights.shape;
42056 assert$1(alphaShape.length <= 1 || alphaShape.length === 3, function () {
42057 return "Error in fused conv2d: only supports scalar, 1-D Tensor or " + "3-D Tensor PReLU activation weights but got a tensor of " + "rank-".concat(alphaShape.length, ".");
42058 });
42059 if (alphaShape.length === 1) {
42060 // Whether the data format is NCHW or NHWC, the 1-D PReLU activation
42061 // weights tensor should be aligned with the output channels of conv2d
42062 // result.
42063 assert$1(alphaShape[0] === 1 || alphaShape[0] === convInfo.outChannels, function () {
42064 return "Error in fused conv2d: PReLU activation weights " + "(".concat(alphaShape, ") is not compatible with the number of output ") + "channels (".concat(convInfo.outChannels, ").");
42065 });
42066 } else if (alphaShape.length === 3) {
42067 // Whether the data format is NCHW or NHWC, the PReLU activation weights
42068 // tensor should has the compatible shape with the result of conv2d.
42069 try {
42070 assertAndGetBroadcastShape(alphaShape, convInfo.outShape);
42071 } catch (e) {
42072 var errMsg = "Error in fused conv2d: PReLU activation weights (".concat(alphaShape, ") ") + "is not compatible with the output shape of the conv2d " + "(".concat(convInfo.outShape, ").");
42073 throw Error(errMsg);
42074 }
42075 }
42076 $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused conv2d');
42077 }
42078 var grad = function grad(dy, saved) {
42079 assert$1(dataFormat === 'NHWC', function () {
42080 return "Error in gradient of fused conv2D: got dataFormat of ".concat(dataFormat, " but only NHWC is currently supported.");
42081 });
42082 var _saved = _slicedToArray(saved, 4),
42083 $filter = _saved[0],
42084 x4D = _saved[1],
42085 y = _saved[2],
42086 $bias = _saved[3];
42087 var dyActivation = getFusedDyActivation(dy, y, activation);
42088 assert$1(tupleValuesAreOne(dilations), function () {
42089 return 'Error in gradient of fused conv2D: ' + "dilation rates greater than 1 " + "are not yet supported in gradients. Got dilations '".concat(dilations, "'");
42090 });
42091 var xDer = conv2DBackpropInput$2(x4D.shape, dyActivation, $filter, strides, pad);
42092 var filterDer = conv2DBackpropFilter$2(x4D, dyActivation, $filter.shape, strides, pad);
42093 var der = [xDer, filterDer];
42094 if ($bias != null) {
42095 var biasDer = getFusedBiasGradient($bias, dyActivation);
42096 der.push(biasDer);
42097 }
42098 return der;
42099 };
42100 var inputs = {
42101 x: x4D,
42102 filter: $filter,
42103 bias: $bias,
42104 preluActivationWeights: $preluActivationWeights
42105 };
42106 var attrs = {
42107 strides: strides,
42108 pad: pad,
42109 dataFormat: dataFormat,
42110 dilations: dilations,
42111 dimRoundingMode: dimRoundingMode,
42112 activation: activation,
42113 leakyreluAlpha: leakyreluAlpha
42114 };
42115 // Depending on the the params passed in we will have different number of
42116 // inputs and thus a a different number of elements in the gradient.
42117 if (bias == null) {
42118 var customOp = customGrad(function (x4D, filter, save) {
42119 var res =
42120 // tslint:disable-next-line: no-unnecessary-type-assertion
42121 ENGINE.runKernel(FusedConv2D, inputs, attrs);
42122 save([filter, x4D, res]);
42123 if (reshapedTo4D) {
42124 // tslint:disable-next-line: no-unnecessary-type-assertion
42125 res = reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
42126 }
42127 return {
42128 value: res,
42129 gradFunc: grad
42130 };
42131 });
42132 return customOp(x4D, $filter);
42133 } else {
42134 var customOpWithBias = customGrad(function (x4D, filter, bias, save) {
42135 var res = ENGINE.runKernel(FusedConv2D, inputs, attrs);
42136 save([filter, x4D, res, bias]);
42137 if (reshapedTo4D) {
42138 // tslint:disable-next-line: no-unnecessary-type-assertion
42139 res = reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
42140 }
42141 return {
42142 value: res,
42143 gradFunc: grad
42144 };
42145 });
42146 return customOpWithBias(x4D, $filter, $bias);
42147 }
42148 }
42149 var conv2d$3 = /* @__PURE__ */op({
42150 fusedConv2d_: fusedConv2d_
42151 });
42152
42153 /**
42154 * @license
42155 * Copyright 2020 Google LLC. All Rights Reserved.
42156 * Licensed under the Apache License, Version 2.0 (the "License");
42157 * you may not use this file except in compliance with the License.
42158 * You may obtain a copy of the License at
42159 *
42160 * http://www.apache.org/licenses/LICENSE-2.0
42161 *
42162 * Unless required by applicable law or agreed to in writing, software
42163 * distributed under the License is distributed on an "AS IS" BASIS,
42164 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42165 * See the License for the specific language governing permissions and
42166 * limitations under the License.
42167 * =============================================================================
42168 */
42169 function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, strides, pad) {
42170 var dilations = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : [1, 1];
42171 var dimRoundingMode = arguments.length > 6 ? arguments[6] : undefined;
42172 var x4D = x;
42173 if (x.rank === 3) {
42174 x4D = reshape$3(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
42175 }
42176 var dy4D = dy;
42177 if (dy4D.rank === 3) {
42178 dy4D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
42179 }
42180 var inputs = {
42181 x: x4D,
42182 dy: dy4D
42183 };
42184 var attrs = {
42185 strides: strides,
42186 pad: pad,
42187 dimRoundingMode: dimRoundingMode,
42188 dilations: dilations,
42189 filterShape: filterShape
42190 };
42191 // tslint:disable-next-line: no-unnecessary-type-assertion
42192 return ENGINE.runKernel(DepthwiseConv2dNativeBackpropFilter, inputs, attrs);
42193 }
42194 var depthwiseConv2dNativeBackpropFilter$2 = op({
42195 depthwiseConv2dNativeBackpropFilter_: depthwiseConv2dNativeBackpropFilter_
42196 });
42197
42198 /**
42199 * @license
42200 * Copyright 2020 Google LLC. All Rights Reserved.
42201 * Licensed under the Apache License, Version 2.0 (the "License");
42202 * you may not use this file except in compliance with the License.
42203 * You may obtain a copy of the License at
42204 *
42205 * http://www.apache.org/licenses/LICENSE-2.0
42206 *
42207 * Unless required by applicable law or agreed to in writing, software
42208 * distributed under the License is distributed on an "AS IS" BASIS,
42209 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42210 * See the License for the specific language governing permissions and
42211 * limitations under the License.
42212 * =============================================================================
42213 */
42214 function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, strides, pad) {
42215 var dilations = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : [1, 1];
42216 var dimRoundingMode = arguments.length > 6 ? arguments[6] : undefined;
42217 var dy4D = dy;
42218 var reshapedTo4D = false;
42219 if (dy.rank === 3) {
42220 reshapedTo4D = true;
42221 dy4D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
42222 }
42223 var inputs = {
42224 dy: dy4D,
42225 filter: filter
42226 };
42227 var attrs = {
42228 strides: strides,
42229 pad: pad,
42230 dimRoundingMode: dimRoundingMode,
42231 dilations: dilations,
42232 inputShape: xShape
42233 };
42234 var res =
42235 // tslint:disable-next-line: no-unnecessary-type-assertion
42236 ENGINE.runKernel(DepthwiseConv2dNativeBackpropInput, inputs, attrs);
42237 if (reshapedTo4D) {
42238 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
42239 }
42240 return res;
42241 }
42242 var depthwiseConv2dNativeBackpropInput$2 = op({
42243 depthwiseConv2dNativeBackpropInput_: depthwiseConv2dNativeBackpropInput_
42244 });
42245
42246 /**
42247 * Computes depthwise 2D convolution, optionally fused with adding a
42248 * bias and applying an activation.
42249 *
42250 * Given a 4D `input` array and a `filter` array of shape
42251 * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
42252 * `inChannels` convolutional filters of depth 1, this op applies a
42253 * different filter to each input channel (expanding from 1 channel to
42254 * `channelMultiplier` channels for each), then concatenates the results
42255 * together. The output has `inChannels * channelMultiplier` channels.
42256 *
42257 * See
42258 * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
42259 * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
42260 * for more details.
42261 *
42262 * @param obj An object with the following properties:
42263 * @param x The input tensor, of rank 4 or rank 3, of shape
42264 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
42265 * assumed.
42266 * @param filter The filter tensor, rank 4, of shape
42267 * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
42268 * @param strides The strides of the convolution: `[strideHeight,
42269 * strideWidth]`. If strides is a single number, then `strideHeight ==
42270 * strideWidth`.
42271 * @param pad The type of padding algorithm.
42272 * - `same` and stride 1: output will be of same size as input,
42273 * regardless of filter size.
42274 * - `valid`: output will be smaller than input if filter is larger
42275 * than 1x1.
42276 * - For more info, see this guide:
42277 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
42278 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
42279 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
42280 * in which we sample input values across the height and width dimensions
42281 * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
42282 * number, then `dilationHeight == dilationWidth`. If it is greater than
42283 * 1, then all values of `strides` must be 1.
42284 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
42285 * "NHWC". Specify the data format of the input and output data. With the
42286 * default format "NHWC", the data is stored in the order of: [batch,
42287 * height, width, channels]. Only "NHWC" is currently supported.
42288 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
42289 * provided, it will default to truncate.
42290 * @param bias Tensor to be added to the result.
42291 * @param activation Name of activation kernel (defaults to `linear`).
42292 * @param preluActivationWeights Tensor of prelu weights to be applied as part
42293 * of a `prelu` activation, typically the same shape as `x`.
42294 * @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`
42295 * activation.
42296 */
42297 function fusedDepthwiseConv2d_(_ref) {
42298 var x = _ref.x,
42299 filter = _ref.filter,
42300 strides = _ref.strides,
42301 pad = _ref.pad,
42302 _ref$dataFormat = _ref.dataFormat,
42303 dataFormat = _ref$dataFormat === void 0 ? 'NHWC' : _ref$dataFormat,
42304 _ref$dilations = _ref.dilations,
42305 dilations = _ref$dilations === void 0 ? [1, 1] : _ref$dilations,
42306 dimRoundingMode = _ref.dimRoundingMode,
42307 bias = _ref.bias,
42308 _ref$activation = _ref.activation,
42309 activation = _ref$activation === void 0 ? 'linear' : _ref$activation,
42310 preluActivationWeights = _ref.preluActivationWeights,
42311 leakyreluAlpha = _ref.leakyreluAlpha;
42312 if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
42313 var result = depthwiseConv2d$3(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
42314 if (bias != null) {
42315 result = add$3(result, bias);
42316 }
42317 return applyActivation$1(result, activation, preluActivationWeights, leakyreluAlpha);
42318 }
42319 var $x = convertToTensor(x, 'x', 'depthwiseConv2d', 'float32');
42320 var $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d', 'float32');
42321 var x4D = $x;
42322 var reshapedTo4D = false;
42323 if ($x.rank === 3) {
42324 reshapedTo4D = true;
42325 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
42326 }
42327 assert$1(x4D.rank === 4, function () {
42328 return "Error in fused depthwiseConv2d: input must be rank 4, but got " + "rank ".concat(x4D.rank, ".");
42329 });
42330 assert$1($filter.rank === 4, function () {
42331 return "Error in fused depthwiseConv2d: filter must be rank 4, " + "but got rank ".concat($filter.rank, ".");
42332 });
42333 assert$1(x4D.shape[3] === $filter.shape[2], function () {
42334 return "Error in fused depthwiseConv2d: number of input channels " + "(".concat(x4D.shape[3], ") must match the inChannels dimension in ") + "filter ".concat($filter.shape[2], ".");
42335 });
42336 if (dilations == null) {
42337 dilations = [1, 1];
42338 }
42339 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
42340 return 'Error in fused depthwiseConv2d: Either strides or dilations must ' + "be 1. Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
42341 });
42342 checkPadOnDimRoundingMode('fused depthwiseConv2d', pad, dimRoundingMode);
42343 var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
42344 var $bias;
42345 if (bias != null) {
42346 $bias = convertToTensor(bias, 'bias', 'fused conv2d');
42347 var _makeTypesMatch = makeTypesMatch($bias, $x);
42348 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 1);
42349 $bias = _makeTypesMatch2[0];
42350 assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
42351 }
42352 var $preluActivationWeights;
42353 if (preluActivationWeights != null) {
42354 $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d');
42355 }
42356 var grad = function grad(dy, saved) {
42357 assert$1(tupleValuesAreOne(dilations), function () {
42358 return 'Error in gradient of fused depthwiseConv2d: dilation rates ' + "greater than 1 are not yet supported. Got dilations " + "'".concat(dilations, "'");
42359 });
42360 var _saved = _slicedToArray(saved, 4),
42361 $filter = _saved[0],
42362 x4D = _saved[1],
42363 y = _saved[2],
42364 bias = _saved[3];
42365 var dyActivation = getFusedDyActivation(dy, y, activation);
42366 var xDer = depthwiseConv2dNativeBackpropInput$2(x4D.shape, dyActivation, $filter, strides, pad, dilations, dimRoundingMode);
42367 var filterDer = depthwiseConv2dNativeBackpropFilter$2(x4D, dyActivation, $filter.shape, strides, pad, dilations, dimRoundingMode);
42368 if (bias != null) {
42369 var biasDer = getFusedBiasGradient($bias, dyActivation);
42370 return [xDer, filterDer, biasDer];
42371 }
42372 return [xDer, filterDer];
42373 };
42374 var inputs = {
42375 x: x4D,
42376 filter: $filter,
42377 bias: $bias,
42378 preluActivationWeights: $preluActivationWeights
42379 };
42380 var attrs = {
42381 strides: strides,
42382 pad: pad,
42383 dataFormat: dataFormat,
42384 dilations: dilations,
42385 dimRoundingMode: dimRoundingMode,
42386 activation: activation,
42387 leakyreluAlpha: leakyreluAlpha
42388 };
42389 // Depending on the the params passed in we will have different number of
42390 // inputs and thus a a different number of elements in the gradient.
42391 if (bias == null) {
42392 var customOp = customGrad(function (x4D, filter, save) {
42393 // tslint:disable-next-line: no-unnecessary-type-assertion
42394 var res = ENGINE.runKernel(FusedDepthwiseConv2D, inputs, attrs);
42395 save([filter, x4D, res]);
42396 if (reshapedTo4D) {
42397 // tslint:disable-next-line: no-unnecessary-type-assertion
42398 res = reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
42399 }
42400 return {
42401 value: res,
42402 gradFunc: grad
42403 };
42404 });
42405 return customOp(x4D, $filter);
42406 } else {
42407 var customOpWithBias = customGrad(function (x4D, filter, bias, save) {
42408 // tslint:disable-next-line: no-unnecessary-type-assertion
42409 var res = ENGINE.runKernel(FusedDepthwiseConv2D, inputs, attrs);
42410 save([filter, x4D, res, bias]);
42411 if (reshapedTo4D) {
42412 // tslint:disable-next-line: no-unnecessary-type-assertion
42413 res = reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
42414 }
42415 return {
42416 value: res,
42417 gradFunc: grad
42418 };
42419 });
42420 return customOpWithBias(x4D, $filter, $bias);
42421 }
42422 }
42423 var depthwiseConv2d$2 = /* @__PURE__ */op({
42424 fusedDepthwiseConv2d_: fusedDepthwiseConv2d_
42425 });
42426
42427 /**
42428 * Computes the dot product of two matrices with optional activation and bias.
42429 *
42430 * ```js
42431 * const a = tf.tensor2d([-1, -2], [1, 2]);
42432 * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
42433 * const bias = tf.tensor2d([1, 2], [1, 2]);
42434 *
42435 * tf.fused.matMul({a, b, bias, activation: 'relu'}).print();
42436 * ```
42437 *
42438 * @param obj An object with the following properties:
42439 * - `a` First matrix in dot product operation.
42440 * - `b` Second matrix in dot product operation.
42441 * - `transposeA` If true, `a` is transposed before multiplication.
42442 * - `transposeB` If true, `b` is transposed before multiplication.
42443 * - `bias` Matrix to be added to the result.
42444 * - `activation` Name of activation kernel (defaults to `linear`).
42445 * - `preluActivationWeights` Tensor of prelu weights.
42446 * - `leakyreluAlpha` Alpha of leakyrelu.
42447 */
42448 function fusedMatMul_(_ref) {
42449 var a = _ref.a,
42450 b = _ref.b,
42451 _ref$transposeA = _ref.transposeA,
42452 transposeA = _ref$transposeA === void 0 ? false : _ref$transposeA,
42453 _ref$transposeB = _ref.transposeB,
42454 transposeB = _ref$transposeB === void 0 ? false : _ref$transposeB,
42455 bias = _ref.bias,
42456 _ref$activation = _ref.activation,
42457 activation = _ref$activation === void 0 ? 'linear' : _ref$activation,
42458 preluActivationWeights = _ref.preluActivationWeights,
42459 _ref$leakyreluAlpha = _ref.leakyreluAlpha,
42460 leakyreluAlpha = _ref$leakyreluAlpha === void 0 ? 0.2 : _ref$leakyreluAlpha;
42461 if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
42462 var result = matMul$1(a, b, transposeA, transposeB);
42463 if (bias != null) {
42464 result = add$3(result, bias);
42465 }
42466 return applyActivation$1(result, activation, preluActivationWeights, leakyreluAlpha);
42467 }
42468 var $a = convertToTensor(a, 'a', 'fused matMul');
42469 var $b = convertToTensor(b, 'b', 'fused matMul');
42470 var _makeTypesMatch = makeTypesMatch($a, $b);
42471 var _makeTypesMatch2 = _slicedToArray(_makeTypesMatch, 2);
42472 $a = _makeTypesMatch2[0];
42473 $b = _makeTypesMatch2[1];
42474 var innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
42475 var innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
42476 var outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
42477 var outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
42478 var outerDimsA = $a.shape.slice(0, -2);
42479 var outerDimsB = $b.shape.slice(0, -2);
42480 var batchDimA = sizeFromShape(outerDimsA);
42481 var batchDimB = sizeFromShape(outerDimsB);
42482 assert$1(innerShapeA === innerShapeB, function () {
42483 return "Error in fused matMul: inner shapes (".concat(innerShapeA, ") and (") + "".concat(innerShapeB, ") of Tensors with shapes ").concat($a.shape, " and ") + "".concat($b.shape, " and transposeA=").concat(transposeA) + " and transposeB=".concat(transposeB, " must match.");
42484 });
42485 var outShapeOuterDims = assertAndGetBroadcastShape($a.shape.slice(0, -2), $b.shape.slice(0, -2));
42486 var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
42487 var a3D = transposeA ? reshape$3($a, [batchDimA, innerShapeA, outerShapeA]) : reshape$3($a, [batchDimA, outerShapeA, innerShapeA]);
42488 var b3D = transposeB ? reshape$3($b, [batchDimB, outerShapeB, innerShapeB]) : reshape$3($b, [batchDimB, innerShapeB, outerShapeB]);
42489 var $bias;
42490 if (bias != null) {
42491 $bias = convertToTensor(bias, 'bias', 'fused matMul');
42492 var _makeTypesMatch3 = makeTypesMatch($bias, $a);
42493 var _makeTypesMatch4 = _slicedToArray(_makeTypesMatch3, 1);
42494 $bias = _makeTypesMatch4[0];
42495 assertAndGetBroadcastShape(outShape, $bias.shape);
42496 }
42497 var $preluActivationWeights;
42498 if (preluActivationWeights != null) {
42499 $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused matMul');
42500 }
42501 var grad = function grad(dy, saved) {
42502 var _saved = _slicedToArray(saved, 4),
42503 a3D = _saved[0],
42504 b3D = _saved[1],
42505 y = _saved[2],
42506 $bias = _saved[3];
42507 // we reshape dy because the result of the forward is not
42508 // necessarily going to be a 3d tensor due to a reshape done at the end of
42509 // the customOp.
42510 var dyActivation = getFusedDyActivation(reshape$3(dy, y.shape), y, activation);
42511 var aDer;
42512 var bDer;
42513 if (!transposeA && !transposeB) {
42514 aDer = matMul$1(dyActivation, b3D, false, true);
42515 bDer = matMul$1(a3D, dyActivation, true, false);
42516 } else if (!transposeA && transposeB) {
42517 aDer = matMul$1(dyActivation, b3D, false, false);
42518 bDer = matMul$1(dyActivation, a3D, true, false);
42519 } else if (transposeA && !transposeB) {
42520 aDer = matMul$1(b3D, dyActivation, false, true);
42521 bDer = matMul$1(a3D, dyActivation, false, false);
42522 } else {
42523 aDer = matMul$1(b3D, dyActivation, true, true);
42524 bDer = matMul$1(dyActivation, a3D, true, true);
42525 }
42526 if (bias != null) {
42527 var biasDer = getFusedBiasGradient($bias, dyActivation);
42528 return [aDer, bDer, biasDer];
42529 } else {
42530 return [aDer, bDer];
42531 }
42532 };
42533 var inputs = {
42534 a: a3D,
42535 b: b3D,
42536 bias: $bias,
42537 preluActivationWeights: $preluActivationWeights
42538 };
42539 var attrs = {
42540 transposeA: transposeA,
42541 transposeB: transposeB,
42542 activation: activation,
42543 leakyreluAlpha: leakyreluAlpha
42544 };
42545 // Depending on the the params passed in we will have different number of
42546 // inputs and thus a a different number of elements in the gradient.
42547 if (bias == null) {
42548 var customOp = customGrad(function (a3D, b3D, save) {
42549 var res =
42550 // tslint:disable-next-line: no-unnecessary-type-assertion
42551 ENGINE.runKernel(_FusedMatMul, inputs, attrs);
42552 save([a3D, b3D, res]);
42553 return {
42554 value: reshape$3(res, outShape),
42555 gradFunc: grad
42556 };
42557 });
42558 return customOp(a3D, b3D);
42559 } else {
42560 var customOpWithBias = customGrad(function (a3D, b3D, $bias, save) {
42561 var res =
42562 // tslint:disable-next-line: no-unnecessary-type-assertion
42563 ENGINE.runKernel(_FusedMatMul, inputs, attrs);
42564 save([a3D, b3D, res, $bias]);
42565 return {
42566 value: reshape$3(res, outShape),
42567 gradFunc: grad
42568 };
42569 });
42570 return customOpWithBias(a3D, b3D, $bias);
42571 }
42572 }
42573 var matMul = /* @__PURE__ */op({
42574 fusedMatMul_: fusedMatMul_
42575 });
42576
42577 /**
42578 * @license
42579 * Copyright 2019 Google LLC. All Rights Reserved.
42580 * Licensed under the Apache License, Version 2.0 (the "License");
42581 * you may not use this file except in compliance with the License.
42582 * You may obtain a copy of the License at
42583 *
42584 * http://www.apache.org/licenses/LICENSE-2.0
42585 *
42586 * Unless required by applicable law or agreed to in writing, software
42587 * distributed under the License is distributed on an "AS IS" BASIS,
42588 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42589 * See the License for the specific language governing permissions and
42590 * limitations under the License.
42591 * =============================================================================
42592 */
42593
42594 var fused_ops = {
42595 __proto__: null,
42596 conv2d: conv2d$3,
42597 depthwiseConv2d: depthwiseConv2d$2,
42598 matMul: matMul
42599 };
42600
42601 /**
42602 * @license
42603 * Copyright 2019 Google LLC. All Rights Reserved.
42604 * Licensed under the Apache License, Version 2.0 (the "License");
42605 * you may not use this file except in compliance with the License.
42606 * You may obtain a copy of the License at
42607 *
42608 * http://www.apache.org/licenses/LICENSE-2.0
42609 *
42610 * Unless required by applicable law or agreed to in writing, software
42611 * distributed under the License is distributed on an "AS IS" BASIS,
42612 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42613 * See the License for the specific language governing permissions and
42614 * limitations under the License.
42615 * =============================================================================
42616 */
42617 /**
42618 * Generate a hamming window.
42619 *
42620 * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
42621 *
42622 * ```js
42623 * tf.signal.hammingWindow(10).print();
42624 * ```
42625 * @param The length of window
42626 *
42627 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
42628 */
42629 function hammingWindow_(windowLength) {
42630 return cosineWindow(windowLength, 0.54, 0.46);
42631 }
42632 var hammingWindow = /* @__PURE__ */op({
42633 hammingWindow_: hammingWindow_
42634 });
42635
42636 /**
42637 * @license
42638 * Copyright 2019 Google LLC. All Rights Reserved.
42639 * Licensed under the Apache License, Version 2.0 (the "License");
42640 * you may not use this file except in compliance with the License.
42641 * You may obtain a copy of the License at
42642 *
42643 * http://www.apache.org/licenses/LICENSE-2.0
42644 *
42645 * Unless required by applicable law or agreed to in writing, software
42646 * distributed under the License is distributed on an "AS IS" BASIS,
42647 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42648 * See the License for the specific language governing permissions and
42649 * limitations under the License.
42650 * =============================================================================
42651 */
42652 /**
42653 * Generate a Hann window.
42654 *
42655 * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
42656 *
42657 * ```js
42658 * tf.signal.hannWindow(10).print();
42659 * ```
42660 * @param The length of window
42661 *
42662 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
42663 */
42664 function hannWindow_(windowLength) {
42665 return cosineWindow(windowLength, 0.5, 0.5);
42666 }
42667 var hannWindow = /* @__PURE__ */op({
42668 hannWindow_: hannWindow_
42669 });
42670
42671 /**
42672 * @license
42673 * Copyright 2019 Google LLC. All Rights Reserved.
42674 * Licensed under the Apache License, Version 2.0 (the "License");
42675 * you may not use this file except in compliance with the License.
42676 * You may obtain a copy of the License at
42677 *
42678 * http://www.apache.org/licenses/LICENSE-2.0
42679 *
42680 * Unless required by applicable law or agreed to in writing, software
42681 * distributed under the License is distributed on an "AS IS" BASIS,
42682 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42683 * See the License for the specific language governing permissions and
42684 * limitations under the License.
42685 * =============================================================================
42686 */
42687 /**
42688 * Expands input into frames of frameLength.
42689 * Slides a window size with frameStep.
42690 *
42691 * ```js
42692 * tf.signal.frame([1, 2, 3], 2, 1).print();
42693 * ```
42694 * @param signal The input tensor to be expanded
42695 * @param frameLength Length of each frame
42696 * @param frameStep The frame hop size in samples.
42697 * @param padEnd Whether to pad the end of signal with padValue.
42698 * @param padValue A number to use where the input signal does
42699 * not exist when padEnd is True.
42700 *
42701 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
42702 */
42703 function frame_(signal, frameLength, frameStep) {
42704 var padEnd = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
42705 var padValue = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 0;
42706 var start = 0;
42707 var output = [];
42708 while (start + frameLength <= signal.size) {
42709 output.push(slice$2(signal, start, frameLength));
42710 start += frameStep;
42711 }
42712 if (padEnd) {
42713 while (start < signal.size) {
42714 var padLen = start + frameLength - signal.size;
42715 var pad = concat$2([slice$2(signal, start, frameLength - padLen), fill$2([padLen], padValue)]);
42716 output.push(pad);
42717 start += frameStep;
42718 }
42719 }
42720 if (output.length === 0) {
42721 return tensor2d([], [0, frameLength]);
42722 }
42723 return reshape$3(concat$2(output), [output.length, frameLength]);
42724 }
42725 var frame = /* @__PURE__ */op({
42726 frame_: frame_
42727 });
42728
42729 /**
42730 * @license
42731 * Copyright 2019 Google LLC. All Rights Reserved.
42732 * Licensed under the Apache License, Version 2.0 (the "License");
42733 * you may not use this file except in compliance with the License.
42734 * You may obtain a copy of the License at
42735 *
42736 * http://www.apache.org/licenses/LICENSE-2.0
42737 *
42738 * Unless required by applicable law or agreed to in writing, software
42739 * distributed under the License is distributed on an "AS IS" BASIS,
42740 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42741 * See the License for the specific language governing permissions and
42742 * limitations under the License.
42743 * =============================================================================
42744 */
42745 /**
42746 * Computes the Short-time Fourier Transform of signals
42747 * See: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
42748 *
42749 * ```js
42750 * const input = tf.tensor1d([1, 1, 1, 1, 1])
42751 * tf.signal.stft(input, 3, 1).print();
42752 * ```
42753 * @param signal 1-dimensional real value tensor.
42754 * @param frameLength The window length of samples.
42755 * @param frameStep The number of samples to step.
42756 * @param fftLength The size of the FFT to apply.
42757 * @param windowFn A callable that takes a window length and returns 1-d tensor.
42758 *
42759 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
42760 */
42761 function stft_(signal, frameLength, frameStep, fftLength) {
42762 var windowFn = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : hannWindow;
42763 if (fftLength == null) {
42764 fftLength = enclosingPowerOfTwo(frameLength);
42765 }
42766 var framedSignal = frame(signal, frameLength, frameStep);
42767 var windowedSignal = mul(framedSignal, windowFn(frameLength));
42768 return rfft(windowedSignal, fftLength);
42769 }
42770 var stft = /* @__PURE__ */op({
42771 stft_: stft_
42772 });
42773
42774 /**
42775 * @license
42776 * Copyright 2020 Google LLC. All Rights Reserved.
42777 * Licensed under the Apache License, Version 2.0 (the "License");
42778 * you may not use this file except in compliance with the License.
42779 * You may obtain a copy of the License at
42780 *
42781 * http://www.apache.org/licenses/LICENSE-2.0
42782 *
42783 * Unless required by applicable law or agreed to in writing, software
42784 * distributed under the License is distributed on an "AS IS" BASIS,
42785 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42786 * See the License for the specific language governing permissions and
42787 * limitations under the License.
42788 * =============================================================================
42789 */
42790 /**
42791 * Extracts crops from the input image tensor and resizes them using bilinear
42792 * sampling or nearest neighbor sampling (possibly with aspect ratio change)
42793 * to a common output size specified by cropSize.
42794 *
42795 * @param image 4d tensor of shape `[batch,imageHeight,imageWidth, depth]`,
42796 * where imageHeight and imageWidth must be positive, specifying the
42797 * batch of images from which to take crops
42798 * @param boxes 2d float32 tensor of shape `[numBoxes, 4]`. Each entry is
42799 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the normalized
42800 * coordinates of the box in the `boxInd[i]`th image in the batch
42801 * @param boxInd 1d int32 tensor of shape `[numBoxes]` with values in range
42802 * `[0, batch)` that specifies the image that the `i`-th box refers to.
42803 * @param cropSize 1d int32 tensor of 2 elements `[cropHeigh, cropWidth]`
42804 * specifying the size to which all crops are resized to.
42805 * @param method Optional string from `'bilinear' | 'nearest'`,
42806 * defaults to bilinear, which specifies the sampling method for resizing
42807 * @param extrapolationValue A threshold for deciding when to remove boxes based
42808 * on score. Defaults to 0.
42809 * @return A 4D tensor of the shape `[numBoxes,cropHeight,cropWidth,depth]`
42810 *
42811 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
42812 */
42813 function cropAndResize_(image, boxes, boxInd, cropSize) {
42814 var method = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 'bilinear';
42815 var extrapolationValue = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 0;
42816 var $image = convertToTensor(image, 'image', 'cropAndResize');
42817 var $boxes = convertToTensor(boxes, 'boxes', 'cropAndResize', 'float32');
42818 var $boxInd = convertToTensor(boxInd, 'boxInd', 'cropAndResize', 'int32');
42819 var numBoxes = $boxes.shape[0];
42820 assert$1($image.rank === 4, function () {
42821 return 'Error in cropAndResize: image must be rank 4,' + "but got rank ".concat($image.rank, ".");
42822 });
42823 assert$1($boxes.rank === 2 && $boxes.shape[1] === 4, function () {
42824 return "Error in cropAndResize: boxes must be have size [".concat(numBoxes, ",4] ") + "but had shape ".concat($boxes.shape, ".");
42825 });
42826 assert$1($boxInd.rank === 1 && $boxInd.shape[0] === numBoxes, function () {
42827 return "Error in cropAndResize: boxInd must be have size [".concat(numBoxes, "] ") + "but had shape ".concat($boxes.shape, ".");
42828 });
42829 assert$1(cropSize.length === 2, function () {
42830 return "Error in cropAndResize: cropSize must be of length 2, but got " + "length ".concat(cropSize.length, ".");
42831 });
42832 assert$1(cropSize[0] >= 1 && cropSize[1] >= 1, function () {
42833 return "cropSize must be atleast [1,1], but was ".concat(cropSize);
42834 });
42835 assert$1(method === 'bilinear' || method === 'nearest', function () {
42836 return "method must be bilinear or nearest, but was ".concat(method);
42837 });
42838 var inputs = {
42839 image: $image,
42840 boxes: $boxes,
42841 boxInd: $boxInd
42842 };
42843 var attrs = {
42844 method: method,
42845 extrapolationValue: extrapolationValue,
42846 cropSize: cropSize
42847 };
42848 var res = ENGINE.runKernel(CropAndResize, inputs, attrs);
42849 return res;
42850 }
42851 var cropAndResize$3 = /* @__PURE__ */op({
42852 cropAndResize_: cropAndResize_
42853 });
42854
42855 /**
42856 * @license
42857 * Copyright 2020 Google LLC. All Rights Reserved.
42858 * Licensed under the Apache License, Version 2.0 (the "License");
42859 * you may not use this file except in compliance with the License.
42860 * You may obtain a copy of the License at
42861 *
42862 * http://www.apache.org/licenses/LICENSE-2.0
42863 *
42864 * Unless required by applicable law or agreed to in writing, software
42865 * distributed under the License is distributed on an "AS IS" BASIS,
42866 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42867 * See the License for the specific language governing permissions and
42868 * limitations under the License.
42869 * =============================================================================
42870 */
42871 /**
42872 * Flips the image left to right. Currently available in the CPU, WebGL, and
42873 * WASM backends.
42874 *
42875 * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
42876 */
42877 /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */
42878 function flipLeftRight_(image) {
42879 var $image = convertToTensor(image, 'image', 'flipLeftRight', 'float32');
42880 assert$1($image.rank === 4, function () {
42881 return 'Error in flipLeftRight: image must be rank 4,' + "but got rank ".concat($image.rank, ".");
42882 });
42883 var inputs = {
42884 image: $image
42885 };
42886 var res = ENGINE.runKernel(FlipLeftRight, inputs, {});
42887 return res;
42888 }
42889 var flipLeftRight = /* @__PURE__ */op({
42890 flipLeftRight_: flipLeftRight_
42891 });
42892
42893 /**
42894 * @license
42895 * Copyright 2021 Google LLC. All Rights Reserved.
42896 * Licensed under the Apache License, Version 2.0 (the "License");
42897 * you may not use this file except in compliance with the License.
42898 * You may obtain a copy of the License at
42899 *
42900 * http://www.apache.org/licenses/LICENSE-2.0
42901 *
42902 * Unless required by applicable law or agreed to in writing, software
42903 * distributed under the License is distributed on an "AS IS" BASIS,
42904 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42905 * See the License for the specific language governing permissions and
42906 * limitations under the License.
42907 * =============================================================================
42908 */
42909 /**
42910 * Converts images from grayscale to RGB format.
42911 *
42912 * @param image A grayscale tensor to convert. The `image`'s last dimension must
42913 * be size 1 with at least a two-dimensional shape.
42914 *
42915 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
42916 */
42917 function grayscaleToRGB_(image) {
42918 var $image = convertToTensor(image, 'image', 'grayscaleToRGB');
42919 var lastDimsIdx = $image.rank - 1;
42920 var lastDims = $image.shape[lastDimsIdx];
42921 assert$1($image.rank >= 2, function () {
42922 return 'Error in grayscaleToRGB: images must be at least rank 2, ' + "but got rank ".concat($image.rank, ".");
42923 });
42924 assert$1(lastDims === 1, function () {
42925 return 'Error in grayscaleToRGB: last dimension of a grayscale image ' + "should be size 1, but got size ".concat(lastDims, ".");
42926 });
42927 var reps = new Array($image.rank);
42928 reps.fill(1, 0, lastDimsIdx);
42929 reps[lastDimsIdx] = 3;
42930 return tile$3($image, reps);
42931 }
42932 var grayscaleToRGB = /* @__PURE__ */op({
42933 grayscaleToRGB_: grayscaleToRGB_
42934 });
42935
42936 /**
42937 * @license
42938 * Copyright 2023 Google LLC.
42939 * Licensed under the Apache License, Version 2.0 (the "License");
42940 * you may not use this file except in compliance with the License.
42941 * You may obtain a copy of the License at
42942 *
42943 * http://www.apache.org/licenses/LICENSE-2.0
42944 *
42945 * Unless required by applicable law or agreed to in writing, software
42946 * distributed under the License is distributed on an "AS IS" BASIS,
42947 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42948 * See the License for the specific language governing permissions and
42949 * limitations under the License.
42950 * =============================================================================
42951 */
42952 /**
42953 * Converts images from RGB format to grayscale.
42954 *
42955 * @param image A RGB tensor to convert. The `image`'s last dimension must
42956 * be size 3 with at least a two-dimensional shape.
42957 *
42958 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
42959 */
42960 function rgbToGrayscale_(image) {
42961 var $image = convertToTensor(image, 'image', 'RGBToGrayscale');
42962 var lastDimsIdx = $image.rank - 1;
42963 var lastDims = $image.shape[lastDimsIdx];
42964 assert$1($image.rank >= 2, function () {
42965 return 'Error in RGBToGrayscale: images must be at least rank 2, ' + "but got rank ".concat($image.rank, ".");
42966 });
42967 assert$1(lastDims === 3, function () {
42968 return 'Error in RGBToGrayscale: last dimension of an RGB image ' + "should be size 3, but got size ".concat(lastDims, ".");
42969 });
42970 // Remember original dtype so we can convert back if needed
42971 var origDtype = $image.dtype;
42972 var fltImage = cast$3($image, 'float32');
42973 var rgbWeights = tensor1d([0.2989, 0.5870, 0.1140]);
42974 var grayFloat;
42975 switch ($image.rank) {
42976 case 2:
42977 grayFloat = einsum$2('ij,j->i', fltImage, rgbWeights);
42978 break;
42979 case 3:
42980 grayFloat = einsum$2('ijk,k->ij', fltImage, rgbWeights);
42981 break;
42982 case 4:
42983 grayFloat = einsum$2('ijkl,l->ijk', fltImage, rgbWeights);
42984 break;
42985 case 5:
42986 grayFloat = einsum$2('ijklm,m->ijkl', fltImage, rgbWeights);
42987 break;
42988 case 6:
42989 grayFloat = einsum$2('ijklmn,n->ijklm', fltImage, rgbWeights);
42990 break;
42991 default:
42992 throw new Error('Not a valid tensor rank.');
42993 }
42994 grayFloat = expandDims$3(grayFloat, -1);
42995 return cast$3(grayFloat, origDtype);
42996 }
42997 var rgbToGrayscale = /* @__PURE__ */op({
42998 rgbToGrayscale_: rgbToGrayscale_
42999 });
43000
43001 /**
43002 * @license
43003 * Copyright 2020 Google LLC. All Rights Reserved.
43004 * Licensed under the Apache License, Version 2.0 (the "License");
43005 * you may not use this file except in compliance with the License.
43006 * You may obtain a copy of the License at
43007 *
43008 * http://www.apache.org/licenses/LICENSE-2.0
43009 *
43010 * Unless required by applicable law or agreed to in writing, software
43011 * distributed under the License is distributed on an "AS IS" BASIS,
43012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
43013 * See the License for the specific language governing permissions and
43014 * limitations under the License.
43015 * =============================================================================
43016 */
43017 /**
43018 * Rotates the input image tensor counter-clockwise with an optional offset
43019 * center of rotation. Currently available in the CPU, WebGL, and WASM backends.
43020 *
43021 * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
43022 * @param radians The amount of rotation.
43023 * @param fillValue The value to fill in the empty space leftover
43024 * after rotation. Can be either a single grayscale value (0-255), or an
43025 * array of three numbers `[red, green, blue]` specifying the red, green,
43026 * and blue channels. Defaults to `0` (black).
43027 * @param center The center of rotation. Can be either a single value (0-1), or
43028 * an array of two numbers `[centerX, centerY]`. Defaults to `0.5` (rotates
43029 * the image around its center).
43030 *
43031 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
43032 */
43033 function rotateWithOffset_(image, radians) {
43034 var fillValue = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
43035 var center = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 0.5;
43036 var $image = convertToTensor(image, 'image', 'rotateWithOffset', 'float32');
43037 assert$1($image.rank === 4, function () {
43038 return 'Error in rotateWithOffset: image must be rank 4,' + "but got rank ".concat($image.rank, ".");
43039 });
43040 var inputs = {
43041 image: $image
43042 };
43043 var attrs = {
43044 radians: radians,
43045 fillValue: fillValue,
43046 center: center
43047 };
43048 var res = ENGINE.runKernel(RotateWithOffset, inputs, attrs);
43049 return res;
43050 }
43051 var rotateWithOffset = /* @__PURE__ */op({
43052 rotateWithOffset_: rotateWithOffset_
43053 });
43054
43055 /**
43056 * @license
43057 * Copyright 2020 Google LLC. All Rights Reserved.
43058 * Licensed under the Apache License, Version 2.0 (the "License");
43059 * you may not use this file except in compliance with the License.
43060 * You may obtain a copy of the License at
43061 *
43062 * http://www.apache.org/licenses/LICENSE-2.0
43063 *
43064 * Unless required by applicable law or agreed to in writing, software
43065 * distributed under the License is distributed on an "AS IS" BASIS,
43066 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
43067 * See the License for the specific language governing permissions and
43068 * limitations under the License.
43069 * =============================================================================
43070 */
43071 function nonMaxSuppSanityCheck(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
43072 if (iouThreshold == null) {
43073 iouThreshold = 0.5;
43074 }
43075 if (scoreThreshold == null) {
43076 scoreThreshold = Number.NEGATIVE_INFINITY;
43077 }
43078 if (softNmsSigma == null) {
43079 softNmsSigma = 0.0;
43080 }
43081 var numBoxes = boxes.shape[0];
43082 maxOutputSize = Math.min(maxOutputSize, numBoxes);
43083 assert$1(0 <= iouThreshold && iouThreshold <= 1, function () {
43084 return "iouThreshold must be in [0, 1], but was '".concat(iouThreshold, "'");
43085 });
43086 assert$1(boxes.rank === 2, function () {
43087 return "boxes must be a 2D tensor, but was of rank '".concat(boxes.rank, "'");
43088 });
43089 assert$1(boxes.shape[1] === 4, function () {
43090 return "boxes must have 4 columns, but 2nd dimension was ".concat(boxes.shape[1]);
43091 });
43092 assert$1(scores.rank === 1, function () {
43093 return 'scores must be a 1D tensor';
43094 });
43095 assert$1(scores.shape[0] === numBoxes, function () {
43096 return "scores has incompatible shape with boxes. Expected ".concat(numBoxes, ", ") + "but was ".concat(scores.shape[0]);
43097 });
43098 assert$1(0 <= softNmsSigma && softNmsSigma <= 1, function () {
43099 return "softNmsSigma must be in [0, 1], but was '".concat(softNmsSigma, "'");
43100 });
43101 return {
43102 maxOutputSize: maxOutputSize,
43103 iouThreshold: iouThreshold,
43104 scoreThreshold: scoreThreshold,
43105 softNmsSigma: softNmsSigma
43106 };
43107 }
43108
43109 /**
43110 * @license
43111 * Copyright 2020 Google LLC. All Rights Reserved.
43112 * Licensed under the Apache License, Version 2.0 (the "License");
43113 * you may not use this file except in compliance with the License.
43114 * You may obtain a copy of the License at
43115 *
43116 * http://www.apache.org/licenses/LICENSE-2.0
43117 *
43118 * Unless required by applicable law or agreed to in writing, software
43119 * distributed under the License is distributed on an "AS IS" BASIS,
43120 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
43121 * See the License for the specific language governing permissions and
43122 * limitations under the License.
43123 * =============================================================================
43124 */
43125 /**
43126 * Performs non maximum suppression of bounding boxes based on
43127 * iou (intersection over union).
43128 *
43129 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
43130 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
43131 * the bounding box.
43132 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
43133 * @param maxOutputSize The maximum number of boxes to be selected.
43134 * @param iouThreshold A float representing the threshold for deciding whether
43135 * boxes overlap too much with respect to IOU. Must be between [0, 1].
43136 * Defaults to 0.5 (50% box overlap).
43137 * @param scoreThreshold A threshold for deciding when to remove boxes based
43138 * on score. Defaults to -inf, which means any score is accepted.
43139 * @return A 1D tensor with the selected box indices.
43140 *
43141 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
43142 */
43143 function nonMaxSuppression_(boxes, scores, maxOutputSize) {
43144 var iouThreshold = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 0.5;
43145 var scoreThreshold = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : Number.NEGATIVE_INFINITY;
43146 var $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression', 'float32');
43147 var $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression', 'float32');
43148 var inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
43149 maxOutputSize = inputs.maxOutputSize;
43150 iouThreshold = inputs.iouThreshold;
43151 scoreThreshold = inputs.scoreThreshold;
43152 var attrs = {
43153 maxOutputSize: maxOutputSize,
43154 iouThreshold: iouThreshold,
43155 scoreThreshold: scoreThreshold
43156 };
43157 return ENGINE.runKernel(NonMaxSuppressionV3, {
43158 boxes: $boxes,
43159 scores: $scores
43160 }, attrs);
43161 }
43162 var nonMaxSuppression = /* @__PURE__ */op({
43163 nonMaxSuppression_: nonMaxSuppression_
43164 });
43165
43166 /**
43167 * @license
43168 * Copyright 2019 Google LLC. All Rights Reserved.
43169 * Licensed under the Apache License, Version 2.0 (the "License");
43170 * you may not use this file except in compliance with the License.
43171 * You may obtain a copy of the License at
43172 *
43173 * http://www.apache.org/licenses/LICENSE-2.0
43174 *
43175 * Unless required by applicable law or agreed to in writing, software
43176 * distributed under the License is distributed on an "AS IS" BASIS,
43177 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
43178 * See the License for the specific language governing permissions and
43179 * limitations under the License.
43180 * =============================================================================
43181 */
43182 /**
43183 * Inserts a value into a sorted array. This method allows duplicate, meaning it
43184 * allows inserting duplicate value, in which case, the element will be inserted
43185 * at the lowest index of the value.
43186 * @param arr The array to modify.
43187 * @param element The element to insert.
43188 * @param comparator Optional. If no comparator is specified, elements are
43189 * compared using array_util.defaultComparator, which is suitable for Strings
43190 * and Numbers in ascending arrays. If the array contains multiple instances of
43191 * the target value, the left-most instance will be returned. To provide a
43192 * comparator, it should take 2 arguments to compare and return a negative,
43193 * zero, or a positive number.
43194 */
43195 function binaryInsert(arr, element, comparator) {
43196 var index = binarySearch(arr, element, comparator);
43197 var insertionPoint = index < 0 ? -(index + 1) : index;
43198 arr.splice(insertionPoint, 0, element);
43199 }
43200 /**
43201 * Searches the array for the target using binary search, returns the index
43202 * of the found element, or position to insert if element not found. If no
43203 * comparator is specified, elements are compared using array_
43204 * util.defaultComparator, which is suitable for Strings and Numbers in
43205 * ascending arrays. If the array contains multiple instances of the target
43206 * value, the left-most instance will be returned.
43207 * @param arr The array to be searched in.
43208 * @param target The target to be searched for.
43209 * @param comparator Should take 2 arguments to compare and return a negative,
43210 * zero, or a positive number.
43211 * @return Lowest index of the target value if found, otherwise the insertion
43212 * point where the target should be inserted, in the form of
43213 * (-insertionPoint - 1).
43214 */
43215 function binarySearch(arr, target, comparator) {
43216 return binarySearch_(arr, target, comparator || defaultComparator);
43217 }
43218 /**
43219 * Compares its two arguments for order.
43220 * @param a The first element to be compared.
43221 * @param b The second element to be compared.
43222 * @return A negative number, zero, or a positive number as the first
43223 * argument is less than, equal to, or greater than the second.
43224 */
43225 function defaultComparator(a, b) {
43226 return a > b ? 1 : a < b ? -1 : 0;
43227 }
43228 function binarySearch_(arr, target, comparator) {
43229 var left = 0;
43230 var right = arr.length;
43231 var middle = 0;
43232 var found = false;
43233 while (left < right) {
43234 middle = left + (right - left >>> 1);
43235 var compareResult = comparator(target, arr[middle]);
43236 if (compareResult > 0) {
43237 left = middle + 1;
43238 } else {
43239 right = middle;
43240 // If compareResult is 0, the value is found. We record it is found,
43241 // and then keep looking because there may be duplicate.
43242 found = !compareResult;
43243 }
43244 }
43245 return found ? left : -left - 1;
43246 }
43247
43248 function nonMaxSuppressionV3Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
43249 return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 /* softNmsSigma */);
43250 }
43251
43252 function nonMaxSuppressionV4Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
43253 return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 /* softNmsSigma */, false /* returnScoresTensor */, padToMaxOutputSize /* padToMaxOutputSize */, true
43254 /* returnValidOutputs */);
43255 }
43256
43257 function nonMaxSuppressionV5Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
43258 return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, true /* returnScoresTensor */);
43259 }
43260
43261 function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
43262 var returnScoresTensor = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : false;
43263 var padToMaxOutputSize = arguments.length > 7 && arguments[7] !== undefined ? arguments[7] : false;
43264 var returnValidOutputs = arguments.length > 8 && arguments[8] !== undefined ? arguments[8] : false;
43265 // The list is sorted in ascending order, so that we can always pop the
43266 // candidate with the largest score in O(1) time.
43267 var candidates = [];
43268 for (var i = 0; i < scores.length; i++) {
43269 if (scores[i] > scoreThreshold) {
43270 candidates.push({
43271 score: scores[i],
43272 boxIndex: i,
43273 suppressBeginIndex: 0
43274 });
43275 }
43276 }
43277 candidates.sort(ascendingComparator);
43278 // If softNmsSigma is 0, the outcome of this algorithm is exactly same as
43279 // before.
43280 var scale = softNmsSigma > 0 ? -0.5 / softNmsSigma : 0.0;
43281 var selectedIndices = [];
43282 var selectedScores = [];
43283 while (selectedIndices.length < maxOutputSize && candidates.length > 0) {
43284 var candidate = candidates.pop();
43285 var originalScore = candidate.score,
43286 boxIndex = candidate.boxIndex,
43287 suppressBeginIndex = candidate.suppressBeginIndex;
43288 if (originalScore < scoreThreshold) {
43289 break;
43290 }
43291 // Overlapping boxes are likely to have similar scores, therefore we
43292 // iterate through the previously selected boxes backwards in order to
43293 // see if candidate's score should be suppressed. We use
43294 // suppressBeginIndex to track and ensure a candidate can be suppressed
43295 // by a selected box no more than once. Also, if the overlap exceeds
43296 // iouThreshold, we simply ignore the candidate.
43297 var ignoreCandidate = false;
43298 for (var j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) {
43299 var iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]);
43300 if (iou >= iouThreshold) {
43301 ignoreCandidate = true;
43302 break;
43303 }
43304 candidate.score = candidate.score * suppressWeight(iouThreshold, scale, iou);
43305 if (candidate.score <= scoreThreshold) {
43306 break;
43307 }
43308 }
43309 // At this point, if `candidate.score` has not dropped below
43310 // `scoreThreshold`, then we know that we went through all of the
43311 // previous selections and can safely update `suppressBeginIndex` to the
43312 // end of the selected array. Then we can re-insert the candidate with
43313 // the updated score and suppressBeginIndex back in the candidate list.
43314 // If on the other hand, `candidate.score` has dropped below the score
43315 // threshold, we will not add it back to the candidates list.
43316 candidate.suppressBeginIndex = selectedIndices.length;
43317 if (!ignoreCandidate) {
43318 // Candidate has passed all the tests, and is not suppressed, so
43319 // select the candidate.
43320 if (candidate.score === originalScore) {
43321 selectedIndices.push(boxIndex);
43322 selectedScores.push(candidate.score);
43323 } else if (candidate.score > scoreThreshold) {
43324 // Candidate's score is suppressed but is still high enough to be
43325 // considered, so add back to the candidates list.
43326 binaryInsert(candidates, candidate, ascendingComparator);
43327 }
43328 }
43329 }
43330 // NonMaxSuppressionV4 feature: padding output to maxOutputSize.
43331 var validOutputs = selectedIndices.length;
43332 var elemsToPad = maxOutputSize - validOutputs;
43333 if (padToMaxOutputSize && elemsToPad > 0) {
43334 selectedIndices.push.apply(selectedIndices, _toConsumableArray(new Array(elemsToPad).fill(0)));
43335 selectedScores.push.apply(selectedScores, _toConsumableArray(new Array(elemsToPad).fill(0.0)));
43336 }
43337 var result = {
43338 selectedIndices: selectedIndices
43339 };
43340 if (returnScoresTensor) {
43341 result['selectedScores'] = selectedScores;
43342 }
43343 if (returnValidOutputs) {
43344 result['validOutputs'] = validOutputs;
43345 }
43346 return result;
43347 }
43348 function intersectionOverUnion(boxes, i, j) {
43349 var iCoord = boxes.subarray(i * 4, i * 4 + 4);
43350 var jCoord = boxes.subarray(j * 4, j * 4 + 4);
43351 var yminI = Math.min(iCoord[0], iCoord[2]);
43352 var xminI = Math.min(iCoord[1], iCoord[3]);
43353 var ymaxI = Math.max(iCoord[0], iCoord[2]);
43354 var xmaxI = Math.max(iCoord[1], iCoord[3]);
43355 var yminJ = Math.min(jCoord[0], jCoord[2]);
43356 var xminJ = Math.min(jCoord[1], jCoord[3]);
43357 var ymaxJ = Math.max(jCoord[0], jCoord[2]);
43358 var xmaxJ = Math.max(jCoord[1], jCoord[3]);
43359 var areaI = (ymaxI - yminI) * (xmaxI - xminI);
43360 var areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
43361 if (areaI <= 0 || areaJ <= 0) {
43362 return 0.0;
43363 }
43364 var intersectionYmin = Math.max(yminI, yminJ);
43365 var intersectionXmin = Math.max(xminI, xminJ);
43366 var intersectionYmax = Math.min(ymaxI, ymaxJ);
43367 var intersectionXmax = Math.min(xmaxI, xmaxJ);
43368 var intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) * Math.max(intersectionXmax - intersectionXmin, 0.0);
43369 return intersectionArea / (areaI + areaJ - intersectionArea);
43370 }
43371 // A Gaussian penalty function, this method always returns values in [0, 1].
43372 // The weight is a function of similarity, the more overlap two boxes are, the
43373 // smaller the weight is,meaning highly overlapping boxes will be significantly
43374 // penalized. On the other hand, a non-overlapping box will not be penalized.
43375 function suppressWeight(iouThreshold, scale, iou) {
43376 var weight = Math.exp(scale * iou * iou);
43377 return iou <= iouThreshold ? weight : 0.0;
43378 }
43379 function ascendingComparator(c1, c2) {
43380 // For objects with same scores, we make the object with the larger index go
43381 // first. In an array that pops from the end, this means that the object with
43382 // the smaller index will be popped first. This ensures the same output as
43383 // the TensorFlow python version.
43384 return c1.score - c2.score || c1.score === c2.score && c2.boxIndex - c1.boxIndex;
43385 }
43386
43387 /**
43388 * Performs non maximum suppression of bounding boxes based on
43389 * iou (intersection over union).
43390 *
43391 * This is the async version of `nonMaxSuppression`
43392 *
43393 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
43394 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
43395 * the bounding box.
43396 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
43397 * @param maxOutputSize The maximum number of boxes to be selected.
43398 * @param iouThreshold A float representing the threshold for deciding whether
43399 * boxes overlap too much with respect to IOU. Must be between [0, 1].
43400 * Defaults to 0.5 (50% box overlap).
43401 * @param scoreThreshold A threshold for deciding when to remove boxes based
43402 * on score. Defaults to -inf, which means any score is accepted.
43403 * @return A 1D tensor with the selected box indices.
43404 *
43405 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
43406 */
43407 function nonMaxSuppressionAsync_(_x, _x2, _x3) {
43408 return _nonMaxSuppressionAsync_.apply(this, arguments);
43409 }
43410 function _nonMaxSuppressionAsync_() {
43411 _nonMaxSuppressionAsync_ = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(boxes, scores, maxOutputSize) {
43412 var iouThreshold,
43413 scoreThreshold,
43414 $boxes,
43415 $scores,
43416 inputs,
43417 boxesAndScores,
43418 boxesVals,
43419 scoresVals,
43420 _nonMaxSuppressionV3I,
43421 selectedIndices,
43422 _args = arguments;
43423 return _regeneratorRuntime().wrap(function _callee$(_context) {
43424 while (1) switch (_context.prev = _context.next) {
43425 case 0:
43426 iouThreshold = _args.length > 3 && _args[3] !== undefined ? _args[3] : 0.5;
43427 scoreThreshold = _args.length > 4 && _args[4] !== undefined ? _args[4] : Number.NEGATIVE_INFINITY;
43428 $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
43429 $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
43430 inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
43431 maxOutputSize = inputs.maxOutputSize;
43432 iouThreshold = inputs.iouThreshold;
43433 scoreThreshold = inputs.scoreThreshold;
43434 _context.next = 10;
43435 return Promise.all([$boxes.data(), $scores.data()]);
43436 case 10:
43437 boxesAndScores = _context.sent;
43438 boxesVals = boxesAndScores[0];
43439 scoresVals = boxesAndScores[1]; // We call a cpu based impl directly with the typedarray data here rather
43440 // than a kernel because all kernels are synchronous (and thus cannot await
43441 // .data()).
43442 _nonMaxSuppressionV3I = nonMaxSuppressionV3Impl$2(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold), selectedIndices = _nonMaxSuppressionV3I.selectedIndices;
43443 if ($boxes !== boxes) {
43444 $boxes.dispose();
43445 }
43446 if ($scores !== scores) {
43447 $scores.dispose();
43448 }
43449 return _context.abrupt("return", tensor1d(selectedIndices, 'int32'));
43450 case 17:
43451 case "end":
43452 return _context.stop();
43453 }
43454 }, _callee);
43455 }));
43456 return _nonMaxSuppressionAsync_.apply(this, arguments);
43457 }
43458 var nonMaxSuppressionAsync = nonMaxSuppressionAsync_;
43459
43460 /**
43461 * @license
43462 * Copyright 2020 Google LLC. All Rights Reserved.
43463 * Licensed under the Apache License, Version 2.0 (the "License");
43464 * you may not use this file except in compliance with the License.
43465 * You may obtain a copy of the License at
43466 *
43467 * http://www.apache.org/licenses/LICENSE-2.0
43468 *
43469 * Unless required by applicable law or agreed to in writing, software
43470 * distributed under the License is distributed on an "AS IS" BASIS,
43471 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
43472 * See the License for the specific language governing permissions and
43473 * limitations under the License.
43474 * =============================================================================
43475 */
43476 /**
43477 * Performs non maximum suppression of bounding boxes based on
43478 * iou (intersection over union).
43479 *
43480 * This op also supports a Soft-NMS mode (cf.
43481 * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
43482 * of other overlapping boxes, therefore favoring different regions of the image
43483 * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
43484 * parameter to be larger than 0.
43485 *
43486 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
43487 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
43488 * the bounding box.
43489 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
43490 * @param maxOutputSize The maximum number of boxes to be selected.
43491 * @param iouThreshold A float representing the threshold for deciding whether
43492 * boxes overlap too much with respect to IOU. Must be between [0, 1].
43493 * Defaults to 0.5 (50% box overlap).
43494 * @param scoreThreshold A threshold for deciding when to remove boxes based
43495 * on score. Defaults to -inf, which means any score is accepted.
43496 * @param softNmsSigma A float representing the sigma parameter for Soft NMS.
43497 * When sigma is 0, it falls back to nonMaxSuppression.
43498 * @return A map with the following properties:
43499 * - selectedIndices: A 1D tensor with the selected box indices.
43500 * - selectedScores: A 1D tensor with the corresponding scores for each
43501 * selected box.
43502 *
43503 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
43504 */
43505 function nonMaxSuppressionWithScore_(boxes, scores, maxOutputSize) {
43506 var iouThreshold = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 0.5;
43507 var scoreThreshold = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : Number.NEGATIVE_INFINITY;
43508 var softNmsSigma = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 0.0;
43509 var $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
43510 var $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
43511 var params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
43512 maxOutputSize = params.maxOutputSize;
43513 iouThreshold = params.iouThreshold;
43514 scoreThreshold = params.scoreThreshold;
43515 softNmsSigma = params.softNmsSigma;
43516 var inputs = {
43517 boxes: $boxes,
43518 scores: $scores
43519 };
43520 var attrs = {
43521 maxOutputSize: maxOutputSize,
43522 iouThreshold: iouThreshold,
43523 scoreThreshold: scoreThreshold,
43524 softNmsSigma: softNmsSigma
43525 };
43526 // tslint:disable-next-line: no-unnecessary-type-assertion
43527 var result = ENGINE.runKernel(NonMaxSuppressionV5, inputs, attrs);
43528 return {
43529 selectedIndices: result[0],
43530 selectedScores: result[1]
43531 };
43532 }
43533 var nonMaxSuppressionWithScore = /* @__PURE__ */op({
43534 nonMaxSuppressionWithScore_: nonMaxSuppressionWithScore_
43535 });
43536
43537 /**
43538 * Asynchronously performs non maximum suppression of bounding boxes based on
43539 * iou (intersection over union).
43540 *
43541 * This op also supports a Soft-NMS mode (cf.
43542 * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
43543 * of other overlapping boxes, therefore favoring different regions of the image
43544 * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
43545 * parameter to be larger than 0.
43546 *
43547 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
43548 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
43549 * the bounding box.
43550 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
43551 * @param maxOutputSize The maximum number of boxes to be selected.
43552 * @param iouThreshold A float representing the threshold for deciding whether
43553 * boxes overlap too much with respect to IOU. Must be between [0, 1].
43554 * Defaults to 0.5 (50% box overlap).
43555 * @param scoreThreshold A threshold for deciding when to remove boxes based
43556 * on score. Defaults to -inf, which means any score is accepted.
43557 * @param softNmsSigma A float representing the sigma parameter for Soft NMS.
43558 * When sigma is 0, it falls back to nonMaxSuppression.
43559 * @return A map with the following properties:
43560 * - selectedIndices: A 1D tensor with the selected box indices.
43561 * - selectedScores: A 1D tensor with the corresponding scores for each
43562 * selected box.
43563 *
43564 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
43565 */
43566 function nonMaxSuppressionWithScoreAsync_(_x, _x2, _x3) {
43567 return _nonMaxSuppressionWithScoreAsync_.apply(this, arguments);
43568 }
43569 function _nonMaxSuppressionWithScoreAsync_() {
43570 _nonMaxSuppressionWithScoreAsync_ = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(boxes, scores, maxOutputSize) {
43571 var iouThreshold,
43572 scoreThreshold,
43573 softNmsSigma,
43574 $boxes,
43575 $scores,
43576 params,
43577 boxesAndScores,
43578 boxesVals,
43579 scoresVals,
43580 _nonMaxSuppressionV5I,
43581 selectedIndices,
43582 selectedScores,
43583 _args = arguments;
43584 return _regeneratorRuntime().wrap(function _callee$(_context) {
43585 while (1) switch (_context.prev = _context.next) {
43586 case 0:
43587 iouThreshold = _args.length > 3 && _args[3] !== undefined ? _args[3] : 0.5;
43588 scoreThreshold = _args.length > 4 && _args[4] !== undefined ? _args[4] : Number.NEGATIVE_INFINITY;
43589 softNmsSigma = _args.length > 5 && _args[5] !== undefined ? _args[5] : 0.0;
43590 $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
43591 $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
43592 params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
43593 maxOutputSize = params.maxOutputSize;
43594 iouThreshold = params.iouThreshold;
43595 scoreThreshold = params.scoreThreshold;
43596 softNmsSigma = params.softNmsSigma;
43597 _context.next = 12;
43598 return Promise.all([$boxes.data(), $scores.data()]);
43599 case 12:
43600 boxesAndScores = _context.sent;
43601 boxesVals = boxesAndScores[0];
43602 scoresVals = boxesAndScores[1]; // We call a cpu based impl directly with the typedarray data here rather
43603 // than a kernel because all kernels are synchronous (and thus cannot await
43604 // .data()).
43605 _nonMaxSuppressionV5I = nonMaxSuppressionV5Impl$2(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma), selectedIndices = _nonMaxSuppressionV5I.selectedIndices, selectedScores = _nonMaxSuppressionV5I.selectedScores;
43606 if ($boxes !== boxes) {
43607 $boxes.dispose();
43608 }
43609 if ($scores !== scores) {
43610 $scores.dispose();
43611 }
43612 return _context.abrupt("return", {
43613 selectedIndices: tensor1d(selectedIndices, 'int32'),
43614 selectedScores: tensor1d(selectedScores)
43615 });
43616 case 19:
43617 case "end":
43618 return _context.stop();
43619 }
43620 }, _callee);
43621 }));
43622 return _nonMaxSuppressionWithScoreAsync_.apply(this, arguments);
43623 }
43624 var nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_;
43625
43626 /**
43627 * @license
43628 * Copyright 2020 Google LLC. All Rights Reserved.
43629 * Licensed under the Apache License, Version 2.0 (the "License");
43630 * you may not use this file except in compliance with the License.
43631 * You may obtain a copy of the License at
43632 *
43633 * http://www.apache.org/licenses/LICENSE-2.0
43634 *
43635 * Unless required by applicable law or agreed to in writing, software
43636 * distributed under the License is distributed on an "AS IS" BASIS,
43637 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
43638 * See the License for the specific language governing permissions and
43639 * limitations under the License.
43640 * =============================================================================
43641 */
43642 /**
43643 * Asynchronously performs non maximum suppression of bounding boxes based on
43644 * iou (intersection over union), with an option to pad results.
43645 *
43646 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
43647 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
43648 * the bounding box.
43649 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
43650 * @param maxOutputSize The maximum number of boxes to be selected.
43651 * @param iouThreshold A float representing the threshold for deciding whether
43652 * boxes overlap too much with respect to IOU. Must be between [0, 1].
43653 * Defaults to 0.5 (50% box overlap).
43654 * @param scoreThreshold A threshold for deciding when to remove boxes based
43655 * on score. Defaults to -inf, which means any score is accepted.
43656 * @param padToMaxOutputSize Defaults to false. If true, size of output
43657 * `selectedIndices` is padded to maxOutputSize.
43658 * @return A map with the following properties:
43659 * - selectedIndices: A 1D tensor with the selected box indices.
43660 * - validOutputs: A scalar denoting how many elements in `selectedIndices`
43661 * are valid. Valid elements occur first, then padding.
43662 *
43663 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
43664 */
43665 function nonMaxSuppressionPadded_(boxes, scores, maxOutputSize) {
43666 var iouThreshold = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 0.5;
43667 var scoreThreshold = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : Number.NEGATIVE_INFINITY;
43668 var padToMaxOutputSize = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : false;
43669 var $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
43670 var $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
43671 var params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null /* softNmsSigma */);
43672 var $maxOutputSize = params.maxOutputSize;
43673 var $iouThreshold = params.iouThreshold;
43674 var $scoreThreshold = params.scoreThreshold;
43675 var inputs = {
43676 boxes: $boxes,
43677 scores: $scores
43678 };
43679 var attrs = {
43680 maxOutputSize: $maxOutputSize,
43681 iouThreshold: $iouThreshold,
43682 scoreThreshold: $scoreThreshold,
43683 padToMaxOutputSize: padToMaxOutputSize
43684 };
43685 // tslint:disable-next-line: no-unnecessary-type-assertion
43686 var result = ENGINE.runKernel(NonMaxSuppressionV4, inputs, attrs);
43687 return {
43688 selectedIndices: result[0],
43689 validOutputs: result[1]
43690 };
43691 }
43692 var nonMaxSuppressionPadded = /* @__PURE__ */op({
43693 nonMaxSuppressionPadded_: nonMaxSuppressionPadded_
43694 });
43695
43696 /**
43697 * Asynchronously performs non maximum suppression of bounding boxes based on
43698 * iou (intersection over union), with an option to pad results.
43699 *
43700 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
43701 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
43702 * the bounding box.
43703 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
43704 * @param maxOutputSize The maximum number of boxes to be selected.
43705 * @param iouThreshold A float representing the threshold for deciding whether
43706 * boxes overlap too much with respect to IOU. Must be between [0, 1].
43707 * Defaults to 0.5 (50% box overlap).
43708 * @param scoreThreshold A threshold for deciding when to remove boxes based
43709 * on score. Defaults to -inf, which means any score is accepted.
43710 * @param padToMaxOutputSize Defaults to false. If true, size of output
43711 * `selectedIndices` is padded to maxOutputSize.
43712 * @return A map with the following properties:
43713 * - selectedIndices: A 1D tensor with the selected box indices.
43714 * - validOutputs: A scalar denoting how many elements in `selectedIndices`
43715 * are valid. Valid elements occur first, then padding.
43716 *
43717 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
43718 */
43719 function nonMaxSuppressionPaddedAsync_(_x, _x2, _x3) {
43720 return _nonMaxSuppressionPaddedAsync_.apply(this, arguments);
43721 }
43722 function _nonMaxSuppressionPaddedAsync_() {
43723 _nonMaxSuppressionPaddedAsync_ = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(boxes, scores, maxOutputSize) {
43724 var iouThreshold,
43725 scoreThreshold,
43726 padToMaxOutputSize,
43727 $boxes,
43728 $scores,
43729 params,
43730 $maxOutputSize,
43731 $iouThreshold,
43732 $scoreThreshold,
43733 _yield$Promise$all,
43734 _yield$Promise$all2,
43735 boxesVals,
43736 scoresVals,
43737 _nonMaxSuppressionV4I,
43738 selectedIndices,
43739 validOutputs,
43740 _args = arguments;
43741 return _regeneratorRuntime().wrap(function _callee$(_context) {
43742 while (1) switch (_context.prev = _context.next) {
43743 case 0:
43744 iouThreshold = _args.length > 3 && _args[3] !== undefined ? _args[3] : 0.5;
43745 scoreThreshold = _args.length > 4 && _args[4] !== undefined ? _args[4] : Number.NEGATIVE_INFINITY;
43746 padToMaxOutputSize = _args.length > 5 && _args[5] !== undefined ? _args[5] : false;
43747 $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
43748 $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
43749 params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null /* softNmsSigma */);
43750 $maxOutputSize = params.maxOutputSize;
43751 $iouThreshold = params.iouThreshold;
43752 $scoreThreshold = params.scoreThreshold;
43753 _context.next = 11;
43754 return Promise.all([$boxes.data(), $scores.data()]);
43755 case 11:
43756 _yield$Promise$all = _context.sent;
43757 _yield$Promise$all2 = _slicedToArray(_yield$Promise$all, 2);
43758 boxesVals = _yield$Promise$all2[0];
43759 scoresVals = _yield$Promise$all2[1];
43760 // We call a cpu based impl directly with the typedarray data here rather
43761 // than a kernel because all kernels are synchronous (and thus cannot await
43762 // .data()).
43763 _nonMaxSuppressionV4I = nonMaxSuppressionV4Impl$2(boxesVals, scoresVals, $maxOutputSize, $iouThreshold, $scoreThreshold, padToMaxOutputSize), selectedIndices = _nonMaxSuppressionV4I.selectedIndices, validOutputs = _nonMaxSuppressionV4I.validOutputs;
43764 if ($boxes !== boxes) {
43765 $boxes.dispose();
43766 }
43767 if ($scores !== scores) {
43768 $scores.dispose();
43769 }
43770 return _context.abrupt("return", {
43771 selectedIndices: tensor1d(selectedIndices, 'int32'),
43772 validOutputs: scalar(validOutputs, 'int32')
43773 });
43774 case 19:
43775 case "end":
43776 return _context.stop();
43777 }
43778 }, _callee);
43779 }));
43780 return _nonMaxSuppressionPaddedAsync_.apply(this, arguments);
43781 }
43782 var nonMaxSuppressionPaddedAsync = nonMaxSuppressionPaddedAsync_;
43783
43784 /**
43785 * Bilinear resize a single 3D image or a batch of 3D images to a new shape.
43786 *
43787 * @param images The images, of rank 4 or rank 3, of shape
43788 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
43789 * @param size The new shape `[newHeight, newWidth]` to resize the
43790 * images to. Each channel is resized individually.
43791 * @param alignCorners Defaults to `false`. If true, rescale
43792 * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
43793 * corners of images and resized images. If false, rescale by
43794 * `new_height / height`. Treat similarly the width dimension.
43795 * @param halfPixelCenters Defaults to `false`. Whether to assume pixel centers
43796 * are at 0.5, which would make the floating point coordinates of the top
43797 * left pixel 0.5, 0.5.
43798 *
43799 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
43800 */
43801 function resizeBilinear_(images, size) {
43802 var alignCorners = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
43803 var halfPixelCenters = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
43804 var $images = convertToTensor(images, 'images', 'resizeBilinear');
43805 assert$1($images.rank === 3 || $images.rank === 4, function () {
43806 return "Error in resizeBilinear: x must be rank 3 or 4, but got " + "rank ".concat($images.rank, ".");
43807 });
43808 assert$1(size.length === 2, function () {
43809 return "Error in resizeBilinear: new shape must 2D, but got shape " + "".concat(size, ".");
43810 });
43811 assert$1(halfPixelCenters === false || alignCorners === false, function () {
43812 return "Error in resizeBilinear: If halfPixelCenters is true, " + "alignCorners must be false.";
43813 });
43814 var batchImages = $images;
43815 var reshapedTo4D = false;
43816 if ($images.rank === 3) {
43817 reshapedTo4D = true;
43818 batchImages = reshape$3($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
43819 }
43820 var _size = _slicedToArray(size, 0);
43821 var inputs = {
43822 images: batchImages
43823 };
43824 var attrs = {
43825 alignCorners: alignCorners,
43826 halfPixelCenters: halfPixelCenters,
43827 size: size
43828 };
43829 // tslint:disable-next-line: no-unnecessary-type-assertion
43830 var res = ENGINE.runKernel(ResizeBilinear, inputs, attrs);
43831 if (reshapedTo4D) {
43832 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
43833 }
43834 return res;
43835 }
43836 var resizeBilinear$3 = /* @__PURE__ */op({
43837 resizeBilinear_: resizeBilinear_
43838 });
43839
43840 /**
43841 * NearestNeighbor resize a batch of 3D images to a new shape.
43842 *
43843 * @param images The images, of rank 4 or rank 3, of shape
43844 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
43845 * @param size The new shape `[newHeight, newWidth]` to resize the
43846 * images to. Each channel is resized individually.
43847 * @param alignCorners Defaults to False. If true, rescale
43848 * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
43849 * corners of images and resized images. If false, rescale by
43850 * `new_height / height`. Treat similarly the width dimension.
43851 * @param halfPixelCenters Defaults to `false`. Whether to assume pixels are of
43852 * half the actual dimensions, and yield more accurate resizes. This flag
43853 * would also make the floating point coordinates of the top left pixel
43854 * 0.5, 0.5.
43855 *
43856 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
43857 */
43858 function resizeNearestNeighbor_(images, size) {
43859 var alignCorners = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
43860 var halfPixelCenters = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
43861 var $images = convertToTensor(images, 'images', 'resizeNearestNeighbor');
43862 assert$1($images.rank === 3 || $images.rank === 4, function () {
43863 return "Error in resizeNearestNeighbor: x must be rank 3 or 4, but got " + "rank ".concat($images.rank, ".");
43864 });
43865 assert$1(size.length === 2, function () {
43866 return "Error in resizeNearestNeighbor: new shape must 2D, but got shape " + "".concat(size, ".");
43867 });
43868 assert$1($images.dtype === 'float32' || $images.dtype === 'int32', function () {
43869 return '`images` must have `int32` or `float32` as dtype';
43870 });
43871 assert$1(halfPixelCenters === false || alignCorners === false, function () {
43872 return "Error in resizeNearestNeighbor: If halfPixelCenters is true, " + "alignCorners must be false.";
43873 });
43874 var batchImages = $images;
43875 var reshapedTo4D = false;
43876 if ($images.rank === 3) {
43877 reshapedTo4D = true;
43878 batchImages = reshape$3($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
43879 }
43880 var _size = _slicedToArray(size, 0);
43881 var inputs = {
43882 images: batchImages
43883 };
43884 var attrs = {
43885 alignCorners: alignCorners,
43886 halfPixelCenters: halfPixelCenters,
43887 size: size
43888 };
43889 // tslint:disable-next-line: no-unnecessary-type-assertion
43890 var res = ENGINE.runKernel(ResizeNearestNeighbor, inputs, attrs);
43891 if (reshapedTo4D) {
43892 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
43893 }
43894 return res;
43895 }
43896 var resizeNearestNeighbor$2 = /* @__PURE__ */op({
43897 resizeNearestNeighbor_: resizeNearestNeighbor_
43898 });
43899
43900 /**
43901 * Performs image binarization with corresponding threshold
43902 * (depends on the method)value, which creates a binary image from a grayscale.
43903 * @param image 3d tensor of shape [imageHeight,imageWidth, depth],
43904 * where imageHeight and imageWidth must be positive.The image color
43905 * range should be [0, 255].
43906 * @param method Optional string from `'binary' | 'otsu'`
43907 * which specifies the method for thresholding. Defaults to 'binary'.
43908 * @param inverted Optional boolean whichspecifies
43909 * if colours should be inverted. Defaults to false.
43910 * @param threshValue Optional number which defines threshold value from 0 to 1.
43911 * Defaults to 0.5.
43912 * @return A 3d tensor of shape [imageHeight,imageWidth, depth], which
43913 * contains binarized image.
43914 */
43915 function threshold_(image) {
43916 var method = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 'binary';
43917 var inverted = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
43918 var threshValue = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 0.5;
43919 var $image = convertToTensor(image, 'image', 'threshold');
43920 /* 0.2989, 0.5870, 0.1140 are represent luma coefficients in CCIR601.
43921 Reference for converting between RGB and grayscale: https://en.wikipedia.org/wiki/Luma_%28video%29 */
43922 var RED_INTENCITY_COEF = 0.2989;
43923 var GREEN_INTENCITY_COEF = 0.5870;
43924 var BLUE_INTENCITY_COEF = 0.1140;
43925 var totalPixelsInImage = $image.shape[0] * $image.shape[1];
43926 var $threshold = mul(tensor1d([threshValue]), 255);
43927 var r, g, b, grayscale;
43928 assert$1($image.rank === 3, function () {
43929 return 'Error in threshold: image must be rank 3,' + "but got rank ".concat($image.rank, ".");
43930 });
43931 assert$1($image.shape[2] === 3 || $image.shape[2] === 1, function () {
43932 return 'Error in threshold: ' + 'image color channel must be equal to 3 or 1' + "but got ".concat($image.shape[2], ".");
43933 });
43934 assert$1($image.dtype === 'int32' || $image.dtype === 'float32', function () {
43935 return 'Error in dtype: image dtype must be int32 or float32,' + "but got dtype ".concat($image.dtype, ".");
43936 });
43937 assert$1(method === 'otsu' || method === 'binary', function () {
43938 return "Method must be binary or otsu, but was ".concat(method);
43939 });
43940 if ($image.shape[2] === 3) {
43941 var _split = split$3($image, [1, 1, 1], -1);
43942 var _split2 = _slicedToArray(_split, 3);
43943 r = _split2[0];
43944 g = _split2[1];
43945 b = _split2[2];
43946 var $r = mul(r, RED_INTENCITY_COEF);
43947 var $g = mul(g, GREEN_INTENCITY_COEF);
43948 var $b = mul(b, BLUE_INTENCITY_COEF);
43949 grayscale = add$3(add$3($r, $g), $b);
43950 } else {
43951 grayscale = image;
43952 }
43953 if (method === 'otsu') {
43954 var $histogram = bincount$2(cast$3(round$2(grayscale), 'int32'), tensor([]), 256);
43955 $threshold = otsu($histogram, totalPixelsInImage);
43956 }
43957 var invCondition = inverted ? lessEqual$2(grayscale, $threshold) : greater$3(grayscale, $threshold);
43958 var result = cast$3(mul(invCondition, 255), 'int32');
43959 return result;
43960 }
43961 function otsu(histogram, total) {
43962 var bestThresh = tensor1d([-1]);
43963 var bestInBetVar = tensor1d([0]);
43964 var cInBetVar = tensor1d([0]);
43965 var classFirst, classSecond, meanFirst, meanSec, weightForeground, weightBack;
43966 for (var index = 0; index < histogram.size - 1; index++) {
43967 classFirst = slice$2(histogram, 0, index + 1);
43968 classSecond = slice$2(histogram, index + 1);
43969 weightForeground = div$1(sum$3(classFirst), total);
43970 weightBack = div$1(sum$3(classSecond), total);
43971 var meanFirstDivA = sum$3(mul(classFirst, range$3(0, classFirst.size)));
43972 meanFirst = div$1(meanFirstDivA, sum$3(classFirst));
43973 var meanSecFill = fill$2(classSecond.shape, classFirst.size);
43974 var meanSecAdd = add$3(range$3(0, classSecond.size), meanSecFill);
43975 var meanSecMul = mul(classSecond, meanSecAdd);
43976 meanSec = div$1(sum$3(meanSecMul), sum$3(classSecond));
43977 var cInBetVarSubA = sub$2(meanFirst, meanSec);
43978 var cInBetVarSubB = sub$2(meanFirst, meanSec);
43979 var cInBetVarMul = mul(weightForeground, weightBack);
43980 cInBetVar = mul(mul(cInBetVarMul, cInBetVarSubA), cInBetVarSubB);
43981 var condition = greater$3(cInBetVar, bestInBetVar);
43982 bestInBetVar = where(condition, cInBetVar, bestInBetVar);
43983 bestThresh = where(condition, tensor1d([index]), bestThresh);
43984 }
43985 return bestThresh;
43986 }
43987 var threshold$1 = /* @__PURE__ */op({
43988 threshold_: threshold_
43989 });
43990
43991 /**
43992 * @license
43993 * Copyright 2021 Google LLC. All Rights Reserved.
43994 * Licensed under the Apache License, Version 2.0 (the "License");
43995 * you may not use this file except in compliance with the License.
43996 * You may obtain a copy of the License at
43997 *
43998 * http://www.apache.org/licenses/LICENSE-2.0
43999 *
44000 * Unless required by applicable law or agreed to in writing, software
44001 * distributed under the License is distributed on an "AS IS" BASIS,
44002 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44003 * See the License for the specific language governing permissions and
44004 * limitations under the License.
44005 * =============================================================================
44006 */
44007 /**
44008 * Applies the given transform(s) to the image(s).
44009 *
44010 * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
44011 * @param transforms Projective transform matrix/matrices. A tensor1d of length
44012 * 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0,
44013 * b1, b2, c0, c1], then it maps the output point (x, y) to a transformed
44014 * input point (x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k),
44015 * where k = c0 x + c1 y + 1. The transforms are inverted compared to the
44016 * transform mapping input points to output points.
44017 * @param interpolation Interpolation mode.
44018 * Supported values: 'nearest', 'bilinear'. Default to 'nearest'.
44019 * @param fillMode Points outside the boundaries of the input are filled
44020 * according to the given mode, one of 'constant', 'reflect', 'wrap',
44021 * 'nearest'. Default to 'constant'.
44022 * 'reflect': (d c b a | a b c d | d c b a ) The input is extended by
44023 * reflecting about the edge of the last pixel.
44024 * 'constant': (k k k k | a b c d | k k k k) The input is extended by
44025 * filling all values beyond the edge with the same constant value k.
44026 * 'wrap': (a b c d | a b c d | a b c d) The input is extended by
44027 * wrapping around to the opposite edge.
44028 * 'nearest': (a a a a | a b c d | d d d d) The input is extended by
44029 * the nearest pixel.
44030 * @param fillValue A float represents the value to be filled outside the
44031 * boundaries when fillMode is 'constant'.
44032 * @param Output dimension after the transform, [height, width]. If undefined,
44033 * output is the same size as input image.
44034 *
44035 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
44036 */
44037 function transform_(image, transforms) {
44038 var interpolation = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 'nearest';
44039 var fillMode = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'constant';
44040 var fillValue = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 0;
44041 var outputShape = arguments.length > 5 ? arguments[5] : undefined;
44042 var $image = convertToTensor(image, 'image', 'transform', 'float32');
44043 var $transforms = convertToTensor(transforms, 'transforms', 'transform', 'float32');
44044 assert$1($image.rank === 4, function () {
44045 return 'Error in transform: image must be rank 4,' + "but got rank ".concat($image.rank, ".");
44046 });
44047 assert$1($transforms.rank === 2 && ($transforms.shape[0] === $image.shape[0] || $transforms.shape[0] === 1) && $transforms.shape[1] === 8, function () {
44048 return "Error in transform: Input transform should be batch x 8 or 1 x 8";
44049 });
44050 assert$1(outputShape == null || outputShape.length === 2, function () {
44051 return 'Error in transform: outputShape must be [height, width] or null, ' + "but got ".concat(outputShape, ".");
44052 });
44053 var inputs = {
44054 image: $image,
44055 transforms: $transforms
44056 };
44057 var attrs = {
44058 interpolation: interpolation,
44059 fillMode: fillMode,
44060 fillValue: fillValue,
44061 outputShape: outputShape
44062 };
44063 return ENGINE.runKernel(Transform, inputs, attrs);
44064 }
44065 var transform$2 = /* @__PURE__ */op({
44066 transform_: transform_
44067 });
44068
44069 /**
44070 * Copy a tensor setting everything outside a central band in each innermost
44071 * matrix to zero.
44072 *
44073 * The band part is computed as follows: Assume input has `k` dimensions
44074 * `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where
44075 * `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
44076 * The indicator function
44077 * `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)`
44078 * `&& (num_upper < 0 || (n-m) <= num_upper)`
44079 *
44080 * ```js
44081 * const x = tf.tensor2d([[ 0, 1, 2, 3],
44082 * [-1, 0, 1, 2],
44083 * [-2, -1, 0, 1],
44084 * [-3, -2, -1, 0]]);
44085 * let y = tf.linalg.bandPart(x, 1, -1);
44086 * y.print(); // [[ 0, 1, 2, 3],
44087 * // [-1, 0, 1, 2],
44088 * // [ 0, -1, 0, 1],
44089 * // [ 0, 0 , -1, 0]]
44090 * let z = tf.linalg.bandPart(x, 2, 1);
44091 * z.print(); // [[ 0, 1, 0, 0],
44092 * // [-1, 0, 1, 0],
44093 * // [-2, -1, 0, 1],
44094 * // [ 0, -2, -1, 0]]
44095 * ```
44096 *
44097 * @param x Rank `k` tensor
44098 * @param numLower Number of subdiagonals to keep.
44099 * If negative, keep entire lower triangle.
44100 * @param numUpper Number of subdiagonals to keep.
44101 * If negative, keep entire upper triangle.
44102 * @returns Rank `k` tensor of the same shape as input.
44103 * The extracted banded tensor.
44104 *
44105 * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
44106 */
44107 function bandPart_(a, numLower, numUpper) {
44108 var $a = convertToTensor(a, 'a', 'bandPart');
44109 assert$1($a.rank >= 2, function () {
44110 return "bandPart(): Rank must be at least 2, got ".concat($a.rank, ".");
44111 });
44112 var shape = $a.shape;
44113 var _$a$shape$slice = $a.shape.slice(-2),
44114 _$a$shape$slice2 = _slicedToArray(_$a$shape$slice, 2),
44115 M = _$a$shape$slice2[0],
44116 N = _$a$shape$slice2[1];
44117 var $numLower;
44118 var $numUpper;
44119 if (typeof numLower === 'number') {
44120 assert$1(numLower % 1 === 0, function () {
44121 return "bandPart(): numLower must be an integer, got ".concat(numLower, ".");
44122 });
44123 assert$1(numLower <= M, function () {
44124 return "bandPart(): numLower (".concat(numLower, ")") + " must not be greater than the number of rows (".concat(M, ").");
44125 });
44126 $numLower = convertToTensor(numLower < 0 ? M : numLower, 'numLower', 'bandPart');
44127 } else {
44128 assert$1(numLower.dtype === 'int32', function () {
44129 return "bandPart(): numLower's dtype must be an int32.";
44130 });
44131 // If numLower is a Scalar, checking `numLower <= M` could hurt performance,
44132 // but minimum(numLower, M) could avoid unexpected results.
44133 $numLower = where(less$3(numLower, 0), M, minimum$4(numLower, M));
44134 }
44135 if (typeof numUpper === 'number') {
44136 assert$1(numUpper % 1 === 0, function () {
44137 return "bandPart(): numUpper must be an integer, got ".concat(numUpper, ".");
44138 });
44139 assert$1(numUpper <= N, function () {
44140 return "bandPart(): numUpper (".concat(numUpper, ")") + " must not be greater than the number of columns (".concat(N, ").");
44141 });
44142 $numUpper = convertToTensor(numUpper < 0 ? N : numUpper, 'numUpper', 'bandPart');
44143 } else {
44144 assert$1(numUpper.dtype === 'int32', function () {
44145 return "bandPart(): numUpper's dtype must be an int32.";
44146 });
44147 $numUpper = where(less$3(numUpper, 0), N, minimum$4(numUpper, N));
44148 }
44149 var i = reshape$3(range$3(0, M, 1, 'int32'), [-1, 1]);
44150 var j = range$3(0, N, 1, 'int32');
44151 var ij = sub$2(i, j);
44152 var inBand = logicalAnd$2(lessEqual$2(ij, $numLower), greaterEqual$2(ij, neg$2($numUpper)));
44153 var zero = zeros$2([M, N], $a.dtype);
44154 return reshape$3(stack(unstack(reshape$3($a, [-1, M, N])).map(function (mat) {
44155 return where(inBand, mat, zero);
44156 })), shape);
44157 }
44158 var bandPart = /* @__PURE__ */op({
44159 bandPart_: bandPart_
44160 });
44161
44162 /**
44163 * @license
44164 * Copyright 2020 Google LLC. All Rights Reserved.
44165 * Licensed under the Apache License, Version 2.0 (the "License");
44166 * you may not use this file except in compliance with the License.
44167 * You may obtain a copy of the License at
44168 *
44169 * http://www.apache.org/licenses/LICENSE-2.0
44170 *
44171 * Unless required by applicable law or agreed to in writing, software
44172 * distributed under the License is distributed on an "AS IS" BASIS,
44173 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44174 * See the License for the specific language governing permissions and
44175 * limitations under the License.
44176 * =============================================================================
44177 */
44178 /**
44179 * Gram-Schmidt orthogonalization.
44180 *
44181 * ```js
44182 * const x = tf.tensor2d([[1, 2], [3, 4]]);
44183 * let y = tf.linalg.gramSchmidt(x);
44184 * y.print();
44185 * console.log('Orthogonalized:');
44186 * y.dot(y.transpose()).print(); // should be nearly the identity matrix.
44187 * console.log('First row direction maintained:');
44188 * const data = await y.array();
44189 * console.log(data[0][1] / data[0][0]); // should be nearly 2.
44190 * ```
44191 *
44192 * @param xs The vectors to be orthogonalized, in one of the two following
44193 * formats:
44194 * - An Array of `tf.Tensor1D`.
44195 * - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows
44196 * of `xs`.
44197 * In each case, all the vectors must have the same length and the length
44198 * must be greater than or equal to the number of vectors.
44199 * @returns The orthogonalized and normalized vectors or matrix.
44200 * Orthogonalization means that the vectors or the rows of the matrix
44201 * are orthogonal (zero inner products). Normalization means that each
44202 * vector or each row of the matrix has an L2 norm that equals `1`.
44203 *
44204 * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
44205 */
44206 function gramSchmidt_(xs) {
44207 var inputIsTensor2D;
44208 if (Array.isArray(xs)) {
44209 inputIsTensor2D = false;
44210 assert$1(xs != null && xs.length > 0, function () {
44211 return 'Gram-Schmidt process: input must not be null, undefined, or ' + 'empty';
44212 });
44213 var dim = xs[0].shape[0];
44214 var _loop = function _loop(i) {
44215 assert$1(xs[i].shape[0] === dim, function () {
44216 return 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' + "(".concat(xs[i].shape[0], " vs. ").concat(dim, ")");
44217 });
44218 };
44219 for (var i = 1; i < xs.length; ++i) {
44220 _loop(i);
44221 }
44222 } else {
44223 inputIsTensor2D = true;
44224 xs = split$3(xs, xs.shape[0], 0).map(function (x) {
44225 return squeeze(x, [0]);
44226 });
44227 }
44228 assert$1(xs.length <= xs[0].shape[0], function () {
44229 return "Gram-Schmidt: Number of vectors (".concat(xs.length, ") exceeds ") + "number of dimensions (".concat(xs[0].shape[0], ").");
44230 });
44231 var ys = [];
44232 var xs1d = xs;
44233 var _loop2 = function _loop2(_i) {
44234 ys.push(ENGINE.tidy(function () {
44235 var x = xs1d[_i];
44236 if (_i > 0) {
44237 for (var j = 0; j < _i; ++j) {
44238 var proj = mul(sum$3(mul(ys[j], x)), ys[j]);
44239 x = sub$2(x, proj);
44240 }
44241 }
44242 return div$1(x, norm(x, 'euclidean'));
44243 }));
44244 };
44245 for (var _i = 0; _i < xs.length; ++_i) {
44246 _loop2(_i);
44247 }
44248 if (inputIsTensor2D) {
44249 return stack(ys, 0);
44250 } else {
44251 return ys;
44252 }
44253 }
44254 var gramSchmidt = /* @__PURE__ */op({
44255 gramSchmidt_: gramSchmidt_
44256 });
44257
44258 /**
44259 * Compute QR decomposition of m-by-n matrix using Householder transformation.
44260 *
44261 * Implementation based on
44262 * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf]
44263 * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf)
44264 *
44265 * ```js
44266 * const a = tf.tensor2d([[1, 2], [3, 4]]);
44267 * let [q, r] = tf.linalg.qr(a);
44268 * console.log('Q');
44269 * q.print();
44270 * console.log('R');
44271 * r.print();
44272 * console.log('Orthogonalized');
44273 * q.dot(q.transpose()).print() // should be nearly the identity matrix.
44274 * console.log('Reconstructed');
44275 * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]];
44276 * ```
44277 *
44278 * @param x The `tf.Tensor` to be QR-decomposed. Must have rank >= 2. Suppose
44279 * it has the shape `[..., M, N]`.
44280 * @param fullMatrices An optional boolean parameter. Defaults to `false`.
44281 * If `true`, compute full-sized `Q`. If `false` (the default),
44282 * compute only the leading N columns of `Q` and `R`.
44283 * @returns An `Array` of two `tf.Tensor`s: `[Q, R]`. `Q` is a unitary matrix,
44284 * i.e., its columns all have unit norm and are mutually orthogonal.
44285 * If `M >= N`,
44286 * If `fullMatrices` is `false` (default),
44287 * - `Q` has a shape of `[..., M, N]`,
44288 * - `R` has a shape of `[..., N, N]`.
44289 * If `fullMatrices` is `true` (default),
44290 * - `Q` has a shape of `[..., M, M]`,
44291 * - `R` has a shape of `[..., M, N]`.
44292 * If `M < N`,
44293 * - `Q` has a shape of `[..., M, M]`,
44294 * - `R` has a shape of `[..., M, N]`.
44295 * @throws If the rank of `x` is less than 2.
44296 *
44297 * @doc {heading:'Operations',
44298 * subheading:'Linear Algebra',
44299 * namespace:'linalg'}
44300 */
44301 function qr_(x) {
44302 var fullMatrices = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
44303 assert$1(x.rank >= 2, function () {
44304 return "qr() requires input tensor to have a rank >= 2, but got rank ".concat(x.rank);
44305 });
44306 if (x.rank === 2) {
44307 return qr2d(x, fullMatrices);
44308 } else {
44309 // Rank > 2.
44310 // TODO(cais): Below we split the input into individual 2D tensors,
44311 // perform QR decomposition on them and then stack the results back
44312 // together. We should explore whether this can be parallelized.
44313 var outerDimsProd = x.shape.slice(0, x.shape.length - 2).reduce(function (value, prev) {
44314 return value * prev;
44315 });
44316 var x2ds = unstack(reshape$3(x, [outerDimsProd, x.shape[x.shape.length - 2], x.shape[x.shape.length - 1]]), 0);
44317 var q2ds = [];
44318 var r2ds = [];
44319 x2ds.forEach(function (x2d) {
44320 var _qr2d = qr2d(x2d, fullMatrices),
44321 _qr2d2 = _slicedToArray(_qr2d, 2),
44322 q2d = _qr2d2[0],
44323 r2d = _qr2d2[1];
44324 q2ds.push(q2d);
44325 r2ds.push(r2d);
44326 });
44327 var q = reshape$3(stack(q2ds, 0), x.shape);
44328 var r = reshape$3(stack(r2ds, 0), x.shape);
44329 return [q, r];
44330 }
44331 }
44332 function qr2d(x) {
44333 var fullMatrices = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
44334 return ENGINE.tidy(function () {
44335 assert$1(x.shape.length === 2, function () {
44336 return "qr2d() requires a 2D Tensor, but got a ".concat(x.shape.length, "D Tensor.");
44337 });
44338 var m = x.shape[0];
44339 var n = x.shape[1];
44340 var q = eye(m); // Orthogonal transform so far.
44341 var r = clone(x); // Transformed matrix so far.
44342 var one2D = tensor2d([[1]], [1, 1]);
44343 var w = clone(one2D);
44344 var iters = m >= n ? n : m;
44345 var _loop = function _loop(j) {
44346 // This tidy within the for-loop ensures we clean up temporary
44347 // tensors as soon as they are no longer needed.
44348 var rTemp = r;
44349 var wTemp = w;
44350 var qTemp = q;
44351 var _ENGINE$tidy = ENGINE.tidy(function () {
44352 // Find H = I - tau * w * w', to put zeros below R(j, j).
44353 var rjEnd1 = slice$2(r, [j, j], [m - j, 1]);
44354 var normX = norm(rjEnd1);
44355 var rjj = slice$2(r, [j, j], [1, 1]);
44356 // The sign() function returns 0 on 0, which causes division by zero.
44357 var s = where(greater$3(rjj, 0), tensor2d([[-1]]), tensor2d([[1]]));
44358 var u1 = sub$2(rjj, mul(s, normX));
44359 var wPre = div$1(rjEnd1, u1);
44360 if (wPre.shape[0] === 1) {
44361 w = clone(one2D);
44362 } else {
44363 w = concat$2([one2D, slice$2(wPre, [1, 0], [wPre.shape[0] - 1, wPre.shape[1]])], 0);
44364 }
44365 var tau = neg$2(div$1(matMul$1(s, u1), normX));
44366 // -- R := HR, Q := QH.
44367 var rjEndAll = slice$2(r, [j, 0], [m - j, n]);
44368 var tauTimesW = mul(tau, w);
44369 var wT = transpose$2(w);
44370 if (j === 0) {
44371 r = sub$2(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll)));
44372 } else {
44373 var rTimesTau = sub$2(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll)));
44374 r = concat$2([slice$2(r, [0, 0], [j, n]), rTimesTau], 0);
44375 }
44376 var tawTimesWT = transpose$2(tauTimesW);
44377 var qAllJEnd = slice$2(q, [0, j], [m, q.shape[1] - j]);
44378 if (j === 0) {
44379 q = sub$2(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT));
44380 } else {
44381 var qTimesTau = sub$2(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT));
44382 q = concat$2([slice$2(q, [0, 0], [m, j]), qTimesTau], 1);
44383 }
44384 return [w, r, q];
44385 });
44386 var _ENGINE$tidy2 = _slicedToArray(_ENGINE$tidy, 3);
44387 w = _ENGINE$tidy2[0];
44388 r = _ENGINE$tidy2[1];
44389 q = _ENGINE$tidy2[2];
44390 dispose([rTemp, wTemp, qTemp]);
44391 };
44392 for (var j = 0; j < iters; ++j) {
44393 _loop(j);
44394 }
44395 if (!fullMatrices && m > n) {
44396 q = slice$2(q, [0, 0], [m, n]);
44397 r = slice$2(r, [0, 0], [n, n]);
44398 }
44399 return [q, r];
44400 });
44401 }
44402 var qr = /* @__PURE__ */op({
44403 qr_: qr_
44404 });
44405
44406 /**
44407 * @license
44408 * Copyright 2020 Google LLC. All Rights Reserved.
44409 * Licensed under the Apache License, Version 2.0 (the "License");
44410 * you may not use this file except in compliance with the License.
44411 * You may obtain a copy of the License at
44412 *
44413 * http://www.apache.org/licenses/LICENSE-2.0
44414 *
44415 * Unless required by applicable law or agreed to in writing, software
44416 * distributed under the License is distributed on an "AS IS" BASIS,
44417 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44418 * See the License for the specific language governing permissions and
44419 * limitations under the License.
44420 * =============================================================================
44421 */
44422 exports.Reduction = void 0;
44423 (function (Reduction) {
44424 Reduction[Reduction["NONE"] = 0] = "NONE";
44425 Reduction[Reduction["MEAN"] = 1] = "MEAN";
44426 Reduction[Reduction["SUM"] = 2] = "SUM";
44427 Reduction[Reduction["SUM_BY_NONZERO_WEIGHTS"] = 3] = "SUM_BY_NONZERO_WEIGHTS";
44428 })(exports.Reduction || (exports.Reduction = {}));
44429
44430 /**
44431 * Computes the weighted loss between two tensors.
44432 *
44433 * @param losses Tensor of shape `[batch_size, d1, ..., dN]`.
44434 * @param weights Tensor whose rank is either 0, or the same rank as
44435 * `losses`, and must be broadcastable to `losses` (i.e., all
44436 * dimensions must be either `1`, or the same as the corresponding
44437 * `losses` dimension).
44438 *
44439 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
44440 */
44441 function computeWeightedLoss_(losses, weights) {
44442 var reduction = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
44443 var $losses = convertToTensor(losses, 'losses', 'computeWeightedLoss');
44444 var $weights = null;
44445 if (weights != null) {
44446 $weights = convertToTensor(weights, 'weights', 'computeWeightedLoss');
44447 }
44448 var weightedLoss = $weights == null ? $losses : mul($losses, $weights);
44449 if (reduction === exports.Reduction.NONE) {
44450 return weightedLoss;
44451 }
44452 if (reduction === exports.Reduction.SUM) {
44453 return sum$3(weightedLoss);
44454 }
44455 if (reduction === exports.Reduction.MEAN) {
44456 if ($weights == null) {
44457 return mean$3(weightedLoss);
44458 } else {
44459 var broadcastFactor = $losses.size / $weights.size;
44460 var result = div$1(sum$3(weightedLoss), sum$3($weights));
44461 return broadcastFactor > 1 ? div$1(result, scalar(broadcastFactor)) : result;
44462 }
44463 }
44464 if (reduction === exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
44465 if ($weights == null) {
44466 return div$1(sum$3(weightedLoss), scalar($losses.size));
44467 } else {
44468 var broadcastedWeights = mul($weights, ones$1($losses.shape));
44469 var numNonZeros = cast$3(sum$3(notEqual$2(broadcastedWeights, scalar(0))), 'float32');
44470 return div$1(sum$3(weightedLoss), numNonZeros);
44471 }
44472 }
44473 throw Error("Unknown reduction: ".concat(reduction));
44474 }
44475 var computeWeightedLoss$1 = /* @__PURE__ */op({
44476 computeWeightedLoss_: computeWeightedLoss_
44477 });
44478
44479 /**
44480 * @license
44481 * Copyright 2020 Google LLC. All Rights Reserved.
44482 * Licensed under the Apache License, Version 2.0 (the "License");
44483 * you may not use this file except in compliance with the License.
44484 * You may obtain a copy of the License at
44485 *
44486 * http://www.apache.org/licenses/LICENSE-2.0
44487 *
44488 * Unless required by applicable law or agreed to in writing, software
44489 * distributed under the License is distributed on an "AS IS" BASIS,
44490 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44491 * See the License for the specific language governing permissions and
44492 * limitations under the License.
44493 * =============================================================================
44494 */
44495 /**
44496 * Computes the absolute difference loss between two tensors.
44497 *
44498 * @param labels The ground truth output tensor, same dimensions as
44499 * 'predictions'.
44500 * @param predictions The predicted outputs.
44501 * @param weights Tensor whose rank is either 0, or the same rank as
44502 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
44503 * must be either `1`, or the same as the corresponding `losses`
44504 * dimension).
44505 * @param reduction Type of reduction to apply to loss. Should be of type
44506 * `Reduction`
44507 *
44508 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
44509 */
44510 function absoluteDifference_(labels, predictions, weights) {
44511 var reduction = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
44512 var $labels = convertToTensor(labels, 'labels', 'absoluteDifference');
44513 var $predictions = convertToTensor(predictions, 'predictions', 'absoluteDifference');
44514 var $weights = null;
44515 if (weights != null) {
44516 $weights = convertToTensor(weights, 'weights', 'absoluteDifference');
44517 }
44518 assertShapesMatch($labels.shape, $predictions.shape, 'Error in absoluteDifference: ');
44519 var losses = abs$2(sub$2($labels, $predictions));
44520 return computeWeightedLoss$1(losses, $weights, reduction);
44521 }
44522 var absoluteDifference = /* @__PURE__ */op({
44523 absoluteDifference_: absoluteDifference_
44524 });
44525
44526 /**
44527 * Computes the cosine distance loss between two tensors.
44528 *
44529 * @param labels The ground truth output tensor, same dimensions as
44530 * 'predictions'.
44531 * @param predictions The predicted outputs.
44532 * @param axis The dimension along which the cosine distance is computed.
44533 * @param weights Tensor whose rank is either 0, or the same rank as
44534 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
44535 * must be either `1`, or the same as the corresponding `losses`
44536 * dimension).
44537 * @param reduction Type of reduction to apply to loss. Should be of type
44538 * `Reduction`
44539 *
44540 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
44541 */
44542 function cosineDistance_(labels, predictions, axis, weights) {
44543 var reduction = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
44544 var $labels = convertToTensor(labels, 'labels', 'cosineDistance');
44545 var $predictions = convertToTensor(predictions, 'predictions', 'cosineDistance');
44546 var $weights = null;
44547 if (weights != null) {
44548 $weights = convertToTensor(weights, 'weights', 'cosineDistance');
44549 }
44550 assertShapesMatch($labels.shape, $predictions.shape, 'Error in cosineDistance: ');
44551 var one = scalar(1);
44552 var losses = sub$2(one, sum$3(mul($labels, $predictions), axis, true));
44553 return computeWeightedLoss$1(losses, $weights, reduction);
44554 }
44555 var cosineDistance = /* @__PURE__ */op({
44556 cosineDistance_: cosineDistance_
44557 });
44558
44559 /**
44560 * Computes the Hinge loss between two tensors.
44561 *
44562 * @param labels The ground truth output tensor, same dimensions as
44563 * 'predictions'.
44564 * @param predictions The predicted outputs.
44565 * @param weights Tensor whose rank is either 0, or the same rank as
44566 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
44567 * must be either `1`, or the same as the corresponding `losses`
44568 * dimension).
44569 * @param reduction Type of reduction to apply to loss. Should be of type
44570 * `Reduction`
44571 *
44572 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
44573 */
44574 function hingeLoss_(labels, predictions, weights) {
44575 var reduction = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
44576 var $labels = convertToTensor(labels, 'labels', 'hingeLoss');
44577 var $predictions = convertToTensor(predictions, 'predictions', 'hingeLoss');
44578 var $weights = null;
44579 if (weights != null) {
44580 $weights = convertToTensor(weights, 'weights', 'hingeLoss');
44581 }
44582 assertShapesMatch($labels.shape, $predictions.shape, 'Error in hingeLoss: ');
44583 var one = scalar(1);
44584 // Convert binary labels to (-1, 1)
44585 $labels = sub$2(mul(scalar(2), $labels), one);
44586 var losses = relu$2(sub$2(one, mul($labels, $predictions)));
44587 return computeWeightedLoss$1(losses, $weights, reduction);
44588 }
44589 var hingeLoss = /* @__PURE__ */op({
44590 hingeLoss_: hingeLoss_
44591 });
44592
44593 /**
44594 * @license
44595 * Copyright 2020 Google LLC. All Rights Reserved.
44596 * Licensed under the Apache License, Version 2.0 (the "License");
44597 * you may not use this file except in compliance with the License.
44598 * You may obtain a copy of the License at
44599 *
44600 * http://www.apache.org/licenses/LICENSE-2.0
44601 *
44602 * Unless required by applicable law or agreed to in writing, software
44603 * distributed under the License is distributed on an "AS IS" BASIS,
44604 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44605 * See the License for the specific language governing permissions and
44606 * limitations under the License.
44607 * =============================================================================
44608 */
44609 /**
44610 * Computes the Huber loss between two tensors.
44611 *
44612 * @param labels The ground truth output tensor, same dimensions as
44613 * 'predictions'.
44614 * @param predictions The predicted outputs.
44615 * @param weights Tensor whose rank is either 0, or the same rank as
44616 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
44617 * must be either `1`, or the same as the corresponding `losses`
44618 * dimension).
44619 * @param delta Point where Huber loss changes from quadratic to linear.
44620 * @param reduction Type of reduction to apply to loss. Should be of type
44621 * `Reduction`.
44622 *
44623 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
44624 */
44625 function huberLoss_(labels, predictions, weights) {
44626 var delta = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 1.0;
44627 var reduction = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
44628 var $labels = convertToTensor(labels, 'labels', 'huberLoss');
44629 var $predictions = convertToTensor(predictions, 'predictions', 'huberLoss');
44630 var $weights = null;
44631 if (weights != null) {
44632 $weights = convertToTensor(weights, 'weights', 'huberLoss');
44633 }
44634 assertShapesMatch($labels.shape, $predictions.shape, 'Error in huberLoss: ');
44635 var deltaScalar = scalar(delta);
44636 var error = abs$2(sub$2($predictions, $labels));
44637 var quadratic = minimum$4(error, deltaScalar);
44638 var linear = sub$2(error, quadratic);
44639 var losses = add$3(mul(scalar(0.5), square$2(quadratic)), mul(deltaScalar, linear));
44640 return computeWeightedLoss$1(losses, $weights, reduction);
44641 }
44642 var huberLoss = /* @__PURE__ */op({
44643 huberLoss_: huberLoss_
44644 });
44645
44646 /**
44647 * @license
44648 * Copyright 2020 Google LLC. All Rights Reserved.
44649 * Licensed under the Apache License, Version 2.0 (the "License");
44650 * you may not use this file except in compliance with the License.
44651 * You may obtain a copy of the License at
44652 *
44653 * http://www.apache.org/licenses/LICENSE-2.0
44654 *
44655 * Unless required by applicable law or agreed to in writing, software
44656 * distributed under the License is distributed on an "AS IS" BASIS,
44657 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44658 * See the License for the specific language governing permissions and
44659 * limitations under the License.
44660 * =============================================================================
44661 */
44662 /**
44663 * Computes the log loss between two tensors.
44664 *
44665 * @param labels The ground truth output tensor, same dimensions as
44666 * 'predictions'.
44667 * @param predictions The predicted outputs.
44668 * @param weights Tensor whose rank is either 0, or the same rank as
44669 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
44670 * must be either `1`, or the same as the corresponding `losses`
44671 * dimension).
44672 * @param epsilon A small increment to avoid taking log of zero
44673 * @param reduction Type of reduction to apply to loss. Should be of type
44674 * `Reduction`
44675 *
44676 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
44677 */
44678 function logLoss_(labels, predictions, weights) {
44679 var epsilon = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 1e-7;
44680 var reduction = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
44681 var $labels = convertToTensor(labels, 'labels', 'logLoss');
44682 var $predictions = convertToTensor(predictions, 'predictions', 'logLoss');
44683 var $weights = null;
44684 if (weights != null) {
44685 $weights = convertToTensor(weights, 'weights', 'logLoss');
44686 }
44687 assertShapesMatch($labels.shape, $predictions.shape, 'Error in logLoss: ');
44688 var one = scalar(1);
44689 var epsilonScalar = scalar(epsilon);
44690 var l1 = neg$2(mul($labels, log$2(add$3($predictions, epsilonScalar))));
44691 var l2 = mul(sub$2(one, $labels), log$2(add$3(sub$2(one, $predictions), epsilonScalar)));
44692 var losses = sub$2(l1, l2);
44693 return computeWeightedLoss$1(losses, $weights, reduction);
44694 }
44695 var logLoss = /* @__PURE__ */op({
44696 logLoss_: logLoss_
44697 });
44698
44699 /**
44700 * @license
44701 * Copyright 2020 Google LLC. All Rights Reserved.
44702 * Licensed under the Apache License, Version 2.0 (the "License");
44703 * you may not use this file except in compliance with the License.
44704 * You may obtain a copy of the License at
44705 *
44706 * http://www.apache.org/licenses/LICENSE-2.0
44707 *
44708 * Unless required by applicable law or agreed to in writing, software
44709 * distributed under the License is distributed on an "AS IS" BASIS,
44710 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44711 * See the License for the specific language governing permissions and
44712 * limitations under the License.
44713 * =============================================================================
44714 */
44715 /**
44716 * Computes the mean squared error between two tensors.
44717 *
44718 * @param labels The ground truth output tensor, same dimensions as
44719 * 'predictions'.
44720 * @param predictions The predicted outputs.
44721 * @param weights Tensor whose rank is either 0, or the same rank as
44722 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
44723 * must be either `1`, or the same as the corresponding `losses`
44724 * dimension).
44725 * @param reduction Type of reduction to apply to loss. Should be of type
44726 * `Reduction`
44727 *
44728 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
44729 */
44730 function meanSquaredError_(labels, predictions, weights) {
44731 var reduction = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
44732 var $labels = convertToTensor(labels, 'labels', 'meanSquaredError');
44733 var $predictions = convertToTensor(predictions, 'predictions', 'meanSquaredError');
44734 var $weights = null;
44735 if (weights != null) {
44736 $weights = convertToTensor(weights, 'weights', 'meanSquaredError');
44737 }
44738 assertShapesMatch($labels.shape, $predictions.shape, 'Error in meanSquaredError: ');
44739 var losses = squaredDifference$2($labels, $predictions);
44740 return computeWeightedLoss$1(losses, $weights, reduction);
44741 }
44742 var meanSquaredError$2 = /* @__PURE__ */op({
44743 meanSquaredError_: meanSquaredError_
44744 });
44745
44746 /**
44747 * @license
44748 * Copyright 2020 Google LLC. All Rights Reserved.
44749 * Licensed under the Apache License, Version 2.0 (the "License");
44750 * you may not use this file except in compliance with the License.
44751 * You may obtain a copy of the License at
44752 *
44753 * http://www.apache.org/licenses/LICENSE-2.0
44754 *
44755 * Unless required by applicable law or agreed to in writing, software
44756 * distributed under the License is distributed on an "AS IS" BASIS,
44757 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44758 * See the License for the specific language governing permissions and
44759 * limitations under the License.
44760 * =============================================================================
44761 */
44762 function sigmoidCrossEntropyWithLogits_(labels, logits) {
44763 var $labels = convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits');
44764 var $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits');
44765 assertShapesMatch($labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: ');
44766 /**
44767 * Implementation Details:
44768 *
44769 * For brevity, let `x = logits`, `z = labels`. The logistic loss is
44770 * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
44771 * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
44772 * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
44773 * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
44774 * = (1 - z) * x + log(1 + exp(-x))
44775 * = x - x * z + log(1 + exp(-x))
44776 *
44777 * For x < 0, to avoid overflow in exp(-x), we reformulate the above
44778 * x - x * z + log(1 + exp(-x))
44779 * = log(exp(x)) - x * z + log(1 + exp(-x))
44780 * = - x * z + log(1 + exp(x))
44781 *
44782 * Hence, to ensure stability and avoid overflow, the implementation uses
44783 * this equivalent formulation:
44784 * max(x, 0) - x * z + log(1 + exp(-abs(x)))
44785 */
44786 var maxOutput = relu$2($logits);
44787 var outputXTarget = mul($logits, $labels);
44788 var sigmoidOutput = log1p$2(exp$2(neg$2(abs$2($logits))));
44789 return add$3(sub$2(maxOutput, outputXTarget), sigmoidOutput);
44790 }
44791 /**
44792 * Computes the sigmoid cross entropy loss between two tensors.
44793 *
44794 * If labelSmoothing is nonzero, smooth the labels towards 1/2:
44795 *
44796 * newMulticlassLabels = multiclassLabels * (1 - labelSmoothing)
44797 * + 0.5 * labelSmoothing
44798 *
44799 * @param multiClassLabels The ground truth output tensor of shape
44800 * [batch_size, num_classes], same dimensions as 'predictions'.
44801 * @param logits The predicted outputs.
44802 * @param weights Tensor whose rank is either 0, or the same rank as
44803 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
44804 * must be either `1`, or the same as the corresponding `losses`
44805 * dimension).
44806 * @param labelSmoothing If greater than 0, then smooth the labels.
44807 * @param reduction Type of reduction to apply to loss. Should be of type
44808 * `Reduction`
44809 *
44810 * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
44811 */
44812 function sigmoidCrossEntropy_(multiClassLabels, logits, weights) {
44813 var labelSmoothing = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 0;
44814 var reduction = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
44815 var $multiClassLabels = convertToTensor(multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy');
44816 var $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy');
44817 var $weights = null;
44818 if (weights != null) {
44819 $weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy');
44820 }
44821 assertShapesMatch($multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: ');
44822 if (labelSmoothing > 0) {
44823 var labelSmoothingScalar = scalar(labelSmoothing);
44824 var one = scalar(1);
44825 var half = scalar(0.5);
44826 $multiClassLabels = add$3(mul($multiClassLabels, sub$2(one, labelSmoothingScalar)), mul(half, labelSmoothingScalar));
44827 }
44828 var losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);
44829 return computeWeightedLoss$1(losses, $weights, reduction);
44830 }
44831 var sigmoidCrossEntropy = /* @__PURE__ */op({
44832 sigmoidCrossEntropy_: sigmoidCrossEntropy_
44833 });
44834
44835 /**
44836 * Computes softmax cross entropy between logits and labels.
44837 *
44838 * Measures the probability error in discrete classification tasks in which
44839 * the classes are mutually exclusive (each entry is in exactly one class).
44840 * For example, each CIFAR-10 image is labeled with one and only one label: an
44841 * image can be a dog or a truck, but not both.
44842 *
44843 * `NOTE`: While the classes are mutually exclusive, their probabilities need
44844 * not be. All that is required is that each row of labels is a valid
44845 * probability distribution. If they are not, the computation of the gradient
44846 * will be incorrect.
44847 *
44848 * `WARNING`: This op expects unscaled logits, since it performs a softmax on
44849 * logits internally for efficiency. Do not call this op with the output of
44850 * softmax, as it will produce incorrect results.
44851 *
44852 * logits and labels must have the same shape, e.g. [batch_size, num_classes]
44853 * and the same dtype.
44854 * @param labels The labels array.
44855 * @param logits The logits array.
44856 * @param dim The dimension softmax would be performed on. Defaults to `-1`
44857 * which indicates the last dimension.
44858 */
44859 function softmaxCrossEntropyWithLogits_(labels, logits) {
44860 var dim = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : -1;
44861 if (dim === -1) {
44862 dim = logits.rank - 1;
44863 }
44864 if (dim !== logits.rank - 1) {
44865 throw Error("Softmax cross entropy along a non-last dimension is not yet " + "supported. Labels / logits was rank ".concat(logits.rank, " ") + "and dim was ".concat(dim));
44866 }
44867 // Use a custom gradient for numerical stability.
44868 var customOp = customGrad(function (labels, logits, save) {
44869 // Reference:
44870 // 1. http://cs231n.github.io/linear-classify/#softmax
44871 // 2. https://blog.feedly.com/tricks-of-the-trade-logsumexp/
44872 var keepDims = true;
44873 var lse = logSumExp(logits, [dim], keepDims);
44874 var logResult = sub$2(cast$3(logits, 'float32'), lse);
44875 save([labels, logResult]);
44876 var costVector = neg$2(mul(logResult, labels));
44877 var value = sum$3(costVector, [dim]);
44878 var gradFunc = function gradFunc(dy, saved) {
44879 var _saved = _slicedToArray(saved, 2),
44880 labels = _saved[0],
44881 logResult = _saved[1];
44882 var dyShape = expandShapeToKeepDim(dy.shape, [dim]);
44883 return [mul(reshape$3(dy, dyShape), sub$2(cast$3(labels, 'float32'), exp$2(logResult))), mul(reshape$3(dy, dyShape), sub$2(exp$2(logResult), cast$3(labels, 'float32')))];
44884 };
44885 return {
44886 value: value,
44887 gradFunc: gradFunc
44888 };
44889 });
44890 return customOp(labels, logits);
44891 }
44892 /**
44893 * Computes the softmax cross entropy loss between two tensors.
44894 *
44895 * If labelSmoothing is nonzero, smooth the labels towards 1/2:
44896 *
44897 * newOnehotLabels = onehotLabels * (1 - labelSmoothing)
44898 * + labelSmoothing / numClasses
44899 *
44900 * @param onehotLabels One hot encoded labels
44901 * [batch_size, num_classes], same dimensions as 'predictions'.
44902 * @param logits The predicted outputs.
44903 * @param weights Tensor whose rank is either 0, or 1, and must be
44904 * broadcastable to `loss` of shape [batch_size]
44905 * @param labelSmoothing If greater than 0, then smooth the labels.
44906 * @param reduction Type of reduction to apply to loss. Should be of type
44907 * `Reduction`
44908 *
44909 * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
44910 */
44911 function softmaxCrossEntropy_(onehotLabels, logits, weights) {
44912 var labelSmoothing = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 0;
44913 var reduction = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
44914 var $onehotLabels = convertToTensor(onehotLabels, 'onehotLabels', 'softmaxCrossEntropy');
44915 var $logits = convertToTensor(logits, 'logits', 'softmaxCrossEntropy');
44916 var $weights = null;
44917 if (weights != null) {
44918 $weights = convertToTensor(weights, 'weights', 'softmaxCrossEntropy');
44919 }
44920 assertShapesMatch($onehotLabels.shape, $logits.shape, 'Error in softmaxCrossEntropy: ');
44921 if (labelSmoothing > 0) {
44922 var labelSmoothingScalar = scalar(labelSmoothing);
44923 var one = scalar(1);
44924 var numClasses = scalar($onehotLabels.shape[1]);
44925 $onehotLabels = add$3(mul($onehotLabels, sub$2(one, labelSmoothingScalar)), div$1(labelSmoothingScalar, numClasses));
44926 }
44927 var losses = softmaxCrossEntropyWithLogits_($onehotLabels, $logits);
44928 return computeWeightedLoss$1(losses, $weights, reduction);
44929 }
44930 var softmaxCrossEntropy = /* @__PURE__ */op({
44931 softmaxCrossEntropy_: softmaxCrossEntropy_
44932 });
44933
44934 /**
44935 * @license
44936 * Copyright 2021 Google LLC. All Rights Reserved.
44937 * Licensed under the Apache License, Version 2.0 (the "License");
44938 * you may not use this file except in compliance with the License.
44939 * You may obtain a copy of the License at
44940 *
44941 * http://www.apache.org/licenses/LICENSE-2.0
44942 *
44943 * Unless required by applicable law or agreed to in writing, software
44944 * distributed under the License is distributed on an "AS IS" BASIS,
44945 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44946 * See the License for the specific language governing permissions and
44947 * limitations under the License.
44948 * =============================================================================
44949 */
44950 /**
44951 * The input SparseTensor is represented via the map of inputs {`indices`,
44952 * `values`, `denseShape`}. The output SparseTensor has the same `denseShape`
44953 * but with indices `outputIndices` and values `outputValues`. This op inserts a
44954 * single entry for every row that doesn't have any values. The index is created
44955 * as `[row, 0, ..., 0]` and the inserted value is `defaultValue`.
44956 *
44957 * For example, suppose `spInput` has shape [5, 6] and non-empty values:
44958 * [0, 1]: a
44959 * [0, 3]: b
44960 * [2, 0]: c
44961 * [3, 1]: d
44962 *
44963 * Rows 1 and 4 are empty, so the output will be of shape [5, 6] with values:
44964 * [0, 1]: a
44965 * [0, 3]: b
44966 * [1, 0]: `defaultValue`
44967 * [2, 0]: c
44968 * [3, 1]: d
44969 * [4, 0]: `defaultValue`
44970 *
44971 * The output SparseTensor will be in row-major order and will have the same
44972 * shape as the input.
44973 *
44974 * This op also returns an indicator vector shaped [dense_shape[0]] such that
44975 * emptyRowIndicator[i] = True iff row i was an empty row.
44976 *
44977 * And a reverse index map vector shaped [indices.shape[0]] that is used during
44978 * backpropagation, reverseIndexMap[i] = outi s.t. indices[i, j] ==
44979 * outputIndices[outi, j] for all j
44980 *
44981 * ```js
44982 * const result = tf.sparse.sparseFillEmptyRows(
44983 * [[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]],
44984 * [0, 10, 13, 14, 32, 33], [5, 6], -1);
44985 * console.log(result);
44986 * result['outputIndices'].print(); // [[0, 0], [1, 0], [1, 3], [1, 4],
44987 * // [2, 0], [3, 2], [3, 3], [4, 0]]
44988 * result['outputValues'].print(); // [0, 10, 13, 14,-1, 32, 33, -1]
44989 * result['emptyRowIndicator'].print(); // [false, false, true, false, true]
44990 * result['reverseIndexMap'].print(); // [0, 1, 2, 3, 5, 6]
44991 * ```
44992 * @param indices: 2-D. The indices of the sparse tensor.
44993 * @param values: 1-D. The values of the sparse tensor.
44994 * @param denseShape: 1-D. The shape of the sparse tensor.
44995 * @param defaultValue: 0-D. Default value to insert into location [row, 0, ...,
44996 * 0] for rows missing from the input sparse tensor.
44997 * @return A map with the following properties:
44998 * - outputIndices
44999 * - outputValues: 1-D. The values of the filled sparse tensor.
45000 * - emptyRowIndicator: 1-D. Whether the dense row was missing in the input
45001 * sparse tensor.
45002 * - reverseIndexMap: 1-D. A map from the input indices to the output
45003 * indices.
45004 * @doc {heading: 'Operations', subheading: 'Sparse'}
45005 */
45006 function sparseFillEmptyRows_(indices, values, denseShape, defaultValue) {
45007 var $indices = convertToTensor(indices, 'indices', 'sparseFillEmptyRows', 'int32');
45008 var $values = convertToTensor(values, 'values', 'sparseFillEmptyRows');
45009 var $denseShape = convertToTensor(denseShape, 'denseShape', 'sparseFillEmptyRows', 'int32');
45010 var $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseFillEmptyRows', $values.dtype);
45011 if ($indices.rank !== 2) {
45012 throw new Error("Indices should be Tensor2D but received shape\n ".concat($indices.shape));
45013 }
45014 if ($values.rank !== 1) {
45015 throw new Error("Values should be Tensor1D but received shape ".concat($values.shape));
45016 }
45017 if ($denseShape.rank !== 1) {
45018 throw new Error("Dense shape should be Tensor1D but received shape ".concat($denseShape.shape));
45019 }
45020 if ($defaultValue.rank !== 0) {
45021 throw new Error("Default value should be a scalar but received shape ".concat($defaultValue.shape));
45022 }
45023 var inputs = {
45024 indices: $indices,
45025 values: $values,
45026 denseShape: $denseShape,
45027 defaultValue: $defaultValue
45028 };
45029 var result = ENGINE.runKernel(SparseFillEmptyRows, inputs);
45030 return {
45031 outputIndices: result[0],
45032 outputValues: result[1],
45033 emptyRowIndicator: result[2],
45034 reverseIndexMap: result[3]
45035 };
45036 }
45037 var sparseFillEmptyRows$2 = /* @__PURE__ */op({
45038 sparseFillEmptyRows_: sparseFillEmptyRows_
45039 });
45040
45041 /**
45042 * @license
45043 * Copyright 2021 Google LLC. All Rights Reserved.
45044 * Licensed under the Apache License, Version 2.0 (the "License");
45045 * you may not use this file except in compliance with the License.
45046 * You may obtain a copy of the License at
45047 *
45048 * http://www.apache.org/licenses/LICENSE-2.0
45049 *
45050 * Unless required by applicable law or agreed to in writing, software
45051 * distributed under the License is distributed on an "AS IS" BASIS,
45052 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45053 * See the License for the specific language governing permissions and
45054 * limitations under the License.
45055 * =============================================================================
45056 */
45057 /**
45058 * This operation has the same semantics as reshape on the represented dense
45059 * tensor. The `inputIndices` are recomputed based on the requested `newShape`.
45060 * If one component of `newShape` is the special value -1, the size of that
45061 * dimension is computed so that the total dense size remains constant. At most
45062 * one component of `newShape` can be -1. The number of dense elements implied
45063 * by `newShape` must be the same as the number of dense elements originally
45064 * implied by `inputShape`. Reshaping does not affect the order of values in the
45065 * SparseTensor. If the input tensor has rank R_in and N non-empty values, and
45066 * `newShape` has length R_out, then `inputIndices` has shape [N, R_in],
45067 * `inputShape` has length R_in, `outputIndices` has shape [N, R_out], and
45068 * `outputShape` has length R_out.
45069 *
45070 * ```js
45071 * const result = tf.sparse.sparseReshape(
45072 * [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]],
45073 * [2, 3, 6], [9, -1]);
45074 * console.log(result);
45075 * result['outputIndices'].print(); //[[0, 0], [0, 1], [1, 2], [4, 2], [8, 1]]
45076 * result['outputShape'].print(); // [9, 4]
45077 * ```
45078 * @param inputIndices: 2-D. N x R_in matrix with the indices of non-empty
45079 * values in a SparseTensor.
45080 * @param inputShape: 1-D. R_in Tensor1D with the input SparseTensor's dense
45081 * shape.
45082 * @param newShape: 1-D. R_out Tensor1D with the requested new dense shape.
45083 * @return A map with the following properties:
45084 * - outputIndices: 2-D. N x R_out matrix with the updated indices of
45085 * non-empty values in the output SparseTensor.
45086 * - outputShape: 1-D. R_out vector with the full dense shape of the output
45087 * SparseTensor. This is the same as newShape but with any -1 dimensions
45088 * filled in.
45089 * @doc {heading: 'Operations', subheading: 'Sparse'}
45090 */
45091 function sparseReshape_(inputIndices, inputShape, newShape) {
45092 var $inputIndices = convertToTensor(inputIndices, 'inputIndices', 'sparseReshape', 'int32');
45093 var $inputShape = convertToTensor(inputShape, 'inputShape', 'sparseReshape', 'int32');
45094 var $newShape = convertToTensor(newShape, 'newShape', 'sparseReshape', 'int32');
45095 if ($inputIndices.rank !== 2) {
45096 throw new Error("Input indices should be Tensor2D but received shape\n ".concat($inputIndices.shape));
45097 }
45098 if ($inputShape.rank !== 1) {
45099 throw new Error("Input shape should be Tensor1D but received shape ".concat($inputShape.shape));
45100 }
45101 if ($newShape.rank !== 1) {
45102 throw new Error("New shape should be Tensor1D but received shape ".concat($newShape.shape));
45103 }
45104 var inputs = {
45105 inputIndices: $inputIndices,
45106 inputShape: $inputShape,
45107 newShape: $newShape
45108 };
45109 var result = ENGINE.runKernel(SparseReshape, inputs);
45110 return {
45111 outputIndices: result[0],
45112 outputShape: result[1]
45113 };
45114 }
45115 var sparseReshape$2 = /* @__PURE__ */op({
45116 sparseReshape_: sparseReshape_
45117 });
45118
45119 /**
45120 * @license
45121 * Copyright 2021 Google LLC. All Rights Reserved.
45122 * Licensed under the Apache License, Version 2.0 (the "License");
45123 * you may not use this file except in compliance with the License.
45124 * You may obtain a copy of the License at
45125 *
45126 * http://www.apache.org/licenses/LICENSE-2.0
45127 *
45128 * Unless required by applicable law or agreed to in writing, software
45129 * distributed under the License is distributed on an "AS IS" BASIS,
45130 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45131 * See the License for the specific language governing permissions and
45132 * limitations under the License.
45133 * =============================================================================
45134 */
45135 /**
45136 * Computes the mean along sparse segments of a tensor.
45137 *
45138 * ```js
45139 * const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [6,7,8,9]]);
45140 * // Select two rows, one segment.
45141 * const result1 = tf.sparse.sparseSegmentMean(c,
45142 * tf.tensor1d([0, 1], 'int32'),
45143 * tf.tensor1d([0, 0], 'int32'));
45144 * result1.print(); // [[0, 0, 0, 0]]
45145 *
45146 * // Select two rows, two segments.
45147 * const result2 = tf.sparse.sparseSegmentMean(c,
45148 * tf.tensor1d([0, 1], 'int32'),
45149 * tf.tensor1d([0, 1], 'int32'));
45150 * result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
45151 *
45152 * // Select all rows, two segments.
45153 * const result3 = tf.sparse.sparseSegmentMean(c,
45154 * tf.tensor1d([0, 1, 2], 'int32'),
45155 * tf.tensor1d([0, 1, 1], 'int32'));
45156 * result3.print(); // [[1.0, 2.0, 3.0, 4.0], [2.5, 2.5, 2.5, 2.5]]
45157 * ```
45158 * @param data: A Tensor of at least one dimension with data that will be
45159 * assembled in the output.
45160 * @param indices: A 1-D Tensor with indices into data. Has same rank as
45161 * segmentIds.
45162 * @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
45163 * should be sorted and can be repeated.
45164 * @return Has same shape as data, except for dimension 0 which has equal to
45165 * the number of segments.
45166 *
45167 * @doc {heading: 'Operations', subheading: 'Sparse'}
45168 */
45169 function sparseSegmentMean_(data, indices, segmentIds) {
45170 var $data = convertToTensor(data, 'data', 'sparseSegmentMean');
45171 var $indices = convertToTensor(indices, 'indices', 'sparseSegmentMean', 'int32');
45172 var $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentMean', 'int32');
45173 if ($data.rank < 1) {
45174 throw new Error("Data should be at least 1 dimensional but received scalar");
45175 }
45176 if ($indices.rank !== 1) {
45177 throw new Error("Indices should be Tensor1D but received shape\n ".concat($indices.shape));
45178 }
45179 if ($segmentIds.rank !== 1) {
45180 throw new Error("Segment ids should be Tensor1D but received shape\n ".concat($segmentIds.shape));
45181 }
45182 var inputs = {
45183 data: $data,
45184 indices: $indices,
45185 segmentIds: $segmentIds
45186 };
45187 return ENGINE.runKernel(SparseSegmentMean, inputs);
45188 }
45189 var sparseSegmentMean$2 = /* @__PURE__ */op({
45190 sparseSegmentMean_: sparseSegmentMean_
45191 });
45192
45193 /**
45194 * @license
45195 * Copyright 2021 Google LLC. All Rights Reserved.
45196 * Licensed under the Apache License, Version 2.0 (the "License");
45197 * you may not use this file except in compliance with the License.
45198 * You may obtain a copy of the License at
45199 *
45200 * http://www.apache.org/licenses/LICENSE-2.0
45201 *
45202 * Unless required by applicable law or agreed to in writing, software
45203 * distributed under the License is distributed on an "AS IS" BASIS,
45204 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45205 * See the License for the specific language governing permissions and
45206 * limitations under the License.
45207 * =============================================================================
45208 */
45209 /**
45210 * Computes the sum along sparse segments of a tensor.
45211 *
45212 * ```js
45213 * const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]);
45214 * // Select two rows, one segment.
45215 * const result1 = tf.sparse.sparseSegmentSum(c,
45216 * tf.tensor1d([0, 1], 'int32'),
45217 * tf.tensor1d([0, 0], 'int32'));
45218 * result1.print(); // [[0, 0, 0, 0]]
45219 *
45220 * // Select two rows, two segments.
45221 * const result2 = tf.sparse.sparseSegmentSum(c,
45222 * tf.tensor1d([0, 1], 'int32'),
45223 * tf.tensor1d([0, 1], 'int32'));
45224 * result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
45225 *
45226 * // Select all rows, two segments.
45227 * const result3 = tf.sparse.sparseSegmentSum(c,
45228 * tf.tensor1d([0, 1, 2], 'int32'),
45229 * tf.tensor1d([0, 0, 1], 'int32'));
45230 * result3.print(); // [[0, 0, 0, 0], [5, 6, 7, 8]]
45231 * ```
45232 * @param data: A Tensor of at least one dimension with data that will be
45233 * assembled in the output.
45234 * @param indices: A 1-D Tensor with indices into data. Has same rank as
45235 * segmentIds.
45236 * @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
45237 * should be sorted and can be repeated.
45238 * @return Has same shape as data, except for dimension 0 which has equal to
45239 * the number of segments.
45240 *
45241 * @doc {heading: 'Operations', subheading: 'Sparse'}
45242 */
45243 function sparseSegmentSum_(data, indices, segmentIds) {
45244 var $data = convertToTensor(data, 'data', 'sparseSegmentSum');
45245 var $indices = convertToTensor(indices, 'indices', 'sparseSegmentSum', 'int32');
45246 var $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentSum', 'int32');
45247 if ($data.rank < 1) {
45248 throw new Error("Data should be at least 1 dimensional but received scalar");
45249 }
45250 if ($indices.rank !== 1) {
45251 throw new Error("Indices should be Tensor1D but received shape\n ".concat($indices.shape));
45252 }
45253 if ($segmentIds.rank !== 1) {
45254 throw new Error("Segment ids should be Tensor1D but received shape\n ".concat($segmentIds.shape));
45255 }
45256 var inputs = {
45257 data: $data,
45258 indices: $indices,
45259 segmentIds: $segmentIds
45260 };
45261 return ENGINE.runKernel(SparseSegmentSum, inputs);
45262 }
45263 var sparseSegmentSum$2 = /* @__PURE__ */op({
45264 sparseSegmentSum_: sparseSegmentSum_
45265 });
45266
45267 /**
45268 * @license
45269 * Copyright 2021 Google LLC. All Rights Reserved.
45270 * Licensed under the Apache License, Version 2.0 (the "License");
45271 * you may not use this file except in compliance with the License.
45272 * You may obtain a copy of the License at
45273 *
45274 * http://www.apache.org/licenses/LICENSE-2.0
45275 *
45276 * Unless required by applicable law or agreed to in writing, software
45277 * distributed under the License is distributed on an "AS IS" BASIS,
45278 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45279 * See the License for the specific language governing permissions and
45280 * limitations under the License.
45281 * =============================================================================
45282 */
45283 /**
45284 * Creates ngrams from ragged string data.
45285 *
45286 * This op accepts a ragged tensor with 1 ragged dimension containing only
45287 * strings and outputs a ragged tensor with 1 ragged dimension containing ngrams
45288 * of that string, joined along the innermost axis.
45289 *
45290 * ```js
45291 * const result = tf.string.stringNGrams(
45292 * ['a', 'b', 'c', 'd'], tf.tensor1d([0, 2, 4], 'int32'),
45293 * '|', [1, 2], 'LP', 'RP', -1, false);
45294 * result['nGrams'].print(); // ['a', 'b', 'LP|a', 'a|b', 'b|RP',
45295 * // 'c', 'd', 'LP|c', 'c|d', 'd|RP']
45296 * result['nGramsSplits'].print(); // [0, 5, 10]
45297 * ```
45298 * @param data: The values tensor of the ragged string tensor to make ngrams out
45299 * of. Must be a 1D string tensor.
45300 * @param dataSplits: The splits tensor of the ragged string tensor to make
45301 * ngrams out of.
45302 * @param separator: The string to append between elements of the token. Use ""
45303 * for no separator.
45304 * @param nGramWidths: The sizes of the ngrams to create.
45305 * @param leftPad: The string to use to pad the left side of the ngram sequence.
45306 * Only used if pad_width !== 0.
45307 * @param rightPad: The string to use to pad the right side of the ngram
45308 * sequence. Only used if pad_width !== 0.
45309 * @param padWidth: The number of padding elements to add to each side of each
45310 * sequence. Note that padding will never be greater than `nGramWidths`-1
45311 * regardless of this value. If `padWidth`=-1, then add max(`nGramWidths`)-1
45312 * elements.
45313 * @param preserveShortSequences: If true, then ensure that at least one ngram
45314 * is generated for each input sequence. In particular, if an input sequence
45315 * is shorter than min(ngramWidth) + 2*padWidth, then generate a single
45316 * ngram containing the entire sequence. If false, then no ngrams are
45317 * generated for these short input sequences.
45318 * @return A map with the following properties:
45319 * - nGrams: The values tensor of the output ngrams ragged tensor.
45320 * - nGramsSplits: The splits tensor of the output ngrams ragged tensor.
45321 *
45322 * @doc {heading: 'Operations', subheading: 'String'}
45323 */
45324 function stringNGrams_(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
45325 var $data = convertToTensor(data, 'data', 'stringNGrams', 'string');
45326 if ($data.dtype !== 'string') {
45327 throw new Error('Data must be of datatype string');
45328 }
45329 if ($data.shape.length !== 1) {
45330 throw new Error("Data must be a vector, saw: ".concat($data.shape));
45331 }
45332 var $dataSplits = convertToTensor(dataSplits, 'dataSplits', 'stringNGrams');
45333 if ($dataSplits.dtype !== 'int32') {
45334 throw new Error('Data splits must be of datatype int32');
45335 }
45336 var attrs = {
45337 separator: separator,
45338 nGramWidths: nGramWidths,
45339 leftPad: leftPad,
45340 rightPad: rightPad,
45341 padWidth: padWidth,
45342 preserveShortSequences: preserveShortSequences
45343 };
45344 var inputs = {
45345 data: $data,
45346 dataSplits: $dataSplits
45347 };
45348 var result = ENGINE.runKernel(StringNGrams, inputs, attrs);
45349 return {
45350 nGrams: result[0],
45351 nGramsSplits: result[1]
45352 };
45353 }
45354 var stringNGrams$2 = /* @__PURE__ */op({
45355 stringNGrams_: stringNGrams_
45356 });
45357
45358 /**
45359 * @license
45360 * Copyright 2021 Google LLC. All Rights Reserved.
45361 * Licensed under the Apache License, Version 2.0 (the "License");
45362 * you may not use this file except in compliance with the License.
45363 * You may obtain a copy of the License at
45364 *
45365 * http://www.apache.org/licenses/LICENSE-2.0
45366 *
45367 * Unless required by applicable law or agreed to in writing, software
45368 * distributed under the License is distributed on an "AS IS" BASIS,
45369 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45370 * See the License for the specific language governing permissions and
45371 * limitations under the License.
45372 * =============================================================================
45373 */
45374 /**
45375 * Split elements of `input` based on `delimiter` into a SparseTensor .
45376 *
45377 * Let N be the size of source (typically N will be the batch size). Split each
45378 * element of `input` based on `delimiter` and return a SparseTensor containing
45379 * the splitted tokens. Empty tokens are ignored if `skipEmpty` is set to True.
45380 *
45381 * `delimiter` can be empty, or a string of split characters. If `delimiter` is
45382 * an empty string, each element of `input` is split into individual
45383 * character strings. Otherwise every character of `delimiter` is a potential
45384 * split point.
45385 *
45386 * ```js
45387 * const result = tf.string.stringSplit(['hello world', 'a b c'], ' ');
45388 * result['indices'].print(); // [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]]
45389 * result['values'].print(); // ['hello', 'world', 'a', 'b', 'c']
45390 * result['shape'].print(); // [2, 3]
45391 * ```
45392 * @param input: 1-D. Strings to split.
45393 * @param delimiter: 0-D. Delimiter characters, or empty string.
45394 * @param skipEmpty: Optional. If true, skip the empty strings from the result.
45395 * Defaults to true.
45396 * @return A map with the following properties:
45397 * - indices: A dense matrix of int32 representing the indices of the sparse
45398 * tensor.
45399 * - values: A vector of strings corresponding to the splited values.
45400 * - shape: a length-2 vector of int32 representing the shape of the sparse
45401 * tensor, where the first value is N and the second value is the maximum number
45402 * of tokens in a single input entry.
45403 *
45404 * @doc {heading: 'Operations', subheading: 'String'}
45405 */
45406 function stringSplit_(input, delimiter) {
45407 var skipEmpty = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true;
45408 var $input = convertToTensor(input, 'input', 'stringSplit', 'string');
45409 var $delimiter = convertToTensor(delimiter, 'delimiter', 'stringSplit', 'string');
45410 if ($input.rank !== 1) {
45411 throw new Error("Input should be Tensor1D but received shape ".concat($input.shape));
45412 }
45413 if ($delimiter.rank !== 0) {
45414 throw new Error("Delimiter should be a scalar but received shape ".concat($delimiter.shape));
45415 }
45416 var attrs = {
45417 skipEmpty: skipEmpty
45418 };
45419 var inputs = {
45420 input: $input,
45421 delimiter: $delimiter
45422 };
45423 var result = ENGINE.runKernel(StringSplit, inputs, attrs);
45424 return {
45425 indices: result[0],
45426 values: result[1],
45427 shape: result[2]
45428 };
45429 }
45430 var stringSplit$2 = /* @__PURE__ */op({
45431 stringSplit_: stringSplit_
45432 });
45433
45434 /**
45435 * @license
45436 * Copyright 2021 Google LLC. All Rights Reserved.
45437 * Licensed under the Apache License, Version 2.0 (the "License");
45438 * you may not use this file except in compliance with the License.
45439 * You may obtain a copy of the License at
45440 *
45441 * http://www.apache.org/licenses/LICENSE-2.0
45442 *
45443 * Unless required by applicable law or agreed to in writing, software
45444 * distributed under the License is distributed on an "AS IS" BASIS,
45445 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45446 * See the License for the specific language governing permissions and
45447 * limitations under the License.
45448 * =============================================================================
45449 */
45450 /**
45451 * Converts each string in the input Tensor to its hash mod by a number of
45452 * buckets.
45453 *
45454 * The hash function is deterministic on the content of the string within the
45455 * process and will never change. However, it is not suitable for cryptography.
45456 * This function may be used when CPU time is scarce and inputs are trusted or
45457 * unimportant. There is a risk of adversaries constructing inputs that all hash
45458 * to the same bucket.
45459 *
45460 * ```js
45461 * const result = tf.string.stringToHashBucketFast(
45462 * ['Hello', 'TensorFlow', '2.x'], 3);
45463 * result.print(); // [0, 2, 2]
45464 * ```
45465 * @param input: The strings to assign a hash bucket.
45466 * @param numBuckets: The number of buckets.
45467 * @return A Tensor of the same shape as the input tensor.
45468 *
45469 * @doc {heading: 'Operations', subheading: 'String'}
45470 */
45471 function stringToHashBucketFast_(input, numBuckets) {
45472 var $input = convertToTensor(input, 'input', 'stringToHashBucketFast', 'string');
45473 var attrs = {
45474 numBuckets: numBuckets
45475 };
45476 if (numBuckets <= 0) {
45477 throw new Error("Number of buckets must be at least 1");
45478 }
45479 var inputs = {
45480 input: $input
45481 };
45482 return ENGINE.runKernel(StringToHashBucketFast, inputs, attrs);
45483 }
45484 var stringToHashBucketFast$2 = /* @__PURE__ */op({
45485 stringToHashBucketFast_: stringToHashBucketFast_
45486 });
45487
45488 /**
45489 * @license
45490 * Copyright 2023 Google LLC.
45491 * Licensed under the Apache License, Version 2.0 (the "License");
45492 * you may not use this file except in compliance with the License.
45493 * You may obtain a copy of the License at
45494 *
45495 * http://www.apache.org/licenses/LICENSE-2.0
45496 *
45497 * Unless required by applicable law or agreed to in writing, software
45498 * distributed under the License is distributed on an "AS IS" BASIS,
45499 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45500 * See the License for the specific language governing permissions and
45501 * limitations under the License.
45502 * =============================================================================
45503 */
45504 /**
45505 * Replace the match of a `pattern` in `input` with `rewrite`.
45506 *
45507 * ```js
45508 * const result = tf.string.staticRegexReplace(
45509 * ['format this spacing better'], ' +', ' ');
45510 * result.print(); // ['format this spacing better']
45511 * ```
45512 * @param input: A Tensor of type string. The text to be processed.
45513 * @param pattern: A string. The regular expression to match the input.
45514 * @param rewrite: A string. The rewrite to be applied to the matched
45515 * expression.
45516 * @param replaceGlobal: An optional bool. Defaults to True. If True, the
45517 * replacement is global, otherwise the replacement is done only on the
45518 * first match.
45519 * @return A Tensor of type string.
45520 *
45521 * @doc {heading: 'Operations', subheading: 'String'}
45522 */
45523 function staticRegexReplace_(input, pattern, rewrite) {
45524 var replaceGlobal = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : true;
45525 var $input = convertToTensor(input, 'input', 'staticRegexReplace', 'string');
45526 var attrs = {
45527 pattern: pattern,
45528 rewrite: rewrite,
45529 replaceGlobal: replaceGlobal
45530 };
45531 return ENGINE.runKernel(StaticRegexReplace, {
45532 x: $input
45533 }, attrs);
45534 }
45535 var staticRegexReplace$2 = /* @__PURE__ */op({
45536 staticRegexReplace_: staticRegexReplace_
45537 });
45538
45539 /**
45540 * @license
45541 * Copyright 2020 Google LLC. All Rights Reserved.
45542 * Licensed under the Apache License, Version 2.0 (the "License");
45543 * you may not use this file except in compliance with the License.
45544 * You may obtain a copy of the License at
45545 *
45546 * http://www.apache.org/licenses/LICENSE-2.0
45547 *
45548 * Unless required by applicable law or agreed to in writing, software
45549 * distributed under the License is distributed on an "AS IS" BASIS,
45550 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45551 * See the License for the specific language governing permissions and
45552 * limitations under the License.
45553 * =============================================================================
45554 */
45555 var spectral$1 = {
45556 fft: fft$2,
45557 ifft: ifft$2,
45558 rfft: rfft,
45559 irfft: irfft
45560 };
45561 var signal = {
45562 hammingWindow: hammingWindow,
45563 hannWindow: hannWindow,
45564 frame: frame,
45565 stft: stft
45566 };
45567 var image$1 = {
45568 flipLeftRight: flipLeftRight,
45569 grayscaleToRGB: grayscaleToRGB,
45570 resizeNearestNeighbor: resizeNearestNeighbor$2,
45571 resizeBilinear: resizeBilinear$3,
45572 rgbToGrayscale: rgbToGrayscale,
45573 rotateWithOffset: rotateWithOffset,
45574 cropAndResize: cropAndResize$3,
45575 nonMaxSuppression: nonMaxSuppression,
45576 nonMaxSuppressionAsync: nonMaxSuppressionAsync,
45577 nonMaxSuppressionWithScore: nonMaxSuppressionWithScore,
45578 nonMaxSuppressionWithScoreAsync: nonMaxSuppressionWithScoreAsync,
45579 nonMaxSuppressionPadded: nonMaxSuppressionPadded,
45580 nonMaxSuppressionPaddedAsync: nonMaxSuppressionPaddedAsync,
45581 threshold: threshold$1,
45582 transform: transform$2
45583 };
45584 var linalg = {
45585 bandPart: bandPart,
45586 gramSchmidt: gramSchmidt,
45587 qr: qr
45588 };
45589 var losses = {
45590 absoluteDifference: absoluteDifference,
45591 computeWeightedLoss: computeWeightedLoss$1,
45592 cosineDistance: cosineDistance,
45593 hingeLoss: hingeLoss,
45594 huberLoss: huberLoss,
45595 logLoss: logLoss,
45596 meanSquaredError: meanSquaredError$2,
45597 sigmoidCrossEntropy: sigmoidCrossEntropy,
45598 softmaxCrossEntropy: softmaxCrossEntropy
45599 };
45600 var sparse$1 = {
45601 sparseFillEmptyRows: sparseFillEmptyRows$2,
45602 sparseReshape: sparseReshape$2,
45603 sparseSegmentMean: sparseSegmentMean$2,
45604 sparseSegmentSum: sparseSegmentSum$2
45605 };
45606 // tslint:disable-next-line:variable-name
45607 var string$1 = {
45608 stringNGrams: stringNGrams$2,
45609 stringSplit: stringSplit$2,
45610 stringToHashBucketFast: stringToHashBucketFast$2,
45611 staticRegexReplace: staticRegexReplace$2
45612 };
45613
45614 /**
45615 * Maps to mapping between the custom object and its name.
45616 *
45617 * After registering a custom class, these two maps will add key-value pairs
45618 * for the class object and the registered name.
45619 *
45620 * Therefore we can get the relative registered name by calling
45621 * getRegisteredName() function.
45622 *
45623 * For example:
45624 * GLOBAL_CUSTOM_OBJECT: {key=registeredName: value=corresponding
45625 * CustomObjectClass}
45626 *
45627 * GLOBAL_CUSTOM_NAMES: {key=CustomObjectClass: value=corresponding
45628 * registeredName}
45629 *
45630 */
45631 var GLOBAL_CUSTOM_OBJECT = new Map();
45632 var GLOBAL_CUSTOM_NAMES = new Map();
45633 /**
45634 * Serializable defines the serialization contract.
45635 *
45636 * TFJS requires serializable classes to return their className when asked
45637 * to avoid issues with minification.
45638 */
45639 var Serializable = /*#__PURE__*/function () {
45640 function Serializable() {
45641 _classCallCheck(this, Serializable);
45642 }
45643 _createClass(Serializable, [{
45644 key: "getClassName",
45645 value:
45646 /**
45647 * Return the class name for this class to use in serialization contexts.
45648 *
45649 * Generally speaking this will be the same thing that constructor.name
45650 * would have returned. However, the class name needs to be robust
45651 * against minification for serialization/deserialization to work properly.
45652 *
45653 * There's also places such as initializers.VarianceScaling, where
45654 * implementation details between different languages led to different
45655 * class hierarchies and a non-leaf node is used for serialization purposes.
45656 */
45657 function getClassName() {
45658 return this.constructor.className;
45659 }
45660 /**
45661 * Creates an instance of T from a ConfigDict.
45662 *
45663 * This works for most descendants of serializable. A few need to
45664 * provide special handling.
45665 * @param cls A Constructor for the class to instantiate.
45666 * @param config The Configuration for the object.
45667 */
45668 /** @nocollapse */
45669 }], [{
45670 key: "fromConfig",
45671 value: function fromConfig(cls, config) {
45672 return new cls(config);
45673 }
45674 }]);
45675 return Serializable;
45676 }();
45677 /**
45678 * Maps string keys to class constructors.
45679 *
45680 * Used during (de)serialization from the cross-language JSON format, which
45681 * requires the class name in the serialization format matches the class
45682 * names as used in Python, should it exist.
45683 */
45684 var SerializationMap = /*#__PURE__*/function () {
45685 function SerializationMap() {
45686 _classCallCheck(this, SerializationMap);
45687 this.classNameMap = {};
45688 }
45689 /**
45690 * Returns the singleton instance of the map.
45691 */
45692 _createClass(SerializationMap, null, [{
45693 key: "getMap",
45694 value: function getMap() {
45695 if (SerializationMap.instance == null) {
45696 SerializationMap.instance = new SerializationMap();
45697 }
45698 return SerializationMap.instance;
45699 }
45700 /**
45701 * Registers the class as serializable.
45702 */
45703 }, {
45704 key: "register",
45705 value: function register(cls) {
45706 SerializationMap.getMap().classNameMap[cls.className] = [cls, cls.fromConfig];
45707 }
45708 }]);
45709 return SerializationMap;
45710 }();
45711 /**
45712 * Register a class with the serialization map of TensorFlow.js.
45713 *
45714 * This is often used for registering custom Layers, so they can be
45715 * serialized and deserialized.
45716 *
45717 * Example 1. Register the class without package name and specified name.
45718 *
45719 * ```js
45720 * class MyCustomLayer extends tf.layers.Layer {
45721 * static className = 'MyCustomLayer';
45722 *
45723 * constructor(config) {
45724 * super(config);
45725 * }
45726 * }
45727 * tf.serialization.registerClass(MyCustomLayer);
45728 * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get("Custom>MyCustomLayer"));
45729 * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer));
45730 * ```
45731 *
45732 * Example 2. Register the class with package name: "Package" and specified
45733 * name: "MyLayer".
45734 * ```js
45735 * class MyCustomLayer extends tf.layers.Layer {
45736 * static className = 'MyCustomLayer';
45737 *
45738 * constructor(config) {
45739 * super(config);
45740 * }
45741 * }
45742 * tf.serialization.registerClass(MyCustomLayer, "Package", "MyLayer");
45743 * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get("Package>MyLayer"));
45744 * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer));
45745 * ```
45746 *
45747 * Example 3. Register the class with specified name: "MyLayer".
45748 * ```js
45749 * class MyCustomLayer extends tf.layers.Layer {
45750 * static className = 'MyCustomLayer';
45751 *
45752 * constructor(config) {
45753 * super(config);
45754 * }
45755 * }
45756 * tf.serialization.registerClass(MyCustomLayer, undefined, "MyLayer");
45757 * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get("Custom>MyLayer"));
45758 * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer));
45759 * ```
45760 *
45761 * Example 4. Register the class with specified package name: "Package".
45762 * ```js
45763 * class MyCustomLayer extends tf.layers.Layer {
45764 * static className = 'MyCustomLayer';
45765 *
45766 * constructor(config) {
45767 * super(config);
45768 * }
45769 * }
45770 * tf.serialization.registerClass(MyCustomLayer, "Package");
45771 * console.log(tf.serialization.GLOBALCUSTOMOBJECT
45772 * .get("Package>MyCustomLayer"));
45773 * console.log(tf.serialization.GLOBALCUSTOMNAMES
45774 * .get(MyCustomLayer));
45775 * ```
45776 *
45777 * @param cls The class to be registered. It must have a public static member
45778 * called `className` defined and the value must be a non-empty string.
45779 * @param pkg The package name that this class belongs to. This used to define
45780 * the key in GlobalCustomObject. If not defined, it defaults to `Custom`.
45781 * @param name The name that user specified. It defaults to the actual name of
45782 * the class as specified by its static `className` property.
45783 * @doc {heading: 'Models', subheading: 'Serialization', ignoreCI: true}
45784 */
45785 function registerClass(cls, pkg, name) {
45786 assert$1(cls.className != null, function () {
45787 return "Class being registered does not have the static className " + "property defined.";
45788 });
45789 assert$1(typeof cls.className === 'string', function () {
45790 return "className is required to be a string, but got type " + _typeof(cls.className);
45791 });
45792 assert$1(cls.className.length > 0, function () {
45793 return "Class being registered has an empty-string as its className, " + "which is disallowed.";
45794 });
45795 if (typeof pkg === 'undefined') {
45796 pkg = 'Custom';
45797 }
45798 if (typeof name === 'undefined') {
45799 name = cls.className;
45800 }
45801 var className = name;
45802 var registerName = pkg + '>' + className;
45803 SerializationMap.register(cls);
45804 GLOBAL_CUSTOM_OBJECT.set(registerName, cls);
45805 GLOBAL_CUSTOM_NAMES.set(cls, registerName);
45806 return cls;
45807 }
45808 /**
45809 * Get the registered name of a class. If the class has not been registered,
45810 * return the class name.
45811 *
45812 * @param cls The class we want to get register name for. It must have a public
45813 * static member called `className` defined.
45814 * @returns registered name or class name.
45815 */
45816 function getRegisteredName(cls) {
45817 if (GLOBAL_CUSTOM_NAMES.has(cls)) {
45818 return GLOBAL_CUSTOM_NAMES.get(cls);
45819 } else {
45820 return cls.className;
45821 }
45822 }
45823
45824 var serialization = {
45825 __proto__: null,
45826 Serializable: Serializable,
45827 SerializationMap: SerializationMap,
45828 getRegisteredName: getRegisteredName,
45829 registerClass: registerClass
45830 };
45831
45832 /** @doc {heading: 'Training', subheading: 'Classes', namespace: 'train'} */
45833 var Optimizer = /*#__PURE__*/function (_Serializable) {
45834 _inherits(Optimizer, _Serializable);
45835 var _super = _createSuper(Optimizer);
45836 function Optimizer() {
45837 _classCallCheck(this, Optimizer);
45838 return _super.apply(this, arguments);
45839 }
45840 _createClass(Optimizer, [{
45841 key: "minimize",
45842 value:
45843 /**
45844 * Executes `f()` and minimizes the scalar output of `f()` by computing
45845 * gradients of y with respect to the list of trainable variables provided by
45846 * `varList`. If no list is provided, it defaults to all trainable variables.
45847 *
45848 * @param f The function to execute and whose output to minimize.
45849 * @param returnCost Whether to return the scalar cost value produced by
45850 * executing `f()`.
45851 * @param varList An optional list of variables to update. If specified, only
45852 * the trainable variables in varList will be updated by minimize. Defaults to
45853 * all trainable variables.
45854 *
45855 * @doc {heading: 'Training', subheading: 'Optimizers'}
45856 */
45857 function minimize(f) {
45858 var returnCost = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
45859 var varList = arguments.length > 2 ? arguments[2] : undefined;
45860 var _this$computeGradient = this.computeGradients(f, varList),
45861 value = _this$computeGradient.value,
45862 grads = _this$computeGradient.grads;
45863 if (varList != null) {
45864 var gradArray = varList.map(function (v) {
45865 return {
45866 name: v.name,
45867 tensor: grads[v.name]
45868 };
45869 });
45870 this.applyGradients(gradArray);
45871 } else {
45872 this.applyGradients(grads);
45873 }
45874 // Dispose gradients.
45875 dispose(grads);
45876 if (returnCost) {
45877 return value;
45878 } else {
45879 value.dispose();
45880 return null;
45881 }
45882 }
45883 /**
45884 * The number of iterations that this optimizer instance has been invoked for.
45885 */
45886 }, {
45887 key: "iterations",
45888 get: function get() {
45889 if (this.iterations_ == null) {
45890 this.iterations_ = 0;
45891 }
45892 return this.iterations_;
45893 }
45894 }, {
45895 key: "incrementIterations",
45896 value: function incrementIterations() {
45897 this.iterations_ = this.iterations + 1;
45898 }
45899 /**
45900 * Executes f() and computes the gradient of the scalar output of f() with
45901 * respect to the list of trainable variables provided by `varList`. If no
45902 * list is provided, it defaults to all trainable variables.
45903 *
45904 * @param f The function to execute and whose output to use for computing
45905 * gradients with respect to variables.
45906 * @param varList An optional list of variables to compute gradients with
45907 * respect to. If specified, only the trainable variables in varList will have
45908 * gradients computed with respect to. Defaults to all trainable variables.
45909 *
45910 * @doc {heading: 'Training', subheading: 'Optimizers'}
45911 */
45912 }, {
45913 key: "computeGradients",
45914 value: function computeGradients(f, varList) {
45915 return variableGrads(f, varList);
45916 }
45917 /**
45918 * Dispose the variables (if any) owned by this optimizer instance.
45919 */
45920 }, {
45921 key: "dispose",
45922 value: function dispose$1() {
45923 if (this.iterations_ != null) {
45924 dispose(this.iterations_);
45925 }
45926 }
45927 }, {
45928 key: "saveIterations",
45929 value: function () {
45930 var _saveIterations = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
45931 return _regeneratorRuntime().wrap(function _callee$(_context) {
45932 while (1) switch (_context.prev = _context.next) {
45933 case 0:
45934 if (this.iterations_ == null) {
45935 this.iterations_ = 0;
45936 }
45937 return _context.abrupt("return", {
45938 name: 'iter',
45939 // TODO(cais): Use 'int64' type when available.
45940 tensor: scalar(this.iterations_, 'int32')
45941 });
45942 case 2:
45943 case "end":
45944 return _context.stop();
45945 }
45946 }, _callee, this);
45947 }));
45948 function saveIterations() {
45949 return _saveIterations.apply(this, arguments);
45950 }
45951 return saveIterations;
45952 }()
45953 }, {
45954 key: "getWeights",
45955 value: function () {
45956 var _getWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
45957 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
45958 while (1) switch (_context2.prev = _context2.next) {
45959 case 0:
45960 throw new Error('getWeights() is not implemented for this optimizer yet.');
45961 case 1:
45962 case "end":
45963 return _context2.stop();
45964 }
45965 }, _callee2);
45966 }));
45967 function getWeights() {
45968 return _getWeights.apply(this, arguments);
45969 }
45970 return getWeights;
45971 }()
45972 }, {
45973 key: "setWeights",
45974 value: function () {
45975 var _setWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(weightValues) {
45976 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
45977 while (1) switch (_context3.prev = _context3.next) {
45978 case 0:
45979 throw new Error("setWeights() is not implemented for this optimizer class " + "".concat(this.getClassName()));
45980 case 1:
45981 case "end":
45982 return _context3.stop();
45983 }
45984 }, _callee3, this);
45985 }));
45986 function setWeights(_x) {
45987 return _setWeights.apply(this, arguments);
45988 }
45989 return setWeights;
45990 }()
45991 /**
45992 * Extract the first element of the weight values and set it
45993 * as the iterations counter variable of this instance of optimizer.
45994 *
45995 * @param weightValues
45996 * @returns Weight values with the first element consumed and excluded.
45997 */
45998 }, {
45999 key: "extractIterations",
46000 value: function () {
46001 var _extractIterations = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(weightValues) {
46002 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
46003 while (1) switch (_context4.prev = _context4.next) {
46004 case 0:
46005 _context4.next = 2;
46006 return weightValues[0].tensor.data();
46007 case 2:
46008 this.iterations_ = _context4.sent[0];
46009 return _context4.abrupt("return", weightValues.slice(1));
46010 case 4:
46011 case "end":
46012 return _context4.stop();
46013 }
46014 }, _callee4, this);
46015 }));
46016 function extractIterations(_x2) {
46017 return _extractIterations.apply(this, arguments);
46018 }
46019 return extractIterations;
46020 }()
46021 }]);
46022 return Optimizer;
46023 }(Serializable);
46024 Object.defineProperty(Optimizer, Symbol.hasInstance, {
46025 value: function value(instance) {
46026 return instance.minimize != null && instance.computeGradients != null && instance.applyGradients != null;
46027 }
46028 });
46029
46030 /** @doclink Optimizer */
46031 var AdadeltaOptimizer = /*#__PURE__*/function (_Optimizer) {
46032 _inherits(AdadeltaOptimizer, _Optimizer);
46033 var _super = _createSuper(AdadeltaOptimizer);
46034 function AdadeltaOptimizer(learningRate, rho) {
46035 var _this;
46036 var epsilon = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : null;
46037 _classCallCheck(this, AdadeltaOptimizer);
46038 _this = _super.call(this);
46039 _this.learningRate = learningRate;
46040 _this.rho = rho;
46041 _this.epsilon = epsilon;
46042 _this.accumulatedGrads = [];
46043 _this.accumulatedUpdates = [];
46044 if (epsilon == null) {
46045 _this.epsilon = ENGINE.backend.epsilon();
46046 }
46047 return _this;
46048 }
46049 _createClass(AdadeltaOptimizer, [{
46050 key: "applyGradients",
46051 value: function applyGradients(variableGradients) {
46052 var _this2 = this;
46053 var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) {
46054 return item.name;
46055 }) : Object.keys(variableGradients);
46056 variableNames.forEach(function (name, i) {
46057 var value = ENGINE.registeredVariables[name];
46058 var trainable = false;
46059 if (_this2.accumulatedGrads[i] == null) {
46060 _this2.accumulatedGrads[i] = {
46061 originalName: "".concat(name, "/accum_grad"),
46062 variable: tidy(function () {
46063 return zerosLike$3(value).variable(trainable);
46064 })
46065 };
46066 }
46067 if (_this2.accumulatedUpdates[i] == null) {
46068 _this2.accumulatedUpdates[i] = {
46069 originalName: "".concat(name, "/accum_var"),
46070 variable: tidy(function () {
46071 return zerosLike$3(value).variable(trainable);
46072 })
46073 };
46074 }
46075 var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
46076 if (gradient == null) {
46077 return;
46078 }
46079 var accumulatedGrad = _this2.accumulatedGrads[i].variable;
46080 var accumulatedUpdate = _this2.accumulatedUpdates[i].variable;
46081 tidy(function () {
46082 var newAccumulatedGrad = add$3(mul(accumulatedGrad, _this2.rho), mul(square$2(gradient), 1 - _this2.rho));
46083 var updates = mul(div$1(sqrt$2(add$3(accumulatedUpdate, _this2.epsilon)), sqrt$2(add$3(accumulatedGrad, _this2.epsilon))), gradient);
46084 var newAccumulatedUpdate = add$3(mul(accumulatedUpdate, _this2.rho), mul(square$2(updates), 1 - _this2.rho));
46085 accumulatedGrad.assign(newAccumulatedGrad);
46086 accumulatedUpdate.assign(newAccumulatedUpdate);
46087 var newValue = add$3(mul(updates, -_this2.learningRate), value);
46088 value.assign(newValue);
46089 });
46090 });
46091 this.incrementIterations();
46092 }
46093 }, {
46094 key: "dispose",
46095 value: function dispose$1() {
46096 if (this.accumulatedUpdates != null) {
46097 dispose(this.accumulatedGrads.map(function (v) {
46098 return v.variable;
46099 }));
46100 dispose(this.accumulatedUpdates.map(function (v) {
46101 return v.variable;
46102 }));
46103 }
46104 }
46105 }, {
46106 key: "getWeights",
46107 value: function () {
46108 var _getWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
46109 var variables;
46110 return _regeneratorRuntime().wrap(function _callee$(_context) {
46111 while (1) switch (_context.prev = _context.next) {
46112 case 0:
46113 // Order matters for Python compatibility.
46114 variables = [].concat(_toConsumableArray(this.accumulatedGrads), _toConsumableArray(this.accumulatedUpdates));
46115 _context.next = 3;
46116 return this.saveIterations();
46117 case 3:
46118 _context.t0 = _context.sent;
46119 return _context.abrupt("return", [_context.t0].concat(variables.map(function (v) {
46120 return {
46121 name: v.originalName,
46122 tensor: v.variable
46123 };
46124 })));
46125 case 5:
46126 case "end":
46127 return _context.stop();
46128 }
46129 }, _callee, this);
46130 }));
46131 function getWeights() {
46132 return _getWeights.apply(this, arguments);
46133 }
46134 return getWeights;
46135 }()
46136 }, {
46137 key: "setWeights",
46138 value: function () {
46139 var _setWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(weightValues) {
46140 var variableCount, trainable;
46141 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
46142 while (1) switch (_context2.prev = _context2.next) {
46143 case 0:
46144 _context2.next = 2;
46145 return this.extractIterations(weightValues);
46146 case 2:
46147 weightValues = _context2.sent;
46148 variableCount = weightValues.length / 2;
46149 trainable = false;
46150 this.accumulatedGrads = weightValues.slice(0, variableCount).map(function (v) {
46151 return {
46152 originalName: v.name,
46153 variable: v.tensor.variable(trainable)
46154 };
46155 });
46156 this.accumulatedUpdates = weightValues.slice(variableCount, variableCount * 2).map(function (v) {
46157 return {
46158 originalName: v.name,
46159 variable: v.tensor.variable(trainable)
46160 };
46161 });
46162 case 7:
46163 case "end":
46164 return _context2.stop();
46165 }
46166 }, _callee2, this);
46167 }));
46168 function setWeights(_x) {
46169 return _setWeights.apply(this, arguments);
46170 }
46171 return setWeights;
46172 }()
46173 }, {
46174 key: "getConfig",
46175 value: function getConfig() {
46176 return {
46177 'learningRate': this.learningRate,
46178 'rho': this.rho,
46179 'epsilon': this.epsilon
46180 };
46181 }
46182 /** @nocollapse */
46183 }], [{
46184 key: "className",
46185 get: /** @nocollapse */
46186 function get() {
46187 // Name matters for Python compatibility.
46188 // This is a getter instead of a property because when it's a property, it
46189 // prevents the entire class from being tree-shaken.
46190 return 'Adadelta';
46191 }
46192 }, {
46193 key: "fromConfig",
46194 value: function fromConfig(cls, config) {
46195 return new cls(config['learningRate'], config['rho'], config['epsilon']);
46196 }
46197 }]);
46198 return AdadeltaOptimizer;
46199 }(Optimizer);
46200
46201 /** @doclink Optimizer */
46202 var AdagradOptimizer = /*#__PURE__*/function (_Optimizer) {
46203 _inherits(AdagradOptimizer, _Optimizer);
46204 var _super = _createSuper(AdagradOptimizer);
46205 function AdagradOptimizer(learningRate) {
46206 var _this;
46207 var initialAccumulatorValue = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0.1;
46208 _classCallCheck(this, AdagradOptimizer);
46209 _this = _super.call(this);
46210 _this.learningRate = learningRate;
46211 _this.initialAccumulatorValue = initialAccumulatorValue;
46212 _this.accumulatedGrads = [];
46213 return _this;
46214 }
46215 _createClass(AdagradOptimizer, [{
46216 key: "applyGradients",
46217 value: function applyGradients(variableGradients) {
46218 var _this2 = this;
46219 var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) {
46220 return item.name;
46221 }) : Object.keys(variableGradients);
46222 variableNames.forEach(function (name, i) {
46223 var value = ENGINE.registeredVariables[name];
46224 if (_this2.accumulatedGrads[i] == null) {
46225 var trainable = false;
46226 _this2.accumulatedGrads[i] = {
46227 originalName: "".concat(name, "/accumulator"),
46228 variable: tidy(function () {
46229 return fill$2(value.shape, _this2.initialAccumulatorValue).variable(trainable);
46230 })
46231 };
46232 }
46233 var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
46234 if (gradient == null) {
46235 return;
46236 }
46237 var accumulatedGrad = _this2.accumulatedGrads[i].variable;
46238 tidy(function () {
46239 var newAccumulatedGrad = add$3(accumulatedGrad, square$2(gradient));
46240 accumulatedGrad.assign(newAccumulatedGrad);
46241 var newValue = add$3(mul(div$1(gradient, sqrt$2(add$3(newAccumulatedGrad, ENGINE.backend.epsilon()))), -_this2.learningRate), value);
46242 value.assign(newValue);
46243 });
46244 });
46245 this.incrementIterations();
46246 }
46247 }, {
46248 key: "dispose",
46249 value: function dispose$1() {
46250 if (this.accumulatedGrads != null) {
46251 dispose(this.accumulatedGrads.map(function (v) {
46252 return v.variable;
46253 }));
46254 }
46255 }
46256 }, {
46257 key: "getWeights",
46258 value: function () {
46259 var _getWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
46260 return _regeneratorRuntime().wrap(function _callee$(_context) {
46261 while (1) switch (_context.prev = _context.next) {
46262 case 0:
46263 _context.next = 2;
46264 return this.saveIterations();
46265 case 2:
46266 _context.t0 = _context.sent;
46267 return _context.abrupt("return", [_context.t0].concat(this.accumulatedGrads.map(function (v) {
46268 return {
46269 name: v.originalName,
46270 tensor: v.variable
46271 };
46272 })));
46273 case 4:
46274 case "end":
46275 return _context.stop();
46276 }
46277 }, _callee, this);
46278 }));
46279 function getWeights() {
46280 return _getWeights.apply(this, arguments);
46281 }
46282 return getWeights;
46283 }()
46284 }, {
46285 key: "setWeights",
46286 value: function () {
46287 var _setWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(weightValues) {
46288 var trainable;
46289 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
46290 while (1) switch (_context2.prev = _context2.next) {
46291 case 0:
46292 _context2.next = 2;
46293 return this.extractIterations(weightValues);
46294 case 2:
46295 weightValues = _context2.sent;
46296 trainable = false;
46297 this.accumulatedGrads = weightValues.map(function (v) {
46298 return {
46299 originalName: v.name,
46300 variable: v.tensor.variable(trainable)
46301 };
46302 });
46303 case 5:
46304 case "end":
46305 return _context2.stop();
46306 }
46307 }, _callee2, this);
46308 }));
46309 function setWeights(_x) {
46310 return _setWeights.apply(this, arguments);
46311 }
46312 return setWeights;
46313 }()
46314 }, {
46315 key: "getConfig",
46316 value: function getConfig() {
46317 return {
46318 'learningRate': this.learningRate,
46319 'initialAccumulatorValue': this.initialAccumulatorValue
46320 };
46321 }
46322 /** @nocollapse */
46323 }], [{
46324 key: "className",
46325 get: /** @nocollapse */
46326 function get() {
46327 // Name matters for Python compatibility.
46328 // This is a getter instead of a property because when it's a property, it
46329 // prevents the entire class from being tree-shaken.
46330 return 'Adagrad';
46331 }
46332 }, {
46333 key: "fromConfig",
46334 value: function fromConfig(cls, config) {
46335 return new cls(config['learningRate'], config['initialAccumulatorValue']);
46336 }
46337 }]);
46338 return AdagradOptimizer;
46339 }(Optimizer);
46340
46341 var AdamOptimizer = /*#__PURE__*/function (_Optimizer) {
46342 _inherits(AdamOptimizer, _Optimizer);
46343 var _super = _createSuper(AdamOptimizer);
46344 function AdamOptimizer(learningRate, beta1, beta2) {
46345 var _this;
46346 var epsilon = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : null;
46347 _classCallCheck(this, AdamOptimizer);
46348 _this = _super.call(this);
46349 _this.learningRate = learningRate;
46350 _this.beta1 = beta1;
46351 _this.beta2 = beta2;
46352 _this.epsilon = epsilon;
46353 _this.accumulatedFirstMoment = [];
46354 _this.accumulatedSecondMoment = [];
46355 tidy(function () {
46356 // accB* will be updated by batch.
46357 _this.accBeta1 = scalar(beta1).variable();
46358 _this.accBeta2 = scalar(beta2).variable();
46359 });
46360 if (epsilon == null) {
46361 _this.epsilon = ENGINE.backend.epsilon();
46362 }
46363 return _this;
46364 }
46365 _createClass(AdamOptimizer, [{
46366 key: "applyGradients",
46367 value: function applyGradients(variableGradients) {
46368 var _this2 = this;
46369 var varNames = Array.isArray(variableGradients) ? variableGradients.map(function (v) {
46370 return v.name;
46371 }) : Object.keys(variableGradients);
46372 tidy(function () {
46373 var oneMinusAccBeta1 = sub$2(1, _this2.accBeta1);
46374 var oneMinusAccBeta2 = sub$2(1, _this2.accBeta2);
46375 varNames.forEach(function (name, i) {
46376 var value = ENGINE.registeredVariables[name];
46377 var trainable = false;
46378 if (_this2.accumulatedFirstMoment[i] == null) {
46379 _this2.accumulatedFirstMoment[i] = {
46380 originalName: "".concat(name, "/m"),
46381 variable: tidy(function () {
46382 return zerosLike$3(value).variable(trainable);
46383 })
46384 };
46385 }
46386 if (_this2.accumulatedSecondMoment[i] == null) {
46387 _this2.accumulatedSecondMoment[i] = {
46388 originalName: "".concat(name, "/v"),
46389 variable: tidy(function () {
46390 return zerosLike$3(value).variable(trainable);
46391 })
46392 };
46393 }
46394 var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
46395 if (gradient == null) {
46396 return;
46397 }
46398 var firstMoment = _this2.accumulatedFirstMoment[i].variable;
46399 var secondMoment = _this2.accumulatedSecondMoment[i].variable;
46400 var newFirstMoment = add$3(mul(firstMoment, _this2.beta1), mul(gradient, 1 - _this2.beta1));
46401 var newSecondMoment = add$3(mul(secondMoment, _this2.beta2), mul(square$2(gradient), 1 - _this2.beta2));
46402 var biasCorrectedFirstMoment = div$1(newFirstMoment, oneMinusAccBeta1);
46403 var biasCorrectedSecondMoment = div$1(newSecondMoment, oneMinusAccBeta2);
46404 firstMoment.assign(newFirstMoment);
46405 secondMoment.assign(newSecondMoment);
46406 var newValue = add$3(mul(div$1(biasCorrectedFirstMoment, add$3(sqrt$2(biasCorrectedSecondMoment), _this2.epsilon)), -_this2.learningRate), value);
46407 value.assign(newValue);
46408 });
46409 _this2.accBeta1.assign(mul(_this2.accBeta1, _this2.beta1));
46410 _this2.accBeta2.assign(mul(_this2.accBeta2, _this2.beta2));
46411 });
46412 this.incrementIterations();
46413 }
46414 }, {
46415 key: "dispose",
46416 value: function dispose$1() {
46417 this.accBeta1.dispose();
46418 this.accBeta2.dispose();
46419 if (this.accumulatedFirstMoment != null) {
46420 dispose(this.accumulatedFirstMoment.map(function (v) {
46421 return v.variable;
46422 }));
46423 }
46424 if (this.accumulatedSecondMoment != null) {
46425 dispose(this.accumulatedSecondMoment.map(function (v) {
46426 return v.variable;
46427 }));
46428 }
46429 }
46430 }, {
46431 key: "getWeights",
46432 value: function () {
46433 var _getWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
46434 var variables;
46435 return _regeneratorRuntime().wrap(function _callee$(_context) {
46436 while (1) switch (_context.prev = _context.next) {
46437 case 0:
46438 // Order matters for Python compatibility.
46439 variables = [].concat(_toConsumableArray(this.accumulatedFirstMoment), _toConsumableArray(this.accumulatedSecondMoment));
46440 _context.next = 3;
46441 return this.saveIterations();
46442 case 3:
46443 _context.t0 = _context.sent;
46444 return _context.abrupt("return", [_context.t0].concat(variables.map(function (v) {
46445 return {
46446 name: v.originalName,
46447 tensor: v.variable
46448 };
46449 })));
46450 case 5:
46451 case "end":
46452 return _context.stop();
46453 }
46454 }, _callee, this);
46455 }));
46456 function getWeights() {
46457 return _getWeights.apply(this, arguments);
46458 }
46459 return getWeights;
46460 }()
46461 }, {
46462 key: "setWeights",
46463 value: function () {
46464 var _setWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(weightValues) {
46465 var _this3 = this;
46466 var variableCount, trainable;
46467 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
46468 while (1) switch (_context2.prev = _context2.next) {
46469 case 0:
46470 _context2.next = 2;
46471 return this.extractIterations(weightValues);
46472 case 2:
46473 weightValues = _context2.sent;
46474 tidy(function () {
46475 _this3.accBeta1.assign(pow$3(_this3.beta1, _this3.iterations_ + 1));
46476 _this3.accBeta2.assign(pow$3(_this3.beta2, _this3.iterations_ + 1));
46477 });
46478 variableCount = weightValues.length / 2;
46479 trainable = false;
46480 this.accumulatedFirstMoment = weightValues.slice(0, variableCount).map(function (v) {
46481 return {
46482 originalName: v.name,
46483 variable: v.tensor.variable(trainable)
46484 };
46485 });
46486 this.accumulatedSecondMoment = weightValues.slice(variableCount, variableCount * 2).map(function (v) {
46487 return {
46488 originalName: v.name,
46489 variable: v.tensor.variable(trainable)
46490 };
46491 });
46492 case 8:
46493 case "end":
46494 return _context2.stop();
46495 }
46496 }, _callee2, this);
46497 }));
46498 function setWeights(_x) {
46499 return _setWeights.apply(this, arguments);
46500 }
46501 return setWeights;
46502 }()
46503 }, {
46504 key: "getConfig",
46505 value: function getConfig() {
46506 return {
46507 'learningRate': this.learningRate,
46508 'beta1': this.beta1,
46509 'beta2': this.beta2,
46510 'epsilon': this.epsilon
46511 };
46512 }
46513 /** @nocollapse */
46514 }], [{
46515 key: "className",
46516 get: /** @nocollapse */
46517 function get() {
46518 // Name matters for Python compatibility.
46519 // This is a getter instead of a property because when it's a property, it
46520 // prevents the entire class from being tree-shaken.
46521 return 'Adam';
46522 }
46523 }, {
46524 key: "fromConfig",
46525 value: function fromConfig(cls, config) {
46526 return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']);
46527 }
46528 }]);
46529 return AdamOptimizer;
46530 }(Optimizer);
46531
46532 var AdamaxOptimizer = /*#__PURE__*/function (_Optimizer) {
46533 _inherits(AdamaxOptimizer, _Optimizer);
46534 var _super = _createSuper(AdamaxOptimizer);
46535 function AdamaxOptimizer(learningRate, beta1, beta2) {
46536 var _this;
46537 var epsilon = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : null;
46538 var decay = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 0.0;
46539 _classCallCheck(this, AdamaxOptimizer);
46540 _this = _super.call(this);
46541 _this.learningRate = learningRate;
46542 _this.beta1 = beta1;
46543 _this.beta2 = beta2;
46544 _this.epsilon = epsilon;
46545 _this.decay = decay;
46546 _this.accumulatedFirstMoment = [];
46547 _this.accumulatedWeightedInfNorm = [];
46548 tidy(function () {
46549 _this.iteration = scalar(0).variable();
46550 _this.accBeta1 = scalar(beta1).variable();
46551 });
46552 if (epsilon == null) {
46553 _this.epsilon = ENGINE.backend.epsilon();
46554 }
46555 return _this;
46556 }
46557 _createClass(AdamaxOptimizer, [{
46558 key: "applyGradients",
46559 value: function applyGradients(variableGradients) {
46560 var _this2 = this;
46561 var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) {
46562 return item.name;
46563 }) : Object.keys(variableGradients);
46564 tidy(function () {
46565 var oneMinusAccBeta1 = sub$2(1, _this2.accBeta1);
46566 var lr = div$1(-_this2.learningRate, add$3(mul(_this2.iteration, _this2.decay), 1));
46567 variableNames.forEach(function (name, i) {
46568 var value = ENGINE.registeredVariables[name];
46569 var trainable = false;
46570 if (_this2.accumulatedFirstMoment[i] == null) {
46571 _this2.accumulatedFirstMoment[i] = {
46572 originalName: "".concat(name, "/m"),
46573 variable: zerosLike$3(value).variable(trainable)
46574 };
46575 }
46576 if (_this2.accumulatedWeightedInfNorm[i] == null) {
46577 _this2.accumulatedWeightedInfNorm[i] = {
46578 originalName: "".concat(name, "/v"),
46579 variable: zerosLike$3(value).variable(trainable)
46580 };
46581 }
46582 var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
46583 if (gradient == null) {
46584 return;
46585 }
46586 var firstMoment = _this2.accumulatedFirstMoment[i].variable;
46587 var weightedInfNorm = _this2.accumulatedWeightedInfNorm[i].variable;
46588 var newFirstMoment = add$3(mul(firstMoment, _this2.beta1), mul(gradient, 1 - _this2.beta1));
46589 var ut0 = mul(weightedInfNorm, _this2.beta2);
46590 var ut1 = abs$2(gradient);
46591 var newWeightedInfNorm = maximum$4(ut0, ut1);
46592 firstMoment.assign(newFirstMoment);
46593 weightedInfNorm.assign(newWeightedInfNorm);
46594 var newValue = add$3(mul(div$1(lr, oneMinusAccBeta1), div$1(newFirstMoment, add$3(newWeightedInfNorm, _this2.epsilon))), value);
46595 value.assign(newValue);
46596 });
46597 _this2.iteration.assign(add$3(_this2.iteration, 1));
46598 _this2.accBeta1.assign(mul(_this2.accBeta1, _this2.beta1));
46599 });
46600 this.incrementIterations();
46601 }
46602 }, {
46603 key: "dispose",
46604 value: function dispose$1() {
46605 this.accBeta1.dispose();
46606 this.iteration.dispose();
46607 if (this.accumulatedFirstMoment != null) {
46608 dispose(this.accumulatedFirstMoment.map(function (v) {
46609 return v.variable;
46610 }));
46611 }
46612 if (this.accumulatedWeightedInfNorm != null) {
46613 dispose(this.accumulatedWeightedInfNorm.map(function (v) {
46614 return v.variable;
46615 }));
46616 }
46617 }
46618 }, {
46619 key: "getWeights",
46620 value: function () {
46621 var _getWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
46622 return _regeneratorRuntime().wrap(function _callee$(_context) {
46623 while (1) switch (_context.prev = _context.next) {
46624 case 0:
46625 throw new Error('getWeights() is not implemented for Adamax yet.');
46626 case 1:
46627 case "end":
46628 return _context.stop();
46629 }
46630 }, _callee);
46631 }));
46632 function getWeights() {
46633 return _getWeights.apply(this, arguments);
46634 }
46635 return getWeights;
46636 }()
46637 }, {
46638 key: "setWeights",
46639 value: function () {
46640 var _setWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(weightValues) {
46641 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
46642 while (1) switch (_context2.prev = _context2.next) {
46643 case 0:
46644 throw new Error('setWeights() is not implemented for Adamax yet.');
46645 case 1:
46646 case "end":
46647 return _context2.stop();
46648 }
46649 }, _callee2);
46650 }));
46651 function setWeights(_x) {
46652 return _setWeights.apply(this, arguments);
46653 }
46654 return setWeights;
46655 }()
46656 }, {
46657 key: "getConfig",
46658 value: function getConfig() {
46659 return {
46660 'learningRate': this.learningRate,
46661 'beta1': this.beta1,
46662 'beta2': this.beta2,
46663 'epsilon': this.epsilon,
46664 'decay': this.decay
46665 };
46666 }
46667 /** @nocollapse */
46668 }], [{
46669 key: "className",
46670 get: /** @nocollapse */
46671 function get() {
46672 // Name matters for Python compatibility.
46673 // This is a getter instead of a property because when it's a property, it
46674 // prevents the entire class from being tree-shaken.
46675 return 'Adamax';
46676 }
46677 }, {
46678 key: "fromConfig",
46679 value: function fromConfig(cls, config) {
46680 return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon'], config['decay']);
46681 }
46682 }]);
46683 return AdamaxOptimizer;
46684 }(Optimizer);
46685
46686 /** @doclink Optimizer */
46687 var SGDOptimizer = /*#__PURE__*/function (_Optimizer) {
46688 _inherits(SGDOptimizer, _Optimizer);
46689 var _super = _createSuper(SGDOptimizer);
46690 function SGDOptimizer(learningRate) {
46691 var _this;
46692 _classCallCheck(this, SGDOptimizer);
46693 _this = _super.call(this);
46694 _this.learningRate = learningRate;
46695 _this.setLearningRate(learningRate);
46696 return _this;
46697 }
46698 _createClass(SGDOptimizer, [{
46699 key: "applyGradients",
46700 value: function applyGradients(variableGradients) {
46701 var _this2 = this;
46702 var varNames = Array.isArray(variableGradients) ? variableGradients.map(function (v) {
46703 return v.name;
46704 }) : Object.keys(variableGradients);
46705 varNames.forEach(function (name, i) {
46706 var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
46707 if (gradient == null) {
46708 return;
46709 }
46710 var value = ENGINE.registeredVariables[name];
46711 tidy(function () {
46712 var newValue = add$3(mul(_this2.c, gradient), value);
46713 value.assign(newValue);
46714 });
46715 });
46716 this.incrementIterations();
46717 }
46718 /**
46719 * Sets the learning rate of the optimizer.
46720 */
46721 }, {
46722 key: "setLearningRate",
46723 value: function setLearningRate(learningRate) {
46724 this.learningRate = learningRate;
46725 if (this.c != null) {
46726 this.c.dispose();
46727 }
46728 this.c = keep(scalar(-learningRate));
46729 }
46730 }, {
46731 key: "dispose",
46732 value: function dispose() {
46733 this.c.dispose();
46734 }
46735 }, {
46736 key: "getWeights",
46737 value: function () {
46738 var _getWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
46739 return _regeneratorRuntime().wrap(function _callee$(_context) {
46740 while (1) switch (_context.prev = _context.next) {
46741 case 0:
46742 _context.next = 2;
46743 return this.saveIterations();
46744 case 2:
46745 _context.t0 = _context.sent;
46746 return _context.abrupt("return", [_context.t0]);
46747 case 4:
46748 case "end":
46749 return _context.stop();
46750 }
46751 }, _callee, this);
46752 }));
46753 function getWeights() {
46754 return _getWeights.apply(this, arguments);
46755 }
46756 return getWeights;
46757 }()
46758 }, {
46759 key: "setWeights",
46760 value: function () {
46761 var _setWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(weightValues) {
46762 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
46763 while (1) switch (_context2.prev = _context2.next) {
46764 case 0:
46765 _context2.next = 2;
46766 return this.extractIterations(weightValues);
46767 case 2:
46768 weightValues = _context2.sent;
46769 if (!(weightValues.length !== 0)) {
46770 _context2.next = 5;
46771 break;
46772 }
46773 throw new Error('SGD optimizer does not have settable weights.');
46774 case 5:
46775 case "end":
46776 return _context2.stop();
46777 }
46778 }, _callee2, this);
46779 }));
46780 function setWeights(_x) {
46781 return _setWeights.apply(this, arguments);
46782 }
46783 return setWeights;
46784 }()
46785 }, {
46786 key: "getConfig",
46787 value: function getConfig() {
46788 return {
46789 'learningRate': this.learningRate
46790 };
46791 }
46792 /** @nocollapse */
46793 }], [{
46794 key: "className",
46795 get: /** @nocollapse */
46796 function get() {
46797 // Name matters for Python compatibility.
46798 // This is a getter instead of a property because when it's a property, it
46799 // prevents the entire class from being tree-shaken.
46800 return 'SGD';
46801 }
46802 }, {
46803 key: "fromConfig",
46804 value: function fromConfig(cls, config) {
46805 return new cls(config['learningRate']);
46806 }
46807 }]);
46808 return SGDOptimizer;
46809 }(Optimizer);
46810
46811 /** @doclink Optimizer */
46812 var MomentumOptimizer = /*#__PURE__*/function (_SGDOptimizer) {
46813 _inherits(MomentumOptimizer, _SGDOptimizer);
46814 var _super = _createSuper(MomentumOptimizer);
46815 function MomentumOptimizer(learningRate, momentum) {
46816 var _this;
46817 var useNesterov = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
46818 _classCallCheck(this, MomentumOptimizer);
46819 _this = _super.call(this, learningRate);
46820 _this.learningRate = learningRate;
46821 _this.momentum = momentum;
46822 _this.useNesterov = useNesterov;
46823 _this.accumulations = [];
46824 _this.m = scalar(_this.momentum);
46825 return _this;
46826 }
46827 _createClass(MomentumOptimizer, [{
46828 key: "applyGradients",
46829 value: function applyGradients(variableGradients) {
46830 var _this2 = this;
46831 var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) {
46832 return item.name;
46833 }) : Object.keys(variableGradients);
46834 variableNames.forEach(function (name, i) {
46835 var value = ENGINE.registeredVariables[name];
46836 if (_this2.accumulations[i] == null) {
46837 var trainable = false;
46838 _this2.accumulations[i] = {
46839 originalName: "".concat(name, "/momentum"),
46840 variable: tidy(function () {
46841 return zerosLike$3(value).variable(trainable);
46842 })
46843 };
46844 }
46845 var accumulation = _this2.accumulations[i].variable;
46846 var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
46847 if (gradient == null) {
46848 return;
46849 }
46850 tidy(function () {
46851 var newValue;
46852 var newAccumulation = add$3(mul(_this2.m, accumulation), gradient);
46853 if (_this2.useNesterov) {
46854 newValue = add$3(mul(_this2.c, add$3(gradient, mul(newAccumulation, _this2.m))), value);
46855 } else {
46856 newValue = add$3(mul(_this2.c, newAccumulation), value);
46857 }
46858 accumulation.assign(newAccumulation);
46859 value.assign(newValue);
46860 });
46861 });
46862 this.incrementIterations();
46863 }
46864 }, {
46865 key: "dispose",
46866 value: function dispose$1() {
46867 this.m.dispose();
46868 if (this.accumulations != null) {
46869 dispose(this.accumulations.map(function (v) {
46870 return v.variable;
46871 }));
46872 }
46873 }
46874 /**
46875 * Sets the momentum of the optimizer.
46876 *
46877 * @param momentum
46878 */
46879 }, {
46880 key: "setMomentum",
46881 value: function setMomentum(momentum) {
46882 this.momentum = momentum;
46883 }
46884 }, {
46885 key: "getWeights",
46886 value: function () {
46887 var _getWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
46888 return _regeneratorRuntime().wrap(function _callee$(_context) {
46889 while (1) switch (_context.prev = _context.next) {
46890 case 0:
46891 _context.next = 2;
46892 return this.saveIterations();
46893 case 2:
46894 _context.t0 = _context.sent;
46895 return _context.abrupt("return", [_context.t0].concat(this.accumulations.map(function (v) {
46896 return {
46897 name: v.originalName,
46898 tensor: v.variable
46899 };
46900 })));
46901 case 4:
46902 case "end":
46903 return _context.stop();
46904 }
46905 }, _callee, this);
46906 }));
46907 function getWeights() {
46908 return _getWeights.apply(this, arguments);
46909 }
46910 return getWeights;
46911 }()
46912 }, {
46913 key: "setWeights",
46914 value: function () {
46915 var _setWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(weightValues) {
46916 var trainable;
46917 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
46918 while (1) switch (_context2.prev = _context2.next) {
46919 case 0:
46920 _context2.next = 2;
46921 return this.extractIterations(weightValues);
46922 case 2:
46923 weightValues = _context2.sent;
46924 trainable = false;
46925 this.accumulations = weightValues.map(function (v) {
46926 return {
46927 originalName: v.name,
46928 variable: v.tensor.variable(trainable)
46929 };
46930 });
46931 case 5:
46932 case "end":
46933 return _context2.stop();
46934 }
46935 }, _callee2, this);
46936 }));
46937 function setWeights(_x) {
46938 return _setWeights.apply(this, arguments);
46939 }
46940 return setWeights;
46941 }()
46942 }, {
46943 key: "getConfig",
46944 value: function getConfig() {
46945 return {
46946 'learningRate': this.learningRate,
46947 'momentum': this.momentum,
46948 'useNesterov': this.useNesterov
46949 };
46950 }
46951 /** @nocollapse */
46952 }], [{
46953 key: "className",
46954 get: /** @nocollapse */
46955 // Name matters for Python compatibility.
46956 function get() {
46957 // Name matters for Python compatibility.
46958 // This is a getter instead of a property because when it's a property, it
46959 // prevents the entire class from being tree-shaken.
46960 return 'Momentum';
46961 }
46962 }, {
46963 key: "fromConfig",
46964 value: function fromConfig(cls, config) {
46965 return new cls(config['learningRate'], config['momentum'], config['useNesterov']);
46966 }
46967 }]);
46968 return MomentumOptimizer;
46969 }(SGDOptimizer);
46970
46971 /** @doclink Optimizer */
46972 var RMSPropOptimizer = /*#__PURE__*/function (_Optimizer) {
46973 _inherits(RMSPropOptimizer, _Optimizer);
46974 var _super = _createSuper(RMSPropOptimizer);
46975 function RMSPropOptimizer(learningRate) {
46976 var _this;
46977 var decay = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0.9;
46978 var momentum = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0.0;
46979 var epsilon = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : null;
46980 var centered = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
46981 _classCallCheck(this, RMSPropOptimizer);
46982 _this = _super.call(this);
46983 _this.learningRate = learningRate;
46984 _this.decay = decay;
46985 _this.momentum = momentum;
46986 _this.epsilon = epsilon;
46987 _this.accumulatedMeanSquares = [];
46988 _this.accumulatedMoments = [];
46989 _this.accumulatedMeanGrads = [];
46990 _this.centered = centered;
46991 if (epsilon == null) {
46992 _this.epsilon = ENGINE.backend.epsilon();
46993 }
46994 if (learningRate == null) {
46995 throw new Error("learningRate for RMSPropOptimizer must be defined.");
46996 }
46997 return _this;
46998 }
46999 _createClass(RMSPropOptimizer, [{
47000 key: "applyGradients",
47001 value: function applyGradients(variableGradients) {
47002 var _this2 = this;
47003 var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) {
47004 return item.name;
47005 }) : Object.keys(variableGradients);
47006 variableNames.forEach(function (name, i) {
47007 var value = ENGINE.registeredVariables[name];
47008 var trainable = false;
47009 if (_this2.accumulatedMeanSquares[i] == null) {
47010 _this2.accumulatedMeanSquares[i] = {
47011 originalName: "".concat(name, "/rms"),
47012 variable: tidy(function () {
47013 return zerosLike$3(value).variable(trainable);
47014 })
47015 };
47016 }
47017 if (_this2.accumulatedMoments[i] == null) {
47018 _this2.accumulatedMoments[i] = {
47019 originalName: "".concat(name, "/momentum"),
47020 variable: tidy(function () {
47021 return zerosLike$3(value).variable(trainable);
47022 })
47023 };
47024 }
47025 if (_this2.accumulatedMeanGrads[i] == null && _this2.centered) {
47026 _this2.accumulatedMeanGrads[i] = {
47027 originalName: "".concat(name, "/mg"),
47028 variable: tidy(function () {
47029 return zerosLike$3(value).variable(trainable);
47030 })
47031 };
47032 }
47033 var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
47034 if (gradient == null) {
47035 return;
47036 }
47037 var accumulatedMeanSquare = _this2.accumulatedMeanSquares[i].variable;
47038 var accumulatedMoments = _this2.accumulatedMoments[i].variable;
47039 tidy(function () {
47040 var newAccumulatedMeanSquare = add$3(mul(accumulatedMeanSquare, _this2.decay), mul(square$2(gradient), 1 - _this2.decay));
47041 if (_this2.centered) {
47042 var accumulatedMeanGrad = _this2.accumulatedMeanGrads[i].variable;
47043 // Centered gradient
47044 var newAccumulatedMeanGrad = add$3(mul(accumulatedMeanGrad, _this2.decay), mul(gradient, 1 - _this2.decay));
47045 var gradContribution = div$1(mul(gradient, _this2.learningRate), sqrt$2(sub$2(newAccumulatedMeanSquare, add$3(square$2(newAccumulatedMeanGrad), _this2.epsilon))));
47046 var newAccumulatedMoments = add$3(mul(accumulatedMoments, _this2.momentum), gradContribution);
47047 accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
47048 accumulatedMeanGrad.assign(newAccumulatedMeanGrad);
47049 accumulatedMoments.assign(newAccumulatedMoments);
47050 var newValue = sub$2(value, newAccumulatedMoments);
47051 value.assign(newValue);
47052 } else {
47053 // Plain gradient
47054 var _newAccumulatedMeanSquare = add$3(mul(accumulatedMeanSquare, _this2.decay), mul(square$2(gradient), 1 - _this2.decay));
47055 var _newAccumulatedMoments = add$3(mul(accumulatedMoments, _this2.momentum), div$1(mul(gradient, _this2.learningRate), sqrt$2(add$3(_newAccumulatedMeanSquare, _this2.epsilon))));
47056 accumulatedMeanSquare.assign(_newAccumulatedMeanSquare);
47057 accumulatedMoments.assign(_newAccumulatedMoments);
47058 var _newValue = sub$2(value, _newAccumulatedMoments);
47059 value.assign(_newValue);
47060 }
47061 });
47062 });
47063 this.incrementIterations();
47064 }
47065 }, {
47066 key: "dispose",
47067 value: function dispose$1() {
47068 if (this.accumulatedMeanSquares != null) {
47069 dispose(this.accumulatedMeanSquares.map(function (v) {
47070 return v.variable;
47071 }));
47072 }
47073 if (this.accumulatedMeanGrads != null && this.centered) {
47074 dispose(this.accumulatedMeanGrads.map(function (v) {
47075 return v.variable;
47076 }));
47077 }
47078 if (this.accumulatedMoments != null) {
47079 dispose(this.accumulatedMoments.map(function (v) {
47080 return v.variable;
47081 }));
47082 }
47083 }
47084 }, {
47085 key: "getWeights",
47086 value: function () {
47087 var _getWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
47088 var variables;
47089 return _regeneratorRuntime().wrap(function _callee$(_context) {
47090 while (1) switch (_context.prev = _context.next) {
47091 case 0:
47092 // Order matters for Python compatibility.
47093 variables = [].concat(_toConsumableArray(this.accumulatedMeanSquares), _toConsumableArray(this.accumulatedMoments));
47094 if (this.centered) {
47095 variables.push.apply(variables, _toConsumableArray(this.accumulatedMeanGrads));
47096 }
47097 _context.next = 4;
47098 return this.saveIterations();
47099 case 4:
47100 _context.t0 = _context.sent;
47101 return _context.abrupt("return", [_context.t0].concat(variables.map(function (v) {
47102 return {
47103 name: v.originalName,
47104 tensor: v.variable
47105 };
47106 })));
47107 case 6:
47108 case "end":
47109 return _context.stop();
47110 }
47111 }, _callee, this);
47112 }));
47113 function getWeights() {
47114 return _getWeights.apply(this, arguments);
47115 }
47116 return getWeights;
47117 }()
47118 }, {
47119 key: "setWeights",
47120 value: function () {
47121 var _setWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(weightValues) {
47122 var variableCount, trainable;
47123 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
47124 while (1) switch (_context2.prev = _context2.next) {
47125 case 0:
47126 _context2.next = 2;
47127 return this.extractIterations(weightValues);
47128 case 2:
47129 weightValues = _context2.sent;
47130 variableCount = this.centered ? weightValues.length / 3 : weightValues.length / 2;
47131 trainable = false;
47132 this.accumulatedMeanSquares = weightValues.slice(0, variableCount).map(function (v) {
47133 return {
47134 originalName: v.name,
47135 variable: v.tensor.variable(trainable)
47136 };
47137 });
47138 this.accumulatedMoments = weightValues.slice(variableCount, variableCount * 2).map(function (v) {
47139 return {
47140 originalName: v.name,
47141 variable: v.tensor.variable(trainable)
47142 };
47143 });
47144 if (this.centered) {
47145 this.accumulatedMeanGrads = weightValues.slice(variableCount * 2, variableCount * 3).map(function (v) {
47146 return {
47147 originalName: v.name,
47148 variable: v.tensor.variable(trainable)
47149 };
47150 });
47151 }
47152 case 8:
47153 case "end":
47154 return _context2.stop();
47155 }
47156 }, _callee2, this);
47157 }));
47158 function setWeights(_x) {
47159 return _setWeights.apply(this, arguments);
47160 }
47161 return setWeights;
47162 }()
47163 }, {
47164 key: "getConfig",
47165 value: function getConfig() {
47166 return {
47167 'learningRate': this.learningRate,
47168 'decay': this.decay,
47169 'momentum': this.momentum,
47170 'epsilon': this.epsilon,
47171 'centered': this.centered
47172 };
47173 }
47174 /** @nocollapse */
47175 }], [{
47176 key: "className",
47177 get: /** @nocollapse */
47178 function get() {
47179 // Name matters for Python compatibility.
47180 // This is a getter instead of a property because when it's a property, it
47181 // prevents the entire class from being tree-shaken.
47182 return 'RMSProp';
47183 }
47184 }, {
47185 key: "fromConfig",
47186 value: function fromConfig(cls, config) {
47187 return new cls(config['learningRate'], config['decay'], config['momentum'], config['epsilon'], config['centered']);
47188 }
47189 }]);
47190 return RMSPropOptimizer;
47191 }(Optimizer);
47192
47193 var OPTIMIZERS = [AdadeltaOptimizer, AdagradOptimizer, AdamOptimizer, AdamaxOptimizer, MomentumOptimizer, RMSPropOptimizer, SGDOptimizer];
47194 function registerOptimizers() {
47195 var _iterator = _createForOfIteratorHelper(OPTIMIZERS),
47196 _step;
47197 try {
47198 for (_iterator.s(); !(_step = _iterator.n()).done;) {
47199 var optimizer = _step.value;
47200 registerClass(optimizer);
47201 }
47202 } catch (err) {
47203 _iterator.e(err);
47204 } finally {
47205 _iterator.f();
47206 }
47207 }
47208
47209 var DEFAULT_FILE_NAME_PREFIX = 'model';
47210 var DEFAULT_JSON_EXTENSION_NAME = '.json';
47211 var DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';
47212 function defer(f) {
47213 return new Promise(function (resolve) {
47214 return setTimeout(resolve);
47215 }).then(f);
47216 }
47217 var BrowserDownloads = /*#__PURE__*/function () {
47218 function BrowserDownloads(fileNamePrefix) {
47219 _classCallCheck(this, BrowserDownloads);
47220 if (!env().getBool('IS_BROWSER')) {
47221 // TODO(cais): Provide info on what IOHandlers are available under the
47222 // current environment.
47223 throw new Error('browserDownloads() cannot proceed because the current environment ' + 'is not a browser.');
47224 }
47225 if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) {
47226 fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length);
47227 }
47228 if (fileNamePrefix == null || fileNamePrefix.length === 0) {
47229 fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;
47230 }
47231 this.modelJsonFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;
47232 this.weightDataFileName = fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;
47233 }
47234 _createClass(BrowserDownloads, [{
47235 key: "save",
47236 value: function () {
47237 var _save = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(modelArtifacts) {
47238 var weightBuffer, weightsURL, weightsManifest, modelJSON, modelJsonURL, jsonAnchor, weightDataAnchor;
47239 return _regeneratorRuntime().wrap(function _callee$(_context) {
47240 while (1) switch (_context.prev = _context.next) {
47241 case 0:
47242 if (!(typeof document === 'undefined')) {
47243 _context.next = 2;
47244 break;
47245 }
47246 throw new Error('Browser downloads are not supported in ' + 'this environment since `document` is not present');
47247 case 2:
47248 // TODO(mattsoulanille): Support saving models over 2GB that exceed
47249 // Chrome's ArrayBuffer size limit.
47250 weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
47251 weightsURL = window.URL.createObjectURL(new Blob([weightBuffer], {
47252 type: 'application/octet-stream'
47253 }));
47254 if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) {
47255 _context.next = 8;
47256 break;
47257 }
47258 throw new Error('BrowserDownloads.save() does not support saving model topology ' + 'in binary formats yet.');
47259 case 8:
47260 weightsManifest = [{
47261 paths: ['./' + this.weightDataFileName],
47262 weights: modelArtifacts.weightSpecs
47263 }];
47264 modelJSON = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
47265 modelJsonURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelJSON)], {
47266 type: 'application/json'
47267 })); // If anchor elements are not provided, create them without attaching them
47268 // to parents, so that the downloaded file names can be controlled.
47269 jsonAnchor = this.modelJsonAnchor == null ? document.createElement('a') : this.modelJsonAnchor;
47270 jsonAnchor.download = this.modelJsonFileName;
47271 jsonAnchor.href = modelJsonURL;
47272 // Trigger downloads by evoking a click event on the download anchors.
47273 // When multiple downloads are started synchronously, Firefox will only
47274 // save the last one.
47275 _context.next = 16;
47276 return defer(function () {
47277 return jsonAnchor.dispatchEvent(new MouseEvent('click'));
47278 });
47279 case 16:
47280 if (!(modelArtifacts.weightData != null)) {
47281 _context.next = 22;
47282 break;
47283 }
47284 weightDataAnchor = this.weightDataAnchor == null ? document.createElement('a') : this.weightDataAnchor;
47285 weightDataAnchor.download = this.weightDataFileName;
47286 weightDataAnchor.href = weightsURL;
47287 _context.next = 22;
47288 return defer(function () {
47289 return weightDataAnchor.dispatchEvent(new MouseEvent('click'));
47290 });
47291 case 22:
47292 return _context.abrupt("return", {
47293 modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts)
47294 });
47295 case 23:
47296 case "end":
47297 return _context.stop();
47298 }
47299 }, _callee, this);
47300 }));
47301 function save(_x) {
47302 return _save.apply(this, arguments);
47303 }
47304 return save;
47305 }()
47306 }]);
47307 return BrowserDownloads;
47308 }();
47309 BrowserDownloads.URL_SCHEME = 'downloads://';
47310 var BrowserFiles = /*#__PURE__*/function () {
47311 function BrowserFiles(files) {
47312 _classCallCheck(this, BrowserFiles);
47313 if (files == null || files.length < 1) {
47314 throw new Error("When calling browserFiles, at least 1 file is required, " + "but received ".concat(files));
47315 }
47316 this.jsonFile = files[0];
47317 this.weightsFiles = files.slice(1);
47318 }
47319 _createClass(BrowserFiles, [{
47320 key: "load",
47321 value: function () {
47322 var _load = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
47323 var _this = this;
47324 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
47325 while (1) switch (_context2.prev = _context2.next) {
47326 case 0:
47327 return _context2.abrupt("return", new Promise(function (resolve, reject) {
47328 var jsonReader = new FileReader();
47329 jsonReader.onload = function (event) {
47330 // tslint:disable-next-line:no-any
47331 var modelJSON = JSON.parse(event.target.result);
47332 var modelTopology = modelJSON.modelTopology;
47333 if (modelTopology == null) {
47334 reject(new Error("modelTopology field is missing from file ".concat(_this.jsonFile.name)));
47335 return;
47336 }
47337 var weightsManifest = modelJSON.weightsManifest;
47338 if (weightsManifest == null) {
47339 reject(new Error("weightManifest field is missing from file ".concat(_this.jsonFile.name)));
47340 return;
47341 }
47342 if (_this.weightsFiles.length === 0) {
47343 resolve({
47344 modelTopology: modelTopology
47345 });
47346 return;
47347 }
47348 var modelArtifactsPromise = getModelArtifactsForJSON(modelJSON, function (weightsManifest) {
47349 return _this.loadWeights(weightsManifest);
47350 });
47351 resolve(modelArtifactsPromise);
47352 };
47353 jsonReader.onerror = function (error) {
47354 return reject("Failed to read model topology and weights manifest JSON " + "from file '".concat(_this.jsonFile.name, "'. BrowserFiles supports loading ") + "Keras-style tf.Model artifacts only.");
47355 };
47356 jsonReader.readAsText(_this.jsonFile);
47357 }));
47358 case 1:
47359 case "end":
47360 return _context2.stop();
47361 }
47362 }, _callee2);
47363 }));
47364 function load() {
47365 return _load.apply(this, arguments);
47366 }
47367 return load;
47368 }()
47369 }, {
47370 key: "loadWeights",
47371 value: function loadWeights(weightsManifest) {
47372 var _this2 = this;
47373 var weightSpecs = [];
47374 var paths = [];
47375 var _iterator = _createForOfIteratorHelper(weightsManifest),
47376 _step;
47377 try {
47378 for (_iterator.s(); !(_step = _iterator.n()).done;) {
47379 var entry = _step.value;
47380 weightSpecs.push.apply(weightSpecs, _toConsumableArray(entry.weights));
47381 paths.push.apply(paths, _toConsumableArray(entry.paths));
47382 }
47383 } catch (err) {
47384 _iterator.e(err);
47385 } finally {
47386 _iterator.f();
47387 }
47388 var pathToFile = this.checkManifestAndWeightFiles(weightsManifest);
47389 var promises = paths.map(function (path) {
47390 return _this2.loadWeightsFile(path, pathToFile[path]);
47391 });
47392 return Promise.all(promises).then(function (buffers) {
47393 return [weightSpecs, buffers];
47394 });
47395 }
47396 }, {
47397 key: "loadWeightsFile",
47398 value: function loadWeightsFile(path, file) {
47399 return new Promise(function (resolve, reject) {
47400 var weightFileReader = new FileReader();
47401 weightFileReader.onload = function (event) {
47402 // tslint:disable-next-line:no-any
47403 var weightData = event.target.result;
47404 resolve(weightData);
47405 };
47406 weightFileReader.onerror = function (error) {
47407 return reject("Failed to weights data from file of path '".concat(path, "'."));
47408 };
47409 weightFileReader.readAsArrayBuffer(file);
47410 });
47411 }
47412 /**
47413 * Check the compatibility between weights manifest and weight files.
47414 */
47415 }, {
47416 key: "checkManifestAndWeightFiles",
47417 value: function checkManifestAndWeightFiles(manifest) {
47418 var _this3 = this;
47419 var basenames = [];
47420 var fileNames = this.weightsFiles.map(function (file) {
47421 return basename(file.name);
47422 });
47423 var pathToFile = {};
47424 var _iterator2 = _createForOfIteratorHelper(manifest),
47425 _step2;
47426 try {
47427 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
47428 var group = _step2.value;
47429 group.paths.forEach(function (path) {
47430 var pathBasename = basename(path);
47431 if (basenames.indexOf(pathBasename) !== -1) {
47432 throw new Error("Duplicate file basename found in weights manifest: " + "'".concat(pathBasename, "'"));
47433 }
47434 basenames.push(pathBasename);
47435 if (fileNames.indexOf(pathBasename) === -1) {
47436 throw new Error("Weight file with basename '".concat(pathBasename, "' is not provided."));
47437 } else {
47438 pathToFile[path] = _this3.weightsFiles[fileNames.indexOf(pathBasename)];
47439 }
47440 });
47441 }
47442 } catch (err) {
47443 _iterator2.e(err);
47444 } finally {
47445 _iterator2.f();
47446 }
47447 if (basenames.length !== this.weightsFiles.length) {
47448 throw new Error("Mismatch in the number of files in weights manifest " + "(".concat(basenames.length, ") and the number of weight files provided ") + "(".concat(this.weightsFiles.length, ")."));
47449 }
47450 return pathToFile;
47451 }
47452 }]);
47453 return BrowserFiles;
47454 }();
47455 var browserDownloadsRouter = function browserDownloadsRouter(url) {
47456 if (!env().getBool('IS_BROWSER')) {
47457 return null;
47458 } else {
47459 if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {
47460 return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));
47461 } else {
47462 return null;
47463 }
47464 }
47465 };
47466 IORouterRegistry.registerSaveRouter(browserDownloadsRouter);
47467 /**
47468 * Creates an IOHandler that triggers file downloads from the browser.
47469 *
47470 * The returned `IOHandler` instance can be used as model exporting methods such
47471 * as `tf.Model.save` and supports only saving.
47472 *
47473 * ```js
47474 * const model = tf.sequential();
47475 * model.add(tf.layers.dense(
47476 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
47477 * const saveResult = await model.save('downloads://mymodel');
47478 * // This will trigger downloading of two files:
47479 * // 'mymodel.json' and 'mymodel.weights.bin'.
47480 * console.log(saveResult);
47481 * ```
47482 *
47483 * @param fileNamePrefix Prefix name of the files to be downloaded. For use with
47484 * `tf.Model`, `fileNamePrefix` should follow either of the following two
47485 * formats:
47486 * 1. `null` or `undefined`, in which case the default file
47487 * names will be used:
47488 * - 'model.json' for the JSON file containing the model topology and
47489 * weights manifest.
47490 * - 'model.weights.bin' for the binary file containing the binary weight
47491 * values.
47492 * 2. A single string or an Array of a single string, as the file name prefix.
47493 * For example, if `'foo'` is provided, the downloaded JSON
47494 * file and binary weights file will be named 'foo.json' and
47495 * 'foo.weights.bin', respectively.
47496 * @param config Additional configuration for triggering downloads.
47497 * @returns An instance of `BrowserDownloads` `IOHandler`.
47498 *
47499 * @doc {
47500 * heading: 'Models',
47501 * subheading: 'Loading',
47502 * namespace: 'io',
47503 * ignoreCI: true
47504 * }
47505 */
47506 function browserDownloads() {
47507 var fileNamePrefix = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 'model';
47508 return new BrowserDownloads(fileNamePrefix);
47509 }
47510 /**
47511 * Creates an IOHandler that loads model artifacts from user-selected files.
47512 *
47513 * This method can be used for loading from files such as user-selected files
47514 * in the browser.
47515 * When used in conjunction with `tf.loadLayersModel`, an instance of
47516 * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
47517 *
47518 * ```js
47519 * // Note: This code snippet won't run properly without the actual file input
47520 * // elements in the HTML DOM.
47521 *
47522 * // Suppose there are two HTML file input (`<input type="file" ...>`)
47523 * // elements.
47524 * const uploadJSONInput = document.getElementById('upload-json');
47525 * const uploadWeightsInput = document.getElementById('upload-weights');
47526 * const model = await tf.loadLayersModel(tf.io.browserFiles(
47527 * [uploadJSONInput.files[0], uploadWeightsInput.files[0]]));
47528 * ```
47529 *
47530 * @param files `File`s to load from. Currently, this function supports only
47531 * loading from files that contain Keras-style models (i.e., `tf.Model`s), for
47532 * which an `Array` of `File`s is expected (in that order):
47533 * - A JSON file containing the model topology and weight manifest.
47534 * - Optionally, one or more binary files containing the binary weights.
47535 * These files must have names that match the paths in the `weightsManifest`
47536 * contained by the aforementioned JSON file, or errors will be thrown
47537 * during loading. These weights files have the same format as the ones
47538 * generated by `tensorflowjs_converter` that comes with the `tensorflowjs`
47539 * Python PIP package. If no weights files are provided, only the model
47540 * topology will be loaded from the JSON file above.
47541 * @returns An instance of `Files` `IOHandler`.
47542 *
47543 * @doc {
47544 * heading: 'Models',
47545 * subheading: 'Loading',
47546 * namespace: 'io',
47547 * ignoreCI: true
47548 * }
47549 */
47550 function browserFiles(files) {
47551 return new BrowserFiles(files);
47552 }
47553
47554 /**
47555 * @license
47556 * Copyright 2019 Google LLC. All Rights Reserved.
47557 * Licensed under the Apache License, Version 2.0 (the "License");
47558 * you may not use this file except in compliance with the License.
47559 * You may obtain a copy of the License at
47560 *
47561 * http://www.apache.org/licenses/LICENSE-2.0
47562 *
47563 * Unless required by applicable law or agreed to in writing, software
47564 * distributed under the License is distributed on an "AS IS" BASIS,
47565 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
47566 * See the License for the specific language governing permissions and
47567 * limitations under the License.
47568 * =============================================================================
47569 */
47570 /**
47571 * Monitor Promise.all progress, fire onProgress callback function.
47572 *
47573 * @param promises Promise list going to be monitored
47574 * @param onProgress Callback function. Fired when a promise resolved.
47575 * @param startFraction Optional fraction start. Default to 0.
47576 * @param endFraction Optional fraction end. Default to 1.
47577 */
47578 function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) {
47579 checkPromises(promises);
47580 startFraction = startFraction == null ? 0 : startFraction;
47581 endFraction = endFraction == null ? 1 : endFraction;
47582 checkFraction(startFraction, endFraction);
47583 var resolvedPromise = 0;
47584 var registerMonitor = function registerMonitor(promise) {
47585 promise.then(function (value) {
47586 var fraction = startFraction + ++resolvedPromise / promises.length * (endFraction - startFraction);
47587 // pass fraction as parameter to callback function.
47588 onProgress(fraction);
47589 return value;
47590 });
47591 return promise;
47592 };
47593 function checkPromises(promises) {
47594 assert$1(promises != null && Array.isArray(promises) && promises.length > 0, function () {
47595 return 'promises must be a none empty array';
47596 });
47597 }
47598 function checkFraction(startFraction, endFraction) {
47599 assert$1(startFraction >= 0 && startFraction <= 1, function () {
47600 return "Progress fraction must be in range [0, 1], but " + "got startFraction ".concat(startFraction);
47601 });
47602 assert$1(endFraction >= 0 && endFraction <= 1, function () {
47603 return "Progress fraction must be in range [0, 1], but " + "got endFraction ".concat(endFraction);
47604 });
47605 assert$1(endFraction >= startFraction, function () {
47606 return "startFraction must be no more than endFraction, but " + "got startFraction ".concat(startFraction, " and endFraction ") + "".concat(endFraction);
47607 });
47608 }
47609 return Promise.all(promises.map(registerMonitor));
47610 }
47611
47612 /**
47613 * Reads binary weights data from a number of URLs.
47614 *
47615 * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.
47616 * @param requestOptions RequestInit (options) for the HTTP requests.
47617 * @param fetchFunc Optional overriding value for the `window.fetch` function.
47618 * @param onProgress Optional, progress callback function, fired periodically
47619 * before the load is completed.
47620 * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same
47621 * length as `fetchURLs`.
47622 */
47623 function loadWeightsAsArrayBuffer(_x, _x2) {
47624 return _loadWeightsAsArrayBuffer.apply(this, arguments);
47625 }
47626 function _loadWeightsAsArrayBuffer() {
47627 _loadWeightsAsArrayBuffer = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(fetchURLs, loadOptions) {
47628 var fetchFunc, requests, fetchStartFraction, fetchEndFraction, responses, bufferPromises, bufferStartFraction, bufferEndFraction, buffers;
47629 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
47630 while (1) switch (_context3.prev = _context3.next) {
47631 case 0:
47632 if (loadOptions == null) {
47633 loadOptions = {};
47634 }
47635 fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : loadOptions.fetchFunc; // Create the requests for all of the weights in parallel.
47636 requests = fetchURLs.map(function (fetchURL) {
47637 return fetchFunc(fetchURL, loadOptions.requestInit, {
47638 isBinary: true
47639 });
47640 });
47641 fetchStartFraction = 0;
47642 fetchEndFraction = 0.5;
47643 if (!(loadOptions.onProgress == null)) {
47644 _context3.next = 11;
47645 break;
47646 }
47647 _context3.next = 8;
47648 return Promise.all(requests);
47649 case 8:
47650 _context3.t0 = _context3.sent;
47651 _context3.next = 14;
47652 break;
47653 case 11:
47654 _context3.next = 13;
47655 return monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction);
47656 case 13:
47657 _context3.t0 = _context3.sent;
47658 case 14:
47659 responses = _context3.t0;
47660 bufferPromises = responses.map(function (response) {
47661 return response.arrayBuffer();
47662 });
47663 bufferStartFraction = 0.5;
47664 bufferEndFraction = 1;
47665 if (!(loadOptions.onProgress == null)) {
47666 _context3.next = 24;
47667 break;
47668 }
47669 _context3.next = 21;
47670 return Promise.all(bufferPromises);
47671 case 21:
47672 _context3.t1 = _context3.sent;
47673 _context3.next = 27;
47674 break;
47675 case 24:
47676 _context3.next = 26;
47677 return monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction);
47678 case 26:
47679 _context3.t1 = _context3.sent;
47680 case 27:
47681 buffers = _context3.t1;
47682 return _context3.abrupt("return", buffers);
47683 case 29:
47684 case "end":
47685 return _context3.stop();
47686 }
47687 }, _callee3);
47688 }));
47689 return _loadWeightsAsArrayBuffer.apply(this, arguments);
47690 }
47691 function streamWeights(fetchURLs, loadOptions) {
47692 var _a;
47693 var fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : loadOptions.fetchFunc;
47694 var fetchIndex = 0;
47695 var chunkReader;
47696 (_a = loadOptions.onProgress) === null || _a === void 0 ? void 0 : _a.call(loadOptions, 0);
47697 return new ReadableStream({
47698 pull: function () {
47699 var _pull = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(controller) {
47700 var _a, body, _yield$chunkReader$re, done, value;
47701 return _regeneratorRuntime().wrap(function _callee$(_context) {
47702 while (1) switch (_context.prev = _context.next) {
47703 case 0:
47704 if (!(fetchIndex < fetchURLs.length)) {
47705 _context.next = 20;
47706 break;
47707 }
47708 if (chunkReader) {
47709 _context.next = 6;
47710 break;
47711 }
47712 _context.next = 4;
47713 return fetchFunc(fetchURLs[fetchIndex], loadOptions.requestInit, {
47714 isBinary: true
47715 });
47716 case 4:
47717 body = _context.sent.body;
47718 chunkReader = body.getReader();
47719 case 6:
47720 _context.next = 8;
47721 return chunkReader.read();
47722 case 8:
47723 _yield$chunkReader$re = _context.sent;
47724 done = _yield$chunkReader$re.done;
47725 value = _yield$chunkReader$re.value;
47726 if (!done) {
47727 _context.next = 16;
47728 break;
47729 }
47730 fetchIndex++;
47731 chunkReader = undefined;
47732 (_a = loadOptions.onProgress) === null || _a === void 0 ? void 0 : _a.call(loadOptions, fetchIndex / fetchURLs.length);
47733 return _context.abrupt("continue", 0);
47734 case 16:
47735 controller.enqueue(value);
47736 return _context.abrupt("return");
47737 case 20:
47738 controller.close();
47739 case 21:
47740 case "end":
47741 return _context.stop();
47742 }
47743 }, _callee);
47744 }));
47745 function pull(_x3) {
47746 return _pull.apply(this, arguments);
47747 }
47748 return pull;
47749 }()
47750 });
47751 }
47752 /**
47753 * Reads a weights manifest JSON configuration, fetches the weights and
47754 * returns them as `Tensor`s.
47755 *
47756 * @param manifest The weights manifest JSON.
47757 * @param filePathPrefix The path prefix for filenames given in the manifest.
47758 * Defaults to the empty string.
47759 * @param weightNames The names of the weights to be fetched.
47760 */
47761 function loadWeights(_x4) {
47762 return _loadWeights.apply(this, arguments);
47763 }
47764 /**
47765 * Creates a function, which reads a weights manifest JSON configuration,
47766 * fetches the weight files using the specified function and returns them as
47767 * `Tensor`s.
47768 *
47769 * ```js
47770 * // example for creating a nodejs weight loader, which reads the weight files
47771 * // from disk using fs.readFileSync
47772 *
47773 * import * as fs from 'fs'
47774 *
47775 * const fetchWeightsFromDisk = (filePaths: string[]) =>
47776 * filePaths.map(filePath => fs.readFileSync(filePath).buffer)
47777 *
47778 * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk)
47779 *
47780 * const manifest = JSON.parse(
47781 * fs.readFileSync('./my_model-weights_manifest').toString()
47782 * )
47783 * const weightMap = await loadWeights(manifest, './')
47784 * ```
47785 * @param fetchWeightsFunction The function used for fetching the weight files.
47786 * @returns Weight loading function.
47787 */
47788 function _loadWeights() {
47789 _loadWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(manifest) {
47790 var filePathPrefix,
47791 weightNames,
47792 requestInit,
47793 fetchWeights,
47794 loadWeights,
47795 _args4 = arguments;
47796 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
47797 while (1) switch (_context4.prev = _context4.next) {
47798 case 0:
47799 filePathPrefix = _args4.length > 1 && _args4[1] !== undefined ? _args4[1] : '';
47800 weightNames = _args4.length > 2 ? _args4[2] : undefined;
47801 requestInit = _args4.length > 3 ? _args4[3] : undefined;
47802 // TODO(nsthorat): Groups are currently fetched atomically. If you need a
47803 // single weight from a group, the whole group will be fetched. At a future
47804 // date, we should support fetching only the individual shards within a
47805 // group that are needed to reconstruct the requested weight.
47806 // TODO(cais): Use `decodeWeights` for implementation.
47807 fetchWeights = function fetchWeights(fetchUrls) {
47808 return loadWeightsAsArrayBuffer(fetchUrls, {
47809 requestInit: requestInit
47810 });
47811 };
47812 loadWeights = weightsLoaderFactory(fetchWeights);
47813 return _context4.abrupt("return", loadWeights(manifest, filePathPrefix, weightNames));
47814 case 6:
47815 case "end":
47816 return _context4.stop();
47817 }
47818 }, _callee4);
47819 }));
47820 return _loadWeights.apply(this, arguments);
47821 }
47822 function weightsLoaderFactory(fetchWeightsFunction) {
47823 return /*#__PURE__*/function () {
47824 var _ref = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(manifest) {
47825 var filePathPrefix,
47826 weightNames,
47827 groupIndicesToFetchMap,
47828 groupWeightsToFetch,
47829 weightsFound,
47830 allManifestWeightNames,
47831 weightsNotFound,
47832 groupIndicesToFetch,
47833 fetchUrls,
47834 buffers,
47835 weightsTensorMap,
47836 bufferIndexOffset,
47837 _args2 = arguments;
47838 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
47839 while (1) switch (_context2.prev = _context2.next) {
47840 case 0:
47841 filePathPrefix = _args2.length > 1 && _args2[1] !== undefined ? _args2[1] : '';
47842 weightNames = _args2.length > 2 ? _args2[2] : undefined;
47843 // Collect all the groups, weights, and their relative offsets to be
47844 // fetched.
47845 groupIndicesToFetchMap = manifest.map(function () {
47846 return false;
47847 });
47848 groupWeightsToFetch = {};
47849 weightsFound = weightNames != null ? weightNames.map(function () {
47850 return false;
47851 }) : [];
47852 allManifestWeightNames = [];
47853 manifest.forEach(function (manifestGroupConfig, groupIndex) {
47854 var groupOffset = 0;
47855 manifestGroupConfig.weights.forEach(function (weightsEntry) {
47856 var rawDtype = 'quantization' in weightsEntry ? weightsEntry.quantization.dtype : weightsEntry.dtype;
47857 var weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * sizeFromShape(weightsEntry.shape);
47858 var enqueueWeightsForFetchingFn = function enqueueWeightsForFetchingFn() {
47859 groupIndicesToFetchMap[groupIndex] = true;
47860 if (groupWeightsToFetch[groupIndex] == null) {
47861 groupWeightsToFetch[groupIndex] = [];
47862 }
47863 groupWeightsToFetch[groupIndex].push({
47864 manifestEntry: weightsEntry,
47865 groupOffset: groupOffset,
47866 sizeBytes: weightsBytes
47867 });
47868 };
47869 if (weightNames != null) {
47870 weightNames.forEach(function (weightName, weightIndex) {
47871 if (weightName === weightsEntry.name) {
47872 enqueueWeightsForFetchingFn();
47873 weightsFound[weightIndex] = true;
47874 }
47875 });
47876 } else {
47877 enqueueWeightsForFetchingFn();
47878 }
47879 allManifestWeightNames.push(weightsEntry.name);
47880 groupOffset += weightsBytes;
47881 });
47882 });
47883 if (weightsFound.every(function (found) {
47884 return found;
47885 })) {
47886 _context2.next = 10;
47887 break;
47888 }
47889 weightsNotFound = weightNames.filter(function (_, i) {
47890 return !weightsFound[i];
47891 });
47892 throw new Error("Could not find weights in manifest with names: " + "".concat(weightsNotFound.join(', '), ". \n") + "Manifest JSON has weights with names: " + "".concat(allManifestWeightNames.join(', '), "."));
47893 case 10:
47894 // Convert the one-hot boolean groupId => shouldFetch map to a list of group
47895 // IDs.
47896 groupIndicesToFetch = groupIndicesToFetchMap.reduce(function (accumulator, shouldFetch, i) {
47897 if (shouldFetch) {
47898 accumulator.push(i);
47899 }
47900 return accumulator;
47901 }, []);
47902 fetchUrls = [];
47903 groupIndicesToFetch.forEach(function (i) {
47904 manifest[i].paths.forEach(function (filepath) {
47905 var fetchUrl = filePathPrefix + (!filePathPrefix.endsWith('/') ? '/' : '') + filepath;
47906 fetchUrls.push(fetchUrl);
47907 });
47908 });
47909 _context2.next = 15;
47910 return fetchWeightsFunction(fetchUrls);
47911 case 15:
47912 buffers = _context2.sent;
47913 weightsTensorMap = {};
47914 bufferIndexOffset = 0;
47915 groupIndicesToFetch.forEach(function (i) {
47916 var numBuffers = manifest[i].paths.length;
47917 var weightsBuffer = new CompositeArrayBuffer(buffers.slice(bufferIndexOffset, bufferIndexOffset + numBuffers));
47918 var weightsEntries = groupWeightsToFetch[i];
47919 weightsEntries.forEach(function (weightsEntry) {
47920 var byteBuffer = weightsBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes);
47921 var nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);
47922 for (var name in nameToTensorMap) {
47923 weightsTensorMap[name] = nameToTensorMap[name];
47924 }
47925 });
47926 bufferIndexOffset += numBuffers;
47927 });
47928 return _context2.abrupt("return", weightsTensorMap);
47929 case 20:
47930 case "end":
47931 return _context2.stop();
47932 }
47933 }, _callee2);
47934 }));
47935 return function (_x5) {
47936 return _ref.apply(this, arguments);
47937 };
47938 }();
47939 }
47940
47941 var OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
47942 var JSON_TYPE = 'application/json';
47943 var HTTPRequest = /*#__PURE__*/function () {
47944 function HTTPRequest(path, loadOptions) {
47945 _classCallCheck(this, HTTPRequest);
47946 this.DEFAULT_METHOD = 'POST';
47947 if (loadOptions == null) {
47948 loadOptions = {};
47949 }
47950 this.weightPathPrefix = loadOptions.weightPathPrefix;
47951 this.weightUrlConverter = loadOptions.weightUrlConverter;
47952 if (loadOptions.fetchFunc != null) {
47953 assert$1(typeof loadOptions.fetchFunc === 'function', function () {
47954 return 'Must pass a function that matches the signature of ' + '`fetch` (see ' + 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)';
47955 });
47956 this.fetch = loadOptions.fetchFunc;
47957 } else {
47958 this.fetch = env().platform.fetch;
47959 }
47960 assert$1(path != null && path.length > 0, function () {
47961 return 'URL path for http must not be null, undefined or ' + 'empty.';
47962 });
47963 if (Array.isArray(path)) {
47964 assert$1(path.length === 2, function () {
47965 return 'URL paths for http must have a length of 2, ' + "(actual length is ".concat(path.length, ").");
47966 });
47967 }
47968 this.path = path;
47969 if (loadOptions.requestInit != null && loadOptions.requestInit.body != null) {
47970 throw new Error('requestInit is expected to have no pre-existing body, but has one.');
47971 }
47972 this.requestInit = loadOptions.requestInit || {};
47973 this.loadOptions = loadOptions;
47974 }
47975 _createClass(HTTPRequest, [{
47976 key: "save",
47977 value: function () {
47978 var _save = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(modelArtifacts) {
47979 var init, weightsManifest, modelTopologyAndWeightManifest, weightBuffer, response;
47980 return _regeneratorRuntime().wrap(function _callee$(_context) {
47981 while (1) switch (_context.prev = _context.next) {
47982 case 0:
47983 if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) {
47984 _context.next = 2;
47985 break;
47986 }
47987 throw new Error('BrowserHTTPRequest.save() does not support saving model topology ' + 'in binary formats yet.');
47988 case 2:
47989 init = Object.assign({
47990 method: this.DEFAULT_METHOD
47991 }, this.requestInit);
47992 init.body = new FormData();
47993 weightsManifest = [{
47994 paths: ['./model.weights.bin'],
47995 weights: modelArtifacts.weightSpecs
47996 }];
47997 modelTopologyAndWeightManifest = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
47998 init.body.append('model.json', new Blob([JSON.stringify(modelTopologyAndWeightManifest)], {
47999 type: JSON_TYPE
48000 }), 'model.json');
48001 if (modelArtifacts.weightData != null) {
48002 // TODO(mattsoulanille): Support saving models over 2GB that exceed
48003 // Chrome's ArrayBuffer size limit.
48004 weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
48005 init.body.append('model.weights.bin', new Blob([weightBuffer], {
48006 type: OCTET_STREAM_MIME_TYPE
48007 }), 'model.weights.bin');
48008 }
48009 _context.next = 10;
48010 return this.fetch(this.path, init);
48011 case 10:
48012 response = _context.sent;
48013 if (!response.ok) {
48014 _context.next = 15;
48015 break;
48016 }
48017 return _context.abrupt("return", {
48018 modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts),
48019 responses: [response]
48020 });
48021 case 15:
48022 throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " + "".concat(response.status, "."));
48023 case 16:
48024 case "end":
48025 return _context.stop();
48026 }
48027 }, _callee, this);
48028 }));
48029 function save(_x) {
48030 return _save.apply(this, arguments);
48031 }
48032 return save;
48033 }()
48034 }, {
48035 key: "loadModelJSON",
48036 value: function () {
48037 var _loadModelJSON = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
48038 var modelConfigRequest, modelJSON, message, modelTopology, weightsManifest;
48039 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
48040 while (1) switch (_context2.prev = _context2.next) {
48041 case 0:
48042 _context2.next = 2;
48043 return this.fetch(this.path, this.requestInit);
48044 case 2:
48045 modelConfigRequest = _context2.sent;
48046 if (modelConfigRequest.ok) {
48047 _context2.next = 5;
48048 break;
48049 }
48050 throw new Error("Request to ".concat(this.path, " failed with status code ") + "".concat(modelConfigRequest.status, ". Please verify this URL points to ") + "the model JSON of the model to load.");
48051 case 5:
48052 _context2.prev = 5;
48053 _context2.next = 8;
48054 return modelConfigRequest.json();
48055 case 8:
48056 modelJSON = _context2.sent;
48057 _context2.next = 16;
48058 break;
48059 case 11:
48060 _context2.prev = 11;
48061 _context2.t0 = _context2["catch"](5);
48062 message = "Failed to parse model JSON of response from ".concat(this.path, "."); // TODO(nsthorat): Remove this after some time when we're comfortable that
48063 // .pb files are mostly gone.
48064 if (this.path.endsWith('.pb')) {
48065 message += ' Your path contains a .pb file extension. ' + 'Support for .pb models have been removed in TensorFlow.js 1.0 ' + 'in favor of .json models. You can re-convert your Python ' + 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' + 'or you can convert your.pb models with the \'pb2json\'' + 'NPM script in the tensorflow/tfjs-converter repository.';
48066 } else {
48067 message += ' Please make sure the server is serving valid ' + 'JSON for this request.';
48068 }
48069 throw new Error(message);
48070 case 16:
48071 // We do not allow both modelTopology and weightsManifest to be missing.
48072 modelTopology = modelJSON.modelTopology;
48073 weightsManifest = modelJSON.weightsManifest;
48074 if (!(modelTopology == null && weightsManifest == null)) {
48075 _context2.next = 20;
48076 break;
48077 }
48078 throw new Error("The JSON from HTTP path ".concat(this.path, " contains neither model ") + "topology or manifest for weights.");
48079 case 20:
48080 return _context2.abrupt("return", modelJSON);
48081 case 21:
48082 case "end":
48083 return _context2.stop();
48084 }
48085 }, _callee2, this, [[5, 11]]);
48086 }));
48087 function loadModelJSON() {
48088 return _loadModelJSON.apply(this, arguments);
48089 }
48090 return loadModelJSON;
48091 }()
48092 /**
48093 * Load model artifacts via HTTP request(s).
48094 *
48095 * See the documentation to `tf.io.http` for details on the saved
48096 * artifacts.
48097 *
48098 * @returns The loaded model artifacts (if loading succeeds).
48099 */
48100 }, {
48101 key: "load",
48102 value: function () {
48103 var _load = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3() {
48104 var _this = this;
48105 var modelJSON;
48106 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
48107 while (1) switch (_context3.prev = _context3.next) {
48108 case 0:
48109 if (!this.loadOptions.streamWeights) {
48110 _context3.next = 2;
48111 break;
48112 }
48113 return _context3.abrupt("return", this.loadStream());
48114 case 2:
48115 _context3.next = 4;
48116 return this.loadModelJSON();
48117 case 4:
48118 modelJSON = _context3.sent;
48119 return _context3.abrupt("return", getModelArtifactsForJSON(modelJSON, function (weightsManifest) {
48120 return _this.loadWeights(weightsManifest);
48121 }));
48122 case 6:
48123 case "end":
48124 return _context3.stop();
48125 }
48126 }, _callee3, this);
48127 }));
48128 function load() {
48129 return _load.apply(this, arguments);
48130 }
48131 return load;
48132 }()
48133 }, {
48134 key: "loadStream",
48135 value: function () {
48136 var _loadStream = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4() {
48137 var _this2 = this;
48138 var modelJSON, fetchURLs, weightSpecs, stream;
48139 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
48140 while (1) switch (_context4.prev = _context4.next) {
48141 case 0:
48142 _context4.next = 2;
48143 return this.loadModelJSON();
48144 case 2:
48145 modelJSON = _context4.sent;
48146 _context4.next = 5;
48147 return this.getWeightUrls(modelJSON.weightsManifest);
48148 case 5:
48149 fetchURLs = _context4.sent;
48150 weightSpecs = getWeightSpecs(modelJSON.weightsManifest);
48151 stream = function stream() {
48152 return streamWeights(fetchURLs, _this2.loadOptions);
48153 };
48154 return _context4.abrupt("return", Object.assign(Object.assign({}, modelJSON), {
48155 weightSpecs: weightSpecs,
48156 getWeightStream: stream
48157 }));
48158 case 9:
48159 case "end":
48160 return _context4.stop();
48161 }
48162 }, _callee4, this);
48163 }));
48164 function loadStream() {
48165 return _loadStream.apply(this, arguments);
48166 }
48167 return loadStream;
48168 }()
48169 }, {
48170 key: "getWeightUrls",
48171 value: function () {
48172 var _getWeightUrls = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5(weightsManifest) {
48173 var weightPath, _parseUrl, _parseUrl2, prefix, suffix, pathPrefix, fetchURLs, urlPromises, _iterator, _step, weightsGroup, _iterator2, _step2, path;
48174 return _regeneratorRuntime().wrap(function _callee5$(_context5) {
48175 while (1) switch (_context5.prev = _context5.next) {
48176 case 0:
48177 weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
48178 _parseUrl = parseUrl(weightPath), _parseUrl2 = _slicedToArray(_parseUrl, 2), prefix = _parseUrl2[0], suffix = _parseUrl2[1];
48179 pathPrefix = this.weightPathPrefix || prefix;
48180 fetchURLs = [];
48181 urlPromises = [];
48182 _iterator = _createForOfIteratorHelper(weightsManifest);
48183 try {
48184 for (_iterator.s(); !(_step = _iterator.n()).done;) {
48185 weightsGroup = _step.value;
48186 _iterator2 = _createForOfIteratorHelper(weightsGroup.paths);
48187 try {
48188 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
48189 path = _step2.value;
48190 if (this.weightUrlConverter != null) {
48191 urlPromises.push(this.weightUrlConverter(path));
48192 } else {
48193 fetchURLs.push(pathPrefix + path + suffix);
48194 }
48195 }
48196 } catch (err) {
48197 _iterator2.e(err);
48198 } finally {
48199 _iterator2.f();
48200 }
48201 }
48202 } catch (err) {
48203 _iterator.e(err);
48204 } finally {
48205 _iterator.f();
48206 }
48207 if (!this.weightUrlConverter) {
48208 _context5.next = 16;
48209 break;
48210 }
48211 _context5.t0 = fetchURLs.push;
48212 _context5.t1 = fetchURLs;
48213 _context5.t2 = _toConsumableArray;
48214 _context5.next = 13;
48215 return Promise.all(urlPromises);
48216 case 13:
48217 _context5.t3 = _context5.sent;
48218 _context5.t4 = (0, _context5.t2)(_context5.t3);
48219 _context5.t0.apply.call(_context5.t0, _context5.t1, _context5.t4);
48220 case 16:
48221 return _context5.abrupt("return", fetchURLs);
48222 case 17:
48223 case "end":
48224 return _context5.stop();
48225 }
48226 }, _callee5, this);
48227 }));
48228 function getWeightUrls(_x2) {
48229 return _getWeightUrls.apply(this, arguments);
48230 }
48231 return getWeightUrls;
48232 }()
48233 }, {
48234 key: "loadWeights",
48235 value: function () {
48236 var _loadWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee6(weightsManifest) {
48237 var fetchURLs, weightSpecs, buffers;
48238 return _regeneratorRuntime().wrap(function _callee6$(_context6) {
48239 while (1) switch (_context6.prev = _context6.next) {
48240 case 0:
48241 _context6.next = 2;
48242 return this.getWeightUrls(weightsManifest);
48243 case 2:
48244 fetchURLs = _context6.sent;
48245 weightSpecs = getWeightSpecs(weightsManifest);
48246 _context6.next = 6;
48247 return loadWeightsAsArrayBuffer(fetchURLs, this.loadOptions);
48248 case 6:
48249 buffers = _context6.sent;
48250 return _context6.abrupt("return", [weightSpecs, buffers]);
48251 case 8:
48252 case "end":
48253 return _context6.stop();
48254 }
48255 }, _callee6, this);
48256 }));
48257 function loadWeights(_x3) {
48258 return _loadWeights.apply(this, arguments);
48259 }
48260 return loadWeights;
48261 }()
48262 }]);
48263 return HTTPRequest;
48264 }();
48265 HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//;
48266 /**
48267 * Extract the prefix and suffix of the url, where the prefix is the path before
48268 * the last file, and suffix is the search params after the last file.
48269 * ```
48270 * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file'
48271 * [prefix, suffix] = parseUrl(url)
48272 * // prefix = 'http://tfhub.dev/model/1/'
48273 * // suffix = '?tfjs-format=file'
48274 * ```
48275 * @param url the model url to be parsed.
48276 */
48277 function parseUrl(url) {
48278 var lastSlash = url.lastIndexOf('/');
48279 var lastSearchParam = url.lastIndexOf('?');
48280 var prefix = url.substring(0, lastSlash);
48281 var suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : '';
48282 return [prefix + '/', suffix];
48283 }
48284 function isHTTPScheme(url) {
48285 return url.match(HTTPRequest.URL_SCHEME_REGEX) != null;
48286 }
48287 var httpRouter = function httpRouter(url, loadOptions) {
48288 if (typeof fetch === 'undefined' && (loadOptions == null || loadOptions.fetchFunc == null)) {
48289 // `http` uses `fetch` or `node-fetch`, if one wants to use it in
48290 // an environment that is not the browser or node they have to setup a
48291 // global fetch polyfill.
48292 return null;
48293 } else {
48294 var isHTTP = true;
48295 if (Array.isArray(url)) {
48296 isHTTP = url.every(function (urlItem) {
48297 return isHTTPScheme(urlItem);
48298 });
48299 } else {
48300 isHTTP = isHTTPScheme(url);
48301 }
48302 if (isHTTP) {
48303 return http(url, loadOptions);
48304 }
48305 }
48306 return null;
48307 };
48308 IORouterRegistry.registerSaveRouter(httpRouter);
48309 IORouterRegistry.registerLoadRouter(httpRouter);
48310 /**
48311 * Creates an IOHandler subtype that sends model artifacts to HTTP server.
48312 *
48313 * An HTTP request of the `multipart/form-data` mime type will be sent to the
48314 * `path` URL. The form data includes artifacts that represent the topology
48315 * and/or weights of the model. In the case of Keras-style `tf.Model`, two
48316 * blobs (files) exist in form-data:
48317 * - A JSON file consisting of `modelTopology` and `weightsManifest`.
48318 * - A binary weights file consisting of the concatenated weight values.
48319 * These files are in the same format as the one generated by
48320 * [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html).
48321 *
48322 * The following code snippet exemplifies the client-side code that uses this
48323 * function:
48324 *
48325 * ```js
48326 * const model = tf.sequential();
48327 * model.add(
48328 * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
48329 *
48330 * const saveResult = await model.save(tf.io.http(
48331 * 'http://model-server:5000/upload', {requestInit: {method: 'PUT'}}));
48332 * console.log(saveResult);
48333 * ```
48334 *
48335 * If the default `POST` method is to be used, without any custom parameters
48336 * such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`:
48337 *
48338 * ```js
48339 * const saveResult = await model.save('http://model-server:5000/upload');
48340 * ```
48341 *
48342 * The following GitHub Gist
48343 * https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864
48344 * implements a server based on [flask](https://github.com/pallets/flask) that
48345 * can receive the request. Upon receiving the model artifacts via the request,
48346 * this particular server reconstitutes instances of [Keras
48347 * Models](https://keras.io/models/model/) in memory.
48348 *
48349 *
48350 * @param path A URL path to the model.
48351 * Can be an absolute HTTP path (e.g.,
48352 * 'http://localhost:8000/model-upload)') or a relative path (e.g.,
48353 * './model-upload').
48354 * @param requestInit Request configurations to be used when sending
48355 * HTTP request to server using `fetch`. It can contain fields such as
48356 * `method`, `credentials`, `headers`, `mode`, etc. See
48357 * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request
48358 * for more information. `requestInit` must not have a body, because the
48359 * body will be set by TensorFlow.js. File blobs representing the model
48360 * topology (filename: 'model.json') and the weights of the model (filename:
48361 * 'model.weights.bin') will be appended to the body. If `requestInit` has a
48362 * `body`, an Error will be thrown.
48363 * @param loadOptions Optional configuration for the loading. It includes the
48364 * following fields:
48365 * - weightPathPrefix Optional, this specifies the path prefix for weight
48366 * files, by default this is calculated from the path param.
48367 * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
48368 * the `fetch` from node-fetch can be used here.
48369 * - onProgress Optional, progress callback function, fired periodically
48370 * before the load is completed.
48371 * @returns An instance of `IOHandler`.
48372 *
48373 * @doc {
48374 * heading: 'Models',
48375 * subheading: 'Loading',
48376 * namespace: 'io',
48377 * ignoreCI: true
48378 * }
48379 */
48380 function http(path, loadOptions) {
48381 return new HTTPRequest(path, loadOptions);
48382 }
48383 /**
48384 * Deprecated. Use `tf.io.http`.
48385 * @param path
48386 * @param loadOptions
48387 */
48388 function browserHTTPRequest(path, loadOptions) {
48389 return http(path, loadOptions);
48390 }
48391
48392 /**
48393 * @license
48394 * Copyright 2018 Google LLC. All Rights Reserved.
48395 * Licensed under the Apache License, Version 2.0 (the "License");
48396 * you may not use this file except in compliance with the License.
48397 * You may obtain a copy of the License at
48398 *
48399 * http://www.apache.org/licenses/LICENSE-2.0
48400 *
48401 * Unless required by applicable law or agreed to in writing, software
48402 * distributed under the License is distributed on an "AS IS" BASIS,
48403 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
48404 * See the License for the specific language governing permissions and
48405 * limitations under the License.
48406 * =============================================================================
48407 */
48408 var PassthroughLoader = /*#__PURE__*/function () {
48409 function PassthroughLoader(modelArtifacts) {
48410 _classCallCheck(this, PassthroughLoader);
48411 this.modelArtifacts = modelArtifacts;
48412 }
48413 _createClass(PassthroughLoader, [{
48414 key: "load",
48415 value: function load() {
48416 return this.modelArtifacts;
48417 }
48418 }]);
48419 return PassthroughLoader;
48420 }();
48421 var PassthroughSaver = /*#__PURE__*/function () {
48422 function PassthroughSaver(saveHandler) {
48423 _classCallCheck(this, PassthroughSaver);
48424 this.saveHandler = saveHandler;
48425 }
48426 _createClass(PassthroughSaver, [{
48427 key: "save",
48428 value: function save(modelArtifacts) {
48429 return this.saveHandler(modelArtifacts);
48430 }
48431 }]);
48432 return PassthroughSaver;
48433 }();
48434 var PassthroughAsync = /*#__PURE__*/_createClass(function PassthroughAsync(handler) {
48435 _classCallCheck(this, PassthroughAsync);
48436 if (handler.load) {
48437 this.load = function () {
48438 return Promise.resolve(handler.load());
48439 };
48440 }
48441 if (handler.save) {
48442 this.save = function (modelArtifacts) {
48443 return Promise.resolve(handler.save(modelArtifacts));
48444 };
48445 }
48446 });
48447 /**
48448 * Creates an IOHandler that loads model artifacts from memory.
48449 *
48450 * When used in conjunction with `tf.loadLayersModel`, an instance of
48451 * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
48452 *
48453 * ```js
48454 * const model = await tf.loadLayersModel(tf.io.fromMemory(
48455 * modelTopology, weightSpecs, weightData));
48456 * ```
48457 *
48458 * @param modelArtifacts a object containing model topology (i.e., parsed from
48459 * the JSON format).
48460 * @param weightSpecs An array of `WeightsManifestEntry` objects describing the
48461 * names, shapes, types, and quantization of the weight data. Optional.
48462 * @param weightData A single `ArrayBuffer` containing the weight data,
48463 * concatenated in the order described by the weightSpecs. Optional.
48464 * @param trainingConfig Model training configuration. Optional.
48465 *
48466 * @returns A passthrough `IOHandler` that simply loads the provided data.
48467 */
48468 function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) {
48469 var args = arguments;
48470 return new PassthroughAsync(fromMemorySync.apply(void 0, _toConsumableArray(args)));
48471 }
48472 /**
48473 * Creates an IOHandler that loads model artifacts from memory.
48474 *
48475 * When used in conjunction with `tf.loadLayersModel`, an instance of
48476 * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
48477 *
48478 * ```js
48479 * const model = await tf.loadLayersModel(tf.io.fromMemory(
48480 * modelTopology, weightSpecs, weightData));
48481 * ```
48482 *
48483 * @param modelArtifacts a object containing model topology (i.e., parsed from
48484 * the JSON format).
48485 * @param weightSpecs An array of `WeightsManifestEntry` objects describing the
48486 * names, shapes, types, and quantization of the weight data. Optional.
48487 * @param weightData A single `ArrayBuffer` containing the weight data,
48488 * concatenated in the order described by the weightSpecs. Optional.
48489 * @param trainingConfig Model training configuration. Optional.
48490 *
48491 * @returns A passthrough `IOHandlerSync` that simply loads the provided data.
48492 */
48493 function fromMemorySync(modelArtifacts, weightSpecs, weightData, trainingConfig) {
48494 if (arguments.length === 1) {
48495 var isModelArtifacts = modelArtifacts.modelTopology != null || modelArtifacts.weightSpecs != null;
48496 if (isModelArtifacts) {
48497 return new PassthroughLoader(modelArtifacts);
48498 } else {
48499 // Legacy support: with only modelTopology.
48500 // TODO(cais): Remove this deprecated API.
48501 console.warn('Please call tf.io.fromMemory() with only one argument. ' + 'The argument should be of type ModelArtifacts. ' + 'The multi-argument signature of tf.io.fromMemory() has been ' + 'deprecated and will be removed in a future release.');
48502 return new PassthroughLoader({
48503 modelTopology: modelArtifacts
48504 });
48505 }
48506 } else {
48507 // Legacy support.
48508 // TODO(cais): Remove this deprecated API.
48509 console.warn('Please call tf.io.fromMemory() with only one argument. ' + 'The argument should be of type ModelArtifacts. ' + 'The multi-argument signature of tf.io.fromMemory() has been ' + 'deprecated and will be removed in a future release.');
48510 return new PassthroughLoader({
48511 modelTopology: modelArtifacts,
48512 weightSpecs: weightSpecs,
48513 weightData: weightData,
48514 trainingConfig: trainingConfig
48515 });
48516 }
48517 }
48518 /**
48519 * Creates an IOHandler that passes saved model artifacts to a callback.
48520 *
48521 * ```js
48522 * function handleSave(artifacts) {
48523 * // ... do something with the artifacts ...
48524 * return {modelArtifactsInfo: {...}, ...};
48525 * }
48526 *
48527 * const saveResult = model.save(tf.io.withSaveHandler(handleSave));
48528 * ```
48529 *
48530 * @param saveHandler A function that accepts a `ModelArtifacts` and returns a
48531 * promise that resolves to a `SaveResult`.
48532 */
48533 function withSaveHandler(saveHandler) {
48534 return new PassthroughSaver(saveHandler);
48535 }
48536 /**
48537 * Creates an IOHandlerSync that passes saved model artifacts to a callback.
48538 *
48539 * ```js
48540 * function handleSave(artifacts) {
48541 * // ... do something with the artifacts ...
48542 * return {modelArtifactsInfo: {...}, ...};
48543 * }
48544 *
48545 * const saveResult = model.save(tf.io.withSaveHandler(handleSave));
48546 * ```
48547 *
48548 * @param saveHandler A function that accepts a `ModelArtifacts` and returns a
48549 * `SaveResult`.
48550 */
48551 function withSaveHandlerSync(saveHandler) {
48552 return new PassthroughSaver(saveHandler);
48553 }
48554
48555 /**
48556 * @license
48557 * Copyright 2018 Google LLC. All Rights Reserved.
48558 * Licensed under the Apache License, Version 2.0 (the "License");
48559 * you may not use this file except in compliance with the License.
48560 * You may obtain a copy of the License at
48561 *
48562 * http://www.apache.org/licenses/LICENSE-2.0
48563 *
48564 * Unless required by applicable law or agreed to in writing, software
48565 * distributed under the License is distributed on an "AS IS" BASIS,
48566 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
48567 * See the License for the specific language governing permissions and
48568 * limitations under the License.
48569 * =============================================================================
48570 */
48571
48572 var io = {
48573 __proto__: null,
48574 CompositeArrayBuffer: CompositeArrayBuffer,
48575 browserFiles: browserFiles,
48576 browserHTTPRequest: browserHTTPRequest,
48577 concatenateArrayBuffers: concatenateArrayBuffers,
48578 copyModel: copyModel,
48579 decodeWeights: decodeWeights,
48580 decodeWeightsStream: decodeWeightsStream,
48581 encodeWeights: encodeWeights,
48582 fromMemory: fromMemory,
48583 fromMemorySync: fromMemorySync,
48584 getLoadHandlers: getLoadHandlers,
48585 getModelArtifactsForJSON: getModelArtifactsForJSON,
48586 getModelArtifactsForJSONSync: getModelArtifactsForJSONSync,
48587 getModelArtifactsInfoForJSON: getModelArtifactsInfoForJSON,
48588 getSaveHandlers: getSaveHandlers,
48589 getWeightSpecs: getWeightSpecs,
48590 http: http,
48591 isHTTPScheme: isHTTPScheme,
48592 listModels: listModels,
48593 loadWeights: loadWeights,
48594 moveModel: moveModel,
48595 registerLoadRouter: registerLoadRouter,
48596 registerSaveRouter: registerSaveRouter,
48597 removeModel: removeModel,
48598 weightsLoaderFactory: weightsLoaderFactory,
48599 withSaveHandler: withSaveHandler,
48600 withSaveHandlerSync: withSaveHandlerSync
48601 };
48602
48603 /**
48604 * @license
48605 * Copyright 2018 Google LLC. All Rights Reserved.
48606 * Licensed under the Apache License, Version 2.0 (the "License");
48607 * you may not use this file except in compliance with the License.
48608 * You may obtain a copy of the License at
48609 *
48610 * http://www.apache.org/licenses/LICENSE-2.0
48611 *
48612 * Unless required by applicable law or agreed to in writing, software
48613 * distributed under the License is distributed on an "AS IS" BASIS,
48614 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
48615 * See the License for the specific language governing permissions and
48616 * limitations under the License.
48617 * =============================================================================
48618 */
48619 /**
48620 * Computes the confusion matrix from true labels and predicted labels.
48621 *
48622 * ```js
48623 * const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32');
48624 * const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32');
48625 * const numClasses = 3;
48626 * const out = tf.math.confusionMatrix(labels, predictions, numClasses);
48627 * out.print();
48628 * // Expected output matrix:
48629 * // [[2, 0, 0],
48630 * // [0, 1, 1],
48631 * // [0, 0, 1]]
48632 * ```
48633 *
48634 * @param labels The target labels, assumed to be 0-based integers
48635 * for the classes. The shape is `[numExamples]`, where
48636 * `numExamples` is the number of examples included.
48637 * @param predictions The predicted classes, assumed to be
48638 * 0-based integers for the classes. Must have the same shape as `labels`.
48639 * @param numClasses Number of all classes, as an integer.
48640 * Its value must be larger than the largest element in `labels` and
48641 * `predictions`.
48642 * @returns The confusion matrix as a int32-type 2D tensor. The value at
48643 * row `r` and column `c` is the number of times examples of actual class
48644 * `r` were predicted as class `c`.
48645 *
48646 * @doc {heading: 'Operations', subheading: 'Evaluation'}
48647 */
48648 function confusionMatrix_(labels, predictions, numClasses) {
48649 var $labels = convertToTensor(labels, 'labels', 'confusionMatrix');
48650 var $predictions = convertToTensor(predictions, 'predictions', 'confusionMatrix');
48651 assert$1(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), function () {
48652 return "If provided, numClasses must be a positive integer, " + "but got ".concat(numClasses);
48653 });
48654 assert$1($labels.rank === 1, function () {
48655 return "Expected the rank of labels to be 1, but got ".concat($labels.rank);
48656 });
48657 assert$1($predictions.rank === 1, function () {
48658 return "Expected the rank of predictions to be 1, " + "but got ".concat($predictions.rank);
48659 });
48660 assert$1($labels.shape[0] === $predictions.shape[0], function () {
48661 return "Mismatch in the number of examples: " + "".concat($labels.shape[0], " vs. ").concat($predictions.shape[0], ". ") + "Labels and predictions should have the same number of elements.";
48662 });
48663 assert$1(numClasses > 0 && Number.isInteger(numClasses), function () {
48664 return "numClasses is required to be a positive integer, but got " + "".concat(numClasses);
48665 });
48666 // TODO(cais): In the future, if oneHot supports tensors inputs for
48667 // `numClasses`, `confusionMatrix` can make `numClasses` optional.
48668 var oneHotLabels = oneHot$3(cast$3($labels, 'int32'), numClasses);
48669 var oneHotPredictions = oneHot$3(cast$3($predictions, 'int32'), numClasses);
48670 var oneHotLabelsT = transpose$2(oneHotLabels);
48671 var product = matMul$1(oneHotLabelsT, oneHotPredictions);
48672 return cast$3(product, 'int32');
48673 }
48674 var confusionMatrix = /* @__PURE__ */op({
48675 confusionMatrix_: confusionMatrix_
48676 });
48677
48678 /**
48679 * @license
48680 * Copyright 2018 Google LLC. All Rights Reserved.
48681 * Licensed under the Apache License, Version 2.0 (the "License");
48682 * you may not use this file except in compliance with the License.
48683 * You may obtain a copy of the License at
48684 *
48685 * http://www.apache.org/licenses/LICENSE-2.0
48686 *
48687 * Unless required by applicable law or agreed to in writing, software
48688 * distributed under the License is distributed on an "AS IS" BASIS,
48689 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
48690 * See the License for the specific language governing permissions and
48691 * limitations under the License.
48692 * =============================================================================
48693 */
48694
48695 var math = {
48696 __proto__: null,
48697 confusionMatrix: confusionMatrix
48698 };
48699
48700 var fromPixels2DContext$1;
48701 var hasToPixelsWarned = false;
48702 /**
48703 * Creates a `tf.Tensor` from an image.
48704 *
48705 * ```js
48706 * const image = new ImageData(1, 1);
48707 * image.data[0] = 100;
48708 * image.data[1] = 150;
48709 * image.data[2] = 200;
48710 * image.data[3] = 255;
48711 *
48712 * tf.browser.fromPixels(image).print();
48713 * ```
48714 *
48715 * @param pixels The input image to construct the tensor from. The
48716 * supported image types are all 4-channel. You can also pass in an image
48717 * object with following attributes:
48718 * `{data: Uint8Array; width: number; height: number}`
48719 * @param numChannels The number of channels of the output tensor. A
48720 * numChannels value less than 4 allows you to ignore channels. Defaults to
48721 * 3 (ignores alpha channel of input image).
48722 *
48723 * @returns A Tensor3D with the shape `[height, width, numChannels]`.
48724 *
48725 * Note: fromPixels can be lossy in some cases, same image may result in
48726 * slightly different tensor values, if rendered by different rendering
48727 * engines. This means that results from different browsers, or even same
48728 * browser with CPU and GPU rendering engines can be different. See discussion
48729 * in details:
48730 * https://github.com/tensorflow/tfjs/issues/5482
48731 *
48732 * @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true}
48733 */
48734 function fromPixels_(pixels) {
48735 var numChannels = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 3;
48736 // Sanity checks.
48737 if (numChannels > 4) {
48738 throw new Error('Cannot construct Tensor with more than 4 channels from pixels.');
48739 }
48740 if (pixels == null) {
48741 throw new Error('pixels passed to tf.browser.fromPixels() can not be null');
48742 }
48743 var isPixelData = false;
48744 var isImageData = false;
48745 var isVideo = false;
48746 var isImage = false;
48747 var isCanvasLike = false;
48748 var isImageBitmap = false;
48749 if (pixels.data instanceof Uint8Array) {
48750 isPixelData = true;
48751 } else if (typeof ImageData !== 'undefined' && pixels instanceof ImageData) {
48752 isImageData = true;
48753 } else if (typeof HTMLVideoElement !== 'undefined' && pixels instanceof HTMLVideoElement) {
48754 isVideo = true;
48755 } else if (typeof HTMLImageElement !== 'undefined' && pixels instanceof HTMLImageElement) {
48756 isImage = true;
48757 // tslint:disable-next-line: no-any
48758 } else if (pixels.getContext != null) {
48759 isCanvasLike = true;
48760 } else if (typeof ImageBitmap !== 'undefined' && pixels instanceof ImageBitmap) {
48761 isImageBitmap = true;
48762 } else {
48763 throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' + "HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData " + "in browser, or OffscreenCanvas, ImageData in webworker" + " or {data: Uint32Array, width: number, height: number}, " + "but was ".concat(pixels.constructor.name));
48764 }
48765 // If the current backend has 'FromPixels' registered, it has a more
48766 // efficient way of handling pixel uploads, so we call that.
48767 var kernel = getKernel(FromPixels, ENGINE.backendName);
48768 if (kernel != null) {
48769 var inputs = {
48770 pixels: pixels
48771 };
48772 var attrs = {
48773 numChannels: numChannels
48774 };
48775 return ENGINE.runKernel(FromPixels, inputs, attrs);
48776 }
48777 var _ref = isVideo ? [pixels.videoWidth, pixels.videoHeight] : [pixels.width, pixels.height],
48778 _ref2 = _slicedToArray(_ref, 2),
48779 width = _ref2[0],
48780 height = _ref2[1];
48781 var vals;
48782 if (isCanvasLike) {
48783 vals =
48784 // tslint:disable-next-line:no-any
48785 pixels.getContext('2d').getImageData(0, 0, width, height).data;
48786 } else if (isImageData || isPixelData) {
48787 vals = pixels.data;
48788 } else if (isImage || isVideo || isImageBitmap) {
48789 if (fromPixels2DContext$1 == null) {
48790 if (typeof document === 'undefined') {
48791 if (typeof OffscreenCanvas !== 'undefined' && typeof OffscreenCanvasRenderingContext2D !== 'undefined') {
48792 // @ts-ignore
48793 fromPixels2DContext$1 = new OffscreenCanvas(1, 1).getContext('2d');
48794 } else {
48795 throw new Error('Cannot parse input in current context. ' + 'Reason: OffscreenCanvas Context2D rendering is not supported.');
48796 }
48797 } else {
48798 fromPixels2DContext$1 = document.createElement('canvas').getContext('2d', {
48799 willReadFrequently: true
48800 });
48801 }
48802 }
48803 fromPixels2DContext$1.canvas.width = width;
48804 fromPixels2DContext$1.canvas.height = height;
48805 fromPixels2DContext$1.drawImage(pixels, 0, 0, width, height);
48806 vals = fromPixels2DContext$1.getImageData(0, 0, width, height).data;
48807 }
48808 var values;
48809 if (numChannels === 4) {
48810 values = new Int32Array(vals);
48811 } else {
48812 var numPixels = width * height;
48813 values = new Int32Array(numPixels * numChannels);
48814 for (var i = 0; i < numPixels; i++) {
48815 for (var channel = 0; channel < numChannels; ++channel) {
48816 values[i * numChannels + channel] = vals[i * 4 + channel];
48817 }
48818 }
48819 }
48820 var outShape = [height, width, numChannels];
48821 return tensor3d(values, outShape, 'int32');
48822 }
48823 // Helper functions for |fromPixelsAsync| to check whether the input can
48824 // be wrapped into imageBitmap.
48825 function isPixelData(pixels) {
48826 return pixels != null && pixels.data instanceof Uint8Array;
48827 }
48828 function isImageBitmapFullySupported() {
48829 return typeof window !== 'undefined' && typeof ImageBitmap !== 'undefined' && window.hasOwnProperty('createImageBitmap');
48830 }
48831 function isNonEmptyPixels(pixels) {
48832 return pixels != null && pixels.width !== 0 && pixels.height !== 0;
48833 }
48834 function canWrapPixelsToImageBitmap(pixels) {
48835 return isImageBitmapFullySupported() && !(pixels instanceof ImageBitmap) && isNonEmptyPixels(pixels) && !isPixelData(pixels);
48836 }
48837 /**
48838 * Creates a `tf.Tensor` from an image in async way.
48839 *
48840 * ```js
48841 * const image = new ImageData(1, 1);
48842 * image.data[0] = 100;
48843 * image.data[1] = 150;
48844 * image.data[2] = 200;
48845 * image.data[3] = 255;
48846 *
48847 * (await tf.browser.fromPixelsAsync(image)).print();
48848 * ```
48849 * This API is the async version of fromPixels. The API will first
48850 * check |WRAP_TO_IMAGEBITMAP| flag, and try to wrap the input to
48851 * imageBitmap if the flag is set to true.
48852 *
48853 * @param pixels The input image to construct the tensor from. The
48854 * supported image types are all 4-channel. You can also pass in an image
48855 * object with following attributes:
48856 * `{data: Uint8Array; width: number; height: number}`
48857 * @param numChannels The number of channels of the output tensor. A
48858 * numChannels value less than 4 allows you to ignore channels. Defaults to
48859 * 3 (ignores alpha channel of input image).
48860 *
48861 * @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true}
48862 */
48863 function fromPixelsAsync(_x) {
48864 return _fromPixelsAsync.apply(this, arguments);
48865 }
48866 function _fromPixelsAsync() {
48867 _fromPixelsAsync = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(pixels) {
48868 var numChannels,
48869 inputs,
48870 imageBitmap,
48871 _args = arguments;
48872 return _regeneratorRuntime().wrap(function _callee$(_context) {
48873 while (1) switch (_context.prev = _context.next) {
48874 case 0:
48875 numChannels = _args.length > 1 && _args[1] !== undefined ? _args[1] : 3;
48876 inputs = null; // Check whether the backend needs to wrap |pixels| to imageBitmap and
48877 // whether |pixels| can be wrapped to imageBitmap.
48878 if (!(env().getBool('WRAP_TO_IMAGEBITMAP') && canWrapPixelsToImageBitmap(pixels))) {
48879 _context.next = 15;
48880 break;
48881 }
48882 _context.prev = 3;
48883 _context.next = 6;
48884 return createImageBitmap(pixels, {
48885 premultiplyAlpha: 'none'
48886 });
48887 case 6:
48888 imageBitmap = _context.sent;
48889 _context.next = 12;
48890 break;
48891 case 9:
48892 _context.prev = 9;
48893 _context.t0 = _context["catch"](3);
48894 imageBitmap = null;
48895 case 12:
48896 // createImageBitmap will clip the source size.
48897 // In some cases, the input will have larger size than its content.
48898 // E.g. new Image(10, 10) but with 1 x 1 content. Using
48899 // createImageBitmap will clip the size from 10 x 10 to 1 x 1, which
48900 // is not correct. We should avoid wrapping such resouce to
48901 // imageBitmap.
48902 if (imageBitmap != null && imageBitmap.width === pixels.width && imageBitmap.height === pixels.height) {
48903 inputs = imageBitmap;
48904 } else {
48905 inputs = pixels;
48906 }
48907 _context.next = 16;
48908 break;
48909 case 15:
48910 inputs = pixels;
48911 case 16:
48912 return _context.abrupt("return", fromPixels_(inputs, numChannels));
48913 case 17:
48914 case "end":
48915 return _context.stop();
48916 }
48917 }, _callee, null, [[3, 9]]);
48918 }));
48919 return _fromPixelsAsync.apply(this, arguments);
48920 }
48921 function validateImgTensor(img) {
48922 if (img.rank !== 2 && img.rank !== 3) {
48923 throw new Error("toPixels only supports rank 2 or 3 tensors, got rank ".concat(img.rank, "."));
48924 }
48925 var depth = img.rank === 2 ? 1 : img.shape[2];
48926 if (depth > 4 || depth === 2) {
48927 throw new Error("toPixels only supports depth of size " + "1, 3 or 4 but got ".concat(depth));
48928 }
48929 if (img.dtype !== 'float32' && img.dtype !== 'int32') {
48930 throw new Error("Unsupported type for toPixels: ".concat(img.dtype, ".") + " Please use float32 or int32 tensors.");
48931 }
48932 }
48933 function validateImageOptions(imageOptions) {
48934 var alpha = (imageOptions === null || imageOptions === void 0 ? void 0 : imageOptions.alpha) || 1;
48935 if (alpha > 1 || alpha < 0) {
48936 throw new Error("Alpha value ".concat(alpha, " is suppoed to be in range [0 - 1]."));
48937 }
48938 }
48939 /**
48940 * Draws a `tf.Tensor` of pixel values to a byte array or optionally a
48941 * canvas.
48942 *
48943 * When the dtype of the input is 'float32', we assume values in the range
48944 * [0-1]. Otherwise, when input is 'int32', we assume values in the range
48945 * [0-255].
48946 *
48947 * Returns a promise that resolves when the canvas has been drawn to.
48948 *
48949 * @param img A rank-2 tensor with shape `[height, width]`, or a rank-3 tensor
48950 * of shape `[height, width, numChannels]`. If rank-2, draws grayscale. If
48951 * rank-3, must have depth of 1, 3 or 4. When depth of 1, draws
48952 * grayscale. When depth of 3, we draw with the first three components of
48953 * the depth dimension corresponding to r, g, b and alpha = 1. When depth of
48954 * 4, all four components of the depth dimension correspond to r, g, b, a.
48955 * @param canvas The canvas to draw to.
48956 *
48957 * @doc {heading: 'Browser', namespace: 'browser'}
48958 */
48959 function toPixels(_x2, _x3) {
48960 return _toPixels.apply(this, arguments);
48961 }
48962 /**
48963 * Draws a `tf.Tensor` to a canvas.
48964 *
48965 * When the dtype of the input is 'float32', we assume values in the range
48966 * [0-1]. Otherwise, when input is 'int32', we assume values in the range
48967 * [0-255].
48968 *
48969 * @param image The tensor to draw on the canvas. Must match one of
48970 * these shapes:
48971 * - Rank-2 with shape `[height, width`]: Drawn as grayscale.
48972 * - Rank-3 with shape `[height, width, 1]`: Drawn as grayscale.
48973 * - Rank-3 with shape `[height, width, 3]`: Drawn as RGB with alpha set in
48974 * `imageOptions` (defaults to 1, which is opaque).
48975 * - Rank-3 with shape `[height, width, 4]`: Drawn as RGBA.
48976 * @param canvas The canvas to draw to.
48977 * @param options The configuration arguments for image to be drawn and the
48978 * canvas to draw to.
48979 *
48980 * @doc {heading: 'Browser', namespace: 'browser'}
48981 */
48982 function _toPixels() {
48983 _toPixels = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(img, canvas) {
48984 var $img, originalImgTensor, _$img$shape$slice, _$img$shape$slice2, height, width, depth, data, multiplier, bytes, i, rgba, d, value, j, kernel, ctx, imageData;
48985 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
48986 while (1) switch (_context2.prev = _context2.next) {
48987 case 0:
48988 $img = convertToTensor(img, 'img', 'toPixels');
48989 if (!(img instanceof Tensor)) {
48990 // Assume int32 if user passed a native array.
48991 originalImgTensor = $img;
48992 $img = cast$3(originalImgTensor, 'int32');
48993 originalImgTensor.dispose();
48994 }
48995 validateImgTensor($img);
48996 _$img$shape$slice = $img.shape.slice(0, 2), _$img$shape$slice2 = _slicedToArray(_$img$shape$slice, 2), height = _$img$shape$slice2[0], width = _$img$shape$slice2[1];
48997 depth = $img.rank === 2 ? 1 : $img.shape[2];
48998 _context2.next = 7;
48999 return $img.data();
49000 case 7:
49001 data = _context2.sent;
49002 multiplier = $img.dtype === 'float32' ? 255 : 1;
49003 bytes = new Uint8ClampedArray(width * height * 4);
49004 i = 0;
49005 case 11:
49006 if (!(i < height * width)) {
49007 _context2.next = 36;
49008 break;
49009 }
49010 rgba = [0, 0, 0, 255];
49011 d = 0;
49012 case 14:
49013 if (!(d < depth)) {
49014 _context2.next = 28;
49015 break;
49016 }
49017 value = data[i * depth + d];
49018 if (!($img.dtype === 'float32')) {
49019 _context2.next = 21;
49020 break;
49021 }
49022 if (!(value < 0 || value > 1)) {
49023 _context2.next = 19;
49024 break;
49025 }
49026 throw new Error("Tensor values for a float32 Tensor must be in the " + "range [0 - 1] but encountered ".concat(value, "."));
49027 case 19:
49028 _context2.next = 24;
49029 break;
49030 case 21:
49031 if (!($img.dtype === 'int32')) {
49032 _context2.next = 24;
49033 break;
49034 }
49035 if (!(value < 0 || value > 255)) {
49036 _context2.next = 24;
49037 break;
49038 }
49039 throw new Error("Tensor values for a int32 Tensor must be in the " + "range [0 - 255] but encountered ".concat(value, "."));
49040 case 24:
49041 if (depth === 1) {
49042 rgba[0] = value * multiplier;
49043 rgba[1] = value * multiplier;
49044 rgba[2] = value * multiplier;
49045 } else {
49046 rgba[d] = value * multiplier;
49047 }
49048 case 25:
49049 d++;
49050 _context2.next = 14;
49051 break;
49052 case 28:
49053 j = i * 4;
49054 bytes[j + 0] = Math.round(rgba[0]);
49055 bytes[j + 1] = Math.round(rgba[1]);
49056 bytes[j + 2] = Math.round(rgba[2]);
49057 bytes[j + 3] = Math.round(rgba[3]);
49058 case 33:
49059 ++i;
49060 _context2.next = 11;
49061 break;
49062 case 36:
49063 if (canvas != null) {
49064 if (!hasToPixelsWarned) {
49065 kernel = getKernel(Draw, ENGINE.backendName);
49066 if (kernel != null) {
49067 console.warn('tf.browser.toPixels is not efficient to draw tensor on canvas. ' + 'Please try tf.browser.draw instead.');
49068 hasToPixelsWarned = true;
49069 }
49070 }
49071 canvas.width = width;
49072 canvas.height = height;
49073 ctx = canvas.getContext('2d');
49074 imageData = new ImageData(bytes, width, height);
49075 ctx.putImageData(imageData, 0, 0);
49076 }
49077 if ($img !== img) {
49078 $img.dispose();
49079 }
49080 return _context2.abrupt("return", bytes);
49081 case 39:
49082 case "end":
49083 return _context2.stop();
49084 }
49085 }, _callee2);
49086 }));
49087 return _toPixels.apply(this, arguments);
49088 }
49089 function draw$1(image, canvas, options) {
49090 var $img = convertToTensor(image, 'img', 'draw');
49091 if (!(image instanceof Tensor)) {
49092 // Assume int32 if user passed a native array.
49093 var originalImgTensor = $img;
49094 $img = cast$3(originalImgTensor, 'int32');
49095 originalImgTensor.dispose();
49096 }
49097 validateImgTensor($img);
49098 validateImageOptions(options === null || options === void 0 ? void 0 : options.imageOptions);
49099 var inputs = {
49100 image: $img
49101 };
49102 var attrs = {
49103 canvas: canvas,
49104 options: options
49105 };
49106 ENGINE.runKernel(Draw, inputs, attrs);
49107 }
49108 var fromPixels$1 = /* @__PURE__ */op({
49109 fromPixels_: fromPixels_
49110 });
49111
49112 var browser = {
49113 __proto__: null,
49114 draw: draw$1,
49115 fromPixels: fromPixels$1,
49116 fromPixelsAsync: fromPixelsAsync,
49117 toPixels: toPixels
49118 };
49119
49120 /**
49121 * Validate gather nd inputs.
49122 *
49123 * @param tensor The tensor contains the source values.
49124 * @param indices The tensor contains the indices to slice the source.
49125 *
49126 * @returns [resultShape, numUpdates, sliceSize, strides]
49127 */
49128 function prepareAndValidate(tensor, indices) {
49129 var tensorRank = tensor.shape.length;
49130 var indicesRank = indices.shape.length;
49131 if (tensorRank < 1) {
49132 throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' + " but the rank was ".concat(tensorRank, "."));
49133 }
49134 if (indicesRank < 1) {
49135 throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' + " but the rank was ".concat(indicesRank, "."));
49136 }
49137 if (indices.dtype !== 'int32') {
49138 throw new Error('tf.gatherND() expects the indices to be int32 type,' + " but the dtype was ".concat(indices.dtype, "."));
49139 }
49140 if (indices.shape[indicesRank - 1] > tensorRank) {
49141 throw new Error('index innermost dimension length must be <= tensor rank; saw: ' + "".concat(indices.shape[indicesRank - 1], " vs. ").concat(tensorRank));
49142 }
49143 if (sizeFromShape(tensor.shape) === 0) {
49144 throw new Error('Requested more than 0 entries, but input is empty.' + " Input shape: ".concat(tensor.shape, "."));
49145 }
49146 var indicesShape = indices.shape;
49147 var sliceRank = indicesShape[indicesShape.length - 1];
49148 // The result shape is
49149 // indices.shape[:-1] + params.shape[indices.shape[-1]:]
49150 var nResult = 1;
49151 for (var i = 0; i < indicesShape.length - 1; ++i) {
49152 nResult *= indicesShape[i];
49153 }
49154 var inputShape = tensor.shape;
49155 var resultShape = indicesShape.slice();
49156 resultShape.pop();
49157 var sliceSize = 1;
49158 for (var _i = sliceRank; _i < tensorRank; ++_i) {
49159 sliceSize *= inputShape[_i];
49160 resultShape.push(inputShape[_i]);
49161 }
49162 var strides = [].concat(_toConsumableArray(computeStrides(tensor.shape).map(function (stride) {
49163 return stride / sliceSize;
49164 })), [1]).slice(0, sliceRank);
49165 return [resultShape, nResult, sliceSize, strides];
49166 }
49167
49168 var gather_nd_util = {
49169 __proto__: null,
49170 prepareAndValidate: prepareAndValidate
49171 };
49172
49173 var NEW_AXIS = -2;
49174 var SHRINK_AXIS = -1;
49175 function assertParamsValid(input, begin, size) {
49176 var inputRank = input.shape.length;
49177 assert$1(inputRank === begin.length, function () {
49178 return "Error in slice".concat(inputRank, "D: Length of begin ").concat(begin, " must ") + "match the rank of the array (".concat(inputRank, ").");
49179 });
49180 assert$1(inputRank === size.length, function () {
49181 return "Error in slice".concat(inputRank, "D: Length of size ").concat(size, " must ") + "match the rank of the array (".concat(inputRank, ").");
49182 });
49183 var _loop = function _loop(i) {
49184 assert$1(begin[i] + size[i] <= input.shape[i], function () {
49185 return "Error in slice".concat(inputRank, "D: begin[").concat(i, "] + size[").concat(i, "] ") + "(".concat(begin[i] + size[i], ") would overflow input.shape[").concat(i, "] (").concat(input.shape[i], ")");
49186 });
49187 };
49188 for (var i = 0; i < inputRank; ++i) {
49189 _loop(i);
49190 }
49191 }
49192 /** Converts a binary mask to an array of axes. Used in stridedSlice(). */
49193 function maskToAxes(mask) {
49194 var axes = [];
49195 var axis = 0;
49196 while (mask > 0) {
49197 if (mask & 1) {
49198 axes.push(axis);
49199 }
49200 mask /= 2;
49201 axis++;
49202 }
49203 return axes;
49204 }
49205 /** Computes the output shape given the strided slice params. */
49206 function computeOutShape$2(begin, end, strides) {
49207 var size = [];
49208 for (var axis = 0; axis < begin.length; axis++) {
49209 size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
49210 }
49211 return size;
49212 }
49213 // Creates full selection at the elided dimensions. If the dimension matches
49214 // the ellipsis mask, override the current stride value. Otherwise, insert.
49215 function stridesWithElidedDims(strides, ellipsisInsertionIndex, numElidedAxes, inputShape) {
49216 var newStrides = _toConsumableArray(strides);
49217 for (var i = newStrides.length; i < inputShape.length; i++) {
49218 newStrides.push(1);
49219 }
49220 for (var _i = 0; _i < numElidedAxes; _i++) {
49221 if (_i === 0) {
49222 newStrides[ellipsisInsertionIndex] = 1;
49223 } else {
49224 newStrides.splice(ellipsisInsertionIndex, 0 /* num elements to delete */, 1 /* element to add */);
49225 newStrides.pop();
49226 }
49227 }
49228 return newStrides;
49229 }
49230 function unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, normalizedAxis) {
49231 if (normalizedAxis <= ellipsisInsertionIndex) {
49232 return normalizedAxis;
49233 }
49234 return normalizedAxis - (numElidedAxes - 1);
49235 }
49236 function getElidedAxes(numElidedAxes, ellipsisInsertionIndex) {
49237 var elidedAxes = [];
49238 for (var i = 0; i < numElidedAxes; i++) {
49239 elidedAxes.push(ellipsisInsertionIndex + i);
49240 }
49241 return elidedAxes;
49242 }
49243 // Normalize the start, end and strides.
49244 function getNormalizedAxes(inputShape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask) {
49245 var inputRank = inputShape.length;
49246 var normalizedBegin = new Array(inputRank),
49247 normalizedEnd = new Array(inputRank),
49248 normalizedStrides = new Array(inputRank);
49249 if (ellipsisAxes.length && numInterpolatedAxes > 0) {
49250 var fullIndex = ellipsisAxes[0];
49251 // The ellipsis applies to the masked index as well as any dimensions
49252 // that are interpolated.
49253 var numElidedAxes = numInterpolatedAxes + 1;
49254 normalizedBegin = startIndicesWithElidedDims(beginMask, fullIndex, numElidedAxes, begin, inputShape);
49255 normalizedEnd = stopIndicesWithElidedDims(endMask, fullIndex, numElidedAxes, end, inputShape);
49256 normalizedStrides = stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape);
49257 } else {
49258 for (var axis = 0; axis < inputRank; axis++) {
49259 normalizedBegin[axis] = startForAxis(beginMask, begin, strides, inputShape, axis, ellipsisMask);
49260 normalizedEnd[axis] = stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask);
49261 normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask);
49262 }
49263 }
49264 return {
49265 begin: normalizedBegin,
49266 end: normalizedEnd,
49267 strides: normalizedStrides
49268 };
49269 }
49270 // Creates full selection at the elided dimensions. If the dimension matches
49271 // the ellipsis mask, override the current start value. Otherwise, insert.
49272 function startIndicesWithElidedDims(beginMask, ellipsisInsertionIndex, numElidedAxes, originalBegin, inputShape) {
49273 var newIndices = _toConsumableArray(inputShape);
49274 var elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
49275 for (var axis = 0; axis < newIndices.length; axis++) {
49276 if (elidedAxes.indexOf(axis) > -1) {
49277 newIndices[axis] = 0;
49278 } else {
49279 var originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
49280 var originalValue = originalBegin[originalAxis];
49281 if (beginMask & 1 << originalAxis) {
49282 originalValue = 0;
49283 }
49284 newIndices[axis] = originalValue;
49285 }
49286 }
49287 return newIndices;
49288 }
49289 // Creates full selection at the elided dimensions. If the dimension matches
49290 // the ellipsis mask, override the current stop value. Otherwise, insert.
49291 function stopIndicesWithElidedDims(endMask, ellipsisInsertionIndex, numElidedAxes, originalEnd, inputShape) {
49292 var newIndices = _toConsumableArray(inputShape);
49293 var elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
49294 for (var axis = 0; axis < newIndices.length; axis++) {
49295 if (elidedAxes.indexOf(axis) > -1) {
49296 newIndices[axis] = Number.MAX_SAFE_INTEGER;
49297 } else {
49298 var originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
49299 var originalValue = originalEnd[originalAxis];
49300 if (endMask & 1 << originalAxis) {
49301 originalValue = Number.MAX_SAFE_INTEGER;
49302 }
49303 newIndices[axis] = originalValue;
49304 }
49305 }
49306 for (var i = 0; i < newIndices.length; i++) {
49307 // Handle negative indices
49308 var axisSize = inputShape[i];
49309 if (newIndices[i] < 0) {
49310 newIndices[i] += axisSize;
49311 }
49312 newIndices[i] = clamp(0, newIndices[i], inputShape[i]);
49313 }
49314 return newIndices;
49315 }
49316 function stridesForAxis(strides, axis, ellipsisMask) {
49317 var stride = strides[axis];
49318 if (ellipsisMask & 1 << axis || stride == null) {
49319 stride = 1;
49320 }
49321 return stride;
49322 }
49323 function startForAxis(beginMask, startIndices, strides, inputShape, axis, ellipsisMask) {
49324 // Begin with the specified index
49325 var start = startIndices[axis];
49326 var stride = strides[axis] || 1;
49327 // Check the axis bit from right of masked axes, or the begin index is not set
49328 // for the axis.
49329 if (beginMask & 1 << axis || ellipsisMask & 1 << axis || start == null) {
49330 if (stride > 0) {
49331 // Forward iteration - use the first element. These values will get
49332 // clamped below (Note: We could have set them to 0 and axis_size-1, but
49333 // use lowest() and max() to maintain symmetry with StopForAxis())
49334 start = Number.MIN_SAFE_INTEGER;
49335 } else {
49336 // Backward iteration - use the last element.
49337 start = Number.MAX_SAFE_INTEGER;
49338 }
49339 }
49340 // Handle negative indices
49341 var axisSize = inputShape[axis];
49342 if (start < 0) {
49343 start += axisSize;
49344 }
49345 // Clamping
49346 start = clamp(0, start, axisSize - 1);
49347 return start;
49348 }
49349 function stopForAxis(endMask, stopIndices, strides, inputShape, axis, ellipsisMask) {
49350 // Begin with the specified index
49351 var stop = stopIndices[axis];
49352 var stride = strides[axis] || 1;
49353 // Check the axis bit from right of masked axes, or if the stop index is not
49354 // set for this axis.
49355 if (endMask & 1 << axis || ellipsisMask & 1 << axis || stop == null) {
49356 if (stride > 0) {
49357 // Forward iteration - use the last element. These values will get
49358 // clamped below
49359 stop = Number.MAX_SAFE_INTEGER;
49360 } else {
49361 // Backward iteration - use the first element.
49362 stop = Number.MIN_SAFE_INTEGER;
49363 }
49364 }
49365 // Handle negative indices
49366 var axisSize = inputShape[axis];
49367 if (stop < 0) {
49368 stop += axisSize;
49369 }
49370 // Clamping
49371 // Because the end index points one past the last element, we need slightly
49372 // different clamping ranges depending on the direction.
49373 if (stride > 0) {
49374 // Forward iteration
49375 stop = clamp(0, stop, axisSize);
49376 } else {
49377 // Backward iteration
49378 stop = clamp(-1, stop, axisSize - 1);
49379 }
49380 return stop;
49381 }
49382 /**
49383 * Returns true if the slice occupies a continous set of elements in the
49384 * 'flat' space.
49385 */
49386 function isSliceContinous(shape, begin, size) {
49387 // Index of the first axis that has size > 1.
49388 var firstNonOneAxis = size.length;
49389 for (var i = 0; i < size.length; i++) {
49390 if (size[i] > 1) {
49391 firstNonOneAxis = i;
49392 break;
49393 }
49394 }
49395 for (var _i2 = firstNonOneAxis + 1; _i2 < size.length; _i2++) {
49396 if (begin[_i2] > 0 || size[_i2] !== shape[_i2]) {
49397 return false;
49398 }
49399 }
49400 return true;
49401 }
49402 function computeFlatOffset(begin, strides) {
49403 var flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1;
49404 for (var i = 0; i < begin.length - 1; i++) {
49405 flatOffset += begin[i] * strides[i];
49406 }
49407 return flatOffset;
49408 }
49409 function parseSliceParams(x, begin, size) {
49410 // The following logic allows for more ergonomic calls.
49411 var begin_;
49412 var xRank = x.shape.length;
49413 if (typeof begin === 'number') {
49414 begin_ = [begin].concat(_toConsumableArray(new Array(xRank - 1).fill(0)));
49415 } else if (begin.length < xRank) {
49416 begin_ = begin.concat(new Array(xRank - begin.length).fill(0));
49417 } else {
49418 begin_ = begin.slice();
49419 }
49420 begin_.forEach(function (d) {
49421 assert$1(d !== -1, function () {
49422 return 'slice() does not support negative begin indexing.';
49423 });
49424 });
49425 var size_;
49426 if (size == null) {
49427 size_ = new Array(xRank).fill(-1);
49428 } else if (typeof size === 'number') {
49429 size_ = [size].concat(_toConsumableArray(new Array(xRank - 1).fill(-1)));
49430 } else if (size.length < xRank) {
49431 size_ = size.concat(new Array(xRank - size.length).fill(-1));
49432 } else {
49433 size_ = size;
49434 }
49435 size_ = size_.map(function (d, i) {
49436 if (d >= 0) {
49437 return d;
49438 } else {
49439 assert$1(d === -1, function () {
49440 return "Negative size values should be exactly -1 but got " + "".concat(d, " for the slice() size at index ").concat(i, ".");
49441 });
49442 return x.shape[i] - begin_[i];
49443 }
49444 });
49445 return [begin_, size_];
49446 }
49447 // Convert the slicing specification from a sparse representation to a dense
49448 // representation. This means that all ellipses and newaxis are expanded out.
49449 function sliceInfo(xShape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
49450 var stridesNonNull;
49451 if (strides == null) {
49452 stridesNonNull = new Array(begin.length);
49453 stridesNonNull.fill(1);
49454 } else {
49455 stridesNonNull = strides;
49456 }
49457 // Only one non-zero bit is allowed in ellipsisMask, which means ellipsisMask
49458 // is a power of 2. Use bit compares to ensure ellipsisMask is 0 or a power
49459 // of 2. When i is a power of 2, i & (i - 1) is always 0.
49460 // Also ref:
49461 // https://stackoverflow.com/questions/600293/how-to-check-if-a-number-is-a-power-of-2
49462 if (ellipsisMask != null && (ellipsisMask & ellipsisMask - 1) !== 0) {
49463 throw new Error('Multiple ellipses in slice is not allowed.');
49464 }
49465 // Step 1: Account for ellipsis and new axis.
49466 // Check for ellipsis and count how many non-newaxis there are after.
49467 var ellipsisSeen = false;
49468 var sparseSpec = {
49469 dims: stridesNonNull.length,
49470 numAddAxisAfterEllipsis: 0,
49471 begin: begin.slice(),
49472 end: end.slice(),
49473 strides: stridesNonNull.slice(),
49474 beginMask: beginMask,
49475 endMask: endMask,
49476 ellipsisMask: ellipsisMask,
49477 newAxisMask: newAxisMask,
49478 shrinkAxisMask: shrinkAxisMask
49479 };
49480 for (var i = 0; i < sparseSpec.dims; i++) {
49481 if (ellipsisSeen && (1 << i & newAxisMask) !== 0) {
49482 sparseSpec.numAddAxisAfterEllipsis++;
49483 }
49484 if (1 << i & ellipsisMask) {
49485 ellipsisSeen = true;
49486 }
49487 }
49488 // If no ellipsis insert one at the end.
49489 if (!ellipsisSeen) {
49490 sparseSpec.ellipsisMask |= 1 << sparseSpec.dims;
49491 sparseSpec.dims++; // this effects loop iteration below
49492 }
49493 // Step 2: Make a sparse spec into a full index spec.
49494 //
49495 // The sparse spec deos not correspond to the number of dimensions.
49496 // Make a dense spec that cooresponds to the number of dimensions.
49497 //
49498 // For example suppose foo[...,3:] on foo.shape = [2, 2, 3] then we need to
49499 // produce the missing beginMask for the first two dimensions i.e. from
49500 // beginMaskSpec = 0, endMaskSpec = 2, we achieve beginMask = 6 (110),
49501 // endMask = 7 (111).
49502 var denseSpec = {
49503 dims: xShape.length,
49504 beginMask: 0,
49505 endMask: 0,
49506 beginValid: false,
49507 endValid: false
49508 };
49509 buildDenseSpec(sparseSpec, denseSpec);
49510 // Step 3: Make implicit ranges (non-zero beginMasks and endMasks) explicit
49511 // and bounds check.
49512 var isIdentity = true;
49513 var sliceDim0 = true;
49514 var isSimpleSlice = true;
49515 var processingShape = [];
49516 var finalShape = [];
49517 for (var _i3 = 0; _i3 < xShape.length; ++_i3) {
49518 if (denseSpec.strides[_i3] === 0) {
49519 throw Error("strides[".concat(_i3, "] must be non-zero"));
49520 }
49521 var shrinkI = !!(denseSpec.shrinkAxisMask & 1 << _i3);
49522 var dimI = xShape[_i3];
49523 if (dimI === -1) {
49524 processingShape.push(shrinkI ? 1 : -1);
49525 continue;
49526 }
49527 var masks = [denseSpec.beginMask & 1 << _i3, denseSpec.endMask & 1 << _i3];
49528 var validRange = [denseSpec.strides[_i3] > 0 ? 0 : -1, denseSpec.strides[_i3] > 0 ? dimI : dimI - 1];
49529 if (shrinkI && denseSpec.strides[_i3] <= 0) {
49530 throw Error('only stride 1 allowed on non-range indexing.');
49531 }
49532 isSimpleSlice = isSimpleSlice && denseSpec.strides[_i3] === 1;
49533 var beginAndEndMasked = !!(denseSpec.beginMask & 1 << _i3 && denseSpec.endMask & 1 << _i3);
49534 if (denseSpec.beginValid && denseSpec.endValid) {
49535 if (shrinkI) {
49536 // If we are shrinking, the end index is now possibly incorrect. In
49537 // particular foo[-1] produces sparseBegin = -1, sparseEnd = 0.
49538 // and canonical puts these to n-1 and 0, which implies a degenerate
49539 // interval. Fortunately, it is now safe to re-create end as begin + 1.
49540 var xFwd = denseSpec.begin[_i3] < 0 ? dimI + denseSpec.begin[_i3] : denseSpec.begin[_i3];
49541 denseSpec.begin[_i3] = xFwd;
49542 denseSpec.end[_i3] = denseSpec.begin[_i3] + 1;
49543 if (xFwd < 0 || xFwd >= dimI) {
49544 throw Error("slice index ".concat(denseSpec.begin[_i3], " of dimension ").concat(_i3, " out of bounds."));
49545 }
49546 } else {
49547 denseSpec.begin[_i3] = canonical(denseSpec.begin[_i3], 0, denseSpec.strides[_i3], dimI, masks, validRange);
49548 denseSpec.end[_i3] = canonical(denseSpec.end[_i3], 1, denseSpec.strides[_i3], dimI, masks, validRange);
49549 }
49550 // Update optimization values
49551 var takeAllInDimension = denseSpec.strides[_i3] === 1 && denseSpec.begin[_i3] === 0 && denseSpec.end[_i3] === dimI;
49552 isIdentity = isIdentity && takeAllInDimension;
49553 sliceDim0 = sliceDim0 && (_i3 === 0 && denseSpec.strides[_i3] === 1 || takeAllInDimension);
49554 } else {
49555 isIdentity = isIdentity && denseSpec.strides[_i3] === 1 && beginAndEndMasked;
49556 sliceDim0 = sliceDim0 && (_i3 === 0 && denseSpec.strides[_i3] === 1 || beginAndEndMasked);
49557 }
49558 // Compute the processing shape (the intermediate Eigen will produce)
49559 var intervalLength = void 0;
49560 var knownInterval = false;
49561 if (denseSpec.beginValid && denseSpec.endValid) {
49562 intervalLength = denseSpec.end[_i3] - denseSpec.begin[_i3];
49563 knownInterval = true;
49564 } else if (shrinkI) {
49565 // The dimension is still known as 1 for the processingShape, but will be
49566 // discarded for the final shape.
49567 intervalLength = 1;
49568 knownInterval = true;
49569 } else if (beginAndEndMasked) {
49570 // Even if we don't have values for begin or end, we do know that this
49571 // dimension covers the whole interval. If we have shape information for
49572 // this dimension, that tells us the interval length.
49573 if (dimI >= 0) {
49574 if (denseSpec.strides[_i3] < 0) {
49575 intervalLength = -dimI;
49576 } else {
49577 intervalLength = dimI;
49578 }
49579 knownInterval = true;
49580 }
49581 }
49582 if (knownInterval) {
49583 var sizeI = void 0;
49584 // Hold zero if the interval is degenerate, otherwise account for
49585 // remainder
49586 if (intervalLength === 0 || intervalLength < 0 !== denseSpec.strides[_i3] < 0) {
49587 sizeI = 0;
49588 } else {
49589 sizeI = Math.trunc(intervalLength / denseSpec.strides[_i3]) + (intervalLength % denseSpec.strides[_i3] !== 0 ? 1 : 0);
49590 }
49591 processingShape.push(sizeI);
49592 } else {
49593 processingShape.push(-1);
49594 }
49595 }
49596 // Step 4: Compute the final shape
49597 //
49598 // newAxis will increase dimension by 1 (with a one-size dimension)
49599 // slices like foo[3, ...] will reduce dimension by 1.
49600 // This cannot be done earlier, because it depends on Step 3.
49601 for (var denseDim = 0; denseDim < denseSpec.finalShapeGatherIndices.length; ++denseDim) {
49602 var gatherIndex = denseSpec.finalShapeGatherIndices[denseDim];
49603 if (gatherIndex >= 0) {
49604 finalShape.push(processingShape[gatherIndex]);
49605 } else if (gatherIndex === NEW_AXIS) {
49606 finalShape.push(1);
49607 }
49608 }
49609 var finalShapeSparse = finalShape.filter(function (dim, i) {
49610 return denseSpec.finalShapeGatherIndices[i] !== NEW_AXIS;
49611 });
49612 return {
49613 finalShapeSparse: finalShapeSparse,
49614 finalShape: finalShape,
49615 isIdentity: isIdentity,
49616 sliceDim0: sliceDim0,
49617 isSimpleSlice: isSimpleSlice,
49618 begin: denseSpec.begin,
49619 end: denseSpec.end,
49620 strides: denseSpec.strides
49621 };
49622 }
49623 function buildDenseSpec(sparse, dense) {
49624 dense.beginMask = 0;
49625 dense.endMask = 0;
49626 dense.shrinkAxisMask = 0;
49627 var fullIndex = 0;
49628 dense.beginValid = sparse.begin != null;
49629 dense.endValid = sparse.end != null;
49630 dense.begin = new Array(dense.dims);
49631 dense.end = new Array(dense.dims);
49632 dense.strides = new Array(dense.dims);
49633 dense.finalShapeGatherIndices = [];
49634 dense.finalShapeGatherIndicesSparse = [];
49635 dense.inputShapeGatherIndicesSparse = new Array(dense.dims);
49636 for (var i = 0; i < sparse.dims; i++) {
49637 if (1 << i & sparse.ellipsisMask) {
49638 // Only the bit that has ellipsis will fall in this condition.
49639 // Expand the ellipsis into the appropriate indices
49640 // Note: this only works because we guaranteed one ellipsis.
49641 var nextIndex = Math.min(dense.dims - (sparse.dims - i) + 1 + sparse.numAddAxisAfterEllipsis, dense.dims);
49642 for (; fullIndex < nextIndex; fullIndex++) {
49643 // newAxis aren't real axis so you have to skip.
49644 dense.begin[fullIndex] = 0;
49645 dense.end[fullIndex] = 0;
49646 dense.strides[fullIndex] = 1;
49647 dense.beginMask |= 1 << fullIndex;
49648 dense.endMask |= 1 << fullIndex;
49649 dense.finalShapeGatherIndices.push(fullIndex);
49650 dense.finalShapeGatherIndicesSparse.push(-1);
49651 dense.inputShapeGatherIndicesSparse[fullIndex] = i;
49652 }
49653 } else if (1 << i & sparse.newAxisMask) {
49654 // Only the bit that has newAxis will fall in this condition.
49655 dense.finalShapeGatherIndices.push(NEW_AXIS);
49656 dense.finalShapeGatherIndicesSparse.push(-1);
49657 } else {
49658 if (fullIndex === dense.begin.length) {
49659 throw Error("Index out of range using input dim ".concat(fullIndex, "; input ") + "has only ".concat(dense.dims, " dims, ").concat(dense.begin.length, "."));
49660 }
49661 // Gather slicing spec into appropriate index.
49662 if (sparse.begin != null) {
49663 dense.begin[fullIndex] = sparse.begin[i];
49664 }
49665 if (sparse.end != null) {
49666 dense.end[fullIndex] = sparse.end[i];
49667 }
49668 dense.strides[fullIndex] = sparse.strides[i];
49669 if (sparse.beginMask & 1 << i) {
49670 dense.beginMask |= 1 << fullIndex;
49671 }
49672 if (sparse.endMask & 1 << i) {
49673 dense.endMask |= 1 << fullIndex;
49674 }
49675 // If shrink, record where to get the dimensionality from (i.e. newAxis)
49676 // creates a fake 1 size dimension. Also remember shrink axis (now in
49677 // dense form) so we can ignore dense.end below.
49678 if (sparse.shrinkAxisMask & 1 << i) {
49679 dense.finalShapeGatherIndices.push(SHRINK_AXIS);
49680 dense.finalShapeGatherIndicesSparse.push(-1);
49681 dense.shrinkAxisMask |= 1 << fullIndex;
49682 } else {
49683 dense.finalShapeGatherIndices.push(fullIndex);
49684 // Remember that where in the sparse shape the dense dim comes from.
49685 dense.finalShapeGatherIndicesSparse.push(i);
49686 }
49687 dense.inputShapeGatherIndicesSparse[fullIndex] = i;
49688 fullIndex++;
49689 }
49690 }
49691 }
49692 function canonical(x, c, strideI, dimI, masks, validRange) {
49693 if (masks[c]) {
49694 return strideI > 0 ? validRange[c] : validRange[c + 1 & 1];
49695 } else {
49696 var xFwd = x < 0 ? dimI + x : x; // make negative indices positive
49697 return xFwd < validRange[0] ? validRange[0] : xFwd > validRange[1] ? validRange[1] : xFwd;
49698 }
49699 }
49700
49701 var slice_util = {
49702 __proto__: null,
49703 assertParamsValid: assertParamsValid,
49704 computeFlatOffset: computeFlatOffset,
49705 computeOutShape: computeOutShape$2,
49706 getNormalizedAxes: getNormalizedAxes,
49707 isSliceContinous: isSliceContinous,
49708 maskToAxes: maskToAxes,
49709 parseSliceParams: parseSliceParams,
49710 sliceInfo: sliceInfo,
49711 startForAxis: startForAxis,
49712 startIndicesWithElidedDims: startIndicesWithElidedDims,
49713 stopForAxis: stopForAxis,
49714 stopIndicesWithElidedDims: stopIndicesWithElidedDims,
49715 stridesForAxis: stridesForAxis,
49716 stridesWithElidedDims: stridesWithElidedDims
49717 };
49718
49719 /** @license See the LICENSE file. */
49720 // This code is auto-generated, do not modify this file!
49721 var version$7 = '4.22.0';
49722
49723 var OptimizerConstructors = /*#__PURE__*/function () {
49724 function OptimizerConstructors() {
49725 _classCallCheck(this, OptimizerConstructors);
49726 }
49727 _createClass(OptimizerConstructors, null, [{
49728 key: "sgd",
49729 value:
49730 /**
49731 * Constructs a `tf.SGDOptimizer` that uses stochastic gradient descent.
49732 *
49733 * ```js
49734 * // Fit a quadratic function by learning the coefficients a, b, c.
49735 * const xs = tf.tensor1d([0, 1, 2, 3]);
49736 * const ys = tf.tensor1d([1.1, 5.9, 16.8, 33.9]);
49737 *
49738 * const a = tf.scalar(Math.random()).variable();
49739 * const b = tf.scalar(Math.random()).variable();
49740 * const c = tf.scalar(Math.random()).variable();
49741 *
49742 * // y = a * x^2 + b * x + c.
49743 * const f = x => a.mul(x.square()).add(b.mul(x)).add(c);
49744 * const loss = (pred, label) => pred.sub(label).square().mean();
49745 *
49746 * const learningRate = 0.01;
49747 * const optimizer = tf.train.sgd(learningRate);
49748 *
49749 * // Train the model.
49750 * for (let i = 0; i < 10; i++) {
49751 * optimizer.minimize(() => loss(f(xs), ys));
49752 * }
49753 *
49754 * // Make predictions.
49755 * console.log(
49756 * `a: ${a.dataSync()}, b: ${b.dataSync()}, c: ${c.dataSync()}`);
49757 * const preds = f(xs).dataSync();
49758 * preds.forEach((pred, i) => {
49759 * console.log(`x: ${i}, pred: ${pred}`);
49760 * });
49761 * ```
49762 *
49763 * @param learningRate The learning rate to use for the SGD algorithm.
49764 *
49765 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
49766 */
49767 function sgd(learningRate) {
49768 return new SGDOptimizer(learningRate);
49769 }
49770 /**
49771 * Constructs a `tf.MomentumOptimizer` that uses momentum gradient
49772 * descent.
49773 *
49774 * See
49775 * [http://proceedings.mlr.press/v28/sutskever13.pdf](
49776 * http://proceedings.mlr.press/v28/sutskever13.pdf)
49777 *
49778 * @param learningRate The learning rate to use for the Momentum gradient
49779 * descent algorithm.
49780 * @param momentum The momentum to use for the momentum gradient descent
49781 * algorithm.
49782 *
49783 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
49784 */
49785 }, {
49786 key: "momentum",
49787 value: function momentum(learningRate, _momentum) {
49788 var useNesterov = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
49789 return new MomentumOptimizer(learningRate, _momentum, useNesterov);
49790 }
49791 /**
49792 * Constructs a `tf.RMSPropOptimizer` that uses RMSProp gradient
49793 * descent. This implementation uses plain momentum and is not centered
49794 * version of RMSProp.
49795 *
49796 * See
49797 * [http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf](
49798 * http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
49799 *
49800 * @param learningRate The learning rate to use for the RMSProp gradient
49801 * descent algorithm.
49802 * @param decay The discounting factor for the history/coming gradient.
49803 * @param momentum The momentum to use for the RMSProp gradient descent
49804 * algorithm.
49805 * @param epsilon Small value to avoid zero denominator.
49806 * @param centered If true, gradients are normalized by the estimated
49807 * variance of the gradient.
49808 *
49809 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
49810 */
49811 }, {
49812 key: "rmsprop",
49813 value: function rmsprop(learningRate) {
49814 var decay = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : .9;
49815 var momentum = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0.0;
49816 var epsilon = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : null;
49817 var centered = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
49818 return new RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered);
49819 }
49820 /**
49821 * Constructs a `tf.AdamOptimizer` that uses the Adam algorithm.
49822 * See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
49823 *
49824 * @param learningRate The learning rate to use for the Adam gradient
49825 * descent algorithm.
49826 * @param beta1 The exponential decay rate for the 1st moment estimates.
49827 * @param beta2 The exponential decay rate for the 2nd moment estimates.
49828 * @param epsilon A small constant for numerical stability.
49829 *
49830 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
49831 */
49832 }, {
49833 key: "adam",
49834 value: function adam() {
49835 var learningRate = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 0.001;
49836 var beta1 = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0.9;
49837 var beta2 = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0.999;
49838 var epsilon = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : null;
49839 return new AdamOptimizer(learningRate, beta1, beta2, epsilon);
49840 }
49841 /**
49842 * Constructs a `tf.AdadeltaOptimizer` that uses the Adadelta algorithm.
49843 * See [https://arxiv.org/abs/1212.5701](https://arxiv.org/abs/1212.5701)
49844 *
49845 * @param learningRate The learning rate to use for the Adadelta gradient
49846 * descent algorithm.
49847 * @param rho The learning rate decay over each update.
49848 * @param epsilon A constant epsilon used to better condition the grad
49849 * update.
49850 *
49851 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
49852 */
49853 }, {
49854 key: "adadelta",
49855 value: function adadelta() {
49856 var learningRate = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : .001;
49857 var rho = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : .95;
49858 var epsilon = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : null;
49859 return new AdadeltaOptimizer(learningRate, rho, epsilon);
49860 }
49861 /**
49862 * Constructs a `tf.AdamaxOptimizer` that uses the Adamax algorithm.
49863 * See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
49864 *
49865 * @param learningRate The learning rate to use for the Adamax gradient
49866 * descent algorithm.
49867 * @param beta1 The exponential decay rate for the 1st moment estimates.
49868 * @param beta2 The exponential decay rate for the 2nd moment estimates.
49869 * @param epsilon A small constant for numerical stability.
49870 * @param decay The learning rate decay over each update.
49871 *
49872 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
49873 */
49874 }, {
49875 key: "adamax",
49876 value: function adamax() {
49877 var learningRate = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 0.002;
49878 var beta1 = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0.9;
49879 var beta2 = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0.999;
49880 var epsilon = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : null;
49881 var decay = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 0.0;
49882 return new AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay);
49883 }
49884 /**
49885 * Constructs a `tf.AdagradOptimizer` that uses the Adagrad algorithm.
49886 * See
49887 * [http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf](
49888 * http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
49889 * or
49890 * [http://ruder.io/optimizing-gradient-descent/index.html#adagrad](
49891 * http://ruder.io/optimizing-gradient-descent/index.html#adagrad)
49892 *
49893 * @param learningRate The learning rate to use for the Adagrad gradient
49894 * descent algorithm.
49895 * @param initialAccumulatorValue Starting value for the accumulators, must be
49896 * positive.
49897 *
49898 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
49899 */
49900 }, {
49901 key: "adagrad",
49902 value: function adagrad(learningRate) {
49903 var initialAccumulatorValue = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0.1;
49904 return new AdagradOptimizer(learningRate, initialAccumulatorValue);
49905 }
49906 }]);
49907 return OptimizerConstructors;
49908 }();
49909
49910 /**
49911 * @license
49912 * Copyright 2018 Google LLC. All Rights Reserved.
49913 * Licensed under the Apache License, Version 2.0 (the "License");
49914 * you may not use this file except in compliance with the License.
49915 * You may obtain a copy of the License at
49916 *
49917 * http://www.apache.org/licenses/LICENSE-2.0
49918 *
49919 * Unless required by applicable law or agreed to in writing, software
49920 * distributed under the License is distributed on an "AS IS" BASIS,
49921 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
49922 * See the License for the specific language governing permissions and
49923 * limitations under the License.
49924 * =============================================================================
49925 */
49926 var train = OptimizerConstructors;
49927
49928 /**
49929 * @license
49930 * Copyright 2017 Google LLC. All Rights Reserved.
49931 * Licensed under the Apache License, Version 2.0 (the "License");
49932 * you may not use this file except in compliance with the License.
49933 * You may obtain a copy of the License at
49934 *
49935 * http://www.apache.org/licenses/LICENSE-2.0
49936 *
49937 * Unless required by applicable law or agreed to in writing, software
49938 * distributed under the License is distributed on an "AS IS" BASIS,
49939 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
49940 * See the License for the specific language governing permissions and
49941 * limitations under the License.
49942 * =============================================================================
49943 */
49944 var delayCallback = function () {
49945 if (typeof requestAnimationFrame !== 'undefined') {
49946 return requestAnimationFrame;
49947 } else if (typeof setImmediate !== 'undefined') {
49948 return setImmediate;
49949 }
49950 return function (f) {
49951 return f();
49952 }; // no delays
49953 }();
49954 /**
49955 * Returns a promise that resolves when a requestAnimationFrame has completed.
49956 *
49957 * On Node.js this uses setImmediate instead of requestAnimationFrame.
49958 *
49959 * This is simply a sugar method so that users can do the following:
49960 * `await tf.nextFrame();`
49961 *
49962 * @doc {heading: 'Performance', subheading: 'Timing'}
49963 */
49964 function nextFrame() {
49965 return new Promise(function (resolve) {
49966 return delayCallback(function () {
49967 return resolve();
49968 });
49969 });
49970 }
49971
49972 /**
49973 * @license
49974 * Copyright 2017 Google LLC. All Rights Reserved.
49975 * Licensed under the Apache License, Version 2.0 (the "License");
49976 * you may not use this file except in compliance with the License.
49977 * You may obtain a copy of the License at
49978 *
49979 * http://www.apache.org/licenses/LICENSE-2.0
49980 *
49981 * Unless required by applicable law or agreed to in writing, software
49982 * distributed under the License is distributed on an "AS IS" BASIS,
49983 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
49984 * See the License for the specific language governing permissions and
49985 * limitations under the License.
49986 * =============================================================================
49987 */
49988 function assertParamsConsistent(shapes, axis) {
49989 var rank = shapes[0].length;
49990 shapes.forEach(function (shape, i) {
49991 assert$1(shape.length === rank, function () {
49992 return "Error in concat".concat(rank, "D: rank of tensors[").concat(i, "] must be the same ") + "as the rank of the rest (".concat(rank, ")");
49993 });
49994 });
49995 assert$1(axis >= 0 && axis < rank, function () {
49996 return "Error in concat".concat(rank, "D: axis must be between 0 and ").concat(rank - 1, ".");
49997 });
49998 var firstShape = shapes[0];
49999 shapes.forEach(function (shape, i) {
50000 for (var r = 0; r < rank; r++) {
50001 assert$1(r === axis || shape[r] === firstShape[r], function () {
50002 return "Error in concat".concat(rank, "D: Shape of tensors[").concat(i, "] (").concat(shape, ") ") + "does not match the shape of the rest (".concat(firstShape, ") ") + "along the non-concatenated axis ".concat(i, ".");
50003 });
50004 }
50005 });
50006 }
50007 function computeOutShape$1(shapes, axis) {
50008 var outputShape = shapes[0].slice();
50009 for (var i = 1; i < shapes.length; i++) {
50010 outputShape[axis] += shapes[i][axis];
50011 }
50012 return outputShape;
50013 }
50014
50015 /**
50016 * @license
50017 * Copyright 2020 Google Inc. All Rights Reserved.
50018 * Licensed under the Apache License, Version 2.0 (the "License");
50019 * you may not use this file except in compliance with the License.
50020 * You may obtain a copy of the License at
50021 *
50022 * http://www.apache.org/licenses/LICENSE-2.0
50023 *
50024 * Unless required by applicable law or agreed to in writing, software
50025 * distributed under the License is distributed on an "AS IS" BASIS,
50026 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50027 * See the License for the specific language governing permissions and
50028 * limitations under the License.
50029 * =============================================================================
50030 */
50031
50032 /**
50033 * @license
50034 * Copyright 2022 Google LLC. All Rights Reserved.
50035 * Licensed under the Apache License, Version 2.0 (the "License");
50036 * you may not use this file except in compliance with the License.
50037 * You may obtain a copy of the License at
50038 *
50039 * http://www.apache.org/licenses/LICENSE-2.0
50040 *
50041 * Unless required by applicable law or agreed to in writing, software
50042 * distributed under the License is distributed on an "AS IS" BASIS,
50043 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50044 * See the License for the specific language governing permissions and
50045 * limitations under the License.
50046 * =============================================================================
50047 */
50048 var RowPartitionType$1;
50049 (function (RowPartitionType) {
50050 RowPartitionType[RowPartitionType["FIRST_DIM_SIZE"] = 0] = "FIRST_DIM_SIZE";
50051 RowPartitionType[RowPartitionType["VALUE_ROWIDS"] = 1] = "VALUE_ROWIDS";
50052 RowPartitionType[RowPartitionType["ROW_LENGTHS"] = 2] = "ROW_LENGTHS";
50053 RowPartitionType[RowPartitionType["ROW_SPLITS"] = 3] = "ROW_SPLITS";
50054 RowPartitionType[RowPartitionType["ROW_LIMITS"] = 4] = "ROW_LIMITS";
50055 RowPartitionType[RowPartitionType["ROW_STARTS"] = 5] = "ROW_STARTS";
50056 })(RowPartitionType$1 || (RowPartitionType$1 = {}));
50057 function combineRaggedTensorToTensorShapes(raggedRank, shape, valueShape) {
50058 // Test for consistency of valueShape and shape specified.
50059 // If shape is unspecified and valueShape is specified, then copy
50060 // over the size from the valueShape dimension.
50061 var outputShape = new Array();
50062 if (valueShape == null && shape == null) {
50063 return outputShape;
50064 }
50065 if (shape == null) {
50066 // Here, value_shape must be of known size.
50067 while (outputShape.length < raggedRank + valueShape.length) {
50068 outputShape.push(-1);
50069 }
50070 } else {
50071 outputShape = shape.slice();
50072 }
50073 if (valueShape == null) {
50074 return outputShape;
50075 }
50076 // At this point, valueShape and output_shape have known ranks.
50077 if (raggedRank + valueShape.length !== outputShape.length) {
50078 throw new Error("rt input.shape and shape=".concat(shape, " are incompatible: rt input.rank = ").concat(raggedRank + valueShape.length, ", but shape.rank = ").concat(outputShape.length));
50079 }
50080 for (var i = 1; i < valueShape.length; ++i) {
50081 var valueDim = valueShape[i];
50082 var outputShapeDimIndex = outputShape[outputShape.length - valueShape.length + i];
50083 var outputShapeDim = outputShape[outputShapeDimIndex];
50084 if (valueDim >= 0) {
50085 if (outputShapeDim >= 0) {
50086 if (outputShapeDim !== valueDim) {
50087 throw new Error("rt input.shape and shape=".concat(shape, " are incompatible: rt input.shape[").concat(i + raggedRank, "] = ").concat(valueDim, " but shape[").concat(i + raggedRank, "] = ").concat(outputShapeDim));
50088 }
50089 } else {
50090 outputShape[outputShapeDimIndex] = valueDim;
50091 }
50092 }
50093 }
50094 return outputShape;
50095 }
50096 function getRowPartitionTypesHelper(rowPartitionTypeStrings) {
50097 var stringToType = {
50098 'FIRST_DIM_SIZE': RowPartitionType$1.FIRST_DIM_SIZE,
50099 'VALUE_ROWIDS': RowPartitionType$1.VALUE_ROWIDS,
50100 'ROW_LENGTHS': RowPartitionType$1.ROW_LENGTHS,
50101 'ROW_SPLITS': RowPartitionType$1.ROW_SPLITS,
50102 'ROW_LIMITS': RowPartitionType$1.ROW_LIMITS,
50103 'ROW_STARTS': RowPartitionType$1.ROW_STARTS
50104 };
50105 var result = [];
50106 var _iterator = _createForOfIteratorHelper(rowPartitionTypeStrings),
50107 _step;
50108 try {
50109 for (_iterator.s(); !(_step = _iterator.n()).done;) {
50110 var typeStr = _step.value;
50111 if (typeStr in stringToType) {
50112 result.push(stringToType[typeStr]);
50113 } else {
50114 break;
50115 }
50116 }
50117 } catch (err) {
50118 _iterator.e(err);
50119 } finally {
50120 _iterator.f();
50121 }
50122 return result;
50123 }
50124 function getRaggedRank(rowPartitionTypes) {
50125 if (rowPartitionTypes.length === 0) {
50126 return 0;
50127 }
50128 if (rowPartitionTypes[0] === RowPartitionType$1.FIRST_DIM_SIZE) {
50129 return rowPartitionTypes.length - 1;
50130 }
50131 return rowPartitionTypes.length;
50132 }
50133 function validateDefaultValueShape(defaultValueShape, valueShape) {
50134 if (defaultValueShape == null || valueShape == null) {
50135 return;
50136 }
50137 var defaultNDims = defaultValueShape.length;
50138 var valuesNDims = valueShape.length;
50139 if (defaultNDims >= valuesNDims) {
50140 throw new Error("defaultValue.shape=".concat(defaultValueShape, " and ragged tensor flatValues.shape=").concat(valueShape, ", are incompatible: defaultValue.rank = ").concat(defaultNDims, " must be less than ragged tensor input flatValues.rank = ").concat(valuesNDims, ")"));
50141 }
50142 for (var i = 0; i < Math.min(defaultNDims, valuesNDims - 1); ++i) {
50143 var defaultDim = defaultValueShape[i];
50144 var valueDim = valueShape[i + 1];
50145 if (defaultDim >= 0 && valueDim >= 0 && defaultDim !== 1 && defaultDim !== valueDim) {
50146 throw new Error("defaultValue.shape=".concat(defaultValueShape, ", and ragged tensor input flatValues.shape=").concat(valueShape, " are incompatible: defaultValue.shape[").concat(i - defaultValueShape.length, "] = ").concat(defaultDim, " but ragged tensor input.flatValues.shape[").concat(i - defaultValueShape.length, "] = ").concat(valueDim));
50147 }
50148 }
50149 }
50150
50151 /**
50152 * @license
50153 * Copyright 2017 Google LLC. All Rights Reserved.
50154 * Licensed under the Apache License, Version 2.0 (the "License");
50155 * you may not use this file except in compliance with the License.
50156 * You may obtain a copy of the License at
50157 *
50158 * http://www.apache.org/licenses/LICENSE-2.0
50159 *
50160 * Unless required by applicable law or agreed to in writing, software
50161 * distributed under the License is distributed on an "AS IS" BASIS,
50162 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50163 * See the License for the specific language governing permissions and
50164 * limitations under the License.
50165 * =============================================================================
50166 */
50167 var PARALLELIZE_THRESHOLD = 30;
50168 function computeOptimalWindowSize(inSize) {
50169 if (inSize <= PARALLELIZE_THRESHOLD) {
50170 return inSize;
50171 }
50172 return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
50173 }
50174
50175 /**
50176 * @license
50177 * Copyright 2020 Google LLC. All Rights Reserved.
50178 * Licensed under the Apache License, Version 2.0 (the "License");
50179 * you may not use this file except in compliance with the License.
50180 * You may obtain a copy of the License at
50181 *
50182 * http://www.apache.org/licenses/LICENSE-2.0
50183 *
50184 * Unless required by applicable law or agreed to in writing, software
50185 * distributed under the License is distributed on an "AS IS" BASIS,
50186 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50187 * See the License for the specific language governing permissions and
50188 * limitations under the License.
50189 * =============================================================================
50190 */
50191 // Returns the image center in pixels.
50192 function getImageCenter(center, imageHeight, imageWidth) {
50193 var centerX = imageWidth * (typeof center === 'number' ? center : center[0]);
50194 var centerY = imageHeight * (typeof center === 'number' ? center : center[1]);
50195 return [centerX, centerY];
50196 }
50197
50198 /**
50199 * @license
50200 * Copyright 2018 Google LLC. All Rights Reserved.
50201 * Licensed under the Apache License, Version 2.0 (the "License");
50202 * you may not use this file except in compliance with the License.
50203 * You may obtain a copy of the License at
50204 *
50205 * http://www.apache.org/licenses/LICENSE-2.0
50206 *
50207 * Unless required by applicable law or agreed to in writing, software
50208 * distributed under the License is distributed on an "AS IS" BASIS,
50209 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50210 * See the License for the specific language governing permissions and
50211 * limitations under the License.
50212 * =============================================================================
50213 */
50214 /**
50215 * Gets the new shape of the input Tensor after it's been reshaped
50216 * to:
50217 * [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape),
50218 * inputShape[1], ..., inputShape[N-1]]
50219 *
50220 * See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
50221 */
50222 function getReshaped(inputShape, blockShape, prod) {
50223 var batchToSpace = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : true;
50224 var reshaped = [];
50225 if (batchToSpace) {
50226 reshaped = reshaped.concat(blockShape.slice(0));
50227 reshaped.push(inputShape[0] / prod);
50228 reshaped = reshaped.concat(inputShape.slice(1));
50229 } else {
50230 reshaped = reshaped.concat(inputShape[0]);
50231 var spatialLength = blockShape.length;
50232 for (var i = 0; i < spatialLength; ++i) {
50233 reshaped = reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
50234 }
50235 reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
50236 }
50237 return reshaped;
50238 }
50239 /**
50240 * Gets the permutation that will transpose the dimensions of the
50241 * reshaped tensor to shape:
50242 *
50243 * [batch / prod(block_shape),inputShape[1], blockShape[0], ...,
50244 * inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
50245 *
50246 * see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
50247 */
50248 function getPermuted(reshapedRank, blockShapeRank) {
50249 var batchToSpace = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true;
50250 var permuted = [];
50251 if (batchToSpace) {
50252 permuted.push(blockShapeRank);
50253 for (var i = blockShapeRank + 1; i < reshapedRank; ++i) {
50254 if (i <= 2 * blockShapeRank) {
50255 permuted.push(i);
50256 permuted.push(i - (blockShapeRank + 1));
50257 } else {
50258 permuted.push(i);
50259 }
50260 }
50261 } else {
50262 var permutedBeforeBatch = [];
50263 var permutedAfterBatch = [];
50264 for (var _i = 1; _i < reshapedRank; ++_i) {
50265 if (_i >= blockShapeRank * 2 + 1 || _i % 2 === 1) {
50266 permutedAfterBatch.push(_i);
50267 } else {
50268 permutedBeforeBatch.push(_i);
50269 }
50270 }
50271 permuted.push.apply(permuted, permutedBeforeBatch);
50272 permuted.push(0);
50273 permuted.push.apply(permuted, permutedAfterBatch);
50274 }
50275 return permuted;
50276 }
50277 /**
50278 * Gets the shape of the reshaped and permuted input Tensor before any cropping
50279 * is applied. The new shape will be:
50280 *
50281 * [batch / prod(blockShape),inputShape[1] * blockShape[0], ...,
50282 * inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
50283 *
50284 * See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
50285 */
50286 function getReshapedPermuted(inputShape, blockShape, prod) {
50287 var batchToSpace = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : true;
50288 var reshapedPermuted = [];
50289 if (batchToSpace) {
50290 reshapedPermuted.push(inputShape[0] / prod);
50291 } else {
50292 reshapedPermuted.push(inputShape[0] * prod);
50293 }
50294 for (var i = 1; i < inputShape.length; ++i) {
50295 if (i <= blockShape.length) {
50296 if (batchToSpace) {
50297 reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
50298 } else {
50299 reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
50300 }
50301 } else {
50302 reshapedPermuted.push(inputShape[i]);
50303 }
50304 }
50305 return reshapedPermuted;
50306 }
50307 /**
50308 * Converts the crops argument into the beginning coordinates of a slice
50309 * operation.
50310 */
50311 function getSliceBeginCoords(crops, blockShape) {
50312 var sliceBeginCoords = [0];
50313 for (var i = 0; i < blockShape; ++i) {
50314 sliceBeginCoords.push(crops[i][0]);
50315 }
50316 return sliceBeginCoords;
50317 }
50318 /**
50319 * Converts the crops argument into the size of a slice operation. When
50320 * combined with getSliceBeginCoords this function allows the reshaped and
50321 * permuted Tensor to be cropped to its final output shape of:
50322 *
50323 * inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ...,
50324 * inputShape[M] * blockShape[M-1] -crops[M-1,0] -
50325 * crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]]
50326 *
50327 * See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
50328 */
50329 function getSliceSize(uncroppedShape, crops, blockShape) {
50330 var sliceSize = uncroppedShape.slice(0, 1);
50331 for (var i = 0; i < blockShape; ++i) {
50332 sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
50333 }
50334 return sliceSize;
50335 }
50336
50337 /**
50338 * @license
50339 * Copyright 2018 Google LLC. All Rights Reserved.
50340 * Licensed under the Apache License, Version 2.0 (the "License");
50341 * you may not use this file except in compliance with the License.
50342 * You may obtain a copy of the License at
50343 *
50344 * http://www.apache.org/licenses/LICENSE-2.0
50345 *
50346 * Unless required by applicable law or agreed to in writing, software
50347 * distributed under the License is distributed on an "AS IS" BASIS,
50348 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50349 * See the License for the specific language governing permissions and
50350 * limitations under the License.
50351 * =============================================================================
50352 */
50353 var SELU_SCALEALPHA = 1.7580993408473768599402175208123;
50354 var SELU_SCALE = 1.0507009873554804934193349852946;
50355
50356 /**
50357 * @license
50358 * Copyright 2018 Google LLC. All Rights Reserved.
50359 * Licensed under the Apache License, Version 2.0 (the "License");
50360 * you may not use this file except in compliance with the License.
50361 * You may obtain a copy of the License at
50362 *
50363 * http://www.apache.org/licenses/LICENSE-2.0
50364 *
50365 * Unless required by applicable law or agreed to in writing, software
50366 * distributed under the License is distributed on an "AS IS" BASIS,
50367 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50368 * See the License for the specific language governing permissions and
50369 * limitations under the License.
50370 * =============================================================================
50371 */
50372 var ERF_P = 0.3275911;
50373 var ERF_A1 = 0.254829592;
50374 var ERF_A2 = -0.284496736;
50375 var ERF_A3 = 1.421413741;
50376 var ERF_A4 = -1.453152027;
50377 var ERF_A5 = 1.061405429;
50378
50379 /**
50380 * @license
50381 * Copyright 2018 Google LLC. All Rights Reserved.
50382 * Licensed under the Apache License, Version 2.0 (the "License");
50383 * you may not use this file except in compliance with the License.
50384 * You may obtain a copy of the License at
50385 *
50386 * http://www.apache.org/licenses/LICENSE-2.0
50387 *
50388 * Unless required by applicable law or agreed to in writing, software
50389 * distributed under the License is distributed on an "AS IS" BASIS,
50390 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50391 * See the License for the specific language governing permissions and
50392 * limitations under the License.
50393 * =============================================================================
50394 */
50395 /**
50396 * Merges real and imaginary Float32Arrays into a single complex Float32Array.
50397 *
50398 * The memory layout is interleaved as follows:
50399 * real: [r0, r1, r2]
50400 * imag: [i0, i1, i2]
50401 * complex: [r0, i0, r1, i1, r2, i2]
50402 *
50403 * This is the inverse of splitRealAndImagArrays.
50404 *
50405 * @param real The real values of the complex tensor values.
50406 * @param imag The imag values of the complex tensor values.
50407 * @returns A complex tensor as a Float32Array with merged values.
50408 */
50409 function mergeRealAndImagArrays(real, imag) {
50410 if (real.length !== imag.length) {
50411 throw new Error("Cannot merge real and imag arrays of different lengths. real:" + "".concat(real.length, ", imag: ").concat(imag.length, "."));
50412 }
50413 var result = new Float32Array(real.length * 2);
50414 for (var i = 0; i < result.length; i += 2) {
50415 result[i] = real[i / 2];
50416 result[i + 1] = imag[i / 2];
50417 }
50418 return result;
50419 }
50420 /**
50421 * Splits a complex Float32Array into real and imag parts.
50422 *
50423 * The memory layout is interleaved as follows:
50424 * complex: [r0, i0, r1, i1, r2, i2]
50425 * real: [r0, r1, r2]
50426 * imag: [i0, i1, i2]
50427 *
50428 * This is the inverse of mergeRealAndImagArrays.
50429 *
50430 * @param complex The complex tensor values.
50431 * @returns An object with real and imag Float32Array components of the complex
50432 * tensor.
50433 */
50434 function splitRealAndImagArrays(complex) {
50435 var real = new Float32Array(complex.length / 2);
50436 var imag = new Float32Array(complex.length / 2);
50437 for (var i = 0; i < complex.length; i += 2) {
50438 real[i / 2] = complex[i];
50439 imag[i / 2] = complex[i + 1];
50440 }
50441 return {
50442 real: real,
50443 imag: imag
50444 };
50445 }
50446 /**
50447 * Extracts even indexed complex values in the given array.
50448 * @param complex The complex tensor values
50449 */
50450 function complexWithEvenIndex(complex) {
50451 var len = Math.ceil(complex.length / 4);
50452 var real = new Float32Array(len);
50453 var imag = new Float32Array(len);
50454 for (var i = 0; i < complex.length; i += 4) {
50455 real[Math.floor(i / 4)] = complex[i];
50456 imag[Math.floor(i / 4)] = complex[i + 1];
50457 }
50458 return {
50459 real: real,
50460 imag: imag
50461 };
50462 }
50463 /**
50464 * Extracts odd indexed complete values in the given array.
50465 * @param complex The complex tensor values
50466 */
50467 function complexWithOddIndex(complex) {
50468 var len = Math.floor(complex.length / 4);
50469 var real = new Float32Array(len);
50470 var imag = new Float32Array(len);
50471 for (var i = 2; i < complex.length; i += 4) {
50472 real[Math.floor(i / 4)] = complex[i];
50473 imag[Math.floor(i / 4)] = complex[i + 1];
50474 }
50475 return {
50476 real: real,
50477 imag: imag
50478 };
50479 }
50480 /**
50481 * Get the map representing a complex value in the given array.
50482 * @param complex The complex tensor values.
50483 * @param index An index of the target complex value.
50484 */
50485 function getComplexWithIndex(complex, index) {
50486 var real = complex[index * 2];
50487 var imag = complex[index * 2 + 1];
50488 return {
50489 real: real,
50490 imag: imag
50491 };
50492 }
50493 /**
50494 * Insert a given complex value into the TypedArray.
50495 * @param data The array in which the complex value is inserted.
50496 * @param c The complex value to be inserted.
50497 * @param index An index of the target complex value.
50498 */
50499 function assignToTypedArray(data, real, imag, index) {
50500 data[index * 2] = real;
50501 data[index * 2 + 1] = imag;
50502 }
50503 /**
50504 * Make the list of exponent terms used by FFT.
50505 */
50506 function exponents(n, inverse) {
50507 var real = new Float32Array(n / 2);
50508 var imag = new Float32Array(n / 2);
50509 for (var i = 0; i < Math.ceil(n / 2); i++) {
50510 var x = (inverse ? 2 : -2) * Math.PI * (i / n);
50511 real[i] = Math.cos(x);
50512 imag[i] = Math.sin(x);
50513 }
50514 return {
50515 real: real,
50516 imag: imag
50517 };
50518 }
50519 /**
50520 * Make the exponent term used by FFT.
50521 */
50522 function exponent(k, n, inverse) {
50523 var x = (inverse ? 2 : -2) * Math.PI * (k / n);
50524 var real = Math.cos(x);
50525 var imag = Math.sin(x);
50526 return {
50527 real: real,
50528 imag: imag
50529 };
50530 }
50531
50532 var ARROW = '->';
50533 var ARROW_REGEX = /->/g;
50534 var COMMA = ',';
50535 var ELLIPSIS = '...';
50536 /**
50537 * Parse an equation for einsum.
50538 *
50539 * @param equation The einsum equation (e.g., "ij,jk->ik").
50540 * @param numTensors Number of tensors provided along with `equation`. Used to
50541 * check matching number of input tensors.
50542 * @returns An object consisting of the following fields:
50543 * - allDims: all dimension names as strings.
50544 * - summedDims: a list of all dimensions being summed over, as indices to
50545 * the elements of `allDims`.
50546 * - idDims: indices of the dimensions in each input tensor, as indices to
50547 * the elements of `allDims.
50548 */
50549 function decodeEinsumEquation(equation, numTensors) {
50550 equation = equation.replace(/\s/g, ''); // Remove witespace in equation.
50551 var numArrows = (equation.length - equation.replace(ARROW_REGEX, '').length) / ARROW.length;
50552 if (numArrows < 1) {
50553 throw new Error('Equations without an arrow are not supported.');
50554 } else if (numArrows > 1) {
50555 throw new Error("Equation must contain exactly one arrow (\"".concat(ARROW, "\")."));
50556 }
50557 var _equation$split = equation.split(ARROW),
50558 _equation$split2 = _slicedToArray(_equation$split, 2),
50559 inputString = _equation$split2[0],
50560 outputString = _equation$split2[1];
50561 assert$1(inputString.indexOf(ELLIPSIS) === -1, function () {
50562 return "The ellipsis notation (\"".concat(ELLIPSIS, "\") is not supported yet.");
50563 });
50564 var inputTerms = inputString.split(COMMA);
50565 var numInputs = inputTerms.length;
50566 if (numTensors !== numInputs) {
50567 throw new Error("Expected ".concat(numInputs, " input tensors, received ").concat(numTensors));
50568 }
50569 if (numInputs > 2) {
50570 throw new Error('Support for more than 2 input tensors is not implemented yet.');
50571 }
50572 var allDims = [];
50573 var _loop = function _loop() {
50574 var dimName = outputString[i];
50575 if (!inputTerms.some(function (inputTerm) {
50576 return inputTerm.indexOf(dimName) !== -1;
50577 })) {
50578 throw new Error("Output subscripts contain the label ".concat(dimName, " ") + "not present in the input subscripts.");
50579 }
50580 if (allDims.indexOf(dimName) === -1) {
50581 allDims.push(dimName);
50582 }
50583 };
50584 for (var i = 0; i < outputString.length; ++i) {
50585 _loop();
50586 }
50587 for (var _i = 0; _i < inputString.length; ++_i) {
50588 var dimName = inputString[_i];
50589 if (allDims.indexOf(dimName) === -1 && dimName !== COMMA) {
50590 allDims.push(dimName);
50591 }
50592 }
50593 var idDims = new Array(inputTerms.length);
50594 for (var _i2 = 0; _i2 < numInputs; ++_i2) {
50595 if (new Set(inputTerms[_i2].split('')).size !== inputTerms[_i2].length) {
50596 throw new Error("Found duplicate axes in input component ".concat(inputTerms[_i2], ". ") + "Support for duplicate axes in input is not implemented yet.");
50597 }
50598 idDims[_i2] = [];
50599 for (var j = 0; j < inputTerms[_i2].length; ++j) {
50600 idDims[_i2].push(allDims.indexOf(inputTerms[_i2][j]));
50601 }
50602 }
50603 var numDims = allDims.length; // Number of unique dimensions.
50604 var numOutDims = outputString.length; // Number of output dimensions.
50605 var summedDims = []; // Dimensions being summed over.
50606 for (var _i3 = numOutDims; _i3 < numDims; ++_i3) {
50607 summedDims.push(_i3);
50608 }
50609 return {
50610 allDims: allDims,
50611 summedDims: summedDims,
50612 idDims: idDims
50613 };
50614 }
50615 /**
50616 * Get the permutation for a given input tensor.
50617 *
50618 * @param nDims Total number of dimension of all tensors involved in the einsum
50619 * operation.
50620 * @param idDims Dimension indices involve in the tensor in question.
50621 * @returns An object consisting of the following fields:
50622 * - permutationIndices: Indices to permute the axes of the tensor with.
50623 * - expandDims: Indices to the dimension that need to be expanded from the
50624 * tensor after permutation.
50625 */
50626 function getEinsumPermutation(nDims, idDims) {
50627 var permutationIndices = new Array(nDims);
50628 permutationIndices.fill(-1);
50629 for (var i = 0; i < idDims.length; ++i) {
50630 permutationIndices[idDims[i]] = i;
50631 }
50632 var expandDims = [];
50633 for (var _i4 = 0; _i4 < nDims; ++_i4) {
50634 if (permutationIndices[_i4] === -1) {
50635 expandDims.push(_i4);
50636 }
50637 }
50638 permutationIndices = permutationIndices.filter(function (d) {
50639 return d !== -1;
50640 });
50641 return {
50642 permutationIndices: permutationIndices,
50643 expandDims: expandDims
50644 };
50645 }
50646 /**
50647 * Checks that the dimension sizes from different input tensors match the
50648 * equation.
50649 */
50650 function checkEinsumDimSizes(nDims, idDims, tensors) {
50651 var dimSizes = new Array(nDims);
50652 var _loop2 = function _loop2(i) {
50653 var shape = tensors[i].shape;
50654 var _loop3 = function _loop3(j) {
50655 if (dimSizes[idDims[i][j]] === undefined) {
50656 dimSizes[idDims[i][j]] = shape[j];
50657 } else {
50658 assert$1(dimSizes[idDims[i][j]] === shape[j], function () {
50659 return "Expected dimension ".concat(dimSizes[idDims[i][j]], " at axis ").concat(j, " ") + "of input shaped ".concat(JSON.stringify(shape), ", ") + "but got dimension ".concat(shape[j]);
50660 });
50661 }
50662 };
50663 for (var j = 0; j < idDims[i].length; ++j) {
50664 _loop3(j);
50665 }
50666 };
50667 for (var i = 0; i < tensors.length; ++i) {
50668 _loop2(i);
50669 }
50670 }
50671 /**
50672 * Gets path of computation for einsum.
50673 *
50674 * @param summedDims indices to the dimensions being summed over.
50675 * @param idDims A look up table for the dimensions present in each input
50676 * tensor.Each constituent array contains indices for the dimensions in the
50677 * corresponding input tensor.
50678 *
50679 * @return A map with two fields:
50680 * - path: The path of computation, with each element indicating the dimension
50681 * being summed over after the element-wise multiplication in that step.
50682 * - steps: With the same length as `path`. Each element contains the indices
50683 * to the input tensors being used for element-wise multiplication in the
50684 * corresponding step.
50685 */
50686 function getEinsumComputePath(summedDims, idDims) {
50687 var path = summedDims;
50688 var steps = [];
50689 var nSteps = 0;
50690 if (summedDims.length === 0) {
50691 // Einsum that involes no summing: e.g., transpose and outer product.
50692 path.push(-1);
50693 }
50694 nSteps = summedDims.length + 1;
50695 for (var i = 0; i < nSteps; ++i) {
50696 steps.push([]);
50697 }
50698 var computedTermIndices = [];
50699 for (var _i5 = 0; _i5 < path.length; ++_i5) {
50700 var summedDim = path[_i5];
50701 var termIndices = findTermsWithDim(idDims, summedDim);
50702 var _iterator = _createForOfIteratorHelper(termIndices),
50703 _step;
50704 try {
50705 for (_iterator.s(); !(_step = _iterator.n()).done;) {
50706 var termIndex = _step.value;
50707 if (computedTermIndices.indexOf(termIndex) === -1) {
50708 steps[_i5].push(termIndex);
50709 computedTermIndices.push(termIndex);
50710 }
50711 }
50712 } catch (err) {
50713 _iterator.e(err);
50714 } finally {
50715 _iterator.f();
50716 }
50717 }
50718 return {
50719 path: path,
50720 steps: steps
50721 };
50722 }
50723 /** Determines if an axes permutation is the identity permutation. */
50724 function isIdentityPermutation(perm) {
50725 return perm.every(function (dim, index) {
50726 return dim === index;
50727 });
50728 }
50729 function findTermsWithDim(idDims, dim) {
50730 var termIndices = [];
50731 for (var i = 0; i < idDims.length; ++i) {
50732 if (idDims[i].length === 0 || idDims[i].indexOf(dim) !== -1 || dim === -1) {
50733 termIndices.push(i);
50734 }
50735 }
50736 return termIndices;
50737 }
50738
50739 /**
50740 * Prepare the split size array. When the input is a number, the axis is evenly
50741 * divided among the split size. When the input contains the negative value, the
50742 * rest of the axis is allocated toward that.
50743 */
50744 function prepareSplitSize(x, numOrSizeSplits) {
50745 var axis = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
50746 var splitSizes = [];
50747 if (typeof numOrSizeSplits === 'number') {
50748 assert$1(x.shape[axis] % numOrSizeSplits === 0, function () {
50749 return 'Number of splits must evenly divide the axis.';
50750 });
50751 splitSizes = new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
50752 } else {
50753 var numOfNegs = numOrSizeSplits.reduce(function (count, value) {
50754 if (value === -1) {
50755 count += 1;
50756 }
50757 return count;
50758 }, 0);
50759 assert$1(numOfNegs <= 1, function () {
50760 return 'There should be only one negative value in split array.';
50761 });
50762 var negIndex = numOrSizeSplits.indexOf(-1);
50763 // Allow the number of split array to be -1, which indicates the rest
50764 // of dimension is allocated to that split.
50765 if (negIndex !== -1) {
50766 var total = numOrSizeSplits.reduce(function (a, b) {
50767 return b > 0 ? a + b : a;
50768 });
50769 numOrSizeSplits[negIndex] = x.shape[axis] - total;
50770 }
50771 assert$1(x.shape[axis] === numOrSizeSplits.reduce(function (a, b) {
50772 return a + b;
50773 }), function () {
50774 return 'The sum of sizes must match the size of the axis dimension.';
50775 });
50776 splitSizes = numOrSizeSplits;
50777 }
50778 return splitSizes;
50779 }
50780
50781 /**
50782 * @license
50783 * Copyright 2021 Google LLC. All Rights Reserved.
50784 * Licensed under the Apache License, Version 2.0 (the "License");
50785 * you may not use this file except in compliance with the License.
50786 * You may obtain a copy of the License at
50787 *
50788 * http://www.apache.org/licenses/LICENSE-2.0
50789 *
50790 * Unless required by applicable law or agreed to in writing, software
50791 * distributed under the License is distributed on an "AS IS" BASIS,
50792 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50793 * See the License for the specific language governing permissions and
50794 * limitations under the License.
50795 * =============================================================================
50796 */
50797 /**
50798 * Generates sparse fill empty rows indices, dense shape mismatch error message.
50799 *
50800 * @param indicesLength The first dimension of indices.
50801 */
50802 function getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesLength) {
50803 return "Received SparseTensor with denseShape[0] = 0 but\n indices.shape[0] = ".concat(indicesLength);
50804 }
50805 /**
50806 * Generates sparse fill empty rows negative index error message.
50807 *
50808 * @param index The index with a negative value.
50809 * @param value The negative value.
50810 */
50811 function getSparseFillEmptyRowsNegativeIndexErrorMessage(index, value) {
50812 return "indices(".concat(index, ", 0) is invalid: ").concat(value, " < 0");
50813 }
50814 /**
50815 * Generates sparse fill empty rows out of range index error message.
50816 *
50817 * @param index The index with an out of range value.
50818 * @param value The out of range value.
50819 * @param limit The upper limit for indices.
50820 */
50821 function getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(index, value, limit) {
50822 return "indices(".concat(index, ", 0) is invalid: ").concat(value, " >= ").concat(limit);
50823 }
50824
50825 /**
50826 * @license
50827 * Copyright 2021 Google LLC. All Rights Reserved.
50828 * Licensed under the Apache License, Version 2.0 (the "License");
50829 * you may not use this file except in compliance with the License.
50830 * You may obtain a copy of the License at
50831 *
50832 * http://www.apache.org/licenses/LICENSE-2.0
50833 *
50834 * Unless required by applicable law or agreed to in writing, software
50835 * distributed under the License is distributed on an "AS IS" BASIS,
50836 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50837 * See the License for the specific language governing permissions and
50838 * limitations under the License.
50839 * =============================================================================
50840 */
50841 /**
50842 * Generates sparse reshape multiple negative 1 output dimension error message.
50843 *
50844 * @param dim1 The first dimension with a negative 1 value.
50845 * @param dim2 The second dimension with a negative 1 value.
50846 */
50847 function getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(dim1, dim2) {
50848 return "only one output dimension may be -1, not both ".concat(dim1, " and ").concat(dim2);
50849 }
50850 /**
50851 * Generates sparse reshape negative output dimension error message.
50852 *
50853 * @param dim The dimension with a negative value.
50854 * @param value The negative value.
50855 */
50856 function getSparseReshapeNegativeOutputDimErrorMessage(dim, value) {
50857 return "size ".concat(dim, " must be non-negative, not ").concat(value);
50858 }
50859 /**
50860 * Generates sparse reshape empty tensor zero output dimension error message.
50861 *
50862 */
50863 function getSparseReshapeEmptyTensorZeroOutputDimErrorMessage() {
50864 return 'reshape cannot infer the missing input size for an empty tensor ' + 'unless all specified input sizes are non-zero';
50865 }
50866 /**
50867 * Generates sparse reshape input output multiple mismatch error message.
50868 *
50869 * @param inputShape the input shape.
50870 * @param outputShape the requested output shape.
50871 */
50872 function getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape) {
50873 var inputSize = sizeFromShape(inputShape);
50874 var outputSize = sizeFromShape(outputShape);
50875 return "Input to reshape is a SparseTensor with ".concat(inputSize, "\n dense values, but the requested shape requires a multiple of ").concat(outputSize, ". inputShape=").concat(inputShape, " outputShape= ").concat(outputShape);
50876 }
50877 /**
50878 * Generates sparse reshape input output inequality error message.
50879 *
50880 * @param inputShape the input shape.
50881 * @param outputShape the requested output shape.
50882 */
50883 function getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape) {
50884 var inputSize = sizeFromShape(inputShape);
50885 var outputSize = sizeFromShape(outputShape);
50886 return "Input to reshape is a tensor with ".concat(inputSize, " dense values, but the requested shape has ").concat(outputSize, ". inputShape=").concat(inputShape, " outputShape=").concat(outputShape);
50887 }
50888
50889 /**
50890 * @license
50891 * Copyright 2021 Google LLC. All Rights Reserved.
50892 * Licensed under the Apache License, Version 2.0 (the "License");
50893 * you may not use this file except in compliance with the License.
50894 * You may obtain a copy of the License at
50895 *
50896 * http://www.apache.org/licenses/LICENSE-2.0
50897 *
50898 * Unless required by applicable law or agreed to in writing, software
50899 * distributed under the License is distributed on an "AS IS" BASIS,
50900 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50901 * See the License for the specific language governing permissions and
50902 * limitations under the License.
50903 * =============================================================================
50904 */
50905 /**
50906 * Generates sparse segment reduction negative segment ids error message.
50907 *
50908 */
50909 function getSparseSegmentReductionNegativeSegmentIdsErrorMessage() {
50910 return "segment ids must be >= 0";
50911 }
50912 /**
50913 * Generates sparse segment reduction non increasing segment ids error message.
50914 *
50915 */
50916 function getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage() {
50917 return "segment ids are not increasing";
50918 }
50919 /**
50920 * Generates sparse segment reduction segment id out of range error message.
50921 *
50922 * @param segmentId The segment id index that is out of range.
50923 * @param outputRows Upper bound of valid segment id values.
50924 */
50925 function getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(segmentId, outputRows) {
50926 return "Segment id ".concat(segmentId, " out of range [0, ").concat(outputRows, "), possibly because segmentIds input is not sorted.");
50927 }
50928 /**
50929 * Generates sparse segment reduction input indice out of range error message.
50930 *
50931 * @param index The index that holds the out of range value.
50932 * @param indexValue The value that is out of range.
50933 * @param inputRows Upper bound of valid index values.
50934 */
50935 function getSparseSegmentReductionIndicesOutOfRangeErrorMessage(index, indexValue, inputRows) {
50936 return "Bad: indices[".concat(index, "] == ").concat(indexValue, " out of range [0, ").concat(inputRows, ")");
50937 }
50938
50939 /**
50940 * @license
50941 * Copyright 2018 Google LLC. All Rights Reserved.
50942 * Licensed under the Apache License, Version 2.0 (the "License");
50943 * you may not use this file except in compliance with the License.
50944 * You may obtain a copy of the License at
50945 *
50946 * http://www.apache.org/licenses/LICENSE-2.0
50947 *
50948 * Unless required by applicable law or agreed to in writing, software
50949 * distributed under the License is distributed on an "AS IS" BASIS,
50950 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50951 * See the License for the specific language governing permissions and
50952 * limitations under the License.
50953 * =============================================================================
50954 */
50955 function segOpComputeOptimalWindowSize(inSize, numSegments) {
50956 var done = false;
50957 var res;
50958 if (inSize <= PARALLELIZE_THRESHOLD) {
50959 res = inSize;
50960 done = true;
50961 } else {
50962 res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
50963 }
50964 while (!done) {
50965 if (res > numSegments || res === inSize) {
50966 done = true;
50967 } else {
50968 res = nearestDivisor(inSize, res + 1);
50969 }
50970 }
50971 return res;
50972 }
50973 function computeOutShape(aShape, axis, numSegments) {
50974 var outShape = [];
50975 var rank = aShape.length;
50976 for (var dim = 0; dim < rank; dim++) {
50977 if (dim !== axis) {
50978 outShape.push(aShape[dim]);
50979 } else {
50980 outShape.push(numSegments);
50981 }
50982 }
50983 return outShape;
50984 }
50985 function collectGatherOpShapeInfo(x, indices, axis, batchDims) {
50986 var indicesRank = indices.shape.length;
50987 var xRank = x.shape.length;
50988 if (batchDims !== 0) {
50989 if (batchDims < -indicesRank || batchDims > indicesRank) {
50990 throw new Error("Expect batchDims in the range of [-".concat(indicesRank, ", ").concat(indicesRank, "], but got ").concat(batchDims));
50991 }
50992 }
50993 if (batchDims < 0) {
50994 batchDims += indicesRank;
50995 }
50996 if (batchDims > xRank) {
50997 throw new Error("batchDims (".concat(batchDims, ") must be less than rank(x) (\n ").concat(xRank, ")."));
50998 }
50999 if (axis < batchDims) {
51000 throw new Error("batchDims (".concat(batchDims, ") must be less than or equal to axis (").concat(axis, ")."));
51001 }
51002 for (var i = 0; i < batchDims; ++i) {
51003 if (x.shape[i] !== indices.shape[i]) {
51004 throw new Error("x.shape[".concat(i, "]: ").concat(x.shape[i], " should be equal to indices.shape[").concat(i, "]: ").concat(indices.shape[i], "."));
51005 }
51006 }
51007 var dimSize = x.shape[axis];
51008 var outputShape = [];
51009 var batchSize = 1;
51010 var outerSize = 1;
51011 var sliceSize = 1;
51012 for (var _i = 0; _i < batchDims; ++_i) {
51013 outputShape.push(x.shape[_i]);
51014 batchSize *= x.shape[_i];
51015 }
51016 for (var _i2 = batchDims; _i2 < axis; _i2++) {
51017 outputShape.push(x.shape[_i2]);
51018 outerSize *= x.shape[_i2];
51019 }
51020 for (var _i3 = batchDims; _i3 < indicesRank; _i3++) {
51021 outputShape.push(indices.shape[_i3]);
51022 }
51023 for (var _i4 = axis + 1; _i4 < xRank; _i4++) {
51024 outputShape.push(x.shape[_i4]);
51025 sliceSize *= x.shape[_i4];
51026 }
51027 return {
51028 batchSize: batchSize,
51029 sliceSize: sliceSize,
51030 outerSize: outerSize,
51031 dimSize: dimSize,
51032 outputShape: outputShape
51033 };
51034 }
51035
51036 var segment_util = {
51037 __proto__: null,
51038 collectGatherOpShapeInfo: collectGatherOpShapeInfo,
51039 computeOutShape: computeOutShape,
51040 segOpComputeOptimalWindowSize: segOpComputeOptimalWindowSize
51041 };
51042
51043 /**
51044 * @license
51045 * Copyright 2018 Google LLC. All Rights Reserved.
51046 * Licensed under the Apache License, Version 2.0 (the "License");
51047 * you may not use this file except in compliance with the License.
51048 * You may obtain a copy of the License at
51049 *
51050 * http://www.apache.org/licenses/LICENSE-2.0
51051 *
51052 * Unless required by applicable law or agreed to in writing, software
51053 * distributed under the License is distributed on an "AS IS" BASIS,
51054 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51055 * See the License for the specific language governing permissions and
51056 * limitations under the License.
51057 * =============================================================================
51058 */
51059 function fromUint8ToStringArray(vals) {
51060 try {
51061 // Decode the bytes into string.
51062 return vals.map(function (val) {
51063 return decodeString(val);
51064 });
51065 } catch (err) {
51066 throw new Error("Failed to decode encoded string bytes into utf-8, error: ".concat(err));
51067 }
51068 }
51069 function fromStringArrayToUint8(strings) {
51070 return strings.map(function (s) {
51071 return encodeString(s);
51072 });
51073 }
51074
51075 var backend_util = {
51076 __proto__: null,
51077 ERF_A1: ERF_A1,
51078 ERF_A2: ERF_A2,
51079 ERF_A3: ERF_A3,
51080 ERF_A4: ERF_A4,
51081 ERF_A5: ERF_A5,
51082 ERF_P: ERF_P,
51083 PARALLELIZE_THRESHOLD: PARALLELIZE_THRESHOLD,
51084 get RowPartitionType () { return RowPartitionType$1; },
51085 SELU_SCALE: SELU_SCALE,
51086 SELU_SCALEALPHA: SELU_SCALEALPHA,
51087 applyActivation: applyActivation$1,
51088 assertAndGetBroadcastShape: assertAndGetBroadcastShape,
51089 assertAxesAreInnerMostDims: assertAxesAreInnerMostDims,
51090 assertParamsConsistent: assertParamsConsistent,
51091 assignToTypedArray: assignToTypedArray,
51092 axesAreInnerMostDims: axesAreInnerMostDims,
51093 calculateShapes: calculateShapes,
51094 checkEinsumDimSizes: checkEinsumDimSizes,
51095 checkPadOnDimRoundingMode: checkPadOnDimRoundingMode,
51096 combineLocations: combineLocations,
51097 combineRaggedTensorToTensorShapes: combineRaggedTensorToTensorShapes,
51098 complexWithEvenIndex: complexWithEvenIndex,
51099 complexWithOddIndex: complexWithOddIndex,
51100 computeConv2DInfo: computeConv2DInfo,
51101 computeConv3DInfo: computeConv3DInfo,
51102 computeDefaultPad: computeDefaultPad,
51103 computeDilation2DInfo: computeDilation2DInfo,
51104 computeOptimalWindowSize: computeOptimalWindowSize,
51105 computeOutAndReduceShapes: computeOutAndReduceShapes,
51106 computeOutShape: computeOutShape$1,
51107 computePool2DInfo: computePool2DInfo,
51108 computePool3DInfo: computePool3DInfo,
51109 convertConv2DDataFormat: convertConv2DDataFormat,
51110 decodeEinsumEquation: decodeEinsumEquation,
51111 eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne,
51112 expandShapeToKeepDim: expandShapeToKeepDim,
51113 exponent: exponent,
51114 exponents: exponents,
51115 fromStringArrayToUint8: fromStringArrayToUint8,
51116 fromUint8ToStringArray: fromUint8ToStringArray,
51117 getAxesPermutation: getAxesPermutation,
51118 getBroadcastDims: getBroadcastDims$1,
51119 getComplexWithIndex: getComplexWithIndex,
51120 getEinsumComputePath: getEinsumComputePath,
51121 getEinsumPermutation: getEinsumPermutation,
51122 getFusedBiasGradient: getFusedBiasGradient,
51123 getFusedDyActivation: getFusedDyActivation,
51124 getImageCenter: getImageCenter,
51125 getInnerMostAxes: getInnerMostAxes,
51126 getPermuted: getPermuted,
51127 getRaggedRank: getRaggedRank,
51128 getReductionAxes: getReductionAxes,
51129 getReshaped: getReshaped,
51130 getReshapedPermuted: getReshapedPermuted,
51131 getRowPartitionTypesHelper: getRowPartitionTypesHelper,
51132 getSliceBeginCoords: getSliceBeginCoords,
51133 getSliceSize: getSliceSize,
51134 getSparseFillEmptyRowsIndicesDenseShapeMismatch: getSparseFillEmptyRowsIndicesDenseShapeMismatch,
51135 getSparseFillEmptyRowsNegativeIndexErrorMessage: getSparseFillEmptyRowsNegativeIndexErrorMessage,
51136 getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: getSparseFillEmptyRowsOutOfRangeIndexErrorMessage,
51137 getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: getSparseReshapeEmptyTensorZeroOutputDimErrorMessage,
51138 getSparseReshapeInputOutputMismatchErrorMessage: getSparseReshapeInputOutputMismatchErrorMessage,
51139 getSparseReshapeInputOutputMultipleErrorMessage: getSparseReshapeInputOutputMultipleErrorMessage,
51140 getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: getSparseReshapeMultipleNegativeOneOutputDimErrorMessage,
51141 getSparseReshapeNegativeOutputDimErrorMessage: getSparseReshapeNegativeOutputDimErrorMessage,
51142 getSparseSegmentReductionIndicesOutOfRangeErrorMessage: getSparseSegmentReductionIndicesOutOfRangeErrorMessage,
51143 getSparseSegmentReductionNegativeSegmentIdsErrorMessage: getSparseSegmentReductionNegativeSegmentIdsErrorMessage,
51144 getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage,
51145 getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage,
51146 getUndoAxesPermutation: getUndoAxesPermutation,
51147 isIdentityPermutation: isIdentityPermutation,
51148 log: log$3,
51149 mergeRealAndImagArrays: mergeRealAndImagArrays,
51150 prepareAndValidate: prepareAndValidate,
51151 prepareSplitSize: prepareSplitSize,
51152 segment_util: segment_util,
51153 shouldFuse: shouldFuse,
51154 slice_util: slice_util,
51155 splitRealAndImagArrays: splitRealAndImagArrays,
51156 stridesOrDilationsArePositive: stridesOrDilationsArePositive,
51157 tupleValuesAreOne: tupleValuesAreOne,
51158 upcastType: upcastType,
51159 validateDefaultValueShape: validateDefaultValueShape,
51160 validateInput: validateInput$1,
51161 validateUpdateShape: validateUpdateShape,
51162 warn: warn
51163 };
51164
51165 /**
51166 * @license
51167 * Copyright 2020 Google LLC. All Rights Reserved.
51168 * Licensed under the Apache License, Version 2.0 (the "License");
51169 * you may not use this file except in compliance with the License.
51170 * You may obtain a copy of the License at
51171 *
51172 * http://www.apache.org/licenses/LICENSE-2.0
51173 *
51174 * Unless required by applicable law or agreed to in writing, software
51175 * distributed under the License is distributed on an "AS IS" BASIS,
51176 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51177 * See the License for the specific language governing permissions and
51178 * limitations under the License.
51179 * =============================================================================
51180 */
51181
51182 var kernel_impls = {
51183 __proto__: null,
51184 nonMaxSuppressionV3Impl: nonMaxSuppressionV3Impl$2,
51185 nonMaxSuppressionV4Impl: nonMaxSuppressionV4Impl$2,
51186 nonMaxSuppressionV5Impl: nonMaxSuppressionV5Impl$2,
51187 whereImpl: whereImpl$2
51188 };
51189
51190 /**
51191 * @license
51192 * Copyright 2020 Google Inc. All Rights Reserved.
51193 * Licensed under the Apache License, Version 2.0 (the "License");
51194 * you may not use this file except in compliance with the License.
51195 * You may obtain a copy of the License at
51196 *
51197 * http://www.apache.org/licenses/LICENSE-2.0
51198 *
51199 * Unless required by applicable law or agreed to in writing, software
51200 * distributed under the License is distributed on an "AS IS" BASIS,
51201 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51202 * See the License for the specific language governing permissions and
51203 * limitations under the License.
51204 * =============================================================================
51205 */
51206
51207 /**
51208 * @license
51209 * Copyright 2017 Google LLC. All Rights Reserved.
51210 * Licensed under the Apache License, Version 2.0 (the "License");
51211 * you may not use this file except in compliance with the License.
51212 * You may obtain a copy of the License at
51213 *
51214 * http://www.apache.org/licenses/LICENSE-2.0
51215 *
51216 * Unless required by applicable law or agreed to in writing, software
51217 * distributed under the License is distributed on an "AS IS" BASIS,
51218 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51219 * See the License for the specific language governing permissions and
51220 * limitations under the License.
51221 * =============================================================================
51222 */
51223 registerOptimizers();
51224
51225 var absGradConfig = {
51226 kernelName: Abs,
51227 inputsToSave: ['x'],
51228 gradFunc: function gradFunc(dy, saved) {
51229 var _saved = _slicedToArray(saved, 1),
51230 _x = _saved[0];
51231 return {
51232 x: function x() {
51233 return mul(dy, step$2(cast$3(_x, 'float32'), -1));
51234 }
51235 };
51236 }
51237 };
51238
51239 var acosGradConfig = {
51240 kernelName: Acos,
51241 inputsToSave: ['x'],
51242 gradFunc: function gradFunc(dy, saved) {
51243 var _saved = _slicedToArray(saved, 1),
51244 _x = _saved[0];
51245 return {
51246 x: function x() {
51247 var a = square$2(cast$3(_x, 'float32'));
51248 var b = sqrt$2(sub$2(scalar(1), a));
51249 return neg$2(div$1(dy, b));
51250 }
51251 };
51252 }
51253 };
51254
51255 var acoshGradConfig = {
51256 kernelName: Acosh,
51257 inputsToSave: ['x'],
51258 gradFunc: function gradFunc(dy, saved) {
51259 var _saved = _slicedToArray(saved, 1),
51260 _x = _saved[0];
51261 return {
51262 x: function x() {
51263 var a = sqrt$2(sub$2(square$2(cast$3(_x, 'float32')), 1));
51264 return div$1(dy, a);
51265 }
51266 };
51267 }
51268 };
51269
51270 var addGradConfig = {
51271 kernelName: Add$1,
51272 inputsToSave: ['a', 'b'],
51273 gradFunc: function gradFunc(dy, saved) {
51274 var _saved = _slicedToArray(saved, 2),
51275 a = _saved[0],
51276 b = _saved[1];
51277 var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
51278 var derA = function derA() {
51279 var res = dy;
51280 var reduceAxes = getReductionAxes(a.shape, outShape);
51281 if (reduceAxes.length > 0) {
51282 res = sum$3(res, reduceAxes);
51283 }
51284 return reshape$3(res, a.shape);
51285 };
51286 var derB = function derB() {
51287 var res = dy;
51288 var reduceAxes = getReductionAxes(b.shape, outShape);
51289 if (reduceAxes.length > 0) {
51290 res = sum$3(res, reduceAxes);
51291 }
51292 return reshape$3(res, b.shape);
51293 };
51294 return {
51295 a: derA,
51296 b: derB
51297 };
51298 }
51299 };
51300
51301 /**
51302 * @license
51303 * Copyright 2020 Google LLC. All Rights Reserved.
51304 * Licensed under the Apache License, Version 2.0 (the "License");
51305 * you may not use this file except in compliance with the License.
51306 * You may obtain a copy of the License at
51307 *
51308 * http://www.apache.org/licenses/LICENSE-2.0
51309 *
51310 * Unless required by applicable law or agreed to in writing, software
51311 * distributed under the License is distributed on an "AS IS" BASIS,
51312 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51313 * See the License for the specific language governing permissions and
51314 * limitations under the License.
51315 * =============================================================================
51316 */
51317 var addNGradConfig = {
51318 kernelName: AddN,
51319 saveAllInputs: true,
51320 gradFunc: function gradFunc(dy, saved) {
51321 var ders = {};
51322 saved.forEach(function (_, i) {
51323 ders[i] = function () {
51324 return dy.clone();
51325 };
51326 });
51327 return ders;
51328 }
51329 };
51330
51331 var argMaxGradConfig = {
51332 kernelName: ArgMax,
51333 inputsToSave: ['x'],
51334 gradFunc: function gradFunc(dy, saved) {
51335 var _saved = _slicedToArray(saved, 1),
51336 _x = _saved[0];
51337 return {
51338 x: function x() {
51339 return zerosLike$3(_x);
51340 }
51341 };
51342 }
51343 };
51344
51345 var argMinGradConfig = {
51346 kernelName: ArgMin,
51347 inputsToSave: ['x'],
51348 gradFunc: function gradFunc(dy, saved) {
51349 var _saved = _slicedToArray(saved, 1),
51350 _x = _saved[0];
51351 return {
51352 x: function x() {
51353 return zerosLike$3(_x);
51354 }
51355 };
51356 }
51357 };
51358
51359 var asinGradConfig = {
51360 kernelName: Asin,
51361 inputsToSave: ['x'],
51362 gradFunc: function gradFunc(dy, saved) {
51363 var _saved = _slicedToArray(saved, 1),
51364 _x = _saved[0];
51365 return {
51366 x: function x() {
51367 return div$1(dy, sqrt$2(sub$2(scalar(1), square$2(cast$3(_x, 'float32')))));
51368 }
51369 };
51370 }
51371 };
51372
51373 var asinhGradConfig = {
51374 kernelName: Asinh,
51375 inputsToSave: ['x'],
51376 gradFunc: function gradFunc(dy, saved) {
51377 var _saved = _slicedToArray(saved, 1),
51378 _x = _saved[0];
51379 return {
51380 x: function x() {
51381 var a = sqrt$2(add$3(scalar(1), square$2(cast$3(_x, 'float32'))));
51382 return div$1(dy, a);
51383 }
51384 };
51385 }
51386 };
51387
51388 var atan2GradConfig = {
51389 kernelName: Atan2,
51390 inputsToSave: ['a', 'b'],
51391 gradFunc: function gradFunc(dy, saved) {
51392 var _saved = _slicedToArray(saved, 2),
51393 a = _saved[0],
51394 b = _saved[1];
51395 var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
51396 var derA = function derA() {
51397 var d = add$3(square$2(a), square$2(b));
51398 var res = mul(dy, div$1(b, d));
51399 var reduceAxes = getReductionAxes(a.shape, outShape);
51400 if (reduceAxes.length > 0) {
51401 res = sum$3(res, reduceAxes);
51402 }
51403 return reshape$3(res, a.shape);
51404 };
51405 var derB = function derB() {
51406 var d = add$3(square$2(a), square$2(b));
51407 var res = neg$2(mul(dy, div$1(a, d)));
51408 var reduceAxes = getReductionAxes(b.shape, outShape);
51409 if (reduceAxes.length > 0) {
51410 res = sum$3(res, reduceAxes);
51411 }
51412 return reshape$3(res, b.shape);
51413 };
51414 return {
51415 a: derA,
51416 b: derB
51417 };
51418 }
51419 };
51420
51421 var atanGradConfig = {
51422 kernelName: Atan,
51423 inputsToSave: ['x'],
51424 gradFunc: function gradFunc(dy, saved) {
51425 var _saved = _slicedToArray(saved, 1),
51426 _x = _saved[0];
51427 return {
51428 x: function x() {
51429 return div$1(dy, add$3(square$2(cast$3(_x, 'float32')), 1));
51430 }
51431 };
51432 }
51433 };
51434
51435 var atanhGradConfig = {
51436 kernelName: Atanh,
51437 inputsToSave: ['x'],
51438 gradFunc: function gradFunc(dy, saved) {
51439 var _saved = _slicedToArray(saved, 1),
51440 _x = _saved[0];
51441 return {
51442 x: function x() {
51443 return div$1(dy, sub$2(scalar(1), square$2(cast$3(_x, 'float32'))));
51444 }
51445 };
51446 }
51447 };
51448
51449 /**
51450 * @license
51451 * Copyright 2020 Google LLC. All Rights Reserved.
51452 * Licensed under the Apache License, Version 2.0 (the "License");
51453 * you may not use this file except in compliance with the License.
51454 * You may obtain a copy of the License at
51455 *
51456 * http://www.apache.org/licenses/LICENSE-2.0
51457 *
51458 * Unless required by applicable law or agreed to in writing, software
51459 * distributed under the License is distributed on an "AS IS" BASIS,
51460 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51461 * See the License for the specific language governing permissions and
51462 * limitations under the License.
51463 * =============================================================================
51464 */
51465 /**
51466 * Computes the backprop of a 3d avg pool.
51467 *
51468 * @param dy The dy error, of rank 5 of shape
51469 * [batchSize, depth, height, width, channels].
51470 * assumed.
51471 * @param input The original input image, of rank 5 or rank4 of shape
51472 * [batchSize, depth, height, width, channels].
51473 * @param filterSize The filter size:
51474 * `[filterDepth, filterHeight, filterWidth]`.
51475 * `filterSize` is a single number,
51476 * then `filterDepth == filterHeight == filterWidth`.
51477 * @param strides The strides of the pooling:
51478 * `[strideDepth, strideHeight, strideWidth]`. If
51479 * `strides` is a single number, then `strideHeight == strideWidth`.
51480 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
51481 * used in the forward prop of the op.
51482 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
51483 * provided, it will default to truncate.
51484 */
51485 function avgPool3dGrad_(dy, input, filterSize, strides, pad, dimRoundingMode) {
51486 var $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad');
51487 var $input = convertToTensor(input, 'input', 'avgPool3dGrad');
51488 var dy5D = $dy;
51489 var input5D = $input;
51490 var reshapedTo5D = false;
51491 if ($input.rank === 4) {
51492 reshapedTo5D = true;
51493 dy5D = reshape$3($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
51494 input5D = reshape$3($input, [1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]]);
51495 }
51496 assert$1(dy5D.rank === 5, function () {
51497 return "Error in avgPool3dGrad: dy must be rank 5 but got rank " + "".concat(dy5D.rank, ".");
51498 });
51499 assert$1(input5D.rank === 5, function () {
51500 return "Error in avgPool3dGrad: input must be rank 5 but got rank " + "".concat(input5D.rank, ".");
51501 });
51502 checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);
51503 var inputs = {
51504 dy: dy5D,
51505 input: input5D
51506 };
51507 var attrs = {
51508 filterSize: filterSize,
51509 strides: strides,
51510 pad: pad,
51511 dimRoundingMode: dimRoundingMode
51512 };
51513 // tslint:disable-next-line: no-unnecessary-type-assertion
51514 var res = ENGINE.runKernel(AvgPool3DGrad, inputs, attrs);
51515 if (reshapedTo5D) {
51516 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
51517 }
51518 return res;
51519 }
51520 var avgPool3dGrad = /* @__PURE__ */op({
51521 avgPool3dGrad_: avgPool3dGrad_
51522 });
51523
51524 var avgPool3DGradConfig$2 = {
51525 kernelName: AvgPool3D,
51526 inputsToSave: ['x'],
51527 gradFunc: function gradFunc(dy, saved, attrs) {
51528 var _saved = _slicedToArray(saved, 1),
51529 _x = _saved[0];
51530 var filterSize = attrs.filterSize,
51531 strides = attrs.strides,
51532 pad = attrs.pad,
51533 dimRoundingMode = attrs.dimRoundingMode;
51534 return {
51535 x: function x() {
51536 return avgPool3dGrad(dy, _x, filterSize, strides, pad, dimRoundingMode);
51537 }
51538 };
51539 }
51540 };
51541
51542 /**
51543 * @license
51544 * Copyright 2020 Google LLC. All Rights Reserved.
51545 * Licensed under the Apache License, Version 2.0 (the "License");
51546 * you may not use this file except in compliance with the License.
51547 * You may obtain a copy of the License at
51548 *
51549 * http://www.apache.org/licenses/LICENSE-2.0
51550 *
51551 * Unless required by applicable law or agreed to in writing, software
51552 * distributed under the License is distributed on an "AS IS" BASIS,
51553 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51554 * See the License for the specific language governing permissions and
51555 * limitations under the License.
51556 * =============================================================================
51557 */
51558 /**
51559 * Computes the backprop of an 2D avg pool.
51560 *
51561 * @param dy The dy error, of rank 4 or rank 3 of shape
51562 * [batchSize, height, width, channels]. If rank 3, batch of 1 is
51563 * assumed.
51564 * @param input The input image, of rank 4 or rank 3 of shape
51565 * [batchSize, height, width, channels]. If rank 3, batch of 1 is
51566 * assumed.
51567 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
51568 * `filterSize` is a single number, then `filterHeight == filterWidth`.
51569 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
51570 * `strides` is a single number, then `strideHeight == strideWidth`.
51571 * @param pad The type of padding algorithm used in the forward prop of the op.
51572 * 'same', 'valid', for more info, see this guide:
51573 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
51574 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
51575 */
51576 function avgPoolGrad_(dy, input, filterSize, strides, pad) {
51577 var $dy = convertToTensor(dy, 'dy', 'avgPoolGrad');
51578 var $input = convertToTensor(input, 'input', 'avgPoolGrad');
51579 assert$1($input.rank === $dy.rank, function () {
51580 return "Rank of input (".concat($input.rank, ") does not match rank of dy (").concat($dy.rank, ")");
51581 });
51582 var input4D = $input;
51583 var dy4D = $dy;
51584 var reshapedTo4D = false;
51585 if ($input.rank === 3) {
51586 reshapedTo4D = true;
51587 input4D = reshape$3($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]);
51588 dy4D = reshape$3($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]);
51589 }
51590 assert$1(dy4D.rank === 4, function () {
51591 return "Error in avgPoolGrad: dy must be rank 4 but got rank " + "".concat(dy4D.rank, ".");
51592 });
51593 assert$1(input4D.rank === 4, function () {
51594 return "Error in avgPoolGrad: input must be rank 4 but got rank " + "".concat(input4D.rank, ".");
51595 });
51596 var inputs = {
51597 dy: dy4D,
51598 input: input4D
51599 };
51600 var attrs = {
51601 filterSize: filterSize,
51602 strides: strides,
51603 pad: pad
51604 };
51605 // tslint:disable-next-line: no-unnecessary-type-assertion
51606 var res = ENGINE.runKernel(AvgPoolGrad, inputs, attrs);
51607 if (reshapedTo4D) {
51608 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
51609 }
51610 return res;
51611 }
51612 var avgPoolGrad$2 = /* @__PURE__ */op({
51613 avgPoolGrad_: avgPoolGrad_
51614 });
51615
51616 var avgPoolGradConfig$2 = {
51617 kernelName: AvgPool,
51618 inputsToSave: ['x'],
51619 gradFunc: function gradFunc(dy, saved, attrs) {
51620 var _saved = _slicedToArray(saved, 1),
51621 _x = _saved[0];
51622 var filterSize = attrs.filterSize,
51623 strides = attrs.strides,
51624 pad = attrs.pad;
51625 return {
51626 x: function x() {
51627 return avgPoolGrad$2(dy, _x, filterSize, strides, pad);
51628 }
51629 };
51630 }
51631 };
51632
51633 var batchMatMulGradConfig = {
51634 kernelName: BatchMatMul,
51635 inputsToSave: ['a', 'b'],
51636 gradFunc: function gradFunc(dy, saved, attrs) {
51637 var _saved = _slicedToArray(saved, 2),
51638 a = _saved[0],
51639 b = _saved[1];
51640 var transposeA = attrs.transposeA,
51641 transposeB = attrs.transposeB;
51642 if (!transposeA && !transposeB) {
51643 return {
51644 a: function a() {
51645 return matMul$1(dy, b, false, true);
51646 },
51647 b: function b() {
51648 return matMul$1(a, dy, true, false);
51649 }
51650 };
51651 } else if (!transposeA && transposeB) {
51652 return {
51653 a: function a() {
51654 return matMul$1(dy, b, false, false);
51655 },
51656 b: function b() {
51657 return matMul$1(dy, a, true, false);
51658 }
51659 };
51660 } else if (transposeA && !transposeB) {
51661 return {
51662 a: function a() {
51663 return matMul$1(b, dy, false, true);
51664 },
51665 b: function b() {
51666 return matMul$1(a, dy, false, false);
51667 }
51668 };
51669 } else {
51670 return {
51671 a: function a() {
51672 return matMul$1(b, dy, true, true);
51673 },
51674 b: function b() {
51675 return matMul$1(dy, a, true, true);
51676 }
51677 };
51678 }
51679 }
51680 };
51681
51682 /**
51683 * @license
51684 * Copyright 2020 Google LLC. All Rights Reserved.
51685 * Licensed under the Apache License, Version 2.0 (the "License");
51686 * you may not use this file except in compliance with the License.
51687 * You may obtain a copy of the License at
51688 *
51689 * http://www.apache.org/licenses/LICENSE-2.0
51690 *
51691 * Unless required by applicable law or agreed to in writing, software
51692 * distributed under the License is distributed on an "AS IS" BASIS,
51693 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51694 * See the License for the specific language governing permissions and
51695 * limitations under the License.
51696 * =============================================================================
51697 */
51698 var batchToSpaceNDGradConfig = {
51699 kernelName: BatchToSpaceND,
51700 gradFunc: function gradFunc(dy, saved, attrs) {
51701 var blockShape = attrs.blockShape,
51702 crops = attrs.crops;
51703 return {
51704 x: function x() {
51705 return spaceToBatchND$2(dy, blockShape, crops);
51706 }
51707 };
51708 }
51709 };
51710
51711 /**
51712 * @license
51713 * Copyright 2020 Google LLC. All Rights Reserved.
51714 * Licensed under the Apache License, Version 2.0 (the "License");
51715 * you may not use this file except in compliance with the License.
51716 * You may obtain a copy of the License at
51717 *
51718 * http://www.apache.org/licenses/LICENSE-2.0
51719 *
51720 * Unless required by applicable law or agreed to in writing, software
51721 * distributed under the License is distributed on an "AS IS" BASIS,
51722 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51723 * See the License for the specific language governing permissions and
51724 * limitations under the License.
51725 * =============================================================================
51726 */
51727 var broadcastToGradConfig = {
51728 kernelName: BroadcastTo,
51729 gradFunc: function gradFunc(dy, saved, attrs) {
51730 var broadCastToAttrs = attrs;
51731 var inputShape = broadCastToAttrs.inputShape;
51732 var outputShape = broadCastToAttrs.shape;
51733 var reps = Array.from(outputShape);
51734 for (var i = inputShape.length - 1; i >= 0; i--) {
51735 if (inputShape[i] === outputShape[i]) {
51736 reps[i] = 1;
51737 } else if (inputShape[i] !== 1) {
51738 throw new Error("broadcastTo(): [".concat(inputShape, "] cannot be broadcast to [").concat(outputShape, "]."));
51739 }
51740 }
51741 var axes = [];
51742 for (var _i = 0; _i < reps.length; _i++) {
51743 if (reps[_i] > 1) {
51744 axes.push(_i);
51745 }
51746 }
51747 return {
51748 x: function x() {
51749 return sum$3(dy, axes, true /* keepDims */);
51750 }
51751 };
51752 }
51753 };
51754
51755 /**
51756 * @license
51757 * Copyright 2020 Google LLC. All Rights Reserved.
51758 * Licensed under the Apache License, Version 2.0 (the "License");
51759 * you may not use this file except in compliance with the License.
51760 * You may obtain a copy of the License at
51761 *
51762 * http://www.apache.org/licenses/LICENSE-2.0
51763 *
51764 * Unless required by applicable law or agreed to in writing, software
51765 * distributed under the License is distributed on an "AS IS" BASIS,
51766 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51767 * See the License for the specific language governing permissions and
51768 * limitations under the License.
51769 * =============================================================================
51770 */
51771 var castGradConfig = {
51772 kernelName: Cast,
51773 gradFunc: function gradFunc(dy) {
51774 return {
51775 x: function x() {
51776 return dy.clone();
51777 }
51778 };
51779 }
51780 };
51781
51782 /**
51783 * @license
51784 * Copyright 2020 Google LLC. All Rights Reserved.
51785 * Licensed under the Apache License, Version 2.0 (the "License");
51786 * you may not use this file except in compliance with the License.
51787 * You may obtain a copy of the License at
51788 *
51789 * http://www.apache.org/licenses/LICENSE-2.0
51790 *
51791 * Unless required by applicable law or agreed to in writing, software
51792 * distributed under the License is distributed on an "AS IS" BASIS,
51793 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51794 * See the License for the specific language governing permissions and
51795 * limitations under the License.
51796 * =============================================================================
51797 */
51798 var ceilGradConfig = {
51799 kernelName: Ceil,
51800 gradFunc: function gradFunc(dy) {
51801 // TODO(manrajgrover): Return null for gradients when backprop supports it.
51802 return {
51803 x: function x() {
51804 return zerosLike$3(dy);
51805 }
51806 };
51807 }
51808 };
51809
51810 var clipByValueGradConfig = {
51811 kernelName: ClipByValue,
51812 inputsToSave: ['x'],
51813 gradFunc: function gradFunc(dy, saved, attrs) {
51814 var _saved = _slicedToArray(saved, 1),
51815 _x = _saved[0];
51816 var clipValueMin = attrs.clipValueMin,
51817 clipValueMax = attrs.clipValueMax;
51818 return {
51819 x: function x() {
51820 return where(logicalAnd$2(greaterEqual$2(_x, clipValueMin), lessEqual$2(_x, clipValueMax)), dy, zerosLike$3(dy));
51821 }
51822 };
51823 }
51824 };
51825
51826 /**
51827 * @license
51828 * Copyright 2020 Google LLC. All Rights Reserved.
51829 * Licensed under the Apache License, Version 2.0 (the "License");
51830 * you may not use this file except in compliance with the License.
51831 * You may obtain a copy of the License at
51832 *
51833 * http://www.apache.org/licenses/LICENSE-2.0
51834 *
51835 * Unless required by applicable law or agreed to in writing, software
51836 * distributed under the License is distributed on an "AS IS" BASIS,
51837 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51838 * See the License for the specific language governing permissions and
51839 * limitations under the License.
51840 * =============================================================================
51841 */
51842 var complexAbsGradConfig = {
51843 kernelName: ComplexAbs,
51844 inputsToSave: ['x'],
51845 gradFunc: absGradConfig.gradFunc
51846 };
51847
51848 /**
51849 * @license
51850 * Copyright 2020 Google LLC. All Rights Reserved.
51851 * Licensed under the Apache License, Version 2.0 (the "License");
51852 * you may not use this file except in compliance with the License.
51853 * You may obtain a copy of the License at
51854 *
51855 * http://www.apache.org/licenses/LICENSE-2.0
51856 *
51857 * Unless required by applicable law or agreed to in writing, software
51858 * distributed under the License is distributed on an "AS IS" BASIS,
51859 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51860 * See the License for the specific language governing permissions and
51861 * limitations under the License.
51862 * =============================================================================
51863 */
51864 var concatGradConfig = {
51865 kernelName: Concat,
51866 saveAllInputs: true,
51867 gradFunc: function gradFunc(dy, saved, attrs) {
51868 var shapes = saved.map(function (t) {
51869 return t.shape;
51870 });
51871 var axis = attrs.axis;
51872 var $axis = parseAxisParam(axis, saved[0].shape)[0];
51873 var sizeSplits = shapes.map(function (s) {
51874 return s[$axis];
51875 });
51876 var derTensors = split$3(dy, sizeSplits, $axis);
51877 return derTensors.map(function (t) {
51878 return function () {
51879 return t;
51880 };
51881 });
51882 }
51883 };
51884
51885 var conv2DGradConfig = {
51886 kernelName: Conv2D$1,
51887 inputsToSave: ['x', 'filter'],
51888 gradFunc: function gradFunc(dy, saved, attrs) {
51889 var _saved = _slicedToArray(saved, 2),
51890 x4D = _saved[0],
51891 $filter = _saved[1];
51892 var dilations = attrs.dilations,
51893 strides = attrs.strides,
51894 pad = attrs.pad,
51895 dataFormat = attrs.dataFormat;
51896 assert$1(tupleValuesAreOne(dilations), function () {
51897 return 'Error in gradient of conv2D: dilation rates greater than 1 ' + "are not yet supported in gradients. Got dilations '".concat(dilations, "'");
51898 });
51899 return {
51900 x: function x() {
51901 return conv2DBackpropInput$2(x4D.shape, dy, $filter, strides, pad, dataFormat);
51902 },
51903 filter: function filter() {
51904 return conv2DBackpropFilter$2(x4D, dy, $filter.shape, strides, pad, dataFormat);
51905 }
51906 };
51907 }
51908 };
51909
51910 var conv2DBackpropInputGradConfig = {
51911 kernelName: Conv2DBackpropInput,
51912 inputsToSave: ['dy', 'filter'],
51913 gradFunc: function gradFunc(ddx, saved, attrs) {
51914 var _saved = _slicedToArray(saved, 2),
51915 dy = _saved[0],
51916 _filter = _saved[1];
51917 var strides = attrs.strides,
51918 pad = attrs.pad,
51919 dataFormat = attrs.dataFormat,
51920 dimRoundingMode = attrs.dimRoundingMode;
51921 return {
51922 dy: function dy() {
51923 return conv2d$4(ddx, _filter, strides, pad, dataFormat, 1 /* dilations */, dimRoundingMode);
51924 },
51925 filter: function filter() {
51926 return conv2DBackpropFilter$2(ddx, dy, _filter.shape, strides, pad, dataFormat, dimRoundingMode);
51927 }
51928 };
51929 }
51930 };
51931
51932 /**
51933 * @license
51934 * Copyright 2020 Google LLC. All Rights Reserved.
51935 * Licensed under the Apache License, Version 2.0 (the "License");
51936 * you may not use this file except in compliance with the License.
51937 * You may obtain a copy of the License at
51938 *
51939 * http://www.apache.org/licenses/LICENSE-2.0
51940 *
51941 * Unless required by applicable law or agreed to in writing, software
51942 * distributed under the License is distributed on an "AS IS" BASIS,
51943 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51944 * See the License for the specific language governing permissions and
51945 * limitations under the License.
51946 * =============================================================================
51947 */
51948 /**
51949 * Computes the derivative of the filter of a 3D convolution.
51950 *
51951 * @param x The input tensor, of rank 5 or rank 4 of shape
51952 * [batch, depth, height, width, inChannels]. If rank 4, batch of 1 is
51953 * assumed.
51954 * @param dy The dy image, of rank 5 or rank 4, of shape
51955 * [batch, depth, height, width, outDepth]. If rank 4, batch of 1 is
51956 * assumed.
51957 * @param filterShape The shape of the filter, length 5,
51958 * [filterDepth, filterHeight, filterWidth, inDepth, outDepth].
51959 * @param strides The strides of the convolution: [strideDepth, strideHeight,
51960 * strideWidth].
51961 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
51962 * used in the forward prop of the op.
51963 */
51964 function conv3DBackpropFilter_(x, dy, filterShape, strides, pad) {
51965 var x5D = x;
51966 if (x.rank === 4) {
51967 x5D = reshape$3(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]);
51968 }
51969 var dy5D = dy;
51970 if (dy5D.rank === 4) {
51971 dy5D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
51972 }
51973 assert$1(x5D.rank === 5, function () {
51974 return "Error in conv3dDerFilter: input must be rank 5, but got shape " + "".concat(x5D.shape, ".");
51975 });
51976 assert$1(dy5D.rank === 5, function () {
51977 return "Error in conv3dDerFilter: dy must be rank 5, but got shape " + "".concat(dy5D.shape, ".");
51978 });
51979 assert$1(filterShape.length === 5, function () {
51980 return "Error in conv3dDerFilter: filterShape must be length 5, but got " + "".concat(filterShape, ".");
51981 });
51982 assert$1(x5D.shape[4] === filterShape[3], function () {
51983 return "Error in conv3dDerFilter: depth of input ".concat(x5D.shape[4], ") must ") + "match input depth in filter (".concat(filterShape[3], ".");
51984 });
51985 assert$1(dy5D.shape[4] === filterShape[4], function () {
51986 return "Error in conv3dDerFilter: depth of dy (".concat(dy5D.shape[4], ") must ") + "match output depth for filter (".concat(filterShape[4], ").");
51987 });
51988 var inputs = {
51989 x: x5D,
51990 dy: dy5D
51991 };
51992 var attrs = {
51993 strides: strides,
51994 pad: pad,
51995 filterShape: filterShape
51996 };
51997 // tslint:disable-next-line: no-unnecessary-type-assertion
51998 return ENGINE.runKernel(Conv3DBackpropFilterV2, inputs, attrs);
51999 }
52000 var conv3DBackpropFilter = /* @__PURE__ */op({
52001 conv3DBackpropFilter_: conv3DBackpropFilter_
52002 });
52003
52004 var conv3DGradConfig = {
52005 kernelName: Conv3D$1,
52006 inputsToSave: ['x', 'filter'],
52007 gradFunc: function gradFunc(dy, saved, attrs) {
52008 var dilations = attrs.dilations,
52009 strides = attrs.strides,
52010 pad = attrs.pad;
52011 assert$1(tupleValuesAreOne(dilations), function () {
52012 return 'Error in gradient of conv3D: dilation rates greater than 1 are ' + "not yet supported in gradients. Got dilations '".concat(dilations, "'");
52013 });
52014 var _saved = _slicedToArray(saved, 2),
52015 x5D = _saved[0],
52016 $filter = _saved[1];
52017 return {
52018 x: function x() {
52019 return conv3DBackpropInput$1(x5D.shape, dy, $filter, strides, pad);
52020 },
52021 filter: function filter() {
52022 return conv3DBackpropFilter(x5D, dy, $filter.shape, strides, pad);
52023 }
52024 };
52025 }
52026 };
52027
52028 var cosGradConfig = {
52029 kernelName: Cos,
52030 inputsToSave: ['x'],
52031 gradFunc: function gradFunc(dy, saved) {
52032 var _saved = _slicedToArray(saved, 1),
52033 _x = _saved[0];
52034 return {
52035 x: function x() {
52036 return mul(neg$2(sin$2(cast$3(_x, 'float32'))), dy);
52037 }
52038 };
52039 }
52040 };
52041
52042 var coshGradConfig = {
52043 kernelName: Cosh,
52044 inputsToSave: ['x'],
52045 gradFunc: function gradFunc(dy, saved) {
52046 var _saved = _slicedToArray(saved, 1),
52047 _x = _saved[0];
52048 return {
52049 x: function x() {
52050 return mul(sinh$2(cast$3(_x, 'float32')), dy);
52051 }
52052 };
52053 }
52054 };
52055
52056 var cumsumGradConfig = {
52057 kernelName: Cumsum,
52058 inputsToSave: ['x'],
52059 gradFunc: function gradFunc(dy, saved, attrs) {
52060 var _saved = _slicedToArray(saved, 1),
52061 _x = _saved[0];
52062 var axis = attrs.axis,
52063 exclusive = attrs.exclusive,
52064 reverse = attrs.reverse;
52065 return {
52066 x: function x() {
52067 var permutation = getAxesPermutation([axis], _x.rank);
52068 var out = cumsum$2(dy, axis, exclusive, !reverse);
52069 if (permutation != null) {
52070 out = transpose$2(out, permutation);
52071 }
52072 return out;
52073 }
52074 };
52075 }
52076 };
52077
52078 var depthwiseConv2dNativeGradConfig = {
52079 kernelName: DepthwiseConv2dNative,
52080 inputsToSave: ['x', 'filter'],
52081 gradFunc: function gradFunc(dy, saved, attrs) {
52082 var dilations = attrs.dilations,
52083 strides = attrs.strides,
52084 pad = attrs.pad,
52085 dimRoundingMode = attrs.dimRoundingMode;
52086 var $dilations = dilations == null ? [1, 1] : dilations;
52087 assert$1(tupleValuesAreOne($dilations), function () {
52088 return 'Error in gradient of depthwiseConv2dNative: dilation rates ' + "greater than 1 are not yet supported. Got dilations " + "'".concat($dilations, "'");
52089 });
52090 var _saved = _slicedToArray(saved, 2),
52091 _x = _saved[0],
52092 _filter = _saved[1];
52093 assert$1(_x.rank === 4, function () {
52094 return "Error in gradient of depthwiseConv2dNative: input must be " + "rank 4, but got rank ".concat(_x.rank, ".");
52095 });
52096 assert$1(_filter.rank === 4, function () {
52097 return "Error in gradient of depthwiseConv2dNative: filter must be " + "rank 4, but got rank ".concat(_filter.rank, ".");
52098 });
52099 assert$1(_x.shape[3] === _filter.shape[2], function () {
52100 return "Error in gradient of depthwiseConv2d: number of input " + "channels (".concat(_x.shape[3], ") must match the inChannels dimension ") + "in filter ".concat(_filter.shape[2], ".");
52101 });
52102 assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), function () {
52103 return 'Error in gradient of depthwiseConv2d: Either strides or ' + "dilations must be 1. Got strides ".concat(strides, " and dilations ") + "'".concat($dilations, "'.");
52104 });
52105 checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode);
52106 return {
52107 x: function x() {
52108 return depthwiseConv2dNativeBackpropInput$2(_x.shape, dy, _filter, strides, pad, $dilations, dimRoundingMode);
52109 },
52110 filter: function filter() {
52111 return depthwiseConv2dNativeBackpropFilter$2(_x, dy, _filter.shape, strides, pad, $dilations, dimRoundingMode);
52112 }
52113 };
52114 }
52115 };
52116
52117 var dilation2dGradConfig = {
52118 kernelName: Dilation2D,
52119 inputsToSave: ['x', 'filter'],
52120 gradFunc: function gradFunc(dy, saved, attrs) {
52121 var _saved = _slicedToArray(saved, 2),
52122 x = _saved[0],
52123 filter = _saved[1];
52124 var inputInputs = {
52125 x: x,
52126 filter: filter,
52127 dy: dy
52128 };
52129 var filterInputs = {
52130 x: x,
52131 filter: filter,
52132 dy: dy
52133 };
52134 return {
52135 x: function x() {
52136 return ENGINE.runKernel(Dilation2DBackpropInput, inputInputs, attrs);
52137 },
52138 filter: function filter() {
52139 return ENGINE.runKernel(Dilation2DBackpropFilter, filterInputs, attrs);
52140 }
52141 };
52142 }
52143 };
52144
52145 var eluGradConfig$2 = {
52146 kernelName: Elu$1,
52147 outputsToSave: [true],
52148 gradFunc: function gradFunc(dy, saved) {
52149 var _saved = _slicedToArray(saved, 1),
52150 y = _saved[0];
52151 var inputs = {
52152 dy: dy,
52153 y: y
52154 };
52155 return {
52156 x: function x() {
52157 return ENGINE.runKernel(EluGrad, inputs);
52158 }
52159 };
52160 }
52161 };
52162
52163 var erfGradConfig = {
52164 kernelName: Erf,
52165 inputsToSave: ['x'],
52166 gradFunc: function gradFunc(dy, saved) {
52167 var _saved = _slicedToArray(saved, 1),
52168 x = _saved[0];
52169 var a = mul(exp$2(neg$2(square$2(x))), 2 / Math.sqrt(Math.PI));
52170 return {
52171 x: function x() {
52172 return mul(dy, a);
52173 }
52174 };
52175 }
52176 };
52177
52178 var expGradConfig = {
52179 kernelName: Exp,
52180 outputsToSave: [true],
52181 gradFunc: function gradFunc(dy, saved) {
52182 var _saved = _slicedToArray(saved, 1),
52183 y = _saved[0];
52184 return {
52185 x: function x() {
52186 return mul(dy, y);
52187 }
52188 };
52189 }
52190 };
52191
52192 var expandDimsGradConfig = {
52193 kernelName: ExpandDims,
52194 inputsToSave: ['input'],
52195 gradFunc: function gradFunc(dy, saved) {
52196 var _saved = _slicedToArray(saved, 1),
52197 _input = _saved[0];
52198 return {
52199 input: function input() {
52200 return reshape$3(dy, _input.shape);
52201 }
52202 };
52203 }
52204 };
52205
52206 var expm1GradConfig = {
52207 kernelName: Expm1,
52208 inputsToSave: ['x'],
52209 gradFunc: function gradFunc(dy, saved) {
52210 var _saved = _slicedToArray(saved, 1),
52211 _x = _saved[0];
52212 return {
52213 x: function x() {
52214 return mul(dy, exp$2(_x));
52215 }
52216 };
52217 }
52218 };
52219
52220 /**
52221 * @license
52222 * Copyright 2020 Google LLC. All Rights Reserved.
52223 * Licensed under the Apache License, Version 2.0 (the "License");
52224 * you may not use this file except in compliance with the License.
52225 * You may obtain a copy of the License at
52226 *
52227 * http://www.apache.org/licenses/LICENSE-2.0
52228 *
52229 * Unless required by applicable law or agreed to in writing, software
52230 * distributed under the License is distributed on an "AS IS" BASIS,
52231 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52232 * See the License for the specific language governing permissions and
52233 * limitations under the License.
52234 * =============================================================================
52235 */
52236 var floorGradConfig = {
52237 kernelName: Floor,
52238 gradFunc: function gradFunc(dy) {
52239 return {
52240 x: function x() {
52241 return zerosLike$3(dy);
52242 }
52243 };
52244 }
52245 };
52246
52247 var floorDivGradConfig = {
52248 kernelName: FloorDiv,
52249 inputsToSave: ['a', 'b'],
52250 gradFunc: function gradFunc(dy, saved) {
52251 var _saved = _slicedToArray(saved, 2),
52252 a = _saved[0],
52253 b = _saved[1];
52254 var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
52255 var derA = function derA() {
52256 var res = div$1(dy, cast$3(b, 'float32'));
52257 var reduceAxes = getReductionAxes(a.shape, outShape);
52258 if (reduceAxes.length > 0) {
52259 return reshape$3(sum$3(res, reduceAxes), a.shape);
52260 }
52261 return res;
52262 };
52263 var derB = function derB() {
52264 var res = mul(dy, cast$3(a, 'float32'));
52265 var reduceAxes = getReductionAxes(b.shape, outShape);
52266 if (reduceAxes.length > 0) {
52267 res = reshape$3(sum$3(res, reduceAxes), b.shape);
52268 }
52269 var tmp = square$2(b);
52270 return neg$2(div$1(res, cast$3(tmp, 'float32')));
52271 };
52272 return {
52273 a: derA,
52274 b: derB
52275 };
52276 }
52277 };
52278
52279 var fusedBatchNormGradConfig = {
52280 kernelName: FusedBatchNorm,
52281 inputsToSave: ['x', 'mean', 'variance', 'scale'],
52282 gradFunc: function gradFunc(dy, saved, attrs) {
52283 var varianceEpsilon = attrs.varianceEpsilon;
52284 var _saved = _slicedToArray(saved, 4),
52285 x = _saved[0],
52286 mean = _saved[1],
52287 variance = _saved[2],
52288 scale = _saved[3];
52289 var scaleValue = scale == null ? scalar(1) : scale;
52290 var reductionAxes = getReductionAxes(mean.shape, x.shape);
52291 var tileShape = [];
52292 if (mean.rank === 1) {
52293 for (var i = 0; i < x.shape.length - 1; ++i) {
52294 tileShape.push(x.shape[i]);
52295 }
52296 tileShape.push(1);
52297 }
52298 var xMinusMean = sub$2(x, mean);
52299 var dyTimesScaleValue = mul(dy, scaleValue);
52300 var oneOverSqrtVariance = rsqrt$2(add$3(variance, scalar(varianceEpsilon)));
52301 var minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5));
52302 var derX = function derX() {
52303 if (mean.rank === 1) {
52304 return reshape$3(mul(mul(dy, tile$3(reshape$3(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape);
52305 } else {
52306 return reshape$3(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);
52307 }
52308 };
52309 var derMean = function derMean() {
52310 var meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);
52311 if (mean.rank === 1) {
52312 meanDer = sum$3(meanDer, reductionAxes);
52313 }
52314 return reshape$3(meanDer, mean.shape);
52315 };
52316 var derVariance = function derVariance() {
52317 var varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);
52318 if (mean.rank === 1) {
52319 varianceDer = sum$3(varianceDer, reductionAxes);
52320 }
52321 return reshape$3(varianceDer, mean.shape);
52322 };
52323 var derScale = function derScale() {
52324 var xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);
52325 var scaleDer = mul(dy, xMinusMean2TimesRsqrt);
52326 if (mean.rank === 1) {
52327 scaleDer = sum$3(scaleDer, reductionAxes);
52328 }
52329 return reshape$3(scaleDer, mean.shape);
52330 };
52331 var derOffset = function derOffset() {
52332 var offsetDer = dy;
52333 if (mean.rank === 1) {
52334 offsetDer = sum$3(offsetDer, reductionAxes);
52335 }
52336 return reshape$3(offsetDer, mean.shape);
52337 };
52338 return {
52339 x: derX,
52340 mean: derMean,
52341 variance: derVariance,
52342 scale: derScale,
52343 offset: derOffset
52344 };
52345 }
52346 };
52347
52348 var gatherGradConfig = {
52349 kernelName: GatherV2,
52350 inputsToSave: ['x', 'indices'],
52351 gradFunc: function gradFunc(dy, saved, attrs) {
52352 var _saved = _slicedToArray(saved, 2),
52353 x = _saved[0],
52354 _indices = _saved[1];
52355 var axis = attrs.axis,
52356 batchDims = attrs.batchDims;
52357 var parsedAxis = parseAxisParam(axis, x.shape)[0];
52358 var derXBatch = function derXBatch(x, indices, dy) {
52359 return function () {
52360 var paramsShape = x.shape;
52361 var indicesSize = indices.size;
52362 var outerShape = paramsShape.slice(0, parsedAxis);
52363 var outerDims = outerShape.length;
52364 var innerShape = paramsShape.slice(axis, paramsShape.length).slice(1);
52365 var innerDims = innerShape.length;
52366 var outerAxesIndices = arrayRange(0, outerDims);
52367 var innerAxesIndices = arrayRange(outerDims + 1, outerDims + 1 + innerDims);
52368 var valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]);
52369 var values = reshape$3(dy, valuesShape);
52370 var reshapedIndices = reshape$3(indices, [indicesSize]);
52371 var transposeDims = arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]);
52372 var valuesTranspose = transpose$2(values, transposeDims);
52373 var paramsGrad = unsortedSegmentSum$2(valuesTranspose, reshapedIndices, x.shape[parsedAxis]);
52374 var invertTransposeDims = getUndoAxesPermutation(transposeDims);
52375 paramsGrad = transpose$2(paramsGrad, invertTransposeDims);
52376 return paramsGrad;
52377 };
52378 };
52379 if (batchDims === 1) {
52380 var batchSize = x.shape[0];
52381 var xBatch = x.split(batchSize, 0);
52382 var derXBatched = function derXBatched() {
52383 var stacked = stack(xBatch.map(function (x, i) {
52384 return derXBatch(x, _indices.slice(i, 1), dy.slice(i, 1))();
52385 }));
52386 return stacked.reshape(x.shape);
52387 };
52388 return {
52389 x: derXBatched,
52390 indices: function indices() {
52391 return _indices;
52392 }
52393 };
52394 } else {
52395 return {
52396 x: derXBatch(x, _indices, dy),
52397 indices: function indices() {
52398 return _indices;
52399 }
52400 };
52401 }
52402 }
52403 };
52404 function arrayRange(start, stop) {
52405 var result = [];
52406 for (var i = start; i < stop; ++i) {
52407 result.push(i);
52408 }
52409 return result;
52410 }
52411 function arrayConcat(arrays) {
52412 var result = [];
52413 for (var i = 0; i < arrays.length; ++i) {
52414 for (var j = 0; j < arrays[i].length; ++j) {
52415 result.push(arrays[i][j]);
52416 }
52417 }
52418 return result;
52419 }
52420
52421 var greaterEqualGradConfig = {
52422 kernelName: GreaterEqual,
52423 inputsToSave: ['a', 'b'],
52424 gradFunc: function gradFunc(dy, saved) {
52425 var _saved = _slicedToArray(saved, 2),
52426 _a = _saved[0],
52427 _b = _saved[1];
52428 return {
52429 a: function a() {
52430 return zerosLike$3(_a);
52431 },
52432 b: function b() {
52433 return zerosLike$3(_b);
52434 }
52435 };
52436 }
52437 };
52438
52439 /**
52440 * @license
52441 * Copyright 2020 Google LLC. All Rights Reserved.
52442 * Licensed under the Apache License, Version 2.0 (the "License");
52443 * you may not use this file except in compliance with the License.
52444 * You may obtain a copy of the License at
52445 *
52446 * http://www.apache.org/licenses/LICENSE-2.0
52447 *
52448 * Unless required by applicable law or agreed to in writing, software
52449 * distributed under the License is distributed on an "AS IS" BASIS,
52450 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52451 * See the License for the specific language governing permissions and
52452 * limitations under the License.
52453 * =============================================================================
52454 */
52455 var identityGradConfig = {
52456 kernelName: Identity$1,
52457 gradFunc: function gradFunc(dy) {
52458 return {
52459 x: function x() {
52460 return cast$3(dy, 'float32');
52461 }
52462 };
52463 }
52464 };
52465
52466 /**
52467 * @license
52468 * Copyright 2020 Google LLC. All Rights Reserved.
52469 * Licensed under the Apache License, Version 2.0 (the "License");
52470 * you may not use this file except in compliance with the License.
52471 * You may obtain a copy of the License at
52472 *
52473 * http://www.apache.org/licenses/LICENSE-2.0
52474 *
52475 * Unless required by applicable law or agreed to in writing, software
52476 * distributed under the License is distributed on an "AS IS" BASIS,
52477 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52478 * See the License for the specific language governing permissions and
52479 * limitations under the License.
52480 * =============================================================================
52481 */
52482 var isFiniteGradConfig = {
52483 kernelName: IsFinite,
52484 gradFunc: function gradFunc(dy) {
52485 // TODO(nsthorat): Let gradients be null for cases where we want to stop
52486 // backpropgation.
52487 return {
52488 x: function x() {
52489 return zerosLike$3(dy);
52490 }
52491 };
52492 }
52493 };
52494
52495 /**
52496 * @license
52497 * Copyright 2020 Google LLC. All Rights Reserved.
52498 * Licensed under the Apache License, Version 2.0 (the "License");
52499 * you may not use this file except in compliance with the License.
52500 * You may obtain a copy of the License at
52501 *
52502 * http://www.apache.org/licenses/LICENSE-2.0
52503 *
52504 * Unless required by applicable law or agreed to in writing, software
52505 * distributed under the License is distributed on an "AS IS" BASIS,
52506 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52507 * See the License for the specific language governing permissions and
52508 * limitations under the License.
52509 * =============================================================================
52510 */
52511 var isInfGradConfig = {
52512 kernelName: IsInf,
52513 gradFunc: function gradFunc(dy) {
52514 // TODO(nsthorat): Let gradients be null for cases where we want to stop
52515 // backpropgation.
52516 return {
52517 x: function x() {
52518 return zerosLike$3(dy);
52519 }
52520 };
52521 }
52522 };
52523
52524 /**
52525 * @license
52526 * Copyright 2020 Google LLC. All Rights Reserved.
52527 * Licensed under the Apache License, Version 2.0 (the "License");
52528 * you may not use this file except in compliance with the License.
52529 * You may obtain a copy of the License at
52530 *
52531 * http://www.apache.org/licenses/LICENSE-2.0
52532 *
52533 * Unless required by applicable law or agreed to in writing, software
52534 * distributed under the License is distributed on an "AS IS" BASIS,
52535 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52536 * See the License for the specific language governing permissions and
52537 * limitations under the License.
52538 * =============================================================================
52539 */
52540 var isNanGradConfig = {
52541 kernelName: IsNan,
52542 gradFunc: function gradFunc(dy) {
52543 // TODO(nsthorat): Let gradients be null for cases where we want to stop
52544 // backpropgation.
52545 return {
52546 x: function x() {
52547 return zerosLike$3(dy);
52548 }
52549 };
52550 }
52551 };
52552
52553 var leakyReluGradConfig = {
52554 kernelName: LeakyRelu,
52555 inputsToSave: ['x'],
52556 gradFunc: function gradFunc(dy, saved, attrs) {
52557 var _saved = _slicedToArray(saved, 1),
52558 x = _saved[0];
52559 var alpha = attrs.alpha;
52560 var mask = greater$3(x, 0);
52561 // Returns `gradients * (features > 0) + alpha * gradients * (features <=
52562 // 0)`.
52563 return {
52564 x: function x() {
52565 return where(mask, dy, mul(dy, alpha));
52566 }
52567 };
52568 }
52569 };
52570
52571 var log1pGradConfig = {
52572 kernelName: Log1p,
52573 inputsToSave: ['x'],
52574 gradFunc: function gradFunc(dy, saved) {
52575 var _saved = _slicedToArray(saved, 1),
52576 _x = _saved[0];
52577 return {
52578 x: function x() {
52579 return div$1(dy, add$3(_x, 1));
52580 }
52581 };
52582 }
52583 };
52584
52585 var logGradConfig = {
52586 kernelName: Log,
52587 inputsToSave: ['x'],
52588 gradFunc: function gradFunc(dy, saved) {
52589 var _saved = _slicedToArray(saved, 1),
52590 _x = _saved[0];
52591 return {
52592 x: function x() {
52593 return div$1(dy, cast$3(_x, 'float32'));
52594 }
52595 };
52596 }
52597 };
52598
52599 var logSoftmaxGradConfig = {
52600 kernelName: LogSoftmax$1,
52601 inputsToSave: [],
52602 outputsToSave: [true],
52603 gradFunc: function gradFunc(dy, saved, attrs) {
52604 var _saved = _slicedToArray(saved, 1),
52605 value = _saved[0];
52606 var axis = attrs.axis;
52607 return {
52608 logits: function logits() {
52609 var keepDims = true;
52610 var softmax = exp$2(value);
52611 return sub$2(dy, mul(sum$3(dy, axis, keepDims), softmax));
52612 }
52613 };
52614 }
52615 };
52616
52617 /**
52618 * @license
52619 * Copyright 2020 Google LLC. All Rights Reserved.
52620 * Licensed under the Apache License, Version 2.0 (the "License");
52621 * you may not use this file except in compliance with the License.
52622 * You may obtain a copy of the License at
52623 *
52624 * http://www.apache.org/licenses/LICENSE-2.0
52625 *
52626 * Unless required by applicable law or agreed to in writing, software
52627 * distributed under the License is distributed on an "AS IS" BASIS,
52628 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52629 * See the License for the specific language governing permissions and
52630 * limitations under the License.
52631 * =============================================================================
52632 */
52633 function localResponseNormalizationBackprop_(x, y, dy) {
52634 var depthRadius = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 5;
52635 var bias = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 1;
52636 var alpha = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 1;
52637 var beta = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : 0.5;
52638 var inputs = {
52639 x: x,
52640 y: y,
52641 dy: dy
52642 };
52643 var attrs = {
52644 depthRadius: depthRadius,
52645 bias: bias,
52646 alpha: alpha,
52647 beta: beta
52648 };
52649 return ENGINE.runKernel(LRNGrad, inputs, attrs);
52650 }
52651 var localResponseNormalizationBackprop = op({
52652 localResponseNormalizationBackprop_: localResponseNormalizationBackprop_
52653 });
52654
52655 var lrnGradConfig = {
52656 kernelName: LRN,
52657 inputsToSave: ['x'],
52658 outputsToSave: [true],
52659 gradFunc: function gradFunc(dy, saved, attrs) {
52660 var _saved = _slicedToArray(saved, 2),
52661 _x = _saved[0],
52662 y = _saved[1];
52663 var depthRadius = attrs.depthRadius,
52664 bias = attrs.bias,
52665 alpha = attrs.alpha,
52666 beta = attrs.beta;
52667 return {
52668 x: function x() {
52669 return localResponseNormalizationBackprop(_x, y, dy, depthRadius, bias, alpha, beta);
52670 }
52671 };
52672 }
52673 };
52674
52675 /**
52676 * @license
52677 * Copyright 2020 Google LLC. All Rights Reserved.
52678 * Licensed under the Apache License, Version 2.0 (the "License");
52679 * you may not use this file except in compliance with the License.
52680 * You may obtain a copy of the License at
52681 *
52682 * http://www.apache.org/licenses/LICENSE-2.0
52683 *
52684 * Unless required by applicable law or agreed to in writing, software
52685 * distributed under the License is distributed on an "AS IS" BASIS,
52686 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52687 * See the License for the specific language governing permissions and
52688 * limitations under the License.
52689 * =============================================================================
52690 */
52691 /**
52692 * Gradient helper function for the min and max operations.
52693 */
52694 function gradForMinAndMax(dy, y, xOrig, origAxes) {
52695 if (y.rank < xOrig.rank) {
52696 y = reshape$3(y, expandShapeToKeepDim(y.shape, origAxes));
52697 }
52698 if (dy.rank < xOrig.rank) {
52699 dy = reshape$3(dy, expandShapeToKeepDim(dy.shape, origAxes));
52700 }
52701 return {
52702 x: function x() {
52703 var dx = mul(dy, cast$3(equal$2(xOrig, y), dy.dtype));
52704 return dx;
52705 }
52706 };
52707 }
52708
52709 /**
52710 * @license
52711 * Copyright 2020 Google LLC. All Rights Reserved.
52712 * Licensed under the Apache License, Version 2.0 (the "License");
52713 * you may not use this file except in compliance with the License.
52714 * You may obtain a copy of the License at
52715 *
52716 * http://www.apache.org/licenses/LICENSE-2.0
52717 *
52718 * Unless required by applicable law or agreed to in writing, software
52719 * distributed under the License is distributed on an "AS IS" BASIS,
52720 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52721 * See the License for the specific language governing permissions and
52722 * limitations under the License.
52723 * =============================================================================
52724 */
52725 var maxGradConfig = {
52726 kernelName: Max,
52727 inputsToSave: ['x'],
52728 outputsToSave: [true],
52729 gradFunc: function gradFunc(dy, saved, attrs) {
52730 var maxAttrs = attrs;
52731 var reductionIndices = maxAttrs.reductionIndices;
52732 var x = saved[0];
52733 var y = saved[1];
52734 var origAxes = parseAxisParam(reductionIndices, x.shape);
52735 var maxGrad = gradForMinAndMax(dy, y, x, origAxes);
52736 return {
52737 x: function x() {
52738 return maxGrad['x']();
52739 }
52740 };
52741 }
52742 };
52743
52744 var maximumGradConfig = {
52745 kernelName: Maximum$1,
52746 inputsToSave: ['a', 'b'],
52747 gradFunc: function gradFunc(dy, saved) {
52748 var _saved = _slicedToArray(saved, 2),
52749 a = _saved[0],
52750 b = _saved[1];
52751 var derA = function derA() {
52752 return mul(dy, cast$3(greaterEqual$2(a, b), 'float32'));
52753 };
52754 var derB = function derB() {
52755 return mul(dy, cast$3(less$3(a, b), 'float32'));
52756 };
52757 return {
52758 a: derA,
52759 b: derB
52760 };
52761 }
52762 };
52763
52764 /**
52765 * @license
52766 * Copyright 2020 Google LLC. All Rights Reserved.
52767 * Licensed under the Apache License, Version 2.0 (the "License");
52768 * you may not use this file except in compliance with the License.
52769 * You may obtain a copy of the License at
52770 *
52771 * http://www.apache.org/licenses/LICENSE-2.0
52772 *
52773 * Unless required by applicable law or agreed to in writing, software
52774 * distributed under the License is distributed on an "AS IS" BASIS,
52775 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52776 * See the License for the specific language governing permissions and
52777 * limitations under the License.
52778 * =============================================================================
52779 */
52780 /**
52781 * Computes the backprop of a 3d max pool.
52782 *
52783 * @param dy The dy error, of rank 5 of shape
52784 * [batchSize, depth, height, width, channels].
52785 * assumed.
52786 * @param input The original input image, of rank 5 or rank 4 of shape
52787 * [batchSize, depth, height, width, channels].
52788 * @param output The original output image, of rank 5 of shape
52789 * [batchSize, outDepth, outHeight, outWidth, channels].
52790 * @param filterSize The filter size:
52791 * `[filterDepth, filterHeight, filterWidth]`.
52792 * `filterSize` is a single number,
52793 * then `filterDepth == filterHeight == filterWidth`.
52794 * @param strides The strides of the pooling:
52795 * `[strideDepth, strideHeight, strideWidth]`. If
52796 * `strides` is a single number, then `strideHeight == strideWidth`.
52797 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
52798 * used in the forward prop of the op.
52799 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
52800 * provided, it will default to truncate.
52801 */
52802 function maxPool3dGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
52803 var $dy = convertToTensor(dy, 'dy', 'maxPool3dGrad');
52804 var $input = convertToTensor(input, 'input', 'maxPool3dGrad');
52805 var $output = convertToTensor(output, 'output', 'maxPool3dGrad');
52806 var dy5D = $dy;
52807 var input5D = $input;
52808 var output5D = $output;
52809 var reshapedTo5D = false;
52810 if ($input.rank === 4) {
52811 reshapedTo5D = true;
52812 dy5D = reshape$3($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
52813 input5D = reshape$3($input, [1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]]);
52814 output5D = reshape$3($output, [1, $output.shape[0], $output.shape[1], $output.shape[2], $output.shape[3]]);
52815 }
52816 assert$1(dy5D.rank === 5, function () {
52817 return "Error in maxPool3dGrad: dy must be rank 5 but got rank " + "".concat(dy5D.rank, ".");
52818 });
52819 assert$1(input5D.rank === 5, function () {
52820 return "Error in maxPool3dGrad: input must be rank 5 but got rank " + "".concat(input5D.rank, ".");
52821 });
52822 assert$1(output5D.rank === 5, function () {
52823 return "Error in maxPool3dGrad: output must be rank 5 but got rank " + "".concat(output5D.rank, ".");
52824 });
52825 checkPadOnDimRoundingMode('maxPool3dGrad', pad, dimRoundingMode);
52826 var inputs = {
52827 dy: dy5D,
52828 input: input5D,
52829 output: output5D
52830 };
52831 var attrs = {
52832 filterSize: filterSize,
52833 strides: strides,
52834 pad: pad,
52835 dimRoundingMode: dimRoundingMode
52836 };
52837 // tslint:disable-next-line: no-unnecessary-type-assertion
52838 var res = ENGINE.runKernel(MaxPool3DGrad, inputs, attrs);
52839 if (reshapedTo5D) {
52840 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
52841 }
52842 return res;
52843 }
52844 var maxPool3dGrad = /* @__PURE__ */op({
52845 maxPool3dGrad_: maxPool3dGrad_
52846 });
52847
52848 var maxPool3DGradConfig$2 = {
52849 kernelName: MaxPool3D,
52850 inputsToSave: ['x'],
52851 outputsToSave: [true],
52852 gradFunc: function gradFunc(dy, saved, attrs) {
52853 var _saved = _slicedToArray(saved, 2),
52854 _x = _saved[0],
52855 y = _saved[1];
52856 var filterSize = attrs.filterSize,
52857 strides = attrs.strides,
52858 pad = attrs.pad,
52859 dimRoundingMode = attrs.dimRoundingMode;
52860 return {
52861 x: function x() {
52862 return maxPool3dGrad(dy, _x, y, filterSize, strides, pad, dimRoundingMode);
52863 }
52864 };
52865 }
52866 };
52867
52868 /**
52869 * @license
52870 * Copyright 2020 Google LLC. All Rights Reserved.
52871 * Licensed under the Apache License, Version 2.0 (the "License");
52872 * you may not use this file except in compliance with the License.
52873 * You may obtain a copy of the License at
52874 *
52875 * http://www.apache.org/licenses/LICENSE-2.0
52876 *
52877 * Unless required by applicable law or agreed to in writing, software
52878 * distributed under the License is distributed on an "AS IS" BASIS,
52879 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52880 * See the License for the specific language governing permissions and
52881 * limitations under the License.
52882 * =============================================================================
52883 */
52884 /**
52885 * Computes the backprop of a 2D max pool.
52886 *
52887 * @param dy The dy error, of rank 4 or rank 3 of shape
52888 * [batchSize, height, width, channels]. If rank 3, batch of 1 is
52889 * assumed.
52890 * @param input The original input image, of rank 4, of shape
52891 * [batchSize, height, width, channels].
52892 * @param output The original output image, of rank 4, of shape
52893 * [batchSize, outHeight, outWidth, channels].
52894 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
52895 * `filterSize` is a single number, then `filterHeight == filterWidth`.
52896 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
52897 * `strides` is a single number, then `strideHeight == strideWidth`.
52898 * @param pad The type of padding algorithm used in the forward prop of the op.
52899 * 'same', 'valid', for more info, see this guide:
52900 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
52901 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
52902 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
52903 * provided, it will default to truncate.
52904 */
52905 function maxPoolGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
52906 var $dy = convertToTensor(dy, 'dy', 'maxPoolGrad');
52907 var $input = convertToTensor(input, 'input', 'maxPoolGrad');
52908 var $output = convertToTensor(output, 'output', 'maxPoolGrad');
52909 assert$1($input.rank === $dy.rank, function () {
52910 return "Rank of input (".concat($input.rank, ") does not match rank of dy ") + "(".concat($dy.rank, ")");
52911 });
52912 assert$1($dy.rank === 4, function () {
52913 return "Error in maxPoolGrad: dy must be rank 4 but got rank " + "".concat($dy.rank, ".");
52914 });
52915 assert$1($input.rank === 4, function () {
52916 return "Error in maxPoolGrad: input must be rank 4 but got rank " + "".concat($input.rank, ".");
52917 });
52918 checkPadOnDimRoundingMode('maxPoolGrad', pad, dimRoundingMode);
52919 var inputs = {
52920 dy: $dy,
52921 input: $input,
52922 output: $output
52923 };
52924 var attrs = {
52925 filterSize: filterSize,
52926 strides: strides,
52927 pad: pad,
52928 dimRoundingMode: dimRoundingMode
52929 };
52930 // tslint:disable-next-line: no-unnecessary-type-assertion
52931 return ENGINE.runKernel(MaxPoolGrad, inputs, attrs);
52932 }
52933 var maxPoolGrad$2 = /* @__PURE__ */op({
52934 maxPoolGrad_: maxPoolGrad_
52935 });
52936
52937 var maxPoolGradConfig$2 = {
52938 kernelName: MaxPool,
52939 inputsToSave: ['x'],
52940 outputsToSave: [true],
52941 gradFunc: function gradFunc(dy, saved, attrs) {
52942 var _saved = _slicedToArray(saved, 2),
52943 _x = _saved[0],
52944 y = _saved[1];
52945 var filterSize = attrs.filterSize,
52946 strides = attrs.strides,
52947 pad = attrs.pad;
52948 return {
52949 x: function x() {
52950 return maxPoolGrad$2(dy, _x, y, filterSize, strides, pad);
52951 }
52952 };
52953 }
52954 };
52955
52956 var meanGradConfig = {
52957 kernelName: Mean,
52958 inputsToSave: ['x'],
52959 gradFunc: function gradFunc(dy, saved, attrs) {
52960 var _saved = _slicedToArray(saved, 1),
52961 x = _saved[0];
52962 var axis = attrs.axis;
52963 var axes = parseAxisParam(axis, x.shape);
52964 var shapes = computeOutAndReduceShapes(x.shape, axes);
52965 var reduceShape = shapes[1];
52966 var reduceSize = sizeFromShape(reduceShape);
52967 var derX = function derX() {
52968 var expandedDyShape = x.shape.slice();
52969 axes.forEach(function (axis) {
52970 expandedDyShape[axis] = 1;
52971 });
52972 var expandedDy = reshape$3(dy, expandedDyShape);
52973 var res = div$1(mul(expandedDy, ones$1(x.shape, 'float32')), reduceSize);
52974 return res;
52975 };
52976 return {
52977 x: derX
52978 };
52979 }
52980 };
52981
52982 var minGradConfig = {
52983 kernelName: Min,
52984 inputsToSave: ['x'],
52985 outputsToSave: [true],
52986 gradFunc: function gradFunc(dy, saved, attrs) {
52987 var minAttrs = attrs;
52988 var axis = minAttrs.axis;
52989 var _saved = _slicedToArray(saved, 2),
52990 x = _saved[0],
52991 y = _saved[1];
52992 var origAxes = parseAxisParam(axis, x.shape);
52993 var minGrad = gradForMinAndMax(dy, y, x, origAxes);
52994 return {
52995 x: function x() {
52996 return minGrad['x']();
52997 }
52998 };
52999 }
53000 };
53001
53002 var minimumGradConfig = {
53003 kernelName: Minimum$1,
53004 inputsToSave: ['a', 'b'],
53005 gradFunc: function gradFunc(dy, saved) {
53006 var _saved = _slicedToArray(saved, 2),
53007 a = _saved[0],
53008 b = _saved[1];
53009 var derA = function derA() {
53010 return mul(dy, cast$3(lessEqual$2(a, b), 'float32'));
53011 };
53012 var derB = function derB() {
53013 return mul(dy, cast$3(greater$3(a, b), 'float32'));
53014 };
53015 return {
53016 a: derA,
53017 b: derB
53018 };
53019 }
53020 };
53021
53022 /**
53023 * @license
53024 * Copyright 2020 Google LLC. All Rights Reserved.
53025 * Licensed under the Apache License, Version 2.0 (the "License");
53026 * you may not use this file except in compliance with the License.
53027 * You may obtain a copy of the License at
53028 *
53029 * http://www.apache.org/licenses/LICENSE-2.0
53030 *
53031 * Unless required by applicable law or agreed to in writing, software
53032 * distributed under the License is distributed on an "AS IS" BASIS,
53033 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53034 * See the License for the specific language governing permissions and
53035 * limitations under the License.
53036 * =============================================================================
53037 */
53038 var mirrorPadGradConfig = {
53039 kernelName: MirrorPad,
53040 inputsToSave: ['x'],
53041 gradFunc: function gradFunc(dy, saved, attrs) {
53042 // Pad introduces values around the original tensor, so the gradient
53043 // slices the original shape out of the gradient.
53044 var _x = saved[0];
53045 var paddings = attrs.paddings;
53046 var begin = paddings.map(function (p) {
53047 return p[0];
53048 });
53049 return {
53050 x: function x() {
53051 return slice$2(dy, begin, _x.shape);
53052 }
53053 };
53054 }
53055 };
53056
53057 var modGradConfig = {
53058 kernelName: Mod,
53059 inputsToSave: ['a', 'b'],
53060 gradFunc: function gradFunc(dy, saved) {
53061 var _saved = _slicedToArray(saved, 2),
53062 a = _saved[0],
53063 b = _saved[1];
53064 var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
53065 var derA = function derA() {
53066 var reduceAxes = getReductionAxes(a.shape, outShape);
53067 if (reduceAxes.length > 0) {
53068 return reshape$3(sum$3(dy, reduceAxes), a.shape);
53069 }
53070 return dy;
53071 };
53072 var derB = function derB() {
53073 var res = mul(dy, neg$2(floor$2(div$1(a, b))));
53074 var reduceAxes = getReductionAxes(b.shape, outShape);
53075 if (reduceAxes.length > 0) {
53076 return reshape$3(sum$3(res, reduceAxes), b.shape);
53077 }
53078 return res;
53079 };
53080 return {
53081 a: derA,
53082 b: derB
53083 };
53084 }
53085 };
53086
53087 var multiplyGradConfig = {
53088 kernelName: Multiply$1,
53089 inputsToSave: ['a', 'b'],
53090 gradFunc: function gradFunc(dy, saved) {
53091 var _saved = _slicedToArray(saved, 2),
53092 a = _saved[0],
53093 b = _saved[1];
53094 var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
53095 var derA = function derA() {
53096 var res = mul(dy, cast$3(b, 'float32'));
53097 var reduceAxes = getReductionAxes(a.shape, outShape);
53098 if (reduceAxes.length > 0) {
53099 return reshape$3(sum$3(res, reduceAxes), a.shape);
53100 }
53101 return res;
53102 };
53103 var derB = function derB() {
53104 var res = mul(dy, cast$3(a, 'float32'));
53105 var reduceAxes = getReductionAxes(b.shape, outShape);
53106 if (reduceAxes.length > 0) {
53107 return reshape$3(sum$3(res, reduceAxes), b.shape);
53108 }
53109 return res;
53110 };
53111 return {
53112 a: derA,
53113 b: derB
53114 };
53115 }
53116 };
53117
53118 /**
53119 * @license
53120 * Copyright 2020 Google LLC. All Rights Reserved.
53121 * Licensed under the Apache License, Version 2.0 (the "License");
53122 * you may not use this file except in compliance with the License.
53123 * You may obtain a copy of the License at
53124 *
53125 * http://www.apache.org/licenses/LICENSE-2.0
53126 *
53127 * Unless required by applicable law or agreed to in writing, software
53128 * distributed under the License is distributed on an "AS IS" BASIS,
53129 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53130 * See the License for the specific language governing permissions and
53131 * limitations under the License.
53132 * =============================================================================
53133 */
53134 var negGradConfig = {
53135 kernelName: Neg,
53136 gradFunc: function gradFunc(dy) {
53137 return {
53138 x: function x() {
53139 return neg$2(dy);
53140 }
53141 };
53142 }
53143 };
53144
53145 /**
53146 * @license
53147 * Copyright 2020 Google LLC. All Rights Reserved.
53148 * Licensed under the Apache License, Version 2.0 (the "License");
53149 * you may not use this file except in compliance with the License.
53150 * You may obtain a copy of the License at
53151 *
53152 * http://www.apache.org/licenses/LICENSE-2.0
53153 *
53154 * Unless required by applicable law or agreed to in writing, software
53155 * distributed under the License is distributed on an "AS IS" BASIS,
53156 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53157 * See the License for the specific language governing permissions and
53158 * limitations under the License.
53159 * =============================================================================
53160 */
53161 var oneHotGradConfig = {
53162 kernelName: OneHot,
53163 inputsToSave: ['indices'],
53164 gradFunc: function gradFunc(dy, saved) {
53165 var _indices = saved[0];
53166 return {
53167 indices: function indices() {
53168 return zeros$2(_indices.shape, 'float32');
53169 }
53170 };
53171 }
53172 };
53173
53174 /**
53175 * @license
53176 * Copyright 2020 Google LLC. All Rights Reserved.
53177 * Licensed under the Apache License, Version 2.0 (the "License");
53178 * you may not use this file except in compliance with the License.
53179 * You may obtain a copy of the License at
53180 *
53181 * http://www.apache.org/licenses/LICENSE-2.0
53182 *
53183 * Unless required by applicable law or agreed to in writing, software
53184 * distributed under the License is distributed on an "AS IS" BASIS,
53185 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53186 * See the License for the specific language governing permissions and
53187 * limitations under the License.
53188 * =============================================================================
53189 */
53190 var onesLikeGradConfig = {
53191 kernelName: OnesLike,
53192 gradFunc: function gradFunc(dy) {
53193 return {
53194 x: function x() {
53195 return zerosLike$3(dy);
53196 }
53197 };
53198 }
53199 };
53200
53201 /**
53202 * @license
53203 * Copyright 2020 Google LLC. All Rights Reserved.
53204 * Licensed under the Apache License, Version 2.0 (the "License");
53205 * you may not use this file except in compliance with the License.
53206 * You may obtain a copy of the License at
53207 *
53208 * http://www.apache.org/licenses/LICENSE-2.0
53209 *
53210 * Unless required by applicable law or agreed to in writing, software
53211 * distributed under the License is distributed on an "AS IS" BASIS,
53212 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53213 * See the License for the specific language governing permissions and
53214 * limitations under the License.
53215 * =============================================================================
53216 */
53217 var packGradConfig = {
53218 kernelName: Pack,
53219 saveAllInputs: true,
53220 gradFunc: function gradFunc(dy, saved, attrs) {
53221 var axis = attrs.axis;
53222 var derTensors = unstack(dy, axis);
53223 return derTensors.map(function (t) {
53224 return function () {
53225 return t;
53226 };
53227 });
53228 }
53229 };
53230
53231 /**
53232 * @license
53233 * Copyright 2020 Google LLC. All Rights Reserved.
53234 * Licensed under the Apache License, Version 2.0 (the "License");
53235 * you may not use this file except in compliance with the License.
53236 * You may obtain a copy of the License at
53237 *
53238 * http://www.apache.org/licenses/LICENSE-2.0
53239 *
53240 * Unless required by applicable law or agreed to in writing, software
53241 * distributed under the License is distributed on an "AS IS" BASIS,
53242 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53243 * See the License for the specific language governing permissions and
53244 * limitations under the License.
53245 * =============================================================================
53246 */
53247 var padV2GradConfig = {
53248 kernelName: PadV2,
53249 inputsToSave: ['x'],
53250 gradFunc: function gradFunc(dy, saved, attrs) {
53251 // Pad introduces values around the original tensor, so the gradient
53252 // slices the original shape out of the gradient.
53253 var _x = saved[0];
53254 var paddings = attrs.paddings;
53255 var begin = paddings.map(function (p) {
53256 return p[0];
53257 });
53258 return {
53259 x: function x() {
53260 return slice$2(dy, begin, _x.shape);
53261 }
53262 };
53263 }
53264 };
53265
53266 var powGradConfig = {
53267 kernelName: Pow,
53268 inputsToSave: ['a', 'b'],
53269 outputsToSave: [true],
53270 gradFunc: function gradFunc(dy, saved) {
53271 var _saved = _slicedToArray(saved, 3),
53272 a = _saved[0],
53273 b = _saved[1],
53274 y = _saved[2];
53275 var base = a;
53276 var exp = b;
53277 var outShape = assertAndGetBroadcastShape(base.shape, exp.shape);
53278 var derBase = function derBase() {
53279 var expFloat = cast$3(exp, 'float32');
53280 var res = mul(dy, mul(expFloat, pow$3(base, sub$2(expFloat, scalar(1)))));
53281 var reduceAxes = getReductionAxes(base.shape, outShape);
53282 if (reduceAxes.length > 0) {
53283 res = sum$3(res, reduceAxes);
53284 }
53285 return reshape$3(res, base.shape);
53286 };
53287 var derExp = function derExp() {
53288 var condition = greater$3(base, 0);
53289 var logBase = where(condition, log$2(base), zerosLike$3(base));
53290 var res = mul(dy, mul(y, logBase));
53291 var reduceAxes = getReductionAxes(exp.shape, outShape);
53292 if (reduceAxes.length > 0) {
53293 res = sum$3(res, reduceAxes);
53294 }
53295 return reshape$3(res, exp.shape);
53296 };
53297 return {
53298 a: derBase,
53299 b: derExp
53300 };
53301 }
53302 };
53303
53304 var preluGradConfig = {
53305 kernelName: Prelu,
53306 inputsToSave: ['x', 'alpha'],
53307 gradFunc: function gradFunc(dy, saved) {
53308 var _saved = _slicedToArray(saved, 2),
53309 x = _saved[0],
53310 _alpha = _saved[1];
53311 var mask = greater$3(x, 0);
53312 return {
53313 x: function x() {
53314 return where(mask, dy, mul(dy, _alpha));
53315 },
53316 alpha: function alpha() {
53317 var res = where(mask, zerosLike$3(dy), mul(dy, x));
53318 var reduceAxes = getReductionAxes(_alpha.shape, dy.shape);
53319 if (reduceAxes.length > 0) {
53320 res = sum$3(res, reduceAxes);
53321 }
53322 return reshape$3(res, _alpha.shape);
53323 }
53324 };
53325 }
53326 };
53327
53328 // Gradient for product operation on a single axis.
53329 function prodGradFn_(x, dy, axis) {
53330 // The gradient tensor (dy) has a set of axes removed, so we create re-shaped
53331 // versions (of size 1) for the removed axis; this supports broadcasting over
53332 // those dimensions.
53333 var expandedYShape = x.shape.slice();
53334 expandedYShape[axis] = 1;
53335 // The actual gradient computation.
53336 var expandedDy = reshape$3(dy, expandedYShape);
53337 var xCumProd = cumprod$2(x, axis, true, false);
53338 var xCumRevProd = cumprod$2(x, axis, true, true);
53339 var dx = mul(xCumProd, xCumRevProd);
53340 return mul(expandedDy, dx);
53341 }
53342 // Support gradients when the product is done on many axes at once.
53343 // This done py pushing all the axes on which the product is applied into a
53344 // single axis.
53345 function prodsGradFn_(x, dy, axis) {
53346 // Move all axes for doing prod over to the end of the tensor.
53347 var xRank = x.shape.length;
53348 var finalProdAxis = xRank - axis.length;
53349 var xPermutation = getAxesPermutation(axis, xRank);
53350 var permutedX = x;
53351 if (xPermutation != null) {
53352 permutedX = transpose$2(x, xPermutation);
53353 }
53354 // Reshape all the prod dimensions into a single one, and do compute prod
53355 // gradients on that.
53356 var newShape = permutedX.shape.slice();
53357 var removedShape = newShape.splice(xRank - axis.length, axis.length);
53358 var endPartShape = removedShape.reduce(function (p, c) {
53359 return p * c;
53360 }, 1);
53361 newShape.push(endPartShape);
53362 var reshapedPermutedX = permutedX.reshape(newShape);
53363 var prodGrad = prodGradFn_(reshapedPermutedX, dy, finalProdAxis);
53364 // Undo the re-shaping now we have the dx vector, and permute back to
53365 // original axes order.
53366 prodGrad = prodGrad.reshape(permutedX.shape);
53367 if (xPermutation != null) {
53368 var undoPermutation = getUndoAxesPermutation(xPermutation);
53369 prodGrad = transpose$2(prodGrad, undoPermutation);
53370 }
53371 return prodGrad;
53372 }
53373 // Running example:
53374 // [
53375 // [
53376 // [3.0, 4.0],
53377 // [5.0, 6.0],
53378 // [7.0, 8.0]
53379 // ],
53380 // [
53381 // [3.0, 5.0],
53382 // [0.0, 6.0],
53383 // [5.0, 6.0]
53384 // ]
53385 // ]
53386 //
53387 var prodGradConfig = {
53388 kernelName: Prod,
53389 inputsToSave: ['x'],
53390 gradFunc: function gradFunc(dy, saved, attrs) {
53391 var _saved = _slicedToArray(saved, 1),
53392 _x = _saved[0];
53393 var axis = attrs.axis;
53394 var axisArr = [];
53395 if (axis === undefined || axis === null) {
53396 axisArr = _x.shape.map(function (_, i) {
53397 return i;
53398 });
53399 } else if (typeof axis === 'number') {
53400 axisArr = [axis];
53401 } else {
53402 axisArr = axis;
53403 }
53404 return {
53405 x: function x() {
53406 return prodsGradFn_(_x, dy, axisArr);
53407 }
53408 };
53409 }
53410 };
53411
53412 var divGradConfig = {
53413 kernelName: RealDiv,
53414 inputsToSave: ['a', 'b'],
53415 gradFunc: function gradFunc(dy, saved) {
53416 var _saved = _slicedToArray(saved, 2),
53417 a = _saved[0],
53418 b = _saved[1];
53419 var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
53420 var derA = function derA() {
53421 var res = div$1(dy, cast$3(b, 'float32'));
53422 var reduceAxes = getReductionAxes(a.shape, outShape);
53423 if (reduceAxes.length > 0) {
53424 return reshape$3(sum$3(res, reduceAxes), a.shape);
53425 }
53426 return res;
53427 };
53428 var derB = function derB() {
53429 var res = mul(dy, cast$3(a, 'float32'));
53430 var reduceAxes = getReductionAxes(b.shape, outShape);
53431 if (reduceAxes.length > 0) {
53432 res = reshape$3(sum$3(res, reduceAxes), b.shape);
53433 }
53434 var tmp = square$2(b);
53435 return neg$2(div$1(res, cast$3(tmp, 'float32')));
53436 };
53437 return {
53438 a: derA,
53439 b: derB
53440 };
53441 }
53442 };
53443
53444 var reciprocalGradConfig = {
53445 kernelName: Reciprocal,
53446 inputsToSave: ['x'],
53447 gradFunc: function gradFunc(dy, saved) {
53448 var _saved = _slicedToArray(saved, 1),
53449 _x = _saved[0];
53450 return {
53451 x: function x() {
53452 return div$1(dy, neg$2(square$2(_x)));
53453 }
53454 };
53455 }
53456 };
53457
53458 var relu6GradConfig = {
53459 kernelName: Relu6$1,
53460 inputsToSave: ['x'],
53461 gradFunc: function gradFunc(dy, saved) {
53462 var _saved = _slicedToArray(saved, 1),
53463 x = _saved[0];
53464 var mask = mul(lessEqual$2(x, 6), step$2(x));
53465 return {
53466 x: function x() {
53467 return mul(dy, cast$3(mask, 'float32'));
53468 }
53469 };
53470 }
53471 };
53472
53473 var reluGradConfig = {
53474 kernelName: Relu$1,
53475 inputsToSave: ['x'],
53476 gradFunc: function gradFunc(dy, saved) {
53477 var _saved = _slicedToArray(saved, 1),
53478 _x = _saved[0];
53479 return {
53480 x: function x() {
53481 return mul(dy, cast$3(step$2(_x), 'float32'));
53482 }
53483 };
53484 }
53485 };
53486
53487 var reshapeGradConfig = {
53488 kernelName: Reshape$1,
53489 inputsToSave: ['x'],
53490 gradFunc: function gradFunc(dy, saved) {
53491 var _saved = _slicedToArray(saved, 1),
53492 _x = _saved[0];
53493 return {
53494 x: function x() {
53495 return reshape$3(dy, _x.shape);
53496 }
53497 };
53498 }
53499 };
53500
53501 var resizeBilinearGradConfig$2 = {
53502 kernelName: ResizeBilinear,
53503 inputsToSave: ['images'],
53504 gradFunc: function gradFunc(dy, saved, attrs) {
53505 var _saved = _slicedToArray(saved, 1),
53506 images = _saved[0];
53507 var inputs = {
53508 dy: dy,
53509 images: images
53510 };
53511 var imagesDer = function imagesDer() {
53512 return (
53513 // tslint:disable-next-line: no-unnecessary-type-assertion
53514 ENGINE.runKernel(ResizeBilinearGrad, inputs, attrs)
53515 );
53516 };
53517 return {
53518 images: imagesDer
53519 };
53520 }
53521 };
53522
53523 var resizeNearestNeighborGradConfig$2 = {
53524 kernelName: ResizeNearestNeighbor,
53525 inputsToSave: ['images'],
53526 gradFunc: function gradFunc(dy, saved, attrs) {
53527 var _saved = _slicedToArray(saved, 1),
53528 images = _saved[0];
53529 var inputs = {
53530 dy: dy,
53531 images: images
53532 };
53533 var imagesDer = function imagesDer() {
53534 return (
53535 // tslint:disable-next-line: no-unnecessary-type-assertion
53536 ENGINE.runKernel(ResizeNearestNeighborGrad, inputs, attrs)
53537 );
53538 };
53539 return {
53540 images: imagesDer
53541 };
53542 }
53543 };
53544
53545 /**
53546 * @license
53547 * Copyright 2020 Google LLC. All Rights Reserved.
53548 * Licensed under the Apache License, Version 2.0 (the "License");
53549 * you may not use this file except in compliance with the License.
53550 * You may obtain a copy of the License at
53551 *
53552 * http://www.apache.org/licenses/LICENSE-2.0
53553 *
53554 * Unless required by applicable law or agreed to in writing, software
53555 * distributed under the License is distributed on an "AS IS" BASIS,
53556 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53557 * See the License for the specific language governing permissions and
53558 * limitations under the License.
53559 * =============================================================================
53560 */
53561 var reverseGradConfig = {
53562 kernelName: Reverse,
53563 gradFunc: function gradFunc(dy, saved, attrs) {
53564 var dims = attrs.dims;
53565 var axes = parseAxisParam(dims, dy.shape);
53566 return {
53567 x: function x() {
53568 return reverse$2(dy, axes);
53569 }
53570 };
53571 }
53572 };
53573
53574 /**
53575 * @license
53576 * Copyright 2020 Google LLC. All Rights Reserved.
53577 * Licensed under the Apache License, Version 2.0 (the "License");
53578 * you may not use this file except in compliance with the License.
53579 * You may obtain a copy of the License at
53580 *
53581 * http://www.apache.org/licenses/LICENSE-2.0
53582 *
53583 * Unless required by applicable law or agreed to in writing, software
53584 * distributed under the License is distributed on an "AS IS" BASIS,
53585 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53586 * See the License for the specific language governing permissions and
53587 * limitations under the License.
53588 * =============================================================================
53589 */
53590 var roundGradConfig = {
53591 kernelName: Round,
53592 gradFunc: function gradFunc(dy) {
53593 // TODO(nsthorat): Let gradients be null for cases where we want to stop
53594 // backpropgation.
53595 return {
53596 x: function x() {
53597 return zerosLike$3(dy);
53598 }
53599 };
53600 }
53601 };
53602
53603 var rsqrtGradConfig = {
53604 kernelName: Rsqrt,
53605 inputsToSave: ['x'],
53606 gradFunc: function gradFunc(dy, saved) {
53607 var _saved = _slicedToArray(saved, 1),
53608 _x = _saved[0];
53609 return {
53610 x: function x() {
53611 return neg$2(div$1(dy, mul(pow$3(_x, 1.5), 2)));
53612 }
53613 };
53614 }
53615 };
53616
53617 var selectGradConfig = {
53618 kernelName: Select,
53619 inputsToSave: ['condition'],
53620 gradFunc: function gradFunc(dy, saved) {
53621 var _saved = _slicedToArray(saved, 1),
53622 _condition = _saved[0];
53623 return {
53624 // TODO(julianoks): Return null for condition gradient
53625 // when backprop supports it.
53626 condition: function condition() {
53627 return cast$3(zerosLike$3(_condition), 'float32');
53628 },
53629 t: function t() {
53630 return mul(dy, cast$3(_condition, dy.dtype));
53631 },
53632 e: function e() {
53633 return mul(dy, cast$3(logicalNot$2(_condition), dy.dtype));
53634 }
53635 };
53636 }
53637 };
53638
53639 var seluGradConfig = {
53640 kernelName: Selu$1,
53641 inputsToSave: ['x'],
53642 gradFunc: function gradFunc(dy, saved) {
53643 var _saved = _slicedToArray(saved, 1),
53644 _x = _saved[0];
53645 return {
53646 x: function x() {
53647 var mask = greater$3(_x, scalar(0));
53648 var scaleAlpha = scalar(SELU_SCALEALPHA);
53649 var scale = scalar(SELU_SCALE);
53650 var greaterThanZeroDer = mul(dy, scale);
53651 var lessEqualZeroDer = mul(mul(dy, scaleAlpha), exp$2(cast$3(_x, 'float32')));
53652 return where(mask, greaterThanZeroDer, lessEqualZeroDer);
53653 }
53654 };
53655 }
53656 };
53657
53658 var sigmoidGradConfig = {
53659 kernelName: Sigmoid$1,
53660 outputsToSave: [true],
53661 gradFunc: function gradFunc(dy, saved) {
53662 var _saved = _slicedToArray(saved, 1),
53663 y = _saved[0];
53664 return {
53665 x: function x() {
53666 return mul(dy, mul(y, sub$2(scalar(1), y)));
53667 }
53668 };
53669 }
53670 };
53671
53672 /**
53673 * @license
53674 * Copyright 2020 Google LLC. All Rights Reserved.
53675 * Licensed under the Apache License, Version 2.0 (the "License");
53676 * you may not use this file except in compliance with the License.
53677 * You may obtain a copy of the License at
53678 *
53679 * http://www.apache.org/licenses/LICENSE-2.0
53680 *
53681 * Unless required by applicable law or agreed to in writing, software
53682 * distributed under the License is distributed on an "AS IS" BASIS,
53683 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53684 * See the License for the specific language governing permissions and
53685 * limitations under the License.
53686 * =============================================================================
53687 */
53688 var signGradConfig = {
53689 kernelName: Sign,
53690 gradFunc: function gradFunc(dy) {
53691 return {
53692 x: function x() {
53693 return zerosLike$3(dy);
53694 }
53695 };
53696 }
53697 };
53698
53699 var sinGradConfig = {
53700 kernelName: Sin,
53701 inputsToSave: ['x'],
53702 gradFunc: function gradFunc(dy, saved) {
53703 var _saved = _slicedToArray(saved, 1),
53704 _x = _saved[0];
53705 return {
53706 x: function x() {
53707 return mul(cos$2(cast$3(_x, 'float32')), dy);
53708 }
53709 };
53710 }
53711 };
53712
53713 var sinhGradConfig = {
53714 kernelName: Sinh,
53715 inputsToSave: ['x'],
53716 gradFunc: function gradFunc(dy, saved) {
53717 var _saved = _slicedToArray(saved, 1),
53718 _x = _saved[0];
53719 return {
53720 x: function x() {
53721 return mul(cosh$2(cast$3(_x, 'float32')), dy);
53722 }
53723 };
53724 }
53725 };
53726
53727 var sliceGradConfig = {
53728 kernelName: Slice,
53729 inputsToSave: ['x'],
53730 gradFunc: function gradFunc(dy, saved, attrs) {
53731 var _saved = _slicedToArray(saved, 1),
53732 x = _saved[0];
53733 var begin = attrs.begin,
53734 size = attrs.size;
53735 var inputShape = x.shape;
53736 var _parseSliceParams = parseSliceParams(x, begin, size),
53737 _parseSliceParams2 = _slicedToArray(_parseSliceParams, 2),
53738 begin_ = _parseSliceParams2[0],
53739 size_ = _parseSliceParams2[1];
53740 // Create an Nx2 padding where the first column represents how many
53741 // zeros are prepended (at start) for each dimension, and the second
53742 // column indicates how many zeros are appended (at end).
53743 // The number of zeros to append is the shape of the input
53744 // elementwise-subtracted by both the begin vector and sizes vector.
53745 var paddings = [];
53746 for (var i = 0; i < dy.rank; i++) {
53747 paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
53748 }
53749 return {
53750 x: function x() {
53751 return pad(dy, paddings);
53752 }
53753 };
53754 }
53755 };
53756
53757 var softmaxGradConfig = {
53758 kernelName: Softmax$2,
53759 outputsToSave: [true],
53760 gradFunc: function gradFunc(dy, saved, attrs) {
53761 var _saved = _slicedToArray(saved, 1),
53762 y = _saved[0];
53763 var dim = attrs.dim;
53764 var keepDims = true;
53765 var dyTimesY = mul(dy, y);
53766 return {
53767 logits: function logits() {
53768 return sub$2(dyTimesY, mul(sum$3(dyTimesY, [dim], keepDims), y));
53769 }
53770 };
53771 }
53772 };
53773
53774 var softplusGradConfig = {
53775 kernelName: Softplus$1,
53776 inputsToSave: ['x'],
53777 gradFunc: function gradFunc(dy, saved) {
53778 var _saved = _slicedToArray(saved, 1),
53779 _x = _saved[0];
53780 return {
53781 x: function x() {
53782 return mul(dy, sigmoid$2(_x));
53783 }
53784 };
53785 }
53786 };
53787
53788 /**
53789 * @license
53790 * Copyright 2020 Google LLC. All Rights Reserved.
53791 * Licensed under the Apache License, Version 2.0 (the "License");
53792 * you may not use this file except in compliance with the License.
53793 * You may obtain a copy of the License at
53794 *
53795 * http://www.apache.org/licenses/LICENSE-2.0
53796 *
53797 * Unless required by applicable law or agreed to in writing, software
53798 * distributed under the License is distributed on an "AS IS" BASIS,
53799 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53800 * See the License for the specific language governing permissions and
53801 * limitations under the License.
53802 * =============================================================================
53803 */
53804 var spaceToBatchNDGradConfig = {
53805 kernelName: SpaceToBatchND,
53806 gradFunc: function gradFunc(dy, saved, attrs) {
53807 var blockShape = attrs.blockShape,
53808 paddings = attrs.paddings;
53809 return {
53810 x: function x() {
53811 return batchToSpaceND$2(dy, blockShape, paddings);
53812 }
53813 };
53814 }
53815 };
53816
53817 /**
53818 * @license
53819 * Copyright 2020 Google LLC. All Rights Reserved.
53820 * Licensed under the Apache License, Version 2.0 (the "License");
53821 * you may not use this file except in compliance with the License.
53822 * You may obtain a copy of the License at
53823 *
53824 * http://www.apache.org/licenses/LICENSE-2.0
53825 *
53826 * Unless required by applicable law or agreed to in writing, software
53827 * distributed under the License is distributed on an "AS IS" BASIS,
53828 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53829 * See the License for the specific language governing permissions and
53830 * limitations under the License.
53831 * =============================================================================
53832 */
53833 var splitVGradConfig = {
53834 kernelName: SplitV,
53835 gradFunc: function gradFunc(dy, saved, attrs) {
53836 var axis = attrs.axis;
53837 return {
53838 x: function x() {
53839 return concat$2(dy, axis);
53840 }
53841 };
53842 }
53843 };
53844
53845 var sqrtGradConfig = {
53846 kernelName: Sqrt,
53847 inputsToSave: ['x'],
53848 gradFunc: function gradFunc(dy, saved) {
53849 var _saved = _slicedToArray(saved, 1),
53850 _x = _saved[0];
53851 return {
53852 x: function x() {
53853 return div$1(dy, mul(sqrt$2(cast$3(_x, 'float32')), 2));
53854 }
53855 };
53856 }
53857 };
53858
53859 var squareGradConfig = {
53860 kernelName: Square,
53861 inputsToSave: ['x'],
53862 gradFunc: function gradFunc(dy, saved) {
53863 var _saved = _slicedToArray(saved, 1),
53864 _x = _saved[0];
53865 return {
53866 x: function x() {
53867 return mul(dy, mul(cast$3(_x, 'float32'), 2));
53868 }
53869 };
53870 }
53871 };
53872
53873 var squaredDifferenceGradConfig = {
53874 kernelName: SquaredDifference,
53875 inputsToSave: ['a', 'b'],
53876 gradFunc: function gradFunc(dy, saved) {
53877 var _saved = _slicedToArray(saved, 2),
53878 a = _saved[0],
53879 b = _saved[1];
53880 var two = scalar(2);
53881 var derA = function derA() {
53882 return mul(dy, mul(two, sub$2(a, b)));
53883 };
53884 var derB = function derB() {
53885 return mul(dy, mul(two, sub$2(b, a)));
53886 };
53887 return {
53888 a: derA,
53889 b: derB
53890 };
53891 }
53892 };
53893
53894 /**
53895 * @license
53896 * Copyright 2020 Google LLC. All Rights Reserved.
53897 * Licensed under the Apache License, Version 2.0 (the "License");
53898 * you may not use this file except in compliance with the License.
53899 * You may obtain a copy of the License at
53900 *
53901 * http://www.apache.org/licenses/LICENSE-2.0
53902 *
53903 * Unless required by applicable law or agreed to in writing, software
53904 * distributed under the License is distributed on an "AS IS" BASIS,
53905 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53906 * See the License for the specific language governing permissions and
53907 * limitations under the License.
53908 * =============================================================================
53909 */
53910 var stepGradConfig = {
53911 kernelName: Step,
53912 gradFunc: function gradFunc(dy) {
53913 // TODO(manrajgrover): Return null for gradients when backprop supports
53914 // it.
53915 return {
53916 x: function x() {
53917 return zerosLike$3(dy);
53918 }
53919 };
53920 }
53921 };
53922
53923 var subGradConfig = {
53924 kernelName: Sub,
53925 inputsToSave: ['a', 'b'],
53926 gradFunc: function gradFunc(dy, saved) {
53927 var _saved = _slicedToArray(saved, 2),
53928 a = _saved[0],
53929 b = _saved[1];
53930 var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
53931 var derA = function derA() {
53932 var res = dy;
53933 var reduceAxes = getReductionAxes(a.shape, outShape);
53934 if (reduceAxes.length > 0) {
53935 res = sum$3(res, reduceAxes);
53936 }
53937 return reshape$3(res, a.shape);
53938 };
53939 var derB = function derB() {
53940 var res = dy;
53941 var reduceAxes = getReductionAxes(b.shape, outShape);
53942 if (reduceAxes.length > 0) {
53943 res = sum$3(res, reduceAxes);
53944 }
53945 return reshape$3(neg$2(res), b.shape);
53946 };
53947 return {
53948 a: derA,
53949 b: derB
53950 };
53951 }
53952 };
53953
53954 var sumGradConfig = {
53955 kernelName: Sum,
53956 inputsToSave: ['x'],
53957 gradFunc: function gradFunc(dy, saved, attrs) {
53958 var _saved = _slicedToArray(saved, 1),
53959 x = _saved[0];
53960 var expandedDyShape = x.shape.slice();
53961 var axis = attrs.axis;
53962 var axes = parseAxisParam(axis, x.shape);
53963 axes.forEach(function (axis) {
53964 expandedDyShape[axis] = 1;
53965 });
53966 var expandedDy = reshape$3(dy, expandedDyShape);
53967 var derX = mul(expandedDy, ones$1(x.shape, 'float32'));
53968 return {
53969 x: function x() {
53970 return derX;
53971 }
53972 };
53973 }
53974 };
53975
53976 var tanGradConfig = {
53977 kernelName: Tan,
53978 inputsToSave: ['x'],
53979 gradFunc: function gradFunc(dy, saved) {
53980 var _saved = _slicedToArray(saved, 1),
53981 _x = _saved[0];
53982 return {
53983 x: function x() {
53984 return div$1(dy, square$2(cos$2(_x)));
53985 }
53986 };
53987 }
53988 };
53989
53990 var tanhGradConfig = {
53991 kernelName: Tanh$1,
53992 outputsToSave: [true],
53993 gradFunc: function gradFunc(dy, saved) {
53994 var _saved = _slicedToArray(saved, 1),
53995 y = _saved[0];
53996 return {
53997 x: function x() {
53998 return mul(sub$2(scalar(1), square$2(y)), dy);
53999 }
54000 };
54001 }
54002 };
54003
54004 var tileGradConfig = {
54005 kernelName: Tile,
54006 inputsToSave: ['x'],
54007 gradFunc: function gradFunc(dy, saved, attrs) {
54008 var _saved = _slicedToArray(saved, 1),
54009 x = _saved[0];
54010 var reps = attrs.reps;
54011 var derX = function derX() {
54012 var xGrad = zerosLike$3(x);
54013 // TODO(cais): Maybe reduce memory footprint by avoiding repeated
54014 // slicing.
54015 if (x.rank === 1) {
54016 for (var i = 0; i < reps[0]; ++i) {
54017 xGrad = add$3(xGrad, slice$2(dy, [i * x.shape[0]], [x.shape[0]]));
54018 }
54019 } else if (x.rank === 2) {
54020 for (var _i = 0; _i < reps[0]; ++_i) {
54021 for (var j = 0; j < reps[1]; ++j) {
54022 xGrad = add$3(xGrad, slice$2(dy, [_i * x.shape[0], j * x.shape[1]], [x.shape[0], x.shape[1]]));
54023 }
54024 }
54025 } else if (x.rank === 3) {
54026 for (var _i2 = 0; _i2 < reps[0]; ++_i2) {
54027 for (var _j = 0; _j < reps[1]; ++_j) {
54028 for (var k = 0; k < reps[2]; ++k) {
54029 xGrad = add$3(xGrad, slice$2(dy, [_i2 * x.shape[0], _j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]]));
54030 }
54031 }
54032 }
54033 } else if (x.rank === 4) {
54034 for (var _i3 = 0; _i3 < reps[0]; ++_i3) {
54035 for (var _j2 = 0; _j2 < reps[1]; ++_j2) {
54036 for (var _k = 0; _k < reps[2]; ++_k) {
54037 for (var l = 0; l < reps[3]; ++l) {
54038 xGrad = add$3(xGrad, slice$2(dy, [_i3 * x.shape[0], _j2 * x.shape[1], _k * x.shape[2], l * x.shape[3]], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));
54039 }
54040 }
54041 }
54042 }
54043 } else {
54044 throw new Error("Gradient for tile operation is not implemented for rank-" + "".concat(x.rank, " tensors yet."));
54045 }
54046 return xGrad;
54047 };
54048 return {
54049 x: derX
54050 };
54051 }
54052 };
54053
54054 /**
54055 * @license
54056 * Copyright 2020 Google LLC. All Rights Reserved.
54057 * Licensed under the Apache License, Version 2.0 (the "License");
54058 * you may not use this file except in compliance with the License.
54059 * You may obtain a copy of the License at
54060 *
54061 * http://www.apache.org/licenses/LICENSE-2.0
54062 *
54063 * Unless required by applicable law or agreed to in writing, software
54064 * distributed under the License is distributed on an "AS IS" BASIS,
54065 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54066 * See the License for the specific language governing permissions and
54067 * limitations under the License.
54068 * =============================================================================
54069 */
54070 var transposeGradConfig = {
54071 kernelName: Transpose,
54072 gradFunc: function gradFunc(dy, saved, attrs) {
54073 var transposeAttrs = attrs;
54074 var perm = transposeAttrs.perm;
54075 var undoPerm = getUndoAxesPermutation(perm);
54076 return {
54077 x: function x() {
54078 return transpose$2(dy, undoPerm);
54079 }
54080 };
54081 }
54082 };
54083
54084 /**
54085 * @license
54086 * Copyright 2020 Google Inc. All Rights Reserved.
54087 * Licensed under the Apache License, Version 2.0 (the "License");
54088 * you may not use this file except in compliance with the License.
54089 * You may obtain a copy of the License at
54090 *
54091 * http://www.apache.org/licenses/LICENSE-2.0
54092 *
54093 * Unless required by applicable law or agreed to in writing, software
54094 * distributed under the License is distributed on an "AS IS" BASIS,
54095 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54096 * See the License for the specific language governing permissions and
54097 * limitations under the License.
54098 * =============================================================================
54099 */
54100 var unpackGradConfig = {
54101 kernelName: Unpack,
54102 gradFunc: function gradFunc(dy, saved, attrs) {
54103 var unpackAttrs = attrs;
54104 var axis = unpackAttrs.axis;
54105 return {
54106 value: function value() {
54107 return stack(dy, axis);
54108 }
54109 };
54110 }
54111 };
54112
54113 var unsortedSegmentSumGradConfig = {
54114 kernelName: UnsortedSegmentSum,
54115 inputsToSave: ['segmentIds'],
54116 gradFunc: function gradFunc(dy, saved) {
54117 var _saved = _slicedToArray(saved, 1),
54118 segmentIds = _saved[0];
54119 var derX = function derX() {
54120 return gatherDropNegatives(dy, segmentIds);
54121 };
54122 return {
54123 x: derX
54124 };
54125 }
54126 };
54127 function gatherDropNegatives(x, indices) {
54128 // Helper function for unsorted segment ops. Gathers params for
54129 // positive segment ids and gathers 0 for inputs with negative segment id.
54130 // Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py
54131 var zeroClippedIndices = maximum$4(indices, zerosLike$3(indices));
54132 var gathered = gather$1(x, zeroClippedIndices);
54133 var isPositive = greaterEqual$2(indices, scalar(0, 'int32'));
54134 var numIters = gathered.rank - isPositive.rank;
54135 for (var i = 0; i < numIters; ++i) {
54136 isPositive = expandDims$3(isPositive, i + 1);
54137 }
54138 isPositive = logicalAnd$2(isPositive, ones$1(gathered.shape, 'bool'));
54139 var zeroSlice = zerosLike$3(gathered);
54140 return where(isPositive, gathered, zeroSlice);
54141 }
54142
54143 /**
54144 * @license
54145 * Copyright 2020 Google LLC. All Rights Reserved.
54146 * Licensed under the Apache License, Version 2.0 (the "License");
54147 * you may not use this file except in compliance with the License.
54148 * You may obtain a copy of the License at
54149 *
54150 * http://www.apache.org/licenses/LICENSE-2.0
54151 *
54152 * Unless required by applicable law or agreed to in writing, software
54153 * distributed under the License is distributed on an "AS IS" BASIS,
54154 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54155 * See the License for the specific language governing permissions and
54156 * limitations under the License.
54157 * =============================================================================
54158 */
54159 var zerosLikeGradConfig = {
54160 kernelName: ZerosLike,
54161 gradFunc: function gradFunc(dy) {
54162 return {
54163 x: function x() {
54164 return zerosLike$3(dy);
54165 }
54166 };
54167 }
54168 };
54169
54170 /**
54171 * @license
54172 * Copyright 2020 Google LLC. All Rights Reserved.
54173 * Licensed under the Apache License, Version 2.0 (the "License");
54174 * you may not use this file except in compliance with the License.
54175 * You may obtain a copy of the License at
54176 *
54177 * http://www.apache.org/licenses/LICENSE-2.0
54178 *
54179 * Unless required by applicable law or agreed to in writing, software
54180 * distributed under the License is distributed on an "AS IS" BASIS,
54181 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54182 * See the License for the specific language governing permissions and
54183 * limitations under the License.
54184 * =============================================================================
54185 */
54186 // Export all kernel configs here so that the package can auto register them
54187 var gradConfigs = [absGradConfig, acosGradConfig, acoshGradConfig, addGradConfig, addNGradConfig, argMaxGradConfig, argMinGradConfig, asinGradConfig, asinhGradConfig, atan2GradConfig, atanGradConfig, atanhGradConfig, avgPool3DGradConfig$2, avgPoolGradConfig$2, batchMatMulGradConfig, batchToSpaceNDGradConfig, broadcastToGradConfig, castGradConfig, ceilGradConfig, clipByValueGradConfig, complexAbsGradConfig, concatGradConfig, conv2DBackpropInputGradConfig, conv2DGradConfig, conv3DGradConfig, cosGradConfig, coshGradConfig, cumsumGradConfig, depthwiseConv2dNativeGradConfig, dilation2dGradConfig, divGradConfig, eluGradConfig$2, erfGradConfig, expGradConfig, expandDimsGradConfig, expm1GradConfig, floorDivGradConfig, floorGradConfig, fusedBatchNormGradConfig, gatherGradConfig, greaterEqualGradConfig, identityGradConfig, isFiniteGradConfig, isInfGradConfig, isNanGradConfig, leakyReluGradConfig, log1pGradConfig, logGradConfig, logSoftmaxGradConfig, lrnGradConfig, maxGradConfig, maxGradConfig, maximumGradConfig, maxPool3DGradConfig$2, maxPoolGradConfig$2, meanGradConfig, minGradConfig, minimumGradConfig, mirrorPadGradConfig, modGradConfig, multiplyGradConfig, negGradConfig, oneHotGradConfig, onesLikeGradConfig, packGradConfig, padV2GradConfig, padV2GradConfig, powGradConfig, preluGradConfig, prodGradConfig, reciprocalGradConfig, relu6GradConfig, reluGradConfig, reshapeGradConfig, resizeBilinearGradConfig$2, resizeNearestNeighborGradConfig$2, reverseGradConfig, roundGradConfig, rsqrtGradConfig, selectGradConfig, seluGradConfig, sigmoidGradConfig, signGradConfig, sinGradConfig, sinhGradConfig, sliceGradConfig, softmaxGradConfig, softplusGradConfig, spaceToBatchNDGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, splitVGradConfig, sqrtGradConfig, squaredDifferenceGradConfig, squareGradConfig, stepGradConfig, subGradConfig, sumGradConfig, tanGradConfig, tanhGradConfig, tileGradConfig, transposeGradConfig, unpackGradConfig, unsortedSegmentSumGradConfig, zerosLikeGradConfig];
54188 for (var _i$2 = 0, _gradConfigs = gradConfigs; _i$2 < _gradConfigs.length; _i$2++) {
54189 var gradientConfig = _gradConfigs[_i$2];
54190 registerGradient(gradientConfig);
54191 }
54192
54193 /**
54194 * @license
54195 * Copyright 2020 Google LLC. All Rights Reserved.
54196 * Licensed under the Apache License, Version 2.0 (the "License");
54197 * you may not use this file except in compliance with the License.
54198 * You may obtain a copy of the License at
54199 *
54200 * http://www.apache.org/licenses/LICENSE-2.0
54201 *
54202 * Unless required by applicable law or agreed to in writing, software
54203 * distributed under the License is distributed on an "AS IS" BASIS,
54204 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54205 * See the License for the specific language governing permissions and
54206 * limitations under the License.
54207 * =============================================================================
54208 */
54209 getGlobalTensorClass().prototype.abs = function () {
54210 this.throwIfDisposed();
54211 return abs$2(this);
54212 };
54213
54214 /**
54215 * @license
54216 * Copyright 2020 Google LLC. All Rights Reserved.
54217 * Licensed under the Apache License, Version 2.0 (the "License");
54218 * you may not use this file except in compliance with the License.
54219 * You may obtain a copy of the License at
54220 *
54221 * http://www.apache.org/licenses/LICENSE-2.0
54222 *
54223 * Unless required by applicable law or agreed to in writing, software
54224 * distributed under the License is distributed on an "AS IS" BASIS,
54225 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54226 * See the License for the specific language governing permissions and
54227 * limitations under the License.
54228 * =============================================================================
54229 */
54230 getGlobalTensorClass().prototype.acos = function () {
54231 this.throwIfDisposed();
54232 return acos$2(this);
54233 };
54234
54235 /**
54236 * @license
54237 * Copyright 2020 Google LLC. All Rights Reserved.
54238 * Licensed under the Apache License, Version 2.0 (the "License");
54239 * you may not use this file except in compliance with the License.
54240 * You may obtain a copy of the License at
54241 *
54242 * http://www.apache.org/licenses/LICENSE-2.0
54243 *
54244 * Unless required by applicable law or agreed to in writing, software
54245 * distributed under the License is distributed on an "AS IS" BASIS,
54246 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54247 * See the License for the specific language governing permissions and
54248 * limitations under the License.
54249 * =============================================================================
54250 */
54251 getGlobalTensorClass().prototype.acosh = function () {
54252 this.throwIfDisposed();
54253 return acosh$2(this);
54254 };
54255
54256 /**
54257 * @license
54258 * Copyright 2020 Google LLC. All Rights Reserved.
54259 * Licensed under the Apache License, Version 2.0 (the "License");
54260 * you may not use this file except in compliance with the License.
54261 * You may obtain a copy of the License at
54262 *
54263 * http://www.apache.org/licenses/LICENSE-2.0
54264 *
54265 * Unless required by applicable law or agreed to in writing, software
54266 * distributed under the License is distributed on an "AS IS" BASIS,
54267 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54268 * See the License for the specific language governing permissions and
54269 * limitations under the License.
54270 * =============================================================================
54271 */
54272 getGlobalTensorClass().prototype.add = function (b) {
54273 this.throwIfDisposed();
54274 return add$3(this, b);
54275 };
54276
54277 /**
54278 * @license
54279 * Copyright 2020 Google LLC. All Rights Reserved.
54280 * Licensed under the Apache License, Version 2.0 (the "License");
54281 * you may not use this file except in compliance with the License.
54282 * You may obtain a copy of the License at
54283 *
54284 * http://www.apache.org/licenses/LICENSE-2.0
54285 *
54286 * Unless required by applicable law or agreed to in writing, software
54287 * distributed under the License is distributed on an "AS IS" BASIS,
54288 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54289 * See the License for the specific language governing permissions and
54290 * limitations under the License.
54291 * =============================================================================
54292 */
54293 getGlobalTensorClass().prototype.all = function (axis, keepDims) {
54294 this.throwIfDisposed();
54295 return all$2(this, axis, keepDims);
54296 };
54297
54298 /**
54299 * @license
54300 * Copyright 2020 Google LLC. All Rights Reserved.
54301 * Licensed under the Apache License, Version 2.0 (the "License");
54302 * you may not use this file except in compliance with the License.
54303 * You may obtain a copy of the License at
54304 *
54305 * http://www.apache.org/licenses/LICENSE-2.0
54306 *
54307 * Unless required by applicable law or agreed to in writing, software
54308 * distributed under the License is distributed on an "AS IS" BASIS,
54309 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54310 * See the License for the specific language governing permissions and
54311 * limitations under the License.
54312 * =============================================================================
54313 */
54314 getGlobalTensorClass().prototype.any = function (axis, keepDims) {
54315 this.throwIfDisposed();
54316 return any$2(this, axis, keepDims);
54317 };
54318
54319 /**
54320 * @license
54321 * Copyright 2020 Google LLC. All Rights Reserved.
54322 * Licensed under the Apache License, Version 2.0 (the "License");
54323 * you may not use this file except in compliance with the License.
54324 * You may obtain a copy of the License at
54325 *
54326 * http://www.apache.org/licenses/LICENSE-2.0
54327 *
54328 * Unless required by applicable law or agreed to in writing, software
54329 * distributed under the License is distributed on an "AS IS" BASIS,
54330 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54331 * See the License for the specific language governing permissions and
54332 * limitations under the License.
54333 * =============================================================================
54334 */
54335 getGlobalTensorClass().prototype.argMax = function (axis) {
54336 this.throwIfDisposed();
54337 return argMax$2(this, axis);
54338 };
54339
54340 /**
54341 * @license
54342 * Copyright 2020 Google LLC. All Rights Reserved.
54343 * Licensed under the Apache License, Version 2.0 (the "License");
54344 * you may not use this file except in compliance with the License.
54345 * You may obtain a copy of the License at
54346 *
54347 * http://www.apache.org/licenses/LICENSE-2.0
54348 *
54349 * Unless required by applicable law or agreed to in writing, software
54350 * distributed under the License is distributed on an "AS IS" BASIS,
54351 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54352 * See the License for the specific language governing permissions and
54353 * limitations under the License.
54354 * =============================================================================
54355 */
54356 getGlobalTensorClass().prototype.argMin = function (axis) {
54357 this.throwIfDisposed();
54358 return argMin$2(this, axis);
54359 };
54360
54361 /**
54362 * @license
54363 * Copyright 2020 Google LLC. All Rights Reserved.
54364 * Licensed under the Apache License, Version 2.0 (the "License");
54365 * you may not use this file except in compliance with the License.
54366 * You may obtain a copy of the License at
54367 *
54368 * http://www.apache.org/licenses/LICENSE-2.0
54369 *
54370 * Unless required by applicable law or agreed to in writing, software
54371 * distributed under the License is distributed on an "AS IS" BASIS,
54372 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54373 * See the License for the specific language governing permissions and
54374 * limitations under the License.
54375 * =============================================================================
54376 */
54377 /**
54378 * Converts a size-1 `tf.Tensor` to a `tf.Scalar`.
54379 * @doc {heading: 'Tensors', subheading: 'Classes'}
54380 */
54381 getGlobalTensorClass().prototype.asScalar = function () {
54382 this.throwIfDisposed();
54383 assert$1(this.size === 1, function () {
54384 return 'The array must have only 1 element.';
54385 });
54386 return reshape$3(this, []);
54387 };
54388
54389 /**
54390 * @license
54391 * Copyright 2020 Google LLC. All Rights Reserved.
54392 * Licensed under the Apache License, Version 2.0 (the "License");
54393 * you may not use this file except in compliance with the License.
54394 * You may obtain a copy of the License at
54395 *
54396 * http://www.apache.org/licenses/LICENSE-2.0
54397 *
54398 * Unless required by applicable law or agreed to in writing, software
54399 * distributed under the License is distributed on an "AS IS" BASIS,
54400 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54401 * See the License for the specific language governing permissions and
54402 * limitations under the License.
54403 * =============================================================================
54404 */
54405 /**
54406 * Casts a `tf.Tensor` to a specified dtype.
54407 *
54408 * @param dtype Data-type to cast the tensor to.
54409 *
54410 * @doc {heading: 'Tensors', subheading: 'Classes'}
54411 */
54412 getGlobalTensorClass().prototype.asType = function (dtype) {
54413 this.throwIfDisposed();
54414 return cast$3(this, dtype);
54415 };
54416
54417 /**
54418 * @license
54419 * Copyright 2020 Google LLC. All Rights Reserved.
54420 * Licensed under the Apache License, Version 2.0 (the "License");
54421 * you may not use this file except in compliance with the License.
54422 * You may obtain a copy of the License at
54423 *
54424 * http://www.apache.org/licenses/LICENSE-2.0
54425 *
54426 * Unless required by applicable law or agreed to in writing, software
54427 * distributed under the License is distributed on an "AS IS" BASIS,
54428 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54429 * See the License for the specific language governing permissions and
54430 * limitations under the License.
54431 * =============================================================================
54432 */
54433 /**
54434 * Converts a `tf.Tensor` to a `tf.Tensor1D`.
54435 * @doc {heading: 'Tensors', subheading: 'Classes'}
54436 */
54437 getGlobalTensorClass().prototype.as1D = function () {
54438 this.throwIfDisposed();
54439 return reshape$3(this, [this.size]);
54440 };
54441
54442 /**
54443 * @license
54444 * Copyright 2020 Google LLC. All Rights Reserved.
54445 * Licensed under the Apache License, Version 2.0 (the "License");
54446 * you may not use this file except in compliance with the License.
54447 * You may obtain a copy of the License at
54448 *
54449 * http://www.apache.org/licenses/LICENSE-2.0
54450 *
54451 * Unless required by applicable law or agreed to in writing, software
54452 * distributed under the License is distributed on an "AS IS" BASIS,
54453 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54454 * See the License for the specific language governing permissions and
54455 * limitations under the License.
54456 * =============================================================================
54457 */
54458 /**
54459 * Converts a `tf.Tensor` to a `tf.Tensor2D`.
54460 *
54461 * @param rows Number of rows in `tf.Tensor2D`.
54462 * @param columns Number of columns in `tf.Tensor2D`.
54463 * @doc {heading: 'Tensors', subheading: 'Classes'}
54464 */
54465 getGlobalTensorClass().prototype.as2D = function (rows, columns) {
54466 this.throwIfDisposed();
54467 return reshape$3(this, [rows, columns]);
54468 };
54469
54470 /**
54471 * @license
54472 * Copyright 2020 Google LLC. All Rights Reserved.
54473 * Licensed under the Apache License, Version 2.0 (the "License");
54474 * you may not use this file except in compliance with the License.
54475 * You may obtain a copy of the License at
54476 *
54477 * http://www.apache.org/licenses/LICENSE-2.0
54478 *
54479 * Unless required by applicable law or agreed to in writing, software
54480 * distributed under the License is distributed on an "AS IS" BASIS,
54481 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54482 * See the License for the specific language governing permissions and
54483 * limitations under the License.
54484 * =============================================================================
54485 */
54486 /**
54487 * Converts a `tf.Tensor` to a `tf.Tensor3D`.
54488 *
54489 * @param rows Number of rows in `tf.Tensor3D`.
54490 * @param columns Number of columns in `tf.Tensor3D`.
54491 * @param depth Depth of `tf.Tensor3D`.
54492 * @doc {heading: 'Tensors', subheading: 'Classes'}
54493 */
54494 getGlobalTensorClass().prototype.as3D = function (rows, columns, depth) {
54495 this.throwIfDisposed();
54496 return reshape$3(this, [rows, columns, depth]);
54497 };
54498
54499 /**
54500 * @license
54501 * Copyright 2020 Google LLC. All Rights Reserved.
54502 * Licensed under the Apache License, Version 2.0 (the "License");
54503 * you may not use this file except in compliance with the License.
54504 * You may obtain a copy of the License at
54505 *
54506 * http://www.apache.org/licenses/LICENSE-2.0
54507 *
54508 * Unless required by applicable law or agreed to in writing, software
54509 * distributed under the License is distributed on an "AS IS" BASIS,
54510 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54511 * See the License for the specific language governing permissions and
54512 * limitations under the License.
54513 * =============================================================================
54514 */
54515 /**
54516 * Converts a `tf.Tensor` to a `tf.Tensor4D`.
54517 *
54518 * @param rows Number of rows in `tf.Tensor4D`.
54519 * @param columns Number of columns in `tf.Tensor4D`.
54520 * @param depth Depth of `tf.Tensor4D`.
54521 * @param depth2 4th dimension of `tf.Tensor4D`.
54522 * @doc {heading: 'Tensors', subheading: 'Classes'}
54523 */
54524 getGlobalTensorClass().prototype.as4D = function (rows, columns, depth, depth2) {
54525 this.throwIfDisposed();
54526 return reshape$3(this, [rows, columns, depth, depth2]);
54527 };
54528
54529 /**
54530 * @license
54531 * Copyright 2020 Google LLC. All Rights Reserved.
54532 * Licensed under the Apache License, Version 2.0 (the "License");
54533 * you may not use this file except in compliance with the License.
54534 * You may obtain a copy of the License at
54535 *
54536 * http://www.apache.org/licenses/LICENSE-2.0
54537 *
54538 * Unless required by applicable law or agreed to in writing, software
54539 * distributed under the License is distributed on an "AS IS" BASIS,
54540 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54541 * See the License for the specific language governing permissions and
54542 * limitations under the License.
54543 * =============================================================================
54544 */
54545 /**
54546 * Converts a `tf.Tensor` to a `tf.Tensor5D`.
54547 *
54548 * @param rows Number of rows in `tf.Tensor5D`.
54549 * @param columns Number of columns in `tf.Tensor5D`.
54550 * @param depth Depth of `tf.Tensor5D`.
54551 * @param depth2 4th dimension of `tf.Tensor5D`.
54552 * @param depth3 5th dimension of 'tf.Tensor5D'
54553 *
54554 * @doc {heading: 'Tensors', subheading: 'Classes'}
54555 */
54556 getGlobalTensorClass().prototype.as5D = function (rows, columns, depth, depth2, depth3) {
54557 this.throwIfDisposed();
54558 return reshape$3(this, [rows, columns, depth, depth2, depth3]);
54559 };
54560
54561 /**
54562 * @license
54563 * Copyright 2020 Google LLC. All Rights Reserved.
54564 * Licensed under the Apache License, Version 2.0 (the "License");
54565 * you may not use this file except in compliance with the License.
54566 * You may obtain a copy of the License at
54567 *
54568 * http://www.apache.org/licenses/LICENSE-2.0
54569 *
54570 * Unless required by applicable law or agreed to in writing, software
54571 * distributed under the License is distributed on an "AS IS" BASIS,
54572 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54573 * See the License for the specific language governing permissions and
54574 * limitations under the License.
54575 * =============================================================================
54576 */
54577 getGlobalTensorClass().prototype.asin = function () {
54578 this.throwIfDisposed();
54579 return asin$2(this);
54580 };
54581
54582 /**
54583 * @license
54584 * Copyright 2020 Google LLC. All Rights Reserved.
54585 * Licensed under the Apache License, Version 2.0 (the "License");
54586 * you may not use this file except in compliance with the License.
54587 * You may obtain a copy of the License at
54588 *
54589 * http://www.apache.org/licenses/LICENSE-2.0
54590 *
54591 * Unless required by applicable law or agreed to in writing, software
54592 * distributed under the License is distributed on an "AS IS" BASIS,
54593 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54594 * See the License for the specific language governing permissions and
54595 * limitations under the License.
54596 * =============================================================================
54597 */
54598 getGlobalTensorClass().prototype.asinh = function () {
54599 this.throwIfDisposed();
54600 return asinh$2(this);
54601 };
54602
54603 /**
54604 * @license
54605 * Copyright 2020 Google LLC. All Rights Reserved.
54606 * Licensed under the Apache License, Version 2.0 (the "License");
54607 * you may not use this file except in compliance with the License.
54608 * You may obtain a copy of the License at
54609 *
54610 * http://www.apache.org/licenses/LICENSE-2.0
54611 *
54612 * Unless required by applicable law or agreed to in writing, software
54613 * distributed under the License is distributed on an "AS IS" BASIS,
54614 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54615 * See the License for the specific language governing permissions and
54616 * limitations under the License.
54617 * =============================================================================
54618 */
54619 getGlobalTensorClass().prototype.atan = function () {
54620 this.throwIfDisposed();
54621 return atan$2(this);
54622 };
54623
54624 /**
54625 * @license
54626 * Copyright 2020 Google LLC. All Rights Reserved.
54627 * Licensed under the Apache License, Version 2.0 (the "License");
54628 * you may not use this file except in compliance with the License.
54629 * You may obtain a copy of the License at
54630 *
54631 * http://www.apache.org/licenses/LICENSE-2.0
54632 *
54633 * Unless required by applicable law or agreed to in writing, software
54634 * distributed under the License is distributed on an "AS IS" BASIS,
54635 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54636 * See the License for the specific language governing permissions and
54637 * limitations under the License.
54638 * =============================================================================
54639 */
54640 getGlobalTensorClass().prototype.atan2 = function (b) {
54641 this.throwIfDisposed();
54642 return atan2$2(this, b);
54643 };
54644
54645 /**
54646 * @license
54647 * Copyright 2020 Google LLC. All Rights Reserved.
54648 * Licensed under the Apache License, Version 2.0 (the "License");
54649 * you may not use this file except in compliance with the License.
54650 * You may obtain a copy of the License at
54651 *
54652 * http://www.apache.org/licenses/LICENSE-2.0
54653 *
54654 * Unless required by applicable law or agreed to in writing, software
54655 * distributed under the License is distributed on an "AS IS" BASIS,
54656 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54657 * See the License for the specific language governing permissions and
54658 * limitations under the License.
54659 * =============================================================================
54660 */
54661 getGlobalTensorClass().prototype.atanh = function () {
54662 this.throwIfDisposed();
54663 return atanh$2(this);
54664 };
54665
54666 getGlobalTensorClass().prototype.avgPool = function (filterSize, strides, pad, dimRoundingMode) {
54667 this.throwIfDisposed();
54668 return avgPool$2(this, filterSize, strides, pad, dimRoundingMode);
54669 };
54670
54671 /**
54672 * @license
54673 * Copyright 2020 Google LLC. All Rights Reserved.
54674 * Licensed under the Apache License, Version 2.0 (the "License");
54675 * you may not use this file except in compliance with the License.
54676 * You may obtain a copy of the License at
54677 *
54678 * http://www.apache.org/licenses/LICENSE-2.0
54679 *
54680 * Unless required by applicable law or agreed to in writing, software
54681 * distributed under the License is distributed on an "AS IS" BASIS,
54682 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54683 * See the License for the specific language governing permissions and
54684 * limitations under the License.
54685 * =============================================================================
54686 */
54687 getGlobalTensorClass().prototype.batchToSpaceND = function (blockShape, crops) {
54688 this.throwIfDisposed();
54689 return batchToSpaceND$2(this, blockShape, crops);
54690 };
54691
54692 /**
54693 * @license
54694 * Copyright 2020 Google LLC. All Rights Reserved.
54695 * Licensed under the Apache License, Version 2.0 (the "License");
54696 * you may not use this file except in compliance with the License.
54697 * You may obtain a copy of the License at
54698 *
54699 * http://www.apache.org/licenses/LICENSE-2.0
54700 *
54701 * Unless required by applicable law or agreed to in writing, software
54702 * distributed under the License is distributed on an "AS IS" BASIS,
54703 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54704 * See the License for the specific language governing permissions and
54705 * limitations under the License.
54706 * =============================================================================
54707 */
54708 getGlobalTensorClass().prototype.batchNorm = function (mean, variance, offset, scale, varianceEpsilon) {
54709 this.throwIfDisposed();
54710 return batchNorm$2(this, mean, variance, offset, scale, varianceEpsilon);
54711 };
54712
54713 /**
54714 * @license
54715 * Copyright 2020 Google LLC. All Rights Reserved.
54716 * Licensed under the Apache License, Version 2.0 (the "License");
54717 * you may not use this file except in compliance with the License.
54718 * You may obtain a copy of the License at
54719 *
54720 * http://www.apache.org/licenses/LICENSE-2.0
54721 *
54722 * Unless required by applicable law or agreed to in writing, software
54723 * distributed under the License is distributed on an "AS IS" BASIS,
54724 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54725 * See the License for the specific language governing permissions and
54726 * limitations under the License.
54727 * =============================================================================
54728 */
54729 getGlobalTensorClass().prototype.broadcastTo = function (shape) {
54730 this.throwIfDisposed();
54731 return broadcastTo(this, shape);
54732 };
54733
54734 /**
54735 * @license
54736 * Copyright 2020 Google LLC. All Rights Reserved.
54737 * Licensed under the Apache License, Version 2.0 (the "License");
54738 * you may not use this file except in compliance with the License.
54739 * You may obtain a copy of the License at
54740 *
54741 * http://www.apache.org/licenses/LICENSE-2.0
54742 *
54743 * Unless required by applicable law or agreed to in writing, software
54744 * distributed under the License is distributed on an "AS IS" BASIS,
54745 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54746 * See the License for the specific language governing permissions and
54747 * limitations under the License.
54748 * =============================================================================
54749 */
54750 getGlobalTensorClass().prototype.cast = function (dtype) {
54751 this.throwIfDisposed();
54752 return cast$3(this, dtype);
54753 };
54754
54755 /**
54756 * @license
54757 * Copyright 2020 Google LLC. All Rights Reserved.
54758 * Licensed under the Apache License, Version 2.0 (the "License");
54759 * you may not use this file except in compliance with the License.
54760 * You may obtain a copy of the License at
54761 *
54762 * http://www.apache.org/licenses/LICENSE-2.0
54763 *
54764 * Unless required by applicable law or agreed to in writing, software
54765 * distributed under the License is distributed on an "AS IS" BASIS,
54766 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54767 * See the License for the specific language governing permissions and
54768 * limitations under the License.
54769 * =============================================================================
54770 */
54771 getGlobalTensorClass().prototype.ceil = function () {
54772 this.throwIfDisposed();
54773 return ceil$2(this);
54774 };
54775
54776 /**
54777 * @license
54778 * Copyright 2020 Google LLC. All Rights Reserved.
54779 * Licensed under the Apache License, Version 2.0 (the "License");
54780 * you may not use this file except in compliance with the License.
54781 * You may obtain a copy of the License at
54782 *
54783 * http://www.apache.org/licenses/LICENSE-2.0
54784 *
54785 * Unless required by applicable law or agreed to in writing, software
54786 * distributed under the License is distributed on an "AS IS" BASIS,
54787 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54788 * See the License for the specific language governing permissions and
54789 * limitations under the License.
54790 * =============================================================================
54791 */
54792 getGlobalTensorClass().prototype.clipByValue = function (min, max) {
54793 this.throwIfDisposed();
54794 return clipByValue$2(this, min, max);
54795 };
54796
54797 getGlobalTensorClass().prototype.concat = function (x, axis) {
54798 this.throwIfDisposed();
54799 if (x instanceof Tensor) {
54800 x = [x];
54801 }
54802 return concat$2([this].concat(_toConsumableArray(x)), axis);
54803 };
54804
54805 /**
54806 * @license
54807 * Copyright 2020 Google LLC. All Rights Reserved.
54808 * Licensed under the Apache License, Version 2.0 (the "License");
54809 * you may not use this file except in compliance with the License.
54810 * You may obtain a copy of the License at
54811 *
54812 * http://www.apache.org/licenses/LICENSE-2.0
54813 *
54814 * Unless required by applicable law or agreed to in writing, software
54815 * distributed under the License is distributed on an "AS IS" BASIS,
54816 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54817 * See the License for the specific language governing permissions and
54818 * limitations under the License.
54819 * =============================================================================
54820 */
54821 getGlobalTensorClass().prototype.conv1d = function (filter, stride, pad, dataFormat, dilation, dimRoundingMode) {
54822 this.throwIfDisposed();
54823 return conv1d$2(this, filter, stride, pad, dataFormat, dilation, dimRoundingMode);
54824 };
54825
54826 /**
54827 * @license
54828 * Copyright 2020 Google LLC. All Rights Reserved.
54829 * Licensed under the Apache License, Version 2.0 (the "License");
54830 * you may not use this file except in compliance with the License.
54831 * You may obtain a copy of the License at
54832 *
54833 * http://www.apache.org/licenses/LICENSE-2.0
54834 *
54835 * Unless required by applicable law or agreed to in writing, software
54836 * distributed under the License is distributed on an "AS IS" BASIS,
54837 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54838 * See the License for the specific language governing permissions and
54839 * limitations under the License.
54840 * =============================================================================
54841 */
54842 getGlobalTensorClass().prototype.conv2dTranspose = function (filter, outputShape, strides, pad, dimRoundingMode) {
54843 this.throwIfDisposed();
54844 return conv2dTranspose$1(this, filter, outputShape, strides, pad, dimRoundingMode);
54845 };
54846
54847 /**
54848 * @license
54849 * Copyright 2020 Google LLC. All Rights Reserved.
54850 * Licensed under the Apache License, Version 2.0 (the "License");
54851 * you may not use this file except in compliance with the License.
54852 * You may obtain a copy of the License at
54853 *
54854 * http://www.apache.org/licenses/LICENSE-2.0
54855 *
54856 * Unless required by applicable law or agreed to in writing, software
54857 * distributed under the License is distributed on an "AS IS" BASIS,
54858 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54859 * See the License for the specific language governing permissions and
54860 * limitations under the License.
54861 * =============================================================================
54862 */
54863 getGlobalTensorClass().prototype.conv2d = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
54864 this.throwIfDisposed();
54865 return conv2d$4(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
54866 };
54867
54868 /**
54869 * @license
54870 * Copyright 2020 Google LLC. All Rights Reserved.
54871 * Licensed under the Apache License, Version 2.0 (the "License");
54872 * you may not use this file except in compliance with the License.
54873 * You may obtain a copy of the License at
54874 *
54875 * http://www.apache.org/licenses/LICENSE-2.0
54876 *
54877 * Unless required by applicable law or agreed to in writing, software
54878 * distributed under the License is distributed on an "AS IS" BASIS,
54879 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54880 * See the License for the specific language governing permissions and
54881 * limitations under the License.
54882 * =============================================================================
54883 */
54884 getGlobalTensorClass().prototype.cos = function () {
54885 this.throwIfDisposed();
54886 return cos$2(this);
54887 };
54888
54889 /**
54890 * @license
54891 * Copyright 2020 Google LLC. All Rights Reserved.
54892 * Licensed under the Apache License, Version 2.0 (the "License");
54893 * you may not use this file except in compliance with the License.
54894 * You may obtain a copy of the License at
54895 *
54896 * http://www.apache.org/licenses/LICENSE-2.0
54897 *
54898 * Unless required by applicable law or agreed to in writing, software
54899 * distributed under the License is distributed on an "AS IS" BASIS,
54900 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54901 * See the License for the specific language governing permissions and
54902 * limitations under the License.
54903 * =============================================================================
54904 */
54905 getGlobalTensorClass().prototype.cosh = function () {
54906 this.throwIfDisposed();
54907 return cosh$2(this);
54908 };
54909
54910 /**
54911 * @license
54912 * Copyright 2022 Google LLC. All Rights Reserved.
54913 * Licensed under the Apache License, Version 2.0 (the 'License');
54914 * you may not use this file except in compliance with the License.
54915 * You may obtain a copy of the License at
54916 *
54917 * http://www.apache.org/licenses/LICENSE-2.0
54918 *
54919 * Unless required by applicable law or agreed to in writing, software
54920 * distributed under the License is distributed on an 'AS IS' BASIS,
54921 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54922 * See the License for the specific language governing permissions and
54923 * limitations under the License.
54924 * =============================================================================
54925 */
54926 getGlobalTensorClass().prototype.cumprod = function (axis, exclusive, reverse) {
54927 this.throwIfDisposed();
54928 return cumprod$2(this, axis, exclusive, reverse);
54929 };
54930
54931 /**
54932 * @license
54933 * Copyright 2020 Google LLC. All Rights Reserved.
54934 * Licensed under the Apache License, Version 2.0 (the "License");
54935 * you may not use this file except in compliance with the License.
54936 * You may obtain a copy of the License at
54937 *
54938 * http://www.apache.org/licenses/LICENSE-2.0
54939 *
54940 * Unless required by applicable law or agreed to in writing, software
54941 * distributed under the License is distributed on an "AS IS" BASIS,
54942 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54943 * See the License for the specific language governing permissions and
54944 * limitations under the License.
54945 * =============================================================================
54946 */
54947 getGlobalTensorClass().prototype.cumsum = function (axis, exclusive, reverse) {
54948 this.throwIfDisposed();
54949 return cumsum$2(this, axis, exclusive, reverse);
54950 };
54951
54952 /**
54953 * @license
54954 * Copyright 2020 Google LLC. All Rights Reserved.
54955 * Licensed under the Apache License, Version 2.0 (the "License");
54956 * you may not use this file except in compliance with the License.
54957 * You may obtain a copy of the License at
54958 *
54959 * http://www.apache.org/licenses/LICENSE-2.0
54960 *
54961 * Unless required by applicable law or agreed to in writing, software
54962 * distributed under the License is distributed on an "AS IS" BASIS,
54963 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54964 * See the License for the specific language governing permissions and
54965 * limitations under the License.
54966 * =============================================================================
54967 */
54968 getGlobalTensorClass().prototype.depthToSpace = function (blockSize, dataFormat) {
54969 this.throwIfDisposed();
54970 return depthToSpace$2(this, blockSize, dataFormat);
54971 };
54972
54973 /**
54974 * @license
54975 * Copyright 2020 Google LLC. All Rights Reserved.
54976 * Licensed under the Apache License, Version 2.0 (the "License");
54977 * you may not use this file except in compliance with the License.
54978 * You may obtain a copy of the License at
54979 *
54980 * http://www.apache.org/licenses/LICENSE-2.0
54981 *
54982 * Unless required by applicable law or agreed to in writing, software
54983 * distributed under the License is distributed on an "AS IS" BASIS,
54984 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54985 * See the License for the specific language governing permissions and
54986 * limitations under the License.
54987 * =============================================================================
54988 */
54989 getGlobalTensorClass().prototype.depthwiseConv2d = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
54990 this.throwIfDisposed();
54991 return depthwiseConv2d$3(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
54992 };
54993
54994 /**
54995 * @license
54996 * Copyright 2020 Google LLC. All Rights Reserved.
54997 * Licensed under the Apache License, Version 2.0 (the "License");
54998 * you may not use this file except in compliance with the License.
54999 * You may obtain a copy of the License at
55000 *
55001 * http://www.apache.org/licenses/LICENSE-2.0
55002 *
55003 * Unless required by applicable law or agreed to in writing, software
55004 * distributed under the License is distributed on an "AS IS" BASIS,
55005 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55006 * See the License for the specific language governing permissions and
55007 * limitations under the License.
55008 * =============================================================================
55009 */
55010 getGlobalTensorClass().prototype.dilation2d = function (filter, strides, pad, dilations, dataFormat) {
55011 this.throwIfDisposed();
55012 return dilation2d(this, filter, strides, pad, dilations, dataFormat);
55013 };
55014
55015 /**
55016 * @license
55017 * Copyright 2020 Google LLC. All Rights Reserved.
55018 * Licensed under the Apache License, Version 2.0 (the "License");
55019 * you may not use this file except in compliance with the License.
55020 * You may obtain a copy of the License at
55021 *
55022 * http://www.apache.org/licenses/LICENSE-2.0
55023 *
55024 * Unless required by applicable law or agreed to in writing, software
55025 * distributed under the License is distributed on an "AS IS" BASIS,
55026 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55027 * See the License for the specific language governing permissions and
55028 * limitations under the License.
55029 * =============================================================================
55030 */
55031 getGlobalTensorClass().prototype.divNoNan = function (b) {
55032 this.throwIfDisposed();
55033 return divNoNan(this, b);
55034 };
55035
55036 /**
55037 * @license
55038 * Copyright 2020 Google LLC. All Rights Reserved.
55039 * Licensed under the Apache License, Version 2.0 (the "License");
55040 * you may not use this file except in compliance with the License.
55041 * You may obtain a copy of the License at
55042 *
55043 * http://www.apache.org/licenses/LICENSE-2.0
55044 *
55045 * Unless required by applicable law or agreed to in writing, software
55046 * distributed under the License is distributed on an "AS IS" BASIS,
55047 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55048 * See the License for the specific language governing permissions and
55049 * limitations under the License.
55050 * =============================================================================
55051 */
55052 getGlobalTensorClass().prototype.div = function (b) {
55053 this.throwIfDisposed();
55054 return div$1(this, b);
55055 };
55056
55057 /**
55058 * @license
55059 * Copyright 2020 Google LLC. All Rights Reserved.
55060 * Licensed under the Apache License, Version 2.0 (the "License");
55061 * you may not use this file except in compliance with the License.
55062 * You may obtain a copy of the License at
55063 *
55064 * http://www.apache.org/licenses/LICENSE-2.0
55065 *
55066 * Unless required by applicable law or agreed to in writing, software
55067 * distributed under the License is distributed on an "AS IS" BASIS,
55068 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55069 * See the License for the specific language governing permissions and
55070 * limitations under the License.
55071 * =============================================================================
55072 */
55073 getGlobalTensorClass().prototype.dot = function (b) {
55074 this.throwIfDisposed();
55075 return dot$2(this, b);
55076 };
55077
55078 /**
55079 * @license
55080 * Copyright 2020 Google LLC. All Rights Reserved.
55081 * Licensed under the Apache License, Version 2.0 (the "License");
55082 * you may not use this file except in compliance with the License.
55083 * You may obtain a copy of the License at
55084 *
55085 * http://www.apache.org/licenses/LICENSE-2.0
55086 *
55087 * Unless required by applicable law or agreed to in writing, software
55088 * distributed under the License is distributed on an "AS IS" BASIS,
55089 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55090 * See the License for the specific language governing permissions and
55091 * limitations under the License.
55092 * =============================================================================
55093 */
55094 getGlobalTensorClass().prototype.elu = function () {
55095 this.throwIfDisposed();
55096 return elu$4(this);
55097 };
55098
55099 /**
55100 * @license
55101 * Copyright 2020 Google LLC. All Rights Reserved.
55102 * Licensed under the Apache License, Version 2.0 (the "License");
55103 * you may not use this file except in compliance with the License.
55104 * You may obtain a copy of the License at
55105 *
55106 * http://www.apache.org/licenses/LICENSE-2.0
55107 *
55108 * Unless required by applicable law or agreed to in writing, software
55109 * distributed under the License is distributed on an "AS IS" BASIS,
55110 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55111 * See the License for the specific language governing permissions and
55112 * limitations under the License.
55113 * =============================================================================
55114 */
55115 getGlobalTensorClass().prototype.equal = function (b) {
55116 this.throwIfDisposed();
55117 return equal$2(this, b);
55118 };
55119
55120 /**
55121 * @license
55122 * Copyright 2020 Google LLC. All Rights Reserved.
55123 * Licensed under the Apache License, Version 2.0 (the "License");
55124 * you may not use this file except in compliance with the License.
55125 * You may obtain a copy of the License at
55126 *
55127 * http://www.apache.org/licenses/LICENSE-2.0
55128 *
55129 * Unless required by applicable law or agreed to in writing, software
55130 * distributed under the License is distributed on an "AS IS" BASIS,
55131 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55132 * See the License for the specific language governing permissions and
55133 * limitations under the License.
55134 * =============================================================================
55135 */
55136 getGlobalTensorClass().prototype.erf = function () {
55137 this.throwIfDisposed();
55138 return erf$2(this);
55139 };
55140
55141 /**
55142 * @license
55143 * Copyright 2021 Google LLC. All Rights Reserved.
55144 * Licensed under the Apache License, Version 2.0 (the "License");
55145 * you may not use this file except in compliance with the License.
55146 * You may obtain a copy of the License at
55147 *
55148 * http://www.apache.org/licenses/LICENSE-2.0
55149 *
55150 * Unless required by applicable law or agreed to in writing, software
55151 * distributed under the License is distributed on an "AS IS" BASIS,
55152 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55153 * See the License for the specific language governing permissions and
55154 * limitations under the License.
55155 * =============================================================================
55156 */
55157 getGlobalTensorClass().prototype.euclideanNorm = function (axis, keepDims) {
55158 this.throwIfDisposed();
55159 return euclideanNorm(this, axis, keepDims);
55160 };
55161
55162 /**
55163 * @license
55164 * Copyright 2020 Google LLC. All Rights Reserved.
55165 * Licensed under the Apache License, Version 2.0 (the "License");
55166 * you may not use this file except in compliance with the License.
55167 * You may obtain a copy of the License at
55168 *
55169 * http://www.apache.org/licenses/LICENSE-2.0
55170 *
55171 * Unless required by applicable law or agreed to in writing, software
55172 * distributed under the License is distributed on an "AS IS" BASIS,
55173 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55174 * See the License for the specific language governing permissions and
55175 * limitations under the License.
55176 * =============================================================================
55177 */
55178 getGlobalTensorClass().prototype.exp = function () {
55179 this.throwIfDisposed();
55180 return exp$2(this);
55181 };
55182
55183 /**
55184 * @license
55185 * Copyright 2020 Google LLC. All Rights Reserved.
55186 * Licensed under the Apache License, Version 2.0 (the "License");
55187 * you may not use this file except in compliance with the License.
55188 * You may obtain a copy of the License at
55189 *
55190 * http://www.apache.org/licenses/LICENSE-2.0
55191 *
55192 * Unless required by applicable law or agreed to in writing, software
55193 * distributed under the License is distributed on an "AS IS" BASIS,
55194 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55195 * See the License for the specific language governing permissions and
55196 * limitations under the License.
55197 * =============================================================================
55198 */
55199 getGlobalTensorClass().prototype.expandDims = function (axis) {
55200 this.throwIfDisposed();
55201 return expandDims$3(this, axis);
55202 };
55203
55204 /**
55205 * @license
55206 * Copyright 2020 Google LLC. All Rights Reserved.
55207 * Licensed under the Apache License, Version 2.0 (the "License");
55208 * you may not use this file except in compliance with the License.
55209 * You may obtain a copy of the License at
55210 *
55211 * http://www.apache.org/licenses/LICENSE-2.0
55212 *
55213 * Unless required by applicable law or agreed to in writing, software
55214 * distributed under the License is distributed on an "AS IS" BASIS,
55215 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55216 * See the License for the specific language governing permissions and
55217 * limitations under the License.
55218 * =============================================================================
55219 */
55220 getGlobalTensorClass().prototype.expm1 = function () {
55221 this.throwIfDisposed();
55222 return expm1$2(this);
55223 };
55224
55225 /**
55226 * @license
55227 * Copyright 2020 Google LLC. All Rights Reserved.
55228 * Licensed under the Apache License, Version 2.0 (the "License");
55229 * you may not use this file except in compliance with the License.
55230 * You may obtain a copy of the License at
55231 *
55232 * http://www.apache.org/licenses/LICENSE-2.0
55233 *
55234 * Unless required by applicable law or agreed to in writing, software
55235 * distributed under the License is distributed on an "AS IS" BASIS,
55236 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55237 * See the License for the specific language governing permissions and
55238 * limitations under the License.
55239 * =============================================================================
55240 */
55241 getGlobalTensorClass().prototype.fft = function () {
55242 this.throwIfDisposed();
55243 return fft$2(this);
55244 };
55245
55246 /**
55247 * @license
55248 * Copyright 2020 Google LLC. All Rights Reserved.
55249 * Licensed under the Apache License, Version 2.0 (the "License");
55250 * you may not use this file except in compliance with the License.
55251 * You may obtain a copy of the License at
55252 *
55253 * http://www.apache.org/licenses/LICENSE-2.0
55254 *
55255 * Unless required by applicable law or agreed to in writing, software
55256 * distributed under the License is distributed on an "AS IS" BASIS,
55257 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55258 * See the License for the specific language governing permissions and
55259 * limitations under the License.
55260 * =============================================================================
55261 */
55262 /**
55263 * Flatten a Tensor to a 1D array.
55264 * @doc {heading: 'Tensors', subheading: 'Classes'}
55265 */
55266 getGlobalTensorClass().prototype.flatten = function () {
55267 this.throwIfDisposed();
55268 return reshape$3(this, [this.size]);
55269 };
55270
55271 /**
55272 * @license
55273 * Copyright 2020 Google LLC. All Rights Reserved.
55274 * Licensed under the Apache License, Version 2.0 (the "License");
55275 * you may not use this file except in compliance with the License.
55276 * You may obtain a copy of the License at
55277 *
55278 * http://www.apache.org/licenses/LICENSE-2.0
55279 *
55280 * Unless required by applicable law or agreed to in writing, software
55281 * distributed under the License is distributed on an "AS IS" BASIS,
55282 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55283 * See the License for the specific language governing permissions and
55284 * limitations under the License.
55285 * =============================================================================
55286 */
55287 getGlobalTensorClass().prototype.floor = function () {
55288 this.throwIfDisposed();
55289 return floor$2(this);
55290 };
55291
55292 /**
55293 * @license
55294 * Copyright 2020 Google LLC. All Rights Reserved.
55295 * Licensed under the Apache License, Version 2.0 (the "License");
55296 * you may not use this file except in compliance with the License.
55297 * You may obtain a copy of the License at
55298 *
55299 * http://www.apache.org/licenses/LICENSE-2.0
55300 *
55301 * Unless required by applicable law or agreed to in writing, software
55302 * distributed under the License is distributed on an "AS IS" BASIS,
55303 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55304 * See the License for the specific language governing permissions and
55305 * limitations under the License.
55306 * =============================================================================
55307 */
55308 getGlobalTensorClass().prototype.floorDiv = function (b) {
55309 this.throwIfDisposed();
55310 return floorDiv$2(this, b);
55311 };
55312
55313 /**
55314 * @license
55315 * Copyright 2020 Google LLC. All Rights Reserved.
55316 * Licensed under the Apache License, Version 2.0 (the "License");
55317 * you may not use this file except in compliance with the License.
55318 * You may obtain a copy of the License at
55319 *
55320 * http://www.apache.org/licenses/LICENSE-2.0
55321 *
55322 * Unless required by applicable law or agreed to in writing, software
55323 * distributed under the License is distributed on an "AS IS" BASIS,
55324 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55325 * See the License for the specific language governing permissions and
55326 * limitations under the License.
55327 * =============================================================================
55328 */
55329 getGlobalTensorClass().prototype.gather = function (indices, axis, batchDims) {
55330 this.throwIfDisposed();
55331 return gather$1(this, indices, axis, batchDims);
55332 };
55333
55334 /**
55335 * @license
55336 * Copyright 2020 Google LLC. All Rights Reserved.
55337 * Licensed under the Apache License, Version 2.0 (the "License");
55338 * you may not use this file except in compliance with the License.
55339 * You may obtain a copy of the License at
55340 *
55341 * http://www.apache.org/licenses/LICENSE-2.0
55342 *
55343 * Unless required by applicable law or agreed to in writing, software
55344 * distributed under the License is distributed on an "AS IS" BASIS,
55345 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55346 * See the License for the specific language governing permissions and
55347 * limitations under the License.
55348 * =============================================================================
55349 */
55350 getGlobalTensorClass().prototype.greaterEqual = function (b) {
55351 this.throwIfDisposed();
55352 return greaterEqual$2(this, b);
55353 };
55354
55355 /**
55356 * @license
55357 * Copyright 2020 Google LLC. All Rights Reserved.
55358 * Licensed under the Apache License, Version 2.0 (the "License");
55359 * you may not use this file except in compliance with the License.
55360 * You may obtain a copy of the License at
55361 *
55362 * http://www.apache.org/licenses/LICENSE-2.0
55363 *
55364 * Unless required by applicable law or agreed to in writing, software
55365 * distributed under the License is distributed on an "AS IS" BASIS,
55366 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55367 * See the License for the specific language governing permissions and
55368 * limitations under the License.
55369 * =============================================================================
55370 */
55371 getGlobalTensorClass().prototype.greater = function (b) {
55372 this.throwIfDisposed();
55373 return greater$3(this, b);
55374 };
55375
55376 /**
55377 * @license
55378 * Copyright 2020 Google LLC. All Rights Reserved.
55379 * Licensed under the Apache License, Version 2.0 (the "License");
55380 * you may not use this file except in compliance with the License.
55381 * You may obtain a copy of the License at
55382 *
55383 * http://www.apache.org/licenses/LICENSE-2.0
55384 *
55385 * Unless required by applicable law or agreed to in writing, software
55386 * distributed under the License is distributed on an "AS IS" BASIS,
55387 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55388 * See the License for the specific language governing permissions and
55389 * limitations under the License.
55390 * =============================================================================
55391 */
55392 getGlobalTensorClass().prototype.ifft = function () {
55393 this.throwIfDisposed();
55394 return ifft$2(this);
55395 };
55396
55397 /**
55398 * @license
55399 * Copyright 2020 Google LLC. All Rights Reserved.
55400 * Licensed under the Apache License, Version 2.0 (the "License");
55401 * you may not use this file except in compliance with the License.
55402 * You may obtain a copy of the License at
55403 *
55404 * http://www.apache.org/licenses/LICENSE-2.0
55405 *
55406 * Unless required by applicable law or agreed to in writing, software
55407 * distributed under the License is distributed on an "AS IS" BASIS,
55408 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55409 * See the License for the specific language governing permissions and
55410 * limitations under the License.
55411 * =============================================================================
55412 */
55413 getGlobalTensorClass().prototype.irfft = function () {
55414 this.throwIfDisposed();
55415 return irfft(this);
55416 };
55417
55418 /**
55419 * @license
55420 * Copyright 2020 Google LLC. All Rights Reserved.
55421 * Licensed under the Apache License, Version 2.0 (the "License");
55422 * you may not use this file except in compliance with the License.
55423 * You may obtain a copy of the License at
55424 *
55425 * http://www.apache.org/licenses/LICENSE-2.0
55426 *
55427 * Unless required by applicable law or agreed to in writing, software
55428 * distributed under the License is distributed on an "AS IS" BASIS,
55429 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55430 * See the License for the specific language governing permissions and
55431 * limitations under the License.
55432 * =============================================================================
55433 */
55434 getGlobalTensorClass().prototype.isFinite = function () {
55435 this.throwIfDisposed();
55436 return isFinite$3(this);
55437 };
55438
55439 /**
55440 * @license
55441 * Copyright 2020 Google LLC. All Rights Reserved.
55442 * Licensed under the Apache License, Version 2.0 (the "License");
55443 * you may not use this file except in compliance with the License.
55444 * You may obtain a copy of the License at
55445 *
55446 * http://www.apache.org/licenses/LICENSE-2.0
55447 *
55448 * Unless required by applicable law or agreed to in writing, software
55449 * distributed under the License is distributed on an "AS IS" BASIS,
55450 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55451 * See the License for the specific language governing permissions and
55452 * limitations under the License.
55453 * =============================================================================
55454 */
55455 getGlobalTensorClass().prototype.isInf = function () {
55456 this.throwIfDisposed();
55457 return isInf$2(this);
55458 };
55459
55460 /**
55461 * @license
55462 * Copyright 2020 Google LLC. All Rights Reserved.
55463 * Licensed under the Apache License, Version 2.0 (the "License");
55464 * you may not use this file except in compliance with the License.
55465 * You may obtain a copy of the License at
55466 *
55467 * http://www.apache.org/licenses/LICENSE-2.0
55468 *
55469 * Unless required by applicable law or agreed to in writing, software
55470 * distributed under the License is distributed on an "AS IS" BASIS,
55471 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55472 * See the License for the specific language governing permissions and
55473 * limitations under the License.
55474 * =============================================================================
55475 */
55476 getGlobalTensorClass().prototype.isNaN = function () {
55477 this.throwIfDisposed();
55478 return isNaN$3(this);
55479 };
55480
55481 /**
55482 * @license
55483 * Copyright 2020 Google LLC. All Rights Reserved.
55484 * Licensed under the Apache License, Version 2.0 (the "License");
55485 * you may not use this file except in compliance with the License.
55486 * You may obtain a copy of the License at
55487 *
55488 * http://www.apache.org/licenses/LICENSE-2.0
55489 *
55490 * Unless required by applicable law or agreed to in writing, software
55491 * distributed under the License is distributed on an "AS IS" BASIS,
55492 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55493 * See the License for the specific language governing permissions and
55494 * limitations under the License.
55495 * =============================================================================
55496 */
55497 getGlobalTensorClass().prototype.leakyRelu = function (alpha) {
55498 this.throwIfDisposed();
55499 return leakyRelu$2(this, alpha);
55500 };
55501
55502 /**
55503 * @license
55504 * Copyright 2020 Google LLC. All Rights Reserved.
55505 * Licensed under the Apache License, Version 2.0 (the "License");
55506 * you may not use this file except in compliance with the License.
55507 * You may obtain a copy of the License at
55508 *
55509 * http://www.apache.org/licenses/LICENSE-2.0
55510 *
55511 * Unless required by applicable law or agreed to in writing, software
55512 * distributed under the License is distributed on an "AS IS" BASIS,
55513 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55514 * See the License for the specific language governing permissions and
55515 * limitations under the License.
55516 * =============================================================================
55517 */
55518 getGlobalTensorClass().prototype.lessEqual = function (b) {
55519 this.throwIfDisposed();
55520 return lessEqual$2(this, b);
55521 };
55522
55523 /**
55524 * @license
55525 * Copyright 2020 Google LLC. All Rights Reserved.
55526 * Licensed under the Apache License, Version 2.0 (the "License");
55527 * you may not use this file except in compliance with the License.
55528 * You may obtain a copy of the License at
55529 *
55530 * http://www.apache.org/licenses/LICENSE-2.0
55531 *
55532 * Unless required by applicable law or agreed to in writing, software
55533 * distributed under the License is distributed on an "AS IS" BASIS,
55534 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55535 * See the License for the specific language governing permissions and
55536 * limitations under the License.
55537 * =============================================================================
55538 */
55539 getGlobalTensorClass().prototype.less = function (b) {
55540 this.throwIfDisposed();
55541 return less$3(this, b);
55542 };
55543
55544 /**
55545 * @license
55546 * Copyright 2020 Google LLC. All Rights Reserved.
55547 * Licensed under the Apache License, Version 2.0 (the "License");
55548 * you may not use this file except in compliance with the License.
55549 * You may obtain a copy of the License at
55550 *
55551 * http://www.apache.org/licenses/LICENSE-2.0
55552 *
55553 * Unless required by applicable law or agreed to in writing, software
55554 * distributed under the License is distributed on an "AS IS" BASIS,
55555 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55556 * See the License for the specific language governing permissions and
55557 * limitations under the License.
55558 * =============================================================================
55559 */
55560 getGlobalTensorClass().prototype.localResponseNormalization = function (depthRadius, bias, alpha, beta) {
55561 this.throwIfDisposed();
55562 return localResponseNormalization(this, depthRadius, bias, alpha, beta);
55563 };
55564
55565 /**
55566 * @license
55567 * Copyright 2020 Google LLC. All Rights Reserved.
55568 * Licensed under the Apache License, Version 2.0 (the "License");
55569 * you may not use this file except in compliance with the License.
55570 * You may obtain a copy of the License at
55571 *
55572 * http://www.apache.org/licenses/LICENSE-2.0
55573 *
55574 * Unless required by applicable law or agreed to in writing, software
55575 * distributed under the License is distributed on an "AS IS" BASIS,
55576 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55577 * See the License for the specific language governing permissions and
55578 * limitations under the License.
55579 * =============================================================================
55580 */
55581 getGlobalTensorClass().prototype.logSigmoid = function () {
55582 this.throwIfDisposed();
55583 return logSigmoid(this);
55584 };
55585
55586 /**
55587 * @license
55588 * Copyright 2020 Google LLC. All Rights Reserved.
55589 * Licensed under the Apache License, Version 2.0 (the "License");
55590 * you may not use this file except in compliance with the License.
55591 * You may obtain a copy of the License at
55592 *
55593 * http://www.apache.org/licenses/LICENSE-2.0
55594 *
55595 * Unless required by applicable law or agreed to in writing, software
55596 * distributed under the License is distributed on an "AS IS" BASIS,
55597 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55598 * See the License for the specific language governing permissions and
55599 * limitations under the License.
55600 * =============================================================================
55601 */
55602 getGlobalTensorClass().prototype.logSoftmax = function (axis) {
55603 this.throwIfDisposed();
55604 return logSoftmax(this, axis);
55605 };
55606
55607 /**
55608 * @license
55609 * Copyright 2020 Google LLC. All Rights Reserved.
55610 * Licensed under the Apache License, Version 2.0 (the "License");
55611 * you may not use this file except in compliance with the License.
55612 * You may obtain a copy of the License at
55613 *
55614 * http://www.apache.org/licenses/LICENSE-2.0
55615 *
55616 * Unless required by applicable law or agreed to in writing, software
55617 * distributed under the License is distributed on an "AS IS" BASIS,
55618 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55619 * See the License for the specific language governing permissions and
55620 * limitations under the License.
55621 * =============================================================================
55622 */
55623 getGlobalTensorClass().prototype.logSumExp = function (axis, keepDims) {
55624 this.throwIfDisposed();
55625 return logSumExp(this, axis, keepDims);
55626 };
55627
55628 /**
55629 * @license
55630 * Copyright 2020 Google LLC. All Rights Reserved.
55631 * Licensed under the Apache License, Version 2.0 (the "License");
55632 * you may not use this file except in compliance with the License.
55633 * You may obtain a copy of the License at
55634 *
55635 * http://www.apache.org/licenses/LICENSE-2.0
55636 *
55637 * Unless required by applicable law or agreed to in writing, software
55638 * distributed under the License is distributed on an "AS IS" BASIS,
55639 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55640 * See the License for the specific language governing permissions and
55641 * limitations under the License.
55642 * =============================================================================
55643 */
55644 getGlobalTensorClass().prototype.log = function () {
55645 this.throwIfDisposed();
55646 return log$2(this);
55647 };
55648
55649 /**
55650 * @license
55651 * Copyright 2020 Google LLC. All Rights Reserved.
55652 * Licensed under the Apache License, Version 2.0 (the "License");
55653 * you may not use this file except in compliance with the License.
55654 * You may obtain a copy of the License at
55655 *
55656 * http://www.apache.org/licenses/LICENSE-2.0
55657 *
55658 * Unless required by applicable law or agreed to in writing, software
55659 * distributed under the License is distributed on an "AS IS" BASIS,
55660 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55661 * See the License for the specific language governing permissions and
55662 * limitations under the License.
55663 * =============================================================================
55664 */
55665 getGlobalTensorClass().prototype.log1p = function () {
55666 this.throwIfDisposed();
55667 return log1p$2(this);
55668 };
55669
55670 /**
55671 * @license
55672 * Copyright 2020 Google LLC. All Rights Reserved.
55673 * Licensed under the Apache License, Version 2.0 (the "License");
55674 * you may not use this file except in compliance with the License.
55675 * You may obtain a copy of the License at
55676 *
55677 * http://www.apache.org/licenses/LICENSE-2.0
55678 *
55679 * Unless required by applicable law or agreed to in writing, software
55680 * distributed under the License is distributed on an "AS IS" BASIS,
55681 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55682 * See the License for the specific language governing permissions and
55683 * limitations under the License.
55684 * =============================================================================
55685 */
55686 getGlobalTensorClass().prototype.logicalAnd = function (b) {
55687 this.throwIfDisposed();
55688 return logicalAnd$2(this, b);
55689 };
55690
55691 /**
55692 * @license
55693 * Copyright 2020 Google LLC. All Rights Reserved.
55694 * Licensed under the Apache License, Version 2.0 (the "License");
55695 * you may not use this file except in compliance with the License.
55696 * You may obtain a copy of the License at
55697 *
55698 * http://www.apache.org/licenses/LICENSE-2.0
55699 *
55700 * Unless required by applicable law or agreed to in writing, software
55701 * distributed under the License is distributed on an "AS IS" BASIS,
55702 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55703 * See the License for the specific language governing permissions and
55704 * limitations under the License.
55705 * =============================================================================
55706 */
55707 getGlobalTensorClass().prototype.logicalNot = function () {
55708 this.throwIfDisposed();
55709 return logicalNot$2(this);
55710 };
55711
55712 /**
55713 * @license
55714 * Copyright 2020 Google LLC. All Rights Reserved.
55715 * Licensed under the Apache License, Version 2.0 (the "License");
55716 * you may not use this file except in compliance with the License.
55717 * You may obtain a copy of the License at
55718 *
55719 * http://www.apache.org/licenses/LICENSE-2.0
55720 *
55721 * Unless required by applicable law or agreed to in writing, software
55722 * distributed under the License is distributed on an "AS IS" BASIS,
55723 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55724 * See the License for the specific language governing permissions and
55725 * limitations under the License.
55726 * =============================================================================
55727 */
55728 getGlobalTensorClass().prototype.logicalOr = function (b) {
55729 this.throwIfDisposed();
55730 return logicalOr$2(this, b);
55731 };
55732
55733 /**
55734 * @license
55735 * Copyright 2020 Google LLC. All Rights Reserved.
55736 * Licensed under the Apache License, Version 2.0 (the "License");
55737 * you may not use this file except in compliance with the License.
55738 * You may obtain a copy of the License at
55739 *
55740 * http://www.apache.org/licenses/LICENSE-2.0
55741 *
55742 * Unless required by applicable law or agreed to in writing, software
55743 * distributed under the License is distributed on an "AS IS" BASIS,
55744 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55745 * See the License for the specific language governing permissions and
55746 * limitations under the License.
55747 * =============================================================================
55748 */
55749 getGlobalTensorClass().prototype.logicalXor = function (b) {
55750 this.throwIfDisposed();
55751 return logicalXor(this, b);
55752 };
55753
55754 /**
55755 * @license
55756 * Copyright 2020 Google LLC. All Rights Reserved.
55757 * Licensed under the Apache License, Version 2.0 (the "License");
55758 * you may not use this file except in compliance with the License.
55759 * You may obtain a copy of the License at
55760 *
55761 * http://www.apache.org/licenses/LICENSE-2.0
55762 *
55763 * Unless required by applicable law or agreed to in writing, software
55764 * distributed under the License is distributed on an "AS IS" BASIS,
55765 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55766 * See the License for the specific language governing permissions and
55767 * limitations under the License.
55768 * =============================================================================
55769 */
55770 getGlobalTensorClass().prototype.matMul = function (b, transposeA, transposeB) {
55771 this.throwIfDisposed();
55772 return matMul$1(this, b, transposeA, transposeB);
55773 };
55774
55775 getGlobalTensorClass().prototype.maxPool = function (filterSize, strides, pad, dimRoundingMode) {
55776 this.throwIfDisposed();
55777 return maxPool$2(this, filterSize, strides, pad, dimRoundingMode);
55778 };
55779
55780 /**
55781 * @license
55782 * Copyright 2020 Google LLC. All Rights Reserved.
55783 * Licensed under the Apache License, Version 2.0 (the "License");
55784 * you may not use this file except in compliance with the License.
55785 * You may obtain a copy of the License at
55786 *
55787 * http://www.apache.org/licenses/LICENSE-2.0
55788 *
55789 * Unless required by applicable law or agreed to in writing, software
55790 * distributed under the License is distributed on an "AS IS" BASIS,
55791 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55792 * See the License for the specific language governing permissions and
55793 * limitations under the License.
55794 * =============================================================================
55795 */
55796 getGlobalTensorClass().prototype.max = function (axis, keepDims) {
55797 this.throwIfDisposed();
55798 return max$3(this, axis, keepDims);
55799 };
55800
55801 /**
55802 * @license
55803 * Copyright 2020 Google LLC. All Rights Reserved.
55804 * Licensed under the Apache License, Version 2.0 (the "License");
55805 * you may not use this file except in compliance with the License.
55806 * You may obtain a copy of the License at
55807 *
55808 * http://www.apache.org/licenses/LICENSE-2.0
55809 *
55810 * Unless required by applicable law or agreed to in writing, software
55811 * distributed under the License is distributed on an "AS IS" BASIS,
55812 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55813 * See the License for the specific language governing permissions and
55814 * limitations under the License.
55815 * =============================================================================
55816 */
55817 getGlobalTensorClass().prototype.maximum = function (b) {
55818 this.throwIfDisposed();
55819 return maximum$4(this, b);
55820 };
55821
55822 /**
55823 * @license
55824 * Copyright 2020 Google LLC. All Rights Reserved.
55825 * Licensed under the Apache License, Version 2.0 (the "License");
55826 * you may not use this file except in compliance with the License.
55827 * You may obtain a copy of the License at
55828 *
55829 * http://www.apache.org/licenses/LICENSE-2.0
55830 *
55831 * Unless required by applicable law or agreed to in writing, software
55832 * distributed under the License is distributed on an "AS IS" BASIS,
55833 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55834 * See the License for the specific language governing permissions and
55835 * limitations under the License.
55836 * =============================================================================
55837 */
55838 getGlobalTensorClass().prototype.mean = function (axis, keepDims) {
55839 this.throwIfDisposed();
55840 return mean$3(this, axis, keepDims);
55841 };
55842
55843 /**
55844 * @license
55845 * Copyright 2020 Google LLC. All Rights Reserved.
55846 * Licensed under the Apache License, Version 2.0 (the "License");
55847 * you may not use this file except in compliance with the License.
55848 * You may obtain a copy of the License at
55849 *
55850 * http://www.apache.org/licenses/LICENSE-2.0
55851 *
55852 * Unless required by applicable law or agreed to in writing, software
55853 * distributed under the License is distributed on an "AS IS" BASIS,
55854 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55855 * See the License for the specific language governing permissions and
55856 * limitations under the License.
55857 * =============================================================================
55858 */
55859 getGlobalTensorClass().prototype.min = function (axis, keepDims) {
55860 this.throwIfDisposed();
55861 return min$3(this, axis, keepDims);
55862 };
55863
55864 /**
55865 * @license
55866 * Copyright 2020 Google LLC. All Rights Reserved.
55867 * Licensed under the Apache License, Version 2.0 (the "License");
55868 * you may not use this file except in compliance with the License.
55869 * You may obtain a copy of the License at
55870 *
55871 * http://www.apache.org/licenses/LICENSE-2.0
55872 *
55873 * Unless required by applicable law or agreed to in writing, software
55874 * distributed under the License is distributed on an "AS IS" BASIS,
55875 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55876 * See the License for the specific language governing permissions and
55877 * limitations under the License.
55878 * =============================================================================
55879 */
55880 getGlobalTensorClass().prototype.minimum = function (b) {
55881 this.throwIfDisposed();
55882 return minimum$4(this, b);
55883 };
55884
55885 /**
55886 * @license
55887 * Copyright 2020 Google LLC. All Rights Reserved.
55888 * Licensed under the Apache License, Version 2.0 (the "License");
55889 * you may not use this file except in compliance with the License.
55890 * You may obtain a copy of the License at
55891 *
55892 * http://www.apache.org/licenses/LICENSE-2.0
55893 *
55894 * Unless required by applicable law or agreed to in writing, software
55895 * distributed under the License is distributed on an "AS IS" BASIS,
55896 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55897 * See the License for the specific language governing permissions and
55898 * limitations under the License.
55899 * =============================================================================
55900 */
55901 getGlobalTensorClass().prototype.mirrorPad = function (paddings, mode) {
55902 this.throwIfDisposed();
55903 return mirrorPad$1(this, paddings, mode);
55904 };
55905
55906 /**
55907 * @license
55908 * Copyright 2020 Google LLC. All Rights Reserved.
55909 * Licensed under the Apache License, Version 2.0 (the "License");
55910 * you may not use this file except in compliance with the License.
55911 * You may obtain a copy of the License at
55912 *
55913 * http://www.apache.org/licenses/LICENSE-2.0
55914 *
55915 * Unless required by applicable law or agreed to in writing, software
55916 * distributed under the License is distributed on an "AS IS" BASIS,
55917 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55918 * See the License for the specific language governing permissions and
55919 * limitations under the License.
55920 * =============================================================================
55921 */
55922 getGlobalTensorClass().prototype.mod = function (b) {
55923 this.throwIfDisposed();
55924 return mod$2(this, b);
55925 };
55926
55927 /**
55928 * @license
55929 * Copyright 2020 Google LLC. All Rights Reserved.
55930 * Licensed under the Apache License, Version 2.0 (the "License");
55931 * you may not use this file except in compliance with the License.
55932 * You may obtain a copy of the License at
55933 *
55934 * http://www.apache.org/licenses/LICENSE-2.0
55935 *
55936 * Unless required by applicable law or agreed to in writing, software
55937 * distributed under the License is distributed on an "AS IS" BASIS,
55938 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55939 * See the License for the specific language governing permissions and
55940 * limitations under the License.
55941 * =============================================================================
55942 */
55943 getGlobalTensorClass().prototype.mul = function (b) {
55944 this.throwIfDisposed();
55945 return mul(this, b);
55946 };
55947
55948 /**
55949 * @license
55950 * Copyright 2020 Google LLC. All Rights Reserved.
55951 * Licensed under the Apache License, Version 2.0 (the "License");
55952 * you may not use this file except in compliance with the License.
55953 * You may obtain a copy of the License at
55954 *
55955 * http://www.apache.org/licenses/LICENSE-2.0
55956 *
55957 * Unless required by applicable law or agreed to in writing, software
55958 * distributed under the License is distributed on an "AS IS" BASIS,
55959 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55960 * See the License for the specific language governing permissions and
55961 * limitations under the License.
55962 * =============================================================================
55963 */
55964 getGlobalTensorClass().prototype.neg = function () {
55965 this.throwIfDisposed();
55966 return neg$2(this);
55967 };
55968
55969 /**
55970 * @license
55971 * Copyright 2020 Google LLC. All Rights Reserved.
55972 * Licensed under the Apache License, Version 2.0 (the "License");
55973 * you may not use this file except in compliance with the License.
55974 * You may obtain a copy of the License at
55975 *
55976 * http://www.apache.org/licenses/LICENSE-2.0
55977 *
55978 * Unless required by applicable law or agreed to in writing, software
55979 * distributed under the License is distributed on an "AS IS" BASIS,
55980 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55981 * See the License for the specific language governing permissions and
55982 * limitations under the License.
55983 * =============================================================================
55984 */
55985 getGlobalTensorClass().prototype.norm = function (ord, axis, keepDims) {
55986 this.throwIfDisposed();
55987 return norm(this, ord, axis, keepDims);
55988 };
55989
55990 /**
55991 * @license
55992 * Copyright 2020 Google LLC. All Rights Reserved.
55993 * Licensed under the Apache License, Version 2.0 (the "License");
55994 * you may not use this file except in compliance with the License.
55995 * You may obtain a copy of the License at
55996 *
55997 * http://www.apache.org/licenses/LICENSE-2.0
55998 *
55999 * Unless required by applicable law or agreed to in writing, software
56000 * distributed under the License is distributed on an "AS IS" BASIS,
56001 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56002 * See the License for the specific language governing permissions and
56003 * limitations under the License.
56004 * =============================================================================
56005 */
56006 getGlobalTensorClass().prototype.notEqual = function (b) {
56007 this.throwIfDisposed();
56008 return notEqual$2(this, b);
56009 };
56010
56011 /**
56012 * @license
56013 * Copyright 2020 Google LLC. All Rights Reserved.
56014 * Licensed under the Apache License, Version 2.0 (the "License");
56015 * you may not use this file except in compliance with the License.
56016 * You may obtain a copy of the License at
56017 *
56018 * http://www.apache.org/licenses/LICENSE-2.0
56019 *
56020 * Unless required by applicable law or agreed to in writing, software
56021 * distributed under the License is distributed on an "AS IS" BASIS,
56022 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56023 * See the License for the specific language governing permissions and
56024 * limitations under the License.
56025 * =============================================================================
56026 */
56027 getGlobalTensorClass().prototype.oneHot = function (depth) {
56028 var onValue = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 1;
56029 var offValue = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
56030 this.throwIfDisposed();
56031 return oneHot$3(this, depth, onValue, offValue);
56032 };
56033
56034 /**
56035 * @license
56036 * Copyright 2020 Google LLC. All Rights Reserved.
56037 * Licensed under the Apache License, Version 2.0 (the "License");
56038 * you may not use this file except in compliance with the License.
56039 * You may obtain a copy of the License at
56040 *
56041 * http://www.apache.org/licenses/LICENSE-2.0
56042 *
56043 * Unless required by applicable law or agreed to in writing, software
56044 * distributed under the License is distributed on an "AS IS" BASIS,
56045 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56046 * See the License for the specific language governing permissions and
56047 * limitations under the License.
56048 * =============================================================================
56049 */
56050 getGlobalTensorClass().prototype.onesLike = function () {
56051 this.throwIfDisposed();
56052 return onesLike$3(this);
56053 };
56054
56055 /**
56056 * @license
56057 * Copyright 2020 Google LLC. All Rights Reserved.
56058 * Licensed under the Apache License, Version 2.0 (the "License");
56059 * you may not use this file except in compliance with the License.
56060 * You may obtain a copy of the License at
56061 *
56062 * http://www.apache.org/licenses/LICENSE-2.0
56063 *
56064 * Unless required by applicable law or agreed to in writing, software
56065 * distributed under the License is distributed on an "AS IS" BASIS,
56066 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56067 * See the License for the specific language governing permissions and
56068 * limitations under the License.
56069 * =============================================================================
56070 */
56071 getGlobalTensorClass().prototype.pad = function (paddings, constantValue) {
56072 this.throwIfDisposed();
56073 return pad(this, paddings, constantValue);
56074 };
56075
56076 getGlobalTensorClass().prototype.pool = function (windowShape, poolingType, padding, dilationRate, strides, dimRoundingMode) {
56077 this.throwIfDisposed();
56078 return pool$1(this, windowShape, poolingType, padding, dilationRate, strides, dimRoundingMode);
56079 };
56080
56081 /**
56082 * @license
56083 * Copyright 2020 Google LLC. All Rights Reserved.
56084 * Licensed under the Apache License, Version 2.0 (the "License");
56085 * you may not use this file except in compliance with the License.
56086 * You may obtain a copy of the License at
56087 *
56088 * http://www.apache.org/licenses/LICENSE-2.0
56089 *
56090 * Unless required by applicable law or agreed to in writing, software
56091 * distributed under the License is distributed on an "AS IS" BASIS,
56092 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56093 * See the License for the specific language governing permissions and
56094 * limitations under the License.
56095 * =============================================================================
56096 */
56097 getGlobalTensorClass().prototype.pow = function (exp) {
56098 this.throwIfDisposed();
56099 return pow$3(this, exp);
56100 };
56101
56102 /**
56103 * @license
56104 * Copyright 2020 Google LLC. All Rights Reserved.
56105 * Licensed under the Apache License, Version 2.0 (the "License");
56106 * you may not use this file except in compliance with the License.
56107 * You may obtain a copy of the License at
56108 *
56109 * http://www.apache.org/licenses/LICENSE-2.0
56110 *
56111 * Unless required by applicable law or agreed to in writing, software
56112 * distributed under the License is distributed on an "AS IS" BASIS,
56113 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56114 * See the License for the specific language governing permissions and
56115 * limitations under the License.
56116 * =============================================================================
56117 */
56118 getGlobalTensorClass().prototype.prelu = function (alpha) {
56119 this.throwIfDisposed();
56120 return prelu$3(this, alpha);
56121 };
56122
56123 /**
56124 * @license
56125 * Copyright 2020 Google LLC. All Rights Reserved.
56126 * Licensed under the Apache License, Version 2.0 (the "License");
56127 * you may not use this file except in compliance with the License.
56128 * You may obtain a copy of the License at
56129 *
56130 * http://www.apache.org/licenses/LICENSE-2.0
56131 *
56132 * Unless required by applicable law or agreed to in writing, software
56133 * distributed under the License is distributed on an "AS IS" BASIS,
56134 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56135 * See the License for the specific language governing permissions and
56136 * limitations under the License.
56137 * =============================================================================
56138 */
56139 getGlobalTensorClass().prototype.prod = function (axis, keepDims) {
56140 this.throwIfDisposed();
56141 return prod$2(this, axis, keepDims);
56142 };
56143
56144 /**
56145 * @license
56146 * Copyright 2020 Google LLC. All Rights Reserved.
56147 * Licensed under the Apache License, Version 2.0 (the "License");
56148 * you may not use this file except in compliance with the License.
56149 * You may obtain a copy of the License at
56150 *
56151 * http://www.apache.org/licenses/LICENSE-2.0
56152 *
56153 * Unless required by applicable law or agreed to in writing, software
56154 * distributed under the License is distributed on an "AS IS" BASIS,
56155 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56156 * See the License for the specific language governing permissions and
56157 * limitations under the License.
56158 * =============================================================================
56159 */
56160 getGlobalTensorClass().prototype.reciprocal = function () {
56161 this.throwIfDisposed();
56162 return reciprocal$2(this);
56163 };
56164
56165 /**
56166 * @license
56167 * Copyright 2020 Google LLC. All Rights Reserved.
56168 * Licensed under the Apache License, Version 2.0 (the "License");
56169 * you may not use this file except in compliance with the License.
56170 * You may obtain a copy of the License at
56171 *
56172 * http://www.apache.org/licenses/LICENSE-2.0
56173 *
56174 * Unless required by applicable law or agreed to in writing, software
56175 * distributed under the License is distributed on an "AS IS" BASIS,
56176 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56177 * See the License for the specific language governing permissions and
56178 * limitations under the License.
56179 * =============================================================================
56180 */
56181 getGlobalTensorClass().prototype.relu = function () {
56182 this.throwIfDisposed();
56183 return relu$2(this);
56184 };
56185
56186 /**
56187 * @license
56188 * Copyright 2020 Google LLC. All Rights Reserved.
56189 * Licensed under the Apache License, Version 2.0 (the "License");
56190 * you may not use this file except in compliance with the License.
56191 * You may obtain a copy of the License at
56192 *
56193 * http://www.apache.org/licenses/LICENSE-2.0
56194 *
56195 * Unless required by applicable law or agreed to in writing, software
56196 * distributed under the License is distributed on an "AS IS" BASIS,
56197 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56198 * See the License for the specific language governing permissions and
56199 * limitations under the License.
56200 * =============================================================================
56201 */
56202 getGlobalTensorClass().prototype.relu6 = function () {
56203 this.throwIfDisposed();
56204 return relu6$2(this);
56205 };
56206
56207 /**
56208 * @license
56209 * Copyright 2020 Google LLC. All Rights Reserved.
56210 * Licensed under the Apache License, Version 2.0 (the "License");
56211 * you may not use this file except in compliance with the License.
56212 * You may obtain a copy of the License at
56213 *
56214 * http://www.apache.org/licenses/LICENSE-2.0
56215 *
56216 * Unless required by applicable law or agreed to in writing, software
56217 * distributed under the License is distributed on an "AS IS" BASIS,
56218 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56219 * See the License for the specific language governing permissions and
56220 * limitations under the License.
56221 * =============================================================================
56222 */
56223 /**
56224 * Reshapes the tensor into the shape of the provided tensor.
56225 *
56226 * @param x The tensor of required shape.
56227 *
56228 * @doc {heading: 'Tensors', subheading: 'Classes'}
56229 */
56230 getGlobalTensorClass().prototype.reshapeAs = function (x) {
56231 this.throwIfDisposed();
56232 return reshape$3(this, x.shape);
56233 };
56234
56235 /**
56236 * @license
56237 * Copyright 2020 Google LLC. All Rights Reserved.
56238 * Licensed under the Apache License, Version 2.0 (the "License");
56239 * you may not use this file except in compliance with the License.
56240 * You may obtain a copy of the License at
56241 *
56242 * http://www.apache.org/licenses/LICENSE-2.0
56243 *
56244 * Unless required by applicable law or agreed to in writing, software
56245 * distributed under the License is distributed on an "AS IS" BASIS,
56246 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56247 * See the License for the specific language governing permissions and
56248 * limitations under the License.
56249 * =============================================================================
56250 */
56251 getGlobalTensorClass().prototype.reshape = function (shape) {
56252 this.throwIfDisposed();
56253 return reshape$3(this, shape);
56254 };
56255
56256 /**
56257 * @license
56258 * Copyright 2020 Google LLC. All Rights Reserved.
56259 * Licensed under the Apache License, Version 2.0 (the "License");
56260 * you may not use this file except in compliance with the License.
56261 * You may obtain a copy of the License at
56262 *
56263 * http://www.apache.org/licenses/LICENSE-2.0
56264 *
56265 * Unless required by applicable law or agreed to in writing, software
56266 * distributed under the License is distributed on an "AS IS" BASIS,
56267 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56268 * See the License for the specific language governing permissions and
56269 * limitations under the License.
56270 * =============================================================================
56271 */
56272 getGlobalTensorClass().prototype.resizeBilinear = function (newShape2D, alignCorners, halfPixelCenters) {
56273 this.throwIfDisposed();
56274 return resizeBilinear$3(this, newShape2D, alignCorners, halfPixelCenters);
56275 };
56276
56277 /**
56278 * @license
56279 * Copyright 2020 Google LLC. All Rights Reserved.
56280 * Licensed under the Apache License, Version 2.0 (the "License");
56281 * you may not use this file except in compliance with the License.
56282 * You may obtain a copy of the License at
56283 *
56284 * http://www.apache.org/licenses/LICENSE-2.0
56285 *
56286 * Unless required by applicable law or agreed to in writing, software
56287 * distributed under the License is distributed on an "AS IS" BASIS,
56288 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56289 * See the License for the specific language governing permissions and
56290 * limitations under the License.
56291 * =============================================================================
56292 */
56293 getGlobalTensorClass().prototype.resizeNearestNeighbor = function (newShape2D, alignCorners, halfFloatCenters) {
56294 this.throwIfDisposed();
56295 return resizeNearestNeighbor$2(this, newShape2D, alignCorners, halfFloatCenters);
56296 };
56297
56298 /**
56299 * @license
56300 * Copyright 2020 Google LLC. All Rights Reserved.
56301 * Licensed under the Apache License, Version 2.0 (the "License");
56302 * you may not use this file except in compliance with the License.
56303 * You may obtain a copy of the License at
56304 *
56305 * http://www.apache.org/licenses/LICENSE-2.0
56306 *
56307 * Unless required by applicable law or agreed to in writing, software
56308 * distributed under the License is distributed on an "AS IS" BASIS,
56309 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56310 * See the License for the specific language governing permissions and
56311 * limitations under the License.
56312 * =============================================================================
56313 */
56314 getGlobalTensorClass().prototype.reverse = function (axis) {
56315 this.throwIfDisposed();
56316 return reverse$2(this, axis);
56317 };
56318
56319 /**
56320 * @license
56321 * Copyright 2020 Google LLC. All Rights Reserved.
56322 * Licensed under the Apache License, Version 2.0 (the "License");
56323 * you may not use this file except in compliance with the License.
56324 * You may obtain a copy of the License at
56325 *
56326 * http://www.apache.org/licenses/LICENSE-2.0
56327 *
56328 * Unless required by applicable law or agreed to in writing, software
56329 * distributed under the License is distributed on an "AS IS" BASIS,
56330 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56331 * See the License for the specific language governing permissions and
56332 * limitations under the License.
56333 * =============================================================================
56334 */
56335 getGlobalTensorClass().prototype.rfft = function () {
56336 this.throwIfDisposed();
56337 return rfft(this);
56338 };
56339
56340 /**
56341 * @license
56342 * Copyright 2020 Google LLC. All Rights Reserved.
56343 * Licensed under the Apache License, Version 2.0 (the "License");
56344 * you may not use this file except in compliance with the License.
56345 * You may obtain a copy of the License at
56346 *
56347 * http://www.apache.org/licenses/LICENSE-2.0
56348 *
56349 * Unless required by applicable law or agreed to in writing, software
56350 * distributed under the License is distributed on an "AS IS" BASIS,
56351 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56352 * See the License for the specific language governing permissions and
56353 * limitations under the License.
56354 * =============================================================================
56355 */
56356 getGlobalTensorClass().prototype.round = function () {
56357 this.throwIfDisposed();
56358 return round$2(this);
56359 };
56360
56361 /**
56362 * @license
56363 * Copyright 2020 Google LLC. All Rights Reserved.
56364 * Licensed under the Apache License, Version 2.0 (the "License");
56365 * you may not use this file except in compliance with the License.
56366 * You may obtain a copy of the License at
56367 *
56368 * http://www.apache.org/licenses/LICENSE-2.0
56369 *
56370 * Unless required by applicable law or agreed to in writing, software
56371 * distributed under the License is distributed on an "AS IS" BASIS,
56372 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56373 * See the License for the specific language governing permissions and
56374 * limitations under the License.
56375 * =============================================================================
56376 */
56377 getGlobalTensorClass().prototype.rsqrt = function () {
56378 this.throwIfDisposed();
56379 return rsqrt$2(this);
56380 };
56381
56382 /**
56383 * @license
56384 * Copyright 2020 Google LLC. All Rights Reserved.
56385 * Licensed under the Apache License, Version 2.0 (the "License");
56386 * you may not use this file except in compliance with the License.
56387 * You may obtain a copy of the License at
56388 *
56389 * http://www.apache.org/licenses/LICENSE-2.0
56390 *
56391 * Unless required by applicable law or agreed to in writing, software
56392 * distributed under the License is distributed on an "AS IS" BASIS,
56393 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56394 * See the License for the specific language governing permissions and
56395 * limitations under the License.
56396 * =============================================================================
56397 */
56398 getGlobalTensorClass().prototype.selu = function () {
56399 this.throwIfDisposed();
56400 return selu$2(this);
56401 };
56402
56403 /**
56404 * @license
56405 * Copyright 2020 Google LLC. All Rights Reserved.
56406 * Licensed under the Apache License, Version 2.0 (the "License");
56407 * you may not use this file except in compliance with the License.
56408 * You may obtain a copy of the License at
56409 *
56410 * http://www.apache.org/licenses/LICENSE-2.0
56411 *
56412 * Unless required by applicable law or agreed to in writing, software
56413 * distributed under the License is distributed on an "AS IS" BASIS,
56414 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56415 * See the License for the specific language governing permissions and
56416 * limitations under the License.
56417 * =============================================================================
56418 */
56419 getGlobalTensorClass().prototype.separableConv2d = function (depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) {
56420 this.throwIfDisposed();
56421 return separableConv2d$1(this, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat);
56422 };
56423
56424 /**
56425 * @license
56426 * Copyright 2020 Google LLC. All Rights Reserved.
56427 * Licensed under the Apache License, Version 2.0 (the "License");
56428 * you may not use this file except in compliance with the License.
56429 * You may obtain a copy of the License at
56430 *
56431 * http://www.apache.org/licenses/LICENSE-2.0
56432 *
56433 * Unless required by applicable law or agreed to in writing, software
56434 * distributed under the License is distributed on an "AS IS" BASIS,
56435 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56436 * See the License for the specific language governing permissions and
56437 * limitations under the License.
56438 * =============================================================================
56439 */
56440 getGlobalTensorClass().prototype.sigmoid = function () {
56441 this.throwIfDisposed();
56442 return sigmoid$2(this);
56443 };
56444
56445 /**
56446 * @license
56447 * Copyright 2020 Google LLC. All Rights Reserved.
56448 * Licensed under the Apache License, Version 2.0 (the "License");
56449 * you may not use this file except in compliance with the License.
56450 * You may obtain a copy of the License at
56451 *
56452 * http://www.apache.org/licenses/LICENSE-2.0
56453 *
56454 * Unless required by applicable law or agreed to in writing, software
56455 * distributed under the License is distributed on an "AS IS" BASIS,
56456 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56457 * See the License for the specific language governing permissions and
56458 * limitations under the License.
56459 * =============================================================================
56460 */
56461 getGlobalTensorClass().prototype.sign = function () {
56462 this.throwIfDisposed();
56463 return sign$3(this);
56464 };
56465
56466 /**
56467 * @license
56468 * Copyright 2020 Google LLC. All Rights Reserved.
56469 * Licensed under the Apache License, Version 2.0 (the "License");
56470 * you may not use this file except in compliance with the License.
56471 * You may obtain a copy of the License at
56472 *
56473 * http://www.apache.org/licenses/LICENSE-2.0
56474 *
56475 * Unless required by applicable law or agreed to in writing, software
56476 * distributed under the License is distributed on an "AS IS" BASIS,
56477 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56478 * See the License for the specific language governing permissions and
56479 * limitations under the License.
56480 * =============================================================================
56481 */
56482 getGlobalTensorClass().prototype.sin = function () {
56483 this.throwIfDisposed();
56484 return sin$2(this);
56485 };
56486
56487 /**
56488 * @license
56489 * Copyright 2020 Google LLC. All Rights Reserved.
56490 * Licensed under the Apache License, Version 2.0 (the "License");
56491 * you may not use this file except in compliance with the License.
56492 * You may obtain a copy of the License at
56493 *
56494 * http://www.apache.org/licenses/LICENSE-2.0
56495 *
56496 * Unless required by applicable law or agreed to in writing, software
56497 * distributed under the License is distributed on an "AS IS" BASIS,
56498 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56499 * See the License for the specific language governing permissions and
56500 * limitations under the License.
56501 * =============================================================================
56502 */
56503 getGlobalTensorClass().prototype.sinh = function () {
56504 this.throwIfDisposed();
56505 return sinh$2(this);
56506 };
56507
56508 /**
56509 * @license
56510 * Copyright 2020 Google LLC. All Rights Reserved.
56511 * Licensed under the Apache License, Version 2.0 (the "License");
56512 * you may not use this file except in compliance with the License.
56513 * You may obtain a copy of the License at
56514 *
56515 * http://www.apache.org/licenses/LICENSE-2.0
56516 *
56517 * Unless required by applicable law or agreed to in writing, software
56518 * distributed under the License is distributed on an "AS IS" BASIS,
56519 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56520 * See the License for the specific language governing permissions and
56521 * limitations under the License.
56522 * =============================================================================
56523 */
56524 getGlobalTensorClass().prototype.slice = function (begin, size) {
56525 this.throwIfDisposed();
56526 return slice$2(this, begin, size);
56527 };
56528
56529 /**
56530 * @license
56531 * Copyright 2020 Google LLC. All Rights Reserved.
56532 * Licensed under the Apache License, Version 2.0 (the "License");
56533 * you may not use this file except in compliance with the License.
56534 * You may obtain a copy of the License at
56535 *
56536 * http://www.apache.org/licenses/LICENSE-2.0
56537 *
56538 * Unless required by applicable law or agreed to in writing, software
56539 * distributed under the License is distributed on an "AS IS" BASIS,
56540 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56541 * See the License for the specific language governing permissions and
56542 * limitations under the License.
56543 * =============================================================================
56544 */
56545 getGlobalTensorClass().prototype.softmax = function (dim) {
56546 this.throwIfDisposed();
56547 return softmax$3(this, dim);
56548 };
56549
56550 /**
56551 * @license
56552 * Copyright 2020 Google LLC. All Rights Reserved.
56553 * Licensed under the Apache License, Version 2.0 (the "License");
56554 * you may not use this file except in compliance with the License.
56555 * You may obtain a copy of the License at
56556 *
56557 * http://www.apache.org/licenses/LICENSE-2.0
56558 *
56559 * Unless required by applicable law or agreed to in writing, software
56560 * distributed under the License is distributed on an "AS IS" BASIS,
56561 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56562 * See the License for the specific language governing permissions and
56563 * limitations under the License.
56564 * =============================================================================
56565 */
56566 getGlobalTensorClass().prototype.softplus = function () {
56567 this.throwIfDisposed();
56568 return softplus$2(this);
56569 };
56570
56571 /**
56572 * @license
56573 * Copyright 2020 Google LLC. All Rights Reserved.
56574 * Licensed under the Apache License, Version 2.0 (the "License");
56575 * you may not use this file except in compliance with the License.
56576 * You may obtain a copy of the License at
56577 *
56578 * http://www.apache.org/licenses/LICENSE-2.0
56579 *
56580 * Unless required by applicable law or agreed to in writing, software
56581 * distributed under the License is distributed on an "AS IS" BASIS,
56582 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56583 * See the License for the specific language governing permissions and
56584 * limitations under the License.
56585 * =============================================================================
56586 */
56587 getGlobalTensorClass().prototype.spaceToBatchND = function (blockShape, paddings) {
56588 this.throwIfDisposed();
56589 return spaceToBatchND$2(this, blockShape, paddings);
56590 };
56591
56592 /**
56593 * @license
56594 * Copyright 2020 Google LLC. All Rights Reserved.
56595 * Licensed under the Apache License, Version 2.0 (the "License");
56596 * you may not use this file except in compliance with the License.
56597 * You may obtain a copy of the License at
56598 *
56599 * http://www.apache.org/licenses/LICENSE-2.0
56600 *
56601 * Unless required by applicable law or agreed to in writing, software
56602 * distributed under the License is distributed on an "AS IS" BASIS,
56603 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56604 * See the License for the specific language governing permissions and
56605 * limitations under the License.
56606 * =============================================================================
56607 */
56608 getGlobalTensorClass().prototype.split = function (numOrSizeSplits, axis) {
56609 this.throwIfDisposed();
56610 return split$3(this, numOrSizeSplits, axis);
56611 };
56612
56613 /**
56614 * @license
56615 * Copyright 2020 Google LLC. All Rights Reserved.
56616 * Licensed under the Apache License, Version 2.0 (the "License");
56617 * you may not use this file except in compliance with the License.
56618 * You may obtain a copy of the License at
56619 *
56620 * http://www.apache.org/licenses/LICENSE-2.0
56621 *
56622 * Unless required by applicable law or agreed to in writing, software
56623 * distributed under the License is distributed on an "AS IS" BASIS,
56624 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56625 * See the License for the specific language governing permissions and
56626 * limitations under the License.
56627 * =============================================================================
56628 */
56629 getGlobalTensorClass().prototype.sqrt = function () {
56630 this.throwIfDisposed();
56631 return sqrt$2(this);
56632 };
56633
56634 /**
56635 * @license
56636 * Copyright 2020 Google LLC. All Rights Reserved.
56637 * Licensed under the Apache License, Version 2.0 (the "License");
56638 * you may not use this file except in compliance with the License.
56639 * You may obtain a copy of the License at
56640 *
56641 * http://www.apache.org/licenses/LICENSE-2.0
56642 *
56643 * Unless required by applicable law or agreed to in writing, software
56644 * distributed under the License is distributed on an "AS IS" BASIS,
56645 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56646 * See the License for the specific language governing permissions and
56647 * limitations under the License.
56648 * =============================================================================
56649 */
56650 getGlobalTensorClass().prototype.square = function () {
56651 this.throwIfDisposed();
56652 return square$2(this);
56653 };
56654
56655 /**
56656 * @license
56657 * Copyright 2020 Google LLC. All Rights Reserved.
56658 * Licensed under the Apache License, Version 2.0 (the "License");
56659 * you may not use this file except in compliance with the License.
56660 * You may obtain a copy of the License at
56661 *
56662 * http://www.apache.org/licenses/LICENSE-2.0
56663 *
56664 * Unless required by applicable law or agreed to in writing, software
56665 * distributed under the License is distributed on an "AS IS" BASIS,
56666 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56667 * See the License for the specific language governing permissions and
56668 * limitations under the License.
56669 * =============================================================================
56670 */
56671 getGlobalTensorClass().prototype.squaredDifference = function (b) {
56672 this.throwIfDisposed();
56673 return squaredDifference$2(this, b);
56674 };
56675
56676 /**
56677 * @license
56678 * Copyright 2020 Google LLC. All Rights Reserved.
56679 * Licensed under the Apache License, Version 2.0 (the "License");
56680 * you may not use this file except in compliance with the License.
56681 * You may obtain a copy of the License at
56682 *
56683 * http://www.apache.org/licenses/LICENSE-2.0
56684 *
56685 * Unless required by applicable law or agreed to in writing, software
56686 * distributed under the License is distributed on an "AS IS" BASIS,
56687 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56688 * See the License for the specific language governing permissions and
56689 * limitations under the License.
56690 * =============================================================================
56691 */
56692 getGlobalTensorClass().prototype.squeeze = function (axis) {
56693 this.throwIfDisposed();
56694 return squeeze(this, axis);
56695 };
56696
56697 getGlobalTensorClass().prototype.stack = function (x, axis) {
56698 this.throwIfDisposed();
56699 var tensorsToBeStacked = x instanceof Tensor ? [this, x] : [this].concat(_toConsumableArray(x));
56700 return stack(tensorsToBeStacked, axis);
56701 };
56702
56703 /**
56704 * @license
56705 * Copyright 2020 Google LLC. All Rights Reserved.
56706 * Licensed under the Apache License, Version 2.0 (the "License");
56707 * you may not use this file except in compliance with the License.
56708 * You may obtain a copy of the License at
56709 *
56710 * http://www.apache.org/licenses/LICENSE-2.0
56711 *
56712 * Unless required by applicable law or agreed to in writing, software
56713 * distributed under the License is distributed on an "AS IS" BASIS,
56714 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56715 * See the License for the specific language governing permissions and
56716 * limitations under the License.
56717 * =============================================================================
56718 */
56719 getGlobalTensorClass().prototype.step = function (alpha) {
56720 this.throwIfDisposed();
56721 return step$2(this, alpha);
56722 };
56723
56724 /**
56725 * @license
56726 * Copyright 2020 Google LLC. All Rights Reserved.
56727 * Licensed under the Apache License, Version 2.0 (the "License");
56728 * you may not use this file except in compliance with the License.
56729 * You may obtain a copy of the License at
56730 *
56731 * http://www.apache.org/licenses/LICENSE-2.0
56732 *
56733 * Unless required by applicable law or agreed to in writing, software
56734 * distributed under the License is distributed on an "AS IS" BASIS,
56735 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56736 * See the License for the specific language governing permissions and
56737 * limitations under the License.
56738 * =============================================================================
56739 */
56740 getGlobalTensorClass().prototype.stridedSlice = function (begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
56741 this.throwIfDisposed();
56742 return stridedSlice$2(this, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
56743 };
56744
56745 /**
56746 * @license
56747 * Copyright 2020 Google LLC. All Rights Reserved.
56748 * Licensed under the Apache License, Version 2.0 (the "License");
56749 * you may not use this file except in compliance with the License.
56750 * You may obtain a copy of the License at
56751 *
56752 * http://www.apache.org/licenses/LICENSE-2.0
56753 *
56754 * Unless required by applicable law or agreed to in writing, software
56755 * distributed under the License is distributed on an "AS IS" BASIS,
56756 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56757 * See the License for the specific language governing permissions and
56758 * limitations under the License.
56759 * =============================================================================
56760 */
56761 getGlobalTensorClass().prototype.sub = function (b) {
56762 this.throwIfDisposed();
56763 return sub$2(this, b);
56764 };
56765
56766 /**
56767 * @license
56768 * Copyright 2020 Google LLC. All Rights Reserved.
56769 * Licensed under the Apache License, Version 2.0 (the "License");
56770 * you may not use this file except in compliance with the License.
56771 * You may obtain a copy of the License at
56772 *
56773 * http://www.apache.org/licenses/LICENSE-2.0
56774 *
56775 * Unless required by applicable law or agreed to in writing, software
56776 * distributed under the License is distributed on an "AS IS" BASIS,
56777 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56778 * See the License for the specific language governing permissions and
56779 * limitations under the License.
56780 * =============================================================================
56781 */
56782 getGlobalTensorClass().prototype.sum = function (axis, keepDims) {
56783 this.throwIfDisposed();
56784 return sum$3(this, axis, keepDims);
56785 };
56786
56787 /**
56788 * @license
56789 * Copyright 2020 Google LLC. All Rights Reserved.
56790 * Licensed under the Apache License, Version 2.0 (the "License");
56791 * you may not use this file except in compliance with the License.
56792 * You may obtain a copy of the License at
56793 *
56794 * http://www.apache.org/licenses/LICENSE-2.0
56795 *
56796 * Unless required by applicable law or agreed to in writing, software
56797 * distributed under the License is distributed on an "AS IS" BASIS,
56798 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56799 * See the License for the specific language governing permissions and
56800 * limitations under the License.
56801 * =============================================================================
56802 */
56803 getGlobalTensorClass().prototype.tan = function () {
56804 this.throwIfDisposed();
56805 return tan$2(this);
56806 };
56807
56808 /**
56809 * @license
56810 * Copyright 2020 Google LLC. All Rights Reserved.
56811 * Licensed under the Apache License, Version 2.0 (the "License");
56812 * you may not use this file except in compliance with the License.
56813 * You may obtain a copy of the License at
56814 *
56815 * http://www.apache.org/licenses/LICENSE-2.0
56816 *
56817 * Unless required by applicable law or agreed to in writing, software
56818 * distributed under the License is distributed on an "AS IS" BASIS,
56819 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56820 * See the License for the specific language governing permissions and
56821 * limitations under the License.
56822 * =============================================================================
56823 */
56824 getGlobalTensorClass().prototype.tanh = function () {
56825 this.throwIfDisposed();
56826 return tanh$2(this);
56827 };
56828
56829 /**
56830 * @license
56831 * Copyright 2020 Google LLC. All Rights Reserved.
56832 * Licensed under the Apache License, Version 2.0 (the "License");
56833 * you may not use this file except in compliance with the License.
56834 * You may obtain a copy of the License at
56835 *
56836 * http://www.apache.org/licenses/LICENSE-2.0
56837 *
56838 * Unless required by applicable law or agreed to in writing, software
56839 * distributed under the License is distributed on an "AS IS" BASIS,
56840 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56841 * See the License for the specific language governing permissions and
56842 * limitations under the License.
56843 * =============================================================================
56844 */
56845 getGlobalTensorClass().prototype.tile = function (reps) {
56846 this.throwIfDisposed();
56847 return tile$3(this, reps);
56848 };
56849
56850 /**
56851 * @license
56852 * Copyright 2020 Google LLC. All Rights Reserved.
56853 * Licensed under the Apache License, Version 2.0 (the "License");
56854 * you may not use this file except in compliance with the License.
56855 * You may obtain a copy of the License at
56856 *
56857 * http://www.apache.org/licenses/LICENSE-2.0
56858 *
56859 * Unless required by applicable law or agreed to in writing, software
56860 * distributed under the License is distributed on an "AS IS" BASIS,
56861 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56862 * See the License for the specific language governing permissions and
56863 * limitations under the License.
56864 * =============================================================================
56865 */
56866 /**
56867 * Casts the array to type `bool`
56868 *
56869 * @doc {heading: 'Tensors', subheading: 'Classes'}
56870 */
56871 getGlobalTensorClass().prototype.toBool = function () {
56872 this.throwIfDisposed();
56873 return cast$3(this, 'bool');
56874 };
56875
56876 /**
56877 * @license
56878 * Copyright 2020 Google LLC. All Rights Reserved.
56879 * Licensed under the Apache License, Version 2.0 (the "License");
56880 * you may not use this file except in compliance with the License.
56881 * You may obtain a copy of the License at
56882 *
56883 * http://www.apache.org/licenses/LICENSE-2.0
56884 *
56885 * Unless required by applicable law or agreed to in writing, software
56886 * distributed under the License is distributed on an "AS IS" BASIS,
56887 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56888 * See the License for the specific language governing permissions and
56889 * limitations under the License.
56890 * =============================================================================
56891 */
56892 /**
56893 * Casts the array to type `float32`
56894 *
56895 * @doc {heading: 'Tensors', subheading: 'Classes'}
56896 */
56897 getGlobalTensorClass().prototype.toFloat = function () {
56898 this.throwIfDisposed();
56899 return cast$3(this, 'float32');
56900 };
56901
56902 /**
56903 * @license
56904 * Copyright 2020 Google LLC. All Rights Reserved.
56905 * Licensed under the Apache License, Version 2.0 (the "License");
56906 * you may not use this file except in compliance with the License.
56907 * You may obtain a copy of the License at
56908 *
56909 * http://www.apache.org/licenses/LICENSE-2.0
56910 *
56911 * Unless required by applicable law or agreed to in writing, software
56912 * distributed under the License is distributed on an "AS IS" BASIS,
56913 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56914 * See the License for the specific language governing permissions and
56915 * limitations under the License.
56916 * =============================================================================
56917 */
56918 /**
56919 * Casts the array to type `int32`
56920 *
56921 * @doc {heading: 'Tensors', subheading: 'Classes'}
56922 */
56923 getGlobalTensorClass().prototype.toInt = function () {
56924 this.throwIfDisposed();
56925 return cast$3(this, 'int32');
56926 };
56927
56928 /**
56929 * @license
56930 * Copyright 2020 Google LLC. All Rights Reserved.
56931 * Licensed under the Apache License, Version 2.0 (the "License");
56932 * you may not use this file except in compliance with the License.
56933 * You may obtain a copy of the License at
56934 *
56935 * http://www.apache.org/licenses/LICENSE-2.0
56936 *
56937 * Unless required by applicable law or agreed to in writing, software
56938 * distributed under the License is distributed on an "AS IS" BASIS,
56939 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56940 * See the License for the specific language governing permissions and
56941 * limitations under the License.
56942 * =============================================================================
56943 */
56944 getGlobalTensorClass().prototype.topk = function (k, sorted) {
56945 this.throwIfDisposed();
56946 return topk(this, k, sorted);
56947 };
56948
56949 /**
56950 * @license
56951 * Copyright 2020 Google LLC. All Rights Reserved.
56952 * Licensed under the Apache License, Version 2.0 (the "License");
56953 * you may not use this file except in compliance with the License.
56954 * You may obtain a copy of the License at
56955 *
56956 * http://www.apache.org/licenses/LICENSE-2.0
56957 *
56958 * Unless required by applicable law or agreed to in writing, software
56959 * distributed under the License is distributed on an "AS IS" BASIS,
56960 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56961 * See the License for the specific language governing permissions and
56962 * limitations under the License.
56963 * =============================================================================
56964 */
56965 getGlobalTensorClass().prototype.transpose = function (perm) {
56966 this.throwIfDisposed();
56967 return transpose$2(this, perm);
56968 };
56969
56970 /**
56971 * @license
56972 * Copyright 2020 Google LLC. All Rights Reserved.
56973 * Licensed under the Apache License, Version 2.0 (the "License");
56974 * you may not use this file except in compliance with the License.
56975 * You may obtain a copy of the License at
56976 *
56977 * http://www.apache.org/licenses/LICENSE-2.0
56978 *
56979 * Unless required by applicable law or agreed to in writing, software
56980 * distributed under the License is distributed on an "AS IS" BASIS,
56981 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56982 * See the License for the specific language governing permissions and
56983 * limitations under the License.
56984 * =============================================================================
56985 */
56986 getGlobalTensorClass().prototype.unique = function (axis) {
56987 this.throwIfDisposed();
56988 return unique$3(this, axis);
56989 };
56990
56991 /**
56992 * @license
56993 * Copyright 2020 Google LLC. All Rights Reserved.
56994 * Licensed under the Apache License, Version 2.0 (the "License");
56995 * you may not use this file except in compliance with the License.
56996 * You may obtain a copy of the License at
56997 *
56998 * http://www.apache.org/licenses/LICENSE-2.0
56999 *
57000 * Unless required by applicable law or agreed to in writing, software
57001 * distributed under the License is distributed on an "AS IS" BASIS,
57002 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57003 * See the License for the specific language governing permissions and
57004 * limitations under the License.
57005 * =============================================================================
57006 */
57007 getGlobalTensorClass().prototype.unsortedSegmentSum = function (segmentIds, numSegments) {
57008 this.throwIfDisposed();
57009 return unsortedSegmentSum$2(this, segmentIds, numSegments);
57010 };
57011
57012 /**
57013 * @license
57014 * Copyright 2020 Google LLC. All Rights Reserved.
57015 * Licensed under the Apache License, Version 2.0 (the "License");
57016 * you may not use this file except in compliance with the License.
57017 * You may obtain a copy of the License at
57018 *
57019 * http://www.apache.org/licenses/LICENSE-2.0
57020 *
57021 * Unless required by applicable law or agreed to in writing, software
57022 * distributed under the License is distributed on an "AS IS" BASIS,
57023 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57024 * See the License for the specific language governing permissions and
57025 * limitations under the License.
57026 * =============================================================================
57027 */
57028 getGlobalTensorClass().prototype.unstack = function (axis) {
57029 this.throwIfDisposed();
57030 return unstack(this, axis);
57031 };
57032
57033 /**
57034 * @license
57035 * Copyright 2020 Google LLC. All Rights Reserved.
57036 * Licensed under the Apache License, Version 2.0 (the "License");
57037 * you may not use this file except in compliance with the License.
57038 * You may obtain a copy of the License at
57039 *
57040 * http://www.apache.org/licenses/LICENSE-2.0
57041 *
57042 * Unless required by applicable law or agreed to in writing, software
57043 * distributed under the License is distributed on an "AS IS" BASIS,
57044 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57045 * See the License for the specific language governing permissions and
57046 * limitations under the License.
57047 * =============================================================================
57048 */
57049 getGlobalTensorClass().prototype.where = function (condition, x) {
57050 this.throwIfDisposed();
57051 return where(condition, this, x);
57052 };
57053
57054 /**
57055 * @license
57056 * Copyright 2020 Google LLC. All Rights Reserved.
57057 * Licensed under the Apache License, Version 2.0 (the "License");
57058 * you may not use this file except in compliance with the License.
57059 * You may obtain a copy of the License at
57060 *
57061 * http://www.apache.org/licenses/LICENSE-2.0
57062 *
57063 * Unless required by applicable law or agreed to in writing, software
57064 * distributed under the License is distributed on an "AS IS" BASIS,
57065 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57066 * See the License for the specific language governing permissions and
57067 * limitations under the License.
57068 * =============================================================================
57069 */
57070 getGlobalTensorClass().prototype.zerosLike = function () {
57071 this.throwIfDisposed();
57072 return zerosLike$3(this);
57073 };
57074
57075 /**
57076 * @license
57077 * Copyright 2020 Google LLC. All Rights Reserved.
57078 * Licensed under the Apache License, Version 2.0 (the "License");
57079 * you may not use this file except in compliance with the License.
57080 * You may obtain a copy of the License at
57081 *
57082 * http://www.apache.org/licenses/LICENSE-2.0
57083 *
57084 * Unless required by applicable law or agreed to in writing, software
57085 * distributed under the License is distributed on an "AS IS" BASIS,
57086 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57087 * See the License for the specific language governing permissions and
57088 * limitations under the License.
57089 * =============================================================================
57090 */
57091
57092 /**
57093 * @license
57094 * Copyright 2018 Google LLC
57095 *
57096 * Use of this source code is governed by an MIT-style
57097 * license that can be found in the LICENSE file or at
57098 * https://opensource.org/licenses/MIT.
57099 * =============================================================================
57100 */
57101 /**
57102 * Explicit error types.
57103 *
57104 * See the following link for more information about why the code includes
57105 * calls to setPrototypeOf:
57106 *
57107 * https://github.com/Microsoft/TypeScript-wiki/blob/master/Breaking-Changes.md#extending-built-ins-like-error-array-and-map-may-no-longer-work
57108 */
57109 // tslint:enable
57110 /**
57111 * Equivalent of Python's AttributeError.
57112 */
57113 var AttributeError = /*#__PURE__*/function (_Error) {
57114 _inherits(AttributeError, _Error);
57115 var _super = _createSuper(AttributeError);
57116 function AttributeError(message) {
57117 var _this;
57118 _classCallCheck(this, AttributeError);
57119 _this = _super.call(this, message);
57120 // Set the prototype explicitly.
57121 Object.setPrototypeOf(_assertThisInitialized(_this), AttributeError.prototype);
57122 return _this;
57123 }
57124 return _createClass(AttributeError);
57125 }( /*#__PURE__*/_wrapNativeSuper(Error));
57126 /**
57127 * Equivalent of Python's RuntimeError.
57128 */
57129 var RuntimeError = /*#__PURE__*/function (_Error2) {
57130 _inherits(RuntimeError, _Error2);
57131 var _super2 = _createSuper(RuntimeError);
57132 function RuntimeError(message) {
57133 var _this2;
57134 _classCallCheck(this, RuntimeError);
57135 _this2 = _super2.call(this, message);
57136 // Set the prototype explicitly.
57137 Object.setPrototypeOf(_assertThisInitialized(_this2), RuntimeError.prototype);
57138 return _this2;
57139 }
57140 return _createClass(RuntimeError);
57141 }( /*#__PURE__*/_wrapNativeSuper(Error));
57142 /**
57143 * Equivalent of Python's ValueError.
57144 */
57145 var ValueError = /*#__PURE__*/function (_Error3) {
57146 _inherits(ValueError, _Error3);
57147 var _super3 = _createSuper(ValueError);
57148 function ValueError(message) {
57149 var _this3;
57150 _classCallCheck(this, ValueError);
57151 _this3 = _super3.call(this, message);
57152 // Set the prototype explicitly.
57153 Object.setPrototypeOf(_assertThisInitialized(_this3), ValueError.prototype);
57154 return _this3;
57155 }
57156 return _createClass(ValueError);
57157 }( /*#__PURE__*/_wrapNativeSuper(Error));
57158 /**
57159 * Equivalent of Python's NotImplementedError.
57160 */
57161 var NotImplementedError = /*#__PURE__*/function (_Error4) {
57162 _inherits(NotImplementedError, _Error4);
57163 var _super4 = _createSuper(NotImplementedError);
57164 function NotImplementedError(message) {
57165 var _this4;
57166 _classCallCheck(this, NotImplementedError);
57167 _this4 = _super4.call(this, message);
57168 // Set the prototype explicitly.
57169 Object.setPrototypeOf(_assertThisInitialized(_this4), NotImplementedError.prototype);
57170 return _this4;
57171 }
57172 return _createClass(NotImplementedError);
57173 }( /*#__PURE__*/_wrapNativeSuper(Error));
57174 /**
57175 * Equivalent of Python's AssertionError.
57176 */
57177 var AssertionError = /*#__PURE__*/function (_Error5) {
57178 _inherits(AssertionError, _Error5);
57179 var _super5 = _createSuper(AssertionError);
57180 function AssertionError(message) {
57181 var _this5;
57182 _classCallCheck(this, AssertionError);
57183 _this5 = _super5.call(this, message);
57184 // Set the prototype explicitly.
57185 Object.setPrototypeOf(_assertThisInitialized(_this5), AssertionError.prototype);
57186 return _this5;
57187 }
57188 return _createClass(AssertionError);
57189 }( /*#__PURE__*/_wrapNativeSuper(Error));
57190 /**
57191 * Equivalent of Python's IndexError.
57192 */
57193 var IndexError = /*#__PURE__*/function (_Error6) {
57194 _inherits(IndexError, _Error6);
57195 var _super6 = _createSuper(IndexError);
57196 function IndexError(message) {
57197 var _this6;
57198 _classCallCheck(this, IndexError);
57199 _this6 = _super6.call(this, message);
57200 // Set the prototype explicitly.
57201 Object.setPrototypeOf(_assertThisInitialized(_this6), IndexError.prototype);
57202 return _this6;
57203 }
57204 return _createClass(IndexError);
57205 }( /*#__PURE__*/_wrapNativeSuper(Error));
57206
57207 /**
57208 * @license
57209 * Copyright 2022 Google LLC
57210 *
57211 * Use of this source code is governed by an MIT-style
57212 * license that can be found in the LICENSE file or at
57213 * https://opensource.org/licenses/MIT.
57214 * =============================================================================
57215 */
57216 /**
57217 * LruCache: A mapping from the String to T. If the number of the entries is
57218 * exceeding the `maxEntries`, the LruCache will delete the least recently
57219 * used entry.
57220 */
57221 var LruCache = /*#__PURE__*/function () {
57222 function LruCache(maxEntries) {
57223 _classCallCheck(this, LruCache);
57224 this.maxEntries = maxEntries || 100;
57225 this.cache = new Map();
57226 }
57227 /**
57228 * Get the entry for the key and mark it as used recently.
57229 */
57230 _createClass(LruCache, [{
57231 key: "get",
57232 value: function get(key) {
57233 var entry;
57234 if (this.cache.has(key)) {
57235 entry = this.cache.get(key);
57236 this.cache.delete(key);
57237 this.cache.set(key, entry);
57238 }
57239 return entry;
57240 }
57241 /**
57242 * Put the entry into the cache. If the key already existed, mark the key as
57243 * used recently.
57244 */
57245 }, {
57246 key: "put",
57247 value: function put(key, value) {
57248 if (this.cache.has(key)) {
57249 this.cache.delete(key);
57250 } else if (this.cache.size >= this.maxEntries) {
57251 var keyToDelete = this.cache.keys().next().value;
57252 this.cache.delete(keyToDelete);
57253 }
57254 this.cache.set(key, value);
57255 }
57256 /**
57257 * Get the MaxEntries of the cache.
57258 */
57259 }, {
57260 key: "getMaxEntries",
57261 value: function getMaxEntries() {
57262 return this.maxEntries;
57263 }
57264 /**
57265 * Set the MaxEntries of the cache. If the maxEntries is decreased, reduce
57266 * entries in the cache.
57267 */
57268 }, {
57269 key: "setMaxEntries",
57270 value: function setMaxEntries(maxEntries) {
57271 if (maxEntries < 0) {
57272 throw new Error("The maxEntries of LRU caches must be at least 0, but got ".concat(maxEntries, "."));
57273 }
57274 if (this.maxEntries > maxEntries) {
57275 for (var i = 0; i < this.maxEntries - maxEntries; i++) {
57276 var keyToDelete = this.cache.keys().next().value;
57277 this.cache.delete(keyToDelete);
57278 }
57279 }
57280 this.maxEntries = maxEntries;
57281 }
57282 }]);
57283 return LruCache;
57284 }();
57285
57286 // tslint:enable
57287 /**
57288 * If `value` is an Array, equivalent to Python's `value * numValues`.
57289 * If `value` is not an Array, equivalent to Python's `[value] * numValues`
57290 */
57291 // tslint:disable-next-line:no-any
57292 function pyListRepeat(value, numValues) {
57293 if (Array.isArray(value)) {
57294 // tslint:disable-next-line:no-any
57295 var newArray = [];
57296 for (var i = 0; i < numValues; i++) {
57297 newArray = newArray.concat(value);
57298 }
57299 return newArray;
57300 } else {
57301 var _newArray = new Array(numValues);
57302 _newArray.fill(value);
57303 return _newArray;
57304 }
57305 }
57306 function assert(val, message) {
57307 if (!val) {
57308 throw new AssertionError(message);
57309 }
57310 }
57311 /**
57312 * Count the number of elements of the `array` that are equal to `reference`.
57313 */
57314 function count(array, refernce) {
57315 var counter = 0;
57316 var _iterator = _createForOfIteratorHelper(array),
57317 _step;
57318 try {
57319 for (_iterator.s(); !(_step = _iterator.n()).done;) {
57320 var item = _step.value;
57321 if (item === refernce) {
57322 counter++;
57323 }
57324 }
57325 } catch (err) {
57326 _iterator.e(err);
57327 } finally {
57328 _iterator.f();
57329 }
57330 return counter;
57331 }
57332 /**
57333 * If an array is of length 1, just return the first element. Otherwise, return
57334 * the full array.
57335 * @param tensors
57336 */
57337 function singletonOrArray(xs) {
57338 if (xs.length === 1) {
57339 return xs[0];
57340 }
57341 return xs;
57342 }
57343 /**
57344 * Normalizes a list/tensor into a list.
57345 *
57346 * If a tensor is passed, we return
57347 * a list of size 1 containing the tensor.
57348 *
57349 * @param x target object to be normalized.
57350 */
57351 // tslint:disable-next-line:no-any
57352 function toList(x) {
57353 if (Array.isArray(x)) {
57354 return x;
57355 }
57356 return [x];
57357 }
57358 /**
57359 * Generate a UID for a list
57360 */
57361 // tslint:disable-next-line:no-any
57362 function objectListUid(objs) {
57363 var objectList = toList(objs);
57364 var retVal = '';
57365 var _iterator2 = _createForOfIteratorHelper(objectList),
57366 _step2;
57367 try {
57368 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
57369 var obj = _step2.value;
57370 if (obj.id == null) {
57371 throw new ValueError("Object ".concat(obj, " passed to objectListUid without an id"));
57372 }
57373 if (retVal !== '') {
57374 retVal = retVal + ', ';
57375 }
57376 retVal = "".concat(retVal).concat(Math.abs(obj.id));
57377 }
57378 } catch (err) {
57379 _iterator2.e(err);
57380 } finally {
57381 _iterator2.f();
57382 }
57383 return retVal;
57384 }
57385 /**
57386 * Converts string to snake-case.
57387 * @param name
57388 */
57389 function toSnakeCase(name) {
57390 var intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2');
57391 var insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase();
57392 /*
57393 If the class is private the name starts with "_" which is not secure
57394 for creating scopes. We prefix the name with "private" in this case.
57395 */
57396 if (insecure[0] !== '_') {
57397 return insecure;
57398 }
57399 return 'private' + insecure;
57400 }
57401 function toCamelCase(identifier) {
57402 // quick return for empty string or single character strings
57403 if (identifier.length <= 1) {
57404 return identifier;
57405 }
57406 // Check for the underscore indicating snake_case
57407 if (identifier.indexOf('_') === -1) {
57408 return identifier;
57409 }
57410 return identifier.replace(/[_]+(\w|$)/g, function (m, p1) {
57411 return p1.toUpperCase();
57412 });
57413 }
57414 // tslint:disable-next-line:no-any
57415 var _GLOBAL_CUSTOM_OBJECTS = {};
57416 function serializeKerasObject(instance) {
57417 if (instance === null || instance === undefined) {
57418 return null;
57419 }
57420 var dict = {};
57421 dict['className'] = instance.getClassName();
57422 dict['config'] = instance.getConfig();
57423 return dict;
57424 }
57425 /**
57426 * Replace ndarray-style scalar objects in serialization objects with numbers.
57427 *
57428 * Background: In some versions of tf.keras, certain scalar values in the HDF5
57429 * model save file can be serialized as: `{'type': 'ndarray', 'value': num}`,
57430 * where in `num` is a plain number. This method converts such serialization
57431 * to a `number`.
57432 *
57433 * @param config The keras-format serialization object to be processed
57434 * (in place).
57435 */
57436 function convertNDArrayScalarsInConfig(config) {
57437 if (config == null || _typeof(config) !== 'object') {
57438 return;
57439 } else if (Array.isArray(config)) {
57440 config.forEach(function (configItem) {
57441 return convertNDArrayScalarsInConfig(configItem);
57442 });
57443 } else {
57444 var fields = Object.keys(config);
57445 for (var _i = 0, _fields = fields; _i < _fields.length; _i++) {
57446 var field = _fields[_i];
57447 var value = config[field];
57448 if (value != null && _typeof(value) === 'object') {
57449 if (!Array.isArray(value) && value['type'] === 'ndarray' && typeof value['value'] === 'number') {
57450 config[field] = value['value'];
57451 } else {
57452 convertNDArrayScalarsInConfig(value);
57453 }
57454 }
57455 }
57456 }
57457 }
57458 /**
57459 * Deserialize a saved Keras Object
57460 * @param identifier either a string ID or a saved Keras dictionary
57461 * @param moduleObjects a list of Python class names to object constructors
57462 * @param customObjects a list of Python class names to object constructors
57463 * @param printableModuleName debug text for the object being reconstituted
57464 * @param fastWeightInit Optional flag to use fast weight initialization
57465 * during deserialization. This is applicable to cases in which
57466 * the initialization will be immediately overwritten by loaded weight
57467 * values. Default: `false`.
57468 * @returns a TensorFlow.js Layers object
57469 */
57470 // tslint:disable:no-any
57471 function deserializeKerasObject(identifier) {
57472 var moduleObjects = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
57473 var customObjects = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : {};
57474 var printableModuleName = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'object';
57475 var fastWeightInit = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
57476 // tslint:enable
57477 if (typeof identifier === 'string') {
57478 var functionName = identifier;
57479 var fn;
57480 if (functionName in customObjects) {
57481 fn = customObjects[functionName];
57482 } else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {
57483 fn = _GLOBAL_CUSTOM_OBJECTS[functionName];
57484 } else {
57485 fn = moduleObjects[functionName];
57486 if (fn == null) {
57487 throw new ValueError("Unknown ".concat(printableModuleName, ": ").concat(identifier, ". ") + "This may be due to one of the following reasons:\n" + "1. The ".concat(printableModuleName, " is defined in Python, in which ") + "case it needs to be ported to TensorFlow.js or your JavaScript " + "code.\n" + "2. The custom ".concat(printableModuleName, " is defined in JavaScript, ") + "but is not registered properly with " + "tf.serialization.registerClass().");
57488 // TODO(cais): Add link to tutorial page on custom layers.
57489 }
57490 }
57491
57492 return fn;
57493 } else {
57494 // In this case we are dealing with a Keras config dictionary.
57495 var config = identifier;
57496 if (config['className'] == null || config['config'] == null) {
57497 throw new ValueError("".concat(printableModuleName, ": Improper config format: ") + "".concat(JSON.stringify(config), ".\n") + "'className' and 'config' must set.");
57498 }
57499 var className = config['className'];
57500 var cls, fromConfig;
57501 if (className in customObjects) {
57502 var _customObjects$classN = _slicedToArray(customObjects[className], 2);
57503 cls = _customObjects$classN[0];
57504 fromConfig = _customObjects$classN[1];
57505 } else if (className in _GLOBAL_CUSTOM_OBJECTS) {
57506 var _GLOBAL_CUSTOM_OBJECT = _slicedToArray(_GLOBAL_CUSTOM_OBJECTS['className'], 2);
57507 cls = _GLOBAL_CUSTOM_OBJECT[0];
57508 fromConfig = _GLOBAL_CUSTOM_OBJECT[1];
57509 } else if (className in moduleObjects) {
57510 var _moduleObjects$classN = _slicedToArray(moduleObjects[className], 2);
57511 cls = _moduleObjects$classN[0];
57512 fromConfig = _moduleObjects$classN[1];
57513 }
57514 if (cls == null) {
57515 throw new ValueError("Unknown ".concat(printableModuleName, ": ").concat(className, ". ") + "This may be due to one of the following reasons:\n" + "1. The ".concat(printableModuleName, " is defined in Python, in which ") + "case it needs to be ported to TensorFlow.js or your JavaScript " + "code.\n" + "2. The custom ".concat(printableModuleName, " is defined in JavaScript, ") + "but is not registered properly with " + "tf.serialization.registerClass().");
57516 // TODO(cais): Add link to tutorial page on custom layers.
57517 }
57518
57519 if (fromConfig != null) {
57520 // Porting notes: Instead of checking to see whether fromConfig accepts
57521 // customObjects, we create a customObjects dictionary and tack it on to
57522 // config['config'] as config['config'].customObjects. Objects can use it,
57523 // if they want.
57524 // tslint:disable-next-line:no-any
57525 var customObjectsCombined = {};
57526 for (var _i2 = 0, _Object$keys = Object.keys(_GLOBAL_CUSTOM_OBJECTS); _i2 < _Object$keys.length; _i2++) {
57527 var key = _Object$keys[_i2];
57528 customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];
57529 }
57530 for (var _i3 = 0, _Object$keys2 = Object.keys(customObjects); _i3 < _Object$keys2.length; _i3++) {
57531 var _key = _Object$keys2[_i3];
57532 customObjectsCombined[_key] = customObjects[_key];
57533 }
57534 // Add the customObjects to config
57535 var nestedConfig = config['config'];
57536 nestedConfig['customObjects'] = customObjectsCombined;
57537 var backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
57538 for (var _i4 = 0, _Object$keys3 = Object.keys(customObjects); _i4 < _Object$keys3.length; _i4++) {
57539 var _key2 = _Object$keys3[_i4];
57540 _GLOBAL_CUSTOM_OBJECTS[_key2] = customObjects[_key2];
57541 }
57542 convertNDArrayScalarsInConfig(config['config']);
57543 var returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit);
57544 _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
57545 return returnObj;
57546 } else {
57547 // Then `cls` may be a function returning a class.
57548 // In this case by convention `config` holds
57549 // the kwargs of the function.
57550 var _backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
57551 for (var _i5 = 0, _Object$keys4 = Object.keys(customObjects); _i5 < _Object$keys4.length; _i5++) {
57552 var _key3 = _Object$keys4[_i5];
57553 _GLOBAL_CUSTOM_OBJECTS[_key3] = customObjects[_key3];
57554 }
57555 // In python this is **config['config'], for tfjs-layers we require
57556 // classes that use this fall-through construction method to take
57557 // a config interface that mimics the expansion of named parameters.
57558 var _returnObj = new cls(config['config']);
57559 _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, _backupCustomObjects);
57560 return _returnObj;
57561 }
57562 }
57563 }
57564 /**
57565 * Compares two numbers for sorting.
57566 * @param a
57567 * @param b
57568 */
57569 function numberCompare(a, b) {
57570 return a < b ? -1 : a > b ? 1 : 0;
57571 }
57572 /**
57573 * Comparison of two numbers for reverse sorting.
57574 * @param a
57575 * @param b
57576 */
57577 function reverseNumberCompare(a, b) {
57578 return -1 * numberCompare(a, b);
57579 }
57580 /**
57581 * Convert a string into the corresponding DType.
57582 * @param dtype
57583 * @returns An instance of DType.
57584 */
57585 function stringToDType(dtype) {
57586 switch (dtype) {
57587 case 'float32':
57588 return 'float32';
57589 default:
57590 throw new ValueError("Invalid dtype: ".concat(dtype));
57591 }
57592 }
57593 /**
57594 * Test the element-by-element equality of two Arrays of strings.
57595 * @param xs First array of strings.
57596 * @param ys Second array of strings.
57597 * @returns Wether the two arrays are all equal, element by element.
57598 */
57599 function stringsEqual(xs, ys) {
57600 if (xs == null || ys == null) {
57601 return xs === ys;
57602 }
57603 if (xs.length !== ys.length) {
57604 return false;
57605 }
57606 for (var i = 0; i < xs.length; ++i) {
57607 if (xs[i] !== ys[i]) {
57608 return false;
57609 }
57610 }
57611 return true;
57612 }
57613 /**
57614 * Get the unique elements of an array.
57615 * @param xs Array.
57616 * @returns An Array consisting of the unique elements in `xs`.
57617 */
57618 function unique$2(xs) {
57619 if (xs == null) {
57620 return xs;
57621 }
57622 var out = [];
57623 // TODO(cais): Maybe improve performance by sorting.
57624 var _iterator3 = _createForOfIteratorHelper(xs),
57625 _step3;
57626 try {
57627 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
57628 var x = _step3.value;
57629 if (out.indexOf(x) === -1) {
57630 out.push(x);
57631 }
57632 }
57633 } catch (err) {
57634 _iterator3.e(err);
57635 } finally {
57636 _iterator3.f();
57637 }
57638 return out;
57639 }
57640 /**
57641 * Determine if an Object is empty (i.e., does not have own properties).
57642 * @param obj Object
57643 * @returns Whether the Object is empty.
57644 * @throws ValueError: If object is `null` or `undefined`.
57645 */
57646 function isObjectEmpty(obj) {
57647 if (obj == null) {
57648 throw new ValueError("Invalid value in obj: ".concat(JSON.stringify(obj)));
57649 }
57650 for (var key in obj) {
57651 if (obj.hasOwnProperty(key)) {
57652 return false;
57653 }
57654 }
57655 return true;
57656 }
57657 /**
57658 * Helper function used to build type union/enum run-time checkers.
57659 * @param values The list of allowed values.
57660 * @param label A string name for the type
57661 * @param value The value to test.
57662 * @throws ValueError: If the value is not in values nor `undefined`/`null`.
57663 */
57664 function checkStringTypeUnionValue(values, label, value) {
57665 if (value == null) {
57666 return;
57667 }
57668 if (values.indexOf(value) < 0) {
57669 throw new ValueError("".concat(value, " is not a valid ").concat(label, ". Valid values are ").concat(values, " or null/undefined."));
57670 }
57671 }
57672 /**
57673 * Helper function for verifying the types of inputs.
57674 *
57675 * Ensures that the elements of `x` are all of type `expectedType`.
57676 * Also verifies that the length of `x` is within bounds.
57677 *
57678 * @param x Object to test.
57679 * @param expectedType The string expected type of all of the elements in the
57680 * Array.
57681 * @param minLength Return false if x.length is less than this.
57682 * @param maxLength Return false if x.length is greater than this.
57683 * @returns true if and only if `x` is an `Array<expectedType>` with
57684 * length >= `minLength` and <= `maxLength`.
57685 */
57686 // tslint:disable:no-any
57687 function checkArrayTypeAndLength(x, expectedType) {
57688 var minLength = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
57689 var maxLength = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : Infinity;
57690 assert(minLength >= 0);
57691 assert(maxLength >= minLength);
57692 return Array.isArray(x) && x.length >= minLength && x.length <= maxLength && x.every(function (e) {
57693 return _typeof(e) === expectedType;
57694 });
57695 }
57696 // tslint:enable:no-any
57697 /**
57698 * Assert that a value or an array of value are positive integer.
57699 *
57700 * @param value The value being asserted on. May be a single number or an array
57701 * of numbers.
57702 * @param name Name of the value, used to make the error message.
57703 */
57704 function assertPositiveInteger(value, name) {
57705 if (Array.isArray(value)) {
57706 assert$1(value.length > 0, function () {
57707 return "".concat(name, " is unexpectedly an empty array.");
57708 });
57709 value.forEach(function (v, i) {
57710 return assertPositiveInteger(v, "element ".concat(i + 1, " of ").concat(name));
57711 });
57712 } else {
57713 assert$1(Number.isInteger(value) && value > 0, function () {
57714 return "Expected ".concat(name, " to be a positive integer, but got ") + "".concat(formatAsFriendlyString(value), ".");
57715 });
57716 }
57717 }
57718 /**
57719 * Format a value into a display-friendly, human-readable fashion.
57720 *
57721 * - `null` is formatted as `'null'`
57722 * - Strings are formated with flanking pair of quotes.
57723 * - Arrays are formatted with flanking pair of square brackets.
57724 *
57725 * @param value The value to display.
57726 * @return Formatted string.
57727 */
57728 // tslint:disable-next-line:no-any
57729 function formatAsFriendlyString(value) {
57730 if (value === null) {
57731 return 'null';
57732 } else if (Array.isArray(value)) {
57733 return '[' + value.map(function (v) {
57734 return formatAsFriendlyString(v);
57735 }).join(',') + ']';
57736 } else if (typeof value === 'string') {
57737 return "\"".concat(value, "\"");
57738 } else {
57739 return "".concat(value);
57740 }
57741 }
57742 /**
57743 * Returns a function `f2` (decorator) which wraps the original function
57744 * `f`. `f2` guarantees that `f` can be called at most once
57745 * every `waitMs` ms. If `f2` is called more often, it will return
57746 * the last returned result of `f`.
57747 *
57748 * @param f The original function `f` to wrap.
57749 * @param waitMs The time between two consecutive calls to `f` in ms.
57750 */
57751 function debounce(f, waitMs, nowFunc) {
57752 var lastTime = nowFunc != null ? nowFunc() : now();
57753 var lastResult;
57754 var f2 = function f2() {
57755 var now$1 = nowFunc != null ? nowFunc() : now();
57756 if (now$1 - lastTime < waitMs) {
57757 return lastResult;
57758 }
57759 lastTime = now$1;
57760 lastResult = f.apply(void 0, arguments);
57761 return lastResult;
57762 };
57763 return f2;
57764 }
57765 /**
57766 * Returns the fusable activation given a layers identifier.
57767 *
57768 * @param activationName The layers identifier string.
57769 * @return The name of the fusable activation.
57770 */
57771 function mapActivationToFusedKernel(activationName) {
57772 if (activationName === 'relu') {
57773 return 'relu';
57774 }
57775 if (activationName === 'linear') {
57776 return 'linear';
57777 }
57778 if (activationName === 'elu') {
57779 return 'elu';
57780 }
57781 return null;
57782 }
57783 /**
57784 * Returns the cartesian product of sets of values.
57785 * This works the same as itertools.product in Python.
57786 *
57787 * Example:
57788 *
57789 * filters = [128, 256, 512]
57790 * paddings = ['same', 'valid']
57791 *
57792 * product = [ [128, 'same'], [128, 'valid'], [256, 'same'], [256, 'valid'],
57793 * [512, 'same'], [512, 'valid']]
57794 *
57795 * @param arrayOfValues List/array of values.
57796 * @return The cartesian product.
57797 */
57798 function getCartesianProductOfValues() {
57799 for (var _len = arguments.length, arrayOfValues = new Array(_len), _key4 = 0; _key4 < _len; _key4++) {
57800 arrayOfValues[_key4] = arguments[_key4];
57801 }
57802 assert(arrayOfValues.length > 0, 'arrayOfValues is empty');
57803 for (var _i6 = 0, _arrayOfValues = arrayOfValues; _i6 < _arrayOfValues.length; _i6++) {
57804 var values = _arrayOfValues[_i6];
57805 assert(Array.isArray(values), 'one of the values is not an array');
57806 assert(values.length > 0, 'one of the values is empty');
57807 }
57808 return arrayOfValues.reduce(function (products, values) {
57809 if (products.length === 0) {
57810 return values.map(function (value) {
57811 return [value];
57812 });
57813 }
57814 return values.map(function (value) {
57815 return products.map(function (prevValue) {
57816 return [].concat(_toConsumableArray(prevValue), [value]);
57817 });
57818 }).reduce(function (flattenedProduct, unflattenedProduct) {
57819 return flattenedProduct.concat(unflattenedProduct);
57820 }, []);
57821 }, []);
57822 }
57823
57824 /**
57825 * @license
57826 * Copyright 2018 Google LLC
57827 *
57828 * Use of this source code is governed by an MIT-style
57829 * license that can be found in the LICENSE file or at
57830 * https://opensource.org/licenses/MIT.
57831 * =============================================================================
57832 */
57833 /**
57834 * Utilities related to persistent state in the backend.
57835 */
57836 /**
57837 * An ID to track `tf.SymbolicTensor`s and derived classes.
57838 * Required in different places in engine/topology.ts to identify unique
57839 * tensors.
57840 */
57841 var _nextUniqueTensorId = 0;
57842 function getNextUniqueTensorId() {
57843 return _nextUniqueTensorId++;
57844 }
57845 var _uidPrefixes = {};
57846 /**
57847 * Provides a unique UID given a string prefix.
57848 *
57849 * @param prefix
57850 */
57851 function getUid() {
57852 var prefix = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : '';
57853 if (!(prefix in _uidPrefixes)) {
57854 _uidPrefixes[prefix] = 0;
57855 }
57856 _uidPrefixes[prefix] += 1;
57857 return prefix + _uidPrefixes[prefix].toString();
57858 }
57859
57860 /**
57861 * @license
57862 * Copyright 2018 Google LLC
57863 *
57864 * Use of this source code is governed by an MIT-style
57865 * license that can be found in the LICENSE file or at
57866 * https://opensource.org/licenses/MIT.
57867 * =============================================================================
57868 */
57869 var VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast'];
57870 var VALID_INTERPOLATION_FORMAT_VALUES = ['nearest', 'bilinear'];
57871 var VALID_PADDING_MODE_VALUES = ['valid', 'same', 'causal'];
57872 var VALID_POOL_MODE_VALUES = ['max', 'avg'];
57873 var VALID_BIDIRECTIONAL_MERGE_MODES = ['sum', 'mul', 'concat', 'ave'];
57874 var VALID_SAMPLE_WEIGHT_MODES = ['temporal'];
57875
57876 /**
57877 * @license
57878 * Copyright 2018 Google LLC
57879 *
57880 * Use of this source code is governed by an MIT-style
57881 * license that can be found in the LICENSE file or at
57882 * https://opensource.org/licenses/MIT.
57883 * =============================================================================
57884 */
57885 // A map from the requested scoped name of a Tensor to the number of Tensors
57886 // wanting that name so far. This allows enforcing name uniqueness by appending
57887 // an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc.
57888 var nameMap = new Map();
57889 function checkDataFormat(value) {
57890 checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);
57891 }
57892 function checkInterpolationFormat(value) {
57893 checkStringTypeUnionValue(VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value);
57894 }
57895 function checkPaddingMode(value) {
57896 checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value);
57897 }
57898 function checkPoolMode(value) {
57899 checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value);
57900 }
57901 var _nameScopeStack = [];
57902 var _nameScopeDivider = '/';
57903 /**
57904 * Enter namescope, which can be nested.
57905 */
57906 function nameScope(name, fn) {
57907 _nameScopeStack.push(name);
57908 try {
57909 var val = fn();
57910 _nameScopeStack.pop();
57911 return val;
57912 } catch (e) {
57913 _nameScopeStack.pop();
57914 throw e;
57915 }
57916 }
57917 /**
57918 * Get the current namescope as a flat, concatenated string.
57919 */
57920 function currentNameScopePrefix() {
57921 if (_nameScopeStack.length === 0) {
57922 return '';
57923 } else {
57924 return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
57925 }
57926 }
57927 /**
57928 * Get the name a Tensor (or Variable) would have if not uniqueified.
57929 * @param tensorName
57930 * @return Scoped name string.
57931 */
57932 function getScopedTensorName(tensorName) {
57933 if (!isValidTensorName(tensorName)) {
57934 throw new Error('Not a valid tensor name: \'' + tensorName + '\'');
57935 }
57936 return currentNameScopePrefix() + tensorName;
57937 }
57938 /**
57939 * Get unique names for Tensors and Variables.
57940 * @param scopedName The fully-qualified name of the Tensor, i.e. as produced by
57941 * `getScopedTensorName()`.
57942 * @return A unique version of the given fully scoped name.
57943 * If this is the first time that the scoped name is seen in this session,
57944 * then the given `scopedName` is returned unaltered. If the same name is
57945 * seen again (producing a collision), an incrementing suffix is added to the
57946 * end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc.
57947 */
57948 function getUniqueTensorName(scopedName) {
57949 if (!isValidTensorName(scopedName)) {
57950 throw new Error('Not a valid tensor name: \'' + scopedName + '\'');
57951 }
57952 if (!nameMap.has(scopedName)) {
57953 nameMap.set(scopedName, 0);
57954 }
57955 var index = nameMap.get(scopedName);
57956 nameMap.set(scopedName, nameMap.get(scopedName) + 1);
57957 if (index > 0) {
57958 var result = "".concat(scopedName, "_").concat(index);
57959 // Mark the composed name as used in case someone wants
57960 // to call getUniqueTensorName("name_1").
57961 nameMap.set(result, 1);
57962 return result;
57963 } else {
57964 return scopedName;
57965 }
57966 }
57967 var tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
57968 /**
57969 * Determine whether a string is a valid tensor name.
57970 * @param name
57971 * @returns A Boolean indicating whether `name` is a valid tensor name.
57972 */
57973 function isValidTensorName(name) {
57974 return !!name.match(tensorNameRegex);
57975 }
57976
57977 /**
57978 * @license
57979 * Copyright 2018 Google LLC
57980 *
57981 * Use of this source code is governed by an MIT-style
57982 * license that can be found in the LICENSE file or at
57983 * https://opensource.org/licenses/MIT.
57984 * =============================================================================
57985 */
57986 /**
57987 * Determine if a number is an integer.
57988 */
57989 function isInteger(x) {
57990 return x === parseInt(x.toString(), 10);
57991 }
57992 /**
57993 * Calculate the product of an array of numbers.
57994 * @param array The array to calculate the product over.
57995 * @param begin Beginning index, inclusive.
57996 * @param end Ending index, exclusive.
57997 * @return The product.
57998 */
57999 function arrayProd(array, begin, end) {
58000 if (begin == null) {
58001 begin = 0;
58002 }
58003 if (end == null) {
58004 end = array.length;
58005 }
58006 var prod = 1;
58007 for (var i = begin; i < end; ++i) {
58008 prod *= array[i];
58009 }
58010 return prod;
58011 }
58012 /**
58013 * Compute minimum value.
58014 * @param array
58015 * @return minimum value.
58016 */
58017 function min$2(array) {
58018 // same behavior as tf.min()
58019 if (array.length === 0) {
58020 return Number.NaN;
58021 }
58022 var min = Number.POSITIVE_INFINITY;
58023 for (var i = 0; i < array.length; i++) {
58024 var value = array[i];
58025 if (value < min) {
58026 min = value;
58027 }
58028 }
58029 return min;
58030 }
58031 /**
58032 * Compute maximum value.
58033 * @param array
58034 * @return maximum value
58035 */
58036 function max$2(array) {
58037 // same behavior as tf.max()
58038 if (array.length === 0) {
58039 return Number.NaN;
58040 }
58041 var max = Number.NEGATIVE_INFINITY;
58042 for (var i = 0; i < array.length; i++) {
58043 var value = array[i];
58044 if (value > max) {
58045 max = value;
58046 }
58047 }
58048 return max;
58049 }
58050 /**
58051 * Compute sum of array.
58052 * @param array
58053 * @return The sum.
58054 */
58055 function sum$2(array) {
58056 var sum = 0;
58057 for (var i = 0; i < array.length; i++) {
58058 var value = array[i];
58059 sum += value;
58060 }
58061 return sum;
58062 }
58063 /**
58064 * Compute mean of array.
58065 * @param array
58066 * @return The mean.
58067 */
58068 function mean$1(array) {
58069 return sum$2(array) / array.length;
58070 }
58071 /**
58072 * Compute variance of array.
58073 * @param array
58074 * @return The variance.
58075 */
58076 function variance(array) {
58077 var meanValue = mean$1(array);
58078 var demeaned = array.map(function (value) {
58079 return value - meanValue;
58080 });
58081 var sumSquare = 0;
58082 for (var i = 0; i < demeaned.length; i++) {
58083 var value = demeaned[i];
58084 sumSquare += value * value;
58085 }
58086 return sumSquare / array.length;
58087 }
58088 /**
58089 * Compute median of array.
58090 * @param array
58091 * @return The median value.
58092 */
58093 function median(array) {
58094 var arraySorted = array.slice().sort(function (a, b) {
58095 return a - b;
58096 });
58097 var lowIdx = Math.floor((arraySorted.length - 1) / 2);
58098 var highIdx = Math.ceil((arraySorted.length - 1) / 2);
58099 if (lowIdx === highIdx) {
58100 return arraySorted[lowIdx];
58101 }
58102 return (arraySorted[lowIdx] + arraySorted[highIdx]) / 2;
58103 }
58104 /**
58105 * Generate an array of integers in [begin, end).
58106 * @param begin Beginning integer, inclusive.
58107 * @param end Ending integer, exclusive.
58108 * @returns Range array.
58109 * @throws ValueError, iff `end` < `begin`.
58110 */
58111 function range$2(begin, end) {
58112 if (end < begin) {
58113 throw new ValueError("end (".concat(end, ") < begin (").concat(begin, ") is forbidden."));
58114 }
58115 var out = [];
58116 for (var i = begin; i < end; ++i) {
58117 out.push(i);
58118 }
58119 return out;
58120 }
58121
58122 /**
58123 * @license
58124 * Copyright 2018 Google LLC
58125 *
58126 * Use of this source code is governed by an MIT-style
58127 * license that can be found in the LICENSE file or at
58128 * https://opensource.org/licenses/MIT.
58129 * =============================================================================
58130 */
58131 var _epsilon;
58132 /**
58133 * Returns the value of the fuzz factor used in numeric expressions.
58134 */
58135 function epsilon$1() {
58136 if (_epsilon == null) {
58137 _epsilon = backend$1().epsilon();
58138 }
58139 return _epsilon;
58140 }
58141 /**
58142 * Sets the value of the fuzz factor used in numeric expressions.
58143 * @param e New value of epsilon.
58144 */
58145 function setEpsilon(e) {
58146 _epsilon = e;
58147 }
58148 /**
58149 * Returns the default image data format convention.
58150 */
58151 function imageDataFormat() {
58152 return 'channelsLast';
58153 }
58154
58155 // tslint:enable
58156 /* Setting and getting backend from deeplearn.js. */
58157 // Default deeplearn.js backend is WebGL (GPU).
58158 var backend = 'webgl';
58159 function setBackend(requestedBackend) {
58160 setBackend$1(requestedBackend);
58161 backend = requestedBackend;
58162 }
58163 function getBackend() {
58164 return backend;
58165 }
58166 /**
58167 * Indicates whether the backend is operating symbolically.
58168 *
58169 * This function will be used to determine how to interpret user code. If
58170 * it returns true, calls to the backend construct a symbolic graph; if
58171 * it returns false, calls to the backend execute immediately.
58172 */
58173 function isBackendSymbolic() {
58174 return false;
58175 }
58176 /**
58177 * Get the number of elements in a Tensor.
58178 * @param x The Tensor.
58179 * @return Number of elements in `x`.
58180 */
58181 function countParams(x) {
58182 var shape = x.shape;
58183 if (shape.length > 0) {
58184 return shape.reduce(function (a, b) {
58185 return a * b;
58186 });
58187 } else {
58188 // Scalar.
58189 return 1;
58190 }
58191 }
58192 /**
58193 * Casts a tensor to a different dtype and returns it.
58194 * @param x Input tensor.
58195 * @param dtype String: 'float32'|'int32'|'bool'.
58196 * @returns Tensor of the specified `dtype`.
58197 */
58198 function cast$2(x, dtype) {
58199 return cast$3(x, dtype);
58200 }
58201 /**
58202 * Adds a 1-sized dimension at index "axis".
58203 * @param x Input tensor.
58204 * @param axis Position where to add the new axis.
58205 * @returns Result of the dimension expansion.
58206 */
58207 function expandDims$2(x) {
58208 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : -1;
58209 var outShape = x.shape.slice();
58210 if (axis < 0) {
58211 axis = outShape.length + axis + 1;
58212 }
58213 outShape.splice(axis, 0, 1);
58214 return reshape$3(x, outShape);
58215 }
58216 /**
58217 * Repeats a 2D tensor.
58218 *
58219 * If `x` has shape `[samples, dim]` and `n` is 2, for example, the output
58220 * will have shape `[samples, 2, dim]`.
58221 *
58222 * @param x Input tensor.
58223 * @param n Integer, number of times to repeat.
58224 * @returns The result of the repeat operation.
58225 * @throws ValueError: If input tensor is not 2D.
58226 */
58227 function repeat(x, n) {
58228 return tidy(function () {
58229 if (x.shape.length !== 2) {
58230 throw new ValueError("repeat() expects a rank-2 tensor, but received a " + "rank-".concat(x.shape.length, " tensor."));
58231 }
58232 var y = expandDims$2(x, 1);
58233 return tile$2(y, [1, n, 1]);
58234 });
58235 }
58236 /**
58237 * Flatten a Tensor into 1D.
58238 * @param x Input tensor.
58239 * @return The result of the flattening `x`.
58240 */
58241 function flatten$1(x) {
58242 var newShape = [arrayProd(x.shape)];
58243 return reshape$3(x, newShape);
58244 }
58245 /**
58246 * Turn a nD tensor into a 2D tensor with same 0th dimension.
58247 * In other words, it flattens each data samples of a batch.
58248 *
58249 * @param x The tensor to flatten. The rank of this tensor is required to be 2
58250 * or higher.
58251 * @return The result of the flattening.
58252 */
58253 function batchFlatten(x) {
58254 if (x.rank <= 1) {
58255 throw new ValueError("batchFlatten requires a minimum rank of 2. Got rank: ".concat(x.rank, "."));
58256 }
58257 var newShape = [x.shape[0], arrayProd(x.shape, 1)];
58258 return reshape$3(x, newShape);
58259 }
58260 /**
58261 * Do slicing along the first axis.
58262 * @param array input `tf.Tensor`.
58263 * @param start starting index, inclusive.
58264 * @param size size of the slice along the first axis.
58265 * @returns result of the slicing.
58266 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
58267 */
58268 function sliceAlongFirstAxis(array, start, size) {
58269 return tidy(function () {
58270 switch (array.rank) {
58271 case 1:
58272 return slice1d(array, start, size);
58273 case 2:
58274 return slice2d(array, [start, 0], [size, array.shape[1]]);
58275 case 3:
58276 return slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
58277 case 4:
58278 return slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
58279 case 5:
58280 return slice$2(array, [start, 0, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]]);
58281 case 6:
58282 return slice$2(array, [start, 0, 0, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3], array.shape[4], array.shape[5]]);
58283 default:
58284 throw new ValueError("sliceAlongFirstAxis() received an unsupported tensor rank: " + "".concat(array.rank));
58285 }
58286 });
58287 }
58288 /**
58289 * Do slicing along the last axis.
58290 * @param array input `tf.Tensor`.
58291 * @param start starting index, inclusive.
58292 * @param size size of the slice along the last axis.
58293 * @returns result of the slicing.
58294 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
58295 */
58296 function sliceAlongLastAxis(array, start, size) {
58297 return tidy(function () {
58298 switch (array.rank) {
58299 case 1:
58300 return slice1d(array, start, size);
58301 case 2:
58302 return slice2d(array, [0, start], [array.shape[0], size]);
58303 case 3:
58304 return slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]);
58305 case 4:
58306 return slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]);
58307 default:
58308 throw new ValueError("sliceAlongLastAxis() received an unsupported tensor rank: " + "".concat(array.rank));
58309 }
58310 });
58311 }
58312 /**
58313 * Do slicing along the sepcified axis.
58314 * @param array input `tf.Tensor`.
58315 * @param start starting index, inclusive.
58316 * @param size of the slice along the chosen axis.
58317 * @param choose an axis.
58318 * @returns result of the slicing.
58319 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
58320 */
58321 function sliceAlongAxis(array, start, size, axis) {
58322 return tidy(function () {
58323 switch (array.rank) {
58324 case 1:
58325 return slice1d(array, start, size);
58326 case 2:
58327 switch (axis) {
58328 case 1:
58329 return sliceAlongFirstAxis(array, start, size);
58330 case 2:
58331 return sliceAlongLastAxis(array, start, size);
58332 default:
58333 throw new ValueError("The axis is not within the rank of the tensor " + "".concat(axis));
58334 }
58335 case 3:
58336 switch (axis) {
58337 case 1:
58338 return sliceAlongFirstAxis(array, start, size);
58339 case 2:
58340 return slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]);
58341 case 3:
58342 return sliceAlongLastAxis(array, start, size);
58343 default:
58344 throw new ValueError("The axis is not within the rank of the tensor " + "".concat(axis));
58345 }
58346 case 4:
58347 switch (axis) {
58348 case 1:
58349 return sliceAlongFirstAxis(array, start, size);
58350 case 2:
58351 return slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]);
58352 case 3:
58353 return slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]);
58354 case 4:
58355 return sliceAlongLastAxis(array, start, size);
58356 default:
58357 throw new ValueError("The axis is not within the rank of the tensor " + "".concat(axis));
58358 }
58359 default:
58360 throw new ValueError("sliceAlongLastAxis() received an unsupported tensor rank: " + "".concat(array.rank));
58361 }
58362 });
58363 }
58364 /**
58365 * Concatenates a list of tensors alongside the specified axis.
58366 * @param tensors `Array` of tensors to concatenate.
58367 * @param axis Concatenation axis.
58368 * @returns The result of the concatenation.
58369 */
58370 function concatenate$2(tensors) {
58371 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : -1;
58372 var rank;
58373 if (axis < 0) {
58374 rank = tensors[0].rank;
58375 if (rank !== 0) {
58376 axis = rank;
58377 } else {
58378 axis = 0;
58379 }
58380 }
58381 if (axis === tensors[0].rank) {
58382 // Porting Note: This is necessary because tfc.concat() requires axis to be
58383 // in the interval [-rank, rank).
58384 axis = -1;
58385 }
58386 // Porting Note: Sparse concat is not supported yet.
58387 return concat$2(tensors, axis);
58388 }
58389 /**
58390 * Concatenate two arrays along the first dimension.
58391 * @param a The 1st `tf.Tensor` to concatenate.
58392 * @param b The 2nd `tf.Tensor` to concatenate.
58393 * @returns Result of the concatenation.
58394 * @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`.
58395 */
58396 function concatAlongFirstAxis(a, b) {
58397 switch (a.rank) {
58398 case 1:
58399 return concat1d([a, b]);
58400 case 2:
58401 return concat2d([a, b], 0);
58402 case 3:
58403 return concat3d([a, b], 0);
58404 case 4:
58405 return concat4d([a, b], 0);
58406 default:
58407 throw new ValueError("concatAlongFirstAxis() received an unsupported " + "tensor rank: ".concat(a.rank));
58408 }
58409 }
58410 /**
58411 * Creates a tensor by tiling `x` by `n`.
58412 * @param x A tensor.
58413 * @param n An Array of integers or a single integer. If an Array, the length
58414 * must be the same as the number of dimensions in `x`. If a single integer,
58415 * it will be treated as an Array of length 1.
58416 */
58417 function tile$2(x, n) {
58418 if (!Array.isArray(n)) {
58419 n = [n];
58420 }
58421 if (x.rank !== n.length) {
58422 throw new ValueError("The length of input n (".concat(n.length, ") does not match ") + "the number of dimensions in input x (".concat(x.rank, ")"));
58423 }
58424 return tile$3(x, n);
58425 }
58426 /* Creation of random tensors. */
58427 /**
58428 * Get a tensor with normal distribution of values.
58429 *
58430 * @param shape Shape of the tensor.
58431 * @param mean mean value of the normal distribution.
58432 * @param stddev standard deviation of the normal distribution.
58433 * @param dtype
58434 * @param seed
58435 * @return The normal tensor.
58436 */
58437 function randomNormal$1(shape) {
58438 var mean = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0.0;
58439 var stddev = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1.0;
58440 var dtype = arguments.length > 3 ? arguments[3] : undefined;
58441 var seed = arguments.length > 4 ? arguments[4] : undefined;
58442 return randomNormal$2(shape, mean, stddev, dtype, seed);
58443 }
58444 /* Linear Algebra */
58445 /**
58446 * Multiply two tensors and returns the result as a tensor.
58447 *
58448 * For 2D tensors, this is equivalent to matrix multiplication (matMul).
58449 * For tensors of higher ranks, it follows the Theano behavior,
58450 * (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`). From the Theano documentation:
58451 *
58452 * For N dimensions it is a sum product over the last axis of x and the
58453 * second-to-last of y:
58454 *
58455 * @param a A tensor of at least rank 2.
58456 * @param b A tensor of at least rank 2.
58457 * @param activation (optional) A string identifying the activation
58458 * function.
58459 * @return Result of the dot operation.
58460 */
58461 function dot$1(a, b, activation, bias) {
58462 if (a.rank < 2 || b.rank < 2) {
58463 throw new NotImplementedError("dot requires both inputs to be rank >= 2" + " but got x shape = ".concat(a.shape, " and y shape = ").concat(b.shape));
58464 }
58465 if (b.rank >= 3) {
58466 var xLastDim = a.shape.slice(-1)[0];
58467 var ySecondLastDim = b.shape.slice(-2)[0];
58468 if (xLastDim !== ySecondLastDim) {
58469 throw new NotImplementedError("If rank y >= 3, then the second last dim" + " of y must equal the last dim of x but got x shape = ".concat(a.shape, " and ") + " y shape = ".concat(b.shape));
58470 }
58471 }
58472 // Handle basic 2D x 2D case.
58473 if (a.rank === 2 && b.rank === 2) {
58474 var transposeA = false;
58475 var transposeB = false;
58476 // tfc.fused.matMul only fuses certain activation functions. Unsupported
58477 // activation functions are treated as 'linear' activations, which is
58478 // equivalent to a no-op.
58479 return matMul({
58480 a: a,
58481 b: b,
58482 transposeA: transposeA,
58483 transposeB: transposeB,
58484 bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
58485 activation: activation
58486 });
58487 } else {
58488 // Reshape x into the analogous 2D Tensor.
58489 var aFirstDims = a.shape.slice(); // Holds all but the last dim of x.
58490 var aLastDim = aFirstDims.pop();
58491 a = reshape$3(a, [-1, aLastDim]);
58492 // Reshape y into the analogous 2D Tensor, and keep track of the
58493 // required dimensions to reproduce the output shape.
58494 var bShape = b.shape.slice();
58495 var bLastDim = bShape.pop();
58496 var _ySecondLastDim = bShape.pop();
58497 var yOtherDims = [].concat(_toConsumableArray(bShape), [bLastDim]);
58498 // permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1]
58499 // where r is the rank of y.
58500 var perm = Array.from({
58501 length: b.rank
58502 }, function (_, i) {
58503 if (i === 0) {
58504 return b.rank - 2;
58505 } else if (i <= b.rank - 2) {
58506 return i - 1;
58507 }
58508 return i;
58509 });
58510 b = reshape$3(transpose$2(b, perm), [_ySecondLastDim, -1]);
58511 // Multiply x and y as 2D Tensors, and then reshape back to original.
58512 var outputShape = [].concat(_toConsumableArray(aFirstDims), _toConsumableArray(yOtherDims));
58513 var _transposeA = false;
58514 var _transposeB = false;
58515 return reshape$3(matMul({
58516 a: a,
58517 b: b,
58518 transposeA: _transposeA,
58519 transposeB: _transposeB,
58520 bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
58521 activation: activation
58522 }), outputShape);
58523 }
58524 }
58525 /**
58526 * Compute the sign Tensor of an input Tensor.
58527 *
58528 * Elements of the input `tf.Tensor` that are === 0 are mapped to 0.
58529 * Elements of the input `tf.Tensor` that are > 0 are mapped to 1.
58530 * Elements of the input `tf.Tensor` that are < 0 are mapped to -1.
58531 *
58532 * @param x Input `tf.Tensor`.
58533 * @return The sign `tf.Tensor`.
58534 */
58535 function sign$2(x) {
58536 // TODO(cais): Move to the core.
58537 return tidy(function () {
58538 var zerosLikeX = zerosLike$3(x);
58539 var onesLikeX = onesLike$3(x);
58540 return where(equal$2(x, zerosLikeX), zerosLikeX, where(greater$3(x, zerosLike$3(x)), onesLikeX, mul(-1, onesLikeX)));
58541 });
58542 }
58543 /**
58544 * Computes the one-hot representation of an integer tensor.
58545 * @param indices nD integer tensor of shape
58546 * `(batch_size, dim1, dim2, ... dim(n-1))`
58547 * @param numClasses Integer, number of classes to consider.
58548 * @returns (n + 1)D one hot representation of the input
58549 * with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
58550 */
58551 function oneHot$2(indices, numClasses) {
58552 return tidy(function () {
58553 if (indices.rank !== 1) {
58554 throw new Error('Only 1D one-hot tensors are supported in the ' + 'deeplearn backend, at present.');
58555 }
58556 indices = cast$3(indices, 'int32');
58557 return cast$3(oneHot$3(indices, numClasses), 'float32');
58558 });
58559 }
58560 /* Elementary math functions. */
58561 /**
58562 * Retrieves the elements of indices `indices` in the tensor `reference`.
58563 * @param reference A tensor.
58564 * @param indices An integer tensor of indices or an `Array` of integers.
58565 * @param axis Axis along which to perform the gather operation.
58566 * @returns The result of the gathering as a tensor.
58567 */
58568 function gather(reference, indices, axis) {
58569 return tidy(function () {
58570 if (Array.isArray(indices)) {
58571 indices = tensor1d(indices, 'int32');
58572 } else {
58573 indices = cast$3(indices, 'int32');
58574 }
58575 return gather$1(reference, indices, axis);
58576 });
58577 }
58578 /**
58579 * Element-wise square.
58580 * @param x Input tensor.
58581 * @return element-wise x^2
58582 */
58583 function square$1(x) {
58584 return mul(x, x);
58585 }
58586 /**
58587 * Element-wise exponentiation.
58588 *
58589 * Porting Note: In PyKeras, `a` (the exponent) is a Python integer, which
58590 * takes advatnage of the backend's (e.g., TensorFlow's) automatic
58591 * conversion to tensor. Here we allow `a` to be either a number or a tensor.
58592 *
58593 * @param x The base tensor.
58594 * @param a The exponent, tensor or number. If a number, it is rounded to the
58595 * nearest integer and converted to a tensor.
58596 * @returns A tensor of the same shape as `x`.
58597 */
58598 function pow$2(x, a) {
58599 return tidy(function () {
58600 if (typeof a === 'number') {
58601 a = scalar(Math.round(a), 'int32');
58602 }
58603 if (a.dtype !== 'int32') {
58604 throw new NotImplementedError("Non-int32 dtype (".concat(a.dtype, ") is not supported by pow() yet"));
58605 }
58606 return pow$3(x, a);
58607 });
58608 }
58609 /**
58610 * Reshapes bias tensor according to rank of x.
58611 */
58612 function reshapeBias(xRank, bias, dataFormat) {
58613 var biasShape = bias.shape;
58614 if (bias.rank !== 1 && bias.rank !== xRank) {
58615 throw new ValueError("Unexpected bias dimensions: ".concat(bias.rank) + "; expected it to be 1 or ".concat(xRank));
58616 }
58617 if (xRank === 5) {
58618 if (dataFormat === 'channelsFirst') {
58619 if (biasShape.length === 1) {
58620 return reshape$3(bias, [1, biasShape[0], 1, 1, 1]);
58621 } else {
58622 return reshape$3(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);
58623 }
58624 } else if (dataFormat === 'channelsLast') {
58625 if (biasShape.length === 1) {
58626 return reshape$3(bias, [1, 1, 1, 1, biasShape[0]]);
58627 } else {
58628 return reshape$3(bias, [1].concat(biasShape));
58629 }
58630 }
58631 } else if (xRank === 4) {
58632 if (dataFormat === 'channelsFirst') {
58633 if (biasShape.length === 1) {
58634 return reshape$3(bias, [1, biasShape[0], 1, 1]);
58635 } else {
58636 return reshape$3(bias, [1, biasShape[2], biasShape[0], biasShape[1]]);
58637 }
58638 } else if (dataFormat === 'channelsLast') {
58639 if (biasShape.length === 1) {
58640 return reshape$3(bias, [1, 1, 1, biasShape[0]]);
58641 } else {
58642 return reshape$3(bias, [1].concat(biasShape));
58643 }
58644 }
58645 } else if (xRank === 3) {
58646 if (dataFormat === 'channelsFirst') {
58647 if (biasShape.length === 1) {
58648 return reshape$3(bias, [1, biasShape[0], 1]);
58649 } else {
58650 return reshape$3(bias, [1, biasShape[1], biasShape[0]]);
58651 }
58652 } else if (dataFormat === 'channelsLast') {
58653 if (biasShape.length === 1) {
58654 return reshape$3(bias, [1, 1, biasShape[0]]);
58655 } else {
58656 return reshape$3(bias, [1].concat(biasShape));
58657 }
58658 }
58659 } else if (xRank < 3) {
58660 return bias;
58661 }
58662 throw new ValueError("Unsupported input rank by biasAdd: ".concat(bias.rank));
58663 }
58664 /* Neural-network operations. */
58665 /**
58666 * Add a bias to a tensor.
58667 *
58668 * @param x The tensor to add the bias to.
58669 * @param bias The bias to add to `x`. Must be 1D or the same rank as `x`.
58670 * @return Result of the bias adding.
58671 * @throws ValueError: If the rank of `bias` is incorrect.
58672 */
58673 function biasAdd(x, bias, dataFormat) {
58674 return tidy(function () {
58675 if (dataFormat == null) {
58676 dataFormat = imageDataFormat();
58677 }
58678 checkDataFormat(dataFormat);
58679 return add$3(x, reshapeBias(x.rank, bias, dataFormat));
58680 });
58681 }
58682 /**
58683 * Exponential linear unit (ELU).
58684 * @param x A tensor or variable to compute the activation function for.
58685 * @param alpha: A scalar, a scaling factor for the negative section.
58686 * @return Output of the ELU operation.
58687 */
58688 function elu$3(x) {
58689 var alpha = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 1;
58690 // TODO(cais): Add support for alpha values other than 1.
58691 if (alpha !== 1) {
58692 throw new NotImplementedError("Support for alpha values other than 1 (".concat(alpha, ") is not implemented ") + "yet.");
58693 }
58694 return elu$4(x);
58695 }
58696 /**
58697 * Softsign of a tensor.
58698 *
58699 * Defined as x / (abs(x) + 1), element-wise.
58700 *
58701 * @param x: Input.
58702 * @returns Output.
58703 */
58704 function softsign(x) {
58705 return tidy(function () {
58706 return div$1(x, add$3(abs$2(x), 1));
58707 });
58708 }
58709 /**
58710 * Sets entries in `x` to zero at random, while scaling the entire tensor.
58711 *
58712 * @param x input tensor.
58713 * @param level fraction of the entries in the tensor that will be set to 0.
58714 * @param noiseShape shape of randomly generated keep/drop flags, must be
58715 * broadcastable to the shape of `x`. Optional.
58716 * @param seed random seed to ensure determinism. Optional.
58717 * @returns Result of the dropout operation.
58718 */
58719 function dropout$1(x, level, noiseShape, seed) {
58720 return tidy(function () {
58721 return dropout$2(x, level, noiseShape, seed);
58722 });
58723 }
58724 /**
58725 * Element-wise, segment-wise linear approximation of sigmoid.
58726 *
58727 * Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
58728 * In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
58729 *
58730 * @param x Input tensor.
58731 * @returns Output tensor.
58732 */
58733 function hardSigmoid(x) {
58734 return tidy(function () {
58735 var y = add$3(.5, mul(.2, x));
58736 return clipByValue$2(y, 0, 1);
58737 });
58738 }
58739 /**
58740 * Invoke `x` in the training phase, and `alt` otherwise.
58741 *
58742 * Porting Note: We do not create placeholder tensors for the `training`
58743 * boolean flag here, because there is no such thing in the TF.js imperative
58744 * backend.
58745 *
58746 * @param x The function to invoke iff `training` is `true`.
58747 * @param alt The function to invoke iff `training` is `false`.
58748 * @param training Boolean flag for whether training phase is active.
58749 * @returns The return value of `x()` if `training` is `true`, or the return
58750 * value of `alt()` if `training` is `false`.
58751 */
58752 function inTrainPhase(x, alt) {
58753 var training = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
58754 return training ? x() : alt();
58755 }
58756
58757 /**
58758 * @license
58759 * Copyright 2018 Google LLC
58760 *
58761 * Use of this source code is governed by an MIT-style
58762 * license that can be found in the LICENSE file or at
58763 * https://opensource.org/licenses/MIT.
58764 * =============================================================================
58765 */
58766 var VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg'];
58767 var VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', 'truncatedNormal'];
58768 // We can't easily extract a string[] from the string union type, but we can
58769 // recapitulate the list, enforcing at compile time that the values are valid
58770 // and that we have the right number of them.
58771 /**
58772 * A string array of valid Initializer class names.
58773 *
58774 * This is guaranteed to match the `InitializerClassName` union type.
58775 */
58776 var initializerClassNames = ['Zeros', 'Ones', 'Constant', 'RandomNormal', 'RandomUniform', 'TruncatedNormal', 'VarianceScaling', 'Orthogonal', 'Identity'];
58777
58778 function checkFanMode(value) {
58779 checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value);
58780 }
58781 function checkDistribution(value) {
58782 checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value);
58783 }
58784 /**
58785 * Initializer base class.
58786 *
58787 * @doc {
58788 * heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'}
58789 */
58790 var Initializer = /*#__PURE__*/function (_serialization$Serial) {
58791 _inherits(Initializer, _serialization$Serial);
58792 var _super = _createSuper(Initializer);
58793 function Initializer() {
58794 _classCallCheck(this, Initializer);
58795 return _super.apply(this, arguments);
58796 }
58797 _createClass(Initializer, [{
58798 key: "fromConfigUsesCustomObjects",
58799 value: function fromConfigUsesCustomObjects() {
58800 return false;
58801 }
58802 }, {
58803 key: "getConfig",
58804 value: function getConfig() {
58805 return {};
58806 }
58807 }]);
58808 return Initializer;
58809 }(Serializable);
58810 var Zeros = /*#__PURE__*/function (_Initializer) {
58811 _inherits(Zeros, _Initializer);
58812 var _super2 = _createSuper(Zeros);
58813 function Zeros() {
58814 _classCallCheck(this, Zeros);
58815 return _super2.apply(this, arguments);
58816 }
58817 _createClass(Zeros, [{
58818 key: "apply",
58819 value: function apply(shape, dtype) {
58820 return zeros$2(shape, dtype);
58821 }
58822 }]);
58823 return Zeros;
58824 }(Initializer);
58825 /** @nocollapse */
58826 Zeros.className = 'Zeros';
58827 registerClass(Zeros);
58828 var Ones = /*#__PURE__*/function (_Initializer2) {
58829 _inherits(Ones, _Initializer2);
58830 var _super3 = _createSuper(Ones);
58831 function Ones() {
58832 _classCallCheck(this, Ones);
58833 return _super3.apply(this, arguments);
58834 }
58835 _createClass(Ones, [{
58836 key: "apply",
58837 value: function apply(shape, dtype) {
58838 return ones$1(shape, dtype);
58839 }
58840 }]);
58841 return Ones;
58842 }(Initializer);
58843 /** @nocollapse */
58844 Ones.className = 'Ones';
58845 registerClass(Ones);
58846 var Constant = /*#__PURE__*/function (_Initializer3) {
58847 _inherits(Constant, _Initializer3);
58848 var _super4 = _createSuper(Constant);
58849 function Constant(args) {
58850 var _this;
58851 _classCallCheck(this, Constant);
58852 _this = _super4.call(this);
58853 if (_typeof(args) !== 'object') {
58854 throw new ValueError("Expected argument of type ConstantConfig but got ".concat(args));
58855 }
58856 if (args.value === undefined) {
58857 throw new ValueError("config must have value set but got ".concat(args));
58858 }
58859 _this.value = args.value;
58860 return _this;
58861 }
58862 _createClass(Constant, [{
58863 key: "apply",
58864 value: function apply(shape, dtype) {
58865 var _this2 = this;
58866 return tidy(function () {
58867 return mul(scalar(_this2.value), ones$1(shape, dtype));
58868 });
58869 }
58870 }, {
58871 key: "getConfig",
58872 value: function getConfig() {
58873 return {
58874 value: this.value
58875 };
58876 }
58877 }]);
58878 return Constant;
58879 }(Initializer);
58880 /** @nocollapse */
58881 Constant.className = 'Constant';
58882 registerClass(Constant);
58883 var RandomUniform = /*#__PURE__*/function (_Initializer4) {
58884 _inherits(RandomUniform, _Initializer4);
58885 var _super5 = _createSuper(RandomUniform);
58886 function RandomUniform(args) {
58887 var _this3;
58888 _classCallCheck(this, RandomUniform);
58889 _this3 = _super5.call(this);
58890 _this3.DEFAULT_MINVAL = -0.05;
58891 _this3.DEFAULT_MAXVAL = 0.05;
58892 _this3.minval = args.minval || _this3.DEFAULT_MINVAL;
58893 _this3.maxval = args.maxval || _this3.DEFAULT_MAXVAL;
58894 _this3.seed = args.seed;
58895 return _this3;
58896 }
58897 _createClass(RandomUniform, [{
58898 key: "apply",
58899 value: function apply(shape, dtype) {
58900 return randomUniform$1(shape, this.minval, this.maxval, dtype, this.seed);
58901 }
58902 }, {
58903 key: "getConfig",
58904 value: function getConfig() {
58905 return {
58906 minval: this.minval,
58907 maxval: this.maxval,
58908 seed: this.seed
58909 };
58910 }
58911 }]);
58912 return RandomUniform;
58913 }(Initializer);
58914 /** @nocollapse */
58915 RandomUniform.className = 'RandomUniform';
58916 registerClass(RandomUniform);
58917 var RandomNormal = /*#__PURE__*/function (_Initializer5) {
58918 _inherits(RandomNormal, _Initializer5);
58919 var _super6 = _createSuper(RandomNormal);
58920 function RandomNormal(args) {
58921 var _this4;
58922 _classCallCheck(this, RandomNormal);
58923 _this4 = _super6.call(this);
58924 _this4.DEFAULT_MEAN = 0.;
58925 _this4.DEFAULT_STDDEV = 0.05;
58926 _this4.mean = args.mean || _this4.DEFAULT_MEAN;
58927 _this4.stddev = args.stddev || _this4.DEFAULT_STDDEV;
58928 _this4.seed = args.seed;
58929 return _this4;
58930 }
58931 _createClass(RandomNormal, [{
58932 key: "apply",
58933 value: function apply(shape, dtype) {
58934 dtype = dtype || 'float32';
58935 if (dtype !== 'float32' && dtype !== 'int32') {
58936 throw new NotImplementedError("randomNormal does not support dType ".concat(dtype, "."));
58937 }
58938 return randomNormal$1(shape, this.mean, this.stddev, dtype, this.seed);
58939 }
58940 }, {
58941 key: "getConfig",
58942 value: function getConfig() {
58943 return {
58944 mean: this.mean,
58945 stddev: this.stddev,
58946 seed: this.seed
58947 };
58948 }
58949 }]);
58950 return RandomNormal;
58951 }(Initializer);
58952 /** @nocollapse */
58953 RandomNormal.className = 'RandomNormal';
58954 registerClass(RandomNormal);
58955 var TruncatedNormal = /*#__PURE__*/function (_Initializer6) {
58956 _inherits(TruncatedNormal, _Initializer6);
58957 var _super7 = _createSuper(TruncatedNormal);
58958 function TruncatedNormal(args) {
58959 var _this5;
58960 _classCallCheck(this, TruncatedNormal);
58961 _this5 = _super7.call(this);
58962 _this5.DEFAULT_MEAN = 0.;
58963 _this5.DEFAULT_STDDEV = 0.05;
58964 _this5.mean = args.mean || _this5.DEFAULT_MEAN;
58965 _this5.stddev = args.stddev || _this5.DEFAULT_STDDEV;
58966 _this5.seed = args.seed;
58967 return _this5;
58968 }
58969 _createClass(TruncatedNormal, [{
58970 key: "apply",
58971 value: function apply(shape, dtype) {
58972 dtype = dtype || 'float32';
58973 if (dtype !== 'float32' && dtype !== 'int32') {
58974 throw new NotImplementedError("truncatedNormal does not support dType ".concat(dtype, "."));
58975 }
58976 return truncatedNormal$1(shape, this.mean, this.stddev, dtype, this.seed);
58977 }
58978 }, {
58979 key: "getConfig",
58980 value: function getConfig() {
58981 return {
58982 mean: this.mean,
58983 stddev: this.stddev,
58984 seed: this.seed
58985 };
58986 }
58987 }]);
58988 return TruncatedNormal;
58989 }(Initializer);
58990 /** @nocollapse */
58991 TruncatedNormal.className = 'TruncatedNormal';
58992 registerClass(TruncatedNormal);
58993 var Identity = /*#__PURE__*/function (_Initializer7) {
58994 _inherits(Identity, _Initializer7);
58995 var _super8 = _createSuper(Identity);
58996 function Identity(args) {
58997 var _this6;
58998 _classCallCheck(this, Identity);
58999 _this6 = _super8.call(this);
59000 _this6.gain = args.gain != null ? args.gain : 1.0;
59001 return _this6;
59002 }
59003 _createClass(Identity, [{
59004 key: "apply",
59005 value: function apply(shape, dtype) {
59006 var _this7 = this;
59007 return tidy(function () {
59008 if (shape.length !== 2 || shape[0] !== shape[1]) {
59009 throw new ValueError('Identity matrix initializer can only be used for' + ' 2D square matrices.');
59010 } else {
59011 return mul(_this7.gain, eye(shape[0]));
59012 }
59013 });
59014 }
59015 }, {
59016 key: "getConfig",
59017 value: function getConfig() {
59018 return {
59019 gain: this.gain
59020 };
59021 }
59022 }]);
59023 return Identity;
59024 }(Initializer);
59025 /** @nocollapse */
59026 Identity.className = 'Identity';
59027 registerClass(Identity);
59028 /**
59029 * Computes the number of input and output units for a weight shape.
59030 * @param shape Shape of weight.
59031 * @param dataFormat data format to use for convolution kernels.
59032 * Note that all kernels in Keras are standardized on the
59033 * CHANNEL_LAST ordering (even when inputs are set to CHANNEL_FIRST).
59034 * @return An length-2 array: fanIn, fanOut.
59035 */
59036 function computeFans(shape) {
59037 var dataFormat = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 'channelsLast';
59038 var fanIn;
59039 var fanOut;
59040 checkDataFormat(dataFormat);
59041 if (shape.length === 2) {
59042 fanIn = shape[0];
59043 fanOut = shape[1];
59044 } else if ([3, 4, 5].indexOf(shape.length) !== -1) {
59045 if (dataFormat === 'channelsFirst') {
59046 var receptiveFieldSize = arrayProd(shape, 2);
59047 fanIn = shape[1] * receptiveFieldSize;
59048 fanOut = shape[0] * receptiveFieldSize;
59049 } else if (dataFormat === 'channelsLast') {
59050 var _receptiveFieldSize = arrayProd(shape, 0, shape.length - 2);
59051 fanIn = shape[shape.length - 2] * _receptiveFieldSize;
59052 fanOut = shape[shape.length - 1] * _receptiveFieldSize;
59053 }
59054 } else {
59055 var shapeProd = arrayProd(shape);
59056 fanIn = Math.sqrt(shapeProd);
59057 fanOut = Math.sqrt(shapeProd);
59058 }
59059 return [fanIn, fanOut];
59060 }
59061 var VarianceScaling = /*#__PURE__*/function (_Initializer8) {
59062 _inherits(VarianceScaling, _Initializer8);
59063 var _super9 = _createSuper(VarianceScaling);
59064 /**
59065 * Constructor of VarianceScaling.
59066 * @throws ValueError for invalid value in scale.
59067 */
59068 function VarianceScaling(args) {
59069 var _this8;
59070 _classCallCheck(this, VarianceScaling);
59071 _this8 = _super9.call(this);
59072 if (args.scale < 0.0) {
59073 throw new ValueError("scale must be a positive float. Got: ".concat(args.scale));
59074 }
59075 _this8.scale = args.scale == null ? 1.0 : args.scale;
59076 _this8.mode = args.mode == null ? 'fanIn' : args.mode;
59077 checkFanMode(_this8.mode);
59078 _this8.distribution = args.distribution == null ? 'normal' : args.distribution;
59079 checkDistribution(_this8.distribution);
59080 _this8.seed = args.seed;
59081 return _this8;
59082 }
59083 _createClass(VarianceScaling, [{
59084 key: "apply",
59085 value: function apply(shape, dtype) {
59086 var fans = computeFans(shape);
59087 var fanIn = fans[0];
59088 var fanOut = fans[1];
59089 var scale = this.scale;
59090 if (this.mode === 'fanIn') {
59091 scale /= Math.max(1, fanIn);
59092 } else if (this.mode === 'fanOut') {
59093 scale /= Math.max(1, fanOut);
59094 } else {
59095 scale /= Math.max(1, (fanIn + fanOut) / 2);
59096 }
59097 if (this.distribution === 'normal') {
59098 var stddev = Math.sqrt(scale);
59099 dtype = dtype || 'float32';
59100 if (dtype !== 'float32' && dtype !== 'int32') {
59101 throw new NotImplementedError("".concat(this.getClassName(), " does not support dType ").concat(dtype, "."));
59102 }
59103 return truncatedNormal$1(shape, 0, stddev, dtype, this.seed);
59104 } else {
59105 var limit = Math.sqrt(3 * scale);
59106 return randomUniform$1(shape, -limit, limit, dtype, this.seed);
59107 }
59108 }
59109 }, {
59110 key: "getConfig",
59111 value: function getConfig() {
59112 return {
59113 scale: this.scale,
59114 mode: this.mode,
59115 distribution: this.distribution,
59116 seed: this.seed
59117 };
59118 }
59119 }]);
59120 return VarianceScaling;
59121 }(Initializer);
59122 /** @nocollapse */
59123 VarianceScaling.className = 'VarianceScaling';
59124 registerClass(VarianceScaling);
59125 var GlorotUniform = /*#__PURE__*/function (_VarianceScaling) {
59126 _inherits(GlorotUniform, _VarianceScaling);
59127 var _super10 = _createSuper(GlorotUniform);
59128 /**
59129 * Constructor of GlorotUniform
59130 * @param scale
59131 * @param mode
59132 * @param distribution
59133 * @param seed
59134 */
59135 function GlorotUniform(args) {
59136 _classCallCheck(this, GlorotUniform);
59137 return _super10.call(this, {
59138 scale: 1.0,
59139 mode: 'fanAvg',
59140 distribution: 'uniform',
59141 seed: args == null ? null : args.seed
59142 });
59143 }
59144 _createClass(GlorotUniform, [{
59145 key: "getClassName",
59146 value: function getClassName() {
59147 // In Python Keras, GlorotUniform is not a class, but a helper method
59148 // that creates a VarianceScaling object. Use 'VarianceScaling' as
59149 // class name to be compatible with that.
59150 return VarianceScaling.className;
59151 }
59152 }]);
59153 return GlorotUniform;
59154 }(VarianceScaling);
59155 /** @nocollapse */
59156 GlorotUniform.className = 'GlorotUniform';
59157 registerClass(GlorotUniform);
59158 var GlorotNormal = /*#__PURE__*/function (_VarianceScaling2) {
59159 _inherits(GlorotNormal, _VarianceScaling2);
59160 var _super11 = _createSuper(GlorotNormal);
59161 /**
59162 * Constructor of GlorotNormal.
59163 * @param scale
59164 * @param mode
59165 * @param distribution
59166 * @param seed
59167 */
59168 function GlorotNormal(args) {
59169 _classCallCheck(this, GlorotNormal);
59170 return _super11.call(this, {
59171 scale: 1.0,
59172 mode: 'fanAvg',
59173 distribution: 'normal',
59174 seed: args == null ? null : args.seed
59175 });
59176 }
59177 _createClass(GlorotNormal, [{
59178 key: "getClassName",
59179 value: function getClassName() {
59180 // In Python Keras, GlorotNormal is not a class, but a helper method
59181 // that creates a VarianceScaling object. Use 'VarianceScaling' as
59182 // class name to be compatible with that.
59183 return VarianceScaling.className;
59184 }
59185 }]);
59186 return GlorotNormal;
59187 }(VarianceScaling);
59188 /** @nocollapse */
59189 GlorotNormal.className = 'GlorotNormal';
59190 registerClass(GlorotNormal);
59191 var HeNormal = /*#__PURE__*/function (_VarianceScaling3) {
59192 _inherits(HeNormal, _VarianceScaling3);
59193 var _super12 = _createSuper(HeNormal);
59194 function HeNormal(args) {
59195 _classCallCheck(this, HeNormal);
59196 return _super12.call(this, {
59197 scale: 2.0,
59198 mode: 'fanIn',
59199 distribution: 'normal',
59200 seed: args == null ? null : args.seed
59201 });
59202 }
59203 _createClass(HeNormal, [{
59204 key: "getClassName",
59205 value: function getClassName() {
59206 // In Python Keras, HeNormal is not a class, but a helper method
59207 // that creates a VarianceScaling object. Use 'VarianceScaling' as
59208 // class name to be compatible with that.
59209 return VarianceScaling.className;
59210 }
59211 }]);
59212 return HeNormal;
59213 }(VarianceScaling);
59214 /** @nocollapse */
59215 HeNormal.className = 'HeNormal';
59216 registerClass(HeNormal);
59217 var HeUniform = /*#__PURE__*/function (_VarianceScaling4) {
59218 _inherits(HeUniform, _VarianceScaling4);
59219 var _super13 = _createSuper(HeUniform);
59220 function HeUniform(args) {
59221 _classCallCheck(this, HeUniform);
59222 return _super13.call(this, {
59223 scale: 2.0,
59224 mode: 'fanIn',
59225 distribution: 'uniform',
59226 seed: args == null ? null : args.seed
59227 });
59228 }
59229 _createClass(HeUniform, [{
59230 key: "getClassName",
59231 value: function getClassName() {
59232 // In Python Keras, HeUniform is not a class, but a helper method
59233 // that creates a VarianceScaling object. Use 'VarianceScaling' as
59234 // class name to be compatible with that.
59235 return VarianceScaling.className;
59236 }
59237 }]);
59238 return HeUniform;
59239 }(VarianceScaling);
59240 /** @nocollapse */
59241 HeUniform.className = 'HeUniform';
59242 registerClass(HeUniform);
59243 var LeCunNormal = /*#__PURE__*/function (_VarianceScaling5) {
59244 _inherits(LeCunNormal, _VarianceScaling5);
59245 var _super14 = _createSuper(LeCunNormal);
59246 function LeCunNormal(args) {
59247 _classCallCheck(this, LeCunNormal);
59248 return _super14.call(this, {
59249 scale: 1.0,
59250 mode: 'fanIn',
59251 distribution: 'normal',
59252 seed: args == null ? null : args.seed
59253 });
59254 }
59255 _createClass(LeCunNormal, [{
59256 key: "getClassName",
59257 value: function getClassName() {
59258 // In Python Keras, LeCunNormal is not a class, but a helper method
59259 // that creates a VarianceScaling object. Use 'VarianceScaling' as
59260 // class name to be compatible with that.
59261 return VarianceScaling.className;
59262 }
59263 }]);
59264 return LeCunNormal;
59265 }(VarianceScaling);
59266 /** @nocollapse */
59267 LeCunNormal.className = 'LeCunNormal';
59268 registerClass(LeCunNormal);
59269 var LeCunUniform = /*#__PURE__*/function (_VarianceScaling6) {
59270 _inherits(LeCunUniform, _VarianceScaling6);
59271 var _super15 = _createSuper(LeCunUniform);
59272 function LeCunUniform(args) {
59273 _classCallCheck(this, LeCunUniform);
59274 return _super15.call(this, {
59275 scale: 1.0,
59276 mode: 'fanIn',
59277 distribution: 'uniform',
59278 seed: args == null ? null : args.seed
59279 });
59280 }
59281 _createClass(LeCunUniform, [{
59282 key: "getClassName",
59283 value: function getClassName() {
59284 // In Python Keras, LeCunUniform is not a class, but a helper method
59285 // that creates a VarianceScaling object. Use 'VarianceScaling' as
59286 // class name to be compatible with that.
59287 return VarianceScaling.className;
59288 }
59289 }]);
59290 return LeCunUniform;
59291 }(VarianceScaling);
59292 /** @nocollapse */
59293 LeCunUniform.className = 'LeCunUniform';
59294 registerClass(LeCunUniform);
59295 var Orthogonal = /*#__PURE__*/function (_Initializer9) {
59296 _inherits(Orthogonal, _Initializer9);
59297 var _super16 = _createSuper(Orthogonal);
59298 function Orthogonal(args) {
59299 var _this9;
59300 _classCallCheck(this, Orthogonal);
59301 _this9 = _super16.call(this);
59302 _this9.DEFAULT_GAIN = 1;
59303 _this9.ELEMENTS_WARN_SLOW = 2000;
59304 _this9.gain = args.gain == null ? _this9.DEFAULT_GAIN : args.gain;
59305 _this9.seed = args.seed;
59306 return _this9;
59307 }
59308 _createClass(Orthogonal, [{
59309 key: "apply",
59310 value: function apply(shape, dtype) {
59311 var _this10 = this;
59312 return tidy(function () {
59313 if (shape.length < 2) {
59314 throw new NotImplementedError('Shape must be at least 2D.');
59315 }
59316 if (dtype !== 'int32' && dtype !== 'float32' && dtype !== undefined) {
59317 throw new TypeError("Unsupported data type ".concat(dtype, "."));
59318 }
59319 dtype = dtype;
59320 // flatten the input shape with the last dimension remaining its
59321 // original shape so it works for conv2d
59322 var numRows = sizeFromShape(shape.slice(0, -1));
59323 var numCols = shape[shape.length - 1];
59324 var numElements = numRows * numCols;
59325 if (numElements > _this10.ELEMENTS_WARN_SLOW) {
59326 console.warn("Orthogonal initializer is being called on a matrix with more " + "than ".concat(_this10.ELEMENTS_WARN_SLOW, " (").concat(numElements, ") elements: ") + "Slowness may result.");
59327 }
59328 var flatShape = [Math.max(numCols, numRows), Math.min(numCols, numRows)];
59329 // Generate a random matrix
59330 var randNormalMat = randomNormal$1(flatShape, 0, 1, dtype, _this10.seed);
59331 // Compute QR factorization
59332 var qr = linalg.qr(randNormalMat, false);
59333 var qMat = qr[0];
59334 var rMat = qr[1];
59335 // Make Q uniform
59336 var diag = rMat.flatten().stridedSlice([0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)], [Math.min(numCols, numRows) + 1]);
59337 qMat = mul(qMat, diag.sign());
59338 if (numRows < numCols) {
59339 qMat = qMat.transpose();
59340 }
59341 return mul(scalar(_this10.gain), qMat.reshape(shape));
59342 });
59343 }
59344 }, {
59345 key: "getConfig",
59346 value: function getConfig() {
59347 return {
59348 gain: this.gain,
59349 seed: this.seed
59350 };
59351 }
59352 }]);
59353 return Orthogonal;
59354 }(Initializer);
59355 /** @nocollapse */
59356 Orthogonal.className = 'Orthogonal';
59357 registerClass(Orthogonal);
59358 // Maps the JavaScript-like identifier keys to the corresponding registry
59359 // symbols.
59360 var INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
59361 'constant': 'Constant',
59362 'glorotNormal': 'GlorotNormal',
59363 'glorotUniform': 'GlorotUniform',
59364 'heNormal': 'HeNormal',
59365 'heUniform': 'HeUniform',
59366 'identity': 'Identity',
59367 'leCunNormal': 'LeCunNormal',
59368 'leCunUniform': 'LeCunUniform',
59369 'ones': 'Ones',
59370 'orthogonal': 'Orthogonal',
59371 'randomNormal': 'RandomNormal',
59372 'randomUniform': 'RandomUniform',
59373 'truncatedNormal': 'TruncatedNormal',
59374 'varianceScaling': 'VarianceScaling',
59375 'zeros': 'Zeros'
59376 };
59377 function deserializeInitializer(config) {
59378 var customObjects = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
59379 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'initializer');
59380 }
59381 function serializeInitializer(initializer) {
59382 return serializeKerasObject(initializer);
59383 }
59384 function getInitializer(identifier) {
59385 if (typeof identifier === 'string') {
59386 var className = identifier in INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ? INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier;
59387 /* We have four 'helper' classes for common initializers that
59388 all get serialized as 'VarianceScaling' and shouldn't go through
59389 the deserializeInitializer pathway. */
59390 if (className === 'GlorotNormal') {
59391 return new GlorotNormal();
59392 } else if (className === 'GlorotUniform') {
59393 return new GlorotUniform();
59394 } else if (className === 'HeNormal') {
59395 return new HeNormal();
59396 } else if (className === 'HeUniform') {
59397 return new HeUniform();
59398 } else if (className === 'LeCunNormal') {
59399 return new LeCunNormal();
59400 } else if (className === 'LeCunUniform') {
59401 return new LeCunUniform();
59402 } else {
59403 var config = {};
59404 config['className'] = className;
59405 config['config'] = {};
59406 return deserializeInitializer(config);
59407 }
59408 } else if (identifier instanceof Initializer) {
59409 return identifier;
59410 } else {
59411 return deserializeInitializer(identifier);
59412 }
59413 }
59414
59415 /**
59416 * @license
59417 * Copyright 2018 Google LLC
59418 *
59419 * Use of this source code is governed by an MIT-style
59420 * license that can be found in the LICENSE file or at
59421 * https://opensource.org/licenses/MIT.
59422 * =============================================================================
59423 */
59424 // tslint:enable
59425 /**
59426 * Determine whether the input is an Array of Shapes.
59427 */
59428 function isArrayOfShapes(x) {
59429 return Array.isArray(x) && Array.isArray(x[0]);
59430 }
59431 /**
59432 * Special case of normalizing shapes to lists.
59433 *
59434 * @param x A shape or list of shapes to normalize into a list of Shapes.
59435 * @return A list of Shapes.
59436 */
59437 function normalizeShapeList(x) {
59438 if (x.length === 0) {
59439 return [];
59440 }
59441 if (!Array.isArray(x[0])) {
59442 return [x];
59443 }
59444 return x;
59445 }
59446 /**
59447 * Helper function to obtain exactly one Tensor.
59448 * @param xs: A single `tf.Tensor` or an `Array` of `tf.Tensor`s.
59449 * @return A single `tf.Tensor`. If `xs` is an `Array`, return the first one.
59450 * @throws ValueError: If `xs` is an `Array` and its length is not 1.
59451 */
59452 function getExactlyOneTensor(xs) {
59453 var x;
59454 if (Array.isArray(xs)) {
59455 if (xs.length !== 1) {
59456 throw new ValueError("Expected Tensor length to be 1; got ".concat(xs.length));
59457 }
59458 x = xs[0];
59459 } else {
59460 x = xs;
59461 }
59462 return x;
59463 }
59464 /**
59465 * Helper function to obtain exactly on instance of Shape.
59466 *
59467 * @param shapes Input single `Shape` or Array of `Shape`s.
59468 * @returns If input is a single `Shape`, return it unchanged. If the input is
59469 * an `Array` containing exactly one instance of `Shape`, return the instance.
59470 * Otherwise, throw a `ValueError`.
59471 * @throws ValueError: If input is an `Array` of `Shape`s, and its length is not
59472 * 1.
59473 */
59474 function getExactlyOneShape(shapes) {
59475 if (Array.isArray(shapes) && Array.isArray(shapes[0])) {
59476 if (shapes.length === 1) {
59477 shapes = shapes;
59478 return shapes[0];
59479 } else {
59480 throw new ValueError("Expected exactly 1 Shape; got ".concat(shapes.length));
59481 }
59482 } else {
59483 return shapes;
59484 }
59485 }
59486
59487 /**
59488 * @license
59489 * Copyright 2018 Google LLC
59490 *
59491 * Use of this source code is governed by an MIT-style
59492 * license that can be found in the LICENSE file or at
59493 * https://opensource.org/licenses/MIT.
59494 * =============================================================================
59495 */
59496 /**
59497 * Count the elements in an Array of LayerVariables.
59498 *
59499 * @param weights: The LayerVariables of which the constituent numbers are to
59500 * be counted.
59501 * @returns A count of the elements in all the LayerVariables
59502 */
59503 function countParamsInWeights(weights) {
59504 var count = 0;
59505 var _iterator = _createForOfIteratorHelper(weights),
59506 _step;
59507 try {
59508 for (_iterator.s(); !(_step = _iterator.n()).done;) {
59509 var weight = _step.value;
59510 if (weight.shape.length === 0) {
59511 count += 1;
59512 } else {
59513 count += weight.shape.reduce(function (a, b) {
59514 return a * b;
59515 });
59516 }
59517 }
59518 } catch (err) {
59519 _iterator.e(err);
59520 } finally {
59521 _iterator.f();
59522 }
59523 return count;
59524 }
59525
59526 var DEFAULT_VARIABLE_NAME_PREFIX = 'Variable';
59527 /**
59528 * A `tf.layers.LayerVariable` is similar to a `tf.Tensor` in that it has a
59529 * dtype and shape, but its value is mutable. The value is itself represented
59530 * as a`tf.Tensor`, and can be read with the `read()` method and updated with
59531 * the `write()` method.
59532 */
59533 var LayerVariable = /*#__PURE__*/function () {
59534 /**
59535 * Construct Variable from a `tf.Tensor`.
59536 *
59537 * If not explicitly named, the Variable will be given a name with the
59538 * prefix 'Variable'. Variable names are unique. In the case of name
59539 * collision, suffixies '_<num>' will be added to the name.
59540 *
59541 * @param val Initial value of the Variable.
59542 * @param name Name of the variable. If `null` or `undefined` is provided, it
59543 * will default a name with the prefix 'Variable'.
59544 * @param constraint Optional, projection function to be applied to the
59545 * variable after optimize updates
59546 * @throws ValueError if `name` is `null` or `undefined`.
59547 */
59548 function LayerVariable(val) {
59549 var dtype = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 'float32';
59550 var name = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : DEFAULT_VARIABLE_NAME_PREFIX;
59551 var trainable = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : true;
59552 var constraint = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : null;
59553 _classCallCheck(this, LayerVariable);
59554 this.dtype = dtype == null ? 'float32' : dtype;
59555 this.shape = val.shape;
59556 this.id = getNextUniqueTensorId();
59557 name = name == null ? DEFAULT_VARIABLE_NAME_PREFIX : name;
59558 this.originalName = getScopedTensorName(name);
59559 this.name = getUniqueTensorName(this.originalName);
59560 this.trainable_ = trainable;
59561 this.constraint = constraint;
59562 this.val = variable$1(val, this.trainable_, this.name, this.dtype);
59563 }
59564 /**
59565 * Get a snapshot of the Variable's value.
59566 *
59567 * The returned value is a snapshot of the Variable's value at the time of
59568 * the invocation. Future mutations in the value of the tensor will only
59569 * be reflected by future calls to this method.
59570 */
59571 _createClass(LayerVariable, [{
59572 key: "read",
59573 value: function read() {
59574 this.assertNotDisposed();
59575 return this.val;
59576 }
59577 /**
59578 * Update the value of the Variable.
59579 *
59580 * @param newVal: The new value to update to. Must be consistent with the
59581 * dtype and shape of the Variable.
59582 * @return This Variable.
59583 */
59584 }, {
59585 key: "write",
59586 value: function write(newVal) {
59587 // TODO(cais): Once TF.js Core supports Tensor.dtype, check dtype match.
59588 this.assertNotDisposed();
59589 checkShapesMatch(this.val, newVal);
59590 // Skip updating if this is the exact same tensor.
59591 if (this.val.id !== newVal.id) {
59592 this.val.assign(newVal);
59593 if (this.constraint != null) {
59594 this.val.assign(this.constraint.apply(this.val));
59595 }
59596 }
59597 return this;
59598 }
59599 /**
59600 * Dispose this LayersVariable instance from memory.
59601 */
59602 }, {
59603 key: "dispose",
59604 value: function dispose() {
59605 this.assertNotDisposed();
59606 this.val.dispose();
59607 }
59608 }, {
59609 key: "assertNotDisposed",
59610 value: function assertNotDisposed() {
59611 if (this.val.isDisposed) {
59612 throw new Error("LayersVariable ".concat(this.name, " is already disposed."));
59613 }
59614 }
59615 }, {
59616 key: "trainable",
59617 get: function get() {
59618 return this.trainable_;
59619 },
59620 set: function set(trainable) {
59621 this.trainable_ = trainable;
59622 this.val.trainable = trainable;
59623 }
59624 }]);
59625 return LayerVariable;
59626 }();
59627 function checkShapesMatch(x, y) {
59628 if (x.shape.toString() !== y.shape.toString()) {
59629 throw new Error('Shape mismatch: ' + JSON.stringify(x.shape) + ' vs. ' + JSON.stringify(y.shape));
59630 }
59631 }
59632 /**
59633 * Create a Variable.
59634 * @param x The initial value of the `Variable`.
59635 * @param dtype optional, the type of the variable.
59636 * @param name optional, the name of the variable, default provided by
59637 * Variable.
59638 * @param constraint optional, a constraint to be applied after every update.
59639 * @return The newly instantiated `Variable`.
59640 */
59641 function variable(x, dtype, name, constraint) {
59642 return new LayerVariable(x, dtype, name, true, constraint);
59643 }
59644 /**
59645 * Instantiates an all-zeros Variable and returns it.
59646 *
59647 * @param shape Shape of the tensor.
59648 * @param dtype DType of the tensor.
59649 * @param name Name of the tensor.
59650 * @return An all-zero Variable.
59651 */
59652 function zerosVariable(shape, dtype, name) {
59653 // TODO(cais): Implement logic for dtype.
59654 return new LayerVariable(zeros$2(shape), dtype, name);
59655 }
59656 /**
59657 * Instantiates an all-zeros tensor of the same shape as another tensor.
59658 *
59659 * @param x The other tensor.
59660 * @param dtype DType of the tensor.
59661 * @param name Name of the tensor.
59662 * @return A newly instantiated Variable.
59663 */
59664 function zerosLike$2(x, dtype, name) {
59665 return new LayerVariable(zerosLike$3(x), dtype, name);
59666 }
59667 /**
59668 * Instantiates an all-ones tensor and returns it.
59669 *
59670 * @param shape Shape of the tensor.
59671 * @param dtype DType of the tensor.
59672 * @param name Name of the tensor.
59673 * @return An all-ones Variable.
59674 */
59675 function onesVariable(shape, dtype, name) {
59676 // TODO(cais): Implement logic for dtype.
59677 var allocated = ones$1(shape);
59678 return new LayerVariable(allocated, dtype, name);
59679 }
59680 /**
59681 * Instantiates an all-ones tensor of the same shape as another tensor.
59682 *
59683 * @param x The other tensor.
59684 * @param dtype DType of the tensor.
59685 * @param name Name of the tensor.
59686 * @return A newly instantiated Variable.
59687 */
59688 function onesLike$2(x, dtype, name) {
59689 var allocated = onesLike$3(x);
59690 return new LayerVariable(allocated, dtype, name);
59691 }
59692 /**
59693 * Instantiate an identity matrix and returns it, as a Variable
59694 *
59695 * @param size Number of rows/columns.
59696 * @param dtype Data type of returned Variable.
59697 * @param name Name of returned Variable.
59698 * @return A Variable, an identity matrix.
59699 */
59700 function eyeVariable(size, dtype, name) {
59701 return new LayerVariable(eye(size), dtype, name);
59702 }
59703 /**
59704 * Get a Variable with uniform distribution of values.
59705 * @param shape Shape of the tensor.
59706 * @param minval Lower bound of the uniform distribution.
59707 * @param maxval Upper bound of the uniform distribution.
59708 * @param dtype
59709 * @param seed
59710 * @param name Optional name.
59711 * @return The uniform-random Variable.
59712 */
59713 function randomUniformVariable(shape, minval, maxval, dtype, seed) {
59714 var name = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 'randomUniform';
59715 return new LayerVariable(randomUniform$1(shape, minval, maxval, dtype), dtype, name);
59716 }
59717 /**
59718 * Get a Variable with truncated-normal distribution of values.
59719 * @param shape Shape of the tensor.
59720 * @param mean mean value of the normal distribution.
59721 * @param stddev standard deviation of the normal distribution.
59722 * @param dtype
59723 * @param seed
59724 * @param name Optional name.
59725 * @return The truncated-normal-random Variable.
59726 */
59727 function truncatedNormalVariable(shape) {
59728 var mean = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0.0;
59729 var stddev = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1.0;
59730 var dtype = arguments.length > 3 ? arguments[3] : undefined;
59731 var seed = arguments.length > 4 ? arguments[4] : undefined;
59732 var name = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 'truncatedNormal';
59733 // TODO(cais): Implement logic for dtype and seed once they are supported
59734 // by deeplearn.js.
59735 dtype = dtype || 'float32';
59736 if (dtype !== 'float32' && dtype !== 'int32') {
59737 throw new NotImplementedError("randomNormal does not support dType ".concat(dtype, "."));
59738 }
59739 return new LayerVariable(truncatedNormal$1(shape, mean, stddev, dtype, seed), dtype, name);
59740 }
59741 /**
59742 * Get a Variable with normal distribution of values.
59743 * @param shape Shape of the tensor.
59744 * @param mean mean value of the normal distribution.
59745 * @param stddev standard deviation of the normal distribution.
59746 * @param dtype
59747 * @param seed
59748 * @param name Optional name.
59749 * @return The truncated-normal-random Variable.
59750 */
59751 function randomNormalVariable(shape) {
59752 var mean = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0.0;
59753 var stddev = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1.0;
59754 var dtype = arguments.length > 3 ? arguments[3] : undefined;
59755 var seed = arguments.length > 4 ? arguments[4] : undefined;
59756 var name = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 'randomNormal';
59757 dtype = dtype || 'float32';
59758 if (dtype !== 'float32' && dtype !== 'int32') {
59759 throw new NotImplementedError("randomNormalVariable does not support dType ".concat(dtype, "."));
59760 }
59761 return new LayerVariable(randomNormal$2(shape, mean, stddev, dtype, seed), dtype, name);
59762 }
59763 /**
59764 * Update the value of a Variable.
59765 * @param x The Variable to be updated.
59766 * @param xNew The new value to update to.
59767 * @return The Variable updated.
59768 */
59769 function update(x, xNew) {
59770 return x.write(xNew);
59771 }
59772 /**
59773 * Update the value of a Variable by adding an increment.
59774 * @param x The Variable to be updated.
59775 * @param increment The incrment to add to `x`.
59776 * @return The Variable updated.
59777 */
59778 function updateAdd(x, increment) {
59779 return x.write(add$3(x.read(), increment));
59780 }
59781 /**
59782 * Update the value of a Variable by subtracting a decrement.
59783 * @param x The Variable to be updated.
59784 * @param decrement The decrement to subtract from `x`.
59785 * @return The Variable updated.
59786 */
59787 function updateSub(x, decrement) {
59788 return x.write(sub$2(x.read(), decrement));
59789 }
59790 /**
59791 * Get the values of an array of Variables.
59792 *
59793 * @param tensors An `Array` of `Variable`s to get the values of.
59794 * @return The values of the inputs, as an `Array` of`tf.Tensor`s.
59795 */
59796 function batchGetValue(xs) {
59797 return xs.map(function (x) {
59798 return x.read();
59799 });
59800 }
59801 /**
59802 * Update the value of multiple Variables at once.
59803 *
59804 * @param variablesAndValues An `Array`, each element is of type
59805 * [Variable, Tensor]. The first item is the
59806 * `Variable` of which the value is to be updated. The second item
59807 * carries the new value.
59808 */
59809 function batchSetValue(variablesAndValues) {
59810 variablesAndValues.forEach(function (variableAndValue) {
59811 var variable = variableAndValue[0];
59812 variable.write(variableAndValue[1]);
59813 });
59814 }
59815 /**
59816 * Returns the gradients of `variables` w.r.t. the return value of `lossFn`.
59817 * @param lossFn A function which returns a Scalar to be used as the function
59818 * value (i.e., numerator) for differentiation.
59819 * @param variables List of variables to be used as the independent variables
59820 * (i.e., denominator) for differentiation.
59821 * @returns An Array of gradients tensors.
59822 */
59823 function gradients(lossFn, variables) {
59824 // TODO(cais): The return type signature can be simplified if deeplearn makes
59825 // the corresponding type public.
59826 var variableList = variables.map(function (variable) {
59827 return variable.read();
59828 });
59829 var valudAndGrads = variableGrads(lossFn, variableList);
59830 return variables.map(function (variable) {
59831 return valudAndGrads.grads[variable.name];
59832 });
59833 }
59834
59835 /**
59836 * Specifies the ndim, dtype and shape of every input to a layer.
59837 *
59838 * Every layer should expose (if appropriate) an `inputSpec` attribute:
59839 * a list of instances of InputSpec (one per input tensor).
59840 *
59841 * A null entry in a shape is compatible with any dimension,
59842 * a null shape is compatible with any shape.
59843 */
59844 var InputSpec = /*#__PURE__*/_createClass(function InputSpec(args) {
59845 _classCallCheck(this, InputSpec);
59846 this.dtype = args.dtype;
59847 this.shape = args.shape;
59848 /*
59849 TODO(michaelterry): Could throw error if ndim and shape are both defined
59850 (then backport).
59851 */
59852 if (args.shape != null) {
59853 this.ndim = args.shape.length;
59854 } else {
59855 this.ndim = args.ndim;
59856 }
59857 this.maxNDim = args.maxNDim;
59858 this.minNDim = args.minNDim;
59859 this.axes = args.axes || {};
59860 });
59861 /**
59862 * `tf.SymbolicTensor` is a placeholder for a Tensor without any concrete value.
59863 *
59864 * They are most often encountered when building a graph of `Layer`s for a
59865 * `tf.LayersModel` and the input data's shape, but not values are known.
59866 *
59867 * @doc {heading: 'Models', 'subheading': 'Classes'}
59868 */
59869 var SymbolicTensor = /*#__PURE__*/_createClass(
59870 /**
59871 *
59872 * @param dtype
59873 * @param shape
59874 * @param sourceLayer The Layer that produced this symbolic tensor.
59875 * @param inputs The inputs passed to sourceLayer's __call__() method.
59876 * @param nodeIndex
59877 * @param tensorIndex
59878 * @param callArgs The keyword arguments passed to the __call__() method.
59879 * @param name
59880 * @param outputTensorIndex The index of this tensor in the list of outputs
59881 * returned by apply().
59882 */
59883 function SymbolicTensor(dtype, shape, sourceLayer, inputs, callArgs, name, outputTensorIndex) {
59884 _classCallCheck(this, SymbolicTensor);
59885 this.dtype = dtype;
59886 this.shape = shape;
59887 this.sourceLayer = sourceLayer;
59888 this.inputs = inputs;
59889 this.callArgs = callArgs;
59890 this.outputTensorIndex = outputTensorIndex;
59891 this.id = getNextUniqueTensorId();
59892 if (name != null) {
59893 this.originalName = getScopedTensorName(name);
59894 this.name = getUniqueTensorName(this.originalName);
59895 }
59896 this.rank = shape.length;
59897 });
59898 var _nextNodeID = 0;
59899 /**
59900 * A `Node` describes the connectivity between two layers.
59901 *
59902 * Each time a layer is connected to some new input,
59903 * a node is added to `layer.inboundNodes`.
59904 *
59905 * Each time the output of a layer is used by another layer,
59906 * a node is added to `layer.outboundNodes`.
59907 *
59908 * `nodeIndices` and `tensorIndices` are basically fine-grained coordinates
59909 * describing the origin of the `inputTensors`, verifying the following:
59910 *
59911 * `inputTensors[i] ==
59912 * inboundLayers[i].inboundNodes[nodeIndices[i]].outputTensors[
59913 * tensorIndices[i]]`
59914 *
59915 * A node from layer A to layer B is added to:
59916 * A.outboundNodes
59917 * B.inboundNodes
59918 */
59919 var Node = /*#__PURE__*/function () {
59920 function Node(args,
59921 // TODO(michaelterry): Define actual type for this.
59922 callArgs) {
59923 _classCallCheck(this, Node);
59924 this.callArgs = callArgs;
59925 this.id = _nextNodeID++;
59926 /*
59927 Layer instance (NOT a list).
59928 this is the layer that takes a list of input tensors
59929 and turns them into a list of output tensors.
59930 the current node will be added to
59931 the inboundNodes of outboundLayer.
59932 */
59933 this.outboundLayer = args.outboundLayer;
59934 /*
59935 The following 3 properties describe where
59936 the input tensors come from: which layers,
59937 and for each layer, which node and which
59938 tensor output of each node.
59939 */
59940 // List of layer instances.
59941 this.inboundLayers = args.inboundLayers;
59942 // List of integers, 1:1 mapping with inboundLayers.
59943 this.nodeIndices = args.nodeIndices;
59944 // List of integers, 1:1 mapping with inboundLayers.
59945 this.tensorIndices = args.tensorIndices;
59946 /*
59947 Following 2 properties:
59948 tensor inputs and outputs of outboundLayer.
59949 */
59950 // List of tensors. 1:1 mapping with inboundLayers.
59951 this.inputTensors = args.inputTensors;
59952 // List of tensors, created by outboundLayer.call().
59953 this.outputTensors = args.outputTensors;
59954 /*
59955 Following 2 properties: input and output masks.
59956 List of tensors, 1:1 mapping with inputTensor.
59957 */
59958 this.inputMasks = args.inputMasks;
59959 // List of tensors, created by outboundLayer.computeMask().
59960 this.outputMasks = args.outputMasks;
59961 // Following 2 properties: input and output shapes.
59962 // List of shape tuples, shapes of inputTensors.
59963 this.inputShapes = args.inputShapes;
59964 // List of shape tuples, shapes of outputTensors.
59965 this.outputShapes = args.outputShapes;
59966 // Add nodes to all layers involved.
59967 var _iterator = _createForOfIteratorHelper(args.inboundLayers),
59968 _step;
59969 try {
59970 for (_iterator.s(); !(_step = _iterator.n()).done;) {
59971 var layer = _step.value;
59972 if (layer != null) {
59973 layer.outboundNodes.push(this);
59974 }
59975 }
59976 } catch (err) {
59977 _iterator.e(err);
59978 } finally {
59979 _iterator.f();
59980 }
59981 args.outboundLayer.inboundNodes.push(this);
59982 }
59983 _createClass(Node, [{
59984 key: "getConfig",
59985 value: function getConfig() {
59986 var inboundNames = [];
59987 var _iterator2 = _createForOfIteratorHelper(this.inboundLayers),
59988 _step2;
59989 try {
59990 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
59991 var layer = _step2.value;
59992 if (layer != null) {
59993 inboundNames.push(layer.name);
59994 } else {
59995 inboundNames.push(null);
59996 }
59997 }
59998 } catch (err) {
59999 _iterator2.e(err);
60000 } finally {
60001 _iterator2.f();
60002 }
60003 return {
60004 outboundLayer: this.outboundLayer ? this.outboundLayer.name : null,
60005 inboundLayers: inboundNames,
60006 nodeIndices: this.nodeIndices,
60007 tensorIndices: this.tensorIndices
60008 };
60009 }
60010 }]);
60011 return Node;
60012 }();
60013 var _nextLayerID = 0;
60014 /**
60015 * A layer is a grouping of operations and weights that can be composed to
60016 * create a `tf.LayersModel`.
60017 *
60018 * Layers are constructed by using the functions under the
60019 * [tf.layers](#Layers-Basic) namespace.
60020 *
60021 * @doc {heading: 'Layers', subheading: 'Classes', namespace: 'layers'}
60022 */
60023 var Layer = /*#__PURE__*/function (_serialization$Serial) {
60024 _inherits(Layer, _serialization$Serial);
60025 var _super = _createSuper(Layer);
60026 function Layer() {
60027 var _this;
60028 var args = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};
60029 _classCallCheck(this, Layer);
60030 _this = _super.call(this);
60031 _this._callHook = null;
60032 _this._addedWeightNames = [];
60033 // Porting Notes: PyKeras does not have this property in this base Layer
60034 // class. Instead lets Layer subclass set it dynamically and checks the
60035 // value with `hasattr`. In tfjs-layers, we let this be a member of this
60036 // base class.
60037 _this._stateful = false;
60038 _this.id = _nextLayerID++;
60039 _this.activityRegularizer = null;
60040 _this.inputSpec = null;
60041 _this.supportsMasking = false;
60042 // These properties will be set upon call of this.build()
60043 _this._trainableWeights = [];
60044 _this._nonTrainableWeights = [];
60045 _this._losses = [];
60046 _this._updates = [];
60047 _this._built = false;
60048 /*
60049 These lists will be filled via successive calls
60050 to this.addInboundNode().
60051 */
60052 _this.inboundNodes = [];
60053 _this.outboundNodes = [];
60054 var name = args.name;
60055 if (!name) {
60056 var prefix = _this.getClassName();
60057 name = toSnakeCase(prefix) + '_' + getUid(prefix);
60058 }
60059 _this.name = name;
60060 _this.trainable_ = args.trainable == null ? true : args.trainable;
60061 if (args.inputShape != null || args.batchInputShape != null) {
60062 /*
60063 In this case we will later create an input layer
60064 to insert before the current layer
60065 */
60066 var batchInputShape;
60067 if (args.batchInputShape != null) {
60068 batchInputShape = args.batchInputShape;
60069 } else if (args.inputShape != null) {
60070 var batchSize = null;
60071 if (args.batchSize != null) {
60072 batchSize = args.batchSize;
60073 }
60074 batchInputShape = [batchSize].concat(args.inputShape);
60075 }
60076 _this.batchInputShape = batchInputShape;
60077 // Set dtype.
60078 var dtype = args.dtype;
60079 if (dtype == null) {
60080 dtype = args.inputDType;
60081 }
60082 if (dtype == null) {
60083 dtype = 'float32';
60084 }
60085 _this.dtype = dtype;
60086 }
60087 if (args.weights != null) {
60088 _this.initialWeights = args.weights;
60089 } else {
60090 _this.initialWeights = null;
60091 }
60092 // The value of `_refCount` is initialized to null. When the layer is used
60093 // in a symbolic way for the first time, it will be set to 1.
60094 _this._refCount = null;
60095 _this.fastWeightInitDuringBuild = false;
60096 return _this;
60097 }
60098 /**
60099 * Converts a layer and its index to a unique (immutable type) name.
60100 * This function is used internally with `this.containerNodes`.
60101 * @param layer The layer.
60102 * @param nodeIndex The layer's position (e.g. via enumerate) in a list of
60103 * nodes.
60104 *
60105 * @returns The unique name.
60106 */
60107 _createClass(Layer, [{
60108 key: "getNodeAtIndex",
60109 value:
60110 /**
60111 * Returns this.inboundNode at index nodeIndex.
60112 *
60113 * Porting note: This is a replacement for _get_node_attribute_at_index()
60114 * @param nodeIndex
60115 * @param attrName The name of the attribute related to request for this node.
60116 */
60117 function getNodeAtIndex(nodeIndex, attrName) {
60118 if (this.inboundNodes.length === 0) {
60119 throw new RuntimeError('The layer has never been called ' + "and thus has no defined ".concat(attrName, "."));
60120 }
60121 if (this.inboundNodes.length <= nodeIndex) {
60122 throw new ValueError("Asked to get ".concat(attrName, " at node ").concat(nodeIndex, ", ") + "but the layer has only ".concat(this.inboundNodes.length, " inbound nodes."));
60123 }
60124 return this.inboundNodes[nodeIndex];
60125 }
60126 /**
60127 * Retrieves the input tensor(s) of a layer at a given node.
60128 *
60129 * @param nodeIndex Integer, index of the node from which to retrieve the
60130 * attribute. E.g. `nodeIndex=0` will correspond to the first time the layer
60131 * was called.
60132 *
60133 * @return A tensor (or list of tensors if the layer has multiple inputs).
60134 */
60135 }, {
60136 key: "getInputAt",
60137 value: function getInputAt(nodeIndex) {
60138 return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'input').inputTensors);
60139 }
60140 /**
60141 * Retrieves the output tensor(s) of a layer at a given node.
60142 *
60143 * @param nodeIndex Integer, index of the node from which to retrieve the
60144 * attribute. E.g. `nodeIndex=0` will correspond to the first time the layer
60145 * was called.
60146 *
60147 * @return A tensor (or list of tensors if the layer has multiple outputs).
60148 */
60149 }, {
60150 key: "getOutputAt",
60151 value: function getOutputAt(nodeIndex) {
60152 return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'output').outputTensors);
60153 }
60154 // Properties
60155 /**
60156 * Retrieves the input tensor(s) of a layer.
60157 *
60158 * Only applicable if the layer has exactly one inbound node,
60159 * i.e. if it is connected to one incoming layer.
60160 *
60161 * @return Input tensor or list of input tensors.
60162 *
60163 * @exception AttributeError if the layer is connected to more than one
60164 * incoming layers.
60165 */
60166 }, {
60167 key: "input",
60168 get: function get() {
60169 if (this.inboundNodes.length > 1) {
60170 throw new AttributeError("Layer ".concat(this.name) + ' has multiple inbound nodes, ' + 'hence the notion of "layer input" ' + 'is ill-defined. ' + 'Use `getInputAt(nodeIndex)` instead.');
60171 } else if (this.inboundNodes.length === 0) {
60172 throw new AttributeError("Layer ".concat(this.name) + ' is not connected, no input to return.');
60173 }
60174 return singletonOrArray(this.getNodeAtIndex(0, 'input').inputTensors);
60175 }
60176 /**
60177 * Retrieves the output tensor(s) of a layer.
60178 *
60179 * Only applicable if the layer has exactly one inbound node,
60180 * i.e. if it is connected to one incoming layer.
60181 *
60182 * @return Output tensor or list of output tensors.
60183 *
60184 * @exception AttributeError if the layer is connected to more than one
60185 * incoming layers.
60186 */
60187 }, {
60188 key: "output",
60189 get: function get() {
60190 if (this.inboundNodes.length === 0) {
60191 throw new AttributeError("Layer ".concat(this.name) + ' has no inbound nodes.');
60192 }
60193 if (this.inboundNodes.length > 1) {
60194 throw new AttributeError("Layer ".concat(this.name) + ' has multiple inbound nodes, ' + 'hence the notion of "layer output" ' + 'is ill-defined. ' + 'Use `getOutputAt(nodeIndex)` instead.');
60195 }
60196 return singletonOrArray(this.getNodeAtIndex(0, 'output').outputTensors);
60197 }
60198 }, {
60199 key: "losses",
60200 get: function get() {
60201 return this._losses;
60202 }
60203 /**
60204 * Retrieves the Layer's current loss values.
60205 *
60206 * Used for regularizers during training.
60207 */
60208 }, {
60209 key: "calculateLosses",
60210 value: function calculateLosses() {
60211 // Porting Node: This is an augmentation to Layer.loss in PyKeras.
60212 // In PyKeras, Layer.loss returns symbolic tensors. Here a concrete
60213 // Tensor (specifically Scalar) values are returned. This is due to the
60214 // imperative backend.
60215 return this.losses.map(function (lossFn) {
60216 return lossFn();
60217 });
60218 }
60219 }, {
60220 key: "updates",
60221 get: function get() {
60222 return this._updates;
60223 }
60224 }, {
60225 key: "built",
60226 get: function get() {
60227 return this._built;
60228 },
60229 set: function set(built) {
60230 this._built = built;
60231 }
60232 }, {
60233 key: "trainable",
60234 get: function get() {
60235 return this.trainable_;
60236 },
60237 set: function set(trainable) {
60238 this._trainableWeights.forEach(function (w) {
60239 return w.trainable = trainable;
60240 });
60241 this.trainable_ = trainable;
60242 }
60243 }, {
60244 key: "trainableWeights",
60245 get: function get() {
60246 if (this.trainable_) {
60247 return this._trainableWeights.filter(function (w) {
60248 return w.trainable;
60249 });
60250 } else {
60251 return [];
60252 }
60253 },
60254 set: function set(weights) {
60255 this._trainableWeights = weights;
60256 }
60257 }, {
60258 key: "nonTrainableWeights",
60259 get: function get() {
60260 if (this.trainable) {
60261 return this._trainableWeights.filter(function (w) {
60262 return !w.trainable;
60263 }).concat(this._nonTrainableWeights);
60264 } else {
60265 return this._trainableWeights.concat(this._nonTrainableWeights);
60266 }
60267 },
60268 set: function set(weights) {
60269 this._nonTrainableWeights = weights;
60270 }
60271 /**
60272 * The concatenation of the lists trainableWeights and nonTrainableWeights
60273 * (in this order).
60274 */
60275 }, {
60276 key: "weights",
60277 get: function get() {
60278 return this.trainableWeights.concat(this.nonTrainableWeights);
60279 }
60280 }, {
60281 key: "stateful",
60282 get: function get() {
60283 return this._stateful;
60284 }
60285 /**
60286 * Reset the states of the layer.
60287 *
60288 * This method of the base Layer class is essentially a no-op.
60289 * Subclasses that are stateful (e.g., stateful RNNs) should override this
60290 * method.
60291 */
60292 }, {
60293 key: "resetStates",
60294 value: function resetStates() {
60295 if (!this.stateful) {
60296 throw new Error('Cannot call the resetStates() method of a non-stateful Layer ' + 'object.');
60297 }
60298 }
60299 /**
60300 * Checks compatibility between the layer and provided inputs.
60301 *
60302 * This checks that the tensor(s) `input`
60303 * verify the input assumptions of the layer
60304 * (if any). If not, exceptions are raised.
60305 *
60306 * @param inputs Input tensor or list of input tensors.
60307 *
60308 * @exception ValueError in case of mismatch between
60309 * the provided inputs and the expectations of the layer.
60310 */
60311 }, {
60312 key: "assertInputCompatibility",
60313 value: function assertInputCompatibility(inputs) {
60314 var inputsList = toList(inputs);
60315 if (this.inputSpec == null || this.inputSpec.length === 0) {
60316 return;
60317 }
60318 var inputSpec = toList(this.inputSpec);
60319 if (inputsList.length !== inputSpec.length) {
60320 throw new ValueError("Layer ".concat(this.name, " expects ").concat(inputSpec.length, " inputs, ") + "but it received ".concat(inputsList.length, " input tensors. ") + "Input received: ".concat(inputs));
60321 }
60322 for (var inputIndex = 0; inputIndex < inputsList.length; inputIndex++) {
60323 var x = inputsList[inputIndex];
60324 var spec = inputSpec[inputIndex];
60325 if (spec == null) {
60326 continue;
60327 }
60328 // Check ndim.
60329 var ndim = x.rank;
60330 if (spec.ndim != null) {
60331 if (ndim !== spec.ndim) {
60332 throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ").concat(this.name, ": ") + "expected ndim=".concat(spec.ndim, ", found ndim=").concat(ndim));
60333 }
60334 }
60335 if (spec.maxNDim != null) {
60336 if (ndim > spec.maxNDim) {
60337 throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ").concat(this.name) + ": expected max_ndim=".concat(spec.maxNDim, ", found ndim=").concat(ndim));
60338 }
60339 }
60340 if (spec.minNDim != null) {
60341 if (ndim < spec.minNDim) {
60342 throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ").concat(this.name) + ": expected min_ndim=".concat(spec.minNDim, ", found ndim=").concat(ndim, "."));
60343 }
60344 }
60345 // Check dtype.
60346 if (spec.dtype != null) {
60347 if (x.dtype !== spec.dtype) {
60348 throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ").concat(this.name, " ") + ": expected dtype=".concat(spec.dtype, ", found dtype=").concat(x.dtype, "."));
60349 }
60350 }
60351 // Check specific shape axes.
60352 if (spec.axes) {
60353 var xShape = x.shape;
60354 for (var key in spec.axes) {
60355 var axis = Number(key);
60356 var value = spec.axes[key];
60357 // Perform Python-style slicing in case axis < 0;
60358 // TODO(cais): Use https://github.com/alvivi/typescript-underscore to
60359 // ensure type safety through Underscore calls.
60360 var xShapeAtAxis = axis >= 0 ? xShape[axis] : xShape[xShape.length + axis];
60361 if (value != null && [value, null].indexOf(xShapeAtAxis) === -1) {
60362 throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ") + "".concat(this.name, ": expected axis ").concat(axis, " of input shape to ") + "have value ".concat(value, " but got shape ").concat(xShape, "."));
60363 }
60364 }
60365 }
60366 // Check shape.
60367 if (spec.shape != null) {
60368 for (var i = 0; i < spec.shape.length; ++i) {
60369 var specDim = spec.shape[i];
60370 var dim = x.shape[i];
60371 if (specDim != null && dim != null) {
60372 if (specDim !== dim) {
60373 throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ") + "".concat(this.name, ": expected shape=").concat(spec.shape, ", ") + "found shape=".concat(x.shape, "."));
60374 }
60375 }
60376 }
60377 }
60378 }
60379 }
60380 /**
60381 * This is where the layer's logic lives.
60382 *
60383 * @param inputs Input tensor, or list/tuple of input tensors.
60384 * @param kwargs Additional keyword arguments.
60385 *
60386 * @return A tensor or list/tuple of tensors.
60387 */
60388 }, {
60389 key: "call",
60390 value: function call(inputs, kwargs) {
60391 return inputs;
60392 }
60393 }, {
60394 key: "invokeCallHook",
60395 value: function invokeCallHook(inputs, kwargs) {
60396 if (this._callHook != null) {
60397 this._callHook(inputs, kwargs);
60398 }
60399 }
60400 /**
60401 * Set call hook.
60402 * This is currently used for testing only.
60403 * @param callHook
60404 */
60405 }, {
60406 key: "setCallHook",
60407 value: function setCallHook(callHook) {
60408 this._callHook = callHook;
60409 }
60410 /**
60411 * Clear call hook.
60412 * This is currently used for testing only.
60413 */
60414 }, {
60415 key: "clearCallHook",
60416 value: function clearCallHook() {
60417 this._callHook = null;
60418 }
60419 /**
60420 * Builds or executes a `Layer`'s logic.
60421 *
60422 * When called with `tf.Tensor`(s), execute the `Layer`'s computation and
60423 * return Tensor(s). For example:
60424 *
60425 * ```js
60426 * const denseLayer = tf.layers.dense({
60427 * units: 1,
60428 * kernelInitializer: 'zeros',
60429 * useBias: false
60430 * });
60431 *
60432 * // Invoke the layer's apply() method with a `tf.Tensor` (with concrete
60433 * // numeric values).
60434 * const input = tf.ones([2, 2]);
60435 * const output = denseLayer.apply(input);
60436 *
60437 * // The output's value is expected to be [[0], [0]], due to the fact that
60438 * // the dense layer has a kernel initialized to all-zeros and does not have
60439 * // a bias.
60440 * output.print();
60441 * ```
60442 *
60443 * When called with `tf.SymbolicTensor`(s), this will prepare the layer for
60444 * future execution. This entails internal book-keeping on shapes of
60445 * expected Tensors, wiring layers together, and initializing weights.
60446 *
60447 * Calling `apply` with `tf.SymbolicTensor`s are typically used during the
60448 * building of non-`tf.Sequential` models. For example:
60449 *
60450 * ```js
60451 * const flattenLayer = tf.layers.flatten();
60452 * const denseLayer = tf.layers.dense({units: 1});
60453 *
60454 * // Use tf.layers.input() to obtain a SymbolicTensor as input to apply().
60455 * const input = tf.input({shape: [2, 2]});
60456 * const output1 = flattenLayer.apply(input);
60457 *
60458 * // output1.shape is [null, 4]. The first dimension is the undetermined
60459 * // batch size. The second dimension comes from flattening the [2, 2]
60460 * // shape.
60461 * console.log(JSON.stringify(output1.shape));
60462 *
60463 * // The output SymbolicTensor of the flatten layer can be used to call
60464 * // the apply() of the dense layer:
60465 * const output2 = denseLayer.apply(output1);
60466 *
60467 * // output2.shape is [null, 1]. The first dimension is the undetermined
60468 * // batch size. The second dimension matches the number of units of the
60469 * // dense layer.
60470 * console.log(JSON.stringify(output2.shape));
60471 *
60472 * // The input and output can be used to construct a model that consists
60473 * // of the flatten and dense layers.
60474 * const model = tf.model({inputs: input, outputs: output2});
60475 * ```
60476 *
60477 * @param inputs a `tf.Tensor` or `tf.SymbolicTensor` or an Array of them.
60478 * @param kwargs Additional keyword arguments to be passed to `call()`.
60479 *
60480 * @return Output of the layer's `call` method.
60481 *
60482 * @exception ValueError error in case the layer is missing shape information
60483 * for its `build` call.
60484 *
60485 * @doc {heading: 'Models', 'subheading': 'Classes'}
60486 */
60487 // Porting Note: This is a replacement for __call__() in Python.
60488 }, {
60489 key: "apply",
60490 value: function apply(inputs, kwargs) {
60491 var _this2 = this;
60492 kwargs = kwargs || {};
60493 this.assertNotDisposed();
60494 // Ensure inputs are all the same type.
60495 var inputsList = toList(inputs);
60496 var allAreSymbolic = checkAllSymbolic(inputs);
60497 var noneAreSymbolic = checkNoneSymbolic(inputs);
60498 if (allAreSymbolic === noneAreSymbolic) {
60499 throw new ValueError('Arguments to apply() must be all ' + 'SymbolicTensors or all Tensors');
60500 }
60501 // TODO(michaelterry): nameScope() may not be necessary.
60502 return nameScope(this.name, function () {
60503 // Handle laying building (weight creating, input spec locking).
60504 if (!_this2.built) {
60505 /*
60506 Throw exceptions in case the input is not compatible
60507 with the inputSpec specified in the layer constructor.
60508 */
60509 _this2.assertInputCompatibility(inputs);
60510 // Collect input shapes to build layer.
60511 var inputShapes = [];
60512 var _iterator3 = _createForOfIteratorHelper(toList(inputs)),
60513 _step3;
60514 try {
60515 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
60516 var xElem = _step3.value;
60517 inputShapes.push(xElem.shape);
60518 }
60519 } catch (err) {
60520 _iterator3.e(err);
60521 } finally {
60522 _iterator3.f();
60523 }
60524 _this2.build(singletonOrArray(inputShapes));
60525 _this2.built = true;
60526 // Load weights that were specified at layer instantiation.
60527 if (_this2.initialWeights) {
60528 _this2.setWeights(_this2.initialWeights);
60529 }
60530 if (_this2._refCount === null && noneAreSymbolic) {
60531 // The first use of this layer is a non-symbolic call, set ref count
60532 // to 1 so the Layer can be properly disposed if its dispose() method
60533 // is called.
60534 _this2._refCount = 1;
60535 }
60536 }
60537 /*
60538 Throw exceptions in case the input is not compatible
60539 with the inputSpec set at build time.
60540 */
60541 _this2.assertInputCompatibility(inputs);
60542 // Handle mask propagation.
60543 // TODO(michaelterry): Mask propagation not currently implemented.
60544 // Actually call the layer, collecting output(s), mask(s), and shape(s).
60545 if (noneAreSymbolic) {
60546 var output = _this2.call(inputs, kwargs);
60547 // Apply masks to the output tensors if the layer supports it.
60548 if (_this2.supportsMasking) {
60549 // TODO(mattsoulanille): pass the input tensors' masks to computeMask
60550 _this2.setMaskMetadata(inputs, output);
60551 }
60552 // If the layer returns tensors from its inputs, unmodified,
60553 // we copy them to avoid loss of tensor metadata.
60554 var outputList = toList(output);
60555 var outputListCopy = [];
60556 // TODO(michaelterry): This copying may not be necessary given our eager
60557 // backend.
60558 var _iterator4 = _createForOfIteratorHelper(outputList),
60559 _step4;
60560 try {
60561 for (_iterator4.s(); !(_step4 = _iterator4.n()).done;) {
60562 var x = _step4.value;
60563 if (inputsList.indexOf(x) !== -1) {
60564 x = x.clone();
60565 }
60566 outputListCopy.push(x);
60567 }
60568 } catch (err) {
60569 _iterator4.e(err);
60570 } finally {
60571 _iterator4.f();
60572 }
60573 output = singletonOrArray(outputListCopy);
60574 if (_this2.activityRegularizer != null) {
60575 throw new NotImplementedError('Layer invocation in the presence of activity ' + 'regularizer(s) is not supported yet.');
60576 }
60577 // TODO(michaelterry): Call addInboundNode()?
60578 return output;
60579 } else {
60580 var inputShape = collectInputShape(inputs);
60581 var outputShape = _this2.computeOutputShape(inputShape);
60582 var _output;
60583 var outputDType = guessOutputDType(inputs);
60584 _this2.warnOnIncompatibleInputShape(Array.isArray(inputs) ? inputShape[0] : inputShape);
60585 if (outputShape != null && outputShape.length > 0 && Array.isArray(outputShape[0])) {
60586 // We have multiple output shapes. Create multiple output tensors.
60587 _output = outputShape.map(function (shape, index) {
60588 return new SymbolicTensor(outputDType, shape, _this2, toList(inputs), kwargs, _this2.name, index);
60589 });
60590 } else {
60591 _output = new SymbolicTensor(outputDType, outputShape, _this2, toList(inputs), kwargs, _this2.name);
60592 }
60593 /*
60594 Add an inbound node to the layer, so that it keeps track
60595 of the call and of all new variables created during the call.
60596 This also updates the layer history of the output tensor(s).
60597 If the input tensor(s) had no previous history,
60598 this does nothing.
60599 */
60600 _this2.addInboundNode(inputs, _output, null, null, inputShape, outputShape, kwargs);
60601 _this2._refCount++;
60602 if (_this2.activityRegularizer != null) {
60603 throw new NotImplementedError('Layer invocation in the presence of activity ' + 'regularizer(s) is not supported yet.');
60604 }
60605 return _output;
60606 }
60607 });
60608 }
60609 /**
60610 * Check compatibility between input shape and this layer's batchInputShape.
60611 *
60612 * Print warning if any incompatibility is found.
60613 *
60614 * @param inputShape Input shape to be checked.
60615 */
60616 }, {
60617 key: "warnOnIncompatibleInputShape",
60618 value: function warnOnIncompatibleInputShape(inputShape) {
60619 if (this.batchInputShape == null) {
60620 return;
60621 } else if (inputShape.length !== this.batchInputShape.length) {
60622 console.warn("The rank of the input tensor provided (shape: " + "".concat(JSON.stringify(inputShape), ") does not match that of the ") + "batchInputShape (".concat(JSON.stringify(this.batchInputShape), ") ") + "of the layer ".concat(this.name));
60623 } else {
60624 var dimMismatch = false;
60625 this.batchInputShape.forEach(function (dimension, i) {
60626 if (dimension != null && inputShape[i] != null && inputShape[i] !== dimension) {
60627 dimMismatch = true;
60628 }
60629 });
60630 if (dimMismatch) {
60631 console.warn("The shape of the input tensor " + "(".concat(JSON.stringify(inputShape), ") does not ") + "match the expectation of layer ".concat(this.name, ": ") + "".concat(JSON.stringify(this.batchInputShape)));
60632 }
60633 }
60634 }
60635 /**
60636 * Retrieves the output shape(s) of a layer.
60637 *
60638 * Only applicable if the layer has only one inbound node, or if all inbound
60639 * nodes have the same output shape.
60640 *
60641 * @returns Output shape or shapes.
60642 * @throws AttributeError: if the layer is connected to more than one incoming
60643 * nodes.
60644 *
60645 * @doc {heading: 'Models', 'subheading': 'Classes'}
60646 */
60647 }, {
60648 key: "outputShape",
60649 get: function get() {
60650 if (this.inboundNodes == null || this.inboundNodes.length === 0) {
60651 throw new AttributeError("The layer ".concat(this.name, " has never been called and thus has no ") + "defined output shape.");
60652 }
60653 var allOutputShapes = [];
60654 var _iterator5 = _createForOfIteratorHelper(this.inboundNodes),
60655 _step5;
60656 try {
60657 for (_iterator5.s(); !(_step5 = _iterator5.n()).done;) {
60658 var node = _step5.value;
60659 var shapeString = JSON.stringify(node.outputShapes);
60660 if (allOutputShapes.indexOf(shapeString) === -1) {
60661 allOutputShapes.push(shapeString);
60662 }
60663 }
60664 } catch (err) {
60665 _iterator5.e(err);
60666 } finally {
60667 _iterator5.f();
60668 }
60669 if (allOutputShapes.length === 1) {
60670 var outputShapes = this.inboundNodes[0].outputShapes;
60671 if (Array.isArray(outputShapes) && Array.isArray(outputShapes[0]) && outputShapes.length === 1) {
60672 return outputShapes[0];
60673 } else {
60674 return outputShapes;
60675 }
60676 } else {
60677 throw new AttributeError("The layer ".concat(this.name, " has multiple inbound nodes with different ") + "output shapes. Hence the notion of \"output shape\" is ill-defined " + "for the layer.");
60678 // TODO(cais): Implement getOutputShapeAt().
60679 }
60680 }
60681 /**
60682 * Counts the total number of numbers (e.g., float32, int32) in the
60683 * weights.
60684 *
60685 * @returns An integer count.
60686 * @throws RuntimeError: If the layer is not built yet (in which case its
60687 * weights are not defined yet.)
60688 *
60689 * @doc {heading: 'Models', 'subheading': 'Classes'}
60690 */
60691 }, {
60692 key: "countParams",
60693 value: function countParams() {
60694 if (!this.built) {
60695 throw new RuntimeError("You tried to call countParams() on ".concat(this.name, ", ") + "but the layer is not built yet. Build it first by calling " + "build(batchInputShape).");
60696 }
60697 return countParamsInWeights(this.weights);
60698 }
60699 /**
60700 * Creates the layer weights.
60701 *
60702 * Must be implemented on all layers that have weights.
60703 *
60704 * Called when apply() is called to construct the weights.
60705 *
60706 * @param inputShape A `Shape` or array of `Shape` (unused).
60707 *
60708 * @doc {heading: 'Models', 'subheading': 'Classes'}
60709 */
60710 }, {
60711 key: "build",
60712 value: function build(inputShape) {
60713 this.built = true;
60714 }
60715 /**
60716 * Returns the current values of the weights of the layer.
60717 *
60718 * @param trainableOnly Whether to get the values of only trainable weights.
60719 * @returns Weight values as an `Array` of `tf.Tensor`s.
60720 *
60721 * @doc {heading: 'Models', 'subheading': 'Classes'}
60722 */
60723 }, {
60724 key: "getWeights",
60725 value: function getWeights() {
60726 var trainableOnly = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : false;
60727 return batchGetValue(trainableOnly ? this.trainableWeights : this.weights);
60728 }
60729 /**
60730 * Sets the weights of the layer, from Tensors.
60731 *
60732 * @param weights a list of Tensors. The number of arrays and their shape
60733 * must match number of the dimensions of the weights of the layer (i.e.
60734 * it should match the output of `getWeights`).
60735 *
60736 * @exception ValueError If the provided weights list does not match the
60737 * layer's specifications.
60738 *
60739 * @doc {heading: 'Models', 'subheading': 'Classes'}
60740 */
60741 }, {
60742 key: "setWeights",
60743 value: function setWeights(weights) {
60744 var _this3 = this;
60745 tidy(function () {
60746 var params = _this3.weights;
60747 if (params.length !== weights.length) {
60748 // TODO(cais): Restore the following and use `providedWeights`, instead
60749 // of `weights` in the error message, once the deeplearn.js bug is
60750 // fixed: https://github.com/PAIR-code/deeplearnjs/issues/498 const
60751 // providedWeights = JSON.stringify(weights).slice(0, 50);
60752 throw new ValueError("You called setWeights(weights) on layer \"".concat(_this3.name, "\" ") + "with a weight list of length ".concat(weights.length, ", ") + "but the layer was expecting ".concat(params.length, " weights. ") + "Provided weights: ".concat(weights, "..."));
60753 }
60754 if (params.length === 0) {
60755 return;
60756 }
60757 var weightValueTuples = [];
60758 var paramValues = batchGetValue(params);
60759 for (var i = 0; i < paramValues.length; ++i) {
60760 var pv = paramValues[i];
60761 var p = params[i];
60762 var w = weights[i];
60763 if (!arraysEqual(pv.shape, w.shape)) {
60764 throw new ValueError("Layer weight shape ".concat(pv.shape, " ") + "not compatible with provided weight shape ".concat(w.shape));
60765 }
60766 weightValueTuples.push([p, w]);
60767 }
60768 batchSetValue(weightValueTuples);
60769 });
60770 }
60771 /**
60772 * Adds a weight variable to the layer.
60773 *
60774 * @param name Name of the new weight variable.
60775 * @param shape The shape of the weight.
60776 * @param dtype The dtype of the weight.
60777 * @param initializer An initializer instance.
60778 * @param regularizer A regularizer instance.
60779 * @param trainable Whether the weight should be trained via backprop or not
60780 * (assuming that the layer itself is also trainable).
60781 * @param constraint An optional trainable.
60782 * @return The created weight variable.
60783 *
60784 * @doc {heading: 'Models', 'subheading': 'Classes'}
60785 */
60786 }, {
60787 key: "addWeight",
60788 value: function addWeight(name, shape, dtype, initializer, regularizer, trainable, constraint, getInitializerFunc) {
60789 // Reject duplicate weight names.
60790 if (this._addedWeightNames.indexOf(name) !== -1) {
60791 throw new ValueError("Duplicate weight name ".concat(name, " for layer ").concat(this.name));
60792 }
60793 this._addedWeightNames.push(name);
60794 if (dtype == null) {
60795 dtype = 'float32';
60796 }
60797 if (this.fastWeightInitDuringBuild) {
60798 initializer = getInitializerFunc != null ? getInitializerFunc() : getInitializer('zeros');
60799 }
60800 var initValue = initializer.apply(shape, dtype);
60801 var weight = new LayerVariable(initValue, dtype, name, trainable, constraint);
60802 initValue.dispose();
60803 // Request backend not to dispose the weights of the model on scope() exit.
60804 if (regularizer != null) {
60805 this.addLoss(function () {
60806 return regularizer.apply(weight.read());
60807 });
60808 }
60809 if (trainable == null) {
60810 trainable = true;
60811 }
60812 if (trainable) {
60813 this._trainableWeights.push(weight);
60814 } else {
60815 this._nonTrainableWeights.push(weight);
60816 }
60817 return weight;
60818 }
60819 /**
60820 * Set the fast-weight-initialization flag.
60821 *
60822 * In cases where the initialized weight values will be immediately
60823 * overwritten by loaded weight values during model loading, setting
60824 * the flag to `true` saves unnecessary calls to potentially expensive
60825 * initializers and speeds up the loading process.
60826 *
60827 * @param value Target value of the flag.
60828 */
60829 }, {
60830 key: "setFastWeightInitDuringBuild",
60831 value: function setFastWeightInitDuringBuild(value) {
60832 this.fastWeightInitDuringBuild = value;
60833 }
60834 /**
60835 * Add losses to the layer.
60836 *
60837 * The loss may potentially be conditional on some inputs tensors,
60838 * for instance activity losses are conditional on the layer's inputs.
60839 *
60840 * @doc {heading: 'Models', 'subheading': 'Classes'}
60841 */
60842 }, {
60843 key: "addLoss",
60844 value: function addLoss(losses) {
60845 if (losses == null || Array.isArray(losses) && losses.length === 0) {
60846 return;
60847 }
60848 // Update this.losses
60849 losses = toList(losses);
60850 if (this._losses !== undefined && this._losses !== null) {
60851 var _this$losses;
60852 (_this$losses = this.losses).push.apply(_this$losses, _toConsumableArray(losses));
60853 }
60854 }
60855 /**
60856 * Computes the output shape of the layer.
60857 *
60858 * Assumes that the layer will be built to match that input shape provided.
60859 *
60860 * @param inputShape A shape (tuple of integers) or a list of shape tuples
60861 * (one per output tensor of the layer). Shape tuples can include null for
60862 * free dimensions, instead of an integer.
60863 *
60864 * @doc {heading: 'Models', 'subheading': 'Classes'}
60865 */
60866 }, {
60867 key: "computeOutputShape",
60868 value: function computeOutputShape(inputShape) {
60869 return inputShape;
60870 }
60871 /**
60872 * Computes an output mask tensor.
60873 *
60874 * @param inputs Tensor or list of tensors.
60875 * @param mask Tensor or list of tensors.
60876 *
60877 * @return null or a tensor (or list of tensors, one per output tensor of the
60878 * layer).
60879 */
60880 }, {
60881 key: "computeMask",
60882 value: function computeMask(inputs, mask) {
60883 var _this4 = this;
60884 if (!this.supportsMasking) {
60885 if (mask != null) {
60886 if (Array.isArray(mask)) {
60887 mask.forEach(function (maskElement) {
60888 if (maskElement != null) {
60889 throw new TypeError("Layer ".concat(_this4.name, " does not support masking, ") + 'but was passed an inputMask.');
60890 }
60891 });
60892 } else {
60893 throw new TypeError("Layer ".concat(this.name, " does not support masking, ") + 'but was passed an inputMask.');
60894 }
60895 }
60896 // masking not explicitly supported: return null as mask
60897 return null;
60898 }
60899 // if masking is explictly supported, by default
60900 // carry over the input mask
60901 return mask;
60902 }
60903 }, {
60904 key: "setMaskMetadata",
60905 value: function setMaskMetadata(inputs, outputs, previousMask) {
60906 if (!this.supportsMasking) {
60907 return;
60908 }
60909 var outputMasks = this.computeMask(inputs, previousMask);
60910 var outputsList = toList(outputs);
60911 var outputMasksList = toList(outputMasks);
60912 if (outputsList.length !== outputMasksList.length) {
60913 throw new Error("".concat(this.name, " outputs ").concat(outputsList.length, " tensors ") + "but ".concat(outputsList.length, " masks for those tensors"));
60914 }
60915 for (var i = 0; i < outputsList.length; i++) {
60916 outputsList[i].kerasMask = outputMasksList[i];
60917 }
60918 }
60919 /**
60920 * Internal method to create an inbound node for the layer.
60921 *
60922 * @param inputTensors List of input tensors.
60923 * @param outputTensors List of output tensors.
60924 * @param inputMasks List of input masks (a mask can be a tensor, or null).
60925 * @param outputMasks List of output masks (a mask can be a tensor, or null).
60926 * @param inputShapes List of input shape tuples.
60927 * @param outputShapes List of output shape tuples.
60928 * @param kwargs Dictionary of keyword arguments that were passed to the
60929 * `call` method of the layer at the call that created the node.
60930 */
60931 }, {
60932 key: "addInboundNode",
60933 value: function addInboundNode(inputTensors, outputTensors, inputMasks, outputMasks, inputShapes, outputShapes) {
60934 var kwargs = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : null;
60935 var inputTensorList = toList(inputTensors);
60936 outputTensors = toList(outputTensors);
60937 inputMasks = toList(inputMasks);
60938 outputMasks = toList(outputMasks);
60939 inputShapes = normalizeShapeList(inputShapes);
60940 outputShapes = normalizeShapeList(outputShapes);
60941 // Collect input tensor(s) coordinates.
60942 var inboundLayers = [];
60943 var nodeIndices = [];
60944 var tensorIndices = [];
60945 var _iterator6 = _createForOfIteratorHelper(inputTensorList),
60946 _step6;
60947 try {
60948 for (_iterator6.s(); !(_step6 = _iterator6.n()).done;) {
60949 var x = _step6.value;
60950 /*
60951 * TODO(michaelterry): Keras adds this value to tensors; it's not
60952 * clear whether we'll use this or not.
60953 */
60954 inboundLayers.push(x.sourceLayer);
60955 nodeIndices.push(x.nodeIndex);
60956 tensorIndices.push(x.tensorIndex);
60957 }
60958 // Create node, add it to inbound nodes.
60959 // (This call has side effects.)
60960 // tslint:disable-next-line:no-unused-expression
60961 } catch (err) {
60962 _iterator6.e(err);
60963 } finally {
60964 _iterator6.f();
60965 }
60966 new Node({
60967 outboundLayer: this,
60968 inboundLayers: inboundLayers,
60969 nodeIndices: nodeIndices,
60970 tensorIndices: tensorIndices,
60971 inputTensors: inputTensorList,
60972 outputTensors: outputTensors,
60973 inputMasks: inputMasks,
60974 outputMasks: outputMasks,
60975 inputShapes: inputShapes,
60976 outputShapes: outputShapes
60977 }, kwargs);
60978 // Update tensor history
60979 for (var i = 0; i < outputTensors.length; i++) {
60980 // TODO(michaelterry: _uses_learning_phase not tracked.
60981 outputTensors[i].sourceLayer = this;
60982 outputTensors[i].nodeIndex = this.inboundNodes.length - 1;
60983 outputTensors[i].tensorIndex = i;
60984 }
60985 }
60986 /**
60987 * Returns the config of the layer.
60988 *
60989 * A layer config is a TS dictionary (serializable)
60990 * containing the configuration of a layer.
60991 * The same layer can be reinstantiated later
60992 * (without its trained weights) from this configuration.
60993 *
60994 * The config of a layer does not include connectivity
60995 * information, nor the layer class name. These are handled
60996 * by 'Container' (one layer of abstraction above).
60997 *
60998 * Porting Note: The TS dictionary follows TS naming standards for
60999 * keys, and uses tfjs-layers type-safe Enums. Serialization methods
61000 * should use a helper function to convert to the pythonic storage
61001 * standard. (see serialization_utils.convertTsToPythonic)
61002 *
61003 * @returns TS dictionary of configuration.
61004 *
61005 * @doc {heading: 'Models', 'subheading': 'Classes'}
61006 */
61007 }, {
61008 key: "getConfig",
61009 value: function getConfig() {
61010 var config = {
61011 name: this.name,
61012 trainable: this.trainable
61013 };
61014 if (this.batchInputShape != null) {
61015 config['batchInputShape'] = this.batchInputShape;
61016 }
61017 if (this.dtype != null) {
61018 config['dtype'] = this.dtype;
61019 }
61020 return config;
61021 }
61022 /**
61023 * Dispose the weight variables that this Layer instance holds.
61024 *
61025 * @returns {number} Number of disposed variables.
61026 */
61027 }, {
61028 key: "disposeWeights",
61029 value: function disposeWeights() {
61030 this.weights.forEach(function (weight) {
61031 return weight.dispose();
61032 });
61033 return this.weights.length;
61034 }
61035 }, {
61036 key: "assertNotDisposed",
61037 value: function assertNotDisposed() {
61038 if (this._refCount === 0) {
61039 throw new Error("Layer '".concat(this.name, "' is already disposed."));
61040 }
61041 }
61042 /**
61043 * Attempt to dispose layer's weights.
61044 *
61045 * This method decreases the reference count of the Layer object by 1.
61046 *
61047 * A Layer is reference-counted. Its reference count is incremented by 1
61048 * the first item its `apply()` method is called and when it becomes a part
61049 * of a new `Node` (through calling the `apply()` method on a
61050 * `tf.SymbolicTensor`).
61051 *
61052 * If the reference count of a Layer becomes 0, all the weights will be
61053 * disposed and the underlying memory (e.g., the textures allocated in WebGL)
61054 * will be freed.
61055 *
61056 * Note: If the reference count is greater than 0 after the decrement, the
61057 * weights of the Layer will *not* be disposed.
61058 *
61059 * After a Layer is disposed, it cannot be used in calls such as `apply()`,
61060 * `getWeights()` or `setWeights()` anymore.
61061 *
61062 * @returns A DisposeResult Object with the following fields:
61063 * - refCountAfterDispose: The reference count of the Container after this
61064 * `dispose()` call.
61065 * - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed
61066 * during this `dispose()` call.
61067 * @throws {Error} If the layer is not built yet, or if the layer has already
61068 * been disposed.
61069 *
61070 * @doc {heading: 'Models', 'subheading': 'Classes'}
61071 */
61072 }, {
61073 key: "dispose",
61074 value: function dispose() {
61075 if (!this.built) {
61076 throw new Error("Cannot dispose Layer ".concat(this.name, " because it has not been ") + "built yet.");
61077 }
61078 if (this._refCount === null) {
61079 throw new Error("Cannot dispose Layer ".concat(this.name, " because it has not been used ") + "yet.");
61080 }
61081 this.assertNotDisposed();
61082 var numDisposedVariables = 0;
61083 if (--this._refCount === 0) {
61084 numDisposedVariables = this.disposeWeights();
61085 }
61086 return {
61087 refCountAfterDispose: this._refCount,
61088 numDisposedVariables: numDisposedVariables
61089 };
61090 }
61091 }], [{
61092 key: "nodeKey",
61093 value: function nodeKey(layer, nodeIndex) {
61094 return layer.name + '_ib-' + nodeIndex.toString();
61095 }
61096 }]);
61097 return Layer;
61098 }(Serializable);
61099 /**
61100 * Collects the input shape(s) of a list of `tf.Tensor`s or
61101 * `tf.SymbolicTensor`s.
61102 *
61103 * TODO(michaelterry): Update PyKeras docs (backport).
61104 *
61105 * @param inputTensors List of input tensors (or single input tensor).
61106 *
61107 * @return List of shape tuples (or single tuple), one tuple per input.
61108 */
61109 function collectInputShape(inputTensors) {
61110 inputTensors = toList(inputTensors);
61111 var shapes = [];
61112 var _iterator7 = _createForOfIteratorHelper(inputTensors),
61113 _step7;
61114 try {
61115 for (_iterator7.s(); !(_step7 = _iterator7.n()).done;) {
61116 var x = _step7.value;
61117 shapes.push(x.shape);
61118 }
61119 } catch (err) {
61120 _iterator7.e(err);
61121 } finally {
61122 _iterator7.f();
61123 }
61124 return singletonOrArray(shapes);
61125 }
61126 /**
61127 * Guesses output dtype based on inputs.
61128 *
61129 * At present, just returns 'float32' for any input.
61130 *
61131 * @param inputTensors List of input tensors (or single input tensor).
61132 *
61133 * @return The guessed DType. At present, always returns 'float32'.
61134 */
61135 function guessOutputDType(inputTensors) {
61136 return 'float32';
61137 }
61138 /**
61139 * Returns the list of input tensors necessary to compute `tensor`.
61140 *
61141 * Output will always be a list of tensors (potentially with 1 element).
61142 *
61143 * @param tensor The tensor to start from.
61144 * @param layer Origin layer of the tensor.
61145 * @param nodeIndex Origin node index of the tensor.
61146 *
61147 * @return Array of input tensors.
61148 */
61149 function getSourceInputs(tensor, layer, nodeIndex) {
61150 if (layer == null || nodeIndex != null && nodeIndex > 0) {
61151 layer = tensor.sourceLayer;
61152 nodeIndex = tensor.nodeIndex;
61153 }
61154 if (layer.inboundNodes.length === 0) {
61155 return [tensor];
61156 } else {
61157 var node = layer.inboundNodes[nodeIndex];
61158 if (node.inboundLayers.length === 0) {
61159 return node.inputTensors;
61160 } else {
61161 var sourceTensors = [];
61162 for (var i = 0; i < node.inboundLayers.length; i++) {
61163 var x = node.inputTensors[i];
61164 var _layer = node.inboundLayers[i];
61165 var _nodeIndex = node.nodeIndices[i];
61166 var previousSources = getSourceInputs(x, _layer, _nodeIndex);
61167 // Avoid input redundancy.
61168 var _iterator8 = _createForOfIteratorHelper(previousSources),
61169 _step8;
61170 try {
61171 for (_iterator8.s(); !(_step8 = _iterator8.n()).done;) {
61172 var _x = _step8.value;
61173 if (sourceTensors.indexOf(_x) === -1) {
61174 sourceTensors.push(_x);
61175 }
61176 }
61177 } catch (err) {
61178 _iterator8.e(err);
61179 } finally {
61180 _iterator8.f();
61181 }
61182 }
61183 return sourceTensors;
61184 }
61185 }
61186 }
61187 function checkAllSymbolic(tensors) {
61188 var allAreSymbolic = true;
61189 var _iterator9 = _createForOfIteratorHelper(toList(tensors)),
61190 _step9;
61191 try {
61192 for (_iterator9.s(); !(_step9 = _iterator9.n()).done;) {
61193 var tensor = _step9.value;
61194 if (!(tensor instanceof SymbolicTensor)) {
61195 allAreSymbolic = false;
61196 break;
61197 }
61198 }
61199 } catch (err) {
61200 _iterator9.e(err);
61201 } finally {
61202 _iterator9.f();
61203 }
61204 return allAreSymbolic;
61205 }
61206 function checkNoneSymbolic(tensors) {
61207 var noneAreSymbolic = true;
61208 var _iterator10 = _createForOfIteratorHelper(toList(tensors)),
61209 _step10;
61210 try {
61211 for (_iterator10.s(); !(_step10 = _iterator10.n()).done;) {
61212 var tensor = _step10.value;
61213 if (tensor instanceof SymbolicTensor) {
61214 noneAreSymbolic = false;
61215 break;
61216 }
61217 }
61218 } catch (err) {
61219 _iterator10.e(err);
61220 } finally {
61221 _iterator10.f();
61222 }
61223 return noneAreSymbolic;
61224 }
61225
61226 var InputLayer = /*#__PURE__*/function (_Layer) {
61227 _inherits(InputLayer, _Layer);
61228 var _super = _createSuper(InputLayer);
61229 function InputLayer(args) {
61230 var _this;
61231 _classCallCheck(this, InputLayer);
61232 _this = _super.call(this, {
61233 dtype: args.dtype,
61234 name: args.name != null ? args.name : getUid('input').toString()
61235 });
61236 // Normalize config.batchSize and config.sparse
61237 if (args.batchSize == null) {
61238 args.batchSize = null;
61239 }
61240 if (args.sparse == null) {
61241 args.sparse = false;
61242 }
61243 _this.trainable = false;
61244 _this.built = true;
61245 _this.sparse = args.sparse;
61246 if (args.inputShape != null && args.batchInputShape != null) {
61247 throw new ValueError('Only provide the inputShape OR ' + 'batchInputShape argument to inputLayer, not both at the same time.');
61248 }
61249 var batchInputShape = args.batchInputShape;
61250 if (batchInputShape == null) {
61251 if (args.inputShape == null) {
61252 throw new ValueError('An InputLayer should be passed either a ' + '`batchInputShape` or an `inputShape`.');
61253 } else {
61254 batchInputShape = [args.batchSize].concat(args.inputShape);
61255 }
61256 } else {
61257 // TODO(michaelterry): Backport to PyKeras
61258 if (args.batchSize != null) {
61259 throw new ValueError('Cannot specify batchSize if batchInputShape is ' + 'specified when creating an InputLayer.');
61260 }
61261 }
61262 var dtype = args.dtype || 'float32';
61263 _this.batchInputShape = batchInputShape;
61264 _this.dtype = dtype;
61265 // TODO(michaelterry): Backport this to PyKeras?
61266 _this.inputSpec = [{
61267 shape: batchInputShape
61268 }];
61269 var inputTensor = new SymbolicTensor(_this.dtype, _this.batchInputShape, _assertThisInitialized(_this), [], {}, _this.name);
61270 inputTensor.nodeIndex = 0;
61271 inputTensor.tensorIndex = 0;
61272 // Create an input node to add to this.outboundNode.
61273 // (This call has side effects.)
61274 // tslint:disable-next-line:no-unused-expression
61275 new Node({
61276 outboundLayer: _assertThisInitialized(_this),
61277 inboundLayers: [],
61278 nodeIndices: [],
61279 tensorIndices: [],
61280 inputTensors: [inputTensor],
61281 outputTensors: [inputTensor],
61282 inputMasks: [null],
61283 outputMasks: [null],
61284 inputShapes: [batchInputShape],
61285 outputShapes: [batchInputShape]
61286 });
61287 return _this;
61288 }
61289 _createClass(InputLayer, [{
61290 key: "apply",
61291 value: function apply(inputs, kwargs) {
61292 throw new ValueError('Cannot pass any input to an ' + "InputLayer's apply() method. InputLayer name: ".concat(this.name));
61293 }
61294 }, {
61295 key: "dispose",
61296 value: function dispose() {
61297 // dispose() for InputLayer is overridden as no-op.
61298 return {
61299 refCountAfterDispose: this._refCount,
61300 numDisposedVariables: 0
61301 };
61302 }
61303 }, {
61304 key: "getConfig",
61305 value: function getConfig() {
61306 return {
61307 batchInputShape: this.batchInputShape,
61308 dtype: this.dtype,
61309 sparse: this.sparse,
61310 name: this.name
61311 };
61312 }
61313 }]);
61314 return InputLayer;
61315 }(Layer);
61316 /** @nocollapse */
61317 InputLayer.className = 'InputLayer';
61318 registerClass(InputLayer);
61319 function Input(config) {
61320 if (config.batchShape == null && config.shape == null) {
61321 throw new Error('Please provide to Input either a `shape`' + ' or a `batchShape` argument. Note that ' + '`shape` does not include the batch ' + 'dimension.');
61322 }
61323 if (config.batchShape != null && config.shape != null) {
61324 // TODO(michaelterry): Backport to PyKeras.
61325 throw new ValueError('Please provide either a `shape` or `batchShape` ' + 'argument to Input, but not both.');
61326 }
61327 var batchShape = config.batchShape;
61328 if (config.shape != null && batchShape == null) {
61329 batchShape = [null].concat(config.shape);
61330 }
61331 var dtype = config.dtype;
61332 if (dtype == null) {
61333 dtype = 'float32';
61334 }
61335 var inputLayer = new InputLayer({
61336 batchInputShape: batchShape,
61337 name: config.name,
61338 dtype: dtype,
61339 sparse: config.sparse
61340 });
61341 var outputs = inputLayer.inboundNodes[0].outputTensors;
61342 return outputs[0];
61343 }
61344
61345 /**
61346 * Helper function to check the dtype and shape compatibility of a feed value.
61347 */
61348 function assertFeedCompatibility(key, val) {
61349 // Check dtype compatibility.
61350 if (key.dtype == null || key.dtype === val.dtype) {
61351 // a. If types match, return val tensor as is.
61352 return val;
61353 }
61354 try {
61355 // b. Attempt to convert to expected type.
61356 return cast$3(val, key.dtype);
61357 } catch (err) {
61358 // c. If conversion fails, return helpful error.
61359 throw new ValueError("The dtype of the feed (".concat(val.dtype, ") can not be cast to the dtype ") + "of the key '".concat(key.name, "' (").concat(key.dtype, ")."));
61360 }
61361 }
61362 /**
61363 * FeedDict: A mapping from unique SymbolicTensors to feed values for them.
61364 * A feed value is a concrete value represented as an `Tensor`.
61365 */
61366 var FeedDict = /*#__PURE__*/function () {
61367 /**
61368 * Constructor, optionally does copy-construction.
61369 * @param feeds An Array of `Feed`s, or another `FeedDict`, in which case
61370 * copy-construction will be performed.
61371 */
61372 function FeedDict(feeds) {
61373 _classCallCheck(this, FeedDict);
61374 this.id2Value = {};
61375 this.id2Mask = {};
61376 this.name2Id = {};
61377 if (feeds instanceof FeedDict) {
61378 for (var id in feeds.id2Value) {
61379 this.id2Value[id] = feeds.id2Value[id];
61380 if (id in feeds.id2Mask) {
61381 this.id2Mask[id] = feeds.id2Mask[id];
61382 }
61383 }
61384 } else {
61385 if (feeds == null) {
61386 return;
61387 }
61388 var _iterator = _createForOfIteratorHelper(feeds),
61389 _step;
61390 try {
61391 for (_iterator.s(); !(_step = _iterator.n()).done;) {
61392 var feed = _step.value;
61393 this.add(feed.key, feed.value);
61394 }
61395 } catch (err) {
61396 _iterator.e(err);
61397 } finally {
61398 _iterator.f();
61399 }
61400 }
61401 }
61402 /**
61403 * Add a key-value pair to the FeedDict.
61404 *
61405 * @param key The key of the feed.
61406 * @param value The value of the tensor feed.
61407 * @param mask The value of the mask feed (optional).
61408 * @returns This `FeedDict`.
61409 * @throws ValueError: If the key `SymbolicTensor` already exists in the
61410 * `FeedDict`.
61411 */
61412 _createClass(FeedDict, [{
61413 key: "add",
61414 value: function add(key, value, mask) {
61415 if (this.id2Value[key.id] == null) {
61416 this.id2Value[key.id] = assertFeedCompatibility(key, value);
61417 this.name2Id[key.name] = key.id;
61418 if (mask != null) {
61419 this.id2Mask[key.id] = mask;
61420 }
61421 } else {
61422 throw new ValueError("Duplicate key: name=".concat(key.name, ", id=").concat(key.id));
61423 }
61424 return this;
61425 }
61426 /**
61427 * Add a Feed to the FeedDict.
61428 * @param feed The new `Feed` to add.
61429 * @returns This `FeedDict`.
61430 */
61431 }, {
61432 key: "addFeed",
61433 value: function addFeed(feed) {
61434 this.add(feed.key, feed.value);
61435 }
61436 /**
61437 * Probe whether a key already exists in the FeedDict.
61438 * @param key
61439 */
61440 }, {
61441 key: "hasKey",
61442 value: function hasKey(key) {
61443 return this.id2Value[key.id] != null;
61444 }
61445 /**
61446 * Get all the SymbolicTensor available in this FeedDict.
61447 */
61448 }, {
61449 key: "names",
61450 value: function names() {
61451 return Object.keys(this.name2Id);
61452 }
61453 /**
61454 * Get the feed value for given key.
61455 * @param key The SymbolicTensor, or its name (as a string), of which the
61456 * value is sought.
61457 * @returns If `key` exists, the corresponding feed value.
61458 * @throws ValueError: If `key` does not exist in this `FeedDict`.
61459 */
61460 }, {
61461 key: "getValue",
61462 value: function getValue(key) {
61463 if (key instanceof SymbolicTensor) {
61464 if (this.id2Value[key.id] == null) {
61465 throw new ValueError("Nonexistent key: ".concat(key.name));
61466 } else {
61467 return this.id2Value[key.id];
61468 }
61469 } else {
61470 var id = this.name2Id[key];
61471 if (id == null) {
61472 throw new ValueError("Feed dict has no SymbolicTensor name: ".concat(key));
61473 }
61474 return this.id2Value[id];
61475 }
61476 }
61477 /**
61478 * Get the feed mask for given key.
61479 * @param key The SymbolicTensor, or its name (as a string), of which the
61480 * value is sought.
61481 * @returns If `key` exists, the corresponding feed mask.
61482 * @throws ValueError: If `key` does not exist in this `FeedDict`.
61483 */
61484 }, {
61485 key: "getMask",
61486 value: function getMask(key) {
61487 if (key instanceof SymbolicTensor) {
61488 if (this.id2Value[key.id] == null) {
61489 throw new ValueError("Nonexistent key: ".concat(key.name));
61490 } else {
61491 return this.id2Mask[key.id];
61492 }
61493 } else {
61494 var id = this.name2Id[key];
61495 if (id == null) {
61496 throw new ValueError("Feed dict has no SymbolicTensor name: ".concat(key));
61497 }
61498 return this.id2Mask[id];
61499 }
61500 }
61501 /** Dispose all mask Tensors held by this object. */
61502 }, {
61503 key: "disposeMasks",
61504 value: function disposeMasks() {
61505 if (this.id2Mask != null) {
61506 dispose(this.id2Mask);
61507 }
61508 }
61509 }]);
61510 return FeedDict;
61511 }();
61512 // Cache for topologically sorted SymbolicTensors for given execution
61513 // targets (i.e., fetches).
61514 var cachedSorted = new LruCache();
61515 // Cache for recipient count maps for given execution targets (i.e., fetches).
61516 var cachedRecipientCounts = new LruCache();
61517 function updateCacheMaxEntries(maxEntries) {
61518 if (cachedSorted != null) {
61519 cachedSorted.setMaxEntries(maxEntries);
61520 }
61521 if (cachedRecipientCounts != null) {
61522 cachedRecipientCounts.setMaxEntries(maxEntries);
61523 }
61524 }
61525 /**
61526 * Execute a SymbolicTensor by using concrete feed values.
61527 *
61528 * A `SymbolicTensor` object is a node in a computation graph of TF.js
61529 * Layers. The object is backed by a source layer and input
61530 * `SymbolicTensor`s to the source layer. This method evaluates
61531 * the `call()` method of the source layer, using concrete values of the
61532 * inputs obtained from either
61533 * * `feedDict`, if the input key exists in `feedDict`, or else,
61534 * * a recursive call to `execute()` itself.
61535 *
61536 * @param x: The `SymbolicTensor` to execute.
61537 * @param feedDict: The feed values, as base condition of the recursion.
61538 * execution.
61539 * @param kwargs: Optional keyword arguments.
61540 * @param probe: A probe object (of interface `ExecutionProbe`) used for
61541 * testing memory footprint of `execute` calls.
61542 * @returns Result of the execution.
61543 * @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s
61544 * encountered during the execution lacks a feed value in `feedDict`.
61545 */
61546 function execute(fetches, feedDict, kwargs, probe) {
61547 var training = kwargs == null ? false : kwargs['training'];
61548 var arrayFetches = Array.isArray(fetches);
61549 var fetchArray = arrayFetches ? fetches : [fetches];
61550 var outputNames = fetchArray.map(function (t) {
61551 return t.name;
61552 });
61553 var finalOutputs = [];
61554 var feedNames = feedDict.names();
61555 var _iterator2 = _createForOfIteratorHelper(outputNames),
61556 _step2;
61557 try {
61558 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
61559 var outputName = _step2.value;
61560 if (feedNames.indexOf(outputName) !== -1) {
61561 finalOutputs.push(feedDict.getValue(outputName));
61562 } else {
61563 finalOutputs.push(null);
61564 }
61565 }
61566 } catch (err) {
61567 _iterator2.e(err);
61568 } finally {
61569 _iterator2.f();
61570 }
61571 if (probe != null) {
61572 // For optional probing of memory footprint during execution.
61573 probe.maxNumTensors = -Infinity;
61574 probe.minNumTensors = Infinity;
61575 }
61576 // Check cache.
61577 var fetchAndFeedKey = outputNames.join(',') + '|' + feedDict.names().sort().join(',');
61578 var sorted = cachedSorted.get(fetchAndFeedKey);
61579 var recipientCounts;
61580 if (sorted == null) {
61581 // Cache doesn't contain the desired combination of fetches. Compute
61582 // topological sort for the combination for the first time.
61583 var out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);
61584 sorted = out.sorted;
61585 recipientCounts = out.recipientCounts;
61586 // Store results in cache for future use.
61587 cachedSorted.put(fetchAndFeedKey, sorted);
61588 cachedRecipientCounts.put(fetchAndFeedKey, recipientCounts);
61589 }
61590 recipientCounts = {};
61591 if (!training) {
61592 Object.assign(recipientCounts, cachedRecipientCounts.get(fetchAndFeedKey));
61593 }
61594 var internalFeedDict = new FeedDict(feedDict);
61595 // Start iterative execution on the topologically-sorted SymbolicTensors.
61596 for (var i = 0; i < sorted.length; ++i) {
61597 if (probe != null) {
61598 // For optional probing of memory usage during execution.
61599 var numTensors = memory().numTensors;
61600 if (numTensors > probe.maxNumTensors) {
61601 probe.maxNumTensors = numTensors;
61602 }
61603 if (numTensors < probe.minNumTensors) {
61604 probe.minNumTensors = numTensors;
61605 }
61606 }
61607 var symbolic = sorted[i];
61608 var srcLayer = symbolic.sourceLayer;
61609 if (srcLayer instanceof InputLayer) {
61610 continue;
61611 }
61612 var inputValues = [];
61613 var inputMasks = [];
61614 var tensorsToDispose = [];
61615 var maskExists = false;
61616 var _iterator3 = _createForOfIteratorHelper(symbolic.inputs),
61617 _step3;
61618 try {
61619 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
61620 var input = _step3.value;
61621 var value = internalFeedDict.getValue(input);
61622 var mask = internalFeedDict.getMask(input);
61623 inputValues.push(value);
61624 inputMasks.push(mask);
61625 if (mask != null) {
61626 maskExists = true;
61627 }
61628 if (!training) {
61629 recipientCounts[input.name]--;
61630 if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) && outputNames.indexOf(input.name) === -1 && !value.isDisposed && input.sourceLayer.stateful !== true) {
61631 tensorsToDispose.push(value);
61632 }
61633 }
61634 }
61635 } catch (err) {
61636 _iterator3.e(err);
61637 } finally {
61638 _iterator3.f();
61639 }
61640 if (maskExists) {
61641 kwargs = kwargs || {};
61642 kwargs['mask'] = inputMasks[0];
61643 }
61644 var outputTensors = toList(srcLayer.apply(inputValues, kwargs));
61645 var outputMask = null;
61646 if (srcLayer.supportsMasking) {
61647 outputMask = srcLayer.computeMask(inputValues, inputMasks);
61648 }
61649 var layerOutputs = getNodeOutputs(symbolic);
61650 var outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];
61651 for (var _i = 0; _i < outputSymbolicTensors.length; ++_i) {
61652 if (!internalFeedDict.hasKey(outputSymbolicTensors[_i])) {
61653 internalFeedDict.add(outputSymbolicTensors[_i], outputTensors[_i], Array.isArray(outputMask) ? outputMask[0] : outputMask);
61654 }
61655 var index = outputNames.indexOf(outputSymbolicTensors[_i].name);
61656 if (index !== -1) {
61657 finalOutputs[index] = outputTensors[_i];
61658 }
61659 }
61660 if (!training) {
61661 // Clean up Tensors that are no longer needed.
61662 dispose(tensorsToDispose);
61663 }
61664 }
61665 // NOTE(cais): Unlike intermediate tensors, we don't discard mask
61666 // tensors as we go, because these tensors are sometimes passed over a
61667 // series of mutliple layers, i.e., not obeying the immediate input
61668 // relations in the graph. If this becomes a memory-usage concern,
61669 // we can improve this in the future.
61670 internalFeedDict.disposeMasks();
61671 return arrayFetches ? finalOutputs : finalOutputs[0];
61672 }
61673 /**
61674 * Sort the `SymbolicTensor`s topologically, for an array of fetches.
61675 *
61676 * This function calls getTopologicalSortAndRecipientCountsForOneFetch and
61677 * merges their results.
61678 *
61679 * @param fetch The array of fetches requested. Must be a non-empty array.
61680 * @param feedDict The dictionary of fed values.
61681 * @returns sorted: Topologically-sorted array of SymbolicTensors.
61682 * recipientCounts: Recipient counts for all SymbolicTensors in `sorted`.
61683 */
61684 function getTopologicalSortAndRecipientCounts(fetches, feedDict) {
61685 assert$1(fetches != null && fetches.length > 0, function () {
61686 return "Expected at least one fetch, got none";
61687 });
61688 var finalSorted = [];
61689 var finalRecipientMap = {};
61690 if (fetches.length === 1) {
61691 // Special-casing 1 fetch for efficiency.
61692 var out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict);
61693 finalSorted = out.sorted;
61694 finalRecipientMap = out.recipientMap;
61695 } else {
61696 var visited = new Set();
61697 var _iterator4 = _createForOfIteratorHelper(fetches),
61698 _step4;
61699 try {
61700 for (_iterator4.s(); !(_step4 = _iterator4.n()).done;) {
61701 var fetch = _step4.value;
61702 var _getTopologicalSortAn = getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict),
61703 sorted = _getTopologicalSortAn.sorted,
61704 recipientMap = _getTopologicalSortAn.recipientMap;
61705 // Merge sorted SymbolicTensor Arrays.
61706 var _iterator5 = _createForOfIteratorHelper(sorted),
61707 _step5;
61708 try {
61709 for (_iterator5.s(); !(_step5 = _iterator5.n()).done;) {
61710 var symbolicTensor = _step5.value;
61711 if (!visited.has(symbolicTensor.name)) {
61712 finalSorted.push(symbolicTensor);
61713 visited.add(symbolicTensor.name);
61714 }
61715 }
61716 // Merge recipient maps.
61717 } catch (err) {
61718 _iterator5.e(err);
61719 } finally {
61720 _iterator5.f();
61721 }
61722 var _loop = function _loop(name) {
61723 if (finalRecipientMap[name] == null) {
61724 finalRecipientMap[name] = new Set();
61725 }
61726 recipientMap[name].forEach(function (recipient) {
61727 return finalRecipientMap[name].add(recipient);
61728 });
61729 };
61730 for (var name in recipientMap) {
61731 _loop(name);
61732 }
61733 }
61734 } catch (err) {
61735 _iterator4.e(err);
61736 } finally {
61737 _iterator4.f();
61738 }
61739 }
61740 return {
61741 sorted: finalSorted,
61742 recipientCounts: recipientMap2Counts(finalRecipientMap)
61743 };
61744 }
61745 function recipientMap2Counts(recipientMap) {
61746 var recipientCounts = {};
61747 for (var name in recipientMap) {
61748 recipientCounts[name] = recipientMap[name].size;
61749 }
61750 return recipientCounts;
61751 }
61752 /**
61753 * Sort the `SymbolicTensor`s topologically, for a single fetch.
61754 *
61755 * This helper function processes the upstream SymbolicTensors of a single
61756 * fetch.
61757 *
61758 * @param fetch The single fetch requested.
61759 * @param feedDict The dictionary of fed values.
61760 * @returns sorted: Topologically-sorted array of SymbolicTensors.
61761 * recipientMap: Recipient names for all SymbolicTensors in `sorted`.
61762 */
61763 function getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict) {
61764 var visited = new Set();
61765 var sorted = [];
61766 var recipientMap = {};
61767 // Put keys of the feedDict into visited first, so they don't have to be
61768 // walked. This is needed in case where there are feeds for intermediate
61769 // SymbolicTensors of the graph.
61770 var _iterator6 = _createForOfIteratorHelper(feedDict.names()),
61771 _step6;
61772 try {
61773 for (_iterator6.s(); !(_step6 = _iterator6.n()).done;) {
61774 var key = _step6.value;
61775 visited.add(key);
61776 }
61777 } catch (err) {
61778 _iterator6.e(err);
61779 } finally {
61780 _iterator6.f();
61781 }
61782 var stack = [];
61783 var marks = [];
61784 // Initial population of stack and marks.
61785 stack.push(fetch);
61786 while (stack.length > 0) {
61787 var top = stack[stack.length - 1];
61788 if (visited.has(top.name)) {
61789 stack.pop();
61790 continue;
61791 }
61792 var topIsMarked = marks[marks.length - 1] === stack.length - 1;
61793 if (top.inputs.length === 0 || topIsMarked) {
61794 // Input SymbolicTensor or all children have been visited.
61795 stack.pop();
61796 sorted.push(top);
61797 visited.add(top.name);
61798 if (topIsMarked) {
61799 marks.pop();
61800 }
61801 } else {
61802 // A non-input SymbolicTensor whose upstream SymbolicTensors haven't
61803 // been visited yet. Push them onto the stack.
61804 marks.push(stack.length - 1);
61805 var _iterator7 = _createForOfIteratorHelper(top.inputs),
61806 _step7;
61807 try {
61808 for (_iterator7.s(); !(_step7 = _iterator7.n()).done;) {
61809 var input = _step7.value;
61810 // Increment the recipient count. Note that this needs to happen
61811 // regardless of whether the SymbolicTensor has been visited before.
61812 if (recipientMap[input.name] == null) {
61813 recipientMap[input.name] = new Set();
61814 }
61815 recipientMap[input.name].add(top.name);
61816 if (visited.has(input.name)) {
61817 continue; // Avoid repeated visits to the same SymbolicTensor.
61818 }
61819
61820 stack.push(input);
61821 }
61822 } catch (err) {
61823 _iterator7.e(err);
61824 } finally {
61825 _iterator7.f();
61826 }
61827 }
61828 }
61829 return {
61830 sorted: sorted,
61831 recipientMap: recipientMap
61832 };
61833 }
61834 /**
61835 * Get the symbolic output tensors of the node to which a given fetch belongs.
61836 * @param fetch The fetched symbolic tensor.
61837 * @returns The Array of symbolic tensors output by the node to which `fetch`
61838 * belongs.
61839 */
61840 function getNodeOutputs(fetch) {
61841 var layerOutputs;
61842 if (fetch.sourceLayer.inboundNodes.length === 1) {
61843 layerOutputs = fetch.sourceLayer.output;
61844 } else {
61845 var nodeIndex = null;
61846 for (var i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) {
61847 var _iterator8 = _createForOfIteratorHelper(fetch.sourceLayer.inboundNodes[i].outputTensors),
61848 _step8;
61849 try {
61850 for (_iterator8.s(); !(_step8 = _iterator8.n()).done;) {
61851 var outputTensor = _step8.value;
61852 if (outputTensor.id === fetch.id) {
61853 nodeIndex = i;
61854 break;
61855 }
61856 }
61857 } catch (err) {
61858 _iterator8.e(err);
61859 } finally {
61860 _iterator8.f();
61861 }
61862 }
61863 layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex);
61864 }
61865 return layerOutputs;
61866 }
61867
61868 /**
61869 * @license
61870 * Copyright 2022 Google LLC. All Rights Reserved.
61871 * Licensed under the Apache License, Version 2.0 (the "License");
61872 * you may not use this file except in compliance with the License.
61873 * You may obtain a copy of the License at
61874 *
61875 * http://www.apache.org/licenses/LICENSE-2.0
61876 *
61877 * Unless required by applicable law or agreed to in writing, software
61878 * distributed under the License is distributed on an "AS IS" BASIS,
61879 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61880 * See the License for the specific language governing permissions and
61881 * limitations under the License.
61882 * =============================================================================
61883 */
61884 var ENV$2 = env();
61885 /** The max number of entries for the caches of layers' topological sort. */
61886 ENV$2.registerFlag('TOPOLOGICAL_SORT_CACHE_MAX_ENTRIES', function () {
61887 return 100;
61888 }, updateCacheMaxEntries);
61889
61890 /**
61891 * Helper function used by many of the Constraints to find the L2Norms.
61892 */
61893 function calcL2Norms(w, axis) {
61894 return tidy(function () {
61895 return sqrt$2(sum$3(mul(w, w), axis, true));
61896 });
61897 }
61898 /**
61899 * Base class for functions that impose constraints on weight values
61900 *
61901 * @doc {
61902 * heading: 'Constraints',
61903 * subheading: 'Classes',
61904 * namespace: 'constraints'
61905 * }
61906 */
61907 var Constraint = /*#__PURE__*/function (_serialization$Serial) {
61908 _inherits(Constraint, _serialization$Serial);
61909 var _super = _createSuper(Constraint);
61910 function Constraint() {
61911 _classCallCheck(this, Constraint);
61912 return _super.apply(this, arguments);
61913 }
61914 _createClass(Constraint, [{
61915 key: "getConfig",
61916 value: function getConfig() {
61917 return {};
61918 }
61919 }]);
61920 return Constraint;
61921 }(Serializable);
61922 var MaxNorm = /*#__PURE__*/function (_Constraint) {
61923 _inherits(MaxNorm, _Constraint);
61924 var _super2 = _createSuper(MaxNorm);
61925 function MaxNorm(args) {
61926 var _this;
61927 _classCallCheck(this, MaxNorm);
61928 _this = _super2.call(this);
61929 _this.defaultMaxValue = 2;
61930 _this.defaultAxis = 0;
61931 _this.maxValue = args.maxValue != null ? args.maxValue : _this.defaultMaxValue;
61932 _this.axis = args.axis != null ? args.axis : _this.defaultAxis;
61933 return _this;
61934 }
61935 _createClass(MaxNorm, [{
61936 key: "apply",
61937 value: function apply(w) {
61938 var _this2 = this;
61939 return tidy(function () {
61940 var norms = calcL2Norms(w, _this2.axis);
61941 var desired = clipByValue$2(norms, 0, _this2.maxValue);
61942 return mul(w, div$1(desired, add$3(epsilon$1(), norms)));
61943 });
61944 }
61945 }, {
61946 key: "getConfig",
61947 value: function getConfig() {
61948 return {
61949 maxValue: this.maxValue,
61950 axis: this.axis
61951 };
61952 }
61953 }]);
61954 return MaxNorm;
61955 }(Constraint);
61956 /** @nocollapse */
61957 MaxNorm.className = 'MaxNorm';
61958 registerClass(MaxNorm);
61959 var UnitNorm = /*#__PURE__*/function (_Constraint2) {
61960 _inherits(UnitNorm, _Constraint2);
61961 var _super3 = _createSuper(UnitNorm);
61962 function UnitNorm(args) {
61963 var _this3;
61964 _classCallCheck(this, UnitNorm);
61965 _this3 = _super3.call(this);
61966 _this3.defaultAxis = 0;
61967 _this3.axis = args.axis != null ? args.axis : _this3.defaultAxis;
61968 return _this3;
61969 }
61970 _createClass(UnitNorm, [{
61971 key: "apply",
61972 value: function apply(w) {
61973 var _this4 = this;
61974 return tidy(function () {
61975 return div$1(w, add$3(epsilon$1(), calcL2Norms(w, _this4.axis)));
61976 });
61977 }
61978 }, {
61979 key: "getConfig",
61980 value: function getConfig() {
61981 return {
61982 axis: this.axis
61983 };
61984 }
61985 }]);
61986 return UnitNorm;
61987 }(Constraint);
61988 /** @nocollapse */
61989 UnitNorm.className = 'UnitNorm';
61990 registerClass(UnitNorm);
61991 var NonNeg = /*#__PURE__*/function (_Constraint3) {
61992 _inherits(NonNeg, _Constraint3);
61993 var _super4 = _createSuper(NonNeg);
61994 function NonNeg() {
61995 _classCallCheck(this, NonNeg);
61996 return _super4.apply(this, arguments);
61997 }
61998 _createClass(NonNeg, [{
61999 key: "apply",
62000 value: function apply(w) {
62001 return relu$2(w);
62002 }
62003 }]);
62004 return NonNeg;
62005 }(Constraint);
62006 /** @nocollapse */
62007 NonNeg.className = 'NonNeg';
62008 registerClass(NonNeg);
62009 var MinMaxNorm = /*#__PURE__*/function (_Constraint4) {
62010 _inherits(MinMaxNorm, _Constraint4);
62011 var _super5 = _createSuper(MinMaxNorm);
62012 function MinMaxNorm(args) {
62013 var _this5;
62014 _classCallCheck(this, MinMaxNorm);
62015 _this5 = _super5.call(this);
62016 _this5.defaultMinValue = 0.0;
62017 _this5.defaultMaxValue = 1.0;
62018 _this5.defaultRate = 1.0;
62019 _this5.defaultAxis = 0;
62020 _this5.minValue = args.minValue != null ? args.minValue : _this5.defaultMinValue;
62021 _this5.maxValue = args.maxValue != null ? args.maxValue : _this5.defaultMaxValue;
62022 _this5.rate = args.rate != null ? args.rate : _this5.defaultRate;
62023 _this5.axis = args.axis != null ? args.axis : _this5.defaultAxis;
62024 return _this5;
62025 }
62026 _createClass(MinMaxNorm, [{
62027 key: "apply",
62028 value: function apply(w) {
62029 var _this6 = this;
62030 return tidy(function () {
62031 var norms = calcL2Norms(w, _this6.axis);
62032 var desired = add$3(mul(_this6.rate, clipByValue$2(norms, _this6.minValue, _this6.maxValue)), mul(1.0 - _this6.rate, norms));
62033 return mul(w, div$1(desired, add$3(epsilon$1(), norms)));
62034 });
62035 }
62036 }, {
62037 key: "getConfig",
62038 value: function getConfig() {
62039 return {
62040 minValue: this.minValue,
62041 maxValue: this.maxValue,
62042 rate: this.rate,
62043 axis: this.axis
62044 };
62045 }
62046 }]);
62047 return MinMaxNorm;
62048 }(Constraint);
62049 /** @nocollapse */
62050 MinMaxNorm.className = 'MinMaxNorm';
62051 registerClass(MinMaxNorm);
62052 // Maps the JavaScript-like identifier keys to the corresponding registry
62053 // symbols.
62054 var CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
62055 'maxNorm': 'MaxNorm',
62056 'minMaxNorm': 'MinMaxNorm',
62057 'nonNeg': 'NonNeg',
62058 'unitNorm': 'UnitNorm'
62059 };
62060 function serializeConstraint(constraint) {
62061 return serializeKerasObject(constraint);
62062 }
62063 function deserializeConstraint(config) {
62064 var customObjects = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
62065 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'constraint');
62066 }
62067 function getConstraint(identifier) {
62068 if (identifier == null) {
62069 return null;
62070 }
62071 if (typeof identifier === 'string') {
62072 var className = identifier in CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP ? CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier;
62073 var config = {
62074 className: className,
62075 config: {}
62076 };
62077 return deserializeConstraint(config);
62078 } else if (identifier instanceof Constraint) {
62079 return identifier;
62080 } else {
62081 return deserializeConstraint(identifier);
62082 }
62083 }
62084
62085 /**
62086 * @license
62087 * Copyright 2018 Google LLC
62088 *
62089 * Use of this source code is governed by an MIT-style
62090 * license that can be found in the LICENSE file or at
62091 * https://opensource.org/licenses/MIT.
62092 * =============================================================================
62093 */
62094 /**
62095 * MaxNorm weight constraint.
62096 *
62097 * Constrains the weights incident to each hidden unit
62098 * to have a norm less than or equal to a desired value.
62099 *
62100 * References
62101 * - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting
62102 * Srivastava, Hinton, et al.
62103 * 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
62104 *
62105 * @doc {heading: 'Constraints',namespace: 'constraints'}
62106 */
62107 function maxNorm(args) {
62108 return new MaxNorm(args);
62109 }
62110 /**
62111 * Constrains the weights incident to each hidden unit to have unit norm.
62112 *
62113 * @doc {heading: 'Constraints', namespace: 'constraints'}
62114 */
62115 function unitNorm(args) {
62116 return new UnitNorm(args);
62117 }
62118 /**
62119 * Constrains the weight to be non-negative.
62120 *
62121 * @doc {heading: 'Constraints', namespace: 'constraints'}
62122 */
62123 function nonNeg() {
62124 return new NonNeg();
62125 }
62126 /** @doc {heading: 'Constraints', namespace: 'constraints'} */
62127 function minMaxNorm(config) {
62128 return new MinMaxNorm(config);
62129 }
62130
62131 var exports_constraints = {
62132 __proto__: null,
62133 maxNorm: maxNorm,
62134 minMaxNorm: minMaxNorm,
62135 nonNeg: nonNeg,
62136 unitNorm: unitNorm
62137 };
62138
62139 /**
62140 * @license
62141 * Copyright 2018 Google LLC
62142 *
62143 * Use of this source code is governed by an MIT-style
62144 * license that can be found in the LICENSE file or at
62145 * https://opensource.org/licenses/MIT.
62146 * =============================================================================
62147 */
62148 /**
62149 * Initializer that generates tensors initialized to 0.
62150 *
62151 * @doc {heading: 'Initializers', namespace: 'initializers'}
62152 */
62153 function zeros$1() {
62154 return new Zeros();
62155 }
62156 /**
62157 * Initializer that generates tensors initialized to 1.
62158 *
62159 * @doc {heading: 'Initializers', namespace: 'initializers'}
62160 */
62161 function ones() {
62162 return new Ones();
62163 }
62164 /**
62165 * Initializer that generates values initialized to some constant.
62166 *
62167 * @doc {heading: 'Initializers', namespace: 'initializers'}
62168 */
62169 function constant(args) {
62170 return new Constant(args);
62171 }
62172 /**
62173 * Initializer that generates random values initialized to a uniform
62174 * distribution.
62175 *
62176 * Values will be distributed uniformly between the configured minval and
62177 * maxval.
62178 *
62179 * @doc {heading: 'Initializers', namespace: 'initializers'}
62180 */
62181 function randomUniform(args) {
62182 return new RandomUniform(args);
62183 }
62184 /**
62185 * Initializer that generates random values initialized to a normal
62186 * distribution.
62187 *
62188 * @doc {heading: 'Initializers', namespace: 'initializers'}
62189 */
62190 function randomNormal(args) {
62191 return new RandomNormal(args);
62192 }
62193 /**
62194 * Initializer that generates random values initialized to a truncated normal
62195 * distribution.
62196 *
62197 * These values are similar to values from a `RandomNormal` except that values
62198 * more than two standard deviations from the mean are discarded and re-drawn.
62199 * This is the recommended initializer for neural network weights and filters.
62200 *
62201 * @doc {heading: 'Initializers', namespace: 'initializers'}
62202 */
62203 function truncatedNormal(args) {
62204 return new TruncatedNormal(args);
62205 }
62206 /**
62207 * Initializer that generates the identity matrix.
62208 * Only use for square 2D matrices.
62209 *
62210 * @doc {heading: 'Initializers', namespace: 'initializers'}
62211 */
62212 function identity$2(args) {
62213 return new Identity(args);
62214 }
62215 /**
62216 * Initializer capable of adapting its scale to the shape of weights.
62217 * With distribution=NORMAL, samples are drawn from a truncated normal
62218 * distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
62219 * - number of input units in the weight tensor, if mode = FAN_IN.
62220 * - number of output units, if mode = FAN_OUT.
62221 * - average of the numbers of input and output units, if mode = FAN_AVG.
62222 * With distribution=UNIFORM,
62223 * samples are drawn from a uniform distribution
62224 * within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
62225 *
62226 * @doc {heading: 'Initializers',namespace: 'initializers'}
62227 */
62228 function varianceScaling(config) {
62229 return new VarianceScaling(config);
62230 }
62231 /**
62232 * Glorot uniform initializer, also called Xavier uniform initializer.
62233 * It draws samples from a uniform distribution within [-limit, limit]
62234 * where `limit` is `sqrt(6 / (fan_in + fan_out))`
62235 * where `fan_in` is the number of input units in the weight tensor
62236 * and `fan_out` is the number of output units in the weight tensor
62237 *
62238 * Reference:
62239 * Glorot & Bengio, AISTATS 2010
62240 * http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf.
62241 *
62242 * @doc {heading: 'Initializers', namespace: 'initializers'}
62243 */
62244 function glorotUniform(args) {
62245 return new GlorotUniform(args);
62246 }
62247 /**
62248 * Glorot normal initializer, also called Xavier normal initializer.
62249 * It draws samples from a truncated normal distribution centered on 0
62250 * with `stddev = sqrt(2 / (fan_in + fan_out))`
62251 * where `fan_in` is the number of input units in the weight tensor
62252 * and `fan_out` is the number of output units in the weight tensor.
62253 *
62254 * Reference:
62255 * Glorot & Bengio, AISTATS 2010
62256 * http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
62257 *
62258 * @doc {heading: 'Initializers', namespace: 'initializers'}
62259 */
62260 function glorotNormal(args) {
62261 return new GlorotNormal(args);
62262 }
62263 /**
62264 * He normal initializer.
62265 *
62266 * It draws samples from a truncated normal distribution centered on 0
62267 * with `stddev = sqrt(2 / fanIn)`
62268 * where `fanIn` is the number of input units in the weight tensor.
62269 *
62270 * Reference:
62271 * He et al., http://arxiv.org/abs/1502.01852
62272 *
62273 * @doc {heading: 'Initializers', namespace: 'initializers'}
62274 */
62275 function heNormal(args) {
62276 return new HeNormal(args);
62277 }
62278 /**
62279 * He uniform initializer.
62280 *
62281 * It draws samples from a uniform distribution within [-limit, limit]
62282 * where `limit` is `sqrt(6 / fan_in)`
62283 * where `fanIn` is the number of input units in the weight tensor.
62284 *
62285 * Reference:
62286 * He et al., http://arxiv.org/abs/1502.01852
62287 *
62288 * @doc {heading: 'Initializers',namespace: 'initializers'}
62289 */
62290 function heUniform(args) {
62291 return new HeUniform(args);
62292 }
62293 /**
62294 * LeCun normal initializer.
62295 *
62296 * It draws samples from a truncated normal distribution centered on 0
62297 * with `stddev = sqrt(1 / fanIn)`
62298 * where `fanIn` is the number of input units in the weight tensor.
62299 *
62300 * References:
62301 * [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
62302 * [Efficient Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
62303 *
62304 * @doc {heading: 'Initializers', namespace: 'initializers'}
62305 */
62306 function leCunNormal(args) {
62307 return new LeCunNormal(args);
62308 }
62309 /**
62310 * LeCun uniform initializer.
62311 *
62312 * It draws samples from a uniform distribution in the interval
62313 * `[-limit, limit]` with `limit = sqrt(3 / fanIn)`,
62314 * where `fanIn` is the number of input units in the weight tensor.
62315 *
62316 * @doc {heading: 'Initializers', namespace: 'initializers'}
62317 */
62318 function leCunUniform(args) {
62319 return new LeCunUniform(args);
62320 }
62321 /**
62322 * Initializer that generates a random orthogonal matrix.
62323 *
62324 * Reference:
62325 * [Saxe et al., http://arxiv.org/abs/1312.6120](http://arxiv.org/abs/1312.6120)
62326 *
62327 * @doc {heading: 'Initializers', namespace: 'initializers'}
62328 */
62329 function orthogonal(args) {
62330 return new Orthogonal(args);
62331 }
62332
62333 var exports_initializers = {
62334 __proto__: null,
62335 constant: constant,
62336 glorotNormal: glorotNormal,
62337 glorotUniform: glorotUniform,
62338 heNormal: heNormal,
62339 heUniform: heUniform,
62340 identity: identity$2,
62341 leCunNormal: leCunNormal,
62342 leCunUniform: leCunUniform,
62343 ones: ones,
62344 orthogonal: orthogonal,
62345 randomNormal: randomNormal,
62346 randomUniform: randomUniform,
62347 truncatedNormal: truncatedNormal,
62348 varianceScaling: varianceScaling,
62349 zeros: zeros$1
62350 };
62351
62352 /**
62353 * Turn any Scalar values in a Logs object into actual number values.
62354 *
62355 * @param logs The `Logs` object to be resolved in place.
62356 */
62357 function resolveScalarsInLogs(_x) {
62358 return _resolveScalarsInLogs.apply(this, arguments);
62359 }
62360 /**
62361 * Dispose all Tensors in an UnresolvedLogs object.
62362 *
62363 * @param logs An `UnresolvedLogs` object potentially containing `tf.Tensor`s in
62364 * places where the values can be `tf.Tensor` or `number`.
62365 */
62366 function _resolveScalarsInLogs() {
62367 _resolveScalarsInLogs = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(logs) {
62368 var promises, keys, scalarsToDispose, key, value, valueScalar, values, i;
62369 return _regeneratorRuntime().wrap(function _callee$(_context) {
62370 while (1) switch (_context.prev = _context.next) {
62371 case 0:
62372 if (!(logs == null)) {
62373 _context.next = 2;
62374 break;
62375 }
62376 return _context.abrupt("return");
62377 case 2:
62378 promises = [];
62379 keys = [];
62380 scalarsToDispose = [];
62381 for (key in logs) {
62382 value = logs[key];
62383 if (typeof value !== 'number') {
62384 valueScalar = value;
62385 promises.push(valueScalar.data());
62386 keys.push(key);
62387 scalarsToDispose.push(valueScalar);
62388 }
62389 }
62390 if (!(promises.length > 0)) {
62391 _context.next = 12;
62392 break;
62393 }
62394 _context.next = 9;
62395 return Promise.all(promises);
62396 case 9:
62397 values = _context.sent;
62398 for (i = 0; i < values.length; ++i) {
62399 logs[keys[i]] = values[i][0];
62400 }
62401 // Dispose the original scalar tensors.
62402 dispose(scalarsToDispose);
62403 case 12:
62404 case "end":
62405 return _context.stop();
62406 }
62407 }, _callee);
62408 }));
62409 return _resolveScalarsInLogs.apply(this, arguments);
62410 }
62411 function disposeTensorsInLogs(logs) {
62412 if (logs == null) {
62413 return;
62414 }
62415 for (var key in logs) {
62416 var value = logs[key];
62417 if (typeof value !== 'number') {
62418 value.dispose();
62419 }
62420 }
62421 }
62422
62423 /** Verbosity logging level when fitting a model. */
62424 var ModelLoggingVerbosity;
62425 (function (ModelLoggingVerbosity) {
62426 ModelLoggingVerbosity[ModelLoggingVerbosity["SILENT"] = 0] = "SILENT";
62427 ModelLoggingVerbosity[ModelLoggingVerbosity["VERBOSE"] = 1] = "VERBOSE";
62428 })(ModelLoggingVerbosity || (ModelLoggingVerbosity = {}));
62429 /** How often to yield to the main thread when training (in ms). */
62430 var DEFAULT_YIELD_EVERY_MS = 125;
62431 /**
62432 * Abstract base class used to build new callbacks.
62433 *
62434 * The `logs` dictionary that callback methods take as argument will contain
62435 * keys for quantities relevant to the current batch or epoch.
62436 *
62437 * Currently, the `.fit()` method of the `Sequential` model class
62438 * will include the following quantities in the `logs` that
62439 * it passes to its callbacks:
62440 *
62441 * onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss`
62442 * (if validation is enabled in `fit`), and `valAcc` (if validation and
62443 * accuracy monitoring are enabled).
62444 * onBatchBegin: Logs include `size`, the number of samples in the current
62445 * batch.
62446 * onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring
62447 * is enabled).
62448 */
62449 var BaseCallback = /*#__PURE__*/function () {
62450 function BaseCallback() {
62451 _classCallCheck(this, BaseCallback);
62452 // TODO(michaelterry): This type is a best guess.
62453 this.validationData = null;
62454 }
62455 _createClass(BaseCallback, [{
62456 key: "setParams",
62457 value: function setParams(params) {
62458 this.params = params;
62459 }
62460 }, {
62461 key: "onEpochBegin",
62462 value: function () {
62463 var _onEpochBegin = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(epoch, logs) {
62464 return _regeneratorRuntime().wrap(function _callee$(_context) {
62465 while (1) switch (_context.prev = _context.next) {
62466 case 0:
62467 case "end":
62468 return _context.stop();
62469 }
62470 }, _callee);
62471 }));
62472 function onEpochBegin(_x, _x2) {
62473 return _onEpochBegin.apply(this, arguments);
62474 }
62475 return onEpochBegin;
62476 }()
62477 }, {
62478 key: "onEpochEnd",
62479 value: function () {
62480 var _onEpochEnd = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(epoch, logs) {
62481 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
62482 while (1) switch (_context2.prev = _context2.next) {
62483 case 0:
62484 case "end":
62485 return _context2.stop();
62486 }
62487 }, _callee2);
62488 }));
62489 function onEpochEnd(_x3, _x4) {
62490 return _onEpochEnd.apply(this, arguments);
62491 }
62492 return onEpochEnd;
62493 }()
62494 }, {
62495 key: "onBatchBegin",
62496 value: function () {
62497 var _onBatchBegin = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(batch, logs) {
62498 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
62499 while (1) switch (_context3.prev = _context3.next) {
62500 case 0:
62501 case "end":
62502 return _context3.stop();
62503 }
62504 }, _callee3);
62505 }));
62506 function onBatchBegin(_x5, _x6) {
62507 return _onBatchBegin.apply(this, arguments);
62508 }
62509 return onBatchBegin;
62510 }()
62511 }, {
62512 key: "onBatchEnd",
62513 value: function () {
62514 var _onBatchEnd = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(batch, logs) {
62515 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
62516 while (1) switch (_context4.prev = _context4.next) {
62517 case 0:
62518 case "end":
62519 return _context4.stop();
62520 }
62521 }, _callee4);
62522 }));
62523 function onBatchEnd(_x7, _x8) {
62524 return _onBatchEnd.apply(this, arguments);
62525 }
62526 return onBatchEnd;
62527 }()
62528 }, {
62529 key: "onTrainBegin",
62530 value: function () {
62531 var _onTrainBegin = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5(logs) {
62532 return _regeneratorRuntime().wrap(function _callee5$(_context5) {
62533 while (1) switch (_context5.prev = _context5.next) {
62534 case 0:
62535 case "end":
62536 return _context5.stop();
62537 }
62538 }, _callee5);
62539 }));
62540 function onTrainBegin(_x9) {
62541 return _onTrainBegin.apply(this, arguments);
62542 }
62543 return onTrainBegin;
62544 }()
62545 }, {
62546 key: "onTrainEnd",
62547 value: function () {
62548 var _onTrainEnd = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee6(logs) {
62549 return _regeneratorRuntime().wrap(function _callee6$(_context6) {
62550 while (1) switch (_context6.prev = _context6.next) {
62551 case 0:
62552 case "end":
62553 return _context6.stop();
62554 }
62555 }, _callee6);
62556 }));
62557 function onTrainEnd(_x10) {
62558 return _onTrainEnd.apply(this, arguments);
62559 }
62560 return onTrainEnd;
62561 }() // LayersModel needs to call Callback.setModel(), but cannot actually depend
62562 // on Callback because that creates a cyclic dependency. Providing this no-op
62563 // method on BaseCallback breaks the cycle: this way LayersModel can depend on
62564 // BaseCallback but not on Callback. The argument is typed as `Container`
62565 // (the superclass of LayersModel) to avoid recapitulating the cycle. Callback
62566 // overrides this method and enforces that the argument is really a
62567 // LayersModel.
62568 }, {
62569 key: "setModel",
62570 value: function setModel(model) {
62571 // Do nothing. Use Callback instead of BaseCallback to track the model.
62572 }
62573 }]);
62574 return BaseCallback;
62575 }();
62576 /**
62577 * Container abstracting a list of callbacks.
62578 */
62579 var CallbackList = /*#__PURE__*/function () {
62580 // TODO(cais): When the need arises, uncomment the following lines and
62581 // implement the queue for time values.
62582 // private deltaTBatch: number;
62583 // private deltaTsBatchBegin: Array<number>;
62584 // private deltaTsBatchEnd: Array<number>;
62585 /**
62586 * Constructor of CallbackList.
62587 * @param callbacks Array of `Callback` instances.
62588 * @param queueLength Queue length for keeping running statistics over
62589 * callback execution time.
62590 */
62591 function CallbackList(callbacks) {
62592 var queueLength = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 10;
62593 _classCallCheck(this, CallbackList);
62594 // TODO(cais): Make use of queueLength when implementing the queue for time
62595 // values.
62596 if (callbacks == null) {
62597 callbacks = [];
62598 }
62599 this.callbacks = callbacks;
62600 this.queueLength = queueLength;
62601 }
62602 _createClass(CallbackList, [{
62603 key: "append",
62604 value: function append(callback) {
62605 this.callbacks.push(callback);
62606 }
62607 }, {
62608 key: "setParams",
62609 value: function setParams(params) {
62610 var _iterator = _createForOfIteratorHelper(this.callbacks),
62611 _step;
62612 try {
62613 for (_iterator.s(); !(_step = _iterator.n()).done;) {
62614 var callback = _step.value;
62615 callback.setParams(params);
62616 }
62617 } catch (err) {
62618 _iterator.e(err);
62619 } finally {
62620 _iterator.f();
62621 }
62622 }
62623 }, {
62624 key: "setModel",
62625 value: function setModel(model) {
62626 var _iterator2 = _createForOfIteratorHelper(this.callbacks),
62627 _step2;
62628 try {
62629 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
62630 var callback = _step2.value;
62631 callback.setModel(model);
62632 }
62633 } catch (err) {
62634 _iterator2.e(err);
62635 } finally {
62636 _iterator2.f();
62637 }
62638 }
62639 /**
62640 * Called at the start of an epoch.
62641 * @param epoch Index of epoch.
62642 * @param logs Dictionary of logs.
62643 */
62644 }, {
62645 key: "onEpochBegin",
62646 value: function () {
62647 var _onEpochBegin2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee7(epoch, logs) {
62648 var _iterator3, _step3, callback;
62649 return _regeneratorRuntime().wrap(function _callee7$(_context7) {
62650 while (1) switch (_context7.prev = _context7.next) {
62651 case 0:
62652 if (logs == null) {
62653 logs = {};
62654 }
62655 _iterator3 = _createForOfIteratorHelper(this.callbacks);
62656 _context7.prev = 2;
62657 _iterator3.s();
62658 case 4:
62659 if ((_step3 = _iterator3.n()).done) {
62660 _context7.next = 10;
62661 break;
62662 }
62663 callback = _step3.value;
62664 _context7.next = 8;
62665 return callback.onEpochBegin(epoch, logs);
62666 case 8:
62667 _context7.next = 4;
62668 break;
62669 case 10:
62670 _context7.next = 15;
62671 break;
62672 case 12:
62673 _context7.prev = 12;
62674 _context7.t0 = _context7["catch"](2);
62675 _iterator3.e(_context7.t0);
62676 case 15:
62677 _context7.prev = 15;
62678 _iterator3.f();
62679 return _context7.finish(15);
62680 case 18:
62681 case "end":
62682 return _context7.stop();
62683 }
62684 }, _callee7, this, [[2, 12, 15, 18]]);
62685 }));
62686 function onEpochBegin(_x11, _x12) {
62687 return _onEpochBegin2.apply(this, arguments);
62688 }
62689 return onEpochBegin;
62690 }()
62691 /**
62692 * Called at the end of an epoch.
62693 * @param epoch Index of epoch.
62694 * @param logs Dictionary of logs.
62695 */
62696 }, {
62697 key: "onEpochEnd",
62698 value: function () {
62699 var _onEpochEnd2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee8(epoch, logs) {
62700 var _iterator4, _step4, callback;
62701 return _regeneratorRuntime().wrap(function _callee8$(_context8) {
62702 while (1) switch (_context8.prev = _context8.next) {
62703 case 0:
62704 if (logs == null) {
62705 logs = {};
62706 }
62707 _iterator4 = _createForOfIteratorHelper(this.callbacks);
62708 _context8.prev = 2;
62709 _iterator4.s();
62710 case 4:
62711 if ((_step4 = _iterator4.n()).done) {
62712 _context8.next = 10;
62713 break;
62714 }
62715 callback = _step4.value;
62716 _context8.next = 8;
62717 return callback.onEpochEnd(epoch, logs);
62718 case 8:
62719 _context8.next = 4;
62720 break;
62721 case 10:
62722 _context8.next = 15;
62723 break;
62724 case 12:
62725 _context8.prev = 12;
62726 _context8.t0 = _context8["catch"](2);
62727 _iterator4.e(_context8.t0);
62728 case 15:
62729 _context8.prev = 15;
62730 _iterator4.f();
62731 return _context8.finish(15);
62732 case 18:
62733 case "end":
62734 return _context8.stop();
62735 }
62736 }, _callee8, this, [[2, 12, 15, 18]]);
62737 }));
62738 function onEpochEnd(_x13, _x14) {
62739 return _onEpochEnd2.apply(this, arguments);
62740 }
62741 return onEpochEnd;
62742 }()
62743 /**
62744 * Called right before processing a batch.
62745 * @param batch Index of batch within the current epoch.
62746 * @param logs Dictionary of logs.
62747 */
62748 }, {
62749 key: "onBatchBegin",
62750 value: function () {
62751 var _onBatchBegin2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee9(batch, logs) {
62752 var _iterator5, _step5, callback;
62753 return _regeneratorRuntime().wrap(function _callee9$(_context9) {
62754 while (1) switch (_context9.prev = _context9.next) {
62755 case 0:
62756 if (logs == null) {
62757 logs = {};
62758 }
62759 _iterator5 = _createForOfIteratorHelper(this.callbacks);
62760 _context9.prev = 2;
62761 _iterator5.s();
62762 case 4:
62763 if ((_step5 = _iterator5.n()).done) {
62764 _context9.next = 10;
62765 break;
62766 }
62767 callback = _step5.value;
62768 _context9.next = 8;
62769 return callback.onBatchBegin(batch, logs);
62770 case 8:
62771 _context9.next = 4;
62772 break;
62773 case 10:
62774 _context9.next = 15;
62775 break;
62776 case 12:
62777 _context9.prev = 12;
62778 _context9.t0 = _context9["catch"](2);
62779 _iterator5.e(_context9.t0);
62780 case 15:
62781 _context9.prev = 15;
62782 _iterator5.f();
62783 return _context9.finish(15);
62784 case 18:
62785 case "end":
62786 return _context9.stop();
62787 }
62788 }, _callee9, this, [[2, 12, 15, 18]]);
62789 }));
62790 function onBatchBegin(_x15, _x16) {
62791 return _onBatchBegin2.apply(this, arguments);
62792 }
62793 return onBatchBegin;
62794 }()
62795 /**
62796 * Called at the end of a batch.
62797 * @param batch Index of batch within the current epoch.
62798 * @param logs Dictionary of logs.
62799 */
62800 }, {
62801 key: "onBatchEnd",
62802 value: function () {
62803 var _onBatchEnd2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee10(batch, logs) {
62804 var _iterator6, _step6, callback;
62805 return _regeneratorRuntime().wrap(function _callee10$(_context10) {
62806 while (1) switch (_context10.prev = _context10.next) {
62807 case 0:
62808 if (logs == null) {
62809 logs = {};
62810 }
62811 _iterator6 = _createForOfIteratorHelper(this.callbacks);
62812 _context10.prev = 2;
62813 _iterator6.s();
62814 case 4:
62815 if ((_step6 = _iterator6.n()).done) {
62816 _context10.next = 10;
62817 break;
62818 }
62819 callback = _step6.value;
62820 _context10.next = 8;
62821 return callback.onBatchEnd(batch, logs);
62822 case 8:
62823 _context10.next = 4;
62824 break;
62825 case 10:
62826 _context10.next = 15;
62827 break;
62828 case 12:
62829 _context10.prev = 12;
62830 _context10.t0 = _context10["catch"](2);
62831 _iterator6.e(_context10.t0);
62832 case 15:
62833 _context10.prev = 15;
62834 _iterator6.f();
62835 return _context10.finish(15);
62836 case 18:
62837 case "end":
62838 return _context10.stop();
62839 }
62840 }, _callee10, this, [[2, 12, 15, 18]]);
62841 }));
62842 function onBatchEnd(_x17, _x18) {
62843 return _onBatchEnd2.apply(this, arguments);
62844 }
62845 return onBatchEnd;
62846 }()
62847 /**
62848 * Called at the beginning of training.
62849 * @param logs Dictionary of logs.
62850 */
62851 }, {
62852 key: "onTrainBegin",
62853 value: function () {
62854 var _onTrainBegin2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee11(logs) {
62855 var _iterator7, _step7, callback;
62856 return _regeneratorRuntime().wrap(function _callee11$(_context11) {
62857 while (1) switch (_context11.prev = _context11.next) {
62858 case 0:
62859 if (logs == null) {
62860 logs = {};
62861 }
62862 _iterator7 = _createForOfIteratorHelper(this.callbacks);
62863 _context11.prev = 2;
62864 _iterator7.s();
62865 case 4:
62866 if ((_step7 = _iterator7.n()).done) {
62867 _context11.next = 10;
62868 break;
62869 }
62870 callback = _step7.value;
62871 _context11.next = 8;
62872 return callback.onTrainBegin(logs);
62873 case 8:
62874 _context11.next = 4;
62875 break;
62876 case 10:
62877 _context11.next = 15;
62878 break;
62879 case 12:
62880 _context11.prev = 12;
62881 _context11.t0 = _context11["catch"](2);
62882 _iterator7.e(_context11.t0);
62883 case 15:
62884 _context11.prev = 15;
62885 _iterator7.f();
62886 return _context11.finish(15);
62887 case 18:
62888 case "end":
62889 return _context11.stop();
62890 }
62891 }, _callee11, this, [[2, 12, 15, 18]]);
62892 }));
62893 function onTrainBegin(_x19) {
62894 return _onTrainBegin2.apply(this, arguments);
62895 }
62896 return onTrainBegin;
62897 }()
62898 /**
62899 * Called at the end of training.
62900 * @param logs Dictionary of logs.
62901 */
62902 }, {
62903 key: "onTrainEnd",
62904 value: function () {
62905 var _onTrainEnd2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee12(logs) {
62906 var _iterator8, _step8, callback;
62907 return _regeneratorRuntime().wrap(function _callee12$(_context12) {
62908 while (1) switch (_context12.prev = _context12.next) {
62909 case 0:
62910 if (logs == null) {
62911 logs = {};
62912 }
62913 _iterator8 = _createForOfIteratorHelper(this.callbacks);
62914 _context12.prev = 2;
62915 _iterator8.s();
62916 case 4:
62917 if ((_step8 = _iterator8.n()).done) {
62918 _context12.next = 10;
62919 break;
62920 }
62921 callback = _step8.value;
62922 _context12.next = 8;
62923 return callback.onTrainEnd(logs);
62924 case 8:
62925 _context12.next = 4;
62926 break;
62927 case 10:
62928 _context12.next = 15;
62929 break;
62930 case 12:
62931 _context12.prev = 12;
62932 _context12.t0 = _context12["catch"](2);
62933 _iterator8.e(_context12.t0);
62934 case 15:
62935 _context12.prev = 15;
62936 _iterator8.f();
62937 return _context12.finish(15);
62938 case 18:
62939 case "end":
62940 return _context12.stop();
62941 }
62942 }, _callee12, this, [[2, 12, 15, 18]]);
62943 }));
62944 function onTrainEnd(_x20) {
62945 return _onTrainEnd2.apply(this, arguments);
62946 }
62947 return onTrainEnd;
62948 }()
62949 }]);
62950 return CallbackList;
62951 }();
62952 /**
62953 * Callback that accumulates epoch averages of metrics.
62954 *
62955 * This callback is automatically applied to every LayersModel.
62956 */
62957 var BaseLogger = /*#__PURE__*/function (_BaseCallback) {
62958 _inherits(BaseLogger, _BaseCallback);
62959 var _super = _createSuper(BaseLogger);
62960 function BaseLogger() {
62961 _classCallCheck(this, BaseLogger);
62962 return _super.call(this);
62963 }
62964 _createClass(BaseLogger, [{
62965 key: "onEpochBegin",
62966 value: function () {
62967 var _onEpochBegin3 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee13(epoch) {
62968 return _regeneratorRuntime().wrap(function _callee13$(_context13) {
62969 while (1) switch (_context13.prev = _context13.next) {
62970 case 0:
62971 this.seen = 0;
62972 this.totals = {};
62973 case 2:
62974 case "end":
62975 return _context13.stop();
62976 }
62977 }, _callee13, this);
62978 }));
62979 function onEpochBegin(_x21) {
62980 return _onEpochBegin3.apply(this, arguments);
62981 }
62982 return onEpochBegin;
62983 }()
62984 }, {
62985 key: "onBatchEnd",
62986 value: function () {
62987 var _onBatchEnd3 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee14(batch, logs) {
62988 var _this = this;
62989 var batchSize, _loop, key;
62990 return _regeneratorRuntime().wrap(function _callee14$(_context15) {
62991 while (1) switch (_context15.prev = _context15.next) {
62992 case 0:
62993 if (logs == null) {
62994 logs = {};
62995 }
62996 batchSize = logs['size'] == null ? 0 : logs['size'];
62997 this.seen += batchSize;
62998 _loop = /*#__PURE__*/_regeneratorRuntime().mark(function _loop(key) {
62999 var value, oldTotalsToDispose, total;
63000 return _regeneratorRuntime().wrap(function _loop$(_context14) {
63001 while (1) switch (_context14.prev = _context14.next) {
63002 case 0:
63003 value = logs[key];
63004 if (typeof value === 'number') {
63005 if (!_this.totals.hasOwnProperty(key)) {
63006 _this.totals[key] = 0;
63007 }
63008 _this.totals[key] = _this.totals[key] + value * batchSize;
63009 } else {
63010 if (key in _this.totals) {
63011 oldTotalsToDispose = _this.totals[key];
63012 } else {
63013 _this.totals[key] = 0;
63014 }
63015 total = tidy(function () {
63016 return add$3(_this.totals[key], mul(value, batchSize));
63017 });
63018 _this.totals[key] = total;
63019 if (oldTotalsToDispose != null) {
63020 oldTotalsToDispose.dispose();
63021 }
63022 }
63023 case 2:
63024 case "end":
63025 return _context14.stop();
63026 }
63027 }, _loop);
63028 });
63029 _context15.t0 = _regeneratorRuntime().keys(logs);
63030 case 5:
63031 if ((_context15.t1 = _context15.t0()).done) {
63032 _context15.next = 10;
63033 break;
63034 }
63035 key = _context15.t1.value;
63036 return _context15.delegateYield(_loop(key), "t2", 8);
63037 case 8:
63038 _context15.next = 5;
63039 break;
63040 case 10:
63041 case "end":
63042 return _context15.stop();
63043 }
63044 }, _callee14, this);
63045 }));
63046 function onBatchEnd(_x22, _x23) {
63047 return _onBatchEnd3.apply(this, arguments);
63048 }
63049 return onBatchEnd;
63050 }()
63051 }, {
63052 key: "onEpochEnd",
63053 value: function () {
63054 var _onEpochEnd3 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee15(epoch, logs) {
63055 var _this2 = this;
63056 var _iterator9, _step9, _loop2, _ret;
63057 return _regeneratorRuntime().wrap(function _callee15$(_context17) {
63058 while (1) switch (_context17.prev = _context17.next) {
63059 case 0:
63060 if (!(logs != null)) {
63061 _context17.next = 20;
63062 break;
63063 }
63064 _iterator9 = _createForOfIteratorHelper(this.params['metrics']);
63065 _context17.prev = 2;
63066 _loop2 = /*#__PURE__*/_regeneratorRuntime().mark(function _loop2() {
63067 var key;
63068 return _regeneratorRuntime().wrap(function _loop2$(_context16) {
63069 while (1) switch (_context16.prev = _context16.next) {
63070 case 0:
63071 key = _step9.value;
63072 if (!(_this2.totals[key] == null)) {
63073 _context16.next = 3;
63074 break;
63075 }
63076 return _context16.abrupt("return", "continue");
63077 case 3:
63078 if (typeof _this2.totals[key] === 'number') {
63079 logs[key] = _this2.totals[key] / _this2.seen;
63080 } else {
63081 tidy(function () {
63082 var log = mul(div$1(1, _this2.seen), _this2.totals[key]);
63083 logs[key] = log;
63084 _this2.totals[key].dispose();
63085 keep(logs[key]);
63086 });
63087 }
63088 case 4:
63089 case "end":
63090 return _context16.stop();
63091 }
63092 }, _loop2);
63093 });
63094 _iterator9.s();
63095 case 5:
63096 if ((_step9 = _iterator9.n()).done) {
63097 _context17.next = 12;
63098 break;
63099 }
63100 return _context17.delegateYield(_loop2(), "t0", 7);
63101 case 7:
63102 _ret = _context17.t0;
63103 if (!(_ret === "continue")) {
63104 _context17.next = 10;
63105 break;
63106 }
63107 return _context17.abrupt("continue", 10);
63108 case 10:
63109 _context17.next = 5;
63110 break;
63111 case 12:
63112 _context17.next = 17;
63113 break;
63114 case 14:
63115 _context17.prev = 14;
63116 _context17.t1 = _context17["catch"](2);
63117 _iterator9.e(_context17.t1);
63118 case 17:
63119 _context17.prev = 17;
63120 _iterator9.f();
63121 return _context17.finish(17);
63122 case 20:
63123 case "end":
63124 return _context17.stop();
63125 }
63126 }, _callee15, this, [[2, 14, 17, 20]]);
63127 }));
63128 function onEpochEnd(_x24, _x25) {
63129 return _onEpochEnd3.apply(this, arguments);
63130 }
63131 return onEpochEnd;
63132 }()
63133 }]);
63134 return BaseLogger;
63135 }(BaseCallback);
63136 /**
63137 * Callback that records events into a `History` object. This callback is
63138 * automatically applied to every TF.js Layers model. The `History` object
63139 * gets returned by the `fit` method of models.
63140 */
63141 var History = /*#__PURE__*/function (_BaseCallback2) {
63142 _inherits(History, _BaseCallback2);
63143 var _super2 = _createSuper(History);
63144 function History() {
63145 _classCallCheck(this, History);
63146 return _super2.apply(this, arguments);
63147 }
63148 _createClass(History, [{
63149 key: "onTrainBegin",
63150 value: function () {
63151 var _onTrainBegin3 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee16(logs) {
63152 return _regeneratorRuntime().wrap(function _callee16$(_context18) {
63153 while (1) switch (_context18.prev = _context18.next) {
63154 case 0:
63155 this.epoch = [];
63156 this.history = {};
63157 case 2:
63158 case "end":
63159 return _context18.stop();
63160 }
63161 }, _callee16, this);
63162 }));
63163 function onTrainBegin(_x26) {
63164 return _onTrainBegin3.apply(this, arguments);
63165 }
63166 return onTrainBegin;
63167 }()
63168 }, {
63169 key: "onEpochEnd",
63170 value: function () {
63171 var _onEpochEnd4 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee17(epoch, logs) {
63172 var key;
63173 return _regeneratorRuntime().wrap(function _callee17$(_context19) {
63174 while (1) switch (_context19.prev = _context19.next) {
63175 case 0:
63176 if (logs == null) {
63177 logs = {};
63178 }
63179 this.epoch.push(epoch);
63180 for (key in logs) {
63181 if (this.history[key] == null) {
63182 this.history[key] = [];
63183 }
63184 this.history[key].push(logs[key]);
63185 }
63186 case 3:
63187 case "end":
63188 return _context19.stop();
63189 }
63190 }, _callee17, this);
63191 }));
63192 function onEpochEnd(_x27, _x28) {
63193 return _onEpochEnd4.apply(this, arguments);
63194 }
63195 return onEpochEnd;
63196 }()
63197 /**
63198 * Await the values of all losses and metrics.
63199 */
63200 }, {
63201 key: "syncData",
63202 value: function () {
63203 var _syncData = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee18() {
63204 var promises, keys, indices, key, valueArray, i, valueScalar, values, n, tensorToDispose;
63205 return _regeneratorRuntime().wrap(function _callee18$(_context20) {
63206 while (1) switch (_context20.prev = _context20.next) {
63207 case 0:
63208 promises = [];
63209 keys = [];
63210 indices = [];
63211 for (key in this.history) {
63212 valueArray = this.history[key];
63213 for (i = 0; i < valueArray.length; ++i) {
63214 if (typeof valueArray[i] !== 'number') {
63215 valueScalar = valueArray[i];
63216 promises.push(valueScalar.data());
63217 keys.push(key);
63218 indices.push(i);
63219 }
63220 }
63221 }
63222 _context20.next = 6;
63223 return Promise.all(promises);
63224 case 6:
63225 values = _context20.sent;
63226 for (n = 0; n < values.length; ++n) {
63227 tensorToDispose = this.history[keys[n]][indices[n]];
63228 tensorToDispose.dispose();
63229 this.history[keys[n]][indices[n]] = values[n][0];
63230 }
63231 case 8:
63232 case "end":
63233 return _context20.stop();
63234 }
63235 }, _callee18, this);
63236 }));
63237 function syncData() {
63238 return _syncData.apply(this, arguments);
63239 }
63240 return syncData;
63241 }()
63242 }]);
63243 return History;
63244 }(BaseCallback);
63245 /**
63246 * Custom callback for training.
63247 */
63248 var CustomCallback = /*#__PURE__*/function (_BaseCallback3) {
63249 _inherits(CustomCallback, _BaseCallback3);
63250 var _super3 = _createSuper(CustomCallback);
63251 function CustomCallback(args, yieldEvery) {
63252 var _this3;
63253 _classCallCheck(this, CustomCallback);
63254 _this3 = _super3.call(this);
63255 _this3.currentEpoch = 0;
63256 _this3.nowFunc = args.nowFunc;
63257 _this3.nextFrameFunc = args.nextFrameFunc || nextFrame;
63258 _this3.yieldEvery = yieldEvery || 'auto';
63259 if (_this3.yieldEvery === 'auto') {
63260 _this3.yieldEvery = DEFAULT_YIELD_EVERY_MS;
63261 }
63262 if (_this3.yieldEvery === 'never' && args.onYield != null) {
63263 throw new Error('yieldEvery is `never` but you provided an `onYield` callback. ' + 'Either change `yieldEvery` or remove the callback');
63264 }
63265 if (isNumber(_this3.yieldEvery)) {
63266 // Decorate `maybeWait` so it will be called at most once every
63267 // `yieldEvery` ms.
63268 _this3.maybeWait = debounce(_this3.maybeWait.bind(_assertThisInitialized(_this3)), _this3.yieldEvery, _this3.nowFunc);
63269 }
63270 _this3.trainBegin = args.onTrainBegin;
63271 _this3.trainEnd = args.onTrainEnd;
63272 _this3.epochBegin = args.onEpochBegin;
63273 _this3.epochEnd = args.onEpochEnd;
63274 _this3.batchBegin = args.onBatchBegin;
63275 _this3.batchEnd = args.onBatchEnd;
63276 _this3.yield = args.onYield;
63277 return _this3;
63278 }
63279 _createClass(CustomCallback, [{
63280 key: "maybeWait",
63281 value: function () {
63282 var _maybeWait = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee19(epoch, batch, logs) {
63283 var ps;
63284 return _regeneratorRuntime().wrap(function _callee19$(_context21) {
63285 while (1) switch (_context21.prev = _context21.next) {
63286 case 0:
63287 ps = [];
63288 if (!(this.yield != null)) {
63289 _context21.next = 5;
63290 break;
63291 }
63292 _context21.next = 4;
63293 return resolveScalarsInLogs(logs);
63294 case 4:
63295 ps.push(this.yield(epoch, batch, logs));
63296 case 5:
63297 ps.push(this.nextFrameFunc());
63298 _context21.next = 8;
63299 return Promise.all(ps);
63300 case 8:
63301 case "end":
63302 return _context21.stop();
63303 }
63304 }, _callee19, this);
63305 }));
63306 function maybeWait(_x29, _x30, _x31) {
63307 return _maybeWait.apply(this, arguments);
63308 }
63309 return maybeWait;
63310 }()
63311 }, {
63312 key: "onEpochBegin",
63313 value: function () {
63314 var _onEpochBegin4 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee20(epoch, logs) {
63315 return _regeneratorRuntime().wrap(function _callee20$(_context22) {
63316 while (1) switch (_context22.prev = _context22.next) {
63317 case 0:
63318 this.currentEpoch = epoch;
63319 if (!(this.epochBegin != null)) {
63320 _context22.next = 6;
63321 break;
63322 }
63323 _context22.next = 4;
63324 return resolveScalarsInLogs(logs);
63325 case 4:
63326 _context22.next = 6;
63327 return this.epochBegin(epoch, logs);
63328 case 6:
63329 case "end":
63330 return _context22.stop();
63331 }
63332 }, _callee20, this);
63333 }));
63334 function onEpochBegin(_x32, _x33) {
63335 return _onEpochBegin4.apply(this, arguments);
63336 }
63337 return onEpochBegin;
63338 }()
63339 }, {
63340 key: "onEpochEnd",
63341 value: function () {
63342 var _onEpochEnd5 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee21(epoch, logs) {
63343 var ps;
63344 return _regeneratorRuntime().wrap(function _callee21$(_context23) {
63345 while (1) switch (_context23.prev = _context23.next) {
63346 case 0:
63347 ps = [];
63348 if (!(this.epochEnd != null)) {
63349 _context23.next = 5;
63350 break;
63351 }
63352 _context23.next = 4;
63353 return resolveScalarsInLogs(logs);
63354 case 4:
63355 ps.push(this.epochEnd(epoch, logs));
63356 case 5:
63357 if (this.yieldEvery === 'epoch') {
63358 ps.push(this.nextFrameFunc());
63359 }
63360 _context23.next = 8;
63361 return Promise.all(ps);
63362 case 8:
63363 case "end":
63364 return _context23.stop();
63365 }
63366 }, _callee21, this);
63367 }));
63368 function onEpochEnd(_x34, _x35) {
63369 return _onEpochEnd5.apply(this, arguments);
63370 }
63371 return onEpochEnd;
63372 }()
63373 }, {
63374 key: "onBatchBegin",
63375 value: function () {
63376 var _onBatchBegin3 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee22(batch, logs) {
63377 return _regeneratorRuntime().wrap(function _callee22$(_context24) {
63378 while (1) switch (_context24.prev = _context24.next) {
63379 case 0:
63380 if (!(this.batchBegin != null)) {
63381 _context24.next = 5;
63382 break;
63383 }
63384 _context24.next = 3;
63385 return resolveScalarsInLogs(logs);
63386 case 3:
63387 _context24.next = 5;
63388 return this.batchBegin(batch, logs);
63389 case 5:
63390 case "end":
63391 return _context24.stop();
63392 }
63393 }, _callee22, this);
63394 }));
63395 function onBatchBegin(_x36, _x37) {
63396 return _onBatchBegin3.apply(this, arguments);
63397 }
63398 return onBatchBegin;
63399 }()
63400 }, {
63401 key: "onBatchEnd",
63402 value: function () {
63403 var _onBatchEnd4 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee23(batch, logs) {
63404 var ps;
63405 return _regeneratorRuntime().wrap(function _callee23$(_context25) {
63406 while (1) switch (_context25.prev = _context25.next) {
63407 case 0:
63408 ps = [];
63409 if (!(this.batchEnd != null)) {
63410 _context25.next = 5;
63411 break;
63412 }
63413 _context25.next = 4;
63414 return resolveScalarsInLogs(logs);
63415 case 4:
63416 ps.push(this.batchEnd(batch, logs));
63417 case 5:
63418 if (this.yieldEvery === 'batch') {
63419 ps.push(this.nextFrameFunc());
63420 } else if (isNumber(this.yieldEvery)) {
63421 ps.push(this.maybeWait(this.currentEpoch, batch, logs));
63422 }
63423 _context25.next = 8;
63424 return Promise.all(ps);
63425 case 8:
63426 case "end":
63427 return _context25.stop();
63428 }
63429 }, _callee23, this);
63430 }));
63431 function onBatchEnd(_x38, _x39) {
63432 return _onBatchEnd4.apply(this, arguments);
63433 }
63434 return onBatchEnd;
63435 }()
63436 }, {
63437 key: "onTrainBegin",
63438 value: function () {
63439 var _onTrainBegin4 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee24(logs) {
63440 return _regeneratorRuntime().wrap(function _callee24$(_context26) {
63441 while (1) switch (_context26.prev = _context26.next) {
63442 case 0:
63443 if (!(this.trainBegin != null)) {
63444 _context26.next = 5;
63445 break;
63446 }
63447 _context26.next = 3;
63448 return resolveScalarsInLogs(logs);
63449 case 3:
63450 _context26.next = 5;
63451 return this.trainBegin(logs);
63452 case 5:
63453 case "end":
63454 return _context26.stop();
63455 }
63456 }, _callee24, this);
63457 }));
63458 function onTrainBegin(_x40) {
63459 return _onTrainBegin4.apply(this, arguments);
63460 }
63461 return onTrainBegin;
63462 }()
63463 }, {
63464 key: "onTrainEnd",
63465 value: function () {
63466 var _onTrainEnd3 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee25(logs) {
63467 return _regeneratorRuntime().wrap(function _callee25$(_context27) {
63468 while (1) switch (_context27.prev = _context27.next) {
63469 case 0:
63470 if (!(this.trainEnd != null)) {
63471 _context27.next = 5;
63472 break;
63473 }
63474 _context27.next = 3;
63475 return resolveScalarsInLogs(logs);
63476 case 3:
63477 _context27.next = 5;
63478 return this.trainEnd(logs);
63479 case 5:
63480 case "end":
63481 return _context27.stop();
63482 }
63483 }, _callee25, this);
63484 }));
63485 function onTrainEnd(_x41) {
63486 return _onTrainEnd3.apply(this, arguments);
63487 }
63488 return onTrainEnd;
63489 }()
63490 }]);
63491 return CustomCallback;
63492 }(BaseCallback);
63493 /**
63494 * Standardize callbacks or configurations of them to an Array of callbacks.
63495 */
63496 function standardizeCallbacks(callbacks, yieldEvery) {
63497 if (callbacks == null) {
63498 callbacks = {};
63499 }
63500 if (callbacks instanceof BaseCallback) {
63501 return [callbacks];
63502 }
63503 if (Array.isArray(callbacks) && callbacks[0] instanceof BaseCallback) {
63504 return callbacks;
63505 }
63506 // Convert custom callback configs to custom callback objects.
63507 var callbackConfigs = toList(callbacks);
63508 return callbackConfigs.map(function (callbackConfig) {
63509 return new CustomCallback(callbackConfig, yieldEvery);
63510 });
63511 }
63512 /**
63513 * A global registry for callback constructors to be used during
63514 * LayersModel.fit().
63515 */
63516 var CallbackConstructorRegistry = /*#__PURE__*/function () {
63517 /**
63518 * Blocks public access to constructor.
63519 */
63520 function CallbackConstructorRegistry() {
63521 _classCallCheck(this, CallbackConstructorRegistry);
63522 }
63523 /**
63524 * Register a tf.LayersModel.fit() callback constructor.
63525 *
63526 * The registered callback constructor will be used to instantiate
63527 * callbacks for every tf.LayersModel.fit() call afterwards.
63528 *
63529 * @param verbosityLevel Level of verbosity at which the `callbackConstructor`
63530 * is to be reigstered.
63531 * @param callbackConstructor A no-arg constructor for `tf.Callback`.
63532 * @throws Error, if the same callbackConstructor has been registered before,
63533 * either at the same or a different `verbosityLevel`.
63534 */
63535 _createClass(CallbackConstructorRegistry, null, [{
63536 key: "registerCallbackConstructor",
63537 value: function registerCallbackConstructor(verbosityLevel, callbackConstructor) {
63538 assert$1(verbosityLevel >= 0 && Number.isInteger(verbosityLevel), function () {
63539 return "Verbosity level is expected to be an integer >= 0, " + "but got ".concat(verbosityLevel);
63540 });
63541 CallbackConstructorRegistry.checkForDuplicate(callbackConstructor);
63542 if (CallbackConstructorRegistry.constructors[verbosityLevel] == null) {
63543 CallbackConstructorRegistry.constructors[verbosityLevel] = [];
63544 }
63545 CallbackConstructorRegistry.constructors[verbosityLevel].push(callbackConstructor);
63546 }
63547 }, {
63548 key: "checkForDuplicate",
63549 value: function checkForDuplicate(callbackConstructor) {
63550 for (var levelName in CallbackConstructorRegistry.constructors) {
63551 var constructors = CallbackConstructorRegistry.constructors[+levelName];
63552 constructors.forEach(function (ctor) {
63553 if (ctor === callbackConstructor) {
63554 throw new ValueError('Duplicate callback constructor.');
63555 }
63556 });
63557 }
63558 }
63559 /**
63560 * Clear all registered callback constructors.
63561 */
63562 }, {
63563 key: "clear",
63564 value: function clear() {
63565 CallbackConstructorRegistry.constructors = {};
63566 }
63567 /**
63568 * Create callbacks using the registered callback constructors.
63569 *
63570 * Given `verbosityLevel`, all constructors registered at that level or above
63571 * will be called and the instantiated callbacks will be used.
63572 *
63573 * @param verbosityLevel: Level of verbosity.
63574 */
63575 }, {
63576 key: "createCallbacks",
63577 value: function createCallbacks(verbosityLevel) {
63578 var constructors = [];
63579 for (var levelName in CallbackConstructorRegistry.constructors) {
63580 var level = +levelName;
63581 if (verbosityLevel >= level) {
63582 constructors.push.apply(constructors, _toConsumableArray(CallbackConstructorRegistry.constructors[level]));
63583 }
63584 }
63585 return constructors.map(function (ctor) {
63586 return new ctor();
63587 });
63588 }
63589 }]);
63590 return CallbackConstructorRegistry;
63591 }();
63592 CallbackConstructorRegistry.constructors = {};
63593 function configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics) {
63594 var history = new History();
63595 var actualCallbacks = [new BaseLogger()].concat(_toConsumableArray(CallbackConstructorRegistry.createCallbacks(verbose)));
63596 if (callbacks != null) {
63597 actualCallbacks.push.apply(actualCallbacks, _toConsumableArray(callbacks));
63598 }
63599 actualCallbacks.push(history);
63600 var callbackList = new CallbackList(actualCallbacks);
63601 // TODO(cais): Figure out when this LayersModel instance can have a
63602 // dynamically
63603 // set property called 'callback_model' as in PyKeras.
63604 callbackList.setParams({
63605 epochs: epochs,
63606 initialEpoch: initialEpoch,
63607 samples: numTrainSamples,
63608 steps: stepsPerEpoch,
63609 batchSize: batchSize,
63610 verbose: verbose,
63611 doValidation: doValidation,
63612 metrics: callbackMetrics
63613 });
63614 return {
63615 callbackList: callbackList,
63616 history: history
63617 };
63618 }
63619
63620 /**
63621 * @license
63622 * Copyright 2018 Google LLC
63623 *
63624 * Use of this source code is governed by an MIT-style
63625 * license that can be found in the LICENSE file or at
63626 * https://opensource.org/licenses/MIT.
63627 * =============================================================================
63628 */
63629 /**
63630 * Instantiate a layer from a config dictionary.
63631 * @param config dict of the form {class_name: str, config: dict}
63632 * @param customObjects dict mapping class names (or function names)
63633 * of custom (non-Keras) objects to class/functions
63634 * @param fastWeightInit Optional flag to use fast weight initialization
63635 * during deserialization. This is applicable to cases in which
63636 * the initialization will be immediately overwritten by loaded weight
63637 * values. Default: `false`.
63638 * @returns Layer instance (may be LayersModel, Sequential, Layer...)
63639 */
63640 function deserialize(config) {
63641 var customObjects = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
63642 var fastWeightInit = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
63643 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'layer', fastWeightInit);
63644 }
63645
63646 /**
63647 * @license
63648 * Copyright 2018 Google LLC
63649 *
63650 * Use of this source code is governed by an MIT-style
63651 * license that can be found in the LICENSE file or at
63652 * https://opensource.org/licenses/MIT.
63653 * =============================================================================
63654 */
63655 /**
63656 * Normalizes a tensor wrt the L2 norm alongside the specified axis.
63657 * @param x
63658 * @param axis Axis along which to perform normalization.
63659 */
63660 function l2Normalize(x, axis) {
63661 return tidy(function () {
63662 if (x.dtype !== 'float32') {
63663 x = cast$3(x, 'float32');
63664 }
63665 var squareSum = sum$3(square$1(x), axis, true);
63666 var epsilonTensor = fill$2(squareSum.shape, epsilon$1());
63667 var norm = sqrt$2(maximum$4(squareSum, epsilonTensor));
63668 return div$1(x, norm);
63669 });
63670 }
63671 function meanSquaredError$1(yTrue, yPred) {
63672 return tidy(function () {
63673 return mean$3(square$1(sub$2(yPred, yTrue)), -1);
63674 });
63675 }
63676 function meanAbsoluteError$1(yTrue, yPred) {
63677 return tidy(function () {
63678 return mean$3(abs$2(sub$2(yPred, yTrue)), -1);
63679 });
63680 }
63681 function meanAbsolutePercentageError$1(yTrue, yPred) {
63682 return tidy(function () {
63683 var diff = sub$2(yTrue, yPred);
63684 var clippedTrue = clipByValue$2(abs$2(yTrue), epsilon$1(), Number.MAX_VALUE);
63685 var absResult = abs$2(div$1(diff, clippedTrue));
63686 return mul(100, mean$3(absResult, -1));
63687 });
63688 }
63689 function meanSquaredLogarithmicError(yTrue, yPred) {
63690 return tidy(function () {
63691 var clippedPred = clipByValue$2(yPred, epsilon$1(), Number.MAX_VALUE);
63692 var firstLog = log$2(add$3(1, clippedPred));
63693 var clippedTrue = clipByValue$2(yTrue, epsilon$1(), Number.MAX_VALUE);
63694 var secondLog = log$2(add$3(1, clippedTrue));
63695 return mean$3(square$1(sub$2(firstLog, secondLog)), -1);
63696 });
63697 }
63698 function squaredHinge(yTrue, yPred) {
63699 return tidy(function () {
63700 var maxResult = maximum$4(0, sub$2(1, mul(yTrue, yPred)));
63701 return mean$3(square$1(maxResult), -1);
63702 });
63703 }
63704 function hinge(yTrue, yPred) {
63705 return tidy(function () {
63706 var maxResult = maximum$4(0, sub$2(1, mul(yTrue, yPred)));
63707 return mean$3(maxResult, -1);
63708 });
63709 }
63710 function categoricalHinge(yTrue, yPred) {
63711 return tidy(function () {
63712 var pos = sum$3(mul(yTrue, yPred), -1);
63713 var neg = max$3(mul(sub$2(1, yTrue), yPred), -1);
63714 return maximum$4(0, add$3(1, sub$2(neg, pos)));
63715 });
63716 }
63717 /**
63718 * Logarithm of the hyperbolic cosine of the prediction error.
63719 *
63720 * `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and
63721 * to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly
63722 * like the mean squared error, but will not be so strongly affected by the
63723 * occasional wildly incorrect prediction.
63724 */
63725 function logcosh(yTrue, yPred) {
63726 return tidy(function () {
63727 var log2 = Math.log(2);
63728 var predictionDiff = sub$2(yPred, yTrue);
63729 var logcoshResult = sub$2(add$3(predictionDiff, softplus$2(mul(-2, predictionDiff))), log2);
63730 return mean$3(logcoshResult, -1);
63731 });
63732 }
63733 function categoricalCrossentropy$2(target, output) {
63734 var fromLogits = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
63735 return tidy(function () {
63736 if (fromLogits) {
63737 output = softmax$3(output);
63738 } else {
63739 // scale preds so that the class probabilities of each sample sum to 1.
63740 var outputSum = sum$3(output, output.shape.length - 1, true);
63741 output = div$1(output, outputSum);
63742 }
63743 output = clipByValue$2(output, epsilon$1(), 1 - epsilon$1());
63744 return neg$2(sum$3(mul(cast$3(target, 'float32'), log$2(output)), output.shape.length - 1));
63745 });
63746 }
63747 /**
63748 * Categorical crossentropy with integer targets.
63749 *
63750 * @param target An integer tensor.
63751 * @param output A tensor resulting from a softmax (unless `fromLogits` is
63752 * `true`, in which case `output` is expected to be the logits).
63753 * @param fromLogits Boolean, whether `output` is the result of a softmax, or is
63754 * a tensor of logits.
63755 */
63756 function sparseCategoricalCrossentropy$1(target, output) {
63757 var fromLogits = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
63758 return tidy(function () {
63759 var flatTarget = cast$3(floor$2(flatten$1(target)), 'int32');
63760 output = clipByValue$2(output, epsilon$1(), 1 - epsilon$1());
63761 var outputShape = output.shape;
63762 var oneHotTarget = reshape$3(oneHot$3(flatTarget, outputShape[outputShape.length - 1]), outputShape);
63763 return categoricalCrossentropy$2(oneHotTarget, output, fromLogits);
63764 });
63765 }
63766 /**
63767 * From TensorFlow's implementation in nn_impl.py:
63768 *
63769 * For brevity, let `x = logits`, `z = labels`. The logistic loss is
63770 * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
63771 * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
63772 * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
63773 * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
63774 * = (1 - z) * x + log(1 + exp(-x))
63775 * = x - x * z + log(1 + exp(-x))
63776 * For x < 0, to avoid overflow in exp(-x), we reformulate the above
63777 * x - x * z + log(1 + exp(-x))
63778 * = log(exp(x)) - x * z + log(1 + exp(-x))
63779 * = - x * z + log(1 + exp(x))
63780 * Hence, to ensure stability and avoid overflow, the implementation uses this
63781 * equivalent formulation
63782 * max(x, 0) - x * z + log(1 + exp(-abs(x)))
63783 *
63784 * @param labels The labels.
63785 * @param logits The logits.
63786 */
63787 function sigmoidCrossEntropyWithLogits(labels, logits) {
63788 if (!arraysEqual(labels.shape, logits.shape)) {
63789 throw new ValueError("logits and labels must have the same shape, but got shapes " + "".concat(JSON.stringify(labels.shape), " and ").concat(JSON.stringify(logits.shape)));
63790 }
63791 return tidy(function () {
63792 // The logistic loss formula from above is
63793 // x - x * z + log(1 + exp(-x))
63794 // For x < 0, a more numerically stable formula is
63795 // -x * z + log(1 + exp(x))
63796 // Note that these two expressions can be combined into the following:
63797 // max(x, 0) - x * z + log(1 + exp(-abs(x)))
63798 var reluLogits = relu$2(logits);
63799 var negAbsLogits = neg$2(abs$2(logits));
63800 return add$3(sub$2(reluLogits, mul(logits, labels)), log1p$2(exp$2(negAbsLogits)));
63801 });
63802 }
63803 function binaryCrossentropy$2(yTrue, yPred) {
63804 return tidy(function () {
63805 var y;
63806 y = clipByValue$2(yPred, epsilon$1(), 1 - epsilon$1());
63807 y = log$2(div$1(y, sub$2(1, y)));
63808 return mean$3(sigmoidCrossEntropyWithLogits(yTrue, y), -1);
63809 });
63810 }
63811 function kullbackLeiblerDivergence(yTrue, yPred) {
63812 return tidy(function () {
63813 var clippedTrue = clipByValue$2(yTrue, epsilon$1(), 1);
63814 var clippedPred = clipByValue$2(yPred, epsilon$1(), 1);
63815 return sum$3(mul(yTrue, log$2(div$1(clippedTrue, clippedPred))), -1);
63816 });
63817 }
63818 function poisson(yTrue, yPred) {
63819 return tidy(function () {
63820 var logPred = log$2(add$3(epsilon$1(), yPred));
63821 return mean$3(sub$2(yPred, mul(yTrue, logPred)), -1);
63822 });
63823 }
63824 function cosineProximity$1(yTrue, yPred) {
63825 return tidy(function () {
63826 var trueNormalized = l2Normalize(yTrue, -1);
63827 var predNormalized = l2Normalize(yPred, -1);
63828 var trueXPred = mul(trueNormalized, predNormalized);
63829 return neg$2(sum$3(trueXPred, -1));
63830 });
63831 }
63832 var mse$2 = meanSquaredError$1;
63833 var MSE$2 = meanSquaredError$1;
63834 var mae$1 = meanAbsoluteError$1;
63835 var MAE$1 = meanAbsoluteError$1;
63836 var mape$2 = meanAbsolutePercentageError$1;
63837 var MAPE$2 = meanAbsolutePercentageError$1;
63838 var msle = meanSquaredLogarithmicError;
63839 var MSLE = meanSquaredLogarithmicError;
63840 var kld = kullbackLeiblerDivergence;
63841 var KLD = kullbackLeiblerDivergence;
63842 var cosine$1 = cosineProximity$1;
63843 // TODO(michaelterry): Add deserialize() function.
63844 var lossesMap = {
63845 meanSquaredError: meanSquaredError$1,
63846 meanAbsoluteError: meanAbsoluteError$1,
63847 meanAbsolutePercentageError: meanAbsolutePercentageError$1,
63848 meanSquaredLogarithmicError: meanSquaredLogarithmicError,
63849 squaredHinge: squaredHinge,
63850 hinge: hinge,
63851 categoricalHinge: categoricalHinge,
63852 logcosh: logcosh,
63853 categoricalCrossentropy: categoricalCrossentropy$2,
63854 sparseCategoricalCrossentropy: sparseCategoricalCrossentropy$1,
63855 binaryCrossentropy: binaryCrossentropy$2,
63856 kullbackLeiblerDivergence: kullbackLeiblerDivergence,
63857 poisson: poisson,
63858 cosineProximity: cosineProximity$1
63859 };
63860 // Porting note: This diverges from the PyKeras implementation and may need to
63861 // change based on (de)serialization requirements.
63862 function get$1(identifierOrFn) {
63863 if (typeof identifierOrFn === 'string') {
63864 if (identifierOrFn in lossesMap) {
63865 return lossesMap[identifierOrFn];
63866 }
63867 var errMsg = "Unknown loss ".concat(identifierOrFn);
63868 if (identifierOrFn.toLowerCase().includes('softmaxcrossentropy')) {
63869 errMsg = "Unknown loss ".concat(identifierOrFn, ". ") + 'Use "categoricalCrossentropy" as the string name for ' + 'tf.losses.softmaxCrossEntropy';
63870 }
63871 throw new ValueError(errMsg);
63872 } else {
63873 return identifierOrFn;
63874 }
63875 }
63876
63877 /**
63878 * @license
63879 * Copyright 2018 Google LLC
63880 *
63881 * Use of this source code is governed by an MIT-style
63882 * license that can be found in the LICENSE file or at
63883 * https://opensource.org/licenses/MIT.
63884 * =============================================================================
63885 */
63886 function binaryAccuracy$1(yTrue, yPred) {
63887 return tidy(function () {
63888 var threshold = mul(.5, onesLike$3(yPred));
63889 var yPredThresholded = cast$2(greater$3(yPred, threshold), yTrue.dtype);
63890 return mean$3(equal$2(yTrue, yPredThresholded), -1);
63891 });
63892 }
63893 function categoricalAccuracy$1(yTrue, yPred) {
63894 return tidy(function () {
63895 return cast$2(equal$2(argMax$2(yTrue, -1), argMax$2(yPred, -1)), 'float32');
63896 });
63897 }
63898 function truePositives(yTrue, yPred) {
63899 return tidy(function () {
63900 return cast$3(sum$3(logicalAnd$2(equal$2(yTrue, 1), equal$2(yPred, 1))), 'float32');
63901 });
63902 }
63903 function falseNegatives(yTrue, yPred) {
63904 return tidy(function () {
63905 return cast$3(sum$3(logicalAnd$2(equal$2(yTrue, 1), equal$2(yPred, 0))), 'float32');
63906 });
63907 }
63908 function falsePositives(yTrue, yPred) {
63909 return tidy(function () {
63910 return cast$3(sum$3(logicalAnd$2(equal$2(yTrue, 0), equal$2(yPred, 1))), 'float32');
63911 });
63912 }
63913 function precision$1(yTrue, yPred) {
63914 return tidy(function () {
63915 var tp = truePositives(yTrue, yPred);
63916 var fp = falsePositives(yTrue, yPred);
63917 var denominator = add$3(tp, fp);
63918 return cast$3(where(greater$3(denominator, 0), div$1(tp, denominator), 0), 'float32');
63919 });
63920 }
63921 function recall$1(yTrue, yPred) {
63922 return tidy(function () {
63923 var tp = truePositives(yTrue, yPred);
63924 var fn = falseNegatives(yTrue, yPred);
63925 var denominator = add$3(tp, fn);
63926 return cast$3(where(greater$3(denominator, 0), div$1(tp, denominator), 0), 'float32');
63927 });
63928 }
63929 function binaryCrossentropy$1(yTrue, yPred) {
63930 return binaryCrossentropy$2(yTrue, yPred);
63931 }
63932 function sparseCategoricalAccuracy$1(yTrue, yPred) {
63933 if (yTrue.rank === yPred.rank) {
63934 yTrue = squeeze(yTrue, [yTrue.rank - 1]);
63935 }
63936 yPred = argMax$2(yPred, -1);
63937 if (yPred.dtype !== yTrue.dtype) {
63938 yPred = cast$3(yPred, yTrue.dtype);
63939 }
63940 return cast$3(equal$2(yTrue, yPred), 'float32');
63941 }
63942 function topKCategoricalAccuracy(yTrue, yPred) {
63943 throw new NotImplementedError();
63944 }
63945 function sparseTopKCategoricalAccuracy(yTrue, yPred) {
63946 throw new NotImplementedError();
63947 }
63948 function r2Score$1(yTrue, yPred) {
63949 return tidy(function () {
63950 var sumSquaresResiduals = yTrue.sub(yPred).square().sum();
63951 var sumSquares = yTrue.sub(yTrue.mean()).square().sum();
63952 return scalar(1).sub(sumSquaresResiduals.div(sumSquares));
63953 });
63954 }
63955 // Aliases.
63956 var mse$1 = meanSquaredError$1;
63957 var MSE$1 = meanSquaredError$1;
63958 var mae = meanAbsoluteError$1;
63959 var MAE = meanAbsoluteError$1;
63960 var mape$1 = meanAbsolutePercentageError$1;
63961 var MAPE$1 = meanAbsolutePercentageError$1;
63962 var categoricalCrossentropy$1 = categoricalCrossentropy$2;
63963 var cosine = cosineProximity$1;
63964 var sparseCategoricalCrossentropy = sparseCategoricalCrossentropy$1;
63965 // TODO(cais, nielsene): Add serialize().
63966 var metricsMap = {
63967 binaryAccuracy: binaryAccuracy$1,
63968 categoricalAccuracy: categoricalAccuracy$1,
63969 precision: precision$1,
63970 categoricalCrossentropy: categoricalCrossentropy$1,
63971 sparseCategoricalCrossentropy: sparseCategoricalCrossentropy,
63972 mse: mse$1,
63973 MSE: MSE$1,
63974 mae: mae,
63975 MAE: MAE,
63976 mape: mape$1,
63977 MAPE: MAPE$1,
63978 cosine: cosine
63979 };
63980 function get(identifier) {
63981 if (typeof identifier === 'string' && identifier in metricsMap) {
63982 return metricsMap[identifier];
63983 } else if (typeof identifier !== 'string' && identifier != null) {
63984 return identifier;
63985 } else {
63986 throw new ValueError("Unknown metric ".concat(identifier));
63987 }
63988 }
63989 /**
63990 * Get the shortcut function name.
63991 *
63992 * If the fn name is a string,
63993 * directly return the string name.
63994 * If the function is included in metricsMap or lossesMap,
63995 * return key of the map.
63996 * - If the function relative to multiple keys,
63997 * return the first found key as the function name.
63998 * - If the function exists in both lossesMap and metricsMap,
63999 * search lossesMap first.
64000 * If the function is not included in metricsMap or lossesMap,
64001 * return the function name.
64002 *
64003 * @param fn loss function, metric function, or short cut name.
64004 * @returns Loss or Metric name in string.
64005 */
64006 function getLossOrMetricName(fn) {
64007 assert(fn !== null, "Unknown LossOrMetricFn ".concat(fn));
64008 if (typeof fn === 'string') {
64009 return fn;
64010 } else {
64011 var fnName;
64012 for (var _i = 0, _Object$keys = Object.keys(lossesMap); _i < _Object$keys.length; _i++) {
64013 var key = _Object$keys[_i];
64014 if (lossesMap[key] === fn) {
64015 fnName = key;
64016 break;
64017 }
64018 }
64019 if (fnName !== undefined) {
64020 return fnName;
64021 }
64022 for (var _i2 = 0, _Object$keys2 = Object.keys(metricsMap); _i2 < _Object$keys2.length; _i2++) {
64023 var _key = _Object$keys2[_i2];
64024 if (metricsMap[_key] === fn) {
64025 fnName = _key;
64026 break;
64027 }
64028 }
64029 if (fnName !== undefined) {
64030 return fnName;
64031 }
64032 return fn.name;
64033 }
64034 }
64035
64036 /**
64037 * @license
64038 * Copyright 2018 Google LLC
64039 *
64040 * Use of this source code is governed by an MIT-style
64041 * license that can be found in the LICENSE file or at
64042 * https://opensource.org/licenses/MIT.
64043 * =============================================================================
64044 */
64045 // Add (de)serialize()
64046 // Porting note: This diverges from the PyKeras implementation and may need to
64047 // change based on (de)serialization requirements.
64048 function getOptimizer(identifier) {
64049 var optimizerMap = {
64050 'Adagrad': function Adagrad() {
64051 return train.adagrad(0.01);
64052 },
64053 'Adadelta': function Adadelta() {
64054 return train.adadelta(1, 0.95, epsilon$1());
64055 },
64056 'Adam': function Adam() {
64057 return train.adam(0.001, 0.9, 0.999, epsilon$1());
64058 },
64059 'Adamax': function Adamax() {
64060 return train.adamax(0.002, 0.9, 0.999, epsilon$1(), 0);
64061 },
64062 'RMSProp': function RMSProp() {
64063 return train.rmsprop(0.001, 0.9, 0, epsilon$1());
64064 },
64065 'SGD': function SGD() {
64066 return train.sgd(0.01);
64067 }
64068 };
64069 optimizerMap['adagrad'] = optimizerMap['Adagrad'];
64070 optimizerMap['adadelta'] = optimizerMap['Adadelta'];
64071 optimizerMap['adam'] = optimizerMap['Adam'];
64072 optimizerMap['adamax'] = optimizerMap['Adamax'];
64073 optimizerMap['rmsprop'] = optimizerMap['RMSProp'];
64074 optimizerMap['sgd'] = optimizerMap['SGD'];
64075 if (identifier in optimizerMap) {
64076 return optimizerMap[identifier]();
64077 }
64078 throw new ValueError("Unknown Optimizer ".concat(identifier));
64079 }
64080
64081 /**
64082 * @license
64083 * Copyright 2019 Google LLC
64084 *
64085 * Use of this source code is governed by an MIT-style
64086 * license that can be found in the LICENSE file or at
64087 * https://opensource.org/licenses/MIT.
64088 * =============================================================================
64089 */
64090 /** Utility functions related to user-defined metadata. */
64091 // Maximum recommended serialized size for user-defined metadata.
64092 // Beyond this limit, a warning message will be printed during model loading and
64093 // saving.
64094 var MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH = 1 * 1024 * 1024;
64095 /**
64096 * Check validity of user-defined metadata.
64097 *
64098 * @param userDefinedMetadata
64099 * @param modelName Name of the model that the user-defined metadata belongs to.
64100 * Used during construction of error messages.
64101 * @param checkSize Whether to check the size of the metadata is under
64102 * recommended limit. Default: `false`. If `true`, will try stringify the
64103 * JSON object and print a console warning if the serialzied size is above the
64104 * limit.
64105 * @throws Error if `userDefinedMetadata` is not a plain JSON object.
64106 */
64107 function checkUserDefinedMetadata(userDefinedMetadata, modelName) {
64108 var checkSize = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
64109 if (userDefinedMetadata == null || _typeof(userDefinedMetadata) !== 'object' || Object.getPrototypeOf(userDefinedMetadata) !== Object.prototype || !plainObjectCheck(userDefinedMetadata)) {
64110 throw new Error('User-defined metadata is expected to be a JSON object, but is not.');
64111 }
64112 if (checkSize) {
64113 var out = JSON.stringify(userDefinedMetadata);
64114 if (out.length > MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH) {
64115 console.warn("User-defined metadata of model \"".concat(modelName, "\" is too large in ") + "size (length=".concat(out.length, " when serialized). It is not ") + "recommended to store such large objects in user-defined metadata. " + "Please make sure its serialized length is <= " + "".concat(MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH, "."));
64116 }
64117 }
64118 }
64119 /**
64120 * Check if an input is plain JSON object or any valid subfield of it.
64121 *
64122 * @param x The input to be checked.
64123 * @param assertObject Whether to assert `x` is a JSON object, i.e., reject
64124 * cases of arrays and primitives.
64125 * @return Returns `true` if and only if `x` is a plain JSON object,
64126 * a JSON-valid primitive including string, number, boolean and null,
64127 * or an array of the said types.
64128 */
64129 // tslint:disable-next-line:no-any
64130 function plainObjectCheck(x) {
64131 if (x === null) {
64132 // Note: typeof `null` is 'object', and `null` is valid in JSON.
64133 return true;
64134 } else if (_typeof(x) === 'object') {
64135 if (Object.getPrototypeOf(x) === Object.prototype) {
64136 // `x` is a JavaScript object and its prototype is Object.
64137 var keys = Object.keys(x);
64138 for (var _i = 0, _keys = keys; _i < _keys.length; _i++) {
64139 var key = _keys[_i];
64140 if (typeof key !== 'string') {
64141 // JSON keys must be strings.
64142 return false;
64143 }
64144 if (!plainObjectCheck(x[key])) {
64145 // Recursive call.
64146 return false;
64147 }
64148 }
64149 return true;
64150 } else {
64151 // `x` is a JavaScript object but its prototype is not Object.
64152 if (Array.isArray(x)) {
64153 // `x` is a JavaScript array.
64154 var _iterator = _createForOfIteratorHelper(x),
64155 _step;
64156 try {
64157 for (_iterator.s(); !(_step = _iterator.n()).done;) {
64158 var item = _step.value;
64159 if (!plainObjectCheck(item)) {
64160 // Recursive call.
64161 return false;
64162 }
64163 }
64164 } catch (err) {
64165 _iterator.e(err);
64166 } finally {
64167 _iterator.f();
64168 }
64169 return true;
64170 } else {
64171 // `x` is a JavaScript object and its prototype is not Object,
64172 // and it's not an Array. I.e., it's a complex object such as
64173 // `Error` and `Date`.
64174 return false;
64175 }
64176 }
64177 } else {
64178 // `x` is not a JavaScript object or `null`.
64179 var xType = _typeof(x);
64180 return xType === 'string' || xType === 'number' || xType === 'boolean';
64181 }
64182 }
64183
64184 /**
64185 * Print the summary of a LayersModel object.
64186 *
64187 * @param model tf.LayersModel instance.
64188 * @param lineLength Total length of printed lines. Set this to adapt to the
64189 * display to different terminal or console sizes.
64190 * @param positions Relative or absolute positions of log elements in each
64191 * line. Each number corresponds to right-most (i.e., ending) position of a
64192 * column.
64193 * If not provided, defaults to `[0.45, 0.85, 1]` for sequential-like
64194 * models and `[0.33, 0.55, 0.67, 1]` for non-sequential like models.
64195 * @param printFn Print function to use.
64196 * It will be called on each line of the summary. You can provide a custom
64197 * function in order to capture the string summary. Defaults to `console.log`.
64198 */
64199 function printSummary(model, lineLength, positions) {
64200 var printFn = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : console.log;
64201 var sequentialLike = isModelSequentialLike(model);
64202 // Header names for different log elements.
64203 var toDisplay = ['Layer (type)', 'Input Shape', 'Output shape', 'Param #'];
64204 if (sequentialLike) {
64205 lineLength = lineLength || 90;
64206 positions = positions || [0.32, 0.61, 0.89, 1];
64207 } else {
64208 lineLength = lineLength || 115;
64209 positions = positions || [0.24, 0.48, 0.70, 0.80, 1];
64210 // Header names for different log elements.
64211 }
64212
64213 if (positions[positions.length - 1] <= 1) {
64214 // `positions` is relative. Convert it to absolute positioning.
64215 positions = positions.map(function (p) {
64216 return Math.floor(lineLength * p);
64217 });
64218 }
64219 var relevantNodes;
64220 if (!sequentialLike) {
64221 toDisplay.push('Receives inputs');
64222 relevantNodes = [];
64223 for (var depth in model.nodesByDepth) {
64224 var _relevantNodes;
64225 (_relevantNodes = relevantNodes).push.apply(_relevantNodes, _toConsumableArray(model.nodesByDepth[depth]));
64226 }
64227 }
64228 printFn('_'.repeat(lineLength));
64229 printRow(toDisplay, positions, printFn);
64230 printFn('='.repeat(lineLength));
64231 var layers = model.layers;
64232 for (var i = 0; i < layers.length; ++i) {
64233 if (sequentialLike) {
64234 printLayerSummary(layers[i], positions, printFn);
64235 } else {
64236 printLayerSummaryWithConnections(layers[i], positions, relevantNodes, printFn);
64237 }
64238 printFn((i === layers.length - 1 ? '=' : '_').repeat(lineLength));
64239 }
64240 // tslint:disable-next-line:no-any
64241 model.checkTrainableWeightsConsistency();
64242 var trainableCount = countTrainableParams(model);
64243 var nonTrainableCount = countParamsInWeights(model.nonTrainableWeights);
64244 printFn("Total params: ".concat(trainableCount + nonTrainableCount));
64245 printFn("Trainable params: ".concat(trainableCount));
64246 printFn("Non-trainable params: ".concat(nonTrainableCount));
64247 printFn('_'.repeat(lineLength));
64248 }
64249 function countTrainableParams(model) {
64250 var trainableCount;
64251 // tslint:disable:no-any
64252 if (model.collectedTrainableWeights != null) {
64253 trainableCount = countParamsInWeights(model.collectedTrainableWeights);
64254 } else {
64255 trainableCount = countParamsInWeights(model.trainableWeights);
64256 }
64257 // tslint:enable:no-any
64258 return trainableCount;
64259 }
64260 function isModelSequentialLike(model) {
64261 var sequentialLike = true;
64262 var nodesByDepth = [];
64263 var nodes = [];
64264 for (var depth in model.nodesByDepth) {
64265 nodesByDepth.push(model.nodesByDepth[depth]);
64266 }
64267 for (var _i = 0, _nodesByDepth = nodesByDepth; _i < _nodesByDepth.length; _i++) {
64268 var depthNodes = _nodesByDepth[_i];
64269 if (depthNodes.length > 1 || depthNodes.length === 1 && depthNodes[0].inboundLayers.length > 1) {
64270 sequentialLike = false;
64271 break;
64272 }
64273 nodes.push.apply(nodes, _toConsumableArray(depthNodes));
64274 }
64275 if (sequentialLike) {
64276 // Search for shared layers.
64277 var _iterator = _createForOfIteratorHelper(model.layers),
64278 _step;
64279 try {
64280 for (_iterator.s(); !(_step = _iterator.n()).done;) {
64281 var layer = _step.value;
64282 var flag = false;
64283 var _iterator2 = _createForOfIteratorHelper(layer.inboundNodes),
64284 _step2;
64285 try {
64286 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
64287 var node = _step2.value;
64288 if (nodes.indexOf(node) !== -1) {
64289 if (flag) {
64290 sequentialLike = false;
64291 break;
64292 } else {
64293 flag = true;
64294 }
64295 }
64296 }
64297 } catch (err) {
64298 _iterator2.e(err);
64299 } finally {
64300 _iterator2.f();
64301 }
64302 if (!sequentialLike) {
64303 break;
64304 }
64305 }
64306 } catch (err) {
64307 _iterator.e(err);
64308 } finally {
64309 _iterator.f();
64310 }
64311 }
64312 return sequentialLike;
64313 }
64314 function printRow(fields, positions) {
64315 var printFn = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : console.log;
64316 var line = '';
64317 for (var i = 0; i < fields.length; ++i) {
64318 if (i > 0) {
64319 line = line.slice(0, line.length - 1) + ' ';
64320 }
64321 line += fields[i];
64322 line = line.slice(0, positions[i]);
64323 line += ' '.repeat(positions[i] - line.length);
64324 }
64325 printFn(line);
64326 }
64327 /**
64328 * Prints a summary for a single Layer, without connectivity information.
64329 *
64330 * @param layer: Layer instance to print.
64331 */
64332 function printLayerSummary(layer, positions,
64333 // tslint:disable-next-line:no-any
64334 printFn) {
64335 var outputShape;
64336 var inputShape;
64337 try {
64338 inputShape = layer.inboundNodes.map(function (x) {
64339 return JSON.stringify(x.inputShapes);
64340 }).join(',');
64341 } catch (err) {
64342 inputShape = 'multiple';
64343 }
64344 try {
64345 outputShape = JSON.stringify(layer.outputShape);
64346 } catch (err) {
64347 outputShape = 'multiple';
64348 }
64349 var name = layer.name;
64350 var className = layer.getClassName();
64351 var fields = ["".concat(name, " (").concat(className, ")"), inputShape, outputShape, layer.countParams().toString()];
64352 printRow(fields, positions, printFn);
64353 }
64354 /**
64355 * Prints a summary for a single Layer, with connectivity information.
64356 */
64357 function printLayerSummaryWithConnections(layer, positions, relevantNodes,
64358 // tslint:disable-next-line:no-any
64359 printFn) {
64360 var outputShape;
64361 var inputShape;
64362 try {
64363 inputShape = layer.inboundNodes.map(function (x) {
64364 return JSON.stringify(x.inputShapes);
64365 }).join(',');
64366 } catch (err) {
64367 inputShape = 'multiple';
64368 }
64369 try {
64370 outputShape = JSON.stringify(layer.outputShape);
64371 } catch (err) {
64372 outputShape = 'multiple';
64373 }
64374 var connections = [];
64375 var _iterator3 = _createForOfIteratorHelper(layer.inboundNodes),
64376 _step3;
64377 try {
64378 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
64379 var node = _step3.value;
64380 if (relevantNodes != null && relevantNodes.length > 0 && relevantNodes.indexOf(node) === -1) {
64381 continue;
64382 }
64383 for (var _i2 = 0; _i2 < node.inboundLayers.length; ++_i2) {
64384 var inboundLayer = node.inboundLayers[_i2].name;
64385 var inboundLayerIndex = node.nodeIndices[_i2];
64386 var inboundTensorIndex = node.tensorIndices[_i2];
64387 connections.push("".concat(inboundLayer, "[").concat(inboundLayerIndex, "][").concat(inboundTensorIndex, "]"));
64388 }
64389 }
64390 } catch (err) {
64391 _iterator3.e(err);
64392 } finally {
64393 _iterator3.f();
64394 }
64395 var name = layer.name;
64396 var className = layer.getClassName();
64397 var firstConnection = connections.length === 0 ? '' : connections[0];
64398 var fields = ["".concat(name, " (").concat(className, ")"), inputShape, outputShape, layer.countParams().toString(), firstConnection];
64399 printRow(fields, positions, printFn);
64400 for (var i = 1; i < connections.length; ++i) {
64401 printRow(['', '', '', '', connections[i]], positions, printFn);
64402 }
64403 }
64404
64405 /**
64406 * @license
64407 * Copyright 2018 Google LLC
64408 *
64409 * Use of this source code is governed by an MIT-style
64410 * license that can be found in the LICENSE file or at
64411 * https://opensource.org/licenses/MIT.
64412 * =============================================================================
64413 */
64414 // tslint:enable
64415 /**
64416 * Test whether a value in an array is the name of a LayersModel or Layer.
64417 * @param key The key name that the value is found under. Note that the key
64418 * may not be at the level immediately above the value, if the value is in a
64419 * nested array.
64420 * @param index Index of the value in the Array that it is found in.
64421 * @param value The value object.
64422 * @returns A boolean indicating whether value is a name.
64423 */
64424 function isArrayItemInputOrOutputName(key, index, value) {
64425 return (key === 'inboundNodes' || key === 'outputLayers' || key === 'inputLayers') && index === 0 && typeof value === 'string';
64426 }
64427 /**
64428 * Convert a Pythonic config object to TypeScript config object.
64429 * @param pythonicConfig The config object to convert.
64430 * @param key Optional key name of the object being converted.
64431 * @returns Result of the conversion.
64432 */
64433 function convertPythonicToTs(pythonicConfig, key) {
64434 if (pythonicConfig === null) {
64435 return null;
64436 } else if (typeof pythonicConfig === 'string') {
64437 return toCamelCase(pythonicConfig);
64438 } else if (typeof pythonicConfig === 'number' || typeof pythonicConfig === 'boolean') {
64439 return pythonicConfig;
64440 } else if (pythonicConfig instanceof Array) {
64441 var tsArray = [];
64442 var arrayLength = pythonicConfig.length;
64443 for (var i = 0; i < arrayLength; ++i) {
64444 var item = pythonicConfig[i];
64445 if (isArrayItemInputOrOutputName(key, i, item)) {
64446 tsArray.push(item);
64447 } else {
64448 tsArray.push(convertPythonicToTs(item, key));
64449 }
64450 }
64451 return tsArray;
64452 } else {
64453 var tsDict = {};
64454 for (var _i = 0, _Object$keys = Object.keys(pythonicConfig); _i < _Object$keys.length; _i++) {
64455 var pythonicKey = _Object$keys[_i];
64456 var pythonicValue = pythonicConfig[pythonicKey];
64457 if (pythonicKey === 'name' && typeof pythonicValue === 'string') {
64458 // Special case the 'name' key with a string value. Name values, such as
64459 // the names of LayersModel and Layer instances, should not undergo the
64460 // camel-case conversion.
64461 tsDict[pythonicKey] = pythonicValue;
64462 } else {
64463 var tsKey = toCamelCase(pythonicKey);
64464 tsDict[tsKey] = convertPythonicToTs(pythonicValue, tsKey);
64465 }
64466 }
64467 return tsDict;
64468 }
64469 }
64470 /**
64471 * Convert a TypeScript config object to Python config object.
64472 * @param tsConfig The config object to convert.
64473 * @param key Optional key name of the object being converted.
64474 * @returns Result of the conversion.
64475 */
64476 function convertTsToPythonic(tsConfig, key) {
64477 if (tsConfig === null || tsConfig === undefined) {
64478 return null;
64479 } else if (typeof tsConfig === 'string') {
64480 return toSnakeCase(tsConfig);
64481 } else if (typeof tsConfig === 'number' || typeof tsConfig === 'boolean') {
64482 return tsConfig;
64483 } else if (tsConfig instanceof Array) {
64484 var pyArray = [];
64485 var arrayLength = tsConfig.length;
64486 for (var i = 0; i < arrayLength; ++i) {
64487 var item = tsConfig[i];
64488 if (isArrayItemInputOrOutputName(key, i, item)) {
64489 pyArray.push(item);
64490 } else {
64491 pyArray.push(convertTsToPythonic(item, key));
64492 }
64493 }
64494 return pyArray;
64495 } else {
64496 var pyDict = {};
64497 for (var _i2 = 0, _Object$keys2 = Object.keys(tsConfig); _i2 < _Object$keys2.length; _i2++) {
64498 var tsKey = _Object$keys2[_i2];
64499 var tsValue = tsConfig[tsKey];
64500 var pyKey = toSnakeCase(tsKey);
64501 if ((tsKey === 'name' || tsKey === 'className') && typeof tsValue === 'string') {
64502 // Special case the 'name' key with a string value. Name values, such as
64503 // the names of LayersModel and Layer instances, should not undergo the
64504 // snake-case conversion.
64505 pyDict[pyKey] = tsValue;
64506 } else {
64507 pyDict[pyKey] = convertTsToPythonic(tsValue, tsKey);
64508 }
64509 }
64510 return pyDict;
64511 }
64512 }
64513
64514 /** @license See the LICENSE file. */
64515 // This code is auto-generated, do not modify this file!
64516 var version$6 = '4.22.0';
64517
64518 // get weights key from tensor map in order to check if it is from keras v3.
64519 // e.g. dense/0
64520 var isKerasSavedModelFormat = function isKerasSavedModelFormat(weights) {
64521 var keys = Object.keys(weights);
64522 if (keys.length === 0) {
64523 return false;
64524 }
64525 var key = keys[0].split('/');
64526 return !isNaN(parseInt(key[key.length - 1], 10));
64527 };
64528 /**
64529 * A Container is a directed acyclic graph of layers.
64530 *
64531 * It is the topological form of a "model". A LayersModel
64532 * is simply a Container with added training routines.
64533 *
64534 */
64535 var Container = /*#__PURE__*/function (_Layer) {
64536 _inherits(Container, _Layer);
64537 var _super = _createSuper(Container);
64538 function Container(args) {
64539 var _this;
64540 _classCallCheck(this, Container);
64541 // No args passed to super's constructor.
64542 _this = _super.call(this, {});
64543 _this.containerNodes = new Set();
64544 _this.name = args.name;
64545 if (_this.name == null) {
64546 var prefix = _this.getClassName().toLowerCase();
64547 _this.name = getUid(prefix);
64548 }
64549 _this.supportsMasking = false;
64550 _this.trainable_ = true;
64551 // TODO(michaelterry): Initialize perInputLosses/Updates here.
64552 // Container-specific properties.
64553 if (Array.isArray(args.inputs)) {
64554 _this.inputs = args.inputs.slice();
64555 } else {
64556 _this.inputs = [args.inputs];
64557 }
64558 if (Array.isArray(args.outputs)) {
64559 _this.outputs = args.outputs.slice();
64560 } else {
64561 _this.outputs = [args.outputs];
64562 }
64563 // Check for redundancy in inputs.
64564 if (unique$2(_this.inputs).length !== _this.inputs.length) {
64565 throw new ValueError('The list of inputs passed to the model is ' + 'redundant. All inputs should only appear once. Found: ' + "".concat(_this.inputs.map(function (x) {
64566 return x.name;
64567 })));
64568 }
64569 // Check for redundancy in outputs.
64570 if (unique$2(_this.outputs).length !== _this.outputs.length) {
64571 console.warn('The list of outputs passed to the model is redundant. ' + 'All outputs should only appear once. Found: ' + "".concat(_this.outputs.map(function (x) {
64572 return x.name;
64573 })));
64574 }
64575 /*
64576 List of initial layers (1 to 1 mapping with this.inputs, hence the same
64577 layer might appear twice)
64578 */
64579 _this.inputLayers = [];
64580 _this.inputLayersNodeIndices = [];
64581 _this.inputLayersTensorIndices = [];
64582 /*
64583 List of layers (1 to 1 mapping with this.outputs, hence the same layer
64584 might appear twice)
64585 */
64586 _this.outputLayers = [];
64587 _this.outputLayersNodeIndices = [];
64588 _this.outputLayersTensorIndices = [];
64589 /*
64590 All layers in order of horizontal graph traversal. Entries are unique.
64591 Includes input and output layers.
64592 */
64593 _this.layers = [];
64594 /*
64595 References to container layers that were constructed internally. We need
64596 these to properly dispose of tensors from nested containers.
64597 */
64598 _this.internalContainerRefs = [];
64599 // TODO(michaelterry): Determine if caching still needed with eager
64600 // backend.
64601 /*
64602 This is for performance optimization when calling the Container on new
64603 inputs. Every time the Container is called on a set on input tensors,
64604 we compute the output tensors, output masks and output shapes in one pass,
64605 then cache them here. When one of these outputs is queried later,
64606 we retrieve it from there instead of recomputing it.
64607 */
64608 // this.outputTensorCache = {};
64609 // this.outputShapeCache = {};
64610 // Build this.outputLayers:
64611 var _iterator = _createForOfIteratorHelper(_this.outputs),
64612 _step;
64613 try {
64614 for (_iterator.s(); !(_step = _iterator.n()).done;) {
64615 var x = _step.value;
64616 var _layer2 = x.sourceLayer;
64617 var nodeIndex = x.nodeIndex;
64618 var tensorIndex = x.tensorIndex;
64619 _this.outputLayers.push(_layer2);
64620 _this.outputLayersNodeIndices.push(nodeIndex);
64621 _this.outputLayersTensorIndices.push(tensorIndex);
64622 }
64623 // TODO(michaelterry): Add output mask cache code.
64624 // Build this.inputLayers:
64625 } catch (err) {
64626 _iterator.e(err);
64627 } finally {
64628 _iterator.f();
64629 }
64630 var _iterator2 = _createForOfIteratorHelper(_this.inputs),
64631 _step2;
64632 try {
64633 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
64634 var _x = _step2.value;
64635 var _layer3 = _x.sourceLayer;
64636 var _nodeIndex2 = _x.nodeIndex;
64637 var _tensorIndex2 = _x.tensorIndex;
64638 /*
64639 It's supposed to be an input layer, so only one node
64640 and one tensor output.
64641 */
64642 assert(_nodeIndex2 === 0, 'input layer has >1 nodes');
64643 assert(_tensorIndex2 === 0, 'input layer has >1 tensors');
64644 _this.inputLayers.push(_layer3);
64645 _this.inputLayersNodeIndices.push(_nodeIndex2);
64646 _this.inputLayersTensorIndices.push(_tensorIndex2);
64647 }
64648 // Build this.inputNames and this.outputNames.
64649 } catch (err) {
64650 _iterator2.e(err);
64651 } finally {
64652 _iterator2.f();
64653 }
64654 _this.inputNames = [];
64655 _this.outputNames = [];
64656 _this.feedInputShapes = [];
64657 _this.feedInputNames = [];
64658 _this.feedOutputNames = [];
64659 for (var i = 0; i < _this.inputLayers.length; i++) {
64660 var layer = _this.inputLayers[i];
64661 // Check that layer is an InputLayer.
64662 if (!(layer instanceof InputLayer)) {
64663 throw new TypeError('Input layers to a LayersModel must be InputLayer objects. ' + "Received inputs: ".concat(args.inputs, ". ") + "Input ".concat(i, " (0-based) originates ") + "from layer type ".concat(layer.getClassName(), "."));
64664 }
64665 _this.inputNames.push(layer.name);
64666 _this.feedInputShapes.push(layer.batchInputShape);
64667 _this.feedInputNames.push(layer.name);
64668 }
64669 var _iterator3 = _createForOfIteratorHelper(_this.outputLayers),
64670 _step3;
64671 try {
64672 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
64673 var _layer4 = _step3.value;
64674 _this.outputNames.push(_layer4.name);
64675 }
64676 } catch (err) {
64677 _iterator3.e(err);
64678 } finally {
64679 _iterator3.f();
64680 }
64681 _this.internalInputShapes = _this.inputs.map(function (x) {
64682 return x.shape;
64683 });
64684 _this.internalOutputShapes = _this.outputs.map(function (x) {
64685 return x.shape;
64686 });
64687 /*
64688 Container_nodes: set of nodes included in the graph (not all nodes
64689 included in the layers are relevant to the current graph).
64690 */
64691 // ids of all nodes relevant to the Container:
64692 var nodesDepths = {};
64693 // To recover nodes from their ID.
64694 var nodeIDToNode = {};
64695 var layersDepths = {};
64696 // To layers from their ID.
64697 var layerIDToLayer = {};
64698 var layerIndices = {};
64699 var nodesInDecreasingDepth = [];
64700 /**
64701 * Builds a map of the graph of layers.
64702 *
64703 * This recursively updates the map `layerIndices`,
64704 * the list `nodesInDecreasingDepth` and the set `containerNodes`.
64705 *
64706 * @param tensor Some tensor in a graph.
64707 * @param finishedNodes Set of nodes whose subgraphs have been traversed
64708 * completely. Useful to prevent duplicated work.
64709 * @param nodesInProgress Set of nodes that are currently active on the
64710 * recursion stack. Useful to detect cycles.
64711 * @param layer Layer from which `tensor` comes from. If not provided,
64712 * will be obtained from tensor.sourceLayer.
64713 * @param nodeIndex Node index from which `tensor` comes from.
64714 * @param tensorIndex TensorIndex from which `tensor` comes from.
64715 *
64716 * @exception RuntimeError if a cycle is detected.
64717 */
64718 var buildMapOfGraph = function buildMapOfGraph(tensor, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex) {
64719 if (layer == null || nodeIndex == null || tensorIndex == null) {
64720 layer = tensor.sourceLayer;
64721 nodeIndex = tensor.nodeIndex;
64722 tensorIndex = tensor.tensorIndex;
64723 }
64724 var node = layer.inboundNodes[nodeIndex];
64725 // Prevent cycles.
64726 if (nodesInProgress.indexOf(node) !== -1) {
64727 throw new RuntimeError("The tensor ".concat(tensor.name, " at layer \"").concat(layer.name, "\" ") + 'is part of a cycle.');
64728 }
64729 // Don't repeat work for shared subgraphs
64730 if (finishedNodes.indexOf(node) !== -1) {
64731 return;
64732 }
64733 // Update containerNodes.
64734 _this.containerNodes.add(Container.nodeKey(layer, nodeIndex));
64735 // Store the traversal order for layer sorting.
64736 if (!(layer.id in layerIndices)) {
64737 layerIndices[layer.id] = Object.keys(layerIndices).length;
64738 }
64739 if (nodesInProgress.indexOf(node) === -1) {
64740 nodesInProgress.push(node);
64741 }
64742 // Propagate to all previous tensors connected to this node.
64743 var numInboundLayers = node.inboundLayers.length;
64744 for (var _i = 0; _i < numInboundLayers; _i++) {
64745 var x = node.inputTensors[_i];
64746 var _layer = node.inboundLayers[_i];
64747 var _nodeIndex = node.nodeIndices[_i];
64748 var _tensorIndex = node.tensorIndices[_i];
64749 buildMapOfGraph(x, finishedNodes, nodesInProgress, _layer, _nodeIndex, _tensorIndex);
64750 }
64751 finishedNodes.push(node);
64752 while (nodesInProgress.indexOf(node) >= 0) {
64753 nodesInProgress.splice(nodesInProgress.indexOf(node), 1);
64754 }
64755 nodesInDecreasingDepth.push(node);
64756 };
64757 var finishedNodes = [];
64758 var nodesInProgress = [];
64759 var _iterator4 = _createForOfIteratorHelper(_this.outputs),
64760 _step4;
64761 try {
64762 for (_iterator4.s(); !(_step4 = _iterator4.n()).done;) {
64763 var _x2 = _step4.value;
64764 buildMapOfGraph(_x2, finishedNodes, nodesInProgress);
64765 }
64766 } catch (err) {
64767 _iterator4.e(err);
64768 } finally {
64769 _iterator4.f();
64770 }
64771 var reversedNodesInDecreasingDepth = nodesInDecreasingDepth.slice().reverse();
64772 var _iterator5 = _createForOfIteratorHelper(reversedNodesInDecreasingDepth),
64773 _step5;
64774 try {
64775 for (_iterator5.s(); !(_step5 = _iterator5.n()).done;) {
64776 var node = _step5.value;
64777 nodeIDToNode[node.id] = node;
64778 // If the depth is not set, the node has no outbound nodes (depth 0).
64779 if (!(node.id in nodesDepths)) {
64780 nodesDepths[node.id] = 0;
64781 }
64782 var _depth2 = nodesDepths[node.id];
64783 // Update the depth of the corresponding layer
64784 var previousDepth = layersDepths[node.outboundLayer.id] == null ? 0 : layersDepths[node.outboundLayer.id];
64785 /*
64786 If we've seen this layer before at a higher depth, we should use that
64787 depth instead of the node depth. This is necessary for shared layers
64788 that have inputs at different depth levels in the graph.
64789 */
64790 _depth2 = Math.max(_depth2, previousDepth);
64791 layersDepths[node.outboundLayer.id] = _depth2;
64792 layerIDToLayer[node.outboundLayer.id] = node.outboundLayer;
64793 nodesDepths[node.id] = _depth2;
64794 // Update the depth of inbound nodes.
64795 for (var _i2 = 0; _i2 < node.inboundLayers.length; _i2++) {
64796 var inboundLayer = node.inboundLayers[_i2];
64797 var _nodeIndex3 = node.nodeIndices[_i2];
64798 var inboundNode = inboundLayer.inboundNodes[_nodeIndex3];
64799 var _previousDepth = nodesDepths[inboundNode.id] == null ? 0 : nodesDepths[inboundNode.id];
64800 nodesDepths[inboundNode.id] = Math.max(_depth2 + 1, _previousDepth);
64801 nodeIDToNode[inboundNode.id] = inboundNode;
64802 }
64803 }
64804 // Build a dict {depth: list of nodes with this depth}
64805 } catch (err) {
64806 _iterator5.e(err);
64807 } finally {
64808 _iterator5.f();
64809 }
64810 var nodesByDepth = {};
64811 for (var nodeID in nodesDepths) {
64812 var depth = nodesDepths[nodeID];
64813 if (!(depth in nodesByDepth)) {
64814 nodesByDepth[depth] = [];
64815 }
64816 nodesByDepth[depth].push(nodeIDToNode[nodeID]);
64817 }
64818 // Build a dict {depth: list of layers with this depth}
64819 var layersByDepth = {};
64820 for (var layerID in layersDepths) {
64821 var _depth = layersDepths[layerID];
64822 if (!(_depth in layersByDepth)) {
64823 layersByDepth[_depth] = [];
64824 }
64825 layersByDepth[_depth].push(layerIDToLayer[layerID]);
64826 }
64827 // Get sorted list of layer depths.
64828 var depthKeys = Object.keys(layersByDepth).map(function (x) {
64829 return parseInt(x, 10);
64830 }).sort(reverseNumberCompare);
64831 // Set this.layers and this.layersByDepth.
64832 _this.layers = [];
64833 var _iterator6 = _createForOfIteratorHelper(depthKeys),
64834 _step6;
64835 try {
64836 for (_iterator6.s(); !(_step6 = _iterator6.n()).done;) {
64837 var _depth3 = _step6.value;
64838 var layersForDepth = layersByDepth[_depth3];
64839 // Container.layers needs to have a deterministic order:
64840 // here we order them by traversal order.
64841 layersForDepth.sort(function (a, b) {
64842 var aIndex = layerIndices[a.id];
64843 var bIndex = layerIndices[b.id];
64844 if (aIndex < bIndex) {
64845 return -1;
64846 }
64847 if (aIndex > bIndex) {
64848 return 1;
64849 }
64850 return 0;
64851 });
64852 var _iterator9 = _createForOfIteratorHelper(layersForDepth),
64853 _step9;
64854 try {
64855 for (_iterator9.s(); !(_step9 = _iterator9.n()).done;) {
64856 var _layer5 = _step9.value;
64857 if (_layer5 instanceof Container) {
64858 _this.internalContainerRefs.push(_layer5);
64859 }
64860 _this.layers.push(_layer5);
64861 }
64862 } catch (err) {
64863 _iterator9.e(err);
64864 } finally {
64865 _iterator9.f();
64866 }
64867 }
64868 } catch (err) {
64869 _iterator6.e(err);
64870 } finally {
64871 _iterator6.f();
64872 }
64873 _this.layersByDepth = layersByDepth;
64874 // Get sorted list of node depths;
64875 depthKeys = Object.keys(nodesByDepth).map(function (x) {
64876 return parseInt(x, 10);
64877 }).sort(reverseNumberCompare);
64878 // Check that all tensors required are computable.
64879 // computable_tensors: all tensors in the graph
64880 // that can be computed from the inputs provided.
64881 var computableTensors = _this.inputs.slice();
64882 // To provide a better error msg.
64883 var layersWithCompleteInput = [];
64884 var _iterator7 = _createForOfIteratorHelper(depthKeys),
64885 _step7;
64886 try {
64887 for (_iterator7.s(); !(_step7 = _iterator7.n()).done;) {
64888 var _depth4 = _step7.value;
64889 var _iterator10 = _createForOfIteratorHelper(nodesByDepth[_depth4]),
64890 _step10;
64891 try {
64892 for (_iterator10.s(); !(_step10 = _iterator10.n()).done;) {
64893 var _node = _step10.value;
64894 var _layer6 = _node.outboundLayer;
64895 if (_layer6 != null) {
64896 var _iterator11 = _createForOfIteratorHelper(_node.inputTensors),
64897 _step11;
64898 try {
64899 for (_iterator11.s(); !(_step11 = _iterator11.n()).done;) {
64900 var _x3 = _step11.value;
64901 if (computableTensors.indexOf(_x3) === -1) {
64902 throw new RuntimeError("Graph disconnected: cannot obtain value for tensor ".concat(_x3) + " at layer \"".concat(_layer6.name, "\". ") + 'The following previous layers were accessed without ' + "issue: ".concat(layersWithCompleteInput));
64903 }
64904 }
64905 } catch (err) {
64906 _iterator11.e(err);
64907 } finally {
64908 _iterator11.f();
64909 }
64910 var _iterator12 = _createForOfIteratorHelper(_node.outputTensors),
64911 _step12;
64912 try {
64913 for (_iterator12.s(); !(_step12 = _iterator12.n()).done;) {
64914 var _x4 = _step12.value;
64915 computableTensors.push(_x4);
64916 }
64917 } catch (err) {
64918 _iterator12.e(err);
64919 } finally {
64920 _iterator12.f();
64921 }
64922 layersWithCompleteInput.push(_layer6.name);
64923 }
64924 }
64925 } catch (err) {
64926 _iterator10.e(err);
64927 } finally {
64928 _iterator10.f();
64929 }
64930 }
64931 // Set this.containerNodes and this.nodesByDepth.
64932 } catch (err) {
64933 _iterator7.e(err);
64934 } finally {
64935 _iterator7.f();
64936 }
64937 _this.nodesByDepth = nodesByDepth;
64938 // Ensure name unicity, which will be crucial for serialization
64939 // (since serialized nodes refer to layers by their name).
64940 var allNames = _this.layers.map(function (x) {
64941 return x.name;
64942 });
64943 var _iterator8 = _createForOfIteratorHelper(allNames),
64944 _step8;
64945 try {
64946 var _loop = function _loop() {
64947 var name = _step8.value;
64948 var numOccurrences = allNames.filter(function (x) {
64949 return x === name;
64950 }).length;
64951 if (numOccurrences !== 1) {
64952 throw new RuntimeError("The name \"".concat(name, "\" is used ").concat(numOccurrences, " times ") + 'in the model. All layer names should be unique. Layer names: ' + JSON.stringify(allNames));
64953 }
64954 };
64955 for (_iterator8.s(); !(_step8 = _iterator8.n()).done;) {
64956 _loop();
64957 }
64958 // Layer parameters.
64959 // The new container starts with a single inbound node
64960 // for its inputs, and no outbound nodes.
64961 // Will be appended to by future calls to apply().
64962 } catch (err) {
64963 _iterator8.e(err);
64964 } finally {
64965 _iterator8.f();
64966 }
64967 _this.outboundNodes = [];
64968 // Will be appended to below, and by future calls to apply().
64969 _this.inboundNodes = [];
64970 // Create the node linking internal inputs to internal outputs.
64971 // (This call has side effects.)
64972 // tslint:disable-next-line:no-unused-expression
64973 new Node({
64974 outboundLayer: _assertThisInitialized(_this),
64975 inboundLayers: [],
64976 nodeIndices: [],
64977 tensorIndices: [],
64978 inputTensors: _this.inputs,
64979 outputTensors: _this.outputs,
64980 inputMasks: _this.inputs.map(function (x) {
64981 return null;
64982 }),
64983 outputMasks: _this.outputs.map(function (x) {
64984 return null;
64985 }),
64986 inputShapes: _this.inputs.map(function (x) {
64987 return x.shape;
64988 }),
64989 outputShapes: _this.outputs.map(function (x) {
64990 return x.shape;
64991 })
64992 });
64993 _this.built = true;
64994 _this._refCount = 1; // The ref count of a container always start at 1.
64995 return _this;
64996 }
64997 _createClass(Container, [{
64998 key: "assertNotDisposed",
64999 value: function assertNotDisposed() {
65000 if (this._refCount === 0) {
65001 throw new Error("Container '".concat(this.name, "' is already disposed."));
65002 }
65003 }
65004 /**
65005 * Attempt to dispose a LayersModel's weights.
65006 *
65007 * This method decrease the reference count of the LayersModel object by 1.
65008 *
65009 * A LayersModel is reference-counted. Its reference count is incremented by 1
65010 * when it is first constructed and when it is used as a Layer of another
65011 * LayersModel.
65012 *
65013 * If the reference count of a LayersModel becomes 0, the `dispose` method of
65014 * all its constituent `Layer`s will be called.
65015 *
65016 * Note: If the reference count is greater than 0 after the decrement, the
65017 * `dispose` method of its constituent `Layer`s will *not* be called.
65018 *
65019 * After a LayersModel is disposed, it cannot be used in calls such as
65020 * 'predict`, `evaluate` or `fit` anymore.
65021 *
65022 * @returns A DisposeResult Object with the following fields:
65023 * - refCountAfterDispose: The reference count of the LayersModel after this
65024 * `dispose()` call.
65025 * - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed
65026 * during this `dispose()` call.
65027 * @throws {Error} If the layer is not built yet, or if the LayersModel has
65028 * already been disposed.
65029 */
65030 }, {
65031 key: "dispose",
65032 value: function dispose() {
65033 this.assertNotDisposed();
65034 var result = {
65035 refCountAfterDispose: null,
65036 numDisposedVariables: 0
65037 };
65038 if (--this._refCount === 0) {
65039 var _iterator13 = _createForOfIteratorHelper(this.layers),
65040 _step13;
65041 try {
65042 for (_iterator13.s(); !(_step13 = _iterator13.n()).done;) {
65043 var layer = _step13.value;
65044 result.numDisposedVariables += layer.dispose().numDisposedVariables;
65045 }
65046 // Call dispose on each internally created container layer again to ensure
65047 // their refCounts hit zero and their tensors are subsequently deleted.
65048 } catch (err) {
65049 _iterator13.e(err);
65050 } finally {
65051 _iterator13.f();
65052 }
65053 var _iterator14 = _createForOfIteratorHelper(this.internalContainerRefs),
65054 _step14;
65055 try {
65056 for (_iterator14.s(); !(_step14 = _iterator14.n()).done;) {
65057 var container = _step14.value;
65058 result.numDisposedVariables += container.dispose().numDisposedVariables;
65059 }
65060 } catch (err) {
65061 _iterator14.e(err);
65062 } finally {
65063 _iterator14.f();
65064 }
65065 }
65066 result.refCountAfterDispose = this._refCount;
65067 return result;
65068 }
65069 }, {
65070 key: "trainable",
65071 get: function get() {
65072 return this.trainable_;
65073 },
65074 set: function set(trainable) {
65075 this.layers.forEach(function (layer) {
65076 // tslint:disable-next-line:no-any
65077 layer._trainableWeights.forEach(function (w) {
65078 return w.trainable = trainable;
65079 });
65080 });
65081 this.trainable_ = trainable;
65082 }
65083 }, {
65084 key: "trainableWeights",
65085 get: function get() {
65086 // Porting Note: This check below is to prevent errors where the
65087 // _trainableWeights inherited from the parent class (Layer) gets
65088 // inadvertently used.
65089 if (this._trainableWeights.length > 0) {
65090 throw new ValueError('Container instance unexpectedly contains _trainableWeights.' + 'The trainable weights of a Container are a union of the ' + 'trainable weights of its consituent Layers. Its own ' + '_trainableWeights must remain an empty Array.');
65091 }
65092 if (!this.trainable) {
65093 return [];
65094 }
65095 var weights = [];
65096 var _iterator15 = _createForOfIteratorHelper(this.layers),
65097 _step15;
65098 try {
65099 for (_iterator15.s(); !(_step15 = _iterator15.n()).done;) {
65100 var layer = _step15.value;
65101 weights = weights.concat(layer.trainableWeights);
65102 }
65103 } catch (err) {
65104 _iterator15.e(err);
65105 } finally {
65106 _iterator15.f();
65107 }
65108 return weights;
65109 }
65110 }, {
65111 key: "nonTrainableWeights",
65112 get: function get() {
65113 var weights = [];
65114 var _iterator16 = _createForOfIteratorHelper(this.layers),
65115 _step16;
65116 try {
65117 for (_iterator16.s(); !(_step16 = _iterator16.n()).done;) {
65118 var _layer7 = _step16.value;
65119 weights.push.apply(weights, _toConsumableArray(_layer7.nonTrainableWeights));
65120 }
65121 } catch (err) {
65122 _iterator16.e(err);
65123 } finally {
65124 _iterator16.f();
65125 }
65126 if (!this.trainable) {
65127 var trainableWeights = [];
65128 var _iterator17 = _createForOfIteratorHelper(this.layers),
65129 _step17;
65130 try {
65131 for (_iterator17.s(); !(_step17 = _iterator17.n()).done;) {
65132 var layer = _step17.value;
65133 trainableWeights.push.apply(trainableWeights, _toConsumableArray(layer.trainableWeights));
65134 }
65135 } catch (err) {
65136 _iterator17.e(err);
65137 } finally {
65138 _iterator17.f();
65139 }
65140 return trainableWeights.concat(weights);
65141 }
65142 return weights;
65143 }
65144 }, {
65145 key: "weights",
65146 get: function get() {
65147 return this.trainableWeights.concat(this.nonTrainableWeights);
65148 }
65149 /**
65150 * Loads all layer weights from a JSON object.
65151 *
65152 * Porting Note: HDF5 weight files cannot be directly loaded in JavaScript /
65153 * TypeScript. The utility script at `scripts/pykeras.py` offers means
65154 * to convert them into JSON strings compatible with this method.
65155 * Porting Note: TensorFlow.js Layers supports only loading by name currently.
65156 *
65157 * @param weights A JSON mapping weight names to weight values as nested
65158 * arrays of numbers, or a `NamedTensorMap`, i.e., a JSON mapping weight
65159 * names to `tf.Tensor` objects.
65160 * @param strict Require that the provided weights exactly match those
65161 * required by the container. Default: `true`. Passing `false` means that
65162 * extra weights and missing weights will be silently ignored.
65163 */
65164 }, {
65165 key: "loadWeights",
65166 value: function loadWeights(weights) {
65167 var strict = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : true;
65168 var nameToWeight = {};
65169 var totalWeightsCount = 0;
65170 var modelIsKerasSavedModelFormat = isKerasSavedModelFormat(weights);
65171 if (modelIsKerasSavedModelFormat) {
65172 this.parseWeights(weights);
65173 }
65174 // Check if weights from keras v3.
65175 var _iterator18 = _createForOfIteratorHelper(this.layers),
65176 _step18;
65177 try {
65178 for (_iterator18.s(); !(_step18 = _iterator18.n()).done;) {
65179 var layer = _step18.value;
65180 var _iterator19 = _createForOfIteratorHelper(layer.weights.entries()),
65181 _step19;
65182 try {
65183 for (_iterator19.s(); !(_step19 = _iterator19.n()).done;) {
65184 var _step19$value = _slicedToArray(_step19.value, 2),
65185 index = _step19$value[0],
65186 weight = _step19$value[1];
65187 // Parse the name to layerName/index.
65188 // e.g. dense/0, dense/1, dense_1/0, dense_1/1
65189 var parsedName = modelIsKerasSavedModelFormat ? "".concat(weight.name.split('/').slice(0, -1).join('/') + '/').concat(index) : weight.originalName;
65190 if (nameToWeight[parsedName] != null) {
65191 throw new ValueError("Duplicate weight name: ".concat(parsedName));
65192 }
65193 nameToWeight[parsedName] = weight;
65194 totalWeightsCount++;
65195 }
65196 } catch (err) {
65197 _iterator19.e(err);
65198 } finally {
65199 _iterator19.f();
65200 }
65201 }
65202 } catch (err) {
65203 _iterator18.e(err);
65204 } finally {
65205 _iterator18.f();
65206 }
65207 var weightValueTuples = [];
65208 for (var name in weights) {
65209 // TF 2.2.0 added cell name to the weight name in the format of
65210 // layer_name/cell_name/weight_name, we need to remove
65211 // the inner cell name.
65212 var validatedName = name;
65213 if (nameToWeight[name] == null) {
65214 var tokens = name.split('/');
65215 var shortenNameArray = tokens.slice(0, -2).concat([tokens[tokens.length - 1]]);
65216 validatedName = shortenNameArray.join('/');
65217 }
65218 if (nameToWeight[validatedName] != null) {
65219 weightValueTuples.push([nameToWeight[validatedName], weights[name]]);
65220 } else if (strict) {
65221 throw new ValueError("Provided weight data has no target variable: ".concat(name));
65222 }
65223 delete nameToWeight[validatedName];
65224 }
65225 if (strict) {
65226 // Check that all weights are set.
65227 var unsetNames = [];
65228 for (var _name in nameToWeight) {
65229 unsetNames.push(_name);
65230 }
65231 if (unsetNames.length > 0) {
65232 throw new ValueError("".concat(unsetNames.length, " of ").concat(totalWeightsCount, " weights are not set: ") + "".concat(unsetNames));
65233 }
65234 }
65235 batchSetValue(weightValueTuples);
65236 }
65237 }, {
65238 key: "parseWeights",
65239 value: function parseWeights(weights) {
65240 var _loop2 = function _loop2() {
65241 var listParts = key.split('/');
65242 var list = ['vars', 'layer_checkpoint_dependencies'];
65243 // For keras v3, the weights name are saved based on the folder structure.
65244 // e.g. _backbone/_layer_checkpoint_dependencies/transformer/_self../
65245 // _output_dense/vars/0
65246 // Therefore we discard the `vars` and `layer_checkpoint_depencies` within
65247 // the saved name and only keeps the layer name and weights.
65248 // This can help to mapping the actual name of the layers and load each
65249 // weight accordingly.
65250 var newKey = listParts.map(function (str) {
65251 if (str.startsWith('_')) {
65252 return str.slice(1);
65253 }
65254 return str;
65255 }).filter(function (str) {
65256 return !list.includes(str);
65257 }).join('/');
65258 if (newKey !== key) {
65259 weights[newKey] = weights[key];
65260 delete weights[key];
65261 }
65262 };
65263 for (var key in Object.keys(weights)) {
65264 _loop2();
65265 }
65266 }
65267 /**
65268 * Util shared between different serialization methods.
65269 * @returns LayersModel config with Keras version information added.
65270 */
65271 }, {
65272 key: "updatedConfig",
65273 value: function updatedConfig() {
65274 var theConfig = this.getConfig();
65275 var modelConfig = {};
65276 modelConfig['className'] = this.getClassName();
65277 modelConfig['config'] = theConfig;
65278 modelConfig['kerasVersion'] = "tfjs-layers ".concat(version$6);
65279 // TODO(nielsene): Replace something like K.backend() once
65280 // possible.
65281 modelConfig['backend'] = 'TensorFlow.js';
65282 return modelConfig;
65283 }
65284 /**
65285 * Returns a JSON string containing the network configuration.
65286 *
65287 * To load a network from a JSON save file, use
65288 * models.modelFromJSON(jsonString);
65289 * @param extraJsonArgs Unused in tfjs-layers, maintained for PyKeras
65290 * @param returnString Whether the return value should be stringified
65291 * (default: `true`).
65292 * @returns a JSON string if `returnString` (default), or a JSON object if
65293 * `!returnString`.
65294 */
65295 // tslint:disable-next-line:no-any
65296 }, {
65297 key: "toJSON",
65298 value: function toJSON(unused) {
65299 var returnString = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : true;
65300 var modelConfig = convertTsToPythonic(this.updatedConfig());
65301 return returnString ? JSON.stringify(modelConfig) : modelConfig;
65302 }
65303 /**
65304 * Call the model on new inputs.
65305 *
65306 * In this case `call` just reapplies all ops in the graph to the new inputs
65307 * (e.g. build a new computational graph from the provided inputs).
65308 *
65309 * @param inputs A tensor or list of tensors.
65310 * @param mask A mask or list of masks. A mask can be either a tensor or null
65311 * (no mask).
65312 *
65313 * @return A tensor if there is a single output, or a list of tensors if there
65314 * are more than one outputs.
65315 */
65316 }, {
65317 key: "call",
65318 value: function call(inputs, kwargs) {
65319 var _this2 = this;
65320 return tidy(function () {
65321 inputs = toList(inputs);
65322 var feedDict = new FeedDict();
65323 for (var i = 0; i < _this2.inputs.length; ++i) {
65324 feedDict.add(_this2.inputs[i], inputs[i]);
65325 }
65326 return execute(_this2.outputs, feedDict, kwargs);
65327 });
65328 }
65329 /**
65330 * Computes an output mask tensor.
65331 *
65332 * @param inputs Tensor or list of tensors.
65333 * @param mask Tensor or list of tensors.
65334 *
65335 * @return null or a tensor (or list of tensors, one per output tensor of the
65336 * layer).
65337 */
65338 }, {
65339 key: "computeMask",
65340 value: function computeMask(inputs, mask) {
65341 var _this3 = this;
65342 return tidy(function () {
65343 inputs = toList(inputs);
65344 var masks;
65345 if (mask == null) {
65346 masks = pyListRepeat(null, inputs.length);
65347 } else {
65348 masks = toList(mask);
65349 }
65350 // TODO(michaelterry): Add support for mask caching.
65351 return _this3.runInternalGraph(inputs, masks)[1];
65352 });
65353 }
65354 /**
65355 * Computes the output shape of the layer.
65356 *
65357 * Assumes that the layer will be built to match that input shape provided.
65358 *
65359 * @param inputShape A shape (tuple of integers) or a list of shape tuples
65360 * (one per output tensor of the layer). Shape tuples can include null for
65361 * free dimensions, instead of an integer.
65362 */
65363 }, {
65364 key: "computeOutputShape",
65365 value: function computeOutputShape(inputShape) {
65366 var inputShapes = normalizeShapeList(inputShape);
65367 if (inputShapes.length !== this.inputLayers.length) {
65368 throw new ValueError("Invalid inputShape argument ".concat(inputShape, ": ") + "model has ".concat(this.inputLayers.length, " tensor inputs."));
65369 }
65370 // TODO(michaelterry): Add caching
65371 var layersToOutputShapes = {};
65372 for (var i = 0; i < inputShapes.length; i++) {
65373 var layer = this.inputLayers[i];
65374 var _inputShape = inputShapes[i];
65375 // It's an input layer: computeOutputShape is identity,
65376 // and there is only one node and one tensor output.
65377 var shapeKey = layer.name + '_0_0';
65378 layersToOutputShapes[shapeKey] = _inputShape;
65379 }
65380 var depthKeys = Object.keys(this.nodesByDepth).map(function (x) {
65381 return parseInt(x, 10);
65382 }).sort(reverseNumberCompare);
65383 // Iterate over nodes, by depth level.
65384 if (depthKeys.length > 1) {
65385 var _iterator20 = _createForOfIteratorHelper(depthKeys),
65386 _step20;
65387 try {
65388 for (_iterator20.s(); !(_step20 = _iterator20.n()).done;) {
65389 var depth = _step20.value;
65390 var nodes = this.nodesByDepth[depth];
65391 var _iterator21 = _createForOfIteratorHelper(nodes),
65392 _step21;
65393 try {
65394 for (_iterator21.s(); !(_step21 = _iterator21.n()).done;) {
65395 var node = _step21.value;
65396 // This is always a single layer, never a list.
65397 var _layer8 = node.outboundLayer;
65398 if (this.inputLayers.map(function (x) {
65399 return x.id;
65400 }).indexOf(_layer8.id) !== -1) {
65401 // We've already covered the input layers a few lines above.
65402 continue;
65403 }
65404 // Potentially redundant list, same size of node.inputTensors.
65405 var _inputShapes = [];
65406 for (var j = 0; j < node.inboundLayers.length; j++) {
65407 var inboundLayer = node.inboundLayers[j];
65408 var _nodeIndex4 = node.nodeIndices[j];
65409 var tensorIndex = node.tensorIndices[j];
65410 var _shapeKey = "".concat(inboundLayer.name, "_").concat(_nodeIndex4, "_").concat(tensorIndex);
65411 var _inputShape2 = layersToOutputShapes[_shapeKey];
65412 _inputShapes.push(_inputShape2);
65413 }
65414 var outputShape = _layer8.computeOutputShape(singletonOrArray(_inputShapes));
65415 var _outputShapes = normalizeShapeList(outputShape);
65416 var nodeIndex = _layer8.inboundNodes.indexOf(node);
65417 for (var _j = 0; _j < _outputShapes.length; _j++) {
65418 var _shapeKey2 = "".concat(_layer8.name, "_").concat(nodeIndex, "_").concat(_j);
65419 layersToOutputShapes[_shapeKey2] = _outputShapes[_j];
65420 }
65421 }
65422 } catch (err) {
65423 _iterator21.e(err);
65424 } finally {
65425 _iterator21.f();
65426 }
65427 }
65428 } catch (err) {
65429 _iterator20.e(err);
65430 } finally {
65431 _iterator20.f();
65432 }
65433 }
65434 // Read final output shapes from layersToOutputShapes.
65435 var outputShapes = [];
65436 var outputShapeKeys = [];
65437 for (var _i3 = 0; _i3 < this.outputLayers.length; _i3++) {
65438 var _layer9 = this.outputLayers[_i3];
65439 var _nodeIndex5 = this.outputLayersNodeIndices[_i3];
65440 var _tensorIndex3 = this.outputLayersTensorIndices[_i3];
65441 var _shapeKey3 = "".concat(_layer9.name, "_").concat(_nodeIndex5, "_").concat(_tensorIndex3);
65442 outputShapeKeys.push(_shapeKey3);
65443 }
65444 for (var _i4 = 0; _i4 < outputShapeKeys.length; _i4++) {
65445 var key = outputShapeKeys[_i4];
65446 assert(key in layersToOutputShapes);
65447 outputShapes.push(layersToOutputShapes[key]);
65448 }
65449 // TODO(michaelterry): Update cache
65450 return singletonOrArray(outputShapes);
65451 }
65452 /**
65453 * Computes output tensors for new inputs.
65454 *
65455 * Note:
65456 * - Expects `inputs` to be a list (potentially with 1 element).
65457 *
65458 * @param inputs List of tensors
65459 * @param masks List of masks (tensors or null).
65460 * @return Three lists: outputTensors, outputMasks, outputShapes
65461 */
65462 }, {
65463 key: "runInternalGraph",
65464 value: function runInternalGraph(inputs, masks) {
65465 if (masks == null) {
65466 masks = pyListRepeat(null, inputs.length);
65467 }
65468 // Dictionary mapping reference tensors to tuples
65469 // (computed tensor, compute mask)
65470 // we assume a 1:1 mapping from tensor to mask
65471 // TODO: raise exception when a `.computeMask()` call
65472 // does not return a list the same size as `call`
65473 var tensorMap = {};
65474 for (var i = 0; i < this.inputs.length; ++i) {
65475 var x = this.inputs[i];
65476 var y = inputs[i];
65477 var mask = masks[i];
65478 tensorMap[x.id] = [y, mask];
65479 }
65480 var depthKeys = Object.keys(this.nodesByDepth).map(function (x) {
65481 return parseInt(x, 10);
65482 }).sort(reverseNumberCompare);
65483 var _iterator22 = _createForOfIteratorHelper(depthKeys),
65484 _step22;
65485 try {
65486 for (_iterator22.s(); !(_step22 = _iterator22.n()).done;) {
65487 var depth = _step22.value;
65488 var nodes = this.nodesByDepth[depth];
65489 var _iterator24 = _createForOfIteratorHelper(nodes),
65490 _step24;
65491 try {
65492 for (_iterator24.s(); !(_step24 = _iterator24.n()).done;) {
65493 var node = _step24.value;
65494 // This is always a single layer, never a list.
65495 var layer = node.outboundLayer;
65496 var referenceInputTensors = node.inputTensors;
65497 var referenceOutputTensors = node.outputTensors;
65498 // If all previous input tensors are available in tensorMap,
65499 // then call node.inboundLayer on them.
65500 // List of tuples [input, mask]:
65501 var computedData = new Array();
65502 var _iterator25 = _createForOfIteratorHelper(referenceInputTensors),
65503 _step25;
65504 try {
65505 for (_iterator25.s(); !(_step25 = _iterator25.n()).done;) {
65506 var _x6 = _step25.value;
65507 if (_x6.id in tensorMap) {
65508 computedData.push(tensorMap[_x6.id]);
65509 }
65510 }
65511 } catch (err) {
65512 _iterator25.e(err);
65513 } finally {
65514 _iterator25.f();
65515 }
65516 if (computedData.length === referenceInputTensors.length) {
65517 // TODO(michaelterry): Add K.name_scope here, if we need it.
65518 var kwargs = {};
65519 var computedTensors = void 0;
65520 var computedMasks = void 0;
65521 var _outputTensors = void 0;
65522 var _outputMasks = void 0;
65523 // call layer
65524 if (node.callArgs != null) {
65525 kwargs = node.callArgs;
65526 }
65527 if (computedData.length === 1) {
65528 var _computedData$ = _slicedToArray(computedData[0], 2),
65529 computedTensor = _computedData$[0],
65530 computedMask = _computedData$[1];
65531 if (kwargs['mask'] == null) {
65532 kwargs['mask'] = computedMask;
65533 }
65534 _outputTensors = toList(layer.call(computedTensor, kwargs));
65535 _outputMasks = toList(layer.computeMask(computedTensor, computedMask));
65536 computedTensors = [computedTensor];
65537 computedMasks = [computedMask];
65538 } else {
65539 computedTensors = computedData.map(function (x) {
65540 return x[0];
65541 });
65542 computedMasks = computedData.map(function (x) {
65543 return x[1];
65544 });
65545 if (kwargs['mask'] == null) {
65546 kwargs['mask'] = computedMasks;
65547 }
65548 _outputTensors = toList(layer.call(computedTensors, kwargs));
65549 _outputMasks = toList(layer.computeMask(computedTensors, computedMasks));
65550 }
65551 if (layer.activityRegularizer) {
65552 throw new NotImplementedError('LayersModel invocation with concrete Tensor value(s) in the ' + 'presence of activity regularizer(s) is not supported yet.');
65553 }
65554 // TODO(michaelterry): Add model updates and losses
65555 // Update tensor map.
65556 for (var _i5 = 0; _i5 < referenceOutputTensors.length; ++_i5) {
65557 var _x5 = referenceOutputTensors[_i5];
65558 var _y = _outputTensors[_i5];
65559 var _mask = _outputMasks[_i5];
65560 tensorMap[_x5.id] = [_y, _mask];
65561 }
65562 }
65563 }
65564 } catch (err) {
65565 _iterator24.e(err);
65566 } finally {
65567 _iterator24.f();
65568 }
65569 }
65570 } catch (err) {
65571 _iterator22.e(err);
65572 } finally {
65573 _iterator22.f();
65574 }
65575 var outputTensors = [];
65576 var outputMasks = [];
65577 var outputShapes = [];
65578 var _iterator23 = _createForOfIteratorHelper(this.outputs),
65579 _step23;
65580 try {
65581 for (_iterator23.s(); !(_step23 = _iterator23.n()).done;) {
65582 var _x7 = _step23.value;
65583 assert(_x7.id in tensorMap, "Could not compute output ".concat(_x7.name, " : ").concat(_x7.id));
65584 var _tensorMap$_x7$id = _slicedToArray(tensorMap[_x7.id], 2),
65585 tensor = _tensorMap$_x7$id[0],
65586 _mask2 = _tensorMap$_x7$id[1];
65587 outputShapes.push(tensor.shape);
65588 outputTensors.push(tensor);
65589 outputMasks.push(_mask2);
65590 }
65591 // TODO(michaelterry): Add support for caches.
65592 } catch (err) {
65593 _iterator23.e(err);
65594 } finally {
65595 _iterator23.f();
65596 }
65597 return [outputTensors, outputMasks, outputShapes];
65598 }
65599 /**
65600 * Builds a map of internal node keys to node ordering.
65601 * Used in serializaion a node orderings may change as unused nodes are
65602 * dropped. Porting Note: This helper method was pulled out of getConfig to
65603 * improve readability.
65604 * @param layers An array of Layers in the model.
65605 * @returns Map of Node Keys to index order within the layer.
65606 */
65607 }, {
65608 key: "buildNodeConversionMap",
65609 value: function buildNodeConversionMap(layers) {
65610 var nodeConversionMap = {};
65611 var keptNodes;
65612 var _iterator26 = _createForOfIteratorHelper(this.layers),
65613 _step26;
65614 try {
65615 for (_iterator26.s(); !(_step26 = _iterator26.n()).done;) {
65616 var layer = _step26.value;
65617 keptNodes = layer instanceof Container ? 1 : 0;
65618 for (var originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
65619 var nodeKey = Container.nodeKey(layer, originalNodeIndex);
65620 if (this.containerNodes.has(nodeKey)) {
65621 // i.e. we mark it to be saved
65622 nodeConversionMap[nodeKey] = keptNodes;
65623 keptNodes += 1;
65624 }
65625 }
65626 }
65627 } catch (err) {
65628 _iterator26.e(err);
65629 } finally {
65630 _iterator26.f();
65631 }
65632 return nodeConversionMap;
65633 }
65634 }, {
65635 key: "getLayer",
65636 value: function getLayer(nameOrIndex, index) {
65637 if (index != null) {
65638 return this.findLayer(index);
65639 } else {
65640 if (nameOrIndex == null) {
65641 throw new ValueError('Provide either a layer name or layer index');
65642 }
65643 if (typeof nameOrIndex === 'number') {
65644 return this.findLayer(nameOrIndex);
65645 }
65646 }
65647 var _iterator27 = _createForOfIteratorHelper(this.layers),
65648 _step27;
65649 try {
65650 for (_iterator27.s(); !(_step27 = _iterator27.n()).done;) {
65651 var layer = _step27.value;
65652 if (layer.name === nameOrIndex) {
65653 return layer;
65654 }
65655 }
65656 } catch (err) {
65657 _iterator27.e(err);
65658 } finally {
65659 _iterator27.f();
65660 }
65661 throw new ValueError("No such layer: ".concat(nameOrIndex));
65662 }
65663 }, {
65664 key: "findLayer",
65665 value: function findLayer(index) {
65666 if (this.layers.length <= index) {
65667 throw new ValueError("Was asked to retrieve layer at index ".concat(index, ", but model only ") + "has ".concat(this.layers.length, " layer(s)."));
65668 } else {
65669 return this.layers[index];
65670 }
65671 }
65672 /**
65673 * Retrieves the Container's current loss values.
65674 *
65675 * Used for regularizers during training.
65676 */
65677 }, {
65678 key: "calculateLosses",
65679 value: function calculateLosses() {
65680 var _this4 = this;
65681 // Porting Node: This is an augmentation to Container.loss in PyKeras.
65682 // In PyKeras, Container.loss returns symbolic tensors. Here a concrete
65683 // Tensor (specifically Scalar) values are returned. This is due to the
65684 // imperative backend.
65685 return tidy(function () {
65686 var losses = [];
65687 var _iterator28 = _createForOfIteratorHelper(_this4.layers),
65688 _step28;
65689 try {
65690 for (_iterator28.s(); !(_step28 = _iterator28.n()).done;) {
65691 var layer = _step28.value;
65692 for (var nodeIndex = 0; nodeIndex < layer.inboundNodes.length; ++nodeIndex) {
65693 var nodeKey = Container.nodeKey(layer, nodeIndex);
65694 if (_this4.containerNodes.has(nodeKey)) {
65695 losses.push.apply(losses, _toConsumableArray(layer.calculateLosses()));
65696 }
65697 }
65698 }
65699 // TODO(cais): Add any unconditional model-level losses?
65700 } catch (err) {
65701 _iterator28.e(err);
65702 } finally {
65703 _iterator28.f();
65704 }
65705 return losses;
65706 });
65707 }
65708 }, {
65709 key: "getConfig",
65710 value: function getConfig() {
65711 var config = {
65712 name: this.name
65713 };
65714 // Build a map from layer unique name (self._node_key)
65715 // to the index of the nodes that are saved in the config.
65716 // Only nodes in container_nodes are saved.
65717 var nodeConversionMap = this.buildNodeConversionMap(this.layers);
65718 // Serialize and save the layers in layerConfigs
65719 var layerConfigs = [];
65720 var _iterator29 = _createForOfIteratorHelper(this.layers),
65721 _step29;
65722 try {
65723 for (_iterator29.s(); !(_step29 = _iterator29.n()).done;) {
65724 var _layer11 = _step29.value;
65725 var layerClassName = _layer11.getClassName();
65726 var layerConfig = _layer11.getConfig();
65727 var filteredInboundNodes = [];
65728 for (var originalNodeIndex = 0; originalNodeIndex < _layer11.inboundNodes.length; originalNodeIndex++) {
65729 var node = _layer11.inboundNodes[originalNodeIndex];
65730 var _nodeKey2 = Container.nodeKey(_layer11, originalNodeIndex);
65731 var kwargs = {};
65732 if (this.containerNodes.has(_nodeKey2)) {
65733 // The node is relevant to the model:
65734 // add to filteredInboundNodes.
65735 if (node.callArgs) {
65736 try {
65737 JSON.stringify(node.callArgs);
65738 kwargs = node.callArgs;
65739 } catch (err) {
65740 console.warn("Layer ".concat(_layer11.name, " was passed ") + "non-serializable keyword arguments: " + "".concat(node.callArgs, ". They will not be included ") + "in the serialized model (and thus will be " + "missing at deserialization time).");
65741 kwargs = {};
65742 }
65743 }
65744 if (node.inboundLayers.length > 0) {
65745 var nodeData = [];
65746 for (var _i7 = 0; _i7 < node.inboundLayers.length; _i7++) {
65747 var inboundLayer = node.inboundLayers[_i7];
65748 var _nodeIndex7 = node.nodeIndices[_i7];
65749 var _tensorIndex5 = node.tensorIndices[_i7];
65750 var _nodeKey3 = Container.nodeKey(inboundLayer, _nodeIndex7);
65751 var _newNodeIndex2 = nodeConversionMap[_nodeKey3];
65752 if (_newNodeIndex2 == null) {
65753 _newNodeIndex2 = 0;
65754 }
65755 nodeData.push([inboundLayer.name, _newNodeIndex2, _tensorIndex5, kwargs]);
65756 }
65757 filteredInboundNodes.push(nodeData);
65758 }
65759 }
65760 }
65761 var dict = {};
65762 dict['name'] = _layer11.name;
65763 dict['className'] = layerClassName;
65764 dict['config'] = layerConfig;
65765 dict['inboundNodes'] = filteredInboundNodes;
65766 layerConfigs.push(dict);
65767 }
65768 } catch (err) {
65769 _iterator29.e(err);
65770 } finally {
65771 _iterator29.f();
65772 }
65773 config['layers'] = layerConfigs;
65774 // Gather info about inputs and outputs
65775 var modelInputs = [];
65776 for (var i = 0; i < this.inputLayers.length; i++) {
65777 var layer = this.inputLayers[i];
65778 var nodeIndex = this.inputLayersNodeIndices[i];
65779 var nodeKey = Container.nodeKey(layer, nodeIndex);
65780 if (!this.containerNodes.has(nodeKey)) {
65781 continue;
65782 }
65783 var newNodeIndex = nodeConversionMap[nodeKey];
65784 if (newNodeIndex === null || newNodeIndex === undefined) {
65785 newNodeIndex = 0;
65786 }
65787 var tensorIndex = this.inputLayersTensorIndices[i];
65788 modelInputs.push([layer.name, newNodeIndex, tensorIndex]);
65789 }
65790 config['inputLayers'] = modelInputs;
65791 var modelOutputs = [];
65792 for (var _i6 = 0; _i6 < this.outputLayers.length; _i6++) {
65793 var _layer10 = this.outputLayers[_i6];
65794 var _nodeIndex6 = this.outputLayersNodeIndices[_i6];
65795 var _nodeKey = Container.nodeKey(_layer10, _nodeIndex6);
65796 if (!this.containerNodes.has(_nodeKey)) {
65797 continue;
65798 }
65799 var _newNodeIndex = nodeConversionMap[_nodeKey];
65800 if (_newNodeIndex === null || _newNodeIndex === undefined) {
65801 _newNodeIndex = 0;
65802 }
65803 var _tensorIndex4 = this.outputLayersTensorIndices[_i6];
65804 modelOutputs.push([_layer10.name, _newNodeIndex, _tensorIndex4]);
65805 }
65806 config['outputLayers'] = modelOutputs;
65807 return config;
65808 }
65809 /**
65810 * Instantiates a LayersModel from its config (output of `get_config()`).
65811 * @param cls the class to create
65812 * @param config LayersModel config dictionary.
65813 * @param customObjects An optional dictionary of custom objects.
65814 * @param fastWeightInit Optional flag to use fast weight initialization
65815 * during deserialization. This is applicable to cases in which
65816 * the initialization will be immediately overwritten by loaded weight
65817 * values. Default: `false`.
65818 * @returns A LayersModel instance.
65819 * @throws ValueError: In case of improperly formatted config dict.
65820 */
65821 /** @nocollapse */
65822 }, {
65823 key: "stateful",
65824 get:
65825 /**
65826 * Determine whether the container is stateful.
65827 *
65828 * Porting Note: this is the equivalent of the stateful @property of
65829 * the Container class in PyKeras.
65830 */
65831 function get() {
65832 // Porting Note: This check is to prevent inadvertent setting of the
65833 // _stateful property of the Container instance.
65834 if (this._stateful) {
65835 throw new ValueError('Container instance unexpectedly has _stateful = true. The ' + 'statefulness of a Container is determined by the Layers it ' + 'contains. Its _stateful property must remain the default false.');
65836 }
65837 var _iterator30 = _createForOfIteratorHelper(this.layers),
65838 _step30;
65839 try {
65840 for (_iterator30.s(); !(_step30 = _iterator30.n()).done;) {
65841 var layer = _step30.value;
65842 if (layer.stateful) {
65843 return true;
65844 }
65845 }
65846 } catch (err) {
65847 _iterator30.e(err);
65848 } finally {
65849 _iterator30.f();
65850 }
65851 return false;
65852 }
65853 /**
65854 * Reset the state of all stateful constituent layers (if any).
65855 *
65856 * Examples of stateful layers include RNN layers whose `stateful` property
65857 * is set as `true`.
65858 */
65859 }, {
65860 key: "resetStates",
65861 value: function resetStates() {
65862 var _this5 = this;
65863 tidy(function () {
65864 _this5.layers.forEach(function (layer) {
65865 // tslint:disable:no-any
65866 if (layer.stateful) {
65867 layer.resetStates();
65868 }
65869 // tslint:enable:no-any
65870 });
65871 });
65872 }
65873 }], [{
65874 key: "fromConfig",
65875 value: function fromConfig(cls, config) {
65876 var customObjects = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : {};
65877 var fastWeightInit = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
65878 // Layer instances created during
65879 // the graph reconstruction process
65880 var createdLayers = {};
65881 // Dictionary mapping layer instances to
65882 // node data that specifies a layer call.
65883 // It acts as a queue that maintains any unprocessed
65884 // layer call until it becomes possible to process it
65885 // (i.e. until the input tensors to the call all exist).
65886 var unprocessedNodes = {};
65887 function addUnprocessedNode(layer, nodeData) {
65888 if (!(layer.name in unprocessedNodes)) {
65889 unprocessedNodes[layer.name] = [nodeData];
65890 } else {
65891 unprocessedNodes[layer.name].push(nodeData);
65892 }
65893 }
65894 function processNode(layer, nodeData) {
65895 var inputTensors = [];
65896 var kwargs;
65897 var _iterator31 = _createForOfIteratorHelper(nodeData),
65898 _step31;
65899 try {
65900 for (_iterator31.s(); !(_step31 = _iterator31.n()).done;) {
65901 var inputData = _step31.value;
65902 var inboundLayerName = inputData[0];
65903 var inboundNodeIndex = inputData[1];
65904 var inboundTensorIndex = inputData[2];
65905 kwargs = inputData[3] == null ? {} : inputData[3];
65906 if (!(inboundLayerName in createdLayers)) {
65907 addUnprocessedNode(layer, nodeData);
65908 return;
65909 }
65910 var inboundLayer = createdLayers[inboundLayerName];
65911 if (inboundLayer.inboundNodes.length <= inboundNodeIndex) {
65912 addUnprocessedNode(layer, nodeData);
65913 return;
65914 }
65915 var inboundNode = inboundLayer.inboundNodes[inboundNodeIndex];
65916 inputTensors.push(inboundNode.outputTensors[inboundTensorIndex]);
65917 }
65918 // Call layer on its inputs, thus creating the node
65919 // and building the layer if needed.
65920 // Note: This has Eager vs Graph Implications.
65921 } catch (err) {
65922 _iterator31.e(err);
65923 } finally {
65924 _iterator31.f();
65925 }
65926 if (inputTensors.length > 0) {
65927 layer.apply(singletonOrArray(inputTensors), kwargs); // was ** kwargs
65928 }
65929 }
65930 /**
65931 * Deserialize a layer, then call it on appropriate inputs.
65932 * @param layerData: layer config dict.
65933 * @throws ValueError: In case of improperly formatted `layer_data`
65934 * dict.
65935 */
65936 function processLayer(layerData) {
65937 var layerName = layerData['name'];
65938 // Instantiate layer.
65939 var layer = deserialize(layerData, config['customObjects'] != null ? config['customObjects'] : {});
65940 layer.setFastWeightInitDuringBuild(fastWeightInit);
65941 createdLayers[layerName] = layer;
65942 // Gather layer inputs.
65943 var inboundNodesData = layerData['inboundNodes'];
65944 inboundNodesData.forEach(function (nodeData) {
65945 if (!(nodeData instanceof Array)) {
65946 throw new ValueError("Corrupted configuration, expected array for nodeData: ".concat(nodeData));
65947 }
65948 // We don't process nodes (i.e. make layer calls)
65949 // on the fly because the inbound node may not yet exist,
65950 // in case of layer shared at different topological depths
65951 // (e.g.a model such as A(B(A(B(x)))))
65952 addUnprocessedNode(layer, nodeData);
65953 });
65954 }
65955 // First, we create all layers and enqueue nodes to be processed.
65956 var name = config['name'];
65957 var layersFromConfig = config['layers'];
65958 var _iterator32 = _createForOfIteratorHelper(layersFromConfig),
65959 _step32;
65960 try {
65961 for (_iterator32.s(); !(_step32 = _iterator32.n()).done;) {
65962 var _layerData = _step32.value;
65963 processLayer(_layerData);
65964 }
65965 // Then we process nodes in order of layer depth.
65966 // Nodes that cannot yet be processed(if the inbound node
65967 // does not yet exist) are re - enqueued, and the process
65968 // is repeated until all nodes are processed.
65969 } catch (err) {
65970 _iterator32.e(err);
65971 } finally {
65972 _iterator32.f();
65973 }
65974 while (!isObjectEmpty(unprocessedNodes)) {
65975 var _iterator33 = _createForOfIteratorHelper(layersFromConfig),
65976 _step33;
65977 try {
65978 for (_iterator33.s(); !(_step33 = _iterator33.n()).done;) {
65979 var layerData = _step33.value;
65980 var layer = createdLayers[layerData['name']];
65981 if (layer.name in unprocessedNodes) {
65982 var currentUnprocessedNodesForLayer = unprocessedNodes[layer.name];
65983 delete unprocessedNodes[layer.name];
65984 var _iterator34 = _createForOfIteratorHelper(currentUnprocessedNodesForLayer),
65985 _step34;
65986 try {
65987 for (_iterator34.s(); !(_step34 = _iterator34.n()).done;) {
65988 var nodeData = _step34.value;
65989 processNode(layer, nodeData);
65990 }
65991 } catch (err) {
65992 _iterator34.e(err);
65993 } finally {
65994 _iterator34.f();
65995 }
65996 }
65997 }
65998 } catch (err) {
65999 _iterator33.e(err);
66000 } finally {
66001 _iterator33.f();
66002 }
66003 }
66004 var inputTensors = [];
66005 var outputTensors = [];
66006 var inputLayersFromConfig = config['inputLayers'];
66007 var _iterator35 = _createForOfIteratorHelper(inputLayersFromConfig),
66008 _step35;
66009 try {
66010 for (_iterator35.s(); !(_step35 = _iterator35.n()).done;) {
66011 var _layerData2 = _step35.value;
66012 var layerName = _layerData2[0];
66013 var nodeIndex = _layerData2[1];
66014 var tensorIndex = _layerData2[2];
66015 assert(layerName in createdLayers);
66016 var _layer12 = createdLayers[layerName];
66017 var layerOutputTensors = _layer12.inboundNodes[nodeIndex].outputTensors;
66018 inputTensors.push(layerOutputTensors[tensorIndex]);
66019 }
66020 } catch (err) {
66021 _iterator35.e(err);
66022 } finally {
66023 _iterator35.f();
66024 }
66025 var outputLayersFromConfig = config['outputLayers'];
66026 var _iterator36 = _createForOfIteratorHelper(outputLayersFromConfig),
66027 _step36;
66028 try {
66029 for (_iterator36.s(); !(_step36 = _iterator36.n()).done;) {
66030 var _layerData3 = _step36.value;
66031 var _layerName = _layerData3[0];
66032 var _nodeIndex8 = _layerData3[1];
66033 var _tensorIndex6 = _layerData3[2];
66034 assert(_layerName in createdLayers);
66035 var _layer13 = createdLayers[_layerName];
66036 var _layerOutputTensors = _layer13.inboundNodes[_nodeIndex8].outputTensors;
66037 outputTensors.push(_layerOutputTensors[_tensorIndex6]);
66038 }
66039 } catch (err) {
66040 _iterator36.e(err);
66041 } finally {
66042 _iterator36.f();
66043 }
66044 return new cls({
66045 inputs: inputTensors,
66046 outputs: outputTensors,
66047 name: name
66048 });
66049 }
66050 }]);
66051 return Container;
66052 }(Layer);
66053
66054 function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) {
66055 var numOutputs = outputNames.length;
66056 if (xWeight == null || Array.isArray(xWeight) && xWeight.length === 0) {
66057 return outputNames.map(function (name) {
66058 return null;
66059 });
66060 }
66061 if (numOutputs === 1) {
66062 if (Array.isArray(xWeight) && xWeight.length === 1) {
66063 return xWeight;
66064 } else if (_typeof(xWeight) === 'object' && outputNames[0] in xWeight) {
66065 return [xWeight[outputNames[0]]];
66066 } else {
66067 return [xWeight];
66068 }
66069 }
66070 if (Array.isArray(xWeight)) {
66071 if (xWeight.length !== numOutputs) {
66072 throw new Error("Provided ".concat(weightType, " is an array of ").concat(xWeight.length, " ") + "element(s), but the model has ".concat(numOutputs, " outputs. ") + "Make sure a set of weights is provided for each model output.");
66073 }
66074 return xWeight;
66075 } else if (_typeof(xWeight) === 'object' && Object.keys(xWeight).length > 0 && _typeof(xWeight[Object.keys(xWeight)[0]]) === 'object') {
66076 var output = [];
66077 outputNames.forEach(function (outputName) {
66078 if (outputName in xWeight) {
66079 output.push(xWeight[outputName]);
66080 } else {
66081 output.push(null);
66082 }
66083 });
66084 return output;
66085 } else {
66086 throw new Error("The model has multiple (".concat(numOutputs, ") outputs, ") + "so ".concat(weightType, " must be either an array with ") + "".concat(numOutputs, " elements or an object with ").concat(outputNames, " keys. ") + "Provided ".concat(weightType, " not understood: ").concat(JSON.stringify(xWeight)));
66087 }
66088 }
66089 /**
66090 * Standardize class weighting objects.
66091 *
66092 * This function takes a single class-weighting object, an array of them,
66093 * or a map from output name to class-weighting object. It compares it to the
66094 * output name(s) of the model, base on which it outputs an array of
66095 * class-weighting objects of which the length matches the number of outputs.
66096 *
66097 * @param classWeight Input class-weighting object(s).
66098 * @param outputNames All output name(s) of the model.
66099 * @return An array of class-weighting objects. The length of the array matches
66100 * the model's number of outputs.
66101 */
66102 function standardizeClassWeights(classWeight, outputNames) {
66103 return standardizeSampleOrClassWeights(classWeight, outputNames, 'classWeight');
66104 }
66105 function standardizeSampleWeights(classWeight, outputNames) {
66106 return standardizeSampleOrClassWeights(classWeight, outputNames, 'sampleWeight');
66107 }
66108 /**
66109 * Standardize by-sample and/or by-class weights for training.
66110 *
66111 * Note that this function operates on one model output at a time. For a model
66112 * with multiple outputs, you must call this function multiple times.
66113 *
66114 * @param y The target tensor that the by-sample and/or by-class weight is for.
66115 * The values of y are assumed to encode the classes, either directly
66116 * as an integer index, or as one-hot encoding.
66117 * @param sampleWeight By-sample weights.
66118 * @param classWeight By-class weights: an object mapping class indices
66119 * (integers) to a weight (float) to apply to the model's loss for the
66120 * samples from this class during training. This can be useful to tell the
66121 * model to "pay more attention" to samples from an under-represented class.
66122 * @param sampleWeightMode The mode for the sample weights.
66123 * @return A Promise of weight tensor, of which the size of the first dimension
66124 * matches that of `y`.
66125 */
66126 function standardizeWeights(_x, _x2, _x3, _x4) {
66127 return _standardizeWeights.apply(this, arguments);
66128 }
66129 /**
66130 * Apply per-sample weights on the loss values from a number of samples.
66131 *
66132 * @param losses Loss tensor of shape `[batchSize]`.
66133 * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`.
66134 * @returns Tensor of the same shape as`losses`.
66135 */
66136 function _standardizeWeights() {
66137 _standardizeWeights = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(y, sampleWeight, classWeight, sampleWeightMode) {
66138 var yClasses, yClassIndices, classSampleWeight;
66139 return _regeneratorRuntime().wrap(function _callee$(_context) {
66140 while (1) switch (_context.prev = _context.next) {
66141 case 0:
66142 if (!(sampleWeight != null || sampleWeightMode != null)) {
66143 _context.next = 2;
66144 break;
66145 }
66146 throw new Error('Support sampleWeight is not implemented yet');
66147 case 2:
66148 if (!(classWeight != null)) {
66149 _context.next = 15;
66150 break;
66151 }
66152 // Apply class weights per sample.
66153 yClasses = tidy(function () {
66154 if (y.shape.length === 1) {
66155 // Assume class indices.
66156 return clone(y);
66157 } else if (y.shape.length === 2) {
66158 if (y.shape[1] > 1) {
66159 // Assume one-hot encoding of classes.
66160 var axis = 1;
66161 return argMax$2(y, axis);
66162 } else if (y.shape[1] === 1) {
66163 // Class index.
66164 return reshape$3(y, [y.shape[0]]);
66165 } else {
66166 throw new Error("Encountered unexpected last-dimension size (".concat(y.shape[1], ") ") + "during handling of class weights. The size is expected to be " + ">= 1.");
66167 }
66168 } else {
66169 throw new Error("Unexpected rank of target (y) tensor (".concat(y.rank, ") during ") + "handling of class weights. The rank is expected to be 1 or 2.");
66170 }
66171 });
66172 _context.t0 = Array;
66173 _context.next = 7;
66174 return yClasses.data();
66175 case 7:
66176 _context.t1 = _context.sent;
66177 yClassIndices = _context.t0.from.call(_context.t0, _context.t1);
66178 dispose(yClasses);
66179 classSampleWeight = [];
66180 yClassIndices.forEach(function (classIndex) {
66181 if (classWeight[classIndex] == null) {
66182 throw new Error("classWeight must contain all classes in the training data. " + "The class ".concat(classIndex, " exists in the data but not in ") + "classWeight");
66183 } else {
66184 classSampleWeight.push(classWeight[classIndex]);
66185 }
66186 });
66187 return _context.abrupt("return", tensor1d(classSampleWeight, 'float32'));
66188 case 15:
66189 return _context.abrupt("return", null);
66190 case 16:
66191 case "end":
66192 return _context.stop();
66193 }
66194 }, _callee);
66195 }));
66196 return _standardizeWeights.apply(this, arguments);
66197 }
66198 function computeWeightedLoss(losses, sampleWeights) {
66199 return mul(losses, sampleWeights);
66200 }
66201
66202 // Default batch size used during tensor-based validation.
66203 var DEFAULT_VALIDATION_BATCH_SIZE = 32;
66204 /**
66205 * Standardize the output of a dataset iterator for use by
66206 * LayersModel.fitDataset().
66207 *
66208 * @param model: A `tf.LayersModel` object.
66209 * @param iteratorOut The output of a dataset iterator. It is required to be
66210 * an object of the form `{xs: TensorOrArrayOrMap, ys:
66211 * TensorOrArrayOrMap}`, where `TensorOrArrayOrMap` is a single `tf.Tensor`,
66212 * a `tf.Tensor[]`, or a flat map from string names to `tf.Tensor`s.
66213 * @returns A flat array of `tf.Tensor` objects: the input `tf.Tensor`s
66214 * followed by the target `tf.Tensor`s. When `tf.Tensor`s are provided
66215 * as a map, the order in the resulting array is taken from the `inputNames`
66216 * and `outputNames` of the model.
66217 */
66218 function standardizeDataIteratorOutput(
66219 // Type `model` as `any` here to avoid circular dependency w/
66220 // training.ts.
66221 // tslint:disable-next-line:no-any
66222 model, iteratorOut) {
66223 var xs;
66224 var ys;
66225 var iteratorOutObj = iteratorOut;
66226 xs = iteratorOutObj['xs'];
66227 ys = iteratorOutObj['ys'];
66228 assert$1(xs != null && ys != null, function () {
66229 return 'A Dataset iterator for fitDataset() is expected to generate ' + 'objects of the form `{xs: xVal, ys: yVal}`, where the two ' + 'values may be `tf.Tensor`, an array of Tensors, or a map of ' + 'string to Tensor. The provided Dataset instead generates ' + "".concat(iteratorOut);
66230 });
66231 var flattenedXs = flattenTensorOrArrayOrMap('input', model.inputNames, xs);
66232 var flattenedYs = flattenTensorOrArrayOrMap('output', model.outputNames, ys);
66233 var batchSize = flattenedXs[0].shape[0];
66234 assert$1(flattenedXs.length === model.inputs.length, function () {
66235 return "LayersModel has ".concat(model.inputs.length, " inputs, but the dataset ") + "provides ".concat(flattenedXs.length, " inputs. (Expected input keys: ") + "".concat(JSON.stringify(model.inputNames), ")");
66236 });
66237 assert$1(flattenedYs.length === model.outputs.length, function () {
66238 return "LayersModel has ".concat(model.outputs.length, " outputs, but the dataset ") + "provides ".concat(flattenedYs.length, " outputs. (Expected output keys: ") + "".concat(JSON.stringify(model.outputNames), ")");
66239 });
66240 var _loop = function _loop(xIndex) {
66241 assert$1(flattenedXs[xIndex].shape[0] === batchSize, function () {
66242 return "Batch size mismatch: input " + "".concat(model.inputNames[xIndex], " has ").concat(flattenedXs[xIndex].shape[0], "; ") + "expected ".concat(batchSize, " based on input ").concat(model.inputNames[0], ".");
66243 });
66244 };
66245 for (var xIndex = 0; xIndex < flattenedXs.length; xIndex++) {
66246 _loop(xIndex);
66247 }
66248 var _loop2 = function _loop2(yIndex) {
66249 assert$1(flattenedYs[yIndex].shape[0] === batchSize, function () {
66250 return "Batch size mismatch: output " + "".concat(model.outputNames[yIndex], " has ").concat(flattenedYs[yIndex].shape[0], "; ") + "expected ".concat(batchSize, " based on input ").concat(model.inputNames[0], ".");
66251 });
66252 };
66253 for (var yIndex = 0; yIndex < flattenedYs.length; yIndex++) {
66254 _loop2(yIndex);
66255 }
66256 return {
66257 xs: flattenedXs,
66258 ys: flattenedYs
66259 };
66260 }
66261 function flattenTensorOrArrayOrMap(inputOrOutput, names, values) {
66262 if (values instanceof Tensor) {
66263 return [values];
66264 } else if (Array.isArray(values)) {
66265 assert$1(values.length === names.length, function () {
66266 return "Received an array of ".concat(values.length, " Tensors, but expected ").concat(names.length, " to match the ").concat(inputOrOutput, " keys ").concat(names, ".");
66267 });
66268 return values;
66269 } else {
66270 var result = [];
66271 // Check that all the required keys are available.
66272 var _iterator = _createForOfIteratorHelper(names),
66273 _step;
66274 try {
66275 for (_iterator.s(); !(_step = _iterator.n()).done;) {
66276 var name = _step.value;
66277 if (values[name] == null) {
66278 throw new ValueError("The feature data generated by the dataset lacks the required " + "".concat(inputOrOutput, " key '").concat(name, "'."));
66279 }
66280 result.push(values[name]);
66281 }
66282 } catch (err) {
66283 _iterator.e(err);
66284 } finally {
66285 _iterator.f();
66286 }
66287 return result;
66288 }
66289 }
66290 function standardizeTensorValidationData(data) {
66291 if (data.length === 3) {
66292 throw new NotImplementedError('Validation with sample weights is not implemented yet.');
66293 }
66294 return {
66295 xs: data[0],
66296 ys: data[1]
66297 };
66298 }
66299 function fitDataset(_x, _x2, _x3) {
66300 return _fitDataset.apply(this, arguments);
66301 }
66302 /** Helper function that determines number of steps (batches) per epoch. */
66303 function _fitDataset() {
66304 _fitDataset = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(
66305 // Type `model` as `any` here to avoid circular dependency w/
66306 // training.ts.
66307 // tslint:disable-next-line:no-any
66308 model, dataset, args) {
66309 var hasBatchesPerEpoch, doValidation, valXs, valYs, validationData, trainFunction, outLabels, callbackMetrics, callbacks, verbose, _configureCallbacks, callbackList, history, epoch, dataIterator, epochLogs, stepsDone, batchIndex, iteratorOut, _standardizeDataItera, xs, ys, batchLogs, sampleWeights, standardClassWeights, i, ins, outs, _i, label, out, valOuts, _i2;
66310 return _regeneratorRuntime().wrap(function _callee$(_context) {
66311 while (1) switch (_context.prev = _context.next) {
66312 case 0:
66313 hasBatchesPerEpoch = args.batchesPerEpoch != null;
66314 assert$1(model.optimizer != null, function () {
66315 return 'You must compile a model before training/testing. Use ' + 'LayersModel.compile(modelCompileConfig).';
66316 });
66317 assert$1(args != null, function () {
66318 return "For fitDataset(), the 2nd argument (config) is required, " + "but it is not provided in this call.";
66319 });
66320 assert$1(args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs), function () {
66321 return "For fitDataset(), config.epochs is expected to be a positive " + "integer, but got ".concat(args.epochs);
66322 });
66323 assert$1(!hasBatchesPerEpoch || args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch), function () {
66324 return "For fitDataset(), config.batchesPerEpoch is expected to be a " + "positive integer if specified, but got ".concat(args.batchesPerEpoch);
66325 });
66326 assert$1(
66327 // tslint:disable-next-line:no-any
66328 args['validationSplit'] == null, function () {
66329 return '`validationSplit` is not supported by `fitDataset()`. ' + 'Use validationData instead.';
66330 });
66331 if (!model.isTraining) {
66332 _context.next = 8;
66333 break;
66334 }
66335 throw new Error('Cannot start training because another fit() call is ongoing.');
66336 case 8:
66337 model.isTraining = true;
66338 _context.prev = 9;
66339 doValidation = args.validationData != null;
66340 if (doValidation) {
66341 if (isDatasetObject(args.validationData)) {
66342 assert$1(args.validationBatches == null || args.validationBatches > 0 && Number.isInteger(args.validationBatches), function () {
66343 return "For fitDataset() with dataset-based validation, " + "config.validationBatches is expected not to be provided, " + "or to be a positive integer, " + "but got ".concat(args.validationBatches);
66344 });
66345 } else {
66346 validationData = standardizeTensorValidationData(args.validationData);
66347 valXs = validationData.xs;
66348 valYs = validationData.ys;
66349 }
66350 }
66351 trainFunction = model.makeTrainFunction();
66352 outLabels = model.getDedupedMetricsNames();
66353 if (doValidation) {
66354 callbackMetrics = outLabels.slice().concat(outLabels.map(function (n) {
66355 return 'val_' + n;
66356 }));
66357 } else {
66358 callbackMetrics = outLabels.slice();
66359 }
66360 callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
66361 verbose = args.verbose == null ? 1 : args.verbose;
66362 _configureCallbacks = configureCallbacks(callbacks, verbose, args.epochs, null, null, getStepsPerEpoch(dataset, args), null,
66363 // Batch size determined by the dataset itself.
66364 doValidation, callbackMetrics), callbackList = _configureCallbacks.callbackList, history = _configureCallbacks.history;
66365 callbackList.setModel(model);
66366 model.history = history;
66367 _context.next = 22;
66368 return callbackList.onTrainBegin();
66369 case 22:
66370 model.stopTraining_ = false;
66371 epoch = args.initialEpoch == null ? 0 : args.initialEpoch;
66372 _context.next = 26;
66373 return dataset.iterator();
66374 case 26:
66375 dataIterator = _context.sent;
66376 case 27:
66377 if (!(epoch < args.epochs)) {
66378 _context.next = 98;
66379 break;
66380 }
66381 epochLogs = {};
66382 _context.next = 31;
66383 return callbackList.onEpochBegin(epoch);
66384 case 31:
66385 stepsDone = 0;
66386 batchIndex = 0;
66387 if (hasBatchesPerEpoch) {
66388 _context.next = 37;
66389 break;
66390 }
66391 _context.next = 36;
66392 return dataset.iterator();
66393 case 36:
66394 dataIterator = _context.sent;
66395 case 37:
66396 if (!(hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true)) {
66397 _context.next = 91;
66398 break;
66399 }
66400 _context.next = 40;
66401 return dataIterator.next();
66402 case 40:
66403 iteratorOut = _context.sent;
66404 if (!(hasBatchesPerEpoch && iteratorOut.done)) {
66405 _context.next = 44;
66406 break;
66407 }
66408 console.warn('You provided `batchesPerEpoch` as ' + "".concat(args.batchesPerEpoch, ", ") + 'but your dataset iterator ran out of data after ' + "".concat(stepsDone, " batches; ") + 'interrupting training. Make sure that your ' + 'dataset can generate at least `batchesPerEpoch * epochs` ' + 'batches (in this case, ' + "".concat(args.batchesPerEpoch * args.epochs, " batches). ") + 'You may need to use the repeat() function when building ' + 'your dataset.');
66409 return _context.abrupt("break", 91);
66410 case 44:
66411 if (!(iteratorOut.value != null)) {
66412 _context.next = 73;
66413 break;
66414 }
66415 _standardizeDataItera = standardizeDataIteratorOutput(model, iteratorOut.value), xs = _standardizeDataItera.xs, ys = _standardizeDataItera.ys;
66416 batchLogs = {};
66417 batchLogs['batch'] = batchIndex;
66418 batchLogs['size'] = xs[0].shape[0];
66419 _context.next = 51;
66420 return callbackList.onBatchBegin(batchIndex, batchLogs);
66421 case 51:
66422 sampleWeights = [];
66423 if (!(args.classWeight != null)) {
66424 _context.next = 64;
66425 break;
66426 }
66427 standardClassWeights = standardizeClassWeights(args.classWeight, model.outputNames);
66428 i = 0;
66429 case 55:
66430 if (!(i < standardClassWeights.length)) {
66431 _context.next = 64;
66432 break;
66433 }
66434 _context.t0 = sampleWeights;
66435 _context.next = 59;
66436 return standardizeWeights(ys[i], null, standardClassWeights[i]);
66437 case 59:
66438 _context.t1 = _context.sent;
66439 _context.t0.push.call(_context.t0, _context.t1);
66440 case 61:
66441 ++i;
66442 _context.next = 55;
66443 break;
66444 case 64:
66445 // Train on batch.
66446 ins = xs.concat(ys).concat(sampleWeights);
66447 outs = trainFunction(ins);
66448 dispose(ins);
66449 for (_i = 0; _i < outLabels.length; ++_i) {
66450 label = outLabels[_i];
66451 out = outs[_i];
66452 batchLogs[label] = out;
66453 keep(out);
66454 }
66455 _context.next = 70;
66456 return callbackList.onBatchEnd(batchIndex, batchLogs);
66457 case 70:
66458 disposeTensorsInLogs(batchLogs);
66459 batchIndex++;
66460 stepsDone++;
66461 case 73:
66462 if (!(hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch : iteratorOut.done)) {
66463 _context.next = 87;
66464 break;
66465 }
66466 if (!doValidation) {
66467 _context.next = 86;
66468 break;
66469 }
66470 valOuts = void 0;
66471 if (!isDatasetObject(args.validationData)) {
66472 _context.next = 84;
66473 break;
66474 }
66475 _context.t2 = toList;
66476 _context.next = 80;
66477 return model.evaluateDataset(args.validationData, {
66478 batches: args.validationBatches
66479 });
66480 case 80:
66481 _context.t3 = _context.sent;
66482 valOuts = (0, _context.t2)(_context.t3);
66483 _context.next = 85;
66484 break;
66485 case 84:
66486 valOuts = toList(model.evaluate(valXs, valYs, {
66487 batchSize: args.validationBatchSize == null ? DEFAULT_VALIDATION_BATCH_SIZE : args.validationBatchSize,
66488 verbose: 0
66489 }));
66490 case 85:
66491 for (_i2 = 0; _i2 < model.metricsNames.length; ++_i2) {
66492 epochLogs["val_".concat(model.metricsNames[_i2])] = valOuts[_i2];
66493 }
66494 case 86:
66495 return _context.abrupt("break", 91);
66496 case 87:
66497 if (!model.stopTraining_) {
66498 _context.next = 89;
66499 break;
66500 }
66501 return _context.abrupt("break", 91);
66502 case 89:
66503 _context.next = 37;
66504 break;
66505 case 91:
66506 _context.next = 93;
66507 return callbackList.onEpochEnd(epoch, epochLogs);
66508 case 93:
66509 epoch++;
66510 if (!model.stopTraining_) {
66511 _context.next = 96;
66512 break;
66513 }
66514 return _context.abrupt("break", 98);
66515 case 96:
66516 _context.next = 27;
66517 break;
66518 case 98:
66519 _context.next = 100;
66520 return callbackList.onTrainEnd();
66521 case 100:
66522 _context.next = 102;
66523 return model.history.syncData();
66524 case 102:
66525 return _context.abrupt("return", model.history);
66526 case 103:
66527 _context.prev = 103;
66528 model.isTraining = false;
66529 return _context.finish(103);
66530 case 106:
66531 case "end":
66532 return _context.stop();
66533 }
66534 }, _callee, null, [[9,, 103, 106]]);
66535 }));
66536 return _fitDataset.apply(this, arguments);
66537 }
66538 function getStepsPerEpoch(dataset, args) {
66539 // Attempt to determine # of batches in an epoch.
66540 var stepsPerEpoch = null;
66541 if (args.batchesPerEpoch != null) {
66542 stepsPerEpoch = args.batchesPerEpoch;
66543 } else if (Number.isFinite(dataset.size)) {
66544 stepsPerEpoch = dataset.size;
66545 }
66546 return stepsPerEpoch;
66547 }
66548 // Check if provided object is a Dataset object by checking its .iterator
66549 // element.
66550 function isDatasetObject(dataset) {
66551 return typeof dataset.iterator === 'function';
66552 }
66553 // Check if provided object is a LazyIterator object by checking it's .next
66554 // element.
66555 function isLazyIteratorObject(iterator) {
66556 return typeof iterator.next === 'function';
66557 }
66558 function evaluateDataset(_x4, _x5, _x6) {
66559 return _evaluateDataset.apply(this, arguments);
66560 }
66561 function _evaluateDataset() {
66562 _evaluateDataset = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(
66563 // Type `model` as `any` here to avoid circular dependency w/
66564 // training.ts.
66565 // tslint:disable-next-line:no-any
66566 model, dataset, args) {
66567 var hasBatches, f, outs, dataIterator, numExamples, batch, _loop3, _ret, i, oldScalar;
66568 return _regeneratorRuntime().wrap(function _callee2$(_context3) {
66569 while (1) switch (_context3.prev = _context3.next) {
66570 case 0:
66571 args = args || {};
66572 hasBatches = args.batches != null;
66573 f = model.testFunction;
66574 outs = [];
66575 if (!(args.verbose > 0)) {
66576 _context3.next = 6;
66577 break;
66578 }
66579 throw new NotImplementedError('Verbose mode is not implemented yet.');
66580 case 6:
66581 assert$1(!hasBatches || args.batches > 0 && Number.isInteger(args.batches), function () {
66582 return 'Test loop expects `batches` to be a positive integer, but ' + "received ".concat(JSON.stringify(args.batches));
66583 });
66584 if (!isLazyIteratorObject(dataset)) {
66585 _context3.next = 11;
66586 break;
66587 }
66588 _context3.t0 = dataset;
66589 _context3.next = 14;
66590 break;
66591 case 11:
66592 _context3.next = 13;
66593 return dataset.iterator();
66594 case 13:
66595 _context3.t0 = _context3.sent;
66596 case 14:
66597 dataIterator = _context3.t0;
66598 // Keeps track of number of examples used in this evaluation.
66599 numExamples = 0;
66600 batch = 0;
66601 _loop3 = /*#__PURE__*/_regeneratorRuntime().mark(function _loop3() {
66602 var iteratorOut;
66603 return _regeneratorRuntime().wrap(function _loop3$(_context2) {
66604 while (1) switch (_context2.prev = _context2.next) {
66605 case 0:
66606 _context2.next = 2;
66607 return dataIterator.next();
66608 case 2:
66609 iteratorOut = _context2.sent;
66610 outs = tidy(function () {
66611 if (iteratorOut.value) {
66612 // TODO(cais): Once real dataset is available, use
66613 // `map(x => standardizeDataIteratorOutput(model, x).map(f)`.
66614 var _standardizeDataItera2 = standardizeDataIteratorOutput(model, iteratorOut.value),
66615 xs = _standardizeDataItera2.xs,
66616 ys = _standardizeDataItera2.ys;
66617 var xsAndYs = xs.concat(ys);
66618 var batchOuts = tidy(function () {
66619 return f(xsAndYs);
66620 });
66621 dispose(xsAndYs);
66622 if (batch === 0) {
66623 for (var _i3 = 0; _i3 < batchOuts.length; ++_i3) {
66624 outs.push(scalar(0));
66625 }
66626 }
66627 var batchSize = xsAndYs[0].shape[0];
66628 var _loop4 = function _loop4(_i4) {
66629 var batchOut = batchOuts[_i4];
66630 var oldScalar = outs[_i4];
66631 outs[_i4] = tidy(function () {
66632 return add$3(outs[_i4], mul(batchSize, batchOut));
66633 });
66634 if (batch > 0) {
66635 dispose(oldScalar);
66636 }
66637 };
66638 for (var _i4 = 0; _i4 < batchOuts.length; ++_i4) {
66639 _loop4(_i4);
66640 }
66641 dispose(batchOuts);
66642 numExamples += batchSize;
66643 ++batch;
66644 }
66645 return outs;
66646 });
66647 if (!iteratorOut.done) {
66648 _context2.next = 7;
66649 break;
66650 }
66651 if (hasBatches) {
66652 console.warn('Your dataset iterator ran out of data during evaluateDataset(). ' + 'Interrupting evalution. Make sure that your ' + 'dataset can generate at least `batches` ' + "batches (in this case, ".concat(args.batches, " batches). ") + 'You may need to use the repeat() function when building ' + 'your dataset.');
66653 }
66654 return _context2.abrupt("return", "break");
66655 case 7:
66656 case "end":
66657 return _context2.stop();
66658 }
66659 }, _loop3);
66660 });
66661 case 18:
66662 if (!(hasBatches ? batch < args.batches : true)) {
66663 _context3.next = 25;
66664 break;
66665 }
66666 return _context3.delegateYield(_loop3(), "t1", 20);
66667 case 20:
66668 _ret = _context3.t1;
66669 if (!(_ret === "break")) {
66670 _context3.next = 23;
66671 break;
66672 }
66673 return _context3.abrupt("break", 25);
66674 case 23:
66675 _context3.next = 18;
66676 break;
66677 case 25:
66678 for (i = 0; i < outs.length; ++i) {
66679 oldScalar = outs[i];
66680 outs[i] = div$1(outs[i], numExamples);
66681 dispose(oldScalar);
66682 }
66683 return _context3.abrupt("return", singletonOrArray(outs));
66684 case 27:
66685 case "end":
66686 return _context3.stop();
66687 }
66688 }, _callee2);
66689 }));
66690 return _evaluateDataset.apply(this, arguments);
66691 }
66692
66693 /**
66694 * @license
66695 * Copyright 2018 Google LLC
66696 *
66697 * Use of this source code is governed by an MIT-style
66698 * license that can be found in the LICENSE file or at
66699 * https://opensource.org/licenses/MIT.
66700 * =============================================================================
66701 */
66702 function checkBatchSize(batchSize) {
66703 assert$1(batchSize > 0 && Number.isInteger(batchSize), function () {
66704 return "batchSize is required to be a positive integer, but got ".concat(batchSize);
66705 });
66706 }
66707 /**
66708 * Slice a Tensor or an Array of Tensors, by start and stop indices.
66709 *
66710 * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
66711 * function and `sliceArraysByIndices()` together.
66712 *
66713 * @param arrays: the input.
66714 * @param start: the starting index (inclusive).
66715 * @param stop: the stopping index (exclusive).
66716 * @returns The result of the slicing. If `arrays` is an `Array` of
66717 * `tf.Tensor`s, the slicing will be applied to all elements of the `Array`
66718 * in the same way.
66719 */
66720 function sliceArrays(arrays, start, stop) {
66721 if (arrays == null) {
66722 return [null];
66723 } else if (Array.isArray(arrays)) {
66724 return arrays.map(function (array) {
66725 return sliceAlongFirstAxis(array, start, stop - start);
66726 });
66727 } else {
66728 // Tensor.
66729 return sliceAlongFirstAxis(arrays, start, stop - start);
66730 }
66731 }
66732 /**
66733 * Slice a Tensor or an Array of Tensors, by random-order indices.
66734 *
66735 * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
66736 * function and `sliceArrays()` together.
66737 *
66738 * @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice.
66739 * If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the
66740 * same fashion.
66741 * @param indices The indices to use for slicing along the first (batch)
66742 * dimension.
66743 * @returns Result(s) of the slicing.
66744 */
66745 function sliceArraysByIndices(arrays, indices) {
66746 return tidy(function () {
66747 if (arrays == null) {
66748 return null;
66749 } else if (Array.isArray(arrays)) {
66750 return arrays.map(function (array) {
66751 return sliceArraysByIndices(array, indices);
66752 });
66753 } else {
66754 // TODO(cais): indices should be a pre-constructed Tensor1D to avoid
66755 // tensor1d() calls.
66756 return gather(arrays, indices.dtype === 'int32' ? indices : cast$3(indices, 'int32'));
66757 }
66758 });
66759 }
66760 /**
66761 * Returns a list of batch indices (tuples of indices).
66762 * @param size: Integer, total size of the data to slice into batches.
66763 * @param batchSize: Integer, batch size.
66764 * @returns An Array of [batchStart, batchEnd] tuples. batchStart is
66765 * inclusive; batchEnd is exclusive. I.e., each batch consists of indices x
66766 * that satisfy batchStart <= x < batchEnd.
66767 */
66768 function makeBatches(size, batchSize) {
66769 var output = [];
66770 var batchStart = 0;
66771 var batchEnd = null;
66772 while (batchStart < size) {
66773 batchEnd = batchStart + batchSize;
66774 if (batchEnd >= size) {
66775 batchEnd = size;
66776 }
66777 output.push([batchStart, batchEnd]);
66778 batchStart = batchEnd;
66779 }
66780 return output;
66781 }
66782 /**
66783 * Ensure tensors all have a rank of at least 2.
66784 *
66785 * If a tensor has a rank of 1, it is dimension-expanded to rank 2.
66786 * If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown.
66787 */
66788 function ensureTensorsRank2OrHigher(tensors) {
66789 var outs = [];
66790 if (tensors instanceof Tensor) {
66791 tensors = [tensors];
66792 }
66793 // Make Tensors at least 2D.
66794 for (var i = 0; i < tensors.length; ++i) {
66795 var tensor = tensors[i];
66796 if (tensor.rank === 1) {
66797 outs.push(expandDims$2(tensor, 1));
66798 } else if (tensor.rank === 0) {
66799 throw new Error('Expected tensor to be at least 1D, but received a 0D tensor ' + '(scalar).');
66800 } else {
66801 outs.push(tensor);
66802 }
66803 }
66804 return outs;
66805 }
66806 /**
66807 * Compare a set of tensors with a reference (old) set, discard the ones
66808 * in the new set that are not present in the reference set.
66809 *
66810 * This method is used for memory clenaup during calls such as
66811 * LayersModel.fit().
66812 *
66813 * @param tensors New set which may contain Tensors not present in
66814 * `refTensors`.
66815 * @param refTensors Reference Tensor set.
66816 */
66817 // TODO(cais, kangyizhang): Deduplicate with tfjs-data.
66818 function disposeNewTensors(tensors, refTensors) {
66819 if (tensors == null) {
66820 return;
66821 }
66822 var oldTensorIds = [];
66823 if (refTensors instanceof Tensor) {
66824 oldTensorIds.push(refTensors.id);
66825 } else if (Array.isArray(refTensors)) {
66826 refTensors.forEach(function (t) {
66827 return oldTensorIds.push(t.id);
66828 });
66829 } else if (refTensors != null) {
66830 // `oldTensors` is a map from string name to Tensor.
66831 for (var name in refTensors) {
66832 var oldTensor = refTensors[name];
66833 oldTensorIds.push(oldTensor.id);
66834 }
66835 }
66836 var tensorsToDispose = [];
66837 if (tensors instanceof Tensor) {
66838 if (oldTensorIds.indexOf(tensors.id) === -1) {
66839 tensorsToDispose.push(tensors);
66840 }
66841 } else if (Array.isArray(tensors)) {
66842 tensors.forEach(function (t) {
66843 if (oldTensorIds.indexOf(t.id) === -1) {
66844 tensorsToDispose.push(t);
66845 }
66846 });
66847 } else if (tensors != null) {
66848 // `oldTensors` is a map from string name to Tensor.
66849 for (var _name in tensors) {
66850 var tensor = tensors[_name];
66851 if (oldTensorIds.indexOf(tensor.id) === -1) {
66852 tensorsToDispose.push(tensor);
66853 }
66854 }
66855 }
66856 tensorsToDispose.forEach(function (t) {
66857 if (!t.isDisposed) {
66858 t.dispose();
66859 }
66860 });
66861 }
66862
66863 /**
66864 * Helper function for polymorphic input data: 1. singleton Tensor.
66865 */
66866 function isDataTensor(x) {
66867 return x instanceof Tensor;
66868 }
66869 /**
66870 * Helper function for polymorphic input data: 2. Array of Tensor.
66871 */
66872 function isDataArray(x) {
66873 return Array.isArray(x);
66874 }
66875 /**
66876 * Helper function for polymorphic input data: 3. "dict" of Tensor.
66877 */
66878 function isDataDict(x) {
66879 return !isDataTensor(x) && !isDataArray(x);
66880 }
66881 /**
66882 * Normalizes inputs and targets provided by users.
66883 * @param data User-provided input data (polymorphic).
66884 * @param names An Array of expected Tensor names.
66885 * @param shapes Optional Array of expected Tensor shapes.
66886 * @param checkBatchAxis Whether to check that the batch axis of the arrays
66887 * match the expected value found in `shapes`.
66888 * @param exceptionPrefix String prefix used for exception formatting.
66889 * @returns List of standardized input Tensors (one Tensor per model input).
66890 * @throws ValueError: in case of improperly formatted user data.
66891 */
66892 function standardizeInputData(data, names, shapes) {
66893 var checkBatchAxis = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : true;
66894 var exceptionPrefix = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : '';
66895 if (names == null || names.length === 0) {
66896 // Check for the case where the model expected no data, but some data got
66897 // sent.
66898 if (data != null) {
66899 var gotUnexpectedData = false;
66900 if (isDataArray(data) && data.length > 0) {
66901 gotUnexpectedData = true;
66902 } else if (isDataDict(data)) {
66903 for (var key in data) {
66904 if (data.hasOwnProperty(key)) {
66905 gotUnexpectedData = true;
66906 break;
66907 }
66908 }
66909 } else {
66910 // `data` is a singleton Tensor in this case.
66911 gotUnexpectedData = true;
66912 }
66913 if (gotUnexpectedData) {
66914 throw new ValueError("Error when checking model ".concat(exceptionPrefix, " expected no data, ") + "but got ".concat(data));
66915 }
66916 }
66917 return [];
66918 }
66919 if (data == null) {
66920 return names.map(function (name) {
66921 return null;
66922 });
66923 }
66924 var arrays;
66925 if (isDataDict(data)) {
66926 data = data;
66927 arrays = [];
66928 var _iterator = _createForOfIteratorHelper(names),
66929 _step;
66930 try {
66931 for (_iterator.s(); !(_step = _iterator.n()).done;) {
66932 var name = _step.value;
66933 if (data[name] == null) {
66934 throw new ValueError("No data provided for \"".concat(name, "\". Need data for each key in: ") + "".concat(names));
66935 }
66936 arrays.push(data[name]);
66937 }
66938 } catch (err) {
66939 _iterator.e(err);
66940 } finally {
66941 _iterator.f();
66942 }
66943 } else if (isDataArray(data)) {
66944 data = data;
66945 if (data.length !== names.length) {
66946 throw new ValueError("Error when checking model ".concat(exceptionPrefix, ": the Array of ") + "Tensors that you are passing to your model is not the size the " + "model expected. Expected to see ".concat(names.length, " Tensor(s), but ") + "instead got the following list of Tensor(s): ".concat(data));
66947 }
66948 arrays = data;
66949 } else {
66950 data = data;
66951 if (names.length > 1) {
66952 throw new ValueError("The model ".concat(exceptionPrefix, " expects ").concat(names.length, " Tensor(s), ") + "but only received one Tensor. Found: Tensor with shape ".concat(data.shape));
66953 }
66954 arrays = [data];
66955 }
66956 arrays = ensureTensorsRank2OrHigher(arrays);
66957 // Check shape compatibility.
66958 if (shapes != null) {
66959 for (var i = 0; i < names.length; ++i) {
66960 if (shapes[i] == null) {
66961 continue;
66962 }
66963 var array = arrays[i];
66964 if (array.shape.length !== shapes[i].length) {
66965 throw new ValueError("Error when checking ".concat(exceptionPrefix, ": expected ").concat(names[i], " ") + "to have ".concat(shapes[i].length, " dimension(s). but got array with ") + "shape ".concat(array.shape));
66966 }
66967 for (var j = 0; j < shapes[i].length; ++j) {
66968 if (j === 0 && !checkBatchAxis) {
66969 // Skip the first (batch) axis.
66970 continue;
66971 }
66972 var dim = array.shape[j];
66973 var refDim = shapes[i][j];
66974 if (refDim != null && refDim >= 0 && dim !== refDim) {
66975 throw new ValueError("".concat(exceptionPrefix, " expected a batch of elements where each ") + "example has shape [".concat(shapes[i].slice(1, shapes[i].length), "] ") + "(i.e.,tensor shape [*,".concat(shapes[i].slice(1, shapes[i].length), "])") + " but the ".concat(exceptionPrefix, " received an input with ").concat(array.shape[0]) + " examples, each with shape [".concat(array.shape.slice(1, array.shape.length), "]") + " (tensor shape [".concat(array.shape, "])"));
66976 }
66977 }
66978 }
66979 }
66980 return arrays;
66981 }
66982 /**
66983 * User input validation for Tensors.
66984 * @param inputs `Array` of `tf.Tensor`s for inputs.
66985 * @param targets `Array` of `tf.Tensor`s for targets.
66986 * @param weights Optional `Array` of `tf.Tensor`s for sample weights.
66987 * @throws ValueError: in case of incorrectly formatted data.
66988 */
66989 function checkArrayLengths(inputs, targets, weights) {
66990 var setX = unique$2(inputs.map(function (input) {
66991 return input.shape[0];
66992 }));
66993 setX.sort();
66994 var setY = unique$2(targets.map(function (target) {
66995 return target.shape[0];
66996 }));
66997 setY.sort();
66998 // TODO(cais): Check `weights` as well.
66999 if (setX.length > 1) {
67000 throw new ValueError("All input Tensors (x) should have the same number of samples. " + "Got array shapes: " + "".concat(JSON.stringify(inputs.map(function (input) {
67001 return input.shape;
67002 }))));
67003 }
67004 if (setY.length > 1) {
67005 throw new ValueError("All target Tensors (y) should have the same number of samples. " + "Got array shapes: " + "".concat(JSON.stringify(targets.map(function (target) {
67006 return target.shape;
67007 }))));
67008 }
67009 if (setX.length > 0 && setY.length > 0 && !arraysEqual(setX, setY)) {
67010 throw new ValueError("Input Tensors should have the same number of samples as target " + "Tensors. Found ".concat(setX[0], " input sample(s) and ").concat(setY[0], " target ") + "sample(s).");
67011 }
67012 }
67013 /**
67014 * Validation on the compatibility of targes and loss functions.
67015 *
67016 * This helps prevent users from using loss functions incorrectly.
67017 *
67018 * @param targets `Array` of `tf.Tensor`s of targets.
67019 * @param lossFns `Array` of loss functions.
67020 * @param outputShapes `Array` of shapes of model outputs.
67021 */
67022 function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) {
67023 // TODO(cais): Dedicated test coverage?
67024 var keyLosses = [meanSquaredError$1, binaryCrossentropy$2, categoricalCrossentropy$2];
67025 for (var i = 0; i < targets.length; ++i) {
67026 var y = targets[i];
67027 var loss = lossFns[i];
67028 var shape = outputShapes[i];
67029 if (loss == null) {
67030 continue;
67031 }
67032 if (loss === categoricalCrossentropy$2) {
67033 if (y.shape[y.shape.length - 1] === 1) {
67034 throw new ValueError("You are passing a target array of shape ".concat(y.shape, " while using ") + "a loss 'categorical_crossentropy'. 'categorical_crossentropy'" + "expects targets to be binary matrices (1s and 0s) of shape " + "[samples, classes].");
67035 // TODO(cais): Example code in error message.
67036 }
67037 }
67038
67039 if (keyLosses.indexOf(loss) !== -1) {
67040 var slicedYShape = y.shape.slice(1);
67041 var slicedShape = shape.slice(1);
67042 for (var j = 0; j < slicedYShape.length; ++j) {
67043 var targetDim = slicedYShape[j];
67044 var outDim = slicedShape[j];
67045 if (outDim != null && targetDim !== outDim) {
67046 throw new ValueError("A target Tensor with shape ".concat(y.shape, " was passed for an ") + "output of shape ".concat(shape, ", while using a loss function that ") + "expects targets to have the same shape as the output.");
67047 }
67048 }
67049 }
67050 }
67051 }
67052 /**
67053 * Check inputs provided by the user.
67054 *
67055 * Porting Note: This corresponds to _standardize_input_data() in Python
67056 * Keras. Because of the strong typing in TF.js, we do not need to convert
67057 * the data. Specifically:
67058 * 1) in PyKeras, `data` can be `DataFrame` instances from pandas, for
67059 * example. We don't need to worry about that here because there is no
67060 * widely popular javascript/typesdcript equivalent of pandas (so far).
67061 * If one becomes available in the future, we can add support.
67062 * 2) in PyKeras, inputs can be Python dict. But here we are stipulating
67063 * that the data is either a single `tf.Tensor` or an Array of `tf.Tensor`s. We
67064 * may add support for `Object` data inputs in the future when the need
67065 * arises.
67066 *
67067 * Instead, we perform basic checks for number of parameters and shapes.
67068 *
67069 * @param data: The input data.
67070 * @param names: Name for the inputs, from the model.
67071 * @param shapes: Expected shapes for the input data, from the model.
67072 * @param checkBatchAxis: Whether the size along the batch axis (i.e., the
67073 * first dimension) will be checked for matching.
67074 * @param exceptionPrefix: Execption prefix message, used in generating error
67075 * messages.
67076 * @throws ValueError: on incorrect number of inputs or mismatches in shapes.
67077 */
67078 function checkInputData(data, names, shapes) {
67079 var checkBatchAxis = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : true;
67080 var exceptionPrefix = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : '';
67081 var arrays;
67082 if (Array.isArray(data)) {
67083 if (data.length !== names.length) {
67084 throw new ValueError("Error when checking model ".concat(exceptionPrefix, ": the Array of ") + "Tensors that you are passing to your model is not the size the " + "the model expected. Expected to see ".concat(names.length, " Tensor(s),") + " but instead got ".concat(data.length, " Tensors(s)."));
67085 }
67086 arrays = data;
67087 } else {
67088 if (names.length > 1) {
67089 throw new ValueError("The model expects ".concat(names.length, " ").concat(exceptionPrefix, " Tensors, ") + "but only received one Tensor. Found: array with shape " + "".concat(JSON.stringify(data.shape), "."));
67090 }
67091 arrays = [data];
67092 }
67093 if (shapes != null) {
67094 for (var i = 0; i < names.length; ++i) {
67095 if (shapes[i] == null) {
67096 continue;
67097 }
67098 var array = arrays[i];
67099 if (array.shape.length !== shapes[i].length) {
67100 throw new ValueError("Error when checking ".concat(exceptionPrefix, ": expected ").concat(names[i], " ") + "to have ".concat(shapes[i].length, " dimension(s), but got array with ") + "shape ".concat(JSON.stringify(array.shape)));
67101 }
67102 for (var j = 0; j < shapes[i].length; ++j) {
67103 if (j === 0 && !checkBatchAxis) {
67104 continue;
67105 }
67106 var dim = array.shape[j];
67107 var refDim = shapes[i][j];
67108 if (refDim != null) {
67109 if (refDim !== dim) {
67110 throw new ValueError("Error when checking ".concat(exceptionPrefix, ": expected ") + "".concat(names[i], " to have shape ").concat(JSON.stringify(shapes[i]), " but ") + "got array with shape ".concat(JSON.stringify(array.shape), "."));
67111 }
67112 }
67113 }
67114 }
67115 }
67116 }
67117 /**
67118 * Maps metric functions to model outputs.
67119 * @param metrics An shortcut strings name, metric function, `Array` or dict
67120 * (`Object`) of metric functions.
67121 * @param outputNames An `Array` of the names of model outputs.
67122 * @returns An `Array` (one entry per model output) of `Array` of metric
67123 * functions. For instance, if the model has 2 outputs, and for the first
67124 * output we want to compute `binaryAccuracy` and `binaryCrossentropy`,
67125 * and just `binaryAccuracy` for the second output, the `Array` would look
67126 * like:
67127 * `[[binaryAccuracy, binaryCrossentropy], [binaryAccuracy]]`
67128 * @throws TypeError: incompatible metrics format.
67129 */
67130 function collectMetrics(metrics, outputNames) {
67131 if (metrics == null || Array.isArray(metrics) && metrics.length === 0) {
67132 return outputNames.map(function (name) {
67133 return [];
67134 });
67135 }
67136 var wrappedMetrics;
67137 if (typeof metrics === 'string' || typeof metrics === 'function') {
67138 wrappedMetrics = [metrics];
67139 } else if (Array.isArray(metrics) || _typeof(metrics) === 'object') {
67140 wrappedMetrics = metrics;
67141 } else {
67142 throw new TypeError('Type of metrics argument not understood. Expected an string,' + "function, Array, or Object, found: ".concat(metrics));
67143 }
67144 if (Array.isArray(wrappedMetrics)) {
67145 // We then apply all metrics to all outputs.
67146 return outputNames.map(function (name) {
67147 return wrappedMetrics;
67148 });
67149 } else {
67150 // In this case, metrics is a dict.
67151 var nestedMetrics = [];
67152 var _iterator2 = _createForOfIteratorHelper(outputNames),
67153 _step2;
67154 try {
67155 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
67156 var name = _step2.value;
67157 var outputMetrics = wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : [];
67158 if (!Array.isArray(outputMetrics)) {
67159 outputMetrics = [outputMetrics];
67160 }
67161 nestedMetrics.push(outputMetrics);
67162 }
67163 } catch (err) {
67164 _iterator2.e(err);
67165 } finally {
67166 _iterator2.f();
67167 }
67168 return nestedMetrics;
67169 }
67170 }
67171 var LAYERS_MODEL_FORMAT_NAME = 'layers-model';
67172 /**
67173 * A `tf.LayersModel` is a directed, acyclic graph of `tf.Layer`s plus methods
67174 * for training, evaluation, prediction and saving.
67175 *
67176 * `tf.LayersModel` is the basic unit of training, inference and evaluation in
67177 * TensorFlow.js. To create a `tf.LayersModel`, use `tf.LayersModel`.
67178 *
67179 * See also:
67180 * `tf.Sequential`, `tf.loadLayersModel`.
67181 *
67182 * @doc {heading: 'Models', subheading: 'Classes'}
67183 */
67184 var LayersModel = /*#__PURE__*/function (_Container) {
67185 _inherits(LayersModel, _Container);
67186 var _super = _createSuper(LayersModel);
67187 function LayersModel(args) {
67188 var _this;
67189 _classCallCheck(this, LayersModel);
67190 _this = _super.call(this, args);
67191 _this.isTraining = false;
67192 return _this;
67193 }
67194 /**
67195 * Print a text summary of the model's layers.
67196 *
67197 * The summary includes
67198 * - Name and type of all layers that comprise the model.
67199 * - Output shape(s) of the layers
67200 * - Number of weight parameters of each layer
67201 * - If the model has non-sequential-like topology, the inputs each layer
67202 * receives
67203 * - The total number of trainable and non-trainable parameters of the model.
67204 *
67205 * ```js
67206 * const input1 = tf.input({shape: [10]});
67207 * const input2 = tf.input({shape: [20]});
67208 * const dense1 = tf.layers.dense({units: 4}).apply(input1);
67209 * const dense2 = tf.layers.dense({units: 8}).apply(input2);
67210 * const concat = tf.layers.concatenate().apply([dense1, dense2]);
67211 * const output =
67212 * tf.layers.dense({units: 3, activation: 'softmax'}).apply(concat);
67213 *
67214 * const model = tf.model({inputs: [input1, input2], outputs: output});
67215 * model.summary();
67216 * ```
67217 *
67218 * @param lineLength Custom line length, in number of characters.
67219 * @param positions Custom widths of each of the columns, as either
67220 * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
67221 * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
67222 * right-most (i.e., ending) position of a column.
67223 * @param printFn Custom print function. Can be used to replace the default
67224 * `console.log`. For example, you can use `x => {}` to mute the printed
67225 * messages in the console.
67226 *
67227 * @doc {heading: 'Models', subheading: 'Classes'}
67228 */
67229 _createClass(LayersModel, [{
67230 key: "summary",
67231 value: function summary(lineLength, positions) {
67232 var printFn = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : console.log;
67233 if (!this.built) {
67234 throw new ValueError("This model has never been called, thus its weights have not been " + "created yet. So no summary can be displayed. Build the model " + "first (e.g., by calling it on some test data).");
67235 }
67236 printSummary(this, lineLength, positions, printFn);
67237 }
67238 /**
67239 * Configures and prepares the model for training and evaluation. Compiling
67240 * outfits the model with an optimizer, loss, and/or metrics. Calling `fit`
67241 * or `evaluate` on an un-compiled model will throw an error.
67242 *
67243 * @param args a `ModelCompileArgs` specifying the loss, optimizer, and
67244 * metrics to be used for fitting and evaluating this model.
67245 *
67246 * @doc {heading: 'Models', subheading: 'Classes'}
67247 */
67248 }, {
67249 key: "compile",
67250 value: function compile(args) {
67251 var _this2 = this;
67252 if (args.loss == null) {
67253 args.loss = [];
67254 }
67255 this.loss = args.loss;
67256 if (typeof args.optimizer === 'string') {
67257 this.optimizer_ = getOptimizer(args.optimizer);
67258 this.isOptimizerOwned = true;
67259 } else {
67260 if (!(args.optimizer instanceof Optimizer)) {
67261 throw new ValueError("User-defined optimizer must be an instance of tf.Optimizer.");
67262 }
67263 this.optimizer_ = args.optimizer;
67264 this.isOptimizerOwned = false;
67265 }
67266 // TODO(cais): Add lossWeights.
67267 // TODO(cais): Add sampleWeightMode.
67268 // Prepare loss functions.
67269 var lossFunctions = [];
67270 if (!Array.isArray(args.loss) && typeof args.loss !== 'string' && typeof args.loss !== 'function') {
67271 args.loss = args.loss;
67272 for (var name in args.loss) {
67273 if (this.outputNames.indexOf(name) === -1) {
67274 throw new ValueError("Unknown entry in loss dictionary: \"".concat(name, "\". ") + "Only expected the following keys: ".concat(this.outputNames));
67275 }
67276 }
67277 var _iterator3 = _createForOfIteratorHelper(this.outputNames),
67278 _step3;
67279 try {
67280 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
67281 var _name = _step3.value;
67282 if (args.loss[_name] == null) {
67283 console.warn("Output \"".concat(_name, "\" is missing from loss dictionary. We assume ") + "this was done on purpose, and we will not be expecting data " + "to be passed to ".concat(_name, " during training"));
67284 }
67285 lossFunctions.push(get$1(args.loss[_name]));
67286 }
67287 } catch (err) {
67288 _iterator3.e(err);
67289 } finally {
67290 _iterator3.f();
67291 }
67292 } else if (Array.isArray(args.loss)) {
67293 if (args.loss.length !== this.outputs.length) {
67294 throw new ValueError("When passing an Array as loss, it should have one entry per " + "model output. The model has ".concat(this.outputs.length, " output(s), ") + "but you passed loss=".concat(args.loss, "."));
67295 }
67296 var theLosses = args.loss;
67297 lossFunctions = theLosses.map(function (l) {
67298 return get$1(l);
67299 });
67300 } else {
67301 var lossFunction = get$1(args.loss);
67302 this.outputs.forEach(function (_) {
67303 lossFunctions.push(lossFunction);
67304 });
67305 }
67306 this.lossFunctions = lossFunctions;
67307 this.feedOutputNames = [];
67308 this.feedOutputShapes = [];
67309 this.feedLossFns = [];
67310 for (var i = 0; i < this.outputs.length; ++i) {
67311 // TODO(cais): Logic for skipping target(s).
67312 var shape = this.internalOutputShapes[i];
67313 var _name2 = this.outputNames[i];
67314 this.feedOutputNames.push(_name2);
67315 this.feedOutputShapes.push(shape);
67316 this.feedLossFns.push(this.lossFunctions[i]);
67317 }
67318 // TODO(cais): Add logic for output masks.
67319 // TODO(cais): Add logic for sample weights.
67320 var skipTargetIndices = [];
67321 // Prepare metrics.
67322 this.metrics = args.metrics;
67323 // TODO(cais): Add weightedMetrics.
67324 this.metricsNames = ['loss'];
67325 this.metricsTensors = [];
67326 // Compute total loss.
67327 // Porting Note: In PyKeras, metrics_tensors are symbolic tensor objects.
67328 // Here, metricsTensors are TypeScript functions. This difference is due
67329 // to the difference in symbolic/imperative property of the backends.
67330 nameScope('loss', function () {
67331 for (var _i = 0; _i < _this2.outputs.length; ++_i) {
67332 if (skipTargetIndices.indexOf(_i) !== -1) {
67333 continue;
67334 }
67335 // TODO(cais): Add weightedLoss, sampleWeight and mask.
67336 // The following line should be weightedLoss
67337 var weightedLoss = _this2.lossFunctions[_i];
67338 if (_this2.outputs.length > 1) {
67339 _this2.metricsTensors.push([weightedLoss, _i]);
67340 _this2.metricsNames.push(_this2.outputNames[_i] + '_loss');
67341 }
67342 }
67343 // Porting Note: Due to the imperative nature of the backend, we calculate
67344 // the regularizer penalties in the totalLossFunction, instead of here.
67345 });
67346
67347 var nestedMetrics = collectMetrics(args.metrics, this.outputNames);
67348 // TODO(cais): Add nestedWeightedMetrics.
67349 /**
67350 * Helper function used in loop below.
67351 */
67352 var appendMetric = function appendMetric(outputIndex, metricName, metricTensor) {
67353 if (_this2.outputNames.length > 1) {
67354 metricName = _this2.outputNames[outputIndex] + '_' + metricName;
67355 }
67356 _this2.metricsNames.push(metricName);
67357 _this2.metricsTensors.push([metricTensor, outputIndex]);
67358 };
67359 nameScope('metric', function () {
67360 var _loop = function _loop(_i2) {
67361 if (skipTargetIndices.indexOf(_i2) !== -1) {
67362 return "continue";
67363 }
67364 var outputMetrics = nestedMetrics[_i2];
67365 // TODO(cais): Add weights and outputWeightedMetrics.
67366 // TODO(cais): Add optional arg `weights` to the following function.
67367 var handleMetrics = function handleMetrics(metrics) {
67368 var metricNamePrefix = '';
67369 var metricName;
67370 var accFn;
67371 var weightedMetricFn;
67372 // TODO(cais): Use 'weights_' for weighted metrics.
67373 var _iterator4 = _createForOfIteratorHelper(metrics),
67374 _step4;
67375 try {
67376 var _loop2 = function _loop2() {
67377 var metric = _step4.value;
67378 if (typeof metric === 'string' && ['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !== -1) {
67379 var outputShape = _this2.internalOutputShapes[_i2];
67380 if (outputShape[outputShape.length - 1] === 1 || _this2.lossFunctions[_i2] === binaryCrossentropy$2) {
67381 // case: binary accuracy/crossentropy.
67382 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
67383 accFn = binaryAccuracy$1;
67384 } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
67385 accFn = binaryCrossentropy$1;
67386 }
67387 } else if (_this2.lossFunctions[_i2] === sparseCategoricalCrossentropy$1) {
67388 // case: categorical accuracy / crossentropy with sparse
67389 // targets.
67390 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
67391 accFn = sparseCategoricalAccuracy$1;
67392 } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
67393 accFn = sparseCategoricalCrossentropy;
67394 }
67395 } else {
67396 // case: categorical accuracy / crossentropy.
67397 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
67398 accFn = categoricalAccuracy$1;
67399 } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
67400 accFn = categoricalCrossentropy$1;
67401 }
67402 }
67403 var suffix;
67404 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
67405 suffix = 'acc';
67406 } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
67407 suffix = 'ce';
67408 }
67409 // TODO(cais): Add weighting actually.
67410 weightedMetricFn = accFn;
67411 metricName = metricNamePrefix + suffix;
67412 } else {
67413 var metricFn = get(metric);
67414 // TODO(cais): Add weighting actually.
67415 weightedMetricFn = metricFn;
67416 metricName = metricNamePrefix + getLossOrMetricName(metric);
67417 }
67418 // TODO(cais): Add weighting and masking to metricResult.
67419 var metricResult;
67420 nameScope(metricName, function () {
67421 metricResult = weightedMetricFn;
67422 });
67423 appendMetric(_i2, metricName, metricResult);
67424 };
67425 for (_iterator4.s(); !(_step4 = _iterator4.n()).done;) {
67426 _loop2();
67427 }
67428 } catch (err) {
67429 _iterator4.e(err);
67430 } finally {
67431 _iterator4.f();
67432 }
67433 };
67434 handleMetrics(outputMetrics);
67435 // TODO(cais): Call handleMetrics with weights.
67436 };
67437 for (var _i2 = 0; _i2 < _this2.outputs.length; ++_i2) {
67438 var _ret = _loop(_i2);
67439 if (_ret === "continue") continue;
67440 }
67441 });
67442 // Porting Notes: Given the imperative backend of tfjs-core,
67443 // there is no need for constructing the symbolic graph and placeholders.
67444 this.collectedTrainableWeights = this.trainableWeights;
67445 }
67446 /**
67447 * Check trainable weights count consistency.
67448 *
67449 * This will raise a warning if `this.trainableWeights` and
67450 * `this.collectedTrainableWeights` are inconsistent (i.e., have different
67451 * numbers of parameters).
67452 * Inconsistency will typically arise when one modifies `model.trainable`
67453 * without calling `model.compile()` again.
67454 */
67455 }, {
67456 key: "checkTrainableWeightsConsistency",
67457 value: function checkTrainableWeightsConsistency() {
67458 if (this.collectedTrainableWeights == null) {
67459 return;
67460 }
67461 if (this.trainableWeights.length !== this.collectedTrainableWeights.length) {
67462 console.warn('Discrepancy between trainableweights and collected trainable ' + 'weights. Did you set `model.trainable` without calling ' + '`model.compile()` afterwards?');
67463 }
67464 }
67465 /**
67466 * Returns the loss value & metrics values for the model in test mode.
67467 *
67468 * Loss and metrics are specified during `compile()`, which needs to happen
67469 * before calls to `evaluate()`.
67470 *
67471 * Computation is done in batches.
67472 *
67473 * ```js
67474 * const model = tf.sequential({
67475 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
67476 * });
67477 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
67478 * const result = model.evaluate(
67479 * tf.ones([8, 10]), tf.ones([8, 1]), {batchSize: 4});
67480 * result.print();
67481 * ```
67482 *
67483 * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
67484 * model has multiple inputs.
67485 * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
67486 * model has multiple outputs.
67487 * @param args A `ModelEvaluateArgs`, containing optional fields.
67488 *
67489 * @return `Scalar` test loss (if the model has a single output and no
67490 * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
67491 * and/or metrics). The attribute `model.metricsNames`
67492 * will give you the display labels for the scalar outputs.
67493 *
67494 * @doc {heading: 'Models', subheading: 'Classes'}
67495 */
67496 }, {
67497 key: "evaluate",
67498 value: function evaluate(x, y) {
67499 var args = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : {};
67500 var batchSize = args.batchSize == null ? 32 : args.batchSize;
67501 checkBatchSize(batchSize);
67502 // TODO(cais): Standardize `config.sampleWeights` as well.
67503 // Validate user data.
67504 var checkBatchAxis = true;
67505 var standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
67506 try {
67507 // TODO(cais): If uses `useLearningPhase`, set the corresponding element
67508 // of the input to 0.
67509 var ins = standardizedOuts[0].concat(standardizedOuts[1]);
67510 this.makeTestFunction();
67511 var f = this.testFunction;
67512 var testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps);
67513 return singletonOrArray(testOuts);
67514 } finally {
67515 disposeNewTensors(standardizedOuts[0], x);
67516 disposeNewTensors(standardizedOuts[1], y);
67517 }
67518 }
67519 // TODO(cais): Add code snippet below once real dataset objects are
67520 // available.
67521 /**
67522 * Evaluate model using a dataset object.
67523 *
67524 * Note: Unlike `evaluate()`, this method is asynchronous (`async`).
67525 *
67526 * @param dataset A dataset object. Its `iterator()` method is expected
67527 * to generate a dataset iterator object, the `next()` method of which
67528 * is expected to produce data batches for evaluation. The return value
67529 * of the `next()` call ought to contain a boolean `done` field and a
67530 * `value` field. The `value` field is expected to be an array of two
67531 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
67532 * case is for models with exactly one input and one output (e.g.
67533 * a sequential model). The latter case is for models with multiple
67534 * inputs and/or multiple outputs. Of the two items in the array, the
67535 * first is the input feature(s) and the second is the output target(s).
67536 * @param args A configuration object for the dataset-based evaluation.
67537 * @returns Loss and metric values as an Array of `Scalar` objects.
67538 *
67539 * @doc {heading: 'Models', subheading: 'Classes'}
67540 */
67541 }, {
67542 key: "evaluateDataset",
67543 value: function () {
67544 var _evaluateDataset2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(dataset, args) {
67545 return _regeneratorRuntime().wrap(function _callee$(_context) {
67546 while (1) switch (_context.prev = _context.next) {
67547 case 0:
67548 this.makeTestFunction();
67549 return _context.abrupt("return", evaluateDataset(this, dataset, args));
67550 case 2:
67551 case "end":
67552 return _context.stop();
67553 }
67554 }, _callee, this);
67555 }));
67556 function evaluateDataset$1(_x, _x2) {
67557 return _evaluateDataset2.apply(this, arguments);
67558 }
67559 return evaluateDataset$1;
67560 }()
67561 /**
67562 * Get number of samples provided for training, evaluation or prediction.
67563 *
67564 * @param ins Input `tf.Tensor`.
67565 * @param batchSize Integer batch size, optional.
67566 * @param steps Total number of steps (batches of samples) before
67567 * declaring loop finished. Optional.
67568 * @param stepsName The public API's parameter name for `steps`.
67569 * @returns Number of samples provided.
67570 */
67571 }, {
67572 key: "checkNumSamples",
67573 value: function checkNumSamples(ins, batchSize, steps) {
67574 var stepsName = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'steps';
67575 var numSamples;
67576 if (steps != null) {
67577 numSamples = null;
67578 if (batchSize != null) {
67579 throw new ValueError("If ".concat(stepsName, " is set, batchSize must be null or undefined.") + "Got batchSize = ".concat(batchSize));
67580 }
67581 } else if (ins != null) {
67582 if (Array.isArray(ins)) {
67583 numSamples = ins[0].shape[0];
67584 } else {
67585 numSamples = ins.shape[0];
67586 }
67587 } else {
67588 throw new ValueError("Either the input data should have a defined shape, or " + "".concat(stepsName, " shoud be specified."));
67589 }
67590 return numSamples;
67591 }
67592 /**
67593 * Execute internal tensors of the model with input data feed.
67594 * @param inputs Input data feed. Must match the inputs of the model.
67595 * @param outputs Names of the output tensors to be fetched. Must match
67596 * names of the SymbolicTensors that belong to the graph.
67597 * @returns Fetched values for `outputs`.
67598 */
67599 }, {
67600 key: "execute",
67601 value: function execute$1(inputs, outputs) {
67602 if (Array.isArray(outputs) && outputs.length === 0) {
67603 throw new ValueError('`outputs` is an empty Array, which is not allowed.');
67604 }
67605 var outputsIsArray = Array.isArray(outputs);
67606 var outputNames = outputsIsArray ? outputs : [outputs];
67607 var outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames);
67608 // Format the input into a FeedDict.
67609 var feedDict = new FeedDict();
67610 if (inputs instanceof Tensor) {
67611 inputs = [inputs];
67612 }
67613 if (Array.isArray(inputs)) {
67614 if (inputs.length !== this.inputs.length) {
67615 throw new ValueError("The number of inputs provided (".concat(inputs.length, ") ") + "does not match the number of inputs of this model " + "(".concat(this.inputs.length, ")."));
67616 }
67617 for (var i = 0; i < this.inputs.length; ++i) {
67618 feedDict.add(this.inputs[i], inputs[i]);
67619 }
67620 } else {
67621 var _iterator5 = _createForOfIteratorHelper(this.inputs),
67622 _step5;
67623 try {
67624 for (_iterator5.s(); !(_step5 = _iterator5.n()).done;) {
67625 var input = _step5.value;
67626 var tensorValue = inputs[input.name];
67627 if (tensorValue == null) {
67628 throw new ValueError("No value is provided for the model's input ".concat(input.name));
67629 }
67630 feedDict.add(input, tensorValue);
67631 }
67632 } catch (err) {
67633 _iterator5.e(err);
67634 } finally {
67635 _iterator5.f();
67636 }
67637 }
67638 // Run execution.
67639 var executeOutputs = execute(outputSymbolicTensors, feedDict);
67640 return outputsIsArray ? executeOutputs : executeOutputs[0];
67641 }
67642 /**
67643 * Retrieve the model's internal symbolic tensors from symbolic-tensor names.
67644 */
67645 }, {
67646 key: "retrieveSymbolicTensors",
67647 value: function retrieveSymbolicTensors(symbolicTensorNames) {
67648 var outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length);
67649 var outputsRemaining = symbolicTensorNames.length;
67650 var _iterator6 = _createForOfIteratorHelper(this.layers),
67651 _step6;
67652 try {
67653 for (_iterator6.s(); !(_step6 = _iterator6.n()).done;) {
67654 var layer = _step6.value;
67655 var layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output];
67656 var layerOutputNames = layerOutputs.map(function (output) {
67657 return output.name;
67658 });
67659 for (var i = 0; i < symbolicTensorNames.length; ++i) {
67660 var index = layerOutputNames.indexOf(symbolicTensorNames[i]);
67661 if (index !== -1) {
67662 outputSymbolicTensors[i] = layerOutputs[index];
67663 outputsRemaining--;
67664 }
67665 if (outputsRemaining === 0) {
67666 break;
67667 }
67668 }
67669 if (outputsRemaining === 0) {
67670 break;
67671 }
67672 }
67673 } catch (err) {
67674 _iterator6.e(err);
67675 } finally {
67676 _iterator6.f();
67677 }
67678 if (outputsRemaining > 0) {
67679 var remainingNames = [];
67680 outputSymbolicTensors.forEach(function (tensor, i) {
67681 if (tensor == null) {
67682 remainingNames.push(symbolicTensorNames[i]);
67683 }
67684 });
67685 throw new ValueError("Cannot find SymbolicTensors for output name(s): " + "".concat(JSON.stringify(remainingNames)));
67686 }
67687 return outputSymbolicTensors;
67688 }
67689 /**
67690 * Helper method to loop over some data in batches.
67691 *
67692 * Porting Note: Not using the functional approach in the Python equivalent
67693 * due to the imperative backend.
67694 * Porting Note: Does not support step mode currently.
67695 *
67696 * @param ins: input data
67697 * @param batchSize: integer batch size.
67698 * @param verbose: verbosity model
67699 * @returns: Predictions as `tf.Tensor` (if a single output) or an `Array` of
67700 * `tf.Tensor` (if multipe outputs).
67701 */
67702 }, {
67703 key: "predictLoop",
67704 value: function predictLoop(ins) {
67705 var _this3 = this;
67706 var batchSize = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 32;
67707 var verbose = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
67708 return tidy(function () {
67709 var numSamples = _this3.checkNumSamples(ins);
67710 if (verbose) {
67711 throw new NotImplementedError('Verbose predictLoop() is not implemented yet.');
67712 }
67713 // Sample-based predictions.
67714 // Porting Note: Tensor currently does not support sliced assignments as
67715 // in numpy, e.g., x[1:3] = y. Therefore we use concatenation while
67716 // iterating over the batches.
67717 var batches = makeBatches(numSamples, batchSize);
67718 var outsBatches = _this3.outputs.map(function (output) {
67719 return [];
67720 });
67721 // TODO(cais): Can the scope() be pushed down inside the for loop?
67722 var _loop3 = function _loop3(batchIndex) {
67723 var batchOuts = tidy(function () {
67724 var batchStart = batches[batchIndex][0];
67725 var batchEnd = batches[batchIndex][1];
67726 // TODO(cais): Take care of the case of the last element is a flag for
67727 // training/test.
67728 var insBatch = sliceArrays(ins, batchStart, batchEnd);
67729 // Construct the feeds for execute();
67730 var feeds = [];
67731 if (Array.isArray(insBatch)) {
67732 for (var i = 0; i < insBatch.length; ++i) {
67733 feeds.push({
67734 key: _this3.inputs[i],
67735 value: insBatch[i]
67736 });
67737 }
67738 } else {
67739 feeds.push({
67740 key: _this3.inputs[0],
67741 value: insBatch
67742 });
67743 }
67744 var feedDict = new FeedDict(feeds);
67745 return execute(_this3.outputs, feedDict);
67746 });
67747 batchOuts.forEach(function (batchOut, i) {
67748 return outsBatches[i].push(batchOut);
67749 });
67750 };
67751 for (var batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
67752 _loop3(batchIndex);
67753 }
67754 return singletonOrArray(outsBatches.map(function (batches) {
67755 return concat$2(batches, 0);
67756 }));
67757 });
67758 }
67759 /**
67760 * Generates output predictions for the input samples.
67761 *
67762 * Computation is done in batches.
67763 *
67764 * Note: the "step" mode of predict() is currently not supported.
67765 * This is because the TensorFlow.js core backend is imperative only.
67766 *
67767 * ```js
67768 * const model = tf.sequential({
67769 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
67770 * });
67771 * model.predict(tf.ones([8, 10]), {batchSize: 4}).print();
67772 * ```
67773 *
67774 * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
67775 * the model has multiple inputs.
67776 * @param args A `ModelPredictArgs` object containing optional fields.
67777 *
67778 * @return Prediction results as a `tf.Tensor`(s).
67779 *
67780 * @exception ValueError In case of mismatch between the provided input data
67781 * and the model's expectations, or in case a stateful model receives a
67782 * number of samples that is not a multiple of the batch size.
67783 *
67784 * @doc {heading: 'Models', subheading: 'Classes'}
67785 */
67786 }, {
67787 key: "predict",
67788 value: function predict(x) {
67789 var args = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
67790 var xsRank2OrHigher = ensureTensorsRank2OrHigher(x);
67791 checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false);
67792 try {
67793 // TODO(cais): Take care of stateful models.
67794 // if (this.stateful) ...
67795 // TODO(cais): Take care of the learning_phase boolean flag.
67796 // if (this.useLearningPhase) ...
67797 var batchSize = args.batchSize == null ? 32 : args.batchSize;
67798 checkBatchSize(batchSize);
67799 return this.predictLoop(xsRank2OrHigher, batchSize);
67800 } finally {
67801 disposeNewTensors(xsRank2OrHigher, x);
67802 }
67803 }
67804 /**
67805 * Returns predictions for a single batch of samples.
67806 *
67807 * ```js
67808 * const model = tf.sequential({
67809 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
67810 * });
67811 * model.predictOnBatch(tf.ones([8, 10])).print();
67812 * ```
67813 * @param x: Input samples, as a Tensor (for models with exactly one
67814 * input) or an array of Tensors (for models with more than one input).
67815 * @return Tensor(s) of predictions
67816 *
67817 * @doc {heading: 'Models', subheading: 'Classes'}
67818 */
67819 }, {
67820 key: "predictOnBatch",
67821 value: function predictOnBatch(x) {
67822 checkInputData(x, this.inputNames, this.feedInputShapes, true);
67823 // TODO(cais): Take care of the learning_phase boolean flag.
67824 // if (this.useLearningPhase) ...
67825 var batchSize = (Array.isArray(x) ? x[0] : x).shape[0];
67826 return this.predictLoop(x, batchSize);
67827 }
67828 }, {
67829 key: "standardizeUserDataXY",
67830 value: function standardizeUserDataXY(x, y) {
67831 var checkBatchAxis = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true;
67832 var batchSize = arguments.length > 3 ? arguments[3] : undefined;
67833 // TODO(cais): Add sampleWeight, classWeight
67834 if (this.optimizer_ == null) {
67835 throw new RuntimeError('You must compile a model before training/testing. Use ' + 'LayersModel.compile(modelCompileArgs).');
67836 }
67837 var outputShapes = [];
67838 for (var i = 0; i < this.feedOutputShapes.length; ++i) {
67839 var outputShape = this.feedOutputShapes[i];
67840 var lossFn = this.feedLossFns[i];
67841 if (lossFn === sparseCategoricalCrossentropy$1) {
67842 outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1]));
67843 } else {
67844 // Porting Note: Because of strong typing `lossFn` must be a function.
67845 outputShapes.push(outputShape);
67846 }
67847 }
67848 x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, 'input');
67849 y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, 'target');
67850 // TODO(cais): Standardize sampleWeights & classWeights.
67851 checkArrayLengths(x, y, null);
67852 // TODO(cais): Check sampleWeights as well.
67853 checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes);
67854 if (this.stateful && batchSize != null && batchSize > 0) {
67855 if (x[0].shape[0] % batchSize !== 0) {
67856 throw new ValueError("In a stateful network, you should only pass inputs with a " + "number of samples that is divisible by the batch size " + "".concat(batchSize, ". Found: ").concat(x[0].shape[0], " sample(s)."));
67857 }
67858 }
67859 return [x, y];
67860 }
67861 }, {
67862 key: "standardizeUserData",
67863 value: function () {
67864 var _standardizeUserData = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(x, y, sampleWeight, classWeight) {
67865 var checkBatchAxis,
67866 batchSize,
67867 _this$standardizeUser,
67868 _this$standardizeUser2,
67869 standardXs,
67870 standardYs,
67871 standardSampleWeights,
67872 classWeights,
67873 i,
67874 _args2 = arguments;
67875 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
67876 while (1) switch (_context2.prev = _context2.next) {
67877 case 0:
67878 checkBatchAxis = _args2.length > 4 && _args2[4] !== undefined ? _args2[4] : true;
67879 batchSize = _args2.length > 5 ? _args2[5] : undefined;
67880 _this$standardizeUser = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize), _this$standardizeUser2 = _slicedToArray(_this$standardizeUser, 2), standardXs = _this$standardizeUser2[0], standardYs = _this$standardizeUser2[1]; // TODO(cais): Handle sampleWeights.
67881 if (!(sampleWeight != null)) {
67882 _context2.next = 5;
67883 break;
67884 }
67885 throw new Error('sample weight is not supported yet.');
67886 case 5:
67887 standardSampleWeights = null;
67888 if (!(classWeight != null)) {
67889 _context2.next = 19;
67890 break;
67891 }
67892 classWeights = standardizeClassWeights(classWeight, this.outputNames);
67893 standardSampleWeights = [];
67894 i = 0;
67895 case 10:
67896 if (!(i < classWeights.length)) {
67897 _context2.next = 19;
67898 break;
67899 }
67900 _context2.t0 = standardSampleWeights;
67901 _context2.next = 14;
67902 return standardizeWeights(standardYs[i], null, classWeights[i]);
67903 case 14:
67904 _context2.t1 = _context2.sent;
67905 _context2.t0.push.call(_context2.t0, _context2.t1);
67906 case 16:
67907 ++i;
67908 _context2.next = 10;
67909 break;
67910 case 19:
67911 return _context2.abrupt("return", [standardXs, standardYs, standardSampleWeights]);
67912 case 20:
67913 case "end":
67914 return _context2.stop();
67915 }
67916 }, _callee2, this);
67917 }));
67918 function standardizeUserData(_x3, _x4, _x5, _x6) {
67919 return _standardizeUserData.apply(this, arguments);
67920 }
67921 return standardizeUserData;
67922 }()
67923 /**
67924 * Loop over some test data in batches.
67925 * @param f A Function returning a list of tensors.
67926 * @param ins Array of tensors to be fed to `f`.
67927 * @param batchSize Integer batch size or `null` / `undefined`.
67928 * @param verbose verbosity mode.
67929 * @param steps Total number of steps (batches of samples) before
67930 * declaring test finished. Ignored with the default value of `null` /
67931 * `undefined`.
67932 * @returns Array of Scalars.
67933 */
67934 }, {
67935 key: "testLoop",
67936 value: function testLoop(f, ins, batchSize) {
67937 var _this4 = this;
67938 var verbose = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 0;
67939 var steps = arguments.length > 4 ? arguments[4] : undefined;
67940 return tidy(function () {
67941 var numSamples = _this4.checkNumSamples(ins, batchSize, steps, 'steps');
67942 var outs = [];
67943 if (verbose > 0) {
67944 throw new NotImplementedError('Verbose mode is not implemented yet.');
67945 }
67946 // TODO(cais): Use `indicesForConversionToDense' to prevent slow down.
67947 if (steps != null) {
67948 throw new NotImplementedError('steps mode in testLoop() is not implemented yet');
67949 } else {
67950 var batches = makeBatches(numSamples, batchSize);
67951 var indexArray = tensor1d(range$2(0, numSamples));
67952 for (var batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
67953 var batchStart = batches[batchIndex][0];
67954 var batchEnd = batches[batchIndex][1];
67955 var batchIds = sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart);
67956 // TODO(cais): In ins, train flag can be a number, instead of an
67957 // Tensor? Do we need to handle this in tfjs-layers?
67958 var insBatch = sliceArraysByIndices(ins, batchIds);
67959 var batchOuts = f(insBatch);
67960 if (batchIndex === 0) {
67961 for (var i = 0; i < batchOuts.length; ++i) {
67962 outs.push(scalar(0));
67963 }
67964 }
67965 for (var _i3 = 0; _i3 < batchOuts.length; ++_i3) {
67966 var batchOut = batchOuts[_i3];
67967 outs[_i3] = add$3(outs[_i3], mul(batchEnd - batchStart, batchOut));
67968 }
67969 }
67970 for (var _i4 = 0; _i4 < outs.length; ++_i4) {
67971 outs[_i4] = div$1(outs[_i4], numSamples);
67972 }
67973 }
67974 return outs;
67975 });
67976 }
67977 }, {
67978 key: "getDedupedMetricsNames",
67979 value: function getDedupedMetricsNames() {
67980 var outLabels = this.metricsNames;
67981 // Rename duplicated metrics names (can happen with an output layer
67982 // shared among multiple dataflows).
67983 var dedupedOutLabels = [];
67984 for (var i = 0; i < outLabels.length; ++i) {
67985 var label = outLabels[i];
67986 var newLabel = label;
67987 if (count(outLabels, label) > 1) {
67988 var dupIndex = count(outLabels.slice(0, i), label);
67989 newLabel += "_".concat(dupIndex);
67990 }
67991 dedupedOutLabels.push(newLabel);
67992 }
67993 return dedupedOutLabels;
67994 }
67995 /**
67996 * Creates a function that performs the following actions:
67997 *
67998 * 1. computes the losses
67999 * 2. sums them to get the total loss
68000 * 3. call the optimizer computes the gradients of the LayersModel's
68001 * trainable weights w.r.t. the total loss and update the variables
68002 * 4. calculates the metrics
68003 * 5. returns the values of the losses and metrics.
68004 */
68005 }, {
68006 key: "makeTrainFunction",
68007 value: function makeTrainFunction() {
68008 var _this5 = this;
68009 return function (data) {
68010 var lossValues = [];
68011 var inputs = data.slice(0, _this5.inputs.length);
68012 var targets = data.slice(_this5.inputs.length, _this5.inputs.length + _this5.outputs.length);
68013 var sampleWeights = data.slice(_this5.inputs.length + _this5.outputs.length, _this5.inputs.length + _this5.outputs.length * 2);
68014 var metricsValues = [];
68015 // Create a function that computes the total loss based on the
68016 // inputs. This function is used for obtaining gradients through
68017 // backprop.
68018 var totalLossFunction = function totalLossFunction() {
68019 var feeds = [];
68020 for (var i = 0; i < _this5.inputs.length; ++i) {
68021 feeds.push({
68022 key: _this5.inputs[i],
68023 value: inputs[i]
68024 });
68025 }
68026 var feedDict = new FeedDict(feeds);
68027 var outputs = execute(_this5.outputs, feedDict, {
68028 'training': true
68029 });
68030 // TODO(cais): Take care of the case of multiple outputs from a
68031 // single layer?
68032 var totalLoss;
68033 for (var _i5 = 0; _i5 < _this5.lossFunctions.length; ++_i5) {
68034 var lossFunction = _this5.lossFunctions[_i5];
68035 var loss = lossFunction(targets[_i5], outputs[_i5]);
68036 if (sampleWeights[_i5] != null) {
68037 loss = computeWeightedLoss(loss, sampleWeights[_i5]);
68038 }
68039 // TODO(cais): push Scalar instead.
68040 var meanLoss = mean$3(loss);
68041 // TODO(cais): Use a scope() instead, to avoid ownership.
68042 lossValues.push(meanLoss);
68043 if (_i5 === 0) {
68044 totalLoss = loss;
68045 } else {
68046 totalLoss = add$3(totalLoss, loss);
68047 }
68048 }
68049 // Compute the metrics.
68050 // TODO(cais): These should probably be calculated outside
68051 // totalLossFunction to benefit speed?
68052 for (var _i6 = 0; _i6 < _this5.metricsTensors.length; ++_i6) {
68053 var weightedMetric = void 0;
68054 if (_this5.outputs.length > 1 && _i6 < _this5.outputs.length) {
68055 weightedMetric = lossValues[_i6];
68056 } else {
68057 var metric = _this5.metricsTensors[_i6][0];
68058 var outputIndex = _this5.metricsTensors[_i6][1];
68059 weightedMetric = mean$3(metric(targets[outputIndex], outputs[outputIndex]));
68060 }
68061 keep(weightedMetric);
68062 // TODO(cais): Use a scope() instead, to avoid ownership.
68063 metricsValues.push(weightedMetric);
68064 }
68065 totalLoss = mean$3(totalLoss);
68066 // Add regularizer penalties.
68067 _this5.calculateLosses().forEach(function (regularizerLoss) {
68068 totalLoss = add$3(totalLoss, regularizerLoss);
68069 });
68070 return totalLoss;
68071 };
68072 var variables = _this5.collectedTrainableWeights.map(function (param) {
68073 return param.read();
68074 });
68075 var returnCost = true;
68076 var totalLossValue = _this5.optimizer_.minimize(totalLossFunction, returnCost, variables);
68077 return [totalLossValue].concat(metricsValues);
68078 };
68079 }
68080 /**
68081 * Create a function which, when invoked with an array of `tf.Tensor`s as a
68082 * batch of inputs, returns the prespecified loss and metrics of the model
68083 * under the batch of input data.
68084 */
68085 }, {
68086 key: "makeTestFunction",
68087 value: function makeTestFunction() {
68088 var _this6 = this;
68089 this.testFunction = function (data) {
68090 return tidy(function () {
68091 var valOutputs = [];
68092 var totalLoss;
68093 var inputs = data.slice(0, _this6.inputs.length);
68094 var targets = data.slice(_this6.inputs.length, _this6.inputs.length + _this6.outputs.length);
68095 var feeds = [];
68096 for (var i = 0; i < _this6.inputs.length; ++i) {
68097 feeds.push({
68098 key: _this6.inputs[i],
68099 value: inputs[i]
68100 });
68101 }
68102 var feedDict = new FeedDict(feeds);
68103 var outputs = execute(_this6.outputs, feedDict);
68104 // Compute total loss.
68105 for (var _i7 = 0; _i7 < _this6.lossFunctions.length; ++_i7) {
68106 var lossFunction = _this6.lossFunctions[_i7];
68107 // TODO(cais): Add sample weighting and replace the simple
68108 // averaging.
68109 var loss = mean$3(lossFunction(targets[_i7], outputs[_i7]));
68110 if (_i7 === 0) {
68111 totalLoss = loss;
68112 } else {
68113 totalLoss = add$3(totalLoss, loss);
68114 }
68115 valOutputs.push(totalLoss);
68116 }
68117 // Compute the metrics.
68118 for (var _i8 = 0; _i8 < _this6.metricsTensors.length; ++_i8) {
68119 var metric = _this6.metricsTensors[_i8][0];
68120 var outputIndex = _this6.metricsTensors[_i8][1];
68121 // TODO(cais): Replace K.mean() with a proper weighting function.
68122 var meanMetric = mean$3(metric(targets[outputIndex], outputs[outputIndex]));
68123 valOutputs.push(meanMetric);
68124 }
68125 return valOutputs;
68126 });
68127 };
68128 }
68129 /**
68130 * Trains the model for a fixed number of epochs (iterations on a
68131 * dataset).
68132 *
68133 * ```js
68134 * const model = tf.sequential({
68135 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
68136 * });
68137 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
68138 * for (let i = 1; i < 5 ; ++i) {
68139 * const h = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
68140 * batchSize: 4,
68141 * epochs: 3
68142 * });
68143 * console.log("Loss after Epoch " + i + " : " + h.history.loss[0]);
68144 * }
68145 * ```
68146 *
68147 * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
68148 * model has multiple inputs. If all inputs in the model are named, you
68149 * can also pass a dictionary mapping input names to `tf.Tensor`s.
68150 * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
68151 * the model has multiple outputs. If all outputs in the model are named,
68152 * you can also pass a dictionary mapping output names to `tf.Tensor`s.
68153 * @param args A `ModelFitArgs`, containing optional fields.
68154 *
68155 * @return A `History` instance. Its `history` attribute contains all
68156 * information collected during training.
68157 *
68158 * @exception ValueError In case of mismatch between the provided input
68159 * data and what the model expects.
68160 *
68161 * @doc {heading: 'Models', subheading: 'Classes'}
68162 */
68163 }, {
68164 key: "fit",
68165 value: function () {
68166 var _fit = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(x, y) {
68167 var args,
68168 inputs,
68169 targets,
68170 originalInputs,
68171 originalTargets,
68172 inputValX,
68173 inputValY,
68174 valX,
68175 valY,
68176 sampleWeights,
68177 batchSize,
68178 checkBatchAxis,
68179 standardizedOuts,
68180 doValidation,
68181 valIns,
68182 _checkBatchAxis,
68183 valStandardized,
68184 splitAt,
68185 originalBatchSize,
68186 ins,
68187 trainFunction,
68188 outLabels,
68189 valFunction,
68190 callbackMetrics,
68191 callbacks,
68192 out,
68193 _args3 = arguments;
68194 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
68195 while (1) switch (_context3.prev = _context3.next) {
68196 case 0:
68197 args = _args3.length > 2 && _args3[2] !== undefined ? _args3[2] : {};
68198 if (!this.isTraining) {
68199 _context3.next = 3;
68200 break;
68201 }
68202 throw new Error('Cannot start training because another fit() call is ongoing.');
68203 case 3:
68204 this.isTraining = true;
68205 _context3.prev = 4;
68206 batchSize = args.batchSize == null ? 32 : args.batchSize;
68207 checkBatchSize(batchSize);
68208 // Validate user data.
68209 // TODO(cais): Support sampleWeight.
68210 checkBatchAxis = false;
68211 _context3.next = 10;
68212 return this.standardizeUserData(x, y, args.sampleWeight, args.classWeight, checkBatchAxis, batchSize);
68213 case 10:
68214 standardizedOuts = _context3.sent;
68215 inputs = standardizedOuts[0];
68216 targets = standardizedOuts[1];
68217 sampleWeights = standardizedOuts[2];
68218 // Prepare validation data.
68219 doValidation = false;
68220 if (!(args.validationData != null && args.validationData.length > 0)) {
68221 _context3.next = 36;
68222 break;
68223 }
68224 doValidation = true;
68225 if (!(args.validationData.length === 2)) {
68226 _context3.next = 22;
68227 break;
68228 }
68229 // config.validationData consists of valX and valY.
68230 inputValX = args.validationData[0];
68231 inputValY = args.validationData[1];
68232 _context3.next = 27;
68233 break;
68234 case 22:
68235 if (!(args.validationData.length === 3)) {
68236 _context3.next = 26;
68237 break;
68238 }
68239 throw new NotImplementedError('validationData including sample weights is not supported yet.');
68240 case 26:
68241 throw new ValueError("When passing validation data, it must contain 2 (valX, valY) " + "or 3 (valX, valY, valSampleWeight) items; " + "".concat(args.validationData, " is invalid."));
68242 case 27:
68243 _checkBatchAxis = true;
68244 _context3.next = 30;
68245 return this.standardizeUserData(inputValX, inputValY, null, /** Unused sample weights. */null, /** Unused class weights. */_checkBatchAxis, batchSize);
68246 case 30:
68247 valStandardized = _context3.sent;
68248 valX = valStandardized[0];
68249 valY = valStandardized[1];
68250 valIns = valX.concat(valY);
68251 // TODO(cais): Add useLearningPhase data properly.
68252 _context3.next = 37;
68253 break;
68254 case 36:
68255 if (args.validationSplit != null && args.validationSplit > 0 && args.validationSplit < 1) {
68256 doValidation = true;
68257 // Porting Note: In tfjs-layers, inputs[0] is always a Tensor.
68258 splitAt = Math.floor(inputs[0].shape[0] * (1 - args.validationSplit));
68259 originalBatchSize = inputs[0].shape[0];
68260 valX = sliceArrays(inputs, splitAt, originalBatchSize);
68261 originalInputs = inputs;
68262 inputs = sliceArrays(inputs, 0, splitAt);
68263 valY = sliceArrays(targets, splitAt, originalBatchSize);
68264 originalTargets = targets;
68265 targets = sliceArrays(targets, 0, splitAt);
68266 // TODO(cais): Once sampleWeights becomes available, slice it to get
68267 // valSampleWeights.
68268 valIns = valX.concat(valY);
68269 // TODO(cais): Add useLearningPhase data properly.
68270 } else if (args.validationSteps != null) {
68271 doValidation = true;
68272 // TODO(cais): Add useLearningPhase.
68273 }
68274 case 37:
68275 ins = inputs.concat(targets).concat(sampleWeights);
68276 this.checkTrainableWeightsConsistency();
68277 // TODO(cais): Handle use_learning_phase and learning_phase?
68278 // Porting Note: Here we see a key deviation of tfjs-layers from
68279 // Keras.
68280 // Due to the imperative nature of tfjs-layers' backend (tfjs-core),
68281 // we do not construct symbolic computation graphs to embody the
68282 // training process. Instead, we define a function that performs the
68283 // training action. In PyKeras, the data (inputs and targets) are fed
68284 // through graph placeholders. In tfjs-layers, the data are fed as
68285 // function arguments. Since the function are defined below in the
68286 // scope, we don't have equivalents of PyKeras's
68287 // `_make_train_funciton`.
68288 trainFunction = this.makeTrainFunction();
68289 outLabels = this.getDedupedMetricsNames();
68290 if (doValidation) {
68291 this.makeTestFunction();
68292 valFunction = this.testFunction;
68293 callbackMetrics = outLabels.slice().concat(outLabels.map(function (n) {
68294 return 'val_' + n;
68295 }));
68296 } else {
68297 valFunction = null;
68298 valIns = [];
68299 callbackMetrics = outLabels.slice();
68300 }
68301 callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
68302 _context3.next = 45;
68303 return this.fitLoop(trainFunction, ins, outLabels, batchSize, args.epochs, args.verbose, callbacks, valFunction, valIns, args.shuffle, callbackMetrics, args.initialEpoch, null, null);
68304 case 45:
68305 out = _context3.sent;
68306 return _context3.abrupt("return", out);
68307 case 47:
68308 _context3.prev = 47;
68309 this.isTraining = false;
68310 // Memory clean up.
68311 disposeNewTensors(inputs, x);
68312 disposeNewTensors(targets, y);
68313 disposeNewTensors(originalInputs, x);
68314 disposeNewTensors(originalTargets, y);
68315 disposeNewTensors(valX, inputValX);
68316 disposeNewTensors(valY, inputValY);
68317 if (sampleWeights != null) {
68318 dispose(sampleWeights);
68319 }
68320 return _context3.finish(47);
68321 case 57:
68322 case "end":
68323 return _context3.stop();
68324 }
68325 }, _callee3, this, [[4,, 47, 57]]);
68326 }));
68327 function fit(_x7, _x8) {
68328 return _fit.apply(this, arguments);
68329 }
68330 return fit;
68331 }()
68332 /**
68333 * Abstract fit function for `f(ins)`.
68334 * @param f A Function returning a list of tensors. For training, this
68335 * function is expected to perform the updates to the variables.
68336 * @param ins List of tensors to be fed to `f`.
68337 * @param outLabels List of strings, display names of the outputs of `f`.
68338 * @param batchSize Integer batch size or `== null` if unknown. Default : 32.
68339 * @param epochs Number of times to iterate over the data. Default : 1.
68340 * @param verbose Verbosity mode: 0, 1, or 2. Default: 1.
68341 * @param callbacks List of callbacks to be called during training.
68342 * @param valF Function to call for validation.
68343 * @param valIns List of tensors to be fed to `valF`.
68344 * @param shuffle Whether to shuffle the data at the beginning of every
68345 * epoch. Default : true.
68346 * @param callbackMetrics List of strings, the display names of the metrics
68347 * passed to the callbacks. They should be the concatenation of the
68348 * display names of the outputs of `f` and the list of display names
68349 * of the outputs of `valF`.
68350 * @param initialEpoch Epoch at which to start training (useful for
68351 * resuming a previous training run). Default : 0.
68352 * @param stepsPerEpoch Total number of steps (batches on samples) before
68353 * declaring one epoch finished and starting the next epoch. Ignored with
68354 * the default value of `undefined` or `null`.
68355 * @param validationSteps Number of steps to run validation for (only if
68356 * doing validation from data tensors). Not applicable for tfjs-layers.
68357 * @returns A `History` object.
68358 */
68359 }, {
68360 key: "fitLoop",
68361 value: function () {
68362 var _fitLoop = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(f, ins, outLabels, batchSize, epochs, verbose, callbacks, valF, valIns, shuffle$1, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps) {
68363 var _this7 = this;
68364 var doValidation, numTrainSamples, indexArray, _configureCallbacks, callbackList, history, _loop4, epoch, _ret2;
68365 return _regeneratorRuntime().wrap(function _callee4$(_context6) {
68366 while (1) switch (_context6.prev = _context6.next) {
68367 case 0:
68368 if (batchSize == null) {
68369 batchSize = 32;
68370 }
68371 if (epochs == null) {
68372 epochs = 1;
68373 }
68374 if (shuffle$1 == null) {
68375 shuffle$1 = true;
68376 }
68377 if (initialEpoch == null) {
68378 initialEpoch = 0;
68379 }
68380 // TODO(cais): Change const to let below when implementing validation.
68381 doValidation = false;
68382 if (valF != null && valIns != null) {
68383 doValidation = true;
68384 // TODO(cais): verbose message.
68385 }
68386 if (!(validationSteps != null)) {
68387 _context6.next = 10;
68388 break;
68389 }
68390 doValidation = true;
68391 if (!(stepsPerEpoch == null)) {
68392 _context6.next = 10;
68393 break;
68394 }
68395 throw new ValueError('Can only use `validationSteps` when doing step-wise training, ' + 'i.e., `stepsPerEpoch` must be set.');
68396 case 10:
68397 numTrainSamples = this.checkNumSamples(ins, batchSize, stepsPerEpoch, 'steps_per_epoch');
68398 if (numTrainSamples != null) {
68399 indexArray = range$2(0, numTrainSamples);
68400 }
68401 if (verbose == null) {
68402 verbose = 1;
68403 }
68404 _configureCallbacks = configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics), callbackList = _configureCallbacks.callbackList, history = _configureCallbacks.history;
68405 callbackList.setModel(this);
68406 this.history = history;
68407 _context6.next = 18;
68408 return callbackList.onTrainBegin();
68409 case 18:
68410 this.stopTraining_ = false;
68411 // TODO(cais): Take care of callbacks.validation_data as in PyKeras.
68412 // TODO(cais): Pre-convert feeds for performance as in PyKeras.
68413 _loop4 = /*#__PURE__*/_regeneratorRuntime().mark(function _loop4() {
68414 var epochLogs, epochIndexArray1D, batches, _loop5, batchIndex, _ret3;
68415 return _regeneratorRuntime().wrap(function _loop4$(_context5) {
68416 while (1) switch (_context5.prev = _context5.next) {
68417 case 0:
68418 _context5.next = 2;
68419 return callbackList.onEpochBegin(epoch);
68420 case 2:
68421 epochLogs = {};
68422 if (!(stepsPerEpoch != null)) {
68423 _context5.next = 7;
68424 break;
68425 }
68426 throw new NotImplementedError('stepsPerEpoch mode is not implemented yet.');
68427 case 7:
68428 if (!(shuffle$1 === 'batch')) {
68429 _context5.next = 11;
68430 break;
68431 }
68432 throw new NotImplementedError('batch shuffling is not implemneted' + ' yet');
68433 case 11:
68434 if (shuffle$1) {
68435 shuffle(indexArray);
68436 }
68437 case 12:
68438 // Convert the potentially shuffled indices to Tensor1D, to avoid the
68439 // cost of repeated creation of Array1Ds later on.
68440 epochIndexArray1D = tensor1d(indexArray);
68441 batches = makeBatches(numTrainSamples, batchSize);
68442 _loop5 = /*#__PURE__*/_regeneratorRuntime().mark(function _loop5(batchIndex) {
68443 var batchLogs;
68444 return _regeneratorRuntime().wrap(function _loop5$(_context4) {
68445 while (1) switch (_context4.prev = _context4.next) {
68446 case 0:
68447 batchLogs = {};
68448 _context4.next = 3;
68449 return callbackList.onBatchBegin(batchIndex, batchLogs);
68450 case 3:
68451 tidy(function () {
68452 var batchStart = batches[batchIndex][0];
68453 var batchEnd = batches[batchIndex][1];
68454 var batchIds = sliceAlongFirstAxis(epochIndexArray1D, batchStart, batchEnd - batchStart);
68455 batchLogs['batch'] = batchIndex;
68456 batchLogs['size'] = batchEnd - batchStart;
68457 // TODO(cais): In ins, train flag can be a number, instead of an
68458 // Tensor? Do we need to handle this in tfjs-layers?
68459 var insBatch = sliceArraysByIndices(ins, batchIds);
68460 var outs = f(insBatch);
68461 for (var i = 0; i < outLabels.length; ++i) {
68462 var label = outLabels[i];
68463 var out = outs[i];
68464 batchLogs[label] = out;
68465 keep(out);
68466 // TODO(cais): Use scope() to avoid ownership.
68467 }
68468
68469 if (batchIndex === batches.length - 1) {
68470 // Last batch.
68471 if (doValidation) {
68472 var valOuts = _this7.testLoop(valF, valIns, batchSize);
68473 // Porting Notes: In tfjs-layers, valOuts is always an Array.
68474 for (var _i9 = 0; _i9 < outLabels.length; ++_i9) {
68475 var _label = outLabels[_i9];
68476 var _out = valOuts[_i9];
68477 keep(_out);
68478 // TODO(cais): Use scope() to avoid ownership.
68479 epochLogs['val_' + _label] = _out;
68480 }
68481 }
68482 }
68483 });
68484 _context4.next = 6;
68485 return callbackList.onBatchEnd(batchIndex, batchLogs);
68486 case 6:
68487 disposeTensorsInLogs(batchLogs);
68488 if (!_this7.stopTraining_) {
68489 _context4.next = 9;
68490 break;
68491 }
68492 return _context4.abrupt("return", "break");
68493 case 9:
68494 case "end":
68495 return _context4.stop();
68496 }
68497 }, _loop5);
68498 });
68499 batchIndex = 0;
68500 case 16:
68501 if (!(batchIndex < batches.length)) {
68502 _context5.next = 24;
68503 break;
68504 }
68505 return _context5.delegateYield(_loop5(batchIndex), "t0", 18);
68506 case 18:
68507 _ret3 = _context5.t0;
68508 if (!(_ret3 === "break")) {
68509 _context5.next = 21;
68510 break;
68511 }
68512 return _context5.abrupt("break", 24);
68513 case 21:
68514 ++batchIndex;
68515 _context5.next = 16;
68516 break;
68517 case 24:
68518 epochIndexArray1D.dispose();
68519 case 25:
68520 _context5.next = 27;
68521 return callbackList.onEpochEnd(epoch, epochLogs);
68522 case 27:
68523 if (!_this7.stopTraining_) {
68524 _context5.next = 29;
68525 break;
68526 }
68527 return _context5.abrupt("return", "break");
68528 case 29:
68529 case "end":
68530 return _context5.stop();
68531 }
68532 }, _loop4);
68533 });
68534 epoch = initialEpoch;
68535 case 21:
68536 if (!(epoch < epochs)) {
68537 _context6.next = 29;
68538 break;
68539 }
68540 return _context6.delegateYield(_loop4(), "t0", 23);
68541 case 23:
68542 _ret2 = _context6.t0;
68543 if (!(_ret2 === "break")) {
68544 _context6.next = 26;
68545 break;
68546 }
68547 return _context6.abrupt("break", 29);
68548 case 26:
68549 ++epoch;
68550 _context6.next = 21;
68551 break;
68552 case 29:
68553 _context6.next = 31;
68554 return callbackList.onTrainEnd();
68555 case 31:
68556 _context6.next = 33;
68557 return this.history.syncData();
68558 case 33:
68559 return _context6.abrupt("return", this.history);
68560 case 34:
68561 case "end":
68562 return _context6.stop();
68563 }
68564 }, _callee4, this);
68565 }));
68566 function fitLoop(_x9, _x10, _x11, _x12, _x13, _x14, _x15, _x16, _x17, _x18, _x19, _x20, _x21, _x22) {
68567 return _fitLoop.apply(this, arguments);
68568 }
68569 return fitLoop;
68570 }() // TODO(cais): Add code snippet below when it's possible to instantiate
68571 // actual dataset objects.
68572 /**
68573 * Trains the model using a dataset object.
68574 *
68575 * @param dataset A dataset object. Its `iterator()` method is expected
68576 * to generate a dataset iterator object, the `next()` method of which
68577 * is expected to produce data batches for training. The return value
68578 * of the `next()` call ought to contain a boolean `done` field and a
68579 * `value` field. The `value` field is expected to be an array of two
68580 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
68581 * case is for models with exactly one input and one output (e.g.
68582 * a sequential model). The latter case is for models with multiple
68583 * inputs and/or multiple outputs.
68584 * Of the two items in the array, the first is the input feature(s) and
68585 * the second is the output target(s).
68586 * @param args A `ModelFitDatasetArgs`, containing optional fields.
68587 *
68588 * @return A `History` instance. Its `history` attribute contains all
68589 * information collected during training.
68590 *
68591 * @doc {heading: 'Models', subheading: 'Classes'}
68592 */
68593 }, {
68594 key: "fitDataset",
68595 value: function () {
68596 var _fitDataset2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5(dataset, args) {
68597 return _regeneratorRuntime().wrap(function _callee5$(_context7) {
68598 while (1) switch (_context7.prev = _context7.next) {
68599 case 0:
68600 return _context7.abrupt("return", fitDataset(this, dataset, args));
68601 case 1:
68602 case "end":
68603 return _context7.stop();
68604 }
68605 }, _callee5, this);
68606 }));
68607 function fitDataset$1(_x23, _x24) {
68608 return _fitDataset2.apply(this, arguments);
68609 }
68610 return fitDataset$1;
68611 }()
68612 /**
68613 * Runs a single gradient update on a single batch of data.
68614 *
68615 * This method differs from `fit()` and `fitDataset()` in the following
68616 * regards:
68617 * - It operates on exactly one batch of data.
68618 * - It returns only the loss and metric values, instead of
68619 * returning the batch-by-batch loss and metric values.
68620 * - It doesn't support fine-grained options such as verbosity and
68621 * callbacks.
68622 *
68623 * @param x Input data. It could be one of the following:
68624 * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
68625 * multiple inputs).
68626 * - An Object mapping input names to corresponding `tf.Tensor` (if the
68627 * model has named inputs).
68628 * @param y Target data. It could be either a `tf.Tensor` or multiple
68629 * `tf.Tensor`s. It should be consistent with `x`.
68630 * @returns Training loss or losses (in case the model has
68631 * multiple outputs), along with metrics (if any), as numbers.
68632 *
68633 * @doc {heading: 'Models', subheading: 'Classes'}
68634 */
68635 }, {
68636 key: "trainOnBatch",
68637 value: function () {
68638 var _trainOnBatch = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee6(x, y) {
68639 var standardizeOut, inputs, targets, trainFunction, losses, lossValues, _iterator7, _step7, loss, v;
68640 return _regeneratorRuntime().wrap(function _callee6$(_context8) {
68641 while (1) switch (_context8.prev = _context8.next) {
68642 case 0:
68643 _context8.next = 2;
68644 return this.standardizeUserData(x, y);
68645 case 2:
68646 standardizeOut = _context8.sent;
68647 inputs = standardizeOut[0];
68648 targets = standardizeOut[1];
68649 trainFunction = this.makeTrainFunction();
68650 losses = trainFunction(inputs.concat(targets));
68651 lossValues = [];
68652 _iterator7 = _createForOfIteratorHelper(losses);
68653 _context8.prev = 9;
68654 _iterator7.s();
68655 case 11:
68656 if ((_step7 = _iterator7.n()).done) {
68657 _context8.next = 19;
68658 break;
68659 }
68660 loss = _step7.value;
68661 _context8.next = 15;
68662 return loss.data();
68663 case 15:
68664 v = _context8.sent;
68665 lossValues.push(v[0]);
68666 case 17:
68667 _context8.next = 11;
68668 break;
68669 case 19:
68670 _context8.next = 24;
68671 break;
68672 case 21:
68673 _context8.prev = 21;
68674 _context8.t0 = _context8["catch"](9);
68675 _iterator7.e(_context8.t0);
68676 case 24:
68677 _context8.prev = 24;
68678 _iterator7.f();
68679 return _context8.finish(24);
68680 case 27:
68681 dispose(losses);
68682 disposeNewTensors(standardizeOut[0], x);
68683 disposeNewTensors(standardizeOut[1], y);
68684 return _context8.abrupt("return", singletonOrArray(lossValues));
68685 case 31:
68686 case "end":
68687 return _context8.stop();
68688 }
68689 }, _callee6, this, [[9, 21, 24, 27]]);
68690 }));
68691 function trainOnBatch(_x25, _x26) {
68692 return _trainOnBatch.apply(this, arguments);
68693 }
68694 return trainOnBatch;
68695 }()
68696 /**
68697 * Extract weight values of the model.
68698 *
68699 * @param config: An instance of `io.SaveConfig`, which specifies
68700 * model-saving options such as whether only trainable weights are to be
68701 * saved.
68702 * @returns A `NamedTensorMap` mapping original weight names (i.e.,
68703 * non-uniqueified weight names) to their values.
68704 */
68705 }, {
68706 key: "getNamedWeights",
68707 value: function getNamedWeights(config) {
68708 var namedWeights = [];
68709 var trainableOnly = config != null && config.trainableOnly;
68710 var weights = trainableOnly ? this.trainableWeights : this.weights;
68711 var weightValues = this.getWeights(trainableOnly);
68712 for (var i = 0; i < weights.length; ++i) {
68713 if (trainableOnly && !weights[i].trainable) {
68714 // Optionally skip non-trainable weights.
68715 continue;
68716 }
68717 namedWeights.push({
68718 name: weights[i].originalName,
68719 tensor: weightValues[i]
68720 });
68721 }
68722 return namedWeights;
68723 }
68724 /**
68725 * Setter used for force stopping of LayersModel.fit() (i.e., training).
68726 *
68727 * Example:
68728 *
68729 * ```js
68730 * const input = tf.input({shape: [10]});
68731 * const output = tf.layers.dense({units: 1}).apply(input);
68732 * const model = tf.model({inputs: [input], outputs: [output]});
68733 * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
68734 * const xs = tf.ones([8, 10]);
68735 * const ys = tf.zeros([8, 1]);
68736 *
68737 * const history = await model.fit(xs, ys, {
68738 * epochs: 10,
68739 * callbacks: {
68740 * onEpochEnd: async (epoch, logs) => {
68741 * if (epoch === 2) {
68742 * model.stopTraining = true;
68743 * }
68744 * }
68745 * }
68746 * });
68747 *
68748 * // There should be only 3 values in the loss array, instead of 10
68749 * values,
68750 * // due to the stopping after 3 epochs.
68751 * console.log(history.history.loss);
68752 * ```
68753 */
68754 }, {
68755 key: "stopTraining",
68756 get: function get() {
68757 return this.stopTraining_;
68758 },
68759 set: function set(stop) {
68760 this.stopTraining_ = stop;
68761 }
68762 }, {
68763 key: "optimizer",
68764 get: function get() {
68765 return this.optimizer_;
68766 },
68767 set: function set(optimizer) {
68768 if (this.optimizer_ !== optimizer) {
68769 this.optimizer_ = optimizer;
68770 this.isOptimizerOwned = false;
68771 }
68772 }
68773 }, {
68774 key: "dispose",
68775 value: function dispose() {
68776 var result = _get(_getPrototypeOf(LayersModel.prototype), "dispose", this).call(this);
68777 if (result.refCountAfterDispose === 0 && this.optimizer != null && this.isOptimizerOwned) {
68778 var numTensorsBeforeOptmizerDisposal = memory().numTensors;
68779 this.optimizer_.dispose();
68780 result.numDisposedVariables += numTensorsBeforeOptmizerDisposal - memory().numTensors;
68781 }
68782 return result;
68783 }
68784 }, {
68785 key: "getLossIdentifiers",
68786 value: function getLossIdentifiers() {
68787 var lossNames;
68788 if (typeof this.loss === 'string') {
68789 lossNames = toSnakeCase(this.loss);
68790 } else if (Array.isArray(this.loss)) {
68791 var _iterator8 = _createForOfIteratorHelper(this.loss),
68792 _step8;
68793 try {
68794 for (_iterator8.s(); !(_step8 = _iterator8.n()).done;) {
68795 var loss = _step8.value;
68796 if (typeof loss !== 'string') {
68797 throw new Error('Serialization of non-string loss is not supported.');
68798 }
68799 }
68800 } catch (err) {
68801 _iterator8.e(err);
68802 } finally {
68803 _iterator8.f();
68804 }
68805 lossNames = this.loss.map(function (name) {
68806 return toSnakeCase(name);
68807 });
68808 } else {
68809 var outputNames = Object.keys(this.loss);
68810 lossNames = {};
68811 var _losses = this.loss;
68812 for (var _i10 = 0, _outputNames = outputNames; _i10 < _outputNames.length; _i10++) {
68813 var outputName = _outputNames[_i10];
68814 if (typeof _losses[outputName] === 'string') {
68815 lossNames[outputName] = toSnakeCase(_losses[outputName]);
68816 } else {
68817 throw new Error('Serialization of non-string loss is not supported.');
68818 }
68819 }
68820 }
68821 return lossNames;
68822 }
68823 }, {
68824 key: "getMetricIdentifiers",
68825 value: function getMetricIdentifiers() {
68826 if (typeof this.metrics === 'string' || typeof this.metrics === 'function') {
68827 return [toSnakeCase(getLossOrMetricName(this.metrics))];
68828 } else if (Array.isArray(this.metrics)) {
68829 return this.metrics.map(function (metric) {
68830 return toSnakeCase(getLossOrMetricName(metric));
68831 });
68832 } else {
68833 var metricsIdentifiers = {};
68834 for (var key in this.metrics) {
68835 metricsIdentifiers[key] = toSnakeCase(getLossOrMetricName(this.metrics[key]));
68836 }
68837 return metricsIdentifiers;
68838 }
68839 }
68840 }, {
68841 key: "getTrainingConfig",
68842 value: function getTrainingConfig() {
68843 return {
68844 loss: this.getLossIdentifiers(),
68845 metrics: this.getMetricIdentifiers(),
68846 optimizer_config: {
68847 class_name: this.optimizer.getClassName(),
68848 config: this.optimizer.getConfig()
68849 }
68850 };
68851 // TODO(cais): Add weight_metrics when they are supported.
68852 // TODO(cais): Add sample_weight_mode when it's supported.
68853 // TODO(cais): Add loss_weights when it's supported.
68854 }
68855 }, {
68856 key: "loadTrainingConfig",
68857 value: function loadTrainingConfig(trainingConfig) {
68858 if (trainingConfig.weighted_metrics != null) {
68859 throw new Error('Loading weight_metrics is not supported yet.');
68860 }
68861 if (trainingConfig.loss_weights != null) {
68862 throw new Error('Loading loss_weights is not supported yet.');
68863 }
68864 if (trainingConfig.sample_weight_mode != null) {
68865 throw new Error('Loading sample_weight_mode is not supported yet.');
68866 }
68867 var tsConfig = convertPythonicToTs(trainingConfig.optimizer_config);
68868 var optimizer = deserialize(tsConfig);
68869 var loss;
68870 if (typeof trainingConfig.loss === 'string') {
68871 loss = toCamelCase(trainingConfig.loss);
68872 } else if (Array.isArray(trainingConfig.loss)) {
68873 loss = trainingConfig.loss.map(function (lossEntry) {
68874 return toCamelCase(lossEntry);
68875 });
68876 } else if (trainingConfig.loss != null) {
68877 loss = {};
68878 for (var key in trainingConfig.loss) {
68879 loss[key] = toCamelCase(trainingConfig.loss[key]);
68880 }
68881 }
68882 var metrics;
68883 if (Array.isArray(trainingConfig.metrics)) {
68884 metrics = trainingConfig.metrics.map(function (metric) {
68885 return toCamelCase(metric);
68886 });
68887 } else if (trainingConfig.metrics != null) {
68888 metrics = {};
68889 for (var _key in trainingConfig.metrics) {
68890 metrics[_key] = toCamelCase(trainingConfig.metrics[_key]);
68891 }
68892 }
68893 this.compile({
68894 loss: loss,
68895 metrics: metrics,
68896 optimizer: optimizer
68897 });
68898 }
68899 /**
68900 * Save the configuration and/or weights of the LayersModel.
68901 *
68902 * An `IOHandler` is an object that has a `save` method of the proper
68903 * signature defined. The `save` method manages the storing or
68904 * transmission of serialized data ("artifacts") that represent the
68905 * model's topology and weights onto or via a specific medium, such as
68906 * file downloads, local storage, IndexedDB in the web browser and HTTP
68907 * requests to a server. TensorFlow.js provides `IOHandler`
68908 * implementations for a number of frequently used saving mediums, such as
68909 * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
68910 * for more details.
68911 *
68912 * This method also allows you to refer to certain types of `IOHandler`s
68913 * as URL-like string shortcuts, such as 'localstorage://' and
68914 * 'indexeddb://'.
68915 *
68916 * Example 1: Save `model`'s topology and weights to browser [local
68917 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
68918 * then load it back.
68919 *
68920 * ```js
68921 * const model = tf.sequential(
68922 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
68923 * console.log('Prediction from original model:');
68924 * model.predict(tf.ones([1, 3])).print();
68925 *
68926 * const saveResults = await model.save('localstorage://my-model-1');
68927 *
68928 * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
68929 * console.log('Prediction from loaded model:');
68930 * loadedModel.predict(tf.ones([1, 3])).print();
68931 * ```
68932 *
68933 * Example 2. Saving `model`'s topology and weights to browser
68934 * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
68935 * then load it back.
68936 *
68937 * ```js
68938 * const model = tf.sequential(
68939 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
68940 * console.log('Prediction from original model:');
68941 * model.predict(tf.ones([1, 3])).print();
68942 *
68943 * const saveResults = await model.save('indexeddb://my-model-1');
68944 *
68945 * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
68946 * console.log('Prediction from loaded model:');
68947 * loadedModel.predict(tf.ones([1, 3])).print();
68948 * ```
68949 *
68950 * Example 3. Saving `model`'s topology and weights as two files
68951 * (`my-model-1.json` and `my-model-1.weights.bin`) downloaded from
68952 * browser.
68953 *
68954 * ```js
68955 * const model = tf.sequential(
68956 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
68957 * const saveResults = await model.save('downloads://my-model-1');
68958 * ```
68959 *
68960 * Example 4. Send `model`'s topology and weights to an HTTP server.
68961 * See the documentation of `tf.io.http` for more details
68962 * including specifying request parameters and implementation of the
68963 * server.
68964 *
68965 * ```js
68966 * const model = tf.sequential(
68967 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
68968 * const saveResults = await model.save('http://my-server/model/upload');
68969 * ```
68970 *
68971 * @param handlerOrURL An instance of `IOHandler` or a URL-like,
68972 * scheme-based string shortcut for `IOHandler`.
68973 * @param config Options for saving the model.
68974 * @returns A `Promise` of `SaveResult`, which summarizes the result of
68975 * the saving, such as byte sizes of the saved artifacts for the model's
68976 * topology and weight values.
68977 *
68978 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
68979 */
68980 }, {
68981 key: "save",
68982 value: function () {
68983 var _save = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee7(handlerOrURL, config) {
68984 var handlers, weightDataAndSpecs, returnString, unusedArg, modelConfig, modelArtifacts, includeOptimizer, _weightDataAndSpecs$s, weightType, _yield$io$encodeWeigh, optimizerWeightData, optimizerWeightSpecs, checkSize;
68985 return _regeneratorRuntime().wrap(function _callee7$(_context9) {
68986 while (1) switch (_context9.prev = _context9.next) {
68987 case 0:
68988 if (!(typeof handlerOrURL === 'string')) {
68989 _context9.next = 9;
68990 break;
68991 }
68992 handlers = getSaveHandlers(handlerOrURL);
68993 if (!(handlers.length === 0)) {
68994 _context9.next = 6;
68995 break;
68996 }
68997 throw new ValueError("Cannot find any save handlers for URL '".concat(handlerOrURL, "'"));
68998 case 6:
68999 if (!(handlers.length > 1)) {
69000 _context9.next = 8;
69001 break;
69002 }
69003 throw new ValueError("Found more than one (".concat(handlers.length, ") save handlers for ") + "URL '".concat(handlerOrURL, "'"));
69004 case 8:
69005 handlerOrURL = handlers[0];
69006 case 9:
69007 if (!(handlerOrURL.save == null)) {
69008 _context9.next = 11;
69009 break;
69010 }
69011 throw new ValueError('LayersModel.save() cannot proceed because the IOHandler ' + 'provided does not have the `save` attribute defined.');
69012 case 11:
69013 _context9.next = 13;
69014 return encodeWeights(this.getNamedWeights(config));
69015 case 13:
69016 weightDataAndSpecs = _context9.sent;
69017 returnString = false;
69018 unusedArg = null;
69019 modelConfig = this.toJSON(unusedArg, returnString);
69020 modelArtifacts = {
69021 modelTopology: modelConfig,
69022 format: LAYERS_MODEL_FORMAT_NAME,
69023 generatedBy: "TensorFlow.js tfjs-layers v".concat(version$6),
69024 convertedBy: null
69025 };
69026 includeOptimizer = config == null ? false : config.includeOptimizer;
69027 if (!(includeOptimizer && this.optimizer != null)) {
69028 _context9.next = 34;
69029 break;
69030 }
69031 modelArtifacts.trainingConfig = this.getTrainingConfig();
69032 weightType = 'optimizer';
69033 _context9.t0 = io;
69034 _context9.next = 25;
69035 return this.optimizer.getWeights();
69036 case 25:
69037 _context9.t1 = _context9.sent;
69038 _context9.t2 = weightType;
69039 _context9.next = 29;
69040 return _context9.t0.encodeWeights.call(_context9.t0, _context9.t1, _context9.t2);
69041 case 29:
69042 _yield$io$encodeWeigh = _context9.sent;
69043 optimizerWeightData = _yield$io$encodeWeigh.data;
69044 optimizerWeightSpecs = _yield$io$encodeWeigh.specs;
69045 (_weightDataAndSpecs$s = weightDataAndSpecs.specs).push.apply(_weightDataAndSpecs$s, _toConsumableArray(optimizerWeightSpecs));
69046 weightDataAndSpecs.data = concatenateArrayBuffers([weightDataAndSpecs.data, optimizerWeightData]);
69047 case 34:
69048 if (this.userDefinedMetadata != null) {
69049 // Check serialized size of user-defined metadata.
69050 checkSize = true;
69051 checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize);
69052 modelArtifacts.userDefinedMetadata = this.userDefinedMetadata;
69053 }
69054 modelArtifacts.weightData = weightDataAndSpecs.data;
69055 modelArtifacts.weightSpecs = weightDataAndSpecs.specs;
69056 return _context9.abrupt("return", handlerOrURL.save(modelArtifacts));
69057 case 38:
69058 case "end":
69059 return _context9.stop();
69060 }
69061 }, _callee7, this);
69062 }));
69063 function save(_x27, _x28) {
69064 return _save.apply(this, arguments);
69065 }
69066 return save;
69067 }()
69068 /**
69069 * Set user-defined metadata.
69070 *
69071 * The set metadata will be serialized together with the topology
69072 * and weights of the model during `save()` calls.
69073 *
69074 * @param setUserDefinedMetadata
69075 */
69076 }, {
69077 key: "setUserDefinedMetadata",
69078 value: function setUserDefinedMetadata(userDefinedMetadata) {
69079 checkUserDefinedMetadata(userDefinedMetadata, this.name);
69080 this.userDefinedMetadata = userDefinedMetadata;
69081 }
69082 /**
69083 * Get user-defined metadata.
69084 *
69085 * The metadata is supplied via one of the two routes:
69086 * 1. By calling `setUserDefinedMetadata()`.
69087 * 2. Loaded during model loading (if the model is constructed
69088 * via `tf.loadLayersModel()`.)
69089 *
69090 * If no user-defined metadata is available from either of the
69091 * two routes, this function will return `undefined`.
69092 */
69093 }, {
69094 key: "getUserDefinedMetadata",
69095 value: function getUserDefinedMetadata() {
69096 return this.userDefinedMetadata;
69097 }
69098 }]);
69099 return LayersModel;
69100 }(Container); // The class name is 'Model' rather than 'LayersModel' for backwards
69101 // compatibility since this class name shows up in the serialization format.
69102 /** @nocollapse */
69103 LayersModel.className = 'Model';
69104 registerClass(LayersModel);
69105 /**
69106 * A `tf.Functional` is an alias to `tf.LayersModel`.
69107 *
69108 * See also:
69109 * `tf.LayersModel`, `tf.Sequential`, `tf.loadLayersModel`.
69110 */
69111 /** @doc {heading: 'Models', subheading: 'Classes'} */
69112 var Functional = /*#__PURE__*/function (_LayersModel) {
69113 _inherits(Functional, _LayersModel);
69114 var _super2 = _createSuper(Functional);
69115 function Functional() {
69116 _classCallCheck(this, Functional);
69117 return _super2.apply(this, arguments);
69118 }
69119 return _createClass(Functional);
69120 }(LayersModel);
69121 Functional.className = 'Functional';
69122 registerClass(Functional);
69123
69124 /**
69125 * Parses a JSON model configuration file and returns a model instance.
69126 *
69127 * ```js
69128 * // This example shows how to serialize a model using `toJSON()` and
69129 * // deserialize it as another model using `tf.models.modelFromJSON()`.
69130 * // Note: this example serializes and deserializes only the topology
69131 * // of the model; the weights of the loaded model will be different
69132 * // from those of the the original model, due to random weight
69133 * // initialization.
69134 * // To load the topology and weights of a model, use `tf.loadLayersModel()`.
69135 * const model1 = tf.sequential();
69136 * model1.add(tf.layers.repeatVector({inputShape: [2], n: 4}));
69137 * // Serialize `model1` as a JSON object.
69138 * const model1JSON = model1.toJSON(null, false);
69139 * model1.summary();
69140 *
69141 * const model2 = await tf.models.modelFromJSON(model1JSON);
69142 * model2.summary();
69143 * ```
69144 *
69145 * @param modelAndWeightsConfig JSON object or string encoding a model and
69146 * weights configuration. It can also be only the topology JSON of the
69147 * model, in which case the weights will not be loaded.
69148 * @param custom_objects Optional dictionary mapping names
69149 * (strings) to custom classes or functions to be
69150 * considered during deserialization.
69151 * @returns A TensorFlow.js Layers `tf.LayersModel` instance (uncompiled).
69152 */
69153 function modelFromJSON(_x, _x2) {
69154 return _modelFromJSON.apply(this, arguments);
69155 }
69156 /**
69157 * Load a model composed of Layer objects, including its topology and optionally
69158 * weights. See the Tutorial named "How to import a Keras Model" for usage
69159 * examples.
69160 *
69161 * This method is applicable to:
69162 *
69163 * 1. Models created with the `tf.layers.*`, `tf.sequential`, and
69164 * `tf.model` APIs of TensorFlow.js and later saved with the
69165 * `tf.LayersModel.save` method.
69166 * 2. Models converted from Keras or TensorFlow tf.keras using the
69167 * [tensorflowjs_converter](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter).
69168 *
69169 * This mode is *not* applicable to TensorFlow `SavedModel`s or their converted
69170 * forms. For those models, use `tf.loadGraphModel`.
69171 *
69172 * Example 1. Load a model from an HTTP server.
69173 *
69174 * ```js
69175 * const model = await tf.loadLayersModel(
69176 * 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json');
69177 * model.summary();
69178 * ```
69179 *
69180 * Example 2: Save `model`'s topology and weights to browser [local
69181 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
69182 * then load it back.
69183 *
69184 * ```js
69185 * const model = tf.sequential(
69186 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
69187 * console.log('Prediction from original model:');
69188 * model.predict(tf.ones([1, 3])).print();
69189 *
69190 * const saveResults = await model.save('localstorage://my-model-1');
69191 *
69192 * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
69193 * console.log('Prediction from loaded model:');
69194 * loadedModel.predict(tf.ones([1, 3])).print();
69195 * ```
69196 *
69197 * Example 3. Saving `model`'s topology and weights to browser
69198 * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
69199 * then load it back.
69200 *
69201 * ```js
69202 * const model = tf.sequential(
69203 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
69204 * console.log('Prediction from original model:');
69205 * model.predict(tf.ones([1, 3])).print();
69206 *
69207 * const saveResults = await model.save('indexeddb://my-model-1');
69208 *
69209 * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
69210 * console.log('Prediction from loaded model:');
69211 * loadedModel.predict(tf.ones([1, 3])).print();
69212 * ```
69213 *
69214 * Example 4. Load a model from user-selected files from HTML
69215 * [file input
69216 * elements](https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/file).
69217 *
69218 * ```js
69219 * // Note: this code snippet will not work without the HTML elements in the
69220 * // page
69221 * const jsonUpload = document.getElementById('json-upload');
69222 * const weightsUpload = document.getElementById('weights-upload');
69223 *
69224 * const model = await tf.loadLayersModel(
69225 * tf.io.browserFiles([jsonUpload.files[0], weightsUpload.files[0]]));
69226 * ```
69227 *
69228 * @param pathOrIOHandler Can be either of the two formats
69229 * 1. A string path to the `ModelAndWeightsConfig` JSON describing
69230 * the model in the canonical TensorFlow.js format. For file://
69231 * (tfjs-node-only), http:// and https:// schemas, the path can be
69232 * either absolute or relative. The content of the JSON file is assumed to
69233 * be a JSON object with the following fields and values:
69234 * - 'modelTopology': A JSON object that can be either of:
69235 * 1. a model architecture JSON consistent with the format of the return
69236 * value of `keras.Model.to_json()`
69237 * 2. a full model JSON in the format of `keras.models.save_model()`.
69238 * - 'weightsManifest': A TensorFlow.js weights manifest.
69239 * See the Python converter function `save_model()` for more details.
69240 * It is also assumed that model weights can be accessed from relative
69241 * paths described by the `paths` fields in weights manifest.
69242 * 2. A `tf.io.IOHandler` object that loads model artifacts with its `load`
69243 * method.
69244 * @param options Optional configuration arguments for the model loading,
69245 * including:
69246 * - `strict`: Require that the provided weights exactly match those required
69247 * by the layers. Default true. Passing false means that both extra
69248 * weights and missing weights will be silently ignored.
69249 * - `onProgress`: A progress callback of the form:
69250 * `(fraction: number) => void`. This callback can be used to monitor the
69251 * model-loading process.
69252 * @returns A `Promise` of `tf.LayersModel`, with the topology and weights
69253 * loaded.
69254 *
69255 * @doc {heading: 'Models', subheading: 'Loading'}
69256 */
69257 function _modelFromJSON() {
69258 _modelFromJSON = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5(modelAndWeightsConfig, customObjects) {
69259 var modelTopology, tsConfig, model, weightValues, uniqueWeightValues, _iterator4, _step4, weight;
69260 return _regeneratorRuntime().wrap(function _callee5$(_context5) {
69261 while (1) switch (_context5.prev = _context5.next) {
69262 case 0:
69263 if (!('modelTopology' in modelAndWeightsConfig)) {
69264 modelAndWeightsConfig = {
69265 modelTopology: modelAndWeightsConfig
69266 };
69267 }
69268 modelAndWeightsConfig = modelAndWeightsConfig;
69269 modelTopology = modelAndWeightsConfig.modelTopology;
69270 if (modelTopology['model_config'] != null) {
69271 // If the model-topology JSON contains a 'model_config' field, then it is
69272 // a full model JSON (e.g., from `keras.Model.save()`), which contains
69273 // not only the model's architecture in its 'model_config' field, but
69274 // additional information such as the model's optimizer. We use only the
69275 // 'model_config' field currently.
69276 modelTopology = modelTopology['model_config'];
69277 }
69278 tsConfig = convertPythonicToTs(modelTopology);
69279 model = deserialize(tsConfig, customObjects);
69280 if (!(modelAndWeightsConfig.weightsManifest != null)) {
69281 _context5.next = 15;
69282 break;
69283 }
69284 _context5.next = 9;
69285 return loadWeights(modelAndWeightsConfig.weightsManifest, modelAndWeightsConfig.pathPrefix, model.weights.map(function (weight) {
69286 return weight.originalName;
69287 }));
69288 case 9:
69289 weightValues = _context5.sent;
69290 // Map the weights to the unique tensor names generated during model loading
69291 uniqueWeightValues = {};
69292 _iterator4 = _createForOfIteratorHelper(model.weights);
69293 try {
69294 for (_iterator4.s(); !(_step4 = _iterator4.n()).done;) {
69295 weight = _step4.value;
69296 uniqueWeightValues[weight.originalName] = weightValues[weight.originalName];
69297 }
69298 } catch (err) {
69299 _iterator4.e(err);
69300 } finally {
69301 _iterator4.f();
69302 }
69303 model.loadWeights(uniqueWeightValues);
69304 // Dispose temporary weight values.
69305 dispose(weightValues);
69306 case 15:
69307 return _context5.abrupt("return", model);
69308 case 16:
69309 case "end":
69310 return _context5.stop();
69311 }
69312 }, _callee5);
69313 }));
69314 return _modelFromJSON.apply(this, arguments);
69315 }
69316 function loadLayersModel(_x3, _x4) {
69317 return _loadLayersModel.apply(this, arguments);
69318 }
69319 /**
69320 * Load a model and optionally its weights, using an IOHandler object.
69321 *
69322 * @param handler The instance of `IOHandler` to be used during the model
69323 * loading.
69324 * @param customObjects Any optional custom objects to be used during model
69325 * loading.
69326 * @param strict Whether the weight loading will be done in strict mode.
69327 * Default: `true`.
69328 */
69329 function _loadLayersModel() {
69330 _loadLayersModel = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee6(pathOrIOHandler, options) {
69331 var handlers;
69332 return _regeneratorRuntime().wrap(function _callee6$(_context6) {
69333 while (1) switch (_context6.prev = _context6.next) {
69334 case 0:
69335 if (options == null) {
69336 options = {};
69337 }
69338 if (!(typeof pathOrIOHandler === 'string')) {
69339 _context6.next = 10;
69340 break;
69341 }
69342 handlers = getLoadHandlers(pathOrIOHandler, options);
69343 if (!(handlers.length === 0)) {
69344 _context6.next = 7;
69345 break;
69346 }
69347 // For backward compatibility: if no load handler can be found,
69348 // assume it is a relative http path.
69349 // TODO(cais): Reformat the args into a single `LoadOptions` once the core
69350 // is refactored.
69351 handlers.push(browserHTTPRequest(pathOrIOHandler, options));
69352 _context6.next = 9;
69353 break;
69354 case 7:
69355 if (!(handlers.length > 1)) {
69356 _context6.next = 9;
69357 break;
69358 }
69359 throw new ValueError("Found more than one (".concat(handlers.length, ") load handlers for ") + "URL '".concat(pathOrIOHandler, "'"));
69360 case 9:
69361 pathOrIOHandler = handlers[0];
69362 case 10:
69363 return _context6.abrupt("return", loadLayersModelFromIOHandler(pathOrIOHandler, undefined, options));
69364 case 11:
69365 case "end":
69366 return _context6.stop();
69367 }
69368 }, _callee6);
69369 }));
69370 return _loadLayersModel.apply(this, arguments);
69371 }
69372 function loadLayersModelFromIOHandler(_x5, _x6, _x7) {
69373 return _loadLayersModelFromIOHandler.apply(this, arguments);
69374 }
69375 function _loadLayersModelFromIOHandler() {
69376 _loadLayersModelFromIOHandler = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee7(handler, customObjects, options) {
69377 var artifacts, modelTopology, strict, fastWeightInit, model, trainingConfig, _decodeModelAndOptimi, modelWeights, optimizerWeights;
69378 return _regeneratorRuntime().wrap(function _callee7$(_context7) {
69379 while (1) switch (_context7.prev = _context7.next) {
69380 case 0:
69381 if (options == null) {
69382 options = {};
69383 }
69384 if (!(handler.load == null)) {
69385 _context7.next = 3;
69386 break;
69387 }
69388 throw new ValueError('Cannot proceed with model loading because the IOHandler provided ' + 'does not have the `load` method implemented.');
69389 case 3:
69390 _context7.next = 5;
69391 return handler.load();
69392 case 5:
69393 artifacts = _context7.sent;
69394 modelTopology = artifacts.modelTopology;
69395 if (modelTopology['model_config'] != null) {
69396 modelTopology = modelTopology['model_config'];
69397 }
69398 strict = options.strict == null ? true : options.strict; // If weights are provided and the weight-loading mode is strict, use
69399 // fast weight initialization. This skips costly initializers such as
69400 // 'orthogonal' and saves unnecessary computation in cases where
69401 // the initialized weight values will immediately be overwritten by
69402 // loaded weight values.
69403 fastWeightInit = artifacts.weightData != null && artifacts.weightSpecs != null && strict;
69404 model = deserialize(convertPythonicToTs(modelTopology), customObjects, fastWeightInit);
69405 trainingConfig = artifacts.trainingConfig;
69406 if (trainingConfig != null) {
69407 model.loadTrainingConfig(trainingConfig);
69408 }
69409 if (artifacts.userDefinedMetadata != null) {
69410 model.setUserDefinedMetadata(artifacts.userDefinedMetadata);
69411 }
69412 // If weightData is present, load the weights into the model.
69413 if (!(artifacts.weightData != null)) {
69414 _context7.next = 24;
69415 break;
69416 }
69417 if (!(artifacts.weightSpecs == null)) {
69418 _context7.next = 17;
69419 break;
69420 }
69421 throw new ValueError('LayersModel artifacts contains weight data, but not weight specs. ' + 'Therefore loading of weights cannot proceed.');
69422 case 17:
69423 _decodeModelAndOptimi = decodeModelAndOptimizerWeights(artifacts.weightData, artifacts.weightSpecs), modelWeights = _decodeModelAndOptimi.modelWeights, optimizerWeights = _decodeModelAndOptimi.optimizerWeights;
69424 model.loadWeights(modelWeights, strict);
69425 if (!(model.optimizer != null && optimizerWeights.length > 0)) {
69426 _context7.next = 22;
69427 break;
69428 }
69429 _context7.next = 22;
69430 return model.optimizer.setWeights(optimizerWeights);
69431 case 22:
69432 // Dispose temporary weight values.
69433 dispose(modelWeights);
69434 dispose(optimizerWeights.map(function (w) {
69435 return w.tensor;
69436 }));
69437 case 24:
69438 return _context7.abrupt("return", model);
69439 case 25:
69440 case "end":
69441 return _context7.stop();
69442 }
69443 }, _callee7);
69444 }));
69445 return _loadLayersModelFromIOHandler.apply(this, arguments);
69446 }
69447 function decodeModelAndOptimizerWeights(weightData, specs) {
69448 var name2Tensor = decodeWeights(weightData, specs);
69449 var modelWeights = {};
69450 var optimizerWeights = [];
69451 specs.forEach(function (spec) {
69452 if (spec.group === 'optimizer') {
69453 optimizerWeights.push({
69454 name: spec.name,
69455 tensor: name2Tensor[spec.name]
69456 });
69457 } else {
69458 modelWeights[spec.name] = name2Tensor[spec.name];
69459 }
69460 });
69461 return {
69462 modelWeights: modelWeights,
69463 optimizerWeights: optimizerWeights
69464 };
69465 }
69466 /**
69467 * A model with a stack of layers, feeding linearly from one to the next.
69468 *
69469 * `tf.sequential` is a factory function that creates an instance of
69470 * `tf.Sequential`.
69471 *
69472 * ```js
69473 * // Define a model for linear regression.
69474 * const model = tf.sequential();
69475 * model.add(tf.layers.dense({units: 1, inputShape: [1]}));
69476 *
69477 * // Prepare the model for training: Specify the loss and the optimizer.
69478 * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
69479 *
69480 * // Generate some synthetic data for training.
69481 * const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
69482 * const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
69483 *
69484 * // Train the model using the data then do inference on a data point the
69485 * // model hasn't seen:
69486 * await model.fit(xs, ys);
69487 * model.predict(tf.tensor2d([5], [1, 1])).print();
69488 * ```
69489 *
69490 * @doc {heading: 'Models', subheading: 'Classes'}
69491 */
69492 var Sequential = /*#__PURE__*/function (_LayersModel) {
69493 _inherits(Sequential, _LayersModel);
69494 var _super = _createSuper(Sequential);
69495 function Sequential(args) {
69496 var _this;
69497 _classCallCheck(this, Sequential);
69498 _this = _super.call(this, {
69499 inputs: [],
69500 outputs: []
69501 });
69502 args = args || {};
69503 _this.trainable = true;
69504 _this.built = false;
69505 // Set model name.
69506 _this.name = args.name != null ? args.name : getUid('sequential_');
69507 // Add to the model any layers passed to the constructor.
69508 if (args.layers != null) {
69509 var _iterator = _createForOfIteratorHelper(args.layers),
69510 _step;
69511 try {
69512 for (_iterator.s(); !(_step = _iterator.n()).done;) {
69513 var layer = _step.value;
69514 _this.add(layer);
69515 }
69516 } catch (err) {
69517 _iterator.e(err);
69518 } finally {
69519 _iterator.f();
69520 }
69521 }
69522 return _this;
69523 }
69524 // Helper function to Sequential.add Throws if the new output shape will be
69525 // invalid.
69526 _createClass(Sequential, [{
69527 key: "checkShape",
69528 value: function checkShape(layer) {
69529 var shape = layer.inboundNodes[0].outputTensors[0].shape;
69530 if (shape.some(function (x) {
69531 return x < 0;
69532 })) {
69533 throw new ValueError('Negative dimension size caused by adding layer ' + "".concat(layer.name, " with input shape [") + "".concat(layer.inboundNodes[0].inputTensors[0].shape, "]"));
69534 }
69535 }
69536 /**
69537 * Adds a layer instance on top of the layer stack.
69538 *
69539 * ```js
69540 * const model = tf.sequential();
69541 * model.add(tf.layers.dense({units: 8, inputShape: [1]}));
69542 * model.add(tf.layers.dense({units: 4, activation: 'relu6'}));
69543 * model.add(tf.layers.dense({units: 1, activation: 'relu6'}));
69544 * // Note that the untrained model is random at this point.
69545 * model.predict(tf.randomNormal([10, 1])).print();
69546 * ```
69547 * @param layer Layer instance.
69548 *
69549 * @exception ValueError In case the `layer` argument does not know its
69550 * input shape.
69551 * @exception ValueError In case the `layer` argument has multiple output
69552 * tensors, or is already connected somewhere else (forbidden in
69553 * `Sequential` models).
69554 *
69555 * @doc {heading: 'Models', subheading: 'Classes'}
69556 */
69557 }, {
69558 key: "add",
69559 value: function add(layer) {
69560 var isLayerModelInstance = layer instanceof Sequential || layer instanceof LayersModel;
69561 var modelLayer;
69562 if (isLayerModelInstance) {
69563 modelLayer = layer;
69564 if (modelLayer.outputs.length !== 1) {
69565 throw new ValueError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.');
69566 }
69567 if (modelLayer.inputs.length !== 1) {
69568 throw new ValueError('All layers in a Sequential model ' + 'should have a single input tensor. ' + 'For multi-input layers, ' + 'use the functional API.');
69569 }
69570 }
69571 if (this.outputs.length === 0) {
69572 // first layer in model: check that it is an input layer
69573 if (layer.inboundNodes.length === 0) {
69574 // create an input layer
69575 if (layer.batchInputShape == null) {
69576 throw new ValueError('The first layer in a Sequential model must ' + 'get an `inputShape` or `batchInputShape` argument.');
69577 }
69578 // Instantiate the input layer.
69579 var x = Input({
69580 batchShape: layer.batchInputShape,
69581 dtype: layer.dtype,
69582 name: layer.name + '_input'
69583 });
69584 // This will build the current layer and create the node connecting
69585 // the current layer to the input layer we just created.
69586 layer.apply(x);
69587 }
69588 if (isLayerModelInstance) {
69589 this.outputs = modelLayer.outputs;
69590 this.inputs = modelLayer.inputs;
69591 } else {
69592 if (layer.inboundNodes.length !== 1) {
69593 throw new ValueError('A layer added to a Sequential model must not already be ' + "connected somewhere else. LayersModel received layer ".concat(layer.name, " ") + "which has ".concat(layer.inboundNodes.length, " pre-existing inbound ") + 'connections.');
69594 }
69595 if (layer.inboundNodes[0].outputTensors.length !== 1) {
69596 throw new ValueError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.');
69597 }
69598 this.checkShape(layer);
69599 this.outputs = [layer.inboundNodes[0].outputTensors[0]];
69600 this.inputs = getSourceInputs(this.outputs[0]);
69601 }
69602 this.inboundNodes = [];
69603 // We create an input node, which we will keep updated
69604 // as we add more layers.
69605 // (This call has side effects.)
69606 // tslint:disable-next-line:no-unused-expression
69607 new Node({
69608 outboundLayer: this,
69609 inboundLayers: [],
69610 nodeIndices: [],
69611 tensorIndices: [],
69612 inputTensors: this.inputs,
69613 outputTensors: this.outputs,
69614 // no model-level masking for now
69615 inputMasks: pyListRepeat(null, this.inputs.length),
69616 outputMasks: [null],
69617 inputShapes: this.inputs.map(function (x) {
69618 return x.shape;
69619 }),
69620 outputShapes: this.outputs[0].shape
69621 });
69622 } else {
69623 var outputTensor = layer.apply(this.outputs[0]);
69624 if (Array.isArray(outputTensor)) {
69625 throw new TypeError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.');
69626 }
69627 this.checkShape(layer);
69628 this.outputs = [outputTensor];
69629 // update self.inbound_nodes
69630 this.inboundNodes[0].outputTensors = this.outputs;
69631 this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
69632 }
69633 this.layers.push(layer);
69634 this.built = false;
69635 }
69636 /**
69637 * Removes the last layer in the model.
69638 *
69639 * @exception TypeError if there are no layers in the model.
69640 */
69641 }, {
69642 key: "pop",
69643 value: function pop() {
69644 if (this.layers.length === 0) {
69645 throw new TypeError('There are no layers in the model.');
69646 }
69647 this.layers.pop();
69648 if (this.layers.length === 0) {
69649 this.outputs = [];
69650 this.inboundNodes = [];
69651 this.outboundNodes = [];
69652 } else {
69653 var lastLayerIndex = this.layers.length - 1;
69654 this.layers[lastLayerIndex].outboundNodes = [];
69655 this.outputs = [this.layers[lastLayerIndex].output];
69656 // update self.inbound_nodes
69657 this.inboundNodes[0].outputTensors = this.outputs;
69658 this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
69659 }
69660 }
69661 }, {
69662 key: "call",
69663 value: function call(inputs, kwargs) {
69664 if (this.model == null) {
69665 this.build();
69666 }
69667 return this.model.call(inputs, kwargs);
69668 }
69669 }, {
69670 key: "build",
69671 value: function build(inputShape) {
69672 // Call `getExactlyOneShape` without using its return value,
69673 // to verify that exactly one input shape is provided.
69674 getExactlyOneShape(inputShape);
69675 if (this.inputs.length === 0 || this.outputs.length === 0) {
69676 throw new TypeError('Sequential model cannot be built: model is empty.' + ' Add some layers first.');
69677 }
69678 // actually create the model
69679 this.model = new LayersModel({
69680 inputs: this.inputs,
69681 outputs: this.outputs[0],
69682 name: this.name + '_model'
69683 });
69684 this.model.trainable = this.trainable;
69685 // mirror model attributes
69686 this.supportsMasking = this.model.supportsMasking;
69687 // TODO(michaelterry): Add caches
69688 this.inputLayers = this.model.inputLayers;
69689 this.inputLayersNodeIndices = this.model.inputLayersNodeIndices;
69690 this.inputLayersTensorIndices = this.model.inputLayersTensorIndices;
69691 this.outputLayers = this.model.outputLayers;
69692 this.outputLayersNodeIndices = this.model.outputLayersNodeIndices;
69693 this.outputLayersTensorIndices = this.model.outputLayersTensorIndices;
69694 this.nodesByDepth = this.model.nodesByDepth;
69695 this.containerNodes = this.model.containerNodes;
69696 this.outputNames = this.model.outputNames;
69697 this.inputNames = this.model.inputNames;
69698 // TODO(michaelterry): Add feedInputNames, feedInputs, if needed.
69699 // TODO(michaelterry): Add callbackModel if needed.
69700 this.built = true;
69701 }
69702 }, {
69703 key: "countParams",
69704 value: function countParams() {
69705 if (!this.built) {
69706 this.build();
69707 }
69708 return _get(_getPrototypeOf(Sequential.prototype), "countParams", this).call(this);
69709 }
69710 /**
69711 * Print a text summary of the Sequential model's layers.
69712 *
69713 * The summary includes
69714 * - Name and type of all layers that comprise the model.
69715 * - Output shape(s) of the layers
69716 * - Number of weight parameters of each layer
69717 * - The total number of trainable and non-trainable parameters of the
69718 * model.
69719 *
69720 * ```js
69721 * const model = tf.sequential();
69722 * model.add(
69723 * tf.layers.dense({units: 100, inputShape: [10], activation: 'relu'}));
69724 * model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
69725 *
69726 * model.summary();
69727 * ```
69728 *
69729 * @param lineLength Custom line length, in number of characters.
69730 * @param positions Custom widths of each of the columns, as either
69731 * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
69732 * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
69733 * right-most (i.e., ending) position of a column.
69734 * @param printFn Custom print function. Can be used to replace the default
69735 * `console.log`. For example, you can use `x => {}` to mute the printed
69736 * messages in the console.
69737 *
69738 * @doc {heading: 'Models', subheading: 'Classes'}
69739 */
69740 }, {
69741 key: "summary",
69742 value: function summary(lineLength, positions) {
69743 var printFn = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : console.log;
69744 if (!this.built) {
69745 this.build();
69746 }
69747 _get(_getPrototypeOf(Sequential.prototype), "summary", this).call(this, lineLength, positions, printFn);
69748 }
69749 /**
69750 * Sets the weights of the model.
69751 *
69752 * @param weights Should be a list of Tensors with shapes and types matching
69753 * the output of `model.getWeights()`.
69754 */
69755 }, {
69756 key: "setWeights",
69757 value: function setWeights(weights) {
69758 if (this.model == null) {
69759 this.build();
69760 }
69761 this.model.setWeights(weights);
69762 }
69763 /**
69764 * Returns the loss value & metrics values for the model in test mode.
69765 *
69766 * Loss and metrics are specified during `compile()`, which needs to happen
69767 * before calls to `evaluate()`.
69768 *
69769 * Computation is done in batches.
69770 *
69771 * ```js
69772 * const model = tf.sequential({
69773 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
69774 * });
69775 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
69776 * const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), {
69777 * batchSize: 4,
69778 * });
69779 * result.print();
69780 * ```
69781 *
69782 * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
69783 * model has multiple inputs.
69784 * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
69785 * model has multiple outputs.
69786 * @param args A `ModelEvaluateConfig`, containing optional fields.
69787 *
69788 * @return `Scalar` test loss (if the model has a single output and no
69789 * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
69790 * and/or metrics). The attribute `model.metricsNames`
69791 * will give you the display labels for the scalar outputs.
69792 *
69793 * @doc {heading: 'Models', subheading: 'Classes'}
69794 */
69795 }, {
69796 key: "evaluate",
69797 value: function evaluate(x, y) {
69798 var args = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : {};
69799 if (!this.built) {
69800 throw new RuntimeError('The model needs to be compiled before being used.');
69801 }
69802 return this.model.evaluate(x, y, args);
69803 }
69804 // TODO(cais): Add code snippet below once real dataset objects are
69805 // available.
69806 /**
69807 * Evaluate model using a dataset object.
69808 *
69809 * Note: Unlike `evaluate()`, this method is asynchronous (`async`).
69810 *
69811 * @param dataset A dataset object. Its `iterator()` method is expected
69812 * to generate a dataset iterator object, the `next()` method of which
69813 * is expected to produce data batches for evaluation. The return value
69814 * of the `next()` call ought to contain a boolean `done` field and a
69815 * `value` field. The `value` field is expected to be an array of two
69816 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
69817 * case is for models with exactly one input and one output (e.g.
69818 * a sequential model). The latter case is for models with multiple
69819 * inputs and/or multiple outputs. Of the two items in the array, the
69820 * first is the input feature(s) and the second is the output target(s).
69821 * @param args A configuration object for the dataset-based evaluation.
69822 * @returns Loss and metric values as an Array of `Scalar` objects.
69823 *
69824 * @doc {heading: 'Models', subheading: 'Classes'}
69825 */
69826 }, {
69827 key: "evaluateDataset",
69828 value: function () {
69829 var _evaluateDataset = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(dataset, args) {
69830 return _regeneratorRuntime().wrap(function _callee$(_context) {
69831 while (1) switch (_context.prev = _context.next) {
69832 case 0:
69833 if (this.built) {
69834 _context.next = 2;
69835 break;
69836 }
69837 throw new RuntimeError('The model needs to be compiled before being used.');
69838 case 2:
69839 return _context.abrupt("return", this.model.evaluateDataset(dataset, args));
69840 case 3:
69841 case "end":
69842 return _context.stop();
69843 }
69844 }, _callee, this);
69845 }));
69846 function evaluateDataset(_x8, _x9) {
69847 return _evaluateDataset.apply(this, arguments);
69848 }
69849 return evaluateDataset;
69850 }()
69851 /**
69852 * Generates output predictions for the input samples.
69853 *
69854 * Computation is done in batches.
69855 *
69856 * Note: the "step" mode of predict() is currently not supported.
69857 * This is because the TensorFlow.js core backend is imperative only.
69858 *
69859 * ```js
69860 * const model = tf.sequential({
69861 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
69862 * });
69863 * model.predict(tf.ones([2, 10])).print();
69864 * ```
69865 *
69866 * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
69867 * the model has multiple inputs.
69868 * @param conifg A `ModelPredictConfig` object containing optional fields.
69869 *
69870 * @return `tf.Tensor`(s) of predictions.
69871 *
69872 * @exception ValueError In case of mismatch between the provided input data
69873 * and the model's expectations, or in case a stateful model receives a
69874 * number of samples that is not a multiple of the batch size.
69875 *
69876 * @doc {heading: 'Models', subheading: 'Classes'}
69877 */
69878 }, {
69879 key: "predict",
69880 value: function predict(x) {
69881 var args = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
69882 if (this.model == null) {
69883 this.build();
69884 }
69885 return this.model.predict(x, args);
69886 }
69887 /**
69888 * Returns predictions for a single batch of samples.
69889 *
69890 * @param x: Input samples, as a Tensor, or list of Tensors (if the model
69891 * has multiple inputs).
69892 * @return Tensor(s) of predictions
69893 */
69894 }, {
69895 key: "predictOnBatch",
69896 value: function predictOnBatch(x) {
69897 if (this.model == null) {
69898 this.build();
69899 }
69900 return this.model.predictOnBatch(x);
69901 }
69902 /**
69903 * See `LayersModel.compile`.
69904 *
69905 * @param args
69906 */
69907 }, {
69908 key: "compile",
69909 value: function compile(args) {
69910 this.build();
69911 this.model.compile(args);
69912 this.optimizer_ = this.model.optimizer;
69913 // tslint:disable-next-line:no-any
69914 this.isOptimizerOwned = this.model.isOptimizerOwned;
69915 this.loss = this.model.loss;
69916 this.metrics = this.model.metrics;
69917 // TODO(cais): Add this.lossWeights, this.sampleWeightMode,
69918 // this.weightedMetrics, this.targets.
69919 this.metricsTensors = this.model.metricsTensors;
69920 this.metricsNames = this.model.metricsNames;
69921 // TODO(cais): Add sampleWeights.
69922 }
69923 }, {
69924 key: "optimizer",
69925 get: function get() {
69926 return this.model == null ? undefined : this.model.optimizer;
69927 },
69928 set: function set(optimizer) {
69929 this.model.optimizer = optimizer;
69930 }
69931 /**
69932 * Trains the model for a fixed number of epochs (iterations on a dataset).
69933 *
69934 * ```js
69935 * const model = tf.sequential({
69936 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
69937 * });
69938 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
69939 * const history = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
69940 * batchSize: 4,
69941 * epochs: 3
69942 * });
69943 * console.log(history.history.loss[0]);
69944 * ```
69945 *
69946 * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
69947 * model has multiple inputs. If all inputs in the model are named, you can
69948 * also pass a dictionary mapping input names to `tf.Tensor`s.
69949 * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
69950 * the model has multiple outputs. If all outputs in the model are named, you
69951 * can also pass a dictionary mapping output names to `tf.Tensor`s.
69952 * @param args A `ModelFitConfig`, containing optional fields.
69953 *
69954 * @return A `History` instance. Its `history` attribute contains all
69955 * information collected during training.
69956 *
69957 * @exception ValueError In case of mismatch between the provided input data
69958 * and what the model expects.
69959 *
69960 * @doc {heading: 'Models', subheading: 'Classes'}
69961 */
69962 }, {
69963 key: "fit",
69964 value: function () {
69965 var _fit = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(x, y) {
69966 var args,
69967 _args2 = arguments;
69968 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
69969 while (1) switch (_context2.prev = _context2.next) {
69970 case 0:
69971 args = _args2.length > 2 && _args2[2] !== undefined ? _args2[2] : {};
69972 if (this.built) {
69973 _context2.next = 3;
69974 break;
69975 }
69976 throw new RuntimeError('The model needs to be compiled before ' + 'being used.');
69977 case 3:
69978 return _context2.abrupt("return", this.model.fit(x, y, args));
69979 case 4:
69980 case "end":
69981 return _context2.stop();
69982 }
69983 }, _callee2, this);
69984 }));
69985 function fit(_x10, _x11) {
69986 return _fit.apply(this, arguments);
69987 }
69988 return fit;
69989 }()
69990 /**
69991 * Trains the model using a dataset object.
69992 *
69993 * ```js
69994 * const xArray = [
69995 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
69996 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
69997 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
69998 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
69999 * ];
70000 * const yArray = [1, 1, 1, 1];
70001 * // Create a dataset from the JavaScript array.
70002 * const xDataset = tf.data.array(xArray);
70003 * const yDataset = tf.data.array(yArray);
70004 * // Zip combines the `x` and `y` Datasets into a single Dataset, the
70005 * // iterator of which will return an object containing of two tensors,
70006 * // corresponding to `x` and `y`. The call to `batch(4)` will bundle
70007 * // four such samples into a single object, with the same keys now pointing
70008 * // to tensors that hold 4 examples, organized along the batch dimension.
70009 * // The call to `shuffle(4)` causes each iteration through the dataset to
70010 * // happen in a different order. The size of the shuffle window is 4.
70011 * const xyDataset = tf.data.zip({xs: xDataset, ys: yDataset})
70012 * .batch(4)
70013 * .shuffle(4);
70014 * const model = tf.sequential({
70015 * layers: [tf.layers.dense({units: 1, inputShape: [9]})]
70016 * });
70017 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
70018 * const history = await model.fitDataset(xyDataset, {
70019 * epochs: 4,
70020 * callbacks: {onEpochEnd: (epoch, logs) => console.log(logs.loss)}
70021 * });
70022 * ```
70023 *
70024 * @param dataset A dataset object. Its `iterator()` method is expected to
70025 * generate a dataset iterator object, the `next()` method of which is
70026 * expected to produce data batches for evaluation. The return value of the
70027 * `next()` call ought to contain a boolean `done` field and a `value`
70028 * field.
70029 *
70030 * The `value` field is expected to be an object of with fields
70031 * `xs` and `ys`, which point to the feature tensor and the target tensor,
70032 * respectively. This case is for models with exactly one input and one
70033 * output (e.g. a sequential model). For example:
70034 * ```js
70035 * {value: {xs: xsTensor, ys: ysTensor}, done: false}
70036 * ```
70037 *
70038 * If the model has multiple inputs, the `xs` field of `value` should
70039 * be an object mapping input names to their respective feature tensors.
70040 * For example:
70041 * ```js
70042 * {
70043 * value: {
70044 * xs: {
70045 * input_1: xsTensor1,
70046 * input_2: xsTensor2
70047 * },
70048 * ys: ysTensor
70049 * },
70050 * done: false
70051 * }
70052 * ```
70053 * If the model has multiple outputs, the `ys` field of `value` should
70054 * be an object mapping output names to their respective target tensors.
70055 * For example:
70056 * ```js
70057 * {
70058 * value: {
70059 * xs: xsTensor,
70060 * ys: {
70061 * output_1: ysTensor1,
70062 * output_2: ysTensor2
70063 * },
70064 * },
70065 * done: false
70066 * }
70067 * ```
70068 * @param args A `ModelFitDatasetArgs`, containing optional fields.
70069 *
70070 * @return A `History` instance. Its `history` attribute contains all
70071 * information collected during training.
70072 *
70073 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
70074 */
70075 }, {
70076 key: "fitDataset",
70077 value: function () {
70078 var _fitDataset = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(dataset, args) {
70079 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
70080 while (1) switch (_context3.prev = _context3.next) {
70081 case 0:
70082 if (this.built) {
70083 _context3.next = 2;
70084 break;
70085 }
70086 throw new RuntimeError('The model needs to be compiled before ' + 'being used.');
70087 case 2:
70088 return _context3.abrupt("return", this.model.fitDataset(dataset, args));
70089 case 3:
70090 case "end":
70091 return _context3.stop();
70092 }
70093 }, _callee3, this);
70094 }));
70095 function fitDataset(_x12, _x13) {
70096 return _fitDataset.apply(this, arguments);
70097 }
70098 return fitDataset;
70099 }()
70100 /**
70101 * Runs a single gradient update on a single batch of data.
70102 *
70103 * This method differs from `fit()` and `fitDataset()` in the following
70104 * regards:
70105 * - It operates on exactly one batch of data.
70106 * - It returns only the loss and metric values, instead of
70107 * returning the batch-by-batch loss and metric values.
70108 * - It doesn't support fine-grained options such as verbosity and
70109 * callbacks.
70110 *
70111 * @param x Input data. It could be one of the following:
70112 * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
70113 * multiple inputs).
70114 * - An Object mapping input names to corresponding `tf.Tensor` (if the
70115 * model has named inputs).
70116 * @param y Target data. It could be either a `tf.Tensor` or multiple
70117 * `tf.Tensor`s. It should be consistent with `x`.
70118 * @returns Training loss or losses (in case the model has
70119 * multiple outputs), along with metrics (if any), as numbers.
70120 *
70121 * @doc {heading: 'Models', subheading: 'Classes'}
70122 */
70123 }, {
70124 key: "trainOnBatch",
70125 value: function () {
70126 var _trainOnBatch = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(x, y) {
70127 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
70128 while (1) switch (_context4.prev = _context4.next) {
70129 case 0:
70130 return _context4.abrupt("return", this.model.trainOnBatch(x, y));
70131 case 1:
70132 case "end":
70133 return _context4.stop();
70134 }
70135 }, _callee4, this);
70136 }));
70137 function trainOnBatch(_x14, _x15) {
70138 return _trainOnBatch.apply(this, arguments);
70139 }
70140 return trainOnBatch;
70141 }()
70142 /* See parent class for JsDoc */
70143 /** @nocollapse */
70144 }, {
70145 key: "stopTraining",
70146 get: function get() {
70147 if (this.model == null) {
70148 throw new ValueError('Cannot get the stopTraining property of a sequential model before ' + 'it is compiled.');
70149 }
70150 return this.model.stopTraining;
70151 }
70152 // TODO(cais): Override get trainableWeights() here
70153 // tslint:disable-next-line:no-any
70154 ,
70155 set:
70156 /**
70157 * Setter used for force stopping of LayersModel.fit() (i.e., training).
70158 *
70159 * Example:
70160 *
70161 * ```js
70162 * const model = tf.sequential();
70163 * model.add(tf.layers.dense({units: 1, inputShape: [10]}));
70164 * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
70165 * const xs = tf.ones([8, 10]);
70166 * const ys = tf.zeros([8, 1]);
70167 *
70168 * const history = await model.fit(xs, ys, {
70169 * epochs: 10,
70170 * callbacks: {
70171 * onEpochEnd: async (epoch, logs) => {
70172 * if (epoch === 2) {
70173 * model.stopTraining = true;
70174 * }
70175 * }
70176 * }
70177 * });
70178 *
70179 * // There should be only 3 values in the loss array, instead of 10 values,
70180 * // due to the stopping after 3 epochs.
70181 * console.log(history.history.loss);
70182 * ```
70183 */
70184 function set(stop) {
70185 // TODO(cais): When refactoring to remove the composition pattern happens,
70186 // remove this method overriding.
70187 if (this.model == null) {
70188 throw new ValueError('Cannot set the stopTraining property of a sequential model before ' + 'it is compiled.');
70189 }
70190 this.model.stopTraining = stop;
70191 }
70192 }, {
70193 key: "getConfig",
70194 value: function getConfig() {
70195 // NOTE(cais): We override the return type of getConfig() to `any` here,
70196 // because the `Sequential` class is a special case among `Container`
70197 // subtypes in that its getConfig() method returns an Array (not a
70198 // dict).
70199 var layers = [];
70200 var _iterator2 = _createForOfIteratorHelper(this.layers),
70201 _step2;
70202 try {
70203 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
70204 var layer = _step2.value;
70205 var dict = {};
70206 dict['className'] = layer.getClassName();
70207 dict['config'] = layer.getConfig();
70208 layers.push(dict);
70209 }
70210 } catch (err) {
70211 _iterator2.e(err);
70212 } finally {
70213 _iterator2.f();
70214 }
70215 return {
70216 name: this.name,
70217 layers: layers
70218 };
70219 }
70220 }], [{
70221 key: "fromConfig",
70222 value: function fromConfig(cls, config) {
70223 var customObjects = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : {};
70224 var fastWeightInit = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
70225 var configArray;
70226 var extraModelConfig = {};
70227 if (config instanceof Array) {
70228 if (!(config[0].className != null) || config[0]['className'] === 'Merge') {
70229 throw new ValueError('Legacy serialization format not supported yet.');
70230 }
70231 configArray = config;
70232 } else {
70233 assert$1(config['layers'] != null, function () {
70234 return "When the config data for a Sequential model is not an Array, " + "it must be an Object that contains the 'layers' field.";
70235 });
70236 configArray = config['layers'];
70237 delete config['layers'];
70238 extraModelConfig = config;
70239 }
70240 var model = new cls(extraModelConfig);
70241 if (!(model instanceof Sequential)) {
70242 throw new NotImplementedError("Sequential.fromConfig called on non-Sequential input: ".concat(model));
70243 }
70244 var _iterator3 = _createForOfIteratorHelper(configArray),
70245 _step3;
70246 try {
70247 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
70248 var conf = _step3.value;
70249 var _customObjects = undefined;
70250 var layer = deserialize(conf, _customObjects, fastWeightInit);
70251 if (fastWeightInit) {
70252 layer.setFastWeightInitDuringBuild(true);
70253 }
70254 model.add(layer);
70255 }
70256 } catch (err) {
70257 _iterator3.e(err);
70258 } finally {
70259 _iterator3.f();
70260 }
70261 return model;
70262 }
70263 }]);
70264 return Sequential;
70265 }(LayersModel);
70266 /** @nocollapse */
70267 Sequential.className = 'Sequential';
70268 registerClass(Sequential);
70269
70270 /**
70271 * @license
70272 * Copyright 2018 Google LLC
70273 *
70274 * Use of this source code is governed by an MIT-style
70275 * license that can be found in the LICENSE file or at
70276 * https://opensource.org/licenses/MIT.
70277 * =============================================================================
70278 */
70279 // TODO(cais): Add doc string to all the public static functions in this
70280 // class; include exectuable JavaScript code snippets where applicable
70281 // (b/74074458).
70282 // LayersModel and related factory methods.
70283 /**
70284 * A model is a data structure that consists of `Layers` and defines inputs
70285 * and outputs.
70286 *
70287 * The key difference between `tf.model` and `tf.sequential` is that
70288 * `tf.model` is more generic, supporting an arbitrary graph (without
70289 * cycles) of layers. `tf.sequential` is less generic and supports only a linear
70290 * stack of layers.
70291 *
70292 * When creating a `tf.LayersModel`, specify its input(s) and output(s). Layers
70293 * are used to wire input(s) to output(s).
70294 *
70295 * For example, the following code snippet defines a model consisting of
70296 * two `dense` layers, with 10 and 4 units, respectively.
70297 *
70298 * ```js
70299 * // Define input, which has a size of 5 (not including batch dimension).
70300 * const input = tf.input({shape: [5]});
70301 *
70302 * // First dense layer uses relu activation.
70303 * const denseLayer1 = tf.layers.dense({units: 10, activation: 'relu'});
70304 * // Second dense layer uses softmax activation.
70305 * const denseLayer2 = tf.layers.dense({units: 4, activation: 'softmax'});
70306 *
70307 * // Obtain the output symbolic tensor by applying the layers on the input.
70308 * const output = denseLayer2.apply(denseLayer1.apply(input));
70309 *
70310 * // Create the model based on the inputs.
70311 * const model = tf.model({inputs: input, outputs: output});
70312 *
70313 * // The model can be used for training, evaluation and prediction.
70314 * // For example, the following line runs prediction with the model on
70315 * // some fake data.
70316 * model.predict(tf.ones([2, 5])).print();
70317 * ```
70318 * See also:
70319 * `tf.sequential`, `tf.loadLayersModel`.
70320 *
70321 * @doc {heading: 'Models', subheading: 'Creation'}
70322 */
70323 function model(args) {
70324 return new LayersModel(args);
70325 }
70326 /**
70327 * Creates a `tf.Sequential` model. A sequential model is any model where the
70328 * outputs of one layer are the inputs to the next layer, i.e. the model
70329 * topology is a simple 'stack' of layers, with no branching or skipping.
70330 *
70331 * This means that the first layer passed to a `tf.Sequential` model should have
70332 * a defined input shape. What that means is that it should have received an
70333 * `inputShape` or `batchInputShape` argument, or for some type of layers
70334 * (recurrent, Dense...) an `inputDim` argument.
70335 *
70336 * The key difference between `tf.model` and `tf.sequential` is that
70337 * `tf.sequential` is less generic, supporting only a linear stack of layers.
70338 * `tf.model` is more generic and supports an arbitrary graph (without
70339 * cycles) of layers.
70340 *
70341 * Examples:
70342 *
70343 * ```js
70344 * const model = tf.sequential();
70345 *
70346 * // First layer must have an input shape defined.
70347 * model.add(tf.layers.dense({units: 32, inputShape: [50]}));
70348 * // Afterwards, TF.js does automatic shape inference.
70349 * model.add(tf.layers.dense({units: 4}));
70350 *
70351 * // Inspect the inferred shape of the model's output, which equals
70352 * // `[null, 4]`. The 1st dimension is the undetermined batch dimension; the
70353 * // 2nd is the output size of the model's last layer.
70354 * console.log(JSON.stringify(model.outputs[0].shape));
70355 * ```
70356 *
70357 * It is also possible to specify a batch size (with potentially undetermined
70358 * batch dimension, denoted by "null") for the first layer using the
70359 * `batchInputShape` key. The following example is equivalent to the above:
70360 *
70361 * ```js
70362 * const model = tf.sequential();
70363 *
70364 * // First layer must have a defined input shape
70365 * model.add(tf.layers.dense({units: 32, batchInputShape: [null, 50]}));
70366 * // Afterwards, TF.js does automatic shape inference.
70367 * model.add(tf.layers.dense({units: 4}));
70368 *
70369 * // Inspect the inferred shape of the model's output.
70370 * console.log(JSON.stringify(model.outputs[0].shape));
70371 * ```
70372 *
70373 * You can also use an `Array` of already-constructed `Layer`s to create
70374 * a `tf.Sequential` model:
70375 *
70376 * ```js
70377 * const model = tf.sequential({
70378 * layers: [tf.layers.dense({units: 32, inputShape: [50]}),
70379 * tf.layers.dense({units: 4})]
70380 * });
70381 * console.log(JSON.stringify(model.outputs[0].shape));
70382 * ```
70383 *
70384 * @doc {heading: 'Models', subheading: 'Creation'}
70385 */
70386 function sequential(config) {
70387 return new Sequential(config);
70388 }
70389 /**
70390 * Used to instantiate an input to a model as a `tf.SymbolicTensor`.
70391 *
70392 * Users should call the `input` factory function for
70393 * consistency with other generator functions.
70394 *
70395 * Example:
70396 *
70397 * ```js
70398 * // Defines a simple logistic regression model with 32 dimensional input
70399 * // and 3 dimensional output.
70400 * const x = tf.input({shape: [32]});
70401 * const y = tf.layers.dense({units: 3, activation: 'softmax'}).apply(x);
70402 * const model = tf.model({inputs: x, outputs: y});
70403 * model.predict(tf.ones([2, 32])).print();
70404 * ```
70405 *
70406 * Note: `input` is only necessary when using `model`. When using
70407 * `sequential`, specify `inputShape` for the first layer or use `inputLayer`
70408 * as the first layer.
70409 *
70410 * @doc {heading: 'Models', subheading: 'Inputs'}
70411 */
70412 function input(config) {
70413 return Input(config);
70414 }
70415 function registerCallbackConstructor(verbosityLevel, callbackConstructor) {
70416 CallbackConstructorRegistry.registerCallbackConstructor(verbosityLevel, callbackConstructor);
70417 }
70418
70419 /**
70420 * Base class for Activations.
70421 *
70422 * Special note: due to cross-language compatibility reasons, the
70423 * static readonly className field in this family of classes must be set to
70424 * the initialLowerCamelCase name of the activation.
70425 */
70426 var Activation$1 = /*#__PURE__*/function (_serialization$Serial) {
70427 _inherits(Activation, _serialization$Serial);
70428 var _super = _createSuper(Activation);
70429 function Activation() {
70430 _classCallCheck(this, Activation);
70431 return _super.apply(this, arguments);
70432 }
70433 _createClass(Activation, [{
70434 key: "getConfig",
70435 value: function getConfig() {
70436 return {};
70437 }
70438 }]);
70439 return Activation;
70440 }(Serializable);
70441 /**
70442 * Exponential linear unit (ELU).
70443 * Reference: https://arxiv.org/abs/1511.07289
70444 */
70445 var Elu = /*#__PURE__*/function (_Activation) {
70446 _inherits(Elu, _Activation);
70447 var _super2 = _createSuper(Elu);
70448 function Elu() {
70449 _classCallCheck(this, Elu);
70450 return _super2.apply(this, arguments);
70451 }
70452 _createClass(Elu, [{
70453 key: "apply",
70454 value:
70455 /**
70456 * Calculate the activation function.
70457 *
70458 * @param x: Input.
70459 * @param alpha: Scaling factor the negative section.
70460 * @return Output of the ELU activation.
70461 */
70462 function apply(x) {
70463 var alpha = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 1;
70464 return elu$3(x, alpha);
70465 }
70466 }]);
70467 return Elu;
70468 }(Activation$1);
70469 /** @nocollapse */
70470 Elu.className = 'elu';
70471 registerClass(Elu);
70472 /**
70473 * Scaled Exponential Linear Unit. (Klambauer et al., 2017).
70474 * Reference: Self-Normalizing Neural Networks, https://arxiv.org/abs/1706.02515
70475 * Notes:
70476 * - To be used together with the initialization "lecunNormal".
70477 * - To be used together with the dropout variant "AlphaDropout".
70478 */
70479 var Selu = /*#__PURE__*/function (_Activation2) {
70480 _inherits(Selu, _Activation2);
70481 var _super3 = _createSuper(Selu);
70482 function Selu() {
70483 _classCallCheck(this, Selu);
70484 return _super3.apply(this, arguments);
70485 }
70486 _createClass(Selu, [{
70487 key: "apply",
70488 value: function apply(x) {
70489 return selu$2(x);
70490 }
70491 }]);
70492 return Selu;
70493 }(Activation$1);
70494 /** @nocollapse */
70495 Selu.className = 'selu';
70496 registerClass(Selu);
70497 /**
70498 * Rectified linear unit
70499 */
70500 var Relu = /*#__PURE__*/function (_Activation3) {
70501 _inherits(Relu, _Activation3);
70502 var _super4 = _createSuper(Relu);
70503 function Relu() {
70504 _classCallCheck(this, Relu);
70505 return _super4.apply(this, arguments);
70506 }
70507 _createClass(Relu, [{
70508 key: "apply",
70509 value: function apply(x) {
70510 return relu$2(x);
70511 }
70512 }]);
70513 return Relu;
70514 }(Activation$1);
70515 /** @nocollapse */
70516 Relu.className = 'relu';
70517 registerClass(Relu);
70518 /**
70519 * Rectified linear unit activation maxing out at 6.0.
70520 */
70521 var Relu6 = /*#__PURE__*/function (_Activation4) {
70522 _inherits(Relu6, _Activation4);
70523 var _super5 = _createSuper(Relu6);
70524 function Relu6() {
70525 _classCallCheck(this, Relu6);
70526 return _super5.apply(this, arguments);
70527 }
70528 _createClass(Relu6, [{
70529 key: "apply",
70530 value: function apply(x) {
70531 return tidy(function () {
70532 return minimum$4(6.0, relu$2(x));
70533 });
70534 }
70535 }]);
70536 return Relu6;
70537 }(Activation$1);
70538 /** @nocollapse */
70539 Relu6.className = 'relu6';
70540 registerClass(Relu6);
70541 //* Linear activation (no-op) */
70542 var Linear = /*#__PURE__*/function (_Activation5) {
70543 _inherits(Linear, _Activation5);
70544 var _super6 = _createSuper(Linear);
70545 function Linear() {
70546 _classCallCheck(this, Linear);
70547 return _super6.apply(this, arguments);
70548 }
70549 _createClass(Linear, [{
70550 key: "apply",
70551 value: function apply(x) {
70552 return x;
70553 }
70554 }]);
70555 return Linear;
70556 }(Activation$1);
70557 /** @nocollapse */
70558 Linear.className = 'linear';
70559 registerClass(Linear);
70560 /**
70561 * Sigmoid activation function.
70562 */
70563 var Sigmoid = /*#__PURE__*/function (_Activation6) {
70564 _inherits(Sigmoid, _Activation6);
70565 var _super7 = _createSuper(Sigmoid);
70566 function Sigmoid() {
70567 _classCallCheck(this, Sigmoid);
70568 return _super7.apply(this, arguments);
70569 }
70570 _createClass(Sigmoid, [{
70571 key: "apply",
70572 value: function apply(x) {
70573 return sigmoid$2(x);
70574 }
70575 }]);
70576 return Sigmoid;
70577 }(Activation$1);
70578 /** @nocollapse */
70579 Sigmoid.className = 'sigmoid';
70580 registerClass(Sigmoid);
70581 /**
70582 * Segment-wise linear approximation of sigmoid.
70583 */
70584 var HardSigmoid = /*#__PURE__*/function (_Activation7) {
70585 _inherits(HardSigmoid, _Activation7);
70586 var _super8 = _createSuper(HardSigmoid);
70587 function HardSigmoid() {
70588 _classCallCheck(this, HardSigmoid);
70589 return _super8.apply(this, arguments);
70590 }
70591 _createClass(HardSigmoid, [{
70592 key: "apply",
70593 value: function apply(x) {
70594 return hardSigmoid(x);
70595 }
70596 }]);
70597 return HardSigmoid;
70598 }(Activation$1);
70599 /** @nocollapse */
70600 HardSigmoid.className = 'hardSigmoid';
70601 registerClass(HardSigmoid);
70602 /**
70603 * Softplus activation function.
70604 */
70605 var Softplus = /*#__PURE__*/function (_Activation8) {
70606 _inherits(Softplus, _Activation8);
70607 var _super9 = _createSuper(Softplus);
70608 function Softplus() {
70609 _classCallCheck(this, Softplus);
70610 return _super9.apply(this, arguments);
70611 }
70612 _createClass(Softplus, [{
70613 key: "apply",
70614 value: function apply(x) {
70615 return softplus$2(x);
70616 }
70617 }]);
70618 return Softplus;
70619 }(Activation$1);
70620 /** @nocollapse */
70621 Softplus.className = 'softplus';
70622 registerClass(Softplus);
70623 /**
70624 * Softsign activation function.
70625 */
70626 var Softsign = /*#__PURE__*/function (_Activation9) {
70627 _inherits(Softsign, _Activation9);
70628 var _super10 = _createSuper(Softsign);
70629 function Softsign() {
70630 _classCallCheck(this, Softsign);
70631 return _super10.apply(this, arguments);
70632 }
70633 _createClass(Softsign, [{
70634 key: "apply",
70635 value: function apply(x) {
70636 return softsign(x);
70637 }
70638 }]);
70639 return Softsign;
70640 }(Activation$1);
70641 /** @nocollapse */
70642 Softsign.className = 'softsign';
70643 registerClass(Softsign);
70644 /**
70645 * Hyperbolic tangent function.
70646 */
70647 var Tanh = /*#__PURE__*/function (_Activation10) {
70648 _inherits(Tanh, _Activation10);
70649 var _super11 = _createSuper(Tanh);
70650 function Tanh() {
70651 _classCallCheck(this, Tanh);
70652 return _super11.apply(this, arguments);
70653 }
70654 _createClass(Tanh, [{
70655 key: "apply",
70656 value: function apply(x) {
70657 return tanh$2(x);
70658 }
70659 }]);
70660 return Tanh;
70661 }(Activation$1);
70662 /** @nocollapse */
70663 Tanh.className = 'tanh';
70664 registerClass(Tanh);
70665 /**
70666 * Softmax activation function
70667 */
70668 var Softmax$1 = /*#__PURE__*/function (_Activation11) {
70669 _inherits(Softmax, _Activation11);
70670 var _super12 = _createSuper(Softmax);
70671 function Softmax() {
70672 _classCallCheck(this, Softmax);
70673 return _super12.apply(this, arguments);
70674 }
70675 _createClass(Softmax, [{
70676 key: "apply",
70677 value:
70678 /**
70679 * Calculate the activation function.
70680 *
70681 * @param x Tensor.
70682 * @param axis Integer, axis along which the softmax normalization is applied.
70683 * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be
70684 * an error.
70685 *
70686 * @returns a Tensor of the same shape as x
70687 *
70688 * @throws ValueError: In case `dim(x) < 2`.
70689 */
70690 function apply(x) {
70691 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : -1;
70692 return softmax$3(x, axis);
70693 }
70694 }]);
70695 return Softmax;
70696 }(Activation$1);
70697 /** @nocollapse */
70698 Softmax$1.className = 'softmax';
70699 registerClass(Softmax$1);
70700 /**
70701 * Log softmax activation function
70702 */
70703 var LogSoftmax = /*#__PURE__*/function (_Activation12) {
70704 _inherits(LogSoftmax, _Activation12);
70705 var _super13 = _createSuper(LogSoftmax);
70706 function LogSoftmax() {
70707 _classCallCheck(this, LogSoftmax);
70708 return _super13.apply(this, arguments);
70709 }
70710 _createClass(LogSoftmax, [{
70711 key: "apply",
70712 value:
70713 /**
70714 * Calculate the activation function of log softmax:
70715 * log( exp(x_i) / sum(exp(x)) )
70716 *
70717 * @param x Tensor.
70718 * @param axis Integer, axis along which the softmax normalization is applied.
70719 * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be
70720 * an error.
70721 *
70722 * @returns a Tensor of the same shape as x
70723 *
70724 * @throws ValueError: In case `dim(x) < 2`.
70725 */
70726 function apply(x) {
70727 var axis = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : -1;
70728 return logSoftmax(x, axis);
70729 }
70730 }]);
70731 return LogSoftmax;
70732 }(Activation$1);
70733 /** @nocollapse */
70734 LogSoftmax.className = 'logSoftmax';
70735 registerClass(LogSoftmax);
70736 /**
70737 * Gelu activation function
70738 */
70739 var Gelu = /*#__PURE__*/function (_Activation13) {
70740 _inherits(Gelu, _Activation13);
70741 var _super14 = _createSuper(Gelu);
70742 function Gelu() {
70743 _classCallCheck(this, Gelu);
70744 return _super14.apply(this, arguments);
70745 }
70746 _createClass(Gelu, [{
70747 key: "apply",
70748 value:
70749 /**
70750 * Calculate the activation function.
70751 *
70752 * @param x Tensor.
70753 * @returns a Tensor of the same shape as x
70754 */
70755 function apply(x) {
70756 return tidy(function () {
70757 return tidy(function () {
70758 var sqrtTwo = Math.sqrt(2);
70759 // Compute Φ(x) using the erf function
70760 var cdf = mul(0.5, add$3(1, erf$2(div$1(x, sqrtTwo))));
70761 // Compute GELU(x) = x * Φ(x)
70762 return mul(x, cdf);
70763 });
70764 });
70765 }
70766 }]);
70767 return Gelu;
70768 }(Activation$1);
70769 /** @nocollapse */
70770 Gelu.className = 'gelu';
70771 registerClass(Gelu);
70772 /**
70773 * GeluNew activation function
70774 */
70775 var GeluNew = /*#__PURE__*/function (_Activation14) {
70776 _inherits(GeluNew, _Activation14);
70777 var _super15 = _createSuper(GeluNew);
70778 function GeluNew() {
70779 _classCallCheck(this, GeluNew);
70780 return _super15.apply(this, arguments);
70781 }
70782 _createClass(GeluNew, [{
70783 key: "apply",
70784 value:
70785 /**
70786 * Calculate the activation function.
70787 *
70788 * @param x Tensor.
70789 * @returns a Tensor of the same shape as x
70790 */
70791 function apply(x) {
70792 return tidy(function () {
70793 return mul(0.5, mul(x, add$3(1, tanh$2(mul(sqrt$2(div$1(2, Math.PI)), add$3(x, mul(0.044715, pow$3(x, 3))))))));
70794 });
70795 }
70796 }]);
70797 return GeluNew;
70798 }(Activation$1);
70799 /** @nocollapse */
70800 GeluNew.className = 'gelu_new';
70801 registerClass(GeluNew);
70802 /**
70803 * Mish activation function
70804 */
70805 var Mish = /*#__PURE__*/function (_Activation15) {
70806 _inherits(Mish, _Activation15);
70807 var _super16 = _createSuper(Mish);
70808 function Mish() {
70809 _classCallCheck(this, Mish);
70810 return _super16.apply(this, arguments);
70811 }
70812 _createClass(Mish, [{
70813 key: "apply",
70814 value:
70815 /**
70816 * Calculate the activation function.
70817 *
70818 * @param x Tensor.
70819 * @returns a Tensor of the same shape as x
70820 */
70821 function apply(x) {
70822 return tidy(function () {
70823 return mul(x, tanh$2(softplus$2(x)));
70824 });
70825 }
70826 }]);
70827 return Mish;
70828 }(Activation$1);
70829 /** @nocollapse */
70830 Mish.className = 'mish';
70831 registerClass(Mish);
70832 /**
70833 * Swish activation function
70834 */
70835 var Swish = /*#__PURE__*/function (_Activation16) {
70836 _inherits(Swish, _Activation16);
70837 var _super17 = _createSuper(Swish);
70838 function Swish() {
70839 _classCallCheck(this, Swish);
70840 return _super17.apply(this, arguments);
70841 }
70842 _createClass(Swish, [{
70843 key: "apply",
70844 value:
70845 /**
70846 * Calculate the activation function.
70847 *
70848 * @param x Tensor.
70849 * @param alpha Scaling factor for the sigmoid function.
70850 * @returns a Tensor of the same shape as x
70851 */
70852 function apply(x) {
70853 var alpha = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 1;
70854 return tidy(function () {
70855 return mul(sigmoid$2(mul(x, alpha)), x);
70856 });
70857 }
70858 }]);
70859 return Swish;
70860 }(Activation$1);
70861 /** @nocollapse */
70862 Swish.className = 'swish';
70863 registerClass(Swish);
70864 function serializeActivation(activation) {
70865 return activation.getClassName();
70866 }
70867 function deserializeActivation(config) {
70868 var customObjects = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
70869 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'activation');
70870 }
70871 function getActivation(identifier) {
70872 if (identifier == null) {
70873 var config = {};
70874 config['className'] = 'linear';
70875 config['config'] = {};
70876 return deserializeActivation(config);
70877 }
70878 if (typeof identifier === 'string') {
70879 var _config = {};
70880 _config['className'] = identifier;
70881 _config['config'] = {};
70882 return deserializeActivation(_config);
70883 } else if (identifier instanceof Activation$1) {
70884 return identifier;
70885 } else {
70886 return deserializeActivation(identifier);
70887 }
70888 }
70889
70890 function assertObjectArgs(args) {
70891 if (args != null && _typeof(args) !== 'object') {
70892 throw new Error("Argument to L1L2 regularizer's constructor is expected to be an " + "object, but received: ".concat(args));
70893 }
70894 }
70895 /**
70896 * Regularizer base class.
70897 */
70898 var Regularizer = /*#__PURE__*/function (_serialization$Serial) {
70899 _inherits(Regularizer, _serialization$Serial);
70900 var _super = _createSuper(Regularizer);
70901 function Regularizer() {
70902 _classCallCheck(this, Regularizer);
70903 return _super.apply(this, arguments);
70904 }
70905 return _createClass(Regularizer);
70906 }(Serializable);
70907 var L1L2 = /*#__PURE__*/function (_Regularizer) {
70908 _inherits(L1L2, _Regularizer);
70909 var _super2 = _createSuper(L1L2);
70910 function L1L2(args) {
70911 var _this;
70912 _classCallCheck(this, L1L2);
70913 _this = _super2.call(this);
70914 assertObjectArgs(args);
70915 _this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
70916 _this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
70917 _this.hasL1 = _this.l1 !== 0;
70918 _this.hasL2 = _this.l2 !== 0;
70919 return _this;
70920 }
70921 /**
70922 * Porting note: Renamed from __call__.
70923 * @param x Variable of which to calculate the regularization score.
70924 */
70925 _createClass(L1L2, [{
70926 key: "apply",
70927 value: function apply(x) {
70928 var _this2 = this;
70929 return tidy(function () {
70930 var regularization = zeros$2([1]);
70931 if (_this2.hasL1) {
70932 regularization = add$3(regularization, sum$3(mul(_this2.l1, abs$2(x))));
70933 }
70934 if (_this2.hasL2) {
70935 regularization = add$3(regularization, sum$3(mul(_this2.l2, square$1(x))));
70936 }
70937 return reshape$3(regularization, []);
70938 });
70939 }
70940 }, {
70941 key: "getConfig",
70942 value: function getConfig() {
70943 return {
70944 'l1': this.l1,
70945 'l2': this.l2
70946 };
70947 }
70948 /** @nocollapse */
70949 }], [{
70950 key: "fromConfig",
70951 value: function fromConfig(cls, config) {
70952 return new cls({
70953 l1: config['l1'],
70954 l2: config['l2']
70955 });
70956 }
70957 }]);
70958 return L1L2;
70959 }(Regularizer);
70960 /** @nocollapse */
70961 L1L2.className = 'L1L2';
70962 registerClass(L1L2);
70963 function l1$1(args) {
70964 assertObjectArgs(args);
70965 return new L1L2({
70966 l1: args != null ? args.l1 : null,
70967 l2: 0
70968 });
70969 }
70970 function l2$1(args) {
70971 assertObjectArgs(args);
70972 return new L1L2({
70973 l2: args != null ? args.l2 : null,
70974 l1: 0
70975 });
70976 }
70977 // Maps the JavaScript-like identifier keys to the corresponding keras symbols.
70978 var REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
70979 'l1l2': 'L1L2'
70980 };
70981 function serializeRegularizer(constraint) {
70982 return serializeKerasObject(constraint);
70983 }
70984 function deserializeRegularizer(config) {
70985 var customObjects = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
70986 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'regularizer');
70987 }
70988 function getRegularizer(identifier) {
70989 if (identifier == null) {
70990 return null;
70991 }
70992 if (typeof identifier === 'string') {
70993 var className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ? REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier;
70994 var config = {
70995 className: className,
70996 config: {}
70997 };
70998 return deserializeRegularizer(config);
70999 } else if (identifier instanceof Regularizer) {
71000 return identifier;
71001 } else {
71002 return deserializeRegularizer(identifier);
71003 }
71004 }
71005
71006 var ReLU = /*#__PURE__*/function (_Layer) {
71007 _inherits(ReLU, _Layer);
71008 var _super = _createSuper(ReLU);
71009 function ReLU(args) {
71010 var _this;
71011 _classCallCheck(this, ReLU);
71012 _this = _super.call(this, args == null ? {} : args);
71013 _this.supportsMasking = true;
71014 if (args != null) {
71015 _this.maxValue = args.maxValue;
71016 }
71017 return _this;
71018 }
71019 _createClass(ReLU, [{
71020 key: "call",
71021 value: function call(inputs, kwargs) {
71022 inputs = getExactlyOneTensor(inputs);
71023 var output = relu$2(inputs);
71024 if (this.maxValue != null) {
71025 output = clipByValue$2(output, 0, this.maxValue);
71026 }
71027 return output;
71028 }
71029 }, {
71030 key: "computeOutputShape",
71031 value: function computeOutputShape(inputShape) {
71032 return inputShape;
71033 }
71034 }, {
71035 key: "getConfig",
71036 value: function getConfig() {
71037 var config = {
71038 maxValue: this.maxValue
71039 };
71040 var baseConfig = _get(_getPrototypeOf(ReLU.prototype), "getConfig", this).call(this);
71041 Object.assign(config, baseConfig);
71042 return config;
71043 }
71044 }]);
71045 return ReLU;
71046 }(Layer);
71047 /** @nocollapse */
71048 ReLU.className = 'ReLU';
71049 registerClass(ReLU);
71050 var LeakyReLU = /*#__PURE__*/function (_Layer2) {
71051 _inherits(LeakyReLU, _Layer2);
71052 var _super2 = _createSuper(LeakyReLU);
71053 function LeakyReLU(args) {
71054 var _this2;
71055 _classCallCheck(this, LeakyReLU);
71056 _this2 = _super2.call(this, args == null ? {} : args);
71057 _this2.DEFAULT_ALPHA = 0.3;
71058 if (args == null) {
71059 args = {};
71060 }
71061 _this2.alpha = args.alpha == null ? _this2.DEFAULT_ALPHA : args.alpha;
71062 return _this2;
71063 }
71064 _createClass(LeakyReLU, [{
71065 key: "call",
71066 value: function call(inputs, kwargs) {
71067 var x = getExactlyOneTensor(inputs);
71068 return leakyRelu$2(x, this.alpha);
71069 }
71070 }, {
71071 key: "computeOutputShape",
71072 value: function computeOutputShape(inputShape) {
71073 return inputShape;
71074 }
71075 }, {
71076 key: "getConfig",
71077 value: function getConfig() {
71078 var config = {
71079 alpha: this.alpha
71080 };
71081 var baseConfig = _get(_getPrototypeOf(LeakyReLU.prototype), "getConfig", this).call(this);
71082 Object.assign(config, baseConfig);
71083 return config;
71084 }
71085 }]);
71086 return LeakyReLU;
71087 }(Layer);
71088 /** @nocollapse */
71089 LeakyReLU.className = 'LeakyReLU';
71090 registerClass(LeakyReLU);
71091 var PReLU = /*#__PURE__*/function (_Layer3) {
71092 _inherits(PReLU, _Layer3);
71093 var _super3 = _createSuper(PReLU);
71094 function PReLU(args) {
71095 var _this3;
71096 _classCallCheck(this, PReLU);
71097 _this3 = _super3.call(this, args == null ? {} : args);
71098 _this3.DEFAULT_ALPHA_INITIALIZER = 'zeros';
71099 if (args == null) {
71100 args = {};
71101 }
71102 _this3.supportsMasking = true;
71103 _this3.alphaInitializer = getInitializer(args.alphaInitializer || _this3.DEFAULT_ALPHA_INITIALIZER);
71104 _this3.alphaRegularizer = getRegularizer(args.alphaRegularizer);
71105 _this3.alphaConstraint = getConstraint(args.alphaConstraint);
71106 if (args.sharedAxes == null) {
71107 _this3.sharedAxes = null;
71108 } else if (Array.isArray(args.sharedAxes)) {
71109 _this3.sharedAxes = args.sharedAxes;
71110 } else if (typeof args.sharedAxes === 'number') {
71111 _this3.sharedAxes = [args.sharedAxes];
71112 } else {
71113 throw new ValueError("Expected sharedAxes to be a number or an array of numbers, " + "but got ".concat(args.sharedAxes));
71114 }
71115 return _this3;
71116 }
71117 _createClass(PReLU, [{
71118 key: "build",
71119 value: function build(inputShape) {
71120 inputShape = getExactlyOneShape(inputShape);
71121 var paramShape = inputShape.slice(1);
71122 if (this.sharedAxes != null) {
71123 var _iterator = _createForOfIteratorHelper(this.sharedAxes),
71124 _step;
71125 try {
71126 for (_iterator.s(); !(_step = _iterator.n()).done;) {
71127 var i = _step.value;
71128 paramShape[i - 1] = 1;
71129 }
71130 } catch (err) {
71131 _iterator.e(err);
71132 } finally {
71133 _iterator.f();
71134 }
71135 }
71136 this.alpha = this.addWeight('alpha', paramShape, 'float32', this.alphaInitializer, this.alphaRegularizer, true, this.alphaConstraint);
71137 // Set input spec.
71138 var axes = {};
71139 if (this.sharedAxes != null) {
71140 for (var _i = 1; _i < inputShape.length; ++_i) {
71141 axes[_i] = inputShape[_i];
71142 }
71143 }
71144 this.inputSpec = [new InputSpec({
71145 ndim: inputShape.length,
71146 axes: axes
71147 })];
71148 this.built = true;
71149 }
71150 }, {
71151 key: "call",
71152 value: function call(inputs, kwargs) {
71153 inputs = getExactlyOneTensor(inputs);
71154 return prelu$3(inputs, this.alpha.read());
71155 }
71156 }, {
71157 key: "getConfig",
71158 value: function getConfig() {
71159 var config = {
71160 alphaInitializer: serializeInitializer(this.alphaInitializer),
71161 alphaRegularizer: serializeRegularizer(this.alphaRegularizer),
71162 alphaConstraint: serializeConstraint(this.alphaConstraint),
71163 sharedAxes: this.sharedAxes
71164 };
71165 var baseConfig = _get(_getPrototypeOf(PReLU.prototype), "getConfig", this).call(this);
71166 Object.assign(config, baseConfig);
71167 return config;
71168 }
71169 }]);
71170 return PReLU;
71171 }(Layer);
71172 /** @nocollapse */
71173 PReLU.className = 'PReLU';
71174 registerClass(PReLU);
71175 var ELU$3 = /*#__PURE__*/function (_Layer4) {
71176 _inherits(ELU, _Layer4);
71177 var _super4 = _createSuper(ELU);
71178 function ELU(args) {
71179 var _this4;
71180 _classCallCheck(this, ELU);
71181 _this4 = _super4.call(this, args == null ? {} : args);
71182 _this4.DEFAULT_ALPHA = 1.0;
71183 if (args == null) {
71184 args = {};
71185 }
71186 if (args.alpha != null && args.alpha !== _this4.DEFAULT_ALPHA) {
71187 throw new NotImplementedError("Non-default alpha value (".concat(args.alpha, ") is not supported by the ") + "ELU layer yet.");
71188 }
71189 _this4.alpha = args.alpha == null ? _this4.DEFAULT_ALPHA : args.alpha;
71190 return _this4;
71191 }
71192 _createClass(ELU, [{
71193 key: "call",
71194 value: function call(inputs, kwargs) {
71195 var x = getExactlyOneTensor(inputs);
71196 return elu$4(x);
71197 }
71198 }, {
71199 key: "computeOutputShape",
71200 value: function computeOutputShape(inputShape) {
71201 return inputShape;
71202 }
71203 }, {
71204 key: "getConfig",
71205 value: function getConfig() {
71206 var config = {
71207 alpha: this.alpha
71208 };
71209 var baseConfig = _get(_getPrototypeOf(ELU.prototype), "getConfig", this).call(this);
71210 Object.assign(config, baseConfig);
71211 return config;
71212 }
71213 }]);
71214 return ELU;
71215 }(Layer);
71216 /** @nocollapse */
71217 ELU$3.className = 'ELU';
71218 registerClass(ELU$3);
71219 var ThresholdedReLU = /*#__PURE__*/function (_Layer5) {
71220 _inherits(ThresholdedReLU, _Layer5);
71221 var _super5 = _createSuper(ThresholdedReLU);
71222 function ThresholdedReLU(args) {
71223 var _this5;
71224 _classCallCheck(this, ThresholdedReLU);
71225 _this5 = _super5.call(this, args == null ? {} : args);
71226 _this5.DEFAULT_THETA = 1.0;
71227 if (args == null) {
71228 args = {};
71229 }
71230 _this5.theta = args.theta == null ? _this5.DEFAULT_THETA : args.theta;
71231 return _this5;
71232 }
71233 _createClass(ThresholdedReLU, [{
71234 key: "call",
71235 value: function call(inputs, kwargs) {
71236 var x = getExactlyOneTensor(inputs);
71237 return mul(x, cast$3(greater$3(x, this.theta), 'float32'));
71238 }
71239 }, {
71240 key: "computeOutputShape",
71241 value: function computeOutputShape(inputShape) {
71242 return inputShape;
71243 }
71244 }, {
71245 key: "getConfig",
71246 value: function getConfig() {
71247 var config = {
71248 theta: this.theta
71249 };
71250 var baseConfig = _get(_getPrototypeOf(ThresholdedReLU.prototype), "getConfig", this).call(this);
71251 Object.assign(config, baseConfig);
71252 return config;
71253 }
71254 }]);
71255 return ThresholdedReLU;
71256 }(Layer);
71257 /** @nocollapse */
71258 ThresholdedReLU.className = 'ThresholdedReLU';
71259 registerClass(ThresholdedReLU);
71260 var Softmax = /*#__PURE__*/function (_Layer6) {
71261 _inherits(Softmax, _Layer6);
71262 var _super6 = _createSuper(Softmax);
71263 function Softmax(args) {
71264 var _this6;
71265 _classCallCheck(this, Softmax);
71266 _this6 = _super6.call(this, args == null ? {} : args);
71267 _this6.DEFAULT_AXIS = 1.0;
71268 if (args == null) {
71269 args = {};
71270 }
71271 _this6.softmax = new Softmax$1().apply;
71272 _this6.axis = args.axis == null ? _this6.DEFAULT_AXIS : args.axis;
71273 return _this6;
71274 }
71275 _createClass(Softmax, [{
71276 key: "call",
71277 value: function call(inputs, kwargs) {
71278 var _this7 = this;
71279 // TODO(pforderique): Add tests for when `this.axis` is a number[].
71280 return tidy(function () {
71281 var x = getExactlyOneTensor(inputs);
71282 var mask = kwargs['mask'];
71283 if (mask != null) {
71284 // Since mask is 1.0 for positions we want to keep and 0.0 for masked
71285 // positions, this operation will create a tensor which is 0.0 for
71286 // positions we want to attend and -1e.9 for masked positions.
71287 var adder = mul(sub$2(ones$1(x.shape), cast$3(mask, x.dtype)), scalar(-1e9));
71288 // Since we are adding it to the raw scores before the softmax, this
71289 // is effectively the same as removing these entirely.
71290 x = add$3(x, adder);
71291 }
71292 if (_this7.axis instanceof Array) {
71293 if (_this7.axis.length > 1) {
71294 return exp$2(sub$2(x, logSumExp(x, _this7.axis, true)));
71295 } else {
71296 return _this7.softmax(x, _this7.axis[0]);
71297 }
71298 }
71299 return _this7.softmax(x, _this7.axis);
71300 });
71301 }
71302 }, {
71303 key: "computeOutputShape",
71304 value: function computeOutputShape(inputShape) {
71305 return inputShape;
71306 }
71307 }, {
71308 key: "getConfig",
71309 value: function getConfig() {
71310 var config = {
71311 axis: this.axis
71312 };
71313 var baseConfig = _get(_getPrototypeOf(Softmax.prototype), "getConfig", this).call(this);
71314 Object.assign(config, baseConfig);
71315 return config;
71316 }
71317 }]);
71318 return Softmax;
71319 }(Layer);
71320 /** @nocollapse */
71321 Softmax.className = 'Softmax';
71322 registerClass(Softmax);
71323
71324 /**
71325 * @license
71326 * Copyright 2018 Google LLC
71327 *
71328 * Use of this source code is governed by an MIT-style
71329 * license that can be found in the LICENSE file or at
71330 * https://opensource.org/licenses/MIT.
71331 * =============================================================================
71332 */
71333 /**
71334 * Transforms a single number of array of numbers into an array of numbers.
71335 * @param value
71336 * @param n: The size of the tuple to be returned.
71337 * @param name: Name of the parameter, used for generating error messages.
71338 * @returns An array of numbers.
71339 */
71340 function normalizeArray(value, n, name) {
71341 if (typeof value === 'number') {
71342 return pyListRepeat(value, n);
71343 } else {
71344 if (value.length !== n) {
71345 throw new ValueError("The ".concat(name, " argument must be an integer or tuple of ").concat(n, " integers.") + " Received: ".concat(value.length, " elements."));
71346 }
71347 for (var i = 0; i < n; ++i) {
71348 var singleValue = value[i];
71349 if (!isInteger(singleValue)) {
71350 throw new ValueError("The ".concat(name, " argument must be an integer or tuple of ").concat(n) + " integers. Received: ".concat(JSON.stringify(value), " including a") + " non-integer number ".concat(singleValue));
71351 }
71352 }
71353 return value;
71354 }
71355 }
71356 /**
71357 * Determines output length of a convolution given input length.
71358 * @param inputLength
71359 * @param filterSize
71360 * @param padding
71361 * @param stride
71362 * @param dilation: dilation rate.
71363 */
71364 function convOutputLength(inputLength, filterSize, padding, stride) {
71365 var dilation = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 1;
71366 if (inputLength == null) {
71367 return inputLength;
71368 }
71369 var dilatedFilterSize = filterSize + (filterSize - 1) * (dilation - 1);
71370 var outputLength;
71371 if (padding === 'same') {
71372 outputLength = inputLength;
71373 } else {
71374 // VALID
71375 outputLength = inputLength - dilatedFilterSize + 1;
71376 }
71377 return Math.floor((outputLength + stride - 1) / stride);
71378 }
71379 function deconvLength(dimSize, strideSize, kernelSize, padding) {
71380 if (dimSize == null) {
71381 return null;
71382 }
71383 if (padding === 'valid') {
71384 dimSize = dimSize * strideSize + max$2([kernelSize - strideSize, 0]);
71385 } else if (padding === 'same') {
71386 dimSize = dimSize * strideSize;
71387 } else {
71388 throw new ValueError("Unsupport padding mode: ".concat(padding, "."));
71389 }
71390 return dimSize;
71391 }
71392
71393 /**
71394 * Transpose and cast the input before the conv2d.
71395 * @param x Input image tensor.
71396 * @param dataFormat
71397 */
71398 function preprocessConv2DInput(x, dataFormat) {
71399 // TODO(cais): Cast type to float32 if not.
71400 return tidy(function () {
71401 checkDataFormat(dataFormat);
71402 if (dataFormat === 'channelsFirst') {
71403 return transpose$2(x, [0, 2, 3, 1]); // NCHW -> NHWC.
71404 } else {
71405 return x;
71406 }
71407 });
71408 }
71409 /**
71410 * Transpose and cast the input before the conv3d.
71411 * @param x Input image tensor.
71412 * @param dataFormat
71413 */
71414 function preprocessConv3DInput(x, dataFormat) {
71415 return tidy(function () {
71416 checkDataFormat(dataFormat);
71417 if (dataFormat === 'channelsFirst') {
71418 return transpose$2(x, [0, 2, 3, 4, 1]); // NCDHW -> NDHWC.
71419 } else {
71420 return x;
71421 }
71422 });
71423 }
71424 /**
71425 * 1D-convolution with bias added.
71426 *
71427 * Porting Note: This function does not exist in the Python Keras backend.
71428 * It is exactly the same as `conv2d`, except the added `bias`.
71429 *
71430 * @param x Input tensor, rank-3, of shape `[batchSize, width, inChannels]`.
71431 * @param kernel Kernel, rank-3, of shape `[filterWidth, inDepth, outDepth]`.
71432 * @param bias Bias, rank-3, of shape `[outDepth]`.
71433 * @param strides
71434 * @param padding Padding mode.
71435 * @param dataFormat Data format.
71436 * @param dilationRate
71437 * @returns The result of the 1D convolution.
71438 * @throws ValueError, if `x`, `kernel` or `bias` is not of the correct rank.
71439 */
71440 function conv1dWithBias(x, kernel, bias) {
71441 var strides = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 1;
71442 var padding = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 'valid';
71443 var dataFormat = arguments.length > 5 ? arguments[5] : undefined;
71444 var dilationRate = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : 1;
71445 return tidy(function () {
71446 if (dataFormat == null) {
71447 dataFormat = imageDataFormat();
71448 }
71449 checkDataFormat(dataFormat);
71450 // Check the ranks of x, kernel and bias.
71451 if (x.shape.length !== 3) {
71452 throw new ValueError("The input of a conv1dWithBias operation should be 3, but is " + "".concat(x.shape.length, " instead."));
71453 }
71454 if (kernel.shape.length !== 3) {
71455 throw new ValueError("The kernel for a conv1dWithBias operation should be 3, but is " + "".concat(kernel.shape.length, " instead"));
71456 }
71457 if (bias != null && bias.shape.length !== 1) {
71458 throw new ValueError("The bias for a conv1dWithBias operation should be 1, but is " + "".concat(bias.shape.length, " instead"));
71459 }
71460 // TODO(cais): Support CAUSAL padding mode.
71461 if (dataFormat === 'channelsFirst') {
71462 x = transpose$2(x, [0, 2, 1]); // NCW -> NWC.
71463 }
71464
71465 if (padding === 'causal') {
71466 throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' + 'implemented yet.');
71467 }
71468 var y = conv1d$2(x, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NWC', dilationRate);
71469 if (bias != null) {
71470 y = biasAdd(y, bias);
71471 }
71472 return y;
71473 });
71474 }
71475 /**
71476 * 1D-convolution.
71477 *
71478 * @param x Input tensor, rank-3, of shape `[batchSize, width, inChannels]`.
71479 * @param kernel Kernel, rank-3, of shape `[filterWidth, inDepth, outDepth]`.s
71480 * @param strides
71481 * @param padding Padding mode.
71482 * @param dataFormat Data format.
71483 * @param dilationRate
71484 * @returns The result of the 1D convolution.
71485 * @throws ValueError, if `x`, `kernel` or `bias` is not of the correct rank.
71486 */
71487 function conv1d$1(x, kernel) {
71488 var strides = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 1;
71489 var padding = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'valid';
71490 var dataFormat = arguments.length > 4 ? arguments[4] : undefined;
71491 var dilationRate = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 1;
71492 return tidy(function () {
71493 checkDataFormat(dataFormat);
71494 return conv1dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate);
71495 });
71496 }
71497 /**
71498 * 2D Convolution
71499 * @param x
71500 * @param kernel kernel of the convolution.
71501 * @param strides strides array.
71502 * @param padding padding mode. Default to 'valid'.
71503 * @param dataFormat data format. Defaults to 'channelsLast'.
71504 * @param dilationRate dilation rate array.
71505 * @returns Result of the 2D pooling.
71506 */
71507 function conv2d$2(x, kernel) {
71508 var strides = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : [1, 1];
71509 var padding = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'valid';
71510 var dataFormat = arguments.length > 4 ? arguments[4] : undefined;
71511 var dilationRate = arguments.length > 5 ? arguments[5] : undefined;
71512 return tidy(function () {
71513 checkDataFormat(dataFormat);
71514 return conv2dWithBiasActivation(x, kernel, null, strides, padding, dataFormat, dilationRate);
71515 });
71516 }
71517 /**
71518 * 2D Convolution with an added bias and optional activation.
71519 * Note: This function does not exist in the Python Keras Backend. This function
71520 * is exactly the same as `conv2d`, except the added `bias`.
71521 */
71522 function conv2dWithBiasActivation(x, kernel, bias) {
71523 var strides = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : [1, 1];
71524 var padding = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 'valid';
71525 var dataFormat = arguments.length > 5 ? arguments[5] : undefined;
71526 var dilationRate = arguments.length > 6 ? arguments[6] : undefined;
71527 var activation = arguments.length > 7 && arguments[7] !== undefined ? arguments[7] : null;
71528 return tidy(function () {
71529 if (dataFormat == null) {
71530 dataFormat = imageDataFormat();
71531 }
71532 checkDataFormat(dataFormat);
71533 if (x.rank !== 3 && x.rank !== 4) {
71534 throw new ValueError("conv2dWithBiasActivation expects input to be of rank 3 or 4, " + "but received ".concat(x.rank, "."));
71535 }
71536 if (kernel.rank !== 3 && kernel.rank !== 4) {
71537 throw new ValueError("conv2dWithBiasActivation expects kernel to be of rank 3 or 4, " + "but received ".concat(x.rank, "."));
71538 }
71539 var y = preprocessConv2DInput(x, dataFormat);
71540 if (padding === 'causal') {
71541 throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' + 'implemented yet.');
71542 }
71543 y = conv2d$3({
71544 x: y,
71545 filter: kernel,
71546 strides: strides,
71547 pad: padding === 'same' ? 'same' : 'valid',
71548 dilations: dilationRate,
71549 dataFormat: 'NHWC',
71550 bias: bias,
71551 activation: activation
71552 });
71553 if (dataFormat === 'channelsFirst') {
71554 y = transpose$2(y, [0, 3, 1, 2]);
71555 }
71556 return y;
71557 });
71558 }
71559 /**
71560 * 3D Convolution.
71561 * @param x
71562 * @param kernel kernel of the convolution.
71563 * @param strides strides array.
71564 * @param padding padding mode. Default to 'valid'.
71565 * @param dataFormat data format. Defaults to 'channelsLast'.
71566 * @param dilationRate dilation rate array.
71567 * @returns Result of the 3D convolution.
71568 */
71569 function conv3d$1(x, kernel) {
71570 var strides = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : [1, 1, 1];
71571 var padding = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'valid';
71572 var dataFormat = arguments.length > 4 ? arguments[4] : undefined;
71573 var dilationRate = arguments.length > 5 ? arguments[5] : undefined;
71574 return tidy(function () {
71575 checkDataFormat(dataFormat);
71576 return conv3dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate);
71577 });
71578 }
71579 /**
71580 * 3D Convolution with an added bias.
71581 * Note: This function does not exist in the Python Keras Backend. This function
71582 * is exactly the same as `conv3d`, except the added `bias`.
71583 */
71584 function conv3dWithBias(x, kernel, bias) {
71585 var strides = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : [1, 1, 1];
71586 var padding = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 'valid';
71587 var dataFormat = arguments.length > 5 ? arguments[5] : undefined;
71588 var dilationRate = arguments.length > 6 ? arguments[6] : undefined;
71589 return tidy(function () {
71590 if (dataFormat == null) {
71591 dataFormat = imageDataFormat();
71592 }
71593 checkDataFormat(dataFormat);
71594 if (x.rank !== 4 && x.rank !== 5) {
71595 throw new ValueError("conv3dWithBias expects input to be of rank 4 or 5, but received " + "".concat(x.rank, "."));
71596 }
71597 if (kernel.rank !== 4 && kernel.rank !== 5) {
71598 throw new ValueError("conv3dWithBias expects kernel to be of rank 4 or 5, but received " + "".concat(x.rank, "."));
71599 }
71600 var y = preprocessConv3DInput(x, dataFormat);
71601 if (padding === 'causal') {
71602 throw new NotImplementedError('The support for CAUSAL padding mode in conv3dWithBias is not ' + 'implemented yet.');
71603 }
71604 y = conv3d$2(y, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NDHWC', dilationRate);
71605 if (bias != null) {
71606 y = biasAdd(y, bias);
71607 }
71608 if (dataFormat === 'channelsFirst') {
71609 y = transpose$2(y, [0, 4, 1, 2, 3]);
71610 }
71611 return y;
71612 });
71613 }
71614 /**
71615 * Abstract convolution layer.
71616 */
71617 var BaseConv = /*#__PURE__*/function (_Layer) {
71618 _inherits(BaseConv, _Layer);
71619 var _super = _createSuper(BaseConv);
71620 function BaseConv(rank, args) {
71621 var _this;
71622 _classCallCheck(this, BaseConv);
71623 _this = _super.call(this, args);
71624 _this.bias = null;
71625 _this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
71626 _this.DEFAULT_BIAS_INITIALIZER = 'zeros';
71627 BaseConv.verifyArgs(args);
71628 _this.rank = rank;
71629 assertPositiveInteger(_this.rank, 'rank');
71630 if (_this.rank !== 1 && _this.rank !== 2 && _this.rank !== 3) {
71631 throw new NotImplementedError("Convolution layer for rank other than 1, 2, or 3 (".concat(_this.rank, ") is ") + "not implemented yet.");
71632 }
71633 _this.kernelSize = normalizeArray(args.kernelSize, rank, 'kernelSize');
71634 _this.strides = normalizeArray(args.strides == null ? 1 : args.strides, rank, 'strides');
71635 _this.padding = args.padding == null ? 'valid' : args.padding;
71636 checkPaddingMode(_this.padding);
71637 _this.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat;
71638 checkDataFormat(_this.dataFormat);
71639 _this.activation = getActivation(args.activation);
71640 _this.useBias = args.useBias == null ? true : args.useBias;
71641 _this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER);
71642 _this.biasConstraint = getConstraint(args.biasConstraint);
71643 _this.biasRegularizer = getRegularizer(args.biasRegularizer);
71644 _this.activityRegularizer = getRegularizer(args.activityRegularizer);
71645 _this.dilationRate = normalizeArray(args.dilationRate == null ? 1 : args.dilationRate, rank, 'dilationRate');
71646 if (_this.rank === 1 && Array.isArray(_this.dilationRate) && _this.dilationRate.length !== 1) {
71647 throw new ValueError("dilationRate must be a number or an array of a single number " + "for 1D convolution, but received " + "".concat(JSON.stringify(_this.dilationRate)));
71648 } else if (_this.rank === 2) {
71649 if (typeof _this.dilationRate === 'number') {
71650 _this.dilationRate = [_this.dilationRate, _this.dilationRate];
71651 } else if (_this.dilationRate.length !== 2) {
71652 throw new ValueError("dilationRate must be a number or array of two numbers for 2D " + "convolution, but received ".concat(JSON.stringify(_this.dilationRate)));
71653 }
71654 } else if (_this.rank === 3) {
71655 if (typeof _this.dilationRate === 'number') {
71656 _this.dilationRate = [_this.dilationRate, _this.dilationRate, _this.dilationRate];
71657 } else if (_this.dilationRate.length !== 3) {
71658 throw new ValueError("dilationRate must be a number or array of three numbers for 3D " + "convolution, but received ".concat(JSON.stringify(_this.dilationRate)));
71659 }
71660 }
71661 return _this;
71662 }
71663 _createClass(BaseConv, [{
71664 key: "getConfig",
71665 value: function getConfig() {
71666 var config = {
71667 kernelSize: this.kernelSize,
71668 strides: this.strides,
71669 padding: this.padding,
71670 dataFormat: this.dataFormat,
71671 dilationRate: this.dilationRate,
71672 activation: serializeActivation(this.activation),
71673 useBias: this.useBias,
71674 biasInitializer: serializeInitializer(this.biasInitializer),
71675 biasRegularizer: serializeRegularizer(this.biasRegularizer),
71676 activityRegularizer: serializeRegularizer(this.activityRegularizer),
71677 biasConstraint: serializeConstraint(this.biasConstraint)
71678 };
71679 var baseConfig = _get(_getPrototypeOf(BaseConv.prototype), "getConfig", this).call(this);
71680 Object.assign(config, baseConfig);
71681 return config;
71682 }
71683 }], [{
71684 key: "verifyArgs",
71685 value: function verifyArgs(args) {
71686 // Check config.kernelSize type and shape.
71687 assert('kernelSize' in args, "required key 'kernelSize' not in config");
71688 if (typeof args.kernelSize !== 'number' && !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 3)) {
71689 throw new ValueError("BaseConv expects config.kernelSize to be number or number[] with " + "length 1, 2, or 3, but received ".concat(JSON.stringify(args.kernelSize), "."));
71690 }
71691 }
71692 }]);
71693 return BaseConv;
71694 }(Layer);
71695 /**
71696 * Abstract nD convolution layer. Ancestor of convolution layers which reduce
71697 * across channels, i.e., Conv1D and Conv2D, but not DepthwiseConv2D.
71698 */
71699 var Conv = /*#__PURE__*/function (_BaseConv) {
71700 _inherits(Conv, _BaseConv);
71701 var _super2 = _createSuper(Conv);
71702 function Conv(rank, args) {
71703 var _this2;
71704 _classCallCheck(this, Conv);
71705 _this2 = _super2.call(this, rank, args);
71706 _this2.kernel = null;
71707 Conv.verifyArgs(args);
71708 _this2.filters = args.filters;
71709 assertPositiveInteger(_this2.filters, 'filters');
71710 _this2.kernelInitializer = getInitializer(args.kernelInitializer || _this2.DEFAULT_KERNEL_INITIALIZER);
71711 _this2.kernelConstraint = getConstraint(args.kernelConstraint);
71712 _this2.kernelRegularizer = getRegularizer(args.kernelRegularizer);
71713 return _this2;
71714 }
71715 _createClass(Conv, [{
71716 key: "build",
71717 value: function build(inputShape) {
71718 inputShape = getExactlyOneShape(inputShape);
71719 var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
71720 if (inputShape[channelAxis] == null) {
71721 throw new ValueError("The channel dimension of the input should be defined. " + "Found ".concat(inputShape[channelAxis]));
71722 }
71723 var inputDim = inputShape[channelAxis];
71724 var kernelShape = this.kernelSize.concat([inputDim, this.filters]);
71725 this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
71726 if (this.useBias) {
71727 this.bias = this.addWeight('bias', [this.filters], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
71728 }
71729 this.inputSpec = [{
71730 ndim: this.rank + 2,
71731 axes: _defineProperty({}, channelAxis, inputDim)
71732 }];
71733 this.built = true;
71734 }
71735 }, {
71736 key: "call",
71737 value: function call(inputs, kwargs) {
71738 var _this3 = this;
71739 return tidy(function () {
71740 inputs = getExactlyOneTensor(inputs);
71741 var outputs;
71742 var biasValue = _this3.bias == null ? null : _this3.bias.read();
71743 var fusedActivationName = mapActivationToFusedKernel(_this3.activation.getClassName());
71744 if (fusedActivationName != null && _this3.rank === 2) {
71745 outputs = conv2dWithBiasActivation(inputs, _this3.kernel.read(), biasValue, _this3.strides, _this3.padding, _this3.dataFormat, _this3.dilationRate, fusedActivationName);
71746 } else {
71747 if (_this3.rank === 1) {
71748 outputs = conv1dWithBias(inputs, _this3.kernel.read(), biasValue, _this3.strides[0], _this3.padding, _this3.dataFormat, _this3.dilationRate[0]);
71749 } else if (_this3.rank === 2) {
71750 // TODO(cais): Move up to constructor.
71751 outputs = conv2dWithBiasActivation(inputs, _this3.kernel.read(), biasValue, _this3.strides, _this3.padding, _this3.dataFormat, _this3.dilationRate);
71752 } else if (_this3.rank === 3) {
71753 outputs = conv3dWithBias(inputs, _this3.kernel.read(), biasValue, _this3.strides, _this3.padding, _this3.dataFormat, _this3.dilationRate);
71754 } else {
71755 throw new NotImplementedError('convolutions greater than 3D are not implemented yet.');
71756 }
71757 if (_this3.activation != null) {
71758 outputs = _this3.activation.apply(outputs);
71759 }
71760 }
71761 return outputs;
71762 });
71763 }
71764 }, {
71765 key: "computeOutputShape",
71766 value: function computeOutputShape(inputShape) {
71767 inputShape = getExactlyOneShape(inputShape);
71768 var newSpace = [];
71769 var space = this.dataFormat === 'channelsLast' ? inputShape.slice(1, inputShape.length - 1) : inputShape.slice(2);
71770 for (var i = 0; i < space.length; ++i) {
71771 var newDim = convOutputLength(space[i], this.kernelSize[i], this.padding, this.strides[i], typeof this.dilationRate === 'number' ? this.dilationRate : this.dilationRate[i]);
71772 newSpace.push(newDim);
71773 }
71774 var outputShape = [inputShape[0]];
71775 if (this.dataFormat === 'channelsLast') {
71776 outputShape = outputShape.concat(newSpace);
71777 outputShape.push(this.filters);
71778 } else {
71779 outputShape.push(this.filters);
71780 outputShape = outputShape.concat(newSpace);
71781 }
71782 return outputShape;
71783 }
71784 }, {
71785 key: "getConfig",
71786 value: function getConfig() {
71787 var config = {
71788 filters: this.filters,
71789 kernelInitializer: serializeInitializer(this.kernelInitializer),
71790 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
71791 kernelConstraint: serializeConstraint(this.kernelConstraint)
71792 };
71793 var baseConfig = _get(_getPrototypeOf(Conv.prototype), "getConfig", this).call(this);
71794 Object.assign(config, baseConfig);
71795 return config;
71796 }
71797 }], [{
71798 key: "verifyArgs",
71799 value: function verifyArgs(args) {
71800 // Check config.filters type, shape, and value.
71801 if (!('filters' in args) || typeof args.filters !== 'number' || args.filters < 1) {
71802 throw new ValueError("Convolution layer expected config.filters to be a 'number' > 0 " + "but got ".concat(JSON.stringify(args.filters)));
71803 }
71804 }
71805 }]);
71806 return Conv;
71807 }(BaseConv);
71808 var Conv2D = /*#__PURE__*/function (_Conv) {
71809 _inherits(Conv2D, _Conv);
71810 var _super3 = _createSuper(Conv2D);
71811 function Conv2D(args) {
71812 var _this4;
71813 _classCallCheck(this, Conv2D);
71814 _this4 = _super3.call(this, 2, args);
71815 Conv2D.verifyArgs(args);
71816 return _this4;
71817 }
71818 _createClass(Conv2D, [{
71819 key: "getConfig",
71820 value: function getConfig() {
71821 var config = _get(_getPrototypeOf(Conv2D.prototype), "getConfig", this).call(this);
71822 delete config['rank'];
71823 return config;
71824 }
71825 }], [{
71826 key: "verifyArgs",
71827 value: function verifyArgs(args) {
71828 // config.kernelSize must be a number or array of numbers.
71829 if (typeof args.kernelSize !== 'number' && !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 2)) {
71830 throw new ValueError("Conv2D expects config.kernelSize to be number or number[] with " + "length 1 or 2, but received ".concat(JSON.stringify(args.kernelSize), "."));
71831 }
71832 }
71833 }]);
71834 return Conv2D;
71835 }(Conv);
71836 /** @nocollapse */
71837 Conv2D.className = 'Conv2D';
71838 registerClass(Conv2D);
71839 var Conv3D = /*#__PURE__*/function (_Conv2) {
71840 _inherits(Conv3D, _Conv2);
71841 var _super4 = _createSuper(Conv3D);
71842 function Conv3D(args) {
71843 var _this5;
71844 _classCallCheck(this, Conv3D);
71845 _this5 = _super4.call(this, 3, args);
71846 Conv3D.verifyArgs(args);
71847 return _this5;
71848 }
71849 _createClass(Conv3D, [{
71850 key: "getConfig",
71851 value: function getConfig() {
71852 var config = _get(_getPrototypeOf(Conv3D.prototype), "getConfig", this).call(this);
71853 delete config['rank'];
71854 return config;
71855 }
71856 }], [{
71857 key: "verifyArgs",
71858 value: function verifyArgs(args) {
71859 // config.kernelSize must be a number or array of numbers.
71860 if (typeof args.kernelSize !== 'number') {
71861 if (!(Array.isArray(args.kernelSize) && (args.kernelSize.length === 1 || args.kernelSize.length === 3))) {
71862 throw new ValueError("Conv3D expects config.kernelSize to be number or" + " [number, number, number], but received ".concat(JSON.stringify(args.kernelSize), "."));
71863 }
71864 }
71865 }
71866 }]);
71867 return Conv3D;
71868 }(Conv);
71869 /** @nocollapse */
71870 Conv3D.className = 'Conv3D';
71871 registerClass(Conv3D);
71872 var Conv2DTranspose = /*#__PURE__*/function (_Conv2D) {
71873 _inherits(Conv2DTranspose, _Conv2D);
71874 var _super5 = _createSuper(Conv2DTranspose);
71875 function Conv2DTranspose(args) {
71876 var _this6;
71877 _classCallCheck(this, Conv2DTranspose);
71878 _this6 = _super5.call(this, args);
71879 _this6.inputSpec = [new InputSpec({
71880 ndim: 4
71881 })];
71882 if (_this6.padding !== 'same' && _this6.padding !== 'valid') {
71883 throw new ValueError("Conv2DTranspose currently supports only padding modes 'same' " + "and 'valid', but received padding mode ".concat(_this6.padding));
71884 }
71885 return _this6;
71886 }
71887 _createClass(Conv2DTranspose, [{
71888 key: "build",
71889 value: function build(inputShape) {
71890 inputShape = getExactlyOneShape(inputShape);
71891 if (inputShape.length !== 4) {
71892 throw new ValueError('Input should have rank 4; Received input shape: ' + JSON.stringify(inputShape));
71893 }
71894 var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
71895 if (inputShape[channelAxis] == null) {
71896 throw new ValueError('The channel dimension of the inputs should be defined. ' + 'Found `None`.');
71897 }
71898 var inputDim = inputShape[channelAxis];
71899 var kernelShape = this.kernelSize.concat([this.filters, inputDim]);
71900 this.kernel = this.addWeight('kernel', kernelShape, 'float32', this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
71901 if (this.useBias) {
71902 this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
71903 }
71904 // Set input spec.
71905 this.inputSpec = [new InputSpec({
71906 ndim: 4,
71907 axes: _defineProperty({}, channelAxis, inputDim)
71908 })];
71909 this.built = true;
71910 }
71911 }, {
71912 key: "call",
71913 value: function call(inputs, kwargs) {
71914 var _this7 = this;
71915 return tidy(function () {
71916 var input = getExactlyOneTensor(inputs);
71917 if (input.shape.length !== 4) {
71918 throw new ValueError("Conv2DTranspose.call() expects input tensor to be rank-4, but " + "received a tensor of rank-".concat(input.shape.length));
71919 }
71920 var inputShape = input.shape;
71921 var batchSize = inputShape[0];
71922 var hAxis;
71923 var wAxis;
71924 if (_this7.dataFormat === 'channelsFirst') {
71925 hAxis = 2;
71926 wAxis = 3;
71927 } else {
71928 hAxis = 1;
71929 wAxis = 2;
71930 }
71931 var height = inputShape[hAxis];
71932 var width = inputShape[wAxis];
71933 var kernelH = _this7.kernelSize[0];
71934 var kernelW = _this7.kernelSize[1];
71935 var strideH = _this7.strides[0];
71936 var strideW = _this7.strides[1];
71937 // Infer the dynamic output shape.
71938 var outHeight = deconvLength(height, strideH, kernelH, _this7.padding);
71939 var outWidth = deconvLength(width, strideW, kernelW, _this7.padding);
71940 // Porting Note: We don't branch based on `this.dataFormat` here,
71941 // because
71942 // the tjfs-core function `conv2dTranspose` called below always
71943 // assumes channelsLast.
71944 var outputShape = [batchSize, outHeight, outWidth, _this7.filters];
71945 if (_this7.dataFormat !== 'channelsLast') {
71946 input = transpose$2(input, [0, 2, 3, 1]);
71947 }
71948 var outputs = conv2dTranspose$1(input, _this7.kernel.read(), outputShape, _this7.strides, _this7.padding);
71949 if (_this7.dataFormat !== 'channelsLast') {
71950 outputs = transpose$2(outputs, [0, 3, 1, 2]);
71951 }
71952 if (_this7.bias != null) {
71953 outputs = biasAdd(outputs, _this7.bias.read(), _this7.dataFormat);
71954 }
71955 if (_this7.activation != null) {
71956 outputs = _this7.activation.apply(outputs);
71957 }
71958 return outputs;
71959 });
71960 }
71961 }, {
71962 key: "computeOutputShape",
71963 value: function computeOutputShape(inputShape) {
71964 inputShape = getExactlyOneShape(inputShape);
71965 var outputShape = inputShape.slice();
71966 var channelAxis;
71967 var heightAxis;
71968 var widthAxis;
71969 if (this.dataFormat === 'channelsFirst') {
71970 channelAxis = 1;
71971 heightAxis = 2;
71972 widthAxis = 3;
71973 } else {
71974 channelAxis = 3;
71975 heightAxis = 1;
71976 widthAxis = 2;
71977 }
71978 var kernelH = this.kernelSize[0];
71979 var kernelW = this.kernelSize[1];
71980 var strideH = this.strides[0];
71981 var strideW = this.strides[1];
71982 outputShape[channelAxis] = this.filters;
71983 outputShape[heightAxis] = deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding);
71984 outputShape[widthAxis] = deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding);
71985 return outputShape;
71986 }
71987 }, {
71988 key: "getConfig",
71989 value: function getConfig() {
71990 var config = _get(_getPrototypeOf(Conv2DTranspose.prototype), "getConfig", this).call(this);
71991 delete config['dilationRate'];
71992 return config;
71993 }
71994 }]);
71995 return Conv2DTranspose;
71996 }(Conv2D);
71997 /** @nocollapse */
71998 Conv2DTranspose.className = 'Conv2DTranspose';
71999 registerClass(Conv2DTranspose);
72000 var Conv3DTranspose = /*#__PURE__*/function (_Conv3D) {
72001 _inherits(Conv3DTranspose, _Conv3D);
72002 var _super6 = _createSuper(Conv3DTranspose);
72003 function Conv3DTranspose(args) {
72004 var _this8;
72005 _classCallCheck(this, Conv3DTranspose);
72006 _this8 = _super6.call(this, args);
72007 _this8.inputSpec = [new InputSpec({
72008 ndim: 5
72009 })];
72010 if (_this8.padding !== 'same' && _this8.padding !== 'valid') {
72011 throw new ValueError("Conv3DTranspose currently supports only padding modes 'same' " + "and 'valid', but received padding mode ".concat(_this8.padding));
72012 }
72013 return _this8;
72014 }
72015 _createClass(Conv3DTranspose, [{
72016 key: "build",
72017 value: function build(inputShape) {
72018 inputShape = getExactlyOneShape(inputShape);
72019 if (inputShape.length !== 5) {
72020 throw new ValueError('Input should have rank 5; Received input shape: ' + JSON.stringify(inputShape));
72021 }
72022 var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
72023 if (inputShape[channelAxis] == null) {
72024 throw new ValueError('The channel dimension of the inputs should be defined. ' + 'Found `None`.');
72025 }
72026 var inputDim = inputShape[channelAxis];
72027 var kernelShape = this.kernelSize.concat([this.filters, inputDim]);
72028 this.kernel = this.addWeight('kernel', kernelShape, 'float32', this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
72029 if (this.useBias) {
72030 this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
72031 }
72032 // Set input spec.
72033 this.inputSpec = [new InputSpec({
72034 ndim: 5,
72035 axes: _defineProperty({}, channelAxis, inputDim)
72036 })];
72037 this.built = true;
72038 }
72039 }, {
72040 key: "call",
72041 value: function call(inputs, kwargs) {
72042 var _this9 = this;
72043 return tidy(function () {
72044 var input = getExactlyOneTensor(inputs);
72045 if (input.shape.length !== 5) {
72046 throw new ValueError("Conv3DTranspose.call() expects input tensor to be rank-4, but " + "received a tensor of rank-".concat(input.shape.length));
72047 }
72048 var inputShape = input.shape;
72049 var batchSize = inputShape[0];
72050 var hAxis;
72051 var wAxis;
72052 var dAxis;
72053 if (_this9.dataFormat === 'channelsFirst') {
72054 dAxis = 2;
72055 hAxis = 3;
72056 wAxis = 4;
72057 } else {
72058 dAxis = 1;
72059 hAxis = 2;
72060 wAxis = 3;
72061 }
72062 var depth = inputShape[dAxis];
72063 var height = inputShape[hAxis];
72064 var width = inputShape[wAxis];
72065 var kernelD = _this9.kernelSize[0];
72066 var kernelH = _this9.kernelSize[1];
72067 var kernelW = _this9.kernelSize[2];
72068 var strideD = _this9.strides[0];
72069 var strideH = _this9.strides[1];
72070 var strideW = _this9.strides[2];
72071 // Infer the dynamic output shape.
72072 var outDepth = deconvLength(depth, strideD, kernelD, _this9.padding);
72073 var outHeight = deconvLength(height, strideH, kernelH, _this9.padding);
72074 var outWidth = deconvLength(width, strideW, kernelW, _this9.padding);
72075 // Same as `conv2dTranspose`. We always assumes channelsLast.
72076 var outputShape = [batchSize, outDepth, outHeight, outWidth, _this9.filters];
72077 if (_this9.dataFormat !== 'channelsLast') {
72078 input = transpose$2(input, [0, 2, 3, 4, 1]);
72079 }
72080 var outputs = conv3dTranspose$1(input, _this9.kernel.read(), outputShape, _this9.strides, _this9.padding);
72081 if (_this9.dataFormat !== 'channelsLast') {
72082 outputs = transpose$2(outputs, [0, 4, 1, 2, 3]);
72083 }
72084 if (_this9.bias !== null) {
72085 outputs = biasAdd(outputs, _this9.bias.read(), _this9.dataFormat);
72086 }
72087 if (_this9.activation !== null) {
72088 outputs = _this9.activation.apply(outputs);
72089 }
72090 return outputs;
72091 });
72092 }
72093 }, {
72094 key: "computeOutputShape",
72095 value: function computeOutputShape(inputShape) {
72096 inputShape = getExactlyOneShape(inputShape);
72097 var outputShape = inputShape.slice();
72098 var channelAxis;
72099 var depthAxis;
72100 var heightAxis;
72101 var widthAxis;
72102 if (this.dataFormat === 'channelsFirst') {
72103 channelAxis = 1;
72104 depthAxis = 2;
72105 heightAxis = 3;
72106 widthAxis = 4;
72107 } else {
72108 channelAxis = 4;
72109 depthAxis = 1;
72110 heightAxis = 2;
72111 widthAxis = 3;
72112 }
72113 var kernelD = this.kernelSize[0];
72114 var kernelH = this.kernelSize[1];
72115 var kernelW = this.kernelSize[2];
72116 var strideD = this.strides[0];
72117 var strideH = this.strides[1];
72118 var strideW = this.strides[2];
72119 outputShape[channelAxis] = this.filters;
72120 outputShape[depthAxis] = deconvLength(outputShape[depthAxis], strideD, kernelD, this.padding);
72121 outputShape[heightAxis] = deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding);
72122 outputShape[widthAxis] = deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding);
72123 return outputShape;
72124 }
72125 }, {
72126 key: "getConfig",
72127 value: function getConfig() {
72128 var config = _get(_getPrototypeOf(Conv3DTranspose.prototype), "getConfig", this).call(this);
72129 delete config['dilationRate'];
72130 return config;
72131 }
72132 }]);
72133 return Conv3DTranspose;
72134 }(Conv3D);
72135 /** @nocollapse */
72136 Conv3DTranspose.className = 'Conv3DTranspose';
72137 registerClass(Conv3DTranspose);
72138 var SeparableConv = /*#__PURE__*/function (_Conv3) {
72139 _inherits(SeparableConv, _Conv3);
72140 var _super7 = _createSuper(SeparableConv);
72141 function SeparableConv(rank, config) {
72142 var _this10;
72143 _classCallCheck(this, SeparableConv);
72144 _this10 = _super7.call(this, rank, config);
72145 _this10.DEFAULT_DEPTHWISE_INITIALIZER = 'glorotUniform';
72146 _this10.DEFAULT_POINTWISE_INITIALIZER = 'glorotUniform';
72147 _this10.depthwiseKernel = null;
72148 _this10.pointwiseKernel = null;
72149 if (config.filters == null) {
72150 throw new ValueError('The `filters` configuration field is required by SeparableConv, ' + 'but is unspecified.');
72151 }
72152 if (config.kernelInitializer != null || config.kernelRegularizer != null || config.kernelConstraint != null) {
72153 throw new ValueError('Fields kernelInitializer, kernelRegularizer and kernelConstraint ' + 'are invalid for SeparableConv2D. Use depthwiseInitializer, ' + 'depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, ' + 'pointwiseRegularizer and pointwiseConstraint instead.');
72154 }
72155 if (config.padding != null && config.padding !== 'same' && config.padding !== 'valid') {
72156 throw new ValueError("SeparableConv".concat(_this10.rank, "D supports only padding modes: ") + "'same' and 'valid', but received ".concat(JSON.stringify(config.padding)));
72157 }
72158 _this10.depthMultiplier = config.depthMultiplier == null ? 1 : config.depthMultiplier;
72159 _this10.depthwiseInitializer = getInitializer(config.depthwiseInitializer || _this10.DEFAULT_DEPTHWISE_INITIALIZER);
72160 _this10.depthwiseRegularizer = getRegularizer(config.depthwiseRegularizer);
72161 _this10.depthwiseConstraint = getConstraint(config.depthwiseConstraint);
72162 _this10.pointwiseInitializer = getInitializer(config.depthwiseInitializer || _this10.DEFAULT_POINTWISE_INITIALIZER);
72163 _this10.pointwiseRegularizer = getRegularizer(config.pointwiseRegularizer);
72164 _this10.pointwiseConstraint = getConstraint(config.pointwiseConstraint);
72165 return _this10;
72166 }
72167 _createClass(SeparableConv, [{
72168 key: "build",
72169 value: function build(inputShape) {
72170 inputShape = getExactlyOneShape(inputShape);
72171 if (inputShape.length < this.rank + 2) {
72172 throw new ValueError("Inputs to SeparableConv".concat(this.rank, "D should have rank ") + "".concat(this.rank + 2, ", but received input shape: ") + "".concat(JSON.stringify(inputShape)));
72173 }
72174 var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
72175 if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
72176 throw new ValueError("The channel dimension of the inputs should be defined, " + "but found ".concat(JSON.stringify(inputShape[channelAxis])));
72177 }
72178 var inputDim = inputShape[channelAxis];
72179 var depthwiseKernelShape = this.kernelSize.concat([inputDim, this.depthMultiplier]);
72180 var pointwiseKernelShape = [];
72181 for (var i = 0; i < this.rank; ++i) {
72182 pointwiseKernelShape.push(1);
72183 }
72184 pointwiseKernelShape.push(inputDim * this.depthMultiplier, this.filters);
72185 var trainable = true;
72186 this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, 'float32', this.depthwiseInitializer, this.depthwiseRegularizer, trainable, this.depthwiseConstraint);
72187 this.pointwiseKernel = this.addWeight('pointwise_kernel', pointwiseKernelShape, 'float32', this.pointwiseInitializer, this.pointwiseRegularizer, trainable, this.pointwiseConstraint);
72188 if (this.useBias) {
72189 this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, trainable, this.biasConstraint);
72190 } else {
72191 this.bias = null;
72192 }
72193 this.inputSpec = [new InputSpec({
72194 ndim: this.rank + 2,
72195 axes: _defineProperty({}, channelAxis, inputDim)
72196 })];
72197 this.built = true;
72198 }
72199 }, {
72200 key: "call",
72201 value: function call(inputs, kwargs) {
72202 var _this11 = this;
72203 return tidy(function () {
72204 inputs = getExactlyOneTensor(inputs);
72205 var output;
72206 if (_this11.rank === 1) {
72207 throw new NotImplementedError('1D separable convolution is not implemented yet.');
72208 } else if (_this11.rank === 2) {
72209 if (_this11.dataFormat === 'channelsFirst') {
72210 inputs = transpose$2(inputs, [0, 2, 3, 1]); // NCHW -> NHWC.
72211 }
72212
72213 output = separableConv2d$1(inputs, _this11.depthwiseKernel.read(), _this11.pointwiseKernel.read(), _this11.strides, _this11.padding, _this11.dilationRate, 'NHWC');
72214 }
72215 if (_this11.useBias) {
72216 output = biasAdd(output, _this11.bias.read(), _this11.dataFormat);
72217 }
72218 if (_this11.activation != null) {
72219 output = _this11.activation.apply(output);
72220 }
72221 if (_this11.dataFormat === 'channelsFirst') {
72222 output = transpose$2(output, [0, 3, 1, 2]); // NHWC -> NCHW.
72223 }
72224
72225 return output;
72226 });
72227 }
72228 }, {
72229 key: "getConfig",
72230 value: function getConfig() {
72231 var config = _get(_getPrototypeOf(SeparableConv.prototype), "getConfig", this).call(this);
72232 delete config['rank'];
72233 delete config['kernelInitializer'];
72234 delete config['kernelRegularizer'];
72235 delete config['kernelConstraint'];
72236 config['depthwiseInitializer'] = serializeInitializer(this.depthwiseInitializer);
72237 config['pointwiseInitializer'] = serializeInitializer(this.pointwiseInitializer);
72238 config['depthwiseRegularizer'] = serializeRegularizer(this.depthwiseRegularizer);
72239 config['pointwiseRegularizer'] = serializeRegularizer(this.pointwiseRegularizer);
72240 config['depthwiseConstraint'] = serializeConstraint(this.depthwiseConstraint);
72241 config['pointwiseConstraint'] = serializeConstraint(this.pointwiseConstraint);
72242 return config;
72243 }
72244 }]);
72245 return SeparableConv;
72246 }(Conv);
72247 /** @nocollapse */
72248 SeparableConv.className = 'SeparableConv';
72249 var SeparableConv2D = /*#__PURE__*/function (_SeparableConv) {
72250 _inherits(SeparableConv2D, _SeparableConv);
72251 var _super8 = _createSuper(SeparableConv2D);
72252 function SeparableConv2D(args) {
72253 _classCallCheck(this, SeparableConv2D);
72254 return _super8.call(this, 2, args);
72255 }
72256 return _createClass(SeparableConv2D);
72257 }(SeparableConv);
72258 /** @nocollapse */
72259 SeparableConv2D.className = 'SeparableConv2D';
72260 registerClass(SeparableConv2D);
72261 var Conv1D = /*#__PURE__*/function (_Conv4) {
72262 _inherits(Conv1D, _Conv4);
72263 var _super9 = _createSuper(Conv1D);
72264 function Conv1D(args) {
72265 var _this12;
72266 _classCallCheck(this, Conv1D);
72267 _this12 = _super9.call(this, 1, args);
72268 Conv1D.verifyArgs(args);
72269 _this12.inputSpec = [{
72270 ndim: 3
72271 }];
72272 return _this12;
72273 }
72274 _createClass(Conv1D, [{
72275 key: "getConfig",
72276 value: function getConfig() {
72277 var config = _get(_getPrototypeOf(Conv1D.prototype), "getConfig", this).call(this);
72278 delete config['rank'];
72279 delete config['dataFormat'];
72280 return config;
72281 }
72282 }], [{
72283 key: "verifyArgs",
72284 value: function verifyArgs(args) {
72285 // config.kernelSize must be a number or array of numbers.
72286 if (typeof args.kernelSize !== 'number' && !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 1)) {
72287 throw new ValueError("Conv1D expects config.kernelSize to be number or number[] with " + "length 1, but received ".concat(JSON.stringify(args.kernelSize), "."));
72288 }
72289 }
72290 }]);
72291 return Conv1D;
72292 }(Conv);
72293 /** @nocollapse */
72294 Conv1D.className = 'Conv1D';
72295 registerClass(Conv1D);
72296 var Cropping2D = /*#__PURE__*/function (_Layer2) {
72297 _inherits(Cropping2D, _Layer2);
72298 var _super10 = _createSuper(Cropping2D);
72299 function Cropping2D(args) {
72300 var _this13;
72301 _classCallCheck(this, Cropping2D);
72302 _this13 = _super10.call(this, args);
72303 if (typeof args.cropping === 'number') {
72304 _this13.cropping = [[args.cropping, args.cropping], [args.cropping, args.cropping]];
72305 } else if (typeof args.cropping[0] === 'number') {
72306 _this13.cropping = [[args.cropping[0], args.cropping[0]], [args.cropping[1], args.cropping[1]]];
72307 } else {
72308 _this13.cropping = args.cropping;
72309 }
72310 _this13.dataFormat = args.dataFormat === undefined ? 'channelsLast' : args.dataFormat;
72311 _this13.inputSpec = [{
72312 ndim: 4
72313 }];
72314 return _this13;
72315 }
72316 _createClass(Cropping2D, [{
72317 key: "computeOutputShape",
72318 value: function computeOutputShape(inputShape) {
72319 if (this.dataFormat === 'channelsFirst') {
72320 return [inputShape[0], inputShape[1], inputShape[2] - this.cropping[0][0] - this.cropping[0][1], inputShape[3] - this.cropping[1][0] - this.cropping[1][1]];
72321 } else {
72322 return [inputShape[0], inputShape[1] - this.cropping[0][0] - this.cropping[0][1], inputShape[2] - this.cropping[1][0] - this.cropping[1][1], inputShape[3]];
72323 }
72324 }
72325 }, {
72326 key: "call",
72327 value: function call(inputs, kwargs) {
72328 var _this14 = this;
72329 return tidy(function () {
72330 inputs = getExactlyOneTensor(inputs);
72331 if (_this14.dataFormat === 'channelsLast') {
72332 var hSliced = sliceAlongAxis(inputs, _this14.cropping[0][0], inputs.shape[1] - _this14.cropping[0][0] - _this14.cropping[0][1], 2);
72333 return sliceAlongAxis(hSliced, _this14.cropping[1][0], inputs.shape[2] - _this14.cropping[1][1] - _this14.cropping[1][0], 3);
72334 } else {
72335 var _hSliced = sliceAlongAxis(inputs, _this14.cropping[0][0], inputs.shape[2] - _this14.cropping[0][0] - _this14.cropping[0][1], 3);
72336 return sliceAlongAxis(_hSliced, _this14.cropping[1][0], inputs.shape[3] - _this14.cropping[1][1] - _this14.cropping[1][0], 4);
72337 }
72338 });
72339 }
72340 }, {
72341 key: "getConfig",
72342 value: function getConfig() {
72343 var config = {
72344 cropping: this.cropping,
72345 dataFormat: this.dataFormat
72346 };
72347 var baseConfig = _get(_getPrototypeOf(Cropping2D.prototype), "getConfig", this).call(this);
72348 Object.assign(config, baseConfig);
72349 return config;
72350 }
72351 }]);
72352 return Cropping2D;
72353 }(Layer);
72354 /** @nocollapse */
72355 Cropping2D.className = 'Cropping2D';
72356 registerClass(Cropping2D);
72357 var UpSampling2D = /*#__PURE__*/function (_Layer3) {
72358 _inherits(UpSampling2D, _Layer3);
72359 var _super11 = _createSuper(UpSampling2D);
72360 function UpSampling2D(args) {
72361 var _this15;
72362 _classCallCheck(this, UpSampling2D);
72363 _this15 = _super11.call(this, args);
72364 _this15.DEFAULT_SIZE = [2, 2];
72365 _this15.inputSpec = [{
72366 ndim: 4
72367 }];
72368 _this15.size = args.size == null ? _this15.DEFAULT_SIZE : args.size;
72369 _this15.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat;
72370 checkDataFormat(_this15.dataFormat);
72371 _this15.interpolation = args.interpolation == null ? 'nearest' : args.interpolation;
72372 checkInterpolationFormat(_this15.interpolation);
72373 return _this15;
72374 }
72375 _createClass(UpSampling2D, [{
72376 key: "computeOutputShape",
72377 value: function computeOutputShape(inputShape) {
72378 if (this.dataFormat === 'channelsFirst') {
72379 var height = inputShape[2] == null ? null : this.size[0] * inputShape[2];
72380 var width = inputShape[3] == null ? null : this.size[1] * inputShape[3];
72381 return [inputShape[0], inputShape[1], height, width];
72382 } else {
72383 var _height = inputShape[1] == null ? null : this.size[0] * inputShape[1];
72384 var _width = inputShape[2] == null ? null : this.size[1] * inputShape[2];
72385 return [inputShape[0], _height, _width, inputShape[3]];
72386 }
72387 }
72388 }, {
72389 key: "call",
72390 value: function call(inputs, kwargs) {
72391 var _this16 = this;
72392 return tidy(function () {
72393 var input = getExactlyOneTensor(inputs);
72394 var inputShape = input.shape;
72395 if (_this16.dataFormat === 'channelsFirst') {
72396 input = transpose$2(input, [0, 2, 3, 1]);
72397 var height = _this16.size[0] * inputShape[2];
72398 var width = _this16.size[1] * inputShape[3];
72399 var resized = _this16.interpolation === 'nearest' ? image$1.resizeNearestNeighbor(input, [height, width]) : image$1.resizeBilinear(input, [height, width]);
72400 return transpose$2(resized, [0, 3, 1, 2]);
72401 } else {
72402 var _height2 = _this16.size[0] * inputShape[1];
72403 var _width2 = _this16.size[1] * inputShape[2];
72404 return _this16.interpolation === 'nearest' ? image$1.resizeNearestNeighbor(input, [_height2, _width2]) : image$1.resizeBilinear(input, [_height2, _width2]);
72405 }
72406 });
72407 }
72408 }, {
72409 key: "getConfig",
72410 value: function getConfig() {
72411 var config = {
72412 size: this.size,
72413 dataFormat: this.dataFormat,
72414 interpolation: this.interpolation
72415 };
72416 var baseConfig = _get(_getPrototypeOf(UpSampling2D.prototype), "getConfig", this).call(this);
72417 Object.assign(config, baseConfig);
72418 return config;
72419 }
72420 }]);
72421 return UpSampling2D;
72422 }(Layer);
72423 /** @nocollapse */
72424 UpSampling2D.className = 'UpSampling2D';
72425 registerClass(UpSampling2D);
72426
72427 /**
72428 * 2D convolution with separable filters.
72429 * @param x Input tensor.
72430 * @param depthwiseKernel Convolution kernel for depthwise convolution.
72431 * @param strides Strides (Array of two integers).
72432 * @param padding Padding model.
72433 * @param dataFormat Data format.
72434 * @param dilationRate Array of two integers, dilation rates for the separable
72435 * convolution.
72436 * @returns Output tensor.
72437 * @throws ValueError If depthwiseKernel is not a 4D array.
72438 */
72439 function depthwiseConv2d$1(x, depthwiseKernel) {
72440 var strides = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : [1, 1];
72441 var padding = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 'valid';
72442 var dataFormat = arguments.length > 4 ? arguments[4] : undefined;
72443 var dilationRate = arguments.length > 5 ? arguments[5] : undefined;
72444 return tidy(function () {
72445 if (dataFormat == null) {
72446 dataFormat = imageDataFormat();
72447 }
72448 checkDataFormat(dataFormat);
72449 var y = preprocessConv2DInput(x, dataFormat);
72450 if (x.rank !== 4) {
72451 throw new ValueError("Input for depthwiseConv2d is required to be 4-D, but is instead " + "".concat(x.rank, "-D"));
72452 }
72453 if (depthwiseKernel.rank !== 4) {
72454 throw new ValueError("depthwiseKernel is required to be 4-D, but is instead " + "".concat(depthwiseKernel.rank, "-D"));
72455 }
72456 y = depthwiseConv2d$3(y, depthwiseKernel, strides, padding === 'same' ? 'same' : 'valid', 'NHWC', dilationRate);
72457 if (dataFormat === 'channelsFirst') {
72458 y = transpose$2(y, [0, 3, 1, 2]);
72459 }
72460 return y;
72461 });
72462 }
72463 var DepthwiseConv2D = /*#__PURE__*/function (_BaseConv) {
72464 _inherits(DepthwiseConv2D, _BaseConv);
72465 var _super = _createSuper(DepthwiseConv2D);
72466 function DepthwiseConv2D(args) {
72467 var _this;
72468 _classCallCheck(this, DepthwiseConv2D);
72469 _this = _super.call(this, 2, args);
72470 _this.depthwiseKernel = null;
72471 _this.depthMultiplier = args.depthMultiplier == null ? 1 : args.depthMultiplier;
72472 _this.depthwiseInitializer = getInitializer(args.depthwiseInitializer || _this.DEFAULT_KERNEL_INITIALIZER);
72473 _this.depthwiseConstraint = getConstraint(args.depthwiseConstraint);
72474 _this.depthwiseRegularizer = getRegularizer(args.depthwiseRegularizer);
72475 return _this;
72476 }
72477 _createClass(DepthwiseConv2D, [{
72478 key: "build",
72479 value: function build(inputShape) {
72480 inputShape = getExactlyOneShape(inputShape);
72481 if (inputShape.length < 4) {
72482 throw new ValueError("Inputs to DepthwiseConv2D should have rank 4. " + "Received input shape: ".concat(JSON.stringify(inputShape), "."));
72483 }
72484 var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : 3;
72485 if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
72486 throw new ValueError('The channel dimension of the inputs to DepthwiseConv2D should ' + "be defined, but is not (".concat(inputShape[channelAxis], ")."));
72487 }
72488 var inputDim = inputShape[channelAxis];
72489 var depthwiseKernelShape = [this.kernelSize[0], this.kernelSize[1], inputDim, this.depthMultiplier];
72490 this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, null, this.depthwiseInitializer, this.depthwiseRegularizer, true, this.depthwiseConstraint);
72491 if (this.useBias) {
72492 this.bias = this.addWeight('bias', [inputDim * this.depthMultiplier], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
72493 } else {
72494 this.bias = null;
72495 }
72496 this.built = true;
72497 }
72498 }, {
72499 key: "call",
72500 value: function call(inputs, kwargs) {
72501 var _this2 = this;
72502 return tidy(function () {
72503 inputs = getExactlyOneTensor(inputs);
72504 var outputs = depthwiseConv2d$1(inputs, _this2.depthwiseKernel.read(), _this2.strides, _this2.padding, _this2.dataFormat, null);
72505 // TODO(cais): Add support for dilation.
72506 if (_this2.useBias) {
72507 outputs = biasAdd(outputs, _this2.bias.read(), _this2.dataFormat);
72508 }
72509 if (_this2.activation != null) {
72510 outputs = _this2.activation.apply(outputs);
72511 }
72512 return outputs;
72513 });
72514 }
72515 }, {
72516 key: "computeOutputShape",
72517 value: function computeOutputShape(inputShape) {
72518 inputShape = getExactlyOneShape(inputShape);
72519 var rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
72520 var cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
72521 var outFilters = this.dataFormat === 'channelsFirst' ? inputShape[1] * this.depthMultiplier : inputShape[3] * this.depthMultiplier;
72522 var outRows = convOutputLength(rows, this.kernelSize[0], this.padding, this.strides[0]);
72523 var outCols = convOutputLength(cols, this.kernelSize[1], this.padding, this.strides[1]);
72524 if (this.dataFormat === 'channelsFirst') {
72525 return [inputShape[0], outFilters, outRows, outCols];
72526 } else {
72527 // In this case, assume 'channelsLast'.
72528 return [inputShape[0], outRows, outCols, outFilters];
72529 }
72530 }
72531 }, {
72532 key: "getConfig",
72533 value: function getConfig() {
72534 var config = _get(_getPrototypeOf(DepthwiseConv2D.prototype), "getConfig", this).call(this);
72535 config['depthMultiplier'] = this.depthMultiplier;
72536 config['depthwiseInitializer'] = serializeInitializer(this.depthwiseInitializer);
72537 config['depthwiseRegularizer'] = serializeRegularizer(this.depthwiseRegularizer);
72538 config['depthwiseConstraint'] = serializeConstraint(this.depthwiseRegularizer);
72539 return config;
72540 }
72541 }]);
72542 return DepthwiseConv2D;
72543 }(BaseConv);
72544 /** @nocollapse */
72545 DepthwiseConv2D.className = 'DepthwiseConv2D';
72546 registerClass(DepthwiseConv2D);
72547
72548 /**
72549 * Standardize `apply()` args to a single list of tensor inputs.
72550 *
72551 * When running a model loaded from file, the input tensors `initialState` and
72552 * `constants` are passed to `RNN.apply()` as part of `inputs` instead of the
72553 * dedicated kwargs fields. `inputs` consists of
72554 * `[inputs, initialState0, initialState1, ..., constant0, constant1]` in this
72555 * case.
72556 * This method makes sure that arguments are
72557 * separated and that `initialState` and `constants` are `Array`s of tensors
72558 * (or None).
72559 *
72560 * @param inputs Tensor or `Array` of tensors.
72561 * @param initialState Tensor or `Array` of tensors or `null`/`undefined`.
72562 * @param constants Tensor or `Array` of tensors or `null`/`undefined`.
72563 * @returns An object consisting of
72564 * inputs: A tensor.
72565 * initialState: `Array` of tensors or `null`.
72566 * constants: `Array` of tensors or `null`.
72567 * @throws ValueError, if `inputs` is an `Array` but either `initialState` or
72568 * `constants` is provided.
72569 */
72570 function standardizeArgs(inputs, initialState, constants, numConstants) {
72571 if (Array.isArray(inputs)) {
72572 if (initialState != null || constants != null) {
72573 throw new ValueError('When inputs is an array, neither initialState or constants ' + 'should be provided');
72574 }
72575 if (numConstants != null) {
72576 constants = inputs.slice(inputs.length - numConstants, inputs.length);
72577 inputs = inputs.slice(0, inputs.length - numConstants);
72578 }
72579 if (inputs.length > 1) {
72580 initialState = inputs.slice(1, inputs.length);
72581 }
72582 inputs = inputs[0];
72583 }
72584 function toListOrNull(x) {
72585 if (x == null || Array.isArray(x)) {
72586 return x;
72587 } else {
72588 return [x];
72589 }
72590 }
72591 initialState = toListOrNull(initialState);
72592 constants = toListOrNull(constants);
72593 return {
72594 inputs: inputs,
72595 initialState: initialState,
72596 constants: constants
72597 };
72598 }
72599 /**
72600 * Iterates over the time dimension of a tensor.
72601 *
72602 * @param stepFunction RNN step function.
72603 * Parameters:
72604 * inputs: tensor with shape `[samples, ...]` (no time dimension),
72605 * representing input for the batch of samples at a certain time step.
72606 * states: an Array of tensors.
72607 * Returns:
72608 * outputs: tensor with shape `[samples, outputDim]` (no time dimension).
72609 * newStates: list of tensors, same length and shapes as `states`. The first
72610 * state in the list must be the output tensor at the previous timestep.
72611 * @param inputs Tensor of temporal data of shape `[samples, time, ...]` (at
72612 * least 3D).
72613 * @param initialStates Tensor with shape `[samples, outputDim]` (no time
72614 * dimension), containing the initial values of the states used in the step
72615 * function.
72616 * @param goBackwards If `true`, do the iteration over the time dimension in
72617 * reverse order and return the reversed sequence.
72618 * @param mask Binary tensor with shape `[sample, time, 1]`, with a zero for
72619 * every element that is masked.
72620 * @param constants An Array of constant values passed at each step.
72621 * @param unroll Whether to unroll the RNN or to use a symbolic loop. *Not*
72622 * applicable to this imperative deeplearn.js backend. Its value is ignored.
72623 * @param needPerStepOutputs Whether the per-step outputs are to be
72624 * concatenated into a single tensor and returned (as the second return
72625 * value). Default: `false`. This arg is included so that the relatively
72626 * expensive concatenation of the stepwise outputs can be omitted unless
72627 * the stepwise outputs need to be kept (e.g., for an LSTM layer of which
72628 * `returnSequence` is `true`.)
72629 * @returns An Array: `[lastOutput, outputs, newStates]`.
72630 * lastOutput: the lastest output of the RNN, of shape `[samples, ...]`.
72631 * outputs: tensor with shape `[samples, time, ...]` where each entry
72632 * `output[s, t]` is the output of the step function at time `t` for sample
72633 * `s`. This return value is provided if and only if the
72634 * `needPerStepOutputs` is set as `true`. If it is set as `false`, this
72635 * return value will be `undefined`.
72636 * newStates: Array of tensors, latest states returned by the step function,
72637 * of shape `(samples, ...)`.
72638 * @throws ValueError If input dimension is less than 3.
72639 *
72640 * TODO(nielsene): This needs to be tidy-ed.
72641 */
72642 function rnn$1(stepFunction, inputs, initialStates) {
72643 var goBackwards = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
72644 var mask = arguments.length > 4 ? arguments[4] : undefined;
72645 var constants = arguments.length > 5 ? arguments[5] : undefined;
72646 var unroll = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : false;
72647 var needPerStepOutputs = arguments.length > 7 && arguments[7] !== undefined ? arguments[7] : false;
72648 return tidy(function () {
72649 var ndim = inputs.shape.length;
72650 if (ndim < 3) {
72651 throw new ValueError("Input should be at least 3D, but is ".concat(ndim, "D."));
72652 }
72653 // Transpose to time-major, i.e., from [batch, time, ...] to [time, batch,
72654 // ...].
72655 var axes = [1, 0].concat(range$2(2, ndim));
72656 inputs = transpose$2(inputs, axes);
72657 if (constants != null) {
72658 throw new NotImplementedError('The rnn() functoin of the deeplearn.js backend does not support ' + 'constants yet.');
72659 }
72660 // Porting Note: the unroll option is ignored by the imperative backend.
72661 if (unroll) {
72662 console.warn('Backend rnn(): the unroll = true option is not applicable to the ' + 'imperative deeplearn.js backend.');
72663 }
72664 if (mask != null) {
72665 mask = cast$3(cast$3(mask, 'bool'), 'float32');
72666 if (mask.rank === ndim - 1) {
72667 mask = expandDims$3(mask, -1);
72668 }
72669 mask = transpose$2(mask, axes);
72670 }
72671 if (goBackwards) {
72672 inputs = reverse$2(inputs, 0);
72673 if (mask != null) {
72674 mask = reverse$2(mask, 0);
72675 }
72676 }
72677 // Porting Note: PyKeras with TensorFlow backend uses a symbolic loop
72678 // (tf.while_loop). But for the imperative deeplearn.js backend, we just
72679 // use the usual TypeScript control flow to iterate over the time steps in
72680 // the inputs.
72681 // Porting Note: PyKeras patches a "_use_learning_phase" attribute to
72682 // outputs.
72683 // This is not idiomatic in TypeScript. The info regarding whether we are
72684 // in a learning (i.e., training) phase for RNN is passed in a different
72685 // way.
72686 var perStepOutputs = [];
72687 var lastOutput;
72688 var states = initialStates;
72689 var timeSteps = inputs.shape[0];
72690 var perStepInputs = unstack(inputs);
72691 var perStepMasks;
72692 if (mask != null) {
72693 perStepMasks = unstack(mask);
72694 }
72695 var _loop = function _loop(t) {
72696 var currentInput = perStepInputs[t];
72697 var stepOutputs = tidy(function () {
72698 return stepFunction(currentInput, states);
72699 });
72700 if (mask == null) {
72701 lastOutput = stepOutputs[0];
72702 states = stepOutputs[1];
72703 } else {
72704 var maskedOutputs = tidy(function () {
72705 var stepMask = perStepMasks[t];
72706 var negStepMask = sub$2(onesLike$3(stepMask), stepMask);
72707 // TODO(cais): Would tfc.where() be better for performance?
72708 var output = add$3(mul(stepOutputs[0], stepMask), mul(states[0], negStepMask));
72709 var newStates = states.map(function (state, i) {
72710 return add$3(mul(stepOutputs[1][i], stepMask), mul(state, negStepMask));
72711 });
72712 return {
72713 output: output,
72714 newStates: newStates
72715 };
72716 });
72717 lastOutput = maskedOutputs.output;
72718 states = maskedOutputs.newStates;
72719 }
72720 if (needPerStepOutputs) {
72721 perStepOutputs.push(lastOutput);
72722 }
72723 };
72724 for (var t = 0; t < timeSteps; ++t) {
72725 _loop(t);
72726 }
72727 var outputs;
72728 if (needPerStepOutputs) {
72729 var axis = 1;
72730 outputs = stack(perStepOutputs, axis);
72731 }
72732 return [lastOutput, outputs, states];
72733 });
72734 }
72735 var RNN = /*#__PURE__*/function (_Layer) {
72736 _inherits(RNN, _Layer);
72737 var _super = _createSuper(RNN);
72738 function RNN(args) {
72739 var _this;
72740 _classCallCheck(this, RNN);
72741 _this = _super.call(this, args);
72742 var cell;
72743 if (args.cell == null) {
72744 throw new ValueError('cell property is missing for the constructor of RNN.');
72745 } else if (Array.isArray(args.cell)) {
72746 cell = new StackedRNNCells({
72747 cells: args.cell
72748 });
72749 } else {
72750 cell = args.cell;
72751 }
72752 if (cell.stateSize == null) {
72753 throw new ValueError('The RNN cell should have an attribute `stateSize` (tuple of ' + 'integers, one integer per RNN state).');
72754 }
72755 _this.cell = cell;
72756 _this.returnSequences = args.returnSequences == null ? false : args.returnSequences;
72757 _this.returnState = args.returnState == null ? false : args.returnState;
72758 _this.goBackwards = args.goBackwards == null ? false : args.goBackwards;
72759 _this._stateful = args.stateful == null ? false : args.stateful;
72760 _this.unroll = args.unroll == null ? false : args.unroll;
72761 _this.supportsMasking = true;
72762 _this.inputSpec = [new InputSpec({
72763 ndim: 3
72764 })];
72765 _this.stateSpec = null;
72766 _this.states_ = null;
72767 // TODO(cais): Add constantsSpec and numConstants.
72768 _this.numConstants = null;
72769 // TODO(cais): Look into the use of initial_state in the kwargs of the
72770 // constructor.
72771 _this.keptStates = [];
72772 return _this;
72773 }
72774 // Porting Note: This is the equivalent of `RNN.states` property getter in
72775 // PyKeras.
72776 _createClass(RNN, [{
72777 key: "getStates",
72778 value: function getStates() {
72779 if (this.states_ == null) {
72780 var numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
72781 return range$2(0, numStates).map(function (x) {
72782 return null;
72783 });
72784 } else {
72785 return this.states_;
72786 }
72787 }
72788 // Porting Note: This is the equivalent of the `RNN.states` property setter in
72789 // PyKeras.
72790 }, {
72791 key: "setStates",
72792 value: function setStates(states) {
72793 this.states_ = states;
72794 }
72795 }, {
72796 key: "computeOutputShape",
72797 value: function computeOutputShape(inputShape) {
72798 if (isArrayOfShapes(inputShape)) {
72799 inputShape = inputShape[0];
72800 }
72801 inputShape = inputShape;
72802 // TODO(cais): Remove the casting once stacked RNN cells become supported.
72803 var stateSize = this.cell.stateSize;
72804 if (!Array.isArray(stateSize)) {
72805 stateSize = [stateSize];
72806 }
72807 var outputDim = stateSize[0];
72808 var outputShape;
72809 if (this.returnSequences) {
72810 outputShape = [inputShape[0], inputShape[1], outputDim];
72811 } else {
72812 outputShape = [inputShape[0], outputDim];
72813 }
72814 if (this.returnState) {
72815 var stateShape = [];
72816 var _iterator = _createForOfIteratorHelper(stateSize),
72817 _step;
72818 try {
72819 for (_iterator.s(); !(_step = _iterator.n()).done;) {
72820 var dim = _step.value;
72821 stateShape.push([inputShape[0], dim]);
72822 }
72823 } catch (err) {
72824 _iterator.e(err);
72825 } finally {
72826 _iterator.f();
72827 }
72828 return [outputShape].concat(stateShape);
72829 } else {
72830 return outputShape;
72831 }
72832 }
72833 }, {
72834 key: "computeMask",
72835 value: function computeMask(inputs, mask) {
72836 var _this2 = this;
72837 return tidy(function () {
72838 if (Array.isArray(mask)) {
72839 mask = mask[0];
72840 }
72841 var outputMask = _this2.returnSequences ? mask : null;
72842 if (_this2.returnState) {
72843 var stateMask = _this2.states.map(function (s) {
72844 return null;
72845 });
72846 return [outputMask].concat(stateMask);
72847 } else {
72848 return outputMask;
72849 }
72850 });
72851 }
72852 /**
72853 * Get the current state tensors of the RNN.
72854 *
72855 * If the state hasn't been set, return an array of `null`s of the correct
72856 * length.
72857 */
72858 }, {
72859 key: "states",
72860 get: function get() {
72861 if (this.states_ == null) {
72862 var numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
72863 var output = [];
72864 for (var i = 0; i < numStates; ++i) {
72865 output.push(null);
72866 }
72867 return output;
72868 } else {
72869 return this.states_;
72870 }
72871 },
72872 set: function set(s) {
72873 this.states_ = s;
72874 }
72875 }, {
72876 key: "build",
72877 value: function build(inputShape) {
72878 // Note inputShape will be an Array of Shapes of initial states and
72879 // constants if these are passed in apply().
72880 var constantShape = null;
72881 if (this.numConstants != null) {
72882 throw new NotImplementedError('Constants support is not implemented in RNN yet.');
72883 }
72884 if (isArrayOfShapes(inputShape)) {
72885 inputShape = inputShape[0];
72886 }
72887 inputShape = inputShape;
72888 var batchSize = this.stateful ? inputShape[0] : null;
72889 var inputDim = inputShape.slice(2);
72890 this.inputSpec[0] = new InputSpec({
72891 shape: [batchSize, null].concat(_toConsumableArray(inputDim))
72892 });
72893 // Allow cell (if RNNCell Layer) to build before we set or validate
72894 // stateSpec.
72895 var stepInputShape = [inputShape[0]].concat(inputShape.slice(2));
72896 if (constantShape != null) {
72897 throw new NotImplementedError('Constants support is not implemented in RNN yet.');
72898 } else {
72899 this.cell.build(stepInputShape);
72900 }
72901 // Set or validate stateSpec.
72902 var stateSize;
72903 if (Array.isArray(this.cell.stateSize)) {
72904 stateSize = this.cell.stateSize;
72905 } else {
72906 stateSize = [this.cell.stateSize];
72907 }
72908 if (this.stateSpec != null) {
72909 if (!arraysEqual(this.stateSpec.map(function (spec) {
72910 return spec.shape[spec.shape.length - 1];
72911 }), stateSize)) {
72912 throw new ValueError("An initialState was passed that is not compatible with " + "cell.stateSize. Received stateSpec=".concat(this.stateSpec, "; ") + "However cell.stateSize is ".concat(this.cell.stateSize));
72913 }
72914 } else {
72915 this.stateSpec = stateSize.map(function (dim) {
72916 return new InputSpec({
72917 shape: [null, dim]
72918 });
72919 });
72920 }
72921 if (this.stateful) {
72922 this.resetStates();
72923 }
72924 }
72925 /**
72926 * Reset the state tensors of the RNN.
72927 *
72928 * If the `states` argument is `undefined` or `null`, will set the
72929 * state tensor(s) of the RNN to all-zero tensors of the appropriate
72930 * shape(s).
72931 *
72932 * If `states` is provided, will set the state tensors of the RNN to its
72933 * value.
72934 *
72935 * @param states Optional externally-provided initial states.
72936 * @param training Whether this call is done during training. For stateful
72937 * RNNs, this affects whether the old states are kept or discarded. In
72938 * particular, if `training` is `true`, the old states will be kept so
72939 * that subsequent backpropgataion through time (BPTT) may work properly.
72940 * Else, the old states will be discarded.
72941 */
72942 }, {
72943 key: "resetStates",
72944 value: function resetStates(states) {
72945 var _this3 = this;
72946 var training = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
72947 tidy(function () {
72948 if (!_this3.stateful) {
72949 throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
72950 }
72951 var batchSize = _this3.inputSpec[0].shape[0];
72952 if (batchSize == null) {
72953 throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' + 'the batch size of your input tensors: \n' + '- If using a Sequential model, specify the batch size by ' + 'passing a `batchInputShape` option to your first layer.\n' + '- If using the functional API, specify the batch size by ' + 'passing a `batchShape` option to your Input layer.');
72954 }
72955 // Initialize state if null.
72956 if (_this3.states_ == null) {
72957 if (Array.isArray(_this3.cell.stateSize)) {
72958 _this3.states_ = _this3.cell.stateSize.map(function (dim) {
72959 return zeros$2([batchSize, dim]);
72960 });
72961 } else {
72962 _this3.states_ = [zeros$2([batchSize, _this3.cell.stateSize])];
72963 }
72964 } else if (states == null) {
72965 // Dispose old state tensors.
72966 dispose(_this3.states_);
72967 // For stateful RNNs, fully dispose kept old states.
72968 if (_this3.keptStates != null) {
72969 dispose(_this3.keptStates);
72970 _this3.keptStates = [];
72971 }
72972 if (Array.isArray(_this3.cell.stateSize)) {
72973 _this3.states_ = _this3.cell.stateSize.map(function (dim) {
72974 return zeros$2([batchSize, dim]);
72975 });
72976 } else {
72977 _this3.states_[0] = zeros$2([batchSize, _this3.cell.stateSize]);
72978 }
72979 } else {
72980 if (!Array.isArray(states)) {
72981 states = [states];
72982 }
72983 if (states.length !== _this3.states_.length) {
72984 throw new ValueError("Layer ".concat(_this3.name, " expects ").concat(_this3.states_.length, " state(s), ") + "but it received ".concat(states.length, " state value(s). Input ") + "received: ".concat(states));
72985 }
72986 if (training === true) {
72987 // Store old state tensors for complete disposal later, i.e., during
72988 // the next no-arg call to this method. We do not dispose the old
72989 // states immediately because that BPTT (among other things) require
72990 // them.
72991 _this3.keptStates.push(_this3.states_.slice());
72992 } else {
72993 dispose(_this3.states_);
72994 }
72995 for (var index = 0; index < _this3.states_.length; ++index) {
72996 var value = states[index];
72997 var dim = Array.isArray(_this3.cell.stateSize) ? _this3.cell.stateSize[index] : _this3.cell.stateSize;
72998 var expectedShape = [batchSize, dim];
72999 if (!arraysEqual(value.shape, expectedShape)) {
73000 throw new ValueError("State ".concat(index, " is incompatible with layer ").concat(_this3.name, ": ") + "expected shape=".concat(expectedShape, ", received shape=").concat(value.shape));
73001 }
73002 _this3.states_[index] = value;
73003 }
73004 }
73005 _this3.states_ = _this3.states_.map(function (state) {
73006 return keep(state.clone());
73007 });
73008 });
73009 }
73010 }, {
73011 key: "apply",
73012 value: function apply(inputs, kwargs) {
73013 // TODO(cais): Figure out whether initialState is in kwargs or inputs.
73014 var initialState = kwargs == null ? null : kwargs['initialState'];
73015 var constants = kwargs == null ? null : kwargs['constants'];
73016 if (kwargs == null) {
73017 kwargs = {};
73018 }
73019 var standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
73020 inputs = standardized.inputs;
73021 initialState = standardized.initialState;
73022 constants = standardized.constants;
73023 // If any of `initial_state` or `constants` are specified and are
73024 // `tf.SymbolicTensor`s, then add them to the inputs and temporarily modify
73025 // the input_spec to include them.
73026 var additionalInputs = [];
73027 var additionalSpecs = [];
73028 if (initialState != null) {
73029 kwargs['initialState'] = initialState;
73030 additionalInputs = additionalInputs.concat(initialState);
73031 this.stateSpec = [];
73032 var _iterator2 = _createForOfIteratorHelper(initialState),
73033 _step2;
73034 try {
73035 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
73036 var state = _step2.value;
73037 this.stateSpec.push(new InputSpec({
73038 shape: state.shape
73039 }));
73040 }
73041 // TODO(cais): Use the following instead.
73042 // this.stateSpec = initialState.map(state => new InputSpec({shape:
73043 // state.shape}));
73044 } catch (err) {
73045 _iterator2.e(err);
73046 } finally {
73047 _iterator2.f();
73048 }
73049 additionalSpecs = additionalSpecs.concat(this.stateSpec);
73050 }
73051 if (constants != null) {
73052 kwargs['constants'] = constants;
73053 additionalInputs = additionalInputs.concat(constants);
73054 // TODO(cais): Add this.constantsSpec.
73055 this.numConstants = constants.length;
73056 }
73057 var isTensor = additionalInputs[0] instanceof SymbolicTensor;
73058 if (isTensor) {
73059 // Compute full input spec, including state and constants.
73060 var fullInput = [inputs].concat(additionalInputs);
73061 var fullInputSpec = this.inputSpec.concat(additionalSpecs);
73062 // Perform the call with temporarily replaced inputSpec.
73063 var originalInputSpec = this.inputSpec;
73064 this.inputSpec = fullInputSpec;
73065 var output = _get(_getPrototypeOf(RNN.prototype), "apply", this).call(this, fullInput, kwargs);
73066 this.inputSpec = originalInputSpec;
73067 return output;
73068 } else {
73069 return _get(_getPrototypeOf(RNN.prototype), "apply", this).call(this, inputs, kwargs);
73070 }
73071 }
73072 // tslint:disable-next-line:no-any
73073 }, {
73074 key: "call",
73075 value: function call(inputs, kwargs) {
73076 var _this4 = this;
73077 // Input shape: `[samples, time (padded with zeros), input_dim]`.
73078 // Note that the .build() method of subclasses **must** define
73079 // this.inputSpec and this.stateSpec owith complete input shapes.
73080 return tidy(function () {
73081 var mask = kwargs == null ? null : kwargs['mask'];
73082 var training = kwargs == null ? null : kwargs['training'];
73083 var initialState = kwargs == null ? null : kwargs['initialState'];
73084 inputs = getExactlyOneTensor(inputs);
73085 if (initialState == null) {
73086 if (_this4.stateful) {
73087 initialState = _this4.states_;
73088 } else {
73089 initialState = _this4.getInitialState(inputs);
73090 }
73091 }
73092 var numStates = Array.isArray(_this4.cell.stateSize) ? _this4.cell.stateSize.length : 1;
73093 if (initialState.length !== numStates) {
73094 throw new ValueError("RNN Layer has ".concat(numStates, " state(s) but was passed ") + "".concat(initialState.length, " initial state(s)."));
73095 }
73096 if (_this4.unroll) {
73097 console.warn('Ignoring unroll = true for RNN layer, due to imperative backend.');
73098 }
73099 var cellCallKwargs = {
73100 training: training
73101 };
73102 // TODO(cais): Add support for constants.
73103 var step = function step(inputs, states) {
73104 // `inputs` and `states` are concatenated to form a single `Array` of
73105 // `tf.Tensor`s as the input to `cell.call()`.
73106 var outputs = _this4.cell.call([inputs].concat(states), cellCallKwargs);
73107 // Marshall the return value into output and new states.
73108 return [outputs[0], outputs.slice(1)];
73109 };
73110 // TODO(cais): Add support for constants.
73111 var rnnOutputs = rnn$1(step, inputs, initialState, _this4.goBackwards, mask, null, _this4.unroll, _this4.returnSequences);
73112 var lastOutput = rnnOutputs[0];
73113 var outputs = rnnOutputs[1];
73114 var states = rnnOutputs[2];
73115 if (_this4.stateful) {
73116 _this4.resetStates(states, training);
73117 }
73118 var output = _this4.returnSequences ? outputs : lastOutput;
73119 // TODO(cais): Property set learning phase flag.
73120 if (_this4.returnState) {
73121 return [output].concat(states);
73122 } else {
73123 return output;
73124 }
73125 });
73126 }
73127 }, {
73128 key: "getInitialState",
73129 value: function getInitialState(inputs) {
73130 var _this5 = this;
73131 return tidy(function () {
73132 // Build an all-zero tensor of shape [samples, outputDim].
73133 // [Samples, timeSteps, inputDim].
73134 var initialState = zeros$2(inputs.shape);
73135 // [Samples].
73136 initialState = sum$3(initialState, [1, 2]);
73137 initialState = expandDims$2(initialState); // [Samples, 1].
73138 if (Array.isArray(_this5.cell.stateSize)) {
73139 return _this5.cell.stateSize.map(function (dim) {
73140 return dim > 1 ? tile$2(initialState, [1, dim]) : initialState;
73141 });
73142 } else {
73143 return _this5.cell.stateSize > 1 ? [tile$2(initialState, [1, _this5.cell.stateSize])] : [initialState];
73144 }
73145 });
73146 }
73147 }, {
73148 key: "trainableWeights",
73149 get: function get() {
73150 if (!this.trainable) {
73151 return [];
73152 }
73153 // Porting Note: In TypeScript, `this` is always an instance of `Layer`.
73154 return this.cell.trainableWeights;
73155 }
73156 }, {
73157 key: "nonTrainableWeights",
73158 get: function get() {
73159 // Porting Note: In TypeScript, `this` is always an instance of `Layer`.
73160 if (!this.trainable) {
73161 return this.cell.weights;
73162 }
73163 return this.cell.nonTrainableWeights;
73164 }
73165 }, {
73166 key: "setFastWeightInitDuringBuild",
73167 value: function setFastWeightInitDuringBuild(value) {
73168 _get(_getPrototypeOf(RNN.prototype), "setFastWeightInitDuringBuild", this).call(this, value);
73169 if (this.cell != null) {
73170 this.cell.setFastWeightInitDuringBuild(value);
73171 }
73172 }
73173 }, {
73174 key: "getConfig",
73175 value: function getConfig() {
73176 var baseConfig = _get(_getPrototypeOf(RNN.prototype), "getConfig", this).call(this);
73177 var config = {
73178 returnSequences: this.returnSequences,
73179 returnState: this.returnState,
73180 goBackwards: this.goBackwards,
73181 stateful: this.stateful,
73182 unroll: this.unroll
73183 };
73184 if (this.numConstants != null) {
73185 config['numConstants'] = this.numConstants;
73186 }
73187 var cellConfig = this.cell.getConfig();
73188 if (this.getClassName() === RNN.className) {
73189 config['cell'] = {
73190 'className': this.cell.getClassName(),
73191 'config': cellConfig
73192 };
73193 }
73194 // this order is necessary, to prevent cell name from replacing layer name
73195 return Object.assign(Object.assign(Object.assign({}, cellConfig), baseConfig), config);
73196 }
73197 /** @nocollapse */
73198 }], [{
73199 key: "fromConfig",
73200 value: function fromConfig(cls, config) {
73201 var customObjects = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : {};
73202 var cellConfig = config['cell'];
73203 var cell = deserialize(cellConfig, customObjects);
73204 return new cls(Object.assign(config, {
73205 cell: cell
73206 }));
73207 }
73208 }]);
73209 return RNN;
73210 }(Layer);
73211 /** @nocollapse */
73212 RNN.className = 'RNN';
73213 registerClass(RNN);
73214 // Porting Note: This is a common parent class for RNN cells. There is no
73215 // equivalent of this in PyKeras. Having a common parent class forgoes the
73216 // need for `has_attr(cell, ...)` checks or its TypeScript equivalent.
73217 /**
73218 * An RNNCell layer.
73219 *
73220 * @doc {heading: 'Layers', subheading: 'Classes'}
73221 */
73222 var RNNCell = /*#__PURE__*/function (_Layer2) {
73223 _inherits(RNNCell, _Layer2);
73224 var _super2 = _createSuper(RNNCell);
73225 function RNNCell() {
73226 _classCallCheck(this, RNNCell);
73227 return _super2.apply(this, arguments);
73228 }
73229 return _createClass(RNNCell);
73230 }(Layer);
73231 var SimpleRNNCell = /*#__PURE__*/function (_RNNCell) {
73232 _inherits(SimpleRNNCell, _RNNCell);
73233 var _super3 = _createSuper(SimpleRNNCell);
73234 function SimpleRNNCell(args) {
73235 var _this6;
73236 _classCallCheck(this, SimpleRNNCell);
73237 _this6 = _super3.call(this, args);
73238 _this6.DEFAULT_ACTIVATION = 'tanh';
73239 _this6.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
73240 _this6.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
73241 _this6.DEFAULT_BIAS_INITIALIZER = 'zeros';
73242 _this6.units = args.units;
73243 assertPositiveInteger(_this6.units, "units");
73244 _this6.activation = getActivation(args.activation == null ? _this6.DEFAULT_ACTIVATION : args.activation);
73245 _this6.useBias = args.useBias == null ? true : args.useBias;
73246 _this6.kernelInitializer = getInitializer(args.kernelInitializer || _this6.DEFAULT_KERNEL_INITIALIZER);
73247 _this6.recurrentInitializer = getInitializer(args.recurrentInitializer || _this6.DEFAULT_RECURRENT_INITIALIZER);
73248 _this6.biasInitializer = getInitializer(args.biasInitializer || _this6.DEFAULT_BIAS_INITIALIZER);
73249 _this6.kernelRegularizer = getRegularizer(args.kernelRegularizer);
73250 _this6.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
73251 _this6.biasRegularizer = getRegularizer(args.biasRegularizer);
73252 _this6.kernelConstraint = getConstraint(args.kernelConstraint);
73253 _this6.recurrentConstraint = getConstraint(args.recurrentConstraint);
73254 _this6.biasConstraint = getConstraint(args.biasConstraint);
73255 _this6.dropout = min$2([1, max$2([0, args.dropout == null ? 0 : args.dropout])]);
73256 _this6.recurrentDropout = min$2([1, max$2([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])]);
73257 _this6.dropoutFunc = args.dropoutFunc;
73258 _this6.stateSize = _this6.units;
73259 _this6.dropoutMask = null;
73260 _this6.recurrentDropoutMask = null;
73261 return _this6;
73262 }
73263 _createClass(SimpleRNNCell, [{
73264 key: "build",
73265 value: function build(inputShape) {
73266 inputShape = getExactlyOneShape(inputShape);
73267 // TODO(cais): Use regularizer.
73268 this.kernel = this.addWeight('kernel', [inputShape[inputShape.length - 1], this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
73269 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
73270 if (this.useBias) {
73271 this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
73272 } else {
73273 this.bias = null;
73274 }
73275 this.built = true;
73276 }
73277 // Porting Note: PyKeras' equivalent of this method takes two tensor inputs:
73278 // `inputs` and `states`. Here, the two tensors are combined into an
73279 // `Tensor[]` Array as the first input argument.
73280 // Similarly, PyKeras' equivalent of this method returns two values:
73281 // `output` and `[output]`. Here the two are combined into one length-2
73282 // `Tensor[]`, consisting of `output` repeated.
73283 }, {
73284 key: "call",
73285 value: function call(inputs, kwargs) {
73286 var _this7 = this;
73287 return tidy(function () {
73288 inputs = inputs;
73289 if (inputs.length !== 2) {
73290 throw new ValueError("SimpleRNNCell expects 2 input Tensors, got ".concat(inputs.length, "."));
73291 }
73292 var prevOutput = inputs[1];
73293 inputs = inputs[0];
73294 var training = kwargs['training'] == null ? false : kwargs['training'];
73295 if (0 < _this7.dropout && _this7.dropout < 1 && _this7.dropoutMask == null) {
73296 _this7.dropoutMask = generateDropoutMask({
73297 ones: function ones() {
73298 return onesLike$3(inputs);
73299 },
73300 rate: _this7.dropout,
73301 training: training,
73302 dropoutFunc: _this7.dropoutFunc
73303 });
73304 }
73305 if (0 < _this7.recurrentDropout && _this7.recurrentDropout < 1 && _this7.recurrentDropoutMask == null) {
73306 _this7.recurrentDropoutMask = generateDropoutMask({
73307 ones: function ones() {
73308 return onesLike$3(prevOutput);
73309 },
73310 rate: _this7.recurrentDropout,
73311 training: training,
73312 dropoutFunc: _this7.dropoutFunc
73313 });
73314 }
73315 var h;
73316 var dpMask = _this7.dropoutMask;
73317 var recDpMask = _this7.recurrentDropoutMask;
73318 if (dpMask != null) {
73319 h = dot$1(mul(inputs, dpMask), _this7.kernel.read());
73320 } else {
73321 h = dot$1(inputs, _this7.kernel.read());
73322 }
73323 if (_this7.bias != null) {
73324 h = biasAdd(h, _this7.bias.read());
73325 }
73326 if (recDpMask != null) {
73327 prevOutput = mul(prevOutput, recDpMask);
73328 }
73329 var output = add$3(h, dot$1(prevOutput, _this7.recurrentKernel.read()));
73330 if (_this7.activation != null) {
73331 output = _this7.activation.apply(output);
73332 }
73333 // TODO(cais): Properly set learning phase on output tensor?
73334 return [output, output];
73335 });
73336 }
73337 }, {
73338 key: "getConfig",
73339 value: function getConfig() {
73340 var baseConfig = _get(_getPrototypeOf(SimpleRNNCell.prototype), "getConfig", this).call(this);
73341 var config = {
73342 units: this.units,
73343 activation: serializeActivation(this.activation),
73344 useBias: this.useBias,
73345 kernelInitializer: serializeInitializer(this.kernelInitializer),
73346 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
73347 biasInitializer: serializeInitializer(this.biasInitializer),
73348 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
73349 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
73350 biasRegularizer: serializeRegularizer(this.biasRegularizer),
73351 activityRegularizer: serializeRegularizer(this.activityRegularizer),
73352 kernelConstraint: serializeConstraint(this.kernelConstraint),
73353 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
73354 biasConstraint: serializeConstraint(this.biasConstraint),
73355 dropout: this.dropout,
73356 recurrentDropout: this.recurrentDropout
73357 };
73358 return Object.assign(Object.assign({}, baseConfig), config);
73359 }
73360 }]);
73361 return SimpleRNNCell;
73362 }(RNNCell);
73363 /** @nocollapse */
73364 SimpleRNNCell.className = 'SimpleRNNCell';
73365 registerClass(SimpleRNNCell);
73366 var SimpleRNN = /*#__PURE__*/function (_RNN) {
73367 _inherits(SimpleRNN, _RNN);
73368 var _super4 = _createSuper(SimpleRNN);
73369 function SimpleRNN(args) {
73370 _classCallCheck(this, SimpleRNN);
73371 args.cell = new SimpleRNNCell(args);
73372 return _super4.call(this, args); // TODO(cais): Add activityRegularizer.
73373 }
73374 _createClass(SimpleRNN, [{
73375 key: "call",
73376 value: function call(inputs, kwargs) {
73377 var _this8 = this;
73378 return tidy(function () {
73379 if (_this8.cell.dropoutMask != null) {
73380 dispose(_this8.cell.dropoutMask);
73381 _this8.cell.dropoutMask = null;
73382 }
73383 if (_this8.cell.recurrentDropoutMask != null) {
73384 dispose(_this8.cell.recurrentDropoutMask);
73385 _this8.cell.recurrentDropoutMask = null;
73386 }
73387 var mask = kwargs == null ? null : kwargs['mask'];
73388 var training = kwargs == null ? null : kwargs['training'];
73389 var initialState = kwargs == null ? null : kwargs['initialState'];
73390 return _get(_getPrototypeOf(SimpleRNN.prototype), "call", _this8).call(_this8, inputs, {
73391 mask: mask,
73392 training: training,
73393 initialState: initialState
73394 });
73395 });
73396 }
73397 /** @nocollapse */
73398 }], [{
73399 key: "fromConfig",
73400 value: function fromConfig(cls, config) {
73401 return new cls(config);
73402 }
73403 }]);
73404 return SimpleRNN;
73405 }(RNN);
73406 /** @nocollapse */
73407 SimpleRNN.className = 'SimpleRNN';
73408 registerClass(SimpleRNN);
73409 var GRUCell = /*#__PURE__*/function (_RNNCell2) {
73410 _inherits(GRUCell, _RNNCell2);
73411 var _super5 = _createSuper(GRUCell);
73412 function GRUCell(args) {
73413 var _this9;
73414 _classCallCheck(this, GRUCell);
73415 _this9 = _super5.call(this, args);
73416 _this9.DEFAULT_ACTIVATION = 'tanh';
73417 _this9.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
73418 _this9.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
73419 _this9.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
73420 _this9.DEFAULT_BIAS_INITIALIZER = 'zeros';
73421 if (args.resetAfter) {
73422 throw new ValueError("GRUCell does not support reset_after parameter set to true.");
73423 }
73424 _this9.units = args.units;
73425 assertPositiveInteger(_this9.units, 'units');
73426 _this9.activation = getActivation(args.activation === undefined ? _this9.DEFAULT_ACTIVATION : args.activation);
73427 _this9.recurrentActivation = getActivation(args.recurrentActivation === undefined ? _this9.DEFAULT_RECURRENT_ACTIVATION : args.recurrentActivation);
73428 _this9.useBias = args.useBias == null ? true : args.useBias;
73429 _this9.kernelInitializer = getInitializer(args.kernelInitializer || _this9.DEFAULT_KERNEL_INITIALIZER);
73430 _this9.recurrentInitializer = getInitializer(args.recurrentInitializer || _this9.DEFAULT_RECURRENT_INITIALIZER);
73431 _this9.biasInitializer = getInitializer(args.biasInitializer || _this9.DEFAULT_BIAS_INITIALIZER);
73432 _this9.kernelRegularizer = getRegularizer(args.kernelRegularizer);
73433 _this9.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
73434 _this9.biasRegularizer = getRegularizer(args.biasRegularizer);
73435 _this9.kernelConstraint = getConstraint(args.kernelConstraint);
73436 _this9.recurrentConstraint = getConstraint(args.recurrentConstraint);
73437 _this9.biasConstraint = getConstraint(args.biasConstraint);
73438 _this9.dropout = min$2([1, max$2([0, args.dropout == null ? 0 : args.dropout])]);
73439 _this9.recurrentDropout = min$2([1, max$2([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])]);
73440 _this9.dropoutFunc = args.dropoutFunc;
73441 _this9.implementation = args.implementation;
73442 _this9.stateSize = _this9.units;
73443 _this9.dropoutMask = null;
73444 _this9.recurrentDropoutMask = null;
73445 return _this9;
73446 }
73447 _createClass(GRUCell, [{
73448 key: "build",
73449 value: function build(inputShape) {
73450 inputShape = getExactlyOneShape(inputShape);
73451 var inputDim = inputShape[inputShape.length - 1];
73452 this.kernel = this.addWeight('kernel', [inputDim, this.units * 3], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
73453 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 3], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
73454 if (this.useBias) {
73455 this.bias = this.addWeight('bias', [this.units * 3], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
73456 } else {
73457 this.bias = null;
73458 }
73459 // Porting Notes: Unlike the PyKeras implementation, we perform slicing
73460 // of the weights and bias in the call() method, at execution time.
73461 this.built = true;
73462 }
73463 }, {
73464 key: "call",
73465 value: function call(inputs, kwargs) {
73466 var _this10 = this;
73467 return tidy(function () {
73468 inputs = inputs;
73469 if (inputs.length !== 2) {
73470 throw new ValueError("GRUCell expects 2 input Tensors (inputs, h, c), got " + "".concat(inputs.length, "."));
73471 }
73472 var training = kwargs['training'] == null ? false : kwargs['training'];
73473 var hTMinus1 = inputs[1]; // Previous memory state.
73474 inputs = inputs[0];
73475 // Note: For superior performance, TensorFlow.js always uses
73476 // implementation 2, regardless of the actual value of
73477 // config.implementation.
73478 if (0 < _this10.dropout && _this10.dropout < 1 && _this10.dropoutMask == null) {
73479 _this10.dropoutMask = generateDropoutMask({
73480 ones: function ones() {
73481 return onesLike$3(inputs);
73482 },
73483 rate: _this10.dropout,
73484 training: training,
73485 count: 3,
73486 dropoutFunc: _this10.dropoutFunc
73487 });
73488 }
73489 if (0 < _this10.recurrentDropout && _this10.recurrentDropout < 1 && _this10.recurrentDropoutMask == null) {
73490 _this10.recurrentDropoutMask = generateDropoutMask({
73491 ones: function ones() {
73492 return onesLike$3(hTMinus1);
73493 },
73494 rate: _this10.recurrentDropout,
73495 training: training,
73496 count: 3,
73497 dropoutFunc: _this10.dropoutFunc
73498 });
73499 }
73500 var dpMask = _this10.dropoutMask;
73501 var recDpMask = _this10.recurrentDropoutMask;
73502 var z;
73503 var r;
73504 var hh;
73505 if (0 < _this10.dropout && _this10.dropout < 1) {
73506 inputs = mul(inputs, dpMask[0]);
73507 }
73508 var matrixX = dot$1(inputs, _this10.kernel.read());
73509 if (_this10.useBias) {
73510 matrixX = biasAdd(matrixX, _this10.bias.read());
73511 }
73512 if (0 < _this10.recurrentDropout && _this10.recurrentDropout < 1) {
73513 hTMinus1 = mul(hTMinus1, recDpMask[0]);
73514 }
73515 var recurrentKernelValue = _this10.recurrentKernel.read();
73516 var _tfc$split = split$3(recurrentKernelValue, [2 * _this10.units, _this10.units], recurrentKernelValue.rank - 1),
73517 _tfc$split2 = _slicedToArray(_tfc$split, 2),
73518 rk1 = _tfc$split2[0],
73519 rk2 = _tfc$split2[1];
73520 var matrixInner = dot$1(hTMinus1, rk1);
73521 var _tfc$split3 = split$3(matrixX, 3, matrixX.rank - 1),
73522 _tfc$split4 = _slicedToArray(_tfc$split3, 3),
73523 xZ = _tfc$split4[0],
73524 xR = _tfc$split4[1],
73525 xH = _tfc$split4[2];
73526 var _tfc$split5 = split$3(matrixInner, 2, matrixInner.rank - 1),
73527 _tfc$split6 = _slicedToArray(_tfc$split5, 2),
73528 recurrentZ = _tfc$split6[0],
73529 recurrentR = _tfc$split6[1];
73530 z = _this10.recurrentActivation.apply(add$3(xZ, recurrentZ));
73531 r = _this10.recurrentActivation.apply(add$3(xR, recurrentR));
73532 var recurrentH = dot$1(mul(r, hTMinus1), rk2);
73533 hh = _this10.activation.apply(add$3(xH, recurrentH));
73534 var h = add$3(mul(z, hTMinus1), mul(add$3(1, neg$2(z)), hh));
73535 // TODO(cais): Add use_learning_phase flag properly.
73536 return [h, h];
73537 });
73538 }
73539 }, {
73540 key: "getConfig",
73541 value: function getConfig() {
73542 var baseConfig = _get(_getPrototypeOf(GRUCell.prototype), "getConfig", this).call(this);
73543 var config = {
73544 units: this.units,
73545 activation: serializeActivation(this.activation),
73546 recurrentActivation: serializeActivation(this.recurrentActivation),
73547 useBias: this.useBias,
73548 kernelInitializer: serializeInitializer(this.kernelInitializer),
73549 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
73550 biasInitializer: serializeInitializer(this.biasInitializer),
73551 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
73552 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
73553 biasRegularizer: serializeRegularizer(this.biasRegularizer),
73554 activityRegularizer: serializeRegularizer(this.activityRegularizer),
73555 kernelConstraint: serializeConstraint(this.kernelConstraint),
73556 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
73557 biasConstraint: serializeConstraint(this.biasConstraint),
73558 dropout: this.dropout,
73559 recurrentDropout: this.recurrentDropout,
73560 implementation: this.implementation,
73561 resetAfter: false
73562 };
73563 return Object.assign(Object.assign({}, baseConfig), config);
73564 }
73565 }]);
73566 return GRUCell;
73567 }(RNNCell);
73568 /** @nocollapse */
73569 GRUCell.className = 'GRUCell';
73570 registerClass(GRUCell);
73571 var GRU = /*#__PURE__*/function (_RNN2) {
73572 _inherits(GRU, _RNN2);
73573 var _super6 = _createSuper(GRU);
73574 function GRU(args) {
73575 _classCallCheck(this, GRU);
73576 if (args.implementation === 0) {
73577 console.warn('`implementation=0` has been deprecated, and now defaults to ' + '`implementation=1`. Please update your layer call.');
73578 }
73579 args.cell = new GRUCell(args);
73580 return _super6.call(this, args); // TODO(cais): Add activityRegularizer.
73581 }
73582 _createClass(GRU, [{
73583 key: "call",
73584 value: function call(inputs, kwargs) {
73585 var _this11 = this;
73586 return tidy(function () {
73587 if (_this11.cell.dropoutMask != null) {
73588 dispose(_this11.cell.dropoutMask);
73589 _this11.cell.dropoutMask = null;
73590 }
73591 if (_this11.cell.recurrentDropoutMask != null) {
73592 dispose(_this11.cell.recurrentDropoutMask);
73593 _this11.cell.recurrentDropoutMask = null;
73594 }
73595 var mask = kwargs == null ? null : kwargs['mask'];
73596 var training = kwargs == null ? null : kwargs['training'];
73597 var initialState = kwargs == null ? null : kwargs['initialState'];
73598 return _get(_getPrototypeOf(GRU.prototype), "call", _this11).call(_this11, inputs, {
73599 mask: mask,
73600 training: training,
73601 initialState: initialState
73602 });
73603 });
73604 }
73605 /** @nocollapse */
73606 }], [{
73607 key: "fromConfig",
73608 value: function fromConfig(cls, config) {
73609 if (config['implmentation'] === 0) {
73610 config['implementation'] = 1;
73611 }
73612 return new cls(config);
73613 }
73614 }]);
73615 return GRU;
73616 }(RNN);
73617 /** @nocollapse */
73618 GRU.className = 'GRU';
73619 registerClass(GRU);
73620 var LSTMCell = /*#__PURE__*/function (_RNNCell3) {
73621 _inherits(LSTMCell, _RNNCell3);
73622 var _super7 = _createSuper(LSTMCell);
73623 function LSTMCell(args) {
73624 var _this12;
73625 _classCallCheck(this, LSTMCell);
73626 _this12 = _super7.call(this, args);
73627 _this12.DEFAULT_ACTIVATION = 'tanh';
73628 _this12.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
73629 _this12.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
73630 _this12.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
73631 _this12.DEFAULT_BIAS_INITIALIZER = 'zeros';
73632 _this12.units = args.units;
73633 assertPositiveInteger(_this12.units, 'units');
73634 _this12.activation = getActivation(args.activation === undefined ? _this12.DEFAULT_ACTIVATION : args.activation);
73635 _this12.recurrentActivation = getActivation(args.recurrentActivation === undefined ? _this12.DEFAULT_RECURRENT_ACTIVATION : args.recurrentActivation);
73636 _this12.useBias = args.useBias == null ? true : args.useBias;
73637 _this12.kernelInitializer = getInitializer(args.kernelInitializer || _this12.DEFAULT_KERNEL_INITIALIZER);
73638 _this12.recurrentInitializer = getInitializer(args.recurrentInitializer || _this12.DEFAULT_RECURRENT_INITIALIZER);
73639 _this12.biasInitializer = getInitializer(args.biasInitializer || _this12.DEFAULT_BIAS_INITIALIZER);
73640 _this12.unitForgetBias = args.unitForgetBias;
73641 _this12.kernelRegularizer = getRegularizer(args.kernelRegularizer);
73642 _this12.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
73643 _this12.biasRegularizer = getRegularizer(args.biasRegularizer);
73644 _this12.kernelConstraint = getConstraint(args.kernelConstraint);
73645 _this12.recurrentConstraint = getConstraint(args.recurrentConstraint);
73646 _this12.biasConstraint = getConstraint(args.biasConstraint);
73647 _this12.dropout = min$2([1, max$2([0, args.dropout == null ? 0 : args.dropout])]);
73648 _this12.recurrentDropout = min$2([1, max$2([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])]);
73649 _this12.dropoutFunc = args.dropoutFunc;
73650 _this12.implementation = args.implementation;
73651 _this12.stateSize = [_this12.units, _this12.units];
73652 _this12.dropoutMask = null;
73653 _this12.recurrentDropoutMask = null;
73654 return _this12;
73655 }
73656 _createClass(LSTMCell, [{
73657 key: "build",
73658 value: function build(inputShape) {
73659 var _a;
73660 inputShape = getExactlyOneShape(inputShape);
73661 var inputDim = inputShape[inputShape.length - 1];
73662 this.kernel = this.addWeight('kernel', [inputDim, this.units * 4], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
73663 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 4], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
73664 var biasInitializer;
73665 if (this.useBias) {
73666 if (this.unitForgetBias) {
73667 var capturedBiasInit = this.biasInitializer;
73668 var capturedUnits = this.units;
73669 biasInitializer = new (_a = /*#__PURE__*/function (_Initializer) {
73670 _inherits(CustomInit, _Initializer);
73671 var _super8 = _createSuper(CustomInit);
73672 function CustomInit() {
73673 _classCallCheck(this, CustomInit);
73674 return _super8.apply(this, arguments);
73675 }
73676 _createClass(CustomInit, [{
73677 key: "apply",
73678 value: function apply(shape, dtype) {
73679 // TODO(cais): More informative variable names?
73680 var bI = capturedBiasInit.apply([capturedUnits]);
73681 var bF = new Ones().apply([capturedUnits]);
73682 var bCAndH = capturedBiasInit.apply([capturedUnits * 2]);
73683 return concatAlongFirstAxis(concatAlongFirstAxis(bI, bF), bCAndH);
73684 }
73685 }]);
73686 return CustomInit;
73687 }(Initializer), /** @nocollapse */
73688 _a.className = 'CustomInit', _a)();
73689 } else {
73690 biasInitializer = this.biasInitializer;
73691 }
73692 this.bias = this.addWeight('bias', [this.units * 4], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
73693 } else {
73694 this.bias = null;
73695 }
73696 // Porting Notes: Unlike the PyKeras implementation, we perform slicing
73697 // of the weights and bias in the call() method, at execution time.
73698 this.built = true;
73699 }
73700 }, {
73701 key: "call",
73702 value: function call(inputs, kwargs) {
73703 var _this13 = this;
73704 return tidy(function () {
73705 var training = kwargs['training'] == null ? false : kwargs['training'];
73706 inputs = inputs;
73707 if (inputs.length !== 3) {
73708 throw new ValueError("LSTMCell expects 3 input Tensors (inputs, h, c), got " + "".concat(inputs.length, "."));
73709 }
73710 var hTMinus1 = inputs[1]; // Previous memory state.
73711 var cTMinus1 = inputs[2]; // Previous carry state.
73712 inputs = inputs[0];
73713 if (0 < _this13.dropout && _this13.dropout < 1 && _this13.dropoutMask == null) {
73714 _this13.dropoutMask = generateDropoutMask({
73715 ones: function ones() {
73716 return onesLike$3(inputs);
73717 },
73718 rate: _this13.dropout,
73719 training: training,
73720 count: 4,
73721 dropoutFunc: _this13.dropoutFunc
73722 });
73723 }
73724 if (0 < _this13.recurrentDropout && _this13.recurrentDropout < 1 && _this13.recurrentDropoutMask == null) {
73725 _this13.recurrentDropoutMask = generateDropoutMask({
73726 ones: function ones() {
73727 return onesLike$3(hTMinus1);
73728 },
73729 rate: _this13.recurrentDropout,
73730 training: training,
73731 count: 4,
73732 dropoutFunc: _this13.dropoutFunc
73733 });
73734 }
73735 var dpMask = _this13.dropoutMask;
73736 var recDpMask = _this13.recurrentDropoutMask;
73737 // Note: For superior performance, TensorFlow.js always uses
73738 // implementation 2 regardless of the actual value of
73739 // config.implementation.
73740 var i;
73741 var f;
73742 var c;
73743 var o;
73744 if (0 < _this13.dropout && _this13.dropout < 1) {
73745 inputs = mul(inputs, dpMask[0]);
73746 }
73747 var z = dot$1(inputs, _this13.kernel.read());
73748 if (0 < _this13.recurrentDropout && _this13.recurrentDropout < 1) {
73749 hTMinus1 = mul(hTMinus1, recDpMask[0]);
73750 }
73751 z = add$3(z, dot$1(hTMinus1, _this13.recurrentKernel.read()));
73752 if (_this13.useBias) {
73753 z = biasAdd(z, _this13.bias.read());
73754 }
73755 var _tfc$split7 = split$3(z, 4, z.rank - 1),
73756 _tfc$split8 = _slicedToArray(_tfc$split7, 4),
73757 z0 = _tfc$split8[0],
73758 z1 = _tfc$split8[1],
73759 z2 = _tfc$split8[2],
73760 z3 = _tfc$split8[3];
73761 i = _this13.recurrentActivation.apply(z0);
73762 f = _this13.recurrentActivation.apply(z1);
73763 c = add$3(mul(f, cTMinus1), mul(i, _this13.activation.apply(z2)));
73764 o = _this13.recurrentActivation.apply(z3);
73765 var h = mul(o, _this13.activation.apply(c));
73766 // TODO(cais): Add use_learning_phase flag properly.
73767 return [h, h, c];
73768 });
73769 }
73770 }, {
73771 key: "getConfig",
73772 value: function getConfig() {
73773 var baseConfig = _get(_getPrototypeOf(LSTMCell.prototype), "getConfig", this).call(this);
73774 var config = {
73775 units: this.units,
73776 activation: serializeActivation(this.activation),
73777 recurrentActivation: serializeActivation(this.recurrentActivation),
73778 useBias: this.useBias,
73779 kernelInitializer: serializeInitializer(this.kernelInitializer),
73780 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
73781 biasInitializer: serializeInitializer(this.biasInitializer),
73782 unitForgetBias: this.unitForgetBias,
73783 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
73784 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
73785 biasRegularizer: serializeRegularizer(this.biasRegularizer),
73786 activityRegularizer: serializeRegularizer(this.activityRegularizer),
73787 kernelConstraint: serializeConstraint(this.kernelConstraint),
73788 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
73789 biasConstraint: serializeConstraint(this.biasConstraint),
73790 dropout: this.dropout,
73791 recurrentDropout: this.recurrentDropout,
73792 implementation: this.implementation
73793 };
73794 return Object.assign(Object.assign({}, baseConfig), config);
73795 }
73796 }]);
73797 return LSTMCell;
73798 }(RNNCell);
73799 /** @nocollapse */
73800 LSTMCell.className = 'LSTMCell';
73801 registerClass(LSTMCell);
73802 var LSTM = /*#__PURE__*/function (_RNN3) {
73803 _inherits(LSTM, _RNN3);
73804 var _super9 = _createSuper(LSTM);
73805 function LSTM(args) {
73806 _classCallCheck(this, LSTM);
73807 if (args.implementation === 0) {
73808 console.warn('`implementation=0` has been deprecated, and now defaults to ' + '`implementation=1`. Please update your layer call.');
73809 }
73810 args.cell = new LSTMCell(args);
73811 return _super9.call(this, args); // TODO(cais): Add activityRegularizer.
73812 }
73813 _createClass(LSTM, [{
73814 key: "call",
73815 value: function call(inputs, kwargs) {
73816 var _this14 = this;
73817 return tidy(function () {
73818 if (_this14.cell.dropoutMask != null) {
73819 dispose(_this14.cell.dropoutMask);
73820 _this14.cell.dropoutMask = null;
73821 }
73822 if (_this14.cell.recurrentDropoutMask != null) {
73823 dispose(_this14.cell.recurrentDropoutMask);
73824 _this14.cell.recurrentDropoutMask = null;
73825 }
73826 var mask = kwargs == null ? null : kwargs['mask'];
73827 var training = kwargs == null ? null : kwargs['training'];
73828 var initialState = kwargs == null ? null : kwargs['initialState'];
73829 return _get(_getPrototypeOf(LSTM.prototype), "call", _this14).call(_this14, inputs, {
73830 mask: mask,
73831 training: training,
73832 initialState: initialState
73833 });
73834 });
73835 }
73836 /** @nocollapse */
73837 }], [{
73838 key: "fromConfig",
73839 value: function fromConfig(cls, config) {
73840 if (config['implmentation'] === 0) {
73841 config['implementation'] = 1;
73842 }
73843 return new cls(config);
73844 }
73845 }]);
73846 return LSTM;
73847 }(RNN);
73848 /** @nocollapse */
73849 LSTM.className = 'LSTM';
73850 registerClass(LSTM);
73851 var StackedRNNCells = /*#__PURE__*/function (_RNNCell4) {
73852 _inherits(StackedRNNCells, _RNNCell4);
73853 var _super10 = _createSuper(StackedRNNCells);
73854 function StackedRNNCells(args) {
73855 var _this15;
73856 _classCallCheck(this, StackedRNNCells);
73857 _this15 = _super10.call(this, args);
73858 _this15.cells = args.cells;
73859 return _this15;
73860 }
73861 _createClass(StackedRNNCells, [{
73862 key: "stateSize",
73863 get: function get() {
73864 // States are a flat list in reverse order of the cell stack.
73865 // This allows preserving the requirement `stack.statesize[0] ===
73866 // outputDim`. E.g., states of a 2-layer LSTM would be `[h2, c2, h1, c1]`,
73867 // assuming one LSTM has states `[h, c]`.
73868 var stateSize = [];
73869 var _iterator3 = _createForOfIteratorHelper(this.cells.slice().reverse()),
73870 _step3;
73871 try {
73872 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
73873 var cell = _step3.value;
73874 if (Array.isArray(cell.stateSize)) {
73875 stateSize.push.apply(stateSize, _toConsumableArray(cell.stateSize));
73876 } else {
73877 stateSize.push(cell.stateSize);
73878 }
73879 }
73880 } catch (err) {
73881 _iterator3.e(err);
73882 } finally {
73883 _iterator3.f();
73884 }
73885 return stateSize;
73886 }
73887 }, {
73888 key: "call",
73889 value: function call(inputs, kwargs) {
73890 var _this16 = this;
73891 return tidy(function () {
73892 inputs = inputs;
73893 var states = inputs.slice(1);
73894 // Recover per-cell states.
73895 var nestedStates = [];
73896 var _iterator4 = _createForOfIteratorHelper(_this16.cells.slice().reverse()),
73897 _step4;
73898 try {
73899 for (_iterator4.s(); !(_step4 = _iterator4.n()).done;) {
73900 var _cell = _step4.value;
73901 if (Array.isArray(_cell.stateSize)) {
73902 nestedStates.push(states.splice(0, _cell.stateSize.length));
73903 } else {
73904 nestedStates.push(states.splice(0, 1));
73905 }
73906 }
73907 } catch (err) {
73908 _iterator4.e(err);
73909 } finally {
73910 _iterator4.f();
73911 }
73912 nestedStates.reverse();
73913 // Call the cells in order and store the returned states.
73914 var newNestedStates = [];
73915 var callInputs;
73916 for (var i = 0; i < _this16.cells.length; ++i) {
73917 var cell = _this16.cells[i];
73918 states = nestedStates[i];
73919 // TODO(cais): Take care of constants.
73920 if (i === 0) {
73921 callInputs = [inputs[0]].concat(states);
73922 } else {
73923 callInputs = [callInputs[0]].concat(states);
73924 }
73925 callInputs = cell.call(callInputs, kwargs);
73926 newNestedStates.push(callInputs.slice(1));
73927 }
73928 // Format the new states as a flat list in reverse cell order.
73929 states = [];
73930 var _iterator5 = _createForOfIteratorHelper(newNestedStates.slice().reverse()),
73931 _step5;
73932 try {
73933 for (_iterator5.s(); !(_step5 = _iterator5.n()).done;) {
73934 var _states;
73935 var cellStates = _step5.value;
73936 (_states = states).push.apply(_states, _toConsumableArray(cellStates));
73937 }
73938 } catch (err) {
73939 _iterator5.e(err);
73940 } finally {
73941 _iterator5.f();
73942 }
73943 return [callInputs[0]].concat(states);
73944 });
73945 }
73946 }, {
73947 key: "build",
73948 value: function build(inputShape) {
73949 if (isArrayOfShapes(inputShape)) {
73950 // TODO(cais): Take care of input constants.
73951 // const constantShape = inputShape.slice(1);
73952 inputShape = inputShape[0];
73953 }
73954 inputShape = inputShape;
73955 var outputDim;
73956 this.cells.forEach(function (cell, i) {
73957 nameScope("RNNCell_".concat(i), function () {
73958 // TODO(cais): Take care of input constants.
73959 cell.build(inputShape);
73960 if (Array.isArray(cell.stateSize)) {
73961 outputDim = cell.stateSize[0];
73962 } else {
73963 outputDim = cell.stateSize;
73964 }
73965 inputShape = [inputShape[0], outputDim];
73966 });
73967 });
73968 this.built = true;
73969 }
73970 }, {
73971 key: "getConfig",
73972 value: function getConfig() {
73973 var baseConfig = _get(_getPrototypeOf(StackedRNNCells.prototype), "getConfig", this).call(this);
73974 var getCellConfig = function getCellConfig(cell) {
73975 return {
73976 'className': cell.getClassName(),
73977 'config': cell.getConfig()
73978 };
73979 };
73980 var cellConfigs = this.cells.map(getCellConfig);
73981 var config = {
73982 'cells': cellConfigs
73983 };
73984 return Object.assign(Object.assign({}, baseConfig), config);
73985 }
73986 /** @nocollapse */
73987 }, {
73988 key: "trainableWeights",
73989 get: function get() {
73990 if (!this.trainable) {
73991 return [];
73992 }
73993 var weights = [];
73994 var _iterator6 = _createForOfIteratorHelper(this.cells),
73995 _step6;
73996 try {
73997 for (_iterator6.s(); !(_step6 = _iterator6.n()).done;) {
73998 var cell = _step6.value;
73999 weights.push.apply(weights, _toConsumableArray(cell.trainableWeights));
74000 }
74001 } catch (err) {
74002 _iterator6.e(err);
74003 } finally {
74004 _iterator6.f();
74005 }
74006 return weights;
74007 }
74008 }, {
74009 key: "nonTrainableWeights",
74010 get: function get() {
74011 var weights = [];
74012 var _iterator7 = _createForOfIteratorHelper(this.cells),
74013 _step7;
74014 try {
74015 for (_iterator7.s(); !(_step7 = _iterator7.n()).done;) {
74016 var _cell2 = _step7.value;
74017 weights.push.apply(weights, _toConsumableArray(_cell2.nonTrainableWeights));
74018 }
74019 } catch (err) {
74020 _iterator7.e(err);
74021 } finally {
74022 _iterator7.f();
74023 }
74024 if (!this.trainable) {
74025 var trainableWeights = [];
74026 var _iterator8 = _createForOfIteratorHelper(this.cells),
74027 _step8;
74028 try {
74029 for (_iterator8.s(); !(_step8 = _iterator8.n()).done;) {
74030 var cell = _step8.value;
74031 trainableWeights.push.apply(trainableWeights, _toConsumableArray(cell.trainableWeights));
74032 }
74033 } catch (err) {
74034 _iterator8.e(err);
74035 } finally {
74036 _iterator8.f();
74037 }
74038 return trainableWeights.concat(weights);
74039 }
74040 return weights;
74041 }
74042 /**
74043 * Retrieve the weights of a the model.
74044 *
74045 * @returns A flat `Array` of `tf.Tensor`s.
74046 */
74047 }, {
74048 key: "getWeights",
74049 value: function getWeights() {
74050 var weights = [];
74051 var _iterator9 = _createForOfIteratorHelper(this.cells),
74052 _step9;
74053 try {
74054 for (_iterator9.s(); !(_step9 = _iterator9.n()).done;) {
74055 var cell = _step9.value;
74056 weights.push.apply(weights, _toConsumableArray(cell.weights));
74057 }
74058 } catch (err) {
74059 _iterator9.e(err);
74060 } finally {
74061 _iterator9.f();
74062 }
74063 return batchGetValue(weights);
74064 }
74065 /**
74066 * Set the weights of the model.
74067 *
74068 * @param weights An `Array` of `tf.Tensor`s with shapes and types matching
74069 * the output of `getWeights()`.
74070 */
74071 }, {
74072 key: "setWeights",
74073 value: function setWeights(weights) {
74074 var tuples = [];
74075 var _iterator10 = _createForOfIteratorHelper(this.cells),
74076 _step10;
74077 try {
74078 for (_iterator10.s(); !(_step10 = _iterator10.n()).done;) {
74079 var cell = _step10.value;
74080 var numParams = cell.weights.length;
74081 var inputWeights = weights.splice(numParams);
74082 for (var i = 0; i < cell.weights.length; ++i) {
74083 tuples.push([cell.weights[i], inputWeights[i]]);
74084 }
74085 }
74086 } catch (err) {
74087 _iterator10.e(err);
74088 } finally {
74089 _iterator10.f();
74090 }
74091 batchSetValue(tuples);
74092 }
74093 }], [{
74094 key: "fromConfig",
74095 value: function fromConfig(cls, config) {
74096 var customObjects = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : {};
74097 var cells = [];
74098 var _iterator11 = _createForOfIteratorHelper(config['cells']),
74099 _step11;
74100 try {
74101 for (_iterator11.s(); !(_step11 = _iterator11.n()).done;) {
74102 var cellConfig = _step11.value;
74103 cells.push(deserialize(cellConfig, customObjects));
74104 }
74105 } catch (err) {
74106 _iterator11.e(err);
74107 } finally {
74108 _iterator11.f();
74109 }
74110 return new cls({
74111 cells: cells
74112 });
74113 }
74114 }]);
74115 return StackedRNNCells;
74116 }(RNNCell);
74117 /** @nocollapse */
74118 StackedRNNCells.className = 'StackedRNNCells';
74119 registerClass(StackedRNNCells);
74120 function generateDropoutMask(args) {
74121 var ones = args.ones,
74122 rate = args.rate,
74123 _args$training = args.training,
74124 training = _args$training === void 0 ? false : _args$training,
74125 _args$count = args.count,
74126 count = _args$count === void 0 ? 1 : _args$count,
74127 dropoutFunc = args.dropoutFunc;
74128 var droppedInputs = function droppedInputs() {
74129 return dropoutFunc != null ? dropoutFunc(ones(), rate) : dropout$1(ones(), rate);
74130 };
74131 var createMask = function createMask() {
74132 return inTrainPhase(droppedInputs, ones, training);
74133 };
74134 // just in case count is provided with null or undefined
74135 if (!count || count <= 1) {
74136 return keep(createMask().clone());
74137 }
74138 var masks = Array(count).fill(undefined).map(createMask);
74139 return masks.map(function (m) {
74140 return keep(m.clone());
74141 });
74142 }
74143
74144 /**
74145 * @license
74146 * Copyright 2020 Google LLC
74147 *
74148 * Use of this source code is governed by an MIT-style
74149 * license that can be found in the LICENSE file or at
74150 * https://opensource.org/licenses/MIT.
74151 * =============================================================================
74152 */
74153 var __rest = undefined && undefined.__rest || function (s, e) {
74154 var t = {};
74155 for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0) t[p] = s[p];
74156 if (s != null && typeof Object.getOwnPropertySymbols === "function") for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) {
74157 if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i])) t[p[i]] = s[p[i]];
74158 }
74159 return t;
74160 };
74161 var ConvRNN2DCell = /*#__PURE__*/function (_RNNCell) {
74162 _inherits(ConvRNN2DCell, _RNNCell);
74163 var _super = _createSuper(ConvRNN2DCell);
74164 function ConvRNN2DCell() {
74165 _classCallCheck(this, ConvRNN2DCell);
74166 return _super.apply(this, arguments);
74167 }
74168 return _createClass(ConvRNN2DCell);
74169 }(RNNCell);
74170 /**
74171 * Base class for convolutional-recurrent layers.
74172 */
74173 var ConvRNN2D = /*#__PURE__*/function (_RNN) {
74174 _inherits(ConvRNN2D, _RNN);
74175 var _super2 = _createSuper(ConvRNN2D);
74176 function ConvRNN2D(args) {
74177 var _this;
74178 _classCallCheck(this, ConvRNN2D);
74179 if (args.unroll) {
74180 throw new NotImplementedError('Unrolling is not possible with convolutional RNNs.');
74181 }
74182 if (Array.isArray(args.cell)) {
74183 throw new NotImplementedError('It is not possible at the moment to stack convolutional cells.');
74184 }
74185 _this = _super2.call(this, args);
74186 _this.inputSpec = [new InputSpec({
74187 ndim: 5
74188 })];
74189 return _this;
74190 }
74191 _createClass(ConvRNN2D, [{
74192 key: "call",
74193 value: function call(inputs, kwargs) {
74194 var _this2 = this;
74195 return tidy(function () {
74196 if (_this2.cell.dropoutMask != null) {
74197 dispose(_this2.cell.dropoutMask);
74198 _this2.cell.dropoutMask = null;
74199 }
74200 if (_this2.cell.recurrentDropoutMask != null) {
74201 dispose(_this2.cell.recurrentDropoutMask);
74202 _this2.cell.recurrentDropoutMask = null;
74203 }
74204 if (kwargs && kwargs['constants']) {
74205 throw new ValueError('ConvRNN2D cell does not support constants');
74206 }
74207 var mask = kwargs == null ? null : kwargs['mask'];
74208 var training = kwargs == null ? null : kwargs['training'];
74209 var initialState = kwargs == null ? null : kwargs['initialState'];
74210 return _get(_getPrototypeOf(ConvRNN2D.prototype), "call", _this2).call(_this2, inputs, {
74211 mask: mask,
74212 training: training,
74213 initialState: initialState
74214 });
74215 });
74216 }
74217 }, {
74218 key: "computeOutputShape",
74219 value: function computeOutputShape(inputShape) {
74220 var outShape = this.computeSingleOutputShape(inputShape);
74221 if (!this.returnSequences) {
74222 outShape = [outShape[0]].concat(_toConsumableArray(outShape.slice(2)));
74223 }
74224 if (this.returnState) {
74225 outShape = [outShape].concat(_toConsumableArray(Array(2).fill([inputShape[0]].concat(_toConsumableArray(outShape.slice(-3))))));
74226 }
74227 return outShape;
74228 }
74229 }, {
74230 key: "getInitialState",
74231 value: function getInitialState(inputs) {
74232 var _this3 = this;
74233 return tidy(function () {
74234 var stateSize = _this3.cell.stateSize;
74235 var inputShape = inputs.shape;
74236 var outputShape = _this3.computeSingleOutputShape(inputShape);
74237 var stateShape = [outputShape[0]].concat(_toConsumableArray(outputShape.slice(2)));
74238 var initialState = zeros$2(stateShape);
74239 if (Array.isArray(stateSize)) {
74240 return Array(stateSize.length).fill(initialState);
74241 }
74242 return [initialState];
74243 });
74244 }
74245 }, {
74246 key: "resetStates",
74247 value: function resetStates(states) {
74248 var _this4 = this;
74249 var training = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
74250 tidy(function () {
74251 if (!_this4.stateful) {
74252 throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
74253 }
74254 var inputShape = _this4.inputSpec[0].shape;
74255 var outputShape = _this4.computeSingleOutputShape(inputShape);
74256 var stateShape = [outputShape[0]].concat(_toConsumableArray(outputShape.slice(2)));
74257 var batchSize = inputShape[0];
74258 if (batchSize == null) {
74259 throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' + 'the batch size of your input tensors: \n' + '- If using a Sequential model, specify the batch size by ' + 'passing a `batchInputShape` option to your first layer.\n' + '- If using the functional API, specify the batch size by ' + 'passing a `batchShape` option to your Input layer.');
74260 }
74261 // Initialize state if null.
74262 if (_this4.getStates() == null) {
74263 if (Array.isArray(_this4.cell.stateSize)) {
74264 _this4.states_ = _this4.cell.stateSize.map(function () {
74265 return zeros$2(stateShape);
74266 });
74267 } else {
74268 _this4.states_ = [zeros$2(stateShape)];
74269 }
74270 } else if (states == null) {
74271 // Dispose old state tensors.
74272 dispose(_this4.states_);
74273 // For stateful RNNs, fully dispose kept old states.
74274 if (_this4.keptStates != null) {
74275 dispose(_this4.keptStates);
74276 _this4.keptStates = [];
74277 }
74278 if (Array.isArray(_this4.cell.stateSize)) {
74279 _this4.states_ = _this4.cell.stateSize.map(function () {
74280 return zeros$2(stateShape);
74281 });
74282 } else {
74283 _this4.states_[0] = zeros$2(stateShape);
74284 }
74285 } else {
74286 if (!Array.isArray(states)) {
74287 states = [states];
74288 }
74289 if (states.length !== _this4.states_.length) {
74290 throw new ValueError("Layer ".concat(_this4.name, " expects ").concat(_this4.states_.length, " state(s), ") + "but it received ".concat(states.length, " state value(s). Input ") + "received: ".concat(states));
74291 }
74292 if (training) {
74293 // Store old state tensors for complete disposal later, i.e., during
74294 // the next no-arg call to this method. We do not dispose the old
74295 // states immediately because that BPTT (among other things) require
74296 // them.
74297 _this4.keptStates.push(_this4.states_.slice());
74298 } else {
74299 dispose(_this4.states_);
74300 }
74301 for (var index = 0; index < _this4.states_.length; ++index) {
74302 var value = states[index];
74303 var expectedShape = stateShape;
74304 if (!arraysEqual(value.shape, expectedShape)) {
74305 throw new ValueError("State ".concat(index, " is incompatible with layer ").concat(_this4.name, ": ") + "expected shape=".concat(expectedShape, ", received shape=").concat(value.shape));
74306 }
74307 _this4.states_[index] = value;
74308 }
74309 }
74310 _this4.states_ = _this4.states_.map(function (state) {
74311 return keep(state.clone());
74312 });
74313 });
74314 }
74315 }, {
74316 key: "computeSingleOutputShape",
74317 value: function computeSingleOutputShape(inputShape) {
74318 var _this$cell = this.cell,
74319 dataFormat = _this$cell.dataFormat,
74320 filters = _this$cell.filters,
74321 kernelSize = _this$cell.kernelSize,
74322 padding = _this$cell.padding,
74323 strides = _this$cell.strides,
74324 dilationRate = _this$cell.dilationRate;
74325 var isChannelsFirst = dataFormat === 'channelsFirst';
74326 var h = inputShape[isChannelsFirst ? 3 : 2];
74327 var w = inputShape[isChannelsFirst ? 4 : 3];
74328 var hOut = convOutputLength(h, kernelSize[0], padding, strides[0], dilationRate[0]);
74329 var wOut = convOutputLength(w, kernelSize[1], padding, strides[1], dilationRate[1]);
74330 var outShape = [].concat(_toConsumableArray(inputShape.slice(0, 2)), _toConsumableArray(isChannelsFirst ? [filters, hOut, wOut] : [hOut, wOut, filters]));
74331 return outShape;
74332 }
74333 }]);
74334 return ConvRNN2D;
74335 }(RNN);
74336 /** @nocollapse */
74337 ConvRNN2D.className = 'ConvRNN2D';
74338 var ConvLSTM2DCell = /*#__PURE__*/function (_LSTMCell) {
74339 _inherits(ConvLSTM2DCell, _LSTMCell);
74340 var _super3 = _createSuper(ConvLSTM2DCell);
74341 function ConvLSTM2DCell(args) {
74342 var _this5;
74343 _classCallCheck(this, ConvLSTM2DCell);
74344 var filters = args.filters,
74345 kernelSize = args.kernelSize,
74346 strides = args.strides,
74347 padding = args.padding,
74348 dataFormat = args.dataFormat,
74349 dilationRate = args.dilationRate;
74350 _this5 = _super3.call(this, Object.assign(Object.assign({}, args), {
74351 units: filters
74352 }));
74353 _this5.filters = filters;
74354 assertPositiveInteger(_this5.filters, 'filters');
74355 _this5.kernelSize = normalizeArray(kernelSize, 2, 'kernelSize');
74356 _this5.kernelSize.forEach(function (size) {
74357 return assertPositiveInteger(size, 'kernelSize');
74358 });
74359 _this5.strides = normalizeArray(strides || 1, 2, 'strides');
74360 _this5.strides.forEach(function (stride) {
74361 return assertPositiveInteger(stride, 'strides');
74362 });
74363 _this5.padding = padding || 'valid';
74364 checkPaddingMode(_this5.padding);
74365 _this5.dataFormat = dataFormat || 'channelsLast';
74366 checkDataFormat(_this5.dataFormat);
74367 _this5.dilationRate = normalizeArray(dilationRate || 1, 2, 'dilationRate');
74368 _this5.dilationRate.forEach(function (rate) {
74369 return assertPositiveInteger(rate, 'dilationRate');
74370 });
74371 return _this5;
74372 }
74373 _createClass(ConvLSTM2DCell, [{
74374 key: "build",
74375 value: function build(inputShape) {
74376 var _a;
74377 inputShape = getExactlyOneShape(inputShape);
74378 var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
74379 if (inputShape[channelAxis] == null) {
74380 throw new ValueError("The channel dimension of the input should be defined. " + "Found ".concat(inputShape[channelAxis]));
74381 }
74382 var inputDim = inputShape[channelAxis];
74383 var numOfKernels = 4;
74384 var kernelShape = this.kernelSize.concat([inputDim, this.filters * numOfKernels]);
74385 this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
74386 var recurrentKernelShape = this.kernelSize.concat([this.filters, this.filters * numOfKernels]);
74387 this.recurrentKernel = this.addWeight('recurrent_kernel', recurrentKernelShape, null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
74388 if (this.useBias) {
74389 var biasInitializer;
74390 if (this.unitForgetBias) {
74391 var init = this.biasInitializer;
74392 var filters = this.filters;
74393 biasInitializer = new (_a = /*#__PURE__*/function (_Initializer) {
74394 _inherits(CustomInit, _Initializer);
74395 var _super4 = _createSuper(CustomInit);
74396 function CustomInit() {
74397 _classCallCheck(this, CustomInit);
74398 return _super4.apply(this, arguments);
74399 }
74400 _createClass(CustomInit, [{
74401 key: "apply",
74402 value: function apply(shape, dtype) {
74403 var biasI = init.apply([filters]);
74404 var biasF = ones$1([filters]);
74405 var biasCAndO = init.apply([filters * 2]);
74406 return concatenate$2([biasI, biasF, biasCAndO]);
74407 }
74408 }]);
74409 return CustomInit;
74410 }(Initializer), /** @nocollapse */
74411 _a.className = 'CustomInit', _a)();
74412 } else {
74413 biasInitializer = this.biasInitializer;
74414 }
74415 this.bias = this.addWeight('bias', [this.filters * numOfKernels], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
74416 }
74417 this.built = true;
74418 }
74419 }, {
74420 key: "call",
74421 value: function call(inputs, kwargs) {
74422 var _this6 = this;
74423 return tidy(function () {
74424 if (inputs.length !== 3) {
74425 throw new ValueError("ConvLSTM2DCell expects 3 input Tensors (inputs, h, c), got " + "".concat(inputs.length, "."));
74426 }
74427 var training = kwargs['training'] || false;
74428 var x = inputs[0]; // Current input
74429 var hTMinus1 = inputs[1]; // Previous memory state.
74430 var cTMinus1 = inputs[2]; // Previous carry state.
74431 var numOfKernels = 4;
74432 if (0 < _this6.dropout && _this6.dropout < 1 && _this6.dropoutMask == null) {
74433 _this6.dropoutMask = generateDropoutMask({
74434 ones: function ones() {
74435 return onesLike$3(x);
74436 },
74437 rate: _this6.dropout,
74438 training: training,
74439 count: numOfKernels,
74440 dropoutFunc: _this6.dropoutFunc
74441 });
74442 }
74443 var dropoutMask = _this6.dropoutMask;
74444 var applyDropout = function applyDropout(x, mask, index) {
74445 if (!mask || !mask[index]) {
74446 return x;
74447 }
74448 return mul(mask[index], x);
74449 };
74450 var xI = applyDropout(x, dropoutMask, 0);
74451 var xF = applyDropout(x, dropoutMask, 1);
74452 var xC = applyDropout(x, dropoutMask, 2);
74453 var xO = applyDropout(x, dropoutMask, 3);
74454 if (0 < _this6.recurrentDropout && _this6.recurrentDropout < 1 && _this6.recurrentDropoutMask == null) {
74455 _this6.recurrentDropoutMask = generateDropoutMask({
74456 ones: function ones() {
74457 return onesLike$3(hTMinus1);
74458 },
74459 rate: _this6.recurrentDropout,
74460 training: training,
74461 count: numOfKernels,
74462 dropoutFunc: _this6.dropoutFunc
74463 });
74464 }
74465 var recDropoutMask = _this6.recurrentDropoutMask;
74466 var hI = applyDropout(hTMinus1, recDropoutMask, 0);
74467 var hF = applyDropout(hTMinus1, recDropoutMask, 1);
74468 var hC = applyDropout(hTMinus1, recDropoutMask, 2);
74469 var hO = applyDropout(hTMinus1, recDropoutMask, 3);
74470 var kernelChannelAxis = 3;
74471 var _tfc$split = split$3(_this6.kernel.read(), numOfKernels, kernelChannelAxis),
74472 _tfc$split2 = _slicedToArray(_tfc$split, 4),
74473 kernelI = _tfc$split2[0],
74474 kernelF = _tfc$split2[1],
74475 kernelC = _tfc$split2[2],
74476 kernelO = _tfc$split2[3];
74477 var _ref = _this6.useBias ? split$3(_this6.bias.read(), numOfKernels) : [null, null, null, null],
74478 _ref2 = _slicedToArray(_ref, 4),
74479 biasI = _ref2[0],
74480 biasF = _ref2[1],
74481 biasC = _ref2[2],
74482 biasO = _ref2[3];
74483 xI = _this6.inputConv(xI, kernelI, biasI, _this6.padding);
74484 xF = _this6.inputConv(xF, kernelF, biasF, _this6.padding);
74485 xC = _this6.inputConv(xC, kernelC, biasC, _this6.padding);
74486 xO = _this6.inputConv(xO, kernelO, biasO, _this6.padding);
74487 var _tfc$split3 = split$3(_this6.recurrentKernel.read(), numOfKernels, kernelChannelAxis),
74488 _tfc$split4 = _slicedToArray(_tfc$split3, 4),
74489 recKernelI = _tfc$split4[0],
74490 recKernelF = _tfc$split4[1],
74491 recKernelC = _tfc$split4[2],
74492 recKernelO = _tfc$split4[3];
74493 hI = _this6.recurrentConv(hI, recKernelI);
74494 hF = _this6.recurrentConv(hF, recKernelF);
74495 hC = _this6.recurrentConv(hC, recKernelC);
74496 hO = _this6.recurrentConv(hO, recKernelO);
74497 var i = _this6.recurrentActivation.apply(add$3(xI, hI));
74498 var f = _this6.recurrentActivation.apply(add$3(xF, hF));
74499 var c = add$3(mul(f, cTMinus1), mul(i, _this6.activation.apply(add$3(xC, hC))));
74500 var h = mul(_this6.recurrentActivation.apply(add$3(xO, hO)), _this6.activation.apply(c));
74501 return [h, h, c];
74502 });
74503 }
74504 }, {
74505 key: "getConfig",
74506 value: function getConfig() {
74507 var _a = _get(_getPrototypeOf(ConvLSTM2DCell.prototype), "getConfig", this).call(this),
74508 _ = _a['units'],
74509 baseConfig = __rest(_a, ['units']);
74510 var config = {
74511 filters: this.filters,
74512 kernelSize: this.kernelSize,
74513 padding: this.padding,
74514 dataFormat: this.dataFormat,
74515 dilationRate: this.dilationRate,
74516 strides: this.strides
74517 };
74518 return Object.assign(Object.assign({}, baseConfig), config);
74519 }
74520 }, {
74521 key: "inputConv",
74522 value: function inputConv(x, w, b, padding) {
74523 var out = conv2d$4(x, w, this.strides, padding || 'valid', this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC', this.dilationRate);
74524 if (b) {
74525 return biasAdd(out, b, this.dataFormat);
74526 }
74527 return out;
74528 }
74529 }, {
74530 key: "recurrentConv",
74531 value: function recurrentConv(x, w) {
74532 var strides = 1;
74533 return conv2d$4(x, w, strides, 'same', this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC');
74534 }
74535 }]);
74536 return ConvLSTM2DCell;
74537 }(LSTMCell);
74538 /** @nocollapse */
74539 ConvLSTM2DCell.className = 'ConvLSTM2DCell';
74540 registerClass(ConvLSTM2DCell);
74541 var ConvLSTM2D = /*#__PURE__*/function (_ConvRNN2D) {
74542 _inherits(ConvLSTM2D, _ConvRNN2D);
74543 var _super5 = _createSuper(ConvLSTM2D);
74544 function ConvLSTM2D(args) {
74545 _classCallCheck(this, ConvLSTM2D);
74546 var cell = new ConvLSTM2DCell(args);
74547 return _super5.call(this, Object.assign(Object.assign({}, args), {
74548 cell: cell
74549 }));
74550 }
74551 /** @nocollapse */
74552 _createClass(ConvLSTM2D, null, [{
74553 key: "fromConfig",
74554 value: function fromConfig(cls, config) {
74555 return new cls(config);
74556 }
74557 }]);
74558 return ConvLSTM2D;
74559 }(ConvRNN2D);
74560 /** @nocollapse */
74561 ConvLSTM2D.className = 'ConvLSTM2D';
74562 registerClass(ConvLSTM2D);
74563
74564 var Dropout = /*#__PURE__*/function (_Layer) {
74565 _inherits(Dropout, _Layer);
74566 var _super = _createSuper(Dropout);
74567 function Dropout(args) {
74568 var _this;
74569 _classCallCheck(this, Dropout);
74570 _this = _super.call(this, args);
74571 _this.rate = Math.max(Math.min(args.rate, 1), 0);
74572 // So that the scalar doesn't get tidied up between executions.
74573 _this.noiseShape = args.noiseShape;
74574 _this.seed = args.seed;
74575 _this.supportsMasking = true;
74576 return _this;
74577 }
74578 _createClass(Dropout, [{
74579 key: "getNoiseShape",
74580 value: function getNoiseShape(input) {
74581 if (this.noiseShape == null) {
74582 return this.noiseShape;
74583 }
74584 var inputShape = input.shape;
74585 var noiseShape = [];
74586 for (var i = 0; i < this.noiseShape.length; ++i) {
74587 noiseShape.push(this.noiseShape[i] == null ? inputShape[i] : this.noiseShape[i]);
74588 }
74589 return noiseShape;
74590 }
74591 }, {
74592 key: "call",
74593 value: function call(inputs, kwargs) {
74594 var _this2 = this;
74595 return tidy(function () {
74596 _this2.invokeCallHook(inputs, kwargs);
74597 var input = getExactlyOneTensor(inputs);
74598 if (0 < _this2.rate && _this2.rate < 1) {
74599 var training = kwargs['training'] == null ? false : kwargs['training'];
74600 var noiseShape = _this2.getNoiseShape(input);
74601 var output = inTrainPhase(function () {
74602 return dropout$1(input, _this2.rate, noiseShape, _this2.seed);
74603 }, function () {
74604 return input;
74605 }, training);
74606 return output;
74607 }
74608 return inputs;
74609 });
74610 }
74611 }, {
74612 key: "getConfig",
74613 value: function getConfig() {
74614 var config = {
74615 rate: this.rate,
74616 noiseShape: this.noiseShape,
74617 seed: this.seed
74618 };
74619 var baseConfig = _get(_getPrototypeOf(Dropout.prototype), "getConfig", this).call(this);
74620 Object.assign(config, baseConfig);
74621 return config;
74622 }
74623 }, {
74624 key: "dispose",
74625 value: function dispose() {
74626 return _get(_getPrototypeOf(Dropout.prototype), "dispose", this).call(this);
74627 }
74628 }]);
74629 return Dropout;
74630 }(Layer);
74631 /** @nocollapse */
74632 Dropout.className = 'Dropout';
74633 registerClass(Dropout);
74634 var SpatialDropout1D = /*#__PURE__*/function (_Dropout) {
74635 _inherits(SpatialDropout1D, _Dropout);
74636 var _super2 = _createSuper(SpatialDropout1D);
74637 function SpatialDropout1D(args) {
74638 var _this3;
74639 _classCallCheck(this, SpatialDropout1D);
74640 _this3 = _super2.call(this, args);
74641 _this3.inputSpec = [{
74642 ndim: 3
74643 }];
74644 return _this3;
74645 }
74646 _createClass(SpatialDropout1D, [{
74647 key: "getNoiseShape",
74648 value: function getNoiseShape(input) {
74649 var inputShape = input.shape;
74650 return [inputShape[0], 1, inputShape[2]];
74651 }
74652 }]);
74653 return SpatialDropout1D;
74654 }(Dropout);
74655 /** @nocollapse */
74656 SpatialDropout1D.className = 'SpatialDropout1D';
74657 registerClass(SpatialDropout1D);
74658 var Dense = /*#__PURE__*/function (_Layer2) {
74659 _inherits(Dense, _Layer2);
74660 var _super3 = _createSuper(Dense);
74661 function Dense(args) {
74662 var _this4;
74663 _classCallCheck(this, Dense);
74664 _this4 = _super3.call(this, args);
74665 // Default activation: Linear (none).
74666 _this4.activation = null;
74667 _this4.useBias = true;
74668 _this4.kernel = null;
74669 _this4.bias = null;
74670 _this4.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
74671 _this4.DEFAULT_BIAS_INITIALIZER = 'zeros';
74672 if (args.batchInputShape == null && args.inputShape == null && args.inputDim != null) {
74673 // This logic is copied from Layer's constructor, since we can't
74674 // do exactly what the Python constructor does for Dense().
74675 var batchSize = null;
74676 if (args.batchSize != null) {
74677 batchSize = args.batchSize;
74678 }
74679 _this4.batchInputShape = [batchSize, args.inputDim];
74680 }
74681 _this4.units = args.units;
74682 assertPositiveInteger(_this4.units, 'units');
74683 _this4.activation = getActivation(args.activation);
74684 if (args.useBias != null) {
74685 _this4.useBias = args.useBias;
74686 }
74687 _this4.kernelInitializer = getInitializer(args.kernelInitializer || _this4.DEFAULT_KERNEL_INITIALIZER);
74688 _this4.biasInitializer = getInitializer(args.biasInitializer || _this4.DEFAULT_BIAS_INITIALIZER);
74689 _this4.kernelConstraint = getConstraint(args.kernelConstraint);
74690 _this4.biasConstraint = getConstraint(args.biasConstraint);
74691 _this4.kernelRegularizer = getRegularizer(args.kernelRegularizer);
74692 _this4.biasRegularizer = getRegularizer(args.biasRegularizer);
74693 _this4.activityRegularizer = getRegularizer(args.activityRegularizer);
74694 _this4.supportsMasking = true;
74695 _this4.inputSpec = [{
74696 minNDim: 2
74697 }];
74698 return _this4;
74699 }
74700 _createClass(Dense, [{
74701 key: "build",
74702 value: function build(inputShape) {
74703 inputShape = getExactlyOneShape(inputShape);
74704 var inputLastDim = inputShape[inputShape.length - 1];
74705 if (this.kernel == null) {
74706 this.kernel = this.addWeight('kernel', [inputLastDim, this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
74707 if (this.useBias) {
74708 this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
74709 }
74710 }
74711 this.inputSpec = [{
74712 minNDim: 2,
74713 axes: _defineProperty({}, -1, inputLastDim)
74714 }];
74715 this.built = true;
74716 }
74717 }, {
74718 key: "computeOutputShape",
74719 value: function computeOutputShape(inputShape) {
74720 inputShape = getExactlyOneShape(inputShape);
74721 var outputShape = inputShape.slice();
74722 outputShape[outputShape.length - 1] = this.units;
74723 return outputShape;
74724 }
74725 }, {
74726 key: "call",
74727 value: function call(inputs, kwargs) {
74728 var _this5 = this;
74729 return tidy(function () {
74730 _this5.invokeCallHook(inputs, kwargs);
74731 // Dense layer accepts only a single input.
74732 var input = getExactlyOneTensor(inputs);
74733 var fusedActivationName = mapActivationToFusedKernel(_this5.activation.getClassName());
74734 var output;
74735 if (fusedActivationName != null) {
74736 output = dot$1(input, _this5.kernel.read(), fusedActivationName, _this5.bias ? _this5.bias.read() : null);
74737 } else {
74738 output = dot$1(input, _this5.kernel.read());
74739 if (_this5.bias != null) {
74740 output = biasAdd(output, _this5.bias.read());
74741 }
74742 if (_this5.activation != null) {
74743 output = _this5.activation.apply(output);
74744 }
74745 }
74746 return output;
74747 });
74748 }
74749 }, {
74750 key: "getConfig",
74751 value: function getConfig() {
74752 var config = {
74753 units: this.units,
74754 activation: serializeActivation(this.activation),
74755 useBias: this.useBias,
74756 kernelInitializer: serializeInitializer(this.kernelInitializer),
74757 biasInitializer: serializeInitializer(this.biasInitializer),
74758 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
74759 biasRegularizer: serializeRegularizer(this.biasRegularizer),
74760 activityRegularizer: serializeRegularizer(this.activityRegularizer),
74761 kernelConstraint: serializeConstraint(this.kernelConstraint),
74762 biasConstraint: serializeConstraint(this.biasConstraint)
74763 };
74764 var baseConfig = _get(_getPrototypeOf(Dense.prototype), "getConfig", this).call(this);
74765 Object.assign(config, baseConfig);
74766 return config;
74767 }
74768 }]);
74769 return Dense;
74770 }(Layer);
74771 /** @nocollapse */
74772 Dense.className = 'Dense';
74773 registerClass(Dense);
74774 var Flatten = /*#__PURE__*/function (_Layer3) {
74775 _inherits(Flatten, _Layer3);
74776 var _super4 = _createSuper(Flatten);
74777 function Flatten(args) {
74778 var _this6;
74779 _classCallCheck(this, Flatten);
74780 args = args || {};
74781 _this6 = _super4.call(this, args);
74782 _this6.inputSpec = [{
74783 minNDim: 3
74784 }];
74785 _this6.dataFormat = args.dataFormat;
74786 return _this6;
74787 }
74788 _createClass(Flatten, [{
74789 key: "computeOutputShape",
74790 value: function computeOutputShape(inputShape) {
74791 inputShape = getExactlyOneShape(inputShape);
74792 var _iterator = _createForOfIteratorHelper(inputShape.slice(1)),
74793 _step;
74794 try {
74795 for (_iterator.s(); !(_step = _iterator.n()).done;) {
74796 var dim = _step.value;
74797 if (dim == null) {
74798 throw new ValueError("The shape of the input to \"Flatten\" is not fully defined " + "(got ".concat(inputShape.slice(1), "). Make sure to pass a complete ") + "\"input_shape\" or \"batch_input_shape\" argument to the first " + "layer in your model.");
74799 }
74800 }
74801 } catch (err) {
74802 _iterator.e(err);
74803 } finally {
74804 _iterator.f();
74805 }
74806 return [inputShape[0], arrayProd(inputShape, 1)];
74807 }
74808 }, {
74809 key: "call",
74810 value: function call(inputs, kwargs) {
74811 var _this7 = this;
74812 return tidy(function () {
74813 _this7.invokeCallHook(inputs, kwargs);
74814 var input = getExactlyOneTensor(inputs);
74815 if (_this7.dataFormat === 'channelsFirst' && input.rank > 1) {
74816 var permutation = [0];
74817 for (var i = 2; i < input.rank; ++i) {
74818 permutation.push(i);
74819 }
74820 permutation.push(1);
74821 input = transpose$2(input, permutation);
74822 }
74823 return batchFlatten(input);
74824 });
74825 }
74826 }, {
74827 key: "getConfig",
74828 value: function getConfig() {
74829 var config = {};
74830 if (this.dataFormat != null) {
74831 config['dataFormat'] = this.dataFormat;
74832 }
74833 var baseConfig = _get(_getPrototypeOf(Flatten.prototype), "getConfig", this).call(this);
74834 Object.assign(config, baseConfig);
74835 return config;
74836 }
74837 }]);
74838 return Flatten;
74839 }(Layer);
74840 /** @nocollapse */
74841 Flatten.className = 'Flatten';
74842 registerClass(Flatten);
74843 var Activation = /*#__PURE__*/function (_Layer4) {
74844 _inherits(Activation, _Layer4);
74845 var _super5 = _createSuper(Activation);
74846 function Activation(args) {
74847 var _this8;
74848 _classCallCheck(this, Activation);
74849 _this8 = _super5.call(this, args);
74850 _this8.supportsMasking = true;
74851 _this8.activation = getActivation(args.activation);
74852 return _this8;
74853 }
74854 _createClass(Activation, [{
74855 key: "call",
74856 value: function call(inputs, kwargs) {
74857 var _this9 = this;
74858 return tidy(function () {
74859 _this9.invokeCallHook(inputs, kwargs);
74860 var input = getExactlyOneTensor(inputs);
74861 return _this9.activation.apply(input);
74862 });
74863 }
74864 }, {
74865 key: "getConfig",
74866 value: function getConfig() {
74867 var config = {
74868 activation: serializeActivation(this.activation)
74869 };
74870 var baseConfig = _get(_getPrototypeOf(Activation.prototype), "getConfig", this).call(this);
74871 Object.assign(config, baseConfig);
74872 return config;
74873 }
74874 }]);
74875 return Activation;
74876 }(Layer);
74877 /** @nocollapse */
74878 Activation.className = 'Activation';
74879 registerClass(Activation);
74880 var RepeatVector = /*#__PURE__*/function (_Layer5) {
74881 _inherits(RepeatVector, _Layer5);
74882 var _super6 = _createSuper(RepeatVector);
74883 function RepeatVector(args) {
74884 var _this10;
74885 _classCallCheck(this, RepeatVector);
74886 _this10 = _super6.call(this, args);
74887 _this10.n = args.n;
74888 _this10.inputSpec = [{
74889 ndim: 2
74890 }];
74891 return _this10;
74892 }
74893 _createClass(RepeatVector, [{
74894 key: "computeOutputShape",
74895 value: function computeOutputShape(inputShape) {
74896 return [inputShape[0], this.n, inputShape[1]];
74897 }
74898 }, {
74899 key: "call",
74900 value: function call(inputs, kwargs) {
74901 var _this11 = this;
74902 return tidy(function () {
74903 inputs = getExactlyOneTensor(inputs);
74904 return repeat(inputs, _this11.n);
74905 });
74906 }
74907 }, {
74908 key: "getConfig",
74909 value: function getConfig() {
74910 var config = {
74911 n: this.n
74912 };
74913 var baseConfig = _get(_getPrototypeOf(RepeatVector.prototype), "getConfig", this).call(this);
74914 Object.assign(config, baseConfig);
74915 return config;
74916 }
74917 }]);
74918 return RepeatVector;
74919 }(Layer);
74920 /** @nocollapse */
74921 RepeatVector.className = 'RepeatVector';
74922 registerClass(RepeatVector);
74923 var Reshape = /*#__PURE__*/function (_Layer6) {
74924 _inherits(Reshape, _Layer6);
74925 var _super7 = _createSuper(Reshape);
74926 function Reshape(args) {
74927 var _this12;
74928 _classCallCheck(this, Reshape);
74929 _this12 = _super7.call(this, args);
74930 _this12.targetShape = args.targetShape;
74931 // Make sure that all unknown dimensions are represented as `null`.
74932 for (var i = 0; i < _this12.targetShape.length; ++i) {
74933 if (_this12.isUnknown(_this12.targetShape[i])) {
74934 _this12.targetShape[i] = null;
74935 }
74936 }
74937 return _this12;
74938 }
74939 _createClass(Reshape, [{
74940 key: "isUnknown",
74941 value: function isUnknown(dim) {
74942 return dim < 0 || dim == null;
74943 }
74944 /**
74945 * Finds and replaces a missing dimension in output shape.
74946 *
74947 * This is a near direct port of the internal Numpy function
74948 * `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c`.
74949 *
74950 * @param inputShape: Original shape of array begin reshape.
74951 * @param outputShape: Target shape of the array, with at most a single
74952 * `null` or negative number, which indicates an underdetermined dimension
74953 * that should be derived from `inputShape` and the known dimensions of
74954 * `outputShape`.
74955 * @returns: The output shape with `null` replaced with its computed value.
74956 * @throws: ValueError: If `inputShape` and `outputShape` do not match.
74957 */
74958 }, {
74959 key: "fixUnknownDimension",
74960 value: function fixUnknownDimension(inputShape, outputShape) {
74961 var errorMsg = 'Total size of new array must be unchanged.';
74962 var finalShape = outputShape.slice();
74963 var known = 1;
74964 var unknown = null;
74965 for (var i = 0; i < finalShape.length; ++i) {
74966 var dim = finalShape[i];
74967 if (this.isUnknown(dim)) {
74968 if (unknown === null) {
74969 unknown = i;
74970 } else {
74971 throw new ValueError('Can only specifiy one unknown dimension.');
74972 }
74973 } else {
74974 known *= dim;
74975 }
74976 }
74977 var originalSize = arrayProd(inputShape);
74978 if (unknown !== null) {
74979 if (known === 0 || originalSize % known !== 0) {
74980 throw new ValueError(errorMsg);
74981 }
74982 finalShape[unknown] = originalSize / known;
74983 } else if (originalSize !== known) {
74984 throw new ValueError(errorMsg);
74985 }
74986 return finalShape;
74987 }
74988 }, {
74989 key: "computeOutputShape",
74990 value: function computeOutputShape(inputShape) {
74991 var anyUnknownDims = false;
74992 for (var i = 0; i < inputShape.length; ++i) {
74993 if (this.isUnknown(inputShape[i])) {
74994 anyUnknownDims = true;
74995 break;
74996 }
74997 }
74998 if (anyUnknownDims) {
74999 return inputShape.slice(0, 1).concat(this.targetShape);
75000 } else {
75001 return inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
75002 }
75003 }
75004 }, {
75005 key: "call",
75006 value: function call(inputs, kwargs) {
75007 var _this13 = this;
75008 return tidy(function () {
75009 _this13.invokeCallHook(inputs, kwargs);
75010 var input = getExactlyOneTensor(inputs);
75011 var inputShape = input.shape;
75012 var outputShape = inputShape.slice(0, 1).concat(_this13.fixUnknownDimension(inputShape.slice(1), _this13.targetShape));
75013 return reshape$3(input, outputShape);
75014 });
75015 }
75016 }, {
75017 key: "getConfig",
75018 value: function getConfig() {
75019 var config = {
75020 targetShape: this.targetShape
75021 };
75022 var baseConfig = _get(_getPrototypeOf(Reshape.prototype), "getConfig", this).call(this);
75023 Object.assign(config, baseConfig);
75024 return config;
75025 }
75026 }]);
75027 return Reshape;
75028 }(Layer);
75029 /** @nocollapse */
75030 Reshape.className = 'Reshape';
75031 registerClass(Reshape);
75032 var Permute = /*#__PURE__*/function (_Layer7) {
75033 _inherits(Permute, _Layer7);
75034 var _super8 = _createSuper(Permute);
75035 function Permute(args) {
75036 var _this14;
75037 _classCallCheck(this, Permute);
75038 _this14 = _super8.call(this, args);
75039 if (args.dims == null) {
75040 throw new Error('Required configuration field `dims` is missing during Permute ' + 'constructor call.');
75041 }
75042 if (!Array.isArray(args.dims)) {
75043 throw new Error('Permute constructor requires `dims` to be an Array, but received ' + "".concat(args.dims, " instead."));
75044 }
75045 // Check the validity of the permutation indices.
75046 var expectedSortedIndices = range$2(1, args.dims.length + 1);
75047 if (!arraysEqual(args.dims.slice().sort(), expectedSortedIndices)) {
75048 throw new Error('Invalid permutation `dims`: ' + JSON.stringify(args.dims) + ' `dims` must contain consecutive integers starting from 1.');
75049 }
75050 _this14.dims = args.dims;
75051 _this14.dimsIncludingBatch = [0].concat(_this14.dims);
75052 _this14.inputSpec = [new InputSpec({
75053 ndim: _this14.dims.length + 1
75054 })];
75055 return _this14;
75056 }
75057 _createClass(Permute, [{
75058 key: "computeOutputShape",
75059 value: function computeOutputShape(inputShape) {
75060 inputShape = getExactlyOneShape(inputShape);
75061 var outputShape = inputShape.slice();
75062 this.dims.forEach(function (dim, i) {
75063 outputShape[i + 1] = inputShape[dim];
75064 });
75065 return outputShape;
75066 }
75067 }, {
75068 key: "call",
75069 value: function call(inputs, kwargs) {
75070 return transpose$2(getExactlyOneTensor(inputs), this.dimsIncludingBatch);
75071 }
75072 }, {
75073 key: "getConfig",
75074 value: function getConfig() {
75075 var config = {
75076 dims: this.dims
75077 };
75078 var baseConfig = _get(_getPrototypeOf(Permute.prototype), "getConfig", this).call(this);
75079 Object.assign(config, baseConfig);
75080 return config;
75081 }
75082 }]);
75083 return Permute;
75084 }(Layer);
75085 /** @nocollapse */
75086 Permute.className = 'Permute';
75087 registerClass(Permute);
75088 var Masking = /*#__PURE__*/function (_Layer8) {
75089 _inherits(Masking, _Layer8);
75090 var _super9 = _createSuper(Masking);
75091 function Masking(args) {
75092 var _this15;
75093 _classCallCheck(this, Masking);
75094 _this15 = _super9.call(this, args == null ? {} : args);
75095 _this15.supportsMasking = true;
75096 if (args != null) {
75097 _this15.maskValue = args.maskValue == null ? 0 : args.maskValue;
75098 } else {
75099 _this15.maskValue = 0;
75100 }
75101 return _this15;
75102 }
75103 _createClass(Masking, [{
75104 key: "computeOutputShape",
75105 value: function computeOutputShape(inputShape) {
75106 return inputShape;
75107 }
75108 }, {
75109 key: "getConfig",
75110 value: function getConfig() {
75111 var baseConfig = _get(_getPrototypeOf(Masking.prototype), "getConfig", this).call(this);
75112 var config = {
75113 maskValue: this.maskValue
75114 };
75115 Object.assign(config, baseConfig);
75116 return config;
75117 }
75118 }, {
75119 key: "computeMask",
75120 value: function computeMask(inputs, mask) {
75121 var input = getExactlyOneTensor(inputs);
75122 var axis = -1;
75123 return any$2(notEqual$2(input, this.maskValue), axis);
75124 }
75125 }, {
75126 key: "call",
75127 value: function call(inputs, kwargs) {
75128 var _this16 = this;
75129 return tidy(function () {
75130 _this16.invokeCallHook(inputs, kwargs);
75131 var input = getExactlyOneTensor(inputs);
75132 var axis = -1;
75133 var keepDims = true;
75134 var booleanMask = any$2(notEqual$2(input, _this16.maskValue), axis, keepDims);
75135 var output = mul(input, cast$3(booleanMask, input.dtype));
75136 return output;
75137 });
75138 }
75139 }]);
75140 return Masking;
75141 }(Layer);
75142 /** @nocollapse */
75143 Masking.className = 'Masking';
75144 registerClass(Masking);
75145
75146 var Embedding = /*#__PURE__*/function (_Layer) {
75147 _inherits(Embedding, _Layer);
75148 var _super = _createSuper(Embedding);
75149 function Embedding(args) {
75150 var _this;
75151 _classCallCheck(this, Embedding);
75152 _this = _super.call(this, args);
75153 _this.embeddings = null;
75154 _this.DEFAULT_EMBEDDINGS_INITIALIZER = 'randomUniform';
75155 if (args.batchInputShape == null && args.inputShape == null) {
75156 // Porting Note: This logic is copied from Layer's constructor, since we
75157 // can't do exactly what the Python constructor does for Embedding().
75158 // Specifically, the super constructor can not be called after the
75159 // mutation of the `config` argument.
75160 var batchSize = null;
75161 if (args.batchSize != null) {
75162 batchSize = args.batchSize;
75163 }
75164 if (args.inputLength == null) {
75165 // Fix super-constructor to what it would have done if
75166 // 'config.inputShape' were (None, )
75167 _this.batchInputShape = [batchSize, null];
75168 } else {
75169 // Fix super-constructor to what it would have done if
75170 // 'config.inputShape' were (config.inputLength, )
75171 _this.batchInputShape = [batchSize].concat(toList(args.inputLength));
75172 }
75173 }
75174 _this.inputDim = args.inputDim;
75175 assertPositiveInteger(_this.inputDim, 'inputDim');
75176 _this.outputDim = args.outputDim;
75177 assertPositiveInteger(_this.outputDim, 'outputDim');
75178 _this.embeddingsInitializer = getInitializer(args.embeddingsInitializer || _this.DEFAULT_EMBEDDINGS_INITIALIZER);
75179 _this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer);
75180 _this.activityRegularizer = getRegularizer(args.activityRegularizer);
75181 _this.embeddingsConstraint = getConstraint(args.embeddingsConstraint);
75182 _this.maskZero = args.maskZero;
75183 _this.supportsMasking = args.maskZero;
75184 _this.inputLength = args.inputLength;
75185 return _this;
75186 }
75187 _createClass(Embedding, [{
75188 key: "build",
75189 value: function build(inputShape) {
75190 this.embeddings = this.addWeight('embeddings', [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, true, this.embeddingsConstraint);
75191 this.built = true;
75192 }
75193 // Override warnOnIncompatibleInputShape because an embedding layer allows
75194 // the input to have varying ranks.
75195 }, {
75196 key: "warnOnIncompatibleInputShape",
75197 value: function warnOnIncompatibleInputShape(inputShape) {}
75198 }, {
75199 key: "computeMask",
75200 value: function computeMask(inputs, mask) {
75201 var _this2 = this;
75202 return tidy(function () {
75203 if (!_this2.maskZero) {
75204 return null;
75205 } else {
75206 inputs = getExactlyOneTensor(inputs);
75207 return notEqual$2(inputs, zerosLike$3(inputs));
75208 }
75209 });
75210 }
75211 }, {
75212 key: "computeOutputShape",
75213 value: function computeOutputShape(inputShape) {
75214 inputShape = getExactlyOneShape(inputShape);
75215 if (this.inputLength == null) {
75216 return [].concat(_toConsumableArray(inputShape), [this.outputDim]);
75217 }
75218 // inputLength can be an array if input is 3D or higher.
75219 var inLens = toList(this.inputLength);
75220 if (inLens.length !== inputShape.length - 1) {
75221 throw new ValueError("\"inputLength\" is ".concat(this.inputLength, ", but received ") + "input shape has shape ".concat(inputShape));
75222 } else {
75223 var i = 0;
75224 for (var k = 0; k < inLens.length; ++k) {
75225 var s1 = inLens[k];
75226 var s2 = inputShape[k + 1];
75227 if (s1 != null && s2 != null && s1 !== s2) {
75228 throw new ValueError("\"inputLength\" is ".concat(this.inputLength, ", but received ") + "input shape has shape ".concat(inputShape));
75229 } else if (s1 == null) {
75230 inLens[i] = s2;
75231 }
75232 i++;
75233 }
75234 }
75235 return [inputShape[0]].concat(_toConsumableArray(inLens), [this.outputDim]);
75236 }
75237 }, {
75238 key: "call",
75239 value: function call(inputs, kwargs) {
75240 var _this3 = this;
75241 return tidy(function () {
75242 _this3.invokeCallHook(inputs, kwargs);
75243 // Embedding layer accepts only a single input.
75244 var input = getExactlyOneTensor(inputs);
75245 if (input.dtype !== 'int32') {
75246 input = cast$2(input, 'int32');
75247 }
75248 var output = gather(_this3.embeddings.read(), reshape$3(input, [input.size]));
75249 return reshape$3(output, getExactlyOneShape(_this3.computeOutputShape(input.shape)));
75250 });
75251 }
75252 }, {
75253 key: "getConfig",
75254 value: function getConfig() {
75255 var config = {
75256 inputDim: this.inputDim,
75257 outputDim: this.outputDim,
75258 embeddingsInitializer: serializeInitializer(this.embeddingsInitializer),
75259 embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer),
75260 activityRegularizer: serializeRegularizer(this.activityRegularizer),
75261 embeddingsConstraint: serializeConstraint(this.embeddingsConstraint),
75262 maskZero: this.maskZero,
75263 inputLength: this.inputLength
75264 };
75265 var baseConfig = _get(_getPrototypeOf(Embedding.prototype), "getConfig", this).call(this);
75266 Object.assign(config, baseConfig);
75267 return config;
75268 }
75269 }]);
75270 return Embedding;
75271 }(Layer);
75272 /** @nocollapse */
75273 Embedding.className = 'Embedding';
75274 registerClass(Embedding);
75275
75276 /**
75277 * Generic Merge layer for element-wise merge functions.
75278 *
75279 * Used to implement `Sum`, `Average`, `Concatenate`, etc.
75280 */
75281 var Merge = /*#__PURE__*/function (_Layer) {
75282 _inherits(Merge, _Layer);
75283 var _super = _createSuper(Merge);
75284 function Merge(args) {
75285 var _this;
75286 _classCallCheck(this, Merge);
75287 _this = _super.call(this, args || {});
75288 _this.supportsMasking = true;
75289 return _this;
75290 }
75291 /**
75292 * Logic for merging multiple tensors, to be overridden by subclasses.
75293 * @param inputs
75294 */
75295 _createClass(Merge, [{
75296 key: "mergeFunction",
75297 value: function mergeFunction(inputs) {
75298 throw new NotImplementedError();
75299 }
75300 /**
75301 * Computes the shape of the result of an elementwise operation.
75302 *
75303 * @param shape1: Shape of the first tensor.
75304 * @param shape2: Shape of the second tensor.
75305 * @returns Expected output shape when an elementwise operation is carried
75306 * out on 2 tensors with shapes `shape1` and `shape2`.
75307 * @throws ValueError: If `shape1` and `shape2` are not compatible for
75308 * element-wise operations.
75309 */
75310 }, {
75311 key: "computeElementwiseOpOutputShape",
75312 value: function computeElementwiseOpOutputShape(shape1, shape2) {
75313 if (shape1 == null || shape2 == null) {
75314 return null;
75315 } else if (shape1.length < shape2.length) {
75316 return this.computeElementwiseOpOutputShape(shape2, shape1);
75317 } else if (shape2.length === 0) {
75318 return shape1;
75319 }
75320 var outputShape = shape1.slice(0, shape1.length - shape2.length);
75321 for (var k = 0; k < shape2.length; ++k) {
75322 var i = shape1[shape1.length - shape2.length + k];
75323 var j = shape2[k];
75324 if (i == null || j == null || i < 0 || j < 0) {
75325 outputShape.push(null);
75326 } else if (i === 1) {
75327 outputShape.push(j);
75328 } else if (j === 1) {
75329 outputShape.push(i);
75330 } else {
75331 if (i !== j) {
75332 throw new ValueError('Operands could not be broadcast together with shapes ' + JSON.stringify(shape1) + ' ' + JSON.stringify(shape2));
75333 }
75334 outputShape.push(i);
75335 }
75336 }
75337 return outputShape;
75338 }
75339 }, {
75340 key: "build",
75341 value: function build(inputShape) {
75342 // Used purely for shape validation.
75343 if (Array.isArray(inputShape) && !Array.isArray(inputShape[0])) {
75344 // Make sure that inputShape is an Array of shape.
75345 inputShape = [getExactlyOneShape(inputShape)];
75346 }
75347 inputShape = inputShape;
75348 if (inputShape.length < 2) {
75349 throw new ValueError('A merge layer should be called on an Array of at least 2 inputs.' + " Got ".concat(inputShape.length, " input(s)."));
75350 }
75351 // Make sure that there is at most one unique batch size among the input
75352 // shapes.
75353 var batchSizes = [];
75354 var _iterator = _createForOfIteratorHelper(inputShape),
75355 _step;
75356 try {
75357 for (_iterator.s(); !(_step = _iterator.n()).done;) {
75358 var _shape = _step.value;
75359 if (_shape != null && _shape[0] !== null) {
75360 batchSizes.push(_shape[0]);
75361 }
75362 }
75363 } catch (err) {
75364 _iterator.e(err);
75365 } finally {
75366 _iterator.f();
75367 }
75368 batchSizes = unique$2(batchSizes);
75369 if (batchSizes.length > 1) {
75370 throw new ValueError("Can not merge tensors with different batch sizes. " + "Got tensors with shapes: ".concat(JSON.stringify(inputShape), "."));
75371 }
75372 var outputShape = inputShape[0] == null ? null : inputShape[0].slice(1);
75373 for (var i = 1; i < inputShape.length; ++i) {
75374 var shape = inputShape[i] == null ? null : inputShape[i].slice(1);
75375 outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
75376 }
75377 // If the inputs have different ranks, we have to reshape them to make them
75378 // broadcastable.
75379 var allRanks = inputShape.map(function (shape) {
75380 return shape.length;
75381 });
75382 if (inputShape.indexOf(null) === -1 && unique$2(allRanks).length === 1) {
75383 this.reshapeRequired = false;
75384 } else {
75385 this.reshapeRequired = true;
75386 }
75387 }
75388 }, {
75389 key: "call",
75390 value: function call(inputs, kwargs) {
75391 var _this2 = this;
75392 return tidy(function () {
75393 inputs = inputs;
75394 if (_this2.reshapeRequired) {
75395 var reshapedInputs = [];
75396 var inputDims = inputs.map(function (input) {
75397 return input.rank;
75398 });
75399 if (inputDims.indexOf(null) === -1) {
75400 // If ranks of all inputs are available, we simply expand each of them
75401 // at axis=1 until all of them have the same rank.
75402 var maxNDim = max$2(inputDims);
75403 var _iterator2 = _createForOfIteratorHelper(inputs),
75404 _step2;
75405 try {
75406 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
75407 var x = _step2.value;
75408 var xNDim = x.rank;
75409 for (var k = 0; k < maxNDim - xNDim; ++k) {
75410 x = expandDims$2(x, 1);
75411 }
75412 reshapedInputs.push(x);
75413 }
75414 } catch (err) {
75415 _iterator2.e(err);
75416 } finally {
75417 _iterator2.f();
75418 }
75419 return _this2.mergeFunction(reshapedInputs);
75420 } else {
75421 // Transpose all inputs so that batch size is the last dimension.
75422 // [batchSize, dim1, dim2, ...] -> [dim1, dim2, ..., batchSize]
75423 var transposed = false;
75424 var _iterator3 = _createForOfIteratorHelper(inputs),
75425 _step3;
75426 try {
75427 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
75428 var _x = _step3.value;
75429 var _xNDim = _x.rank;
75430 if (_xNDim == null) {
75431 var xShape = _x.shape;
75432 var _batchSize = xShape[0];
75433 var _newShape = xShape.slice(1).concat([_batchSize]);
75434 var xTransposed = reshape$3(_x, [_batchSize].concat(arrayProd(xShape.slice(1))));
75435 xTransposed = transpose$2(xTransposed, [1, 0]);
75436 xTransposed = reshape$3(xTransposed, _newShape);
75437 reshapedInputs.push(xTransposed);
75438 transposed = true;
75439 } else if (_xNDim > 1) {
75440 var _dims = range$2(1, _xNDim).concat([0]);
75441 reshapedInputs.push(transpose$2(_x, _dims));
75442 transposed = true;
75443 } else {
75444 // We don't transpose inputs if they are 1D vectors or scalars.
75445 reshapedInputs.push(_x);
75446 }
75447 }
75448 } catch (err) {
75449 _iterator3.e(err);
75450 } finally {
75451 _iterator3.f();
75452 }
75453 var y = _this2.mergeFunction(reshapedInputs);
75454 var yNDim = y.rank;
75455 if (transposed) {
75456 // If inputs have been transposed, we have to transpose the output
75457 // too.
75458 if (yNDim == null) {
75459 var yShape = y.shape;
75460 var _yNDim = yShape.length;
75461 var batchSize = yShape[_yNDim - 1];
75462 var newShape = [batchSize].concat(yShape.slice(0, yShape.length - 1));
75463 y = reshape$3(transpose$2(reshape$3(y, [-1, batchSize]), [1, 0]), newShape);
75464 } else if (yNDim > 1) {
75465 var dims = [yNDim - 1].concat(range$2(0, yNDim - 1));
75466 y = transpose$2(y, dims);
75467 }
75468 }
75469 return y;
75470 }
75471 } else {
75472 return _this2.mergeFunction(inputs);
75473 }
75474 });
75475 }
75476 }, {
75477 key: "computeOutputShape",
75478 value: function computeOutputShape(inputShape) {
75479 inputShape = inputShape;
75480 var outputShape;
75481 if (inputShape[0] == null) {
75482 outputShape = null;
75483 } else {
75484 outputShape = inputShape[0].slice(1);
75485 }
75486 for (var i = 1; i < inputShape.length; ++i) {
75487 var shape = inputShape[i] == null ? null : inputShape[i].slice(1);
75488 outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
75489 }
75490 var batchSizes = [];
75491 var _iterator4 = _createForOfIteratorHelper(inputShape),
75492 _step4;
75493 try {
75494 for (_iterator4.s(); !(_step4 = _iterator4.n()).done;) {
75495 var _shape2 = _step4.value;
75496 if (_shape2 != null && _shape2[0] !== null) {
75497 batchSizes.push(_shape2[0]);
75498 }
75499 }
75500 } catch (err) {
75501 _iterator4.e(err);
75502 } finally {
75503 _iterator4.f();
75504 }
75505 batchSizes = unique$2(batchSizes);
75506 if (batchSizes.length === 1) {
75507 outputShape = batchSizes.concat(outputShape);
75508 } else {
75509 outputShape = [null].concat(outputShape);
75510 }
75511 return outputShape;
75512 }
75513 }, {
75514 key: "computeMask",
75515 value: function computeMask(inputs, mask) {
75516 return tidy(function () {
75517 if (mask == null) {
75518 return null;
75519 }
75520 if (!Array.isArray(mask)) {
75521 throw new ValueError('`mask` should be an Array');
75522 }
75523 if (!Array.isArray(inputs)) {
75524 throw new ValueError('`inputs` should be an Array');
75525 }
75526 if (mask.length !== inputs.length) {
75527 throw new ValueError("The Array 'inputs' and 'mask' are expected to have the same " + "length, but have different lengths " + "(".concat(inputs.length, " vs ").concat(mask.length, ")"));
75528 }
75529 if (mask.every(function (m) {
75530 return m == null;
75531 })) {
75532 return null;
75533 }
75534 mask = mask.map(function (m) {
75535 return m == null ? m : expandDims$3(m, 0);
75536 });
75537 var output = mask[0];
75538 for (var i = 1; i < mask.length - 1; ++i) {
75539 output = logicalAnd$2(output, mask[i]);
75540 }
75541 return output;
75542 });
75543 }
75544 }]);
75545 return Merge;
75546 }(Layer);
75547 var Add = /*#__PURE__*/function (_Merge) {
75548 _inherits(Add, _Merge);
75549 var _super2 = _createSuper(Add);
75550 function Add(args) {
75551 _classCallCheck(this, Add);
75552 return _super2.call(this, args);
75553 }
75554 _createClass(Add, [{
75555 key: "mergeFunction",
75556 value: function mergeFunction(inputs) {
75557 return tidy(function () {
75558 var output = inputs[0].clone();
75559 for (var i = 1; i < inputs.length; ++i) {
75560 output = add$3(output, inputs[i]);
75561 }
75562 return output;
75563 });
75564 }
75565 }]);
75566 return Add;
75567 }(Merge);
75568 /** @nocollapse */
75569 Add.className = 'Add';
75570 registerClass(Add);
75571 /**
75572 * Calculate the element-wise sum of inputs, which all have the same shape.
75573 *
75574 * This function can be invoked in three ways.
75575 *
75576 * 1. Construct an instance of `Add` layer, by using no input argument
75577 * or a single configuration argument. The resultant `Add` layer can then
75578 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
75579 *
75580 * ```js
75581 * const addLayer = tf.layers.add();
75582 *
75583 * // The layer can be applied to inputs.
75584 * const input1 = tf.input({shape: [2, 2]});
75585 * const input2 = tf.input({shape: [2, 2]});
75586 * const output = addLayer.apply([input1, input2]);
75587 * console.log(output.shape);
75588 * // You get [null, 2, 2], with the first dimension as the undetermined batch
75589 * // dimension.
75590 * ```
75591 *
75592 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
75593 * an `Layer` object internally and calls its `apply` method on the inputs,
75594 * generating a new `tf.SymbolicTensor`. For example:
75595 *
75596 * ```js
75597 * const input1 = tf.input({shape: [2, 2]});
75598 * const input2 = tf.input({shape: [2, 2]});
75599 * const output = tf.layers.add([input1, input2]);
75600 * console.log(output.shape);
75601 * // You get [null, 2, 2], with the first dimension as the undetermined batch
75602 * // dimension.
75603 * ```
75604 *
75605 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
75606 * an `Layer` object internally and calls its `apply` method on the inputs,
75607 * generating a new `tf.Tensor` as the result of the computation. For
75608 * example:
75609 *
75610 * ```js
75611 * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
75612 * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
75613 * tf.layers.add([input1, input2]).print();
75614 * // Gives [[11, 22], [33, 44]].
75615 *
75616 */
75617 function add$2(config) {
75618 if (Array.isArray(config)) {
75619 var layer = new Add({});
75620 return layer.apply(config);
75621 } else {
75622 return new Add(config);
75623 }
75624 }
75625 var Multiply = /*#__PURE__*/function (_Merge2) {
75626 _inherits(Multiply, _Merge2);
75627 var _super3 = _createSuper(Multiply);
75628 function Multiply(args) {
75629 _classCallCheck(this, Multiply);
75630 return _super3.call(this, args);
75631 }
75632 _createClass(Multiply, [{
75633 key: "mergeFunction",
75634 value: function mergeFunction(inputs) {
75635 return tidy(function () {
75636 var output = inputs[0].clone();
75637 for (var i = 1; i < inputs.length; ++i) {
75638 output = mul(output, inputs[i]);
75639 }
75640 return output;
75641 });
75642 }
75643 }]);
75644 return Multiply;
75645 }(Merge);
75646 /** @nocollapse */
75647 Multiply.className = 'Multiply';
75648 registerClass(Multiply);
75649 /**
75650 * Calculate the element-wise product of inputs, which all have the same shape.
75651 *
75652 * This function can be invoked in three ways.
75653 *
75654 * 1. Construct an instance of `Multiply` layer, by using no input argument
75655 * or a single configuration argument. The resultant `Multiply` layer can
75656 * then be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
75657 *
75658 * ```js
75659 * const multiplyLayer = tf.layers.multiply();
75660 *
75661 * // The layer can be applied to inputs.
75662 * const input1 = tf.input({shape: [2, 2]});
75663 * const input2 = tf.input({shape: [2, 2]});
75664 * const output = multiplyLayer.apply([input1, input2]);
75665 * console.log(output.shape);
75666 * // You get [null, 2, 2], with the first dimension as the undetermined batch
75667 * // dimension.
75668 * ```
75669 *
75670 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
75671 * an `Layer` object internally and calls its `apply` method on the inputs,
75672 * generating a new `tf.SymbolicTensor`. For example:
75673 *
75674 * ```js
75675 * const input1 = tf.input({shape: [2, 2]});
75676 * const input2 = tf.input({shape: [2, 2]});
75677 * const output = tf.layers.multiply([input1, input2]);
75678 * console.log(output.shape);
75679 * // You get [null, 2, 2], with the first dimension as the undetermined batch
75680 * // dimension.
75681 * ```
75682 *
75683 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
75684 * an `Layer` object internally and calls its `apply` method on the inputs,
75685 * generating a new `tf.Tensor` as the result of the computation. For
75686 * example:
75687 *
75688 * ```js
75689 * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
75690 * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
75691 * tf.layers.multiply([input1, input2]).print();
75692 * // Gives [[10, 40], [90, 160]].
75693 *
75694 */
75695 function multiply$3(config) {
75696 if (Array.isArray(config)) {
75697 var layer = new Multiply({});
75698 return layer.apply(config);
75699 } else {
75700 return new Multiply(config);
75701 }
75702 }
75703 var Average = /*#__PURE__*/function (_Merge3) {
75704 _inherits(Average, _Merge3);
75705 var _super4 = _createSuper(Average);
75706 function Average(args) {
75707 _classCallCheck(this, Average);
75708 return _super4.call(this, args);
75709 }
75710 _createClass(Average, [{
75711 key: "mergeFunction",
75712 value: function mergeFunction(inputs) {
75713 return tidy(function () {
75714 var output = inputs[0].clone();
75715 for (var i = 1; i < inputs.length; ++i) {
75716 output = add$3(output, inputs[i]);
75717 }
75718 return mul(1 / inputs.length, output);
75719 });
75720 }
75721 }]);
75722 return Average;
75723 }(Merge);
75724 /** @nocollapse */
75725 Average.className = 'Average';
75726 registerClass(Average);
75727 /**
75728 * Calculate the element-wise arithmetic mean of inputs, which all have the same
75729 * shape.
75730 *
75731 * This function can be invoked in three ways.
75732 *
75733 * 1. Construct an instance of `Average` layer, by using no input argument
75734 * or a single configuration argument. The resultant `Average` layer can then
75735 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
75736 *
75737 * ```js
75738 * const averageLayer = tf.layers.average();
75739 *
75740 * // The layer can be applied to inputs.
75741 * const input1 = tf.input({shape: [2, 2]});
75742 * const input2 = tf.input({shape: [2, 2]});
75743 * const output = averageLayer.apply([input1, input2]);
75744 * console.log(output.shape);
75745 * // You get [null, 2, 2], with the first dimension as the undetermined batch
75746 * // dimension.
75747 * ```
75748 *
75749 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
75750 * an `Layer` object internally and calls its `apply` method on the inputs,
75751 * generating a new `tf.SymbolicTensor`. For example:
75752 *
75753 * ```js
75754 * const input1 = tf.input({shape: [2, 2]});
75755 * const input2 = tf.input({shape: [2, 2]});
75756 * const output = tf.layers.average([input1, input2]);
75757 * console.log(output.shape);
75758 * // You get [null, 2, 2], with the first dimension as the undetermined batch
75759 * // dimension.
75760 * ```
75761 *
75762 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
75763 * an `Layer` object internally and calls its `apply` method on the inputs,
75764 * generating a new `tf.Tensor` as the result of the computation. For
75765 * example:
75766 *
75767 * ```js
75768 * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
75769 * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
75770 * tf.layers.average([input1, input2]).print();
75771 * // Gives [[5.5, 11], [16.5, 22]].
75772 *
75773 */
75774 function average$1(config) {
75775 if (Array.isArray(config)) {
75776 var layer = new Average({});
75777 return layer.apply(config);
75778 } else {
75779 return new Average(config);
75780 }
75781 }
75782 var Maximum = /*#__PURE__*/function (_Merge4) {
75783 _inherits(Maximum, _Merge4);
75784 var _super5 = _createSuper(Maximum);
75785 function Maximum(args) {
75786 _classCallCheck(this, Maximum);
75787 return _super5.call(this, args);
75788 }
75789 _createClass(Maximum, [{
75790 key: "mergeFunction",
75791 value: function mergeFunction(inputs) {
75792 return tidy(function () {
75793 var output = inputs[0];
75794 for (var i = 1; i < inputs.length; ++i) {
75795 output = maximum$4(output, inputs[i]);
75796 }
75797 return output;
75798 });
75799 }
75800 }]);
75801 return Maximum;
75802 }(Merge);
75803 /** @nocollapse */
75804 Maximum.className = 'Maximum';
75805 registerClass(Maximum);
75806 /**
75807 * Calculate the element-wise maximum of inputs, which all have the same shape.
75808 *
75809 * This function can be invoked in three ways.
75810 *
75811 * 1. Construct an instance of `Maximum` layer, by using no input argument
75812 * or a single configuration argument. The resultant `Maximum` layer can then
75813 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
75814 *
75815 * ```js
75816 * const maximumLayer = tf.layers.maximum();
75817 *
75818 * // The layer can be applied to inputs.
75819 * const input1 = tf.input({shape: [2, 2]});
75820 * const input2 = tf.input({shape: [2, 2]});
75821 * const output = maximumLayer.apply([input1, input2]);
75822 * console.log(output.shape);
75823 * // You get [null, 2, 2], with the first dimension as the undetermined batch
75824 * // dimension.
75825 * ```
75826 *
75827 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
75828 * an `Layer` object internally and calls its `apply` method on the inputs,
75829 * generating a new `tf.SymbolicTensor`. For example:
75830 *
75831 * ```js
75832 * const input1 = tf.input({shape: [2, 2]});
75833 * const input2 = tf.input({shape: [2, 2]});
75834 * const output = tf.layers.maximum([input1, input2]);
75835 * console.log(output.shape);
75836 * // You get [null, 2, 2], with the first dimension as the undetermined batch
75837 * // dimension.
75838 * ```
75839 *
75840 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
75841 * an `Layer` object internally and calls its `apply` method on the inputs,
75842 * generating a new `tf.Tensor` as the result of the computation. For
75843 * example:
75844 *
75845 * ```js
75846 * const input1 = tf.tensor2d([1, 20, 3, 40], [2, 2]);
75847 * const input2 = tf.tensor2d([10, 2, 30, 4], [2, 2]);
75848 * tf.layers.maximum([input1, input2]).print();
75849 * // Gives [[10, 20], [30, 40]].
75850 *
75851 */
75852 function maximum$3(config) {
75853 if (Array.isArray(config)) {
75854 var layer = new Maximum({});
75855 return layer.apply(config);
75856 } else {
75857 return new Maximum(config);
75858 }
75859 }
75860 var Minimum = /*#__PURE__*/function (_Merge5) {
75861 _inherits(Minimum, _Merge5);
75862 var _super6 = _createSuper(Minimum);
75863 function Minimum(args) {
75864 _classCallCheck(this, Minimum);
75865 return _super6.call(this, args);
75866 }
75867 _createClass(Minimum, [{
75868 key: "mergeFunction",
75869 value: function mergeFunction(inputs) {
75870 return tidy(function () {
75871 var output = inputs[0];
75872 for (var i = 1; i < inputs.length; ++i) {
75873 output = minimum$4(output, inputs[i]);
75874 }
75875 return output;
75876 });
75877 }
75878 }]);
75879 return Minimum;
75880 }(Merge);
75881 /** @nocollapse */
75882 Minimum.className = 'Minimum';
75883 registerClass(Minimum);
75884 /**
75885 * Calculate the element-wise minimum of inputs, which all have the same shape.
75886 *
75887 * This function can be invoked in three ways.
75888 *
75889 * 1. Construct an instance of `Minimum` layer, by using no input argument
75890 * or a single configuration argument. The resultant `Minimum` layer can then
75891 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
75892 *
75893 * ```js
75894 * const minimumLayer = tf.layers.minimum();
75895 *
75896 * // The layer can be applied to inputs.
75897 * const input1 = tf.input({shape: [2, 2]});
75898 * const input2 = tf.input({shape: [2, 2]});
75899 * const output = minimumLayer.apply([input1, input2]);
75900 * console.log(output.shape);
75901 * // You get [null, 2, 2], with the first dimension as the undetermined batch
75902 * // dimension.
75903 * ```
75904 *
75905 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
75906 * an `Layer` object internally and calls its `apply` method on the inputs,
75907 * generating a new `tf.SymbolicTensor`. For example:
75908 *
75909 * ```js
75910 * const input1 = tf.input({shape: [2, 2]});
75911 * const input2 = tf.input({shape: [2, 2]});
75912 * const output = tf.layers.minimum([input1, input2]);
75913 * console.log(output.shape);
75914 * // You get [null, 2, 2], with the first dimension as the undetermined batch
75915 * // dimension.
75916 * ```
75917 *
75918 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
75919 * an `Layer` object internally and calls its `apply` method on the inputs,
75920 * generating a new `tf.Tensor` as the result of the computation. For
75921 * example:
75922 *
75923 * ```js
75924 * const input1 = tf.tensor2d([1, 20, 3, 40], [2, 2]);
75925 * const input2 = tf.tensor2d([10, 2, 30, 4], [2, 2]);
75926 * tf.layers.minimum([input1, input2]).print();
75927 * // Gives [[1, 2], [3, 4]].
75928 *
75929 */
75930 function minimum$3(config) {
75931 if (Array.isArray(config)) {
75932 var layer = new Minimum({});
75933 return layer.apply(config);
75934 } else {
75935 return new Minimum(config);
75936 }
75937 }
75938 var Concatenate = /*#__PURE__*/function (_Merge6) {
75939 _inherits(Concatenate, _Merge6);
75940 var _super7 = _createSuper(Concatenate);
75941 function Concatenate(args) {
75942 var _this3;
75943 _classCallCheck(this, Concatenate);
75944 _this3 = _super7.call(this, args);
75945 _this3.DEFAULT_AXIS = -1;
75946 if (args == null) {
75947 args = {};
75948 }
75949 _this3.axis = args.axis == null ? _this3.DEFAULT_AXIS : args.axis;
75950 _this3.supportsMasking = true;
75951 _this3.reshapeRequired = false;
75952 return _this3;
75953 }
75954 _createClass(Concatenate, [{
75955 key: "build",
75956 value: function build(inputShape) {
75957 // Used purely for shape validation.]
75958 if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0])) || inputShape.length === 1) {
75959 throw new ValueError('A `Concatenate` layer should be called on a list of at least 2 ' + 'inputs');
75960 }
75961 inputShape = inputShape;
75962 var allNoneShape = true;
75963 var _iterator5 = _createForOfIteratorHelper(inputShape),
75964 _step5;
75965 try {
75966 for (_iterator5.s(); !(_step5 = _iterator5.n()).done;) {
75967 var _shape3 = _step5.value;
75968 if (_shape3 != null) {
75969 allNoneShape = false;
75970 break;
75971 }
75972 }
75973 } catch (err) {
75974 _iterator5.e(err);
75975 } finally {
75976 _iterator5.f();
75977 }
75978 if (allNoneShape) {
75979 return;
75980 }
75981 var shapeSet = [];
75982 for (var i = 0; i < inputShape.length; ++i) {
75983 var shapeWithoutConcatAxis = inputShape[i].slice();
75984 shapeWithoutConcatAxis.splice(this.axis, 1);
75985 var exists = false;
75986 var _iterator6 = _createForOfIteratorHelper(shapeSet),
75987 _step6;
75988 try {
75989 for (_iterator6.s(); !(_step6 = _iterator6.n()).done;) {
75990 var shape = _step6.value;
75991 if (arraysEqual(shape, shapeWithoutConcatAxis)) {
75992 exists = true;
75993 break;
75994 }
75995 }
75996 } catch (err) {
75997 _iterator6.e(err);
75998 } finally {
75999 _iterator6.f();
76000 }
76001 if (!exists) {
76002 shapeSet.push(shapeWithoutConcatAxis);
76003 }
76004 }
76005 if (shapeSet.length > 1) {
76006 throw new ValueError('A `Concatenate` layer requires inputs with matching shapes ' + 'except for the concat axis. Got input shapes: ' + JSON.stringify(inputShape));
76007 }
76008 }
76009 }, {
76010 key: "mergeFunction",
76011 value: function mergeFunction(inputs) {
76012 var _this4 = this;
76013 return tidy(function () {
76014 return concatenate$2(inputs, _this4.axis);
76015 });
76016 }
76017 }, {
76018 key: "computeOutputShape",
76019 value: function computeOutputShape(inputShape) {
76020 if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0]))) {
76021 throw new ValueError('A `Concatenate` layer should be called on a list of inputs.');
76022 }
76023 var inputShapes = inputShape;
76024 var outputShape = inputShapes[0].slice();
76025 var axis = this.axis < 0 ? outputShape.length + this.axis : this.axis;
76026 // Porting Note: the line above is because TypeScript doesn't support
76027 // negative indices.
76028 var _iterator7 = _createForOfIteratorHelper(inputShapes.slice(1)),
76029 _step7;
76030 try {
76031 for (_iterator7.s(); !(_step7 = _iterator7.n()).done;) {
76032 var shape = _step7.value;
76033 if (outputShape[axis] == null || shape[axis] == null) {
76034 outputShape[axis] = null;
76035 break;
76036 }
76037 outputShape[axis] += shape[axis];
76038 }
76039 } catch (err) {
76040 _iterator7.e(err);
76041 } finally {
76042 _iterator7.f();
76043 }
76044 return outputShape;
76045 }
76046 }, {
76047 key: "computeMask",
76048 value: function computeMask(inputs, mask) {
76049 var _this5 = this;
76050 if (mask == null) {
76051 return null;
76052 }
76053 if (!Array.isArray(mask)) {
76054 throw new ValueError('`mask` should be an array for Concatenate');
76055 }
76056 if (!Array.isArray(inputs)) {
76057 throw new ValueError('`inputs` should be an array for Concatenate');
76058 }
76059 if (mask.length !== inputs.length) {
76060 throw new ValueError("Mismatch in the length of mask (".concat(mask.length, ") ") + "and the legnth of inputs (".concat(inputs.length, ")"));
76061 }
76062 return tidy(function () {
76063 var allNullMasks = true;
76064 mask.forEach(function (m) {
76065 if (m != null) {
76066 allNullMasks = false;
76067 return;
76068 }
76069 });
76070 if (allNullMasks) {
76071 return null;
76072 }
76073 var outputMasks = [];
76074 for (var i = 0; i < inputs.length; ++i) {
76075 if (mask[i] == null) {
76076 // Input is unmasked. Append all 1's to masks.
76077 outputMasks.push(cast$3(onesLike$3(inputs[i]), 'bool'));
76078 } else if (mask[i].rank < inputs[i].rank) {
76079 // Mask is smaller than the input, expand it.
76080 outputMasks.push(expandDims$3(mask[i], -1));
76081 } else {
76082 outputMasks.push(mask[i]);
76083 }
76084 }
76085 var concatenatedMasks = concat$2(outputMasks, _this5.axis);
76086 return all$2(concatenatedMasks, -1, false);
76087 });
76088 }
76089 }, {
76090 key: "getConfig",
76091 value: function getConfig() {
76092 var config = {
76093 'axis': this.axis
76094 };
76095 var baseConfig = _get(_getPrototypeOf(Concatenate.prototype), "getConfig", this).call(this);
76096 Object.assign(config, baseConfig);
76097 return config;
76098 }
76099 }]);
76100 return Concatenate;
76101 }(Merge);
76102 /** @nocollapse */
76103 Concatenate.className = 'Concatenate';
76104 registerClass(Concatenate);
76105 /**
76106 * Concatenate an `Array` of inputs.
76107 *
76108 * This function can be invoked in three ways.
76109 *
76110 * 1. Construct an instance of `Concatenate` layer, by using no input argument
76111 * or a single configuration argument. The resultant `Concatenate` layer can
76112 * then be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
76113 *
76114 * ```js
76115 * const concatLayer = tf.layers.concatenate();
76116 *
76117 * // The layer can be applied to inputs.
76118 * const input1 = tf.input({shape: [2, 3]});
76119 * const input2 = tf.input({shape: [2, 4]});
76120 * const output = concatLayer.apply([input1, input2]);
76121 * console.log(output.shape);
76122 * // You get [null, 2, 7], with the first dimension as the undetermined batch
76123 * // dimension and the last dimension as the result of concatenating the
76124 * // last dimensions of the two inputs.
76125 * ```
76126 *
76127 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
76128 * an `Layer` object internally and calls its `apply` method on the inputs,
76129 * generating a new `tf.SymbolicTensor`. For example:
76130 *
76131 * ```js
76132 * const input1 = tf.input({shape: [2, 3]});
76133 * const input2 = tf.input({shape: [2, 4]});
76134 * const output = tf.layers.concatenate([input1, input2]);
76135 * console.log(output.shape);
76136 * // You get [null, 2, 2], with the first dimension as the undetermined batch
76137 * // dimension and the last dimension as the result of concatenating the
76138 * // last dimensions of the two inputs.
76139 * ```
76140 *
76141 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
76142 * an `Layer` object internally and calls its `apply` method on the inputs,
76143 * generating a new `tf.Tensor` as the result of the computation. For
76144 * example:
76145 *
76146 * ```js
76147 * const input1 = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
76148 * const input2 = tf.tensor2d([[10, 20], [30, 40]], [2, 2]);
76149 * tf.layers.concatenate([input1, input2]).print();
76150 * // Gives [[1, 2, 10, 20], [3, 4, 30, 40]].
76151 *
76152 */
76153 function concatenate$1(config) {
76154 if (Array.isArray(config)) {
76155 var layer = new Concatenate({});
76156 return layer.apply(config);
76157 } else {
76158 return new Concatenate(config);
76159 }
76160 }
76161 /**
76162 * Interpretable potentially negative axis index.
76163 *
76164 * For example, given axis = -1, and dim = 3, this function will return 2.
76165 *
76166 * @param axis The axis index, may be a positive, zero or negative integer.
76167 * @param dim Total number of dimensions, a positive integer.
76168 * @returns A non-negative axis index equivalent to the input `axis`.
76169 */
76170 function interpretAxis(axis, dim) {
76171 while (axis < 0) {
76172 axis += dim;
76173 }
76174 return axis;
76175 }
76176 function batchDot(x, y, axes) {
76177 if (x.shape.length > 3 || y.shape.length > 3) {
76178 throw new NotImplementedError('batchDot is not implemented for tensors of 4D or higher rank yet');
76179 }
76180 assert$1(x.shape.length >= 2, function () {
76181 return "batchDot requires the rank of x to be >= 2, " + "but got ".concat(x.shape.length);
76182 });
76183 assert$1(x.shape.length >= 2, function () {
76184 return "batchDot requires the rank of y to be >= 2, " + "but got ".concat(y.shape.length);
76185 });
76186 if (typeof axes === 'number') {
76187 axes = [axes, axes];
76188 }
76189 if (x.dtype === 'complex64' || y.dtype === 'complex64') {
76190 throw new NotImplementedError('batchDot is not implemented for complex64-type Tensors yet.');
76191 }
76192 var xNDim = x.shape.length;
76193 var yNDim = y.shape.length;
76194 if (axes == null) {
76195 // Behave like batchMatmul by default.
76196 axes = [xNDim - 1, yNDim - 2];
76197 }
76198 var axesArray = axes;
76199 return tidy(function () {
76200 var diff;
76201 if (xNDim > yNDim) {
76202 diff = xNDim - yNDim;
76203 var diffShape = [];
76204 for (var i = 0; i < diff; ++i) {
76205 diffShape.push(1);
76206 }
76207 y = reshape$3(y, y.shape.concat(diffShape));
76208 } else if (yNDim > xNDim) {
76209 diff = yNDim - xNDim;
76210 var _diffShape = [];
76211 for (var _i = 0; _i < diff; ++_i) {
76212 _diffShape.push(1);
76213 }
76214 x = reshape$3(x, x.shape.concat(_diffShape));
76215 } else {
76216 diff = 0;
76217 }
76218 var out;
76219 if (x.shape.length === 2 && y.shape.length === 2) {
76220 if (axesArray[0] === axesArray[1]) {
76221 out = sum$3(mul(x, y), axesArray[0]);
76222 } else {
76223 out = sum$3(mul(transpose$2(x, [1, 0]), y), axesArray[1]);
76224 }
76225 } else {
76226 var adjX = axesArray[0] !== x.shape.length - 1;
76227 var adjY = axesArray[1] === y.shape.length - 1;
76228 out = matMul$1(x, y, adjX, adjY);
76229 }
76230 if (diff > 0) {
76231 var idx;
76232 if (xNDim > yNDim) {
76233 idx = xNDim + yNDim - 3;
76234 } else {
76235 idx = xNDim - 1;
76236 }
76237 var squeezeAxes = [];
76238 for (var _i2 = idx; _i2 < idx + diff; ++_i2) {
76239 squeezeAxes.push(_i2);
76240 }
76241 out = squeeze(out, squeezeAxes);
76242 }
76243 if (out.shape.length === 1) {
76244 out = expandDims$3(out, 1);
76245 }
76246 return out;
76247 });
76248 }
76249 var Dot = /*#__PURE__*/function (_Merge7) {
76250 _inherits(Dot, _Merge7);
76251 var _super8 = _createSuper(Dot);
76252 function Dot(args) {
76253 var _this6;
76254 _classCallCheck(this, Dot);
76255 _this6 = _super8.call(this, args);
76256 _this6.axes = args.axes;
76257 _this6.normalize = args.normalize == null ? false : args.normalize;
76258 _this6.supportsMasking = true;
76259 _this6.reshapeRequired = false;
76260 return _this6;
76261 }
76262 _createClass(Dot, [{
76263 key: "build",
76264 value: function build(inputShape) {
76265 assert$1(Array.isArray(inputShape) && inputShape.length === 2 && Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), function () {
76266 return 'A `Dot` layer should be called on a list of exactly 2 inputs.';
76267 });
76268 var shape1 = inputShape[0];
76269 var shape2 = inputShape[1];
76270 if (shape1.length > 3 || shape2.length > 3) {
76271 throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.');
76272 }
76273 var axes = this.interpretAxes(shape1, shape2);
76274 if (shape1[axes[0]] !== shape2[axes[1]]) {
76275 throw new ValueError("Dimension incompatibility: " + "".concat(shape1[axes[0]], " !== ").concat(shape2[axes[1]]));
76276 }
76277 }
76278 }, {
76279 key: "mergeFunction",
76280 value: function mergeFunction(inputs) {
76281 if (inputs.length !== 2) {
76282 throw new ValueError('A `Dot` layer must be called on exactly 2 inputs, ' + "but received ".concat(inputs.length, " input(s)."));
76283 }
76284 var x1 = inputs[0];
76285 var x2 = inputs[1];
76286 var axes;
76287 if (!Array.isArray(this.axes)) {
76288 axes = [interpretAxis(this.axes, x1.shape.length), interpretAxis(this.axes, x2.shape.length)];
76289 } else {
76290 axes = this.axes.map(function (axis, i) {
76291 return interpretAxis(axis, inputs[i].shape.length);
76292 });
76293 }
76294 if (this.normalize) {
76295 x1 = l2Normalize(x1, axes[0]);
76296 x2 = l2Normalize(x2, axes[1]);
76297 }
76298 return batchDot(x1, x2, axes);
76299 }
76300 }, {
76301 key: "interpretAxes",
76302 value: function interpretAxes(shape1, shape2) {
76303 var axes;
76304 if (!Array.isArray(this.axes)) {
76305 // `this.axes` is a single integer.
76306 axes = [interpretAxis(this.axes, shape1.length), interpretAxis(this.axes, shape2.length)];
76307 } else {
76308 // `this.axes` is an Array of integers.
76309 axes = this.axes;
76310 }
76311 return axes;
76312 }
76313 }, {
76314 key: "computeOutputShape",
76315 value: function computeOutputShape(inputShape) {
76316 assert$1(Array.isArray(inputShape) && inputShape.length === 2 && Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), function () {
76317 return 'A `Dot` layer should be called on a list of exactly 2 inputs.';
76318 });
76319 var shape1 = inputShape[0].slice();
76320 var shape2 = inputShape[1].slice();
76321 if (shape1.length > 3 || shape2.length > 3) {
76322 throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.');
76323 }
76324 var axes = this.interpretAxes(shape1, shape2);
76325 shape1.splice(axes[0], 1);
76326 shape2.splice(axes[1], 1);
76327 shape2.splice(0, 1);
76328 var outputShape = shape1.concat(shape2);
76329 if (outputShape.length === 1) {
76330 outputShape.push(1);
76331 }
76332 return outputShape;
76333 }
76334 }, {
76335 key: "computeMask",
76336 value: function computeMask(inputs, mask) {
76337 return null;
76338 }
76339 }, {
76340 key: "getConfig",
76341 value: function getConfig() {
76342 var config = {
76343 'axes': this.axes,
76344 'normalize': this.normalize
76345 };
76346 var baseConfig = _get(_getPrototypeOf(Dot.prototype), "getConfig", this).call(this);
76347 Object.assign(config, baseConfig);
76348 return config;
76349 }
76350 }]);
76351 return Dot;
76352 }(Merge);
76353 /** @nocollapse */
76354 Dot.className = 'Dot';
76355 registerClass(Dot);
76356 // TODO(cais): Add functional interfaces for the merge layers.
76357
76358 var GaussianNoise = /*#__PURE__*/function (_Layer) {
76359 _inherits(GaussianNoise, _Layer);
76360 var _super = _createSuper(GaussianNoise);
76361 function GaussianNoise(args) {
76362 var _this;
76363 _classCallCheck(this, GaussianNoise);
76364 _this = _super.call(this, args);
76365 _this.supportsMasking = true;
76366 _this.stddev = args.stddev;
76367 return _this;
76368 }
76369 _createClass(GaussianNoise, [{
76370 key: "computeOutputShape",
76371 value: function computeOutputShape(inputShape) {
76372 return inputShape;
76373 }
76374 }, {
76375 key: "getConfig",
76376 value: function getConfig() {
76377 var baseConfig = _get(_getPrototypeOf(GaussianNoise.prototype), "getConfig", this).call(this);
76378 var config = {
76379 stddev: this.stddev
76380 };
76381 Object.assign(config, baseConfig);
76382 return config;
76383 }
76384 }, {
76385 key: "call",
76386 value: function call(inputs, kwargs) {
76387 var _this2 = this;
76388 return tidy(function () {
76389 _this2.invokeCallHook(inputs, kwargs);
76390 var input = getExactlyOneTensor(inputs);
76391 var noised = function noised() {
76392 return add$3(randomNormal$1(input.shape, 0, _this2.stddev), input);
76393 };
76394 var output = inTrainPhase(noised, function () {
76395 return input;
76396 }, kwargs['training'] || false);
76397 return output;
76398 });
76399 }
76400 }]);
76401 return GaussianNoise;
76402 }(Layer);
76403 /** @nocollapse */
76404 GaussianNoise.className = 'GaussianNoise';
76405 registerClass(GaussianNoise);
76406 var GaussianDropout = /*#__PURE__*/function (_Layer2) {
76407 _inherits(GaussianDropout, _Layer2);
76408 var _super2 = _createSuper(GaussianDropout);
76409 function GaussianDropout(args) {
76410 var _this3;
76411 _classCallCheck(this, GaussianDropout);
76412 _this3 = _super2.call(this, args);
76413 _this3.supportsMasking = true;
76414 _this3.rate = args.rate;
76415 return _this3;
76416 }
76417 _createClass(GaussianDropout, [{
76418 key: "computeOutputShape",
76419 value: function computeOutputShape(inputShape) {
76420 return inputShape;
76421 }
76422 }, {
76423 key: "getConfig",
76424 value: function getConfig() {
76425 var baseConfig = _get(_getPrototypeOf(GaussianDropout.prototype), "getConfig", this).call(this);
76426 var config = {
76427 rate: this.rate
76428 };
76429 Object.assign(config, baseConfig);
76430 return config;
76431 }
76432 }, {
76433 key: "call",
76434 value: function call(inputs, kwargs) {
76435 var _this4 = this;
76436 return tidy(function () {
76437 _this4.invokeCallHook(inputs, kwargs);
76438 var input = getExactlyOneTensor(inputs);
76439 if (_this4.rate > 0 && _this4.rate < 1) {
76440 var noised = function noised() {
76441 var stddev = Math.sqrt(_this4.rate / (1 - _this4.rate));
76442 return mul(input, randomNormal$1(input.shape, 1, stddev));
76443 };
76444 return inTrainPhase(noised, function () {
76445 return input;
76446 }, kwargs['training'] || false);
76447 }
76448 return input;
76449 });
76450 }
76451 }]);
76452 return GaussianDropout;
76453 }(Layer);
76454 /** @nocollapse */
76455 GaussianDropout.className = 'GaussianDropout';
76456 registerClass(GaussianDropout);
76457 /**
76458 * Applies Alpha Dropout to the input.
76459 *
76460 * As it is a regularization layer, it is only active at training time.
76461 *
76462 * Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
76463 * to their original values, in order to ensure the self-normalizing property
76464 * even after this dropout.
76465 * Alpha Dropout fits well to Scaled Exponential Linear Units
76466 * by randomly setting activations to the negative saturation value.
76467 *
76468 * Arguments:
76469 * - `rate`: float, drop probability (as with `Dropout`).
76470 * The multiplicative noise will have
76471 * standard deviation `sqrt(rate / (1 - rate))`.
76472 * - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the
76473 * shape for randomly generated keep/drop flags.
76474 *
76475 * Input shape:
76476 * Arbitrary. Use the keyword argument `inputShape`
76477 * (tuple of integers, does not include the samples axis)
76478 * when using this layer as the first layer in a model.
76479 *
76480 * Output shape:
76481 * Same shape as input.
76482 *
76483 * References:
76484 * - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
76485 */
76486 var AlphaDropout = /*#__PURE__*/function (_Layer3) {
76487 _inherits(AlphaDropout, _Layer3);
76488 var _super3 = _createSuper(AlphaDropout);
76489 function AlphaDropout(args) {
76490 var _this5;
76491 _classCallCheck(this, AlphaDropout);
76492 _this5 = _super3.call(this, args);
76493 _this5.supportsMasking = true;
76494 _this5.rate = args.rate;
76495 _this5.noiseShape = args.noiseShape;
76496 return _this5;
76497 }
76498 _createClass(AlphaDropout, [{
76499 key: "_getNoiseShape",
76500 value: function _getNoiseShape(inputs) {
76501 return this.noiseShape || getExactlyOneTensor(inputs).shape;
76502 }
76503 }, {
76504 key: "computeOutputShape",
76505 value: function computeOutputShape(inputShape) {
76506 return inputShape;
76507 }
76508 }, {
76509 key: "getConfig",
76510 value: function getConfig() {
76511 var baseConfig = _get(_getPrototypeOf(AlphaDropout.prototype), "getConfig", this).call(this);
76512 var config = {
76513 rate: this.rate
76514 };
76515 Object.assign(config, baseConfig);
76516 return config;
76517 }
76518 }, {
76519 key: "call",
76520 value: function call(inputs, kwargs) {
76521 var _this6 = this;
76522 return tidy(function () {
76523 if (_this6.rate < 1 && _this6.rate > 0) {
76524 var noiseShape = _this6._getNoiseShape(inputs);
76525 var droppedInputs = function droppedInputs() {
76526 var input = getExactlyOneTensor(inputs);
76527 var alpha = 1.6732632423543772848170429916717;
76528 var scale = 1.0507009873554804934193349852946;
76529 var alphaP = -alpha * scale;
76530 var keptIdx = greaterEqual$2(randomUniform$1(noiseShape), _this6.rate);
76531 keptIdx = cast$2(keptIdx, 'float32'); // get default dtype.
76532 // Get affine transformation params.
76533 var a = Math.pow((1 - _this6.rate) * (1 + _this6.rate * Math.pow(alphaP, 2)), -0.5);
76534 var b = -a * alphaP * _this6.rate;
76535 // Apply mask.
76536 var x = add$3(mul(input, keptIdx), mul(add$3(keptIdx, -1), alphaP));
76537 return add$3(mul(x, a), b);
76538 };
76539 return inTrainPhase(droppedInputs, function () {
76540 return getExactlyOneTensor(inputs);
76541 }, kwargs['training'] || false);
76542 }
76543 return inputs;
76544 });
76545 }
76546 }]);
76547 return AlphaDropout;
76548 }(Layer);
76549 /** @nocollapse */
76550 AlphaDropout.className = 'AlphaDropout';
76551 registerClass(AlphaDropout);
76552
76553 /**
76554 * Applies batch normalization on x given mean, var, beta and gamma.
76555 *
76556 * I.e. returns:
76557 * `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
76558 *
76559 * @param x Input tensor.
76560 * @param mean Mean of batch.
76561 * @param variance Variance of batch.
76562 * @param beta Tensor with which to center the input.
76563 * @param gamma Tensor by which to scale the input.
76564 * @param epsilon Fuzz factor.
76565 * @returns The result of the batch normalization.
76566 */
76567 function batchNormalization$1(x, mean, variance, beta, gamma) {
76568 var epsilon = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : 1e-3;
76569 var out;
76570 if (x.rank === 2) {
76571 out = batchNorm2d(x, mean, variance, beta, gamma, epsilon);
76572 } else if (x.rank === 3) {
76573 // TODO(cais): Check rank; give proper error message.
76574 out = batchNorm3d(x, mean, variance, beta, gamma, epsilon);
76575 } else if (x.rank === 4) {
76576 out = batchNorm4d(x, mean, variance, beta, gamma, epsilon);
76577 } else {
76578 throw new NotImplementedError("batchNormalization is not implemented for array of rank ".concat(x.rank, " ") + "yet");
76579 }
76580 return out;
76581 }
76582 /**
76583 * Non-broadcasting batch normalization for use in training (not inference).
76584 *
76585 * The input is normalized to zero mean and unit variance along the
76586 * `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`.
76587 * The result of that is returned as the first element
76588 * of the returned `Array`. The other two elements are the mean and variance,
76589 * respectively.
76590 *
76591 * @param x Input tensor to be normalized.
76592 * @param gamma Tensor by which to scale the input.
76593 * @param beta Tensor by which to center the input.
76594 * @param reductionAxes Axes over which to normalize.
76595 * @param epsilon Fuzz factor.
76596 * @returns An `Array` of three `Tensors`:
76597 * [normalized tensor, mean of input, variance of input].
76598 */
76599 function regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes) {
76600 var epsilon = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 1e-3;
76601 return tidy(function () {
76602 var meanAndVariance = moments(x, reductionAxes);
76603 var mean = meanAndVariance.mean;
76604 var variance = meanAndVariance.variance;
76605 var normed = batchNormalization$1(x, mean, variance, beta, gamma, epsilon);
76606 return [normed, mean, variance];
76607 });
76608 }
76609 /**
76610 * Broadcasting batch normalization for use in training (not inference).
76611 *
76612 * The input is normalized to zero mean and unit variance along the
76613 * `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`.
76614 * The result of that is returned as the first element
76615 * of the returned `Array`. The other two elements are the mean and variance,
76616 * respectively.
76617 *
76618 * @param x Input tensor to be normalized.
76619 * @param gamma Tensor by which to scale the input.
76620 * @param beta Tensor by which to center the input.
76621 * @param reductionAxes Axes over which to normalize.
76622 * @param epsilon Fuzz factor.
76623 * @returns An `Array` of three `Tensors`:
76624 * [normalized tensor, mean of input, variance of input].
76625 */
76626 function broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes) {
76627 var epsilon = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 1e-3;
76628 return tidy(function () {
76629 var meanAndVariance = moments(x, reductionAxes);
76630 var mean = meanAndVariance.mean;
76631 var variance = meanAndVariance.variance;
76632 var targetShape = [];
76633 var _iterator = _createForOfIteratorHelper(range$2(0, x.rank)),
76634 _step;
76635 try {
76636 for (_iterator.s(); !(_step = _iterator.n()).done;) {
76637 var axis = _step.value;
76638 if (reductionAxes.indexOf(axis) !== -1) {
76639 targetShape.push(1);
76640 } else {
76641 targetShape.push(x.shape[axis]);
76642 }
76643 }
76644 } catch (err) {
76645 _iterator.e(err);
76646 } finally {
76647 _iterator.f();
76648 }
76649 var broadcastMean = reshape$3(mean, targetShape);
76650 var broadcastVariance = reshape$3(variance, targetShape);
76651 var broadcastGamma = gamma == null ? null : reshape$3(gamma, targetShape);
76652 var broadcastBeta = beta == null ? null : reshape$3(beta, targetShape);
76653 var normed = batchNormalization$1(x, broadcastMean, broadcastVariance, broadcastBeta, broadcastGamma, epsilon);
76654 return [normed, mean, variance];
76655 });
76656 }
76657 /**
76658 * Batch normalization for use in training (not inference).
76659 *
76660 * @param x Input tensor to be normalized.
76661 * @param gamma Tensor by which to scale the input.
76662 * @param beta Tensor by which to center the input.
76663 * @param reductionAxes Axes over which to normalize.
76664 * @param epsilon Fuzz factor.
76665 * @returns An `Array` of three `Tensors`:
76666 * [normalized tensor, mean of input, variance of input].
76667 */
76668 function normalizeBatchInTraining(x, gamma, beta, reductionAxes) {
76669 var epsilon = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 1e-3;
76670 if (arraysEqual(reductionAxes.slice().sort(), range$2(0, x.rank - 1))) {
76671 return regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon);
76672 } else {
76673 return broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon);
76674 }
76675 }
76676 var BatchNormalization = /*#__PURE__*/function (_Layer) {
76677 _inherits(BatchNormalization, _Layer);
76678 var _super = _createSuper(BatchNormalization);
76679 function BatchNormalization(args) {
76680 var _this;
76681 _classCallCheck(this, BatchNormalization);
76682 if (args == null) {
76683 args = {};
76684 }
76685 _this = _super.call(this, args);
76686 _this.supportsMasking = true;
76687 _this.axis = args.axis == null ? -1 : args.axis;
76688 _this.momentum = args.momentum == null ? 0.99 : args.momentum;
76689 _this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
76690 _this.center = args.center == null ? true : args.center;
76691 _this.scale = args.scale == null ? true : args.scale;
76692 _this.betaInitializer = getInitializer(args.betaInitializer || 'zeros');
76693 _this.gammaInitializer = getInitializer(args.gammaInitializer || 'ones');
76694 _this.movingMeanInitializer = getInitializer(args.movingMeanInitializer || 'zeros');
76695 _this.movingVarianceInitializer = getInitializer(args.movingVarianceInitializer || 'ones');
76696 _this.betaConstraint = getConstraint(args.betaConstraint);
76697 _this.gammaConstraint = getConstraint(args.gammaConstraint);
76698 _this.betaRegularizer = getRegularizer(args.betaRegularizer);
76699 _this.gammaRegularizer = getRegularizer(args.gammaRegularizer);
76700 return _this;
76701 }
76702 _createClass(BatchNormalization, [{
76703 key: "build",
76704 value: function build(inputShape) {
76705 inputShape = getExactlyOneShape(inputShape);
76706 var axis = this.axis >= 0 ? this.axis : this.axis + inputShape.length;
76707 var dim = inputShape[axis];
76708 if (dim == null) {
76709 throw new ValueError("Axis ".concat(axis, " of input tensor should have a defined dimension but ") + "the layer received an input with shape " + "".concat(JSON.stringify(inputShape), "."));
76710 }
76711 this.inputSpec = [new InputSpec({
76712 ndim: inputShape.length,
76713 axes: _defineProperty({}, axis, dim)
76714 })];
76715 var shape = [dim];
76716 if (this.scale) {
76717 this.gamma = this.addWeight('gamma', shape, null, this.gammaInitializer, this.gammaRegularizer, true, this.gammaConstraint);
76718 }
76719 if (this.center) {
76720 this.beta = this.addWeight('beta', shape, null, this.betaInitializer, this.betaRegularizer, true, this.betaConstraint);
76721 }
76722 this.movingMean = this.addWeight('moving_mean', shape, null, this.movingMeanInitializer, null, false);
76723 this.movingVariance = this.addWeight('moving_variance', shape, null, this.movingVarianceInitializer, null, false);
76724 this.built = true;
76725 }
76726 }, {
76727 key: "call",
76728 value: function call(inputs, kwargs) {
76729 var _this2 = this;
76730 return tidy(function () {
76731 var training = kwargs['training'] == null ? false : kwargs['training'];
76732 var input = getExactlyOneTensor(inputs);
76733 var inputShape = input.shape;
76734 var ndim = inputShape.length;
76735 var reductionAxes = range$2(0, ndim);
76736 var axis = _this2.axis >= 0 ? _this2.axis : _this2.axis + ndim;
76737 reductionAxes.splice(axis, 1);
76738 var broadcastShape = pyListRepeat(1, ndim);
76739 broadcastShape[axis] = inputShape[axis];
76740 var sortedReductionAxes = reductionAxes.slice();
76741 sortedReductionAxes.sort();
76742 var needsBroadcasting = !arraysEqual(sortedReductionAxes, range$2(0, ndim).slice(0, ndim - 1));
76743 var normalizeInference = function normalizeInference() {
76744 if (needsBroadcasting) {
76745 var broadcastMovingMean = reshape$3(_this2.movingMean.read(), broadcastShape);
76746 var broadcastMovingVariance = reshape$3(_this2.movingVariance.read(), broadcastShape);
76747 var broadcastBeta = _this2.center ? reshape$3(_this2.beta.read(), broadcastShape) : null;
76748 var broadcastGamma = _this2.scale ? reshape$3(_this2.gamma.read(), broadcastShape) : null;
76749 return batchNormalization$1(input, broadcastMovingMean, broadcastMovingVariance, broadcastBeta, broadcastGamma, _this2.epsilon);
76750 } else {
76751 return batchNormalization$1(input, _this2.movingMean.read(), _this2.movingVariance.read(), _this2.beta == null ? null : _this2.beta.read(), _this2.gamma == null ? null : _this2.gamma.read(), _this2.epsilon);
76752 }
76753 };
76754 if (!training) {
76755 return normalizeInference();
76756 }
76757 var _normalizeBatchInTrai = normalizeBatchInTraining(input, _this2.gamma.read(), _this2.beta.read(), reductionAxes, _this2.epsilon),
76758 _normalizeBatchInTrai2 = _slicedToArray(_normalizeBatchInTrai, 3),
76759 normedTraining = _normalizeBatchInTrai2[0],
76760 mean = _normalizeBatchInTrai2[1],
76761 variance = _normalizeBatchInTrai2[2];
76762 var doMovingAverage = function doMovingAverage(variable, value, momentum) {
76763 tidy(function () {
76764 var decay = 1 - momentum;
76765 var origValue = variable.read();
76766 var updateDelta = mul(sub$2(origValue, value), decay);
76767 variable.write(sub$2(origValue, updateDelta));
76768 });
76769 };
76770 // Perform updates to moving mean and moving variance for training.
76771 // Porting Note: In PyKeras, these updates to `movingMean` and
76772 // `movingAverage` are done as a deferred Graph, added to the `Layer`'s
76773 // `update`s using the `add_update()` method. Here we do it imperatively
76774 // and encapsulate the updates in a function that is invoked
76775 // immediately.
76776 var updateMovingMeanAndVariance = function updateMovingMeanAndVariance() {
76777 doMovingAverage(_this2.movingMean, mean, _this2.momentum);
76778 doMovingAverage(_this2.movingVariance, variance, _this2.momentum);
76779 };
76780 updateMovingMeanAndVariance();
76781 return normedTraining;
76782 });
76783 }
76784 }, {
76785 key: "getConfig",
76786 value: function getConfig() {
76787 var config = {
76788 axis: this.axis,
76789 momentum: this.momentum,
76790 epsilon: this.epsilon,
76791 center: this.center,
76792 scale: this.scale,
76793 betaInitializer: serializeInitializer(this.betaInitializer),
76794 gammaInitializer: serializeInitializer(this.gammaInitializer),
76795 movingMeanInitializer: serializeInitializer(this.movingMeanInitializer),
76796 movingVarianceInitializer: serializeInitializer(this.movingVarianceInitializer),
76797 betaRegularizer: serializeRegularizer(this.betaRegularizer),
76798 gammaRegularizer: serializeRegularizer(this.gammaRegularizer),
76799 betaConstraint: serializeConstraint(this.betaConstraint),
76800 gammaConstraint: serializeConstraint(this.gammaConstraint)
76801 };
76802 var baseConfig = _get(_getPrototypeOf(BatchNormalization.prototype), "getConfig", this).call(this);
76803 Object.assign(config, baseConfig);
76804 return config;
76805 }
76806 }]);
76807 return BatchNormalization;
76808 }(Layer);
76809 /** @nocollapse */
76810 BatchNormalization.className = 'BatchNormalization';
76811 registerClass(BatchNormalization);
76812 var LayerNormalization = /*#__PURE__*/function (_Layer2) {
76813 _inherits(LayerNormalization, _Layer2);
76814 var _super2 = _createSuper(LayerNormalization);
76815 function LayerNormalization(args) {
76816 var _this3;
76817 _classCallCheck(this, LayerNormalization);
76818 if (args == null) {
76819 args = {};
76820 }
76821 _this3 = _super2.call(this, args);
76822 _this3.axis = args.axis == null ? -1 : args.axis;
76823 if (typeof _this3.axis === 'number') {
76824 if (!Number.isInteger(_this3.axis)) {
76825 throw new Error("Expected axis to be an integer, but received ".concat(_this3.axis));
76826 }
76827 } else if (Array.isArray(_this3.axis)) {
76828 var _iterator2 = _createForOfIteratorHelper(_this3.axis),
76829 _step2;
76830 try {
76831 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
76832 var axis = _step2.value;
76833 if (!Number.isInteger(axis)) {
76834 throw new Error("Expected axis to be an array of integers, " + "but received ".concat(JSON.stringify(_this3.axis)));
76835 }
76836 }
76837 } catch (err) {
76838 _iterator2.e(err);
76839 } finally {
76840 _iterator2.f();
76841 }
76842 } else {
76843 throw new Error("Expected axis to be an integer or an array of integers, " + "but received ".concat(JSON.stringify(_this3.axis)));
76844 }
76845 _this3.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
76846 _this3.center = args.center == null ? true : args.center;
76847 _this3.scale = args.scale == null ? true : args.scale;
76848 _this3.betaInitializer = getInitializer(args.betaInitializer || 'zeros');
76849 _this3.gammaInitializer = getInitializer(args.gammaInitializer || 'ones');
76850 _this3.betaRegularizer = getRegularizer(args.betaRegularizer);
76851 _this3.gammaRegularizer = getRegularizer(args.gammaRegularizer);
76852 _this3.supportsMasking = true;
76853 return _this3;
76854 }
76855 _createClass(LayerNormalization, [{
76856 key: "build",
76857 value: function build(inputShape) {
76858 inputShape = getExactlyOneShape(inputShape);
76859 var nDims = inputShape.length;
76860 // Convert axis to array and resolve negatives.
76861 if (typeof this.axis === 'number') {
76862 this.axis = [this.axis];
76863 }
76864 for (var i = 0; i < this.axis.length; ++i) {
76865 if (this.axis[i] < 0) {
76866 this.axis[i] += nDims;
76867 }
76868 }
76869 // Further validate axes.
76870 var _iterator3 = _createForOfIteratorHelper(this.axis),
76871 _step3;
76872 try {
76873 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
76874 var axis = _step3.value;
76875 if (axis < 0 || axis >= nDims) {
76876 throw new Error("Invalid axis: ".concat(axis));
76877 }
76878 }
76879 } catch (err) {
76880 _iterator3.e(err);
76881 } finally {
76882 _iterator3.f();
76883 }
76884 if (this.axis.length !== unique$2(this.axis).length) {
76885 throw new Error("Found duplicate axes in: ".concat(this.axis));
76886 }
76887 var paramShape = this.axis.map(function (axis) {
76888 return inputShape[axis];
76889 });
76890 var trainable = true;
76891 if (this.scale) {
76892 this.gamma = this.addWeight('gamma', paramShape, 'float32', this.gammaInitializer, this.gammaRegularizer, trainable);
76893 } else {
76894 this.gamma = null;
76895 }
76896 if (this.center) {
76897 this.beta = this.addWeight('beta', paramShape, 'float32', this.betaInitializer, this.betaRegularizer, trainable);
76898 } else {
76899 this.beta = null;
76900 }
76901 this.built = true;
76902 }
76903 }, {
76904 key: "call",
76905 value: function call(inputs, kwargs) {
76906 var _this4 = this;
76907 var input = getExactlyOneTensor(inputs);
76908 var inputShape = input.shape;
76909 var nDims = inputShape.length;
76910 return tidy(function () {
76911 var keepDims = true;
76912 var _moments = moments(input, _this4.axis, keepDims),
76913 mean = _moments.mean,
76914 variance = _moments.variance;
76915 var broadcastShape = pyListRepeat(1, nDims);
76916 var _iterator4 = _createForOfIteratorHelper(_this4.axis),
76917 _step4;
76918 try {
76919 for (_iterator4.s(); !(_step4 = _iterator4.n()).done;) {
76920 var dim = _step4.value;
76921 broadcastShape[dim] = inputShape[dim];
76922 }
76923 } catch (err) {
76924 _iterator4.e(err);
76925 } finally {
76926 _iterator4.f();
76927 }
76928 var broadcast = function broadcast(v) {
76929 if (v != null && v.shape.length !== nDims) {
76930 return reshape$3(v, broadcastShape);
76931 } else {
76932 return v;
76933 }
76934 };
76935 var scale = _this4.scale ? broadcast(_this4.gamma.read()) : null;
76936 var offset = _this4.center ? broadcast(_this4.beta.read()) : null;
76937 // TODO(https://github.com/tensorflow/tfjs/issues/2120): The tiling below
76938 // is a workaround for the limitation of core's batchNormalization?d don't
76939 // support broadcasting in their gradients. In addition, the tiling is
76940 // necessary to ensure correctness on the browser CPU backend regardless
76941 // of forward or backward computation. Remove this workaround once the
76942 // limitation is addressed. See .
76943 var momentsTiling = [];
76944 var scaleOffsetTiling = [];
76945 for (var i = 0; i < nDims; ++i) {
76946 if (_this4.axis.indexOf(i) !== -1) {
76947 momentsTiling.push(inputShape[i]);
76948 scaleOffsetTiling.push(1);
76949 } else {
76950 momentsTiling.push(1);
76951 scaleOffsetTiling.push(inputShape[i]);
76952 }
76953 }
76954 mean = tile$3(mean, momentsTiling);
76955 variance = tile$3(variance, momentsTiling);
76956 if (scale != null) {
76957 scale = tile$3(scale, scaleOffsetTiling);
76958 }
76959 if (offset != null) {
76960 offset = tile$3(offset, scaleOffsetTiling);
76961 }
76962 return batchNormalization$1(input, mean, variance, offset, scale, _this4.epsilon);
76963 });
76964 }
76965 }, {
76966 key: "getConfig",
76967 value: function getConfig() {
76968 var config = {
76969 axis: this.axis,
76970 epsilon: this.epsilon,
76971 center: this.center,
76972 scale: this.scale,
76973 betaInitializer: serializeInitializer(this.betaInitializer),
76974 gammaInitializer: serializeInitializer(this.gammaInitializer),
76975 betaRegularizer: serializeRegularizer(this.betaRegularizer),
76976 gammaRegularizer: serializeRegularizer(this.gammaRegularizer)
76977 };
76978 var baseConfig = _get(_getPrototypeOf(LayerNormalization.prototype), "getConfig", this).call(this);
76979 Object.assign(config, baseConfig);
76980 return config;
76981 }
76982 }]);
76983 return LayerNormalization;
76984 }(Layer);
76985 /** @nocollapse */
76986 LayerNormalization.className = 'LayerNormalization';
76987 registerClass(LayerNormalization);
76988
76989 /**
76990 * Pads the middle dimension of a 3D tensor.
76991 *
76992 * @param x Input `tf.Tensor` to be padded.
76993 * @param padding `Array` of 2 integers, how many zeros to add at the start and
76994 * end of the middle dimension (i.e., dimension 1).
76995 * @return A padded 3D `tf.Tensor`.
76996 */
76997 function temporalPadding(x, padding) {
76998 return tidy(function () {
76999 if (x.rank !== 3) {
77000 throw new ValueError("temporalPadding expects input tensor to be 3-D, but received a " + "".concat(x.rank, "-D tensor."));
77001 }
77002 if (padding == null) {
77003 padding = [1, 1];
77004 }
77005 if (padding.length !== 2) {
77006 throw new ValueError("temporalPadding expects input padding pattern to be a length-2 " + "array, but received a length-".concat(padding.length, " array."));
77007 }
77008 var pattern = [[0, 0], padding, [0, 0]];
77009 return pad(x, pattern);
77010 });
77011 }
77012 /**
77013 * Pads the 2nd and 3rd dimensions of a 4D tensor.
77014 *
77015 * @param x Input `tf.Tensor` to be padded.
77016 * @param padding `Array` of two `Array`s, each of which is an `Array` of two
77017 * integers. The amount of padding at the beginning and end of the 2nd and 3rd
77018 * dimensions, respectively.
77019 * @param dataFormat 'channelsLast' (default) or 'channelsFirst'.
77020 * @return Padded 4D `tf.Tensor`.
77021 */
77022 function spatial2dPadding(x, padding, dataFormat) {
77023 return tidy(function () {
77024 if (x.rank !== 4) {
77025 throw new ValueError("temporalPadding expects input tensor to be 4-D, but received a " + "".concat(x.rank, "-D tensor."));
77026 }
77027 if (padding == null) {
77028 padding = [[1, 1], [1, 1]];
77029 }
77030 if (padding.length !== 2 || padding[0].length !== 2 || padding[1].length !== 2) {
77031 throw new ValueError('spatial2dPadding expects `padding` to be an Array of two Arrays, ' + 'each of which is an Array of two integers.');
77032 }
77033 if (dataFormat == null) {
77034 dataFormat = imageDataFormat();
77035 }
77036 if (dataFormat !== 'channelsLast' && dataFormat !== 'channelsFirst') {
77037 throw new ValueError("Unknown data format: ".concat(dataFormat, ". ") + "Supported data formats are 'channelsLast' and 'channelsFirst.");
77038 }
77039 var pattern;
77040 if (dataFormat === 'channelsFirst') {
77041 pattern = [[0, 0], [0, 0], padding[0], padding[1]];
77042 } else {
77043 pattern = [[0, 0], padding[0], padding[1], [0, 0]];
77044 }
77045 return pad(x, pattern);
77046 });
77047 }
77048 var ZeroPadding2D = /*#__PURE__*/function (_Layer) {
77049 _inherits(ZeroPadding2D, _Layer);
77050 var _super = _createSuper(ZeroPadding2D);
77051 function ZeroPadding2D(args) {
77052 var _this;
77053 _classCallCheck(this, ZeroPadding2D);
77054 if (args == null) {
77055 args = {};
77056 }
77057 _this = _super.call(this, args);
77058 _this.dataFormat = args.dataFormat == null ? imageDataFormat() : args.dataFormat;
77059 // TODO(cais): Maybe refactor the following logic surrounding `padding`
77060 // into a helper method.
77061 if (args.padding == null) {
77062 _this.padding = [[1, 1], [1, 1]];
77063 } else if (typeof args.padding === 'number') {
77064 _this.padding = [[args.padding, args.padding], [args.padding, args.padding]];
77065 } else {
77066 args.padding = args.padding;
77067 if (args.padding.length !== 2) {
77068 throw new ValueError("ZeroPadding2D expects padding to be a length-2 array, but " + "received a length-".concat(args.padding.length, " array."));
77069 }
77070 var heightPadding;
77071 var widthPadding;
77072 if (typeof args.padding[0] === 'number') {
77073 heightPadding = [args.padding[0], args.padding[0]];
77074 widthPadding = [args.padding[1], args.padding[1]];
77075 } else {
77076 args.padding = args.padding;
77077 if (args.padding[0].length !== 2) {
77078 throw new ValueError("ZeroPadding2D expects height padding to be a length-2 array, " + "but received a length-".concat(args.padding[0].length, " array."));
77079 }
77080 heightPadding = args.padding[0];
77081 if (args.padding[1].length !== 2) {
77082 throw new ValueError("ZeroPadding2D expects width padding to be a length-2 array, " + "but received a length-".concat(args.padding[1].length, " array."));
77083 }
77084 widthPadding = args.padding[1];
77085 }
77086 _this.padding = [heightPadding, widthPadding];
77087 }
77088 _this.inputSpec = [new InputSpec({
77089 ndim: 4
77090 })];
77091 return _this;
77092 }
77093 _createClass(ZeroPadding2D, [{
77094 key: "computeOutputShape",
77095 value: function computeOutputShape(inputShape) {
77096 inputShape = getExactlyOneShape(inputShape);
77097 var rows;
77098 var cols;
77099 if (this.dataFormat === 'channelsFirst') {
77100 if (inputShape[2] != null && inputShape[2] >= 0) {
77101 rows = inputShape[2] + this.padding[0][0] + this.padding[0][1];
77102 } else {
77103 rows = null;
77104 }
77105 if (inputShape[3] != null && inputShape[3] >= 0) {
77106 cols = inputShape[3] + this.padding[1][0] + this.padding[1][1];
77107 } else {
77108 cols = null;
77109 }
77110 return [inputShape[0], inputShape[1], rows, cols];
77111 } else {
77112 if (inputShape[1] != null && inputShape[1] >= 0) {
77113 rows = inputShape[1] + this.padding[0][0] + this.padding[0][1];
77114 } else {
77115 rows = null;
77116 }
77117 if (inputShape[2] != null && inputShape[2] >= 0) {
77118 cols = inputShape[2] + this.padding[1][0] + this.padding[1][1];
77119 } else {
77120 cols = null;
77121 }
77122 return [inputShape[0], rows, cols, inputShape[3]];
77123 }
77124 }
77125 }, {
77126 key: "call",
77127 value: function call(inputs, kwargs) {
77128 var _this2 = this;
77129 return tidy(function () {
77130 return spatial2dPadding(getExactlyOneTensor(inputs), _this2.padding, _this2.dataFormat);
77131 });
77132 }
77133 }, {
77134 key: "getConfig",
77135 value: function getConfig() {
77136 var config = {
77137 padding: this.padding,
77138 dataFormat: this.dataFormat
77139 };
77140 var baseConfig = _get(_getPrototypeOf(ZeroPadding2D.prototype), "getConfig", this).call(this);
77141 Object.assign(config, baseConfig);
77142 return config;
77143 }
77144 }]);
77145 return ZeroPadding2D;
77146 }(Layer);
77147 /** @nocollapse */
77148 ZeroPadding2D.className = 'ZeroPadding2D';
77149 registerClass(ZeroPadding2D);
77150
77151 /**
77152 * 2D pooling.
77153 * @param x
77154 * @param poolSize
77155 * @param strides strides. Defaults to [1, 1].
77156 * @param padding padding. Defaults to 'valid'.
77157 * @param dataFormat data format. Defaults to 'channelsLast'.
77158 * @param poolMode Mode of pooling. Defaults to 'max'.
77159 * @returns Result of the 2D pooling.
77160 */
77161 function pool2d(x, poolSize, strides, padding, dataFormat, poolMode) {
77162 return tidy(function () {
77163 checkDataFormat(dataFormat);
77164 checkPoolMode(poolMode);
77165 checkPaddingMode(padding);
77166 if (strides == null) {
77167 strides = [1, 1];
77168 }
77169 if (padding == null) {
77170 padding = 'valid';
77171 }
77172 if (dataFormat == null) {
77173 dataFormat = imageDataFormat();
77174 }
77175 if (poolMode == null) {
77176 poolMode = 'max';
77177 }
77178 // TODO(cais): Remove the preprocessing step once deeplearn.js supports
77179 // dataFormat as an input argument.
77180 x = preprocessConv2DInput(x, dataFormat); // x is NHWC after preprocessing.
77181 var y;
77182 var paddingString = padding === 'same' ? 'same' : 'valid';
77183 if (poolMode === 'max') {
77184 // TODO(cais): Rank check?
77185 y = maxPool$2(x, poolSize, strides, paddingString);
77186 } else {
77187 // 'avg'
77188 // TODO(cais): Check the dtype and rank of x and give clear error message
77189 // if those are incorrect.
77190 y = avgPool$2(
77191 // TODO(cais): Rank check?
77192 x, poolSize, strides, paddingString);
77193 }
77194 if (dataFormat === 'channelsFirst') {
77195 y = transpose$2(y, [0, 3, 1, 2]); // NHWC -> NCHW.
77196 }
77197
77198 return y;
77199 });
77200 }
77201 /**
77202 * 3D pooling.
77203 * @param x
77204 * @param poolSize. Default to [1, 1, 1].
77205 * @param strides strides. Defaults to [1, 1, 1].
77206 * @param padding padding. Defaults to 'valid'.
77207 * @param dataFormat data format. Defaults to 'channelsLast'.
77208 * @param poolMode Mode of pooling. Defaults to 'max'.
77209 * @returns Result of the 3D pooling.
77210 */
77211 function pool3d$1(x, poolSize, strides, padding, dataFormat, poolMode) {
77212 return tidy(function () {
77213 checkDataFormat(dataFormat);
77214 checkPoolMode(poolMode);
77215 checkPaddingMode(padding);
77216 if (strides == null) {
77217 strides = [1, 1, 1];
77218 }
77219 if (padding == null) {
77220 padding = 'valid';
77221 }
77222 if (dataFormat == null) {
77223 dataFormat = imageDataFormat();
77224 }
77225 if (poolMode == null) {
77226 poolMode = 'max';
77227 }
77228 // x is NDHWC after preprocessing.
77229 x = preprocessConv3DInput(x, dataFormat);
77230 var y;
77231 var paddingString = padding === 'same' ? 'same' : 'valid';
77232 if (poolMode === 'max') {
77233 y = maxPool3d$1(x, poolSize, strides, paddingString);
77234 } else {
77235 // 'avg'
77236 y = avgPool3d$1(x, poolSize, strides, paddingString);
77237 }
77238 if (dataFormat === 'channelsFirst') {
77239 y = transpose$2(y, [0, 4, 1, 2, 3]); // NDHWC -> NCDHW.
77240 }
77241
77242 return y;
77243 });
77244 }
77245 /**
77246 * Abstract class for different pooling 1D layers.
77247 */
77248 var Pooling1D = /*#__PURE__*/function (_Layer) {
77249 _inherits(Pooling1D, _Layer);
77250 var _super = _createSuper(Pooling1D);
77251 /**
77252 *
77253 * @param args Parameters for the Pooling layer.
77254 *
77255 * config.poolSize defaults to 2.
77256 */
77257 function Pooling1D(args) {
77258 var _this;
77259 _classCallCheck(this, Pooling1D);
77260 if (args.poolSize == null) {
77261 args.poolSize = 2;
77262 }
77263 _this = _super.call(this, args);
77264 if (typeof args.poolSize === 'number') {
77265 _this.poolSize = [args.poolSize];
77266 } else if (Array.isArray(args.poolSize) && args.poolSize.length === 1 && typeof args.poolSize[0] === 'number') {
77267 _this.poolSize = args.poolSize;
77268 } else {
77269 throw new ValueError("poolSize for 1D convolutional layer must be a number or an " + "Array of a single number, but received " + "".concat(JSON.stringify(args.poolSize)));
77270 }
77271 assertPositiveInteger(_this.poolSize, 'poolSize');
77272 if (args.strides == null) {
77273 _this.strides = _this.poolSize;
77274 } else {
77275 if (typeof args.strides === 'number') {
77276 _this.strides = [args.strides];
77277 } else if (Array.isArray(args.strides) && args.strides.length === 1 && typeof args.strides[0] === 'number') {
77278 _this.strides = args.strides;
77279 } else {
77280 throw new ValueError("strides for 1D convolutional layer must be a number or an " + "Array of a single number, but received " + "".concat(JSON.stringify(args.strides)));
77281 }
77282 }
77283 assertPositiveInteger(_this.strides, 'strides');
77284 _this.padding = args.padding == null ? 'valid' : args.padding;
77285 checkPaddingMode(_this.padding);
77286 _this.inputSpec = [new InputSpec({
77287 ndim: 3
77288 })];
77289 return _this;
77290 }
77291 _createClass(Pooling1D, [{
77292 key: "computeOutputShape",
77293 value: function computeOutputShape(inputShape) {
77294 inputShape = getExactlyOneShape(inputShape);
77295 var length = convOutputLength(inputShape[1], this.poolSize[0], this.padding, this.strides[0]);
77296 return [inputShape[0], length, inputShape[2]];
77297 }
77298 }, {
77299 key: "call",
77300 value: function call(inputs, kwargs) {
77301 var _this2 = this;
77302 return tidy(function () {
77303 _this2.invokeCallHook(inputs, kwargs);
77304 // Add dummy last dimension.
77305 inputs = expandDims$2(getExactlyOneTensor(inputs), 2);
77306 var output = _this2.poolingFunction(getExactlyOneTensor(inputs), [_this2.poolSize[0], 1], [_this2.strides[0], 1], _this2.padding, 'channelsLast');
77307 // Remove dummy last dimension.
77308 return squeeze(output, [2]);
77309 });
77310 }
77311 }, {
77312 key: "getConfig",
77313 value: function getConfig() {
77314 var config = {
77315 poolSize: this.poolSize,
77316 padding: this.padding,
77317 strides: this.strides
77318 };
77319 var baseConfig = _get(_getPrototypeOf(Pooling1D.prototype), "getConfig", this).call(this);
77320 Object.assign(config, baseConfig);
77321 return config;
77322 }
77323 }]);
77324 return Pooling1D;
77325 }(Layer);
77326 var MaxPooling1D = /*#__PURE__*/function (_Pooling1D) {
77327 _inherits(MaxPooling1D, _Pooling1D);
77328 var _super2 = _createSuper(MaxPooling1D);
77329 function MaxPooling1D(args) {
77330 _classCallCheck(this, MaxPooling1D);
77331 return _super2.call(this, args);
77332 }
77333 _createClass(MaxPooling1D, [{
77334 key: "poolingFunction",
77335 value: function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
77336 checkDataFormat(dataFormat);
77337 checkPaddingMode(padding);
77338 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max');
77339 }
77340 }]);
77341 return MaxPooling1D;
77342 }(Pooling1D);
77343 /** @nocollapse */
77344 MaxPooling1D.className = 'MaxPooling1D';
77345 registerClass(MaxPooling1D);
77346 var AveragePooling1D = /*#__PURE__*/function (_Pooling1D2) {
77347 _inherits(AveragePooling1D, _Pooling1D2);
77348 var _super3 = _createSuper(AveragePooling1D);
77349 function AveragePooling1D(args) {
77350 _classCallCheck(this, AveragePooling1D);
77351 return _super3.call(this, args);
77352 }
77353 _createClass(AveragePooling1D, [{
77354 key: "poolingFunction",
77355 value: function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
77356 checkDataFormat(dataFormat);
77357 checkPaddingMode(padding);
77358 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg');
77359 }
77360 }]);
77361 return AveragePooling1D;
77362 }(Pooling1D);
77363 /** @nocollapse */
77364 AveragePooling1D.className = 'AveragePooling1D';
77365 registerClass(AveragePooling1D);
77366 /**
77367 * Abstract class for different pooling 2D layers.
77368 */
77369 var Pooling2D = /*#__PURE__*/function (_Layer2) {
77370 _inherits(Pooling2D, _Layer2);
77371 var _super4 = _createSuper(Pooling2D);
77372 function Pooling2D(args) {
77373 var _this3;
77374 _classCallCheck(this, Pooling2D);
77375 if (args.poolSize == null) {
77376 args.poolSize = [2, 2];
77377 }
77378 _this3 = _super4.call(this, args);
77379 _this3.poolSize = Array.isArray(args.poolSize) ? args.poolSize : [args.poolSize, args.poolSize];
77380 if (args.strides == null) {
77381 _this3.strides = _this3.poolSize;
77382 } else if (Array.isArray(args.strides)) {
77383 if (args.strides.length !== 2) {
77384 throw new ValueError("If the strides property of a 2D pooling layer is an Array, " + "it is expected to have a length of 2, but received length " + "".concat(args.strides.length, "."));
77385 }
77386 _this3.strides = args.strides;
77387 } else {
77388 // `config.strides` is a number.
77389 _this3.strides = [args.strides, args.strides];
77390 }
77391 assertPositiveInteger(_this3.poolSize, 'poolSize');
77392 assertPositiveInteger(_this3.strides, 'strides');
77393 _this3.padding = args.padding == null ? 'valid' : args.padding;
77394 _this3.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat;
77395 checkDataFormat(_this3.dataFormat);
77396 checkPaddingMode(_this3.padding);
77397 _this3.inputSpec = [new InputSpec({
77398 ndim: 4
77399 })];
77400 return _this3;
77401 }
77402 _createClass(Pooling2D, [{
77403 key: "computeOutputShape",
77404 value: function computeOutputShape(inputShape) {
77405 inputShape = getExactlyOneShape(inputShape);
77406 var rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
77407 var cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
77408 rows = convOutputLength(rows, this.poolSize[0], this.padding, this.strides[0]);
77409 cols = convOutputLength(cols, this.poolSize[1], this.padding, this.strides[1]);
77410 if (this.dataFormat === 'channelsFirst') {
77411 return [inputShape[0], inputShape[1], rows, cols];
77412 } else {
77413 return [inputShape[0], rows, cols, inputShape[3]];
77414 }
77415 }
77416 }, {
77417 key: "call",
77418 value: function call(inputs, kwargs) {
77419 var _this4 = this;
77420 return tidy(function () {
77421 _this4.invokeCallHook(inputs, kwargs);
77422 return _this4.poolingFunction(getExactlyOneTensor(inputs), _this4.poolSize, _this4.strides, _this4.padding, _this4.dataFormat);
77423 });
77424 }
77425 }, {
77426 key: "getConfig",
77427 value: function getConfig() {
77428 var config = {
77429 poolSize: this.poolSize,
77430 padding: this.padding,
77431 strides: this.strides,
77432 dataFormat: this.dataFormat
77433 };
77434 var baseConfig = _get(_getPrototypeOf(Pooling2D.prototype), "getConfig", this).call(this);
77435 Object.assign(config, baseConfig);
77436 return config;
77437 }
77438 }]);
77439 return Pooling2D;
77440 }(Layer);
77441 var MaxPooling2D = /*#__PURE__*/function (_Pooling2D) {
77442 _inherits(MaxPooling2D, _Pooling2D);
77443 var _super5 = _createSuper(MaxPooling2D);
77444 function MaxPooling2D(args) {
77445 _classCallCheck(this, MaxPooling2D);
77446 return _super5.call(this, args);
77447 }
77448 _createClass(MaxPooling2D, [{
77449 key: "poolingFunction",
77450 value: function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
77451 checkDataFormat(dataFormat);
77452 checkPaddingMode(padding);
77453 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max');
77454 }
77455 }]);
77456 return MaxPooling2D;
77457 }(Pooling2D);
77458 /** @nocollapse */
77459 MaxPooling2D.className = 'MaxPooling2D';
77460 registerClass(MaxPooling2D);
77461 var AveragePooling2D = /*#__PURE__*/function (_Pooling2D2) {
77462 _inherits(AveragePooling2D, _Pooling2D2);
77463 var _super6 = _createSuper(AveragePooling2D);
77464 function AveragePooling2D(args) {
77465 _classCallCheck(this, AveragePooling2D);
77466 return _super6.call(this, args);
77467 }
77468 _createClass(AveragePooling2D, [{
77469 key: "poolingFunction",
77470 value: function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
77471 checkDataFormat(dataFormat);
77472 checkPaddingMode(padding);
77473 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg');
77474 }
77475 }]);
77476 return AveragePooling2D;
77477 }(Pooling2D);
77478 /** @nocollapse */
77479 AveragePooling2D.className = 'AveragePooling2D';
77480 registerClass(AveragePooling2D);
77481 /**
77482 * Abstract class for different pooling 3D layers.
77483 */
77484 var Pooling3D = /*#__PURE__*/function (_Layer3) {
77485 _inherits(Pooling3D, _Layer3);
77486 var _super7 = _createSuper(Pooling3D);
77487 function Pooling3D(args) {
77488 var _this5;
77489 _classCallCheck(this, Pooling3D);
77490 if (args.poolSize == null) {
77491 args.poolSize = [2, 2, 2];
77492 }
77493 _this5 = _super7.call(this, args);
77494 _this5.poolSize = Array.isArray(args.poolSize) ? args.poolSize : [args.poolSize, args.poolSize, args.poolSize];
77495 if (args.strides == null) {
77496 _this5.strides = _this5.poolSize;
77497 } else if (Array.isArray(args.strides)) {
77498 if (args.strides.length !== 3) {
77499 throw new ValueError("If the strides property of a 3D pooling layer is an Array, " + "it is expected to have a length of 3, but received length " + "".concat(args.strides.length, "."));
77500 }
77501 _this5.strides = args.strides;
77502 } else {
77503 // `config.strides` is a number.
77504 _this5.strides = [args.strides, args.strides, args.strides];
77505 }
77506 assertPositiveInteger(_this5.poolSize, 'poolSize');
77507 assertPositiveInteger(_this5.strides, 'strides');
77508 _this5.padding = args.padding == null ? 'valid' : args.padding;
77509 _this5.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat;
77510 checkDataFormat(_this5.dataFormat);
77511 checkPaddingMode(_this5.padding);
77512 _this5.inputSpec = [new InputSpec({
77513 ndim: 5
77514 })];
77515 return _this5;
77516 }
77517 _createClass(Pooling3D, [{
77518 key: "computeOutputShape",
77519 value: function computeOutputShape(inputShape) {
77520 inputShape = getExactlyOneShape(inputShape);
77521 var depths = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
77522 var rows = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
77523 var cols = this.dataFormat === 'channelsFirst' ? inputShape[4] : inputShape[3];
77524 depths = convOutputLength(depths, this.poolSize[0], this.padding, this.strides[0]);
77525 rows = convOutputLength(rows, this.poolSize[1], this.padding, this.strides[1]);
77526 cols = convOutputLength(cols, this.poolSize[2], this.padding, this.strides[2]);
77527 if (this.dataFormat === 'channelsFirst') {
77528 return [inputShape[0], inputShape[1], depths, rows, cols];
77529 } else {
77530 return [inputShape[0], depths, rows, cols, inputShape[4]];
77531 }
77532 }
77533 }, {
77534 key: "call",
77535 value: function call(inputs, kwargs) {
77536 var _this6 = this;
77537 return tidy(function () {
77538 _this6.invokeCallHook(inputs, kwargs);
77539 return _this6.poolingFunction(getExactlyOneTensor(inputs), _this6.poolSize, _this6.strides, _this6.padding, _this6.dataFormat);
77540 });
77541 }
77542 }, {
77543 key: "getConfig",
77544 value: function getConfig() {
77545 var config = {
77546 poolSize: this.poolSize,
77547 padding: this.padding,
77548 strides: this.strides,
77549 dataFormat: this.dataFormat
77550 };
77551 var baseConfig = _get(_getPrototypeOf(Pooling3D.prototype), "getConfig", this).call(this);
77552 Object.assign(config, baseConfig);
77553 return config;
77554 }
77555 }]);
77556 return Pooling3D;
77557 }(Layer);
77558 var MaxPooling3D = /*#__PURE__*/function (_Pooling3D) {
77559 _inherits(MaxPooling3D, _Pooling3D);
77560 var _super8 = _createSuper(MaxPooling3D);
77561 function MaxPooling3D(args) {
77562 _classCallCheck(this, MaxPooling3D);
77563 return _super8.call(this, args);
77564 }
77565 _createClass(MaxPooling3D, [{
77566 key: "poolingFunction",
77567 value: function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
77568 checkDataFormat(dataFormat);
77569 checkPaddingMode(padding);
77570 return pool3d$1(inputs, poolSize, strides, padding, dataFormat, 'max');
77571 }
77572 }]);
77573 return MaxPooling3D;
77574 }(Pooling3D);
77575 /** @nocollapse */
77576 MaxPooling3D.className = 'MaxPooling3D';
77577 registerClass(MaxPooling3D);
77578 var AveragePooling3D = /*#__PURE__*/function (_Pooling3D2) {
77579 _inherits(AveragePooling3D, _Pooling3D2);
77580 var _super9 = _createSuper(AveragePooling3D);
77581 function AveragePooling3D(args) {
77582 _classCallCheck(this, AveragePooling3D);
77583 return _super9.call(this, args);
77584 }
77585 _createClass(AveragePooling3D, [{
77586 key: "poolingFunction",
77587 value: function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
77588 checkDataFormat(dataFormat);
77589 checkPaddingMode(padding);
77590 return pool3d$1(inputs, poolSize, strides, padding, dataFormat, 'avg');
77591 }
77592 }]);
77593 return AveragePooling3D;
77594 }(Pooling3D);
77595 /** @nocollapse */
77596 AveragePooling3D.className = 'AveragePooling3D';
77597 registerClass(AveragePooling3D);
77598 /**
77599 * Abstract class for different global pooling 1D layers.
77600 */
77601 var GlobalPooling1D = /*#__PURE__*/function (_Layer4) {
77602 _inherits(GlobalPooling1D, _Layer4);
77603 var _super10 = _createSuper(GlobalPooling1D);
77604 function GlobalPooling1D(args) {
77605 var _this7;
77606 _classCallCheck(this, GlobalPooling1D);
77607 _this7 = _super10.call(this, args);
77608 _this7.inputSpec = [new InputSpec({
77609 ndim: 3
77610 })];
77611 return _this7;
77612 }
77613 _createClass(GlobalPooling1D, [{
77614 key: "computeOutputShape",
77615 value: function computeOutputShape(inputShape) {
77616 return [inputShape[0], inputShape[2]];
77617 }
77618 }, {
77619 key: "call",
77620 value: function call(inputs, kwargs) {
77621 throw new NotImplementedError();
77622 }
77623 }]);
77624 return GlobalPooling1D;
77625 }(Layer);
77626 var GlobalAveragePooling1D = /*#__PURE__*/function (_GlobalPooling1D) {
77627 _inherits(GlobalAveragePooling1D, _GlobalPooling1D);
77628 var _super11 = _createSuper(GlobalAveragePooling1D);
77629 function GlobalAveragePooling1D(args) {
77630 _classCallCheck(this, GlobalAveragePooling1D);
77631 return _super11.call(this, args || {});
77632 }
77633 _createClass(GlobalAveragePooling1D, [{
77634 key: "call",
77635 value: function call(inputs, kwargs) {
77636 return tidy(function () {
77637 var input = getExactlyOneTensor(inputs);
77638 return mean$3(input, 1);
77639 });
77640 }
77641 }]);
77642 return GlobalAveragePooling1D;
77643 }(GlobalPooling1D);
77644 /** @nocollapse */
77645 GlobalAveragePooling1D.className = 'GlobalAveragePooling1D';
77646 registerClass(GlobalAveragePooling1D);
77647 var GlobalMaxPooling1D = /*#__PURE__*/function (_GlobalPooling1D2) {
77648 _inherits(GlobalMaxPooling1D, _GlobalPooling1D2);
77649 var _super12 = _createSuper(GlobalMaxPooling1D);
77650 function GlobalMaxPooling1D(args) {
77651 _classCallCheck(this, GlobalMaxPooling1D);
77652 return _super12.call(this, args || {});
77653 }
77654 _createClass(GlobalMaxPooling1D, [{
77655 key: "call",
77656 value: function call(inputs, kwargs) {
77657 return tidy(function () {
77658 var input = getExactlyOneTensor(inputs);
77659 return max$3(input, 1);
77660 });
77661 }
77662 }]);
77663 return GlobalMaxPooling1D;
77664 }(GlobalPooling1D);
77665 /** @nocollapse */
77666 GlobalMaxPooling1D.className = 'GlobalMaxPooling1D';
77667 registerClass(GlobalMaxPooling1D);
77668 /**
77669 * Abstract class for different global pooling 2D layers.
77670 */
77671 var GlobalPooling2D = /*#__PURE__*/function (_Layer5) {
77672 _inherits(GlobalPooling2D, _Layer5);
77673 var _super13 = _createSuper(GlobalPooling2D);
77674 function GlobalPooling2D(args) {
77675 var _this8;
77676 _classCallCheck(this, GlobalPooling2D);
77677 _this8 = _super13.call(this, args);
77678 _this8.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat;
77679 checkDataFormat(_this8.dataFormat);
77680 _this8.inputSpec = [new InputSpec({
77681 ndim: 4
77682 })];
77683 return _this8;
77684 }
77685 _createClass(GlobalPooling2D, [{
77686 key: "computeOutputShape",
77687 value: function computeOutputShape(inputShape) {
77688 inputShape = inputShape;
77689 if (this.dataFormat === 'channelsLast') {
77690 return [inputShape[0], inputShape[3]];
77691 } else {
77692 return [inputShape[0], inputShape[1]];
77693 }
77694 }
77695 }, {
77696 key: "call",
77697 value: function call(inputs, kwargs) {
77698 throw new NotImplementedError();
77699 }
77700 }, {
77701 key: "getConfig",
77702 value: function getConfig() {
77703 var config = {
77704 dataFormat: this.dataFormat
77705 };
77706 var baseConfig = _get(_getPrototypeOf(GlobalPooling2D.prototype), "getConfig", this).call(this);
77707 Object.assign(config, baseConfig);
77708 return config;
77709 }
77710 }]);
77711 return GlobalPooling2D;
77712 }(Layer);
77713 var GlobalAveragePooling2D = /*#__PURE__*/function (_GlobalPooling2D) {
77714 _inherits(GlobalAveragePooling2D, _GlobalPooling2D);
77715 var _super14 = _createSuper(GlobalAveragePooling2D);
77716 function GlobalAveragePooling2D() {
77717 _classCallCheck(this, GlobalAveragePooling2D);
77718 return _super14.apply(this, arguments);
77719 }
77720 _createClass(GlobalAveragePooling2D, [{
77721 key: "call",
77722 value: function call(inputs, kwargs) {
77723 var _this9 = this;
77724 return tidy(function () {
77725 var input = getExactlyOneTensor(inputs);
77726 if (_this9.dataFormat === 'channelsLast') {
77727 return mean$3(input, [1, 2]);
77728 } else {
77729 return mean$3(input, [2, 3]);
77730 }
77731 });
77732 }
77733 }]);
77734 return GlobalAveragePooling2D;
77735 }(GlobalPooling2D);
77736 /** @nocollapse */
77737 GlobalAveragePooling2D.className = 'GlobalAveragePooling2D';
77738 registerClass(GlobalAveragePooling2D);
77739 var GlobalMaxPooling2D = /*#__PURE__*/function (_GlobalPooling2D2) {
77740 _inherits(GlobalMaxPooling2D, _GlobalPooling2D2);
77741 var _super15 = _createSuper(GlobalMaxPooling2D);
77742 function GlobalMaxPooling2D() {
77743 _classCallCheck(this, GlobalMaxPooling2D);
77744 return _super15.apply(this, arguments);
77745 }
77746 _createClass(GlobalMaxPooling2D, [{
77747 key: "call",
77748 value: function call(inputs, kwargs) {
77749 var _this10 = this;
77750 return tidy(function () {
77751 var input = getExactlyOneTensor(inputs);
77752 if (_this10.dataFormat === 'channelsLast') {
77753 return max$3(input, [1, 2]);
77754 } else {
77755 return max$3(input, [2, 3]);
77756 }
77757 });
77758 }
77759 }]);
77760 return GlobalMaxPooling2D;
77761 }(GlobalPooling2D);
77762 /** @nocollapse */
77763 GlobalMaxPooling2D.className = 'GlobalMaxPooling2D';
77764 registerClass(GlobalMaxPooling2D);
77765
77766 /**
77767 * Abstract wrapper base class.
77768 *
77769 * Wrappers take another layer and augment it in various ways.
77770 * Do not use this class as a layer, it is only an abstract base class.
77771 * Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
77772 */
77773 var Wrapper = /*#__PURE__*/function (_Layer) {
77774 _inherits(Wrapper, _Layer);
77775 var _super = _createSuper(Wrapper);
77776 function Wrapper(args) {
77777 var _this;
77778 _classCallCheck(this, Wrapper);
77779 // Porting Note: In PyKeras, `self.layer` is set prior to the calling
77780 // `super()`. But we can't do that here due to TypeScript's restriction.
77781 // See: https://github.com/Microsoft/TypeScript/issues/8277
77782 // As a result, we have to add checks in `get trainable()` and
77783 // `set trainable()` below in order to prevent using `this.layer` when
77784 // its value is `undefined`. The super constructor does use the getter
77785 // and the setter of `this.layer`.
77786 _this = _super.call(this, args);
77787 _this.layer = args.layer;
77788 return _this;
77789 }
77790 _createClass(Wrapper, [{
77791 key: "build",
77792 value: function build(inputShape) {
77793 this.built = true;
77794 }
77795 // TODO(cais): Implement activityRegularizer getter.
77796 }, {
77797 key: "trainable",
77798 get: function get() {
77799 // Porting Note: the check of `this.layer` here is necessary due to the
77800 // way the `constructor` of this class is written (see Porting Note
77801 // above).
77802 if (this.layer != null) {
77803 return this.layer.trainable;
77804 } else {
77805 return false;
77806 }
77807 },
77808 set: function set(value) {
77809 // Porting Note: the check of `this.layer` here is necessary due to the
77810 // way the `constructor` of this class is written (see Porting Note
77811 // above).
77812 if (this.layer != null) {
77813 this.layer.trainable = value;
77814 }
77815 }
77816 }, {
77817 key: "trainableWeights",
77818 get: function get() {
77819 return this.layer.trainableWeights;
77820 }
77821 // TODO(cais): Implement setter for trainableWeights.
77822 }, {
77823 key: "nonTrainableWeights",
77824 get: function get() {
77825 return this.layer.nonTrainableWeights;
77826 }
77827 // TODO(cais): Implement setter for nonTrainableWeights.
77828 }, {
77829 key: "updates",
77830 get: function get() {
77831 // tslint:disable-next-line:no-any
77832 return this.layer._updates;
77833 }
77834 // TODO(cais): Implement getUpdatesFor().
77835 }, {
77836 key: "losses",
77837 get: function get() {
77838 return this.layer.losses;
77839 }
77840 // TODO(cais): Implement getLossesFor().
77841 }, {
77842 key: "getWeights",
77843 value: function getWeights() {
77844 return this.layer.getWeights();
77845 }
77846 }, {
77847 key: "setWeights",
77848 value: function setWeights(weights) {
77849 this.layer.setWeights(weights);
77850 }
77851 }, {
77852 key: "getConfig",
77853 value: function getConfig() {
77854 var config = {
77855 'layer': {
77856 'className': this.layer.getClassName(),
77857 'config': this.layer.getConfig()
77858 }
77859 };
77860 var baseConfig = _get(_getPrototypeOf(Wrapper.prototype), "getConfig", this).call(this);
77861 Object.assign(config, baseConfig);
77862 return config;
77863 }
77864 }, {
77865 key: "setFastWeightInitDuringBuild",
77866 value: function setFastWeightInitDuringBuild(value) {
77867 _get(_getPrototypeOf(Wrapper.prototype), "setFastWeightInitDuringBuild", this).call(this, value);
77868 if (this.layer != null) {
77869 this.layer.setFastWeightInitDuringBuild(value);
77870 }
77871 }
77872 /** @nocollapse */
77873 }], [{
77874 key: "fromConfig",
77875 value: function fromConfig(cls, config) {
77876 var customObjects = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : {};
77877 var layerConfig = config['layer'];
77878 var layer = deserialize(layerConfig, customObjects);
77879 delete config['layer'];
77880 var newConfig = {
77881 layer: layer
77882 };
77883 Object.assign(newConfig, config);
77884 return new cls(newConfig);
77885 }
77886 }]);
77887 return Wrapper;
77888 }(Layer);
77889 var TimeDistributed = /*#__PURE__*/function (_Wrapper) {
77890 _inherits(TimeDistributed, _Wrapper);
77891 var _super2 = _createSuper(TimeDistributed);
77892 function TimeDistributed(args) {
77893 var _this2;
77894 _classCallCheck(this, TimeDistributed);
77895 _this2 = _super2.call(this, args);
77896 _this2.supportsMasking = true;
77897 return _this2;
77898 }
77899 _createClass(TimeDistributed, [{
77900 key: "build",
77901 value: function build(inputShape) {
77902 inputShape = getExactlyOneShape(inputShape);
77903 if (inputShape.length < 3) {
77904 throw new ValueError("TimeDistributed layer expects an input shape >= 3D, but received " + "input shape ".concat(JSON.stringify(inputShape)));
77905 }
77906 this.inputSpec = [{
77907 shape: inputShape
77908 }];
77909 var childInputShape = [inputShape[0]].concat(inputShape.slice(2));
77910 if (!this.layer.built) {
77911 this.layer.build(childInputShape);
77912 this.layer.built = true;
77913 }
77914 _get(_getPrototypeOf(TimeDistributed.prototype), "build", this).call(this, inputShape);
77915 }
77916 }, {
77917 key: "computeOutputShape",
77918 value: function computeOutputShape(inputShape) {
77919 inputShape = getExactlyOneShape(inputShape);
77920 var childInputShape = [inputShape[0]].concat(inputShape.slice(2));
77921 var childOutputShape = this.layer.computeOutputShape(childInputShape);
77922 var timesteps = inputShape[1];
77923 return [childOutputShape[0], timesteps].concat(childOutputShape.slice(1));
77924 }
77925 }, {
77926 key: "call",
77927 value: function call(inputs, kwargs) {
77928 var _this3 = this;
77929 return tidy(function () {
77930 // TODO(cais): Add 'training' and 'useLearningPhase' to kwargs.
77931 inputs = getExactlyOneTensor(inputs);
77932 // Porting Note: In tfjs-layers, `inputs` are always concrete tensor
77933 // values. Hence the inputs can't have an undetermined first (batch)
77934 // dimension, which is why we always use the K.rnn approach here.
77935 var step = function step(inputs, states) {
77936 // TODO(cais): Add useLearningPhase.
77937 // NOTE(cais): `layer.call` may return a length-1 array of Tensor in
77938 // some cases (e.g., `layer` is a `Sequential` instance), which is
77939 // why `getExactlyOneTensor` is used below.
77940 var output = getExactlyOneTensor(_this3.layer.call(inputs, kwargs));
77941 return [output, []];
77942 };
77943 var rnnOutputs = rnn$1(step, inputs, [], false /* goBackwards */, null /* mask */, null /* constants */, false /* unroll */, true /* needPerStepOutputs */);
77944 var y = rnnOutputs[1];
77945 // TODO(cais): Add activity regularization.
77946 // TODO(cais): Add useLearningPhase.
77947 return y;
77948 });
77949 }
77950 }]);
77951 return TimeDistributed;
77952 }(Wrapper);
77953 /** @nocollapse */
77954 TimeDistributed.className = 'TimeDistributed';
77955 registerClass(TimeDistributed);
77956 function checkBidirectionalMergeMode(value) {
77957 checkStringTypeUnionValue(VALID_BIDIRECTIONAL_MERGE_MODES, 'BidirectionalMergeMode', value);
77958 }
77959 var DEFAULT_BIDIRECTIONAL_MERGE_MODE = 'concat';
77960 var Bidirectional = /*#__PURE__*/function (_Wrapper2) {
77961 _inherits(Bidirectional, _Wrapper2);
77962 var _super3 = _createSuper(Bidirectional);
77963 function Bidirectional(args) {
77964 var _this4;
77965 _classCallCheck(this, Bidirectional);
77966 _this4 = _super3.call(this, args);
77967 // Note: When creating `this.forwardLayer`, the original Layer object
77968 // (`config.layer`) ought to be cloned. This is why we call
77969 // `getConfig()` followed by `deserialize()`. Without this cloning,
77970 // the layer names saved during serialization will incorrectly contain
77971 // the 'forward_' prefix. In Python Keras, this is done using
77972 // `copy.copy` (shallow copy), which does not have a simple equivalent
77973 // in JavaScript. JavaScript's `Object.assign()` does not copy
77974 // methods.
77975 var layerConfig = args.layer.getConfig();
77976 var forwDict = {};
77977 forwDict['className'] = args.layer.getClassName();
77978 forwDict['config'] = layerConfig;
77979 _this4.forwardLayer = deserialize(forwDict);
77980 layerConfig['goBackwards'] = layerConfig['goBackwards'] === true ? false : true;
77981 var backDict = {};
77982 backDict['className'] = args.layer.getClassName();
77983 backDict['config'] = layerConfig;
77984 _this4.backwardLayer = deserialize(backDict);
77985 _this4.forwardLayer.name = 'forward_' + _this4.forwardLayer.name;
77986 _this4.backwardLayer.name = 'backward_' + _this4.backwardLayer.name;
77987 _this4.mergeMode = args.mergeMode === undefined ? DEFAULT_BIDIRECTIONAL_MERGE_MODE : args.mergeMode;
77988 checkBidirectionalMergeMode(_this4.mergeMode);
77989 if (args.weights) {
77990 throw new NotImplementedError('weights support is not implemented for Bidirectional layer yet.');
77991 }
77992 _this4._stateful = args.layer.stateful;
77993 _this4.returnSequences = args.layer.returnSequences;
77994 _this4.returnState = args.layer.returnState;
77995 _this4.supportsMasking = true;
77996 _this4._trainable = true;
77997 _this4.inputSpec = args.layer.inputSpec;
77998 _this4.numConstants = null;
77999 return _this4;
78000 }
78001 _createClass(Bidirectional, [{
78002 key: "trainable",
78003 get: function get() {
78004 return this._trainable;
78005 },
78006 set: function set(value) {
78007 // Porting Note: the check of `this.layer` here is necessary due to the
78008 // way the `constructor` of this class is written (see Porting Note
78009 // above).
78010 this._trainable = value;
78011 if (this.forwardLayer != null) {
78012 this.forwardLayer.trainable = value;
78013 }
78014 if (this.backwardLayer != null) {
78015 this.backwardLayer.trainable = value;
78016 }
78017 }
78018 }, {
78019 key: "getWeights",
78020 value: function getWeights() {
78021 return this.forwardLayer.getWeights().concat(this.backwardLayer.getWeights());
78022 }
78023 }, {
78024 key: "setWeights",
78025 value: function setWeights(weights) {
78026 var numWeights = weights.length;
78027 var numeightsOver2 = Math.floor(numWeights / 2);
78028 this.forwardLayer.setWeights(weights.slice(0, numeightsOver2));
78029 this.backwardLayer.setWeights(weights.slice(numeightsOver2));
78030 }
78031 }, {
78032 key: "computeOutputShape",
78033 value: function computeOutputShape(inputShape) {
78034 var layerShapes = this.forwardLayer.computeOutputShape(inputShape);
78035 if (!(Array.isArray(layerShapes) && Array.isArray(layerShapes[0]))) {
78036 layerShapes = [layerShapes];
78037 }
78038 layerShapes = layerShapes;
78039 var outputShape;
78040 var outputShapes;
78041 var stateShape;
78042 if (this.returnState) {
78043 stateShape = layerShapes.slice(1);
78044 outputShape = layerShapes[0];
78045 } else {
78046 outputShape = layerShapes[0];
78047 }
78048 outputShape = outputShape;
78049 if (this.mergeMode === 'concat') {
78050 outputShape[outputShape.length - 1] *= 2;
78051 outputShapes = [outputShape];
78052 } else if (this.mergeMode == null) {
78053 outputShapes = [outputShape, outputShape.slice()];
78054 } else {
78055 outputShapes = [outputShape];
78056 }
78057 if (this.returnState) {
78058 if (this.mergeMode == null) {
78059 return outputShapes.concat(stateShape).concat(stateShape.slice());
78060 }
78061 return [outputShape].concat(stateShape).concat(stateShape.slice());
78062 }
78063 return singletonOrArray(outputShapes);
78064 }
78065 }, {
78066 key: "apply",
78067 value: function apply(inputs, kwargs) {
78068 var initialState = kwargs == null ? null : kwargs['initialState'];
78069 var constants = kwargs == null ? null : kwargs['constants'];
78070 if (kwargs == null) {
78071 kwargs = {};
78072 }
78073 var standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
78074 inputs = standardized.inputs;
78075 initialState = standardized.initialState;
78076 constants = standardized.constants;
78077 if (Array.isArray(inputs)) {
78078 initialState = inputs.slice(1);
78079 inputs = inputs[0];
78080 }
78081 if ((initialState == null || initialState.length === 0) && constants == null) {
78082 return _get(_getPrototypeOf(Bidirectional.prototype), "apply", this).call(this, inputs, kwargs);
78083 }
78084 var additionalInputs = [];
78085 var additionalSpecs = [];
78086 if (initialState != null) {
78087 var numStates = initialState.length;
78088 if (numStates % 2 > 0) {
78089 throw new ValueError('When passing `initialState` to a Bidrectional RNN, ' + 'the state should be an Array containing the states of ' + 'the underlying RNNs.');
78090 }
78091 kwargs['initialState'] = initialState;
78092 additionalInputs.push.apply(additionalInputs, _toConsumableArray(initialState));
78093 var stateSpecs = initialState.map(function (state) {
78094 return new InputSpec({
78095 shape: state.shape
78096 });
78097 });
78098 this.forwardLayer.stateSpec = stateSpecs.slice(0, numStates / 2);
78099 this.backwardLayer.stateSpec = stateSpecs.slice(numStates / 2);
78100 additionalSpecs.push.apply(additionalSpecs, _toConsumableArray(stateSpecs));
78101 }
78102 if (constants != null) {
78103 throw new NotImplementedError('Support for constants in Bidirectional layers is not ' + 'implemented yet.');
78104 }
78105 var isSymbolicTensor = additionalInputs[0] instanceof SymbolicTensor;
78106 for (var _i = 0, _additionalInputs = additionalInputs; _i < _additionalInputs.length; _i++) {
78107 var tensor = _additionalInputs[_i];
78108 if (tensor instanceof SymbolicTensor !== isSymbolicTensor) {
78109 throw new ValueError('The initial state of a Bidirectional layer cannot be ' + 'specified as a mix of symbolic and non-symbolic tensors');
78110 }
78111 }
78112 if (isSymbolicTensor) {
78113 // Compute the full input and specs, including the states.
78114 var fullInput = [inputs].concat(additionalInputs);
78115 var fullInputSpec = this.inputSpec.concat(additionalSpecs);
78116 // Perform the call temporarily and replace inputSpec.
78117 // Note: with initial states symbolic calls and non-symbolic calls to
78118 // this method differ in how the initial states are passed. For
78119 // symbolic calls, the initial states are passed in the first arg, as
78120 // an Array of SymbolicTensors; for non-symbolic calls, they are
78121 // passed in the second arg as a part of the kwargs. Hence the need to
78122 // temporarily modify inputSpec here.
78123 // TODO(cais): Make refactoring so that this hacky code below is no
78124 // longer needed.
78125 var originalInputSpec = this.inputSpec;
78126 this.inputSpec = fullInputSpec;
78127 var output = _get(_getPrototypeOf(Bidirectional.prototype), "apply", this).call(this, fullInput, kwargs);
78128 this.inputSpec = originalInputSpec;
78129 return output;
78130 } else {
78131 return _get(_getPrototypeOf(Bidirectional.prototype), "apply", this).call(this, inputs, kwargs);
78132 }
78133 }
78134 }, {
78135 key: "call",
78136 value: function call(inputs, kwargs) {
78137 var _this5 = this;
78138 return tidy(function () {
78139 var initialState = kwargs['initialState'];
78140 var y;
78141 var yRev;
78142 if (initialState == null) {
78143 y = _this5.forwardLayer.call(inputs, kwargs);
78144 yRev = _this5.backwardLayer.call(inputs, kwargs);
78145 } else {
78146 var forwardState = initialState.slice(0, initialState.length / 2);
78147 var backwardState = initialState.slice(initialState.length / 2);
78148 y = _this5.forwardLayer.call(inputs, Object.assign(kwargs, {
78149 initialState: forwardState
78150 }));
78151 yRev = _this5.backwardLayer.call(inputs, Object.assign(kwargs, {
78152 initialState: backwardState
78153 }));
78154 }
78155 var states;
78156 if (_this5.returnState) {
78157 if (Array.isArray(y)) {
78158 states = y.slice(1).concat(yRev.slice(1));
78159 } else {}
78160 y = y[0];
78161 yRev = yRev[0];
78162 }
78163 if (_this5.returnSequences) {
78164 yRev = reverse$2(yRev, 1);
78165 }
78166 var output;
78167 if (_this5.mergeMode === 'concat') {
78168 output = concatenate$2([y, yRev]);
78169 } else if (_this5.mergeMode === 'sum') {
78170 output = add$3(y, yRev);
78171 } else if (_this5.mergeMode === 'ave') {
78172 output = mul(.5, add$3(y, yRev));
78173 } else if (_this5.mergeMode === 'mul') {
78174 output = mul(y, yRev);
78175 } else if (_this5.mergeMode == null) {
78176 output = [y, yRev];
78177 }
78178 // TODO(cais): Properly set learning phase.
78179 if (_this5.returnState) {
78180 if (_this5.mergeMode == null) {
78181 return output.concat(states);
78182 }
78183 return [output].concat(states);
78184 }
78185 return output;
78186 });
78187 }
78188 }, {
78189 key: "resetStates",
78190 value: function resetStates(states) {
78191 this.forwardLayer.resetStates();
78192 this.backwardLayer.resetStates();
78193 }
78194 }, {
78195 key: "build",
78196 value: function build(inputShape) {
78197 var _this6 = this;
78198 nameScope(this.forwardLayer.name, function () {
78199 _this6.forwardLayer.build(inputShape);
78200 });
78201 nameScope(this.backwardLayer.name, function () {
78202 _this6.backwardLayer.build(inputShape);
78203 });
78204 this.built = true;
78205 }
78206 }, {
78207 key: "computeMask",
78208 value: function computeMask(inputs, mask) {
78209 if (Array.isArray(mask)) {
78210 mask = mask[0];
78211 }
78212 var outputMask;
78213 if (this.returnSequences) {
78214 if (this.mergeMode == null) {
78215 outputMask = [mask, mask];
78216 } else {
78217 outputMask = mask;
78218 }
78219 } else {
78220 if (this.mergeMode == null) {
78221 outputMask = [null, null];
78222 } else {
78223 outputMask = null;
78224 }
78225 }
78226 if (this.returnState) {
78227 var states = this.forwardLayer.states;
78228 var stateMask = states.map(function (state) {
78229 return null;
78230 });
78231 if (Array.isArray(outputMask)) {
78232 return outputMask.concat(stateMask).concat(stateMask);
78233 } else {
78234 return [outputMask].concat(stateMask).concat(stateMask);
78235 }
78236 } else {
78237 return outputMask;
78238 }
78239 }
78240 }, {
78241 key: "trainableWeights",
78242 get: function get() {
78243 return this.forwardLayer.trainableWeights.concat(this.backwardLayer.trainableWeights);
78244 }
78245 }, {
78246 key: "nonTrainableWeights",
78247 get: function get() {
78248 return this.forwardLayer.nonTrainableWeights.concat(this.backwardLayer.nonTrainableWeights);
78249 }
78250 // TODO(cais): Implement constraints().
78251 }, {
78252 key: "setFastWeightInitDuringBuild",
78253 value: function setFastWeightInitDuringBuild(value) {
78254 _get(_getPrototypeOf(Bidirectional.prototype), "setFastWeightInitDuringBuild", this).call(this, value);
78255 if (this.forwardLayer != null) {
78256 this.forwardLayer.setFastWeightInitDuringBuild(value);
78257 }
78258 if (this.backwardLayer != null) {
78259 this.backwardLayer.setFastWeightInitDuringBuild(value);
78260 }
78261 }
78262 }, {
78263 key: "getConfig",
78264 value: function getConfig() {
78265 var config = {
78266 'mergeMode': this.mergeMode
78267 };
78268 // TODO(cais): Add logic for `numConstants` once the property is added.
78269 var baseConfig = _get(_getPrototypeOf(Bidirectional.prototype), "getConfig", this).call(this);
78270 Object.assign(config, baseConfig);
78271 return config;
78272 }
78273 /** @nocollapse */
78274 }], [{
78275 key: "fromConfig",
78276 value: function fromConfig(cls, config) {
78277 var rnnLayer = deserialize(config['layer']);
78278 delete config['layer'];
78279 // TODO(cais): Add logic for `numConstants` once the property is added.
78280 if (config['numConstants'] != null) {
78281 throw new NotImplementedError("Deserialization of a Bidirectional layer with numConstants " + "present is not supported yet.");
78282 }
78283 // tslint:disable-next-line:no-any
78284 var newConfig = config;
78285 newConfig['layer'] = rnnLayer;
78286 return new cls(newConfig);
78287 }
78288 }]);
78289 return Bidirectional;
78290 }(Wrapper);
78291 /** @nocollapse */
78292 Bidirectional.className = 'Bidirectional';
78293 registerClass(Bidirectional);
78294
78295 /**
78296 * Preprocessing Rescaling Layer
78297 *
78298 * This rescales images by a scaling and offset factor
78299 */
78300 var Rescaling = /*#__PURE__*/function (_Layer) {
78301 _inherits(Rescaling, _Layer);
78302 var _super = _createSuper(Rescaling);
78303 function Rescaling(args) {
78304 var _this;
78305 _classCallCheck(this, Rescaling);
78306 _this = _super.call(this, args);
78307 _this.scale = args.scale;
78308 if (args.offset) {
78309 _this.offset = args.offset;
78310 } else {
78311 _this.offset = 0;
78312 }
78313 return _this;
78314 }
78315 _createClass(Rescaling, [{
78316 key: "getConfig",
78317 value: function getConfig() {
78318 var config = {
78319 'scale': this.scale,
78320 'offset': this.offset
78321 };
78322 var baseConfig = _get(_getPrototypeOf(Rescaling.prototype), "getConfig", this).call(this);
78323 Object.assign(config, baseConfig);
78324 return config;
78325 }
78326 }, {
78327 key: "call",
78328 value: function call(inputs, kwargs) {
78329 var _this2 = this;
78330 return tidy(function () {
78331 inputs = getExactlyOneTensor(inputs);
78332 if (inputs.dtype !== 'float32') {
78333 inputs = cast$2(inputs, 'float32');
78334 }
78335 return add$3(mul(inputs, _this2.scale), _this2.offset);
78336 });
78337 }
78338 }]);
78339 return Rescaling;
78340 }(Layer);
78341 /** @nocollapse */
78342 Rescaling.className = 'Rescaling';
78343 registerClass(Rescaling);
78344
78345 var resizeBilinear$2 = image$1.resizeBilinear,
78346 cropAndResize$2 = image$1.cropAndResize;
78347 var CenterCrop = /*#__PURE__*/function (_Layer) {
78348 _inherits(CenterCrop, _Layer);
78349 var _super = _createSuper(CenterCrop);
78350 function CenterCrop(args) {
78351 var _this;
78352 _classCallCheck(this, CenterCrop);
78353 _this = _super.call(this, args);
78354 _this.height = args.height;
78355 _this.width = args.width;
78356 return _this;
78357 }
78358 _createClass(CenterCrop, [{
78359 key: "centerCrop",
78360 value: function centerCrop(inputs, hBuffer, wBuffer, height, width, inputHeight, inputWidth, dtype) {
78361 return tidy(function () {
78362 var input;
78363 var isRank3 = false;
78364 var top = hBuffer / inputHeight;
78365 var left = wBuffer / inputWidth;
78366 var bottom = (height + hBuffer) / inputHeight;
78367 var right = (width + wBuffer) / inputWidth;
78368 var bound = [top, left, bottom, right];
78369 var boxesArr = [];
78370 if (inputs.rank === 3) {
78371 isRank3 = true;
78372 input = stack([inputs]);
78373 } else {
78374 input = inputs;
78375 }
78376 for (var i = 0; i < input.shape[0]; i++) {
78377 boxesArr.push(bound);
78378 }
78379 var boxes = tensor(boxesArr, [boxesArr.length, 4]);
78380 var boxInd = range$3(0, boxesArr.length, 1, 'int32');
78381 var cropSize = [height, width];
78382 var cropped = cropAndResize$2(input, boxes, boxInd, cropSize, 'nearest');
78383 if (isRank3) {
78384 return cast$2(getExactlyOneTensor(unstack(cropped)), dtype);
78385 }
78386 return cast$2(cropped, dtype);
78387 });
78388 }
78389 }, {
78390 key: "upsize",
78391 value: function upsize(inputs, height, width, dtype) {
78392 return tidy(function () {
78393 var outputs = resizeBilinear$2(inputs, [height, width]);
78394 return cast$2(outputs, dtype);
78395 });
78396 }
78397 }, {
78398 key: "call",
78399 value: function call(inputs, kwargs) {
78400 var _this2 = this;
78401 return tidy(function () {
78402 var rankedInputs = getExactlyOneTensor(inputs);
78403 var dtype = rankedInputs.dtype;
78404 var inputShape = rankedInputs.shape;
78405 var inputHeight = inputShape[inputShape.length - 3];
78406 var inputWidth = inputShape[inputShape.length - 2];
78407 var hBuffer = 0;
78408 if (inputHeight !== _this2.height) {
78409 hBuffer = Math.floor((inputHeight - _this2.height) / 2);
78410 }
78411 var wBuffer = 0;
78412 if (inputWidth !== _this2.width) {
78413 wBuffer = Math.floor((inputWidth - _this2.width) / 2);
78414 if (wBuffer === 0) {
78415 wBuffer = 1;
78416 }
78417 }
78418 if (hBuffer >= 0 && wBuffer >= 0) {
78419 return _this2.centerCrop(rankedInputs, hBuffer, wBuffer, _this2.height, _this2.width, inputHeight, inputWidth, dtype);
78420 } else {
78421 return _this2.upsize(inputs, _this2.height, _this2.width, dtype);
78422 }
78423 });
78424 }
78425 }, {
78426 key: "getConfig",
78427 value: function getConfig() {
78428 var config = {
78429 'height': this.height,
78430 'width': this.width
78431 };
78432 var baseConfig = _get(_getPrototypeOf(CenterCrop.prototype), "getConfig", this).call(this);
78433 Object.assign(config, baseConfig);
78434 return config;
78435 }
78436 }, {
78437 key: "computeOutputShape",
78438 value: function computeOutputShape(inputShape) {
78439 inputShape = getExactlyOneShape(inputShape);
78440 var hAxis = inputShape.length - 3;
78441 var wAxis = inputShape.length - 2;
78442 inputShape[hAxis] = this.height;
78443 inputShape[wAxis] = this.width;
78444 return inputShape;
78445 }
78446 }]);
78447 return CenterCrop;
78448 }(Layer);
78449 /** @nocollapse */
78450 CenterCrop.className = 'CenterCrop';
78451 registerClass(CenterCrop);
78452
78453 /**
78454 * @license
78455 * Copyright 2022 CodeSmith LLC
78456 *
78457 * Use of this source code is governed by an MIT-style
78458 * license that can be found in the LICENSE file or at
78459 * https://opensource.org/licenses/MIT.
78460 * =============================================================================
78461 */
78462 function encodeCategoricalInputs(inputs, outputMode, depth, weights) {
78463 var input = getExactlyOneTensor(inputs);
78464 if (input.dtype !== 'int32') {
78465 input = cast$2(input, 'int32');
78466 }
78467 if (outputMode === 'int') {
78468 return input;
78469 }
78470 var originalShape = input.shape;
78471 if (input.rank === 0) {
78472 input = expandDims$3(input, -1);
78473 }
78474 if (outputMode === 'oneHot') {
78475 if (input.shape[input.shape.length - 1] !== 1) {
78476 input = expandDims$3(input, -1);
78477 }
78478 }
78479 if (input.rank > 2) {
78480 throw new ValueError("When outputMode is not int, maximum output rank is 2" + " Received outputMode ".concat(outputMode, " and input shape ").concat(originalShape) + " which would result in output rank ".concat(input.rank, "."));
78481 }
78482 var binaryOutput = ['multiHot', 'oneHot'].includes(outputMode);
78483 var denseBincountInput = input;
78484 var binCounts;
78485 if (typeof weights !== 'undefined' && outputMode === 'count') {
78486 binCounts = denseBincount$2(denseBincountInput, weights, depth, binaryOutput);
78487 } else {
78488 binCounts = denseBincount$2(denseBincountInput, [], depth, binaryOutput);
78489 }
78490 if (outputMode !== 'tfIdf') {
78491 return binCounts;
78492 }
78493 if (weights) {
78494 return mul(binCounts, weights);
78495 } else {
78496 throw new ValueError("When outputMode is 'tfIdf', weights must be provided.");
78497 }
78498 }
78499
78500 var CategoryEncoding = /*#__PURE__*/function (_Layer) {
78501 _inherits(CategoryEncoding, _Layer);
78502 var _super = _createSuper(CategoryEncoding);
78503 function CategoryEncoding(args) {
78504 var _this;
78505 _classCallCheck(this, CategoryEncoding);
78506 _this = _super.call(this, args);
78507 _this.numTokens = args.numTokens;
78508 if (args.outputMode) {
78509 _this.outputMode = args.outputMode;
78510 } else {
78511 _this.outputMode = 'multiHot';
78512 }
78513 return _this;
78514 }
78515 _createClass(CategoryEncoding, [{
78516 key: "getConfig",
78517 value: function getConfig() {
78518 var config = {
78519 'numTokens': this.numTokens,
78520 'outputMode': this.outputMode
78521 };
78522 var baseConfig = _get(_getPrototypeOf(CategoryEncoding.prototype), "getConfig", this).call(this);
78523 Object.assign(config, baseConfig);
78524 return config;
78525 }
78526 }, {
78527 key: "computeOutputShape",
78528 value: function computeOutputShape(inputShape) {
78529 inputShape = getExactlyOneShape(inputShape);
78530 if (inputShape == null) {
78531 return [this.numTokens];
78532 }
78533 if (this.outputMode === 'oneHot' && inputShape[inputShape.length - 1] !== 1) {
78534 inputShape.push(this.numTokens);
78535 return inputShape;
78536 }
78537 inputShape[inputShape.length - 1] = this.numTokens;
78538 return inputShape;
78539 }
78540 }, {
78541 key: "call",
78542 value: function call(inputs, kwargs) {
78543 var _this2 = this;
78544 return tidy(function () {
78545 inputs = getExactlyOneTensor(inputs);
78546 if (inputs.dtype !== 'int32') {
78547 inputs = cast$2(inputs, 'int32');
78548 }
78549 var countWeights;
78550 if (typeof kwargs['countWeights'] !== 'undefined') {
78551 if (_this2.outputMode !== 'count') {
78552 throw new ValueError("countWeights is not used when outputMode !== count.\n Received countWeights=".concat(kwargs['countWeights']));
78553 }
78554 countWeights = getExactlyOneTensor(kwargs['countWeights']);
78555 }
78556 var maxValue = max$3(inputs);
78557 var minValue = min$3(inputs);
78558 var greaterEqualMax = greater$3(_this2.numTokens, maxValue).bufferSync().get(0);
78559 var greaterMin = greaterEqual$2(minValue, 0).bufferSync().get(0);
78560 if (!(greaterEqualMax && greaterMin)) {
78561 throw new ValueError('Input values must be between 0 < values <=' + " numTokens with numTokens=".concat(_this2.numTokens));
78562 }
78563 return encodeCategoricalInputs(inputs, _this2.outputMode, _this2.numTokens, countWeights);
78564 });
78565 }
78566 }]);
78567 return CategoryEncoding;
78568 }(Layer);
78569 /** @nocollapse */
78570 CategoryEncoding.className = 'CategoryEncoding';
78571 registerClass(CategoryEncoding);
78572
78573 // tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5',
78574 // 'gaussian', 'mitchellcubic'
78575 var INTERPOLATION_KEYS$1 = ['bilinear', 'nearest'];
78576 var INTERPOLATION_METHODS$1 = new Set(INTERPOLATION_KEYS$1);
78577 /**
78578 * Preprocessing Resizing Layer
78579 *
78580 * This resizes images by a scaling and offset factor
78581 */
78582 var Resizing = /*#__PURE__*/function (_Layer) {
78583 _inherits(Resizing, _Layer);
78584 var _super = _createSuper(Resizing);
78585 function Resizing(args) {
78586 var _this;
78587 _classCallCheck(this, Resizing);
78588 _this = _super.call(this, args);
78589 _this.height = args.height;
78590 _this.width = args.width;
78591 if (args.interpolation) {
78592 if (INTERPOLATION_METHODS$1.has(args.interpolation)) {
78593 _this.interpolation = args.interpolation;
78594 } else {
78595 throw new ValueError("Invalid interpolation parameter: ".concat(args.interpolation, " is not implemented"));
78596 }
78597 } else {
78598 _this.interpolation = 'bilinear';
78599 }
78600 _this.cropToAspectRatio = Boolean(args.cropToAspectRatio);
78601 return _this;
78602 }
78603 _createClass(Resizing, [{
78604 key: "computeOutputShape",
78605 value: function computeOutputShape(inputShape) {
78606 inputShape = getExactlyOneShape(inputShape);
78607 var numChannels = inputShape[2];
78608 return [this.height, this.width, numChannels];
78609 }
78610 }, {
78611 key: "getConfig",
78612 value: function getConfig() {
78613 var config = {
78614 'height': this.height,
78615 'width': this.width,
78616 'interpolation': this.interpolation,
78617 'cropToAspectRatio': this.cropToAspectRatio
78618 };
78619 var baseConfig = _get(_getPrototypeOf(Resizing.prototype), "getConfig", this).call(this);
78620 Object.assign(config, baseConfig);
78621 return config;
78622 }
78623 }, {
78624 key: "call",
78625 value: function call(inputs, kwargs) {
78626 var _this2 = this;
78627 return tidy(function () {
78628 var size = [_this2.height, _this2.width];
78629 if (_this2.interpolation === 'bilinear') {
78630 return image$1.resizeBilinear(inputs, size, !_this2.cropToAspectRatio);
78631 } else if (_this2.interpolation === 'nearest') {
78632 return image$1.resizeNearestNeighbor(inputs, size, !_this2.cropToAspectRatio);
78633 } else {
78634 throw new Error("Interpolation is ".concat(_this2.interpolation, " but only ").concat(_toConsumableArray(INTERPOLATION_METHODS$1), " are supported"));
78635 }
78636 });
78637 }
78638 }]);
78639 return Resizing;
78640 }(Layer);
78641 /** @nocollapse */
78642 Resizing.className = 'Resizing';
78643 registerClass(Resizing);
78644
78645 /**
78646 * @license
78647 * Copyright 2023 CodeSmith LLC
78648 *
78649 * Use of this source code is governed by an MIT-style
78650 * license that can be found in the LICENSE file or at
78651 * https://opensource.org/licenses/MIT.
78652 * =============================================================================
78653 */
78654 /**
78655 * Keeps track of seed and handles pseudorandomness
78656 * Instance created in BaseRandomLayer class
78657 * Utilized for random preprocessing layers
78658 */
78659 var RandomSeed = /*#__PURE__*/function () {
78660 function RandomSeed(seed) {
78661 _classCallCheck(this, RandomSeed);
78662 this.seed = seed;
78663 }
78664 _createClass(RandomSeed, [{
78665 key: "next",
78666 value: function next() {
78667 if (this.seed === undefined) {
78668 return undefined;
78669 }
78670 return this.seed++;
78671 }
78672 }]);
78673 return RandomSeed;
78674 }();
78675 RandomSeed.className = 'RandomSeed';
78676
78677 var BaseRandomLayer = /*#__PURE__*/function (_Layer) {
78678 _inherits(BaseRandomLayer, _Layer);
78679 var _super = _createSuper(BaseRandomLayer);
78680 function BaseRandomLayer(args) {
78681 var _this;
78682 _classCallCheck(this, BaseRandomLayer);
78683 _this = _super.call(this, args);
78684 _this.randomGenerator = new RandomSeed(args.seed);
78685 return _this;
78686 }
78687 _createClass(BaseRandomLayer, [{
78688 key: "getConfig",
78689 value: function getConfig() {
78690 var config = {
78691 'seed': this.randomGenerator.seed
78692 };
78693 var baseConfig = _get(_getPrototypeOf(BaseRandomLayer.prototype), "getConfig", this).call(this);
78694 Object.assign(config, baseConfig);
78695 return config;
78696 }
78697 }]);
78698 return BaseRandomLayer;
78699 }(Layer); // A layer handle the random number creation and savemodel behavior.
78700 /** @nocollapse */
78701 BaseRandomLayer.className = 'BaseRandomLayer';
78702
78703 var INTERPOLATION_KEYS = ['bilinear', 'nearest'];
78704 var INTERPOLATION_METHODS = new Set(INTERPOLATION_KEYS);
78705 /**
78706 * Preprocessing Layer with randomly varies image during training
78707 *
78708 * This layer randomly adjusts the width of a batch of images of a
78709 * batch of images by a random factor.
78710 *
78711 * The input should be a 3D (unbatched) or
78712 * 4D (batched) tensor in the `"channels_last"` image data format. Input pixel
78713 * values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and of integer
78714 * or floating point dtype. By default, the layer will output floats.
78715 *
78716 * tf methods implemented in tfjs: 'bilinear', 'nearest',
78717 * tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5',
78718 * 'gaussian', 'mitchellcubic'
78719 *
78720 */
78721 var RandomWidth = /*#__PURE__*/function (_BaseRandomLayer) {
78722 _inherits(RandomWidth, _BaseRandomLayer);
78723 var _super = _createSuper(RandomWidth);
78724 function RandomWidth(args) {
78725 var _this;
78726 _classCallCheck(this, RandomWidth);
78727 _this = _super.call(this, args);
78728 var factor = args.factor,
78729 _args$interpolation = args.interpolation,
78730 interpolation = _args$interpolation === void 0 ? 'bilinear' : _args$interpolation;
78731 _this.factor = factor;
78732 if (Array.isArray(_this.factor) && _this.factor.length === 2) {
78733 _this.widthLower = _this.factor[0];
78734 _this.widthUpper = _this.factor[1];
78735 } else if (!Array.isArray(_this.factor) && _this.factor > 0) {
78736 _this.widthLower = -_this.factor;
78737 _this.widthUpper = _this.factor;
78738 } else {
78739 throw new ValueError("Invalid factor: ".concat(_this.factor, ". Must be positive number or tuple of 2 numbers"));
78740 }
78741 if (_this.widthLower < -1.0 || _this.widthUpper < -1.0) {
78742 throw new ValueError("factor must have values larger than -1. Got: ".concat(_this.factor));
78743 }
78744 if (_this.widthUpper < _this.widthLower) {
78745 throw new ValueError("factor cannot have upper bound less than lower bound.\n Got upper bound: ".concat(_this.widthUpper, ".\n Got lower bound: ").concat(_this.widthLower, "\n "));
78746 }
78747 if (interpolation) {
78748 if (INTERPOLATION_METHODS.has(interpolation)) {
78749 _this.interpolation = interpolation;
78750 } else {
78751 throw new ValueError("Invalid interpolation parameter: ".concat(interpolation, " is not implemented"));
78752 }
78753 }
78754 return _this;
78755 }
78756 _createClass(RandomWidth, [{
78757 key: "getConfig",
78758 value: function getConfig() {
78759 var config = {
78760 'factor': this.factor,
78761 'interpolation': this.interpolation
78762 };
78763 var baseConfig = _get(_getPrototypeOf(RandomWidth.prototype), "getConfig", this).call(this);
78764 Object.assign(config, baseConfig);
78765 return config;
78766 }
78767 }, {
78768 key: "computeOutputShape",
78769 value: function computeOutputShape(inputShape) {
78770 inputShape = getExactlyOneShape(inputShape);
78771 var numChannels = inputShape[2];
78772 return [this.imgHeight, -1, numChannels];
78773 }
78774 }, {
78775 key: "call",
78776 value: function call(inputs, kwargs) {
78777 var _this2 = this;
78778 return tidy(function () {
78779 var input = getExactlyOneTensor(inputs);
78780 _this2.imgHeight = input.shape[input.shape.length - 3];
78781 var imgWidth = input.shape[input.shape.length - 2];
78782 _this2.widthFactor = randomUniform$1([1], 1.0 + _this2.widthLower, 1.0 + _this2.widthUpper, 'float32', _this2.randomGenerator.next());
78783 var adjustedWidth = _this2.widthFactor.dataSync()[0] * imgWidth;
78784 adjustedWidth = Math.round(adjustedWidth);
78785 var size = [_this2.imgHeight, adjustedWidth];
78786 switch (_this2.interpolation) {
78787 case 'bilinear':
78788 return image$1.resizeBilinear(inputs, size);
78789 case 'nearest':
78790 return image$1.resizeNearestNeighbor(inputs, size);
78791 default:
78792 throw new Error("Interpolation is ".concat(_this2.interpolation, "\n but only ").concat(_toConsumableArray(INTERPOLATION_METHODS), " are supported"));
78793 }
78794 });
78795 }
78796 }]);
78797 return RandomWidth;
78798 }(BaseRandomLayer);
78799 /** @nocollapse */
78800 RandomWidth.className = 'RandomWidth';
78801 registerClass(RandomWidth);
78802
78803 /**
78804 * @license
78805 * Copyright 2018 Google LLC
78806 *
78807 * Use of this source code is governed by an MIT-style
78808 * license that can be found in the LICENSE file or at
78809 * https://opensource.org/licenses/MIT.
78810 * =============================================================================
78811 */
78812 // TODO(cais): Add doc string to all the public static functions in this
78813 // class; include exectuable JavaScript code snippets where applicable
78814 // (b/74074458).
78815 // Input Layer.
78816 /**
78817 * An input layer is an entry point into a `tf.LayersModel`.
78818 *
78819 * `InputLayer` is generated automatically for `tf.Sequential` models by
78820 * specifying the `inputshape` or `batchInputShape` for the first layer. It
78821 * should not be specified explicitly. However, it can be useful sometimes,
78822 * e.g., when constructing a sequential model from a subset of another
78823 * sequential model's layers. Like the code snippet below shows.
78824 *
78825 * ```js
78826 * // Define a model which simply adds two inputs.
78827 * const model1 = tf.sequential();
78828 * model1.add(tf.layers.dense({inputShape: [4], units: 3, activation: 'relu'}));
78829 * model1.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
78830 * model1.summary();
78831 * model1.predict(tf.zeros([1, 4])).print();
78832 *
78833 * // Construct another model, reusing the second layer of `model1` while
78834 * // not using the first layer of `model1`. Note that you cannot add the second
78835 * // layer of `model` directly as the first layer of the new sequential model,
78836 * // because doing so will lead to an error related to the fact that the layer
78837 * // is not an input layer. Instead, you need to create an `inputLayer` and add
78838 * // it to the new sequential model before adding the reused layer.
78839 * const model2 = tf.sequential();
78840 * // Use an inputShape that matches the input shape of `model1`'s second
78841 * // layer.
78842 * model2.add(tf.layers.inputLayer({inputShape: [3]}));
78843 * model2.add(model1.layers[1]);
78844 * model2.summary();
78845 * model2.predict(tf.zeros([1, 3])).print();
78846 * ```
78847 *
78848 * @doc {heading: 'Layers', subheading: 'Inputs', namespace: 'layers'}
78849 */
78850 function inputLayer(args) {
78851 return new InputLayer(args);
78852 }
78853 // Advanced Activation Layers.
78854 /**
78855 * Exponential Linear Unit (ELU).
78856 *
78857 * It follows:
78858 * `f(x) = alpha * (exp(x) - 1.) for x < 0`,
78859 * `f(x) = x for x >= 0`.
78860 *
78861 * Input shape:
78862 * Arbitrary. Use the configuration `inputShape` when using this layer as the
78863 * first layer in a model.
78864 *
78865 * Output shape:
78866 * Same shape as the input.
78867 *
78868 * References:
78869 * - [Fast and Accurate Deep Network Learning by Exponential Linear Units
78870 * (ELUs)](https://arxiv.org/abs/1511.07289v1)
78871 *
78872 * @doc {
78873 * heading: 'Layers',
78874 * subheading: 'Advanced Activation',
78875 * namespace: 'layers'
78876 * }
78877 */
78878 function elu$2(args) {
78879 return new ELU$3(args);
78880 }
78881 /**
78882 * Rectified Linear Unit activation function.
78883 *
78884 * Input shape:
78885 * Arbitrary. Use the config field `inputShape` (Array of integers, does
78886 * not include the sample axis) when using this layer as the first layer
78887 * in a model.
78888 *
78889 * Output shape:
78890 * Same shape as the input.
78891 *
78892 * @doc {
78893 * heading: 'Layers',
78894 * subheading: 'Advanced Activation',
78895 * namespace: 'layers'
78896 * }
78897 */
78898 function reLU(args) {
78899 return new ReLU(args);
78900 }
78901 /**
78902 * Leaky version of a rectified linear unit.
78903 *
78904 * It allows a small gradient when the unit is not active:
78905 * `f(x) = alpha * x for x < 0.`
78906 * `f(x) = x for x >= 0.`
78907 *
78908 * Input shape:
78909 * Arbitrary. Use the configuration `inputShape` when using this layer as the
78910 * first layer in a model.
78911 *
78912 * Output shape:
78913 * Same shape as the input.
78914 *
78915 * @doc {
78916 * heading: 'Layers',
78917 * subheading: 'Advanced Activation',
78918 * namespace: 'layers'
78919 * }
78920 */
78921 function leakyReLU(args) {
78922 return new LeakyReLU(args);
78923 }
78924 /**
78925 * Parameterized version of a leaky rectified linear unit.
78926 *
78927 * It follows
78928 * `f(x) = alpha * x for x < 0.`
78929 * `f(x) = x for x >= 0.`
78930 * wherein `alpha` is a trainable weight.
78931 *
78932 * Input shape:
78933 * Arbitrary. Use the configuration `inputShape` when using this layer as the
78934 * first layer in a model.
78935 *
78936 * Output shape:
78937 * Same shape as the input.
78938 *
78939 * @doc {
78940 * heading: 'Layers',
78941 * subheading: 'Advanced Activation',
78942 * namespace: 'layers'
78943 * }
78944 */
78945 function prelu$2(args) {
78946 return new PReLU(args);
78947 }
78948 /**
78949 * Softmax activation layer.
78950 *
78951 * Input shape:
78952 * Arbitrary. Use the configuration `inputShape` when using this layer as the
78953 * first layer in a model.
78954 *
78955 * Output shape:
78956 * Same shape as the input.
78957 *
78958 * @doc {
78959 * heading: 'Layers',
78960 * subheading: 'Advanced Activation',
78961 * namespace: 'layers'
78962 * }
78963 */
78964 function softmax$2(args) {
78965 return new Softmax(args);
78966 }
78967 /**
78968 * Thresholded Rectified Linear Unit.
78969 *
78970 * It follows:
78971 * `f(x) = x for x > theta`,
78972 * `f(x) = 0 otherwise`.
78973 *
78974 * Input shape:
78975 * Arbitrary. Use the configuration `inputShape` when using this layer as the
78976 * first layer in a model.
78977 *
78978 * Output shape:
78979 * Same shape as the input.
78980 *
78981 * References:
78982 * - [Zero-Bias Autoencoders and the Benefits of Co-Adapting
78983 * Features](http://arxiv.org/abs/1402.3337)
78984 *
78985 * @doc {
78986 * heading: 'Layers',
78987 * subheading: 'Advanced Activation',
78988 * namespace: 'layers'
78989 * }
78990 */
78991 function thresholdedReLU(args) {
78992 return new ThresholdedReLU(args);
78993 }
78994 // Convolutional Layers.
78995 /**
78996 * 1D convolution layer (e.g., temporal convolution).
78997 *
78998 * This layer creates a convolution kernel that is convolved
78999 * with the layer input over a single spatial (or temporal) dimension
79000 * to produce a tensor of outputs.
79001 *
79002 * If `use_bias` is True, a bias vector is created and added to the outputs.
79003 *
79004 * If `activation` is not `null`, it is applied to the outputs as well.
79005 *
79006 * When using this layer as the first layer in a model, provide an
79007 * `inputShape` argument `Array` or `null`.
79008 *
79009 * For example, `inputShape` would be:
79010 * - `[10, 128]` for sequences of 10 vectors of 128-dimensional vectors
79011 * - `[null, 128]` for variable-length sequences of 128-dimensional vectors.
79012 *
79013 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
79014 */
79015 function conv1d(args) {
79016 return new Conv1D(args);
79017 }
79018 /**
79019 * 2D convolution layer (e.g. spatial convolution over images).
79020 *
79021 * This layer creates a convolution kernel that is convolved
79022 * with the layer input to produce a tensor of outputs.
79023 *
79024 * If `useBias` is True, a bias vector is created and added to the outputs.
79025 *
79026 * If `activation` is not `null`, it is applied to the outputs as well.
79027 *
79028 * When using this layer as the first layer in a model,
79029 * provide the keyword argument `inputShape`
79030 * (Array of integers, does not include the sample axis),
79031 * e.g. `inputShape=[128, 128, 3]` for 128x128 RGB pictures
79032 * in `dataFormat='channelsLast'`.
79033 *
79034 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
79035 */
79036 function conv2d$1(args) {
79037 return new Conv2D(args);
79038 }
79039 /**
79040 * Transposed convolutional layer (sometimes called Deconvolution).
79041 *
79042 * The need for transposed convolutions generally arises
79043 * from the desire to use a transformation going in the opposite direction of
79044 * a normal convolution, i.e., from something that has the shape of the output
79045 * of some convolution to something that has the shape of its input while
79046 * maintaining a connectivity pattern that is compatible with said
79047 * convolution.
79048 *
79049 * When using this layer as the first layer in a model, provide the
79050 * configuration `inputShape` (`Array` of integers, does not include the
79051 * sample axis), e.g., `inputShape: [128, 128, 3]` for 128x128 RGB pictures in
79052 * `dataFormat: 'channelsLast'`.
79053 *
79054 * Input shape:
79055 * 4D tensor with shape:
79056 * `[batch, channels, rows, cols]` if `dataFormat` is `'channelsFirst'`.
79057 * or 4D tensor with shape
79058 * `[batch, rows, cols, channels]` if `dataFormat` is `'channelsLast'`.
79059 *
79060 * Output shape:
79061 * 4D tensor with shape:
79062 * `[batch, filters, newRows, newCols]` if `dataFormat` is
79063 * `'channelsFirst'`. or 4D tensor with shape:
79064 * `[batch, newRows, newCols, filters]` if `dataFormat` is `'channelsLast'`.
79065 *
79066 * References:
79067 * - [A guide to convolution arithmetic for deep
79068 * learning](https://arxiv.org/abs/1603.07285v1)
79069 * - [Deconvolutional
79070 * Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf)
79071 *
79072 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
79073 */
79074 function conv2dTranspose(args) {
79075 return new Conv2DTranspose(args);
79076 }
79077 /**
79078 * 3D convolution layer (e.g. spatial convolution over volumes).
79079 *
79080 * This layer creates a convolution kernel that is convolved
79081 * with the layer input to produce a tensor of outputs.
79082 *
79083 * If `useBias` is True, a bias vector is created and added to the outputs.
79084 *
79085 * If `activation` is not `null`, it is applied to the outputs as well.
79086 *
79087 * When using this layer as the first layer in a model,
79088 * provide the keyword argument `inputShape`
79089 * (Array of integers, does not include the sample axis),
79090 * e.g. `inputShape=[128, 128, 128, 1]` for 128x128x128 grayscale volumes
79091 * in `dataFormat='channelsLast'`.
79092 *
79093 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
79094 */
79095 function conv3d(args) {
79096 return new Conv3D(args);
79097 }
79098 function conv3dTranspose(args) {
79099 return new Conv3DTranspose(args);
79100 }
79101 /**
79102 * Depthwise separable 2D convolution.
79103 *
79104 * Separable convolution consists of first performing
79105 * a depthwise spatial convolution
79106 * (which acts on each input channel separately)
79107 * followed by a pointwise convolution which mixes together the resulting
79108 * output channels. The `depthMultiplier` argument controls how many
79109 * output channels are generated per input channel in the depthwise step.
79110 *
79111 * Intuitively, separable convolutions can be understood as
79112 * a way to factorize a convolution kernel into two smaller kernels,
79113 * or as an extreme version of an Inception block.
79114 *
79115 * Input shape:
79116 * 4D tensor with shape:
79117 * `[batch, channels, rows, cols]` if data_format='channelsFirst'
79118 * or 4D tensor with shape:
79119 * `[batch, rows, cols, channels]` if data_format='channelsLast'.
79120 *
79121 * Output shape:
79122 * 4D tensor with shape:
79123 * `[batch, filters, newRows, newCols]` if data_format='channelsFirst'
79124 * or 4D tensor with shape:
79125 * `[batch, newRows, newCols, filters]` if data_format='channelsLast'.
79126 * `rows` and `cols` values might have changed due to padding.
79127 *
79128 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
79129 */
79130 function separableConv2d(args) {
79131 return new SeparableConv2D(args);
79132 }
79133 /**
79134 * Cropping layer for 2D input (e.g., image).
79135 *
79136 * This layer can crop an input
79137 * at the top, bottom, left and right side of an image tensor.
79138 *
79139 * Input shape:
79140 * 4D tensor with shape:
79141 * - If `dataFormat` is `"channelsLast"`:
79142 * `[batch, rows, cols, channels]`
79143 * - If `data_format` is `"channels_first"`:
79144 * `[batch, channels, rows, cols]`.
79145 *
79146 * Output shape:
79147 * 4D with shape:
79148 * - If `dataFormat` is `"channelsLast"`:
79149 * `[batch, croppedRows, croppedCols, channels]`
79150 * - If `dataFormat` is `"channelsFirst"`:
79151 * `[batch, channels, croppedRows, croppedCols]`.
79152 *
79153 * Examples
79154 * ```js
79155 *
79156 * const model = tf.sequential();
79157 * model.add(tf.layers.cropping2D({cropping:[[2, 2], [2, 2]],
79158 * inputShape: [128, 128, 3]}));
79159 * //now output shape is [batch, 124, 124, 3]
79160 * ```
79161 *
79162 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
79163 */
79164 function cropping2D(args) {
79165 return new Cropping2D(args);
79166 }
79167 /**
79168 * Upsampling layer for 2D inputs.
79169 *
79170 * Repeats the rows and columns of the data
79171 * by size[0] and size[1] respectively.
79172 *
79173 *
79174 * Input shape:
79175 * 4D tensor with shape:
79176 * - If `dataFormat` is `"channelsLast"`:
79177 * `[batch, rows, cols, channels]`
79178 * - If `dataFormat` is `"channelsFirst"`:
79179 * `[batch, channels, rows, cols]`
79180 *
79181 * Output shape:
79182 * 4D tensor with shape:
79183 * - If `dataFormat` is `"channelsLast"`:
79184 * `[batch, upsampledRows, upsampledCols, channels]`
79185 * - If `dataFormat` is `"channelsFirst"`:
79186 * `[batch, channels, upsampledRows, upsampledCols]`
79187 *
79188 *
79189 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
79190 */
79191 function upSampling2d(args) {
79192 return new UpSampling2D(args);
79193 }
79194 // Convolutional(depthwise) Layers.
79195 /**
79196 * Depthwise separable 2D convolution.
79197 *
79198 * Depthwise Separable convolutions consists in performing just the first step
79199 * in a depthwise spatial convolution (which acts on each input channel
79200 * separately). The `depthMultiplier` argument controls how many output channels
79201 * are generated per input channel in the depthwise step.
79202 *
79203 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
79204 */
79205 function depthwiseConv2d(args) {
79206 return new DepthwiseConv2D(args);
79207 }
79208 // Basic Layers.
79209 /**
79210 * Applies an activation function to an output.
79211 *
79212 * This layer applies element-wise activation function. Other layers, notably
79213 * `dense` can also apply activation functions. Use this isolated activation
79214 * function to extract the values before and after the
79215 * activation. For instance:
79216 *
79217 * ```js
79218 * const input = tf.input({shape: [5]});
79219 * const denseLayer = tf.layers.dense({units: 1});
79220 * const activationLayer = tf.layers.activation({activation: 'relu6'});
79221 *
79222 * // Obtain the output symbolic tensors by applying the layers in order.
79223 * const denseOutput = denseLayer.apply(input);
79224 * const activationOutput = activationLayer.apply(denseOutput);
79225 *
79226 * // Create the model based on the inputs.
79227 * const model = tf.model({
79228 * inputs: input,
79229 * outputs: [denseOutput, activationOutput]
79230 * });
79231 *
79232 * // Collect both outputs and print separately.
79233 * const [denseOut, activationOut] = model.predict(tf.randomNormal([6, 5]));
79234 * denseOut.print();
79235 * activationOut.print();
79236 * ```
79237 *
79238 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
79239 */
79240 function activation(args) {
79241 return new Activation(args);
79242 }
79243 /**
79244 * Creates a dense (fully connected) layer.
79245 *
79246 * This layer implements the operation:
79247 * `output = activation(dot(input, kernel) + bias)`
79248 *
79249 * `activation` is the element-wise activation function
79250 * passed as the `activation` argument.
79251 *
79252 * `kernel` is a weights matrix created by the layer.
79253 *
79254 * `bias` is a bias vector created by the layer (only applicable if `useBias`
79255 * is `true`).
79256 *
79257 * **Input shape:**
79258 *
79259 * nD `tf.Tensor` with shape: `(batchSize, ..., inputDim)`.
79260 *
79261 * The most common situation would be
79262 * a 2D input with shape `(batchSize, inputDim)`.
79263 *
79264 * **Output shape:**
79265 *
79266 * nD tensor with shape: `(batchSize, ..., units)`.
79267 *
79268 * For instance, for a 2D input with shape `(batchSize, inputDim)`,
79269 * the output would have shape `(batchSize, units)`.
79270 *
79271 * Note: if the input to the layer has a rank greater than 2, then it is
79272 * flattened prior to the initial dot product with the kernel.
79273 *
79274 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
79275 */
79276 function dense(args) {
79277 return new Dense(args);
79278 }
79279 /**
79280 * Applies
79281 * [dropout](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) to
79282 * the input.
79283 *
79284 * Dropout consists in randomly setting a fraction `rate` of input units to 0 at
79285 * each update during training time, which helps prevent overfitting.
79286 *
79287 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
79288 */
79289 function dropout(args) {
79290 return new Dropout(args);
79291 }
79292 /**
79293 * Spatial 1D version of Dropout.
79294 *
79295 * This Layer type performs the same function as the Dropout layer, but it drops
79296 * entire 1D feature maps instead of individual elements. For example, if an
79297 * input example consists of 3 timesteps and the feature map for each timestep
79298 * has a size of 4, a `spatialDropout1d` layer may zero out the feature maps
79299 * of the 1st timesteps and 2nd timesteps completely while sparing all feature
79300 * elements of the 3rd timestep.
79301 *
79302 * If adjacent frames (timesteps) are strongly correlated (as is normally the
79303 * case in early convolution layers), regular dropout will not regularize the
79304 * activation and will otherwise just result in merely an effective learning
79305 * rate decrease. In this case, `spatialDropout1d` will help promote
79306 * independence among feature maps and should be used instead.
79307 *
79308 * **Arguments:**
79309 * rate: A floating-point number >=0 and <=1. Fraction of the input elements
79310 * to drop.
79311 *
79312 * **Input shape:**
79313 * 3D tensor with shape `(samples, timesteps, channels)`.
79314 *
79315 * **Output shape:**
79316 * Same as the input shape.
79317 *
79318 * References:
79319 * - [Efficient Object Localization Using Convolutional
79320 * Networks](https://arxiv.org/abs/1411.4280)
79321 *
79322 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
79323 */
79324 function spatialDropout1d(args) {
79325 return new SpatialDropout1D(args);
79326 }
79327 /**
79328 * Flattens the input. Does not affect the batch size.
79329 *
79330 * A `Flatten` layer flattens each batch in its inputs to 1D (making the output
79331 * 2D).
79332 *
79333 * For example:
79334 *
79335 * ```js
79336 * const input = tf.input({shape: [4, 3]});
79337 * const flattenLayer = tf.layers.flatten();
79338 * // Inspect the inferred output shape of the flatten layer, which
79339 * // equals `[null, 12]`. The 2nd dimension is 4 * 3, i.e., the result of the
79340 * // flattening. (The 1st dimension is the undermined batch size.)
79341 * console.log(JSON.stringify(flattenLayer.apply(input).shape));
79342 * ```
79343 *
79344 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
79345 */
79346 function flatten(args) {
79347 return new Flatten(args);
79348 }
79349 /**
79350 * Repeats the input n times in a new dimension.
79351 *
79352 * ```js
79353 * const model = tf.sequential();
79354 * model.add(tf.layers.repeatVector({n: 4, inputShape: [2]}));
79355 * const x = tf.tensor2d([[10, 20]]);
79356 * // Use the model to do inference on a data point the model hasn't seen
79357 * model.predict(x).print();
79358 * // output shape is now [batch, 2, 4]
79359 * ```
79360 *
79361 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
79362 */
79363 function repeatVector(args) {
79364 return new RepeatVector(args);
79365 }
79366 /**
79367 * Reshapes an input to a certain shape.
79368 *
79369 * ```js
79370 * const input = tf.input({shape: [4, 3]});
79371 * const reshapeLayer = tf.layers.reshape({targetShape: [2, 6]});
79372 * // Inspect the inferred output shape of the Reshape layer, which
79373 * // equals `[null, 2, 6]`. (The 1st dimension is the undermined batch size.)
79374 * console.log(JSON.stringify(reshapeLayer.apply(input).shape));
79375 * ```
79376 *
79377 * Input shape:
79378 * Arbitrary, although all dimensions in the input shape must be fixed.
79379 * Use the configuration `inputShape` when using this layer as the
79380 * first layer in a model.
79381 *
79382 *
79383 * Output shape:
79384 * [batchSize, targetShape[0], targetShape[1], ...,
79385 * targetShape[targetShape.length - 1]].
79386 *
79387 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
79388 */
79389 function reshape$2(args) {
79390 return new Reshape(args);
79391 }
79392 /**
79393 * Permutes the dimensions of the input according to a given pattern.
79394 *
79395 * Useful for, e.g., connecting RNNs and convnets together.
79396 *
79397 * Example:
79398 *
79399 * ```js
79400 * const model = tf.sequential();
79401 * model.add(tf.layers.permute({
79402 * dims: [2, 1],
79403 * inputShape: [10, 64]
79404 * }));
79405 * console.log(model.outputShape);
79406 * // Now model's output shape is [null, 64, 10], where null is the
79407 * // unpermuted sample (batch) dimension.
79408 * ```
79409 *
79410 * Input shape:
79411 * Arbitrary. Use the configuration field `inputShape` when using this
79412 * layer as the first layer in a model.
79413 *
79414 * Output shape:
79415 * Same rank as the input shape, but with the dimensions re-ordered (i.e.,
79416 * permuted) according to the `dims` configuration of this layer.
79417 *
79418 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
79419 */
79420 function permute(args) {
79421 return new Permute(args);
79422 }
79423 /**
79424 * Maps positive integers (indices) into dense vectors of fixed size.
79425 * E.g. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]
79426 *
79427 * **Input shape:** 2D tensor with shape: `[batchSize, sequenceLength]`.
79428 *
79429 * **Output shape:** 3D tensor with shape: `[batchSize, sequenceLength,
79430 * outputDim]`.
79431 *
79432 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
79433 */
79434 function embedding(args) {
79435 return new Embedding(args);
79436 }
79437 // Merge Layers.
79438 /**
79439 * Layer that performs element-wise addition on an `Array` of inputs.
79440 *
79441 * It takes as input a list of tensors, all of the same shape, and returns a
79442 * single tensor (also of the same shape). The inputs are specified as an
79443 * `Array` when the `apply` method of the `Add` layer instance is called. For
79444 * example:
79445 *
79446 * ```js
79447 * const input1 = tf.input({shape: [2, 2]});
79448 * const input2 = tf.input({shape: [2, 2]});
79449 * const addLayer = tf.layers.add();
79450 * const sum = addLayer.apply([input1, input2]);
79451 * console.log(JSON.stringify(sum.shape));
79452 * // You get [null, 2, 2], with the first dimension as the undetermined batch
79453 * // dimension.
79454 * ```
79455 *
79456 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
79457 */
79458 function add$1(args) {
79459 return new Add(args);
79460 }
79461 /**
79462 * Layer that performs element-wise averaging on an `Array` of inputs.
79463 *
79464 * It takes as input a list of tensors, all of the same shape, and returns a
79465 * single tensor (also of the same shape). For example:
79466 *
79467 * ```js
79468 * const input1 = tf.input({shape: [2, 2]});
79469 * const input2 = tf.input({shape: [2, 2]});
79470 * const averageLayer = tf.layers.average();
79471 * const average = averageLayer.apply([input1, input2]);
79472 * console.log(JSON.stringify(average.shape));
79473 * // You get [null, 2, 2], with the first dimension as the undetermined batch
79474 * // dimension.
79475 * ```
79476 *
79477 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
79478 */
79479 function average(args) {
79480 return new Average(args);
79481 }
79482 /**
79483 * Layer that concatenates an `Array` of inputs.
79484 *
79485 * It takes a list of tensors, all of the same shape except for the
79486 * concatenation axis, and returns a single tensor, the concatenation
79487 * of all inputs. For example:
79488 *
79489 * ```js
79490 * const input1 = tf.input({shape: [2, 2]});
79491 * const input2 = tf.input({shape: [2, 3]});
79492 * const concatLayer = tf.layers.concatenate();
79493 * const output = concatLayer.apply([input1, input2]);
79494 * console.log(JSON.stringify(output.shape));
79495 * // You get [null, 2, 5], with the first dimension as the undetermined batch
79496 * // dimension. The last dimension (5) is the result of concatenating the
79497 * // last dimensions of the inputs (2 and 3).
79498 * ```
79499 *
79500 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
79501 */
79502 function concatenate(args) {
79503 return new Concatenate(args);
79504 }
79505 /**
79506 * Layer that computes the element-wise maximum of an `Array` of inputs.
79507 *
79508 * It takes as input a list of tensors, all of the same shape, and returns a
79509 * single tensor (also of the same shape). For example:
79510 *
79511 * ```js
79512 * const input1 = tf.input({shape: [2, 2]});
79513 * const input2 = tf.input({shape: [2, 2]});
79514 * const maxLayer = tf.layers.maximum();
79515 * const max = maxLayer.apply([input1, input2]);
79516 * console.log(JSON.stringify(max.shape));
79517 * // You get [null, 2, 2], with the first dimension as the undetermined batch
79518 * // dimension.
79519 * ```
79520 *
79521 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
79522 */
79523 function maximum$2(args) {
79524 return new Maximum(args);
79525 }
79526 /**
79527 * Layer that computes the element-wise minimum of an `Array` of inputs.
79528 *
79529 * It takes as input a list of tensors, all of the same shape, and returns a
79530 * single tensor (also of the same shape). For example:
79531 *
79532 * ```js
79533 * const input1 = tf.input({shape: [2, 2]});
79534 * const input2 = tf.input({shape: [2, 2]});
79535 * const minLayer = tf.layers.minimum();
79536 * const min = minLayer.apply([input1, input2]);
79537 * console.log(JSON.stringify(min.shape));
79538 * // You get [null, 2, 2], with the first dimension as the undetermined batch
79539 * // dimension.
79540 * ```
79541 *
79542 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
79543 */
79544 function minimum$2(args) {
79545 return new Minimum(args);
79546 }
79547 /**
79548 * Layer that multiplies (element-wise) an `Array` of inputs.
79549 *
79550 * It takes as input an Array of tensors, all of the same
79551 * shape, and returns a single tensor (also of the same shape).
79552 * For example:
79553 *
79554 * ```js
79555 * const input1 = tf.input({shape: [2, 2]});
79556 * const input2 = tf.input({shape: [2, 2]});
79557 * const input3 = tf.input({shape: [2, 2]});
79558 * const multiplyLayer = tf.layers.multiply();
79559 * const product = multiplyLayer.apply([input1, input2, input3]);
79560 * console.log(product.shape);
79561 * // You get [null, 2, 2], with the first dimension as the undetermined batch
79562 * // dimension.
79563 *
79564 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
79565 */
79566 function multiply$2(args) {
79567 return new Multiply(args);
79568 }
79569 /**
79570 * Layer that computes a dot product between samples in two tensors.
79571 *
79572 * E.g., if applied to a list of two tensors `a` and `b` both of shape
79573 * `[batchSize, n]`, the output will be a tensor of shape `[batchSize, 1]`,
79574 * where each entry at index `[i, 0]` will be the dot product between
79575 * `a[i, :]` and `b[i, :]`.
79576 *
79577 * Example:
79578 *
79579 * ```js
79580 * const dotLayer = tf.layers.dot({axes: -1});
79581 * const x1 = tf.tensor2d([[10, 20], [30, 40]]);
79582 * const x2 = tf.tensor2d([[-1, -2], [-3, -4]]);
79583 *
79584 * // Invoke the layer's apply() method in eager (imperative) mode.
79585 * const y = dotLayer.apply([x1, x2]);
79586 * y.print();
79587 * ```
79588 *
79589 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
79590 */
79591 function dot(args) {
79592 return new Dot(args);
79593 }
79594 // Normalization Layers.
79595 /**
79596 * Batch normalization layer (Ioffe and Szegedy, 2014).
79597 *
79598 * Normalize the activations of the previous layer at each batch,
79599 * i.e. applies a transformation that maintains the mean activation
79600 * close to 0 and the activation standard deviation close to 1.
79601 *
79602 * Input shape:
79603 * Arbitrary. Use the keyword argument `inputShape` (Array of integers, does
79604 * not include the sample axis) when calling the constructor of this class,
79605 * if this layer is used as a first layer in a model.
79606 *
79607 * Output shape:
79608 * Same shape as input.
79609 *
79610 * References:
79611 * - [Batch Normalization: Accelerating Deep Network Training by Reducing
79612 * Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
79613 *
79614 * @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'}
79615 */
79616 function batchNormalization(args) {
79617 return new BatchNormalization(args);
79618 }
79619 /**
79620 * Layer-normalization layer (Ba et al., 2016).
79621 *
79622 * Normalizes the activations of the previous layer for each given example in a
79623 * batch independently, instead of across a batch like in `batchNormalization`.
79624 * In other words, this layer applies a transformation that maintains the mean
79625 * activation within each example close to 0 and activation variance close to 1.
79626 *
79627 * Input shape:
79628 * Arbitrary. Use the argument `inputShape` when using this layer as the first
79629 * layer in a model.
79630 *
79631 * Output shape:
79632 * Same as input.
79633 *
79634 * References:
79635 * - [Layer Normalization](https://arxiv.org/abs/1607.06450)
79636 *
79637 * @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'}
79638 */
79639 function layerNormalization(args) {
79640 return new LayerNormalization(args);
79641 }
79642 // Padding Layers.
79643 /**
79644 * Zero-padding layer for 2D input (e.g., image).
79645 *
79646 * This layer can add rows and columns of zeros
79647 * at the top, bottom, left and right side of an image tensor.
79648 *
79649 * Input shape:
79650 * 4D tensor with shape:
79651 * - If `dataFormat` is `"channelsLast"`:
79652 * `[batch, rows, cols, channels]`
79653 * - If `data_format` is `"channels_first"`:
79654 * `[batch, channels, rows, cols]`.
79655 *
79656 * Output shape:
79657 * 4D with shape:
79658 * - If `dataFormat` is `"channelsLast"`:
79659 * `[batch, paddedRows, paddedCols, channels]`
79660 * - If `dataFormat` is `"channelsFirst"`:
79661 * `[batch, channels, paddedRows, paddedCols]`.
79662 *
79663 * @doc {heading: 'Layers', subheading: 'Padding', namespace: 'layers'}
79664 */
79665 function zeroPadding2d(args) {
79666 return new ZeroPadding2D(args);
79667 }
79668 // Pooling Layers.
79669 /**
79670 * Average pooling operation for spatial data.
79671 *
79672 * Input shape: `[batchSize, inLength, channels]`
79673 *
79674 * Output shape: `[batchSize, pooledLength, channels]`
79675 *
79676 * `tf.avgPool1d` is an alias.
79677 *
79678 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
79679 */
79680 function averagePooling1d(args) {
79681 return new AveragePooling1D(args);
79682 }
79683 function avgPool1d(args) {
79684 return averagePooling1d(args);
79685 }
79686 // For backwards compatibility.
79687 // See https://github.com/tensorflow/tfjs/issues/152
79688 function avgPooling1d(args) {
79689 return averagePooling1d(args);
79690 }
79691 /**
79692 * Average pooling operation for spatial data.
79693 *
79694 * Input shape:
79695 * - If `dataFormat === CHANNEL_LAST`:
79696 * 4D tensor with shape:
79697 * `[batchSize, rows, cols, channels]`
79698 * - If `dataFormat === CHANNEL_FIRST`:
79699 * 4D tensor with shape:
79700 * `[batchSize, channels, rows, cols]`
79701 *
79702 * Output shape
79703 * - If `dataFormat === CHANNEL_LAST`:
79704 * 4D tensor with shape:
79705 * `[batchSize, pooledRows, pooledCols, channels]`
79706 * - If `dataFormat === CHANNEL_FIRST`:
79707 * 4D tensor with shape:
79708 * `[batchSize, channels, pooledRows, pooledCols]`
79709 *
79710 * `tf.avgPool2d` is an alias.
79711 *
79712 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
79713 */
79714 function averagePooling2d(args) {
79715 return new AveragePooling2D(args);
79716 }
79717 function avgPool2d(args) {
79718 return averagePooling2d(args);
79719 }
79720 // For backwards compatibility.
79721 // See https://github.com/tensorflow/tfjs/issues/152
79722 function avgPooling2d(args) {
79723 return averagePooling2d(args);
79724 }
79725 /**
79726 * Average pooling operation for 3D data.
79727 *
79728 * Input shape
79729 * - If `dataFormat === channelsLast`:
79730 * 5D tensor with shape:
79731 * `[batchSize, depths, rows, cols, channels]`
79732 * - If `dataFormat === channelsFirst`:
79733 * 4D tensor with shape:
79734 * `[batchSize, channels, depths, rows, cols]`
79735 *
79736 * Output shape
79737 * - If `dataFormat=channelsLast`:
79738 * 5D tensor with shape:
79739 * `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
79740 * - If `dataFormat=channelsFirst`:
79741 * 5D tensor with shape:
79742 * `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
79743 *
79744 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
79745 */
79746 function averagePooling3d(args) {
79747 return new AveragePooling3D(args);
79748 }
79749 function avgPool3d(args) {
79750 return averagePooling3d(args);
79751 }
79752 // For backwards compatibility.
79753 // See https://github.com/tensorflow/tfjs/issues/152
79754 function avgPooling3d(args) {
79755 return averagePooling3d(args);
79756 }
79757 /**
79758 * Global average pooling operation for temporal data.
79759 *
79760 * Input Shape: 3D tensor with shape: `[batchSize, steps, features]`.
79761 *
79762 * Output Shape: 2D tensor with shape: `[batchSize, features]`.
79763 *
79764 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
79765 */
79766 function globalAveragePooling1d(args) {
79767 return new GlobalAveragePooling1D(args);
79768 }
79769 /**
79770 * Global average pooling operation for spatial data.
79771 *
79772 * Input shape:
79773 * - If `dataFormat` is `CHANNEL_LAST`:
79774 * 4D tensor with shape: `[batchSize, rows, cols, channels]`.
79775 * - If `dataFormat` is `CHANNEL_FIRST`:
79776 * 4D tensor with shape: `[batchSize, channels, rows, cols]`.
79777 *
79778 * Output shape:
79779 * 2D tensor with shape: `[batchSize, channels]`.
79780 *
79781 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
79782 */
79783 function globalAveragePooling2d(args) {
79784 return new GlobalAveragePooling2D(args);
79785 }
79786 /**
79787 * Global max pooling operation for temporal data.
79788 *
79789 * Input Shape: 3D tensor with shape: `[batchSize, steps, features]`.
79790 *
79791 * Output Shape: 2D tensor with shape: `[batchSize, features]`.
79792 *
79793 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
79794 */
79795 function globalMaxPooling1d(args) {
79796 return new GlobalMaxPooling1D(args);
79797 }
79798 /**
79799 * Global max pooling operation for spatial data.
79800 *
79801 * Input shape:
79802 * - If `dataFormat` is `CHANNEL_LAST`:
79803 * 4D tensor with shape: `[batchSize, rows, cols, channels]`.
79804 * - If `dataFormat` is `CHANNEL_FIRST`:
79805 * 4D tensor with shape: `[batchSize, channels, rows, cols]`.
79806 *
79807 * Output shape:
79808 * 2D tensor with shape: `[batchSize, channels]`.
79809 *
79810 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
79811 */
79812 function globalMaxPooling2d(args) {
79813 return new GlobalMaxPooling2D(args);
79814 }
79815 /**
79816 * Max pooling operation for temporal data.
79817 *
79818 * Input shape: `[batchSize, inLength, channels]`
79819 *
79820 * Output shape: `[batchSize, pooledLength, channels]`
79821 *
79822 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
79823 */
79824 function maxPooling1d(args) {
79825 return new MaxPooling1D(args);
79826 }
79827 /**
79828 * Max pooling operation for spatial data.
79829 *
79830 * Input shape
79831 * - If `dataFormat === CHANNEL_LAST`:
79832 * 4D tensor with shape:
79833 * `[batchSize, rows, cols, channels]`
79834 * - If `dataFormat === CHANNEL_FIRST`:
79835 * 4D tensor with shape:
79836 * `[batchSize, channels, rows, cols]`
79837 *
79838 * Output shape
79839 * - If `dataFormat=CHANNEL_LAST`:
79840 * 4D tensor with shape:
79841 * `[batchSize, pooledRows, pooledCols, channels]`
79842 * - If `dataFormat=CHANNEL_FIRST`:
79843 * 4D tensor with shape:
79844 * `[batchSize, channels, pooledRows, pooledCols]`
79845 *
79846 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
79847 */
79848 function maxPooling2d(args) {
79849 return new MaxPooling2D(args);
79850 }
79851 /**
79852 * Max pooling operation for 3D data.
79853 *
79854 * Input shape
79855 * - If `dataFormat === channelsLast`:
79856 * 5D tensor with shape:
79857 * `[batchSize, depths, rows, cols, channels]`
79858 * - If `dataFormat === channelsFirst`:
79859 * 5D tensor with shape:
79860 * `[batchSize, channels, depths, rows, cols]`
79861 *
79862 * Output shape
79863 * - If `dataFormat=channelsLast`:
79864 * 5D tensor with shape:
79865 * `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
79866 * - If `dataFormat=channelsFirst`:
79867 * 5D tensor with shape:
79868 * `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
79869 *
79870 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
79871 */
79872 function maxPooling3d(args) {
79873 return new MaxPooling3D(args);
79874 }
79875 // Recurrent Layers.
79876 /**
79877 * Gated Recurrent Unit - Cho et al. 2014.
79878 *
79879 * This is an `RNN` layer consisting of one `GRUCell`. However, unlike
79880 * the underlying `GRUCell`, the `apply` method of `SimpleRNN` operates
79881 * on a sequence of inputs. The shape of the input (not including the first,
79882 * batch dimension) needs to be at least 2-D, with the first dimension being
79883 * time steps. For example:
79884 *
79885 * ```js
79886 * const rnn = tf.layers.gru({units: 8, returnSequences: true});
79887 *
79888 * // Create an input with 10 time steps.
79889 * const input = tf.input({shape: [10, 20]});
79890 * const output = rnn.apply(input);
79891 *
79892 * console.log(JSON.stringify(output.shape));
79893 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
79894 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
79895 * // 3rd dimension is the `GRUCell`'s number of units.
79896 *
79897 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
79898 */
79899 function gru(args) {
79900 return new GRU(args);
79901 }
79902 /**
79903 * Cell class for `GRU`.
79904 *
79905 * `GRUCell` is distinct from the `RNN` subclass `GRU` in that its
79906 * `apply` method takes the input data of only a single time step and returns
79907 * the cell's output at the time step, while `GRU` takes the input data
79908 * over a number of time steps. For example:
79909 *
79910 * ```js
79911 * const cell = tf.layers.gruCell({units: 2});
79912 * const input = tf.input({shape: [10]});
79913 * const output = cell.apply(input);
79914 *
79915 * console.log(JSON.stringify(output.shape));
79916 * // [null, 10]: This is the cell's output at a single time step. The 1st
79917 * // dimension is the unknown batch size.
79918 * ```
79919 *
79920 * Instance(s) of `GRUCell` can be used to construct `RNN` layers. The
79921 * most typical use of this workflow is to combine a number of cells into a
79922 * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
79923 * RNN. For example:
79924 *
79925 * ```js
79926 * const cells = [
79927 * tf.layers.gruCell({units: 4}),
79928 * tf.layers.gruCell({units: 8}),
79929 * ];
79930 * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
79931 *
79932 * // Create an input with 10 time steps and a length-20 vector at each step.
79933 * const input = tf.input({shape: [10, 20]});
79934 * const output = rnn.apply(input);
79935 *
79936 * console.log(JSON.stringify(output.shape));
79937 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
79938 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
79939 * // 3rd dimension is the last `gruCell`'s number of units.
79940 * ```
79941 *
79942 * To create an `RNN` consisting of only *one* `GRUCell`, use the
79943 * `tf.layers.gru`.
79944 *
79945 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
79946 */
79947 function gruCell(args) {
79948 return new GRUCell(args);
79949 }
79950 /**
79951 * Long-Short Term Memory layer - Hochreiter 1997.
79952 *
79953 * This is an `RNN` layer consisting of one `LSTMCell`. However, unlike
79954 * the underlying `LSTMCell`, the `apply` method of `LSTM` operates
79955 * on a sequence of inputs. The shape of the input (not including the first,
79956 * batch dimension) needs to be at least 2-D, with the first dimension being
79957 * time steps. For example:
79958 *
79959 * ```js
79960 * const lstm = tf.layers.lstm({units: 8, returnSequences: true});
79961 *
79962 * // Create an input with 10 time steps.
79963 * const input = tf.input({shape: [10, 20]});
79964 * const output = lstm.apply(input);
79965 *
79966 * console.log(JSON.stringify(output.shape));
79967 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
79968 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
79969 * // 3rd dimension is the `LSTMCell`'s number of units.
79970 *
79971 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
79972 */
79973 function lstm(args) {
79974 return new LSTM(args);
79975 }
79976 /**
79977 * Cell class for `LSTM`.
79978 *
79979 * `LSTMCell` is distinct from the `RNN` subclass `LSTM` in that its
79980 * `apply` method takes the input data of only a single time step and returns
79981 * the cell's output at the time step, while `LSTM` takes the input data
79982 * over a number of time steps. For example:
79983 *
79984 * ```js
79985 * const cell = tf.layers.lstmCell({units: 2});
79986 * const input = tf.input({shape: [10]});
79987 * const output = cell.apply(input);
79988 *
79989 * console.log(JSON.stringify(output.shape));
79990 * // [null, 10]: This is the cell's output at a single time step. The 1st
79991 * // dimension is the unknown batch size.
79992 * ```
79993 *
79994 * Instance(s) of `LSTMCell` can be used to construct `RNN` layers. The
79995 * most typical use of this workflow is to combine a number of cells into a
79996 * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
79997 * RNN. For example:
79998 *
79999 * ```js
80000 * const cells = [
80001 * tf.layers.lstmCell({units: 4}),
80002 * tf.layers.lstmCell({units: 8}),
80003 * ];
80004 * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
80005 *
80006 * // Create an input with 10 time steps and a length-20 vector at each step.
80007 * const input = tf.input({shape: [10, 20]});
80008 * const output = rnn.apply(input);
80009 *
80010 * console.log(JSON.stringify(output.shape));
80011 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
80012 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
80013 * // 3rd dimension is the last `lstmCell`'s number of units.
80014 * ```
80015 *
80016 * To create an `RNN` consisting of only *one* `LSTMCell`, use the
80017 * `tf.layers.lstm`.
80018 *
80019 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
80020 */
80021 function lstmCell(args) {
80022 return new LSTMCell(args);
80023 }
80024 /**
80025 * Fully-connected RNN where the output is to be fed back to input.
80026 *
80027 * This is an `RNN` layer consisting of one `SimpleRNNCell`. However, unlike
80028 * the underlying `SimpleRNNCell`, the `apply` method of `SimpleRNN` operates
80029 * on a sequence of inputs. The shape of the input (not including the first,
80030 * batch dimension) needs to be at least 2-D, with the first dimension being
80031 * time steps. For example:
80032 *
80033 * ```js
80034 * const rnn = tf.layers.simpleRNN({units: 8, returnSequences: true});
80035 *
80036 * // Create an input with 10 time steps.
80037 * const input = tf.input({shape: [10, 20]});
80038 * const output = rnn.apply(input);
80039 *
80040 * console.log(JSON.stringify(output.shape));
80041 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
80042 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
80043 * // 3rd dimension is the `SimpleRNNCell`'s number of units.
80044 * ```
80045 *
80046 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
80047 */
80048 function simpleRNN(args) {
80049 return new SimpleRNN(args);
80050 }
80051 /**
80052 * Cell class for `SimpleRNN`.
80053 *
80054 * `SimpleRNNCell` is distinct from the `RNN` subclass `SimpleRNN` in that its
80055 * `apply` method takes the input data of only a single time step and returns
80056 * the cell's output at the time step, while `SimpleRNN` takes the input data
80057 * over a number of time steps. For example:
80058 *
80059 * ```js
80060 * const cell = tf.layers.simpleRNNCell({units: 2});
80061 * const input = tf.input({shape: [10]});
80062 * const output = cell.apply(input);
80063 *
80064 * console.log(JSON.stringify(output.shape));
80065 * // [null, 10]: This is the cell's output at a single time step. The 1st
80066 * // dimension is the unknown batch size.
80067 * ```
80068 *
80069 * Instance(s) of `SimpleRNNCell` can be used to construct `RNN` layers. The
80070 * most typical use of this workflow is to combine a number of cells into a
80071 * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
80072 * RNN. For example:
80073 *
80074 * ```js
80075 * const cells = [
80076 * tf.layers.simpleRNNCell({units: 4}),
80077 * tf.layers.simpleRNNCell({units: 8}),
80078 * ];
80079 * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
80080 *
80081 * // Create an input with 10 time steps and a length-20 vector at each step.
80082 * const input = tf.input({shape: [10, 20]});
80083 * const output = rnn.apply(input);
80084 *
80085 * console.log(JSON.stringify(output.shape));
80086 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
80087 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
80088 * // 3rd dimension is the last `SimpleRNNCell`'s number of units.
80089 * ```
80090 *
80091 * To create an `RNN` consisting of only *one* `SimpleRNNCell`, use the
80092 * `tf.layers.simpleRNN`.
80093 *
80094 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
80095 */
80096 function simpleRNNCell(args) {
80097 return new SimpleRNNCell(args);
80098 }
80099 /**
80100 * Convolutional LSTM layer - Xingjian Shi 2015.
80101 *
80102 * This is a `ConvRNN2D` layer consisting of one `ConvLSTM2DCell`. However,
80103 * unlike the underlying `ConvLSTM2DCell`, the `apply` method of `ConvLSTM2D`
80104 * operates on a sequence of inputs. The shape of the input (not including the
80105 * first, batch dimension) needs to be 4-D, with the first dimension being time
80106 * steps. For example:
80107 *
80108 * ```js
80109 * const filters = 3;
80110 * const kernelSize = 3;
80111 *
80112 * const batchSize = 4;
80113 * const sequenceLength = 2;
80114 * const size = 5;
80115 * const channels = 3;
80116 *
80117 * const inputShape = [batchSize, sequenceLength, size, size, channels];
80118 * const input = tf.ones(inputShape);
80119 *
80120 * const layer = tf.layers.convLstm2d({filters, kernelSize});
80121 *
80122 * const output = layer.apply(input);
80123 * ```
80124 */
80125 /** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */
80126 function convLstm2d(args) {
80127 return new ConvLSTM2D(args);
80128 }
80129 /**
80130 * Cell class for `ConvLSTM2D`.
80131 *
80132 * `ConvLSTM2DCell` is distinct from the `ConvRNN2D` subclass `ConvLSTM2D` in
80133 * that its `call` method takes the input data of only a single time step and
80134 * returns the cell's output at the time step, while `ConvLSTM2D` takes the
80135 * input data over a number of time steps. For example:
80136 *
80137 * ```js
80138 * const filters = 3;
80139 * const kernelSize = 3;
80140 *
80141 * const sequenceLength = 1;
80142 * const size = 5;
80143 * const channels = 3;
80144 *
80145 * const inputShape = [sequenceLength, size, size, channels];
80146 * const input = tf.ones(inputShape);
80147 *
80148 * const cell = tf.layers.convLstm2dCell({filters, kernelSize});
80149 *
80150 * cell.build(input.shape);
80151 *
80152 * const outputSize = size - kernelSize + 1;
80153 * const outShape = [sequenceLength, outputSize, outputSize, filters];
80154 *
80155 * const initialH = tf.zeros(outShape);
80156 * const initialC = tf.zeros(outShape);
80157 *
80158 * const [o, h, c] = cell.call([input, initialH, initialC], {});
80159 * ```
80160 */
80161 /** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */
80162 function convLstm2dCell(args) {
80163 return new ConvLSTM2DCell(args);
80164 }
80165 /**
80166 * Base class for recurrent layers.
80167 *
80168 * Input shape:
80169 * 3D tensor with shape `[batchSize, timeSteps, inputDim]`.
80170 *
80171 * Output shape:
80172 * - if `returnState`, an Array of tensors (i.e., `tf.Tensor`s). The first
80173 * tensor is the output. The remaining tensors are the states at the
80174 * last time step, each with shape `[batchSize, units]`.
80175 * - if `returnSequences`, the output will have shape
80176 * `[batchSize, timeSteps, units]`.
80177 * - else, the output will have shape `[batchSize, units]`.
80178 *
80179 * Masking:
80180 * This layer supports masking for input data with a variable number
80181 * of timesteps. To introduce masks to your data,
80182 * use an embedding layer with the `mask_zero` parameter
80183 * set to `True`.
80184 *
80185 * Notes on using statefulness in RNNs:
80186 * You can set RNN layers to be 'stateful', which means that the states
80187 * computed for the samples in one batch will be reused as initial states
80188 * for the samples in the next batch. This assumes a one-to-one mapping
80189 * between samples in different successive batches.
80190 *
80191 * To enable statefulness:
80192 * - specify `stateful: true` in the layer constructor.
80193 * - specify a fixed batch size for your model, by passing
80194 * if sequential model:
80195 * `batchInputShape=[...]` to the first layer in your model.
80196 * else for functional model with 1 or more Input layers:
80197 * `batchShape=[...]` to all the first layers in your model.
80198 * This is the expected shape of your inputs *including the batch size*.
80199 * It should be a tuple of integers, e.g. `(32, 10, 100)`.
80200 * - specify `shuffle=False` when calling fit().
80201 *
80202 * To reset the states of your model, call `.resetStates()` on either
80203 * a specific layer, or on your entire model.
80204 *
80205 * Note on specifying the initial state of RNNs
80206 * You can specify the initial state of RNN layers symbolically by
80207 * calling them with the option `initialState`. The value of
80208 * `initialState` should be a tensor or list of tensors representing
80209 * the initial state of the RNN layer.
80210 *
80211 * You can specify the initial state of RNN layers numerically by
80212 * calling `resetStates` with the keyword argument `states`. The value of
80213 * `states` should be a numpy array or list of numpy arrays representing
80214 * the initial state of the RNN layer.
80215 *
80216 * Note on passing external constants to RNNs
80217 * You can pass "external" constants to the cell using the `constants`
80218 * keyword argument of `RNN.call` method. This requires that the `cell.call`
80219 * method accepts the same keyword argument `constants`. Such constants
80220 * can be used to condition the cell transformation on additional static
80221 * inputs (not changing over time), a.k.a. an attention mechanism.
80222 *
80223 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
80224 */
80225 function rnn(args) {
80226 return new RNN(args);
80227 }
80228 /**
80229 * Wrapper allowing a stack of RNN cells to behave as a single cell.
80230 *
80231 * Used to implement efficient stacked RNNs.
80232 *
80233 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
80234 */
80235 function stackedRNNCells(args) {
80236 return new StackedRNNCells(args);
80237 }
80238 // Wrapper Layers.
80239 /** @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'} */
80240 function bidirectional(args) {
80241 return new Bidirectional(args);
80242 }
80243 /**
80244 * This wrapper applies a layer to every temporal slice of an input.
80245 *
80246 * The input should be at least 3D, and the dimension of the index `1` will be
80247 * considered to be the temporal dimension.
80248 *
80249 * Consider a batch of 32 samples, where each sample is a sequence of 10 vectors
80250 * of 16 dimensions. The batch input shape of the layer is then `[32, 10,
80251 * 16]`, and the `inputShape`, not including the sample dimension, is
80252 * `[10, 16]`.
80253 *
80254 * You can then use `TimeDistributed` to apply a `Dense` layer to each of the 10
80255 * timesteps, independently:
80256 *
80257 * ```js
80258 * const model = tf.sequential();
80259 * model.add(tf.layers.timeDistributed({
80260 * layer: tf.layers.dense({units: 8}),
80261 * inputShape: [10, 16],
80262 * }));
80263 *
80264 * // Now model.outputShape = [null, 10, 8].
80265 * // The output will then have shape `[32, 10, 8]`.
80266 *
80267 * // In subsequent layers, there is no need for `inputShape`:
80268 * model.add(tf.layers.timeDistributed({layer: tf.layers.dense({units: 32})}));
80269 * console.log(JSON.stringify(model.outputs[0].shape));
80270 * // Now model.outputShape = [null, 10, 32].
80271 * ```
80272 *
80273 * The output will then have shape `[32, 10, 32]`.
80274 *
80275 * `TimeDistributed` can be used with arbitrary layers, not just `Dense`, for
80276 * instance a `Conv2D` layer.
80277 *
80278 * ```js
80279 * const model = tf.sequential();
80280 * model.add(tf.layers.timeDistributed({
80281 * layer: tf.layers.conv2d({filters: 64, kernelSize: [3, 3]}),
80282 * inputShape: [10, 299, 299, 3],
80283 * }));
80284 * console.log(JSON.stringify(model.outputs[0].shape));
80285 * ```
80286 *
80287 * @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'}
80288 */
80289 function timeDistributed(args) {
80290 return new TimeDistributed(args);
80291 }
80292 // Aliases for pooling.
80293 var globalMaxPool1d = globalMaxPooling1d;
80294 var globalMaxPool2d = globalMaxPooling2d;
80295 var maxPool1d = maxPooling1d;
80296 var maxPool2d = maxPooling2d;
80297 /**
80298 * Apply additive zero-centered Gaussian noise.
80299 *
80300 * As it is a regularization layer, it is only active at training time.
80301 *
80302 * This is useful to mitigate overfitting
80303 * (you could see it as a form of random data augmentation).
80304 * Gaussian Noise (GS) is a natural choice as corruption process
80305 * for real valued inputs.
80306 *
80307 * # Arguments
80308 * stddev: float, standard deviation of the noise distribution.
80309 *
80310 * # Input shape
80311 * Arbitrary. Use the keyword argument `input_shape`
80312 * (tuple of integers, does not include the samples axis)
80313 * when using this layer as the first layer in a model.
80314 *
80315 * # Output shape
80316 * Same shape as input.
80317 *
80318 * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
80319 */
80320 function gaussianNoise(args) {
80321 return new GaussianNoise(args);
80322 }
80323 /**
80324 * Apply multiplicative 1-centered Gaussian noise.
80325 *
80326 * As it is a regularization layer, it is only active at training time.
80327 *
80328 * Arguments:
80329 * - `rate`: float, drop probability (as with `Dropout`).
80330 * The multiplicative noise will have
80331 * standard deviation `sqrt(rate / (1 - rate))`.
80332 *
80333 * Input shape:
80334 * Arbitrary. Use the keyword argument `inputShape`
80335 * (tuple of integers, does not include the samples axis)
80336 * when using this layer as the first layer in a model.
80337 *
80338 * Output shape:
80339 * Same shape as input.
80340 *
80341 * References:
80342 * - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](
80343 * http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
80344 *
80345 * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
80346 */
80347 function gaussianDropout(args) {
80348 return new GaussianDropout(args);
80349 }
80350 /**
80351 * Applies Alpha Dropout to the input.
80352 *
80353 * As it is a regularization layer, it is only active at training time.
80354 *
80355 * Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
80356 * to their original values, in order to ensure the self-normalizing property
80357 * even after this dropout.
80358 * Alpha Dropout fits well to Scaled Exponential Linear Units
80359 * by randomly setting activations to the negative saturation value.
80360 *
80361 * Arguments:
80362 * - `rate`: float, drop probability (as with `Dropout`).
80363 * The multiplicative noise will have
80364 * standard deviation `sqrt(rate / (1 - rate))`.
80365 * - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the
80366 * shape for randomly generated keep/drop flags.
80367 *
80368 * Input shape:
80369 * Arbitrary. Use the keyword argument `inputShape`
80370 * (tuple of integers, does not include the samples axis)
80371 * when using this layer as the first layer in a model.
80372 *
80373 * Output shape:
80374 * Same shape as input.
80375 *
80376 * References:
80377 * - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
80378 *
80379 * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
80380 */
80381 function alphaDropout(args) {
80382 return new AlphaDropout(args);
80383 }
80384 /**
80385 * Masks a sequence by using a mask value to skip timesteps.
80386 *
80387 * If all features for a given sample timestep are equal to `mask_value`,
80388 * then the sample timestep will be masked (skipped) in all downstream layers
80389 * (as long as they support masking).
80390 *
80391 * If any downstream layer does not support masking yet receives such
80392 * an input mask, an exception will be raised.
80393 *
80394 * Arguments:
80395 * - `maskValue`: Either None or mask value to skip.
80396 *
80397 * Input shape:
80398 * Arbitrary. Use the keyword argument `inputShape`
80399 * (tuple of integers, does not include the samples axis)
80400 * when using this layer as the first layer in a model.
80401 *
80402 * Output shape:
80403 * Same shape as input.
80404 *
80405 * @doc {heading: 'Layers', subheading: 'Mask', namespace: 'layers'}
80406 */
80407 function masking(args) {
80408 return new Masking(args);
80409 }
80410 /**
80411 * A preprocessing layer which rescales input values to a new range.
80412 *
80413 * This layer rescales every value of an input (often an image) by multiplying
80414 * by `scale` and adding `offset`.
80415 *
80416 * For instance:
80417 * 1. To rescale an input in the ``[0, 255]`` range
80418 * to be in the `[0, 1]` range, you would pass `scale=1/255`.
80419 * 2. To rescale an input in the ``[0, 255]`` range to be in the `[-1, 1]`
80420 * range, you would pass `scale=1./127.5, offset=-1`.
80421 * The rescaling is applied both during training and inference. Inputs can be
80422 * of integer or floating point dtype, and by default the layer will output
80423 * floats.
80424 *
80425 * Arguments:
80426 * - `scale`: Float, the scale to apply to the inputs.
80427 * - `offset`: Float, the offset to apply to the inputs.
80428 *
80429 * Input shape:
80430 * Arbitrary.
80431 *
80432 * Output shape:
80433 * Same as input.
80434 *
80435 * @doc {heading: 'Layers', subheading: 'Rescaling', namespace: 'layers'}
80436 */
80437 function rescaling(args) {
80438 return new Rescaling(args);
80439 }
80440 /**
80441 * A preprocessing layer which center crops images.
80442 *
80443 * This layers crops the central portion of the images to a target size. If an
80444 * image is smaller than the target size, it will be resized and cropped so as
80445 * to return the largest possible window in the image that matches the target
80446 * aspect ratio.
80447 *
80448 * Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and
80449 * of integer or floating point dtype.
80450 *
80451 * If the input height/width is even and the target height/width is odd (or
80452 * inversely), the input image is left-padded by 1 pixel.
80453 *
80454 * Arguments:
80455 * `height`: Integer, the height of the output shape.
80456 * `width`: Integer, the width of the output shape.
80457 *
80458 * Input shape:
80459 * 3D (unbatched) or 4D (batched) tensor with shape:
80460 * `(..., height, width, channels)`, in `channelsLast` format.
80461 *
80462 * Output shape:
80463 * 3D (unbatched) or 4D (batched) tensor with shape:
80464 * `(..., targetHeight, targetWidth, channels)`.
80465 *
80466 *
80467 * @doc {heading: 'Layers', subheading: 'CenterCrop', namespace: 'layers'}
80468 */
80469 function centerCrop(args) {
80470 return new CenterCrop(args);
80471 }
80472 /**
80473 * A preprocessing layer which resizes images.
80474 * This layer resizes an image input to a target height and width. The input
80475 * should be a 4D (batched) or 3D (unbatched) tensor in `"channels_last"`
80476 * format. Input pixel values can be of any range (e.g. `[0., 1.)` or `[0,
80477 * 255]`) and of interger or floating point dtype. By default, the layer will
80478 * output floats.
80479 *
80480 * Arguments:
80481 * - `height`: number, the height for the output tensor.
80482 * - `width`: number, the width for the output tensor.
80483 * - `interpolation`: string, the method for image resizing interpolation.
80484 * - `cropToAspectRatio`: boolean, whether to keep image aspect ratio.
80485 *
80486 * Input shape:
80487 * Arbitrary.
80488 *
80489 * Output shape:
80490 * height, width, num channels.
80491 *
80492 * @doc {heading: 'Layers', subheading: 'Resizing', namespace: 'layers'}
80493 */
80494 function resizing(args) {
80495 return new Resizing(args);
80496 }
80497 /**
80498 * A preprocessing layer which encodes integer features.
80499 *
80500 * This layer provides options for condensing data into a categorical encoding
80501 * when the total number of tokens are known in advance. It accepts integer
80502 * values as inputs, and it outputs a dense representation of those
80503 * inputs.
80504 *
80505 * Arguments:
80506 *
80507 * numTokens: The total number of tokens the layer should support. All
80508 * inputs to the layer must integers in the range `0 <= value <
80509 * numTokens`, or an error will be thrown.
80510 *
80511 * outputMode: Specification for the output of the layer.
80512 * Defaults to `multiHot`. Values can be `oneHot`, `multiHot` or
80513 * `count`, configuring the layer as follows:
80514 *
80515 * oneHot: Encodes each individual element in the input into an
80516 * array of `numTokens` size, containing a 1 at the element index. If
80517 * the last dimension is size 1, will encode on that dimension. If the
80518 * last dimension is not size 1, will append a new dimension for the
80519 * encoded output.
80520 *
80521 * multiHot: Encodes each sample in the input into a single array
80522 * of `numTokens` size, containing a 1 for each vocabulary term
80523 * present in the sample. Treats the last dimension as the sample
80524 * dimension, if input shape is `(..., sampleLength)`, output shape
80525 * will be `(..., numTokens)`.
80526 *
80527 * count: Like `multiHot`, but the int array contains a count of
80528 * the number of times the token at that index appeared in the sample.
80529 *
80530 * For all output modes, currently only output up to rank 2 is supported.
80531 * Call arguments:
80532 * inputs: A 1D or 2D tensor of integer inputs.
80533 * countWeights: A tensor in the same shape as `inputs` indicating the
80534 * weight for each sample value when summing up in `count` mode. Not used
80535 * in `multiHot` or `oneHot` modes.
80536 *
80537 *
80538 * @doc {heading: 'Layers', subheading: 'CategoryEncoding', namespace: 'layers'}
80539 */
80540 function categoryEncoding(args) {
80541 return new CategoryEncoding(args);
80542 }
80543 /**
80544 * A preprocessing layer which randomly varies image width during training.
80545 *
80546 * This layer will randomly adjusts the width of a batch of images of a batch
80547 * of images by a random factor.
80548 *
80549 * The input should be a 3D (unbatched) or 4D (batched) tensor in
80550 * the `"channels_last"` image data format. Input pixel values can be of any
80551 * range (e.g. `[0., 1.)` or `[0, 255]`) and of integer or floating point
80552 * dtype. By default, the layer will output floats. By default, this layer is
80553 * inactive during inference. For an overview and full list of preprocessing
80554 * layers, see the preprocessing [guide]
80555 * (https://www.tensorflow.org/guide/keras/preprocessing_layers).
80556 *
80557 * Arguments:
80558 *
80559 * factor:
80560 * A positive float (fraction of original width), or a tuple of size 2
80561 * representing lower and upper bound for resizing vertically.
80562 * When represented as a single float, this value is used for both the upper
80563 * and lower bound. For instance, `factor=(0.2, 0.3)` results in an output
80564 * with width changed by a random amount in the range `[20%, 30%]`.
80565 * `factor=(-0.2, 0.3)` results in an output with width changed by a random
80566 * amount in the range `[-20%, +30%]`. `factor=0.2` results in an output
80567 * with width changed by a random amount in the range `[-20%, +20%]`.
80568 * interpolation:
80569 * String, the interpolation method.
80570 * Defaults to `bilinear`.
80571 * Supports `"bilinear"`, `"nearest"`.
80572 * The tf methods `"bicubic"`, `"area"`, `"lanczos3"`, `"lanczos5"`,
80573 * `"gaussian"`, `"mitchellcubic"` are unimplemented in tfjs.
80574 * seed:
80575 * Integer. Used to create a random seed.
80576 *
80577 * Input shape:
80578 * 3D (unbatched) or 4D (batched) tensor with shape:
80579 * `(..., height, width, channels)`, in `"channels_last"` format.
80580 * Output shape:
80581 * 3D (unbatched) or 4D (batched) tensor with shape:
80582 * `(..., height, random_width, channels)`.
80583 *
80584 *
80585 * @doc {heading: 'Layers', subheading: 'RandomWidth', namespace: 'layers'}
80586 */
80587 function randomWidth(args) {
80588 return new RandomWidth(args);
80589 }
80590
80591 var exports_layers = {
80592 __proto__: null,
80593 Layer: Layer,
80594 RNN: RNN,
80595 RNNCell: RNNCell,
80596 activation: activation,
80597 add: add$1,
80598 alphaDropout: alphaDropout,
80599 average: average,
80600 averagePooling1d: averagePooling1d,
80601 averagePooling2d: averagePooling2d,
80602 averagePooling3d: averagePooling3d,
80603 avgPool1d: avgPool1d,
80604 avgPool2d: avgPool2d,
80605 avgPool3d: avgPool3d,
80606 avgPooling1d: avgPooling1d,
80607 avgPooling2d: avgPooling2d,
80608 avgPooling3d: avgPooling3d,
80609 batchNormalization: batchNormalization,
80610 bidirectional: bidirectional,
80611 categoryEncoding: categoryEncoding,
80612 centerCrop: centerCrop,
80613 concatenate: concatenate,
80614 conv1d: conv1d,
80615 conv2d: conv2d$1,
80616 conv2dTranspose: conv2dTranspose,
80617 conv3d: conv3d,
80618 conv3dTranspose: conv3dTranspose,
80619 convLstm2d: convLstm2d,
80620 convLstm2dCell: convLstm2dCell,
80621 cropping2D: cropping2D,
80622 dense: dense,
80623 depthwiseConv2d: depthwiseConv2d,
80624 dot: dot,
80625 dropout: dropout,
80626 elu: elu$2,
80627 embedding: embedding,
80628 flatten: flatten,
80629 gaussianDropout: gaussianDropout,
80630 gaussianNoise: gaussianNoise,
80631 globalAveragePooling1d: globalAveragePooling1d,
80632 globalAveragePooling2d: globalAveragePooling2d,
80633 globalMaxPool1d: globalMaxPool1d,
80634 globalMaxPool2d: globalMaxPool2d,
80635 globalMaxPooling1d: globalMaxPooling1d,
80636 globalMaxPooling2d: globalMaxPooling2d,
80637 gru: gru,
80638 gruCell: gruCell,
80639 input: input,
80640 inputLayer: inputLayer,
80641 layerNormalization: layerNormalization,
80642 leakyReLU: leakyReLU,
80643 lstm: lstm,
80644 lstmCell: lstmCell,
80645 masking: masking,
80646 maxPool1d: maxPool1d,
80647 maxPool2d: maxPool2d,
80648 maxPooling1d: maxPooling1d,
80649 maxPooling2d: maxPooling2d,
80650 maxPooling3d: maxPooling3d,
80651 maximum: maximum$2,
80652 minimum: minimum$2,
80653 multiply: multiply$2,
80654 permute: permute,
80655 prelu: prelu$2,
80656 randomWidth: randomWidth,
80657 reLU: reLU,
80658 repeatVector: repeatVector,
80659 rescaling: rescaling,
80660 reshape: reshape$2,
80661 resizing: resizing,
80662 rnn: rnn,
80663 separableConv2d: separableConv2d,
80664 simpleRNN: simpleRNN,
80665 simpleRNNCell: simpleRNNCell,
80666 softmax: softmax$2,
80667 spatialDropout1d: spatialDropout1d,
80668 stackedRNNCells: stackedRNNCells,
80669 thresholdedReLU: thresholdedReLU,
80670 timeDistributed: timeDistributed,
80671 upSampling2d: upSampling2d,
80672 zeroPadding2d: zeroPadding2d
80673 };
80674
80675 /**
80676 * Binary accuracy metric function.
80677 *
80678 * `yTrue` and `yPred` can have 0-1 values. Example:
80679 * ```js
80680 * const x = tf.tensor2d([[1, 1, 1, 1], [0, 0, 0, 0]], [2, 4]);
80681 * const y = tf.tensor2d([[1, 0, 1, 0], [0, 0, 0, 1]], [2, 4]);
80682 * const accuracy = tf.metrics.binaryAccuracy(x, y);
80683 * accuracy.print();
80684 * ```
80685 *
80686 * `yTrue` and `yPred` can also have floating-number values between 0 and 1, in
80687 * which case the values will be thresholded at 0.5 to yield 0-1 values (i.e.,
80688 * a value >= 0.5 and <= 1.0 is interpreted as 1).
80689 *
80690 * Example:
80691 * ```js
80692 * const x = tf.tensor1d([1, 1, 1, 1, 0, 0, 0, 0]);
80693 * const y = tf.tensor1d([0.2, 0.4, 0.6, 0.8, 0.2, 0.3, 0.4, 0.7]);
80694 * const accuracy = tf.metrics.binaryAccuracy(x, y);
80695 * accuracy.print();
80696 * ```
80697 *
80698 * @param yTrue Binary Tensor of truth.
80699 * @param yPred Binary Tensor of prediction.
80700 * @return Accuracy Tensor.
80701 *
80702 * @doc {heading: 'Metrics', namespace: 'metrics'}
80703 */
80704 function binaryAccuracy(yTrue, yPred) {
80705 return binaryAccuracy$1(yTrue, yPred);
80706 }
80707 /**
80708 * Binary crossentropy metric function.
80709 *
80710 * Example:
80711 * ```js
80712 * const x = tf.tensor2d([[0], [1], [1], [1]]);
80713 * const y = tf.tensor2d([[0], [0], [0.5], [1]]);
80714 * const crossentropy = tf.metrics.binaryCrossentropy(x, y);
80715 * crossentropy.print();
80716 * ```
80717 *
80718 * @param yTrue Binary Tensor of truth.
80719 * @param yPred Binary Tensor of prediction, probabilities for the `1` case.
80720 * @return Accuracy Tensor.
80721 *
80722 * @doc {heading: 'Metrics', namespace: 'metrics'}
80723 */
80724 function binaryCrossentropy(yTrue, yPred) {
80725 return binaryCrossentropy$1(yTrue, yPred);
80726 }
80727 /**
80728 * Sparse categorical accuracy metric function.
80729 *
80730 * Example:
80731 * ```js
80732 *
80733 * const yTrue = tf.tensor1d([1, 1, 2, 2, 0]);
80734 * const yPred = tf.tensor2d(
80735 * [[0, 1, 0], [1, 0, 0], [0, 0.4, 0.6], [0, 0.6, 0.4], [0.7, 0.3, 0]]);
80736 * const crossentropy = tf.metrics.sparseCategoricalAccuracy(yTrue, yPred);
80737 * crossentropy.print();
80738 * ```
80739 *
80740 * @param yTrue True labels: indices.
80741 * @param yPred Predicted probabilities or logits.
80742 * @returns Accuracy tensor.
80743 *
80744 * @doc {heading: 'Metrics', namespace: 'metrics'}
80745 */
80746 function sparseCategoricalAccuracy(yTrue, yPred) {
80747 return sparseCategoricalAccuracy$1(yTrue, yPred);
80748 }
80749 /**
80750 * Categorical accuracy metric function.
80751 *
80752 * Example:
80753 * ```js
80754 * const x = tf.tensor2d([[0, 0, 0, 1], [0, 0, 0, 1]]);
80755 * const y = tf.tensor2d([[0.1, 0.8, 0.05, 0.05], [0.1, 0.05, 0.05, 0.8]]);
80756 * const accuracy = tf.metrics.categoricalAccuracy(x, y);
80757 * accuracy.print();
80758 * ```
80759 *
80760 * @param yTrue Binary Tensor of truth: one-hot encoding of categories.
80761 * @param yPred Binary Tensor of prediction: probabilities or logits for the
80762 * same categories as in `yTrue`.
80763 * @return Accuracy Tensor.
80764 *
80765 * @doc {heading: 'Metrics', namespace: 'metrics'}
80766 */
80767 function categoricalAccuracy(yTrue, yPred) {
80768 return categoricalAccuracy$1(yTrue, yPred);
80769 }
80770 /**
80771 * Categorical crossentropy between an output tensor and a target tensor.
80772 *
80773 * @param target A tensor of the same shape as `output`.
80774 * @param output A tensor resulting from a softmax (unless `fromLogits` is
80775 * `true`, in which case `output` is expected to be the logits).
80776 * @param fromLogits Boolean, whether `output` is the result of a softmax, or is
80777 * a tensor of logits.
80778 *
80779 * @doc {heading: 'Metrics', namespace: 'metrics'}
80780 */
80781 function categoricalCrossentropy(yTrue, yPred) {
80782 return categoricalCrossentropy$1(yTrue, yPred);
80783 }
80784 /**
80785 * Computes the precision of the predictions with respect to the labels.
80786 *
80787 * Example:
80788 * ```js
80789 * const x = tf.tensor2d(
80790 * [
80791 * [0, 0, 0, 1],
80792 * [0, 1, 0, 0],
80793 * [0, 0, 0, 1],
80794 * [1, 0, 0, 0],
80795 * [0, 0, 1, 0]
80796 * ]
80797 * );
80798 *
80799 * const y = tf.tensor2d(
80800 * [
80801 * [0, 0, 1, 0],
80802 * [0, 1, 0, 0],
80803 * [0, 0, 0, 1],
80804 * [0, 1, 0, 0],
80805 * [0, 1, 0, 0]
80806 * ]
80807 * );
80808 *
80809 * const precision = tf.metrics.precision(x, y);
80810 * precision.print();
80811 * ```
80812 *
80813 * @param yTrue The ground truth values. Expected to contain only 0-1 values.
80814 * @param yPred The predicted values. Expected to contain only 0-1 values.
80815 * @return Precision Tensor.
80816 *
80817 * @doc {heading: 'Metrics', namespace: 'metrics'}
80818 */
80819 function precision(yTrue, yPred) {
80820 return precision$1(yTrue, yPred);
80821 }
80822 /**
80823 * Computes the recall of the predictions with respect to the labels.
80824 *
80825 * Example:
80826 * ```js
80827 * const x = tf.tensor2d(
80828 * [
80829 * [0, 0, 0, 1],
80830 * [0, 1, 0, 0],
80831 * [0, 0, 0, 1],
80832 * [1, 0, 0, 0],
80833 * [0, 0, 1, 0]
80834 * ]
80835 * );
80836 *
80837 * const y = tf.tensor2d(
80838 * [
80839 * [0, 0, 1, 0],
80840 * [0, 1, 0, 0],
80841 * [0, 0, 0, 1],
80842 * [0, 1, 0, 0],
80843 * [0, 1, 0, 0]
80844 * ]
80845 * );
80846 *
80847 * const recall = tf.metrics.recall(x, y);
80848 * recall.print();
80849 * ```
80850 *
80851 * @param yTrue The ground truth values. Expected to contain only 0-1 values.
80852 * @param yPred The predicted values. Expected to contain only 0-1 values.
80853 * @return Recall Tensor.
80854 *
80855 * @doc {heading: 'Metrics', namespace: 'metrics'}
80856 */
80857 function recall(yTrue, yPred) {
80858 return recall$1(yTrue, yPred);
80859 }
80860 /**
80861 * Loss or metric function: Cosine proximity.
80862 *
80863 * Mathematically, cosine proximity is defined as:
80864 * `-sum(l2Normalize(yTrue) * l2Normalize(yPred))`,
80865 * wherein `l2Normalize()` normalizes the L2 norm of the input to 1 and `*`
80866 * represents element-wise multiplication.
80867 *
80868 * ```js
80869 * const yTrue = tf.tensor2d([[1, 0], [1, 0]]);
80870 * const yPred = tf.tensor2d([[1 / Math.sqrt(2), 1 / Math.sqrt(2)], [0, 1]]);
80871 * const proximity = tf.metrics.cosineProximity(yTrue, yPred);
80872 * proximity.print();
80873 * ```
80874 *
80875 * @param yTrue Truth Tensor.
80876 * @param yPred Prediction Tensor.
80877 * @return Cosine proximity Tensor.
80878 *
80879 * @doc {heading: 'Metrics', namespace: 'metrics'}
80880 */
80881 function cosineProximity(yTrue, yPred) {
80882 return cosineProximity$1(yTrue, yPred);
80883 }
80884 /**
80885 * Loss or metric function: Mean absolute error.
80886 *
80887 * Mathematically, mean absolute error is defined as:
80888 * `mean(abs(yPred - yTrue))`,
80889 * wherein the `mean` is applied over feature dimensions.
80890 *
80891 * ```js
80892 * const yTrue = tf.tensor2d([[0, 1], [0, 0], [2, 3]]);
80893 * const yPred = tf.tensor2d([[0, 1], [0, 1], [-2, -3]]);
80894 * const mse = tf.metrics.meanAbsoluteError(yTrue, yPred);
80895 * mse.print();
80896 * ```
80897 *
80898 * @param yTrue Truth Tensor.
80899 * @param yPred Prediction Tensor.
80900 * @return Mean absolute error Tensor.
80901 *
80902 * @doc {heading: 'Metrics', namespace: 'metrics'}
80903 */
80904 function meanAbsoluteError(yTrue, yPred) {
80905 return meanAbsoluteError$1(yTrue, yPred);
80906 }
80907 /**
80908 * Loss or metric function: Mean absolute percentage error.
80909 *
80910 * ```js
80911 * const yTrue = tf.tensor2d([[0, 1], [10, 20]]);
80912 * const yPred = tf.tensor2d([[0, 1], [11, 24]]);
80913 * const mse = tf.metrics.meanAbsolutePercentageError(yTrue, yPred);
80914 * mse.print();
80915 * ```
80916 *
80917 * Aliases: `tf.metrics.MAPE`, `tf.metrics.mape`.
80918 *
80919 * @param yTrue Truth Tensor.
80920 * @param yPred Prediction Tensor.
80921 * @return Mean absolute percentage error Tensor.
80922 *
80923 * @doc {heading: 'Metrics', namespace: 'metrics'}
80924 */
80925 function meanAbsolutePercentageError(yTrue, yPred) {
80926 return meanAbsolutePercentageError$1(yTrue, yPred);
80927 }
80928 function MAPE(yTrue, yPred) {
80929 return meanAbsolutePercentageError$1(yTrue, yPred);
80930 }
80931 function mape(yTrue, yPred) {
80932 return meanAbsolutePercentageError$1(yTrue, yPred);
80933 }
80934 /**
80935 * Loss or metric function: Mean squared error.
80936 *
80937 * ```js
80938 * const yTrue = tf.tensor2d([[0, 1], [3, 4]]);
80939 * const yPred = tf.tensor2d([[0, 1], [-3, -4]]);
80940 * const mse = tf.metrics.meanSquaredError(yTrue, yPred);
80941 * mse.print();
80942 * ```
80943 *
80944 * Aliases: `tf.metrics.MSE`, `tf.metrics.mse`.
80945 *
80946 * @param yTrue Truth Tensor.
80947 * @param yPred Prediction Tensor.
80948 * @return Mean squared error Tensor.
80949 *
80950 * @doc {heading: 'Metrics', namespace: 'metrics'}
80951 */
80952 function meanSquaredError(yTrue, yPred) {
80953 return meanSquaredError$1(yTrue, yPred);
80954 }
80955 function MSE(yTrue, yPred) {
80956 return meanSquaredError$1(yTrue, yPred);
80957 }
80958 function mse(yTrue, yPred) {
80959 return meanSquaredError$1(yTrue, yPred);
80960 }
80961 /**
80962 * Computes R2 score.
80963 *
80964 * ```js
80965 * const yTrue = tf.tensor2d([[0, 1], [3, 4]]);
80966 * const yPred = tf.tensor2d([[0, 1], [-3, -4]]);
80967 * const r2Score = tf.metrics.r2Score(yTrue, yPred);
80968 * r2Score.print();
80969 * ```
80970 * @param yTrue Truth Tensor.
80971 * @param yPred Prediction Tensor.
80972 * @return R2 score Tensor.
80973 *
80974 * @doc {heading: 'Metrics', namespace: 'metrics'}
80975 */
80976 function r2Score(yTrue, yPred) {
80977 return r2Score$1(yTrue, yPred);
80978 }
80979
80980 var exports_metrics = {
80981 __proto__: null,
80982 MAPE: MAPE,
80983 MSE: MSE,
80984 binaryAccuracy: binaryAccuracy,
80985 binaryCrossentropy: binaryCrossentropy,
80986 categoricalAccuracy: categoricalAccuracy,
80987 categoricalCrossentropy: categoricalCrossentropy,
80988 cosineProximity: cosineProximity,
80989 mape: mape,
80990 meanAbsoluteError: meanAbsoluteError,
80991 meanAbsolutePercentageError: meanAbsolutePercentageError,
80992 meanSquaredError: meanSquaredError,
80993 mse: mse,
80994 precision: precision,
80995 r2Score: r2Score,
80996 recall: recall,
80997 sparseCategoricalAccuracy: sparseCategoricalAccuracy
80998 };
80999
81000 /**
81001 * @license
81002 * Copyright 2018 Google LLC
81003 *
81004 * Use of this source code is governed by an MIT-style
81005 * license that can be found in the LICENSE file or at
81006 * https://opensource.org/licenses/MIT.
81007 * =============================================================================
81008 */
81009
81010 var exports_models = {
81011 __proto__: null,
81012 modelFromJSON: modelFromJSON
81013 };
81014
81015 /**
81016 * @license
81017 * Copyright 2018 Google LLC
81018 *
81019 * Use of this source code is governed by an MIT-style
81020 * license that can be found in the LICENSE file or at
81021 * https://opensource.org/licenses/MIT.
81022 * =============================================================================
81023 */
81024 /**
81025 * Regularizer for L1 and L2 regularization.
81026 *
81027 * Adds a term to the loss to penalize large weights:
81028 * loss += sum(l1 * abs(x)) + sum(l2 * x^2)
81029 *
81030 * @doc {heading: 'Regularizers', namespace: 'regularizers'}
81031 */
81032 function l1l2(config) {
81033 return new L1L2(config);
81034 }
81035 /**
81036 * Regularizer for L1 regularization.
81037 *
81038 * Adds a term to the loss to penalize large weights:
81039 * loss += sum(l1 * abs(x))
81040 * @param args l1 config.
81041 *
81042 * @doc {heading: 'Regularizers', namespace: 'regularizers'}
81043 */
81044 function l1(config) {
81045 return l1$1(config);
81046 }
81047 /**
81048 * Regularizer for L2 regularization.
81049 *
81050 * Adds a term to the loss to penalize large weights:
81051 * loss += sum(l2 * x^2)
81052 * @param args l2 config.
81053 *
81054 * @doc {heading: 'Regularizers', namespace: 'regularizers'}
81055 */
81056 function l2(config) {
81057 return l2$1(config);
81058 }
81059
81060 var exports_regularizers = {
81061 __proto__: null,
81062 l1: l1,
81063 l1l2: l1l2,
81064 l2: l2
81065 };
81066
81067 var Callback = /*#__PURE__*/function (_BaseCallback) {
81068 _inherits(Callback, _BaseCallback);
81069 var _super = _createSuper(Callback);
81070 function Callback() {
81071 var _this;
81072 _classCallCheck(this, Callback);
81073 _this = _super.apply(this, arguments);
81074 /** Instance of `keras.models.Model`. Reference of the model being trained. */
81075 _this.model = null;
81076 return _this;
81077 }
81078 _createClass(Callback, [{
81079 key: "setModel",
81080 value: function setModel(model) {
81081 if (!(model instanceof LayersModel)) {
81082 throw new Error('model must be a LayersModel, not some other Container');
81083 }
81084 this.model = model;
81085 }
81086 }]);
81087 return Callback;
81088 }(BaseCallback);
81089 function less$2(currVal, prevVal) {
81090 return currVal < prevVal;
81091 }
81092 function greater$2(currVal, prevVal) {
81093 return currVal > prevVal;
81094 }
81095 /**
81096 * A Callback that stops training when a monitored quantity has stopped
81097 * improving.
81098 */
81099 var EarlyStopping = /*#__PURE__*/function (_Callback) {
81100 _inherits(EarlyStopping, _Callback);
81101 var _super2 = _createSuper(EarlyStopping);
81102 function EarlyStopping(args) {
81103 var _this2;
81104 _classCallCheck(this, EarlyStopping);
81105 _this2 = _super2.call(this);
81106 if (args == null) {
81107 args = {};
81108 }
81109 if (args.restoreBestWeights) {
81110 throw new NotImplementedError('restoreBestWeights = True is not implemented in EarlyStopping yet.');
81111 }
81112 _this2.monitor = args.monitor || 'val_loss';
81113 _this2.minDelta = Math.abs(args.minDelta || 0);
81114 _this2.patience = args.patience || 0;
81115 _this2.verbose = args.verbose || 0;
81116 _this2.mode = args.mode || 'auto';
81117 _this2.baseline = args.baseline;
81118 if (['auto', 'min', 'max'].indexOf(_this2.mode) === -1) {
81119 console.warn("EarlyStopping mode '".concat(_this2.mode, "' is invalid. ") + "Falling back to mode 'auto'.");
81120 _this2.mode = 'auto';
81121 }
81122 if (_this2.mode === 'min') {
81123 _this2.monitorFunc = less$2;
81124 } else if (_this2.mode === 'max') {
81125 _this2.monitorFunc = greater$2;
81126 } else {
81127 // For mode === 'auto'.
81128 if (_this2.monitor.indexOf('acc') !== -1) {
81129 _this2.monitorFunc = greater$2;
81130 } else {
81131 _this2.monitorFunc = less$2;
81132 }
81133 }
81134 if (_this2.monitorFunc === less$2) {
81135 _this2.minDelta *= -1;
81136 }
81137 return _this2;
81138 }
81139 _createClass(EarlyStopping, [{
81140 key: "onTrainBegin",
81141 value: function () {
81142 var _onTrainBegin = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(logs) {
81143 return _regeneratorRuntime().wrap(function _callee$(_context) {
81144 while (1) switch (_context.prev = _context.next) {
81145 case 0:
81146 this.wait = 0;
81147 this.stoppedEpoch = 0;
81148 if (this.baseline != null) {
81149 this.best = this.baseline;
81150 } else {
81151 this.best = this.monitorFunc === less$2 ? Infinity : -Infinity;
81152 }
81153 case 3:
81154 case "end":
81155 return _context.stop();
81156 }
81157 }, _callee, this);
81158 }));
81159 function onTrainBegin(_x) {
81160 return _onTrainBegin.apply(this, arguments);
81161 }
81162 return onTrainBegin;
81163 }()
81164 }, {
81165 key: "onEpochEnd",
81166 value: function () {
81167 var _onEpochEnd = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(epoch, logs) {
81168 var current;
81169 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
81170 while (1) switch (_context2.prev = _context2.next) {
81171 case 0:
81172 _context2.next = 2;
81173 return resolveScalarsInLogs(logs);
81174 case 2:
81175 current = this.getMonitorValue(logs);
81176 if (!(current == null)) {
81177 _context2.next = 5;
81178 break;
81179 }
81180 return _context2.abrupt("return");
81181 case 5:
81182 if (this.monitorFunc(current - this.minDelta, this.best)) {
81183 this.best = current;
81184 this.wait = 0;
81185 // TODO(cais): Logic for restoreBestWeights.
81186 } else {
81187 this.wait++;
81188 if (this.wait >= this.patience) {
81189 this.stoppedEpoch = epoch;
81190 this.model.stopTraining = true;
81191 }
81192 // TODO(cais): Logic for restoreBestWeights.
81193 }
81194 case 6:
81195 case "end":
81196 return _context2.stop();
81197 }
81198 }, _callee2, this);
81199 }));
81200 function onEpochEnd(_x2, _x3) {
81201 return _onEpochEnd.apply(this, arguments);
81202 }
81203 return onEpochEnd;
81204 }()
81205 }, {
81206 key: "onTrainEnd",
81207 value: function () {
81208 var _onTrainEnd = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(logs) {
81209 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
81210 while (1) switch (_context3.prev = _context3.next) {
81211 case 0:
81212 if (this.stoppedEpoch > 0 && this.verbose) {
81213 console.log("Epoch ".concat(this.stoppedEpoch, ": early stopping."));
81214 }
81215 case 1:
81216 case "end":
81217 return _context3.stop();
81218 }
81219 }, _callee3, this);
81220 }));
81221 function onTrainEnd(_x4) {
81222 return _onTrainEnd.apply(this, arguments);
81223 }
81224 return onTrainEnd;
81225 }()
81226 }, {
81227 key: "getMonitorValue",
81228 value: function getMonitorValue(logs) {
81229 if (logs == null) {
81230 logs = {};
81231 }
81232 var monitorValue = logs[this.monitor];
81233 if (monitorValue == null) {
81234 console.warn("Metric for EarlyStopping ".concat(this.monitor, " is not available. ") + "Available metrics are: ".concat(Object.keys(logs)));
81235 }
81236 return monitorValue;
81237 }
81238 }]);
81239 return EarlyStopping;
81240 }(Callback);
81241 /**
81242 * Factory function for a Callback that stops training when a monitored
81243 * quantity has stopped improving.
81244 *
81245 * Early stopping is a type of regularization, and protects model against
81246 * overfitting.
81247 *
81248 * The following example based on fake data illustrates how this callback
81249 * can be used during `tf.LayersModel.fit()`:
81250 *
81251 * ```js
81252 * const model = tf.sequential();
81253 * model.add(tf.layers.dense({
81254 * units: 3,
81255 * activation: 'softmax',
81256 * kernelInitializer: 'ones',
81257 * inputShape: [2]
81258 * }));
81259 * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
81260 * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
81261 * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
81262 * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
81263 * model.compile(
81264 * {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
81265 *
81266 * // Without the EarlyStopping callback, the val_acc value would be:
81267 * // 0.5, 0.5, 0.5, 0.5, ...
81268 * // With val_acc being monitored, training should stop after the 2nd epoch.
81269 * const history = await model.fit(xs, ys, {
81270 * epochs: 10,
81271 * validationData: [xsVal, ysVal],
81272 * callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
81273 * });
81274 *
81275 * // Expect to see a length-2 array.
81276 * console.log(history.history.val_acc);
81277 * ```
81278 *
81279 * @doc {
81280 * heading: 'Callbacks',
81281 * namespace: 'callbacks'
81282 * }
81283 */
81284 function earlyStopping(args) {
81285 return new EarlyStopping(args);
81286 }
81287 var callbacks = {
81288 earlyStopping: earlyStopping
81289 };
81290
81291 /**
81292 * @license
81293 * Copyright 2018 Google LLC
81294 *
81295 * Use of this source code is governed by an MIT-style
81296 * license that can be found in the LICENSE file or at
81297 * https://opensource.org/licenses/MIT.
81298 * =============================================================================
81299 */
81300
81301 /**
81302 * @license
81303 * Copyright 2021 Google LLC. All Rights Reserved.
81304 * Licensed under the Apache License, Version 2.0 (the "License");
81305 * you may not use this file except in compliance with the License.
81306 * You may obtain a copy of the License at
81307 *
81308 * http://www.apache.org/licenses/LICENSE-2.0
81309 *
81310 * Unless required by applicable law or agreed to in writing, software
81311 * distributed under the License is distributed on an "AS IS" BASIS,
81312 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81313 * See the License for the specific language governing permissions and
81314 * limitations under the License.
81315 * =============================================================================
81316 */
81317 var ENV$1 = env();
81318 /** Whether to keep intermediate tensors. */
81319 ENV$1.registerFlag('KEEP_INTERMEDIATE_TENSORS', function () {
81320 return false;
81321 }, function (debugValue) {
81322 if (debugValue) {
81323 console.warn('Keep intermediate tensors is ON. This will print the values of all ' + 'intermediate tensors during model inference. Not all models ' + 'support this mode. For details, check e2e/benchmarks/ ' + 'model_config.js. This significantly impacts performance.');
81324 }
81325 });
81326
81327 /**
81328 * @license
81329 * Copyright 2019 Google LLC. All Rights Reserved.
81330 * Licensed under the Apache License, Version 2.0 (the "License");
81331 * you may not use this file except in compliance with the License.
81332 * You may obtain a copy of the License at
81333 *
81334 * http://www.apache.org/licenses/LICENSE-2.0
81335 *
81336 * Unless required by applicable law or agreed to in writing, software
81337 * distributed under the License is distributed on an "AS IS" BASIS,
81338 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81339 * See the License for the specific language governing permissions and
81340 * limitations under the License.
81341 *
81342 * =============================================================================
81343 */
81344 /** DataType enum. */
81345 var DataType;
81346 (function (DataType) {
81347 // These properties must be quoted since they are used by parseDtypeParam
81348 // in tfjs-converter/src/operations/operation_mapper.ts to look up dtypes
81349 // by string name. If they are not quoted, Closure will mangle their names.
81350 // Not a legal value for DataType. Used to indicate a DataType field
81351 // has not been set.
81352 DataType[DataType["DT_INVALID"] = 0] = "DT_INVALID";
81353 // Data types that all computation devices are expected to be
81354 // capable to support.
81355 DataType[DataType["DT_FLOAT"] = 1] = "DT_FLOAT";
81356 DataType[DataType["DT_DOUBLE"] = 2] = "DT_DOUBLE";
81357 DataType[DataType["DT_INT32"] = 3] = "DT_INT32";
81358 DataType[DataType["DT_UINT8"] = 4] = "DT_UINT8";
81359 DataType[DataType["DT_INT16"] = 5] = "DT_INT16";
81360 DataType[DataType["DT_INT8"] = 6] = "DT_INT8";
81361 DataType[DataType["DT_STRING"] = 7] = "DT_STRING";
81362 DataType[DataType["DT_COMPLEX64"] = 8] = "DT_COMPLEX64";
81363 DataType[DataType["DT_INT64"] = 9] = "DT_INT64";
81364 DataType[DataType["DT_BOOL"] = 10] = "DT_BOOL";
81365 DataType[DataType["DT_QINT8"] = 11] = "DT_QINT8";
81366 DataType[DataType["DT_QUINT8"] = 12] = "DT_QUINT8";
81367 DataType[DataType["DT_QINT32"] = 13] = "DT_QINT32";
81368 DataType[DataType["DT_BFLOAT16"] = 14] = "DT_BFLOAT16";
81369 DataType[DataType["DT_QINT16"] = 15] = "DT_QINT16";
81370 DataType[DataType["DT_QUINT16"] = 16] = "DT_QUINT16";
81371 DataType[DataType["DT_UINT16"] = 17] = "DT_UINT16";
81372 DataType[DataType["DT_COMPLEX128"] = 18] = "DT_COMPLEX128";
81373 DataType[DataType["DT_HALF"] = 19] = "DT_HALF";
81374 DataType[DataType["DT_RESOURCE"] = 20] = "DT_RESOURCE";
81375 DataType[DataType["DT_VARIANT"] = 21] = "DT_VARIANT";
81376 DataType[DataType["DT_UINT32"] = 22] = "DT_UINT32";
81377 DataType[DataType["DT_UINT64"] = 23] = "DT_UINT64";
81378 // Do not use! These are only for parameters. Every enum above
81379 // should have a corresponding value below (verified by types_test).
81380 DataType[DataType["DT_FLOAT_REF"] = 101] = "DT_FLOAT_REF";
81381 DataType[DataType["DT_DOUBLE_REF"] = 102] = "DT_DOUBLE_REF";
81382 DataType[DataType["DT_INT32_REF"] = 103] = "DT_INT32_REF";
81383 DataType[DataType["DT_UINT8_REF"] = 104] = "DT_UINT8_REF";
81384 DataType[DataType["DT_INT16_REF"] = 105] = "DT_INT16_REF";
81385 DataType[DataType["DT_INT8_REF"] = 106] = "DT_INT8_REF";
81386 DataType[DataType["DT_STRING_REF"] = 107] = "DT_STRING_REF";
81387 DataType[DataType["DT_COMPLEX64_REF"] = 108] = "DT_COMPLEX64_REF";
81388 DataType[DataType["DT_INT64_REF"] = 109] = "DT_INT64_REF";
81389 DataType[DataType["DT_BOOL_REF"] = 110] = "DT_BOOL_REF";
81390 DataType[DataType["DT_QINT8_REF"] = 111] = "DT_QINT8_REF";
81391 DataType[DataType["DT_QUINT8_REF"] = 112] = "DT_QUINT8_REF";
81392 DataType[DataType["DT_QINT32_REF"] = 113] = "DT_QINT32_REF";
81393 DataType[DataType["DT_BFLOAT16_REF"] = 114] = "DT_BFLOAT16_REF";
81394 DataType[DataType["DT_QINT16_REF"] = 115] = "DT_QINT16_REF";
81395 DataType[DataType["DT_QUINT16_REF"] = 116] = "DT_QUINT16_REF";
81396 DataType[DataType["DT_UINT16_REF"] = 117] = "DT_UINT16_REF";
81397 DataType[DataType["DT_COMPLEX128_REF"] = 118] = "DT_COMPLEX128_REF";
81398 DataType[DataType["DT_HALF_REF"] = 119] = "DT_HALF_REF";
81399 DataType[DataType["DT_RESOURCE_REF"] = 120] = "DT_RESOURCE_REF";
81400 DataType[DataType["DT_VARIANT_REF"] = 121] = "DT_VARIANT_REF";
81401 DataType[DataType["DT_UINT32_REF"] = 122] = "DT_UINT32_REF";
81402 DataType[DataType["DT_UINT64_REF"] = 123] = "DT_UINT64_REF";
81403 })(DataType || (DataType = {}));
81404 var SaverDef;
81405 (function (SaverDef) {
81406 /** CheckpointFormatVersion enum. */
81407 var CheckpointFormatVersion;
81408 (function (CheckpointFormatVersion) {
81409 CheckpointFormatVersion[CheckpointFormatVersion["LEGACY"] = 0] = "LEGACY";
81410 CheckpointFormatVersion[CheckpointFormatVersion["V1"] = 1] = "V1";
81411 CheckpointFormatVersion[CheckpointFormatVersion["V2"] = 2] = "V2";
81412 })(CheckpointFormatVersion = SaverDef.CheckpointFormatVersion || (SaverDef.CheckpointFormatVersion = {}));
81413 })(SaverDef || (SaverDef = {}));
81414
81415 /**
81416 * @license
81417 * Copyright 2019 Google LLC. All Rights Reserved.
81418 * Licensed under the Apache License, Version 2.0 (the "License");
81419 * you may not use this file except in compliance with the License.
81420 * You may obtain a copy of the License at
81421 *
81422 * http://www.apache.org/licenses/LICENSE-2.0
81423 *
81424 * Unless required by applicable law or agreed to in writing, software
81425 * distributed under the License is distributed on an "AS IS" BASIS,
81426 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81427 * See the License for the specific language governing permissions and
81428 * limitations under the License.
81429 * =============================================================================
81430 */
81431 var CUSTOM_OPS = {};
81432 /**
81433 * Register an Op for graph model executor. This allows you to register
81434 * TensorFlow custom op or override existing op.
81435 *
81436 * Here is an example of registering a new MatMul Op.
81437 * ```js
81438 * const customMatmul = (node) =>
81439 * tf.matMul(
81440 * node.inputs[0], node.inputs[1],
81441 * node.attrs['transpose_a'], node.attrs['transpose_b']);
81442 *
81443 * tf.registerOp('MatMul', customMatmul);
81444 * ```
81445 * The inputs and attrs of the node object are based on the TensorFlow op
81446 * registry.
81447 *
81448 * @param name The Tensorflow Op name.
81449 * @param opFunc An op function which is called with the current graph node
81450 * during execution and needs to return a tensor or a list of tensors. The node
81451 * has the following attributes:
81452 * - attr: A map from attribute name to its value
81453 * - inputs: A list of input tensors
81454 *
81455 * @doc {heading: 'Models', subheading: 'Op Registry'}
81456 */
81457 function registerOp(name, opFunc) {
81458 var opMapper = {
81459 tfOpName: name,
81460 category: 'custom',
81461 inputs: [],
81462 attrs: [],
81463 customExecutor: opFunc
81464 };
81465 CUSTOM_OPS[name] = opMapper;
81466 }
81467 /**
81468 * Retrieve the OpMapper object for the registered op.
81469 *
81470 * @param name The Tensorflow Op name.
81471 *
81472 * @doc {heading: 'Models', subheading: 'Op Registry'}
81473 */
81474 function getRegisteredOp(name) {
81475 return CUSTOM_OPS[name];
81476 }
81477 /**
81478 * Deregister the Op for graph model executor.
81479 *
81480 * @param name The Tensorflow Op name.
81481 *
81482 * @doc {heading: 'Models', subheading: 'Op Registry'}
81483 */
81484 function deregisterOp(name) {
81485 delete CUSTOM_OPS[name];
81486 }
81487
81488 function getParamValue(paramName, node, tensorMap, context, resourceManager) {
81489 var inputParam = node.inputParams[paramName];
81490 if (inputParam && inputParam.inputIndexStart !== undefined) {
81491 var start = inputParam.inputIndexStart;
81492 var end = inputParam.inputIndexEnd === 0 ? undefined : inputParam.inputIndexEnd === undefined ? start + 1 : inputParam.inputIndexEnd;
81493 var shiftedStart = start < 0 ? node.inputNames.length + start : start;
81494 if (inputParam.type === 'tensor') {
81495 return getTensor(node.inputNames[shiftedStart], tensorMap, context, resourceManager);
81496 }
81497 if (inputParam.type === 'tensors') {
81498 // TODO(mattSoulanille): This filters out NoOp nodes during execution, but
81499 // these should really never be in the execution graph in the first place.
81500 // They're necessary for ordering the graph, but should not be visible
81501 // during execution. Perhaps have different sets of children, one for
81502 // control dependencies and another for real dependencies.
81503 var inputs = node.inputs.slice(start, end);
81504 var inputNames = node.inputNames.slice(start, end).filter(function (_name, index) {
81505 var _a;
81506 return ((_a = inputs[index]) === null || _a === void 0 ? void 0 : _a.op) !== 'NoOp';
81507 });
81508 return inputNames.map(function (name) {
81509 return getTensor(name, tensorMap, context, resourceManager);
81510 });
81511 }
81512 var tensor = getTensor(node.inputNames[shiftedStart], tensorMap, context, resourceManager);
81513 var data = tensor.dataSync();
81514 return inputParam.type === 'number' ? data[0] : toNestedArray(tensor.shape, data);
81515 }
81516 var attrParam = node.attrParams[paramName];
81517 return attrParam && attrParam.value;
81518 }
81519 /**
81520 * Retrieve the tensor from tensorsMap based on input name.
81521 * @param name Node input name
81522 * @param tensorsMap Tensors map keyed by the node
81523 * @param context contains tensors and information for running the current node.
81524 * @param resourceManager Optional. Contains global resources of the model.
81525 */
81526 function getTensor(name, tensorsMap, context, resourceManager) {
81527 var _parseNodeName = parseNodeName(name, context),
81528 _parseNodeName2 = _slicedToArray(_parseNodeName, 2),
81529 nodeName = _parseNodeName2[0],
81530 index = _parseNodeName2[1];
81531 if (resourceManager != null) {
81532 var tensor = resourceManager.getHashTableHandleByName(nodeName);
81533 if (tensor != null) {
81534 return tensor;
81535 }
81536 }
81537 var contextId = context.currentContextIds.find(function (contextId) {
81538 return !!tensorsMap[getNodeNameWithContextId(nodeName, contextId)];
81539 });
81540 return contextId !== undefined ? tensorsMap[getNodeNameWithContextId(nodeName, contextId)][index] : undefined;
81541 }
81542 /**
81543 * Retrieve the tensors based on input name for current context.
81544 * @param name Node input name
81545 * @param tensorsMap Tensors map keyed by the node
81546 */
81547 function getTensorsForCurrentContext(name, tensorsMap, context) {
81548 return tensorsMap[getNodeNameWithContextId(name, context.currentContextId)];
81549 }
81550 /**
81551 * Returns the node name, outputName and index from the Node input name.
81552 * @param inputName The input name of the node, in format of
81553 * node_name:output_index, i.e. MatMul:0, if the output_index is not set, it is
81554 * default to 0.
81555 * If the input name contains output name i.e. StringSplit:indices:0, it will
81556 * return ['StringSplit', 0, 'indices'].
81557 */
81558 function getNodeNameAndIndex(inputName, context) {
81559 var _parseNodeName3 = parseNodeName(inputName, context),
81560 _parseNodeName4 = _slicedToArray(_parseNodeName3, 3),
81561 nodeName = _parseNodeName4[0],
81562 index = _parseNodeName4[1],
81563 outputName = _parseNodeName4[2];
81564 return [getNodeNameWithContextId(nodeName, context && context.currentContextId), index, outputName];
81565 }
81566 function getNodeNameWithContextId(name, contextId) {
81567 return !!contextId ? "".concat(name, "-").concat(contextId) : name;
81568 }
81569 function parseNodeName(name, context) {
81570 if (name === '') {
81571 return ['', 0, undefined];
81572 }
81573 var isCacheEnabled = context != null && context.parseNodeNameCache != null;
81574 if (isCacheEnabled) {
81575 var cachedResult = context.parseNodeNameCache.get(name);
81576 if (cachedResult != null) {
81577 return cachedResult;
81578 }
81579 }
81580 var parts = name.split(':');
81581 var result;
81582 if (parts.length === 1) {
81583 result = [name, 0, undefined];
81584 } else {
81585 var nodeName = parts[0];
81586 var outputName = parts.length === 3 ? parts[1] : undefined;
81587 var index = Number(parts[parts.length - 1]);
81588 result = [nodeName, index, outputName];
81589 }
81590 if (isCacheEnabled) {
81591 context.parseNodeNameCache.set(name, result);
81592 }
81593 return result;
81594 }
81595 function split$2(arr, size) {
81596 var res = [];
81597 for (var i = 0; i < arr.length; i += size) {
81598 res.push(arr.slice(i, i + size));
81599 }
81600 return res;
81601 }
81602 function getPadding(node, tensorMap, context) {
81603 var pad = getParamValue('pad', node, tensorMap, context);
81604 if (pad === 'explicit') {
81605 // This is 1d array, we need to convert it to 2d array
81606 pad = getParamValue('explicitPaddings', node, tensorMap, context);
81607 var explicitPadding = [[0, 0], [0, 0], [0, 0], [0, 0]];
81608 for (var i = 0; i < 4; i++) {
81609 explicitPadding[i][0] = pad[i * 2];
81610 explicitPadding[i][1] = pad[i * 2 + 1];
81611 }
81612 return explicitPadding;
81613 }
81614 return pad;
81615 }
81616 /**
81617 * Reuse the tensor if it is marked as keep, otherwise clone the tensor to
81618 * avoid disposal. This is important for TensorArray and TensorList ops, since
81619 * internally they use a tensor as the id for TensorArray and TensorList, and
81620 * to simplify lookup, they also use Tensor.id as the key to the internal map.
81621 * These id tensors have been marked as kept in the backend, we need avoid clone
81622 * them in order to create new Tensor.id.
81623 * @param tensor
81624 */
81625 function cloneTensor(tensor) {
81626 return tensor.kept ? tensor : clone(tensor);
81627 }
81628
81629 /**
81630 * @license
81631 * Copyright 2023 Google LLC. All Rights Reserved.
81632 * Licensed under the Apache License, Version 2.0 (the "License");
81633 * you may not use this file except in compliance with the License.
81634 * You may obtain a copy of the License at
81635 *
81636 * http://www.apache.org/licenses/LICENSE-2.0
81637 *
81638 * Unless required by applicable law or agreed to in writing, software
81639 * distributed under the License is distributed on an "AS IS" BASIS,
81640 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81641 * See the License for the specific language governing permissions and
81642 * limitations under the License.
81643 * =============================================================================
81644 */
81645 var json$i = [{
81646 'tfOpName': 'Add',
81647 'category': 'arithmetic',
81648 'inputs': [{
81649 'start': 0,
81650 'name': 'a',
81651 'type': 'tensor'
81652 }, {
81653 'start': 1,
81654 'name': 'b',
81655 'type': 'tensor'
81656 }],
81657 'attrs': [{
81658 'tfName': 'T',
81659 'name': 'dtype',
81660 'type': 'dtype',
81661 'notSupported': true
81662 }]
81663 }, {
81664 'tfOpName': 'AddV2',
81665 'category': 'arithmetic',
81666 'inputs': [{
81667 'start': 0,
81668 'name': 'a',
81669 'type': 'tensor'
81670 }, {
81671 'start': 1,
81672 'name': 'b',
81673 'type': 'tensor'
81674 }],
81675 'attrs': [{
81676 'tfName': 'T',
81677 'name': 'dtype',
81678 'type': 'dtype',
81679 'notSupported': true
81680 }]
81681 }, {
81682 'tfOpName': 'AddN',
81683 'category': 'arithmetic',
81684 'inputs': [{
81685 'start': 0,
81686 'end': 0,
81687 'name': 'tensors',
81688 'type': 'tensors'
81689 }]
81690 }, {
81691 'tfOpName': 'BiasAdd',
81692 'category': 'arithmetic',
81693 'inputs': [{
81694 'start': 0,
81695 'name': 'a',
81696 'type': 'tensor'
81697 }, {
81698 'start': 1,
81699 'name': 'b',
81700 'type': 'tensor'
81701 }],
81702 'attrs': [{
81703 'tfName': 'T',
81704 'name': 'dtype',
81705 'type': 'dtype',
81706 'notSupported': true
81707 }, {
81708 'tfName': 'data_format',
81709 'name': 'dataFormat',
81710 'type': 'string',
81711 'notSupported': true
81712 }]
81713 }, {
81714 'tfOpName': 'Sub',
81715 'category': 'arithmetic',
81716 'inputs': [{
81717 'start': 0,
81718 'name': 'a',
81719 'type': 'tensor'
81720 }, {
81721 'start': 1,
81722 'name': 'b',
81723 'type': 'tensor'
81724 }],
81725 'attrs': [{
81726 'tfName': 'T',
81727 'name': 'dtype',
81728 'type': 'dtype',
81729 'notSupported': true
81730 }]
81731 }, {
81732 'tfOpName': 'RealDiv',
81733 'category': 'arithmetic',
81734 'inputs': [{
81735 'start': 0,
81736 'name': 'a',
81737 'type': 'tensor'
81738 }, {
81739 'start': 1,
81740 'name': 'b',
81741 'type': 'tensor'
81742 }],
81743 'attrs': [{
81744 'tfName': 'T',
81745 'name': 'dtype',
81746 'type': 'dtype',
81747 'notSupported': true
81748 }]
81749 }, {
81750 'tfOpName': 'Div',
81751 'category': 'arithmetic',
81752 'inputs': [{
81753 'start': 0,
81754 'name': 'a',
81755 'type': 'tensor'
81756 }, {
81757 'start': 1,
81758 'name': 'b',
81759 'type': 'tensor'
81760 }],
81761 'attrs': [{
81762 'tfName': 'T',
81763 'name': 'dtype',
81764 'type': 'dtype',
81765 'notSupported': true
81766 }]
81767 }, {
81768 'tfOpName': 'DivNoNan',
81769 'category': 'arithmetic',
81770 'inputs': [{
81771 'start': 0,
81772 'name': 'a',
81773 'type': 'tensor'
81774 }, {
81775 'start': 1,
81776 'name': 'b',
81777 'type': 'tensor'
81778 }],
81779 'attrs': [{
81780 'tfName': 'T',
81781 'name': 'dtype',
81782 'type': 'dtype',
81783 'notSupported': true
81784 }]
81785 }, {
81786 'tfOpName': 'FloorDiv',
81787 'category': 'arithmetic',
81788 'inputs': [{
81789 'start': 0,
81790 'name': 'a',
81791 'type': 'tensor'
81792 }, {
81793 'start': 1,
81794 'name': 'b',
81795 'type': 'tensor'
81796 }],
81797 'attrs': [{
81798 'tfName': 'T',
81799 'name': 'dtype',
81800 'type': 'dtype',
81801 'notSupported': true
81802 }]
81803 }, {
81804 'tfOpName': 'Mul',
81805 'category': 'arithmetic',
81806 'inputs': [{
81807 'start': 0,
81808 'name': 'a',
81809 'type': 'tensor'
81810 }, {
81811 'start': 1,
81812 'name': 'b',
81813 'type': 'tensor'
81814 }],
81815 'attrs': [{
81816 'tfName': 'T',
81817 'name': 'dtype',
81818 'type': 'dtype',
81819 'notSupported': true
81820 }]
81821 }, {
81822 'tfOpName': 'Maximum',
81823 'category': 'arithmetic',
81824 'inputs': [{
81825 'start': 0,
81826 'name': 'a',
81827 'type': 'tensor'
81828 }, {
81829 'start': 1,
81830 'name': 'b',
81831 'type': 'tensor'
81832 }],
81833 'attrs': [{
81834 'tfName': 'T',
81835 'name': 'dtype',
81836 'type': 'dtype',
81837 'notSupported': true
81838 }]
81839 }, {
81840 'tfOpName': 'Minimum',
81841 'category': 'arithmetic',
81842 'inputs': [{
81843 'start': 0,
81844 'name': 'a',
81845 'type': 'tensor'
81846 }, {
81847 'start': 1,
81848 'name': 'b',
81849 'type': 'tensor'
81850 }],
81851 'attrs': [{
81852 'tfName': 'T',
81853 'name': 'dtype',
81854 'type': 'dtype',
81855 'notSupported': true
81856 }]
81857 }, {
81858 'tfOpName': 'Pow',
81859 'category': 'arithmetic',
81860 'inputs': [{
81861 'start': 0,
81862 'name': 'a',
81863 'type': 'tensor'
81864 }, {
81865 'start': 1,
81866 'name': 'b',
81867 'type': 'tensor'
81868 }],
81869 'attrs': [{
81870 'tfName': 'T',
81871 'name': 'dtype',
81872 'type': 'dtype',
81873 'notSupported': true
81874 }]
81875 }, {
81876 'tfOpName': 'SquaredDifference',
81877 'category': 'arithmetic',
81878 'inputs': [{
81879 'start': 0,
81880 'name': 'a',
81881 'type': 'tensor'
81882 }, {
81883 'start': 1,
81884 'name': 'b',
81885 'type': 'tensor'
81886 }],
81887 'attrs': [{
81888 'tfName': 'T',
81889 'name': 'dtype',
81890 'type': 'dtype',
81891 'notSupported': true
81892 }]
81893 }, {
81894 'tfOpName': 'Mod',
81895 'category': 'arithmetic',
81896 'inputs': [{
81897 'start': 0,
81898 'name': 'a',
81899 'type': 'tensor'
81900 }, {
81901 'start': 1,
81902 'name': 'b',
81903 'type': 'tensor'
81904 }],
81905 'attrs': [{
81906 'tfName': 'T',
81907 'name': 'dtype',
81908 'type': 'dtype',
81909 'notSupported': true
81910 }]
81911 }, {
81912 'tfOpName': 'FloorMod',
81913 'category': 'arithmetic',
81914 'inputs': [{
81915 'start': 0,
81916 'name': 'a',
81917 'type': 'tensor'
81918 }, {
81919 'start': 1,
81920 'name': 'b',
81921 'type': 'tensor'
81922 }],
81923 'attrs': [{
81924 'tfName': 'T',
81925 'name': 'dtype',
81926 'type': 'dtype',
81927 'notSupported': true
81928 }]
81929 }];
81930
81931 var arithmetic = {
81932 __proto__: null,
81933 json: json$i
81934 };
81935
81936 /**
81937 * @license
81938 * Copyright 2023 Google LLC. All Rights Reserved.
81939 * Licensed under the Apache License, Version 2.0 (the "License");
81940 * you may not use this file except in compliance with the License.
81941 * You may obtain a copy of the License at
81942 *
81943 * http://www.apache.org/licenses/LICENSE-2.0
81944 *
81945 * Unless required by applicable law or agreed to in writing, software
81946 * distributed under the License is distributed on an "AS IS" BASIS,
81947 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81948 * See the License for the specific language governing permissions and
81949 * limitations under the License.
81950 * =============================================================================
81951 */
81952 var json$h = [{
81953 'tfOpName': 'Abs',
81954 'category': 'basic_math',
81955 'inputs': [{
81956 'start': 0,
81957 'name': 'x',
81958 'type': 'tensor'
81959 }],
81960 'attrs': [{
81961 'tfName': 'T',
81962 'name': 'dtype',
81963 'type': 'dtype',
81964 'notSupported': true
81965 }]
81966 }, {
81967 'tfOpName': 'Acos',
81968 'category': 'basic_math',
81969 'inputs': [{
81970 'start': 0,
81971 'name': 'x',
81972 'type': 'tensor'
81973 }],
81974 'attrs': [{
81975 'tfName': 'T',
81976 'name': 'dtype',
81977 'type': 'dtype',
81978 'notSupported': true
81979 }]
81980 }, {
81981 'tfOpName': 'Asin',
81982 'category': 'basic_math',
81983 'inputs': [{
81984 'start': 0,
81985 'name': 'x',
81986 'type': 'tensor'
81987 }],
81988 'attrs': [{
81989 'tfName': 'T',
81990 'name': 'dtype',
81991 'type': 'dtype',
81992 'notSupported': true
81993 }]
81994 }, {
81995 'tfOpName': 'Atan',
81996 'category': 'basic_math',
81997 'inputs': [{
81998 'start': 0,
81999 'name': 'x',
82000 'type': 'tensor'
82001 }],
82002 'attrs': [{
82003 'tfName': 'T',
82004 'name': 'dtype',
82005 'type': 'dtype',
82006 'notSupported': true
82007 }]
82008 }, {
82009 'tfOpName': 'Atan2',
82010 'category': 'basic_math',
82011 'inputs': [{
82012 'start': 0,
82013 'name': 'x',
82014 'type': 'tensor'
82015 }, {
82016 'start': 1,
82017 'name': 'y',
82018 'type': 'tensor'
82019 }],
82020 'attrs': [{
82021 'tfName': 'T',
82022 'name': 'dtype',
82023 'type': 'dtype',
82024 'notSupported': true
82025 }]
82026 }, {
82027 'tfOpName': 'Ceil',
82028 'category': 'basic_math',
82029 'inputs': [{
82030 'start': 0,
82031 'name': 'x',
82032 'type': 'tensor'
82033 }],
82034 'attrs': [{
82035 'tfName': 'T',
82036 'name': 'dtype',
82037 'type': 'dtype',
82038 'notSupported': true
82039 }]
82040 }, {
82041 'tfOpName': 'ClipByValue',
82042 'category': 'basic_math',
82043 'inputs': [{
82044 'start': 0,
82045 'name': 'x',
82046 'type': 'tensor'
82047 }, {
82048 'start': 1,
82049 'name': 'clipValueMin',
82050 'type': 'number'
82051 }, {
82052 'start': 2,
82053 'name': 'clipValueMax',
82054 'type': 'number'
82055 }],
82056 'attrs': [{
82057 'tfName': 'T',
82058 'name': 'dtype',
82059 'type': 'dtype',
82060 'notSupported': true
82061 }]
82062 }, {
82063 'tfOpName': 'Complex',
82064 'category': 'basic_math',
82065 'inputs': [{
82066 'start': 0,
82067 'name': 'real',
82068 'type': 'tensor'
82069 }, {
82070 'start': 1,
82071 'name': 'imag',
82072 'type': 'tensor'
82073 }],
82074 'attrs': [{
82075 'tfName': 'T',
82076 'name': 'dtype',
82077 'type': 'dtype',
82078 'notSupported': true
82079 }]
82080 }, {
82081 'tfOpName': 'ComplexAbs',
82082 'category': 'basic_math',
82083 'inputs': [{
82084 'start': 0,
82085 'name': 'x',
82086 'type': 'tensor'
82087 }],
82088 'attrs': [{
82089 'tfName': 'T',
82090 'name': 'dtype',
82091 'type': 'dtype',
82092 'notSupported': true
82093 }]
82094 }, {
82095 'tfOpName': 'Cos',
82096 'category': 'basic_math',
82097 'inputs': [{
82098 'start': 0,
82099 'name': 'x',
82100 'type': 'tensor'
82101 }],
82102 'attrs': [{
82103 'tfName': 'T',
82104 'name': 'dtype',
82105 'type': 'dtype',
82106 'notSupported': true
82107 }]
82108 }, {
82109 'tfOpName': 'Cosh',
82110 'category': 'basic_math',
82111 'inputs': [{
82112 'start': 0,
82113 'name': 'x',
82114 'type': 'tensor'
82115 }],
82116 'attrs': [{
82117 'tfName': 'T',
82118 'name': 'dtype',
82119 'type': 'dtype',
82120 'notSupported': true
82121 }]
82122 }, {
82123 'tfOpName': 'Elu',
82124 'category': 'basic_math',
82125 'inputs': [{
82126 'start': 0,
82127 'name': 'x',
82128 'type': 'tensor'
82129 }],
82130 'attrs': [{
82131 'tfName': 'T',
82132 'name': 'dtype',
82133 'type': 'dtype',
82134 'notSupported': true
82135 }]
82136 }, {
82137 'tfOpName': 'Exp',
82138 'category': 'basic_math',
82139 'inputs': [{
82140 'start': 0,
82141 'name': 'x',
82142 'type': 'tensor'
82143 }],
82144 'attrs': [{
82145 'tfName': 'T',
82146 'name': 'dtype',
82147 'type': 'dtype',
82148 'notSupported': true
82149 }]
82150 }, {
82151 'tfOpName': 'Floor',
82152 'category': 'basic_math',
82153 'inputs': [{
82154 'start': 0,
82155 'name': 'x',
82156 'type': 'tensor'
82157 }],
82158 'attrs': [{
82159 'tfName': 'T',
82160 'name': 'dtype',
82161 'type': 'dtype',
82162 'notSupported': true
82163 }]
82164 }, {
82165 'tfOpName': 'Log',
82166 'category': 'basic_math',
82167 'inputs': [{
82168 'start': 0,
82169 'name': 'x',
82170 'type': 'tensor'
82171 }],
82172 'attrs': [{
82173 'tfName': 'T',
82174 'name': 'dtype',
82175 'type': 'dtype',
82176 'notSupported': true
82177 }]
82178 }, {
82179 'tfOpName': 'Imag',
82180 'category': 'basic_math',
82181 'inputs': [{
82182 'start': 0,
82183 'name': 'x',
82184 'type': 'tensor'
82185 }],
82186 'attrs': [{
82187 'tfName': 'T',
82188 'name': 'dtype',
82189 'type': 'dtype',
82190 'notSupported': true
82191 }, {
82192 'tfName': 'Tout',
82193 'name': 'outputType',
82194 'type': 'dtype',
82195 'notSupported': true
82196 }]
82197 }, {
82198 'tfOpName': 'Neg',
82199 'category': 'basic_math',
82200 'inputs': [{
82201 'start': 0,
82202 'name': 'x',
82203 'type': 'tensor'
82204 }],
82205 'attrs': [{
82206 'tfName': 'T',
82207 'name': 'dtype',
82208 'type': 'dtype',
82209 'notSupported': true
82210 }]
82211 }, {
82212 'tfOpName': 'Real',
82213 'category': 'basic_math',
82214 'inputs': [{
82215 'start': 0,
82216 'name': 'x',
82217 'type': 'tensor'
82218 }],
82219 'attrs': [{
82220 'tfName': 'T',
82221 'name': 'dtype',
82222 'type': 'dtype',
82223 'notSupported': true
82224 }, {
82225 'tfName': 'Tout',
82226 'name': 'outputType',
82227 'type': 'dtype',
82228 'notSupported': true
82229 }]
82230 }, {
82231 'tfOpName': 'Prelu',
82232 'category': 'basic_math',
82233 'inputs': [{
82234 'start': 0,
82235 'name': 'x',
82236 'type': 'tensor'
82237 }, {
82238 'start': 1,
82239 'name': 'alpha',
82240 'type': 'tensor'
82241 }],
82242 'attrs': [{
82243 'tfName': 'T',
82244 'name': 'dtype',
82245 'type': 'dtype',
82246 'notSupported': true
82247 }]
82248 }, {
82249 'tfOpName': 'Relu',
82250 'category': 'basic_math',
82251 'inputs': [{
82252 'start': 0,
82253 'name': 'x',
82254 'type': 'tensor'
82255 }],
82256 'attrs': [{
82257 'tfName': 'T',
82258 'name': 'dtype',
82259 'type': 'dtype',
82260 'notSupported': true
82261 }]
82262 }, {
82263 'tfOpName': 'Relu6',
82264 'category': 'basic_math',
82265 'inputs': [{
82266 'start': 0,
82267 'name': 'x',
82268 'type': 'tensor'
82269 }],
82270 'attrs': [{
82271 'tfName': 'T',
82272 'name': 'dtype',
82273 'type': 'dtype',
82274 'notSupported': true
82275 }]
82276 }, {
82277 'tfOpName': 'Selu',
82278 'category': 'basic_math',
82279 'inputs': [{
82280 'start': 0,
82281 'name': 'x',
82282 'type': 'tensor'
82283 }],
82284 'attrs': [{
82285 'tfName': 'T',
82286 'name': 'dtype',
82287 'type': 'dtype',
82288 'notSupported': true
82289 }]
82290 }, {
82291 'tfOpName': 'Sigmoid',
82292 'category': 'basic_math',
82293 'inputs': [{
82294 'start': 0,
82295 'name': 'x',
82296 'type': 'tensor'
82297 }],
82298 'attrs': [{
82299 'tfName': 'T',
82300 'name': 'dtype',
82301 'type': 'dtype',
82302 'notSupported': true
82303 }]
82304 }, {
82305 'tfOpName': 'Sin',
82306 'category': 'basic_math',
82307 'inputs': [{
82308 'start': 0,
82309 'name': 'x',
82310 'type': 'tensor'
82311 }],
82312 'attrs': [{
82313 'tfName': 'T',
82314 'name': 'dtype',
82315 'type': 'dtype',
82316 'notSupported': true
82317 }]
82318 }, {
82319 'tfOpName': 'Sinh',
82320 'category': 'basic_math',
82321 'inputs': [{
82322 'start': 0,
82323 'name': 'x',
82324 'type': 'tensor'
82325 }],
82326 'attrs': [{
82327 'tfName': 'T',
82328 'name': 'dtype',
82329 'type': 'dtype',
82330 'notSupported': true
82331 }]
82332 }, {
82333 'tfOpName': 'Sqrt',
82334 'category': 'basic_math',
82335 'inputs': [{
82336 'start': 0,
82337 'name': 'x',
82338 'type': 'tensor'
82339 }],
82340 'attrs': [{
82341 'tfName': 'T',
82342 'name': 'dtype',
82343 'type': 'dtype',
82344 'notSupported': true
82345 }]
82346 }, {
82347 'tfOpName': 'Rsqrt',
82348 'category': 'basic_math',
82349 'inputs': [{
82350 'start': 0,
82351 'name': 'x',
82352 'type': 'tensor'
82353 }],
82354 'attrs': [{
82355 'tfName': 'T',
82356 'name': 'dtype',
82357 'type': 'dtype',
82358 'notSupported': true
82359 }]
82360 }, {
82361 'tfOpName': 'Square',
82362 'category': 'basic_math',
82363 'inputs': [{
82364 'start': 0,
82365 'name': 'x',
82366 'type': 'tensor'
82367 }],
82368 'attrs': [{
82369 'tfName': 'T',
82370 'name': 'dtype',
82371 'type': 'dtype',
82372 'notSupported': true
82373 }]
82374 }, {
82375 'tfOpName': 'Tan',
82376 'category': 'basic_math',
82377 'inputs': [{
82378 'start': 0,
82379 'name': 'x',
82380 'type': 'tensor'
82381 }],
82382 'attrs': [{
82383 'tfName': 'T',
82384 'name': 'dtype',
82385 'type': 'dtype',
82386 'notSupported': true
82387 }]
82388 }, {
82389 'tfOpName': 'Tanh',
82390 'category': 'basic_math',
82391 'inputs': [{
82392 'start': 0,
82393 'name': 'x',
82394 'type': 'tensor'
82395 }],
82396 'attrs': [{
82397 'tfName': 'T',
82398 'name': 'dtype',
82399 'type': 'dtype',
82400 'notSupported': true
82401 }]
82402 }, {
82403 'tfOpName': 'Sign',
82404 'category': 'basic_math',
82405 'inputs': [{
82406 'start': 0,
82407 'name': 'x',
82408 'type': 'tensor'
82409 }],
82410 'attrs': [{
82411 'tfName': 'T',
82412 'name': 'dtype',
82413 'type': 'dtype',
82414 'notSupported': true
82415 }]
82416 }, {
82417 'tfOpName': 'Round',
82418 'category': 'basic_math',
82419 'inputs': [{
82420 'start': 0,
82421 'name': 'x',
82422 'type': 'tensor'
82423 }],
82424 'attrs': [{
82425 'tfName': 'T',
82426 'name': 'dtype',
82427 'type': 'dtype',
82428 'notSupported': true
82429 }]
82430 }, {
82431 'tfOpName': 'Expm1',
82432 'category': 'basic_math',
82433 'inputs': [{
82434 'start': 0,
82435 'name': 'x',
82436 'type': 'tensor'
82437 }],
82438 'attrs': [{
82439 'tfName': 'T',
82440 'name': 'dtype',
82441 'type': 'dtype',
82442 'notSupported': true
82443 }]
82444 }, {
82445 'tfOpName': 'Log1p',
82446 'category': 'basic_math',
82447 'inputs': [{
82448 'start': 0,
82449 'name': 'x',
82450 'type': 'tensor'
82451 }],
82452 'attrs': [{
82453 'tfName': 'T',
82454 'name': 'dtype',
82455 'type': 'dtype',
82456 'notSupported': true
82457 }]
82458 }, {
82459 'tfOpName': 'Reciprocal',
82460 'category': 'basic_math',
82461 'inputs': [{
82462 'start': 0,
82463 'name': 'x',
82464 'type': 'tensor'
82465 }],
82466 'attrs': [{
82467 'tfName': 'T',
82468 'name': 'dtype',
82469 'type': 'dtype',
82470 'notSupported': true
82471 }]
82472 }, {
82473 'tfOpName': 'Softplus',
82474 'category': 'basic_math',
82475 'inputs': [{
82476 'start': 0,
82477 'name': 'x',
82478 'type': 'tensor'
82479 }],
82480 'attrs': [{
82481 'tfName': 'T',
82482 'name': 'dtype',
82483 'type': 'dtype',
82484 'notSupported': true
82485 }]
82486 }, {
82487 'tfOpName': 'Asinh',
82488 'category': 'basic_math',
82489 'inputs': [{
82490 'start': 0,
82491 'name': 'x',
82492 'type': 'tensor'
82493 }],
82494 'attrs': [{
82495 'tfName': 'T',
82496 'name': 'dtype',
82497 'type': 'dtype',
82498 'notSupported': true
82499 }]
82500 }, {
82501 'tfOpName': 'Acosh',
82502 'category': 'basic_math',
82503 'inputs': [{
82504 'start': 0,
82505 'name': 'x',
82506 'type': 'tensor'
82507 }],
82508 'attrs': [{
82509 'tfName': 'T',
82510 'name': 'dtype',
82511 'type': 'dtype',
82512 'notSupported': true
82513 }]
82514 }, {
82515 'tfOpName': 'Atanh',
82516 'category': 'basic_math',
82517 'inputs': [{
82518 'start': 0,
82519 'name': 'x',
82520 'type': 'tensor'
82521 }],
82522 'attrs': [{
82523 'tfName': 'T',
82524 'name': 'dtype',
82525 'type': 'dtype',
82526 'notSupported': true
82527 }]
82528 }, {
82529 'tfOpName': 'Erf',
82530 'category': 'basic_math',
82531 'inputs': [{
82532 'start': 0,
82533 'name': 'x',
82534 'type': 'tensor'
82535 }],
82536 'attrs': [{
82537 'tfName': 'T',
82538 'name': 'dtype',
82539 'type': 'dtype',
82540 'notSupported': true
82541 }]
82542 }, {
82543 'tfOpName': 'LeakyRelu',
82544 'category': 'basic_math',
82545 'inputs': [{
82546 'start': 0,
82547 'name': 'x',
82548 'type': 'tensor'
82549 }],
82550 'attrs': [{
82551 'tfName': 'alpha',
82552 'name': 'alpha',
82553 'type': 'number',
82554 'defaultValue': 0.2
82555 }, {
82556 'tfName': 'T',
82557 'name': 'dtype',
82558 'type': 'dtype',
82559 'notSupported': true
82560 }]
82561 }, {
82562 'tfOpName': 'IsNan',
82563 'category': 'basic_math',
82564 'inputs': [{
82565 'start': 0,
82566 'name': 'x',
82567 'type': 'tensor'
82568 }],
82569 'attrs': [{
82570 'tfName': 'T',
82571 'name': 'dtype',
82572 'type': 'dtype',
82573 'notSupported': true
82574 }]
82575 }, {
82576 'tfOpName': 'IsFinite',
82577 'category': 'basic_math',
82578 'inputs': [{
82579 'start': 0,
82580 'name': 'x',
82581 'type': 'tensor'
82582 }],
82583 'attrs': [{
82584 'tfName': 'T',
82585 'name': 'dtype',
82586 'type': 'dtype',
82587 'notSupported': true
82588 }]
82589 }, {
82590 'tfOpName': 'IsInf',
82591 'category': 'basic_math',
82592 'inputs': [{
82593 'start': 0,
82594 'name': 'x',
82595 'type': 'tensor'
82596 }],
82597 'attrs': [{
82598 'tfName': 'T',
82599 'name': 'dtype',
82600 'type': 'dtype',
82601 'notSupported': true
82602 }]
82603 }];
82604
82605 var basicMath = {
82606 __proto__: null,
82607 json: json$h
82608 };
82609
82610 /**
82611 * @license
82612 * Copyright 2023 Google LLC. All Rights Reserved.
82613 * Licensed under the Apache License, Version 2.0 (the "License");
82614 * you may not use this file except in compliance with the License.
82615 * You may obtain a copy of the License at
82616 *
82617 * http://www.apache.org/licenses/LICENSE-2.0
82618 *
82619 * Unless required by applicable law or agreed to in writing, software
82620 * distributed under the License is distributed on an "AS IS" BASIS,
82621 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82622 * See the License for the specific language governing permissions and
82623 * limitations under the License.
82624 * =============================================================================
82625 */
82626 var json$g = [{
82627 'tfOpName': 'EmptyTensorList',
82628 'category': 'control',
82629 'inputs': [{
82630 'start': 0,
82631 'name': 'elementShape',
82632 'type': 'shape'
82633 }, {
82634 'start': 1,
82635 'name': 'maxNumElements',
82636 'type': 'number'
82637 }],
82638 'attrs': [{
82639 'tfName': 'element_dtype',
82640 'name': 'elementDType',
82641 'type': 'dtype'
82642 }]
82643 }, {
82644 'tfOpName': 'LoopCond',
82645 'category': 'control',
82646 'inputs': [{
82647 'start': 0,
82648 'name': 'pred',
82649 'type': 'tensor'
82650 }]
82651 }, {
82652 'tfOpName': 'Switch',
82653 'category': 'control',
82654 'inputs': [{
82655 'start': 0,
82656 'name': 'data',
82657 'type': 'tensor'
82658 }, {
82659 'start': 1,
82660 'name': 'pred',
82661 'type': 'tensor'
82662 }]
82663 }, {
82664 'tfOpName': 'Merge',
82665 'category': 'control',
82666 'inputs': [{
82667 'start': 0,
82668 'end': 0,
82669 'name': 'tensors',
82670 'type': 'tensors'
82671 }]
82672 }, {
82673 'tfOpName': 'Enter',
82674 'category': 'control',
82675 'inputs': [{
82676 'start': 0,
82677 'name': 'tensor',
82678 'type': 'tensor'
82679 }],
82680 'attrs': [{
82681 'tfName': 'T',
82682 'name': 'dtype',
82683 'type': 'dtype',
82684 'notSupported': true
82685 }, {
82686 'tfName': 'frame_name',
82687 'name': 'frameName',
82688 'type': 'string'
82689 }, {
82690 'tfName': 'is_constant',
82691 'name': 'isConstant',
82692 'type': 'bool'
82693 }]
82694 }, {
82695 'tfOpName': 'Exit',
82696 'category': 'control',
82697 'inputs': [{
82698 'start': 0,
82699 'name': 'tensor',
82700 'type': 'tensor'
82701 }],
82702 'attrs': [{
82703 'tfName': 'T',
82704 'name': 'dtype',
82705 'type': 'dtype',
82706 'notSupported': true
82707 }]
82708 }, {
82709 'tfOpName': 'NextIteration',
82710 'category': 'control',
82711 'inputs': [{
82712 'start': 0,
82713 'name': 'tensor',
82714 'type': 'tensor'
82715 }],
82716 'attrs': [{
82717 'tfName': 'T',
82718 'name': 'dtype',
82719 'type': 'dtype',
82720 'notSupported': true
82721 }]
82722 }, {
82723 'tfOpName': 'TensorArrayV3',
82724 'category': 'control',
82725 'inputs': [{
82726 'start': 0,
82727 'name': 'size',
82728 'type': 'number'
82729 }],
82730 'attrs': [{
82731 'tfName': 'dtype',
82732 'name': 'dtype',
82733 'type': 'dtype'
82734 }, {
82735 'tfName': 'element_shape',
82736 'name': 'elementShape',
82737 'type': 'shape'
82738 }, {
82739 'tfName': 'dynamic_size',
82740 'name': 'dynamicSize',
82741 'type': 'bool'
82742 }, {
82743 'tfName': 'clear_after_read',
82744 'name': 'clearAfterRead',
82745 'type': 'bool'
82746 }, {
82747 'tfName': 'identical_element_shapes',
82748 'name': 'identicalElementShapes',
82749 'type': 'bool'
82750 }, {
82751 'tfName': 'tensor_array_name',
82752 'name': 'name',
82753 'type': 'string'
82754 }]
82755 }, {
82756 'tfOpName': 'TensorArrayWriteV3',
82757 'category': 'control',
82758 'inputs': [{
82759 'start': 0,
82760 'name': 'tensorArrayId',
82761 'type': 'tensor'
82762 }, {
82763 'start': 1,
82764 'name': 'index',
82765 'type': 'number'
82766 }, {
82767 'start': 2,
82768 'name': 'tensor',
82769 'type': 'tensor'
82770 }, {
82771 'start': 3,
82772 'name': 'flowIn',
82773 'type': 'number'
82774 }],
82775 'attrs': [{
82776 'tfName': 'T',
82777 'name': 'dtype',
82778 'type': 'dtype',
82779 'notSupported': true
82780 }]
82781 }, {
82782 'tfOpName': 'TensorArrayReadV3',
82783 'category': 'control',
82784 'inputs': [{
82785 'start': 0,
82786 'name': 'tensorArrayId',
82787 'type': 'tensor'
82788 }, {
82789 'start': 1,
82790 'name': 'index',
82791 'type': 'number'
82792 }, {
82793 'start': 2,
82794 'name': 'flowIn',
82795 'type': 'number'
82796 }],
82797 'attrs': [{
82798 'tfName': 'dtype',
82799 'name': 'dtype',
82800 'type': 'dtype',
82801 'notSupported': true
82802 }]
82803 }, {
82804 'tfOpName': 'TensorArrayGatherV3',
82805 'category': 'control',
82806 'inputs': [{
82807 'start': 0,
82808 'name': 'tensorArrayId',
82809 'type': 'tensor'
82810 }, {
82811 'start': 1,
82812 'name': 'indices',
82813 'type': 'number[]'
82814 }, {
82815 'start': 2,
82816 'name': 'flowIn',
82817 'type': 'number'
82818 }],
82819 'attrs': [{
82820 'tfName': 'dtype',
82821 'name': 'dtype',
82822 'type': 'dtype'
82823 }, {
82824 'tfName': 'element_shape',
82825 'name': 'elementShape',
82826 'type': 'shape'
82827 }]
82828 }, {
82829 'tfOpName': 'TensorArrayScatterV3',
82830 'category': 'control',
82831 'inputs': [{
82832 'start': 0,
82833 'name': 'tensorArrayId',
82834 'type': 'tensor'
82835 }, {
82836 'start': 1,
82837 'name': 'indices',
82838 'type': 'number[]'
82839 }, {
82840 'start': 2,
82841 'name': 'tensor',
82842 'type': 'tensor'
82843 }, {
82844 'start': 3,
82845 'name': 'flowIn',
82846 'type': 'number'
82847 }],
82848 'attrs': [{
82849 'tfName': 'T',
82850 'name': 'dtype',
82851 'type': 'dtype'
82852 }]
82853 }, {
82854 'tfOpName': 'TensorArrayConcatV3',
82855 'category': 'control',
82856 'inputs': [{
82857 'start': 0,
82858 'name': 'tensorArrayId',
82859 'type': 'tensor'
82860 }, {
82861 'start': 1,
82862 'name': 'flowIn',
82863 'type': 'number'
82864 }],
82865 'attrs': [{
82866 'tfName': 'dtype',
82867 'name': 'dtype',
82868 'type': 'dtype'
82869 }, {
82870 'tfName': 'element_shape_except0',
82871 'name': 'elementShapeExcept0',
82872 'type': 'shape',
82873 'notSupported': true
82874 }]
82875 }, {
82876 'tfOpName': 'TensorArraySplitV3',
82877 'category': 'control',
82878 'inputs': [{
82879 'start': 0,
82880 'name': 'tensorArrayId',
82881 'type': 'tensor'
82882 }, {
82883 'start': 1,
82884 'name': 'tensor',
82885 'type': 'tensor'
82886 }, {
82887 'start': 2,
82888 'name': 'lengths',
82889 'type': 'number[]'
82890 }, {
82891 'start': 3,
82892 'name': 'flowIn',
82893 'type': 'number'
82894 }],
82895 'attrs': [{
82896 'tfName': 'T',
82897 'name': 'dtype',
82898 'type': 'dtype'
82899 }]
82900 }, {
82901 'tfOpName': 'TensorArraySizeV3',
82902 'category': 'control',
82903 'inputs': [{
82904 'start': 0,
82905 'name': 'tensorArrayId',
82906 'type': 'tensor'
82907 }, {
82908 'start': 1,
82909 'name': 'flowIn',
82910 'type': 'number'
82911 }]
82912 }, {
82913 'tfOpName': 'TensorArrayCloseV3',
82914 'category': 'control',
82915 'inputs': [{
82916 'start': 0,
82917 'name': 'tensorArrayId',
82918 'type': 'tensor'
82919 }]
82920 }, {
82921 'tfOpName': 'StatelessIf',
82922 'category': 'control',
82923 'inputs': [{
82924 'start': 0,
82925 'name': 'cond',
82926 'type': 'tensor'
82927 }, {
82928 'start': 1,
82929 'end': 0,
82930 'name': 'args',
82931 'type': 'tensors'
82932 }],
82933 'attrs': [{
82934 'tfName': 'then_branch',
82935 'name': 'thenBranch',
82936 'type': 'func'
82937 }, {
82938 'tfName': 'else_branch',
82939 'name': 'elseBranch',
82940 'type': 'func'
82941 }]
82942 }, {
82943 'tfOpName': 'If',
82944 'category': 'control',
82945 'inputs': [{
82946 'start': 0,
82947 'name': 'cond',
82948 'type': 'tensor'
82949 }, {
82950 'start': 1,
82951 'end': 0,
82952 'name': 'args',
82953 'type': 'tensors'
82954 }],
82955 'attrs': [{
82956 'tfName': 'then_branch',
82957 'name': 'thenBranch',
82958 'type': 'func'
82959 }, {
82960 'tfName': 'else_branch',
82961 'name': 'elseBranch',
82962 'type': 'func'
82963 }]
82964 }, {
82965 'tfOpName': 'StatelessWhile',
82966 'category': 'control',
82967 'inputs': [{
82968 'start': 0,
82969 'end': 0,
82970 'name': 'args',
82971 'type': 'tensors'
82972 }],
82973 'attrs': [{
82974 'tfName': 'cond',
82975 'name': 'cond',
82976 'type': 'func'
82977 }, {
82978 'tfName': 'body',
82979 'name': 'body',
82980 'type': 'func'
82981 }]
82982 }, {
82983 'tfOpName': 'While',
82984 'category': 'control',
82985 'inputs': [{
82986 'start': 0,
82987 'end': 0,
82988 'name': 'args',
82989 'type': 'tensors'
82990 }],
82991 'attrs': [{
82992 'tfName': 'cond',
82993 'name': 'cond',
82994 'type': 'func'
82995 }, {
82996 'tfName': 'body',
82997 'name': 'body',
82998 'type': 'func'
82999 }]
83000 }, {
83001 'tfOpName': 'TensorListScatter',
83002 'category': 'control',
83003 'inputs': [{
83004 'start': 0,
83005 'name': 'tensor',
83006 'type': 'tensor'
83007 }, {
83008 'start': 1,
83009 'name': 'indices',
83010 'type': 'number[]'
83011 }, {
83012 'start': 2,
83013 'name': 'elementShape',
83014 'type': 'shape'
83015 }],
83016 'attrs': [{
83017 'tfName': 'element_dtype',
83018 'name': 'elementDType',
83019 'type': 'dtype'
83020 }]
83021 }, {
83022 'tfOpName': 'TensorListScatterV2',
83023 'category': 'control',
83024 'inputs': [{
83025 'start': 0,
83026 'name': 'tensor',
83027 'type': 'tensor'
83028 }, {
83029 'start': 1,
83030 'name': 'indices',
83031 'type': 'number[]'
83032 }, {
83033 'start': 2,
83034 'name': 'elementShape',
83035 'type': 'shape'
83036 }, {
83037 'start': 3,
83038 'name': 'numElements',
83039 'type': 'number'
83040 }],
83041 'attrs': [{
83042 'tfName': 'element_dtype',
83043 'name': 'elementDType',
83044 'type': 'dtype'
83045 }]
83046 }, {
83047 'tfOpName': 'TensorListGather',
83048 'category': 'control',
83049 'inputs': [{
83050 'start': 0,
83051 'name': 'tensorListId',
83052 'type': 'tensor'
83053 }, {
83054 'start': 1,
83055 'name': 'indices',
83056 'type': 'number[]'
83057 }, {
83058 'start': 2,
83059 'name': 'elementShape',
83060 'type': 'shape'
83061 }],
83062 'attrs': [{
83063 'tfName': 'element_dtype',
83064 'name': 'elementDType',
83065 'type': 'dtype'
83066 }]
83067 }, {
83068 'tfOpName': 'TensorListGetItem',
83069 'category': 'control',
83070 'inputs': [{
83071 'start': 0,
83072 'name': 'tensorListId',
83073 'type': 'tensor'
83074 }, {
83075 'start': 1,
83076 'name': 'index',
83077 'type': 'number'
83078 }, {
83079 'start': 2,
83080 'name': 'elementShape',
83081 'type': 'shape'
83082 }],
83083 'attrs': [{
83084 'tfName': 'element_dtype',
83085 'name': 'elementDType',
83086 'type': 'dtype'
83087 }]
83088 }, {
83089 'tfOpName': 'TensorListSetItem',
83090 'category': 'control',
83091 'inputs': [{
83092 'start': 0,
83093 'name': 'tensorListId',
83094 'type': 'tensor'
83095 }, {
83096 'start': 1,
83097 'name': 'index',
83098 'type': 'number'
83099 }, {
83100 'start': 2,
83101 'name': 'tensor',
83102 'type': 'tensor'
83103 }],
83104 'attrs': [{
83105 'tfName': 'element_dtype',
83106 'name': 'elementDType',
83107 'type': 'dtype'
83108 }]
83109 }, {
83110 'tfOpName': 'TensorListReserve',
83111 'category': 'control',
83112 'inputs': [{
83113 'start': 0,
83114 'name': 'elementShape',
83115 'type': 'shape'
83116 }, {
83117 'start': 1,
83118 'name': 'numElements',
83119 'type': 'number'
83120 }],
83121 'attrs': [{
83122 'tfName': 'element_dtype',
83123 'name': 'elementDType',
83124 'type': 'dtype'
83125 }]
83126 }, {
83127 'tfOpName': 'TensorListFromTensor',
83128 'category': 'control',
83129 'inputs': [{
83130 'start': 0,
83131 'name': 'tensor',
83132 'type': 'tensor'
83133 }, {
83134 'start': 1,
83135 'name': 'elementShape',
83136 'type': 'shape'
83137 }],
83138 'attrs': [{
83139 'tfName': 'element_dtype',
83140 'name': 'elementDType',
83141 'type': 'dtype'
83142 }]
83143 }, {
83144 'tfOpName': 'TensorListStack',
83145 'category': 'control',
83146 'inputs': [{
83147 'start': 0,
83148 'name': 'tensorListId',
83149 'type': 'tensor'
83150 }, {
83151 'start': 1,
83152 'name': 'elementShape',
83153 'type': 'shape'
83154 }],
83155 'attrs': [{
83156 'tfName': 'element_dtype',
83157 'name': 'elementDType',
83158 'type': 'dtype'
83159 }, {
83160 'tfName': 'num_elements',
83161 'name': 'numElements',
83162 'type': 'dtype'
83163 }]
83164 }, {
83165 'tfOpName': 'TensorListSplit',
83166 'category': 'control',
83167 'inputs': [{
83168 'start': 0,
83169 'name': 'tensor',
83170 'type': 'tensor'
83171 }, {
83172 'start': 1,
83173 'name': 'elementShape',
83174 'type': 'shape'
83175 }, {
83176 'start': 2,
83177 'name': 'lengths',
83178 'type': 'number[]'
83179 }],
83180 'attrs': [{
83181 'tfName': 'element_dtype',
83182 'name': 'elementDType',
83183 'type': 'dtype'
83184 }]
83185 }, {
83186 'tfOpName': 'TensorListConcat',
83187 'category': 'control',
83188 'inputs': [{
83189 'start': 0,
83190 'name': 'tensorListId',
83191 'type': 'tensor'
83192 }],
83193 'attrs': [{
83194 'tfName': 'element_shape',
83195 'name': 'elementShape',
83196 'type': 'shape'
83197 }, {
83198 'tfName': 'element_dtype',
83199 'name': 'elementDType',
83200 'type': 'dtype'
83201 }]
83202 }, {
83203 'tfOpName': 'TensorListConcatV2',
83204 'category': 'control',
83205 'inputs': [{
83206 'start': 0,
83207 'name': 'tensorListId',
83208 'type': 'tensor'
83209 }],
83210 'attrs': [{
83211 'tfName': 'element_shape',
83212 'name': 'elementShape',
83213 'type': 'shape'
83214 }, {
83215 'tfName': 'element_dtype',
83216 'name': 'elementDType',
83217 'type': 'dtype'
83218 }]
83219 }, {
83220 'tfOpName': 'TensorListPopBack',
83221 'category': 'control',
83222 'inputs': [{
83223 'start': 0,
83224 'name': 'tensorListId',
83225 'type': 'tensor'
83226 }, {
83227 'start': 1,
83228 'name': 'elementShape',
83229 'type': 'shape'
83230 }],
83231 'attrs': [{
83232 'tfName': 'element_dtype',
83233 'name': 'elementDType',
83234 'type': 'dtype'
83235 }]
83236 }, {
83237 'tfOpName': 'TensorListPushBack',
83238 'category': 'control',
83239 'inputs': [{
83240 'start': 0,
83241 'name': 'tensorListId',
83242 'type': 'tensor'
83243 }, {
83244 'start': 1,
83245 'name': 'tensor',
83246 'type': 'tensor'
83247 }],
83248 'attrs': [{
83249 'tfName': 'element_dtype',
83250 'name': 'elementDType',
83251 'type': 'dtype'
83252 }]
83253 }, {
83254 'tfOpName': 'TensorListLength',
83255 'category': 'control',
83256 'inputs': [{
83257 'start': 0,
83258 'name': 'tensorListId',
83259 'type': 'tensor'
83260 }]
83261 }, {
83262 'tfOpName': 'TensorListResize',
83263 'category': 'control',
83264 'inputs': [{
83265 'start': 0,
83266 'name': 'tensorListId',
83267 'type': 'tensor'
83268 }, {
83269 'start': 1,
83270 'name': 'size',
83271 'type': 'number'
83272 }]
83273 }];
83274
83275 var control = {
83276 __proto__: null,
83277 json: json$g
83278 };
83279
83280 /**
83281 * @license
83282 * Copyright 2023 Google LLC. All Rights Reserved.
83283 * Licensed under the Apache License, Version 2.0 (the "License");
83284 * you may not use this file except in compliance with the License.
83285 * You may obtain a copy of the License at
83286 *
83287 * http://www.apache.org/licenses/LICENSE-2.0
83288 *
83289 * Unless required by applicable law or agreed to in writing, software
83290 * distributed under the License is distributed on an "AS IS" BASIS,
83291 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83292 * See the License for the specific language governing permissions and
83293 * limitations under the License.
83294 * =============================================================================
83295 */
83296 var json$f = [{
83297 'tfOpName': 'AvgPool',
83298 'category': 'convolution',
83299 'inputs': [{
83300 'start': 0,
83301 'name': 'x',
83302 'type': 'tensor'
83303 }],
83304 'attrs': [{
83305 'tfName': 'strides',
83306 'name': 'strides',
83307 'type': 'number[]'
83308 }, {
83309 'tfName': 'padding',
83310 'name': 'pad',
83311 'type': 'string'
83312 }, {
83313 'tfName': 'data_format',
83314 'name': 'dataFormat',
83315 'type': 'string',
83316 'notSupported': true
83317 }, {
83318 'tfName': 'ksize',
83319 'name': 'kernelSize',
83320 'type': 'number[]'
83321 }, {
83322 'tfName': 'T',
83323 'name': 'dtype',
83324 'type': 'dtype',
83325 'notSupported': true
83326 }]
83327 }, {
83328 'tfOpName': 'MaxPool',
83329 'category': 'convolution',
83330 'inputs': [{
83331 'start': 0,
83332 'name': 'x',
83333 'type': 'tensor'
83334 }],
83335 'attrs': [{
83336 'tfName': 'strides',
83337 'name': 'strides',
83338 'type': 'number[]'
83339 }, {
83340 'tfName': 'padding',
83341 'name': 'pad',
83342 'type': 'string'
83343 }, {
83344 'tfName': 'data_format',
83345 'name': 'dataFormat',
83346 'type': 'string',
83347 'notSupported': true
83348 }, {
83349 'tfName': 'ksize',
83350 'name': 'kernelSize',
83351 'type': 'number[]'
83352 }, {
83353 'tfName': 'explicit_paddings',
83354 'name': 'explicitPaddings',
83355 'type': 'number[]',
83356 'defaultValue': [],
83357 'notSupported': true
83358 }, {
83359 'tfName': 'T',
83360 'name': 'dtype',
83361 'type': 'dtype',
83362 'notSupported': true
83363 }]
83364 }, {
83365 'tfOpName': 'MaxPoolWithArgmax',
83366 'category': 'convolution',
83367 'inputs': [{
83368 'start': 0,
83369 'name': 'x',
83370 'type': 'tensor'
83371 }],
83372 'attrs': [{
83373 'tfName': 'strides',
83374 'name': 'strides',
83375 'type': 'number[]'
83376 }, {
83377 'tfName': 'padding',
83378 'name': 'pad',
83379 'type': 'string'
83380 }, {
83381 'tfName': 'ksize',
83382 'name': 'kernelSize',
83383 'type': 'number[]'
83384 }, {
83385 'tfName': 'include_batch_in_index',
83386 'name': 'includeBatchInIndex',
83387 'type': 'bool'
83388 }, {
83389 'tfName': 'T',
83390 'name': 'dtype',
83391 'type': 'dtype',
83392 'notSupported': true
83393 }]
83394 }, {
83395 'tfOpName': 'AvgPool3D',
83396 'category': 'convolution',
83397 'inputs': [{
83398 'start': 0,
83399 'name': 'x',
83400 'type': 'tensor'
83401 }],
83402 'attrs': [{
83403 'tfName': 'strides',
83404 'name': 'strides',
83405 'type': 'number[]'
83406 }, {
83407 'tfName': 'padding',
83408 'name': 'pad',
83409 'type': 'string'
83410 }, {
83411 'tfName': 'data_format',
83412 'name': 'dataFormat',
83413 'type': 'string',
83414 'notSupported': true
83415 }, {
83416 'tfName': 'ksize',
83417 'name': 'kernelSize',
83418 'type': 'number[]'
83419 }, {
83420 'tfName': 'T',
83421 'name': 'dtype',
83422 'type': 'dtype',
83423 'notSupported': true
83424 }]
83425 }, {
83426 'tfOpName': 'MaxPool3D',
83427 'category': 'convolution',
83428 'inputs': [{
83429 'start': 0,
83430 'name': 'x',
83431 'type': 'tensor'
83432 }],
83433 'attrs': [{
83434 'tfName': 'strides',
83435 'name': 'strides',
83436 'type': 'number[]'
83437 }, {
83438 'tfName': 'padding',
83439 'name': 'pad',
83440 'type': 'string'
83441 }, {
83442 'tfName': 'data_format',
83443 'name': 'dataFormat',
83444 'type': 'string',
83445 'notSupported': true
83446 }, {
83447 'tfName': 'ksize',
83448 'name': 'kernelSize',
83449 'type': 'number[]'
83450 }, {
83451 'tfName': 'T',
83452 'name': 'dtype',
83453 'type': 'dtype',
83454 'notSupported': true
83455 }]
83456 }, {
83457 'tfOpName': 'Conv1D',
83458 'category': 'convolution',
83459 'inputs': [{
83460 'start': 0,
83461 'name': 'x',
83462 'type': 'tensor'
83463 }, {
83464 'start': 1,
83465 'name': 'filter',
83466 'type': 'tensor'
83467 }],
83468 'attrs': [{
83469 'tfName': 'stride',
83470 'name': 'stride',
83471 'type': 'number'
83472 }, {
83473 'tfName': 'padding',
83474 'name': 'pad',
83475 'type': 'string'
83476 }, {
83477 'tfName': 'data_format',
83478 'name': 'dataFormat',
83479 'type': 'string',
83480 'defaultValue': 'NWC'
83481 }, {
83482 'tfName': 'T',
83483 'name': 'dtype',
83484 'type': 'dtype',
83485 'notSupported': true
83486 }, {
83487 'tfName': 'dilation',
83488 'name': 'dilation',
83489 'type': 'number',
83490 'defaultValue': 1
83491 }]
83492 }, {
83493 'tfOpName': 'Conv2D',
83494 'category': 'convolution',
83495 'inputs': [{
83496 'start': 0,
83497 'name': 'x',
83498 'type': 'tensor'
83499 }, {
83500 'start': 1,
83501 'name': 'filter',
83502 'type': 'tensor'
83503 }],
83504 'attrs': [{
83505 'tfName': 'T',
83506 'name': 'dtype',
83507 'type': 'dtype',
83508 'notSupported': true
83509 }, {
83510 'tfName': 'strides',
83511 'name': 'strides',
83512 'type': 'number[]'
83513 }, {
83514 'tfName': 'padding',
83515 'name': 'pad',
83516 'type': 'string'
83517 }, {
83518 'tfName': 'useCudnnOnGpu',
83519 'name': 'useCudnnOnGpu',
83520 'type': 'bool'
83521 }, {
83522 'tfName': 'data_format',
83523 'name': 'dataFormat',
83524 'type': 'string',
83525 'defaultValue': 'NHWC'
83526 }, {
83527 'tfName': 'explicit_paddings',
83528 'name': 'explicitPaddings',
83529 'type': 'number[]',
83530 'defaultValue': []
83531 }, {
83532 'tfName': 'dilations',
83533 'name': 'dilations',
83534 'type': 'number[]'
83535 }]
83536 }, {
83537 'tfOpName': '_FusedConv2D',
83538 'category': 'convolution',
83539 'inputs': [{
83540 'start': 0,
83541 'name': 'x',
83542 'type': 'tensor'
83543 }, {
83544 'start': 1,
83545 'name': 'filter',
83546 'type': 'tensor'
83547 }, {
83548 'start': 2,
83549 'end': 0,
83550 'name': 'args',
83551 'type': 'tensors'
83552 }],
83553 'attrs': [{
83554 'tfName': 'num_args',
83555 'name': 'numArgs',
83556 'type': 'number'
83557 }, {
83558 'tfName': 'T',
83559 'name': 'dtype',
83560 'type': 'dtype',
83561 'notSupported': true
83562 }, {
83563 'tfName': 'strides',
83564 'name': 'strides',
83565 'type': 'number[]'
83566 }, {
83567 'tfName': 'padding',
83568 'name': 'pad',
83569 'type': 'string'
83570 }, {
83571 'tfName': 'explicit_paddings',
83572 'name': 'explicitPaddings',
83573 'type': 'number[]',
83574 'defaultValue': []
83575 }, {
83576 'tfName': 'use_cudnn_on_gpu',
83577 'name': 'useCudnnOnGpu',
83578 'type': 'bool',
83579 'defaultValue': true
83580 }, {
83581 'tfName': 'data_format',
83582 'name': 'dataFormat',
83583 'type': 'string',
83584 'defaultValue': 'NHWC'
83585 }, {
83586 'tfName': 'dilations',
83587 'name': 'dilations',
83588 'type': 'number[]',
83589 'defaultValue': [1, 1, 1, 1]
83590 }, {
83591 'tfName': 'fused_ops',
83592 'name': 'fusedOps',
83593 'type': 'string[]',
83594 'defaultValue': []
83595 }, {
83596 'tfName': 'epsilon',
83597 'name': 'epsilon',
83598 'type': 'number',
83599 'defaultValue': 0.0001
83600 }, {
83601 'tfName': 'leakyrelu_alpha',
83602 'name': 'leakyreluAlpha',
83603 'type': 'number',
83604 'defaultValue': 0.2
83605 }]
83606 }, {
83607 'tfOpName': 'Conv2DBackpropInput',
83608 'category': 'convolution',
83609 'inputs': [{
83610 'start': 2,
83611 'name': 'x',
83612 'type': 'tensor'
83613 }, {
83614 'start': 1,
83615 'name': 'filter',
83616 'type': 'tensor'
83617 }, {
83618 'start': 0,
83619 'name': 'outputShape',
83620 'type': 'number[]'
83621 }],
83622 'attrs': [{
83623 'tfName': 'strides',
83624 'name': 'strides',
83625 'type': 'number[]'
83626 }, {
83627 'tfName': 'padding',
83628 'name': 'pad',
83629 'type': 'string'
83630 }, {
83631 'tfName': 'data_format',
83632 'name': 'dataFormat',
83633 'type': 'string',
83634 'notSupported': true
83635 }, {
83636 'tfName': 'explicit_paddings',
83637 'name': 'explicitPaddings',
83638 'type': 'number[]',
83639 'defaultValue': []
83640 }, {
83641 'tfName': 'dilations',
83642 'name': 'dilations',
83643 'type': 'number[]',
83644 'notSupported': true
83645 }]
83646 }, {
83647 'tfOpName': 'DepthwiseConv2d',
83648 'category': 'convolution',
83649 'inputs': [{
83650 'start': 0,
83651 'name': 'input',
83652 'type': 'tensor'
83653 }, {
83654 'start': 1,
83655 'name': 'filter',
83656 'type': 'tensor'
83657 }],
83658 'attrs': [{
83659 'tfName': 'strides',
83660 'name': 'strides',
83661 'type': 'number[]'
83662 }, {
83663 'tfName': 'padding',
83664 'name': 'pad',
83665 'type': 'string'
83666 }, {
83667 'tfName': 'data_format',
83668 'name': 'dataFormat',
83669 'type': 'string',
83670 'defaultValue': 'NHWC'
83671 }, {
83672 'tfName': 'explicit_paddings',
83673 'name': 'explicitPaddings',
83674 'type': 'number[]',
83675 'defaultValue': []
83676 }, {
83677 'tfName': 'dilations',
83678 'name': 'dilations',
83679 'type': 'number[]'
83680 }]
83681 }, {
83682 'tfOpName': 'DepthwiseConv2dNative',
83683 'category': 'convolution',
83684 'inputs': [{
83685 'start': 0,
83686 'name': 'input',
83687 'type': 'tensor'
83688 }, {
83689 'start': 1,
83690 'name': 'filter',
83691 'type': 'tensor'
83692 }],
83693 'attrs': [{
83694 'tfName': 'strides',
83695 'name': 'strides',
83696 'type': 'number[]'
83697 }, {
83698 'tfName': 'padding',
83699 'name': 'pad',
83700 'type': 'string'
83701 }, {
83702 'tfName': 'data_format',
83703 'name': 'dataFormat',
83704 'type': 'string',
83705 'defaultValue': 'NHWC'
83706 }, {
83707 'tfName': 'explicit_paddings',
83708 'name': 'explicitPaddings',
83709 'type': 'number[]',
83710 'defaultValue': []
83711 }, {
83712 'tfName': 'dilations',
83713 'name': 'dilations',
83714 'type': 'number[]'
83715 }]
83716 }, {
83717 'tfOpName': 'FusedDepthwiseConv2dNative',
83718 'category': 'convolution',
83719 'inputs': [{
83720 'start': 0,
83721 'name': 'x',
83722 'type': 'tensor'
83723 }, {
83724 'start': 1,
83725 'name': 'filter',
83726 'type': 'tensor'
83727 }, {
83728 'start': 2,
83729 'end': 0,
83730 'name': 'args',
83731 'type': 'tensors'
83732 }],
83733 'attrs': [{
83734 'tfName': 'num_args',
83735 'name': 'numArgs',
83736 'type': 'number'
83737 }, {
83738 'tfName': 'T',
83739 'name': 'dtype',
83740 'type': 'dtype',
83741 'notSupported': true
83742 }, {
83743 'tfName': 'strides',
83744 'name': 'strides',
83745 'type': 'number[]'
83746 }, {
83747 'tfName': 'padding',
83748 'name': 'pad',
83749 'type': 'string'
83750 }, {
83751 'tfName': 'data_format',
83752 'name': 'dataFormat',
83753 'type': 'string',
83754 'defaultValue': 'NHWC'
83755 }, {
83756 'tfName': 'dilations',
83757 'name': 'dilations',
83758 'type': 'number[]',
83759 'defaultValue': [1, 1, 1, 1]
83760 }, {
83761 'tfName': 'fused_ops',
83762 'name': 'fusedOps',
83763 'type': 'string[]',
83764 'defaultValue': []
83765 }, {
83766 'tfName': 'explicit_paddings',
83767 'name': 'explicitPaddings',
83768 'type': 'number[]',
83769 'defaultValue': []
83770 }]
83771 }, {
83772 'tfOpName': 'Conv3D',
83773 'category': 'convolution',
83774 'inputs': [{
83775 'start': 0,
83776 'name': 'x',
83777 'type': 'tensor'
83778 }, {
83779 'start': 1,
83780 'name': 'filter',
83781 'type': 'tensor'
83782 }],
83783 'attrs': [{
83784 'tfName': 'strides',
83785 'name': 'strides',
83786 'type': 'number[]'
83787 }, {
83788 'tfName': 'padding',
83789 'name': 'pad',
83790 'type': 'string'
83791 }, {
83792 'tfName': 'data_format',
83793 'name': 'dataFormat',
83794 'type': 'string',
83795 'defaultValue': 'NHWC'
83796 }, {
83797 'tfName': 'dilations',
83798 'name': 'dilations',
83799 'type': 'number[]'
83800 }]
83801 }, {
83802 'tfOpName': 'Dilation2D',
83803 'category': 'convolution',
83804 'inputs': [{
83805 'start': 0,
83806 'name': 'x',
83807 'type': 'tensor'
83808 }, {
83809 'start': 1,
83810 'name': 'filter',
83811 'type': 'tensor'
83812 }],
83813 'attrs': [{
83814 'tfName': 'strides',
83815 'name': 'strides',
83816 'type': 'number[]'
83817 }, {
83818 'tfName': 'rates',
83819 'name': 'dilations',
83820 'type': 'number[]'
83821 }, {
83822 'tfName': 'padding',
83823 'name': 'pad',
83824 'type': 'string'
83825 }]
83826 }];
83827
83828 var convolution = {
83829 __proto__: null,
83830 json: json$f
83831 };
83832
83833 /**
83834 * @license
83835 * Copyright 2023 Google LLC. All Rights Reserved.
83836 * Licensed under the Apache License, Version 2.0 (the "License");
83837 * you may not use this file except in compliance with the License.
83838 * You may obtain a copy of the License at
83839 *
83840 * http://www.apache.org/licenses/LICENSE-2.0
83841 *
83842 * Unless required by applicable law or agreed to in writing, software
83843 * distributed under the License is distributed on an "AS IS" BASIS,
83844 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83845 * See the License for the specific language governing permissions and
83846 * limitations under the License.
83847 * =============================================================================
83848 */
83849 var json$e = [{
83850 'tfOpName': 'Fill',
83851 'category': 'creation',
83852 'inputs': [{
83853 'start': 0,
83854 'name': 'shape',
83855 'type': 'number[]'
83856 }, {
83857 'start': 1,
83858 'name': 'value',
83859 'type': 'number'
83860 }],
83861 'attrs': [{
83862 'tfName': 'T',
83863 'name': 'dtype',
83864 'type': 'dtype'
83865 }]
83866 }, {
83867 'tfOpName': 'LinSpace',
83868 'category': 'creation',
83869 'inputs': [{
83870 'start': 0,
83871 'name': 'start',
83872 'type': 'number'
83873 }, {
83874 'start': 1,
83875 'name': 'stop',
83876 'type': 'number'
83877 }, {
83878 'start': 2,
83879 'name': 'num',
83880 'type': 'number'
83881 }],
83882 'attrs': [{
83883 'tfName': 'T',
83884 'name': 'dtype',
83885 'type': 'dtype',
83886 'notSupported': true
83887 }]
83888 }, {
83889 'tfOpName': 'OneHot',
83890 'category': 'creation',
83891 'inputs': [{
83892 'start': 0,
83893 'name': 'indices',
83894 'type': 'tensor'
83895 }, {
83896 'start': 1,
83897 'name': 'depth',
83898 'type': 'number'
83899 }, {
83900 'start': 2,
83901 'name': 'onValue',
83902 'type': 'number',
83903 'defaultValue': 1
83904 }, {
83905 'start': 3,
83906 'name': 'offValue',
83907 'type': 'number',
83908 'defaultValue': 0
83909 }],
83910 'attrs': [{
83911 'tfName': 'axis',
83912 'name': 'axis',
83913 'type': 'number',
83914 'notSupported': true
83915 }, {
83916 'tfName': 'T',
83917 'name': 'dtype',
83918 'type': 'dtype'
83919 }]
83920 }, {
83921 'tfOpName': 'Ones',
83922 'category': 'creation',
83923 'inputs': [{
83924 'start': 0,
83925 'name': 'shape',
83926 'type': 'number[]'
83927 }],
83928 'attrs': [{
83929 'tfName': 'T',
83930 'name': 'dtype',
83931 'type': 'dtype'
83932 }]
83933 }, {
83934 'tfOpName': 'OnesLike',
83935 'category': 'creation',
83936 'inputs': [{
83937 'start': 0,
83938 'name': 'x',
83939 'type': 'tensor'
83940 }],
83941 'attrs': [{
83942 'tfName': 'dtype',
83943 'name': 'dtype',
83944 'type': 'dtype'
83945 }]
83946 }, {
83947 'tfOpName': 'RandomStandardNormal',
83948 'category': 'creation',
83949 'inputs': [{
83950 'start': 0,
83951 'name': 'shape',
83952 'type': 'number[]'
83953 }],
83954 'attrs': [{
83955 'tfName': 'seed',
83956 'name': 'seed',
83957 'type': 'number',
83958 'defaultValue': 0
83959 }, {
83960 'tfName': 'seed2',
83961 'name': 'seed2',
83962 'type': 'number',
83963 'defaultValue': 0,
83964 'notSupported': true
83965 }, {
83966 'tfName': 'dtype',
83967 'name': 'dtype',
83968 'type': 'dtype'
83969 }, {
83970 'tfName': 'T',
83971 'name': 'T',
83972 'type': 'number',
83973 'notSupported': true
83974 }]
83975 }, {
83976 'tfOpName': 'RandomUniform',
83977 'category': 'creation',
83978 'inputs': [{
83979 'start': 0,
83980 'name': 'shape',
83981 'type': 'number[]'
83982 }],
83983 'attrs': [{
83984 'tfName': 'minval',
83985 'name': 'minval',
83986 'type': 'number',
83987 'defaultValue': 0
83988 }, {
83989 'tfName': 'maxval',
83990 'name': 'maxval',
83991 'type': 'number',
83992 'defaultValue': 1
83993 }, {
83994 'tfName': 'dtype',
83995 'name': 'dtype',
83996 'type': 'dtype'
83997 }, {
83998 'tfName': 'seed',
83999 'name': 'seed',
84000 'type': 'number',
84001 'defaultValue': 0
84002 }, {
84003 'tfName': 'seed2',
84004 'name': 'seed2',
84005 'type': 'number',
84006 'defaultValue': 0,
84007 'notSupported': true
84008 }, {
84009 'tfName': 'T',
84010 'name': 'T',
84011 'type': 'number',
84012 'notSupported': true
84013 }]
84014 }, {
84015 'tfOpName': 'RandomUniformInt',
84016 'category': 'creation',
84017 'inputs': [{
84018 'start': 0,
84019 'name': 'shape',
84020 'type': 'number[]'
84021 }],
84022 'attrs': [{
84023 'tfName': 'minval',
84024 'name': 'minval',
84025 'type': 'number'
84026 }, {
84027 'tfName': 'maxval',
84028 'name': 'maxval',
84029 'type': 'number'
84030 }, {
84031 'tfName': 'seed',
84032 'name': 'seed',
84033 'type': 'number',
84034 'defaultValue': 0
84035 }, {
84036 'tfName': 'seed2',
84037 'name': 'seed2',
84038 'type': 'number',
84039 'defaultValue': 0,
84040 'notSupported': true
84041 }]
84042 }, {
84043 'tfOpName': 'Range',
84044 'category': 'creation',
84045 'inputs': [{
84046 'start': 0,
84047 'name': 'start',
84048 'type': 'number'
84049 }, {
84050 'start': 1,
84051 'name': 'stop',
84052 'type': 'number'
84053 }, {
84054 'start': 2,
84055 'name': 'step',
84056 'type': 'number',
84057 'defaultValue': 0
84058 }],
84059 'attrs': [{
84060 'tfName': 'Tidx',
84061 'name': 'dtype',
84062 'type': 'dtype'
84063 }]
84064 }, {
84065 'tfOpName': 'TruncatedNormal',
84066 'category': 'creation',
84067 'inputs': [{
84068 'start': 0,
84069 'name': 'shape',
84070 'type': 'number[]'
84071 }],
84072 'attrs': [{
84073 'tfName': 'means',
84074 'name': 'mean',
84075 'type': 'number',
84076 'defaultValue': 0
84077 }, {
84078 'tfName': 'stddev',
84079 'name': 'stdDev',
84080 'type': 'number',
84081 'defaultValue': 1
84082 }, {
84083 'tfName': 'seed',
84084 'name': 'seed',
84085 'type': 'number'
84086 }, {
84087 'tfName': 'seed2',
84088 'name': 'seed2',
84089 'type': 'number',
84090 'defaultValue': 0,
84091 'notSupported': true
84092 }, {
84093 'tfName': 'dtype',
84094 'name': 'dtype',
84095 'type': 'dtype'
84096 }, {
84097 'tfName': 'T',
84098 'name': 'T',
84099 'type': 'number',
84100 'notSupported': true
84101 }]
84102 }, {
84103 'tfOpName': 'Zeros',
84104 'category': 'creation',
84105 'inputs': [{
84106 'start': 0,
84107 'name': 'shape',
84108 'type': 'number[]'
84109 }],
84110 'attrs': [{
84111 'tfName': 'T',
84112 'name': 'dtype',
84113 'type': 'dtype'
84114 }]
84115 }, {
84116 'tfOpName': 'ZerosLike',
84117 'category': 'creation',
84118 'inputs': [{
84119 'start': 0,
84120 'name': 'x',
84121 'type': 'tensor'
84122 }],
84123 'attrs': [{
84124 'tfName': 'T',
84125 'name': 'dtype',
84126 'type': 'dtype'
84127 }]
84128 }, {
84129 'tfOpName': 'Multinomial',
84130 'category': 'creation',
84131 'inputs': [{
84132 'start': 0,
84133 'name': 'logits',
84134 'type': 'tensor'
84135 }, {
84136 'start': 1,
84137 'name': 'numSamples',
84138 'type': 'number'
84139 }],
84140 'attrs': [{
84141 'tfName': 'seed',
84142 'name': 'seed',
84143 'type': 'number'
84144 }, {
84145 'tfName': 'seed2',
84146 'name': 'seed2',
84147 'type': 'number'
84148 }, {
84149 'tfName': 'T',
84150 'name': 'dtype',
84151 'type': 'dtype'
84152 }, {
84153 'tfName': 'output_dtype',
84154 'name': 'output_dtype',
84155 'type': 'dtype'
84156 }]
84157 }];
84158
84159 var creation = {
84160 __proto__: null,
84161 json: json$e
84162 };
84163
84164 /**
84165 * @license
84166 * Copyright 2023 Google LLC. All Rights Reserved.
84167 * Licensed under the Apache License, Version 2.0 (the "License");
84168 * you may not use this file except in compliance with the License.
84169 * You may obtain a copy of the License at
84170 *
84171 * http://www.apache.org/licenses/LICENSE-2.0
84172 *
84173 * Unless required by applicable law or agreed to in writing, software
84174 * distributed under the License is distributed on an "AS IS" BASIS,
84175 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84176 * See the License for the specific language governing permissions and
84177 * limitations under the License.
84178 * =============================================================================
84179 */
84180 var json$d = [{
84181 'tfOpName': 'NonMaxSuppressionV2',
84182 'category': 'dynamic',
84183 'inputs': [{
84184 'start': 0,
84185 'name': 'boxes',
84186 'type': 'tensor'
84187 }, {
84188 'start': 1,
84189 'name': 'scores',
84190 'type': 'tensor'
84191 }, {
84192 'start': 2,
84193 'name': 'maxOutputSize',
84194 'type': 'number'
84195 }, {
84196 'start': 3,
84197 'name': 'iouThreshold',
84198 'type': 'number'
84199 }]
84200 }, {
84201 'tfOpName': 'NonMaxSuppressionV3',
84202 'category': 'dynamic',
84203 'inputs': [{
84204 'start': 0,
84205 'name': 'boxes',
84206 'type': 'tensor'
84207 }, {
84208 'start': 1,
84209 'name': 'scores',
84210 'type': 'tensor'
84211 }, {
84212 'start': 2,
84213 'name': 'maxOutputSize',
84214 'type': 'number'
84215 }, {
84216 'start': 3,
84217 'name': 'iouThreshold',
84218 'type': 'number'
84219 }, {
84220 'start': 4,
84221 'name': 'scoreThreshold',
84222 'type': 'number'
84223 }]
84224 }, {
84225 'tfOpName': 'NonMaxSuppressionV4',
84226 'category': 'dynamic',
84227 'inputs': [{
84228 'start': 0,
84229 'name': 'boxes',
84230 'type': 'tensor'
84231 }, {
84232 'start': 1,
84233 'name': 'scores',
84234 'type': 'tensor'
84235 }, {
84236 'start': 2,
84237 'name': 'maxOutputSize',
84238 'type': 'number'
84239 }, {
84240 'start': 3,
84241 'name': 'iouThreshold',
84242 'type': 'number'
84243 }, {
84244 'start': 4,
84245 'name': 'scoreThreshold',
84246 'type': 'number'
84247 }],
84248 'attrs': [{
84249 'tfName': 'T',
84250 'name': 'dtype',
84251 'type': 'dtype',
84252 'notSupported': true
84253 }, {
84254 'tfName': 'T_threshold',
84255 'name': 'threshold',
84256 'type': 'dtype',
84257 'notSupported': true
84258 }, {
84259 'tfName': 'pad_to_max_output_size',
84260 'name': 'padToMaxOutputSize',
84261 'type': 'bool'
84262 }]
84263 }, {
84264 'tfOpName': 'NonMaxSuppressionV5',
84265 'category': 'dynamic',
84266 'inputs': [{
84267 'start': 0,
84268 'name': 'boxes',
84269 'type': 'tensor'
84270 }, {
84271 'start': 1,
84272 'name': 'scores',
84273 'type': 'tensor'
84274 }, {
84275 'start': 2,
84276 'name': 'maxOutputSize',
84277 'type': 'number'
84278 }, {
84279 'start': 3,
84280 'name': 'iouThreshold',
84281 'type': 'number'
84282 }, {
84283 'start': 4,
84284 'name': 'scoreThreshold',
84285 'type': 'number'
84286 }, {
84287 'start': 5,
84288 'name': 'softNmsSigma',
84289 'type': 'number'
84290 }]
84291 }, {
84292 'tfOpName': 'Where',
84293 'category': 'dynamic',
84294 'inputs': [{
84295 'start': 0,
84296 'name': 'condition',
84297 'type': 'tensor'
84298 }],
84299 'attrs': [{
84300 'tfName': 'T',
84301 'name': 'dtype',
84302 'type': 'dtype',
84303 'notSupported': true
84304 }]
84305 }, {
84306 'tfOpName': 'ListDiff',
84307 'category': 'dynamic',
84308 'inputs': [{
84309 'start': 0,
84310 'name': 'x',
84311 'type': 'tensor'
84312 }, {
84313 'start': 1,
84314 'name': 'y',
84315 'type': 'tensor'
84316 }],
84317 'attrs': [{
84318 'tfName': 'T',
84319 'name': 'dtype',
84320 'type': 'dtype',
84321 'notSupported': true
84322 }]
84323 }];
84324
84325 var dynamic = {
84326 __proto__: null,
84327 json: json$d
84328 };
84329
84330 /**
84331 * @license
84332 * Copyright 2023 Google LLC. All Rights Reserved.
84333 * Licensed under the Apache License, Version 2.0 (the "License");
84334 * you may not use this file except in compliance with the License.
84335 * You may obtain a copy of the License at
84336 *
84337 * http://www.apache.org/licenses/LICENSE-2.0
84338 *
84339 * Unless required by applicable law or agreed to in writing, software
84340 * distributed under the License is distributed on an "AS IS" BASIS,
84341 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84342 * See the License for the specific language governing permissions and
84343 * limitations under the License.
84344 * =============================================================================
84345 */
84346 var json$c = [{
84347 'tfOpName': 'LowerBound',
84348 'category': 'evaluation',
84349 'inputs': [{
84350 'start': 0,
84351 'name': 'sortedSequence',
84352 'type': 'tensor'
84353 }, {
84354 'start': 1,
84355 'name': 'values',
84356 'type': 'tensor'
84357 }]
84358 }, {
84359 'tfOpName': 'TopKV2',
84360 'category': 'evaluation',
84361 'inputs': [{
84362 'start': 0,
84363 'name': 'x',
84364 'type': 'tensor'
84365 }, {
84366 'start': 1,
84367 'name': 'k',
84368 'type': 'number'
84369 }],
84370 'attrs': [{
84371 'tfName': 'sorted',
84372 'name': 'sorted',
84373 'type': 'bool'
84374 }]
84375 }, {
84376 'tfOpName': 'UpperBound',
84377 'category': 'evaluation',
84378 'inputs': [{
84379 'start': 0,
84380 'name': 'sortedSequence',
84381 'type': 'tensor'
84382 }, {
84383 'start': 1,
84384 'name': 'values',
84385 'type': 'tensor'
84386 }]
84387 }, {
84388 'tfOpName': 'Unique',
84389 'category': 'evaluation',
84390 'inputs': [{
84391 'start': 0,
84392 'name': 'x',
84393 'type': 'tensor'
84394 }]
84395 }, {
84396 'tfOpName': 'UniqueV2',
84397 'category': 'evaluation',
84398 'inputs': [{
84399 'start': 0,
84400 'name': 'x',
84401 'type': 'tensor'
84402 }, {
84403 'start': 1,
84404 'name': 'axis',
84405 'type': 'number'
84406 }]
84407 }];
84408
84409 var evaluation = {
84410 __proto__: null,
84411 json: json$c
84412 };
84413
84414 /**
84415 * @license
84416 * Copyright 2023 Google LLC. All Rights Reserved.
84417 * Licensed under the Apache License, Version 2.0 (the "License");
84418 * you may not use this file except in compliance with the License.
84419 * You may obtain a copy of the License at
84420 *
84421 * http://www.apache.org/licenses/LICENSE-2.0
84422 *
84423 * Unless required by applicable law or agreed to in writing, software
84424 * distributed under the License is distributed on an "AS IS" BASIS,
84425 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84426 * See the License for the specific language governing permissions and
84427 * limitations under the License.
84428 * =============================================================================
84429 */
84430 var json$b = [{
84431 'tfOpName': 'PlaceholderWithDefault',
84432 'category': 'graph',
84433 'inputs': [{
84434 'start': 0,
84435 'name': 'default',
84436 'type': 'tensor'
84437 }],
84438 'attrs': [{
84439 'tfName': 'shape',
84440 'name': 'shape',
84441 'type': 'shape'
84442 }, {
84443 'tfName': 'dtype',
84444 'name': 'dtype',
84445 'type': 'dtype'
84446 }]
84447 }, {
84448 'tfOpName': 'Placeholder',
84449 'category': 'graph',
84450 'attrs': [{
84451 'tfName': 'shape',
84452 'name': 'shape',
84453 'type': 'shape'
84454 }, {
84455 'tfName': 'dtype',
84456 'name': 'dtype',
84457 'type': 'dtype'
84458 }]
84459 }, {
84460 'tfOpName': 'Const',
84461 'category': 'graph'
84462 }, {
84463 'tfOpName': 'Identity',
84464 'category': 'graph',
84465 'inputs': [{
84466 'start': 0,
84467 'name': 'x',
84468 'type': 'tensor'
84469 }]
84470 }, {
84471 'tfOpName': 'IdentityN',
84472 'category': 'graph',
84473 'inputs': [{
84474 'start': 0,
84475 'end': 0,
84476 'name': 'x',
84477 'type': 'tensors'
84478 }]
84479 }, {
84480 'tfOpName': 'Snapshot',
84481 'category': 'graph',
84482 'inputs': [{
84483 'start': 0,
84484 'name': 'x',
84485 'type': 'tensor'
84486 }]
84487 }, {
84488 'tfOpName': 'Rank',
84489 'category': 'graph',
84490 'inputs': [{
84491 'start': 0,
84492 'name': 'x',
84493 'type': 'tensor'
84494 }]
84495 }, {
84496 'tfOpName': 'Size',
84497 'category': 'graph',
84498 'inputs': [{
84499 'start': 0,
84500 'name': 'x',
84501 'type': 'tensor'
84502 }]
84503 }, {
84504 'tfOpName': 'Shape',
84505 'category': 'graph',
84506 'inputs': [{
84507 'start': 0,
84508 'name': 'x',
84509 'type': 'tensor'
84510 }]
84511 }, {
84512 'tfOpName': 'ShapeN',
84513 'category': 'graph',
84514 'inputs': [{
84515 'start': 0,
84516 'end': 0,
84517 'name': 'x',
84518 'type': 'tensors'
84519 }]
84520 }, {
84521 'tfOpName': 'Print',
84522 'category': 'graph',
84523 'inputs': [{
84524 'start': 0,
84525 'name': 'x',
84526 'type': 'tensor'
84527 }, {
84528 'start': 1,
84529 'name': 'data',
84530 'type': 'tensors'
84531 }],
84532 'attrs': [{
84533 'tfName': 'message',
84534 'name': 'message',
84535 'type': 'string'
84536 }, {
84537 'tfName': 'first_n',
84538 'name': 'firstN',
84539 'type': 'number',
84540 'notSupported': true
84541 }, {
84542 'tfName': 'summarize',
84543 'name': 'summarize',
84544 'type': 'number',
84545 'defaultValue': 3
84546 }]
84547 }, {
84548 'tfOpName': 'NoOp',
84549 'category': 'graph',
84550 'inputs': []
84551 }, {
84552 'tfOpName': 'StopGradient',
84553 'category': 'graph',
84554 'inputs': [{
84555 'start': 0,
84556 'name': 'x',
84557 'type': 'tensor'
84558 }]
84559 }, {
84560 'tfOpName': 'FakeQuantWithMinMaxVars',
84561 'category': 'graph',
84562 'inputs': [{
84563 'start': 0,
84564 'name': 'x',
84565 'type': 'tensor'
84566 }],
84567 'attrs': [{
84568 'tfName': 'min',
84569 'name': 'min',
84570 'type': 'number'
84571 }, {
84572 'tfName': 'max',
84573 'name': 'max',
84574 'type': 'number'
84575 }]
84576 }];
84577
84578 var graph = {
84579 __proto__: null,
84580 json: json$b
84581 };
84582
84583 /**
84584 * @license
84585 * Copyright 2023 Google LLC. All Rights Reserved.
84586 * Licensed under the Apache License, Version 2.0 (the "License");
84587 * you may not use this file except in compliance with the License.
84588 * You may obtain a copy of the License at
84589 *
84590 * http://www.apache.org/licenses/LICENSE-2.0
84591 *
84592 * Unless required by applicable law or agreed to in writing, software
84593 * distributed under the License is distributed on an "AS IS" BASIS,
84594 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84595 * See the License for the specific language governing permissions and
84596 * limitations under the License.
84597 * =============================================================================
84598 */
84599 var json$a = [{
84600 'tfOpName': 'HashTable',
84601 'category': 'hash_table',
84602 'inputs': [],
84603 'attrs': [{
84604 'tfName': 'shared_name',
84605 'name': 'sharedName',
84606 'type': 'string'
84607 }, {
84608 'tfName': 'use_node_name_sharing',
84609 'name': 'useNodeNameSharing',
84610 'type': 'bool'
84611 }, {
84612 'tfName': 'key_dtype',
84613 'name': 'keyDType',
84614 'type': 'dtype'
84615 }, {
84616 'tfName': 'value_dtype',
84617 'name': 'valueDType',
84618 'type': 'dtype'
84619 }]
84620 }, {
84621 'tfOpName': 'HashTableV2',
84622 'category': 'hash_table',
84623 'inputs': [],
84624 'attrs': [{
84625 'tfName': 'shared_name',
84626 'name': 'sharedName',
84627 'type': 'string'
84628 }, {
84629 'tfName': 'use_node_name_sharing',
84630 'name': 'useNodeNameSharing',
84631 'type': 'bool'
84632 }, {
84633 'tfName': 'key_dtype',
84634 'name': 'keyDType',
84635 'type': 'dtype'
84636 }, {
84637 'tfName': 'value_dtype',
84638 'name': 'valueDType',
84639 'type': 'dtype'
84640 }]
84641 }, {
84642 'tfOpName': 'LookupTableImport',
84643 'category': 'hash_table',
84644 'inputs': [{
84645 'start': 0,
84646 'name': 'tableHandle',
84647 'type': 'tensor'
84648 }, {
84649 'start': 1,
84650 'name': 'keys',
84651 'type': 'tensor'
84652 }, {
84653 'start': 2,
84654 'name': 'values',
84655 'type': 'tensor'
84656 }],
84657 'attrs': [{
84658 'tfName': 'Tin',
84659 'name': 'tIn',
84660 'type': 'dtype',
84661 'notSupported': true
84662 }, {
84663 'tfName': 'Tout',
84664 'name': 'tOut',
84665 'type': 'dtype',
84666 'notSupported': true
84667 }]
84668 }, {
84669 'tfOpName': 'LookupTableImportV2',
84670 'category': 'hash_table',
84671 'inputs': [{
84672 'start': 0,
84673 'name': 'tableHandle',
84674 'type': 'tensor'
84675 }, {
84676 'start': 1,
84677 'name': 'keys',
84678 'type': 'tensor'
84679 }, {
84680 'start': 2,
84681 'name': 'values',
84682 'type': 'tensor'
84683 }],
84684 'attrs': [{
84685 'tfName': 'Tin',
84686 'name': 'tIn',
84687 'type': 'dtype',
84688 'notSupported': true
84689 }, {
84690 'tfName': 'Tout',
84691 'name': 'tOut',
84692 'type': 'dtype',
84693 'notSupported': true
84694 }]
84695 }, {
84696 'tfOpName': 'LookupTableFind',
84697 'category': 'hash_table',
84698 'inputs': [{
84699 'start': 0,
84700 'name': 'tableHandle',
84701 'type': 'tensor'
84702 }, {
84703 'start': 1,
84704 'name': 'keys',
84705 'type': 'tensor'
84706 }, {
84707 'start': 2,
84708 'name': 'defaultValue',
84709 'type': 'tensor'
84710 }],
84711 'attrs': [{
84712 'tfName': 'Tin',
84713 'name': 'tIn',
84714 'type': 'dtype',
84715 'notSupported': true
84716 }, {
84717 'tfName': 'Tout',
84718 'name': 'tOut',
84719 'type': 'dtype',
84720 'notSupported': true
84721 }]
84722 }, {
84723 'tfOpName': 'LookupTableFindV2',
84724 'category': 'hash_table',
84725 'inputs': [{
84726 'start': 0,
84727 'name': 'tableHandle',
84728 'type': 'tensor'
84729 }, {
84730 'start': 1,
84731 'name': 'keys',
84732 'type': 'tensor'
84733 }, {
84734 'start': 2,
84735 'name': 'defaultValue',
84736 'type': 'tensor'
84737 }],
84738 'attrs': [{
84739 'tfName': 'Tin',
84740 'name': 'tIn',
84741 'type': 'dtype',
84742 'notSupported': true
84743 }, {
84744 'tfName': 'Tout',
84745 'name': 'tOut',
84746 'type': 'dtype',
84747 'notSupported': true
84748 }]
84749 }, {
84750 'tfOpName': 'LookupTableSize',
84751 'category': 'hash_table',
84752 'inputs': [{
84753 'start': 0,
84754 'name': 'tableHandle',
84755 'type': 'tensor'
84756 }]
84757 }, {
84758 'tfOpName': 'LookupTableSizeV2',
84759 'category': 'hash_table',
84760 'inputs': [{
84761 'start': 0,
84762 'name': 'tableHandle',
84763 'type': 'tensor'
84764 }]
84765 }, {
84766 'tfOpName': 'InitializeTable',
84767 'category': 'hash_table',
84768 'inputs': [{
84769 'start': 0,
84770 'name': 'tableHandle',
84771 'type': 'tensor'
84772 }, {
84773 'start': 1,
84774 'name': 'keys',
84775 'type': 'tensor'
84776 }, {
84777 'start': 2,
84778 'name': 'values',
84779 'type': 'tensor'
84780 }]
84781 }, {
84782 'tfOpName': 'InitializeTableV2',
84783 'category': 'hash_table',
84784 'inputs': [{
84785 'start': 0,
84786 'name': 'tableHandle',
84787 'type': 'tensor'
84788 }, {
84789 'start': 1,
84790 'name': 'keys',
84791 'type': 'tensor'
84792 }, {
84793 'start': 2,
84794 'name': 'values',
84795 'type': 'tensor'
84796 }]
84797 }];
84798
84799 var hashTable = {
84800 __proto__: null,
84801 json: json$a
84802 };
84803
84804 /**
84805 * @license
84806 * Copyright 2023 Google LLC. All Rights Reserved.
84807 * Licensed under the Apache License, Version 2.0 (the "License");
84808 * you may not use this file except in compliance with the License.
84809 * You may obtain a copy of the License at
84810 *
84811 * http://www.apache.org/licenses/LICENSE-2.0
84812 *
84813 * Unless required by applicable law or agreed to in writing, software
84814 * distributed under the License is distributed on an "AS IS" BASIS,
84815 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84816 * See the License for the specific language governing permissions and
84817 * limitations under the License.
84818 * =============================================================================
84819 */
84820 var json$9 = [{
84821 'tfOpName': 'ResizeBilinear',
84822 'category': 'image',
84823 'inputs': [{
84824 'start': 0,
84825 'name': 'images',
84826 'type': 'tensor'
84827 }, {
84828 'start': 1,
84829 'name': 'size',
84830 'type': 'number[]'
84831 }],
84832 'attrs': [{
84833 'tfName': 'align_corners',
84834 'name': 'alignCorners',
84835 'type': 'bool'
84836 }, {
84837 'tfName': 'half_pixel_centers',
84838 'name': 'halfPixelCenters',
84839 'type': 'bool'
84840 }, {
84841 'tfName': 'T',
84842 'name': 'dtype',
84843 'type': 'dtype',
84844 'notSupported': true
84845 }]
84846 }, {
84847 'tfOpName': 'ResizeNearestNeighbor',
84848 'category': 'image',
84849 'inputs': [{
84850 'start': 0,
84851 'name': 'images',
84852 'type': 'tensor'
84853 }, {
84854 'start': 1,
84855 'name': 'size',
84856 'type': 'number[]'
84857 }],
84858 'attrs': [{
84859 'tfName': 'align_corners',
84860 'name': 'alignCorners',
84861 'type': 'bool'
84862 }, {
84863 'tfName': 'half_pixel_centers',
84864 'name': 'halfPixelCenters',
84865 'type': 'bool'
84866 }, {
84867 'tfName': 'T',
84868 'name': 'dtype',
84869 'type': 'dtype',
84870 'notSupported': true
84871 }]
84872 }, {
84873 'tfOpName': 'CropAndResize',
84874 'category': 'image',
84875 'inputs': [{
84876 'start': 0,
84877 'name': 'image',
84878 'type': 'tensor'
84879 }, {
84880 'start': 1,
84881 'name': 'boxes',
84882 'type': 'tensor'
84883 }, {
84884 'start': 2,
84885 'name': 'boxInd',
84886 'type': 'tensor'
84887 }, {
84888 'start': 3,
84889 'name': 'cropSize',
84890 'type': 'number[]'
84891 }],
84892 'attrs': [{
84893 'tfName': 'method',
84894 'name': 'method',
84895 'type': 'string'
84896 }, {
84897 'tfName': 'extrapolation_value',
84898 'name': 'extrapolationValue',
84899 'type': 'number'
84900 }]
84901 }, {
84902 'tfOpName': 'ImageProjectiveTransformV3',
84903 'category': 'image',
84904 'inputs': [{
84905 'start': 0,
84906 'name': 'images',
84907 'type': 'tensor'
84908 }, {
84909 'start': 1,
84910 'name': 'transforms',
84911 'type': 'tensor'
84912 }, {
84913 'start': 2,
84914 'name': 'outputShape',
84915 'type': 'number[]'
84916 }, {
84917 'start': 3,
84918 'name': 'fillValue',
84919 'type': 'number'
84920 }],
84921 'attrs': [{
84922 'tfName': 'interpolation',
84923 'name': 'interpolation',
84924 'type': 'string'
84925 }, {
84926 'tfName': 'fill_mode',
84927 'name': 'fillMode',
84928 'type': 'string'
84929 }]
84930 }];
84931
84932 var image = {
84933 __proto__: null,
84934 json: json$9
84935 };
84936
84937 /**
84938 * @license
84939 * Copyright 2023 Google LLC. All Rights Reserved.
84940 * Licensed under the Apache License, Version 2.0 (the "License");
84941 * you may not use this file except in compliance with the License.
84942 * You may obtain a copy of the License at
84943 *
84944 * http://www.apache.org/licenses/LICENSE-2.0
84945 *
84946 * Unless required by applicable law or agreed to in writing, software
84947 * distributed under the License is distributed on an "AS IS" BASIS,
84948 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84949 * See the License for the specific language governing permissions and
84950 * limitations under the License.
84951 * =============================================================================
84952 */
84953 var json$8 = [{
84954 'tfOpName': 'Equal',
84955 'category': 'logical',
84956 'inputs': [{
84957 'start': 0,
84958 'name': 'a',
84959 'type': 'tensor'
84960 }, {
84961 'start': 1,
84962 'name': 'b',
84963 'type': 'tensor'
84964 }],
84965 'attrs': [{
84966 'tfName': 'T',
84967 'name': 'dtype',
84968 'type': 'dtype',
84969 'notSupported': true
84970 }]
84971 }, {
84972 'tfOpName': 'NotEqual',
84973 'category': 'logical',
84974 'inputs': [{
84975 'start': 0,
84976 'name': 'a',
84977 'type': 'tensor'
84978 }, {
84979 'start': 1,
84980 'name': 'b',
84981 'type': 'tensor'
84982 }],
84983 'attrs': [{
84984 'tfName': 'T',
84985 'name': 'dtype',
84986 'type': 'dtype',
84987 'notSupported': true
84988 }]
84989 }, {
84990 'tfOpName': 'Greater',
84991 'category': 'logical',
84992 'inputs': [{
84993 'start': 0,
84994 'name': 'a',
84995 'type': 'tensor'
84996 }, {
84997 'start': 1,
84998 'name': 'b',
84999 'type': 'tensor'
85000 }],
85001 'attrs': [{
85002 'tfName': 'T',
85003 'name': 'dtype',
85004 'type': 'dtype',
85005 'notSupported': true
85006 }]
85007 }, {
85008 'tfOpName': 'GreaterEqual',
85009 'category': 'logical',
85010 'inputs': [{
85011 'start': 0,
85012 'name': 'a',
85013 'type': 'tensor'
85014 }, {
85015 'start': 1,
85016 'name': 'b',
85017 'type': 'tensor'
85018 }],
85019 'attrs': [{
85020 'tfName': 'T',
85021 'name': 'dtype',
85022 'type': 'dtype',
85023 'notSupported': true
85024 }]
85025 }, {
85026 'tfOpName': 'Less',
85027 'category': 'logical',
85028 'inputs': [{
85029 'start': 0,
85030 'name': 'a',
85031 'type': 'tensor'
85032 }, {
85033 'start': 1,
85034 'name': 'b',
85035 'type': 'tensor'
85036 }],
85037 'attrs': [{
85038 'tfName': 'T',
85039 'name': 'dtype',
85040 'type': 'dtype',
85041 'notSupported': true
85042 }]
85043 }, {
85044 'tfOpName': 'LessEqual',
85045 'category': 'logical',
85046 'inputs': [{
85047 'start': 0,
85048 'name': 'a',
85049 'type': 'tensor'
85050 }, {
85051 'start': 1,
85052 'name': 'b',
85053 'type': 'tensor'
85054 }],
85055 'attrs': [{
85056 'tfName': 'T',
85057 'name': 'dtype',
85058 'type': 'dtype',
85059 'notSupported': true
85060 }]
85061 }, {
85062 'tfOpName': 'LogicalAnd',
85063 'category': 'logical',
85064 'inputs': [{
85065 'start': 0,
85066 'name': 'a',
85067 'type': 'tensor'
85068 }, {
85069 'start': 1,
85070 'name': 'b',
85071 'type': 'tensor'
85072 }],
85073 'attrs': [{
85074 'tfName': 'T',
85075 'name': 'dtype',
85076 'type': 'dtype',
85077 'notSupported': true
85078 }]
85079 }, {
85080 'tfOpName': 'LogicalNot',
85081 'category': 'logical',
85082 'inputs': [{
85083 'start': 0,
85084 'name': 'a',
85085 'type': 'tensor'
85086 }],
85087 'attrs': [{
85088 'tfName': 'T',
85089 'name': 'dtype',
85090 'type': 'dtype',
85091 'notSupported': true
85092 }]
85093 }, {
85094 'tfOpName': 'LogicalOr',
85095 'category': 'logical',
85096 'inputs': [{
85097 'start': 0,
85098 'name': 'a',
85099 'type': 'tensor'
85100 }, {
85101 'start': 1,
85102 'name': 'b',
85103 'type': 'tensor'
85104 }],
85105 'attrs': [{
85106 'tfName': 'T',
85107 'name': 'dtype',
85108 'type': 'dtype',
85109 'notSupported': true
85110 }]
85111 }, {
85112 'tfOpName': 'Select',
85113 'category': 'logical',
85114 'inputs': [{
85115 'start': 0,
85116 'name': 'condition',
85117 'type': 'tensor'
85118 }, {
85119 'start': 1,
85120 'name': 'a',
85121 'type': 'tensor'
85122 }, {
85123 'start': 2,
85124 'name': 'b',
85125 'type': 'tensor'
85126 }],
85127 'attrs': [{
85128 'tfName': 'T',
85129 'name': 'dtype',
85130 'type': 'dtype',
85131 'notSupported': true
85132 }]
85133 }, {
85134 'tfOpName': 'SelectV2',
85135 'category': 'logical',
85136 'inputs': [{
85137 'start': 0,
85138 'name': 'condition',
85139 'type': 'tensor'
85140 }, {
85141 'start': 1,
85142 'name': 'a',
85143 'type': 'tensor'
85144 }, {
85145 'start': 2,
85146 'name': 'b',
85147 'type': 'tensor'
85148 }],
85149 'attrs': [{
85150 'tfName': 'T',
85151 'name': 'dtype',
85152 'type': 'dtype',
85153 'notSupported': true
85154 }]
85155 }, {
85156 'tfOpName': 'BitwiseAnd',
85157 'category': 'logical',
85158 'inputs': [{
85159 'start': 0,
85160 'name': 'x',
85161 'type': 'tensor'
85162 }, {
85163 'start': 1,
85164 'name': 'y',
85165 'type': 'tensor'
85166 }]
85167 }];
85168
85169 var logical = {
85170 __proto__: null,
85171 json: json$8
85172 };
85173
85174 /**
85175 * @license
85176 * Copyright 2023 Google LLC. All Rights Reserved.
85177 * Licensed under the Apache License, Version 2.0 (the "License");
85178 * you may not use this file except in compliance with the License.
85179 * You may obtain a copy of the License at
85180 *
85181 * http://www.apache.org/licenses/LICENSE-2.0
85182 *
85183 * Unless required by applicable law or agreed to in writing, software
85184 * distributed under the License is distributed on an "AS IS" BASIS,
85185 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85186 * See the License for the specific language governing permissions and
85187 * limitations under the License.
85188 * =============================================================================
85189 */
85190 var json$7 = [{
85191 'tfOpName': '_FusedMatMul',
85192 'category': 'matrices',
85193 'inputs': [{
85194 'start': 0,
85195 'name': 'a',
85196 'type': 'tensor'
85197 }, {
85198 'start': 1,
85199 'name': 'b',
85200 'type': 'tensor'
85201 }, {
85202 'start': 2,
85203 'end': 0,
85204 'name': 'args',
85205 'type': 'tensors'
85206 }],
85207 'attrs': [{
85208 'tfName': 'num_args',
85209 'name': 'numArgs',
85210 'type': 'number'
85211 }, {
85212 'tfName': 'fused_ops',
85213 'name': 'fusedOps',
85214 'type': 'string[]',
85215 'defaultValue': []
85216 }, {
85217 'tfName': 'epsilon',
85218 'name': 'epsilon',
85219 'type': 'number',
85220 'defaultValue': 0.0001
85221 }, {
85222 'tfName': 'transpose_a',
85223 'name': 'transposeA',
85224 'type': 'bool',
85225 'defaultValue': false
85226 }, {
85227 'tfName': 'transpose_b',
85228 'name': 'transposeB',
85229 'type': 'bool',
85230 'defaultValue': false
85231 }, {
85232 'tfName': 'leakyrelu_alpha',
85233 'name': 'leakyreluAlpha',
85234 'type': 'number',
85235 'defaultValue': 0.2
85236 }, {
85237 'tfName': 'T',
85238 'name': 'dtype',
85239 'type': 'dtype',
85240 'notSupported': true
85241 }]
85242 }, {
85243 'tfOpName': 'MatMul',
85244 'category': 'matrices',
85245 'inputs': [{
85246 'start': 0,
85247 'name': 'a',
85248 'type': 'tensor'
85249 }, {
85250 'start': 1,
85251 'name': 'b',
85252 'type': 'tensor'
85253 }],
85254 'attrs': [{
85255 'tfName': 'transpose_a',
85256 'name': 'transposeA',
85257 'type': 'bool',
85258 'defaultValue': false
85259 }, {
85260 'tfName': 'transpose_b',
85261 'name': 'transposeB',
85262 'type': 'bool',
85263 'defaultValue': false
85264 }, {
85265 'tfName': 'T',
85266 'name': 'dtype',
85267 'type': 'dtype',
85268 'notSupported': true
85269 }]
85270 }, {
85271 'tfOpName': 'BatchMatMul',
85272 'category': 'matrices',
85273 'inputs': [{
85274 'start': 0,
85275 'name': 'a',
85276 'type': 'tensor'
85277 }, {
85278 'start': 1,
85279 'name': 'b',
85280 'type': 'tensor'
85281 }],
85282 'attrs': [{
85283 'tfName': 'adj_x',
85284 'name': 'transposeA',
85285 'type': 'bool',
85286 'defaultValue': false
85287 }, {
85288 'tfName': 'adj_y',
85289 'name': 'transposeB',
85290 'type': 'bool',
85291 'defaultValue': false
85292 }, {
85293 'tfName': 'T',
85294 'name': 'dtype',
85295 'type': 'dtype',
85296 'notSupported': true
85297 }]
85298 }, {
85299 'tfOpName': 'BatchMatMulV2',
85300 'category': 'matrices',
85301 'inputs': [{
85302 'start': 0,
85303 'name': 'a',
85304 'type': 'tensor'
85305 }, {
85306 'start': 1,
85307 'name': 'b',
85308 'type': 'tensor'
85309 }],
85310 'attrs': [{
85311 'tfName': 'adj_x',
85312 'name': 'transposeA',
85313 'type': 'bool',
85314 'defaultValue': false
85315 }, {
85316 'tfName': 'adj_y',
85317 'name': 'transposeB',
85318 'type': 'bool',
85319 'defaultValue': false
85320 }, {
85321 'tfName': 'T',
85322 'name': 'dtype',
85323 'type': 'dtype',
85324 'notSupported': true
85325 }]
85326 }, {
85327 'tfOpName': 'Transpose',
85328 'category': 'matrices',
85329 'inputs': [{
85330 'start': 0,
85331 'name': 'x',
85332 'type': 'tensor'
85333 }, {
85334 'start': 1,
85335 'name': 'perm',
85336 'type': 'number[]'
85337 }],
85338 'attrs': [{
85339 'tfName': 'T',
85340 'name': 'dtype',
85341 'type': 'dtype',
85342 'notSupported': true
85343 }]
85344 }, {
85345 'tfOpName': 'Einsum',
85346 'category': 'matrices',
85347 'inputs': [{
85348 'start': 0,
85349 'end': 0,
85350 'name': 'tensors',
85351 'type': 'tensors'
85352 }],
85353 'attrs': [{
85354 'tfName': 'equation',
85355 'name': 'equation',
85356 'type': 'string'
85357 }, {
85358 'tfName': 'N',
85359 'name': 'n',
85360 'type': 'number',
85361 'defaultValue': 2
85362 }, {
85363 'tfName': 'T',
85364 'name': 'dtype',
85365 'type': 'dtype'
85366 }]
85367 }, {
85368 'tfOpName': 'MatrixBandPart',
85369 'category': 'matrices',
85370 'inputs': [{
85371 'start': 0,
85372 'name': 'a',
85373 'type': 'tensor'
85374 }, {
85375 'start': 1,
85376 'name': 'numLower',
85377 'type': 'tensor'
85378 }, {
85379 'start': 1,
85380 'name': 'numUpper',
85381 'type': 'tensor'
85382 }]
85383 }];
85384
85385 var matrices = {
85386 __proto__: null,
85387 json: json$7
85388 };
85389
85390 /**
85391 * @license
85392 * Copyright 2023 Google LLC. All Rights Reserved.
85393 * Licensed under the Apache License, Version 2.0 (the "License");
85394 * you may not use this file except in compliance with the License.
85395 * You may obtain a copy of the License at
85396 *
85397 * http://www.apache.org/licenses/LICENSE-2.0
85398 *
85399 * Unless required by applicable law or agreed to in writing, software
85400 * distributed under the License is distributed on an "AS IS" BASIS,
85401 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85402 * See the License for the specific language governing permissions and
85403 * limitations under the License.
85404 * =============================================================================
85405 */
85406 var json$6 = [{
85407 'tfOpName': 'EuclideanNorm',
85408 'category': 'normalization',
85409 'inputs': [{
85410 'start': 0,
85411 'name': 'x',
85412 'type': 'tensor'
85413 }, {
85414 'start': 1,
85415 'name': 'axis',
85416 'type': 'number[]'
85417 }],
85418 'attrs': [{
85419 'tfName': 'keep_dims',
85420 'name': 'keepDims',
85421 'type': 'bool',
85422 'defaultValue': false
85423 }]
85424 }, {
85425 'tfOpName': 'FusedBatchNorm',
85426 'category': 'normalization',
85427 'inputs': [{
85428 'start': 0,
85429 'name': 'x',
85430 'type': 'tensor'
85431 }, {
85432 'start': 1,
85433 'name': 'scale',
85434 'type': 'tensor'
85435 }, {
85436 'start': 2,
85437 'name': 'offset',
85438 'type': 'tensor'
85439 }, {
85440 'start': 3,
85441 'name': 'mean',
85442 'type': 'tensor'
85443 }, {
85444 'start': 4,
85445 'name': 'variance',
85446 'type': 'tensor'
85447 }],
85448 'attrs': [{
85449 'tfName': 'epsilon',
85450 'name': 'epsilon',
85451 'type': 'number',
85452 'defaultValue': 0.001
85453 }, {
85454 'tfName': 'data_format',
85455 'name': 'dataFormat',
85456 'type': 'string',
85457 'notSupported': true
85458 }]
85459 }, {
85460 'tfOpName': 'FusedBatchNormV2',
85461 'category': 'normalization',
85462 'inputs': [{
85463 'start': 0,
85464 'name': 'x',
85465 'type': 'tensor'
85466 }, {
85467 'start': 1,
85468 'name': 'scale',
85469 'type': 'tensor'
85470 }, {
85471 'start': 2,
85472 'name': 'offset',
85473 'type': 'tensor'
85474 }, {
85475 'start': 3,
85476 'name': 'mean',
85477 'type': 'tensor'
85478 }, {
85479 'start': 4,
85480 'name': 'variance',
85481 'type': 'tensor'
85482 }],
85483 'attrs': [{
85484 'tfName': 'epsilon',
85485 'name': 'epsilon',
85486 'type': 'number',
85487 'defaultValue': 0.001
85488 }, {
85489 'tfName': 'data_format',
85490 'name': 'dataFormat',
85491 'type': 'string',
85492 'notSupported': true
85493 }]
85494 }, {
85495 'tfOpName': 'FusedBatchNormV3',
85496 'category': 'normalization',
85497 'inputs': [{
85498 'start': 0,
85499 'name': 'x',
85500 'type': 'tensor'
85501 }, {
85502 'start': 1,
85503 'name': 'scale',
85504 'type': 'tensor'
85505 }, {
85506 'start': 2,
85507 'name': 'offset',
85508 'type': 'tensor'
85509 }, {
85510 'start': 3,
85511 'name': 'mean',
85512 'type': 'tensor'
85513 }, {
85514 'start': 4,
85515 'name': 'variance',
85516 'type': 'tensor'
85517 }],
85518 'attrs': [{
85519 'tfName': 'epsilon',
85520 'name': 'epsilon',
85521 'type': 'number',
85522 'defaultValue': 0.001
85523 }, {
85524 'tfName': 'data_format',
85525 'name': 'dataFormat',
85526 'type': 'string',
85527 'notSupported': true
85528 }]
85529 }, {
85530 'tfOpName': 'LRN',
85531 'category': 'normalization',
85532 'inputs': [{
85533 'start': 0,
85534 'name': 'x',
85535 'type': 'tensor'
85536 }],
85537 'attrs': [{
85538 'tfName': 'depth_radius',
85539 'name': 'radius',
85540 'type': 'number',
85541 'defaultValue': 5
85542 }, {
85543 'tfName': 'bias',
85544 'name': 'bias',
85545 'type': 'number',
85546 'defaultValue': 1
85547 }, {
85548 'tfName': 'alpha',
85549 'name': 'alpha',
85550 'type': 'number',
85551 'defaultValue': 1
85552 }, {
85553 'tfName': 'beta',
85554 'name': 'beta',
85555 'type': 'number',
85556 'defaultValue': 0.5
85557 }]
85558 }, {
85559 'tfOpName': 'Softmax',
85560 'category': 'normalization',
85561 'inputs': [{
85562 'start': 0,
85563 'name': 'x',
85564 'type': 'tensor'
85565 }]
85566 }, {
85567 'tfOpName': 'LogSoftmax',
85568 'category': 'normalization',
85569 'inputs': [{
85570 'start': 0,
85571 'name': 'x',
85572 'type': 'tensor'
85573 }]
85574 }];
85575
85576 var normalization = {
85577 __proto__: null,
85578 json: json$6
85579 };
85580
85581 /**
85582 * @license
85583 * Copyright 2023 Google LLC. All Rights Reserved.
85584 * Licensed under the Apache License, Version 2.0 (the "License");
85585 * you may not use this file except in compliance with the License.
85586 * You may obtain a copy of the License at
85587 *
85588 * http://www.apache.org/licenses/LICENSE-2.0
85589 *
85590 * Unless required by applicable law or agreed to in writing, software
85591 * distributed under the License is distributed on an "AS IS" BASIS,
85592 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85593 * See the License for the specific language governing permissions and
85594 * limitations under the License.
85595 * =============================================================================
85596 */
85597 var json$5 = [{
85598 'tfOpName': 'Bincount',
85599 'category': 'reduction',
85600 'inputs': [{
85601 'start': 0,
85602 'name': 'x',
85603 'type': 'tensor'
85604 }, {
85605 'start': 1,
85606 'name': 'size',
85607 'type': 'number'
85608 }, {
85609 'start': 2,
85610 'name': 'weights',
85611 'type': 'tensor'
85612 }]
85613 }, {
85614 'tfOpName': 'DenseBincount',
85615 'category': 'reduction',
85616 'inputs': [{
85617 'start': 0,
85618 'name': 'x',
85619 'type': 'tensor'
85620 }, {
85621 'start': 1,
85622 'name': 'size',
85623 'type': 'number'
85624 }, {
85625 'start': 2,
85626 'name': 'weights',
85627 'type': 'tensor'
85628 }],
85629 'attrs': [{
85630 'tfName': 'binary_output',
85631 'name': 'binaryOutput',
85632 'type': 'bool'
85633 }]
85634 }, {
85635 'tfOpName': 'Max',
85636 'category': 'reduction',
85637 'inputs': [{
85638 'start': 0,
85639 'name': 'x',
85640 'type': 'tensor'
85641 }, {
85642 'start': 1,
85643 'name': 'axis',
85644 'type': 'number[]'
85645 }],
85646 'attrs': [{
85647 'tfName': 'keep_dims',
85648 'name': 'keepDims',
85649 'type': 'bool'
85650 }]
85651 }, {
85652 'tfOpName': 'Mean',
85653 'category': 'reduction',
85654 'inputs': [{
85655 'start': 0,
85656 'name': 'x',
85657 'type': 'tensor'
85658 }, {
85659 'start': 1,
85660 'name': 'axis',
85661 'type': 'number[]'
85662 }],
85663 'attrs': [{
85664 'tfName': 'keep_dims',
85665 'name': 'keepDims',
85666 'type': 'bool'
85667 }]
85668 }, {
85669 'tfOpName': 'Min',
85670 'category': 'reduction',
85671 'inputs': [{
85672 'start': 0,
85673 'name': 'x',
85674 'type': 'tensor'
85675 }, {
85676 'start': 1,
85677 'name': 'axis',
85678 'type': 'number[]'
85679 }],
85680 'attrs': [{
85681 'tfName': 'keep_dims',
85682 'name': 'keepDims',
85683 'type': 'bool'
85684 }]
85685 }, {
85686 'tfOpName': 'Sum',
85687 'category': 'reduction',
85688 'inputs': [{
85689 'start': 0,
85690 'name': 'x',
85691 'type': 'tensor'
85692 }, {
85693 'start': 1,
85694 'name': 'axis',
85695 'type': 'number[]'
85696 }],
85697 'attrs': [{
85698 'tfName': 'keep_dims',
85699 'name': 'keepDims',
85700 'type': 'bool'
85701 }]
85702 }, {
85703 'tfOpName': 'All',
85704 'category': 'reduction',
85705 'inputs': [{
85706 'start': 0,
85707 'name': 'x',
85708 'type': 'tensor'
85709 }, {
85710 'start': 1,
85711 'name': 'axis',
85712 'type': 'number[]'
85713 }],
85714 'attrs': [{
85715 'tfName': 'keep_dims',
85716 'name': 'keepDims',
85717 'type': 'bool'
85718 }]
85719 }, {
85720 'tfOpName': 'Any',
85721 'category': 'reduction',
85722 'inputs': [{
85723 'start': 0,
85724 'name': 'x',
85725 'type': 'tensor'
85726 }, {
85727 'start': 1,
85728 'name': 'axis',
85729 'type': 'number[]'
85730 }],
85731 'attrs': [{
85732 'tfName': 'keep_dims',
85733 'name': 'keepDims',
85734 'type': 'bool'
85735 }]
85736 }, {
85737 'tfOpName': 'ArgMax',
85738 'category': 'reduction',
85739 'inputs': [{
85740 'start': 0,
85741 'name': 'x',
85742 'type': 'tensor'
85743 }, {
85744 'start': 1,
85745 'name': 'axis',
85746 'type': 'number'
85747 }]
85748 }, {
85749 'tfOpName': 'ArgMin',
85750 'category': 'reduction',
85751 'inputs': [{
85752 'start': 0,
85753 'name': 'x',
85754 'type': 'tensor'
85755 }, {
85756 'start': 1,
85757 'name': 'axis',
85758 'type': 'number'
85759 }]
85760 }, {
85761 'tfOpName': 'Prod',
85762 'category': 'reduction',
85763 'inputs': [{
85764 'start': 0,
85765 'name': 'x',
85766 'type': 'tensor'
85767 }, {
85768 'start': 1,
85769 'name': 'axis',
85770 'type': 'number[]'
85771 }],
85772 'attrs': [{
85773 'tfName': 'keep_dims',
85774 'name': 'keepDims',
85775 'type': 'bool'
85776 }, {
85777 'tfName': 'T',
85778 'name': 'dtype',
85779 'type': 'dtype',
85780 'notSupported': true
85781 }]
85782 }, {
85783 'tfOpName': 'Cumprod',
85784 'category': 'reduction',
85785 'inputs': [{
85786 'start': 0,
85787 'name': 'x',
85788 'type': 'tensor'
85789 }, {
85790 'start': 1,
85791 'name': 'axis',
85792 'type': 'number'
85793 }],
85794 'attrs': [{
85795 'tfName': 'exclusive',
85796 'name': 'exclusive',
85797 'type': 'bool'
85798 }, {
85799 'tfName': 'reverse',
85800 'name': 'reverse',
85801 'type': 'bool'
85802 }]
85803 }, {
85804 'tfOpName': 'Cumsum',
85805 'category': 'reduction',
85806 'inputs': [{
85807 'start': 0,
85808 'name': 'x',
85809 'type': 'tensor'
85810 }, {
85811 'start': 1,
85812 'name': 'axis',
85813 'type': 'number'
85814 }],
85815 'attrs': [{
85816 'tfName': 'exclusive',
85817 'name': 'exclusive',
85818 'type': 'bool'
85819 }, {
85820 'tfName': 'reverse',
85821 'name': 'reverse',
85822 'type': 'bool'
85823 }]
85824 }];
85825
85826 var reduction = {
85827 __proto__: null,
85828 json: json$5
85829 };
85830
85831 /**
85832 * @license
85833 * Copyright 2023 Google LLC. All Rights Reserved.
85834 * Licensed under the Apache License, Version 2.0 (the "License");
85835 * you may not use this file except in compliance with the License.
85836 * You may obtain a copy of the License at
85837 *
85838 * http://www.apache.org/licenses/LICENSE-2.0
85839 *
85840 * Unless required by applicable law or agreed to in writing, software
85841 * distributed under the License is distributed on an "AS IS" BASIS,
85842 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85843 * See the License for the specific language governing permissions and
85844 * limitations under the License.
85845 * =============================================================================
85846 */
85847 var json$4 = [{
85848 'tfOpName': 'ConcatV2',
85849 'category': 'slice_join',
85850 'inputs': [{
85851 'start': 0,
85852 'end': -1,
85853 'name': 'tensors',
85854 'type': 'tensors'
85855 }, {
85856 'start': -1,
85857 'name': 'axis',
85858 'type': 'number'
85859 }],
85860 'attrs': [{
85861 'tfName': 'N',
85862 'name': 'n',
85863 'type': 'number',
85864 'defaultValue': 2
85865 }]
85866 }, {
85867 'tfOpName': 'Concat',
85868 'category': 'slice_join',
85869 'inputs': [{
85870 'start': 1,
85871 'end': 0,
85872 'name': 'tensors',
85873 'type': 'tensors'
85874 }, {
85875 'start': 0,
85876 'name': 'axis',
85877 'type': 'number'
85878 }],
85879 'attrs': [{
85880 'tfName': 'N',
85881 'name': 'n',
85882 'type': 'number',
85883 'defaultValue': 2
85884 }]
85885 }, {
85886 'tfOpName': 'GatherV2',
85887 'category': 'slice_join',
85888 'inputs': [{
85889 'start': 0,
85890 'name': 'x',
85891 'type': 'tensor'
85892 }, {
85893 'start': 1,
85894 'name': 'indices',
85895 'type': 'tensor'
85896 }, {
85897 'start': 2,
85898 'name': 'axis',
85899 'type': 'number',
85900 'defaultValue': 0
85901 }],
85902 'attrs': [{
85903 'tfName': 'batch_dims',
85904 'name': 'batchDims',
85905 'type': 'number',
85906 'defaultValue': 0
85907 }]
85908 }, {
85909 'tfOpName': 'Gather',
85910 'category': 'slice_join',
85911 'inputs': [{
85912 'start': 0,
85913 'name': 'x',
85914 'type': 'tensor'
85915 }, {
85916 'start': 1,
85917 'name': 'indices',
85918 'type': 'tensor'
85919 }],
85920 'attrs': [{
85921 'tfName': 'validate_indices',
85922 'name': 'validateIndices',
85923 'type': 'bool',
85924 'notSupported': true
85925 }]
85926 }, {
85927 'tfOpName': 'Reverse',
85928 'category': 'slice_join',
85929 'inputs': [{
85930 'start': 0,
85931 'name': 'x',
85932 'type': 'tensor'
85933 }, {
85934 'start': 1,
85935 'name': 'dims',
85936 'type': 'bool[]'
85937 }]
85938 }, {
85939 'tfOpName': 'ReverseV2',
85940 'category': 'slice_join',
85941 'inputs': [{
85942 'start': 0,
85943 'name': 'x',
85944 'type': 'tensor'
85945 }, {
85946 'start': 1,
85947 'name': 'axis',
85948 'type': 'number[]'
85949 }]
85950 }, {
85951 'tfOpName': 'Slice',
85952 'category': 'slice_join',
85953 'inputs': [{
85954 'start': 0,
85955 'name': 'x',
85956 'type': 'tensor'
85957 }, {
85958 'start': 1,
85959 'name': 'begin',
85960 'type': 'number[]'
85961 }, {
85962 'start': 2,
85963 'name': 'size',
85964 'type': 'number[]'
85965 }]
85966 }, {
85967 'tfOpName': 'StridedSlice',
85968 'category': 'slice_join',
85969 'inputs': [{
85970 'start': 0,
85971 'name': 'x',
85972 'type': 'tensor'
85973 }, {
85974 'start': 1,
85975 'name': 'begin',
85976 'type': 'number[]'
85977 }, {
85978 'start': 2,
85979 'name': 'end',
85980 'type': 'number[]'
85981 }, {
85982 'start': 3,
85983 'name': 'strides',
85984 'type': 'number[]'
85985 }],
85986 'attrs': [{
85987 'tfName': 'begin_mask',
85988 'name': 'beginMask',
85989 'type': 'number',
85990 'defaultValue': 0
85991 }, {
85992 'tfName': 'end_mask',
85993 'name': 'endMask',
85994 'type': 'number',
85995 'defaultValue': 0
85996 }, {
85997 'tfName': 'new_axis_mask',
85998 'name': 'newAxisMask',
85999 'type': 'number',
86000 'defaultValue': 0
86001 }, {
86002 'tfName': 'ellipsis_mask',
86003 'name': 'ellipsisMask',
86004 'type': 'number',
86005 'defaultValue': 0
86006 }, {
86007 'tfName': 'shrink_axis_mask',
86008 'name': 'shrinkAxisMask',
86009 'type': 'number',
86010 'defaultValue': 0
86011 }]
86012 }, {
86013 'tfOpName': 'Pack',
86014 'category': 'slice_join',
86015 'inputs': [{
86016 'start': 0,
86017 'end': 0,
86018 'name': 'tensors',
86019 'type': 'tensors'
86020 }],
86021 'attrs': [{
86022 'tfName': 'axis',
86023 'name': 'axis',
86024 'type': 'number',
86025 'defaultValue': 0
86026 }]
86027 }, {
86028 'tfOpName': 'Unpack',
86029 'category': 'slice_join',
86030 'inputs': [{
86031 'start': 0,
86032 'name': 'tensor',
86033 'type': 'tensor'
86034 }],
86035 'attrs': [{
86036 'tfName': 'axis',
86037 'name': 'axis',
86038 'type': 'number',
86039 'defaultValue': 0
86040 }, {
86041 'tfName': 'num',
86042 'name': 'num',
86043 'type': 'number',
86044 'defaultValue': 0,
86045 'notSupported': true
86046 }]
86047 }, {
86048 'tfOpName': 'Tile',
86049 'category': 'slice_join',
86050 'inputs': [{
86051 'start': 0,
86052 'name': 'x',
86053 'type': 'tensor'
86054 }, {
86055 'start': 1,
86056 'name': 'reps',
86057 'type': 'number[]'
86058 }]
86059 }, {
86060 'tfOpName': 'Split',
86061 'category': 'slice_join',
86062 'inputs': [{
86063 'start': 0,
86064 'name': 'axis',
86065 'type': 'number',
86066 'defaultValue': 0
86067 }, {
86068 'start': 1,
86069 'name': 'x',
86070 'type': 'tensor'
86071 }],
86072 'attrs': [{
86073 'tfName': 'num_split',
86074 'name': 'numOrSizeSplits',
86075 'type': 'number',
86076 'defaultValue': 1
86077 }]
86078 }, {
86079 'tfOpName': 'SplitV',
86080 'category': 'slice_join',
86081 'inputs': [{
86082 'start': 0,
86083 'name': 'x',
86084 'type': 'tensor'
86085 }, {
86086 'start': 1,
86087 'name': 'numOrSizeSplits',
86088 'type': 'number[]'
86089 }, {
86090 'start': 2,
86091 'name': 'axis',
86092 'type': 'number',
86093 'defaultValue': 0
86094 }]
86095 }, {
86096 'tfOpName': 'ScatterNd',
86097 'category': 'slice_join',
86098 'inputs': [{
86099 'start': 0,
86100 'name': 'indices',
86101 'type': 'tensor'
86102 }, {
86103 'start': 1,
86104 'name': 'values',
86105 'type': 'tensor'
86106 }, {
86107 'start': 2,
86108 'name': 'shape',
86109 'type': 'number[]'
86110 }]
86111 }, {
86112 'tfOpName': 'GatherNd',
86113 'category': 'slice_join',
86114 'inputs': [{
86115 'start': 0,
86116 'name': 'x',
86117 'type': 'tensor'
86118 }, {
86119 'start': 1,
86120 'name': 'indices',
86121 'type': 'tensor'
86122 }]
86123 }, {
86124 'tfOpName': 'SparseToDense',
86125 'category': 'slice_join',
86126 'inputs': [{
86127 'start': 0,
86128 'name': 'sparseIndices',
86129 'type': 'tensor'
86130 }, {
86131 'start': 1,
86132 'name': 'outputShape',
86133 'type': 'number[]'
86134 }, {
86135 'start': 2,
86136 'name': 'sparseValues',
86137 'type': 'tensor'
86138 }, {
86139 'start': 3,
86140 'name': 'defaultValue',
86141 'type': 'tensor'
86142 }],
86143 'attrs': [{
86144 'tfName': 'validate_indices',
86145 'name': 'validateIndices',
86146 'type': 'bool',
86147 'defaultValue': false,
86148 'notSupported': true
86149 }]
86150 }, {
86151 'tfOpName': 'TensorScatterUpdate',
86152 'category': 'slice_join',
86153 'inputs': [{
86154 'start': 0,
86155 'name': 'tensor',
86156 'type': 'tensor'
86157 }, {
86158 'start': 1,
86159 'name': 'indices',
86160 'type': 'tensor'
86161 }, {
86162 'start': 2,
86163 'name': 'values',
86164 'type': 'tensor'
86165 }]
86166 }];
86167
86168 var sliceJoin = {
86169 __proto__: null,
86170 json: json$4
86171 };
86172
86173 /**
86174 * @license
86175 * Copyright 2023 Google LLC. All Rights Reserved.
86176 * Licensed under the Apache License, Version 2.0 (the "License");
86177 * you may not use this file except in compliance with the License.
86178 * You may obtain a copy of the License at
86179 *
86180 * http://www.apache.org/licenses/LICENSE-2.0
86181 *
86182 * Unless required by applicable law or agreed to in writing, software
86183 * distributed under the License is distributed on an "AS IS" BASIS,
86184 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86185 * See the License for the specific language governing permissions and
86186 * limitations under the License.
86187 * =============================================================================
86188 */
86189 var json$3 = [{
86190 'tfOpName': 'SparseFillEmptyRows',
86191 'category': 'sparse',
86192 'inputs': [{
86193 'start': 0,
86194 'name': 'indices',
86195 'type': 'tensor'
86196 }, {
86197 'start': 1,
86198 'name': 'values',
86199 'type': 'tensor'
86200 }, {
86201 'start': 2,
86202 'name': 'denseShape',
86203 'type': 'tensor'
86204 }, {
86205 'start': 3,
86206 'name': 'defaultValue',
86207 'type': 'tensor'
86208 }]
86209 }, {
86210 'tfOpName': 'SparseReshape',
86211 'category': 'sparse',
86212 'inputs': [{
86213 'start': 0,
86214 'name': 'inputIndices',
86215 'type': 'tensor'
86216 }, {
86217 'start': 1,
86218 'name': 'inputShape',
86219 'type': 'tensor'
86220 }, {
86221 'start': 2,
86222 'name': 'newShape',
86223 'type': 'tensor'
86224 }],
86225 'attrs': [{
86226 'tfName': 'T',
86227 'name': 'dtype',
86228 'type': 'dtype',
86229 'notSupported': true
86230 }]
86231 }, {
86232 'tfOpName': 'SparseSegmentMean',
86233 'category': 'sparse',
86234 'inputs': [{
86235 'start': 0,
86236 'name': 'data',
86237 'type': 'tensor'
86238 }, {
86239 'start': 1,
86240 'name': 'indices',
86241 'type': 'tensor'
86242 }, {
86243 'start': 2,
86244 'name': 'segmentIds',
86245 'type': 'tensor'
86246 }]
86247 }, {
86248 'tfOpName': 'SparseSegmentSum',
86249 'category': 'sparse',
86250 'inputs': [{
86251 'start': 0,
86252 'name': 'data',
86253 'type': 'tensor'
86254 }, {
86255 'start': 1,
86256 'name': 'indices',
86257 'type': 'tensor'
86258 }, {
86259 'start': 2,
86260 'name': 'segmentIds',
86261 'type': 'tensor'
86262 }]
86263 }];
86264
86265 var sparse = {
86266 __proto__: null,
86267 json: json$3
86268 };
86269
86270 /**
86271 * @license
86272 * Copyright 2023 Google LLC. All Rights Reserved.
86273 * Licensed under the Apache License, Version 2.0 (the "License");
86274 * you may not use this file except in compliance with the License.
86275 * You may obtain a copy of the License at
86276 *
86277 * http://www.apache.org/licenses/LICENSE-2.0
86278 *
86279 * Unless required by applicable law or agreed to in writing, software
86280 * distributed under the License is distributed on an "AS IS" BASIS,
86281 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86282 * See the License for the specific language governing permissions and
86283 * limitations under the License.
86284 * =============================================================================
86285 */
86286 var json$2 = [{
86287 'tfOpName': 'FFT',
86288 'category': 'spectral',
86289 'inputs': [{
86290 'start': 0,
86291 'name': 'x',
86292 'type': 'tensor'
86293 }]
86294 }, {
86295 'tfOpName': 'IFFT',
86296 'category': 'spectral',
86297 'inputs': [{
86298 'start': 0,
86299 'name': 'x',
86300 'type': 'tensor'
86301 }]
86302 }, {
86303 'tfOpName': 'RFFT',
86304 'category': 'spectral',
86305 'inputs': [{
86306 'start': 0,
86307 'name': 'x',
86308 'type': 'tensor'
86309 }, {
86310 'start': 1,
86311 'name': 'fft_length',
86312 'type': 'number',
86313 'notSupported': true
86314 }]
86315 }, {
86316 'tfOpName': 'IRFFT',
86317 'category': 'spectral',
86318 'inputs': [{
86319 'start': 0,
86320 'name': 'x',
86321 'type': 'tensor'
86322 }, {
86323 'start': 1,
86324 'name': 'fft_length',
86325 'type': 'number',
86326 'notSupported': true
86327 }]
86328 }];
86329
86330 var spectral = {
86331 __proto__: null,
86332 json: json$2
86333 };
86334
86335 /**
86336 * @license
86337 * Copyright 2023 Google LLC. All Rights Reserved.
86338 * Licensed under the Apache License, Version 2.0 (the "License");
86339 * you may not use this file except in compliance with the License.
86340 * You may obtain a copy of the License at
86341 *
86342 * http://www.apache.org/licenses/LICENSE-2.0
86343 *
86344 * Unless required by applicable law or agreed to in writing, software
86345 * distributed under the License is distributed on an "AS IS" BASIS,
86346 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86347 * See the License for the specific language governing permissions and
86348 * limitations under the License.
86349 * =============================================================================
86350 */
86351 var json$1 = [{
86352 'tfOpName': 'StaticRegexReplace',
86353 'category': 'string',
86354 'inputs': [{
86355 'start': 0,
86356 'name': 'input',
86357 'type': 'tensor'
86358 }],
86359 'attrs': [{
86360 'tfName': 'pattern',
86361 'name': 'pattern',
86362 'type': 'string'
86363 }, {
86364 'tfName': 'rewrite',
86365 'name': 'rewrite',
86366 'type': 'string'
86367 }, {
86368 'tfName': 'replace_global',
86369 'name': 'replaceGlobal',
86370 'type': 'bool'
86371 }]
86372 }, {
86373 'tfOpName': 'StringNGrams',
86374 'category': 'string',
86375 'inputs': [{
86376 'start': 0,
86377 'name': 'data',
86378 'type': 'tensor'
86379 }, {
86380 'start': 1,
86381 'name': 'dataSplits',
86382 'type': 'tensor'
86383 }],
86384 'attrs': [{
86385 'tfName': 'separator',
86386 'name': 'separator',
86387 'type': 'string'
86388 }, {
86389 'tfName': 'ngram_widths',
86390 'name': 'nGramWidths',
86391 'type': 'number[]'
86392 }, {
86393 'tfName': 'left_pad',
86394 'name': 'leftPad',
86395 'type': 'string'
86396 }, {
86397 'tfName': 'right_pad',
86398 'name': 'rightPad',
86399 'type': 'string'
86400 }, {
86401 'tfName': 'pad_width',
86402 'name': 'padWidth',
86403 'type': 'number'
86404 }, {
86405 'tfName': 'preserve_short_sequences',
86406 'name': 'preserveShortSequences',
86407 'type': 'bool'
86408 }],
86409 'outputs': ['ngrams', 'ngrams_splits']
86410 }, {
86411 'tfOpName': 'StringSplit',
86412 'category': 'string',
86413 'inputs': [{
86414 'start': 0,
86415 'name': 'input',
86416 'type': 'tensor'
86417 }, {
86418 'start': 1,
86419 'name': 'delimiter',
86420 'type': 'tensor'
86421 }],
86422 'attrs': [{
86423 'tfName': 'skip_empty',
86424 'name': 'skipEmpty',
86425 'type': 'bool'
86426 }],
86427 'outputs': ['indices', 'values', 'shape']
86428 }, {
86429 'tfOpName': 'StringToHashBucketFast',
86430 'category': 'string',
86431 'inputs': [{
86432 'start': 0,
86433 'name': 'input',
86434 'type': 'tensor'
86435 }],
86436 'attrs': [{
86437 'tfName': 'num_buckets',
86438 'name': 'numBuckets',
86439 'type': 'number'
86440 }]
86441 }];
86442
86443 var string = {
86444 __proto__: null,
86445 json: json$1
86446 };
86447
86448 /**
86449 * @license
86450 * Copyright 2023 Google LLC. All Rights Reserved.
86451 * Licensed under the Apache License, Version 2.0 (the "License");
86452 * you may not use this file except in compliance with the License.
86453 * You may obtain a copy of the License at
86454 *
86455 * http://www.apache.org/licenses/LICENSE-2.0
86456 *
86457 * Unless required by applicable law or agreed to in writing, software
86458 * distributed under the License is distributed on an "AS IS" BASIS,
86459 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86460 * See the License for the specific language governing permissions and
86461 * limitations under the License.
86462 * =============================================================================
86463 */
86464 var json = [{
86465 'tfOpName': 'Cast',
86466 'category': 'transformation',
86467 'inputs': [{
86468 'start': 0,
86469 'name': 'x',
86470 'type': 'tensor'
86471 }],
86472 'attrs': [{
86473 'tfName': 'SrcT',
86474 'name': 'sdtype',
86475 'type': 'dtype',
86476 'notSupported': true
86477 }, {
86478 'tfName': 'DstT',
86479 'name': 'dtype',
86480 'type': 'dtype'
86481 }]
86482 }, {
86483 'tfOpName': 'ExpandDims',
86484 'category': 'transformation',
86485 'inputs': [{
86486 'start': 0,
86487 'name': 'x',
86488 'type': 'tensor'
86489 }, {
86490 'start': 1,
86491 'name': 'axis',
86492 'type': 'number'
86493 }]
86494 }, {
86495 'tfOpName': 'MirrorPad',
86496 'category': 'transformation',
86497 'inputs': [{
86498 'start': 0,
86499 'name': 'x',
86500 'type': 'tensor'
86501 }, {
86502 'start': 1,
86503 'name': 'padding',
86504 'type': 'number[]'
86505 }],
86506 'attrs': [{
86507 'tfName': 'mode',
86508 'name': 'mode',
86509 'type': 'string'
86510 }]
86511 }, {
86512 'tfOpName': 'Pad',
86513 'category': 'transformation',
86514 'inputs': [{
86515 'start': 0,
86516 'name': 'x',
86517 'type': 'tensor'
86518 }, {
86519 'start': 1,
86520 'name': 'padding',
86521 'type': 'number[]'
86522 }],
86523 'attrs': [{
86524 'tfName': 'constant_value',
86525 'name': 'constantValue',
86526 'type': 'number',
86527 'defaultValue': 0
86528 }]
86529 }, {
86530 'tfOpName': 'PadV2',
86531 'category': 'transformation',
86532 'inputs': [{
86533 'start': 0,
86534 'name': 'x',
86535 'type': 'tensor'
86536 }, {
86537 'start': 1,
86538 'name': 'padding',
86539 'type': 'number[]'
86540 }, {
86541 'start': 2,
86542 'name': 'constantValue',
86543 'type': 'number',
86544 'defaultValue': 0
86545 }]
86546 }, {
86547 'tfOpName': 'Reshape',
86548 'category': 'transformation',
86549 'inputs': [{
86550 'start': 0,
86551 'name': 'x',
86552 'type': 'tensor'
86553 }, {
86554 'start': 1,
86555 'name': 'shape',
86556 'type': 'number[]'
86557 }]
86558 }, {
86559 'tfOpName': 'EnsureShape',
86560 'category': 'transformation',
86561 'inputs': [{
86562 'start': 0,
86563 'name': 'x',
86564 'type': 'tensor'
86565 }, {
86566 'start': 1,
86567 'name': 'shape',
86568 'type': 'number[]'
86569 }]
86570 }, {
86571 'tfOpName': 'Squeeze',
86572 'category': 'transformation',
86573 'inputs': [{
86574 'start': 0,
86575 'name': 'x',
86576 'type': 'tensor'
86577 }],
86578 'attrs': [{
86579 'tfName': 'axis',
86580 'tfDeprecatedName': 'squeeze_dims',
86581 'name': 'axis',
86582 'type': 'number[]'
86583 }]
86584 }, {
86585 'tfOpName': 'SpaceToBatchND',
86586 'category': 'transformation',
86587 'inputs': [{
86588 'start': 0,
86589 'name': 'x',
86590 'type': 'tensor'
86591 }, {
86592 'start': 1,
86593 'name': 'blockShape',
86594 'type': 'number[]'
86595 }, {
86596 'start': 2,
86597 'name': 'paddings',
86598 'type': 'number[]'
86599 }]
86600 }, {
86601 'tfOpName': 'BatchToSpaceND',
86602 'category': 'transformation',
86603 'inputs': [{
86604 'start': 0,
86605 'name': 'x',
86606 'type': 'tensor'
86607 }, {
86608 'start': 1,
86609 'name': 'blockShape',
86610 'type': 'number[]'
86611 }, {
86612 'start': 2,
86613 'name': 'crops',
86614 'type': 'number[]'
86615 }]
86616 }, {
86617 'tfOpName': 'DepthToSpace',
86618 'category': 'transformation',
86619 'inputs': [{
86620 'start': 0,
86621 'name': 'x',
86622 'type': 'tensor'
86623 }],
86624 'attrs': [{
86625 'tfName': 'block_size',
86626 'name': 'blockSize',
86627 'type': 'number'
86628 }, {
86629 'tfName': 'data_format',
86630 'name': 'dataFormat',
86631 'type': 'string'
86632 }]
86633 }, {
86634 'tfOpName': 'BroadcastTo',
86635 'category': 'transformation',
86636 'inputs': [{
86637 'start': 0,
86638 'name': 'x',
86639 'type': 'tensor'
86640 }, {
86641 'start': 1,
86642 'name': 'shape',
86643 'type': 'number[]'
86644 }],
86645 'attrs': []
86646 }, {
86647 'tfOpName': 'BroadcastArgs',
86648 'category': 'transformation',
86649 'inputs': [{
86650 'start': 0,
86651 'name': 's0',
86652 'type': 'tensor'
86653 }, {
86654 'start': 1,
86655 'name': 's1',
86656 'type': 'tensor'
86657 }],
86658 'attrs': []
86659 }];
86660
86661 var transformation = {
86662 __proto__: null,
86663 json: json
86664 };
86665
86666 var OperationMapper = /*#__PURE__*/function () {
86667 // Loads the op mapping from the JSON file.
86668 function OperationMapper() {
86669 var _ref;
86670 _classCallCheck(this, OperationMapper);
86671 var ops = [arithmetic, basicMath, control, convolution, creation, dynamic, evaluation, graph, hashTable, image, logical, matrices, normalization, reduction, sliceJoin, sparse, spectral, string, transformation];
86672 var mappersJson = (_ref = []).concat.apply(_ref, _toConsumableArray(ops.map(function (op) {
86673 return op.json;
86674 })));
86675 this.opMappers = mappersJson.reduce(function (map, mapper) {
86676 map[mapper.tfOpName] = mapper;
86677 return map;
86678 }, {});
86679 }
86680 // Converts the model inference graph from Tensorflow GraphDef to local
86681 // representation for TensorFlow.js API
86682 _createClass(OperationMapper, [{
86683 key: "transformGraph",
86684 value: function transformGraph(graph) {
86685 var _this = this;
86686 var signature = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
86687 var tfNodes = graph.node;
86688 var placeholders = [];
86689 var weights = [];
86690 var initNodes = [];
86691 var nodes = tfNodes.reduce(function (map, node) {
86692 map[node.name] = _this.mapNode(node);
86693 if (node.op.startsWith('Placeholder')) {
86694 placeholders.push(map[node.name]);
86695 } else if (node.op === 'Const') {
86696 weights.push(map[node.name]);
86697 } else if (node.input == null || node.input.length === 0) {
86698 initNodes.push(map[node.name]);
86699 }
86700 return map;
86701 }, {});
86702 var inputs = [];
86703 var outputs = [];
86704 var inputNodeNameToKey = {};
86705 var outputNodeNameToKey = {};
86706 if (signature != null) {
86707 inputNodeNameToKey = this.mapSignatureEntries(signature.inputs);
86708 outputNodeNameToKey = this.mapSignatureEntries(signature.outputs);
86709 }
86710 var allNodes = Object.keys(nodes);
86711 allNodes.forEach(function (key) {
86712 var node = nodes[key];
86713 node.inputNames.forEach(function (name, index) {
86714 var _getNodeNameAndIndex = getNodeNameAndIndex(name),
86715 _getNodeNameAndIndex2 = _slicedToArray(_getNodeNameAndIndex, 3),
86716 nodeName = _getNodeNameAndIndex2[0],
86717 outputName = _getNodeNameAndIndex2[2];
86718 var inputNode = nodes[nodeName];
86719 if (inputNode.outputs != null) {
86720 var outputIndex = inputNode.outputs.indexOf(outputName);
86721 if (outputIndex !== -1) {
86722 var inputName = "".concat(nodeName, ":").concat(outputIndex);
86723 // update the input name to use the mapped output index directly.
86724 node.inputNames[index] = inputName;
86725 }
86726 }
86727 node.inputs.push(inputNode);
86728 inputNode.children.push(node);
86729 });
86730 });
86731 // if signature has not outputs set, add any node that does not have
86732 // outputs.
86733 if (Object.keys(outputNodeNameToKey).length === 0) {
86734 allNodes.forEach(function (key) {
86735 var node = nodes[key];
86736 if (node.children.length === 0) {
86737 outputs.push(node);
86738 }
86739 });
86740 } else {
86741 Object.keys(outputNodeNameToKey).forEach(function (name) {
86742 var _getNodeNameAndIndex3 = getNodeNameAndIndex(name),
86743 _getNodeNameAndIndex4 = _slicedToArray(_getNodeNameAndIndex3, 1),
86744 nodeName = _getNodeNameAndIndex4[0];
86745 var node = nodes[nodeName];
86746 if (node != null) {
86747 node.signatureKey = outputNodeNameToKey[name];
86748 outputs.push(node);
86749 }
86750 });
86751 }
86752 if (Object.keys(inputNodeNameToKey).length > 0) {
86753 Object.keys(inputNodeNameToKey).forEach(function (name) {
86754 var _getNodeNameAndIndex5 = getNodeNameAndIndex(name),
86755 _getNodeNameAndIndex6 = _slicedToArray(_getNodeNameAndIndex5, 1),
86756 nodeName = _getNodeNameAndIndex6[0];
86757 var node = nodes[nodeName];
86758 if (node) {
86759 node.signatureKey = inputNodeNameToKey[name];
86760 inputs.push(node);
86761 }
86762 });
86763 } else {
86764 inputs = placeholders;
86765 }
86766 var functions = {};
86767 if (graph.library != null && graph.library.function != null) {
86768 functions = graph.library.function.reduce(function (functions, func) {
86769 functions[func.signature.name] = _this.mapFunction(func);
86770 return functions;
86771 }, {});
86772 }
86773 var result = {
86774 nodes: nodes,
86775 inputs: inputs,
86776 outputs: outputs,
86777 weights: weights,
86778 placeholders: placeholders,
86779 signature: signature,
86780 functions: functions
86781 };
86782 if (initNodes.length > 0) {
86783 result.initNodes = initNodes;
86784 }
86785 return result;
86786 }
86787 }, {
86788 key: "mapSignatureEntries",
86789 value: function mapSignatureEntries(entries) {
86790 return Object.keys(entries || {}).reduce(function (prev, curr) {
86791 prev[entries[curr].name] = curr;
86792 return prev;
86793 }, {});
86794 }
86795 }, {
86796 key: "mapNode",
86797 value: function mapNode(node) {
86798 // Unsupported ops will cause an error at run-time (not parse time), since
86799 // they may not be used by the actual execution subgraph.
86800 var mapper = getRegisteredOp(node.op) || this.opMappers[node.op] || {};
86801 if (node.attr == null) {
86802 node.attr = {};
86803 }
86804 var newNode = {
86805 name: node.name,
86806 op: node.op,
86807 category: mapper.category,
86808 inputNames: (node.input || []).map(function (input) {
86809 return input.startsWith('^') ? input.slice(1) : input;
86810 }),
86811 inputs: [],
86812 children: [],
86813 inputParams: {},
86814 attrParams: {},
86815 rawAttrs: node.attr,
86816 outputs: mapper.outputs
86817 };
86818 if (mapper.inputs != null) {
86819 newNode.inputParams = mapper.inputs.reduce(function (map, param) {
86820 map[param.name] = {
86821 type: param.type,
86822 inputIndexStart: param.start,
86823 inputIndexEnd: param.end
86824 };
86825 return map;
86826 }, {});
86827 }
86828 if (mapper.attrs != null) {
86829 newNode.attrParams = mapper.attrs.reduce(function (map, param) {
86830 var type = param.type;
86831 var value = undefined;
86832 switch (param.type) {
86833 case 'string':
86834 value = getStringParam(node.attr, param.tfName, param.defaultValue);
86835 if (value === undefined && !!param.tfDeprecatedName) {
86836 value = getStringParam(node.attr, param.tfDeprecatedName, param.defaultValue);
86837 }
86838 break;
86839 case 'string[]':
86840 value = getStringArrayParam(node.attr, param.tfName, param.defaultValue);
86841 if (value === undefined && !!param.tfDeprecatedName) {
86842 value = getStringArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
86843 }
86844 break;
86845 case 'number':
86846 value = getNumberParam(node.attr, param.tfName, param.defaultValue || 0);
86847 if (value === undefined && !!param.tfDeprecatedName) {
86848 value = getNumberParam(node.attr, param.tfDeprecatedName, param.defaultValue);
86849 }
86850 break;
86851 case 'number[]':
86852 value = getNumericArrayParam(node.attr, param.tfName, param.defaultValue);
86853 if (value === undefined && !!param.tfDeprecatedName) {
86854 value = getNumericArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
86855 }
86856 break;
86857 case 'bool':
86858 value = getBoolParam(node.attr, param.tfName, param.defaultValue);
86859 if (value === undefined && !!param.tfDeprecatedName) {
86860 value = getBoolParam(node.attr, param.tfDeprecatedName, param.defaultValue);
86861 }
86862 break;
86863 case 'bool[]':
86864 value = getBoolArrayParam(node.attr, param.tfName, param.defaultValue);
86865 if (value === undefined && !!param.tfDeprecatedName) {
86866 value = getBoolArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
86867 }
86868 break;
86869 case 'shape':
86870 value = getTensorShapeParam(node.attr, param.tfName, param.defaultValue);
86871 if (value === undefined && !!param.tfDeprecatedName) {
86872 value = getTensorShapeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
86873 }
86874 break;
86875 case 'shape[]':
86876 value = getTensorShapeArrayParam(node.attr, param.tfName, param.defaultValue);
86877 if (value === undefined && !!param.tfDeprecatedName) {
86878 value = getTensorShapeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
86879 }
86880 break;
86881 case 'dtype':
86882 value = getDtypeParam(node.attr, param.tfName, param.defaultValue);
86883 if (value === undefined && !!param.tfDeprecatedName) {
86884 value = getDtypeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
86885 }
86886 break;
86887 case 'dtype[]':
86888 value = getDtypeArrayParam(node.attr, param.tfName, param.defaultValue);
86889 if (value === undefined && !!param.tfDeprecatedName) {
86890 value = getDtypeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
86891 }
86892 break;
86893 case 'func':
86894 value = getFuncParam(node.attr, param.tfName, param.defaultValue);
86895 if (value === undefined && !!param.tfDeprecatedName) {
86896 value = getFuncParam(node.attr, param.tfDeprecatedName, param.defaultValue);
86897 }
86898 break;
86899 case 'tensor':
86900 case 'tensors':
86901 break;
86902 default:
86903 throw new Error("Unsupported param type: ".concat(param.type, " for op: ").concat(node.op));
86904 }
86905 map[param.name] = {
86906 value: value,
86907 type: type
86908 };
86909 return map;
86910 }, {});
86911 }
86912 return newNode;
86913 }
86914 // map the TFunctionDef to TFJS graph object
86915 }, {
86916 key: "mapFunction",
86917 value: function mapFunction(functionDef) {
86918 var _this2 = this;
86919 var tfNodes = functionDef.nodeDef;
86920 var placeholders = [];
86921 var weights = [];
86922 var nodes = {};
86923 if (tfNodes != null) {
86924 nodes = tfNodes.reduce(function (map, node) {
86925 map[node.name] = _this2.mapNode(node);
86926 if (node.op === 'Const') {
86927 weights.push(map[node.name]);
86928 }
86929 return map;
86930 }, {});
86931 }
86932 var inputs = [];
86933 var outputs = [];
86934 functionDef.signature.inputArg.forEach(function (arg) {
86935 var _getNodeNameAndIndex7 = getNodeNameAndIndex(arg.name),
86936 _getNodeNameAndIndex8 = _slicedToArray(_getNodeNameAndIndex7, 1),
86937 nodeName = _getNodeNameAndIndex8[0];
86938 var node = {
86939 name: nodeName,
86940 op: 'Placeholder',
86941 inputs: [],
86942 inputNames: [],
86943 category: 'graph',
86944 inputParams: {},
86945 attrParams: {
86946 dtype: {
86947 value: parseDtypeParam(arg.type),
86948 type: 'dtype'
86949 }
86950 },
86951 children: []
86952 };
86953 node.signatureKey = arg.name;
86954 inputs.push(node);
86955 nodes[nodeName] = node;
86956 });
86957 var allNodes = Object.keys(nodes);
86958 allNodes.forEach(function (key) {
86959 var node = nodes[key];
86960 node.inputNames.forEach(function (name, index) {
86961 var _getNodeNameAndIndex9 = getNodeNameAndIndex(name),
86962 _getNodeNameAndIndex10 = _slicedToArray(_getNodeNameAndIndex9, 3),
86963 nodeName = _getNodeNameAndIndex10[0],
86964 outputName = _getNodeNameAndIndex10[2];
86965 var inputNode = nodes[nodeName];
86966 if (inputNode.outputs != null) {
86967 var outputIndex = inputNode.outputs.indexOf(outputName);
86968 if (outputIndex !== -1) {
86969 var inputName = "".concat(nodeName, ":").concat(outputIndex);
86970 // update the input name to use the mapped output index directly.
86971 node.inputNames[index] = inputName;
86972 }
86973 }
86974 node.inputs.push(inputNode);
86975 inputNode.children.push(node);
86976 });
86977 });
86978 var returnNodeMap = functionDef.ret;
86979 functionDef.signature.outputArg.forEach(function (output) {
86980 var _getNodeNameAndIndex11 = getNodeNameAndIndex(returnNodeMap[output.name]),
86981 _getNodeNameAndIndex12 = _slicedToArray(_getNodeNameAndIndex11, 2),
86982 nodeName = _getNodeNameAndIndex12[0],
86983 index = _getNodeNameAndIndex12[1];
86984 var node = nodes[nodeName];
86985 if (node != null) {
86986 node.defaultOutput = index;
86987 outputs.push(node);
86988 }
86989 });
86990 var signature = this.mapArgsToSignature(functionDef);
86991 return {
86992 nodes: nodes,
86993 inputs: inputs,
86994 outputs: outputs,
86995 weights: weights,
86996 placeholders: placeholders,
86997 signature: signature
86998 };
86999 }
87000 }, {
87001 key: "mapArgsToSignature",
87002 value: function mapArgsToSignature(functionDef) {
87003 var _this3 = this;
87004 return {
87005 methodName: functionDef.signature.name,
87006 inputs: functionDef.signature.inputArg.reduce(function (map, arg) {
87007 map[arg.name] = _this3.mapArgToTensorInfo(arg);
87008 return map;
87009 }, {}),
87010 outputs: functionDef.signature.outputArg.reduce(function (map, arg) {
87011 map[arg.name] = _this3.mapArgToTensorInfo(arg, functionDef.ret);
87012 return map;
87013 }, {})
87014 };
87015 }
87016 }, {
87017 key: "mapArgToTensorInfo",
87018 value: function mapArgToTensorInfo(arg, nameMap) {
87019 var name = arg.name;
87020 if (nameMap != null) {
87021 name = nameMap[name];
87022 }
87023 return {
87024 name: name,
87025 dtype: arg.type
87026 };
87027 }
87028 }], [{
87029 key: "Instance",
87030 get:
87031 // Singleton instance for the mapper
87032 function get() {
87033 return this._instance || (this._instance = new this());
87034 }
87035 }]);
87036 return OperationMapper;
87037 }();
87038 function decodeBase64(text) {
87039 var global = env().global;
87040 if (typeof global.atob !== 'undefined') {
87041 return global.atob(text);
87042 } else if (typeof Buffer !== 'undefined') {
87043 return new Buffer(text, 'base64').toString();
87044 } else {
87045 throw new Error('Unable to decode base64 in this environment. ' + 'Missing built-in atob() or Buffer()');
87046 }
87047 }
87048 function parseStringParam(s, keepCase) {
87049 var value = Array.isArray(s) ? String.fromCharCode.apply(null, s) : decodeBase64(s);
87050 return keepCase ? value : value.toLowerCase();
87051 }
87052 function getStringParam(attrs, name, def) {
87053 var keepCase = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
87054 var param = attrs[name];
87055 if (param != null) {
87056 return parseStringParam(param.s, keepCase);
87057 }
87058 return def;
87059 }
87060 function getBoolParam(attrs, name, def) {
87061 var param = attrs[name];
87062 return param ? param.b : def;
87063 }
87064 function getNumberParam(attrs, name, def) {
87065 var param = attrs[name] || {};
87066 var value = param['i'] != null ? param['i'] : param['f'] != null ? param['f'] : def;
87067 return typeof value === 'number' ? value : parseInt(value, 10);
87068 }
87069 function parseDtypeParam(value) {
87070 if (typeof value === 'string') {
87071 // tslint:disable-next-line:no-any
87072 value = DataType[value];
87073 }
87074 switch (value) {
87075 case DataType.DT_FLOAT:
87076 case DataType.DT_HALF:
87077 return 'float32';
87078 case DataType.DT_INT32:
87079 case DataType.DT_INT64:
87080 case DataType.DT_INT8:
87081 case DataType.DT_UINT8:
87082 return 'int32';
87083 case DataType.DT_BOOL:
87084 return 'bool';
87085 case DataType.DT_DOUBLE:
87086 return 'float32';
87087 case DataType.DT_STRING:
87088 return 'string';
87089 case DataType.DT_COMPLEX64:
87090 case DataType.DT_COMPLEX128:
87091 return 'complex64';
87092 default:
87093 // Unknown dtype error will happen at runtime (instead of parse time),
87094 // since these nodes might not be used by the actual subgraph execution.
87095 return null;
87096 }
87097 }
87098 function getFuncParam(attrs, name, def) {
87099 var param = attrs[name];
87100 if (param && param.func) {
87101 return param.func.name;
87102 }
87103 return def;
87104 }
87105 function getDtypeParam(attrs, name, def) {
87106 var param = attrs[name];
87107 if (param && param.type) {
87108 return parseDtypeParam(param.type);
87109 }
87110 return def;
87111 }
87112 function getDtypeArrayParam(attrs, name, def) {
87113 var param = attrs[name];
87114 if (param && param.list && param.list.type) {
87115 return param.list.type.map(function (v) {
87116 return parseDtypeParam(v);
87117 });
87118 }
87119 return def;
87120 }
87121 function parseTensorShapeParam(shape) {
87122 if (shape.unknownRank) {
87123 return undefined;
87124 }
87125 if (shape.dim != null) {
87126 return shape.dim.map(function (dim) {
87127 return typeof dim.size === 'number' ? dim.size : parseInt(dim.size, 10);
87128 });
87129 }
87130 return [];
87131 }
87132 function getTensorShapeParam(attrs, name, def) {
87133 var param = attrs[name];
87134 if (param && param.shape) {
87135 return parseTensorShapeParam(param.shape);
87136 }
87137 return def;
87138 }
87139 function getNumericArrayParam(attrs, name, def) {
87140 var param = attrs[name];
87141 if (param) {
87142 return ((param.list.f && param.list.f.length ? param.list.f : param.list.i) || []).map(function (v) {
87143 return typeof v === 'number' ? v : parseInt(v, 10);
87144 });
87145 }
87146 return def;
87147 }
87148 function getStringArrayParam(attrs, name, def) {
87149 var keepCase = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
87150 var param = attrs[name];
87151 if (param && param.list && param.list.s) {
87152 return param.list.s.map(function (v) {
87153 return parseStringParam(v, keepCase);
87154 });
87155 }
87156 return def;
87157 }
87158 function getTensorShapeArrayParam(attrs, name, def) {
87159 var param = attrs[name];
87160 if (param && param.list && param.list.shape) {
87161 return param.list.shape.map(function (v) {
87162 return parseTensorShapeParam(v);
87163 });
87164 }
87165 return def;
87166 }
87167 function getBoolArrayParam(attrs, name, def) {
87168 var param = attrs[name];
87169 if (param && param.list && param.list.b) {
87170 return param.list.b;
87171 }
87172 return def;
87173 }
87174
87175 /**
87176 * Helper class for lookup inputs and params for nodes in the model graph.
87177 */
87178 var NodeValueImpl = /*#__PURE__*/function () {
87179 function NodeValueImpl(node, tensorMap, context) {
87180 var _this = this;
87181 _classCallCheck(this, NodeValueImpl);
87182 this.node = node;
87183 this.tensorMap = tensorMap;
87184 this.context = context;
87185 this.inputs = [];
87186 this.attrs = {};
87187 this.inputs = node.inputNames.map(function (name) {
87188 return _this.getInput(name);
87189 });
87190 if (node.rawAttrs != null) {
87191 this.attrs = Object.keys(node.rawAttrs).reduce(function (attrs, key) {
87192 attrs[key] = _this.getAttr(key);
87193 return attrs;
87194 }, {});
87195 }
87196 }
87197 /**
87198 * Return the value of the attribute or input param.
87199 * @param name String: name of attribute or input param.
87200 */
87201 _createClass(NodeValueImpl, [{
87202 key: "getInput",
87203 value: function getInput(name) {
87204 return getTensor(name, this.tensorMap, this.context);
87205 }
87206 /**
87207 * Return the value of the attribute or input param.
87208 * @param name String: name of attribute or input param.
87209 */
87210 }, {
87211 key: "getAttr",
87212 value: function getAttr(name, defaultValue) {
87213 var value = this.node.rawAttrs[name];
87214 if (value.tensor != null) {
87215 return getTensor(name, this.tensorMap, this.context);
87216 }
87217 if (value.i != null || value.f != null) {
87218 return getNumberParam(this.node.rawAttrs, name, defaultValue);
87219 }
87220 if (value.s != null) {
87221 return getStringParam(this.node.rawAttrs, name, defaultValue);
87222 }
87223 if (value.b != null) {
87224 return getBoolParam(this.node.rawAttrs, name, defaultValue);
87225 }
87226 if (value.shape != null) {
87227 return getTensorShapeParam(this.node.rawAttrs, name, defaultValue);
87228 }
87229 if (value.type != null) {
87230 return getDtypeParam(this.node.rawAttrs, name, defaultValue);
87231 }
87232 if (value.list != null) {
87233 if (value.list.i != null || value.list.f != null) {
87234 return getNumericArrayParam(this.node.rawAttrs, name, defaultValue);
87235 }
87236 if (value.list.s != null) {
87237 return getStringArrayParam(this.node.rawAttrs, name, defaultValue);
87238 }
87239 if (value.list.shape != null) {
87240 return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue);
87241 }
87242 if (value.list.b != null) {
87243 return getBoolArrayParam(this.node.rawAttrs, name, defaultValue);
87244 }
87245 if (value.list.type != null) {
87246 return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue);
87247 }
87248 }
87249 return defaultValue;
87250 }
87251 }]);
87252 return NodeValueImpl;
87253 }();
87254
87255 /**
87256 * @license
87257 * Copyright 2020 Google LLC. All Rights Reserved.
87258 * Licensed under the Apache License, Version 2.0 (the "License");
87259 * you may not use this file except in compliance with the License.
87260 * You may obtain a copy of the License at
87261 *
87262 * http://www.apache.org/licenses/LICENSE-2.0
87263 *
87264 * Unless required by applicable law or agreed to in writing, software
87265 * distributed under the License is distributed on an "AS IS" BASIS,
87266 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87267 * See the License for the specific language governing permissions and
87268 * limitations under the License.
87269 * =============================================================================
87270 */
87271
87272 var tfOps = {
87273 __proto__: null,
87274 OP_SCOPE_SUFFIX: OP_SCOPE_SUFFIX,
87275 abs: abs$2,
87276 acos: acos$2,
87277 acosh: acosh$2,
87278 add: add$3,
87279 addN: addN$2,
87280 all: all$2,
87281 any: any$2,
87282 argMax: argMax$2,
87283 argMin: argMin$2,
87284 asin: asin$2,
87285 asinh: asinh$2,
87286 atan: atan$2,
87287 atan2: atan2$2,
87288 atanh: atanh$2,
87289 avgPool: avgPool$2,
87290 avgPool3d: avgPool3d$1,
87291 basicLSTMCell: basicLSTMCell,
87292 batchNorm: batchNorm$2,
87293 batchNorm2d: batchNorm2d,
87294 batchNorm3d: batchNorm3d,
87295 batchNorm4d: batchNorm4d,
87296 batchToSpaceND: batchToSpaceND$2,
87297 bincount: bincount$2,
87298 bitwiseAnd: bitwiseAnd$2,
87299 booleanMaskAsync: booleanMaskAsync,
87300 broadcastArgs: broadcastArgs$2,
87301 broadcastTo: broadcastTo,
87302 buffer: buffer,
87303 cast: cast$3,
87304 ceil: ceil$2,
87305 clipByValue: clipByValue$2,
87306 clone: clone,
87307 complex: complex$2,
87308 concat: concat$2,
87309 concat1d: concat1d,
87310 concat2d: concat2d,
87311 concat3d: concat3d,
87312 concat4d: concat4d,
87313 conv1d: conv1d$2,
87314 conv2d: conv2d$4,
87315 conv2dTranspose: conv2dTranspose$1,
87316 conv3d: conv3d$2,
87317 conv3dTranspose: conv3dTranspose$1,
87318 cos: cos$2,
87319 cosh: cosh$2,
87320 cosineWindow: cosineWindow,
87321 cumprod: cumprod$2,
87322 cumsum: cumsum$2,
87323 denseBincount: denseBincount$2,
87324 depthToSpace: depthToSpace$2,
87325 depthwiseConv2d: depthwiseConv2d$3,
87326 diag: diag$2,
87327 dilation2d: dilation2d,
87328 div: div$1,
87329 divNoNan: divNoNan,
87330 dot: dot$2,
87331 dropout: dropout$2,
87332 einsum: einsum$2,
87333 elu: elu$4,
87334 enclosingPowerOfTwo: enclosingPowerOfTwo,
87335 ensureShape: ensureShape,
87336 equal: equal$2,
87337 erf: erf$2,
87338 euclideanNorm: euclideanNorm,
87339 exp: exp$2,
87340 expandDims: expandDims$3,
87341 expm1: expm1$2,
87342 eye: eye,
87343 fft: fft$2,
87344 fill: fill$2,
87345 floor: floor$2,
87346 floorDiv: floorDiv$2,
87347 fused: fused_ops,
87348 gather: gather$1,
87349 gatherND: gatherND,
87350 greater: greater$3,
87351 greaterEqual: greaterEqual$2,
87352 ifft: ifft$2,
87353 imag: imag$2,
87354 image: image$1,
87355 inTopKAsync: inTopKAsync,
87356 irfft: irfft,
87357 isFinite: isFinite$3,
87358 isInf: isInf$2,
87359 isNaN: isNaN$3,
87360 leakyRelu: leakyRelu$2,
87361 less: less$3,
87362 lessEqual: lessEqual$2,
87363 linalg: linalg,
87364 linspace: linspace,
87365 localResponseNormalization: localResponseNormalization,
87366 log: log$2,
87367 log1p: log1p$2,
87368 logSigmoid: logSigmoid,
87369 logSoftmax: logSoftmax,
87370 logSumExp: logSumExp,
87371 logicalAnd: logicalAnd$2,
87372 logicalNot: logicalNot$2,
87373 logicalOr: logicalOr$2,
87374 logicalXor: logicalXor,
87375 losses: losses,
87376 lowerBound: lowerBound$1,
87377 matMul: matMul$1,
87378 max: max$3,
87379 maxPool: maxPool$2,
87380 maxPool3d: maxPool3d$1,
87381 maxPoolWithArgmax: maxPoolWithArgmax,
87382 maximum: maximum$4,
87383 mean: mean$3,
87384 meshgrid: meshgrid,
87385 min: min$3,
87386 minimum: minimum$4,
87387 mirrorPad: mirrorPad$1,
87388 mod: mod$2,
87389 moments: moments,
87390 movingAverage: movingAverage,
87391 mul: mul,
87392 multiRNNCell: multiRNNCell,
87393 multinomial: multinomial$2,
87394 neg: neg$2,
87395 norm: norm,
87396 notEqual: notEqual$2,
87397 oneHot: oneHot$3,
87398 ones: ones$1,
87399 onesLike: onesLike$3,
87400 op: op,
87401 outerProduct: outerProduct,
87402 pad: pad,
87403 pad1d: pad1d,
87404 pad2d: pad2d,
87405 pad3d: pad3d,
87406 pad4d: pad4d,
87407 pool: pool$1,
87408 pow: pow$3,
87409 prelu: prelu$3,
87410 print: print,
87411 prod: prod$2,
87412 raggedGather: raggedGather$2,
87413 raggedRange: raggedRange$2,
87414 raggedTensorToTensor: raggedTensorToTensor$2,
87415 rand: rand,
87416 randomGamma: randomGamma,
87417 randomNormal: randomNormal$2,
87418 randomStandardNormal: randomStandardNormal,
87419 randomUniform: randomUniform$1,
87420 randomUniformInt: randomUniformInt,
87421 range: range$3,
87422 real: real$2,
87423 reciprocal: reciprocal$2,
87424 relu: relu$2,
87425 relu6: relu6$2,
87426 reshape: reshape$3,
87427 reverse: reverse$2,
87428 reverse1d: reverse1d,
87429 reverse2d: reverse2d,
87430 reverse3d: reverse3d,
87431 reverse4d: reverse4d,
87432 rfft: rfft,
87433 round: round$2,
87434 rsqrt: rsqrt$2,
87435 scalar: scalar,
87436 scatterND: scatterND,
87437 searchSorted: searchSorted$2,
87438 selu: selu$2,
87439 separableConv2d: separableConv2d$1,
87440 setdiff1dAsync: setdiff1dAsync,
87441 sigmoid: sigmoid$2,
87442 sign: sign$3,
87443 signal: signal,
87444 sin: sin$2,
87445 sinh: sinh$2,
87446 slice: slice$2,
87447 slice1d: slice1d,
87448 slice2d: slice2d,
87449 slice3d: slice3d,
87450 slice4d: slice4d,
87451 softmax: softmax$3,
87452 softplus: softplus$2,
87453 spaceToBatchND: spaceToBatchND$2,
87454 sparse: sparse$1,
87455 sparseToDense: sparseToDense$2,
87456 spectral: spectral$1,
87457 split: split$3,
87458 sqrt: sqrt$2,
87459 square: square$2,
87460 squaredDifference: squaredDifference$2,
87461 squeeze: squeeze,
87462 stack: stack,
87463 step: step$2,
87464 stridedSlice: stridedSlice$2,
87465 string: string$1,
87466 sub: sub$2,
87467 sum: sum$3,
87468 tan: tan$2,
87469 tanh: tanh$2,
87470 tensor: tensor,
87471 tensor1d: tensor1d,
87472 tensor2d: tensor2d,
87473 tensor3d: tensor3d,
87474 tensor4d: tensor4d,
87475 tensor5d: tensor5d,
87476 tensor6d: tensor6d,
87477 tensorScatterUpdate: tensorScatterUpdate$2,
87478 tile: tile$3,
87479 topk: topk,
87480 transpose: transpose$2,
87481 truncatedNormal: truncatedNormal$1,
87482 unique: unique$3,
87483 unsortedSegmentSum: unsortedSegmentSum$2,
87484 unstack: unstack,
87485 upperBound: upperBound$1,
87486 variable: variable$1,
87487 where: where,
87488 whereAsync: whereAsync,
87489 zeros: zeros$2,
87490 zerosLike: zerosLike$3
87491 };
87492
87493 /**
87494 * @license
87495 * Copyright 2018 Google LLC. All Rights Reserved.
87496 * Licensed under the Apache License, Version 2.0 (the "License");
87497 * you may not use this file except in compliance with the License.
87498 * You may obtain a copy of the License at
87499 *
87500 * http://www.apache.org/licenses/LICENSE-2.0
87501 *
87502 * Unless required by applicable law or agreed to in writing, software
87503 * distributed under the License is distributed on an "AS IS" BASIS,
87504 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87505 * See the License for the specific language governing permissions and
87506 * limitations under the License.
87507 * =============================================================================
87508 */
87509 var executeOp$k = function executeOp(node, tensorMap, context) {
87510 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
87511 switch (node.op) {
87512 case 'BiasAdd':
87513 case 'AddV2':
87514 case 'Add':
87515 {
87516 return [ops.add(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
87517 }
87518 case 'AddN':
87519 {
87520 return [ops.addN(getParamValue('tensors', node, tensorMap, context))];
87521 }
87522 case 'FloorMod':
87523 case 'Mod':
87524 return [ops.mod(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
87525 case 'Mul':
87526 return [ops.mul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
87527 case 'RealDiv':
87528 case 'Div':
87529 {
87530 return [ops.div(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
87531 }
87532 case 'DivNoNan':
87533 {
87534 return [ops.divNoNan(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
87535 }
87536 case 'FloorDiv':
87537 {
87538 return [ops.floorDiv(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
87539 }
87540 case 'Sub':
87541 {
87542 return [ops.sub(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
87543 }
87544 case 'Minimum':
87545 {
87546 return [ops.minimum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
87547 }
87548 case 'Maximum':
87549 {
87550 return [ops.maximum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
87551 }
87552 case 'Pow':
87553 {
87554 return [ops.pow(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
87555 }
87556 case 'SquaredDifference':
87557 {
87558 return [ops.squaredDifference(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
87559 }
87560 default:
87561 throw TypeError("Node type ".concat(node.op, " is not implemented"));
87562 }
87563 };
87564 var CATEGORY$j = 'arithmetic';
87565
87566 /**
87567 * @license
87568 * Copyright 2018 Google LLC. All Rights Reserved.
87569 * Licensed under the Apache License, Version 2.0 (the "License");
87570 * you may not use this file except in compliance with the License.
87571 * You may obtain a copy of the License at
87572 *
87573 * http://www.apache.org/licenses/LICENSE-2.0
87574 *
87575 * Unless required by applicable law or agreed to in writing, software
87576 * distributed under the License is distributed on an "AS IS" BASIS,
87577 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87578 * See the License for the specific language governing permissions and
87579 * limitations under the License.
87580 * =============================================================================
87581 */
87582 var executeOp$j = function executeOp(node, tensorMap, context) {
87583 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
87584 switch (node.op) {
87585 case 'Abs':
87586 case 'ComplexAbs':
87587 return [ops.abs(getParamValue('x', node, tensorMap, context))];
87588 case 'Acos':
87589 return [ops.acos(getParamValue('x', node, tensorMap, context))];
87590 case 'Acosh':
87591 return [ops.acosh(getParamValue('x', node, tensorMap, context))];
87592 case 'Asin':
87593 return [ops.asin(getParamValue('x', node, tensorMap, context))];
87594 case 'Asinh':
87595 return [ops.asinh(getParamValue('x', node, tensorMap, context))];
87596 case 'Atan':
87597 return [ops.atan(getParamValue('x', node, tensorMap, context))];
87598 case 'Atan2':
87599 return [ops.atan2(getParamValue('x', node, tensorMap, context), getParamValue('y', node, tensorMap, context))];
87600 case 'Atanh':
87601 return [ops.atanh(getParamValue('x', node, tensorMap, context))];
87602 case 'Ceil':
87603 return [ops.ceil(getParamValue('x', node, tensorMap, context))];
87604 case 'Complex':
87605 return [ops.complex(getParamValue('real', node, tensorMap, context), getParamValue('imag', node, tensorMap, context))];
87606 case 'Cos':
87607 return [ops.cos(getParamValue('x', node, tensorMap, context))];
87608 case 'Cosh':
87609 return [ops.cosh(getParamValue('x', node, tensorMap, context))];
87610 case 'Elu':
87611 return [ops.elu(getParamValue('x', node, tensorMap, context))];
87612 case 'Erf':
87613 return [ops.erf(getParamValue('x', node, tensorMap, context))];
87614 case 'Exp':
87615 return [ops.exp(getParamValue('x', node, tensorMap, context))];
87616 case 'Expm1':
87617 {
87618 return [ops.expm1(getParamValue('x', node, tensorMap, context))];
87619 }
87620 case 'Floor':
87621 return [ops.floor(getParamValue('x', node, tensorMap, context))];
87622 case 'Log':
87623 return [ops.log(getParamValue('x', node, tensorMap, context))];
87624 case 'Log1p':
87625 {
87626 return [ops.log1p(getParamValue('x', node, tensorMap, context))];
87627 }
87628 case 'Imag':
87629 return [ops.imag(getParamValue('x', node, tensorMap, context))];
87630 case 'Neg':
87631 return [ops.neg(getParamValue('x', node, tensorMap, context))];
87632 case 'Reciprocal':
87633 {
87634 return [ops.reciprocal(getParamValue('x', node, tensorMap, context))];
87635 }
87636 case 'Real':
87637 return [ops.real(getParamValue('x', node, tensorMap, context))];
87638 case 'Relu':
87639 return [ops.relu(getParamValue('x', node, tensorMap, context))];
87640 case 'Round':
87641 {
87642 return [ops.round(getParamValue('x', node, tensorMap, context))];
87643 }
87644 case 'Selu':
87645 return [ops.selu(getParamValue('x', node, tensorMap, context))];
87646 case 'Sigmoid':
87647 return [ops.sigmoid(getParamValue('x', node, tensorMap, context))];
87648 case 'Sin':
87649 return [ops.sin(getParamValue('x', node, tensorMap, context))];
87650 case 'Sign':
87651 {
87652 return [ops.sign(getParamValue('x', node, tensorMap, context))];
87653 }
87654 case 'Sinh':
87655 {
87656 return [ops.sinh(getParamValue('x', node, tensorMap, context))];
87657 }
87658 case 'Softplus':
87659 {
87660 return [ops.softplus(getParamValue('x', node, tensorMap, context))];
87661 }
87662 case 'Sqrt':
87663 {
87664 return [ops.sqrt(getParamValue('x', node, tensorMap, context))];
87665 }
87666 case 'Square':
87667 {
87668 return [ops.square(getParamValue('x', node, tensorMap, context))];
87669 }
87670 case 'Tanh':
87671 {
87672 return [ops.tanh(getParamValue('x', node, tensorMap, context))];
87673 }
87674 case 'Tan':
87675 return [ops.tan(getParamValue('x', node, tensorMap, context))];
87676 case 'ClipByValue':
87677 return [ops.clipByValue(getParamValue('x', node, tensorMap, context), getParamValue('clipValueMin', node, tensorMap, context), getParamValue('clipValueMax', node, tensorMap, context))];
87678 case 'Relu6':
87679 return [ops.relu6(getParamValue('x', node, tensorMap, context))];
87680 case 'Rsqrt':
87681 return [ops.rsqrt(getTensor(node.inputNames[0], tensorMap, context))];
87682 case 'LeakyRelu':
87683 return [ops.leakyRelu(getParamValue('x', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context))];
87684 case 'Prelu':
87685 return [ops.prelu(getParamValue('x', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context))];
87686 case 'IsNan':
87687 return [ops.isNaN(getTensor(node.inputNames[0], tensorMap, context))];
87688 case 'IsInf':
87689 return [ops.isInf(getTensor(node.inputNames[0], tensorMap, context))];
87690 case 'IsFinite':
87691 return [ops.isFinite(getTensor(node.inputNames[0], tensorMap, context))];
87692 default:
87693 throw TypeError("Node type ".concat(node.op, " is not implemented"));
87694 }
87695 };
87696 var CATEGORY$i = 'basic_math';
87697
87698 /**
87699 * @license
87700 * Copyright 2020 Google LLC. All Rights Reserved.
87701 * Licensed under the Apache License, Version 2.0 (the "License");
87702 * you may not use this file except in compliance with the License.
87703 * You may obtain a copy of the License at
87704 *
87705 * http://www.apache.org/licenses/LICENSE-2.0
87706 *
87707 * Unless required by applicable law or agreed to in writing, software
87708 * distributed under the License is distributed on an "AS IS" BASIS,
87709 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87710 * See the License for the specific language governing permissions and
87711 * limitations under the License.
87712 * =============================================================================
87713 */
87714 /**
87715 * Used by TensorList and TensorArray to verify if elementShape matches, support
87716 * negative value as the dim shape.
87717 * @param shapeA
87718 * @param shapeB
87719 * @param errorMessagePrefix
87720 */
87721 function assertShapesMatchAllowUndefinedSize(shapeA, shapeB) {
87722 var errorMessagePrefix = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : '';
87723 // constant shape means unknown rank
87724 if (typeof shapeA === 'number' || typeof shapeB === 'number') {
87725 return;
87726 }
87727 assert$1(shapeA.length === shapeB.length, function () {
87728 return errorMessagePrefix + " Shapes ".concat(shapeA, " and ").concat(shapeB, " must match");
87729 });
87730 for (var i = 0; i < shapeA.length; i++) {
87731 var dim0 = shapeA[i];
87732 var dim1 = shapeB[i];
87733 assert$1(dim0 < 0 || dim1 < 0 || dim0 === dim1, function () {
87734 return errorMessagePrefix + " Shapes ".concat(shapeA, " and ").concat(shapeB, " must match");
87735 });
87736 }
87737 }
87738 function fullDefinedShape(elementShape) {
87739 if (typeof elementShape === 'number' || elementShape.some(function (dim) {
87740 return dim < 0;
87741 })) {
87742 return false;
87743 }
87744 return true;
87745 }
87746 /**
87747 * Generate the output element shape from the list elementShape, list tensors
87748 * and input param.
87749 * @param listElementShape
87750 * @param tensors
87751 * @param elementShape
87752 */
87753 function inferElementShape(listElementShape, tensors, elementShape) {
87754 var partialShape = mergeElementShape(listElementShape, elementShape);
87755 var notfullDefinedShape = !fullDefinedShape(partialShape);
87756 if (notfullDefinedShape && tensors.length === 0) {
87757 throw new Error("Tried to calculate elements of an empty list" + " with non-fully-defined elementShape: ".concat(partialShape));
87758 }
87759 if (notfullDefinedShape) {
87760 tensors.forEach(function (tensor) {
87761 partialShape = mergeElementShape(tensor.shape, partialShape);
87762 });
87763 }
87764 if (!fullDefinedShape(partialShape)) {
87765 throw new Error("Non-fully-defined elementShape: ".concat(partialShape));
87766 }
87767 return partialShape;
87768 }
87769 function mergeElementShape(elementShapeA, elementShapeB) {
87770 if (typeof elementShapeA === 'number') {
87771 return elementShapeB;
87772 }
87773 if (typeof elementShapeB === 'number') {
87774 return elementShapeA;
87775 }
87776 if (elementShapeA.length !== elementShapeB.length) {
87777 throw new Error("Incompatible ranks during merge: ".concat(elementShapeA, " vs. ").concat(elementShapeB));
87778 }
87779 var result = [];
87780 for (var i = 0; i < elementShapeA.length; ++i) {
87781 var dim0 = elementShapeA[i];
87782 var dim1 = elementShapeB[i];
87783 if (dim0 >= 0 && dim1 >= 0 && dim0 !== dim1) {
87784 throw new Error("Incompatible shape during merge: ".concat(elementShapeA, " vs. ").concat(elementShapeB));
87785 }
87786 result[i] = dim0 >= 0 ? dim0 : dim1;
87787 }
87788 return result;
87789 }
87790
87791 /**
87792 * The TensorArray object keeps an array of Tensors. It
87793 * allows reading from the array and writing to the array.
87794 */
87795 var TensorArray = /*#__PURE__*/function () {
87796 function TensorArray(name, dtype, maxSize, elementShape, identicalElementShapes, dynamicSize, clearAfterRead) {
87797 _classCallCheck(this, TensorArray);
87798 this.name = name;
87799 this.dtype = dtype;
87800 this.maxSize = maxSize;
87801 this.elementShape = elementShape;
87802 this.identicalElementShapes = identicalElementShapes;
87803 this.dynamicSize = dynamicSize;
87804 this.clearAfterRead = clearAfterRead;
87805 this.tensors = [];
87806 this.closed_ = false;
87807 this.idTensor = scalar(0);
87808 keep(this.idTensor);
87809 }
87810 _createClass(TensorArray, [{
87811 key: "id",
87812 get: function get() {
87813 return this.idTensor.id;
87814 }
87815 }, {
87816 key: "closed",
87817 get: function get() {
87818 return this.closed_;
87819 }
87820 /**
87821 * Dispose the tensors and idTensor and mark the TensoryArray as closed.
87822 */
87823 }, {
87824 key: "clearAndClose",
87825 value: function clearAndClose(keepIds) {
87826 this.tensors.forEach(function (tensor) {
87827 if (keepIds == null || !keepIds.has(tensor.tensor.id)) {
87828 tensor.tensor.dispose();
87829 }
87830 });
87831 this.tensors = [];
87832 this.closed_ = true;
87833 this.idTensor.dispose();
87834 }
87835 }, {
87836 key: "size",
87837 value: function size() {
87838 return this.tensors.length;
87839 }
87840 /**
87841 * Read the value at location index in the TensorArray.
87842 * @param index Number the index to read from.
87843 */
87844 }, {
87845 key: "read",
87846 value: function read(index) {
87847 if (this.closed_) {
87848 throw new Error("TensorArray ".concat(this.name, " has already been closed."));
87849 }
87850 if (index < 0 || index >= this.size()) {
87851 throw new Error("Tried to read from index ".concat(index, ", but array size is: ").concat(this.size()));
87852 }
87853 var tensorWithState = this.tensors[index];
87854 if (tensorWithState.cleared) {
87855 throw new Error("TensorArray ".concat(this.name, ": Could not read index ").concat(index, " twice because it was cleared after a previous read ") + "(perhaps try setting clear_after_read = false?).");
87856 }
87857 if (this.clearAfterRead) {
87858 tensorWithState.cleared = true;
87859 }
87860 tensorWithState.read = true;
87861 return tensorWithState.tensor;
87862 }
87863 /**
87864 * Helper method to read multiple tensors from the specified indices.
87865 */
87866 }, {
87867 key: "readMany",
87868 value: function readMany(indices) {
87869 var _this = this;
87870 return indices.map(function (index) {
87871 return _this.read(index);
87872 });
87873 }
87874 /**
87875 * Write value into the index of the TensorArray.
87876 * @param index number the index to write to.
87877 * @param tensor
87878 */
87879 }, {
87880 key: "write",
87881 value: function write(index, tensor) {
87882 if (this.closed_) {
87883 throw new Error("TensorArray ".concat(this.name, " has already been closed."));
87884 }
87885 if (index < 0 || !this.dynamicSize && index >= this.maxSize) {
87886 throw new Error("Tried to write to index ".concat(index, ", but array is not resizeable and size is: ").concat(this.maxSize));
87887 }
87888 var t = this.tensors[index] || {};
87889 if (tensor.dtype !== this.dtype) {
87890 throw new Error("TensorArray ".concat(this.name, ": Could not write to TensorArray index ").concat(index, ",\n because the value dtype is ").concat(tensor.dtype, ", but TensorArray dtype is ").concat(this.dtype, "."));
87891 }
87892 // Set the shape for the first time write to unknow shape tensor array
87893 if (this.size() === 0 && (this.elementShape == null || this.elementShape.length === 0)) {
87894 this.elementShape = tensor.shape;
87895 }
87896 assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, "TensorArray ".concat(this.name, ": Could not write to TensorArray index ").concat(index, "."));
87897 if (t.read) {
87898 throw new Error("TensorArray ".concat(this.name, ": Could not write to TensorArray index ").concat(index, ", because it has already been read."));
87899 }
87900 if (t.written) {
87901 throw new Error("TensorArray ".concat(this.name, ": Could not write to TensorArray index ").concat(index, ", because it has already been written."));
87902 }
87903 t.tensor = tensor;
87904 keep(tensor);
87905 t.written = true;
87906 this.tensors[index] = t;
87907 }
87908 /**
87909 * Helper method to write multiple tensors to the specified indices.
87910 */
87911 }, {
87912 key: "writeMany",
87913 value: function writeMany(indices, tensors) {
87914 var _this2 = this;
87915 if (indices.length !== tensors.length) {
87916 throw new Error("TensorArray ".concat(this.name, ": could not write multiple tensors,") + "because the index size: ".concat(indices.length, " is not the same as tensors size: ").concat(tensors.length, "."));
87917 }
87918 indices.forEach(function (i, index) {
87919 return _this2.write(i, tensors[index]);
87920 });
87921 }
87922 /**
87923 * Return selected values in the TensorArray as a packed Tensor. All of
87924 * selected values must have been written and their shapes must all match.
87925 * @param [indices] number[] Optional. Taking values in [0, max_value). If the
87926 * TensorArray is not dynamic, max_value=size(). If not specified returns
87927 * all tensors in the original order.
87928 * @param [dtype]
87929 */
87930 }, {
87931 key: "gather",
87932 value: function gather(indices, dtype) {
87933 if (!!dtype && dtype !== this.dtype) {
87934 throw new Error("TensorArray dtype is ".concat(this.dtype, " but gather requested dtype ").concat(dtype));
87935 }
87936 if (!indices) {
87937 indices = [];
87938 for (var i = 0; i < this.size(); i++) {
87939 indices.push(i);
87940 }
87941 } else {
87942 indices = indices.slice(0, this.size());
87943 }
87944 if (indices.length === 0) {
87945 return tensor([], [0].concat(this.elementShape));
87946 }
87947 // Read all the PersistentTensors into a vector to keep track of
87948 // their memory.
87949 var tensors = this.readMany(indices);
87950 assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: ');
87951 return stack(tensors, 0);
87952 }
87953 /**
87954 * Return the values in the TensorArray as a concatenated Tensor.
87955 */
87956 }, {
87957 key: "concat",
87958 value: function concat(dtype) {
87959 if (!!dtype && dtype !== this.dtype) {
87960 throw new Error("TensorArray dtype is ".concat(this.dtype, " but concat requested dtype ").concat(dtype));
87961 }
87962 if (this.size() === 0) {
87963 return tensor([], [0].concat(this.elementShape));
87964 }
87965 var indices = [];
87966 for (var i = 0; i < this.size(); i++) {
87967 indices.push(i);
87968 }
87969 // Collect all the tensors from the tensors array.
87970 var tensors = this.readMany(indices);
87971 assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, "TensorArray shape mismatch: tensor array shape (".concat(this.elementShape, ") vs first tensor shape (").concat(tensors[0].shape, ")"));
87972 return concat$2(tensors, 0);
87973 }
87974 /**
87975 * Scatter the values of a Tensor in specific indices of a TensorArray.
87976 * @param indices number[] values in [0, max_value). If the
87977 * TensorArray is not dynamic, max_value=size().
87978 * @param tensor Tensor input tensor.
87979 */
87980 }, {
87981 key: "scatter",
87982 value: function scatter(indices, tensor) {
87983 if (tensor.dtype !== this.dtype) {
87984 throw new Error("TensorArray dtype is ".concat(this.dtype, " but tensor has dtype ").concat(tensor.dtype));
87985 }
87986 if (indices.length !== tensor.shape[0]) {
87987 throw new Error("Expected len(indices) == tensor.shape[0], but saw: ".concat(indices.length, " vs. ").concat(tensor.shape[0]));
87988 }
87989 var maxIndex = Math.max.apply(Math, _toConsumableArray(indices));
87990 if (!this.dynamicSize && maxIndex >= this.maxSize) {
87991 throw new Error("Max index must be < array size (".concat(maxIndex, " vs. ").concat(this.maxSize, ")"));
87992 }
87993 this.writeMany(indices, unstack(tensor, 0));
87994 }
87995 /**
87996 * Split the values of a Tensor into the TensorArray.
87997 * @param length number[] with the lengths to use when splitting value along
87998 * its first dimension.
87999 * @param tensor Tensor, the tensor to split.
88000 */
88001 }, {
88002 key: "split",
88003 value: function split(length, tensor) {
88004 var _this3 = this;
88005 if (tensor.dtype !== this.dtype) {
88006 throw new Error("TensorArray dtype is ".concat(this.dtype, " but tensor has dtype ").concat(tensor.dtype));
88007 }
88008 var totalLength = 0;
88009 var cumulativeLengths = length.map(function (len) {
88010 totalLength += len;
88011 return totalLength;
88012 });
88013 if (totalLength !== tensor.shape[0]) {
88014 throw new Error("Expected sum of lengths to be equal to\n tensor.shape[0], but sum of lengths is\n ".concat(totalLength, ", and tensor's shape is: ").concat(tensor.shape));
88015 }
88016 if (!this.dynamicSize && length.length !== this.maxSize) {
88017 throw new Error("TensorArray's size is not equal to the size of lengths (".concat(this.maxSize, " vs. ").concat(length.length, "), ") + 'and the TensorArray is not marked as dynamically resizeable');
88018 }
88019 var elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
88020 var tensors = [];
88021 tidy(function () {
88022 tensor = reshape$3(tensor, [1, totalLength, elementPerRow]);
88023 for (var i = 0; i < length.length; ++i) {
88024 var previousLength = i === 0 ? 0 : cumulativeLengths[i - 1];
88025 var _indices = [0, previousLength, 0];
88026 var sizes = [1, length[i], elementPerRow];
88027 tensors[i] = reshape$3(slice$2(tensor, _indices, sizes), _this3.elementShape);
88028 }
88029 return tensors;
88030 });
88031 var indices = [];
88032 for (var i = 0; i < length.length; i++) {
88033 indices[i] = i;
88034 }
88035 this.writeMany(indices, tensors);
88036 }
88037 }]);
88038 return TensorArray;
88039 }();
88040
88041 /**
88042 * TensorList stores a container of `tf.Tensor` objects, which are accessible
88043 * via tensors field.
88044 *
88045 * In order to get a copy of the underlying list, use the copy method:
88046 * ```
88047 * TensorList b = a.copy();
88048 * b.tensors().pushBack(t); // This does not modify a.tensors().
88049 * ```
88050 *
88051 * Note that this is not a deep copy: the memory locations of the underlying
88052 * tensors will still point to the same locations of the corresponding tensors
88053 * in the original.
88054 */
88055 var TensorList = /*#__PURE__*/function () {
88056 /**
88057 *
88058 * @param tensors list of tensors
88059 * @param elementShape shape of each tensor, this can be a single number (any
88060 * shape is allowed) or partial shape (dim = -1).
88061 * @param elementDtype data type of each tensor
88062 * @param maxNumElements The maximum allowed size of `tensors`. Defaults to -1
88063 * meaning that the size of `tensors` is unbounded.
88064 */
88065 function TensorList(tensors, elementShape, elementDtype) {
88066 var maxNumElements = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : -1;
88067 _classCallCheck(this, TensorList);
88068 this.tensors = tensors;
88069 this.elementShape = elementShape;
88070 this.elementDtype = elementDtype;
88071 if (tensors != null) {
88072 tensors.forEach(function (tensor) {
88073 if (elementDtype !== tensor.dtype) {
88074 throw new Error("Invalid data types; op elements ".concat(elementDtype, ", but list elements ").concat(tensor.dtype));
88075 }
88076 assertShapesMatchAllowUndefinedSize(elementShape, tensor.shape, 'TensorList shape mismatch: ');
88077 keep(tensor);
88078 });
88079 }
88080 this.idTensor = scalar(0);
88081 this.maxNumElements = maxNumElements;
88082 keep(this.idTensor);
88083 }
88084 /**
88085 * Get a new TensorList containing a copy of the underlying tensor container.
88086 */
88087 _createClass(TensorList, [{
88088 key: "id",
88089 get: function get() {
88090 return this.idTensor.id;
88091 }
88092 }, {
88093 key: "copy",
88094 value: function copy() {
88095 return new TensorList(_toConsumableArray(this.tensors), this.elementShape, this.elementDtype);
88096 }
88097 /**
88098 * Dispose the tensors and idTensor and clear the tensor list.
88099 */
88100 }, {
88101 key: "clearAndClose",
88102 value: function clearAndClose(keepIds) {
88103 this.tensors.forEach(function (tensor) {
88104 if (keepIds == null || !keepIds.has(tensor.id)) {
88105 tensor.dispose();
88106 }
88107 });
88108 this.tensors.length = 0;
88109 this.idTensor.dispose();
88110 }
88111 /**
88112 * The size of the tensors in the tensor list.
88113 */
88114 }, {
88115 key: "size",
88116 value: function size() {
88117 return this.tensors.length;
88118 }
88119 /**
88120 * Return a tensor that stacks a list of rank-R tf.Tensors into one rank-(R+1)
88121 * tf.Tensor.
88122 * @param elementShape shape of each tensor
88123 * @param elementDtype data type of each tensor
88124 * @param numElements the number of elements to stack
88125 */
88126 }, {
88127 key: "stack",
88128 value: function stack$1(elementShape, elementDtype) {
88129 var _this = this;
88130 var numElements = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : -1;
88131 if (elementDtype !== this.elementDtype) {
88132 throw new Error("Invalid data types; op elements ".concat(elementDtype, ", but list elements ").concat(this.elementDtype));
88133 }
88134 if (numElements !== -1 && this.tensors.length !== numElements) {
88135 throw new Error("Operation expected a list with ".concat(numElements, " elements but got a list with ").concat(this.tensors.length, " elements."));
88136 }
88137 assertShapesMatchAllowUndefinedSize(elementShape, this.elementShape, 'TensorList shape mismatch: ');
88138 var outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
88139 return tidy(function () {
88140 var reshapedTensors = _this.tensors.map(function (tensor) {
88141 return reshape$3(tensor, outputElementShape);
88142 });
88143 return stack(reshapedTensors, 0);
88144 });
88145 }
88146 /**
88147 * Pop a tensor from the end of the list.
88148 * @param elementShape shape of the tensor
88149 * @param elementDtype data type of the tensor
88150 */
88151 }, {
88152 key: "popBack",
88153 value: function popBack(elementShape, elementDtype) {
88154 if (elementDtype !== this.elementDtype) {
88155 throw new Error("Invalid data types; op elements ".concat(elementDtype, ", but list elements ").concat(this.elementDtype));
88156 }
88157 if (this.size() === 0) {
88158 throw new Error('Trying to pop from an empty list.');
88159 }
88160 var outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
88161 var tensor = this.tensors.pop();
88162 tensor.kept = false;
88163 assertShapesMatchAllowUndefinedSize(tensor.shape, elementShape, 'TensorList shape mismatch: ');
88164 return reshape$3(tensor, outputElementShape);
88165 }
88166 /**
88167 * Push a tensor to the end of the list.
88168 * @param tensor Tensor to be pushed.
88169 */
88170 }, {
88171 key: "pushBack",
88172 value: function pushBack(tensor) {
88173 if (tensor.dtype !== this.elementDtype) {
88174 throw new Error("Invalid data types; op elements ".concat(tensor.dtype, ", but list elements ").concat(this.elementDtype));
88175 }
88176 assertShapesMatchAllowUndefinedSize(tensor.shape, this.elementShape, 'TensorList shape mismatch: ');
88177 if (this.maxNumElements === this.size()) {
88178 throw new Error("Trying to push element into a full list.");
88179 }
88180 keep(tensor);
88181 this.tensors.push(tensor);
88182 }
88183 /**
88184 * Update the size of the list.
88185 * @param size the new size of the list.
88186 */
88187 }, {
88188 key: "resize",
88189 value: function resize(size) {
88190 if (size < 0) {
88191 throw new Error("TensorListResize expects size to be non-negative. Got: ".concat(size));
88192 }
88193 if (this.maxNumElements !== -1 && size > this.maxNumElements) {
88194 throw new Error("TensorListResize input size ".concat(size, " is greater maxNumElement ").concat(this.maxNumElements, "."));
88195 }
88196 var destTensorList = new TensorList([], this.elementShape, this.elementDtype, this.maxNumElements);
88197 destTensorList.tensors.length = size;
88198 for (var i = 0; i < Math.min(this.tensors.length, size); ++i) {
88199 destTensorList.tensors[i] = this.tensors[i];
88200 }
88201 return destTensorList;
88202 }
88203 /**
88204 * Retrieve the element at the provided index
88205 * @param elementShape shape of the tensor
88206 * @param elementDtype dtype of the tensor
88207 * @param elementIndex index of the tensor
88208 */
88209 }, {
88210 key: "getItem",
88211 value: function getItem(elementIndex, elementShape, elementDtype) {
88212 if (elementDtype !== this.elementDtype) {
88213 throw new Error("Invalid data types; op elements ".concat(elementDtype, ", but list elements ").concat(this.elementDtype));
88214 }
88215 if (elementIndex < 0 || elementIndex > this.tensors.length) {
88216 throw new Error("Trying to access element ".concat(elementIndex, " in a list with ").concat(this.tensors.length, " elements."));
88217 }
88218 if (this.tensors[elementIndex] == null) {
88219 throw new Error("element at index ".concat(elementIndex, " is null."));
88220 }
88221 assertShapesMatchAllowUndefinedSize(this.tensors[elementIndex].shape, elementShape, 'TensorList shape mismatch: ');
88222 var outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
88223 return reshape$3(this.tensors[elementIndex], outputElementShape);
88224 }
88225 /**
88226 * Set the tensor at the index
88227 * @param elementIndex index of the tensor
88228 * @param tensor the tensor to be inserted into the list
88229 */
88230 }, {
88231 key: "setItem",
88232 value: function setItem(elementIndex, tensor) {
88233 if (tensor.dtype !== this.elementDtype) {
88234 throw new Error("Invalid data types; op elements ".concat(tensor.dtype, ", but list elements ").concat(this.elementDtype));
88235 }
88236 if (elementIndex < 0 || this.maxNumElements !== -1 && elementIndex >= this.maxNumElements) {
88237 throw new Error("Trying to set element ".concat(elementIndex, " in a list with max ").concat(this.maxNumElements, " elements."));
88238 }
88239 assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, 'TensorList shape mismatch: ');
88240 keep(tensor);
88241 // dispose the previous value if it is replacing.
88242 if (this.tensors[elementIndex] != null) {
88243 this.tensors[elementIndex].kept = false;
88244 }
88245 this.tensors[elementIndex] = tensor;
88246 }
88247 /**
88248 * Return selected values in the TensorList as a stacked Tensor. All of
88249 * selected values must have been written and their shapes must all match.
88250 * @param indices indices of tensors to gather
88251 * @param elementDtype output tensor dtype
88252 * @param elementShape output tensor element shape
88253 */
88254 }, {
88255 key: "gather",
88256 value: function gather(indices, elementDtype, elementShape) {
88257 var _this2 = this;
88258 if (elementDtype !== this.elementDtype) {
88259 throw new Error("Invalid data types; op elements ".concat(elementDtype, ", but list elements ").concat(this.elementDtype));
88260 }
88261 assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
88262 // When indices is greater than the size of the list, indices beyond the
88263 // size of the list are ignored.
88264 indices = indices.slice(0, this.size());
88265 var outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
88266 if (indices.length === 0) {
88267 return tensor([], [0].concat(outputElementShape));
88268 }
88269 return tidy(function () {
88270 var tensors = indices.map(function (i) {
88271 return reshape$3(_this2.tensors[i], outputElementShape);
88272 });
88273 return stack(tensors, 0);
88274 });
88275 }
88276 /**
88277 * Return the values in the TensorList as a concatenated Tensor.
88278 * @param elementDtype output tensor dtype
88279 * @param elementShape output tensor element shape
88280 */
88281 }, {
88282 key: "concat",
88283 value: function concat(elementDtype, elementShape) {
88284 var _this3 = this;
88285 if (!!elementDtype && elementDtype !== this.elementDtype) {
88286 throw new Error("TensorList dtype is ".concat(this.elementDtype, " but concat requested dtype ").concat(elementDtype));
88287 }
88288 assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
88289 var outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
88290 if (this.size() === 0) {
88291 return tensor([], [0].concat(outputElementShape));
88292 }
88293 return tidy(function () {
88294 var tensors = _this3.tensors.map(function (t) {
88295 return reshape$3(t, outputElementShape);
88296 });
88297 return concat$2(tensors, 0);
88298 });
88299 }
88300 }]);
88301 return TensorList;
88302 }();
88303 /**
88304 * Creates a TensorList which, when stacked, has the value of tensor.
88305 * @param tensor from tensor
88306 * @param elementShape output tensor element shape
88307 */
88308 function fromTensor(tensor, elementShape, elementDtype) {
88309 var dtype = tensor.dtype;
88310 if (tensor.shape.length < 1) {
88311 throw new Error("Tensor must be at least a vector, but saw shape: ".concat(tensor.shape));
88312 }
88313 if (tensor.dtype !== elementDtype) {
88314 throw new Error("Invalid data types; op elements ".concat(tensor.dtype, ", but list elements ").concat(elementDtype));
88315 }
88316 var tensorElementShape = tensor.shape.slice(1);
88317 assertShapesMatchAllowUndefinedSize(tensorElementShape, elementShape, 'TensorList shape mismatch: ');
88318 var tensorList = unstack(tensor);
88319 return new TensorList(tensorList, elementShape, dtype);
88320 }
88321 /**
88322 * Return a TensorList of the given size with empty elements.
88323 * @param elementShape the shape of the future elements of the list
88324 * @param elementDtype the desired type of elements in the list
88325 * @param numElements the number of elements to reserve
88326 * @param maxNumElements the maximum number of elements in th list
88327 */
88328 function reserve(elementShape, elementDtype, numElements, maxNumElements) {
88329 return new TensorList([], elementShape, elementDtype, maxNumElements);
88330 }
88331 /**
88332 * Put tensors at specific indices of a stacked tensor into a TensorList.
88333 * @param indices list of indices on how to scatter the tensor.
88334 * @param tensor input tensor.
88335 * @param elementShape the shape of the future elements of the list
88336 * @param numElements the number of elements to scatter
88337 */
88338 function scatter(tensor, indices, elementShape, numElements) {
88339 if (indices.length !== tensor.shape[0]) {
88340 throw new Error("Expected len(indices) == tensor.shape[0], but saw: ".concat(indices.length, " vs. ").concat(tensor.shape[0]));
88341 }
88342 var maxIndex = Math.max.apply(Math, _toConsumableArray(indices));
88343 if (numElements != null && numElements !== -1 && maxIndex >= numElements) {
88344 throw new Error("Max index must be < array size (".concat(maxIndex, " vs. ").concat(numElements, ")"));
88345 }
88346 var list = new TensorList([], elementShape, tensor.dtype, numElements);
88347 var tensors = unstack(tensor, 0);
88348 indices.forEach(function (value, index) {
88349 list.setItem(value, tensors[index]);
88350 });
88351 return list;
88352 }
88353 /**
88354 * Split the values of a Tensor into a TensorList.
88355 * @param length the lengths to use when splitting value along
88356 * its first dimension.
88357 * @param tensor the tensor to split.
88358 * @param elementShape the shape of the future elements of the list
88359 */
88360 function split$1(tensor, length, elementShape) {
88361 var totalLength = 0;
88362 var cumulativeLengths = length.map(function (len) {
88363 totalLength += len;
88364 return totalLength;
88365 });
88366 if (totalLength !== tensor.shape[0]) {
88367 throw new Error("Expected sum of lengths to be equal to\n tensor.shape[0], but sum of lengths is\n ".concat(totalLength, ", and tensor's shape is: ").concat(tensor.shape));
88368 }
88369 var shapeWithoutFirstDim = tensor.shape.slice(1);
88370 var outputElementShape = mergeElementShape(shapeWithoutFirstDim, elementShape);
88371 var elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
88372 var tensors = tidy(function () {
88373 var tensors = [];
88374 tensor = reshape$3(tensor, [1, totalLength, elementPerRow]);
88375 for (var i = 0; i < length.length; ++i) {
88376 var previousLength = i === 0 ? 0 : cumulativeLengths[i - 1];
88377 var indices = [0, previousLength, 0];
88378 var sizes = [1, length[i], elementPerRow];
88379 tensors[i] = reshape$3(slice$2(tensor, indices, sizes), outputElementShape);
88380 }
88381 tensor.dispose();
88382 return tensors;
88383 });
88384 var list = new TensorList([], elementShape, tensor.dtype, length.length);
88385 for (var i = 0; i < tensors.length; i++) {
88386 list.setItem(i, tensors[i]);
88387 }
88388 return list;
88389 }
88390
88391 var executeOp$i = /*#__PURE__*/function () {
88392 var _ref = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(node, tensorMap, context) {
88393 var thenFunc, elseFunc, cond, args, condValue, bodyFunc, condFunc, _args, condResult, argIds, _condValue, result, _loop, pred, _pred, data, inputName, _data, frameId, _data2, _data3, _data4, size, dtype, elementShape, dynamicSize, clearAfterRead, identicalElementShapes, name, tensorArray, id, index, writeTensor, writeTensorArray, readId, readIndex, readTensorArray, gatherId, gatherIndices, gatherDtype, gatherTensorArray, scatterId, scatterIndices, scatterTensor, scatterTensorArray, concatId, concatTensorArray, concatDtype, splitId, splitTensor, lengths, splitTensorArray, sizeId, sizeTensorArray, closeId, closeTensorArray, idTensor, _index, _writeTensor, tensorList, _idTensor, _readIndex, _elementShape, elementDType, _tensorList, _scatterIndices, _scatterTensor, _elementShape2, numElements, _tensorList2, _elementShape3, elementDtype, numElementsParam, _numElements, maxNumElements, _tensorList3, _gatherId, _gatherIndices, _elementShape4, _elementDtype, _tensorList4, _idTensor2, _elementShape5, _elementDtype2, _numElements2, _tensorList5, tensor, _elementShape6, _elementDtype3, _tensorList6, _concatId, _tensorList7, _concatDtype, _elementShape7, _idTensor3, _writeTensor2, _tensorList8, _idTensor4, _elementShape8, _elementDType, _tensorList9, _splitTensor, _elementShape9, _lengths, _tensorList10, _idTensor5, _tensorList11, _idTensor6, _size, srcTensorList, destTensorList;
88394 return _regeneratorRuntime().wrap(function _callee$(_context2) {
88395 while (1) switch (_context2.prev = _context2.next) {
88396 case 0:
88397 _context2.t0 = node.op;
88398 _context2.next = _context2.t0 === 'If' ? 3 : _context2.t0 === 'StatelessIf' ? 3 : _context2.t0 === 'While' ? 15 : _context2.t0 === 'StatelessWhile' ? 15 : _context2.t0 === 'LoopCond' ? 33 : _context2.t0 === 'Switch' ? 35 : _context2.t0 === 'Merge' ? 46 : _context2.t0 === 'Enter' ? 51 : _context2.t0 === 'Exit' ? 55 : _context2.t0 === 'NextIteration' ? 58 : _context2.t0 === 'TensorArrayV3' ? 61 : _context2.t0 === 'TensorArrayWriteV3' ? 71 : _context2.t0 === 'TensorArrayReadV3' ? 77 : _context2.t0 === 'TensorArrayGatherV3' ? 81 : _context2.t0 === 'TensorArrayScatterV3' ? 86 : _context2.t0 === 'TensorArrayConcatV3' ? 92 : _context2.t0 === 'TensorArraySplitV3' ? 96 : _context2.t0 === 'TensorArraySizeV3' ? 102 : _context2.t0 === 'TensorArrayCloseV3' ? 105 : _context2.t0 === 'TensorListSetItem' ? 109 : _context2.t0 === 'TensorListGetItem' ? 115 : _context2.t0 === 'TensorListScatterV2' ? 121 : _context2.t0 === 'TensorListScatter' ? 121 : _context2.t0 === 'TensorListReserve' ? 128 : _context2.t0 === 'EmptyTensorList' ? 128 : _context2.t0 === 'TensorListGather' ? 136 : _context2.t0 === 'TensorListStack' ? 142 : _context2.t0 === 'TensorListFromTensor' ? 148 : _context2.t0 === 'TensorListConcat' ? 154 : _context2.t0 === 'TensorListConcatV2' ? 154 : _context2.t0 === 'TensorListPushBack' ? 159 : _context2.t0 === 'TensorListPopBack' ? 164 : _context2.t0 === 'TensorListSplit' ? 169 : _context2.t0 === 'TensorListLength' ? 175 : _context2.t0 === 'TensorListResize' ? 178 : 184;
88399 break;
88400 case 3:
88401 thenFunc = getParamValue('thenBranch', node, tensorMap, context);
88402 elseFunc = getParamValue('elseBranch', node, tensorMap, context);
88403 cond = getParamValue('cond', node, tensorMap, context);
88404 args = getParamValue('args', node, tensorMap, context);
88405 _context2.next = 9;
88406 return cond.data();
88407 case 9:
88408 condValue = _context2.sent;
88409 if (!condValue[0]) {
88410 _context2.next = 14;
88411 break;
88412 }
88413 return _context2.abrupt("return", context.functionMap[thenFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap));
88414 case 14:
88415 return _context2.abrupt("return", context.functionMap[elseFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap));
88416 case 15:
88417 bodyFunc = getParamValue('body', node, tensorMap, context);
88418 condFunc = getParamValue('cond', node, tensorMap, context);
88419 _args = getParamValue('args', node, tensorMap, context); // Calculate the condition of the loop
88420 _context2.next = 20;
88421 return context.functionMap[condFunc].executeFunctionAsync(_args, context.tensorArrayMap, context.tensorListMap);
88422 case 20:
88423 condResult = _context2.sent;
88424 argIds = _args.map(function (tensor) {
88425 return tensor.id;
88426 });
88427 _context2.next = 24;
88428 return condResult[0].data();
88429 case 24:
88430 _condValue = _context2.sent;
88431 // Dispose the intermediate tensors for condition function
88432 condResult.forEach(function (tensor) {
88433 if (!tensor.kept && argIds.indexOf(tensor.id) === -1) {
88434 tensor.dispose();
88435 }
88436 });
88437 result = _args;
88438 _loop = /*#__PURE__*/_regeneratorRuntime().mark(function _loop() {
88439 var origResult, resultIds, condResult;
88440 return _regeneratorRuntime().wrap(function _loop$(_context) {
88441 while (1) switch (_context.prev = _context.next) {
88442 case 0:
88443 // Record the previous result for intermediate tensor tracking
88444 origResult = result; // Execution the body of the loop
88445 _context.next = 3;
88446 return context.functionMap[bodyFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap);
88447 case 3:
88448 result = _context.sent;
88449 resultIds = result.map(function (tensor) {
88450 return tensor.id;
88451 }); // Dispose the intermediate tensor for body function that is not global
88452 // kept, not input/output of the body function
88453 origResult.forEach(function (tensor) {
88454 if (!tensor.kept && argIds.indexOf(tensor.id) === -1 && resultIds.indexOf(tensor.id) === -1) {
88455 tensor.dispose();
88456 }
88457 });
88458 // Recalcuate the condition of the loop using the latest results.
88459 _context.next = 8;
88460 return context.functionMap[condFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap);
88461 case 8:
88462 condResult = _context.sent;
88463 _context.next = 11;
88464 return condResult[0].data();
88465 case 11:
88466 _condValue = _context.sent;
88467 // Dispose the intermediate tensors for condition function
88468 condResult.forEach(function (tensor) {
88469 if (!tensor.kept && argIds.indexOf(tensor.id) === -1 && resultIds.indexOf(tensor.id) === -1) {
88470 tensor.dispose();
88471 }
88472 });
88473 case 13:
88474 case "end":
88475 return _context.stop();
88476 }
88477 }, _loop);
88478 });
88479 case 28:
88480 if (!_condValue[0]) {
88481 _context2.next = 32;
88482 break;
88483 }
88484 return _context2.delegateYield(_loop(), "t1", 30);
88485 case 30:
88486 _context2.next = 28;
88487 break;
88488 case 32:
88489 return _context2.abrupt("return", result);
88490 case 33:
88491 pred = getParamValue('pred', node, tensorMap, context);
88492 return _context2.abrupt("return", [cloneTensor(pred)]);
88493 case 35:
88494 _pred = getParamValue('pred', node, tensorMap, context);
88495 data = getParamValue('data', node, tensorMap, context);
88496 if (!data.kept) {
88497 data = cloneTensor(data);
88498 }
88499 // Outputs nodes :0 => false, :1 => true
88500 _context2.next = 40;
88501 return _pred.data();
88502 case 40:
88503 if (!_context2.sent[0]) {
88504 _context2.next = 44;
88505 break;
88506 }
88507 _context2.t2 = [undefined, data];
88508 _context2.next = 45;
88509 break;
88510 case 44:
88511 _context2.t2 = [data, undefined];
88512 case 45:
88513 return _context2.abrupt("return", _context2.t2);
88514 case 46:
88515 inputName = node.inputNames.find(function (name) {
88516 return getTensor(name, tensorMap, context) !== undefined;
88517 });
88518 if (!inputName) {
88519 _context2.next = 50;
88520 break;
88521 }
88522 _data = getTensor(inputName, tensorMap, context);
88523 return _context2.abrupt("return", [cloneTensor(_data)]);
88524 case 50:
88525 return _context2.abrupt("return", undefined);
88526 case 51:
88527 frameId = getParamValue('frameName', node, tensorMap, context);
88528 _data2 = getParamValue('tensor', node, tensorMap, context);
88529 context.enterFrame(frameId);
88530 return _context2.abrupt("return", [cloneTensor(_data2)]);
88531 case 55:
88532 _data3 = getParamValue('tensor', node, tensorMap, context);
88533 context.exitFrame();
88534 return _context2.abrupt("return", [cloneTensor(_data3)]);
88535 case 58:
88536 _data4 = getParamValue('tensor', node, tensorMap, context);
88537 context.nextIteration();
88538 return _context2.abrupt("return", [cloneTensor(_data4)]);
88539 case 61:
88540 size = getParamValue('size', node, tensorMap, context);
88541 dtype = getParamValue('dtype', node, tensorMap, context);
88542 elementShape = getParamValue('elementShape', node, tensorMap, context);
88543 dynamicSize = getParamValue('dynamicSize', node, tensorMap, context);
88544 clearAfterRead = getParamValue('clearAfterRead', node, tensorMap, context);
88545 identicalElementShapes = getParamValue('identicalElementShapes', node, tensorMap, context);
88546 name = getParamValue('name', node, tensorMap, context);
88547 tensorArray = new TensorArray(name, dtype, size, elementShape, identicalElementShapes, dynamicSize, clearAfterRead);
88548 context.addTensorArray(tensorArray);
88549 return _context2.abrupt("return", [tensorArray.idTensor, scalar(1.0)]);
88550 case 71:
88551 id = getParamValue('tensorArrayId', node, tensorMap, context);
88552 index = getParamValue('index', node, tensorMap, context);
88553 writeTensor = getParamValue('tensor', node, tensorMap, context);
88554 writeTensorArray = context.getTensorArray(id.id);
88555 writeTensorArray.write(index, writeTensor);
88556 return _context2.abrupt("return", [writeTensorArray.idTensor]);
88557 case 77:
88558 readId = getParamValue('tensorArrayId', node, tensorMap, context);
88559 readIndex = getParamValue('index', node, tensorMap, context);
88560 readTensorArray = context.getTensorArray(readId.id);
88561 return _context2.abrupt("return", [readTensorArray.read(readIndex)]);
88562 case 81:
88563 gatherId = getParamValue('tensorArrayId', node, tensorMap, context);
88564 gatherIndices = getParamValue('indices', node, tensorMap, context);
88565 gatherDtype = getParamValue('dtype', node, tensorMap, context);
88566 gatherTensorArray = context.getTensorArray(gatherId.id);
88567 return _context2.abrupt("return", [gatherTensorArray.gather(gatherIndices, gatherDtype)]);
88568 case 86:
88569 scatterId = getParamValue('tensorArrayId', node, tensorMap, context);
88570 scatterIndices = getParamValue('indices', node, tensorMap, context);
88571 scatterTensor = getParamValue('tensor', node, tensorMap, context);
88572 scatterTensorArray = context.getTensorArray(scatterId.id);
88573 scatterTensorArray.scatter(scatterIndices, scatterTensor);
88574 return _context2.abrupt("return", [scatterTensorArray.idTensor]);
88575 case 92:
88576 concatId = getParamValue('tensorArrayId', node, tensorMap, context);
88577 concatTensorArray = context.getTensorArray(concatId.id);
88578 concatDtype = getParamValue('dtype', node, tensorMap, context);
88579 return _context2.abrupt("return", [concatTensorArray.concat(concatDtype)]);
88580 case 96:
88581 splitId = getParamValue('tensorArrayId', node, tensorMap, context);
88582 splitTensor = getParamValue('tensor', node, tensorMap, context);
88583 lengths = getParamValue('lengths', node, tensorMap, context);
88584 splitTensorArray = context.getTensorArray(splitId.id);
88585 splitTensorArray.split(lengths, splitTensor);
88586 return _context2.abrupt("return", [splitTensorArray.idTensor]);
88587 case 102:
88588 sizeId = getParamValue('tensorArrayId', node, tensorMap, context);
88589 sizeTensorArray = context.getTensorArray(sizeId.id);
88590 return _context2.abrupt("return", [scalar(sizeTensorArray.size(), 'int32')]);
88591 case 105:
88592 closeId = getParamValue('tensorArrayId', node, tensorMap, context);
88593 closeTensorArray = context.getTensorArray(closeId.id);
88594 closeTensorArray.clearAndClose();
88595 return _context2.abrupt("return", [closeTensorArray.idTensor]);
88596 case 109:
88597 idTensor = getParamValue('tensorListId', node, tensorMap, context);
88598 _index = getParamValue('index', node, tensorMap, context);
88599 _writeTensor = getParamValue('tensor', node, tensorMap, context);
88600 tensorList = context.getTensorList(idTensor.id);
88601 tensorList.setItem(_index, _writeTensor);
88602 return _context2.abrupt("return", [tensorList.idTensor]);
88603 case 115:
88604 _idTensor = getParamValue('tensorListId', node, tensorMap, context);
88605 _readIndex = getParamValue('index', node, tensorMap, context);
88606 _elementShape = getParamValue('elementShape', node, tensorMap, context);
88607 elementDType = getParamValue('elementDType', node, tensorMap, context);
88608 _tensorList = context.getTensorList(_idTensor.id);
88609 return _context2.abrupt("return", [_tensorList.getItem(_readIndex, _elementShape, elementDType)]);
88610 case 121:
88611 _scatterIndices = getParamValue('indices', node, tensorMap, context);
88612 _scatterTensor = getParamValue('tensor', node, tensorMap, context);
88613 _elementShape2 = getParamValue('elementShape', node, tensorMap, context);
88614 numElements = getParamValue('numElements', node, tensorMap, context);
88615 _tensorList2 = scatter(_scatterTensor, _scatterIndices, _elementShape2, numElements);
88616 context.addTensorList(_tensorList2);
88617 return _context2.abrupt("return", [_tensorList2.idTensor]);
88618 case 128:
88619 _elementShape3 = getParamValue('elementShape', node, tensorMap, context);
88620 elementDtype = getParamValue('elementDType', node, tensorMap, context);
88621 if (node.op === 'TensorListReserve') {
88622 numElementsParam = 'numElements';
88623 } else {
88624 numElementsParam = 'maxNumElements';
88625 }
88626 _numElements = getParamValue(numElementsParam, node, tensorMap, context);
88627 maxNumElements = node.op === 'TensorListReserve' ? -1 : _numElements;
88628 _tensorList3 = reserve(_elementShape3, elementDtype, _numElements, maxNumElements);
88629 context.addTensorList(_tensorList3);
88630 return _context2.abrupt("return", [_tensorList3.idTensor]);
88631 case 136:
88632 _gatherId = getParamValue('tensorListId', node, tensorMap, context);
88633 _gatherIndices = getParamValue('indices', node, tensorMap, context);
88634 _elementShape4 = getParamValue('elementShape', node, tensorMap, context);
88635 _elementDtype = getParamValue('elementDType', node, tensorMap, context);
88636 _tensorList4 = context.getTensorList(_gatherId.id);
88637 return _context2.abrupt("return", [_tensorList4.gather(_gatherIndices, _elementDtype, _elementShape4)]);
88638 case 142:
88639 _idTensor2 = getParamValue('tensorListId', node, tensorMap, context);
88640 _elementShape5 = getParamValue('elementShape', node, tensorMap, context);
88641 _elementDtype2 = getParamValue('elementDType', node, tensorMap, context);
88642 _numElements2 = getParamValue('numElements', node, tensorMap, context);
88643 _tensorList5 = context.getTensorList(_idTensor2.id);
88644 return _context2.abrupt("return", [_tensorList5.stack(_elementShape5, _elementDtype2, _numElements2)]);
88645 case 148:
88646 tensor = getParamValue('tensor', node, tensorMap, context);
88647 _elementShape6 = getParamValue('elementShape', node, tensorMap, context);
88648 _elementDtype3 = getParamValue('elementDType', node, tensorMap, context);
88649 _tensorList6 = fromTensor(tensor, _elementShape6, _elementDtype3);
88650 context.addTensorList(_tensorList6);
88651 return _context2.abrupt("return", [_tensorList6.idTensor]);
88652 case 154:
88653 _concatId = getParamValue('tensorListId', node, tensorMap, context);
88654 _tensorList7 = context.getTensorList(_concatId.id);
88655 _concatDtype = getParamValue('dtype', node, tensorMap, context);
88656 _elementShape7 = getParamValue('elementShape', node, tensorMap, context);
88657 return _context2.abrupt("return", [_tensorList7.concat(_concatDtype, _elementShape7)]);
88658 case 159:
88659 _idTensor3 = getParamValue('tensorListId', node, tensorMap, context);
88660 _writeTensor2 = getParamValue('tensor', node, tensorMap, context);
88661 _tensorList8 = context.getTensorList(_idTensor3.id);
88662 _tensorList8.pushBack(_writeTensor2);
88663 return _context2.abrupt("return", [_tensorList8.idTensor]);
88664 case 164:
88665 _idTensor4 = getParamValue('tensorListId', node, tensorMap, context);
88666 _elementShape8 = getParamValue('elementShape', node, tensorMap, context);
88667 _elementDType = getParamValue('elementDType', node, tensorMap, context);
88668 _tensorList9 = context.getTensorList(_idTensor4.id);
88669 return _context2.abrupt("return", [_tensorList9.popBack(_elementShape8, _elementDType)]);
88670 case 169:
88671 _splitTensor = getParamValue('tensor', node, tensorMap, context);
88672 _elementShape9 = getParamValue('elementShape', node, tensorMap, context);
88673 _lengths = getParamValue('lengths', node, tensorMap, context);
88674 _tensorList10 = split$1(_splitTensor, _lengths, _elementShape9);
88675 context.addTensorList(_tensorList10);
88676 return _context2.abrupt("return", [_tensorList10.idTensor]);
88677 case 175:
88678 _idTensor5 = getParamValue('tensorListId', node, tensorMap, context);
88679 _tensorList11 = context.getTensorList(_idTensor5.id);
88680 return _context2.abrupt("return", [scalar(_tensorList11.size(), 'int32')]);
88681 case 178:
88682 _idTensor6 = getParamValue('tensorListId', node, tensorMap, context);
88683 _size = getParamValue('size', node, tensorMap, context);
88684 srcTensorList = context.getTensorList(_idTensor6.id);
88685 destTensorList = srcTensorList.resize(_size);
88686 context.addTensorList(destTensorList);
88687 return _context2.abrupt("return", [destTensorList.idTensor]);
88688 case 184:
88689 throw TypeError("Node type ".concat(node.op, " is not implemented"));
88690 case 185:
88691 case "end":
88692 return _context2.stop();
88693 }
88694 }, _callee);
88695 }));
88696 return function executeOp(_x, _x2, _x3) {
88697 return _ref.apply(this, arguments);
88698 };
88699 }();
88700 var CATEGORY$h = 'control';
88701
88702 function fusedConvAndDepthWiseParams(node, tensorMap, context) {
88703 var _getParamValue = getParamValue('fusedOps', node, tensorMap, context),
88704 _getParamValue2 = _slicedToArray(_getParamValue, 2),
88705 extraOp = _getParamValue2[0],
88706 activationFunc = _getParamValue2[1];
88707 var isBiasAdd = extraOp === 'biasadd';
88708 var noBiasAdd = !isBiasAdd;
88709 var isPrelu = activationFunc === 'prelu';
88710 var isBatchNorm = extraOp === 'fusedbatchnorm';
88711 var numArgs = getParamValue('numArgs', node, tensorMap, context);
88712 if (isBiasAdd) {
88713 if (isPrelu && numArgs !== 2) {
88714 throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd and Prelu ' + 'must have two extra arguments: bias and alpha.');
88715 }
88716 if (!isPrelu && isBiasAdd && numArgs !== 1) {
88717 throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd must have ' + 'one extra argument: bias.');
88718 }
88719 }
88720 if (isBatchNorm) {
88721 throw new Error('FusedConv2d and DepthwiseConv2d with FusedBatchNorm is not supported');
88722 }
88723 var stride = getParamValue('strides', node, tensorMap, context);
88724 var pad = getPadding(node, tensorMap, context);
88725 var dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
88726 var dilations = getParamValue('dilations', node, tensorMap, context);
88727 var _getParamValue3 = getParamValue('args', node, tensorMap, context),
88728 _getParamValue4 = _slicedToArray(_getParamValue3, 2),
88729 biasArg = _getParamValue4[0],
88730 preluArg = _getParamValue4[1];
88731 if (noBiasAdd) {
88732 preluArg = biasArg;
88733 biasArg = undefined;
88734 }
88735 var leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
88736 return {
88737 stride: stride,
88738 pad: pad,
88739 dataFormat: dataFormat,
88740 dilations: dilations,
88741 biasArg: biasArg,
88742 preluArg: preluArg,
88743 activationFunc: activationFunc,
88744 leakyreluAlpha: leakyreluAlpha
88745 };
88746 }
88747 var executeOp$h = function executeOp(node, tensorMap, context) {
88748 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
88749 switch (node.op) {
88750 case 'Conv1D':
88751 {
88752 var stride = getParamValue('stride', node, tensorMap, context);
88753 var pad = getParamValue('pad', node, tensorMap, context);
88754 var dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
88755 var dilation = getParamValue('dilation', node, tensorMap, context);
88756 return [ops.conv1d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), stride, pad, dataFormat, dilation)];
88757 }
88758 case 'Conv2D':
88759 {
88760 var _stride = getParamValue('strides', node, tensorMap, context);
88761 var _pad = getPadding(node, tensorMap, context);
88762 var _dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
88763 var dilations = getParamValue('dilations', node, tensorMap, context);
88764 return [ops.conv2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [_stride[1], _stride[2]], _pad, _dataFormat, [dilations[1], dilations[2]])];
88765 }
88766 case '_FusedConv2D':
88767 {
88768 var _fusedConvAndDepthWis = fusedConvAndDepthWiseParams(node, tensorMap, context),
88769 _stride2 = _fusedConvAndDepthWis.stride,
88770 _pad2 = _fusedConvAndDepthWis.pad,
88771 _dataFormat2 = _fusedConvAndDepthWis.dataFormat,
88772 _dilations = _fusedConvAndDepthWis.dilations,
88773 biasArg = _fusedConvAndDepthWis.biasArg,
88774 preluArg = _fusedConvAndDepthWis.preluArg,
88775 activationFunc = _fusedConvAndDepthWis.activationFunc,
88776 leakyreluAlpha = _fusedConvAndDepthWis.leakyreluAlpha;
88777 return [ops.fused.conv2d({
88778 x: getParamValue('x', node, tensorMap, context),
88779 filter: getParamValue('filter', node, tensorMap, context),
88780 strides: [_stride2[1], _stride2[2]],
88781 pad: _pad2,
88782 dataFormat: _dataFormat2,
88783 dilations: [_dilations[1], _dilations[2]],
88784 bias: biasArg,
88785 activation: activationFunc,
88786 preluActivationWeights: preluArg,
88787 leakyreluAlpha: leakyreluAlpha
88788 })];
88789 }
88790 case 'FusedDepthwiseConv2dNative':
88791 {
88792 var _fusedConvAndDepthWis2 = fusedConvAndDepthWiseParams(node, tensorMap, context),
88793 _stride3 = _fusedConvAndDepthWis2.stride,
88794 _pad3 = _fusedConvAndDepthWis2.pad,
88795 _dataFormat3 = _fusedConvAndDepthWis2.dataFormat,
88796 _dilations2 = _fusedConvAndDepthWis2.dilations,
88797 _biasArg = _fusedConvAndDepthWis2.biasArg,
88798 _preluArg = _fusedConvAndDepthWis2.preluArg,
88799 _activationFunc = _fusedConvAndDepthWis2.activationFunc,
88800 _leakyreluAlpha = _fusedConvAndDepthWis2.leakyreluAlpha;
88801 return [ops.fused.depthwiseConv2d({
88802 x: getParamValue('x', node, tensorMap, context),
88803 filter: getParamValue('filter', node, tensorMap, context),
88804 strides: [_stride3[1], _stride3[2]],
88805 pad: _pad3,
88806 dataFormat: _dataFormat3,
88807 dilations: [_dilations2[1], _dilations2[2]],
88808 bias: _biasArg,
88809 activation: _activationFunc,
88810 preluActivationWeights: _preluArg,
88811 leakyreluAlpha: _leakyreluAlpha
88812 })];
88813 }
88814 case 'Conv2DBackpropInput':
88815 case 'Conv2dTranspose':
88816 {
88817 var shape = getParamValue('outputShape', node, tensorMap, context);
88818 var _stride4 = getParamValue('strides', node, tensorMap, context);
88819 var _pad4 = getPadding(node, tensorMap, context);
88820 return [ops.conv2dTranspose(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), shape, [_stride4[1], _stride4[2]], _pad4)];
88821 }
88822 case 'DepthwiseConv2dNative':
88823 case 'DepthwiseConv2d':
88824 {
88825 var _stride5 = getParamValue('strides', node, tensorMap, context);
88826 var _pad5 = getPadding(node, tensorMap, context);
88827 var _dilations3 = getParamValue('dilations', node, tensorMap, context);
88828 var _dataFormat4 = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
88829 return [ops.depthwiseConv2d(getParamValue('input', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [_stride5[1], _stride5[2]], _pad5, _dataFormat4, [_dilations3[1], _dilations3[2]])];
88830 }
88831 case 'Conv3D':
88832 {
88833 var _stride6 = getParamValue('strides', node, tensorMap, context);
88834 var _pad6 = getParamValue('pad', node, tensorMap, context);
88835 var _dataFormat5 = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
88836 var _dilations4 = getParamValue('dilations', node, tensorMap, context);
88837 return [ops.conv3d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [_stride6[1], _stride6[2], _stride6[3]], _pad6, _dataFormat5, [_dilations4[1], _dilations4[2], _dilations4[3]])];
88838 }
88839 case 'AvgPool':
88840 {
88841 var _stride7 = getParamValue('strides', node, tensorMap, context);
88842 var _pad7 = getParamValue('pad', node, tensorMap, context);
88843 var kernelSize = getParamValue('kernelSize', node, tensorMap, context);
88844 return [ops.avgPool(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [_stride7[1], _stride7[2]], _pad7)];
88845 }
88846 case 'MaxPool':
88847 {
88848 var _stride8 = getParamValue('strides', node, tensorMap, context);
88849 var _pad8 = getParamValue('pad', node, tensorMap, context);
88850 var _kernelSize = getParamValue('kernelSize', node, tensorMap, context);
88851 return [ops.maxPool(getParamValue('x', node, tensorMap, context), [_kernelSize[1], _kernelSize[2]], [_stride8[1], _stride8[2]], _pad8)];
88852 }
88853 case 'MaxPoolWithArgmax':
88854 {
88855 var _stride9 = getParamValue('strides', node, tensorMap, context);
88856 var _pad9 = getParamValue('pad', node, tensorMap, context);
88857 var _kernelSize2 = getParamValue('kernelSize', node, tensorMap, context);
88858 var includeBatchInIndex = getParamValue('includeBatchInIndex', node, tensorMap, context);
88859 var _ops$maxPoolWithArgma = ops.maxPoolWithArgmax(getParamValue('x', node, tensorMap, context), [_kernelSize2[1], _kernelSize2[2]], [_stride9[1], _stride9[2]], _pad9, includeBatchInIndex),
88860 result = _ops$maxPoolWithArgma.result,
88861 indexes = _ops$maxPoolWithArgma.indexes;
88862 return [result, indexes];
88863 }
88864 case 'AvgPool3D':
88865 {
88866 var _stride10 = getParamValue('strides', node, tensorMap, context);
88867 var _pad10 = getParamValue('pad', node, tensorMap, context);
88868 var _kernelSize3 = getParamValue('kernelSize', node, tensorMap, context);
88869 return [ops.avgPool3d(getParamValue('x', node, tensorMap, context), [_kernelSize3[1], _kernelSize3[2], _kernelSize3[3]], [_stride10[1], _stride10[2], _stride10[3]], _pad10)];
88870 }
88871 case 'MaxPool3D':
88872 {
88873 var _stride11 = getParamValue('strides', node, tensorMap, context);
88874 var _pad11 = getParamValue('pad', node, tensorMap, context);
88875 var _kernelSize4 = getParamValue('kernelSize', node, tensorMap, context);
88876 return [ops.maxPool3d(getParamValue('x', node, tensorMap, context), [_kernelSize4[1], _kernelSize4[2], _kernelSize4[3]], [_stride11[1], _stride11[2], _stride11[3]], _pad11)];
88877 }
88878 case 'Dilation2D':
88879 {
88880 var strides = getParamValue('strides', node, tensorMap, context);
88881 var _pad12 = getParamValue('pad', node, tensorMap, context);
88882 var _dilations5 = getParamValue('dilations', node, tensorMap, context);
88883 // strides: [1, stride_height, stride_width, 1].
88884 var strideHeight = strides[1];
88885 var strideWidth = strides[2];
88886 // dilations: [1, dilation_height, dilation_width, 1].
88887 var dilationHeight = _dilations5[1];
88888 var dilationWidth = _dilations5[2];
88889 return [ops.dilation2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [strideHeight, strideWidth], _pad12, [dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)];
88890 }
88891
88892 default:
88893 throw TypeError("Node type ".concat(node.op, " is not implemented"));
88894 }
88895 };
88896 var CATEGORY$g = 'convolution';
88897
88898 /**
88899 * @license
88900 * Copyright 2018 Google LLC. All Rights Reserved.
88901 * Licensed under the Apache License, Version 2.0 (the "License");
88902 * you may not use this file except in compliance with the License.
88903 * You may obtain a copy of the License at
88904 *
88905 * http://www.apache.org/licenses/LICENSE-2.0
88906 *
88907 * Unless required by applicable law or agreed to in writing, software
88908 * distributed under the License is distributed on an "AS IS" BASIS,
88909 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88910 * See the License for the specific language governing permissions and
88911 * limitations under the License.
88912 * =============================================================================
88913 */
88914 var executeOp$g = function executeOp(node, tensorMap, context) {
88915 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
88916 switch (node.op) {
88917 case 'Fill':
88918 {
88919 var shape = getParamValue('shape', node, tensorMap, context);
88920 var dtype = getParamValue('dtype', node, tensorMap, context);
88921 var value = getParamValue('value', node, tensorMap, context);
88922 return [ops.fill(shape, value, dtype)];
88923 }
88924 case 'LinSpace':
88925 {
88926 var start = getParamValue('start', node, tensorMap, context);
88927 var stop = getParamValue('stop', node, tensorMap, context);
88928 var num = getParamValue('num', node, tensorMap, context);
88929 return [ops.linspace(start, stop, num)];
88930 }
88931 case 'Multinomial':
88932 {
88933 var logits = getParamValue('logits', node, tensorMap, context);
88934 var numSamples = getParamValue('numSamples', node, tensorMap, context);
88935 var seed = getParamValue('seed', node, tensorMap, context);
88936 return [ops.multinomial(logits, numSamples, seed)];
88937 }
88938 case 'OneHot':
88939 {
88940 var indices = getParamValue('indices', node, tensorMap, context);
88941 var depth = getParamValue('depth', node, tensorMap, context);
88942 var onValue = getParamValue('onValue', node, tensorMap, context);
88943 var offValue = getParamValue('offValue', node, tensorMap, context);
88944 var _dtype = getParamValue('dtype', node, tensorMap, context);
88945 return [ops.oneHot(indices, depth, onValue, offValue, _dtype)];
88946 }
88947 case 'Ones':
88948 {
88949 return [ops.ones(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
88950 }
88951 case 'OnesLike':
88952 {
88953 return [ops.onesLike(getParamValue('x', node, tensorMap, context))];
88954 }
88955 case 'RandomStandardNormal':
88956 {
88957 return [ops.randomStandardNormal(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context), getParamValue('seed', node, tensorMap, context))];
88958 }
88959 case 'RandomUniform':
88960 {
88961 return [ops.randomUniform(
88962 // tslint:disable-next-line:no-any
88963 getParamValue('shape', node, tensorMap, context), getParamValue('minval', node, tensorMap, context), getParamValue('maxval', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
88964 }
88965 case 'RandomUniformInt':
88966 {
88967 return [ops.randomUniformInt(getParamValue('shape', node, tensorMap, context), getParamValue('minval', node, tensorMap, context), getParamValue('maxval', node, tensorMap, context), getParamValue('seed', node, tensorMap, context))];
88968 }
88969 case 'Range':
88970 {
88971 var _start = getParamValue('start', node, tensorMap, context);
88972 var _stop = getParamValue('stop', node, tensorMap, context);
88973 var step = getParamValue('step', node, tensorMap, context);
88974 return [ops.range(_start, _stop, step, getParamValue('dtype', node, tensorMap, context))];
88975 }
88976 case 'TruncatedNormal':
88977 {
88978 var _shape = getParamValue('shape', node, tensorMap, context);
88979 var mean = getParamValue('mean', node, tensorMap, context);
88980 var stdDev = getParamValue('stdDev', node, tensorMap, context);
88981 var _seed = getParamValue('seed', node, tensorMap, context);
88982 return [ops.truncatedNormal(_shape, mean, stdDev, getParamValue('dtype', node, tensorMap, context), _seed)];
88983 }
88984 case 'Zeros':
88985 {
88986 return [ops.zeros(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
88987 }
88988 case 'ZerosLike':
88989 {
88990 return [ops.zerosLike(getParamValue('x', node, tensorMap, context))];
88991 }
88992 default:
88993 throw TypeError("Node type ".concat(node.op, " is not implemented"));
88994 }
88995 };
88996 var CATEGORY$f = 'creation';
88997
88998 function nmsParams(node, tensorMap, context) {
88999 var boxes = getParamValue('boxes', node, tensorMap, context);
89000 var scores = getParamValue('scores', node, tensorMap, context);
89001 var maxOutputSize = getParamValue('maxOutputSize', node, tensorMap, context);
89002 var iouThreshold = getParamValue('iouThreshold', node, tensorMap, context);
89003 var scoreThreshold = getParamValue('scoreThreshold', node, tensorMap, context);
89004 var softNmsSigma = getParamValue('softNmsSigma', node, tensorMap, context);
89005 return {
89006 boxes: boxes,
89007 scores: scores,
89008 maxOutputSize: maxOutputSize,
89009 iouThreshold: iouThreshold,
89010 scoreThreshold: scoreThreshold,
89011 softNmsSigma: softNmsSigma
89012 };
89013 }
89014 var executeOp$f = /*#__PURE__*/function () {
89015 var _ref = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(node, tensorMap, context, resourceManager) {
89016 var ops,
89017 _nmsParams,
89018 boxes,
89019 scores,
89020 maxOutputSize,
89021 iouThreshold,
89022 scoreThreshold,
89023 softNmsSigma,
89024 result,
89025 _nmsParams2,
89026 _boxes,
89027 _scores,
89028 _maxOutputSize,
89029 _iouThreshold,
89030 _scoreThreshold,
89031 padToMaxOutputSize,
89032 _result,
89033 _nmsParams3,
89034 _boxes2,
89035 _scores2,
89036 _maxOutputSize2,
89037 _iouThreshold2,
89038 _scoreThreshold2,
89039 condition,
89040 _result2,
89041 _args = arguments;
89042 return _regeneratorRuntime().wrap(function _callee$(_context) {
89043 while (1) switch (_context.prev = _context.next) {
89044 case 0:
89045 ops = _args.length > 4 && _args[4] !== undefined ? _args[4] : tfOps;
89046 _context.t0 = node.op;
89047 _context.next = _context.t0 === 'NonMaxSuppressionV5' ? 4 : _context.t0 === 'NonMaxSuppressionV4' ? 9 : _context.t0 === 'NonMaxSuppressionV3' ? 15 : _context.t0 === 'NonMaxSuppressionV2' ? 15 : _context.t0 === 'Where' ? 20 : _context.t0 === 'ListDiff' ? 27 : 28;
89048 break;
89049 case 4:
89050 _nmsParams = nmsParams(node, tensorMap, context), boxes = _nmsParams.boxes, scores = _nmsParams.scores, maxOutputSize = _nmsParams.maxOutputSize, iouThreshold = _nmsParams.iouThreshold, scoreThreshold = _nmsParams.scoreThreshold, softNmsSigma = _nmsParams.softNmsSigma;
89051 _context.next = 7;
89052 return ops.image.nonMaxSuppressionWithScoreAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
89053 case 7:
89054 result = _context.sent;
89055 return _context.abrupt("return", [result.selectedIndices, result.selectedScores]);
89056 case 9:
89057 _nmsParams2 = nmsParams(node, tensorMap, context), _boxes = _nmsParams2.boxes, _scores = _nmsParams2.scores, _maxOutputSize = _nmsParams2.maxOutputSize, _iouThreshold = _nmsParams2.iouThreshold, _scoreThreshold = _nmsParams2.scoreThreshold;
89058 padToMaxOutputSize = getParamValue('padToMaxOutputSize', node, tensorMap, context);
89059 _context.next = 13;
89060 return ops.image.nonMaxSuppressionPaddedAsync(_boxes, _scores, _maxOutputSize, _iouThreshold, _scoreThreshold, padToMaxOutputSize);
89061 case 13:
89062 _result = _context.sent;
89063 return _context.abrupt("return", [_result.selectedIndices, _result.validOutputs]);
89064 case 15:
89065 _nmsParams3 = nmsParams(node, tensorMap, context), _boxes2 = _nmsParams3.boxes, _scores2 = _nmsParams3.scores, _maxOutputSize2 = _nmsParams3.maxOutputSize, _iouThreshold2 = _nmsParams3.iouThreshold, _scoreThreshold2 = _nmsParams3.scoreThreshold;
89066 _context.next = 18;
89067 return ops.image.nonMaxSuppressionAsync(_boxes2, _scores2, _maxOutputSize2, _iouThreshold2, _scoreThreshold2);
89068 case 18:
89069 _context.t1 = _context.sent;
89070 return _context.abrupt("return", [_context.t1]);
89071 case 20:
89072 condition = ops.cast(getParamValue('condition', node, tensorMap, context), 'bool');
89073 _context.next = 23;
89074 return ops.whereAsync(condition);
89075 case 23:
89076 _context.t2 = _context.sent;
89077 _result2 = [_context.t2];
89078 condition.dispose();
89079 return _context.abrupt("return", _result2);
89080 case 27:
89081 return _context.abrupt("return", ops.setdiff1dAsync(getParamValue('x', node, tensorMap, context), getParamValue('y', node, tensorMap, context)));
89082 case 28:
89083 throw TypeError("Node type ".concat(node.op, " is not implemented"));
89084 case 29:
89085 case "end":
89086 return _context.stop();
89087 }
89088 }, _callee);
89089 }));
89090 return function executeOp(_x, _x2, _x3, _x4) {
89091 return _ref.apply(this, arguments);
89092 };
89093 }();
89094 var CATEGORY$e = 'dynamic';
89095
89096 /**
89097 * @license
89098 * Copyright 2018 Google LLC. All Rights Reserved.
89099 * Licensed under the Apache License, Version 2.0 (the "License");
89100 * you may not use this file except in compliance with the License.
89101 * You may obtain a copy of the License at
89102 *
89103 * http://www.apache.org/licenses/LICENSE-2.0
89104 *
89105 * Unless required by applicable law or agreed to in writing, software
89106 * distributed under the License is distributed on an "AS IS" BASIS,
89107 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89108 * See the License for the specific language governing permissions and
89109 * limitations under the License.
89110 * =============================================================================
89111 */
89112 var executeOp$e = function executeOp(node, tensorMap, context) {
89113 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
89114 switch (node.op) {
89115 case 'LowerBound':
89116 {
89117 var sortedSequence = getParamValue('sortedSequence', node, tensorMap, context);
89118 var values = getParamValue('values', node, tensorMap, context);
89119 return [ops.lowerBound(sortedSequence, values)];
89120 }
89121 case 'TopKV2':
89122 {
89123 var x = getParamValue('x', node, tensorMap, context);
89124 var k = getParamValue('k', node, tensorMap, context);
89125 var sorted = getParamValue('sorted', node, tensorMap, context);
89126 var result = ops.topk(x, k, sorted);
89127 return [result.values, result.indices];
89128 }
89129 case 'UpperBound':
89130 {
89131 var _sortedSequence = getParamValue('sortedSequence', node, tensorMap, context);
89132 var _values = getParamValue('values', node, tensorMap, context);
89133 return [ops.upperBound(_sortedSequence, _values)];
89134 }
89135 case 'Unique':
89136 {
89137 var _x = getParamValue('x', node, tensorMap, context);
89138 var _result = ops.unique(_x);
89139 return [_result.values, _result.indices];
89140 }
89141 case 'UniqueV2':
89142 {
89143 var _x2 = getParamValue('x', node, tensorMap, context);
89144 var axis = getParamValue('axis', node, tensorMap, context);
89145 var _result2 = ops.unique(_x2, axis);
89146 return [_result2.values, _result2.indices];
89147 }
89148 default:
89149 throw TypeError("Node type ".concat(node.op, " is not implemented"));
89150 }
89151 };
89152 var CATEGORY$d = 'evaluation';
89153
89154 /**
89155 * @license
89156 * Copyright 2018 Google LLC. All Rights Reserved.
89157 * Licensed under the Apache License, Version 2.0 (the "License");
89158 * you may not use this file except in compliance with the License.
89159 * You may obtain a copy of the License at
89160 *
89161 * http://www.apache.org/licenses/LICENSE-2.0
89162 *
89163 * Unless required by applicable law or agreed to in writing, software
89164 * distributed under the License is distributed on an "AS IS" BASIS,
89165 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89166 * See the License for the specific language governing permissions and
89167 * limitations under the License.
89168 * =============================================================================
89169 */
89170 var executeOp$d = function executeOp(node, tensorMap, context) {
89171 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
89172 switch (node.op) {
89173 case 'Const':
89174 {
89175 return tensorMap[node.name];
89176 }
89177 case 'PlaceholderWithDefault':
89178 var def = getParamValue('default', node, tensorMap, context);
89179 return [getTensor(node.name, tensorMap, context) || def];
89180 case 'Placeholder':
89181 return [getTensor(node.name, tensorMap, context)];
89182 case 'Identity':
89183 case 'StopGradient':
89184 case 'FakeQuantWithMinMaxVars':
89185 {
89186 // This op is currently ignored.
89187 var _data = getParamValue('x', node, tensorMap, context);
89188 return [cloneTensor(_data)];
89189 }
89190 case 'IdentityN':
89191 return getParamValue('x', node, tensorMap, context).map(function (t) {
89192 return cloneTensor(t);
89193 });
89194 case 'Snapshot':
89195 var snapshot = getParamValue('x', node, tensorMap, context);
89196 return [cloneTensor(snapshot)];
89197 case 'Shape':
89198 return [ops.tensor1d(getParamValue('x', node, tensorMap, context).shape, 'int32')];
89199 case 'ShapeN':
89200 return getParamValue('x', node, tensorMap, context).map(function (t) {
89201 return ops.tensor1d(t.shape);
89202 });
89203 case 'Size':
89204 return [ops.scalar(getParamValue('x', node, tensorMap, context).size, 'int32')];
89205 case 'Rank':
89206 return [ops.scalar(getParamValue('x', node, tensorMap, context).rank, 'int32')];
89207 case 'NoOp':
89208 return [ops.scalar(1)];
89209 case 'Print':
89210 var input = getParamValue('x', node, tensorMap, context);
89211 var data = getParamValue('data', node, tensorMap, context);
89212 var message = getParamValue('message', node, tensorMap, context);
89213 var summarize = getParamValue('summarize', node, tensorMap, context);
89214 console.warn('The graph has a tf.print() operation,' + 'usually used for debugging, which slows down performance.');
89215 console.log(message);
89216 for (var i = 0; i < data.length; i++) {
89217 console.log(Array.prototype.slice.call(data[i].dataSync()).slice(0, summarize));
89218 }
89219 return [input];
89220 default:
89221 throw TypeError("Node type ".concat(node.op, " is not implemented"));
89222 }
89223 };
89224 var CATEGORY$c = 'graph';
89225
89226 /**
89227 * Hashtable contains a set of tensors, which can be accessed by key.
89228 */
89229 var HashTable = /*#__PURE__*/function () {
89230 /**
89231 * Constructor of HashTable. Creates a hash table.
89232 *
89233 * @param keyDType `dtype` of the table keys.
89234 * @param valueDType `dtype` of the table values.
89235 */
89236 function HashTable(keyDType, valueDType) {
89237 _classCallCheck(this, HashTable);
89238 this.keyDType = keyDType;
89239 this.valueDType = valueDType;
89240 this.handle = scalar(0);
89241 // tslint:disable-next-line: no-any
89242 this.tensorMap = new Map();
89243 keep(this.handle);
89244 }
89245 /**
89246 * Dispose the tensors and handle and clear the hashtable.
89247 */
89248 _createClass(HashTable, [{
89249 key: "id",
89250 get: function get() {
89251 return this.handle.id;
89252 }
89253 }, {
89254 key: "clearAndClose",
89255 value: function clearAndClose() {
89256 this.tensorMap.forEach(function (value) {
89257 return value.dispose();
89258 });
89259 this.tensorMap.clear();
89260 this.handle.dispose();
89261 }
89262 /**
89263 * The number of items in the hash table.
89264 */
89265 }, {
89266 key: "size",
89267 value: function size() {
89268 return this.tensorMap.size;
89269 }
89270 /**
89271 * The number of items in the hash table as a rank-0 tensor.
89272 */
89273 }, {
89274 key: "tensorSize",
89275 value: function tensorSize() {
89276 return scalar(this.size(), 'int32');
89277 }
89278 /**
89279 * Replaces the contents of the table with the specified keys and values.
89280 * @param keys Keys to store in the hashtable.
89281 * @param values Values to store in the hashtable.
89282 */
89283 }, {
89284 key: "import",
89285 value: function () {
89286 var _import2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(keys, values) {
89287 var _this = this;
89288 var $keys;
89289 return _regeneratorRuntime().wrap(function _callee$(_context) {
89290 while (1) switch (_context.prev = _context.next) {
89291 case 0:
89292 this.checkKeyAndValueTensor(keys, values);
89293 // We only store the primitive values of the keys, this allows lookup
89294 // to be O(1).
89295 _context.next = 3;
89296 return keys.data();
89297 case 3:
89298 $keys = _context.sent;
89299 // Clear the hashTable before inserting new values.
89300 this.tensorMap.forEach(function (value) {
89301 return value.dispose();
89302 });
89303 this.tensorMap.clear();
89304 return _context.abrupt("return", tidy(function () {
89305 var $values = unstack(values);
89306 var keysLength = $keys.length;
89307 var valuesLength = $values.length;
89308 assert$1(keysLength === valuesLength, function () {
89309 return "The number of elements doesn't match, keys has " + "".concat(keysLength, " elements, the values has ").concat(valuesLength, " ") + "elements.";
89310 });
89311 for (var i = 0; i < keysLength; i++) {
89312 var key = $keys[i];
89313 var value = $values[i];
89314 keep(value);
89315 _this.tensorMap.set(key, value);
89316 }
89317 return _this.handle;
89318 }));
89319 case 7:
89320 case "end":
89321 return _context.stop();
89322 }
89323 }, _callee, this);
89324 }));
89325 function _import(_x, _x2) {
89326 return _import2.apply(this, arguments);
89327 }
89328 return _import;
89329 }()
89330 /**
89331 * Looks up keys in a hash table, outputs the corresponding values.
89332 *
89333 * Performs batch lookups, for every element in the key tensor, `find`
89334 * stacks the corresponding value into the return tensor.
89335 *
89336 * If an element is not present in the table, the given `defaultValue` is
89337 * used.
89338 *
89339 * @param keys Keys to look up. Must have the same type as the keys of the
89340 * table.
89341 * @param defaultValue The scalar `defaultValue` is the value output for keys
89342 * not present in the table. It must also be of the same type as the
89343 * table values.
89344 */
89345 }, {
89346 key: "find",
89347 value: function () {
89348 var _find = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(keys, defaultValue) {
89349 var _this2 = this;
89350 var $keys;
89351 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
89352 while (1) switch (_context2.prev = _context2.next) {
89353 case 0:
89354 this.checkKeyAndValueTensor(keys, defaultValue);
89355 _context2.next = 3;
89356 return keys.data();
89357 case 3:
89358 $keys = _context2.sent;
89359 return _context2.abrupt("return", tidy(function () {
89360 var result = [];
89361 for (var i = 0; i < $keys.length; i++) {
89362 var key = $keys[i];
89363 var value = _this2.findWithDefault(key, defaultValue);
89364 result.push(value);
89365 }
89366 return stack(result);
89367 }));
89368 case 5:
89369 case "end":
89370 return _context2.stop();
89371 }
89372 }, _callee2, this);
89373 }));
89374 function find(_x3, _x4) {
89375 return _find.apply(this, arguments);
89376 }
89377 return find;
89378 }() // tslint:disable-next-line: no-any
89379 }, {
89380 key: "findWithDefault",
89381 value: function findWithDefault(key, defaultValue) {
89382 var result = this.tensorMap.get(key);
89383 return result != null ? result : defaultValue;
89384 }
89385 }, {
89386 key: "checkKeyAndValueTensor",
89387 value: function checkKeyAndValueTensor(key, value) {
89388 if (key.dtype !== this.keyDType) {
89389 throw new Error("Expect key dtype ".concat(this.keyDType, ", but got ") + "".concat(key.dtype));
89390 }
89391 if (value.dtype !== this.valueDType) {
89392 throw new Error("Expect value dtype ".concat(this.valueDType, ", but got ") + "".concat(value.dtype));
89393 }
89394 }
89395 }]);
89396 return HashTable;
89397 }();
89398
89399 var executeOp$c = /*#__PURE__*/function () {
89400 var _ref = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(node, tensorMap, context, resourceManager) {
89401 var existingTableHandle, keyDType, valueDType, hashTable, handle, keys, values, _hashTable, _handle, _keys, defaultValue, _hashTable2, _handle2, _hashTable3;
89402 return _regeneratorRuntime().wrap(function _callee$(_context) {
89403 while (1) switch (_context.prev = _context.next) {
89404 case 0:
89405 _context.t0 = node.op;
89406 _context.next = _context.t0 === 'HashTable' ? 3 : _context.t0 === 'HashTableV2' ? 3 : _context.t0 === 'InitializeTable' ? 13 : _context.t0 === 'InitializeTableV2' ? 13 : _context.t0 === 'LookupTableImport' ? 13 : _context.t0 === 'LookupTableImportV2' ? 13 : _context.t0 === 'LookupTableFind' ? 21 : _context.t0 === 'LookupTableFindV2' ? 21 : _context.t0 === 'LookupTableSize' ? 29 : _context.t0 === 'LookupTableSizeV2' ? 29 : 32;
89407 break;
89408 case 3:
89409 existingTableHandle = resourceManager.getHashTableHandleByName(node.name); // Table is shared with initializer.
89410 if (!(existingTableHandle != null)) {
89411 _context.next = 8;
89412 break;
89413 }
89414 return _context.abrupt("return", [existingTableHandle]);
89415 case 8:
89416 keyDType = getParamValue('keyDType', node, tensorMap, context);
89417 valueDType = getParamValue('valueDType', node, tensorMap, context);
89418 hashTable = new HashTable(keyDType, valueDType);
89419 resourceManager.addHashTable(node.name, hashTable);
89420 return _context.abrupt("return", [hashTable.handle]);
89421 case 13:
89422 handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
89423 keys = getParamValue('keys', node, tensorMap, context);
89424 values = getParamValue('values', node, tensorMap, context);
89425 _hashTable = resourceManager.getHashTableById(handle.id);
89426 _context.next = 19;
89427 return _hashTable.import(keys, values);
89428 case 19:
89429 _context.t1 = _context.sent;
89430 return _context.abrupt("return", [_context.t1]);
89431 case 21:
89432 _handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
89433 _keys = getParamValue('keys', node, tensorMap, context);
89434 defaultValue = getParamValue('defaultValue', node, tensorMap, context);
89435 _hashTable2 = resourceManager.getHashTableById(_handle.id);
89436 _context.next = 27;
89437 return _hashTable2.find(_keys, defaultValue);
89438 case 27:
89439 _context.t2 = _context.sent;
89440 return _context.abrupt("return", [_context.t2]);
89441 case 29:
89442 _handle2 = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
89443 _hashTable3 = resourceManager.getHashTableById(_handle2.id);
89444 return _context.abrupt("return", [_hashTable3.tensorSize()]);
89445 case 32:
89446 throw TypeError("Node type ".concat(node.op, " is not implemented"));
89447 case 33:
89448 case "end":
89449 return _context.stop();
89450 }
89451 }, _callee);
89452 }));
89453 return function executeOp(_x, _x2, _x3, _x4) {
89454 return _ref.apply(this, arguments);
89455 };
89456 }();
89457 var CATEGORY$b = 'hash_table';
89458
89459 /**
89460 * @license
89461 * Copyright 2018 Google LLC. All Rights Reserved.
89462 * Licensed under the Apache License, Version 2.0 (the "License");
89463 * you may not use this file except in compliance with the License.
89464 * You may obtain a copy of the License at
89465 *
89466 * http://www.apache.org/licenses/LICENSE-2.0
89467 *
89468 * Unless required by applicable law or agreed to in writing, software
89469 * distributed under the License is distributed on an "AS IS" BASIS,
89470 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89471 * See the License for the specific language governing permissions and
89472 * limitations under the License.
89473 * =============================================================================
89474 */
89475 var executeOp$b = function executeOp(node, tensorMap, context) {
89476 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
89477 switch (node.op) {
89478 case 'ResizeBilinear':
89479 {
89480 var images = getParamValue('images', node, tensorMap, context);
89481 var size = getParamValue('size', node, tensorMap, context);
89482 var alignCorners = getParamValue('alignCorners', node, tensorMap, context);
89483 var halfPixelCenters = getParamValue('halfPixelCenters', node, tensorMap, context);
89484 return [ops.image.resizeBilinear(images, [size[0], size[1]], alignCorners, halfPixelCenters)];
89485 }
89486 case 'ResizeNearestNeighbor':
89487 {
89488 var _images = getParamValue('images', node, tensorMap, context);
89489 var _size = getParamValue('size', node, tensorMap, context);
89490 var _alignCorners = getParamValue('alignCorners', node, tensorMap, context);
89491 var _halfPixelCenters = getParamValue('halfPixelCenters', node, tensorMap, context);
89492 return [ops.image.resizeNearestNeighbor(_images, [_size[0], _size[1]], _alignCorners, _halfPixelCenters)];
89493 }
89494 case 'CropAndResize':
89495 {
89496 var image = getParamValue('image', node, tensorMap, context);
89497 var boxes = getParamValue('boxes', node, tensorMap, context);
89498 var boxInd = getParamValue('boxInd', node, tensorMap, context);
89499 var cropSize = getParamValue('cropSize', node, tensorMap, context);
89500 var method = getParamValue('method', node, tensorMap, context);
89501 var extrapolationValue = getParamValue('extrapolationValue', node, tensorMap, context);
89502 return [ops.image.cropAndResize(image, boxes, boxInd, cropSize, method, extrapolationValue)];
89503 }
89504 case 'ImageProjectiveTransformV3':
89505 {
89506 var _images2 = getParamValue('images', node, tensorMap, context);
89507 var transforms = getParamValue('transforms', node, tensorMap, context);
89508 var outputShape = getParamValue('outputShape', node, tensorMap, context);
89509 var fillValue = getParamValue('fillValue', node, tensorMap, context);
89510 var interpolation = getParamValue('interpolation', node, tensorMap, context);
89511 var fillMode = getParamValue('fillMode', node, tensorMap, context);
89512 return [ops.image.transform(_images2, transforms, interpolation.toLowerCase(), fillMode.toLowerCase(), fillValue, outputShape)];
89513 }
89514 default:
89515 throw TypeError("Node type ".concat(node.op, " is not implemented"));
89516 }
89517 };
89518 var CATEGORY$a = 'image';
89519
89520 /**
89521 * @license
89522 * Copyright 2018 Google LLC. All Rights Reserved.
89523 * Licensed under the Apache License, Version 2.0 (the "License");
89524 * you may not use this file except in compliance with the License.
89525 * You may obtain a copy of the License at
89526 *
89527 * http://www.apache.org/licenses/LICENSE-2.0
89528 *
89529 * Unless required by applicable law or agreed to in writing, software
89530 * distributed under the License is distributed on an "AS IS" BASIS,
89531 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89532 * See the License for the specific language governing permissions and
89533 * limitations under the License.
89534 * =============================================================================
89535 */
89536 var executeOp$a = function executeOp(node, tensorMap, context) {
89537 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
89538 switch (node.op) {
89539 case 'Equal':
89540 {
89541 return [ops.equal(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
89542 }
89543 case 'NotEqual':
89544 {
89545 return [ops.notEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
89546 }
89547 case 'Greater':
89548 {
89549 return [ops.greater(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
89550 }
89551 case 'GreaterEqual':
89552 {
89553 return [ops.greaterEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
89554 }
89555 case 'Less':
89556 {
89557 return [ops.less(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
89558 }
89559 case 'LessEqual':
89560 {
89561 return [ops.lessEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
89562 }
89563 case 'LogicalAnd':
89564 {
89565 return [ops.logicalAnd(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
89566 }
89567 case 'LogicalNot':
89568 {
89569 return [ops.logicalNot(getParamValue('a', node, tensorMap, context))];
89570 }
89571 case 'LogicalOr':
89572 {
89573 return [ops.logicalOr(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
89574 }
89575 case 'Select':
89576 case 'SelectV2':
89577 {
89578 return [ops.where(getParamValue('condition', node, tensorMap, context), getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
89579 }
89580 case 'BitwiseAnd':
89581 {
89582 return [ops.bitwiseAnd(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
89583 }
89584 default:
89585 throw TypeError("Node type ".concat(node.op, " is not implemented"));
89586 }
89587 };
89588 var CATEGORY$9 = 'logical';
89589
89590 var executeOp$9 = function executeOp(node, tensorMap, context) {
89591 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
89592 switch (node.op) {
89593 case 'BatchMatMul':
89594 case 'BatchMatMulV2':
89595 case 'MatMul':
89596 return [ops.matMul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context), getParamValue('transposeA', node, tensorMap, context), getParamValue('transposeB', node, tensorMap, context))];
89597 case 'Einsum':
89598 return [ops.einsum.apply(ops, [getParamValue('equation', node, tensorMap, context)].concat(_toConsumableArray(getParamValue('tensors', node, tensorMap, context))))];
89599 case 'Transpose':
89600 return [ops.transpose(getParamValue('x', node, tensorMap, context), getParamValue('perm', node, tensorMap, context))];
89601 case '_FusedMatMul':
89602 var _getParamValue = getParamValue('fusedOps', node, tensorMap, context),
89603 _getParamValue2 = _slicedToArray(_getParamValue, 2),
89604 extraOp = _getParamValue2[0],
89605 activationFunc = _getParamValue2[1];
89606 var isBiasAdd = extraOp === 'biasadd';
89607 var isPrelu = activationFunc === 'prelu';
89608 var numArgs = getParamValue('numArgs', node, tensorMap, context);
89609 var leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
89610 if (isBiasAdd) {
89611 if (isPrelu && numArgs !== 2) {
89612 throw new Error('Fused MatMul with BiasAdd and Prelu must have two ' + 'extra arguments: bias and alpha.');
89613 }
89614 if (!isPrelu && numArgs !== 1) {
89615 throw new Error('Fused MatMul with BiasAdd must have one extra argument: bias.');
89616 }
89617 }
89618 var _getParamValue3 = getParamValue('args', node, tensorMap, context),
89619 _getParamValue4 = _slicedToArray(_getParamValue3, 2),
89620 biasArg = _getParamValue4[0],
89621 preluArg = _getParamValue4[1];
89622 return [ops.fused.matMul({
89623 a: getParamValue('a', node, tensorMap, context),
89624 b: getParamValue('b', node, tensorMap, context),
89625 transposeA: getParamValue('transposeA', node, tensorMap, context),
89626 transposeB: getParamValue('transposeB', node, tensorMap, context),
89627 bias: biasArg,
89628 activation: activationFunc,
89629 preluActivationWeights: preluArg,
89630 leakyreluAlpha: leakyreluAlpha
89631 })];
89632 case 'MatrixBandPart':
89633 return [ops.linalg.bandPart(getParamValue('a', node, tensorMap, context), getParamValue('numLower', node, tensorMap, context), getParamValue('numUpper', node, tensorMap, context))];
89634 default:
89635 throw TypeError("Node type ".concat(node.op, " is not implemented"));
89636 }
89637 };
89638 var CATEGORY$8 = 'matrices';
89639
89640 /**
89641 * @license
89642 * Copyright 2018 Google LLC. All Rights Reserved.
89643 * Licensed under the Apache License, Version 2.0 (the "License");
89644 * you may not use this file except in compliance with the License.
89645 * You may obtain a copy of the License at
89646 *
89647 * http://www.apache.org/licenses/LICENSE-2.0
89648 *
89649 * Unless required by applicable law or agreed to in writing, software
89650 * distributed under the License is distributed on an "AS IS" BASIS,
89651 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89652 * See the License for the specific language governing permissions and
89653 * limitations under the License.
89654 * =============================================================================
89655 */
89656 var executeOp$8 = function executeOp(node, tensorMap, context) {
89657 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
89658 switch (node.op) {
89659 case 'EuclideanNorm':
89660 return [ops.euclideanNorm(getParamValue('x', node, tensorMap, context), getParamValue('axis', node, tensorMap, context), getParamValue('keepDims', node, tensorMap, context))];
89661 case 'FusedBatchNorm':
89662 case 'FusedBatchNormV2':
89663 {
89664 return [ops.batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
89665 }
89666 case 'FusedBatchNormV3':
89667 {
89668 return [ops.batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
89669 }
89670 case 'LRN':
89671 {
89672 return [ops.localResponseNormalization(getParamValue('x', node, tensorMap, context), getParamValue('radius', node, tensorMap, context), getParamValue('bias', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context), getParamValue('beta', node, tensorMap, context))];
89673 }
89674 case 'Softmax':
89675 {
89676 return [ops.softmax(getParamValue('x', node, tensorMap, context))];
89677 }
89678 case 'LogSoftmax':
89679 {
89680 return [ops.logSoftmax(getParamValue('x', node, tensorMap, context))];
89681 }
89682 default:
89683 throw TypeError("Node type ".concat(node.op, " is not implemented"));
89684 }
89685 };
89686 var CATEGORY$7 = 'normalization';
89687
89688 /**
89689 * @license
89690 * Copyright 2022 Google LLC. All Rights Reserved.
89691 * Licensed under the Apache License, Version 2.0 (the "License");
89692 * you may not use this file except in compliance with the License.
89693 * You may obtain a copy of the License at
89694 *
89695 * http://www.apache.org/licenses/LICENSE-2.0
89696 *
89697 * Unless required by applicable law or agreed to in writing, software
89698 * distributed under the License is distributed on an "AS IS" BASIS,
89699 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89700 * See the License for the specific language governing permissions and
89701 * limitations under the License.
89702 * =============================================================================
89703 */
89704 var executeOp$7 = function executeOp(node, tensorMap, context) {
89705 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
89706 switch (node.op) {
89707 case 'RaggedGather':
89708 {
89709 var _ops$raggedGather = ops.raggedGather(getParamValue('paramsNestedSplits', node, tensorMap, context), getParamValue('paramsDenseValues', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('outputRaggedRank', node, tensorMap, context)),
89710 outputNestedSplits = _ops$raggedGather.outputNestedSplits,
89711 outputDenseValues = _ops$raggedGather.outputDenseValues;
89712 return outputNestedSplits.concat(outputDenseValues);
89713 }
89714 case 'RaggedRange':
89715 {
89716 var _ops$raggedRange = ops.raggedRange(getParamValue('starts', node, tensorMap, context), getParamValue('limits', node, tensorMap, context), getParamValue('splits', node, tensorMap, context)),
89717 rtNestedSplits = _ops$raggedRange.rtNestedSplits,
89718 rtDenseValues = _ops$raggedRange.rtDenseValues;
89719 return [rtNestedSplits, rtDenseValues];
89720 }
89721 case 'RaggedTensorToTensor':
89722 {
89723 return [ops.raggedTensorToTensor(getParamValue('shape', node, tensorMap, context), getParamValue('values', node, tensorMap, context), getParamValue('defaultValue', node, tensorMap, context), getParamValue('rowPartitionTensors', node, tensorMap, context), getParamValue('rowPartitionTypes', node, tensorMap, context))];
89724 }
89725 default:
89726 throw TypeError("Node type ".concat(node.op, " is not implemented"));
89727 }
89728 };
89729 var CATEGORY$6 = 'ragged';
89730
89731 /**
89732 * @license
89733 * Copyright 2018 Google LLC. All Rights Reserved.
89734 * Licensed under the Apache License, Version 2.0 (the "License");
89735 * you may not use this file except in compliance with the License.
89736 * You may obtain a copy of the License at
89737 *
89738 * http://www.apache.org/licenses/LICENSE-2.0
89739 *
89740 * Unless required by applicable law or agreed to in writing, software
89741 * distributed under the License is distributed on an "AS IS" BASIS,
89742 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89743 * See the License for the specific language governing permissions and
89744 * limitations under the License.
89745 * =============================================================================
89746 */
89747 var executeOp$6 = function executeOp(node, tensorMap, context) {
89748 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
89749 switch (node.op) {
89750 case 'Max':
89751 {
89752 var axis = getParamValue('axis', node, tensorMap, context);
89753 var keepDims = getParamValue('keepDims', node, tensorMap, context);
89754 return [ops.max(getParamValue('x', node, tensorMap, context), axis, keepDims)];
89755 }
89756 case 'Mean':
89757 {
89758 var _axis = getParamValue('axis', node, tensorMap, context);
89759 var _keepDims = getParamValue('keepDims', node, tensorMap, context);
89760 return [ops.mean(getParamValue('x', node, tensorMap, context), _axis, _keepDims)];
89761 }
89762 case 'Min':
89763 {
89764 var _axis2 = getParamValue('axis', node, tensorMap, context);
89765 var _keepDims2 = getParamValue('keepDims', node, tensorMap, context);
89766 return [ops.min(getParamValue('x', node, tensorMap, context), _axis2, _keepDims2)];
89767 }
89768 case 'Sum':
89769 {
89770 var _axis3 = getParamValue('axis', node, tensorMap, context);
89771 var _keepDims3 = getParamValue('keepDims', node, tensorMap, context);
89772 return [ops.sum(getParamValue('x', node, tensorMap, context), _axis3, _keepDims3)];
89773 }
89774 case 'All':
89775 {
89776 var _axis4 = getParamValue('axis', node, tensorMap, context);
89777 var _keepDims4 = getParamValue('keepDims', node, tensorMap, context);
89778 return [ops.all(getParamValue('x', node, tensorMap, context), _axis4, _keepDims4)];
89779 }
89780 case 'Any':
89781 {
89782 var _axis5 = getParamValue('axis', node, tensorMap, context);
89783 var _keepDims5 = getParamValue('keepDims', node, tensorMap, context);
89784 return [ops.any(getParamValue('x', node, tensorMap, context), _axis5, _keepDims5)];
89785 }
89786 case 'ArgMax':
89787 {
89788 var _axis6 = getParamValue('axis', node, tensorMap, context);
89789 return [ops.argMax(getParamValue('x', node, tensorMap, context), _axis6)];
89790 }
89791 case 'ArgMin':
89792 {
89793 var _axis7 = getParamValue('axis', node, tensorMap, context);
89794 return [ops.argMin(getParamValue('x', node, tensorMap, context), _axis7)];
89795 }
89796 case 'Prod':
89797 {
89798 var _axis8 = getParamValue('axis', node, tensorMap, context);
89799 var _keepDims6 = getParamValue('keepDims', node, tensorMap, context);
89800 return [ops.prod(getParamValue('x', node, tensorMap, context), _axis8, _keepDims6)];
89801 }
89802 case 'Cumprod':
89803 {
89804 var _axis9 = getParamValue('axis', node, tensorMap, context);
89805 var exclusive = getParamValue('exclusive', node, tensorMap, context);
89806 var reverse = getParamValue('reverse', node, tensorMap, context);
89807 return [ops.cumprod(getParamValue('x', node, tensorMap, context), _axis9, exclusive, reverse)];
89808 }
89809 case 'Cumsum':
89810 {
89811 var _axis10 = getParamValue('axis', node, tensorMap, context);
89812 var _exclusive = getParamValue('exclusive', node, tensorMap, context);
89813 var _reverse = getParamValue('reverse', node, tensorMap, context);
89814 return [ops.cumsum(getParamValue('x', node, tensorMap, context), _axis10, _exclusive, _reverse)];
89815 }
89816 case 'Bincount':
89817 var x = getParamValue('x', node, tensorMap, context);
89818 var weights = getParamValue('weights', node, tensorMap, context);
89819 var size = getParamValue('size', node, tensorMap, context);
89820 return [ops.bincount(x, weights, size)];
89821 case 'DenseBincount':
89822 {
89823 var _x = getParamValue('x', node, tensorMap, context);
89824 var _weights = getParamValue('weights', node, tensorMap, context);
89825 var _size = getParamValue('size', node, tensorMap, context);
89826 var binaryOutput = getParamValue('binaryOutput', node, tensorMap, context);
89827 return [ops.denseBincount(_x, _weights, _size, binaryOutput)];
89828 }
89829 default:
89830 throw TypeError("Node type ".concat(node.op, " is not implemented"));
89831 }
89832 };
89833 var CATEGORY$5 = 'reduction';
89834
89835 /**
89836 * @license
89837 * Copyright 2018 Google LLC. All Rights Reserved.
89838 * Licensed under the Apache License, Version 2.0 (the "License");
89839 * you may not use this file except in compliance with the License.
89840 * You may obtain a copy of the License at
89841 *
89842 * http://www.apache.org/licenses/LICENSE-2.0
89843 *
89844 * Unless required by applicable law or agreed to in writing, software
89845 * distributed under the License is distributed on an "AS IS" BASIS,
89846 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89847 * See the License for the specific language governing permissions and
89848 * limitations under the License.
89849 * =============================================================================
89850 */
89851 var executeOp$5 = function executeOp(node, tensorMap, context) {
89852 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
89853 switch (node.op) {
89854 case 'ConcatV2':
89855 case 'Concat':
89856 {
89857 var n = getParamValue('n', node, tensorMap, context);
89858 var axis = getParamValue('axis', node, tensorMap, context);
89859 var inputs = getParamValue('tensors', node, tensorMap, context);
89860 inputs = inputs.slice(0, n);
89861 return [ops.concat(inputs, axis)];
89862 }
89863 case 'Gather':
89864 {
89865 var input = getParamValue('x', node, tensorMap, context);
89866 var indices = getParamValue('indices', node, tensorMap, context);
89867 return [ops.gather(input, ops.cast(indices, 'int32'), 0)];
89868 }
89869 case 'GatherV2':
89870 {
89871 var _axis = getParamValue('axis', node, tensorMap, context);
89872 var batchDims = getParamValue('batchDims', node, tensorMap, context);
89873 var _input = getParamValue('x', node, tensorMap, context);
89874 var _indices = getParamValue('indices', node, tensorMap, context);
89875 return [ops.gather(_input, ops.cast(_indices, 'int32'), _axis, batchDims)];
89876 }
89877 case 'Reverse':
89878 {
89879 var dims = getParamValue('dims', node, tensorMap, context);
89880 var _axis2 = [];
89881 for (var i = 0; i < dims.length; i++) {
89882 if (dims[i]) {
89883 _axis2.push(i);
89884 }
89885 }
89886 var _input2 = getParamValue('x', node, tensorMap, context);
89887 return [ops.reverse(_input2, _axis2)];
89888 }
89889 case 'ReverseV2':
89890 {
89891 var _axis3 = getParamValue('axis', node, tensorMap, context);
89892 var _input3 = getParamValue('x', node, tensorMap, context);
89893 return [ops.reverse(_input3, _axis3)];
89894 }
89895 case 'Slice':
89896 {
89897 // tslint:disable-next-line:no-any
89898 var begin = getParamValue('begin', node, tensorMap, context);
89899 // tslint:disable-next-line:no-any
89900 var size = getParamValue('size', node, tensorMap, context);
89901 return [ops.slice(getParamValue('x', node, tensorMap, context), begin, size)];
89902 }
89903 case 'StridedSlice':
89904 {
89905 var _begin = getParamValue('begin', node, tensorMap, context);
89906 var end = getParamValue('end', node, tensorMap, context);
89907 var strides = getParamValue('strides', node, tensorMap, context);
89908 var beginMask = getParamValue('beginMask', node, tensorMap, context);
89909 var endMask = getParamValue('endMask', node, tensorMap, context);
89910 var ellipsisMask = getParamValue('ellipsisMask', node, tensorMap, context);
89911 var newAxisMask = getParamValue('newAxisMask', node, tensorMap, context);
89912 var shrinkAxisMask = getParamValue('shrinkAxisMask', node, tensorMap, context);
89913 var tensor = getParamValue('x', node, tensorMap, context);
89914 return [ops.stridedSlice(tensor, _begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)];
89915 }
89916 case 'Pack':
89917 {
89918 return tidy(function () {
89919 var axis = getParamValue('axis', node, tensorMap, context);
89920 var tensors = getParamValue('tensors', node, tensorMap, context);
89921 // Reshape the tensors to the first tensor's shape if they don't
89922 // match.
89923 var shape = tensors[0].shape;
89924 var squeezedShape = ops.squeeze(tensors[0]).shape;
89925 var mapped = tensors.map(function (tensor) {
89926 var sameShape = arraysEqual(tensor.shape, shape);
89927 if (!sameShape && !arraysEqual(ops.squeeze(tensor).shape, squeezedShape)) {
89928 throw new Error('the input tensors shape does not match');
89929 }
89930 return sameShape ? tensor : ops.reshape(tensor, shape);
89931 });
89932 return [ops.stack(mapped, axis)];
89933 });
89934 }
89935 case 'Unpack':
89936 {
89937 var _axis4 = getParamValue('axis', node, tensorMap, context);
89938 var _tensor = getParamValue('tensor', node, tensorMap, context);
89939 return ops.unstack(_tensor, _axis4);
89940 }
89941 case 'Tile':
89942 {
89943 var reps = getParamValue('reps', node, tensorMap, context);
89944 return [ops.tile(getParamValue('x', node, tensorMap, context), reps)];
89945 }
89946 case 'Split':
89947 case 'SplitV':
89948 {
89949 var _axis5 = getParamValue('axis', node, tensorMap, context);
89950 var numOrSizeSplits = getParamValue('numOrSizeSplits', node, tensorMap, context);
89951 var _tensor2 = getParamValue('x', node, tensorMap, context);
89952 return ops.split(_tensor2, numOrSizeSplits, _axis5);
89953 }
89954 case 'ScatterNd':
89955 {
89956 var _indices2 = getParamValue('indices', node, tensorMap, context);
89957 var values = getParamValue('values', node, tensorMap, context);
89958 var shape = getParamValue('shape', node, tensorMap, context);
89959 return [ops.scatterND(_indices2, values, shape)];
89960 }
89961 case 'GatherNd':
89962 {
89963 var x = getParamValue('x', node, tensorMap, context);
89964 var _indices3 = getParamValue('indices', node, tensorMap, context);
89965 return [ops.gatherND(x, _indices3)];
89966 }
89967 case 'SparseToDense':
89968 {
89969 var _indices4 = getParamValue('sparseIndices', node, tensorMap, context);
89970 var _shape = getParamValue('outputShape', node, tensorMap, context);
89971 var sparseValues = getParamValue('sparseValues', node, tensorMap, context);
89972 var defaultValue = getParamValue('defaultValue', node, tensorMap, context);
89973 return [ops.sparseToDense(_indices4, sparseValues, _shape, sparseValues.dtype === defaultValue.dtype ? defaultValue : ops.cast(defaultValue, sparseValues.dtype))];
89974 }
89975 case 'TensorScatterUpdate':
89976 {
89977 var _indices5 = getParamValue('indices', node, tensorMap, context);
89978 var _values = getParamValue('values', node, tensorMap, context);
89979 var _tensor3 = getParamValue('tensor', node, tensorMap, context);
89980 return [ops.tensorScatterUpdate(_tensor3, _indices5, _values)];
89981 }
89982 default:
89983 throw TypeError("Node type ".concat(node.op, " is not implemented"));
89984 }
89985 };
89986 var CATEGORY$4 = 'slice_join';
89987
89988 /**
89989 * @license
89990 * Copyright 2021 Google LLC. All Rights Reserved.
89991 * Licensed under the Apache License, Version 2.0 (the "License");
89992 * you may not use this file except in compliance with the License.
89993 * You may obtain a copy of the License at
89994 *
89995 * http://www.apache.org/licenses/LICENSE-2.0
89996 *
89997 * Unless required by applicable law or agreed to in writing, software
89998 * distributed under the License is distributed on an "AS IS" BASIS,
89999 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90000 * See the License for the specific language governing permissions and
90001 * limitations under the License.
90002 * =============================================================================
90003 */
90004 var executeOp$4 = function executeOp(node, tensorMap, context) {
90005 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
90006 switch (node.op) {
90007 case 'SparseFillEmptyRows':
90008 {
90009 var _ops$sparse$sparseFil = ops.sparse.sparseFillEmptyRows(getParamValue('indices', node, tensorMap, context), getParamValue('values', node, tensorMap, context), getParamValue('denseShape', node, tensorMap, context), getParamValue('defaultValue', node, tensorMap, context)),
90010 outputIndices = _ops$sparse$sparseFil.outputIndices,
90011 outputValues = _ops$sparse$sparseFil.outputValues,
90012 emptyRowIndicator = _ops$sparse$sparseFil.emptyRowIndicator,
90013 reverseIndexMap = _ops$sparse$sparseFil.reverseIndexMap;
90014 return [outputIndices, outputValues, emptyRowIndicator, reverseIndexMap];
90015 }
90016 case 'SparseReshape':
90017 {
90018 var _ops$sparse$sparseRes = ops.sparse.sparseReshape(getParamValue('inputIndices', node, tensorMap, context), getParamValue('inputShape', node, tensorMap, context), getParamValue('newShape', node, tensorMap, context)),
90019 _outputIndices = _ops$sparse$sparseRes.outputIndices,
90020 outputShape = _ops$sparse$sparseRes.outputShape;
90021 return [_outputIndices, outputShape];
90022 }
90023 case 'SparseSegmentMean':
90024 {
90025 var outputData = ops.sparse.sparseSegmentMean(getParamValue('data', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('segmentIds', node, tensorMap, context));
90026 return [outputData];
90027 }
90028 case 'SparseSegmentSum':
90029 {
90030 var _outputData = ops.sparse.sparseSegmentSum(getParamValue('data', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('segmentIds', node, tensorMap, context));
90031 return [_outputData];
90032 }
90033 default:
90034 throw TypeError("Node type ".concat(node.op, " is not implemented"));
90035 }
90036 };
90037 var CATEGORY$3 = 'sparse';
90038
90039 /**
90040 * @license
90041 * Copyright 2018 Google LLC. All Rights Reserved.
90042 * Licensed under the Apache License, Version 2.0 (the "License");
90043 * you may not use this file except in compliance with the License.
90044 * You may obtain a copy of the License at
90045 *
90046 * http://www.apache.org/licenses/LICENSE-2.0
90047 *
90048 * Unless required by applicable law or agreed to in writing, software
90049 * distributed under the License is distributed on an "AS IS" BASIS,
90050 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90051 * See the License for the specific language governing permissions and
90052 * limitations under the License.
90053 * =============================================================================
90054 */
90055 var executeOp$3 = function executeOp(node, tensorMap, context) {
90056 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
90057 switch (node.op) {
90058 case 'FFT':
90059 {
90060 return [ops.fft(getParamValue('x', node, tensorMap, context))];
90061 }
90062 case 'IFFT':
90063 {
90064 return [ops.ifft(getParamValue('x', node, tensorMap, context))];
90065 }
90066 case 'RFFT':
90067 {
90068 return [ops.rfft(getParamValue('x', node, tensorMap, context))];
90069 }
90070 case 'IRFFT':
90071 {
90072 return [ops.irfft(getParamValue('x', node, tensorMap, context))];
90073 }
90074 default:
90075 throw TypeError("Node type ".concat(node.op, " is not implemented"));
90076 }
90077 };
90078 var CATEGORY$2 = 'spectral';
90079
90080 /**
90081 * @license
90082 * Copyright 2021 Google LLC. All Rights Reserved.
90083 * Licensed under the Apache License, Version 2.0 (the "License");
90084 * you may not use this file except in compliance with the License.
90085 * You may obtain a copy of the License at
90086 *
90087 * http://www.apache.org/licenses/LICENSE-2.0
90088 *
90089 * Unless required by applicable law or agreed to in writing, software
90090 * distributed under the License is distributed on an "AS IS" BASIS,
90091 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90092 * See the License for the specific language governing permissions and
90093 * limitations under the License.
90094 * =============================================================================
90095 */
90096 var executeOp$2 = function executeOp(node, tensorMap, context) {
90097 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
90098 switch (node.op) {
90099 case 'StaticRegexReplace':
90100 {
90101 return [ops.string.staticRegexReplace(getParamValue('input', node, tensorMap, context), getParamValue('pattern', node, tensorMap, context), getParamValue('rewrite', node, tensorMap, context), getParamValue('replaceGlobal', node, tensorMap, context))];
90102 }
90103 case 'StringNGrams':
90104 {
90105 var _ops$string$stringNGr = ops.string.stringNGrams(getParamValue('data', node, tensorMap, context), getParamValue('dataSplits', node, tensorMap, context), getParamValue('separator', node, tensorMap, context), getParamValue('nGramWidths', node, tensorMap, context), getParamValue('leftPad', node, tensorMap, context), getParamValue('rightPad', node, tensorMap, context), getParamValue('padWidth', node, tensorMap, context), getParamValue('preserveShortSequences', node, tensorMap, context)),
90106 nGrams = _ops$string$stringNGr.nGrams,
90107 nGramsSplits = _ops$string$stringNGr.nGramsSplits;
90108 return [nGrams, nGramsSplits];
90109 }
90110 case 'StringSplit':
90111 {
90112 var _ops$string$stringSpl = ops.string.stringSplit(getParamValue('input', node, tensorMap, context), getParamValue('delimiter', node, tensorMap, context), getParamValue('skipEmpty', node, tensorMap, context)),
90113 indices = _ops$string$stringSpl.indices,
90114 values = _ops$string$stringSpl.values,
90115 shape = _ops$string$stringSpl.shape;
90116 return [indices, values, shape];
90117 }
90118 case 'StringToHashBucketFast':
90119 {
90120 var output = ops.string.stringToHashBucketFast(getParamValue('input', node, tensorMap, context), getParamValue('numBuckets', node, tensorMap, context));
90121 return [output];
90122 }
90123 default:
90124 throw TypeError("Node type ".concat(node.op, " is not implemented"));
90125 }
90126 };
90127 var CATEGORY$1 = 'string';
90128
90129 /**
90130 * @license
90131 * Copyright 2018 Google LLC. All Rights Reserved.
90132 * Licensed under the Apache License, Version 2.0 (the "License");
90133 * you may not use this file except in compliance with the License.
90134 * You may obtain a copy of the License at
90135 *
90136 * http://www.apache.org/licenses/LICENSE-2.0
90137 *
90138 * Unless required by applicable law or agreed to in writing, software
90139 * distributed under the License is distributed on an "AS IS" BASIS,
90140 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90141 * See the License for the specific language governing permissions and
90142 * limitations under the License.
90143 * =============================================================================
90144 */
90145 var executeOp$1 = function executeOp(node, tensorMap, context) {
90146 var ops = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : tfOps;
90147 switch (node.op) {
90148 case 'Cast':
90149 {
90150 return [ops.cast(getParamValue('x', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
90151 }
90152 case 'ExpandDims':
90153 {
90154 var axis = getParamValue('axis', node, tensorMap, context);
90155 return [ops.expandDims(getParamValue('x', node, tensorMap, context), axis)];
90156 }
90157 case 'Squeeze':
90158 {
90159 var _axis = getParamValue('axis', node, tensorMap, context);
90160 return [ops.squeeze(getParamValue('x', node, tensorMap, context), _axis)];
90161 }
90162 case 'Reshape':
90163 {
90164 return [ops.reshape(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
90165 }
90166 case 'EnsureShape':
90167 {
90168 return [ops.ensureShape(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
90169 }
90170 case 'MirrorPad':
90171 {
90172 return [ops.mirrorPad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('mode', node, tensorMap, context))];
90173 }
90174 case 'PadV2':
90175 case 'Pad':
90176 {
90177 return [ops.pad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('constantValue', node, tensorMap, context))];
90178 }
90179 case 'SpaceToBatchND':
90180 {
90181 var blockShape = getParamValue('blockShape', node, tensorMap, context);
90182 var paddings = getParamValue('paddings', node, tensorMap, context);
90183 return [ops.spaceToBatchND(getParamValue('x', node, tensorMap, context), blockShape, paddings)];
90184 }
90185 case 'BatchToSpaceND':
90186 {
90187 var _blockShape = getParamValue('blockShape', node, tensorMap, context);
90188 var crops = getParamValue('crops', node, tensorMap, context);
90189 return [ops.batchToSpaceND(getParamValue('x', node, tensorMap, context), _blockShape, crops)];
90190 }
90191 case 'DepthToSpace':
90192 {
90193 var blockSize = getParamValue('blockSize', node, tensorMap, context);
90194 var dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
90195 return [ops.depthToSpace(getParamValue('x', node, tensorMap, context), blockSize, dataFormat)];
90196 }
90197 case 'BroadcastTo':
90198 {
90199 return [ops.broadcastTo(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
90200 }
90201 case 'BroadcastArgs':
90202 {
90203 return [ops.broadcastArgs(getParamValue('s0', node, tensorMap, context), getParamValue('s1', node, tensorMap, context))];
90204 }
90205 default:
90206 throw TypeError("Node type ".concat(node.op, " is not implemented"));
90207 }
90208 };
90209 var CATEGORY = 'transformation';
90210
90211 /**
90212 * @license
90213 * Copyright 2018 Google LLC. All Rights Reserved.
90214 * Licensed under the Apache License, Version 2.0 (the "License");
90215 * you may not use this file except in compliance with the License.
90216 * You may obtain a copy of the License at
90217 *
90218 * http://www.apache.org/licenses/LICENSE-2.0
90219 *
90220 * Unless required by applicable law or agreed to in writing, software
90221 * distributed under the License is distributed on an "AS IS" BASIS,
90222 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90223 * See the License for the specific language governing permissions and
90224 * limitations under the License.
90225 * =============================================================================
90226 */
90227 /**
90228 * Executes the op defined by the node object.
90229 * @param node
90230 * @param tensorMap contains tensors for executed nodes and weights
90231 * @param context contains tensors and information for running the current node.
90232 * @param resourceManager Optional. Contains global resources of the model.
90233 */
90234 function executeOp(node, tensorMap, context, resourceManager) {
90235 var tidy$1 = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : tidy;
90236 var value = function (node, tensorMap, context) {
90237 switch (node.category) {
90238 case 'arithmetic':
90239 return tidy$1(function () {
90240 return executeOp$k(node, tensorMap, context);
90241 });
90242 case 'basic_math':
90243 return tidy$1(function () {
90244 return executeOp$j(node, tensorMap, context);
90245 });
90246 case 'control':
90247 return executeOp$i(node, tensorMap, context);
90248 case 'convolution':
90249 return tidy$1(function () {
90250 return executeOp$h(node, tensorMap, context);
90251 });
90252 case 'creation':
90253 return tidy$1(function () {
90254 return executeOp$g(node, tensorMap, context);
90255 });
90256 case 'dynamic':
90257 return executeOp$f(node, tensorMap, context);
90258 case 'evaluation':
90259 return tidy$1(function () {
90260 return executeOp$e(node, tensorMap, context);
90261 });
90262 case 'image':
90263 return tidy$1(function () {
90264 return executeOp$b(node, tensorMap, context);
90265 });
90266 case 'graph':
90267 return tidy$1(function () {
90268 return executeOp$d(node, tensorMap, context);
90269 });
90270 case 'logical':
90271 return tidy$1(function () {
90272 return executeOp$a(node, tensorMap, context);
90273 });
90274 case 'matrices':
90275 return tidy$1(function () {
90276 return executeOp$9(node, tensorMap, context);
90277 });
90278 case 'normalization':
90279 return tidy$1(function () {
90280 return executeOp$8(node, tensorMap, context);
90281 });
90282 case 'ragged':
90283 return tidy$1(function () {
90284 return executeOp$7(node, tensorMap, context);
90285 });
90286 case 'reduction':
90287 return tidy$1(function () {
90288 return executeOp$6(node, tensorMap, context);
90289 });
90290 case 'slice_join':
90291 return tidy$1(function () {
90292 return executeOp$5(node, tensorMap, context);
90293 });
90294 case 'sparse':
90295 return tidy$1(function () {
90296 return executeOp$4(node, tensorMap, context);
90297 });
90298 case 'spectral':
90299 return tidy$1(function () {
90300 return executeOp$3(node, tensorMap, context);
90301 });
90302 case 'string':
90303 return tidy$1(function () {
90304 return executeOp$2(node, tensorMap, context);
90305 });
90306 case 'transformation':
90307 return tidy$1(function () {
90308 return executeOp$1(node, tensorMap, context);
90309 });
90310 case 'hash_table':
90311 return executeOp$c(node, tensorMap, context, resourceManager);
90312 case 'custom':
90313 var opMapper = getRegisteredOp(node.op);
90314 if (opMapper && opMapper.customExecutor) {
90315 return opMapper.customExecutor(new NodeValueImpl(node, tensorMap, context));
90316 } else {
90317 throw TypeError("Custom op ".concat(node.op, " is not registered."));
90318 }
90319 default:
90320 throw TypeError("Unknown op '".concat(node.op, "'. File an issue at ") + "https://github.com/tensorflow/tfjs/issues so we can add it" + ", or register a custom execution with tf.registerOp()");
90321 }
90322 }(node, tensorMap, context);
90323 if (isPromise(value)) {
90324 return value.then(function (data) {
90325 return [].concat(data);
90326 });
90327 }
90328 return [].concat(value);
90329 }
90330
90331 /**
90332 * ExecutionContext captures the runtime environment of the node. It keeps
90333 * track of the current frame and iteration for the control flow ops.
90334 *
90335 * For example, typical Dynamic RNN model may contain loops, for which
90336 * TensorFlow will generate graphs with Enter/Exit nodes to control the
90337 * current execution frame, and NextIteration Nodes for iteration id increment.
90338 * For model with branch logic, TensorFLow will generate Switch/Merge ops.
90339 */
90340 var ExecutionContext = /*#__PURE__*/function () {
90341 function ExecutionContext() {
90342 var weightMap = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};
90343 var tensorArrayMap = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
90344 var tensorListMap = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : {};
90345 var functionMap = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : {};
90346 var parseNodeNameCache = arguments.length > 4 ? arguments[4] : undefined;
90347 _classCallCheck(this, ExecutionContext);
90348 this.weightMap = weightMap;
90349 this.tensorArrayMap = tensorArrayMap;
90350 this.tensorListMap = tensorListMap;
90351 this.functionMap = functionMap;
90352 this.parseNodeNameCache = parseNodeNameCache;
90353 this.rootContext = {
90354 id: 0,
90355 frameName: '',
90356 iterationId: 0
90357 };
90358 this.contexts = [this.rootContext];
90359 this.lastId = 0;
90360 this.generateCurrentContextIds();
90361 }
90362 _createClass(ExecutionContext, [{
90363 key: "newFrame",
90364 value: function newFrame(id, frameName) {
90365 return {
90366 id: id,
90367 frameName: frameName,
90368 iterationId: 0
90369 };
90370 }
90371 /**
90372 * Set the current context
90373 * @param contexts: ExecutionContextInfo[] the current path of execution
90374 * frames
90375 */
90376 }, {
90377 key: "currentContext",
90378 get: function get() {
90379 return this.contexts;
90380 }
90381 /**
90382 * Returns the current context in string format.
90383 */,
90384 set: function set(contexts) {
90385 if (this.contexts !== contexts) {
90386 this.contexts = contexts;
90387 this.generateCurrentContextIds();
90388 }
90389 }
90390 }, {
90391 key: "currentContextId",
90392 get: function get() {
90393 return this._currentContextIds[0];
90394 }
90395 /**
90396 * Returns the current context and all parent contexts in string format.
90397 * This allow access to the nodes in the current and parent frames.
90398 */
90399 }, {
90400 key: "currentContextIds",
90401 get: function get() {
90402 return this._currentContextIds;
90403 }
90404 }, {
90405 key: "generateCurrentContextIds",
90406 value: function generateCurrentContextIds() {
90407 var names = [];
90408 for (var i = 0; i < this.contexts.length - 1; i++) {
90409 var contexts = this.contexts.slice(0, this.contexts.length - i);
90410 names.push(this.contextIdforContexts(contexts));
90411 }
90412 names.push('');
90413 this._currentContextIds = names;
90414 }
90415 }, {
90416 key: "contextIdforContexts",
90417 value: function contextIdforContexts(contexts) {
90418 return contexts ? contexts.map(function (context) {
90419 return context.id === 0 && context.iterationId === 0 ? '' : "".concat(context.frameName, "-").concat(context.iterationId);
90420 }).join('/') : '';
90421 }
90422 /**
90423 * Enter a new frame, a new context is pushed on the current context list.
90424 * @param frameId new frame id
90425 */
90426 }, {
90427 key: "enterFrame",
90428 value: function enterFrame(frameId) {
90429 if (this.contexts) {
90430 this.lastId++;
90431 this.contexts = this.contexts.slice();
90432 this.contexts.push(this.newFrame(this.lastId, frameId));
90433 this._currentContextIds.unshift(this.contextIdforContexts(this.contexts));
90434 }
90435 }
90436 /**
90437 * Exit the current frame, the last context is removed from the current
90438 * context list.
90439 */
90440 }, {
90441 key: "exitFrame",
90442 value: function exitFrame() {
90443 if (this.contexts && this.contexts.length > 1) {
90444 this.contexts = this.contexts.slice();
90445 this.contexts.splice(-1);
90446 this.currentContextIds.shift();
90447 } else {
90448 throw new Error('Cannot exit frame, the context is empty');
90449 }
90450 }
90451 /**
90452 * Enter the next iteration of a loop, the iteration id of last context is
90453 * increased.
90454 */
90455 }, {
90456 key: "nextIteration",
90457 value: function nextIteration() {
90458 if (this.contexts && this.contexts.length > 0) {
90459 this.contexts = this.contexts.slice();
90460 this.lastId++;
90461 var context = Object.assign({}, this.contexts[this.contexts.length - 1]);
90462 context.iterationId += 1;
90463 context.id = this.lastId;
90464 this.contexts.splice(-1, 1, context);
90465 this._currentContextIds.splice(0, 1, this.contextIdforContexts(this.contexts));
90466 } else {
90467 throw new Error('Cannot increase frame iteration, the context is empty');
90468 }
90469 }
90470 }, {
90471 key: "getWeight",
90472 value: function getWeight(name) {
90473 return this.weightMap[name];
90474 }
90475 }, {
90476 key: "addTensorArray",
90477 value: function addTensorArray(tensorArray) {
90478 this.tensorArrayMap[tensorArray.id] = tensorArray;
90479 }
90480 }, {
90481 key: "getTensorArray",
90482 value: function getTensorArray(id) {
90483 return this.tensorArrayMap[id];
90484 }
90485 }, {
90486 key: "addTensorList",
90487 value: function addTensorList(tensorList) {
90488 this.tensorListMap[tensorList.id] = tensorList;
90489 }
90490 }, {
90491 key: "getTensorList",
90492 value: function getTensorList(id) {
90493 return this.tensorListMap[id];
90494 }
90495 }, {
90496 key: "dispose",
90497 value: function dispose(keepIds) {
90498 for (var key in this.tensorArrayMap) {
90499 this.tensorArrayMap[key].clearAndClose(keepIds);
90500 }
90501 for (var _key in this.tensorListMap) {
90502 this.tensorListMap[_key].clearAndClose(keepIds);
90503 }
90504 }
90505 }]);
90506 return ExecutionContext;
90507 }();
90508
90509 /**
90510 * Given graph inputs and desired outputs, find the minimal set of nodes
90511 * to execute in order to compute the outputs. In addition return other useful
90512 * info such:
90513 * - Missing inputs needed to compute the output.
90514 * - Whether the subgraph contains dynamic ops (control flow, dynamic shape).
90515 * - Alternative inputs in order to avoid async (dynamic op) execution.
90516 */
90517 function getExecutionSubgraph(inputs, outputs, weightMap, initNodes) {
90518 var usedNodes = new Set();
90519 var missingInputs = [];
90520 var dynamicNode = null;
90521 var syncInputs = null;
90522 // Start with the outputs, going backwards and find all the nodes that are
90523 // needed to compute those outputs.
90524 var seen = new Set();
90525 var inputNodeNames = new Set(Object.keys(inputs).map(function (name) {
90526 return parseNodeName(name)[0];
90527 }));
90528 initNodes = initNodes || [];
90529 var initNodeNames = new Set(initNodes.map(function (node) {
90530 return parseNodeName(node.name)[0];
90531 }));
90532 var frontier = _toConsumableArray(outputs);
90533 while (frontier.length > 0) {
90534 var node = frontier.pop();
90535 if (isControlFlow(node) || isDynamicShape(node) || isHashTable(node)) {
90536 if (dynamicNode == null) {
90537 dynamicNode = node;
90538 syncInputs = dynamicNode.children.map(function (child) {
90539 return child.name;
90540 }).filter(function (name) {
90541 return usedNodes.has(name);
90542 });
90543 }
90544 }
90545 usedNodes.add(node.name);
90546 // Weights are dead end since we already have their values.
90547 if (weightMap[node.name] != null) {
90548 continue;
90549 }
90550 // This node is a dead end since it's one of the user-provided inputs.
90551 if (inputNodeNames.has(node.name)) {
90552 continue;
90553 }
90554 // This node is a dead end since it doesn't have any inputs.
90555 if (initNodeNames.has(node.name)) {
90556 continue;
90557 }
90558 if (node.inputs.length === 0) {
90559 missingInputs.push(node.name);
90560 continue;
90561 }
90562 node.inputs.forEach(function (input) {
90563 // Don't add to the frontier if it is already there.
90564 if (seen.has(input.name)) {
90565 return;
90566 }
90567 seen.add(input.name);
90568 frontier.push(input);
90569 });
90570 }
90571 return {
90572 inputs: inputs,
90573 outputs: outputs,
90574 usedNodes: usedNodes,
90575 missingInputs: missingInputs,
90576 dynamicNode: dynamicNode,
90577 syncInputs: syncInputs
90578 };
90579 }
90580 /**
90581 * Given the execution info, return a list of nodes in topological order that
90582 * need to be executed to compute the output.
90583 */
90584 function getNodesInTopologicalOrder(graph, executionInfo) {
90585 var usedNodes = executionInfo.usedNodes,
90586 inputs = executionInfo.inputs;
90587 var inputNodes = Object.keys(inputs).map(function (name) {
90588 return parseNodeName(name)[0];
90589 }).map(function (name) {
90590 return graph.nodes[name];
90591 });
90592 var initNodes = graph.initNodes || [];
90593 var isUsed = function isUsed(node) {
90594 return usedNodes.has(typeof node === 'string' ? node : node.name);
90595 };
90596 function unique(nodes) {
90597 return _toConsumableArray(new Map(nodes.map(function (node) {
90598 return [node.name, node];
90599 })).values());
90600 }
90601 var predefinedNodes = unique([].concat(_toConsumableArray(inputNodes), _toConsumableArray(graph.weights), _toConsumableArray(initNodes))).filter(isUsed);
90602 var allNodes = unique([].concat(_toConsumableArray(predefinedNodes), _toConsumableArray(Object.values(graph.nodes)))).filter(isUsed);
90603 var nameToNode = new Map(allNodes.map(function (node) {
90604 return [node.name, node];
90605 }));
90606 var inCounts = {};
90607 var _iterator = _createForOfIteratorHelper(allNodes),
90608 _step;
90609 try {
90610 for (_iterator.s(); !(_step = _iterator.n()).done;) {
90611 var _node = _step.value;
90612 inCounts[_node.name] = inCounts[_node.name] || 0;
90613 var _iterator3 = _createForOfIteratorHelper(_node.children),
90614 _step3;
90615 try {
90616 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
90617 var _child = _step3.value;
90618 // When the child is unused, set in counts to infinity so that it will
90619 // never be decreased to 0 and added to the execution list.
90620 if (!isUsed(_child)) {
90621 inCounts[_child.name] = Number.POSITIVE_INFINITY;
90622 }
90623 inCounts[_child.name] = (inCounts[_child.name] || 0) + 1;
90624 }
90625 } catch (err) {
90626 _iterator3.e(err);
90627 } finally {
90628 _iterator3.f();
90629 }
90630 }
90631 // Build execution order for all used nodes regardless whether they are
90632 // predefined or not.
90633 } catch (err) {
90634 _iterator.e(err);
90635 } finally {
90636 _iterator.f();
90637 }
90638 var frontier = Object.entries(inCounts).filter(function (_ref) {
90639 var _ref2 = _slicedToArray(_ref, 2),
90640 inCount = _ref2[1];
90641 return inCount === 0;
90642 }).map(function (_ref3) {
90643 var _ref4 = _slicedToArray(_ref3, 1),
90644 name = _ref4[0];
90645 return name;
90646 });
90647 var orderedNodeNames = _toConsumableArray(frontier);
90648 while (frontier.length > 0) {
90649 var nodeName = frontier.pop();
90650 var node = nameToNode.get(nodeName);
90651 var _iterator2 = _createForOfIteratorHelper(node.children.filter(isUsed)),
90652 _step2;
90653 try {
90654 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
90655 var child = _step2.value;
90656 if (--inCounts[child.name] === 0) {
90657 orderedNodeNames.push(child.name);
90658 frontier.push(child.name);
90659 }
90660 }
90661 } catch (err) {
90662 _iterator2.e(err);
90663 } finally {
90664 _iterator2.f();
90665 }
90666 }
90667 var orderedNodes = orderedNodeNames.map(function (name) {
90668 return nameToNode.get(name);
90669 });
90670 var filteredOrderedNodes = filterPredefinedReachableNodes(orderedNodes, predefinedNodes);
90671 // TODO: Turn validation on/off with tf env flag.
90672 validateNodesExecutionOrder(filteredOrderedNodes, predefinedNodes);
90673 return filteredOrderedNodes;
90674 }
90675 /**
90676 * This is a helper function of `getNodesInTopologicalOrder`.
90677 * Returns ordered nodes reachable by at least one predefined node.
90678 * This can help us filter out redundant nodes from the returned node list.
90679 * For example:
90680 * If we have four nodes with dependencies like this:
90681 * a --> b --> c --> d
90682 * when node `c` is predefined (e.g. given as an input tensor), we can
90683 * skip node `a` and `b` since their outputs will never be used.
90684 *
90685 * @param orderedNodes Graph nodes in execution order.
90686 * @param predefinedNodes Graph inputs, weights, and init nodes. Nodes in this
90687 * list must have distinct names.
90688 */
90689 function filterPredefinedReachableNodes(orderedNodes, predefinedNodes) {
90690 var nameToNode = new Map(orderedNodes.map(function (node) {
90691 return [node.name, node];
90692 }));
90693 // TODO: Filter out more nodes when >=2 nodes are predefined in a path.
90694 var stack = predefinedNodes.map(function (node) {
90695 return node.name;
90696 });
90697 var predefinedReachableNodeNames = new Set(stack);
90698 // Perform a DFS starting from the set of all predefined nodes
90699 // to find the set of all nodes reachable from the predefined nodes.
90700 while (stack.length > 0) {
90701 var nodeName = stack.pop();
90702 var node = nameToNode.get(nodeName);
90703 var _iterator4 = _createForOfIteratorHelper(node.children),
90704 _step4;
90705 try {
90706 for (_iterator4.s(); !(_step4 = _iterator4.n()).done;) {
90707 var child = _step4.value;
90708 if (!nameToNode.has(child.name) || predefinedReachableNodeNames.has(child.name)) {
90709 continue;
90710 }
90711 predefinedReachableNodeNames.add(child.name);
90712 stack.push(child.name);
90713 }
90714 } catch (err) {
90715 _iterator4.e(err);
90716 } finally {
90717 _iterator4.f();
90718 }
90719 }
90720 // Filter out unreachable nodes and build the ordered node list.
90721 var filteredOrderedNodes = orderedNodes.filter(function (node) {
90722 return predefinedReachableNodeNames.has(node.name);
90723 });
90724 return filteredOrderedNodes;
90725 }
90726 var NodesExecutionOrderError = /*#__PURE__*/function (_Error) {
90727 _inherits(NodesExecutionOrderError, _Error);
90728 var _super = _createSuper(NodesExecutionOrderError);
90729 function NodesExecutionOrderError(message) {
90730 _classCallCheck(this, NodesExecutionOrderError);
90731 return _super.call(this, "NodesExecutionOrderError: ".concat(message));
90732 }
90733 return _createClass(NodesExecutionOrderError);
90734 }( /*#__PURE__*/_wrapNativeSuper(Error));
90735 /**
90736 * This is a helper function of `getNodesInTopologicalOrder`.
90737 * Validates property: given nodes `a` and `b`, Order(a) > Order(b) if `a`
90738 * is a child of `b`. This function throws an error if validation fails.
90739 *
90740 * @param orderedNodes Graph nodes in execution order.
90741 * @param predefinedNodes Graph inputs, weights, and init nodes. Nodes in this
90742 * list must have distinct names.
90743 */
90744 function validateNodesExecutionOrder(orderedNodes, predefinedNodes) {
90745 var nodeNameToOrder = new Map(orderedNodes.map(function (node, order) {
90746 return [node.name, order];
90747 }));
90748 var predefinedNodeNames = new Set(predefinedNodes.map(function (node) {
90749 return node.name;
90750 }));
90751 var isPredefined = function isPredefined(node) {
90752 return predefinedNodeNames.has(typeof node === 'string' ? node : node.name);
90753 };
90754 var willBeExecutedNodeNames = new Set(orderedNodes.map(function (node) {
90755 return node.name;
90756 }));
90757 var willBeExecuted = function willBeExecuted(node) {
90758 return willBeExecutedNodeNames.has(typeof node === 'string' ? node : node.name);
90759 };
90760 var _iterator5 = _createForOfIteratorHelper(orderedNodes),
90761 _step5;
90762 try {
90763 for (_iterator5.s(); !(_step5 = _iterator5.n()).done;) {
90764 var node = _step5.value;
90765 var _iterator6 = _createForOfIteratorHelper(node.children.filter(willBeExecuted)),
90766 _step6;
90767 try {
90768 for (_iterator6.s(); !(_step6 = _iterator6.n()).done;) {
90769 var child = _step6.value;
90770 if (!nodeNameToOrder.has(child.name)) {
90771 throw new NodesExecutionOrderError("Child ".concat(child.name, " of node ").concat(node.name, " is unreachable."));
90772 }
90773 if (nodeNameToOrder.get(node.name) > nodeNameToOrder.get(child.name)) {
90774 throw new NodesExecutionOrderError("Node ".concat(node.name, " is scheduled to run after its child ").concat(child.name, "."));
90775 }
90776 }
90777 } catch (err) {
90778 _iterator6.e(err);
90779 } finally {
90780 _iterator6.f();
90781 }
90782 if (!isPredefined(node)) {
90783 var _iterator7 = _createForOfIteratorHelper(node.inputs),
90784 _step7;
90785 try {
90786 for (_iterator7.s(); !(_step7 = _iterator7.n()).done;) {
90787 var input = _step7.value;
90788 if (!nodeNameToOrder.has(input.name)) {
90789 throw new NodesExecutionOrderError("Input ".concat(input.name, " of node ").concat(node.name, " is unreachable."));
90790 }
90791 if (nodeNameToOrder.get(input.name) > nodeNameToOrder.get(node.name)) {
90792 throw new NodesExecutionOrderError("Node ".concat(node.name, " is scheduled to run before its input ").concat(input.name, "."));
90793 }
90794 }
90795 } catch (err) {
90796 _iterator7.e(err);
90797 } finally {
90798 _iterator7.f();
90799 }
90800 }
90801 }
90802 } catch (err) {
90803 _iterator5.e(err);
90804 } finally {
90805 _iterator5.f();
90806 }
90807 }
90808 /**
90809 * Given the execution info, return a map from node name to the disposable
90810 * node name list after its execution.
90811 *
90812 * @returns A map from node name to disposable nodes after its
90813 * execution. That is, for a node `x`, `nodeLiveUntilMap[x]` indicates
90814 * all nodes which their intermediate tensors should be disposed after `x`
90815 * being executed.
90816 */
90817 function getNodeLiveUntilMap(orderedNodes) {
90818 var nodeNameToOrder = new Map(orderedNodes.map(function (node, order) {
90819 return [node.name, order];
90820 }));
90821 var INF_LIFE = Number.MAX_SAFE_INTEGER;
90822 // Make control flow nodes (and consequently their direct parents)
90823 // live forever since they're tricky to track correctly.
90824 var selfLifespans = orderedNodes.map(function (node, nodeOrder) {
90825 return isControlFlow(node) ? INF_LIFE : nodeOrder;
90826 });
90827 var getSelfLifeSpan = function getSelfLifeSpan(node) {
90828 var selfLife = selfLifespans[nodeNameToOrder.get(node.name)];
90829 if (selfLife == null) {
90830 // If nodeToOrder does not contain the node, it is unused or
90831 // unreachable in graph.
90832 return -1;
90833 }
90834 return selfLife;
90835 };
90836 // `liveUntil[i]` points to the last node in the `orderedNodes` array that
90837 // may depend on tensors from node `i`. It indicates that all the
90838 // intermediate tensors from `orderedNodes[i]` should be disposed after
90839 // `orderedNodes[liveUntil[i]]` is executed.
90840 // A node lives long enough to pass on its tensors to its children.
90841 // It lives until at least `max(node's position, children's positions)`.
90842 var liveUntilOrders = orderedNodes.map(function (node, nodeOrder) {
90843 return node.children.map(getSelfLifeSpan).reduce(function (a, b) {
90844 return Math.max(a, b);
90845 }, selfLifespans[nodeOrder]);
90846 });
90847 // liveUntilMap:
90848 // - Key: Name of a node `x`
90849 // - Values: All nodes whose intermediate tensors should be disposed
90850 // after `x` is executed.
90851 var liveUntilMap = new Map();
90852 for (var nodeOrder = 0; nodeOrder < orderedNodes.length; ++nodeOrder) {
90853 var liveUntilOrder = liveUntilOrders[nodeOrder];
90854 if (liveUntilOrder === INF_LIFE) {
90855 continue;
90856 }
90857 var node = orderedNodes[nodeOrder];
90858 var liveUntilNode = orderedNodes[liveUntilOrder];
90859 if (!liveUntilMap.has(liveUntilNode.name)) {
90860 liveUntilMap.set(liveUntilNode.name, []);
90861 }
90862 liveUntilMap.get(liveUntilNode.name).push(node);
90863 }
90864 return liveUntilMap;
90865 }
90866 var CONTROL_FLOW_OPS = new Set(['Switch', 'Merge', 'Enter', 'Exit', 'NextIteration', 'StatelessIf', 'StatelessWhile', 'if', 'While']);
90867 var DYNAMIC_SHAPE_OPS = new Set(['NonMaxSuppressionV2', 'NonMaxSuppressionV3', 'NonMaxSuppressionV5', 'Where']);
90868 var HASH_TABLE_OPS = new Set(['HashTable', 'HashTableV2', 'LookupTableImport', 'LookupTableImportV2', 'LookupTableFind', 'LookupTableFindV2', 'LookupTableSize', 'LookupTableSizeV2']);
90869 function isControlFlow(node) {
90870 return CONTROL_FLOW_OPS.has(node.op);
90871 }
90872 function isDynamicShape(node) {
90873 return DYNAMIC_SHAPE_OPS.has(node.op);
90874 }
90875 function isHashTable(node) {
90876 return HASH_TABLE_OPS.has(node.op);
90877 }
90878
90879 var GraphExecutor = /*#__PURE__*/function () {
90880 /**
90881 *
90882 * @param graph Graph the model or function graph to be executed.
90883 * @param parent When building function exector you need to set the parent
90884 * executor. Since the weights and function executor maps are set at parant
90885 * level, that function executor can access the function maps and weight maps
90886 * through the parent.
90887 */
90888 function GraphExecutor(graph, parent) {
90889 var _this = this;
90890 _classCallCheck(this, GraphExecutor);
90891 this.graph = graph;
90892 this.parent = parent;
90893 this.compiledMap = new Map();
90894 this.parseNodeNameCache = new Map();
90895 this._weightMap = {};
90896 this.SEPARATOR = ',';
90897 this._functions = {};
90898 this._functionExecutorMap = {};
90899 this.keepIntermediateTensors = false;
90900 this._outputs = graph.outputs;
90901 this._inputs = graph.inputs;
90902 this._initNodes = graph.initNodes;
90903 this._signature = graph.signature;
90904 this._functions = graph.functions;
90905 // create sub-graph executors
90906 if (graph.functions != null) {
90907 Object.keys(graph.functions).forEach(function (name) {
90908 _this._functionExecutorMap[name] = new GraphExecutor(graph.functions[name], _this);
90909 });
90910 }
90911 }
90912 _createClass(GraphExecutor, [{
90913 key: "weightIds",
90914 get: function get() {
90915 return this.parent ? this.parent.weightIds : this._weightIds;
90916 }
90917 }, {
90918 key: "functionExecutorMap",
90919 get: function get() {
90920 return this.parent ? this.parent.functionExecutorMap : this._functionExecutorMap;
90921 }
90922 }, {
90923 key: "weightMap",
90924 get: function get() {
90925 return this.parent ? this.parent.weightMap : this._weightMap;
90926 },
90927 set: function set(weightMap) {
90928 var _ref;
90929 var weightIds = Object.keys(weightMap).map(function (key) {
90930 return weightMap[key].map(function (tensor) {
90931 return tensor.id;
90932 });
90933 });
90934 this._weightIds = (_ref = []).concat.apply(_ref, _toConsumableArray(weightIds));
90935 this._weightMap = weightMap;
90936 }
90937 /**
90938 * Set `ResourceManager` shared by executors of a model.
90939 * @param resourceManager: `ResourceManager` of the `GraphModel`.
90940 */
90941 }, {
90942 key: "resourceManager",
90943 set: function set(resourceManager) {
90944 this._resourceManager = resourceManager;
90945 }
90946 }, {
90947 key: "inputs",
90948 get: function get() {
90949 return this._inputs.map(function (node) {
90950 return {
90951 name: node.name,
90952 shape: node.attrParams['shape'] ? node.attrParams['shape'].value : undefined,
90953 dtype: node.attrParams['dtype'] ? node.attrParams['dtype'].value : undefined
90954 };
90955 });
90956 }
90957 }, {
90958 key: "outputs",
90959 get: function get() {
90960 return this._outputs.map(function (node) {
90961 return {
90962 name: node.name,
90963 shape: node.attrParams['shape'] ? node.attrParams['shape'].value : undefined,
90964 dtype: node.attrParams['dtype'] ? node.attrParams['dtype'].value : undefined
90965 };
90966 });
90967 }
90968 }, {
90969 key: "inputNodes",
90970 get: function get() {
90971 return this._inputs.map(function (node) {
90972 return node.signatureKey || node.name;
90973 });
90974 }
90975 }, {
90976 key: "outputNodes",
90977 get: function get() {
90978 return this._outputs.map(function (node) {
90979 var name = node.signatureKey || node.name;
90980 return node.defaultOutput ? "".concat(name, ":").concat(node.defaultOutput) : name;
90981 });
90982 }
90983 }, {
90984 key: "functions",
90985 get: function get() {
90986 var _this2 = this;
90987 return Object.keys(this._functions).reduce(function (map, key) {
90988 map[key] = _this2._functions[key].signature;
90989 return map;
90990 }, {});
90991 }
90992 }, {
90993 key: "getCompilationKey",
90994 value: function getCompilationKey(inputs, outputs) {
90995 var sortedInputs = inputs.map(function (node) {
90996 return node.name;
90997 }).sort();
90998 var sortedOutputs = outputs.map(function (node) {
90999 return node.name;
91000 }).sort();
91001 return sortedInputs.join(this.SEPARATOR) + '--' + sortedOutputs.join(this.SEPARATOR);
91002 }
91003 /**
91004 * Compiles the inference graph and returns the minimal set of nodes that are
91005 * required for execution, in the correct execution order.
91006 * @returns {Object} compilation The compile result.
91007 * @returns {Node[]} compilation.orderedNodes Nodes in the correct execution
91008 * order.
91009 * @returns {Map<string, Node[]>} compilation.nodeLiveUntilMap A map from node
91010 * to disposable nodes after its execution. That is, for a node `x`,
91011 * `nodeLiveUntilMap[x]` indicates all nodes whose intermediate
91012 * tensors should be disposed after `x` is executed.
91013 */
91014 }, {
91015 key: "compile",
91016 value: function compile(inputs, outputs) {
91017 var executionInfo = getExecutionSubgraph(inputs, outputs, this.weightMap, this._initNodes);
91018 var missingInputs = executionInfo.missingInputs,
91019 dynamicNode = executionInfo.dynamicNode,
91020 syncInputs = executionInfo.syncInputs;
91021 if (dynamicNode != null) {
91022 throw new Error("This execution contains the node '".concat(dynamicNode.name, "', which has ") + "the dynamic op '".concat(dynamicNode.op, "'. Please use ") + "model.executeAsync() instead. Alternatively, to avoid the " + "dynamic ops, specify the inputs [".concat(syncInputs, "]"));
91023 }
91024 if (missingInputs.length > 0) {
91025 var outNames = outputs.map(function (n) {
91026 return n.name;
91027 });
91028 var inNames = Object.keys(inputs);
91029 throw new Error("Cannot compute the outputs [".concat(outNames, "] from the provided inputs ") + "[".concat(inNames, "]. Missing the following inputs: [").concat(missingInputs, "]"));
91030 }
91031 var orderedNodes = getNodesInTopologicalOrder(this.graph, executionInfo);
91032 var nodeLiveUntilMap = getNodeLiveUntilMap(orderedNodes);
91033 return {
91034 orderedNodes: orderedNodes,
91035 nodeLiveUntilMap: nodeLiveUntilMap
91036 };
91037 }
91038 }, {
91039 key: "cloneAndKeepTensor",
91040 value: function cloneAndKeepTensor(tensor) {
91041 if (tensor == null) {
91042 return null;
91043 }
91044 var clone = tensor.clone();
91045 // Keep the clone because`model.execute()` may be called within
91046 // a `tidy()`, but the user may inspect these tensors after the
91047 // tidy.
91048 keep(clone);
91049 return clone;
91050 }
91051 }, {
91052 key: "cloneTensorList",
91053 value: function cloneTensorList(tensors) {
91054 var _this3 = this;
91055 if (!tensors) {
91056 return null;
91057 }
91058 var clonedTensor = tensors.map(function (tensor) {
91059 return _this3.cloneAndKeepTensor(tensor);
91060 });
91061 return clonedTensor;
91062 }
91063 }, {
91064 key: "cloneTensorMap",
91065 value: function cloneTensorMap(tensorsMap) {
91066 var _this4 = this;
91067 return Object.fromEntries(Object.entries(tensorsMap).map(function (_ref2) {
91068 var _ref3 = _slicedToArray(_ref2, 2),
91069 name = _ref3[0],
91070 tensorsList = _ref3[1];
91071 return [name, _this4.cloneTensorList(tensorsList)];
91072 }));
91073 }
91074 /**
91075 * Executes the inference for given input tensors.
91076 * @param inputs Tensor map for the model inputs, keyed by the input node
91077 * names.
91078 * @param outputs Optional. output node name from the Tensorflow model, if
91079 * no outputs are specified, the default outputs of the model would be used.
91080 * You can inspect intermediate nodes of the model by adding them to the
91081 * outputs array.
91082 */
91083 }, {
91084 key: "execute",
91085 value: function execute(inputs, outputs) {
91086 var _this5 = this;
91087 // Dispose any tensors from a prior run to avoid leaking them.
91088 this.disposeIntermediateTensors();
91089 inputs = this.mapInputs(inputs);
91090 var names = Object.keys(inputs).sort();
91091 this.checkInputs(inputs);
91092 this.checkInputShapeAndType(inputs);
91093 outputs = this.mapOutputs(outputs);
91094 this.checkOutputs(outputs);
91095 var inputNodes = names.map(function (name) {
91096 return _this5.graph.nodes[parseNodeName(name)[0]];
91097 });
91098 var outputNodeNames = outputs.map(function (name) {
91099 return parseNodeName(name)[0];
91100 });
91101 var outputNodeNameSet = new Set(outputNodeNames);
91102 var outputNodes = outputNodeNames.map(function (name) {
91103 return _this5.graph.nodes[name];
91104 });
91105 // If no outputs are specified, then use the default outputs of the model.
91106 if (outputNodes.length === 0) {
91107 outputNodes = this._outputs;
91108 }
91109 var compilationKey = this.getCompilationKey(inputNodes, outputNodes);
91110 // Do nothing if the compiled graph cache contains the input.
91111 var compilation = this.compiledMap.get(compilationKey);
91112 if (compilation == null) {
91113 compilation = this.compile(inputs, outputNodes);
91114 this.compiledMap.set(compilationKey, compilation);
91115 }
91116 // Keep tensors if KEEP_INTERMEDIATE_TENSORS is on.
91117 try {
91118 this.keepIntermediateTensors = env().getBool('KEEP_INTERMEDIATE_TENSORS');
91119 } catch (e) {
91120 this.keepIntermediateTensors = false;
91121 console.warn(e.message);
91122 }
91123 var tensorArrayMap = {};
91124 var tensorListMap = {};
91125 return tidy(function () {
91126 var context = new ExecutionContext(_this5.weightMap, tensorArrayMap, tensorListMap, _this5.functionExecutorMap, _this5.parseNodeNameCache);
91127 var tensorsMap = Object.assign({}, _this5.weightMap);
91128 if (_this5.keepIntermediateTensors) {
91129 _this5.clonedTensorsMap = _this5.cloneTensorMap(_this5.weightMap);
91130 }
91131 Object.keys(inputs).forEach(function (name) {
91132 var _parseNodeName = parseNodeName(name, context),
91133 _parseNodeName2 = _slicedToArray(_parseNodeName, 2),
91134 nodeName = _parseNodeName2[0],
91135 index = _parseNodeName2[1];
91136 var tensors = [];
91137 tensors[index] = inputs[name];
91138 tensorsMap[nodeName] = tensors;
91139 if (_this5.keepIntermediateTensors) {
91140 _this5.clonedTensorsMap[nodeName] = _this5.cloneTensorList(tensors);
91141 }
91142 });
91143 var tensorsToKeep = _this5.getFrozenTensorIds(tensorsMap);
91144 var _compilation = compilation,
91145 orderedNodes = _compilation.orderedNodes,
91146 nodeLiveUntilMap = _compilation.nodeLiveUntilMap;
91147 var _iterator = _createForOfIteratorHelper(orderedNodes),
91148 _step;
91149 try {
91150 for (_iterator.s(); !(_step = _iterator.n()).done;) {
91151 var node = _step.value;
91152 if (tensorsMap[node.name]) {
91153 continue;
91154 }
91155 var tensors = executeOp(node, tensorsMap, context, _this5._resourceManager);
91156 if (isPromise(tensors)) {
91157 throw new Error("The execution of the op '".concat(node.op, "' returned a promise. ") + "Please use model.executeAsync() instead.");
91158 }
91159 tensorsMap[node.name] = tensors;
91160 if (_this5.keepIntermediateTensors) {
91161 _this5.clonedTensorsMap[node.name] = _this5.cloneTensorList(tensors);
91162 }
91163 _this5.checkTensorForDisposalWithNodeLiveUntilInfo(node, tensorsMap, context, tensorsToKeep, outputNodeNameSet, nodeLiveUntilMap.get(node.name));
91164 }
91165 // dispose the context for the root executor
91166 } catch (err) {
91167 _iterator.e(err);
91168 } finally {
91169 _iterator.f();
91170 }
91171 if (_this5.parent == null) {
91172 context.dispose(tensorsToKeep);
91173 }
91174 return outputs.map(function (name) {
91175 return getTensor(name, tensorsMap, context);
91176 });
91177 });
91178 }
91179 }, {
91180 key: "getFrozenTensorIds",
91181 value: function getFrozenTensorIds(tensorMap) {
91182 var ids = [].concat.apply([], Object.keys(tensorMap).map(function (key) {
91183 return tensorMap[key];
91184 }).map(function (tensors) {
91185 return tensors.map(function (tensor) {
91186 return tensor.id;
91187 });
91188 }));
91189 return new Set(ids);
91190 }
91191 }, {
91192 key: "checkTensorForDisposal",
91193 value: function checkTensorForDisposal(nodeName, node, tensorMap, context, tensorsToKeep, outputNodeNameSet, intermediateTensorConsumerCount) {
91194 // Skip output nodes and any control flow nodes, since its dependency is
91195 // tricky to track correctly.
91196 if (isControlFlow(node) || outputNodeNameSet.has(nodeName)) {
91197 return;
91198 }
91199 var _iterator2 = _createForOfIteratorHelper(tensorMap[nodeName]),
91200 _step2;
91201 try {
91202 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
91203 var tensor = _step2.value;
91204 if (tensor == null) {
91205 continue;
91206 }
91207 intermediateTensorConsumerCount[tensor.id] = (intermediateTensorConsumerCount[tensor.id] || 0) + node.children.length;
91208 }
91209 } catch (err) {
91210 _iterator2.e(err);
91211 } finally {
91212 _iterator2.f();
91213 }
91214 var _iterator3 = _createForOfIteratorHelper(node.inputs),
91215 _step3;
91216 try {
91217 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
91218 var input = _step3.value;
91219 // Skip any control flow nodes, since its dependency is tricky to track
91220 // correctly.
91221 if (isControlFlow(input)) {
91222 continue;
91223 }
91224 var tensors = getTensorsForCurrentContext(input.name, tensorMap, context);
91225 if (tensors == null) {
91226 continue;
91227 }
91228 var _iterator4 = _createForOfIteratorHelper(tensors),
91229 _step4;
91230 try {
91231 for (_iterator4.s(); !(_step4 = _iterator4.n()).done;) {
91232 var _tensor = _step4.value;
91233 if (!_tensor || _tensor.kept || tensorsToKeep.has(_tensor.id)) {
91234 continue;
91235 }
91236 // Only intermediate nodes' tensors have counts set, not marked as
91237 // kept, and not in `tensorsToKeep`.
91238 // Input and weight nodes' tensors should exist in `tensorsToKeep`.
91239 // Output and control flow nodes' tensors should never have count set.
91240 var count = intermediateTensorConsumerCount[_tensor.id];
91241 if (count === 1) {
91242 _tensor.dispose();
91243 delete intermediateTensorConsumerCount[_tensor.id];
91244 } else if (count != null) {
91245 intermediateTensorConsumerCount[_tensor.id]--;
91246 }
91247 }
91248 } catch (err) {
91249 _iterator4.e(err);
91250 } finally {
91251 _iterator4.f();
91252 }
91253 }
91254 } catch (err) {
91255 _iterator3.e(err);
91256 } finally {
91257 _iterator3.f();
91258 }
91259 }
91260 }, {
91261 key: "checkTensorForDisposalWithNodeLiveUntilInfo",
91262 value: function checkTensorForDisposalWithNodeLiveUntilInfo(node, tensorMap, context, tensorsToKeep, outputNodeNameSet, liveUntilNodes) {
91263 function isNonDisposableNode(node) {
91264 // Skip output nodes and any control flow nodes, since its dependency is
91265 // tricky to track correctly.
91266 return isControlFlow(node) || outputNodeNameSet.has(node.name);
91267 }
91268 if (isControlFlow(node) || liveUntilNodes == null) {
91269 return;
91270 }
91271 var _iterator5 = _createForOfIteratorHelper(liveUntilNodes),
91272 _step5;
91273 try {
91274 for (_iterator5.s(); !(_step5 = _iterator5.n()).done;) {
91275 var nodeToDispose = _step5.value;
91276 if (isNonDisposableNode(nodeToDispose)) {
91277 continue;
91278 }
91279 var tensors = getTensorsForCurrentContext(nodeToDispose.name, tensorMap, context);
91280 var _iterator6 = _createForOfIteratorHelper(tensors),
91281 _step6;
91282 try {
91283 for (_iterator6.s(); !(_step6 = _iterator6.n()).done;) {
91284 var tensor = _step6.value;
91285 if (!tensor || tensor.kept || tensorsToKeep.has(tensor.id)) {
91286 continue;
91287 }
91288 tensor.dispose();
91289 }
91290 } catch (err) {
91291 _iterator6.e(err);
91292 } finally {
91293 _iterator6.f();
91294 }
91295 }
91296 } catch (err) {
91297 _iterator5.e(err);
91298 } finally {
91299 _iterator5.f();
91300 }
91301 }
91302 /**
91303 * Executes the inference for given input tensors in Async fashion.
91304 * @param inputs Tensor map for the model inputs, keyed by the input node
91305 * names.
91306 * @param outputs output node name from the Tensorflow model, if no outputs
91307 * are specified, the default outputs of the model would be used. You can
91308 * inspect intermediate nodes of the model by adding them to the outputs
91309 * array.
91310 */
91311 }, {
91312 key: "executeAsync",
91313 value: function () {
91314 var _executeAsync2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(inputs, outputs) {
91315 return _regeneratorRuntime().wrap(function _callee$(_context) {
91316 while (1) switch (_context.prev = _context.next) {
91317 case 0:
91318 return _context.abrupt("return", this._executeAsync(inputs, outputs));
91319 case 1:
91320 case "end":
91321 return _context.stop();
91322 }
91323 }, _callee, this);
91324 }));
91325 function executeAsync(_x, _x2) {
91326 return _executeAsync2.apply(this, arguments);
91327 }
91328 return executeAsync;
91329 }()
91330 }, {
91331 key: "disposeIntermediateTensors",
91332 value: function disposeIntermediateTensors() {
91333 if (!this.clonedTensorsMap) {
91334 return;
91335 }
91336 Object.values(this.clonedTensorsMap).forEach(function (tensorsList) {
91337 var _iterator7 = _createForOfIteratorHelper(tensorsList),
91338 _step7;
91339 try {
91340 for (_iterator7.s(); !(_step7 = _iterator7.n()).done;) {
91341 var tensor = _step7.value;
91342 if (tensor && !tensor.isDisposed) {
91343 tensor.dispose();
91344 }
91345 }
91346 } catch (err) {
91347 _iterator7.e(err);
91348 } finally {
91349 _iterator7.f();
91350 }
91351 });
91352 this.clonedTensorsMap = null;
91353 }
91354 }, {
91355 key: "getIntermediateTensors",
91356 value: function getIntermediateTensors() {
91357 return this.clonedTensorsMap;
91358 }
91359 /**
91360 * Executes the inference for given input tensors in Async fashion.
91361 * @param inputs Tensor map for the model inputs, keyed by the input node
91362 * names.
91363 * @param outputs Optional. output node name from the Tensorflow model,
91364 * if no outputs are specified, the default outputs of the model would be
91365 * used. You can inspect intermediate nodes of the model by adding them to
91366 * the outputs array.
91367 * @param isFunctionExecution Optional. Flag for executing a function.
91368 * @param tensorArrayMap Optional, global TensorArray map by id. Used for
91369 * function execution.
91370 * @param tensorArrayMap Optional global TensorList map by id. Used for
91371 * function execution.
91372 */
91373 }, {
91374 key: "_executeAsync",
91375 value: function () {
91376 var _executeAsync3 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(inputs, outputs) {
91377 var isFunctionExecution,
91378 tensorArrayMap,
91379 tensorListMap,
91380 context,
91381 tensorsMap,
91382 results,
91383 outputIds,
91384 inputIds,
91385 keepIds,
91386 _args2 = arguments;
91387 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
91388 while (1) switch (_context2.prev = _context2.next) {
91389 case 0:
91390 isFunctionExecution = _args2.length > 2 && _args2[2] !== undefined ? _args2[2] : false;
91391 tensorArrayMap = _args2.length > 3 && _args2[3] !== undefined ? _args2[3] : {};
91392 tensorListMap = _args2.length > 4 && _args2[4] !== undefined ? _args2[4] : {};
91393 // Dispose any tensors from a prior run to avoid leaking them.
91394 this.disposeIntermediateTensors();
91395 if (!isFunctionExecution) {
91396 inputs = this.mapInputs(inputs);
91397 this.checkInputs(inputs);
91398 this.checkInputShapeAndType(inputs);
91399 outputs = this.mapOutputs(outputs);
91400 this.checkOutputs(outputs);
91401 }
91402 // Keep tensors if KEEP_INTERMEDIATE_TENSORS is on.
91403 try {
91404 this.keepIntermediateTensors = env().getBool('KEEP_INTERMEDIATE_TENSORS');
91405 } catch (e) {
91406 this.keepIntermediateTensors = false;
91407 console.warn(e.message);
91408 }
91409 context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap, this.parseNodeNameCache);
91410 if (this.keepIntermediateTensors) {
91411 this.clonedTensorsMap = this.cloneTensorMap(this.weightMap);
91412 }
91413 // Graph with control flow op requires runtime evaluation of the execution
91414 // order, while without control flow the execution order is pre-determined
91415 // in the compile method.
91416 _context2.next = 10;
91417 return this.executeWithControlFlow(inputs, context, outputs, isFunctionExecution);
91418 case 10:
91419 tensorsMap = _context2.sent;
91420 results = outputs.map(function (name) {
91421 return getTensor(name, tensorsMap, context);
91422 }); // dispose all the intermediate tensors
91423 outputIds = results.map(function (t) {
91424 return t.id;
91425 });
91426 inputIds = Object.keys(inputs).map(function (name) {
91427 return inputs[name].id;
91428 });
91429 keepIds = new Set([].concat(_toConsumableArray(outputIds), _toConsumableArray(inputIds), _toConsumableArray(this.weightIds)));
91430 Object.values(tensorsMap).forEach(function (tensorsList) {
91431 tensorsList.forEach(function (tensor) {
91432 if (tensor && !tensor.isDisposed && !keepIds.has(tensor.id)) {
91433 tensor.dispose();
91434 }
91435 });
91436 });
91437 // dispose the context for the root executor
91438 if (this.parent == null) {
91439 context.dispose(keepIds);
91440 }
91441 return _context2.abrupt("return", results);
91442 case 18:
91443 case "end":
91444 return _context2.stop();
91445 }
91446 }, _callee2, this);
91447 }));
91448 function _executeAsync(_x3, _x4) {
91449 return _executeAsync3.apply(this, arguments);
91450 }
91451 return _executeAsync;
91452 }()
91453 }, {
91454 key: "executeFunctionAsync",
91455 value: function () {
91456 var _executeFunctionAsync = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(inputs, tensorArrayMap, tensorListMap) {
91457 var _this6 = this;
91458 var mappedInputs;
91459 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
91460 while (1) switch (_context3.prev = _context3.next) {
91461 case 0:
91462 mappedInputs = inputs.reduce(function (map, tensor, index) {
91463 map[_this6.inputs[index].name] = tensor;
91464 return map;
91465 }, {});
91466 return _context3.abrupt("return", this._executeAsync(mappedInputs, this.outputNodes, true, tensorArrayMap, tensorListMap));
91467 case 2:
91468 case "end":
91469 return _context3.stop();
91470 }
91471 }, _callee3, this);
91472 }));
91473 function executeFunctionAsync(_x5, _x6, _x7) {
91474 return _executeFunctionAsync.apply(this, arguments);
91475 }
91476 return executeFunctionAsync;
91477 }()
91478 /**
91479 * When there are control flow nodes in the graph, the graph execution use
91480 * ExecutionContext to keep track of the frames and loop iterators.
91481 * @param inputs placeholder tensors for the graph.
91482 * @param context the execution context object for current execution.
91483 * @param outputNames Optional. output node name from the Tensorflow model,
91484 * if no outputs are specified, the default outputs of the model would be
91485 * used. You can inspect intermediate nodes of the model by adding them to
91486 * the outputs array.
91487 * @param isFunctionExecution Flag for executing a function.
91488 */
91489 }, {
91490 key: "executeWithControlFlow",
91491 value: function () {
91492 var _executeWithControlFlow = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(inputs, context, outputNames, isFunctionExecution) {
91493 var _this7 = this;
91494 var names, inputNodes, outputNodeNames, outputNodeNameSet, outputNodes, _getExecutionSubgraph, usedNodes, missingInputs, dynamicNode, syncInputs, stack, tensorsMap, intermediateTensorConsumerCount, tensorsToKeep, added, promises, missingOutputs, alternativeMsg;
91495 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
91496 while (1) switch (_context4.prev = _context4.next) {
91497 case 0:
91498 names = Object.keys(inputs);
91499 inputNodes = names.map(function (name) {
91500 return _this7.graph.nodes[parseNodeName(name)[0]];
91501 });
91502 outputNodeNames = outputNames.map(function (name) {
91503 return parseNodeName(name)[0];
91504 });
91505 outputNodeNameSet = new Set(outputNodeNames);
91506 outputNodes = outputNodeNames.map(function (name) {
91507 return _this7.graph.nodes[name];
91508 }); // If no outputs are specified, then use the default outputs of the model.
91509 if (outputNodes.length === 0) {
91510 outputNodes = this._outputs;
91511 }
91512 _getExecutionSubgraph = getExecutionSubgraph(inputs, outputNodes, this.weightMap, this._initNodes), usedNodes = _getExecutionSubgraph.usedNodes, missingInputs = _getExecutionSubgraph.missingInputs, dynamicNode = _getExecutionSubgraph.dynamicNode, syncInputs = _getExecutionSubgraph.syncInputs; // First nodes to execute include inputNodes, weights, and initNodes.
91513 stack = [].concat(_toConsumableArray(inputNodes), _toConsumableArray(this.graph.weights), _toConsumableArray(this._initNodes || [])).map(function (node) {
91514 return {
91515 node: node,
91516 contexts: context.currentContext
91517 };
91518 });
91519 tensorsMap = Object.assign({}, this.weightMap);
91520 Object.keys(inputs).forEach(function (name) {
91521 var _parseNodeName3 = parseNodeName(name),
91522 _parseNodeName4 = _slicedToArray(_parseNodeName3, 2),
91523 nodeName = _parseNodeName4[0],
91524 index = _parseNodeName4[1];
91525 var tensors = [];
91526 tensors[index] = inputs[name];
91527 tensorsMap[nodeName] = tensors;
91528 });
91529 intermediateTensorConsumerCount = {};
91530 tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
91531 added = {};
91532 case 13:
91533 if (!(stack.length > 0)) {
91534 _context4.next = 19;
91535 break;
91536 }
91537 promises = this.processStack(inputNodes, stack, context, tensorsMap, added, tensorsToKeep, outputNodeNameSet, intermediateTensorConsumerCount, usedNodes);
91538 _context4.next = 17;
91539 return Promise.all(promises);
91540 case 17:
91541 _context4.next = 13;
91542 break;
91543 case 19:
91544 if (dynamicNode == null && !isFunctionExecution) {
91545 console.warn("This model execution did not contain any nodes with control flow " + "or dynamic output shapes. You can use model.execute() instead.");
91546 }
91547 missingOutputs = outputNodes.filter(function (node) {
91548 return !isControlFlow(node) && !getTensor(node.name, tensorsMap, context);
91549 }).map(function (node) {
91550 return node.name;
91551 });
91552 if (!(missingOutputs.length > 0)) {
91553 _context4.next = 25;
91554 break;
91555 }
91556 alternativeMsg = '';
91557 if (dynamicNode != null) {
91558 alternativeMsg = "Alternatively, to avoid the dynamic ops, use model.execute() " + "and specify the inputs [".concat(syncInputs, "]");
91559 }
91560 throw new Error("Cannot compute the outputs [".concat(missingOutputs, "] from the provided ") + "inputs [".concat(names, "]. Consider providing the following inputs: ") + "[".concat(missingInputs, "]. ").concat(alternativeMsg));
91561 case 25:
91562 return _context4.abrupt("return", tensorsMap);
91563 case 26:
91564 case "end":
91565 return _context4.stop();
91566 }
91567 }, _callee4, this);
91568 }));
91569 function executeWithControlFlow(_x8, _x9, _x10, _x11) {
91570 return _executeWithControlFlow.apply(this, arguments);
91571 }
91572 return executeWithControlFlow;
91573 }()
91574 }, {
91575 key: "processStack",
91576 value: function processStack(inputNodes, stack, context, tensorMap, added, tensorsToKeep, outputNodeNameSet, intermediateTensorConsumerCount, usedNodes) {
91577 var _this8 = this;
91578 var promises = [];
91579 var _loop = function _loop() {
91580 var item = stack.pop();
91581 context.currentContext = item.contexts;
91582 var nodeName = '';
91583 // The tensor of the Enter op with isConstant set should be set
91584 // in the parent scope, so it will be available as constant for the
91585 // whole loop.
91586 if (item.node.op === 'Enter' && getParamValue('isConstant', item.node, tensorMap, context)) {
91587 var _getNodeNameAndIndex = getNodeNameAndIndex(item.node.name, context);
91588 var _getNodeNameAndIndex2 = _slicedToArray(_getNodeNameAndIndex, 1);
91589 nodeName = _getNodeNameAndIndex2[0];
91590 }
91591 // only process nodes that are not in the tensorMap yet, this include
91592 // inputNodes and internal initNodes.
91593 if (tensorMap[item.node.name] == null) {
91594 var tensors = executeOp(item.node, tensorMap, context, _this8._resourceManager);
91595 if (!nodeName) {
91596 var _getNodeNameAndIndex3 = getNodeNameAndIndex(item.node.name, context);
91597 var _getNodeNameAndIndex4 = _slicedToArray(_getNodeNameAndIndex3, 1);
91598 nodeName = _getNodeNameAndIndex4[0];
91599 }
91600 var currentContext = context.currentContext;
91601 if (isPromise(tensors)) {
91602 promises.push(tensors.then(function (t) {
91603 tensorMap[nodeName] = t;
91604 if (_this8.keepIntermediateTensors) {
91605 _this8.clonedTensorsMap[nodeName] = _this8.cloneTensorList(t);
91606 }
91607 context.currentContext = currentContext;
91608 _this8.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNodeNameSet, intermediateTensorConsumerCount);
91609 _this8.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
91610 return t;
91611 }));
91612 } else {
91613 tensorMap[nodeName] = tensors;
91614 if (_this8.keepIntermediateTensors) {
91615 _this8.clonedTensorsMap[nodeName] = _this8.cloneTensorList(tensors);
91616 }
91617 _this8.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNodeNameSet, intermediateTensorConsumerCount);
91618 _this8.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
91619 }
91620 } else {
91621 _this8.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
91622 }
91623 };
91624 while (stack.length > 0) {
91625 _loop();
91626 }
91627 return promises;
91628 }
91629 }, {
91630 key: "processChildNodes",
91631 value: function processChildNodes(node, stack, context, tensorMap, added, usedNodes) {
91632 node.children.forEach(function (childNode) {
91633 var _getNodeNameAndIndex5 = getNodeNameAndIndex(childNode.name, context),
91634 _getNodeNameAndIndex6 = _slicedToArray(_getNodeNameAndIndex5, 1),
91635 nodeName = _getNodeNameAndIndex6[0];
91636 if (added[nodeName] || !usedNodes.has(childNode.name)) {
91637 return;
91638 }
91639 // Merge op can be pushed if any of its inputs has value.
91640 if (childNode.op === 'Merge') {
91641 if (childNode.inputNames.some(function (name) {
91642 return !!getTensor(name, tensorMap, context);
91643 })) {
91644 added[nodeName] = true;
91645 stack.push({
91646 contexts: context.currentContext,
91647 node: childNode
91648 });
91649 }
91650 } else
91651 // Otherwise all inputs must to have value.
91652 if (childNode.inputNames.every(function (name) {
91653 return !!getTensor(name, tensorMap, context);
91654 })) {
91655 added[nodeName] = true;
91656 stack.push({
91657 contexts: context.currentContext,
91658 node: childNode
91659 });
91660 }
91661 });
91662 }
91663 /**
91664 * Releases the memory used by the weight tensors.
91665 */
91666 }, {
91667 key: "dispose",
91668 value: function dispose() {
91669 var _this9 = this;
91670 Object.keys(this.weightMap).forEach(function (key) {
91671 return _this9.weightMap[key].forEach(function (tensor) {
91672 return tensor.dispose();
91673 });
91674 });
91675 }
91676 }, {
91677 key: "checkInputShapeAndType",
91678 value: function checkInputShapeAndType(inputs) {
91679 var _this10 = this;
91680 Object.keys(inputs).forEach(function (name) {
91681 var input = inputs[name];
91682 var _parseNodeName5 = parseNodeName(name),
91683 _parseNodeName6 = _slicedToArray(_parseNodeName5, 1),
91684 nodeName = _parseNodeName6[0];
91685 var node = _this10.graph.nodes[nodeName];
91686 if (node.attrParams['shape'] && node.attrParams['shape'].value) {
91687 var shape = node.attrParams['shape'].value;
91688 var match = shape.length === input.shape.length && input.shape.every(function (dim, index) {
91689 return shape[index] === -1 || shape[index] === dim;
91690 });
91691 assert$1(match, function () {
91692 return "The shape of dict['".concat(node.name, "'] provided in ") + "model.execute(dict) must be [".concat(shape, "], but was ") + "[".concat(input.shape, "]");
91693 });
91694 }
91695 if (node.attrParams['dtype'] && node.attrParams['dtype'].value) {
91696 assert$1(input.dtype === node.attrParams['dtype'].value, function () {
91697 return "The dtype of dict['".concat(node.name, "'] provided in ") + "model.execute(dict) must be " + "".concat(node.attrParams['dtype'].value, ", but was ").concat(input.dtype);
91698 });
91699 }
91700 });
91701 }
91702 }, {
91703 key: "mapInputs",
91704 value: function mapInputs(inputs) {
91705 var _a, _b;
91706 var result = {};
91707 for (var inputName in inputs) {
91708 var tensor = (_b = (_a = this._signature) === null || _a === void 0 ? void 0 : _a.inputs) === null || _b === void 0 ? void 0 : _b[inputName];
91709 if (tensor != null) {
91710 result[tensor.name] = inputs[inputName];
91711 } else {
91712 result[inputName] = inputs[inputName];
91713 }
91714 }
91715 return result;
91716 }
91717 }, {
91718 key: "checkInputs",
91719 value: function checkInputs(inputs) {
91720 var _this11 = this;
91721 var notInGraph = Object.keys(inputs).filter(function (name) {
91722 var _parseNodeName7 = parseNodeName(name),
91723 _parseNodeName8 = _slicedToArray(_parseNodeName7, 1),
91724 nodeName = _parseNodeName8[0];
91725 return _this11.graph.nodes[nodeName] == null;
91726 });
91727 if (notInGraph.length > 0) {
91728 throw new Error("The dict provided in model.execute(dict) has " + "keys: [".concat(notInGraph, "] that are not part of graph"));
91729 }
91730 }
91731 }, {
91732 key: "mapOutputs",
91733 value: function mapOutputs(outputs) {
91734 var _this12 = this;
91735 return outputs.map(function (name) {
91736 var _a, _b;
91737 var tensor = (_b = (_a = _this12._signature) === null || _a === void 0 ? void 0 : _a.outputs) === null || _b === void 0 ? void 0 : _b[name];
91738 if (tensor != null) {
91739 return tensor.name;
91740 }
91741 return name;
91742 }, {});
91743 }
91744 }, {
91745 key: "checkOutputs",
91746 value: function checkOutputs(outputs) {
91747 var _this13 = this;
91748 outputs.forEach(function (name) {
91749 var _parseNodeName9 = parseNodeName(name),
91750 _parseNodeName10 = _slicedToArray(_parseNodeName9, 1),
91751 normalizedName = _parseNodeName10[0];
91752 if (!_this13.graph.nodes[normalizedName]) {
91753 throw new Error("The output '".concat(name, "' is not found in the graph"));
91754 }
91755 });
91756 }
91757 }]);
91758 return GraphExecutor;
91759 }();
91760
91761 /**
91762 * Contains global resources of a model.
91763 */
91764 var ResourceManager = /*#__PURE__*/function () {
91765 function ResourceManager() {
91766 var hashTableNameToHandle = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};
91767 var hashTableMap = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
91768 _classCallCheck(this, ResourceManager);
91769 this.hashTableNameToHandle = hashTableNameToHandle;
91770 this.hashTableMap = hashTableMap;
91771 }
91772 /**
91773 * Register a `HashTable` in the resource manager.
91774 *
91775 * The `HashTable` can be retrieved by `resourceManager.getHashTableById`,
91776 * where id is the table handle tensor's id.
91777 *
91778 * @param name Op node name that creates the `HashTable`.
91779 * @param hashTable The `HashTable` to be added to resource manager.
91780 */
91781 _createClass(ResourceManager, [{
91782 key: "addHashTable",
91783 value: function addHashTable(name, hashTable) {
91784 this.hashTableNameToHandle[name] = hashTable.handle;
91785 this.hashTableMap[hashTable.id] = hashTable;
91786 }
91787 /**
91788 * Get the table handle by node name.
91789 * @param name Op node name that creates the `HashTable`. This name is also
91790 * used in the inputs list of lookup and import `HashTable` ops.
91791 */
91792 }, {
91793 key: "getHashTableHandleByName",
91794 value: function getHashTableHandleByName(name) {
91795 return this.hashTableNameToHandle[name];
91796 }
91797 /**
91798 * Get the actual `HashTable` by its handle tensor's id.
91799 * @param id The id of the handle tensor.
91800 */
91801 }, {
91802 key: "getHashTableById",
91803 value: function getHashTableById(id) {
91804 return this.hashTableMap[id];
91805 }
91806 /**
91807 * Dispose `ResourceManager`, including its hashTables and tensors in them.
91808 */
91809 }, {
91810 key: "dispose",
91811 value: function dispose() {
91812 for (var key in this.hashTableMap) {
91813 this.hashTableMap[key].clearAndClose();
91814 delete this.hashTableMap[key];
91815 }
91816 for (var name in this.hashTableNameToHandle) {
91817 this.hashTableNameToHandle[name].dispose();
91818 delete this.hashTableNameToHandle[name];
91819 }
91820 }
91821 }]);
91822 return ResourceManager;
91823 }();
91824
91825 var TFHUB_SEARCH_PARAM = '?tfjs-format=file';
91826 var DEFAULT_MODEL_NAME = 'model.json';
91827 /**
91828 * A `tf.GraphModel` is a directed, acyclic graph built from a
91829 * SavedModel GraphDef and allows inference execution.
91830 *
91831 * A `tf.GraphModel` can only be created by loading from a model converted from
91832 * a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using
91833 * the command line converter tool and loaded via `tf.loadGraphModel`.
91834 *
91835 * @doc {heading: 'Models', subheading: 'Classes'}
91836 */
91837 var GraphModel = /*#__PURE__*/function () {
91838 /**
91839 * @param modelUrl url for the model, or an `io.IOHandler`.
91840 * @param weightManifestUrl url for the weight file generated by
91841 * scripts/convert.py script.
91842 * @param requestOption options for Request, which allows to send credentials
91843 * and custom headers.
91844 * @param onProgress Optional, progress callback function, fired periodically
91845 * before the load is completed.
91846 */
91847 function GraphModel(modelUrl) {
91848 var loadOptions = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
91849 var tfio = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : io;
91850 _classCallCheck(this, GraphModel);
91851 this.modelUrl = modelUrl;
91852 this.loadOptions = loadOptions;
91853 this.version = 'n/a';
91854 this.io = tfio;
91855 if (loadOptions == null) {
91856 this.loadOptions = {};
91857 }
91858 this.resourceManager = new ResourceManager();
91859 }
91860 _createClass(GraphModel, [{
91861 key: "modelVersion",
91862 get:
91863 // Returns the version information for the tensorflow model GraphDef.
91864 function get() {
91865 return this.version;
91866 }
91867 }, {
91868 key: "inputNodes",
91869 get: function get() {
91870 return this.executor.inputNodes;
91871 }
91872 }, {
91873 key: "outputNodes",
91874 get: function get() {
91875 return this.executor.outputNodes;
91876 }
91877 }, {
91878 key: "inputs",
91879 get: function get() {
91880 return this.executor.inputs;
91881 }
91882 }, {
91883 key: "outputs",
91884 get: function get() {
91885 return this.executor.outputs;
91886 }
91887 }, {
91888 key: "weights",
91889 get: function get() {
91890 return this.executor.weightMap;
91891 }
91892 }, {
91893 key: "metadata",
91894 get: function get() {
91895 return this.artifacts.userDefinedMetadata;
91896 }
91897 }, {
91898 key: "modelSignature",
91899 get: function get() {
91900 return this.signature;
91901 }
91902 }, {
91903 key: "modelStructuredOutputKeys",
91904 get: function get() {
91905 return this.structuredOutputKeys;
91906 }
91907 }, {
91908 key: "findIOHandler",
91909 value: function findIOHandler() {
91910 var path = this.modelUrl;
91911 if (path.load != null) {
91912 // Path is an IO Handler.
91913 this.handler = path;
91914 } else if (this.loadOptions.requestInit != null) {
91915 this.handler = this.io.browserHTTPRequest(path, this.loadOptions);
91916 } else {
91917 var handlers = this.io.getLoadHandlers(path, this.loadOptions);
91918 if (handlers.length === 0) {
91919 // For backward compatibility: if no load handler can be found,
91920 // assume it is a relative http path.
91921 handlers.push(this.io.browserHTTPRequest(path, this.loadOptions));
91922 } else if (handlers.length > 1) {
91923 throw new Error("Found more than one (".concat(handlers.length, ") load handlers for ") + "URL '".concat([path], "'"));
91924 }
91925 this.handler = handlers[0];
91926 }
91927 }
91928 /**
91929 * Loads the model and weight files, construct the in memory weight map and
91930 * compile the inference graph.
91931 */
91932 }, {
91933 key: "load",
91934 value: function load() {
91935 var _this = this;
91936 this.findIOHandler();
91937 if (this.handler.load == null) {
91938 throw new Error('Cannot proceed with model loading because the IOHandler provided ' + 'does not have the `load` method implemented.');
91939 }
91940 var loadResult = this.handler.load();
91941 if (isPromise(loadResult)) {
91942 return loadResult.then(function (artifacts) {
91943 if (artifacts.getWeightStream == null) {
91944 return _this.loadSync(artifacts);
91945 }
91946 return _this.loadStreaming(artifacts);
91947 });
91948 }
91949 return this.loadSync(loadResult);
91950 }
91951 /**
91952 * Synchronously construct the in memory weight map and
91953 * compile the inference graph.
91954 *
91955 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
91956 */
91957 }, {
91958 key: "loadSync",
91959 value: function loadSync(artifacts) {
91960 var weightMap = this.io.decodeWeights(artifacts.weightData, artifacts.weightSpecs);
91961 return this.loadWithWeightMap(artifacts, weightMap);
91962 }
91963 }, {
91964 key: "loadStreaming",
91965 value: function () {
91966 var _loadStreaming = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(artifacts) {
91967 var weightMap;
91968 return _regeneratorRuntime().wrap(function _callee$(_context) {
91969 while (1) switch (_context.prev = _context.next) {
91970 case 0:
91971 if (!(artifacts.getWeightStream == null)) {
91972 _context.next = 2;
91973 break;
91974 }
91975 throw new Error('Model artifacts missing streamWeights function');
91976 case 2:
91977 _context.next = 4;
91978 return decodeWeightsStream(artifacts.getWeightStream(), artifacts.weightSpecs);
91979 case 4:
91980 weightMap = _context.sent;
91981 return _context.abrupt("return", this.loadWithWeightMap(artifacts, weightMap));
91982 case 6:
91983 case "end":
91984 return _context.stop();
91985 }
91986 }, _callee, this);
91987 }));
91988 function loadStreaming(_x) {
91989 return _loadStreaming.apply(this, arguments);
91990 }
91991 return loadStreaming;
91992 }()
91993 }, {
91994 key: "loadWithWeightMap",
91995 value: function loadWithWeightMap(artifacts, weightMap) {
91996 this.artifacts = artifacts;
91997 var graph = this.artifacts.modelTopology;
91998 var signature = this.artifacts.signature;
91999 if (this.artifacts.userDefinedMetadata != null) {
92000 var metadata = this.artifacts.userDefinedMetadata;
92001 if (metadata.signature != null) {
92002 signature = metadata.signature;
92003 }
92004 if (metadata.structuredOutputKeys != null) {
92005 this.structuredOutputKeys = metadata.structuredOutputKeys;
92006 }
92007 }
92008 this.signature = signature;
92009 this.version = "".concat(graph.versions.producer, ".").concat(graph.versions.minConsumer);
92010 this.executor = new GraphExecutor(OperationMapper.Instance.transformGraph(graph, this.signature));
92011 this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
92012 // Attach a model-level resourceManager to each executor to share resources,
92013 // such as `HashTable`.
92014 this.executor.resourceManager = this.resourceManager;
92015 if (artifacts.modelInitializer != null && artifacts.modelInitializer.node != null) {
92016 var initializer = OperationMapper.Instance.transformGraph(artifacts.modelInitializer);
92017 this.initializer = new GraphExecutor(initializer);
92018 this.initializer.weightMap = this.executor.weightMap;
92019 // Attach a model-level resourceManager to the initializer, the
92020 // hashTables created from when executing the initializer will be stored
92021 // in the resourceManager.
92022 this.initializer.resourceManager = this.resourceManager;
92023 this.initializerSignature = artifacts.initializerSignature;
92024 }
92025 return true;
92026 }
92027 /**
92028 * Save the configuration and/or weights of the GraphModel.
92029 *
92030 * An `IOHandler` is an object that has a `save` method of the proper
92031 * signature defined. The `save` method manages the storing or
92032 * transmission of serialized data ("artifacts") that represent the
92033 * model's topology and weights onto or via a specific medium, such as
92034 * file downloads, local storage, IndexedDB in the web browser and HTTP
92035 * requests to a server. TensorFlow.js provides `IOHandler`
92036 * implementations for a number of frequently used saving mediums, such as
92037 * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
92038 * for more details.
92039 *
92040 * This method also allows you to refer to certain types of `IOHandler`s
92041 * as URL-like string shortcuts, such as 'localstorage://' and
92042 * 'indexeddb://'.
92043 *
92044 * Example 1: Save `model`'s topology and weights to browser [local
92045 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
92046 * then load it back.
92047 *
92048 * ```js
92049 * const modelUrl =
92050 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
92051 * const model = await tf.loadGraphModel(modelUrl);
92052 * const zeros = tf.zeros([1, 224, 224, 3]);
92053 * model.predict(zeros).print();
92054 *
92055 * const saveResults = await model.save('localstorage://my-model-1');
92056 *
92057 * const loadedModel = await tf.loadGraphModel('localstorage://my-model-1');
92058 * console.log('Prediction from loaded model:');
92059 * model.predict(zeros).print();
92060 * ```
92061 *
92062 * @param handlerOrURL An instance of `IOHandler` or a URL-like,
92063 * scheme-based string shortcut for `IOHandler`.
92064 * @param config Options for saving the model.
92065 * @returns A `Promise` of `SaveResult`, which summarizes the result of
92066 * the saving, such as byte sizes of the saved artifacts for the model's
92067 * topology and weight values.
92068 *
92069 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
92070 */
92071 }, {
92072 key: "save",
92073 value: function () {
92074 var _save = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(handlerOrURL, config) {
92075 var handlers;
92076 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
92077 while (1) switch (_context2.prev = _context2.next) {
92078 case 0:
92079 if (!(typeof handlerOrURL === 'string')) {
92080 _context2.next = 9;
92081 break;
92082 }
92083 handlers = this.io.getSaveHandlers(handlerOrURL);
92084 if (!(handlers.length === 0)) {
92085 _context2.next = 6;
92086 break;
92087 }
92088 throw new Error("Cannot find any save handlers for URL '".concat(handlerOrURL, "'"));
92089 case 6:
92090 if (!(handlers.length > 1)) {
92091 _context2.next = 8;
92092 break;
92093 }
92094 throw new Error("Found more than one (".concat(handlers.length, ") save handlers for ") + "URL '".concat(handlerOrURL, "'"));
92095 case 8:
92096 handlerOrURL = handlers[0];
92097 case 9:
92098 if (!(handlerOrURL.save == null)) {
92099 _context2.next = 11;
92100 break;
92101 }
92102 throw new Error('GraphModel.save() cannot proceed because the IOHandler ' + 'provided does not have the `save` attribute defined.');
92103 case 11:
92104 return _context2.abrupt("return", handlerOrURL.save(this.artifacts));
92105 case 12:
92106 case "end":
92107 return _context2.stop();
92108 }
92109 }, _callee2, this);
92110 }));
92111 function save(_x2, _x3) {
92112 return _save.apply(this, arguments);
92113 }
92114 return save;
92115 }()
92116 }, {
92117 key: "addStructuredOutputNames",
92118 value: function addStructuredOutputNames(outputTensors) {
92119 var _this2 = this;
92120 if (this.structuredOutputKeys) {
92121 var outputTensorsArray = outputTensors instanceof Tensor ? [outputTensors] : outputTensors;
92122 var outputTensorMap = {};
92123 outputTensorsArray.forEach(function (outputTensor, i) {
92124 return outputTensorMap[_this2.structuredOutputKeys[i]] = outputTensor;
92125 });
92126 return outputTensorMap;
92127 }
92128 return outputTensors;
92129 }
92130 /**
92131 * Execute the inference for the input tensors.
92132 *
92133 * @param input The input tensors, when there is single input for the model,
92134 * inputs param should be a `tf.Tensor`. For models with multiple inputs,
92135 * inputs params should be in either `tf.Tensor`[] if the input order is
92136 * fixed, or otherwise NamedTensorMap format.
92137 *
92138 * For model with multiple inputs, we recommend you use NamedTensorMap as the
92139 * input type, if you use `tf.Tensor`[], the order of the array needs to
92140 * follow the
92141 * order of inputNodes array. @see {@link GraphModel.inputNodes}
92142 *
92143 * You can also feed any intermediate nodes using the NamedTensorMap as the
92144 * input type. For example, given the graph
92145 * InputNode => Intermediate => OutputNode,
92146 * you can execute the subgraph Intermediate => OutputNode by calling
92147 * model.execute('IntermediateNode' : tf.tensor(...));
92148 *
92149 * This is useful for models that uses tf.dynamic_rnn, where the intermediate
92150 * state needs to be fed manually.
92151 *
92152 * For batch inference execution, the tensors for each input need to be
92153 * concatenated together. For example with mobilenet, the required input shape
92154 * is [1, 244, 244, 3], which represents the [batch, height, width, channel].
92155 * If we are provide a batched data of 100 images, the input tensor should be
92156 * in the shape of [100, 244, 244, 3].
92157 *
92158 * @param config Prediction configuration for specifying the batch size.
92159 * Currently the batch size option is ignored for graph model.
92160 *
92161 * @returns Inference result tensors. If the model is converted and it
92162 * originally had structured_outputs in tensorflow, then a NamedTensorMap
92163 * will be returned matching the structured_outputs. If no structured_outputs
92164 * are present, the output will be single `tf.Tensor` if the model has single
92165 * output node, otherwise Tensor[].
92166 *
92167 * @doc {heading: 'Models', subheading: 'Classes'}
92168 */
92169 }, {
92170 key: "predict",
92171 value: function predict(inputs, config) {
92172 var outputTensors = this.execute(inputs, this.outputNodes);
92173 return this.addStructuredOutputNames(outputTensors);
92174 }
92175 /**
92176 * Execute the inference for the input tensors in async fashion, use this
92177 * method when your model contains control flow ops.
92178 *
92179 * @param input The input tensors, when there is single input for the model,
92180 * inputs param should be a `tf.Tensor`. For models with mutliple inputs,
92181 * inputs params should be in either `tf.Tensor`[] if the input order is
92182 * fixed, or otherwise NamedTensorMap format.
92183 *
92184 * For model with multiple inputs, we recommend you use NamedTensorMap as the
92185 * input type, if you use `tf.Tensor`[], the order of the array needs to
92186 * follow the
92187 * order of inputNodes array. @see {@link GraphModel.inputNodes}
92188 *
92189 * You can also feed any intermediate nodes using the NamedTensorMap as the
92190 * input type. For example, given the graph
92191 * InputNode => Intermediate => OutputNode,
92192 * you can execute the subgraph Intermediate => OutputNode by calling
92193 * model.execute('IntermediateNode' : tf.tensor(...));
92194 *
92195 * This is useful for models that uses tf.dynamic_rnn, where the intermediate
92196 * state needs to be fed manually.
92197 *
92198 * For batch inference execution, the tensors for each input need to be
92199 * concatenated together. For example with mobilenet, the required input shape
92200 * is [1, 244, 244, 3], which represents the [batch, height, width, channel].
92201 * If we are provide a batched data of 100 images, the input tensor should be
92202 * in the shape of [100, 244, 244, 3].
92203 *
92204 * @param config Prediction configuration for specifying the batch size.
92205 * Currently the batch size option is ignored for graph model.
92206 *
92207 * @returns A Promise of inference result tensors. If the model is converted
92208 * and it originally had structured_outputs in tensorflow, then a
92209 * NamedTensorMap will be returned matching the structured_outputs. If no
92210 * structured_outputs are present, the output will be single `tf.Tensor` if
92211 * the model has single output node, otherwise Tensor[].
92212 *
92213 * @doc {heading: 'Models', subheading: 'Classes'}
92214 */
92215 }, {
92216 key: "predictAsync",
92217 value: function () {
92218 var _predictAsync = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(inputs, config) {
92219 var outputTensors;
92220 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
92221 while (1) switch (_context3.prev = _context3.next) {
92222 case 0:
92223 _context3.next = 2;
92224 return this.executeAsync(inputs, this.outputNodes);
92225 case 2:
92226 outputTensors = _context3.sent;
92227 return _context3.abrupt("return", this.addStructuredOutputNames(outputTensors));
92228 case 4:
92229 case "end":
92230 return _context3.stop();
92231 }
92232 }, _callee3, this);
92233 }));
92234 function predictAsync(_x4, _x5) {
92235 return _predictAsync.apply(this, arguments);
92236 }
92237 return predictAsync;
92238 }()
92239 }, {
92240 key: "normalizeInputs",
92241 value: function normalizeInputs(inputs) {
92242 var _this3 = this;
92243 var _a;
92244 if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) {
92245 // The input is already a NamedTensorMap.
92246 var signatureInputs = (_a = this.signature) === null || _a === void 0 ? void 0 : _a.inputs;
92247 if (signatureInputs != null) {
92248 for (var input in signatureInputs) {
92249 var tensor = signatureInputs[input];
92250 if (tensor.resourceId != null) {
92251 inputs[input] = this.resourceIdToCapturedInput[tensor.resourceId];
92252 }
92253 }
92254 }
92255 return inputs;
92256 }
92257 inputs = Array.isArray(inputs) ? inputs : [inputs];
92258 var numCapturedInputs = Object.keys(this.resourceIdToCapturedInput).length;
92259 if (inputs.length + numCapturedInputs !== this.inputNodes.length) {
92260 throw new Error("Input tensor count mismatch, the graph model has ".concat(this.inputNodes.length - numCapturedInputs, " non-resource placeholders, while there are ").concat(inputs.length, " input tensors provided."));
92261 }
92262 var inputIndex = 0;
92263 return this.inputNodes.reduce(function (map, inputName) {
92264 var _a, _b, _c;
92265 var resourceId = (_c = (_b = (_a = _this3.signature) === null || _a === void 0 ? void 0 : _a.inputs) === null || _b === void 0 ? void 0 : _b[inputName]) === null || _c === void 0 ? void 0 : _c.resourceId;
92266 if (resourceId != null) {
92267 map[inputName] = _this3.resourceIdToCapturedInput[resourceId];
92268 } else {
92269 map[inputName] = inputs[inputIndex++];
92270 }
92271 return map;
92272 }, {});
92273 }
92274 }, {
92275 key: "normalizeOutputs",
92276 value: function normalizeOutputs(outputs) {
92277 outputs = outputs || this.outputNodes;
92278 return !Array.isArray(outputs) ? [outputs] : outputs;
92279 }
92280 }, {
92281 key: "executeInitializerGraph",
92282 value: function executeInitializerGraph() {
92283 if (this.initializer == null) {
92284 return [];
92285 }
92286 if (this.initializerSignature == null) {
92287 return this.initializer.execute({}, []);
92288 } else {
92289 return this.initializer.execute({}, Object.keys(this.initializerSignature.outputs));
92290 }
92291 }
92292 }, {
92293 key: "executeInitializerGraphAsync",
92294 value: function () {
92295 var _executeInitializerGraphAsync = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4() {
92296 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
92297 while (1) switch (_context4.prev = _context4.next) {
92298 case 0:
92299 if (!(this.initializer == null)) {
92300 _context4.next = 2;
92301 break;
92302 }
92303 return _context4.abrupt("return", []);
92304 case 2:
92305 if (!(this.initializerSignature == null)) {
92306 _context4.next = 6;
92307 break;
92308 }
92309 return _context4.abrupt("return", this.initializer.executeAsync({}, []));
92310 case 6:
92311 return _context4.abrupt("return", this.initializer.executeAsync({}, Object.keys(this.initializerSignature.outputs)));
92312 case 7:
92313 case "end":
92314 return _context4.stop();
92315 }
92316 }, _callee4, this);
92317 }));
92318 function executeInitializerGraphAsync() {
92319 return _executeInitializerGraphAsync.apply(this, arguments);
92320 }
92321 return executeInitializerGraphAsync;
92322 }()
92323 }, {
92324 key: "setResourceIdToCapturedInput",
92325 value: function setResourceIdToCapturedInput(outputs) {
92326 this.resourceIdToCapturedInput = {};
92327 if (this.initializerSignature) {
92328 var signatureOutputs = this.initializerSignature.outputs;
92329 var outputNames = Object.keys(signatureOutputs);
92330 for (var i = 0; i < outputNames.length; i++) {
92331 var outputName = outputNames[i];
92332 var tensorInfo = signatureOutputs[outputName];
92333 this.resourceIdToCapturedInput[tensorInfo.resourceId] = outputs[i];
92334 }
92335 }
92336 }
92337 /**
92338 * Executes inference for the model for given input tensors.
92339 * @param inputs tensor, tensor array or tensor map of the inputs for the
92340 * model, keyed by the input node names.
92341 * @param outputs output node name from the TensorFlow model, if no
92342 * outputs are specified, the default outputs of the model would be used.
92343 * You can inspect intermediate nodes of the model by adding them to the
92344 * outputs array.
92345 *
92346 * @returns A single tensor if provided with a single output or no outputs
92347 * are provided and there is only one default output, otherwise return a
92348 * tensor array. The order of the tensor array is the same as the outputs
92349 * if provided, otherwise the order of outputNodes attribute of the model.
92350 *
92351 * @doc {heading: 'Models', subheading: 'Classes'}
92352 */
92353 }, {
92354 key: "execute",
92355 value: function execute(inputs, outputs) {
92356 if (this.resourceIdToCapturedInput == null) {
92357 this.setResourceIdToCapturedInput(this.executeInitializerGraph());
92358 }
92359 inputs = this.normalizeInputs(inputs);
92360 outputs = this.normalizeOutputs(outputs);
92361 var result = this.executor.execute(inputs, outputs);
92362 return result.length > 1 ? result : result[0];
92363 }
92364 /**
92365 * Executes inference for the model for given input tensors in async
92366 * fashion, use this method when your model contains control flow ops.
92367 * @param inputs tensor, tensor array or tensor map of the inputs for the
92368 * model, keyed by the input node names.
92369 * @param outputs output node name from the TensorFlow model, if no outputs
92370 * are specified, the default outputs of the model would be used. You can
92371 * inspect intermediate nodes of the model by adding them to the outputs
92372 * array.
92373 *
92374 * @returns A Promise of single tensor if provided with a single output or
92375 * no outputs are provided and there is only one default output, otherwise
92376 * return a tensor map.
92377 *
92378 * @doc {heading: 'Models', subheading: 'Classes'}
92379 */
92380 }, {
92381 key: "executeAsync",
92382 value: function () {
92383 var _executeAsync = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5(inputs, outputs) {
92384 var result;
92385 return _regeneratorRuntime().wrap(function _callee5$(_context5) {
92386 while (1) switch (_context5.prev = _context5.next) {
92387 case 0:
92388 if (!(this.resourceIdToCapturedInput == null)) {
92389 _context5.next = 6;
92390 break;
92391 }
92392 _context5.t0 = this;
92393 _context5.next = 4;
92394 return this.executeInitializerGraphAsync();
92395 case 4:
92396 _context5.t1 = _context5.sent;
92397 _context5.t0.setResourceIdToCapturedInput.call(_context5.t0, _context5.t1);
92398 case 6:
92399 inputs = this.normalizeInputs(inputs);
92400 outputs = this.normalizeOutputs(outputs);
92401 _context5.next = 10;
92402 return this.executor.executeAsync(inputs, outputs);
92403 case 10:
92404 result = _context5.sent;
92405 return _context5.abrupt("return", result.length > 1 ? result : result[0]);
92406 case 12:
92407 case "end":
92408 return _context5.stop();
92409 }
92410 }, _callee5, this);
92411 }));
92412 function executeAsync(_x6, _x7) {
92413 return _executeAsync.apply(this, arguments);
92414 }
92415 return executeAsync;
92416 }()
92417 /**
92418 * Get intermediate tensors for model debugging mode (flag
92419 * KEEP_INTERMEDIATE_TENSORS is true).
92420 *
92421 * @doc {heading: 'Models', subheading: 'Classes'}
92422 */
92423 }, {
92424 key: "getIntermediateTensors",
92425 value: function getIntermediateTensors() {
92426 return this.executor.getIntermediateTensors();
92427 }
92428 /**
92429 * Dispose intermediate tensors for model debugging mode (flag
92430 * KEEP_INTERMEDIATE_TENSORS is true).
92431 *
92432 * @doc {heading: 'Models', subheading: 'Classes'}
92433 */
92434 }, {
92435 key: "disposeIntermediateTensors",
92436 value: function disposeIntermediateTensors() {
92437 this.executor.disposeIntermediateTensors();
92438 }
92439 }, {
92440 key: "convertTensorMapToTensorsMap",
92441 value: function convertTensorMapToTensorsMap(map) {
92442 return Object.keys(map).reduce(function (newMap, key) {
92443 newMap[key] = [map[key]];
92444 return newMap;
92445 }, {});
92446 }
92447 /**
92448 * Releases the memory used by the weight tensors and resourceManager.
92449 *
92450 * @doc {heading: 'Models', subheading: 'Classes'}
92451 */
92452 }, {
92453 key: "dispose",
92454 value: function dispose$1() {
92455 this.executor.dispose();
92456 if (this.initializer) {
92457 this.initializer.dispose();
92458 if (this.resourceIdToCapturedInput) {
92459 dispose(this.resourceIdToCapturedInput);
92460 }
92461 }
92462 this.resourceManager.dispose();
92463 }
92464 }]);
92465 return GraphModel;
92466 }();
92467 /**
92468 * Load a graph model given a URL to the model definition.
92469 *
92470 * Example of loading MobileNetV2 from a URL and making a prediction with a
92471 * zeros input:
92472 *
92473 * ```js
92474 * const modelUrl =
92475 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
92476 * const model = await tf.loadGraphModel(modelUrl);
92477 * const zeros = tf.zeros([1, 224, 224, 3]);
92478 * model.predict(zeros).print();
92479 * ```
92480 *
92481 * Example of loading MobileNetV2 from a TF Hub URL and making a prediction
92482 * with a zeros input:
92483 *
92484 * ```js
92485 * const modelUrl =
92486 * 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
92487 * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
92488 * const zeros = tf.zeros([1, 224, 224, 3]);
92489 * model.predict(zeros).print();
92490 * ```
92491 * @param modelUrl The url or an `io.IOHandler` that loads the model.
92492 * @param options Options for the HTTP request, which allows to send
92493 * credentials
92494 * and custom headers.
92495 *
92496 * @doc {heading: 'Models', subheading: 'Loading'}
92497 */
92498 function loadGraphModel(_x8) {
92499 return _loadGraphModel.apply(this, arguments);
92500 }
92501 /**
92502 * Load a graph model given a synchronous IO handler with a 'load' method.
92503 *
92504 * @param modelSource The `io.IOHandlerSync` that loads the model, or the
92505 * `io.ModelArtifacts` that encode the model, or a tuple of
92506 * `[io.ModelJSON, ArrayBuffer]` of which the first element encodes the
92507 * model and the second contains the weights.
92508 *
92509 * @doc {heading: 'Models', subheading: 'Loading'}
92510 */
92511 function _loadGraphModel() {
92512 _loadGraphModel = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee6(modelUrl) {
92513 var options,
92514 tfio,
92515 model,
92516 _args6 = arguments;
92517 return _regeneratorRuntime().wrap(function _callee6$(_context6) {
92518 while (1) switch (_context6.prev = _context6.next) {
92519 case 0:
92520 options = _args6.length > 1 && _args6[1] !== undefined ? _args6[1] : {};
92521 tfio = _args6.length > 2 && _args6[2] !== undefined ? _args6[2] : io;
92522 if (!(modelUrl == null)) {
92523 _context6.next = 4;
92524 break;
92525 }
92526 throw new Error('modelUrl in loadGraphModel() cannot be null. Please provide a url ' + 'or an IOHandler that loads the model');
92527 case 4:
92528 if (options == null) {
92529 options = {};
92530 }
92531 if (options.fromTFHub && typeof modelUrl === 'string') {
92532 modelUrl = getTFHubUrl(modelUrl);
92533 }
92534 model = new GraphModel(modelUrl, options, tfio);
92535 _context6.next = 9;
92536 return model.load();
92537 case 9:
92538 return _context6.abrupt("return", model);
92539 case 10:
92540 case "end":
92541 return _context6.stop();
92542 }
92543 }, _callee6);
92544 }));
92545 return _loadGraphModel.apply(this, arguments);
92546 }
92547 function loadGraphModelSync(modelSource) {
92548 if (modelSource == null) {
92549 throw new Error('modelUrl in loadGraphModelSync() cannot be null. Please provide ' + 'model artifacts or an IOHandler that loads the model');
92550 }
92551 var ioHandler;
92552 if (modelSource instanceof Array) {
92553 var _modelSource = _slicedToArray(modelSource, 2),
92554 modelJSON = _modelSource[0],
92555 weights = _modelSource[1];
92556 if (!modelJSON) {
92557 throw new Error('modelJSON must be the first element of the array');
92558 }
92559 if (!weights || !(weights instanceof ArrayBuffer)) {
92560 throw new Error('An ArrayBuffer of weights must be the second element of' + ' the array');
92561 }
92562 if (!('modelTopology' in modelJSON)) {
92563 throw new Error('Model JSON is missing \'modelTopology\'');
92564 }
92565 if (!('weightsManifest' in modelJSON)) {
92566 throw new Error('Model JSON is missing \'weightsManifest\'');
92567 }
92568 var weightSpecs = getWeightSpecs(modelJSON.weightsManifest);
92569 var modelArtifacts = getModelArtifactsForJSONSync(modelJSON, weightSpecs, weights);
92570 ioHandler = fromMemorySync(modelArtifacts);
92571 } else if ('load' in modelSource) {
92572 // Then modelSource is already an IOHandlerSync.
92573 ioHandler = modelSource;
92574 } else if ('modelTopology' in modelSource && 'weightSpecs' in modelSource && 'weightData' in modelSource) {
92575 // modelSource is of type ModelArtifacts.
92576 ioHandler = fromMemorySync(modelSource);
92577 } else {
92578 throw new Error('Unknown model format');
92579 }
92580 var model = new GraphModel(ioHandler);
92581 model.load();
92582 return model;
92583 }
92584 function getTFHubUrl(modelUrl) {
92585 if (!modelUrl.endsWith('/')) {
92586 modelUrl = modelUrl + '/';
92587 }
92588 return "".concat(modelUrl).concat(DEFAULT_MODEL_NAME).concat(TFHUB_SEARCH_PARAM);
92589 }
92590
92591 /** @license See the LICENSE file. */
92592 // This code is auto-generated, do not modify this file!
92593 var version$5 = '4.22.0';
92594
92595 /**
92596 * @license
92597 * Copyright 2018 Google LLC. All Rights Reserved.
92598 * Licensed under the Apache License, Version 2.0 (the "License");
92599 * you may not use this file except in compliance with the License.
92600 * You may obtain a copy of the License at
92601 *
92602 * http://www.apache.org/licenses/LICENSE-2.0
92603 *
92604 * Unless required by applicable law or agreed to in writing, software
92605 * distributed under the License is distributed on an "AS IS" BASIS,
92606 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92607 * See the License for the specific language governing permissions and
92608 * limitations under the License.
92609 * =============================================================================
92610 */
92611
92612 /**
92613 * Apply a mapping function to a nested structure in a recursive manner.
92614 *
92615 * The result of the mapping is an object with the same nested structure (i.e.,
92616 * of arrays and dicts) as the input, except that some subtrees are replaced,
92617 * according to the results of the mapping function.
92618 *
92619 * Mappings are memoized. Thus, if the nested structure contains the same
92620 * object in multiple positions, the output will contain the same mapped object
92621 * in those positions. Cycles are not supported, however.
92622 *
92623 * @param input: The object to which to apply the mapping function.
92624 * @param mapFn: A function that expects a single node of the object tree, and
92625 * returns a `DeepMapResult`. The `DeepMapResult` either provides a
92626 * replacement value for that node (i.e., replacing the subtree), or indicates
92627 * that the node should be processed recursively.
92628 */
92629 function deepMap(input, mapFn) {
92630 return deepMapInternal(input, mapFn);
92631 }
92632 /**
92633 * @param seen: A Map of known object mappings (i.e., memoized results of
92634 * `mapFn()`)
92635 * @param containedIn: An set containing objects on the reference path currently
92636 * being processed (used to detect cycles).
92637 */
92638 function deepMapInternal(input, mapFn) {
92639 var seen = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : new Map();
92640 var containedIn = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : new Set();
92641 if (input == null) {
92642 return null;
92643 }
92644 if (typeof Blob === 'function' && input instanceof Blob) {
92645 return input.slice();
92646 }
92647 if (containedIn.has(input)) {
92648 throw new Error('Circular references are not supported.');
92649 }
92650 if (seen.has(input)) {
92651 return seen.get(input);
92652 }
92653 var result = mapFn(input);
92654 if (result.recurse && result.value !== null) {
92655 throw new Error('A deep map function may not return both a value and recurse=true.');
92656 }
92657 if (!result.recurse) {
92658 seen.set(input, result.value);
92659 return result.value;
92660 } else if (isIterable(input)) {
92661 // tslint:disable-next-line:no-any
92662 var mappedIterable = Array.isArray(input) ? [] : {};
92663 containedIn.add(input);
92664 for (var k in input) {
92665 var child = input[k];
92666 var childResult = deepMapInternal(child, mapFn, seen, containedIn);
92667 mappedIterable[k] = childResult;
92668 }
92669 containedIn.delete(input);
92670 if (input.__proto__) {
92671 mappedIterable.__proto__ = input.__proto__;
92672 }
92673 return mappedIterable;
92674 } else {
92675 throw new Error("Can't recurse into non-iterable type: ".concat(input));
92676 }
92677 }
92678 // TODO(soergel, kangyizhang) Reconsider naming of deepZip() to avoid confusion
92679 // with zip()
92680 /**
92681 * Zip nested structures together in a recursive manner.
92682 *
92683 * This has the effect of transposing or pivoting data, e.g. converting it from
92684 * a row-major representation to a column-major representation.
92685 *
92686 * For example, `deepZip([{a: 1, b: 2}, {a: 3, b: 4}])` returns
92687 * `{a: [1, 3], b: [2, 4]}`.
92688 *
92689 * The inputs should all have the same nested structure (i.e., of arrays and
92690 * dicts). The result is a single object with the same nested structure, where
92691 * the leaves are arrays collecting the values of the inputs at that location
92692 * (or, optionally, the result of a custom function applied to those arrays).
92693 *
92694 * @param inputs: An array of the objects to zip together.
92695 * @param zipFn: (optional) A function that expects an array of elements at a
92696 * single node of the object tree, and returns a `DeepMapResult`. The
92697 * `DeepMapResult` either provides a result value for that node (i.e.,
92698 * representing the subtree), or indicates that the node should be processed
92699 * recursively. The default zipFn recurses as far as possible and places
92700 * arrays at the leaves.
92701 */
92702 function deepZip(inputs) {
92703 var zipFn = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : zipToList;
92704 return deepZipInternal(inputs, zipFn);
92705 }
92706 /**
92707 * @param containedIn: An set containing objects on the reference path currently
92708 * being processed (used to detect cycles).
92709 */
92710 function deepZipInternal(inputs, zipFn) {
92711 var containedIn = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : new Set();
92712 // The recursion follows the structure of input 0; it's assumed that all the
92713 // other inputs have the same structure.
92714 var input = inputs[0];
92715 if (containedIn.has(input)) {
92716 throw new Error('Circular references are not supported.');
92717 }
92718 var result = zipFn(inputs);
92719 if (result.recurse && result.value !== null) {
92720 throw new Error('A deep zip function may not return both a value and recurse=true.');
92721 }
92722 if (!result.recurse) {
92723 return result.value;
92724 } else if (isIterable(input)) {
92725 // tslint:disable-next-line:no-any
92726 var mappedIterable = Array.isArray(input) ? [] : {};
92727 containedIn.add(input);
92728 var _loop = function _loop(k) {
92729 var children = inputs.map(function (x) {
92730 return x[k];
92731 });
92732 var childResult = deepZipInternal(children, zipFn, containedIn);
92733 mappedIterable[k] = childResult;
92734 };
92735 for (var k in input) {
92736 _loop(k);
92737 }
92738 containedIn.delete(input);
92739 return mappedIterable;
92740 } else {
92741 throw new Error("Can't recurse into non-iterable type: ".concat(input));
92742 }
92743 }
92744 // tslint:disable-next-line:no-any
92745 function zipToList(x) {
92746 if (x === null) {
92747 return null;
92748 }
92749 // TODO(soergel): validate array type?
92750 if (isIterable(x[0])) {
92751 return {
92752 value: null,
92753 recurse: true
92754 };
92755 } else {
92756 return {
92757 value: x,
92758 recurse: false
92759 };
92760 }
92761 }
92762 /**
92763 * Apply an async mapping function to a nested structure in a recursive manner.
92764 *
92765 * This first creates a nested structure of Promises, and then awaits all of
92766 * those, resulting in a single Promise for a resolved nested structure.
92767 *
92768 * The result of the mapping is an object with the same nested structure (i.e.,
92769 * of arrays and dicts) as the input, except that some subtrees are replaced,
92770 * according to the results of the mapping function.
92771 *
92772 * Mappings are memoized. Thus, if the nested structure contains the same
92773 * object in multiple positions, the output will contain the same mapped object
92774 * in those positions. Cycles are not supported, however.
92775 *
92776 * @param input: The object to which to apply the mapping function.
92777 * @param mapFn: A function that expects a single node of the object tree, and
92778 * returns a `DeepMapAsyncResult`. The `DeepMapAsyncResult` either provides
92779 * a `Promise` for a replacement value for that node (i.e., replacing the
92780 * subtree), or indicates that the node should be processed recursively. Note
92781 * that the decision whether or not to recurse must be made immediately; only
92782 * the mapped value may be promised.
92783 */
92784 function deepMapAndAwaitAll(_x, _x2) {
92785 return _deepMapAndAwaitAll.apply(this, arguments);
92786 }
92787 /**
92788 * Determine whether the argument is iterable.
92789 *
92790 * @returns true if the argument is an array or any non-Tensor object.
92791 */
92792 // tslint:disable-next-line:no-any
92793 function _deepMapAndAwaitAll() {
92794 _deepMapAndAwaitAll = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(input, mapFn) {
92795 var seen, _i, _Array$from, key, value, mappedValue, result;
92796 return _regeneratorRuntime().wrap(function _callee$(_context) {
92797 while (1) switch (_context.prev = _context.next) {
92798 case 0:
92799 seen = new Map(); // First do a normal deepMap, collecting Promises in 'seen' as a side effect.
92800 deepMapInternal(input, mapFn, seen);
92801 // Replace the Promises in 'seen' in place.
92802 // Note TypeScript provides no async map iteration, and regular map iteration
92803 // is broken too, so sadly we have to do Array.from() to make it work.
92804 // (There's no advantage to Promise.all(), and that would be tricky anyway.)
92805 _i = 0, _Array$from = Array.from(seen.keys());
92806 case 3:
92807 if (!(_i < _Array$from.length)) {
92808 _context.next = 14;
92809 break;
92810 }
92811 key = _Array$from[_i];
92812 value = seen.get(key);
92813 if (!isPromise(value)) {
92814 _context.next = 11;
92815 break;
92816 }
92817 _context.next = 9;
92818 return value;
92819 case 9:
92820 mappedValue = _context.sent;
92821 seen.set(key, mappedValue);
92822 case 11:
92823 _i++;
92824 _context.next = 3;
92825 break;
92826 case 14:
92827 // Normal deepMap again, this time filling in the resolved values.
92828 // It's unfortunate that we have to do two passes.
92829 // TODO(soergel): test performance and think harder about a fast solution.
92830 result = deepMapInternal(input, mapFn, seen);
92831 return _context.abrupt("return", result);
92832 case 16:
92833 case "end":
92834 return _context.stop();
92835 }
92836 }, _callee);
92837 }));
92838 return _deepMapAndAwaitAll.apply(this, arguments);
92839 }
92840 function isIterable(obj) {
92841 var isTextDecoder = false;
92842 if (env().get('IS_BROWSER')) {
92843 isTextDecoder = obj instanceof TextDecoder;
92844 } else {
92845 // tslint:disable-next-line:no-require-imports
92846 var _require = require('string_decoder'),
92847 StringDecoder = _require.StringDecoder;
92848 isTextDecoder = obj instanceof StringDecoder;
92849 }
92850 return obj != null && !ArrayBuffer.isView(obj) && (Array.isArray(obj) || _typeof(obj) === 'object' && !(obj instanceof Tensor) && !(obj instanceof Promise) && !isTextDecoder);
92851 }
92852 /**
92853 * Determine whether the argument can be converted to Tensor.
92854 *
92855 * Tensors, primitives, arrays, and TypedArrays all qualify; anything else does
92856 * not.
92857 *
92858 * @returns true if the argument can be converted to Tensor.
92859 */
92860 // tslint:disable-next-line:no-any
92861 function canTensorify(obj) {
92862 return obj == null || isPrimitive(obj) || Array.isArray(obj) || _typeof(obj) === 'object' && obj instanceof Tensor || isTypedArray(obj);
92863 }
92864 /**
92865 * Returns true if the given `value` is a primitive type. Otherwise returns
92866 * false. This is equivalant to node util.isPrimitive
92867 */
92868 function isPrimitive(value) {
92869 return value === null || _typeof(value) !== 'object' && typeof value !== 'function';
92870 }
92871
92872 /**
92873 * @license
92874 * Copyright 2018 Google LLC. All Rights Reserved.
92875 * Licensed under the Apache License, Version 2.0 (the "License");
92876 * you may not use this file except in compliance with the License.
92877 * You may obtain a copy of the License at
92878 *
92879 * http://www.apache.org/licenses/LICENSE-2.0
92880 *
92881 * Unless required by applicable law or agreed to in writing, software
92882 * distributed under the License is distributed on an "AS IS" BASIS,
92883 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92884 * See the License for the specific language governing permissions and
92885 * limitations under the License.
92886 *
92887 * =============================================================================
92888 */
92889 function deepClone(container) {
92890 return deepMap(container, cloneIfTensor);
92891 }
92892 // tslint:disable-next-line: no-any
92893 function cloneIfTensor(item) {
92894 if (item instanceof Tensor) {
92895 return {
92896 value: item.clone(),
92897 recurse: false
92898 };
92899 } else if (isIterable(item)) {
92900 return {
92901 value: null,
92902 recurse: true
92903 };
92904 } else {
92905 return {
92906 value: item,
92907 recurse: false
92908 };
92909 }
92910 }
92911
92912 /**
92913 * @license
92914 * Copyright 2018 Google LLC. All Rights Reserved.
92915 * Licensed under the Apache License, Version 2.0 (the "License");
92916 * you may not use this file except in compliance with the License.
92917 * You may obtain a copy of the License at
92918 *
92919 * http://www.apache.org/licenses/LICENSE-2.0
92920 *
92921 * Unless required by applicable law or agreed to in writing, software
92922 * distributed under the License is distributed on an "AS IS" BASIS,
92923 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92924 * See the License for the specific language governing permissions and
92925 * limitations under the License.
92926 *
92927 * =============================================================================
92928 */
92929 /**
92930 * A ring buffer, providing O(1) FIFO, LIFO, and related operations.
92931 */
92932 var RingBuffer = /*#__PURE__*/function () {
92933 /**
92934 * Constructs a `RingBuffer`.
92935 * @param capacity The number of items that the buffer can accomodate.
92936 */
92937 function RingBuffer(capacity) {
92938 _classCallCheck(this, RingBuffer);
92939 this.capacity = capacity;
92940 // Note we store the indices in the range 0 <= index < 2*capacity.
92941 // This allows us to distinguish the full from the empty case.
92942 // See https://www.snellman.net/blog/archive/2016-12-13-ring-buffers/
92943 this.begin = 0; // inclusive
92944 this.end = 0; // exclusive
92945 if (capacity == null) {
92946 throw new RangeError('Can\'t create a ring buffer of unknown capacity.');
92947 }
92948 if (capacity < 1) {
92949 throw new RangeError('Can\'t create ring buffer of capacity < 1.');
92950 }
92951 this.data = new Array(capacity);
92952 this.doubledCapacity = 2 * capacity;
92953 }
92954 /**
92955 * Map any index into the range 0 <= index < 2*capacity.
92956 */
92957 _createClass(RingBuffer, [{
92958 key: "wrap",
92959 value: function wrap(index) {
92960 // don't trust % on negative numbers
92961 while (index < 0) {
92962 index += this.doubledCapacity;
92963 }
92964 return index % this.doubledCapacity;
92965 }
92966 }, {
92967 key: "get",
92968 value: function get(index) {
92969 if (index < 0) {
92970 throw new RangeError('Can\'t get item at a negative index.');
92971 }
92972 return this.data[index % this.capacity];
92973 }
92974 }, {
92975 key: "set",
92976 value: function set(index, value) {
92977 if (index < 0) {
92978 throw new RangeError('Can\'t set item at a negative index.');
92979 }
92980 this.data[index % this.capacity] = value;
92981 }
92982 /**
92983 * Returns the current number of items in the buffer.
92984 */
92985 }, {
92986 key: "length",
92987 value: function length() {
92988 var length = this.end - this.begin;
92989 if (length < 0) {
92990 length = this.doubledCapacity + length;
92991 }
92992 return length;
92993 }
92994 /**
92995 * Reports whether the buffer is full.
92996 * @returns true if the number of items in the buffer equals its capacity, and
92997 * false otherwise.
92998 */
92999 }, {
93000 key: "isFull",
93001 value: function isFull() {
93002 return this.length() === this.capacity;
93003 }
93004 /**
93005 * Reports whether the buffer is empty.
93006 * @returns true if the number of items in the buffer equals zero, and
93007 * false otherwise.
93008 */
93009 }, {
93010 key: "isEmpty",
93011 value: function isEmpty() {
93012 return this.length() === 0;
93013 }
93014 /**
93015 * Adds an item to the end of the buffer.
93016 */
93017 }, {
93018 key: "push",
93019 value: function push(value) {
93020 if (this.isFull()) {
93021 throw new RangeError('Ring buffer is full.');
93022 }
93023 this.set(this.end, value);
93024 this.end = this.wrap(this.end + 1);
93025 }
93026 /**
93027 * Adds many items to the end of the buffer, in order.
93028 */
93029 }, {
93030 key: "pushAll",
93031 value: function pushAll(values) {
93032 var _iterator = _createForOfIteratorHelper(values),
93033 _step;
93034 try {
93035 for (_iterator.s(); !(_step = _iterator.n()).done;) {
93036 var value = _step.value;
93037 this.push(value);
93038 }
93039 } catch (err) {
93040 _iterator.e(err);
93041 } finally {
93042 _iterator.f();
93043 }
93044 }
93045 /**
93046 * Removes and returns the last item in the buffer.
93047 */
93048 }, {
93049 key: "pop",
93050 value: function pop() {
93051 if (this.isEmpty()) {
93052 throw new RangeError('Ring buffer is empty.');
93053 }
93054 this.end = this.wrap(this.end - 1);
93055 var result = this.get(this.end);
93056 this.set(this.end, undefined);
93057 return result;
93058 }
93059 /**
93060 * Adds an item to the beginning of the buffer.
93061 */
93062 }, {
93063 key: "unshift",
93064 value: function unshift(value) {
93065 if (this.isFull()) {
93066 throw new RangeError('Ring buffer is full.');
93067 }
93068 this.begin = this.wrap(this.begin - 1);
93069 this.set(this.begin, value);
93070 }
93071 /**
93072 * Removes and returns the first item in the buffer.
93073 */
93074 }, {
93075 key: "shift",
93076 value: function shift() {
93077 if (this.isEmpty()) {
93078 throw new RangeError('Ring buffer is empty.');
93079 }
93080 var result = this.get(this.begin);
93081 this.set(this.begin, undefined);
93082 this.begin = this.wrap(this.begin + 1);
93083 return result;
93084 }
93085 /**
93086 * Removes and returns a specific item in the buffer, and moves the last item
93087 * to the vacated slot. This is useful for implementing a shuffling stream.
93088 * Note that this operation necessarily scrambles the original order.
93089 *
93090 * @param relativeIndex: the index of the item to remove, relative to the
93091 * first item in the buffer (e.g., hiding the ring nature of the underlying
93092 * storage).
93093 */
93094 }, {
93095 key: "shuffleExcise",
93096 value: function shuffleExcise(relativeIndex) {
93097 if (this.isEmpty()) {
93098 throw new RangeError('Ring buffer is empty.');
93099 }
93100 var index = this.wrap(this.begin + relativeIndex);
93101 var result = this.get(index);
93102 this.set(index, this.pop());
93103 return result;
93104 }
93105 }]);
93106 return RingBuffer;
93107 }();
93108
93109 var GrowingRingBuffer = /*#__PURE__*/function (_RingBuffer) {
93110 _inherits(GrowingRingBuffer, _RingBuffer);
93111 var _super = _createSuper(GrowingRingBuffer);
93112 /**
93113 * Constructs a `GrowingRingBuffer`.
93114 */
93115 function GrowingRingBuffer() {
93116 _classCallCheck(this, GrowingRingBuffer);
93117 return _super.call(this, GrowingRingBuffer.INITIAL_CAPACITY);
93118 }
93119 _createClass(GrowingRingBuffer, [{
93120 key: "isFull",
93121 value: function isFull() {
93122 return false;
93123 }
93124 }, {
93125 key: "push",
93126 value: function push(value) {
93127 if (_get(_getPrototypeOf(GrowingRingBuffer.prototype), "isFull", this).call(this)) {
93128 this.expand();
93129 }
93130 _get(_getPrototypeOf(GrowingRingBuffer.prototype), "push", this).call(this, value);
93131 }
93132 }, {
93133 key: "unshift",
93134 value: function unshift(value) {
93135 if (_get(_getPrototypeOf(GrowingRingBuffer.prototype), "isFull", this).call(this)) {
93136 this.expand();
93137 }
93138 _get(_getPrototypeOf(GrowingRingBuffer.prototype), "unshift", this).call(this, value);
93139 }
93140 /**
93141 * Doubles the capacity of the buffer.
93142 */
93143 }, {
93144 key: "expand",
93145 value: function expand() {
93146 var newCapacity = this.capacity * 2;
93147 var newData = new Array(newCapacity);
93148 var len = this.length();
93149 // Rotate the buffer to start at index 0 again, since we can't just
93150 // allocate more space at the end.
93151 for (var i = 0; i < len; i++) {
93152 newData[i] = this.get(this.wrap(this.begin + i));
93153 }
93154 this.data = newData;
93155 this.capacity = newCapacity;
93156 this.doubledCapacity = 2 * this.capacity;
93157 this.begin = 0;
93158 this.end = len;
93159 }
93160 }]);
93161 return GrowingRingBuffer;
93162 }(RingBuffer);
93163 GrowingRingBuffer.INITIAL_CAPACITY = 32;
93164
93165 // Here we implement a simple asynchronous iterator.
93166 // This lets us avoid using either third-party stream libraries or
93167 // recent TypeScript language support requiring polyfills.
93168 /**
93169 * Create a `LazyIterator` from an array of items.
93170 */
93171 function iteratorFromItems(items) {
93172 return new ArrayIterator(items);
93173 }
93174 /**
93175 * Create a `LazyIterator` of incrementing integers.
93176 */
93177 function iteratorFromIncrementing(start) {
93178 var i = start;
93179 return iteratorFromFunction(function () {
93180 return {
93181 value: i++,
93182 done: false
93183 };
93184 });
93185 }
93186 /**
93187 * Create a `LazyIterator` from a function.
93188 *
93189 * ```js
93190 * let i = -1;
93191 * const func = () =>
93192 * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
93193 * const iter = tf.data.iteratorFromFunction(func);
93194 * await iter.forEachAsync(e => console.log(e));
93195 * ```
93196 *
93197 * @param func A function that produces data on each call.
93198 */
93199 function iteratorFromFunction(func) {
93200 return new FunctionCallIterator(func);
93201 }
93202 /**
93203 * Create a `LazyIterator` by concatenating underlying streams, which are
93204 * themselves provided as a stream.
93205 *
93206 * This can also be thought of as a "stream flatten" operation.
93207 *
93208 * @param baseIterators A stream of streams to be concatenated.
93209 * @param baseErrorHandler An optional function that can intercept `Error`s
93210 * raised during a `next()` call on the base stream. This function can decide
93211 * whether the error should be propagated, whether the error should be
93212 * ignored, or whether the base stream should be terminated.
93213 */
93214 function iteratorFromConcatenated(baseIterators, baseErrorHandler) {
93215 return new ChainedIterator(baseIterators, baseErrorHandler);
93216 }
93217 /**
93218 * Create a `LazyIterator` by concatenating streams produced by calling a
93219 * stream-generating function a given number of times.
93220 *
93221 * Since a `LazyIterator` is read-once, it cannot be repeated, but this
93222 * function can be used to achieve a similar effect:
93223 *
93224 * LazyIterator.ofConcatenatedFunction(() => new MyIterator(), 6);
93225 *
93226 * @param iteratorFunc: A function that produces a new stream on each call.
93227 * @param count: The number of times to call the function.
93228 * @param baseErrorHandler An optional function that can intercept `Error`s
93229 * raised during a `next()` call on the base stream. This function can decide
93230 * whether the error should be propagated, whether the error should be
93231 * ignored, or whether the base stream should be terminated.
93232 */
93233 function iteratorFromConcatenatedFunction(iteratorFunc, count, baseErrorHandler) {
93234 return iteratorFromConcatenated(iteratorFromFunction(iteratorFunc).take(count), baseErrorHandler);
93235 }
93236 /**
93237 * Create a `LazyIterator` by zipping together an array, dict, or nested
93238 * structure of `LazyIterator`s (and perhaps additional constants).
93239 *
93240 * The underlying streams must provide elements in a consistent order such
93241 * that they correspond.
93242 *
93243 * Typically, the underlying streams should have the same number of
93244 * elements. If they do not, the behavior is determined by the
93245 * `mismatchMode` argument.
93246 *
93247 * The nested structure of the `iterators` argument determines the
93248 * structure of elements in the resulting iterator.
93249 *
93250 * @param iterators: An array or object containing LazyIterators at the
93251 * leaves.
93252 * @param mismatchMode: Determines what to do when one underlying iterator
93253 * is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
93254 * causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
93255 * causes the zipped iterator to terminate with the furst underlying
93256 * streams, so elements remaining on the longer streams are ignored.
93257 * `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
93258 * in nulls for the exhausted streams, until all streams are exhausted.
93259 */
93260 function iteratorFromZipped(iterators) {
93261 var mismatchMode = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : ZipMismatchMode.FAIL;
93262 return new ZipIterator(iterators, mismatchMode);
93263 }
93264 /**
93265 * An asynchronous iterator, providing lazy access to a potentially
93266 * unbounded stream of elements.
93267 *
93268 * Iterator can be obtained from a dataset:
93269 * `const iter = await dataset.iterator();`
93270 */
93271 var LazyIterator = /*#__PURE__*/function () {
93272 function LazyIterator() {
93273 _classCallCheck(this, LazyIterator);
93274 }
93275 _createClass(LazyIterator, [{
93276 key: "toArray",
93277 value:
93278 /**
93279 * Collect all remaining elements of a bounded stream into an array.
93280 * Obviously this will succeed only for small streams that fit in memory.
93281 * Useful for testing.
93282 *
93283 * @returns A Promise for an array of stream elements, which will resolve
93284 * when the stream is exhausted.
93285 */
93286 function () {
93287 var _toArray = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
93288 var result, x;
93289 return _regeneratorRuntime().wrap(function _callee$(_context) {
93290 while (1) switch (_context.prev = _context.next) {
93291 case 0:
93292 result = [];
93293 _context.next = 3;
93294 return this.next();
93295 case 3:
93296 x = _context.sent;
93297 case 4:
93298 if (x.done) {
93299 _context.next = 11;
93300 break;
93301 }
93302 result.push(x.value);
93303 _context.next = 8;
93304 return this.next();
93305 case 8:
93306 x = _context.sent;
93307 _context.next = 4;
93308 break;
93309 case 11:
93310 return _context.abrupt("return", result);
93311 case 12:
93312 case "end":
93313 return _context.stop();
93314 }
93315 }, _callee, this);
93316 }));
93317 function toArray() {
93318 return _toArray.apply(this, arguments);
93319 }
93320 return toArray;
93321 }()
93322 /**
93323 * Collect all elements of this dataset into an array with prefetching 100
93324 * elements. This is useful for testing, because the prefetch changes the
93325 * order in which the Promises are resolved along the processing pipeline.
93326 * This may help expose bugs where results are dependent on the order of
93327 * Promise resolution rather than on the logical order of the stream (i.e.,
93328 * due to hidden mutable state).
93329 *
93330 * @returns A Promise for an array of stream elements, which will resolve
93331 * when the stream is exhausted.
93332 */
93333 }, {
93334 key: "toArrayForTest",
93335 value: function () {
93336 var _toArrayForTest = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
93337 var stream, result, x;
93338 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
93339 while (1) switch (_context2.prev = _context2.next) {
93340 case 0:
93341 stream = this.prefetch(100);
93342 result = [];
93343 _context2.next = 4;
93344 return stream.next();
93345 case 4:
93346 x = _context2.sent;
93347 case 5:
93348 if (x.done) {
93349 _context2.next = 12;
93350 break;
93351 }
93352 result.push(x.value);
93353 _context2.next = 9;
93354 return stream.next();
93355 case 9:
93356 x = _context2.sent;
93357 _context2.next = 5;
93358 break;
93359 case 12:
93360 return _context2.abrupt("return", result);
93361 case 13:
93362 case "end":
93363 return _context2.stop();
93364 }
93365 }, _callee2, this);
93366 }));
93367 function toArrayForTest() {
93368 return _toArrayForTest.apply(this, arguments);
93369 }
93370 return toArrayForTest;
93371 }()
93372 /**
93373 * Draw items from the stream until it is exhausted.
93374 *
93375 * This can be useful when the stream has side effects but no output. In
93376 * that case, calling this function guarantees that the stream will be
93377 * fully processed.
93378 */
93379 }, {
93380 key: "resolveFully",
93381 value: function () {
93382 var _resolveFully = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3() {
93383 var x;
93384 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
93385 while (1) switch (_context3.prev = _context3.next) {
93386 case 0:
93387 _context3.next = 2;
93388 return this.next();
93389 case 2:
93390 x = _context3.sent;
93391 case 3:
93392 if (x.done) {
93393 _context3.next = 9;
93394 break;
93395 }
93396 _context3.next = 6;
93397 return this.next();
93398 case 6:
93399 x = _context3.sent;
93400 _context3.next = 3;
93401 break;
93402 case 9:
93403 case "end":
93404 return _context3.stop();
93405 }
93406 }, _callee3, this);
93407 }));
93408 function resolveFully() {
93409 return _resolveFully.apply(this, arguments);
93410 }
93411 return resolveFully;
93412 }()
93413 /**
93414 * Draw items from the stream until it is exhausted, or a predicate fails.
93415 *
93416 * This can be useful when the stream has side effects but no output. In
93417 * that case, calling this function guarantees that the stream will be
93418 * fully processed.
93419 */
93420 }, {
93421 key: "resolveWhile",
93422 value: function () {
93423 var _resolveWhile = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(predicate) {
93424 var x, shouldContinue;
93425 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
93426 while (1) switch (_context4.prev = _context4.next) {
93427 case 0:
93428 _context4.next = 2;
93429 return this.next();
93430 case 2:
93431 x = _context4.sent;
93432 shouldContinue = predicate(x.value);
93433 case 4:
93434 if (!(!x.done && shouldContinue)) {
93435 _context4.next = 11;
93436 break;
93437 }
93438 _context4.next = 7;
93439 return this.next();
93440 case 7:
93441 x = _context4.sent;
93442 shouldContinue = predicate(x.value);
93443 _context4.next = 4;
93444 break;
93445 case 11:
93446 case "end":
93447 return _context4.stop();
93448 }
93449 }, _callee4, this);
93450 }));
93451 function resolveWhile(_x) {
93452 return _resolveWhile.apply(this, arguments);
93453 }
93454 return resolveWhile;
93455 }()
93456 /**
93457 * Handles errors thrown on this stream using a provided handler function.
93458 *
93459 * @param handler A function that handles any `Error` thrown during a `next()`
93460 * call and returns true if the stream should continue (dropping the failed
93461 * call) or false if the stream should quietly terminate. If the handler
93462 * itself throws (or rethrows) an `Error`, that will be propagated.
93463 *
93464 * @returns A `LazyIterator` of elements passed through from upstream,
93465 * possibly filtering or terminating on upstream `next()` calls that
93466 * throw an `Error`.
93467 */
93468 }, {
93469 key: "handleErrors",
93470 value: function handleErrors(handler) {
93471 return new ErrorHandlingLazyIterator(this, handler);
93472 }
93473 // TODO(soergel): Implement reduce() etc.
93474 /**
93475 * Filters this stream according to `predicate`.
93476 *
93477 * @param predicate A function mapping a stream element to a boolean or a
93478 * `Promise` for one.
93479 *
93480 * @returns A `LazyIterator` of elements for which the predicate was true.
93481 */
93482 }, {
93483 key: "filter",
93484 value: function filter(predicate) {
93485 return new FilterIterator(this, predicate);
93486 }
93487 /**
93488 * Maps this stream through a 1-to-1 transform.
93489 *
93490 * @param transform A function mapping a stream element to a transformed
93491 * element.
93492 *
93493 * @returns A `LazyIterator` of transformed elements.
93494 */
93495 }, {
93496 key: "map",
93497 value: function map(transform) {
93498 return new MapIterator(this, transform);
93499 }
93500 /**
93501 * Maps this stream through an async 1-to-1 transform.
93502 *
93503 * @param transform A function mapping a stream element to a `Promise` for a
93504 * transformed stream element.
93505 *
93506 * @returns A `LazyIterator` of transformed elements.
93507 */
93508 }, {
93509 key: "mapAsync",
93510 value: function mapAsync(transform) {
93511 return new AsyncMapIterator(this, transform);
93512 }
93513 /**
93514 * Maps this stream through a 1-to-1 transform, forcing serial execution.
93515 *
93516 * @param transform A function mapping a stream element to a transformed
93517 * element.
93518 *
93519 * @returns A `LazyIterator` of transformed elements.
93520 */
93521 }, {
93522 key: "serialMapAsync",
93523 value: function serialMapAsync(transform) {
93524 return new AsyncMapIterator(this, transform).serial();
93525 }
93526 /**
93527 * Maps this stream through a 1-to-many transform.
93528 *
93529 * @param transform A function mapping a stream element to an array of
93530 * transformed elements.
93531 *
93532 * @returns A `DataStream` of transformed elements.
93533 */
93534 }, {
93535 key: "flatmap",
93536 value: function flatmap(transform) {
93537 return new FlatmapIterator(this, transform);
93538 }
93539 /**
93540 * Apply a function to every element of the stream.
93541 *
93542 * @param f A function to apply to each stream element.
93543 */
93544 }, {
93545 key: "forEachAsync",
93546 value: function () {
93547 var _forEachAsync = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5(f) {
93548 return _regeneratorRuntime().wrap(function _callee5$(_context5) {
93549 while (1) switch (_context5.prev = _context5.next) {
93550 case 0:
93551 return _context5.abrupt("return", this.map(f).resolveFully());
93552 case 1:
93553 case "end":
93554 return _context5.stop();
93555 }
93556 }, _callee5, this);
93557 }));
93558 function forEachAsync(_x2) {
93559 return _forEachAsync.apply(this, arguments);
93560 }
93561 return forEachAsync;
93562 }()
93563 /**
93564 * Apply a function to every element of the stream, forcing serial execution.
93565 *
93566 * @param f A function to apply to each stream element. Should return 'true'
93567 * to indicate that the stream should continue, or 'false' to cause it to
93568 * terminate.
93569 */
93570 }, {
93571 key: "serialForEach",
93572 value: function () {
93573 var _serialForEach = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee6(f) {
93574 return _regeneratorRuntime().wrap(function _callee6$(_context6) {
93575 while (1) switch (_context6.prev = _context6.next) {
93576 case 0:
93577 return _context6.abrupt("return", this.serialMapAsync(f).resolveWhile(function (x) {
93578 return x === true;
93579 }));
93580 case 1:
93581 case "end":
93582 return _context6.stop();
93583 }
93584 }, _callee6, this);
93585 }));
93586 function serialForEach(_x3) {
93587 return _serialForEach.apply(this, arguments);
93588 }
93589 return serialForEach;
93590 }()
93591 /**
93592 * Groups elements into batches, represented as arrays of elements.
93593 *
93594 * We can think of the elements of this iterator as 'rows' (even if they are
93595 * nested structures). By the same token, consecutive values for a given
93596 * key within the elements form a 'column'. This matches the usual sense of
93597 * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
93598 *
93599 * Thus, "Row-major" means that the resulting batch is simply a collection of
93600 * rows: `[row1, row2, row3, ...]`. This is contrast to the column-major
93601 * form, which is needed for vectorized computation.
93602 *
93603 * @param batchSize The number of elements desired per batch.
93604 * @param smallLastBatch Whether to emit the final batch when it has fewer
93605 * than batchSize elements. Default true.
93606 * @returns A `LazyIterator` of batches of elements, represented as arrays
93607 * of the original element type.
93608 */
93609 }, {
93610 key: "rowMajorBatch",
93611 value: function rowMajorBatch(batchSize) {
93612 var smallLastBatch = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : true;
93613 return new RowMajorBatchIterator(this, batchSize, smallLastBatch);
93614 }
93615 /**
93616 * Groups elements into batches, represented in column-major form.
93617 *
93618 * We can think of the elements of this iterator as 'rows' (even if they are
93619 * nested structures). By the same token, consecutive values for a given
93620 * key within the elements form a 'column'. This matches the usual sense of
93621 * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
93622 *
93623 * Thus, "column-major" means that the resulting batch is a (potentially
93624 * nested) structure representing the columns. Each column entry, then,
93625 * contains a collection of the values found in that column for a range of
93626 * input elements. This representation allows for vectorized computation, in
93627 * contrast to the row-major form.
93628 *
93629 * The inputs should all have the same nested structure (i.e., of arrays and
93630 * dicts). The result is a single object with the same nested structure,
93631 * where the leaves are arrays collecting the values of the inputs at that
93632 * location (or, optionally, the result of a custom function applied to those
93633 * arrays).
93634 *
93635 * @param batchSize The number of elements desired per batch.
93636 * @param smallLastBatch Whether to emit the final batch when it has fewer
93637 * than batchSize elements. Default true.
93638 * @param zipFn: (optional) A function that expects an array of elements at a
93639 * single node of the object tree, and returns a `DeepMapResult`. The
93640 * `DeepMapResult` either provides a result value for that node (i.e.,
93641 * representing the subtree), or indicates that the node should be processed
93642 * recursively. The default zipFn recurses as far as possible and places
93643 * arrays at the leaves.
93644 * @returns A `LazyIterator` of batches of elements, represented as an object
93645 * with collections at the leaves.
93646 */
93647 }, {
93648 key: "columnMajorBatch",
93649 value: function columnMajorBatch(batchSize) {
93650 var smallLastBatch = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : true;
93651 var zipFn = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : zipToList;
93652 // First collect the desired number of input elements as a row-major batch.
93653 var rowBatches = this.rowMajorBatch(batchSize, smallLastBatch);
93654 // Now 'rotate' or 'pivot' the data, collecting all values from each column
93655 // in the batch (i.e., for each key within the elements) into an array.
93656 return rowBatches.map(function (x) {
93657 return deepZip(x, zipFn);
93658 });
93659 }
93660 /**
93661 * Concatenate this `LazyIterator` with another.
93662 *
93663 * @param iterator A `LazyIterator` to be concatenated onto this one.
93664 * @param baseErrorHandler An optional function that can intercept `Error`s
93665 * raised during a `next()` call on the base stream. This function can
93666 * decide whether the error should be propagated, whether the error should
93667 * be ignored, or whether the base stream should be terminated.
93668 * @returns A `LazyIterator`.
93669 */
93670 }, {
93671 key: "concatenate",
93672 value: function concatenate(iterator, baseErrorHandler) {
93673 return new ChainedIterator(iteratorFromItems([this, iterator]), baseErrorHandler);
93674 }
93675 /**
93676 * Limits this stream to return at most `count` items.
93677 *
93678 * @param count The maximum number of items to provide from the stream. If
93679 * a negative or undefined value is given, the entire stream is returned
93680 * unaltered.
93681 */
93682 }, {
93683 key: "take",
93684 value: function take(count) {
93685 if (count < 0 || count == null) {
93686 return this;
93687 }
93688 return new TakeIterator(this, count);
93689 }
93690 /**
93691 * Skips the first `count` items in this stream.
93692 *
93693 * @param count The number of items to skip. If a negative or undefined
93694 * value is given, the entire stream is returned unaltered.
93695 */
93696 }, {
93697 key: "skip",
93698 value: function skip(count) {
93699 if (count < 0 || count == null) {
93700 return this;
93701 }
93702 return new SkipIterator(this, count);
93703 }
93704 /**
93705 * Prefetch the first `bufferSize` items in this stream.
93706 *
93707 * Note this prefetches Promises, but makes no guarantees about when those
93708 * Promises resolve.
93709 *
93710 * @param bufferSize: An integer specifying the number of elements to be
93711 * prefetched.
93712 */
93713 }, {
93714 key: "prefetch",
93715 value: function prefetch(bufferSize) {
93716 return new PrefetchIterator(this, bufferSize);
93717 }
93718 // TODO(soergel): deep sharded shuffle, where supported
93719 /**
93720 * Randomly shuffles the elements of this stream.
93721 *
93722 * @param bufferSize: An integer specifying the number of elements from
93723 * this stream from which the new stream will sample.
93724 * @param seed: (Optional.) An integer specifying the random seed that
93725 * will be used to create the distribution.
93726 */
93727 }, {
93728 key: "shuffle",
93729 value: function shuffle(windowSize, seed) {
93730 return new ShuffleIterator(this, windowSize, seed);
93731 }
93732 /**
93733 * Force an iterator to execute serially: each next() call will await the
93734 * prior one, so that they cannot execute concurrently.
93735 */
93736 }, {
93737 key: "serial",
93738 value: function serial() {
93739 return new SerialIterator(this);
93740 }
93741 }]);
93742 return LazyIterator;
93743 }();
93744 // ============================================================================
93745 // The following private classes serve to implement the chainable methods
93746 // on LazyIterator. Unfortunately they can't be placed in separate files,
93747 // due to resulting trouble with circular imports.
93748 // ============================================================================
93749 // Iterators that just extend LazyIterator directly
93750 // ============================================================================
93751 var ArrayIterator = /*#__PURE__*/function (_LazyIterator) {
93752 _inherits(ArrayIterator, _LazyIterator);
93753 var _super = _createSuper(ArrayIterator);
93754 function ArrayIterator(items) {
93755 var _this;
93756 _classCallCheck(this, ArrayIterator);
93757 _this = _super.call(this);
93758 _this.items = items;
93759 _this.trav = 0;
93760 return _this;
93761 }
93762 _createClass(ArrayIterator, [{
93763 key: "summary",
93764 value: function summary() {
93765 return "Array of ".concat(this.items.length, " items");
93766 }
93767 }, {
93768 key: "next",
93769 value: function () {
93770 var _next = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee7() {
93771 var item;
93772 return _regeneratorRuntime().wrap(function _callee7$(_context7) {
93773 while (1) switch (_context7.prev = _context7.next) {
93774 case 0:
93775 if (!(this.trav >= this.items.length)) {
93776 _context7.next = 2;
93777 break;
93778 }
93779 return _context7.abrupt("return", {
93780 value: null,
93781 done: true
93782 });
93783 case 2:
93784 item = this.items[this.trav];
93785 this.trav++;
93786 return _context7.abrupt("return", {
93787 value: deepClone(item),
93788 done: false
93789 });
93790 case 5:
93791 case "end":
93792 return _context7.stop();
93793 }
93794 }, _callee7, this);
93795 }));
93796 function next() {
93797 return _next.apply(this, arguments);
93798 }
93799 return next;
93800 }()
93801 }]);
93802 return ArrayIterator;
93803 }(LazyIterator);
93804 var FunctionCallIterator = /*#__PURE__*/function (_LazyIterator2) {
93805 _inherits(FunctionCallIterator, _LazyIterator2);
93806 var _super2 = _createSuper(FunctionCallIterator);
93807 function FunctionCallIterator(nextFn) {
93808 var _this2;
93809 _classCallCheck(this, FunctionCallIterator);
93810 _this2 = _super2.call(this);
93811 _this2.nextFn = nextFn;
93812 return _this2;
93813 }
93814 _createClass(FunctionCallIterator, [{
93815 key: "summary",
93816 value: function summary() {
93817 return "Function call";
93818 }
93819 }, {
93820 key: "next",
93821 value: function () {
93822 var _next2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee8() {
93823 return _regeneratorRuntime().wrap(function _callee8$(_context8) {
93824 while (1) switch (_context8.prev = _context8.next) {
93825 case 0:
93826 _context8.prev = 0;
93827 return _context8.abrupt("return", this.nextFn());
93828 case 4:
93829 _context8.prev = 4;
93830 _context8.t0 = _context8["catch"](0);
93831 // Modify the error message but leave the stack trace intact
93832 _context8.t0.message = "Error thrown while iterating through a dataset: ".concat(_context8.t0.message);
93833 throw _context8.t0;
93834 case 8:
93835 case "end":
93836 return _context8.stop();
93837 }
93838 }, _callee8, this, [[0, 4]]);
93839 }));
93840 function next() {
93841 return _next2.apply(this, arguments);
93842 }
93843 return next;
93844 }()
93845 }]);
93846 return FunctionCallIterator;
93847 }(LazyIterator);
93848 var SerialIterator = /*#__PURE__*/function (_LazyIterator3) {
93849 _inherits(SerialIterator, _LazyIterator3);
93850 var _super3 = _createSuper(SerialIterator);
93851 function SerialIterator(upstream) {
93852 var _this3;
93853 _classCallCheck(this, SerialIterator);
93854 _this3 = _super3.call(this);
93855 _this3.upstream = upstream;
93856 _this3.lastRead = Promise.resolve({
93857 value: null,
93858 done: false
93859 });
93860 return _this3;
93861 }
93862 _createClass(SerialIterator, [{
93863 key: "summary",
93864 value: function summary() {
93865 return "".concat(this.upstream.summary(), " -> Serial");
93866 }
93867 }, {
93868 key: "next",
93869 value: function () {
93870 var _next3 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee9() {
93871 var _this4 = this;
93872 return _regeneratorRuntime().wrap(function _callee9$(_context9) {
93873 while (1) switch (_context9.prev = _context9.next) {
93874 case 0:
93875 // This sets this.lastRead to a new Promise right away, as opposed to
93876 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
93877 // would not work because this.nextRead would be updated only after the
93878 // promise resolves.
93879 this.lastRead = this.lastRead.then(function () {
93880 return _this4.serialNext();
93881 });
93882 return _context9.abrupt("return", this.lastRead);
93883 case 2:
93884 case "end":
93885 return _context9.stop();
93886 }
93887 }, _callee9, this);
93888 }));
93889 function next() {
93890 return _next3.apply(this, arguments);
93891 }
93892 return next;
93893 }()
93894 }, {
93895 key: "serialNext",
93896 value: function () {
93897 var _serialNext = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee10() {
93898 return _regeneratorRuntime().wrap(function _callee10$(_context10) {
93899 while (1) switch (_context10.prev = _context10.next) {
93900 case 0:
93901 return _context10.abrupt("return", this.upstream.next());
93902 case 1:
93903 case "end":
93904 return _context10.stop();
93905 }
93906 }, _callee10, this);
93907 }));
93908 function serialNext() {
93909 return _serialNext.apply(this, arguments);
93910 }
93911 return serialNext;
93912 }()
93913 }]);
93914 return SerialIterator;
93915 }(LazyIterator);
93916 var SkipIterator = /*#__PURE__*/function (_LazyIterator4) {
93917 _inherits(SkipIterator, _LazyIterator4);
93918 var _super4 = _createSuper(SkipIterator);
93919 function SkipIterator(upstream, maxCount) {
93920 var _this5;
93921 _classCallCheck(this, SkipIterator);
93922 _this5 = _super4.call(this);
93923 _this5.upstream = upstream;
93924 _this5.maxCount = maxCount;
93925 // Local state that should not be clobbered by out-of-order execution.
93926 _this5.count = 0;
93927 _this5.lastRead = Promise.resolve({
93928 value: null,
93929 done: false
93930 });
93931 return _this5;
93932 }
93933 _createClass(SkipIterator, [{
93934 key: "summary",
93935 value: function summary() {
93936 return "".concat(this.upstream.summary(), " -> Skip");
93937 }
93938 }, {
93939 key: "next",
93940 value: function () {
93941 var _next4 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee11() {
93942 var _this6 = this;
93943 return _regeneratorRuntime().wrap(function _callee11$(_context11) {
93944 while (1) switch (_context11.prev = _context11.next) {
93945 case 0:
93946 // This sets this.lastRead to a new Promise right away, as opposed to
93947 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
93948 // would not work because this.nextRead would be updated only after the
93949 // promise resolves.
93950 this.lastRead = this.lastRead.then(function () {
93951 return _this6.serialNext();
93952 });
93953 return _context11.abrupt("return", this.lastRead);
93954 case 2:
93955 case "end":
93956 return _context11.stop();
93957 }
93958 }, _callee11, this);
93959 }));
93960 function next() {
93961 return _next4.apply(this, arguments);
93962 }
93963 return next;
93964 }()
93965 }, {
93966 key: "serialNext",
93967 value: function () {
93968 var _serialNext2 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee12() {
93969 var skipped;
93970 return _regeneratorRuntime().wrap(function _callee12$(_context12) {
93971 while (1) switch (_context12.prev = _context12.next) {
93972 case 0:
93973 if (!(this.count++ < this.maxCount)) {
93974 _context12.next = 9;
93975 break;
93976 }
93977 _context12.next = 3;
93978 return this.upstream.next();
93979 case 3:
93980 skipped = _context12.sent;
93981 if (!skipped.done) {
93982 _context12.next = 6;
93983 break;
93984 }
93985 return _context12.abrupt("return", skipped);
93986 case 6:
93987 dispose(skipped.value);
93988 _context12.next = 0;
93989 break;
93990 case 9:
93991 return _context12.abrupt("return", this.upstream.next());
93992 case 10:
93993 case "end":
93994 return _context12.stop();
93995 }
93996 }, _callee12, this);
93997 }));
93998 function serialNext() {
93999 return _serialNext2.apply(this, arguments);
94000 }
94001 return serialNext;
94002 }()
94003 }]);
94004 return SkipIterator;
94005 }(LazyIterator);
94006 var TakeIterator = /*#__PURE__*/function (_LazyIterator5) {
94007 _inherits(TakeIterator, _LazyIterator5);
94008 var _super5 = _createSuper(TakeIterator);
94009 function TakeIterator(upstream, maxCount) {
94010 var _this7;
94011 _classCallCheck(this, TakeIterator);
94012 _this7 = _super5.call(this);
94013 _this7.upstream = upstream;
94014 _this7.maxCount = maxCount;
94015 _this7.count = 0;
94016 return _this7;
94017 }
94018 _createClass(TakeIterator, [{
94019 key: "summary",
94020 value: function summary() {
94021 return "".concat(this.upstream.summary(), " -> Take");
94022 }
94023 }, {
94024 key: "next",
94025 value: function () {
94026 var _next5 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee13() {
94027 return _regeneratorRuntime().wrap(function _callee13$(_context13) {
94028 while (1) switch (_context13.prev = _context13.next) {
94029 case 0:
94030 if (!(this.count++ >= this.maxCount)) {
94031 _context13.next = 2;
94032 break;
94033 }
94034 return _context13.abrupt("return", {
94035 value: null,
94036 done: true
94037 });
94038 case 2:
94039 return _context13.abrupt("return", this.upstream.next());
94040 case 3:
94041 case "end":
94042 return _context13.stop();
94043 }
94044 }, _callee13, this);
94045 }));
94046 function next() {
94047 return _next5.apply(this, arguments);
94048 }
94049 return next;
94050 }()
94051 }]);
94052 return TakeIterator;
94053 }(LazyIterator); // Note this batch just groups items into row-wise element arrays.
94054 // Rotating these to a column-wise representation happens only at the dataset
94055 // level.
94056 var RowMajorBatchIterator = /*#__PURE__*/function (_LazyIterator6) {
94057 _inherits(RowMajorBatchIterator, _LazyIterator6);
94058 var _super6 = _createSuper(RowMajorBatchIterator);
94059 function RowMajorBatchIterator(upstream, batchSize) {
94060 var _this8;
94061 var enableSmallLastBatch = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true;
94062 _classCallCheck(this, RowMajorBatchIterator);
94063 _this8 = _super6.call(this);
94064 _this8.upstream = upstream;
94065 _this8.batchSize = batchSize;
94066 _this8.enableSmallLastBatch = enableSmallLastBatch;
94067 _this8.lastRead = Promise.resolve({
94068 value: null,
94069 done: false
94070 });
94071 return _this8;
94072 }
94073 _createClass(RowMajorBatchIterator, [{
94074 key: "summary",
94075 value: function summary() {
94076 return "".concat(this.upstream.summary(), " -> RowMajorBatch");
94077 }
94078 }, {
94079 key: "next",
94080 value: function () {
94081 var _next6 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee14() {
94082 var _this9 = this;
94083 return _regeneratorRuntime().wrap(function _callee14$(_context14) {
94084 while (1) switch (_context14.prev = _context14.next) {
94085 case 0:
94086 // This sets this.lastRead to a new Promise right away, as opposed to
94087 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
94088 // would not work because this.nextRead would be updated only after the
94089 // promise resolves.
94090 this.lastRead = this.lastRead.then(function () {
94091 return _this9.serialNext();
94092 });
94093 return _context14.abrupt("return", this.lastRead);
94094 case 2:
94095 case "end":
94096 return _context14.stop();
94097 }
94098 }, _callee14, this);
94099 }));
94100 function next() {
94101 return _next6.apply(this, arguments);
94102 }
94103 return next;
94104 }()
94105 }, {
94106 key: "serialNext",
94107 value: function () {
94108 var _serialNext3 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee15() {
94109 var batch, item;
94110 return _regeneratorRuntime().wrap(function _callee15$(_context15) {
94111 while (1) switch (_context15.prev = _context15.next) {
94112 case 0:
94113 batch = [];
94114 case 1:
94115 if (!(batch.length < this.batchSize)) {
94116 _context15.next = 12;
94117 break;
94118 }
94119 _context15.next = 4;
94120 return this.upstream.next();
94121 case 4:
94122 item = _context15.sent;
94123 if (!item.done) {
94124 _context15.next = 9;
94125 break;
94126 }
94127 if (!(this.enableSmallLastBatch && batch.length > 0)) {
94128 _context15.next = 8;
94129 break;
94130 }
94131 return _context15.abrupt("return", {
94132 value: batch,
94133 done: false
94134 });
94135 case 8:
94136 return _context15.abrupt("return", {
94137 value: null,
94138 done: true
94139 });
94140 case 9:
94141 batch.push(item.value);
94142 _context15.next = 1;
94143 break;
94144 case 12:
94145 return _context15.abrupt("return", {
94146 value: batch,
94147 done: false
94148 });
94149 case 13:
94150 case "end":
94151 return _context15.stop();
94152 }
94153 }, _callee15, this);
94154 }));
94155 function serialNext() {
94156 return _serialNext3.apply(this, arguments);
94157 }
94158 return serialNext;
94159 }()
94160 }]);
94161 return RowMajorBatchIterator;
94162 }(LazyIterator);
94163 var FilterIterator = /*#__PURE__*/function (_LazyIterator7) {
94164 _inherits(FilterIterator, _LazyIterator7);
94165 var _super7 = _createSuper(FilterIterator);
94166 function FilterIterator(upstream, predicate) {
94167 var _this10;
94168 _classCallCheck(this, FilterIterator);
94169 _this10 = _super7.call(this);
94170 _this10.upstream = upstream;
94171 _this10.predicate = predicate;
94172 _this10.lastRead = Promise.resolve({
94173 value: null,
94174 done: false
94175 });
94176 return _this10;
94177 }
94178 _createClass(FilterIterator, [{
94179 key: "summary",
94180 value: function summary() {
94181 return "".concat(this.upstream.summary(), " -> Filter");
94182 }
94183 }, {
94184 key: "next",
94185 value: function () {
94186 var _next7 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee16() {
94187 var _this11 = this;
94188 return _regeneratorRuntime().wrap(function _callee16$(_context16) {
94189 while (1) switch (_context16.prev = _context16.next) {
94190 case 0:
94191 // This sets this.lastRead to a new Promise right away, as opposed to
94192 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
94193 // would not work because this.nextRead would be updated only after the
94194 // promise resolves.
94195 this.lastRead = this.lastRead.then(function () {
94196 return _this11.serialNext();
94197 });
94198 return _context16.abrupt("return", this.lastRead);
94199 case 2:
94200 case "end":
94201 return _context16.stop();
94202 }
94203 }, _callee16, this);
94204 }));
94205 function next() {
94206 return _next7.apply(this, arguments);
94207 }
94208 return next;
94209 }()
94210 }, {
94211 key: "serialNext",
94212 value: function () {
94213 var _serialNext4 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee17() {
94214 var item;
94215 return _regeneratorRuntime().wrap(function _callee17$(_context17) {
94216 while (1) switch (_context17.prev = _context17.next) {
94217 case 0:
94218 if (!true) {
94219 _context17.next = 9;
94220 break;
94221 }
94222 _context17.next = 3;
94223 return this.upstream.next();
94224 case 3:
94225 item = _context17.sent;
94226 if (!(item.done || this.predicate(item.value))) {
94227 _context17.next = 6;
94228 break;
94229 }
94230 return _context17.abrupt("return", item);
94231 case 6:
94232 dispose(item.value);
94233 _context17.next = 0;
94234 break;
94235 case 9:
94236 case "end":
94237 return _context17.stop();
94238 }
94239 }, _callee17, this);
94240 }));
94241 function serialNext() {
94242 return _serialNext4.apply(this, arguments);
94243 }
94244 return serialNext;
94245 }()
94246 }]);
94247 return FilterIterator;
94248 }(LazyIterator);
94249 var MapIterator = /*#__PURE__*/function (_LazyIterator8) {
94250 _inherits(MapIterator, _LazyIterator8);
94251 var _super8 = _createSuper(MapIterator);
94252 function MapIterator(upstream, transform) {
94253 var _this12;
94254 _classCallCheck(this, MapIterator);
94255 _this12 = _super8.call(this);
94256 _this12.upstream = upstream;
94257 _this12.transform = transform;
94258 return _this12;
94259 }
94260 _createClass(MapIterator, [{
94261 key: "summary",
94262 value: function summary() {
94263 return "".concat(this.upstream.summary(), " -> Map");
94264 }
94265 }, {
94266 key: "next",
94267 value: function () {
94268 var _next8 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee18() {
94269 var item, inputTensors, mapped, outputTensors, _iterator, _step, t;
94270 return _regeneratorRuntime().wrap(function _callee18$(_context18) {
94271 while (1) switch (_context18.prev = _context18.next) {
94272 case 0:
94273 _context18.next = 2;
94274 return this.upstream.next();
94275 case 2:
94276 item = _context18.sent;
94277 if (!item.done) {
94278 _context18.next = 5;
94279 break;
94280 }
94281 return _context18.abrupt("return", {
94282 value: null,
94283 done: true
94284 });
94285 case 5:
94286 inputTensors = getTensorsInContainer(item.value); // Careful: the transform may mutate the item in place.
94287 // That's why we have to remember the input Tensors above, and then
94288 // below dispose only those that were not passed through to the output.
94289 // Note too that the transform function is responsible for tidying
94290 // any intermediate Tensors. Here we are concerned only about the
94291 // inputs.
94292 mapped = this.transform(item.value);
94293 outputTensors = getTensorsInContainer(mapped); // TODO(soergel) faster intersection
94294 // TODO(soergel) move to tf.disposeExcept(in, out)?
94295 _iterator = _createForOfIteratorHelper(inputTensors);
94296 try {
94297 for (_iterator.s(); !(_step = _iterator.n()).done;) {
94298 t = _step.value;
94299 if (!isTensorInList(t, outputTensors)) {
94300 t.dispose();
94301 }
94302 }
94303 } catch (err) {
94304 _iterator.e(err);
94305 } finally {
94306 _iterator.f();
94307 }
94308 return _context18.abrupt("return", {
94309 value: mapped,
94310 done: false
94311 });
94312 case 11:
94313 case "end":
94314 return _context18.stop();
94315 }
94316 }, _callee18, this);
94317 }));
94318 function next() {
94319 return _next8.apply(this, arguments);
94320 }
94321 return next;
94322 }()
94323 }]);
94324 return MapIterator;
94325 }(LazyIterator);
94326 var ErrorHandlingLazyIterator = /*#__PURE__*/function (_LazyIterator9) {
94327 _inherits(ErrorHandlingLazyIterator, _LazyIterator9);
94328 var _super9 = _createSuper(ErrorHandlingLazyIterator);
94329 function ErrorHandlingLazyIterator(upstream, handler) {
94330 var _this13;
94331 _classCallCheck(this, ErrorHandlingLazyIterator);
94332 _this13 = _super9.call(this);
94333 _this13.upstream = upstream;
94334 _this13.handler = handler;
94335 _this13.count = 0;
94336 _this13.lastRead = Promise.resolve({
94337 value: null,
94338 done: false
94339 });
94340 return _this13;
94341 }
94342 _createClass(ErrorHandlingLazyIterator, [{
94343 key: "summary",
94344 value: function summary() {
94345 return "".concat(this.upstream.summary(), " -> handleErrors");
94346 }
94347 }, {
94348 key: "next",
94349 value: function () {
94350 var _next9 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee19() {
94351 var _this14 = this;
94352 return _regeneratorRuntime().wrap(function _callee19$(_context19) {
94353 while (1) switch (_context19.prev = _context19.next) {
94354 case 0:
94355 // This sets this.lastRead to a new Promise right away, as opposed to
94356 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
94357 // would not work because this.nextRead would be updated only after the
94358 // promise resolves.
94359 this.lastRead = this.lastRead.then(function () {
94360 return _this14.serialNext();
94361 });
94362 return _context19.abrupt("return", this.lastRead);
94363 case 2:
94364 case "end":
94365 return _context19.stop();
94366 }
94367 }, _callee19, this);
94368 }));
94369 function next() {
94370 return _next9.apply(this, arguments);
94371 }
94372 return next;
94373 }()
94374 }, {
94375 key: "serialNext",
94376 value: function () {
94377 var _serialNext5 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee20() {
94378 return _regeneratorRuntime().wrap(function _callee20$(_context20) {
94379 while (1) switch (_context20.prev = _context20.next) {
94380 case 0:
94381 if (!true) {
94382 _context20.next = 13;
94383 break;
94384 }
94385 _context20.prev = 1;
94386 _context20.next = 4;
94387 return this.upstream.next();
94388 case 4:
94389 return _context20.abrupt("return", _context20.sent);
94390 case 7:
94391 _context20.prev = 7;
94392 _context20.t0 = _context20["catch"](1);
94393 if (this.handler(_context20.t0)) {
94394 _context20.next = 11;
94395 break;
94396 }
94397 return _context20.abrupt("return", {
94398 value: null,
94399 done: true
94400 });
94401 case 11:
94402 _context20.next = 0;
94403 break;
94404 case 13:
94405 case "end":
94406 return _context20.stop();
94407 }
94408 }, _callee20, this, [[1, 7]]);
94409 }));
94410 function serialNext() {
94411 return _serialNext5.apply(this, arguments);
94412 }
94413 return serialNext;
94414 }()
94415 }]);
94416 return ErrorHandlingLazyIterator;
94417 }(LazyIterator);
94418 var AsyncMapIterator = /*#__PURE__*/function (_LazyIterator10) {
94419 _inherits(AsyncMapIterator, _LazyIterator10);
94420 var _super10 = _createSuper(AsyncMapIterator);
94421 function AsyncMapIterator(upstream, transform) {
94422 var _this15;
94423 _classCallCheck(this, AsyncMapIterator);
94424 _this15 = _super10.call(this);
94425 _this15.upstream = upstream;
94426 _this15.transform = transform;
94427 return _this15;
94428 }
94429 _createClass(AsyncMapIterator, [{
94430 key: "summary",
94431 value: function summary() {
94432 return "".concat(this.upstream.summary(), " -> AsyncMap");
94433 }
94434 }, {
94435 key: "next",
94436 value: function () {
94437 var _next10 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee21() {
94438 var item, inputTensors, mapped, outputTensors, _iterator2, _step2, t;
94439 return _regeneratorRuntime().wrap(function _callee21$(_context21) {
94440 while (1) switch (_context21.prev = _context21.next) {
94441 case 0:
94442 _context21.next = 2;
94443 return this.upstream.next();
94444 case 2:
94445 item = _context21.sent;
94446 if (!item.done) {
94447 _context21.next = 5;
94448 break;
94449 }
94450 return _context21.abrupt("return", {
94451 value: null,
94452 done: true
94453 });
94454 case 5:
94455 inputTensors = getTensorsInContainer(item.value); // Careful: the transform may mutate the item in place.
94456 // That's why we have to remember the input Tensors above, and then
94457 // below dispose only those that were not passed through to the output.
94458 // Note too that the transform function is responsible for tidying
94459 // any intermediate Tensors. Here we are concerned only about the
94460 // inputs.
94461 _context21.next = 8;
94462 return this.transform(item.value);
94463 case 8:
94464 mapped = _context21.sent;
94465 outputTensors = getTensorsInContainer(mapped); // TODO(soergel) faster intersection
94466 // TODO(soergel) move to tf.disposeExcept(in, out)?
94467 _iterator2 = _createForOfIteratorHelper(inputTensors);
94468 try {
94469 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
94470 t = _step2.value;
94471 if (!isTensorInList(t, outputTensors)) {
94472 t.dispose();
94473 }
94474 }
94475 } catch (err) {
94476 _iterator2.e(err);
94477 } finally {
94478 _iterator2.f();
94479 }
94480 return _context21.abrupt("return", {
94481 value: mapped,
94482 done: false
94483 });
94484 case 13:
94485 case "end":
94486 return _context21.stop();
94487 }
94488 }, _callee21, this);
94489 }));
94490 function next() {
94491 return _next10.apply(this, arguments);
94492 }
94493 return next;
94494 }()
94495 }]);
94496 return AsyncMapIterator;
94497 }(LazyIterator); // Iterators that maintain a queue of pending items
94498 // ============================================================================
94499 /**
94500 * A base class for transforming streams that operate by maintaining an
94501 * output queue of elements that are ready to return via next(). This is
94502 * commonly required when the transformation is 1-to-many: A call to next()
94503 * may trigger a call to the underlying stream, which will produce many
94504 * mapped elements of this stream-- of which we need to return only one, so
94505 * we have to queue the rest.
94506 */
94507 var OneToManyIterator = /*#__PURE__*/function (_LazyIterator11) {
94508 _inherits(OneToManyIterator, _LazyIterator11);
94509 var _super11 = _createSuper(OneToManyIterator);
94510 function OneToManyIterator() {
94511 var _this16;
94512 _classCallCheck(this, OneToManyIterator);
94513 _this16 = _super11.call(this);
94514 _this16.outputQueue = new GrowingRingBuffer();
94515 _this16.lastRead = Promise.resolve({
94516 value: null,
94517 done: false
94518 });
94519 return _this16;
94520 }
94521 _createClass(OneToManyIterator, [{
94522 key: "next",
94523 value: function () {
94524 var _next11 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee22() {
94525 var _this17 = this;
94526 return _regeneratorRuntime().wrap(function _callee22$(_context22) {
94527 while (1) switch (_context22.prev = _context22.next) {
94528 case 0:
94529 // This sets this.lastRead to a new Promise right away, as opposed to
94530 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
94531 // would not work because this.nextRead would be updated only after the
94532 // promise resolves.
94533 this.lastRead = this.lastRead.then(function () {
94534 return _this17.serialNext();
94535 });
94536 return _context22.abrupt("return", this.lastRead);
94537 case 2:
94538 case "end":
94539 return _context22.stop();
94540 }
94541 }, _callee22, this);
94542 }));
94543 function next() {
94544 return _next11.apply(this, arguments);
94545 }
94546 return next;
94547 }()
94548 }, {
94549 key: "serialNext",
94550 value: function () {
94551 var _serialNext6 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee23() {
94552 return _regeneratorRuntime().wrap(function _callee23$(_context23) {
94553 while (1) switch (_context23.prev = _context23.next) {
94554 case 0:
94555 if (!(this.outputQueue.length() === 0)) {
94556 _context23.next = 7;
94557 break;
94558 }
94559 _context23.next = 3;
94560 return this.pump();
94561 case 3:
94562 if (_context23.sent) {
94563 _context23.next = 5;
94564 break;
94565 }
94566 return _context23.abrupt("return", {
94567 value: null,
94568 done: true
94569 });
94570 case 5:
94571 _context23.next = 0;
94572 break;
94573 case 7:
94574 return _context23.abrupt("return", {
94575 value: this.outputQueue.shift(),
94576 done: false
94577 });
94578 case 8:
94579 case "end":
94580 return _context23.stop();
94581 }
94582 }, _callee23, this);
94583 }));
94584 function serialNext() {
94585 return _serialNext6.apply(this, arguments);
94586 }
94587 return serialNext;
94588 }()
94589 }]);
94590 return OneToManyIterator;
94591 }(LazyIterator);
94592 var FlatmapIterator = /*#__PURE__*/function (_OneToManyIterator) {
94593 _inherits(FlatmapIterator, _OneToManyIterator);
94594 var _super12 = _createSuper(FlatmapIterator);
94595 function FlatmapIterator(upstream, transform) {
94596 var _this18;
94597 _classCallCheck(this, FlatmapIterator);
94598 _this18 = _super12.call(this);
94599 _this18.upstream = upstream;
94600 _this18.transform = transform;
94601 return _this18;
94602 }
94603 _createClass(FlatmapIterator, [{
94604 key: "summary",
94605 value: function summary() {
94606 return "".concat(this.upstream.summary(), " -> Flatmap");
94607 }
94608 }, {
94609 key: "pump",
94610 value: function () {
94611 var _pump = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee24() {
94612 var item, inputTensors, mappedArray, outputTensors, _iterator3, _step3, t;
94613 return _regeneratorRuntime().wrap(function _callee24$(_context24) {
94614 while (1) switch (_context24.prev = _context24.next) {
94615 case 0:
94616 _context24.next = 2;
94617 return this.upstream.next();
94618 case 2:
94619 item = _context24.sent;
94620 if (!item.done) {
94621 _context24.next = 5;
94622 break;
94623 }
94624 return _context24.abrupt("return", false);
94625 case 5:
94626 inputTensors = getTensorsInContainer(item.value); // Careful: the transform may mutate the item in place.
94627 // that's why we have to remember the input Tensors above, and then
94628 // below dispose only those that were not passed through to the output.
94629 // Note too that the transform function is responsible for tidying any
94630 // intermediate Tensors. Here we are concerned only about the inputs.
94631 mappedArray = this.transform(item.value);
94632 outputTensors = getTensorsInContainer(mappedArray);
94633 this.outputQueue.pushAll(mappedArray);
94634 // TODO(soergel) faster intersection, and deduplicate outputTensors
94635 // TODO(soergel) move to tf.disposeExcept(in, out)?
94636 _iterator3 = _createForOfIteratorHelper(inputTensors);
94637 try {
94638 for (_iterator3.s(); !(_step3 = _iterator3.n()).done;) {
94639 t = _step3.value;
94640 if (!isTensorInList(t, outputTensors)) {
94641 t.dispose();
94642 }
94643 }
94644 } catch (err) {
94645 _iterator3.e(err);
94646 } finally {
94647 _iterator3.f();
94648 }
94649 return _context24.abrupt("return", true);
94650 case 12:
94651 case "end":
94652 return _context24.stop();
94653 }
94654 }, _callee24, this);
94655 }));
94656 function pump() {
94657 return _pump.apply(this, arguments);
94658 }
94659 return pump;
94660 }()
94661 }]);
94662 return FlatmapIterator;
94663 }(OneToManyIterator);
94664 /**
94665 * Provides a `LazyIterator` that concatenates a stream of underlying
94666 * streams.
94667 *
94668 * Doing this in a concurrency-safe way requires some trickery. In
94669 * particular, we want this stream to return the elements from the
94670 * underlying streams in the correct order according to when next() was
94671 * called, even if the resulting Promises resolve in a different order.
94672 */
94673 var ChainedIterator = /*#__PURE__*/function (_LazyIterator12) {
94674 _inherits(ChainedIterator, _LazyIterator12);
94675 var _super13 = _createSuper(ChainedIterator);
94676 function ChainedIterator(iterators, baseErrorHandler) {
94677 var _this19;
94678 _classCallCheck(this, ChainedIterator);
94679 _this19 = _super13.call(this);
94680 _this19.baseErrorHandler = baseErrorHandler;
94681 // Strict Promise execution order:
94682 // a next() call may not even begin until the previous one completes.
94683 _this19.lastRead = null;
94684 // Local state that should not be clobbered by out-of-order execution.
94685 _this19.iterator = null;
94686 _this19.moreIterators = iterators;
94687 return _this19;
94688 }
94689 _createClass(ChainedIterator, [{
94690 key: "summary",
94691 value: function summary() {
94692 var upstreamSummaries = 'TODO: fill in upstream of chained summaries';
94693 return "".concat(upstreamSummaries, " -> Chained");
94694 }
94695 }, {
94696 key: "next",
94697 value: function () {
94698 var _next12 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee25() {
94699 return _regeneratorRuntime().wrap(function _callee25$(_context25) {
94700 while (1) switch (_context25.prev = _context25.next) {
94701 case 0:
94702 this.lastRead = this.readFromChain(this.lastRead);
94703 return _context25.abrupt("return", this.lastRead);
94704 case 2:
94705 case "end":
94706 return _context25.stop();
94707 }
94708 }, _callee25, this);
94709 }));
94710 function next() {
94711 return _next12.apply(this, arguments);
94712 }
94713 return next;
94714 }()
94715 }, {
94716 key: "readFromChain",
94717 value: function () {
94718 var _readFromChain = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee26(lastRead) {
94719 var iteratorResult, itemResult;
94720 return _regeneratorRuntime().wrap(function _callee26$(_context26) {
94721 while (1) switch (_context26.prev = _context26.next) {
94722 case 0:
94723 _context26.next = 2;
94724 return lastRead;
94725 case 2:
94726 if (!(this.iterator == null)) {
94727 _context26.next = 10;
94728 break;
94729 }
94730 _context26.next = 5;
94731 return this.moreIterators.next();
94732 case 5:
94733 iteratorResult = _context26.sent;
94734 if (!iteratorResult.done) {
94735 _context26.next = 8;
94736 break;
94737 }
94738 return _context26.abrupt("return", {
94739 value: null,
94740 done: true
94741 });
94742 case 8:
94743 this.iterator = iteratorResult.value;
94744 if (this.baseErrorHandler != null) {
94745 this.iterator = this.iterator.handleErrors(this.baseErrorHandler);
94746 }
94747 case 10:
94748 _context26.next = 12;
94749 return this.iterator.next();
94750 case 12:
94751 itemResult = _context26.sent;
94752 if (!itemResult.done) {
94753 _context26.next = 16;
94754 break;
94755 }
94756 this.iterator = null;
94757 return _context26.abrupt("return", this.readFromChain(lastRead));
94758 case 16:
94759 return _context26.abrupt("return", itemResult);
94760 case 17:
94761 case "end":
94762 return _context26.stop();
94763 }
94764 }, _callee26, this);
94765 }));
94766 function readFromChain(_x4) {
94767 return _readFromChain.apply(this, arguments);
94768 }
94769 return readFromChain;
94770 }()
94771 }]);
94772 return ChainedIterator;
94773 }(LazyIterator);
94774 var ZipMismatchMode;
94775 (function (ZipMismatchMode) {
94776 ZipMismatchMode[ZipMismatchMode["FAIL"] = 0] = "FAIL";
94777 ZipMismatchMode[ZipMismatchMode["SHORTEST"] = 1] = "SHORTEST";
94778 ZipMismatchMode[ZipMismatchMode["LONGEST"] = 2] = "LONGEST"; // use nulls for exhausted streams; use up the longest stream.
94779 })(ZipMismatchMode || (ZipMismatchMode = {}));
94780 /**
94781 * Provides a `LazyIterator` that zips together an array, dict, or nested
94782 * structure of `LazyIterator`s (and perhaps additional constants).
94783 *
94784 * The underlying streams must provide elements in a consistent order such
94785 * that they correspond.
94786 *
94787 * Typically, the underlying streams should have the same number of
94788 * elements. If they do not, the behavior is determined by the
94789 * `mismatchMode` argument.
94790 *
94791 * The nested structure of the `iterators` argument determines the
94792 * structure of elements in the resulting iterator.
94793 *
94794 * Doing this in a concurrency-safe way requires some trickery. In
94795 * particular, we want this stream to return the elements from the
94796 * underlying streams in the correct order according to when next() was
94797 * called, even if the resulting Promises resolve in a different order.
94798 *
94799 * @param iterators: An array or object containing LazyIterators at the
94800 * leaves.
94801 * @param mismatchMode: Determines what to do when one underlying iterator
94802 * is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
94803 * causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
94804 * causes the zipped iterator to terminate with the furst underlying
94805 * streams, so elements remaining on the longer streams are ignored.
94806 * `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
94807 * in nulls for the exhausted streams, until all streams are exhausted.
94808 */
94809 var ZipIterator = /*#__PURE__*/function (_LazyIterator13) {
94810 _inherits(ZipIterator, _LazyIterator13);
94811 var _super14 = _createSuper(ZipIterator);
94812 function ZipIterator(iterators) {
94813 var _this20;
94814 var mismatchMode = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : ZipMismatchMode.FAIL;
94815 _classCallCheck(this, ZipIterator);
94816 _this20 = _super14.call(this);
94817 _this20.iterators = iterators;
94818 _this20.mismatchMode = mismatchMode;
94819 _this20.count = 0;
94820 _this20.currentPromise = null;
94821 return _this20;
94822 }
94823 _createClass(ZipIterator, [{
94824 key: "summary",
94825 value: function summary() {
94826 var upstreamSummaries = 'TODO: fill in upstream of zip summaries';
94827 return "{".concat(upstreamSummaries, "} -> Zip");
94828 }
94829 }, {
94830 key: "nextState",
94831 value: function () {
94832 var _nextState = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee27(afterState) {
94833 var numIterators, iteratorsDone, getNext, mapped;
94834 return _regeneratorRuntime().wrap(function _callee27$(_context27) {
94835 while (1) switch (_context27.prev = _context27.next) {
94836 case 0:
94837 getNext = function _getNext(container) {
94838 if (container instanceof LazyIterator) {
94839 var result = container.next();
94840 return {
94841 value: result.then(function (x) {
94842 numIterators++;
94843 if (x.done) {
94844 iteratorsDone++;
94845 }
94846 return x.value;
94847 }),
94848 recurse: false
94849 };
94850 } else {
94851 return {
94852 value: null,
94853 recurse: true
94854 };
94855 }
94856 };
94857 _context27.next = 3;
94858 return afterState;
94859 case 3:
94860 // Collect underlying iterator "done" signals as a side effect in
94861 // getNext()
94862 numIterators = 0;
94863 iteratorsDone = 0;
94864 _context27.next = 7;
94865 return deepMapAndAwaitAll(this.iterators, getNext);
94866 case 7:
94867 mapped = _context27.sent;
94868 if (!(numIterators === iteratorsDone)) {
94869 _context27.next = 10;
94870 break;
94871 }
94872 return _context27.abrupt("return", {
94873 value: null,
94874 done: true
94875 });
94876 case 10:
94877 if (!(iteratorsDone > 0)) {
94878 _context27.next = 16;
94879 break;
94880 }
94881 _context27.t0 = this.mismatchMode;
94882 _context27.next = _context27.t0 === ZipMismatchMode.FAIL ? 14 : _context27.t0 === ZipMismatchMode.SHORTEST ? 15 : _context27.t0 === ZipMismatchMode.LONGEST ? 16 : 16;
94883 break;
94884 case 14:
94885 throw new Error('Zipped streams should have the same length. ' + "Mismatched at element ".concat(this.count, "."));
94886 case 15:
94887 return _context27.abrupt("return", {
94888 value: null,
94889 done: true
94890 });
94891 case 16:
94892 this.count++;
94893 return _context27.abrupt("return", {
94894 value: mapped,
94895 done: false
94896 });
94897 case 18:
94898 case "end":
94899 return _context27.stop();
94900 }
94901 }, _callee27, this);
94902 }));
94903 function nextState(_x5) {
94904 return _nextState.apply(this, arguments);
94905 }
94906 return nextState;
94907 }()
94908 }, {
94909 key: "next",
94910 value: function () {
94911 var _next13 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee28() {
94912 return _regeneratorRuntime().wrap(function _callee28$(_context28) {
94913 while (1) switch (_context28.prev = _context28.next) {
94914 case 0:
94915 this.currentPromise = this.nextState(this.currentPromise);
94916 return _context28.abrupt("return", this.currentPromise);
94917 case 2:
94918 case "end":
94919 return _context28.stop();
94920 }
94921 }, _callee28, this);
94922 }));
94923 function next() {
94924 return _next13.apply(this, arguments);
94925 }
94926 return next;
94927 }()
94928 }]);
94929 return ZipIterator;
94930 }(LazyIterator); // Iterators that maintain a ring buffer of pending promises
94931 // ============================================================================
94932 /**
94933 * A stream that prefetches a given number of items from an upstream source,
94934 * returning them in FIFO order.
94935 *
94936 * Note this prefetches Promises, but makes no guarantees about when those
94937 * Promises resolve.
94938 */
94939 var PrefetchIterator = /*#__PURE__*/function (_LazyIterator14) {
94940 _inherits(PrefetchIterator, _LazyIterator14);
94941 var _super15 = _createSuper(PrefetchIterator);
94942 function PrefetchIterator(upstream, bufferSize) {
94943 var _this21;
94944 _classCallCheck(this, PrefetchIterator);
94945 _this21 = _super15.call(this);
94946 _this21.upstream = upstream;
94947 _this21.bufferSize = bufferSize;
94948 _this21.buffer = new RingBuffer(bufferSize);
94949 return _this21;
94950 }
94951 _createClass(PrefetchIterator, [{
94952 key: "summary",
94953 value: function summary() {
94954 return "".concat(this.upstream.summary(), " -> Prefetch");
94955 }
94956 /**
94957 * Refill the prefetch buffer. Returns only after the buffer is full, or
94958 * the upstream source is exhausted.
94959 */
94960 }, {
94961 key: "refill",
94962 value: function refill() {
94963 while (!this.buffer.isFull()) {
94964 var v = this.upstream.next();
94965 this.buffer.push(v);
94966 }
94967 }
94968 }, {
94969 key: "next",
94970 value: function next() {
94971 this.refill();
94972 // This shift will never throw an error because the buffer is always
94973 // full after a refill. If the stream is exhausted, the buffer will be
94974 // full of Promises that will resolve to the end-of-stream signal.
94975 return this.buffer.shift();
94976 }
94977 }]);
94978 return PrefetchIterator;
94979 }(LazyIterator);
94980 /**
94981 * A stream that performs a sliding-window random shuffle on an upstream
94982 * source. This is like a `PrefetchIterator` except that the items are
94983 * returned in randomized order. Mixing naturally improves as the buffer
94984 * size increases.
94985 */
94986 var ShuffleIterator = /*#__PURE__*/function (_PrefetchIterator) {
94987 _inherits(ShuffleIterator, _PrefetchIterator);
94988 var _super16 = _createSuper(ShuffleIterator);
94989 function ShuffleIterator(upstream, windowSize, seed) {
94990 var _this22;
94991 _classCallCheck(this, ShuffleIterator);
94992 _this22 = _super16.call(this, upstream, windowSize);
94993 _this22.upstream = upstream;
94994 _this22.windowSize = windowSize;
94995 // Local state that should not be clobbered by out-of-order execution.
94996 _this22.upstreamExhausted = false;
94997 _this22.random = seedrandom.alea(seed || now().toString());
94998 _this22.lastRead = Promise.resolve({
94999 value: null,
95000 done: false
95001 });
95002 return _this22;
95003 }
95004 _createClass(ShuffleIterator, [{
95005 key: "next",
95006 value: function () {
95007 var _next14 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee29() {
95008 var _this23 = this;
95009 return _regeneratorRuntime().wrap(function _callee29$(_context29) {
95010 while (1) switch (_context29.prev = _context29.next) {
95011 case 0:
95012 // This sets this.lastRead to a new Promise right away, as opposed to
95013 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
95014 // would not work because this.nextRead would be updated only after the
95015 // promise resolves.
95016 this.lastRead = this.lastRead.then(function () {
95017 return _this23.serialNext();
95018 });
95019 return _context29.abrupt("return", this.lastRead);
95020 case 2:
95021 case "end":
95022 return _context29.stop();
95023 }
95024 }, _callee29, this);
95025 }));
95026 function next() {
95027 return _next14.apply(this, arguments);
95028 }
95029 return next;
95030 }()
95031 }, {
95032 key: "randomInt",
95033 value: function randomInt(max) {
95034 return Math.floor(this.random() * max);
95035 }
95036 }, {
95037 key: "chooseIndex",
95038 value: function chooseIndex() {
95039 return this.randomInt(this.buffer.length());
95040 }
95041 }, {
95042 key: "serialNext",
95043 value: function () {
95044 var _serialNext7 = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee30() {
95045 var chosenIndex, result;
95046 return _regeneratorRuntime().wrap(function _callee30$(_context30) {
95047 while (1) switch (_context30.prev = _context30.next) {
95048 case 0:
95049 // TODO(soergel): consider performance
95050 if (!this.upstreamExhausted) {
95051 this.refill();
95052 }
95053 case 1:
95054 if (this.buffer.isEmpty()) {
95055 _context30.next = 14;
95056 break;
95057 }
95058 chosenIndex = this.chooseIndex();
95059 _context30.next = 5;
95060 return this.buffer.shuffleExcise(chosenIndex);
95061 case 5:
95062 result = _context30.sent;
95063 if (!result.done) {
95064 _context30.next = 10;
95065 break;
95066 }
95067 this.upstreamExhausted = true;
95068 _context30.next = 12;
95069 break;
95070 case 10:
95071 this.refill();
95072 return _context30.abrupt("return", result);
95073 case 12:
95074 _context30.next = 1;
95075 break;
95076 case 14:
95077 return _context30.abrupt("return", {
95078 value: null,
95079 done: true
95080 });
95081 case 15:
95082 case "end":
95083 return _context30.stop();
95084 }
95085 }, _callee30, this);
95086 }));
95087 function serialNext() {
95088 return _serialNext7.apply(this, arguments);
95089 }
95090 return serialNext;
95091 }()
95092 }]);
95093 return ShuffleIterator;
95094 }(PrefetchIterator);
95095
95096 // TODO(soergel): consider vectorized operations within the pipeline.
95097 /**
95098 * Represents a potentially large list of independent data elements (typically
95099 * 'samples' or 'examples').
95100 *
95101 * A 'data example' may be a primitive, an array, a map from string keys to
95102 * values, or any nested structure of these.
95103 *
95104 * A `Dataset` represents an ordered collection of elements, together with a
95105 * chain of transformations to be performed on those elements. Each
95106 * transformation is a method of `Dataset` that returns another `Dataset`, so
95107 * these may be chained, e.g.
95108 * `const processedDataset = rawDataset.filter(...).map(...).batch(...)`.
95109 *
95110 * Data loading and transformation is done in a lazy, streaming fashion. The
95111 * dataset may be iterated over multiple times; each iteration starts the data
95112 * loading anew and recapitulates the transformations.
95113 *
95114 * A `Dataset` is typically processed as a stream of unbatched examples -- i.e.,
95115 * its transformations are applied one example at a time. Batching produces a
95116 * new `Dataset` where each element is a batch. Batching should usually come
95117 * last in a pipeline, because data transformations are easier to express on a
95118 * per-example basis than on a per-batch basis.
95119 *
95120 * The following code examples are calling `await dataset.forEachAsync(...)` to
95121 * iterate once over the entire dataset in order to print out the data.
95122 *
95123 * @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
95124 */
95125 var Dataset = /*#__PURE__*/function () {
95126 function Dataset() {
95127 _classCallCheck(this, Dataset);
95128 this.size = null;
95129 }
95130 // TODO(soergel): Make Datasets report whether repeated iterator() calls
95131 // produce the same result (e.g., reading from a file) or different results
95132 // (e.g., from the webcam). Currently we don't make this distinction but it
95133 // could be important for the user to know.
95134 // abstract isDeterministic(): boolean;
95135 /**
95136 * Groups elements into batches.
95137 *
95138 * It is assumed that each of the incoming dataset elements has the same
95139 * structure -- i.e. the same set of keys at each location in an object
95140 * hierarchy. For each key, the resulting `Dataset` provides a batched
95141 * element collecting all of the incoming values for that key.
95142 *
95143 * * Incoming primitives are grouped into a 1-D Tensor.
95144 * * Incoming Tensors are grouped into a new Tensor where the 0th axis is
95145 * the batch dimension.
95146 * * Incoming arrays are converted to Tensor and then batched.
95147 * * A nested array is interpreted as an n-D Tensor, so the batched result
95148 * has n+1 dimensions.
95149 * * An array that cannot be converted to Tensor produces an error.
95150 *
95151 * If an array should not be batched as a unit, it should first be converted
95152 * to an object with integer keys.
95153 *
95154 * Here are a few examples:
95155 *
95156 * Batch a dataset of numbers:
95157 * ```js
95158 * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8]).batch(4);
95159 * await a.forEachAsync(e => e.print());
95160 * ```
95161 *
95162 * Batch a dataset of arrays:
95163 * ```js
95164 * const b = tf.data.array([[1], [2], [3], [4], [5], [6], [7], [8]]).batch(4);
95165 * await b.forEachAsync(e => e.print());
95166 * ```
95167 *
95168 * Batch a dataset of objects:
95169 * ```js
95170 * const c = tf.data.array([{a: 1, b: 11}, {a: 2, b: 12}, {a: 3, b: 13},
95171 * {a: 4, b: 14}, {a: 5, b: 15}, {a: 6, b: 16}, {a: 7, b: 17},
95172 * {a: 8, b: 18}]).batch(4);
95173 * await c.forEachAsync(e => {
95174 * console.log('{');
95175 * for(var key in e) {
95176 * console.log(key+':');
95177 * e[key].print();
95178 * }
95179 * console.log('}');
95180 * })
95181 * ```
95182 *
95183 * @param batchSize The number of elements desired per batch.
95184 * @param smallLastBatch Whether to emit the final batch when it has fewer
95185 * than batchSize elements. Default true.
95186 * @returns A `Dataset`, from which a stream of batches can be obtained.
95187 *
95188 * @doc {heading: 'Data', subheading: 'Classes'}
95189 */
95190 _createClass(Dataset, [{
95191 key: "batch",
95192 value: function batch(batchSize) {
95193 var smallLastBatch = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : true;
95194 var base = this;
95195 assert$1(batchSize > 0, function () {
95196 return "batchSize needs to be positive, but it is\n ".concat(batchSize);
95197 });
95198 var size;
95199 if (this.size === Infinity || this.size == null) {
95200 // If the size of this dataset is infinity or null, the new size keeps the
95201 // same.
95202 size = this.size;
95203 } else if (smallLastBatch) {
95204 // If the size of this dataset is known and include small last batch, the
95205 // new size is full batch count plus last batch.
95206 size = Math.ceil(this.size / batchSize);
95207 } else {
95208 // If the size of this dataset is known and not include small last batch,
95209 // the new size is full batch count.
95210 size = Math.floor(this.size / batchSize);
95211 }
95212 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
95213 return _regeneratorRuntime().wrap(function _callee$(_context) {
95214 while (1) switch (_context.prev = _context.next) {
95215 case 0:
95216 _context.next = 2;
95217 return base.iterator();
95218 case 2:
95219 return _context.abrupt("return", _context.sent.columnMajorBatch(batchSize, smallLastBatch, deepBatchConcat));
95220 case 3:
95221 case "end":
95222 return _context.stop();
95223 }
95224 }, _callee);
95225 })), size);
95226 }
95227 /**
95228 * Concatenates this `Dataset` with another.
95229 *
95230 * ```js
95231 * const a = tf.data.array([1, 2, 3]);
95232 * const b = tf.data.array([4, 5, 6]);
95233 * const c = a.concatenate(b);
95234 * await c.forEachAsync(e => console.log(e));
95235 * ```
95236 *
95237 * @param dataset A `Dataset` to be concatenated onto this one.
95238 * @returns A `Dataset`.
95239 *
95240 * @doc {heading: 'Data', subheading: 'Classes'}
95241 */
95242 }, {
95243 key: "concatenate",
95244 value: function concatenate(dataset) {
95245 var base = this;
95246 var size;
95247 if (this.size === Infinity || dataset.size === Infinity) {
95248 // If the size of any of these two dataset is infinity, new size is
95249 // infinity.
95250 size = Infinity;
95251 } else if (this.size != null && dataset.size != null) {
95252 // If the size of both datasets are known and not infinity, new size is
95253 // sum the size of these two datasets.
95254 size = this.size + dataset.size;
95255 } else {
95256 // If neither of these two datasets has infinite size and any of these two
95257 // datasets' size is null, the new size is null.
95258 size = null;
95259 }
95260 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
95261 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
95262 while (1) switch (_context2.prev = _context2.next) {
95263 case 0:
95264 _context2.next = 2;
95265 return base.iterator();
95266 case 2:
95267 _context2.t0 = _context2.sent;
95268 _context2.next = 5;
95269 return dataset.iterator();
95270 case 5:
95271 _context2.t1 = _context2.sent;
95272 return _context2.abrupt("return", _context2.t0.concatenate.call(_context2.t0, _context2.t1));
95273 case 7:
95274 case "end":
95275 return _context2.stop();
95276 }
95277 }, _callee2);
95278 })), size);
95279 }
95280 /**
95281 * Filters this dataset according to `predicate`.
95282 *
95283 * ```js
95284 * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
95285 * .filter(x => x%2 === 0);
95286 * await a.forEachAsync(e => console.log(e));
95287 * ```
95288 *
95289 * @param predicate A function mapping a dataset element to a boolean or a
95290 * `Promise` for one.
95291 *
95292 * @returns A `Dataset` of elements for which the predicate was true.
95293 *
95294 * @doc {heading: 'Data', subheading: 'Classes'}
95295 */
95296 }, {
95297 key: "filter",
95298 value: function filter(predicate) {
95299 var base = this;
95300 var size;
95301 if (this.size === Infinity) {
95302 // If the size of this dataset is infinity, new size is infinity
95303 size = Infinity;
95304 } else {
95305 // If this dataset has limited elements, new size is null because it might
95306 // exhausted randomly.
95307 size = null;
95308 }
95309 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3() {
95310 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
95311 while (1) switch (_context3.prev = _context3.next) {
95312 case 0:
95313 _context3.next = 2;
95314 return base.iterator();
95315 case 2:
95316 return _context3.abrupt("return", _context3.sent.filter(function (x) {
95317 return tidy(function () {
95318 return predicate(x);
95319 });
95320 }));
95321 case 3:
95322 case "end":
95323 return _context3.stop();
95324 }
95325 }, _callee3);
95326 })), size);
95327 }
95328 /**
95329 * Apply a function to every element of the dataset.
95330 *
95331 * After the function is applied to a dataset element, any Tensors contained
95332 * within that element are disposed.
95333 *
95334 * ```js
95335 * const a = tf.data.array([1, 2, 3]);
95336 * await a.forEachAsync(e => console.log(e));
95337 * ```
95338 *
95339 * @param f A function to apply to each dataset element.
95340 * @returns A `Promise` that resolves after all elements have been processed.
95341 *
95342 * @doc {heading: 'Data', subheading: 'Classes'}
95343 */
95344 }, {
95345 key: "forEachAsync",
95346 value: function () {
95347 var _forEachAsync = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(f) {
95348 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
95349 while (1) switch (_context4.prev = _context4.next) {
95350 case 0:
95351 _context4.next = 2;
95352 return this.iterator();
95353 case 2:
95354 return _context4.abrupt("return", _context4.sent.forEachAsync(f));
95355 case 3:
95356 case "end":
95357 return _context4.stop();
95358 }
95359 }, _callee4, this);
95360 }));
95361 function forEachAsync(_x) {
95362 return _forEachAsync.apply(this, arguments);
95363 }
95364 return forEachAsync;
95365 }()
95366 /**
95367 * Maps this dataset through a 1-to-1 transform.
95368 *
95369 * ```js
95370 * const a = tf.data.array([1, 2, 3]).map(x => x*x);
95371 * await a.forEachAsync(e => console.log(e));
95372 * ```
95373 *
95374 * @param transform A function mapping a dataset element to a transformed
95375 * dataset element.
95376 *
95377 * @returns A `Dataset` of transformed elements.
95378 *
95379 * @doc {heading: 'Data', subheading: 'Classes'}
95380 */
95381 }, {
95382 key: "map",
95383 value: function map(transform) {
95384 var base = this;
95385 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5() {
95386 return _regeneratorRuntime().wrap(function _callee5$(_context5) {
95387 while (1) switch (_context5.prev = _context5.next) {
95388 case 0:
95389 _context5.next = 2;
95390 return base.iterator();
95391 case 2:
95392 return _context5.abrupt("return", _context5.sent.map(function (x) {
95393 return tidy(function () {
95394 return transform(x);
95395 });
95396 }));
95397 case 3:
95398 case "end":
95399 return _context5.stop();
95400 }
95401 }, _callee5);
95402 })), this.size);
95403 }
95404 /**
95405 * Maps this dataset through an async 1-to-1 transform.
95406 *
95407 * ```js
95408 * const a =
95409 * tf.data.array([1, 2, 3]).mapAsync(x => new Promise(function(resolve){
95410 * setTimeout(() => {
95411 * resolve(x * x);
95412 * }, Math.random()*1000 + 500);
95413 * }));
95414 * console.log(await a.toArray());
95415 * ```
95416 *
95417 * @param transform A function mapping a dataset element to a `Promise` for a
95418 * transformed dataset element. This transform is responsible for disposing
95419 * any intermediate `Tensor`s, i.e. by wrapping its computation in
95420 * `tf.tidy()`; that cannot be automated here (as it is in the synchronous
95421 * `map()` case).
95422 *
95423 * @returns A `Dataset` of transformed elements.
95424 *
95425 * @doc {heading: 'Data', subheading: 'Classes'}
95426 */
95427 }, {
95428 key: "mapAsync",
95429 value: function mapAsync(transform) {
95430 var base = this;
95431 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee6() {
95432 return _regeneratorRuntime().wrap(function _callee6$(_context6) {
95433 while (1) switch (_context6.prev = _context6.next) {
95434 case 0:
95435 _context6.next = 2;
95436 return base.iterator();
95437 case 2:
95438 return _context6.abrupt("return", _context6.sent.mapAsync(transform));
95439 case 3:
95440 case "end":
95441 return _context6.stop();
95442 }
95443 }, _callee6);
95444 })), this.size);
95445 }
95446 /**
95447 * Creates a `Dataset` that prefetches elements from this dataset.
95448 *
95449 * @param bufferSize: An integer specifying the number of elements to be
95450 * prefetched.
95451 * @returns A `Dataset`.
95452 *
95453 * @doc {heading: 'Data', subheading: 'Classes'}
95454 */
95455 }, {
95456 key: "prefetch",
95457 value: function prefetch(bufferSize) {
95458 if (bufferSize == null) {
95459 throw new RangeError('`Dataset.prefetch()` requires bufferSize to be specified.');
95460 }
95461 var base = this;
95462 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee7() {
95463 return _regeneratorRuntime().wrap(function _callee7$(_context7) {
95464 while (1) switch (_context7.prev = _context7.next) {
95465 case 0:
95466 _context7.next = 2;
95467 return base.iterator();
95468 case 2:
95469 return _context7.abrupt("return", _context7.sent.prefetch(bufferSize));
95470 case 3:
95471 case "end":
95472 return _context7.stop();
95473 }
95474 }, _callee7);
95475 })), this.size);
95476 }
95477 /**
95478 * Repeats this dataset `count` times.
95479 *
95480 * NOTE: If this dataset is a function of global state (e.g. a random number
95481 * generator), then different repetitions may produce different elements.
95482 *
95483 * ```js
95484 * const a = tf.data.array([1, 2, 3]).repeat(3);
95485 * await a.forEachAsync(e => console.log(e));
95486 * ```
95487 *
95488 * @param count: (Optional) An integer, representing the number of times
95489 * the dataset should be repeated. The default behavior (if `count` is
95490 * `undefined` or negative) is for the dataset be repeated indefinitely.
95491 * @returns A `Dataset`.
95492 *
95493 * @doc {heading: 'Data', subheading: 'Classes'}
95494 */
95495 }, {
95496 key: "repeat",
95497 value: function repeat(count) {
95498 var base = this;
95499 var size;
95500 if (this.size != null && count > 0) {
95501 // If this dataset has size and count is positive, new size is current
95502 // size multiply count. This also covers the case that current size is
95503 // infinity.
95504 size = this.size * count;
95505 } else if (count === 0) {
95506 // If count is 0, new size is 0.
95507 size = 0;
95508 } else if (this.size != null && (count === undefined || count < 0)) {
95509 // If this dataset has size and count is undefined or negative, the
95510 // dataset will be repeated indefinitely and new size is infinity.
95511 size = Infinity;
95512 } else {
95513 // If the size of this dataset is null, the new dataset's size is null.
95514 size = null;
95515 }
95516 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee9() {
95517 var iteratorIterator;
95518 return _regeneratorRuntime().wrap(function _callee9$(_context9) {
95519 while (1) switch (_context9.prev = _context9.next) {
95520 case 0:
95521 iteratorIterator = iteratorFromFunction( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee8() {
95522 return _regeneratorRuntime().wrap(function _callee8$(_context8) {
95523 while (1) switch (_context8.prev = _context8.next) {
95524 case 0:
95525 _context8.next = 2;
95526 return base.iterator();
95527 case 2:
95528 _context8.t0 = _context8.sent;
95529 return _context8.abrupt("return", {
95530 value: _context8.t0,
95531 done: false
95532 });
95533 case 4:
95534 case "end":
95535 return _context8.stop();
95536 }
95537 }, _callee8);
95538 })));
95539 return _context9.abrupt("return", iteratorFromConcatenated(iteratorIterator.take(count)));
95540 case 2:
95541 case "end":
95542 return _context9.stop();
95543 }
95544 }, _callee9);
95545 })), size);
95546 }
95547 /**
95548 * Creates a `Dataset` that skips `count` initial elements from this dataset.
95549 *
95550 * ```js
95551 * const a = tf.data.array([1, 2, 3, 4, 5, 6]).skip(3);
95552 * await a.forEachAsync(e => console.log(e));
95553 * ```
95554 *
95555 * @param count: The number of elements of this dataset that should be skipped
95556 * to form the new dataset. If `count` is greater than the size of this
95557 * dataset, the new dataset will contain no elements. If `count`
95558 * is `undefined` or negative, skips the entire dataset.
95559 *
95560 * @returns A `Dataset`.
95561 *
95562 * @doc {heading: 'Data', subheading: 'Classes'}
95563 */
95564 }, {
95565 key: "skip",
95566 value: function skip(count) {
95567 var base = this;
95568 var size;
95569 if (this.size != null && count >= 0 && this.size >= count) {
95570 // If the size of this dataset is greater than count, the new dataset's
95571 // size is current size minus skipped size.This also covers the case that
95572 // current size is infinity.
95573 size = this.size - count;
95574 } else if (this.size != null && (this.size < count || count === undefined || count < 0)) {
95575 // If the size of this dataset is smaller than count, or count is
95576 // undefined or negative, skips the entire dataset and the new size is 0.
95577 size = 0;
95578 } else {
95579 // If the size of this dataset is null, the new dataset's size is null.
95580 size = null;
95581 }
95582 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee10() {
95583 return _regeneratorRuntime().wrap(function _callee10$(_context10) {
95584 while (1) switch (_context10.prev = _context10.next) {
95585 case 0:
95586 _context10.next = 2;
95587 return base.iterator();
95588 case 2:
95589 return _context10.abrupt("return", _context10.sent.skip(count));
95590 case 3:
95591 case "end":
95592 return _context10.stop();
95593 }
95594 }, _callee10);
95595 })), size);
95596 }
95597 /**
95598 * Pseudorandomly shuffles the elements of this dataset. This is done in a
95599 * streaming manner, by sampling from a given number of prefetched elements.
95600 *
95601 * ```js
95602 * const a = tf.data.array([1, 2, 3, 4, 5, 6]).shuffle(3);
95603 * await a.forEachAsync(e => console.log(e));
95604 * ```
95605 *
95606 * @param bufferSize: An integer specifying the number of elements from this
95607 * dataset from which the new dataset will sample.
95608 * @param seed: (Optional) An integer specifying the random seed that will
95609 * be used to create the distribution.
95610 * @param reshuffleEachIteration: (Optional) A boolean, which if true
95611 * indicates that the dataset should be pseudorandomly reshuffled each time
95612 * it is iterated over. If false, elements will be returned in the same
95613 * shuffled order on each iteration. (Defaults to `true`.)
95614 * @returns A `Dataset`.
95615 *
95616 * @doc {heading: 'Data', subheading: 'Classes'}
95617 */
95618 }, {
95619 key: "shuffle",
95620 value: function shuffle(bufferSize, seed) {
95621 var reshuffleEachIteration = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true;
95622 if (bufferSize == null || bufferSize < 0) {
95623 if (this.size == null) {
95624 throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified.');
95625 } else {
95626 throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified. ' + 'If your data fits in main memory (for regular JS objects), ' + 'and/or GPU memory (for `tf.Tensor`s), consider setting ' + "bufferSize to the dataset size (".concat(this.size, " elements)"));
95627 }
95628 }
95629 var base = this;
95630 var random = seedrandom.alea(seed || now().toString());
95631 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee11() {
95632 var seed2;
95633 return _regeneratorRuntime().wrap(function _callee11$(_context11) {
95634 while (1) switch (_context11.prev = _context11.next) {
95635 case 0:
95636 seed2 = random.int32();
95637 if (reshuffleEachIteration) {
95638 seed2 += random.int32();
95639 }
95640 _context11.next = 4;
95641 return base.iterator();
95642 case 4:
95643 return _context11.abrupt("return", _context11.sent.shuffle(bufferSize, seed2.toString()));
95644 case 5:
95645 case "end":
95646 return _context11.stop();
95647 }
95648 }, _callee11);
95649 })), this.size);
95650 }
95651 /**
95652 * Creates a `Dataset` with at most `count` initial elements from this
95653 * dataset.
95654 *
95655 * ```js
95656 * const a = tf.data.array([1, 2, 3, 4, 5, 6]).take(3);
95657 * await a.forEachAsync(e => console.log(e));
95658 * ```
95659 *
95660 * @param count: The number of elements of this dataset that should be taken
95661 * to form the new dataset. If `count` is `undefined` or negative, or if
95662 * `count` is greater than the size of this dataset, the new dataset will
95663 * contain all elements of this dataset.
95664 * @returns A `Dataset`.
95665 *
95666 * @doc {heading: 'Data', subheading: 'Classes'}
95667 */
95668 }, {
95669 key: "take",
95670 value: function take(count) {
95671 var base = this;
95672 var size;
95673 if (this.size != null && this.size > count) {
95674 // If the size of this dataset is greater than count, the new dataset's
95675 // size is count.
95676 size = count;
95677 } else if (this.size != null && this.size <= count) {
95678 // If the size of this dataset is equal or smaller than count, the new
95679 // dataset's size is the size of this dataset.
95680 size = this.size;
95681 } else {
95682 // If the size of this dataset is null, the new dataset's size is null.
95683 size = null;
95684 }
95685 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee12() {
95686 return _regeneratorRuntime().wrap(function _callee12$(_context12) {
95687 while (1) switch (_context12.prev = _context12.next) {
95688 case 0:
95689 _context12.next = 2;
95690 return base.iterator();
95691 case 2:
95692 return _context12.abrupt("return", _context12.sent.take(count));
95693 case 3:
95694 case "end":
95695 return _context12.stop();
95696 }
95697 }, _callee12);
95698 })), size);
95699 }
95700 /**
95701 * Collect all elements of this dataset into an array.
95702 *
95703 * Obviously this will succeed only for small datasets that fit in memory.
95704 * Useful for testing and generally should be avoided if possible.
95705 *
95706 * ```js
95707 * const a = tf.data.array([1, 2, 3, 4, 5, 6]);
95708 * console.log(await a.toArray());
95709 * ```
95710 *
95711 * @returns A Promise for an array of elements, which will resolve
95712 * when a new stream has been obtained and fully consumed.
95713 *
95714 * @doc {heading: 'Data', subheading: 'Classes'}
95715 */
95716 }, {
95717 key: "toArray",
95718 value: function () {
95719 var _toArray = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee13() {
95720 return _regeneratorRuntime().wrap(function _callee13$(_context13) {
95721 while (1) switch (_context13.prev = _context13.next) {
95722 case 0:
95723 if (!(this.size === Infinity)) {
95724 _context13.next = 2;
95725 break;
95726 }
95727 throw new Error('Can not convert infinite data stream to array.');
95728 case 2:
95729 _context13.next = 4;
95730 return this.iterator();
95731 case 4:
95732 return _context13.abrupt("return", _context13.sent.toArray());
95733 case 5:
95734 case "end":
95735 return _context13.stop();
95736 }
95737 }, _callee13, this);
95738 }));
95739 function toArray() {
95740 return _toArray.apply(this, arguments);
95741 }
95742 return toArray;
95743 }()
95744 /**
95745 * Collect all elements of this dataset into an array with prefetching 100
95746 * elements. This is useful for testing, because the prefetch changes the
95747 * order in which the Promises are resolved along the processing pipeline.
95748 * This may help expose bugs where results are dependent on the order of
95749 * Promise resolution rather than on the logical order of the stream (i.e.,
95750 * due to hidden mutable state).
95751 *
95752 * @returns A Promise for an array of elements, which will resolve
95753 * when a new stream has been obtained and fully consumed.
95754 */
95755 }, {
95756 key: "toArrayForTest",
95757 value: function () {
95758 var _toArrayForTest = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee14() {
95759 return _regeneratorRuntime().wrap(function _callee14$(_context14) {
95760 while (1) switch (_context14.prev = _context14.next) {
95761 case 0:
95762 if (!(this.size === Infinity)) {
95763 _context14.next = 2;
95764 break;
95765 }
95766 throw new Error('Can not convert infinite data stream to array.');
95767 case 2:
95768 _context14.next = 4;
95769 return this.iterator();
95770 case 4:
95771 return _context14.abrupt("return", _context14.sent.toArrayForTest());
95772 case 5:
95773 case "end":
95774 return _context14.stop();
95775 }
95776 }, _callee14, this);
95777 }));
95778 function toArrayForTest() {
95779 return _toArrayForTest.apply(this, arguments);
95780 }
95781 return toArrayForTest;
95782 }()
95783 }]);
95784 return Dataset;
95785 }(); // TODO(soergel): deep sharded shuffle, where supported
95786 Dataset.MAX_BUFFER_SIZE = 10000;
95787 /**
95788 * Create a `Dataset` defined by a provided iterator() function.
95789 *
95790 * ```js
95791 * let i = -1;
95792 * const func = () =>
95793 * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
95794 * const iter = tf.data.iteratorFromFunction(func);
95795 * const ds = tf.data.datasetFromIteratorFn(iter);
95796 * await ds.forEachAsync(e => console.log(e));
95797 * ```
95798 */
95799 function datasetFromIteratorFn(iteratorFn) {
95800 var size = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null;
95801 return new ( /*#__PURE__*/function (_Dataset) {
95802 _inherits(_class, _Dataset);
95803 var _super = _createSuper(_class);
95804 function _class() {
95805 var _this;
95806 _classCallCheck(this, _class);
95807 _this = _super.apply(this, arguments);
95808 _this.size = size;
95809 return _this;
95810 }
95811 /*
95812 * Provide a new stream of elements. Note this will also start new streams
95813 * from any underlying `Dataset`s.
95814 */
95815 _createClass(_class, [{
95816 key: "iterator",
95817 value: function () {
95818 var _iterator = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee15() {
95819 return _regeneratorRuntime().wrap(function _callee15$(_context15) {
95820 while (1) switch (_context15.prev = _context15.next) {
95821 case 0:
95822 return _context15.abrupt("return", iteratorFn());
95823 case 1:
95824 case "end":
95825 return _context15.stop();
95826 }
95827 }, _callee15);
95828 }));
95829 function iterator() {
95830 return _iterator.apply(this, arguments);
95831 }
95832 return iterator;
95833 }()
95834 }]);
95835 return _class;
95836 }(Dataset))();
95837 }
95838 /**
95839 * Create a `Dataset` from an array of elements.
95840 *
95841 * Create a Dataset from an array of objects:
95842 * ```js
95843 * const a = tf.data.array([{'item': 1}, {'item': 2}, {'item': 3}]);
95844 * await a.forEachAsync(e => console.log(e));
95845 * ```
95846 *
95847 * Create a Dataset from an array of numbers:
95848 * ```js
95849 * const a = tf.data.array([4, 5, 6]);
95850 * await a.forEachAsync(e => console.log(e));
95851 * ```
95852 * @param items An array of elements that will be parsed as items in a dataset.
95853 *
95854 * @doc {heading: 'Data', subheading: 'Creation', namespace: 'data'}
95855 */
95856 function array(items) {
95857 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee16() {
95858 return _regeneratorRuntime().wrap(function _callee16$(_context16) {
95859 while (1) switch (_context16.prev = _context16.next) {
95860 case 0:
95861 return _context16.abrupt("return", iteratorFromItems(items));
95862 case 1:
95863 case "end":
95864 return _context16.stop();
95865 }
95866 }, _callee16);
95867 })), items.length);
95868 }
95869 /**
95870 * Create a `Dataset` by zipping together an array, dict, or nested
95871 * structure of `Dataset`s (and perhaps additional constants).
95872 * The underlying datasets must provide elements in a consistent order such that
95873 * they correspond.
95874 *
95875 * The number of elements in the resulting dataset is the same as the size of
95876 * the smallest dataset in datasets.
95877 *
95878 * The nested structure of the `datasets` argument determines the
95879 * structure of elements in the resulting iterator.
95880 *
95881 * Note this means that, given an array of two datasets that produce dict
95882 * elements, the result is a dataset that produces elements that are arrays
95883 * of two dicts:
95884 *
95885 * Zip an array of datasets:
95886 * ```js
95887 * console.log('Zip two datasets of objects:');
95888 * const ds1 = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
95889 * const ds2 = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
95890 * const ds3 = tf.data.zip([ds1, ds2]);
95891 * await ds3.forEachAsync(e => console.log(JSON.stringify(e)));
95892 *
95893 * // If the goal is to merge the dicts in order to produce elements like
95894 * // {a: ..., b: ...}, this requires a second step such as:
95895 * console.log('Merge the objects:');
95896 * const ds4 = ds3.map(x => {return {a: x[0].a, b: x[1].b}});
95897 * await ds4.forEachAsync(e => console.log(e));
95898 * ```
95899 *
95900 * Zip a dict of datasets:
95901 * ```js
95902 * const a = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
95903 * const b = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
95904 * const c = tf.data.zip({c: a, d: b});
95905 * await c.forEachAsync(e => console.log(JSON.stringify(e)));
95906 * ```
95907 *
95908 * @doc {heading: 'Data', subheading: 'Operations', namespace: 'data'}
95909 */
95910 function zip(datasets) {
95911 // manually type-check the argument for JS users
95912 if (!isIterable(datasets)) {
95913 throw new Error('The argument to zip() must be an object or array.');
95914 }
95915 var size;
95916 if (Array.isArray(datasets)) {
95917 for (var i = 0; i < datasets.length; i++) {
95918 size = size == null ? datasets[i].size : Math.min(size, datasets[i].size);
95919 }
95920 } else if (datasets instanceof Object) {
95921 for (var ds in datasets) {
95922 size = size == null ? datasets[ds].size : Math.min(size, datasets[ds].size);
95923 }
95924 }
95925 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee17() {
95926 var streams;
95927 return _regeneratorRuntime().wrap(function _callee17$(_context17) {
95928 while (1) switch (_context17.prev = _context17.next) {
95929 case 0:
95930 _context17.next = 2;
95931 return deepMapAndAwaitAll(datasets, function (d) {
95932 if (d instanceof Dataset) {
95933 return {
95934 value: d.iterator(),
95935 recurse: false
95936 };
95937 } else if (isIterable(d)) {
95938 return {
95939 value: null,
95940 recurse: true
95941 };
95942 } else {
95943 throw new Error('Leaves of the structure passed to zip() must be Datasets, ' + 'not primitives.');
95944 }
95945 });
95946 case 2:
95947 streams = _context17.sent;
95948 return _context17.abrupt("return", iteratorFromZipped(streams, ZipMismatchMode.SHORTEST));
95949 case 4:
95950 case "end":
95951 return _context17.stop();
95952 }
95953 }, _callee17);
95954 })), size);
95955 }
95956 /**
95957 * A zip function for use with deepZip, passed via the columnMajorBatch call.
95958 *
95959 * Accepts an array of identically-structured nested elements and either batches
95960 * them (if they are primitives, numeric arrays, or Tensors) or requests
95961 * recursion (if not).
95962 */
95963 // tslint:disable-next-line:no-any
95964 function deepBatchConcat(rows) {
95965 if (rows === null) {
95966 return null;
95967 }
95968 // use the first item to decide whether to recurse or batch here.
95969 var exampleRow = rows[0];
95970 if (canTensorify(exampleRow)) {
95971 // rows is an array of primitives, Tensors, or arrays. Batch them.
95972 var value = batchConcat(rows);
95973 return {
95974 value: value,
95975 recurse: false
95976 };
95977 }
95978 // the example row is an object, so recurse into it.
95979 return {
95980 value: null,
95981 recurse: true
95982 };
95983 }
95984 /**
95985 * Assembles a list of same-shaped numbers, number arrays, or Tensors
95986 * into a single new Tensor where axis 0 is the batch dimension.
95987 */
95988 function batchConcat(arrays) {
95989 if (arrays.length === 0) {
95990 // We can't return an empty Tensor because we don't know the element shape.
95991 throw new Error('Can\'t make a batch of zero elements.');
95992 }
95993 if (arrays[0] instanceof Tensor) {
95994 // Input is an array of Tensors
95995 return stack(arrays);
95996 } else {
95997 // Input is a possibly-nested array of numbers.
95998 return tensor(arrays);
95999 }
96000 }
96001
96002 /**
96003 * Represents a potentially large collection of text lines.
96004 *
96005 * The results are not batched.
96006 */
96007 var TextLineDataset = /*#__PURE__*/function (_Dataset) {
96008 _inherits(TextLineDataset, _Dataset);
96009 var _super = _createSuper(TextLineDataset);
96010 /**
96011 * Create a `TextLineDataset`.
96012 *
96013 * @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
96014 */
96015 function TextLineDataset(input) {
96016 var _this;
96017 _classCallCheck(this, TextLineDataset);
96018 _this = _super.call(this);
96019 _this.input = input;
96020 return _this;
96021 }
96022 _createClass(TextLineDataset, [{
96023 key: "iterator",
96024 value: function () {
96025 var _iterator = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
96026 var inputIterator, utf8Iterator, lineIterator;
96027 return _regeneratorRuntime().wrap(function _callee$(_context) {
96028 while (1) switch (_context.prev = _context.next) {
96029 case 0:
96030 _context.next = 2;
96031 return this.input.iterator();
96032 case 2:
96033 inputIterator = _context.sent;
96034 utf8Iterator = inputIterator.decodeUTF8();
96035 lineIterator = utf8Iterator.split('\n').map(function (line) {
96036 // Windows/DOS format text file has extra line breaker at the end of line.
96037 if (line.endsWith('\r')) {
96038 line = line.slice(0, -1);
96039 }
96040 return line;
96041 });
96042 return _context.abrupt("return", lineIterator);
96043 case 6:
96044 case "end":
96045 return _context.stop();
96046 }
96047 }, _callee, this);
96048 }));
96049 function iterator() {
96050 return _iterator.apply(this, arguments);
96051 }
96052 return iterator;
96053 }()
96054 }]);
96055 return TextLineDataset;
96056 }(Dataset);
96057
96058 var CODE_QUOTE = '"';
96059 var STATE_OUT = Symbol('out');
96060 var STATE_FIELD = Symbol('field');
96061 var STATE_QUOTE = Symbol('quote');
96062 var STATE_QUOTE_AFTER_QUOTE = Symbol('quoteafterquote');
96063 var STATE_WITHIN_QUOTE_IN_QUOTE = Symbol('quoteinquote');
96064 /**
96065 * Represents a potentially large collection of delimited text records.
96066 *
96067 * The produced `TensorContainer`s each contain one key-value pair for
96068 * every column of the table. When a field is empty in the incoming data, the
96069 * resulting value is `undefined`, or throw error if it is required. Values
96070 * that can be parsed as numbers are emitted as type `number`, other values
96071 * are parsed as `string`.
96072 *
96073 * The results are not batched.
96074 *
96075 * @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
96076 */
96077 var CSVDataset = /*#__PURE__*/function (_Dataset) {
96078 _inherits(CSVDataset, _Dataset);
96079 var _super = _createSuper(CSVDataset);
96080 /**
96081 * Create a `CSVDataset`.
96082 *
96083 * @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
96084 * @param csvConfig (Optional) A CSVConfig object that contains configurations
96085 * of reading and decoding from CSV file(s).
96086 *
96087 * hasHeader: (Optional) A boolean value that indicates whether the first
96088 * row of provided CSV file is a header line with column names, and should
96089 * not be included in the data. Defaults to `true`.
96090 *
96091 * columnNames: (Optional) A list of strings that corresponds to
96092 * the CSV column names, in order. If provided, it ignores the column
96093 * names inferred from the header row. If not provided, infers the column
96094 * names from the first row of the records. If hasHeader is false and
96095 * columnNames is not provided, this method throws an error.
96096 *
96097 * columnConfigs: (Optional) A dictionary whose key is column names, value
96098 * is an object stating if this column is required, column's data type,
96099 * default value, and if this column is label. If provided, keys must
96100 * correspond to names provided in columnNames or inferred from the file
96101 * header lines. If isLabel is true any column, returns an array of two
96102 * items: the first item is a dict of features key/value pairs, the second
96103 * item is a dict of labels key/value pairs. If no feature is marked as
96104 * label, returns a dict of features only.
96105 *
96106 * configuredColumnsOnly (Optional) If true, only columns provided in
96107 * columnConfigs will be parsed and provided during iteration.
96108 *
96109 * delimiter (Optional) The string used to parse each line of the input
96110 * file. Defaults to `,`.
96111 */
96112 function CSVDataset(input, csvConfig) {
96113 var _this;
96114 _classCallCheck(this, CSVDataset);
96115 _this = _super.call(this);
96116 _this.input = input;
96117 _this.hasHeader = true;
96118 _this.fullColumnNames = null;
96119 _this.columnNamesValidated = false;
96120 _this.columnConfigs = null;
96121 _this.configuredColumnsOnly = false;
96122 _this.delimiter = ',';
96123 _this.delimWhitespace = false;
96124 _this.base = new TextLineDataset(input);
96125 if (!csvConfig) {
96126 csvConfig = {};
96127 }
96128 _this.hasHeader = csvConfig.hasHeader === false ? false : true;
96129 _this.fullColumnNames = csvConfig.columnNames;
96130 _this.columnConfigs = csvConfig.columnConfigs;
96131 _this.configuredColumnsOnly = csvConfig.configuredColumnsOnly;
96132 if (csvConfig.delimWhitespace) {
96133 assert$1(csvConfig.delimiter == null, function () {
96134 return 'Delimiter should not be provided when delimWhitespace is true.';
96135 });
96136 _this.delimWhitespace = true;
96137 _this.delimiter = ' ';
96138 } else {
96139 _this.delimiter = csvConfig.delimiter ? csvConfig.delimiter : ',';
96140 }
96141 return _this;
96142 }
96143 _createClass(CSVDataset, [{
96144 key: "columnNames",
96145 value:
96146 /**
96147 * Returns column names of the csv dataset. If `configuredColumnsOnly` is
96148 * true, return column names in `columnConfigs`. If `configuredColumnsOnly` is
96149 * false and `columnNames` is provided, `columnNames`. If
96150 * `configuredColumnsOnly` is false and `columnNames` is not provided, return
96151 * all column names parsed from the csv file. For example usage please go to
96152 * `tf.data.csv`.
96153 *
96154 * @doc {heading: 'Data', subheading: 'Classes'}
96155 */
96156 function () {
96157 var _columnNames = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
96158 return _regeneratorRuntime().wrap(function _callee$(_context) {
96159 while (1) switch (_context.prev = _context.next) {
96160 case 0:
96161 if (this.columnNamesValidated) {
96162 _context.next = 3;
96163 break;
96164 }
96165 _context.next = 3;
96166 return this.setColumnNames();
96167 case 3:
96168 return _context.abrupt("return", this.configuredColumnsOnly ? Object.keys(this.columnConfigs) : this.fullColumnNames);
96169 case 4:
96170 case "end":
96171 return _context.stop();
96172 }
96173 }, _callee, this);
96174 }));
96175 function columnNames() {
96176 return _columnNames.apply(this, arguments);
96177 }
96178 return columnNames;
96179 }()
96180 /* 1) If `columnNames` is provided as string[], use this string[] as output
96181 * keys in corresponding order. The length must match the number of inferred
96182 * columns if `hasHeader` is true .
96183 * 2) If `columnNames` is not provided, parse header line as `columnNames` if
96184 * hasHeader is true. If `hasHeader` is false, throw an error.
96185 * 3) If `columnConfigs` is provided, all the keys in `columnConfigs` must
96186 * exist in parsed `columnNames`.
96187 */
96188 }, {
96189 key: "setColumnNames",
96190 value: function () {
96191 var _setColumnNames = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
96192 var _this2 = this;
96193 var columnNamesFromFile, counts, duplicateNames, _i, _Object$keys, key, index;
96194 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
96195 while (1) switch (_context2.prev = _context2.next) {
96196 case 0:
96197 _context2.next = 2;
96198 return this.maybeReadHeaderLine();
96199 case 2:
96200 columnNamesFromFile = _context2.sent;
96201 if (!(!this.fullColumnNames && !columnNamesFromFile)) {
96202 _context2.next = 7;
96203 break;
96204 }
96205 throw new Error('Column names must be provided if there is no header line.');
96206 case 7:
96207 if (this.fullColumnNames && columnNamesFromFile) {
96208 // Check provided columnNames match header line.
96209 assert$1(columnNamesFromFile.length === this.fullColumnNames.length, function () {
96210 return 'The length of provided columnNames (' + _this2.fullColumnNames.length.toString() + ') does not match the length of the header line read from ' + 'file (' + columnNamesFromFile.length.toString() + ').';
96211 });
96212 }
96213 case 8:
96214 if (!this.fullColumnNames) {
96215 this.fullColumnNames = columnNamesFromFile;
96216 }
96217 // Check if there are duplicate column names.
96218 counts = this.fullColumnNames.reduce(function (countAcc, name) {
96219 countAcc[name] = countAcc[name] + 1 || 1;
96220 return countAcc;
96221 }, {});
96222 duplicateNames = Object.keys(counts).filter(function (name) {
96223 return counts[name] > 1;
96224 });
96225 assert$1(duplicateNames.length === 0, function () {
96226 return 'Duplicate column names found: ' + duplicateNames.toString();
96227 });
96228 // Check if keys in columnConfigs match columnNames.
96229 if (!this.columnConfigs) {
96230 _context2.next = 22;
96231 break;
96232 }
96233 _i = 0, _Object$keys = Object.keys(this.columnConfigs);
96234 case 14:
96235 if (!(_i < _Object$keys.length)) {
96236 _context2.next = 22;
96237 break;
96238 }
96239 key = _Object$keys[_i];
96240 index = this.fullColumnNames.indexOf(key);
96241 if (!(index === -1)) {
96242 _context2.next = 19;
96243 break;
96244 }
96245 throw new Error('The key "' + key + '" provided in columnConfigs does not match any of the column ' + 'names (' + this.fullColumnNames.toString() + ').');
96246 case 19:
96247 _i++;
96248 _context2.next = 14;
96249 break;
96250 case 22:
96251 this.columnNamesValidated = true;
96252 case 23:
96253 case "end":
96254 return _context2.stop();
96255 }
96256 }, _callee2, this);
96257 }));
96258 function setColumnNames() {
96259 return _setColumnNames.apply(this, arguments);
96260 }
96261 return setColumnNames;
96262 }()
96263 }, {
96264 key: "maybeReadHeaderLine",
96265 value: function () {
96266 var _maybeReadHeaderLine = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3() {
96267 var iter, firstElement, firstLine, headers;
96268 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
96269 while (1) switch (_context3.prev = _context3.next) {
96270 case 0:
96271 if (!this.hasHeader) {
96272 _context3.next = 14;
96273 break;
96274 }
96275 _context3.next = 3;
96276 return this.base.iterator();
96277 case 3:
96278 iter = _context3.sent;
96279 _context3.next = 6;
96280 return iter.next();
96281 case 6:
96282 firstElement = _context3.sent;
96283 if (!firstElement.done) {
96284 _context3.next = 9;
96285 break;
96286 }
96287 throw new Error('No data was found for CSV parsing.');
96288 case 9:
96289 firstLine = firstElement.value;
96290 headers = this.parseRow(firstLine, false);
96291 return _context3.abrupt("return", headers);
96292 case 14:
96293 return _context3.abrupt("return", null);
96294 case 15:
96295 case "end":
96296 return _context3.stop();
96297 }
96298 }, _callee3, this);
96299 }));
96300 function maybeReadHeaderLine() {
96301 return _maybeReadHeaderLine.apply(this, arguments);
96302 }
96303 return maybeReadHeaderLine;
96304 }()
96305 }, {
96306 key: "iterator",
96307 value: function () {
96308 var _iterator = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4() {
96309 var _this3 = this;
96310 var lines;
96311 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
96312 while (1) switch (_context4.prev = _context4.next) {
96313 case 0:
96314 if (this.columnNamesValidated) {
96315 _context4.next = 3;
96316 break;
96317 }
96318 _context4.next = 3;
96319 return this.setColumnNames();
96320 case 3:
96321 _context4.next = 5;
96322 return this.base.iterator();
96323 case 5:
96324 lines = _context4.sent;
96325 if (this.hasHeader) {
96326 // We previously read the first line to get the columnNames.
96327 // Now that we're providing data, skip it.
96328 lines = lines.skip(1);
96329 }
96330 return _context4.abrupt("return", lines.map(function (x) {
96331 return _this3.makeDataElement(x);
96332 }));
96333 case 8:
96334 case "end":
96335 return _context4.stop();
96336 }
96337 }, _callee4, this);
96338 }));
96339 function iterator() {
96340 return _iterator.apply(this, arguments);
96341 }
96342 return iterator;
96343 }()
96344 }, {
96345 key: "makeDataElement",
96346 value: function makeDataElement(line) {
96347 var values = this.parseRow(line);
96348 var features = {};
96349 var labels = {};
96350 for (var i = 0; i < this.fullColumnNames.length; i++) {
96351 var key = this.fullColumnNames[i];
96352 var config = this.columnConfigs ? this.columnConfigs[key] : null;
96353 if (this.configuredColumnsOnly && !config) {
96354 // This column is not selected.
96355 continue;
96356 } else {
96357 var value = values[i];
96358 var parsedValue = null;
96359 if (value === '') {
96360 // If default value is provided, use it. If default value is not
96361 // provided, set as undefined.
96362 if (config && config.default !== undefined) {
96363 parsedValue = config.default;
96364 } else if (config && (config.required || config.isLabel)) {
96365 throw new Error("Required column ".concat(key, " is empty in this line: ").concat(line));
96366 } else {
96367 parsedValue = undefined;
96368 }
96369 } else {
96370 // A value is present, so parse it based on type
96371 var valueAsNum = Number(value);
96372 if (isNaN(valueAsNum)) {
96373 // The value is a string and this column is declared as boolean
96374 // in config, parse it as boolean.
96375 if (config && config.dtype === 'bool') {
96376 parsedValue = this.getBoolean(value);
96377 } else {
96378 // Set value as string
96379 parsedValue = value;
96380 }
96381 } else if (!config || !config.dtype) {
96382 // If this value is a number and no type config is provided, return
96383 // it as number.
96384 parsedValue = valueAsNum;
96385 } else {
96386 // If this value is a number and data type is provided, parse it
96387 // according to provided data type.
96388 switch (config.dtype) {
96389 case 'float32':
96390 parsedValue = valueAsNum;
96391 break;
96392 case 'int32':
96393 parsedValue = Math.floor(valueAsNum);
96394 break;
96395 case 'bool':
96396 parsedValue = this.getBoolean(value);
96397 break;
96398 default:
96399 parsedValue = valueAsNum;
96400 }
96401 }
96402 }
96403 // Check if this column is label.
96404 config && config.isLabel ? labels[key] = parsedValue : features[key] = parsedValue;
96405 }
96406 }
96407 // If label exists, return an object of features and labels as {xs:features,
96408 // ys:labels}, otherwise return features only.
96409 if (Object.keys(labels).length === 0) {
96410 return features;
96411 } else {
96412 return {
96413 xs: features,
96414 ys: labels
96415 };
96416 }
96417 }
96418 }, {
96419 key: "getBoolean",
96420 value: function getBoolean(value) {
96421 if (value === '1' || value.toLowerCase() === 'true') {
96422 return 1;
96423 } else {
96424 return 0;
96425 }
96426 }
96427 // adapted from https://beta.observablehq.com/@mbostock/streaming-csv
96428 }, {
96429 key: "parseRow",
96430 value: function parseRow(line) {
96431 var validateElementCount = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : true;
96432 var result = [];
96433 var readOffset = 0;
96434 var readLength = line.length;
96435 var currentState = STATE_OUT;
96436 // Goes through the line to parse quote.
96437 for (var i = 0; i < readLength; i++) {
96438 switch (currentState) {
96439 // Before enter a new field
96440 case STATE_OUT:
96441 switch (line.charAt(i)) {
96442 // Enter a quoted field
96443 case CODE_QUOTE:
96444 readOffset = i + 1;
96445 currentState = STATE_QUOTE;
96446 break;
96447 // Read an empty field
96448 case this.delimiter:
96449 readOffset = i + 1;
96450 // If delimiter is white space and configured to collapse
96451 // multiple white spaces, ignore this white space.
96452 if (this.delimiter === ' ' && this.delimWhitespace) {
96453 break;
96454 }
96455 result.push('');
96456 currentState = STATE_OUT;
96457 break;
96458 // Enter an unquoted field
96459 default:
96460 currentState = STATE_FIELD;
96461 readOffset = i;
96462 break;
96463 }
96464 break;
96465 // In an unquoted field
96466 case STATE_FIELD:
96467 switch (line.charAt(i)) {
96468 // Exit an unquoted field, add it to result
96469 case this.delimiter:
96470 result.push(line.substring(readOffset, i));
96471 currentState = STATE_OUT;
96472 readOffset = i + 1;
96473 break;
96474 default:
96475 }
96476 break;
96477 // In a quoted field
96478 case STATE_QUOTE:
96479 switch (line.charAt(i)) {
96480 // Read a quote after a quote
96481 case CODE_QUOTE:
96482 currentState = STATE_QUOTE_AFTER_QUOTE;
96483 break;
96484 default:
96485 }
96486 break;
96487 // This state means it's right after a second quote in a field
96488 case STATE_QUOTE_AFTER_QUOTE:
96489 switch (line.charAt(i)) {
96490 // Finished a quoted field
96491 case this.delimiter:
96492 result.push(line.substring(readOffset, i - 1));
96493 currentState = STATE_OUT;
96494 readOffset = i + 1;
96495 break;
96496 // Finished a quoted part in a quoted field
96497 case CODE_QUOTE:
96498 currentState = STATE_QUOTE;
96499 break;
96500 // In a quoted part in a quoted field
96501 default:
96502 currentState = STATE_WITHIN_QUOTE_IN_QUOTE;
96503 break;
96504 }
96505 break;
96506 case STATE_WITHIN_QUOTE_IN_QUOTE:
96507 switch (line.charAt(i)) {
96508 // Exit a quoted part in a quoted field
96509 case CODE_QUOTE:
96510 currentState = STATE_QUOTE;
96511 break;
96512 default:
96513 }
96514 break;
96515 default:
96516 }
96517 }
96518 // Adds last item based on if it is quoted.
96519 if (currentState === STATE_QUOTE_AFTER_QUOTE) {
96520 result.push(line.substring(readOffset, readLength - 1));
96521 } else {
96522 result.push(line.substring(readOffset));
96523 }
96524 // Check if each row has the same number of elements as column names.
96525 if (validateElementCount && result.length !== this.fullColumnNames.length) {
96526 throw new Error("Invalid row in csv file. Should have ".concat(this.fullColumnNames.length, " elements in a row, but got ").concat(result));
96527 }
96528 return result;
96529 }
96530 }]);
96531 return CSVDataset;
96532 }(Dataset);
96533 // TODO(soergel): add more basic datasets for parity with tf.data
96534 // tf.data.FixedLengthRecordDataset()
96535 // tf.data.TFRecordDataset()
96536
96537 /**
96538 * Provide a stream of tensors from microphone audio stream. The tensors are
96539 * representing audio data as frequency-domain spectrogram generated with
96540 * browser's native FFT. Tensors representing time-domain waveform is available
96541 * based on configuration. Only works in browser environment.
96542 */
96543 var MicrophoneIterator = /*#__PURE__*/function (_LazyIterator) {
96544 _inherits(MicrophoneIterator, _LazyIterator);
96545 var _super = _createSuper(MicrophoneIterator);
96546 function MicrophoneIterator(microphoneConfig) {
96547 var _this;
96548 _classCallCheck(this, MicrophoneIterator);
96549 _this = _super.call(this);
96550 _this.microphoneConfig = microphoneConfig;
96551 _this.isClosed = false;
96552 _this.fftSize = microphoneConfig.fftSize || 1024;
96553 var fftSizeLog2 = Math.log2(_this.fftSize);
96554 if (_this.fftSize < 0 || fftSizeLog2 < 4 || fftSizeLog2 > 14 || !Number.isInteger(fftSizeLog2)) {
96555 throw new Error("Invalid fftSize: it must be a power of 2 between " + "2 to 4 and 2 to 14, but got ".concat(_this.fftSize));
96556 }
96557 _this.numFrames = microphoneConfig.numFramesPerSpectrogram || 43;
96558 _this.sampleRateHz = microphoneConfig.sampleRateHz;
96559 _this.columnTruncateLength = microphoneConfig.columnTruncateLength || _this.fftSize;
96560 _this.audioTrackConstraints = microphoneConfig.audioTrackConstraints;
96561 _this.smoothingTimeConstant = microphoneConfig.smoothingTimeConstant || 0;
96562 _this.includeSpectrogram = microphoneConfig.includeSpectrogram === false ? false : true;
96563 _this.includeWaveform = microphoneConfig.includeWaveform === true ? true : false;
96564 if (!_this.includeSpectrogram && !_this.includeWaveform) {
96565 throw new Error('Both includeSpectrogram and includeWaveform are false. ' + 'At least one type of data should be returned.');
96566 }
96567 return _this;
96568 }
96569 _createClass(MicrophoneIterator, [{
96570 key: "summary",
96571 value: function summary() {
96572 return "microphone";
96573 }
96574 // Construct a MicrophoneIterator and start the audio stream.
96575 }, {
96576 key: "start",
96577 value: // Start the audio stream and FFT.
96578 function () {
96579 var _start = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
96580 var ctxConstructor, streamSource;
96581 return _regeneratorRuntime().wrap(function _callee$(_context) {
96582 while (1) switch (_context.prev = _context.next) {
96583 case 0:
96584 _context.prev = 0;
96585 _context.next = 3;
96586 return navigator.mediaDevices.getUserMedia({
96587 audio: this.audioTrackConstraints == null ? true : this.audioTrackConstraints,
96588 video: false
96589 });
96590 case 3:
96591 this.stream = _context.sent;
96592 _context.next = 9;
96593 break;
96594 case 6:
96595 _context.prev = 6;
96596 _context.t0 = _context["catch"](0);
96597 throw new Error("Error thrown while initializing video stream: ".concat(_context.t0.message));
96598 case 9:
96599 if (this.stream) {
96600 _context.next = 11;
96601 break;
96602 }
96603 throw new Error('Could not obtain audio from microphone.');
96604 case 11:
96605 ctxConstructor =
96606 // tslint:disable-next-line:no-any
96607 window.AudioContext || window.webkitAudioContext;
96608 this.audioContext = new ctxConstructor();
96609 if (this.sampleRateHz) {
96610 _context.next = 17;
96611 break;
96612 }
96613 // If sample rate is not provided, use the available sample rate on
96614 // device.
96615 this.sampleRateHz = this.audioContext.sampleRate;
96616 _context.next = 19;
96617 break;
96618 case 17:
96619 if (!(this.audioContext.sampleRate !== this.sampleRateHz)) {
96620 _context.next = 19;
96621 break;
96622 }
96623 throw new Error("Mismatch in sampling rate: " + "Expected: ".concat(this.sampleRateHz, "; ") + "Actual: ".concat(this.audioContext.sampleRate));
96624 case 19:
96625 streamSource = this.audioContext.createMediaStreamSource(this.stream);
96626 this.analyser = this.audioContext.createAnalyser();
96627 this.analyser.fftSize = this.fftSize * 2;
96628 this.analyser.smoothingTimeConstant = this.smoothingTimeConstant;
96629 streamSource.connect(this.analyser);
96630 this.freqData = new Float32Array(this.fftSize);
96631 this.timeData = new Float32Array(this.fftSize);
96632 return _context.abrupt("return");
96633 case 27:
96634 case "end":
96635 return _context.stop();
96636 }
96637 }, _callee, this, [[0, 6]]);
96638 }));
96639 function start() {
96640 return _start.apply(this, arguments);
96641 }
96642 return start;
96643 }()
96644 }, {
96645 key: "next",
96646 value: function () {
96647 var _next = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
96648 var spectrogramTensor, waveformTensor, audioDataQueue, freqData, timeData;
96649 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
96650 while (1) switch (_context2.prev = _context2.next) {
96651 case 0:
96652 if (!this.isClosed) {
96653 _context2.next = 2;
96654 break;
96655 }
96656 return _context2.abrupt("return", {
96657 value: null,
96658 done: true
96659 });
96660 case 2:
96661 _context2.next = 4;
96662 return this.getAudioData();
96663 case 4:
96664 audioDataQueue = _context2.sent;
96665 if (this.includeSpectrogram) {
96666 freqData = this.flattenQueue(audioDataQueue.freqDataQueue);
96667 spectrogramTensor = this.getTensorFromAudioDataArray(freqData, [this.numFrames, this.columnTruncateLength, 1]);
96668 }
96669 if (this.includeWaveform) {
96670 timeData = this.flattenQueue(audioDataQueue.timeDataQueue);
96671 waveformTensor = this.getTensorFromAudioDataArray(timeData, [this.numFrames * this.fftSize, 1]);
96672 }
96673 return _context2.abrupt("return", {
96674 value: {
96675 'spectrogram': spectrogramTensor,
96676 'waveform': waveformTensor
96677 },
96678 done: false
96679 });
96680 case 8:
96681 case "end":
96682 return _context2.stop();
96683 }
96684 }, _callee2, this);
96685 }));
96686 function next() {
96687 return _next.apply(this, arguments);
96688 }
96689 return next;
96690 }() // Capture one result from the audio stream, and extract the value from
96691 // iterator.next() result.
96692 }, {
96693 key: "capture",
96694 value: function () {
96695 var _capture = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3() {
96696 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
96697 while (1) switch (_context3.prev = _context3.next) {
96698 case 0:
96699 _context3.next = 2;
96700 return this.next();
96701 case 2:
96702 return _context3.abrupt("return", _context3.sent.value);
96703 case 3:
96704 case "end":
96705 return _context3.stop();
96706 }
96707 }, _callee3, this);
96708 }));
96709 function capture() {
96710 return _capture.apply(this, arguments);
96711 }
96712 return capture;
96713 }()
96714 }, {
96715 key: "getAudioData",
96716 value: function () {
96717 var _getAudioData = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4() {
96718 var _this2 = this;
96719 var freqDataQueue, timeDataQueue, currentFrames;
96720 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
96721 while (1) switch (_context4.prev = _context4.next) {
96722 case 0:
96723 freqDataQueue = [];
96724 timeDataQueue = [];
96725 currentFrames = 0;
96726 return _context4.abrupt("return", new Promise(function (resolve) {
96727 var intervalID = setInterval(function () {
96728 if (_this2.includeSpectrogram) {
96729 _this2.analyser.getFloatFrequencyData(_this2.freqData);
96730 // If the audio stream is initializing, return empty queue.
96731 if (_this2.freqData[0] === -Infinity) {
96732 resolve({
96733 freqDataQueue: freqDataQueue,
96734 timeDataQueue: timeDataQueue
96735 });
96736 }
96737 freqDataQueue.push(_this2.freqData.slice(0, _this2.columnTruncateLength));
96738 }
96739 if (_this2.includeWaveform) {
96740 _this2.analyser.getFloatTimeDomainData(_this2.timeData);
96741 timeDataQueue.push(_this2.timeData.slice());
96742 }
96743 // Clean interval and return when all frames have been collected
96744 if (++currentFrames === _this2.numFrames) {
96745 clearInterval(intervalID);
96746 resolve({
96747 freqDataQueue: freqDataQueue,
96748 timeDataQueue: timeDataQueue
96749 });
96750 }
96751 }, _this2.fftSize / _this2.sampleRateHz * 1e3);
96752 }));
96753 case 4:
96754 case "end":
96755 return _context4.stop();
96756 }
96757 }, _callee4);
96758 }));
96759 function getAudioData() {
96760 return _getAudioData.apply(this, arguments);
96761 }
96762 return getAudioData;
96763 }() // Stop the audio stream and pause the iterator.
96764 }, {
96765 key: "stop",
96766 value: function stop() {
96767 if (!this.isClosed) {
96768 this.isClosed = true;
96769 this.analyser.disconnect();
96770 this.audioContext.close();
96771 if (this.stream != null && this.stream.getTracks().length > 0) {
96772 this.stream.getTracks()[0].stop();
96773 }
96774 }
96775 }
96776 // Override toArray() function to prevent collecting.
96777 }, {
96778 key: "toArray",
96779 value: function toArray() {
96780 throw new Error('Can not convert infinite audio stream to array.');
96781 }
96782 // Return audio sampling rate in Hz
96783 }, {
96784 key: "getSampleRate",
96785 value: function getSampleRate() {
96786 return this.sampleRateHz;
96787 }
96788 }, {
96789 key: "flattenQueue",
96790 value: function flattenQueue(queue) {
96791 var frameSize = queue[0].length;
96792 var freqData = new Float32Array(queue.length * frameSize);
96793 queue.forEach(function (data, i) {
96794 return freqData.set(data, i * frameSize);
96795 });
96796 return freqData;
96797 }
96798 }, {
96799 key: "getTensorFromAudioDataArray",
96800 value: function getTensorFromAudioDataArray(freqData, shape) {
96801 var vals = new Float32Array(sizeFromShape(shape));
96802 // If the data is less than the output shape, the rest is padded with zeros.
96803 vals.set(freqData, vals.length - freqData.length);
96804 return tensor(vals, shape);
96805 }
96806 }], [{
96807 key: "create",
96808 value: function () {
96809 var _create = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5() {
96810 var microphoneConfig,
96811 microphoneIterator,
96812 _args5 = arguments;
96813 return _regeneratorRuntime().wrap(function _callee5$(_context5) {
96814 while (1) switch (_context5.prev = _context5.next) {
96815 case 0:
96816 microphoneConfig = _args5.length > 0 && _args5[0] !== undefined ? _args5[0] : {};
96817 if (env().get('IS_BROWSER')) {
96818 _context5.next = 3;
96819 break;
96820 }
96821 throw new Error('microphone API is only supported in browser environment.');
96822 case 3:
96823 microphoneIterator = new MicrophoneIterator(microphoneConfig); // Call async function start() to initialize the audio stream.
96824 _context5.next = 6;
96825 return microphoneIterator.start();
96826 case 6:
96827 return _context5.abrupt("return", microphoneIterator);
96828 case 7:
96829 case "end":
96830 return _context5.stop();
96831 }
96832 }, _callee5);
96833 }));
96834 function create() {
96835 return _create.apply(this, arguments);
96836 }
96837 return create;
96838 }()
96839 }]);
96840 return MicrophoneIterator;
96841 }(LazyIterator);
96842
96843 /**
96844 * Provide a stream of image tensors from webcam video stream. Only works in
96845 * browser environment.
96846 */
96847 var WebcamIterator = /*#__PURE__*/function (_LazyIterator) {
96848 _inherits(WebcamIterator, _LazyIterator);
96849 var _super = _createSuper(WebcamIterator);
96850 function WebcamIterator(webcamVideoElement, webcamConfig) {
96851 var _this;
96852 _classCallCheck(this, WebcamIterator);
96853 _this = _super.call(this);
96854 _this.webcamVideoElement = webcamVideoElement;
96855 _this.webcamConfig = webcamConfig;
96856 _this.isClosed = true;
96857 _this.resize = false;
96858 if (_this.needToResize()) {
96859 _this.resize = true;
96860 _this.cropSize = [_this.webcamConfig.resizeHeight, _this.webcamConfig.resizeWidth];
96861 _this.cropBoxInd = tensor1d([0], 'int32');
96862 if (_this.webcamConfig.centerCrop) {
96863 // Calculate the box based on resizing shape.
96864 var widthCroppingRatio = _this.webcamConfig.resizeWidth * 1.0 / _this.webcamVideoElement.width;
96865 var heightCroppingRatio = _this.webcamConfig.resizeHeight * 1.0 / _this.webcamVideoElement.height;
96866 var widthCropStart = (1 - widthCroppingRatio) / 2;
96867 var heightCropStart = (1 - heightCroppingRatio) / 2;
96868 var widthCropEnd = widthCropStart + widthCroppingRatio;
96869 var heightCropEnd = heightCroppingRatio + heightCropStart;
96870 _this.cropBox = tensor2d([heightCropStart, widthCropStart, heightCropEnd, widthCropEnd], [1, 4]);
96871 } else {
96872 _this.cropBox = tensor2d([0, 0, 1, 1], [1, 4]);
96873 }
96874 }
96875 return _this;
96876 }
96877 _createClass(WebcamIterator, [{
96878 key: "summary",
96879 value: function summary() {
96880 return "webcam";
96881 }
96882 // Construct a WebcamIterator and start it's video stream.
96883 }, {
96884 key: "start",
96885 value: // Async function to start video stream.
96886 function () {
96887 var _start = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
96888 var _this2 = this;
96889 return _regeneratorRuntime().wrap(function _callee$(_context) {
96890 while (1) switch (_context.prev = _context.next) {
96891 case 0:
96892 if (this.webcamConfig.facingMode) {
96893 assert$1(this.webcamConfig.facingMode === 'user' || this.webcamConfig.facingMode === 'environment', function () {
96894 return "Invalid webcam facing mode: ".concat(_this2.webcamConfig.facingMode, ". ") + "Please provide 'user' or 'environment'";
96895 });
96896 }
96897 _context.prev = 1;
96898 _context.next = 4;
96899 return navigator.mediaDevices.getUserMedia({
96900 video: {
96901 deviceId: this.webcamConfig.deviceId,
96902 facingMode: this.webcamConfig.facingMode ? this.webcamConfig.facingMode : 'user',
96903 width: this.webcamVideoElement.width,
96904 height: this.webcamVideoElement.height
96905 }
96906 });
96907 case 4:
96908 this.stream = _context.sent;
96909 _context.next = 11;
96910 break;
96911 case 7:
96912 _context.prev = 7;
96913 _context.t0 = _context["catch"](1);
96914 // Modify the error message but leave the stack trace intact
96915 _context.t0.message = "Error thrown while initializing video stream: ".concat(_context.t0.message);
96916 throw _context.t0;
96917 case 11:
96918 if (this.stream) {
96919 _context.next = 13;
96920 break;
96921 }
96922 throw new Error('Could not obtain video from webcam.');
96923 case 13:
96924 // Older browsers may not have srcObject
96925 try {
96926 this.webcamVideoElement.srcObject = this.stream;
96927 } catch (error) {
96928 console.log(error);
96929 this.webcamVideoElement.src = window.URL.createObjectURL(this.stream);
96930 }
96931 // Start the webcam video stream
96932 this.webcamVideoElement.play();
96933 this.isClosed = false;
96934 return _context.abrupt("return", new Promise(function (resolve) {
96935 // Add event listener to make sure the webcam has been fully initialized.
96936 _this2.webcamVideoElement.onloadedmetadata = function () {
96937 resolve();
96938 };
96939 }));
96940 case 17:
96941 case "end":
96942 return _context.stop();
96943 }
96944 }, _callee, this, [[1, 7]]);
96945 }));
96946 function start() {
96947 return _start.apply(this, arguments);
96948 }
96949 return start;
96950 }()
96951 }, {
96952 key: "next",
96953 value: function () {
96954 var _next = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
96955 var img;
96956 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
96957 while (1) switch (_context2.prev = _context2.next) {
96958 case 0:
96959 if (!this.isClosed) {
96960 _context2.next = 2;
96961 break;
96962 }
96963 return _context2.abrupt("return", {
96964 value: null,
96965 done: true
96966 });
96967 case 2:
96968 _context2.prev = 2;
96969 img = fromPixels$1(this.webcamVideoElement);
96970 _context2.next = 9;
96971 break;
96972 case 6:
96973 _context2.prev = 6;
96974 _context2.t0 = _context2["catch"](2);
96975 throw new Error("Error thrown converting video to pixels: ".concat(JSON.stringify(_context2.t0)));
96976 case 9:
96977 if (!this.resize) {
96978 _context2.next = 22;
96979 break;
96980 }
96981 _context2.prev = 10;
96982 return _context2.abrupt("return", {
96983 value: this.cropAndResizeFrame(img),
96984 done: false
96985 });
96986 case 14:
96987 _context2.prev = 14;
96988 _context2.t1 = _context2["catch"](10);
96989 throw new Error("Error thrown cropping the video: ".concat(_context2.t1.message));
96990 case 17:
96991 _context2.prev = 17;
96992 img.dispose();
96993 return _context2.finish(17);
96994 case 20:
96995 _context2.next = 23;
96996 break;
96997 case 22:
96998 return _context2.abrupt("return", {
96999 value: img,
97000 done: false
97001 });
97002 case 23:
97003 case "end":
97004 return _context2.stop();
97005 }
97006 }, _callee2, this, [[2, 6], [10, 14, 17, 20]]);
97007 }));
97008 function next() {
97009 return _next.apply(this, arguments);
97010 }
97011 return next;
97012 }()
97013 }, {
97014 key: "needToResize",
97015 value: function needToResize() {
97016 // If resizeWidth and resizeHeight are provided, and different from the
97017 // width and height of original HTMLVideoElement, then resizing and cropping
97018 // is required.
97019 if (this.webcamConfig.resizeWidth && this.webcamConfig.resizeHeight && (this.webcamVideoElement.width !== this.webcamConfig.resizeWidth || this.webcamVideoElement.height !== this.webcamConfig.resizeHeight)) {
97020 return true;
97021 }
97022 return false;
97023 }
97024 // Cropping and resizing each frame based on config
97025 }, {
97026 key: "cropAndResizeFrame",
97027 value: function cropAndResizeFrame(img) {
97028 var _this3 = this;
97029 return tidy(function () {
97030 var expandedImage = expandDims$3(cast$3(img, 'float32'), 0);
97031 var resizedImage;
97032 resizedImage = image$1.cropAndResize(expandedImage, _this3.cropBox, _this3.cropBoxInd, _this3.cropSize, 'bilinear');
97033 // Extract image from batch cropping.
97034 var shape = resizedImage.shape;
97035 return reshape$3(resizedImage, shape.slice(1));
97036 });
97037 }
97038 // Capture one frame from the video stream, and extract the value from
97039 // iterator.next() result.
97040 }, {
97041 key: "capture",
97042 value: function () {
97043 var _capture = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3() {
97044 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
97045 while (1) switch (_context3.prev = _context3.next) {
97046 case 0:
97047 _context3.next = 2;
97048 return this.next();
97049 case 2:
97050 return _context3.abrupt("return", _context3.sent.value);
97051 case 3:
97052 case "end":
97053 return _context3.stop();
97054 }
97055 }, _callee3, this);
97056 }));
97057 function capture() {
97058 return _capture.apply(this, arguments);
97059 }
97060 return capture;
97061 }() // Stop the video stream and pause webcam iterator.
97062 }, {
97063 key: "stop",
97064 value: function stop() {
97065 var tracks = this.stream.getTracks();
97066 tracks.forEach(function (track) {
97067 return track.stop();
97068 });
97069 try {
97070 this.webcamVideoElement.srcObject = null;
97071 } catch (error) {
97072 console.log(error);
97073 this.webcamVideoElement.src = null;
97074 }
97075 this.isClosed = true;
97076 }
97077 // Override toArray() function to prevent collecting.
97078 }, {
97079 key: "toArray",
97080 value: function toArray() {
97081 throw new Error('Can not convert infinite video stream to array.');
97082 }
97083 }], [{
97084 key: "create",
97085 value: function () {
97086 var _create = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(webcamVideoElement) {
97087 var webcamConfig,
97088 webcamIterator,
97089 _args4 = arguments;
97090 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
97091 while (1) switch (_context4.prev = _context4.next) {
97092 case 0:
97093 webcamConfig = _args4.length > 1 && _args4[1] !== undefined ? _args4[1] : {};
97094 if (env().get('IS_BROWSER')) {
97095 _context4.next = 3;
97096 break;
97097 }
97098 throw new Error('tf.data.webcam is only supported in browser environment.');
97099 case 3:
97100 if (webcamVideoElement) {
97101 _context4.next = 9;
97102 break;
97103 }
97104 // If webcam video element is not provided, create a hidden video element
97105 // with provided width and height.
97106 webcamVideoElement = document.createElement('video');
97107 if (!(!webcamConfig.resizeWidth || !webcamConfig.resizeHeight)) {
97108 _context4.next = 7;
97109 break;
97110 }
97111 throw new Error('Please provide webcam video element, or resizeWidth and ' + 'resizeHeight to create a hidden video element.');
97112 case 7:
97113 webcamVideoElement.width = webcamConfig.resizeWidth;
97114 webcamVideoElement.height = webcamConfig.resizeHeight;
97115 case 9:
97116 webcamIterator = new WebcamIterator(webcamVideoElement, webcamConfig); // Call async function to initialize the video stream.
97117 _context4.next = 12;
97118 return webcamIterator.start();
97119 case 12:
97120 return _context4.abrupt("return", webcamIterator);
97121 case 13:
97122 case "end":
97123 return _context4.stop();
97124 }
97125 }, _callee4);
97126 }));
97127 function create(_x) {
97128 return _create.apply(this, arguments);
97129 }
97130 return create;
97131 }()
97132 }]);
97133 return WebcamIterator;
97134 }(LazyIterator);
97135
97136 /**
97137 * @license
97138 * Copyright 2018 Google LLC. All Rights Reserved.
97139 * Licensed under the Apache License, Version 2.0 (the "License");
97140 * you may not use this file except in compliance with the License.
97141 * You may obtain a copy of the License at
97142 *
97143 * http://www.apache.org/licenses/LICENSE-2.0
97144 *
97145 * Unless required by applicable law or agreed to in writing, software
97146 * distributed under the License is distributed on an "AS IS" BASIS,
97147 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97148 * See the License for the specific language governing permissions and
97149 * limitations under the License.
97150 *
97151 * =============================================================================
97152 */
97153 /**
97154 * Represents a data source readable as a stream of binary data chunks.
97155 *
97156 * Because `Dataset`s can be read repeatedly (via `Dataset.iterator()`), this
97157 * provides a means to repeatedly create streams from the underlying data
97158 * sources.
97159 */
97160 var DataSource = /*#__PURE__*/_createClass(function DataSource() {
97161 _classCallCheck(this, DataSource);
97162 });
97163 // TODO(soergel): consider convenience factory functions here
97164 // in combination with chainable source->dataset above, e.g.:
97165 // tf.data.url(...).asCsvDataset().shuffle().batch()
97166
97167 var StringIterator = /*#__PURE__*/function (_LazyIterator) {
97168 _inherits(StringIterator, _LazyIterator);
97169 var _super = _createSuper(StringIterator);
97170 function StringIterator() {
97171 _classCallCheck(this, StringIterator);
97172 return _super.apply(this, arguments);
97173 }
97174 _createClass(StringIterator, [{
97175 key: "split",
97176 value:
97177 /**
97178 * Splits a string stream on a given separator.
97179 *
97180 * It is assumed that the incoming chunk boundaries have no semantic meaning,
97181 * so conceptually the incoming stream is treated simply as the concatenation
97182 * of its elements.
97183 *
97184 * The outgoing stream provides chunks corresponding to the results of the
97185 * standard string split() operation (even if such a chunk spanned incoming
97186 * chunks). The separators are not included.
97187 *
97188 * A typical usage is to split a text file (represented as a stream with
97189 * arbitrary chunk boundaries) into lines.
97190 *
97191 * @param upstream A readable stream of strings that can be treated as
97192 * concatenated.
97193 * @param separator A character to split on.
97194 */
97195 function split(separator) {
97196 return new SplitIterator(this, separator);
97197 }
97198 }]);
97199 return StringIterator;
97200 }(LazyIterator);
97201 // ============================================================================
97202 // The following private classes serve to implement the chainable methods
97203 // on StringIterator. Unfortunately they can't be placed in separate files, due
97204 // to resulting trouble with circular imports.
97205 // ============================================================================
97206 // We wanted multiple inheritance, e.g.
97207 // class SplitIterator extends QueueIterator<string>, StringIterator
97208 // but the TypeScript mixin approach is a bit hacky, so we take this adapter
97209 // approach instead.
97210 var SplitIterator = /*#__PURE__*/function (_StringIterator) {
97211 _inherits(SplitIterator, _StringIterator);
97212 var _super2 = _createSuper(SplitIterator);
97213 function SplitIterator(upstream, separator) {
97214 var _this;
97215 _classCallCheck(this, SplitIterator);
97216 _this = _super2.call(this);
97217 _this.upstream = upstream;
97218 _this.impl = new SplitIteratorImpl(upstream, separator);
97219 return _this;
97220 }
97221 _createClass(SplitIterator, [{
97222 key: "summary",
97223 value: function summary() {
97224 return this.impl.summary();
97225 }
97226 }, {
97227 key: "next",
97228 value: function () {
97229 var _next = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
97230 return _regeneratorRuntime().wrap(function _callee$(_context) {
97231 while (1) switch (_context.prev = _context.next) {
97232 case 0:
97233 return _context.abrupt("return", this.impl.next());
97234 case 1:
97235 case "end":
97236 return _context.stop();
97237 }
97238 }, _callee, this);
97239 }));
97240 function next() {
97241 return _next.apply(this, arguments);
97242 }
97243 return next;
97244 }()
97245 }]);
97246 return SplitIterator;
97247 }(StringIterator);
97248 var SplitIteratorImpl = /*#__PURE__*/function (_OneToManyIterator) {
97249 _inherits(SplitIteratorImpl, _OneToManyIterator);
97250 var _super3 = _createSuper(SplitIteratorImpl);
97251 function SplitIteratorImpl(upstream, separator) {
97252 var _this2;
97253 _classCallCheck(this, SplitIteratorImpl);
97254 _this2 = _super3.call(this);
97255 _this2.upstream = upstream;
97256 _this2.separator = separator;
97257 // A partial string at the end of an upstream chunk
97258 _this2.carryover = '';
97259 return _this2;
97260 }
97261 _createClass(SplitIteratorImpl, [{
97262 key: "summary",
97263 value: function summary() {
97264 return "".concat(this.upstream.summary(), " -> Split('").concat(this.separator, "')");
97265 }
97266 }, {
97267 key: "pump",
97268 value: function () {
97269 var _pump = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
97270 var chunkResult, lines, _iterator, _step, line;
97271 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
97272 while (1) switch (_context2.prev = _context2.next) {
97273 case 0:
97274 _context2.next = 2;
97275 return this.upstream.next();
97276 case 2:
97277 chunkResult = _context2.sent;
97278 if (!chunkResult.done) {
97279 _context2.next = 9;
97280 break;
97281 }
97282 if (!(this.carryover === '')) {
97283 _context2.next = 6;
97284 break;
97285 }
97286 return _context2.abrupt("return", false);
97287 case 6:
97288 // Pretend that the pump succeeded in order to emit the small last batch.
97289 // The next pump() call will actually fail.
97290 this.outputQueue.push(this.carryover);
97291 this.carryover = '';
97292 return _context2.abrupt("return", true);
97293 case 9:
97294 lines = chunkResult.value.split(this.separator); // Note the behavior: " ab ".split(' ') === ['', 'ab', '']
97295 // Thus the carryover may be '' if the separator falls on a chunk
97296 // boundary; this produces the correct result.
97297 lines[0] = this.carryover + lines[0];
97298 _iterator = _createForOfIteratorHelper(lines.slice(0, -1));
97299 try {
97300 for (_iterator.s(); !(_step = _iterator.n()).done;) {
97301 line = _step.value;
97302 this.outputQueue.push(line);
97303 }
97304 } catch (err) {
97305 _iterator.e(err);
97306 } finally {
97307 _iterator.f();
97308 }
97309 this.carryover = lines[lines.length - 1];
97310 return _context2.abrupt("return", true);
97311 case 15:
97312 case "end":
97313 return _context2.stop();
97314 }
97315 }, _callee2, this);
97316 }));
97317 function pump() {
97318 return _pump.apply(this, arguments);
97319 }
97320 return pump;
97321 }()
97322 }]);
97323 return SplitIteratorImpl;
97324 }(OneToManyIterator);
97325
97326 var ByteChunkIterator = /*#__PURE__*/function (_LazyIterator) {
97327 _inherits(ByteChunkIterator, _LazyIterator);
97328 var _super = _createSuper(ByteChunkIterator);
97329 function ByteChunkIterator() {
97330 _classCallCheck(this, ByteChunkIterator);
97331 return _super.apply(this, arguments);
97332 }
97333 _createClass(ByteChunkIterator, [{
97334 key: "decodeUTF8",
97335 value:
97336 /**
97337 * Decode a stream of UTF8-encoded byte arrays to a stream of strings.
97338 *
97339 * The byte arrays producetd from the ByteChunkIterator on which this is
97340 * called will be interpreted as concatenated. No assumptions are made about
97341 * the boundaries of the incoming chunks, so a multi-byte UTF8 encoding of a
97342 * character may span the boundary between chunks. This naturally happens,
97343 * for instance, when reading fixed-size byte arrays from a file.
97344 */
97345 function decodeUTF8() {
97346 return new Utf8Iterator(this);
97347 }
97348 }]);
97349 return ByteChunkIterator;
97350 }(LazyIterator);
97351 // ============================================================================
97352 // The following private classes serve to implement the chainable methods
97353 // on ByteChunkIterator. Unfortunately they can't be placed in separate files,
97354 // due to resulting trouble with circular imports.
97355 // ============================================================================
97356 // We wanted multiple inheritance, e.g.
97357 // class Utf8Iterator extends QueueIterator<string>, StringIterator
97358 // but the TypeScript mixin approach is a bit hacky, so we take this adapter
97359 // approach instead.
97360 var Utf8Iterator = /*#__PURE__*/function (_StringIterator) {
97361 _inherits(Utf8Iterator, _StringIterator);
97362 var _super2 = _createSuper(Utf8Iterator);
97363 function Utf8Iterator(upstream) {
97364 var _this;
97365 _classCallCheck(this, Utf8Iterator);
97366 _this = _super2.call(this);
97367 _this.upstream = upstream;
97368 _this.impl = new Utf8IteratorImpl(upstream);
97369 return _this;
97370 }
97371 _createClass(Utf8Iterator, [{
97372 key: "summary",
97373 value: function summary() {
97374 return this.impl.summary();
97375 }
97376 }, {
97377 key: "next",
97378 value: function () {
97379 var _next = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
97380 return _regeneratorRuntime().wrap(function _callee$(_context) {
97381 while (1) switch (_context.prev = _context.next) {
97382 case 0:
97383 return _context.abrupt("return", this.impl.next());
97384 case 1:
97385 case "end":
97386 return _context.stop();
97387 }
97388 }, _callee, this);
97389 }));
97390 function next() {
97391 return _next.apply(this, arguments);
97392 }
97393 return next;
97394 }()
97395 }]);
97396 return Utf8Iterator;
97397 }(StringIterator);
97398 /**
97399 * Decode a stream of UTF8-encoded byte arrays to a stream of strings.
97400 *
97401 * This is tricky because the incoming byte array boundaries may disrupt a
97402 * multi-byte UTF8 character. Thus any incomplete character data at the end of
97403 * a chunk must be carried over and prepended to the next chunk before
97404 * decoding. Luckily with native decoder, TextDecoder in browser and
97405 * string_decoder in node, byte array boundaries are handled automatically.
97406 *
97407 * In the context of an input pipeline for machine learning, UTF8 decoding is
97408 * needed to parse text files containing training examples or prediction
97409 * requests (e.g., formatted as CSV or JSON). We cannot use the built-in
97410 * decoding provided by FileReader.readAsText() because here we are in a
97411 * streaming context, which FileReader does not support.
97412 *
97413 * @param upstream A `LazyIterator` of `Uint8Arrays` containing UTF8-encoded
97414 * text, which should be interpreted as concatenated. No assumptions are
97415 * made about the boundaries of the incoming chunks, so a multi-byte UTF8
97416 * encoding of a character may span the boundary between chunks. This
97417 * naturally happens, for instance, when reading fixed-size byte arrays from a
97418 * file.
97419 */
97420 var Utf8IteratorImpl = /*#__PURE__*/function (_OneToManyIterator) {
97421 _inherits(Utf8IteratorImpl, _OneToManyIterator);
97422 var _super3 = _createSuper(Utf8IteratorImpl);
97423 function Utf8IteratorImpl(upstream) {
97424 var _this2;
97425 _classCallCheck(this, Utf8IteratorImpl);
97426 _this2 = _super3.call(this);
97427 _this2.upstream = upstream;
97428 if (env().get('IS_BROWSER')) {
97429 _this2.decoder = new TextDecoder('utf-8');
97430 } else {
97431 // tslint:disable-next-line:no-require-imports
97432 var _require = require('string_decoder'),
97433 StringDecoder = _require.StringDecoder;
97434 _this2.decoder = new StringDecoder('utf8');
97435 }
97436 return _this2;
97437 }
97438 _createClass(Utf8IteratorImpl, [{
97439 key: "summary",
97440 value: function summary() {
97441 return "".concat(this.upstream.summary(), " -> Utf8");
97442 }
97443 }, {
97444 key: "pump",
97445 value: function () {
97446 var _pump = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
97447 var chunkResult, chunk, text;
97448 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
97449 while (1) switch (_context2.prev = _context2.next) {
97450 case 0:
97451 _context2.next = 2;
97452 return this.upstream.next();
97453 case 2:
97454 chunkResult = _context2.sent;
97455 if (!chunkResult.done) {
97456 _context2.next = 7;
97457 break;
97458 }
97459 return _context2.abrupt("return", false);
97460 case 7:
97461 chunk = chunkResult.value;
97462 case 8:
97463 if (env().get('IS_BROWSER')) {
97464 text = this.decoder.decode(chunk, {
97465 stream: true
97466 });
97467 } else {
97468 text = this.decoder.write(Buffer.from(chunk.buffer));
97469 }
97470 this.outputQueue.push(text);
97471 return _context2.abrupt("return", true);
97472 case 11:
97473 case "end":
97474 return _context2.stop();
97475 }
97476 }, _callee2, this);
97477 }));
97478 function pump() {
97479 return _pump.apply(this, arguments);
97480 }
97481 return pump;
97482 }()
97483 }]);
97484 return Utf8IteratorImpl;
97485 }(OneToManyIterator);
97486
97487 /**
97488 * Provide a stream of chunks from a File, Blob, or Uint8Array.
97489 * @param file The source File, Blob or Uint8Array.
97490 * @param options Optional settings controlling file reading.
97491 * @returns a lazy Iterator of Uint8Arrays containing sequential chunks of the
97492 * input File, Blob or Uint8Array.
97493 */
97494 var FileChunkIterator = /*#__PURE__*/function (_ByteChunkIterator) {
97495 _inherits(FileChunkIterator, _ByteChunkIterator);
97496 var _super = _createSuper(FileChunkIterator);
97497 function FileChunkIterator(file) {
97498 var _this;
97499 var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
97500 _classCallCheck(this, FileChunkIterator);
97501 _this = _super.call(this);
97502 _this.file = file;
97503 _this.options = options;
97504 assert$1(file instanceof Uint8Array || (env().get('IS_BROWSER') ? file instanceof File || file instanceof Blob : false), function () {
97505 return 'FileChunkIterator only supports File, Blob and Uint8Array ' + 'right now.';
97506 });
97507 _this.offset = options.offset || 0;
97508 // default 1MB chunk has tolerable perf on large files
97509 _this.chunkSize = options.chunkSize || 1024 * 1024;
97510 return _this;
97511 }
97512 _createClass(FileChunkIterator, [{
97513 key: "summary",
97514 value: function summary() {
97515 return "FileChunks ".concat(this.file);
97516 }
97517 }, {
97518 key: "next",
97519 value: function () {
97520 var _next = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
97521 var _this2 = this;
97522 var chunk;
97523 return _regeneratorRuntime().wrap(function _callee$(_context) {
97524 while (1) switch (_context.prev = _context.next) {
97525 case 0:
97526 if (!(this.offset >= (this.file instanceof Uint8Array ? this.file.byteLength : this.file.size))) {
97527 _context.next = 2;
97528 break;
97529 }
97530 return _context.abrupt("return", {
97531 value: null,
97532 done: true
97533 });
97534 case 2:
97535 chunk = new Promise(function (resolve, reject) {
97536 var end = _this2.offset + _this2.chunkSize;
97537 if (_this2.file instanceof Uint8Array) {
97538 // Note if end > this.uint8Array.byteLength, we just get a small last
97539 // chunk.
97540 resolve(new Uint8Array(_this2.file.slice(_this2.offset, end)));
97541 } else {
97542 // This branch assumes that this.file type is File or Blob, which
97543 // means it is in the browser environment.
97544 // TODO(soergel): is this a performance issue?
97545 var fileReader = new FileReader();
97546 fileReader.onload = function (event) {
97547 var data = fileReader.result;
97548 // Not sure we can trust the return type of
97549 // FileReader.readAsArrayBuffer See e.g.
97550 // https://github.com/node-file-api/FileReader/issues/2
97551 if (data instanceof ArrayBuffer) {
97552 data = new Uint8Array(data);
97553 }
97554 if (!(data instanceof Uint8Array)) {
97555 return reject(new TypeError('FileReader returned unknown type.'));
97556 }
97557 resolve(data);
97558 };
97559 fileReader.onabort = function (event) {
97560 return reject(new Error('Aborted'));
97561 };
97562 fileReader.onerror = function (event) {
97563 return reject(new Error(event.type));
97564 };
97565 // TODO(soergel): better handle onabort, onerror
97566 // Note if end > this.file.size, we just get a small last chunk.
97567 var slice = _this2.file.slice(_this2.offset, end);
97568 // We can't use readAsText here (even if we know the file is text)
97569 // because the slice boundary may fall within a multi-byte character.
97570 fileReader.readAsArrayBuffer(slice);
97571 }
97572 _this2.offset = end;
97573 });
97574 _context.next = 5;
97575 return chunk;
97576 case 5:
97577 _context.t0 = _context.sent;
97578 return _context.abrupt("return", {
97579 value: _context.t0,
97580 done: false
97581 });
97582 case 7:
97583 case "end":
97584 return _context.stop();
97585 }
97586 }, _callee, this);
97587 }));
97588 function next() {
97589 return _next.apply(this, arguments);
97590 }
97591 return next;
97592 }()
97593 }]);
97594 return FileChunkIterator;
97595 }(ByteChunkIterator);
97596
97597 /**
97598 * Provide a stream of chunks from a URL.
97599 *
97600 * Note this class first downloads the entire file into memory before providing
97601 * the first element from the stream. This is because the Fetch API does not
97602 * yet reliably provide a reader stream for the response body.
97603 */
97604 function urlChunkIterator(_x) {
97605 return _urlChunkIterator.apply(this, arguments);
97606 }
97607 // Generate RequestInit from Request to match tf.util.fetch signature.
97608 function _urlChunkIterator() {
97609 _urlChunkIterator = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(url) {
97610 var options,
97611 fetchFunc,
97612 urlString,
97613 requestInit,
97614 response,
97615 uint8Array,
97616 _args = arguments;
97617 return _regeneratorRuntime().wrap(function _callee$(_context) {
97618 while (1) switch (_context.prev = _context.next) {
97619 case 0:
97620 options = _args.length > 1 && _args[1] !== undefined ? _args[1] : {};
97621 fetchFunc = _args.length > 2 ? _args[2] : undefined;
97622 if (typeof url === 'string') {
97623 urlString = url;
97624 } else {
97625 urlString = url.url;
97626 requestInit = getRequestInitFromRequest(url);
97627 }
97628 _context.next = 5;
97629 return (fetchFunc || fetch$1)(urlString, requestInit);
97630 case 5:
97631 response = _context.sent;
97632 if (!response.ok) {
97633 _context.next = 15;
97634 break;
97635 }
97636 _context.t0 = Uint8Array;
97637 _context.next = 10;
97638 return response.arrayBuffer();
97639 case 10:
97640 _context.t1 = _context.sent;
97641 uint8Array = new _context.t0(_context.t1);
97642 return _context.abrupt("return", new FileChunkIterator(uint8Array, options));
97643 case 15:
97644 throw new Error(response.statusText);
97645 case 16:
97646 case "end":
97647 return _context.stop();
97648 }
97649 }, _callee);
97650 }));
97651 return _urlChunkIterator.apply(this, arguments);
97652 }
97653 var getRequestInitFromRequest = function getRequestInitFromRequest(request) {
97654 var init = {
97655 method: request.method,
97656 headers: request.headers,
97657 body: request.body,
97658 mode: request.mode,
97659 credentials: request.credentials,
97660 cache: request.cache,
97661 redirect: request.redirect,
97662 referrer: request.referrer,
97663 integrity: request.integrity
97664 };
97665 return init;
97666 };
97667
97668 /**
97669 * @license
97670 * Copyright 2018 Google LLC. All Rights Reserved.
97671 * Licensed under the Apache License, Version 2.0 (the "License");
97672 * you may not use this file except in compliance with the License.
97673 * You may obtain a copy of the License at
97674 *
97675 * http://www.apache.org/licenses/LICENSE-2.0
97676 *
97677 * Unless required by applicable law or agreed to in writing, software
97678 * distributed under the License is distributed on an "AS IS" BASIS,
97679 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97680 * See the License for the specific language governing permissions and
97681 * limitations under the License.
97682 *
97683 * =============================================================================
97684 */
97685 // Skip tslint any type check cause this method is aiming to check type of
97686 // input.
97687 // tslint:disable-next-line:no-any
97688 function isLocalPath(source) {
97689 return typeof source === 'string' && source.slice(0, 7) === 'file://';
97690 }
97691
97692 /**
97693 * Represents a file, blob, or Uint8Array readable as a stream of binary data
97694 * chunks.
97695 */
97696 var FileDataSource = /*#__PURE__*/function (_DataSource) {
97697 _inherits(FileDataSource, _DataSource);
97698 var _super = _createSuper(FileDataSource);
97699 /**
97700 * Create a `FileDataSource`.
97701 *
97702 * @param input Local file path, or `File`/`Blob`/`Uint8Array` object to
97703 * read. Local file only works in node environment.
97704 * @param options Options passed to the underlying `FileChunkIterator`s,
97705 * such as {chunksize: 1024}.
97706 */
97707 function FileDataSource(input) {
97708 var _this;
97709 var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
97710 _classCallCheck(this, FileDataSource);
97711 _this = _super.call(this);
97712 _this.input = input;
97713 _this.options = options;
97714 return _this;
97715 }
97716 _createClass(FileDataSource, [{
97717 key: "iterator",
97718 value: function () {
97719 var _iterator = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
97720 var fs;
97721 return _regeneratorRuntime().wrap(function _callee$(_context) {
97722 while (1) switch (_context.prev = _context.next) {
97723 case 0:
97724 if (isLocalPath(this.input) && env().get('IS_NODE')) {
97725 // tslint:disable-next-line:no-require-imports
97726 fs = require('fs');
97727 this.input = fs.readFileSync(this.input.slice(7));
97728 }
97729 // TODO(kangyizhang): Add LocalFileChunkIterator to split local streaming
97730 // with file in browser.
97731 return _context.abrupt("return", new FileChunkIterator(this.input, this.options));
97732 case 2:
97733 case "end":
97734 return _context.stop();
97735 }
97736 }, _callee, this);
97737 }));
97738 function iterator() {
97739 return _iterator.apply(this, arguments);
97740 }
97741 return iterator;
97742 }()
97743 }]);
97744 return FileDataSource;
97745 }(DataSource);
97746
97747 /*
97748 * Represents a URL readable as a stream of binary data chunks.
97749 */
97750 var URLDataSource = /*#__PURE__*/function (_DataSource) {
97751 _inherits(URLDataSource, _DataSource);
97752 var _super = _createSuper(URLDataSource);
97753 /**
97754 * Create a `URLDataSource`.
97755 *
97756 * @param url A source URL string, or a `Request` object.
97757 * @param options Options passed to the underlying `FileChunkIterator`s,
97758 * such as {chunksize: 1024}.
97759 */
97760 function URLDataSource(url) {
97761 var _this;
97762 var fileOptions = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
97763 _classCallCheck(this, URLDataSource);
97764 _this = _super.call(this);
97765 _this.url = url;
97766 _this.fileOptions = fileOptions;
97767 return _this;
97768 }
97769 // TODO(soergel): provide appropriate caching options. Currently this
97770 // will download the URL anew for each call to iterator(). Since we have
97771 // to treat the downloaded file as a blob/buffer anyway, we may as well retain
97772 // it-- but that raises GC issues. Also we may want a persistent disk cache.
97773 _createClass(URLDataSource, [{
97774 key: "iterator",
97775 value: function () {
97776 var _iterator = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
97777 return _regeneratorRuntime().wrap(function _callee$(_context) {
97778 while (1) switch (_context.prev = _context.next) {
97779 case 0:
97780 if (!isLocalPath(this.url)) {
97781 _context.next = 4;
97782 break;
97783 }
97784 return _context.abrupt("return", new FileDataSource(this.url, this.fileOptions).iterator());
97785 case 4:
97786 return _context.abrupt("return", urlChunkIterator(this.url, this.fileOptions));
97787 case 5:
97788 case "end":
97789 return _context.stop();
97790 }
97791 }, _callee, this);
97792 }));
97793 function iterator() {
97794 return _iterator.apply(this, arguments);
97795 }
97796 return iterator;
97797 }()
97798 }]);
97799 return URLDataSource;
97800 }(DataSource);
97801
97802 /**
97803 * Create a `CSVDataset` by reading and decoding CSV file(s) from provided URL
97804 * or local path if it's in Node environment.
97805 *
97806 * Note: If isLabel in columnConfigs is `true` for at least one column, the
97807 * element in returned `CSVDataset` will be an object of
97808 * `{xs:features, ys:labels}`: xs is a dict of features key/value pairs, ys
97809 * is a dict of labels key/value pairs. If no column is marked as label,
97810 * returns a dict of features only.
97811 *
97812 * ```js
97813 * const csvUrl =
97814 * 'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';
97815 *
97816 * async function run() {
97817 * // We want to predict the column "medv", which represents a median value of
97818 * // a home (in $1000s), so we mark it as a label.
97819 * const csvDataset = tf.data.csv(
97820 * csvUrl, {
97821 * columnConfigs: {
97822 * medv: {
97823 * isLabel: true
97824 * }
97825 * }
97826 * });
97827 *
97828 * // Number of features is the number of column names minus one for the label
97829 * // column.
97830 * const numOfFeatures = (await csvDataset.columnNames()).length - 1;
97831 *
97832 * // Prepare the Dataset for training.
97833 * const flattenedDataset =
97834 * csvDataset
97835 * .map(({xs, ys}) =>
97836 * {
97837 * // Convert xs(features) and ys(labels) from object form (keyed by
97838 * // column name) to array form.
97839 * return {xs:Object.values(xs), ys:Object.values(ys)};
97840 * })
97841 * .batch(10);
97842 *
97843 * // Define the model.
97844 * const model = tf.sequential();
97845 * model.add(tf.layers.dense({
97846 * inputShape: [numOfFeatures],
97847 * units: 1
97848 * }));
97849 * model.compile({
97850 * optimizer: tf.train.sgd(0.000001),
97851 * loss: 'meanSquaredError'
97852 * });
97853 *
97854 * // Fit the model using the prepared Dataset
97855 * return model.fitDataset(flattenedDataset, {
97856 * epochs: 10,
97857 * callbacks: {
97858 * onEpochEnd: async (epoch, logs) => {
97859 * console.log(epoch + ':' + logs.loss);
97860 * }
97861 * }
97862 * });
97863 * }
97864 *
97865 * await run();
97866 * ```
97867 *
97868 * @param source URL or local path to get CSV file. If it's a local path, it
97869 * must have prefix `file://` and it only works in node environment.
97870 * @param csvConfig (Optional) A CSVConfig object that contains configurations
97871 * of reading and decoding from CSV file(s).
97872 *
97873 * @doc {
97874 * heading: 'Data',
97875 * subheading: 'Creation',
97876 * namespace: 'data',
97877 * configParamIndices: [1]
97878 * }
97879 */
97880 function csv(source) {
97881 var csvConfig = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
97882 return new CSVDataset(new URLDataSource(source), csvConfig);
97883 }
97884 /**
97885 * Create a `Dataset` that produces each element by calling a provided function.
97886 *
97887 * Note that repeated iterations over this `Dataset` may produce different
97888 * results, because the function will be called anew for each element of each
97889 * iteration.
97890 *
97891 * Also, beware that the sequence of calls to this function may be out of order
97892 * in time with respect to the logical order of the Dataset. This is due to the
97893 * asynchronous lazy nature of stream processing, and depends on downstream
97894 * transformations (e.g. .shuffle()). If the provided function is pure, this is
97895 * no problem, but if it is a closure over a mutable state (e.g., a traversal
97896 * pointer), then the order of the produced elements may be scrambled.
97897 *
97898 * ```js
97899 * let i = -1;
97900 * const func = () =>
97901 * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
97902 * const ds = tf.data.func(func);
97903 * await ds.forEachAsync(e => console.log(e));
97904 * ```
97905 *
97906 * @param f A function that produces one data element on each call.
97907 */
97908 function func(f) {
97909 var iter = iteratorFromFunction(f);
97910 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee() {
97911 return _regeneratorRuntime().wrap(function _callee$(_context) {
97912 while (1) switch (_context.prev = _context.next) {
97913 case 0:
97914 return _context.abrupt("return", iter);
97915 case 1:
97916 case "end":
97917 return _context.stop();
97918 }
97919 }, _callee);
97920 })));
97921 }
97922 /**
97923 * Create a `Dataset` that produces each element from provided JavaScript
97924 * generator, which is a function that returns a (potentially async) iterator.
97925 *
97926 * For more information on iterators and generators, see
97927 * https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators .
97928 * For the iterator protocol, see
97929 * https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Iteration_protocols .
97930 *
97931 * Example of creating a dataset from an iterator factory:
97932 * ```js
97933 * function makeIterator() {
97934 * const numElements = 10;
97935 * let index = 0;
97936 *
97937 * const iterator = {
97938 * next: () => {
97939 * let result;
97940 * if (index < numElements) {
97941 * result = {value: index, done: false};
97942 * index++;
97943 * return result;
97944 * }
97945 * return {value: index, done: true};
97946 * }
97947 * };
97948 * return iterator;
97949 * }
97950 * const ds = tf.data.generator(makeIterator);
97951 * await ds.forEachAsync(e => console.log(e));
97952 * ```
97953 *
97954 * Example of creating a dataset from a generator:
97955 * ```js
97956 * function* dataGenerator() {
97957 * const numElements = 10;
97958 * let index = 0;
97959 * while (index < numElements) {
97960 * const x = index;
97961 * index++;
97962 * yield x;
97963 * }
97964 * }
97965 *
97966 * const ds = tf.data.generator(dataGenerator);
97967 * await ds.forEachAsync(e => console.log(e));
97968 * ```
97969 *
97970 * @param generator A JavaScript function that returns
97971 * a (potentially async) JavaScript iterator.
97972 *
97973 * @doc {
97974 * heading: 'Data',
97975 * subheading: 'Creation',
97976 * namespace: 'data',
97977 * configParamIndices: [1]
97978 * }
97979 */
97980 function generator(generator) {
97981 return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
97982 var gen;
97983 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
97984 while (1) switch (_context2.prev = _context2.next) {
97985 case 0:
97986 _context2.next = 2;
97987 return generator();
97988 case 2:
97989 gen = _context2.sent;
97990 return _context2.abrupt("return", iteratorFromFunction(function () {
97991 return gen.next();
97992 }));
97993 case 4:
97994 case "end":
97995 return _context2.stop();
97996 }
97997 }, _callee2);
97998 })));
97999 }
98000 /**
98001 * Create an iterator that generates `Tensor`s from webcam video stream. This
98002 * API only works in Browser environment when the device has webcam.
98003 *
98004 * Note: this code snippet only works when the device has a webcam. It will
98005 * request permission to open the webcam when running.
98006 * ```js
98007 * const videoElement = document.createElement('video');
98008 * videoElement.width = 100;
98009 * videoElement.height = 100;
98010 * const cam = await tf.data.webcam(videoElement);
98011 * const img = await cam.capture();
98012 * img.print();
98013 * cam.stop();
98014 * ```
98015 *
98016 * @param webcamVideoElement A `HTMLVideoElement` used to play video from
98017 * webcam. If this element is not provided, a hidden `HTMLVideoElement` will
98018 * be created. In that case, `resizeWidth` and `resizeHeight` must be
98019 * provided to set the generated tensor shape.
98020 * @param webcamConfig A `WebcamConfig` object that contains configurations of
98021 * reading and manipulating data from webcam video stream.
98022 *
98023 * @doc {
98024 * heading: 'Data',
98025 * subheading: 'Creation',
98026 * namespace: 'data',
98027 * ignoreCI: true
98028 * }
98029 */
98030 function webcam(_x, _x2) {
98031 return _webcam.apply(this, arguments);
98032 }
98033 /**
98034 * Create an iterator that generates frequency-domain spectrogram `Tensor`s from
98035 * microphone audio stream with browser's native FFT. This API only works in
98036 * browser environment when the device has microphone.
98037 *
98038 * Note: this code snippet only works when the device has a microphone. It will
98039 * request permission to open the microphone when running.
98040 * ```js
98041 * const mic = await tf.data.microphone({
98042 * fftSize: 1024,
98043 * columnTruncateLength: 232,
98044 * numFramesPerSpectrogram: 43,
98045 * sampleRateHz:44100,
98046 * includeSpectrogram: true,
98047 * includeWaveform: true
98048 * });
98049 * const audioData = await mic.capture();
98050 * const spectrogramTensor = audioData.spectrogram;
98051 * spectrogramTensor.print();
98052 * const waveformTensor = audioData.waveform;
98053 * waveformTensor.print();
98054 * mic.stop();
98055 * ```
98056 *
98057 * @param microphoneConfig A `MicrophoneConfig` object that contains
98058 * configurations of reading audio data from microphone.
98059 *
98060 * @doc {
98061 * heading: 'Data',
98062 * subheading: 'Creation',
98063 * namespace: 'data',
98064 * ignoreCI: true
98065 * }
98066 */
98067 function _webcam() {
98068 _webcam = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(webcamVideoElement, webcamConfig) {
98069 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
98070 while (1) switch (_context3.prev = _context3.next) {
98071 case 0:
98072 return _context3.abrupt("return", WebcamIterator.create(webcamVideoElement, webcamConfig));
98073 case 1:
98074 case "end":
98075 return _context3.stop();
98076 }
98077 }, _callee3);
98078 }));
98079 return _webcam.apply(this, arguments);
98080 }
98081 function microphone(_x3) {
98082 return _microphone.apply(this, arguments);
98083 }
98084 function _microphone() {
98085 _microphone = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4(microphoneConfig) {
98086 return _regeneratorRuntime().wrap(function _callee4$(_context4) {
98087 while (1) switch (_context4.prev = _context4.next) {
98088 case 0:
98089 return _context4.abrupt("return", MicrophoneIterator.create(microphoneConfig));
98090 case 1:
98091 case "end":
98092 return _context4.stop();
98093 }
98094 }, _callee4);
98095 }));
98096 return _microphone.apply(this, arguments);
98097 }
98098
98099 /** @license See the LICENSE file. */
98100 // This code is auto-generated, do not modify this file!
98101 var version$4 = '4.22.0';
98102
98103 /**
98104 * @license
98105 * Copyright 2018 Google LLC. All Rights Reserved.
98106 * Licensed under the Apache License, Version 2.0 (the "License");
98107 * you may not use this file except in compliance with the License.
98108 * You may obtain a copy of the License at
98109 *
98110 * http://www.apache.org/licenses/LICENSE-2.0
98111 *
98112 * Unless required by applicable law or agreed to in writing, software
98113 * distributed under the License is distributed on an "AS IS" BASIS,
98114 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98115 * See the License for the specific language governing permissions and
98116 * limitations under the License.
98117 * =============================================================================
98118 */
98119
98120 var index = {
98121 __proto__: null,
98122 CSVDataset: CSVDataset,
98123 Dataset: Dataset,
98124 FileDataSource: FileDataSource,
98125 TextLineDataset: TextLineDataset,
98126 URLDataSource: URLDataSource,
98127 array: array,
98128 csv: csv,
98129 func: func,
98130 generator: generator,
98131 microphone: microphone,
98132 version_data: version$4,
98133 webcam: webcam,
98134 zip: zip
98135 };
98136
98137 /**
98138 * @license
98139 * Copyright 2019 Google LLC. All Rights Reserved.
98140 * Licensed under the Apache License, Version 2.0 (the "License");
98141 * you may not use this file except in compliance with the License.
98142 * You may obtain a copy of the License at
98143 *
98144 * http://www.apache.org/licenses/LICENSE-2.0
98145 *
98146 * Unless required by applicable law or agreed to in writing, software
98147 * distributed under the License is distributed on an "AS IS" BASIS,
98148 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98149 * See the License for the specific language governing permissions and
98150 * limitations under the License.
98151 * =============================================================================
98152 */
98153 function assertNotComplex$1(tensor, opName) {
98154 if (!Array.isArray(tensor)) {
98155 tensor = [tensor];
98156 }
98157 tensor.forEach(function (t) {
98158 if (t != null) {
98159 assert$1(t.dtype !== 'complex64', function () {
98160 return "".concat(opName, " does not support complex64 tensors in the CPU backend.");
98161 });
98162 }
98163 });
98164 }
98165
98166 var whereImpl$1 = whereImpl$2;
98167 var MathBackendCPU = /*#__PURE__*/function (_KernelBackend) {
98168 _inherits(MathBackendCPU, _KernelBackend);
98169 var _super = _createSuper(MathBackendCPU);
98170 function MathBackendCPU() {
98171 var _this;
98172 _classCallCheck(this, MathBackendCPU);
98173 _this = _super.call(this);
98174 _this.blockSize = 48;
98175 _this.firstUse = true;
98176 _this.data = new DataStorage(_assertThisInitialized(_this), engine());
98177 return _this;
98178 }
98179 _createClass(MathBackendCPU, [{
98180 key: "nextDataId",
98181 value: function nextDataId() {
98182 return MathBackendCPU.nextDataId++;
98183 }
98184 }, {
98185 key: "write",
98186 value: function write(values, shape, dtype) {
98187 if (this.firstUse) {
98188 this.firstUse = false;
98189 if (env().get('IS_NODE')) {
98190 warn('\n============================\n' + 'Hi, looks like you are running TensorFlow.js in ' + 'Node.js. To speed things up dramatically, install our node ' + 'backend, visit https://github.com/tensorflow/tfjs-node for more details. ' + '\n============================');
98191 }
98192 }
98193 var dataId = {
98194 id: this.nextDataId()
98195 };
98196 this.data.set(dataId, {
98197 values: values,
98198 dtype: dtype,
98199 refCount: 1
98200 });
98201 return dataId;
98202 }
98203 /**
98204 * Create a data bucket in cpu backend.
98205 * @param shape Shape of the `TensorInfo`.
98206 * @param dtype DType of the `TensorInfo`.
98207 * @param values The value of the `TensorInfo` stored as a flattened array.
98208 */
98209 }, {
98210 key: "makeTensorInfo",
98211 value: function makeTensorInfo(shape, dtype, values) {
98212 var outId;
98213 if (dtype === 'string' && values != null && values.length > 0 && isString(values[0])) {
98214 var encodedValues = values.map(function (d) {
98215 return encodeString(d);
98216 });
98217 outId = this.write(encodedValues, shape, dtype);
98218 } else {
98219 outId = this.write(values, shape, dtype);
98220 }
98221 return {
98222 dataId: outId,
98223 shape: shape,
98224 dtype: dtype
98225 };
98226 }
98227 /** Return refCount of a `TensorData`. */
98228 }, {
98229 key: "refCount",
98230 value: function refCount(dataId) {
98231 if (this.data.has(dataId)) {
98232 var tensorData = this.data.get(dataId);
98233 return tensorData.refCount;
98234 }
98235 return 0;
98236 }
98237 /** Increase refCount of a `TensorData`. */
98238 }, {
98239 key: "incRef",
98240 value: function incRef(dataId) {
98241 var tensorData = this.data.get(dataId);
98242 tensorData.refCount++;
98243 }
98244 /** Decrease refCount of a `TensorData`. */
98245 }, {
98246 key: "decRef",
98247 value: function decRef(dataId) {
98248 if (this.data.has(dataId)) {
98249 var tensorData = this.data.get(dataId);
98250 tensorData.refCount--;
98251 }
98252 }
98253 }, {
98254 key: "move",
98255 value: function move(dataId, values, shape, dtype, refCount) {
98256 this.data.set(dataId, {
98257 values: values,
98258 dtype: dtype,
98259 refCount: refCount
98260 });
98261 }
98262 }, {
98263 key: "numDataIds",
98264 value: function numDataIds() {
98265 return this.data.numDataIds();
98266 }
98267 }, {
98268 key: "read",
98269 value: function () {
98270 var _read = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(dataId) {
98271 return _regeneratorRuntime().wrap(function _callee$(_context) {
98272 while (1) switch (_context.prev = _context.next) {
98273 case 0:
98274 return _context.abrupt("return", this.readSync(dataId));
98275 case 1:
98276 case "end":
98277 return _context.stop();
98278 }
98279 }, _callee, this);
98280 }));
98281 function read(_x) {
98282 return _read.apply(this, arguments);
98283 }
98284 return read;
98285 }()
98286 }, {
98287 key: "readSync",
98288 value: function readSync(dataId) {
98289 var _this$data$get = this.data.get(dataId),
98290 dtype = _this$data$get.dtype,
98291 complexTensorInfos = _this$data$get.complexTensorInfos;
98292 if (dtype === 'complex64') {
98293 var realValues = this.readSync(complexTensorInfos.real.dataId);
98294 var imagValues = this.readSync(complexTensorInfos.imag.dataId);
98295 return mergeRealAndImagArrays(realValues, imagValues);
98296 }
98297 return convertBackendValuesAndArrayBuffer(this.data.get(dataId).values, dtype);
98298 }
98299 }, {
98300 key: "bufferSync",
98301 value: function bufferSync(t) {
98302 var data = this.readSync(t.dataId);
98303 if (t.dtype === 'string') {
98304 try {
98305 // Decode the bytes into string.
98306 var strings = data.map(function (d) {
98307 return decodeString(d);
98308 });
98309 return buffer(t.shape, t.dtype, strings);
98310 } catch (_a) {
98311 throw new Error('Failed to decode encoded string bytes into utf-8');
98312 }
98313 }
98314 return buffer(t.shape, t.dtype, data);
98315 }
98316 }, {
98317 key: "makeOutput",
98318 value: function makeOutput(values, shape, dtype) {
98319 return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this);
98320 }
98321 /**
98322 * Dispose the memory if the dataId has 0 refCount. Return true if the memory
98323 * is released or memory is not managed in this backend, false if memory is
98324 * not cleared.
98325 * @param dataId
98326 * @oaram force Optional, remove the data regardless of refCount
98327 */
98328 }, {
98329 key: "disposeData",
98330 value: function disposeData(dataId) {
98331 var force = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
98332 if (this.data.has(dataId)) {
98333 this.data.get(dataId).refCount--;
98334 if (!force && this.data.get(dataId).refCount > 0) {
98335 return false;
98336 }
98337 var _this$data$get2 = this.data.get(dataId),
98338 complexTensorInfos = _this$data$get2.complexTensorInfos;
98339 if (complexTensorInfos != null) {
98340 this.disposeData(complexTensorInfos.real.dataId, true);
98341 this.disposeData(complexTensorInfos.imag.dataId, true);
98342 }
98343 this.data.delete(dataId);
98344 }
98345 return true;
98346 }
98347 }, {
98348 key: "disposeIntermediateTensorInfo",
98349 value: function disposeIntermediateTensorInfo(tensorInfo) {
98350 this.disposeData(tensorInfo.dataId);
98351 }
98352 }, {
98353 key: "time",
98354 value: function () {
98355 var _time = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2(f) {
98356 var start, kernelMs;
98357 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
98358 while (1) switch (_context2.prev = _context2.next) {
98359 case 0:
98360 start = now();
98361 f();
98362 kernelMs = now() - start;
98363 return _context2.abrupt("return", {
98364 kernelMs: kernelMs
98365 });
98366 case 4:
98367 case "end":
98368 return _context2.stop();
98369 }
98370 }, _callee2);
98371 }));
98372 function time(_x2) {
98373 return _time.apply(this, arguments);
98374 }
98375 return time;
98376 }()
98377 }, {
98378 key: "memory",
98379 value: function memory() {
98380 return {
98381 // Unreliable due to automatic gc. The numbers above are cumulative.
98382 unreliable: true,
98383 reasons: ['The reported memory is an upper bound. Due to automatic garbage ' + 'collection, the true allocated memory may be less.']
98384 };
98385 }
98386 }, {
98387 key: "where",
98388 value: function where(condition) {
98389 assertNotComplex$1([condition], 'where');
98390 var condVals = this.readSync(condition.dataId);
98391 return whereImpl$1(condition.shape, condVals);
98392 }
98393 }, {
98394 key: "dispose",
98395 value: function dispose() {}
98396 }, {
98397 key: "floatPrecision",
98398 value: function floatPrecision() {
98399 return 32;
98400 }
98401 /** Returns the smallest representable number. */
98402 }, {
98403 key: "epsilon",
98404 value: function epsilon() {
98405 return _get(_getPrototypeOf(MathBackendCPU.prototype), "epsilon", this).call(this);
98406 }
98407 }]);
98408 return MathBackendCPU;
98409 }(KernelBackend);
98410 MathBackendCPU.nextDataId = 0;
98411
98412 /**
98413 * @license
98414 * Copyright 2020 Google LLC. All Rights Reserved.
98415 * Licensed under the Apache License, Version 2.0 (the License);
98416 * you may not use this file except in compliance with the License.
98417 * You may obtain a copy of the License at
98418 *
98419 * http://www.apache.org/licenses/LICENSE-2.0
98420 *
98421 * Unless required by applicable law or agreed to in writing, software
98422 * distributed under the License is distributed on an AS IS BASIS,
98423 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98424 * See the License for the specific language governing permissions and
98425 * limitations under the License.
98426 * =============================================================================
98427 */
98428 function simpleAbsImpl(vals) {
98429 var resultValues = new Float32Array(vals.length);
98430 for (var i = 0; i < vals.length; ++i) {
98431 resultValues[i] = Math.abs(vals[i]);
98432 }
98433 return resultValues;
98434 }
98435 var abs$1 = function abs(args) {
98436 var x = args.inputs.x;
98437 var cpuBackend = args.backend;
98438 assertNotComplex$1(x, 'abs');
98439 var resultValues = new Float32Array(sizeFromShape(x.shape));
98440 var values = cpuBackend.data.get(x.dataId).values;
98441 resultValues = simpleAbsImpl(values);
98442 return cpuBackend.makeOutput(resultValues, x.shape, x.dtype);
98443 };
98444 var absConfig$1 = {
98445 kernelName: Abs,
98446 backendName: 'cpu',
98447 kernelFunc: abs$1
98448 };
98449
98450 /**
98451 * @license
98452 * Copyright 2020 Google LLC. All Rights Reserved.
98453 * Licensed under the Apache License, Version 2.0 (the "License");
98454 * you may not use this file except in compliance with the License.
98455 * You may obtain a copy of the License at
98456 *
98457 * http://www.apache.org/licenses/LICENSE-2.0
98458 *
98459 * Unless required by applicable law or agreed to in writing, software
98460 * distributed under the License is distributed on an "AS IS" BASIS,
98461 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98462 * See the License for the specific language governing permissions and
98463 * limitations under the License.
98464 * =============================================================================
98465 */
98466 /**
98467 * Template that creates implementation for binary ops. Supports broadcast.
98468 */
98469 function createSimpleBinaryKernelImpl(op) {
98470 return function (aShape, bShape, aVals, bVals, dtype) {
98471 var newShape = assertAndGetBroadcastShape(aShape, bShape);
98472 var resultRank = newShape.length;
98473 var resultStrides = computeStrides(newShape);
98474 var resultSize = sizeFromShape(newShape);
98475 var result = getTypedArrayFromDType(dtype, resultSize);
98476 var aRank = aShape.length;
98477 var bRank = bShape.length;
98478 var aStrides = computeStrides(aShape);
98479 var bStrides = computeStrides(bShape);
98480 var aBroadcastDims = getBroadcastDims$1(aShape, newShape);
98481 var bBroadcastDims = getBroadcastDims$1(bShape, newShape);
98482 if (aBroadcastDims.length + bBroadcastDims.length === 0) {
98483 for (var i = 0; i < result.length; ++i) {
98484 result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
98485 }
98486 } else {
98487 var _loop = function _loop() {
98488 var loc = indexToLoc(_i, resultRank, resultStrides);
98489 var aLoc = loc.slice(-aRank);
98490 aBroadcastDims.forEach(function (d) {
98491 return aLoc[d] = 0;
98492 });
98493 var aIndex = locToIndex(aLoc, aRank, aStrides);
98494 var bLoc = loc.slice(-bRank);
98495 bBroadcastDims.forEach(function (d) {
98496 return bLoc[d] = 0;
98497 });
98498 var bIndex = locToIndex(bLoc, bRank, bStrides);
98499 result[_i] = op(aVals[aIndex], bVals[bIndex]);
98500 };
98501 for (var _i = 0; _i < result.length; ++_i) {
98502 _loop();
98503 }
98504 }
98505 return [result, newShape];
98506 };
98507 }
98508
98509 /**
98510 * @license
98511 * Copyright 2020 Google LLC. All Rights Reserved.
98512 * Licensed under the Apache License, Version 2.0 (the "License");
98513 * you may not use this file except in compliance with the License.
98514 * You may obtain a copy of the License at
98515 *
98516 * http://www.apache.org/licenses/LICENSE-2.0
98517 *
98518 * Unless required by applicable law or agreed to in writing, software
98519 * distributed under the License is distributed on an "AS IS" BASIS,
98520 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98521 * See the License for the specific language governing permissions and
98522 * limitations under the License.
98523 * =============================================================================
98524 */
98525 function complex$1(args) {
98526 var inputs = args.inputs,
98527 backend = args.backend;
98528 var real = inputs.real,
98529 imag = inputs.imag;
98530 var realVals = backend.data.get(real.dataId).values;
98531 var imagVals = backend.data.get(imag.dataId).values;
98532 var complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
98533 var complex = backend.data.get(complexInfo.dataId);
98534 // The complex tensor owns the underlying real and imag tensorInfos, only the
98535 // complex tensor tracks refCount, when complexData is disposed the
98536 // underlying tensorData will be disposed.
98537 complex.complexTensorInfos = {
98538 real: backend.makeTensorInfo(real.shape, 'float32', realVals),
98539 imag: backend.makeTensorInfo(imag.shape, 'float32', imagVals)
98540 };
98541 return complexInfo;
98542 }
98543 var complexConfig$1 = {
98544 kernelName: Complex,
98545 backendName: 'cpu',
98546 kernelFunc: complex$1
98547 };
98548
98549 /**
98550 * @license
98551 * Copyright 2020 Google LLC. All Rights Reserved.
98552 * Licensed under the Apache License, Version 2.0 (the "License");
98553 * you may not use this file except in compliance with the License.
98554 * You may obtain a copy of the License at
98555 *
98556 * http://www.apache.org/licenses/LICENSE-2.0
98557 *
98558 * Unless required by applicable law or agreed to in writing, software
98559 * distributed under the License is distributed on an "AS IS" BASIS,
98560 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98561 * See the License for the specific language governing permissions and
98562 * limitations under the License.
98563 * =============================================================================
98564 */
98565 /**
98566 * Generates a tensorInfo with all zeros value.
98567 * @param backend cpu backend.
98568 * @param shape Shape for the zeros tensor.
98569 * @param dtype Optional. If set, the result has this dtype.
98570 */
98571 function zeros(backend, shape) {
98572 var dtype = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 'float32';
98573 if (dtype === 'complex64') {
98574 var real = zeros(backend, shape, 'float32');
98575 var imag = zeros(backend, shape, 'float32');
98576 return complex$1({
98577 inputs: {
98578 real: real,
98579 imag: imag
98580 },
98581 backend: backend
98582 });
98583 }
98584 var values = makeZerosTypedArray(sizeFromShape(shape), dtype);
98585 return backend.makeTensorInfo(shape, dtype, values);
98586 }
98587
98588 /**
98589 * @license
98590 * Copyright 2020 Google LLC. All Rights Reserved.
98591 * Licensed under the Apache License, Version 2.0 (the "License");
98592 * you may not use this file except in compliance with the License.
98593 * You may obtain a copy of the License at
98594 *
98595 * http://www.apache.org/licenses/LICENSE-2.0
98596 *
98597 * Unless required by applicable law or agreed to in writing, software
98598 * distributed under the License is distributed on an "AS IS" BASIS,
98599 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98600 * See the License for the specific language governing permissions and
98601 * limitations under the License.
98602 * =============================================================================
98603 */
98604 function identity$1(args) {
98605 var inputs = args.inputs,
98606 backend = args.backend;
98607 var x = inputs.x;
98608 backend.incRef(x.dataId);
98609 return {
98610 dataId: x.dataId,
98611 shape: x.shape,
98612 dtype: x.dtype
98613 };
98614 }
98615 var identityConfig$1 = {
98616 kernelName: Identity$1,
98617 backendName: 'cpu',
98618 kernelFunc: identity$1
98619 };
98620
98621 /**
98622 * @license
98623 * Copyright 2020 Google LLC. All Rights Reserved.
98624 * Licensed under the Apache License, Version 2.0 (the "License");
98625 * you may not use this file except in compliance with the License.
98626 * You may obtain a copy of the License at
98627 *
98628 * http://www.apache.org/licenses/LICENSE-2.0
98629 *
98630 * Unless required by applicable law or agreed to in writing, software
98631 * distributed under the License is distributed on an "AS IS" BASIS,
98632 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98633 * See the License for the specific language governing permissions and
98634 * limitations under the License.
98635 * =============================================================================
98636 */
98637 function real$1(args) {
98638 var inputs = args.inputs,
98639 backend = args.backend;
98640 var input = inputs.input;
98641 var real = backend.data.get(input.dataId).complexTensorInfos.real;
98642 var realVal = backend.data.get(real.dataId).values;
98643 // When complex tensor is disposed, its underlying parts will be disposed too.
98644 // Make new tensor out of the real value of the complex. This makes sure the
98645 // value is still accessible even if complex tensor is disposed.
98646 return backend.makeTensorInfo(real.shape, real.dtype, realVal);
98647 }
98648 var realConfig$1 = {
98649 kernelName: Real,
98650 backendName: 'cpu',
98651 kernelFunc: real$1
98652 };
98653
98654 function castImpl(values, shape, inputType, dtype) {
98655 if (dtype === 'int32') {
98656 var resultValues = Int32Array.from(values);
98657 return [shape, 'int32', resultValues];
98658 }
98659 if (dtype === 'bool') {
98660 // This is essentially the result of notEqual(x, 0). We avoid using
98661 // kernel notEqual to avoid circular dependency, i.e. binary_utils ->
98662 // cast -> notEqual -> binary_utils.
98663 var zero = toTypedArray([0], inputType);
98664 var _createSimpleBinaryKe = createSimpleBinaryKernelImpl(function (a, b) {
98665 return a !== b ? 1 : 0;
98666 })(shape, [], values, zero, 'bool'),
98667 _createSimpleBinaryKe2 = _slicedToArray(_createSimpleBinaryKe, 2),
98668 resultData = _createSimpleBinaryKe2[0],
98669 resultShape = _createSimpleBinaryKe2[1];
98670 return [resultShape, 'bool', resultData];
98671 }
98672 throw new Error("Error in Cast: failed to cast ".concat(inputType, " to ").concat(dtype));
98673 }
98674 function cast$1(args) {
98675 var inputs = args.inputs,
98676 backend = args.backend,
98677 attrs = args.attrs;
98678 var x = inputs.x;
98679 var dtype = attrs.dtype;
98680 // Casting to complex64.
98681 if (dtype === 'complex64') {
98682 if (x.dtype === 'complex64') {
98683 return identity$1({
98684 inputs: {
98685 x: x
98686 },
98687 backend: backend
98688 });
98689 }
98690 var zerosTensorInfo = zeros(backend, x.shape, x.dtype);
98691 var floatX = cast$1({
98692 inputs: {
98693 x: x
98694 },
98695 backend: backend,
98696 attrs: {
98697 dtype: 'float32'
98698 }
98699 });
98700 var result = complex$1({
98701 inputs: {
98702 real: floatX,
98703 imag: zerosTensorInfo
98704 },
98705 backend: backend
98706 });
98707 backend.disposeIntermediateTensorInfo(zerosTensorInfo);
98708 backend.disposeIntermediateTensorInfo(floatX);
98709 return result;
98710 }
98711 // Casting from complex64
98712 if (x.dtype === 'complex64') {
98713 var realPart = real$1({
98714 inputs: {
98715 input: x
98716 },
98717 backend: backend
98718 });
98719 var _result = cast$1({
98720 inputs: {
98721 x: realPart
98722 },
98723 backend: backend,
98724 attrs: {
98725 dtype: dtype
98726 }
98727 });
98728 backend.disposeIntermediateTensorInfo(realPart);
98729 return _result;
98730 }
98731 if (!hasEncodingLoss(x.dtype, dtype)) {
98732 // We don't change the underlying data, since we cast to higher
98733 // precision.
98734 var _result2 = identity$1({
98735 inputs: {
98736 x: x
98737 },
98738 backend: backend
98739 });
98740 return {
98741 dataId: _result2.dataId,
98742 shape: _result2.shape,
98743 dtype: dtype
98744 };
98745 }
98746 var values = backend.data.get(x.dataId).values;
98747 var _castImpl = castImpl(values, x.shape, x.dtype, dtype),
98748 _castImpl2 = _slicedToArray(_castImpl, 3),
98749 resultShape = _castImpl2[0],
98750 resultType = _castImpl2[1],
98751 resultData = _castImpl2[2];
98752 return backend.makeTensorInfo(resultShape, resultType, resultData);
98753 }
98754 var castConfig$1 = {
98755 kernelName: Cast,
98756 backendName: 'cpu',
98757 kernelFunc: cast$1
98758 };
98759
98760 /**
98761 * Template that creates a `KernelFunc` for binary ops.
98762 * @param name Kernel name.
98763 * @param binaryKernelImpl A `SimpleBinaryKernelImpl` for the kernel.
98764 * @param binaryKernelComplexImpl Optional. If exists, represents a
98765 * `ComplexBinaryKernelImpl` for the kernel, will be used when input dtype
98766 * is `complex64`.
98767 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
98768 * result has the same dtype as the first input. This is mainly used in
98769 * comparison kernels, such as Equal, Less, Greater, etc.
98770 */
98771 function binaryKernelFunc$1(name, simpleImpl, complexImpl, dtype) {
98772 if (complexImpl == null) {
98773 return function (_ref) {
98774 var inputs = _ref.inputs,
98775 backend = _ref.backend;
98776 var a = inputs.a,
98777 b = inputs.b;
98778 var cpuBackend = backend;
98779 assertNotComplex$1([a, b], name);
98780 var aVals = cpuBackend.data.get(a.dataId).values;
98781 var bVals = cpuBackend.data.get(b.dataId).values;
98782 var decodedAVals = a.dtype === 'string' ?
98783 // tslint:disable-next-line: no-any
98784 fromUint8ToStringArray(aVals) : aVals;
98785 var decodedBVals = a.dtype === 'string' ?
98786 // tslint:disable-next-line: no-any
98787 fromUint8ToStringArray(bVals) : bVals;
98788 var $dtype = dtype || a.dtype;
98789 var _simpleImpl = simpleImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype),
98790 _simpleImpl2 = _slicedToArray(_simpleImpl, 2),
98791 resultData = _simpleImpl2[0],
98792 resultShape = _simpleImpl2[1];
98793 return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
98794 };
98795 }
98796 return function (_ref2) {
98797 var inputs = _ref2.inputs,
98798 backend = _ref2.backend;
98799 var a = inputs.a,
98800 b = inputs.b;
98801 var cpuBackend = backend;
98802 if (a.dtype === 'complex64' || b.dtype === 'complex64') {
98803 var $aComplex = cast$1({
98804 inputs: {
98805 x: a
98806 },
98807 backend: cpuBackend,
98808 attrs: {
98809 dtype: 'complex64'
98810 }
98811 });
98812 var $aComplexVals = cpuBackend.data.get($aComplex.dataId);
98813 var aReal = $aComplexVals.complexTensorInfos.real;
98814 var aImag = $aComplexVals.complexTensorInfos.imag;
98815 var aRealVals = cpuBackend.data.get(aReal.dataId).values;
98816 var aImagVals = cpuBackend.data.get(aImag.dataId).values;
98817 var $bComplex = cast$1({
98818 inputs: {
98819 x: b
98820 },
98821 backend: cpuBackend,
98822 attrs: {
98823 dtype: 'complex64'
98824 }
98825 });
98826 var $bComplexVals = cpuBackend.data.get($bComplex.dataId);
98827 var bReal = $bComplexVals.complexTensorInfos.real;
98828 var bImag = $bComplexVals.complexTensorInfos.imag;
98829 var bRealVals = cpuBackend.data.get(bReal.dataId).values;
98830 var bImagVals = cpuBackend.data.get(bImag.dataId).values;
98831 var _complexImpl = complexImpl(a.shape, b.shape, aRealVals, aImagVals, bRealVals, bImagVals),
98832 _complexImpl2 = _slicedToArray(_complexImpl, 3),
98833 resultRealData = _complexImpl2[0],
98834 resultImagData = _complexImpl2[1],
98835 resultShape = _complexImpl2[2];
98836 var resultReal = cpuBackend.makeTensorInfo(resultShape, 'float32', resultRealData);
98837 var resultImag = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImagData);
98838 var result = complex$1({
98839 inputs: {
98840 real: resultReal,
98841 imag: resultImag
98842 },
98843 backend: cpuBackend
98844 });
98845 cpuBackend.disposeIntermediateTensorInfo($aComplex);
98846 cpuBackend.disposeIntermediateTensorInfo($bComplex);
98847 cpuBackend.disposeIntermediateTensorInfo(resultReal);
98848 cpuBackend.disposeIntermediateTensorInfo(resultImag);
98849 return result;
98850 } else {
98851 var aVals = cpuBackend.data.get(a.dataId).values;
98852 var bVals = cpuBackend.data.get(b.dataId).values;
98853 var $dtype = dtype || a.dtype;
98854 var _simpleImpl3 = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype),
98855 _simpleImpl4 = _slicedToArray(_simpleImpl3, 2),
98856 resultData = _simpleImpl4[0],
98857 _resultShape = _simpleImpl4[1];
98858 return cpuBackend.makeTensorInfo(_resultShape, $dtype, resultData);
98859 }
98860 };
98861 }
98862 /**
98863 * Template that creates the complex type implementation for binary ops.
98864 * Supports broadcast.
98865 */
98866 function createComplexBinaryKernelImpl(op) {
98867 return function (aShape, bShape, aRealVals, aImagVals, bRealVals, bImagVals) {
98868 var resultShape = assertAndGetBroadcastShape(aShape, bShape);
98869 var resultSize = sizeFromShape(resultShape);
98870 var resultRank = resultShape.length;
98871 var resultStrides = computeStrides(resultShape);
98872 var resultRealVals = getTypedArrayFromDType('float32', resultSize);
98873 var resultImagVals = getTypedArrayFromDType('float32', resultSize);
98874 var aBroadcastDims = getBroadcastDims$1(aShape, resultShape);
98875 var bBroadcastDims = getBroadcastDims$1(bShape, resultShape);
98876 var aVals = mergeRealAndImagArrays(aRealVals, aImagVals);
98877 var bVals = mergeRealAndImagArrays(bRealVals, bImagVals);
98878 var aRank = aShape.length;
98879 var aStrides = computeStrides(aShape);
98880 var bRank = bShape.length;
98881 var bStrides = computeStrides(bShape);
98882 if (aBroadcastDims.length + bBroadcastDims.length === 0) {
98883 for (var i = 0; i < resultRealVals.length; i++) {
98884 var aIdx = i % aVals.length;
98885 var bIdx = i % bVals.length;
98886 var result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]);
98887 resultRealVals[i] = result.real;
98888 resultImagVals[i] = result.imag;
98889 }
98890 } else {
98891 var _loop = function _loop() {
98892 var loc = indexToLoc(_i, resultRank, resultStrides);
98893 var aLoc = loc.slice(-aRank);
98894 aBroadcastDims.forEach(function (d) {
98895 return aLoc[d] = 0;
98896 });
98897 var aIndex = locToIndex(aLoc, aRank, aStrides);
98898 var bLoc = loc.slice(-bRank);
98899 bBroadcastDims.forEach(function (d) {
98900 return bLoc[d] = 0;
98901 });
98902 var bIndex = locToIndex(bLoc, bRank, bStrides);
98903 var opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]);
98904 resultRealVals[_i] = opResult.real;
98905 resultImagVals[_i] = opResult.imag;
98906 };
98907 for (var _i = 0; _i < resultRealVals.length; _i++) {
98908 _loop();
98909 }
98910 }
98911 return [resultRealVals, resultImagVals, resultShape];
98912 };
98913 }
98914
98915 /**
98916 * @license
98917 * Copyright 2020 Google LLC. All Rights Reserved.
98918 * Licensed under the Apache License, Version 2.0 (the "License");
98919 * you may not use this file except in compliance with the License.
98920 * You may obtain a copy of the License at
98921 *
98922 * http://www.apache.org/licenses/LICENSE-2.0
98923 *
98924 * Unless required by applicable law or agreed to in writing, software
98925 * distributed under the License is distributed on an "AS IS" BASIS,
98926 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98927 * See the License for the specific language governing permissions and
98928 * limitations under the License.
98929 * =============================================================================
98930 */
98931 var addImpl = createSimpleBinaryKernelImpl(function (a, b) {
98932 return a + b;
98933 });
98934 var addComplexImpl = createComplexBinaryKernelImpl(function (aReal, aImag, bReal, bImag) {
98935 return {
98936 real: aReal + bReal,
98937 imag: aImag + bImag
98938 };
98939 });
98940 var add = binaryKernelFunc$1(Add$1, addImpl, addComplexImpl);
98941 var addConfig$1 = {
98942 kernelName: Add$1,
98943 backendName: 'cpu',
98944 kernelFunc: add
98945 };
98946
98947 /**
98948 * @license
98949 * Copyright 2020 Google LLC. All Rights Reserved.
98950 * Licensed under the Apache License, Version 2.0 (the "License");
98951 * you may not use this file except in compliance with the License.
98952 * You may obtain a copy of the License at
98953 *
98954 * http://www.apache.org/licenses/LICENSE-2.0
98955 *
98956 * Unless required by applicable law or agreed to in writing, software
98957 * distributed under the License is distributed on an "AS IS" BASIS,
98958 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98959 * See the License for the specific language governing permissions and
98960 * limitations under the License.
98961 * =============================================================================
98962 */
98963 function bincountImpl(xVals, weightsVals, weightsDtype, weightsShape, size) {
98964 var weightsSize = sizeFromShape(weightsShape);
98965 var outVals = makeZerosTypedArray(size, weightsDtype);
98966 for (var i = 0; i < xVals.length; i++) {
98967 var value = xVals[i];
98968 if (value < 0) {
98969 throw new Error('Input x must be non-negative!');
98970 }
98971 if (value >= size) {
98972 continue;
98973 }
98974 if (weightsSize > 0) {
98975 outVals[value] += weightsVals[i];
98976 } else {
98977 outVals[value] += 1;
98978 }
98979 }
98980 return outVals;
98981 }
98982 function bincountReduceImpl(xBuf, weightsBuf, size) {
98983 var binaryOutput = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
98984 var numRows = xBuf.shape[0];
98985 var numCols = xBuf.shape[1];
98986 var outBuf = buffer([numRows, size], weightsBuf.dtype);
98987 for (var i = 0; i < numRows; i++) {
98988 for (var j = 0; j < numCols; j++) {
98989 var value = xBuf.get(i, j);
98990 if (value < 0) {
98991 throw new Error('Input x must be non-negative!');
98992 }
98993 if (value >= size) {
98994 continue;
98995 }
98996 if (binaryOutput) {
98997 outBuf.set(1, i, value);
98998 } else {
98999 if (weightsBuf.size > 0) {
99000 outBuf.set(outBuf.get(i, value) + weightsBuf.get(i, j), i, value);
99001 } else {
99002 outBuf.set(outBuf.get(i, value) + 1, i, value);
99003 }
99004 }
99005 }
99006 }
99007 return outBuf;
99008 }
99009
99010 /**
99011 * @license
99012 * Copyright 2023 Google LLC.
99013 * Licensed under the Apache License, Version 2.0 (the "License");
99014 * you may not use this file except in compliance with the License.
99015 * You may obtain a copy of the License at
99016 *
99017 * http://www.apache.org/licenses/LICENSE-2.0
99018 *
99019 * Unless required by applicable law or agreed to in writing, software
99020 * distributed under the License is distributed on an "AS IS" BASIS,
99021 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99022 * See the License for the specific language governing permissions and
99023 * limitations under the License.
99024 * =============================================================================
99025 */
99026 var bitwiseAndImpl = createSimpleBinaryKernelImpl(function (a, b) {
99027 return a & b;
99028 });
99029 var bitwiseAnd$1 = binaryKernelFunc$1(BitwiseAnd, bitwiseAndImpl);
99030 var bitwiseAndConfig$1 = {
99031 kernelName: BitwiseAnd,
99032 backendName: 'cpu',
99033 kernelFunc: bitwiseAnd$1
99034 };
99035
99036 /**
99037 * @license
99038 * Copyright 2020 Google LLC. All Rights Reserved.
99039 * Licensed under the Apache License, Version 2.0 (the "License");
99040 * you may not use this file except in compliance with the License.
99041 * You may obtain a copy of the License at
99042 *
99043 * http://www.apache.org/licenses/LICENSE-2.0
99044 *
99045 * Unless required by applicable law or agreed to in writing, software
99046 * distributed under the License is distributed on an "AS IS" BASIS,
99047 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99048 * See the License for the specific language governing permissions and
99049 * limitations under the License.
99050 * =============================================================================
99051 */
99052 /**
99053 * Template that creates implementation for unary op.
99054 */
99055 function createSimpleUnaryImpl(op) {
99056 return function (values, dtype, attrs) {
99057 var newValues = getArrayFromDType(dtype, values.length);
99058 for (var i = 0; i < values.length; ++i) {
99059 newValues[i] = op(values[i], attrs);
99060 }
99061 return newValues;
99062 };
99063 }
99064
99065 /**
99066 * @license
99067 * Copyright 2020 Google LLC. All Rights Reserved.
99068 * Licensed under the Apache License, Version 2.0 (the "License");
99069 * you may not use this file except in compliance with the License.
99070 * You may obtain a copy of the License at
99071 *
99072 * http://www.apache.org/licenses/LICENSE-2.0
99073 *
99074 * Unless required by applicable law or agreed to in writing, software
99075 * distributed under the License is distributed on an "AS IS" BASIS,
99076 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99077 * See the License for the specific language governing permissions and
99078 * limitations under the License.
99079 * =============================================================================
99080 */
99081 /**
99082 * Template that creates a `KernelFunc` for unary ops.
99083 * @param name Kernel name.
99084 * @param op A `SimpleUnaryOperation` for the kernel.
99085 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
99086 * result has the same dtype as the input. This is mainly used in certain
99087 * kernels that return bool type, such as isFinite, isInf, etc.
99088 */
99089 function unaryKernelFunc$1(name, op, dtype) {
99090 var impl = createSimpleUnaryImpl(op);
99091 return unaryKernelFuncFromImpl(name, impl, dtype);
99092 }
99093 /**
99094 * Template that creates a `KernelFunc` for unary ops from the given
99095 * `SimpleUnaryImpl`..
99096 * @param name Kernel name.
99097 * @param unaryImpl A `SimpleUnaryImpl` that implements the op.
99098 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
99099 * result has the same dtype as the input. This is mainly used in certain
99100 * kernels that return bool type, such as isFinite, isInf, etc.
99101 */
99102 function unaryKernelFuncFromImpl(name, unaryImpl, dtype) {
99103 return function (_ref) {
99104 var inputs = _ref.inputs,
99105 attrs = _ref.attrs,
99106 backend = _ref.backend;
99107 var x = inputs.x;
99108 assertNotComplex$1(x, name);
99109 var cpuBackend = backend;
99110 var values = cpuBackend.data.get(x.dataId).values;
99111 var decoded;
99112 if (x.dtype === 'string') {
99113 if (!Array.isArray(values)) {
99114 throw new Error('String tensor\'s value was not an instance of Array');
99115 }
99116 decoded = fromUint8ToStringArray(values);
99117 } else {
99118 decoded = values;
99119 }
99120 var $dtype = dtype || x.dtype;
99121 var newValues = unaryImpl(decoded, $dtype, attrs);
99122 return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
99123 };
99124 }
99125
99126 /**
99127 * @license
99128 * Copyright 2020 Google LLC. All Rights Reserved.
99129 * Licensed under the Apache License, Version 2.0 (the License);
99130 * you may not use this file except in compliance with the License.
99131 * You may obtain a copy of the License at
99132 *
99133 * http://www.apache.org/licenses/LICENSE-2.0
99134 *
99135 * Unless required by applicable law or agreed to in writing, software
99136 * distributed under the License is distributed on an AS IS BASIS,
99137 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99138 * See the License for the specific language governing permissions and
99139 * limitations under the License.
99140 * =============================================================================
99141 */
99142 var ceilImpl = createSimpleUnaryImpl(function (xi) {
99143 return Math.ceil(xi);
99144 });
99145 var ceil$1 = unaryKernelFuncFromImpl(Ceil, ceilImpl);
99146 var ceilConfig$1 = {
99147 kernelName: Ceil,
99148 backendName: 'cpu',
99149 kernelFunc: ceil$1
99150 };
99151
99152 /**
99153 * @license
99154 * Copyright 2020 Google LLC. All Rights Reserved.
99155 * Licensed under the Apache License, Version 2.0 (the "License");
99156 * you may not use this file except in compliance with the License.
99157 * You may obtain a copy of the License at
99158 *
99159 * http://www.apache.org/licenses/LICENSE-2.0
99160 *
99161 * Unless required by applicable law or agreed to in writing, software
99162 * distributed under the License is distributed on an "AS IS" BASIS,
99163 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99164 * See the License for the specific language governing permissions and
99165 * limitations under the License.
99166 * =============================================================================
99167 */
99168 function concatImpl$1(inputs, outShape, dtype, simplyConcat) {
99169 var outVals = getArrayFromDType(dtype, sizeFromShape(outShape));
99170 if (simplyConcat && dtype !== 'string') {
99171 // Use built-in TypedArray.set() method for speed.
99172 var offset = 0;
99173 inputs.forEach(function (input) {
99174 var size = sizeFromShape(input.shape);
99175 outVals.set(input.vals, offset);
99176 offset += size;
99177 });
99178 } else {
99179 var colOffset = 0;
99180 inputs.forEach(function (input) {
99181 var decodedData = dtype === 'string' ? fromUint8ToStringArray(input.vals) : input.vals;
99182 var tIdx = 0;
99183 for (var row = 0; row < input.shape[0]; ++row) {
99184 var resIdx = row * outShape[1] + colOffset;
99185 for (var col = 0; col < input.shape[1]; ++col) {
99186 outVals[resIdx + col] = decodedData[tIdx++];
99187 }
99188 }
99189 colOffset += input.shape[1];
99190 });
99191 }
99192 return outVals;
99193 }
99194
99195 /**
99196 * @license
99197 * Copyright 2020 Google LLC. All Rights Reserved.
99198 * Licensed under the Apache License, Version 2.0 (the "License");
99199 * you may not use this file except in compliance with the License.
99200 * You may obtain a copy of the License at
99201 *
99202 * http://www.apache.org/licenses/LICENSE-2.0
99203 *
99204 * Unless required by applicable law or agreed to in writing, software
99205 * distributed under the License is distributed on an "AS IS" BASIS,
99206 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99207 * See the License for the specific language governing permissions and
99208 * limitations under the License.
99209 * =============================================================================
99210 */
99211 var equalImpl = createSimpleBinaryKernelImpl(function (a, b) {
99212 return a === b ? 1 : 0;
99213 });
99214 var equal$1 = binaryKernelFunc$1(Equal, equalImpl, null /* complexImpl */, 'bool');
99215 var equalConfig$1 = {
99216 kernelName: Equal,
99217 backendName: 'cpu',
99218 kernelFunc: equal$1
99219 };
99220
99221 /**
99222 * @license
99223 * Copyright 2020 Google LLC. All Rights Reserved.
99224 * Licensed under the Apache License, Version 2.0 (the License);
99225 * you may not use this file except in compliance with the License.
99226 * You may obtain a copy of the License at
99227 *
99228 * http://www.apache.org/licenses/LICENSE-2.0
99229 *
99230 * Unless required by applicable law or agreed to in writing, software
99231 * distributed under the License is distributed on an AS IS BASIS,
99232 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99233 * See the License for the specific language governing permissions and
99234 * limitations under the License.
99235 * =============================================================================
99236 */
99237 var expImpl = createSimpleUnaryImpl(function (xi) {
99238 return Math.exp(xi);
99239 });
99240 var exp$1 = unaryKernelFuncFromImpl(Exp, expImpl, 'float32');
99241 var expConfig$1 = {
99242 kernelName: Exp,
99243 backendName: 'cpu',
99244 kernelFunc: exp$1
99245 };
99246
99247 /**
99248 * @license
99249 * Copyright 2020 Google LLC. All Rights Reserved.
99250 * Licensed under the Apache License, Version 2.0 (the License);
99251 * you may not use this file except in compliance with the License.
99252 * You may obtain a copy of the License at
99253 *
99254 * http://www.apache.org/licenses/LICENSE-2.0
99255 *
99256 * Unless required by applicable law or agreed to in writing, software
99257 * distributed under the License is distributed on an AS IS BASIS,
99258 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99259 * See the License for the specific language governing permissions and
99260 * limitations under the License.
99261 * =============================================================================
99262 */
99263 var expm1Impl = createSimpleUnaryImpl(function (xi) {
99264 return Math.expm1(xi);
99265 });
99266 var expm1$1 = unaryKernelFuncFromImpl(Expm1, expm1Impl);
99267 var expm1Config$1 = {
99268 kernelName: Expm1,
99269 backendName: 'cpu',
99270 kernelFunc: expm1$1
99271 };
99272
99273 /**
99274 * @license
99275 * Copyright 2020 Google LLC. All Rights Reserved.
99276 * Licensed under the Apache License, Version 2.0 (the License);
99277 * you may not use this file except in compliance with the License.
99278 * You may obtain a copy of the License at
99279 *
99280 * http://www.apache.org/licenses/LICENSE-2.0
99281 *
99282 * Unless required by applicable law or agreed to in writing, software
99283 * distributed under the License is distributed on an AS IS BASIS,
99284 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99285 * See the License for the specific language governing permissions and
99286 * limitations under the License.
99287 * =============================================================================
99288 */
99289 var floorImpl = createSimpleUnaryImpl(function (xi) {
99290 return Math.floor(xi);
99291 });
99292 var floor$1 = unaryKernelFuncFromImpl(Floor, floorImpl);
99293 var floorConfig$1 = {
99294 kernelName: Floor,
99295 backendName: 'cpu',
99296 kernelFunc: floor$1
99297 };
99298
99299 /**
99300 * @license
99301 * Copyright 2020 Google LLC. All Rights Reserved.
99302 * Licensed under the Apache License, Version 2.0 (the "License");
99303 * you may not use this file except in compliance with the License.
99304 * You may obtain a copy of the License at
99305 *
99306 * http://www.apache.org/licenses/LICENSE-2.0
99307 *
99308 * Unless required by applicable law or agreed to in writing, software
99309 * distributed under the License is distributed on an "AS IS" BASIS,
99310 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99311 * See the License for the specific language governing permissions and
99312 * limitations under the License.
99313 * =============================================================================
99314 */
99315 var floorDivImpl = createSimpleBinaryKernelImpl(function (a, b) {
99316 return Math.floor(a / b);
99317 });
99318 var floorDiv$1 = binaryKernelFunc$1(FloorDiv, floorDivImpl, null /* complexImpl */, 'int32');
99319 var floorDivConfig$1 = {
99320 kernelName: FloorDiv,
99321 backendName: 'cpu',
99322 kernelFunc: floorDiv$1
99323 };
99324
99325 function gatherNdImpl(indicesData, paramsBuf, dtype, numSlices, sliceRank, sliceSize, strides, paramsShape, paramsSize) {
99326 var outBuf = buffer([numSlices, sliceSize], dtype);
99327 for (var i = 0; i < numSlices; i++) {
99328 var index = [];
99329 var flattenIndex = 0;
99330 for (var j = 0; j < sliceRank; j++) {
99331 var dim = indicesData[i * sliceRank + j];
99332 flattenIndex += dim * strides[j];
99333 index.push(dim);
99334 }
99335 if (flattenIndex < 0 || flattenIndex >= paramsSize / sliceSize) {
99336 throw new Error("Invalid indices: ".concat(index, " does not index into ").concat(paramsShape));
99337 }
99338 for (var k = 0; k < sliceSize; k++) {
99339 outBuf.values[i * sliceSize + k] = paramsBuf.get.apply(paramsBuf, _toConsumableArray(paramsBuf.indexToLoc(flattenIndex * sliceSize + k)));
99340 }
99341 }
99342 return outBuf;
99343 }
99344
99345 /**
99346 * @license
99347 * Copyright 2020 Google LLC. All Rights Reserved.
99348 * Licensed under the Apache License, Version 2.0 (the "License");
99349 * you may not use this file except in compliance with the License.
99350 * You may obtain a copy of the License at
99351 *
99352 * http://www.apache.org/licenses/LICENSE-2.0
99353 *
99354 * Unless required by applicable law or agreed to in writing, software
99355 * distributed under the License is distributed on an "AS IS" BASIS,
99356 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99357 * See the License for the specific language governing permissions and
99358 * limitations under the License.
99359 * =============================================================================
99360 */
99361 function gatherV2Impl(xBuf, indicesBuf, flattenOutputShape) {
99362 var outBuf = buffer(flattenOutputShape, xBuf.dtype);
99363 for (var i = 0; i < outBuf.size; ++i) {
99364 var newLoc = outBuf.indexToLoc(i);
99365 var originalLoc = newLoc.slice();
99366 var batchIdx = originalLoc[0];
99367 var indicesIdx = originalLoc[2];
99368 var indicesIndex = indicesBuf.locToIndex([batchIdx, indicesIdx]);
99369 originalLoc[2] = indicesBuf.values[indicesIndex];
99370 var originalIndex = xBuf.locToIndex(originalLoc);
99371 if (0 <= originalIndex && originalIndex < xBuf.values.length) {
99372 outBuf.values[i] = xBuf.values[originalIndex];
99373 } // Else, index is out of bounds, so leave the default zero val in outBuf.
99374 }
99375
99376 return outBuf;
99377 }
99378
99379 /**
99380 * @license
99381 * Copyright 2020 Google LLC. All Rights Reserved.
99382 * Licensed under the Apache License, Version 2.0 (the "License");
99383 * you may not use this file except in compliance with the License.
99384 * You may obtain a copy of the License at
99385 *
99386 * http://www.apache.org/licenses/LICENSE-2.0
99387 *
99388 * Unless required by applicable law or agreed to in writing, software
99389 * distributed under the License is distributed on an "AS IS" BASIS,
99390 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99391 * See the License for the specific language governing permissions and
99392 * limitations under the License.
99393 * =============================================================================
99394 */
99395 var greaterImpl = createSimpleBinaryKernelImpl(function (a, b) {
99396 return a > b ? 1 : 0;
99397 });
99398 var greater$1 = binaryKernelFunc$1(Greater, greaterImpl, null /* complexImpl */, 'bool');
99399 var greaterConfig$1 = {
99400 kernelName: Greater,
99401 backendName: 'cpu',
99402 kernelFunc: greater$1
99403 };
99404
99405 /**
99406 * @license
99407 * Copyright 2020 Google LLC. All Rights Reserved.
99408 * Licensed under the Apache License, Version 2.0 (the "License");
99409 * you may not use this file except in compliance with the License.
99410 * You may obtain a copy of the License at
99411 *
99412 * http://www.apache.org/licenses/LICENSE-2.0
99413 *
99414 * Unless required by applicable law or agreed to in writing, software
99415 * distributed under the License is distributed on an "AS IS" BASIS,
99416 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99417 * See the License for the specific language governing permissions and
99418 * limitations under the License.
99419 * =============================================================================
99420 */
99421 var greaterEqualImpl = createSimpleBinaryKernelImpl(function (a, b) {
99422 return a >= b ? 1 : 0;
99423 });
99424 var greaterEqual$1 = binaryKernelFunc$1(GreaterEqual, greaterEqualImpl, null /* complexImpl */, 'bool');
99425 var greaterEqualConfig$1 = {
99426 kernelName: GreaterEqual,
99427 backendName: 'cpu',
99428 kernelFunc: greaterEqual$1
99429 };
99430
99431 /**
99432 * @license
99433 * Copyright 2020 Google LLC. All Rights Reserved.
99434 * Licensed under the Apache License, Version 2.0 (the "License");
99435 * you may not use this file except in compliance with the License.
99436 * You may obtain a copy of the License at
99437 *
99438 * http://www.apache.org/licenses/LICENSE-2.0
99439 *
99440 * Unless required by applicable law or agreed to in writing, software
99441 * distributed under the License is distributed on an "AS IS" BASIS,
99442 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99443 * See the License for the specific language governing permissions and
99444 * limitations under the License.
99445 * =============================================================================
99446 */
99447 var lessImpl = createSimpleBinaryKernelImpl(function (a, b) {
99448 return a < b ? 1 : 0;
99449 });
99450 var less$1 = binaryKernelFunc$1(Less, lessImpl, null /* complexImpl */, 'bool');
99451 var lessConfig$1 = {
99452 kernelName: Less,
99453 backendName: 'cpu',
99454 kernelFunc: less$1
99455 };
99456
99457 /**
99458 * @license
99459 * Copyright 2020 Google LLC. All Rights Reserved.
99460 * Licensed under the Apache License, Version 2.0 (the "License");
99461 * you may not use this file except in compliance with the License.
99462 * You may obtain a copy of the License at
99463 *
99464 * http://www.apache.org/licenses/LICENSE-2.0
99465 *
99466 * Unless required by applicable law or agreed to in writing, software
99467 * distributed under the License is distributed on an "AS IS" BASIS,
99468 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99469 * See the License for the specific language governing permissions and
99470 * limitations under the License.
99471 * =============================================================================
99472 */
99473 var lessEqualImpl = createSimpleBinaryKernelImpl(function (a, b) {
99474 return a <= b ? 1 : 0;
99475 });
99476 var lessEqual$1 = binaryKernelFunc$1(LessEqual, lessEqualImpl, null /* complexImpl */, 'bool');
99477 var lessEqualConfig$1 = {
99478 kernelName: LessEqual,
99479 backendName: 'cpu',
99480 kernelFunc: lessEqual$1
99481 };
99482
99483 /**
99484 * @license
99485 * Copyright 2020 Google LLC. All Rights Reserved.
99486 * Licensed under the Apache License, Version 2.0 (the "License");
99487 * you may not use this file except in compliance with the License.
99488 * You may obtain a copy of the License at
99489 *
99490 * http://www.apache.org/licenses/LICENSE-2.0
99491 *
99492 * Unless required by applicable law or agreed to in writing, software
99493 * distributed under the License is distributed on an "AS IS" BASIS,
99494 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99495 * See the License for the specific language governing permissions and
99496 * limitations under the License.
99497 * =============================================================================
99498 */
99499 function linSpaceImpl(start, stop, num) {
99500 var step = (stop - start) / (num - 1);
99501 var values = makeZerosTypedArray(num, 'float32');
99502 values[0] = start;
99503 for (var i = 1; i < values.length; i++) {
99504 values[i] = values[i - 1] + step;
99505 }
99506 return values;
99507 }
99508
99509 /**
99510 * @license
99511 * Copyright 2020 Google LLC. All Rights Reserved.
99512 * Licensed under the Apache License, Version 2.0 (the License);
99513 * you may not use this file except in compliance with the License.
99514 * You may obtain a copy of the License at
99515 *
99516 * http://www.apache.org/licenses/LICENSE-2.0
99517 *
99518 * Unless required by applicable law or agreed to in writing, software
99519 * distributed under the License is distributed on an AS IS BASIS,
99520 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99521 * See the License for the specific language governing permissions and
99522 * limitations under the License.
99523 * =============================================================================
99524 */
99525 var logImpl = createSimpleUnaryImpl(function (xi) {
99526 return Math.log(xi);
99527 });
99528 var log$1 = unaryKernelFuncFromImpl(Log, logImpl);
99529 var logConfig$1 = {
99530 kernelName: Log,
99531 backendName: 'cpu',
99532 kernelFunc: log$1
99533 };
99534
99535 /**
99536 * @license
99537 * Copyright 2020 Google LLC. All Rights Reserved.
99538 * Licensed under the Apache License, Version 2.0 (the "License");
99539 * you may not use this file except in compliance with the License.
99540 * You may obtain a copy of the License at
99541 *
99542 * http://www.apache.org/licenses/LICENSE-2.0
99543 *
99544 * Unless required by applicable law or agreed to in writing, software
99545 * distributed under the License is distributed on an "AS IS" BASIS,
99546 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99547 * See the License for the specific language governing permissions and
99548 * limitations under the License.
99549 * =============================================================================
99550 */
99551 function maxImpl$1(aVals, reduceSize, outShape, dtype) {
99552 var vals = getTypedArrayFromDType(dtype, sizeFromShape(outShape));
99553 for (var i = 0; i < vals.length; ++i) {
99554 var offset = i * reduceSize;
99555 var max = aVals[offset];
99556 for (var j = 0; j < reduceSize; ++j) {
99557 var value = aVals[offset + j];
99558 if (Number.isNaN(value) || value > max) {
99559 // comparison with NaN always return false
99560 max = value;
99561 }
99562 }
99563 vals[i] = max;
99564 }
99565 return vals;
99566 }
99567
99568 /**
99569 * @license
99570 * Copyright 2020 Google LLC. All Rights Reserved.
99571 * Licensed under the Apache License, Version 2.0 (the "License");
99572 * you may not use this file except in compliance with the License.
99573 * You may obtain a copy of the License at
99574 *
99575 * http://www.apache.org/licenses/LICENSE-2.0
99576 *
99577 * Unless required by applicable law or agreed to in writing, software
99578 * distributed under the License is distributed on an "AS IS" BASIS,
99579 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99580 * See the License for the specific language governing permissions and
99581 * limitations under the License.
99582 * =============================================================================
99583 */
99584 var maximumImpl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
99585 return Math.max(aValue, bValue);
99586 });
99587 var maximum$1 = binaryKernelFunc$1(Maximum$1, maximumImpl);
99588 var maximumConfig$1 = {
99589 kernelName: Maximum$1,
99590 backendName: 'cpu',
99591 kernelFunc: maximum$1
99592 };
99593
99594 /**
99595 * @license
99596 * Copyright 2020 Google LLC. All Rights Reserved.
99597 * Licensed under the Apache License, Version 2.0 (the "License");
99598 * you may not use this file except in compliance with the License.
99599 * You may obtain a copy of the License at
99600 *
99601 * http://www.apache.org/licenses/LICENSE-2.0
99602 *
99603 * Unless required by applicable law or agreed to in writing, software
99604 * distributed under the License is distributed on an "AS IS" BASIS,
99605 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99606 * See the License for the specific language governing permissions and
99607 * limitations under the License.
99608 * =============================================================================
99609 */
99610 var minimumImpl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
99611 return Math.min(aValue, bValue);
99612 });
99613 var minimum$1 = binaryKernelFunc$1(Minimum$1, minimumImpl);
99614 var minimumConfig$1 = {
99615 kernelName: Minimum$1,
99616 backendName: 'cpu',
99617 kernelFunc: minimum$1
99618 };
99619
99620 /**
99621 * @license
99622 * Copyright 2020 Google LLC. All Rights Reserved.
99623 * Licensed under the Apache License, Version 2.0 (the "License");
99624 * you may not use this file except in compliance with the License.
99625 * You may obtain a copy of the License at
99626 *
99627 * http://www.apache.org/licenses/LICENSE-2.0
99628 *
99629 * Unless required by applicable law or agreed to in writing, software
99630 * distributed under the License is distributed on an "AS IS" BASIS,
99631 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99632 * See the License for the specific language governing permissions and
99633 * limitations under the License.
99634 * =============================================================================
99635 */
99636 var multiplyImpl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
99637 return aValue * bValue;
99638 });
99639 var multiplyComplexImpl = createComplexBinaryKernelImpl(function (aReal, aImag, bReal, bImag) {
99640 return {
99641 real: aReal * bReal - aImag * bImag,
99642 imag: aReal * bImag + aImag * bReal
99643 };
99644 });
99645 var multiply$1 = binaryKernelFunc$1(Multiply$1, multiplyImpl, multiplyComplexImpl);
99646 var multiplyConfig$1 = {
99647 kernelName: Multiply$1,
99648 backendName: 'cpu',
99649 kernelFunc: multiply$1
99650 };
99651
99652 function negImpl(xVals, xShape, xDtype) {
99653 var minusOne = createScalarValue(-1, xDtype);
99654 return multiplyImpl([], xShape, minusOne, xVals, xDtype);
99655 }
99656 function neg$1(args) {
99657 var inputs = args.inputs,
99658 backend = args.backend;
99659 var x = inputs.x;
99660 assertNotComplex$1(x, 'neg');
99661 var xVals = backend.data.get(x.dataId).values;
99662 var _negImpl = negImpl(xVals, x.shape, x.dtype),
99663 _negImpl2 = _slicedToArray(_negImpl, 2),
99664 res = _negImpl2[0],
99665 newShape = _negImpl2[1];
99666 return backend.makeTensorInfo(newShape, x.dtype, res);
99667 }
99668 var negConfig$1 = {
99669 kernelName: Neg,
99670 backendName: 'cpu',
99671 kernelFunc: neg$1
99672 };
99673
99674 /**
99675 * @license
99676 * Copyright 2020 Google LLC. All Rights Reserved.
99677 * Licensed under the Apache License, Version 2.0 (the "License");
99678 * you may not use this file except in compliance with the License.
99679 * You may obtain a copy of the License at
99680 *
99681 * http://www.apache.org/licenses/LICENSE-2.0
99682 *
99683 * Unless required by applicable law or agreed to in writing, software
99684 * distributed under the License is distributed on an "AS IS" BASIS,
99685 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99686 * See the License for the specific language governing permissions and
99687 * limitations under the License.
99688 * =============================================================================
99689 */
99690 var notEqualImpl = createSimpleBinaryKernelImpl(function (a, b) {
99691 return a !== b ? 1 : 0;
99692 });
99693 var notEqual$1 = binaryKernelFunc$1(NotEqual, notEqualImpl, null /* complexOp */, 'bool');
99694 var notEqualConfig$1 = {
99695 kernelName: NotEqual,
99696 backendName: 'cpu',
99697 kernelFunc: notEqual$1
99698 };
99699
99700 /**
99701 * @license
99702 * Copyright 2020 Google LLC. All Rights Reserved.
99703 * Licensed under the Apache License, Version 2.0 (the "License");
99704 * you may not use this file except in compliance with the License.
99705 * You may obtain a copy of the License at
99706 *
99707 * http://www.apache.org/licenses/LICENSE-2.0
99708 *
99709 * Unless required by applicable law or agreed to in writing, software
99710 * distributed under the License is distributed on an "AS IS" BASIS,
99711 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99712 * See the License for the specific language governing permissions and
99713 * limitations under the License.
99714 * =============================================================================
99715 */
99716 function transposeImpl$1(xVals, xShape, dtype, perm, newShape) {
99717 var xRank = xShape.length;
99718 var xSize = sizeFromShape(xShape);
99719 var xStrides = computeStrides(xShape);
99720 var newStrides = computeStrides(newShape);
99721 var result = getTypedArrayFromDType(dtype, sizeFromShape(newShape));
99722 for (var i = 0; i < xSize; ++i) {
99723 var loc = indexToLoc(i, xRank, xStrides);
99724 // Permute location.
99725 var newLoc = new Array(loc.length);
99726 for (var _i = 0; _i < newLoc.length; _i++) {
99727 newLoc[_i] = loc[perm[_i]];
99728 }
99729 var newIndex = locToIndex(newLoc, xRank, newStrides);
99730 result[newIndex] = xVals[i];
99731 }
99732 return result;
99733 }
99734
99735 /**
99736 * @license
99737 * Copyright 2020 Google LLC. All Rights Reserved.
99738 * Licensed under the Apache License, Version 2.0 (the "License");
99739 * you may not use this file except in compliance with the License.
99740 * You may obtain a copy of the License at
99741 *
99742 * http://www.apache.org/licenses/LICENSE-2.0
99743 *
99744 * Unless required by applicable law or agreed to in writing, software
99745 * distributed under the License is distributed on an "AS IS" BASIS,
99746 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99747 * See the License for the specific language governing permissions and
99748 * limitations under the License.
99749 * =============================================================================
99750 */
99751 function transpose$1(args) {
99752 var inputs = args.inputs,
99753 attrs = args.attrs,
99754 backend = args.backend;
99755 var x = inputs.x;
99756 var perm = attrs.perm;
99757 assertNotComplex$1(x, 'transpose');
99758 var xRank = x.shape.length;
99759 var newShape = new Array(xRank);
99760 for (var i = 0; i < newShape.length; i++) {
99761 newShape[i] = x.shape[perm[i]];
99762 }
99763 var values = backend.data.get(x.dataId).values;
99764 var result = transposeImpl$1(values, x.shape, x.dtype, perm, newShape);
99765 var dataId = backend.write(result, newShape, x.dtype);
99766 return {
99767 dataId: dataId,
99768 shape: newShape,
99769 dtype: x.dtype
99770 };
99771 }
99772 var transposeConfig$1 = {
99773 kernelName: Transpose,
99774 backendName: 'cpu',
99775 kernelFunc: transpose$1
99776 };
99777
99778 function prodImpl(xShape, xDtype, xVals, reductionAxes) {
99779 var _backend_util$compute = computeOutAndReduceShapes(xShape, reductionAxes),
99780 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
99781 outShape = _backend_util$compute2[0],
99782 reduceShape = _backend_util$compute2[1];
99783 var outDtype = upcastType(xDtype, 'int32');
99784 var outVals = makeZerosTypedArray(sizeFromShape(outShape), outDtype);
99785 var reduceSize = sizeFromShape(reduceShape);
99786 for (var i = 0; i < outVals.length; ++i) {
99787 var offset = i * reduceSize;
99788 var _prod = 1;
99789 for (var j = 0; j < reduceSize; ++j) {
99790 _prod *= xVals[offset + j];
99791 }
99792 outVals[i] = _prod;
99793 }
99794 return {
99795 outVals: outVals,
99796 outShape: outShape,
99797 outDtype: outDtype
99798 };
99799 }
99800 function prod$1(args) {
99801 var inputs = args.inputs,
99802 backend = args.backend,
99803 attrs = args.attrs;
99804 var x = inputs.x;
99805 var axis = attrs.axis,
99806 keepDims = attrs.keepDims;
99807 assertNotComplex$1(x, 'prod');
99808 var xRank = x.shape.length;
99809 var axes = parseAxisParam(axis, x.shape);
99810 var permutation = getAxesPermutation(axes, xRank);
99811 var reductionAxes = axes;
99812 var permutedX = x;
99813 var intermediateTensorInfos = [];
99814 if (permutation != null) {
99815 permutedX = transpose$1({
99816 inputs: {
99817 x: x
99818 },
99819 backend: backend,
99820 attrs: {
99821 perm: permutation
99822 }
99823 });
99824 intermediateTensorInfos.push(permutedX);
99825 reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
99826 }
99827 var xVals = backend.data.get(permutedX.dataId).values;
99828 var _prodImpl = prodImpl(permutedX.shape, permutedX.dtype, xVals, reductionAxes),
99829 outVals = _prodImpl.outVals,
99830 outShape = _prodImpl.outShape,
99831 outDtype = _prodImpl.outDtype;
99832 var resultShape = outShape;
99833 if (keepDims) {
99834 resultShape = expandShapeToKeepDim(outShape, axes);
99835 }
99836 intermediateTensorInfos.forEach(function (t) {
99837 return backend.disposeIntermediateTensorInfo(t);
99838 });
99839 return backend.makeTensorInfo(resultShape, outDtype, outVals);
99840 }
99841 var prodConfig$1 = {
99842 kernelName: Prod,
99843 backendName: 'cpu',
99844 kernelFunc: prod$1
99845 };
99846
99847 function validateIndices(indices, indicesShape, numParams) {
99848 indices.forEach(function (index, i) {
99849 if (index < 0 || index >= numParams) {
99850 var locString = indexToLoc(i, indicesShape.length, computeStrides(indicesShape)).join(',');
99851 throw new Error("indices[".concat(locString, "] = ").concat(index, " is not in [0, ").concat(numParams, ")"));
99852 }
99853 });
99854 }
99855 function validateSplits(paramsNestedSplits, numParamsDenseValues) {
99856 // Validate
99857 for (var dim = 0; dim < paramsNestedSplits.length; ++dim) {
99858 var splits = paramsNestedSplits[dim];
99859 var lastSplit = dim === paramsNestedSplits.length - 1 ? numParamsDenseValues : paramsNestedSplits[dim + 1].length;
99860 if (splits.length === 0) {
99861 throw new Error('Ragged splits may not be empty');
99862 }
99863 if (splits[0] < 0) {
99864 throw new Error('Ragged splits must be non-negative');
99865 }
99866 if (splits[splits.length - 1] > lastSplit) {
99867 throw new Error('Ragged splits must not point past values');
99868 }
99869 for (var i = 1; i < splits.length; ++i) {
99870 if (splits[i - 1] > splits[i]) {
99871 throw new Error('Ragged splits must be sorted in ascending order');
99872 }
99873 }
99874 }
99875 }
99876 // Construct the `splits` output tensors, encoded using a nested vector.
99877 // Also find the slices of values that need to be copied, and store them
99878 // in `valueSlices`. The total number of values that will be copied (which
99879 // we need for allocating the output values tensor) is stored in `numValues`.
99880 function makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues) {
99881 var valueSlices = [];
99882 var numValues = 0;
99883 var numSplits = indicesShape.length - 1 + paramsNestedSplits.length;
99884 var outSplits = new Array(numSplits).fill(null).map(function () {
99885 return [0];
99886 });
99887 validateSplits(paramsNestedSplits, numParamsDenseValues);
99888 // Add `splits` that come from all but the last dimension of the dense
99889 // Tensor `indices`. In particular, for each dimension D, we add a
99890 // splits tensor whose values are:
99891 // range(reduceProd(splits.shape[:D]) + 1) * splits.shape[D+1]
99892 // E.g., if indices.shape=[2, 3, 4] then we will add splits tensors:
99893 // [0, 3, 6] # length=2+1, stride=3
99894 // [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4
99895 var nrows = 1;
99896 for (var dim = 0; dim < indicesShape.length - 1; ++dim) {
99897 nrows *= indicesShape[dim];
99898 var rowLength = indicesShape[dim + 1];
99899 for (var i = 1; i < nrows + 1; ++i) {
99900 outSplits[dim].push(i * rowLength);
99901 }
99902 }
99903 // Add `splits` that come from `paramsNestedSplits`. Starting with the
99904 // outermost ragged dimension (i.e., the first `splits` tensor), we work
99905 // our way in, finding the range of values that should be copied. As we
99906 // go, we update the output `splits` for each dimension with the appropriate
99907 // values. In particular, the *lengths* of the slices from `param_splits`
99908 // should be copied to generate corresponding slice lengths in the output
99909 // splits. E.g., if we are copying a ragged row with length 4, then we
99910 // should add a new split point to outSplits that is 4 greater than the
99911 // previous split point in outSplits.
99912 for (var _i = 0; _i < indices.length; ++_i) {
99913 var start = indices[_i];
99914 var limit = indices[_i] + 1;
99915 // Copy splits.
99916 for (var _dim = 0; _dim < paramsNestedSplits.length; ++_dim) {
99917 var splits = paramsNestedSplits[_dim];
99918 var outDim = _dim + indicesShape.length - 1;
99919 if (outDim >= 0) {
99920 var outSplitsOutDim = outSplits[outDim];
99921 var delta = outSplitsOutDim[outSplitsOutDim.length - 1] - splits[start];
99922 for (var j = start; j < limit; ++j) {
99923 outSplits[outDim].push(splits[j + 1] + delta);
99924 }
99925 }
99926 start = splits[start];
99927 limit = splits[limit];
99928 }
99929 if (limit !== start) {
99930 valueSlices.push([start, limit]);
99931 numValues += limit - start;
99932 }
99933 }
99934 return {
99935 outSplits: outSplits,
99936 valueSlices: valueSlices,
99937 numValues: numValues
99938 };
99939 }
99940 function getSplits(outSplits) {
99941 var splitsOut = [];
99942 var _loop = function _loop() {
99943 var numSplits = outSplits[i].length;
99944 var splits = getArrayFromDType('int32', numSplits);
99945 splitsOut.push(splits);
99946 outSplits[i].forEach(function (value, j) {
99947 return splits[j] = value;
99948 });
99949 };
99950 for (var i = 0; i < outSplits.length; ++i) {
99951 _loop();
99952 }
99953 return splitsOut;
99954 }
99955 function computeFlatOuterDims(orig, numOutDims) {
99956 var outDims = orig.slice(0, numOutDims);
99957 while (outDims.length < numOutDims) {
99958 outDims.push(1);
99959 }
99960 for (var inDim = numOutDims; inDim < orig.length; inDim++) {
99961 outDims[numOutDims - 1] *= orig[inDim];
99962 }
99963 return outDims;
99964 }
99965 // For each slice in `(start, limit)` in `valueSlices`, append
99966 // `paramsDenseValues[start,...,limit] to `values`. `valueSize` indicates
99967 // the number of scalars contained in each value paramsDenseValues[i].
99968 function writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, values, valuesShape) {
99969 var denseM = computeFlatOuterDims(paramsDenseValuesShape, 2)[1];
99970 var valuesM = computeFlatOuterDims(valuesShape, 2)[1];
99971 var outPos = 0;
99972 var _iterator = _createForOfIteratorHelper(valueSlices),
99973 _step;
99974 try {
99975 for (_iterator.s(); !(_step = _iterator.n()).done;) {
99976 var slice = _step.value;
99977 for (var i = slice[0]; i < slice[1]; ++i) {
99978 for (var j = 0; j < valueSize; ++j) {
99979 values[outPos * valuesM + j] = paramsDenseValues[i * denseM + j];
99980 }
99981 ++outPos;
99982 }
99983 }
99984 } catch (err) {
99985 _iterator.e(err);
99986 } finally {
99987 _iterator.f();
99988 }
99989 }
99990 function getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues) {
99991 var valuesShape = paramsDenseValuesShape.slice();
99992 valuesShape[0] = numValues;
99993 var valuesOut = getArrayFromDType(paramsDenseValuesDType, sizeFromShape(valuesShape));
99994 var numElements = paramsDenseValues.length;
99995 var valueSize = numElements === 0 ? 0 : numElements / paramsDenseValuesShape[0];
99996 writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, valuesOut, valuesShape);
99997 return [valuesOut, valuesShape];
99998 }
99999 function raggedGatherImpl(paramsNestedSplits, paramsNestedSplitsShapes, paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, indices, indicesShape, outputRaggedRank) {
100000 if (paramsNestedSplits.length === 0) {
100001 throw new Error('paramsNestedSplits must be non empty');
100002 }
100003 if (paramsNestedSplitsShapes[0].length === 0) {
100004 throw new Error('Split tensors must not be scalars');
100005 }
100006 var numParams = paramsNestedSplitsShapes[0][0] - 1;
100007 validateIndices(indices, indicesShape, numParams);
100008 if (paramsDenseValuesShape.length === 0) {
100009 throw new Error('params.rank must be nonzero');
100010 }
100011 var numParamsDenseValues = paramsDenseValuesShape[0];
100012 // Calculate the `splits`, and store the value slices that we need to
100013 // copy in `valueSlices`.
100014 var _makeSplits = makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues),
100015 outSplits = _makeSplits.outSplits,
100016 valueSlices = _makeSplits.valueSlices,
100017 numValues = _makeSplits.numValues;
100018 // Write the output tensors.
100019 var outputNestedSplits = getSplits(outSplits);
100020 var outputDenseValues = getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues);
100021 return [outputNestedSplits, outputDenseValues[0], outputDenseValues[1]];
100022 }
100023
100024 /**
100025 * @license
100026 * Copyright 2022 Google LLC.
100027 * Licensed under the Apache License, Version 2.0 (the "License");
100028 * you may not use this file except in compliance with the License.
100029 * You may obtain a copy of the License at
100030 *
100031 * http://www.apache.org/licenses/LICENSE-2.0
100032 *
100033 * Unless required by applicable law or agreed to in writing, software
100034 * distributed under the License is distributed on an "AS IS" BASIS,
100035 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100036 * See the License for the specific language governing permissions and
100037 * limitations under the License.
100038 * =============================================================================
100039 */
100040 var INT32_MAX = 2147483647;
100041 function raggedRangeImpl(starts, startsShape, startsDType, limits, limitsShape, deltas, deltasShape) {
100042 // Check input tensor shapes.
100043 if (startsShape.length > 1) {
100044 throw new Error('starts must be a scalar or vector');
100045 }
100046 if (limitsShape.length > 1) {
100047 throw new Error('limits must be a scalar or vector');
100048 }
100049 if (deltasShape.length > 1) {
100050 throw new Error('deltas must be a scalar or vector');
100051 }
100052 // Determine which tensors we need to broadcast.
100053 var broadcastStarts = startsShape.length === 0;
100054 var broadcastLimits = limitsShape.length === 0;
100055 var broadcastDeltas = deltasShape.length === 0;
100056 // nRows (number of output rows) is the size of the non-broadcast inputs,
100057 // or 1 if all inputs are scalars.
100058 var inSizes = [];
100059 if (!broadcastStarts) {
100060 inSizes.push(startsShape[0]);
100061 }
100062 if (!broadcastLimits) {
100063 inSizes.push(limitsShape[0]);
100064 }
100065 if (!broadcastDeltas) {
100066 inSizes.push(deltasShape[0]);
100067 }
100068 for (var i = 1; i < inSizes.length; ++i) {
100069 if (inSizes[i] !== inSizes[i - 1]) {
100070 throw new Error('starts, limits, and deltas must have the same shape');
100071 }
100072 }
100073 var nRows = inSizes.length === 0 ? 1 : inSizes[0];
100074 // Construct the rtNestedSplits tensor.
100075 var rtNestedSplits = getArrayFromDType('int32', nRows + 1);
100076 rtNestedSplits[0] = 0;
100077 for (var row = 0; row < nRows; ++row) {
100078 var start = broadcastStarts ? starts[0] : starts[row];
100079 var limit = broadcastLimits ? limits[0] : limits[row];
100080 var delta = broadcastDeltas ? deltas[0] : deltas[row];
100081 if (delta === 0) {
100082 throw new Error('Requires delta != 0');
100083 }
100084 var size = void 0; // The number of elements in the specified range.
100085 if (delta > 0 && limit < start || delta < 0 && limit > start) {
100086 size = 0;
100087 } else {
100088 size = Math.ceil(Math.abs((limit - start) / delta));
100089 if (size > INT32_MAX) {
100090 throw new Error("Requires ((limit - start) / delta) <= ".concat(INT32_MAX));
100091 }
100092 }
100093 rtNestedSplits[row + 1] = rtNestedSplits[row] + size;
100094 }
100095 var nVals = rtNestedSplits[nRows];
100096 // Construct the rtDenseValues tensor.
100097 var rtDenseValues = getArrayFromDType(startsDType, nVals);
100098 var valueIndex = 0;
100099 for (var _row = 0; _row < nRows; ++_row) {
100100 var rowSize = rtNestedSplits[_row + 1] - rtNestedSplits[_row];
100101 var value = broadcastStarts ? starts[0] : starts[_row];
100102 var _delta = broadcastDeltas ? deltas[0] : deltas[_row];
100103 for (var _i = 0; _i < rowSize; ++_i) {
100104 rtDenseValues[valueIndex++] = value;
100105 value += _delta;
100106 }
100107 }
100108 return [rtNestedSplits, rtDenseValues];
100109 }
100110
100111 var RowPartitionType = RowPartitionType$1;
100112 // Based on
100113 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
100114 var RaggedTensorToTensorOp = /*#__PURE__*/function () {
100115 function RaggedTensorToTensorOp(shape, shapeShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypeStrings) {
100116 _classCallCheck(this, RaggedTensorToTensorOp);
100117 this.shape = shape;
100118 this.shapeShape = shapeShape;
100119 this.values = values;
100120 this.valuesShape = valuesShape;
100121 this.valuesDType = valuesDType;
100122 this.defaultValue = defaultValue;
100123 this.defaultValueShape = defaultValueShape;
100124 this.rowPartitionValues = rowPartitionValues;
100125 this.rowPartitionValuesShapes = rowPartitionValuesShapes;
100126 this.rowPartitionTypes = getRowPartitionTypesHelper(rowPartitionTypeStrings);
100127 this.raggedRank = getRaggedRank(this.rowPartitionTypes);
100128 }
100129 _createClass(RaggedTensorToTensorOp, [{
100130 key: "getRowPartitionTypeByDimension",
100131 value: function getRowPartitionTypeByDimension(dimension) {
100132 if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
100133 return this.rowPartitionTypes[dimension + 1];
100134 } else {
100135 return this.rowPartitionTypes[dimension];
100136 }
100137 }
100138 // Returns the relationship between dimension and dimension + 1.
100139 }, {
100140 key: "getRowPartitionTensor",
100141 value: function getRowPartitionTensor(dimension) {
100142 if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
100143 return this.rowPartitionValues[dimension + 1];
100144 } else {
100145 return this.rowPartitionValues[dimension];
100146 }
100147 }
100148 }, {
100149 key: "getMaxWidth",
100150 value: function getMaxWidth(dimension) {
100151 var rowPartitionTensor = this.getRowPartitionTensor(dimension - 1);
100152 switch (this.getRowPartitionTypeByDimension(dimension - 1)) {
100153 case RowPartitionType.VALUE_ROWIDS:
100154 return RaggedTensorToTensorOp.getMaxWidthValueRowID(rowPartitionTensor);
100155 case RowPartitionType.ROW_SPLITS:
100156 return RaggedTensorToTensorOp.getMaxWidthRowSplit(rowPartitionTensor);
100157 default:
100158 throw new Error("Cannot handle partition type ".concat(RowPartitionType[this.getRowPartitionTypeByDimension(dimension - 1)]));
100159 }
100160 }
100161 }, {
100162 key: "tensorShapeFromTensor",
100163 value: function tensorShapeFromTensor(t, tShape) {
100164 var isPartial = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true;
100165 if (tShape.length === 0) {
100166 if (t[0] === -1) {
100167 return [];
100168 }
100169 throw new Error("The only valid scalar shape tensor is the fully unknown shape specified as -1.");
100170 }
100171 // MakePartialShape/MakeShapeHelper.
100172 return makeShape(t, isPartial);
100173 }
100174 }, {
100175 key: "calculateOutputSize",
100176 value: function calculateOutputSize(firstDim) {
100177 var valueShape = this.valuesShape;
100178 var defaultValueShape = this.defaultValueShape;
100179 validateDefaultValueShape(defaultValueShape, valueShape);
100180 var shape = this.tensorShapeFromTensor(this.shape, this.shapeShape);
100181 var outputShape = combineRaggedTensorToTensorShapes(this.raggedRank, shape, valueShape);
100182 var result = outputShape;
100183 if (result[0] < 0) {
100184 result[0] = firstDim;
100185 }
100186 for (var i = 1; i <= this.raggedRank; ++i) {
100187 if (result[i] < 0) {
100188 result[i] = this.getMaxWidth(i);
100189 }
100190 }
100191 return result;
100192 }
100193 /**
100194 * The outputIndex represents the index in the output tensor
100195 * where the first element of a particular dimension would be written.
100196 * If it is -1, it indicates that the index is out of scope.
100197 * Example, given firstDimension = 10, firstDimensionOutput = 6,
100198 * and outputIndexMultiplier = 100:
100199 * result = [0 100 200 300 400 500 -1 -1 -1 -1]
100200 * If firstDimensionOutput = 11 instead, then:
100201 * result = [0 100 200 300 400 500 600 700 800 900]
100202 */
100203 }, {
100204 key: "calculateFirstParentOutputIndex",
100205 value: function calculateFirstParentOutputIndex(firstDimension, outputIndexMultiplier, firstDimensionOutput) {
100206 var minDimension = Math.min(firstDimension, firstDimensionOutput);
100207 var result = [];
100208 var currentOutputIndex = 0;
100209 for (var i = 0; i < minDimension; ++i, currentOutputIndex += outputIndexMultiplier) {
100210 result.push(currentOutputIndex);
100211 }
100212 for (var _i = minDimension; _i < firstDimension; ++_i) {
100213 result.push(-1);
100214 }
100215 assert$1(result.length === firstDimension, function () {
100216 return 'Final length of result must be equal to firstDimension.';
100217 });
100218 return result;
100219 }
100220 }, {
100221 key: "calculateOutputIndexRowSplit",
100222 value: function calculateOutputIndexRowSplit(rowSplit, parentOutputIndex, outputIndexMultiplier, outputSize) {
100223 var rowSplitSize = rowSplit.length;
100224 var result = [];
100225 for (var i = 0; i < rowSplitSize - 1; ++i) {
100226 var rowLength = rowSplit[i + 1] - rowSplit[i];
100227 var realLength = Math.min(outputSize, rowLength);
100228 var parentOutputIndexCurrent = parentOutputIndex[i];
100229 if (parentOutputIndexCurrent === -1) {
100230 realLength = 0;
100231 }
100232 for (var j = 0; j < realLength; ++j) {
100233 result.push(parentOutputIndexCurrent);
100234 parentOutputIndexCurrent += outputIndexMultiplier;
100235 }
100236 for (var _j = 0; _j < rowLength - realLength; ++_j) {
100237 result.push(-1);
100238 }
100239 }
100240 if (rowSplitSize > 0 && result.length !== rowSplit[rowSplitSize - 1]) {
100241 throw new Error('Invalid row split size.');
100242 }
100243 return result;
100244 }
100245 // Calculate the output index of the first element of a list.
100246 // The parentOutputIndex is the same computation for the previous list.
100247 // -1 indicates an element or list that is out of range.
100248 // The outputIndexMultiplier is the number of output indices one moves
100249 // forward for each column.
100250 // E.g., given:
100251 // valueRowIds:[0 1 2 2 2 3 5 5 6]
100252 // parentOutputIndex:[1000 1100 2000 2100 -1 3000 4000]
100253 // outputIndexMultiplier: 10
100254 // outputSize: 2
100255 // You get:
100256 // result = [1000 1100 2000 2010 -1 2100 -1 -1 3000]
100257 // result[0] = parentOutputIndex[valueRowIds[0]]
100258 // result[1] = parentOutputIndex[valueRowIds[1]]
100259 // result[2] = parentOutputIndex[valueRowIds[2]]
100260 // result[3] = parentOutputIndex[valueRowIds[2] + 10]
100261 // result[4] = -1 because it is the third element the size is 2.
100262 // result[5] = parentOutputIndex[valueRowIds[3]]
100263 // result[6] = -1 because parentOutputIndex[valueRowIds[6]] == -1
100264 // result[7] = -1 because parentOutputIndex[valueRowIds[6]] == -1
100265 // result[8] = parentOutputIndex[valueRowIds[7]]
100266 }, {
100267 key: "calculateOutputIndexValueRowID",
100268 value: function calculateOutputIndexValueRowID(valueRowIds, parentOutputIndex, outputIndexMultiplier, outputSize) {
100269 var indexSize = valueRowIds.length;
100270 var result = [];
100271 if (indexSize === 0) {
100272 return [];
100273 }
100274 var currentOutputColumn = 0;
100275 var currentValueRowId = valueRowIds[0];
100276 if (currentValueRowId >= parentOutputIndex.length) {
100277 throw new Error("Got currentValueRowId=".concat(currentValueRowId, ", which is not less than ").concat(parentOutputIndex.length));
100278 }
100279 var currentOutputIndex = parentOutputIndex[currentValueRowId];
100280 result.push(currentOutputIndex);
100281 for (var i = 1; i < indexSize; ++i) {
100282 var nextValueRowId = valueRowIds[i];
100283 if (nextValueRowId === currentValueRowId) {
100284 if (currentOutputIndex >= 0) {
100285 ++currentOutputColumn;
100286 if (currentOutputColumn < outputSize) {
100287 currentOutputIndex += outputIndexMultiplier;
100288 } else {
100289 currentOutputIndex = -1;
100290 }
100291 }
100292 } else {
100293 currentOutputColumn = 0;
100294 currentValueRowId = nextValueRowId;
100295 if (nextValueRowId >= parentOutputIndex.length) {
100296 throw new Error("Got nextValueRowId=".concat(nextValueRowId, " which is not less than ").concat(parentOutputIndex.length));
100297 }
100298 currentOutputIndex = parentOutputIndex[nextValueRowId];
100299 }
100300 result.push(currentOutputIndex);
100301 }
100302 if (result.length !== valueRowIds.length) {
100303 throw new Error('Invalid row ids.');
100304 }
100305 return result;
100306 }
100307 }, {
100308 key: "calculateOutputIndex",
100309 value: function calculateOutputIndex(dimension, parentOutputIndex, outputIndexMultiplier, outputSize) {
100310 var rowPartitionTensor = this.getRowPartitionTensor(dimension);
100311 var partitionType = this.getRowPartitionTypeByDimension(dimension);
100312 switch (partitionType) {
100313 case RowPartitionType.VALUE_ROWIDS:
100314 return this.calculateOutputIndexValueRowID(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize);
100315 case RowPartitionType.ROW_SPLITS:
100316 if (rowPartitionTensor.length - 1 > parentOutputIndex.length) {
100317 throw new Error("Row partition size is greater than output size: ".concat(rowPartitionTensor.length - 1, " > ").concat(parentOutputIndex.length));
100318 }
100319 return this.calculateOutputIndexRowSplit(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize);
100320 default:
100321 throw new Error("Unsupported partition type: ".concat(RowPartitionType[partitionType]));
100322 }
100323 }
100324 }, {
100325 key: "getFirstDimensionSize",
100326 value: function getFirstDimensionSize() {
100327 var firstPartitionTensor = this.rowPartitionValues[0];
100328 if (this.rowPartitionTypes.length === 0) {
100329 throw new Error('No row_partition_types given.');
100330 }
100331 var firstPartitionType = this.rowPartitionTypes[0];
100332 switch (firstPartitionType) {
100333 case RowPartitionType.FIRST_DIM_SIZE:
100334 return firstPartitionTensor[0];
100335 case RowPartitionType.VALUE_ROWIDS:
100336 throw new Error('Cannot handle VALUE_ROWIDS in first dimension.');
100337 case RowPartitionType.ROW_SPLITS:
100338 return this.rowPartitionValuesShapes[0][0] - 1;
100339 default:
100340 throw new Error("Cannot handle type ".concat(RowPartitionType[firstPartitionType]));
100341 }
100342 }
100343 }, {
100344 key: "compute",
100345 value: function compute() {
100346 var firstPartitionTensor = this.rowPartitionValues[0];
100347 if (firstPartitionTensor.length <= 0) {
100348 throw new Error('Invalid first partition input. ' + 'Tensor requires at least one element.');
100349 }
100350 var firstDimension = this.getFirstDimensionSize();
100351 var outputSize = this.calculateOutputSize(firstDimension);
100352 var multiplier = new Array(this.raggedRank + 1);
100353 multiplier[multiplier.length - 1] = 1;
100354 for (var i = multiplier.length - 2; i >= 0; --i) {
100355 multiplier[i] = multiplier[i + 1] * outputSize[i + 1];
100356 }
100357 // Full size of the tensor.
100358 var outputShape = makeShape(outputSize, false);
100359 var outputTensor = getArrayFromDType(this.valuesDType, sizeFromShape(outputShape));
100360 var fullSize = multiplier[0] * outputSize[0];
100361 if (fullSize > 0) {
100362 var outputIndex = this.calculateFirstParentOutputIndex(firstDimension, multiplier[0], outputSize[0]);
100363 for (var _i2 = 1; _i2 <= this.raggedRank; ++_i2) {
100364 var newOutputIndex = this.calculateOutputIndex(_i2 - 1, outputIndex, multiplier[_i2], outputSize[_i2]);
100365 outputIndex = newOutputIndex;
100366 }
100367 this.setOutput(this.raggedRank, outputIndex, outputTensor, outputShape);
100368 }
100369 return [outputShape, outputTensor];
100370 }
100371 }, {
100372 key: "setOutput",
100373 value: function setOutput(raggedRank, outputIndex, outputTensor, outputShape) {
100374 if (outputTensor.length === 0) {
100375 return;
100376 }
100377 var valuesBase = this.values;
100378 var outputBase = outputTensor;
100379 var elementShape = outputShape.slice();
100380 elementShape = elementShape.slice(raggedRank + 1);
100381 var valueElementSize = sizeFromShape(elementShape);
100382 var outputIndexSize = outputIndex.length;
100383 // Broadcast the default value to value_element_size. (We can skip this
100384 // if defaultValueTensor.size == 1, since we use fill when that's true.)
100385 var defaultValue = this.defaultValue;
100386 if (defaultValue.length !== valueElementSize && defaultValue.length !== 1) {
100387 var srcShape = this.defaultValueShape;
100388 tidy(function () {
100389 var defaultValueTensor = reshape$3(defaultValue, srcShape);
100390 var bCastDefault = broadcastTo(defaultValueTensor, elementShape);
100391 defaultValue = bCastDefault.dataSync();
100392 });
100393 }
100394 // Loop through the outputIndex array, finding contiguous regions that
100395 // should be copied. Once we find the end of a contiguous region, copy it
100396 // and add any necessary padding (with defaultValue).
100397 var srcStart = 0; // Start of contiguous region (in values)
100398 var dstStart = 0; // Destination for contiguous region (in output)
100399 var dstEnd = 0; // Destination for contiguous region (in output)
100400 for (var srcI = 0; srcI <= outputIndexSize; ++srcI) {
100401 // dstI is the destination where the value at srcI should be copied.
100402 var dstI = srcI < outputIndexSize ? outputIndex[srcI] : -1;
100403 // If we're still in a contiguous region, then update dstEnd go to the
100404 // next srcI.
100405 if (dstI === dstEnd) {
100406 ++dstEnd;
100407 continue;
100408 }
100409 // We found the end of contiguous region. This can be because we found
100410 // a gap (dstI > dstEnd), or a source value that shouldn't be copied
100411 // because it's out-of-bounds (dstI == -1), or the end of the tensor
100412 // (dstI === -1).
100413 if (dstStart < dstEnd) {
100414 // Copy the contiguous region.
100415 var src = valuesBase.subarray(srcStart * valueElementSize);
100416 var dst = outputBase.subarray(dstStart * valueElementSize);
100417 var nVals = (dstEnd - dstStart) * valueElementSize;
100418 copyArray(dst, src, nVals);
100419 }
100420 // Add any necessary padding (w/ defaultValue).
100421 if (srcI >= outputIndexSize) {
100422 // We reached the end of values: pad to the end of output.
100423 var outputSize = outputTensor.length;
100424 dstI = Math.floor(outputSize / valueElementSize);
100425 }
100426 if (dstI > dstEnd) {
100427 if (this.defaultValue.length === 1) {
100428 outputBase.subarray(dstEnd * valueElementSize, dstI * valueElementSize).fill(this.defaultValue[0]);
100429 dstEnd = dstI;
100430 } else {
100431 while (dstI > dstEnd) {
100432 var _dst = outputBase.slice(dstEnd * valueElementSize);
100433 copyArray(_dst, defaultValue, valueElementSize);
100434 ++dstEnd;
100435 }
100436 }
100437 }
100438 // Update indices.
100439 if (dstI < 0) {
100440 // srcI should be skipped -- leave it out of the contiguous region.
100441 srcStart = srcI + 1;
100442 dstStart = dstEnd;
100443 } else {
100444 // srcI should be copied -- include it in the contiguous region.
100445 srcStart = srcI;
100446 dstStart = dstEnd;
100447 dstEnd = dstStart + 1;
100448 }
100449 }
100450 }
100451 }], [{
100452 key: "getMaxWidthRowSplit",
100453 value: function getMaxWidthRowSplit(rowSplit) {
100454 var tensorLength = rowSplit.length;
100455 if (tensorLength === 0 || tensorLength === 1) {
100456 return 0;
100457 }
100458 var maxWidth = 0;
100459 for (var i = 0; i < tensorLength - 1; ++i) {
100460 var currentWidth = rowSplit[i + 1] - rowSplit[i];
100461 if (currentWidth > maxWidth) {
100462 maxWidth = currentWidth;
100463 }
100464 }
100465 return maxWidth;
100466 }
100467 }, {
100468 key: "getMaxWidthValueRowID",
100469 value: function getMaxWidthValueRowID(valueRowIds) {
100470 var indexLength = valueRowIds.length;
100471 if (indexLength === 0) {
100472 return 0;
100473 }
100474 var firstEqualIndex = 0;
100475 var firstEqualIndexValue = valueRowIds[0];
100476 var maxWidth = 0;
100477 for (var i = 1; i < indexLength; ++i) {
100478 var value = valueRowIds[i];
100479 if (value !== firstEqualIndexValue) {
100480 firstEqualIndexValue = value;
100481 maxWidth = Math.max(i - firstEqualIndex, maxWidth);
100482 firstEqualIndex = i;
100483 }
100484 }
100485 return Math.max(indexLength - firstEqualIndex, maxWidth);
100486 }
100487 }]);
100488 return RaggedTensorToTensorOp;
100489 }();
100490 function copyArray(dst, src, size) {
100491 for (var i = 0; i < size; i++) {
100492 dst[i] = src[i];
100493 }
100494 }
100495 function makeShape(shape, isPartial) {
100496 var out = [];
100497 var _iterator = _createForOfIteratorHelper(shape),
100498 _step;
100499 try {
100500 for (_iterator.s(); !(_step = _iterator.n()).done;) {
100501 var dim = _step.value;
100502 if (dim < 0) {
100503 if (!isPartial) {
100504 throw new Error("Dimension ".concat(dim, " must be >= 0"));
100505 }
100506 if (dim < -1) {
100507 throw new Error("Dimension ".concat(dim, " must be >= -1"));
100508 }
100509 dim = -1;
100510 }
100511 out.push(dim);
100512 }
100513 } catch (err) {
100514 _iterator.e(err);
100515 } finally {
100516 _iterator.f();
100517 }
100518 return out;
100519 }
100520 function raggedTensorToTensorImpl(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes) {
100521 return new RaggedTensorToTensorOp(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes).compute();
100522 }
100523
100524 /**
100525 * @license
100526 * Copyright 2020 Google LLC. All Rights Reserved.
100527 * Licensed under the Apache License, Version 2.0 (the "License");
100528 * you may not use this file except in compliance with the License.
100529 * You may obtain a copy of the License at
100530 *
100531 * http://www.apache.org/licenses/LICENSE-2.0
100532 *
100533 * Unless required by applicable law or agreed to in writing, software
100534 * distributed under the License is distributed on an "AS IS" BASIS,
100535 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100536 * See the License for the specific language governing permissions and
100537 * limitations under the License.
100538 * =============================================================================
100539 */
100540 function rangeImpl(start, stop, step, dtype) {
100541 var sameStartStop = start === stop;
100542 var increasingRangeNegativeStep = start < stop && step < 0;
100543 var decreasingRangePositiveStep = stop < start && step > 1;
100544 if (sameStartStop || increasingRangeNegativeStep || decreasingRangePositiveStep) {
100545 return makeZerosTypedArray(0, dtype);
100546 }
100547 var numElements = Math.abs(Math.ceil((stop - start) / step));
100548 var values = makeZerosTypedArray(numElements, dtype);
100549 if (stop < start && step === 1) {
100550 // Auto adjust the step's sign if it hasn't been set
100551 // (or was set to 1)
100552 step = -1;
100553 }
100554 values[0] = start;
100555 for (var i = 1; i < values.length; i++) {
100556 values[i] = values[i - 1] + step;
100557 }
100558 return values;
100559 }
100560
100561 /**
100562 * @license
100563 * Copyright 2020 Google LLC. All Rights Reserved.
100564 * Licensed under the Apache License, Version 2.0 (the License);
100565 * you may not use this file except in compliance with the License.
100566 * You may obtain a copy of the License at
100567 *
100568 * http://www.apache.org/licenses/LICENSE-2.0
100569 *
100570 * Unless required by applicable law or agreed to in writing, software
100571 * distributed under the License is distributed on an AS IS BASIS,
100572 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100573 * See the License for the specific language governing permissions and
100574 * limitations under the License.
100575 * =============================================================================
100576 */
100577 var rsqrtImpl = createSimpleUnaryImpl(function (xi) {
100578 return 1 / Math.sqrt(xi);
100579 });
100580 var rsqrt$1 = unaryKernelFuncFromImpl(Rsqrt, rsqrtImpl);
100581 var rsqrtConfig$1 = {
100582 kernelName: Rsqrt,
100583 backendName: 'cpu',
100584 kernelFunc: rsqrt$1
100585 };
100586
100587 /**
100588 * @license
100589 * Copyright 2020 Google LLC. All Rights Reserved.
100590 * Licensed under the Apache License, Version 2.0 (the "License");
100591 * you may not use this file except in compliance with the License.
100592 * You may obtain a copy of the License at
100593 *
100594 * http://www.apache.org/licenses/LICENSE-2.0
100595 *
100596 * Unless required by applicable law or agreed to in writing, software
100597 * distributed under the License is distributed on an "AS IS" BASIS,
100598 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100599 * See the License for the specific language governing permissions and
100600 * limitations under the License.
100601 * =============================================================================
100602 */
100603 function scatterImpl(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) {
100604 var flattenShape = [outputSize / sliceSize, sliceSize];
100605 var indicesData = indices.values;
100606 var updatesData = updates.values;
100607 if (outputSize === 0) {
100608 return buffer(shape, updates.dtype);
100609 }
100610 var outBuf = defaultValue instanceof TensorBuffer ? defaultValue : buffer(flattenShape, updates.dtype);
100611 if (typeof defaultValue === 'string') {
100612 outBuf.values.fill(defaultValue);
100613 } else if (typeof defaultValue === 'number') {
100614 outBuf.values.fill(defaultValue);
100615 } else if (typeof defaultValue === 'boolean') {
100616 outBuf.values.fill(+defaultValue);
100617 }
100618 for (var i = 0; i < numUpdates; i++) {
100619 var index = [];
100620 var flattenIndex = 0;
100621 for (var j = 0; j < sliceRank; j++) {
100622 var dim = indicesData[i * sliceRank + j];
100623 index.push(dim);
100624 flattenIndex += dim * strides[j];
100625 }
100626 if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) {
100627 throw new Error("Invalid indices: ".concat(index, " does not index into ").concat(shape));
100628 }
100629 for (var k = 0; k < sliceSize; k++) {
100630 if (sumDupeIndices) {
100631 outBuf.values[flattenIndex * sliceSize + k] += updatesData[i * sliceSize + k];
100632 } else {
100633 outBuf.values[flattenIndex * sliceSize + k] = updates.rank === 0 ? updatesData[0] : updatesData[i * sliceSize + k];
100634 }
100635 }
100636 }
100637 return outBuf;
100638 }
100639
100640 /**
100641 * @license
100642 * Copyright 2020 Google LLC. All Rights Reserved.
100643 * Licensed under the Apache License, Version 2.0 (the License);
100644 * you may not use this file except in compliance with the License.
100645 * You may obtain a copy of the License at
100646 *
100647 * http://www.apache.org/licenses/LICENSE-2.0
100648 *
100649 * Unless required by applicable law or agreed to in writing, software
100650 * distributed under the License is distributed on an AS IS BASIS,
100651 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100652 * See the License for the specific language governing permissions and
100653 * limitations under the License.
100654 * =============================================================================
100655 */
100656 var sigmoidImpl = createSimpleUnaryImpl(function (xi) {
100657 return 1 / (1 + Math.exp(-xi));
100658 });
100659 var sigmoid$1 = unaryKernelFunc$1(Sigmoid$1, function (xi) {
100660 return 1 / (1 + Math.exp(-xi));
100661 });
100662 var sigmoidConfig$1 = {
100663 kernelName: Sigmoid$1,
100664 backendName: 'cpu',
100665 kernelFunc: sigmoid$1
100666 };
100667
100668 function sliceImpl(vals, begin, size, shape, dtype) {
100669 var isContinous = isSliceContinous(shape, begin, size);
100670 var length = sizeFromShape(size);
100671 var xStrides = computeStrides(shape);
100672 if (isContinous) {
100673 var flatOffset = computeFlatOffset(begin, xStrides);
100674 if (dtype === 'string') {
100675 return vals.slice(flatOffset, flatOffset + length);
100676 }
100677 return vals.subarray(flatOffset, flatOffset + length);
100678 }
100679 var decodedData = dtype === 'string' ? fromUint8ToStringArray(vals) : vals;
100680 var inBuf = buffer(shape, dtype, decodedData);
100681 var outBuf = buffer(size, dtype);
100682 for (var i = 0; i < outBuf.size; ++i) {
100683 var outLoc = outBuf.indexToLoc(i);
100684 var inLoc = outLoc.map(function (idx, j) {
100685 return idx + begin[j];
100686 });
100687 outBuf.set.apply(outBuf, [inBuf.get.apply(inBuf, _toConsumableArray(inLoc))].concat(_toConsumableArray(outLoc)));
100688 }
100689 if (dtype === 'string') {
100690 return fromStringArrayToUint8(outBuf.values);
100691 }
100692 return outBuf.values;
100693 }
100694 function slice$1(args) {
100695 var inputs = args.inputs,
100696 backend = args.backend,
100697 attrs = args.attrs;
100698 var x = inputs.x;
100699 var begin = attrs.begin,
100700 size = attrs.size;
100701 assertNotComplex$1(x, 'slice');
100702 var _slice_util$parseSlic = parseSliceParams(x, begin, size),
100703 _slice_util$parseSlic2 = _slicedToArray(_slice_util$parseSlic, 2),
100704 $begin = _slice_util$parseSlic2[0],
100705 $size = _slice_util$parseSlic2[1];
100706 assertParamsValid(x, $begin, $size);
100707 var vals = backend.data.get(x.dataId).values;
100708 var outVals = sliceImpl(vals, $begin, $size, x.shape, x.dtype);
100709 return backend.makeTensorInfo($size, x.dtype, outVals);
100710 }
100711 var sliceConfig$1 = {
100712 kernelName: Slice,
100713 backendName: 'cpu',
100714 kernelFunc: slice$1
100715 };
100716
100717 /**
100718 * @license
100719 * Copyright 2021 Google LLC. All Rights Reserved.
100720 * Licensed under the Apache License, Version 2.0 (the "License");
100721 * you may not use this file except in compliance with the License.
100722 * You may obtain a copy of the License at
100723 *
100724 * http://www.apache.org/licenses/LICENSE-2.0
100725 *
100726 * Unless required by applicable law or agreed to in writing, software
100727 * distributed under the License is distributed on an "AS IS" BASIS,
100728 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100729 * See the License for the specific language governing permissions and
100730 * limitations under the License.
100731 * =============================================================================
100732 */
100733 function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) {
100734 var indicesCount = indicesShape[0];
100735 var denseRows = denseShape[0];
100736 var emptyRowIndicator = new Array(denseRows);
100737 var reverseIndexMap = new Array(indicesCount);
100738 var rank = indicesShape[1];
100739 if (denseRows === 0) {
100740 if (indicesCount !== 0) {
100741 throw new Error(getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesCount));
100742 }
100743 var outputIndices = getArrayFromDType(indicesDType, 0);
100744 var outputValues = getArrayFromDType(valuesDType, 0);
100745 return [outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap];
100746 }
100747 var rowsAreOrdered = true;
100748 var lastIndicesRow = 0;
100749 var csrOffset = new Array(denseRows).fill(0);
100750 for (var i = 0; i < indicesCount; ++i) {
100751 // indices is a 2d tensor with shape of [N, rank]
100752 var row = indices[i * rank];
100753 if (row < 0) {
100754 throw new Error(getSparseFillEmptyRowsNegativeIndexErrorMessage(i, row));
100755 }
100756 if (row >= denseRows) {
100757 throw new Error(getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(i, row, denseRows));
100758 }
100759 ++csrOffset[row];
100760 rowsAreOrdered = rowsAreOrdered && row >= lastIndicesRow;
100761 lastIndicesRow = row;
100762 }
100763 var allRowsFull = true;
100764 for (var _row = 0; _row < denseRows; ++_row) {
100765 // csrOffset here describes the number of elements in this dense row
100766 var rowEmpty = csrOffset[_row] === 0;
100767 emptyRowIndicator[_row] = rowEmpty;
100768 allRowsFull = allRowsFull && !rowEmpty;
100769 // In filled version, each row has at least one element.
100770 csrOffset[_row] = Math.max(csrOffset[_row], 1);
100771 // Update csrOffset to represent the number of elements up to and
100772 // including denseRows + 1:
100773 // csrOffset[0] == #{elements of row 0}
100774 // csrOffset[1] == #{elements of row 1} + #{elements of row 0}
100775 // ..
100776 // csrOffset[i] == starting index for elements in row i + 1.
100777 if (_row > 0) {
100778 csrOffset[_row] += csrOffset[_row - 1];
100779 }
100780 }
100781 if (allRowsFull && rowsAreOrdered) {
100782 var _outputIndices = indices;
100783 var _outputValues = values;
100784 for (var _i = 0; _i < indicesCount; ++_i) {
100785 reverseIndexMap[_i] = _i;
100786 }
100787 return [_outputIndices, [indicesCount, rank], _outputValues, emptyRowIndicator, reverseIndexMap];
100788 } else {
100789 var fullIndicesCount = csrOffset[denseRows - 1];
100790 var _outputIndices2 = getArrayFromDType(indicesDType, fullIndicesCount * rank);
100791 var _outputValues2 = getArrayFromDType(valuesDType, fullIndicesCount);
100792 var filledCount = new Array(denseRows).fill(0);
100793 // Fill in values for rows that are not missing
100794 for (var _i2 = 0; _i2 < indicesCount; ++_i2) {
100795 // indices is a 2d tensor with shape of [N, rank]
100796 var _row2 = indices[_i2 * rank];
100797 var offset = filledCount[_row2];
100798 var outputI = (_row2 === 0 ? 0 : csrOffset[_row2 - 1]) + offset;
100799 filledCount[_row2]++; // Increment the filled count for this row.
100800 for (var j = 0; j < rank; ++j) {
100801 // indices and outputIndices are 2d tensors with shape of [N, rank]
100802 _outputIndices2[outputI * rank + j] = indices[_i2 * rank + j];
100803 }
100804 _outputValues2[outputI] = values[_i2];
100805 // We'll need this reverse index map to backprop correctly.
100806 reverseIndexMap[_i2] = outputI;
100807 }
100808 // Fill in values for rows that are missing
100809 for (var _row3 = 0; _row3 < denseRows; ++_row3) {
100810 var rowCount = filledCount[_row3];
100811 if (rowCount === 0) {
100812 // We haven't filled this row
100813 var startingIndex = _row3 === 0 ? 0 : csrOffset[_row3 - 1];
100814 // Remaining index values were set to zero already.
100815 // Just need to set the row index in the right location.
100816 // outputIndices is a 2d tensor with shape of [N, rank]
100817 _outputIndices2[startingIndex * rank + 0] = _row3;
100818 for (var col = 1; col < rank; ++col) {
100819 _outputIndices2[startingIndex * rank + col] = 0;
100820 }
100821 _outputValues2[startingIndex] = defaultValue;
100822 }
100823 }
100824 return [_outputIndices2, [fullIndicesCount, rank], _outputValues2, emptyRowIndicator, reverseIndexMap];
100825 }
100826 }
100827
100828 /**
100829 * @license
100830 * Copyright 2021 Google LLC. All Rights Reserved.
100831 * Licensed under the Apache License, Version 2.0 (the "License");
100832 * you may not use this file except in compliance with the License.
100833 * You may obtain a copy of the License at
100834 *
100835 * http://www.apache.org/licenses/LICENSE-2.0
100836 *
100837 * Unless required by applicable law or agreed to in writing, software
100838 * distributed under the License is distributed on an "AS IS" BASIS,
100839 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100840 * See the License for the specific language governing permissions and
100841 * limitations under the License.
100842 * =============================================================================
100843 */
100844 function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) {
100845 var denseSize = sizeFromShape(inputShape);
100846 var nnz = inputIndicesShape[0];
100847 var outputRank = targetShape.length;
100848 // Compute the output shape. Determine product of specified dimensions, and
100849 // find the index of the unspecified one.
100850 var outputShape = [];
100851 var product = 1;
100852 var unknownIndex = -1;
100853 for (var d = 0; d < outputRank; ++d) {
100854 var size = targetShape[d];
100855 if (size === -1) {
100856 if (unknownIndex !== -1) {
100857 throw new Error(getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(unknownIndex, d));
100858 }
100859 unknownIndex = d;
100860 outputShape.push(1);
100861 } else {
100862 if (size < 0) {
100863 throw new Error(getSparseReshapeNegativeOutputDimErrorMessage(d, size));
100864 }
100865 product *= size;
100866 outputShape.push(size);
100867 }
100868 }
100869 if (unknownIndex !== -1) {
100870 if (product <= 0) {
100871 throw new Error(getSparseReshapeEmptyTensorZeroOutputDimErrorMessage());
100872 }
100873 var missing = Math.trunc(denseSize / product);
100874 if (product * missing !== denseSize) {
100875 throw new Error(getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape));
100876 }
100877 outputShape[unknownIndex] = missing;
100878 }
100879 var outputSize = sizeFromShape(outputShape);
100880 if (outputSize !== denseSize) {
100881 throw new Error(getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape));
100882 }
100883 var inputRank = inputShape.length;
100884 var inputStrides = [];
100885 if (inputRank > 0) {
100886 inputStrides[inputRank - 1] = 1;
100887 for (var _d = inputRank - 2; _d >= 0; --_d) {
100888 inputStrides[_d] = inputStrides[_d + 1] * inputShape[_d + 1];
100889 }
100890 }
100891 var outputStrides = [];
100892 if (outputRank > 0) {
100893 outputStrides[outputRank - 1] = 1;
100894 for (var _d2 = outputRank - 2; _d2 >= 0; --_d2) {
100895 outputStrides[_d2] = outputStrides[_d2 + 1] * outputShape[_d2 + 1];
100896 }
100897 }
100898 var newIndices = getArrayFromDType(inputDType, nnz * outputRank);
100899 for (var i = 0; i < nnz; ++i) {
100900 var id = 0;
100901 for (var j = 0; j < inputRank; ++j) {
100902 // inputIndices is a 2d tensor with shape of [nnz, inputRank]
100903 id += inputIndices[i * inputRank + j] * inputStrides[j];
100904 }
100905 for (var _j = 0; _j < outputRank; ++_j) {
100906 // newIndices is a 2d tensor with shape of [nnz, outputRank]
100907 newIndices[i * outputRank + _j] = Math.trunc(id / outputStrides[_j]);
100908 id %= outputStrides[_j];
100909 }
100910 }
100911 return [newIndices, [nnz, outputRank], outputShape];
100912 }
100913
100914 /**
100915 * @license
100916 * Copyright 2021 Google LLC. All Rights Reserved.
100917 * Licensed under the Apache License, Version 2.0 (the "License");
100918 * you may not use this file except in compliance with the License.
100919 * You may obtain a copy of the License at
100920 *
100921 * http://www.apache.org/licenses/LICENSE-2.0
100922 *
100923 * Unless required by applicable law or agreed to in writing, software
100924 * distributed under the License is distributed on an "AS IS" BASIS,
100925 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100926 * See the License for the specific language governing permissions and
100927 * limitations under the License.
100928 * =============================================================================
100929 */
100930 function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds) {
100931 var isMean = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : false;
100932 var defaultValue = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : 0;
100933 var numIndices = indices.length;
100934 // Flatten the array to two dimensions
100935 var inputFlat = [inputShape[0], input.length / inputShape[0]];
100936 var numCol = inputFlat[1];
100937 // Note that the current implementation assumes that segmentIds values are
100938 // sorted.
100939 var lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0;
100940 var outputRows = lastSegmentIdPlusOne;
100941 if (outputRows < 0) {
100942 throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
100943 }
100944 var outputShape = inputShape.slice();
100945 outputShape[0] = outputRows;
100946 var outputLength = outputShape.reduce(function (product, value) {
100947 return product * value;
100948 }, 1);
100949 // Output array is initialized with the value 0 by default.
100950 var output = getArrayFromDType(inputDType, outputLength);
100951 // Note that we do not initialize the output buffer with a default value, so
100952 // we need to explicitly set missing indices to the default value.
100953 if (numIndices === 0) {
100954 if (outputRows > 0) {
100955 output.fill(defaultValue);
100956 }
100957 return [output, outputShape];
100958 }
100959 if (outputRows <= 0) {
100960 throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
100961 }
100962 var start = 0,
100963 end = 1;
100964 // Index from which the output is not initialized.
100965 var uninitializedIndex = 0;
100966 var outIndex = segmentIds[start];
100967 while (true) {
100968 // We initialize nextIndex to 0 to avoid may be uninitialized warning
100969 var nextIndex = 0;
100970 if (end < numIndices) {
100971 nextIndex = segmentIds[end];
100972 if (outIndex === nextIndex) {
100973 ++end;
100974 continue;
100975 }
100976 // We have a new segment here. Verify that the segment ids are growing.
100977 if (outIndex >= nextIndex) {
100978 throw new Error(getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage());
100979 }
100980 }
100981 if (outIndex < 0 || outIndex >= outputRows) {
100982 throw new Error(getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(outIndex, outputRows));
100983 }
100984 // If there is a gap between two indices, we need to set that gap to the
100985 // default value.
100986 if (outIndex > uninitializedIndex) {
100987 output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol);
100988 }
100989 for (var i = start; i < end; ++i) {
100990 var index = indices[i];
100991 if (index < 0 || index >= inputFlat[0]) {
100992 throw new Error(getSparseSegmentReductionIndicesOutOfRangeErrorMessage(i, indices[i], inputFlat[0]));
100993 }
100994 for (var j = 0; j < numCol; j++) {
100995 output[outIndex * numCol + j] += input[index * numCol + j];
100996 }
100997 }
100998 if (isMean) {
100999 for (var _j = 0; _j < numCol; _j++) {
101000 output[outIndex * numCol + _j] /= end - start;
101001 }
101002 }
101003 start = end;
101004 ++end;
101005 uninitializedIndex = outIndex + 1;
101006 outIndex = nextIndex;
101007 if (end > numIndices) {
101008 break;
101009 }
101010 }
101011 // Fill the gap at the end with the default value.
101012 if (uninitializedIndex < outputRows) {
101013 output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol);
101014 }
101015 return [output, outputShape];
101016 }
101017
101018 /**
101019 * @license
101020 * Copyright 2020 Google LLC. All Rights Reserved.
101021 * Licensed under the Apache License, Version 2.0 (the License);
101022 * you may not use this file except in compliance with the License.
101023 * You may obtain a copy of the License at
101024 *
101025 * http://www.apache.org/licenses/LICENSE-2.0
101026 *
101027 * Unless required by applicable law or agreed to in writing, software
101028 * distributed under the License is distributed on an AS IS BASIS,
101029 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101030 * See the License for the specific language governing permissions and
101031 * limitations under the License.
101032 * =============================================================================
101033 */
101034 var sqrtImpl = createSimpleUnaryImpl(function (xi) {
101035 return Math.sqrt(xi);
101036 });
101037 var sqrt$1 = unaryKernelFunc$1(Sqrt, function (xi) {
101038 return Math.sqrt(xi);
101039 });
101040 var sqrtConfig$1 = {
101041 kernelName: Sqrt,
101042 backendName: 'cpu',
101043 kernelFunc: sqrt$1
101044 };
101045
101046 /**
101047 * @license
101048 * Copyright 2020 Google LLC. All Rights Reserved.
101049 * Licensed under the Apache License, Version 2.0 (the "License");
101050 * you may not use this file except in compliance with the License.
101051 * You may obtain a copy of the License at
101052 *
101053 * http://www.apache.org/licenses/LICENSE-2.0
101054 *
101055 * Unless required by applicable law or agreed to in writing, software
101056 * distributed under the License is distributed on an "AS IS" BASIS,
101057 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101058 * See the License for the specific language governing permissions and
101059 * limitations under the License.
101060 * =============================================================================
101061 */
101062 var squaredDifferenceImpl = createSimpleBinaryKernelImpl(function (a, b) {
101063 var diff = a - b;
101064 return diff * diff;
101065 });
101066 var squaredDifference$1 = binaryKernelFunc$1(SquaredDifference, squaredDifferenceImpl);
101067 var squaredDifferenceConfig$1 = {
101068 kernelName: SquaredDifference,
101069 backendName: 'cpu',
101070 kernelFunc: squaredDifference$1
101071 };
101072
101073 /**
101074 * @license
101075 * Copyright 2023 Google LLC.
101076 * Licensed under the Apache License, Version 2.0 (the "License");
101077 * you may not use this file except in compliance with the License.
101078 * You may obtain a copy of the License at
101079 *
101080 * http://www.apache.org/licenses/LICENSE-2.0
101081 *
101082 * Unless required by applicable law or agreed to in writing, software
101083 * distributed under the License is distributed on an "AS IS" BASIS,
101084 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101085 * See the License for the specific language governing permissions and
101086 * limitations under the License.
101087 * =============================================================================
101088 */
101089 var staticRegexReplaceImpl = createSimpleUnaryImpl(function (x, attrs) {
101090 var pattern = attrs.pattern,
101091 replaceGlobal = attrs.replaceGlobal,
101092 rewrite = attrs.rewrite;
101093 // TODO(mattSoulanille): Don't create a regex each time.
101094 return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite);
101095 });
101096 var staticRegexReplace$1 = unaryKernelFuncFromImpl(StaticRegexReplace, staticRegexReplaceImpl);
101097 var staticRegexReplaceConfig$1 = {
101098 kernelName: StaticRegexReplace,
101099 backendName: 'cpu',
101100 kernelFunc: staticRegexReplace$1
101101 };
101102
101103 function stridedSliceImpl(outShape, xBuf, strides, begin) {
101104 var outBuf = buffer(outShape, xBuf.dtype);
101105 for (var i = 0; i < outBuf.size; i++) {
101106 var loc = outBuf.indexToLoc(i);
101107 var newLoc = new Array(loc.length);
101108 for (var j = 0; j < newLoc.length; j++) {
101109 newLoc[j] = loc[j] * strides[j] + begin[j];
101110 }
101111 outBuf.set.apply(outBuf, [xBuf.get.apply(xBuf, newLoc)].concat(_toConsumableArray(loc)));
101112 }
101113 return outBuf;
101114 }
101115
101116 /**
101117 * The StringNGramsOp class creates ngrams from ragged string data.
101118 * The constructor contains all attributes related to the operation such as
101119 * padding widths and strings, and the compute function can be used to
101120 * compute the ngrams for different ragged tensor inputs.
101121 */
101122 var StringNGramsOp = /*#__PURE__*/function () {
101123 function StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
101124 _classCallCheck(this, StringNGramsOp);
101125 this.separator = encodeString(separator);
101126 this.nGramWidths = nGramWidths;
101127 this.leftPad = encodeString(leftPad);
101128 this.rightPad = encodeString(rightPad);
101129 this.padWidth = padWidth;
101130 this.preserveShort = preserveShortSequences;
101131 }
101132 _createClass(StringNGramsOp, [{
101133 key: "getPadWidth",
101134 value: function getPadWidth(nGramWidth) {
101135 // Ngrams can be padded with either a fixed pad width or a dynamic pad
101136 // width depending on the 'padWidth' arg, but in no case should the padding
101137 // ever be wider than 'nGramWidth' - 1.
101138 return Math.min(this.padWidth < 0 ? nGramWidth - 1 : this.padWidth, nGramWidth - 1);
101139 }
101140 }, {
101141 key: "getNumNGrams",
101142 value: function getNumNGrams(length, nGramWidth) {
101143 var padWidth = this.getPadWidth(nGramWidth);
101144 return Math.max(0, length + 2 * padWidth - nGramWidth + 1);
101145 }
101146 }, {
101147 key: "createNGrams",
101148 value: function createNGrams(data, splitIndex, output, outputStartIndex, numNGrams, nGramWidth) {
101149 var _this = this;
101150 var _loop = function _loop() {
101151 var padWidth = _this.getPadWidth(nGramWidth);
101152 var leftPadding = Math.max(0, padWidth - nGramIndex);
101153 var rightPadding = Math.max(0, padWidth - (numNGrams - (nGramIndex + 1)));
101154 var numTokens = nGramWidth - (leftPadding + rightPadding);
101155 var dataStartIndex = splitIndex + (leftPadding > 0 ? 0 : nGramIndex - padWidth);
101156 // Calculate the total expected size of the nGram so we can reserve the
101157 // correct amount of space in the string.
101158 var nGramSize = 0;
101159 // Size of the left padding.
101160 nGramSize += leftPadding * _this.leftPad.length;
101161 // Size of the tokens.
101162 for (var n = 0; n < numTokens; ++n) {
101163 nGramSize += data[dataStartIndex + n].length;
101164 }
101165 // Size of the right padding.
101166 nGramSize += rightPadding * _this.rightPad.length;
101167 // Size of the separators.
101168 var numSeparators = leftPadding + rightPadding + numTokens - 1;
101169 nGramSize += numSeparators * _this.separator.length;
101170 // Build the nGram.
101171 output[outputStartIndex + nGramIndex] = new Uint8Array(nGramSize);
101172 var nGram = output[outputStartIndex + nGramIndex];
101173 var nextNGramIndex = 0;
101174 var appendToNGram = function appendToNGram(str) {
101175 return str.forEach(function (value) {
101176 return nGram[nextNGramIndex++] = value;
101177 });
101178 };
101179 for (var _n = 0; _n < leftPadding; ++_n) {
101180 appendToNGram(_this.leftPad);
101181 appendToNGram(_this.separator);
101182 }
101183 // Only output first numTokens - 1 pairs of data and separator
101184 for (var _n2 = 0; _n2 < numTokens - 1; ++_n2) {
101185 appendToNGram(data[dataStartIndex + _n2]);
101186 appendToNGram(_this.separator);
101187 }
101188 // Handle case when there are no tokens or no right padding as these
101189 // can result in consecutive separators.
101190 if (numTokens > 0) {
101191 // If we have tokens, then output last and then pair each separator
101192 // with the right padding that follows, to ensure nGram ends either with
101193 // the token or with the right pad.
101194 appendToNGram(data[dataStartIndex + numTokens - 1]);
101195 for (var _n3 = 0; _n3 < rightPadding; ++_n3) {
101196 appendToNGram(_this.separator);
101197 appendToNGram(_this.rightPad);
101198 }
101199 } else {
101200 // If we don't have tokens, then the last item inserted into the nGram
101201 // has been the separator from the left padding loop above. Hence,
101202 // output right pad and separator and make sure to finish with a
101203 // padding, not a separator.
101204 for (var _n4 = 0; _n4 < rightPadding - 1; ++_n4) {
101205 appendToNGram(_this.rightPad);
101206 appendToNGram(_this.separator);
101207 }
101208 appendToNGram(_this.rightPad);
101209 }
101210 };
101211 for (var nGramIndex = 0; nGramIndex < numNGrams; ++nGramIndex) {
101212 _loop();
101213 }
101214 }
101215 // Data and splits together form the definition of the ragged tensor,
101216 // where data is 1 dimensional and contains the values of the tensor
101217 // and splits denotes the indices at which each row starts.
101218 }, {
101219 key: "compute",
101220 value: function compute(data, splits) {
101221 var _this2 = this;
101222 // Validate that the splits are valid indices into data, only if there are
101223 // splits specified.
101224 var inputDataSize = data.length;
101225 var splitsSize = splits.length;
101226 if (splitsSize > 0) {
101227 var prevSplit = splits[0];
101228 if (prevSplit !== 0) {
101229 throw new Error("First split value must be 0, got ".concat(prevSplit));
101230 }
101231 for (var i = 1; i < splitsSize; ++i) {
101232 var validSplits = splits[i] >= prevSplit;
101233 validSplits = validSplits && splits[i] <= inputDataSize;
101234 if (!validSplits) {
101235 throw new Error("Invalid split value ".concat(splits[i], ", must be in [").concat(prevSplit, ", ").concat(inputDataSize, "]"));
101236 }
101237 prevSplit = splits[i];
101238 }
101239 if (prevSplit !== inputDataSize) {
101240 throw new Error("Last split value must be data size. Expected ".concat(inputDataSize, ", got ").concat(prevSplit));
101241 }
101242 }
101243 var numBatchItems = splitsSize - 1;
101244 var nGramsSplits = getArrayFromDType('int32', splitsSize);
101245 // If there is no data or size, return an empty ragged tensor.
101246 if (inputDataSize === 0 || splitsSize === 0) {
101247 var empty = new Array(inputDataSize);
101248 for (var _i = 0; _i <= numBatchItems; ++_i) {
101249 nGramsSplits[_i] = 0;
101250 }
101251 return [empty, nGramsSplits];
101252 }
101253 nGramsSplits[0] = 0;
101254 var _loop2 = function _loop2() {
101255 var length = splits[_i2] - splits[_i2 - 1];
101256 var numNGrams = 0;
101257 _this2.nGramWidths.forEach(function (nGramWidth) {
101258 numNGrams += _this2.getNumNGrams(length, nGramWidth);
101259 });
101260 if (_this2.preserveShort && length > 0 && numNGrams === 0) {
101261 numNGrams = 1;
101262 }
101263 nGramsSplits[_i2] = nGramsSplits[_i2 - 1] + numNGrams;
101264 };
101265 for (var _i2 = 1; _i2 <= numBatchItems; ++_i2) {
101266 _loop2();
101267 }
101268 var nGrams = new Array(nGramsSplits[numBatchItems]);
101269 var _loop3 = function _loop3(_i3) {
101270 var splitIndex = splits[_i3];
101271 var outputStartIdx = nGramsSplits[_i3];
101272 _this2.nGramWidths.forEach(function (nGramWidth) {
101273 var length = splits[_i3 + 1] - splits[_i3];
101274 var numNGrams = _this2.getNumNGrams(length, nGramWidth);
101275 _this2.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
101276 outputStartIdx += numNGrams;
101277 });
101278 // If we're preserving short sequences, check to see if no sequence was
101279 // generated by comparing the current output start idx to the original
101280 // one (nGramSplitsdata). If no ngrams were generated, then they will
101281 // be equal (since we increment outputStartIdx by numNGrams every
101282 // time we create a set of ngrams.)
101283 if (_this2.preserveShort && outputStartIdx === nGramsSplits[_i3]) {
101284 var dataLength = splits[_i3 + 1] - splits[_i3];
101285 // One legitimate reason to not have any ngrams when this.preserveShort
101286 // is true is if the sequence itself is empty. In that case, move on.
101287 if (dataLength === 0) {
101288 return "continue";
101289 }
101290 // We don't have to worry about dynamic padding sizes here: if padding
101291 // was dynamic, every sequence would have had sufficient padding to
101292 // generate at least one nGram.
101293 var nGramWidth = dataLength + 2 * _this2.padWidth;
101294 var numNGrams = 1;
101295 _this2.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
101296 }
101297 };
101298 for (var _i3 = 0; _i3 < numBatchItems; ++_i3) {
101299 var _ret = _loop3(_i3);
101300 if (_ret === "continue") continue;
101301 }
101302 return [nGrams, nGramsSplits];
101303 }
101304 }]);
101305 return StringNGramsOp;
101306 }();
101307 function stringNGramsImpl(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
101308 return new StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences).compute(data, dataSplits);
101309 }
101310
101311 /**
101312 * @license
101313 * Copyright 2021 Google LLC. All Rights Reserved.
101314 * Licensed under the Apache License, Version 2.0 (the "License");
101315 * you may not use this file except in compliance with the License.
101316 * You may obtain a copy of the License at
101317 *
101318 * http://www.apache.org/licenses/LICENSE-2.0
101319 *
101320 * Unless required by applicable law or agreed to in writing, software
101321 * distributed under the License is distributed on an "AS IS" BASIS,
101322 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101323 * See the License for the specific language governing permissions and
101324 * limitations under the License.
101325 * =============================================================================
101326 */
101327 function split(str, delimiters, skipEmpty, result) {
101328 if (!str.length) {
101329 return;
101330 }
101331 // When the delimiter is empty, the input is split into individual characters.
101332 if (delimiters.length === 0) {
101333 for (var i = 0; i < str.length; ++i) {
101334 result.push(str.subarray(i, i + 1));
101335 }
101336 return;
101337 }
101338 // When there is one delimiter, the input is split only at that delimiter.
101339 if (delimiters.length === 1) {
101340 var delimiter = delimiters[0];
101341 var f = str.indexOf(delimiter);
101342 while (f !== -1) {
101343 var token = str.subarray(0, f);
101344 if (!skipEmpty || token.length !== 0) {
101345 result.push(token);
101346 }
101347 str = str.subarray(f + 1);
101348 f = str.indexOf(delimiter);
101349 }
101350 if (!skipEmpty || str.length !== 0) {
101351 result.push(str);
101352 }
101353 return;
101354 }
101355 // When there are multiple delimiters, the input is split at every instance
101356 // one of the delimiters appears.
101357 var tokenStart = 0;
101358 for (var _i = 0; _i < str.length + 1; _i++) {
101359 if (_i === str.length || delimiters.indexOf(str[_i]) !== -1) {
101360 var _token = str.subarray(tokenStart, _i);
101361 if (!skipEmpty || _token.length !== 0) {
101362 result.push(_token);
101363 }
101364 tokenStart = _i + 1;
101365 }
101366 }
101367 }
101368 function stringSplitImpl(input, delimiter, skipEmpty) {
101369 var batchSize = input.length;
101370 // Empty delimiter means split the input character by character.
101371 var tokens = [];
101372 var outputSize = 0;
101373 var maxNumEntries = 0;
101374 var numIndices = new Array(batchSize);
101375 for (var i = 0; i < batchSize; ++i) {
101376 var prevTokensLength = tokens.length;
101377 split(input[i], delimiter, skipEmpty, tokens);
101378 var nEntries = tokens.length - prevTokensLength;
101379 numIndices[i] = nEntries;
101380 outputSize += nEntries;
101381 maxNumEntries = Math.max(maxNumEntries, nEntries);
101382 }
101383 var indices = getArrayFromDType('int32', outputSize * 2);
101384 var values = new Array(outputSize);
101385 var shape = [batchSize, maxNumEntries];
101386 var c = 0;
101387 for (var _i2 = 0; _i2 < batchSize; ++_i2) {
101388 for (var j = 0; j < numIndices[_i2]; ++j) {
101389 // indices is a 2d tensor with shape of [outputSize, 2]
101390 indices[c * 2] = _i2;
101391 indices[c * 2 + 1] = j;
101392 values[c] = tokens[c];
101393 ++c;
101394 }
101395 }
101396 return [indices, values, shape];
101397 }
101398
101399 /**
101400 * @license
101401 * Copyright 2021 Google LLC. All Rights Reserved.
101402 * Licensed under the Apache License, Version 2.0 (the "License");
101403 * you may not use this file except in compliance with the License.
101404 * You may obtain a copy of the License at
101405 *
101406 * http://www.apache.org/licenses/LICENSE-2.0
101407 *
101408 * Unless required by applicable law or agreed to in writing, software
101409 * distributed under the License is distributed on an "AS IS" BASIS,
101410 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101411 * See the License for the specific language governing permissions and
101412 * limitations under the License.
101413 * =============================================================================
101414 */
101415 function stringToHashBucketFastImpl(input, numBuckets) {
101416 var output = getArrayFromDType('int32', input.length);
101417 for (var i = 0; i < input.length; ++i) {
101418 output[i] = fingerPrint64(input[i]).modulo(numBuckets).getLowBitsUnsigned();
101419 }
101420 return output;
101421 }
101422
101423 /**
101424 * @license
101425 * Copyright 2020 Google LLC. All Rights Reserved.
101426 * Licensed under the Apache License, Version 2.0 (the "License");
101427 * you may not use this file except in compliance with the License.
101428 * You may obtain a copy of the License at
101429 *
101430 * http://www.apache.org/licenses/LICENSE-2.0
101431 *
101432 * Unless required by applicable law or agreed to in writing, software
101433 * distributed under the License is distributed on an "AS IS" BASIS,
101434 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101435 * See the License for the specific language governing permissions and
101436 * limitations under the License.
101437 * =============================================================================
101438 */
101439 var subImpl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
101440 return aValue - bValue;
101441 });
101442 var subComplexImpl = createComplexBinaryKernelImpl(function (aReal, aImag, bReal, bImag) {
101443 return {
101444 real: aReal - bReal,
101445 imag: aImag - bImag
101446 };
101447 });
101448 var sub$1 = binaryKernelFunc$1(Sub, subImpl, subComplexImpl);
101449 var subConfig$1 = {
101450 kernelName: Sub,
101451 backendName: 'cpu',
101452 kernelFunc: sub$1
101453 };
101454
101455 /**
101456 * @license
101457 * Copyright 2019 Google LLC. All Rights Reserved.
101458 * Licensed under the Apache License, Version 2.0 (the "License");
101459 * you may not use this file except in compliance with the License.
101460 * You may obtain a copy of the License at
101461 *
101462 * http://www.apache.org/licenses/LICENSE-2.0
101463 *
101464 * Unless required by applicable law or agreed to in writing, software
101465 * distributed under the License is distributed on an "AS IS" BASIS,
101466 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101467 * See the License for the specific language governing permissions and
101468 * limitations under the License.
101469 * =============================================================================
101470 */
101471 /**
101472 * An implementation of the tile kernel shared between webgl and cpu for string
101473 * tensors only.
101474 */
101475 function tileImpl(xBuf, reps) {
101476 var newShape = new Array(xBuf.rank);
101477 for (var i = 0; i < newShape.length; i++) {
101478 newShape[i] = xBuf.shape[i] * reps[i];
101479 }
101480 var result = buffer(newShape, xBuf.dtype);
101481 for (var _i = 0; _i < result.values.length; ++_i) {
101482 var newLoc = result.indexToLoc(_i);
101483 var originalLoc = new Array(xBuf.rank);
101484 for (var j = 0; j < originalLoc.length; j++) {
101485 originalLoc[j] = newLoc[j] % xBuf.shape[j];
101486 }
101487 var originalIndex = xBuf.locToIndex(originalLoc);
101488 result.values[_i] = xBuf.values[originalIndex];
101489 }
101490 return result;
101491 }
101492
101493 /**
101494 * @license
101495 * Copyright 2020 Google LLC. All Rights Reserved.
101496 * Licensed under the Apache License, Version 2.0 (the "License");
101497 * you may not use this file except in compliance with the License.
101498 * You may obtain a copy of the License at
101499 *
101500 * http://www.apache.org/licenses/LICENSE-2.0
101501 *
101502 * Unless required by applicable law or agreed to in writing, software
101503 * distributed under the License is distributed on an "AS IS" BASIS,
101504 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101505 * See the License for the specific language governing permissions and
101506 * limitations under the License.
101507 * =============================================================================
101508 */
101509 var comparePair = function comparePair(a, b) {
101510 var valueDiff = b.value - a.value;
101511 return valueDiff === 0 ? a.index - b.index : valueDiff;
101512 };
101513 /**
101514 * Partitions array where all elements smaller than the (k+1) smallest element
101515 * are found to the left of it, and all larger to the right of it.
101516 * Based on the Floyd-Rivest Algorithm, ref:
101517 * https://en.wikipedia.org/wiki/Floyd%E2%80%93Rivest_algorithm
101518 * @param array: Array to partition
101519 * @param left: Left index for the interval
101520 * @param right: Right index for the interval
101521 * @param k: Desired index value, where array[k] is the (k+1)th smallest element
101522 * when left = 0
101523 */
101524 function select$2(array, k) {
101525 var left = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 0;
101526 var right = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : array.length - 1;
101527 while (right > left) {
101528 // Use select recursively to sample a smaller set of size s
101529 // the arbitrary constants 600 and 0.5 are used in the original
101530 // version to minimize execution time.
101531 if (right - left > 600) {
101532 var n = right - left + 1;
101533 var _i = k - left + 1;
101534 var z = Math.log(n);
101535 var s = 0.5 * Math.exp(2 * z / 3);
101536 var sd = 0.5 * Math.sqrt(z * s * (n - s) / n) * Math.sign(_i - n / 2);
101537 var newLeft = Math.max(left, Math.floor(k - _i * s / n + sd));
101538 var newRight = Math.min(right, Math.floor(k + (n - _i) * s / n + sd));
101539 select$2(array, k, newLeft, newRight);
101540 }
101541 // partition the elements between left and right around t
101542 var t = array[k];
101543 var i = left;
101544 var j = right;
101545 swap(array, left, k);
101546 if (comparePair(array[right], t) > 0) {
101547 swap(array, left, right);
101548 }
101549 while (i < j) {
101550 swap(array, i, j);
101551 i++;
101552 j--;
101553 while (comparePair(array[i], t) < 0) {
101554 i = i + 1;
101555 }
101556 while (comparePair(array[j], t) > 0) {
101557 j = j - 1;
101558 }
101559 }
101560 if (comparePair(array[left], t) === 0) {
101561 swap(array, left, j);
101562 } else {
101563 j = j + 1;
101564 swap(array, j, right);
101565 }
101566 // Adjust left and right towards the boundaries of the subset
101567 // containing the (k - left + 1)th smallest element.
101568 if (j <= k) {
101569 left = j + 1;
101570 }
101571 if (k <= j) {
101572 right = j - 1;
101573 }
101574 }
101575 }
101576 function topKImpl(x, xShape, xDtype, k, sorted) {
101577 // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
101578 var lastDim = xShape[xShape.length - 1];
101579 var batch = x.length / lastDim,
101580 size = lastDim;
101581 var allTopKVals = getTypedArrayFromDType(xDtype, batch * k);
101582 var allTopKIndices = getTypedArrayFromDType('int32', batch * k);
101583 var _loop = function _loop() {
101584 var offset = b * size;
101585 var vals = x.subarray(offset, offset + size);
101586 var valAndInd = new Array(vals.length);
101587 vals.forEach(function (value, index) {
101588 return valAndInd[index] = {
101589 value: value,
101590 index: index
101591 };
101592 });
101593 if (k < valAndInd.length) {
101594 select$2(valAndInd, k);
101595 valAndInd = valAndInd.slice(0, k);
101596 }
101597 if (sorted) {
101598 valAndInd.sort(comparePair);
101599 }
101600 var outOffset = b * k;
101601 var topKVals = allTopKVals.subarray(outOffset, outOffset + k);
101602 var topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
101603 for (var i = 0; i < k; i++) {
101604 topKVals[i] = valAndInd[i].value;
101605 topKIndices[i] = valAndInd[i].index;
101606 }
101607 };
101608 for (var b = 0; b < batch; b++) {
101609 _loop();
101610 }
101611 // Reshape back to the original input shape, except that the last
101612 // dimension is k.
101613 var outputShape = xShape.slice();
101614 outputShape[outputShape.length - 1] = k;
101615 return [buffer(outputShape, xDtype, allTopKVals), buffer(outputShape, 'int32', allTopKIndices)];
101616 }
101617
101618 /**
101619 * @license
101620 * Copyright 2020 Google LLC. All Rights Reserved.
101621 * Licensed under the Apache License, Version 2.0 (the "License");
101622 * you may not use this file except in compliance with the License.
101623 * You may obtain a copy of the License at
101624 *
101625 * http://www.apache.org/licenses/LICENSE-2.0
101626 *
101627 * Unless required by applicable law or agreed to in writing, software
101628 * distributed under the License is distributed on an "AS IS" BASIS,
101629 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101630 * See the License for the specific language governing permissions and
101631 * limitations under the License.
101632 * =============================================================================
101633 */
101634 function uniqueImpl(values, axis, shape, dtype) {
101635 // Normalize and validate axis.
101636 var $axis = parseAxisParam(axis, shape)[0];
101637 // Calculate the new shape that is suitable for extracting data along the
101638 // given axis.
101639 //
101640 // The rank is 3.
101641 // The size of the 1st dimension is the size of all the axes < the given axis.
101642 // The size of the 2nd dimension is the same as the size of the given axis.
101643 // The size of the 3rd dimension is the size of all the axes > the given axis.
101644 //
101645 // For example, for a 4D tensor with shape=[2, 3, 5, 4] and axis=2, the
101646 // newShape would be: [2*3, 5, 4].
101647 //
101648 // Note that this is not the final output shape. This will be the shape for an
101649 // intermediate TensorBuffer (see inputBuffer below) to allow us to extract
101650 // values along the given axis. To demonstrate how it works, consider the
101651 // following example:
101652 //
101653 // Input: a 3D tensor, with shape [1, 2, 3]
101654 // [
101655 // [
101656 // [1,2,3],
101657 // [4,5,6]
101658 // ]
101659 // ]
101660 // Axis: 2 (the last axis).
101661 // Along axis 2, we expect to extract 3 tensors: [1,4], [2,5], [3,6].
101662 //
101663 // For this example, newShape would be: [2, 3, 1], where 2 is calculated from
101664 // 1*2. The re-shaped data would look like:
101665 //
101666 // [
101667 // [
101668 // [1], [2], [3]
101669 // ],
101670 // [
101671 // [4], [5], [6]
101672 // ]
101673 // ]
101674 //
101675 // Then, we can construct a 3-level nested loop by the following dimension
101676 // order to extract the values along the axis (dimension1):
101677 // i: dimension1 // 0,1,2 (newShape[1])
101678 // m: dimension0 // 0,1 (newShape[0])
101679 // n: dimension2 // 0 (newShape[2])
101680 //
101681 // m, i, n
101682 // ---------
101683 // Iteration 0: data at [0, 0, 0] => "1"
101684 // Iteration 1: data at [1, 0, 0] => "4"
101685 // We got [1,4].
101686 // Iteration 2: data at [0, 1, 0] => "2"
101687 // Iteration 3: data at [1, 1, 0] => "5"
101688 // We got [2,5].
101689 // Iteration 4: data at [0, 2, 0] => "3"
101690 // Iteration 5: data at [1, 2, 0] => "6"
101691 // We got [3,6].
101692 var newShape = [1, shape[0], 1];
101693 for (var i = 0; i < $axis; i++) {
101694 newShape[0] *= shape[i];
101695 }
101696 newShape[1] = shape[$axis];
101697 for (var _i = $axis + 1; _i < shape.length; _i++) {
101698 newShape[2] *= shape[_i];
101699 }
101700 // A map from unique elements (their string representations) to their values
101701 // in "indices" (below).
101702 var uniqueElements = new Map();
101703 // The indices of each unique element in the original tensor along the given
101704 // axis. It is 1D and has the same size as the given axis.
101705 var indices = new Int32Array(shape[$axis]);
101706 // Create a buffer so we can easily extract value at a given location.
101707 var inputBuffer = new TensorBuffer(newShape, dtype, values);
101708 // The indices along the given axis that have unique elements. This is a
101709 // de-duped version of "indices" above.
101710 var uniqueIndices = [];
101711 var is1DTensor = newShape[0] === 1 && newShape[2] === 1;
101712 for (var _i2 = 0; _i2 < shape[$axis]; _i2++) {
101713 // Extract values along the axis.
101714 var element = void 0;
101715 if (is1DTensor) {
101716 // Fast path for 1D tensor input.
101717 element = values[_i2].toString();
101718 } else {
101719 var axisValues = [];
101720 for (var m = 0; m < newShape[0]; m++) {
101721 for (var n = 0; n < newShape[2]; n++) {
101722 axisValues.push(inputBuffer.get(m, _i2, n));
101723 }
101724 }
101725 element = axisValues.join(',');
101726 }
101727 // Dedup and update various indices.
101728 var existingIndex = uniqueElements.get(element);
101729 if (existingIndex != null) {
101730 indices[_i2] = existingIndex;
101731 } else {
101732 var uniqueIndex = uniqueElements.size;
101733 uniqueElements.set(element, uniqueIndex);
101734 indices[_i2] = uniqueIndex;
101735 uniqueIndices.push(_i2);
101736 }
101737 }
101738 // Now we know where each of the unique elements are located along the axis
101739 // (uniqueIndices). Extract them from input buffer and store them in the
101740 // output buffer.
101741 var outputTmpShape = newShape.slice();
101742 outputTmpShape[1] = uniqueElements.size;
101743 var outputBuffer = new TensorBuffer(outputTmpShape, dtype);
101744 uniqueIndices.forEach(function (uniqueElementIndex, i) {
101745 for (var _m = 0; _m < newShape[0]; _m++) {
101746 for (var _n = 0; _n < newShape[2]; _n++) {
101747 outputBuffer.set(inputBuffer.get(_m, uniqueElementIndex, _n), _m, i, _n);
101748 }
101749 }
101750 });
101751 // The output shape can be calculated from the input shape with the size of
101752 // the given axis replaced by the number of unique elements along that axis.
101753 var outputShape = shape.slice();
101754 outputShape[$axis] = outputTmpShape[1];
101755 return {
101756 outputValues: outputBuffer.values,
101757 outputShape: outputShape,
101758 indices: indices
101759 };
101760 }
101761
101762 /**
101763 * @license
101764 * Copyright 2020 Google LLC. All Rights Reserved.
101765 * Licensed under the Apache License, Version 2.0 (the "License");
101766 * you may not use this file except in compliance with the License.
101767 * You may obtain a copy of the License at
101768 *
101769 * http://www.apache.org/licenses/LICENSE-2.0
101770 *
101771 * Unless required by applicable law or agreed to in writing, software
101772 * distributed under the License is distributed on an "AS IS" BASIS,
101773 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101774 * See the License for the specific language governing permissions and
101775 * limitations under the License.
101776 * =============================================================================
101777 */
101778
101779 var shared = {
101780 __proto__: null,
101781 addImpl: addImpl,
101782 bincountImpl: bincountImpl,
101783 bincountReduceImpl: bincountReduceImpl,
101784 bitwiseAndImpl: bitwiseAndImpl,
101785 castImpl: castImpl,
101786 ceilImpl: ceilImpl,
101787 concatImpl: concatImpl$1,
101788 equalImpl: equalImpl,
101789 expImpl: expImpl,
101790 expm1Impl: expm1Impl,
101791 floorDivImpl: floorDivImpl,
101792 floorImpl: floorImpl,
101793 gatherNdImpl: gatherNdImpl,
101794 gatherV2Impl: gatherV2Impl,
101795 greaterEqualImpl: greaterEqualImpl,
101796 greaterImpl: greaterImpl,
101797 lessEqualImpl: lessEqualImpl,
101798 lessImpl: lessImpl,
101799 linSpaceImpl: linSpaceImpl,
101800 logImpl: logImpl,
101801 maxImpl: maxImpl$1,
101802 maximumImpl: maximumImpl,
101803 minimumImpl: minimumImpl,
101804 multiplyImpl: multiplyImpl,
101805 negImpl: negImpl,
101806 notEqualImpl: notEqualImpl,
101807 prodImpl: prodImpl,
101808 raggedGatherImpl: raggedGatherImpl,
101809 raggedRangeImpl: raggedRangeImpl,
101810 raggedTensorToTensorImpl: raggedTensorToTensorImpl,
101811 rangeImpl: rangeImpl,
101812 rsqrtImpl: rsqrtImpl,
101813 scatterImpl: scatterImpl,
101814 sigmoidImpl: sigmoidImpl,
101815 simpleAbsImpl: simpleAbsImpl,
101816 sliceImpl: sliceImpl,
101817 sparseFillEmptyRowsImpl: sparseFillEmptyRowsImpl,
101818 sparseReshapeImpl: sparseReshapeImpl,
101819 sparseSegmentReductionImpl: sparseSegmentReductionImpl,
101820 sqrtImpl: sqrtImpl,
101821 squaredDifferenceImpl: squaredDifferenceImpl,
101822 staticRegexReplaceImpl: staticRegexReplaceImpl,
101823 stridedSliceImpl: stridedSliceImpl,
101824 stringNGramsImpl: stringNGramsImpl,
101825 stringSplitImpl: stringSplitImpl,
101826 stringToHashBucketFastImpl: stringToHashBucketFastImpl,
101827 subImpl: subImpl,
101828 tileImpl: tileImpl,
101829 topKImpl: topKImpl,
101830 transposeImpl: transposeImpl$1,
101831 uniqueImpl: uniqueImpl
101832 };
101833
101834 /** @license See the LICENSE file. */
101835 // This code is auto-generated, do not modify this file!
101836 var version$3 = '4.22.0';
101837
101838 /**
101839 * @license
101840 * Copyright 2020 Google LLC. All Rights Reserved.
101841 * Licensed under the Apache License, Version 2.0 (the "License");
101842 * you may not use this file except in compliance with the License.
101843 * You may obtain a copy of the License at
101844 *
101845 * http://www.apache.org/licenses/LICENSE-2.0
101846 *
101847 * Unless required by applicable law or agreed to in writing, software
101848 * distributed under the License is distributed on an "AS IS" BASIS,
101849 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101850 * See the License for the specific language governing permissions and
101851 * limitations under the License.
101852 * =============================================================================
101853 */
101854 // Side effects for default initialization of MathBackendCPU
101855 registerBackend('cpu', function () {
101856 return new MathBackendCPU();
101857 }, 1 /* priority */);
101858
101859 /**
101860 * @license
101861 * Copyright 2020 Google LLC. All Rights Reserved.
101862 * Licensed under the Apache License, Version 2.0 (the License);
101863 * you may not use this file except in compliance with the License.
101864 * You may obtain a copy of the License at
101865 *
101866 * http://www.apache.org/licenses/LICENSE-2.0
101867 *
101868 * Unless required by applicable law or agreed to in writing, software
101869 * distributed under the License is distributed on an AS IS BASIS,
101870 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101871 * See the License for the specific language governing permissions and
101872 * limitations under the License.
101873 * =============================================================================
101874 */
101875 var elu$1 = unaryKernelFunc$1(Elu$1, function (xi) {
101876 return xi >= 0 ? xi : Math.exp(xi) - 1;
101877 });
101878 var eluConfig$1 = {
101879 kernelName: Elu$1,
101880 backendName: 'cpu',
101881 kernelFunc: elu$1
101882 };
101883
101884 /**
101885 * @license
101886 * Copyright 2020 Google LLC. All Rights Reserved.
101887 * Licensed under the Apache License, Version 2.0 (the "License");
101888 * you may not use this file except in compliance with the License.
101889 * You may obtain a copy of the License at
101890 *
101891 * http://www.apache.org/licenses/LICENSE-2.0
101892 *
101893 * Unless required by applicable law or agreed to in writing, software
101894 * distributed under the License is distributed on an "AS IS" BASIS,
101895 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101896 * See the License for the specific language governing permissions and
101897 * limitations under the License.
101898 * =============================================================================
101899 */
101900 function leakyRelu$1(args) {
101901 var inputs = args.inputs,
101902 backend = args.backend,
101903 attrs = args.attrs;
101904 var x = inputs.x;
101905 var alpha = attrs.alpha;
101906 assertNotComplex$1([x], 'leakyRelu');
101907 var xSize = sizeFromShape(x.shape);
101908 var xVals = backend.data.get(x.dataId).values;
101909 var outVals = getTypedArrayFromDType('float32', xSize);
101910 for (var i = 0; i < xVals.length; i++) {
101911 outVals[i] = xVals[i] < 0 ? alpha * xVals[i] : xVals[i];
101912 }
101913 return backend.makeTensorInfo(x.shape, 'float32', outVals);
101914 }
101915 var leakyReluConfig$1 = {
101916 kernelName: LeakyRelu,
101917 backendName: 'cpu',
101918 kernelFunc: leakyRelu$1
101919 };
101920
101921 var preluImpl = createSimpleBinaryKernelImpl(function (xValue, aValue) {
101922 return xValue < 0 ? aValue * xValue : xValue;
101923 });
101924 function prelu$1(args) {
101925 var inputs = args.inputs,
101926 backend = args.backend;
101927 var x = inputs.x,
101928 alpha = inputs.alpha;
101929 assertNotComplex$1([x, alpha], 'prelu');
101930 var aVals = backend.data.get(x.dataId).values;
101931 var bVals = backend.data.get(alpha.dataId).values;
101932 var _preluImpl = preluImpl(x.shape, alpha.shape, aVals, bVals, 'float32'),
101933 _preluImpl2 = _slicedToArray(_preluImpl, 2),
101934 resultData = _preluImpl2[0],
101935 resultShape = _preluImpl2[1];
101936 return backend.makeTensorInfo(resultShape, 'float32', resultData);
101937 }
101938 var preluConfig$1 = {
101939 kernelName: Prelu,
101940 backendName: 'cpu',
101941 kernelFunc: prelu$1
101942 };
101943
101944 /**
101945 * @license
101946 * Copyright 2020 Google LLC. All Rights Reserved.
101947 * Licensed under the Apache License, Version 2.0 (the License);
101948 * you may not use this file except in compliance with the License.
101949 * You may obtain a copy of the License at
101950 *
101951 * http://www.apache.org/licenses/LICENSE-2.0
101952 *
101953 * Unless required by applicable law or agreed to in writing, software
101954 * distributed under the License is distributed on an AS IS BASIS,
101955 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101956 * See the License for the specific language governing permissions and
101957 * limitations under the License.
101958 * =============================================================================
101959 */
101960 var relu$1 = unaryKernelFunc$1(Relu$1, function (xi) {
101961 return Math.max(0, xi);
101962 });
101963 var reluConfig$1 = {
101964 kernelName: Relu$1,
101965 backendName: 'cpu',
101966 kernelFunc: relu$1
101967 };
101968
101969 /**
101970 * @license
101971 * Copyright 2020 Google LLC. All Rights Reserved.
101972 * Licensed under the Apache License, Version 2.0 (the License);
101973 * you may not use this file except in compliance with the License.
101974 * You may obtain a copy of the License at
101975 *
101976 * http://www.apache.org/licenses/LICENSE-2.0
101977 *
101978 * Unless required by applicable law or agreed to in writing, software
101979 * distributed under the License is distributed on an AS IS BASIS,
101980 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101981 * See the License for the specific language governing permissions and
101982 * limitations under the License.
101983 * =============================================================================
101984 */
101985 var relu6$1 = unaryKernelFunc$1(Relu6$1, function (xi) {
101986 return Math.min(Math.max(0, xi), 6);
101987 });
101988 var relu6Config$1 = {
101989 kernelName: Relu6$1,
101990 backendName: 'cpu',
101991 kernelFunc: relu6$1
101992 };
101993
101994 /**
101995 * @license
101996 * Copyright 2020 Google LLC. All Rights Reserved.
101997 * Licensed under the Apache License, Version 2.0 (the "License");
101998 * you may not use this file except in compliance with the License.
101999 * You may obtain a copy of the License at
102000 *
102001 * http://www.apache.org/licenses/LICENSE-2.0
102002 *
102003 * Unless required by applicable law or agreed to in writing, software
102004 * distributed under the License is distributed on an "AS IS" BASIS,
102005 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102006 * See the License for the specific language governing permissions and
102007 * limitations under the License.
102008 * =============================================================================
102009 */
102010 function applyActivation(backend, x, activation, preluActivationWeights, leakyreluAlpha) {
102011 if (activation === 'linear') {
102012 return identity$1({
102013 inputs: {
102014 x: x
102015 },
102016 backend: backend
102017 });
102018 } else if (activation === 'relu') {
102019 return relu$1({
102020 inputs: {
102021 x: x
102022 },
102023 backend: backend
102024 });
102025 } else if (activation === 'elu') {
102026 return elu$1({
102027 inputs: {
102028 x: x
102029 },
102030 backend: backend
102031 });
102032 } else if (activation === 'relu6') {
102033 return relu6$1({
102034 inputs: {
102035 x: x
102036 },
102037 backend: backend
102038 });
102039 } else if (activation === 'prelu') {
102040 return prelu$1({
102041 inputs: {
102042 x: x,
102043 alpha: preluActivationWeights
102044 },
102045 backend: backend
102046 });
102047 } else if (activation === 'leakyrelu') {
102048 return leakyRelu$1({
102049 inputs: {
102050 x: x
102051 },
102052 backend: backend,
102053 attrs: {
102054 alpha: leakyreluAlpha
102055 }
102056 });
102057 } else if (activation === 'sigmoid') {
102058 return sigmoid$1({
102059 inputs: {
102060 x: x
102061 },
102062 backend: backend
102063 });
102064 }
102065 throw new Error("Activation ".concat(activation, " has not been implemented for the CPU backend."));
102066 }
102067
102068 /**
102069 * @license
102070 * Copyright 2020 Google LLC. All Rights Reserved.
102071 * Licensed under the Apache License, Version 2.0 (the "License");
102072 * you may not use this file except in compliance with the License.
102073 * You may obtain a copy of the License at
102074 *
102075 * http://www.apache.org/licenses/LICENSE-2.0
102076 *
102077 * Unless required by applicable law or agreed to in writing, software
102078 * distributed under the License is distributed on an "AS IS" BASIS,
102079 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102080 * See the License for the specific language governing permissions and
102081 * limitations under the License.
102082 * =============================================================================
102083 */
102084 function reshape$1(args) {
102085 var inputs = args.inputs,
102086 backend = args.backend,
102087 attrs = args.attrs;
102088 var x = inputs.x;
102089 var shape = attrs.shape;
102090 var xSize = sizeFromShape(x.shape);
102091 var $shape = inferFromImplicitShape(shape, xSize);
102092 var $xSize = sizeFromShape($shape);
102093 assert$1(xSize === $xSize, function () {
102094 return "The new shape (".concat($shape, ") has ").concat($xSize, " elements and the old ") + "shape (".concat(x.shape, ") has ").concat(xSize, " elements. The new shape and old ") + "shape must have the same number of elements.";
102095 });
102096 backend.incRef(x.dataId);
102097 var xData = backend.data.get(x.dataId);
102098 if (xData.complexTensorInfos != null) {
102099 var real = xData.complexTensorInfos.real;
102100 var imag = xData.complexTensorInfos.imag;
102101 real.shape = $shape;
102102 imag.shape = $shape;
102103 }
102104 return {
102105 dataId: x.dataId,
102106 shape: $shape,
102107 dtype: x.dtype
102108 };
102109 }
102110 var reshapeConfig$1 = {
102111 kernelName: Reshape$1,
102112 backendName: 'cpu',
102113 kernelFunc: reshape$1
102114 };
102115
102116 function batchMatMul$1(args) {
102117 var inputs = args.inputs,
102118 backend = args.backend,
102119 attrs = args.attrs;
102120 var a = inputs.a,
102121 b = inputs.b;
102122 var transposeA = attrs.transposeA,
102123 transposeB = attrs.transposeB;
102124 assertNotComplex$1([a, b], 'matMul');
102125 var aRank = a.shape.length;
102126 var bRank = b.shape.length;
102127 var innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
102128 var innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
102129 var outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
102130 var outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
102131 var outerDimsA = a.shape.slice(0, -2);
102132 var outerDimsB = b.shape.slice(0, -2);
102133 var batchDimA = sizeFromShape(outerDimsA);
102134 var batchDimB = sizeFromShape(outerDimsB);
102135 var outShapeOuterDims = assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
102136 var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
102137 assert$1(innerShapeA === innerShapeB, function () {
102138 return "Error in matMul: inner shapes (".concat(innerShapeA, ") and (") + "".concat(innerShapeB, ") of Tensors with shapes ").concat(a.shape, " and ") + "".concat(b.shape, " and transposeA=").concat(transposeA) + " and transposeB=".concat(transposeB, " must match.");
102139 });
102140 var a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] : [batchDimA, outerShapeA, innerShapeA];
102141 var b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] : [batchDimB, innerShapeB, outerShapeB];
102142 // The rest of the implementation is designed to operate on rank-3 tensors
102143 var a3d = reshape$1({
102144 inputs: {
102145 x: a
102146 },
102147 backend: backend,
102148 attrs: {
102149 shape: a3dShape
102150 }
102151 });
102152 var b3d = reshape$1({
102153 inputs: {
102154 x: b
102155 },
102156 backend: backend,
102157 attrs: {
102158 shape: b3dShape
102159 }
102160 });
102161 var sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
102162 var leftDim = transposeA ? a3d.shape[2] : a3d.shape[1];
102163 var rightDim = transposeB ? b3d.shape[1] : b3d.shape[2];
102164 var batchDim = Math.max(batchDimA, batchDimB);
102165 var a3dValues = backend.data.get(a3d.dataId).values;
102166 var b3dValues = backend.data.get(b3d.dataId).values;
102167 var a3dStrides = computeStrides(a3d.shape);
102168 var b3dStrides = computeStrides(b3d.shape);
102169 var _ref = transposeA ? [a3dStrides[0], 1, a3dStrides[1]] : [a3dStrides[0], a3dStrides[1], 1],
102170 _ref2 = _slicedToArray(_ref, 3),
102171 aBatch = _ref2[0],
102172 aOuterStep = _ref2[1],
102173 aInnerStep = _ref2[2];
102174 var _ref3 = transposeB ? [1, b3dStrides[1], b3dStrides[0]] : [b3dStrides[1], 1, b3dStrides[0]],
102175 _ref4 = _slicedToArray(_ref3, 3),
102176 bInnerStep = _ref4[0],
102177 bOuterStep = _ref4[1],
102178 bBatch = _ref4[2];
102179 var size = leftDim * rightDim;
102180 var result = buffer([batchDim, leftDim, rightDim], a3d.dtype);
102181 var resVals = result.values;
102182 var blockSize = backend.blockSize;
102183 for (var bi = 0; bi < batchDim; bi++) {
102184 var batchIndexA = bi % batchDimA;
102185 var batchIndexB = bi % batchDimB;
102186 for (var i0 = 0; i0 < leftDim; i0 += blockSize) {
102187 // for when blockSize doesn't evenly divide the input
102188 var iBlock = Math.min(i0 + blockSize, leftDim);
102189 for (var j0 = 0; j0 < rightDim; j0 += blockSize) {
102190 var jBlock = Math.min(j0 + blockSize, rightDim);
102191 for (var k0 = 0; k0 < sharedDim; k0 += blockSize) {
102192 var kBlock = Math.min(k0 + blockSize, sharedDim);
102193 for (var i = i0; i < iBlock; i++) {
102194 for (var j = j0; j < jBlock; j++) {
102195 var sum = 0.0;
102196 for (var k = k0; k < kBlock; k++) {
102197 var aVal =
102198 // tslint:disable-next-line: max-line-length
102199 a3dValues[batchIndexA * aBatch + i * aOuterStep + k * aInnerStep];
102200 var bVal =
102201 // tslint:disable-next-line: max-line-length
102202 b3dValues[k * bInnerStep + j * bOuterStep + batchIndexB * bBatch];
102203 sum += aVal * bVal;
102204 }
102205 resVals[bi * size + (i * rightDim + j)] += sum;
102206 }
102207 }
102208 }
102209 }
102210 }
102211 }
102212 backend.disposeIntermediateTensorInfo(a3d);
102213 backend.disposeIntermediateTensorInfo(b3d);
102214 // set correct shape on output.
102215 return backend.makeTensorInfo(outShape, result.dtype, result.values);
102216 }
102217 var batchMatMulConfig$1 = {
102218 kernelName: BatchMatMul,
102219 backendName: 'cpu',
102220 kernelFunc: batchMatMul$1
102221 };
102222
102223 /**
102224 * @license
102225 * Copyright 2020 Google LLC. All Rights Reserved.
102226 * Licensed under the Apache License, Version 2.0 (the License);
102227 * you may not use this file except in compliance with the License.
102228 * You may obtain a copy of the License at
102229 *
102230 * http://www.apache.org/licenses/LICENSE-2.0
102231 *
102232 * Unless required by applicable law or agreed to in writing, software
102233 * distributed under the License is distributed on an AS IS BASIS,
102234 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102235 * See the License for the specific language governing permissions and
102236 * limitations under the License.
102237 * =============================================================================
102238 */
102239 function _fusedMatMul$1(args) {
102240 var inputs = args.inputs,
102241 backend = args.backend,
102242 attrs = args.attrs;
102243 var a = inputs.a,
102244 b = inputs.b,
102245 bias = inputs.bias,
102246 preluActivationWeights = inputs.preluActivationWeights;
102247 var transposeA = attrs.transposeA,
102248 transposeB = attrs.transposeB,
102249 activation = attrs.activation,
102250 leakyreluAlpha = attrs.leakyreluAlpha;
102251 var current;
102252 var addRes;
102253 var activationRes;
102254 var intermediates = [];
102255 var matMulRes = batchMatMul$1({
102256 inputs: {
102257 a: a,
102258 b: b
102259 },
102260 attrs: {
102261 transposeA: transposeA,
102262 transposeB: transposeB
102263 },
102264 backend: backend
102265 });
102266 current = matMulRes;
102267 if (bias) {
102268 addRes = add({
102269 inputs: {
102270 a: current,
102271 b: bias
102272 },
102273 backend: backend
102274 });
102275 intermediates.push(current);
102276 current = addRes;
102277 }
102278 if (activation) {
102279 activationRes = applyActivation(backend, current, activation, preluActivationWeights, leakyreluAlpha);
102280 intermediates.push(current);
102281 current = activationRes;
102282 }
102283 for (var _i = 0, _intermediates = intermediates; _i < _intermediates.length; _i++) {
102284 var i = _intermediates[_i];
102285 backend.disposeIntermediateTensorInfo(i);
102286 }
102287 return current;
102288 }
102289 var _fusedMatMulConfig$1 = {
102290 kernelName: _FusedMatMul,
102291 backendName: 'cpu',
102292 kernelFunc: _fusedMatMul$1
102293 };
102294
102295 /**
102296 * @license
102297 * Copyright 2020 Google LLC. All Rights Reserved.
102298 * Licensed under the Apache License, Version 2.0 (the License);
102299 * you may not use this file except in compliance with the License.
102300 * You may obtain a copy of the License at
102301 *
102302 * http://www.apache.org/licenses/LICENSE-2.0
102303 *
102304 * Unless required by applicable law or agreed to in writing, software
102305 * distributed under the License is distributed on an AS IS BASIS,
102306 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102307 * See the License for the specific language governing permissions and
102308 * limitations under the License.
102309 * =============================================================================
102310 */
102311 var acos$1 = unaryKernelFunc$1(Acos, function (xi) {
102312 return Math.acos(xi);
102313 });
102314 var acosConfig$1 = {
102315 kernelName: Acos,
102316 backendName: 'cpu',
102317 kernelFunc: acos$1
102318 };
102319
102320 /**
102321 * @license
102322 * Copyright 2020 Google LLC. All Rights Reserved.
102323 * Licensed under the Apache License, Version 2.0 (the License);
102324 * you may not use this file except in compliance with the License.
102325 * You may obtain a copy of the License at
102326 *
102327 * http://www.apache.org/licenses/LICENSE-2.0
102328 *
102329 * Unless required by applicable law or agreed to in writing, software
102330 * distributed under the License is distributed on an AS IS BASIS,
102331 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102332 * See the License for the specific language governing permissions and
102333 * limitations under the License.
102334 * =============================================================================
102335 */
102336 var acosh$1 = unaryKernelFunc$1(Acosh, function (xi) {
102337 return Math.acosh(xi);
102338 });
102339 var acoshConfig$1 = {
102340 kernelName: Acosh,
102341 backendName: 'cpu',
102342 kernelFunc: acosh$1
102343 };
102344
102345 /**
102346 * @license
102347 * Copyright 2020 Google LLC. All Rights Reserved.
102348 * Licensed under the Apache License, Version 2.0 (the "License");
102349 * you may not use this file except in compliance with the License.
102350 * You may obtain a copy of the License at
102351 *
102352 * http://www.apache.org/licenses/LICENSE-2.0
102353 *
102354 * Unless required by applicable law or agreed to in writing, software
102355 * distributed under the License is distributed on an "AS IS" BASIS,
102356 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102357 * See the License for the specific language governing permissions and
102358 * limitations under the License.
102359 * =============================================================================
102360 */
102361 function addN$1(args) {
102362 var inputs = args.inputs,
102363 backend = args.backend;
102364 var tensors = inputs;
102365 assertNotComplex$1(inputs, 'addN');
102366 var vals = tensors.map(function (t) {
102367 return backend.data.get(t.dataId).values;
102368 });
102369 var outBuf = buffer(tensors[0].shape, tensors[0].dtype);
102370 var outVals = outBuf.values;
102371 for (var i = 0; i < tensors.length; i++) {
102372 var currVals = vals[i];
102373 for (var j = 0; j < outVals.length; j++) {
102374 outVals[j] += currVals[j];
102375 }
102376 }
102377 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
102378 }
102379 var addNConfig$1 = {
102380 kernelName: AddN,
102381 backendName: 'cpu',
102382 kernelFunc: addN$1
102383 };
102384
102385 function all$1(args) {
102386 var inputs = args.inputs,
102387 backend = args.backend,
102388 attrs = args.attrs;
102389 var x = inputs.x;
102390 var axis = attrs.axis,
102391 keepDims = attrs.keepDims;
102392 assertNotComplex$1(x, 'all');
102393 var origAxes = parseAxisParam(axis, x.shape);
102394 var axes = origAxes;
102395 var permutedAxes = getAxesPermutation(axes, x.shape.length);
102396 var $x = x;
102397 if (permutedAxes != null) {
102398 $x = transpose$1({
102399 inputs: {
102400 x: x
102401 },
102402 backend: backend,
102403 attrs: {
102404 perm: permutedAxes
102405 }
102406 });
102407 axes = getInnerMostAxes(axes.length, x.shape.length);
102408 }
102409 assertAxesAreInnerMostDims('all', axes, $x.shape.length);
102410 var _backend_util$compute = computeOutAndReduceShapes($x.shape, axes),
102411 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
102412 outShape = _backend_util$compute2[0],
102413 reduceShape = _backend_util$compute2[1];
102414 var reduceSize = sizeFromShape(reduceShape);
102415 var vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
102416 var aVals = backend.data.get($x.dataId).values;
102417 for (var i = 0; i < vals.length; ++i) {
102418 var offset = i * reduceSize;
102419 var _all = aVals[offset];
102420 for (var j = 0; j < reduceSize; ++j) {
102421 var value = aVals[offset + j];
102422 _all = _all && value;
102423 }
102424 vals[i] = _all;
102425 }
102426 if (permutedAxes != null) {
102427 backend.disposeIntermediateTensorInfo($x);
102428 }
102429 var result = backend.makeTensorInfo(outShape, $x.dtype, vals);
102430 if (keepDims) {
102431 var expandedShape = expandShapeToKeepDim(outShape, origAxes);
102432 var reshapedResult = reshape$1({
102433 inputs: {
102434 x: result
102435 },
102436 backend: backend,
102437 attrs: {
102438 shape: expandedShape
102439 }
102440 });
102441 backend.disposeIntermediateTensorInfo(result);
102442 return reshapedResult;
102443 }
102444 return result;
102445 }
102446 var allConfig$1 = {
102447 kernelName: All,
102448 backendName: 'cpu',
102449 kernelFunc: all$1
102450 };
102451
102452 function any$1(args) {
102453 var inputs = args.inputs,
102454 backend = args.backend,
102455 attrs = args.attrs;
102456 var x = inputs.x;
102457 var axis = attrs.axis,
102458 keepDims = attrs.keepDims;
102459 assertNotComplex$1(x, 'any');
102460 var origAxes = parseAxisParam(axis, x.shape);
102461 var axes = origAxes;
102462 var permutedAxes = getAxesPermutation(axes, x.shape.length);
102463 var $x = x;
102464 if (permutedAxes != null) {
102465 $x = transpose$1({
102466 inputs: {
102467 x: x
102468 },
102469 backend: backend,
102470 attrs: {
102471 perm: permutedAxes
102472 }
102473 });
102474 axes = getInnerMostAxes(axes.length, x.shape.length);
102475 }
102476 assertAxesAreInnerMostDims('any', axes, $x.shape.length);
102477 var _backend_util$compute = computeOutAndReduceShapes($x.shape, axes),
102478 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
102479 outShape = _backend_util$compute2[0],
102480 reduceShape = _backend_util$compute2[1];
102481 var reduceSize = sizeFromShape(reduceShape);
102482 var vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
102483 var aVals = backend.data.get($x.dataId).values;
102484 for (var i = 0; i < vals.length; ++i) {
102485 var offset = i * reduceSize;
102486 var anyVal = aVals[offset];
102487 for (var j = 0; j < reduceSize; ++j) {
102488 var value = aVals[offset + j];
102489 anyVal = anyVal || value;
102490 }
102491 vals[i] = anyVal;
102492 }
102493 if (permutedAxes != null) {
102494 backend.disposeIntermediateTensorInfo($x);
102495 }
102496 var result = backend.makeTensorInfo(outShape, $x.dtype, vals);
102497 if (keepDims) {
102498 var expandedShape = expandShapeToKeepDim(outShape, origAxes);
102499 var reshapedResult = reshape$1({
102500 inputs: {
102501 x: result
102502 },
102503 backend: backend,
102504 attrs: {
102505 shape: expandedShape
102506 }
102507 });
102508 backend.disposeIntermediateTensorInfo(result);
102509 return reshapedResult;
102510 }
102511 return result;
102512 }
102513 var anyConfig$1 = {
102514 kernelName: Any,
102515 backendName: 'cpu',
102516 kernelFunc: any$1
102517 };
102518
102519 function argMax$1(args) {
102520 var inputs = args.inputs,
102521 backend = args.backend,
102522 attrs = args.attrs;
102523 var x = inputs.x;
102524 var axis = attrs.axis;
102525 assertNotComplex$1(x, 'argMax');
102526 var axes = parseAxisParam(axis, x.shape);
102527 var permutedAxes = getAxesPermutation(axes, x.shape.length);
102528 var $x = x;
102529 var intermediateTensorInfos = [];
102530 if (permutedAxes != null) {
102531 $x = transpose$1({
102532 inputs: {
102533 x: x
102534 },
102535 backend: backend,
102536 attrs: {
102537 perm: permutedAxes
102538 }
102539 });
102540 intermediateTensorInfos.push($x);
102541 axes = getInnerMostAxes(axes.length, $x.shape.length);
102542 }
102543 axes = [axes[0]];
102544 assertAxesAreInnerMostDims('argMax', axes, $x.shape.length);
102545 var _backend_util$compute = computeOutAndReduceShapes($x.shape, axes),
102546 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
102547 outShape = _backend_util$compute2[0],
102548 reduceShape = _backend_util$compute2[1];
102549 var outSize = sizeFromShape(outShape);
102550 var vals = makeZerosTypedArray(outSize, 'int32');
102551 var reduceSize = sizeFromShape(reduceShape);
102552 var aVals = backend.data.get($x.dataId).values;
102553 for (var i = 0; i < vals.length; ++i) {
102554 var offset = i * reduceSize;
102555 var max = aVals[offset];
102556 var maxIndex = 0;
102557 for (var j = 0; j < reduceSize; ++j) {
102558 var value = aVals[offset + j];
102559 if (value > max) {
102560 max = value;
102561 maxIndex = j;
102562 }
102563 }
102564 vals[i] = maxIndex;
102565 }
102566 intermediateTensorInfos.forEach(function (t) {
102567 return backend.disposeIntermediateTensorInfo(t);
102568 });
102569 return backend.makeTensorInfo(outShape, 'int32', vals);
102570 }
102571 var argMaxConfig$1 = {
102572 kernelName: ArgMax,
102573 backendName: 'cpu',
102574 kernelFunc: argMax$1
102575 };
102576
102577 function argMin$1(args) {
102578 var inputs = args.inputs,
102579 backend = args.backend,
102580 attrs = args.attrs;
102581 var x = inputs.x;
102582 var axis = attrs.axis;
102583 assertNotComplex$1(x, 'argMin');
102584 var axes = parseAxisParam(axis, x.shape);
102585 var permutedAxes = getAxesPermutation(axes, x.shape.length);
102586 var $x = x;
102587 var intermediateTensorInfos = [];
102588 if (permutedAxes != null) {
102589 $x = transpose$1({
102590 inputs: {
102591 x: x
102592 },
102593 backend: backend,
102594 attrs: {
102595 perm: permutedAxes
102596 }
102597 });
102598 intermediateTensorInfos.push($x);
102599 axes = getInnerMostAxes(axes.length, $x.shape.length);
102600 }
102601 axes = [axes[0]];
102602 assertAxesAreInnerMostDims('argMin', axes, $x.shape.length);
102603 var _backend_util$compute = computeOutAndReduceShapes($x.shape, axes),
102604 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
102605 outShape = _backend_util$compute2[0],
102606 reduceShape = _backend_util$compute2[1];
102607 var outSize = sizeFromShape(outShape);
102608 var vals = makeZerosTypedArray(outSize, 'int32');
102609 var reduceSize = sizeFromShape(reduceShape);
102610 var aVals = backend.data.get($x.dataId).values;
102611 for (var i = 0; i < vals.length; ++i) {
102612 var offset = i * reduceSize;
102613 var min = aVals[offset];
102614 var minIndex = 0;
102615 for (var j = 0; j < reduceSize; ++j) {
102616 var value = aVals[offset + j];
102617 if (value < min) {
102618 min = value;
102619 minIndex = j;
102620 }
102621 }
102622 vals[i] = minIndex;
102623 }
102624 intermediateTensorInfos.forEach(function (t) {
102625 return backend.disposeIntermediateTensorInfo(t);
102626 });
102627 return backend.makeTensorInfo(outShape, 'int32', vals);
102628 }
102629 var argMinConfig$1 = {
102630 kernelName: ArgMin,
102631 backendName: 'cpu',
102632 kernelFunc: argMin$1
102633 };
102634
102635 /**
102636 * @license
102637 * Copyright 2020 Google LLC. All Rights Reserved.
102638 * Licensed under the Apache License, Version 2.0 (the License);
102639 * you may not use this file except in compliance with the License.
102640 * You may obtain a copy of the License at
102641 *
102642 * http://www.apache.org/licenses/LICENSE-2.0
102643 *
102644 * Unless required by applicable law or agreed to in writing, software
102645 * distributed under the License is distributed on an AS IS BASIS,
102646 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102647 * See the License for the specific language governing permissions and
102648 * limitations under the License.
102649 * =============================================================================
102650 */
102651 var asin$1 = unaryKernelFunc$1(Asin, function (xi) {
102652 return Math.asin(xi);
102653 });
102654 var asinConfig$1 = {
102655 kernelName: Asin,
102656 backendName: 'cpu',
102657 kernelFunc: asin$1
102658 };
102659
102660 /**
102661 * @license
102662 * Copyright 2020 Google LLC. All Rights Reserved.
102663 * Licensed under the Apache License, Version 2.0 (the License);
102664 * you may not use this file except in compliance with the License.
102665 * You may obtain a copy of the License at
102666 *
102667 * http://www.apache.org/licenses/LICENSE-2.0
102668 *
102669 * Unless required by applicable law or agreed to in writing, software
102670 * distributed under the License is distributed on an AS IS BASIS,
102671 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102672 * See the License for the specific language governing permissions and
102673 * limitations under the License.
102674 * =============================================================================
102675 */
102676 var asinh$1 = unaryKernelFunc$1(Asinh, function (xi) {
102677 return Math.asinh(xi);
102678 });
102679 var asinhConfig$1 = {
102680 kernelName: Asinh,
102681 backendName: 'cpu',
102682 kernelFunc: asinh$1
102683 };
102684
102685 /**
102686 * @license
102687 * Copyright 2020 Google LLC. All Rights Reserved.
102688 * Licensed under the Apache License, Version 2.0 (the License);
102689 * you may not use this file except in compliance with the License.
102690 * You may obtain a copy of the License at
102691 *
102692 * http://www.apache.org/licenses/LICENSE-2.0
102693 *
102694 * Unless required by applicable law or agreed to in writing, software
102695 * distributed under the License is distributed on an AS IS BASIS,
102696 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102697 * See the License for the specific language governing permissions and
102698 * limitations under the License.
102699 * =============================================================================
102700 */
102701 var atan$1 = unaryKernelFunc$1(Atan, function (xi) {
102702 return Math.atan(xi);
102703 });
102704 var atanConfig$1 = {
102705 kernelName: Atan,
102706 backendName: 'cpu',
102707 kernelFunc: atan$1
102708 };
102709
102710 /**
102711 * @license
102712 * Copyright 2020 Google LLC. All Rights Reserved.
102713 * Licensed under the Apache License, Version 2.0 (the License);
102714 * you may not use this file except in compliance with the License.
102715 * You may obtain a copy of the License at
102716 *
102717 * http://www.apache.org/licenses/LICENSE-2.0
102718 *
102719 * Unless required by applicable law or agreed to in writing, software
102720 * distributed under the License is distributed on an AS IS BASIS,
102721 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102722 * See the License for the specific language governing permissions and
102723 * limitations under the License.
102724 * =============================================================================
102725 */
102726 var atan2Impl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
102727 return Math.atan2(aValue, bValue);
102728 });
102729 var atan2$1 = binaryKernelFunc$1(Atan2, atan2Impl);
102730 var atan2Config$1 = {
102731 kernelName: Atan2,
102732 backendName: 'cpu',
102733 kernelFunc: atan2$1
102734 };
102735
102736 /**
102737 * @license
102738 * Copyright 2020 Google LLC. All Rights Reserved.
102739 * Licensed under the Apache License, Version 2.0 (the License);
102740 * you may not use this file except in compliance with the License.
102741 * You may obtain a copy of the License at
102742 *
102743 * http://www.apache.org/licenses/LICENSE-2.0
102744 *
102745 * Unless required by applicable law or agreed to in writing, software
102746 * distributed under the License is distributed on an AS IS BASIS,
102747 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102748 * See the License for the specific language governing permissions and
102749 * limitations under the License.
102750 * =============================================================================
102751 */
102752 var atanh$1 = unaryKernelFunc$1(Atanh, function (xi) {
102753 return Math.atanh(xi);
102754 });
102755 var atanhConfig$1 = {
102756 kernelName: Atanh,
102757 backendName: 'cpu',
102758 kernelFunc: atanh$1
102759 };
102760
102761 /**
102762 * @license
102763 * Copyright 2020 Google LLC. All Rights Reserved.
102764 * Licensed under the Apache License, Version 2.0 (the "License");
102765 * you may not use this file except in compliance with the License.
102766 * You may obtain a copy of the License at
102767 *
102768 * http://www.apache.org/licenses/LICENSE-2.0
102769 *
102770 * Unless required by applicable law or agreed to in writing, software
102771 * distributed under the License is distributed on an "AS IS" BASIS,
102772 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102773 * See the License for the specific language governing permissions and
102774 * limitations under the License.
102775 * =============================================================================
102776 */
102777 function pool(xValues, xShape, dtype, strides, convInfo, poolType) {
102778 var strideHeight = convInfo.strideHeight;
102779 var strideWidth = convInfo.strideWidth;
102780 var dilationHeight = convInfo.dilationHeight;
102781 var dilationWidth = convInfo.dilationWidth;
102782 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
102783 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
102784 var padTop = convInfo.padInfo.top;
102785 var padLeft = convInfo.padInfo.left;
102786 var initialValue = poolType === 'max' ? Number.NEGATIVE_INFINITY : Number.POSITIVE_INFINITY;
102787 var output = buffer(convInfo.outShape, dtype);
102788 var outputVals = output.values;
102789 var outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
102790 var outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
102791 var outputColStrides = convInfo.outShape[3];
102792 for (var b = 0; b < convInfo.batchSize; ++b) {
102793 var outputBatchOffset = b * outputBatchStrides;
102794 var inputBatchOffset = b * strides[0];
102795 for (var d = 0; d < convInfo.inChannels; ++d) {
102796 for (var yR = 0; yR < convInfo.outHeight; ++yR) {
102797 var xRCorner = yR * strideHeight - padTop;
102798 var xRMin = Math.max(0, xRCorner);
102799 var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
102800 var outputRowOffset = outputBatchOffset + yR * outputRowStrides;
102801 for (var yC = 0; yC < convInfo.outWidth; ++yC) {
102802 var xCCorner = yC * strideWidth - padLeft;
102803 var xCMin = Math.max(0, xCCorner);
102804 var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
102805 var minMaxValue = initialValue;
102806 var avgValue = 0;
102807 var count = 0;
102808 for (var xR = xRMin; xR < xRMax; xR += dilationHeight) {
102809 var xROffset = inputBatchOffset + xR * strides[1];
102810 for (var xC = xCMin; xC < xCMax; xC += dilationWidth) {
102811 var xCOffset = xROffset + xC * strides[2];
102812 var pixel = xValues[xCOffset + d];
102813 if (poolType === 'max' && pixel > minMaxValue) {
102814 minMaxValue = pixel;
102815 } else if (poolType === 'avg') {
102816 avgValue += pixel;
102817 count++;
102818 }
102819 }
102820 if (isNaN(minMaxValue)) {
102821 break;
102822 }
102823 }
102824 var outputOffset = outputRowOffset + yC * outputColStrides + d;
102825 outputVals[outputOffset] = poolType === 'avg' ? avgValue / count : minMaxValue;
102826 }
102827 }
102828 }
102829 }
102830 return output;
102831 }
102832 function maxPoolPositions(xValues, xShape, dtype, convInfo) {
102833 var flattenPositions = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
102834 var includeBatchInIndex = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : false;
102835 var maxPositions = buffer(convInfo.outShape, 'int32');
102836 var strideHeight = convInfo.strideHeight;
102837 var strideWidth = convInfo.strideWidth;
102838 var dilationHeight = convInfo.dilationHeight;
102839 var dilationWidth = convInfo.dilationWidth;
102840 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
102841 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
102842 var padTop = convInfo.padInfo.top;
102843 var padLeft = convInfo.padInfo.left;
102844 var xBuf = buffer(xShape, dtype, xValues);
102845 for (var b = 0; b < convInfo.batchSize; ++b) {
102846 for (var d = 0; d < convInfo.inChannels; ++d) {
102847 for (var yR = 0; yR < convInfo.outHeight; ++yR) {
102848 var xRCorner = yR * strideHeight - padTop;
102849 var xRMin = xRCorner;
102850 while (xRMin < 0) {
102851 xRMin += dilationHeight;
102852 }
102853 // const xRMin = Math.max(0, xRCorner);
102854 var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
102855 for (var yC = 0; yC < convInfo.outWidth; ++yC) {
102856 var xCCorner = yC * strideWidth - padLeft;
102857 var xCMin = xCCorner;
102858 while (xCMin < 0) {
102859 xCMin += dilationWidth;
102860 }
102861 var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
102862 var maxValue = Number.NEGATIVE_INFINITY;
102863 var maxPosition = -1;
102864 for (var xR = xRMin; xR < xRMax; xR += dilationHeight) {
102865 var wR = xR - xRCorner;
102866 for (var xC = xCMin; xC < xCMax; xC += dilationWidth) {
102867 var wC = xC - xCCorner;
102868 // For some reason, disable-next-line is not working
102869 // TODO(mattsoulanille): Remove this when switching to TS5.
102870 /* tslint:disable: no-unnecessary-type-assertion */
102871 var pixel = xBuf.get(b, xR, xC, d);
102872 if (pixel > maxValue) {
102873 maxValue = pixel;
102874 if (flattenPositions) {
102875 maxPosition = includeBatchInIndex ? ((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) * convInfo.inChannels + d : (xR * convInfo.inWidth + xC) * convInfo.inChannels + d;
102876 } else {
102877 maxPosition = wR * effectiveFilterWidth + wC;
102878 }
102879 }
102880 }
102881 }
102882 maxPositions.set(maxPosition, b, yR, yC, d);
102883 }
102884 }
102885 }
102886 }
102887 return maxPositions;
102888 }
102889 function pool3d(xValues, xShape, dtype, strides, convInfo, poolType) {
102890 var strideDepth = convInfo.strideDepth;
102891 var strideHeight = convInfo.strideHeight;
102892 var strideWidth = convInfo.strideWidth;
102893 var dilationDepth = convInfo.dilationDepth;
102894 var dilationHeight = convInfo.dilationHeight;
102895 var dilationWidth = convInfo.dilationWidth;
102896 var effectiveFilterDepth = convInfo.effectiveFilterDepth;
102897 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
102898 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
102899 var padFront = convInfo.padInfo.front;
102900 var padTop = convInfo.padInfo.top;
102901 var padLeft = convInfo.padInfo.left;
102902 var initialValue = poolType === 'max' ? Number.NEGATIVE_INFINITY : Number.POSITIVE_INFINITY;
102903 var output = buffer(convInfo.outShape, dtype);
102904 var outputVals = output.values;
102905 var outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
102906 var outputDepthStrides = convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
102907 var outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];
102908 var outputColStrides = convInfo.outShape[4];
102909 for (var batch = 0; batch < convInfo.batchSize; ++batch) {
102910 var outputBatchOffset = batch * outputBatchStrides;
102911 var inputBatchOffset = batch * strides[0];
102912 for (var channel = 0; channel < convInfo.inChannels; ++channel) {
102913 for (var yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
102914 var xDepthCorner = yDepth * strideDepth - padFront;
102915 var xDepthMin = xDepthCorner;
102916 while (xDepthMin < 0) {
102917 xDepthMin += dilationDepth;
102918 }
102919 var xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
102920 var outputDepthOffset = outputBatchOffset + yDepth * outputDepthStrides;
102921 for (var yRow = 0; yRow < convInfo.outHeight; ++yRow) {
102922 var xRowCorner = yRow * strideHeight - padTop;
102923 var xRowMin = xRowCorner;
102924 while (xRowMin < 0) {
102925 xRowMin += dilationHeight;
102926 }
102927 var xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
102928 var outputRowOffset = outputDepthOffset + yRow * outputRowStrides;
102929 for (var yCol = 0; yCol < convInfo.outWidth; ++yCol) {
102930 var xColCorner = yCol * strideWidth - padLeft;
102931 var xColMin = xColCorner;
102932 while (xColMin < 0) {
102933 xColMin += dilationWidth;
102934 }
102935 var xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
102936 // Shader code begins
102937 var outputColOffset = outputRowOffset + yCol * outputColStrides;
102938 var minMaxValue = initialValue;
102939 var avgValue = 0;
102940 var count = 0;
102941 for (var xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
102942 var xDepthOffset = inputBatchOffset + xDepth * strides[1];
102943 for (var xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
102944 var xRowOffset = xDepthOffset + xRow * strides[2];
102945 for (var xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
102946 var xColOffset = xRowOffset + xCol * strides[3];
102947 var pixel = xValues[xColOffset + channel];
102948 if (poolType === 'max' && pixel > minMaxValue) {
102949 minMaxValue = pixel;
102950 } else if (poolType === 'avg') {
102951 avgValue += pixel;
102952 count++;
102953 }
102954 if (isNaN(minMaxValue)) {
102955 break;
102956 }
102957 }
102958 if (isNaN(minMaxValue)) {
102959 break;
102960 }
102961 }
102962 if (isNaN(minMaxValue)) {
102963 break;
102964 }
102965 }
102966 var outputOffset = outputColOffset + channel;
102967 outputVals[outputOffset] = poolType === 'avg' ? avgValue / Math.max(count, 1) : minMaxValue;
102968 }
102969 }
102970 }
102971 }
102972 }
102973 return output;
102974 }
102975 function maxPool3dPositions(xBuf, convInfo) {
102976 var maxPositions = buffer(convInfo.outShape, 'int32');
102977 var strideDepth = convInfo.strideDepth;
102978 var strideHeight = convInfo.strideHeight;
102979 var strideWidth = convInfo.strideWidth;
102980 var dilationDepth = convInfo.dilationDepth;
102981 var dilationHeight = convInfo.dilationHeight;
102982 var dilationWidth = convInfo.dilationWidth;
102983 var effectiveFilterDepth = convInfo.effectiveFilterDepth;
102984 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
102985 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
102986 var padFront = convInfo.padInfo.front;
102987 var padTop = convInfo.padInfo.top;
102988 var padLeft = convInfo.padInfo.left;
102989 for (var batch = 0; batch < convInfo.batchSize; ++batch) {
102990 for (var channel = 0; channel < convInfo.inChannels; ++channel) {
102991 for (var yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
102992 var xDepthCorner = yDepth * strideDepth - padFront;
102993 var xDepthMin = xDepthCorner;
102994 while (xDepthMin < 0) {
102995 xDepthMin += dilationDepth;
102996 }
102997 var xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
102998 for (var yRow = 0; yRow < convInfo.outHeight; ++yRow) {
102999 var xRowCorner = yRow * strideHeight - padTop;
103000 var xRowMin = xRowCorner;
103001 while (xRowMin < 0) {
103002 xRowMin += dilationHeight;
103003 }
103004 var xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
103005 for (var yCol = 0; yCol < convInfo.outWidth; ++yCol) {
103006 var xColCorner = yCol * strideWidth - padLeft;
103007 var xColMin = xColCorner;
103008 while (xColMin < 0) {
103009 xColMin += dilationWidth;
103010 }
103011 var xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
103012 // Shader code begins
103013 var maxValue = Number.NEGATIVE_INFINITY;
103014 var maxPosition = -1;
103015 for (var xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
103016 var wDepth = xDepth - xDepthCorner;
103017 for (var xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
103018 var wRow = xRow - xRowCorner;
103019 for (var xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
103020 var wCol = xCol - xColCorner;
103021 var pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);
103022 if (pixel >= maxValue) {
103023 maxValue = pixel;
103024 maxPosition = wDepth * effectiveFilterHeight * effectiveFilterWidth + wRow * effectiveFilterHeight + wCol;
103025 }
103026 }
103027 }
103028 }
103029 maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);
103030 }
103031 }
103032 }
103033 }
103034 }
103035 return maxPositions;
103036 }
103037
103038 /**
103039 * @license
103040 * Copyright 2020 Google LLC. All Rights Reserved.
103041 * Licensed under the Apache License, Version 2.0 (the "License");
103042 * you may not use this file except in compliance with the License.
103043 * You may obtain a copy of the License at
103044 *
103045 * http://www.apache.org/licenses/LICENSE-2.0
103046 *
103047 * Unless required by applicable law or agreed to in writing, software
103048 * distributed under the License is distributed on an "AS IS" BASIS,
103049 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103050 * See the License for the specific language governing permissions and
103051 * limitations under the License.
103052 * =============================================================================
103053 */
103054 function avgPool$1(args) {
103055 var inputs = args.inputs,
103056 backend = args.backend,
103057 attrs = args.attrs;
103058 var x = inputs.x;
103059 assertNotComplex$1(x, 'avgPool');
103060 var filterSize = attrs.filterSize,
103061 strides = attrs.strides,
103062 pad = attrs.pad,
103063 dimRoundingMode = attrs.dimRoundingMode;
103064 var dilations = 1;
103065 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
103066 return 'Error in avgPool: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
103067 });
103068 var convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
103069 var res;
103070 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) {
103071 res = identity$1({
103072 inputs: {
103073 x: x
103074 },
103075 backend: backend
103076 });
103077 } else {
103078 var xValues = backend.data.get(x.dataId).values;
103079 var _strides = computeStrides(x.shape);
103080 var buffer = pool(xValues, x.shape, x.dtype, _strides, convInfo, 'avg');
103081 res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
103082 }
103083 return res;
103084 }
103085 var avgPoolConfig$1 = {
103086 kernelName: AvgPool,
103087 backendName: 'cpu',
103088 kernelFunc: avgPool$1
103089 };
103090
103091 /**
103092 * @license
103093 * Copyright 2020 Google LLC. All Rights Reserved.
103094 * Licensed under the Apache License, Version 2.0 (the "License");
103095 * you may not use this file except in compliance with the License.
103096 * You may obtain a copy of the License at
103097 *
103098 * http://www.apache.org/licenses/LICENSE-2.0
103099 *
103100 * Unless required by applicable law or agreed to in writing, software
103101 * distributed under the License is distributed on an "AS IS" BASIS,
103102 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103103 * See the License for the specific language governing permissions and
103104 * limitations under the License.
103105 * =============================================================================
103106 */
103107 function avgPool3D$1(args) {
103108 var inputs = args.inputs,
103109 backend = args.backend,
103110 attrs = args.attrs;
103111 var x = inputs.x;
103112 var filterSize = attrs.filterSize,
103113 strides = attrs.strides,
103114 pad = attrs.pad,
103115 dimRoundingMode = attrs.dimRoundingMode,
103116 dataFormat = attrs.dataFormat;
103117 assertNotComplex$1(x, 'avgPool3d');
103118 var convInfo = computePool3DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode, dataFormat);
103119 var xValues = backend.data.get(x.dataId).values;
103120 var outBuf = pool3d(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'avg');
103121 return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
103122 }
103123 var avgPool3DConfig$1 = {
103124 kernelName: AvgPool3D,
103125 backendName: 'cpu',
103126 kernelFunc: avgPool3D$1
103127 };
103128
103129 /**
103130 * @license
103131 * Copyright 2020 Google LLC. All Rights Reserved.
103132 * Licensed under the Apache License, Version 2.0 (the "License");
103133 * you may not use this file except in compliance with the License.
103134 * You may obtain a copy of the License at
103135 *
103136 * http://www.apache.org/licenses/LICENSE-2.0
103137 *
103138 * Unless required by applicable law or agreed to in writing, software
103139 * distributed under the License is distributed on an "AS IS" BASIS,
103140 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103141 * See the License for the specific language governing permissions and
103142 * limitations under the License.
103143 * =============================================================================
103144 */
103145 function avgPool3DGrad$1(args) {
103146 var inputs = args.inputs,
103147 backend = args.backend,
103148 attrs = args.attrs;
103149 var dy = inputs.dy,
103150 input = inputs.input;
103151 var filterSize = attrs.filterSize,
103152 strides = attrs.strides,
103153 pad = attrs.pad,
103154 dimRoundingMode = attrs.dimRoundingMode;
103155 assertNotComplex$1([dy, input], 'avgPool3DGrad');
103156 var convInfo = computePool3DInfo(input.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
103157 var strideDepth = convInfo.strideDepth;
103158 var strideHeight = convInfo.strideHeight;
103159 var strideWidth = convInfo.strideWidth;
103160 var filterDepth = convInfo.filterDepth;
103161 var filterHeight = convInfo.filterHeight;
103162 var filterWidth = convInfo.filterWidth;
103163 var dilationDepth = convInfo.dilationDepth;
103164 var dilationHeight = convInfo.dilationHeight;
103165 var dilationWidth = convInfo.dilationWidth;
103166 var effectiveFilterDepth = convInfo.effectiveFilterDepth;
103167 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
103168 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
103169 var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
103170 var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
103171 var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
103172 var dx = buffer(input.shape, 'float32');
103173 var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
103174 var dyBuf = backend.bufferSync(dy);
103175 for (var batch = 0; batch < convInfo.batchSize; ++batch) {
103176 for (var channel = 0; channel < convInfo.inChannels; ++channel) {
103177 for (var dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
103178 for (var dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
103179 for (var dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
103180 // Shader code begins.
103181 var dyDepthCorner = dxDepth - padFront;
103182 var dyRowCorner = dxRow - padTop;
103183 var dyColCorner = dxCol - padLeft;
103184 var dotProd = 0;
103185 for (var wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
103186 var dyDepth = (dyDepthCorner + wDepth) / strideDepth;
103187 if (dyDepth < 0 || dyDepth >= convInfo.outDepth || Math.floor(dyDepth) !== dyDepth) {
103188 continue;
103189 }
103190 for (var wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
103191 var dyRow = (dyRowCorner + wRow) / strideHeight;
103192 if (dyRow < 0 || dyRow >= convInfo.outHeight || Math.floor(dyRow) !== dyRow) {
103193 continue;
103194 }
103195 for (var wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
103196 var dyCol = (dyColCorner + wCol) / strideWidth;
103197 if (dyCol < 0 || dyCol >= convInfo.outWidth || Math.floor(dyCol) !== dyCol) {
103198 continue;
103199 }
103200 var pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
103201 dotProd += pixel;
103202 }
103203 }
103204 }
103205 dx.set(dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol, channel);
103206 }
103207 }
103208 }
103209 }
103210 }
103211 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
103212 }
103213 var avgPool3DGradConfig$1 = {
103214 kernelName: AvgPool3DGrad,
103215 backendName: 'cpu',
103216 kernelFunc: avgPool3DGrad$1
103217 };
103218
103219 /**
103220 * @license
103221 * Copyright 2020 Google LLC. All Rights Reserved.
103222 * Licensed under the Apache License, Version 2.0 (the "License");
103223 * you may not use this file except in compliance with the License.
103224 * You may obtain a copy of the License at
103225 *
103226 * http://www.apache.org/licenses/LICENSE-2.0
103227 *
103228 * Unless required by applicable law or agreed to in writing, software
103229 * distributed under the License is distributed on an "AS IS" BASIS,
103230 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103231 * See the License for the specific language governing permissions and
103232 * limitations under the License.
103233 * =============================================================================
103234 */
103235 function avgPoolGrad$1(args) {
103236 var inputs = args.inputs,
103237 backend = args.backend,
103238 attrs = args.attrs;
103239 var dy = inputs.dy,
103240 input = inputs.input;
103241 var x = input;
103242 assertNotComplex$1([dy, input], 'avgPoolGrad');
103243 var filterSize = attrs.filterSize,
103244 strides = attrs.strides,
103245 pad = attrs.pad;
103246 var convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad);
103247 var strideHeight = convInfo.strideHeight;
103248 var strideWidth = convInfo.strideWidth;
103249 var filterHeight = convInfo.filterHeight;
103250 var filterWidth = convInfo.filterWidth;
103251 var dilationHeight = convInfo.dilationHeight;
103252 var dilationWidth = convInfo.dilationWidth;
103253 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
103254 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
103255 var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
103256 var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
103257 var dx = buffer(x.shape, 'float32');
103258 var avgMultiplier = 1 / (filterHeight * filterWidth);
103259 var dyData = backend.data.get(dy.dataId).values;
103260 var dyBuf = buffer(dy.shape, 'float32', dyData);
103261 for (var b = 0; b < convInfo.batchSize; ++b) {
103262 for (var d = 0; d < convInfo.inChannels; ++d) {
103263 for (var dxR = 0; dxR < convInfo.inHeight; ++dxR) {
103264 for (var dxC = 0; dxC < convInfo.inWidth; ++dxC) {
103265 // Shader code begins.
103266 var dyRCorner = dxR - padTop;
103267 var dyCCorner = dxC - padLeft;
103268 var dotProd = 0;
103269 for (var wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
103270 var dyR = (dyRCorner + wR) / strideHeight;
103271 if (dyR < 0 || dyR >= convInfo.outHeight || Math.floor(dyR) !== dyR) {
103272 continue;
103273 }
103274 for (var wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
103275 var dyC = (dyCCorner + wC) / strideWidth;
103276 if (dyC < 0 || dyC >= convInfo.outWidth || Math.floor(dyC) !== dyC) {
103277 continue;
103278 }
103279 var pixel = dyBuf.get(b, dyR, dyC, d);
103280 dotProd += pixel;
103281 }
103282 }
103283 dx.set(dotProd * avgMultiplier, b, dxR, dxC, d);
103284 }
103285 }
103286 }
103287 }
103288 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
103289 }
103290 var avgPoolGradConfig$1 = {
103291 kernelName: AvgPoolGrad,
103292 backendName: 'cpu',
103293 kernelFunc: avgPoolGrad$1
103294 };
103295
103296 /**
103297 * @license
103298 * Copyright 2020 Google LLC. All Rights Reserved.
103299 * Licensed under the Apache License, Version 2.0 (the "License");
103300 * you may not use this file except in compliance with the License.
103301 * You may obtain a copy of the License at
103302 *
103303 * http://www.apache.org/licenses/LICENSE-2.0
103304 *
103305 * Unless required by applicable law or agreed to in writing, software
103306 * distributed under the License is distributed on an "AS IS" BASIS,
103307 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103308 * See the License for the specific language governing permissions and
103309 * limitations under the License.
103310 * =============================================================================
103311 */
103312 function batchNorm$1(args) {
103313 var inputs = args.inputs,
103314 backend = args.backend,
103315 attrs = args.attrs;
103316 var x = inputs.x,
103317 scale = inputs.scale,
103318 offset = inputs.offset,
103319 mean = inputs.mean,
103320 variance = inputs.variance;
103321 assert$1(mean.shape.length === variance.shape.length, function () {
103322 return 'Batch normalization gradient requires mean and variance to have ' + 'equal ranks.';
103323 });
103324 assert$1(offset == null || mean.shape.length === offset.shape.length, function () {
103325 return 'Batch normalization gradient requires mean and offset to have ' + 'equal ranks.';
103326 });
103327 assert$1(scale == null || mean.shape.length === scale.shape.length, function () {
103328 return 'Batch normalization gradient requires mean and scale to have ' + 'equal ranks.';
103329 });
103330 assertNotComplex$1([x, mean, variance, scale, offset], 'batchNorm');
103331 var varianceEpsilon = attrs.varianceEpsilon;
103332 if (varianceEpsilon == null) {
103333 varianceEpsilon = 0.001;
103334 }
103335 var xVals = backend.data.get(x.dataId).values;
103336 var mVals = backend.data.get(mean.dataId).values;
103337 var varVals = backend.data.get(variance.dataId).values;
103338 var sVals = scale ? backend.data.get(scale.dataId).values : new Float32Array([1]);
103339 var offVals = offset ? backend.data.get(offset.dataId).values : new Float32Array([0]);
103340 var outVals = new Float32Array(xVals.length);
103341 var offValsLength = offVals.length;
103342 var sValsLength = sVals.length;
103343 var varValsLength = varVals.length;
103344 var mValsLength = mVals.length;
103345 var offi = 0;
103346 var mi = 0;
103347 var si = 0;
103348 var vi = 0;
103349 for (var i = 0; i < xVals.length; ++i) {
103350 outVals[i] = offVals[offi++] + (xVals[i] - mVals[mi++]) * sVals[si++] / Math.sqrt(varVals[vi++] + varianceEpsilon);
103351 if (offi >= offValsLength) {
103352 offi = 0;
103353 }
103354 if (mi >= mValsLength) {
103355 mi = 0;
103356 }
103357 if (si >= sValsLength) {
103358 si = 0;
103359 }
103360 if (vi >= varValsLength) {
103361 vi = 0;
103362 }
103363 }
103364 return backend.makeTensorInfo(x.shape, x.dtype, outVals);
103365 }
103366 var batchNormConfig$1 = {
103367 kernelName: FusedBatchNorm,
103368 backendName: 'cpu',
103369 kernelFunc: batchNorm$1
103370 };
103371
103372 /**
103373 * @license
103374 * Copyright 2020 Google LLC. All Rights Reserved.
103375 * Licensed under the Apache License, Version 2.0 (the "License");
103376 * you may not use this file except in compliance with the License.
103377 * You may obtain a copy of the License at
103378 *
103379 * http://www.apache.org/licenses/LICENSE-2.0
103380 *
103381 * Unless required by applicable law or agreed to in writing, software
103382 * distributed under the License is distributed on an "AS IS" BASIS,
103383 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103384 * See the License for the specific language governing permissions and
103385 * limitations under the License.
103386 * =============================================================================
103387 */
103388 function batchToSpaceND$1(args) {
103389 var inputs = args.inputs,
103390 backend = args.backend,
103391 attrs = args.attrs;
103392 var x = inputs.x;
103393 var blockShape = attrs.blockShape,
103394 crops = attrs.crops;
103395 assertNotComplex$1([x], 'batchToSpaceND');
103396 var prod = blockShape.reduce(function (a, b) {
103397 return a * b;
103398 });
103399 var reshaped = getReshaped(x.shape, blockShape, prod);
103400 var permuted = getPermuted(reshaped.length, blockShape.length);
103401 var reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
103402 var sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
103403 var sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
103404 var xReshaped = reshape$1({
103405 inputs: {
103406 x: x
103407 },
103408 backend: backend,
103409 attrs: {
103410 shape: reshaped
103411 }
103412 });
103413 var xTransposed = transpose$1({
103414 inputs: {
103415 x: xReshaped
103416 },
103417 backend: backend,
103418 attrs: {
103419 perm: permuted
103420 }
103421 });
103422 var xTransposedReshaped = reshape$1({
103423 inputs: {
103424 x: xTransposed
103425 },
103426 backend: backend,
103427 attrs: {
103428 shape: reshapedPermuted
103429 }
103430 });
103431 var result = slice$1({
103432 inputs: {
103433 x: xTransposedReshaped
103434 },
103435 backend: backend,
103436 attrs: {
103437 begin: sliceBeginCoords,
103438 size: sliceSize
103439 }
103440 });
103441 backend.disposeIntermediateTensorInfo(xReshaped);
103442 backend.disposeIntermediateTensorInfo(xTransposed);
103443 backend.disposeIntermediateTensorInfo(xTransposedReshaped);
103444 return result;
103445 }
103446 var batchToSpaceNDConfig$1 = {
103447 kernelName: BatchToSpaceND,
103448 backendName: 'cpu',
103449 kernelFunc: batchToSpaceND$1
103450 };
103451
103452 /**
103453 * @license
103454 * Copyright 2020 Google LLC. All Rights Reserved.
103455 * Licensed under the Apache License, Version 2.0 (the "License");
103456 * you may not use this file except in compliance with the License.
103457 * You may obtain a copy of the License at
103458 *
103459 * http://www.apache.org/licenses/LICENSE-2.0
103460 *
103461 * Unless required by applicable law or agreed to in writing, software
103462 * distributed under the License is distributed on an "AS IS" BASIS,
103463 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103464 * See the License for the specific language governing permissions and
103465 * limitations under the License.
103466 * =============================================================================
103467 */
103468 function bincount$1(args) {
103469 var inputs = args.inputs,
103470 backend = args.backend,
103471 attrs = args.attrs;
103472 var x = inputs.x,
103473 weights = inputs.weights;
103474 var size = attrs.size;
103475 var xVals = backend.data.get(x.dataId).values;
103476 var weightsVals = backend.data.get(weights.dataId).values;
103477 var outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
103478 return backend.makeTensorInfo([size], weights.dtype, outVals);
103479 }
103480 var bincountConfig$1 = {
103481 kernelName: Bincount,
103482 backendName: 'cpu',
103483 kernelFunc: bincount$1
103484 };
103485
103486 /**
103487 * @license
103488 * Copyright 2021 Google LLC. All Rights Reserved.
103489 * Licensed under the Apache License, Version 2.0 (the "License");
103490 * you may not use this file except in compliance with the License.
103491 * You may obtain a copy of the License at
103492 *
103493 * http://www.apache.org/licenses/LICENSE-2.0
103494 *
103495 * Unless required by applicable law or agreed to in writing, software
103496 * distributed under the License is distributed on an "AS IS" BASIS,
103497 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103498 * See the License for the specific language governing permissions and
103499 * limitations under the License.
103500 * =============================================================================
103501 */
103502 function broadcastArgs$1(args) {
103503 var inputs = args.inputs,
103504 backend = args.backend;
103505 var s0 = inputs.s0,
103506 s1 = inputs.s1;
103507 var s0Vals = backend.data.get(s0.dataId).values;
103508 var s1Vals = backend.data.get(s1.dataId).values;
103509 var broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
103510 return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
103511 }
103512 var broadcastArgsConfig$1 = {
103513 kernelName: BroadcastArgs,
103514 backendName: 'cpu',
103515 kernelFunc: broadcastArgs$1
103516 };
103517
103518 /**
103519 * @license
103520 * Copyright 2020 Google LLC. All Rights Reserved.
103521 * Licensed under the Apache License, Version 2.0 (the License);
103522 * you may not use this file except in compliance with the License.
103523 * You may obtain a copy of the License at
103524 *
103525 * http://www.apache.org/licenses/LICENSE-2.0
103526 *
103527 * Unless required by applicable law or agreed to in writing, software
103528 * distributed under the License is distributed on an AS IS BASIS,
103529 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103530 * See the License for the specific language governing permissions and
103531 * limitations under the License.
103532 * =============================================================================
103533 */
103534 var clipByValue$1 = unaryKernelFunc$1(ClipByValue, function (xi, attrs) {
103535 var clipAttrs = attrs;
103536 if (xi > clipAttrs.clipValueMax) {
103537 return clipAttrs.clipValueMax;
103538 }
103539 return xi < clipAttrs.clipValueMin ? clipAttrs.clipValueMin : xi;
103540 });
103541 var clipByValueConfig$1 = {
103542 kernelName: ClipByValue,
103543 backendName: 'cpu',
103544 kernelFunc: clipByValue$1
103545 };
103546
103547 /**
103548 * @license
103549 * Copyright 2020 Google LLC. All Rights Reserved.
103550 * Licensed under the Apache License, Version 2.0 (the License);
103551 * you may not use this file except in compliance with the License.
103552 * You may obtain a copy of the License at
103553 *
103554 * http://www.apache.org/licenses/LICENSE-2.0
103555 *
103556 * Unless required by applicable law or agreed to in writing, software
103557 * distributed under the License is distributed on an AS IS BASIS,
103558 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103559 * See the License for the specific language governing permissions and
103560 * limitations under the License.
103561 * =============================================================================
103562 */
103563 var complexAbs$1 = function complexAbs(args) {
103564 var x = args.inputs.x;
103565 var cpuBackend = args.backend;
103566 var resultValues = new Float32Array(sizeFromShape(x.shape));
103567 var complexVals = cpuBackend.data.get(x.dataId);
103568 var real = complexVals.complexTensorInfos.real;
103569 var imag = complexVals.complexTensorInfos.imag;
103570 var realVals = cpuBackend.data.get(real.dataId).values;
103571 var imagVals = cpuBackend.data.get(imag.dataId).values;
103572 for (var i = 0; i < realVals.length; i++) {
103573 var _real = realVals[i];
103574 var _imag = imagVals[i];
103575 resultValues[i] = Math.hypot(_real, _imag);
103576 }
103577 return cpuBackend.makeOutput(resultValues, x.shape, 'float32');
103578 };
103579 var complexAbsConfig$1 = {
103580 kernelName: ComplexAbs,
103581 backendName: 'cpu',
103582 kernelFunc: complexAbs$1
103583 };
103584
103585 /**
103586 * @license
103587 * Copyright 2020 Google LLC. All Rights Reserved.
103588 * Licensed under the Apache License, Version 2.0 (the "License");
103589 * you may not use this file except in compliance with the License.
103590 * You may obtain a copy of the License at
103591 *
103592 * http://www.apache.org/licenses/LICENSE-2.0
103593 *
103594 * Unless required by applicable law or agreed to in writing, software
103595 * distributed under the License is distributed on an "AS IS" BASIS,
103596 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103597 * See the License for the specific language governing permissions and
103598 * limitations under the License.
103599 * =============================================================================
103600 */
103601 function imag$1(args) {
103602 var inputs = args.inputs,
103603 backend = args.backend;
103604 var input = inputs.input;
103605 var imag = backend.data.get(input.dataId).complexTensorInfos.imag;
103606 var imagVal = backend.data.get(imag.dataId).values;
103607 // When complex tensor is disposed, its underlying parts will be disposed too.
103608 // Make new tensor out of the imag value of the complex. This makes sure the
103609 // value is still accessible even if complex tensor is disposed.
103610 return backend.makeTensorInfo(imag.shape, imag.dtype, imagVal);
103611 }
103612 var imagConfig$1 = {
103613 kernelName: Imag,
103614 backendName: 'cpu',
103615 kernelFunc: imag$1
103616 };
103617
103618 /**
103619 * @license
103620 * Copyright 2020 Google LLC. All Rights Reserved.
103621 * Licensed under the Apache License, Version 2.0 (the "License");
103622 * you may not use this file except in compliance with the License.
103623 * You may obtain a copy of the License at
103624 *
103625 * http://www.apache.org/licenses/LICENSE-2.0
103626 *
103627 * Unless required by applicable law or agreed to in writing, software
103628 * distributed under the License is distributed on an "AS IS" BASIS,
103629 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103630 * See the License for the specific language governing permissions and
103631 * limitations under the License.
103632 * =============================================================================
103633 */
103634 function concat$1(args) {
103635 var inputs = args.inputs,
103636 backend = args.backend,
103637 attrs = args.attrs;
103638 var axis = attrs.axis;
103639 var $axis = parseAxisParam(axis, inputs[0].shape)[0];
103640 var shapes = inputs.map(function (t) {
103641 return t.shape;
103642 });
103643 assertParamsConsistent(shapes, $axis);
103644 var outShape = computeOutShape$1(inputs.map(function (t) {
103645 return t.shape;
103646 }), $axis);
103647 if (sizeFromShape(outShape) === 0) {
103648 return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
103649 }
103650 // Keep only non-empty tensors (ignore tensors with 0 in their shape).
103651 var $inputs = inputs.filter(function (t) {
103652 return sizeFromShape(t.shape) > 0;
103653 });
103654 if ($inputs.length === 1) {
103655 return identity$1({
103656 inputs: {
103657 x: $inputs[0]
103658 },
103659 backend: backend
103660 });
103661 }
103662 if ($inputs[0].dtype === 'complex64') {
103663 var reals = $inputs.map(function (t) {
103664 return real$1({
103665 inputs: {
103666 input: t
103667 },
103668 backend: backend
103669 });
103670 });
103671 var imags = $inputs.map(function (t) {
103672 return imag$1({
103673 inputs: {
103674 input: t
103675 },
103676 backend: backend
103677 });
103678 });
103679 var realConcated = concat$1({
103680 inputs: reals,
103681 backend: backend,
103682 attrs: {
103683 axis: $axis
103684 }
103685 });
103686 var imagConcated = concat$1({
103687 inputs: imags,
103688 backend: backend,
103689 attrs: {
103690 axis: $axis
103691 }
103692 });
103693 var result = complex$1({
103694 inputs: {
103695 real: realConcated,
103696 imag: imagConcated
103697 },
103698 backend: backend
103699 });
103700 reals.forEach(function (r) {
103701 return backend.disposeIntermediateTensorInfo(r);
103702 });
103703 imags.forEach(function (i) {
103704 return backend.disposeIntermediateTensorInfo(i);
103705 });
103706 backend.disposeIntermediateTensorInfo(realConcated);
103707 backend.disposeIntermediateTensorInfo(imagConcated);
103708 return result;
103709 }
103710 // Any concat of n-dimensional tensors across any axis can be reduced to
103711 // a concatenation of two-dimensional tensors across the axis 1 by first
103712 // partitioning the axes of the original tensors into those less than the
103713 // axis to be concatenated and the rest. Then reshape the tensors
103714 // into a two-dimensional tensor by collapsing these two sets of axes and
103715 // concatenate the resulting matrices across the axis 1, finally reshaping
103716 // the result to have the proper shape.
103717 var inputs2D = $inputs.map(function (t) {
103718 var innerSize = sizeFromShape(t.shape.slice($axis));
103719 var shape = [-1, innerSize];
103720 return reshape$1({
103721 inputs: {
103722 x: t
103723 },
103724 backend: backend,
103725 attrs: {
103726 shape: shape
103727 }
103728 });
103729 });
103730 var inputsValShapes = inputs2D.map(function (t) {
103731 return {
103732 vals: backend.data.get(t.dataId).values,
103733 shape: t.shape
103734 };
103735 });
103736 // Concats 2d tensors along axis=1.
103737 outShape = computeOutShape$1(inputs2D.map(function (t) {
103738 return t.shape;
103739 }), 1 /* axis */);
103740 var simplyConcat = inputs2D[0].shape[0] === 1;
103741 var outVals = concatImpl$1(inputsValShapes, outShape, inputs[0].dtype, simplyConcat);
103742 var finalOutShape = computeOutShape$1($inputs.map(function (t) {
103743 return t.shape;
103744 }), $axis);
103745 var outInfo = backend.makeTensorInfo(finalOutShape, inputs[0].dtype, outVals);
103746 inputs2D.forEach(function (t) {
103747 return backend.disposeIntermediateTensorInfo(t);
103748 });
103749 return outInfo;
103750 }
103751 var concatConfig$1 = {
103752 kernelName: Concat,
103753 backendName: 'cpu',
103754 kernelFunc: concat$1
103755 };
103756
103757 /**
103758 * @license
103759 * Copyright 2020 Google LLC. All Rights Reserved.
103760 * Licensed under the Apache License, Version 2.0 (the "License");
103761 * you may not use this file except in compliance with the License.
103762 * You may obtain a copy of the License at
103763 *
103764 * http://www.apache.org/licenses/LICENSE-2.0
103765 *
103766 * Unless required by applicable law or agreed to in writing, software
103767 * distributed under the License is distributed on an "AS IS" BASIS,
103768 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103769 * See the License for the specific language governing permissions and
103770 * limitations under the License.
103771 * =============================================================================
103772 */
103773 function conv2D(args) {
103774 var inputs = args.inputs,
103775 backend = args.backend,
103776 attrs = args.attrs;
103777 var x = inputs.x,
103778 filter = inputs.filter;
103779 var strides = attrs.strides,
103780 pad = attrs.pad,
103781 dataFormat = attrs.dataFormat,
103782 dilations = attrs.dilations,
103783 dimRoundingMode = attrs.dimRoundingMode;
103784 assertNotComplex$1([x, filter], 'conv2d');
103785 var $dataFormat = convertConv2DDataFormat(dataFormat);
103786 var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
103787 var filterHeight = convInfo.filterHeight;
103788 var filterWidth = convInfo.filterWidth;
103789 var dilationHeight = convInfo.dilationHeight;
103790 var dilationWidth = convInfo.dilationWidth;
103791 var padLeft = convInfo.padInfo.left;
103792 var padTop = convInfo.padInfo.top;
103793 var isChannelsLast = convInfo.dataFormat === 'channelsLast';
103794 var y = new TensorBuffer(convInfo.outShape, x.dtype);
103795 var xStrides = computeStrides(x.shape);
103796 var filterStrides = computeStrides(filter.shape);
103797 var xBatchStride = xStrides[0];
103798 var xRowStride = isChannelsLast ? xStrides[1] : xStrides[2];
103799 var xColStride = isChannelsLast ? xStrides[2] : 1;
103800 var xChannelStride = isChannelsLast ? 1 : xStrides[1];
103801 var yBatchStride = y.strides[0];
103802 var yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];
103803 var yColStride = isChannelsLast ? y.strides[2] : 1;
103804 var yChannelStride = isChannelsLast ? 1 : y.strides[1];
103805 var xVals = backend.data.get(x.dataId).values;
103806 var wVals = backend.data.get(filter.dataId).values;
103807 var yVals = y.values;
103808 for (var b = 0; b < convInfo.batchSize; ++b) {
103809 var xOffset1 = b * xBatchStride;
103810 var yOffset1 = b * yBatchStride;
103811 for (var yR = 0; yR < convInfo.outHeight; ++yR) {
103812 var yOffset2 = yOffset1 + yR * yRowStride;
103813 var xRCorner = yR * convInfo.strideHeight - padTop;
103814 for (var wR = 0; wR < filterHeight; ++wR) {
103815 var xR = xRCorner + wR * dilationHeight;
103816 if (xR < 0 || xR >= convInfo.inHeight) {
103817 continue;
103818 }
103819 var wOffset1 = wR * filterStrides[0];
103820 var xOffset2 = xOffset1 + xR * xRowStride;
103821 for (var yC = 0; yC < convInfo.outWidth; ++yC) {
103822 var yOffset3 = yOffset2 + yC * yColStride;
103823 var xCCorner = yC * convInfo.strideWidth - padLeft;
103824 for (var wC = 0; wC < filterWidth; ++wC) {
103825 var xC = xCCorner + wC * dilationWidth;
103826 if (xC < 0 || xC >= convInfo.inWidth) {
103827 continue;
103828 }
103829 var wOffset2 = wOffset1 + wC * filterStrides[1];
103830 var xOffset3 = xOffset2 + xC * xColStride;
103831 var wOffset3 = wOffset2;
103832 for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
103833 var xVal = xVals[xOffset3 + d1 * xChannelStride];
103834 for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
103835 yVals[yOffset3 + d2 * yChannelStride] += xVal * wVals[wOffset3 + d2];
103836 }
103837 wOffset3 += convInfo.outChannels;
103838 }
103839 }
103840 }
103841 }
103842 }
103843 }
103844 return backend.makeTensorInfo(y.shape, y.dtype, yVals);
103845 }
103846 var conv2DConfig$1 = {
103847 kernelName: Conv2D$1,
103848 backendName: 'cpu',
103849 kernelFunc: conv2D
103850 };
103851
103852 /**
103853 * @license
103854 * Copyright 2020 Google LLC. All Rights Reserved.
103855 * Licensed under the Apache License, Version 2.0 (the "License");
103856 * you may not use this file except in compliance with the License.
103857 * You may obtain a copy of the License at
103858 *
103859 * http://www.apache.org/licenses/LICENSE-2.0
103860 *
103861 * Unless required by applicable law or agreed to in writing, software
103862 * distributed under the License is distributed on an "AS IS" BASIS,
103863 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103864 * See the License for the specific language governing permissions and
103865 * limitations under the License.
103866 * =============================================================================
103867 */
103868 function conv2DBackpropFilter$1(args) {
103869 var inputs = args.inputs,
103870 backend = args.backend,
103871 attrs = args.attrs;
103872 var x = inputs.x,
103873 dy = inputs.dy;
103874 var strides = attrs.strides,
103875 pad = attrs.pad,
103876 dataFormat = attrs.dataFormat,
103877 dimRoundingMode = attrs.dimRoundingMode,
103878 filterShape = attrs.filterShape;
103879 assertNotComplex$1([x, dy], 'conv2dBackpropFilter');
103880 var $dataFormat = convertConv2DDataFormat(dataFormat);
103881 var convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
103882 var strideHeight = convInfo.strideHeight,
103883 strideWidth = convInfo.strideWidth,
103884 filterHeight = convInfo.filterHeight,
103885 filterWidth = convInfo.filterWidth;
103886 var isChannelsLast = convInfo.dataFormat === 'channelsLast';
103887 var dW = new TensorBuffer(convInfo.filterShape, 'float32');
103888 var leftPad = convInfo.padInfo.left;
103889 var topPad = convInfo.padInfo.top;
103890 var xVals = backend.data.get(x.dataId).values;
103891 var dyVals = backend.data.get(dy.dataId).values;
103892 var xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
103893 var dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
103894 for (var wR = 0; wR < filterHeight; ++wR) {
103895 var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
103896 var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
103897 for (var wC = 0; wC < filterWidth; ++wC) {
103898 var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
103899 var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
103900 for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
103901 for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
103902 var dotProd = 0;
103903 for (var b = 0; b < convInfo.batchSize; ++b) {
103904 for (var yR = yRMin; yR < yRMax; ++yR) {
103905 var xR = wR + yR * strideHeight - topPad;
103906 for (var yC = yCMin; yC < yCMax; ++yC) {
103907 var xC = wC + yC * strideWidth - leftPad;
103908 if (isChannelsLast) {
103909 dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
103910 } else {
103911 dotProd += xBuf.get(b, d1, xR, xC) * dyBuf.get(b, d2, yR, yC);
103912 }
103913 }
103914 }
103915 }
103916 dW.set(dotProd, wR, wC, d1, d2);
103917 }
103918 }
103919 }
103920 }
103921 return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
103922 }
103923 var conv2DBackpropFilterConfig$1 = {
103924 kernelName: Conv2DBackpropFilter,
103925 backendName: 'cpu',
103926 kernelFunc: conv2DBackpropFilter$1
103927 };
103928
103929 function conv2DBackpropInput$1(args) {
103930 var inputs = args.inputs,
103931 backend = args.backend,
103932 attrs = args.attrs;
103933 var dy = inputs.dy,
103934 filter = inputs.filter;
103935 var inputShape = attrs.inputShape,
103936 strides = attrs.strides,
103937 pad = attrs.pad,
103938 dataFormat = attrs.dataFormat,
103939 dimRoundingMode = attrs.dimRoundingMode;
103940 assertNotComplex$1([dy, filter], 'conv2dBackpropInput');
103941 var filterStrides = computeStrides(filter.shape);
103942 var dyStrides = computeStrides(dy.shape);
103943 var $dataFormat = convertConv2DDataFormat(dataFormat);
103944 var convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat);
103945 var dx = new TensorBuffer(convInfo.inShape, 'float32');
103946 var dxValues = dx.values;
103947 var dyValues = backend.data.get(dy.dataId).values;
103948 var fltValues = backend.data.get(filter.dataId).values;
103949 var _filterStrides = _slicedToArray(filterStrides, 3),
103950 fltS0 = _filterStrides[0],
103951 fltS1 = _filterStrides[1],
103952 fltS2 = _filterStrides[2];
103953 var batchSize = convInfo.batchSize,
103954 filterHeight = convInfo.filterHeight,
103955 filterWidth = convInfo.filterWidth,
103956 inChannels = convInfo.inChannels,
103957 inHeight = convInfo.inHeight,
103958 inWidth = convInfo.inWidth,
103959 outChannels = convInfo.outChannels,
103960 outHeight = convInfo.outHeight,
103961 outWidth = convInfo.outWidth,
103962 strideHeight = convInfo.strideHeight,
103963 strideWidth = convInfo.strideWidth;
103964 $dataFormat = convInfo.dataFormat;
103965 var topPad = filterHeight - 1 - convInfo.padInfo.top;
103966 var leftPad = filterWidth - 1 - convInfo.padInfo.left;
103967 var isChannelsLast = $dataFormat === 'channelsLast';
103968 var xBatchStride = dx.strides[0];
103969 var xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];
103970 var xColStride = isChannelsLast ? dx.strides[2] : 1;
103971 var xChannelStride = isChannelsLast ? 1 : dx.strides[1];
103972 var yBatchStride = dyStrides[0];
103973 var yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];
103974 var yColStride = isChannelsLast ? dyStrides[2] : 1;
103975 var yChannelStride = isChannelsLast ? 1 : dyStrides[1];
103976 for (var b = 0; b < batchSize; ++b) {
103977 for (var d1 = 0; d1 < inChannels; ++d1) {
103978 for (var xR = 0; xR < inHeight; ++xR) {
103979 var xRCorner = xR - topPad;
103980 var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
103981 var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
103982 for (var xC = 0; xC < inWidth; ++xC) {
103983 var xCCorner = xC - leftPad;
103984 var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
103985 var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
103986 var dotProd = 0;
103987 for (var yR = xRMin; yR < yRMax; ++yR) {
103988 var wR = yR * strideHeight - xRCorner;
103989 for (var yC = xCMin; yC < yCMax; ++yC) {
103990 var wC = yC * strideWidth - xCCorner;
103991 var dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC;
103992 var fltOffset = fltS0 * (filterHeight - 1 - wR) + fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
103993 for (var d2 = 0; d2 < outChannels; ++d2) {
103994 var pixel = dyValues[dyOffset + yChannelStride * d2];
103995 var weight = fltValues[fltOffset + d2];
103996 dotProd += pixel * weight;
103997 }
103998 }
103999 }
104000 var dxOffset = xBatchStride * b + xRowStride * xR + xColStride * xC + xChannelStride * d1;
104001 dxValues[dxOffset] = dotProd;
104002 }
104003 }
104004 }
104005 }
104006 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
104007 }
104008 var conv2DBackpropInputConfig$1 = {
104009 kernelName: Conv2DBackpropInput,
104010 backendName: 'cpu',
104011 kernelFunc: conv2DBackpropInput$1
104012 };
104013
104014 /**
104015 * @license
104016 * Copyright 2020 Google LLC. All Rights Reserved.
104017 * Licensed under the Apache License, Version 2.0 (the "License");
104018 * you may not use this file except in compliance with the License.
104019 * You may obtain a copy of the License at
104020 *
104021 * http://www.apache.org/licenses/LICENSE-2.0
104022 *
104023 * Unless required by applicable law or agreed to in writing, software
104024 * distributed under the License is distributed on an "AS IS" BASIS,
104025 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104026 * See the License for the specific language governing permissions and
104027 * limitations under the License.
104028 * =============================================================================
104029 */
104030 function conv3D$1(args) {
104031 var inputs = args.inputs,
104032 backend = args.backend,
104033 attrs = args.attrs;
104034 var x = inputs.x,
104035 filter = inputs.filter;
104036 var strides = attrs.strides,
104037 pad = attrs.pad,
104038 dilations = attrs.dilations;
104039 assertNotComplex$1([x, filter], 'conv3d');
104040 var convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
104041 var filterDepth = convInfo.filterDepth,
104042 filterHeight = convInfo.filterHeight,
104043 filterWidth = convInfo.filterWidth,
104044 dilationDepth = convInfo.dilationDepth,
104045 dilationHeight = convInfo.dilationHeight,
104046 dilationWidth = convInfo.dilationWidth,
104047 padInfo = convInfo.padInfo;
104048 var padFront = padInfo.front;
104049 var padLeft = padInfo.left;
104050 var padTop = padInfo.top;
104051 var y = new TensorBuffer(convInfo.outShape, x.dtype);
104052 var xVals = backend.data.get(x.dataId).values;
104053 var wVals = backend.data.get(filter.dataId).values;
104054 var yVals = y.values;
104055 var xStrides = computeStrides(x.shape);
104056 var filterStrides = computeStrides(filter.shape);
104057 for (var b = 0; b < convInfo.batchSize; ++b) {
104058 var xOffset1 = b * xStrides[0];
104059 var yOffset1 = b * y.strides[0];
104060 for (var yF = 0; yF < convInfo.outDepth; ++yF) {
104061 var yOffset2 = yOffset1 + yF * y.strides[1];
104062 var xFCorner = yF * convInfo.strideDepth - padFront;
104063 for (var wF = 0; wF < filterDepth; ++wF) {
104064 var xF = xFCorner + wF * dilationDepth;
104065 if (xF < 0 || xF >= convInfo.inDepth) {
104066 continue;
104067 }
104068 var wOffset1 = wF * filterStrides[0];
104069 var xOffset2 = xOffset1 + xF * xStrides[1];
104070 for (var yR = 0; yR < convInfo.outHeight; ++yR) {
104071 var yOffset3 = yOffset2 + yR * y.strides[2];
104072 var xRCorner = yR * convInfo.strideHeight - padTop;
104073 for (var wR = 0; wR < filterHeight; ++wR) {
104074 var xR = xRCorner + wR * dilationHeight;
104075 if (xR < 0 || xR >= convInfo.inHeight) {
104076 continue;
104077 }
104078 var wOffset2 = wOffset1 + wR * filterStrides[1];
104079 var xOffset3 = xOffset2 + xR * xStrides[2];
104080 for (var yC = 0; yC < convInfo.outWidth; ++yC) {
104081 var yOffset4 = yOffset3 + yC * convInfo.outChannels;
104082 var xCCorner = yC * convInfo.strideWidth - padLeft;
104083 for (var wC = 0; wC < filterWidth; ++wC) {
104084 var xC = xCCorner + wC * dilationWidth;
104085 if (xC < 0 || xC >= convInfo.inWidth) {
104086 continue;
104087 }
104088 var wOffset3 = wOffset2 + wC * filterStrides[2];
104089 var xOffset4 = xOffset3 + xC * convInfo.inChannels;
104090 var wOffset4 = wOffset3;
104091 for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
104092 var xVal = xVals[xOffset4 + d1];
104093 for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
104094 yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2];
104095 }
104096 wOffset4 += convInfo.outChannels;
104097 }
104098 }
104099 }
104100 }
104101 }
104102 }
104103 }
104104 }
104105 return backend.makeTensorInfo(y.shape, y.dtype, y.values);
104106 }
104107 var conv3DConfig$1 = {
104108 kernelName: Conv3D$1,
104109 backendName: 'cpu',
104110 kernelFunc: conv3D$1
104111 };
104112
104113 function conv3DBackpropFilterV2$1(args) {
104114 var inputs = args.inputs,
104115 backend = args.backend,
104116 attrs = args.attrs;
104117 var x = inputs.x,
104118 dy = inputs.dy;
104119 var strides = attrs.strides,
104120 pad = attrs.pad,
104121 filterShape = attrs.filterShape;
104122 assertNotComplex$1([x, dy], 'conv3dBackpropFilterV2');
104123 var xStrides = computeStrides(x.shape);
104124 var dyStrides = computeStrides(dy.shape);
104125 var convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad);
104126 var strideDepth = convInfo.strideDepth;
104127 var strideHeight = convInfo.strideHeight;
104128 var strideWidth = convInfo.strideWidth;
104129 var filterDepth = convInfo.filterDepth;
104130 var filterHeight = convInfo.filterHeight;
104131 var filterWidth = convInfo.filterWidth;
104132 var dw = new TensorBuffer(convInfo.filterShape, 'float32');
104133 var dwValues = dw.values;
104134 var _dw$strides = _slicedToArray(dw.strides, 4),
104135 dwS0 = _dw$strides[0],
104136 dwS1 = _dw$strides[1],
104137 dwS2 = _dw$strides[2],
104138 dwS3 = _dw$strides[3];
104139 var dyValues = backend.data.get(dy.dataId).values;
104140 var _dyStrides = _slicedToArray(dyStrides, 4),
104141 dyS0 = _dyStrides[0],
104142 dyS1 = _dyStrides[1],
104143 dyS2 = _dyStrides[2],
104144 dyS3 = _dyStrides[3];
104145 var xValues = backend.data.get(x.dataId).values;
104146 var _xStrides = _slicedToArray(xStrides, 4),
104147 xS0 = _xStrides[0],
104148 xS1 = _xStrides[1],
104149 xS2 = _xStrides[2],
104150 xS3 = _xStrides[3];
104151 var frontPad = convInfo.padInfo.front;
104152 var leftPad = convInfo.padInfo.left;
104153 var topPad = convInfo.padInfo.top;
104154 for (var wF = 0; wF < filterDepth; ++wF) {
104155 var yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth));
104156 var yFMax = Math.min(convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth);
104157 var wOffset1 = wF * dwS0;
104158 for (var wR = 0; wR < filterHeight; ++wR) {
104159 var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
104160 var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
104161 var wOffset2 = wR * dwS1 + wOffset1;
104162 for (var wC = 0; wC < filterWidth; ++wC) {
104163 var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
104164 var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
104165 var wOffset3 = wC * dwS2 + wOffset2;
104166 for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
104167 var wOffset4 = d1 * dwS3 + wOffset3;
104168 for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
104169 var dotProd = 0;
104170 for (var b = 0; b < convInfo.batchSize; ++b) {
104171 var xOffset1 = b * xS0;
104172 var yOffset1 = b * dyS0;
104173 for (var yF = yFMin; yF < yFMax; ++yF) {
104174 var xF = wF + yF * strideDepth - frontPad;
104175 var xOffset2 = xF * xS1 + xOffset1;
104176 var yOffset2 = yF * dyS1 + yOffset1;
104177 for (var yR = yRMin; yR < yRMax; ++yR) {
104178 var xR = wR + yR * strideHeight - topPad;
104179 var xOffset3 = xR * xS2 + xOffset2;
104180 var yOffset3 = yR * dyS2 + yOffset2;
104181 for (var yC = yCMin; yC < yCMax; ++yC) {
104182 var xC = wC + yC * strideWidth - leftPad;
104183 var xOffset4 = xC * xS3 + xOffset3;
104184 var yOffset4 = yC * dyS3 + yOffset3;
104185 dotProd += xValues[xOffset4 + d1] * dyValues[yOffset4 + d2];
104186 }
104187 }
104188 }
104189 }
104190 dwValues[wOffset4 + d2] = dotProd;
104191 }
104192 }
104193 }
104194 }
104195 }
104196 return backend.makeTensorInfo(dw.shape, dw.dtype, dw.values);
104197 }
104198 var conv3DBackpropFilterV2Config$1 = {
104199 kernelName: Conv3DBackpropFilterV2,
104200 backendName: 'cpu',
104201 kernelFunc: conv3DBackpropFilterV2$1
104202 };
104203
104204 function conv3DBackpropInputV2(args) {
104205 var inputs = args.inputs,
104206 backend = args.backend,
104207 attrs = args.attrs;
104208 var dy = inputs.dy,
104209 filter = inputs.filter;
104210 var pad = attrs.pad,
104211 strides = attrs.strides,
104212 inputShape = attrs.inputShape;
104213 assertNotComplex$1([dy], 'conv3dBackpropInputV2');
104214 var dyStrides = computeStrides(dy.shape);
104215 var filterStrides = computeStrides(filter.shape);
104216 var convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad);
104217 var dx = new TensorBuffer(convInfo.inShape, 'float32');
104218 var dxValues = dx.values;
104219 var _dx$strides = _slicedToArray(dx.strides, 4),
104220 dxS0 = _dx$strides[0],
104221 dxS1 = _dx$strides[1],
104222 dxS2 = _dx$strides[2],
104223 dxS3 = _dx$strides[3];
104224 var dyValues = backend.data.get(dy.dataId).values;
104225 var _dyStrides = _slicedToArray(dyStrides, 4),
104226 dyS0 = _dyStrides[0],
104227 dyS1 = _dyStrides[1],
104228 dyS2 = _dyStrides[2],
104229 dyS3 = _dyStrides[3];
104230 var fltValues = backend.data.get(filter.dataId).values;
104231 var _filterStrides = _slicedToArray(filterStrides, 4),
104232 fltS0 = _filterStrides[0],
104233 fltS1 = _filterStrides[1],
104234 fltS2 = _filterStrides[2],
104235 fltS3 = _filterStrides[3];
104236 var batchSize = convInfo.batchSize,
104237 filterDepth = convInfo.filterDepth,
104238 filterHeight = convInfo.filterHeight,
104239 filterWidth = convInfo.filterWidth,
104240 inChannels = convInfo.inChannels,
104241 inDepth = convInfo.inDepth,
104242 inHeight = convInfo.inHeight,
104243 inWidth = convInfo.inWidth,
104244 outChannels = convInfo.outChannels,
104245 outDepth = convInfo.outDepth,
104246 outHeight = convInfo.outHeight,
104247 outWidth = convInfo.outWidth,
104248 strideDepth = convInfo.strideDepth,
104249 strideHeight = convInfo.strideHeight,
104250 strideWidth = convInfo.strideWidth;
104251 var frontPad = filterDepth - 1 - convInfo.padInfo.front;
104252 var topPad = filterHeight - 1 - convInfo.padInfo.top;
104253 var leftPad = filterWidth - 1 - convInfo.padInfo.left;
104254 for (var b = 0; b < batchSize; ++b) {
104255 for (var d1 = 0; d1 < inChannels; ++d1) {
104256 // Frames of depth
104257 for (var xF = 0; xF < inDepth; ++xF) {
104258 var xFCorner = xF - frontPad;
104259 var xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth));
104260 var yFMax = Math.min(outDepth, (filterDepth + xFCorner) / strideDepth);
104261 // Rows as per standard 2d matrix notation
104262 for (var xR = 0; xR < inHeight; ++xR) {
104263 var xRCorner = xR - topPad;
104264 var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
104265 var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
104266 // Columns as per standard 2d matrix notation
104267 for (var xC = 0; xC < inWidth; ++xC) {
104268 var xCCorner = xC - leftPad;
104269 var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
104270 var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
104271 var dotProd = 0;
104272 for (var yF = xFMin; yF < yFMax; ++yF) {
104273 var wF = yF * strideDepth - xFCorner;
104274 for (var yR = xRMin; yR < yRMax; ++yR) {
104275 var wR = yR * strideHeight - xRCorner;
104276 for (var yC = xCMin; yC < yCMax; ++yC) {
104277 var wC = yC * strideWidth - xCCorner;
104278 var dyOffset = dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC;
104279 var fltOffset = fltS0 * (filterDepth - 1 - wF) + fltS1 * (filterHeight - 1 - wR) + fltS2 * (filterWidth - 1 - wC) + fltS3 * d1;
104280 for (var d2 = 0; d2 < outChannels; ++d2) {
104281 var pixel = dyValues[dyOffset + d2];
104282 var weight = fltValues[fltOffset + d2];
104283 dotProd += pixel * weight;
104284 }
104285 }
104286 }
104287 }
104288 dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] = dotProd;
104289 }
104290 }
104291 }
104292 }
104293 }
104294 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
104295 }
104296 var conv3DBackpropInputV2Config = {
104297 kernelName: Conv3DBackpropInputV2,
104298 backendName: 'cpu',
104299 kernelFunc: conv3DBackpropInputV2
104300 };
104301
104302 /**
104303 * @license
104304 * Copyright 2020 Google LLC. All Rights Reserved.
104305 * Licensed under the Apache License, Version 2.0 (the "License");
104306 * you may not use this file except in compliance with the License.
104307 * You may obtain a copy of the License at
104308 *
104309 * http://www.apache.org/licenses/LICENSE-2.0
104310 *
104311 * Unless required by applicable law or agreed to in writing, software
104312 * distributed under the License is distributed on an "AS IS" BASIS,
104313 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104314 * See the License for the specific language governing permissions and
104315 * limitations under the License.
104316 * =============================================================================
104317 */
104318 var cos$1 = unaryKernelFunc$1(Cos, function (xi) {
104319 return Math.cos(xi);
104320 });
104321 var cosConfig$1 = {
104322 kernelName: Cos,
104323 backendName: 'cpu',
104324 kernelFunc: cos$1
104325 };
104326
104327 /**
104328 * @license
104329 * Copyright 2020 Google LLC. All Rights Reserved.
104330 * Licensed under the Apache License, Version 2.0 (the License);
104331 * you may not use this file except in compliance with the License.
104332 * You may obtain a copy of the License at
104333 *
104334 * http://www.apache.org/licenses/LICENSE-2.0
104335 *
104336 * Unless required by applicable law or agreed to in writing, software
104337 * distributed under the License is distributed on an AS IS BASIS,
104338 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104339 * See the License for the specific language governing permissions and
104340 * limitations under the License.
104341 * =============================================================================
104342 */
104343 var cosh$1 = unaryKernelFunc$1(Cosh, function (xi) {
104344 return Math.cosh(xi);
104345 });
104346 var coshConfig$1 = {
104347 kernelName: Cosh,
104348 backendName: 'cpu',
104349 kernelFunc: cosh$1
104350 };
104351
104352 function cropAndResize$1(args) {
104353 var inputs = args.inputs,
104354 backend = args.backend,
104355 attrs = args.attrs;
104356 var image = inputs.image,
104357 boxes = inputs.boxes,
104358 boxInd = inputs.boxInd;
104359 var cropSize = attrs.cropSize,
104360 method = attrs.method,
104361 extrapolationValue = attrs.extrapolationValue;
104362 var _image$shape = _slicedToArray(image.shape, 4),
104363 batch = _image$shape[0],
104364 imageHeight = _image$shape[1],
104365 imageWidth = _image$shape[2],
104366 numChannels = _image$shape[3];
104367 var numBoxes = boxes.shape[0];
104368 var _cropSize = _slicedToArray(cropSize, 2),
104369 cropHeight = _cropSize[0],
104370 cropWidth = _cropSize[1];
104371 var output = buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32');
104372 var boxVals = backend.data.get(boxes.dataId).values;
104373 var boxIndVals = backend.data.get(boxInd.dataId).values;
104374 var imageVals = backend.data.get(image.dataId).values;
104375 var inStride = computeStrides(image.shape); // to calculate flat indexes into image
104376 var outStride = computeStrides(output.shape); // to calculate flat indexes into output
104377 // Reference implementation
104378 // tslint:disable-next-line:max-line-length
104379 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op.cc
104380 for (var b = 0; b < numBoxes; b++) {
104381 var startInd = b * 4;
104382 var y1 = boxVals[startInd];
104383 var x1 = boxVals[startInd + 1];
104384 var y2 = boxVals[startInd + 2];
104385 var x2 = boxVals[startInd + 3];
104386 var bInd = boxIndVals[b];
104387 if (bInd >= batch) {
104388 continue;
104389 }
104390 var heightScale = cropHeight > 1 ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : 0;
104391 var widthScale = cropWidth > 1 ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0;
104392 for (var y = 0; y < cropHeight; y++) {
104393 var yInd = cropHeight > 1 ? y1 * (imageHeight - 1) + y * heightScale : 0.5 * (y1 + y2) * (imageHeight - 1);
104394 if (yInd < 0 || yInd > imageHeight - 1) {
104395 for (var x = 0; x < cropWidth; x++) {
104396 for (var c = 0; c < numChannels; c++) {
104397 var ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
104398 output.values[ind] = extrapolationValue;
104399 }
104400 }
104401 continue;
104402 }
104403 if (method === 'bilinear') {
104404 var topInd = Math.floor(yInd);
104405 var bottomInd = Math.ceil(yInd);
104406 var yLerp = yInd - topInd;
104407 for (var _x = 0; _x < cropWidth; _x++) {
104408 var xInd = cropWidth > 1 ? x1 * (imageWidth - 1) + _x * widthScale : 0.5 * (x1 + x2) * (imageWidth - 1);
104409 if (xInd < 0 || xInd > imageWidth - 1) {
104410 for (var _c = 0; _c < numChannels; _c++) {
104411 var _ind = _c + _x * outStride[2] + y * outStride[1] + b * outStride[0];
104412 output.values[_ind] = extrapolationValue;
104413 }
104414 continue;
104415 }
104416 var leftInd = Math.floor(xInd);
104417 var rightInd = Math.ceil(xInd);
104418 var xLerp = xInd - leftInd;
104419 for (var _c2 = 0; _c2 < numChannels; _c2++) {
104420 var _ind2 = _c2 + leftInd * inStride[2] + topInd * inStride[1] + bInd * inStride[0];
104421 var topLeft = imageVals[_ind2];
104422 _ind2 = _c2 + rightInd * inStride[2] + topInd * inStride[1] + bInd * inStride[0];
104423 var topRight = imageVals[_ind2];
104424 _ind2 = _c2 + leftInd * inStride[2] + bottomInd * inStride[1] + bInd * inStride[0];
104425 var bottomLeft = imageVals[_ind2];
104426 _ind2 = _c2 + rightInd * inStride[2] + bottomInd * inStride[1] + bInd * inStride[0];
104427 var bottomRight = imageVals[_ind2];
104428 var top = topLeft + (topRight - topLeft) * xLerp;
104429 var bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp;
104430 _ind2 = _c2 + _x * outStride[2] + y * outStride[1] + b * outStride[0];
104431 output.values[_ind2] = top + (bottom - top) * yLerp;
104432 }
104433 }
104434 } else {
104435 // method == "nearest"
104436 for (var _x2 = 0; _x2 < cropWidth; ++_x2) {
104437 var _xInd = cropWidth > 1 ? x1 * (imageWidth - 1) + _x2 * widthScale : 0.5 * (x1 + x2) * (imageWidth - 1);
104438 if (_xInd < 0 || _xInd > imageWidth - 1) {
104439 for (var _c3 = 0; _c3 < numChannels; _c3++) {
104440 var _ind3 = _c3 + _x2 * outStride[2] + y * outStride[1] + b * outStride[0];
104441 output.values[_ind3] = extrapolationValue;
104442 }
104443 continue;
104444 }
104445 var closestX = Math.round(_xInd);
104446 var closestY = Math.round(yInd);
104447 for (var _c4 = 0; _c4 < numChannels; _c4++) {
104448 var inInd = _c4 + closestX * inStride[2] + closestY * inStride[1] + bInd * inStride[0];
104449 var outInd = _c4 + _x2 * outStride[2] + y * outStride[1] + b * outStride[0];
104450 output.values[outInd] = imageVals[inInd];
104451 }
104452 }
104453 }
104454 }
104455 }
104456 return backend.makeTensorInfo(output.shape, output.dtype, output.values);
104457 }
104458 var cropAndResizeConfig$1 = {
104459 kernelName: CropAndResize,
104460 backendName: 'cpu',
104461 kernelFunc: cropAndResize$1
104462 };
104463
104464 /**
104465 * @license
104466 * Copyright 2022 Google LLC. All Rights Reserved.
104467 * Licensed under the Apache License, Version 2.0 (the "License");
104468 * you may not use this file except in compliance with the License.
104469 * You may obtain a copy of the License at
104470 *
104471 * http://www.apache.org/licenses/LICENSE-2.0
104472 *
104473 * Unless required by applicable law or agreed to in writing, software
104474 * distributed under the License is distributed on an "AS IS" BASIS,
104475 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104476 * See the License for the specific language governing permissions and
104477 * limitations under the License.
104478 * =============================================================================
104479 */
104480 function cumprod$1(args) {
104481 var inputs = args.inputs,
104482 backend = args.backend,
104483 attrs = args.attrs;
104484 var x = inputs.x;
104485 var axis = attrs.axis,
104486 exclusive = attrs.exclusive,
104487 reverse = attrs.reverse;
104488 assertNotComplex$1(x, 'cumprod');
104489 var permutation = getAxesPermutation([axis], x.shape.length);
104490 var $x = x;
104491 if (permutation != null) {
104492 $x = transpose$1({
104493 inputs: {
104494 x: x
104495 },
104496 backend: backend,
104497 attrs: {
104498 perm: permutation
104499 }
104500 });
104501 }
104502 var permutedAxis = getInnerMostAxes(1, x.shape.length)[0];
104503 if (permutedAxis !== $x.shape.length - 1) {
104504 throw new Error("backend.cumprod in CPU expects an inner-most " + "axis=".concat($x.shape.length - 1, " but got axis=").concat(permutedAxis));
104505 }
104506 var resultDtype = upcastType($x.dtype, 'int32');
104507 var vals = makeOnesTypedArray(sizeFromShape($x.shape), resultDtype);
104508 var aVals = backend.data.get($x.dataId).values;
104509 var finalDim = $x.shape[$x.shape.length - 1];
104510 var indexAdjuster = reverse ? function (i, j) {
104511 return i + finalDim - j - 1;
104512 } : function (i, j) {
104513 return i + j;
104514 };
104515 for (var i = 0; i < aVals.length; i += finalDim) {
104516 for (var j = 0; j < finalDim; j++) {
104517 var idx = indexAdjuster(i, j);
104518 if (j === 0) {
104519 vals[idx] = exclusive ? 1 : aVals[idx];
104520 } else {
104521 var prevIdx = indexAdjuster(i, j - 1);
104522 vals[idx] = exclusive ? aVals[prevIdx] * vals[prevIdx] : aVals[idx] * vals[prevIdx];
104523 }
104524 }
104525 }
104526 var result = backend.makeTensorInfo($x.shape, resultDtype, vals);
104527 if (permutation != null) {
104528 var reversePermutation = getUndoAxesPermutation(permutation);
104529 var reverseTransposedResult = transpose$1({
104530 inputs: {
104531 x: result
104532 },
104533 backend: backend,
104534 attrs: {
104535 perm: reversePermutation
104536 }
104537 });
104538 backend.disposeIntermediateTensorInfo(result);
104539 backend.disposeIntermediateTensorInfo($x);
104540 return reverseTransposedResult;
104541 }
104542 return result;
104543 }
104544 var cumprodConfig$1 = {
104545 kernelName: Cumprod,
104546 backendName: 'cpu',
104547 kernelFunc: cumprod$1
104548 };
104549
104550 /**
104551 * @license
104552 * Copyright 2020 Google LLC. All Rights Reserved.
104553 * Licensed under the Apache License, Version 2.0 (the "License");
104554 * you may not use this file except in compliance with the License.
104555 * You may obtain a copy of the License at
104556 *
104557 * http://www.apache.org/licenses/LICENSE-2.0
104558 *
104559 * Unless required by applicable law or agreed to in writing, software
104560 * distributed under the License is distributed on an "AS IS" BASIS,
104561 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104562 * See the License for the specific language governing permissions and
104563 * limitations under the License.
104564 * =============================================================================
104565 */
104566 function cumsum$1(args) {
104567 var inputs = args.inputs,
104568 backend = args.backend,
104569 attrs = args.attrs;
104570 var x = inputs.x;
104571 var axis = attrs.axis,
104572 exclusive = attrs.exclusive,
104573 reverse = attrs.reverse;
104574 assertNotComplex$1(x, 'cumsum');
104575 var permutation = getAxesPermutation([axis], x.shape.length);
104576 var $x = x;
104577 if (permutation != null) {
104578 $x = transpose$1({
104579 inputs: {
104580 x: x
104581 },
104582 backend: backend,
104583 attrs: {
104584 perm: permutation
104585 }
104586 });
104587 }
104588 var permutedAxis = getInnerMostAxes(1, x.shape.length)[0];
104589 if (permutedAxis !== $x.shape.length - 1) {
104590 throw new Error("backend.cumsum in CPU expects an inner-most " + "axis=".concat($x.shape.length - 1, " but got axis=").concat(permutedAxis));
104591 }
104592 var resultDtype = upcastType($x.dtype, 'int32');
104593 var vals = makeZerosTypedArray(sizeFromShape($x.shape), resultDtype);
104594 var aVals = backend.data.get($x.dataId).values;
104595 var finalDim = $x.shape[$x.shape.length - 1];
104596 var indexAdjuster = reverse ? function (i, j) {
104597 return i + finalDim - j - 1;
104598 } : function (i, j) {
104599 return i + j;
104600 };
104601 for (var i = 0; i < aVals.length; i += finalDim) {
104602 for (var j = 0; j < finalDim; j++) {
104603 var idx = indexAdjuster(i, j);
104604 if (j === 0) {
104605 vals[idx] = exclusive ? 0 : aVals[idx];
104606 } else {
104607 var prevIdx = indexAdjuster(i, j - 1);
104608 vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] : aVals[idx] + vals[prevIdx];
104609 }
104610 }
104611 }
104612 var result = backend.makeTensorInfo($x.shape, resultDtype, vals);
104613 if (permutation != null) {
104614 var reversePermutation = getUndoAxesPermutation(permutation);
104615 var reverseTransposedResult = transpose$1({
104616 inputs: {
104617 x: result
104618 },
104619 backend: backend,
104620 attrs: {
104621 perm: reversePermutation
104622 }
104623 });
104624 backend.disposeIntermediateTensorInfo(result);
104625 backend.disposeIntermediateTensorInfo($x);
104626 return reverseTransposedResult;
104627 }
104628 return result;
104629 }
104630 var cumsumConfig$1 = {
104631 kernelName: Cumsum,
104632 backendName: 'cpu',
104633 kernelFunc: cumsum$1
104634 };
104635
104636 /**
104637 * @license
104638 * Copyright 2020 Google LLC. All Rights Reserved.
104639 * Licensed under the Apache License, Version 2.0 (the "License");
104640 * you may not use this file except in compliance with the License.
104641 * You may obtain a copy of the License at
104642 *
104643 * http://www.apache.org/licenses/LICENSE-2.0
104644 *
104645 * Unless required by applicable law or agreed to in writing, software
104646 * distributed under the License is distributed on an "AS IS" BASIS,
104647 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104648 * See the License for the specific language governing permissions and
104649 * limitations under the License.
104650 * =============================================================================
104651 */
104652 function denseBincount$1(args) {
104653 var inputs = args.inputs,
104654 backend = args.backend,
104655 attrs = args.attrs;
104656 var x = inputs.x,
104657 weights = inputs.weights;
104658 var size = attrs.size,
104659 binaryOutput = attrs.binaryOutput;
104660 if (x.shape.length === 1) {
104661 var xVals = backend.data.get(x.dataId).values;
104662 var weightsVals = backend.data.get(weights.dataId).values;
104663 var outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
104664 return backend.makeTensorInfo([size], weights.dtype, outVals);
104665 } else if (x.shape.length === 2) {
104666 var xBuf = backend.bufferSync(x);
104667 var weightsBuf = backend.bufferSync(weights);
104668 var outBuf = bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput);
104669 return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
104670 }
104671 throw new Error("Error in denseBincount: input must be at most rank 2, but got rank" + "".concat(x.shape.length, "."));
104672 }
104673 var denseBincountConfig$1 = {
104674 kernelName: DenseBincount,
104675 backendName: 'cpu',
104676 kernelFunc: denseBincount$1
104677 };
104678
104679 /**
104680 * @license
104681 * Copyright 2020 Google LLC. All Rights Reserved.
104682 * Licensed under the Apache License, Version 2.0 (the "License");
104683 * you may not use this file except in compliance with the License.
104684 * You may obtain a copy of the License at
104685 *
104686 * http://www.apache.org/licenses/LICENSE-2.0
104687 *
104688 * Unless required by applicable law or agreed to in writing, software
104689 * distributed under the License is distributed on an "AS IS" BASIS,
104690 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104691 * See the License for the specific language governing permissions and
104692 * limitations under the License.
104693 * =============================================================================
104694 */
104695 function depthToSpace$1(args) {
104696 var inputs = args.inputs,
104697 backend = args.backend,
104698 attrs = args.attrs;
104699 var x = inputs.x;
104700 var blockSize = attrs.blockSize,
104701 dataFormat = attrs.dataFormat;
104702 assert$1(dataFormat === 'NHWC', function () {
104703 return "Only NHWC dataFormat supported on CPU for depthToSpace. Got ".concat(dataFormat);
104704 });
104705 var batchSize = x.shape[0];
104706 var inputHeight = x.shape[1];
104707 var inputWidth = x.shape[2];
104708 var inputDepth = x.shape[3];
104709 var outputHeight = inputHeight * blockSize;
104710 var outputWidth = inputWidth * blockSize;
104711 var outputDepth = inputDepth / (blockSize * blockSize);
104712 var xValues = backend.data.get(x.dataId).values;
104713 var result = new Float32Array(batchSize * outputHeight * outputWidth * outputDepth);
104714 var outputIdx = 0;
104715 for (var b = 0; b < batchSize; ++b) {
104716 for (var h = 0; h < outputHeight; ++h) {
104717 var inH = Math.floor(h / blockSize);
104718 var offsetH = h % blockSize;
104719 for (var w = 0; w < outputWidth; ++w) {
104720 var inW = Math.floor(w / blockSize);
104721 var offsetW = w % blockSize;
104722 var offsetD = (offsetH * blockSize + offsetW) * outputDepth;
104723 for (var d = 0; d < outputDepth; ++d) {
104724 var inD = d + offsetD;
104725 var inputIdx = inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b));
104726 result[outputIdx++] = xValues[inputIdx];
104727 }
104728 }
104729 }
104730 }
104731 return backend.makeTensorInfo([batchSize, outputHeight, outputWidth, outputDepth], x.dtype, result);
104732 }
104733 var depthToSpaceConfig$1 = {
104734 kernelName: DepthToSpace,
104735 backendName: 'cpu',
104736 kernelFunc: depthToSpace$1
104737 };
104738
104739 /**
104740 * @license
104741 * Copyright 2020 Google LLC. All Rights Reserved.
104742 * Licensed under the Apache License, Version 2.0 (the "License");
104743 * you may not use this file except in compliance with the License.
104744 * You may obtain a copy of the License at
104745 *
104746 * http://www.apache.org/licenses/LICENSE-2.0
104747 *
104748 * Unless required by applicable law or agreed to in writing, software
104749 * distributed under the License is distributed on an "AS IS" BASIS,
104750 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104751 * See the License for the specific language governing permissions and
104752 * limitations under the License.
104753 * =============================================================================
104754 */
104755 function depthwiseConv2dNative$1(args) {
104756 var inputs = args.inputs,
104757 backend = args.backend,
104758 attrs = args.attrs;
104759 var x = inputs.x,
104760 filter = inputs.filter;
104761 var strides = attrs.strides,
104762 pad = attrs.pad,
104763 dilations = attrs.dilations,
104764 dimRoundingMode = attrs.dimRoundingMode;
104765 assertNotComplex$1([x, filter], 'depthwiseConv2DNative');
104766 var xStrides = computeStrides(x.shape);
104767 var filterStrides = computeStrides(filter.shape);
104768 var $dilations = dilations;
104769 if ($dilations == null) {
104770 $dilations = [1, 1];
104771 }
104772 assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), function () {
104773 return 'Error in depthwiseConv2d: Either strides or dilations must be ' + "1. Got strides ".concat(strides, " and dilations '").concat($dilations, "'");
104774 });
104775 var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
104776 var filterHeight = convInfo.filterHeight,
104777 filterWidth = convInfo.filterWidth,
104778 dilationHeight = convInfo.dilationHeight,
104779 dilationWidth = convInfo.dilationWidth,
104780 padInfo = convInfo.padInfo;
104781 var padLeft = padInfo.left;
104782 var padTop = padInfo.top;
104783 var chMul = convInfo.outChannels / convInfo.inChannels;
104784 var y = new TensorBuffer(convInfo.outShape, x.dtype);
104785 var xVals = backend.data.get(x.dataId).values;
104786 var wVals = backend.data.get(filter.dataId).values;
104787 var yVals = y.values;
104788 for (var b = 0; b < convInfo.batchSize; ++b) {
104789 var xOffset1 = b * xStrides[0];
104790 var yOffset1 = b * y.strides[0];
104791 for (var yR = 0; yR < convInfo.outHeight; ++yR) {
104792 var yOffset2 = yOffset1 + yR * y.strides[1];
104793 var xRCorner = yR * convInfo.strideHeight - padTop;
104794 for (var wR = 0; wR < filterHeight; ++wR) {
104795 var xR = xRCorner + wR * dilationHeight;
104796 if (xR < 0 || xR >= convInfo.inHeight) {
104797 continue;
104798 }
104799 var wOffset1 = wR * filterStrides[0];
104800 var xOffset2 = xOffset1 + xR * xStrides[1];
104801 for (var yC = 0; yC < convInfo.outWidth; ++yC) {
104802 var yOffset3 = yOffset2 + yC * y.strides[2];
104803 var xCCorner = yC * convInfo.strideWidth - padLeft;
104804 for (var wC = 0; wC < filterWidth; ++wC) {
104805 var xC = xCCorner + wC * dilationWidth;
104806 if (xC < 0 || xC >= convInfo.inWidth) {
104807 continue;
104808 }
104809 var wOffset2 = wOffset1 + wC * filterStrides[1];
104810 var xOffset3 = xOffset2 + xC * convInfo.inChannels;
104811 var yOffset4 = yOffset3;
104812 var wOffset3 = wOffset2;
104813 for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
104814 var xVal = xVals[xOffset3 + d1];
104815 for (var q = 0; q < chMul; ++q) {
104816 yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q];
104817 }
104818 yOffset4 += chMul;
104819 wOffset3 += chMul;
104820 }
104821 }
104822 }
104823 }
104824 }
104825 }
104826 return backend.makeTensorInfo(y.shape, y.dtype, y.values);
104827 }
104828 var depthwiseConv2dNativeConfig$1 = {
104829 kernelName: DepthwiseConv2dNative,
104830 backendName: 'cpu',
104831 kernelFunc: depthwiseConv2dNative$1
104832 };
104833
104834 /**
104835 * @license
104836 * Copyright 2020 Google LLC. All Rights Reserved.
104837 * Licensed under the Apache License, Version 2.0 (the "License");
104838 * you may not use this file except in compliance with the License.
104839 * You may obtain a copy of the License at
104840 *
104841 * http://www.apache.org/licenses/LICENSE-2.0
104842 *
104843 * Unless required by applicable law or agreed to in writing, software
104844 * distributed under the License is distributed on an "AS IS" BASIS,
104845 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104846 * See the License for the specific language governing permissions and
104847 * limitations under the License.
104848 * =============================================================================
104849 */
104850 function depthwiseConv2dNativeBackpropFilter$1(args) {
104851 var inputs = args.inputs,
104852 backend = args.backend,
104853 attrs = args.attrs;
104854 var x = inputs.x,
104855 dy = inputs.dy;
104856 var strides = attrs.strides,
104857 dilations = attrs.dilations,
104858 pad = attrs.pad,
104859 dimRoundingMode = attrs.dimRoundingMode,
104860 filterShape = attrs.filterShape;
104861 assertNotComplex$1([x, dy], 'depthwiseConv2dNativeBackpropFilter');
104862 var convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
104863 var strideHeight = convInfo.strideHeight,
104864 strideWidth = convInfo.strideWidth,
104865 filterHeight = convInfo.filterHeight,
104866 filterWidth = convInfo.filterWidth;
104867 var dW = new TensorBuffer(convInfo.filterShape, 'float32');
104868 var leftPad = convInfo.padInfo.left;
104869 var topPad = convInfo.padInfo.top;
104870 var chMul = convInfo.outChannels / convInfo.inChannels;
104871 var xVals = backend.data.get(x.dataId).values;
104872 var xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
104873 var dyVals = backend.data.get(dy.dataId).values;
104874 var dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
104875 for (var wR = 0; wR < filterHeight; ++wR) {
104876 var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
104877 var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
104878 for (var wC = 0; wC < filterWidth; ++wC) {
104879 var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
104880 var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
104881 for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
104882 var d1 = Math.trunc(d2 / chMul);
104883 var dm = d2 % chMul;
104884 var dotProd = 0;
104885 for (var b = 0; b < convInfo.batchSize; ++b) {
104886 for (var yR = yRMin; yR < yRMax; ++yR) {
104887 var xR = wR + yR * strideHeight - topPad;
104888 for (var yC = yCMin; yC < yCMax; ++yC) {
104889 var xC = wC + yC * strideWidth - leftPad;
104890 dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
104891 }
104892 }
104893 }
104894 dW.set(dotProd, wR, wC, d1, dm);
104895 }
104896 }
104897 }
104898 return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
104899 }
104900 var depthwiseConv2dNativeBackpropFilterConfig$1 = {
104901 kernelName: DepthwiseConv2dNativeBackpropFilter,
104902 backendName: 'cpu',
104903 kernelFunc: depthwiseConv2dNativeBackpropFilter$1
104904 };
104905
104906 function depthwiseConv2dNativeBackpropInput$1(args) {
104907 var inputs = args.inputs,
104908 backend = args.backend,
104909 attrs = args.attrs;
104910 var dy = inputs.dy,
104911 filter = inputs.filter;
104912 var strides = attrs.strides,
104913 dilations = attrs.dilations,
104914 pad = attrs.pad,
104915 dimRoundingMode = attrs.dimRoundingMode,
104916 inputShape = attrs.inputShape;
104917 assertNotComplex$1([dy, filter], 'depthwiseConv2DNativeBackpropInput');
104918 var dyStrides = computeStrides(dy.shape);
104919 var filterStrides = computeStrides(filter.shape);
104920 var convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
104921 var dx = new TensorBuffer(convInfo.inShape, 'float32');
104922 var dxValues = dx.values;
104923 var _dx$strides = _slicedToArray(dx.strides, 3),
104924 dxS0 = _dx$strides[0],
104925 dxS1 = _dx$strides[1],
104926 dxS2 = _dx$strides[2];
104927 var dyValues = backend.data.get(dy.dataId).values;
104928 var _dyStrides = _slicedToArray(dyStrides, 3),
104929 dyS0 = _dyStrides[0],
104930 dyS1 = _dyStrides[1],
104931 dyS2 = _dyStrides[2];
104932 var fltValues = backend.data.get(filter.dataId).values;
104933 var _filterStrides = _slicedToArray(filterStrides, 3),
104934 fltS0 = _filterStrides[0],
104935 fltS1 = _filterStrides[1],
104936 fltS2 = _filterStrides[2];
104937 var batchSize = convInfo.batchSize,
104938 filterHeight = convInfo.filterHeight,
104939 filterWidth = convInfo.filterWidth,
104940 inChannels = convInfo.inChannels,
104941 inHeight = convInfo.inHeight,
104942 inWidth = convInfo.inWidth,
104943 outChannels = convInfo.outChannels,
104944 outHeight = convInfo.outHeight,
104945 outWidth = convInfo.outWidth,
104946 strideHeight = convInfo.strideHeight,
104947 strideWidth = convInfo.strideWidth;
104948 var topPad = filterHeight - 1 - convInfo.padInfo.top;
104949 var leftPad = filterWidth - 1 - convInfo.padInfo.left;
104950 var chMul = outChannels / inChannels;
104951 for (var b = 0; b < batchSize; ++b) {
104952 for (var d1 = 0; d1 < inChannels; ++d1) {
104953 for (var xR = 0; xR < inHeight; ++xR) {
104954 var xRCorner = xR - topPad;
104955 var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
104956 var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
104957 for (var xC = 0; xC < inWidth; ++xC) {
104958 var xCCorner = xC - leftPad;
104959 var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
104960 var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
104961 var dotProd = 0;
104962 for (var yR = xRMin; yR < yRMax; ++yR) {
104963 var wR = yR * strideHeight - xRCorner;
104964 for (var yC = xCMin; yC < yCMax; ++yC) {
104965 var wC = yC * strideWidth - xCCorner;
104966 var dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC;
104967 var fltOffset = fltS0 * (filterHeight - 1 - wR) + fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
104968 for (var dm = 0; dm < chMul; ++dm) {
104969 var d2 = d1 * chMul + dm;
104970 var pixel = dyValues[dyOffset + d2];
104971 var weight = fltValues[fltOffset + dm];
104972 dotProd += pixel * weight;
104973 }
104974 }
104975 }
104976 dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd;
104977 }
104978 }
104979 }
104980 }
104981 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
104982 }
104983 var depthwiseConv2dNativeBackpropInputConfig$1 = {
104984 kernelName: DepthwiseConv2dNativeBackpropInput,
104985 backendName: 'cpu',
104986 kernelFunc: depthwiseConv2dNativeBackpropInput$1
104987 };
104988
104989 function diag$1(args) {
104990 var inputs = args.inputs,
104991 backend = args.backend;
104992 var x = inputs.x;
104993 var xSize = sizeFromShape(x.shape);
104994 var xVals = backend.data.get(x.dataId).values;
104995 var outBuf = buffer([xSize, xSize], x.dtype);
104996 var vals = outBuf.values;
104997 for (var i = 0; i < xVals.length; i++) {
104998 vals[i * xSize + i] = xVals[i];
104999 }
105000 var outShape = [].concat(_toConsumableArray(x.shape), _toConsumableArray(x.shape));
105001 return backend.makeTensorInfo(outShape, outBuf.dtype, outBuf.values);
105002 }
105003 var diagConfig$1 = {
105004 kernelName: Diag,
105005 backendName: 'cpu',
105006 kernelFunc: diag$1
105007 };
105008
105009 /**
105010 * @license
105011 * Copyright 2020 Google LLC. All Rights Reserved.
105012 * Licensed under the Apache License, Version 2.0 (the "License");
105013 * you may not use this file except in compliance with the License.
105014 * You may obtain a copy of the License at
105015 *
105016 * http://www.apache.org/licenses/LICENSE-2.0
105017 *
105018 * Unless required by applicable law or agreed to in writing, software
105019 * distributed under the License is distributed on an "AS IS" BASIS,
105020 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
105021 * See the License for the specific language governing permissions and
105022 * limitations under the License.
105023 * =============================================================================
105024 */
105025 var dilation2DConfig$1 = {
105026 kernelName: Dilation2D,
105027 backendName: 'cpu',
105028 kernelFunc: function kernelFunc(_ref) {
105029 var inputs = _ref.inputs,
105030 backend = _ref.backend,
105031 attrs = _ref.attrs;
105032 var x = inputs.x,
105033 filter = inputs.filter;
105034 var strides = attrs.strides,
105035 pad = attrs.pad,
105036 dilations = attrs.dilations;
105037 var cpuBackend = backend;
105038 var xVals = cpuBackend.data.get(x.dataId).values;
105039 var xRank = x.shape.length;
105040 var filterVals = cpuBackend.data.get(filter.dataId).values;
105041 var filterRank = filter.shape.length;
105042 var _backend_util$compute = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations),
105043 batchSize = _backend_util$compute.batchSize,
105044 inHeight = _backend_util$compute.inHeight,
105045 inWidth = _backend_util$compute.inWidth,
105046 inChannels = _backend_util$compute.inChannels,
105047 outHeight = _backend_util$compute.outHeight,
105048 outWidth = _backend_util$compute.outWidth,
105049 padInfo = _backend_util$compute.padInfo,
105050 strideHeight = _backend_util$compute.strideHeight,
105051 strideWidth = _backend_util$compute.strideWidth,
105052 filterHeight = _backend_util$compute.filterHeight,
105053 filterWidth = _backend_util$compute.filterWidth,
105054 dilationHeight = _backend_util$compute.dilationHeight,
105055 dilationWidth = _backend_util$compute.dilationWidth,
105056 outShape = _backend_util$compute.outShape;
105057 var outSize = sizeFromShape(outShape);
105058 var outRank = outShape.length;
105059 var outputVals = getArrayFromDType(x.dtype, outSize);
105060 // Upsampling the input by fill in `dilation size - 1` values between each
105061 // input value.
105062 // This implementation follows the TF c++ implementation:
105063 // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
105064 for (var b = 0; b < batchSize; ++b) {
105065 for (var hOut = 0; hOut < outHeight; ++hOut) {
105066 var hBeg = hOut * strideHeight - padInfo.top;
105067 for (var wOut = 0; wOut < outWidth; ++wOut) {
105068 var wBeg = wOut * strideWidth - padInfo.left;
105069 for (var d = 0; d < inChannels; ++d) {
105070 var curVal = Number.MIN_SAFE_INTEGER;
105071 for (var h = 0; h < filterHeight; ++h) {
105072 var hIn = hBeg + h * dilationHeight;
105073 if (hIn >= 0 && hIn < inHeight) {
105074 for (var w = 0; w < filterWidth; ++w) {
105075 var wIn = wBeg + w * dilationWidth;
105076 if (wIn >= 0 && wIn < inWidth) {
105077 var xIndex = locToIndex([b, hIn, wIn, d], xRank, computeStrides(x.shape));
105078 var filterIndex = locToIndex([h, w, d], filterRank, computeStrides(filter.shape));
105079 var val = xVals[xIndex] + filterVals[filterIndex];
105080 if (val > curVal) {
105081 curVal = val;
105082 }
105083 }
105084 }
105085 }
105086 }
105087 var outputIndex = locToIndex([b, hOut, wOut, d], outRank, computeStrides(outShape));
105088 outputVals[outputIndex] = curVal;
105089 }
105090 }
105091 }
105092 }
105093 var dataId = cpuBackend.write(toTypedArray(outputVals, x.dtype), outShape, x.dtype);
105094 return {
105095 dataId: dataId,
105096 shape: outShape,
105097 dtype: x.dtype
105098 };
105099 }
105100 };
105101
105102 /**
105103 * @license
105104 * Copyright 2020 Google LLC. All Rights Reserved.
105105 * Licensed under the Apache License, Version 2.0 (the "License");
105106 * you may not use this file except in compliance with the License.
105107 * You may obtain a copy of the License at
105108 *
105109 * http://www.apache.org/licenses/LICENSE-2.0
105110 *
105111 * Unless required by applicable law or agreed to in writing, software
105112 * distributed under the License is distributed on an "AS IS" BASIS,
105113 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
105114 * See the License for the specific language governing permissions and
105115 * limitations under the License.
105116 * =============================================================================
105117 */
105118 var dilation2DBackpropFilterConfig = {
105119 kernelName: Dilation2DBackpropFilter,
105120 backendName: 'cpu',
105121 kernelFunc: function kernelFunc(_ref) {
105122 var inputs = _ref.inputs,
105123 backend = _ref.backend,
105124 attrs = _ref.attrs;
105125 var x = inputs.x,
105126 filter = inputs.filter,
105127 dy = inputs.dy;
105128 var strides = attrs.strides,
105129 pad = attrs.pad,
105130 dilations = attrs.dilations;
105131 var cpuBackend = backend;
105132 var $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
105133 var $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
105134 var _backend_util$compute = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations),
105135 batchSize = _backend_util$compute.batchSize,
105136 inHeight = _backend_util$compute.inHeight,
105137 inWidth = _backend_util$compute.inWidth,
105138 inChannels = _backend_util$compute.inChannels,
105139 outHeight = _backend_util$compute.outHeight,
105140 outWidth = _backend_util$compute.outWidth,
105141 padInfo = _backend_util$compute.padInfo,
105142 strideHeight = _backend_util$compute.strideHeight,
105143 strideWidth = _backend_util$compute.strideWidth,
105144 filterHeight = _backend_util$compute.filterHeight,
105145 filterWidth = _backend_util$compute.filterWidth,
105146 dilationHeight = _backend_util$compute.dilationHeight,
105147 dilationWidth = _backend_util$compute.dilationWidth,
105148 outShape = _backend_util$compute.outShape;
105149 assert$1(dy.rank === outShape.length, function () {
105150 return "Error in ".concat(Dilation2DBackpropFilter, ", dy ") + "must have the same rank as output ".concat(outShape.length, ", but got ") + "".concat(dy.rank);
105151 });
105152 var $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
105153 // The computed filter gradients has the same dimensions as the filter:
105154 // [filterHeight, filterWidth, depth]
105155 var gradients = makeZerosNestedTypedArray(filter.shape, filter.dtype);
105156 // In the case of multiple argmax branches, we only back-propagate along the
105157 // last branch, i.e., the one with largest value of `h * filter_cols + w`,
105158 // similarly to the max-pooling backward routines.
105159 // This implementation follows the TF c++ implementation:
105160 // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
105161 for (var b = 0; b < batchSize; ++b) {
105162 for (var hOut = 0; hOut < outHeight; ++hOut) {
105163 var hBeg = hOut * strideHeight - padInfo.top;
105164 for (var wOut = 0; wOut < outWidth; ++wOut) {
105165 var wBeg = wOut * strideWidth - padInfo.left;
105166 for (var d = 0; d < inChannels; ++d) {
105167 var curVal = Number.MIN_SAFE_INTEGER;
105168 var hMax = 0;
105169 var wMax = 0;
105170 for (var h = 0; h < filterHeight; ++h) {
105171 var hIn = hBeg + h * dilationHeight;
105172 if (hIn >= 0 && hIn < inHeight) {
105173 for (var w = 0; w < filterWidth; ++w) {
105174 var wIn = wBeg + w * dilationWidth;
105175 if (wIn >= 0 && wIn < inWidth) {
105176 var val = $x[b][hIn][wIn][d] + $filter[h][w][d];
105177 if (val > curVal) {
105178 curVal = val;
105179 hMax = h;
105180 wMax = w;
105181 }
105182 }
105183 }
105184 }
105185 }
105186 gradients[hMax][wMax][d] += $dy[b][hOut][wOut][d];
105187 }
105188 }
105189 }
105190 }
105191 var dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), filter.shape, filter.dtype);
105192 return {
105193 dataId: dataId,
105194 shape: filter.shape,
105195 dtype: filter.dtype
105196 };
105197 }
105198 };
105199
105200 /**
105201 * @license
105202 * Copyright 2020 Google LLC. All Rights Reserved.
105203 * Licensed under the Apache License, Version 2.0 (the "License");
105204 * you may not use this file except in compliance with the License.
105205 * You may obtain a copy of the License at
105206 *
105207 * http://www.apache.org/licenses/LICENSE-2.0
105208 *
105209 * Unless required by applicable law or agreed to in writing, software
105210 * distributed under the License is distributed on an "AS IS" BASIS,
105211 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
105212 * See the License for the specific language governing permissions and
105213 * limitations under the License.
105214 * =============================================================================
105215 */
105216 var dilation2DBackpropInputConfig = {
105217 kernelName: Dilation2DBackpropInput,
105218 backendName: 'cpu',
105219 kernelFunc: function kernelFunc(_ref) {
105220 var inputs = _ref.inputs,
105221 backend = _ref.backend,
105222 attrs = _ref.attrs;
105223 var x = inputs.x,
105224 filter = inputs.filter,
105225 dy = inputs.dy;
105226 var strides = attrs.strides,
105227 pad = attrs.pad,
105228 dilations = attrs.dilations;
105229 var cpuBackend = backend;
105230 var $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
105231 var $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
105232 var _backend_util$compute = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations),
105233 batchSize = _backend_util$compute.batchSize,
105234 inHeight = _backend_util$compute.inHeight,
105235 inWidth = _backend_util$compute.inWidth,
105236 inChannels = _backend_util$compute.inChannels,
105237 outHeight = _backend_util$compute.outHeight,
105238 outWidth = _backend_util$compute.outWidth,
105239 padInfo = _backend_util$compute.padInfo,
105240 strideHeight = _backend_util$compute.strideHeight,
105241 strideWidth = _backend_util$compute.strideWidth,
105242 filterHeight = _backend_util$compute.filterHeight,
105243 filterWidth = _backend_util$compute.filterWidth,
105244 dilationHeight = _backend_util$compute.dilationHeight,
105245 dilationWidth = _backend_util$compute.dilationWidth,
105246 outShape = _backend_util$compute.outShape;
105247 assert$1(dy.rank === outShape.length, function () {
105248 return "Error in ".concat(Dilation2DBackpropInput, ", dy ") + "must have the same rank as output ".concat(outShape.length, ", but got ") + "".concat(dy.rank);
105249 });
105250 var $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
105251 // The computed gradients has the same dimensions as the input:
105252 // [batch, inputHeight, inputCols, inChannel]
105253 var gradients = makeZerosNestedTypedArray(x.shape, x.dtype);
105254 // In the case of multiple argmax branches, we only back-propagate along the
105255 // last branch, i.e., the one with largest value of `h * filter_cols + w`,
105256 // similarly to the max-pooling backward routines.
105257 // This implementation follows the TF c++ implementation:
105258 // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
105259 for (var b = 0; b < batchSize; ++b) {
105260 for (var hOut = 0; hOut < outHeight; ++hOut) {
105261 var hBeg = hOut * strideHeight - padInfo.top;
105262 for (var wOut = 0; wOut < outWidth; ++wOut) {
105263 var wBeg = wOut * strideWidth - padInfo.left;
105264 for (var d = 0; d < inChannels; ++d) {
105265 var curVal = Number.MIN_SAFE_INTEGER;
105266 var hInMax = hBeg < 0 ? 0 : hBeg;
105267 var wInMax = wBeg < 0 ? 0 : wBeg;
105268 for (var h = 0; h < filterHeight; ++h) {
105269 var hIn = hBeg + h * dilationHeight;
105270 if (hIn >= 0 && hIn < inHeight) {
105271 for (var w = 0; w < filterWidth; ++w) {
105272 var wIn = wBeg + w * dilationWidth;
105273 if (wIn >= 0 && wIn < inWidth) {
105274 var val = $x[b][hIn][wIn][d] + $filter[h][w][d];
105275 if (val > curVal) {
105276 curVal = val;
105277 hInMax = hIn;
105278 wInMax = wIn;
105279 }
105280 }
105281 }
105282 }
105283 }
105284 gradients[b][hInMax][wInMax][d] += $dy[b][hOut][wOut][d];
105285 }
105286 }
105287 }
105288 }
105289 var dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), x.shape, x.dtype);
105290 return {
105291 dataId: dataId,
105292 shape: x.shape,
105293 dtype: x.dtype
105294 };
105295 }
105296 };
105297
105298 function draw(args) {
105299 var inputs = args.inputs,
105300 backend = args.backend,
105301 attrs = args.attrs;
105302 var image = inputs.image;
105303 var canvas = attrs.canvas,
105304 options = attrs.options;
105305 var _ref = options || {},
105306 contextOptions = _ref.contextOptions,
105307 imageOptions = _ref.imageOptions;
105308 var alpha = (imageOptions === null || imageOptions === void 0 ? void 0 : imageOptions.alpha) || 1;
105309 var contextType = (contextOptions === null || contextOptions === void 0 ? void 0 : contextOptions.contextType) || '2d';
105310 if (contextType !== '2d') {
105311 throw new Error("Context type ".concat(contextOptions.contextType, " is not supported by the CPU backend."));
105312 }
105313 var ctx = canvas.getContext(contextType, (contextOptions === null || contextOptions === void 0 ? void 0 : contextOptions.contextAttributes) || {});
105314 if (ctx == null) {
105315 throw new Error("Could not get the context with ".concat(contextType, " type."));
105316 }
105317 var _image$shape$slice = image.shape.slice(0, 2),
105318 _image$shape$slice2 = _slicedToArray(_image$shape$slice, 2),
105319 height = _image$shape$slice2[0],
105320 width = _image$shape$slice2[1];
105321 var depth = image.shape.length === 2 ? 1 : image.shape[2];
105322 var data = backend.data.get(image.dataId).values;
105323 var multiplier = image.dtype === 'float32' ? 255 : 1;
105324 var bytes = new Uint8ClampedArray(width * height * 4);
105325 for (var i = 0; i < height * width; ++i) {
105326 var rgba = [0, 0, 0, 255 * alpha];
105327 for (var d = 0; d < depth; d++) {
105328 var value = data[i * depth + d];
105329 if (image.dtype === 'float32') {
105330 if (value < 0 || value > 1) {
105331 throw new Error("Tensor values for a float32 Tensor must be in the " + "range [0 - 1] but encountered ".concat(value, "."));
105332 }
105333 } else if (image.dtype === 'int32') {
105334 if (value < 0 || value > 255) {
105335 throw new Error("Tensor values for a int32 Tensor must be in the " + "range [0 - 255] but encountered ".concat(value, "."));
105336 }
105337 }
105338 if (depth === 1) {
105339 rgba[0] = value * multiplier;
105340 rgba[1] = value * multiplier;
105341 rgba[2] = value * multiplier;
105342 } else {
105343 rgba[d] = value * multiplier;
105344 }
105345 }
105346 var j = i * 4;
105347 bytes[j + 0] = Math.round(rgba[0]);
105348 bytes[j + 1] = Math.round(rgba[1]);
105349 bytes[j + 2] = Math.round(rgba[2]);
105350 bytes[j + 3] = Math.round(rgba[3]);
105351 }
105352 canvas.width = width;
105353 canvas.height = height;
105354 var imageData = new ImageData(bytes, width, height);
105355 ctx.putImageData(imageData, 0, 0);
105356 return image;
105357 }
105358 var drawConfig = {
105359 kernelName: Draw,
105360 backendName: 'cpu',
105361 kernelFunc: draw
105362 };
105363
105364 function sum$1(args) {
105365 var inputs = args.inputs,
105366 backend = args.backend,
105367 attrs = args.attrs;
105368 var x = inputs.x;
105369 var axis = attrs.axis,
105370 keepDims = attrs.keepDims;
105371 assertNotComplex$1(x, 'sum');
105372 var $x;
105373 if (x.dtype === 'bool') {
105374 $x = cast$1({
105375 inputs: {
105376 x: x
105377 },
105378 backend: backend,
105379 attrs: {
105380 dtype: 'int32'
105381 }
105382 });
105383 } else {
105384 $x = identity$1({
105385 inputs: {
105386 x: x
105387 },
105388 backend: backend
105389 });
105390 }
105391 var xRank = $x.shape.length;
105392 var axes = parseAxisParam(axis, $x.shape);
105393 var permutation = getAxesPermutation(axes, xRank);
105394 var reductionAxes = axes;
105395 var permutedX = $x;
105396 if (permutation != null) {
105397 permutedX = transpose$1({
105398 inputs: {
105399 x: $x
105400 },
105401 backend: backend,
105402 attrs: {
105403 perm: permutation
105404 }
105405 });
105406 reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
105407 }
105408 assertAxesAreInnerMostDims('sum', reductionAxes, permutedX.shape.length);
105409 var _backend_util$compute = computeOutAndReduceShapes(permutedX.shape, reductionAxes),
105410 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
105411 outShape = _backend_util$compute2[0],
105412 reduceShape = _backend_util$compute2[1];
105413 var resultDtype = upcastType(permutedX.dtype, 'int32');
105414 var result = zeros(backend, outShape, resultDtype);
105415 var reduceSize = sizeFromShape(reduceShape);
105416 var vals = backend.data.get(result.dataId).values;
105417 var aVals = backend.data.get(permutedX.dataId).values;
105418 for (var i = 0; i < vals.length; ++i) {
105419 var offset = i * reduceSize;
105420 var _sum = 0;
105421 for (var j = 0; j < reduceSize; ++j) {
105422 _sum += aVals[offset + j];
105423 }
105424 vals[i] = _sum;
105425 }
105426 if (keepDims) {
105427 var newShape = expandShapeToKeepDim(result.shape, axes);
105428 var oldResult = result;
105429 result = reshape$1({
105430 inputs: {
105431 x: result
105432 },
105433 backend: backend,
105434 attrs: {
105435 shape: newShape
105436 }
105437 });
105438 backend.disposeIntermediateTensorInfo(oldResult);
105439 }
105440 backend.disposeIntermediateTensorInfo($x);
105441 if (permutation != null) {
105442 backend.disposeIntermediateTensorInfo(permutedX);
105443 }
105444 return result;
105445 }
105446 var sumConfig$1 = {
105447 kernelName: Sum,
105448 backendName: 'cpu',
105449 kernelFunc: sum$1
105450 };
105451
105452 function einsum$1(args) {
105453 var inputs = args.inputs,
105454 backend = args.backend,
105455 attrs = args.attrs;
105456 var equation = attrs.equation;
105457 var tensors = inputs;
105458 var _backend_util$decodeE = decodeEinsumEquation(equation, tensors.length),
105459 allDims = _backend_util$decodeE.allDims,
105460 summedDims = _backend_util$decodeE.summedDims,
105461 idDims = _backend_util$decodeE.idDims;
105462 checkEinsumDimSizes(allDims.length, idDims, tensors);
105463 var _backend_util$getEins = getEinsumComputePath(summedDims, idDims),
105464 path = _backend_util$getEins.path,
105465 steps = _backend_util$getEins.steps;
105466 var nSteps = steps.length;
105467 var out = null;
105468 var numDimsRemaining = allDims.length;
105469 var tensorsToDispose = [];
105470 for (var i = 0; i < nSteps; ++i) {
105471 var _iterator = _createForOfIteratorHelper(steps[i]),
105472 _step;
105473 try {
105474 for (_iterator.s(); !(_step = _iterator.n()).done;) {
105475 var idTerm = _step.value;
105476 var _backend_util$getEins2 = getEinsumPermutation(numDimsRemaining, idDims[idTerm]),
105477 perm = _backend_util$getEins2.permutationIndices,
105478 dimsToExpand = _backend_util$getEins2.expandDims;
105479 var x = void 0;
105480 if (isIdentityPermutation(perm)) {
105481 x = tensors[idTerm];
105482 } else {
105483 x = transpose$1({
105484 inputs: {
105485 x: tensors[idTerm]
105486 },
105487 backend: backend,
105488 attrs: {
105489 perm: perm
105490 }
105491 });
105492 tensorsToDispose.push(x);
105493 }
105494 var targetShape = x.shape.slice();
105495 for (var k = 0; k < dimsToExpand.length; ++k) {
105496 targetShape.splice(dimsToExpand[k], 0, 1);
105497 }
105498 if (!arraysEqual(x.shape, targetShape)) {
105499 x = reshape$1({
105500 inputs: {
105501 x: x
105502 },
105503 backend: backend,
105504 attrs: {
105505 shape: targetShape
105506 }
105507 });
105508 tensorsToDispose.push(x);
105509 }
105510 if (out === null) {
105511 out = x;
105512 } else {
105513 // tslint:disable-next-line: no-unnecessary-type-assertion
105514 out = multiply$1({
105515 inputs: {
105516 a: x,
105517 b: out
105518 },
105519 backend: backend
105520 });
105521 tensorsToDispose.push(out);
105522 }
105523 }
105524 } catch (err) {
105525 _iterator.e(err);
105526 } finally {
105527 _iterator.f();
105528 }
105529 if (i < nSteps - 1) {
105530 if (path[i] >= 0) {
105531 out = sum$1({
105532 inputs: {
105533 x: out
105534 },
105535 backend: backend,
105536 attrs: {
105537 axis: path[i] - (allDims.length - numDimsRemaining),
105538 keepDims: false
105539 }
105540 });
105541 tensorsToDispose.push(out);
105542 }
105543 numDimsRemaining--;
105544 }
105545 }
105546 // Clean up intermediate tensors.
105547 for (var _i = 0, _tensorsToDispose = tensorsToDispose; _i < _tensorsToDispose.length; _i++) {
105548 var tensorInfo = _tensorsToDispose[_i];
105549 if (tensorInfo === out) {
105550 continue;
105551 }
105552 backend.disposeIntermediateTensorInfo(tensorInfo);
105553 }
105554 return out;
105555 }
105556 var einsumConfig$1 = {
105557 kernelName: Einsum,
105558 backendName: 'cpu',
105559 kernelFunc: einsum$1
105560 };
105561
105562 /**
105563 * @license
105564 * Copyright 2020 Google LLC. All Rights Reserved.
105565 * Licensed under the Apache License, Version 2.0 (the "License");
105566 * you may not use this file except in compliance with the License.
105567 * You may obtain a copy of the License at
105568 *
105569 * http://www.apache.org/licenses/LICENSE-2.0
105570 *
105571 * Unless required by applicable law or agreed to in writing, software
105572 * distributed under the License is distributed on an "AS IS" BASIS,
105573 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
105574 * See the License for the specific language governing permissions and
105575 * limitations under the License.
105576 * =============================================================================
105577 */
105578 function eluGrad$1(args) {
105579 var inputs = args.inputs,
105580 backend = args.backend;
105581 var dy = inputs.dy,
105582 y = inputs.y;
105583 assertNotComplex$1([dy, y], 'eluGrad');
105584 var resultValues = new Float32Array(sizeFromShape(y.shape));
105585 var values = backend.data.get(y.dataId).values;
105586 var dyValues = backend.data.get(dy.dataId).values;
105587 for (var i = 0; i < values.length; ++i) {
105588 var v = values[i];
105589 if (v >= 0) {
105590 resultValues[i] = dyValues[i];
105591 } else {
105592 resultValues[i] = dyValues[i] * (v + 1);
105593 }
105594 }
105595 return backend.makeTensorInfo(y.shape, 'float32', resultValues);
105596 }
105597 var eluGradConfig$1 = {
105598 kernelName: EluGrad,
105599 backendName: 'cpu',
105600 kernelFunc: eluGrad$1
105601 };
105602
105603 /**
105604 * @license
105605 * Copyright 2020 Google LLC. All Rights Reserved.
105606 * Licensed under the Apache License, Version 2.0 (the License);
105607 * you may not use this file except in compliance with the License.
105608 * You may obtain a copy of the License at
105609 *
105610 * http://www.apache.org/licenses/LICENSE-2.0
105611 *
105612 * Unless required by applicable law or agreed to in writing, software
105613 * distributed under the License is distributed on an AS IS BASIS,
105614 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
105615 * See the License for the specific language governing permissions and
105616 * limitations under the License.
105617 * =============================================================================
105618 */
105619 var p = ERF_P;
105620 var a1 = ERF_A1;
105621 var a2 = ERF_A2;
105622 var a3 = ERF_A3;
105623 var a4 = ERF_A4;
105624 var a5 = ERF_A5;
105625 var erf$1 = unaryKernelFunc$1(Erf, function (xi) {
105626 var sign = Math.sign(xi);
105627 var v = Math.abs(xi);
105628 var t = 1.0 / (1.0 + p * v);
105629 return sign * (1.0 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-v * v));
105630 });
105631 var erfConfig$1 = {
105632 kernelName: Erf,
105633 backendName: 'cpu',
105634 kernelFunc: erf$1
105635 };
105636
105637 /**
105638 * @license
105639 * Copyright 2020 Google LLC. All Rights Reserved.
105640 * Licensed under the Apache License, Version 2.0 (the "License");
105641 * you may not use this file except in compliance with the License.
105642 * You may obtain a copy of the License at
105643 *
105644 * http://www.apache.org/licenses/LICENSE-2.0
105645 *
105646 * Unless required by applicable law or agreed to in writing, software
105647 * distributed under the License is distributed on an "AS IS" BASIS,
105648 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
105649 * See the License for the specific language governing permissions and
105650 * limitations under the License.
105651 * =============================================================================
105652 */
105653 function expandDims$1(args) {
105654 var inputs = args.inputs,
105655 backend = args.backend,
105656 attrs = args.attrs;
105657 var input = inputs.input;
105658 var dim = attrs.dim;
105659 var inputRank = input.shape.length;
105660 var newShape = input.shape.slice();
105661 var $dim = dim;
105662 if (dim < 0) {
105663 // Negative value is counted from the tail of rank.
105664 assert$1(-(inputRank + 1) <= dim, function () {
105665 return "Axis must be in the interval [".concat(-(inputRank + 1), ", ").concat(inputRank, "]");
105666 });
105667 $dim = inputRank + dim + 1;
105668 }
105669 newShape.splice($dim, 0, 1);
105670 return reshape$1({
105671 inputs: {
105672 x: input
105673 },
105674 backend: backend,
105675 attrs: {
105676 shape: newShape
105677 }
105678 });
105679 }
105680 var expandDimsConfig$1 = {
105681 kernelName: ExpandDims,
105682 backendName: 'cpu',
105683 kernelFunc: expandDims$1
105684 };
105685
105686 /**
105687 * @license
105688 * Copyright 2020 Google LLC. All Rights Reserved.
105689 * Licensed under the Apache License, Version 2.0 (the "License");
105690 * you may not use this file except in compliance with the License.
105691 * You may obtain a copy of the License at
105692 *
105693 * http://www.apache.org/licenses/LICENSE-2.0
105694 *
105695 * Unless required by applicable law or agreed to in writing, software
105696 * distributed under the License is distributed on an "AS IS" BASIS,
105697 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
105698 * See the License for the specific language governing permissions and
105699 * limitations under the License.
105700 * =============================================================================
105701 */
105702 var realDivImpl = createSimpleBinaryKernelImpl(function (a, b) {
105703 return a / b;
105704 });
105705 var div = binaryKernelFunc$1(RealDiv, realDivImpl);
105706 var realDivConfig$1 = {
105707 kernelName: RealDiv,
105708 backendName: 'cpu',
105709 kernelFunc: div
105710 };
105711
105712 /**
105713 * @license
105714 * Copyright 2020 Google LLC. All Rights Reserved.
105715 * Licensed under the Apache License, Version 2.0 (the "License");
105716 * you may not use this file except in compliance with the License.
105717 * You may obtain a copy of the License at
105718 *
105719 * http://www.apache.org/licenses/LICENSE-2.0
105720 *
105721 * Unless required by applicable law or agreed to in writing, software
105722 * distributed under the License is distributed on an "AS IS" BASIS,
105723 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
105724 * See the License for the specific language governing permissions and
105725 * limitations under the License.
105726 * =============================================================================
105727 */
105728 /**
105729 * Calculate FFT of inner most elements of batch tensor.
105730 */
105731 function fftBatch(input, inverse, cpuBackend) {
105732 var inputShape = input.shape;
105733 var batch = inputShape[0];
105734 var innerDim = inputShape[1];
105735 var inputVals = cpuBackend.data.get(input.dataId);
105736 var real2D = inputVals.complexTensorInfos.real;
105737 var imag2D = inputVals.complexTensorInfos.imag;
105738 // Collects real and imaginary values separately.
105739 var resultShape = [batch, innerDim];
105740 var resultSize = sizeFromShape(resultShape);
105741 var resultReal = getTypedArrayFromDType('float32', resultSize);
105742 var resultImag = getTypedArrayFromDType('float32', resultSize);
105743 for (var b = 0; b < batch; b++) {
105744 // TODO: Support slice ops for complex type.
105745 var r = slice$1({
105746 inputs: {
105747 x: real2D
105748 },
105749 backend: cpuBackend,
105750 attrs: {
105751 begin: [b, 0],
105752 size: [1, innerDim]
105753 }
105754 });
105755 var i = slice$1({
105756 inputs: {
105757 x: imag2D
105758 },
105759 backend: cpuBackend,
105760 attrs: {
105761 begin: [b, 0],
105762 size: [1, innerDim]
105763 }
105764 });
105765 var _input = complex$1({
105766 inputs: {
105767 real: r,
105768 imag: i
105769 },
105770 backend: cpuBackend
105771 });
105772 // Run FFT by batch element.
105773 var _fftImpl = fftImpl$1(_input, inverse, cpuBackend),
105774 _real = _fftImpl.real,
105775 _imag = _fftImpl.imag;
105776 var res = mergeRealAndImagArrays(_real, _imag);
105777 for (var d = 0; d < innerDim; d++) {
105778 var c = getComplexWithIndex(res, d);
105779 resultReal[b * innerDim + d] = c.real;
105780 resultImag[b * innerDim + d] = c.imag;
105781 }
105782 cpuBackend.disposeIntermediateTensorInfo(r);
105783 cpuBackend.disposeIntermediateTensorInfo(i);
105784 cpuBackend.disposeIntermediateTensorInfo(_input);
105785 }
105786 var $realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultReal);
105787 var $imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImag);
105788 var result = complex$1({
105789 inputs: {
105790 real: $realInfo,
105791 imag: $imagInfo
105792 },
105793 backend: cpuBackend
105794 });
105795 cpuBackend.disposeIntermediateTensorInfo($realInfo);
105796 cpuBackend.disposeIntermediateTensorInfo($imagInfo);
105797 return result;
105798 }
105799 function fftImpl$1(input, inverse, cpuBackend) {
105800 var inputSize = sizeFromShape(input.shape);
105801 var inputVals = cpuBackend.data.get(input.dataId);
105802 var realVals = cpuBackend.data.get(inputVals.complexTensorInfos.real.dataId).values;
105803 var imagVals = cpuBackend.data.get(inputVals.complexTensorInfos.imag.dataId).values;
105804 if (isExponentOf2(inputSize)) {
105805 var result = fftRadix2(realVals, imagVals, inputSize, inverse, cpuBackend);
105806 var resultShape = [input.shape[0], input.shape[1]];
105807 if (inverse) {
105808 var realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.real);
105809 var imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.imag);
105810 var sizeInfo = cpuBackend.makeTensorInfo([], 'float32', createScalarValue(inputSize, 'float32'));
105811 var sizeInfoCopy = identity$1({
105812 inputs: {
105813 x: sizeInfo
105814 },
105815 backend: cpuBackend
105816 });
105817 var divRealInfo = realDivConfig$1.kernelFunc({
105818 inputs: {
105819 a: realInfo,
105820 b: sizeInfo
105821 },
105822 backend: cpuBackend
105823 });
105824 var divImagInfo = realDivConfig$1.kernelFunc({
105825 inputs: {
105826 a: imagInfo,
105827 b: sizeInfoCopy
105828 },
105829 backend: cpuBackend
105830 });
105831 var divRealVals = cpuBackend.data.get(divRealInfo.dataId).values;
105832 var divImagVals = cpuBackend.data.get(divImagInfo.dataId).values;
105833 cpuBackend.disposeIntermediateTensorInfo(realInfo);
105834 cpuBackend.disposeIntermediateTensorInfo(imagInfo);
105835 cpuBackend.disposeIntermediateTensorInfo(sizeInfo);
105836 cpuBackend.disposeIntermediateTensorInfo(sizeInfoCopy);
105837 cpuBackend.disposeIntermediateTensorInfo(divRealInfo);
105838 cpuBackend.disposeIntermediateTensorInfo(divImagInfo);
105839 return {
105840 real: divRealVals,
105841 imag: divImagVals
105842 };
105843 }
105844 return result;
105845 } else {
105846 var data = mergeRealAndImagArrays(realVals, imagVals);
105847 var rawOutput = fourierTransformByMatmul(data, inputSize, inverse);
105848 return splitRealAndImagArrays(rawOutput);
105849 }
105850 }
105851 function isExponentOf2(size) {
105852 return (size & size - 1) === 0;
105853 }
105854 // FFT using Cooley-Tukey algorithm on radix 2 dimensional input.
105855 function fftRadix2(realVals, imagVals, size, inverse, cpuBackend) {
105856 if (size === 1) {
105857 return {
105858 real: realVals,
105859 imag: imagVals
105860 };
105861 }
105862 var data = mergeRealAndImagArrays(realVals, imagVals);
105863 var half = size / 2;
105864 var evenComplex = complexWithEvenIndex(data);
105865 var evenRealVals = evenComplex.real;
105866 var evenImagVals = evenComplex.imag;
105867 var evenShape = [evenRealVals.length];
105868 var evenRealInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenRealVals);
105869 var evenImagInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenImagVals);
105870 var evenTensorInfo = complex$1({
105871 inputs: {
105872 real: evenRealInfo,
105873 imag: evenImagInfo
105874 },
105875 backend: cpuBackend
105876 });
105877 var oddComplex = complexWithOddIndex(data);
105878 var oddRealVals = oddComplex.real;
105879 var oddImagVals = oddComplex.imag;
105880 var oddShape = [oddRealVals.length];
105881 var oddRealInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddRealVals);
105882 var oddImagInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddImagVals);
105883 var oddTensorInfo = complex$1({
105884 inputs: {
105885 real: oddRealInfo,
105886 imag: oddImagInfo
105887 },
105888 backend: cpuBackend
105889 });
105890 // Recursive call for half part of original input.
105891 var $evenComplex = fftRadix2(evenRealVals, evenImagVals, half, inverse, cpuBackend);
105892 var $evenRealVals = $evenComplex.real;
105893 var $evenImagVals = $evenComplex.imag;
105894 var $evenShape = [$evenRealVals.length];
105895 var $evenRealInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenRealVals);
105896 var $evenImagInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenImagVals);
105897 var $evenTensorInfo = complex$1({
105898 inputs: {
105899 real: $evenRealInfo,
105900 imag: $evenImagInfo
105901 },
105902 backend: cpuBackend
105903 });
105904 var $oddComplex = fftRadix2(oddRealVals, oddImagVals, half, inverse, cpuBackend);
105905 var $oddRealVals = $oddComplex.real;
105906 var $oddImagVals = $oddComplex.imag;
105907 var $oddShape = [$oddRealVals.length];
105908 var $oddRealInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddRealVals);
105909 var $oddImagInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddImagVals);
105910 var $oddTensorInfo = complex$1({
105911 inputs: {
105912 real: $oddRealInfo,
105913 imag: $oddImagInfo
105914 },
105915 backend: cpuBackend
105916 });
105917 var e = exponents(size, inverse);
105918 var eShape = [e.real.length];
105919 var eRealInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.real);
105920 var eImagInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.imag);
105921 var complexInfo = complex$1({
105922 inputs: {
105923 real: eRealInfo,
105924 imag: eImagInfo
105925 },
105926 backend: cpuBackend
105927 });
105928 var exponentInfo = multiply$1({
105929 inputs: {
105930 a: complexInfo,
105931 b: $oddTensorInfo
105932 },
105933 backend: cpuBackend
105934 });
105935 var addPart = add({
105936 inputs: {
105937 a: $evenTensorInfo,
105938 b: exponentInfo
105939 },
105940 backend: cpuBackend
105941 });
105942 var subPart = sub$1({
105943 inputs: {
105944 a: $evenTensorInfo,
105945 b: exponentInfo
105946 },
105947 backend: cpuBackend
105948 });
105949 var addPartReal = real$1({
105950 inputs: {
105951 input: addPart
105952 },
105953 backend: cpuBackend
105954 });
105955 var subPartReal = real$1({
105956 inputs: {
105957 input: subPart
105958 },
105959 backend: cpuBackend
105960 });
105961 var addPartImag = imag$1({
105962 inputs: {
105963 input: addPart
105964 },
105965 backend: cpuBackend
105966 });
105967 var subPartImag = imag$1({
105968 inputs: {
105969 input: subPart
105970 },
105971 backend: cpuBackend
105972 });
105973 var $real = concat$1({
105974 inputs: [addPartReal, subPartReal],
105975 backend: cpuBackend,
105976 attrs: {
105977 axis: 0
105978 }
105979 });
105980 var $imag = concat$1({
105981 inputs: [addPartImag, subPartImag],
105982 backend: cpuBackend,
105983 attrs: {
105984 axis: 0
105985 }
105986 });
105987 var $realVals = cpuBackend.data.get($real.dataId).values;
105988 var $imagVals = cpuBackend.data.get($imag.dataId).values;
105989 cpuBackend.disposeIntermediateTensorInfo(evenRealInfo);
105990 cpuBackend.disposeIntermediateTensorInfo(evenImagInfo);
105991 cpuBackend.disposeIntermediateTensorInfo(evenTensorInfo);
105992 cpuBackend.disposeIntermediateTensorInfo(oddRealInfo);
105993 cpuBackend.disposeIntermediateTensorInfo(oddImagInfo);
105994 cpuBackend.disposeIntermediateTensorInfo(oddTensorInfo);
105995 cpuBackend.disposeIntermediateTensorInfo($evenRealInfo);
105996 cpuBackend.disposeIntermediateTensorInfo($evenImagInfo);
105997 cpuBackend.disposeIntermediateTensorInfo($evenTensorInfo);
105998 cpuBackend.disposeIntermediateTensorInfo($oddRealInfo);
105999 cpuBackend.disposeIntermediateTensorInfo($oddImagInfo);
106000 cpuBackend.disposeIntermediateTensorInfo($oddTensorInfo);
106001 cpuBackend.disposeIntermediateTensorInfo(eRealInfo);
106002 cpuBackend.disposeIntermediateTensorInfo(eImagInfo);
106003 cpuBackend.disposeIntermediateTensorInfo(complexInfo);
106004 cpuBackend.disposeIntermediateTensorInfo(exponentInfo);
106005 cpuBackend.disposeIntermediateTensorInfo(addPart);
106006 cpuBackend.disposeIntermediateTensorInfo(subPart);
106007 cpuBackend.disposeIntermediateTensorInfo(addPartReal);
106008 cpuBackend.disposeIntermediateTensorInfo(addPartImag);
106009 cpuBackend.disposeIntermediateTensorInfo(subPartReal);
106010 cpuBackend.disposeIntermediateTensorInfo(subPartImag);
106011 cpuBackend.disposeIntermediateTensorInfo($real);
106012 cpuBackend.disposeIntermediateTensorInfo($imag);
106013 return {
106014 real: $realVals,
106015 imag: $imagVals
106016 };
106017 }
106018 // Calculate fourier transform by multplying sinusoid matrix.
106019 function fourierTransformByMatmul(data, size, inverse) {
106020 var ret = new Float32Array(size * 2);
106021 // TODO: Use matmul instead once it supports complex64 type.
106022 for (var r = 0; r < size; r++) {
106023 var _real2 = 0.0;
106024 var _imag2 = 0.0;
106025 for (var c = 0; c < size; c++) {
106026 var e = exponent(r * c, size, inverse);
106027 var term = getComplexWithIndex(data, c);
106028 _real2 += term.real * e.real - term.imag * e.imag;
106029 _imag2 += term.real * e.imag + term.imag * e.real;
106030 }
106031 if (inverse) {
106032 _real2 /= size;
106033 _imag2 /= size;
106034 }
106035 assignToTypedArray(ret, _real2, _imag2, r);
106036 }
106037 return ret;
106038 }
106039
106040 /**
106041 * @license
106042 * Copyright 2020 Google LLC. All Rights Reserved.
106043 * Licensed under the Apache License, Version 2.0 (the "License");
106044 * you may not use this file except in compliance with the License.
106045 * You may obtain a copy of the License at
106046 *
106047 * http://www.apache.org/licenses/LICENSE-2.0
106048 *
106049 * Unless required by applicable law or agreed to in writing, software
106050 * distributed under the License is distributed on an "AS IS" BASIS,
106051 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106052 * See the License for the specific language governing permissions and
106053 * limitations under the License.
106054 * =============================================================================
106055 */
106056 function fft$1(args) {
106057 var inputs = args.inputs,
106058 backend = args.backend;
106059 var input = inputs.input;
106060 var inputSize = sizeFromShape(input.shape);
106061 // Collapse all outer dimensions to a single batch dimension.
106062 var innerDimensionSize = input.shape[input.shape.length - 1];
106063 var batch = inputSize / innerDimensionSize;
106064 var input2D = reshape$1({
106065 inputs: {
106066 x: input
106067 },
106068 backend: backend,
106069 attrs: {
106070 shape: [batch, innerDimensionSize]
106071 }
106072 });
106073 var result = fftBatch(input2D, false, backend);
106074 var resultReshaped = reshape$1({
106075 inputs: {
106076 x: result
106077 },
106078 backend: backend,
106079 attrs: {
106080 shape: input.shape
106081 }
106082 });
106083 backend.disposeIntermediateTensorInfo(input2D);
106084 backend.disposeIntermediateTensorInfo(result);
106085 return resultReshaped;
106086 }
106087 var fftConfig$1 = {
106088 kernelName: FFT,
106089 backendName: 'cpu',
106090 kernelFunc: fft$1
106091 };
106092
106093 /**
106094 * @license
106095 * Copyright 2020 Google LLC. All Rights Reserved.
106096 * Licensed under the Apache License, Version 2.0 (the "License");
106097 * you may not use this file except in compliance with the License.
106098 * You may obtain a copy of the License at
106099 *
106100 * http://www.apache.org/licenses/LICENSE-2.0
106101 *
106102 * Unless required by applicable law or agreed to in writing, software
106103 * distributed under the License is distributed on an "AS IS" BASIS,
106104 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106105 * See the License for the specific language governing permissions and
106106 * limitations under the License.
106107 * =============================================================================
106108 */
106109 function fill$1(args) {
106110 var backend = args.backend,
106111 attrs = args.attrs;
106112 var shape = attrs.shape,
106113 value = attrs.value,
106114 dtype = attrs.dtype;
106115 var $dtype = dtype || inferDtype(value);
106116 var values = getArrayFromDType($dtype, sizeFromShape(shape));
106117 fillValues(values, value, $dtype);
106118 return backend.makeTensorInfo(shape, $dtype, values);
106119 }
106120 var fillConfig$1 = {
106121 kernelName: Fill,
106122 backendName: 'cpu',
106123 kernelFunc: fill$1
106124 };
106125 function fillValues(values, value, dtype) {
106126 if (dtype === 'string') {
106127 values.fill(value);
106128 } else {
106129 values.fill(value);
106130 }
106131 }
106132
106133 var flipLeftRightConfig$1 = {
106134 kernelName: FlipLeftRight,
106135 backendName: 'cpu',
106136 kernelFunc: function kernelFunc(_ref) {
106137 var inputs = _ref.inputs,
106138 attrs = _ref.attrs,
106139 backend = _ref.backend;
106140 var image = inputs.image;
106141 var cpuBackend = backend;
106142 var output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
106143 var _image$shape = _slicedToArray(image.shape, 4),
106144 batch = _image$shape[0],
106145 imageHeight = _image$shape[1],
106146 imageWidth = _image$shape[2],
106147 numChannels = _image$shape[3];
106148 var imageVals = cpuBackend.data.get(image.dataId).values;
106149 for (var batchIdx = 0; batchIdx < batch; batchIdx++) {
106150 var batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
106151 for (var row = 0; row < imageHeight; row++) {
106152 var rowOffset = row * (imageWidth * numChannels);
106153 for (var col = 0; col < imageWidth; col++) {
106154 var colOffset = col * numChannels;
106155 for (var channel = 0; channel < numChannels; channel++) {
106156 var coordX = Math.round(imageWidth - col - 1);
106157 var outIdx = batchOffset + rowOffset + colOffset + channel;
106158 var outputValue = imageVals[outIdx];
106159 // If the coordinate position falls within the image boundaries...
106160 if (coordX >= 0 && coordX < imageWidth) {
106161 // set the output to the image value at the coordinate position.
106162 var rotatedColOffset = coordX * numChannels;
106163 var imageIdx = batchOffset + rowOffset + rotatedColOffset + channel;
106164 outputValue = imageVals[imageIdx];
106165 }
106166 output[outIdx] = outputValue;
106167 }
106168 }
106169 }
106170 }
106171 var dataId = cpuBackend.write(output, image.shape, image.dtype);
106172 return {
106173 dataId: dataId,
106174 shape: image.shape,
106175 dtype: image.dtype
106176 };
106177 }
106178 };
106179
106180 /**
106181 * @license
106182 * Copyright 2020 Google LLC. All Rights Reserved.
106183 * Licensed under the Apache License, Version 2.0 (the "License");
106184 * you may not use this file except in compliance with the License.
106185 * You may obtain a copy of the License at
106186 *
106187 * http://www.apache.org/licenses/LICENSE-2.0
106188 *
106189 * Unless required by applicable law or agreed to in writing, software
106190 * distributed under the License is distributed on an "AS IS" BASIS,
106191 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106192 * See the License for the specific language governing permissions and
106193 * limitations under the License.
106194 * =============================================================================
106195 */
106196 function fusedConv2D(args) {
106197 var inputs = args.inputs,
106198 backend = args.backend,
106199 attrs = args.attrs;
106200 var x = inputs.x,
106201 filter = inputs.filter,
106202 bias = inputs.bias,
106203 preluActivationWeights = inputs.preluActivationWeights;
106204 var strides = attrs.strides,
106205 pad = attrs.pad,
106206 dataFormat = attrs.dataFormat,
106207 dilations = attrs.dilations,
106208 dimRoundingMode = attrs.dimRoundingMode,
106209 activation = attrs.activation,
106210 leakyreluAlpha = attrs.leakyreluAlpha;
106211 var result = conv2D({
106212 inputs: {
106213 x: x,
106214 filter: filter
106215 },
106216 backend: backend,
106217 attrs: {
106218 strides: strides,
106219 pad: pad,
106220 dataFormat: dataFormat,
106221 dilations: dilations,
106222 dimRoundingMode: dimRoundingMode
106223 }
106224 });
106225 if (bias) {
106226 var resultOld = result;
106227 // For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned
106228 // to the channel of the conv2d's result; if the bias is a scalar, the
106229 // bias_add is computed as if the bias was broadcasted to the shape of the
106230 // conv2d's result.
106231 if (dataFormat === 'NCHW' && bias.shape.length === 1 && bias.shape[0] !== 1) {
106232 var reshapedBias = reshape$1({
106233 inputs: {
106234 x: bias
106235 },
106236 backend: backend,
106237 attrs: {
106238 shape: [bias.shape[0], 1, 1]
106239 }
106240 });
106241 result = add({
106242 inputs: {
106243 a: result,
106244 b: reshapedBias
106245 },
106246 backend: backend
106247 });
106248 backend.disposeIntermediateTensorInfo(reshapedBias);
106249 } else {
106250 // This condition handles NHWC and NCHW (scalar case). The only other case
106251 // for NCHW (1D case) is handled above.
106252 result = add({
106253 inputs: {
106254 a: result,
106255 b: bias
106256 },
106257 backend: backend
106258 });
106259 }
106260 backend.disposeIntermediateTensorInfo(resultOld);
106261 }
106262 if (activation) {
106263 var _resultOld = result;
106264 // For NCHW format, if PReLu activation weights is a 1-D tensor, it is
106265 // supposed to be aligned with the channel of the conv2d's result. For other
106266 // cases, whether NCHW or NHWC data format, the conv2d result is
106267 // already aligned with the activation weights.
106268 if (dataFormat === 'NCHW' && activation === 'prelu' && preluActivationWeights.shape.length === 1 && preluActivationWeights.shape[0] !== 1) {
106269 var reshapedAlpha = reshape$1({
106270 inputs: {
106271 x: preluActivationWeights
106272 },
106273 backend: backend,
106274 attrs: {
106275 shape: [preluActivationWeights.shape[0], 1, 1]
106276 }
106277 });
106278 result = applyActivation(backend, result, activation, reshapedAlpha, leakyreluAlpha);
106279 backend.disposeIntermediateTensorInfo(reshapedAlpha);
106280 } else {
106281 result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
106282 }
106283 backend.disposeIntermediateTensorInfo(_resultOld);
106284 }
106285 return result;
106286 }
106287 var fusedConv2DConfig$1 = {
106288 kernelName: FusedConv2D,
106289 backendName: 'cpu',
106290 kernelFunc: fusedConv2D
106291 };
106292
106293 /**
106294 * @license
106295 * Copyright 2020 Google LLC. All Rights Reserved.
106296 * Licensed under the Apache License, Version 2.0 (the "License");
106297 * you may not use this file except in compliance with the License.
106298 * You may obtain a copy of the License at
106299 *
106300 * http://www.apache.org/licenses/LICENSE-2.0
106301 *
106302 * Unless required by applicable law or agreed to in writing, software
106303 * distributed under the License is distributed on an "AS IS" BASIS,
106304 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106305 * See the License for the specific language governing permissions and
106306 * limitations under the License.
106307 * =============================================================================
106308 */
106309 function fusedDepthwiseConv2D$1(args) {
106310 var inputs = args.inputs,
106311 backend = args.backend,
106312 attrs = args.attrs;
106313 var x = inputs.x,
106314 filter = inputs.filter,
106315 bias = inputs.bias,
106316 preluActivationWeights = inputs.preluActivationWeights;
106317 var strides = attrs.strides,
106318 pad = attrs.pad,
106319 dataFormat = attrs.dataFormat,
106320 dilations = attrs.dilations,
106321 dimRoundingMode = attrs.dimRoundingMode,
106322 activation = attrs.activation,
106323 leakyreluAlpha = attrs.leakyreluAlpha;
106324 var result = depthwiseConv2dNative$1({
106325 inputs: {
106326 x: x,
106327 filter: filter
106328 },
106329 backend: backend,
106330 attrs: {
106331 strides: strides,
106332 pad: pad,
106333 dataFormat: dataFormat,
106334 dilations: dilations,
106335 dimRoundingMode: dimRoundingMode
106336 }
106337 });
106338 if (bias) {
106339 var oldResult = result;
106340 result = add({
106341 inputs: {
106342 a: result,
106343 b: bias
106344 },
106345 backend: backend
106346 });
106347 backend.disposeIntermediateTensorInfo(oldResult);
106348 }
106349 if (activation) {
106350 var _oldResult = result;
106351 result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
106352 backend.disposeIntermediateTensorInfo(_oldResult);
106353 }
106354 return result;
106355 }
106356 var fusedDepthwiseConv2DConfig$1 = {
106357 kernelName: FusedDepthwiseConv2D,
106358 backendName: 'cpu',
106359 kernelFunc: fusedDepthwiseConv2D$1
106360 };
106361
106362 function gatherNd$1(args) {
106363 var inputs = args.inputs,
106364 backend = args.backend;
106365 var params = inputs.params,
106366 indices = inputs.indices;
106367 var paramsSize = sizeFromShape(params.shape);
106368 var indicesShape = indices.shape;
106369 var sliceRank = indicesShape[indicesShape.length - 1];
106370 var _backend_util$prepare = prepareAndValidate(params, indices),
106371 _backend_util$prepare2 = _slicedToArray(_backend_util$prepare, 4),
106372 resultShape = _backend_util$prepare2[0],
106373 numSlices = _backend_util$prepare2[1],
106374 sliceSize = _backend_util$prepare2[2],
106375 strides = _backend_util$prepare2[3];
106376 if (numSlices === 0) {
106377 return backend.makeTensorInfo(resultShape, params.dtype, []);
106378 }
106379 var indicesData = backend.data.get(indices.dataId).values;
106380 var paramsBuf = backend.bufferSync(params);
106381 var outBuf = gatherNdImpl(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
106382 return backend.makeTensorInfo(resultShape, params.dtype, outBuf.values);
106383 }
106384 var gatherNdConfig$1 = {
106385 kernelName: GatherNd,
106386 backendName: 'cpu',
106387 kernelFunc: gatherNd$1
106388 };
106389
106390 /**
106391 * @license
106392 * Copyright 2020 Google LLC. All Rights Reserved.
106393 * Licensed under the Apache License, Version 2.0 (the "License");
106394 * you may not use this file except in compliance with the License.
106395 * You may obtain a copy of the License at
106396 *
106397 * http://www.apache.org/licenses/LICENSE-2.0
106398 *
106399 * Unless required by applicable law or agreed to in writing, software
106400 * distributed under the License is distributed on an "AS IS" BASIS,
106401 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106402 * See the License for the specific language governing permissions and
106403 * limitations under the License.
106404 * =============================================================================
106405 */
106406 function gatherV2$1(args) {
106407 var inputs = args.inputs,
106408 backend = args.backend,
106409 attrs = args.attrs;
106410 var x = inputs.x,
106411 indices = inputs.indices;
106412 var axis = attrs.axis,
106413 batchDims = attrs.batchDims;
106414 assertNotComplex$1([x, indices], 'gatherV2');
106415 // Throw error when any index is out of bound.
106416 var parsedAxis = parseAxisParam(axis, x.shape)[0];
106417 var indicesVals = backend.data.get(indices.dataId).values;
106418 var axisDim = x.shape[parsedAxis];
106419 var _loop = function _loop() {
106420 var index = indicesVals[i];
106421 assert$1(index <= axisDim - 1 && index >= 0, function () {
106422 return "GatherV2: the index value ".concat(index, " is not in [0, ").concat(axisDim - 1, "]");
106423 });
106424 };
106425 for (var i = 0; i < indicesVals.length; ++i) {
106426 _loop();
106427 }
106428 var $batchDims = batchDims;
106429 if (batchDims == null) {
106430 $batchDims = 0;
106431 }
106432 var indicesSize = sizeFromShape(indices.shape);
106433 var shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, $batchDims);
106434 var flattenX = reshape$1({
106435 inputs: {
106436 x: x
106437 },
106438 backend: backend,
106439 attrs: {
106440 shape: [shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize, shapeInfo.sliceSize]
106441 }
106442 });
106443 var flattenIndex = reshape$1({
106444 inputs: {
106445 x: indices
106446 },
106447 backend: backend,
106448 attrs: {
106449 shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize]
106450 }
106451 });
106452 var flattenOutputShape = [shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize, shapeInfo.sliceSize];
106453 var indicesBuf = backend.bufferSync(flattenIndex);
106454 var xBuf = backend.bufferSync(flattenX);
106455 var outBuf = gatherV2Impl(xBuf, indicesBuf, flattenOutputShape);
106456 backend.disposeIntermediateTensorInfo(flattenX);
106457 backend.disposeIntermediateTensorInfo(flattenIndex);
106458 return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
106459 }
106460 var gatherV2Config$1 = {
106461 kernelName: GatherV2,
106462 backendName: 'cpu',
106463 kernelFunc: gatherV2$1
106464 };
106465
106466 /**
106467 * @license
106468 * Copyright 2020 Google LLC. All Rights Reserved.
106469 * Licensed under the Apache License, Version 2.0 (the "License");
106470 * you may not use this file except in compliance with the License.
106471 * You may obtain a copy of the License at
106472 *
106473 * http://www.apache.org/licenses/LICENSE-2.0
106474 *
106475 * Unless required by applicable law or agreed to in writing, software
106476 * distributed under the License is distributed on an "AS IS" BASIS,
106477 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106478 * See the License for the specific language governing permissions and
106479 * limitations under the License.
106480 * =============================================================================
106481 */
106482 function ifft$1(args) {
106483 var inputs = args.inputs,
106484 backend = args.backend;
106485 var input = inputs.input;
106486 var inputSize = sizeFromShape(input.shape);
106487 // Collapse all outer dimensions to a single batch dimension.
106488 var innerDimensionSize = input.shape[input.shape.length - 1];
106489 var batch = inputSize / innerDimensionSize;
106490 var input2D = reshape$1({
106491 inputs: {
106492 x: input
106493 },
106494 backend: backend,
106495 attrs: {
106496 shape: [batch, innerDimensionSize]
106497 }
106498 });
106499 var result = fftBatch(input2D, true, backend);
106500 var resultReshaped = reshape$1({
106501 inputs: {
106502 x: result
106503 },
106504 backend: backend,
106505 attrs: {
106506 shape: input.shape
106507 }
106508 });
106509 backend.disposeIntermediateTensorInfo(input2D);
106510 backend.disposeIntermediateTensorInfo(result);
106511 return resultReshaped;
106512 }
106513 var ifftConfig$1 = {
106514 kernelName: IFFT,
106515 backendName: 'cpu',
106516 kernelFunc: ifft$1
106517 };
106518
106519 /**
106520 * @license
106521 * Copyright 2020 Google LLC. All Rights Reserved.
106522 * Licensed under the Apache License, Version 2.0 (the License);
106523 * you may not use this file except in compliance with the License.
106524 * You may obtain a copy of the License at
106525 *
106526 * http://www.apache.org/licenses/LICENSE-2.0
106527 *
106528 * Unless required by applicable law or agreed to in writing, software
106529 * distributed under the License is distributed on an AS IS BASIS,
106530 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106531 * See the License for the specific language governing permissions and
106532 * limitations under the License.
106533 * =============================================================================
106534 */
106535 var isFinite$2 = unaryKernelFunc$1(IsFinite, function (xi) {
106536 return Number.isFinite(xi) ? 1 : 0;
106537 }, 'bool');
106538 var isFiniteConfig$1 = {
106539 kernelName: IsFinite,
106540 backendName: 'cpu',
106541 kernelFunc: isFinite$2
106542 };
106543
106544 /**
106545 * @license
106546 * Copyright 2020 Google LLC. All Rights Reserved.
106547 * Licensed under the Apache License, Version 2.0 (the License);
106548 * you may not use this file except in compliance with the License.
106549 * You may obtain a copy of the License at
106550 *
106551 * http://www.apache.org/licenses/LICENSE-2.0
106552 *
106553 * Unless required by applicable law or agreed to in writing, software
106554 * distributed under the License is distributed on an AS IS BASIS,
106555 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106556 * See the License for the specific language governing permissions and
106557 * limitations under the License.
106558 * =============================================================================
106559 */
106560 var isInf$1 = unaryKernelFunc$1(IsInf, function (xi) {
106561 return Math.abs(xi) === Infinity ? 1 : 0;
106562 }, 'bool');
106563 var isInfConfig$1 = {
106564 kernelName: IsInf,
106565 backendName: 'cpu',
106566 kernelFunc: isInf$1
106567 };
106568
106569 /**
106570 * @license
106571 * Copyright 2020 Google LLC. All Rights Reserved.
106572 * Licensed under the Apache License, Version 2.0 (the License);
106573 * you may not use this file except in compliance with the License.
106574 * You may obtain a copy of the License at
106575 *
106576 * http://www.apache.org/licenses/LICENSE-2.0
106577 *
106578 * Unless required by applicable law or agreed to in writing, software
106579 * distributed under the License is distributed on an AS IS BASIS,
106580 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106581 * See the License for the specific language governing permissions and
106582 * limitations under the License.
106583 * =============================================================================
106584 */
106585 var isNaN$2 = unaryKernelFunc$1(IsNan, function (xi) {
106586 return Number.isNaN(xi) ? 1 : 0;
106587 }, 'bool');
106588 var isNaNConfig$1 = {
106589 kernelName: IsNan,
106590 backendName: 'cpu',
106591 kernelFunc: isNaN$2
106592 };
106593
106594 /**
106595 * @license
106596 * Copyright 2020 Google LLC. All Rights Reserved.
106597 * Licensed under the Apache License, Version 2.0 (the "License");
106598 * you may not use this file except in compliance with the License.
106599 * You may obtain a copy of the License at
106600 *
106601 * http://www.apache.org/licenses/LICENSE-2.0
106602 *
106603 * Unless required by applicable law or agreed to in writing, software
106604 * distributed under the License is distributed on an "AS IS" BASIS,
106605 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106606 * See the License for the specific language governing permissions and
106607 * limitations under the License.
106608 * =============================================================================
106609 */
106610 function linSpace$1(args) {
106611 var backend = args.backend,
106612 attrs = args.attrs;
106613 var start = attrs.start,
106614 stop = attrs.stop,
106615 num = attrs.num;
106616 var outVals = linSpaceImpl(start, stop, num);
106617 return backend.makeTensorInfo([outVals.length], 'float32', outVals);
106618 }
106619 var linSpaceConfig$1 = {
106620 kernelName: LinSpace,
106621 backendName: 'cpu',
106622 kernelFunc: linSpace$1
106623 };
106624
106625 /**
106626 * @license
106627 * Copyright 2020 Google LLC. All Rights Reserved.
106628 * Licensed under the Apache License, Version 2.0 (the License);
106629 * you may not use this file except in compliance with the License.
106630 * You may obtain a copy of the License at
106631 *
106632 * http://www.apache.org/licenses/LICENSE-2.0
106633 *
106634 * Unless required by applicable law or agreed to in writing, software
106635 * distributed under the License is distributed on an AS IS BASIS,
106636 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106637 * See the License for the specific language governing permissions and
106638 * limitations under the License.
106639 * =============================================================================
106640 */
106641 var log1p$1 = unaryKernelFunc$1(Log1p, function (xi) {
106642 return Math.log1p(xi);
106643 });
106644 var log1pConfig$1 = {
106645 kernelName: Log1p,
106646 backendName: 'cpu',
106647 kernelFunc: log1p$1
106648 };
106649
106650 /**
106651 * @license
106652 * Copyright 2020 Google LLC. All Rights Reserved.
106653 * Licensed under the Apache License, Version 2.0 (the "License");
106654 * you may not use this file except in compliance with the License.
106655 * You may obtain a copy of the License at
106656 *
106657 * http://www.apache.org/licenses/LICENSE-2.0
106658 *
106659 * Unless required by applicable law or agreed to in writing, software
106660 * distributed under the License is distributed on an "AS IS" BASIS,
106661 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106662 * See the License for the specific language governing permissions and
106663 * limitations under the License.
106664 * =============================================================================
106665 */
106666 var logicalAndImpl = createSimpleBinaryKernelImpl(function (a, b) {
106667 return a && b;
106668 });
106669 var logicalAnd$1 = binaryKernelFunc$1(LogicalAnd, logicalAndImpl, null /* complexImpl */, 'bool');
106670 var logicalAndConfig$1 = {
106671 kernelName: LogicalAnd,
106672 backendName: 'cpu',
106673 kernelFunc: logicalAnd$1
106674 };
106675
106676 /**
106677 * @license
106678 * Copyright 2020 Google LLC. All Rights Reserved.
106679 * Licensed under the Apache License, Version 2.0 (the License);
106680 * you may not use this file except in compliance with the License.
106681 * You may obtain a copy of the License at
106682 *
106683 * http://www.apache.org/licenses/LICENSE-2.0
106684 *
106685 * Unless required by applicable law or agreed to in writing, software
106686 * distributed under the License is distributed on an AS IS BASIS,
106687 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106688 * See the License for the specific language governing permissions and
106689 * limitations under the License.
106690 * =============================================================================
106691 */
106692 var logicalNot$1 = unaryKernelFunc$1(LogicalNot, function (xi) {
106693 return xi ? 0 : 1;
106694 }, 'bool');
106695 var logicalNotConfig$1 = {
106696 kernelName: LogicalNot,
106697 backendName: 'cpu',
106698 kernelFunc: logicalNot$1
106699 };
106700
106701 /**
106702 * @license
106703 * Copyright 2020 Google LLC. All Rights Reserved.
106704 * Licensed under the Apache License, Version 2.0 (the "License");
106705 * you may not use this file except in compliance with the License.
106706 * You may obtain a copy of the License at
106707 *
106708 * http://www.apache.org/licenses/LICENSE-2.0
106709 *
106710 * Unless required by applicable law or agreed to in writing, software
106711 * distributed under the License is distributed on an "AS IS" BASIS,
106712 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106713 * See the License for the specific language governing permissions and
106714 * limitations under the License.
106715 * =============================================================================
106716 */
106717 var logicalOrImpl = createSimpleBinaryKernelImpl(function (a, b) {
106718 return a || b;
106719 });
106720 var logicalOr$1 = binaryKernelFunc$1(LogicalOr, logicalOrImpl, null /* complexImpl */, 'bool');
106721 var logicalOrConfig$1 = {
106722 kernelName: LogicalOr,
106723 backendName: 'cpu',
106724 kernelFunc: logicalOr$1
106725 };
106726
106727 /**
106728 * @license
106729 * Copyright 2020 Google LLC. All Rights Reserved.
106730 * Licensed under the Apache License, Version 2.0 (the "License");
106731 * you may not use this file except in compliance with the License.
106732 * You may obtain a copy of the License at
106733 *
106734 * http://www.apache.org/licenses/LICENSE-2.0
106735 *
106736 * Unless required by applicable law or agreed to in writing, software
106737 * distributed under the License is distributed on an "AS IS" BASIS,
106738 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106739 * See the License for the specific language governing permissions and
106740 * limitations under the License.
106741 * =============================================================================
106742 */
106743 function lRN(args) {
106744 var inputs = args.inputs,
106745 backend = args.backend,
106746 attrs = args.attrs;
106747 var x = inputs.x;
106748 var depthRadius = attrs.depthRadius,
106749 bias = attrs.bias,
106750 alpha = attrs.alpha,
106751 beta = attrs.beta;
106752 assertNotComplex$1(x, 'LRN');
106753 var channels = x.shape[3];
106754 var maxD = channels - 1;
106755 var xValues = backend.data.get(x.dataId).values;
106756 var size = sizeFromShape(x.shape);
106757 var result = new Float32Array(size);
106758 function sumAcrossChannels(offset) {
106759 var currentChannel = offset % channels;
106760 var beginSumOffset = offset - currentChannel + Math.max(0, currentChannel - depthRadius);
106761 var endSumOffset = offset - currentChannel + Math.min(currentChannel + depthRadius, maxD);
106762 var sum = 0.0;
106763 for (; beginSumOffset <= endSumOffset; beginSumOffset++) {
106764 var z = xValues[beginSumOffset];
106765 sum += z * z;
106766 }
106767 return sum;
106768 }
106769 for (var offset = 0; offset < size; offset++) {
106770 var sum = sumAcrossChannels(offset);
106771 var val = xValues[offset] * Math.pow(bias + alpha * sum, -beta);
106772 result[offset] = val;
106773 }
106774 return backend.makeTensorInfo(x.shape, x.dtype, result);
106775 }
106776 // tslint:disable-next-line: variable-name
106777 var LRNConfig$1 = {
106778 kernelName: LRN,
106779 backendName: 'cpu',
106780 kernelFunc: lRN
106781 };
106782
106783 /**
106784 * @license
106785 * Copyright 2020 Google LLC. All Rights Reserved.
106786 * Licensed under the Apache License, Version 2.0 (the "License");
106787 * you may not use this file except in compliance with the License.
106788 * You may obtain a copy of the License at
106789 *
106790 * http://www.apache.org/licenses/LICENSE-2.0
106791 *
106792 * Unless required by applicable law or agreed to in writing, software
106793 * distributed under the License is distributed on an "AS IS" BASIS,
106794 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106795 * See the License for the specific language governing permissions and
106796 * limitations under the License.
106797 * =============================================================================
106798 */
106799 function lRNGrad(args) {
106800 var inputs = args.inputs,
106801 backend = args.backend,
106802 attrs = args.attrs;
106803 var x = inputs.x,
106804 y = inputs.y,
106805 dy = inputs.dy;
106806 var depthRadius = attrs.depthRadius,
106807 bias = attrs.bias,
106808 alpha = attrs.alpha,
106809 beta = attrs.beta;
106810 assertNotComplex$1(dy, 'LRNGrad');
106811 var dySize = sizeFromShape(dy.shape);
106812 var channels = dy.shape[3];
106813 var dyValues = backend.data.get(dy.dataId).values;
106814 var xValues = backend.data.get(x.dataId).values;
106815 var yValues = backend.data.get(y.dataId).values;
106816 var result = new Float32Array(dySize);
106817 var size = dySize;
106818 for (var offset = 0; offset < size; offset++) {
106819 var currentChannel = offset % channels;
106820 var depthBegin = offset - currentChannel + Math.max(0, currentChannel - depthRadius);
106821 var depthEnd = offset - currentChannel + Math.min(channels, currentChannel + depthRadius + 1);
106822 var norm = 0;
106823 for (var k = depthBegin; k < depthEnd; k++) {
106824 norm += Math.pow(xValues[k], 2);
106825 }
106826 norm = alpha * norm + bias;
106827 for (var _k = depthBegin; _k < depthEnd; _k++) {
106828 var dyi = -2 * alpha * beta * xValues[_k] * yValues[offset] / norm;
106829 if (offset === _k) {
106830 dyi += Math.pow(norm, -beta);
106831 }
106832 dyi *= dyValues[offset];
106833 result[_k] += dyi;
106834 }
106835 }
106836 return backend.makeTensorInfo(dy.shape, x.dtype, result);
106837 }
106838 // tslint:disable-next-line: variable-name
106839 var LRNGradConfig$1 = {
106840 kernelName: LRNGrad,
106841 backendName: 'cpu',
106842 kernelFunc: lRNGrad
106843 };
106844
106845 function max$1(args) {
106846 var inputs = args.inputs,
106847 backend = args.backend,
106848 attrs = args.attrs;
106849 var x = inputs.x;
106850 var reductionIndices = attrs.reductionIndices,
106851 keepDims = attrs.keepDims;
106852 var cpuBackend = backend;
106853 var xShape = x.shape;
106854 var xRank = xShape.length;
106855 var origAxes = parseAxisParam(reductionIndices, xShape);
106856 var axes = origAxes;
106857 var permutedAxes = getAxesPermutation(axes, xRank);
106858 var xVals = cpuBackend.data.get(x.dataId).values;
106859 if (permutedAxes != null) {
106860 var newShape = new Array(xRank);
106861 for (var i = 0; i < newShape.length; i++) {
106862 newShape[i] = xShape[permutedAxes[i]];
106863 }
106864 xVals = transposeImpl$1(xVals, xShape, x.dtype, permutedAxes, newShape);
106865 axes = getInnerMostAxes(axes.length, xRank);
106866 xShape = newShape;
106867 }
106868 assertNotComplex$1(x, 'max');
106869 assertAxesAreInnerMostDims('max', axes, xRank);
106870 var _backend_util$compute = computeOutAndReduceShapes(xShape, axes),
106871 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
106872 maxOutShape = _backend_util$compute2[0],
106873 reduceShape = _backend_util$compute2[1];
106874 var reduceSize = sizeFromShape(reduceShape);
106875 var result = maxImpl$1(xVals, reduceSize, maxOutShape, x.dtype);
106876 var dataId = cpuBackend.write(result, maxOutShape, x.dtype);
106877 var outShape = maxOutShape;
106878 if (keepDims) {
106879 // reshape
106880 var _newShape = expandShapeToKeepDim(maxOutShape, origAxes);
106881 outShape = _newShape;
106882 }
106883 return {
106884 dataId: dataId,
106885 shape: outShape,
106886 dtype: x.dtype
106887 };
106888 }
106889 var maxConfig$1 = {
106890 kernelName: Max,
106891 backendName: 'cpu',
106892 kernelFunc: max$1
106893 };
106894
106895 /**
106896 * @license
106897 * Copyright 2020 Google LLC. All Rights Reserved.
106898 * Licensed under the Apache License, Version 2.0 (the "License");
106899 * you may not use this file except in compliance with the License.
106900 * You may obtain a copy of the License at
106901 *
106902 * http://www.apache.org/licenses/LICENSE-2.0
106903 *
106904 * Unless required by applicable law or agreed to in writing, software
106905 * distributed under the License is distributed on an "AS IS" BASIS,
106906 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106907 * See the License for the specific language governing permissions and
106908 * limitations under the License.
106909 * =============================================================================
106910 */
106911 function maxPool$1(args) {
106912 var inputs = args.inputs,
106913 backend = args.backend,
106914 attrs = args.attrs;
106915 var x = inputs.x;
106916 assertNotComplex$1(x, 'maxPool');
106917 var filterSize = attrs.filterSize,
106918 strides = attrs.strides,
106919 pad = attrs.pad,
106920 dimRoundingMode = attrs.dimRoundingMode;
106921 var dilations = 1;
106922 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
106923 return 'Error in maxPool: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
106924 });
106925 var convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
106926 var res;
106927 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) {
106928 res = identity$1({
106929 inputs: {
106930 x: x
106931 },
106932 backend: backend
106933 });
106934 } else {
106935 var xValues = backend.data.get(x.dataId).values;
106936 var _strides = computeStrides(x.shape);
106937 var buffer = pool(xValues, x.shape, x.dtype, _strides, convInfo, 'max');
106938 res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
106939 }
106940 return res;
106941 }
106942 var maxPoolConfig$1 = {
106943 kernelName: MaxPool,
106944 backendName: 'cpu',
106945 kernelFunc: maxPool$1
106946 };
106947
106948 /**
106949 * @license
106950 * Copyright 2020 Google LLC. All Rights Reserved.
106951 * Licensed under the Apache License, Version 2.0 (the "License");
106952 * you may not use this file except in compliance with the License.
106953 * You may obtain a copy of the License at
106954 *
106955 * http://www.apache.org/licenses/LICENSE-2.0
106956 *
106957 * Unless required by applicable law or agreed to in writing, software
106958 * distributed under the License is distributed on an "AS IS" BASIS,
106959 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106960 * See the License for the specific language governing permissions and
106961 * limitations under the License.
106962 * =============================================================================
106963 */
106964 function maxPool3D(args) {
106965 var inputs = args.inputs,
106966 backend = args.backend,
106967 attrs = args.attrs;
106968 var x = inputs.x;
106969 var filterSize = attrs.filterSize,
106970 strides = attrs.strides,
106971 pad = attrs.pad,
106972 dimRoundingMode = attrs.dimRoundingMode,
106973 dataFormat = attrs.dataFormat;
106974 assertNotComplex$1(x, 'maxPool3d');
106975 var convInfo = computePool3DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode, dataFormat);
106976 var xValues = backend.data.get(x.dataId).values;
106977 var outBuf = pool3d(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'max');
106978 return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
106979 }
106980 var maxPool3DConfig$1 = {
106981 kernelName: MaxPool3D,
106982 backendName: 'cpu',
106983 kernelFunc: maxPool3D
106984 };
106985
106986 /**
106987 * @license
106988 * Copyright 2020 Google LLC. All Rights Reserved.
106989 * Licensed under the Apache License, Version 2.0 (the "License");
106990 * you may not use this file except in compliance with the License.
106991 * You may obtain a copy of the License at
106992 *
106993 * http://www.apache.org/licenses/LICENSE-2.0
106994 *
106995 * Unless required by applicable law or agreed to in writing, software
106996 * distributed under the License is distributed on an "AS IS" BASIS,
106997 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
106998 * See the License for the specific language governing permissions and
106999 * limitations under the License.
107000 * =============================================================================
107001 */
107002 function maxPool3DGrad$1(args) {
107003 var inputs = args.inputs,
107004 backend = args.backend,
107005 attrs = args.attrs;
107006 var dy = inputs.dy,
107007 input = inputs.input;
107008 var filterSize = attrs.filterSize,
107009 strides = attrs.strides,
107010 pad = attrs.pad,
107011 dimRoundingMode = attrs.dimRoundingMode;
107012 assertNotComplex$1([dy, input], 'maxPool3DGrad');
107013 var convInfo = computePool3DInfo(input.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
107014 var inputBuf = backend.bufferSync(input);
107015 var maxPosBuf = maxPool3dPositions(inputBuf, convInfo);
107016 var strideDepth = convInfo.strideDepth;
107017 var strideHeight = convInfo.strideHeight;
107018 var strideWidth = convInfo.strideWidth;
107019 var dilationDepth = convInfo.dilationDepth;
107020 var dilationHeight = convInfo.dilationHeight;
107021 var dilationWidth = convInfo.dilationWidth;
107022 var effectiveFilterDepth = convInfo.effectiveFilterDepth;
107023 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
107024 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
107025 var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
107026 var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
107027 var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
107028 var dx = buffer(input.shape, 'float32');
107029 var dyBuf = backend.bufferSync(dy);
107030 for (var batch = 0; batch < convInfo.batchSize; ++batch) {
107031 for (var channel = 0; channel < convInfo.inChannels; ++channel) {
107032 for (var dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
107033 for (var dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
107034 for (var dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
107035 // Shader code begins
107036 var dyDepthCorner = dxDepth - padFront;
107037 var dyRowCorner = dxRow - padTop;
107038 var dyColCorner = dxCol - padLeft;
107039 var dotProd = 0;
107040 for (var wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
107041 var dyDepth = (dyDepthCorner + wDepth) / strideDepth;
107042 if (dyDepth < 0 || dyDepth >= convInfo.outDepth || Math.floor(dyDepth) !== dyDepth) {
107043 continue;
107044 }
107045 for (var wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
107046 var dyRow = (dyRowCorner + wRow) / strideHeight;
107047 if (dyRow < 0 || dyRow >= convInfo.outHeight || Math.floor(dyRow) !== dyRow) {
107048 continue;
107049 }
107050 for (var wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
107051 var dyCol = (dyColCorner + wCol) / strideWidth;
107052 if (dyCol < 0 || dyCol >= convInfo.outWidth || Math.floor(dyCol) !== dyCol) {
107053 continue;
107054 }
107055 var maxPos = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1 - maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);
107056 var curPos = wDepth * effectiveFilterHeight * effectiveFilterWidth + wRow * effectiveFilterWidth + wCol;
107057 var mask = maxPos === curPos ? 1 : 0;
107058 if (mask === 0) {
107059 continue;
107060 }
107061 var pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
107062 dotProd += pixel * mask;
107063 }
107064 }
107065 }
107066 dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);
107067 }
107068 }
107069 }
107070 }
107071 }
107072 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
107073 }
107074 var maxPool3DGradConfig$1 = {
107075 kernelName: MaxPool3DGrad,
107076 backendName: 'cpu',
107077 kernelFunc: maxPool3DGrad$1
107078 };
107079
107080 /**
107081 * @license
107082 * Copyright 2020 Google LLC. All Rights Reserved.
107083 * Licensed under the Apache License, Version 2.0 (the "License");
107084 * you may not use this file except in compliance with the License.
107085 * You may obtain a copy of the License at
107086 *
107087 * http://www.apache.org/licenses/LICENSE-2.0
107088 *
107089 * Unless required by applicable law or agreed to in writing, software
107090 * distributed under the License is distributed on an "AS IS" BASIS,
107091 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107092 * See the License for the specific language governing permissions and
107093 * limitations under the License.
107094 * =============================================================================
107095 */
107096 function maxPoolGrad$1(args) {
107097 var inputs = args.inputs,
107098 backend = args.backend,
107099 attrs = args.attrs;
107100 var dy = inputs.dy,
107101 input = inputs.input,
107102 output = inputs.output;
107103 var x = input;
107104 assertNotComplex$1([input, output], 'maxPoolGrad');
107105 var filterSize = attrs.filterSize,
107106 strides = attrs.strides,
107107 pad = attrs.pad,
107108 dimRoundingMode = attrs.dimRoundingMode;
107109 var convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
107110 var xValues = backend.data.get(x.dataId).values;
107111 var maxPosBuf = buffer(convInfo.outShape, x.dtype, maxPoolPositions(xValues, x.shape, x.dtype, convInfo).values);
107112 var strideHeight = convInfo.strideHeight;
107113 var strideWidth = convInfo.strideWidth;
107114 var dilationHeight = convInfo.dilationHeight;
107115 var dilationWidth = convInfo.dilationWidth;
107116 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
107117 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
107118 var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
107119 var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
107120 var dx = buffer(x.shape, 'float32');
107121 var dyData = backend.data.get(dy.dataId).values;
107122 var dyBuf = buffer(dy.shape, 'float32', dyData);
107123 for (var b = 0; b < convInfo.batchSize; ++b) {
107124 for (var d = 0; d < convInfo.inChannels; ++d) {
107125 for (var dxR = 0; dxR < convInfo.inHeight; ++dxR) {
107126 for (var dxC = 0; dxC < convInfo.inWidth; ++dxC) {
107127 // Shader code begins.
107128 var dyRCorner = dxR - padTop;
107129 var dyCCorner = dxC - padLeft;
107130 var dotProd = 0;
107131 for (var wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
107132 var dyR = (dyRCorner + wR) / strideHeight;
107133 if (dyR < 0 || dyR >= convInfo.outHeight || Math.floor(dyR) !== dyR) {
107134 continue;
107135 }
107136 for (var wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
107137 var dyC = (dyCCorner + wC) / strideWidth;
107138 if (dyC < 0 || dyC >= convInfo.outWidth || Math.floor(dyC) !== dyC) {
107139 continue;
107140 }
107141 var maxPos = effectiveFilterHeight * effectiveFilterWidth - 1 - maxPosBuf.get(b, dyR, dyC, d);
107142 var curPos = wR * effectiveFilterWidth + wC;
107143 var mask = maxPos === curPos ? 1 : 0;
107144 if (mask === 0) {
107145 continue;
107146 }
107147 var pixel = dyBuf.get(b, dyR, dyC, d);
107148 dotProd += pixel * mask;
107149 }
107150 }
107151 dx.set(dotProd, b, dxR, dxC, d);
107152 }
107153 }
107154 }
107155 }
107156 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
107157 }
107158 var maxPoolGradConfig$1 = {
107159 kernelName: MaxPoolGrad,
107160 backendName: 'cpu',
107161 kernelFunc: maxPoolGrad$1
107162 };
107163
107164 /**
107165 * @license
107166 * Copyright 2020 Google LLC. All Rights Reserved.
107167 * Licensed under the Apache License, Version 2.0 (the "License");
107168 * you may not use this file except in compliance with the License.
107169 * You may obtain a copy of the License at
107170 *
107171 * http://www.apache.org/licenses/LICENSE-2.0
107172 *
107173 * Unless required by applicable law or agreed to in writing, software
107174 * distributed under the License is distributed on an "AS IS" BASIS,
107175 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107176 * See the License for the specific language governing permissions and
107177 * limitations under the License.
107178 * =============================================================================
107179 */
107180 function maxPoolWithArgmaxImpl$1(xValues, xShape, dtype, includeBatchInIndex, convInfo) {
107181 var strides = computeStrides(xShape);
107182 var maxPools = pool(xValues, xShape, dtype, strides, convInfo, 'max');
107183 var maxPositions = maxPoolPositions(xValues, xShape, dtype, convInfo, true, includeBatchInIndex);
107184 return [maxPools.values, maxPositions.values];
107185 }
107186
107187 var maxPoolWithArgmaxConfig$1 = {
107188 kernelName: MaxPoolWithArgmax,
107189 backendName: 'cpu',
107190 kernelFunc: function kernelFunc(_ref) {
107191 var inputs = _ref.inputs,
107192 attrs = _ref.attrs,
107193 backend = _ref.backend;
107194 var x = inputs.x;
107195 var filterSize = attrs.filterSize,
107196 strides = attrs.strides,
107197 pad = attrs.pad,
107198 includeBatchInIndex = attrs.includeBatchInIndex;
107199 var cpuBackend = backend;
107200 assertNotComplex$1(x, 'MaxPoolWithArgmax');
107201 var values = cpuBackend.data.get(x.dataId).values;
107202 var convInfo = computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad);
107203 var _maxPoolWithArgmaxImp = maxPoolWithArgmaxImpl$1(values, x.shape, x.dtype, includeBatchInIndex, convInfo),
107204 _maxPoolWithArgmaxImp2 = _slicedToArray(_maxPoolWithArgmaxImp, 2),
107205 pooled = _maxPoolWithArgmaxImp2[0],
107206 indexes = _maxPoolWithArgmaxImp2[1];
107207 var pooledDataId = cpuBackend.write(pooled, convInfo.outShape, x.dtype);
107208 var indexesDataId = cpuBackend.write(indexes, convInfo.outShape, x.dtype);
107209 return [{
107210 dataId: pooledDataId,
107211 shape: convInfo.outShape,
107212 dtype: x.dtype
107213 }, {
107214 dataId: indexesDataId,
107215 shape: convInfo.outShape,
107216 dtype: 'int32'
107217 }];
107218 }
107219 };
107220
107221 /**
107222 * @license
107223 * Copyright 2020 Google LLC. All Rights Reserved.
107224 * Licensed under the Apache License, Version 2.0 (the "License");
107225 * you may not use this file except in compliance with the License.
107226 * You may obtain a copy of the License at
107227 *
107228 * http://www.apache.org/licenses/LICENSE-2.0
107229 *
107230 * Unless required by applicable law or agreed to in writing, software
107231 * distributed under the License is distributed on an "AS IS" BASIS,
107232 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107233 * See the License for the specific language governing permissions and
107234 * limitations under the License.
107235 * =============================================================================
107236 */
107237 function mean(args) {
107238 var inputs = args.inputs,
107239 backend = args.backend,
107240 attrs = args.attrs;
107241 var x = inputs.x;
107242 var axis = attrs.axis,
107243 keepDims = attrs.keepDims;
107244 var axes = parseAxisParam(axis, x.shape);
107245 var shapes = computeOutAndReduceShapes(x.shape, axes);
107246 var reduceShape = shapes[1];
107247 var reduceSize = sizeFromShape(reduceShape);
107248 var toDispose = [];
107249 var reduceSizeScalar = backend.makeTensorInfo([], 'float32', new Float32Array([reduceSize]));
107250 toDispose.push(reduceSizeScalar);
107251 var $x = cast$1({
107252 inputs: {
107253 x: x
107254 },
107255 backend: backend,
107256 attrs: {
107257 dtype: 'float32'
107258 }
107259 });
107260 toDispose.push($x);
107261 var res = div({
107262 inputs: {
107263 a: $x,
107264 b: reduceSizeScalar
107265 },
107266 backend: backend
107267 });
107268 toDispose.push(res);
107269 var result = sum$1({
107270 inputs: {
107271 x: res
107272 },
107273 backend: backend,
107274 attrs: {
107275 axis: axis,
107276 keepDims: keepDims
107277 }
107278 });
107279 toDispose.forEach(function (t) {
107280 return backend.disposeIntermediateTensorInfo(t);
107281 });
107282 return result;
107283 }
107284 var meanConfig$1 = {
107285 kernelName: Mean,
107286 backendName: 'cpu',
107287 kernelFunc: mean
107288 };
107289
107290 function min$1(args) {
107291 var inputs = args.inputs,
107292 backend = args.backend,
107293 attrs = args.attrs;
107294 var x = inputs.x;
107295 var axis = attrs.axis,
107296 keepDims = attrs.keepDims;
107297 assertNotComplex$1(x, 'min');
107298 var origAxes = parseAxisParam(axis, x.shape);
107299 var axes = origAxes;
107300 var permutedAxes = getAxesPermutation(axes, x.shape.length);
107301 var $x = x;
107302 if (permutedAxes != null) {
107303 $x = transpose$1({
107304 inputs: {
107305 x: x
107306 },
107307 backend: backend,
107308 attrs: {
107309 perm: permutedAxes
107310 }
107311 });
107312 axes = getInnerMostAxes(axes.length, x.shape.length);
107313 }
107314 assertAxesAreInnerMostDims('min', axes, $x.shape.length);
107315 var _backend_util$compute = computeOutAndReduceShapes($x.shape, axes),
107316 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
107317 outShape = _backend_util$compute2[0],
107318 reduceShape = _backend_util$compute2[1];
107319 var reduceSize = sizeFromShape(reduceShape);
107320 var vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
107321 var aVals = backend.data.get($x.dataId).values;
107322 for (var i = 0; i < vals.length; ++i) {
107323 var offset = i * reduceSize;
107324 var _min = aVals[offset];
107325 for (var j = 0; j < reduceSize; ++j) {
107326 var value = aVals[offset + j];
107327 if (Number.isNaN(value) || value < _min) {
107328 // comparison with NaN always return false
107329 _min = value;
107330 }
107331 }
107332 vals[i] = _min;
107333 }
107334 if (permutedAxes != null) {
107335 backend.disposeIntermediateTensorInfo($x);
107336 }
107337 var result = backend.makeTensorInfo(outShape, $x.dtype, vals);
107338 if (keepDims) {
107339 var expandedShape = expandShapeToKeepDim(outShape, origAxes);
107340 var reshapedResult = reshape$1({
107341 inputs: {
107342 x: result
107343 },
107344 backend: backend,
107345 attrs: {
107346 shape: expandedShape
107347 }
107348 });
107349 backend.disposeIntermediateTensorInfo(result);
107350 return reshapedResult;
107351 }
107352 return result;
107353 }
107354 var minConfig$1 = {
107355 kernelName: Min,
107356 backendName: 'cpu',
107357 kernelFunc: min$1
107358 };
107359
107360 /**
107361 * @license
107362 * Copyright 2020 Google LLC. All Rights Reserved.
107363 * Licensed under the Apache License, Version 2.0 (the "License");
107364 * you may not use this file except in compliance with the License.
107365 * You may obtain a copy of the License at
107366 *
107367 * http://www.apache.org/licenses/LICENSE-2.0
107368 *
107369 * Unless required by applicable law or agreed to in writing, software
107370 * distributed under the License is distributed on an "AS IS" BASIS,
107371 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107372 * See the License for the specific language governing permissions and
107373 * limitations under the License.
107374 * =============================================================================
107375 */
107376 function mirrorPad(args) {
107377 var inputs = args.inputs,
107378 backend = args.backend,
107379 attrs = args.attrs;
107380 var x = inputs.x;
107381 var paddings = attrs.paddings,
107382 mode = attrs.mode;
107383 assertNotComplex$1(x, 'mirrorPad');
107384 var outShape = paddings.map(function (p, i) {
107385 return p[0] /* beforePad */ + x.shape[i] + p[1];
107386 } /* afterPad */);
107387 var start = paddings.map(function (p) {
107388 return p[0];
107389 });
107390 var end = paddings.map(function (p, i) {
107391 return p[0] + x.shape[i];
107392 });
107393 var offset = mode === 'reflect' ? 0 : 1;
107394 var xVals = backend.data.get(x.dataId).values;
107395 var xRank = x.shape.length;
107396 var xStrides = computeStrides(x.shape);
107397 var resultSize = sizeFromShape(outShape);
107398 var resultRank = outShape.length;
107399 var resultStrides = computeStrides(outShape);
107400 var resVals = getTypedArrayFromDType(x.dtype, resultSize);
107401 for (var i = 0; i < resultSize; i++) {
107402 var coords = indexToLoc(i, resultRank, resultStrides);
107403 for (var _i = 0; _i < resultRank; _i++) {
107404 if (coords[_i] < start[_i]) {
107405 coords[_i] = start[_i] * 2 - coords[_i] - offset;
107406 } else if (coords[_i] >= end[_i]) {
107407 coords[_i] = (end[_i] - 1) * 2 - coords[_i] + offset;
107408 }
107409 }
107410 coords = coords.map(function (c, i) {
107411 return c - start[i];
107412 });
107413 var inIndex = locToIndex(coords, xRank, xStrides);
107414 resVals[i] = xVals[inIndex];
107415 }
107416 var outId = backend.write(resVals, outShape, x.dtype);
107417 return {
107418 dataId: outId,
107419 shape: outShape,
107420 dtype: x.dtype
107421 };
107422 }
107423 var mirrorPadConfig$1 = {
107424 kernelName: MirrorPad,
107425 backendName: 'cpu',
107426 kernelFunc: mirrorPad
107427 };
107428
107429 /**
107430 * @license
107431 * Copyright 2020 Google LLC. All Rights Reserved.
107432 * Licensed under the Apache License, Version 2.0 (the "License");
107433 * you may not use this file except in compliance with the License.
107434 * You may obtain a copy of the License at
107435 *
107436 * http://www.apache.org/licenses/LICENSE-2.0
107437 *
107438 * Unless required by applicable law or agreed to in writing, software
107439 * distributed under the License is distributed on an "AS IS" BASIS,
107440 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107441 * See the License for the specific language governing permissions and
107442 * limitations under the License.
107443 * =============================================================================
107444 */
107445 var modImpl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
107446 var rem = aValue % bValue;
107447 if (aValue < 0 && bValue < 0 || aValue >= 0 && bValue >= 0) {
107448 return rem;
107449 } else {
107450 return (rem + bValue) % bValue;
107451 }
107452 });
107453 var mod$1 = binaryKernelFunc$1(Mod, modImpl);
107454 var modConfig$1 = {
107455 kernelName: Mod,
107456 backendName: 'cpu',
107457 kernelFunc: mod$1
107458 };
107459
107460 /**
107461 * @license
107462 * Copyright 2020 Google LLC. All Rights Reserved.
107463 * Licensed under the Apache License, Version 2.0 (the "License");
107464 * you may not use this file except in compliance with the License.
107465 * You may obtain a copy of the License at
107466 *
107467 * http://www.apache.org/licenses/LICENSE-2.0
107468 *
107469 * Unless required by applicable law or agreed to in writing, software
107470 * distributed under the License is distributed on an "AS IS" BASIS,
107471 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107472 * See the License for the specific language governing permissions and
107473 * limitations under the License.
107474 * =============================================================================
107475 */
107476 function softmax$1(args) {
107477 var inputs = args.inputs,
107478 backend = args.backend,
107479 attrs = args.attrs;
107480 var logits = inputs.logits;
107481 var dim = attrs.dim;
107482 var logitsRank = logits.shape.length;
107483 var $dim = dim;
107484 if ($dim === -1) {
107485 $dim = logitsRank - 1;
107486 }
107487 if ($dim !== logitsRank - 1) {
107488 throw Error('Softmax along a non-last dimension is not yet supported. ' + "Logits was rank ".concat(logitsRank, " and dim was ").concat($dim));
107489 }
107490 var axes = parseAxisParam([$dim], logits.shape);
107491 var maxLogit = max$1({
107492 inputs: {
107493 x: logits
107494 },
107495 backend: backend,
107496 attrs: {
107497 reductionIndices: axes,
107498 keepDims: false
107499 }
107500 });
107501 var expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
107502 var maxLogitReshaped = reshape$1({
107503 inputs: {
107504 x: maxLogit
107505 },
107506 backend: backend,
107507 attrs: {
107508 shape: expandedShape
107509 }
107510 });
107511 var a = sub$1({
107512 inputs: {
107513 a: logits,
107514 b: maxLogitReshaped
107515 },
107516 backend: backend
107517 });
107518 var b = exp$1({
107519 inputs: {
107520 x: a
107521 },
107522 backend: backend
107523 });
107524 var sumExp = sum$1({
107525 inputs: {
107526 x: b
107527 },
107528 backend: backend,
107529 attrs: {
107530 axis: axes,
107531 keepDims: false
107532 }
107533 });
107534 var sumReshaped = reshape$1({
107535 inputs: {
107536 x: sumExp
107537 },
107538 backend: backend,
107539 attrs: {
107540 shape: expandedShape
107541 }
107542 });
107543 var result = div({
107544 inputs: {
107545 a: b,
107546 b: sumReshaped
107547 },
107548 backend: backend
107549 });
107550 backend.disposeIntermediateTensorInfo(maxLogit);
107551 backend.disposeIntermediateTensorInfo(maxLogitReshaped);
107552 backend.disposeIntermediateTensorInfo(a);
107553 backend.disposeIntermediateTensorInfo(b);
107554 backend.disposeIntermediateTensorInfo(sumExp);
107555 backend.disposeIntermediateTensorInfo(sumReshaped);
107556 return result;
107557 }
107558 var softmaxConfig$1 = {
107559 kernelName: Softmax$2,
107560 backendName: 'cpu',
107561 kernelFunc: softmax$1
107562 };
107563
107564 /**
107565 * @license
107566 * Copyright 2020 Google LLC. All Rights Reserved.
107567 * Licensed under the Apache License, Version 2.0 (the "License");
107568 * you may not use this file except in compliance with the License.
107569 * You may obtain a copy of the License at
107570 *
107571 * http://www.apache.org/licenses/LICENSE-2.0
107572 *
107573 * Unless required by applicable law or agreed to in writing, software
107574 * distributed under the License is distributed on an "AS IS" BASIS,
107575 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107576 * See the License for the specific language governing permissions and
107577 * limitations under the License.
107578 * =============================================================================
107579 */
107580 function multinomial$1(args) {
107581 var inputs = args.inputs,
107582 backend = args.backend,
107583 attrs = args.attrs;
107584 var logits = inputs.logits;
107585 var numSamples = attrs.numSamples,
107586 seed = attrs.seed,
107587 normalized = attrs.normalized;
107588 assertNotComplex$1(logits, 'multinomial');
107589 var probabilities = normalized ? logits : softmax$1({
107590 inputs: {
107591 logits: logits
107592 },
107593 backend: backend,
107594 attrs: {
107595 dim: -1
107596 }
107597 });
107598 var batchSize = probabilities.shape[0];
107599 var numEvents = probabilities.shape[1];
107600 var probVals = backend.data.get(probabilities.dataId).values;
107601 var resShape = [batchSize, numSamples];
107602 var resVals = makeZerosTypedArray(sizeFromShape(resShape), 'int32');
107603 for (var b = 0; b < batchSize; ++b) {
107604 var offset = b * numEvents;
107605 // The cdf won't include the last event. It will be implicit if no other
107606 // event happened.
107607 var cdf = new Float32Array(numEvents - 1);
107608 cdf[0] = probVals[offset];
107609 for (var event = 1; event < cdf.length; ++event) {
107610 cdf[event] = cdf[event - 1] + probVals[offset + event];
107611 }
107612 var random = seedrandom.alea(seed.toString());
107613 var outOffset = b * numSamples;
107614 for (var sampleId = 0; sampleId < numSamples; ++sampleId) {
107615 var r = random();
107616 // Assume last event happened by default.
107617 resVals[outOffset + sampleId] = cdf.length;
107618 for (var _event = 0; _event < cdf.length; _event++) {
107619 if (r < cdf[_event]) {
107620 resVals[outOffset + sampleId] = _event;
107621 break;
107622 }
107623 }
107624 }
107625 }
107626 if (!normalized) {
107627 backend.disposeIntermediateTensorInfo(probabilities);
107628 }
107629 return backend.makeTensorInfo(resShape, 'int32', resVals);
107630 }
107631 var multinomialConfig$1 = {
107632 kernelName: Multinomial,
107633 backendName: 'cpu',
107634 kernelFunc: multinomial$1
107635 };
107636
107637 /**
107638 * @license
107639 * Copyright 2020 Google LLC. All Rights Reserved.
107640 * Licensed under the Apache License, Version 2.0 (the "License");
107641 * you may not use this file except in compliance with the License.
107642 * You may obtain a copy of the License at
107643 *
107644 * http://www.apache.org/licenses/LICENSE-2.0
107645 *
107646 * Unless required by applicable law or agreed to in writing, software
107647 * distributed under the License is distributed on an "AS IS" BASIS,
107648 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107649 * See the License for the specific language governing permissions and
107650 * limitations under the License.
107651 * =============================================================================
107652 */
107653 var nonMaxSuppressionV3Impl$1 = nonMaxSuppressionV3Impl$2;
107654 function nonMaxSuppressionV3$1(args) {
107655 var inputs = args.inputs,
107656 backend = args.backend,
107657 attrs = args.attrs;
107658 var boxes = inputs.boxes,
107659 scores = inputs.scores;
107660 var maxOutputSize = attrs.maxOutputSize,
107661 iouThreshold = attrs.iouThreshold,
107662 scoreThreshold = attrs.scoreThreshold;
107663 assertNotComplex$1(boxes, 'NonMaxSuppression');
107664 var boxesVals = backend.data.get(boxes.dataId).values;
107665 var scoresVals = backend.data.get(scores.dataId).values;
107666 var _nonMaxSuppressionV3I = nonMaxSuppressionV3Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold),
107667 selectedIndices = _nonMaxSuppressionV3I.selectedIndices;
107668 return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
107669 }
107670 var nonMaxSuppressionV3Config$1 = {
107671 kernelName: NonMaxSuppressionV3,
107672 backendName: 'cpu',
107673 kernelFunc: nonMaxSuppressionV3$1
107674 };
107675
107676 /**
107677 * @license
107678 * Copyright 2020 Google LLC. All Rights Reserved.
107679 * Licensed under the Apache License, Version 2.0 (the "License");
107680 * you may not use this file except in compliance with the License.
107681 * You may obtain a copy of the License at
107682 *
107683 * http://www.apache.org/licenses/LICENSE-2.0
107684 *
107685 * Unless required by applicable law or agreed to in writing, software
107686 * distributed under the License is distributed on an "AS IS" BASIS,
107687 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107688 * See the License for the specific language governing permissions and
107689 * limitations under the License.
107690 * =============================================================================
107691 */
107692 var nonMaxSuppressionV4Impl$1 = nonMaxSuppressionV4Impl$2;
107693 function nonMaxSuppressionV4$1(args) {
107694 var inputs = args.inputs,
107695 backend = args.backend,
107696 attrs = args.attrs;
107697 var boxes = inputs.boxes,
107698 scores = inputs.scores;
107699 var maxOutputSize = attrs.maxOutputSize,
107700 iouThreshold = attrs.iouThreshold,
107701 scoreThreshold = attrs.scoreThreshold,
107702 padToMaxOutputSize = attrs.padToMaxOutputSize;
107703 assertNotComplex$1(boxes, 'NonMaxSuppressionPadded');
107704 var boxesVals = backend.data.get(boxes.dataId).values;
107705 var scoresVals = backend.data.get(scores.dataId).values;
107706 var _nonMaxSuppressionV4I = nonMaxSuppressionV4Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize),
107707 selectedIndices = _nonMaxSuppressionV4I.selectedIndices,
107708 validOutputs = _nonMaxSuppressionV4I.validOutputs;
107709 return [backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))];
107710 }
107711 var nonMaxSuppressionV4Config$1 = {
107712 kernelName: NonMaxSuppressionV4,
107713 backendName: 'cpu',
107714 kernelFunc: nonMaxSuppressionV4$1
107715 };
107716
107717 /**
107718 * @license
107719 * Copyright 2019 Google LLC. All Rights Reserved.
107720 * Licensed under the Apache License, Version 2.0 (the "License");
107721 * you may not use this file except in compliance with the License.
107722 * You may obtain a copy of the License at
107723 *
107724 * http://www.apache.org/licenses/LICENSE-2.0
107725 *
107726 * Unless required by applicable law or agreed to in writing, software
107727 * distributed under the License is distributed on an "AS IS" BASIS,
107728 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107729 * See the License for the specific language governing permissions and
107730 * limitations under the License.
107731 * =============================================================================
107732 */
107733 var nonMaxSuppressionV5Impl$1 = nonMaxSuppressionV5Impl$2;
107734 function nonMaxSuppressionV5$1(args) {
107735 var inputs = args.inputs,
107736 backend = args.backend,
107737 attrs = args.attrs;
107738 var boxes = inputs.boxes,
107739 scores = inputs.scores;
107740 var maxOutputSize = attrs.maxOutputSize,
107741 iouThreshold = attrs.iouThreshold,
107742 scoreThreshold = attrs.scoreThreshold,
107743 softNmsSigma = attrs.softNmsSigma;
107744 assertNotComplex$1(boxes, 'NonMaxSuppressionWithScore');
107745 var boxesVals = backend.data.get(boxes.dataId).values;
107746 var scoresVals = backend.data.get(scores.dataId).values;
107747 var maxOutputSizeVal = maxOutputSize;
107748 var iouThresholdVal = iouThreshold;
107749 var scoreThresholdVal = scoreThreshold;
107750 var softNmsSigmaVal = softNmsSigma;
107751 var _nonMaxSuppressionV5I = nonMaxSuppressionV5Impl$1(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal),
107752 selectedIndices = _nonMaxSuppressionV5I.selectedIndices,
107753 selectedScores = _nonMaxSuppressionV5I.selectedScores;
107754 return [backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))];
107755 }
107756 var nonMaxSuppressionV5Config$1 = {
107757 kernelName: NonMaxSuppressionV5,
107758 backendName: 'cpu',
107759 kernelFunc: nonMaxSuppressionV5$1
107760 };
107761
107762 function oneHot$1(args) {
107763 var inputs = args.inputs,
107764 backend = args.backend,
107765 attrs = args.attrs;
107766 var indices = inputs.indices;
107767 var dtype = attrs.dtype,
107768 depth = attrs.depth,
107769 onValue = attrs.onValue,
107770 offValue = attrs.offValue;
107771 assertNotComplex$1(indices, 'oneHot');
107772 var indicesSize = sizeFromShape(indices.shape);
107773 var res = new Float32Array(indicesSize * depth);
107774 res.fill(offValue);
107775 var indicesVal = backend.data.get(indices.dataId).values;
107776 for (var event = 0; event < indicesSize; ++event) {
107777 if (indicesVal[event] >= 0 && indicesVal[event] < depth) {
107778 res[event * depth + indicesVal[event]] = onValue;
107779 }
107780 }
107781 return backend.makeTensorInfo([].concat(_toConsumableArray(indices.shape), [depth]), dtype, res);
107782 }
107783 var oneHotConfig$1 = {
107784 kernelName: OneHot,
107785 backendName: 'cpu',
107786 kernelFunc: oneHot$1
107787 };
107788
107789 /**
107790 * @license
107791 * Copyright 2020 Google LLC. All Rights Reserved.
107792 * Licensed under the Apache License, Version 2.0 (the "License");
107793 * you may not use this file except in compliance with the License.
107794 * You may obtain a copy of the License at
107795 *
107796 * http://www.apache.org/licenses/LICENSE-2.0
107797 *
107798 * Unless required by applicable law or agreed to in writing, software
107799 * distributed under the License is distributed on an "AS IS" BASIS,
107800 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107801 * See the License for the specific language governing permissions and
107802 * limitations under the License.
107803 * =============================================================================
107804 */
107805 function zerosLike$1(args) {
107806 var inputs = args.inputs,
107807 backend = args.backend;
107808 var x = inputs.x;
107809 if (x.dtype === 'string') {
107810 throw new Error('zerosLike is not supported for string tensors');
107811 } else if (x.dtype === 'complex64') {
107812 var realPart = real$1({
107813 inputs: {
107814 input: x
107815 },
107816 backend: backend
107817 });
107818 var r = zerosLike$1({
107819 inputs: {
107820 x: realPart
107821 },
107822 backend: backend
107823 });
107824 var imagPart = imag$1({
107825 inputs: {
107826 input: x
107827 },
107828 backend: backend
107829 });
107830 var i = zerosLike$1({
107831 inputs: {
107832 x: imagPart
107833 },
107834 backend: backend
107835 });
107836 var result = complex$1({
107837 inputs: {
107838 real: r,
107839 imag: i
107840 },
107841 backend: backend
107842 });
107843 backend.disposeIntermediateTensorInfo(realPart);
107844 backend.disposeIntermediateTensorInfo(r);
107845 backend.disposeIntermediateTensorInfo(imagPart);
107846 backend.disposeIntermediateTensorInfo(i);
107847 return result;
107848 } else {
107849 return fill$1({
107850 backend: backend,
107851 attrs: {
107852 shape: x.shape,
107853 value: 0,
107854 dtype: x.dtype
107855 }
107856 });
107857 }
107858 }
107859 var zerosLikeConfig$1 = {
107860 kernelName: ZerosLike,
107861 backendName: 'cpu',
107862 kernelFunc: zerosLike$1
107863 };
107864
107865 /**
107866 * @license
107867 * Copyright 2020 Google LLC. All Rights Reserved.
107868 * Licensed under the Apache License, Version 2.0 (the "License");
107869 * you may not use this file except in compliance with the License.
107870 * You may obtain a copy of the License at
107871 *
107872 * http://www.apache.org/licenses/LICENSE-2.0
107873 *
107874 * Unless required by applicable law or agreed to in writing, software
107875 * distributed under the License is distributed on an "AS IS" BASIS,
107876 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107877 * See the License for the specific language governing permissions and
107878 * limitations under the License.
107879 * =============================================================================
107880 */
107881 function onesLike$1(args) {
107882 var inputs = args.inputs,
107883 backend = args.backend;
107884 var x = inputs.x;
107885 if (x.dtype === 'string') {
107886 throw new Error('onesLike is not supported for string tensors');
107887 } else if (x.dtype === 'complex64') {
107888 var realPart = real$1({
107889 inputs: {
107890 input: x
107891 },
107892 backend: backend
107893 });
107894 var r = onesLike$1({
107895 inputs: {
107896 x: realPart
107897 },
107898 backend: backend
107899 });
107900 var imagPart = imag$1({
107901 inputs: {
107902 input: x
107903 },
107904 backend: backend
107905 });
107906 var i = zerosLike$1({
107907 inputs: {
107908 x: imagPart
107909 },
107910 backend: backend
107911 });
107912 var result = complex$1({
107913 inputs: {
107914 real: r,
107915 imag: i
107916 },
107917 backend: backend
107918 });
107919 backend.disposeIntermediateTensorInfo(realPart);
107920 backend.disposeIntermediateTensorInfo(r);
107921 backend.disposeIntermediateTensorInfo(imagPart);
107922 backend.disposeIntermediateTensorInfo(i);
107923 return result;
107924 } else {
107925 return fill$1({
107926 backend: backend,
107927 attrs: {
107928 shape: x.shape,
107929 value: 1,
107930 dtype: x.dtype
107931 }
107932 });
107933 }
107934 }
107935 var onesLikeConfig$1 = {
107936 kernelName: OnesLike,
107937 backendName: 'cpu',
107938 kernelFunc: onesLike$1
107939 };
107940
107941 /**
107942 * @license
107943 * Copyright 2020 Google LLC. All Rights Reserved.
107944 * Licensed under the Apache License, Version 2.0 (the "License");
107945 * you may not use this file except in compliance with the License.
107946 * You may obtain a copy of the License at
107947 *
107948 * http://www.apache.org/licenses/LICENSE-2.0
107949 *
107950 * Unless required by applicable law or agreed to in writing, software
107951 * distributed under the License is distributed on an "AS IS" BASIS,
107952 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107953 * See the License for the specific language governing permissions and
107954 * limitations under the License.
107955 * =============================================================================
107956 */
107957 function pack$1(args) {
107958 var inputs = args.inputs,
107959 backend = args.backend,
107960 attrs = args.attrs;
107961 var axis = attrs.axis;
107962 if (inputs.length === 1) {
107963 return expandDims$1({
107964 inputs: {
107965 input: inputs[0]
107966 },
107967 backend: backend,
107968 attrs: {
107969 dim: axis
107970 }
107971 });
107972 }
107973 var shape = inputs[0].shape;
107974 var dtype = inputs[0].dtype;
107975 inputs.forEach(function (t) {
107976 assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
107977 assert$1(dtype === t.dtype, function () {
107978 return 'All tensors passed to stack must have matching dtypes';
107979 });
107980 });
107981 var intermediateTensorInfos = [];
107982 var expandedTensors = inputs.map(function (t) {
107983 var expandedT = expandDims$1({
107984 inputs: {
107985 input: t
107986 },
107987 backend: backend,
107988 attrs: {
107989 dim: axis
107990 }
107991 });
107992 intermediateTensorInfos.push(expandedT);
107993 return expandedT;
107994 });
107995 var result = concat$1({
107996 inputs: expandedTensors,
107997 backend: backend,
107998 attrs: {
107999 axis: axis
108000 }
108001 });
108002 intermediateTensorInfos.forEach(function (t) {
108003 return backend.disposeIntermediateTensorInfo(t);
108004 });
108005 return result;
108006 }
108007 var packConfig$1 = {
108008 kernelName: Pack,
108009 backendName: 'cpu',
108010 kernelFunc: pack$1
108011 };
108012
108013 /**
108014 * @license
108015 * Copyright 2020 Google LLC. All Rights Reserved.
108016 * Licensed under the Apache License, Version 2.0 (the "License");
108017 * you may not use this file except in compliance with the License.
108018 * You may obtain a copy of the License at
108019 *
108020 * http://www.apache.org/licenses/LICENSE-2.0
108021 *
108022 * Unless required by applicable law or agreed to in writing, software
108023 * distributed under the License is distributed on an "AS IS" BASIS,
108024 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108025 * See the License for the specific language governing permissions and
108026 * limitations under the License.
108027 * =============================================================================
108028 */
108029 function padV2$1(args) {
108030 var inputs = args.inputs,
108031 backend = args.backend,
108032 attrs = args.attrs;
108033 var x = inputs.x;
108034 var paddings = attrs.paddings,
108035 constantValue = attrs.constantValue;
108036 assertNotComplex$1(x, 'pad');
108037 var outShape = paddings.map(function (p, i) {
108038 return p[0] /* beforePad */ + x.shape[i] + p[1];
108039 } /* afterPad */);
108040 var start = paddings.map(function (p) {
108041 return p[0];
108042 });
108043 var xVals = backend.data.get(x.dataId).values;
108044 var xSize = sizeFromShape(x.shape);
108045 var xRank = x.shape.length;
108046 var xStrides = computeStrides(x.shape);
108047 var resultSize = sizeFromShape(outShape);
108048 var resultRank = outShape.length;
108049 var resultStrides = computeStrides(outShape);
108050 var resVals = getTypedArrayFromDType(x.dtype, resultSize);
108051 if (constantValue !== 0) {
108052 resVals.fill(constantValue);
108053 }
108054 for (var i = 0; i < xSize; i++) {
108055 var coords = indexToLoc(i, xRank, xStrides);
108056 var outCoords = coords.map(function (c, i) {
108057 return c + start[i];
108058 });
108059 var outIndex = locToIndex(outCoords, resultRank, resultStrides);
108060 resVals[outIndex] = xVals[i];
108061 }
108062 var outId = backend.write(resVals, outShape, x.dtype);
108063 return {
108064 dataId: outId,
108065 shape: outShape,
108066 dtype: x.dtype
108067 };
108068 }
108069 var padV2Config$1 = {
108070 kernelName: PadV2,
108071 backendName: 'cpu',
108072 kernelFunc: padV2$1
108073 };
108074
108075 /**
108076 * @license
108077 * Copyright 2020 Google LLC. All Rights Reserved.
108078 * Licensed under the Apache License, Version 2.0 (the "License");
108079 * you may not use this file except in compliance with the License.
108080 * You may obtain a copy of the License at
108081 *
108082 * http://www.apache.org/licenses/LICENSE-2.0
108083 *
108084 * Unless required by applicable law or agreed to in writing, software
108085 * distributed under the License is distributed on an "AS IS" BASIS,
108086 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108087 * See the License for the specific language governing permissions and
108088 * limitations under the License.
108089 * =============================================================================
108090 */
108091 var powImpl = createSimpleBinaryKernelImpl(function (a, b) {
108092 return Math.pow(a, b);
108093 });
108094 var pow$1 = binaryKernelFunc$1(Pow, powImpl);
108095 var powConfig$1 = {
108096 kernelName: Pow,
108097 backendName: 'cpu',
108098 kernelFunc: pow$1
108099 };
108100
108101 function raggedGather$1(args) {
108102 var inputs = args.inputs,
108103 backend = args.backend,
108104 attrs = args.attrs;
108105 var paramsNestedSplits = inputs.paramsNestedSplits,
108106 paramsDenseValues = inputs.paramsDenseValues,
108107 indices = inputs.indices;
108108 var outputRaggedRank = attrs.outputRaggedRank;
108109 var $paramsNestedSplits = paramsNestedSplits.map(function (t) {
108110 return backend.data.get(t.dataId).values;
108111 });
108112 var $paramsNestedSplitsShapes = paramsNestedSplits.map(function (t) {
108113 return t.shape;
108114 });
108115 var $paramsDenseValues = backend.data.get(paramsDenseValues.dataId).values;
108116 var $indices = backend.data.get(indices.dataId).values;
108117 var _raggedGatherImpl = raggedGatherImpl($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape, outputRaggedRank),
108118 _raggedGatherImpl2 = _slicedToArray(_raggedGatherImpl, 3),
108119 outputNestedSplits = _raggedGatherImpl2[0],
108120 outputDenseValues = _raggedGatherImpl2[1],
108121 outputDenseValuesShape = _raggedGatherImpl2[2];
108122 var outputNestedSplitsTensors = outputNestedSplits.map(function (splits) {
108123 return backend.makeTensorInfo([splits.length], 'int32', splits);
108124 });
108125 var outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues);
108126 return outputNestedSplitsTensors.concat([outputDenseValuesTensor]);
108127 }
108128 var raggedGatherConfig$1 = {
108129 kernelName: RaggedGather,
108130 backendName: 'cpu',
108131 kernelFunc: raggedGather$1
108132 };
108133
108134 function raggedRange$1(args) {
108135 var inputs = args.inputs,
108136 backend = args.backend;
108137 var starts = inputs.starts,
108138 limits = inputs.limits,
108139 deltas = inputs.deltas;
108140 var $starts = backend.data.get(starts.dataId).values;
108141 var $limits = backend.data.get(limits.dataId).values;
108142 var $deltas = backend.data.get(deltas.dataId).values;
108143 var _raggedRangeImpl = raggedRangeImpl($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape),
108144 _raggedRangeImpl2 = _slicedToArray(_raggedRangeImpl, 2),
108145 rtNestedSplitsData = _raggedRangeImpl2[0],
108146 rtDenseValuesData = _raggedRangeImpl2[1];
108147 var rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData);
108148 var rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData);
108149 return [rtNestedSplits, rtDenseValues];
108150 }
108151 var raggedRangeConfig$1 = {
108152 kernelName: RaggedRange,
108153 backendName: 'cpu',
108154 kernelFunc: raggedRange$1
108155 };
108156
108157 function raggedTensorToTensor$1(args) {
108158 var inputs = args.inputs,
108159 backend = args.backend,
108160 attrs = args.attrs;
108161 var shape = inputs.shape,
108162 values = inputs.values,
108163 defaultValue = inputs.defaultValue,
108164 rowPartitionTensors = inputs.rowPartitionTensors;
108165 var rowPartitionTypes = attrs.rowPartitionTypes;
108166 var $shape = backend.data.get(shape.dataId).values;
108167 var $values = backend.data.get(values.dataId).values;
108168 var $defaultValue = backend.data.get(defaultValue.dataId).values;
108169 var $rowPartitionValues = rowPartitionTensors.map(function (t) {
108170 return backend.data.get(t.dataId).values;
108171 });
108172 var rowPartitionValuesShapes = rowPartitionTensors.map(function (t) {
108173 return t.shape;
108174 });
108175 var _raggedTensorToTensor = raggedTensorToTensorImpl($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes),
108176 _raggedTensorToTensor2 = _slicedToArray(_raggedTensorToTensor, 2),
108177 outputShape = _raggedTensorToTensor2[0],
108178 output = _raggedTensorToTensor2[1];
108179 return backend.makeTensorInfo(outputShape, values.dtype, output);
108180 }
108181 var raggedTensorToTensorConfig$1 = {
108182 kernelName: RaggedTensorToTensor,
108183 backendName: 'cpu',
108184 kernelFunc: raggedTensorToTensor$1
108185 };
108186
108187 /**
108188 * @license
108189 * Copyright 2020 Google LLC. All Rights Reserved.
108190 * Licensed under the Apache License, Version 2.0 (the "License");
108191 * you may not use this file except in compliance with the License.
108192 * You may obtain a copy of the License at
108193 *
108194 * http://www.apache.org/licenses/LICENSE-2.0
108195 *
108196 * Unless required by applicable law or agreed to in writing, software
108197 * distributed under the License is distributed on an "AS IS" BASIS,
108198 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108199 * See the License for the specific language governing permissions and
108200 * limitations under the License.
108201 * =============================================================================
108202 */
108203 function range$1(args) {
108204 var backend = args.backend,
108205 attrs = args.attrs;
108206 var start = attrs.start,
108207 stop = attrs.stop,
108208 dtype = attrs.dtype,
108209 step = attrs.step;
108210 var values = rangeImpl(start, stop, step, dtype);
108211 return backend.makeTensorInfo([values.length], dtype, values);
108212 }
108213 var rangeConfig$1 = {
108214 kernelName: Range,
108215 backendName: 'cpu',
108216 kernelFunc: range$1
108217 };
108218
108219 /**
108220 * @license
108221 * Copyright 2020 Google LLC. All Rights Reserved.
108222 * Licensed under the Apache License, Version 2.0 (the License);
108223 * you may not use this file except in compliance with the License.
108224 * You may obtain a copy of the License at
108225 *
108226 * http://www.apache.org/licenses/LICENSE-2.0
108227 *
108228 * Unless required by applicable law or agreed to in writing, software
108229 * distributed under the License is distributed on an AS IS BASIS,
108230 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108231 * See the License for the specific language governing permissions and
108232 * limitations under the License.
108233 * =============================================================================
108234 */
108235 var reciprocal$1 = unaryKernelFunc$1(Reciprocal, function (xi) {
108236 return 1 / xi;
108237 });
108238 var reciprocalConfig$1 = {
108239 kernelName: Reciprocal,
108240 backendName: 'cpu',
108241 kernelFunc: reciprocal$1
108242 };
108243
108244 function resizeBilinear$1(args) {
108245 var inputs = args.inputs,
108246 backend = args.backend,
108247 attrs = args.attrs;
108248 var images = inputs.images;
108249 var alignCorners = attrs.alignCorners,
108250 halfPixelCenters = attrs.halfPixelCenters,
108251 size = attrs.size;
108252 assertNotComplex$1(images, 'resizeBilinear');
108253 var imagesStrides = computeStrides(images.shape);
108254 var _size = _slicedToArray(size, 2),
108255 newHeight = _size[0],
108256 newWidth = _size[1];
108257 var _images$shape = _slicedToArray(images.shape, 4),
108258 batch = _images$shape[0],
108259 oldHeight = _images$shape[1],
108260 oldWidth = _images$shape[2],
108261 numChannels = _images$shape[3];
108262 var xValues = backend.data.get(images.dataId).values;
108263 var result = new Float32Array(sizeFromShape([batch, newHeight, newWidth, numChannels]));
108264 var effectiveInputSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
108265 var effectiveOutputSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth];
108266 var outputIdx = 0;
108267 var effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
108268 var effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
108269 for (var b = 0; b < batch; b++) {
108270 for (var r = 0; r < newHeight; r++) {
108271 var sourceFracRow = void 0;
108272 if (halfPixelCenters) {
108273 sourceFracRow = effectiveRowSizeRatio * (r + 0.5) - 0.5;
108274 } else {
108275 sourceFracRow = effectiveRowSizeRatio * r;
108276 }
108277 var sourceRowFloor = Math.max(0, Math.floor(sourceFracRow));
108278 var rowFrac = sourceFracRow - sourceRowFloor;
108279 var sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow));
108280 var topRowOffset = b * imagesStrides[0] + sourceRowFloor * imagesStrides[1];
108281 var botRowOffset = b * imagesStrides[0] + sourceRowCeil * imagesStrides[1];
108282 for (var c = 0; c < newWidth; c++) {
108283 var sourceFracCol = void 0;
108284 if (halfPixelCenters) {
108285 sourceFracCol = effectiveColSizeRatio * (c + 0.5) - 0.5;
108286 } else {
108287 sourceFracCol = effectiveColSizeRatio * c;
108288 }
108289 var sourceColFloor = Math.max(0, Math.floor(sourceFracCol));
108290 var colFrac = sourceFracCol - sourceColFloor;
108291 var sourceColCeil = Math.min(oldWidth - 1, Math.ceil(sourceFracCol));
108292 var topLeftOffest = topRowOffset + sourceColFloor * imagesStrides[2];
108293 var botLeftOffset = botRowOffset + sourceColFloor * imagesStrides[2];
108294 var topRightOffset = topRowOffset + sourceColCeil * imagesStrides[2];
108295 var botRightOffest = botRowOffset + sourceColCeil * imagesStrides[2];
108296 for (var d = 0; d < numChannels; d++) {
108297 // Begin shader.
108298 // Compute the fractional index of the source.
108299 var topLeft = xValues[topLeftOffest + d];
108300 var bottomLeft = xValues[botLeftOffset + d];
108301 var topRight = xValues[topRightOffset + d];
108302 var bottomRight = xValues[botRightOffest + d];
108303 var top = topLeft + (topRight - topLeft) * colFrac;
108304 var bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac;
108305 var newValue = top + (bottom - top) * rowFrac;
108306 result[outputIdx++] = newValue;
108307 }
108308 }
108309 }
108310 }
108311 return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], 'float32', result);
108312 }
108313 var resizeBilinearConfig$1 = {
108314 kernelName: ResizeBilinear,
108315 backendName: 'cpu',
108316 kernelFunc: resizeBilinear$1
108317 };
108318
108319 function resizeBilinearGrad$1(args) {
108320 var inputs = args.inputs,
108321 backend = args.backend,
108322 attrs = args.attrs;
108323 var images = inputs.images,
108324 dy = inputs.dy;
108325 var alignCorners = attrs.alignCorners;
108326 assertNotComplex$1([dy, images], 'resizeBilinearGrad');
108327 var imagesStrides = computeStrides(images.shape);
108328 var _images$shape = _slicedToArray(images.shape, 4),
108329 batch = _images$shape[0],
108330 xHeight = _images$shape[1],
108331 xWidth = _images$shape[2],
108332 depth = _images$shape[3];
108333 var _dy$shape = _slicedToArray(dy.shape, 3),
108334 yHeight = _dy$shape[1],
108335 yWidth = _dy$shape[2];
108336 var output = new Float32Array(batch * xHeight * xWidth * depth);
108337 // In the backwards pass, we want to find the pixels that were generated
108338 // for each pixel in the input image the forward pass and add the
108339 // corresponding coefficient from dy to the gradient (with some
108340 // interpolation).
108341 var effectiveXSize = [alignCorners && yHeight > 1 ? xHeight - 1 : xHeight, alignCorners && yWidth > 1 ? xWidth - 1 : xWidth];
108342 var effectiveYSize = [alignCorners && yHeight > 1 ? yHeight - 1 : yHeight, alignCorners && yWidth > 1 ? yWidth - 1 : yWidth];
108343 var heightScale = effectiveXSize[0] / effectiveYSize[0];
108344 var widthScale = effectiveXSize[1] / effectiveYSize[1];
108345 // Reference implementation
108346 // tslint:disable-next-line:max-line-length
108347 // https://github.com/tensorflow/tensorflow/blob/3039375c86a5bbc9610c7725dcaa95d635f87ba2/tensorflow/core/kernels/resize_bilinear_op.cc#L275
108348 var dyValues = backend.data.get(dy.dataId).values;
108349 var offset = 0;
108350 for (var b = 0; b < batch; b++) {
108351 var bOffset = b * imagesStrides[0];
108352 for (var r = 0; r < yHeight; r++) {
108353 var dxR = r * heightScale;
108354 var topDxRIndex = Math.floor(dxR);
108355 var bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1);
108356 var topDxROffset = bOffset + topDxRIndex * imagesStrides[1];
108357 var bottomDxROffset = bOffset + bottomDxRIndex * imagesStrides[1];
108358 var dxRLerp = dxR - topDxRIndex;
108359 var inverseDxRLerp = 1.0 - dxRLerp;
108360 for (var c = 0; c < yWidth; c++) {
108361 var dxC = c * widthScale;
108362 var leftDxCIndex = Math.floor(dxC);
108363 var rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1);
108364 var dxCLerp = dxC - leftDxCIndex;
108365 var inverseDxCLerp = 1.0 - dxCLerp;
108366 var topLeftRCOffset = topDxROffset + leftDxCIndex * imagesStrides[2];
108367 var topRightRCOffset = topDxROffset + rightDxCIndex * imagesStrides[2];
108368 var bottomLeftRCOffset = bottomDxROffset + leftDxCIndex * imagesStrides[2];
108369 var bottomRightRCOffset = bottomDxROffset + rightDxCIndex * imagesStrides[2];
108370 var inverseDxRLerpTimesInverseDxCLerp = inverseDxRLerp * inverseDxCLerp;
108371 var inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp;
108372 var dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp;
108373 var dxRLerpTimesDxCLerp = dxRLerp * dxCLerp;
108374 for (var d = 0; d < depth; d++) {
108375 var dyVal = dyValues[offset++];
108376 output[topLeftRCOffset + d] += dyVal * inverseDxRLerpTimesInverseDxCLerp;
108377 output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp;
108378 output[bottomLeftRCOffset + d] += dyVal * dxRLerpTimesInverseDxCLerp;
108379 output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp;
108380 }
108381 }
108382 }
108383 }
108384 return backend.makeTensorInfo([batch, xWidth, xHeight, depth], 'float32', output);
108385 }
108386 var resizeBilinearGradConfig$1 = {
108387 kernelName: ResizeBilinearGrad,
108388 backendName: 'cpu',
108389 kernelFunc: resizeBilinearGrad$1
108390 };
108391
108392 function resizeNearestNeighbor$1(args) {
108393 var inputs = args.inputs,
108394 backend = args.backend,
108395 attrs = args.attrs;
108396 var images = inputs.images;
108397 var alignCorners = attrs.alignCorners,
108398 halfPixelCenters = attrs.halfPixelCenters,
108399 size = attrs.size;
108400 assertNotComplex$1(images, 'resizeNearestNeighbor');
108401 var imagesStrides = computeStrides(images.shape);
108402 var _size = _slicedToArray(size, 2),
108403 newHeight = _size[0],
108404 newWidth = _size[1];
108405 var _images$shape = _slicedToArray(images.shape, 4),
108406 batch = _images$shape[0],
108407 oldHeight = _images$shape[1],
108408 oldWidth = _images$shape[2],
108409 numChannels = _images$shape[3];
108410 var xValues = backend.data.get(images.dataId).values;
108411 var output = new Float32Array(batch * newHeight * newWidth * numChannels);
108412 var effectiveInputSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
108413 var effectiveOutputSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth];
108414 var effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
108415 var effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
108416 var outputOffset = 0;
108417 for (var b = 0; b < batch; b++) {
108418 var batchOffset = b * imagesStrides[0];
108419 for (var r = 0; r < newHeight; r++) {
108420 var sourceFracRow = halfPixelCenters ? effectiveRowSizeRatio * (r + 0.5) : effectiveRowSizeRatio * r;
108421 var sourceNearestRow = Math.min(oldHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow));
108422 if (halfPixelCenters) {
108423 sourceNearestRow = Math.max(0, sourceNearestRow);
108424 }
108425 var rowOffset = batchOffset + sourceNearestRow * imagesStrides[1];
108426 for (var c = 0; c < newWidth; c++) {
108427 var sourceFracCol = halfPixelCenters ? effectiveColSizeRatio * (c + 0.5) : effectiveColSizeRatio * c;
108428 var sourceNearestCol = Math.min(oldWidth - 1, alignCorners ? Math.round(sourceFracCol) : Math.floor(sourceFracCol));
108429 if (halfPixelCenters) {
108430 sourceNearestCol = Math.max(0, sourceNearestCol);
108431 }
108432 var colOffset = rowOffset + sourceNearestCol * imagesStrides[2];
108433 for (var d = 0; d < numChannels; d++) {
108434 // Begin shader.
108435 // Compute the fractional index of the source.
108436 var newVal = xValues[colOffset + d];
108437 output[outputOffset++] = newVal;
108438 }
108439 }
108440 }
108441 }
108442 return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], images.dtype, output);
108443 }
108444 var resizeNearestNeighborConfig$1 = {
108445 kernelName: ResizeNearestNeighbor,
108446 backendName: 'cpu',
108447 kernelFunc: resizeNearestNeighbor$1
108448 };
108449
108450 function resizeNearestNeighborGrad$1(args) {
108451 var inputs = args.inputs,
108452 backend = args.backend,
108453 attrs = args.attrs;
108454 var images = inputs.images,
108455 dy = inputs.dy;
108456 var alignCorners = attrs.alignCorners;
108457 assertNotComplex$1([dy, images], 'resizeNearestNeighborGrad');
108458 var imagesStrides = computeStrides(images.shape);
108459 var dyStrides = computeStrides(dy.shape);
108460 var _images$shape = _slicedToArray(images.shape, 4),
108461 batch = _images$shape[0],
108462 xHeight = _images$shape[1],
108463 xWidth = _images$shape[2],
108464 depth = _images$shape[3];
108465 var _dy$shape = _slicedToArray(dy.shape, 3),
108466 yHeight = _dy$shape[1],
108467 yWidth = _dy$shape[2];
108468 var output = new Float32Array(batch * xHeight * xWidth * depth);
108469 var dyValues = backend.data.get(dy.dataId).values;
108470 // In the backwards pass, we want to find the pixels that were generated
108471 // for each pixel in the input image the forward pass
108472 var effectiveXSize = [alignCorners && yHeight > 1 ? xHeight - 1 : xHeight, alignCorners && yWidth > 1 ? xWidth - 1 : xWidth];
108473 var effectiveYSize = [alignCorners && yHeight > 1 ? yHeight - 1 : yHeight, alignCorners && yWidth > 1 ? yWidth - 1 : yWidth];
108474 var heightScale = effectiveXSize[0] / effectiveYSize[0];
108475 var widthScale = effectiveXSize[1] / effectiveYSize[1];
108476 var invHeightScale = 1 / heightScale;
108477 var invWidthScale = 1 / widthScale;
108478 // This defines the size of the window of values around a particular
108479 // index in dy that we want to search for contributions to dx.
108480 var winHeight = Math.ceil(invHeightScale) * 2 + 2;
108481 var winWidth = Math.ceil(invWidthScale) * 2 + 2;
108482 // Loop over the output space.
108483 for (var b = 0; b < batch; b++) {
108484 var batchOffset = b * imagesStrides[0];
108485 for (var r = 0; r < xHeight; r++) {
108486 var rowOffset = batchOffset + r * imagesStrides[1];
108487 // Compute bounds for where in dy we will look
108488 var startRLerp = Math.floor(r * invHeightScale);
108489 var startDyR = Math.floor(startRLerp - winHeight / 2);
108490 for (var c = 0; c < xWidth; c++) {
108491 var colOffset = rowOffset + c * imagesStrides[2];
108492 // Compute bounds for where in dy we will look
108493 var startCLerp = Math.floor(c * invWidthScale);
108494 var startDyC = Math.floor(startCLerp - winWidth / 2);
108495 for (var d = 0; d < depth; d++) {
108496 var accum = 0;
108497 // loop over dy
108498 for (var dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {
108499 var dyR = dyRIndex + startDyR;
108500 // Guard against the window exceeding the bounds of dy
108501 if (dyR < 0 || dyR >= yHeight) {
108502 continue;
108503 }
108504 var dyROffset = batchOffset + dyR * dyStrides[1];
108505 var sourceFracRow = dyR * heightScale;
108506 var sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow));
108507 if (r !== sourceNearestRow) {
108508 continue;
108509 }
108510 for (var dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {
108511 var dyC = dyCIndex + startDyC;
108512 // Guard against the window exceeding the bounds of dy
108513 if (dyC < 0 || dyC >= yWidth) {
108514 continue;
108515 }
108516 var dyCOffset = dyROffset + dyC * dyStrides[2];
108517 var sourceFracCol = dyC * widthScale;
108518 var sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) : Math.floor(sourceFracCol));
108519 if (c === sourceNearestCol) {
108520 accum += dyValues[dyCOffset + d];
108521 }
108522 }
108523 }
108524 output[colOffset + d] = accum;
108525 }
108526 }
108527 }
108528 }
108529 return backend.makeTensorInfo(images.shape, images.dtype, output);
108530 }
108531 var resizeNearestNeighborGradConfig$1 = {
108532 kernelName: ResizeNearestNeighborGrad,
108533 backendName: 'cpu',
108534 kernelFunc: resizeNearestNeighborGrad$1
108535 };
108536
108537 function reverse$1(args) {
108538 var inputs = args.inputs,
108539 backend = args.backend,
108540 attrs = args.attrs;
108541 var x = inputs.x;
108542 var dims = attrs.dims;
108543 assertNotComplex$1(x, 'reverse');
108544 var xRank = x.shape.length;
108545 var $dims = parseAxisParam(dims, x.shape);
108546 if (xRank === 0) {
108547 return identity$1({
108548 inputs: {
108549 x: x
108550 },
108551 backend: backend
108552 });
108553 }
108554 var outBuf = new TensorBuffer(x.shape, x.dtype);
108555 var xBuf = backend.bufferSync(x);
108556 var _loop = function _loop() {
108557 var outLoc = outBuf.indexToLoc(i);
108558 var inLoc = outLoc.slice();
108559 $dims.forEach(function (d) {
108560 return inLoc[d] = x.shape[d] - 1 - inLoc[d];
108561 });
108562 outBuf.set.apply(outBuf, [xBuf.get.apply(xBuf, _toConsumableArray(inLoc))].concat(_toConsumableArray(outLoc)));
108563 };
108564 for (var i = 0; i < outBuf.size; i++) {
108565 _loop();
108566 }
108567 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
108568 }
108569 var reverseConfig$1 = {
108570 kernelName: Reverse,
108571 backendName: 'cpu',
108572 kernelFunc: reverse$1
108573 };
108574
108575 var rotateWithOffsetConfig$1 = {
108576 kernelName: RotateWithOffset,
108577 backendName: 'cpu',
108578 kernelFunc: function kernelFunc(_ref) {
108579 var inputs = _ref.inputs,
108580 attrs = _ref.attrs,
108581 backend = _ref.backend;
108582 var image = inputs.image;
108583 var radians = attrs.radians,
108584 fillValue = attrs.fillValue,
108585 center = attrs.center;
108586 var cpuBackend = backend;
108587 var output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
108588 var _image$shape = _slicedToArray(image.shape, 4),
108589 batch = _image$shape[0],
108590 imageHeight = _image$shape[1],
108591 imageWidth = _image$shape[2],
108592 numChannels = _image$shape[3];
108593 var _backend_util$getImag = getImageCenter(center, imageHeight, imageWidth),
108594 _backend_util$getImag2 = _slicedToArray(_backend_util$getImag, 2),
108595 centerX = _backend_util$getImag2[0],
108596 centerY = _backend_util$getImag2[1];
108597 var fullOpacityValue = 255;
108598 var sinFactor = Math.sin(radians);
108599 var cosFactor = Math.cos(radians);
108600 var imageVals = cpuBackend.data.get(image.dataId).values;
108601 for (var batchIdx = 0; batchIdx < batch; batchIdx++) {
108602 var batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
108603 for (var row = 0; row < imageHeight; row++) {
108604 var rowOffset = row * (imageWidth * numChannels);
108605 for (var col = 0; col < imageWidth; col++) {
108606 var colOffset = col * numChannels;
108607 for (var channel = 0; channel < numChannels; channel++) {
108608 var coords = [batch, row, col, channel];
108609 var x = coords[2];
108610 var y = coords[1];
108611 // coordX/coordY are the result of rotating and translating x/y.
108612 var coordX = (x - centerX) * cosFactor - (y - centerY) * sinFactor;
108613 var coordY = (x - centerX) * sinFactor + (y - centerY) * cosFactor;
108614 coordX = Math.round(coordX + centerX);
108615 coordY = Math.round(coordY + centerY);
108616 var outputValue = fillValue;
108617 if (typeof fillValue !== 'number') {
108618 if (channel === 3) {
108619 outputValue = fullOpacityValue;
108620 } else {
108621 outputValue = fillValue[channel];
108622 }
108623 }
108624 // If the coordinate position falls within the image boundaries...
108625 if (coordX >= 0 && coordX < imageWidth && coordY >= 0 && coordY < imageHeight) {
108626 // set the output to the image value at the coordinate position.
108627 var rotatedRowOffset = coordY * (imageWidth * numChannels);
108628 var rotatedColOffset = coordX * numChannels;
108629 var imageIdx = batchOffset + rotatedRowOffset + rotatedColOffset + channel;
108630 outputValue = imageVals[imageIdx];
108631 }
108632 var outIdx = batchOffset + rowOffset + colOffset + channel;
108633 output[outIdx] = outputValue;
108634 }
108635 }
108636 }
108637 }
108638 var dataId = cpuBackend.write(output, image.shape, image.dtype);
108639 return {
108640 dataId: dataId,
108641 shape: image.shape,
108642 dtype: image.dtype
108643 };
108644 }
108645 };
108646
108647 /**
108648 * @license
108649 * Copyright 2020 Google LLC. All Rights Reserved.
108650 * Licensed under the Apache License, Version 2.0 (the License);
108651 * you may not use this file except in compliance with the License.
108652 * You may obtain a copy of the License at
108653 *
108654 * http://www.apache.org/licenses/LICENSE-2.0
108655 *
108656 * Unless required by applicable law or agreed to in writing, software
108657 * distributed under the License is distributed on an AS IS BASIS,
108658 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108659 * See the License for the specific language governing permissions and
108660 * limitations under the License.
108661 * =============================================================================
108662 */
108663 var round$1 = unaryKernelFunc$1(Round, function (xi) {
108664 // The algorithm is based on banker's rounding.
108665 var base = Math.floor(xi);
108666 if (xi - base < 0.5) {
108667 return Math.floor(xi);
108668 } else if (xi - base > 0.5) {
108669 return Math.ceil(xi);
108670 } else {
108671 if (base % 2.0 === 0.0) {
108672 return base;
108673 } else {
108674 return base + 1.0;
108675 }
108676 }
108677 });
108678 var roundConfig$1 = {
108679 kernelName: Round,
108680 backendName: 'cpu',
108681 kernelFunc: round$1
108682 };
108683
108684 /**
108685 * @license
108686 * Copyright 2020 Google LLC. All Rights Reserved.
108687 * Licensed under the Apache License, Version 2.0 (the "License");
108688 * you may not use this file except in compliance with the License.
108689 * You may obtain a copy of the License at
108690 *
108691 * http://www.apache.org/licenses/LICENSE-2.0
108692 *
108693 * Unless required by applicable law or agreed to in writing, software
108694 * distributed under the License is distributed on an "AS IS" BASIS,
108695 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108696 * See the License for the specific language governing permissions and
108697 * limitations under the License.
108698 * =============================================================================
108699 */
108700 function scatterNd$1(args) {
108701 var inputs = args.inputs,
108702 backend = args.backend,
108703 attrs = args.attrs;
108704 var indices = inputs.indices,
108705 updates = inputs.updates;
108706 var shape = attrs.shape;
108707 var _backend_util$calcula = calculateShapes(updates, indices, shape),
108708 sliceRank = _backend_util$calcula.sliceRank,
108709 numUpdates = _backend_util$calcula.numUpdates,
108710 sliceSize = _backend_util$calcula.sliceSize,
108711 strides = _backend_util$calcula.strides,
108712 outputSize = _backend_util$calcula.outputSize;
108713 var sumDupeIndices = true;
108714 var indicesBuf = backend.bufferSync(indices);
108715 var updatesBuf = backend.bufferSync(updates);
108716 var outBuf = scatterImpl(indicesBuf, updatesBuf, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, 0 /* defaultValue */, sumDupeIndices);
108717 return backend.makeTensorInfo(shape, outBuf.dtype, outBuf.values);
108718 }
108719 var scatterNdConfig$1 = {
108720 kernelName: ScatterNd,
108721 backendName: 'cpu',
108722 kernelFunc: scatterNd$1
108723 };
108724
108725 /**
108726 * @license
108727 * Copyright 2022 Google LLC. All Rights Reserved.
108728 * Licensed under the Apache License, Version 2.0 (the "License");
108729 * you may not use this file except in compliance with the License.
108730 * You may obtain a copy of the License at
108731 *
108732 * http://www.apache.org/licenses/LICENSE-2.0
108733 *
108734 * Unless required by applicable law or agreed to in writing, software
108735 * distributed under the License is distributed on an "AS IS" BASIS,
108736 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108737 * See the License for the specific language governing permissions and
108738 * limitations under the License.
108739 * =============================================================================
108740 */
108741 function lowerBound(array, value) {
108742 var left = 0;
108743 var right = array.length;
108744 var mid = 0;
108745 while (left < right) {
108746 mid = Math.floor((left + right) / 2);
108747 if (array[mid] < value) {
108748 left = mid + 1;
108749 } else {
108750 right = mid;
108751 }
108752 }
108753 return right;
108754 }
108755 function upperBound(array, value) {
108756 var left = 0;
108757 var right = array.length;
108758 var mid = 0;
108759 while (left < right) {
108760 mid = Math.floor((left + right) / 2);
108761 if (array[mid] <= value) {
108762 left = mid + 1;
108763 } else {
108764 right = mid;
108765 }
108766 }
108767 return right;
108768 }
108769 function searchSortedImpl(sortedInputs, values, batchSize, numInputs, numValues, side) {
108770 var output = getArrayFromDType('int32', batchSize * numValues);
108771 for (var b = 0; b < batchSize; ++b) {
108772 var sortedInputsSlice = sortedInputs.slice(b * numInputs, (b + 1) * numInputs);
108773 var outputOffset = b * numValues;
108774 for (var i = 0; i < numValues; ++i) {
108775 output[outputOffset + i] = side === 'left' ? lowerBound(sortedInputsSlice, values[i + outputOffset]) : upperBound(sortedInputsSlice, values[i + outputOffset]);
108776 }
108777 }
108778 return output;
108779 }
108780
108781 /**
108782 * @license
108783 * Copyright 2022 Google LLC. All Rights Reserved.
108784 * Licensed under the Apache License, Version 2.0 (the "License");
108785 * you may not use this file except in compliance with the License.
108786 * You may obtain a copy of the License at
108787 *
108788 * http://www.apache.org/licenses/LICENSE-2.0
108789 *
108790 * Unless required by applicable law or agreed to in writing, software
108791 * distributed under the License is distributed on an "AS IS" BASIS,
108792 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108793 * See the License for the specific language governing permissions and
108794 * limitations under the License.
108795 * =============================================================================
108796 */
108797 function searchSorted$1(args) {
108798 var inputs = args.inputs,
108799 backend = args.backend,
108800 attrs = args.attrs;
108801 var sortedSequence = inputs.sortedSequence,
108802 values = inputs.values;
108803 var side = attrs.side;
108804 var $sortedSequence = backend.data.get(sortedSequence.dataId).values;
108805 var $values = backend.data.get(values.dataId).values;
108806 var output = searchSortedImpl($sortedSequence, $values, sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side);
108807 return backend.makeTensorInfo(values.shape, 'int32', output);
108808 }
108809 var searchSortedConfig$1 = {
108810 kernelName: SearchSorted,
108811 backendName: 'cpu',
108812 kernelFunc: searchSorted$1
108813 };
108814
108815 /**
108816 * @license
108817 * Copyright 2020 Google LLC. All Rights Reserved.
108818 * Licensed under the Apache License, Version 2.0 (the "License");
108819 * you may not use this file except in compliance with the License.
108820 * You may obtain a copy of the License at
108821 *
108822 * http://www.apache.org/licenses/LICENSE-2.0
108823 *
108824 * Unless required by applicable law or agreed to in writing, software
108825 * distributed under the License is distributed on an "AS IS" BASIS,
108826 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108827 * See the License for the specific language governing permissions and
108828 * limitations under the License.
108829 * =============================================================================
108830 */
108831 function select$1(args) {
108832 var inputs = args.inputs,
108833 backend = args.backend;
108834 var condition = inputs.condition,
108835 t = inputs.t,
108836 e = inputs.e;
108837 assertNotComplex$1([condition, t, e], 'select');
108838 var conditionRank = condition.shape.length;
108839 var values = backend.data.get(condition.dataId).values;
108840 var tValues = backend.data.get(t.dataId).values;
108841 var eValues = backend.data.get(e.dataId).values;
108842 var resultDtype = upcastType(t.dtype, e.dtype);
108843 var newValues = makeZerosTypedArray(sizeFromShape(t.shape), resultDtype);
108844 var index = 0;
108845 var offset = conditionRank === 0 || conditionRank > 1 || t.shape.length === 1 ? 1 : sizeFromShape(t.shape.slice(1));
108846 for (var i = 0; i < values.length; i++) {
108847 for (var j = 0; j < offset; j++) {
108848 if (values[i] === 1) {
108849 newValues[index++] = tValues[i];
108850 } else {
108851 newValues[index++] = eValues[i];
108852 }
108853 }
108854 }
108855 return backend.makeTensorInfo(t.shape, resultDtype, newValues);
108856 }
108857 var selectConfig$1 = {
108858 kernelName: Select,
108859 backendName: 'cpu',
108860 kernelFunc: select$1
108861 };
108862
108863 /**
108864 * @license
108865 * Copyright 2020 Google LLC. All Rights Reserved.
108866 * Licensed under the Apache License, Version 2.0 (the License);
108867 * you may not use this file except in compliance with the License.
108868 * You may obtain a copy of the License at
108869 *
108870 * http://www.apache.org/licenses/LICENSE-2.0
108871 *
108872 * Unless required by applicable law or agreed to in writing, software
108873 * distributed under the License is distributed on an AS IS BASIS,
108874 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108875 * See the License for the specific language governing permissions and
108876 * limitations under the License.
108877 * =============================================================================
108878 */
108879 var scaleAlpha = SELU_SCALEALPHA;
108880 var scale = SELU_SCALE;
108881 var selu$1 = unaryKernelFunc$1(Selu$1, function (xi) {
108882 if (xi >= 0) {
108883 return scale * xi;
108884 } else {
108885 return scaleAlpha * (Math.exp(xi) - 1);
108886 }
108887 });
108888 var seluConfig$1 = {
108889 kernelName: Selu$1,
108890 backendName: 'cpu',
108891 kernelFunc: selu$1
108892 };
108893
108894 /**
108895 * @license
108896 * Copyright 2020 Google LLC. All Rights Reserved.
108897 * Licensed under the Apache License, Version 2.0 (the License);
108898 * you may not use this file except in compliance with the License.
108899 * You may obtain a copy of the License at
108900 *
108901 * http://www.apache.org/licenses/LICENSE-2.0
108902 *
108903 * Unless required by applicable law or agreed to in writing, software
108904 * distributed under the License is distributed on an AS IS BASIS,
108905 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108906 * See the License for the specific language governing permissions and
108907 * limitations under the License.
108908 * =============================================================================
108909 */
108910 var sign$1 = unaryKernelFunc$1(Sign, function (xi) {
108911 if (xi < 0) {
108912 return -1;
108913 } else if (xi > 0) {
108914 return 1;
108915 } else {
108916 return 0;
108917 }
108918 });
108919 var signConfig$1 = {
108920 kernelName: Sign,
108921 backendName: 'cpu',
108922 kernelFunc: sign$1
108923 };
108924
108925 /**
108926 * @license
108927 * Copyright 2020 Google LLC. All Rights Reserved.
108928 * Licensed under the Apache License, Version 2.0 (the License);
108929 * you may not use this file except in compliance with the License.
108930 * You may obtain a copy of the License at
108931 *
108932 * http://www.apache.org/licenses/LICENSE-2.0
108933 *
108934 * Unless required by applicable law or agreed to in writing, software
108935 * distributed under the License is distributed on an AS IS BASIS,
108936 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108937 * See the License for the specific language governing permissions and
108938 * limitations under the License.
108939 * =============================================================================
108940 */
108941 var sin$1 = unaryKernelFunc$1(Sin, function (xi) {
108942 return Math.sin(xi);
108943 });
108944 var sinConfig$1 = {
108945 kernelName: Sin,
108946 backendName: 'cpu',
108947 kernelFunc: sin$1
108948 };
108949
108950 /**
108951 * @license
108952 * Copyright 2020 Google LLC. All Rights Reserved.
108953 * Licensed under the Apache License, Version 2.0 (the License);
108954 * you may not use this file except in compliance with the License.
108955 * You may obtain a copy of the License at
108956 *
108957 * http://www.apache.org/licenses/LICENSE-2.0
108958 *
108959 * Unless required by applicable law or agreed to in writing, software
108960 * distributed under the License is distributed on an AS IS BASIS,
108961 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108962 * See the License for the specific language governing permissions and
108963 * limitations under the License.
108964 * =============================================================================
108965 */
108966 var sinh$1 = unaryKernelFunc$1(Sinh, function (xi) {
108967 return Math.sinh(xi);
108968 });
108969 var sinhConfig$1 = {
108970 kernelName: Sinh,
108971 backendName: 'cpu',
108972 kernelFunc: sinh$1
108973 };
108974
108975 /**
108976 * @license
108977 * Copyright 2020 Google LLC. All Rights Reserved.
108978 * Licensed under the Apache License, Version 2.0 (the License);
108979 * you may not use this file except in compliance with the License.
108980 * You may obtain a copy of the License at
108981 *
108982 * http://www.apache.org/licenses/LICENSE-2.0
108983 *
108984 * Unless required by applicable law or agreed to in writing, software
108985 * distributed under the License is distributed on an AS IS BASIS,
108986 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
108987 * See the License for the specific language governing permissions and
108988 * limitations under the License.
108989 * =============================================================================
108990 */
108991 // mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX
108992 // epsilon is the difference between 1.0 and the next representable float.
108993 // For a single precision 32 bit float this should be 2^-23, see:
108994 // https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm
108995 var epsilon = 1.1920928955078125e-7;
108996 var threshold = Math.log(epsilon) + 2.0;
108997 var softplus$1 = unaryKernelFunc$1(Softplus$1, function (xi) {
108998 // Value above which exp(x) may overflow, but softplus(x) == x
108999 // is within machine epsilon.
109000 var tooLarge = xi > -threshold;
109001 // Value below which exp(x) may underflow, but softplus(x) == exp(x)
109002 // is within machine epsilon.
109003 var tooSmall = xi < threshold;
109004 var expX = Math.exp(xi);
109005 var result;
109006 if (tooSmall) {
109007 result = expX;
109008 } else if (tooLarge) {
109009 result = xi;
109010 } else {
109011 result = Math.log(1.0 + expX);
109012 }
109013 return result;
109014 });
109015 var softplusConfig$1 = {
109016 kernelName: Softplus$1,
109017 backendName: 'cpu',
109018 kernelFunc: softplus$1
109019 };
109020
109021 function spaceToBatchND$1(args) {
109022 var inputs = args.inputs,
109023 backend = args.backend,
109024 attrs = args.attrs;
109025 var x = inputs.x;
109026 var blockShape = attrs.blockShape,
109027 paddings = attrs.paddings;
109028 assertNotComplex$1([x], 'spaceToBatchND');
109029 var prod = sizeFromShape(blockShape);
109030 var completePaddings = [[0, 0]];
109031 completePaddings.push.apply(completePaddings, _toConsumableArray(paddings));
109032 for (var i = 1 + blockShape.length; i < x.shape.length; ++i) {
109033 completePaddings.push([0, 0]);
109034 }
109035 var paddedX = padV2Config$1.kernelFunc({
109036 inputs: {
109037 x: x
109038 },
109039 backend: backend,
109040 attrs: {
109041 paddings: completePaddings,
109042 constantValue: 0
109043 }
109044 });
109045 var reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
109046 var permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
109047 var flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
109048 var reshapeInputs = {
109049 x: paddedX
109050 };
109051 var reshapeAttrs = {
109052 shape: reshapedPaddedShape
109053 };
109054 var paddedXReshaped = reshape$1({
109055 inputs: reshapeInputs,
109056 backend: backend,
109057 attrs: reshapeAttrs
109058 });
109059 var transposeInputs = {
109060 x: paddedXReshaped
109061 };
109062 var transposeAttrs = {
109063 perm: permutedReshapedPaddedPermutation
109064 };
109065 var paddedXT = transpose$1({
109066 inputs: transposeInputs,
109067 backend: backend,
109068 attrs: transposeAttrs
109069 });
109070 var resultReshapeInputs = {
109071 x: paddedXT
109072 };
109073 var resultReshapeAttrs = {
109074 shape: flattenShape
109075 };
109076 var result = reshape$1({
109077 inputs: resultReshapeInputs,
109078 backend: backend,
109079 attrs: resultReshapeAttrs
109080 });
109081 backend.disposeIntermediateTensorInfo(paddedX);
109082 backend.disposeIntermediateTensorInfo(paddedXReshaped);
109083 backend.disposeIntermediateTensorInfo(paddedXT);
109084 return result;
109085 }
109086 var spaceToBatchNDConfig$1 = {
109087 kernelName: SpaceToBatchND,
109088 backendName: 'cpu',
109089 kernelFunc: spaceToBatchND$1
109090 };
109091
109092 function sparseFillEmptyRows$1(args) {
109093 var inputs = args.inputs,
109094 backend = args.backend;
109095 var indices = inputs.indices,
109096 values = inputs.values,
109097 denseShape = inputs.denseShape,
109098 defaultValue = inputs.defaultValue;
109099 if (denseShape.shape.length !== 1) {
109100 throw new Error("Dense shape must be a vector, saw:\n ".concat(denseShape.shape));
109101 }
109102 if (indices.shape.length !== 2) {
109103 throw new Error("Indices must be a matrix, saw:\n ".concat(indices.shape));
109104 }
109105 if (values.shape.length !== 1) {
109106 throw new Error("Values must be a vector, saw:\n ".concat(values.shape));
109107 }
109108 if (defaultValue.shape.length !== 0) {
109109 throw new Error("Default value must be a scalar, saw:\n ".concat(defaultValue.shape));
109110 }
109111 var $indices = backend.data.get(indices.dataId).values;
109112 var $values = backend.data.get(values.dataId).values;
109113 var $denseShape = backend.data.get(denseShape.dataId).values;
109114 var $defaultValue = backend.data.get(defaultValue.dataId).values[0];
109115 var _sparseFillEmptyRowsI = sparseFillEmptyRowsImpl($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue),
109116 _sparseFillEmptyRowsI2 = _slicedToArray(_sparseFillEmptyRowsI, 5),
109117 outputIndices = _sparseFillEmptyRowsI2[0],
109118 outputIndicesShape = _sparseFillEmptyRowsI2[1],
109119 outputValues = _sparseFillEmptyRowsI2[2],
109120 emptyRowIndicator = _sparseFillEmptyRowsI2[3],
109121 reverseIndexMap = _sparseFillEmptyRowsI2[4];
109122 return [backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices), backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues), backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map(function (value) {
109123 return Number(value);
109124 }))), backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap))];
109125 }
109126 var sparseFillEmptyRowsConfig$1 = {
109127 kernelName: SparseFillEmptyRows,
109128 backendName: 'cpu',
109129 kernelFunc: sparseFillEmptyRows$1
109130 };
109131
109132 function sparseReshape$1(args) {
109133 var inputs = args.inputs,
109134 backend = args.backend;
109135 var inputIndices = inputs.inputIndices,
109136 inputShape = inputs.inputShape,
109137 newShape = inputs.newShape;
109138 if (inputIndices.shape.length !== 2) {
109139 throw new Error("Input indices should be a matrix but received shape\n ".concat(inputIndices.shape));
109140 }
109141 if (inputShape.shape.length !== 1) {
109142 throw new Error("Input shape should be a vector but received shape\n ".concat(inputShape.shape));
109143 }
109144 if (newShape.shape.length !== 1) {
109145 throw new Error("Target shape should be a vector but received shape ".concat(newShape.shape));
109146 }
109147 var $inputShape = Array.from(backend.data.get(inputShape.dataId).values);
109148 var $inputIndices = backend.data.get(inputIndices.dataId).values;
109149 var targetShape = Array.from(backend.data.get(newShape.dataId).values);
109150 var _sparseReshapeImpl = sparseReshapeImpl($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape),
109151 _sparseReshapeImpl2 = _slicedToArray(_sparseReshapeImpl, 3),
109152 newIndices = _sparseReshapeImpl2[0],
109153 indicesShape = _sparseReshapeImpl2[1],
109154 outputShape = _sparseReshapeImpl2[2];
109155 return [backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices), backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape))];
109156 }
109157 var sparseReshapeConfig$1 = {
109158 kernelName: SparseReshape,
109159 backendName: 'cpu',
109160 kernelFunc: sparseReshape$1
109161 };
109162
109163 function sparseSegmentMean$1(args) {
109164 var inputs = args.inputs,
109165 backend = args.backend;
109166 var data = inputs.data,
109167 indices = inputs.indices,
109168 segmentIds = inputs.segmentIds;
109169 if (data.shape.length < 1) {
109170 throw new Error("Data should be at least 1 dimensional but received scalar");
109171 }
109172 if (indices.shape.length !== 1) {
109173 throw new Error("Indices should be a vector but received shape\n ".concat(indices.shape));
109174 }
109175 if (segmentIds.shape.length !== 1) {
109176 throw new Error("Segment ids should be a vector but received shape\n ".concat(segmentIds.shape));
109177 }
109178 if (indices.shape[0] !== segmentIds.shape[0]) {
109179 throw new Error("segmentIds and indices should have same size.");
109180 }
109181 var $data = backend.data.get(data.dataId).values;
109182 var $indices = backend.data.get(indices.dataId).values;
109183 var $segmentIds = backend.data.get(segmentIds.dataId).values;
109184 var _sparseSegmentReducti = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds, true),
109185 _sparseSegmentReducti2 = _slicedToArray(_sparseSegmentReducti, 2),
109186 outputData = _sparseSegmentReducti2[0],
109187 outputDataShape = _sparseSegmentReducti2[1];
109188 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
109189 }
109190 var sparseSegmentMeanConfig$1 = {
109191 kernelName: SparseSegmentMean,
109192 backendName: 'cpu',
109193 kernelFunc: sparseSegmentMean$1
109194 };
109195
109196 function sparseSegmentSum$1(args) {
109197 var inputs = args.inputs,
109198 backend = args.backend;
109199 var data = inputs.data,
109200 indices = inputs.indices,
109201 segmentIds = inputs.segmentIds;
109202 if (data.shape.length < 1) {
109203 throw new Error("Data should be at least 1 dimensional but received scalar");
109204 }
109205 if (indices.shape.length !== 1) {
109206 throw new Error("Indices should be a vector but received shape\n ".concat(indices.shape));
109207 }
109208 if (segmentIds.shape.length !== 1) {
109209 throw new Error("Segment ids should be a vector but received shape\n ".concat(segmentIds.shape));
109210 }
109211 if (indices.shape[0] !== segmentIds.shape[0]) {
109212 throw new Error("segmentIds and indices should have same size.");
109213 }
109214 var $data = backend.data.get(data.dataId).values;
109215 var $indices = backend.data.get(indices.dataId).values;
109216 var $segmentIds = backend.data.get(segmentIds.dataId).values;
109217 var _sparseSegmentReducti = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds),
109218 _sparseSegmentReducti2 = _slicedToArray(_sparseSegmentReducti, 2),
109219 outputData = _sparseSegmentReducti2[0],
109220 outputDataShape = _sparseSegmentReducti2[1];
109221 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
109222 }
109223 var sparseSegmentSumConfig$1 = {
109224 kernelName: SparseSegmentSum,
109225 backendName: 'cpu',
109226 kernelFunc: sparseSegmentSum$1
109227 };
109228
109229 /**
109230 * @license
109231 * Copyright 2020 Google LLC. All Rights Reserved.
109232 * Licensed under the Apache License, Version 2.0 (the "License");
109233 * you may not use this file except in compliance with the License.
109234 * You may obtain a copy of the License at
109235 *
109236 * http://www.apache.org/licenses/LICENSE-2.0
109237 *
109238 * Unless required by applicable law or agreed to in writing, software
109239 * distributed under the License is distributed on an "AS IS" BASIS,
109240 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109241 * See the License for the specific language governing permissions and
109242 * limitations under the License.
109243 * =============================================================================
109244 */
109245 function sparseToDense$1(args) {
109246 var inputs = args.inputs,
109247 backend = args.backend,
109248 attrs = args.attrs;
109249 var sparseIndices = inputs.sparseIndices,
109250 sparseValues = inputs.sparseValues,
109251 defaultValue = inputs.defaultValue;
109252 var outputShape = attrs.outputShape;
109253 var _backend_util$calcula = calculateShapes(sparseValues, sparseIndices, outputShape),
109254 sliceRank = _backend_util$calcula.sliceRank,
109255 numUpdates = _backend_util$calcula.numUpdates,
109256 sliceSize = _backend_util$calcula.sliceSize,
109257 strides = _backend_util$calcula.strides,
109258 outputSize = _backend_util$calcula.outputSize;
109259 var sumDupeIndices = false;
109260 var indicesBuf = backend.bufferSync(sparseIndices);
109261 var outBuf;
109262 switch (sparseValues.dtype) {
109263 case 'bool':
109264 {
109265 var updatesBuf = backend.bufferSync(sparseValues);
109266 var $defaultValue = Boolean(backend.data.get(defaultValue.dataId).values[0]);
109267 outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
109268 break;
109269 }
109270 case 'float32':
109271 {
109272 var _updatesBuf = backend.bufferSync(sparseValues);
109273 var _$defaultValue = backend.data.get(defaultValue.dataId).values[0];
109274 outBuf = scatterImpl(indicesBuf, _updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, _$defaultValue, sumDupeIndices);
109275 break;
109276 }
109277 case 'int32':
109278 {
109279 var _updatesBuf2 = backend.bufferSync(sparseValues);
109280 var _$defaultValue2 = backend.data.get(defaultValue.dataId).values[0];
109281 outBuf = scatterImpl(indicesBuf, _updatesBuf2, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, _$defaultValue2, sumDupeIndices);
109282 break;
109283 }
109284 case 'string':
109285 {
109286 var _updatesBuf3 = backend.bufferSync(sparseValues);
109287 var _$defaultValue3 = decodeString(backend.data.get(defaultValue.dataId).values[0]);
109288 outBuf = scatterImpl(indicesBuf, _updatesBuf3, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, _$defaultValue3, sumDupeIndices);
109289 break;
109290 }
109291 default:
109292 throw new Error("Unsupported type ".concat(sparseValues.dtype));
109293 }
109294 return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
109295 }
109296 var sparseToDenseConfig$1 = {
109297 kernelName: SparseToDense,
109298 backendName: 'cpu',
109299 kernelFunc: sparseToDense$1
109300 };
109301
109302 function splitV$1(args) {
109303 var inputs = args.inputs,
109304 backend = args.backend,
109305 attrs = args.attrs;
109306 var x = inputs.x;
109307 var numOrSizeSplits = attrs.numOrSizeSplits,
109308 axis = attrs.axis;
109309 var $axis = parseAxisParam(axis, x.shape)[0];
109310 var splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
109311 var begin = new Array(x.shape.length).fill(0);
109312 var size = x.shape.slice();
109313 return splitSizes.map(function (s) {
109314 var sliceSize = _toConsumableArray(size);
109315 sliceSize[$axis] = s;
109316 var sliceT = slice$1({
109317 inputs: {
109318 x: x
109319 },
109320 backend: backend,
109321 attrs: {
109322 begin: begin,
109323 size: sliceSize
109324 }
109325 });
109326 begin[$axis] += s;
109327 return sliceT;
109328 });
109329 }
109330 var splitVConfig$1 = {
109331 kernelName: SplitV,
109332 backendName: 'cpu',
109333 kernelFunc: splitV$1
109334 };
109335
109336 /**
109337 * @license
109338 * Copyright 2019 Google LLC. All Rights Reserved.
109339 * Licensed under the Apache License, Version 2.0 (the "License");
109340 * you may not use this file except in compliance with the License.
109341 * You may obtain a copy of the License at
109342 *
109343 * http://www.apache.org/licenses/LICENSE-2.0
109344 *
109345 * Unless required by applicable law or agreed to in writing, software
109346 * distributed under the License is distributed on an "AS IS" BASIS,
109347 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109348 * See the License for the specific language governing permissions and
109349 * limitations under the License.
109350 * =============================================================================
109351 */
109352 var squareConfig$1 = {
109353 kernelName: Square,
109354 backendName: 'cpu',
109355 kernelFunc: function kernelFunc(_ref) {
109356 var inputs = _ref.inputs,
109357 backend = _ref.backend;
109358 var x = inputs.x;
109359 var cpuBackend = backend;
109360 assertNotComplex$1(x, 'square');
109361 var values = cpuBackend.data.get(x.dataId).values;
109362 var newValues = new Float32Array(values.length);
109363 for (var i = 0; i < values.length; ++i) {
109364 var value = values[i];
109365 newValues[i] = value * value;
109366 }
109367 var dataId = cpuBackend.write(newValues, x.shape, x.dtype);
109368 return {
109369 dataId: dataId,
109370 shape: x.shape,
109371 dtype: x.dtype
109372 };
109373 }
109374 };
109375
109376 /**
109377 * @license
109378 * Copyright 2020 Google LLC. All Rights Reserved.
109379 * Licensed under the Apache License, Version 2.0 (the License);
109380 * you may not use this file except in compliance with the License.
109381 * You may obtain a copy of the License at
109382 *
109383 * http://www.apache.org/licenses/LICENSE-2.0
109384 *
109385 * Unless required by applicable law or agreed to in writing, software
109386 * distributed under the License is distributed on an AS IS BASIS,
109387 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109388 * See the License for the specific language governing permissions and
109389 * limitations under the License.
109390 * =============================================================================
109391 */
109392 var step$1 = unaryKernelFunc$1(Step, function (xi, attrs) {
109393 var stepAttrs = attrs;
109394 if (isNaN(xi)) {
109395 return NaN;
109396 } else {
109397 return xi > 0 ? 1 : stepAttrs.alpha;
109398 }
109399 });
109400 var stepConfig$1 = {
109401 kernelName: Step,
109402 backendName: 'cpu',
109403 kernelFunc: step$1
109404 };
109405
109406 /**
109407 * @license
109408 * Copyright 2020 Google LLC. All Rights Reserved.
109409 * Licensed under the Apache License, Version 2.0 (the "License");
109410 * you may not use this file except in compliance with the License.
109411 * You may obtain a copy of the License at
109412 *
109413 * http://www.apache.org/licenses/LICENSE-2.0
109414 *
109415 * Unless required by applicable law or agreed to in writing, software
109416 * distributed under the License is distributed on an "AS IS" BASIS,
109417 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109418 * See the License for the specific language governing permissions and
109419 * limitations under the License.
109420 * =============================================================================
109421 */
109422 function stridedSlice$1(args) {
109423 var inputs = args.inputs,
109424 backend = args.backend,
109425 attrs = args.attrs;
109426 var x = inputs.x;
109427 var begin = attrs.begin,
109428 end = attrs.end,
109429 strides = attrs.strides,
109430 beginMask = attrs.beginMask,
109431 endMask = attrs.endMask,
109432 ellipsisMask = attrs.ellipsisMask,
109433 newAxisMask = attrs.newAxisMask,
109434 shrinkAxisMask = attrs.shrinkAxisMask;
109435 assertNotComplex$1(x, 'stridedSlice');
109436 var _slice_util$sliceInfo = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask),
109437 finalShapeSparse = _slice_util$sliceInfo.finalShapeSparse,
109438 finalShape = _slice_util$sliceInfo.finalShape,
109439 isIdentity = _slice_util$sliceInfo.isIdentity,
109440 sliceDim0 = _slice_util$sliceInfo.sliceDim0,
109441 isSimpleSlice = _slice_util$sliceInfo.isSimpleSlice,
109442 $begin = _slice_util$sliceInfo.begin,
109443 $end = _slice_util$sliceInfo.end,
109444 $strides = _slice_util$sliceInfo.strides;
109445 var result;
109446 // ref:
109447 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/strided_slice_op.cc
109448 if (isIdentity) {
109449 // Optimization #1, slice is a no-op plus reshape
109450 result = reshape$1({
109451 inputs: {
109452 x: x
109453 },
109454 backend: backend,
109455 attrs: {
109456 shape: finalShape
109457 }
109458 });
109459 } else if (sliceDim0 || isSimpleSlice) {
109460 // Optimization #2, slice is memory contiguous (only occurs in dim 0)
109461 assert$1(x.shape.length >= 1, function () {
109462 return "Input must have rank at least 1, got: ".concat(x.shape.length);
109463 });
109464 var size = computeOutShape$2($begin, $end, $strides);
109465 // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
109466 var sliced = slice$1({
109467 inputs: {
109468 x: x
109469 },
109470 backend: backend,
109471 attrs: {
109472 begin: $begin,
109473 size: size
109474 }
109475 });
109476 result = reshape$1({
109477 inputs: {
109478 x: sliced
109479 },
109480 backend: backend,
109481 attrs: {
109482 shape: finalShape
109483 }
109484 });
109485 backend.disposeIntermediateTensorInfo(sliced);
109486 } else {
109487 var xBuf = backend.bufferSync(x);
109488 var outBuf = stridedSliceImpl(finalShapeSparse, xBuf, $strides, $begin);
109489 result = backend.makeTensorInfo(finalShape, outBuf.dtype, outBuf.values);
109490 }
109491 return result;
109492 }
109493 var stridedSliceConfig$1 = {
109494 kernelName: StridedSlice,
109495 backendName: 'cpu',
109496 kernelFunc: stridedSlice$1
109497 };
109498
109499 function stringNGrams$1(args) {
109500 var inputs = args.inputs,
109501 backend = args.backend,
109502 attrs = args.attrs;
109503 var separator = attrs.separator,
109504 nGramWidths = attrs.nGramWidths,
109505 leftPad = attrs.leftPad,
109506 rightPad = attrs.rightPad,
109507 padWidth = attrs.padWidth,
109508 preserveShortSequences = attrs.preserveShortSequences;
109509 var data = inputs.data,
109510 dataSplits = inputs.dataSplits;
109511 var $data = backend.data.get(data.dataId).values;
109512 var $dataSplits = backend.data.get(dataSplits.dataId).values;
109513 var _stringNGramsImpl = stringNGramsImpl($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences),
109514 _stringNGramsImpl2 = _slicedToArray(_stringNGramsImpl, 2),
109515 nGrams = _stringNGramsImpl2[0],
109516 nGramsSplits = _stringNGramsImpl2[1];
109517 return [backend.makeTensorInfo([nGrams.length], 'string', nGrams), backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits)];
109518 }
109519 var stringNGramsConfig$1 = {
109520 kernelName: StringNGrams,
109521 backendName: 'cpu',
109522 kernelFunc: stringNGrams$1
109523 };
109524
109525 function stringSplit$1(args) {
109526 var inputs = args.inputs,
109527 backend = args.backend,
109528 attrs = args.attrs;
109529 var skipEmpty = attrs.skipEmpty;
109530 var input = inputs.input,
109531 delimiter = inputs.delimiter;
109532 if (input.dtype !== 'string') {
109533 throw new Error('Input must be of datatype string');
109534 }
109535 if (input.shape.length !== 1) {
109536 throw new Error("Input must be a vector, got shape: ".concat(input.shape));
109537 }
109538 if (delimiter.shape.length !== 0) {
109539 throw new Error("Delimiter must be a scalar, got shape: ".concat(delimiter.shape));
109540 }
109541 var $input = backend.data.get(input.dataId).values;
109542 var $delimiter = backend.data.get(delimiter.dataId).values[0];
109543 var _stringSplitImpl = stringSplitImpl($input, $delimiter, skipEmpty),
109544 _stringSplitImpl2 = _slicedToArray(_stringSplitImpl, 3),
109545 indices = _stringSplitImpl2[0],
109546 values = _stringSplitImpl2[1],
109547 shape = _stringSplitImpl2[2];
109548 var outputSize = values.length;
109549 return [backend.makeTensorInfo([outputSize, 2], 'int32', indices), backend.makeTensorInfo([outputSize], 'string', values), backend.makeTensorInfo([2], 'int32', new Int32Array(shape))];
109550 }
109551 var stringSplitConfig$1 = {
109552 kernelName: StringSplit,
109553 backendName: 'cpu',
109554 kernelFunc: stringSplit$1
109555 };
109556
109557 /**
109558 * @license
109559 * Copyright 2021 Google LLC. All Rights Reserved.
109560 * Licensed under the Apache License, Version 2.0 (the "License");
109561 * you may not use this file except in compliance with the License.
109562 * You may obtain a copy of the License at
109563 *
109564 * http://www.apache.org/licenses/LICENSE-2.0
109565 *
109566 * Unless required by applicable law or agreed to in writing, software
109567 * distributed under the License is distributed on an "AS IS" BASIS,
109568 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109569 * See the License for the specific language governing permissions and
109570 * limitations under the License.
109571 * =============================================================================
109572 */
109573 function stringToHashBucketFast$1(args) {
109574 var inputs = args.inputs,
109575 backend = args.backend,
109576 attrs = args.attrs;
109577 var numBuckets = attrs.numBuckets;
109578 var input = inputs.input;
109579 if (input.dtype !== 'string') {
109580 throw new Error('Input must be of datatype string');
109581 }
109582 if (numBuckets <= 0) {
109583 throw new Error("Number of buckets must be at least 1");
109584 }
109585 var $input = backend.data.get(input.dataId).values;
109586 var output = stringToHashBucketFastImpl($input, numBuckets);
109587 return backend.makeTensorInfo(input.shape, 'int32', output);
109588 }
109589 var stringToHashBucketFastConfig$1 = {
109590 kernelName: StringToHashBucketFast,
109591 backendName: 'cpu',
109592 kernelFunc: stringToHashBucketFast$1
109593 };
109594
109595 /**
109596 * @license
109597 * Copyright 2020 Google LLC. All Rights Reserved.
109598 * Licensed under the Apache License, Version 2.0 (the License);
109599 * you may not use this file except in compliance with the License.
109600 * You may obtain a copy of the License at
109601 *
109602 * http://www.apache.org/licenses/LICENSE-2.0
109603 *
109604 * Unless required by applicable law or agreed to in writing, software
109605 * distributed under the License is distributed on an AS IS BASIS,
109606 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109607 * See the License for the specific language governing permissions and
109608 * limitations under the License.
109609 * =============================================================================
109610 */
109611 var tan$1 = unaryKernelFunc$1(Tan, function (xi) {
109612 return Math.tan(xi);
109613 });
109614 var tanConfig$1 = {
109615 kernelName: Tan,
109616 backendName: 'cpu',
109617 kernelFunc: tan$1
109618 };
109619
109620 /**
109621 * @license
109622 * Copyright 2020 Google LLC. All Rights Reserved.
109623 * Licensed under the Apache License, Version 2.0 (the License);
109624 * you may not use this file except in compliance with the License.
109625 * You may obtain a copy of the License at
109626 *
109627 * http://www.apache.org/licenses/LICENSE-2.0
109628 *
109629 * Unless required by applicable law or agreed to in writing, software
109630 * distributed under the License is distributed on an AS IS BASIS,
109631 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109632 * See the License for the specific language governing permissions and
109633 * limitations under the License.
109634 * =============================================================================
109635 */
109636 var tanh$1 = unaryKernelFunc$1(Tanh$1, function (xi) {
109637 return Math.tanh(xi);
109638 });
109639 var tanhConfig$1 = {
109640 kernelName: Tanh$1,
109641 backendName: 'cpu',
109642 kernelFunc: tanh$1
109643 };
109644
109645 /**
109646 * @license
109647 * Copyright 2022 Google LLC. All Rights Reserved.
109648 * Licensed under the Apache License, Version 2.0 (the "License");
109649 * you may not use this file except in compliance with the License.
109650 * You may obtain a copy of the License at
109651 *
109652 * http://www.apache.org/licenses/LICENSE-2.0
109653 *
109654 * Unless required by applicable law or agreed to in writing, software
109655 * distributed under the License is distributed on an "AS IS" BASIS,
109656 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109657 * See the License for the specific language governing permissions and
109658 * limitations under the License.
109659 * =============================================================================
109660 */
109661 function tensorScatterUpdate$1(args) {
109662 var inputs = args.inputs,
109663 backend = args.backend;
109664 var tensor = inputs.tensor,
109665 indices = inputs.indices,
109666 updates = inputs.updates;
109667 var _backend_util$calcula = calculateShapes(updates, indices, tensor.shape),
109668 sliceRank = _backend_util$calcula.sliceRank,
109669 numUpdates = _backend_util$calcula.numUpdates,
109670 sliceSize = _backend_util$calcula.sliceSize,
109671 strides = _backend_util$calcula.strides,
109672 outputSize = _backend_util$calcula.outputSize;
109673 var sumDupeIndices = false;
109674 var indicesBuf = backend.bufferSync(indices);
109675 var updatesBuf = backend.bufferSync(updates);
109676 var tensorBuf = backend.bufferSync(tensor);
109677 var outBuf = scatterImpl(indicesBuf, updatesBuf, tensor.shape, outputSize, sliceSize, numUpdates, sliceRank, strides, tensorBuf, sumDupeIndices);
109678 return backend.makeTensorInfo(tensor.shape, outBuf.dtype, outBuf.values);
109679 }
109680 var tensorScatterUpdateConfig$1 = {
109681 kernelName: TensorScatterUpdate,
109682 backendName: 'cpu',
109683 kernelFunc: tensorScatterUpdate$1
109684 };
109685
109686 /**
109687 * @license
109688 * Copyright 2020 Google LLC. All Rights Reserved.
109689 * Licensed under the Apache License, Version 2.0 (the "License");
109690 * you may not use this file except in compliance with the License.
109691 * You may obtain a copy of the License at
109692 *
109693 * http://www.apache.org/licenses/LICENSE-2.0
109694 *
109695 * Unless required by applicable law or agreed to in writing, software
109696 * distributed under the License is distributed on an "AS IS" BASIS,
109697 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109698 * See the License for the specific language governing permissions and
109699 * limitations under the License.
109700 * =============================================================================
109701 */
109702 function tile$1(args) {
109703 var inputs = args.inputs,
109704 backend = args.backend,
109705 attrs = args.attrs;
109706 var x = inputs.x;
109707 var reps = attrs.reps;
109708 assertNotComplex$1(x, 'tile');
109709 var outBuf = tileImpl(backend.bufferSync(x), reps);
109710 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
109711 }
109712 var tileConfig$1 = {
109713 kernelName: Tile,
109714 backendName: 'cpu',
109715 kernelFunc: tile$1
109716 };
109717
109718 function topK$1(args) {
109719 var inputs = args.inputs,
109720 backend = args.backend,
109721 attrs = args.attrs;
109722 var x = inputs.x;
109723 var k = attrs.k,
109724 sorted = attrs.sorted;
109725 assertNotComplex$1(x, 'topk');
109726 var xVals = backend.data.get(x.dataId).values;
109727 var _topKImpl = topKImpl(xVals, x.shape, x.dtype, k, sorted),
109728 _topKImpl2 = _slicedToArray(_topKImpl, 2),
109729 allTopKVals = _topKImpl2[0],
109730 allTopKIndices = _topKImpl2[1];
109731 return [backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values), backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)];
109732 }
109733 var topKConfig$1 = {
109734 kernelName: TopK,
109735 backendName: 'cpu',
109736 kernelFunc: topK$1
109737 };
109738
109739 function transform$1(args) {
109740 var inputs = args.inputs,
109741 attrs = args.attrs,
109742 backend = args.backend;
109743 var image = inputs.image,
109744 transforms = inputs.transforms;
109745 var interpolation = attrs.interpolation,
109746 fillMode = attrs.fillMode,
109747 fillValue = attrs.fillValue,
109748 outputShape = attrs.outputShape;
109749 var _image$shape = _slicedToArray(image.shape, 4),
109750 batch = _image$shape[0],
109751 imageHeight = _image$shape[1],
109752 imageWidth = _image$shape[2],
109753 numChannels = _image$shape[3];
109754 var _ref = outputShape != null ? outputShape : [imageHeight, imageWidth],
109755 _ref2 = _slicedToArray(_ref, 2),
109756 outHeight = _ref2[0],
109757 outWidth = _ref2[1];
109758 var outShape = [batch, outHeight, outWidth, numChannels];
109759 var inStrides = computeStrides(image.shape);
109760 var batchInStride = inStrides[0];
109761 var rowInStride = inStrides[1];
109762 var colInStride = inStrides[2];
109763 var outStrides = computeStrides(outShape);
109764 var batchOutStride = outStrides[0];
109765 var rowOutStride = outStrides[1];
109766 var colOutStride = outStrides[2];
109767 var outVals = getTypedArrayFromDType(image.dtype, sizeFromShape(outShape));
109768 outVals.fill(fillValue);
109769 var imageVals = backend.data.get(image.dataId).values;
109770 var transformVals = backend.data.get(transforms.dataId).values;
109771 // Ref TF implementation:
109772 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/image/image_ops.h
109773 for (var b = 0; b < batch; ++b) {
109774 var _transform = transforms.shape[0] === 1 ? transformVals : transformVals.subarray(b * 8, b * 8 + 8);
109775 for (var outY = 0; outY < outHeight; ++outY) {
109776 for (var outX = 0; outX < outWidth; ++outX) {
109777 for (var channel = 0; channel < numChannels; ++channel) {
109778 var val = void 0;
109779 var projection = _transform[6] * outX + _transform[7] * outY + 1;
109780 if (projection === 0) {
109781 // Return the fill value for infinite coordinates,
109782 // which are outside the input image
109783 continue;
109784 }
109785 var inX = (_transform[0] * outX + _transform[1] * outY + _transform[2]) / projection;
109786 var inY = (_transform[3] * outX + _transform[4] * outY + _transform[5]) / projection;
109787 var x = mapCoord(inX, imageWidth, fillMode);
109788 var y = mapCoord(inY, imageHeight, fillMode);
109789 switch (interpolation) {
109790 case 'nearest':
109791 val = nearestInterpolation(imageVals, imageHeight, imageWidth, batchInStride, rowInStride, colInStride, b, y, x, channel, fillValue);
109792 break;
109793 case 'bilinear':
109794 val = bilinearInterpolation(imageVals, imageHeight, imageWidth, batchInStride, rowInStride, colInStride, b, y, x, channel, fillValue);
109795 break;
109796 default:
109797 throw new Error("Error in Transform: Expect 'nearest' or " + "'bilinear', but got ".concat(interpolation));
109798 }
109799 var ind = b * batchOutStride + outY * rowOutStride + outX * colOutStride + channel;
109800 outVals[ind] = val;
109801 }
109802 }
109803 }
109804 return backend.makeTensorInfo(outShape, image.dtype, outVals);
109805 }
109806 var dataId = backend.write(outVals, outShape, image.dtype);
109807 return {
109808 dataId: dataId,
109809 shape: image.shape,
109810 dtype: image.dtype
109811 };
109812 }
109813 var transformConfig$1 = {
109814 kernelName: Transform,
109815 backendName: 'cpu',
109816 kernelFunc: transform$1
109817 };
109818 function mapCoord(outCoord, len, mode) {
109819 switch (mode) {
109820 case 'reflect':
109821 return mapCoordReflect(outCoord, len);
109822 case 'wrap':
109823 return mapCoordWrap(outCoord, len);
109824 case 'nearest':
109825 return mapCoordNearest(outCoord, len);
109826 case 'constant':
109827 default:
109828 return mapCoordConstant(outCoord, len);
109829 }
109830 }
109831 function mapCoordReflect(outCoord, len) {
109832 // Reflect [abcd] to [dcba|abcd|dcba].
109833 var inCoord = outCoord;
109834 if (inCoord < 0) {
109835 if (len <= 1) {
109836 inCoord = 0;
109837 } else {
109838 var sz2 = 2 * len;
109839 if (inCoord < sz2) {
109840 inCoord = sz2 * Math.trunc(-inCoord / sz2) + inCoord;
109841 }
109842 inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1;
109843 }
109844 } else if (inCoord > len - 1) {
109845 if (len <= 1) {
109846 inCoord = 0;
109847 } else {
109848 var _sz = 2 * len;
109849 inCoord -= _sz * Math.trunc(inCoord / _sz);
109850 if (inCoord >= len) {
109851 inCoord = _sz - inCoord - 1;
109852 }
109853 }
109854 }
109855 // clamp is necessary because when outCoord = 3.5 and len = 4,
109856 // inCoord = 3.5 and will be rounded to 4 in nearest interpolation.
109857 return clamp(0, inCoord, len - 1);
109858 }
109859 function mapCoordWrap(outCoord, len) {
109860 // Wrap [abcd] to [abcd|abcd|abcd].
109861 var inCoord = outCoord;
109862 if (inCoord < 0) {
109863 if (len <= 1) {
109864 inCoord = 0;
109865 } else {
109866 var sz = len - 1;
109867 inCoord += len * (Math.trunc(-inCoord / sz) + 1);
109868 }
109869 } else if (inCoord > len - 1) {
109870 if (len <= 1) {
109871 inCoord = 0;
109872 } else {
109873 var _sz2 = len - 1;
109874 inCoord -= len * Math.trunc(inCoord / _sz2);
109875 }
109876 }
109877 // clamp is necessary because when outCoord = -0.5 and len = 4,
109878 // inCoord = 3.5 and will be rounded to 4 in nearest interpolation.
109879 return clamp(0, inCoord, len - 1);
109880 }
109881 function mapCoordConstant(outCoord, len) {
109882 return outCoord;
109883 }
109884 function mapCoordNearest(outCoord, len) {
109885 return clamp(0, outCoord, len - 1);
109886 }
109887 function readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
109888 var ind = batch * batchStride + y * rowStride + x * colStride + channel;
109889 if (0 <= y && y < imageHeight && 0 <= x && x < imageWidth) {
109890 return imageVals[ind];
109891 } else {
109892 return fillValue;
109893 }
109894 }
109895 function nearestInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
109896 var $y = Math.round(y);
109897 var $x = Math.round(x);
109898 return readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, $y, $x, channel, fillValue);
109899 }
109900 function bilinearInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
109901 var yFloor = Math.floor(y);
109902 var xFloor = Math.floor(x);
109903 var yCeil = yFloor + 1;
109904 var xCeil = xFloor + 1;
109905 // f(x, yFloor) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yFloor)
109906 // + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yFloor)
109907 var valueYFloor = (xCeil - x) * readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xFloor, channel, fillValue) + (x - xFloor) * readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xCeil, channel, fillValue);
109908 // f(x, yCeil) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yCeil)
109909 // + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yCeil)
109910 var valueYCeil = (xCeil - x) * readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xFloor, channel, fillValue) + (x - xFloor) * readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xCeil, channel, fillValue);
109911 // f(x, y) = (yCeil - y) / (yCeil - yFloor) * f(x, yFloor)
109912 // + (y - yFloor) / (yCeil - yFloor) * f(x, yCeil)
109913 return (yCeil - y) * valueYFloor + (y - yFloor) * valueYCeil;
109914 }
109915
109916 /**
109917 * @license
109918 * Copyright 2020 Google LLC. All Rights Reserved.
109919 * Licensed under the Apache License, Version 2.0 (the License);
109920 * you may not use this file except in compliance with the License.
109921 * You may obtain a copy of the License at
109922 *
109923 * http://www.apache.org/licenses/LICENSE-2.0
109924 *
109925 * Unless required by applicable law or agreed to in writing, software
109926 * distributed under the License is distributed on an AS IS BASIS,
109927 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109928 * See the License for the specific language governing permissions and
109929 * limitations under the License.
109930 * =============================================================================
109931 */
109932 function unique$1(args) {
109933 var inputs = args.inputs,
109934 attrs = args.attrs,
109935 backend = args.backend;
109936 var axis = attrs.axis;
109937 var x = inputs.x;
109938 assertNotComplex$1(x, 'unique');
109939 var values = backend.data.get(x.dataId).values;
109940 var _uniqueImpl = uniqueImpl(values, axis, x.shape, x.dtype),
109941 outputValues = _uniqueImpl.outputValues,
109942 outputShape = _uniqueImpl.outputShape,
109943 indices = _uniqueImpl.indices;
109944 return [backend.makeTensorInfo(outputShape, x.dtype, outputValues), backend.makeTensorInfo([indices.length], 'int32', indices)];
109945 }
109946 var uniqueConfig$1 = {
109947 kernelName: Unique,
109948 backendName: 'cpu',
109949 kernelFunc: unique$1
109950 };
109951
109952 /**
109953 * @license
109954 * Copyright 2020 Google LLC. All Rights Reserved.
109955 * Licensed under the Apache License, Version 2.0 (the "License");
109956 * you may not use this file except in compliance with the License.
109957 * You may obtain a copy of the License at
109958 *
109959 * http://www.apache.org/licenses/LICENSE-2.0
109960 *
109961 * Unless required by applicable law or agreed to in writing, software
109962 * distributed under the License is distributed on an "AS IS" BASIS,
109963 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109964 * See the License for the specific language governing permissions and
109965 * limitations under the License.
109966 * =============================================================================
109967 */
109968 function unpack$1(args) {
109969 var inputs = args.inputs,
109970 backend = args.backend,
109971 attrs = args.attrs;
109972 var value = inputs.value;
109973 var axis = attrs.axis;
109974 if (axis < 0) {
109975 axis += value.shape.length;
109976 }
109977 var valueRank = value.shape.length;
109978 var num = value.shape[axis];
109979 var outShape = new Array(valueRank - 1);
109980 var outIndex = 0;
109981 for (var i = 0; i < valueRank; i++) {
109982 if (i !== axis) {
109983 outShape[outIndex++] = value.shape[i];
109984 }
109985 }
109986 var begin = new Array(valueRank).fill(0);
109987 var size = value.shape.slice();
109988 size[axis] = 1;
109989 var res = new Array(num);
109990 for (var _i = 0; _i < res.length; _i++) {
109991 begin[axis] = _i;
109992 var tempRes = slice$1({
109993 inputs: {
109994 x: value
109995 },
109996 backend: backend,
109997 attrs: {
109998 begin: begin,
109999 size: size
110000 }
110001 });
110002 res[_i] = reshape$1({
110003 inputs: {
110004 x: tempRes
110005 },
110006 backend: backend,
110007 attrs: {
110008 shape: outShape
110009 }
110010 });
110011 backend.disposeIntermediateTensorInfo(tempRes);
110012 }
110013 return res;
110014 }
110015 var unpackConfig$1 = {
110016 kernelName: Unpack,
110017 backendName: 'cpu',
110018 kernelFunc: unpack$1
110019 };
110020
110021 /**
110022 * @license
110023 * Copyright 2020 Google LLC. All Rights Reserved.
110024 * Licensed under the Apache License, Version 2.0 (the "License");
110025 * you may not use this file except in compliance with the License.
110026 * You may obtain a copy of the License at
110027 *
110028 * http://www.apache.org/licenses/LICENSE-2.0
110029 *
110030 * Unless required by applicable law or agreed to in writing, software
110031 * distributed under the License is distributed on an "AS IS" BASIS,
110032 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
110033 * See the License for the specific language governing permissions and
110034 * limitations under the License.
110035 * =============================================================================
110036 */
110037 function unsortedSegmentSum$1(args) {
110038 var inputs = args.inputs,
110039 backend = args.backend,
110040 attrs = args.attrs;
110041 var x = inputs.x,
110042 segmentIds = inputs.segmentIds;
110043 var numSegments = attrs.numSegments;
110044 assertNotComplex$1(x, 'unsortedSegmentSum');
110045 var xRank = x.shape.length;
110046 var segmentIdsRank = segmentIds.shape.length;
110047 var res = [];
110048 var intermediates = [];
110049 // Reshape the segment id's so that they can be broadcast with
110050 // x. The new shape should be [segmentIds.shape, 1, ..., 1]
110051 var numIters = xRank - segmentIdsRank;
110052 var $segmentIds = segmentIds;
110053 for (var i = 0; i < numIters; ++i) {
110054 var expanded = expandDims$1({
110055 inputs: {
110056 input: $segmentIds
110057 },
110058 backend: backend,
110059 attrs: {
110060 dim: i + 1
110061 }
110062 });
110063 $segmentIds = expanded;
110064 intermediates.push(expanded);
110065 }
110066 for (var _i = 0; _i < numSegments; ++_i) {
110067 var scalarValue = createScalarValue(_i, 'int32');
110068 var segmentId = backend.makeTensorInfo([], 'int32', scalarValue);
110069 var mask = equal$1({
110070 inputs: {
110071 a: segmentId,
110072 b: $segmentIds
110073 },
110074 backend: backend
110075 });
110076 var maskCasted = cast$1({
110077 inputs: {
110078 x: mask
110079 },
110080 backend: backend,
110081 attrs: {
110082 dtype: 'float32'
110083 }
110084 });
110085 var mul = multiply$1({
110086 inputs: {
110087 a: maskCasted,
110088 b: x
110089 },
110090 backend: backend
110091 });
110092 var sumTensorInfo = sum$1({
110093 inputs: {
110094 x: mul
110095 },
110096 backend: backend,
110097 attrs: {
110098 axis: 0,
110099 keepDims: false
110100 }
110101 });
110102 res.push(sumTensorInfo);
110103 intermediates.push(segmentId);
110104 intermediates.push(mask);
110105 intermediates.push(maskCasted);
110106 intermediates.push(mul);
110107 intermediates.push(sumTensorInfo);
110108 }
110109 var result = pack$1({
110110 inputs: res,
110111 backend: backend,
110112 attrs: {
110113 axis: 0
110114 }
110115 });
110116 intermediates.forEach(function (t) {
110117 return backend.disposeIntermediateTensorInfo(t);
110118 });
110119 return result;
110120 }
110121 var unsortedSegmentSumConfig$1 = {
110122 kernelName: UnsortedSegmentSum,
110123 backendName: 'cpu',
110124 kernelFunc: unsortedSegmentSum$1
110125 };
110126
110127 /**
110128 * @license
110129 * Copyright 2020 Google LLC. All Rights Reserved.
110130 * Licensed under the Apache License, Version 2.0 (the "License");
110131 * you may not use this file except in compliance with the License.
110132 * You may obtain a copy of the License at
110133 *
110134 * http://www.apache.org/licenses/LICENSE-2.0
110135 *
110136 * Unless required by applicable law or agreed to in writing, software
110137 * distributed under the License is distributed on an "AS IS" BASIS,
110138 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
110139 * See the License for the specific language governing permissions and
110140 * limitations under the License.
110141 * =============================================================================
110142 */
110143 // List all kernel configs here
110144 var kernelConfigs$1 = [_fusedMatMulConfig$1, absConfig$1, acosConfig$1, acoshConfig$1, addConfig$1, addNConfig$1, allConfig$1, anyConfig$1, argMaxConfig$1, argMinConfig$1, asinConfig$1, asinhConfig$1, atanConfig$1, atan2Config$1, atanhConfig$1, avgPoolConfig$1, avgPool3DConfig$1, avgPool3DGradConfig$1, avgPoolGradConfig$1, batchMatMulConfig$1, batchNormConfig$1, batchToSpaceNDConfig$1, bincountConfig$1, bitwiseAndConfig$1, broadcastArgsConfig$1, castConfig$1, ceilConfig$1, clipByValueConfig$1, complexConfig$1, complexAbsConfig$1, concatConfig$1, conv2DConfig$1, conv2DBackpropFilterConfig$1, conv2DBackpropInputConfig$1, conv3DConfig$1, conv3DBackpropFilterV2Config$1, conv3DBackpropInputV2Config, cosConfig$1, coshConfig$1, cropAndResizeConfig$1, cumprodConfig$1, cumsumConfig$1, denseBincountConfig$1, depthToSpaceConfig$1, depthwiseConv2dNativeConfig$1, depthwiseConv2dNativeBackpropFilterConfig$1, depthwiseConv2dNativeBackpropInputConfig$1, diagConfig$1, dilation2DConfig$1, dilation2DBackpropFilterConfig, dilation2DBackpropInputConfig, drawConfig, einsumConfig$1, eluConfig$1, eluGradConfig$1, equalConfig$1, erfConfig$1, expConfig$1, expandDimsConfig$1, expm1Config$1, fftConfig$1, fillConfig$1, flipLeftRightConfig$1, floorConfig$1, floorDivConfig$1, fusedConv2DConfig$1, fusedDepthwiseConv2DConfig$1, gatherNdConfig$1, gatherV2Config$1, greaterConfig$1, greaterEqualConfig$1, identityConfig$1, ifftConfig$1, imagConfig$1, isFiniteConfig$1, isInfConfig$1, isNaNConfig$1, leakyReluConfig$1, lessConfig$1, lessEqualConfig$1, linSpaceConfig$1, logConfig$1, log1pConfig$1, logicalAndConfig$1, logicalNotConfig$1, logicalOrConfig$1, LRNConfig$1, LRNGradConfig$1, maxConfig$1, maximumConfig$1, maxPoolConfig$1, maxPool3DConfig$1, maxPool3DGradConfig$1, maxPoolGradConfig$1, maxPoolWithArgmaxConfig$1, meanConfig$1, minConfig$1, minimumConfig$1, mirrorPadConfig$1, modConfig$1, multinomialConfig$1, multiplyConfig$1, negConfig$1, nonMaxSuppressionV3Config$1, nonMaxSuppressionV4Config$1, nonMaxSuppressionV5Config$1, notEqualConfig$1, oneHotConfig$1, onesLikeConfig$1, packConfig$1, padV2Config$1, powConfig$1, preluConfig$1, prodConfig$1, raggedGatherConfig$1, raggedRangeConfig$1, raggedTensorToTensorConfig$1, rangeConfig$1, realConfig$1, realDivConfig$1, reciprocalConfig$1, reluConfig$1, relu6Config$1, reshapeConfig$1, resizeBilinearConfig$1, resizeBilinearGradConfig$1, resizeNearestNeighborConfig$1, resizeNearestNeighborGradConfig$1, reverseConfig$1, rotateWithOffsetConfig$1, roundConfig$1, rsqrtConfig$1, scatterNdConfig$1, searchSortedConfig$1, selectConfig$1, seluConfig$1, sigmoidConfig$1, signConfig$1, sinConfig$1, sinhConfig$1, sliceConfig$1, softmaxConfig$1, softplusConfig$1, spaceToBatchNDConfig$1, sparseFillEmptyRowsConfig$1, sparseReshapeConfig$1, sparseSegmentMeanConfig$1, sparseSegmentSumConfig$1, sparseToDenseConfig$1, splitVConfig$1, sqrtConfig$1, squareConfig$1, squaredDifferenceConfig$1, staticRegexReplaceConfig$1, stepConfig$1, stridedSliceConfig$1, stringNGramsConfig$1, stringSplitConfig$1, stringToHashBucketFastConfig$1, subConfig$1, sumConfig$1, tanConfig$1, tanhConfig$1, tensorScatterUpdateConfig$1, tileConfig$1, topKConfig$1, transformConfig$1, transposeConfig$1, uniqueConfig$1, unpackConfig$1, unsortedSegmentSumConfig$1, zerosLikeConfig$1];
110145 for (var _i$1 = 0, _kernelConfigs$1 = kernelConfigs$1; _i$1 < _kernelConfigs$1.length; _i$1++) {
110146 var kernelConfig$1 = _kernelConfigs$1[_i$1];
110147 registerKernel(kernelConfig$1);
110148 }
110149
110150 /**
110151 * @license
110152 * Copyright 2020 Google LLC. All Rights Reserved.
110153 * Licensed under the Apache License, Version 2.0 (the "License");
110154 * you may not use this file except in compliance with the License.
110155 * You may obtain a copy of the License at
110156 *
110157 * http://www.apache.org/licenses/LICENSE-2.0
110158 *
110159 * Unless required by applicable law or agreed to in writing, software
110160 * distributed under the License is distributed on an "AS IS" BASIS,
110161 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
110162 * See the License for the specific language governing permissions and
110163 * limitations under the License.
110164 * =============================================================================
110165 */
110166
110167 /**
110168 * @license
110169 * Copyright 2018 Google LLC. All Rights Reserved.
110170 * Licensed under the Apache License, Version 2.0 (the "License");
110171 * you may not use this file except in compliance with the License.
110172 * You may obtain a copy of the License at
110173 *
110174 * http://www.apache.org/licenses/LICENSE-2.0
110175 *
110176 * Unless required by applicable law or agreed to in writing, software
110177 * distributed under the License is distributed on an "AS IS" BASIS,
110178 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
110179 * See the License for the specific language governing permissions and
110180 * limitations under the License.
110181 * =============================================================================
110182 */
110183 var contexts = {};
110184 var WEBGL_ATTRIBUTES = {
110185 alpha: false,
110186 antialias: false,
110187 premultipliedAlpha: false,
110188 preserveDrawingBuffer: false,
110189 depth: false,
110190 stencil: false,
110191 failIfMajorPerformanceCaveat: true
110192 };
110193 function clearWebGLContext(webGLVersion) {
110194 delete contexts[webGLVersion];
110195 }
110196 function setWebGLContext(webGLVersion, gl) {
110197 contexts[webGLVersion] = gl;
110198 }
110199 function getWebGLContext(webGLVersion, customCanvas) {
110200 if (!(webGLVersion in contexts) || customCanvas != null) {
110201 var newCtx = getWebGLRenderingContext(webGLVersion, customCanvas);
110202 if (newCtx !== null) {
110203 contexts[webGLVersion] = newCtx;
110204 } else {
110205 console.log('Could not get context for WebGL version', webGLVersion);
110206 return null;
110207 }
110208 }
110209 var gl = contexts[webGLVersion];
110210 if (gl == null || gl.isContextLost()) {
110211 delete contexts[webGLVersion];
110212 return getWebGLContext(webGLVersion);
110213 }
110214 gl.disable(gl.DEPTH_TEST);
110215 gl.disable(gl.STENCIL_TEST);
110216 gl.disable(gl.BLEND);
110217 gl.disable(gl.DITHER);
110218 gl.disable(gl.POLYGON_OFFSET_FILL);
110219 gl.disable(gl.SAMPLE_COVERAGE);
110220 gl.enable(gl.SCISSOR_TEST);
110221 gl.enable(gl.CULL_FACE);
110222 gl.cullFace(gl.BACK);
110223 return contexts[webGLVersion];
110224 }
110225 function createCanvas(webGLVersion) {
110226 // Use canvas element for Safari, since its offscreen canvas does not support
110227 // fencing.
110228 if (!env().getBool('IS_SAFARI') && typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) {
110229 return new OffscreenCanvas(300, 150);
110230 } else if (typeof document !== 'undefined') {
110231 return document.createElement('canvas');
110232 } else {
110233 throw new Error('Cannot create a canvas in this context');
110234 }
110235 }
110236 function getWebGLRenderingContext(webGLVersion, customCanvas) {
110237 if (webGLVersion !== 1 && webGLVersion !== 2) {
110238 throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
110239 }
110240 var canvas = customCanvas == null ? createCanvas(webGLVersion) : customCanvas;
110241 canvas.addEventListener('webglcontextlost', function (ev) {
110242 ev.preventDefault();
110243 delete contexts[webGLVersion];
110244 }, false);
110245 if (env().getBool('SOFTWARE_WEBGL_ENABLED')) {
110246 WEBGL_ATTRIBUTES.failIfMajorPerformanceCaveat = false;
110247 }
110248 if (webGLVersion === 1) {
110249 return (
110250 // tslint:disable-next-line
110251 canvas.getContext('webgl', WEBGL_ATTRIBUTES) || canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES)
110252 );
110253 }
110254 return canvas.getContext('webgl2', WEBGL_ATTRIBUTES);
110255 }
110256
110257 var PackingScheme;
110258 (function (PackingScheme) {
110259 /**
110260 * All values in a single texel are densely packed without any constraints.
110261 *
110262 * This is how the shader encodes a tensor with shape = [2, 3, 4]
110263 * (indices are [batch, row, col]).
110264 *
110265 * 000|001 010|011 020|021
110266 * ------- ------- -------
110267 * 002|003 012|013 022|023
110268 *
110269 * 100|101 110|111 120|121
110270 * ------- ------- -------
110271 * 102|103 112|113 122|123
110272 *
110273 */
110274 PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE";
110275 /**
110276 * Single texels contain only values from the same batch, and from adjacent
110277 * rows and columns.
110278 *
110279 * This is how the shader encodes a tensor with shape = [2, 3, 5]
110280 * (indices are [batch, row, col]).
110281 *
110282 * 000|001 002|003 004|xxx 020|021 022|023 024|xxx
110283 * ------- ------- ------- ------- ------- -------
110284 * 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
110285 *
110286 * 100|101 102|103 104|xxx 120|121 122|123 124|xxx
110287 * ------- ------- ------- ------- ------- -------
110288 * 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
110289 *
110290 */
110291 PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH";
110292 })(PackingScheme || (PackingScheme = {}));
110293 var TextureUsage;
110294 (function (TextureUsage) {
110295 TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER";
110296 TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD";
110297 TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS";
110298 TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD";
110299 })(TextureUsage || (TextureUsage = {}));
110300 var PhysicalTextureType;
110301 (function (PhysicalTextureType) {
110302 PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16";
110303 PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32";
110304 PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE";
110305 PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32";
110306 PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16";
110307 })(PhysicalTextureType || (PhysicalTextureType = {}));
110308 function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) {
110309 return [columns, rows];
110310 }
110311 function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) {
110312 return matrixSize * channelsPerTexture;
110313 }
110314 function getColorMatrixTextureShapeWidthHeight(rows, columns) {
110315 return [columns * 4, rows];
110316 }
110317 /**
110318 * Get shape for densely packed RGBA texture.
110319 */
110320 function getDenseTexShape(shape) {
110321 var size = sizeFromShape(shape);
110322 var texelsNeeded = Math.ceil(size / 4);
110323 return sizeToSquarishShape(texelsNeeded);
110324 }
110325 function getMatrixSizeFromUnpackedArraySize(unpackedSize, channelsPerTexture) {
110326 if (unpackedSize % channelsPerTexture !== 0) {
110327 throw new Error("unpackedSize (".concat(unpackedSize, ") must be a multiple of ") + "".concat(channelsPerTexture));
110328 }
110329 return unpackedSize / channelsPerTexture;
110330 }
110331 function decodeMatrixFromUnpackedColorRGBAArray(unpackedArray, matrix, channels) {
110332 var requiredSize = unpackedArray.length * channels / 4;
110333 if (matrix.length < requiredSize) {
110334 throw new Error("matrix length (".concat(matrix.length, ") must be >= ").concat(requiredSize));
110335 }
110336 var dst = 0;
110337 for (var src = 0; src < unpackedArray.length; src += 4) {
110338 for (var c = 0; c < channels; c++) {
110339 matrix[dst++] = unpackedArray[src + c];
110340 }
110341 }
110342 }
110343 function getPackedMatrixTextureShapeWidthHeight(rows, columns) {
110344 return [Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2))];
110345 }
110346 function getPackedRGBAArraySizeFromMatrixShape(rows, columns) {
110347 var _getPackedMatrixTextu = getPackedMatrixTextureShapeWidthHeight(rows, columns),
110348 _getPackedMatrixTextu2 = _slicedToArray(_getPackedMatrixTextu, 2),
110349 w = _getPackedMatrixTextu2[0],
110350 h = _getPackedMatrixTextu2[1];
110351 return w * h * 4;
110352 }
110353 function getTextureConfig(
110354 // tslint:disable-next-line:no-any
110355 gl, textureHalfFloatExtension) {
110356 // tslint:disable-next-line:no-any
110357 var glany = gl;
110358 var internalFormatFloat;
110359 var internalFormatHalfFloat;
110360 var internalFormatPackedHalfFloat;
110361 var internalFormatPackedFloat;
110362 var textureFormatFloat;
110363 var downloadTextureFormat;
110364 var downloadUnpackNumChannels;
110365 var defaultNumChannels;
110366 var textureTypeHalfFloat;
110367 var textureTypeFloat;
110368 if (env().getNumber('WEBGL_VERSION') === 2) {
110369 internalFormatFloat = glany.R32F;
110370 internalFormatHalfFloat = glany.R16F;
110371 internalFormatPackedHalfFloat = glany.RGBA16F;
110372 internalFormatPackedFloat = glany.RGBA32F;
110373 textureFormatFloat = glany.RED;
110374 downloadUnpackNumChannels = 4;
110375 defaultNumChannels = 1;
110376 textureTypeHalfFloat = glany.HALF_FLOAT;
110377 textureTypeFloat = glany.FLOAT;
110378 downloadTextureFormat = glany.RGBA8;
110379 } else {
110380 internalFormatFloat = gl.RGBA;
110381 internalFormatHalfFloat = gl.RGBA;
110382 internalFormatPackedHalfFloat = gl.RGBA;
110383 internalFormatPackedFloat = glany.RGBA;
110384 textureFormatFloat = gl.RGBA;
110385 downloadUnpackNumChannels = 4;
110386 defaultNumChannels = 4;
110387 textureTypeHalfFloat = textureHalfFloatExtension != null ? textureHalfFloatExtension.HALF_FLOAT_OES : null;
110388 textureTypeFloat = gl.FLOAT;
110389 downloadTextureFormat = gl.RGBA;
110390 }
110391 return {
110392 internalFormatFloat: internalFormatFloat,
110393 internalFormatHalfFloat: internalFormatHalfFloat,
110394 internalFormatPackedHalfFloat: internalFormatPackedHalfFloat,
110395 internalFormatPackedFloat: internalFormatPackedFloat,
110396 textureFormatFloat: textureFormatFloat,
110397 downloadTextureFormat: downloadTextureFormat,
110398 downloadUnpackNumChannels: downloadUnpackNumChannels,
110399 defaultNumChannels: defaultNumChannels,
110400 textureTypeHalfFloat: textureTypeHalfFloat,
110401 textureTypeFloat: textureTypeFloat
110402 };
110403 }
110404
110405 function callAndCheck(gl, func) {
110406 var returnValue = func();
110407 if (env().getBool('DEBUG')) {
110408 checkWebGLError(gl);
110409 }
110410 return returnValue;
110411 }
110412 function checkWebGLError(gl) {
110413 var error = gl.getError();
110414 if (error !== gl.NO_ERROR) {
110415 throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error));
110416 }
110417 }
110418 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format
110419 var MIN_FLOAT16 = 5.96e-8;
110420 var MAX_FLOAT16 = 65504;
110421 function canBeRepresented(num) {
110422 if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 || MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16) {
110423 return true;
110424 }
110425 return false;
110426 }
110427 function getWebGLErrorMessage(gl, status) {
110428 switch (status) {
110429 case gl.NO_ERROR:
110430 return 'NO_ERROR';
110431 case gl.INVALID_ENUM:
110432 return 'INVALID_ENUM';
110433 case gl.INVALID_VALUE:
110434 return 'INVALID_VALUE';
110435 case gl.INVALID_OPERATION:
110436 return 'INVALID_OPERATION';
110437 case gl.INVALID_FRAMEBUFFER_OPERATION:
110438 return 'INVALID_FRAMEBUFFER_OPERATION';
110439 case gl.OUT_OF_MEMORY:
110440 return 'OUT_OF_MEMORY';
110441 case gl.CONTEXT_LOST_WEBGL:
110442 return 'CONTEXT_LOST_WEBGL';
110443 default:
110444 return "Unknown error code ".concat(status);
110445 }
110446 }
110447 function getExtensionOrThrow(gl, extensionName) {
110448 return throwIfNull(gl, function () {
110449 return gl.getExtension(extensionName);
110450 }, 'Extension "' + extensionName + '" not supported on this browser.');
110451 }
110452 function createVertexShader$1(gl, vertexShaderSource) {
110453 var vertexShader = throwIfNull(gl, function () {
110454 return gl.createShader(gl.VERTEX_SHADER);
110455 }, 'Unable to create vertex WebGLShader.');
110456 callAndCheck(gl, function () {
110457 return gl.shaderSource(vertexShader, vertexShaderSource);
110458 });
110459 callAndCheck(gl, function () {
110460 return gl.compileShader(vertexShader);
110461 });
110462 if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) {
110463 console.log(gl.getShaderInfoLog(vertexShader));
110464 throw new Error('Failed to compile vertex shader.');
110465 }
110466 return vertexShader;
110467 }
110468 function createFragmentShader(gl, fragmentShaderSource) {
110469 var fragmentShader = throwIfNull(gl, function () {
110470 return gl.createShader(gl.FRAGMENT_SHADER);
110471 }, 'Unable to create fragment WebGLShader.');
110472 callAndCheck(gl, function () {
110473 return gl.shaderSource(fragmentShader, fragmentShaderSource);
110474 });
110475 callAndCheck(gl, function () {
110476 return gl.compileShader(fragmentShader);
110477 });
110478 if (env().get('ENGINE_COMPILE_ONLY')) {
110479 return fragmentShader;
110480 }
110481 if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) {
110482 logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader));
110483 throw new Error('Failed to compile fragment shader.');
110484 }
110485 return fragmentShader;
110486 }
110487 var lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g;
110488 function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) {
110489 var lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog);
110490 if (lineNumberRegexResult == null) {
110491 console.log("Couldn't parse line number in error: ".concat(shaderInfoLog));
110492 console.log(shaderSource);
110493 return;
110494 }
110495 var lineNumber = +lineNumberRegexResult[1];
110496 var shaderLines = shaderSource.split('\n');
110497 var pad = shaderLines.length.toString().length + 2;
110498 var linesWithLineNumbers = shaderLines.map(function (line, lineNumber) {
110499 return rightPad((lineNumber + 1).toString(), pad) + line;
110500 });
110501 var maxLineLength = 0;
110502 for (var i = 0; i < linesWithLineNumbers.length; i++) {
110503 maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength);
110504 }
110505 var beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1);
110506 var errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber);
110507 var afterErrorLines = linesWithLineNumbers.slice(lineNumber);
110508 console.log(beforeErrorLines.join('\n'));
110509 console.log(shaderInfoLog.split('\n')[0]);
110510 console.log("%c ".concat(rightPad(errorLine[0], maxLineLength)), 'border:1px solid red; background-color:#e3d2d2; color:#a61717');
110511 console.log(afterErrorLines.join('\n'));
110512 }
110513 function createProgram(gl) {
110514 return throwIfNull(gl, function () {
110515 return gl.createProgram();
110516 }, 'Unable to create WebGLProgram.');
110517 }
110518 function linkProgram(gl, program) {
110519 callAndCheck(gl, function () {
110520 return gl.linkProgram(program);
110521 });
110522 if (env().get('ENGINE_COMPILE_ONLY')) {
110523 return;
110524 }
110525 if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) {
110526 console.log(gl.getProgramInfoLog(program));
110527 throw new Error('Failed to link vertex and fragment shaders.');
110528 }
110529 }
110530 /// validateProgram is effectively "If we `useProgram(program); drawArrays();`,
110531 /// give feedback in log about perf/correctness warnings or errors that would
110532 /// occur."
110533 /// So make sure we set up all vertex/texture/sampler/uniform data before
110534 /// calling validateProgram!
110535 function validateProgram(gl, program) {
110536 callAndCheck(gl, function () {
110537 return gl.validateProgram(program);
110538 });
110539 if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) {
110540 console.log(gl.getProgramInfoLog(program));
110541 throw new Error('Shader program validation failed.');
110542 }
110543 }
110544 function createStaticVertexBuffer(gl, data) {
110545 var buffer = throwIfNull(gl, function () {
110546 return gl.createBuffer();
110547 }, 'Unable to create WebGLBuffer');
110548 callAndCheck(gl, function () {
110549 return gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
110550 });
110551 callAndCheck(gl, function () {
110552 return gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW);
110553 });
110554 return buffer;
110555 }
110556 function createStaticIndexBuffer(gl, data) {
110557 var buffer = throwIfNull(gl, function () {
110558 return gl.createBuffer();
110559 }, 'Unable to create WebGLBuffer');
110560 callAndCheck(gl, function () {
110561 return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer);
110562 });
110563 callAndCheck(gl, function () {
110564 return gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW);
110565 });
110566 return buffer;
110567 }
110568 function getNumChannels() {
110569 if (env().getNumber('WEBGL_VERSION') === 2) {
110570 return 1;
110571 }
110572 return 4;
110573 }
110574 function createTexture(gl) {
110575 return throwIfNull(gl, function () {
110576 return gl.createTexture();
110577 }, 'Unable to create WebGLTexture.');
110578 }
110579 function validateTextureSize(width, height) {
110580 var maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
110581 if (width <= 0 || height <= 0) {
110582 var requested = "[".concat(width, "x").concat(height, "]");
110583 throw new Error('Requested texture size ' + requested + ' is invalid.');
110584 }
110585 if (width > maxTextureSize || height > maxTextureSize) {
110586 var _requested = "[".concat(width, "x").concat(height, "]");
110587 var max = "[".concat(maxTextureSize, "x").concat(maxTextureSize, "]");
110588 throw new Error('Requested texture size ' + _requested + ' greater than WebGL maximum on this browser / GPU ' + max + '.');
110589 }
110590 }
110591 function createFramebuffer(gl) {
110592 return throwIfNull(gl, function () {
110593 return gl.createFramebuffer();
110594 }, 'Unable to create WebGLFramebuffer.');
110595 }
110596 function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) {
110597 var loc = gl.getAttribLocation(program, attribute);
110598 if (loc === -1) {
110599 // The GPU compiler decided to strip out this attribute because it's unused,
110600 // thus no need to bind.
110601 return false;
110602 }
110603 callAndCheck(gl, function () {
110604 return gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
110605 });
110606 callAndCheck(gl, function () {
110607 return gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes);
110608 });
110609 callAndCheck(gl, function () {
110610 return gl.enableVertexAttribArray(loc);
110611 });
110612 return true;
110613 }
110614 function bindTextureUnit(gl, texture, textureUnit) {
110615 validateTextureUnit(gl, textureUnit);
110616 callAndCheck(gl, function () {
110617 return gl.activeTexture(gl.TEXTURE0 + textureUnit);
110618 });
110619 callAndCheck(gl, function () {
110620 return gl.bindTexture(gl.TEXTURE_2D, texture);
110621 });
110622 }
110623 function unbindTextureUnit(gl, textureUnit) {
110624 validateTextureUnit(gl, textureUnit);
110625 callAndCheck(gl, function () {
110626 return gl.activeTexture(gl.TEXTURE0 + textureUnit);
110627 });
110628 callAndCheck(gl, function () {
110629 return gl.bindTexture(gl.TEXTURE_2D, null);
110630 });
110631 }
110632 function getProgramUniformLocationOrThrow(gl, program, uniformName) {
110633 return throwIfNull(gl, function () {
110634 return gl.getUniformLocation(program, uniformName);
110635 }, 'uniform "' + uniformName + '" not present in program.');
110636 }
110637 function getProgramUniformLocation(gl, program, uniformName) {
110638 return gl.getUniformLocation(program, uniformName);
110639 }
110640 function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) {
110641 callAndCheck(gl, function () {
110642 return bindTextureUnit(gl, texture, textureUnit);
110643 });
110644 callAndCheck(gl, function () {
110645 return gl.uniform1i(uniformSamplerLocation, textureUnit);
110646 });
110647 }
110648 function bindCanvasToFramebuffer(gl) {
110649 callAndCheck(gl, function () {
110650 return gl.bindFramebuffer(gl.FRAMEBUFFER, null);
110651 });
110652 callAndCheck(gl, function () {
110653 return gl.viewport(0, 0, gl.canvas.width, gl.canvas.height);
110654 });
110655 callAndCheck(gl, function () {
110656 return gl.scissor(0, 0, gl.canvas.width, gl.canvas.height);
110657 });
110658 }
110659 function bindColorTextureToFramebuffer(gl, texture, framebuffer) {
110660 callAndCheck(gl, function () {
110661 return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
110662 });
110663 callAndCheck(gl, function () {
110664 return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
110665 });
110666 }
110667 function unbindColorTextureFromFramebuffer(gl, framebuffer) {
110668 callAndCheck(gl, function () {
110669 return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
110670 });
110671 callAndCheck(gl, function () {
110672 return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0);
110673 });
110674 }
110675 function validateFramebuffer(gl) {
110676 var status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
110677 if (status !== gl.FRAMEBUFFER_COMPLETE) {
110678 throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status));
110679 }
110680 }
110681 function getFramebufferErrorMessage(gl, status) {
110682 switch (status) {
110683 case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT:
110684 return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT';
110685 case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT:
110686 return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT';
110687 case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS:
110688 return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS';
110689 case gl.FRAMEBUFFER_UNSUPPORTED:
110690 return 'FRAMEBUFFER_UNSUPPORTED';
110691 default:
110692 return "unknown error ".concat(status);
110693 }
110694 }
110695 function throwIfNull(gl, returnTOrNull, failureMessage) {
110696 var tOrNull = callAndCheck(gl, function () {
110697 return returnTOrNull();
110698 });
110699 if (tOrNull == null) {
110700 throw new Error(failureMessage);
110701 }
110702 return tOrNull;
110703 }
110704 function validateTextureUnit(gl, textureUnit) {
110705 var maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1;
110706 var glTextureUnit = textureUnit + gl.TEXTURE0;
110707 if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) {
110708 var textureUnitRange = "[gl.TEXTURE0, gl.TEXTURE".concat(maxTextureUnit, "]");
110709 throw new Error("textureUnit must be in ".concat(textureUnitRange, "."));
110710 }
110711 }
110712 function getBatchDim(shape) {
110713 var dimsToSkip = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 2;
110714 return sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
110715 }
110716 function getRowsCols(shape) {
110717 if (shape.length === 0) {
110718 throw Error('Cannot get rows and columns of an empty shape array.');
110719 }
110720 return [shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]];
110721 }
110722 function getShapeAs3D(shape) {
110723 var shapeAs3D = [1, 1, 1];
110724 var isScalar = shape.length === 0 || shape.length === 1 && shape[0] === 1;
110725 if (!isScalar) {
110726 shapeAs3D = [getBatchDim(shape)].concat(_toConsumableArray(getRowsCols(shape)));
110727 }
110728 return shapeAs3D;
110729 }
110730 function getTextureShapeFromLogicalShape(logShape) {
110731 var isPacked = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
110732 var maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
110733 var maxSizeForNarrowTex = env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE');
110734 if (maxSizeForNarrowTex === Infinity && env().getBool('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE')) {
110735 maxSizeForNarrowTex = maxTexSize / 2;
110736 }
110737 if (isPacked) {
110738 maxTexSize = maxTexSize * 2;
110739 maxSizeForNarrowTex = maxSizeForNarrowTex * 2;
110740 // This logic ensures we accurately count the number of packed texels needed
110741 // to accommodate the tensor. We can only pack values in the same texel if
110742 // they are from adjacent pairs of rows/cols within the same batch. So if a
110743 // tensor has 3 rows, we pretend it has 4 rows in order to account for the
110744 // fact that the texels containing the third row are half empty.
110745 logShape = logShape.map(function (d, i) {
110746 return i >= logShape.length - 2 ? nearestLargerEven(logShape[i]) : logShape[i];
110747 });
110748 // Packed texture height is at least 2 (the channel height of a single
110749 // texel).
110750 if (logShape.length === 1) {
110751 logShape = [2, logShape[0]];
110752 }
110753 }
110754 // If logical shape is 2, we don't squeeze, since we want to match physical.
110755 if (logShape.length !== 2) {
110756 var squeezeResult = squeezeShape(logShape);
110757 logShape = squeezeResult.newShape;
110758 }
110759 var size = sizeFromShape(logShape);
110760 var textureShape = null;
110761 if (logShape.length <= 1 && size <= maxTexSize) {
110762 textureShape = [1, size];
110763 } else if (logShape.length === 2 && logShape[0] <= maxTexSize && logShape[1] <= maxTexSize) {
110764 textureShape = logShape;
110765 } else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize && logShape[2] <= maxTexSize) {
110766 textureShape = [logShape[0] * logShape[1], logShape[2]];
110767 } else if (logShape.length === 3 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] <= maxTexSize) {
110768 textureShape = [logShape[0], logShape[1] * logShape[2]];
110769 } else if (logShape.length === 4 && logShape[0] * logShape[1] * logShape[2] <= maxTexSize && logShape[3] <= maxTexSize) {
110770 textureShape = [logShape[0] * logShape[1] * logShape[2], logShape[3]];
110771 } else if (logShape.length === 4 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {
110772 textureShape = [logShape[0], logShape[1] * logShape[2] * logShape[3]];
110773 }
110774 // true if one edge length is 1 (1 or 2, if packed), while another edge
110775 // length exceeds maxSizeForNarrowTex.
110776 var isLongNarrowTex = textureShape != null && Math.max.apply(Math, _toConsumableArray(textureShape)) > maxSizeForNarrowTex && Math.min.apply(Math, _toConsumableArray(textureShape)) <= (isPacked ? 2 : 1) && Math.min.apply(Math, _toConsumableArray(textureShape)) > 0;
110777 if (textureShape == null || isLongNarrowTex) {
110778 if (isPacked) {
110779 // For packed textures size equals the number of channels required to
110780 // accommodate the texture data. However in order to squarify such that
110781 // inner dimensions stay even, we rewrite size to equal the number of
110782 // texels. Then in the return statement we rehydrate the squarified
110783 // dimensions to channel units.
110784 var batchDim = getBatchDim(logShape);
110785 var rows = 2,
110786 cols = 2;
110787 if (logShape.length) {
110788 var _getRowsCols = getRowsCols(logShape);
110789 var _getRowsCols2 = _slicedToArray(_getRowsCols, 2);
110790 rows = _getRowsCols2[0];
110791 cols = _getRowsCols2[1];
110792 }
110793 size = batchDim * (rows / 2) * (cols / 2);
110794 textureShape = sizeToSquarishShape(size).map(function (d) {
110795 return d * 2;
110796 });
110797 } else {
110798 textureShape = sizeToSquarishShape(size);
110799 }
110800 }
110801 return textureShape;
110802 }
110803 function isEven(n) {
110804 return n % 2 === 0;
110805 }
110806 /**
110807 * This determines whether reshaping a packed texture requires rearranging
110808 * the data within the texture, assuming 2x2 packing.
110809 */
110810 function isReshapeFree(shape1, shape2) {
110811 shape1 = shape1.slice(-2);
110812 shape2 = shape2.slice(-2);
110813 if (arraysEqual(shape1, shape2)) {
110814 return true;
110815 }
110816 if (!shape1.length || !shape2.length) {
110817 // One of the shapes is a scalar.
110818 return true;
110819 }
110820 if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 || shape2[1] === 0) {
110821 return true;
110822 }
110823 if (shape1.length !== shape2.length) {
110824 // One of the shapes is a vector.
110825 var shape1Cols = shape1[shape1.length - 1];
110826 var shape2Cols = shape2[shape2.length - 1];
110827 if (shape1Cols === shape2Cols) {
110828 return true;
110829 }
110830 if (isEven(shape1Cols) && isEven(shape2Cols) && (shape1[0] === 1 || shape2[0] === 1)) {
110831 return true;
110832 }
110833 }
110834 return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]);
110835 }
110836 // We cache webgl params because the environment gets reset between
110837 // unit tests and we don't want to constantly query the WebGLContext for
110838 // MAX_TEXTURE_SIZE.
110839 var MAX_TEXTURE_SIZE;
110840 var MAX_TEXTURES_IN_SHADER;
110841 function getWebGLMaxTextureSize(webGLVersion) {
110842 if (MAX_TEXTURE_SIZE == null) {
110843 var gl = getWebGLContext(webGLVersion);
110844 MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE);
110845 }
110846 return MAX_TEXTURE_SIZE;
110847 }
110848 function resetMaxTextureSize() {
110849 MAX_TEXTURE_SIZE = null;
110850 }
110851 function resetMaxTexturesInShader() {
110852 MAX_TEXTURES_IN_SHADER = null;
110853 }
110854 function getMaxTexturesInShader(webGLVersion) {
110855 if (MAX_TEXTURES_IN_SHADER == null) {
110856 var gl = getWebGLContext(webGLVersion);
110857 MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
110858 }
110859 // We cap at 16 to avoid spurious runtime "memory exhausted" error.
110860 return Math.min(16, MAX_TEXTURES_IN_SHADER);
110861 }
110862 function getWebGLDisjointQueryTimerVersion(webGLVersion) {
110863 if (webGLVersion === 0) {
110864 return 0;
110865 }
110866 var queryTimerVersion;
110867 var gl = getWebGLContext(webGLVersion);
110868 if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') && webGLVersion === 2) {
110869 queryTimerVersion = 2;
110870 } else if (hasExtension(gl, 'EXT_disjoint_timer_query')) {
110871 queryTimerVersion = 1;
110872 } else {
110873 queryTimerVersion = 0;
110874 }
110875 return queryTimerVersion;
110876 }
110877 function hasExtension(gl, extensionName) {
110878 var ext = gl.getExtension(extensionName);
110879 return ext != null;
110880 }
110881 function isWebGLVersionEnabled(webGLVersion) {
110882 try {
110883 var gl = getWebGLContext(webGLVersion);
110884 if (gl != null) {
110885 return true;
110886 }
110887 } catch (e) {
110888 console.log('Error when getting WebGL context: ', e);
110889 return false;
110890 }
110891 return false;
110892 }
110893 function isCapableOfRenderingToFloatTexture(webGLVersion) {
110894 if (webGLVersion === 0) {
110895 return false;
110896 }
110897 var gl = getWebGLContext(webGLVersion);
110898 if (webGLVersion === 1) {
110899 if (!hasExtension(gl, 'OES_texture_float')) {
110900 return false;
110901 }
110902 } else {
110903 if (!hasExtension(gl, 'EXT_color_buffer_float')) {
110904 return false;
110905 }
110906 }
110907 var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
110908 return isFrameBufferComplete;
110909 }
110910 /**
110911 * Check if we can download values from a float/half-float texture.
110912 *
110913 * Note that for performance reasons we use binding a texture to a framebuffer
110914 * as a proxy for ability to download float values later using readPixels. The
110915 * texture params of this texture will not match those in readPixels exactly
110916 * but if we are unable to bind some kind of float texture to the frameBuffer
110917 * then we definitely will not be able to read float values from it.
110918 */
110919 function isDownloadFloatTextureEnabled(webGLVersion) {
110920 if (webGLVersion === 0) {
110921 return false;
110922 }
110923 var gl = getWebGLContext(webGLVersion);
110924 if (webGLVersion === 1) {
110925 if (!hasExtension(gl, 'OES_texture_float')) {
110926 return false;
110927 }
110928 if (!hasExtension(gl, 'WEBGL_color_buffer_float')) {
110929 return false;
110930 }
110931 } else {
110932 if (hasExtension(gl, 'EXT_color_buffer_float')) {
110933 return createFloatTextureAndBindToFramebuffer(gl);
110934 }
110935 var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
110936 if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) {
110937 var textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
110938 return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension);
110939 }
110940 return false;
110941 }
110942 var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
110943 return isFrameBufferComplete;
110944 }
110945 function createFloatTextureAndBindToFramebuffer(gl) {
110946 var texConfig = getTextureConfig(gl);
110947 var texture = gl.createTexture();
110948 gl.bindTexture(gl.TEXTURE_2D, texture);
110949 var width = 1;
110950 var height = 1;
110951 gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null);
110952 var frameBuffer = gl.createFramebuffer();
110953 gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
110954 gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
110955 var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
110956 gl.bindTexture(gl.TEXTURE_2D, null);
110957 gl.bindFramebuffer(gl.FRAMEBUFFER, null);
110958 gl.deleteTexture(texture);
110959 gl.deleteFramebuffer(frameBuffer);
110960 return isFrameBufferComplete;
110961 }
110962 function createHalfFloatTextureAndBindToFramebuffer(
110963 // tslint:disable-next-line:no-any
110964 gl, textureHalfFloatExtension) {
110965 var texConfig = getTextureConfig(gl, textureHalfFloatExtension);
110966 var texture = gl.createTexture();
110967 gl.bindTexture(gl.TEXTURE_2D, texture);
110968 var width = 1;
110969 var height = 1;
110970 gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null);
110971 var frameBuffer = gl.createFramebuffer();
110972 gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
110973 gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
110974 var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
110975 gl.bindTexture(gl.TEXTURE_2D, null);
110976 gl.bindFramebuffer(gl.FRAMEBUFFER, null);
110977 gl.deleteTexture(texture);
110978 gl.deleteFramebuffer(frameBuffer);
110979 return isFrameBufferComplete;
110980 }
110981 function isWebGLFenceEnabled(webGLVersion) {
110982 if (webGLVersion !== 2) {
110983 return false;
110984 }
110985 var gl = getWebGLContext(webGLVersion);
110986 // tslint:disable-next-line:no-any
110987 var isEnabled = gl.fenceSync != null;
110988 return isEnabled;
110989 }
110990 function assertNotComplex(tensor, opName) {
110991 if (!Array.isArray(tensor)) {
110992 tensor = [tensor];
110993 }
110994 tensor.forEach(function (t) {
110995 if (t != null) {
110996 assert$1(t.dtype !== 'complex64', function () {
110997 return "".concat(opName, " does not support complex64 tensors ") + 'in the WebGL backend.';
110998 });
110999 }
111000 });
111001 }
111002
111003 var webgl_util = {
111004 __proto__: null,
111005 assertNotComplex: assertNotComplex,
111006 bindCanvasToFramebuffer: bindCanvasToFramebuffer,
111007 bindColorTextureToFramebuffer: bindColorTextureToFramebuffer,
111008 bindTextureToProgramUniformSampler: bindTextureToProgramUniformSampler,
111009 bindTextureUnit: bindTextureUnit,
111010 bindVertexBufferToProgramAttribute: bindVertexBufferToProgramAttribute,
111011 callAndCheck: callAndCheck,
111012 canBeRepresented: canBeRepresented,
111013 createFragmentShader: createFragmentShader,
111014 createFramebuffer: createFramebuffer,
111015 createProgram: createProgram,
111016 createStaticIndexBuffer: createStaticIndexBuffer,
111017 createStaticVertexBuffer: createStaticVertexBuffer,
111018 createTexture: createTexture,
111019 createVertexShader: createVertexShader$1,
111020 getBatchDim: getBatchDim,
111021 getExtensionOrThrow: getExtensionOrThrow,
111022 getFramebufferErrorMessage: getFramebufferErrorMessage,
111023 getMaxTexturesInShader: getMaxTexturesInShader,
111024 getNumChannels: getNumChannels,
111025 getProgramUniformLocation: getProgramUniformLocation,
111026 getProgramUniformLocationOrThrow: getProgramUniformLocationOrThrow,
111027 getRowsCols: getRowsCols,
111028 getShapeAs3D: getShapeAs3D,
111029 getTextureShapeFromLogicalShape: getTextureShapeFromLogicalShape,
111030 getWebGLDisjointQueryTimerVersion: getWebGLDisjointQueryTimerVersion,
111031 getWebGLErrorMessage: getWebGLErrorMessage,
111032 getWebGLMaxTextureSize: getWebGLMaxTextureSize,
111033 hasExtension: hasExtension,
111034 isCapableOfRenderingToFloatTexture: isCapableOfRenderingToFloatTexture,
111035 isDownloadFloatTextureEnabled: isDownloadFloatTextureEnabled,
111036 isReshapeFree: isReshapeFree,
111037 isWebGLFenceEnabled: isWebGLFenceEnabled,
111038 isWebGLVersionEnabled: isWebGLVersionEnabled,
111039 linkProgram: linkProgram,
111040 logShaderSourceAndInfoLog: logShaderSourceAndInfoLog,
111041 resetMaxTextureSize: resetMaxTextureSize,
111042 resetMaxTexturesInShader: resetMaxTexturesInShader,
111043 unbindColorTextureFromFramebuffer: unbindColorTextureFromFramebuffer,
111044 unbindTextureUnit: unbindTextureUnit,
111045 validateFramebuffer: validateFramebuffer,
111046 validateProgram: validateProgram,
111047 validateTextureSize: validateTextureSize
111048 };
111049
111050 /**
111051 * @license
111052 * Copyright 2019 Google LLC. All Rights Reserved.
111053 * Licensed under the Apache License, Version 2.0 (the "License");
111054 * you may not use this file except in compliance with the License.
111055 * You may obtain a copy of the License at
111056 *
111057 * http://www.apache.org/licenses/LICENSE-2.0
111058 *
111059 * Unless required by applicable law or agreed to in writing, software
111060 * distributed under the License is distributed on an "AS IS" BASIS,
111061 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
111062 * See the License for the specific language governing permissions and
111063 * limitations under the License.
111064 * =============================================================================
111065 */
111066 var ENV = env();
111067 /**
111068 * This file contains WebGL-specific flag registrations.
111069 */
111070 /**
111071 * True if WebGL is supported.
111072 */
111073 ENV.registerFlag('HAS_WEBGL', function () {
111074 return ENV.getNumber('WEBGL_VERSION') > 0;
111075 });
111076 /** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */
111077 ENV.registerFlag('WEBGL_VERSION', function () {
111078 if (isWebGLVersionEnabled(2)) {
111079 return 2;
111080 } else if (isWebGLVersionEnabled(1)) {
111081 return 1;
111082 }
111083 return 0;
111084 });
111085 /** Whether to check for numerical representation problems. */
111086 ENV.registerFlag('WEBGL_CHECK_NUMERICAL_PROBLEMS', function () {
111087 return false;
111088 });
111089 ENV.registerFlag('WEBGL_BUFFER_SUPPORTED', function () {
111090 return ENV.get('WEBGL_VERSION') === 2;
111091 });
111092 /** Whether the WebGL backend will sometimes forward ops to the CPU. */
111093 ENV.registerFlag('WEBGL_CPU_FORWARD', function () {
111094 return true;
111095 });
111096 /** Whether the WebGL backend will always use f16 textures for rendering. */
111097 ENV.registerFlag('WEBGL_FORCE_F16_TEXTURES', function () {
111098 return false;
111099 });
111100 /** Whether to turn all packing related flags on. */
111101 ENV.registerFlag('WEBGL_PACK', function () {
111102 return ENV.getBool('HAS_WEBGL');
111103 });
111104 /** Whether we will pack the batchnormalization op. */
111105 ENV.registerFlag('WEBGL_PACK_NORMALIZATION', function () {
111106 return ENV.getBool('WEBGL_PACK');
111107 });
111108 /** Whether we will pack the clip op. */
111109 ENV.registerFlag('WEBGL_PACK_CLIP', function () {
111110 return ENV.getBool('WEBGL_PACK');
111111 });
111112 /** Whether we will pack the depthwise conv op. */
111113 ENV.registerFlag('WEBGL_PACK_DEPTHWISECONV', function () {
111114 return ENV.getBool('WEBGL_PACK');
111115 });
111116 /** Whether we will pack binary ops. */
111117 ENV.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', function () {
111118 return ENV.getBool('WEBGL_PACK');
111119 });
111120 /** Whether we will pack unary ops. */
111121 ENV.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', function () {
111122 return ENV.getBool('WEBGL_PACK');
111123 });
111124 /** Whether we will pack array ops. */
111125 ENV.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', function () {
111126 return ENV.getBool('WEBGL_PACK');
111127 });
111128 /** Whether we will pack image ops. */
111129 ENV.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', function () {
111130 return ENV.getBool('WEBGL_PACK');
111131 });
111132 /** Whether we will pack reduce ops. */
111133 ENV.registerFlag('WEBGL_PACK_REDUCE', function () {
111134 return ENV.getBool('WEBGL_PACK');
111135 });
111136 /** Whether packed WebGL kernels lazily unpack their outputs. */
111137 ENV.registerFlag('WEBGL_LAZILY_UNPACK', function () {
111138 return ENV.getBool('WEBGL_PACK');
111139 });
111140 /** Whether we will use the im2col algorithm to speed up convolutions. */
111141 ENV.registerFlag('WEBGL_CONV_IM2COL', function () {
111142 return ENV.getBool('WEBGL_PACK');
111143 });
111144 /** Whether we will pack conv2dTranspose op. */
111145 ENV.registerFlag('WEBGL_PACK_CONV2DTRANSPOSE', function () {
111146 return ENV.getBool('WEBGL_PACK');
111147 });
111148 /** The maximum texture dimension. */
111149 ENV.registerFlag('WEBGL_MAX_TEXTURE_SIZE', function () {
111150 return getWebGLMaxTextureSize(ENV.getNumber('WEBGL_VERSION'));
111151 });
111152 /** The maximum texture dimension. */
111153 ENV.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', function () {
111154 return getMaxTexturesInShader(ENV.getNumber('WEBGL_VERSION'));
111155 });
111156 /**
111157 * The disjoint_query_timer extension version.
111158 * 0: disabled, 1: EXT_disjoint_timer_query, 2:
111159 * EXT_disjoint_timer_query_webgl2.
111160 * In Firefox with WebGL 2.0,
111161 * EXT_disjoint_timer_query_webgl2 is not available, so we must use the
111162 * WebGL 1.0 extension.
111163 */
111164 ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', function () {
111165 var webGLVersion = ENV.getNumber('WEBGL_VERSION');
111166 if (webGLVersion === 0) {
111167 return 0;
111168 }
111169 return getWebGLDisjointQueryTimerVersion(webGLVersion);
111170 });
111171 /**
111172 * Whether the timer object from the disjoint_query_timer extension gives
111173 * timing information that is reliable.
111174 */
111175 ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', function () {
111176 return ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 && !isMobile();
111177 });
111178 /**
111179 * Whether the device is physically capable of rendering to float32 textures.
111180 */
111181 ENV.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', function () {
111182 return isCapableOfRenderingToFloatTexture(ENV.getNumber('WEBGL_VERSION'));
111183 });
111184 /**
111185 * Whether rendering to float32 textures is enabled. If disabled, renders to
111186 * float16 textures.
111187 */
111188 ENV.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', function () {
111189 return ENV.getBool('WEBGL_FORCE_F16_TEXTURES') ? false : ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE');
111190 });
111191 /**
111192 * Whether downloading float textures is enabled (16 or 32 bit). If disabled,
111193 * uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading.
111194 */
111195 ENV.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', function () {
111196 return isDownloadFloatTextureEnabled(ENV.getNumber('WEBGL_VERSION'));
111197 });
111198 /** Whether the fence API is available. */
111199 ENV.registerFlag('WEBGL_FENCE_API_ENABLED', function () {
111200 return isWebGLFenceEnabled(ENV.getNumber('WEBGL_VERSION'));
111201 });
111202 /**
111203 * Tensors with size <= than this will be uploaded as uniforms, not textures.
111204 */
111205 ENV.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', function () {
111206 // Use uniform uploads only when 32bit floats are supported. In
111207 // 16bit
111208 // environments there are problems with comparing a 16bit texture value
111209 // with a 32bit uniform value.
111210 var useUniforms = ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED');
111211 return useUniforms ? 4 : 0;
111212 });
111213 /**
111214 * If the total number of bytes allocated on the GPU is greater than this
111215 * number, we will aggressively delete textures upon disposal with
111216 * gl.deleteMatrixTexture, rather than making them available for reuse.
111217 *
111218 * Default value -1 indicates that we will never aggressively delete textures.
111219 */
111220 ENV.registerFlag('WEBGL_DELETE_TEXTURE_THRESHOLD', function () {
111221 return -1;
111222 }, function (threshold) {
111223 if (!(typeof threshold === 'number')) {
111224 throw new Error('WEBGL_DELETE_TEXTURE_THRESHOLD must be a number but ' + "got ".concat(threshold, "."));
111225 }
111226 if (threshold < 0 && threshold !== -1) {
111227 throw new Error("WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never " + "delete) or at least 0, but got ".concat(threshold, "."));
111228 }
111229 });
111230 /**
111231 * Trigger a manual GL command flush if the threshold of time has passed since
111232 * previous Kernel execution. This can be useful for Andorid device where GL
111233 * command flush are delayed un til the end of javascript task. This value is
111234 * measured in millisecond. Typically you want to set this value to close to 1.
111235 *
111236 * Default value 1 for mobile chrome, and -1 for rest cases. -1 indicates that
111237 * we will not enforce manual flush and depend on system default flush schedule.
111238 */
111239 ENV.registerFlag('WEBGL_FLUSH_THRESHOLD', function () {
111240 return isMobile() ? 1 : -1;
111241 }, function (threshold) {
111242 if (!(typeof threshold === 'number')) {
111243 throw new Error('WEBGL_FLUSH_THRESHOLD must be a number but got ' + "".concat(threshold, "."));
111244 }
111245 if (threshold < 0 && threshold !== -1) {
111246 throw new Error("WEBGL_FLUSH_THRESHOLD must be -1 (indicating never " + "manual flush) or at least 0, but got ".concat(threshold, "."));
111247 }
111248 });
111249 /**
111250 * Threshold for input tensor size that determines whether WebGL backend will
111251 * delegate computation to CPU.
111252 *
111253 * Default value is 128.
111254 */
111255 ENV.registerFlag('CPU_HANDOFF_SIZE_THRESHOLD', function () {
111256 return 128;
111257 });
111258 /** Whether we will use shapes uniforms. */
111259 ENV.registerFlag('WEBGL_USE_SHAPES_UNIFORMS', function () {
111260 return false;
111261 });
111262 /**
111263 * Threshold for last dimension of input tensor that determines whether
111264 * WebGL backend for the Top K op will delegate computation to CPU. If input
111265 * is smaller than threshold then CPU will be used
111266 *
111267 * Default value is 100000.
111268 */
111269 ENV.registerFlag('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD', function () {
111270 return 100000;
111271 });
111272 /**
111273 * Threshold for K that determines whether
111274 * WebGL backend for the Top K op will delegate computation to CPU. If k
111275 * is larger than threshold then CPU will be used
111276 *
111277 * Default value is 128.
111278 */
111279 ENV.registerFlag('TOPK_K_CPU_HANDOFF_THRESHOLD', function () {
111280 return 128;
111281 });
111282 /** Whether we will use the experimental conv op. */
111283 ENV.registerFlag('WEBGL_EXP_CONV', function () {
111284 return false;
111285 });
111286 /**
111287 * If the device performance is low or if no hardware GPU is available, whether
111288 * software WebGL will be used.
111289 */
111290 ENV.registerFlag('SOFTWARE_WEBGL_ENABLED', function () {
111291 return ENV.getBool('IS_TEST');
111292 });
111293 /**
111294 * For narrow texture (physical height or physical width is 1), if the length of
111295 * any texture edges exceed the threshold, the texture will be reshaped to be
111296 * more squarish.
111297 *
111298 * This flag is used to help some GPUs that could not provide correct
111299 * interpolations for long skinny triangles. We found Mali GPU probably has this
111300 * problem: https://github.com/tensorflow/tfjs/issues/6775.
111301 */
111302 ENV.registerFlag('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', function () {
111303 return Infinity;
111304 });
111305 /**
111306 * If the flag is set to true, the max size of the narrow texture will be auto
111307 * computed and it will be considerred as a threshold to reshape the narrow
111308 * texture to be more squarish.
111309 *
111310 * This flag is used to help some GPUs that could not provide correct
111311 * interpolations for long skinny triangles. We found Mali GPU probably has this
111312 * problem: https://github.com/tensorflow/tfjs/issues/6775.
111313 */
111314 ENV.registerFlag('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', function () {
111315 return false;
111316 });
111317 /**
111318 * Whether to use the customized isnan. It's only useful for webgl2 since webgl1
111319 * doesn't have the builtin isnan.
111320 */
111321 ENV.registerFlag('WEBGL2_ISNAN_CUSTOM', function () {
111322 return false;
111323 });
111324 /** Experimental flag, whether enter compile only phase. */
111325 ENV.registerFlag('ENGINE_COMPILE_ONLY', function () {
111326 return false;
111327 });
111328
111329 /**
111330 * @license
111331 * Copyright 2018 Google LLC. All Rights Reserved.
111332 * Licensed under the Apache License, Version 2.0 (the "License");
111333 * you may not use this file except in compliance with the License.
111334 * You may obtain a copy of the License at
111335 *
111336 * http://www.apache.org/licenses/LICENSE-2.0
111337 *
111338 * Unless required by applicable law or agreed to in writing, software
111339 * distributed under the License is distributed on an "AS IS" BASIS,
111340 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
111341 * See the License for the specific language governing permissions and
111342 * limitations under the License.
111343 * =============================================================================
111344 */
111345 function getGlslDifferences() {
111346 var version;
111347 var attribute;
111348 var varyingVs;
111349 var varyingFs;
111350 var texture2D;
111351 var output;
111352 var defineOutput;
111353 var defineSpecialNaN;
111354 var defineSpecialInf;
111355 var defineRound;
111356 if (env().getNumber('WEBGL_VERSION') === 2) {
111357 version = '#version 300 es';
111358 attribute = 'in';
111359 varyingVs = 'out';
111360 varyingFs = 'in';
111361 texture2D = 'texture';
111362 output = 'outputColor';
111363 defineOutput = 'out vec4 outputColor;';
111364 // Use custom isnan definition to work across differences between
111365 // implementations on various platforms. While this should happen in ANGLE
111366 // we still see differences between android and windows (on chrome) when
111367 // using isnan directly. Since WebGL2 supports uint type and
111368 // floatBitsToUinT built-in function, we could implment isnan following
111369 // IEEE 754 rules.
111370 // NaN defination in IEEE 754-1985 is :
111371 // - sign = either 0 or 1.
111372 // - biased exponent = all 1 bits.
111373 // - fraction = anything except all 0 bits (since all 0 bits represents
111374 // infinity).
111375 // https://en.wikipedia.org/wiki/IEEE_754-1985#Representation_of_non-numbers
111376 defineSpecialNaN = env().getBool('WEBGL2_ISNAN_CUSTOM') ? "\n bool isnan_custom(float val) {\n uint floatToUint = floatBitsToUint(val);\n return (floatToUint & 0x7fffffffu) > 0x7f800000u;\n }\n\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan_custom(val.x),\n isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));\n }\n\n #define isnan(value) isnan_custom(value)\n " : '';
111377 // In webgl 2 we do not need to specify a custom isinf so there is no
111378 // need for a special INFINITY constant.
111379 defineSpecialInf = "";
111380 defineRound = "\n #define round(value) newRound(value)\n int newRound(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 newRound(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n ";
111381 } else {
111382 version = '';
111383 attribute = 'attribute';
111384 varyingVs = 'varying';
111385 varyingFs = 'varying';
111386 texture2D = 'texture2D';
111387 output = 'gl_FragColor';
111388 defineOutput = '';
111389 // WebGL1 has no built in isnan so we define one here.
111390 defineSpecialNaN = "\n #define isnan(value) isnan_custom(value)\n bool isnan_custom(float val) {\n return (val > 0. || val < 1. || val == 0.) ? false : true;\n }\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));\n }\n ";
111391 defineSpecialInf = "\n uniform float INFINITY;\n\n bool isinf(float val) {\n return abs(val) == INFINITY;\n }\n bvec4 isinf(vec4 val) {\n return equal(abs(val), vec4(INFINITY));\n }\n ";
111392 defineRound = "\n int round(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 round(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n ";
111393 }
111394 return {
111395 version: version,
111396 attribute: attribute,
111397 varyingVs: varyingVs,
111398 varyingFs: varyingFs,
111399 texture2D: texture2D,
111400 output: output,
111401 defineOutput: defineOutput,
111402 defineSpecialNaN: defineSpecialNaN,
111403 defineSpecialInf: defineSpecialInf,
111404 defineRound: defineRound
111405 };
111406 }
111407
111408 /**
111409 * @license
111410 * Copyright 2018 Google LLC. All Rights Reserved.
111411 * Licensed under the Apache License, Version 2.0 (the "License");
111412 * you may not use this file except in compliance with the License.
111413 * You may obtain a copy of the License at
111414 *
111415 * http://www.apache.org/licenses/LICENSE-2.0
111416 *
111417 * Unless required by applicable law or agreed to in writing, software
111418 * distributed under the License is distributed on an "AS IS" BASIS,
111419 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
111420 * See the License for the specific language governing permissions and
111421 * limitations under the License.
111422 * =============================================================================
111423 */
111424 /**
111425 * Produces GLSL code that derives logical coordinates from a flat
111426 * index. The code performs integer division with each stride and decrements
111427 * the index until the index equals the final dimension coordinate.
111428 */
111429 function getLogicalCoordinatesFromFlatIndex(coords, shape) {
111430 var index = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 'index';
111431 var strides = computeStrides(shape);
111432 return strides.map(function (stride, i) {
111433 var line1 = "int ".concat(coords[i], " = ").concat(index, " / ").concat(stride);
111434 var line2 = i === strides.length - 1 ? "int ".concat(coords[i + 1], " = ").concat(index, " - ").concat(coords[i], " * ").concat(stride) : "index -= ".concat(coords[i], " * ").concat(stride);
111435 return "".concat(line1, "; ").concat(line2, ";");
111436 }).join('');
111437 }
111438 function getOutputLogicalCoordinatesFromFlatIndexByUniform(coords, shape) {
111439 var index = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 'index';
111440 var strides = computeStrides(shape);
111441 return strides.map(function (_, i) {
111442 var line1 = "int ".concat(coords[i], " = ").concat(index, " / outShapeStrides[").concat(i, "]");
111443 var line2 = i === strides.length - 1 ? "int ".concat(coords[i + 1], " = ").concat(index, " - ").concat(coords[i], " * outShapeStrides[").concat(i, "]") : "index -= ".concat(coords[i], " * outShapeStrides[").concat(i, "]");
111444 return "".concat(line1, "; ").concat(line2, ";");
111445 }).join('');
111446 }
111447 // Produces GLSL code that computes strides.
111448 function symbolicallyComputeStrides(indicesArr, variableName) {
111449 var numCoords = indicesArr.length;
111450 var shape = indicesArr.map(function (d) {
111451 return "".concat(variableName, "[").concat(d, "]");
111452 });
111453 var strides = new Array(numCoords - 1);
111454 strides[numCoords - 2] = shape[numCoords - 1];
111455 for (var i = numCoords - 3; i >= 0; --i) {
111456 strides[i] = "(".concat(strides[i + 1], " * ").concat(shape[i + 1], ")");
111457 }
111458 return strides;
111459 }
111460 function getLogicalCoordinatesFromFlatIndexByUniform(coords, variableName) {
111461 var index = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 'index';
111462 var indicesArray = coords.map(function (_, i) {
111463 return i;
111464 });
111465 var strides = symbolicallyComputeStrides(indicesArray, variableName);
111466 return strides.map(function (_, i) {
111467 var line1 = "int ".concat(coords[i], " = ").concat(index, " / ").concat(strides[i]);
111468 var line2 = i === strides.length - 1 ? "int ".concat(coords[i + 1], " = ").concat(index, " - ").concat(coords[i], " * ").concat(strides[i]) : "index -= ".concat(coords[i], " * ").concat(strides[i]);
111469 return "".concat(line1, "; ").concat(line2, ";");
111470 }).join('');
111471 }
111472 function buildVec(x) {
111473 if (x.length === 1) {
111474 return "".concat(x[0]);
111475 }
111476 return "vec".concat(x.length, "(").concat(x.join(','), ")");
111477 }
111478 /**
111479 * Produces GLSL code that computes the dot product of the input x and y
111480 * vectors. Handles splitting inputs into increments of vec4s when necessary.
111481 */
111482 function dotify(x, y) {
111483 if (x.length !== y.length) {
111484 throw new Error("Vectors to be dotted must be of the same length -" + "got ".concat(x.length, " and ").concat(y.length));
111485 }
111486 var slices = [];
111487 var nearestVec4 = Math.floor(x.length / 4);
111488 var nearestVec4Remainder = x.length % 4;
111489 for (var i = 0; i < nearestVec4; i++) {
111490 var xSlice = x.slice(i * 4, i * 4 + 4);
111491 var ySlice = y.slice(i * 4, i * 4 + 4);
111492 slices.push("".concat(buildVec(xSlice), ", ").concat(buildVec(ySlice)));
111493 }
111494 if (nearestVec4Remainder !== 0) {
111495 var _xSlice = x.slice(nearestVec4 * 4);
111496 var _ySlice = y.slice(nearestVec4 * 4);
111497 if (_xSlice.length === 1) {
111498 _xSlice = _xSlice.map(function (d) {
111499 return "float(".concat(d, ")");
111500 });
111501 _ySlice = _ySlice.map(function (d) {
111502 return "float(".concat(d, ")");
111503 });
111504 }
111505 slices.push("".concat(buildVec(_xSlice), ", ").concat(buildVec(_ySlice)));
111506 }
111507 return slices.map(function (d, i) {
111508 return "dot(".concat(d, ")");
111509 }).join('+');
111510 }
111511 /**
111512 * Produces GLSL that computes the flat index from 3D coordinates.
111513 */
111514 function getFlatIndexFrom3D(shape) {
111515 var strides = computeStrides(shape).map(function (d) {
111516 return d.toString();
111517 });
111518 return "\n int getFlatIndex(ivec3 coords) {\n return coords.x * ".concat(strides[0], " + coords.y * ").concat(strides[1], " + coords.z;\n }\n");
111519 }
111520 function getFlatIndexFrom3DOutput() {
111521 return "\n int getFlatIndex(ivec3 coords) {\n return coords.x * outShapeStrides[0] + coords.y * outShapeStrides[1] + coords.z;\n }\n";
111522 }
111523 var ENCODE_FLOAT_SNIPPET = "\n const float FLOAT_MAX = 1.70141184e38;\n const float FLOAT_MIN = 1.17549435e-38;\n\n lowp vec4 encode_float(highp float v) {\n if (isnan(v)) {\n return vec4(255, 255, 255, 255);\n }\n\n highp float av = abs(v);\n\n if(av < FLOAT_MIN) {\n return vec4(0.0, 0.0, 0.0, 0.0);\n } else if(v > FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;\n } else if(v < -FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;\n }\n\n highp vec4 c = vec4(0,0,0,0);\n\n highp float e = floor(log2(av));\n highp float m = exp2(fract(log2(av))) - 1.0;\n\n c[2] = floor(128.0 * m);\n m -= c[2] / 128.0;\n c[1] = floor(32768.0 * m);\n m -= c[1] / 32768.0;\n c[0] = floor(8388608.0 * m);\n\n highp float ebias = e + 127.0;\n c[3] = floor(ebias / 2.0);\n ebias -= c[3] * 2.0;\n c[2] += floor(ebias) * 128.0;\n\n c[3] += 128.0 * step(0.0, -v);\n\n return c / 255.0;\n }\n";
111524
111525 var getBroadcastDims = getBroadcastDims$1;
111526 function makeShader(inputsInfo, outputShape, program) {
111527 var prefixSnippets = [];
111528 inputsInfo.forEach(function (x) {
111529 var size = sizeFromShape(x.shapeInfo.logicalShape);
111530 // Snippet when we decided to upload the values as uniform.
111531 if (x.shapeInfo.isUniform) {
111532 prefixSnippets.push("uniform float ".concat(x.name).concat(size > 1 ? "[".concat(size, "]") : '', ";"));
111533 } else {
111534 prefixSnippets.push("uniform sampler2D ".concat(x.name, ";"));
111535 prefixSnippets.push("uniform int offset".concat(x.name, ";"));
111536 }
111537 if (program.enableShapeUniforms) {
111538 var _getUniformInfoFromSh = getUniformInfoFromShape(program.packedInputs, x.shapeInfo.logicalShape, x.shapeInfo.texShape),
111539 uniformShape = _getUniformInfoFromSh.uniformShape;
111540 switch (uniformShape.length) {
111541 case 1:
111542 prefixSnippets.push("uniform int ".concat(x.name, "Shape;"));
111543 break;
111544 case 2:
111545 prefixSnippets.push("uniform ivec2 ".concat(x.name, "Shape;"));
111546 break;
111547 case 3:
111548 prefixSnippets.push("uniform ivec3 ".concat(x.name, "Shape;"));
111549 break;
111550 case 4:
111551 prefixSnippets.push("uniform ivec4 ".concat(x.name, "Shape;"));
111552 break;
111553 default:
111554 break;
111555 }
111556 prefixSnippets.push("uniform ivec2 ".concat(x.name, "TexShape;"));
111557 }
111558 });
111559 if (program.enableShapeUniforms) {
111560 switch (outputShape.logicalShape.length) {
111561 case 1:
111562 prefixSnippets.push("uniform int outShape;");
111563 break;
111564 case 2:
111565 prefixSnippets.push("uniform ivec2 outShape;");
111566 prefixSnippets.push("uniform int outShapeStrides;");
111567 break;
111568 case 3:
111569 prefixSnippets.push("uniform ivec3 outShape;");
111570 prefixSnippets.push("uniform ivec2 outShapeStrides;");
111571 break;
111572 case 4:
111573 prefixSnippets.push("uniform ivec4 outShape;");
111574 prefixSnippets.push("uniform ivec3 outShapeStrides;");
111575 break;
111576 default:
111577 break;
111578 }
111579 prefixSnippets.push("uniform ivec2 outTexShape;");
111580 }
111581 if (program.customUniforms) {
111582 program.customUniforms.forEach(function (d) {
111583 prefixSnippets.push("uniform ".concat(d.type, " ").concat(d.name).concat(d.arrayIndex ? "[".concat(d.arrayIndex, "]") : '', ";"));
111584 });
111585 }
111586 var inputPrefixSnippet = prefixSnippets.join('\n');
111587 var inputSamplingSnippet = inputsInfo.map(function (x) {
111588 return getInputSamplingSnippet(x, outputShape, program.packedInputs, program.enableShapeUniforms);
111589 }).join('\n');
111590 var outTexShape = outputShape.texShape;
111591 var glsl = getGlslDifferences();
111592 var floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl);
111593 var outputSamplingSnippet;
111594 var floatTextureSetOutputSnippet;
111595 var shaderPrefix = getShaderPrefix(glsl);
111596 if (outputShape.isPacked) {
111597 outputSamplingSnippet = getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
111598 floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl);
111599 } else {
111600 outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
111601 floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl);
111602 }
111603 if (program.packedInputs) {
111604 shaderPrefix += SHADER_PACKED_PREFIX;
111605 }
111606 var source = [shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet, inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, program.userCode].join('\n');
111607 return source;
111608 }
111609 function getSamplerFromInInfo(inInfo) {
111610 var enableShapeUniforms = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
111611 var shape = inInfo.shapeInfo.logicalShape;
111612 switch (shape.length) {
111613 case 0:
111614 return getSamplerScalar(inInfo, enableShapeUniforms);
111615 case 1:
111616 return getSampler1D(inInfo, enableShapeUniforms);
111617 case 2:
111618 return getSampler2D(inInfo, enableShapeUniforms);
111619 case 3:
111620 return getSampler3D(inInfo, enableShapeUniforms);
111621 case 4:
111622 return getSampler4D(inInfo, enableShapeUniforms);
111623 case 5:
111624 return getSampler5D(inInfo);
111625 case 6:
111626 return getSampler6D(inInfo);
111627 default:
111628 throw new Error("".concat(shape.length, "-D input sampling") + " is not yet supported");
111629 }
111630 }
111631 function getPackedSamplerFromInInfo(inInfo, enableShapeUniforms) {
111632 var shape = inInfo.shapeInfo.logicalShape;
111633 switch (shape.length) {
111634 case 0:
111635 return getPackedSamplerScalar(inInfo);
111636 case 1:
111637 return getPackedSampler1D(inInfo, enableShapeUniforms);
111638 case 2:
111639 return getPackedSampler2D(inInfo, enableShapeUniforms);
111640 case 3:
111641 return getPackedSampler3D(inInfo, enableShapeUniforms);
111642 default:
111643 return getPackedSamplerND(inInfo, enableShapeUniforms);
111644 }
111645 }
111646 function getInputSamplingSnippet(inInfo, outShapeInfo) {
111647 var usesPackedTextures = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
111648 var enableShapeUniforms = arguments.length > 3 ? arguments[3] : undefined;
111649 var res = '';
111650 if (usesPackedTextures) {
111651 res += getPackedSamplerFromInInfo(inInfo, enableShapeUniforms);
111652 } else {
111653 res += getSamplerFromInInfo(inInfo, enableShapeUniforms);
111654 }
111655 var inShape = inInfo.shapeInfo.logicalShape;
111656 var outShape = outShapeInfo.logicalShape;
111657 if (inShape.length <= outShape.length) {
111658 if (usesPackedTextures) {
111659 res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo);
111660 } else {
111661 res += getSamplerAtOutputCoords(inInfo, outShapeInfo);
111662 }
111663 }
111664 return res;
111665 }
111666 function getPackedOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
111667 switch (outShape.length) {
111668 case 0:
111669 return getOutputScalarCoords();
111670 case 1:
111671 return getOutputPacked1DCoords(outShape, outTexShape, enableShapeUniforms);
111672 case 2:
111673 return getOutputPacked2DCoords(outShape, outTexShape, enableShapeUniforms);
111674 case 3:
111675 return getOutputPacked3DCoords(outShape, outTexShape, enableShapeUniforms);
111676 default:
111677 return getOutputPackedNDCoords(outShape, outTexShape, enableShapeUniforms);
111678 }
111679 }
111680 function getOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
111681 switch (outShape.length) {
111682 case 0:
111683 return getOutputScalarCoords();
111684 case 1:
111685 return getOutput1DCoords(outShape, outTexShape, enableShapeUniforms);
111686 case 2:
111687 return getOutput2DCoords(outShape, outTexShape, enableShapeUniforms);
111688 case 3:
111689 return getOutput3DCoords(outShape, outTexShape, enableShapeUniforms);
111690 case 4:
111691 return getOutput4DCoords(outShape, outTexShape, enableShapeUniforms);
111692 case 5:
111693 return getOutput5DCoords(outShape, outTexShape);
111694 case 6:
111695 return getOutput6DCoords(outShape, outTexShape);
111696 default:
111697 throw new Error("".concat(outShape.length, "-D output sampling is not yet supported"));
111698 }
111699 }
111700 function getFloatTextureSampleSnippet(glsl) {
111701 return "\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n return ".concat(glsl.texture2D, "(textureSampler, uv).r;\n }\n ");
111702 }
111703 function getFloatTextureSetRSnippet(glsl) {
111704 return "\n void setOutput(float val) {\n ".concat(glsl.output, " = vec4(val, 0, 0, 0);\n }\n ");
111705 }
111706 function getFloatTextureSetRGBASnippet(glsl) {
111707 return "\n void setOutput(vec4 val) {\n ".concat(glsl.output, " = val;\n }\n ");
111708 }
111709 function getShaderPrefix(glsl) {
111710 var SHADER_PREFIX = "".concat(glsl.version, "\n precision highp float;\n precision highp int;\n precision highp sampler2D;\n ").concat(glsl.varyingFs, " vec2 resultUV;\n ").concat(glsl.defineOutput, "\n const vec2 halfCR = vec2(0.5, 0.5);\n\n struct ivec5\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n };\n\n struct ivec6\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n int v;\n };\n\n uniform float NAN;\n ").concat(glsl.defineSpecialNaN, "\n ").concat(glsl.defineSpecialInf, "\n ").concat(glsl.defineRound, "\n\n int imod(int x, int y) {\n return x - y * (x / y);\n }\n\n int idiv(int a, int b, float sign) {\n int res = a / b;\n int mod = imod(a, b);\n if (sign < 0. && mod != 0) {\n res -= 1;\n }\n return res;\n }\n\n //Based on the work of Dave Hoskins\n //https://www.shadertoy.com/view/4djSRW\n #define HASHSCALE1 443.8975\n float random(float seed){\n vec2 p = resultUV * seed;\n vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);\n p3 += dot(p3, p3.yzx + 19.19);\n return fract((p3.x + p3.y) * p3.z);\n }\n\n ").concat(SAMPLE_1D_SNIPPET, "\n ").concat(SAMPLE_2D_SNIPPET, "\n ").concat(SAMPLE_3D_SNIPPET, "\n ");
111711 return SHADER_PREFIX;
111712 }
111713 var SAMPLE_1D_SNIPPET = "\nvec2 uvFromFlat(int texNumR, int texNumC, int index) {\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\nvec2 packedUVfrom1D(int texNumR, int texNumC, int index) {\n int texelIndex = index / 2;\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
111714 var SAMPLE_2D_SNIPPET = "\nvec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,\n int texNumC, int row, int col) {\n int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
111715 var SAMPLE_3D_SNIPPET = "\nvec2 packedUVfrom3D(int texNumR, int texNumC,\n int texelsInBatch, int texelsInLogicalRow, int b,\n int row, int col) {\n int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
111716 var SHADER_PACKED_PREFIX = "\n float getChannel(vec4 frag, vec2 innerDims) {\n vec2 modCoord = mod(innerDims, 2.);\n return modCoord.x == 0. ?\n (modCoord.y == 0. ? frag.r : frag.g) :\n (modCoord.y == 0. ? frag.b : frag.a);\n }\n float getChannel(vec4 frag, int dim) {\n float modCoord = mod(float(dim), 2.);\n return modCoord == 0. ? frag.r : frag.g;\n }\n";
111717 function getOutputScalarCoords() {
111718 return "\n int getOutputCoords() {\n return 0;\n }\n ";
111719 }
111720 function getOutputPacked1DCoords(shape, texShape, enableShapeUniforms) {
111721 var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
111722 if (packedTexShape[0] === 1) {
111723 if (enableShapeUniforms) {
111724 return "\n int getOutputCoords() {\n return 2 * int(resultUV.x * ceil(float(outTexShape[1]) / 2.0));\n }\n ";
111725 }
111726 return "\n int getOutputCoords() {\n return 2 * int(resultUV.x * ".concat(packedTexShape[1], ".0);\n }\n ");
111727 }
111728 if (packedTexShape[1] === 1) {
111729 if (enableShapeUniforms) {
111730 return "\n int getOutputCoords() {\n return 2 * int(resultUV.y * ceil(float(outTexShape[0]) / 2.0));\n }\n ";
111731 }
111732 return "\n int getOutputCoords() {\n return 2 * int(resultUV.y * ".concat(packedTexShape[0], ".0);\n }\n ");
111733 }
111734 if (enableShapeUniforms) {
111735 return "\n int getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n return 2 * (resTexRC.x * packedTexShape[1] + resTexRC.y);\n }\n ";
111736 }
111737 return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(packedTexShape[0], ", ").concat(packedTexShape[1], "));\n return 2 * (resTexRC.x * ").concat(packedTexShape[1], " + resTexRC.y);\n }\n ");
111738 }
111739 function getOutput1DCoords(shape, texShape, enableShapeUniforms) {
111740 if (texShape[0] === 1) {
111741 if (enableShapeUniforms) {
111742 return "\n int getOutputCoords() {\n return int(resultUV.x * float(outTexShape[1]));\n }\n ";
111743 }
111744 return "\n int getOutputCoords() {\n return int(resultUV.x * ".concat(texShape[1], ".0);\n }\n ");
111745 }
111746 if (texShape[1] === 1) {
111747 if (enableShapeUniforms) {
111748 return "\n int getOutputCoords() {\n return int(resultUV.y * float(outTexShape[0]));\n }\n ";
111749 }
111750 return "\n int getOutputCoords() {\n return int(resultUV.y * ".concat(texShape[0], ".0);\n }\n ");
111751 }
111752 if (enableShapeUniforms) {
111753 return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n return resTexRC.x * outTexShape[1] + resTexRC.y;\n }\n ";
111754 }
111755 return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n return resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n }\n ");
111756 }
111757 function getOutputPacked3DCoords(shape, texShape, enableShapeUniforms) {
111758 if (enableShapeUniforms) {
111759 return "\n ivec3 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n int texelsInLogicalRow = int(ceil(float(outShape[2]) / 2.0));\n int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n int index = resTexRC.x * packedTexShape[1] + resTexRC.y;\n\n int b = index / texelsInBatch;\n index -= b * texelsInBatch;\n\n int r = 2 * (index / texelsInLogicalRow);\n int c = imod(index, texelsInLogicalRow) * 2;\n\n return ivec3(b, r, c);\n }\n ";
111760 }
111761 var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
111762 var texelsInLogicalRow = Math.ceil(shape[2] / 2);
111763 var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);
111764 return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(packedTexShape[0], ", ").concat(packedTexShape[1], "));\n int index = resTexRC.x * ").concat(packedTexShape[1], " + resTexRC.y;\n\n int b = index / ").concat(texelsInBatch, ";\n index -= b * ").concat(texelsInBatch, ";\n\n int r = 2 * (index / ").concat(texelsInLogicalRow, ");\n int c = imod(index, ").concat(texelsInLogicalRow, ") * 2;\n\n return ivec3(b, r, c);\n }\n ");
111765 }
111766 function getOutput3DCoords(shape, texShape, enableShapeUniforms) {
111767 if (enableShapeUniforms) {
111768 var _coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], shape);
111769 return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n ".concat(_coordsFromIndexSnippet, "\n return ivec3(r, c, d);\n }\n");
111770 }
111771 var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
111772 return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n ").concat(coordsFromIndexSnippet, "\n return ivec3(r, c, d);\n }\n ");
111773 }
111774 function getOutputPackedNDCoords(shape, texShape, enableShapeUniforms) {
111775 if (enableShapeUniforms) {
111776 // TODO: support 5d and 6d
111777 return "\n ivec4 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n int index = resTexRC.x * packedTexShape[1] + resTexRC.y;\n\n int texelsInLogicalRow = int(ceil(float(outShape[3]) / 2.0));\n int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[2]) / 2.0));\n int texelsInBatchN = texelsInBatch * outShape[1];\n\n int b2 = index / texelsInBatchN;\n index -= b2 * texelsInBatchN;\n\n int b = index / texelsInBatch;\n index -= b * texelsInBatch;\n\n int r = 2 * (index / texelsInLogicalRow);\n int c = imod(index, texelsInLogicalRow) * 2;\n\n return ivec4(b2, b, r, c);\n }\n ";
111778 }
111779 var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
111780 var texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);
111781 var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);
111782 var texelsInBatchN = texelsInBatch;
111783 var batches = "";
111784 var coords = 'b, r, c';
111785 for (var b = 2; b < shape.length - 1; b++) {
111786 texelsInBatchN *= shape[shape.length - b - 1];
111787 batches = "\n int b".concat(b, " = index / ").concat(texelsInBatchN, ";\n index -= b").concat(b, " * ").concat(texelsInBatchN, ";\n ") + batches;
111788 coords = "b".concat(b, ", ") + coords;
111789 }
111790 return "\n ivec".concat(shape.length, " getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(").concat(packedTexShape[0], ", ").concat(packedTexShape[1], "));\n int index = resTexRC.x * ").concat(packedTexShape[1], " + resTexRC.y;\n\n ").concat(batches, "\n\n int b = index / ").concat(texelsInBatch, ";\n index -= b * ").concat(texelsInBatch, ";\n\n int r = 2 * (index / ").concat(texelsInLogicalRow, ");\n int c = imod(index, ").concat(texelsInLogicalRow, ") * 2;\n\n return ivec").concat(shape.length, "(").concat(coords, ");\n }\n ");
111791 }
111792 function getOutput4DCoords(shape, texShape, enableShapeUniforms) {
111793 if (enableShapeUniforms) {
111794 var _coordsFromIndexSnippet2 = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd', 'd2'], shape);
111795 return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n ".concat(_coordsFromIndexSnippet2, "\n return ivec4(r, c, d, d2);\n }\n ");
111796 }
111797 var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape);
111798 return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n ").concat(coordsFromIndexSnippet, "\n return ivec4(r, c, d, d2);\n }\n ");
111799 }
111800 function getOutput5DCoords(shape, texShape) {
111801 var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape);
111802 return "\n ivec5 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(".concat(texShape[0], ",\n ").concat(texShape[1], "));\n\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n\n ").concat(coordsFromIndexSnippet, "\n\n ivec5 outShape = ivec5(r, c, d, d2, d3);\n return outShape;\n }\n ");
111803 }
111804 function getOutput6DCoords(shape, texShape) {
111805 var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape);
111806 return "\n ivec6 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n\n ").concat(coordsFromIndexSnippet, "\n\n ivec6 result = ivec6(r, c, d, d2, d3, d4);\n return result;\n }\n ");
111807 }
111808 function getOutputPacked2DCoords(shape, texShape, enableShapeUniforms) {
111809 var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
111810 if (arraysEqual(shape, texShape)) {
111811 if (enableShapeUniforms) {
111812 return "\n ivec2 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n return 2 * ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1]));\n }\n ";
111813 }
111814 return "\n ivec2 getOutputCoords() {\n return 2 * ivec2(resultUV.yx * vec2(".concat(packedTexShape[0], ", ").concat(packedTexShape[1], "));\n }\n ");
111815 }
111816 // texels needed to accommodate a logical row
111817 var texelsInLogicalRow = Math.ceil(shape[1] / 2);
111818 /**
111819 * getOutputCoords
111820 *
111821 * resTexRC: The rows and columns of the texels. If you move over one
111822 * texel to the right in the packed texture, you are moving over one column
111823 * (not two).
111824 *
111825 * index: The texel index
111826 */
111827 if (enableShapeUniforms) {
111828 return "\n ivec2 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n int texelsInLogicalRow = int(ceil(float(outShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n\n int index = resTexRC.x * packedTexShape[1] + resTexRC.y;\n int r = 2 * (index / texelsInLogicalRow);\n int c = imod(index, texelsInLogicalRow) * 2;\n\n return ivec2(r, c);\n }\n ";
111829 }
111830 return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(packedTexShape[0], ", ").concat(packedTexShape[1], "));\n\n int index = resTexRC.x * ").concat(packedTexShape[1], " + resTexRC.y;\n int r = 2 * (index / ").concat(texelsInLogicalRow, ");\n int c = imod(index, ").concat(texelsInLogicalRow, ") * 2;\n\n return ivec2(r, c);\n }\n ");
111831 }
111832 function getOutput2DCoords(shape, texShape, enableShapeUniforms) {
111833 if (arraysEqual(shape, texShape)) {
111834 if (enableShapeUniforms) {
111835 return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1]));\n }\n ";
111836 }
111837 return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n }\n ");
111838 }
111839 if (shape[1] === 1) {
111840 if (enableShapeUniforms) {
111841 return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n return ivec2(index, 0);\n }\n ";
111842 }
111843 return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n return ivec2(index, 0);\n }\n ");
111844 }
111845 if (shape[0] === 1) {
111846 if (enableShapeUniforms) {
111847 return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n return ivec2(0, index);\n }\n ";
111848 }
111849 return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n return ivec2(0, index);\n }\n ");
111850 }
111851 if (enableShapeUniforms) {
111852 return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n int r = index / outShape[1];\n int c = index - r * outShape[1];\n return ivec2(r, c);\n }\n ";
111853 }
111854 return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n int r = index / ").concat(shape[1], ";\n int c = index - r * ").concat(shape[1], ";\n return ivec2(r, c);\n }\n ");
111855 }
111856 function getFlatOffsetUniformName(texName) {
111857 return "offset".concat(texName);
111858 }
111859 function getPackedSamplerScalar(inputInfo) {
111860 var texName = inputInfo.name;
111861 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
111862 var glsl = getGlslDifferences();
111863 return "\n vec4 ".concat(funcName, "() {\n return ").concat(glsl.texture2D, "(").concat(texName, ", halfCR);\n }\n ");
111864 }
111865 function getSamplerScalar(inputInfo, enableShapeUniforms) {
111866 var texName = inputInfo.name;
111867 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
111868 if (inputInfo.shapeInfo.isUniform) {
111869 return "float ".concat(funcName, "() {return ").concat(texName, ";}");
111870 }
111871 var _inputInfo$shapeInfo$ = _slicedToArray(inputInfo.shapeInfo.texShape, 2),
111872 texNumR = _inputInfo$shapeInfo$[0],
111873 texNumC = _inputInfo$shapeInfo$[1];
111874 if (texNumR === 1 && texNumC === 1) {
111875 return "\n float ".concat(funcName, "() {\n return sampleTexture(").concat(texName, ", halfCR);\n }\n ");
111876 }
111877 var offset = getFlatOffsetUniformName(texName);
111878 if (enableShapeUniforms) {
111879 return "\n float ".concat(funcName, "() {\n vec2 uv = uvFromFlat(").concat(texName, "TexShape[0], ").concat(texName, "TexShape[1], ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111880 }
111881 var _inputInfo$shapeInfo$2 = _slicedToArray(inputInfo.shapeInfo.texShape, 2),
111882 tNumR = _inputInfo$shapeInfo$2[0],
111883 tNumC = _inputInfo$shapeInfo$2[1];
111884 return "\n float ".concat(funcName, "() {\n vec2 uv = uvFromFlat(").concat(tNumR, ", ").concat(tNumC, ", ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111885 }
111886 function getPackedSampler1D(inputInfo, enableShapeUniforms) {
111887 var texName = inputInfo.name;
111888 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
111889 var texShape = inputInfo.shapeInfo.texShape;
111890 var glsl = getGlslDifferences();
111891 if (enableShapeUniforms) {
111892 return "\n vec4 ".concat(funcName, "(int index) {\n ivec2 packedTexShape = ivec2(ceil(float(").concat(texName, "TexShape[0]) / 2.0), ceil(float(").concat(texName, "TexShape[1]) / 2.0));\n vec2 uv = packedUVfrom1D(\n packedTexShape[0], packedTexShape[1], index);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n ");
111893 }
111894 var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
111895 return "\n vec4 ".concat(funcName, "(int index) {\n vec2 uv = packedUVfrom1D(\n ").concat(packedTexShape[0], ", ").concat(packedTexShape[1], ", index);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n ");
111896 }
111897 function getSampler1D(inputInfo, enableShapeUniforms) {
111898 var texName = inputInfo.name;
111899 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
111900 if (inputInfo.shapeInfo.isUniform) {
111901 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
111902 return "\n float ".concat(funcName, "(int index) {\n ").concat(getUniformSampler(inputInfo), "\n }\n ");
111903 }
111904 var texShape = inputInfo.shapeInfo.texShape;
111905 var tNumR = texShape[0];
111906 var tNumC = texShape[1];
111907 if (tNumC === 1 && tNumR === 1) {
111908 return "\n float ".concat(funcName, "(int index) {\n return sampleTexture(").concat(texName, ", halfCR);\n }\n ");
111909 }
111910 var offset = getFlatOffsetUniformName(texName);
111911 if (tNumC === 1) {
111912 if (enableShapeUniforms) {
111913 return "\n float ".concat(funcName, "(int index) {\n vec2 uv = vec2(0.5, (float(index + ").concat(offset, ") + 0.5) / float(").concat(texName, "TexShape[0]));\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111914 }
111915 return "\n float ".concat(funcName, "(int index) {\n vec2 uv = vec2(0.5, (float(index + ").concat(offset, ") + 0.5) / ").concat(tNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111916 }
111917 if (tNumR === 1) {
111918 if (enableShapeUniforms) {
111919 return "\n float ".concat(funcName, "(int index) {\n vec2 uv = vec2((float(index + ").concat(offset, ") + 0.5) / float(").concat(texName, "TexShape[1]), 0.5);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111920 }
111921 return "\n float ".concat(funcName, "(int index) {\n vec2 uv = vec2((float(index + ").concat(offset, ") + 0.5) / ").concat(tNumC, ".0, 0.5);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111922 }
111923 if (enableShapeUniforms) {
111924 return "\n float ".concat(funcName, "(int index) {\n vec2 uv = uvFromFlat(").concat(texName, "TexShape[0], ").concat(texName, "TexShape[1], index + ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111925 }
111926 return "\n float ".concat(funcName, "(int index) {\n vec2 uv = uvFromFlat(").concat(tNumR, ", ").concat(tNumC, ", index + ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111927 }
111928 function getPackedSampler2D(inputInfo, enableShapeUniforms) {
111929 var shape = inputInfo.shapeInfo.logicalShape;
111930 var texName = inputInfo.name;
111931 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
111932 var texShape = inputInfo.shapeInfo.texShape;
111933 var texNumR = texShape[0];
111934 var texNumC = texShape[1];
111935 var glsl = getGlslDifferences();
111936 if (texShape != null && arraysEqual(shape, texShape)) {
111937 if (enableShapeUniforms) {
111938 return "\n vec4 ".concat(funcName, "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n ");
111939 }
111940 return "\n vec4 ".concat(funcName, "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n ");
111941 }
111942 if (enableShapeUniforms) {
111943 return "\n vec4 ".concat(funcName, "(int row, int col) {\n ivec2 packedTexShape = ivec2(ceil(float(").concat(texName, "TexShape[0]) / 2.0), ceil(float(").concat(texName, "TexShape[1]) / 2.0));\n int valuesPerRow = int(ceil(float(").concat(texName, "Shape[1]) / 2.0));\n vec2 uv = packedUVfrom2D(valuesPerRow, packedTexShape[0], packedTexShape[1], row, col);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n ");
111944 }
111945 var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
111946 var valuesPerRow = Math.ceil(shape[1] / 2);
111947 return "\n vec4 ".concat(funcName, "(int row, int col) {\n vec2 uv = packedUVfrom2D(").concat(valuesPerRow, ", ").concat(packedTexShape[0], ", ").concat(packedTexShape[1], ", row, col);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n ");
111948 }
111949 function getSampler2D(inputInfo, enableShapeUniforms) {
111950 var shape = inputInfo.shapeInfo.logicalShape;
111951 var texName = inputInfo.name;
111952 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
111953 var texShape = inputInfo.shapeInfo.texShape;
111954 if (texShape != null && arraysEqual(shape, texShape)) {
111955 if (enableShapeUniforms) {
111956 return "\n float ".concat(funcName, "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111957 }
111958 var _texNumR = texShape[0];
111959 var _texNumC = texShape[1];
111960 return "\n float ".concat(funcName, "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(").concat(_texNumC, ".0, ").concat(_texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111961 }
111962 var _util$squeezeShape = squeezeShape(shape),
111963 newShape = _util$squeezeShape.newShape,
111964 keptDims = _util$squeezeShape.keptDims;
111965 var squeezedShape = newShape;
111966 if (squeezedShape.length < shape.length) {
111967 var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
111968 var params = ['row', 'col'];
111969 return "\n ".concat(getSamplerFromInInfo(newInputInfo, enableShapeUniforms), "\n float ").concat(funcName, "(int row, int col) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n ");
111970 }
111971 if (inputInfo.shapeInfo.isUniform) {
111972 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
111973 return "\n float ".concat(funcName, "(int row, int col) {\n int index = round(dot(vec2(row, col), vec2(").concat(shape[1], ", 1)));\n ").concat(getUniformSampler(inputInfo), "\n }\n ");
111974 }
111975 var texNumR = texShape[0];
111976 var texNumC = texShape[1];
111977 var offset = getFlatOffsetUniformName(texName);
111978 if (texNumC === 1) {
111979 // index is used directly as physical (no risk of float16 overflow).
111980 if (enableShapeUniforms) {
111981 return "\n float ".concat(funcName, "(int row, int col) {\n float index = dot(vec3(row, col, ").concat(offset, "), vec3(").concat(texName, "Shape[1], 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / float(").concat(texName, "TexShape[0]));\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111982 }
111983 return "\n float ".concat(funcName, "(int row, int col) {\n float index = dot(vec3(row, col, ").concat(offset, "), vec3(").concat(shape[1], ", 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111984 }
111985 if (texNumR === 1) {
111986 // index is used directly as physical (no risk of float16 overflow).
111987 if (enableShapeUniforms) {
111988 return "\n float ".concat(funcName, "(int row, int col) {\n float index = dot(vec3(row, col, ").concat(offset, "), vec3(").concat(texName, "Shape[1], 1, 1));\n vec2 uv = vec2((index + 0.5) / float(").concat(texName, "TexShape[1]), 0.5);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111989 }
111990 return "\n float ".concat(funcName, "(int row, int col) {\n float index = dot(vec3(row, col, ").concat(offset, "), vec3(").concat(shape[1], ", 1, 1));\n vec2 uv = vec2((index + 0.5) / ").concat(texNumC, ".0, 0.5);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111991 }
111992 if (enableShapeUniforms) {
111993 return "\n float ".concat(funcName, "(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(texName, "Shape[1] + col + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texName, "TexShape[0], ").concat(texName, "TexShape[1], index);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
111994 }
111995 return "\n float ".concat(funcName, "(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(shape[1], " + col + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texNumR, ", ").concat(texNumC, ", index);\n return sampleTexture(").concat(texName, ", uv);\n }\n");
111996 }
111997 function getPackedSampler3D(inputInfo, enableShapeUniforms) {
111998 var shape = inputInfo.shapeInfo.logicalShape;
111999 var texName = inputInfo.name;
112000 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
112001 var texShape = inputInfo.shapeInfo.texShape;
112002 var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
112003 if (shape[0] === 1) {
112004 var squeezedShape = shape.slice(1);
112005 var keptDims = [1, 2];
112006 var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
112007 var params = ['b', 'row', 'col'];
112008 return "\n ".concat(getPackedSamplerFromInInfo(newInputInfo, enableShapeUniforms), "\n vec4 ").concat(funcName, "(int b, int row, int col) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n ");
112009 }
112010 var glsl = getGlslDifferences();
112011 if (enableShapeUniforms) {
112012 return "\n vec4 ".concat(funcName, "(int b, int row, int col) {\n ivec2 packedTexShape = ivec2(ceil(float(").concat(texName, "TexShape[0]) / 2.0), ceil(float(").concat(texName, "TexShape[1]) / 2.0));\n int valuesPerRow = int(ceil(float(").concat(texName, "Shape[2]) / 2.0));\n int texelsInBatch = valuesPerRow * int(ceil(float(").concat(texName, "Shape[1]) / 2.0));\n vec2 uv = packedUVfrom3D(\n packedTexShape[0], packedTexShape[1], texelsInBatch, valuesPerRow, b, row, col);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n ");
112013 }
112014 var texNumR = packedTexShape[0];
112015 var texNumC = packedTexShape[1];
112016 var valuesPerRow = Math.ceil(shape[2] / 2);
112017 var texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2);
112018 return "\n vec4 ".concat(funcName, "(int b, int row, int col) {\n vec2 uv = packedUVfrom3D(\n ").concat(texNumR, ", ").concat(texNumC, ", ").concat(texelsInBatch, ", ").concat(valuesPerRow, ", b, row, col);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n ");
112019 }
112020 function getSampler3D(inputInfo, enableShapeUniforms) {
112021 var shape = inputInfo.shapeInfo.logicalShape;
112022 var texName = inputInfo.name;
112023 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
112024 var stride0 = shape[1] * shape[2];
112025 var stride1 = shape[2];
112026 var _util$squeezeShape2 = squeezeShape(shape),
112027 newShape = _util$squeezeShape2.newShape,
112028 keptDims = _util$squeezeShape2.keptDims;
112029 var squeezedShape = newShape;
112030 if (squeezedShape.length < shape.length) {
112031 var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
112032 var params = ['row', 'col', 'depth'];
112033 return "\n ".concat(getSamplerFromInInfo(newInputInfo, enableShapeUniforms), "\n float ").concat(funcName, "(int row, int col, int depth) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n ");
112034 }
112035 if (inputInfo.shapeInfo.isUniform) {
112036 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
112037 return "\n float ".concat(funcName, "(int row, int col, int depth) {\n int index = round(dot(vec3(row, col, depth),\n vec3(").concat(stride0, ", ").concat(stride1, ", 1)));\n ").concat(getUniformSampler(inputInfo), "\n }\n ");
112038 }
112039 var texShape = inputInfo.shapeInfo.texShape;
112040 var texNumR = texShape[0];
112041 var texNumC = texShape[1];
112042 var flatOffset = inputInfo.shapeInfo.flatOffset;
112043 if (texNumC === stride0 && flatOffset == null) {
112044 // texC is used directly as physical (no risk of float16 overflow).
112045 if (enableShapeUniforms) {
112046 return "\n float ".concat(funcName, "(int row, int col, int depth) {\n int stride1 = ").concat(texName, "Shape[2];\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(stride1, 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112047 }
112048 return "\n float ".concat(funcName, "(int row, int col, int depth) {\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(").concat(stride1, ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112049 }
112050 if (texNumC === stride1 && flatOffset == null) {
112051 // texR is used directly as physical (no risk of float16 overflow).
112052 if (enableShapeUniforms) {
112053 return "\n float ".concat(funcName, "(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(").concat(texName, "Shape[1], 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112054 }
112055 return "\n float ".concat(funcName, "(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(").concat(shape[1], ", 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112056 }
112057 var offset = getFlatOffsetUniformName(texName);
112058 if (enableShapeUniforms) {
112059 return "\n float ".concat(funcName, "(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int stride0 = ").concat(texName, "Shape[1] * ").concat(texName, "Shape[2];\n int stride1 = ").concat(texName, "Shape[2];\n int index = row * stride0 + col * stride1 + depth + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texName, "TexShape[0], ").concat(texName, "TexShape[1], index);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112060 }
112061 return "\n float ".concat(funcName, "(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(stride0, " + col * ").concat(stride1, " + depth + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texNumR, ", ").concat(texNumC, ", index);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112062 }
112063 function getPackedSamplerND(inputInfo, enableShapeUniforms) {
112064 var texName = inputInfo.name;
112065 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
112066 var glsl = getGlslDifferences();
112067 if (enableShapeUniforms) {
112068 // TODO: support 5d and 6d
112069 return "\n vec4 ".concat(funcName, "(int b2, int b, int row, int col) {\n int valuesPerRow = int(ceil(float(").concat(texName, "Shape[3]) / 2.0));\n int texelsInBatch = valuesPerRow * int(ceil(float(").concat(texName, "Shape[2]) / 2.0));\n int index = b * texelsInBatch + (row / 2) * valuesPerRow + (col / 2);\n texelsInBatch *= ").concat(texName, "Shape[1];\n index = b2 * texelsInBatch + index;\n ivec2 packedTexShape = ivec2(ceil(float(").concat(texName, "TexShape[0]) / 2.0), ceil(float(").concat(texName, "TexShape[1]) / 2.0));\n int texR = index / packedTexShape[1];\n int texC = index - texR * packedTexShape[1];\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(packedTexShape[1], packedTexShape[0]); return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n ");
112070 }
112071 var shape = inputInfo.shapeInfo.logicalShape;
112072 var rank = shape.length;
112073 var texShape = inputInfo.shapeInfo.texShape;
112074 var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
112075 var texNumR = packedTexShape[0];
112076 var texNumC = packedTexShape[1];
112077 var valuesPerRow = Math.ceil(shape[rank - 1] / 2);
112078 var texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2);
112079 var params = "int b, int row, int col";
112080 var index = "b * ".concat(texelsInBatch, " + (row / 2) * ").concat(valuesPerRow, " + (col / 2)");
112081 for (var b = 2; b < rank - 1; b++) {
112082 params = "int b".concat(b, ", ") + params;
112083 texelsInBatch *= shape[rank - b - 1];
112084 index = "b".concat(b, " * ").concat(texelsInBatch, " + ") + index;
112085 }
112086 return "\n vec4 ".concat(funcName, "(").concat(params, ") {\n int index = ").concat(index, ";\n int texR = index / ").concat(texNumC, ";\n int texC = index - texR * ").concat(texNumC, ";\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(").concat(texNumC, ", ").concat(texNumR, ");\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n ");
112087 }
112088 function getSampler4D(inputInfo, enableShapeUniforms) {
112089 var shape = inputInfo.shapeInfo.logicalShape;
112090 var texName = inputInfo.name;
112091 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
112092 var stride2 = shape[3];
112093 var stride1 = shape[2] * stride2;
112094 var stride0 = shape[1] * stride1;
112095 var _util$squeezeShape3 = squeezeShape(shape),
112096 newShape = _util$squeezeShape3.newShape,
112097 keptDims = _util$squeezeShape3.keptDims;
112098 if (newShape.length < shape.length) {
112099 var newInputInfo = squeezeInputInfo(inputInfo, newShape);
112100 var params = ['row', 'col', 'depth', 'depth2'];
112101 return "\n ".concat(getSamplerFromInInfo(newInputInfo, enableShapeUniforms), "\n float ").concat(funcName, "(int row, int col, int depth, int depth2) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n ");
112102 }
112103 if (inputInfo.shapeInfo.isUniform) {
112104 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
112105 return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n int index = round(dot(vec4(row, col, depth, depth2),\n vec4(").concat(stride0, ", ").concat(stride1, ", ").concat(stride2, ", 1)));\n ").concat(getUniformSampler(inputInfo), "\n }\n ");
112106 }
112107 var flatOffset = inputInfo.shapeInfo.flatOffset;
112108 var texShape = inputInfo.shapeInfo.texShape;
112109 var texNumR = texShape[0];
112110 var texNumC = texShape[1];
112111 var stride2Str = "int stride2 = ".concat(texName, "Shape[3];");
112112 var stride1Str = "int stride1 = ".concat(texName, "Shape[2] * stride2;");
112113 var stride0Str = "int stride0 = ".concat(texName, "Shape[1] * stride1;");
112114 if (texNumC === stride0 && flatOffset == null) {
112115 // texC is used directly as physical (no risk of float16 overflow).
112116 if (enableShapeUniforms) {
112117 return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n ").concat(stride2Str, "\n ").concat(stride1Str, "\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(stride1, stride2, 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112118 }
112119 return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(").concat(stride1, ", ").concat(stride2, ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112120 }
112121 if (texNumC === stride2 && flatOffset == null) {
112122 // texR is used directly as physical (no risk of float16 overflow).
112123 if (enableShapeUniforms) {
112124 return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(").concat(texName, "Shape[1] * ").concat(texName, "Shape[2], ").concat(texName, "Shape[2], 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112125 }
112126 return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(").concat(shape[1] * shape[2], ", ").concat(shape[2], ", 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112127 }
112128 var offset = getFlatOffsetUniformName(texName);
112129 if (enableShapeUniforms) {
112130 return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n ").concat(stride2Str, "\n ").concat(stride1Str, "\n ").concat(stride0Str, "\n int index = row * stride0 + col * stride1 +\n depth * stride2 + depth2;\n vec2 uv = uvFromFlat(").concat(texName, "TexShape[0], ").concat(texName, "TexShape[1], index + ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112131 }
112132 return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(stride0, " + col * ").concat(stride1, " +\n depth * ").concat(stride2, " + depth2;\n vec2 uv = uvFromFlat(").concat(texNumR, ", ").concat(texNumC, ", index + ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112133 }
112134 function getSampler5D(inputInfo) {
112135 var shape = inputInfo.shapeInfo.logicalShape;
112136 var texName = inputInfo.name;
112137 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
112138 var stride3 = shape[4];
112139 var stride2 = shape[3] * stride3;
112140 var stride1 = shape[2] * stride2;
112141 var stride0 = shape[1] * stride1;
112142 var _util$squeezeShape4 = squeezeShape(shape),
112143 newShape = _util$squeezeShape4.newShape,
112144 keptDims = _util$squeezeShape4.keptDims;
112145 if (newShape.length < shape.length) {
112146 var newInputInfo = squeezeInputInfo(inputInfo, newShape);
112147 var params = ['row', 'col', 'depth', 'depth2', 'depth3'];
112148 return "\n ".concat(getSamplerFromInInfo(newInputInfo), "\n float ").concat(funcName, "(int row, int col, int depth, int depth2, int depth3) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n ");
112149 }
112150 if (inputInfo.shapeInfo.isUniform) {
112151 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
112152 return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2, int depth3) {\n float index = dot(\n vec4(row, col, depth, depth2),\n vec4(").concat(stride0, ", ").concat(stride1, ", ").concat(stride2, ", ").concat(stride3, ")) +\n depth3;\n ").concat(getUniformSampler(inputInfo), "\n }\n ");
112153 }
112154 var flatOffset = inputInfo.shapeInfo.flatOffset;
112155 var texShape = inputInfo.shapeInfo.texShape;
112156 var texNumR = texShape[0];
112157 var texNumC = texShape[1];
112158 if (texNumC === stride0 && flatOffset == null) {
112159 // texC is used directly as physical (no risk of float16 overflow).
112160 return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2, int depth3) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(").concat(stride1, ", ").concat(stride2, ", ").concat(stride3, ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112161 }
112162 if (texNumC === stride3 && flatOffset == null) {
112163 // texR is used directly as physical (no risk of float16 overflow).
112164 return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2, int depth3) {\n float texR = dot(\n vec4(row, col, depth, depth2),\n vec4(").concat(shape[1] * shape[2] * shape[3], ",\n ").concat(shape[2] * shape[3], ", ").concat(shape[3], ", 1));\n int texC = depth3;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112165 }
112166 var offset = getFlatOffsetUniformName(texName);
112167 return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2, int depth3) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(stride0, " + col * ").concat(stride1, " + depth * ").concat(stride2, " +\n depth2 * ").concat(stride3, " + depth3 + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texNumR, ", ").concat(texNumC, ", index);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112168 }
112169 function getSampler6D(inputInfo) {
112170 var shape = inputInfo.shapeInfo.logicalShape;
112171 var texName = inputInfo.name;
112172 var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
112173 var _util$squeezeShape5 = squeezeShape(shape),
112174 newShape = _util$squeezeShape5.newShape,
112175 keptDims = _util$squeezeShape5.keptDims;
112176 if (newShape.length < shape.length) {
112177 var newInputInfo = squeezeInputInfo(inputInfo, newShape);
112178 var params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4'];
112179 return "\n ".concat(getSamplerFromInInfo(newInputInfo), "\n float ").concat(funcName, "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n ");
112180 }
112181 var stride4 = shape[5];
112182 var stride3 = shape[4] * stride4;
112183 var stride2 = shape[3] * stride3;
112184 var stride1 = shape[2] * stride2;
112185 var stride0 = shape[1] * stride1;
112186 if (inputInfo.shapeInfo.isUniform) {
112187 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
112188 return "\n float ".concat(funcName, "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int index = round(dot(\n vec4(row, col, depth, depth2),\n vec4(").concat(stride0, ", ").concat(stride1, ", ").concat(stride2, ", ").concat(stride3, ")) +\n dot(\n vec2(depth3, depth4),\n vec2(").concat(stride4, ", 1)));\n ").concat(getUniformSampler(inputInfo), "\n }\n ");
112189 }
112190 var flatOffset = inputInfo.shapeInfo.flatOffset;
112191 var texShape = inputInfo.shapeInfo.texShape;
112192 var texNumR = texShape[0];
112193 var texNumC = texShape[1];
112194 if (texNumC === stride0 && flatOffset == null) {
112195 // texC is used directly as physical (no risk of float16 overflow).
112196 return "\n float ".concat(funcName, "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(").concat(stride1, ", ").concat(stride2, ", ").concat(stride3, ", ").concat(stride4, ")) +\n float(depth4);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112197 }
112198 if (texNumC === stride4 && flatOffset == null) {
112199 // texR is used directly as physical (no risk of float16 overflow).
112200 return "\n float ".concat(funcName, "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n float texR = dot(vec4(row, col, depth, depth2),\n vec4(").concat(shape[1] * shape[2] * shape[3] * shape[4], ",\n ").concat(shape[2] * shape[3] * shape[4], ",\n ").concat(shape[3] * shape[4], ",\n ").concat(shape[4], ")) + float(depth3);\n int texC = depth4;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112201 }
112202 var offset = getFlatOffsetUniformName(texName);
112203 return "\n float ".concat(funcName, "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(stride0, " + col * ").concat(stride1, " + depth * ").concat(stride2, " +\n depth2 * ").concat(stride3, " + depth3 * ").concat(stride4, " + depth4 + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texNumR, ", ").concat(texNumC, ", index);\n return sampleTexture(").concat(texName, ", uv);\n }\n ");
112204 }
112205 function getUniformSampler(inputInfo) {
112206 var texName = inputInfo.name;
112207 var inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
112208 if (inSize < 2) {
112209 return "return ".concat(texName, ";");
112210 }
112211 return "\n for (int i = 0; i < ".concat(inSize, "; i++) {\n if (i == index) {\n return ").concat(texName, "[i];\n }\n }\n ");
112212 }
112213 function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) {
112214 var texName = inputInfo.name;
112215 var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
112216 var funcName = 'get' + texFuncSnippet + 'AtOutCoords';
112217 var inRank = inputInfo.shapeInfo.logicalShape.length;
112218 var outRank = outShapeInfo.logicalShape.length;
112219 var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
112220 var type = getCoordsDataType(outRank);
112221 var rankDiff = outRank - inRank;
112222 var coordsSnippet;
112223 var fields = ['x', 'y', 'z', 'w', 'u', 'v'];
112224 if (inRank === 0) {
112225 coordsSnippet = '';
112226 } else if (outRank < 2 && broadcastDims.length >= 1) {
112227 coordsSnippet = 'coords = 0;';
112228 } else {
112229 coordsSnippet = broadcastDims.map(function (d) {
112230 return "coords.".concat(fields[d + rankDiff], " = 0;");
112231 }).join('\n');
112232 }
112233 var unpackedCoordsSnippet = '';
112234 if (outRank < 2 && inRank > 0) {
112235 unpackedCoordsSnippet = 'coords';
112236 } else {
112237 unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape.map(function (s, i) {
112238 return "coords.".concat(fields[i + rankDiff]);
112239 }).join(', ');
112240 }
112241 var output = "return outputValue;";
112242 var inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
112243 var isInputScalar = inSize === 1;
112244 var outSize = sizeFromShape(outShapeInfo.logicalShape);
112245 var isOutputScalar = outSize === 1;
112246 if (inRank === 1 && !isInputScalar && !isOutputScalar) {
112247 output = "\n return vec4(outputValue.xy, outputValue.xy);\n ";
112248 } else if (isInputScalar && !isOutputScalar) {
112249 if (outRank === 1) {
112250 output = "\n return vec4(outputValue.x, outputValue.x, 0., 0.);\n ";
112251 } else {
112252 output = "\n return vec4(outputValue.x);\n ";
112253 }
112254 } else if (broadcastDims.length) {
112255 var rows = inRank - 2;
112256 var cols = inRank - 1;
112257 if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) {
112258 output = "return vec4(outputValue.x);";
112259 } else if (broadcastDims.indexOf(rows) > -1) {
112260 output = "return vec4(outputValue.x, outputValue.y, " + "outputValue.x, outputValue.y);";
112261 } else if (broadcastDims.indexOf(cols) > -1) {
112262 output = "return vec4(outputValue.xx, outputValue.zz);";
112263 }
112264 }
112265 return "\n vec4 ".concat(funcName, "() {\n ").concat(type, " coords = getOutputCoords();\n ").concat(coordsSnippet, "\n vec4 outputValue = get").concat(texFuncSnippet, "(").concat(unpackedCoordsSnippet, ");\n ").concat(output, "\n }\n ");
112266 }
112267 function getSamplerAtOutputCoords(inputInfo, outShapeInfo) {
112268 var texName = inputInfo.name;
112269 var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
112270 var funcName = 'get' + texFuncSnippet + 'AtOutCoords';
112271 var outTexShape = outShapeInfo.texShape;
112272 var inTexShape = inputInfo.shapeInfo.texShape;
112273 var inRank = inputInfo.shapeInfo.logicalShape.length;
112274 var outRank = outShapeInfo.logicalShape.length;
112275 if (!inputInfo.shapeInfo.isUniform && inRank === outRank && inputInfo.shapeInfo.flatOffset == null && arraysEqual(inTexShape, outTexShape)) {
112276 return "\n float ".concat(funcName, "() {\n return sampleTexture(").concat(texName, ", resultUV);\n }\n ");
112277 }
112278 var type = getCoordsDataType(outRank);
112279 var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
112280 var rankDiff = outRank - inRank;
112281 var coordsSnippet;
112282 var fields = ['x', 'y', 'z', 'w', 'u', 'v'];
112283 if (inRank === 0) {
112284 coordsSnippet = '';
112285 } else if (outRank < 2 && broadcastDims.length >= 1) {
112286 coordsSnippet = 'coords = 0;';
112287 } else {
112288 coordsSnippet = broadcastDims.map(function (d) {
112289 return "coords.".concat(fields[d + rankDiff], " = 0;");
112290 }).join('\n');
112291 }
112292 var unpackedCoordsSnippet = '';
112293 if (outRank < 2 && inRank > 0) {
112294 unpackedCoordsSnippet = 'coords';
112295 } else {
112296 unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape.map(function (s, i) {
112297 return "coords.".concat(fields[i + rankDiff]);
112298 }).join(', ');
112299 }
112300 return "\n float ".concat(funcName, "() {\n ").concat(type, " coords = getOutputCoords();\n ").concat(coordsSnippet, "\n return get").concat(texFuncSnippet, "(").concat(unpackedCoordsSnippet, ");\n }\n ");
112301 }
112302 function getCoordsDataType(rank) {
112303 if (rank <= 1) {
112304 return 'int';
112305 } else if (rank === 2) {
112306 return 'ivec2';
112307 } else if (rank === 3) {
112308 return 'ivec3';
112309 } else if (rank === 4) {
112310 return 'ivec4';
112311 } else if (rank === 5) {
112312 return 'ivec5';
112313 } else if (rank === 6) {
112314 return 'ivec6';
112315 } else {
112316 throw Error("GPU for rank ".concat(rank, " is not yet supported"));
112317 }
112318 }
112319 function getUniformInfoFromShape(isPacked, shape, texShape) {
112320 var _util$squeezeShape6 = squeezeShape(shape),
112321 newShape = _util$squeezeShape6.newShape,
112322 keptDims = _util$squeezeShape6.keptDims;
112323 var rank = shape.length;
112324 var useSqueezePackedShape = isPacked && rank === 3 && shape[0] === 1;
112325 var squeezeShape$1 = useSqueezePackedShape ? shape.slice(1) : newShape;
112326 var useSqueezeShape = !isPacked && rank > 1 && !arraysEqual(shape, texShape) && newShape.length < rank || useSqueezePackedShape;
112327 var uniformShape = useSqueezeShape ? squeezeShape$1 : shape;
112328 return {
112329 useSqueezeShape: useSqueezeShape,
112330 uniformShape: uniformShape,
112331 keptDims: keptDims
112332 };
112333 }
112334 /** Returns a new input info (a copy) that has a squeezed logical shape. */
112335 function squeezeInputInfo(inInfo, squeezedShape) {
112336 // Deep copy.
112337 var newInputInfo = JSON.parse(JSON.stringify(inInfo));
112338 newInputInfo.shapeInfo.logicalShape = squeezedShape;
112339 return newInputInfo;
112340 }
112341 function getSqueezedParams(params, keptDims) {
112342 return keptDims.map(function (d) {
112343 return params[d];
112344 }).join(', ');
112345 }
112346
112347 function compileProgram(gpgpu, program, inputs, output) {
112348 var inputInfos = inputs.map(function (input, i) {
112349 var shapeInfo = {
112350 logicalShape: input.shape,
112351 texShape: input.isUniform ? null : input.texData.texShape,
112352 isUniform: input.isUniform,
112353 isPacked: input.isUniform ? false : input.texData.isPacked,
112354 flatOffset: null
112355 };
112356 if (input.texData != null && input.texData.slice != null && input.texData.slice.flatOffset > 0) {
112357 shapeInfo.flatOffset = input.texData.slice.flatOffset;
112358 }
112359 return {
112360 name: program.variableNames[i],
112361 shapeInfo: shapeInfo
112362 };
112363 });
112364 var inShapeInfos = inputInfos.map(function (x) {
112365 return x.shapeInfo;
112366 });
112367 var outShapeInfo = {
112368 logicalShape: output.shape,
112369 texShape: output.texData.texShape,
112370 isUniform: false,
112371 isPacked: output.texData.isPacked,
112372 flatOffset: null
112373 };
112374 var source = makeShader(inputInfos, outShapeInfo, program);
112375 var fragmentShader = createFragmentShader(gpgpu.gl, source);
112376 var webGLProgram = gpgpu.createProgram(fragmentShader);
112377 if (!env().get('ENGINE_COMPILE_ONLY')) {
112378 gpgpu.buildVao(webGLProgram);
112379 return Object.assign({
112380 program: program,
112381 fragmentShader: fragmentShader,
112382 source: source,
112383 webGLProgram: webGLProgram,
112384 inShapeInfos: inShapeInfos,
112385 outShapeInfo: outShapeInfo
112386 }, getUniformLocations(gpgpu, program, webGLProgram));
112387 } else {
112388 return {
112389 program: program,
112390 fragmentShader: fragmentShader,
112391 source: source,
112392 webGLProgram: webGLProgram,
112393 inShapeInfos: inShapeInfos,
112394 outShapeInfo: outShapeInfo,
112395 variablesLocations: null,
112396 customUniformLocations: null,
112397 infLoc: null,
112398 nanLoc: null,
112399 outShapeLocation: null,
112400 outShapeStridesLocation: null,
112401 outTexShapeLocation: null
112402 };
112403 }
112404 }
112405 function getUniformLocations(gpgpu, program, webGLProgram) {
112406 var variablesLocations = [];
112407 var customUniformLocations = [];
112408 var outShapeLocation;
112409 var outTexShapeLocation;
112410 var outShapeStridesLocation;
112411 var infLoc = null;
112412 var nanLoc = null;
112413 // Add special uniforms (NAN, INFINITY)
112414 nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);
112415 if (env().getNumber('WEBGL_VERSION') === 1) {
112416 infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false);
112417 }
112418 // Add user-defined uniforms
112419 var shouldThrow = false;
112420 var _iterator = _createForOfIteratorHelper(program.variableNames),
112421 _step;
112422 try {
112423 for (_iterator.s(); !(_step = _iterator.n()).done;) {
112424 var varName = _step.value;
112425 var varLocs = {
112426 name: varName,
112427 uniform: gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow),
112428 offset: gpgpu.getUniformLocation(webGLProgram, "offset".concat(varName), shouldThrow)
112429 };
112430 if (program.enableShapeUniforms) {
112431 varLocs.shape = gpgpu.getUniformLocation(webGLProgram, "".concat(varName, "Shape"), shouldThrow);
112432 varLocs.texShape = gpgpu.getUniformLocation(webGLProgram, "".concat(varName, "TexShape"), shouldThrow);
112433 }
112434 variablesLocations.push(varLocs);
112435 }
112436 } catch (err) {
112437 _iterator.e(err);
112438 } finally {
112439 _iterator.f();
112440 }
112441 if (program.enableShapeUniforms) {
112442 outShapeLocation = gpgpu.getUniformLocation(webGLProgram, 'outShape', shouldThrow);
112443 outShapeStridesLocation = gpgpu.getUniformLocation(webGLProgram, 'outShapeStrides', shouldThrow);
112444 outTexShapeLocation = gpgpu.getUniformLocation(webGLProgram, 'outTexShape', shouldThrow);
112445 }
112446 if (program.customUniforms) {
112447 var _iterator2 = _createForOfIteratorHelper(program.customUniforms),
112448 _step2;
112449 try {
112450 for (_iterator2.s(); !(_step2 = _iterator2.n()).done;) {
112451 var d = _step2.value;
112452 customUniformLocations.push(gpgpu.getUniformLocation(webGLProgram, d.name, shouldThrow));
112453 }
112454 } catch (err) {
112455 _iterator2.e(err);
112456 } finally {
112457 _iterator2.f();
112458 }
112459 }
112460 return {
112461 variablesLocations: variablesLocations,
112462 customUniformLocations: customUniformLocations,
112463 infLoc: infLoc,
112464 nanLoc: nanLoc,
112465 outShapeLocation: outShapeLocation,
112466 outShapeStridesLocation: outShapeStridesLocation,
112467 outTexShapeLocation: outTexShapeLocation
112468 };
112469 }
112470 function validateBinaryAndProgram(shapeInfos, inputs) {
112471 if (shapeInfos.length !== inputs.length) {
112472 throw Error("Binary was compiled with ".concat(shapeInfos.length, " inputs, but ") + "was executed with ".concat(inputs.length, " inputs"));
112473 }
112474 shapeInfos.forEach(function (s, i) {
112475 var shapeA = s.logicalShape;
112476 var input = inputs[i];
112477 var shapeB = input.shape;
112478 if (!arraysEqual(shapeA, shapeB)) {
112479 throw Error("Binary was compiled with different shapes than " + "the current args. Shapes ".concat(shapeA, " and ").concat(shapeB, " must match"));
112480 }
112481 // The input is uploaded as uniform.
112482 if (s.isUniform && input.isUniform) {
112483 return;
112484 }
112485 var texShapeA = s.texShape;
112486 var texShapeB = input.isUniform ? null : input.texData.texShape;
112487 if (!arraysEqual(texShapeA, texShapeB)) {
112488 throw Error("Binary was compiled with different texture shapes than the" + " current args. Shape ".concat(texShapeA, " and ").concat(texShapeB, " must match"));
112489 }
112490 });
112491 }
112492 function runProgram(gpgpu, binary, inputs, output, customUniformValues) {
112493 if (!binary.program.enableShapeUniforms) {
112494 validateBinaryAndProgram(binary.inShapeInfos, inputs);
112495 validateBinaryAndProgram([binary.outShapeInfo], [output]);
112496 }
112497 var outTex = output.texData.texture;
112498 var outTexShape = output.texData.texShape;
112499 if (output.texData.isPacked) {
112500 gpgpu.setOutputPackedMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]);
112501 } else {
112502 gpgpu.setOutputMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]);
112503 }
112504 gpgpu.setProgram(binary.webGLProgram);
112505 gpgpu.bindVertexArray(binary.webGLProgram.vao);
112506 // Set special uniforms (NAN, INFINITY)
112507 if (env().getNumber('WEBGL_VERSION') === 1) {
112508 if (binary.infLoc !== null) {
112509 gpgpu.gl.uniform1f(binary.infLoc, Infinity);
112510 }
112511 }
112512 if (binary.nanLoc !== null) {
112513 gpgpu.gl.uniform1f(binary.nanLoc, NaN);
112514 }
112515 // Set user-defined inputs
112516 for (var i = 0; i < inputs.length; ++i) {
112517 var input = inputs[i];
112518 var _binary$variablesLoca = binary.variablesLocations[i],
112519 varLoc = _binary$variablesLoca.uniform,
112520 varOffsetLoc = _binary$variablesLoca.offset,
112521 varShapeLoc = _binary$variablesLoca.shape,
112522 varTexShapeLoc = _binary$variablesLoca.texShape;
112523 if (varShapeLoc) {
112524 var _shader_compiler$getU = getUniformInfoFromShape(binary.program.packedInputs, input.shape, input.texData.texShape),
112525 uniformShape = _shader_compiler$getU.uniformShape;
112526 switch (uniformShape.length) {
112527 case 1:
112528 gpgpu.gl.uniform1iv(varShapeLoc, new Int32Array(uniformShape));
112529 break;
112530 case 2:
112531 gpgpu.gl.uniform2iv(varShapeLoc, new Int32Array(uniformShape));
112532 break;
112533 case 3:
112534 gpgpu.gl.uniform3iv(varShapeLoc, new Int32Array(uniformShape));
112535 break;
112536 case 4:
112537 gpgpu.gl.uniform4iv(varShapeLoc, new Int32Array(uniformShape));
112538 break;
112539 default:
112540 break;
112541 }
112542 }
112543 if (varTexShapeLoc) {
112544 gpgpu.gl.uniform2i(varTexShapeLoc, input.texData.texShape[0], input.texData.texShape[1]);
112545 }
112546 if (varLoc == null) {
112547 // The compiler inferred that this variable is not used in this shader.
112548 continue;
112549 }
112550 if (input.isUniform) {
112551 // Upload the values of the tensor as uniform.
112552 if (sizeFromShape(input.shape) < 2) {
112553 gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]);
112554 } else {
112555 var vals = input.uniformValues;
112556 if (!(vals instanceof Float32Array)) {
112557 vals = new Float32Array(vals);
112558 }
112559 gpgpu.gl.uniform1fv(varLoc, vals);
112560 }
112561 continue;
112562 }
112563 // If the input was sliced, upload the flat offset index.
112564 if (input.texData.slice != null && varOffsetLoc != null) {
112565 gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset);
112566 }
112567 gpgpu.setInputMatrixTexture(input.texData.texture.texture, varLoc, i);
112568 }
112569 var outShapeLoc = binary.outShapeLocation;
112570 if (outShapeLoc) {
112571 switch (output.shape.length) {
112572 case 1:
112573 gpgpu.gl.uniform1iv(outShapeLoc, new Int32Array(output.shape));
112574 break;
112575 case 2:
112576 gpgpu.gl.uniform2iv(outShapeLoc, new Int32Array(output.shape));
112577 break;
112578 case 3:
112579 gpgpu.gl.uniform3iv(outShapeLoc, new Int32Array(output.shape));
112580 break;
112581 case 4:
112582 gpgpu.gl.uniform4iv(outShapeLoc, new Int32Array(output.shape));
112583 break;
112584 default:
112585 break;
112586 }
112587 }
112588 if (binary.outShapeStridesLocation) {
112589 var strides = computeStrides(output.shape);
112590 switch (output.shape.length) {
112591 case 2:
112592 gpgpu.gl.uniform1iv(binary.outShapeStridesLocation, new Int32Array(strides));
112593 break;
112594 case 3:
112595 gpgpu.gl.uniform2iv(binary.outShapeStridesLocation, new Int32Array(strides));
112596 break;
112597 case 4:
112598 gpgpu.gl.uniform3iv(binary.outShapeStridesLocation, new Int32Array(strides));
112599 break;
112600 default:
112601 break;
112602 }
112603 }
112604 if (binary.outTexShapeLocation) {
112605 gpgpu.gl.uniform2i(binary.outTexShapeLocation, output.texData.texShape[0], output.texData.texShape[1]);
112606 }
112607 if (binary.program.customUniforms && customUniformValues) {
112608 for (var _i = 0; _i < binary.program.customUniforms.length; ++_i) {
112609 var d = binary.program.customUniforms[_i];
112610 var customLoc = binary.customUniformLocations[_i];
112611 var customValue = customUniformValues[_i];
112612 if (d.type === 'float') {
112613 gpgpu.gl.uniform1fv(customLoc, customValue);
112614 } else if (d.type === 'vec2') {
112615 gpgpu.gl.uniform2fv(customLoc, customValue);
112616 } else if (d.type === 'vec3') {
112617 gpgpu.gl.uniform3fv(customLoc, customValue);
112618 } else if (d.type === 'vec4') {
112619 gpgpu.gl.uniform4fv(customLoc, customValue);
112620 } else if (d.type === 'int') {
112621 gpgpu.gl.uniform1iv(customLoc, customValue);
112622 } else if (d.type === 'ivec2') {
112623 gpgpu.gl.uniform2iv(customLoc, customValue);
112624 } else if (d.type === 'ivec3') {
112625 gpgpu.gl.uniform3iv(customLoc, customValue);
112626 } else if (d.type === 'ivec4') {
112627 gpgpu.gl.uniform4iv(customLoc, customValue);
112628 } else {
112629 throw Error("uniform type ".concat(d.type, " is not supported yet."));
112630 }
112631 }
112632 }
112633 gpgpu.executeProgram();
112634 }
112635 function makeShaderKey(program, inputs, output) {
112636 var keyInputs = '';
112637 inputs.concat(output).forEach(function (x) {
112638 var hasOffset = x.texData != null && x.texData.slice != null && x.texData.slice.flatOffset > 0;
112639 // TODO: Remove the condition of !x.isUniform.
112640 if (program.enableShapeUniforms && !x.isUniform) {
112641 var xTexShape = x.texData.texShape;
112642 var _shader_compiler$getU2 = getUniformInfoFromShape(program.packedInputs, x.shape, xTexShape),
112643 useSqueezeShape = _shader_compiler$getU2.useSqueezeShape,
112644 uniformShape = _shader_compiler$getU2.uniformShape,
112645 keptDims = _shader_compiler$getU2.keptDims;
112646 var rank1 = '',
112647 rank2 = '',
112648 rank34 = '';
112649 if (uniformShape.length === 1 && program.packedInputs) {
112650 var packedTexShape = [Math.ceil(xTexShape[0] / 2), Math.ceil(xTexShape[1] / 2)];
112651 rank1 = "".concat(packedTexShape[0] > 1, "_").concat(packedTexShape[1] > 1);
112652 } else if (uniformShape.length === 2 && !program.packedInputs) {
112653 rank2 = "".concat(uniformShape[0] > 1, "_").concat(uniformShape[1] > 1);
112654 } else if (uniformShape.length > 2 && !program.packedInputs) {
112655 var strides = computeStrides(uniformShape);
112656 rank34 = "".concat(strides[0] === xTexShape[1], "_").concat(strides[strides.length - 1] === xTexShape[1]);
112657 }
112658 var xRank = x.shape.length;
112659 var isLogicalShapTexShapeEqual = uniformShape.length === 2 && arraysEqual(x.shape, xTexShape);
112660 var isScalar = sizeFromShape(x.shape) === 1;
112661 var broadcastDims = getBroadcastDims$1(x.shape, output.shape);
112662 var isInOutTexShapeEqual = !program.packedInputs && xRank === output.shape.length && arraysEqual(xTexShape, output.texData.texShape);
112663 var isTexShapeGreaterThanOne = program.packedInputs || uniformShape.length > 2 ? '' : "".concat(xTexShape[0] > 1, "_").concat(xTexShape[1] > 1);
112664 // These key components are needed due to shader_compiler is embedding
112665 // them in the shader.
112666 // |xRank| is used to determine the coords length. See
112667 // get[Packed]SamplerAtOutputCoords.
112668 // |isInOutTexShapeEqual| is used to determine whether going to an
112669 // optimization path in getSamplerAtOutputCoords.
112670 // |useSqueezeShape| is extracted from squeezeInputInfo of
112671 // getSampler[2|3|4]D/getPackedSampler3D.
112672 // |isScalar| is extracted from isInputScalar/isOutputScalar in
112673 // getPackedSamplerAtOutputCoords.
112674 // |broadcastDims| is extracted from get[Packed]SamplerAtOutputCoords.
112675 // |isLogicalShapTexShapeEqual| is used in
112676 // getOutput[Packed]2DCoords/get[Packed]Sampler2D.
112677 // |rank1| is used in getOutputPacked1DCoords.
112678 // |rank2| is used in getOutput2DCoords.
112679 // |rank34| is used in getSampler3D/getSampler4D.
112680 // |isTexShapeGreaterThanOne| are used in
112681 // getSampler[Scalar|1D|2D]/getOutput1DCoords.
112682 keyInputs += "".concat(xRank, "_").concat(isInOutTexShapeEqual, "_").concat(useSqueezeShape ? keptDims : '', "_").concat(uniformShape.length, "_").concat(isScalar, "_").concat(broadcastDims, "_").concat(isLogicalShapTexShapeEqual, "_").concat(rank1, "_").concat(rank2, "_").concat(rank34, "_").concat(isTexShapeGreaterThanOne, "_").concat(hasOffset);
112683 } else {
112684 var texShape = x.isUniform ? 'uniform' : x.texData.texShape;
112685 keyInputs += "".concat(x.shape, "_").concat(texShape, "_").concat(hasOffset);
112686 }
112687 });
112688 var keyUserCode = program.userCode;
112689 var key = program.constructor.name;
112690 // Fast string concat. See https://jsperf.com/string-concatenation/14.
112691 key += '_' + keyInputs + '_' + keyUserCode + "".concat(env().getNumber('WEBGL_VERSION'));
112692 return key;
112693 }
112694 function useShapeUniforms(rank) {
112695 // TODO: Remove the limitaion of rank <= 4.
112696 return env().getBool('WEBGL_USE_SHAPES_UNIFORMS') && rank <= 4;
112697 }
112698
112699 var DecodeMatrixProgram = /*#__PURE__*/_createClass(function DecodeMatrixProgram(outputShape) {
112700 _classCallCheck(this, DecodeMatrixProgram);
112701 this.variableNames = ['A'];
112702 this.packedInputs = false;
112703 this.packedOutput = true;
112704 this.outPackingScheme = PackingScheme.DENSE;
112705 this.customUniforms = [{
112706 name: 'texShape',
112707 type: 'ivec2'
112708 }];
112709 var glsl = getGlslDifferences();
112710 this.outputShape = outputShape;
112711 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
112712 this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n ".concat(this.enableShapeUniforms ? getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape), "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));\n int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getA(rc.x, rc.y, rc.z);\n }\n\n ").concat(glsl.output, " = result;\n }\n ");
112713 });
112714
112715 var DecodeMatrixPackedProgram = /*#__PURE__*/_createClass(function DecodeMatrixPackedProgram(outputShape) {
112716 _classCallCheck(this, DecodeMatrixPackedProgram);
112717 this.variableNames = ['A'];
112718 this.packedInputs = true;
112719 this.packedOutput = true;
112720 this.outPackingScheme = PackingScheme.DENSE;
112721 this.customUniforms = [{
112722 name: 'texShape',
112723 type: 'ivec2'
112724 }];
112725 var glsl = getGlslDifferences();
112726 this.outputShape = outputShape;
112727 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
112728 this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n ".concat(this.enableShapeUniforms ? getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape), "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));\n int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));\n }\n\n ").concat(glsl.output, " = result;\n }\n ");
112729 });
112730
112731 var EncodeFloatProgram = /*#__PURE__*/_createClass(function EncodeFloatProgram(outputShape) {
112732 _classCallCheck(this, EncodeFloatProgram);
112733 this.variableNames = ['A'];
112734 this.outTexUsage = TextureUsage.DOWNLOAD;
112735 var glsl = getGlslDifferences();
112736 this.outputShape = outputShape;
112737 this.userCode = "\n ".concat(ENCODE_FLOAT_SNIPPET, "\n\n void main() {\n float x = getAAtOutCoords();\n ").concat(glsl.output, " = encode_float(x);\n }\n ");
112738 });
112739
112740 var EncodeFloatPackedProgram = /*#__PURE__*/_createClass(function EncodeFloatPackedProgram(outputShape) {
112741 _classCallCheck(this, EncodeFloatPackedProgram);
112742 this.variableNames = ['A'];
112743 this.packedInputs = true;
112744 this.packedOutput = false;
112745 this.outTexUsage = TextureUsage.DOWNLOAD;
112746 var glsl = getGlslDifferences();
112747 this.outputShape = outputShape;
112748 this.userCode = "\n ".concat(ENCODE_FLOAT_SNIPPET, "\n\n void main() {\n ivec3 coords = getOutputCoords();\n float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));\n ").concat(glsl.output, " = encode_float(x);\n }\n ");
112749 });
112750
112751 var CHANNEL_CHAR_TO_INDEX_MAP = {
112752 'R': 0,
112753 'G': 1,
112754 'B': 2,
112755 'A': 3
112756 };
112757 var EncodeMatrixProgram = /*#__PURE__*/_createClass(function EncodeMatrixProgram(outputShape) {
112758 var inputIsUnsignedByte = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
112759 var usedChannels = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 'RGBA';
112760 _classCallCheck(this, EncodeMatrixProgram);
112761 this.variableNames = ['A'];
112762 this.customUniforms = [{
112763 name: 'texShape',
112764 type: 'ivec2'
112765 }];
112766 var glsl = getGlslDifferences();
112767 this.outputShape = outputShape;
112768 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
112769 var output = "result";
112770 if (inputIsUnsignedByte) {
112771 output = "floor(result * 255. + 0.5)";
112772 }
112773 var mainLoop = '';
112774 for (var usedChannelIndex = 0; usedChannelIndex < usedChannels.length; usedChannelIndex++) {
112775 var curChannel = usedChannels[usedChannelIndex];
112776 mainLoop += "\n if(offset == ".concat(usedChannelIndex, ") {\n result = values[").concat(CHANNEL_CHAR_TO_INDEX_MAP[curChannel], "];\n }");
112777 }
112778 this.userCode = "\n ".concat(this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape), "\n\n void main() {\n ivec3 coords = getOutputCoords();\n int flatIndex = getFlatIndex(coords);\n float result = 0.;\n int offset = imod(flatIndex, ").concat(usedChannels.length, ");\n\n flatIndex = idiv(flatIndex, ").concat(usedChannels.length, ", 1.);\n\n int r = flatIndex / texShape[1];\n if (r < texShape[0]) {\n int c = imod(flatIndex, texShape[1]);\n vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);\n vec4 values = ").concat(glsl.texture2D, "(A, uv);\n ").concat(mainLoop, "\n }\n ").concat(glsl.output, " = vec4(").concat(output, ", 0., 0., 0.);\n }\n ");
112779 });
112780
112781 /*
112782 This is how the shader encodes a tensor with shape = [2, 3, 5]
112783 (indices are [batch, row, col]).
112784
112785 000|001 002|003 004|xxx 020|021 022|023 024|xxx
112786 ------- ------- ------- ------- ------- -------
112787 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
112788
112789 100|101 102|103 104|xxx 120|121 122|123 124|xxx
112790 ------- ------- ------- ------- ------- -------
112791 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
112792
112793 Single texels contain only values from the same batch, and from adjacent rows
112794 and columns.
112795 */
112796 var EncodeMatrixPackedProgram = /*#__PURE__*/_createClass(function EncodeMatrixPackedProgram(outputShape) {
112797 var inputIsUnsignedByte = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
112798 _classCallCheck(this, EncodeMatrixPackedProgram);
112799 this.variableNames = ['A'];
112800 this.packedInputs = false;
112801 this.packedOutput = true;
112802 this.customUniforms = [{
112803 name: 'texShape',
112804 type: 'ivec2'
112805 }];
112806 var glsl = getGlslDifferences();
112807 this.outputShape = outputShape;
112808 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
112809 var mainLoop = '';
112810 var output = 'result';
112811 if (inputIsUnsignedByte) {
112812 output = 'floor(result * 255. + 0.5)';
112813 }
112814 for (var row = 0; row <= 1; row++) {
112815 for (var col = 0; col <= 1; col++) {
112816 var channel = row * 2 + col;
112817 mainLoop += "\n localCoords = coords;\n if(localCoords[2] + ".concat(col, " < ").concat(this.enableShapeUniforms ? 'outShape[2]' : "".concat(outputShape[2]), ") {\n localCoords[2] += ").concat(col, ";\n if (localCoords[1] + ").concat(row, " < ").concat(this.enableShapeUniforms ? 'outShape[1]' : "".concat(outputShape[1]), ") {\n localCoords[1] += ").concat(row, ";\n\n flatIndex = getFlatIndex(localCoords);\n offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n int r = flatIndex / texShape[1];\n int c = imod(flatIndex, texShape[1]);\n vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);\n values = ").concat(glsl.texture2D, "(A, uv);\n\n if (offset == 0) {\n result[").concat(channel, "] = values[0];\n } else if (offset == 1) {\n result[").concat(channel, "] = values[1];\n } else if (offset == 2) {\n result[").concat(channel, "] = values[2];\n } else {\n result[").concat(channel, "] = values[3];\n }\n }\n }\n ");
112818 }
112819 }
112820 this.userCode = "\n ".concat(this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape), "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n vec4 result = vec4(0.);\n int flatIndex, r, c, offset;\n ivec3 localCoords;\n vec2 uv;\n vec4 values;\n\n ").concat(mainLoop, "\n\n ").concat(glsl.output, " = ").concat(output, ";\n }\n ");
112821 });
112822
112823 function createVertexShader(gl) {
112824 var glsl = getGlslDifferences();
112825 var vertexShaderSource = "".concat(glsl.version, "\n precision highp float;\n ").concat(glsl.attribute, " vec3 clipSpacePos;\n ").concat(glsl.attribute, " vec2 uv;\n ").concat(glsl.varyingVs, " vec2 resultUV;\n\n void main() {\n gl_Position = vec4(clipSpacePos, 1);\n resultUV = uv;\n }");
112826 return createVertexShader$1(gl, vertexShaderSource);
112827 }
112828 function createVertexBuffer(gl) {
112829 // [x y z u v] * [upper-left, lower-left, upper-right, lower-right]
112830 var vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]);
112831 return createStaticVertexBuffer(gl, vertexArray);
112832 }
112833 function createIndexBuffer(gl) {
112834 // OpenGL (and WebGL) have "CCW == front" winding
112835 var triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]);
112836 return createStaticIndexBuffer(gl, triangleVertexIndices);
112837 }
112838 function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) {
112839 validateTextureSize(width, height);
112840 var texture = createTexture(gl);
112841 var tex2d = gl.TEXTURE_2D;
112842 callAndCheck(gl, function () {
112843 return gl.bindTexture(tex2d, texture);
112844 });
112845 callAndCheck(gl, function () {
112846 return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
112847 });
112848 callAndCheck(gl, function () {
112849 return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
112850 });
112851 callAndCheck(gl, function () {
112852 return gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
112853 });
112854 callAndCheck(gl, function () {
112855 return gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
112856 });
112857 if (env().getNumber('WEBGL_VERSION') === 1) {
112858 callAndCheck(gl, function () {
112859 return gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null);
112860 });
112861 } else {
112862 callAndCheck(gl, function () {
112863 return gl.texStorage2D(tex2d, 1, internalFormat, width, height);
112864 });
112865 }
112866 callAndCheck(gl, function () {
112867 return gl.bindTexture(gl.TEXTURE_2D, null);
112868 });
112869 return {
112870 texture: texture,
112871 texShape: [height, width]
112872 };
112873 }
112874 function getInternalFormatForFloat32MatrixTexture(textureConfig) {
112875 return textureConfig.internalFormatFloat;
112876 }
112877 function createFloat32MatrixTexture(gl, rows, columns, textureConfig) {
112878 var _tex_util$getUnpacked = getUnpackedMatrixTextureShapeWidthHeight(rows, columns),
112879 _tex_util$getUnpacked2 = _slicedToArray(_tex_util$getUnpacked, 2),
112880 width = _tex_util$getUnpacked2[0],
112881 height = _tex_util$getUnpacked2[1];
112882 return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT);
112883 }
112884 function getInternalFormatForFloat16MatrixTexture(textureConfig) {
112885 return textureConfig.internalFormatHalfFloat;
112886 }
112887 function createFloat16MatrixTexture(gl, rows, columns, textureConfig) {
112888 var _tex_util$getUnpacked3 = getUnpackedMatrixTextureShapeWidthHeight(rows, columns),
112889 _tex_util$getUnpacked4 = _slicedToArray(_tex_util$getUnpacked3, 2),
112890 width = _tex_util$getUnpacked4[0],
112891 height = _tex_util$getUnpacked4[1];
112892 return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat);
112893 }
112894 function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) {
112895 return textureConfig.downloadTextureFormat;
112896 }
112897 function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) {
112898 var _tex_util$getUnpacked5 = getUnpackedMatrixTextureShapeWidthHeight(rows, columns),
112899 _tex_util$getUnpacked6 = _slicedToArray(_tex_util$getUnpacked5, 2),
112900 width = _tex_util$getUnpacked6[0],
112901 height = _tex_util$getUnpacked6[1];
112902 return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE);
112903 }
112904 function getInternalFormatForPackedMatrixTexture(textureConfig) {
112905 return textureConfig.internalFormatPackedFloat;
112906 }
112907 function createPackedMatrixTexture(gl, rows, columns, textureConfig) {
112908 var _tex_util$getPackedMa = getPackedMatrixTextureShapeWidthHeight(rows, columns),
112909 _tex_util$getPackedMa2 = _slicedToArray(_tex_util$getPackedMa, 2),
112910 width = _tex_util$getPackedMa2[0],
112911 height = _tex_util$getPackedMa2[1];
112912 return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT);
112913 }
112914 function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) {
112915 return textureConfig.internalFormatPackedHalfFloat;
112916 }
112917 function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) {
112918 var _tex_util$getPackedMa3 = getPackedMatrixTextureShapeWidthHeight(rows, columns),
112919 _tex_util$getPackedMa4 = _slicedToArray(_tex_util$getPackedMa3, 2),
112920 width = _tex_util$getPackedMa4[0],
112921 height = _tex_util$getPackedMa4[1];
112922 return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat);
112923 }
112924 function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) {
112925 var posOffset = 0; // x is the first buffer element
112926 var uvOffset = 3 * 4; // uv comes after [x y z]
112927 var stride = 3 * 4 + 2 * 4; // xyz + uv, each entry is 4-byte float.
112928 callAndCheck(gl, function () {
112929 return gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer);
112930 });
112931 var success = bindVertexBufferToProgramAttribute(gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset);
112932 return success && bindVertexBufferToProgramAttribute(gl, program, 'uv', vertexBuffer, 2, stride, uvOffset);
112933 }
112934 function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) {
112935 callAndCheck(gl, function () {
112936 return gl.bindTexture(gl.TEXTURE_2D, texture);
112937 });
112938 var dataForUpload, texelDataType, internalFormat;
112939 if (data instanceof Uint8Array) {
112940 dataForUpload = new Uint8Array(width * height * 4);
112941 texelDataType = gl.UNSIGNED_BYTE;
112942 internalFormat = gl.RGBA;
112943 } else {
112944 dataForUpload = new Float32Array(width * height * 4);
112945 texelDataType = gl.FLOAT;
112946 internalFormat = textureConfig.internalFormatPackedFloat;
112947 }
112948 dataForUpload.set(data);
112949 if (env().getNumber('WEBGL_VERSION') === 2) {
112950 callAndCheck(gl, function () {
112951 return gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, width, height, gl.RGBA, texelDataType, dataForUpload);
112952 });
112953 } else {
112954 callAndCheck(gl, function () {
112955 return gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload);
112956 });
112957 }
112958 callAndCheck(gl, function () {
112959 return gl.bindTexture(gl.TEXTURE_2D, null);
112960 });
112961 }
112962 function uploadPixelDataToTexture(gl, texture, pixels) {
112963 callAndCheck(gl, function () {
112964 return gl.bindTexture(gl.TEXTURE_2D, texture);
112965 });
112966 if (pixels.data instanceof Uint8Array) {
112967 if (env().getNumber('WEBGL_VERSION') === 2) {
112968 callAndCheck(gl, function () {
112969 return gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, pixels.width, pixels.height, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data);
112970 });
112971 } else {
112972 callAndCheck(gl, function () {
112973 return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data);
112974 });
112975 }
112976 } else {
112977 if (env().getNumber('WEBGL_VERSION') === 2) {
112978 callAndCheck(gl, function () {
112979 return gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels);
112980 });
112981 } else {
112982 callAndCheck(gl, function () {
112983 return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels);
112984 });
112985 }
112986 }
112987 callAndCheck(gl, function () {
112988 return gl.bindTexture(gl.TEXTURE_2D, null);
112989 });
112990 }
112991 function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) {
112992 // Create and bind the buffer.
112993 var buffer = gl2.createBuffer();
112994 callAndCheck(gl2, function () {
112995 return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
112996 });
112997 // Initialize the buffer to the size of the texture in bytes.
112998 var bytesPerFloat = 4;
112999 var valuesPerTexel = 4;
113000 var bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns;
113001 callAndCheck(gl2, function () {
113002 return gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ);
113003 });
113004 // Enqueue a command on the GPU command queue to copy of texture into the
113005 // buffer.
113006 callAndCheck(gl2, function () {
113007 return gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0);
113008 });
113009 callAndCheck(gl2, function () {
113010 return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
113011 });
113012 return buffer;
113013 }
113014 function downloadFloat32MatrixFromBuffer(gl, buffer, size) {
113015 var gl2 = gl;
113016 var downloadTarget = new Float32Array(size);
113017 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
113018 gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
113019 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
113020 return downloadTarget;
113021 }
113022 function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) {
113023 var _tex_util$getUnpacked7 = getUnpackedMatrixTextureShapeWidthHeight(rows, columns),
113024 _tex_util$getUnpacked8 = _slicedToArray(_tex_util$getUnpacked7, 2),
113025 w = _tex_util$getUnpacked8[0],
113026 h = _tex_util$getUnpacked8[1];
113027 var numChannels = 4;
113028 var downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels));
113029 callAndCheck(gl, function () {
113030 return gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget);
113031 });
113032 // By wrapping the buffer in a Float32Array, we use native browser IEEE 754
113033 // decoding of the 4 bytes that back each 32 bit float.
113034 return new Float32Array(downloadTarget.buffer);
113035 }
113036 function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) {
113037 var gl2 = gl;
113038 var downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols));
113039 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
113040 gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
113041 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
113042 return downloadTarget;
113043 }
113044 function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) {
113045 var packedRGBA = new Float32Array(physicalRows * physicalCols * 4);
113046 callAndCheck(gl, function () {
113047 return gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA);
113048 });
113049 return packedRGBA;
113050 }
113051
113052 var gpgpu_util = {
113053 __proto__: null,
113054 bindVertexProgramAttributeStreams: bindVertexProgramAttributeStreams,
113055 createBufferFromOutputTexture: createBufferFromOutputTexture,
113056 createFloat16MatrixTexture: createFloat16MatrixTexture,
113057 createFloat16PackedMatrixTexture: createFloat16PackedMatrixTexture,
113058 createFloat32MatrixTexture: createFloat32MatrixTexture,
113059 createIndexBuffer: createIndexBuffer,
113060 createPackedMatrixTexture: createPackedMatrixTexture,
113061 createUnsignedBytesMatrixTexture: createUnsignedBytesMatrixTexture,
113062 createVertexBuffer: createVertexBuffer,
113063 createVertexShader: createVertexShader,
113064 downloadByteEncodedFloatMatrixFromOutputTexture: downloadByteEncodedFloatMatrixFromOutputTexture,
113065 downloadFloat32MatrixFromBuffer: downloadFloat32MatrixFromBuffer,
113066 downloadMatrixFromPackedOutputTexture: downloadMatrixFromPackedOutputTexture,
113067 downloadPackedMatrixFromBuffer: downloadPackedMatrixFromBuffer,
113068 getInternalFormatForFloat16MatrixTexture: getInternalFormatForFloat16MatrixTexture,
113069 getInternalFormatForFloat16PackedMatrixTexture: getInternalFormatForFloat16PackedMatrixTexture,
113070 getInternalFormatForFloat32MatrixTexture: getInternalFormatForFloat32MatrixTexture,
113071 getInternalFormatForPackedMatrixTexture: getInternalFormatForPackedMatrixTexture,
113072 getInternalFormatForUnsignedBytesMatrixTexture: getInternalFormatForUnsignedBytesMatrixTexture,
113073 uploadDenseMatrixToTexture: uploadDenseMatrixToTexture,
113074 uploadPixelDataToTexture: uploadPixelDataToTexture
113075 };
113076
113077 var GPGPUContext = /*#__PURE__*/function () {
113078 function GPGPUContext(gl) {
113079 _classCallCheck(this, GPGPUContext);
113080 this.outputTexture = null;
113081 this.program = null;
113082 this.disposed = false;
113083 this.itemsToPoll = [];
113084 var glVersion = env().getNumber('WEBGL_VERSION');
113085 if (gl != null) {
113086 this.gl = gl;
113087 setWebGLContext(glVersion, gl);
113088 } else {
113089 this.gl = getWebGLContext(glVersion);
113090 }
113091 gl = this.gl;
113092 if (env().getNumber('WEBGL_VERSION') === 2) {
113093 var gl2 = gl;
113094 this.createVertexArray = function () {
113095 return callAndCheck(gl2, function () {
113096 return gl2.createVertexArray();
113097 });
113098 };
113099 this.bindVertexArray = function (vao) {
113100 return callAndCheck(gl2, function () {
113101 return gl2.bindVertexArray(vao);
113102 });
113103 };
113104 this.deleteVertexArray = function (vao) {
113105 return callAndCheck(gl2, function () {
113106 return gl2.deleteVertexArray(vao);
113107 });
113108 };
113109 this.getVertexArray = function () {
113110 return callAndCheck(gl2, function () {
113111 return gl2.getParameter(gl2.VERTEX_ARRAY_BINDING);
113112 });
113113 };
113114 } else if (gl != null) {
113115 var ext = gl.getExtension('OES_vertex_array_object');
113116 if (ext == null) {
113117 throw new Error('All WebGL1 implementations are expected to offer' + ' OES_vertex_array_object.');
113118 }
113119 this.createVertexArray = function () {
113120 return callAndCheck(gl, function () {
113121 return ext.createVertexArrayOES();
113122 });
113123 };
113124 this.bindVertexArray = function (vao) {
113125 return callAndCheck(gl, function () {
113126 return ext.bindVertexArrayOES(vao);
113127 });
113128 };
113129 this.deleteVertexArray = function (vao) {
113130 return callAndCheck(gl, function () {
113131 return ext.deleteVertexArrayOES(vao);
113132 });
113133 };
113134 this.getVertexArray = function () {
113135 return callAndCheck(gl, function () {
113136 return gl.getParameter(ext.VERTEX_ARRAY_BINDING_OES);
113137 });
113138 };
113139 }
113140 // WebGL 2.0 enables texture floats without an extension.
113141 var COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';
113142 var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
113143 this.parallelCompilationExtension = this.gl.getExtension('KHR_parallel_shader_compile');
113144 if (env().getNumber('WEBGL_VERSION') === 1) {
113145 var TEXTURE_FLOAT = 'OES_texture_float';
113146 var TEXTURE_HALF_FLOAT = 'OES_texture_half_float';
113147 this.textureFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_FLOAT);
113148 if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) {
113149 this.textureHalfFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT);
113150 } else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
113151 throw new Error('GL context does not support half float textures, yet the ' + 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
113152 }
113153 this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
113154 if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
113155 this.colorBufferHalfFloatExtension = getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT);
113156 } else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
113157 throw new Error('GL context does not support color renderable half floats, yet ' + 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
113158 }
113159 } else {
113160 COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float';
113161 if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {
113162 this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
113163 } else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
113164 this.colorBufferHalfFloatExtension = this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
113165 } else {
113166 throw new Error('GL context does not support color renderable floats');
113167 }
113168 }
113169 this.vertexBuffer = createVertexBuffer(this.gl);
113170 this.indexBuffer = createIndexBuffer(this.gl);
113171 this.framebuffer = createFramebuffer(this.gl);
113172 this.textureConfig = getTextureConfig(this.gl, this.textureHalfFloatExtension);
113173 }
113174 _createClass(GPGPUContext, [{
113175 key: "debug",
113176 get: function get() {
113177 return env().getBool('DEBUG');
113178 }
113179 }, {
113180 key: "dispose",
113181 value: function dispose() {
113182 var _this = this;
113183 if (this.disposed) {
113184 return;
113185 }
113186 if (this.program != null) {
113187 console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' + ' This is probably a resource leak, delete the program with ' + 'GPGPUContext.deleteProgram before disposing.');
113188 }
113189 if (this.outputTexture != null) {
113190 console.warn('Disposing a GPGPUContext that still has a bound output matrix ' + 'texture. This is probably a resource leak, delete the output ' + 'matrix texture with GPGPUContext.deleteMatrixTexture before ' + 'disposing.');
113191 }
113192 var gl = this.gl;
113193 callAndCheck(gl, function () {
113194 return gl.finish();
113195 });
113196 callAndCheck(gl, function () {
113197 return gl.bindFramebuffer(gl.FRAMEBUFFER, null);
113198 });
113199 callAndCheck(gl, function () {
113200 return gl.deleteFramebuffer(_this.framebuffer);
113201 });
113202 callAndCheck(gl, function () {
113203 return gl.bindBuffer(gl.ARRAY_BUFFER, null);
113204 });
113205 callAndCheck(gl, function () {
113206 return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null);
113207 });
113208 callAndCheck(gl, function () {
113209 return gl.deleteBuffer(_this.indexBuffer);
113210 });
113211 this.disposed = true;
113212 }
113213 }, {
113214 key: "createFloat32MatrixTexture",
113215 value: function createFloat32MatrixTexture$1(rows, columns) {
113216 this.throwIfDisposed();
113217 return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig);
113218 }
113219 }, {
113220 key: "createFloat16MatrixTexture",
113221 value: function createFloat16MatrixTexture$1(rows, columns) {
113222 this.throwIfDisposed();
113223 return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig);
113224 }
113225 }, {
113226 key: "createUnsignedBytesMatrixTexture",
113227 value: function createUnsignedBytesMatrixTexture$1(rows, columns) {
113228 this.throwIfDisposed();
113229 return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig);
113230 }
113231 }, {
113232 key: "uploadPixelDataToTexture",
113233 value: function uploadPixelDataToTexture$1(texture, pixels) {
113234 this.throwIfDisposed();
113235 uploadPixelDataToTexture(this.gl, texture, pixels);
113236 }
113237 }, {
113238 key: "uploadDenseMatrixToTexture",
113239 value: function uploadDenseMatrixToTexture$1(texture, width, height, data) {
113240 this.throwIfDisposed();
113241 uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig);
113242 }
113243 }, {
113244 key: "createFloat16PackedMatrixTexture",
113245 value: function createFloat16PackedMatrixTexture$1(rows, columns) {
113246 this.throwIfDisposed();
113247 return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
113248 }
113249 }, {
113250 key: "createPackedMatrixTexture",
113251 value: function createPackedMatrixTexture$1(rows, columns) {
113252 this.throwIfDisposed();
113253 return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
113254 }
113255 }, {
113256 key: "deleteMatrixTexture",
113257 value: function deleteMatrixTexture(texture) {
113258 var _this2 = this;
113259 this.throwIfDisposed();
113260 if (this.outputTexture === texture) {
113261 unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
113262 this.outputTexture = null;
113263 }
113264 callAndCheck(this.gl, function () {
113265 return _this2.gl.deleteTexture(texture);
113266 });
113267 }
113268 }, {
113269 key: "downloadByteEncodedFloatMatrixFromOutputTexture",
113270 value: function downloadByteEncodedFloatMatrixFromOutputTexture$1(texture, rows, columns) {
113271 var _this3 = this;
113272 return this.downloadMatrixDriver(texture, function () {
113273 return downloadByteEncodedFloatMatrixFromOutputTexture(_this3.gl, rows, columns, _this3.textureConfig);
113274 });
113275 }
113276 }, {
113277 key: "downloadPackedMatrixFromBuffer",
113278 value: function downloadPackedMatrixFromBuffer$1(buffer, batch, rows, columns, physicalRows, physicalCols) {
113279 return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig);
113280 }
113281 }, {
113282 key: "downloadFloat32MatrixFromBuffer",
113283 value: function downloadFloat32MatrixFromBuffer$1(buffer, size) {
113284 return downloadFloat32MatrixFromBuffer(this.gl, buffer, size);
113285 }
113286 }, {
113287 key: "createBufferFromTexture",
113288 value: function createBufferFromTexture(texture, rows, columns) {
113289 this.bindTextureToFrameBuffer(texture);
113290 var result = createBufferFromOutputTexture(this.gl, rows, columns, this.textureConfig);
113291 this.unbindTextureToFrameBuffer();
113292 return result;
113293 }
113294 }, {
113295 key: "createAndWaitForFence",
113296 value: function createAndWaitForFence() {
113297 var fenceContext = this.createFence(this.gl);
113298 return this.pollFence(fenceContext);
113299 }
113300 }, {
113301 key: "createFence",
113302 value: function createFence(gl) {
113303 var _this4 = this;
113304 var query;
113305 var isFencePassed;
113306 if (env().getBool('WEBGL_FENCE_API_ENABLED')) {
113307 var gl2 = gl;
113308 var sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0);
113309 gl.flush();
113310 isFencePassed = function isFencePassed() {
113311 var status = gl2.clientWaitSync(sync, 0, 0);
113312 return status === gl2.ALREADY_SIGNALED || status === gl2.CONDITION_SATISFIED;
113313 };
113314 query = sync;
113315 } else if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
113316 query = this.beginQuery();
113317 this.endQuery();
113318 isFencePassed = function isFencePassed() {
113319 return _this4.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
113320 };
113321 } else {
113322 // If we have no way to fence, return true immediately. This will fire in
113323 // WebGL 1.0 when there is no disjoint query timer. In this case, because
113324 // the fence passes immediately, we'll immediately ask for a download of
113325 // the texture, which will cause the UI thread to hang.
113326 isFencePassed = function isFencePassed() {
113327 return true;
113328 };
113329 }
113330 return {
113331 query: query,
113332 isFencePassed: isFencePassed
113333 };
113334 }
113335 }, {
113336 key: "downloadMatrixFromPackedTexture",
113337 value: function downloadMatrixFromPackedTexture(texture, physicalRows, physicalCols) {
113338 var _this5 = this;
113339 return this.downloadMatrixDriver(texture, function () {
113340 return downloadMatrixFromPackedOutputTexture(_this5.gl, physicalRows, physicalCols);
113341 });
113342 }
113343 }, {
113344 key: "createProgram",
113345 value: function createProgram$1(fragmentShader) {
113346 var _this6 = this;
113347 this.throwIfDisposed();
113348 var gl = this.gl;
113349 if (this.vertexShader == null) {
113350 this.vertexShader = createVertexShader(gl);
113351 }
113352 var program = createProgram(gl);
113353 callAndCheck(gl, function () {
113354 return gl.attachShader(program, _this6.vertexShader);
113355 });
113356 callAndCheck(gl, function () {
113357 return gl.attachShader(program, fragmentShader);
113358 });
113359 linkProgram(gl, program);
113360 var program2 = Object.assign(program, {
113361 vao: this.createVertexArray()
113362 });
113363 if (this.debug) {
113364 validateProgram(gl, program2);
113365 }
113366 return program2;
113367 }
113368 }, {
113369 key: "buildVao",
113370 value: function buildVao(program) {
113371 var _this7 = this;
113372 this.setProgram(program);
113373 this.bindVertexArray(program.vao);
113374 var gl = this.gl;
113375 // Bind index buffer, and vertex buffers based on program attrib
113376 // locations.
113377 callAndCheck(gl, function () {
113378 return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, _this7.indexBuffer);
113379 });
113380 bindVertexProgramAttributeStreams(gl, program, this.vertexBuffer);
113381 }
113382 }, {
113383 key: "deleteProgram",
113384 value: function deleteProgram(program) {
113385 var _this8 = this;
113386 this.throwIfDisposed();
113387 if (program === this.program) {
113388 this.program = null;
113389 }
113390 if (program != null) {
113391 callAndCheck(this.gl, function () {
113392 return _this8.gl.deleteProgram(program);
113393 });
113394 this.deleteVertexArray(program.vao);
113395 }
113396 }
113397 }, {
113398 key: "setProgram",
113399 value: function setProgram(program) {
113400 var _this9 = this;
113401 this.throwIfDisposed();
113402 this.program = program;
113403 if (this.program != null) {
113404 if (this.debug) {
113405 validateProgram(this.gl, this.program);
113406 }
113407 }
113408 callAndCheck(this.gl, function () {
113409 return _this9.gl.useProgram(program);
113410 });
113411 }
113412 }, {
113413 key: "getUniformLocation",
113414 value: function getUniformLocation(program, uniformName) {
113415 var shouldThrow = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true;
113416 this.throwIfDisposed();
113417 if (shouldThrow) {
113418 return getProgramUniformLocationOrThrow(this.gl, program, uniformName);
113419 } else {
113420 return getProgramUniformLocation(this.gl, program, uniformName);
113421 }
113422 }
113423 }, {
113424 key: "getAttributeLocation",
113425 value: function getAttributeLocation(program, attribute) {
113426 var _this10 = this;
113427 this.throwIfDisposed();
113428 return callAndCheck(this.gl, function () {
113429 return _this10.gl.getAttribLocation(program, attribute);
113430 });
113431 }
113432 }, {
113433 key: "getUniformLocationNoThrow",
113434 value: function getUniformLocationNoThrow(program, uniformName) {
113435 this.throwIfDisposed();
113436 return this.gl.getUniformLocation(program, uniformName);
113437 }
113438 }, {
113439 key: "setInputMatrixTexture",
113440 value: function setInputMatrixTexture(inputMatrixTexture, uniformLocation, textureUnit) {
113441 this.throwIfDisposed();
113442 this.throwIfNoProgram();
113443 bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit);
113444 }
113445 }, {
113446 key: "setOutputMatrixTexture",
113447 value: function setOutputMatrixTexture(outputMatrixTexture, rows, columns) {
113448 this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows);
113449 }
113450 }, {
113451 key: "setOutputPackedMatrixTexture",
113452 value: function setOutputPackedMatrixTexture(outputPackedMatrixTexture, rows, columns) {
113453 this.throwIfDisposed();
113454 var _tex_util$getPackedMa = getPackedMatrixTextureShapeWidthHeight(rows, columns),
113455 _tex_util$getPackedMa2 = _slicedToArray(_tex_util$getPackedMa, 2),
113456 width = _tex_util$getPackedMa2[0],
113457 height = _tex_util$getPackedMa2[1];
113458 this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height);
113459 }
113460 }, {
113461 key: "setOutputMatrixWriteRegion",
113462 value: function setOutputMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
113463 this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows);
113464 }
113465 }, {
113466 key: "setOutputPackedMatrixWriteRegion",
113467 value: function setOutputPackedMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
113468 throw new Error('setOutputPackedMatrixWriteRegion not implemented.');
113469 }
113470 }, {
113471 key: "debugValidate",
113472 value: function debugValidate() {
113473 if (this.program != null) {
113474 validateProgram(this.gl, this.program);
113475 }
113476 validateFramebuffer(this.gl);
113477 }
113478 }, {
113479 key: "executeProgram",
113480 value: function executeProgram() {
113481 this.throwIfDisposed();
113482 this.throwIfNoProgram();
113483 var gl = this.gl;
113484 if (this.debug) {
113485 var boundVao = this.getVertexArray();
113486 console.assert(boundVao === this.program.vao, 'VAO changed between setProgram and executeProgram!');
113487 this.debugValidate();
113488 }
113489 callAndCheck(gl, function () {
113490 return gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0);
113491 });
113492 }
113493 }, {
113494 key: "blockUntilAllProgramsCompleted",
113495 value: function blockUntilAllProgramsCompleted() {
113496 var _this11 = this;
113497 this.throwIfDisposed();
113498 callAndCheck(this.gl, function () {
113499 return _this11.gl.finish();
113500 });
113501 }
113502 }, {
113503 key: "getQueryTimerExtension",
113504 value: function getQueryTimerExtension() {
113505 if (this.disjointQueryTimerExtension == null) {
113506 this.disjointQueryTimerExtension = getExtensionOrThrow(this.gl, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ? 'EXT_disjoint_timer_query_webgl2' : 'EXT_disjoint_timer_query');
113507 }
113508 return this.disjointQueryTimerExtension;
113509 }
113510 }, {
113511 key: "getQueryTimerExtensionWebGL2",
113512 value: function getQueryTimerExtensionWebGL2() {
113513 return this.getQueryTimerExtension();
113514 }
113515 }, {
113516 key: "getQueryTimerExtensionWebGL1",
113517 value: function getQueryTimerExtensionWebGL1() {
113518 return this.getQueryTimerExtension();
113519 }
113520 }, {
113521 key: "beginQuery",
113522 value: function beginQuery() {
113523 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
113524 var gl2 = this.gl;
113525 var _ext = this.getQueryTimerExtensionWebGL2();
113526 var _query = gl2.createQuery();
113527 gl2.beginQuery(_ext.TIME_ELAPSED_EXT, _query);
113528 return _query;
113529 }
113530 var ext = this.getQueryTimerExtensionWebGL1();
113531 var query = ext.createQueryEXT();
113532 ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query);
113533 return query;
113534 }
113535 }, {
113536 key: "endQuery",
113537 value: function endQuery() {
113538 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
113539 var gl2 = this.gl;
113540 var _ext2 = this.getQueryTimerExtensionWebGL2();
113541 gl2.endQuery(_ext2.TIME_ELAPSED_EXT);
113542 return;
113543 }
113544 var ext = this.getQueryTimerExtensionWebGL1();
113545 ext.endQueryEXT(ext.TIME_ELAPSED_EXT);
113546 }
113547 }, {
113548 key: "waitForQueryAndGetTime",
113549 value: function () {
113550 var _waitForQueryAndGetTime = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(query) {
113551 var _this12 = this;
113552 return _regeneratorRuntime().wrap(function _callee$(_context) {
113553 while (1) switch (_context.prev = _context.next) {
113554 case 0:
113555 _context.next = 2;
113556 return repeatedTry(function () {
113557 return _this12.disposed ||
113558 // while testing contexts are created / disposed
113559 // in rapid succession, so without this check we
113560 // may poll for the query timer indefinitely
113561 _this12.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
113562 });
113563 case 2:
113564 return _context.abrupt("return", this.getQueryTime(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')));
113565 case 3:
113566 case "end":
113567 return _context.stop();
113568 }
113569 }, _callee, this);
113570 }));
113571 function waitForQueryAndGetTime(_x) {
113572 return _waitForQueryAndGetTime.apply(this, arguments);
113573 }
113574 return waitForQueryAndGetTime;
113575 }()
113576 }, {
113577 key: "getQueryTime",
113578 value: function getQueryTime(query, queryTimerVersion) {
113579 if (queryTimerVersion === 0) {
113580 return null;
113581 }
113582 if (queryTimerVersion === 2) {
113583 var gl2 = this.gl;
113584 var timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT);
113585 // Return milliseconds.
113586 return timeElapsedNanos / 1000000;
113587 } else {
113588 var ext = this.getQueryTimerExtensionWebGL1();
113589 var _timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT);
113590 // Return milliseconds.
113591 return _timeElapsedNanos / 1000000;
113592 }
113593 }
113594 }, {
113595 key: "isQueryAvailable",
113596 value: function isQueryAvailable(query, queryTimerVersion) {
113597 if (queryTimerVersion === 0) {
113598 return true;
113599 }
113600 if (queryTimerVersion === 2) {
113601 var gl2 = this.gl;
113602 var ext = this.getQueryTimerExtensionWebGL2();
113603 var available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);
113604 if (this.disjoint == null) {
113605 this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
113606 }
113607 return available && !this.disjoint;
113608 } else {
113609 var _ext3 = this.getQueryTimerExtensionWebGL1();
113610 var _available = _ext3.getQueryObjectEXT(query, _ext3.QUERY_RESULT_AVAILABLE_EXT);
113611 if (this.disjoint == null) {
113612 this.disjoint = this.gl.getParameter(_ext3.GPU_DISJOINT_EXT);
113613 }
113614 return _available && !this.disjoint;
113615 }
113616 }
113617 }, {
113618 key: "pollFence",
113619 value: function pollFence(fenceContext) {
113620 var _this13 = this;
113621 return new Promise(function (resolve) {
113622 _this13.addItemToPoll(function () {
113623 return fenceContext.isFencePassed();
113624 }, function () {
113625 return resolve();
113626 });
113627 });
113628 }
113629 }, {
113630 key: "pollItems",
113631 value: function pollItems() {
113632 // Find the last query that has finished.
113633 var index = linearSearchLastTrue(this.itemsToPoll.map(function (x) {
113634 return x.isDoneFn;
113635 }));
113636 for (var i = 0; i <= index; ++i) {
113637 var resolveFn = this.itemsToPoll[i].resolveFn;
113638 resolveFn();
113639 }
113640 this.itemsToPoll = this.itemsToPoll.slice(index + 1);
113641 }
113642 }, {
113643 key: "addItemToPoll",
113644 value: function addItemToPoll(isDoneFn, resolveFn) {
113645 var _this14 = this;
113646 this.itemsToPoll.push({
113647 isDoneFn: isDoneFn,
113648 resolveFn: resolveFn
113649 });
113650 if (this.itemsToPoll.length > 1) {
113651 // We already have a running loop that polls.
113652 return;
113653 }
113654 // Start a new loop that polls.
113655 var scheduleFn = undefined;
113656 if ('setTimeoutCustom' in env().platform) {
113657 scheduleFn = env().platform.setTimeoutCustom.bind(env().platform);
113658 }
113659 repeatedTry(function () {
113660 _this14.pollItems();
113661 // End the loop if no more items to poll.
113662 return _this14.itemsToPoll.length === 0;
113663 }, function () {
113664 return 0;
113665 }, null, scheduleFn);
113666 }
113667 }, {
113668 key: "bindTextureToFrameBuffer",
113669 value: function bindTextureToFrameBuffer(texture) {
113670 this.throwIfDisposed();
113671 bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer);
113672 if (this.debug) {
113673 validateFramebuffer(this.gl);
113674 }
113675 }
113676 }, {
113677 key: "unbindTextureToFrameBuffer",
113678 value: function unbindTextureToFrameBuffer() {
113679 if (this.outputTexture != null) {
113680 bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer);
113681 if (this.debug) {
113682 validateFramebuffer(this.gl);
113683 }
113684 } else {
113685 unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
113686 }
113687 }
113688 }, {
113689 key: "downloadMatrixDriver",
113690 value: function downloadMatrixDriver(texture, downloadAndDecode) {
113691 this.bindTextureToFrameBuffer(texture);
113692 var result = downloadAndDecode();
113693 this.unbindTextureToFrameBuffer();
113694 return result;
113695 }
113696 }, {
113697 key: "setOutputMatrixTextureDriver",
113698 value: function setOutputMatrixTextureDriver(outputMatrixTextureMaybePacked, width, height) {
113699 this.throwIfDisposed();
113700 var gl = this.gl;
113701 bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer);
113702 if (this.debug) {
113703 validateFramebuffer(gl);
113704 }
113705 this.outputTexture = outputMatrixTextureMaybePacked;
113706 callAndCheck(gl, function () {
113707 return gl.viewport(0, 0, width, height);
113708 });
113709 callAndCheck(gl, function () {
113710 return gl.scissor(0, 0, width, height);
113711 });
113712 }
113713 }, {
113714 key: "setOutputMatrixWriteRegionDriver",
113715 value: function setOutputMatrixWriteRegionDriver(x, y, width, height) {
113716 var _this15 = this;
113717 this.throwIfDisposed();
113718 callAndCheck(this.gl, function () {
113719 return _this15.gl.scissor(x, y, width, height);
113720 });
113721 }
113722 }, {
113723 key: "throwIfDisposed",
113724 value: function throwIfDisposed() {
113725 if (this.disposed) {
113726 throw new Error('Attempted to use disposed GPGPUContext.');
113727 }
113728 }
113729 }, {
113730 key: "throwIfNoProgram",
113731 value: function throwIfNoProgram() {
113732 if (this.program == null) {
113733 throw new Error('No GPU program is currently set.');
113734 }
113735 }
113736 }]);
113737 return GPGPUContext;
113738 }();
113739 /**
113740 * Finds the index of the last true element using linear search.
113741 * Note: We can't do binary search because Chrome expects us to explicitly
113742 * test all fences before download:
113743 * https://github.com/tensorflow/tfjs/issues/1145
113744 */
113745 function linearSearchLastTrue(arr) {
113746 var i = 0;
113747 for (; i < arr.length; ++i) {
113748 var isDone = arr[i]();
113749 if (!isDone) {
113750 break;
113751 }
113752 }
113753 return i - 1;
113754 }
113755
113756 /**
113757 * @license
113758 * Copyright 2020 Google LLC. All Rights Reserved.
113759 * Licensed under the Apache License, Version 2.0 (the "License");
113760 * you may not use this file except in compliance with the License.
113761 * You may obtain a copy of the License at
113762 *
113763 * http://www.apache.org/licenses/LICENSE-2.0
113764 *
113765 * Unless required by applicable law or agreed to in writing, software
113766 * distributed under the License is distributed on an "AS IS" BASIS,
113767 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
113768 * See the License for the specific language governing permissions and
113769 * limitations under the License.
113770 * =============================================================================
113771 */
113772 var addImplCPU = addImpl,
113773 bincountImplCPU = bincountImpl,
113774 bincountReduceImplCPU = bincountReduceImpl,
113775 bitwiseAndImplCPU = bitwiseAndImpl,
113776 castImplCPU = castImpl,
113777 ceilImplCPU = ceilImpl,
113778 concatImplCPU = concatImpl$1,
113779 equalImplCPU = equalImpl,
113780 expImplCPU = expImpl,
113781 expm1ImplCPU = expm1Impl,
113782 floorImplCPU = floorImpl,
113783 gatherNdImplCPU = gatherNdImpl,
113784 gatherV2ImplCPU = gatherV2Impl,
113785 greaterImplCPU = greaterImpl,
113786 greaterEqualImplCPU = greaterEqualImpl,
113787 lessImplCPU = lessImpl,
113788 lessEqualImplCPU = lessEqualImpl,
113789 linSpaceImplCPU = linSpaceImpl,
113790 logImplCPU = logImpl,
113791 maxImplCPU = maxImpl$1,
113792 maximumImplCPU = maximumImpl,
113793 minimumImplCPU = minimumImpl,
113794 multiplyImplCPU = multiplyImpl,
113795 negImplCPU = negImpl,
113796 notEqualImplCPU = notEqualImpl,
113797 prodImplCPU = prodImpl,
113798 raggedGatherImplCPU = raggedGatherImpl,
113799 raggedRangeImplCPU = raggedRangeImpl,
113800 raggedTensorToTensorImplCPU = raggedTensorToTensorImpl,
113801 rangeImplCPU = rangeImpl,
113802 rsqrtImplCPU = rsqrtImpl,
113803 scatterImplCPU = scatterImpl,
113804 sigmoidImplCPU = sigmoidImpl,
113805 simpleAbsImplCPU = simpleAbsImpl,
113806 sliceImplCPU = sliceImpl,
113807 sparseFillEmptyRowsImplCPU = sparseFillEmptyRowsImpl,
113808 sparseReshapeImplCPU = sparseReshapeImpl,
113809 sparseSegmentReductionImplCPU = sparseSegmentReductionImpl,
113810 sqrtImplCPU = sqrtImpl,
113811 staticRegexReplaceImplCPU = staticRegexReplaceImpl,
113812 stridedSliceImplCPU = stridedSliceImpl,
113813 stringNGramsImplCPU = stringNGramsImpl,
113814 stringSplitImplCPU = stringSplitImpl,
113815 stringToHashBucketFastImplCPU = stringToHashBucketFastImpl,
113816 subImplCPU = subImpl,
113817 tileImplCPU = tileImpl,
113818 topKImplCPU = topKImpl,
113819 transposeImplCPU = transposeImpl$1,
113820 uniqueImplCPU = uniqueImpl;
113821
113822 /**
113823 * @license
113824 * Copyright 2018 Google LLC. All Rights Reserved.
113825 * Licensed under the Apache License, Version 2.0 (the "License");
113826 * you may not use this file except in compliance with the License.
113827 * You may obtain a copy of the License at
113828 *
113829 * http://www.apache.org/licenses/LICENSE-2.0
113830 *
113831 * Unless required by applicable law or agreed to in writing, software
113832 * distributed under the License is distributed on an "AS IS" BASIS,
113833 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
113834 * See the License for the specific language governing permissions and
113835 * limitations under the License.
113836 * =============================================================================
113837 */
113838 function getVecChannels(name, rank) {
113839 return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(function (d) {
113840 return "".concat(name, ".").concat(d);
113841 });
113842 }
113843 function getChannels(name, rank) {
113844 if (rank === 1) {
113845 return [name];
113846 }
113847 return getVecChannels(name, rank);
113848 }
113849 function getSourceCoords$2(rank, dims) {
113850 if (rank === 1) {
113851 return 'rc';
113852 }
113853 var coords = '';
113854 for (var i = 0; i < rank; i++) {
113855 coords += dims[i];
113856 if (i < rank - 1) {
113857 coords += ',';
113858 }
113859 }
113860 return coords;
113861 }
113862
113863 var PackProgram = /*#__PURE__*/function () {
113864 function PackProgram(outputShape) {
113865 _classCallCheck(this, PackProgram);
113866 this.variableNames = ['A'];
113867 this.packedInputs = false;
113868 this.packedOutput = true;
113869 // Only input / output 3D tensors.
113870 this.outputShape = outputShape;
113871 this.rank = outputShape.length;
113872 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
113873 if (this.rank === 0) {
113874 this.userCode = "\n void main() {\n setOutput(vec4(getA(), 0., 0., 0.));\n }\n ";
113875 } else {
113876 var channels = getChannels('rc', this.rank);
113877 var dtype = getCoordsDataType(this.rank);
113878 var outOfBoundsCondition = this.getOutOfBoundsCondition(channels);
113879 var setup = this.getSetup(channels);
113880 var output = this.getOutput(channels);
113881 this.userCode = "\n void main() {\n ".concat(dtype, " rc = getOutputCoords();\n\n if(").concat(outOfBoundsCondition, ") {\n setOutput(vec4(0));\n } else {\n ").concat(setup, "\n\n setOutput(vec4(").concat(output, "));\n }\n }\n ");
113882 }
113883 }
113884 _createClass(PackProgram, [{
113885 key: "getSourceCoordsArr",
113886 value: function getSourceCoordsArr(dims) {
113887 var coords = [];
113888 for (var row = 0; row <= 1; row++) {
113889 for (var col = 0; col <= 1; col++) {
113890 var coord = "".concat(row === 0 ? 'r' : 'rp1', ", ").concat(col === 0 ? 'c' : 'cp1');
113891 for (var d = 2; d < this.rank; d++) {
113892 coord = "".concat(dims[dims.length - 1 - d], ",") + coord;
113893 }
113894 coords.push(coord);
113895 }
113896 }
113897 return coords;
113898 }
113899 }, {
113900 key: "getOutOfBoundsCondition",
113901 value: function getOutOfBoundsCondition(dims) {
113902 if (this.rank === 1) {
113903 return "rc > ".concat(this.enableShapeUniforms ? 'outShape' : this.outputShape[0]);
113904 }
113905 var cond = '';
113906 for (var i = this.rank - 2; i < this.rank; i++) {
113907 cond += "".concat(dims[i], " >= ").concat(this.enableShapeUniforms ? "outShape[".concat(i, "]") : this.outputShape[i]);
113908 if (i < this.rank - 1) {
113909 cond += '||';
113910 }
113911 }
113912 return cond;
113913 }
113914 }, {
113915 key: "getSetup",
113916 value: function getSetup(dims) {
113917 if (this.rank === 1) {
113918 return '';
113919 }
113920 var innerDims = dims.slice(-2);
113921 var col = this.enableShapeUniforms ? "outShape[".concat(this.rank, " - 1]") : this.outputShape[this.rank - 1];
113922 var row = this.enableShapeUniforms ? "outShape[".concat(this.rank, " - 2]") : this.outputShape[this.rank - 2];
113923 return "\n int r = ".concat(innerDims[0], ";\n int c = ").concat(innerDims[1], ";\n int rp1 = r + 1;\n int cp1 = c + 1;\n\n bool cEdge = cp1 >= ").concat(col, ";\n bool rEdge = rp1 >= ").concat(row, ";\n ");
113924 }
113925 }, {
113926 key: "getOutput",
113927 value: function getOutput(dims) {
113928 var sourceCoords = this.getSourceCoordsArr(dims);
113929 if (this.rank === 1) {
113930 var outShape = this.enableShapeUniforms ? 'outShape' : this.outputShape[0];
113931 return "getA(rc), (rc + 1 >= ".concat(outShape, " ? 0. : getA(rc + 1)), 0, 0");
113932 }
113933 return "getA(".concat(sourceCoords[0], "),\n cEdge ? 0. : getA(").concat(sourceCoords[1], "),\n rEdge ? 0. : getA(").concat(sourceCoords[2], "),\n rEdge || cEdge ? 0. : getA(").concat(sourceCoords[3], ")");
113934 }
113935 }]);
113936 return PackProgram;
113937 }();
113938
113939 var ReshapePackedProgram = /*#__PURE__*/_createClass(function ReshapePackedProgram(outputShape, inputShape) {
113940 _classCallCheck(this, ReshapePackedProgram);
113941 this.variableNames = ['A'];
113942 this.packedInputs = true;
113943 this.packedOutput = true;
113944 this.customUniforms = [{
113945 name: 'inputShape',
113946 type: 'ivec3'
113947 }];
113948 this.outputShape = outputShape;
113949 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
113950 var mainLoop = "";
113951 for (var i = 0; i < 4; i++) {
113952 var thisRC = "thisRC = rc;";
113953 if (i % 2 === 1) {
113954 thisRC += "thisRC.z += 1;";
113955 }
113956 if (i > 1) {
113957 thisRC += "thisRC.y += 1;";
113958 }
113959 mainLoop += "\n ".concat(thisRC, "\n ").concat(i > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : '', "\n int flatIndex = getFlatIndex(thisRC);\n\n ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);\n vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));\n\n result[").concat(i, "] =\n getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);\n ").concat(i > 0 ? '}' : '', "\n ");
113960 }
113961 this.userCode = "\n ".concat(getReshapedInputCoords(inputShape, this.enableShapeUniforms), "\n ").concat(this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape), "\n\n void main() {\n ivec3 rc = getOutputCoords();\n\n vec4 result = vec4(0.);\n\n ivec3 thisRC;\n int rows = ").concat(this.enableShapeUniforms ? 'outShape[1]' : outputShape[1], ";\n int cols = ").concat(this.enableShapeUniforms ? 'outShape[2]' : outputShape[2], ";\n\n ").concat(mainLoop, "\n\n setOutput(result);\n }\n ");
113962 });
113963 function getReshapedInputCoords(shape, enableShapeUniforms) {
113964 var coordsFromIndexSnippet = enableShapeUniforms ? getLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], 'inputShape') : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
113965 return "\n ivec3 inputCoordsFromReshapedOutCoords(int index) {\n ".concat(coordsFromIndexSnippet, "\n return ivec3(r, c, d);\n }\n ");
113966 }
113967
113968 var TextureManager = /*#__PURE__*/function () {
113969 function TextureManager(gpgpu) {
113970 _classCallCheck(this, TextureManager);
113971 this.gpgpu = gpgpu;
113972 this.numUsedTextures = 0;
113973 this.numFreeTextures = 0;
113974 this._numBytesAllocated = 0;
113975 // Number of bytes that have been allocated and available for reuse.
113976 this._numBytesFree = 0;
113977 this.freeTextures = {};
113978 this.usedTextures = {};
113979 this.logEnabled = false;
113980 }
113981 _createClass(TextureManager, [{
113982 key: "acquireTexture",
113983 value: function acquireTexture(shapeRC, usage, isPacked) {
113984 var physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked);
113985 var shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked);
113986 if (!(shapeKey in this.freeTextures)) {
113987 this.freeTextures[shapeKey] = [];
113988 }
113989 if (!(shapeKey in this.usedTextures)) {
113990 this.usedTextures[shapeKey] = [];
113991 }
113992 var texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
113993 if (this.freeTextures[shapeKey].length > 0) {
113994 this.numFreeTextures--;
113995 this.numUsedTextures++;
113996 this._numBytesFree -= texBytes;
113997 this.log();
113998 var _newTexture = this.freeTextures[shapeKey].pop();
113999 this.usedTextures[shapeKey].push(_newTexture);
114000 return _newTexture;
114001 }
114002 var newTexture;
114003 if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) {
114004 newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]);
114005 } else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) {
114006 newTexture = this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]);
114007 } else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) {
114008 newTexture = this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]);
114009 } else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) {
114010 newTexture = this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]);
114011 } else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) {
114012 newTexture = this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]);
114013 }
114014 this.usedTextures[shapeKey].push(newTexture);
114015 this.numUsedTextures++;
114016 this._numBytesAllocated += texBytes;
114017 this.log();
114018 return newTexture;
114019 }
114020 }, {
114021 key: "releaseTexture",
114022 value: function releaseTexture(texture, shape, logicalTexType, isPacked) {
114023 if (this.freeTextures == null) {
114024 // Already disposed.
114025 return;
114026 }
114027 var physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked);
114028 var shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked);
114029 if (!(shapeKey in this.freeTextures)) {
114030 this.freeTextures[shapeKey] = [];
114031 }
114032 var texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
114033 var deleteTexThreshold = env().getNumber('WEBGL_DELETE_TEXTURE_THRESHOLD');
114034 if (deleteTexThreshold !== -1 && this._numBytesAllocated > deleteTexThreshold) {
114035 this.gpgpu.deleteMatrixTexture(texture.texture);
114036 this._numBytesAllocated -= texBytes;
114037 } else {
114038 this.freeTextures[shapeKey].push(texture);
114039 this.numFreeTextures++;
114040 this._numBytesFree += texBytes;
114041 }
114042 this.numUsedTextures--;
114043 var texList = this.usedTextures[shapeKey];
114044 var texIndex = texList && texList.indexOf(texture);
114045 if (texIndex == null || texIndex < 0) {
114046 throw new Error('Cannot release a texture that was never provided by this ' + 'texture manager');
114047 }
114048 texList[texIndex] = texList[texList.length - 1];
114049 texList.pop();
114050 this.log();
114051 }
114052 }, {
114053 key: "log",
114054 value: function log() {
114055 if (!this.logEnabled) {
114056 return;
114057 }
114058 var total = this.numFreeTextures + this.numUsedTextures;
114059 console.log('Free/Used', "".concat(this.numFreeTextures, " / ").concat(this.numUsedTextures), "(".concat(total, ")"));
114060 var freeRatio = this._numBytesFree / this._numBytesAllocated;
114061 console.log("Bytes allocated: ".concat(this._numBytesAllocated));
114062 console.log("Bytes unused: ".concat(this._numBytesFree, " (").concat(Math.round(100 * freeRatio), "%)"));
114063 }
114064 }, {
114065 key: "numBytesAllocated",
114066 get: function get() {
114067 return this._numBytesAllocated;
114068 }
114069 }, {
114070 key: "numBytesFree",
114071 get: function get() {
114072 return this._numBytesFree;
114073 }
114074 }, {
114075 key: "getNumUsedTextures",
114076 value: function getNumUsedTextures() {
114077 return this.numUsedTextures;
114078 }
114079 }, {
114080 key: "getNumFreeTextures",
114081 value: function getNumFreeTextures() {
114082 return this.numFreeTextures;
114083 }
114084 }, {
114085 key: "dispose",
114086 value: function dispose() {
114087 var _this = this;
114088 if (this.freeTextures == null) {
114089 // Already disposed.
114090 return;
114091 }
114092 for (var texShape in this.freeTextures) {
114093 this.freeTextures[texShape].forEach(function (tex) {
114094 _this.gpgpu.deleteMatrixTexture(tex.texture);
114095 });
114096 }
114097 for (var _texShape in this.usedTextures) {
114098 this.usedTextures[_texShape].forEach(function (tex) {
114099 _this.gpgpu.deleteMatrixTexture(tex.texture);
114100 });
114101 }
114102 // TODO: Assign non-null value (empty object) to textures after disposed.
114103 this.freeTextures = null;
114104 this.usedTextures = null;
114105 this.numUsedTextures = 0;
114106 this.numFreeTextures = 0;
114107 this._numBytesAllocated = 0;
114108 this._numBytesFree = 0;
114109 }
114110 }]);
114111 return TextureManager;
114112 }();
114113 function numBytesForInternalFormat(gl, internalFormat) {
114114 // tslint:disable-next-line:no-any
114115 var glany = gl;
114116 if (internalFormat === glany.R32F) {
114117 return 4;
114118 } else if (internalFormat === glany.R16F) {
114119 return 2;
114120 } else if (internalFormat === glany.RGBA32F) {
114121 return 16;
114122 } else if (internalFormat === gl.RGBA) {
114123 return 16;
114124 } else if (internalFormat === glany.RGBA16F) {
114125 return 8;
114126 } else if (internalFormat === glany.RGBA8) {
114127 return 4;
114128 }
114129 throw new Error("Unknown internal format ".concat(internalFormat));
114130 }
114131 function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) {
114132 // It is not possible to infer packed status from the texture type because
114133 // depending on the textureConfig, different texture types may resolve to the
114134 // same internal format (e.g. in WebGL1, the internal format for
114135 // UNPACKED_FLOAT16 textures is gl.RGBA). Therefore we pass in `isPacked`
114136 // explicitly.
114137 var internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig);
114138 var numElements;
114139 if (isPacked) {
114140 var _getPackedMatrixTextu = getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]),
114141 _getPackedMatrixTextu2 = _slicedToArray(_getPackedMatrixTextu, 2),
114142 packedWidth = _getPackedMatrixTextu2[0],
114143 packedHeight = _getPackedMatrixTextu2[1];
114144 numElements = packedWidth * packedHeight;
114145 } else {
114146 var _getUnpackedMatrixTex = getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]),
114147 _getUnpackedMatrixTex2 = _slicedToArray(_getUnpackedMatrixTex, 2),
114148 width = _getUnpackedMatrixTex2[0],
114149 height = _getUnpackedMatrixTex2[1];
114150 numElements = width * height;
114151 }
114152 var bytesPerElement = numBytesForInternalFormat(gl, internalFormat);
114153 return numElements * bytesPerElement;
114154 }
114155 function internalFormatForPhysicalTexType(physicalTexType, textureConfig) {
114156 switch (physicalTexType) {
114157 case PhysicalTextureType.PACKED_2X2_FLOAT32:
114158 return getInternalFormatForPackedMatrixTexture(textureConfig);
114159 case PhysicalTextureType.PACKED_2X2_FLOAT16:
114160 return getInternalFormatForFloat16PackedMatrixTexture(textureConfig);
114161 case PhysicalTextureType.UNPACKED_FLOAT32:
114162 return getInternalFormatForFloat32MatrixTexture(textureConfig);
114163 case PhysicalTextureType.UNPACKED_FLOAT16:
114164 return getInternalFormatForFloat16MatrixTexture(textureConfig);
114165 case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE:
114166 return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig);
114167 default:
114168 throw new Error("Unknown physical texture type ".concat(physicalTexType));
114169 }
114170 }
114171 function getPhysicalTextureForRendering(isPacked) {
114172 if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) {
114173 if (isPacked) {
114174 return PhysicalTextureType.PACKED_2X2_FLOAT32;
114175 }
114176 return PhysicalTextureType.UNPACKED_FLOAT32;
114177 }
114178 if (isPacked) {
114179 return PhysicalTextureType.PACKED_2X2_FLOAT16;
114180 }
114181 return PhysicalTextureType.UNPACKED_FLOAT16;
114182 }
114183 function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) {
114184 if (logicalTexType === TextureUsage.UPLOAD) {
114185 return PhysicalTextureType.PACKED_2X2_FLOAT32;
114186 } else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) {
114187 return getPhysicalTextureForRendering(isPacked);
114188 } else if (logicalTexType === TextureUsage.DOWNLOAD || logicalTexType === TextureUsage.PIXELS) {
114189 return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE;
114190 }
114191 throw new Error("Unknown logical texture type ".concat(logicalTexType));
114192 }
114193 function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) {
114194 return "".concat(shapeRowsCol[0], "_").concat(shapeRowsCol[1], "_").concat(physicalTexType, "_").concat(isPacked);
114195 }
114196
114197 var UnaryOpProgram = /*#__PURE__*/_createClass(function UnaryOpProgram(aShape, opSnippet) {
114198 _classCallCheck(this, UnaryOpProgram);
114199 this.variableNames = ['A'];
114200 this.outputShape = aShape;
114201 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
114202 this.userCode = "\n float unaryOperation(float x) {\n ".concat(opSnippet, "\n }\n\n void main() {\n float x = getAAtOutCoords();\n float y = unaryOperation(x);\n\n setOutput(y);\n }\n ");
114203 });
114204 var CHECK_NAN_SNIPPET$1 = "if (isnan(x)) return x;";
114205 var LINEAR$1 = "return x;";
114206 var ABS$1 = "return abs(x);";
114207 function STEP() {
114208 var alpha = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 0.0;
114209 return CHECK_NAN_SNIPPET$1 + "\n return x > 0.0 ? 1.0 : float(".concat(alpha, ");\n ");
114210 }
114211 var ELU$2 = "return (x >= 0.0) ? x : (exp(x) - 1.0);";
114212 var RELU$2 = CHECK_NAN_SNIPPET$1 + "\n return (x < 0.0) ? 0.0 : x;\n";
114213 var RELU6$2 = CHECK_NAN_SNIPPET$1 + "\n return (x < 0.0) ? 0.0 : min(6.0, x);\n";
114214 var CLONE = 'return x;';
114215 var SIGMOID$2 = "return 1.0 / (1.0 + exp(-1.0 * x));";
114216
114217 var LINEAR = "return x;";
114218 var ELU$1 = "\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n";
114219 var RELU$1 = "\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
114220 var RELU6$1 = "\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
114221 var SIGMOID$1 = "return 1.0 / (1.0 + exp(-1.0 * x));";
114222 var UnaryOpPackedProgram = /*#__PURE__*/_createClass(function UnaryOpPackedProgram(aShape, opSnippet) {
114223 _classCallCheck(this, UnaryOpPackedProgram);
114224 this.variableNames = ['A'];
114225 this.packedInputs = true;
114226 this.packedOutput = true;
114227 this.outputShape = aShape;
114228 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
114229 this.userCode = "\n vec4 unaryOperation(vec4 x) {\n ".concat(opSnippet, "\n }\n\n void main() {\n vec4 x = getAAtOutCoords();\n vec4 y = unaryOperation(x);\n\n setOutput(y);\n }\n ");
114230 });
114231
114232 var UnpackProgram = /*#__PURE__*/_createClass(function UnpackProgram(outputShape) {
114233 _classCallCheck(this, UnpackProgram);
114234 this.variableNames = ['A'];
114235 this.packedInputs = true;
114236 this.packedOutput = false;
114237 this.outputShape = outputShape;
114238 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
114239 var rank = outputShape.length;
114240 var channels = getChannels('rc', rank);
114241 var dtype = getCoordsDataType(rank);
114242 var sourceCoords = getSourceCoords$2(rank, channels);
114243 var innerDims = channels.slice(-2);
114244 var coords = rank <= 1 ? 'rc' : "vec2(".concat(innerDims.join(','), ")");
114245 this.userCode = "\n void main() {\n ".concat(dtype, " rc = getOutputCoords();\n vec4 packedInput = getA(").concat(sourceCoords, ");\n\n setOutput(getChannel(packedInput, ").concat(coords, "));\n }\n ");
114246 });
114247
114248 var whereImpl = whereImpl$2;
114249 var EPSILON_FLOAT32 = 1e-7;
114250 var EPSILON_FLOAT16 = 1e-4;
114251 var binaryCaches = {};
114252 function getBinaryCache(webGLVersion) {
114253 if (webGLVersion in binaryCaches) {
114254 return binaryCaches[webGLVersion];
114255 }
114256 binaryCaches[webGLVersion] = {};
114257 return binaryCaches[webGLVersion];
114258 }
114259 // Empirically determined constant used to determine size threshold for handing
114260 // off execution to the CPU.
114261 var CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD');
114262 // Empirically determined constant used to decide the number of MB on GPU
114263 // before we warn about high memory use. The MB are this constant * screen area
114264 // * dpi / 1024 / 1024.
114265 var BEFORE_PAGING_CONSTANT = 600;
114266 function numMBBeforeWarning() {
114267 if (env().global.screen == null) {
114268 return 1024; // 1 GB.
114269 }
114270
114271 return env().global.screen.height * env().global.screen.width * window.devicePixelRatio * BEFORE_PAGING_CONSTANT / 1024 / 1024;
114272 }
114273 var MathBackendWebGL = /*#__PURE__*/function (_KernelBackend) {
114274 _inherits(MathBackendWebGL, _KernelBackend);
114275 var _super = _createSuper(MathBackendWebGL);
114276 function MathBackendWebGL(gpuResource) {
114277 var _this;
114278 _classCallCheck(this, MathBackendWebGL);
114279 _this = _super.call(this);
114280 // Maps data ids that have a pending read operation, to list of subscribers.
114281 _this.pendingRead = new WeakMap();
114282 // List of data ids that are scheduled for disposal, but are waiting on a
114283 // pending read operation.
114284 _this.pendingDisposal = new WeakSet();
114285 // Used to count the number of 'shallow' sliced tensors that point to the
114286 // same data id.
114287 _this.dataRefCount = new WeakMap();
114288 _this.numBytesInGPU = 0;
114289 // Accumulated time spent (including blocking) in uploading data to webgl.
114290 _this.uploadWaitMs = 0;
114291 // Accumulated time spent (including blocking in downloading data from webgl.
114292 _this.downloadWaitMs = 0;
114293 // record the last manual GL Flush time.
114294 _this.lastGlFlushTime = 0;
114295 _this.warnedAboutMemory = false;
114296 _this.pendingDeletes = 0;
114297 _this.disposed = false;
114298 if (!env().getBool('HAS_WEBGL')) {
114299 throw new Error('WebGL is not supported on this device');
114300 }
114301 var newGPGPU;
114302 if (gpuResource != null) {
114303 if (gpuResource instanceof GPGPUContext) {
114304 newGPGPU = gpuResource;
114305 } else {
114306 var gl = getWebGLContext(env().getNumber('WEBGL_VERSION'), gpuResource);
114307 newGPGPU = new GPGPUContext(gl);
114308 }
114309 _this.binaryCache = {};
114310 _this.gpgpuCreatedLocally = false;
114311 } else {
114312 var _gl = getWebGLContext(env().getNumber('WEBGL_VERSION'));
114313 newGPGPU = new GPGPUContext(_gl);
114314 _this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION'));
114315 _this.gpgpuCreatedLocally = true;
114316 }
114317 _this.gpgpu = newGPGPU;
114318 _this.canvas = _this.gpgpu.gl.canvas;
114319 _this.textureManager = new TextureManager(_this.gpgpu);
114320 _this.numMBBeforeWarning = numMBBeforeWarning();
114321 _this.texData = new DataStorage(_assertThisInitialized(_this), engine());
114322 return _this;
114323 }
114324 _createClass(MathBackendWebGL, [{
114325 key: "nextDataId",
114326 value: function nextDataId() {
114327 return MathBackendWebGL.nextDataId++;
114328 }
114329 }, {
114330 key: "numDataIds",
114331 value: function numDataIds() {
114332 return this.texData.numDataIds() - this.pendingDeletes;
114333 }
114334 // Writes a new entry to the data store with a WebGL texture, and registers it
114335 // to the texture manager.
114336 }, {
114337 key: "writeTexture",
114338 value: function writeTexture(texture, shape, dtype, texHeight, texWidth, channels) {
114339 // Temporarily create an tensor info to make the texture compatible with
114340 // the runWebGLProgram's input.
114341 var input = this.makeTensorInfo(shape, dtype);
114342 var inData = this.texData.get(input.dataId);
114343 // Even though the input texture could be unpacked or dense packed, it is
114344 // always considered as unpacked for EncodeMatrixProgram.
114345 inData.isPacked = false;
114346 // Bind texture to the input tensor.
114347 inData.texture = {
114348 texture: texture,
114349 texShape: [texHeight, texWidth]
114350 };
114351 inData.texShape = [texHeight, texWidth];
114352 var shapeAs3D = getShapeAs3D(shape);
114353 var program = new EncodeMatrixProgram(shapeAs3D, false /* isByteArray */, channels);
114354 var output = this.runWebGLProgram(program, [input], dtype, [[texHeight, texWidth]]);
114355 output.shape = shape;
114356 // Unbind the texture from the input tensor to avoid the texture being
114357 // released.
114358 inData.texture = null;
114359 this.disposeIntermediateTensorInfo(input);
114360 return output.dataId;
114361 }
114362 }, {
114363 key: "write",
114364 value: function write(values, shape, dtype) {
114365 if (env().getBool('WEBGL_CHECK_NUMERICAL_PROBLEMS') || env().getBool('DEBUG')) {
114366 this.checkNumericalProblems(values);
114367 }
114368 if (dtype === 'complex64' && values != null) {
114369 throw new Error("Cannot write to a complex64 dtype. " + "Please use tf.complex(real, imag).");
114370 }
114371 var dataId = {
114372 id: this.nextDataId()
114373 };
114374 this.texData.set(dataId, {
114375 shape: shape,
114376 dtype: dtype,
114377 values: values,
114378 usage: TextureUsage.UPLOAD,
114379 refCount: 1
114380 });
114381 return dataId;
114382 }
114383 /** Return refCount of a `TensorData`. */
114384 }, {
114385 key: "refCount",
114386 value: function refCount(dataId) {
114387 if (this.texData.has(dataId)) {
114388 var tensorData = this.texData.get(dataId);
114389 return tensorData.refCount;
114390 }
114391 return 0;
114392 }
114393 /** Increase refCount of a `TextureData`. */
114394 }, {
114395 key: "incRef",
114396 value: function incRef(dataId) {
114397 var texData = this.texData.get(dataId);
114398 texData.refCount++;
114399 }
114400 /** Decrease refCount of a `TextureData`. */
114401 }, {
114402 key: "decRef",
114403 value: function decRef(dataId) {
114404 if (this.texData.has(dataId)) {
114405 var texData = this.texData.get(dataId);
114406 texData.refCount--;
114407 }
114408 }
114409 }, {
114410 key: "move",
114411 value: function move(dataId, values, shape, dtype, refCount) {
114412 if (env().getBool('DEBUG')) {
114413 this.checkNumericalProblems(values);
114414 }
114415 if (dtype === 'complex64') {
114416 throw new Error("Cannot write to a complex64 dtype. " + "Please use tf.complex(real, imag).");
114417 }
114418 this.texData.set(dataId, {
114419 shape: shape,
114420 dtype: dtype,
114421 values: values,
114422 usage: TextureUsage.UPLOAD,
114423 refCount: refCount
114424 });
114425 }
114426 }, {
114427 key: "disposeIntermediateTensorInfo",
114428 value: function disposeIntermediateTensorInfo(tensorInfo) {
114429 this.disposeData(tensorInfo.dataId);
114430 }
114431 }, {
114432 key: "readSync",
114433 value: function readSync(dataId) {
114434 var texData = this.texData.get(dataId);
114435 var values = texData.values,
114436 dtype = texData.dtype,
114437 complexTensorInfos = texData.complexTensorInfos,
114438 slice = texData.slice,
114439 shape = texData.shape,
114440 isPacked = texData.isPacked;
114441 // The presence of `slice` indicates this tensor is a shallow slice of a
114442 // different tensor, and is using that original tensor's texture. Run
114443 // `clone` in order to copy that texture and read from it.
114444 if (slice != null) {
114445 var program;
114446 if (isPacked) {
114447 program = new UnaryOpPackedProgram(shape, CLONE);
114448 } else {
114449 program = new UnaryOpProgram(shape, CLONE);
114450 }
114451 var res = this.runWebGLProgram(program, [{
114452 dataId: dataId,
114453 shape: shape,
114454 dtype: dtype
114455 }], dtype);
114456 var data = this.readSync(res.dataId);
114457 this.disposeIntermediateTensorInfo(res);
114458 return data;
114459 }
114460 if (values != null) {
114461 return this.convertAndCacheOnCPU(dataId);
114462 }
114463 if (dtype === 'string') {
114464 return values;
114465 }
114466 var shouldTimeProgram = this.activeTimers != null;
114467 var start;
114468 if (shouldTimeProgram) {
114469 start = now();
114470 }
114471 var result;
114472 if (dtype === 'complex64') {
114473 var realValues = this.readSync(complexTensorInfos.real.dataId);
114474 var imagValues = this.readSync(complexTensorInfos.imag.dataId);
114475 result = mergeRealAndImagArrays(realValues, imagValues);
114476 } else {
114477 result = this.getValuesFromTexture(dataId);
114478 }
114479 if (shouldTimeProgram) {
114480 this.downloadWaitMs += now() - start;
114481 }
114482 return this.convertAndCacheOnCPU(dataId, result);
114483 }
114484 }, {
114485 key: "read",
114486 value: function () {
114487 var _read = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee(dataId) {
114488 var _subscribers, texData, values, shape, slice, dtype, complexTensorInfos, isPacked, program, res, data, buffer, tmpDownloadTarget, _this$gpgpu, tmpData, vals, ps, realValues, imagValues, size, gl, dTypeVals, subscribers;
114489 return _regeneratorRuntime().wrap(function _callee$(_context) {
114490 while (1) switch (_context.prev = _context.next) {
114491 case 0:
114492 if (!this.pendingRead.has(dataId)) {
114493 _context.next = 3;
114494 break;
114495 }
114496 _subscribers = this.pendingRead.get(dataId);
114497 return _context.abrupt("return", new Promise(function (resolve) {
114498 return _subscribers.push(resolve);
114499 }));
114500 case 3:
114501 texData = this.texData.get(dataId);
114502 values = texData.values, shape = texData.shape, slice = texData.slice, dtype = texData.dtype, complexTensorInfos = texData.complexTensorInfos, isPacked = texData.isPacked; // The presence of `slice` indicates this tensor is a shallow slice of a
114503 // different tensor, and is using that original tensor's texture. Run
114504 // `clone` in order to copy that texture and read from it.
114505 if (!(slice != null)) {
114506 _context.next = 11;
114507 break;
114508 }
114509 if (isPacked) {
114510 program = new UnaryOpPackedProgram(shape, CLONE);
114511 } else {
114512 program = new UnaryOpProgram(shape, CLONE);
114513 }
114514 res = this.runWebGLProgram(program, [{
114515 dataId: dataId,
114516 shape: shape,
114517 dtype: dtype
114518 }], dtype);
114519 data = this.read(res.dataId);
114520 this.disposeIntermediateTensorInfo(res);
114521 return _context.abrupt("return", data);
114522 case 11:
114523 if (!(values != null)) {
114524 _context.next = 13;
114525 break;
114526 }
114527 return _context.abrupt("return", this.convertAndCacheOnCPU(dataId));
114528 case 13:
114529 if (!env().getBool('DEBUG')) {
114530 _context.next = 16;
114531 break;
114532 }
114533 if (!(!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') && env().getNumber('WEBGL_VERSION') === 2)) {
114534 _context.next = 16;
114535 break;
114536 }
114537 throw new Error("tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and " + "WEBGL_VERSION=2 not yet supported.");
114538 case 16:
114539 buffer = null;
114540 if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) {
114541 // Possibly copy the texture into a buffer before inserting a fence.
114542 tmpDownloadTarget = this.decode(dataId);
114543 tmpData = this.texData.get(tmpDownloadTarget.dataId);
114544 buffer = (_this$gpgpu = this.gpgpu).createBufferFromTexture.apply(_this$gpgpu, [tmpData.texture.texture].concat(_toConsumableArray(getDenseTexShape(shape))));
114545 }
114546 this.pendingRead.set(dataId, []);
114547 if (!(dtype !== 'complex64')) {
114548 _context.next = 22;
114549 break;
114550 }
114551 _context.next = 22;
114552 return this.gpgpu.createAndWaitForFence();
114553 case 22:
114554 if (!(dtype === 'complex64')) {
114555 _context.next = 31;
114556 break;
114557 }
114558 _context.next = 25;
114559 return Promise.all([this.read(complexTensorInfos.real.dataId), this.read(complexTensorInfos.imag.dataId)]);
114560 case 25:
114561 ps = _context.sent;
114562 realValues = ps[0];
114563 imagValues = ps[1];
114564 vals = mergeRealAndImagArrays(realValues, imagValues);
114565 _context.next = 32;
114566 break;
114567 case 31:
114568 if (buffer == null) {
114569 vals = this.getValuesFromTexture(dataId);
114570 } else {
114571 size = sizeFromShape(shape);
114572 vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
114573 }
114574 case 32:
114575 if (tmpDownloadTarget != null) {
114576 this.disposeIntermediateTensorInfo(tmpDownloadTarget);
114577 }
114578 if (buffer != null) {
114579 gl = this.gpgpu.gl;
114580 callAndCheck(gl, function () {
114581 return gl.deleteBuffer(buffer);
114582 });
114583 }
114584 dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
114585 subscribers = this.pendingRead.get(dataId);
114586 this.pendingRead.delete(dataId);
114587 // Notify all pending reads.
114588 subscribers.forEach(function (resolve) {
114589 return resolve(dTypeVals);
114590 });
114591 if (this.pendingDisposal.has(dataId)) {
114592 this.pendingDisposal.delete(dataId);
114593 if (this.disposeData(dataId)) {
114594 engine().removeDataId(dataId, this);
114595 }
114596 this.pendingDeletes--;
114597 }
114598 return _context.abrupt("return", dTypeVals);
114599 case 40:
114600 case "end":
114601 return _context.stop();
114602 }
114603 }, _callee, this);
114604 }));
114605 function read(_x) {
114606 return _read.apply(this, arguments);
114607 }
114608 return read;
114609 }()
114610 /**
114611 * Read tensor to a new texture that is densely packed for ease of use.
114612 * @param dataId The source tensor.
114613 * @param options
114614 * customTexShape: Optional. If set, will use the user defined texture
114615 * shape to create the texture.
114616 */
114617 }, {
114618 key: "readToGPU",
114619 value: function readToGPU(dataId) {
114620 var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
114621 var texData = this.texData.get(dataId);
114622 var values = texData.values,
114623 shape = texData.shape,
114624 slice = texData.slice,
114625 dtype = texData.dtype,
114626 isPacked = texData.isPacked,
114627 texture = texData.texture;
114628 if (dtype === 'complex64') {
114629 throw new Error('Does not support reading texture for complex64 dtype.');
114630 }
114631 // The presence of `slice` indicates this tensor is a shallow slice of a
114632 // different tensor, and is using that original tensor's texture. Run
114633 // `clone` in order to copy that texture and read from it.
114634 if (slice != null) {
114635 var program;
114636 if (isPacked) {
114637 program = new UnaryOpPackedProgram(shape, CLONE);
114638 } else {
114639 program = new UnaryOpProgram(shape, CLONE);
114640 }
114641 var res = this.runWebGLProgram(program, [{
114642 dataId: dataId,
114643 shape: shape,
114644 dtype: dtype
114645 }], dtype);
114646 var gpuResouorce = this.readToGPU(res, options);
114647 this.disposeIntermediateTensorInfo(res);
114648 return gpuResouorce;
114649 }
114650 if (texture == null) {
114651 if (values != null) {
114652 throw new Error('Data is not on GPU but on CPU.');
114653 } else {
114654 throw new Error('There is no data on GPU or CPU.');
114655 }
114656 }
114657 // Decode the texture so that it is stored densely (using four channels).
114658 var tmpTarget = this.decode(dataId, options.customTexShape);
114659 // Make engine track this tensor, so that we can dispose it later.
114660 var tensorRef = engine().makeTensorFromTensorInfo(tmpTarget);
114661 var tmpData = this.texData.get(tmpTarget.dataId);
114662 return Object.assign({
114663 tensorRef: tensorRef
114664 }, tmpData.texture);
114665 }
114666 }, {
114667 key: "bufferSync",
114668 value: function bufferSync(t) {
114669 var data = this.readSync(t.dataId);
114670 if (t.dtype === 'string') {
114671 try {
114672 // Decode the bytes into string.
114673 var strings = data.map(function (d) {
114674 return decodeString(d);
114675 });
114676 return buffer(t.shape, t.dtype, strings);
114677 } catch (_a) {
114678 throw new Error('Failed to decode encoded string bytes into utf-8');
114679 }
114680 }
114681 return buffer(t.shape, t.dtype, data);
114682 }
114683 }, {
114684 key: "checkNumericalProblems",
114685 value: function checkNumericalProblems(values) {
114686 if (values == null) {
114687 return;
114688 }
114689 for (var i = 0; i < values.length; i++) {
114690 var num = values[i];
114691 if (!canBeRepresented(num)) {
114692 if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) {
114693 throw Error("The value ".concat(num, " cannot be represented with your ") + "current settings. Consider enabling float32 rendering: " + "'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'");
114694 }
114695 throw Error("The value ".concat(num, " cannot be represented on this device."));
114696 }
114697 }
114698 }
114699 }, {
114700 key: "getValuesFromTexture",
114701 value: function getValuesFromTexture(dataId) {
114702 var _this$texData$get = this.texData.get(dataId),
114703 shape = _this$texData$get.shape,
114704 dtype = _this$texData$get.dtype,
114705 isPacked = _this$texData$get.isPacked;
114706 var size = sizeFromShape(shape);
114707 if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
114708 var _this$gpgpu2;
114709 var tmpTarget = this.decode(dataId);
114710 var _tmpData = this.texData.get(tmpTarget.dataId);
114711 var _vals = (_this$gpgpu2 = this.gpgpu).downloadMatrixFromPackedTexture.apply(_this$gpgpu2, [_tmpData.texture.texture].concat(_toConsumableArray(getDenseTexShape(shape)))).subarray(0, size);
114712 this.disposeIntermediateTensorInfo(tmpTarget);
114713 return _vals;
114714 }
114715 var shouldUsePackedProgram = env().getBool('WEBGL_PACK') && isPacked === true;
114716 var outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape;
114717 var program = shouldUsePackedProgram ? new EncodeFloatPackedProgram(outputShape) : new EncodeFloatProgram(outputShape);
114718 var output = this.runWebGLProgram(program, [{
114719 shape: outputShape,
114720 dtype: dtype,
114721 dataId: dataId
114722 }], 'float32');
114723 var tmpData = this.texData.get(output.dataId);
114724 var vals = this.gpgpu.downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture.texture, tmpData.texShape[0], tmpData.texShape[1]).subarray(0, size);
114725 this.disposeIntermediateTensorInfo(output);
114726 return vals;
114727 }
114728 }, {
114729 key: "timerAvailable",
114730 value: function timerAvailable() {
114731 return env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0;
114732 }
114733 }, {
114734 key: "time",
114735 value: function time(f) {
114736 var _this2 = this;
114737 var oldActiveTimers = this.activeTimers;
114738 var newActiveTimers = [];
114739 var outerMostTime = false;
114740 if (this.programTimersStack == null) {
114741 this.programTimersStack = newActiveTimers;
114742 outerMostTime = true;
114743 } else {
114744 this.activeTimers.push(newActiveTimers);
114745 }
114746 this.activeTimers = newActiveTimers;
114747 f();
114748 // needing to split these up because util.flatten only accepts certain types
114749 var flattenedActiveTimerQueries = flatten$2(this.activeTimers.map(function (d) {
114750 return d.query;
114751 })).filter(function (d) {
114752 return d != null;
114753 });
114754 var flattenedActiveTimerNames = flatten$2(this.activeTimers.map(function (d) {
114755 return d.name;
114756 })).filter(function (d) {
114757 return d != null;
114758 });
114759 this.activeTimers = oldActiveTimers;
114760 if (outerMostTime) {
114761 this.programTimersStack = null;
114762 }
114763 var res = {
114764 uploadWaitMs: this.uploadWaitMs,
114765 downloadWaitMs: this.downloadWaitMs,
114766 kernelMs: null,
114767 wallMs: null // will be filled by the engine
114768 };
114769
114770 return _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
114771 var kernelMs;
114772 return _regeneratorRuntime().wrap(function _callee2$(_context2) {
114773 while (1) switch (_context2.prev = _context2.next) {
114774 case 0:
114775 if (!(env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0)) {
114776 _context2.next = 8;
114777 break;
114778 }
114779 _context2.next = 3;
114780 return Promise.all(flattenedActiveTimerQueries);
114781 case 3:
114782 kernelMs = _context2.sent;
114783 res['kernelMs'] = sum$4(kernelMs);
114784 res['getExtraProfileInfo'] = function () {
114785 return kernelMs.map(function (d, i) {
114786 return {
114787 name: flattenedActiveTimerNames[i],
114788 ms: d
114789 };
114790 }).map(function (d) {
114791 return "".concat(d.name, ": ").concat(d.ms);
114792 }).join(', ');
114793 };
114794 _context2.next = 9;
114795 break;
114796 case 8:
114797 res['kernelMs'] = {
114798 error: 'WebGL query timers are not supported in this environment.'
114799 };
114800 case 9:
114801 _this2.uploadWaitMs = 0;
114802 _this2.downloadWaitMs = 0;
114803 return _context2.abrupt("return", res);
114804 case 12:
114805 case "end":
114806 return _context2.stop();
114807 }
114808 }, _callee2);
114809 }))();
114810 }
114811 }, {
114812 key: "memory",
114813 value: function memory() {
114814 return {
114815 unreliable: false,
114816 numBytesInGPU: this.numBytesInGPU,
114817 numBytesInGPUAllocated: this.textureManager.numBytesAllocated,
114818 numBytesInGPUFree: this.textureManager.numBytesFree
114819 };
114820 }
114821 }, {
114822 key: "startTimer",
114823 value: function startTimer() {
114824 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
114825 return this.gpgpu.beginQuery();
114826 }
114827 return {
114828 startMs: now(),
114829 endMs: null
114830 };
114831 }
114832 }, {
114833 key: "endTimer",
114834 value: function endTimer(query) {
114835 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
114836 this.gpgpu.endQuery();
114837 return query;
114838 }
114839 query.endMs = now();
114840 return query;
114841 }
114842 }, {
114843 key: "getQueryTime",
114844 value: function () {
114845 var _getQueryTime = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee3(query) {
114846 var timerQuery;
114847 return _regeneratorRuntime().wrap(function _callee3$(_context3) {
114848 while (1) switch (_context3.prev = _context3.next) {
114849 case 0:
114850 if (!(env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0)) {
114851 _context3.next = 2;
114852 break;
114853 }
114854 return _context3.abrupt("return", this.gpgpu.waitForQueryAndGetTime(query));
114855 case 2:
114856 timerQuery = query;
114857 return _context3.abrupt("return", timerQuery.endMs - timerQuery.startMs);
114858 case 4:
114859 case "end":
114860 return _context3.stop();
114861 }
114862 }, _callee3, this);
114863 }));
114864 function getQueryTime(_x2) {
114865 return _getQueryTime.apply(this, arguments);
114866 }
114867 return getQueryTime;
114868 }()
114869 /**
114870 * Decrease the RefCount on the dataId and dispose the memory if the dataId
114871 * has 0 refCount. If there are pending read on the data, the disposal would
114872 * added to the pending delete queue. Return true if the dataId is removed
114873 * from backend or the backend does not contain the dataId, false if the
114874 * dataId is not removed. Memory may or may not be released even when dataId
114875 * is removed, which also depends on dataRefCount, see `releaseGPU`.
114876 * @param dataId
114877 * @oaram force Optional, remove the data regardless of refCount
114878 */
114879 }, {
114880 key: "disposeData",
114881 value: function disposeData(dataId) {
114882 var force = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
114883 if (this.pendingDisposal.has(dataId)) {
114884 return false;
114885 }
114886 // No-op if already disposed.
114887 if (!this.texData.has(dataId)) {
114888 return true;
114889 }
114890 // if force flag is set, change refCount to 0, this would ensure disposal
114891 // when added to the pendingDisposal queue. Memory may or may not be
114892 // released, which also depends on dataRefCount, see `releaseGPU`.
114893 if (force) {
114894 this.texData.get(dataId).refCount = 0;
114895 } else {
114896 this.texData.get(dataId).refCount--;
114897 }
114898 if (!force && this.texData.get(dataId).refCount > 0) {
114899 return false;
114900 }
114901 if (this.pendingRead.has(dataId)) {
114902 this.pendingDisposal.add(dataId);
114903 this.pendingDeletes++;
114904 return false;
114905 }
114906 this.releaseGPUData(dataId);
114907 var _this$texData$get2 = this.texData.get(dataId),
114908 complexTensorInfos = _this$texData$get2.complexTensorInfos;
114909 if (complexTensorInfos != null) {
114910 this.disposeData(complexTensorInfos.real.dataId, force);
114911 this.disposeData(complexTensorInfos.imag.dataId, force);
114912 }
114913 this.texData.delete(dataId);
114914 return true;
114915 }
114916 }, {
114917 key: "releaseGPUData",
114918 value: function releaseGPUData(dataId) {
114919 var _this$texData$get3 = this.texData.get(dataId),
114920 texture = _this$texData$get3.texture,
114921 dtype = _this$texData$get3.dtype,
114922 texShape = _this$texData$get3.texShape,
114923 usage = _this$texData$get3.usage,
114924 isPacked = _this$texData$get3.isPacked,
114925 slice = _this$texData$get3.slice;
114926 var key = slice && slice.origDataId || dataId;
114927 var refCount = this.dataRefCount.get(key);
114928 if (refCount > 1) {
114929 this.dataRefCount.set(key, refCount - 1);
114930 } else {
114931 this.dataRefCount.delete(key);
114932 if (texture != null) {
114933 this.numBytesInGPU -= this.computeBytes(texShape, dtype);
114934 this.textureManager.releaseTexture(texture, texShape, usage, isPacked);
114935 }
114936 }
114937 var texData = this.texData.get(dataId);
114938 texData.texture = null;
114939 texData.texShape = null;
114940 texData.isPacked = false;
114941 texData.slice = null;
114942 }
114943 }, {
114944 key: "getTexture",
114945 value: function getTexture(dataId) {
114946 this.uploadToGPU(dataId);
114947 return this.texData.get(dataId).texture.texture;
114948 }
114949 /**
114950 * Returns internal information for the specific data bucket. Used in unit
114951 * tests.
114952 */
114953 }, {
114954 key: "getDataInfo",
114955 value: function getDataInfo(dataId) {
114956 return this.texData.get(dataId);
114957 }
114958 /*
114959 Tests whether all the inputs to an op are small and on the CPU. This heuristic
114960 determines when it would be faster to execute a kernel on the CPU. WebGL
114961 kernels opt into running this check and forwarding when appropriate.
114962 TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more
114963 sustainable strategy for optimizing backend execution of ops.
114964 */
114965 }, {
114966 key: "shouldExecuteOnCPU",
114967 value: function shouldExecuteOnCPU(inputs) {
114968 var _this3 = this;
114969 var sizeThreshold = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : CPU_HANDOFF_SIZE_THRESHOLD;
114970 return env().getBool('WEBGL_CPU_FORWARD') && inputs.every(function (input) {
114971 return _this3.texData.get(input.dataId).texture == null && sizeFromShape(input.shape) < sizeThreshold;
114972 });
114973 }
114974 }, {
114975 key: "getGPGPUContext",
114976 value: function getGPGPUContext() {
114977 return this.gpgpu;
114978 }
114979 }, {
114980 key: "where",
114981 value: function where(condition) {
114982 warn('tf.where() in webgl locks the UI thread. ' + 'Call tf.whereAsync() instead');
114983 var condVals = condition.dataSync();
114984 return whereImpl(condition.shape, condVals);
114985 }
114986 }, {
114987 key: "packedUnaryOp",
114988 value: function packedUnaryOp(x, op, dtype) {
114989 var program = new UnaryOpPackedProgram(x.shape, op);
114990 var outInfo = this.compileAndRun(program, [x], dtype);
114991 return engine().makeTensorFromTensorInfo(outInfo);
114992 }
114993 // TODO(msoulanille) remove this once the backend has been modularized
114994 // a copy is needed here to break a circular dependency.
114995 // Also remove the op from unary_op.
114996 }, {
114997 key: "abs",
114998 value: function abs(x) {
114999 // TODO: handle cases when x is complex.
115000 if (this.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
115001 var outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values);
115002 return this.makeOutput(x.shape, x.dtype, outValues);
115003 }
115004 if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
115005 return this.packedUnaryOp(x, ABS$1, x.dtype);
115006 }
115007 var program = new UnaryOpProgram(x.shape, ABS$1);
115008 var outInfo = this.compileAndRun(program, [x]);
115009 return engine().makeTensorFromTensorInfo(outInfo);
115010 }
115011 }, {
115012 key: "makeTensorInfo",
115013 value: function makeTensorInfo(shape, dtype, values) {
115014 var dataId;
115015 if (dtype === 'string' && values != null && values.length > 0 && isString(values[0])) {
115016 var encodedValues = values.map(function (d) {
115017 return encodeString(d);
115018 });
115019 dataId = this.write(encodedValues, shape, dtype);
115020 } else {
115021 dataId = this.write(values, shape, dtype);
115022 }
115023 this.texData.get(dataId).usage = null;
115024 return {
115025 dataId: dataId,
115026 shape: shape,
115027 dtype: dtype
115028 };
115029 }
115030 }, {
115031 key: "makeOutput",
115032 value: function makeOutput(shape, dtype, values) {
115033 return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this);
115034 }
115035 }, {
115036 key: "unpackTensor",
115037 value: function unpackTensor(input) {
115038 var program = new UnpackProgram(input.shape);
115039 return this.runWebGLProgram(program, [input], input.dtype);
115040 }
115041 }, {
115042 key: "packTensor",
115043 value: function packTensor(input) {
115044 var program = new PackProgram(input.shape);
115045 var preventEagerUnpackingOutput = true;
115046 return this.runWebGLProgram(program, [input], input.dtype, null /* customUniformValues */, preventEagerUnpackingOutput);
115047 }
115048 }, {
115049 key: "packedReshape",
115050 value: function packedReshape(input, afterShape) {
115051 var input3DShape = [getBatchDim(input.shape)].concat(_toConsumableArray(getRowsCols(input.shape)));
115052 var input3D = {
115053 dtype: input.dtype,
115054 shape: input3DShape,
115055 dataId: input.dataId
115056 };
115057 var afterShapeAs3D = [getBatchDim(afterShape)].concat(_toConsumableArray(getRowsCols(afterShape)));
115058 var program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
115059 var preventEagerUnpackingOfOutput = true;
115060 var customValues = [input3DShape];
115061 var output = this.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
115062 return {
115063 dataId: output.dataId,
115064 shape: afterShape,
115065 dtype: output.dtype
115066 };
115067 }
115068 }, {
115069 key: "decode",
115070 value: function decode(dataId, customTexShape) {
115071 var texData = this.texData.get(dataId);
115072 var isPacked = texData.isPacked,
115073 shape = texData.shape,
115074 dtype = texData.dtype;
115075 if (customTexShape != null) {
115076 var size = sizeFromShape(shape);
115077 var texSize = customTexShape[0] * customTexShape[1] * 4;
115078 assert$1(size <= texSize, function () {
115079 return 'customTexShape is too small. ' + 'Row * Column * 4 should be equal or larger than the ' + 'size of the tensor data.';
115080 });
115081 }
115082 var shapeAs3D = getShapeAs3D(shape);
115083 var program;
115084 if (isPacked) {
115085 program = new DecodeMatrixPackedProgram(shapeAs3D);
115086 } else {
115087 program = new DecodeMatrixProgram(shapeAs3D);
115088 }
115089 var preventEagerUnpackingOfOutput = true;
115090 var customValues = [customTexShape != null ? customTexShape : getDenseTexShape(shapeAs3D)];
115091 var out = this.runWebGLProgram(program, [{
115092 shape: shapeAs3D,
115093 dtype: dtype,
115094 dataId: dataId
115095 }], dtype, customValues, preventEagerUnpackingOfOutput, customTexShape);
115096 return {
115097 dtype: dtype,
115098 shape: shape,
115099 dataId: out.dataId
115100 };
115101 }
115102 }, {
115103 key: "runWebGLProgram",
115104 value: function runWebGLProgram(program, inputs, outputDtype, customUniformValues) {
115105 var _this4 = this;
115106 var preventEagerUnpackingOfOutput = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
115107 var customTexShape = arguments.length > 5 ? arguments[5] : undefined;
115108 var output = this.makeTensorInfo(program.outputShape, outputDtype);
115109 var outData = this.texData.get(output.dataId);
115110 if (program.packedOutput) {
115111 outData.isPacked = true;
115112 }
115113 if (program.outPackingScheme === PackingScheme.DENSE) {
115114 var texelShape = customTexShape != null ? customTexShape : getDenseTexShape(program.outputShape);
115115 // For a densely packed output, we explicitly set texShape
115116 // so it doesn't get assigned later according to our typical packing
115117 // scheme wherein a single texel can only contain values from adjacent
115118 // rows/cols.
115119 outData.texShape = texelShape.map(function (d) {
115120 return d * 2;
115121 });
115122 }
115123 if (program.outTexUsage != null) {
115124 outData.usage = program.outTexUsage;
115125 }
115126 if (sizeFromShape(output.shape) === 0) {
115127 // Short-circuit the computation since the result is empty (has 0 in its
115128 // shape).
115129 outData.values = getTypedArrayFromDType(output.dtype, 0);
115130 return output;
115131 }
115132 var dataToDispose = [];
115133 var inputsData = inputs.map(function (input) {
115134 if (input.dtype === 'complex64') {
115135 throw new Error("GPGPUProgram does not support complex64 input. For complex64 " + "dtypes, please separate the program into real and imaginary " + "parts.");
115136 }
115137 var texData = _this4.texData.get(input.dataId);
115138 if (texData.texture == null) {
115139 if (!program.packedInputs && sizeFromShape(input.shape) <= env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) {
115140 // Upload small tensors that live on the CPU as uniforms, not as
115141 // textures. Do this only when the environment supports 32bit floats
115142 // due to problems when comparing 16bit floats with 32bit floats.
115143 // TODO(https://github.com/tensorflow/tfjs/issues/821): Make it
115144 // possible for packed shaders to sample from uniforms.
115145 return {
115146 shape: input.shape,
115147 texData: null,
115148 isUniform: true,
115149 uniformValues: texData.values
115150 };
115151 }
115152 // This ensures that if a packed program's inputs have not yet been
115153 // uploaded to the GPU, they get uploaded as packed right off the bat.
115154 if (program.packedInputs) {
115155 texData.isPacked = true;
115156 texData.shape = input.shape;
115157 }
115158 }
115159 _this4.uploadToGPU(input.dataId);
115160 if (!!texData.isPacked !== !!program.packedInputs) {
115161 input = texData.isPacked ? _this4.unpackTensor(input) : _this4.packTensor(input);
115162 dataToDispose.push(input);
115163 texData = _this4.texData.get(input.dataId);
115164 } else if (texData.isPacked && !isReshapeFree(texData.shape, input.shape)) {
115165 // This is a special case where a texture exists for a tensor
115166 // but the shapes are incompatible (due to packing constraints) because
115167 // the tensor did not have a chance to go through the packed reshape
115168 // shader. This only happens when we reshape the *same* tensor to form
115169 // *distinct* inputs to an op, e.g. dotting a vector with itself. This
115170 // case will disappear once packed uploading is the default.
115171 var savedInput = input;
115172 var targetShape = input.shape;
115173 input.shape = texData.shape;
115174 input = _this4.packedReshape(input, targetShape);
115175 dataToDispose.push(input);
115176 texData = _this4.texData.get(input.dataId);
115177 savedInput.shape = targetShape;
115178 }
115179 return {
115180 shape: input.shape,
115181 texData: texData,
115182 isUniform: false
115183 };
115184 });
115185 this.uploadToGPU(output.dataId);
115186 var outputData = {
115187 shape: output.shape,
115188 texData: outData,
115189 isUniform: false
115190 };
115191 var key = makeShaderKey(program, inputsData, outputData);
115192 var binary = this.getAndSaveBinary(key, function () {
115193 return compileProgram(_this4.gpgpu, program, inputsData, outputData);
115194 });
115195 var shouldTimeProgram = this.activeTimers != null;
115196 var query;
115197 if (shouldTimeProgram) {
115198 query = this.startTimer();
115199 }
115200 if (!env().get('ENGINE_COMPILE_ONLY')) {
115201 runProgram(this.gpgpu, binary, inputsData, outputData, customUniformValues);
115202 }
115203 dataToDispose.forEach(function (info) {
115204 return _this4.disposeIntermediateTensorInfo(info);
115205 });
115206 if (shouldTimeProgram) {
115207 query = this.endTimer(query);
115208 this.activeTimers.push({
115209 name: program.constructor.name,
115210 query: this.getQueryTime(query)
115211 });
115212 }
115213 var glFlushThreshold = env().getNumber('WEBGL_FLUSH_THRESHOLD');
115214 // Manually GL flush requested
115215 if (glFlushThreshold > 0) {
115216 var time = now();
115217 if (time - this.lastGlFlushTime > glFlushThreshold) {
115218 this.gpgpu.gl.flush();
115219 this.lastGlFlushTime = time;
115220 }
115221 }
115222 if (!env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked && preventEagerUnpackingOfOutput === false) {
115223 var unpacked = this.unpackTensor(output);
115224 this.disposeIntermediateTensorInfo(output);
115225 return unpacked;
115226 }
115227 return output;
115228 }
115229 }, {
115230 key: "compileAndRun",
115231 value: function compileAndRun(program, inputs, outputDtype, customUniformValues) {
115232 var preventEagerUnpackingOfOutput = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
115233 outputDtype = outputDtype || inputs[0].dtype;
115234 var outInfo = this.runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput);
115235 return outInfo;
115236 }
115237 }, {
115238 key: "getAndSaveBinary",
115239 value: function getAndSaveBinary(key, getBinary) {
115240 if (!(key in this.binaryCache)) {
115241 this.binaryCache[key] = getBinary();
115242 }
115243 return this.binaryCache[key];
115244 }
115245 }, {
115246 key: "getTextureManager",
115247 value: function getTextureManager() {
115248 return this.textureManager;
115249 }
115250 }, {
115251 key: "dispose",
115252 value: function dispose() {
115253 var _this5 = this;
115254 if (this.disposed) {
115255 return;
115256 }
115257 // Avoid disposing the compiled webgl programs during unit testing because
115258 // it slows down test execution.
115259 if (!env().getBool('IS_TEST')) {
115260 var allKeys = Object.keys(this.binaryCache);
115261 allKeys.forEach(function (key) {
115262 _this5.gpgpu.deleteProgram(_this5.binaryCache[key].webGLProgram);
115263 delete _this5.binaryCache[key];
115264 });
115265 }
115266 this.textureManager.dispose();
115267 if (this.canvas != null && typeof HTMLCanvasElement !== 'undefined' && this.canvas instanceof HTMLCanvasElement) {
115268 this.canvas.remove();
115269 } else {
115270 this.canvas = null;
115271 }
115272 if (this.gpgpuCreatedLocally) {
115273 this.gpgpu.program = null;
115274 this.gpgpu.dispose();
115275 }
115276 this.disposed = true;
115277 }
115278 }, {
115279 key: "floatPrecision",
115280 value: function floatPrecision() {
115281 var _this6 = this;
115282 if (this.floatPrecisionValue == null) {
115283 this.floatPrecisionValue = tidy(function () {
115284 if (!env().get('WEBGL_RENDER_FLOAT32_ENABLED')) {
115285 // Momentarily switching DEBUG flag to false so we don't throw an
115286 // error trying to upload a small value.
115287 var debugFlag = env().getBool('DEBUG');
115288 env().set('DEBUG', false);
115289 var underflowCheckValue = _this6.abs(scalar(1e-8)).dataSync()[0];
115290 env().set('DEBUG', debugFlag);
115291 if (underflowCheckValue > 0) {
115292 return 32;
115293 }
115294 }
115295 return 16;
115296 });
115297 }
115298 return this.floatPrecisionValue;
115299 }
115300 /** Returns the smallest representable number. */
115301 }, {
115302 key: "epsilon",
115303 value: function epsilon() {
115304 return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
115305 }
115306 }, {
115307 key: "uploadToGPU",
115308 value: function uploadToGPU(dataId) {
115309 var texData = this.texData.get(dataId);
115310 var shape = texData.shape,
115311 dtype = texData.dtype,
115312 values = texData.values,
115313 texture = texData.texture,
115314 usage = texData.usage,
115315 isPacked = texData.isPacked;
115316 if (texture != null) {
115317 // Array is already on GPU. No-op.
115318 return;
115319 }
115320 var shouldTimeProgram = this.activeTimers != null;
115321 var start;
115322 if (shouldTimeProgram) {
115323 start = now();
115324 }
115325 var texShape = texData.texShape;
115326 if (texShape == null) {
115327 // This texShape may not be the final texture shape. For packed or dense
115328 // textures, the texShape will be changed when textures are created.
115329 texShape = getTextureShapeFromLogicalShape(shape, isPacked);
115330 texData.texShape = texShape;
115331 }
115332 if (values != null) {
115333 var shapeAs3D = getShapeAs3D(shape);
115334 var program;
115335 var width = texShape[1],
115336 height = texShape[0];
115337 var isByteArray = values instanceof Uint8Array || values instanceof Uint8ClampedArray;
115338 // texture for float array is PhysicalTextureType.PACKED_2X2_FLOAT32, we
115339 // need to make sure the upload uses the same packed size
115340 if (isPacked || !isByteArray) {
115341 var _tex_util$getPackedMa = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]);
115342 var _tex_util$getPackedMa2 = _slicedToArray(_tex_util$getPackedMa, 2);
115343 width = _tex_util$getPackedMa2[0];
115344 height = _tex_util$getPackedMa2[1];
115345 }
115346 if (isPacked) {
115347 program = new EncodeMatrixPackedProgram(shapeAs3D, isByteArray);
115348 } else {
115349 program = new EncodeMatrixProgram(shapeAs3D, isByteArray);
115350 }
115351 // TexShape for float array needs to be the original shape, which byte
115352 // array needs to be packed size. This allow the data upload shape to be
115353 // matched with texture creation logic.
115354 var tempDenseInputTexShape = isByteArray ? [height, width] : texShape;
115355 var tempDenseInputHandle = this.makeTensorInfo(tempDenseInputTexShape, dtype);
115356 var tempDenseInputTexData = this.texData.get(tempDenseInputHandle.dataId);
115357 if (isByteArray) {
115358 tempDenseInputTexData.usage = TextureUsage.PIXELS;
115359 } else {
115360 tempDenseInputTexData.usage = TextureUsage.UPLOAD;
115361 }
115362 tempDenseInputTexData.texShape = tempDenseInputTexShape;
115363 this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values);
115364 var customValues = [[height, width]];
115365 // We want the output to remain packed regardless of the value of
115366 // WEBGL_PACK.
115367 var preventEagerUnpacking = true;
115368 var encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, customValues, preventEagerUnpacking);
115369 // Have the original texture assume the identity of the encoded output.
115370 var outputTexData = this.texData.get(encodedOutputTarget.dataId);
115371 texData.texShape = outputTexData.texShape;
115372 texData.isPacked = outputTexData.isPacked;
115373 texData.usage = outputTexData.usage;
115374 if (!env().get('ENGINE_COMPILE_ONLY')) {
115375 texData.texture = outputTexData.texture;
115376 // Once uploaded, don't store the values on cpu.
115377 texData.values = null;
115378 this.texData.delete(encodedOutputTarget.dataId);
115379 } else {
115380 this.disposeData(encodedOutputTarget.dataId);
115381 }
115382 this.disposeIntermediateTensorInfo(tempDenseInputHandle);
115383 if (shouldTimeProgram) {
115384 this.uploadWaitMs += now() - start;
115385 }
115386 } else {
115387 var newTexture = this.acquireTexture(texShape, usage, dtype, isPacked);
115388 texData.texture = newTexture;
115389 }
115390 }
115391 }, {
115392 key: "convertAndCacheOnCPU",
115393 value: function convertAndCacheOnCPU(dataId, float32Values) {
115394 var texData = this.texData.get(dataId);
115395 var dtype = texData.dtype;
115396 if (float32Values != null) {
115397 texData.values = float32ToTypedArray(float32Values, dtype);
115398 }
115399 return texData.values;
115400 }
115401 }, {
115402 key: "acquireTexture",
115403 value: function acquireTexture(texShape, texType, dtype, isPacked) {
115404 this.numBytesInGPU += this.computeBytes(texShape, dtype);
115405 if (!this.warnedAboutMemory && this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) {
115406 var mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2);
115407 this.warnedAboutMemory = true;
115408 console.warn("High memory usage in GPU: ".concat(mb, " MB, ") + "most likely due to a memory leak");
115409 }
115410 return this.textureManager.acquireTexture(texShape, texType, isPacked);
115411 }
115412 }, {
115413 key: "computeBytes",
115414 value: function computeBytes(shape, dtype) {
115415 return shape[0] * shape[1] * bytesPerElement(dtype);
115416 }
115417 }, {
115418 key: "checkCompileCompletion",
115419 value: function checkCompileCompletion() {
115420 for (var _i = 0, _Object$entries = Object.entries(this.binaryCache); _i < _Object$entries.length; _i++) {
115421 var _Object$entries$_i = _slicedToArray(_Object$entries[_i], 2),
115422 binary = _Object$entries$_i[1];
115423 this.checkCompletion_(binary);
115424 }
115425 }
115426 }, {
115427 key: "checkCompileCompletionAsync",
115428 value: function () {
115429 var _checkCompileCompletionAsync = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee4() {
115430 var _this7 = this;
115431 var ps, _i2, _Object$entries2, _Object$entries2$_i, binary, _loop, _i3, _Object$entries3;
115432 return _regeneratorRuntime().wrap(function _callee4$(_context5) {
115433 while (1) switch (_context5.prev = _context5.next) {
115434 case 0:
115435 ps = [];
115436 if (!this.gpgpu.parallelCompilationExtension) {
115437 _context5.next = 6;
115438 break;
115439 }
115440 for (_i2 = 0, _Object$entries2 = Object.entries(this.binaryCache); _i2 < _Object$entries2.length; _i2++) {
115441 _Object$entries2$_i = _slicedToArray(_Object$entries2[_i2], 2), binary = _Object$entries2$_i[1];
115442 ps.push(this.checkCompletionAsync_(binary));
115443 }
115444 return _context5.abrupt("return", Promise.all(ps));
115445 case 6:
115446 _loop = /*#__PURE__*/_regeneratorRuntime().mark(function _loop() {
115447 var _Object$entries3$_i, binary, p;
115448 return _regeneratorRuntime().wrap(function _loop$(_context4) {
115449 while (1) switch (_context4.prev = _context4.next) {
115450 case 0:
115451 _Object$entries3$_i = _slicedToArray(_Object$entries3[_i3], 2), binary = _Object$entries3$_i[1];
115452 p = new Promise(function (resolve) {
115453 try {
115454 _this7.checkCompletion_(binary);
115455 resolve(true);
115456 } catch (error) {
115457 throw error;
115458 }
115459 });
115460 ps.push(p);
115461 case 3:
115462 case "end":
115463 return _context4.stop();
115464 }
115465 }, _loop);
115466 });
115467 _i3 = 0, _Object$entries3 = Object.entries(this.binaryCache);
115468 case 8:
115469 if (!(_i3 < _Object$entries3.length)) {
115470 _context5.next = 13;
115471 break;
115472 }
115473 return _context5.delegateYield(_loop(), "t0", 10);
115474 case 10:
115475 _i3++;
115476 _context5.next = 8;
115477 break;
115478 case 13:
115479 return _context5.abrupt("return", Promise.all(ps));
115480 case 14:
115481 case "end":
115482 return _context5.stop();
115483 }
115484 }, _callee4, this);
115485 }));
115486 function checkCompileCompletionAsync() {
115487 return _checkCompileCompletionAsync.apply(this, arguments);
115488 }
115489 return checkCompileCompletionAsync;
115490 }()
115491 }, {
115492 key: "checkCompletionAsync_",
115493 value: function () {
115494 var _checkCompletionAsync_ = _asyncToGenerator( /*#__PURE__*/_regeneratorRuntime().mark(function _callee5(binary) {
115495 return _regeneratorRuntime().wrap(function _callee5$(_context6) {
115496 while (1) switch (_context6.prev = _context6.next) {
115497 case 0:
115498 if (!this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.parallelCompilationExtension.COMPLETION_STATUS_KHR)) {
115499 _context6.next = 4;
115500 break;
115501 }
115502 return _context6.abrupt("return", this.checkCompletion_(binary));
115503 case 4:
115504 _context6.next = 6;
115505 return nextFrame();
115506 case 6:
115507 return _context6.abrupt("return", this.checkCompletionAsync_(binary));
115508 case 7:
115509 case "end":
115510 return _context6.stop();
115511 }
115512 }, _callee5, this);
115513 }));
115514 function checkCompletionAsync_(_x3) {
115515 return _checkCompletionAsync_.apply(this, arguments);
115516 }
115517 return checkCompletionAsync_;
115518 }()
115519 }, {
115520 key: "checkCompletion_",
115521 value: function checkCompletion_(binary) {
115522 if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.gl.LINK_STATUS) === false) {
115523 console.log(this.gpgpu.gl.getProgramInfoLog(binary.webGLProgram));
115524 if (this.gpgpu.gl.getShaderParameter(binary.fragmentShader, this.gpgpu.gl.COMPILE_STATUS) === false) {
115525 logShaderSourceAndInfoLog(binary.source, this.gpgpu.gl.getShaderInfoLog(binary.fragmentShader));
115526 throw new Error('Failed to compile fragment shader.');
115527 }
115528 throw new Error('Failed to link vertex and fragment shaders.');
115529 }
115530 return true;
115531 }
115532 }, {
115533 key: "getUniformLocations",
115534 value: function getUniformLocations$1() {
115535 for (var _i4 = 0, _Object$values = Object.values(this.binaryCache); _i4 < _Object$values.length; _i4++) {
115536 var binary = _Object$values[_i4];
115537 // TODO: Iterating through all binaries to build VAOs is supposed to be in
115538 // a seperate function, like 'setVaos'. However, to avoid breaking changes
115539 // for the users using parallel compile feature now, buildVao is silently
115540 // added here.
115541 this.gpgpu.buildVao(binary.webGLProgram);
115542 var _getUniformLocations2 = getUniformLocations(this.gpgpu, binary.program, binary.webGLProgram),
115543 variablesLocations = _getUniformLocations2.variablesLocations,
115544 customUniformLocations = _getUniformLocations2.customUniformLocations,
115545 infLoc = _getUniformLocations2.infLoc,
115546 nanLoc = _getUniformLocations2.nanLoc,
115547 outShapeLocation = _getUniformLocations2.outShapeLocation,
115548 outShapeStridesLocation = _getUniformLocations2.outShapeStridesLocation,
115549 outTexShapeLocation = _getUniformLocations2.outTexShapeLocation;
115550 binary.variablesLocations = variablesLocations;
115551 binary.customUniformLocations = customUniformLocations;
115552 binary.infLoc = infLoc;
115553 binary.nanLoc = nanLoc;
115554 binary.outShapeLocation = outShapeLocation;
115555 binary.outShapeStridesLocation = outShapeStridesLocation;
115556 binary.outTexShapeLocation = outTexShapeLocation;
115557 }
115558 }
115559 /**
115560 * Create a TF.js tensor out of an existing WebGL texture. A new texture will
115561 * be created.
115562 */
115563 }, {
115564 key: "createTensorFromGPUData",
115565 value: function createTensorFromGPUData(values, shape, dtype) {
115566 values.channels = values.channels || 'RGBA';
115567 var texture = values.texture,
115568 height = values.height,
115569 width = values.width,
115570 channels = values.channels;
115571 var backend = engine().backend;
115572 // Have to throw an error, otherwise WebGL just warns and returns wrong
115573 // values.
115574 if (!backend.gpgpu.gl.isTexture(texture)) {
115575 throw new Error("The texture is invalid. Also, please make sure the texture and " + "the TFJS WebGL backend are using the same canvas. If you want to " + "use your own custom canvas, you have to create and use the custom " + "TFJS WebGL backend created from the canvas through " + "'new tf.MathBackendWebGL(customCanvas)'.");
115576 }
115577 var dataId = backend.writeTexture(texture, shape, dtype, height, width, channels);
115578 return engine().makeTensorFromDataId(dataId, shape, dtype, backend);
115579 }
115580 }]);
115581 return MathBackendWebGL;
115582 }(KernelBackend);
115583 MathBackendWebGL.nextDataId = 0;
115584 function float32ToTypedArray(a, dtype) {
115585 if (dtype === 'float32' || dtype === 'complex64') {
115586 return a;
115587 } else if (dtype === 'int32' || dtype === 'bool') {
115588 var result = dtype === 'int32' ? new Int32Array(a.length) : new Uint8Array(a.length);
115589 for (var i = 0; i < result.length; ++i) {
115590 result[i] = Math.round(a[i]);
115591 }
115592 return result;
115593 } else {
115594 throw new Error("Unknown dtype ".concat(dtype));
115595 }
115596 }
115597
115598 /** @license See the LICENSE file. */
115599 // This code is auto-generated, do not modify this file!
115600 var version$2 = '4.22.0';
115601
115602 /**
115603 * @license
115604 * Copyright 2019 Google LLC. All Rights Reserved.
115605 * Licensed under the Apache License, Version 2.0 (the "License");
115606 * you may not use this file except in compliance with the License.
115607 * You may obtain a copy of the License at
115608 *
115609 * http://www.apache.org/licenses/LICENSE-2.0
115610 *
115611 * Unless required by applicable law or agreed to in writing, software
115612 * distributed under the License is distributed on an "AS IS" BASIS,
115613 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
115614 * See the License for the specific language governing permissions and
115615 * limitations under the License.
115616 * =============================================================================
115617 */
115618 /**
115619 * Enforce use of half precision textures if available on the platform.
115620 *
115621 * @doc {heading: 'Environment', namespace: 'webgl'}
115622 */
115623 function forceHalfFloat() {
115624 env().set('WEBGL_FORCE_F16_TEXTURES', true);
115625 }
115626
115627 /**
115628 * @license
115629 * Copyright 2020 Google Inc. All Rights Reserved.
115630 * Licensed under the Apache License, Version 2.0 (the "License");
115631 * you may not use this file except in compliance with the License.
115632 * You may obtain a copy of the License at
115633 *
115634 * http://www.apache.org/licenses/LICENSE-2.0
115635 *
115636 * Unless required by applicable law or agreed to in writing, software
115637 * distributed under the License is distributed on an "AS IS" BASIS,
115638 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
115639 * See the License for the specific language governing permissions and
115640 * limitations under the License.
115641 * =============================================================================
115642 */
115643 if (isBrowser()) {
115644 registerBackend('webgl', function () {
115645 return new MathBackendWebGL();
115646 }, 2 /* priority */);
115647 }
115648 var webgl = {
115649 forceHalfFloat: forceHalfFloat
115650 };
115651
115652 var CHECK_NAN_SNIPPET = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n";
115653 var SQUARED_DIFFERENCE$1 = 'return (a - b) * (a - b);';
115654 var BinaryOpProgram = /*#__PURE__*/_createClass(function BinaryOpProgram(op, aShape, bShape) {
115655 _classCallCheck(this, BinaryOpProgram);
115656 this.variableNames = ['A', 'B'];
115657 this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
115658 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
115659 this.userCode = "\n float binaryOperation(float a, float b) {\n ".concat(op, "\n }\n\n void main() {\n float a = getAAtOutCoords();\n float b = getBAtOutCoords();\n setOutput(binaryOperation(a, b));\n }\n ");
115660 });
115661
115662 var CHECK_NAN_SNIPPET_PACKED = "\n result.r = isNaN.r ? NAN : result.r;\n result.g = isNaN.g ? NAN : result.g;\n result.b = isNaN.b ? NAN : result.b;\n result.a = isNaN.a ? NAN : result.a;\n";
115663 var ELU_DER$1 = "\n vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n";
115664 var NOT_EQUAL$1 = "\n return vec4(notEqual(a, b));\n";
115665 var BinaryOpPackedProgram = /*#__PURE__*/_createClass(function BinaryOpPackedProgram(op, aShape, bShape) {
115666 var checkOutOfBounds = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
115667 _classCallCheck(this, BinaryOpPackedProgram);
115668 this.variableNames = ['A', 'B'];
115669 this.supportsBroadcasting = true;
115670 this.packedInputs = true;
115671 this.packedOutput = true;
115672 this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
115673 var rank = this.outputShape.length;
115674 this.enableShapeUniforms = useShapeUniforms(rank);
115675 var checkOutOfBoundsString = '';
115676 if (checkOutOfBounds) {
115677 if (rank === 0 || sizeFromShape(this.outputShape) === 1) {
115678 checkOutOfBoundsString = "\n result.y = 0.;\n result.z = 0.;\n result.w = 0.;\n ";
115679 } else {
115680 var dtype = getCoordsDataType(rank);
115681 checkOutOfBoundsString = "\n ".concat(dtype, " coords = getOutputCoords();\n ");
115682 if (rank === 1) {
115683 if (this.enableShapeUniforms) {
115684 checkOutOfBoundsString += "\n result.y = (coords + 1) >= outShape ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n ";
115685 } else {
115686 checkOutOfBoundsString += "\n result.y = (coords + 1) >= ".concat(this.outputShape[0], " ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n ");
115687 }
115688 } else {
115689 var channels = getChannels('coords', rank);
115690 if (this.enableShapeUniforms) {
115691 checkOutOfBoundsString += "\n bool nextRowOutOfBounds =\n (".concat(channels[rank - 2], " + 1) >= outShape[").concat(rank, " - 2];\n bool nextColOutOfBounds =\n (").concat(channels[rank - 1], " + 1) >= outShape[").concat(rank, " - 1];\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n ");
115692 } else {
115693 checkOutOfBoundsString += "\n bool nextRowOutOfBounds =\n (".concat(channels[rank - 2], " + 1) >= ").concat(this.outputShape[rank - 2], ";\n bool nextColOutOfBounds =\n (").concat(channels[rank - 1], " + 1) >= ").concat(this.outputShape[rank - 1], ";\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n ");
115694 }
115695 }
115696 }
115697 }
115698 this.userCode = "\n vec4 binaryOperation(vec4 a, vec4 b) {\n ".concat(op, "\n }\n\n void main() {\n vec4 a = getAAtOutCoords();\n vec4 b = getBAtOutCoords();\n\n vec4 result = binaryOperation(a, b);\n ").concat(checkOutOfBoundsString, "\n\n setOutput(result);\n }\n ");
115699 });
115700
115701 /**
115702 * @license
115703 * Copyright 2020 Google LLC. All Rights Reserved.
115704 * Licensed under the Apache License, Version 2.0 (the "License");
115705 * you may not use this file except in compliance with the License.
115706 * You may obtain a copy of the License at
115707 *
115708 * http://www.apache.org/licenses/LICENSE-2.0
115709 *
115710 * Unless required by applicable law or agreed to in writing, software
115711 * distributed under the License is distributed on an "AS IS" BASIS,
115712 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
115713 * See the License for the specific language governing permissions and
115714 * limitations under the License.
115715 * =============================================================================
115716 */
115717 function identity(args) {
115718 var inputs = args.inputs,
115719 backend = args.backend;
115720 var x = inputs.x;
115721 backend.incRef(x.dataId);
115722 return {
115723 dataId: x.dataId,
115724 shape: x.shape,
115725 dtype: x.dtype
115726 };
115727 }
115728 var identityConfig = {
115729 kernelName: Identity$1,
115730 backendName: 'webgl',
115731 kernelFunc: identity
115732 };
115733
115734 /**
115735 * @license
115736 * Copyright 2020 Google LLC. All Rights Reserved.
115737 * Licensed under the Apache License, Version 2.0 (the "License");
115738 * you may not use this file except in compliance with the License.
115739 * You may obtain a copy of the License at
115740 *
115741 * http://www.apache.org/licenses/LICENSE-2.0
115742 *
115743 * Unless required by applicable law or agreed to in writing, software
115744 * distributed under the License is distributed on an "AS IS" BASIS,
115745 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
115746 * See the License for the specific language governing permissions and
115747 * limitations under the License.
115748 * =============================================================================
115749 */
115750 /**
115751 * In WebGL data is stored in GPU textures which can't be efficiently copied, so
115752 * complex tensors share data with their real and imaginary components. Complex
115753 * tensors' reference to the components is tracked by refCount on the individual
115754 * component. The refCounts are increased by the identity call.
115755 *
115756 * When a complex tensor is disposed, it will reduce the refCount on the
115757 * components by calling disposeData on each.
115758 */
115759 function complex(args) {
115760 var inputs = args.inputs,
115761 backend = args.backend;
115762 var real = inputs.real,
115763 imag = inputs.imag;
115764 var complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
115765 var complex = backend.texData.get(complexInfo.dataId);
115766 var realTensorInfo = identity({
115767 inputs: {
115768 x: real
115769 },
115770 backend: backend
115771 });
115772 var imagTensorInfo = identity({
115773 inputs: {
115774 x: imag
115775 },
115776 backend: backend
115777 });
115778 complex.complexTensorInfos = {
115779 real: realTensorInfo,
115780 imag: imagTensorInfo
115781 };
115782 return complexInfo;
115783 }
115784 var complexConfig = {
115785 kernelName: Complex,
115786 backendName: 'webgl',
115787 kernelFunc: complex
115788 };
115789
115790 /**
115791 * @license
115792 * Copyright 2020 Google LLC. All Rights Reserved.
115793 * Licensed under the Apache License, Version 2.0 (the "License");
115794 * you may not use this file except in compliance with the License.
115795 * You may obtain a copy of the License at
115796 *
115797 * http://www.apache.org/licenses/LICENSE-2.0
115798 *
115799 * Unless required by applicable law or agreed to in writing, software
115800 * distributed under the License is distributed on an "AS IS" BASIS,
115801 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
115802 * See the License for the specific language governing permissions and
115803 * limitations under the License.
115804 * =============================================================================
115805 */
115806 var LEAKYRELU = "return (a < 0.) ? b * a : a;";
115807 var LEAKYRELU_PACKED = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n";
115808 function leakyRelu(args) {
115809 var inputs = args.inputs,
115810 backend = args.backend,
115811 attrs = args.attrs;
115812 var x = inputs.x;
115813 var alpha = attrs.alpha;
115814 var $alpha = backend.makeTensorInfo([], 'float32', createScalarValue(alpha, 'float32'));
115815 var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(LEAKYRELU_PACKED, x.shape, $alpha.shape) : new BinaryOpProgram(LEAKYRELU, x.shape, $alpha.shape);
115816 var result = backend.runWebGLProgram(program, [x, $alpha], 'float32');
115817 backend.disposeIntermediateTensorInfo($alpha);
115818 return result;
115819 }
115820 var leakyReluConfig = {
115821 kernelName: LeakyRelu,
115822 backendName: 'webgl',
115823 kernelFunc: leakyRelu
115824 };
115825
115826 /**
115827 * @license
115828 * Copyright 2020 Google LLC. All Rights Reserved.
115829 * Licensed under the Apache License, Version 2.0 (the "License");
115830 * you may not use this file except in compliance with the License.
115831 * You may obtain a copy of the License at
115832 *
115833 * http://www.apache.org/licenses/LICENSE-2.0
115834 *
115835 * Unless required by applicable law or agreed to in writing, software
115836 * distributed under the License is distributed on an "AS IS" BASIS,
115837 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
115838 * See the License for the specific language governing permissions and
115839 * limitations under the License.
115840 * =============================================================================
115841 */
115842 var PRELU = "return (a < 0.) ? b * a : a;";
115843 var PRELU_PACKED = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n";
115844 function prelu(args) {
115845 var inputs = args.inputs,
115846 backend = args.backend;
115847 var x = inputs.x,
115848 alpha = inputs.alpha;
115849 var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(PRELU_PACKED, x.shape, alpha.shape) : new BinaryOpProgram(PRELU, x.shape, alpha.shape);
115850 return backend.runWebGLProgram(program, [x, alpha], 'float32');
115851 }
115852 var preluConfig = {
115853 kernelName: Prelu,
115854 backendName: 'webgl',
115855 kernelFunc: prelu
115856 };
115857
115858 var CHECK_NAN_SNIPPET_UNARY = "if (isnan(x)) return x;";
115859 /**
115860 * Template that creates a `KernelFunc` for unary ops.
115861 * @param opSnippet Op snippet to create `UnaryOpProgram`.
115862 * @param packedOpSnippet Op snippet to create `UnaryOpPackedProgram`.
115863 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
115864 * result has the same dtype as the first input. This is mainly used in
115865 * comparison kernels, such as Equal, Less, Greater, etc.
115866 */
115867 function unaryKernelFunc(_ref) {
115868 var opSnippet = _ref.opSnippet,
115869 packedOpSnippet = _ref.packedOpSnippet,
115870 cpuKernelImpl = _ref.cpuKernelImpl,
115871 dtype = _ref.dtype;
115872 return function (_ref2) {
115873 var inputs = _ref2.inputs,
115874 backend = _ref2.backend;
115875 var x = inputs.x;
115876 var webglBackend = backend;
115877 var $dtype = dtype || x.dtype;
115878 if (webglBackend.shouldExecuteOnCPU([x]) && cpuKernelImpl != null) {
115879 var xData = webglBackend.texData.get(x.dataId);
115880 var outValues = cpuKernelImpl(xData.values, $dtype);
115881 return webglBackend.makeTensorInfo(x.shape, $dtype, outValues);
115882 }
115883 var shouldUsePackedProgram = env().getBool('WEBGL_PACK_UNARY_OPERATIONS') && packedOpSnippet != null;
115884 var program;
115885 if (shouldUsePackedProgram) {
115886 program = new UnaryOpPackedProgram(x.shape, packedOpSnippet);
115887 } else {
115888 program = new UnaryOpProgram(x.shape, opSnippet);
115889 }
115890 return webglBackend.runWebGLProgram(program, [x], $dtype);
115891 };
115892 }
115893 /**
115894 * Template that creates a `KernelFunc` for binary ops.
115895 * @param opSnippet Op snippet to create `BinaryOpProgram`.
115896 * @param packedOpSnippet Op snippet to create `BinaryOpPackedProgram`.
115897 * @param checkOutOfBoundsForPackedProgram Whether to set checkOutOfBounds=true
115898 * when creating BinaryOpPackedProgram.
115899 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
115900 * result has the same dtype as the first input. This is mainly used in
115901 * comparison kernels, such as Equal, Less, Greater, etc.
115902 */
115903 function binaryKernelFunc(_ref3) {
115904 var opSnippet = _ref3.opSnippet,
115905 packedOpSnippet = _ref3.packedOpSnippet,
115906 _ref3$checkOutOfBound = _ref3.checkOutOfBounds,
115907 checkOutOfBounds = _ref3$checkOutOfBound === void 0 ? false : _ref3$checkOutOfBound,
115908 _ref3$supportsComplex = _ref3.supportsComplex,
115909 supportsComplex = _ref3$supportsComplex === void 0 ? false : _ref3$supportsComplex,
115910 cpuKernelImpl = _ref3.cpuKernelImpl,
115911 dtype = _ref3.dtype;
115912 return function (_ref4) {
115913 var inputs = _ref4.inputs,
115914 backend = _ref4.backend;
115915 var a = inputs.a,
115916 b = inputs.b;
115917 var webglBackend = backend;
115918 if (supportsComplex && a.dtype === 'complex64') {
115919 var aData = webglBackend.texData.get(a.dataId);
115920 var bData = webglBackend.texData.get(b.dataId);
115921 var _map = [[aData.complexTensorInfos.real, bData.complexTensorInfos.real], [aData.complexTensorInfos.imag, bData.complexTensorInfos.imag]].map(function (complexParts) {
115922 var _complexParts = _slicedToArray(complexParts, 2),
115923 aPart = _complexParts[0],
115924 bPart = _complexParts[1];
115925 var aHandle = {
115926 dataId: aPart.dataId,
115927 dtype: aPart.dtype,
115928 shape: a.shape
115929 };
115930 var bHandle = {
115931 dataId: bPart.dataId,
115932 dtype: bPart.dtype,
115933 shape: b.shape
115934 };
115935 var program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
115936 return webglBackend.runWebGLProgram(program, [aHandle, bHandle], upcastType(aPart.dtype, bPart.dtype));
115937 }),
115938 _map2 = _slicedToArray(_map, 2),
115939 real = _map2[0],
115940 imag = _map2[1];
115941 var complexOutput = complex({
115942 inputs: {
115943 real: real,
115944 imag: imag
115945 },
115946 backend: webglBackend
115947 });
115948 webglBackend.disposeIntermediateTensorInfo(real);
115949 webglBackend.disposeIntermediateTensorInfo(imag);
115950 // TODO(annxingyuan): Implement CPU forwarding for complex inputs.
115951 return complexOutput;
115952 }
115953 var $dtype = dtype || upcastType(a.dtype, b.dtype);
115954 if ((a.dtype === 'string' || b.dtype === 'string' || webglBackend.shouldExecuteOnCPU([a, b])) && cpuKernelImpl != null) {
115955 var aVals = webglBackend.texData.get(a.dataId).values;
115956 var bVals = webglBackend.texData.get(b.dataId).values;
115957 var decodedAVals = a.dtype === 'string' ?
115958 // tslint:disable-next-line: no-any
115959 fromUint8ToStringArray(aVals) : aVals;
115960 var decodedBVals = a.dtype === 'string' ?
115961 // tslint:disable-next-line: no-any
115962 fromUint8ToStringArray(bVals) : bVals;
115963 var _cpuKernelImpl = cpuKernelImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype),
115964 _cpuKernelImpl2 = _slicedToArray(_cpuKernelImpl, 2),
115965 outValues = _cpuKernelImpl2[0],
115966 outShape = _cpuKernelImpl2[1];
115967 var out = webglBackend.makeTensorInfo(outShape, $dtype);
115968 var outData = webglBackend.texData.get(out.dataId);
115969 outData.values = outValues;
115970 return out;
115971 }
115972 var shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') && packedOpSnippet != null;
115973 var program;
115974 if (shouldUsePackedProgram) {
115975 program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds);
115976 } else {
115977 program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
115978 }
115979 return webglBackend.runWebGLProgram(program, [a, b], $dtype);
115980 };
115981 }
115982 function mapActivationToShaderProgram(activation) {
115983 var packed = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
115984 if (activation === 'linear') {
115985 if (packed) {
115986 return LINEAR;
115987 }
115988 return LINEAR$1;
115989 } else if (activation === 'relu') {
115990 if (packed) {
115991 return RELU$1;
115992 }
115993 return RELU$2;
115994 } else if (activation === 'elu') {
115995 if (packed) {
115996 return ELU$1;
115997 }
115998 return ELU$2;
115999 } else if (activation === 'relu6') {
116000 if (packed) {
116001 return RELU6$1;
116002 }
116003 return RELU6$2;
116004 } else if (activation === 'prelu') {
116005 if (packed) {
116006 return PRELU_PACKED;
116007 }
116008 return PRELU;
116009 } else if (activation === 'leakyrelu') {
116010 if (packed) {
116011 return LEAKYRELU_PACKED;
116012 }
116013 return LEAKYRELU;
116014 } else if (activation === 'sigmoid') {
116015 if (packed) {
116016 return SIGMOID$1;
116017 }
116018 return SIGMOID$2;
116019 }
116020 throw new Error("Activation ".concat(activation, " has not been implemented for the WebGL backend."));
116021 }
116022
116023 var MatMulPackedProgram = /*#__PURE__*/_createClass(function MatMulPackedProgram(aShape, bShape, outputShape) {
116024 var transposeA = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
116025 var transposeB = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
116026 var addBias = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : false;
116027 var activation = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : null;
116028 var hasPreluActivation = arguments.length > 7 && arguments[7] !== undefined ? arguments[7] : false;
116029 var hasLeakyreluActivation = arguments.length > 8 && arguments[8] !== undefined ? arguments[8] : false;
116030 _classCallCheck(this, MatMulPackedProgram);
116031 this.variableNames = ['matrixA', 'matrixB'];
116032 this.packedInputs = true;
116033 this.packedOutput = true;
116034 this.outputShape = outputShape;
116035 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
116036 var sharedDim = transposeA ? aShape[1] : aShape[2];
116037 var sharedDimensionPacked = Math.ceil(sharedDim / 2);
116038 var aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2';
116039 var bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z';
116040 var aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww'];
116041 var bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw'];
116042 var activationSnippet = '',
116043 applyActivationSnippet = '';
116044 if (activation) {
116045 if (hasPreluActivation) {
116046 activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n ".concat(activation, "\n }");
116047 } else if (hasLeakyreluActivation) {
116048 activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getLeakyreluAlphaAtOutCoords();\n ".concat(activation, "\n }");
116049 } else {
116050 activationSnippet = "vec4 activation(vec4 x) {\n ".concat(activation, "\n }");
116051 }
116052 applyActivationSnippet = "result = activation(result);";
116053 }
116054 var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
116055 if (addBias) {
116056 this.variableNames.push('bias');
116057 }
116058 if (hasPreluActivation) {
116059 this.variableNames.push('preluActivationWeights');
116060 }
116061 if (hasLeakyreluActivation) {
116062 this.variableNames.push('leakyreluAlpha');
116063 }
116064 var batchASnippet = 'rc.x';
116065 var batchBSnippet = 'rc.x';
116066 if (aShape[0] < bShape[0]) {
116067 batchASnippet = "imod(rc.x, ".concat(aShape[0], ")");
116068 } else if (bShape[0] < aShape[0]) {
116069 batchBSnippet = "imod(rc.x, ".concat(bShape[0], ")");
116070 }
116071 this.userCode = "\n ".concat(activationSnippet, "\n // Don't use uniform for sharedDimensionPacked for performance.\n const float sharedDimension = ").concat(sharedDimensionPacked, ".0;\n\n vec4 dot2x2ARowBCol(ivec3 rc) {\n vec4 result = vec4(0);\n int batchA = ").concat(batchASnippet, ";\n int batchB = ").concat(batchBSnippet, ";\n for (int i = 0; i < ").concat(sharedDimensionPacked, "; i++) {\n vec4 a = getMatrixA(batchA, ").concat(aSample, ");\n vec4 b = getMatrixB(batchB, ").concat(bSample, ");\n\n // These swizzled products need to be separately added.\n // See: https://github.com/tensorflow/tfjs/issues/1735\n result += (").concat(aSwizzle[0], " * ").concat(bSwizzle[0], ");\n result += (").concat(aSwizzle[1], " * ").concat(bSwizzle[1], ");\n }\n return result;\n }\n\n void main() {\n ivec3 rc = getOutputCoords();\n vec4 result = dot2x2ARowBCol(rc);\n\n ").concat(addBiasSnippet, "\n\n ").concat(applyActivationSnippet, "\n\n setOutput(result);\n }\n ");
116072 });
116073
116074 // (Ar + Ai)(Br + Bi) =
116075 // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
116076 // Yr = ArBr - AB
116077 // Yi = ArBi + AiBr
116078 var COMPLEX_MULTIPLY = {
116079 REAL: 'return areal * breal - aimag * bimag;',
116080 IMAG: 'return areal * bimag + aimag * breal;'
116081 };
116082 var BinaryOpComplexProgram = /*#__PURE__*/_createClass(function BinaryOpComplexProgram(op, aShape, bShape) {
116083 _classCallCheck(this, BinaryOpComplexProgram);
116084 this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag'];
116085 this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
116086 this.userCode = "\n float binaryOpComplex(\n float areal, float aimag, float breal, float bimag) {\n ".concat(op, "\n }\n\n void main() {\n float areal = getARealAtOutCoords();\n float aimag = getAImagAtOutCoords();\n float breal = getBRealAtOutCoords();\n float bimag = getBImagAtOutCoords();\n setOutput(binaryOpComplex(areal, aimag, breal, bimag));\n }\n ");
116087 });
116088
116089 var MUL = 'return a * b;';
116090 function multiply(args) {
116091 var inputs = args.inputs,
116092 backend = args.backend;
116093 var a = inputs.a,
116094 b = inputs.b;
116095 var dtype = upcastType(a.dtype, b.dtype);
116096 if (a.dtype === 'complex64') {
116097 var aData = backend.texData.get(a.dataId);
116098 var bData = backend.texData.get(b.dataId);
116099 var realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
116100 var imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
116101 var _inputs = [{
116102 dataId: aData.complexTensorInfos.real.dataId,
116103 dtype: aData.complexTensorInfos.real.dtype,
116104 shape: a.shape
116105 }, {
116106 dataId: aData.complexTensorInfos.imag.dataId,
116107 dtype: aData.complexTensorInfos.imag.dtype,
116108 shape: a.shape
116109 }, {
116110 dataId: bData.complexTensorInfos.real.dataId,
116111 dtype: bData.complexTensorInfos.real.dtype,
116112 shape: b.shape
116113 }, {
116114 dataId: bData.complexTensorInfos.imag.dataId,
116115 dtype: bData.complexTensorInfos.imag.dtype,
116116 shape: b.shape
116117 }];
116118 var realPart = backend.runWebGLProgram(realProgram, _inputs, 'float32');
116119 var imagPart = backend.runWebGLProgram(imagProgram, _inputs, 'float32');
116120 var complexOutput = complex({
116121 inputs: {
116122 real: realPart,
116123 imag: imagPart
116124 },
116125 backend: backend
116126 });
116127 backend.disposeIntermediateTensorInfo(realPart);
116128 backend.disposeIntermediateTensorInfo(imagPart);
116129 // TODO(annxingyuan): CPU forwarding for complex inputs.
116130 return complexOutput;
116131 }
116132 if (backend.shouldExecuteOnCPU([a, b])) {
116133 var _aData = backend.texData.get(a.dataId);
116134 var _bData = backend.texData.get(b.dataId);
116135 var _cpuMultiply = multiplyImplCPU(a.shape, b.shape, _aData.values, _bData.values, dtype),
116136 _cpuMultiply2 = _slicedToArray(_cpuMultiply, 2),
116137 outValues = _cpuMultiply2[0],
116138 outShape = _cpuMultiply2[1];
116139 var out = backend.makeTensorInfo(outShape, dtype);
116140 var outData = backend.texData.get(out.dataId);
116141 outData.values = outValues;
116142 return out;
116143 }
116144 var program;
116145 if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
116146 program = new BinaryOpPackedProgram(MUL, a.shape, b.shape);
116147 } else {
116148 program = new BinaryOpProgram(MUL, a.shape, b.shape);
116149 }
116150 return backend.runWebGLProgram(program, [a, b], dtype);
116151 }
116152 var multiplyConfig = {
116153 kernelName: Multiply$1,
116154 backendName: 'webgl',
116155 kernelFunc: multiply
116156 };
116157
116158 function packedReshape(input, afterShape, backend) {
116159 var input3DShape = [getBatchDim(input.shape)].concat(_toConsumableArray(getRowsCols(input.shape)));
116160 var input3D = {
116161 dtype: input.dtype,
116162 shape: input3DShape,
116163 dataId: input.dataId
116164 };
116165 var afterShapeAs3D = [getBatchDim(afterShape)].concat(_toConsumableArray(getRowsCols(afterShape)));
116166 var program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
116167 var preventEagerUnpackingOfOutput = true;
116168 var customValues = [input3DShape];
116169 var output = backend.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
116170 return {
116171 dataId: output.dataId,
116172 shape: afterShape,
116173 dtype: output.dtype
116174 };
116175 }
116176
116177 /**
116178 * @license
116179 * Copyright 2020 Google LLC. All Rights Reserved.
116180 * Licensed under the Apache License, Version 2.0 (the "License");
116181 * you may not use this file except in compliance with the License.
116182 * You may obtain a copy of the License at
116183 *
116184 * http://www.apache.org/licenses/LICENSE-2.0
116185 *
116186 * Unless required by applicable law or agreed to in writing, software
116187 * distributed under the License is distributed on an "AS IS" BASIS,
116188 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116189 * See the License for the specific language governing permissions and
116190 * limitations under the License.
116191 * =============================================================================
116192 */
116193 function reshape(args) {
116194 var inputs = args.inputs,
116195 backend = args.backend,
116196 attrs = args.attrs;
116197 var x = inputs.x;
116198 var shape = attrs.shape;
116199 var webglBackend = backend;
116200 var xSize = sizeFromShape(x.shape);
116201 var $shape = inferFromImplicitShape(shape, xSize);
116202 var $xSize = sizeFromShape($shape);
116203 assert$1(xSize === $xSize, function () {
116204 return "The new shape (".concat($shape, ") has ").concat($xSize, " elements and the old ") + "shape (".concat(x.shape, ") has ").concat(xSize, " elements. The new shape and old ") + "shape must have the same number of elements.";
116205 });
116206 var xTexData = webglBackend.texData.get(x.dataId);
116207 if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) && !(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) {
116208 return packedReshape(x, $shape, webglBackend);
116209 }
116210 webglBackend.incRef(x.dataId);
116211 return {
116212 dataId: x.dataId,
116213 shape: $shape,
116214 dtype: x.dtype
116215 };
116216 }
116217 var reshapeConfig = {
116218 kernelName: Reshape$1,
116219 backendName: 'webgl',
116220 kernelFunc: reshape
116221 };
116222
116223 var MeanProgram = /*#__PURE__*/_createClass(function MeanProgram(reduceInfo, divisor) {
116224 _classCallCheck(this, MeanProgram);
116225 this.variableNames = ['x'];
116226 var windowSize = reduceInfo.windowSize,
116227 batchSize = reduceInfo.batchSize,
116228 inSize = reduceInfo.inSize,
116229 outSize = reduceInfo.outSize;
116230 this.outputShape = [batchSize, outSize];
116231 var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
116232 var windowSizeVec4Remainder = windowSize % 4;
116233 var updateSnippet = "sumValue += dot(values, ones);";
116234 if (divisor != null) {
116235 var denominator = 1 / divisor;
116236 updateSnippet = "sumValue += dot(values * ".concat(isInt(denominator) ? denominator.toPrecision(2) : denominator, ", ones);");
116237 }
116238 var checkOutOfBounds = '';
116239 if (inSize % windowSize > 0) {
116240 checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= ".concat(inSize, ") {\n return 0.0;\n }\n ");
116241 }
116242 this.userCode = "\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n ".concat(checkOutOfBounds, "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * ").concat(windowSize, ";\n\n float sumValue = 0.0;\n\n for (int i = 0; i < ").concat(windowSizeNearestVec4, "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n ").concat(updateSnippet, "\n }\n\n int inIdx = inOffset + ").concat(windowSizeNearestVec4, ";\n if (").concat(windowSizeVec4Remainder === 1, ") {\n vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 2, ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1), 0.0, 0.0);\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 3, ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2), 0.0);\n\n ").concat(updateSnippet, "\n }\n setOutput(sumValue);\n }\n ");
116243 });
116244
116245 /**
116246 * @license
116247 * Copyright 2017 Google LLC. All Rights Reserved.
116248 * Licensed under the Apache License, Version 2.0 (the "License");
116249 * you may not use this file except in compliance with the License.
116250 * You may obtain a copy of the License at
116251 *
116252 * http://www.apache.org/licenses/LICENSE-2.0
116253 *
116254 * Unless required by applicable law or agreed to in writing, software
116255 * distributed under the License is distributed on an "AS IS" BASIS,
116256 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116257 * See the License for the specific language governing permissions and
116258 * limitations under the License.
116259 * =============================================================================
116260 */
116261 var ReduceProgram = /*#__PURE__*/_createClass(function ReduceProgram(reduceInfo, reduceType) {
116262 _classCallCheck(this, ReduceProgram);
116263 this.variableNames = ['x'];
116264 var windowSize = reduceInfo.windowSize,
116265 batchSize = reduceInfo.batchSize,
116266 inSize = reduceInfo.inSize,
116267 outSize = reduceInfo.outSize;
116268 this.outputShape = [batchSize, outSize];
116269 var initializationValue = '0.0';
116270 var compareOp = "";
116271 if (reduceType === 'prod') {
116272 initializationValue = '1.0';
116273 } else if (reduceType === 'min') {
116274 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
116275 initializationValue = '1.0 / 1e-20';
116276 compareOp = "min";
116277 } else if (reduceType === 'max') {
116278 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
116279 initializationValue = '-1.0 / 1e-20';
116280 compareOp = "max";
116281 }
116282 var returnValue = "".concat(reduceType, "(").concat(reduceType, "(").concat(reduceType, "(") + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
116283 if (reduceType === 'sum') {
116284 returnValue = "sumValue";
116285 } else if (reduceType === 'prod') {
116286 returnValue = "prodValue";
116287 } else if (reduceType === 'all') {
116288 returnValue = "allValue";
116289 } else if (reduceType === 'any') {
116290 returnValue = "anyValue";
116291 }
116292 var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
116293 var windowSizeVec4Remainder = windowSize % 4;
116294 var updateSnippet = "\n if (".concat(reduceType === 'sum', ") {\n sumValue += dot(values, ones);\n } else if (").concat(reduceType === 'prod', ") {\n vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);\n prodValue *= tmp[0] * tmp[1];\n } else {\n minMaxValue = ").concat(compareOp, "(values, minMaxValue);\n if (").concat(reduceType === 'min', " || ").concat(reduceType === 'max', ") {\n minMaxValue = ").concat(compareOp, "(values, minMaxValue);\n bvec4 isNaN = isnan(values);\n if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {\n minMaxValue = vec4(NAN);\n }\n }\n }\n ");
116295 var vecType = "vec4";
116296 if (reduceType === 'all') {
116297 initializationValue = '1.0';
116298 updateSnippet = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n ";
116299 vecType = "bvec4";
116300 } else if (reduceType === 'any') {
116301 initializationValue = '0.0';
116302 updateSnippet = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n ";
116303 vecType = "bvec4";
116304 }
116305 var checkOutOfBounds = '';
116306 if (inSize % windowSize > 0) {
116307 checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= ".concat(inSize, ") {\n return initializationValue;\n }\n ");
116308 }
116309 this.userCode = "\n const float initializationValue = ".concat(initializationValue, ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n ").concat(checkOutOfBounds, "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * ").concat(windowSize, ";\n\n vec4 minMaxValue = vec4(").concat(initializationValue, ");\n float prodValue = 1.0;\n float sumValue = 0.0;\n float allValue = 1.0;\n float anyValue = 0.0;\n\n for (int i = 0; i < ").concat(windowSizeNearestVec4, "; i += 4) {\n int inIdx = inOffset + i;\n ").concat(vecType, " values = ").concat(vecType, "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n ").concat(updateSnippet, "\n }\n\n int inIdx = inOffset + ").concat(windowSizeNearestVec4, ";\n if (").concat(windowSizeVec4Remainder === 1, ") {\n ").concat(vecType, " values = ").concat(vecType, "(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 2, ") {\n ").concat(vecType, " values = ").concat(vecType, "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 3, ") {\n ").concat(vecType, " values = ").concat(vecType, "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n ").concat(updateSnippet, "\n }\n setOutput(").concat(returnValue, ");\n }\n ");
116310 });
116311
116312 /**
116313 * @license
116314 * Copyright 2020 Google LLC. All Rights Reserved.
116315 * Licensed under the Apache License, Version 2.0 (the "License");
116316 * you may not use this file except in compliance with the License.
116317 * You may obtain a copy of the License at
116318 *
116319 * http://www.apache.org/licenses/LICENSE-2.0
116320 *
116321 * Unless required by applicable law or agreed to in writing, software
116322 * distributed under the License is distributed on an "AS IS" BASIS,
116323 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116324 * See the License for the specific language governing permissions and
116325 * limitations under the License.
116326 * =============================================================================
116327 */
116328 // Returns an array of configuration objects that describe each stage of the
116329 // reduction.
116330 function getReductionStages(inShape) {
116331 var stages = [];
116332 while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) {
116333 var outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1];
116334 var windowSize = computeOptimalWindowSize(outSize);
116335 stages.push({
116336 inSize: outSize,
116337 windowSize: windowSize,
116338 outSize: Math.ceil(outSize / windowSize)
116339 });
116340 }
116341 return stages;
116342 }
116343 function reduce(x, dtype, reductionType, backend) {
116344 var reductionStages = getReductionStages(x.shape);
116345 var result = x;
116346 for (var i = 0; i < reductionStages.length; i++) {
116347 var _reductionStages$i = reductionStages[i],
116348 inSize = _reductionStages$i.inSize,
116349 windowSize = _reductionStages$i.windowSize,
116350 outSize = _reductionStages$i.outSize;
116351 var program = void 0;
116352 var previousResult = void 0;
116353 if (reductionType === 'mean') {
116354 program = i === 0 ? new MeanProgram({
116355 windowSize: windowSize,
116356 inSize: inSize,
116357 batchSize: x.shape[0],
116358 outSize: outSize
116359 }, inSize) : new MeanProgram({
116360 windowSize: windowSize,
116361 inSize: inSize,
116362 batchSize: x.shape[0],
116363 outSize: outSize
116364 });
116365 } else {
116366 program = new ReduceProgram({
116367 windowSize: windowSize,
116368 inSize: inSize,
116369 batchSize: x.shape[0],
116370 outSize: outSize
116371 }, reductionType);
116372 }
116373 previousResult = result;
116374 result = backend.runWebGLProgram(program, [result], dtype);
116375 if (previousResult.dataId !== x.dataId) {
116376 backend.disposeIntermediateTensorInfo(previousResult);
116377 }
116378 }
116379 return result;
116380 }
116381
116382 var TransposeProgram = /*#__PURE__*/_createClass(function TransposeProgram(aShape, newDim) {
116383 _classCallCheck(this, TransposeProgram);
116384 this.variableNames = ['A'];
116385 var outputShape = new Array(aShape.length);
116386 for (var i = 0; i < outputShape.length; i++) {
116387 outputShape[i] = aShape[newDim[i]];
116388 }
116389 this.outputShape = outputShape;
116390 this.rank = outputShape.length;
116391 var dtype = getCoordsDataType(this.rank);
116392 var switched = getSwitchedCoords(newDim);
116393 this.userCode = "\n void main() {\n ".concat(dtype, " resRC = getOutputCoords();\n setOutput(getA(").concat(switched, "));\n }\n ");
116394 });
116395 function getSwitchedCoords(newDim) {
116396 var rank = newDim.length;
116397 if (rank > 6) {
116398 throw Error("Transpose for rank ".concat(rank, " is not yet supported"));
116399 }
116400 var originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'];
116401 var switchedCoords = new Array(rank);
116402 for (var i = 0; i < newDim.length; i++) {
116403 switchedCoords[newDim[i]] = originalOrder[i];
116404 }
116405 return switchedCoords.join();
116406 }
116407
116408 var TransposePackedProgram = /*#__PURE__*/_createClass(function TransposePackedProgram(aShape, newDim) {
116409 _classCallCheck(this, TransposePackedProgram);
116410 this.variableNames = ['A'];
116411 this.packedInputs = true;
116412 this.packedOutput = true;
116413 var outputShape = new Array(aShape.length);
116414 for (var i = 0; i < outputShape.length; i++) {
116415 outputShape[i] = aShape[newDim[i]];
116416 }
116417 this.outputShape = outputShape;
116418 this.rank = outputShape.length;
116419 if (this.rank > 6) {
116420 throw Error("Packed transpose for rank ".concat(this.rank, " is not yet supported."));
116421 }
116422 var dtype = getCoordsDataType(this.rank);
116423 var outputOrder = getVecChannels('rc', this.rank);
116424 var switchedOrder = new Array(this.rank);
116425 for (var _i = 0; _i < newDim.length; _i++) {
116426 switchedOrder[newDim[_i]] = outputOrder[_i];
116427 }
116428 var innerDims = "vec2(".concat(switchedOrder.slice(-2).join(), ")");
116429 var nextColumn = "++".concat(outputOrder[this.rank - 1], " < ").concat(outputShape[this.rank - 1]);
116430 var getc = "getChannel(getA(".concat(switchedOrder.join(), "), ").concat(innerDims, ")");
116431 this.userCode = "\n void main() {\n ".concat(dtype, " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result[0] = ").concat(getc, ";\n if(").concat(nextColumn, ") {\n result[1] = ").concat(getc, ";\n }\n --").concat(outputOrder[this.rank - 1], ";\n if(++").concat(outputOrder[this.rank - 2], " < ").concat(outputShape[this.rank - 2], ") {\n result[2] = ").concat(getc, ";\n if(").concat(nextColumn, ") {\n result[3] = ").concat(getc, ";\n }\n }\n setOutput(result);\n }\n ");
116432 });
116433
116434 /**
116435 * @license
116436 * Copyright 2020 Google LLC. All Rights Reserved.
116437 * Licensed under the Apache License, Version 2.0 (the "License");
116438 * you may not use this file except in compliance with the License.
116439 * You may obtain a copy of the License at
116440 *
116441 * http://www.apache.org/licenses/LICENSE-2.0
116442 *
116443 * Unless required by applicable law or agreed to in writing, software
116444 * distributed under the License is distributed on an "AS IS" BASIS,
116445 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116446 * See the License for the specific language governing permissions and
116447 * limitations under the License.
116448 * =============================================================================
116449 */
116450 function transposeImpl(x, perm, backend) {
116451 var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new TransposePackedProgram(x.shape, perm) : new TransposeProgram(x.shape, perm);
116452 return backend.runWebGLProgram(program, [x], x.dtype);
116453 }
116454
116455 function sumImpl(x, axis, keepDims, backend) {
116456 var reductionIndices = axis;
116457 var xRank = x.shape.length;
116458 var origAxes = parseAxisParam(reductionIndices, x.shape);
116459 var axes = origAxes;
116460 var permutedAxes = getAxesPermutation(axes, xRank);
116461 var sumInputIsTransposed = permutedAxes != null;
116462 var sumInput = x;
116463 if (sumInputIsTransposed) {
116464 sumInput = transposeImpl(x, permutedAxes, backend);
116465 axes = getInnerMostAxes(axes.length, xRank);
116466 }
116467 assertAxesAreInnerMostDims('sum', axes, xRank);
116468 var _backend_util$compute = computeOutAndReduceShapes(sumInput.shape, axes),
116469 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
116470 sumOutShape = _backend_util$compute2[0],
116471 reduceShape = _backend_util$compute2[1];
116472 var outShape = sumOutShape;
116473 if (keepDims) {
116474 // rather than reshape at the end, set the target shape here.
116475 outShape = expandShapeToKeepDim(sumOutShape, origAxes);
116476 }
116477 var inSize = sizeFromShape(reduceShape);
116478 var xSize = sizeFromShape(x.shape);
116479 var batchSize = xSize / inSize;
116480 var reshapedInput = reshape({
116481 inputs: {
116482 x: sumInput
116483 },
116484 attrs: {
116485 shape: [batchSize, inSize]
116486 },
116487 backend: backend
116488 });
116489 var outType = sumOutType(x.dtype);
116490 var reduced = reduce(reshapedInput, outType, 'sum', backend);
116491 var out = reshape({
116492 inputs: {
116493 x: reduced
116494 },
116495 attrs: {
116496 shape: outShape
116497 },
116498 backend: backend
116499 });
116500 backend.disposeIntermediateTensorInfo(reshapedInput);
116501 backend.disposeIntermediateTensorInfo(reduced);
116502 if (sumInputIsTransposed) {
116503 backend.disposeIntermediateTensorInfo(sumInput);
116504 }
116505 return out;
116506 }
116507
116508 /**
116509 * @license
116510 * Copyright 2020 Google LLC. All Rights Reserved.
116511 * Licensed under the Apache License, Version 2.0 (the "License");
116512 * you may not use this file except in compliance with the License.
116513 * You may obtain a copy of the License at
116514 *
116515 * http://www.apache.org/licenses/LICENSE-2.0
116516 *
116517 * Unless required by applicable law or agreed to in writing, software
116518 * distributed under the License is distributed on an "AS IS" BASIS,
116519 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116520 * See the License for the specific language governing permissions and
116521 * limitations under the License.
116522 * =============================================================================
116523 */
116524 function sum(args) {
116525 var inputs = args.inputs,
116526 backend = args.backend,
116527 attrs = args.attrs;
116528 var x = inputs.x;
116529 var axis = attrs.axis,
116530 keepDims = attrs.keepDims;
116531 return sumImpl(x, axis, keepDims, backend);
116532 }
116533 var sumConfig = {
116534 kernelName: Sum,
116535 backendName: 'webgl',
116536 kernelFunc: sum
116537 };
116538
116539 /**
116540 * @license
116541 * Copyright 2020 Google LLC. All Rights Reserved.
116542 * Licensed under the Apache License, Version 2.0 (the "License");
116543 * you may not use this file except in compliance with the License.
116544 * You may obtain a copy of the License at
116545 *
116546 * http://www.apache.org/licenses/LICENSE-2.0
116547 *
116548 * Unless required by applicable law or agreed to in writing, software
116549 * distributed under the License is distributed on an "AS IS" BASIS,
116550 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116551 * See the License for the specific language governing permissions and
116552 * limitations under the License.
116553 * =============================================================================
116554 */
116555 function transpose(args) {
116556 var inputs = args.inputs,
116557 backend = args.backend,
116558 attrs = args.attrs;
116559 var x = inputs.x;
116560 var perm = attrs.perm;
116561 var webglBackend = backend;
116562 var xRank = x.shape.length;
116563 var newShape = new Array(xRank);
116564 for (var i = 0; i < newShape.length; i++) {
116565 newShape[i] = x.shape[perm[i]];
116566 }
116567 var out;
116568 if (webglBackend.shouldExecuteOnCPU([x])) {
116569 var xTexData = webglBackend.texData.get(x.dataId);
116570 var values = xTexData.values;
116571 var outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape);
116572 out = webglBackend.makeTensorInfo(newShape, x.dtype);
116573 var outData = webglBackend.texData.get(out.dataId);
116574 outData.values = outValues;
116575 } else {
116576 out = transposeImpl(x, perm, webglBackend);
116577 }
116578 return out;
116579 }
116580 var transposeConfig = {
116581 kernelName: Transpose,
116582 backendName: 'webgl',
116583 kernelFunc: transpose
116584 };
116585
116586 /**
116587 * @license
116588 * Copyright 2020 Google LLC. All Rights Reserved.
116589 * Licensed under the Apache License, Version 2.0 (the "License");
116590 * you may not use this file except in compliance with the License.
116591 * You may obtain a copy of the License at
116592 *
116593 * http://www.apache.org/licenses/LICENSE-2.0
116594 *
116595 * Unless required by applicable law or agreed to in writing, software
116596 * distributed under the License is distributed on an "AS IS" BASIS,
116597 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116598 * See the License for the specific language governing permissions and
116599 * limitations under the License.
116600 * =============================================================================
116601 */
116602 // Empirically determined minimal shared dimension in matmul before we forward
116603 // to a.mul(b).sum() in order to take advantage of GPU parallelism. See
116604 // https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks.
116605 var MATMUL_SHARED_DIM_THRESHOLD = 1000;
116606 function batchMatMulImpl(_ref) {
116607 var a = _ref.a,
116608 b = _ref.b,
116609 transposeA = _ref.transposeA,
116610 transposeB = _ref.transposeB,
116611 backend = _ref.backend,
116612 _ref$bias = _ref.bias,
116613 bias = _ref$bias === void 0 ? null : _ref$bias,
116614 _ref$preluActivationW = _ref.preluActivationWeights,
116615 preluActivationWeights = _ref$preluActivationW === void 0 ? null : _ref$preluActivationW,
116616 _ref$leakyreluAlpha = _ref.leakyreluAlpha,
116617 leakyreluAlpha = _ref$leakyreluAlpha === void 0 ? 0 : _ref$leakyreluAlpha,
116618 _ref$activation = _ref.activation,
116619 activation = _ref$activation === void 0 ? null : _ref$activation;
116620 var aRank = a.shape.length;
116621 var bRank = b.shape.length;
116622 var innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
116623 var innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
116624 var outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
116625 var outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
116626 var outerDimsA = a.shape.slice(0, -2);
116627 var outerDimsB = b.shape.slice(0, -2);
116628 var batchDimA = sizeFromShape(outerDimsA);
116629 var batchDimB = sizeFromShape(outerDimsB);
116630 var outShapeOuterDims = assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
116631 var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
116632 assert$1(innerShapeA === innerShapeB, function () {
116633 return "Error in matMul: inner shapes (".concat(innerShapeA, ") and (") + "".concat(innerShapeB, ") of Tensors with shapes ").concat(a.shape, " and ") + "".concat(b.shape, " and transposeA=").concat(transposeA) + " and transposeB=".concat(transposeB, " must match.");
116634 });
116635 var a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] : [batchDimA, outerShapeA, innerShapeA];
116636 var b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] : [batchDimB, innerShapeB, outerShapeB];
116637 // The rest of the implementation is designed to operate on rank-3 tensors
116638 var a3d = reshape({
116639 inputs: {
116640 x: a
116641 },
116642 backend: backend,
116643 attrs: {
116644 shape: a3dShape
116645 }
116646 });
116647 var b3d = reshape({
116648 inputs: {
116649 x: b
116650 },
116651 backend: backend,
116652 attrs: {
116653 shape: b3dShape
116654 }
116655 });
116656 var intermediates = [a3d, b3d];
116657 var batchDim = Math.max(batchDimA, batchDimB);
116658 var sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
116659 var hasBias = bias != null;
116660 var hasPreluActivationWeights = preluActivationWeights != null;
116661 var hasLeakyreluAlpha = activation === 'leakyrelu';
116662 var fusedActivation = activation != null ? mapActivationToShaderProgram(activation, true) : null;
116663 var containsFusedOps = hasBias || hasPreluActivationWeights || hasLeakyreluAlpha || fusedActivation != null;
116664 var out;
116665 // Since the matrices are vectors, it is faster to call mul().sum()
116666 // because sum() is O(sqrt(N)) due to divide-and-conquer.
116667 if ((outerShapeA === 1 || outerShapeB === 1) && sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) {
116668 var aVec = a3d;
116669 var bVec = b3d;
116670 if (transposeA) {
116671 aVec = transpose({
116672 inputs: {
116673 x: a3d
116674 },
116675 backend: backend,
116676 attrs: {
116677 perm: [0, 2, 1]
116678 }
116679 });
116680 intermediates.push(aVec);
116681 }
116682 if (transposeB) {
116683 bVec = transpose({
116684 inputs: {
116685 x: b3d
116686 },
116687 backend: backend,
116688 attrs: {
116689 perm: [0, 2, 1]
116690 }
116691 });
116692 intermediates.push(bVec);
116693 }
116694 var shouldReshapeA = outerShapeB !== 1;
116695 var shouldReshapeB = outerShapeB === 1;
116696 var aVec3d = aVec;
116697 if (shouldReshapeA) {
116698 aVec3d = reshape({
116699 inputs: {
116700 x: aVec
116701 },
116702 backend: backend,
116703 attrs: {
116704 shape: [batchDim, sharedDim, 1]
116705 }
116706 });
116707 intermediates.push(aVec3d);
116708 }
116709 var axis = outerShapeB === 1 ? 2 : 1;
116710 var bVec3d = bVec;
116711 if (shouldReshapeB) {
116712 bVec3d = reshape({
116713 inputs: {
116714 x: bVec
116715 },
116716 backend: backend,
116717 attrs: {
116718 shape: [batchDim, 1, sharedDim]
116719 }
116720 });
116721 intermediates.push(bVec3d);
116722 }
116723 var product = multiply({
116724 inputs: {
116725 a: aVec3d,
116726 b: bVec3d
116727 },
116728 backend: backend
116729 });
116730 out = sum({
116731 inputs: {
116732 x: product
116733 },
116734 backend: backend,
116735 attrs: {
116736 axis: axis,
116737 keepDims: true
116738 }
116739 });
116740 intermediates.push(product);
116741 } else {
116742 var dtype = upcastType(a.dtype, b.dtype);
116743 var program = new MatMulPackedProgram(a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
116744 var inputs = [a3d, b3d];
116745 if (bias != null) {
116746 inputs.push(bias);
116747 }
116748 if (hasPreluActivationWeights) {
116749 inputs.push(preluActivationWeights);
116750 }
116751 if (hasLeakyreluAlpha) {
116752 var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
116753 inputs.push($leakyreluAlpha);
116754 intermediates.push($leakyreluAlpha);
116755 }
116756 out = backend.runWebGLProgram(program, inputs, dtype);
116757 }
116758 var outReshaped = reshape({
116759 inputs: {
116760 x: out
116761 },
116762 backend: backend,
116763 attrs: {
116764 shape: outShape
116765 }
116766 });
116767 intermediates.push(out);
116768 for (var _i = 0, _intermediates = intermediates; _i < _intermediates.length; _i++) {
116769 var i = _intermediates[_i];
116770 backend.disposeIntermediateTensorInfo(i);
116771 }
116772 return outReshaped;
116773 }
116774
116775 /**
116776 * @license
116777 * Copyright 2020 Google LLC. All Rights Reserved.
116778 * Licensed under the Apache License, Version 2.0 (the License);
116779 * you may not use this file except in compliance with the License.
116780 * You may obtain a copy of the License at
116781 *
116782 * http://www.apache.org/licenses/LICENSE-2.0
116783 *
116784 * Unless required by applicable law or agreed to in writing, software
116785 * distributed under the License is distributed on an AS IS BASIS,
116786 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116787 * See the License for the specific language governing permissions and
116788 * limitations under the License.
116789 * =============================================================================
116790 */
116791 function _fusedMatMul(args) {
116792 var inputs = args.inputs,
116793 backend = args.backend,
116794 attrs = args.attrs;
116795 var a = inputs.a,
116796 b = inputs.b,
116797 bias = inputs.bias,
116798 preluActivationWeights = inputs.preluActivationWeights;
116799 var transposeA = attrs.transposeA,
116800 transposeB = attrs.transposeB,
116801 activation = attrs.activation,
116802 leakyreluAlpha = attrs.leakyreluAlpha;
116803 return batchMatMulImpl({
116804 a: a,
116805 b: b,
116806 transposeA: transposeA,
116807 transposeB: transposeB,
116808 backend: backend,
116809 bias: bias,
116810 preluActivationWeights: preluActivationWeights,
116811 leakyreluAlpha: leakyreluAlpha,
116812 activation: activation
116813 });
116814 }
116815 var _fusedMatMulConfig = {
116816 kernelName: _FusedMatMul,
116817 backendName: 'webgl',
116818 kernelFunc: _fusedMatMul
116819 };
116820
116821 /**
116822 * @license
116823 * Copyright 2020 Google LLC. All Rights Reserved.
116824 * Licensed under the Apache License, Version 2.0 (the "License");
116825 * you may not use this file except in compliance with the License.
116826 * You may obtain a copy of the License at
116827 *
116828 * http://www.apache.org/licenses/LICENSE-2.0
116829 *
116830 * Unless required by applicable law or agreed to in writing, software
116831 * distributed under the License is distributed on an "AS IS" BASIS,
116832 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116833 * See the License for the specific language governing permissions and
116834 * limitations under the License.
116835 * =============================================================================
116836 */
116837 var ABS = "return abs(x);";
116838 function abs(args) {
116839 var inputs = args.inputs,
116840 backend = args.backend;
116841 var x = inputs.x;
116842 // TODO: handle cases when x is complex. Once the cpu implementation
116843 // can handle complex values, refactor to use unaryKernelFunc.
116844 if (backend.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
116845 var xData = backend.texData.get(x.dataId);
116846 var outValues = simpleAbsImplCPU(xData.values);
116847 return backend.makeTensorInfo(x.shape, x.dtype, outValues);
116848 }
116849 var program;
116850 if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
116851 program = new UnaryOpPackedProgram(x.shape, ABS);
116852 } else {
116853 program = new UnaryOpProgram(x.shape, ABS);
116854 }
116855 return backend.runWebGLProgram(program, [x], x.dtype);
116856 }
116857 var absConfig = {
116858 kernelName: Abs,
116859 backendName: 'webgl',
116860 kernelFunc: abs
116861 };
116862
116863 /**
116864 * @license
116865 * Copyright 2020 Google LLC. All Rights Reserved.
116866 * Licensed under the Apache License, Version 2.0 (the "License");
116867 * you may not use this file except in compliance with the License.
116868 * You may obtain a copy of the License at
116869 *
116870 * http://www.apache.org/licenses/LICENSE-2.0
116871 *
116872 * Unless required by applicable law or agreed to in writing, software
116873 * distributed under the License is distributed on an "AS IS" BASIS,
116874 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116875 * See the License for the specific language governing permissions and
116876 * limitations under the License.
116877 * =============================================================================
116878 */
116879 var ACOS = CHECK_NAN_SNIPPET$1 + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return acos(x);\n";
116880 var acos = unaryKernelFunc({
116881 opSnippet: ACOS
116882 });
116883 var acosConfig = {
116884 kernelName: Acos,
116885 backendName: 'webgl',
116886 kernelFunc: acos
116887 };
116888
116889 /**
116890 * @license
116891 * Copyright 2020 Google LLC. All Rights Reserved.
116892 * Licensed under the Apache License, Version 2.0 (the "License");
116893 * you may not use this file except in compliance with the License.
116894 * You may obtain a copy of the License at
116895 *
116896 * http://www.apache.org/licenses/LICENSE-2.0
116897 *
116898 * Unless required by applicable law or agreed to in writing, software
116899 * distributed under the License is distributed on an "AS IS" BASIS,
116900 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116901 * See the License for the specific language governing permissions and
116902 * limitations under the License.
116903 * =============================================================================
116904 */
116905 var ACOSH = CHECK_NAN_SNIPPET$1 + "\n if (x < 1.0) return NAN;\nreturn log(x + sqrt(x * x - 1.0));";
116906 var acosh = unaryKernelFunc({
116907 opSnippet: ACOSH
116908 });
116909 var acoshConfig = {
116910 kernelName: Acosh,
116911 backendName: 'webgl',
116912 kernelFunc: acosh
116913 };
116914
116915 /**
116916 * @license
116917 * Copyright 2020 Google LLC. All Rights Reserved.
116918 * Licensed under the Apache License, Version 2.0 (the "License");
116919 * you may not use this file except in compliance with the License.
116920 * You may obtain a copy of the License at
116921 *
116922 * http://www.apache.org/licenses/LICENSE-2.0
116923 *
116924 * Unless required by applicable law or agreed to in writing, software
116925 * distributed under the License is distributed on an "AS IS" BASIS,
116926 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116927 * See the License for the specific language governing permissions and
116928 * limitations under the License.
116929 * =============================================================================
116930 */
116931 var ADD = 'return a + b;';
116932 var addKernelFunc = binaryKernelFunc({
116933 opSnippet: ADD,
116934 packedOpSnippet: ADD,
116935 supportsComplex: true,
116936 cpuKernelImpl: addImplCPU
116937 });
116938 var addConfig = {
116939 kernelName: Add$1,
116940 backendName: 'webgl',
116941 kernelFunc: addKernelFunc
116942 };
116943
116944 /**
116945 * @license
116946 * Copyright 2019 Google LLC. All Rights Reserved.
116947 * Licensed under the Apache License, Version 2.0 (the "License");
116948 * you may not use this file except in compliance with the License.
116949 * You may obtain a copy of the License at
116950 *
116951 * http://www.apache.org/licenses/LICENSE-2.0
116952 *
116953 * Unless required by applicable law or agreed to in writing, software
116954 * distributed under the License is distributed on an "AS IS" BASIS,
116955 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116956 * See the License for the specific language governing permissions and
116957 * limitations under the License.
116958 * =============================================================================
116959 */
116960 var AddNProgram = /*#__PURE__*/_createClass(function AddNProgram(outputShape, shapes) {
116961 _classCallCheck(this, AddNProgram);
116962 this.outputShape = [];
116963 this.outputShape = outputShape;
116964 this.variableNames = shapes.map(function (_, i) {
116965 return "T".concat(i);
116966 });
116967 var snippets = [];
116968 // Get target elements from every input tensor.
116969 this.variableNames.forEach(function (variable) {
116970 snippets.push("float v".concat(variable, " = get").concat(variable, "AtOutCoords();"));
116971 });
116972 // Calculate the sum of all elements.
116973 var operation = this.variableNames.map(function (variable) {
116974 return "v".concat(variable);
116975 }).join(' + ');
116976 this.userCode = "\n void main() {\n ".concat(snippets.join('\n '), "\n\n float result = ").concat(operation, ";\n setOutput(result);\n }\n ");
116977 });
116978
116979 /**
116980 * @license
116981 * Copyright 2019 Google LLC. All Rights Reserved.
116982 * Licensed under the Apache License, Version 2.0 (the "License");
116983 * you may not use this file except in compliance with the License.
116984 * You may obtain a copy of the License at
116985 *
116986 * http://www.apache.org/licenses/LICENSE-2.0
116987 *
116988 * Unless required by applicable law or agreed to in writing, software
116989 * distributed under the License is distributed on an "AS IS" BASIS,
116990 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116991 * See the License for the specific language governing permissions and
116992 * limitations under the License.
116993 * =============================================================================
116994 */
116995 var AddNPackedProgram = /*#__PURE__*/_createClass(function AddNPackedProgram(outputShape, shapes) {
116996 _classCallCheck(this, AddNPackedProgram);
116997 this.outputShape = [];
116998 this.packedInputs = true;
116999 this.packedOutput = true;
117000 this.outputShape = outputShape;
117001 this.variableNames = shapes.map(function (_, i) {
117002 return "T".concat(i);
117003 });
117004 var snippets = [];
117005 // Get target elements from every input tensor.
117006 this.variableNames.forEach(function (variable) {
117007 snippets.push("vec4 v".concat(variable, " = get").concat(variable, "AtOutCoords();"));
117008 });
117009 // Calculate the sum of all elements.
117010 var operation = this.variableNames.map(function (variable) {
117011 return "v".concat(variable);
117012 }).join(' + ');
117013 this.userCode = "\n void main() {\n ".concat(snippets.join('\n '), "\n\n vec4 result = ").concat(operation, ";\n setOutput(result);\n }\n ");
117014 });
117015
117016 /**
117017 * @license
117018 * Copyright 2020 Google LLC. All Rights Reserved.
117019 * Licensed under the Apache License, Version 2.0 (the "License");
117020 * you may not use this file except in compliance with the License.
117021 * You may obtain a copy of the License at
117022 *
117023 * http://www.apache.org/licenses/LICENSE-2.0
117024 *
117025 * Unless required by applicable law or agreed to in writing, software
117026 * distributed under the License is distributed on an "AS IS" BASIS,
117027 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117028 * See the License for the specific language governing permissions and
117029 * limitations under the License.
117030 * =============================================================================
117031 */
117032 function addN(args) {
117033 var inputs = args.inputs,
117034 backend = args.backend;
117035 var tensors = inputs;
117036 if (tensors.length === 1) {
117037 return identity({
117038 inputs: {
117039 x: tensors[0]
117040 },
117041 backend: backend
117042 });
117043 }
117044 // Limit the number of uploaded textures for optimization.
117045 if (tensors.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
117046 var midIndex = Math.floor(tensors.length / 2);
117047 var leftSide = addN({
117048 inputs: tensors.slice(0, midIndex),
117049 backend: backend
117050 });
117051 var rightSide = addN({
117052 inputs: tensors.slice(midIndex),
117053 backend: backend
117054 });
117055 return addN({
117056 inputs: [leftSide, rightSide],
117057 backend: backend
117058 });
117059 }
117060 var dtype = tensors.map(function (t) {
117061 return t.dtype;
117062 }).reduce(function (d1, d2) {
117063 return upcastType(d1, d2);
117064 });
117065 var shapes = tensors.map(function (t) {
117066 return t.shape;
117067 });
117068 // We can make sure shapes are identical in op level.
117069 var usePackedOp = env().getBool('WEBGL_PACK');
117070 var program = usePackedOp ? new AddNPackedProgram(tensors[0].shape, shapes) : new AddNProgram(tensors[0].shape, shapes);
117071 return backend.runWebGLProgram(program, tensors, dtype);
117072 }
117073 var addNConfig = {
117074 kernelName: AddN,
117075 backendName: 'webgl',
117076 kernelFunc: addN
117077 };
117078
117079 function all(args) {
117080 var inputs = args.inputs,
117081 backend = args.backend,
117082 attrs = args.attrs;
117083 var x = inputs.x;
117084 var axis = attrs.axis,
117085 keepDims = attrs.keepDims;
117086 var xRank = x.shape.length;
117087 var origAxes = parseAxisParam(axis, x.shape);
117088 var axes = origAxes;
117089 var permutedAxes = getAxesPermutation(axes, xRank);
117090 var permutedX = x;
117091 if (permutedAxes != null) {
117092 permutedX = transpose({
117093 inputs: {
117094 x: x
117095 },
117096 backend: backend,
117097 attrs: {
117098 perm: permutedAxes
117099 }
117100 });
117101 axes = getInnerMostAxes(axes.length, xRank);
117102 }
117103 assertAxesAreInnerMostDims('all', axes, xRank);
117104 var _backend_util$compute = computeOutAndReduceShapes(permutedX.shape, axes),
117105 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
117106 outShape = _backend_util$compute2[0],
117107 reduceShape = _backend_util$compute2[1];
117108 var inSize = sizeFromShape(reduceShape);
117109 var a2D = reshape({
117110 inputs: {
117111 x: permutedX
117112 },
117113 backend: backend,
117114 attrs: {
117115 shape: [-1, inSize]
117116 }
117117 });
117118 var reduced = reduce(a2D, a2D.dtype, 'all', backend);
117119 var res;
117120 if (keepDims) {
117121 var newShape = expandShapeToKeepDim(outShape, origAxes);
117122 res = reshape({
117123 inputs: {
117124 x: reduced
117125 },
117126 backend: backend,
117127 attrs: {
117128 shape: newShape
117129 }
117130 });
117131 } else {
117132 res = reshape({
117133 inputs: {
117134 x: reduced
117135 },
117136 backend: backend,
117137 attrs: {
117138 shape: outShape
117139 }
117140 });
117141 }
117142 backend.disposeIntermediateTensorInfo(a2D);
117143 backend.disposeIntermediateTensorInfo(reduced);
117144 if (permutedAxes != null) {
117145 backend.disposeIntermediateTensorInfo(permutedX);
117146 }
117147 return res;
117148 }
117149 var allConfig = {
117150 kernelName: All,
117151 backendName: 'webgl',
117152 kernelFunc: all
117153 };
117154
117155 function any(args) {
117156 var inputs = args.inputs,
117157 backend = args.backend,
117158 attrs = args.attrs;
117159 var x = inputs.x;
117160 var axis = attrs.axis,
117161 keepDims = attrs.keepDims;
117162 var xRank = x.shape.length;
117163 var origAxes = parseAxisParam(axis, x.shape);
117164 var axes = origAxes;
117165 var permutedAxes = getAxesPermutation(axes, xRank);
117166 var permutedX = x;
117167 if (permutedAxes != null) {
117168 permutedX = transpose({
117169 inputs: {
117170 x: x
117171 },
117172 backend: backend,
117173 attrs: {
117174 perm: permutedAxes
117175 }
117176 });
117177 axes = getInnerMostAxes(axes.length, xRank);
117178 }
117179 assertAxesAreInnerMostDims('any', axes, xRank);
117180 var _backend_util$compute = computeOutAndReduceShapes(permutedX.shape, axes),
117181 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
117182 outShape = _backend_util$compute2[0],
117183 reduceShape = _backend_util$compute2[1];
117184 var inSize = sizeFromShape(reduceShape);
117185 var a2D = reshape({
117186 inputs: {
117187 x: permutedX
117188 },
117189 backend: backend,
117190 attrs: {
117191 shape: [-1, inSize]
117192 }
117193 });
117194 var reduced = reduce(a2D, a2D.dtype, 'any', backend);
117195 var res;
117196 if (keepDims) {
117197 var newShape = expandShapeToKeepDim(outShape, origAxes);
117198 res = reshape({
117199 inputs: {
117200 x: reduced
117201 },
117202 backend: backend,
117203 attrs: {
117204 shape: newShape
117205 }
117206 });
117207 } else {
117208 res = reshape({
117209 inputs: {
117210 x: reduced
117211 },
117212 backend: backend,
117213 attrs: {
117214 shape: outShape
117215 }
117216 });
117217 }
117218 backend.disposeIntermediateTensorInfo(a2D);
117219 backend.disposeIntermediateTensorInfo(reduced);
117220 if (permutedAxes != null) {
117221 backend.disposeIntermediateTensorInfo(permutedX);
117222 }
117223 return res;
117224 }
117225 var anyConfig = {
117226 kernelName: Any,
117227 backendName: 'webgl',
117228 kernelFunc: any
117229 };
117230
117231 /**
117232 * @license
117233 * Copyright 2017 Google LLC. All Rights Reserved.
117234 * Licensed under the Apache License, Version 2.0 (the "License");
117235 * you may not use this file except in compliance with the License.
117236 * You may obtain a copy of the License at
117237 *
117238 * http://www.apache.org/licenses/LICENSE-2.0
117239 *
117240 * Unless required by applicable law or agreed to in writing, software
117241 * distributed under the License is distributed on an "AS IS" BASIS,
117242 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117243 * See the License for the specific language governing permissions and
117244 * limitations under the License.
117245 * =============================================================================
117246 */
117247 var ArgMinMaxProgram = /*#__PURE__*/_createClass(function ArgMinMaxProgram(reduceInfo, op, firstPass) {
117248 _classCallCheck(this, ArgMinMaxProgram);
117249 this.variableNames = ['A'];
117250 var windowSize = reduceInfo.windowSize,
117251 batchSize = reduceInfo.batchSize,
117252 outSize = reduceInfo.outSize;
117253 if (!firstPass) {
117254 this.variableNames.push('bestIndicesA');
117255 }
117256 this.outputShape = [batchSize, outSize];
117257 var compOp = op === 'max' ? '>' : '<';
117258 var indexSnippet = firstPass ? 'inOffset + i;' : 'round(getBestIndicesA(batch, inOffset + i));';
117259 this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * ".concat(windowSize, ";\n\n int bestIndex = inOffset;\n float bestValue = getA(batch, bestIndex);\n\n for (int i = 0; i < ").concat(windowSize, "; i++) {\n int inIdx = ").concat(indexSnippet, ";\n float candidate = getA(batch, inIdx);\n if (candidate ").concat(compOp, " bestValue) {\n bestValue = candidate;\n bestIndex = inIdx;\n }\n }\n setOutput(float(bestIndex));\n }\n ");
117260 });
117261
117262 var ArgMinMaxPackedProgram = /*#__PURE__*/_createClass(function ArgMinMaxPackedProgram(shape, windowSize, op, firstPass) {
117263 _classCallCheck(this, ArgMinMaxPackedProgram);
117264 this.variableNames = ['A'];
117265 this.packedInputs = true;
117266 this.packedOutput = true;
117267 assert$1(shape.length > 2, function () {
117268 return "Packed arg".concat(op.charAt(0).toUpperCase() + op.slice(1), " supports only inputs with rank above 2.");
117269 });
117270 var inSize = shape[shape.length - 1];
117271 var outSize = Math.ceil(inSize / windowSize);
117272 this.outputShape = shape.slice(0, -1);
117273 if (outSize > 1) {
117274 this.outputShape.push(outSize);
117275 }
117276 if (!firstPass) {
117277 this.variableNames.push('bestIndicesA');
117278 }
117279 var outShape = this.outputShape;
117280 var rank = outShape.length;
117281 var dtype = getCoordsDataType(rank);
117282 var coords = getChannels('coords', rank);
117283 var sourceLocSetup;
117284 var sourceRank;
117285 if (outSize === 1) {
117286 sourceRank = rank + 1;
117287 var sourceLocDType = getCoordsDataType(sourceRank);
117288 sourceLocSetup = "\n ".concat(sourceLocDType, " sourceLocR = ").concat(sourceLocDType, "(").concat(coords.join(), ", 0);\n ++").concat(coords[rank - 1], ";\n ").concat(sourceLocDType, " sourceLocG = ").concat(sourceLocDType, "(").concat(coords.join(), ", 0);\n ++").concat(coords[rank - 2], ";\n ").concat(sourceLocDType, " sourceLocA = ").concat(sourceLocDType, "(").concat(coords.join(), ", 0);\n --").concat(coords[rank - 1], ";\n ").concat(sourceLocDType, " sourceLocB = ").concat(sourceLocDType, "(").concat(coords.join(), ", 0);\n --").concat(coords[rank - 2], ";");
117289 } else {
117290 sourceRank = rank;
117291 sourceLocSetup = "\n ".concat(dtype, " sourceLocR = coords;\n ++").concat(coords[rank - 1], ";\n ").concat(dtype, " sourceLocG = coords;\n ++").concat(coords[rank - 2], ";\n ").concat(dtype, " sourceLocA = coords;\n --").concat(coords[rank - 1], ";\n ").concat(dtype, " sourceLocB = coords;\n --").concat(coords[rank - 2], ";");
117292 }
117293 var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
117294 var inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3.
117295 var intChannels = channels.map(function (x) {
117296 return 'int ' + x;
117297 });
117298 var srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
117299 var srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
117300 var srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
117301 var srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
117302 var compOp = op === 'max' ? 'greaterThan' : 'lessThan';
117303 var fetchCandidateIdx = firstPass ? '' : "\n inIdx = round(vec4(getBestIndicesAChannel(".concat(srcRCoords.join(), "),\n getBestIndicesAChannel(").concat(srcGCoords.join(), "),\n getBestIndicesAChannel(").concat(srcBCoords.join(), "),\n getBestIndicesAChannel(").concat(srcACoords.join(), ")));");
117304 var fetchValue = "vec4(\n getAChannel(".concat(srcRCoords.join(), "),\n hasNextCol ? getAChannel(").concat(srcGCoords.join(), ") : 0.,\n hasNextRow ? getAChannel(").concat(srcBCoords.join(), ") : 0.,\n hasNextRow && hasNextCol ? getAChannel(").concat(srcACoords.join(), ") : 0.)");
117305 var getBestIndicesAChannelSnippet = firstPass ? '' : "\n float getBestIndicesAChannel(".concat(intChannels.join(), ") {\n return getChannel(getBestIndicesA(").concat(channels.join(), "),\n vec2(").concat(channels.slice(-2).join(), "));\n }");
117306 this.userCode = "\n float getAChannel(".concat(intChannels.join(), ") {\n return getChannel(getA(").concat(channels.join(), "),\n vec2(").concat(channels.slice(-2).join(), "));\n }\n ").concat(getBestIndicesAChannelSnippet, "\n void main() {\n ").concat(dtype, " coords = getOutputCoords();\n bool hasNextCol = ").concat(coords[rank - 1], " < ").concat(outShape[rank - 1] - 1, ";\n bool hasNextRow = ").concat(coords[rank - 2], " < ").concat(outShape[rank - 2] - 1, ";\n ").concat(sourceLocSetup, "\n ivec4 srcIdx = ivec4(sourceLocR").concat(inChannel, ", sourceLocG").concat(inChannel, ",\n sourceLocB").concat(inChannel, ", sourceLocA").concat(inChannel, ") * ").concat(windowSize, ";\n ivec4 inIdx = srcIdx;\n vec4 bestIndex = vec4(inIdx);\n vec4 bestValue = ").concat(fetchValue, ";\n\n for (int i = 0; i < ").concat(windowSize, "; i++) {\n inIdx = srcIdx;\n ").concat(fetchCandidateIdx, "\n vec4 candidate = ").concat(fetchValue, ";\n bvec4 nan = isnan(candidate);\n bvec4 replace = bvec4(\n vec4(").concat(compOp, "(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n bestValue = vec4(replace.x ? candidate.x : bestValue.x,\n replace.y ? candidate.y : bestValue.y,\n replace.z ? candidate.z : bestValue.z,\n replace.w ? candidate.w : bestValue.w);\n bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n srcIdx++;\n }\n setOutput(bestIndex);\n }\n ");
117307 });
117308
117309 function argReduce(backend, x, reduceType) {
117310 var bestIndicesA = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : null;
117311 var batchSize = x.shape[0];
117312 var inSize = x.shape[1];
117313 if (bestIndicesA != null) {
117314 batchSize = bestIndicesA.shape[0];
117315 inSize = bestIndicesA.shape[1];
117316 }
117317 var windowSize = computeOptimalWindowSize(inSize);
117318 var reduceInfo = {
117319 windowSize: windowSize,
117320 inSize: inSize,
117321 batchSize: batchSize,
117322 outSize: Math.ceil(inSize / windowSize)
117323 };
117324 var program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
117325 var inputs = [x];
117326 if (bestIndicesA != null) {
117327 inputs.push(bestIndicesA);
117328 }
117329 var output = backend.runWebGLProgram(program, inputs, 'int32');
117330 // No need to run another GPGPU program.
117331 if (output.shape[1] === 1) {
117332 return output;
117333 }
117334 var result = argReduce(backend, x, reduceType, output);
117335 backend.disposeIntermediateTensorInfo(output);
117336 return result;
117337 }
117338 function argReducePacked(backend, x, reduceType) {
117339 var bestIndicesA = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : null;
117340 var inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
117341 var inSize = inShape[inShape.length - 1];
117342 var windowSize = computeOptimalWindowSize(inSize);
117343 var program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null);
117344 var inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
117345 var output = backend.runWebGLProgram(program, inputs, 'int32');
117346 if (output.shape.length === x.shape.length) {
117347 var result = argReducePacked(backend, x, reduceType, output);
117348 backend.disposeIntermediateTensorInfo(output);
117349 return result;
117350 }
117351 return output;
117352 }
117353 function argMinMaxReduce(backend, x, axis, reduceType) {
117354 var axes = [axis];
117355 assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.shape.length);
117356 if (!env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) {
117357 var intermediateTensorInfos = [];
117358 // Eagerly unpack x input since it is passed in to all the shaders which
117359 // require unpacked inputs.
117360 var xtexData = backend.texData.get(x.dataId);
117361 var xIsPacked = xtexData !== null && xtexData.isPacked;
117362 var xUnPacked = x;
117363 if (xIsPacked) {
117364 xUnPacked = backend.unpackTensor(x);
117365 intermediateTensorInfos.push(xUnPacked);
117366 }
117367 var _backend_util$compute = computeOutAndReduceShapes(xUnPacked.shape, axes),
117368 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
117369 outShape = _backend_util$compute2[0],
117370 reduceShape = _backend_util$compute2[1];
117371 var inSize = sizeFromShape(reduceShape);
117372 var a2D = reshape({
117373 inputs: {
117374 x: xUnPacked
117375 },
117376 backend: backend,
117377 attrs: {
117378 shape: [-1, inSize]
117379 }
117380 });
117381 intermediateTensorInfos.push(a2D);
117382 var reduced = argReduce(backend, a2D, reduceType);
117383 intermediateTensorInfos.push(reduced);
117384 var reshaped = reshape({
117385 inputs: {
117386 x: reduced
117387 },
117388 backend: backend,
117389 attrs: {
117390 shape: outShape
117391 }
117392 });
117393 intermediateTensorInfos.forEach(function (t) {
117394 return backend.disposeIntermediateTensorInfo(t);
117395 });
117396 return reshaped;
117397 }
117398 return argReducePacked(backend, x, reduceType);
117399 }
117400
117401 /**
117402 * @license
117403 * Copyright 2020 Google LLC. All Rights Reserved.
117404 * Licensed under the Apache License, Version 2.0 (the "License");
117405 * you may not use this file except in compliance with the License.
117406 * You may obtain a copy of the License at
117407 *
117408 * http://www.apache.org/licenses/LICENSE-2.0
117409 *
117410 * Unless required by applicable law or agreed to in writing, software
117411 * distributed under the License is distributed on an "AS IS" BASIS,
117412 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117413 * See the License for the specific language governing permissions and
117414 * limitations under the License.
117415 * =============================================================================
117416 */
117417 function argMax(args) {
117418 var inputs = args.inputs,
117419 backend = args.backend,
117420 attrs = args.attrs;
117421 var x = inputs.x;
117422 var axis = attrs.axis;
117423 var axes = parseAxisParam(axis, x.shape);
117424 var permutedAxes = getAxesPermutation(axes, x.shape.length);
117425 var $x = x;
117426 var intermediateTensorInfos = [];
117427 if (permutedAxes != null) {
117428 $x = transpose({
117429 inputs: {
117430 x: x
117431 },
117432 backend: backend,
117433 attrs: {
117434 perm: permutedAxes
117435 }
117436 });
117437 intermediateTensorInfos.push($x);
117438 axes = getInnerMostAxes(axes.length, $x.shape.length);
117439 }
117440 assertAxesAreInnerMostDims('argMax', [axes[0]], $x.shape.length);
117441 var out = argMinMaxReduce(backend, $x, axes[0], 'max');
117442 intermediateTensorInfos.forEach(function (t) {
117443 return backend.disposeIntermediateTensorInfo(t);
117444 });
117445 return out;
117446 }
117447 var argMaxConfig = {
117448 kernelName: ArgMax,
117449 backendName: 'webgl',
117450 kernelFunc: argMax
117451 };
117452
117453 /**
117454 * @license
117455 * Copyright 2020 Google LLC. All Rights Reserved.
117456 * Licensed under the Apache License, Version 2.0 (the "License");
117457 * you may not use this file except in compliance with the License.
117458 * You may obtain a copy of the License at
117459 *
117460 * http://www.apache.org/licenses/LICENSE-2.0
117461 *
117462 * Unless required by applicable law or agreed to in writing, software
117463 * distributed under the License is distributed on an "AS IS" BASIS,
117464 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117465 * See the License for the specific language governing permissions and
117466 * limitations under the License.
117467 * =============================================================================
117468 */
117469 function argMin(args) {
117470 var inputs = args.inputs,
117471 backend = args.backend,
117472 attrs = args.attrs;
117473 var x = inputs.x;
117474 var axis = attrs.axis;
117475 var axes = parseAxisParam(axis, x.shape);
117476 var permutedAxes = getAxesPermutation(axes, x.shape.length);
117477 var $x = x;
117478 var intermediateTensorInfos = [];
117479 if (permutedAxes != null) {
117480 $x = transpose({
117481 inputs: {
117482 x: x
117483 },
117484 backend: backend,
117485 attrs: {
117486 perm: permutedAxes
117487 }
117488 });
117489 intermediateTensorInfos.push($x);
117490 axes = getInnerMostAxes(axes.length, $x.shape.length);
117491 }
117492 assertAxesAreInnerMostDims('argMin', [axes[0]], $x.shape.length);
117493 var out = argMinMaxReduce(backend, $x, axes[0], 'min');
117494 intermediateTensorInfos.forEach(function (t) {
117495 return backend.disposeIntermediateTensorInfo(t);
117496 });
117497 return out;
117498 }
117499 var argMinConfig = {
117500 kernelName: ArgMin,
117501 backendName: 'webgl',
117502 kernelFunc: argMin
117503 };
117504
117505 /**
117506 * @license
117507 * Copyright 2020 Google LLC. All Rights Reserved.
117508 * Licensed under the Apache License, Version 2.0 (the "License");
117509 * you may not use this file except in compliance with the License.
117510 * You may obtain a copy of the License at
117511 *
117512 * http://www.apache.org/licenses/LICENSE-2.0
117513 *
117514 * Unless required by applicable law or agreed to in writing, software
117515 * distributed under the License is distributed on an "AS IS" BASIS,
117516 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117517 * See the License for the specific language governing permissions and
117518 * limitations under the License.
117519 * =============================================================================
117520 */
117521 var ASIN = CHECK_NAN_SNIPPET$1 + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return asin(x);\n";
117522 var asin = unaryKernelFunc({
117523 opSnippet: ASIN
117524 });
117525 var asinConfig = {
117526 kernelName: Asin,
117527 backendName: 'webgl',
117528 kernelFunc: asin
117529 };
117530
117531 /**
117532 * @license
117533 * Copyright 2020 Google LLC. All Rights Reserved.
117534 * Licensed under the Apache License, Version 2.0 (the "License");
117535 * you may not use this file except in compliance with the License.
117536 * You may obtain a copy of the License at
117537 *
117538 * http://www.apache.org/licenses/LICENSE-2.0
117539 *
117540 * Unless required by applicable law or agreed to in writing, software
117541 * distributed under the License is distributed on an "AS IS" BASIS,
117542 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117543 * See the License for the specific language governing permissions and
117544 * limitations under the License.
117545 * =============================================================================
117546 */
117547 var ASINH = CHECK_NAN_SNIPPET$1 + "return log(x + sqrt(x * x + 1.0));";
117548 var asinh = unaryKernelFunc({
117549 opSnippet: ASINH
117550 });
117551 var asinhConfig = {
117552 kernelName: Asinh,
117553 backendName: 'webgl',
117554 kernelFunc: asinh
117555 };
117556
117557 /**
117558 * @license
117559 * Copyright 2020 Google LLC. All Rights Reserved.
117560 * Licensed under the Apache License, Version 2.0 (the "License");
117561 * you may not use this file except in compliance with the License.
117562 * You may obtain a copy of the License at
117563 *
117564 * http://www.apache.org/licenses/LICENSE-2.0
117565 *
117566 * Unless required by applicable law or agreed to in writing, software
117567 * distributed under the License is distributed on an "AS IS" BASIS,
117568 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117569 * See the License for the specific language governing permissions and
117570 * limitations under the License.
117571 * =============================================================================
117572 */
117573 var ATAN = CHECK_NAN_SNIPPET$1 + "\n return atan(x);\n";
117574 var atan = unaryKernelFunc({
117575 opSnippet: ATAN
117576 });
117577 var atanConfig = {
117578 kernelName: Atan,
117579 backendName: 'webgl',
117580 kernelFunc: atan
117581 };
117582
117583 /**
117584 * @license
117585 * Copyright 2020 Google LLC. All Rights Reserved.
117586 * Licensed under the Apache License, Version 2.0 (the "License");
117587 * you may not use this file except in compliance with the License.
117588 * You may obtain a copy of the License at
117589 *
117590 * http://www.apache.org/licenses/LICENSE-2.0
117591 *
117592 * Unless required by applicable law or agreed to in writing, software
117593 * distributed under the License is distributed on an "AS IS" BASIS,
117594 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117595 * See the License for the specific language governing permissions and
117596 * limitations under the License.
117597 * =============================================================================
117598 */
117599 var ATAN2 = CHECK_NAN_SNIPPET + "\n return atan(a, b);\n";
117600 var ATAN2_PACKED = "\n vec4 result = atan(a, b);\n bvec4 isNaNA = isnan(a);\n bvec4 isNaNB = isnan(b);\n bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);\n " + CHECK_NAN_SNIPPET_PACKED + "\n return result;\n";
117601 var atan2 = binaryKernelFunc({
117602 opSnippet: ATAN2,
117603 packedOpSnippet: ATAN2_PACKED
117604 });
117605 var atan2Config = {
117606 kernelName: Atan2,
117607 backendName: 'webgl',
117608 kernelFunc: atan2
117609 };
117610
117611 /**
117612 * @license
117613 * Copyright 2020 Google LLC. All Rights Reserved.
117614 * Licensed under the Apache License, Version 2.0 (the "License");
117615 * you may not use this file except in compliance with the License.
117616 * You may obtain a copy of the License at
117617 *
117618 * http://www.apache.org/licenses/LICENSE-2.0
117619 *
117620 * Unless required by applicable law or agreed to in writing, software
117621 * distributed under the License is distributed on an "AS IS" BASIS,
117622 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117623 * See the License for the specific language governing permissions and
117624 * limitations under the License.
117625 * =============================================================================
117626 */
117627 var ATANH = CHECK_NAN_SNIPPET$1 + "\n if ((x < -1.0) || (x > 1.0)) return NAN;\nreturn (log(1.0 + x) - log(1.0 - x)) / 2.0;";
117628 var atanh = unaryKernelFunc({
117629 opSnippet: ATANH
117630 });
117631 var atanhConfig = {
117632 kernelName: Atanh,
117633 backendName: 'webgl',
117634 kernelFunc: atanh
117635 };
117636
117637 /**
117638 * @license
117639 * Copyright 2017 Google LLC. All Rights Reserved.
117640 * Licensed under the Apache License, Version 2.0 (the "License");
117641 * you may not use this file except in compliance with the License.
117642 * You may obtain a copy of the License at
117643 *
117644 * http://www.apache.org/licenses/LICENSE-2.0
117645 *
117646 * Unless required by applicable law or agreed to in writing, software
117647 * distributed under the License is distributed on an "AS IS" BASIS,
117648 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117649 * See the License for the specific language governing permissions and
117650 * limitations under the License.
117651 * =============================================================================
117652 */
117653 var Pool2DProgram = /*#__PURE__*/_createClass(function Pool2DProgram(convInfo, poolType, computePositions) {
117654 var flattenPositions = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
117655 var includeBatchInIndex = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
117656 _classCallCheck(this, Pool2DProgram);
117657 this.variableNames = ['x'];
117658 if (poolType === 'avg' && computePositions) {
117659 throw new Error('Cannot compute positions for average pool.');
117660 }
117661 var filterWidth = convInfo.filterWidth;
117662 var strideHeight = convInfo.strideHeight;
117663 var strideWidth = convInfo.strideWidth;
117664 var dilationHeight = convInfo.dilationHeight;
117665 var dilationWidth = convInfo.dilationWidth;
117666 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
117667 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
117668 var padTop = convInfo.padInfo.top;
117669 var padLeft = convInfo.padInfo.left;
117670 this.outputShape = convInfo.outShape;
117671 var isAvgPool = poolType === 'avg';
117672 var batchFlattenPositionStr = "((batch * ".concat(convInfo.inHeight, " + xR) * ").concat(convInfo.inWidth, " + xC) * ").concat(convInfo.inChannels, " + d");
117673 var flattenPositionStr = "(xR * ".concat(convInfo.inWidth, " + xC) * ").concat(convInfo.inChannels, " + d");
117674 var initializationValue = '0.0';
117675 if (!isAvgPool) {
117676 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
117677 initializationValue = '-1.0 / 1e-20';
117678 }
117679 if (computePositions) {
117680 var _compareOp = '>=';
117681 this.userCode = "\n const ivec2 strides = ivec2(".concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec2 pads = ivec2(").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n float avgValue = 0.0;\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, ";\n wC += ").concat(dilationWidth, ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n float value = getX(batch, xR, xC, d);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value ").concat(_compareOp, " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = ").concat(flattenPositions ? includeBatchInIndex ? batchFlattenPositionStr : flattenPositionStr : "wR * ".concat(effectiveFilterWidth, " + wC"), ";\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n ");
117682 return;
117683 }
117684 var compareOp = 'max';
117685 var returnValue = "".concat(poolType, "(").concat(poolType, "(").concat(poolType, "(") + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
117686 if (poolType === 'avg') {
117687 returnValue = "avgValue / max(count, 1.0)";
117688 }
117689 var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
117690 var filterWidthVec4Remainder = filterWidth % 4;
117691 var updateSnippet = "\n if (".concat(isAvgPool, ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = ").concat(compareOp, "(values, minMaxValue);\n }\n ");
117692 this.userCode = "\n const ivec2 strides = ivec2(".concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec2 pads = ivec2(").concat(padTop, ", ").concat(padLeft, ");\n const float initializationValue = ").concat(initializationValue, ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xR, int xC, int d) {\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xR, xC, d);\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n vec4 minMaxValue = vec4(").concat(initializationValue, ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(filterWidthNearestVec4, "; wC += 4) {\n int xC = xCCorner + wC * ").concat(dilationWidth, ";\n\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + ").concat(dilationWidth, ", d),\n getValue(batch, xR, xC + 2 * ").concat(dilationWidth, ", d),\n getValue(batch, xR, xC + 3 * ").concat(dilationWidth, ", d)\n );\n\n ").concat(updateSnippet, "\n }\n\n int xC = xCCorner + ").concat(filterWidthNearestVec4, ";\n if (").concat(filterWidthVec4Remainder === 1, ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(filterWidthVec4Remainder === 2, ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + ").concat(dilationWidth, ", d),\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(filterWidthVec4Remainder === 3, ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + ").concat(dilationWidth, ", d),\n getValue(batch, xR, xC + 2 * ").concat(dilationWidth, ", d),\n initializationValue\n );\n\n ").concat(updateSnippet, "\n }\n }\n setOutput(").concat(returnValue, ");\n }\n ");
117693 });
117694 var Pool3DProgram = /*#__PURE__*/_createClass(function Pool3DProgram(convInfo, poolType, computePositions) {
117695 var flattenPositions = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
117696 var includeBatchInIndex = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
117697 _classCallCheck(this, Pool3DProgram);
117698 this.variableNames = ['x'];
117699 if (poolType === 'avg' && computePositions) {
117700 throw new Error('Cannot compute positions for average pool.');
117701 }
117702 var filterWidth = convInfo.filterWidth;
117703 var strideDepth = convInfo.strideDepth;
117704 var strideHeight = convInfo.strideHeight;
117705 var strideWidth = convInfo.strideWidth;
117706 var dilationDepth = convInfo.dilationDepth;
117707 var dilationHeight = convInfo.dilationHeight;
117708 var dilationWidth = convInfo.dilationWidth;
117709 var effectiveFilterDepth = convInfo.effectiveFilterDepth;
117710 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
117711 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
117712 var padFront = convInfo.padInfo.front;
117713 var padTop = convInfo.padInfo.top;
117714 var padLeft = convInfo.padInfo.left;
117715 this.outputShape = convInfo.outShape;
117716 var isAvgPool = poolType === 'avg';
117717 var initializationValue = '0.0';
117718 if (!isAvgPool) {
117719 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
117720 initializationValue = '-1.0 / 1e-20';
117721 }
117722 if (computePositions) {
117723 var _compareOp2 = '>=';
117724 this.userCode = "\n const ivec3 strides =\n ivec3(".concat(strideDepth, ", ").concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec3 pads = ivec3(").concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n\n for (int wD = 0; wD < ").concat(effectiveFilterDepth, ";\n wD += ").concat(dilationDepth, ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= ").concat(convInfo.inDepth, ") {\n continue;\n }\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, ";\n wC += ").concat(dilationWidth, ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n float value = getX(batch, xD, xR, xC, ch);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value ").concat(_compareOp2, " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = ").concat(flattenPositions ? includeBatchInIndex ? "(((batch * ".concat(convInfo.inDepth, " + xD) * ").concat(convInfo.inHeight, " + xR) * ").concat(convInfo.inWidth, " + xC) * ").concat(convInfo.inChannels, " + ch") : "((xD * ".concat(convInfo.inHeight, " + xR) * ").concat(convInfo.inWidth, " + xC) * ").concat(convInfo.inChannels, " + ch") : "wD * ".concat(effectiveFilterHeight, " * ").concat(effectiveFilterWidth, " +\n wR * ").concat(effectiveFilterWidth, " + wC"), ";\n }\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n ");
117725 return;
117726 }
117727 var compareOp = 'max';
117728 var returnValue = "".concat(poolType, "(").concat(poolType, "(").concat(poolType, "(") + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
117729 if (poolType === 'avg') {
117730 // Use `max(count, 1.0)` instead of `count` in case count === 0.0.
117731 // If count === 0.0, `avgValue` is always 0.0 and we change `count`'s
117732 // value to avoid dividing zero.
117733 returnValue = "avgValue / max(count, 1.0)";
117734 }
117735 var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
117736 var filterWidthVec4Remainder = filterWidth % 4;
117737 var updateSnippet = "\n if (".concat(isAvgPool, ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = ").concat(compareOp, "(values, minMaxValue);\n }\n ");
117738 this.userCode = "\n const ivec3 strides =\n ivec3(".concat(strideDepth, ", ").concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec3 pads = ivec3(").concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n const float initializationValue = ").concat(initializationValue, ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xD, int xR, int xC, int ch) {\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xD, xR, xC, ch);\n }\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).\n // ? = to be determined\n vec4 minMaxValue = vec4(").concat(initializationValue, ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wD = 0; wD < ").concat(effectiveFilterDepth, ";\n wD += ").concat(dilationDepth, ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= ").concat(convInfo.inDepth, ") {\n continue;\n }\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(filterWidthNearestVec4, "; wC += 4) {\n int xC = xCCorner + wC * ").concat(dilationWidth, ";\n\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + ").concat(dilationWidth, ", ch),\n getValue(batch, xD, xR, xC + 2 * ").concat(dilationWidth, ", ch),\n getValue(batch, xD, xR, xC + 3 * ").concat(dilationWidth, ", ch)\n );\n\n ").concat(updateSnippet, "\n }\n\n int xC = xCCorner + ").concat(filterWidthNearestVec4, ";\n if (").concat(filterWidthVec4Remainder === 1, ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(filterWidthVec4Remainder === 2, ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + ").concat(dilationWidth, ", ch),\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(filterWidthVec4Remainder === 3, ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + ").concat(dilationWidth, ", ch),\n getValue(batch, xD, xR, xC + 2 * ").concat(dilationWidth, ", ch),\n initializationValue\n );\n\n ").concat(updateSnippet, "\n }\n }\n }\n setOutput(").concat(returnValue, ");\n }\n ");
117739 });
117740
117741 /**
117742 * @license
117743 * Copyright 2020 Google LLC. All Rights Reserved.
117744 * Licensed under the Apache License, Version 2.0 (the "License");
117745 * you may not use this file except in compliance with the License.
117746 * You may obtain a copy of the License at
117747 *
117748 * http://www.apache.org/licenses/LICENSE-2.0
117749 *
117750 * Unless required by applicable law or agreed to in writing, software
117751 * distributed under the License is distributed on an "AS IS" BASIS,
117752 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117753 * See the License for the specific language governing permissions and
117754 * limitations under the License.
117755 * =============================================================================
117756 */
117757 function avgPool(args) {
117758 var inputs = args.inputs,
117759 backend = args.backend,
117760 attrs = args.attrs;
117761 var x = inputs.x;
117762 assertNotComplex(x, 'avgPool');
117763 var filterSize = attrs.filterSize,
117764 strides = attrs.strides,
117765 pad = attrs.pad,
117766 dimRoundingMode = attrs.dimRoundingMode;
117767 var dilations = 1;
117768 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
117769 return 'Error in avgPool: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
117770 });
117771 var convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
117772 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) {
117773 return identity({
117774 inputs: {
117775 x: x
117776 },
117777 backend: backend
117778 });
117779 }
117780 var avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false);
117781 return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
117782 }
117783 var avgPoolConfig = {
117784 kernelName: AvgPool,
117785 backendName: 'webgl',
117786 kernelFunc: avgPool
117787 };
117788
117789 /**
117790 * @license
117791 * Copyright 2020 Google LLC. All Rights Reserved.
117792 * Licensed under the Apache License, Version 2.0 (the "License");
117793 * you may not use this file except in compliance with the License.
117794 * You may obtain a copy of the License at
117795 *
117796 * http://www.apache.org/licenses/LICENSE-2.0
117797 *
117798 * Unless required by applicable law or agreed to in writing, software
117799 * distributed under the License is distributed on an "AS IS" BASIS,
117800 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117801 * See the License for the specific language governing permissions and
117802 * limitations under the License.
117803 * =============================================================================
117804 */
117805 function avgPool3D(args) {
117806 var inputs = args.inputs,
117807 backend = args.backend,
117808 attrs = args.attrs;
117809 var x = inputs.x;
117810 var filterSize = attrs.filterSize,
117811 strides = attrs.strides,
117812 pad = attrs.pad,
117813 dimRoundingMode = attrs.dimRoundingMode,
117814 dataFormat = attrs.dataFormat;
117815 var dilations = [1, 1, 1];
117816 var convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
117817 var avgPoolProgram = new Pool3DProgram(convInfo, 'avg', false);
117818 return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
117819 }
117820 var avgPool3DConfig = {
117821 kernelName: AvgPool3D,
117822 backendName: 'webgl',
117823 kernelFunc: avgPool3D
117824 };
117825
117826 /**
117827 * @license
117828 * Copyright 2017 Google LLC. All Rights Reserved.
117829 * Licensed under the Apache License, Version 2.0 (the "License");
117830 * you may not use this file except in compliance with the License.
117831 * You may obtain a copy of the License at
117832 *
117833 * http://www.apache.org/licenses/LICENSE-2.0
117834 *
117835 * Unless required by applicable law or agreed to in writing, software
117836 * distributed under the License is distributed on an "AS IS" BASIS,
117837 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117838 * See the License for the specific language governing permissions and
117839 * limitations under the License.
117840 * =============================================================================
117841 */
117842 var AvgPool2DBackpropProgram = /*#__PURE__*/_createClass(function AvgPool2DBackpropProgram(convInfo) {
117843 _classCallCheck(this, AvgPool2DBackpropProgram);
117844 this.variableNames = ['dy'];
117845 this.outputShape = convInfo.inShape;
117846 var filterHeight = convInfo.filterHeight;
117847 var filterWidth = convInfo.filterWidth;
117848 var strideHeight = convInfo.strideHeight;
117849 var strideWidth = convInfo.strideWidth;
117850 var dilationHeight = convInfo.dilationHeight;
117851 var dilationWidth = convInfo.dilationWidth;
117852 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
117853 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
117854 var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
117855 var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
117856 var avgMultiplier = 1 / (filterHeight * filterWidth);
117857 this.userCode = "\n const ivec2 pads = ivec2(".concat(padTop, ", ").concat(padLeft, ");\n const float avgMultiplier = float(").concat(avgMultiplier, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, ";\n wC+= ").concat(dilationWidth, ") {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n setOutput(dotProd);\n }\n ");
117858 });
117859 var AvgPool3DBackpropProgram = /*#__PURE__*/_createClass(function AvgPool3DBackpropProgram(convInfo) {
117860 _classCallCheck(this, AvgPool3DBackpropProgram);
117861 this.variableNames = ['dy'];
117862 this.outputShape = convInfo.inShape;
117863 var filterDepth = convInfo.filterDepth;
117864 var filterHeight = convInfo.filterHeight;
117865 var filterWidth = convInfo.filterWidth;
117866 var strideDepth = convInfo.strideDepth;
117867 var strideHeight = convInfo.strideHeight;
117868 var strideWidth = convInfo.strideWidth;
117869 var dilationDepth = convInfo.dilationDepth;
117870 var dilationHeight = convInfo.dilationHeight;
117871 var dilationWidth = convInfo.dilationWidth;
117872 var effectiveFilterDepth = convInfo.effectiveFilterDepth;
117873 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
117874 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
117875 var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
117876 var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
117877 var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
117878 var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
117879 this.userCode = "\n const ivec3 pads = ivec3(".concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n const float avgMultiplier = float(").concat(avgMultiplier, ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < ").concat(effectiveFilterDepth, ";\n wD += ").concat(dilationDepth, ") {\n float dyD = float(dyDCorner + wD) / ").concat(strideDepth, ".0;\n\n if (dyD < 0.0 || dyD >= ").concat(convInfo.outDepth, ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, ";\n wC += ").concat(dilationWidth, ") {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n }\n setOutput(dotProd);\n }\n ");
117880 });
117881
117882 /**
117883 * @license
117884 * Copyright 2020 Google LLC. All Rights Reserved.
117885 * Licensed under the Apache License, Version 2.0 (the "License");
117886 * you may not use this file except in compliance with the License.
117887 * You may obtain a copy of the License at
117888 *
117889 * http://www.apache.org/licenses/LICENSE-2.0
117890 *
117891 * Unless required by applicable law or agreed to in writing, software
117892 * distributed under the License is distributed on an "AS IS" BASIS,
117893 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117894 * See the License for the specific language governing permissions and
117895 * limitations under the License.
117896 * =============================================================================
117897 */
117898 function avgPool3DGrad(args) {
117899 var inputs = args.inputs,
117900 backend = args.backend,
117901 attrs = args.attrs;
117902 var dy = inputs.dy,
117903 input = inputs.input;
117904 var x = input;
117905 var filterSize = attrs.filterSize,
117906 strides = attrs.strides,
117907 pad = attrs.pad,
117908 dimRoundingMode = attrs.dimRoundingMode;
117909 var dilations = [1, 1, 1];
117910 var convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
117911 var avgPoolBackpropProgram = new AvgPool3DBackpropProgram(convInfo);
117912 return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
117913 }
117914 var avgPool3DGradConfig = {
117915 kernelName: AvgPool3DGrad,
117916 backendName: 'webgl',
117917 kernelFunc: avgPool3DGrad
117918 };
117919
117920 /**
117921 * @license
117922 * Copyright 2020 Google LLC. All Rights Reserved.
117923 * Licensed under the Apache License, Version 2.0 (the "License");
117924 * you may not use this file except in compliance with the License.
117925 * You may obtain a copy of the License at
117926 *
117927 * http://www.apache.org/licenses/LICENSE-2.0
117928 *
117929 * Unless required by applicable law or agreed to in writing, software
117930 * distributed under the License is distributed on an "AS IS" BASIS,
117931 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117932 * See the License for the specific language governing permissions and
117933 * limitations under the License.
117934 * =============================================================================
117935 */
117936 function avgPoolGrad(args) {
117937 var inputs = args.inputs,
117938 backend = args.backend,
117939 attrs = args.attrs;
117940 var dy = inputs.dy,
117941 input = inputs.input;
117942 var x = input;
117943 assertNotComplex([dy, input], 'avgPoolGrad');
117944 var filterSize = attrs.filterSize,
117945 strides = attrs.strides,
117946 pad = attrs.pad;
117947 var convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad);
117948 var avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo);
117949 return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
117950 }
117951 var avgPoolGradConfig = {
117952 kernelName: AvgPoolGrad,
117953 backendName: 'webgl',
117954 kernelFunc: avgPoolGrad
117955 };
117956
117957 /**
117958 * @license
117959 * Copyright 2020 Google LLC. All Rights Reserved.
117960 * Licensed under the Apache License, Version 2.0 (the "License");
117961 * you may not use this file except in compliance with the License.
117962 * You may obtain a copy of the License at
117963 *
117964 * http://www.apache.org/licenses/LICENSE-2.0
117965 *
117966 * Unless required by applicable law or agreed to in writing, software
117967 * distributed under the License is distributed on an "AS IS" BASIS,
117968 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117969 * See the License for the specific language governing permissions and
117970 * limitations under the License.
117971 * =============================================================================
117972 */
117973 function batchMatMul(args) {
117974 var inputs = args.inputs,
117975 backend = args.backend,
117976 attrs = args.attrs;
117977 var a = inputs.a,
117978 b = inputs.b;
117979 var transposeA = attrs.transposeA,
117980 transposeB = attrs.transposeB;
117981 return batchMatMulImpl({
117982 a: a,
117983 b: b,
117984 transposeA: transposeA,
117985 transposeB: transposeB,
117986 backend: backend
117987 });
117988 }
117989 var batchMatMulConfig = {
117990 kernelName: BatchMatMul,
117991 backendName: 'webgl',
117992 kernelFunc: batchMatMul
117993 };
117994
117995 var BatchNormProgram = /*#__PURE__*/_createClass(function BatchNormProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
117996 _classCallCheck(this, BatchNormProgram);
117997 this.outputShape = [];
117998 this.variableNames = ['x', 'mean', 'variance'];
117999 assertAndGetBroadcastShape(xShape, meanShape);
118000 assertAndGetBroadcastShape(xShape, varianceShape);
118001 var offsetSnippet = '0.0';
118002 if (offsetShape != null) {
118003 assertAndGetBroadcastShape(xShape, offsetShape);
118004 this.variableNames.push('offset');
118005 offsetSnippet = 'getOffsetAtOutCoords()';
118006 }
118007 var scaleSnippet = '1.0';
118008 if (scaleShape != null) {
118009 assertAndGetBroadcastShape(xShape, scaleShape);
118010 this.variableNames.push('scale');
118011 scaleSnippet = 'getScaleAtOutCoords()';
118012 }
118013 this.outputShape = xShape;
118014 this.userCode = "\n void main() {\n float x = getXAtOutCoords();\n float mean = getMeanAtOutCoords();\n float variance = getVarianceAtOutCoords();\n float offset = ".concat(offsetSnippet, ";\n float scale = ").concat(scaleSnippet, ";\n float inv = scale * inversesqrt(variance + float(").concat(varianceEpsilon, "));\n setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));\n }\n ");
118015 });
118016
118017 var BatchNormPackedProgram = /*#__PURE__*/_createClass(function BatchNormPackedProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
118018 _classCallCheck(this, BatchNormPackedProgram);
118019 this.packedInputs = true;
118020 this.packedOutput = true;
118021 this.variableNames = ['x', 'mean', 'variance'];
118022 assertAndGetBroadcastShape(xShape, meanShape);
118023 assertAndGetBroadcastShape(xShape, varianceShape);
118024 var offsetSnippet = 'vec4(0.0)';
118025 if (offsetShape != null) {
118026 assertAndGetBroadcastShape(xShape, offsetShape);
118027 this.variableNames.push('offset');
118028 offsetSnippet = 'getOffsetAtOutCoords()';
118029 }
118030 var scaleSnippet = 'vec4(1.0)';
118031 if (scaleShape != null) {
118032 assertAndGetBroadcastShape(xShape, scaleShape);
118033 this.variableNames.push('scale');
118034 scaleSnippet = 'getScaleAtOutCoords()';
118035 }
118036 this.outputShape = xShape;
118037 this.userCode = "\n void main() {\n vec4 offset = ".concat(offsetSnippet, ";\n vec4 scale = ").concat(scaleSnippet, ";\n\n vec4 x = getXAtOutCoords();\n vec4 mean = getMeanAtOutCoords();\n vec4 variance = getVarianceAtOutCoords();\n\n vec4 inv = scale * inversesqrt(variance + vec4(").concat(varianceEpsilon, "));\n\n setOutput((x - mean) * inv + offset);\n }\n ");
118038 });
118039
118040 /**
118041 * @license
118042 * Copyright 2020 Google LLC. All Rights Reserved.
118043 * Licensed under the Apache License, Version 2.0 (the "License");
118044 * you may not use this file except in compliance with the License.
118045 * You may obtain a copy of the License at
118046 *
118047 * http://www.apache.org/licenses/LICENSE-2.0
118048 *
118049 * Unless required by applicable law or agreed to in writing, software
118050 * distributed under the License is distributed on an "AS IS" BASIS,
118051 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118052 * See the License for the specific language governing permissions and
118053 * limitations under the License.
118054 * =============================================================================
118055 */
118056 var batchNorm = function batchNorm(_ref) {
118057 var inputs = _ref.inputs,
118058 backend = _ref.backend,
118059 attrs = _ref.attrs;
118060 var x = inputs.x,
118061 mean = inputs.mean,
118062 variance = inputs.variance,
118063 offset = inputs.offset,
118064 scale = inputs.scale;
118065 assert$1(mean.shape.length === variance.shape.length, function () {
118066 return 'Batch normalization gradient requires mean and variance to have ' + 'equal ranks.';
118067 });
118068 assert$1(offset == null || mean.shape.length === offset.shape.length, function () {
118069 return 'Batch normalization gradient requires mean and offset to have ' + 'equal ranks.';
118070 });
118071 assert$1(scale == null || mean.shape.length === scale.shape.length, function () {
118072 return 'Batch normalization gradient requires mean and scale to have ' + 'equal ranks.';
118073 });
118074 var varianceEpsilon = attrs.varianceEpsilon;
118075 if (varianceEpsilon == null) {
118076 varianceEpsilon = 0.001;
118077 }
118078 var finalInputs = [x, mean, variance];
118079 var offsetShape = null;
118080 if (offset != null) {
118081 offsetShape = offset.shape;
118082 finalInputs.push(offset);
118083 }
118084 var scaleShape = null;
118085 if (scale != null) {
118086 scaleShape = scale.shape;
118087 finalInputs.push(scale);
118088 }
118089 var program = env().getBool('WEBGL_PACK_NORMALIZATION') ? new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) : new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
118090 var output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype);
118091 return output;
118092 };
118093 var batchNormConfig = {
118094 kernelName: FusedBatchNorm,
118095 backendName: 'webgl',
118096 kernelFunc: batchNorm
118097 };
118098
118099 var SliceProgram = /*#__PURE__*/_createClass(function SliceProgram(destSize) {
118100 _classCallCheck(this, SliceProgram);
118101 this.variableNames = ['source'];
118102 this.outputShape = destSize;
118103 this.rank = destSize.length;
118104 var dtype = getCoordsDataType(this.rank);
118105 this.customUniforms = [{
118106 name: 'start',
118107 arrayIndex: this.rank,
118108 type: 'int'
118109 }];
118110 var sourceCoords = getCoords$1(this.rank);
118111 var body;
118112 var coordSum = destSize.map(function (_, i) {
118113 return "sourceLoc.".concat(coords[i], " = start[").concat(i, "] + coords.").concat(coords[i], ";");
118114 });
118115 body = "\n ".concat(dtype, " sourceLoc;\n ").concat(dtype, " coords = getOutputCoords();\n ").concat(coordSum.join('\n'), "\n ");
118116 this.userCode = "\n void main() {\n ".concat(body, "\n setOutput(getSource(").concat(sourceCoords, "));\n }\n ");
118117 });
118118 var coords = ['x', 'y', 'z', 'w', 'u', 'v'];
118119 function getCoords$1(rank) {
118120 if (rank === 1) {
118121 return 'sourceLoc';
118122 } else if (rank <= 6) {
118123 return coords.slice(0, rank).map(function (x) {
118124 return 'sourceLoc.' + x;
118125 }).join(',');
118126 } else {
118127 throw Error("Slicing for rank ".concat(rank, " is not yet supported"));
118128 }
118129 }
118130
118131 var SlicePackedProgram = /*#__PURE__*/_createClass(function SlicePackedProgram(destSize) {
118132 _classCallCheck(this, SlicePackedProgram);
118133 this.variableNames = ['source'];
118134 this.packedInputs = true;
118135 this.packedOutput = true;
118136 this.outputShape = destSize;
118137 this.rank = destSize.length;
118138 this.customUniforms = [{
118139 name: 'start',
118140 arrayIndex: this.rank,
118141 type: 'int'
118142 }];
118143 var dtype = getCoordsDataType(this.rank);
118144 var coords = getChannels('coords', this.rank);
118145 var sourceLoc = getChannels('sourceLoc', this.rank);
118146 var innerDims = this.rank === 1 ? 'sourceLoc' : "vec2(".concat(sourceLoc.slice(-2).join(), ")");
118147 var getChannel = "getChannel(getSource(".concat(sourceLoc.join(), "), ").concat(innerDims, ")");
118148 var upperRow = "\n result.x = ".concat(getChannel, ";\n if (++").concat(coords[this.rank - 1], " < ").concat(destSize[this.rank - 1], ") {\n ++").concat(sourceLoc[this.rank - 1], ";\n result.y = ").concat(getChannel, ";\n --").concat(sourceLoc[this.rank - 1], ";\n }\n ");
118149 var lowerRow = this.rank === 1 ? '' : "\n --".concat(coords[this.rank - 1], ";\n if (++").concat(coords[this.rank - 2], " < ").concat(destSize[this.rank - 2], ") {\n ++").concat(sourceLoc[this.rank - 2], ";\n result.z = ").concat(getChannel, ";\n if (++").concat(coords[this.rank - 1], " < ").concat(destSize[this.rank - 1], ") {\n ++").concat(sourceLoc[this.rank - 1], ";\n result.w = ").concat(getChannel, ";\n }\n }\n ");
118150 var sourceLocSetup = this.rank <= 4 ? "sourceLoc = coords +\n ".concat(dtype, "(").concat(destSize.map(function (_, i) {
118151 return "start[".concat(i, "]");
118152 }).join(), ");") : destSize.map(function (_, i) {
118153 return "".concat(sourceLoc[i], " = ").concat(coords[i], " + start[").concat(i, "];");
118154 }).join('\n');
118155 this.userCode = "\n void main() {\n ".concat(dtype, " coords = getOutputCoords();\n ").concat(dtype, " sourceLoc;\n ").concat(sourceLocSetup, "\n vec4 result = vec4(0.);\n ").concat(upperRow, "\n ").concat(lowerRow, "\n setOutput(result);\n }\n ");
118156 });
118157
118158 function shallowSlice(x, begin, size, backend) {
118159 var xTexData = backend.texData.get(x.dataId);
118160 var t = backend.makeTensorInfo(size, x.dtype);
118161 var newTexData = backend.texData.get(t.dataId);
118162 // Copy texture data from the original tensor.
118163 Object.assign(newTexData, xTexData);
118164 newTexData.refCount = 1;
118165 newTexData.shape = size;
118166 newTexData.dtype = x.dtype;
118167 var flatOffset = computeFlatOffset(begin, computeStrides(x.shape));
118168 if (xTexData.slice) {
118169 // We are slicing an already sliced tensor, so we have to accumulate
118170 // the offset.
118171 flatOffset += xTexData.slice.flatOffset;
118172 }
118173 newTexData.slice = {
118174 flatOffset: flatOffset,
118175 // Point to the original dataId, which is used to do ref counting.
118176 origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
118177 };
118178 // Increase the ref count for that data bucket.
118179 var refCount = backend.dataRefCount.get(newTexData.slice.origDataId) || 1;
118180 backend.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
118181 return t;
118182 }
118183 function slice(args) {
118184 var inputs = args.inputs,
118185 backend = args.backend,
118186 attrs = args.attrs;
118187 var x = inputs.x;
118188 var begin = attrs.begin,
118189 size = attrs.size;
118190 var _slice_util$parseSlic = parseSliceParams(x, begin, size),
118191 _slice_util$parseSlic2 = _slicedToArray(_slice_util$parseSlic, 2),
118192 $begin = _slice_util$parseSlic2[0],
118193 $size = _slice_util$parseSlic2[1];
118194 assertParamsValid(x, $begin, $size);
118195 if (sizeFromShape($size) === 0) {
118196 return backend.makeTensorInfo($size, x.dtype, []);
118197 }
118198 // Run on cpu if dtype is string. For string, the backend represents it
118199 // as Uint8Array[], where each Uint8Array is a character. Given that the
118200 // computation is only on the outer array, uploading the whole data onto
118201 // gpu is wasteful. Also, currently webgl doesn't have a design to
118202 // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we
118203 // just run the kernel on cpu if dtype is string.
118204 if (backend.shouldExecuteOnCPU([x]) || x.dtype === 'string') {
118205 var xTexData = backend.texData.get(x.dataId);
118206 var outValues = sliceImplCPU(xTexData.values, $begin, $size, x.shape, x.dtype);
118207 return backend.makeTensorInfo($size, x.dtype, outValues);
118208 }
118209 var _backend$texData$get = backend.texData.get(x.dataId),
118210 isPacked = _backend$texData$get.isPacked;
118211 var isContinous = isSliceContinous(x.shape, $begin, $size);
118212 if (isPacked || !isContinous) {
118213 var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new SlicePackedProgram($size) : new SliceProgram($size);
118214 var customValues = [$begin];
118215 return backend.runWebGLProgram(program, [x], x.dtype, customValues);
118216 }
118217 backend.uploadToGPU(x.dataId);
118218 return shallowSlice(x, $begin, $size, backend);
118219 }
118220 var sliceConfig = {
118221 kernelName: Slice,
118222 backendName: 'webgl',
118223 kernelFunc: slice
118224 };
118225
118226 /**
118227 * @license
118228 * Copyright 2020 Google LLC. All Rights Reserved.
118229 * Licensed under the Apache License, Version 2.0 (the "License");
118230 * you may not use this file except in compliance with the License.
118231 * You may obtain a copy of the License at
118232 *
118233 * http://www.apache.org/licenses/LICENSE-2.0
118234 *
118235 * Unless required by applicable law or agreed to in writing, software
118236 * distributed under the License is distributed on an "AS IS" BASIS,
118237 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118238 * See the License for the specific language governing permissions and
118239 * limitations under the License.
118240 * =============================================================================
118241 */
118242 var batchToSpaceND = function batchToSpaceND(args) {
118243 var inputs = args.inputs,
118244 backend = args.backend,
118245 attrs = args.attrs;
118246 var x = inputs.x;
118247 var blockShape = attrs.blockShape,
118248 crops = attrs.crops;
118249 assert$1(x.shape.length <= 4, function () {
118250 return 'batchToSpaceND for rank > 4 with a WebGL backend not ' + 'implemented yet';
118251 });
118252 var prod = blockShape.reduce(function (a, b) {
118253 return a * b;
118254 });
118255 var reshaped = getReshaped(x.shape, blockShape, prod);
118256 var permuted = getPermuted(reshaped.length, blockShape.length);
118257 var reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
118258 var sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
118259 var sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
118260 var toDispose = [];
118261 var reshapedIntermediate = reshape({
118262 inputs: {
118263 x: x
118264 },
118265 backend: backend,
118266 attrs: {
118267 shape: reshaped
118268 }
118269 });
118270 var transposedIntermediate = transpose({
118271 inputs: {
118272 x: reshapedIntermediate
118273 },
118274 backend: backend,
118275 attrs: {
118276 perm: permuted
118277 }
118278 });
118279 var reshapedIntermediate2 = reshape({
118280 inputs: {
118281 x: transposedIntermediate
118282 },
118283 backend: backend,
118284 attrs: {
118285 shape: reshapedPermuted
118286 }
118287 });
118288 var sliced = slice({
118289 inputs: {
118290 x: reshapedIntermediate2
118291 },
118292 backend: backend,
118293 attrs: {
118294 begin: sliceBeginCoords,
118295 size: sliceSize
118296 }
118297 });
118298 toDispose.push(reshapedIntermediate);
118299 toDispose.push(transposedIntermediate);
118300 toDispose.push(reshapedIntermediate2);
118301 toDispose.forEach(function (t) {
118302 return backend.disposeIntermediateTensorInfo(t);
118303 });
118304 return sliced;
118305 };
118306 var batchToSpaceNDConfig = {
118307 kernelName: BatchToSpaceND,
118308 backendName: 'webgl',
118309 kernelFunc: batchToSpaceND
118310 };
118311
118312 /**
118313 * @license
118314 * Copyright 2020 Google LLC. All Rights Reserved.
118315 * Licensed under the Apache License, Version 2.0 (the "License");
118316 * you may not use this file except in compliance with the License.
118317 * You may obtain a copy of the License at
118318 *
118319 * http://www.apache.org/licenses/LICENSE-2.0
118320 *
118321 * Unless required by applicable law or agreed to in writing, software
118322 * distributed under the License is distributed on an "AS IS" BASIS,
118323 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118324 * See the License for the specific language governing permissions and
118325 * limitations under the License.
118326 * =============================================================================
118327 */
118328 function bincount(args) {
118329 var inputs = args.inputs,
118330 backend = args.backend,
118331 attrs = args.attrs;
118332 var x = inputs.x,
118333 weights = inputs.weights;
118334 var size = attrs.size;
118335 var xVals = backend.readSync(x.dataId);
118336 var weightsVals = backend.readSync(weights.dataId);
118337 var outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
118338 return backend.makeTensorInfo([size], weights.dtype, outVals);
118339 }
118340 var bincountConfig = {
118341 kernelName: Bincount,
118342 backendName: 'webgl',
118343 kernelFunc: bincount
118344 };
118345
118346 var BITWISEAND = "\n int r = int(a.r) & int(b.r);\n int g = int(a.g) & int(b.g);\n int rb = int(a.b) & int(b.b);\n int ra = int(a.a) & int(b.a);\n return vec4(r, g, rb, ra);\n";
118347 var BITWISEAND_UNPACKED = "\n return float(int(a.r) & int(b.r));\n";
118348 function bitwiseAnd(args) {
118349 var inputs = args.inputs,
118350 backend = args.backend;
118351 var a = inputs.a,
118352 b = inputs.b;
118353 var shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS');
118354 var versionNumber = env().getNumber('WEBGL_VERSION');
118355 // The type of a and b are ensured to be `int32` in core, therefore no need to
118356 // consider other type situations.
118357 if (backend.shouldExecuteOnCPU([a, b]) || versionNumber === 1) {
118358 var aVals = backend.texData.get(a.dataId).values;
118359 var bVals = backend.texData.get(b.dataId).values;
118360 var _cpuBitwiseAnd = bitwiseAndImplCPU(a.shape, b.shape, aVals, bVals, a.dtype),
118361 _cpuBitwiseAnd2 = _slicedToArray(_cpuBitwiseAnd, 2),
118362 outValues = _cpuBitwiseAnd2[0],
118363 outShape = _cpuBitwiseAnd2[1];
118364 var out = backend.makeTensorInfo(outShape, a.dtype);
118365 var outData = backend.texData.get(out.dataId);
118366 outData.values = outValues;
118367 return out;
118368 }
118369 var program;
118370 if (shouldUsePackedProgram) {
118371 program = new BinaryOpPackedProgram(BITWISEAND, a.shape, b.shape, false);
118372 } else {
118373 program = new BinaryOpProgram(BITWISEAND_UNPACKED, a.shape, b.shape);
118374 }
118375 return backend.runWebGLProgram(program, [a, b], a.dtype);
118376 }
118377 var bitwiseAndConfig = {
118378 kernelName: BitwiseAnd,
118379 backendName: 'webgl',
118380 kernelFunc: bitwiseAnd
118381 };
118382
118383 /**
118384 * @license
118385 * Copyright 2021 Google LLC. All Rights Reserved.
118386 * Licensed under the Apache License, Version 2.0 (the "License");
118387 * you may not use this file except in compliance with the License.
118388 * You may obtain a copy of the License at
118389 *
118390 * http://www.apache.org/licenses/LICENSE-2.0
118391 *
118392 * Unless required by applicable law or agreed to in writing, software
118393 * distributed under the License is distributed on an "AS IS" BASIS,
118394 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118395 * See the License for the specific language governing permissions and
118396 * limitations under the License.
118397 * =============================================================================
118398 */
118399 function broadcastArgs(args) {
118400 var inputs = args.inputs,
118401 backend = args.backend;
118402 var s0 = inputs.s0,
118403 s1 = inputs.s1;
118404 var s0Vals = backend.readSync(s0.dataId);
118405 var s1Vals = backend.readSync(s1.dataId);
118406 var broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
118407 return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
118408 }
118409 var broadcastArgsConfig = {
118410 kernelName: BroadcastArgs,
118411 backendName: 'webgl',
118412 kernelFunc: broadcastArgs
118413 };
118414
118415 /**
118416 * @license
118417 * Copyright 2020 Google LLC. All Rights Reserved.
118418 * Licensed under the Apache License, Version 2.0 (the "License");
118419 * you may not use this file except in compliance with the License.
118420 * You may obtain a copy of the License at
118421 *
118422 * http://www.apache.org/licenses/LICENSE-2.0
118423 *
118424 * Unless required by applicable law or agreed to in writing, software
118425 * distributed under the License is distributed on an "AS IS" BASIS,
118426 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118427 * See the License for the specific language governing permissions and
118428 * limitations under the License.
118429 * =============================================================================
118430 */
118431 var NOT_EQUAL = "return float(a != b);";
118432 var notEqual = binaryKernelFunc({
118433 opSnippet: NOT_EQUAL,
118434 cpuKernelImpl: notEqualImplCPU,
118435 dtype: 'bool'
118436 });
118437 var notEqualConfig = {
118438 kernelName: NotEqual,
118439 backendName: 'webgl',
118440 kernelFunc: notEqual
118441 };
118442
118443 /**
118444 * @license
118445 * Copyright 2020 Google LLC. All Rights Reserved.
118446 * Licensed under the Apache License, Version 2.0 (the "License");
118447 * you may not use this file except in compliance with the License.
118448 * You may obtain a copy of the License at
118449 *
118450 * http://www.apache.org/licenses/LICENSE-2.0
118451 *
118452 * Unless required by applicable law or agreed to in writing, software
118453 * distributed under the License is distributed on an "AS IS" BASIS,
118454 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118455 * See the License for the specific language governing permissions and
118456 * limitations under the License.
118457 * =============================================================================
118458 */
118459 function real(args) {
118460 var inputs = args.inputs,
118461 backend = args.backend;
118462 var input = inputs.input;
118463 var inputData = backend.texData.get(input.dataId);
118464 return identity({
118465 inputs: {
118466 x: inputData.complexTensorInfos.real
118467 },
118468 backend: backend
118469 });
118470 }
118471 var realConfig = {
118472 kernelName: Real,
118473 backendName: 'webgl',
118474 kernelFunc: real
118475 };
118476
118477 /**
118478 * @license
118479 * Copyright 2020 Google LLC. All Rights Reserved.
118480 * Licensed under the Apache License, Version 2.0 (the "License");
118481 * you may not use this file except in compliance with the License.
118482 * You may obtain a copy of the License at
118483 *
118484 * http://www.apache.org/licenses/LICENSE-2.0
118485 *
118486 * Unless required by applicable law or agreed to in writing, software
118487 * distributed under the License is distributed on an "AS IS" BASIS,
118488 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118489 * See the License for the specific language governing permissions and
118490 * limitations under the License.
118491 * =============================================================================
118492 */
118493 var TO_INT = "return float(int(x));";
118494 function int(input, backend) {
118495 var program = new UnaryOpProgram(input.shape, TO_INT);
118496 var output = backend.runWebGLProgram(program, [input], 'int32');
118497 return {
118498 dataId: output.dataId,
118499 shape: output.shape,
118500 dtype: output.dtype
118501 };
118502 }
118503
118504 function cast(args) {
118505 var inputs = args.inputs,
118506 backend = args.backend,
118507 attrs = args.attrs;
118508 var x = inputs.x;
118509 var dtype = attrs.dtype;
118510 // Casting to complex64.
118511 if (dtype === 'complex64') {
118512 if (x.dtype === 'complex64') {
118513 return identity({
118514 inputs: {
118515 x: x
118516 },
118517 backend: backend
118518 });
118519 }
118520 // TODO(annxingyuan): Import kernel function once zeros is modularized.
118521 var zerosTensor = zeros$2(x.shape);
118522 var floatX = cast({
118523 inputs: {
118524 x: x
118525 },
118526 backend: backend,
118527 attrs: {
118528 dtype: 'float32'
118529 }
118530 });
118531 var result = complex({
118532 inputs: {
118533 real: floatX,
118534 imag: zerosTensor
118535 },
118536 backend: backend
118537 });
118538 zerosTensor.dispose();
118539 backend.disposeIntermediateTensorInfo(floatX);
118540 return result;
118541 }
118542 // Casting from complex64
118543 if (x.dtype === 'complex64') {
118544 var realPart = real({
118545 inputs: {
118546 input: x
118547 },
118548 backend: backend
118549 });
118550 var _result = cast({
118551 inputs: {
118552 x: realPart
118553 },
118554 backend: backend,
118555 attrs: {
118556 dtype: dtype
118557 }
118558 });
118559 backend.disposeIntermediateTensorInfo(realPart);
118560 return _result;
118561 }
118562 if (!hasEncodingLoss(x.dtype, dtype)) {
118563 // We don't change the underlying data, since we cast to higher
118564 // precision.
118565 var _result2 = identity({
118566 inputs: {
118567 x: x
118568 },
118569 backend: backend
118570 });
118571 return {
118572 dataId: _result2.dataId,
118573 shape: _result2.shape,
118574 dtype: dtype
118575 };
118576 }
118577 if (backend.shouldExecuteOnCPU([x])) {
118578 var values = backend.texData.get(x.dataId).values;
118579 var _castImplCPU = castImplCPU(values, x.shape, x.dtype, dtype),
118580 _castImplCPU2 = _slicedToArray(_castImplCPU, 3),
118581 resultShape = _castImplCPU2[0],
118582 resultType = _castImplCPU2[1],
118583 resultData = _castImplCPU2[2];
118584 return backend.makeTensorInfo(resultShape, resultType, resultData);
118585 }
118586 if (dtype === 'int32') {
118587 return int(x, backend);
118588 }
118589 if (dtype === 'bool') {
118590 var zerosTensorInfo = backend.makeTensorInfo([], 'bool', getTypedArrayFromDType('bool', 1));
118591 var binaryInputs = {
118592 a: x,
118593 b: zerosTensorInfo
118594 };
118595 var _result3 = notEqual({
118596 inputs: binaryInputs,
118597 backend: backend
118598 });
118599 backend.disposeIntermediateTensorInfo(zerosTensorInfo);
118600 return _result3;
118601 }
118602 throw new Error("Error in Cast: failed to cast ".concat(x.dtype, " to ").concat(dtype));
118603 }
118604 var castConfig = {
118605 kernelName: Cast,
118606 backendName: 'webgl',
118607 kernelFunc: cast
118608 };
118609
118610 /**
118611 * @license
118612 * Copyright 2020 Google LLC. All Rights Reserved.
118613 * Licensed under the Apache License, Version 2.0 (the "License");
118614 * you may not use this file except in compliance with the License.
118615 * You may obtain a copy of the License at
118616 *
118617 * http://www.apache.org/licenses/LICENSE-2.0
118618 *
118619 * Unless required by applicable law or agreed to in writing, software
118620 * distributed under the License is distributed on an "AS IS" BASIS,
118621 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118622 * See the License for the specific language governing permissions and
118623 * limitations under the License.
118624 * =============================================================================
118625 */
118626 var CEIL = "return ceil(x);";
118627 var ceil = unaryKernelFunc({
118628 opSnippet: CEIL,
118629 packedOpSnippet: CEIL,
118630 cpuKernelImpl: ceilImplCPU
118631 });
118632 var ceilConfig = {
118633 kernelName: Ceil,
118634 backendName: 'webgl',
118635 kernelFunc: ceil
118636 };
118637
118638 /**
118639 * @license
118640 * Copyright 2017 Google LLC. All Rights Reserved.
118641 * Licensed under the Apache License, Version 2.0 (the "License");
118642 * you may not use this file except in compliance with the License.
118643 * You may obtain a copy of the License at
118644 *
118645 * http://www.apache.org/licenses/LICENSE-2.0
118646 *
118647 * Unless required by applicable law or agreed to in writing, software
118648 * distributed under the License is distributed on an "AS IS" BASIS,
118649 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118650 * See the License for the specific language governing permissions and
118651 * limitations under the License.
118652 * =============================================================================
118653 */
118654 var ClipProgram = /*#__PURE__*/_createClass(function ClipProgram(aShape) {
118655 _classCallCheck(this, ClipProgram);
118656 this.variableNames = ['A'];
118657 this.customUniforms = [{
118658 name: 'minVal',
118659 type: 'float'
118660 }, {
118661 name: 'maxVal',
118662 type: 'float'
118663 }];
118664 this.outputShape = aShape;
118665 this.userCode = "\n\n void main() {\n float value = getAAtOutCoords();\n if (isnan(value)) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, minVal, maxVal));\n }\n ";
118666 });
118667
118668 /**
118669 * @license
118670 * Copyright 2018 Google LLC. All Rights Reserved.
118671 * Licensed under the Apache License, Version 2.0 (the "License");
118672 * you may not use this file except in compliance with the License.
118673 * You may obtain a copy of the License at
118674 *
118675 * http://www.apache.org/licenses/LICENSE-2.0
118676 *
118677 * Unless required by applicable law or agreed to in writing, software
118678 * distributed under the License is distributed on an "AS IS" BASIS,
118679 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118680 * See the License for the specific language governing permissions and
118681 * limitations under the License.
118682 * =============================================================================
118683 */
118684 var ClipPackedProgram = /*#__PURE__*/_createClass(function ClipPackedProgram(aShape) {
118685 _classCallCheck(this, ClipPackedProgram);
118686 this.variableNames = ['A'];
118687 this.packedInputs = true;
118688 this.packedOutput = true;
118689 this.customUniforms = [{
118690 name: 'minVal',
118691 type: 'float'
118692 }, {
118693 name: 'maxVal',
118694 type: 'float'
118695 }];
118696 this.outputShape = aShape;
118697 this.userCode = "\n void main() {\n vec4 value = getAAtOutCoords();\n\n if (any(isnan(value))) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, vec4(minVal), vec4(maxVal)));\n }\n ";
118698 });
118699
118700 /**
118701 * @license
118702 * Copyright 2020 Google LLC. All Rights Reserved.
118703 * Licensed under the Apache License, Version 2.0 (the "License");
118704 * you may not use this file except in compliance with the License.
118705 * You may obtain a copy of the License at
118706 *
118707 * http://www.apache.org/licenses/LICENSE-2.0
118708 *
118709 * Unless required by applicable law or agreed to in writing, software
118710 * distributed under the License is distributed on an "AS IS" BASIS,
118711 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118712 * See the License for the specific language governing permissions and
118713 * limitations under the License.
118714 * =============================================================================
118715 */
118716 function clipByValue(args) {
118717 var inputs = args.inputs,
118718 backend = args.backend,
118719 attrs = args.attrs;
118720 var x = inputs.x;
118721 var clipValueMin = attrs.clipValueMin,
118722 clipValueMax = attrs.clipValueMax;
118723 var program;
118724 if (env().getBool('WEBGL_PACK_CLIP')) {
118725 program = new ClipPackedProgram(x.shape);
118726 } else {
118727 program = new ClipProgram(x.shape);
118728 }
118729 var customValues = [[clipValueMin], [clipValueMax]];
118730 return backend.runWebGLProgram(program, [x], x.dtype, customValues);
118731 }
118732 var clipByValueConfig = {
118733 kernelName: ClipByValue,
118734 backendName: 'webgl',
118735 kernelFunc: clipByValue
118736 };
118737
118738 /**
118739 * @license
118740 * Copyright 2018 Google LLC. All Rights Reserved.
118741 * Licensed under the Apache License, Version 2.0 (the "License");
118742 * you may not use this file except in compliance with the License.
118743 * You may obtain a copy of the License at
118744 *
118745 * http://www.apache.org/licenses/LICENSE-2.0
118746 *
118747 * Unless required by applicable law or agreed to in writing, software
118748 * distributed under the License is distributed on an "AS IS" BASIS,
118749 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118750 * See the License for the specific language governing permissions and
118751 * limitations under the License.
118752 * =============================================================================
118753 */
118754 var ComplexAbsProgram = /*#__PURE__*/_createClass(function ComplexAbsProgram(shape) {
118755 _classCallCheck(this, ComplexAbsProgram);
118756 this.variableNames = ['real', 'imag'];
118757 this.outputShape = shape;
118758 this.userCode = "\n void main() {\n float re = abs(getRealAtOutCoords());\n float im = abs(getImagAtOutCoords());\n float mx = max(re, im);\n\n // sadly the length function in glsl is not underflow-safe\n // (at least not on Intel GPUs). So the safe solution is\n // to ensure underflow-safety in all cases.\n setOutput(\n mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))\n );\n }\n ";
118759 });
118760
118761 /**
118762 * @license
118763 * Copyright 2020 Google LLC. All Rights Reserved.
118764 * Licensed under the Apache License, Version 2.0 (the "License");
118765 * you may not use this file except in compliance with the License.
118766 * You may obtain a copy of the License at
118767 *
118768 * http://www.apache.org/licenses/LICENSE-2.0
118769 *
118770 * Unless required by applicable law or agreed to in writing, software
118771 * distributed under the License is distributed on an "AS IS" BASIS,
118772 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118773 * See the License for the specific language governing permissions and
118774 * limitations under the License.
118775 * =============================================================================
118776 */
118777 // Returns a TensorInfo with the complex shape and the dataId of the
118778 // underlying part. We need to do this because a reshaped complex tensor is
118779 // not reflected in its parts.
118780 function makeComplexComponentTensorInfo(complexTensor, complexPart) {
118781 return {
118782 dataId: complexPart.dataId,
118783 dtype: complexPart.dtype,
118784 shape: complexTensor.shape
118785 };
118786 }
118787 function complexAbs(args) {
118788 var inputs = args.inputs,
118789 backend = args.backend;
118790 var x = inputs.x;
118791 var xData = backend.texData.get(x.dataId);
118792 var program = new ComplexAbsProgram(x.shape);
118793 var programInputs = [makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real), makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag)];
118794 return backend.runWebGLProgram(program, programInputs, programInputs[0].dtype);
118795 }
118796 var complexAbsConfig = {
118797 kernelName: ComplexAbs,
118798 backendName: 'webgl',
118799 kernelFunc: complexAbs
118800 };
118801
118802 var ConcatProgram = /*#__PURE__*/_createClass(
118803 // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat().
118804 function ConcatProgram(shapes) {
118805 _classCallCheck(this, ConcatProgram);
118806 this.outputShape = [];
118807 this.outputShape = computeOutShape$1(shapes, 1 /* axis */);
118808 this.variableNames = shapes.map(function (_, i) {
118809 return "T".concat(i);
118810 });
118811 var offsets = new Array(shapes.length - 1);
118812 offsets[0] = shapes[0][1];
118813 for (var i = 1; i < offsets.length; i++) {
118814 offsets[i] = offsets[i - 1] + shapes[i][1];
118815 }
118816 var snippets = ["if (yC < ".concat(offsets[0], ") setOutput(getT0(yR, yC));")];
118817 for (var _i = 1; _i < offsets.length; _i++) {
118818 var shift = offsets[_i - 1];
118819 snippets.push("else if (yC < ".concat(offsets[_i], ") ") + "setOutput(getT".concat(_i, "(yR, yC-").concat(shift, "));"));
118820 }
118821 var lastIndex = offsets.length;
118822 var lastShift = offsets[offsets.length - 1];
118823 snippets.push("else setOutput(getT".concat(lastIndex, "(yR, yC-").concat(lastShift, "));"));
118824 this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int yR = coords.x;\n int yC = coords.y;\n\n ".concat(snippets.join('\n '), "\n }\n ");
118825 });
118826
118827 var ConcatPackedProgram = /*#__PURE__*/_createClass(function ConcatPackedProgram(shapes, axis) {
118828 _classCallCheck(this, ConcatPackedProgram);
118829 this.packedInputs = true;
118830 this.packedOutput = true;
118831 this.outputShape = [];
118832 this.outputShape = computeOutShape$1(shapes, axis);
118833 var shape = this.outputShape;
118834 var rank = shape.length;
118835 var dtype = getCoordsDataType(rank);
118836 var coords = getChannels('coords', rank);
118837 var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);
118838 this.variableNames = shapes.map(function (_, i) {
118839 return "T".concat(i);
118840 });
118841 var offsets = new Array(shapes.length - 1);
118842 offsets[0] = shapes[0][axis];
118843 for (var i = 1; i < offsets.length; i++) {
118844 offsets[i] = offsets[i - 1] + shapes[i][axis];
118845 }
118846 var channel = channels[axis];
118847 var lastChannels = channels.slice(-2);
118848 var allChannels = channels.join();
118849 var getValueSnippet = "if (".concat(channel, " < ").concat(offsets[0], ") {\n return getChannel(\n getT0(").concat(allChannels, "), vec2(").concat(lastChannels.join(), "));\n }");
118850 for (var _i = 1; _i < offsets.length; _i++) {
118851 var _shift = offsets[_i - 1];
118852 // Note: the >= comparison below may seem unnecessary given the check
118853 // above but is needed to workaround branch execution issues on some
118854 // devices. It makes all the conditions exclusive without relying on
118855 // execution order.
118856 getValueSnippet += "\n if (".concat(channel, " < ").concat(offsets[_i], " && ").concat(channel, " >= ").concat(offsets[_i - 1], ") {\n return getChannel(\n getT").concat(_i, "(").concat(shiftedChannels(channels, channel, _shift), "),\n vec2(").concat(shiftedChannels(lastChannels, channel, _shift), "));\n }");
118857 }
118858 var lastIndex = offsets.length;
118859 var shift = offsets[offsets.length - 1];
118860 getValueSnippet += "\n return getChannel(\n getT".concat(lastIndex, "(").concat(shiftedChannels(channels, channel, shift), "),\n vec2(").concat(shiftedChannels(lastChannels, channel, shift), "));");
118861 this.userCode = "\n float getValue(".concat(channels.map(function (x) {
118862 return 'int ' + x;
118863 }), ") {\n ").concat(getValueSnippet, "\n }\n\n void main() {\n ").concat(dtype, " coords = getOutputCoords();\n vec4 result = vec4(getValue(").concat(coords, "), 0., 0., 0.);\n\n ").concat(coords[rank - 1], " = ").concat(coords[rank - 1], " + 1;\n if (").concat(coords[rank - 1], " < ").concat(shape[rank - 1], ") {\n result.g = getValue(").concat(coords, ");\n }\n\n ").concat(coords[rank - 2], " = ").concat(coords[rank - 2], " + 1;\n if (").concat(coords[rank - 2], " < ").concat(shape[rank - 2], ") {\n result.a = getValue(").concat(coords, ");\n }\n\n ").concat(coords[rank - 1], " = ").concat(coords[rank - 1], " - 1;\n if (").concat(coords[rank - 2], " < ").concat(shape[rank - 2], " &&\n ").concat(coords[rank - 1], " < ").concat(shape[rank - 1], ") {\n result.b = getValue(").concat(coords, ");\n }\n setOutput(result);\n }\n ");
118864 });
118865 /**
118866 * Return an expression for coordinates into a vector where a given channel
118867 * will be offset by [shift].
118868 *
118869 * @param channels the channels to consider
118870 * @param channel the channel we want shifted
118871 * @param shift the amount to subtract from the channel.
118872 *
118873 * @returns a string of the form 'x, y-[shift], z' where any one channel can
118874 * have the shift applied.
118875 */
118876 function shiftedChannels(channels, channel, shift) {
118877 var channelIdx = channels.indexOf(channel);
118878 var res = channels.map(function (c, idx) {
118879 if (idx === channelIdx) {
118880 return "".concat(c, " - ").concat(shift);
118881 } else {
118882 return c;
118883 }
118884 });
118885 return res.join();
118886 }
118887
118888 /**
118889 * @license
118890 * Copyright 2020 Google LLC. All Rights Reserved.
118891 * Licensed under the Apache License, Version 2.0 (the "License");
118892 * you may not use this file except in compliance with the License.
118893 * You may obtain a copy of the License at
118894 *
118895 * http://www.apache.org/licenses/LICENSE-2.0
118896 *
118897 * Unless required by applicable law or agreed to in writing, software
118898 * distributed under the License is distributed on an "AS IS" BASIS,
118899 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118900 * See the License for the specific language governing permissions and
118901 * limitations under the License.
118902 * =============================================================================
118903 */
118904 function imag(args) {
118905 var inputs = args.inputs,
118906 backend = args.backend;
118907 var input = inputs.input;
118908 var inputData = backend.texData.get(input.dataId);
118909 return identity({
118910 inputs: {
118911 x: inputData.complexTensorInfos.imag
118912 },
118913 backend: backend
118914 });
118915 }
118916 var imagConfig = {
118917 kernelName: Imag,
118918 backendName: 'webgl',
118919 kernelFunc: imag
118920 };
118921
118922 /**
118923 * @license
118924 * Copyright 2020 Google LLC. All Rights Reserved.
118925 * Licensed under the Apache License, Version 2.0 (the "License");
118926 * you may not use this file except in compliance with the License.
118927 * You may obtain a copy of the License at
118928 *
118929 * http://www.apache.org/licenses/LICENSE-2.0
118930 *
118931 * Unless required by applicable law or agreed to in writing, software
118932 * distributed under the License is distributed on an "AS IS" BASIS,
118933 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118934 * See the License for the specific language governing permissions and
118935 * limitations under the License.
118936 * =============================================================================
118937 */
118938 function concatImpl(inputs, axis, backend) {
118939 var dtype = inputs[0].dtype;
118940 if (dtype === 'complex64') {
118941 var reals = inputs.map(function (t) {
118942 return real({
118943 inputs: {
118944 input: t
118945 },
118946 backend: backend
118947 });
118948 });
118949 var imags = inputs.map(function (t) {
118950 return imag({
118951 inputs: {
118952 input: t
118953 },
118954 backend: backend
118955 });
118956 });
118957 var realConcated = concatImpl(reals, axis, backend);
118958 var imagConcated = concatImpl(imags, axis, backend);
118959 var _result = complex({
118960 inputs: {
118961 real: realConcated,
118962 imag: imagConcated
118963 },
118964 backend: backend
118965 });
118966 reals.forEach(function (r) {
118967 return backend.disposeIntermediateTensorInfo(r);
118968 });
118969 imags.forEach(function (i) {
118970 return backend.disposeIntermediateTensorInfo(i);
118971 });
118972 backend.disposeIntermediateTensorInfo(realConcated);
118973 backend.disposeIntermediateTensorInfo(imagConcated);
118974 return _result;
118975 }
118976 var runOnCpu = backend.shouldExecuteOnCPU(inputs);
118977 // Run on cpu if dtype is string. For string, the backend represents it
118978 // as Uint8Array[], where each Uint8Array is a character. Given that the
118979 // computation is only on the outer array, uploading the whole data onto
118980 // gpu is wasteful. Also, currently webgl doesn't have a design to
118981 // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we
118982 // just run the kernel on cpu if dtype is string.
118983 if (dtype === 'string') {
118984 runOnCpu = true;
118985 }
118986 if (runOnCpu) {
118987 // Any concat of n-dimensional tensors across any axis can be reduced to
118988 // a concatenation of two-dimensional tensors across the axis 1 by first
118989 // partitioning the axes of the original tensors into those less than the
118990 // axis to be concatenated and the rest. Then reshape the tensors
118991 // into a two-dimensional tensor by collapsing these two sets of axes and
118992 // concatenate the resulting matrices across the axis 1, finally reshaping
118993 // the result to have the proper shape.
118994 var _tensors2D = inputs.map(function (t) {
118995 var innerSize = sizeFromShape(t.shape.slice(axis));
118996 var shape = [-1, innerSize];
118997 return reshape({
118998 inputs: {
118999 x: t
119000 },
119001 backend: backend,
119002 attrs: {
119003 shape: shape
119004 }
119005 });
119006 });
119007 var inputsValShapes = _tensors2D.map(function (t) {
119008 return {
119009 vals: backend.readSync(t.dataId),
119010 shape: t.shape
119011 };
119012 });
119013 // Concats 2d tensors along axis=1.
119014 var _outShape = computeOutShape$1(_tensors2D.map(function (t) {
119015 return t.shape;
119016 }), 1 /* axis */);
119017 var simplyConcat = _tensors2D[0].shape[0] === 1;
119018 var outVals = concatImplCPU(inputsValShapes, _outShape, dtype, simplyConcat);
119019 var finalOutShape = computeOutShape$1(inputs.map(function (t) {
119020 return t.shape;
119021 }), axis);
119022 var outInfo = backend.makeTensorInfo(finalOutShape, dtype, outVals);
119023 _tensors2D.forEach(function (t) {
119024 return backend.disposeIntermediateTensorInfo(t);
119025 });
119026 return outInfo;
119027 }
119028 // Keep only non-empty tensors (ignore tensors with 0 in their shape).
119029 var $inputs = inputs.filter(function (t) {
119030 return sizeFromShape(t.shape) > 0;
119031 });
119032 var shouldPack = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && $inputs[0].shape.length > 1;
119033 if ($inputs.length === 1) {
119034 // Clone tensor.
119035 var _program = shouldPack ? new UnaryOpProgram(inputs[0].shape, CLONE) : new UnaryOpPackedProgram(inputs[0].shape, CLONE);
119036 return backend.runWebGLProgram(_program, inputs, dtype);
119037 }
119038 var maxTexturesInShader = env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER');
119039 if ($inputs.length > maxTexturesInShader) {
119040 var reducedInputs = [];
119041 for (var i = 0; i < $inputs.length; i += maxTexturesInShader) {
119042 var subArray = $inputs.slice(i, i + maxTexturesInShader);
119043 reducedInputs.push(concatImpl(subArray, axis, backend));
119044 }
119045 var _result2 = concatImpl(reducedInputs, axis, backend);
119046 for (var _i = 0, _reducedInputs = reducedInputs; _i < _reducedInputs.length; _i++) {
119047 var _i2 = _reducedInputs[_i];
119048 backend.disposeIntermediateTensorInfo(_i2);
119049 }
119050 return _result2;
119051 }
119052 if (shouldPack) {
119053 var _program2 = new ConcatPackedProgram($inputs.map(function (t) {
119054 return t.shape;
119055 }), axis);
119056 return backend.runWebGLProgram(_program2, $inputs, dtype);
119057 }
119058 var _computeTensors2D = computeTensors2D($inputs, axis, backend),
119059 tensors2D = _computeTensors2D.tensors2D,
119060 outShape = _computeTensors2D.outShape;
119061 var program = new ConcatProgram(tensors2D.map(function (t) {
119062 return t.shape;
119063 }));
119064 var result = backend.runWebGLProgram(program, tensors2D, dtype);
119065 tensors2D.forEach(function (r) {
119066 return backend.disposeIntermediateTensorInfo(r);
119067 });
119068 var reshapedResult = reshape({
119069 inputs: {
119070 x: result
119071 },
119072 attrs: {
119073 shape: outShape
119074 },
119075 backend: backend
119076 });
119077 backend.disposeIntermediateTensorInfo(result);
119078 return reshapedResult;
119079 }
119080 function computeTensors2D(inputs, axis, backend) {
119081 // Any concat of n-dimensional tensors across any axis can be reduced to
119082 // a concatenation of two-dimensional tensors across the axis 1 by first
119083 // partitioning the axes of the original tensors into those less than the
119084 // axis to be concatenated and the rest. Then reshape the tensors
119085 // into a two-dimensional tensor by collapsing these two sets of axes and
119086 // concatenate the resulting matrices across the axis 1, finally reshaping
119087 // the result to have the proper shape.
119088 var outShape = computeOutShape$1(inputs.map(function (t) {
119089 return t.shape;
119090 }), axis);
119091 var tensors2D = inputs.map(function (x) {
119092 return reshape({
119093 inputs: {
119094 x: x
119095 },
119096 attrs: {
119097 shape: [-1, sizeFromShape(x.shape.slice(axis))]
119098 },
119099 backend: backend
119100 });
119101 });
119102 return {
119103 tensors2D: tensors2D,
119104 outShape: outShape
119105 };
119106 }
119107
119108 /**
119109 * @license
119110 * Copyright 2020 Google LLC. All Rights Reserved.
119111 * Licensed under the Apache License, Version 2.0 (the "License");
119112 * you may not use this file except in compliance with the License.
119113 * You may obtain a copy of the License at
119114 *
119115 * http://www.apache.org/licenses/LICENSE-2.0
119116 *
119117 * Unless required by applicable law or agreed to in writing, software
119118 * distributed under the License is distributed on an "AS IS" BASIS,
119119 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
119120 * See the License for the specific language governing permissions and
119121 * limitations under the License.
119122 * =============================================================================
119123 */
119124 function concat(args) {
119125 var inputs = args.inputs,
119126 backend = args.backend,
119127 attrs = args.attrs;
119128 var axis = attrs.axis;
119129 var $axis = parseAxisParam(axis, inputs[0].shape)[0];
119130 var shapes = inputs.map(function (t) {
119131 return t.shape;
119132 });
119133 assertParamsConsistent(shapes, $axis);
119134 var outShape = computeOutShape$1(inputs.map(function (t) {
119135 return t.shape;
119136 }), $axis);
119137 if (sizeFromShape(outShape) === 0) {
119138 return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
119139 }
119140 // Keep only non-empty tensors (ignore tensors with 0 in their shape).
119141 var $inputs = inputs.filter(function (t) {
119142 return sizeFromShape(t.shape) > 0;
119143 });
119144 if ($inputs.length === 1) {
119145 return identity({
119146 inputs: {
119147 x: $inputs[0]
119148 },
119149 backend: backend
119150 });
119151 }
119152 return concatImpl($inputs, $axis, backend);
119153 }
119154 var concatConfig = {
119155 kernelName: Concat,
119156 backendName: 'webgl',
119157 kernelFunc: concat
119158 };
119159
119160 /**
119161 * @license
119162 * Copyright 2017 Google LLC. All Rights Reserved.
119163 * Licensed under the Apache License, Version 2.0 (the "License");
119164 * you may not use this file except in compliance with the License.
119165 * You may obtain a copy of the License at
119166 *
119167 * http://www.apache.org/licenses/LICENSE-2.0
119168 *
119169 * Unless required by applicable law or agreed to in writing, software
119170 * distributed under the License is distributed on an "AS IS" BASIS,
119171 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
119172 * See the License for the specific language governing permissions and
119173 * limitations under the License.
119174 * =============================================================================
119175 */
119176 var Conv2DProgram = /*#__PURE__*/_createClass(function Conv2DProgram(convInfo) {
119177 var addBias = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
119178 var activation = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : null;
119179 var hasPreluActivationWeights = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
119180 var hasLeakyreluAlpha = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
119181 _classCallCheck(this, Conv2DProgram);
119182 this.variableNames = ['x', 'W'];
119183 this.outputShape = convInfo.outShape;
119184 var padTop = convInfo.padInfo.top;
119185 var padLeft = convInfo.padInfo.left;
119186 var strideHeight = convInfo.strideHeight;
119187 var strideWidth = convInfo.strideWidth;
119188 var dilationHeight = convInfo.dilationHeight;
119189 var dilationWidth = convInfo.dilationWidth;
119190 var filterHeight = convInfo.filterHeight;
119191 var filterWidth = convInfo.filterWidth;
119192 var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
119193 var inputDepthVec4Remainder = convInfo.inChannels % 4;
119194 var isChannelsLast = convInfo.dataFormat === 'channelsLast';
119195 var rowDim = isChannelsLast ? 1 : 2;
119196 var colDim = isChannelsLast ? 2 : 3;
119197 var channelDim = isChannelsLast ? 3 : 1;
119198 var activationSnippet = '',
119199 applyActivationSnippet = '';
119200 if (activation) {
119201 if (hasPreluActivationWeights) {
119202 activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n ".concat(activation, "\n }");
119203 } else if (hasLeakyreluAlpha) {
119204 activationSnippet = "float activation(float a) {\n float b = getLeakyreluAlphaAtOutCoords();\n ".concat(activation, "\n }");
119205 } else {
119206 activationSnippet = "\n float activation(float x) {\n ".concat(activation, "\n }\n ");
119207 }
119208 applyActivationSnippet = "result = activation(result);";
119209 }
119210 var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
119211 if (addBias) {
119212 this.variableNames.push('bias');
119213 }
119214 if (hasPreluActivationWeights) {
119215 this.variableNames.push('preluActivationWeights');
119216 }
119217 if (hasLeakyreluAlpha) {
119218 this.variableNames.push('leakyreluAlpha');
119219 }
119220 this.userCode = "\n ".concat(activationSnippet, "\n\n const ivec2 strides = ivec2(").concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec2 pads = ivec2(").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d2 = coords[").concat(channelDim, "];\n\n ivec2 xRCCorner =\n ivec2(coords[").concat(rowDim, "], coords[").concat(colDim, "]) * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n int xR = xRCorner + wR * ").concat(dilationHeight, ";\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n int xC = xCCorner + wC * ").concat(dilationWidth, ";\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n for (int d1 = 0; d1 < ").concat(inputDepthNearestVec4, "; d1 += 4) {\n vec4 wValues = vec4(\n getW(wR, wC, d1, d2),\n getW(wR, wC, d1 + 1, d2),\n getW(wR, wC, d1 + 2, d2),\n getW(wR, wC, d1 + 3, d2)\n );\n\n if (").concat(isChannelsLast, ") {\n vec4 xValues = vec4(\n getX(batch, xR, xC, d1),\n getX(batch, xR, xC, d1 + 1),\n getX(batch, xR, xC, d1 + 2),\n getX(batch, xR, xC, d1 + 3)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec4 xValues = vec4(\n getX(batch, d1, xR, xC),\n getX(batch, d1 + 1, xR, xC),\n getX(batch, d1 + 2, xR, xC),\n getX(batch, d1 + 3, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n\n if (").concat(inputDepthVec4Remainder === 1, ") {\n\n if (").concat(isChannelsLast, ") {\n dotProd +=\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, ") *\n getW(wR, wC, ").concat(inputDepthNearestVec4, ", d2);\n } else {\n dotProd +=\n getX(batch, ").concat(inputDepthNearestVec4, ", xR, xC) *\n getW(wR, wC, ").concat(inputDepthNearestVec4, ", d2);\n }\n\n } else if (").concat(inputDepthVec4Remainder === 2, ") {\n vec2 wValues = vec2(\n getW(wR, wC, ").concat(inputDepthNearestVec4, ", d2),\n getW(wR, wC, ").concat(inputDepthNearestVec4, " + 1, d2)\n );\n\n if (").concat(isChannelsLast, ") {\n vec2 xValues = vec2(\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, "),\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, " + 1)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec2 xValues = vec2(\n getX(batch, ").concat(inputDepthNearestVec4, ", xR, xC),\n getX(batch, ").concat(inputDepthNearestVec4, " + 1, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n } else if (").concat(inputDepthVec4Remainder === 3, ") {\n vec3 wValues = vec3(\n getW(wR, wC, ").concat(inputDepthNearestVec4, ", d2),\n getW(wR, wC, ").concat(inputDepthNearestVec4, " + 1, d2),\n getW(wR, wC, ").concat(inputDepthNearestVec4, " + 2, d2)\n );\n\n if (").concat(isChannelsLast, ") {\n vec3 xValues = vec3(\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, "),\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, " + 1),\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, " + 2)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec3 xValues = vec3(\n getX(batch, ").concat(inputDepthNearestVec4, ", xR, xC),\n getX(batch, ").concat(inputDepthNearestVec4, " + 1, xR, xC),\n getX(batch, ").concat(inputDepthNearestVec4, " + 2, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n }\n }\n }\n\n float result = dotProd;\n ").concat(addBiasSnippet, "\n ").concat(applyActivationSnippet, "\n setOutput(result);\n }\n ");
119221 });
119222 var Conv3DProgram = /*#__PURE__*/_createClass(function Conv3DProgram(convInfo) {
119223 _classCallCheck(this, Conv3DProgram);
119224 this.variableNames = ['x', 'W'];
119225 this.outputShape = convInfo.outShape;
119226 var padFront = convInfo.padInfo.front;
119227 var padTop = convInfo.padInfo.top;
119228 var padLeft = convInfo.padInfo.left;
119229 var strideDepth = convInfo.strideDepth;
119230 var strideHeight = convInfo.strideHeight;
119231 var strideWidth = convInfo.strideWidth;
119232 var dilationDepth = convInfo.dilationDepth;
119233 var dilationHeight = convInfo.dilationHeight;
119234 var dilationWidth = convInfo.dilationWidth;
119235 var filterDepth = convInfo.filterDepth;
119236 var filterHeight = convInfo.filterHeight;
119237 var filterWidth = convInfo.filterWidth;
119238 var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
119239 var inputDepthVec4Remainder = convInfo.inChannels % 4;
119240 this.userCode = "\n const ivec3 strides = ivec3(".concat(strideDepth, ", ").concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec3 pads = ivec3(").concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d2 = coords.u;\n\n ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xFCorner = xFRCCorner.x;\n int xRCorner = xFRCCorner.y;\n int xCCorner = xFRCCorner.z;\n\n // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get\n // y(yF, yR, yC, d2). ? = to be determined. : = across all\n // values in that axis.\n float dotProd = 0.0;\n for (int wF = 0; wF < ").concat(filterDepth, "; wF++) {\n int xF = xFCorner + wF * ").concat(dilationDepth, ";\n\n if (xF < 0 || xF >= ").concat(convInfo.inDepth, ") {\n continue;\n }\n\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n int xR = xRCorner + wR * ").concat(dilationHeight, ";\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n int xC = xCCorner + wC * ").concat(dilationWidth, ";\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n for (int d1 = 0; d1 < ").concat(inputDepthNearestVec4, "; d1 += 4) {\n vec4 xValues = vec4(\n getX(batch, xF, xR, xC, d1),\n getX(batch, xF, xR, xC, d1 + 1),\n getX(batch, xF, xR, xC, d1 + 2),\n getX(batch, xF, xR, xC, d1 + 3)\n );\n vec4 wValues = vec4(\n getW(wF, wR, wC, d1, d2),\n getW(wF, wR, wC, d1 + 1, d2),\n getW(wF, wR, wC, d1 + 2, d2),\n getW(wF, wR, wC, d1 + 3, d2)\n );\n\n dotProd += dot(xValues, wValues);\n }\n\n if (").concat(inputDepthVec4Remainder === 1, ") {\n dotProd +=\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, ") *\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, ", d2);\n } else if (").concat(inputDepthVec4Remainder === 2, ") {\n vec2 xValues = vec2(\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, "),\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, " + 1)\n );\n vec2 wValues = vec2(\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, ", d2),\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, " + 1, d2)\n );\n dotProd += dot(xValues, wValues);\n } else if (").concat(inputDepthVec4Remainder === 3, ") {\n vec3 xValues = vec3(\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, "),\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, " + 1),\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, " + 2)\n );\n vec3 wValues = vec3(\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, ", d2),\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, " + 1, d2),\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, " + 2, d2)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ");
119241 });
119242
119243 var Conv2DPackedProgram = /*#__PURE__*/_createClass(function Conv2DPackedProgram(convInfo) {
119244 var addBias = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
119245 var activation = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : null;
119246 var hasPreluActivation = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
119247 var hasLeakyReluAlpha = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
119248 _classCallCheck(this, Conv2DPackedProgram);
119249 this.variableNames = ['x', 'W'];
119250 this.packedInputs = true;
119251 this.packedOutput = true;
119252 this.customUniforms = [{
119253 name: 'pads',
119254 type: 'ivec2'
119255 }, {
119256 name: 'strides',
119257 type: 'ivec2'
119258 }, {
119259 name: 'dilations',
119260 type: 'ivec2'
119261 }, {
119262 name: 'inDims',
119263 type: 'ivec2'
119264 }];
119265 this.outputShape = convInfo.outShape;
119266 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
119267 var padLeft = convInfo.padInfo.left;
119268 var strideWidth = convInfo.strideWidth;
119269 var dilationWidth = convInfo.dilationWidth;
119270 var filterHeight = convInfo.filterHeight;
119271 var filterWidth = convInfo.filterWidth;
119272 var texelsAcross = filterWidth;
119273 var mainLoop = "\n int xR; int xC; int xCOffset;\n vec4 wTexel; vec4 previous; vec4 final;";
119274 for (var c = 0; c < filterWidth; c++) {
119275 mainLoop += "\n vec4 xTexelC".concat(c * 2, ";\n int xTexelC").concat(c * 2, "Ready;\n vec4 xTexelC").concat(c * 2 + 1, ";\n int xTexelC").concat(c * 2 + 1, "Ready;\n vec4 xC").concat(c, ";");
119276 }
119277 /**
119278 * This vectorized implementation works by gathering the values needed for
119279 * each output channel's dot product into vec4's and then multiplying them
119280 * all together (this happens in the final double for-loop below). Most of
119281 * the main loop consists of constructing these vec4's with the minimum
119282 * number of texture2D calls, which means making use of all four returned
119283 * values from a texture2D call at once.
119284 */
119285 mainLoop += "\n for (int r = 0; r < ".concat(filterHeight, "; r++) {\n for (int d1 = 0; d1 < ").concat(convInfo.inChannels, "; d1 += 2) {\n ");
119286 for (var _c = 0; _c < filterWidth; _c++) {
119287 mainLoop += "\n xTexelC".concat(_c * 2, " = vec4(0.0);\n xTexelC").concat(_c * 2, "Ready = 0;\n xTexelC").concat(_c * 2 + 1, " = vec4(0.0);\n xTexelC").concat(_c * 2 + 1, "Ready = 0;\n xC").concat(_c, " = vec4(0.0);");
119288 }
119289 mainLoop += "\n xR = xRCorner + r * dilations[0];\n if (xR >=0 && xR < inDims[0]) {\n ";
119290 for (var texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) {
119291 var colIndex = texelC * 2;
119292 mainLoop += "\n xC = xCCorner + ".concat(colIndex * dilationWidth, ";\n ");
119293 if (strideWidth === 1) {
119294 if (colIndex < filterWidth) {
119295 // If padding is odd, the outer texels have to be composed.
119296 if (padLeft % 2 === 1) {
119297 // TODO: Ensure vec4 previous does not result in redundant sample,
119298 // and avoid setting xTexelRC's that exceed the boundary in the
119299 // first place rather than resetting them to vec4(0)).
119300 // To compute xCOffset:
119301 // - If padding is odd, we must add 1 to ensure we ask for an
119302 // even-numbered row.
119303 // - We subtract 2 to access the previous texel.
119304 mainLoop += "\n xCOffset = xC + 1;\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n ");
119305 // This texel has been read in previous iteration if the dilation
119306 // is 1.
119307 if (dilationWidth === 1 && colIndex > 0) {
119308 mainLoop += "\n xC".concat(colIndex, " = vec4(xTexelC").concat(colIndex - 2, ".zw, xTexelC").concat(colIndex, ".xy);\n ");
119309 } else {
119310 mainLoop += "\n xCOffset = xC + 1 - 2;\n\n if (xCOffset >= 0 && xCOffset < inDims[1]) {\n previous = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n previous.zw = vec2(0.0);\n }\n\n xC".concat(colIndex, " = vec4(previous.zw, xTexelC").concat(colIndex, ".xy);\n } else {\n xC").concat(colIndex, " = vec4(0.0, 0.0, xTexelC").concat(colIndex, ".xy);\n }\n ");
119311 }
119312 } else {
119313 // Padding is even, so xRC corresponds to a single texel.
119314 mainLoop += "\n if (xC >= 0 && xC < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xC, d1);\n if (xC + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n xC").concat(colIndex, " = xTexelC").concat(colIndex, ";\n ");
119315 }
119316 if (colIndex + 1 < filterWidth) {
119317 // If dilation is even, the second entry should match the first
119318 // (either both are composed or both are single samples). But if
119319 // dilation is odd, then the second entry should be the opposite
119320 // of the first (if the first is composed, the second is a single
119321 // sample, and vice versa.)
119322 var nextTexelOffset = padLeft % 2 === 0 ? nearestLargerEven(dilationWidth) : dilationWidth;
119323 if (dilationWidth % 2 === 0 && padLeft % 2 === 1 || dilationWidth % 2 !== 0 && padLeft % 2 !== 1) {
119324 mainLoop += "\n xCOffset = xC + imod(pads[1], 2) + ".concat(nextTexelOffset, ";\n\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n ");
119325 // If dilation > 1 then the xRC's will not be able to share any
119326 // values, so each xRC will require two unique calls to getX.
119327 if (dilationWidth > 1) {
119328 mainLoop += "\n xCOffset -= 2;\n if (xCOffset >= 0 && xCOffset < inDims[1]) {\n previous = getX(batch, xR, xCOffset, d1);\n xC".concat(colIndex + 1, " = vec4(previous.zw, xTexelC").concat(colIndex + 1, ".xy);\n } else {\n xC").concat(colIndex + 1, " = vec4(0.0, 0.0, xTexelC").concat(colIndex + 1, ".xy);\n }\n ");
119329 } else {
119330 mainLoop += "\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".xy);\n ");
119331 }
119332 } else {
119333 // If dilation is 1 and padding is odd, we have already read the
119334 // texel when constructing the previous x value. Here we can
119335 // simply skip the texture read.
119336 if (nextTexelOffset === 1) {
119337 mainLoop += "\n xC".concat(colIndex + 1, " = xTexelC").concat(colIndex, ";\n ");
119338 } else {
119339 mainLoop += "\n xCOffset = xC + ".concat(nextTexelOffset, ";\n\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex + 1, " = xTexelC").concat(colIndex + 1, ";\n ");
119340 }
119341 }
119342 }
119343 }
119344 } else {
119345 // stride === 2
119346 if (colIndex < filterWidth) {
119347 // Depending on whether padLeft is even or odd, we want either the
119348 // xy or zw channels from X texels for xC${colIndex}. If padLeft is
119349 // even, xC${colIndex +1} is simply the zw channels of texels we've
119350 // already sampled. But if padLeft is odd, xC{$c + 1}.zw will
119351 // need to come from the xy channels of a new texel, hence the `
119352 // vec4
119353 // final` initialized below.
119354 if (padLeft % 2 === 1) {
119355 mainLoop += "\n xCOffset = xC + 1 - strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xCOffset, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xC + 1, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xC + 2 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".zw);\n ");
119356 if (colIndex + 1 < filterWidth) {
119357 mainLoop += "\n final = vec4(0.0);\n xCOffset = xC + 1 + strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1]) {\n final = getX(batch, xR, xCOffset, d1);\n }\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex + 1, ".xy, final.xy);\n ");
119358 }
119359 } else {
119360 mainLoop += "\n if(xC >= 0 && xC < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xC, d1);\n if (xC + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n xCOffset = xC + strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex, " = vec4(\n xTexelC").concat(colIndex, ".xy, xTexelC").concat(colIndex + 1, ".xy);\n ");
119361 if (colIndex + 1 < filterWidth) {
119362 mainLoop += "\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".zw);\n ");
119363 }
119364 }
119365 }
119366 }
119367 // localize the dotProd accumulation within the loop, the theory is for
119368 // GPU with limited cache, accumulate sum across large amount of
119369 // veriables will cause lots of cache misses. (i.e. 5x5 filter will have
119370 // 50 variables)
119371 if (colIndex < filterWidth) {
119372 mainLoop += "\n wTexel = getW(r, ".concat(colIndex, ", d1, d2);\n dotProd += xC").concat(colIndex, ".xxzz * vec4(wTexel.xy, wTexel.xy);\n if(d1 + 1 < ").concat(convInfo.inChannels, ") {\n dotProd += xC").concat(colIndex, ".yyww * vec4(wTexel.zw, wTexel.zw);\n }\n ");
119373 if (colIndex + 1 < filterWidth) {
119374 mainLoop += "\n wTexel = getW(r, ".concat(colIndex + 1, ", d1, d2);\n dotProd += xC").concat(colIndex + 1, ".xxzz * vec4(wTexel.xy, wTexel.xy);\n if(d1 + 1 < ").concat(convInfo.inChannels, ") {\n dotProd += xC").concat(colIndex + 1, ".yyww * vec4(wTexel.zw, wTexel.zw);\n }\n ");
119375 }
119376 }
119377 }
119378 mainLoop += "\n }\n ";
119379 mainLoop += "\n }\n ";
119380 mainLoop += "\n }\n ";
119381 var activationSnippet = '',
119382 applyActivationSnippet = '';
119383 if (activation) {
119384 if (hasPreluActivation) {
119385 activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n ".concat(activation, "\n }");
119386 } else if (hasLeakyReluAlpha) {
119387 activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getLeakyreluAlphaAtOutCoords();\n ".concat(activation, "\n }");
119388 } else {
119389 activationSnippet = "vec4 activation(vec4 x) {\n ".concat(activation, "\n }");
119390 }
119391 applyActivationSnippet = "result = activation(result);";
119392 }
119393 var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
119394 if (addBias) {
119395 this.variableNames.push('bias');
119396 }
119397 if (hasPreluActivation) {
119398 this.variableNames.push('preluActivationWeights');
119399 }
119400 if (hasLeakyReluAlpha) {
119401 this.variableNames.push('leakyreluAlpha');
119402 }
119403 this.userCode = "\n ".concat(activationSnippet, "\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n //intialize dotProd with a small epsilon seems to reduce GPU accuracy loss.\n vec4 dotProd = vec4(0.000000000000001);\n\n ").concat(mainLoop, "\n\n vec4 result = dotProd - vec4(0.000000000000001);\n ").concat(addBiasSnippet, "\n ").concat(applyActivationSnippet, "\n setOutput(result);\n }\n ");
119404 });
119405
119406 var Im2ColPackedProgram = /*#__PURE__*/_createClass(function Im2ColPackedProgram(outputShape, convInfo) {
119407 _classCallCheck(this, Im2ColPackedProgram);
119408 this.variableNames = ['A'];
119409 this.packedInputs = true;
119410 this.packedOutput = true;
119411 this.customUniforms = [{
119412 name: 'inputShape',
119413 type: 'ivec4'
119414 }, {
119415 name: 'pad',
119416 type: 'ivec2'
119417 }, {
119418 name: 'stride',
119419 type: 'ivec2'
119420 }, {
119421 name: 'dilation',
119422 type: 'ivec2'
119423 }, {
119424 name: 'inChannels',
119425 type: 'int'
119426 }, {
119427 name: 'itemsPerBlockRow',
119428 type: 'int'
119429 }, {
119430 name: 'outWidth',
119431 type: 'int'
119432 }];
119433 this.outputShape = outputShape;
119434 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
119435 var dataFormat = convInfo.dataFormat;
119436 var glsl = getGlslDifferences();
119437 var isChannelsLast = dataFormat === 'channelsLast';
119438 var rowDim = isChannelsLast ? 1 : 2;
119439 var colDim = isChannelsLast ? 2 : 3;
119440 var boundsCheckingSnippet = this.enableShapeUniforms ? 'if(blockIndex < outShape[2] && pos < outShape[1]) {' : "if(blockIndex < ".concat(outputShape[2], " && pos < ").concat(outputShape[1], ") {");
119441 var unrolled = "";
119442 for (var row = 0; row <= 1; row++) {
119443 for (var col = 0; col <= 1; col++) {
119444 unrolled += "\n blockIndex = rc.z + ".concat(col, ";\n pos = rc.y + ").concat(row, ";\n\n ").concat(boundsCheckingSnippet, "\n offsetY = int(blockIndex / outWidth) * stride[0] - pad[0];\n d0 = offsetY + dilation[0] * (pos / itemsPerBlockRow);\n\n if(d0 < inputShape[").concat(rowDim, "] && d0 >= 0) {\n // Use custom imod instead mod. On Intel GPU, mod may generate\n // unexpected value.\n // https://github.com/tensorflow/tfjs/issues/5447\n offsetX = imod(blockIndex, outWidth) * stride[1] - pad[1];\n d1 = offsetX + dilation[1] * (imod(pos, itemsPerBlockRow) /\n inChannels);\n\n if(d1 < inputShape[").concat(colDim, "] && d1 >= 0) {\n\n ch = imod(pos, inChannels);\n\n if (").concat(isChannelsLast, ") {\n innerDims = vec2(d1, ch);\n result[").concat(row * 2 + col, "] = getChannel(\n getA(rc.x, d0, int(innerDims.x),\n int(innerDims.y)), innerDims);\n } else {\n innerDims = vec2(d0, d1);\n result[").concat(row * 2 + col, "] = getChannel(\n getA(rc.x, ch, int(innerDims.x),\n int(innerDims.y)), innerDims);\n }\n }\n }\n }\n ");
119445 }
119446 }
119447 this.userCode = "\n void main() {\n ivec3 rc = getOutputCoords();\n\n vec4 result = vec4(0);\n\n int blockIndex, pos, offsetY, d0, offsetX, d1, ch;\n vec2 innerDims;\n\n ".concat(unrolled, "\n\n ").concat(glsl.output, " = result;\n }\n ");
119448 });
119449
119450 // Both conv2dByMatMul and conv2dWithIm2Row fuse height and width into one
119451 // dimension to compute batchMatMul, so bias and activation weights are also
119452 // supposed to fuse the two dimensions into one.
119453 //
119454 // This function computes the target shape for fusing height and width
119455 // dimensions. Returning null means the shape is already compatible.
119456 //
119457 // Even though the bias is not supposed to be a 3-D or a 4-D (including
119458 // batch) tensor and PReLU activiation weights is not supposed to be a 4-D
119459 // tensor, we still need to support them, because we haven't disabled
119460 // them for NHWC format.
119461 // https://github.com/tensorflow/tfjs/blob/b53bd47e880367ae57493f0ea628abaf08db2d5d/tfjs-core/src/ops/fused/conv2d.ts#L181-L196
119462 function getShapeForBatchMatMul(shape, isChannelsLast) {
119463 var length = shape.length;
119464 if (length >= 3) {
119465 return isChannelsLast ? [].concat(_toConsumableArray(shape.slice(0, -3)), [shape[length - 3] * shape[length - 2] /* height * width */, shape[length - 1] /* channel */]) : [].concat(_toConsumableArray(shape.slice(0, -3)), [shape[length - 3] /* channel */, shape[length - 2] * shape[length - 1] /* height * width */]);
119466 } else if (!isChannelsLast && length === 1 && shape[0] > 1) {
119467 return [shape[0], 1];
119468 } else {
119469 return null;
119470 }
119471 }
119472 // For 1x1 kernels that iterate through every point in the input, convolution
119473 // can be expressed as matrix multiplication (without need for memory
119474 // remapping).
119475 function conv2dByMatMul(_ref) {
119476 var x = _ref.x,
119477 filter = _ref.filter,
119478 convInfo = _ref.convInfo,
119479 backend = _ref.backend,
119480 _ref$bias = _ref.bias,
119481 bias = _ref$bias === void 0 ? null : _ref$bias,
119482 _ref$preluActivationW = _ref.preluActivationWeights,
119483 preluActivationWeights = _ref$preluActivationW === void 0 ? null : _ref$preluActivationW,
119484 _ref$leakyreluAlpha = _ref.leakyreluAlpha,
119485 leakyreluAlpha = _ref$leakyreluAlpha === void 0 ? 0 : _ref$leakyreluAlpha,
119486 _ref$activation = _ref.activation,
119487 activation = _ref$activation === void 0 ? null : _ref$activation;
119488 // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
119489 // result from 2D to 4D.
119490 var xShape = x.shape;
119491 var xTexData = backend.texData.get(x.dataId);
119492 var sharedMatMulDim = convInfo.inChannels;
119493 var outerShapeX = xShape[0] * xShape[1] * xShape[2];
119494 var outerShapeFilter = convInfo.outChannels;
119495 var isChannelsLast = convInfo.dataFormat === 'channelsLast';
119496 var transposeA = false;
119497 var transposeB = false;
119498 var out;
119499 var intermediates = [];
119500 if (preluActivationWeights != null) {
119501 var targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast);
119502 if (targetShape != null) {
119503 preluActivationWeights = reshape({
119504 inputs: {
119505 x: preluActivationWeights
119506 },
119507 backend: backend,
119508 attrs: {
119509 shape: targetShape
119510 }
119511 });
119512 intermediates.push(preluActivationWeights);
119513 }
119514 }
119515 if (bias != null) {
119516 var _targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast);
119517 if (_targetShape != null) {
119518 bias = reshape({
119519 inputs: {
119520 x: bias
119521 },
119522 backend: backend,
119523 attrs: {
119524 shape: _targetShape
119525 }
119526 });
119527 intermediates.push(bias);
119528 }
119529 }
119530 // TODO: Once reduction ops are packed, batchMatMul will always be packed
119531 // and we can remove this condition.
119532 var batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) && sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD;
119533 // The algorithm in the if condition assumes (1) the output will be packed,
119534 // (2) x is packed, (3) x isChannelsLast, (4) x's packed texture is already
119535 // on GPU, (5) col is odd, (6) the width, height and inChannels are the same
119536 // for xTexData.shape and xShape.
119537 var canOptimize = !batchMatMulWillBeUnpacked && xTexData.isPacked && isChannelsLast && xTexData.texture != null && xShape[2] % 2 !== 0 && arraysEqual(xTexData.shape.slice(-3), xShape.slice(-3));
119538 if (canOptimize) {
119539 // We avoid expensive packed 2x2 reshape by padding col count to next,
119540 // even number. When col is odd, the result of packed batchMatMul is
119541 // the same (has the same texture layout and and values in the texture) as
119542 // it is for next even col. We make the odd-cols tensor to look like
119543 // even-cols tensor before the operation and, after the batchMatMul,
119544 // fix the even-cols result to have odd number of cols.
119545 var _targetShape2 = xShape[0] * xShape[1] * (xShape[2] + 1);
119546 var xReshaped = {
119547 dataId: x.dataId,
119548 shape: [1, _targetShape2, convInfo.inChannels],
119549 dtype: x.dtype
119550 };
119551 // xTexData.shape gets referenced from GPGPUBinary.inShapeInfos.
119552 // Decrementing col count, after batchMatMul->...->compileProgram leads to
119553 // invalid col count within the reference in GPGPUBinary.inShapeInfos.
119554 // Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos
119555 // in compileProgram method, but that would affect compilation of all
119556 // programs - instead, provide a copy here, with even col count, before
119557 // calling batchMatMul->...->compileProgram and after that, the original
119558 // xTexData.shape is restored.
119559 var originalXTexDataShape = xTexData.shape;
119560 xTexData.shape = xTexData.shape.slice();
119561 xTexData.shape[xTexData.shape.length - 2]++;
119562 assert$1(isReshapeFree(xTexData.shape, xReshaped.shape), function () {
119563 return "packed reshape ".concat(xTexData.shape, " to ").concat(xReshaped.shape, " isn't free");
119564 });
119565 var filterReshaped = reshape({
119566 inputs: {
119567 x: filter
119568 },
119569 backend: backend,
119570 attrs: {
119571 shape: [1, convInfo.inChannels, convInfo.outChannels]
119572 }
119573 });
119574 intermediates.push(filterReshaped);
119575 var pointwiseConv = batchMatMulImpl({
119576 a: xReshaped,
119577 b: filterReshaped,
119578 backend: backend,
119579 transposeA: transposeA,
119580 transposeB: transposeB,
119581 bias: bias,
119582 activation: activation,
119583 preluActivationWeights: preluActivationWeights,
119584 leakyreluAlpha: leakyreluAlpha
119585 });
119586 var pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId);
119587 assert$1(pointwiseConvTexData.isPacked, function () {
119588 return 'batchMatMul result is expected to be packed';
119589 });
119590 // Restore the input shape to original.
119591 xTexData.shape = originalXTexDataShape;
119592 // Set the output shape - there is no need for expensive reshape as data
119593 // layout is already correct.
119594 pointwiseConvTexData.shape = convInfo.outShape;
119595 out = identity({
119596 inputs: {
119597 x: pointwiseConv
119598 },
119599 backend: backend
119600 });
119601 out.shape = convInfo.outShape;
119602 intermediates.push(pointwiseConv);
119603 } else {
119604 var numCols = convInfo.outHeight * convInfo.outWidth;
119605 var _xReshaped = reshape({
119606 inputs: {
119607 x: x
119608 },
119609 backend: backend,
119610 attrs: {
119611 shape: isChannelsLast ? [convInfo.batchSize, numCols, convInfo.inChannels] : [convInfo.batchSize, convInfo.inChannels, numCols]
119612 }
119613 });
119614 var _filterReshaped = reshape({
119615 inputs: {
119616 x: filter
119617 },
119618 backend: backend,
119619 attrs: {
119620 shape: [1, convInfo.inChannels, convInfo.outChannels]
119621 }
119622 });
119623 var result = batchMatMulImpl({
119624 a: isChannelsLast ? _xReshaped : _filterReshaped,
119625 b: isChannelsLast ? _filterReshaped : _xReshaped,
119626 transposeA: !isChannelsLast,
119627 transposeB: transposeB,
119628 backend: backend,
119629 bias: bias,
119630 activation: activation,
119631 preluActivationWeights: preluActivationWeights,
119632 leakyreluAlpha: leakyreluAlpha
119633 });
119634 out = reshape({
119635 inputs: {
119636 x: result
119637 },
119638 backend: backend,
119639 attrs: {
119640 shape: convInfo.outShape
119641 }
119642 });
119643 intermediates.push(_xReshaped);
119644 intermediates.push(_filterReshaped);
119645 intermediates.push(result);
119646 }
119647 for (var _i = 0, _intermediates = intermediates; _i < _intermediates.length; _i++) {
119648 var i = _intermediates[_i];
119649 backend.disposeIntermediateTensorInfo(i);
119650 }
119651 return out;
119652 }
119653 // Implements the im2row algorithm as outlined in "High Performance
119654 // Convolutional Neural Networks for Document Processing" (Suvisoft, 2006)
119655 function conv2dWithIm2Row(_ref2) {
119656 var x = _ref2.x,
119657 filter = _ref2.filter,
119658 convInfo = _ref2.convInfo,
119659 backend = _ref2.backend,
119660 _ref2$bias = _ref2.bias,
119661 bias = _ref2$bias === void 0 ? null : _ref2$bias,
119662 _ref2$preluActivation = _ref2.preluActivationWeights,
119663 preluActivationWeights = _ref2$preluActivation === void 0 ? null : _ref2$preluActivation,
119664 _ref2$leakyreluAlpha = _ref2.leakyreluAlpha,
119665 leakyreluAlpha = _ref2$leakyreluAlpha === void 0 ? 0 : _ref2$leakyreluAlpha,
119666 _ref2$activation = _ref2.activation,
119667 activation = _ref2$activation === void 0 ? null : _ref2$activation;
119668 // Rearranges conv2d input so each block to be convolved over forms the
119669 // column of a new matrix with shape [filterWidth * filterHeight *
119670 // inChannels, outHeight * outWidth]. The filter is also rearranged so each
119671 // output channel forms a row of a new matrix with shape [outChannels,
119672 // filterWidth * filterHeight * inChannels]. The convolution is then
119673 // computed by multiplying these matrices and reshaping the result.
119674 var filterWidth = convInfo.filterWidth,
119675 filterHeight = convInfo.filterHeight,
119676 inChannels = convInfo.inChannels,
119677 outWidth = convInfo.outWidth,
119678 outHeight = convInfo.outHeight,
119679 dataFormat = convInfo.dataFormat;
119680 var isChannelsLast = dataFormat === 'channelsLast';
119681 var sharedDim = filterWidth * filterHeight * inChannels;
119682 var numCols = outHeight * outWidth;
119683 var x2ColShape = [convInfo.batchSize, sharedDim, numCols];
119684 var transposeA = true;
119685 var transposeB = false;
119686 var intermediates = [];
119687 if (preluActivationWeights != null) {
119688 var targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast);
119689 if (targetShape != null) {
119690 preluActivationWeights = reshape({
119691 inputs: {
119692 x: preluActivationWeights
119693 },
119694 backend: backend,
119695 attrs: {
119696 shape: targetShape
119697 }
119698 });
119699 intermediates.push(preluActivationWeights);
119700 }
119701 }
119702 if (bias != null) {
119703 var _targetShape3 = getShapeForBatchMatMul(bias.shape, isChannelsLast);
119704 if (_targetShape3 != null) {
119705 bias = reshape({
119706 inputs: {
119707 x: bias
119708 },
119709 backend: backend,
119710 attrs: {
119711 shape: _targetShape3
119712 }
119713 });
119714 intermediates.push(bias);
119715 }
119716 }
119717 var w2Row = reshape({
119718 inputs: {
119719 x: filter
119720 },
119721 backend: backend,
119722 attrs: {
119723 shape: [1, sharedDim, sizeFromShape(filter.shape) / sharedDim]
119724 }
119725 });
119726 intermediates.push(w2Row);
119727 var im2ColProgram = new Im2ColPackedProgram(x2ColShape, convInfo);
119728 var customValues = [x.shape, [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inChannels], [convInfo.filterWidth * convInfo.inChannels], [convInfo.outWidth]];
119729 var im2Col = backend.runWebGLProgram(im2ColProgram, [x], 'float32', customValues);
119730 var im2ColReshaped = reshape({
119731 inputs: {
119732 x: im2Col
119733 },
119734 backend: backend,
119735 attrs: {
119736 shape: x2ColShape
119737 }
119738 });
119739 intermediates.push(im2Col);
119740 intermediates.push(im2ColReshaped);
119741 var hasBias = bias != null;
119742 var hasPreluActivationWeights = preluActivationWeights != null;
119743 var hasLeakyreluAlpha = activation === 'leakyrelu';
119744 var fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
119745 var matmulProgram = new MatMulPackedProgram(isChannelsLast ? im2ColReshaped.shape : w2Row.shape, isChannelsLast ? w2Row.shape : im2ColReshaped.shape, isChannelsLast ? [convInfo.batchSize, numCols, convInfo.outChannels] : [convInfo.batchSize, convInfo.outChannels, numCols], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
119746 var inputs = isChannelsLast ? [im2ColReshaped, w2Row] : [w2Row, im2ColReshaped];
119747 if (bias) {
119748 inputs.push(bias);
119749 }
119750 if (hasPreluActivationWeights) {
119751 inputs.push(preluActivationWeights);
119752 }
119753 if (hasLeakyreluAlpha) {
119754 var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
119755 inputs.push($leakyreluAlpha);
119756 intermediates.push($leakyreluAlpha);
119757 }
119758 var product = backend.runWebGLProgram(matmulProgram, inputs, 'float32');
119759 var out = reshape({
119760 inputs: {
119761 x: product
119762 },
119763 backend: backend,
119764 attrs: {
119765 shape: convInfo.outShape
119766 }
119767 });
119768 intermediates.push(product);
119769 for (var _i2 = 0, _intermediates2 = intermediates; _i2 < _intermediates2.length; _i2++) {
119770 var i = _intermediates2[_i2];
119771 backend.disposeIntermediateTensorInfo(i);
119772 }
119773 return out;
119774 }
119775
119776 /**
119777 * @license
119778 * Copyright 2020 Google LLC. All Rights Reserved.
119779 * Licensed under the Apache License, Version 2.0 (the "License");
119780 * you may not use this file except in compliance with the License.
119781 * You may obtain a copy of the License at
119782 *
119783 * http://www.apache.org/licenses/LICENSE-2.0
119784 *
119785 * Unless required by applicable law or agreed to in writing, software
119786 * distributed under the License is distributed on an "AS IS" BASIS,
119787 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
119788 * See the License for the specific language governing permissions and
119789 * limitations under the License.
119790 * =============================================================================
119791 */
119792 function conv2d(args) {
119793 var inputs = args.inputs,
119794 backend = args.backend,
119795 attrs = args.attrs;
119796 var x = inputs.x,
119797 filter = inputs.filter;
119798 var strides = attrs.strides,
119799 pad = attrs.pad,
119800 dataFormat = attrs.dataFormat,
119801 dilations = attrs.dilations,
119802 dimRoundingMode = attrs.dimRoundingMode;
119803 var $dataFormat = convertConv2DDataFormat(dataFormat);
119804 var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
119805 var out;
119806 if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
119807 out = conv2dByMatMul({
119808 x: x,
119809 filter: filter,
119810 convInfo: convInfo,
119811 backend: backend
119812 });
119813 } else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast' && env().getBool('WEBGL_EXP_CONV')) {
119814 var program = new Conv2DPackedProgram(convInfo);
119815 var customValues = [[convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth]];
119816 out = backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
119817 } else if (env().getBool('WEBGL_CONV_IM2COL')) {
119818 out = conv2dWithIm2Row({
119819 x: x,
119820 filter: filter,
119821 convInfo: convInfo,
119822 backend: backend
119823 });
119824 } else {
119825 var _program = new Conv2DProgram(convInfo);
119826 out = backend.runWebGLProgram(_program, [x, filter], 'float32');
119827 }
119828 var outReshaped = reshape({
119829 inputs: {
119830 x: out
119831 },
119832 backend: backend,
119833 attrs: {
119834 shape: convInfo.outShape
119835 }
119836 });
119837 backend.disposeIntermediateTensorInfo(out);
119838 return outReshaped;
119839 }
119840 var conv2DConfig = {
119841 kernelName: Conv2D$1,
119842 backendName: 'webgl',
119843 kernelFunc: conv2d
119844 };
119845
119846 /**
119847 * @license
119848 * Copyright 2017 Google LLC. All Rights Reserved.
119849 * Licensed under the Apache License, Version 2.0 (the "License");
119850 * you may not use this file except in compliance with the License.
119851 * You may obtain a copy of the License at
119852 *
119853 * http://www.apache.org/licenses/LICENSE-2.0
119854 *
119855 * Unless required by applicable law or agreed to in writing, software
119856 * distributed under the License is distributed on an "AS IS" BASIS,
119857 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
119858 * See the License for the specific language governing permissions and
119859 * limitations under the License.
119860 * =============================================================================
119861 */
119862 var Conv2DDerFilterProgram = /*#__PURE__*/_createClass(function Conv2DDerFilterProgram(convInfo) {
119863 _classCallCheck(this, Conv2DDerFilterProgram);
119864 this.variableNames = ['x', 'dy'];
119865 this.outputShape = convInfo.filterShape;
119866 var strideHeight = convInfo.strideHeight;
119867 var strideWidth = convInfo.strideWidth;
119868 var padTop = convInfo.padInfo.top;
119869 var padLeft = convInfo.padInfo.left;
119870 var isChannelsLast = convInfo.dataFormat === 'channelsLast';
119871 this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int d2 = coords.w;\n\n // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int b = 0; b < ".concat(convInfo.batchSize, "; b++) {\n for (int yR = 0; yR < ").concat(convInfo.outHeight, "; yR++) {\n int xR = wR + yR * ").concat(strideHeight, " - ").concat(padTop, ";\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int yC = 0; yC < ").concat(convInfo.outWidth, "; yC++) {\n int xC = wC + yC * ").concat(strideWidth, " - ").concat(padLeft, ";\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n ").concat(isChannelsLast ? "float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);" : "float dyValue = getDy(b, d2, yR, yC);\n float xValue = getX(b, d1, xR, xC);\n dotProd += (xValue * dyValue);", "\n }\n }\n }\n setOutput(dotProd);\n }\n ");
119872 });
119873 var Conv2DDerInputProgram = /*#__PURE__*/_createClass(function Conv2DDerInputProgram(convInfo) {
119874 _classCallCheck(this, Conv2DDerInputProgram);
119875 this.variableNames = ['dy', 'W'];
119876 this.outputShape = convInfo.inShape;
119877 var filterHeight = convInfo.filterHeight;
119878 var filterWidth = convInfo.filterWidth;
119879 var strideHeight = convInfo.strideHeight;
119880 var strideWidth = convInfo.strideWidth;
119881 var isChannelsLast = convInfo.dataFormat === 'channelsLast';
119882 var padTop = filterHeight - 1 - convInfo.padInfo.top;
119883 var padLeft = filterWidth - 1 - convInfo.padInfo.left;
119884 var rowDim = isChannelsLast ? 1 : 2;
119885 var colDim = isChannelsLast ? 2 : 3;
119886 var channelDim = isChannelsLast ? 3 : 1;
119887 this.userCode = "\n const ivec2 pads = ivec2(".concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[").concat(channelDim, "];\n\n ivec2 dyCorner = ivec2(coords[").concat(rowDim, "], coords[").concat(colDim, "]) - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = ").concat(filterHeight, " - 1 - wR;\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = ").concat(filterWidth, " - 1 - wC;\n\n for (int d2 = 0; d2 < ").concat(convInfo.outChannels, "; d2++) {\n\n if (").concat(isChannelsLast, ") {\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n } else {\n float xValue = getDy(batch, d2, idyR, idyC);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n ");
119888 });
119889 var Conv3DDerFilterProgram = /*#__PURE__*/_createClass(function Conv3DDerFilterProgram(convInfo) {
119890 _classCallCheck(this, Conv3DDerFilterProgram);
119891 this.variableNames = ['x', 'dy'];
119892 this.outputShape = convInfo.filterShape;
119893 var strideDepth = convInfo.strideDepth;
119894 var strideHeight = convInfo.strideHeight;
119895 var strideWidth = convInfo.strideWidth;
119896 var padFront = convInfo.padInfo.front;
119897 var padTop = convInfo.padInfo.top;
119898 var padLeft = convInfo.padInfo.left;
119899 this.userCode = "\n void main() {\n ivec5 coords = getOutputCoords();\n int wF = coords.x;\n int wR = coords.y;\n int wC = coords.z;\n int d1 = coords.w;\n int d2 = coords.u;\n\n float dotProd = 0.0;\n\n for (int b = 0; b < ".concat(convInfo.batchSize, "; b++) {\n for (int yF = 0; yF < ").concat(convInfo.outDepth, "; yF++) {\n int xF = wF + yF * ").concat(strideDepth, " - ").concat(padFront, ";\n\n if (xF < 0 || xF >= ").concat(convInfo.inDepth, ") {\n continue;\n }\n\n for (int yR = 0; yR < ").concat(convInfo.outHeight, "; yR++) {\n int xR = wR + yR * ").concat(strideHeight, " - ").concat(padTop, ";\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int yC = 0; yC < ").concat(convInfo.outWidth, "; yC++) {\n int xC = wC + yC * ").concat(strideWidth, " - ").concat(padLeft, ";\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n float dyValue = getDy(b, yF, yR, yC, d2);\n float xValue = getX(b, xF, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ");
119900 });
119901 var Conv3DDerInputProgram = /*#__PURE__*/_createClass(function Conv3DDerInputProgram(convInfo) {
119902 _classCallCheck(this, Conv3DDerInputProgram);
119903 this.variableNames = ['dy', 'W'];
119904 this.outputShape = convInfo.inShape;
119905 var filterDepth = convInfo.filterDepth;
119906 var filterHeight = convInfo.filterHeight;
119907 var filterWidth = convInfo.filterWidth;
119908 var strideDepth = convInfo.strideDepth;
119909 var strideHeight = convInfo.strideHeight;
119910 var strideWidth = convInfo.strideWidth;
119911 var padFront = filterDepth - 1 - convInfo.padInfo.front;
119912 var padTop = filterHeight - 1 - convInfo.padInfo.top;
119913 var padLeft = filterWidth - 1 - convInfo.padInfo.left;
119914 this.userCode = "\n const ivec3 pads = ivec3(".concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.u;\n\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyFCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n float dotProd = 0.0;\n for (int wF = 0; wF < ").concat(filterDepth, "; wF++) {\n float dyF = float(dyFCorner + wF) / ").concat(strideDepth, ".0;\n\n if (dyF < 0.0 || dyF >= ").concat(convInfo.outDepth, ".0 || fract(dyF) > 0.0) {\n continue;\n }\n int idyF = int(dyF);\n\n int wFPerm = ").concat(filterDepth, " - 1 - wF;\n\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = ").concat(filterHeight, " - 1 - wR;\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = ").concat(filterWidth, " - 1 - wC;\n\n for (int d2 = 0; d2 < ").concat(convInfo.outChannels, "; d2++) {\n float xValue = getDy(batch, idyF, idyR, idyC, d2);\n float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ");
119915 });
119916
119917 /**
119918 * @license
119919 * Copyright 2020 Google LLC. All Rights Reserved.
119920 * Licensed under the Apache License, Version 2.0 (the "License");
119921 * you may not use this file except in compliance with the License.
119922 * You may obtain a copy of the License at
119923 *
119924 * http://www.apache.org/licenses/LICENSE-2.0
119925 *
119926 * Unless required by applicable law or agreed to in writing, software
119927 * distributed under the License is distributed on an "AS IS" BASIS,
119928 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
119929 * See the License for the specific language governing permissions and
119930 * limitations under the License.
119931 * =============================================================================
119932 */
119933 function conv2DBackpropFilter(args) {
119934 var inputs = args.inputs,
119935 backend = args.backend,
119936 attrs = args.attrs;
119937 var x = inputs.x,
119938 dy = inputs.dy;
119939 var strides = attrs.strides,
119940 pad = attrs.pad,
119941 dataFormat = attrs.dataFormat,
119942 dimRoundingMode = attrs.dimRoundingMode,
119943 filterShape = attrs.filterShape;
119944 var $dataFormat = convertConv2DDataFormat(dataFormat);
119945 var convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
119946 var program = new Conv2DDerFilterProgram(convInfo);
119947 return backend.runWebGLProgram(program, [x, dy], 'float32');
119948 }
119949 var conv2DBackpropFilterConfig = {
119950 kernelName: Conv2DBackpropFilter,
119951 backendName: 'webgl',
119952 kernelFunc: conv2DBackpropFilter
119953 };
119954
119955 var Conv2DDerInputPackedProgram = /*#__PURE__*/_createClass(function Conv2DDerInputPackedProgram(convInfo) {
119956 _classCallCheck(this, Conv2DDerInputPackedProgram);
119957 this.variableNames = ['dy', 'W'];
119958 this.packedInputs = true;
119959 this.packedOutput = true;
119960 this.customUniforms = [{
119961 name: 'strides',
119962 type: 'vec2'
119963 }];
119964 this.outputShape = convInfo.inShape;
119965 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
119966 var filterHeight = convInfo.filterHeight;
119967 var filterWidth = convInfo.filterWidth;
119968 var padTop = filterHeight - 1 - convInfo.padInfo.top;
119969 var padLeft = filterWidth - 1 - convInfo.padInfo.left;
119970 this.userCode = "\n const ivec2 pads = ivec2(".concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[3];\n\n ivec2 dyCorner = ivec2(coords[1], coords[2]) - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n vec4 result = vec4(0.);\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n float dyR = float(dyRCorner + wR) / strides[0];\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n int wRPerm = ").concat(filterHeight, " - 1 - wR;\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n int wCPerm = ").concat(filterWidth, " - 1 - wC;\n\n float dyC = float(dyCCorner + wC) / strides[1];\n bool idyCVal = (dyC >= 0.0) && (dyC < ").concat(convInfo.outWidth, ".0)\n && (fract(dyC) == 0.0);\n int idyC = int(dyC);\n\n float dyC2 = float(dyCCorner + wC + 1) / strides[1];\n bool idyCVal2 = (dyC2 >= 0.0) && (dyC2 < ").concat(convInfo.outWidth, ".0)\n && (fract(dyC2) == 0.0);\n int idyC2 = int(dyC2);\n\n if (idyCVal && idyCVal2) {\n for (int d2 = 0; d2 < ").concat(convInfo.outChannels, "; d2 += 2) {\n vec4 wValue = getW(wRPerm, wCPerm, d1, d2);\n vec4 dySample = getDy(batch, idyR, idyC, d2);\n vec4 dySample2 = (idyC / 2 == idyC2 / 2) ?\n dySample : getDy(batch, idyR, idyC2, d2);\n\n vec2 dyValue = mod(float(idyC), 2.) == 0. ?\n dySample.xy : dySample.zw;\n result.xy += vec2(dot(dyValue, wValue.xy),\n dot(dyValue, wValue.zw));\n\n dyValue = mod(float(idyC2), 2.) == 0. ?\n dySample2.xy : dySample2.zw;\n result.zw += vec2(dot(dyValue, wValue.xy),\n dot(dyValue, wValue.zw));\n }\n } else if (idyCVal) {\n for (int d2 = 0; d2 < ").concat(convInfo.outChannels, "; d2 += 2) {\n vec4 wValue = getW(wRPerm, wCPerm, d1, d2);\n vec4 dySample = getDy(batch, idyR, idyC, d2);\n vec2 dyValue = mod(float(idyC), 2.) == 0. ?\n dySample.xy : dySample.zw;\n result.xy += vec2(dot(dyValue, wValue.xy),\n dot(dyValue, wValue.zw));\n }\n } else if (idyCVal2) {\n for (int d2 = 0; d2 < ").concat(convInfo.outChannels, "; d2 += 2) {\n vec4 wValue = getW(wRPerm, wCPerm, d1, d2);\n vec4 dySample = getDy(batch, idyR, idyC2, d2);\n vec2 dyValue = mod(float(idyC2), 2.) == 0. ?\n dySample.xy : dySample.zw;\n result.zw += vec2(dot(dyValue, wValue.xy),\n dot(dyValue, wValue.zw));\n }\n }\n }\n }\n setOutput(result);\n }\n ");
119971 });
119972
119973 /**
119974 * @license
119975 * Copyright 2020 Google LLC. All Rights Reserved.
119976 * Licensed under the Apache License, Version 2.0 (the "License");
119977 * you may not use this file except in compliance with the License.
119978 * You may obtain a copy of the License at
119979 *
119980 * http://www.apache.org/licenses/LICENSE-2.0
119981 *
119982 * Unless required by applicable law or agreed to in writing, software
119983 * distributed under the License is distributed on an "AS IS" BASIS,
119984 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
119985 * See the License for the specific language governing permissions and
119986 * limitations under the License.
119987 * =============================================================================
119988 */
119989 function conv2DBackpropInput(args) {
119990 var inputs = args.inputs,
119991 backend = args.backend,
119992 attrs = args.attrs;
119993 var dy = inputs.dy,
119994 filter = inputs.filter;
119995 var inputShape = attrs.inputShape,
119996 strides = attrs.strides,
119997 pad = attrs.pad,
119998 dataFormat = attrs.dataFormat,
119999 dimRoundingMode = attrs.dimRoundingMode;
120000 var $dataFormat = convertConv2DDataFormat(dataFormat);
120001 var convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat);
120002 if (env().getBool('WEBGL_PACK_CONV2DTRANSPOSE') && $dataFormat === 'channelsLast') {
120003 var customValues = [[convInfo.strideHeight, convInfo.strideWidth]];
120004 var program = new Conv2DDerInputPackedProgram(convInfo);
120005 return backend.runWebGLProgram(program, [dy, filter], 'float32', customValues);
120006 } else {
120007 var _program = new Conv2DDerInputProgram(convInfo);
120008 return backend.runWebGLProgram(_program, [dy, filter], 'float32');
120009 }
120010 }
120011 var conv2DBackpropInputConfig = {
120012 kernelName: Conv2DBackpropInput,
120013 backendName: 'webgl',
120014 kernelFunc: conv2DBackpropInput
120015 };
120016
120017 /**
120018 * @license
120019 * Copyright 2020 Google LLC. All Rights Reserved.
120020 * Licensed under the Apache License, Version 2.0 (the "License");
120021 * you may not use this file except in compliance with the License.
120022 * You may obtain a copy of the License at
120023 *
120024 * http://www.apache.org/licenses/LICENSE-2.0
120025 *
120026 * Unless required by applicable law or agreed to in writing, software
120027 * distributed under the License is distributed on an "AS IS" BASIS,
120028 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120029 * See the License for the specific language governing permissions and
120030 * limitations under the License.
120031 * =============================================================================
120032 */
120033 function conv3D(args) {
120034 var inputs = args.inputs,
120035 backend = args.backend,
120036 attrs = args.attrs;
120037 var x = inputs.x,
120038 filter = inputs.filter;
120039 var strides = attrs.strides,
120040 pad = attrs.pad,
120041 dilations = attrs.dilations;
120042 var convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
120043 var program = new Conv3DProgram(convInfo);
120044 return backend.runWebGLProgram(program, [x, filter], 'float32');
120045 }
120046 var conv3DConfig = {
120047 kernelName: Conv3D$1,
120048 backendName: 'webgl',
120049 kernelFunc: conv3D
120050 };
120051
120052 /**
120053 * @license
120054 * Copyright 2020 Google LLC. All Rights Reserved.
120055 * Licensed under the Apache License, Version 2.0 (the "License");
120056 * you may not use this file except in compliance with the License.
120057 * You may obtain a copy of the License at
120058 *
120059 * http://www.apache.org/licenses/LICENSE-2.0
120060 *
120061 * Unless required by applicable law or agreed to in writing, software
120062 * distributed under the License is distributed on an "AS IS" BASIS,
120063 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120064 * See the License for the specific language governing permissions and
120065 * limitations under the License.
120066 * =============================================================================
120067 */
120068 function conv3DBackpropFilterV2(args) {
120069 var inputs = args.inputs,
120070 backend = args.backend,
120071 attrs = args.attrs;
120072 var x = inputs.x,
120073 dy = inputs.dy;
120074 var strides = attrs.strides,
120075 pad = attrs.pad,
120076 filterShape = attrs.filterShape;
120077 var convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad);
120078 var program = new Conv3DDerFilterProgram(convInfo);
120079 return backend.runWebGLProgram(program, [x, dy], 'float32');
120080 }
120081 var conv3DBackpropFilterV2Config = {
120082 kernelName: Conv3DBackpropFilterV2,
120083 backendName: 'webgl',
120084 kernelFunc: conv3DBackpropFilterV2
120085 };
120086
120087 /**
120088 * @license
120089 * Copyright 2020 Google LLC. All Rights Reserved.
120090 * Licensed under the Apache License, Version 2.0 (the "License");
120091 * you may not use this file except in compliance with the License.
120092 * You may obtain a copy of the License at
120093 *
120094 * http://www.apache.org/licenses/LICENSE-2.0
120095 *
120096 * Unless required by applicable law or agreed to in writing, software
120097 * distributed under the License is distributed on an "AS IS" BASIS,
120098 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120099 * See the License for the specific language governing permissions and
120100 * limitations under the License.
120101 * =============================================================================
120102 */
120103 function conv3DBackpropInput(args) {
120104 var inputs = args.inputs,
120105 backend = args.backend,
120106 attrs = args.attrs;
120107 var dy = inputs.dy,
120108 filter = inputs.filter;
120109 var pad = attrs.pad,
120110 strides = attrs.strides,
120111 inputShape = attrs.inputShape;
120112 var convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad);
120113 var program = new Conv3DDerInputProgram(convInfo);
120114 return backend.runWebGLProgram(program, [dy, filter], 'float32');
120115 }
120116 var conv3DBackpropInputConfig = {
120117 kernelName: Conv3DBackpropInputV2,
120118 backendName: 'webgl',
120119 kernelFunc: conv3DBackpropInput
120120 };
120121
120122 /**
120123 * @license
120124 * Copyright 2020 Google LLC. All Rights Reserved.
120125 * Licensed under the Apache License, Version 2.0 (the "License");
120126 * you may not use this file except in compliance with the License.
120127 * You may obtain a copy of the License at
120128 *
120129 * http://www.apache.org/licenses/LICENSE-2.0
120130 *
120131 * Unless required by applicable law or agreed to in writing, software
120132 * distributed under the License is distributed on an "AS IS" BASIS,
120133 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120134 * See the License for the specific language governing permissions and
120135 * limitations under the License.
120136 * =============================================================================
120137 */
120138 var COS = CHECK_NAN_SNIPPET_UNARY + "\n return cos(x);\n";
120139 var COS_PACKED = "\n vec4 result = cos(x);\n bvec4 isNaN = isnan(x);\n ".concat(CHECK_NAN_SNIPPET_PACKED, "\n return result;\n");
120140 var cos = unaryKernelFunc({
120141 opSnippet: COS,
120142 packedOpSnippet: COS_PACKED
120143 });
120144 var cosConfig = {
120145 kernelName: Cos,
120146 backendName: 'webgl',
120147 kernelFunc: cos
120148 };
120149
120150 /**
120151 * @license
120152 * Copyright 2020 Google LLC. All Rights Reserved.
120153 * Licensed under the Apache License, Version 2.0 (the "License");
120154 * you may not use this file except in compliance with the License.
120155 * You may obtain a copy of the License at
120156 *
120157 * http://www.apache.org/licenses/LICENSE-2.0
120158 *
120159 * Unless required by applicable law or agreed to in writing, software
120160 * distributed under the License is distributed on an "AS IS" BASIS,
120161 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120162 * See the License for the specific language governing permissions and
120163 * limitations under the License.
120164 * =============================================================================
120165 */
120166 var COSH = "\n float e2x = exp(-x);\n return (e2x + 1.0 / e2x) / 2.0;\n";
120167 var cosh = unaryKernelFunc({
120168 opSnippet: COSH
120169 });
120170 var coshConfig = {
120171 kernelName: Cosh,
120172 backendName: 'webgl',
120173 kernelFunc: cosh
120174 };
120175
120176 /**
120177 * @license
120178 * Copyright 2017 Google LLC. All Rights Reserved.
120179 * Licensed under the Apache License, Version 2.0 (the "License");
120180 * you may not use this file except in compliance with the License.
120181 * You may obtain a copy of the License at
120182 *
120183 * http://www.apache.org/licenses/LICENSE-2.0
120184 *
120185 * Unless required by applicable law or agreed to in writing, software
120186 * distributed under the License is distributed on an "AS IS" BASIS,
120187 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120188 * See the License for the specific language governing permissions and
120189 * limitations under the License.
120190 * =============================================================================
120191 */
120192 var CropAndResizeProgram = /*#__PURE__*/_createClass(function CropAndResizeProgram(imageShape, boxShape, cropSize, method, extrapolationValue) {
120193 _classCallCheck(this, CropAndResizeProgram);
120194 this.variableNames = ['Image', 'Boxes', 'BoxInd'];
120195 this.outputShape = [];
120196 var _imageShape = _slicedToArray(imageShape, 4),
120197 batch = _imageShape[0],
120198 imageHeight = _imageShape[1],
120199 imageWidth = _imageShape[2],
120200 depth = _imageShape[3];
120201 var _boxShape = _slicedToArray(boxShape, 1),
120202 numBoxes = _boxShape[0];
120203 var _cropSize = _slicedToArray(cropSize, 2),
120204 cropHeight = _cropSize[0],
120205 cropWidth = _cropSize[1];
120206 this.outputShape = [numBoxes, cropHeight, cropWidth, depth];
120207 var methodId = method === 'bilinear' ? 1 : 0;
120208 var inputHeightFloat = "".concat(imageHeight - 1, ".0"),
120209 inputWidthFloat = "".concat(imageWidth - 1, ".0");
120210 var _ref = cropHeight > 1 ? ["".concat((imageHeight - 1) / (cropHeight - 1)), '(y2-y1) * height_ratio', "y1*".concat(inputHeightFloat, " + float(y)*(height_scale)")] : ['0.0', '0.0', "0.5 * (y1+y2) * ".concat(inputHeightFloat)],
120211 _ref2 = _slicedToArray(_ref, 3),
120212 heightRatio = _ref2[0],
120213 heightScale = _ref2[1],
120214 inY = _ref2[2];
120215 var _ref3 = cropWidth > 1 ? ["".concat((imageWidth - 1) / (cropWidth - 1)), '(x2-x1) * width_ratio', "x1*".concat(inputWidthFloat, " + float(x)*(width_scale)")] : ['0.0', '0.0', "0.5 * (x1+x2) * ".concat(inputWidthFloat)],
120216 _ref4 = _slicedToArray(_ref3, 3),
120217 widthRatio = _ref4[0],
120218 widthScale = _ref4[1],
120219 inX = _ref4[2];
120220 // Reference implementation
120221 // tslint:disable-next-line:max-line-length
120222 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
120223 this.userCode = "\n const float height_ratio = float(".concat(heightRatio, ");\n const float width_ratio = float(").concat(widthRatio, ");\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int y = coords[1];\n int x = coords[2];\n int d = coords[3];\n\n // get box vals\n float y1 = getBoxes(b,0);\n float x1 = getBoxes(b,1);\n float y2 = getBoxes(b,2);\n float x2 = getBoxes(b,3);\n\n // get image in batch index\n int bInd = round(getBoxInd(b));\n if(bInd < 0 || bInd >= ").concat(batch, ") {\n return;\n }\n\n float height_scale = ").concat(heightScale, ";\n float width_scale = ").concat(widthScale, ";\n\n float in_y = ").concat(inY, ";\n if( in_y < 0.0 || in_y > ").concat(inputHeightFloat, " ) {\n setOutput(float(").concat(extrapolationValue, "));\n return;\n }\n float in_x = ").concat(inX, ";\n if( in_x < 0.0 || in_x > ").concat(inputWidthFloat, " ) {\n setOutput(float(").concat(extrapolationValue, "));\n return;\n }\n\n vec2 sourceFracIndexCR = vec2(in_x,in_y);\n if(").concat(methodId, " == 1) {\n // Compute the four integer indices.\n ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);\n ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));\n\n float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);\n float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);\n float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);\n float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);\n\n vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);\n\n float top = topLeft + (topRight - topLeft) * fracCR.x;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;\n float newValue = top + (bottom - top) * fracCR.y;\n setOutput(newValue);\n } else {\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestCR = ivec2(floor(\n sourceFracIndexCR + vec2(0.5,0.5)));\n float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);\n setOutput(newValue);\n }\n }\n ");
120224 });
120225
120226 /**
120227 * @license
120228 * Copyright 2020 Google LLC. All Rights Reserved.
120229 * Licensed under the Apache License, Version 2.0 (the "License");
120230 * you may not use this file except in compliance with the License.
120231 * You may obtain a copy of the License at
120232 *
120233 * http://www.apache.org/licenses/LICENSE-2.0
120234 *
120235 * Unless required by applicable law or agreed to in writing, software
120236 * distributed under the License is distributed on an "AS IS" BASIS,
120237 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120238 * See the License for the specific language governing permissions and
120239 * limitations under the License.
120240 * =============================================================================
120241 */
120242 var cropAndResize = function cropAndResize(args) {
120243 var inputs = args.inputs,
120244 backend = args.backend,
120245 attrs = args.attrs;
120246 var image = inputs.image,
120247 boxes = inputs.boxes,
120248 boxInd = inputs.boxInd;
120249 var cropSize = attrs.cropSize,
120250 method = attrs.method,
120251 extrapolationValue = attrs.extrapolationValue;
120252 var program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue);
120253 return backend.runWebGLProgram(program, [image, boxes, boxInd], 'float32');
120254 };
120255 var cropAndResizeConfig = {
120256 kernelName: CropAndResize,
120257 backendName: 'webgl',
120258 kernelFunc: cropAndResize
120259 };
120260
120261 var CumOpType;
120262 (function (CumOpType) {
120263 CumOpType["Prod"] = "*";
120264 CumOpType["Sum"] = "+";
120265 })(CumOpType || (CumOpType = {}));
120266 var CumProgram = /*#__PURE__*/_createClass(function CumProgram(op, outputShape, exclusive, reverse) {
120267 _classCallCheck(this, CumProgram);
120268 this.op = op;
120269 this.outputShape = outputShape;
120270 this.variableNames = ['x'];
120271 this.customUniforms = [{
120272 name: 'index',
120273 type: 'float'
120274 }];
120275 var rank = this.outputShape.length;
120276 var initVal = this.op === CumOpType.Prod ? '1.0' : '0.0';
120277 var val = exclusive ? initVal : "getX(".concat(getCoords(rank, 'coords', this.op), ")");
120278 var length = this.outputShape[this.outputShape.length - 1];
120279 var condition = '';
120280 var idxString = '';
120281 // When exclusive is set, the cum op becomes roll op that copies the
120282 // value from the previous index based on the direction specified by the
120283 // reverse flag.
120284 if (exclusive) {
120285 condition = reverse ? "end != ".concat(length - 1) : 'end != 0';
120286 idxString = reverse ? 'end + 1' : 'end - 1';
120287 } else {
120288 condition = reverse ? "end + pow2 < ".concat(length) : 'end >= pow2';
120289 idxString = reverse ? 'end + pow2' : 'end - pow2';
120290 }
120291 this.userCode = "\n void main() {\n ".concat(getCoordsDataType(rank), " coords = getOutputCoords();\n int end = ").concat(getFinalCoord(rank, 'coords', this.op), ";\n float val = ").concat(val, ";\n int pow2 = int(pow(2.0, index));\n if (").concat(condition, ") {\n int idx = ").concat(idxString, ";\n ").concat(getFinalCoord(rank, 'coords', this.op), " = idx;\n val ").concat(this.op, "= getX(").concat(getCoords(rank, 'coords', this.op), ");\n }\n setOutput(val);\n }\n ");
120292 });
120293 function getCoords(rank, name, op) {
120294 if (rank === 1) {
120295 return "".concat(name);
120296 } else if (rank === 2) {
120297 return "".concat(name, ".x, ").concat(name, ".y");
120298 } else if (rank === 3) {
120299 return "".concat(name, ".x, ").concat(name, ".y, ").concat(name, ".z");
120300 } else if (rank === 4) {
120301 return "".concat(name, ".x, ").concat(name, ".y, ").concat(name, ".z, ").concat(name, ".w");
120302 } else {
120303 throw new Error("Cumulative ".concat(op, " for rank ").concat(rank, " is not yet supported"));
120304 }
120305 }
120306 function getFinalCoord(rank, name, op) {
120307 if (rank === 1) {
120308 return "".concat(name);
120309 } else if (rank === 2) {
120310 return "".concat(name, ".y");
120311 } else if (rank === 3) {
120312 return "".concat(name, ".z");
120313 } else if (rank === 4) {
120314 return "".concat(name, ".w");
120315 } else {
120316 throw new Error("Cumulative ".concat(op, " for rank ").concat(rank, " is not yet supported"));
120317 }
120318 }
120319
120320 /**
120321 * @license
120322 * Copyright 2022 Google LLC. All Rights Reserved.
120323 * Licensed under the Apache License, Version 2.0 (the "License");
120324 * you may not use this file except in compliance with the License.
120325 * You may obtain a copy of the License at
120326 *
120327 * http://www.apache.org/licenses/LICENSE-2.0
120328 *
120329 * Unless required by applicable law or agreed to in writing, software
120330 * distributed under the License is distributed on an "AS IS" BASIS,
120331 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120332 * See the License for the specific language governing permissions and
120333 * limitations under the License.
120334 * =============================================================================
120335 */
120336 function cumImpl(op, x, backend, axis, exclusive, reverse) {
120337 var xRank = x.shape.length;
120338 var permutation = getAxesPermutation([axis], xRank);
120339 var permutedX = x;
120340 if (permutation != null) {
120341 permutedX = transpose({
120342 inputs: {
120343 x: x
120344 },
120345 backend: backend,
120346 attrs: {
120347 perm: permutation
120348 }
120349 });
120350 }
120351 var permutedAxis = getInnerMostAxes(1, xRank)[0];
120352 if (permutedAxis !== xRank - 1) {
120353 throw new Error("WebGL cumprod shader expects an inner-most axis=".concat(x.shape.length - 1, " ") + "but got axis=".concat(axis));
120354 }
120355 var size = permutedX.shape[permutedAxis];
120356 var result = identity({
120357 inputs: {
120358 x: permutedX
120359 },
120360 backend: backend
120361 });
120362 // Use cum parallel algorithm, inspired by:
120363 // https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
120364 // Note: although the algorithm is called sum, it works for any associtative
120365 // operator with an identity.
120366 for (var i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) {
120367 var program = new CumProgram(op, permutedX.shape, false, reverse);
120368 var customValues = [[i]];
120369 var prevResult = result;
120370 result = backend.runWebGLProgram(program, [result], result.dtype, customValues);
120371 backend.disposeIntermediateTensorInfo(prevResult);
120372 }
120373 // For exclusive cum, shift the end result in the direction of product or sum
120374 // and add 1 for product or 0 for sum to the front index.
120375 if (exclusive) {
120376 var _program = new CumProgram(op, permutedX.shape, exclusive, reverse);
120377 var _prevResult = result;
120378 result = backend.runWebGLProgram(_program, [result], result.dtype);
120379 backend.disposeIntermediateTensorInfo(_prevResult);
120380 }
120381 if (permutation != null) {
120382 var reversePermutation = getUndoAxesPermutation(permutation);
120383 var reverseTransposedResult = transpose({
120384 inputs: {
120385 x: result
120386 },
120387 backend: backend,
120388 attrs: {
120389 perm: reversePermutation
120390 }
120391 });
120392 backend.disposeIntermediateTensorInfo(result);
120393 backend.disposeIntermediateTensorInfo(permutedX);
120394 return reverseTransposedResult;
120395 }
120396 return result;
120397 }
120398
120399 /**
120400 * @license
120401 * Copyright 2022 Google LLC. All Rights Reserved.
120402 * Licensed under the Apache License, Version 2.0 (the "License");
120403 * you may not use this file except in compliance with the License.
120404 * You may obtain a copy of the License at
120405 *
120406 * http://www.apache.org/licenses/LICENSE-2.0
120407 *
120408 * Unless required by applicable law or agreed to in writing, software
120409 * distributed under the License is distributed on an "AS IS" BASIS,
120410 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120411 * See the License for the specific language governing permissions and
120412 * limitations under the License.
120413 * =============================================================================
120414 */
120415 function cumprod(args) {
120416 var inputs = args.inputs,
120417 backend = args.backend,
120418 attrs = args.attrs;
120419 var x = inputs.x;
120420 var axis = attrs.axis,
120421 exclusive = attrs.exclusive,
120422 reverse = attrs.reverse;
120423 return cumImpl(CumOpType.Prod, x, backend, axis, exclusive, reverse);
120424 }
120425 var cumprodConfig = {
120426 kernelName: Cumprod,
120427 backendName: 'webgl',
120428 kernelFunc: cumprod
120429 };
120430
120431 /**
120432 * @license
120433 * Copyright 2022 Google LLC. All Rights Reserved.
120434 * Licensed under the Apache License, Version 2.0 (the "License");
120435 * you may not use this file except in compliance with the License.
120436 * You may obtain a copy of the License at
120437 *
120438 * http://www.apache.org/licenses/LICENSE-2.0
120439 *
120440 * Unless required by applicable law or agreed to in writing, software
120441 * distributed under the License is distributed on an "AS IS" BASIS,
120442 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120443 * See the License for the specific language governing permissions and
120444 * limitations under the License.
120445 * =============================================================================
120446 */
120447 function cumsum(args) {
120448 var inputs = args.inputs,
120449 backend = args.backend,
120450 attrs = args.attrs;
120451 var x = inputs.x;
120452 var axis = attrs.axis,
120453 exclusive = attrs.exclusive,
120454 reverse = attrs.reverse;
120455 return cumImpl(CumOpType.Sum, x, backend, axis, exclusive, reverse);
120456 }
120457 var cumsumConfig = {
120458 kernelName: Cumsum,
120459 backendName: 'webgl',
120460 kernelFunc: cumsum
120461 };
120462
120463 /**
120464 * @license
120465 * Copyright 2020 Google LLC. All Rights Reserved.
120466 * Licensed under the Apache License, Version 2.0 (the "License");
120467 * you may not use this file except in compliance with the License.
120468 * You may obtain a copy of the License at
120469 *
120470 * http://www.apache.org/licenses/LICENSE-2.0
120471 *
120472 * Unless required by applicable law or agreed to in writing, software
120473 * distributed under the License is distributed on an "AS IS" BASIS,
120474 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120475 * See the License for the specific language governing permissions and
120476 * limitations under the License.
120477 * =============================================================================
120478 */
120479 function denseBincount(args) {
120480 var inputs = args.inputs,
120481 backend = args.backend,
120482 attrs = args.attrs;
120483 var x = inputs.x,
120484 weights = inputs.weights;
120485 var size = attrs.size,
120486 binaryOutput = attrs.binaryOutput;
120487 if (x.shape.length === 1) {
120488 var xVals = backend.readSync(x.dataId);
120489 var weightsVals = backend.readSync(weights.dataId);
120490 var outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
120491 return backend.makeTensorInfo([size], weights.dtype, outVals);
120492 } else if (x.shape.length === 2) {
120493 var xBuf = backend.bufferSync(x);
120494 var weightsBuf = backend.bufferSync(weights);
120495 var outBuf = bincountReduceImplCPU(xBuf, weightsBuf, size, binaryOutput);
120496 return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
120497 }
120498 throw new Error("Error in denseBincount: input must be at most rank 2, but got rank" + "".concat(x.shape.length, "."));
120499 }
120500 var denseBincountConfig = {
120501 kernelName: DenseBincount,
120502 backendName: 'webgl',
120503 kernelFunc: denseBincount
120504 };
120505
120506 /**
120507 * @license
120508 * Copyright 2018 Google LLC. All Rights Reserved.
120509 * Licensed under the Apache License, Version 2.0 (the "License");
120510 * you may not use this file except in compliance with the License.
120511 * You may obtain a copy of the License at
120512 *
120513 * http://www.apache.org/licenses/LICENSE-2.0
120514 *
120515 * Unless required by applicable law or agreed to in writing, software
120516 * distributed under the License is distributed on an "AS IS" BASIS,
120517 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120518 * See the License for the specific language governing permissions and
120519 * limitations under the License.
120520 * =============================================================================
120521 */
120522 var DepthToSpaceProgram = /*#__PURE__*/function () {
120523 function DepthToSpaceProgram(outputShape, blockSize, dataFormat) {
120524 _classCallCheck(this, DepthToSpaceProgram);
120525 this.variableNames = ['x'];
120526 this.outputShape = [];
120527 this.outputShape = outputShape;
120528 this.blockSize = blockSize;
120529 this.dataFormat = dataFormat;
120530 this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int h = ".concat(this.getHeightCoordString(), ";\n int w = ").concat(this.getWidthCoordString(), ";\n int d = ").concat(this.getDepthCoordString(), ";\n\n int in_h = h / ").concat(blockSize, ";\n int offset_h = imod(h, ").concat(blockSize, ");\n int in_w = w / ").concat(blockSize, ";\n int offset_w = imod(w, ").concat(blockSize, ");\n int offset_d = (offset_h * ").concat(blockSize, " + offset_w) *\n ").concat(this.getOutputDepthSize(), ";\n int in_d = d + offset_d;\n\n float result = ").concat(this.getInputSamplingString(), ";\n setOutput(result);\n }\n ");
120531 }
120532 _createClass(DepthToSpaceProgram, [{
120533 key: "getHeightCoordString",
120534 value: function getHeightCoordString() {
120535 if (this.dataFormat === 'NHWC') {
120536 return "coords[1]";
120537 } else {
120538 return "coords[2]";
120539 }
120540 }
120541 }, {
120542 key: "getWidthCoordString",
120543 value: function getWidthCoordString() {
120544 if (this.dataFormat === 'NHWC') {
120545 return "coords[2]";
120546 } else {
120547 return "coords[3]";
120548 }
120549 }
120550 }, {
120551 key: "getDepthCoordString",
120552 value: function getDepthCoordString() {
120553 if (this.dataFormat === 'NHWC') {
120554 return "coords[3]";
120555 } else {
120556 return "coords[1]";
120557 }
120558 }
120559 }, {
120560 key: "getOutputDepthSize",
120561 value: function getOutputDepthSize() {
120562 if (this.dataFormat === 'NHWC') {
120563 return this.outputShape[3];
120564 } else {
120565 return this.outputShape[1];
120566 }
120567 }
120568 }, {
120569 key: "getInputSamplingString",
120570 value: function getInputSamplingString() {
120571 if (this.dataFormat === 'NHWC') {
120572 return "getX(b, in_h, in_w, in_d)";
120573 } else {
120574 return "getX(b, in_d, in_h, in_w)";
120575 }
120576 }
120577 }]);
120578 return DepthToSpaceProgram;
120579 }();
120580
120581 /**
120582 * @license
120583 * Copyright 2020 Google LLC. All Rights Reserved.
120584 * Licensed under the Apache License, Version 2.0 (the "License");
120585 * you may not use this file except in compliance with the License.
120586 * You may obtain a copy of the License at
120587 *
120588 * http://www.apache.org/licenses/LICENSE-2.0
120589 *
120590 * Unless required by applicable law or agreed to in writing, software
120591 * distributed under the License is distributed on an "AS IS" BASIS,
120592 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120593 * See the License for the specific language governing permissions and
120594 * limitations under the License.
120595 * =============================================================================
120596 */
120597 function depthToSpace(args) {
120598 var inputs = args.inputs,
120599 backend = args.backend,
120600 attrs = args.attrs;
120601 var x = inputs.x;
120602 var blockSize = attrs.blockSize,
120603 dataFormat = attrs.dataFormat;
120604 var batchSize = x.shape[0];
120605 var inputHeight = dataFormat === 'NHWC' ? x.shape[1] : x.shape[2];
120606 var inputWidth = dataFormat === 'NHWC' ? x.shape[2] : x.shape[3];
120607 var inputDepth = dataFormat === 'NHWC' ? x.shape[3] : x.shape[1];
120608 var outputHeight = inputHeight * blockSize;
120609 var outputWidth = inputWidth * blockSize;
120610 var outputDepth = inputDepth / (blockSize * blockSize);
120611 var outputShape = dataFormat === 'NHWC' ? [batchSize, outputHeight, outputWidth, outputDepth] : [batchSize, outputDepth, outputHeight, outputWidth];
120612 var program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat);
120613 return backend.runWebGLProgram(program, [x], x.dtype);
120614 }
120615 var depthToSpaceConfig = {
120616 kernelName: DepthToSpace,
120617 backendName: 'webgl',
120618 kernelFunc: depthToSpace
120619 };
120620
120621 var DepthwiseConv2DProgram = /*#__PURE__*/_createClass(function DepthwiseConv2DProgram(convInfo) {
120622 var addBias = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
120623 var activation = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : null;
120624 var hasPreluActivation = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
120625 var hasLeakyReluAlpha = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
120626 _classCallCheck(this, DepthwiseConv2DProgram);
120627 this.variableNames = ['x', 'W'];
120628 this.customUniforms = [{
120629 name: 'pads',
120630 type: 'ivec2'
120631 }, {
120632 name: 'strides',
120633 type: 'ivec2'
120634 }, {
120635 name: 'dilations',
120636 type: 'ivec2'
120637 }, {
120638 name: 'inDims',
120639 type: 'ivec2'
120640 }];
120641 this.outputShape = convInfo.outShape;
120642 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
120643 var filterHeight = convInfo.filterHeight;
120644 var filterWidth = convInfo.filterWidth;
120645 var channelMul = convInfo.outChannels / convInfo.inChannels;
120646 var activationSnippet = '',
120647 applyActivationSnippet = '';
120648 if (activation) {
120649 if (hasPreluActivation) {
120650 activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n ".concat(activation, "\n }");
120651 } else if (hasLeakyReluAlpha) {
120652 activationSnippet = "float activation(float a) {\n float b = getLeakyreluAlphaAtOutCoords();\n ".concat(activation, "\n }");
120653 } else {
120654 activationSnippet = "\n float activation(float x) {\n ".concat(activation, "\n }\n ");
120655 }
120656 applyActivationSnippet = "result = activation(result);";
120657 }
120658 var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
120659 if (addBias) {
120660 this.variableNames.push('bias');
120661 }
120662 if (hasPreluActivation) {
120663 this.variableNames.push('preluActivationWeights');
120664 }
120665 if (hasLeakyReluAlpha) {
120666 this.variableNames.push('leakyreluAlpha');
120667 }
120668 this.userCode = "\n ".concat(activationSnippet, "\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / ").concat(channelMul, ";\n int q = d2 - d1 * ").concat(channelMul, ";\n\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n int xR = xRCorner + wR * dilations[0];\n\n if (xR < 0 || xR >= inDims[0]) {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n int xC = xCCorner + wC * dilations[1];\n\n if (xC < 0 || xC >= inDims[1]) {\n continue;\n }\n\n float xVal = getX(batch, xR, xC, d1);\n float wVal = getW(wR, wC, d1, q);\n dotProd += xVal * wVal;\n }\n }\n\n float result = dotProd;\n ").concat(addBiasSnippet, "\n ").concat(applyActivationSnippet, "\n setOutput(result);\n }\n ");
120669 });
120670
120671 var DepthwiseConvPacked2DProgram = /*#__PURE__*/_createClass(function DepthwiseConvPacked2DProgram(convInfo) {
120672 var addBias = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
120673 var activation = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : null;
120674 var hasPreluActivation = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : false;
120675 var hasLeakyReluAlpha = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false;
120676 _classCallCheck(this, DepthwiseConvPacked2DProgram);
120677 this.variableNames = ['x', 'W'];
120678 this.packedInputs = true;
120679 this.packedOutput = true;
120680 this.customUniforms = [{
120681 name: 'pads',
120682 type: 'ivec2'
120683 }, {
120684 name: 'strides',
120685 type: 'ivec2'
120686 }, {
120687 name: 'dilations',
120688 type: 'ivec2'
120689 }, {
120690 name: 'inDims',
120691 type: 'ivec2'
120692 }];
120693 this.outputShape = convInfo.outShape;
120694 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
120695 var channelMul = convInfo.outChannels / convInfo.inChannels;
120696 var padLeft = convInfo.padInfo.left;
120697 var strideWidth = convInfo.strideWidth;
120698 var dilationWidth = convInfo.dilationWidth;
120699 var filterHeight = convInfo.filterHeight;
120700 var filterWidth = convInfo.filterWidth;
120701 var texelsAcross = filterWidth;
120702 var mainLoop = "\n int xR; int xC; int xCOffset;\n vec4 wTexel; vec4 previous; vec4 final;";
120703 for (var c = 0; c < filterWidth; c++) {
120704 mainLoop += "\n vec4 xTexelC".concat(c * 2, ";\n int xTexelC").concat(c * 2, "Ready;\n vec4 xTexelC").concat(c * 2 + 1, ";\n int xTexelC").concat(c * 2 + 1, "Ready;\n vec4 xC").concat(c, ";");
120705 }
120706 /**
120707 * This vectorized implementation works by gathering the values needed for
120708 * each output channel's dot product into vec4's and then multiplying them
120709 * all together (this happens in the final double for-loop below). Most of
120710 * the main loop consists of constructing these vec4's with the minimum
120711 * number of texture2D calls, which means making use of all four returned
120712 * values from a texture2D call at once.
120713 */
120714 mainLoop += "\n for (int r = 0; r < ".concat(filterHeight, "; r++) {\n ");
120715 for (var _c = 0; _c < filterWidth; _c++) {
120716 mainLoop += "\n xTexelC".concat(_c * 2, " = vec4(0.0);\n xTexelC").concat(_c * 2, "Ready = 0;\n xTexelC").concat(_c * 2 + 1, " = vec4(0.0);\n xTexelC").concat(_c * 2 + 1, "Ready = 0;\n xC").concat(_c, " = vec4(0.0);");
120717 }
120718 mainLoop += "\n xR = xRCorner + r * dilations[0];\n if (xR >=0 && xR < inDims[0]) {\n ";
120719 for (var texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) {
120720 var colIndex = texelC * 2;
120721 mainLoop += "\n xC = xCCorner + ".concat(colIndex * dilationWidth, ";\n ");
120722 if (strideWidth === 1) {
120723 if (colIndex < filterWidth) {
120724 // If padding is odd, the outer texels have to be composed.
120725 if (padLeft % 2 === 1) {
120726 // TODO: Ensure vec4 previous does not result in redundant sample,
120727 // and avoid setting xTexelRC's that exceed the boundary in the
120728 // first place rather than resetting them to vec4(0)).
120729 // To compute xCOffset:
120730 // - If padding is odd, we must add 1 to ensure we ask for an
120731 // even-numbered row.
120732 // - We subtract 2 to access the previous texel.
120733 mainLoop += "\n xCOffset = xC + 1;\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n ");
120734 // This texel has been read in previous iteration if the dilation
120735 // is 1.
120736 if (dilationWidth === 1 && colIndex > 0) {
120737 mainLoop += "\n xC".concat(colIndex, " = vec4(xTexelC").concat(colIndex - 2, ".zw, xTexelC").concat(colIndex, ".xy);\n ");
120738 } else {
120739 mainLoop += "\n xCOffset = xC + 1 - 2;\n\n if (xCOffset >= 0 && xCOffset < inDims[1]) {\n previous = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n previous.zw = vec2(0.0);\n }\n\n xC".concat(colIndex, " = vec4(previous.zw, xTexelC").concat(colIndex, ".xy);\n } else {\n xC").concat(colIndex, " = vec4(0.0, 0.0, xTexelC").concat(colIndex, ".xy);\n }\n ");
120740 }
120741 } else {
120742 // Padding is even, so xRC corresponds to a single texel.
120743 mainLoop += "\n if (xC >= 0 && xC < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xC, d1);\n if (xC + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n xC").concat(colIndex, " = xTexelC").concat(colIndex, ";\n ");
120744 }
120745 if (colIndex + 1 < filterWidth) {
120746 // If dilation is even, the second entry should match the first
120747 // (either both are composed or both are single samples). But if
120748 // dilation is odd, then the second entry should be the opposite
120749 // of the first (if the first is composed, the second is a single
120750 // sample, and vice versa.)
120751 var nextTexelOffset = padLeft % 2 === 0 ? nearestLargerEven(dilationWidth) : dilationWidth;
120752 if (dilationWidth % 2 === 0 && padLeft % 2 === 1 || dilationWidth % 2 !== 0 && padLeft % 2 !== 1) {
120753 mainLoop += "\n xCOffset = xC + imod(pads[1], 2) + ".concat(nextTexelOffset, ";\n\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n ");
120754 // If dilation > 1 then the xRC's will not be able to share any
120755 // values, so each xRC will require two unique calls to getX.
120756 if (dilationWidth > 1) {
120757 mainLoop += "\n xCOffset -= 2;\n if (xCOffset >= 0 && xCOffset < inDims[1]) {\n previous = getX(batch, xR, xCOffset, d1);\n xC".concat(colIndex + 1, " = vec4(previous.zw, xTexelC").concat(colIndex + 1, ".xy);\n } else {\n xC").concat(colIndex + 1, " = vec4(0.0, 0.0, xTexelC").concat(colIndex + 1, ".xy);\n }\n ");
120758 } else {
120759 mainLoop += "\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".xy);\n ");
120760 }
120761 } else {
120762 // If dilation is 1 and padding is odd, we have already read the
120763 // texel when constructing the previous x value. Here we can
120764 // simply skip the texture read.
120765 if (nextTexelOffset === 1) {
120766 mainLoop += "\n xC".concat(colIndex + 1, " = xTexelC").concat(colIndex, ";\n ");
120767 } else {
120768 mainLoop += "\n xCOffset = xC + ".concat(nextTexelOffset, ";\n\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex + 1, " = xTexelC").concat(colIndex + 1, ";\n ");
120769 }
120770 }
120771 }
120772 }
120773 } else {
120774 // stride === 2
120775 if (colIndex < filterWidth) {
120776 // Depending on whether padLeft is even or odd, we want either the
120777 // xy or zw channels from X texels for xC${colIndex}. If padLeft is
120778 // even, xC${colIndex +1} is simply the zw channels of texels we've
120779 // already sampled. But if padLeft is odd, xC{$c + 1}.zw will
120780 // need to come from the xy channels of a new texel, hence the `
120781 // vec4
120782 // final` initialized below.
120783 if (padLeft % 2 === 1) {
120784 mainLoop += "\n xCOffset = xC + 1 - strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xCOffset, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xC + 1, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xC + 2 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".zw);\n ");
120785 if (colIndex + 1 < filterWidth) {
120786 mainLoop += "\n final = vec4(0.0);\n xCOffset = xC + 1 + strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1]) {\n final = getX(batch, xR, xCOffset, d1);\n }\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex + 1, ".xy, final.xy);\n ");
120787 }
120788 } else {
120789 mainLoop += "\n if(xC >= 0 && xC < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xC, d1);\n if (xC + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n xCOffset = xC + strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex, " = vec4(\n xTexelC").concat(colIndex, ".xy, xTexelC").concat(colIndex + 1, ".xy);\n ");
120790 if (colIndex + 1 < filterWidth) {
120791 mainLoop += "\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".zw);\n ");
120792 }
120793 }
120794 }
120795 }
120796 // localize the dotProd accumulation within the loop, the theory is for
120797 // GPU with limited cache, accumulate sum across large amount of
120798 // veriables will cause lots of cache misses. (i.e. 5x5 filter will have
120799 // 50 variables)
120800 if (colIndex < filterWidth) {
120801 mainLoop += "\n wTexel = getW(r, ".concat(colIndex, ", d1, q);\n dotProd += xC").concat(colIndex, " * vec4(wTexel.xz, wTexel.xz);\n ");
120802 if (colIndex + 1 < filterWidth) {
120803 mainLoop += "\n wTexel = getW(r, ".concat(colIndex + 1, ", d1, q);\n dotProd += xC").concat(colIndex + 1, " * vec4(wTexel.xz, wTexel.xz);\n ");
120804 }
120805 }
120806 }
120807 mainLoop += "\n }\n ";
120808 mainLoop += "\n }\n ";
120809 var activationSnippet = '',
120810 applyActivationSnippet = '';
120811 if (activation) {
120812 if (hasPreluActivation) {
120813 activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n ".concat(activation, "\n }");
120814 } else if (hasLeakyReluAlpha) {
120815 activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getLeakyreluAlphaAtOutCoords();\n ".concat(activation, "\n }");
120816 } else {
120817 activationSnippet = "vec4 activation(vec4 x) {\n ".concat(activation, "\n }");
120818 }
120819 applyActivationSnippet = "result = activation(result);";
120820 }
120821 var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
120822 if (addBias) {
120823 this.variableNames.push('bias');
120824 }
120825 if (hasPreluActivation) {
120826 this.variableNames.push('preluActivationWeights');
120827 }
120828 if (hasLeakyReluAlpha) {
120829 this.variableNames.push('leakyreluAlpha');
120830 }
120831 this.userCode = "\n ".concat(activationSnippet, "\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / ").concat(channelMul, ";\n int q = d2 - d1 * ").concat(channelMul, ";\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n //intialize dotProd with a small epsilon seems to reduce GPU accuracy loss.\n vec4 dotProd = vec4(0.000000000000001);\n\n ").concat(mainLoop, "\n\n vec4 result = dotProd - vec4(0.000000000000001);\n ").concat(addBiasSnippet, "\n ").concat(applyActivationSnippet, "\n setOutput(result);\n }\n ");
120832 });
120833
120834 /**
120835 * @license
120836 * Copyright 2020 Google LLC. All Rights Reserved.
120837 * Licensed under the Apache License, Version 2.0 (the "License");
120838 * you may not use this file except in compliance with the License.
120839 * You may obtain a copy of the License at
120840 *
120841 * http://www.apache.org/licenses/LICENSE-2.0
120842 *
120843 * Unless required by applicable law or agreed to in writing, software
120844 * distributed under the License is distributed on an "AS IS" BASIS,
120845 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120846 * See the License for the specific language governing permissions and
120847 * limitations under the License.
120848 * =============================================================================
120849 */
120850 function depthwiseConv2dNative(args) {
120851 var inputs = args.inputs,
120852 backend = args.backend,
120853 attrs = args.attrs;
120854 var x = inputs.x,
120855 filter = inputs.filter;
120856 var strides = attrs.strides,
120857 pad = attrs.pad,
120858 dilations = attrs.dilations,
120859 dimRoundingMode = attrs.dimRoundingMode;
120860 var $dilations = dilations;
120861 if ($dilations == null) {
120862 $dilations = [1, 1];
120863 }
120864 assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), function () {
120865 return 'Error in depthwiseConv2d: Either strides or dilations must be ' + "1. Got strides ".concat(strides, " and dilations '").concat($dilations, "'");
120866 });
120867 var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
120868 var program;
120869 if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1) {
120870 program = new DepthwiseConvPacked2DProgram(convInfo);
120871 } else {
120872 program = new DepthwiseConv2DProgram(convInfo);
120873 }
120874 var customValues = [[convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth]];
120875 return backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
120876 }
120877 var depthwiseConv2dNativeConfig = {
120878 kernelName: DepthwiseConv2dNative,
120879 backendName: 'webgl',
120880 kernelFunc: depthwiseConv2dNative
120881 };
120882
120883 /**
120884 * @license
120885 * Copyright 2018 Google LLC. All Rights Reserved.
120886 * Licensed under the Apache License, Version 2.0 (the "License");
120887 * you may not use this file except in compliance with the License.
120888 * You may obtain a copy of the License at
120889 *
120890 * http://www.apache.org/licenses/LICENSE-2.0
120891 *
120892 * Unless required by applicable law or agreed to in writing, software
120893 * distributed under the License is distributed on an "AS IS" BASIS,
120894 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120895 * See the License for the specific language governing permissions and
120896 * limitations under the License.
120897 * =============================================================================
120898 */
120899 var DepthwiseConv2DDerFilterProgram = /*#__PURE__*/_createClass(function DepthwiseConv2DDerFilterProgram(convInfo) {
120900 _classCallCheck(this, DepthwiseConv2DDerFilterProgram);
120901 this.variableNames = ['x', 'dy'];
120902 this.outputShape = convInfo.filterShape;
120903 var strideHeight = convInfo.strideHeight;
120904 var strideWidth = convInfo.strideWidth;
120905 var padTop = convInfo.padInfo.top;
120906 var padLeft = convInfo.padInfo.left;
120907 var channelMul = convInfo.outChannels / convInfo.inChannels;
120908 this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int dm = coords.w;\n int d2 = d1 * ".concat(channelMul, " + dm;\n\n float dotProd = 0.0;\n\n // TO DO: Vec4 over the batch size\n for (int b = 0; b < ").concat(convInfo.batchSize, "; b++) {\n for (int yR = 0; yR < ").concat(convInfo.outHeight, "; yR++) {\n int xR = wR + yR * ").concat(strideHeight, " - ").concat(padTop, ";\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int yC = 0; yC < ").concat(convInfo.outWidth, "; yC++) {\n int xC = wC + yC * ").concat(strideWidth, " - ").concat(padLeft, ";\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n setOutput(dotProd);\n }\n ");
120909 });
120910 var DepthwiseConv2DDerInputProgram = /*#__PURE__*/_createClass(function DepthwiseConv2DDerInputProgram(convInfo) {
120911 _classCallCheck(this, DepthwiseConv2DDerInputProgram);
120912 this.variableNames = ['dy', 'W'];
120913 this.outputShape = convInfo.inShape;
120914 var filterHeight = convInfo.filterHeight;
120915 var filterWidth = convInfo.filterWidth;
120916 var strideHeight = convInfo.strideHeight;
120917 var strideWidth = convInfo.strideWidth;
120918 var padTop = filterHeight - 1 - convInfo.padInfo.top;
120919 var padLeft = filterWidth - 1 - convInfo.padInfo.left;
120920 var channelMul = convInfo.outChannels / convInfo.inChannels;
120921 this.userCode = "\n const ivec2 pads = ivec2(".concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[3];\n ivec2 dyCorner = coords.yz - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n float dotProd = 0.0;\n\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = ").concat(filterHeight, " - 1 - wR;\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = ").concat(filterWidth, " - 1 - wC;\n\n // TO DO: Vec4 over the channelMul\n for (int dm = 0; dm < ").concat(channelMul, "; dm++) {\n int d2 = d1 * ").concat(channelMul, " + dm;\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, dm);\n dotProd += xValue * wValue;\n }\n }\n }\n setOutput(dotProd);\n }\n ");
120922 });
120923
120924 /**
120925 * @license
120926 * Copyright 2020 Google LLC. All Rights Reserved.
120927 * Licensed under the Apache License, Version 2.0 (the "License");
120928 * you may not use this file except in compliance with the License.
120929 * You may obtain a copy of the License at
120930 *
120931 * http://www.apache.org/licenses/LICENSE-2.0
120932 *
120933 * Unless required by applicable law or agreed to in writing, software
120934 * distributed under the License is distributed on an "AS IS" BASIS,
120935 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120936 * See the License for the specific language governing permissions and
120937 * limitations under the License.
120938 * =============================================================================
120939 */
120940 function depthwiseConv2dNativeBackpropFilter(args) {
120941 var inputs = args.inputs,
120942 backend = args.backend,
120943 attrs = args.attrs;
120944 var x = inputs.x,
120945 dy = inputs.dy;
120946 var strides = attrs.strides,
120947 dilations = attrs.dilations,
120948 pad = attrs.pad,
120949 dimRoundingMode = attrs.dimRoundingMode,
120950 filterShape = attrs.filterShape;
120951 var convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
120952 var program = new DepthwiseConv2DDerFilterProgram(convInfo);
120953 return backend.runWebGLProgram(program, [x, dy], 'float32');
120954 }
120955 var depthwiseConv2dNativeBackpropFilterConfig = {
120956 kernelName: DepthwiseConv2dNativeBackpropFilter,
120957 backendName: 'webgl',
120958 kernelFunc: depthwiseConv2dNativeBackpropFilter
120959 };
120960
120961 /**
120962 * @license
120963 * Copyright 2020 Google LLC. All Rights Reserved.
120964 * Licensed under the Apache License, Version 2.0 (the "License");
120965 * you may not use this file except in compliance with the License.
120966 * You may obtain a copy of the License at
120967 *
120968 * http://www.apache.org/licenses/LICENSE-2.0
120969 *
120970 * Unless required by applicable law or agreed to in writing, software
120971 * distributed under the License is distributed on an "AS IS" BASIS,
120972 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120973 * See the License for the specific language governing permissions and
120974 * limitations under the License.
120975 * =============================================================================
120976 */
120977 function depthwiseConv2dNativeBackpropInput(args) {
120978 var inputs = args.inputs,
120979 backend = args.backend,
120980 attrs = args.attrs;
120981 var dy = inputs.dy,
120982 filter = inputs.filter;
120983 var strides = attrs.strides,
120984 dilations = attrs.dilations,
120985 pad = attrs.pad,
120986 dimRoundingMode = attrs.dimRoundingMode,
120987 inputShape = attrs.inputShape;
120988 var convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
120989 var program = new DepthwiseConv2DDerInputProgram(convInfo);
120990 return backend.runWebGLProgram(program, [dy, filter], 'float32');
120991 }
120992 var depthwiseConv2dNativeBackpropInputConfig = {
120993 kernelName: DepthwiseConv2dNativeBackpropInput,
120994 backendName: 'webgl',
120995 kernelFunc: depthwiseConv2dNativeBackpropInput
120996 };
120997
120998 /**
120999 * @license
121000 * Copyright 2019 Google LLC. All Rights Reserved.
121001 * Licensed under the Apache License, Version 2.0 (the "License");
121002 * you may not use this file except in compliance with the License.
121003 * You may obtain a copy of the License at
121004 *
121005 * http://www.apache.org/licenses/LICENSE-2.0
121006 *
121007 * Unless required by applicable law or agreed to in writing, software
121008 * distributed under the License is distributed on an "AS IS" BASIS,
121009 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121010 * See the License for the specific language governing permissions and
121011 * limitations under the License.
121012 * =============================================================================
121013 */
121014 var DiagProgram = /*#__PURE__*/_createClass(function DiagProgram(size) {
121015 _classCallCheck(this, DiagProgram);
121016 this.variableNames = ['X'];
121017 this.outputShape = [size, size];
121018 this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;\n setOutput(val);\n }\n ";
121019 });
121020
121021 function diag(args) {
121022 var inputs = args.inputs,
121023 backend = args.backend;
121024 var x = inputs.x;
121025 var outShape = [].concat(_toConsumableArray(x.shape), _toConsumableArray(x.shape));
121026 var xSize = sizeFromShape(x.shape);
121027 var flat = reshape({
121028 inputs: {
121029 x: x
121030 },
121031 backend: backend,
121032 attrs: {
121033 shape: [xSize]
121034 }
121035 });
121036 var program = new DiagProgram(xSize);
121037 var res = backend.runWebGLProgram(program, [flat], flat.dtype);
121038 var out = reshape({
121039 inputs: {
121040 x: res
121041 },
121042 backend: backend,
121043 attrs: {
121044 shape: outShape
121045 }
121046 });
121047 backend.disposeIntermediateTensorInfo(flat);
121048 backend.disposeIntermediateTensorInfo(res);
121049 return out;
121050 }
121051 var diagConfig = {
121052 kernelName: Diag,
121053 backendName: 'webgl',
121054 kernelFunc: diag
121055 };
121056
121057 /**
121058 * @license
121059 * Copyright 2017 Google LLC. All Rights Reserved.
121060 * Licensed under the Apache License, Version 2.0 (the "License");
121061 * you may not use this file except in compliance with the License.
121062 * You may obtain a copy of the License at
121063 *
121064 * http://www.apache.org/licenses/LICENSE-2.0
121065 *
121066 * Unless required by applicable law or agreed to in writing, software
121067 * distributed under the License is distributed on an "AS IS" BASIS,
121068 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121069 * See the License for the specific language governing permissions and
121070 * limitations under the License.
121071 * =============================================================================
121072 */
121073 var Dilation2DProgram = /*#__PURE__*/_createClass(function Dilation2DProgram(convInfo) {
121074 _classCallCheck(this, Dilation2DProgram);
121075 this.variableNames = ['x', 'W'];
121076 this.outputShape = convInfo.outShape;
121077 var inHeight = convInfo.inHeight,
121078 inWidth = convInfo.inWidth,
121079 padInfo = convInfo.padInfo,
121080 strideHeight = convInfo.strideHeight,
121081 strideWidth = convInfo.strideWidth,
121082 filterHeight = convInfo.filterHeight,
121083 filterWidth = convInfo.filterWidth,
121084 dilationHeight = convInfo.dilationHeight,
121085 dilationWidth = convInfo.dilationWidth;
121086 var padTop = padInfo.top,
121087 padLeft = padInfo.left;
121088 this.userCode = "\n const ivec2 strides = ivec2(".concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec2 pads = ivec2(").concat(padTop, ", ").concat(padLeft, ");\n const float neg_infinity = -3.4e38;\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.w;\n ivec2 outTopLeftCorner =\n coords.yz * strides - pads;\n int hBeg = outTopLeftCorner.x;\n int wBeg = outTopLeftCorner.y;\n\n float curVal = neg_infinity;\n for (int h = 0; h < ").concat(filterHeight, "; h++) {\n int hIn = hBeg + h * ").concat(dilationHeight, ";\n\n if (hIn >= 0 && hIn < ").concat(inHeight, ") {\n for (int w = 0; w < ").concat(filterWidth, "; w++) {\n int wIn = wBeg + w * ").concat(dilationWidth, ";\n\n if (wIn >= 0 && wIn < ").concat(inWidth, ") {\n float xVal = getX(batch, hIn, wIn, d1);\n float wVal = getW(h, w, d1);\n\n float val = xVal + wVal;\n if (val > curVal) {\n curVal = val;\n }\n }\n }\n }\n }\n\n float result = curVal;\n setOutput(result);\n }\n ");
121089 });
121090
121091 /**
121092 * @license
121093 * Copyright 2020 Google LLC. All Rights Reserved.
121094 * Licensed under the Apache License, Version 2.0 (the "License");
121095 * you may not use this file except in compliance with the License.
121096 * You may obtain a copy of the License at
121097 *
121098 * http://www.apache.org/licenses/LICENSE-2.0
121099 *
121100 * Unless required by applicable law or agreed to in writing, software
121101 * distributed under the License is distributed on an "AS IS" BASIS,
121102 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121103 * See the License for the specific language governing permissions and
121104 * limitations under the License.
121105 * =============================================================================
121106 */
121107 function dilation2D(args) {
121108 var inputs = args.inputs,
121109 backend = args.backend,
121110 attrs = args.attrs;
121111 var x = inputs.x,
121112 filter = inputs.filter;
121113 var strides = attrs.strides,
121114 pad = attrs.pad,
121115 dilations = attrs.dilations;
121116 var convInfo = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
121117 var out;
121118 var program = new Dilation2DProgram(convInfo);
121119 out = backend.runWebGLProgram(program, [x, filter], 'float32');
121120 var outReshaped = reshape({
121121 inputs: {
121122 x: out
121123 },
121124 backend: backend,
121125 attrs: {
121126 shape: convInfo.outShape
121127 }
121128 });
121129 backend.disposeIntermediateTensorInfo(out);
121130 return outReshaped;
121131 }
121132 var dilation2DConfig = {
121133 kernelName: Dilation2D,
121134 backendName: 'webgl',
121135 kernelFunc: dilation2D
121136 };
121137
121138 function einsum(args) {
121139 var inputs = args.inputs,
121140 backend = args.backend,
121141 attrs = args.attrs;
121142 var equation = attrs.equation;
121143 var tensors = inputs;
121144 var _backend_util$decodeE = decodeEinsumEquation(equation, tensors.length),
121145 allDims = _backend_util$decodeE.allDims,
121146 summedDims = _backend_util$decodeE.summedDims,
121147 idDims = _backend_util$decodeE.idDims;
121148 checkEinsumDimSizes(allDims.length, idDims, tensors);
121149 var _backend_util$getEins = getEinsumComputePath(summedDims, idDims),
121150 path = _backend_util$getEins.path,
121151 steps = _backend_util$getEins.steps;
121152 var nSteps = steps.length;
121153 var out = null;
121154 var numDimsRemaining = allDims.length;
121155 var tensorsToDispose = [];
121156 for (var i = 0; i < nSteps; ++i) {
121157 var _iterator = _createForOfIteratorHelper(steps[i]),
121158 _step;
121159 try {
121160 for (_iterator.s(); !(_step = _iterator.n()).done;) {
121161 var idTerm = _step.value;
121162 var _backend_util$getEins2 = getEinsumPermutation(numDimsRemaining, idDims[idTerm]),
121163 perm = _backend_util$getEins2.permutationIndices,
121164 dimsToExpand = _backend_util$getEins2.expandDims;
121165 var x = void 0;
121166 if (isIdentityPermutation(perm)) {
121167 x = tensors[idTerm];
121168 } else {
121169 x = transpose({
121170 inputs: {
121171 x: tensors[idTerm]
121172 },
121173 backend: backend,
121174 attrs: {
121175 perm: perm
121176 }
121177 });
121178 tensorsToDispose.push(x);
121179 }
121180 var targetShape = x.shape.slice();
121181 for (var k = 0; k < dimsToExpand.length; ++k) {
121182 targetShape.splice(dimsToExpand[k], 0, 1);
121183 }
121184 if (!arraysEqual(x.shape, targetShape)) {
121185 x = reshape({
121186 inputs: {
121187 x: x
121188 },
121189 backend: backend,
121190 attrs: {
121191 shape: targetShape
121192 }
121193 });
121194 tensorsToDispose.push(x);
121195 }
121196 if (out === null) {
121197 out = x;
121198 } else {
121199 // tslint:disable-next-line: no-unnecessary-type-assertion
121200 out = multiply({
121201 inputs: {
121202 a: x,
121203 b: out
121204 },
121205 backend: backend
121206 });
121207 tensorsToDispose.push(out);
121208 }
121209 }
121210 } catch (err) {
121211 _iterator.e(err);
121212 } finally {
121213 _iterator.f();
121214 }
121215 if (i < nSteps - 1) {
121216 if (path[i] >= 0) {
121217 out = sum({
121218 inputs: {
121219 x: out
121220 },
121221 backend: backend,
121222 attrs: {
121223 axis: path[i] - (allDims.length - numDimsRemaining),
121224 keepDims: false
121225 }
121226 });
121227 tensorsToDispose.push(out);
121228 }
121229 numDimsRemaining--;
121230 }
121231 }
121232 // Clean up intermediate tensors.
121233 for (var _i = 0, _tensorsToDispose = tensorsToDispose; _i < _tensorsToDispose.length; _i++) {
121234 var tensorInfo = _tensorsToDispose[_i];
121235 if (tensorInfo === out) {
121236 continue;
121237 }
121238 backend.disposeIntermediateTensorInfo(tensorInfo);
121239 }
121240 return out;
121241 }
121242 var einsumConfig = {
121243 kernelName: Einsum,
121244 backendName: 'webgl',
121245 kernelFunc: einsum
121246 };
121247
121248 /**
121249 * @license
121250 * Copyright 2020 Google LLC. All Rights Reserved.
121251 * Licensed under the Apache License, Version 2.0 (the "License");
121252 * you may not use this file except in compliance with the License.
121253 * You may obtain a copy of the License at
121254 *
121255 * http://www.apache.org/licenses/LICENSE-2.0
121256 *
121257 * Unless required by applicable law or agreed to in writing, software
121258 * distributed under the License is distributed on an "AS IS" BASIS,
121259 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121260 * See the License for the specific language governing permissions and
121261 * limitations under the License.
121262 * =============================================================================
121263 */
121264 var ELU = "return (x >= 0.0) ? x : (exp(x) - 1.0);";
121265 var ELU_PACKED = "\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n";
121266 var elu = unaryKernelFunc({
121267 opSnippet: ELU,
121268 packedOpSnippet: ELU_PACKED
121269 });
121270 var eluConfig = {
121271 kernelName: Elu$1,
121272 backendName: 'webgl',
121273 kernelFunc: elu
121274 };
121275
121276 /**
121277 * @license
121278 * Copyright 2020 Google LLC. All Rights Reserved.
121279 * Licensed under the Apache License, Version 2.0 (the "License");
121280 * you may not use this file except in compliance with the License.
121281 * You may obtain a copy of the License at
121282 *
121283 * http://www.apache.org/licenses/LICENSE-2.0
121284 *
121285 * Unless required by applicable law or agreed to in writing, software
121286 * distributed under the License is distributed on an "AS IS" BASIS,
121287 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121288 * See the License for the specific language governing permissions and
121289 * limitations under the License.
121290 * =============================================================================
121291 */
121292 var ELU_DER = "return (b >= 0.0) ? a : a * (b + 1.0);";
121293 var ELU_DER_PACKED = "\n vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n";
121294 var eluGrad = function eluGrad(args) {
121295 var inputs = args.inputs,
121296 backend = args.backend;
121297 var dy = inputs.dy,
121298 y = inputs.y;
121299 var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(ELU_DER_PACKED, dy.shape, y.shape) : new BinaryOpProgram(ELU_DER, dy.shape, y.shape);
121300 return backend.runWebGLProgram(program, [dy, y], dy.dtype);
121301 };
121302 var eluGradConfig = {
121303 kernelName: EluGrad,
121304 backendName: 'webgl',
121305 kernelFunc: eluGrad
121306 };
121307
121308 /**
121309 * @license
121310 * Copyright 2020 Google LLC. All Rights Reserved.
121311 * Licensed under the Apache License, Version 2.0 (the "License");
121312 * you may not use this file except in compliance with the License.
121313 * You may obtain a copy of the License at
121314 *
121315 * http://www.apache.org/licenses/LICENSE-2.0
121316 *
121317 * Unless required by applicable law or agreed to in writing, software
121318 * distributed under the License is distributed on an "AS IS" BASIS,
121319 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121320 * See the License for the specific language governing permissions and
121321 * limitations under the License.
121322 * =============================================================================
121323 */
121324 var PACKED_EQUAL = "\n return vec4(equal(a, b));\n";
121325 var EQUAL = "return float(a == b);";
121326 var equal = binaryKernelFunc({
121327 opSnippet: EQUAL,
121328 packedOpSnippet: PACKED_EQUAL,
121329 dtype: 'bool',
121330 cpuKernelImpl: equalImplCPU
121331 });
121332 var equalConfig = {
121333 kernelName: Equal,
121334 backendName: 'webgl',
121335 kernelFunc: equal
121336 };
121337
121338 /**
121339 * @license
121340 * Copyright 2020 Google LLC. All Rights Reserved.
121341 * Licensed under the Apache License, Version 2.0 (the "License");
121342 * you may not use this file except in compliance with the License.
121343 * You may obtain a copy of the License at
121344 *
121345 * http://www.apache.org/licenses/LICENSE-2.0
121346 *
121347 * Unless required by applicable law or agreed to in writing, software
121348 * distributed under the License is distributed on an "AS IS" BASIS,
121349 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121350 * See the License for the specific language governing permissions and
121351 * limitations under the License.
121352 * =============================================================================
121353 */
121354 var ERF = "\n // Error function is calculated approximately with elementary function.\n // See \"Handbook of Mathematical Functions with Formulas,\n // Graphs, and Mathematical Tables\", Abramowitz and Stegun.\n float p = ".concat(ERF_P, ";\n float a1 = ").concat(ERF_A1, ";\n float a2 = ").concat(ERF_A2, ";\n float a3 = ").concat(ERF_A3, ";\n float a4 = ").concat(ERF_A4, ";\n float a5 = ").concat(ERF_A5, ";\n\n float sign = sign(x);\n x = abs(x);\n float t = 1.0 / (1.0 + p * x);\n return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));\n");
121355 var erf = unaryKernelFunc({
121356 opSnippet: ERF
121357 });
121358 var erfConfig = {
121359 kernelName: Erf,
121360 backendName: 'webgl',
121361 kernelFunc: erf
121362 };
121363
121364 /**
121365 * @license
121366 * Copyright 2020 Google LLC. All Rights Reserved.
121367 * Licensed under the Apache License, Version 2.0 (the "License");
121368 * you may not use this file except in compliance with the License.
121369 * You may obtain a copy of the License at
121370 *
121371 * http://www.apache.org/licenses/LICENSE-2.0
121372 *
121373 * Unless required by applicable law or agreed to in writing, software
121374 * distributed under the License is distributed on an "AS IS" BASIS,
121375 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121376 * See the License for the specific language governing permissions and
121377 * limitations under the License.
121378 * =============================================================================
121379 */
121380 var EXP = CHECK_NAN_SNIPPET_UNARY + "\n return exp(x);\n";
121381 var EXP_PACKED = "\n vec4 result = exp(x);\n bvec4 isNaN = isnan(x);\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
121382 var exp = unaryKernelFunc({
121383 opSnippet: EXP,
121384 packedOpSnippet: EXP_PACKED,
121385 cpuKernelImpl: expImplCPU,
121386 dtype: 'float32'
121387 });
121388 var expConfig = {
121389 kernelName: Exp,
121390 backendName: 'webgl',
121391 kernelFunc: exp
121392 };
121393
121394 /**
121395 * @license
121396 * Copyright 2020 Google LLC. All Rights Reserved.
121397 * Licensed under the Apache License, Version 2.0 (the License);
121398 * you may not use this file except in compliance with the License.
121399 * You may obtain a copy of the License at
121400 *
121401 * http://www.apache.org/licenses/LICENSE-2.0
121402 *
121403 * Unless required by applicable law or agreed to in writing, software
121404 * distributed under the License is distributed on an AS IS BASIS,
121405 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121406 * See the License for the specific language governing permissions and
121407 * limitations under the License.
121408 * =============================================================================
121409 */
121410 function expandDims(args) {
121411 var inputs = args.inputs,
121412 attrs = args.attrs,
121413 backend = args.backend;
121414 var dim = attrs.dim;
121415 var input = inputs.input;
121416 var inputRank = input.shape.length;
121417 var newShape = input.shape.slice();
121418 var $dim = dim;
121419 if (dim < 0) {
121420 // Negative value is counted from the tail of rank.
121421 assert$1(-(inputRank + 1) <= dim, function () {
121422 return "Axis must be in the interval [".concat(-(inputRank + 1), ", ").concat(inputRank, "]");
121423 });
121424 $dim = inputRank + dim + 1;
121425 }
121426 newShape.splice($dim, 0, 1);
121427 return reshape({
121428 inputs: {
121429 x: input
121430 },
121431 backend: backend,
121432 attrs: {
121433 shape: newShape
121434 }
121435 });
121436 }
121437 var expandDimsConfig = {
121438 kernelName: ExpandDims,
121439 backendName: 'webgl',
121440 kernelFunc: expandDims
121441 };
121442
121443 /**
121444 * @license
121445 * Copyright 2020 Google LLC. All Rights Reserved.
121446 * Licensed under the Apache License, Version 2.0 (the "License");
121447 * you may not use this file except in compliance with the License.
121448 * You may obtain a copy of the License at
121449 *
121450 * http://www.apache.org/licenses/LICENSE-2.0
121451 *
121452 * Unless required by applicable law or agreed to in writing, software
121453 * distributed under the License is distributed on an "AS IS" BASIS,
121454 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121455 * See the License for the specific language governing permissions and
121456 * limitations under the License.
121457 * =============================================================================
121458 */
121459 var EXPM1 = "return exp(x) - 1.0;";
121460 var expm1 = unaryKernelFunc({
121461 opSnippet: EXPM1,
121462 packedOpSnippet: EXPM1,
121463 cpuKernelImpl: expm1ImplCPU
121464 });
121465 var expm1Config = {
121466 kernelName: Expm1,
121467 backendName: 'webgl',
121468 kernelFunc: expm1
121469 };
121470
121471 /**
121472 * @license
121473 * Copyright 2018 Google LLC. All Rights Reserved.
121474 * Licensed under the Apache License, Version 2.0 (the "License");
121475 * you may not use this file except in compliance with the License.
121476 * You may obtain a copy of the License at
121477 *
121478 * http://www.apache.org/licenses/LICENSE-2.0
121479 *
121480 * Unless required by applicable law or agreed to in writing, software
121481 * distributed under the License is distributed on an "AS IS" BASIS,
121482 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121483 * See the License for the specific language governing permissions and
121484 * limitations under the License.
121485 * =============================================================================
121486 */
121487 var FFTProgram = /*#__PURE__*/_createClass(function FFTProgram(component, inputShape, inverse) {
121488 _classCallCheck(this, FFTProgram);
121489 this.variableNames = ['real', 'imag'];
121490 var innerDim = inputShape[1];
121491 this.outputShape = inputShape;
121492 var exponentMultiplierSnippet = inverse ? "2.0 * ".concat(Math.PI) : "-2.0 * ".concat(Math.PI);
121493 var resultDenominator = inverse ? "".concat(innerDim, ".0") : '1.0';
121494 var opString;
121495 if (component === 'real') {
121496 opString = 'return real * expR - imag * expI;';
121497 } else if (component === 'imag') {
121498 opString = 'return real * expI + imag * expR;';
121499 } else {
121500 throw new Error("FFT component must be either \"real\" or \"imag\", got ".concat(component, "."));
121501 }
121502 this.userCode = "\n const float exponentMultiplier = ".concat(exponentMultiplierSnippet, ";\n\n float unaryOpComplex(float real, float expR, float imag, float expI) {\n ").concat(opString, "\n }\n\n float mulMatDFT(int batch, int index) {\n float indexRatio = float(index) / float(").concat(innerDim, ");\n float exponentMultiplierTimesIndexRatio =\n exponentMultiplier * indexRatio;\n\n float result = 0.0;\n\n for (int i = 0; i < ").concat(innerDim, "; i++) {\n // x = (-2|2 * PI / N) * index * i;\n float x = exponentMultiplierTimesIndexRatio * float(i);\n float expR = cos(x);\n float expI = sin(x);\n float real = getReal(batch, i);\n float imag = getImag(batch, i);\n\n result +=\n unaryOpComplex(real, expR, imag, expI) / ").concat(resultDenominator, ";\n }\n\n return result;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n setOutput(mulMatDFT(coords[0], coords[1]));\n }\n ");
121503 });
121504
121505 /**
121506 * @license
121507 * Copyright 2020 Google LLC. All Rights Reserved.
121508 * Licensed under the Apache License, Version 2.0 (the "License");
121509 * you may not use this file except in compliance with the License.
121510 * You may obtain a copy of the License at
121511 *
121512 * http://www.apache.org/licenses/LICENSE-2.0
121513 *
121514 * Unless required by applicable law or agreed to in writing, software
121515 * distributed under the License is distributed on an "AS IS" BASIS,
121516 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121517 * See the License for the specific language governing permissions and
121518 * limitations under the License.
121519 * =============================================================================
121520 */
121521 function fftImpl(x, inverse, backend) {
121522 var xData = backend.texData.get(x.dataId);
121523 var inputSize = sizeFromShape(x.shape);
121524 // Collapse all outer dimensions to a single batch dimension.
121525 var innerDimensionSize = x.shape[x.shape.length - 1];
121526 var batch = inputSize / innerDimensionSize;
121527 var input2D = reshape({
121528 inputs: {
121529 x: x
121530 },
121531 backend: backend,
121532 attrs: {
121533 shape: [batch, innerDimensionSize]
121534 }
121535 });
121536 var xShape = input2D.shape;
121537 var realProgram = new FFTProgram('real', xShape, inverse);
121538 var imagProgram = new FFTProgram('imag', xShape, inverse);
121539 var inputs = [{
121540 dataId: xData.complexTensorInfos.real.dataId,
121541 dtype: xData.complexTensorInfos.real.dtype,
121542 shape: xShape
121543 }, {
121544 dataId: xData.complexTensorInfos.imag.dataId,
121545 dtype: xData.complexTensorInfos.imag.dtype,
121546 shape: xShape
121547 }];
121548 var realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
121549 var imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
121550 var complexOutput = complex({
121551 inputs: {
121552 real: realPart,
121553 imag: imagPart
121554 },
121555 backend: backend
121556 });
121557 backend.disposeIntermediateTensorInfo(realPart);
121558 backend.disposeIntermediateTensorInfo(imagPart);
121559 var complexOutputReshaped = reshape({
121560 inputs: {
121561 x: complexOutput
121562 },
121563 backend: backend,
121564 attrs: {
121565 shape: x.shape
121566 }
121567 });
121568 backend.disposeIntermediateTensorInfo(input2D);
121569 backend.disposeIntermediateTensorInfo(complexOutput);
121570 return complexOutputReshaped;
121571 }
121572
121573 /**
121574 * @license
121575 * Copyright 2020 Google LLC. All Rights Reserved.
121576 * Licensed under the Apache License, Version 2.0 (the "License");
121577 * you may not use this file except in compliance with the License.
121578 * You may obtain a copy of the License at
121579 *
121580 * http://www.apache.org/licenses/LICENSE-2.0
121581 *
121582 * Unless required by applicable law or agreed to in writing, software
121583 * distributed under the License is distributed on an "AS IS" BASIS,
121584 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121585 * See the License for the specific language governing permissions and
121586 * limitations under the License.
121587 * =============================================================================
121588 */
121589 function fft(args) {
121590 var inputs = args.inputs,
121591 backend = args.backend;
121592 var input = inputs.input;
121593 return fftImpl(input, false /* inverse */, backend);
121594 }
121595 var fftConfig = {
121596 kernelName: FFT,
121597 backendName: 'webgl',
121598 kernelFunc: fft
121599 };
121600
121601 /**
121602 * @license
121603 * Copyright 2019 Google LLC. All Rights Reserved.
121604 * Licensed under the Apache License, Version 2.0 (the "License");
121605 * you may not use this file except in compliance with the License.
121606 * You may obtain a copy of the License at
121607 *
121608 * http://www.apache.org/licenses/LICENSE-2.0
121609 *
121610 * Unless required by applicable law or agreed to in writing, software
121611 * distributed under the License is distributed on an "AS IS" BASIS,
121612 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121613 * See the License for the specific language governing permissions and
121614 * limitations under the License.
121615 * =============================================================================
121616 */
121617 var FillProgram = /*#__PURE__*/_createClass(function FillProgram(shape, value) {
121618 _classCallCheck(this, FillProgram);
121619 this.outputShape = [];
121620 this.customUniforms = [{
121621 name: 'value',
121622 type: 'float'
121623 }];
121624 this.variableNames = ['x'];
121625 this.outputShape = shape;
121626 this.userCode = "\n void main() {\n // Input can be obtained from uniform value.\n setOutput(value);\n }\n ";
121627 });
121628
121629 /**
121630 * @license
121631 * Copyright 2020 Google LLC. All Rights Reserved.
121632 * Licensed under the Apache License, Version 2.0 (the "License");
121633 * you may not use this file except in compliance with the License.
121634 * You may obtain a copy of the License at
121635 *
121636 * http://www.apache.org/licenses/LICENSE-2.0
121637 *
121638 * Unless required by applicable law or agreed to in writing, software
121639 * distributed under the License is distributed on an "AS IS" BASIS,
121640 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121641 * See the License for the specific language governing permissions and
121642 * limitations under the License.
121643 * =============================================================================
121644 */
121645 function fill(args) {
121646 var backend = args.backend,
121647 attrs = args.attrs;
121648 var shape = attrs.shape,
121649 value = attrs.value;
121650 var dtype = attrs.dtype;
121651 dtype = dtype || inferDtype(value);
121652 if (dtype === 'string') {
121653 // String type should be handled in CPU memory.
121654 var values = getArrayFromDType(dtype, sizeFromShape(shape));
121655 values.fill(value);
121656 return backend.makeTensorInfo(shape, dtype, values);
121657 } else {
121658 var program = new FillProgram(shape, value);
121659 var customValues = [[value]];
121660 return backend.runWebGLProgram(program, [], dtype, customValues);
121661 }
121662 }
121663 var fillConfig = {
121664 kernelName: Fill,
121665 backendName: 'webgl',
121666 kernelFunc: fill
121667 };
121668
121669 /**
121670 * @license
121671 * Copyright 2020 Google LLC. All Rights Reserved.
121672 * Licensed under the Apache License, Version 2.0 (the "License");
121673 * you may not use this file except in compliance with the License.
121674 * You may obtain a copy of the License at
121675 *
121676 * http://www.apache.org/licenses/LICENSE-2.0
121677 *
121678 * Unless required by applicable law or agreed to in writing, software
121679 * distributed under the License is distributed on an "AS IS" BASIS,
121680 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121681 * See the License for the specific language governing permissions and
121682 * limitations under the License.
121683 * =============================================================================
121684 */
121685 var FlipLeftRightProgram = /*#__PURE__*/_createClass(function FlipLeftRightProgram(imageShape) {
121686 _classCallCheck(this, FlipLeftRightProgram);
121687 this.variableNames = ['Image'];
121688 this.outputShape = [];
121689 var imageWidth = imageShape[2];
121690 this.outputShape = imageShape;
121691 this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int x = coords[2];\n\n int coordX = ".concat(imageWidth, " - x - 1;\n float outputValue;\n if(coordX >= 0 && coordX < ").concat(imageWidth, ") {\n outputValue = getImage(coords[0], coords[1], coordX, coords[3]);\n } else {\n outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);\n }\n setOutput(outputValue);\n }\n ");
121692 });
121693
121694 /**
121695 * @license
121696 * Copyright 2020 Google LLC. All Rights Reserved.
121697 * Licensed under the Apache License, Version 2.0 (the "License");
121698 * you may not use this file except in compliance with the License.
121699 * You may obtain a copy of the License at
121700 *
121701 * http://www.apache.org/licenses/LICENSE-2.0
121702 *
121703 * Unless required by applicable law or agreed to in writing, software
121704 * distributed under the License is distributed on an "AS IS" BASIS,
121705 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121706 * See the License for the specific language governing permissions and
121707 * limitations under the License.
121708 * =============================================================================
121709 */
121710 var flipLeftRightConfig = {
121711 kernelName: FlipLeftRight,
121712 backendName: 'webgl',
121713 kernelFunc: function kernelFunc(_ref) {
121714 var inputs = _ref.inputs,
121715 backend = _ref.backend;
121716 var image = inputs.image;
121717 var webglBackend = backend;
121718 var program = new FlipLeftRightProgram(image.shape);
121719 var output = webglBackend.runWebGLProgram(program, [image], image.dtype);
121720 return output;
121721 }
121722 };
121723
121724 /**
121725 * @license
121726 * Copyright 2020 Google LLC. All Rights Reserved.
121727 * Licensed under the Apache License, Version 2.0 (the "License");
121728 * you may not use this file except in compliance with the License.
121729 * You may obtain a copy of the License at
121730 *
121731 * http://www.apache.org/licenses/LICENSE-2.0
121732 *
121733 * Unless required by applicable law or agreed to in writing, software
121734 * distributed under the License is distributed on an "AS IS" BASIS,
121735 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121736 * See the License for the specific language governing permissions and
121737 * limitations under the License.
121738 * =============================================================================
121739 */
121740 var FLOOR = "return floor(x);";
121741 var floor = unaryKernelFunc({
121742 opSnippet: FLOOR,
121743 packedOpSnippet: FLOOR,
121744 cpuKernelImpl: floorImplCPU
121745 });
121746 var floorConfig = {
121747 kernelName: Floor,
121748 backendName: 'webgl',
121749 kernelFunc: floor
121750 };
121751
121752 /**
121753 * @license
121754 * Copyright 2020 Google LLC. All Rights Reserved.
121755 * Licensed under the Apache License, Version 2.0 (the "License");
121756 * you may not use this file except in compliance with the License.
121757 * You may obtain a copy of the License at
121758 *
121759 * http://www.apache.org/licenses/LICENSE-2.0
121760 *
121761 * Unless required by applicable law or agreed to in writing, software
121762 * distributed under the License is distributed on an "AS IS" BASIS,
121763 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121764 * See the License for the specific language governing permissions and
121765 * limitations under the License.
121766 * =============================================================================
121767 */
121768 // We use native integer division to deal with floating point imprecision. Since
121769 // we implement floor division and glsl implements truncated division, we
121770 // correct for this by subtracting 1 from result when the result is negative and
121771 // there is a remainder.
121772 var INT_DIV = "\n float s = sign(a) * sign(b);\n int ia = round(a);\n int ib = round(b);\n if (ib != 0) {\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n return float(idiv(ia, ib, s));\n } else {\n return NAN;\n }\n";
121773 var INT_DIV_PACKED = "\n ivec4 ia = round(a);\n ivec4 ib = round(b);\n bvec4 cond = notEqual(ib, ivec4(0));\n ivec4 result = ivec4(0);\n vec4 s = sign(a) * sign(b);\n\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n if (cond[0]) {\n result[0] = idiv(ia[0], ib[0], s[0]);\n }\n if (cond[1]) {\n result[1] = idiv(ia[1], ib[1], s[1]);\n }\n if (cond[2]) {\n result[2] = idiv(ia[2], ib[2], s[2]);\n }\n if (cond[3]) {\n result[3] = idiv(ia[3], ib[3], s[3]);\n }\n return vec4(result);\n";
121774 var floorDiv = binaryKernelFunc({
121775 opSnippet: INT_DIV,
121776 packedOpSnippet: INT_DIV_PACKED,
121777 dtype: 'int32'
121778 });
121779 var floorDivConfig = {
121780 kernelName: FloorDiv,
121781 backendName: 'webgl',
121782 kernelFunc: floorDiv
121783 };
121784
121785 var FromPixelsProgram = /*#__PURE__*/_createClass(function FromPixelsProgram(outputShape) {
121786 _classCallCheck(this, FromPixelsProgram);
121787 this.variableNames = ['A'];
121788 var glsl = getGlslDifferences();
121789 var _outputShape = _slicedToArray(outputShape, 2),
121790 height = _outputShape[0],
121791 width = _outputShape[1];
121792 this.outputShape = outputShape;
121793 this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(".concat(width, ".0, ").concat(height, ".0);\n\n vec4 values = ").concat(glsl.texture2D, "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n setOutput(floor(value * 255.0 + 0.5));\n }\n ");
121794 });
121795
121796 var FromPixelsPackedProgram = /*#__PURE__*/_createClass(function FromPixelsPackedProgram(outputShape) {
121797 _classCallCheck(this, FromPixelsPackedProgram);
121798 this.variableNames = ['A'];
121799 this.packedInputs = false;
121800 this.packedOutput = true;
121801 var glsl = getGlslDifferences();
121802 var _outputShape = _slicedToArray(outputShape, 2),
121803 height = _outputShape[0],
121804 width = _outputShape[1];
121805 this.outputShape = outputShape;
121806 this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n\n vec4 result = vec4(0.);\n\n for(int row=0; row<=1; row++) {\n for(int col=0; col<=1; col++) {\n texC = coords[1] + row;\n depth = coords[2] + col;\n\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(".concat(width, ".0, ").concat(height, ".0);\n vec4 values = ").concat(glsl.texture2D, "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n result[row * 2 + col] = floor(value * 255.0 + 0.5);\n }\n }\n\n ").concat(glsl.output, " = result;\n }\n ");
121807 });
121808
121809 var fromPixelsConfig = {
121810 kernelName: FromPixels,
121811 backendName: 'webgl',
121812 kernelFunc: fromPixels
121813 };
121814 var fromPixels2DContext;
121815 var willReadFrequently = env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU');
121816 function fromPixels(args) {
121817 var inputs = args.inputs,
121818 backend = args.backend,
121819 attrs = args.attrs;
121820 var pixels = inputs.pixels;
121821 var numChannels = attrs.numChannels;
121822 var isVideo = typeof HTMLVideoElement !== 'undefined' && pixels instanceof HTMLVideoElement;
121823 var isImage = typeof HTMLImageElement !== 'undefined' && pixels instanceof HTMLImageElement;
121824 var _ref = isVideo ? [pixels.videoWidth, pixels.videoHeight] : [pixels.width, pixels.height],
121825 _ref2 = _slicedToArray(_ref, 2),
121826 width = _ref2[0],
121827 height = _ref2[1];
121828 var texShape = [height, width];
121829 var outShape = [height, width, numChannels];
121830 if (isImage || isVideo) {
121831 var newWillReadFrequently = env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU');
121832 if (fromPixels2DContext == null || newWillReadFrequently !== willReadFrequently) {
121833 willReadFrequently = newWillReadFrequently;
121834 fromPixels2DContext = document.createElement('canvas').getContext('2d', {
121835 willReadFrequently: willReadFrequently
121836 });
121837 }
121838 fromPixels2DContext.canvas.width = width;
121839 fromPixels2DContext.canvas.height = height;
121840 fromPixels2DContext.drawImage(pixels, 0, 0, width, height);
121841 pixels = fromPixels2DContext.canvas;
121842 }
121843 var tempPixelHandle = backend.makeTensorInfo(texShape, 'int32');
121844 // This is a byte texture with pixels.
121845 backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS;
121846 backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels);
121847 var program = env().getBool('WEBGL_PACK') ? new FromPixelsPackedProgram(outShape) : new FromPixelsProgram(outShape);
121848 var res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32');
121849 backend.disposeData(tempPixelHandle.dataId);
121850 return res;
121851 }
121852
121853 /**
121854 * @license
121855 * Copyright 2020 Google LLC. All Rights Reserved.
121856 * Licensed under the Apache License, Version 2.0 (the "License");
121857 * you may not use this file except in compliance with the License.
121858 * You may obtain a copy of the License at
121859 *
121860 * http://www.apache.org/licenses/LICENSE-2.0
121861 *
121862 * Unless required by applicable law or agreed to in writing, software
121863 * distributed under the License is distributed on an "AS IS" BASIS,
121864 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121865 * See the License for the specific language governing permissions and
121866 * limitations under the License.
121867 * =============================================================================
121868 */
121869 function fusedConv2d(args) {
121870 var inputs = args.inputs,
121871 backend = args.backend,
121872 attrs = args.attrs;
121873 var x = inputs.x,
121874 filter = inputs.filter,
121875 bias = inputs.bias,
121876 preluActivationWeights = inputs.preluActivationWeights;
121877 var strides = attrs.strides,
121878 pad = attrs.pad,
121879 dataFormat = attrs.dataFormat,
121880 dilations = attrs.dilations,
121881 dimRoundingMode = attrs.dimRoundingMode,
121882 activation = attrs.activation,
121883 leakyreluAlpha = attrs.leakyreluAlpha;
121884 var $dataFormat = convertConv2DDataFormat(dataFormat);
121885 var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
121886 var out;
121887 var intermediates = [];
121888 var hasBias = bias != null;
121889 var hasPreluActivationWeights = preluActivationWeights != null;
121890 var hasLeakyreluAlpha = activation === 'leakyrelu';
121891 var prepareInputs = function prepareInputs() {
121892 var inputs = [x, filter];
121893 // If the input is a 1-D tensor, align it with the channels.
121894 //
121895 // For fusedConv2d, the inputs (x, W, bias, preluActivationWeights) are
121896 // supposed to be aligned with the dataFormat. The 4-D tensor inputs or
121897 // scalar inputs are originally aligned, but the 1-D tensor inputs are
121898 // supposed to be aligned with the channels (only bias and PReLU activation
121899 // weights could be a 1-D tensor).
121900 var alignInputWithDataFormat = function alignInputWithDataFormat(input, dataFormat) {
121901 if (dataFormat === 'NCHW' && input.shape.length === 1 && input.shape[0] !== 1) {
121902 var alignedInput = reshape({
121903 inputs: {
121904 x: input
121905 },
121906 backend: backend,
121907 attrs: {
121908 shape: [input.shape[0], 1, 1]
121909 }
121910 });
121911 intermediates.push(alignedInput);
121912 return alignedInput;
121913 }
121914 return input;
121915 };
121916 if (hasBias) {
121917 inputs.push(alignInputWithDataFormat(bias, dataFormat));
121918 }
121919 if (hasPreluActivationWeights) {
121920 inputs.push(alignInputWithDataFormat(preluActivationWeights, dataFormat));
121921 }
121922 if (hasLeakyreluAlpha) {
121923 var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
121924 inputs.push($leakyreluAlpha);
121925 intermediates.push($leakyreluAlpha);
121926 }
121927 return inputs;
121928 };
121929 if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
121930 out = conv2dByMatMul({
121931 x: x,
121932 filter: filter,
121933 convInfo: convInfo,
121934 backend: backend,
121935 bias: bias,
121936 activation: activation,
121937 preluActivationWeights: preluActivationWeights,
121938 leakyreluAlpha: leakyreluAlpha
121939 });
121940 } else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast' && env().getBool('WEBGL_EXP_CONV')) {
121941 var fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
121942 var program = new Conv2DPackedProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
121943 var customValues = [[convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth]];
121944 var _inputs = prepareInputs();
121945 out = backend.runWebGLProgram(program, _inputs, 'float32', customValues);
121946 } else if (env().getBool('WEBGL_CONV_IM2COL')) {
121947 out = conv2dWithIm2Row({
121948 x: x,
121949 filter: filter,
121950 convInfo: convInfo,
121951 backend: backend,
121952 bias: bias,
121953 activation: activation,
121954 preluActivationWeights: preluActivationWeights,
121955 leakyreluAlpha: leakyreluAlpha
121956 });
121957 } else {
121958 var _fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null;
121959 var _program = new Conv2DProgram(convInfo, hasBias, _fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
121960 var _inputs2 = prepareInputs();
121961 out = backend.runWebGLProgram(_program, _inputs2, 'float32');
121962 }
121963 var outReshaped = reshape({
121964 inputs: {
121965 x: out
121966 },
121967 backend: backend,
121968 attrs: {
121969 shape: convInfo.outShape
121970 }
121971 });
121972 intermediates.push(out);
121973 intermediates.forEach(function (t) {
121974 return backend.disposeIntermediateTensorInfo(t);
121975 });
121976 return outReshaped;
121977 }
121978 var fusedConv2DConfig = {
121979 kernelName: FusedConv2D,
121980 backendName: 'webgl',
121981 kernelFunc: fusedConv2d
121982 };
121983
121984 /**
121985 * @license
121986 * Copyright 2020 Google LLC. All Rights Reserved.
121987 * Licensed under the Apache License, Version 2.0 (the "License");
121988 * you may not use this file except in compliance with the License.
121989 * You may obtain a copy of the License at
121990 *
121991 * http://www.apache.org/licenses/LICENSE-2.0
121992 *
121993 * Unless required by applicable law or agreed to in writing, software
121994 * distributed under the License is distributed on an "AS IS" BASIS,
121995 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121996 * See the License for the specific language governing permissions and
121997 * limitations under the License.
121998 * =============================================================================
121999 */
122000 function fusedDepthwiseConv2D(args) {
122001 var inputs = args.inputs,
122002 backend = args.backend,
122003 attrs = args.attrs;
122004 var x = inputs.x,
122005 filter = inputs.filter,
122006 bias = inputs.bias,
122007 preluActivationWeights = inputs.preluActivationWeights;
122008 var strides = attrs.strides,
122009 pad = attrs.pad,
122010 dilations = attrs.dilations,
122011 dimRoundingMode = attrs.dimRoundingMode,
122012 activation = attrs.activation,
122013 leakyreluAlpha = attrs.leakyreluAlpha;
122014 var intermediates = [];
122015 var $dilations = dilations;
122016 if ($dilations == null) {
122017 $dilations = [1, 1];
122018 }
122019 assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), function () {
122020 return 'Error in depthwiseConv2d: Either strides or dilations must be ' + "1. Got strides ".concat(strides, " and dilations '").concat($dilations, "'");
122021 });
122022 var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
122023 var shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1;
122024 var fusedActivation = activation ? mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) : null;
122025 var programInputs = [x, filter];
122026 var hasBias = bias != null;
122027 var hasPreluActivationWeights = preluActivationWeights != null;
122028 var hasLeakyreluAlpha = activation === 'leakyrelu';
122029 if (hasBias) {
122030 programInputs.push(bias);
122031 }
122032 if (hasPreluActivationWeights) {
122033 programInputs.push(preluActivationWeights);
122034 }
122035 if (hasLeakyreluAlpha) {
122036 var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
122037 programInputs.push($leakyreluAlpha);
122038 intermediates.push($leakyreluAlpha);
122039 }
122040 var program;
122041 if (shouldPackDepthwiseConv) {
122042 program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
122043 } else {
122044 program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
122045 }
122046 var customValues = [[convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth]];
122047 var result = backend.runWebGLProgram(program, programInputs, 'float32', customValues);
122048 intermediates.forEach(function (t) {
122049 return backend.disposeIntermediateTensorInfo(t);
122050 });
122051 return result;
122052 }
122053 var fusedDepthwiseConv2DConfig = {
122054 kernelName: FusedDepthwiseConv2D,
122055 backendName: 'webgl',
122056 kernelFunc: fusedDepthwiseConv2D
122057 };
122058
122059 var GatherNDProgram = /*#__PURE__*/_createClass(function GatherNDProgram(sliceDim, strides, shape, paramsShape) {
122060 _classCallCheck(this, GatherNDProgram);
122061 this.sliceDim = sliceDim;
122062 this.strides = strides;
122063 this.paramsShape = paramsShape;
122064 this.variableNames = ['x', 'indices'];
122065 this.outputShape = shape;
122066 var dtype = getCoordsDataType(shape.length);
122067 var mainLoop = "\n int index;";
122068 for (var j = 0; j < this.sliceDim; j++) {
122069 mainLoop += "\n index = round(getIndices(coords[0], ".concat(j, "));\n out_of_bounds = out_of_bounds || index < 0;\n out_of_bounds = out_of_bounds || index >= ").concat(this.paramsShape[j], ";\n flattenIndex += index * ").concat(this.strides[j], ";");
122070 }
122071 this.userCode = "\n void main() {\n ".concat(dtype, " coords = getOutputCoords();\n int flattenIndex = 0;\n bool out_of_bounds = false;\n\n ").concat(mainLoop, "\n\n setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1]));\n }\n ");
122072 });
122073
122074 function gatherNd(args) {
122075 var inputs = args.inputs,
122076 backend = args.backend;
122077 var params = inputs.params,
122078 indices = inputs.indices;
122079 var indicesShape = indices.shape;
122080 var sliceRank = indicesShape[indicesShape.length - 1];
122081 var paramsSize = sizeFromShape(params.shape);
122082 var _backend_util$prepare = prepareAndValidate(params, indices),
122083 _backend_util$prepare2 = _slicedToArray(_backend_util$prepare, 4),
122084 resultShape = _backend_util$prepare2[0],
122085 numSlices = _backend_util$prepare2[1],
122086 sliceSize = _backend_util$prepare2[2],
122087 strides = _backend_util$prepare2[3];
122088 var flattenIndices = reshape({
122089 inputs: {
122090 x: indices
122091 },
122092 backend: backend,
122093 attrs: {
122094 shape: [numSlices, sliceRank]
122095 }
122096 });
122097 var flattenX = reshape({
122098 inputs: {
122099 x: params
122100 },
122101 backend: backend,
122102 attrs: {
122103 shape: [sizeFromShape(params.shape) / sliceSize, sliceSize]
122104 }
122105 });
122106 if (backend.shouldExecuteOnCPU([params, indices]) || params.dtype === 'string') {
122107 var indicesData = backend.readSync(indices.dataId);
122108 var paramsBuf = backend.bufferSync(params);
122109 var outValue = gatherNdImplCPU(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
122110 return backend.makeTensorInfo(resultShape, params.dtype, outValue.values);
122111 }
122112 var program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize], params.shape);
122113 var res = backend.runWebGLProgram(program, [flattenX, flattenIndices], flattenX.dtype);
122114 var reshaped = reshape({
122115 inputs: {
122116 x: res
122117 },
122118 backend: backend,
122119 attrs: {
122120 shape: resultShape
122121 }
122122 });
122123 backend.disposeIntermediateTensorInfo(flattenIndices);
122124 backend.disposeIntermediateTensorInfo(flattenX);
122125 backend.disposeIntermediateTensorInfo(res);
122126 return reshaped;
122127 }
122128 var gatherNdConfig = {
122129 kernelName: GatherNd,
122130 backendName: 'webgl',
122131 kernelFunc: gatherNd
122132 };
122133
122134 var GatherProgram = /*#__PURE__*/_createClass(function GatherProgram(aShape, outputShape) {
122135 _classCallCheck(this, GatherProgram);
122136 this.variableNames = ['A', 'indices'];
122137 this.outputShape = outputShape;
122138 this.rank = outputShape.length;
122139 var dtype = getCoordsDataType(this.rank);
122140 var sourceCoords = getSourceCoords$1(aShape, 2);
122141 this.userCode = "\n void main() {\n ".concat(dtype, " resRC = getOutputCoords();\n int index = int(getIndices(resRC.x, resRC.z));\n float inBounds = (index >= 0) && (index < ").concat(aShape[2], ") ? 1.0 : 0.0;\n setOutput(inBounds * getA(").concat(sourceCoords, "));\n }\n ");
122142 });
122143 // The input and output are always flattened into rank 4 tensors.
122144 function getSourceCoords$1(aShape, axis) {
122145 var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
122146 var sourceCoords = [];
122147 for (var i = 0; i < aShape.length; i++) {
122148 if (i === 2) {
122149 sourceCoords.push('index');
122150 } else {
122151 sourceCoords.push("".concat(currentCoords[i]));
122152 }
122153 }
122154 return sourceCoords.join();
122155 }
122156
122157 /**
122158 * @license
122159 * Copyright 2020 Google LLC. All Rights Reserved.
122160 * Licensed under the Apache License, Version 2.0 (the "License");
122161 * you may not use this file except in compliance with the License.
122162 * You may obtain a copy of the License at
122163 *
122164 * http://www.apache.org/licenses/LICENSE-2.0
122165 *
122166 * Unless required by applicable law or agreed to in writing, software
122167 * distributed under the License is distributed on an "AS IS" BASIS,
122168 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122169 * See the License for the specific language governing permissions and
122170 * limitations under the License.
122171 * =============================================================================
122172 */
122173 function gatherV2(args) {
122174 var inputs = args.inputs,
122175 backend = args.backend,
122176 attrs = args.attrs;
122177 var x = inputs.x,
122178 indices = inputs.indices;
122179 var axis = attrs.axis,
122180 batchDims = attrs.batchDims;
122181 var parsedAxis = parseAxisParam(axis, x.shape)[0];
122182 if (env().get('DEBUG')) {
122183 // In debug mode, throw error when any index is out of bound.
122184 // Otherwise, just fill out of bounds with zeroes.
122185 var indicesVals = backend.readSync(indices.dataId);
122186 var axisDim = x.shape[parsedAxis];
122187 var _loop = function _loop() {
122188 var index = indicesVals[i];
122189 assert$1(index <= axisDim - 1 && index >= 0, function () {
122190 return "GatherV2: the index value ".concat(index, " is not in [0, ").concat(axisDim - 1, "]");
122191 });
122192 };
122193 for (var i = 0; i < indicesVals.length; ++i) {
122194 _loop();
122195 }
122196 }
122197 var shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, batchDims);
122198 var indicesSize = sizeFromShape(indices.shape);
122199 var toDispose = [];
122200 var flattenX = reshape({
122201 inputs: {
122202 x: x
122203 },
122204 backend: backend,
122205 attrs: {
122206 shape: [shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize, shapeInfo.sliceSize]
122207 }
122208 });
122209 var flattenIndex = reshape({
122210 inputs: {
122211 x: indices
122212 },
122213 backend: backend,
122214 attrs: {
122215 shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize]
122216 }
122217 });
122218 toDispose.push(flattenX);
122219 toDispose.push(flattenIndex);
122220 var flattenOutputShape = [shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize, shapeInfo.sliceSize];
122221 if (backend.shouldExecuteOnCPU([x, indices]) || x.dtype === 'string') {
122222 var indicesBuf = backend.bufferSync(flattenIndex);
122223 var xBuf = backend.bufferSync(flattenX);
122224 var outBuf = gatherV2ImplCPU(xBuf, indicesBuf, flattenOutputShape);
122225 toDispose.forEach(function (t) {
122226 return backend.disposeIntermediateTensorInfo(t);
122227 });
122228 return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
122229 }
122230 var program = new GatherProgram(flattenX.shape, flattenOutputShape);
122231 var res = backend.runWebGLProgram(program, [flattenX, flattenIndex], flattenX.dtype);
122232 toDispose.push(res);
122233 var reshaped = reshape({
122234 inputs: {
122235 x: res
122236 },
122237 backend: backend,
122238 attrs: {
122239 shape: shapeInfo.outputShape
122240 }
122241 });
122242 toDispose.forEach(function (t) {
122243 return backend.disposeIntermediateTensorInfo(t);
122244 });
122245 return reshaped;
122246 }
122247 var gatherV2Config = {
122248 kernelName: GatherV2,
122249 backendName: 'webgl',
122250 kernelFunc: gatherV2
122251 };
122252
122253 /**
122254 * @license
122255 * Copyright 2020 Google LLC. All Rights Reserved.
122256 * Licensed under the Apache License, Version 2.0 (the "License");
122257 * you may not use this file except in compliance with the License.
122258 * You may obtain a copy of the License at
122259 *
122260 * http://www.apache.org/licenses/LICENSE-2.0
122261 *
122262 * Unless required by applicable law or agreed to in writing, software
122263 * distributed under the License is distributed on an "AS IS" BASIS,
122264 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122265 * See the License for the specific language governing permissions and
122266 * limitations under the License.
122267 * =============================================================================
122268 */
122269 var GREATER = "return float(a > b);";
122270 var GREATER_PACKED = "\n return vec4(greaterThan(a, b));\n";
122271 var greater = binaryKernelFunc({
122272 opSnippet: GREATER,
122273 packedOpSnippet: GREATER_PACKED,
122274 cpuKernelImpl: greaterImplCPU,
122275 dtype: 'bool'
122276 });
122277 var greaterConfig = {
122278 kernelName: Greater,
122279 backendName: 'webgl',
122280 kernelFunc: greater
122281 };
122282
122283 /**
122284 * @license
122285 * Copyright 2020 Google LLC. All Rights Reserved.
122286 * Licensed under the Apache License, Version 2.0 (the "License");
122287 * you may not use this file except in compliance with the License.
122288 * You may obtain a copy of the License at
122289 *
122290 * http://www.apache.org/licenses/LICENSE-2.0
122291 *
122292 * Unless required by applicable law or agreed to in writing, software
122293 * distributed under the License is distributed on an "AS IS" BASIS,
122294 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122295 * See the License for the specific language governing permissions and
122296 * limitations under the License.
122297 * =============================================================================
122298 */
122299 var GREATER_EQUAL = "return float(a >= b);";
122300 var GREATER_EQUAL_PACKED = "\n return vec4(greaterThanEqual(a, b));\n";
122301 var greaterEqual = binaryKernelFunc({
122302 opSnippet: GREATER_EQUAL,
122303 packedOpSnippet: GREATER_EQUAL_PACKED,
122304 dtype: 'bool',
122305 cpuKernelImpl: greaterEqualImplCPU
122306 });
122307 var greaterEqualConfig = {
122308 kernelName: GreaterEqual,
122309 backendName: 'webgl',
122310 kernelFunc: greaterEqual
122311 };
122312
122313 /**
122314 * @license
122315 * Copyright 2020 Google LLC. All Rights Reserved.
122316 * Licensed under the Apache License, Version 2.0 (the "License");
122317 * you may not use this file except in compliance with the License.
122318 * You may obtain a copy of the License at
122319 *
122320 * http://www.apache.org/licenses/LICENSE-2.0
122321 *
122322 * Unless required by applicable law or agreed to in writing, software
122323 * distributed under the License is distributed on an "AS IS" BASIS,
122324 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122325 * See the License for the specific language governing permissions and
122326 * limitations under the License.
122327 * =============================================================================
122328 */
122329 function ifft(args) {
122330 var inputs = args.inputs,
122331 backend = args.backend;
122332 var input = inputs.input;
122333 return fftImpl(input, true /* inverse */, backend);
122334 }
122335 var ifftConfig = {
122336 kernelName: IFFT,
122337 backendName: 'webgl',
122338 kernelFunc: ifft
122339 };
122340
122341 /**
122342 * @license
122343 * Copyright 2020 Google LLC. All Rights Reserved.
122344 * Licensed under the Apache License, Version 2.0 (the "License");
122345 * you may not use this file except in compliance with the License.
122346 * You may obtain a copy of the License at
122347 *
122348 * http://www.apache.org/licenses/LICENSE-2.0
122349 *
122350 * Unless required by applicable law or agreed to in writing, software
122351 * distributed under the License is distributed on an "AS IS" BASIS,
122352 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122353 * See the License for the specific language governing permissions and
122354 * limitations under the License.
122355 * =============================================================================
122356 */
122357 var IS_FINITE = "return float(!isnan(x) && !isinf(x));";
122358 var isFinite$1 = unaryKernelFunc({
122359 opSnippet: IS_FINITE,
122360 dtype: 'bool'
122361 });
122362 var isFiniteConfig = {
122363 kernelName: IsFinite,
122364 backendName: 'webgl',
122365 kernelFunc: isFinite$1
122366 };
122367
122368 /**
122369 * @license
122370 * Copyright 2020 Google LLC. All Rights Reserved.
122371 * Licensed under the Apache License, Version 2.0 (the "License");
122372 * you may not use this file except in compliance with the License.
122373 * You may obtain a copy of the License at
122374 *
122375 * http://www.apache.org/licenses/LICENSE-2.0
122376 *
122377 * Unless required by applicable law or agreed to in writing, software
122378 * distributed under the License is distributed on an "AS IS" BASIS,
122379 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122380 * See the License for the specific language governing permissions and
122381 * limitations under the License.
122382 * =============================================================================
122383 */
122384 var IS_INF = "return float(isinf(x));";
122385 var isInf = unaryKernelFunc({
122386 opSnippet: IS_INF,
122387 dtype: 'bool'
122388 });
122389 var isInfConfig = {
122390 kernelName: IsInf,
122391 backendName: 'webgl',
122392 kernelFunc: isInf
122393 };
122394
122395 /**
122396 * @license
122397 * Copyright 2020 Google LLC. All Rights Reserved.
122398 * Licensed under the Apache License, Version 2.0 (the "License");
122399 * you may not use this file except in compliance with the License.
122400 * You may obtain a copy of the License at
122401 *
122402 * http://www.apache.org/licenses/LICENSE-2.0
122403 *
122404 * Unless required by applicable law or agreed to in writing, software
122405 * distributed under the License is distributed on an "AS IS" BASIS,
122406 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122407 * See the License for the specific language governing permissions and
122408 * limitations under the License.
122409 * =============================================================================
122410 */
122411 var IS_NAN = "return float(isnan(x));";
122412 var isNaN$1 = unaryKernelFunc({
122413 opSnippet: IS_NAN,
122414 dtype: 'bool'
122415 });
122416 var isNaNConfig = {
122417 kernelName: IsNan,
122418 backendName: 'webgl',
122419 kernelFunc: isNaN$1
122420 };
122421
122422 /**
122423 * @license
122424 * Copyright 2020 Google LLC. All Rights Reserved.
122425 * Licensed under the Apache License, Version 2.0 (the "License");
122426 * you may not use this file except in compliance with the License.
122427 * You may obtain a copy of the License at
122428 *
122429 * http://www.apache.org/licenses/LICENSE-2.0
122430 *
122431 * Unless required by applicable law or agreed to in writing, software
122432 * distributed under the License is distributed on an "AS IS" BASIS,
122433 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122434 * See the License for the specific language governing permissions and
122435 * limitations under the License.
122436 * =============================================================================
122437 */
122438 var LESS = "return float(a < b);";
122439 var LESS_PACKED = "\n return vec4(lessThan(a, b));\n";
122440 var less = binaryKernelFunc({
122441 opSnippet: LESS,
122442 packedOpSnippet: LESS_PACKED,
122443 cpuKernelImpl: lessImplCPU,
122444 dtype: 'bool'
122445 });
122446 var lessConfig = {
122447 kernelName: Less,
122448 backendName: 'webgl',
122449 kernelFunc: less
122450 };
122451
122452 /**
122453 * @license
122454 * Copyright 2020 Google LLC. All Rights Reserved.
122455 * Licensed under the Apache License, Version 2.0 (the "License");
122456 * you may not use this file except in compliance with the License.
122457 * You may obtain a copy of the License at
122458 *
122459 * http://www.apache.org/licenses/LICENSE-2.0
122460 *
122461 * Unless required by applicable law or agreed to in writing, software
122462 * distributed under the License is distributed on an "AS IS" BASIS,
122463 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122464 * See the License for the specific language governing permissions and
122465 * limitations under the License.
122466 * =============================================================================
122467 */
122468 var LESS_EQUAL = "return float(a <= b);";
122469 var LESS_EQUAL_PACKED = "\n return vec4(lessThanEqual(a, b));\n";
122470 var lessEqual = binaryKernelFunc({
122471 opSnippet: LESS_EQUAL,
122472 packedOpSnippet: LESS_EQUAL_PACKED,
122473 cpuKernelImpl: lessEqualImplCPU,
122474 dtype: 'bool'
122475 });
122476 var lessEqualConfig = {
122477 kernelName: LessEqual,
122478 backendName: 'webgl',
122479 kernelFunc: lessEqual
122480 };
122481
122482 /**
122483 * @license
122484 * Copyright 2020 Google LLC. All Rights Reserved.
122485 * Licensed under the Apache License, Version 2.0 (the "License");
122486 * you may not use this file except in compliance with the License.
122487 * You may obtain a copy of the License at
122488 *
122489 * http://www.apache.org/licenses/LICENSE-2.0
122490 *
122491 * Unless required by applicable law or agreed to in writing, software
122492 * distributed under the License is distributed on an "AS IS" BASIS,
122493 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122494 * See the License for the specific language governing permissions and
122495 * limitations under the License.
122496 * =============================================================================
122497 */
122498 function linSpace(args) {
122499 var backend = args.backend,
122500 attrs = args.attrs;
122501 var start = attrs.start,
122502 stop = attrs.stop,
122503 num = attrs.num;
122504 // TODO: Use CPU implementation due to the precision problem in Safari.
122505 var outVals = linSpaceImplCPU(start, stop, num);
122506 return backend.makeTensorInfo([outVals.length], 'float32', outVals);
122507 }
122508 var linSpaceConfig = {
122509 kernelName: LinSpace,
122510 backendName: 'webgl',
122511 kernelFunc: linSpace
122512 };
122513
122514 /**
122515 * @license
122516 * Copyright 2020 Google LLC. All Rights Reserved.
122517 * Licensed under the Apache License, Version 2.0 (the "License");
122518 * you may not use this file except in compliance with the License.
122519 * You may obtain a copy of the License at
122520 *
122521 * http://www.apache.org/licenses/LICENSE-2.0
122522 *
122523 * Unless required by applicable law or agreed to in writing, software
122524 * distributed under the License is distributed on an "AS IS" BASIS,
122525 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122526 * See the License for the specific language governing permissions and
122527 * limitations under the License.
122528 * =============================================================================
122529 */
122530 // Windows chrome return 0 if the input is negative value. We will specifically
122531 // return NaN if the input is 0 to solve compatiblity issue.
122532 var LOG = CHECK_NAN_SNIPPET_UNARY + "\n return x < 0.0 ? 0./0. : log(x);\n";
122533 var LOG_PACKED = "\n vec4 result = log(x);\n bvec4 isNaN = isnan(x);\n result.r = isNaN.r ? x.r : (x.r < 0.0 ? 0./0. : result.r);\n result.g = isNaN.g ? x.g : (x.g < 0.0 ? 0./0. : result.g);\n result.b = isNaN.b ? x.b : (x.b < 0.0 ? 0./0. : result.b);\n result.a = isNaN.a ? x.a : (x.a < 0.0 ? 0./0. : result.a);\n return result;\n";
122534 var log = unaryKernelFunc({
122535 opSnippet: LOG,
122536 packedOpSnippet: LOG_PACKED,
122537 cpuKernelImpl: logImplCPU
122538 });
122539 var logConfig = {
122540 kernelName: Log,
122541 backendName: 'webgl',
122542 kernelFunc: log
122543 };
122544
122545 /**
122546 * @license
122547 * Copyright 2020 Google LLC. All Rights Reserved.
122548 * Licensed under the Apache License, Version 2.0 (the "License");
122549 * you may not use this file except in compliance with the License.
122550 * You may obtain a copy of the License at
122551 *
122552 * http://www.apache.org/licenses/LICENSE-2.0
122553 *
122554 * Unless required by applicable law or agreed to in writing, software
122555 * distributed under the License is distributed on an "AS IS" BASIS,
122556 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122557 * See the License for the specific language governing permissions and
122558 * limitations under the License.
122559 * =============================================================================
122560 */
122561 var LOG1P = CHECK_NAN_SNIPPET_UNARY + "\n return log(1.0 + x);\n";
122562 var log1p = unaryKernelFunc({
122563 opSnippet: LOG1P
122564 });
122565 var log1pConfig = {
122566 kernelName: Log1p,
122567 backendName: 'webgl',
122568 kernelFunc: log1p
122569 };
122570
122571 /**
122572 * @license
122573 * Copyright 2020 Google LLC. All Rights Reserved.
122574 * Licensed under the Apache License, Version 2.0 (the "License");
122575 * you may not use this file except in compliance with the License.
122576 * You may obtain a copy of the License at
122577 *
122578 * http://www.apache.org/licenses/LICENSE-2.0
122579 *
122580 * Unless required by applicable law or agreed to in writing, software
122581 * distributed under the License is distributed on an "AS IS" BASIS,
122582 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122583 * See the License for the specific language governing permissions and
122584 * limitations under the License.
122585 * =============================================================================
122586 */
122587 var LOGICAL_AND = "return float(a >= 1.0 && b >= 1.0);";
122588 var LOGICAL_AND_PACKED = "\n return vec4(\n vec4(greaterThanEqual(a, vec4(1.0))) *\n vec4(greaterThanEqual(b, vec4(1.0))));\n";
122589 var logicalAnd = binaryKernelFunc({
122590 opSnippet: LOGICAL_AND,
122591 packedOpSnippet: LOGICAL_AND_PACKED,
122592 dtype: 'bool'
122593 });
122594 var logicalAndConfig = {
122595 kernelName: LogicalAnd,
122596 backendName: 'webgl',
122597 kernelFunc: logicalAnd
122598 };
122599
122600 /**
122601 * @license
122602 * Copyright 2020 Google LLC. All Rights Reserved.
122603 * Licensed under the Apache License, Version 2.0 (the "License");
122604 * you may not use this file except in compliance with the License.
122605 * You may obtain a copy of the License at
122606 *
122607 * http://www.apache.org/licenses/LICENSE-2.0
122608 *
122609 * Unless required by applicable law or agreed to in writing, software
122610 * distributed under the License is distributed on an "AS IS" BASIS,
122611 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122612 * See the License for the specific language governing permissions and
122613 * limitations under the License.
122614 * =============================================================================
122615 */
122616 var LOGICAL_NOT = "return float(!(x >= 1.0));";
122617 var logicalNot = unaryKernelFunc({
122618 opSnippet: LOGICAL_NOT
122619 });
122620 var logicalNotConfig = {
122621 kernelName: LogicalNot,
122622 backendName: 'webgl',
122623 kernelFunc: logicalNot
122624 };
122625
122626 /**
122627 * @license
122628 * Copyright 2020 Google LLC. All Rights Reserved.
122629 * Licensed under the Apache License, Version 2.0 (the "License");
122630 * you may not use this file except in compliance with the License.
122631 * You may obtain a copy of the License at
122632 *
122633 * http://www.apache.org/licenses/LICENSE-2.0
122634 *
122635 * Unless required by applicable law or agreed to in writing, software
122636 * distributed under the License is distributed on an "AS IS" BASIS,
122637 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122638 * See the License for the specific language governing permissions and
122639 * limitations under the License.
122640 * =============================================================================
122641 */
122642 var LOGICAL_OR = "return float(a >= 1.0 || b >= 1.0);";
122643 var LOGICAL_OR_PACKED = "\n return min(\n vec4(greaterThanEqual(a, vec4(1.0))) +\n vec4(greaterThanEqual(b, vec4(1.0))),\n vec4(1.0));\n";
122644 var logicalOr = binaryKernelFunc({
122645 opSnippet: LOGICAL_OR,
122646 packedOpSnippet: LOGICAL_OR_PACKED,
122647 dtype: 'bool'
122648 });
122649 var logicalOrConfig = {
122650 kernelName: LogicalOr,
122651 backendName: 'webgl',
122652 kernelFunc: logicalOr
122653 };
122654
122655 /**
122656 * @license
122657 * Copyright 2017 Google LLC. All Rights Reserved.
122658 * Licensed under the Apache License, Version 2.0 (the "License");
122659 * you may not use this file except in compliance with the License.
122660 * You may obtain a copy of the License at
122661 *
122662 * http://www.apache.org/licenses/LICENSE-2.0
122663 *
122664 * Unless required by applicable law or agreed to in writing, software
122665 * distributed under the License is distributed on an "AS IS" BASIS,
122666 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122667 * See the License for the specific language governing permissions and
122668 * limitations under the License.
122669 * =============================================================================
122670 */
122671 var LRNProgram = /*#__PURE__*/_createClass(function LRNProgram(xShape, radius, bias, alpha, beta) {
122672 _classCallCheck(this, LRNProgram);
122673 this.variableNames = ['x'];
122674 this.outputShape = [];
122675 var rad = radius;
122676 var maxD = xShape[3] - 1;
122677 this.outputShape = xShape;
122678 // optimize pow(bias + alpha * sum, -beta)
122679 // src: https://github.com/tensorflow/tensorflow/..
122680 // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
122681 // tensorflow/core/kernels/mkl_lrn_op.cc#L320
122682 var powOperator;
122683 var basis = "float(".concat(bias, ") + float(").concat(alpha, ") * sum");
122684 if (beta === 0.5) {
122685 powOperator = "inversesqrt(".concat(basis, ")");
122686 } else if (beta === 1.0) {
122687 powOperator = "1.0/(".concat(basis, ")");
122688 } else {
122689 powOperator = "exp(log(".concat(basis, ") * float(-").concat(beta, "));");
122690 }
122691 this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n int d = coords[3];\n float x = getX(b, r, c, d);\n float sum = 0.0;\n for (int j = -".concat(rad, "; j <= ").concat(rad, "; j++) {\n int idx = d + j;\n if (idx >= 0 && idx <= ").concat(maxD, ") {\n float z = getX(b, r, c, idx);\n sum += z * z;\n }\n }\n float val = x * ").concat(powOperator, ";\n setOutput(val);\n }\n ");
122692 });
122693
122694 /**
122695 * @license
122696 * Copyright 2019 Google LLC. All Rights Reserved.
122697 * Licensed under the Apache License, Version 2.0 (the "License");
122698 * you may not use this file except in compliance with the License.
122699 * You may obtain a copy of the License at
122700 *
122701 * http://www.apache.org/licenses/LICENSE-2.0
122702 *
122703 * Unless required by applicable law or agreed to in writing, software
122704 * distributed under the License is distributed on an "AS IS" BASIS,
122705 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122706 * See the License for the specific language governing permissions and
122707 * limitations under the License.
122708 * =============================================================================
122709 */
122710 var LRNPackedProgram = /*#__PURE__*/_createClass(function LRNPackedProgram(xShape, radius, bias, alpha, beta) {
122711 _classCallCheck(this, LRNPackedProgram);
122712 this.variableNames = ['x'];
122713 this.outputShape = [];
122714 this.packedInputs = true;
122715 this.packedOutput = true;
122716 var rad = radius;
122717 var maxD = xShape[3] - 1;
122718 this.outputShape = xShape;
122719 // optimize pow(bias + alpha * sum, -beta)
122720 // src: https://github.com/tensorflow/tensorflow/..
122721 // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
122722 // tensorflow/core/kernels/mkl_lrn_op.cc#L320
122723 var powOperator;
122724 var basis = "float(".concat(bias, ") + float(").concat(alpha, ") * sum");
122725 if (beta === 0.5) {
122726 powOperator = "inversesqrt(".concat(basis, ")");
122727 } else if (beta === 1.0) {
122728 powOperator = "1.0/(".concat(basis, ")");
122729 } else {
122730 powOperator = "exp(log(".concat(basis, ") * float(-").concat(beta, "));");
122731 }
122732 this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords.x;\n int r = coords.y;\n int c = coords.z;\n int d = coords.w;\n\n bool hasNextCol = d < ".concat(this.outputShape[3], ";\n bool hasNextRow = c < ").concat(this.outputShape[2], ";\n\n vec4 sum = vec4(0.);\n vec4 xFragAtOutputCoords = getX(b, r, c, d);\n\n vec4 xAtOutputCoords = vec4(\n getChannel(xFragAtOutputCoords, vec2(c, d)),\n hasNextCol ?\n getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,\n hasNextRow ?\n getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,\n (hasNextRow && hasNextCol) ?\n getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0\n );\n\n int firstChannel = d - ").concat(rad, ";\n vec2 cache = vec2(0.);\n if(firstChannel >= 0){\n vec4 firstChannelFrag = getX(b, r, c, firstChannel);\n cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));\n if(hasNextRow){\n cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));\n }\n }\n\n ivec2 depth = ivec2(d, d + 1);\n for (int j = - ").concat(rad, "; j <= ").concat(rad, "; j++) {\n ivec2 idx = depth + j;\n bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));\n bvec2 belowUpperBound = lessThanEqual(idx, ivec2(").concat(maxD, "));\n\n bool depthInRange = aboveLowerBound.x && belowUpperBound.x;\n bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;\n\n if(depthInRange || depthPlusOneInRange){\n vec4 z = vec4(0.);\n vec4 xFragAtCurrentDepth;\n z.xz = cache.xy;\n if(depthPlusOneInRange && hasNextCol){\n xFragAtCurrentDepth = idx.y != d ?\n getX(b, r, c, idx.y) : xFragAtOutputCoords;\n z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));\n if(hasNextRow){\n z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));\n }\n }\n cache.xy = z.yw;\n sum += z * z;\n }\n }\n vec4 result = xAtOutputCoords * ").concat(powOperator, ";\n setOutput(result);\n }\n ");
122733 });
122734
122735 /**
122736 * @license
122737 * Copyright 2020 Google LLC. All Rights Reserved.
122738 * Licensed under the Apache License, Version 2.0 (the "License");
122739 * you may not use this file except in compliance with the License.
122740 * You may obtain a copy of the License at
122741 *
122742 * http://www.apache.org/licenses/LICENSE-2.0
122743 *
122744 * Unless required by applicable law or agreed to in writing, software
122745 * distributed under the License is distributed on an "AS IS" BASIS,
122746 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122747 * See the License for the specific language governing permissions and
122748 * limitations under the License.
122749 * =============================================================================
122750 */
122751 var lrn = function lrn(args) {
122752 var inputs = args.inputs,
122753 backend = args.backend,
122754 attrs = args.attrs;
122755 var x = inputs.x;
122756 var depthRadius = attrs.depthRadius,
122757 bias = attrs.bias,
122758 alpha = attrs.alpha,
122759 beta = attrs.beta;
122760 var program = env().getBool('WEBGL_PACK_NORMALIZATION') ? new LRNPackedProgram(x.shape, depthRadius, bias, alpha, beta) : new LRNProgram(x.shape, depthRadius, bias, alpha, beta);
122761 return backend.runWebGLProgram(program, [x], x.dtype);
122762 };
122763 // tslint:disable-next-line: variable-name
122764 var LRNConfig = {
122765 kernelName: LRN,
122766 backendName: 'webgl',
122767 kernelFunc: lrn
122768 };
122769
122770 /**
122771 * @license
122772 * Copyright 2018 Google LLC. All Rights Reserved.
122773 * Licensed under the Apache License, Version 2.0 (the "License");
122774 * you may not use this file except in compliance with the License.
122775 * You may obtain a copy of the License at
122776 *
122777 * http://www.apache.org/licenses/LICENSE-2.0
122778 *
122779 * Unless required by applicable law or agreed to in writing, software
122780 * distributed under the License is distributed on an "AS IS" BASIS,
122781 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122782 * See the License for the specific language governing permissions and
122783 * limitations under the License.
122784 * =============================================================================
122785 */
122786 var LRNGradProgram = /*#__PURE__*/_createClass(function LRNGradProgram(inputShape, depthRadius, bias, alpha, beta) {
122787 _classCallCheck(this, LRNGradProgram);
122788 this.variableNames = ['inputImage', 'outputImage', 'dy'];
122789 this.outputShape = [];
122790 this.outputShape = inputShape;
122791 this.depth = inputShape[3];
122792 this.depthRadius = depthRadius;
122793 this.bias = bias;
122794 this.alpha = alpha;
122795 this.beta = beta;
122796 this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n\n float result = 0.0;\n for (int d = 0; d < ".concat(this.depth, "; ++d) {\n int depthBegin = int(max(0.0, float(d - ").concat(depthRadius, ")));\n int depthEnd = int(min(float(").concat(this.depth, "),\n float(d + ").concat(depthRadius, " + 1)));\n\n const int MIN_DEPTH_BEGIN = 0;\n const int MAX_DEPTH_END = ").concat(this.depth, ";\n\n float norm = 0.0;\n for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd) {\n norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);\n }\n else {\n break;\n }\n }\n\n norm = float(").concat(alpha, ") * norm + float(").concat(bias, ");\n\n for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd){\n float dyi = -2.0 * float(").concat(alpha, ")\n * float(").concat(beta, ")\n * getInputImage(b, r, c, k) * getOutputImage(b, r, c, d)\n / norm;\n if (k == d) {\n dyi += pow(norm, -1.0 * ").concat(beta, ");\n }\n if (k == coords[3]) {\n dyi *= getDy(b, r, c, d);\n result += dyi;\n }\n }\n else {\n break;\n }\n }\n }\n setOutput(result);\n }\n ");
122797 });
122798
122799 /**
122800 * @license
122801 * Copyright 2020 Google LLC. All Rights Reserved.
122802 * Licensed under the Apache License, Version 2.0 (the "License");
122803 * you may not use this file except in compliance with the License.
122804 * You may obtain a copy of the License at
122805 *
122806 * http://www.apache.org/licenses/LICENSE-2.0
122807 *
122808 * Unless required by applicable law or agreed to in writing, software
122809 * distributed under the License is distributed on an "AS IS" BASIS,
122810 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122811 * See the License for the specific language governing permissions and
122812 * limitations under the License.
122813 * =============================================================================
122814 */
122815 var lrnGrad = function lrnGrad(args) {
122816 var inputs = args.inputs,
122817 backend = args.backend,
122818 attrs = args.attrs;
122819 var x = inputs.x,
122820 y = inputs.y,
122821 dy = inputs.dy;
122822 var depthRadius = attrs.depthRadius,
122823 bias = attrs.bias,
122824 alpha = attrs.alpha,
122825 beta = attrs.beta;
122826 var program = new LRNGradProgram(x.shape, depthRadius, bias, alpha, beta);
122827 return backend.runWebGLProgram(program, [x, y, dy], x.dtype);
122828 };
122829 // tslint:disable-next-line: variable-name
122830 var LRNGradConfig = {
122831 kernelName: LRNGrad,
122832 backendName: 'webgl',
122833 kernelFunc: lrnGrad
122834 };
122835
122836 /**
122837 * @license
122838 * Copyright 2020 Google LLC. All Rights Reserved.
122839 * Licensed under the Apache License, Version 2.0 (the "License");
122840 * you may not use this file except in compliance with the License.
122841 * You may obtain a copy of the License at
122842 *
122843 * http://www.apache.org/licenses/LICENSE-2.0
122844 *
122845 * Unless required by applicable law or agreed to in writing, software
122846 * distributed under the License is distributed on an "AS IS" BASIS,
122847 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122848 * See the License for the specific language governing permissions and
122849 * limitations under the License.
122850 * =============================================================================
122851 */
122852 function maxImpl(x, reduceShape, outShape, backend) {
122853 var inSize = sizeFromShape(reduceShape);
122854 var xSize = sizeFromShape(x.shape);
122855 var batchSize = xSize / inSize;
122856 var reshapedInput = reshape({
122857 inputs: {
122858 x: x
122859 },
122860 attrs: {
122861 shape: [batchSize, inSize]
122862 },
122863 backend: backend
122864 });
122865 var reduced = reduce(reshapedInput, x.dtype, 'max', backend);
122866 var reshapedOutput = reshape({
122867 inputs: {
122868 x: reduced
122869 },
122870 attrs: {
122871 shape: outShape
122872 },
122873 backend: backend
122874 });
122875 backend.disposeIntermediateTensorInfo(reshapedInput);
122876 backend.disposeIntermediateTensorInfo(reduced);
122877 return reshapedOutput;
122878 }
122879
122880 function max(args) {
122881 var inputs = args.inputs,
122882 backend = args.backend,
122883 attrs = args.attrs;
122884 var x = inputs.x;
122885 var reductionIndices = attrs.reductionIndices,
122886 keepDims = attrs.keepDims;
122887 var xRank = x.shape.length;
122888 var origAxes = parseAxisParam(reductionIndices, x.shape);
122889 var axes = origAxes;
122890 var permutedAxes = getAxesPermutation(axes, xRank);
122891 var maxInputIsTransposed = permutedAxes != null;
122892 var shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
122893 var maxInput = x;
122894 if (maxInputIsTransposed) {
122895 if (shouldExecuteOnCPU) {
122896 var xTexData = backend.texData.get(maxInput.dataId);
122897 var values = xTexData.values;
122898 var newShape = new Array(xRank);
122899 for (var i = 0; i < newShape.length; i++) {
122900 newShape[i] = x.shape[permutedAxes[i]];
122901 }
122902 var maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
122903 maxInput = backend.makeTensorInfo(newShape, x.dtype);
122904 var maxInputData = backend.texData.get(maxInput.dataId);
122905 maxInputData.values = maxInputValues;
122906 } else {
122907 maxInput = transposeImpl(x, permutedAxes, backend);
122908 }
122909 axes = getInnerMostAxes(axes.length, xRank);
122910 }
122911 assertAxesAreInnerMostDims('max', axes, xRank);
122912 var _backend_util$compute = computeOutAndReduceShapes(maxInput.shape, axes),
122913 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
122914 maxOutShape = _backend_util$compute2[0],
122915 reduceShape = _backend_util$compute2[1];
122916 var outShape = maxOutShape;
122917 if (keepDims) {
122918 // rather than reshape at the end, set the target shape here.
122919 outShape = expandShapeToKeepDim(maxOutShape, origAxes);
122920 }
122921 var out;
122922 if (shouldExecuteOnCPU) {
122923 var _xTexData = backend.texData.get(maxInput.dataId);
122924 var _values = _xTexData.values;
122925 var outValues = maxImplCPU(_values, sizeFromShape(reduceShape), outShape, x.dtype);
122926 out = backend.makeTensorInfo(outShape, x.dtype);
122927 var outData = backend.texData.get(out.dataId);
122928 outData.values = outValues;
122929 } else {
122930 out = maxImpl(maxInput, reduceShape, outShape, backend);
122931 }
122932 if (maxInputIsTransposed) {
122933 backend.disposeIntermediateTensorInfo(maxInput);
122934 }
122935 return out;
122936 }
122937 var maxConfig = {
122938 kernelName: Max,
122939 backendName: 'webgl',
122940 kernelFunc: max
122941 };
122942
122943 /**
122944 * @license
122945 * Copyright 2020 Google LLC. All Rights Reserved.
122946 * Licensed under the Apache License, Version 2.0 (the "License");
122947 * you may not use this file except in compliance with the License.
122948 * You may obtain a copy of the License at
122949 *
122950 * http://www.apache.org/licenses/LICENSE-2.0
122951 *
122952 * Unless required by applicable law or agreed to in writing, software
122953 * distributed under the License is distributed on an "AS IS" BASIS,
122954 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122955 * See the License for the specific language governing permissions and
122956 * limitations under the License.
122957 * =============================================================================
122958 */
122959 var MAXIMUM = CHECK_NAN_SNIPPET + "\n return max(a, b);\n";
122960 var MAXIMUM_PACKED = "\n vec4 result = vec4(max(a, b));\n bvec4 isNaNA = isnan(a);\n bvec4 isNaNB = isnan(b);\n bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);\n " + CHECK_NAN_SNIPPET_PACKED + "\n return result;\n";
122961 var maximum = binaryKernelFunc({
122962 opSnippet: MAXIMUM,
122963 packedOpSnippet: MAXIMUM_PACKED,
122964 cpuKernelImpl: maximumImplCPU
122965 });
122966 var maximumConfig = {
122967 kernelName: Maximum$1,
122968 backendName: 'webgl',
122969 kernelFunc: maximum
122970 };
122971
122972 /**
122973 * @license
122974 * Copyright 2020 Google LLC. All Rights Reserved.
122975 * Licensed under the Apache License, Version 2.0 (the "License");
122976 * you may not use this file except in compliance with the License.
122977 * You may obtain a copy of the License at
122978 *
122979 * http://www.apache.org/licenses/LICENSE-2.0
122980 *
122981 * Unless required by applicable law or agreed to in writing, software
122982 * distributed under the License is distributed on an "AS IS" BASIS,
122983 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122984 * See the License for the specific language governing permissions and
122985 * limitations under the License.
122986 * =============================================================================
122987 */
122988 function maxPool(args) {
122989 var inputs = args.inputs,
122990 backend = args.backend,
122991 attrs = args.attrs;
122992 var x = inputs.x;
122993 assertNotComplex(x, 'maxPool');
122994 var filterSize = attrs.filterSize,
122995 strides = attrs.strides,
122996 pad = attrs.pad,
122997 dimRoundingMode = attrs.dimRoundingMode;
122998 var dilations = 1;
122999 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
123000 return 'Error in maxPool: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
123001 });
123002 var convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
123003 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) {
123004 return identity({
123005 inputs: {
123006 x: x
123007 },
123008 backend: backend
123009 });
123010 }
123011 var maxPoolProgram = new Pool2DProgram(convInfo, 'max', false);
123012 return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
123013 }
123014 var maxPoolConfig = {
123015 kernelName: MaxPool,
123016 backendName: 'webgl',
123017 kernelFunc: maxPool
123018 };
123019
123020 /**
123021 * @license
123022 * Copyright 2020 Google LLC. All Rights Reserved.
123023 * Licensed under the Apache License, Version 2.0 (the "License");
123024 * you may not use this file except in compliance with the License.
123025 * You may obtain a copy of the License at
123026 *
123027 * http://www.apache.org/licenses/LICENSE-2.0
123028 *
123029 * Unless required by applicable law or agreed to in writing, software
123030 * distributed under the License is distributed on an "AS IS" BASIS,
123031 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123032 * See the License for the specific language governing permissions and
123033 * limitations under the License.
123034 * =============================================================================
123035 */
123036 function maxPool3d(args) {
123037 var inputs = args.inputs,
123038 backend = args.backend,
123039 attrs = args.attrs;
123040 var x = inputs.x;
123041 var filterSize = attrs.filterSize,
123042 strides = attrs.strides,
123043 pad = attrs.pad,
123044 dataFormat = attrs.dataFormat,
123045 dimRoundingMode = attrs.dimRoundingMode;
123046 var dilations = [1, 1, 1];
123047 var convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
123048 var maxPoolProgram = new Pool3DProgram(convInfo, 'max', false);
123049 return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
123050 }
123051 var maxPool3DConfig = {
123052 kernelName: MaxPool3D,
123053 backendName: 'webgl',
123054 kernelFunc: maxPool3d
123055 };
123056
123057 /**
123058 * @license
123059 * Copyright 2017 Google LLC. All Rights Reserved.
123060 * Licensed under the Apache License, Version 2.0 (the "License");
123061 * you may not use this file except in compliance with the License.
123062 * You may obtain a copy of the License at
123063 *
123064 * http://www.apache.org/licenses/LICENSE-2.0
123065 *
123066 * Unless required by applicable law or agreed to in writing, software
123067 * distributed under the License is distributed on an "AS IS" BASIS,
123068 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123069 * See the License for the specific language governing permissions and
123070 * limitations under the License.
123071 * =============================================================================
123072 */
123073 var MaxPool2DBackpropProgram = /*#__PURE__*/_createClass(function MaxPool2DBackpropProgram(convInfo) {
123074 _classCallCheck(this, MaxPool2DBackpropProgram);
123075 this.variableNames = ['dy', 'maxPos'];
123076 this.outputShape = convInfo.inShape;
123077 var strideHeight = convInfo.strideHeight;
123078 var strideWidth = convInfo.strideWidth;
123079 var dilationHeight = convInfo.dilationHeight;
123080 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
123081 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
123082 var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
123083 var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
123084 var lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1;
123085 this.userCode = "\n const ivec2 pads = ivec2(".concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, "; wC++) {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n int maxPosValue = ").concat(lastIndex, " - int(getMaxPos(b, idyR, idyC, d));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue = wR * ").concat(effectiveFilterWidth, " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n setOutput(dotProd);\n }\n ");
123086 });
123087 var MaxPool3DBackpropProgram = /*#__PURE__*/_createClass(function MaxPool3DBackpropProgram(convInfo) {
123088 _classCallCheck(this, MaxPool3DBackpropProgram);
123089 this.variableNames = ['dy', 'maxPos'];
123090 this.outputShape = convInfo.inShape;
123091 var strideDepth = convInfo.strideDepth;
123092 var strideHeight = convInfo.strideHeight;
123093 var strideWidth = convInfo.strideWidth;
123094 var dilationDepth = convInfo.dilationDepth;
123095 var dilationHeight = convInfo.dilationHeight;
123096 var dilationWidth = convInfo.dilationWidth;
123097 var effectiveFilterDepth = convInfo.effectiveFilterDepth;
123098 var effectiveFilterHeight = convInfo.effectiveFilterHeight;
123099 var effectiveFilterWidth = convInfo.effectiveFilterWidth;
123100 var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
123101 var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
123102 var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
123103 var lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1;
123104 this.userCode = "\n const ivec3 pads = ivec3(".concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < ").concat(effectiveFilterDepth, ";\n wD += ").concat(dilationDepth, ") {\n float dyD = float(dyDCorner + wD) / ").concat(strideDepth, ".0;\n\n if (dyD < 0.0 || dyD >= ").concat(convInfo.outDepth, ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, ";\n wC += ").concat(dilationWidth, ") {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n int maxPosValue = ").concat(lastIndex, " -\n int(getMaxPos(batch, idyD, idyR, idyC, ch));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue =\n wD * ").concat(effectiveFilterHeight, " * ").concat(effectiveFilterWidth, " +\n wR * ").concat(effectiveFilterWidth, " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n }\n setOutput(dotProd);\n }\n ");
123105 });
123106
123107 /**
123108 * @license
123109 * Copyright 2020 Google LLC. All Rights Reserved.
123110 * Licensed under the Apache License, Version 2.0 (the "License");
123111 * you may not use this file except in compliance with the License.
123112 * You may obtain a copy of the License at
123113 *
123114 * http://www.apache.org/licenses/LICENSE-2.0
123115 *
123116 * Unless required by applicable law or agreed to in writing, software
123117 * distributed under the License is distributed on an "AS IS" BASIS,
123118 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123119 * See the License for the specific language governing permissions and
123120 * limitations under the License.
123121 * =============================================================================
123122 */
123123 function maxPool3DGrad(args) {
123124 var inputs = args.inputs,
123125 backend = args.backend,
123126 attrs = args.attrs;
123127 var dy = inputs.dy,
123128 input = inputs.input;
123129 var x = input;
123130 var filterSize = attrs.filterSize,
123131 strides = attrs.strides,
123132 pad = attrs.pad,
123133 dimRoundingMode = attrs.dimRoundingMode;
123134 var dilations = [1, 1, 1];
123135 var convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
123136 var maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', true /* get positions */);
123137 var maxPool3dPositions = backend.runWebGLProgram(maxPool3dPositionsProgram, [x], x.dtype);
123138 var maxPoolBackpropProgram = new MaxPool3DBackpropProgram(convInfo);
123139 var result = backend.runWebGLProgram(maxPoolBackpropProgram, [dy, maxPool3dPositions], x.dtype);
123140 backend.disposeIntermediateTensorInfo(maxPool3dPositions);
123141 return result;
123142 }
123143 var maxPool3DGradConfig = {
123144 kernelName: MaxPool3DGrad,
123145 backendName: 'webgl',
123146 kernelFunc: maxPool3DGrad
123147 };
123148
123149 /**
123150 * @license
123151 * Copyright 2020 Google LLC. All Rights Reserved.
123152 * Licensed under the Apache License, Version 2.0 (the "License");
123153 * you may not use this file except in compliance with the License.
123154 * You may obtain a copy of the License at
123155 *
123156 * http://www.apache.org/licenses/LICENSE-2.0
123157 *
123158 * Unless required by applicable law or agreed to in writing, software
123159 * distributed under the License is distributed on an "AS IS" BASIS,
123160 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123161 * See the License for the specific language governing permissions and
123162 * limitations under the License.
123163 * =============================================================================
123164 */
123165 function maxPoolGrad(args) {
123166 var inputs = args.inputs,
123167 backend = args.backend,
123168 attrs = args.attrs;
123169 var dy = inputs.dy,
123170 input = inputs.input,
123171 output = inputs.output;
123172 var x = input;
123173 assertNotComplex([input, output], 'maxPoolGrad');
123174 var filterSize = attrs.filterSize,
123175 strides = attrs.strides,
123176 pad = attrs.pad,
123177 dimRoundingMode = attrs.dimRoundingMode;
123178 var convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
123179 var getPositions = true;
123180 var maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions);
123181 var maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype);
123182 var maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);
123183 var result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype);
123184 backend.disposeIntermediateTensorInfo(maxPoolPositions);
123185 return result;
123186 }
123187 var maxPoolGradConfig = {
123188 kernelName: MaxPoolGrad,
123189 backendName: 'webgl',
123190 kernelFunc: maxPoolGrad
123191 };
123192
123193 /**
123194 * @license
123195 * Copyright 2020 Google LLC. All Rights Reserved.
123196 * Licensed under the Apache License, Version 2.0 (the "License");
123197 * you may not use this file except in compliance with the License.
123198 * You may obtain a copy of the License at
123199 *
123200 * http://www.apache.org/licenses/LICENSE-2.0
123201 *
123202 * Unless required by applicable law or agreed to in writing, software
123203 * distributed under the License is distributed on an "AS IS" BASIS,
123204 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123205 * See the License for the specific language governing permissions and
123206 * limitations under the License.
123207 * =============================================================================
123208 */
123209 function maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, backend) {
123210 var program = new Pool2DProgram(convInfo, 'max', false);
123211 var poolOutput = backend.runWebGLProgram(program, [x], 'float32');
123212 program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex);
123213 var indexOutput = backend.runWebGLProgram(program, [x], 'float32');
123214 return [poolOutput, indexOutput];
123215 }
123216
123217 var maxPoolWithArgmaxConfig = {
123218 kernelName: MaxPoolWithArgmax,
123219 backendName: 'webgl',
123220 kernelFunc: function kernelFunc(_ref) {
123221 var inputs = _ref.inputs,
123222 attrs = _ref.attrs,
123223 backend = _ref.backend;
123224 var x = inputs.x;
123225 var filterSize = attrs.filterSize,
123226 strides = attrs.strides,
123227 pad = attrs.pad,
123228 includeBatchInIndex = attrs.includeBatchInIndex;
123229 var webglBackend = backend;
123230 assert$1(x.shape.length === 4, function () {
123231 return "Error in maxPool: input must be rank 4 but got rank ".concat(x.shape.length, ".");
123232 });
123233 var dilations = [1, 1];
123234 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), function () {
123235 return 'Error in maxPool: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'");
123236 });
123237 var convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad);
123238 var _maxPoolWithArgmaxImp = maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, webglBackend),
123239 _maxPoolWithArgmaxImp2 = _slicedToArray(_maxPoolWithArgmaxImp, 2),
123240 result = _maxPoolWithArgmaxImp2[0],
123241 indexes = _maxPoolWithArgmaxImp2[1];
123242 return [result, indexes];
123243 }
123244 };
123245
123246 /**
123247 * @license
123248 * Copyright 2020 Google LLC. All Rights Reserved.
123249 * Licensed under the Apache License, Version 2.0 (the "License");
123250 * you may not use this file except in compliance with the License.
123251 * You may obtain a copy of the License at
123252 *
123253 * http://www.apache.org/licenses/LICENSE-2.0
123254 *
123255 * Unless required by applicable law or agreed to in writing, software
123256 * distributed under the License is distributed on an "AS IS" BASIS,
123257 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123258 * See the License for the specific language governing permissions and
123259 * limitations under the License.
123260 * =============================================================================
123261 */
123262 function meanImpl(x, reduceShape, outShape, backend) {
123263 var inSize = sizeFromShape(reduceShape);
123264 var xSize = sizeFromShape(x.shape);
123265 var batchSize = xSize / inSize;
123266 var reshapedInput = reshape({
123267 inputs: {
123268 x: x
123269 },
123270 attrs: {
123271 shape: [batchSize, inSize]
123272 },
123273 backend: backend
123274 });
123275 var reduced = reduce(reshapedInput, 'float32', 'mean', backend);
123276 var reshapedOutput = reshape({
123277 inputs: {
123278 x: reduced
123279 },
123280 attrs: {
123281 shape: outShape
123282 },
123283 backend: backend
123284 });
123285 backend.disposeIntermediateTensorInfo(reshapedInput);
123286 backend.disposeIntermediateTensorInfo(reduced);
123287 return reshapedOutput;
123288 }
123289
123290 var meanConfig = {
123291 kernelName: Mean,
123292 backendName: 'webgl',
123293 kernelFunc: function kernelFunc(_ref) {
123294 var inputs = _ref.inputs,
123295 attrs = _ref.attrs,
123296 backend = _ref.backend;
123297 var x = inputs.x;
123298 var keepDims = attrs.keepDims,
123299 axis = attrs.axis;
123300 var webglBackend = backend;
123301 var xRank = x.shape.length;
123302 var origAxes = parseAxisParam(axis, x.shape);
123303 var axes = origAxes;
123304 var permutedAxes = getAxesPermutation(axes, xRank);
123305 var meanInputIsTransposed = permutedAxes != null;
123306 var shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
123307 var intermediates = [];
123308 var meanInput = x;
123309 if (meanInputIsTransposed) {
123310 if (shouldExecuteOnCPU) {
123311 var xTexData = webglBackend.texData.get(meanInput.dataId);
123312 var values = xTexData.values;
123313 var newShape = new Array(xRank);
123314 for (var i = 0; i < newShape.length; i++) {
123315 newShape[i] = x.shape[permutedAxes[i]];
123316 }
123317 var meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
123318 meanInput = webglBackend.makeTensorInfo(newShape, x.dtype);
123319 var meanInputData = webglBackend.texData.get(meanInput.dataId);
123320 meanInputData.values = meanInputValues;
123321 } else {
123322 meanInput = transposeImpl(x, permutedAxes, webglBackend);
123323 }
123324 intermediates.push(meanInput);
123325 axes = getInnerMostAxes(axes.length, xRank);
123326 }
123327 assertAxesAreInnerMostDims('sum', axes, xRank);
123328 var _backend_util$compute = computeOutAndReduceShapes(meanInput.shape, axes),
123329 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
123330 meanOutShape = _backend_util$compute2[0],
123331 reduceShape = _backend_util$compute2[1];
123332 var outShape = meanOutShape;
123333 if (keepDims) {
123334 // rather than reshape at the end, set the target shape here.
123335 outShape = expandShapeToKeepDim(meanOutShape, origAxes);
123336 }
123337 var out = meanImpl(meanInput, reduceShape, outShape, webglBackend);
123338 for (var _i = 0, _intermediates = intermediates; _i < _intermediates.length; _i++) {
123339 var _i2 = _intermediates[_i];
123340 webglBackend.disposeIntermediateTensorInfo(_i2);
123341 }
123342 return out;
123343 }
123344 };
123345
123346 function min(args) {
123347 var inputs = args.inputs,
123348 backend = args.backend,
123349 attrs = args.attrs;
123350 var x = inputs.x;
123351 var axis = attrs.axis,
123352 keepDims = attrs.keepDims;
123353 var xRank = x.shape.length;
123354 var origAxes = parseAxisParam(axis, x.shape);
123355 var axes = origAxes;
123356 var permutedAxes = getAxesPermutation(axes, xRank);
123357 var permutedX = x;
123358 if (permutedAxes != null) {
123359 permutedX = transpose({
123360 inputs: {
123361 x: x
123362 },
123363 backend: backend,
123364 attrs: {
123365 perm: permutedAxes
123366 }
123367 });
123368 axes = getInnerMostAxes(axes.length, x.shape.length);
123369 }
123370 assertAxesAreInnerMostDims('min', axes, xRank);
123371 var _backend_util$compute = computeOutAndReduceShapes(permutedX.shape, axes),
123372 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
123373 outShape = _backend_util$compute2[0],
123374 reduceShape = _backend_util$compute2[1];
123375 var inSize = sizeFromShape(reduceShape);
123376 var a2D = reshape({
123377 inputs: {
123378 x: permutedX
123379 },
123380 backend: backend,
123381 attrs: {
123382 shape: [-1, inSize]
123383 }
123384 });
123385 var reduced = reduce(a2D, a2D.dtype, 'min', backend);
123386 var res;
123387 if (keepDims) {
123388 var newShape = expandShapeToKeepDim(outShape, origAxes);
123389 res = reshape({
123390 inputs: {
123391 x: reduced
123392 },
123393 backend: backend,
123394 attrs: {
123395 shape: newShape
123396 }
123397 });
123398 } else {
123399 res = reshape({
123400 inputs: {
123401 x: reduced
123402 },
123403 backend: backend,
123404 attrs: {
123405 shape: outShape
123406 }
123407 });
123408 }
123409 backend.disposeIntermediateTensorInfo(a2D);
123410 backend.disposeIntermediateTensorInfo(reduced);
123411 if (permutedAxes != null) {
123412 backend.disposeIntermediateTensorInfo(permutedX);
123413 }
123414 return res;
123415 }
123416 var minConfig = {
123417 kernelName: Min,
123418 backendName: 'webgl',
123419 kernelFunc: min
123420 };
123421
123422 /**
123423 * @license
123424 * Copyright 2020 Google LLC. All Rights Reserved.
123425 * Licensed under the Apache License, Version 2.0 (the "License");
123426 * you may not use this file except in compliance with the License.
123427 * You may obtain a copy of the License at
123428 *
123429 * http://www.apache.org/licenses/LICENSE-2.0
123430 *
123431 * Unless required by applicable law or agreed to in writing, software
123432 * distributed under the License is distributed on an "AS IS" BASIS,
123433 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123434 * See the License for the specific language governing permissions and
123435 * limitations under the License.
123436 * =============================================================================
123437 */
123438 var MINIMUM = CHECK_NAN_SNIPPET + "\n return min(a, b);\n";
123439 var MINIMUM_PACKED = "\n vec4 result = vec4(min(a, b));\n bvec4 isNaNA = isnan(a);\n bvec4 isNaNB = isnan(b);\n bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);\n " + CHECK_NAN_SNIPPET_PACKED + "\n return result;\n";
123440 var minimum = binaryKernelFunc({
123441 opSnippet: MINIMUM,
123442 packedOpSnippet: MINIMUM_PACKED,
123443 cpuKernelImpl: minimumImplCPU
123444 });
123445 var minimumConfig = {
123446 kernelName: Minimum$1,
123447 backendName: 'webgl',
123448 kernelFunc: minimum
123449 };
123450
123451 var MirrorPadProgram = /*#__PURE__*/_createClass(function MirrorPadProgram(xShape, paddings, mode) {
123452 _classCallCheck(this, MirrorPadProgram);
123453 this.variableNames = ['x'];
123454 this.outputShape = paddings.map(function (p, i) {
123455 return p[0] /* beforePad */ + xShape[i] + p[1];
123456 } /* afterPad */);
123457 var rank = xShape.length;
123458 var dtype = getCoordsDataType(rank);
123459 var start = paddings.map(function (p) {
123460 return p[0];
123461 }).join(',');
123462 var end = paddings.map(function (p, i) {
123463 return p[0] + xShape[i];
123464 }).join(',');
123465 var unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
123466 var offset = mode === 'reflect' ? 0 : 1;
123467 if (rank === 1) {
123468 this.userCode = "\n int start = ".concat(start, ";\n int end = ").concat(end, ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start) {\n outC = start * 2 - outC - ").concat(offset, ";\n } else if(outC >= end) {\n outC = (end - 1) * 2 - outC + ").concat(offset, ";\n }\n setOutput(getX(outC - start));\n }\n ");
123469 return;
123470 }
123471 this.userCode = "\n ".concat(dtype, " start = ").concat(dtype, "(").concat(start, ");\n ").concat(dtype, " end = ").concat(dtype, "(").concat(end, ");\n\n void main() {\n ").concat(dtype, " outC = getOutputCoords();\n for (int i = 0; i < ").concat(rank, "; i++) {\n if (outC[i] < start[i]) {\n outC[i] = start[i] * 2 - outC[i] - ").concat(offset, ";\n } else if(outC[i] >= end[i]) {\n outC[i] = (end[i] - 1) * 2 - outC[i] + ").concat(offset, ";\n }\n }\n ").concat(dtype, " coords = outC - start;\n setOutput(getX(").concat(unpackedCoords, "));\n }\n ");
123472 });
123473
123474 /**
123475 * Example shader code for
123476 * `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')`
123477 * ```
123478 * const int start = int(2);
123479 * const int end = int(5);
123480 *
123481 * void main() {
123482 * int outputLoc = getOutputCoords();
123483 * vec4 result = vec4(0.);
123484 *
123485 * int rc = outputLoc;
123486 *
123487 * int source = rc;
123488 * if (source < start) {
123489 * source = start * 2 - source - 0;
123490 * } else if (source >= end) {
123491 * source = (end - 1) * 2 - source + 0;
123492 * }
123493 * source -= start;
123494 *
123495 * result[0] = getChannel(getX(source), source);
123496 * rc += 1;
123497 * if(rc < 6) {
123498 * int source = rc;
123499 * if (source < start) {
123500 * source = start * 2 - source - 0;
123501 * } else if (source >= end) {
123502 * source = (end - 1) * 2 - source + 0;
123503 * }
123504 * source -= start;
123505 *
123506 * result[1] = getChannel(getX(source), source);
123507 * }
123508 *
123509 * setOutput(result);
123510 * }
123511 * ```
123512 */
123513 var MirrorPadPackedProgram = /*#__PURE__*/_createClass(function MirrorPadPackedProgram(xShape, paddings, mode) {
123514 _classCallCheck(this, MirrorPadPackedProgram);
123515 this.variableNames = ['x'];
123516 this.packedInputs = true;
123517 this.packedOutput = true;
123518 this.outputShape = paddings.map(function (p, i) {
123519 return p[0] /* beforePad */ + xShape[i] + p[1];
123520 } /* afterPad */);
123521 var rank = xShape.length;
123522 var dtype = getCoordsDataType(rank);
123523 var start = paddings.map(function (p) {
123524 return p[0];
123525 }).join(',');
123526 var end = paddings.map(function (p, i) {
123527 return p[0] + xShape[i];
123528 }).join(',');
123529 var coords = getChannels('rc', rank);
123530 var source = getChannels('source', rank);
123531 var cLimit = "".concat(coords[rank - 1], " < ").concat(this.outputShape[rank - 1]);
123532 var innerDims = rank === 1 ? 'source' : "vec2(".concat(source.slice(-2).join(), ")");
123533 var offset = mode === 'reflect' ? 0 : 1;
123534 var mainLoop = '';
123535 if (rank === 1) {
123536 var padSetup = "\n ".concat(dtype, " source = rc;\n if (source < start) {\n source = start * 2 - source - ").concat(offset, ";\n } else if (source >= end) {\n source = (end - 1) * 2 - source + ").concat(offset, ";\n }\n source -= start;\n ");
123537 mainLoop = "\n ".concat(dtype, " rc = outputLoc;\n ").concat(padSetup, "\n result[0] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n ").concat(coords[rank - 1], " += 1;\n if(").concat(cLimit, ") {\n ").concat(padSetup, "\n result[1] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n }\n ");
123538 } else {
123539 var _padSetup = "\n ".concat(dtype, " source = rc;\n ").concat(dtype, " lt = ").concat(dtype, "(lessThan(source, start));\n ").concat(dtype, " gte = ").concat(dtype, "(greaterThanEqual(source, end));\n ").concat(dtype, " orig = 1 - (lt + gte);\n source = orig * source +\n lt * (start * 2 - source - ").concat(offset, ") +\n gte * ((end - 1) * 2 - source + ").concat(offset, ");\n source -= start;\n ");
123540 mainLoop = "\n ".concat(dtype, " rc = outputLoc;\n ").concat(_padSetup, "\n result[0] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n ").concat(coords[rank - 1], " += 1;\n if(").concat(cLimit, ") {\n ").concat(_padSetup, "\n result[1] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n }\n rc = outputLoc;\n ").concat(coords[rank - 2], " += 1;\n if(").concat(coords[rank - 2], " < ").concat(this.outputShape[rank - 2], ") {\n ").concat(_padSetup, "\n result[2] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n ").concat(coords[rank - 1], " += 1;\n if(").concat(cLimit, ") {\n ").concat(_padSetup, "\n result[3] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n }\n }\n ");
123541 }
123542 this.userCode = "\n const ".concat(dtype, " start = ").concat(dtype, "(").concat(start, ");\n const ").concat(dtype, " end = ").concat(dtype, "(").concat(end, ");\n\n void main() {\n ").concat(dtype, " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n ").concat(mainLoop, "\n setOutput(result);\n }\n ");
123543 });
123544
123545 /**
123546 * @license
123547 * Copyright 2020 Google LLC. All Rights Reserved.
123548 * Licensed under the Apache License, Version 2.0 (the "License");
123549 * you may not use this file except in compliance with the License.
123550 * You may obtain a copy of the License at
123551 *
123552 * http://www.apache.org/licenses/LICENSE-2.0
123553 *
123554 * Unless required by applicable law or agreed to in writing, software
123555 * distributed under the License is distributed on an "AS IS" BASIS,
123556 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123557 * See the License for the specific language governing permissions and
123558 * limitations under the License.
123559 * =============================================================================
123560 */
123561 var mirrorPadKernelFunc = function mirrorPadKernelFunc(_ref) {
123562 var inputs = _ref.inputs,
123563 backend = _ref.backend,
123564 attrs = _ref.attrs;
123565 var x = inputs.x;
123566 var paddings = attrs.paddings,
123567 mode = attrs.mode;
123568 var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new MirrorPadPackedProgram(x.shape, paddings, mode) : new MirrorPadProgram(x.shape, paddings, mode);
123569 var output = backend.runWebGLProgram(program, [x], x.dtype);
123570 return output;
123571 };
123572 var mirrorPadConfig = {
123573 kernelName: MirrorPad,
123574 backendName: 'webgl',
123575 kernelFunc: mirrorPadKernelFunc
123576 };
123577
123578 /**
123579 * @license
123580 * Copyright 2020 Google LLC. All Rights Reserved.
123581 * Licensed under the Apache License, Version 2.0 (the "License");
123582 * you may not use this file except in compliance with the License.
123583 * You may obtain a copy of the License at
123584 *
123585 * http://www.apache.org/licenses/LICENSE-2.0
123586 *
123587 * Unless required by applicable law or agreed to in writing, software
123588 * distributed under the License is distributed on an "AS IS" BASIS,
123589 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123590 * See the License for the specific language governing permissions and
123591 * limitations under the License.
123592 * =============================================================================
123593 */
123594 var MOD = "if (b == 0.0) return NAN;\n return mod(a, b);";
123595 var MOD_PACKED = "\n vec4 result = mod(a, b);\n bvec4 isNaN = equal(b, vec4(0.0));\n " + CHECK_NAN_SNIPPET_PACKED + "\n return result;\n";
123596 var mod = binaryKernelFunc({
123597 opSnippet: MOD,
123598 packedOpSnippet: MOD_PACKED
123599 });
123600 var modConfig = {
123601 kernelName: Mod,
123602 backendName: 'webgl',
123603 kernelFunc: mod
123604 };
123605
123606 /**
123607 * @license
123608 * Copyright 2017 Google LLC. All Rights Reserved.
123609 * Licensed under the Apache License, Version 2.0 (the "License");
123610 * you may not use this file except in compliance with the License.
123611 * You may obtain a copy of the License at
123612 *
123613 * http://www.apache.org/licenses/LICENSE-2.0
123614 *
123615 * Unless required by applicable law or agreed to in writing, software
123616 * distributed under the License is distributed on an "AS IS" BASIS,
123617 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123618 * See the License for the specific language governing permissions and
123619 * limitations under the License.
123620 * =============================================================================
123621 */
123622 var MultinomialProgram = /*#__PURE__*/_createClass(function MultinomialProgram(batchSize, numOutcomes, numSamples) {
123623 _classCallCheck(this, MultinomialProgram);
123624 this.variableNames = ['probs'];
123625 this.customUniforms = [{
123626 name: 'seed',
123627 type: 'float'
123628 }];
123629 this.outputShape = [batchSize, numSamples];
123630 this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n\n float r = random(seed);\n float cdf = 0.0;\n\n for (int i = 0; i < ".concat(numOutcomes - 1, "; i++) {\n cdf += getProbs(batch, i);\n\n if (r < cdf) {\n setOutput(float(i));\n return;\n }\n }\n\n // If no other event happened, last event happened.\n setOutput(float(").concat(numOutcomes - 1, "));\n }\n ");
123631 });
123632
123633 /**
123634 * @license
123635 * Copyright 2020 Google LLC. All Rights Reserved.
123636 * Licensed under the Apache License, Version 2.0 (the "License");
123637 * you may not use this file except in compliance with the License.
123638 * You may obtain a copy of the License at
123639 *
123640 * http://www.apache.org/licenses/LICENSE-2.0
123641 *
123642 * Unless required by applicable law or agreed to in writing, software
123643 * distributed under the License is distributed on an "AS IS" BASIS,
123644 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123645 * See the License for the specific language governing permissions and
123646 * limitations under the License.
123647 * =============================================================================
123648 */
123649 // Without the equality check div produces 0.9999 for a = b, which when
123650 // floored can cause errors.
123651 var DIV = "\nif (a == b) {\n return 1.0;\n};\nreturn a / b;";
123652 // We do the same as in ./binaryop_gpu, with vec4 and ivec4.
123653 // On Linux, the vectorized implementation produces NaNs when a and b are 0.
123654 var DIV_PACKED = "\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n";
123655 var realDiv = binaryKernelFunc({
123656 opSnippet: DIV,
123657 packedOpSnippet: DIV_PACKED,
123658 checkOutOfBounds: true
123659 });
123660 var realDivConfig = {
123661 kernelName: RealDiv,
123662 backendName: 'webgl',
123663 kernelFunc: realDiv
123664 };
123665
123666 /**
123667 * @license
123668 * Copyright 2020 Google LLC. All Rights Reserved.
123669 * Licensed under the Apache License, Version 2.0 (the "License");
123670 * you may not use this file except in compliance with the License.
123671 * You may obtain a copy of the License at
123672 *
123673 * http://www.apache.org/licenses/LICENSE-2.0
123674 *
123675 * Unless required by applicable law or agreed to in writing, software
123676 * distributed under the License is distributed on an "AS IS" BASIS,
123677 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123678 * See the License for the specific language governing permissions and
123679 * limitations under the License.
123680 * =============================================================================
123681 */
123682 var SUB = 'return a - b;';
123683 var sub = binaryKernelFunc({
123684 opSnippet: SUB,
123685 packedOpSnippet: SUB,
123686 supportsComplex: true,
123687 cpuKernelImpl: subImplCPU
123688 });
123689 var subConfig = {
123690 kernelName: Sub,
123691 backendName: 'webgl',
123692 kernelFunc: sub
123693 };
123694
123695 /**
123696 * @license
123697 * Copyright 2020 Google LLC. All Rights Reserved.
123698 * Licensed under the Apache License, Version 2.0 (the "License");
123699 * you may not use this file except in compliance with the License.
123700 * You may obtain a copy of the License at
123701 *
123702 * http://www.apache.org/licenses/LICENSE-2.0
123703 *
123704 * Unless required by applicable law or agreed to in writing, software
123705 * distributed under the License is distributed on an "AS IS" BASIS,
123706 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123707 * See the License for the specific language governing permissions and
123708 * limitations under the License.
123709 * =============================================================================
123710 */
123711 function softmax(args) {
123712 var inputs = args.inputs,
123713 backend = args.backend,
123714 attrs = args.attrs;
123715 var logits = inputs.logits;
123716 var dim = attrs.dim;
123717 var axes = parseAxisParam([dim], logits.shape);
123718 var maxLogit = max({
123719 inputs: {
123720 x: logits
123721 },
123722 backend: backend,
123723 attrs: {
123724 reductionIndices: axes,
123725 keepDims: false
123726 }
123727 });
123728 var expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
123729 var maxLogitsReshaped = reshape({
123730 inputs: {
123731 x: maxLogit
123732 },
123733 backend: backend,
123734 attrs: {
123735 shape: expandedShape
123736 }
123737 });
123738 var a = sub({
123739 inputs: {
123740 a: logits,
123741 b: maxLogitsReshaped
123742 },
123743 backend: backend
123744 });
123745 var b = exp({
123746 inputs: {
123747 x: a
123748 },
123749 backend: backend
123750 });
123751 var sumExp = sum({
123752 inputs: {
123753 x: b
123754 },
123755 backend: backend,
123756 attrs: {
123757 axis: axes,
123758 keepDims: false
123759 }
123760 });
123761 var sumExpReshaped = reshape({
123762 inputs: {
123763 x: sumExp
123764 },
123765 backend: backend,
123766 attrs: {
123767 shape: expandedShape
123768 }
123769 });
123770 var res = realDiv({
123771 inputs: {
123772 a: b,
123773 b: sumExpReshaped
123774 },
123775 backend: backend
123776 });
123777 backend.disposeIntermediateTensorInfo(maxLogit);
123778 backend.disposeIntermediateTensorInfo(maxLogitsReshaped);
123779 backend.disposeIntermediateTensorInfo(a);
123780 backend.disposeIntermediateTensorInfo(b);
123781 backend.disposeIntermediateTensorInfo(sumExp);
123782 backend.disposeIntermediateTensorInfo(sumExpReshaped);
123783 return res;
123784 }
123785 var softmaxConfig = {
123786 kernelName: Softmax$2,
123787 backendName: 'webgl',
123788 kernelFunc: softmax
123789 };
123790
123791 /**
123792 * @license
123793 * Copyright 2020 Google LLC. All Rights Reserved.
123794 * Licensed under the Apache License, Version 2.0 (the "License");
123795 * you may not use this file except in compliance with the License.
123796 * You may obtain a copy of the License at
123797 *
123798 * http://www.apache.org/licenses/LICENSE-2.0
123799 *
123800 * Unless required by applicable law or agreed to in writing, software
123801 * distributed under the License is distributed on an "AS IS" BASIS,
123802 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123803 * See the License for the specific language governing permissions and
123804 * limitations under the License.
123805 * =============================================================================
123806 */
123807 function multinomial(args) {
123808 var inputs = args.inputs,
123809 backend = args.backend,
123810 attrs = args.attrs;
123811 var logits = inputs.logits;
123812 var numSamples = attrs.numSamples,
123813 seed = attrs.seed,
123814 normalized = attrs.normalized;
123815 var probs = normalized ? logits : softmax({
123816 inputs: {
123817 logits: logits
123818 },
123819 backend: backend,
123820 attrs: {
123821 dim: logits.shape.length - 1
123822 }
123823 });
123824 var batchSize = probs.shape[0];
123825 var numOutcomes = probs.shape[1];
123826 var program = new MultinomialProgram(batchSize, numOutcomes, numSamples);
123827 var customValues = [[seed]];
123828 var res = backend.runWebGLProgram(program, [probs], 'int32', customValues);
123829 if (!normalized) {
123830 backend.disposeIntermediateTensorInfo(probs);
123831 }
123832 return res;
123833 }
123834 var multinomialConfig = {
123835 kernelName: Multinomial,
123836 backendName: 'webgl',
123837 kernelFunc: multinomial
123838 };
123839
123840 var NEG = CHECK_NAN_SNIPPET$1 + "\n return -x;\n";
123841 var NEG_PACKED = "\n vec4 result = -x;\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
123842 // This doesn't use unaryKernelFunc because negImplCPU is not of type
123843 // SimpleUnaryKernelImplCPU.
123844 function neg(args) {
123845 var inputs = args.inputs,
123846 backend = args.backend;
123847 var x = inputs.x;
123848 if (backend.shouldExecuteOnCPU([x])) {
123849 var xData = backend.texData.get(x.dataId);
123850 var _negImplCPU = negImplCPU(xData.values, x.shape, x.dtype),
123851 _negImplCPU2 = _slicedToArray(_negImplCPU, 2),
123852 outValues = _negImplCPU2[0],
123853 newShape = _negImplCPU2[1];
123854 return backend.makeTensorInfo(newShape, x.dtype, outValues);
123855 }
123856 var program;
123857 if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
123858 program = new UnaryOpPackedProgram(x.shape, NEG_PACKED);
123859 } else {
123860 program = new UnaryOpProgram(x.shape, NEG);
123861 }
123862 return backend.runWebGLProgram(program, [x], x.dtype);
123863 }
123864 var negConfig = {
123865 kernelName: Neg,
123866 backendName: 'webgl',
123867 kernelFunc: neg
123868 };
123869
123870 /**
123871 * @license
123872 * Copyright 2020 Google LLC. All Rights Reserved.
123873 * Licensed under the Apache License, Version 2.0 (the "License");
123874 * you may not use this file except in compliance with the License.
123875 * You may obtain a copy of the License at
123876 *
123877 * http://www.apache.org/licenses/LICENSE-2.0
123878 *
123879 * Unless required by applicable law or agreed to in writing, software
123880 * distributed under the License is distributed on an "AS IS" BASIS,
123881 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123882 * See the License for the specific language governing permissions and
123883 * limitations under the License.
123884 * =============================================================================
123885 */
123886 var nonMaxSuppressionV3Impl = nonMaxSuppressionV3Impl$2;
123887 function nonMaxSuppressionV3(args) {
123888 warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead');
123889 var inputs = args.inputs,
123890 backend = args.backend,
123891 attrs = args.attrs;
123892 var boxes = inputs.boxes,
123893 scores = inputs.scores;
123894 var maxOutputSize = attrs.maxOutputSize,
123895 iouThreshold = attrs.iouThreshold,
123896 scoreThreshold = attrs.scoreThreshold;
123897 var boxesVals = backend.readSync(boxes.dataId);
123898 var scoresVals = backend.readSync(scores.dataId);
123899 var _nonMaxSuppressionV3I = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold),
123900 selectedIndices = _nonMaxSuppressionV3I.selectedIndices;
123901 return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
123902 }
123903 var nonMaxSuppressionV3Config = {
123904 kernelName: NonMaxSuppressionV3,
123905 backendName: 'webgl',
123906 kernelFunc: nonMaxSuppressionV3
123907 };
123908
123909 /**
123910 * @license
123911 * Copyright 2020 Google LLC. All Rights Reserved.
123912 * Licensed under the Apache License, Version 2.0 (the "License");
123913 * you may not use this file except in compliance with the License.
123914 * You may obtain a copy of the License at
123915 *
123916 * http://www.apache.org/licenses/LICENSE-2.0
123917 *
123918 * Unless required by applicable law or agreed to in writing, software
123919 * distributed under the License is distributed on an "AS IS" BASIS,
123920 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123921 * See the License for the specific language governing permissions and
123922 * limitations under the License.
123923 * =============================================================================
123924 */
123925 var nonMaxSuppressionV4Impl = nonMaxSuppressionV4Impl$2;
123926 function nonMaxSuppressionV4(args) {
123927 warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead');
123928 var inputs = args.inputs,
123929 backend = args.backend,
123930 attrs = args.attrs;
123931 var boxes = inputs.boxes,
123932 scores = inputs.scores;
123933 var maxOutputSize = attrs.maxOutputSize,
123934 iouThreshold = attrs.iouThreshold,
123935 scoreThreshold = attrs.scoreThreshold,
123936 padToMaxOutputSize = attrs.padToMaxOutputSize;
123937 var boxesVals = backend.readSync(boxes.dataId);
123938 var scoresVals = backend.readSync(scores.dataId);
123939 var _nonMaxSuppressionV4I = nonMaxSuppressionV4Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize),
123940 selectedIndices = _nonMaxSuppressionV4I.selectedIndices,
123941 validOutputs = _nonMaxSuppressionV4I.validOutputs;
123942 return [backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))];
123943 }
123944 var nonMaxSuppressionV4Config = {
123945 kernelName: NonMaxSuppressionV4,
123946 backendName: 'webgl',
123947 kernelFunc: nonMaxSuppressionV4
123948 };
123949
123950 /**
123951 * @license
123952 * Copyright 2020 Google LLC. All Rights Reserved.
123953 * Licensed under the Apache License, Version 2.0 (the "License");
123954 * you may not use this file except in compliance with the License.
123955 * You may obtain a copy of the License at
123956 *
123957 * http://www.apache.org/licenses/LICENSE-2.0
123958 *
123959 * Unless required by applicable law or agreed to in writing, software
123960 * distributed under the License is distributed on an "AS IS" BASIS,
123961 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123962 * See the License for the specific language governing permissions and
123963 * limitations under the License.
123964 * =============================================================================
123965 */
123966 var nonMaxSuppressionV5Impl = nonMaxSuppressionV5Impl$2;
123967 function nonMaxSuppressionV5(args) {
123968 warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead');
123969 var inputs = args.inputs,
123970 backend = args.backend,
123971 attrs = args.attrs;
123972 var boxes = inputs.boxes,
123973 scores = inputs.scores;
123974 var maxOutputSize = attrs.maxOutputSize,
123975 iouThreshold = attrs.iouThreshold,
123976 scoreThreshold = attrs.scoreThreshold,
123977 softNmsSigma = attrs.softNmsSigma;
123978 var boxesVals = backend.readSync(boxes.dataId);
123979 var scoresVals = backend.readSync(scores.dataId);
123980 var maxOutputSizeVal = maxOutputSize;
123981 var iouThresholdVal = iouThreshold;
123982 var scoreThresholdVal = scoreThreshold;
123983 var softNmsSigmaVal = softNmsSigma;
123984 var _nonMaxSuppressionV5I = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal),
123985 selectedIndices = _nonMaxSuppressionV5I.selectedIndices,
123986 selectedScores = _nonMaxSuppressionV5I.selectedScores;
123987 return [backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))];
123988 }
123989 var nonMaxSuppressionV5Config = {
123990 kernelName: NonMaxSuppressionV5,
123991 backendName: 'webgl',
123992 kernelFunc: nonMaxSuppressionV5
123993 };
123994
123995 /**
123996 * @license
123997 * Copyright 2017 Google LLC. All Rights Reserved.
123998 * Licensed under the Apache License, Version 2.0 (the "License");
123999 * you may not use this file except in compliance with the License.
124000 * You may obtain a copy of the License at
124001 *
124002 * http://www.apache.org/licenses/LICENSE-2.0
124003 *
124004 * Unless required by applicable law or agreed to in writing, software
124005 * distributed under the License is distributed on an "AS IS" BASIS,
124006 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124007 * See the License for the specific language governing permissions and
124008 * limitations under the License.
124009 * =============================================================================
124010 */
124011 var OneHotProgram = /*#__PURE__*/_createClass(function OneHotProgram(numIndices, depth, onValue, offValue) {
124012 _classCallCheck(this, OneHotProgram);
124013 this.variableNames = ['indices'];
124014 this.outputShape = [numIndices, depth];
124015 this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int index = round(getIndices(coords.x));\n setOutput(mix(float(".concat(offValue, "), float(").concat(onValue, "),\n float(index == coords.y)));\n }\n ");
124016 });
124017
124018 var oneHot = function oneHot(args) {
124019 var inputs = args.inputs,
124020 backend = args.backend,
124021 attrs = args.attrs;
124022 var indices = inputs.indices;
124023 var dtype = attrs.dtype,
124024 depth = attrs.depth,
124025 onValue = attrs.onValue,
124026 offValue = attrs.offValue;
124027 var indicesSize = sizeFromShape(indices.shape);
124028 var program = new OneHotProgram(indicesSize, depth, onValue, offValue);
124029 var reshaped = reshape({
124030 inputs: {
124031 x: indices
124032 },
124033 backend: backend,
124034 attrs: {
124035 shape: [indicesSize]
124036 }
124037 });
124038 var result = backend.runWebGLProgram(program, [reshaped], dtype);
124039 backend.disposeIntermediateTensorInfo(reshaped);
124040 var outShape = [].concat(_toConsumableArray(indices.shape), [depth]);
124041 var out = reshape({
124042 inputs: {
124043 x: result
124044 },
124045 backend: backend,
124046 attrs: {
124047 shape: outShape
124048 }
124049 });
124050 backend.disposeIntermediateTensorInfo(result);
124051 return out;
124052 };
124053 var oneHotConfig = {
124054 kernelName: OneHot,
124055 backendName: 'webgl',
124056 kernelFunc: oneHot
124057 };
124058
124059 /**
124060 * @license
124061 * Copyright 2020 Google LLC. All Rights Reserved.
124062 * Licensed under the Apache License, Version 2.0 (the "License");
124063 * you may not use this file except in compliance with the License.
124064 * You may obtain a copy of the License at
124065 *
124066 * http://www.apache.org/licenses/LICENSE-2.0
124067 *
124068 * Unless required by applicable law or agreed to in writing, software
124069 * distributed under the License is distributed on an "AS IS" BASIS,
124070 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124071 * See the License for the specific language governing permissions and
124072 * limitations under the License.
124073 * =============================================================================
124074 */
124075 function zerosLike(args) {
124076 var inputs = args.inputs,
124077 backend = args.backend;
124078 var x = inputs.x;
124079 if (x.dtype === 'complex64') {
124080 var realPart = real({
124081 inputs: {
124082 input: x
124083 },
124084 backend: backend
124085 });
124086 var r = zerosLike({
124087 inputs: {
124088 x: realPart
124089 },
124090 backend: backend
124091 });
124092 var imagPart = imag({
124093 inputs: {
124094 input: x
124095 },
124096 backend: backend
124097 });
124098 var i = zerosLike({
124099 inputs: {
124100 x: imagPart
124101 },
124102 backend: backend
124103 });
124104 var result = complex({
124105 inputs: {
124106 real: r,
124107 imag: i
124108 },
124109 backend: backend
124110 });
124111 backend.disposeIntermediateTensorInfo(realPart);
124112 backend.disposeIntermediateTensorInfo(r);
124113 backend.disposeIntermediateTensorInfo(imagPart);
124114 backend.disposeIntermediateTensorInfo(i);
124115 return result;
124116 } else {
124117 return fill({
124118 attrs: {
124119 shape: x.shape,
124120 dtype: x.dtype,
124121 value: x.dtype === 'string' ? '' : 0
124122 },
124123 backend: backend
124124 });
124125 }
124126 }
124127 var zerosLikeConfig = {
124128 kernelName: ZerosLike,
124129 backendName: 'webgl',
124130 kernelFunc: zerosLike
124131 };
124132
124133 /**
124134 * @license
124135 * Copyright 2020 Google LLC. All Rights Reserved.
124136 * Licensed under the Apache License, Version 2.0 (the "License");
124137 * you may not use this file except in compliance with the License.
124138 * You may obtain a copy of the License at
124139 *
124140 * http://www.apache.org/licenses/LICENSE-2.0
124141 *
124142 * Unless required by applicable law or agreed to in writing, software
124143 * distributed under the License is distributed on an "AS IS" BASIS,
124144 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124145 * See the License for the specific language governing permissions and
124146 * limitations under the License.
124147 * =============================================================================
124148 */
124149 function onesLike(args) {
124150 var inputs = args.inputs,
124151 backend = args.backend;
124152 var x = inputs.x;
124153 if (x.dtype === 'string') {
124154 throw new Error('onesLike is not supported under string dtype');
124155 } else if (x.dtype === 'complex64') {
124156 var realPart = real({
124157 inputs: {
124158 input: x
124159 },
124160 backend: backend
124161 });
124162 var r = onesLike({
124163 inputs: {
124164 x: realPart
124165 },
124166 backend: backend
124167 });
124168 var imagPart = imag({
124169 inputs: {
124170 input: x
124171 },
124172 backend: backend
124173 });
124174 var i = zerosLike({
124175 inputs: {
124176 x: imagPart
124177 },
124178 backend: backend
124179 });
124180 var result = complex({
124181 inputs: {
124182 real: r,
124183 imag: i
124184 },
124185 backend: backend
124186 });
124187 backend.disposeIntermediateTensorInfo(realPart);
124188 backend.disposeIntermediateTensorInfo(r);
124189 backend.disposeIntermediateTensorInfo(imagPart);
124190 backend.disposeIntermediateTensorInfo(i);
124191 return result;
124192 } else {
124193 // TODO(cais, smilkov): Add WebGL shader for onesLike:
124194 // https://github.com/tensorflow/tfjs/issues/1293
124195 return fill({
124196 attrs: {
124197 shape: x.shape,
124198 dtype: x.dtype,
124199 value: 1
124200 },
124201 backend: backend
124202 });
124203 }
124204 }
124205 var onesLikeConfig = {
124206 kernelName: OnesLike,
124207 backendName: 'webgl',
124208 kernelFunc: onesLike
124209 };
124210
124211 /**
124212 * @license
124213 * Copyright 2020 Google LLC. All Rights Reserved.
124214 * Licensed under the Apache License, Version 2.0 (the "License");
124215 * you may not use this file except in compliance with the License.
124216 * You may obtain a copy of the License at
124217 *
124218 * http://www.apache.org/licenses/LICENSE-2.0
124219 *
124220 * Unless required by applicable law or agreed to in writing, software
124221 * distributed under the License is distributed on an "AS IS" BASIS,
124222 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124223 * See the License for the specific language governing permissions and
124224 * limitations under the License.
124225 * =============================================================================
124226 */
124227 function pack(args) {
124228 var inputs = args.inputs,
124229 backend = args.backend,
124230 attrs = args.attrs;
124231 var axis = attrs.axis;
124232 if (inputs.length === 1) {
124233 return expandDims({
124234 inputs: {
124235 input: inputs[0]
124236 },
124237 backend: backend,
124238 attrs: {
124239 dim: axis
124240 }
124241 });
124242 }
124243 var shape = inputs[0].shape;
124244 var dtype = inputs[0].dtype;
124245 inputs.forEach(function (t) {
124246 assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
124247 assert$1(dtype === t.dtype, function () {
124248 return 'All tensors passed to stack must have matching dtypes';
124249 });
124250 });
124251 var intermediateTensorInfos = [];
124252 var expandedTensors = inputs.map(function (t) {
124253 var expandedT = expandDims({
124254 inputs: {
124255 input: t
124256 },
124257 backend: backend,
124258 attrs: {
124259 dim: axis
124260 }
124261 });
124262 intermediateTensorInfos.push(expandedT);
124263 return expandedT;
124264 });
124265 var result = concat({
124266 inputs: expandedTensors,
124267 backend: backend,
124268 attrs: {
124269 axis: axis
124270 }
124271 });
124272 intermediateTensorInfos.forEach(function (t) {
124273 return backend.disposeIntermediateTensorInfo(t);
124274 });
124275 return result;
124276 }
124277 var packConfig = {
124278 kernelName: Pack,
124279 backendName: 'webgl',
124280 kernelFunc: pack
124281 };
124282
124283 var PadProgram = /*#__PURE__*/_createClass(function PadProgram(xShape, paddings, constantValue) {
124284 _classCallCheck(this, PadProgram);
124285 this.variableNames = ['x'];
124286 this.customUniforms = [{
124287 name: 'value',
124288 type: 'float'
124289 }];
124290 this.outputShape = paddings.map(function (p, i) {
124291 return p[0] /* beforePad */ + xShape[i] + p[1];
124292 } /* afterPad */);
124293 var rank = xShape.length;
124294 var type = getCoordsDataType(rank);
124295 var start = paddings.map(function (p) {
124296 return p[0];
124297 }).join(',');
124298 var end = paddings.map(function (p, i) {
124299 return p[0] + xShape[i];
124300 }).join(',');
124301 var unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
124302 if (rank === 1) {
124303 this.userCode = "\n int start = ".concat(start, ";\n int end = ").concat(end, ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start || outC >= end) {\n setOutput(value);\n } else {\n setOutput(getX(outC - start));\n }\n }\n ");
124304 return;
124305 }
124306 this.userCode = "\n ".concat(type, " start = ").concat(type, "(").concat(start, ");\n ").concat(type, " end = ").concat(type, "(").concat(end, ");\n\n void main() {\n ").concat(type, " outC = getOutputCoords();\n if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {\n setOutput(value);\n } else {\n ").concat(type, " coords = outC - start;\n setOutput(getX(").concat(unpackedCoords, "));\n }\n }\n ");
124307 });
124308
124309 var PadPackedProgram = /*#__PURE__*/_createClass(function PadPackedProgram(xShape, paddings, constantValue) {
124310 _classCallCheck(this, PadPackedProgram);
124311 this.variableNames = ['x'];
124312 this.packedInputs = true;
124313 this.packedOutput = true;
124314 this.customUniforms = [{
124315 name: 'value',
124316 type: 'float'
124317 }];
124318 this.outputShape = paddings.map(function (p, i) {
124319 return p[0] /* beforePad */ + xShape[i] + p[1];
124320 } /* afterPad */);
124321 var rank = xShape.length;
124322 var dtype = getCoordsDataType(rank);
124323 var start = paddings.map(function (p) {
124324 return p[0];
124325 }).join(',');
124326 var end = paddings.map(function (p, i) {
124327 return p[0] + xShape[i];
124328 }).join(',');
124329 var coords = getChannels('rc', rank);
124330 var source = getChannels('source', rank);
124331 var cLimit = "".concat(coords[rank - 1], " < ").concat(this.outputShape[rank - 1]);
124332 var innerDims = rank === 1 ? 'source' : "vec2(".concat(source.slice(-2).join(), ")");
124333 var componentSetup = ["".concat(dtype, " rc = outputLoc;"), "".concat(coords[rank - 1], " += 1;\n if(").concat(cLimit, ") {\n "), rank === 1 ? '' : "}\n rc = outputLoc;\n ".concat(coords[rank - 2], " += 1;\n if(").concat(coords[rank - 2], " < ").concat(this.outputShape[rank - 2], ") {"), rank === 1 ? '' : " ".concat(coords[rank - 1], " += 1;\n if(").concat(cLimit, ") {")];
124334 var paddingArea = rank === 1 ? 'rc < start || rc >= end' : 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))';
124335 var mainLoop = '';
124336 for (var i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {
124337 mainLoop += "\n ".concat(componentSetup[i], "\n if (").concat(paddingArea, ") {\n result[").concat(i, "] = float(value);\n } else {\n ").concat(dtype, " source = rc - start;\n result[").concat(i, "] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n }\n ");
124338 }
124339 mainLoop += rank === 1 ? "} " : "}}";
124340 this.userCode = "\n const ".concat(dtype, " start = ").concat(dtype, "(").concat(start, ");\n const ").concat(dtype, " end = ").concat(dtype, "(").concat(end, ");\n\n void main() {\n ").concat(dtype, " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n ").concat(mainLoop, "\n setOutput(result);\n }\n ");
124341 });
124342
124343 /**
124344 * @license
124345 * Copyright 2020 Google LLC. All Rights Reserved.
124346 * Licensed under the Apache License, Version 2.0 (the "License");
124347 * you may not use this file except in compliance with the License.
124348 * You may obtain a copy of the License at
124349 *
124350 * http://www.apache.org/licenses/LICENSE-2.0
124351 *
124352 * Unless required by applicable law or agreed to in writing, software
124353 * distributed under the License is distributed on an "AS IS" BASIS,
124354 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124355 * See the License for the specific language governing permissions and
124356 * limitations under the License.
124357 * =============================================================================
124358 */
124359 var padV2 = function padV2(args) {
124360 var inputs = args.inputs,
124361 backend = args.backend,
124362 attrs = args.attrs;
124363 var x = inputs.x;
124364 var paddings = attrs.paddings,
124365 constantValue = attrs.constantValue;
124366 if (sizeFromShape(x.shape) === 0) {
124367 // Short-circuit the computation, since x doesn't have value, only
124368 // the shape is used to compute output shape to pad.
124369 var outputShape = paddings.map(function (p, i) {
124370 return p[0] /* beforePad */ + x.shape[i] + p[1];
124371 } /* afterPad */);
124372 return fill({
124373 backend: backend,
124374 attrs: {
124375 shape: outputShape,
124376 value: constantValue,
124377 dtype: x.dtype
124378 }
124379 });
124380 }
124381 var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new PadPackedProgram(x.shape, paddings, constantValue) : new PadProgram(x.shape, paddings, constantValue);
124382 var customValues = [[constantValue]];
124383 return backend.runWebGLProgram(program, [x], x.dtype, customValues);
124384 };
124385 var padV2Config = {
124386 kernelName: PadV2,
124387 backendName: 'webgl',
124388 kernelFunc: padV2
124389 };
124390
124391 /**
124392 * @license
124393 * Copyright 2020 Google LLC. All Rights Reserved.
124394 * Licensed under the Apache License, Version 2.0 (the "License");
124395 * you may not use this file except in compliance with the License.
124396 * You may obtain a copy of the License at
124397 *
124398 * http://www.apache.org/licenses/LICENSE-2.0
124399 *
124400 * Unless required by applicable law or agreed to in writing, software
124401 * distributed under the License is distributed on an "AS IS" BASIS,
124402 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124403 * See the License for the specific language governing permissions and
124404 * limitations under the License.
124405 * =============================================================================
124406 */
124407 var POW = "\n if(a < 0.0 && floor(b) < b){\n return NAN;\n }\n if (b == 0.0) {\n return 1.0;\n }\n return (round(mod(b, 2.0)) != 1) ?\n pow(abs(a), b) : sign(a) * pow(abs(a), b);\n";
124408 var POW_PACKED = "\n // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.\n vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));\n vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);\n vec4 result = multiplier * pow(abs(a), b);\n\n // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS\n bvec4 isExpZero = equal(b, vec4(0.0));\n result.r = isExpZero.r ? 1.0 : result.r;\n result.g = isExpZero.g ? 1.0 : result.g;\n result.b = isExpZero.b ? 1.0 : result.b;\n result.a = isExpZero.a ? 1.0 : result.a;\n\n bvec4 isNaN1 = lessThan(a, vec4(0.0));\n bvec4 isNaN2 = lessThan(floor(b), b);\n bvec4 isNaN = bvec4(isNaN1.x && isNaN2.x, isNaN1.y && isNaN2.y, isNaN1.z && isNaN2.z, isNaN1.w && isNaN2.w);\n " + CHECK_NAN_SNIPPET_PACKED + "\n return result;\n";
124409 var pow = binaryKernelFunc({
124410 opSnippet: POW,
124411 packedOpSnippet: POW_PACKED
124412 });
124413 var powConfig = {
124414 kernelName: Pow,
124415 backendName: 'webgl',
124416 kernelFunc: pow
124417 };
124418
124419 function prod(args) {
124420 var inputs = args.inputs,
124421 backend = args.backend,
124422 attrs = args.attrs;
124423 var x = inputs.x;
124424 var axis = attrs.axis,
124425 keepDims = attrs.keepDims;
124426 var xRank = x.shape.length;
124427 var toDispose = [];
124428 var origAxes = parseAxisParam(axis, x.shape);
124429 var axes = origAxes;
124430 var permutedAxes = getAxesPermutation(axes, xRank);
124431 var permutedX = x;
124432 if (permutedAxes != null) {
124433 permutedX = transpose({
124434 inputs: {
124435 x: x
124436 },
124437 backend: backend,
124438 attrs: {
124439 perm: permutedAxes
124440 }
124441 });
124442 axes = getInnerMostAxes(axes.length, xRank);
124443 toDispose.push(permutedX);
124444 }
124445 assertAxesAreInnerMostDims('prod', axes, xRank);
124446 var res;
124447 if (backend.shouldExecuteOnCPU([permutedX])) {
124448 var xVals = backend.texData.get(permutedX.dataId).values;
124449 var _prodImplCPU = prodImplCPU(permutedX.shape, permutedX.dtype, xVals, axes),
124450 outVals = _prodImplCPU.outVals,
124451 outShape = _prodImplCPU.outShape,
124452 outDtype = _prodImplCPU.outDtype;
124453 res = backend.makeTensorInfo(outShape, outDtype, outVals);
124454 } else {
124455 var _backend_util$compute = computeOutAndReduceShapes(permutedX.shape, axes),
124456 _backend_util$compute2 = _slicedToArray(_backend_util$compute, 2),
124457 _outShape = _backend_util$compute2[0],
124458 reduceShape = _backend_util$compute2[1];
124459 var inSize = sizeFromShape(reduceShape);
124460 var a2D = reshape({
124461 inputs: {
124462 x: permutedX
124463 },
124464 backend: backend,
124465 attrs: {
124466 shape: [-1, inSize]
124467 }
124468 });
124469 var outputDType = sumOutType(x.dtype);
124470 var reduced = reduce(a2D, outputDType, 'prod', backend);
124471 res = reshape({
124472 inputs: {
124473 x: reduced
124474 },
124475 backend: backend,
124476 attrs: {
124477 shape: _outShape
124478 }
124479 });
124480 toDispose.push(a2D);
124481 toDispose.push(reduced);
124482 }
124483 if (keepDims) {
124484 toDispose.push(res);
124485 var newShape = expandShapeToKeepDim(res.shape, origAxes);
124486 res = reshape({
124487 inputs: {
124488 x: res
124489 },
124490 backend: backend,
124491 attrs: {
124492 shape: newShape
124493 }
124494 });
124495 }
124496 toDispose.forEach(function (t) {
124497 return backend.disposeIntermediateTensorInfo(t);
124498 });
124499 return res;
124500 }
124501 var prodConfig = {
124502 kernelName: Prod,
124503 backendName: 'webgl',
124504 kernelFunc: prod
124505 };
124506
124507 function raggedGather(args) {
124508 var inputs = args.inputs,
124509 backend = args.backend,
124510 attrs = args.attrs;
124511 var paramsNestedSplits = inputs.paramsNestedSplits,
124512 paramsDenseValues = inputs.paramsDenseValues,
124513 indices = inputs.indices;
124514 var outputRaggedRank = attrs.outputRaggedRank;
124515 var $paramsNestedSplits = paramsNestedSplits.map(function (t) {
124516 return backend.readSync(t.dataId);
124517 });
124518 var $paramsNestedSplitsShapes = paramsNestedSplits.map(function (t) {
124519 return t.shape;
124520 });
124521 var $paramsDenseValues = backend.readSync(paramsDenseValues.dataId);
124522 var $indices = backend.readSync(indices.dataId);
124523 var _raggedGatherImplCPU = raggedGatherImplCPU($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape, outputRaggedRank),
124524 _raggedGatherImplCPU2 = _slicedToArray(_raggedGatherImplCPU, 3),
124525 outputNestedSplits = _raggedGatherImplCPU2[0],
124526 outputDenseValues = _raggedGatherImplCPU2[1],
124527 outputDenseValuesShape = _raggedGatherImplCPU2[2];
124528 var outputNestedSplitsTensors = outputNestedSplits.map(function (splits) {
124529 return backend.makeTensorInfo([splits.length], 'int32', splits);
124530 });
124531 var outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues);
124532 return outputNestedSplitsTensors.concat([outputDenseValuesTensor]);
124533 }
124534 var raggedGatherConfig = {
124535 kernelName: RaggedGather,
124536 backendName: 'webgl',
124537 kernelFunc: raggedGather
124538 };
124539
124540 function raggedRange(args) {
124541 var inputs = args.inputs,
124542 backend = args.backend;
124543 var starts = inputs.starts,
124544 limits = inputs.limits,
124545 deltas = inputs.deltas;
124546 var $starts = backend.readSync(starts.dataId);
124547 var $limits = backend.readSync(limits.dataId);
124548 var $deltas = backend.readSync(deltas.dataId);
124549 var _raggedRangeImplCPU = raggedRangeImplCPU($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape),
124550 _raggedRangeImplCPU2 = _slicedToArray(_raggedRangeImplCPU, 2),
124551 rtNestedSplitsData = _raggedRangeImplCPU2[0],
124552 rtDenseValuesData = _raggedRangeImplCPU2[1];
124553 var rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData);
124554 var rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData);
124555 return [rtNestedSplits, rtDenseValues];
124556 }
124557 var raggedRangeConfig = {
124558 kernelName: RaggedRange,
124559 backendName: 'webgl',
124560 kernelFunc: raggedRange
124561 };
124562
124563 function raggedTensorToTensor(args) {
124564 var inputs = args.inputs,
124565 backend = args.backend,
124566 attrs = args.attrs;
124567 var shape = inputs.shape,
124568 values = inputs.values,
124569 defaultValue = inputs.defaultValue,
124570 rowPartitionTensors = inputs.rowPartitionTensors;
124571 var rowPartitionTypes = attrs.rowPartitionTypes;
124572 var $shape = backend.readSync(shape.dataId);
124573 var $values = backend.readSync(values.dataId);
124574 var $defaultValue = backend.readSync(defaultValue.dataId);
124575 var $rowPartitionValues = rowPartitionTensors.map(function (t) {
124576 return backend.readSync(t.dataId);
124577 });
124578 var rowPartitionValuesShapes = rowPartitionTensors.map(function (t) {
124579 return t.shape;
124580 });
124581 var _raggedTensorToTensor = raggedTensorToTensorImplCPU($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes),
124582 _raggedTensorToTensor2 = _slicedToArray(_raggedTensorToTensor, 2),
124583 outputShape = _raggedTensorToTensor2[0],
124584 output = _raggedTensorToTensor2[1];
124585 return backend.makeTensorInfo(outputShape, values.dtype, output);
124586 }
124587 var raggedTensorToTensorConfig = {
124588 kernelName: RaggedTensorToTensor,
124589 backendName: 'webgl',
124590 kernelFunc: raggedTensorToTensor
124591 };
124592
124593 /**
124594 * @license
124595 * Copyright 2020 Google LLC. All Rights Reserved.
124596 * Licensed under the Apache License, Version 2.0 (the "License");
124597 * you may not use this file except in compliance with the License.
124598 * You may obtain a copy of the License at
124599 *
124600 * http://www.apache.org/licenses/LICENSE-2.0
124601 *
124602 * Unless required by applicable law or agreed to in writing, software
124603 * distributed under the License is distributed on an "AS IS" BASIS,
124604 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124605 * See the License for the specific language governing permissions and
124606 * limitations under the License.
124607 * =============================================================================
124608 */
124609 var range = function range(args) {
124610 var backend = args.backend,
124611 attrs = args.attrs;
124612 var start = attrs.start,
124613 stop = attrs.stop,
124614 step = attrs.step,
124615 dtype = attrs.dtype;
124616 var values = rangeImplCPU(start, stop, step, dtype);
124617 return backend.makeTensorInfo([values.length], dtype, values);
124618 };
124619 var rangeConfig = {
124620 kernelName: Range,
124621 backendName: 'webgl',
124622 kernelFunc: range
124623 };
124624
124625 /**
124626 * @license
124627 * Copyright 2020 Google LLC. All Rights Reserved.
124628 * Licensed under the Apache License, Version 2.0 (the "License");
124629 * you may not use this file except in compliance with the License.
124630 * You may obtain a copy of the License at
124631 *
124632 * http://www.apache.org/licenses/LICENSE-2.0
124633 *
124634 * Unless required by applicable law or agreed to in writing, software
124635 * distributed under the License is distributed on an "AS IS" BASIS,
124636 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124637 * See the License for the specific language governing permissions and
124638 * limitations under the License.
124639 * =============================================================================
124640 */
124641 var RECIPROCAL = "return 1.0 / x;";
124642 var reciprocal = unaryKernelFunc({
124643 opSnippet: RECIPROCAL
124644 });
124645 var reciprocalConfig = {
124646 kernelName: Reciprocal,
124647 backendName: 'webgl',
124648 kernelFunc: reciprocal
124649 };
124650
124651 /**
124652 * @license
124653 * Copyright 2020 Google LLC. All Rights Reserved.
124654 * Licensed under the Apache License, Version 2.0 (the "License");
124655 * you may not use this file except in compliance with the License.
124656 * You may obtain a copy of the License at
124657 *
124658 * http://www.apache.org/licenses/LICENSE-2.0
124659 *
124660 * Unless required by applicable law or agreed to in writing, software
124661 * distributed under the License is distributed on an "AS IS" BASIS,
124662 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124663 * See the License for the specific language governing permissions and
124664 * limitations under the License.
124665 * =============================================================================
124666 */
124667 var RELU = CHECK_NAN_SNIPPET$1 + "\n return (x < 0.0) ? 0.0 : x;\n";
124668 var RELU_PACKED = "\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
124669 var relu = unaryKernelFunc({
124670 opSnippet: RELU,
124671 packedOpSnippet: RELU_PACKED
124672 });
124673 var reluConfig = {
124674 kernelName: Relu$1,
124675 backendName: 'webgl',
124676 kernelFunc: relu
124677 };
124678
124679 /**
124680 * @license
124681 * Copyright 2020 Google LLC. All Rights Reserved.
124682 * Licensed under the Apache License, Version 2.0 (the "License");
124683 * you may not use this file except in compliance with the License.
124684 * You may obtain a copy of the License at
124685 *
124686 * http://www.apache.org/licenses/LICENSE-2.0
124687 *
124688 * Unless required by applicable law or agreed to in writing, software
124689 * distributed under the License is distributed on an "AS IS" BASIS,
124690 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124691 * See the License for the specific language governing permissions and
124692 * limitations under the License.
124693 * =============================================================================
124694 */
124695 var RELU6 = CHECK_NAN_SNIPPET$1 + "\n return (x < 0.0) ? 0.0 : min(6.0, x);\n";
124696 var RELU6_PACKED = "\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
124697 var relu6 = unaryKernelFunc({
124698 opSnippet: RELU6,
124699 packedOpSnippet: RELU6_PACKED
124700 });
124701 var relu6Config = {
124702 kernelName: Relu6$1,
124703 backendName: 'webgl',
124704 kernelFunc: relu6
124705 };
124706
124707 /**
124708 * @license
124709 * Copyright 2017 Google LLC. All Rights Reserved.
124710 * Licensed under the Apache License, Version 2.0 (the "License");
124711 * you may not use this file except in compliance with the License.
124712 * You may obtain a copy of the License at
124713 *
124714 * http://www.apache.org/licenses/LICENSE-2.0
124715 *
124716 * Unless required by applicable law or agreed to in writing, software
124717 * distributed under the License is distributed on an "AS IS" BASIS,
124718 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124719 * See the License for the specific language governing permissions and
124720 * limitations under the License.
124721 * =============================================================================
124722 */
124723 var ResizeBilinearProgram = /*#__PURE__*/_createClass(function ResizeBilinearProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
124724 _classCallCheck(this, ResizeBilinearProgram);
124725 this.variableNames = ['A'];
124726 this.outputShape = [];
124727 var _inputShape = _slicedToArray(inputShape, 4),
124728 batch = _inputShape[0],
124729 oldHeight = _inputShape[1],
124730 oldWidth = _inputShape[2],
124731 depth = _inputShape[3];
124732 this.outputShape = [batch, newHeight, newWidth, depth];
124733 var effectiveInSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
124734 var effectiveOutSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth];
124735 var sourceFracIndexRC;
124736 if (halfPixelCenters) {
124737 sourceFracIndexRC = "(vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC" + " - vec2(0.5)";
124738 } else {
124739 sourceFracIndexRC = "vec2(yRC) * effectiveInputOverOutputRatioRC";
124740 }
124741 this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n ".concat(effectiveInSize[0] / effectiveOutSize[0], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ");\n const vec2 inputShapeRC = vec2(").concat(oldHeight, ".0, ").concat(oldWidth, ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = ").concat(sourceFracIndexRC, ";\n\n // Compute the four integer indices.\n ivec2 sourceFloorRC = ivec2(max(sourceFracIndexRC, vec2(0.0)));\n ivec2 sourceCeilRC = ivec2(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);\n float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);\n float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);\n float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);\n\n vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);\n\n float top = topLeft + (topRight - topLeft) * fracRC.y;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;\n float newValue = top + (bottom - top) * fracRC.x;\n\n setOutput(newValue);\n }\n ");
124742 });
124743
124744 /**
124745 * @license
124746 * Copyright 2019 Google LLC. All Rights Reserved.
124747 * Licensed under the Apache License, Version 2.0 (the "License");
124748 * you may not use this file except in compliance with the License.
124749 * You may obtain a copy of the License at
124750 *
124751 * http://www.apache.org/licenses/LICENSE-2.0
124752 *
124753 * Unless required by applicable law or agreed to in writing, software
124754 * distributed under the License is distributed on an "AS IS" BASIS,
124755 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124756 * See the License for the specific language governing permissions and
124757 * limitations under the License.
124758 * =============================================================================
124759 */
124760 var ResizeBilinearPackedProgram = /*#__PURE__*/_createClass(function ResizeBilinearPackedProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
124761 _classCallCheck(this, ResizeBilinearPackedProgram);
124762 this.variableNames = ['A'];
124763 this.packedInputs = true;
124764 this.packedOutput = true;
124765 this.outputShape = [];
124766 var _inputShape = _slicedToArray(inputShape, 4),
124767 batch = _inputShape[0],
124768 oldHeight = _inputShape[1],
124769 oldWidth = _inputShape[2],
124770 depth = _inputShape[3];
124771 this.outputShape = [batch, newHeight, newWidth, depth];
124772 var effectiveInSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
124773 var effectiveOutSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth];
124774 var sourceFracIndexRC;
124775 if (halfPixelCenters) {
124776 sourceFracIndexRC = "(vec3(yRC) + vec3(0.5)) * " + "effectiveInputOverOutputRatioRC - vec3(0.5)";
124777 } else {
124778 sourceFracIndexRC = "vec3(yRC) * effectiveInputOverOutputRatioRC";
124779 }
124780 this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n ".concat(effectiveInSize[0] / effectiveOutSize[0], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ");\n const vec3 inputShapeRC = vec3(").concat(oldHeight, ".0, ").concat(oldWidth, ".0,\n ").concat(oldWidth, ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = ").concat(sourceFracIndexRC, ";\n\n // Compute the four integer indices.\n ivec3 sourceFloorRC = ivec3(max(sourceFracIndexRC, vec3(0.0)));\n ivec3 sourceCeilRC = ivec3(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < ").concat(depth - 1, ";\n bool hasNextRow = coords.z < ").concat(newWidth - 1, ";\n\n // In parallel, construct four corners for all four components in\n // packed 2x2 cell.\n vec4 topLeft = vec4(\n getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 bottomLeft = vec4(\n getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 topRight = vec4(\n getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec4 bottomRight = vec4(\n getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);\n\n vec4 top = mix(topLeft, topRight, fracRC.yyzz);\n vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);\n vec4 newValue = mix(top, bottom, fracRC.x);\n\n setOutput(newValue);\n }\n ");
124781 });
124782
124783 function resizeBilinear(args) {
124784 var inputs = args.inputs,
124785 backend = args.backend,
124786 attrs = args.attrs;
124787 var images = inputs.images;
124788 var alignCorners = attrs.alignCorners,
124789 halfPixelCenters = attrs.halfPixelCenters,
124790 size = attrs.size;
124791 var _size = _slicedToArray(size, 2),
124792 newHeight = _size[0],
124793 newWidth = _size[1];
124794 var program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeBilinearPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) : new ResizeBilinearProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
124795 return backend.runWebGLProgram(program, [images], 'float32');
124796 }
124797 var resizeBilinearConfig = {
124798 kernelName: ResizeBilinear,
124799 backendName: 'webgl',
124800 kernelFunc: resizeBilinear
124801 };
124802
124803 /**
124804 * @license
124805 * Copyright 2018 Google LLC. All Rights Reserved.
124806 * Licensed under the Apache License, Version 2.0 (the "License");
124807 * you may not use this file except in compliance with the License.
124808 * You may obtain a copy of the License at
124809 *
124810 * http://www.apache.org/licenses/LICENSE-2.0
124811 *
124812 * Unless required by applicable law or agreed to in writing, software
124813 * distributed under the License is distributed on an "AS IS" BASIS,
124814 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124815 * See the License for the specific language governing permissions and
124816 * limitations under the License.
124817 * =============================================================================
124818 */
124819 var ResizeBilinearBackpropProgram = /*#__PURE__*/_createClass(function ResizeBilinearBackpropProgram(dyShape, inputShape, alignCorners) {
124820 _classCallCheck(this, ResizeBilinearBackpropProgram);
124821 this.variableNames = ['dy'];
124822 this.outputShape = [];
124823 this.outputShape = inputShape;
124824 var _inputShape = _slicedToArray(inputShape, 3),
124825 xHeight = _inputShape[1],
124826 xWidth = _inputShape[2];
124827 var _dyShape = _slicedToArray(dyShape, 3),
124828 yHeight = _dyShape[1],
124829 yWidth = _dyShape[2];
124830 // In the backwards pass, we want to find the pixels that were generated for
124831 // each pixel in the input image the forward pass and add the corresponding
124832 // coefficient from dy to the gradient (with some interpolation).
124833 var effectiveXSize = [alignCorners && yHeight > 1 ? xHeight - 1 : xHeight, alignCorners && yWidth > 1 ? xWidth - 1 : xWidth];
124834 var effectiveYSize = [alignCorners && yHeight > 1 ? yHeight - 1 : yHeight, alignCorners && yWidth > 1 ? yWidth - 1 : yWidth];
124835 var heightScale = effectiveXSize[0] / effectiveYSize[0];
124836 var widthScale = effectiveXSize[1] / effectiveYSize[1];
124837 var invHeightScale = 1 / heightScale;
124838 var invWidthScale = 1 / widthScale;
124839 // This defines the size of the window of values around a particular
124840 // index in dy that we want to search for contributions to dx.
124841 var winHeight = Math.ceil(invHeightScale) * 2 + 2;
124842 var winWidth = Math.ceil(invWidthScale) * 2 + 2;
124843 this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(".concat(heightScale, ");\n const float widthScale = float(").concat(widthScale, ");\n\n const float invHeightScale = float(").concat(invHeightScale, ");\n const float invWidthScale = float(").concat(invWidthScale, ");\n\n const int winHeight = int(").concat(winHeight, ");\n const int winWidth = int(").concat(winWidth, ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(startRLerp - float(winHeight / 2));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(startCLerp - float(winWidth / 2));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= ").concat(yHeight, ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= ").concat(yWidth, ") {\n continue;\n }\n\n float dxR = float(dyR) * heightScale;\n int topDxRIndex = int(floor(dxR));\n int bottomDxRIndex = int(min(ceil(dxR), ").concat(xHeight - 1, ".0));\n float dxRLerp = dxR - float(topDxRIndex);\n float inverseDxRLerp = 1.0 - dxRLerp;\n\n float dxC = float(dyC) * widthScale;\n int leftDxCIndex = int(floor(dxC));\n int rightDxCIndex = int(min(ceil(dxC), ").concat(xWidth - 1, ".0));\n float dxCLerp = dxC - float(leftDxCIndex);\n float inverseDxCLerp = 1.0 - dxCLerp;\n\n if (r == topDxRIndex && c == leftDxCIndex) {\n // topLeft\n accumulator +=\n getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;\n }\n\n if (r == topDxRIndex && c == rightDxCIndex) {\n // topRight\n accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;\n }\n\n if (r == bottomDxRIndex && c == leftDxCIndex) {\n // bottomLeft\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;\n }\n\n if (r == bottomDxRIndex && c == rightDxCIndex) {\n // bottomRight\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n ");
124844 });
124845
124846 /**
124847 * @license
124848 * Copyright 2020 Google LLC. All Rights Reserved.
124849 * Licensed under the Apache License, Version 2.0 (the "License");
124850 * you may not use this file except in compliance with the License.
124851 * You may obtain a copy of the License at
124852 *
124853 * http://www.apache.org/licenses/LICENSE-2.0
124854 *
124855 * Unless required by applicable law or agreed to in writing, software
124856 * distributed under the License is distributed on an "AS IS" BASIS,
124857 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124858 * See the License for the specific language governing permissions and
124859 * limitations under the License.
124860 * =============================================================================
124861 */
124862 function resizeBilinearGrad(args) {
124863 var inputs = args.inputs,
124864 backend = args.backend,
124865 attrs = args.attrs;
124866 var images = inputs.images,
124867 dy = inputs.dy;
124868 var alignCorners = attrs.alignCorners;
124869 var program = new ResizeBilinearBackpropProgram(dy.shape, images.shape, alignCorners);
124870 return backend.runWebGLProgram(program, [dy], dy.dtype);
124871 }
124872 var resizeBilinearGradConfig = {
124873 kernelName: ResizeBilinearGrad,
124874 backendName: 'webgl',
124875 kernelFunc: resizeBilinearGrad
124876 };
124877
124878 /**
124879 * @license
124880 * Copyright 2018 Google LLC. All Rights Reserved.
124881 * Licensed under the Apache License, Version 2.0 (the "License");
124882 * you may not use this file except in compliance with the License.
124883 * You may obtain a copy of the License at
124884 *
124885 * http://www.apache.org/licenses/LICENSE-2.0
124886 *
124887 * Unless required by applicable law or agreed to in writing, software
124888 * distributed under the License is distributed on an "AS IS" BASIS,
124889 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124890 * See the License for the specific language governing permissions and
124891 * limitations under the License.
124892 * =============================================================================
124893 */
124894 var ResizeNearestNeighborProgram = /*#__PURE__*/_createClass(function ResizeNearestNeighborProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
124895 _classCallCheck(this, ResizeNearestNeighborProgram);
124896 this.variableNames = ['A'];
124897 this.outputShape = [];
124898 var _inputShape = _slicedToArray(inputShape, 4),
124899 batch = _inputShape[0],
124900 oldHeight = _inputShape[1],
124901 oldWidth = _inputShape[2],
124902 depth = _inputShape[3];
124903 this.outputShape = [batch, newHeight, newWidth, depth];
124904 var effectiveInSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
124905 var effectiveOutSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth];
124906 // When align corners is false, we rounds the value with floor.
124907 var roundBase = alignCorners ? '0.5' : '0.0';
124908 var sourceFracIndexRC;
124909 if (halfPixelCenters) {
124910 sourceFracIndexRC = "max((vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC" + ", vec2(0.0))";
124911 } else {
124912 sourceFracIndexRC = "vec2(yRC) * effectiveInputOverOutputRatioRC";
124913 }
124914 this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n ".concat(effectiveInSize[0] / effectiveOutSize[0], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ");\n const vec2 inputShapeRC = vec2(").concat(oldHeight, ".0, ").concat(oldWidth, ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = ").concat(sourceFracIndexRC, ";\n\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestRC = ivec2(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ").concat(roundBase, ")));\n float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);\n\n setOutput(newValue);\n }\n ");
124915 });
124916
124917 /**
124918 * @license
124919 * Copyright 2019 Google LLC. All Rights Reserved.
124920 * Licensed under the Apache License, Version 2.0 (the "License");
124921 * you may not use this file except in compliance with the License.
124922 * You may obtain a copy of the License at
124923 *
124924 * http://www.apache.org/licenses/LICENSE-2.0
124925 *
124926 * Unless required by applicable law or agreed to in writing, software
124927 * distributed under the License is distributed on an "AS IS" BASIS,
124928 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124929 * See the License for the specific language governing permissions and
124930 * limitations under the License.
124931 * =============================================================================
124932 */
124933 var ResizeNearestNeighborPackedProgram = /*#__PURE__*/_createClass(function ResizeNearestNeighborPackedProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
124934 _classCallCheck(this, ResizeNearestNeighborPackedProgram);
124935 this.variableNames = ['A'];
124936 this.packedInputs = true;
124937 this.packedOutput = true;
124938 this.outputShape = [];
124939 var _inputShape = _slicedToArray(inputShape, 4),
124940 batch = _inputShape[0],
124941 oldHeight = _inputShape[1],
124942 oldWidth = _inputShape[2],
124943 depth = _inputShape[3];
124944 this.outputShape = [batch, newHeight, newWidth, depth];
124945 var effectiveInSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
124946 var effectiveOutSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth];
124947 // When align corners is false, we rounds the value with floor.
124948 var roundBase = alignCorners ? '0.5' : '0.0';
124949 var sourceFracIndexRC;
124950 if (halfPixelCenters) {
124951 sourceFracIndexRC = "max((vec3(yRC) + vec3(0.5)) * " + "effectiveInputOverOutputRatioRC, vec3(0.0))";
124952 } else {
124953 sourceFracIndexRC = "vec3(yRC) * effectiveInputOverOutputRatioRC";
124954 }
124955 this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n ".concat(effectiveInSize[0] / effectiveOutSize[0], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ");\n const vec3 inputShapeRC = vec3(").concat(oldHeight, ".0, ").concat(oldWidth, ".0,\n ").concat(oldWidth, ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = ").concat(sourceFracIndexRC, ";\n\n // Compute the coordinators of nearest neighbor point.\n ivec3 sourceNearestRC = ivec3(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ").concat(roundBase, ")));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < ").concat(depth - 1, ";\n bool hasNextRow = coords.z < ").concat(newWidth - 1, ";\n\n vec4 newValue = vec4(\n getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d),\n hasNextCol ? getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d + 1) : 0.0);\n\n setOutput(newValue);\n }\n ");
124956 });
124957
124958 function resizeNearestNeighbor(args) {
124959 var inputs = args.inputs,
124960 backend = args.backend,
124961 attrs = args.attrs;
124962 var images = inputs.images;
124963 var alignCorners = attrs.alignCorners,
124964 halfPixelCenters = attrs.halfPixelCenters,
124965 size = attrs.size;
124966 var _size = _slicedToArray(size, 2),
124967 newHeight = _size[0],
124968 newWidth = _size[1];
124969 var program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeNearestNeighborPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) : new ResizeNearestNeighborProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
124970 return backend.runWebGLProgram(program, [images], images.dtype);
124971 }
124972 var resizeNearestNeighborConfig = {
124973 kernelName: ResizeNearestNeighbor,
124974 backendName: 'webgl',
124975 kernelFunc: resizeNearestNeighbor
124976 };
124977
124978 /**
124979 * @license
124980 * Copyright 2018 Google LLC. All Rights Reserved.
124981 * Licensed under the Apache License, Version 2.0 (the "License");
124982 * you may not use this file except in compliance with the License.
124983 * You may obtain a copy of the License at
124984 *
124985 * http://www.apache.org/licenses/LICENSE-2.0
124986 *
124987 * Unless required by applicable law or agreed to in writing, software
124988 * distributed under the License is distributed on an "AS IS" BASIS,
124989 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124990 * See the License for the specific language governing permissions and
124991 * limitations under the License.
124992 * =============================================================================
124993 */
124994 var ResizeNearestNeigborBackpropProgram = /*#__PURE__*/_createClass(function ResizeNearestNeigborBackpropProgram(dyShape, inputShape, alignCorners) {
124995 _classCallCheck(this, ResizeNearestNeigborBackpropProgram);
124996 this.variableNames = ['dy'];
124997 this.outputShape = [];
124998 this.outputShape = inputShape;
124999 var _inputShape = _slicedToArray(inputShape, 3),
125000 xHeight = _inputShape[1],
125001 xWidth = _inputShape[2];
125002 var _dyShape = _slicedToArray(dyShape, 3),
125003 yHeight = _dyShape[1],
125004 yWidth = _dyShape[2];
125005 // In the backwards pass, we want to find the pixels that were generated for
125006 // each pixel in the input image the forward pass and add the corresponding
125007 // coefficient from dy to the gradient (with some interpolation).
125008 var effectiveXSize = [alignCorners && yHeight > 1 ? xHeight - 1 : xHeight, alignCorners && yWidth > 1 ? xWidth - 1 : xWidth];
125009 var effectiveYSize = [alignCorners && yHeight > 1 ? yHeight - 1 : yHeight, alignCorners && yWidth > 1 ? yWidth - 1 : yWidth];
125010 var heightScale = effectiveXSize[0] / effectiveYSize[0];
125011 var widthScale = effectiveXSize[1] / effectiveYSize[1];
125012 var invHeightScale = 1 / heightScale;
125013 var invWidthScale = 1 / widthScale;
125014 // This defines the size of the window of values around a particular
125015 // index in dy that we want to search for contributions to dx.
125016 var winHeight = Math.ceil(invHeightScale) * 2 + 2;
125017 var winWidth = Math.ceil(invWidthScale) * 2 + 2;
125018 this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(".concat(heightScale, ");\n const float widthScale = float(").concat(widthScale, ");\n\n const float invHeightScale = float(").concat(invHeightScale, ");\n const float invWidthScale = float(").concat(invWidthScale, ");\n\n const int winHeight = int(").concat(winHeight, ");\n const int winWidth = int(").concat(winWidth, ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(floor(startRLerp - float(winHeight / 2)));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(floor(startCLerp - float(winWidth / 2)));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= ").concat(yHeight, ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= ").concat(yWidth, ") {\n continue;\n }\n\n float sourceFracRow =\n float(").concat(effectiveXSize[0], ") *\n (float(dyR) / float(").concat(effectiveYSize[0], "));\n\n float sourceFracCol =\n float(").concat(effectiveXSize[1], ") *\n (float(dyC) / float(").concat(effectiveYSize[1], "));\n\n int sourceNearestRow = int(min(\n float(int(").concat(xHeight, ") - 1),\n ").concat(alignCorners, " ? float(round(sourceFracRow)) :\n float(floor(sourceFracRow))));\n\n int sourceNearestCol = int(min(\n float(int(").concat(xWidth, ") - 1),\n ").concat(alignCorners, " ? float(round(sourceFracCol)) :\n float(floor(sourceFracCol))));\n\n if (r == sourceNearestRow && c == sourceNearestCol) {\n accumulator += getDy(b, dyR, dyC, d);\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n ");
125019 });
125020
125021 /**
125022 * @license
125023 * Copyright 2020 Google LLC. All Rights Reserved.
125024 * Licensed under the Apache License, Version 2.0 (the "License");
125025 * you may not use this file except in compliance with the License.
125026 * You may obtain a copy of the License at
125027 *
125028 * http://www.apache.org/licenses/LICENSE-2.0
125029 *
125030 * Unless required by applicable law or agreed to in writing, software
125031 * distributed under the License is distributed on an "AS IS" BASIS,
125032 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125033 * See the License for the specific language governing permissions and
125034 * limitations under the License.
125035 * =============================================================================
125036 */
125037 function resizeNearestNeighborGrad(args) {
125038 var inputs = args.inputs,
125039 backend = args.backend,
125040 attrs = args.attrs;
125041 var images = inputs.images,
125042 dy = inputs.dy;
125043 var alignCorners = attrs.alignCorners;
125044 var program = new ResizeNearestNeigborBackpropProgram(dy.shape, images.shape, alignCorners);
125045 return backend.runWebGLProgram(program, [dy], dy.dtype);
125046 }
125047 var resizeNearestNeighborGradConfig = {
125048 kernelName: ResizeNearestNeighborGrad,
125049 backendName: 'webgl',
125050 kernelFunc: resizeNearestNeighborGrad
125051 };
125052
125053 var ReverseProgram = /*#__PURE__*/_createClass(function ReverseProgram(xShape, axis) {
125054 _classCallCheck(this, ReverseProgram);
125055 this.variableNames = ['x'];
125056 var rank = xShape.length;
125057 if (rank > 4) {
125058 throw new Error("WebGL backend: Reverse of rank-".concat(rank, " tensor is not yet supported"));
125059 }
125060 this.outputShape = xShape;
125061 if (rank === 1) {
125062 this.userCode = "\n void main() {\n int coord = getOutputCoords();\n setOutput(getX(".concat(xShape[0], " - coord - 1));\n }\n ");
125063 return;
125064 }
125065 var getInCoord = function getInCoord(i) {
125066 if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
125067 return "".concat(xShape[i], " - coords[").concat(i, "] - 1");
125068 }
125069 return "coords[".concat(i, "]");
125070 };
125071 var inCoords = xShape.map(function (_, i) {
125072 return getInCoord(i);
125073 }).join(',');
125074 var type = getCoordsDataType(rank);
125075 this.userCode = "\n void main() {\n ".concat(type, " coords = getOutputCoords();\n setOutput(getX(").concat(inCoords, "));\n }\n ");
125076 });
125077
125078 var ReversePackedProgram = /*#__PURE__*/_createClass(function ReversePackedProgram(xShape, axis) {
125079 _classCallCheck(this, ReversePackedProgram);
125080 this.variableNames = ['x'];
125081 this.packedInputs = true;
125082 this.packedOutput = true;
125083 var rank = xShape.length;
125084 if (rank > 4) {
125085 throw new Error("WebGL backend: Reverse of rank-".concat(rank, " tensor is not yet supported"));
125086 }
125087 this.outputShape = xShape;
125088 var channels = getChannels('rc', rank);
125089 var nextColumn = "".concat(channels[rank - 1], " + 1 < ").concat(this.outputShape[rank - 1]);
125090 var nextRow = "".concat(channels[rank - 2], " + 1 < ").concat(this.outputShape[rank - 2]);
125091 var type = getCoordsDataType(rank);
125092 if (rank === 1) {
125093 this.userCode = "\n void main(){\n int rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = getChannel(getX(".concat(xShape[0], " - rc - 1),\n ").concat(xShape[0], " - rc - 1);\n if(").concat(nextColumn, "){\n result.g = getChannel(getX(").concat(xShape[0], " - (rc + 1) - 1),\n ").concat(xShape[0], " - (rc + 1) - 1);\n }\n setOutput(result);\n }\n ");
125094 } else {
125095 this.userCode = "\n void main() {\n ".concat(type, " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = ").concat(getR(channels.slice()), ";\n if(").concat(nextColumn, "){\n result.g = ").concat(getG(channels.slice()), ";\n }\n if(").concat(nextRow, ") {\n result.b = ").concat(getB(channels.slice()), ";\n if(").concat(nextColumn, ") {\n result.a = ").concat(getA(channels.slice()), ";\n }\n }\n setOutput(result);\n }\n ");
125096 }
125097 function getR(channels) {
125098 return getChannel(channels);
125099 }
125100 function getG(channels) {
125101 channels[rank - 1] = '(' + channels[rank - 1] + " + 1)";
125102 return getChannel(channels);
125103 }
125104 function getB(channels) {
125105 channels[rank - 2] = '(' + channels[rank - 2] + " + 1)";
125106 return getChannel(channels);
125107 }
125108 function getA(channels) {
125109 channels[rank - 1] = '(' + channels[rank - 1] + " + 1)";
125110 channels[rank - 2] = '(' + channels[rank - 2] + " + 1)";
125111 return getChannel(channels);
125112 }
125113 function getChannel(channels) {
125114 var inCoordsArray = xShape.map(function (_, i) {
125115 return getInCoord(i, channels);
125116 });
125117 var inCoords = inCoordsArray.join(',');
125118 var innerDims = inCoordsArray.slice(-2).join(',');
125119 return "getChannel(getX(".concat(inCoords, "), vec2(").concat(innerDims, "))");
125120 }
125121 function getInCoord(i, channels1) {
125122 if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
125123 return "".concat(xShape[i], " - ").concat(channels1[i], " - 1");
125124 } else {
125125 return "".concat(channels1[i]);
125126 }
125127 }
125128 });
125129
125130 /**
125131 * @license
125132 * Copyright 2020 Google LLC. All Rights Reserved.
125133 * Licensed under the Apache License, Version 2.0 (the "License");
125134 * you may not use this file except in compliance with the License.
125135 * You may obtain a copy of the License at
125136 *
125137 * http://www.apache.org/licenses/LICENSE-2.0
125138 *
125139 * Unless required by applicable law or agreed to in writing, software
125140 * distributed under the License is distributed on an "AS IS" BASIS,
125141 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125142 * See the License for the specific language governing permissions and
125143 * limitations under the License.
125144 * =============================================================================
125145 */
125146 function reverse(args) {
125147 var inputs = args.inputs,
125148 backend = args.backend,
125149 attrs = args.attrs;
125150 var x = inputs.x;
125151 var dims = attrs.dims;
125152 var xRank = x.shape.length;
125153 var $dims = parseAxisParam(dims, x.shape);
125154 if (xRank === 0) {
125155 return identity({
125156 inputs: {
125157 x: x
125158 },
125159 backend: backend
125160 });
125161 }
125162 var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new ReversePackedProgram(x.shape, $dims) : new ReverseProgram(x.shape, $dims);
125163 return backend.runWebGLProgram(program, [x], x.dtype);
125164 }
125165 var reverseConfig = {
125166 kernelName: Reverse,
125167 backendName: 'webgl',
125168 kernelFunc: reverse
125169 };
125170
125171 /**
125172 * @license
125173 * Copyright 2020 Google LLC. All Rights Reserved.
125174 * Licensed under the Apache License, Version 2.0 (the "License");
125175 * you may not use this file except in compliance with the License.
125176 * You may obtain a copy of the License at
125177 *
125178 * http://www.apache.org/licenses/LICENSE-2.0
125179 *
125180 * Unless required by applicable law or agreed to in writing, software
125181 * distributed under the License is distributed on an "AS IS" BASIS,
125182 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125183 * See the License for the specific language governing permissions and
125184 * limitations under the License.
125185 * =============================================================================
125186 */
125187 var RotateProgram = /*#__PURE__*/_createClass(function RotateProgram(imageShape, fillValue) {
125188 _classCallCheck(this, RotateProgram);
125189 this.variableNames = ['Image'];
125190 this.outputShape = [];
125191 this.customUniforms = [{
125192 name: 'params',
125193 type: 'vec4'
125194 }];
125195 var imageHeight = imageShape[1];
125196 var imageWidth = imageShape[2];
125197 this.outputShape = imageShape;
125198 var fillSnippet = '';
125199 if (typeof fillValue === 'number') {
125200 fillSnippet = "float outputValue = ".concat(fillValue.toFixed(2), ";");
125201 } else {
125202 fillSnippet = "\n vec3 fill = vec3(".concat(fillValue.join(','), ");\n float outputValue = fill[coords[3]];");
125203 }
125204 this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int x = coords[2];\n int y = coords[1];\n float coordXFloat = (float(x) - params[0]) * params[3] -\n (float(y) - params[1]) * params[2];\n float coordYFloat = (float(x) - params[0]) * params[2] +\n (float(y) - params[1]) * params[3];\n int coordX = int(round(coordXFloat + params[0]));\n int coordY = int(round(coordYFloat + params[1]));\n ".concat(fillSnippet, "\n if(coordX >= 0 && coordX < ").concat(imageWidth, " && coordY >= 0 && coordY < ").concat(imageHeight, ") {\n outputValue = getImage(coords[0], coordY, coordX, coords[3]);\n }\n setOutput(outputValue);\n }\n ");
125205 });
125206
125207 var rotateWithOffsetConfig = {
125208 kernelName: RotateWithOffset,
125209 backendName: 'webgl',
125210 kernelFunc: function kernelFunc(_ref) {
125211 var inputs = _ref.inputs,
125212 attrs = _ref.attrs,
125213 backend = _ref.backend;
125214 var image = inputs.image;
125215 var radians = attrs.radians,
125216 fillValue = attrs.fillValue,
125217 center = attrs.center;
125218 var webglBackend = backend;
125219 var program = new RotateProgram(image.shape, fillValue);
125220 var _backend_util$getImag = getImageCenter(center, image.shape[1], image.shape[2]),
125221 _backend_util$getImag2 = _slicedToArray(_backend_util$getImag, 2),
125222 centerX = _backend_util$getImag2[0],
125223 centerY = _backend_util$getImag2[1];
125224 var customValues = [[centerX, centerY, Math.sin(radians), Math.cos(radians)]];
125225 var output = webglBackend.runWebGLProgram(program, [image], image.dtype, customValues);
125226 return output;
125227 }
125228 };
125229
125230 /**
125231 * @license
125232 * Copyright 2020 Google LLC. All Rights Reserved.
125233 * Licensed under the Apache License, Version 2.0 (the "License");
125234 * you may not use this file except in compliance with the License.
125235 * You may obtain a copy of the License at
125236 *
125237 * http://www.apache.org/licenses/LICENSE-2.0
125238 *
125239 * Unless required by applicable law or agreed to in writing, software
125240 * distributed under the License is distributed on an "AS IS" BASIS,
125241 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125242 * See the License for the specific language governing permissions and
125243 * limitations under the License.
125244 * =============================================================================
125245 */
125246 var ROUND = "\n // OpenGL ES does not support round function.\n // The algorithm is based on banker's rounding.\n float base = floor(x);\n if ((x - base) < 0.5) {\n return floor(x);\n } else if ((x - base) > 0.5) {\n return ceil(x);\n } else {\n if (mod(base, 2.0) == 0.0) {\n return base;\n } else {\n return base + 1.0;\n }\n }\n";
125247 var round = unaryKernelFunc({
125248 opSnippet: ROUND
125249 });
125250 var roundConfig = {
125251 kernelName: Round,
125252 backendName: 'webgl',
125253 kernelFunc: round
125254 };
125255
125256 /**
125257 * @license
125258 * Copyright 2020 Google LLC. All Rights Reserved.
125259 * Licensed under the Apache License, Version 2.0 (the "License");
125260 * you may not use this file except in compliance with the License.
125261 * You may obtain a copy of the License at
125262 *
125263 * http://www.apache.org/licenses/LICENSE-2.0
125264 *
125265 * Unless required by applicable law or agreed to in writing, software
125266 * distributed under the License is distributed on an "AS IS" BASIS,
125267 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125268 * See the License for the specific language governing permissions and
125269 * limitations under the License.
125270 * =============================================================================
125271 */
125272 var RSQRT = "return inversesqrt(x);";
125273 var rsqrt = unaryKernelFunc({
125274 opSnippet: RSQRT,
125275 cpuKernelImpl: rsqrtImplCPU
125276 });
125277 var rsqrtConfig = {
125278 kernelName: Rsqrt,
125279 backendName: 'webgl',
125280 kernelFunc: rsqrt
125281 };
125282
125283 var ScatterProgram = /*#__PURE__*/_createClass(function ScatterProgram(updateSize, sliceDim, indicesRank, updatesRank, strides, shape) {
125284 var summingDupeIndex = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : true;
125285 var defaultIsTensor = arguments.length > 7 && arguments[7] !== undefined ? arguments[7] : false;
125286 _classCallCheck(this, ScatterProgram);
125287 this.variableNames = ['updates', 'indices', 'defaultValue'];
125288 this.outputShape = shape;
125289 var stridesType = getCoordsDataType(strides.length);
125290 var dtype = getCoordsDataType(shape.length);
125291 var indicesString = '';
125292 if (indicesRank === 1) {
125293 indicesString = 'i';
125294 } else if (indicesRank === 2) {
125295 indicesString = 'i, j';
125296 }
125297 var indicesSnippet = "getIndices(".concat(indicesString, ")");
125298 var updatesString = '';
125299 if (updatesRank === 1) {
125300 updatesString = 'i';
125301 } else if (updatesRank === 2) {
125302 updatesString = 'i, coords[1]';
125303 }
125304 var updatesSnippet = "getUpdates(".concat(updatesString, ")");
125305 var defaultValuesString = '';
125306 if (defaultIsTensor) {
125307 defaultValuesString = 'coords[0], coords[1]';
125308 }
125309 var defaultValueSnippet = "getDefaultValue(".concat(defaultValuesString, ")");
125310 var strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
125311 this.userCode = "\n ".concat(stridesType, " strides = ").concat(stridesType, "(").concat(strides, ");\n\n void main() {\n ").concat(dtype, " coords = getOutputCoords();\n float sum = 0.0;\n bool found = false;\n for (int i = 0; i < ").concat(updateSize, "; i++) {\n int flattenedIndex = 0;\n for (int j = 0; j < ").concat(sliceDim, "; j++) {\n int index = round(").concat(indicesSnippet, ");\n flattenedIndex += index * ").concat(strideString, ";\n }\n if (flattenedIndex == coords[0]) {\n sum += ").concat(updatesSnippet, ";\n found = true;\n }\n }\n setOutput(mix(").concat(defaultValueSnippet, ", sum, float(found)));\n }\n ");
125312 });
125313
125314 var ScatterPackedProgram = /*#__PURE__*/_createClass(function ScatterPackedProgram(updateSize, sliceDim, indicesRank, updatesRank, strides, shape) {
125315 var summingDupeIndex = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : true;
125316 var defaultIsTensor = arguments.length > 7 && arguments[7] !== undefined ? arguments[7] : false;
125317 _classCallCheck(this, ScatterPackedProgram);
125318 this.variableNames = ['updates', 'indices', 'defaultValue'];
125319 this.packedInputs = true;
125320 this.packedOutput = true;
125321 this.outputShape = shape;
125322 var stridesType = getCoordsDataType(strides.length);
125323 var dtype = getCoordsDataType(shape.length);
125324 var indicesString = '';
125325 if (indicesRank === 1) {
125326 indicesString = 'i';
125327 } else if (indicesRank === 2) {
125328 indicesString = 'i, j';
125329 }
125330 var indicesSnippet = "getIndices(".concat(indicesString, ")");
125331 var updatesString = '';
125332 if (updatesRank === 1) {
125333 updatesString = 'i';
125334 } else if (updatesRank === 2) {
125335 updatesString = 'i, coords[1]';
125336 }
125337 var updatesSnippet = "getUpdates(".concat(updatesString, ")");
125338 var defaultValuesString = '';
125339 if (defaultIsTensor) {
125340 defaultValuesString = 'coords[0], coords[1]';
125341 }
125342 var defaultValueSnippet = "getDefaultValue(".concat(defaultValuesString, ")");
125343 var strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
125344 var strideString2 = sliceDim > 1 ? 'strides[j + 1]' : 'strides';
125345 this.userCode = "\n ".concat(stridesType, " strides = ").concat(stridesType, "(").concat(strides, ");\n\n void main() {\n ").concat(dtype, " coords = getOutputCoords();\n vec4 sum = vec4(0.);\n vec4 found = vec4(0.);\n for (int i = 0; i < ").concat(updateSize, "; i+=2) {\n ivec2 flattenedIndex = ivec2(0);\n for (int j = 0; j < ").concat(sliceDim, "; j+=2) {\n ivec4 index = round(").concat(indicesSnippet, ");\n flattenedIndex += index.xz * ").concat(strideString, ";\n if (j + 1 < ").concat(sliceDim, ") {\n flattenedIndex += index.yw * ").concat(strideString2, ";\n }\n }\n if (flattenedIndex[0] == coords[0] || flattenedIndex[1] == coords[0] ||\n flattenedIndex[0] == coords[0] + 1 || flattenedIndex[1] == coords[0] + 1) {\n vec4 updVals = ").concat(updatesSnippet, ";\n if (flattenedIndex[0] == coords[0]) {\n sum.xy += updVals.xy;\n found.xy = vec2(1.);\n } else if (flattenedIndex[0] == coords[0] + 1) {\n sum.zw += updVals.xy;\n found.zw = vec2(1.);\n }\n if (flattenedIndex[1] == coords[0]) {\n sum.xy += updVals.zw;\n found.xy = vec2(1.);\n } else if (flattenedIndex[1] == coords[0] + 1) {\n sum.zw += updVals.zw;\n found.zw = vec2(1.);\n }\n }\n }\n setOutput(mix(").concat(defaultValueSnippet, ", sum, found));\n }\n ");
125346 });
125347
125348 /**
125349 * @license
125350 * Copyright 2020 Google LLC. All Rights Reserved.
125351 * Licensed under the Apache License, Version 2.0 (the "License");
125352 * you may not use this file except in compliance with the License.
125353 * You may obtain a copy of the License at
125354 *
125355 * http://www.apache.org/licenses/LICENSE-2.0
125356 *
125357 * Unless required by applicable law or agreed to in writing, software
125358 * distributed under the License is distributed on an "AS IS" BASIS,
125359 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125360 * See the License for the specific language governing permissions and
125361 * limitations under the License.
125362 * =============================================================================
125363 */
125364 function scatterNd(args) {
125365 var inputs = args.inputs,
125366 backend = args.backend,
125367 attrs = args.attrs;
125368 var indices = inputs.indices,
125369 updates = inputs.updates;
125370 var shape = attrs.shape;
125371 var _backend_util$calcula = calculateShapes(updates, indices, shape),
125372 sliceRank = _backend_util$calcula.sliceRank,
125373 numUpdates = _backend_util$calcula.numUpdates,
125374 sliceSize = _backend_util$calcula.sliceSize,
125375 strides = _backend_util$calcula.strides,
125376 outputSize = _backend_util$calcula.outputSize;
125377 var flattenShape = [outputSize / sliceSize, sliceSize];
125378 if (outputSize === 0) {
125379 return backend.makeTensorInfo(shape, indices.dtype);
125380 }
125381 var flattenIndices = reshape({
125382 inputs: {
125383 x: indices
125384 },
125385 backend: backend,
125386 attrs: {
125387 shape: [numUpdates, sliceRank]
125388 }
125389 });
125390 var flattenX = reshape({
125391 inputs: {
125392 x: updates
125393 },
125394 backend: backend,
125395 attrs: {
125396 shape: [numUpdates, sliceSize]
125397 }
125398 });
125399 var defaultValue = backend.makeTensorInfo([], 'float32', new Float32Array([0])); // scalar(0)
125400 var program;
125401 if (env().getBool('WEBGL_PACK')) {
125402 program = new ScatterPackedProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape);
125403 } else {
125404 program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape);
125405 }
125406 var res = backend.runWebGLProgram(program, [flattenX, flattenIndices, defaultValue], flattenX.dtype);
125407 var reshaped = reshape({
125408 inputs: {
125409 x: res
125410 },
125411 backend: backend,
125412 attrs: {
125413 shape: shape
125414 }
125415 });
125416 backend.disposeIntermediateTensorInfo(flattenIndices);
125417 backend.disposeIntermediateTensorInfo(flattenX);
125418 backend.disposeIntermediateTensorInfo(res);
125419 backend.disposeIntermediateTensorInfo(defaultValue);
125420 return reshaped;
125421 }
125422 var scatterNdConfig = {
125423 kernelName: ScatterNd,
125424 backendName: 'webgl',
125425 kernelFunc: scatterNd
125426 };
125427
125428 var SearchSortedProgram = /*#__PURE__*/_createClass(function SearchSortedProgram(batchSize, numInputs, numValues, side) {
125429 _classCallCheck(this, SearchSortedProgram);
125430 this.variableNames = ['sortedSequence', 'values'];
125431 this.customUniforms = [{
125432 name: 'numInputs',
125433 type: 'int'
125434 }];
125435 this.outputShape = [batchSize, numValues];
125436 var webGL2LoopHead = 'while (left < right) {';
125437 // WebGL1 doesn't accept non constant loop conditions, so upper bound loop
125438 // iterations.
125439 var webGL1LoopHead = "for (int i = 0; i < ".concat(Math.ceil(Math.log2(numInputs + 1)), "; ++i) { if (left >= right) break;");
125440 var loopHead = env().getNumber('WEBGL_VERSION') === 2 ? webGL2LoopHead : webGL1LoopHead;
125441 // left corresponds to lower bound and right to upper bound.
125442 var boundComparator = side === 'left' ? '<' : '<=';
125443 this.userCode = "\n int findBound(int batch, float value) {\n int left = 0;\n int right = numInputs;\n int mid;\n ".concat(loopHead, "\n mid = (left + right) / 2;\n if (getSortedSequence(batch, mid) ").concat(boundComparator, " value) {\n left = mid + 1;\n } else {\n right = mid;\n }\n }\n return right;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int valueIndex = coords[1];\n\n float value = getValues(batch, valueIndex);\n\n setOutput(float(findBound(batch, value)));\n }\n ");
125444 });
125445
125446 /**
125447 * @license
125448 * Copyright 2022 Google LLC. All Rights Reserved.
125449 * Licensed under the Apache License, Version 2.0 (the "License");
125450 * you may not use this file except in compliance with the License.
125451 * You may obtain a copy of the License at
125452 *
125453 * http://www.apache.org/licenses/LICENSE-2.0
125454 *
125455 * Unless required by applicable law or agreed to in writing, software
125456 * distributed under the License is distributed on an "AS IS" BASIS,
125457 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125458 * See the License for the specific language governing permissions and
125459 * limitations under the License.
125460 * =============================================================================
125461 */
125462 function searchSorted(args) {
125463 var inputs = args.inputs,
125464 backend = args.backend,
125465 attrs = args.attrs;
125466 var sortedSequence = inputs.sortedSequence,
125467 values = inputs.values;
125468 var side = attrs.side;
125469 var program = new SearchSortedProgram(sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side);
125470 var customValues = [[sortedSequence.shape[1]]];
125471 return backend.runWebGLProgram(program, [sortedSequence, values], 'int32', customValues);
125472 }
125473 var searchSortedConfig = {
125474 kernelName: SearchSorted,
125475 backendName: 'webgl',
125476 kernelFunc: searchSorted
125477 };
125478
125479 var SelectProgram = /*#__PURE__*/_createClass(function SelectProgram(cRank, shape, rank) {
125480 _classCallCheck(this, SelectProgram);
125481 this.variableNames = ['c', 'a', 'b'];
125482 this.outputShape = shape;
125483 var cCoords;
125484 var abCoords;
125485 if (rank > 4) {
125486 throw Error("Where for rank ".concat(rank, " is not yet supported"));
125487 }
125488 if (rank === 1) {
125489 abCoords = "resRC";
125490 cCoords = "resRC";
125491 } else {
125492 var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
125493 var cCoordVars = [];
125494 var abCoordVars = [];
125495 for (var i = 0; i < shape.length; i++) {
125496 abCoordVars.push("".concat(currentCoords[i]));
125497 if (i < cRank) {
125498 cCoordVars.push("".concat(currentCoords[i]));
125499 }
125500 }
125501 cCoords = cCoordVars.join();
125502 abCoords = abCoordVars.join();
125503 }
125504 var dtype = getCoordsDataType(rank);
125505 this.userCode = "\n void main() {\n ".concat(dtype, " resRC = getOutputCoords();\n float cVal = getC(").concat(cCoords, ");\n if (cVal >= 1.0) {\n setOutput(getA(").concat(abCoords, "));\n } else {\n setOutput(getB(").concat(abCoords, "));\n }\n }\n ");
125506 });
125507
125508 /**
125509 * @license
125510 * Copyright 2020 Google LLC. All Rights Reserved.
125511 * Licensed under the Apache License, Version 2.0 (the "License");
125512 * you may not use this file except in compliance with the License.
125513 * You may obtain a copy of the License at
125514 *
125515 * http://www.apache.org/licenses/LICENSE-2.0
125516 *
125517 * Unless required by applicable law or agreed to in writing, software
125518 * distributed under the License is distributed on an "AS IS" BASIS,
125519 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125520 * See the License for the specific language governing permissions and
125521 * limitations under the License.
125522 * =============================================================================
125523 */
125524 function select(args) {
125525 var inputs = args.inputs,
125526 backend = args.backend;
125527 var condition = inputs.condition,
125528 t = inputs.t,
125529 e = inputs.e;
125530 var program = new SelectProgram(condition.shape.length, t.shape, t.shape.length);
125531 return backend.runWebGLProgram(program, [condition, t, e], upcastType(t.dtype, e.dtype));
125532 }
125533 var selectConfig = {
125534 kernelName: Select,
125535 backendName: 'webgl',
125536 kernelFunc: select
125537 };
125538
125539 /**
125540 * @license
125541 * Copyright 2020 Google LLC. All Rights Reserved.
125542 * Licensed under the Apache License, Version 2.0 (the "License");
125543 * you may not use this file except in compliance with the License.
125544 * You may obtain a copy of the License at
125545 *
125546 * http://www.apache.org/licenses/LICENSE-2.0
125547 *
125548 * Unless required by applicable law or agreed to in writing, software
125549 * distributed under the License is distributed on an "AS IS" BASIS,
125550 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125551 * See the License for the specific language governing permissions and
125552 * limitations under the License.
125553 * =============================================================================
125554 */
125555 var SELU = "\n // Stable and Attracting Fixed Point (0, 1) for Normalized Weights.\n // see: https://arxiv.org/abs/1706.02515\n float scaleAlpha = ".concat(SELU_SCALEALPHA, ";\n float scale = ").concat(SELU_SCALE, ";\n return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);\n");
125556 var selu = unaryKernelFunc({
125557 opSnippet: SELU
125558 });
125559 var seluConfig = {
125560 kernelName: Selu$1,
125561 backendName: 'webgl',
125562 kernelFunc: selu
125563 };
125564
125565 /**
125566 * @license
125567 * Copyright 2020 Google LLC. All Rights Reserved.
125568 * Licensed under the Apache License, Version 2.0 (the "License");
125569 * you may not use this file except in compliance with the License.
125570 * You may obtain a copy of the License at
125571 *
125572 * http://www.apache.org/licenses/LICENSE-2.0
125573 *
125574 * Unless required by applicable law or agreed to in writing, software
125575 * distributed under the License is distributed on an "AS IS" BASIS,
125576 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125577 * See the License for the specific language governing permissions and
125578 * limitations under the License.
125579 * =============================================================================
125580 */
125581 var SIGMOID = CHECK_NAN_SNIPPET_UNARY + "\n return 1.0 / (1.0 + exp(-1.0 * x));\n";
125582 var SIGMOID_PACKED = "\n vec4 result = 1.0 / (1.0 + exp(-1.0 * x));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
125583 var sigmoid = unaryKernelFunc({
125584 opSnippet: SIGMOID,
125585 packedOpSnippet: SIGMOID_PACKED,
125586 cpuKernelImpl: sigmoidImplCPU
125587 });
125588 var sigmoidConfig = {
125589 kernelName: Sigmoid$1,
125590 backendName: 'webgl',
125591 kernelFunc: sigmoid
125592 };
125593
125594 /**
125595 * @license
125596 * Copyright 2020 Google LLC. All Rights Reserved.
125597 * Licensed under the Apache License, Version 2.0 (the "License");
125598 * you may not use this file except in compliance with the License.
125599 * You may obtain a copy of the License at
125600 *
125601 * http://www.apache.org/licenses/LICENSE-2.0
125602 *
125603 * Unless required by applicable law or agreed to in writing, software
125604 * distributed under the License is distributed on an "AS IS" BASIS,
125605 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125606 * See the License for the specific language governing permissions and
125607 * limitations under the License.
125608 * =============================================================================
125609 */
125610 // Sign does not propagate NANs.
125611 var SIGN = "\n if (isnan(x)) { return 0.0; }\n return sign(x);\n";
125612 var sign = unaryKernelFunc({
125613 opSnippet: SIGN
125614 });
125615 var signConfig = {
125616 kernelName: Sign,
125617 backendName: 'webgl',
125618 kernelFunc: sign
125619 };
125620
125621 /**
125622 * @license
125623 * Copyright 2020 Google LLC. All Rights Reserved.
125624 * Licensed under the Apache License, Version 2.0 (the "License");
125625 * you may not use this file except in compliance with the License.
125626 * You may obtain a copy of the License at
125627 *
125628 * http://www.apache.org/licenses/LICENSE-2.0
125629 *
125630 * Unless required by applicable law or agreed to in writing, software
125631 * distributed under the License is distributed on an "AS IS" BASIS,
125632 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125633 * See the License for the specific language governing permissions and
125634 * limitations under the License.
125635 * =============================================================================
125636 */
125637 var SIN = CHECK_NAN_SNIPPET_UNARY + "\n return sin(x);\n";
125638 var SIN_PACKED = "\n vec4 result = sin(x);\n bvec4 isNaN = isnan(x);\n ".concat(CHECK_NAN_SNIPPET_PACKED, "\n return result;\n");
125639 var sin = unaryKernelFunc({
125640 opSnippet: SIN,
125641 packedOpSnippet: SIN_PACKED
125642 });
125643 var sinConfig = {
125644 kernelName: Sin,
125645 backendName: 'webgl',
125646 kernelFunc: sin
125647 };
125648
125649 /**
125650 * @license
125651 * Copyright 2020 Google LLC. All Rights Reserved.
125652 * Licensed under the Apache License, Version 2.0 (the "License");
125653 * you may not use this file except in compliance with the License.
125654 * You may obtain a copy of the License at
125655 *
125656 * http://www.apache.org/licenses/LICENSE-2.0
125657 *
125658 * Unless required by applicable law or agreed to in writing, software
125659 * distributed under the License is distributed on an "AS IS" BASIS,
125660 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125661 * See the License for the specific language governing permissions and
125662 * limitations under the License.
125663 * =============================================================================
125664 */
125665 var SINH = "\n float e2x = exp(x);\n return (e2x - 1.0 / e2x) / 2.0;\n";
125666 var sinh = unaryKernelFunc({
125667 opSnippet: SINH
125668 });
125669 var sinhConfig = {
125670 kernelName: Sinh,
125671 backendName: 'webgl',
125672 kernelFunc: sinh
125673 };
125674
125675 /**
125676 * @license
125677 * Copyright 2020 Google LLC. All Rights Reserved.
125678 * Licensed under the Apache License, Version 2.0 (the "License");
125679 * you may not use this file except in compliance with the License.
125680 * You may obtain a copy of the License at
125681 *
125682 * http://www.apache.org/licenses/LICENSE-2.0
125683 *
125684 * Unless required by applicable law or agreed to in writing, software
125685 * distributed under the License is distributed on an "AS IS" BASIS,
125686 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125687 * See the License for the specific language governing permissions and
125688 * limitations under the License.
125689 * =============================================================================
125690 */
125691 var SOFTPLUS = "\n float epsilon = 1.1920928955078125e-7;\n float threshold = log(epsilon) + 2.0;\n\n bool too_large = x > -threshold;\n bool too_small = x < threshold;\n\n float result;\n float exp_x = exp(x);\n\n if (too_large){\n result = x;\n }\n else if (too_small){\n result = exp_x;\n }\n else{\n result = log(exp_x + 1.0);\n }\n return result;\n";
125692 var softplus = unaryKernelFunc({
125693 opSnippet: SOFTPLUS
125694 });
125695 var softplusConfig = {
125696 kernelName: Softplus$1,
125697 backendName: 'webgl',
125698 kernelFunc: softplus
125699 };
125700
125701 var spaceToBatchND = function spaceToBatchND(args) {
125702 var inputs = args.inputs,
125703 backend = args.backend,
125704 attrs = args.attrs;
125705 var x = inputs.x;
125706 var blockShape = attrs.blockShape,
125707 paddings = attrs.paddings;
125708 assert$1(x.shape.length <= 4, function () {
125709 return 'spaceToBatchND for rank > 4 with a WebGL backend not ' + 'implemented yet';
125710 });
125711 var prod = blockShape.reduce(function (a, b) {
125712 return a * b;
125713 });
125714 var completePaddings = [[0, 0]];
125715 completePaddings.push.apply(completePaddings, _toConsumableArray(paddings));
125716 for (var i = 1 + blockShape.length; i < x.shape.length; ++i) {
125717 completePaddings.push([0, 0]);
125718 }
125719 var toDispose = [];
125720 var paddedX = padV2({
125721 inputs: {
125722 x: x
125723 },
125724 backend: backend,
125725 attrs: {
125726 paddings: completePaddings,
125727 constantValue: 0
125728 }
125729 });
125730 var reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
125731 var permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
125732 var flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
125733 var reshapedPaddedX = reshape({
125734 inputs: {
125735 x: paddedX
125736 },
125737 backend: backend,
125738 attrs: {
125739 shape: reshapedPaddedShape
125740 }
125741 });
125742 var paddedXT = transpose({
125743 inputs: {
125744 x: reshapedPaddedX
125745 },
125746 backend: backend,
125747 attrs: {
125748 perm: permutedReshapedPaddedPermutation
125749 }
125750 });
125751 var result = reshape({
125752 inputs: {
125753 x: paddedXT
125754 },
125755 backend: backend,
125756 attrs: {
125757 shape: flattenShape
125758 }
125759 });
125760 toDispose.push(paddedX);
125761 toDispose.push(reshapedPaddedX);
125762 toDispose.push(paddedXT);
125763 toDispose.forEach(function (t) {
125764 return backend.disposeIntermediateTensorInfo(t);
125765 });
125766 return result;
125767 };
125768 var spaceToBatchNDConfig = {
125769 kernelName: SpaceToBatchND,
125770 backendName: 'webgl',
125771 kernelFunc: spaceToBatchND
125772 };
125773
125774 function sparseFillEmptyRows(args) {
125775 var inputs = args.inputs,
125776 backend = args.backend;
125777 var indices = inputs.indices,
125778 values = inputs.values,
125779 denseShape = inputs.denseShape,
125780 defaultValue = inputs.defaultValue;
125781 if (denseShape.shape.length !== 1) {
125782 throw new Error("Dense shape must be a vector, saw:\n ".concat(denseShape.shape));
125783 }
125784 if (indices.shape.length !== 2) {
125785 throw new Error("Indices must be a matrix, saw:\n ".concat(indices.shape));
125786 }
125787 if (values.shape.length !== 1) {
125788 throw new Error("Values must be a vector, saw:\n ".concat(values.shape));
125789 }
125790 if (defaultValue.shape.length !== 0) {
125791 throw new Error("Default value must be a scalar, saw:\n ".concat(defaultValue.shape));
125792 }
125793 var $indices = backend.readSync(indices.dataId);
125794 var $values = backend.readSync(values.dataId);
125795 var $denseShape = backend.readSync(denseShape.dataId);
125796 var $defaultValue = backend.readSync(defaultValue.dataId)[0];
125797 var _sparseFillEmptyRowsI = sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue),
125798 _sparseFillEmptyRowsI2 = _slicedToArray(_sparseFillEmptyRowsI, 5),
125799 outputIndices = _sparseFillEmptyRowsI2[0],
125800 outputIndicesShape = _sparseFillEmptyRowsI2[1],
125801 outputValues = _sparseFillEmptyRowsI2[2],
125802 emptyRowIndicator = _sparseFillEmptyRowsI2[3],
125803 reverseIndexMap = _sparseFillEmptyRowsI2[4];
125804 return [backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices), backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues), backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map(function (value) {
125805 return Number(value);
125806 }))), backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap))];
125807 }
125808 var sparseFillEmptyRowsConfig = {
125809 kernelName: SparseFillEmptyRows,
125810 backendName: 'webgl',
125811 kernelFunc: sparseFillEmptyRows
125812 };
125813
125814 function sparseReshape(args) {
125815 var inputs = args.inputs,
125816 backend = args.backend;
125817 var inputIndices = inputs.inputIndices,
125818 inputShape = inputs.inputShape,
125819 newShape = inputs.newShape;
125820 if (inputIndices.shape.length !== 2) {
125821 throw new Error("Input indices should be a matrix but received shape ".concat(inputIndices.shape));
125822 }
125823 if (inputShape.shape.length !== 1) {
125824 throw new Error("Input shape should be a vector but received shape ".concat(inputShape.shape));
125825 }
125826 if (newShape.shape.length !== 1) {
125827 throw new Error("Target shape should be a vector but received shape ".concat(newShape.shape));
125828 }
125829 var $inputShape = Array.from(backend.readSync(inputShape.dataId));
125830 var $inputIndices = backend.readSync(inputIndices.dataId);
125831 var targetShape = Array.from(backend.readSync(newShape.dataId));
125832 var _sparseReshapeImplCPU = sparseReshapeImplCPU($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape),
125833 _sparseReshapeImplCPU2 = _slicedToArray(_sparseReshapeImplCPU, 3),
125834 newIndices = _sparseReshapeImplCPU2[0],
125835 indicesShape = _sparseReshapeImplCPU2[1],
125836 outputShape = _sparseReshapeImplCPU2[2];
125837 return [backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices), backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape))];
125838 }
125839 var sparseReshapeConfig = {
125840 kernelName: SparseReshape,
125841 backendName: 'webgl',
125842 kernelFunc: sparseReshape
125843 };
125844
125845 function sparseSegmentMean(args) {
125846 var inputs = args.inputs,
125847 backend = args.backend;
125848 var data = inputs.data,
125849 indices = inputs.indices,
125850 segmentIds = inputs.segmentIds;
125851 if (data.shape.length < 1) {
125852 throw new Error("Data should be at least 1 dimensional but received scalar");
125853 }
125854 if (indices.shape.length !== 1) {
125855 throw new Error("Indices should be a vector but received shape\n ".concat(indices.shape));
125856 }
125857 if (segmentIds.shape.length !== 1) {
125858 throw new Error("Segment ids should be a vector but received shape\n ".concat(segmentIds.shape));
125859 }
125860 var $data = backend.readSync(data.dataId);
125861 var $indices = backend.readSync(indices.dataId);
125862 var $segmentIds = backend.readSync(segmentIds.dataId);
125863 var _sparseSegmentReducti = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds, true),
125864 _sparseSegmentReducti2 = _slicedToArray(_sparseSegmentReducti, 2),
125865 outputData = _sparseSegmentReducti2[0],
125866 outputDataShape = _sparseSegmentReducti2[1];
125867 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
125868 }
125869 var sparseSegmentMeanConfig = {
125870 kernelName: SparseSegmentMean,
125871 backendName: 'webgl',
125872 kernelFunc: sparseSegmentMean
125873 };
125874
125875 function sparseSegmentSum(args) {
125876 var inputs = args.inputs,
125877 backend = args.backend;
125878 var data = inputs.data,
125879 indices = inputs.indices,
125880 segmentIds = inputs.segmentIds;
125881 if (data.shape.length < 1) {
125882 throw new Error("Data should be at least 1 dimensional but received scalar");
125883 }
125884 if (indices.shape.length !== 1) {
125885 throw new Error("Indices should be a vector but received shape\n ".concat(indices.shape));
125886 }
125887 if (segmentIds.shape.length !== 1) {
125888 throw new Error("Segment ids should be a vector but received shape\n ".concat(segmentIds.shape));
125889 }
125890 var $data = backend.readSync(data.dataId);
125891 var $indices = backend.readSync(indices.dataId);
125892 var $segmentIds = backend.readSync(segmentIds.dataId);
125893 var _sparseSegmentReducti = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds),
125894 _sparseSegmentReducti2 = _slicedToArray(_sparseSegmentReducti, 2),
125895 outputData = _sparseSegmentReducti2[0],
125896 outputDataShape = _sparseSegmentReducti2[1];
125897 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
125898 }
125899 var sparseSegmentSumConfig = {
125900 kernelName: SparseSegmentSum,
125901 backendName: 'webgl',
125902 kernelFunc: sparseSegmentSum
125903 };
125904
125905 /**
125906 * @license
125907 * Copyright 2020 Google LLC. All Rights Reserved.
125908 * Licensed under the Apache License, Version 2.0 (the "License");
125909 * you may not use this file except in compliance with the License.
125910 * You may obtain a copy of the License at
125911 *
125912 * http://www.apache.org/licenses/LICENSE-2.0
125913 *
125914 * Unless required by applicable law or agreed to in writing, software
125915 * distributed under the License is distributed on an "AS IS" BASIS,
125916 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125917 * See the License for the specific language governing permissions and
125918 * limitations under the License.
125919 * =============================================================================
125920 */
125921 function sparseToDense(args) {
125922 var inputs = args.inputs,
125923 backend = args.backend,
125924 attrs = args.attrs;
125925 var sparseIndices = inputs.sparseIndices,
125926 sparseValues = inputs.sparseValues,
125927 defaultValue = inputs.defaultValue;
125928 var outputShape = attrs.outputShape;
125929 var _backend_util$calcula = calculateShapes(sparseValues, sparseIndices, outputShape),
125930 sliceRank = _backend_util$calcula.sliceRank,
125931 numUpdates = _backend_util$calcula.numUpdates,
125932 sliceSize = _backend_util$calcula.sliceSize,
125933 strides = _backend_util$calcula.strides,
125934 outputSize = _backend_util$calcula.outputSize;
125935 var sumDupeIndices = false;
125936 if (sparseValues.dtype === 'string') {
125937 var indicesBuf = backend.bufferSync(sparseIndices);
125938 var updatesBuf = backend.bufferSync(sparseValues);
125939 var $defaultValue = decodeString(backend.readSync(defaultValue.dataId)[0]);
125940 var outBuf = scatterImplCPU(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
125941 return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
125942 }
125943 var program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.shape.length, sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices);
125944 var res = backend.runWebGLProgram(program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype);
125945 var reshaped = reshape({
125946 inputs: {
125947 x: res
125948 },
125949 backend: backend,
125950 attrs: {
125951 shape: outputShape
125952 }
125953 });
125954 backend.disposeIntermediateTensorInfo(res);
125955 return reshaped;
125956 }
125957 var sparseToDenseConfig = {
125958 kernelName: SparseToDense,
125959 backendName: 'webgl',
125960 kernelFunc: sparseToDense
125961 };
125962
125963 function splitV(args) {
125964 var inputs = args.inputs,
125965 backend = args.backend,
125966 attrs = args.attrs;
125967 var x = inputs.x;
125968 var numOrSizeSplits = attrs.numOrSizeSplits,
125969 axis = attrs.axis;
125970 var $axis = parseAxisParam(axis, x.shape)[0];
125971 var splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
125972 var xRank = x.shape.length;
125973 var begin = new Array(xRank).fill(0);
125974 var size = x.shape.slice();
125975 return splitSizes.map(function (s) {
125976 var sliceSize = _toConsumableArray(size);
125977 sliceSize[$axis] = s;
125978 var sliceT = slice({
125979 inputs: {
125980 x: x
125981 },
125982 backend: backend,
125983 attrs: {
125984 begin: begin,
125985 size: sliceSize
125986 }
125987 });
125988 begin[$axis] += s;
125989 return sliceT;
125990 });
125991 }
125992 var splitVConfig = {
125993 kernelName: SplitV,
125994 backendName: 'webgl',
125995 kernelFunc: splitV
125996 };
125997
125998 /**
125999 * @license
126000 * Copyright 2020 Google LLC. All Rights Reserved.
126001 * Licensed under the Apache License, Version 2.0 (the "License");
126002 * you may not use this file except in compliance with the License.
126003 * You may obtain a copy of the License at
126004 *
126005 * http://www.apache.org/licenses/LICENSE-2.0
126006 *
126007 * Unless required by applicable law or agreed to in writing, software
126008 * distributed under the License is distributed on an "AS IS" BASIS,
126009 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126010 * See the License for the specific language governing permissions and
126011 * limitations under the License.
126012 * =============================================================================
126013 */
126014 var SQRT = "return sqrt(x);";
126015 var sqrt = unaryKernelFunc({
126016 opSnippet: SQRT,
126017 packedOpSnippet: SQRT,
126018 cpuKernelImpl: sqrtImplCPU
126019 });
126020 var sqrtConfig = {
126021 kernelName: Sqrt,
126022 backendName: 'webgl',
126023 kernelFunc: sqrt
126024 };
126025
126026 /**
126027 * @license
126028 * Copyright 2019 Google LLC. All Rights Reserved.
126029 * Licensed under the Apache License, Version 2.0 (the "License");
126030 * you may not use this file except in compliance with the License.
126031 * You may obtain a copy of the License at
126032 *
126033 * http://www.apache.org/licenses/LICENSE-2.0
126034 *
126035 * Unless required by applicable law or agreed to in writing, software
126036 * distributed under the License is distributed on an "AS IS" BASIS,
126037 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126038 * See the License for the specific language governing permissions and
126039 * limitations under the License.
126040 * =============================================================================
126041 */
126042 var SQUARE = "return x * x;";
126043 var square = unaryKernelFunc({
126044 opSnippet: SQUARE
126045 });
126046 var squareConfig = {
126047 kernelName: Square,
126048 backendName: 'webgl',
126049 kernelFunc: square
126050 };
126051
126052 /**
126053 * @license
126054 * Copyright 2020 Google LLC. All Rights Reserved.
126055 * Licensed under the Apache License, Version 2.0 (the "License");
126056 * you may not use this file except in compliance with the License.
126057 * You may obtain a copy of the License at
126058 *
126059 * http://www.apache.org/licenses/LICENSE-2.0
126060 *
126061 * Unless required by applicable law or agreed to in writing, software
126062 * distributed under the License is distributed on an "AS IS" BASIS,
126063 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126064 * See the License for the specific language governing permissions and
126065 * limitations under the License.
126066 * =============================================================================
126067 */
126068 var SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';
126069 var squaredDifference = binaryKernelFunc({
126070 opSnippet: SQUARED_DIFFERENCE,
126071 packedOpSnippet: SQUARED_DIFFERENCE
126072 });
126073 var squaredDifferenceConfig = {
126074 kernelName: SquaredDifference,
126075 backendName: 'webgl',
126076 kernelFunc: squaredDifference
126077 };
126078
126079 /**
126080 * @license
126081 * Copyright 2023 Google LLC.
126082 * Licensed under the Apache License, Version 2.0 (the "License");
126083 * you may not use this file except in compliance with the License.
126084 * You may obtain a copy of the License at
126085 *
126086 * http://www.apache.org/licenses/LICENSE-2.0
126087 *
126088 * Unless required by applicable law or agreed to in writing, software
126089 * distributed under the License is distributed on an "AS IS" BASIS,
126090 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126091 * See the License for the specific language governing permissions and
126092 * limitations under the License.
126093 * =============================================================================
126094 */
126095 function staticRegexReplace(args) {
126096 var inputs = args.inputs,
126097 backend = args.backend,
126098 attrs = args.attrs;
126099 var x = inputs.x;
126100 if (x.dtype !== 'string') {
126101 throw new Error('Input must be of datatype string');
126102 }
126103 var $x = backend.readSync(x.dataId);
126104 var stringInput = fromUint8ToStringArray($x);
126105 var output = staticRegexReplaceImplCPU(stringInput, 'string', attrs);
126106 return backend.makeTensorInfo(x.shape, 'string', output);
126107 }
126108 var staticRegexReplaceConfig = {
126109 kernelName: StaticRegexReplace,
126110 backendName: 'webgl',
126111 kernelFunc: staticRegexReplace
126112 };
126113
126114 /**
126115 * @license
126116 * Copyright 2020 Google LLC. All Rights Reserved.
126117 * Licensed under the Apache License, Version 2.0 (the "License");
126118 * you may not use this file except in compliance with the License.
126119 * You may obtain a copy of the License at
126120 *
126121 * http://www.apache.org/licenses/LICENSE-2.0
126122 *
126123 * Unless required by applicable law or agreed to in writing, software
126124 * distributed under the License is distributed on an "AS IS" BASIS,
126125 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126126 * See the License for the specific language governing permissions and
126127 * limitations under the License.
126128 * =============================================================================
126129 */
126130 function step(_ref) {
126131 var inputs = _ref.inputs,
126132 attrs = _ref.attrs,
126133 backend = _ref.backend;
126134 var x = inputs.x;
126135 var opSnippet = CHECK_NAN_SNIPPET$1 + "\n return x > 0.0 ? 1.0 : float(".concat(attrs.alpha, ");\n ");
126136 var program = new UnaryOpProgram(x.shape, opSnippet);
126137 return backend.runWebGLProgram(program, [x], x.dtype);
126138 }
126139 var stepConfig = {
126140 kernelName: Step,
126141 backendName: 'webgl',
126142 kernelFunc: step
126143 };
126144
126145 var StridedSliceProgram = /*#__PURE__*/_createClass(function StridedSliceProgram(begin, strides, size) {
126146 _classCallCheck(this, StridedSliceProgram);
126147 this.variableNames = ['x'];
126148 this.outputShape = size;
126149 var rank = size.length;
126150 var inputDtype = getCoordsDataType(size.length);
126151 var dtype = getCoordsDataType(size.length);
126152 var newCoords = '';
126153 if (rank === 1) {
126154 newCoords = 'coords * strides + begin';
126155 } else {
126156 var outputAxis = 0;
126157 newCoords = size.map(function (_, i) {
126158 outputAxis++;
126159 return size.length === 1 ? "coords * strides[".concat(i, "] + begin[").concat(i, "]") : "coords[".concat(outputAxis - 1, "] * strides[").concat(i, "] + begin[").concat(i, "]");
126160 }).join(',');
126161 }
126162 this.userCode = "\n ".concat(inputDtype, " begin = ").concat(inputDtype, "(").concat(begin, ");\n ").concat(inputDtype, " strides = ").concat(inputDtype, "(").concat(strides, ");\n\n void main() {\n ").concat(dtype, " coords = getOutputCoords();\n setOutput(getX(").concat(newCoords, "));\n }\n ");
126163 });
126164
126165 /**
126166 * @license
126167 * Copyright 2020 Google LLC. All Rights Reserved.
126168 * Licensed under the Apache License, Version 2.0 (the "License");
126169 * you may not use this file except in compliance with the License.
126170 * You may obtain a copy of the License at
126171 *
126172 * http://www.apache.org/licenses/LICENSE-2.0
126173 *
126174 * Unless required by applicable law or agreed to in writing, software
126175 * distributed under the License is distributed on an "AS IS" BASIS,
126176 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126177 * See the License for the specific language governing permissions and
126178 * limitations under the License.
126179 * =============================================================================
126180 */
126181 function stridedSlice(args) {
126182 var inputs = args.inputs,
126183 backend = args.backend,
126184 attrs = args.attrs;
126185 var x = inputs.x;
126186 var begin = attrs.begin,
126187 end = attrs.end,
126188 strides = attrs.strides,
126189 beginMask = attrs.beginMask,
126190 endMask = attrs.endMask,
126191 ellipsisMask = attrs.ellipsisMask,
126192 newAxisMask = attrs.newAxisMask,
126193 shrinkAxisMask = attrs.shrinkAxisMask;
126194 var _slice_util$sliceInfo = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask),
126195 finalShapeSparse = _slice_util$sliceInfo.finalShapeSparse,
126196 finalShape = _slice_util$sliceInfo.finalShape,
126197 isIdentity = _slice_util$sliceInfo.isIdentity,
126198 sliceDim0 = _slice_util$sliceInfo.sliceDim0,
126199 isSimpleSlice = _slice_util$sliceInfo.isSimpleSlice,
126200 $begin = _slice_util$sliceInfo.begin,
126201 $end = _slice_util$sliceInfo.end,
126202 $strides = _slice_util$sliceInfo.strides;
126203 var result;
126204 if (isIdentity) {
126205 // Optimization #1, slice is a no-op plus reshape
126206 result = reshape({
126207 inputs: {
126208 x: x
126209 },
126210 backend: backend,
126211 attrs: {
126212 shape: finalShape
126213 }
126214 });
126215 } else if (sliceDim0 || isSimpleSlice) {
126216 // Optimization #2, slice is memory contiguous (only occurs in dim 0)
126217 assert$1(x.shape.length >= 1, function () {
126218 return "Input must have rank at least 1, got: ".concat(x.shape.length);
126219 });
126220 var size = computeOutShape$2($begin, $end, $strides);
126221 // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
126222 var sliced = slice({
126223 inputs: {
126224 x: x
126225 },
126226 backend: backend,
126227 attrs: {
126228 begin: $begin,
126229 size: size
126230 }
126231 });
126232 result = reshape({
126233 inputs: {
126234 x: sliced
126235 },
126236 backend: backend,
126237 attrs: {
126238 shape: finalShape
126239 }
126240 });
126241 backend.disposeIntermediateTensorInfo(sliced);
126242 } else {
126243 var shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
126244 if (shouldExecuteOnCPU) {
126245 // tslint:disable-next-line: no-unnecessary-type-assertion
126246 var values = backend.readSync(x.dataId);
126247 // tslint:disable-next-line: no-unnecessary-type-assertion
126248 var xBuf = buffer(x.shape, x.dtype, values);
126249 var resultValues = stridedSliceImplCPU(finalShapeSparse, xBuf, $strides, $begin);
126250 result = backend.makeTensorInfo(finalShape, x.dtype, resultValues.values);
126251 } else {
126252 var program = new StridedSliceProgram($begin, $strides, finalShapeSparse);
126253 result = backend.runWebGLProgram(program, [x], x.dtype);
126254 }
126255 }
126256 var resultReshaped = reshape({
126257 inputs: {
126258 x: result
126259 },
126260 backend: backend,
126261 attrs: {
126262 shape: finalShape
126263 }
126264 });
126265 backend.disposeIntermediateTensorInfo(result);
126266 return resultReshaped;
126267 }
126268 var stridedSliceConfig = {
126269 kernelName: StridedSlice,
126270 backendName: 'webgl',
126271 kernelFunc: stridedSlice
126272 };
126273
126274 function stringNGrams(args) {
126275 var inputs = args.inputs,
126276 backend = args.backend,
126277 attrs = args.attrs;
126278 var separator = attrs.separator,
126279 nGramWidths = attrs.nGramWidths,
126280 leftPad = attrs.leftPad,
126281 rightPad = attrs.rightPad,
126282 padWidth = attrs.padWidth,
126283 preserveShortSequences = attrs.preserveShortSequences;
126284 var data = inputs.data,
126285 dataSplits = inputs.dataSplits;
126286 var $data = backend.readSync(data.dataId);
126287 var $dataSplits = backend.readSync(dataSplits.dataId);
126288 var _stringNGramsImplCPU = stringNGramsImplCPU($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences),
126289 _stringNGramsImplCPU2 = _slicedToArray(_stringNGramsImplCPU, 2),
126290 nGrams = _stringNGramsImplCPU2[0],
126291 nGramsSplits = _stringNGramsImplCPU2[1];
126292 return [backend.makeTensorInfo([nGrams.length], 'string', nGrams), backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits)];
126293 }
126294 var stringNGramsConfig = {
126295 kernelName: StringNGrams,
126296 backendName: 'webgl',
126297 kernelFunc: stringNGrams
126298 };
126299
126300 function stringSplit(args) {
126301 var inputs = args.inputs,
126302 backend = args.backend,
126303 attrs = args.attrs;
126304 var skipEmpty = attrs.skipEmpty;
126305 var input = inputs.input,
126306 delimiter = inputs.delimiter;
126307 if (input.dtype !== 'string') {
126308 throw new Error('Input must be of datatype string');
126309 }
126310 if (input.shape.length !== 1) {
126311 throw new Error("Input must be a vector, got shape: ".concat(input.shape));
126312 }
126313 if (delimiter.shape.length !== 0) {
126314 throw new Error("Delimiter must be a scalar, got shape: ".concat(delimiter.shape));
126315 }
126316 var $input = backend.readSync(input.dataId);
126317 var $delimiter = backend.readSync(delimiter.dataId)[0];
126318 var _stringSplitImplCPU = stringSplitImplCPU($input, $delimiter, skipEmpty),
126319 _stringSplitImplCPU2 = _slicedToArray(_stringSplitImplCPU, 3),
126320 indices = _stringSplitImplCPU2[0],
126321 values = _stringSplitImplCPU2[1],
126322 shape = _stringSplitImplCPU2[2];
126323 var outputSize = values.length;
126324 return [backend.makeTensorInfo([outputSize, 2], 'int32', indices), backend.makeTensorInfo([outputSize], 'string', values), backend.makeTensorInfo([2], 'int32', new Int32Array(shape))];
126325 }
126326 var stringSplitConfig = {
126327 kernelName: StringSplit,
126328 backendName: 'webgl',
126329 kernelFunc: stringSplit
126330 };
126331
126332 /**
126333 * @license
126334 * Copyright 2021 Google LLC. All Rights Reserved.
126335 * Licensed under the Apache License, Version 2.0 (the "License");
126336 * you may not use this file except in compliance with the License.
126337 * You may obtain a copy of the License at
126338 *
126339 * http://www.apache.org/licenses/LICENSE-2.0
126340 *
126341 * Unless required by applicable law or agreed to in writing, software
126342 * distributed under the License is distributed on an "AS IS" BASIS,
126343 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126344 * See the License for the specific language governing permissions and
126345 * limitations under the License.
126346 * =============================================================================
126347 */
126348 function stringToHashBucketFast(args) {
126349 var inputs = args.inputs,
126350 backend = args.backend,
126351 attrs = args.attrs;
126352 var numBuckets = attrs.numBuckets;
126353 var input = inputs.input;
126354 if (input.dtype !== 'string') {
126355 throw new Error('Input must be of datatype string');
126356 }
126357 if (numBuckets <= 0) {
126358 throw new Error("Number of buckets must be at least 1");
126359 }
126360 var $input = backend.readSync(input.dataId);
126361 var output = stringToHashBucketFastImplCPU($input, numBuckets);
126362 return backend.makeTensorInfo(input.shape, 'int32', output);
126363 }
126364 var stringToHashBucketFastConfig = {
126365 kernelName: StringToHashBucketFast,
126366 backendName: 'webgl',
126367 kernelFunc: stringToHashBucketFast
126368 };
126369
126370 /**
126371 * @license
126372 * Copyright 2020 Google LLC. All Rights Reserved.
126373 * Licensed under the Apache License, Version 2.0 (the "License");
126374 * you may not use this file except in compliance with the License.
126375 * You may obtain a copy of the License at
126376 *
126377 * http://www.apache.org/licenses/LICENSE-2.0
126378 *
126379 * Unless required by applicable law or agreed to in writing, software
126380 * distributed under the License is distributed on an "AS IS" BASIS,
126381 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126382 * See the License for the specific language governing permissions and
126383 * limitations under the License.
126384 * =============================================================================
126385 */
126386 var TAN = "return tan(x);";
126387 var tan = unaryKernelFunc({
126388 opSnippet: TAN
126389 });
126390 var tanConfig = {
126391 kernelName: Tan,
126392 backendName: 'webgl',
126393 kernelFunc: tan
126394 };
126395
126396 /**
126397 * @license
126398 * Copyright 2020 Google LLC. All Rights Reserved.
126399 * Licensed under the Apache License, Version 2.0 (the "License");
126400 * you may not use this file except in compliance with the License.
126401 * You may obtain a copy of the License at
126402 *
126403 * http://www.apache.org/licenses/LICENSE-2.0
126404 *
126405 * Unless required by applicable law or agreed to in writing, software
126406 * distributed under the License is distributed on an "AS IS" BASIS,
126407 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126408 * See the License for the specific language governing permissions and
126409 * limitations under the License.
126410 * =============================================================================
126411 */
126412 var TANH = "\n float e2x = exp(-2.0 * abs(x));\n return sign(x) * (1.0 - e2x) / (1.0 + e2x);\n";
126413 var tanh = unaryKernelFunc({
126414 opSnippet: TANH
126415 });
126416 var tanhConfig = {
126417 kernelName: Tanh$1,
126418 backendName: 'webgl',
126419 kernelFunc: tanh
126420 };
126421
126422 function tensorScatterUpdate(args) {
126423 var inputs = args.inputs,
126424 backend = args.backend,
126425 attrs = args.attrs;
126426 var tensor = inputs.tensor,
126427 indices = inputs.indices,
126428 updates = inputs.updates;
126429 _objectDestructuringEmpty(attrs);
126430 var _backend_util$calcula = calculateShapes(updates, indices, tensor.shape),
126431 sliceRank = _backend_util$calcula.sliceRank,
126432 numUpdates = _backend_util$calcula.numUpdates,
126433 sliceSize = _backend_util$calcula.sliceSize,
126434 strides = _backend_util$calcula.strides,
126435 outputSize = _backend_util$calcula.outputSize;
126436 var flattenShape = [outputSize / sliceSize, sliceSize];
126437 if (outputSize === 0) {
126438 return backend.makeTensorInfo(tensor.shape, indices.dtype);
126439 }
126440 var flattenIndices = reshape({
126441 inputs: {
126442 x: indices
126443 },
126444 backend: backend,
126445 attrs: {
126446 shape: [numUpdates, sliceRank]
126447 }
126448 });
126449 var flattenX = reshape({
126450 inputs: {
126451 x: updates
126452 },
126453 backend: backend,
126454 attrs: {
126455 shape: [numUpdates, sliceSize]
126456 }
126457 });
126458 var flattenTensor = reshape({
126459 inputs: {
126460 x: tensor
126461 },
126462 backend: backend,
126463 attrs: {
126464 shape: flattenShape
126465 }
126466 });
126467 var program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape, false, true);
126468 var res = backend.runWebGLProgram(program, [flattenX, flattenIndices, flattenTensor], flattenTensor.dtype);
126469 var reshaped = reshape({
126470 inputs: {
126471 x: res
126472 },
126473 backend: backend,
126474 attrs: {
126475 shape: tensor.shape
126476 }
126477 });
126478 backend.disposeIntermediateTensorInfo(flattenIndices);
126479 backend.disposeIntermediateTensorInfo(flattenX);
126480 backend.disposeIntermediateTensorInfo(flattenTensor);
126481 backend.disposeIntermediateTensorInfo(res);
126482 return reshaped;
126483 }
126484 var tensorScatterUpdateConfig = {
126485 kernelName: TensorScatterUpdate,
126486 backendName: 'webgl',
126487 kernelFunc: tensorScatterUpdate
126488 };
126489
126490 var TileProgram = /*#__PURE__*/_createClass(function TileProgram(aShape, reps) {
126491 _classCallCheck(this, TileProgram);
126492 this.variableNames = ['A'];
126493 var outputShape = new Array(aShape.length);
126494 for (var i = 0; i < outputShape.length; i++) {
126495 outputShape[i] = aShape[i] * reps[i];
126496 }
126497 this.outputShape = outputShape;
126498 this.rank = outputShape.length;
126499 var dtype = getCoordsDataType(this.rank);
126500 var sourceCoords = getSourceCoords(aShape);
126501 this.userCode = "\n void main() {\n ".concat(dtype, " resRC = getOutputCoords();\n setOutput(getA(").concat(sourceCoords, "));\n }\n ");
126502 });
126503 function getSourceCoords(aShape) {
126504 var rank = aShape.length;
126505 if (rank > 5) {
126506 throw Error("Tile for rank ".concat(rank, " is not yet supported"));
126507 }
126508 if (rank === 1) {
126509 return "imod(resRC, ".concat(aShape[0], ")");
126510 }
126511 var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u'];
126512 var sourceCoords = [];
126513 for (var i = 0; i < aShape.length; i++) {
126514 sourceCoords.push("imod(".concat(currentCoords[i], ", ").concat(aShape[i], ")"));
126515 }
126516 return sourceCoords.join();
126517 }
126518
126519 /**
126520 * @license
126521 * Copyright 2020 Google LLC. All Rights Reserved.
126522 * Licensed under the Apache License, Version 2.0 (the "License");
126523 * you may not use this file except in compliance with the License.
126524 * You may obtain a copy of the License at
126525 *
126526 * http://www.apache.org/licenses/LICENSE-2.0
126527 *
126528 * Unless required by applicable law or agreed to in writing, software
126529 * distributed under the License is distributed on an "AS IS" BASIS,
126530 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126531 * See the License for the specific language governing permissions and
126532 * limitations under the License.
126533 * =============================================================================
126534 */
126535 function tile(params) {
126536 var inputs = params.inputs,
126537 backend = params.backend,
126538 attrs = params.attrs;
126539 var x = inputs.x;
126540 var reps = attrs.reps;
126541 // tile gpu program cannot handle rank > 5 case.
126542 if (x.dtype === 'string' || x.shape.length > 5) {
126543 // Even thought string tensor is always on CPU, just to be consistent on how
126544 // to access tensor data.
126545 var data = backend.readSync(x.dataId);
126546 var value = x.dtype === 'string' ? data.map(function (d) {
126547 return decodeString(d);
126548 }) : data;
126549 var buf = buffer(x.shape, x.dtype, value);
126550 var outBuf = tileImplCPU(buf, reps);
126551 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
126552 }
126553 var program = new TileProgram(x.shape, reps);
126554 var output = backend.runWebGLProgram(program, [x], x.dtype);
126555 return output;
126556 }
126557 var tileConfig = {
126558 kernelName: Tile,
126559 backendName: 'webgl',
126560 kernelFunc: tile
126561 };
126562
126563 // Based on Algorithm 2 of Bitonic Top K, ref:
126564 // https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf
126565 // The original algorithm is based on computing the top K only, however
126566 // since for TFJS we require the indices of the top K values as well then the
126567 // algorithm found here is a bit modified. Rather than producing the values
126568 // at each step, the indices containing the top K are generated instead.
126569 // The output values are not generated to reduce the number of outputs in the
126570 // GPU, the values can easily be retrieved from the indices using a gather
126571 // op.
126572 var SwapProgram = /*#__PURE__*/_createClass(
126573 /**
126574 * @param shape desired output shape (can be larger than input shape, output
126575 * will be padded with -Infinity)
126576 */
126577 function SwapProgram(shape) {
126578 _classCallCheck(this, SwapProgram);
126579 this.variableNames = ['x', 'indices'];
126580 // |n| Size of the original input of TopK.
126581 // |firstPass|indicates if this is the first time swap is being used which
126582 // means no indices input containing the top K is present yet.
126583 // |inc| Swaps pairs of indices (0, inc), (1, inc + 1), (2, inc + 2) ...
126584 this.customUniforms = [{
126585 name: 'n',
126586 type: 'int'
126587 }, {
126588 name: 'firstPass',
126589 type: 'int'
126590 }, {
126591 name: 'negativeInf',
126592 type: 'float'
126593 }, {
126594 name: 'dir',
126595 type: 'int'
126596 }, {
126597 name: 'inc',
126598 type: 'int'
126599 }];
126600 this.outputShape = shape;
126601 this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int elemIdx = coords[1];\n\n // We compare elements pair-wise within a group of size 2 * inc.\n // The comparing rule for each group alternates between ascending\n // and descending. Within each group, we compare each pair at\n // positions i and i+inc. To decide whether an element at position i\n // is x0 or x1, we mod it by 2 * inc, if the result is smaller than\n // inc, it is in the first half of the group, we denote it as x0,\n // otherwise we denote it as x1.\n // For example, as shown in the Bitonic top K paper referenced above,\n // Figure5(a) shows that element[1] is in the\n // second half of the group when group size is 2, but it is in the\n // first half of the group when group size is 4.\n\n bool isFirstInPair = imod(elemIdx, 2 * inc) < inc;\n int i = isFirstInPair ? elemIdx : elemIdx - inc;\n\n int i0 = firstPass == 1 ? i : int(getIndices(batch, i));\n int i1 = firstPass == 1 ? i + inc : int(getIndices(batch, i + inc));\n float x0 = i0 < n ? getX(batch, i0) : negativeInf;\n float x1 = i1 < n ? getX(batch, i1) : negativeInf;\n\n // Denotes which direction indices are in (ascending or descending).\n bool reverse = imod(elemIdx, 2 * dir) >= dir;\n bool isGreater = x0 > x1 || (x0 == x1 && i1 > i0);\n if (reverse == isGreater) { // Elements in opposite order of direction\n int iTemp = i0;\n i0 = i1;\n i1 = iTemp;\n }\n if (isFirstInPair) {\n setOutput(float(i0));\n } else {\n setOutput(float(i1));\n }\n }\n ";
126602 });
126603 var MergeProgram = /*#__PURE__*/_createClass(
126604 /**
126605 * @param shape desired output shape (must be half of the input size)
126606 */
126607 function MergeProgram(shape) {
126608 _classCallCheck(this, MergeProgram);
126609 this.variableNames = ['x', 'indices'];
126610 // |n| Size of the original input of TopK
126611 // |firstPass| indicates if this is the first time swap is being used which
126612 // means no indices input containing the top K is present yet.
126613 // |k| Top k elements desired
126614 this.customUniforms = [{
126615 name: 'n',
126616 type: 'int'
126617 }, {
126618 name: 'firstPass',
126619 type: 'int'
126620 }, {
126621 name: 'k',
126622 type: 'int'
126623 }];
126624 this.outputShape = shape;
126625 this.userCode = "\n void main() {\n // Takes max of indices (0, k), (1, k + 1), (2, k + 2) ...\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int elemIdx = coords[1];\n\n // The output size is half of the previous size.\n // If the previous sequence is | | | | _ _ _ _ | | | | _ _ _ _ (k=4),\n // we only need to output the indices at positions |, the indices at\n // positions _ can be thrown away, see Figure5(b) After Phase 2\n // (Merge phase) in the Bitonic Top K paper referenced above.\n // For example, the paper shows we only need to output the orange bars.\n // The output sequence should look like this | | | | | | | |.\n // Because the sequence is halved, to map the output index back\n // to the previous sequence to find the corresponding value,\n // we need to double the index. When we double the index,\n // we basically interpolate a position, so 2i looks like\n // | _ | _ | _ | _ | _ | _ | _. We move the | to the first k position\n // of each 2k positions by - elemIdx % k. E.g. for output at\n // index 4,5,6,7, we want to get the corresponding element at\n // original index 8,9,10,11, for output at index 8,9,10,11,\n // we want to get the corresponding element at original index\n // 16,17,18,19, so on and so forth.\n\n int i = elemIdx < k ? elemIdx : (elemIdx * 2 - imod(elemIdx, k));\n int i0 = firstPass == 1 ? i : int(getIndices(batch, i));\n int i1 = firstPass == 1 ? i + k : int(getIndices(batch, i + k));\n\n float x0 = getX(batch, i0);\n float x1 = i1 < n ? getX(batch, i1) : x0;\n\n setOutput(x0 >= x1 ? float(i0) : float(i1));\n }\n ";
126626 });
126627
126628 function disposeIntermediateTensorInfoOrNull(backend, tensorInfo) {
126629 if (tensorInfo !== null) {
126630 backend.disposeIntermediateTensorInfo(tensorInfo);
126631 }
126632 }
126633 function roundUpToPow2(num) {
126634 var pow2 = 1;
126635 while (pow2 < num) {
126636 pow2 *= 2;
126637 }
126638 return pow2;
126639 }
126640 // Based on Algorithm 2 of Bitonic Top K, ref:
126641 // https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf
126642 function topK(args) {
126643 var inputs = args.inputs,
126644 backend = args.backend,
126645 attrs = args.attrs;
126646 var x = inputs.x;
126647 var k = attrs.k,
126648 sorted = attrs.sorted;
126649 // Empirically determined constant used to determine last dim threshold for
126650 // handing off execution to the CPU.
126651 var TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD');
126652 // Empirically determined constant used to determine k threshold for handing
126653 // off execution to the CPU.
126654 var TOPK_K_CPU_HANDOFF_THRESHOLD = env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD');
126655 var xShape = x.shape;
126656 var lastDim = xShape[xShape.length - 1];
126657 if (backend.shouldExecuteOnCPU([x]) || lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD || k > TOPK_K_CPU_HANDOFF_THRESHOLD) {
126658 var xVals = backend.readSync(x.dataId);
126659 var _topKImplCPU = topKImplCPU(xVals, xShape, x.dtype, k, sorted),
126660 _topKImplCPU2 = _slicedToArray(_topKImplCPU, 2),
126661 allTopKVals = _topKImplCPU2[0],
126662 allTopKIndices = _topKImplCPU2[1];
126663 return [backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values), backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)];
126664 }
126665 if (k === 0) {
126666 xShape[xShape.length - 1] = 0;
126667 return [backend.makeTensorInfo(xShape, x.dtype, []), backend.makeTensorInfo(xShape, 'int32', [])];
126668 }
126669 if (lastDim === 1 /* firstPass */) {
126670 return [x, fill({
126671 attrs: {
126672 shape: xShape,
126673 dtype: 'int32',
126674 value: 0
126675 },
126676 backend: backend
126677 })];
126678 }
126679 // Eagerly unpack x input since it is passed in to all the shaders which
126680 // require unpacked inputs.
126681 var xtexData = backend.texData.get(x.dataId);
126682 var xIsPacked = xtexData !== null && xtexData.isPacked;
126683 var xUnPacked = xIsPacked ? backend.unpackTensor(x) : x;
126684 // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
126685 var xSize = sizeFromShape(xShape);
126686 var batch = xSize / lastDim;
126687 var x2D = reshape({
126688 inputs: {
126689 x: xUnPacked
126690 },
126691 attrs: {
126692 shape: [batch, lastDim]
126693 },
126694 backend: backend
126695 });
126696 if (xIsPacked) {
126697 disposeIntermediateTensorInfoOrNull(backend, xUnPacked);
126698 }
126699 var kPow2 = roundUpToPow2(k);
126700 var lastDimPow2 = roundUpToPow2(lastDim);
126701 // Only the indices containing the top K are kept at every step to reduce
126702 // number of outputs in the GPU algorithms, so once the final set of indices
126703 // is computed then gather is used to grab the corresponding values
126704 // from the original input.
126705 var indices = null;
126706 // GPU algorithm always takes in an indices input but this input is not used
126707 // on the first run of a GPU algorithm, therefore if indices is null we simply
126708 // pass in x2D instead of it but the value will not actually be used
126709 var getInputs = function getInputs() {
126710 return indices === null ? [x2D, x2D] : [x2D, indices];
126711 };
126712 var runSwap = function runSwap(dir, inc, shape) {
126713 var inputs = getInputs();
126714 var program = new SwapProgram(shape);
126715 var fistPass = indices === null ? 1 : 0;
126716 var customValues = [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]];
126717 var prevIndices = indices;
126718 indices = backend.runWebGLProgram(program, inputs, 'int32', customValues);
126719 disposeIntermediateTensorInfoOrNull(backend, prevIndices);
126720 };
126721 // Step 1: local sort
126722 for (var len = 1; len < kPow2; len *= 2) {
126723 var dir = len * 2;
126724 for (var inc = len; inc >= 1; inc /= 2) {
126725 runSwap(dir, inc, [batch, lastDimPow2]);
126726 }
126727 }
126728 // Step 2: merge
126729 for (var indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) {
126730 var _inputs = getInputs();
126731 var mergeProgram = new MergeProgram([batch, indicesSize / 2]);
126732 var firstPass = indices === null ? 1 : 0;
126733 var customValues = [[lastDim], [firstPass], [kPow2]];
126734 var _prevIndices = indices;
126735 indices = backend.runWebGLProgram(mergeProgram, _inputs, 'int32', customValues);
126736 disposeIntermediateTensorInfoOrNull(backend, _prevIndices);
126737 // Step 3: rebuild
126738 var _len = kPow2 / 2;
126739 var _dir = _len * 2;
126740 for (var _inc = _len; _inc >= 1; _inc /= 2) {
126741 runSwap(_dir, _inc, indices.shape);
126742 }
126743 }
126744 // Keep only the requested top K results instead of kPow2
126745 var prevIndices = indices;
126746 indices = slice({
126747 inputs: {
126748 x: indices
126749 },
126750 backend: backend,
126751 attrs: {
126752 begin: 0,
126753 size: [batch, k]
126754 }
126755 });
126756 disposeIntermediateTensorInfoOrNull(backend, prevIndices);
126757 // Gather values on last dimension
126758 var values = gatherV2({
126759 inputs: {
126760 x: x2D,
126761 indices: indices
126762 },
126763 backend: backend,
126764 attrs: {
126765 axis: 1,
126766 batchDims: 1
126767 }
126768 });
126769 disposeIntermediateTensorInfoOrNull(backend, x2D);
126770 // Reshape back to the original input shape, except that the last
126771 // dimension is k.
126772 var newShape = xShape.slice(0, -1);
126773 newShape.push(k);
126774 prevIndices = indices;
126775 indices = reshape({
126776 inputs: {
126777 x: indices
126778 },
126779 attrs: {
126780 shape: newShape
126781 },
126782 backend: backend
126783 });
126784 disposeIntermediateTensorInfoOrNull(backend, prevIndices);
126785 var prevValues = values;
126786 values = reshape({
126787 inputs: {
126788 x: values
126789 },
126790 attrs: {
126791 shape: newShape
126792 },
126793 backend: backend
126794 });
126795 disposeIntermediateTensorInfoOrNull(backend, prevValues);
126796 return [values, indices];
126797 }
126798 var topKConfig = {
126799 kernelName: TopK,
126800 backendName: 'webgl',
126801 kernelFunc: topK
126802 };
126803
126804 /**
126805 * @license
126806 * Copyright 2021 Google LLC. All Rights Reserved.
126807 * Licensed under the Apache License, Version 2.0 (the "License");
126808 * you may not use this file except in compliance with the License.
126809 * You may obtain a copy of the License at
126810 *
126811 * http://www.apache.org/licenses/LICENSE-2.0
126812 *
126813 * Unless required by applicable law or agreed to in writing, software
126814 * distributed under the License is distributed on an "AS IS" BASIS,
126815 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126816 * See the License for the specific language governing permissions and
126817 * limitations under the License.
126818 * =============================================================================
126819 */
126820 var TransformProgram = /*#__PURE__*/_createClass(function TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape) {
126821 _classCallCheck(this, TransformProgram);
126822 this.variableNames = ['Image', 'Transforms'];
126823 this.outputShape = outShape;
126824 var interpolationModeId = interpolation === 'nearest' ? 1 : 2;
126825 var fillModeId;
126826 switch (fillMode) {
126827 case 'constant':
126828 fillModeId = 1;
126829 break;
126830 case 'reflect':
126831 fillModeId = 2;
126832 break;
126833 case 'wrap':
126834 fillModeId = 3;
126835 break;
126836 case 'nearest':
126837 fillModeId = 4;
126838 break;
126839 default:
126840 fillModeId = 1;
126841 break;
126842 }
126843 this.userCode = "\n float mapCoord(float outCoord, float len) {\n float inCoord = outCoord;\n if(".concat(fillModeId, " == 2) {\n if (inCoord < 0.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz2 = 2.0 * len;\n if (inCoord < sz2) {\n inCoord = sz2 * float(int(float(-inCoord / sz2))) +\n inCoord;\n }\n inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0;\n }\n } else if (inCoord > len - 1.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz2 = 2.0 * len;\n inCoord -= sz2 * float(int(float(inCoord / sz2)));\n if (inCoord >= len) {\n inCoord = sz2 - inCoord - 1.0;\n }\n }\n }\n return clamp(inCoord, 0.0, len - 1.0);\n } else if (").concat(fillModeId, " == 3) {\n if (inCoord < 0.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz = len - 1.0;\n inCoord += len * (float(int(float(-inCoord / sz))) + 1.0);\n }\n } else if (inCoord > len - 1.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz = len - 1.0;\n inCoord -= len * float(int(float(inCoord / sz)));\n }\n }\n return clamp(inCoord, 0.0, len - 1.0);\n } else if (").concat(fillModeId, " == 4) {\n return clamp(outCoord, 0.0, len - 1.0);\n } else {\n return outCoord;\n }\n }\n\n float readWithFillValue(int batch, int coordY, int coordX,\n int channel) {\n float outputValue;\n if (0 <= coordY && coordY < ").concat(imageHeight, " && 0 <= coordX && coordX < ").concat(imageWidth, ") {\n outputValue = getImage(batch, coordY, coordX, channel);\n } else {\n outputValue = float(").concat(fillValue, ");\n }\n return outputValue;\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n float outputValue;\n int batch = coords[0];\n int x = coords[2];\n int y = coords[1];\n int channel = coords[3];\n float xf = float(x);\n float yf = float(y);\n float a1 = getTransforms(batch, 0);\n float a2 = getTransforms(batch, 1);\n float a3 = getTransforms(batch, 2);\n float b1 = getTransforms(batch, 3);\n float b2 = getTransforms(batch, 4);\n float b3 = getTransforms(batch, 5);\n float c1 = getTransforms(batch, 6);\n float c2 = getTransforms(batch, 7);\n float projection = c1 * xf + c2 * yf + 1.0;\n if (projection == 0.0) {\n outputValue = float(").concat(fillValue, ");\n } else {\n float inX = (a1 * xf + a2 * yf + a3) / projection;\n float inY = (b1 * xf + b2 * yf + b3) / projection;\n float mapX = mapCoord(inX, float(").concat(imageWidth, "));\n float mapY = mapCoord(inY, float(").concat(imageHeight, "));\n\n if (").concat(interpolationModeId, " == 1) {\n int coordY = int(round(mapY));\n int coordX = int(round(mapX));\n outputValue = readWithFillValue(batch, coordY, coordX,\n channel);\n } else {\n float yFloor = floor(mapY);\n float xFloor = floor(mapX);\n float yCeil = yFloor + 1.0;\n float xCeil = xFloor + 1.0;\n float valueYFloor = (xCeil - mapX) *\n readWithFillValue(batch, int(yFloor), int(xFloor), channel) +\n (mapX - xFloor) *\n readWithFillValue(batch, int(yFloor), int(xCeil), channel);\n float valueYCeil = (xCeil - mapX) *\n readWithFillValue(batch, int(yCeil), int(xFloor), channel) +\n (mapX - xFloor) *\n readWithFillValue(batch, int(yCeil), int(xCeil), channel);\n outputValue = (yCeil - mapY) * valueYFloor +\n (mapY - yFloor) * valueYCeil;\n }\n }\n setOutput(outputValue);\n }\n ");
126844 });
126845
126846 function transform(args) {
126847 var inputs = args.inputs,
126848 backend = args.backend,
126849 attrs = args.attrs;
126850 var image = inputs.image,
126851 transforms = inputs.transforms;
126852 var interpolation = attrs.interpolation,
126853 fillMode = attrs.fillMode,
126854 fillValue = attrs.fillValue,
126855 outputShape = attrs.outputShape;
126856 var _image$shape = _slicedToArray(image.shape, 4),
126857 batch = _image$shape[0],
126858 imageHeight = _image$shape[1],
126859 imageWidth = _image$shape[2],
126860 numChannels = _image$shape[3];
126861 var _ref = outputShape != null ? outputShape : [imageHeight, imageWidth],
126862 _ref2 = _slicedToArray(_ref, 2),
126863 outHeight = _ref2[0],
126864 outWidth = _ref2[1];
126865 var outShape = [batch, outHeight, outWidth, numChannels];
126866 var program = new TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape);
126867 return backend.runWebGLProgram(program, [image, transforms], 'float32');
126868 }
126869 var transformConfig = {
126870 kernelName: Transform,
126871 backendName: 'webgl',
126872 kernelFunc: transform
126873 };
126874
126875 /**
126876 * @license
126877 * Copyright 2020 Google LLC. All Rights Reserved.
126878 * Licensed under the Apache License, Version 2.0 (the License);
126879 * you may not use this file except in compliance with the License.
126880 * You may obtain a copy of the License at
126881 *
126882 * http://www.apache.org/licenses/LICENSE-2.0
126883 *
126884 * Unless required by applicable law or agreed to in writing, software
126885 * distributed under the License is distributed on an AS IS BASIS,
126886 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126887 * See the License for the specific language governing permissions and
126888 * limitations under the License.
126889 * =============================================================================
126890 */
126891 function unique(args) {
126892 var inputs = args.inputs,
126893 attrs = args.attrs,
126894 backend = args.backend;
126895 var axis = attrs.axis;
126896 var x = inputs.x;
126897 assertNotComplex(x, 'unique');
126898 // For now, always forward calculation to the CPU backend.
126899 console.warn('WARNING: ', 'UI might be locked temporarily as data is being downloaded');
126900 var values = backend.readSync(x.dataId);
126901 var _uniqueImplCPU = uniqueImplCPU(values, axis, x.shape, x.dtype),
126902 outputValues = _uniqueImplCPU.outputValues,
126903 outputShape = _uniqueImplCPU.outputShape,
126904 indices = _uniqueImplCPU.indices;
126905 return [backend.makeTensorInfo(outputShape, x.dtype, outputValues), backend.makeTensorInfo([indices.length], 'int32', indices)];
126906 }
126907 var uniqueConfig = {
126908 kernelName: Unique,
126909 backendName: 'webgl',
126910 kernelFunc: unique
126911 };
126912
126913 /**
126914 * @license
126915 * Copyright 2020 Google LLC. All Rights Reserved.
126916 * Licensed under the Apache License, Version 2.0 (the "License");
126917 * you may not use this file except in compliance with the License.
126918 * You may obtain a copy of the License at
126919 *
126920 * http://www.apache.org/licenses/LICENSE-2.0
126921 *
126922 * Unless required by applicable law or agreed to in writing, software
126923 * distributed under the License is distributed on an "AS IS" BASIS,
126924 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126925 * See the License for the specific language governing permissions and
126926 * limitations under the License.
126927 * =============================================================================
126928 */
126929 function unpack(args) {
126930 var inputs = args.inputs,
126931 backend = args.backend,
126932 attrs = args.attrs;
126933 var value = inputs.value;
126934 var axis = attrs.axis;
126935 if (axis < 0) {
126936 axis += value.shape.length;
126937 }
126938 var x = value;
126939 var xRank = x.shape.length;
126940 var num = value.shape[axis];
126941 var outShape = new Array(xRank - 1);
126942 var outIndex = 0;
126943 for (var i = 0; i < xRank; i++) {
126944 if (i !== axis) {
126945 outShape[outIndex++] = x.shape[i];
126946 }
126947 }
126948 var toDispose = [];
126949 var begin = new Array(xRank).fill(0);
126950 var size = x.shape.slice();
126951 size[axis] = 1;
126952 var res = new Array(num);
126953 for (var _i = 0; _i < res.length; _i++) {
126954 begin[axis] = _i;
126955 var sliced = slice({
126956 inputs: {
126957 x: x
126958 },
126959 backend: backend,
126960 attrs: {
126961 begin: begin,
126962 size: size
126963 }
126964 });
126965 var reshaped = reshape({
126966 inputs: {
126967 x: sliced
126968 },
126969 backend: backend,
126970 attrs: {
126971 shape: outShape
126972 }
126973 });
126974 res[_i] = reshaped;
126975 toDispose.push(sliced);
126976 }
126977 toDispose.forEach(function (t) {
126978 return backend.disposeIntermediateTensorInfo(t);
126979 });
126980 return res;
126981 }
126982 var unpackConfig = {
126983 kernelName: Unpack,
126984 backendName: 'webgl',
126985 kernelFunc: unpack
126986 };
126987
126988 /**
126989 * @license
126990 * Copyright 2018 Google LLC. All Rights Reserved.
126991 * Licensed under the Apache License, Version 2.0 (the "License");
126992 * you may not use this file except in compliance with the License.
126993 * You may obtain a copy of the License at
126994 *
126995 * http://www.apache.org/licenses/LICENSE-2.0
126996 *
126997 * Unless required by applicable law or agreed to in writing, software
126998 * distributed under the License is distributed on an "AS IS" BASIS,
126999 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
127000 * See the License for the specific language governing permissions and
127001 * limitations under the License.
127002 * =============================================================================
127003 */
127004 var SegmentOpProgram = /*#__PURE__*/_createClass(function SegmentOpProgram(segOpInfo, segOpType) {
127005 _classCallCheck(this, SegmentOpProgram);
127006 this.variableNames = ['x', 'segmentIds'];
127007 var windowSize = segOpInfo.windowSize;
127008 var batchSize = segOpInfo.batchSize;
127009 var inSize = segOpInfo.inSize;
127010 var numSegments = segOpInfo.numSegments;
127011 var outSize = numSegments * Math.ceil(inSize / windowSize);
127012 this.outputShape = [batchSize, outSize];
127013 var initializationValue = '0.0';
127014 var returnValue = "sumValue";
127015 var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
127016 var windowSizeVec4Remainder = windowSize % 4;
127017 var updateSnippet = "\n sumValue += dot(values, segFilter);\n ";
127018 var checkValueOutOfBounds = '';
127019 if (inSize % windowSize > 0) {
127020 checkValueOutOfBounds = "\n if (inIdx < 0 || inIdx >= ".concat(inSize, ") {\n return initializationValue;\n }\n ");
127021 }
127022 var checkSegmentIdOutOfBounds = '';
127023 if (inSize % windowSize > 0) {
127024 checkSegmentIdOutOfBounds = "\n if (inIdx < 0 || inIdx >= ".concat(inSize, ") {\n return -1.0;\n }\n ");
127025 }
127026 this.userCode = "\n const float initializationValue = ".concat(initializationValue, ";\n\n float getValue(int batch, int inIdx) {\n ").concat(checkValueOutOfBounds, "\n return getX(batch, inIdx);\n }\n\n float getSegmentIdAtIndex(int inIdx) {\n ").concat(checkSegmentIdOutOfBounds, "\n return getSegmentIds(inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = int(floor(float(outIdx) / float(\n ").concat(numSegments, ")) * float(").concat(windowSize, "));\n int currentSeg = int(mod(float(outIdx), float(").concat(numSegments, ")));\n\n float sumValue = 0.0;\n\n for (int i = 0; i < ").concat(windowSizeNearestVec4, "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0\n );\n\n ").concat(updateSnippet, "\n }\n\n int inIdx = inOffset + ").concat(windowSizeNearestVec4, ";\n if (").concat(windowSizeVec4Remainder === 1, ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n int inIdxSeg = int(getSegmentIdAtIndex(inIdx));\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n 0,\n 0,\n 0\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 2, ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n 0,\n 0\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 3, ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n 0\n );\n\n ").concat(updateSnippet, "\n }\n setOutput(").concat(returnValue, ");\n }\n ");
127027 });
127028
127029 /**
127030 * @license
127031 * Copyright 2020 Google LLC. All Rights Reserved.
127032 * Licensed under the Apache License, Version 2.0 (the "License");
127033 * you may not use this file except in compliance with the License.
127034 * You may obtain a copy of the License at
127035 *
127036 * http://www.apache.org/licenses/LICENSE-2.0
127037 *
127038 * Unless required by applicable law or agreed to in writing, software
127039 * distributed under the License is distributed on an "AS IS" BASIS,
127040 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
127041 * See the License for the specific language governing permissions and
127042 * limitations under the License.
127043 * =============================================================================
127044 */
127045 function unsortedSegmentSum(args) {
127046 var inputs = args.inputs,
127047 backend = args.backend,
127048 attrs = args.attrs;
127049 var x = inputs.x,
127050 segmentIds = inputs.segmentIds;
127051 var numSegments = attrs.numSegments;
127052 var xRank = x.shape.length;
127053 var toDispose = [];
127054 var axis = 0;
127055 var permutation = getAxesPermutation([axis], xRank);
127056 var permutedX = x;
127057 if (permutation != null) {
127058 permutedX = transpose({
127059 inputs: {
127060 x: x
127061 },
127062 backend: backend,
127063 attrs: {
127064 perm: permutation
127065 }
127066 });
127067 toDispose.push(permutedX);
127068 axis = getInnerMostAxes(1, xRank)[0];
127069 }
127070 var outShape = computeOutShape(permutedX.shape, axis, numSegments);
127071 var inSize = sizeFromShape([permutedX.shape[axis]]);
127072 var a2D = reshape({
127073 inputs: {
127074 x: permutedX
127075 },
127076 backend: backend,
127077 attrs: {
127078 shape: [-1, inSize]
127079 }
127080 });
127081 toDispose.push(a2D);
127082 var outputDType = sumOutType(x.dtype);
127083 var segOpCompute = function segOpCompute(x, segOpType, segmentIds, dtype, numSegments) {
127084 var batchSize = x.shape[0];
127085 var inSize = x.shape[1];
127086 var windowSize = segOpComputeOptimalWindowSize(inSize, numSegments);
127087 var segOpInfo = {
127088 windowSize: windowSize,
127089 inSize: inSize,
127090 batchSize: batchSize,
127091 numSegments: numSegments
127092 };
127093 var program = new SegmentOpProgram(segOpInfo, segOpType);
127094 var output = backend.compileAndRun(program, [x, segmentIds], dtype);
127095 toDispose.push(output);
127096 // No need to run another GPGPU program.
127097 if (output.shape[1] === numSegments) {
127098 return output;
127099 }
127100 var rangeInfo = range({
127101 backend: backend,
127102 attrs: {
127103 start: 0,
127104 stop: numSegments,
127105 step: 1,
127106 dtype: 'float32'
127107 }
127108 });
127109 var tileInfo = tile({
127110 inputs: {
127111 x: rangeInfo
127112 },
127113 backend: backend,
127114 attrs: {
127115 reps: [inSize / windowSize]
127116 }
127117 });
127118 toDispose.push(rangeInfo);
127119 toDispose.push(tileInfo);
127120 var result = segOpCompute(output, segOpType, tileInfo, dtype, numSegments);
127121 return result;
127122 };
127123 var segOpResult = segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments);
127124 var reshaped = reshape({
127125 inputs: {
127126 x: segOpResult
127127 },
127128 backend: backend,
127129 attrs: {
127130 shape: outShape
127131 }
127132 });
127133 var result = reshaped;
127134 if (permutation != null) {
127135 toDispose.push(reshaped);
127136 var perm = getUndoAxesPermutation(permutation);
127137 result = transpose({
127138 inputs: {
127139 x: result
127140 },
127141 backend: backend,
127142 attrs: {
127143 perm: perm
127144 }
127145 });
127146 }
127147 toDispose.forEach(function (t) {
127148 return backend.disposeIntermediateTensorInfo(t);
127149 });
127150 return result;
127151 }
127152 var unsortedSegmentSumConfig = {
127153 kernelName: UnsortedSegmentSum,
127154 backendName: 'webgl',
127155 kernelFunc: unsortedSegmentSum
127156 };
127157
127158 /**
127159 * @license
127160 * Copyright 2020 Google LLC. All Rights Reserved.
127161 * Licensed under the Apache License, Version 2.0 (the "License");
127162 * you may not use this file except in compliance with the License.
127163 * You may obtain a copy of the License at
127164 *
127165 * http://www.apache.org/licenses/LICENSE-2.0
127166 *
127167 * Unless required by applicable law or agreed to in writing, software
127168 * distributed under the License is distributed on an "AS IS" BASIS,
127169 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
127170 * See the License for the specific language governing permissions and
127171 * limitations under the License.
127172 * =============================================================================
127173 */
127174 // List all kernel configs here
127175 var kernelConfigs = [_fusedMatMulConfig, absConfig, acosConfig, acoshConfig, addConfig, addNConfig, allConfig, anyConfig, argMaxConfig, argMinConfig, asinConfig, asinhConfig, atanConfig, atan2Config, atanhConfig, avgPoolConfig, avgPool3DConfig, avgPool3DGradConfig, avgPoolGradConfig, batchMatMulConfig, batchNormConfig, batchToSpaceNDConfig, bincountConfig, bitwiseAndConfig, broadcastArgsConfig, castConfig, ceilConfig, clipByValueConfig, complexConfig, complexAbsConfig, concatConfig, conv2DConfig, conv2DBackpropFilterConfig, conv2DBackpropInputConfig, conv3DConfig, conv3DBackpropFilterV2Config, conv3DBackpropInputConfig, cosConfig, coshConfig, cropAndResizeConfig, cumprodConfig, cumsumConfig, denseBincountConfig, depthToSpaceConfig, depthwiseConv2dNativeConfig, depthwiseConv2dNativeBackpropFilterConfig, depthwiseConv2dNativeBackpropInputConfig, diagConfig, dilation2DConfig, einsumConfig, eluConfig, eluGradConfig, equalConfig, erfConfig, expConfig, expandDimsConfig, expm1Config, fftConfig, fillConfig, flipLeftRightConfig, floorConfig, floorDivConfig, fromPixelsConfig, fusedConv2DConfig, fusedDepthwiseConv2DConfig, gatherNdConfig, gatherV2Config, greaterConfig, greaterEqualConfig, identityConfig, ifftConfig, imagConfig, isFiniteConfig, isInfConfig, isNaNConfig, leakyReluConfig, lessConfig, lessEqualConfig, linSpaceConfig, logConfig, log1pConfig, logicalAndConfig, logicalNotConfig, logicalOrConfig, LRNConfig, LRNGradConfig, maxConfig, maximumConfig, maxPoolConfig, maxPool3DConfig, maxPool3DGradConfig, maxPoolGradConfig, maxPoolWithArgmaxConfig, meanConfig, minConfig, minimumConfig, mirrorPadConfig, modConfig, multinomialConfig, multiplyConfig, negConfig, nonMaxSuppressionV3Config, nonMaxSuppressionV4Config, nonMaxSuppressionV5Config, notEqualConfig, oneHotConfig, onesLikeConfig, packConfig, padV2Config, powConfig, preluConfig, prodConfig, raggedGatherConfig, raggedRangeConfig, raggedTensorToTensorConfig, rangeConfig, realConfig, realDivConfig, reciprocalConfig, reluConfig, relu6Config, reshapeConfig, resizeBilinearConfig, resizeBilinearGradConfig, resizeNearestNeighborConfig, resizeNearestNeighborGradConfig, reverseConfig, rotateWithOffsetConfig, roundConfig, rsqrtConfig, scatterNdConfig, searchSortedConfig, selectConfig, seluConfig, sigmoidConfig, signConfig, sinConfig, sinhConfig, sliceConfig, softmaxConfig, softplusConfig, spaceToBatchNDConfig, sparseFillEmptyRowsConfig, sparseReshapeConfig, sparseSegmentMeanConfig, sparseSegmentSumConfig, sparseToDenseConfig, splitVConfig, sqrtConfig, squareConfig, squaredDifferenceConfig, staticRegexReplaceConfig, stepConfig, stridedSliceConfig, stringNGramsConfig, stringSplitConfig, stringToHashBucketFastConfig, subConfig, sumConfig, tanConfig, tanhConfig, tensorScatterUpdateConfig, tileConfig, topKConfig, transformConfig, transposeConfig, uniqueConfig, unpackConfig, unsortedSegmentSumConfig, zerosLikeConfig];
127176 for (var _i = 0, _kernelConfigs = kernelConfigs; _i < _kernelConfigs.length; _i++) {
127177 var kernelConfig = _kernelConfigs[_i];
127178 registerKernel(kernelConfig);
127179 }
127180
127181 /**
127182 * @license
127183 * Copyright 2020 Google LLC. All Rights Reserved.
127184 * Licensed under the Apache License, Version 2.0 (the "License");
127185 * you may not use this file except in compliance with the License.
127186 * You may obtain a copy of the License at
127187 *
127188 * http://www.apache.org/licenses/LICENSE-2.0
127189 *
127190 * Unless required by applicable law or agreed to in writing, software
127191 * distributed under the License is distributed on an "AS IS" BASIS,
127192 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
127193 * See the License for the specific language governing permissions and
127194 * limitations under the License.
127195 * =============================================================================
127196 */
127197
127198 /** @license See the LICENSE file. */
127199 // This code is auto-generated, do not modify this file!
127200 var version$1 = '4.22.0';
127201
127202 /**
127203 * @license
127204 * Copyright 2018 Google LLC. All Rights Reserved.
127205 * Licensed under the Apache License, Version 2.0 (the "License");
127206 * you may not use this file except in compliance with the License.
127207 * You may obtain a copy of the License at
127208 *
127209 * http://www.apache.org/licenses/LICENSE-2.0
127210 *
127211 * Unless required by applicable law or agreed to in writing, software
127212 * distributed under the License is distributed on an "AS IS" BASIS,
127213 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
127214 * See the License for the specific language governing permissions and
127215 * limitations under the License.
127216 * =============================================================================
127217 */
127218 var version = {
127219 'tfjs-core': version$7,
127220 'tfjs-backend-cpu': version$3,
127221 'tfjs-backend-webgl': version$2,
127222 'tfjs-data': version$4,
127223 'tfjs-layers': version$6,
127224 'tfjs-converter': version$5,
127225 'tfjs': version$1
127226 };
127227
127228 /**
127229 * @license
127230 * Copyright 2020 Google LLC. All Rights Reserved.
127231 * Licensed under the Apache License, Version 2.0 (the "License");
127232 * you may not use this file except in compliance with the License.
127233 * You may obtain a copy of the License at
127234 *
127235 * http://www.apache.org/licenses/LICENSE-2.0
127236 *
127237 * Unless required by applicable law or agreed to in writing, software
127238 * distributed under the License is distributed on an "AS IS" BASIS,
127239 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
127240 * See the License for the specific language governing permissions and
127241 * limitations under the License.
127242 * =============================================================================
127243 */
127244
127245 exports.Abs = Abs;
127246 exports.Acos = Acos;
127247 exports.Acosh = Acosh;
127248 exports.AdadeltaOptimizer = AdadeltaOptimizer;
127249 exports.AdagradOptimizer = AdagradOptimizer;
127250 exports.AdamOptimizer = AdamOptimizer;
127251 exports.AdamaxOptimizer = AdamaxOptimizer;
127252 exports.Add = Add$1;
127253 exports.AddN = AddN;
127254 exports.All = All;
127255 exports.Any = Any;
127256 exports.ArgMax = ArgMax;
127257 exports.ArgMin = ArgMin;
127258 exports.Asin = Asin;
127259 exports.Asinh = Asinh;
127260 exports.Atan = Atan;
127261 exports.Atan2 = Atan2;
127262 exports.Atanh = Atanh;
127263 exports.AvgPool = AvgPool;
127264 exports.AvgPool3D = AvgPool3D;
127265 exports.AvgPool3DGrad = AvgPool3DGrad;
127266 exports.AvgPoolGrad = AvgPoolGrad;
127267 exports.BatchMatMul = BatchMatMul;
127268 exports.BatchToSpaceND = BatchToSpaceND;
127269 exports.Bincount = Bincount;
127270 exports.BitwiseAnd = BitwiseAnd;
127271 exports.BroadcastArgs = BroadcastArgs;
127272 exports.BroadcastTo = BroadcastTo;
127273 exports.Callback = Callback;
127274 exports.CallbackList = CallbackList;
127275 exports.Cast = Cast;
127276 exports.Ceil = Ceil;
127277 exports.ClipByValue = ClipByValue;
127278 exports.Complex = Complex;
127279 exports.ComplexAbs = ComplexAbs;
127280 exports.Concat = Concat;
127281 exports.Conv2D = Conv2D$1;
127282 exports.Conv2DBackpropFilter = Conv2DBackpropFilter;
127283 exports.Conv2DBackpropInput = Conv2DBackpropInput;
127284 exports.Conv3D = Conv3D$1;
127285 exports.Conv3DBackpropFilterV2 = Conv3DBackpropFilterV2;
127286 exports.Conv3DBackpropInputV2 = Conv3DBackpropInputV2;
127287 exports.Cos = Cos;
127288 exports.Cosh = Cosh;
127289 exports.CropAndResize = CropAndResize;
127290 exports.Cumprod = Cumprod;
127291 exports.Cumsum = Cumsum;
127292 exports.CustomCallback = CustomCallback;
127293 exports.DataStorage = DataStorage;
127294 exports.DenseBincount = DenseBincount;
127295 exports.DepthToSpace = DepthToSpace;
127296 exports.DepthwiseConv2dNative = DepthwiseConv2dNative;
127297 exports.DepthwiseConv2dNativeBackpropFilter = DepthwiseConv2dNativeBackpropFilter;
127298 exports.DepthwiseConv2dNativeBackpropInput = DepthwiseConv2dNativeBackpropInput;
127299 exports.Diag = Diag;
127300 exports.Dilation2D = Dilation2D;
127301 exports.Dilation2DBackpropFilter = Dilation2DBackpropFilter;
127302 exports.Dilation2DBackpropInput = Dilation2DBackpropInput;
127303 exports.Draw = Draw;
127304 exports.EarlyStopping = EarlyStopping;
127305 exports.Einsum = Einsum;
127306 exports.Elu = Elu$1;
127307 exports.EluGrad = EluGrad;
127308 exports.Environment = Environment;
127309 exports.Equal = Equal;
127310 exports.Erf = Erf;
127311 exports.Exp = Exp;
127312 exports.ExpandDims = ExpandDims;
127313 exports.Expm1 = Expm1;
127314 exports.FFT = FFT;
127315 exports.Fill = Fill;
127316 exports.FlipLeftRight = FlipLeftRight;
127317 exports.Floor = Floor;
127318 exports.FloorDiv = FloorDiv;
127319 exports.FromPixels = FromPixels;
127320 exports.FusedBatchNorm = FusedBatchNorm;
127321 exports.FusedConv2D = FusedConv2D;
127322 exports.FusedDepthwiseConv2D = FusedDepthwiseConv2D;
127323 exports.GPGPUContext = GPGPUContext;
127324 exports.GatherNd = GatherNd;
127325 exports.GatherV2 = GatherV2;
127326 exports.GraphModel = GraphModel;
127327 exports.Greater = Greater;
127328 exports.GreaterEqual = GreaterEqual;
127329 exports.History = History;
127330 exports.IFFT = IFFT;
127331 exports.Identity = Identity$1;
127332 exports.Imag = Imag;
127333 exports.InputSpec = InputSpec;
127334 exports.IsFinite = IsFinite;
127335 exports.IsInf = IsInf;
127336 exports.IsNan = IsNan;
127337 exports.KernelBackend = KernelBackend;
127338 exports.LRN = LRN;
127339 exports.LRNGrad = LRNGrad;
127340 exports.LayerVariable = LayerVariable;
127341 exports.LayersModel = LayersModel;
127342 exports.LeakyRelu = LeakyRelu;
127343 exports.Less = Less;
127344 exports.LessEqual = LessEqual;
127345 exports.LinSpace = LinSpace;
127346 exports.Log = Log;
127347 exports.Log1p = Log1p;
127348 exports.LogSoftmax = LogSoftmax$1;
127349 exports.LogicalAnd = LogicalAnd;
127350 exports.LogicalNot = LogicalNot;
127351 exports.LogicalOr = LogicalOr;
127352 exports.LogicalXor = LogicalXor;
127353 exports.LowerBound = LowerBound;
127354 exports.MathBackendCPU = MathBackendCPU;
127355 exports.MathBackendWebGL = MathBackendWebGL;
127356 exports.MatrixBandPart = MatrixBandPart;
127357 exports.Max = Max;
127358 exports.MaxPool = MaxPool;
127359 exports.MaxPool3D = MaxPool3D;
127360 exports.MaxPool3DGrad = MaxPool3DGrad;
127361 exports.MaxPoolGrad = MaxPoolGrad;
127362 exports.MaxPoolWithArgmax = MaxPoolWithArgmax;
127363 exports.Maximum = Maximum$1;
127364 exports.Mean = Mean;
127365 exports.Min = Min;
127366 exports.Minimum = Minimum$1;
127367 exports.MirrorPad = MirrorPad;
127368 exports.Mod = Mod;
127369 exports.MomentumOptimizer = MomentumOptimizer;
127370 exports.Multinomial = Multinomial;
127371 exports.Multiply = Multiply$1;
127372 exports.Neg = Neg;
127373 exports.NonMaxSuppressionV3 = NonMaxSuppressionV3;
127374 exports.NonMaxSuppressionV4 = NonMaxSuppressionV4;
127375 exports.NonMaxSuppressionV5 = NonMaxSuppressionV5;
127376 exports.NotEqual = NotEqual;
127377 exports.OP_SCOPE_SUFFIX = OP_SCOPE_SUFFIX;
127378 exports.OneHot = OneHot;
127379 exports.OnesLike = OnesLike;
127380 exports.Optimizer = Optimizer;
127381 exports.OptimizerConstructors = OptimizerConstructors;
127382 exports.Pack = Pack;
127383 exports.PadV2 = PadV2;
127384 exports.Pool = Pool;
127385 exports.Pow = Pow;
127386 exports.Prelu = Prelu;
127387 exports.Prod = Prod;
127388 exports.RMSPropOptimizer = RMSPropOptimizer;
127389 exports.RNN = RNN;
127390 exports.RaggedGather = RaggedGather;
127391 exports.RaggedRange = RaggedRange;
127392 exports.RaggedTensorToTensor = RaggedTensorToTensor;
127393 exports.Range = Range;
127394 exports.Real = Real;
127395 exports.RealDiv = RealDiv;
127396 exports.Reciprocal = Reciprocal;
127397 exports.Relu = Relu$1;
127398 exports.Relu6 = Relu6$1;
127399 exports.Reshape = Reshape$1;
127400 exports.ResizeBilinear = ResizeBilinear;
127401 exports.ResizeBilinearGrad = ResizeBilinearGrad;
127402 exports.ResizeNearestNeighbor = ResizeNearestNeighbor;
127403 exports.ResizeNearestNeighborGrad = ResizeNearestNeighborGrad;
127404 exports.Reverse = Reverse;
127405 exports.RotateWithOffset = RotateWithOffset;
127406 exports.Round = Round;
127407 exports.Rsqrt = Rsqrt;
127408 exports.SGDOptimizer = SGDOptimizer;
127409 exports.ScatterNd = ScatterNd;
127410 exports.SearchSorted = SearchSorted;
127411 exports.Select = Select;
127412 exports.Selu = Selu$1;
127413 exports.Sequential = Sequential;
127414 exports.Sigmoid = Sigmoid$1;
127415 exports.Sign = Sign;
127416 exports.Sin = Sin;
127417 exports.Sinh = Sinh;
127418 exports.Slice = Slice;
127419 exports.Softmax = Softmax$2;
127420 exports.Softplus = Softplus$1;
127421 exports.SpaceToBatchND = SpaceToBatchND;
127422 exports.SparseFillEmptyRows = SparseFillEmptyRows;
127423 exports.SparseReshape = SparseReshape;
127424 exports.SparseSegmentMean = SparseSegmentMean;
127425 exports.SparseSegmentSum = SparseSegmentSum;
127426 exports.SparseToDense = SparseToDense;
127427 exports.SplitV = SplitV;
127428 exports.Sqrt = Sqrt;
127429 exports.Square = Square;
127430 exports.SquaredDifference = SquaredDifference;
127431 exports.StaticRegexReplace = StaticRegexReplace;
127432 exports.Step = Step;
127433 exports.StridedSlice = StridedSlice;
127434 exports.StringNGrams = StringNGrams;
127435 exports.StringSplit = StringSplit;
127436 exports.StringToHashBucketFast = StringToHashBucketFast;
127437 exports.Sub = Sub;
127438 exports.Sum = Sum;
127439 exports.SymbolicTensor = SymbolicTensor;
127440 exports.Tan = Tan;
127441 exports.Tanh = Tanh$1;
127442 exports.Tensor = Tensor;
127443 exports.TensorBuffer = TensorBuffer;
127444 exports.TensorScatterUpdate = TensorScatterUpdate;
127445 exports.Tile = Tile;
127446 exports.TopK = TopK;
127447 exports.Transform = Transform;
127448 exports.Transpose = Transpose;
127449 exports.Unique = Unique;
127450 exports.Unpack = Unpack;
127451 exports.UnsortedSegmentSum = UnsortedSegmentSum;
127452 exports.UpperBound = UpperBound;
127453 exports.Variable = Variable;
127454 exports.ZerosLike = ZerosLike;
127455 exports._FusedMatMul = _FusedMatMul;
127456 exports.abs = abs$2;
127457 exports.acos = acos$2;
127458 exports.acosh = acosh$2;
127459 exports.add = add$3;
127460 exports.addN = addN$2;
127461 exports.all = all$2;
127462 exports.any = any$2;
127463 exports.argMax = argMax$2;
127464 exports.argMin = argMin$2;
127465 exports.asin = asin$2;
127466 exports.asinh = asinh$2;
127467 exports.atan = atan$2;
127468 exports.atan2 = atan2$2;
127469 exports.atanh = atanh$2;
127470 exports.avgPool = avgPool$2;
127471 exports.avgPool3d = avgPool3d$1;
127472 exports.backend = backend$1;
127473 exports.backend_util = backend_util;
127474 exports.basicLSTMCell = basicLSTMCell;
127475 exports.batchNorm = batchNorm$2;
127476 exports.batchNorm2d = batchNorm2d;
127477 exports.batchNorm3d = batchNorm3d;
127478 exports.batchNorm4d = batchNorm4d;
127479 exports.batchToSpaceND = batchToSpaceND$2;
127480 exports.bincount = bincount$2;
127481 exports.bitwiseAnd = bitwiseAnd$2;
127482 exports.booleanMaskAsync = booleanMaskAsync;
127483 exports.broadcastArgs = broadcastArgs$2;
127484 exports.broadcastTo = broadcastTo;
127485 exports.broadcast_util = broadcast_util;
127486 exports.browser = browser;
127487 exports.buffer = buffer;
127488 exports.callbacks = callbacks;
127489 exports.cast = cast$3;
127490 exports.ceil = ceil$2;
127491 exports.clipByValue = clipByValue$2;
127492 exports.clone = clone;
127493 exports.complex = complex$2;
127494 exports.concat = concat$2;
127495 exports.concat1d = concat1d;
127496 exports.concat2d = concat2d;
127497 exports.concat3d = concat3d;
127498 exports.concat4d = concat4d;
127499 exports.constraints = exports_constraints;
127500 exports.conv1d = conv1d$2;
127501 exports.conv2d = conv2d$4;
127502 exports.conv2dTranspose = conv2dTranspose$1;
127503 exports.conv3d = conv3d$2;
127504 exports.conv3dTranspose = conv3dTranspose$1;
127505 exports.copyRegisteredKernels = copyRegisteredKernels;
127506 exports.cos = cos$2;
127507 exports.cosh = cosh$2;
127508 exports.cosineWindow = cosineWindow;
127509 exports.cumprod = cumprod$2;
127510 exports.cumsum = cumsum$2;
127511 exports.customGrad = customGrad;
127512 exports.data = index;
127513 exports.denseBincount = denseBincount$2;
127514 exports.deprecationWarn = deprecationWarn;
127515 exports.depthToSpace = depthToSpace$2;
127516 exports.depthwiseConv2d = depthwiseConv2d$3;
127517 exports.deregisterOp = deregisterOp;
127518 exports.device_util = device_util;
127519 exports.diag = diag$2;
127520 exports.dilation2d = dilation2d;
127521 exports.disableDeprecationWarnings = disableDeprecationWarnings;
127522 exports.dispose = dispose;
127523 exports.disposeVariables = disposeVariables;
127524 exports.div = div$1;
127525 exports.divNoNan = divNoNan;
127526 exports.dot = dot$2;
127527 exports.dropout = dropout$2;
127528 exports.einsum = einsum$2;
127529 exports.elu = elu$4;
127530 exports.enableDebugMode = enableDebugMode;
127531 exports.enableProdMode = enableProdMode;
127532 exports.enclosingPowerOfTwo = enclosingPowerOfTwo;
127533 exports.engine = engine;
127534 exports.ensureShape = ensureShape;
127535 exports.env = env;
127536 exports.equal = equal$2;
127537 exports.erf = erf$2;
127538 exports.euclideanNorm = euclideanNorm;
127539 exports.exp = exp$2;
127540 exports.expandDims = expandDims$3;
127541 exports.expm1 = expm1$2;
127542 exports.eye = eye;
127543 exports.fft = fft$2;
127544 exports.fill = fill$2;
127545 exports.findBackend = findBackend;
127546 exports.findBackendFactory = findBackendFactory;
127547 exports.floor = floor$2;
127548 exports.floorDiv = floorDiv$2;
127549 exports.forceHalfFloat = forceHalfFloat;
127550 exports.fused = fused_ops;
127551 exports.gather = gather$1;
127552 exports.gatherND = gatherND;
127553 exports.gather_util = gather_nd_util;
127554 exports.getBackend = getBackend$1;
127555 exports.getGradient = getGradient;
127556 exports.getKernel = getKernel;
127557 exports.getKernelsForBackend = getKernelsForBackend;
127558 exports.gpgpu_util = gpgpu_util;
127559 exports.grad = grad;
127560 exports.grads = grads;
127561 exports.greater = greater$3;
127562 exports.greaterEqual = greaterEqual$2;
127563 exports.ifft = ifft$2;
127564 exports.imag = imag$2;
127565 exports.image = image$1;
127566 exports.inTopKAsync = inTopKAsync;
127567 exports.initializers = exports_initializers;
127568 exports.input = input;
127569 exports.io = io;
127570 exports.irfft = irfft;
127571 exports.isFinite = isFinite$3;
127572 exports.isInf = isInf$2;
127573 exports.isNaN = isNaN$3;
127574 exports.keep = keep;
127575 exports.kernel_impls = kernel_impls;
127576 exports.layers = exports_layers;
127577 exports.leakyRelu = leakyRelu$2;
127578 exports.less = less$3;
127579 exports.lessEqual = lessEqual$2;
127580 exports.linalg = linalg;
127581 exports.linspace = linspace;
127582 exports.loadGraphModel = loadGraphModel;
127583 exports.loadGraphModelSync = loadGraphModelSync;
127584 exports.loadLayersModel = loadLayersModel;
127585 exports.localResponseNormalization = localResponseNormalization;
127586 exports.log = log$2;
127587 exports.log1p = log1p$2;
127588 exports.logSigmoid = logSigmoid;
127589 exports.logSoftmax = logSoftmax;
127590 exports.logSumExp = logSumExp;
127591 exports.logicalAnd = logicalAnd$2;
127592 exports.logicalNot = logicalNot$2;
127593 exports.logicalOr = logicalOr$2;
127594 exports.logicalXor = logicalXor;
127595 exports.losses = losses;
127596 exports.lowerBound = lowerBound$1;
127597 exports.matMul = matMul$1;
127598 exports.math = math;
127599 exports.max = max$3;
127600 exports.maxPool = maxPool$2;
127601 exports.maxPool3d = maxPool3d$1;
127602 exports.maxPoolWithArgmax = maxPoolWithArgmax;
127603 exports.maximum = maximum$4;
127604 exports.mean = mean$3;
127605 exports.memory = memory;
127606 exports.meshgrid = meshgrid;
127607 exports.metrics = exports_metrics;
127608 exports.min = min$3;
127609 exports.minimum = minimum$4;
127610 exports.mirrorPad = mirrorPad$1;
127611 exports.mod = mod$2;
127612 exports.model = model;
127613 exports.models = exports_models;
127614 exports.moments = moments;
127615 exports.movingAverage = movingAverage;
127616 exports.mul = mul;
127617 exports.multiRNNCell = multiRNNCell;
127618 exports.multinomial = multinomial$2;
127619 exports.neg = neg$2;
127620 exports.nextFrame = nextFrame;
127621 exports.norm = norm;
127622 exports.notEqual = notEqual$2;
127623 exports.oneHot = oneHot$3;
127624 exports.ones = ones$1;
127625 exports.onesLike = onesLike$3;
127626 exports.op = op;
127627 exports.outerProduct = outerProduct;
127628 exports.pad = pad;
127629 exports.pad1d = pad1d;
127630 exports.pad2d = pad2d;
127631 exports.pad3d = pad3d;
127632 exports.pad4d = pad4d;
127633 exports.pool = pool$1;
127634 exports.pow = pow$3;
127635 exports.prelu = prelu$3;
127636 exports.print = print;
127637 exports.prod = prod$2;
127638 exports.profile = profile;
127639 exports.raggedGather = raggedGather$2;
127640 exports.raggedRange = raggedRange$2;
127641 exports.raggedTensorToTensor = raggedTensorToTensor$2;
127642 exports.rand = rand;
127643 exports.randomGamma = randomGamma;
127644 exports.randomNormal = randomNormal$2;
127645 exports.randomStandardNormal = randomStandardNormal;
127646 exports.randomUniform = randomUniform$1;
127647 exports.randomUniformInt = randomUniformInt;
127648 exports.range = range$3;
127649 exports.ready = ready;
127650 exports.real = real$2;
127651 exports.reciprocal = reciprocal$2;
127652 exports.registerBackend = registerBackend;
127653 exports.registerCallbackConstructor = registerCallbackConstructor;
127654 exports.registerGradient = registerGradient;
127655 exports.registerKernel = registerKernel;
127656 exports.registerOp = registerOp;
127657 exports.regularizers = exports_regularizers;
127658 exports.relu = relu$2;
127659 exports.relu6 = relu6$2;
127660 exports.removeBackend = removeBackend;
127661 exports.reshape = reshape$3;
127662 exports.reverse = reverse$2;
127663 exports.reverse1d = reverse1d;
127664 exports.reverse2d = reverse2d;
127665 exports.reverse3d = reverse3d;
127666 exports.reverse4d = reverse4d;
127667 exports.rfft = rfft;
127668 exports.round = round$2;
127669 exports.rsqrt = rsqrt$2;
127670 exports.scalar = scalar;
127671 exports.scatterND = scatterND;
127672 exports.scatter_util = scatter_nd_util;
127673 exports.searchSorted = searchSorted$2;
127674 exports.selu = selu$2;
127675 exports.separableConv2d = separableConv2d$1;
127676 exports.sequential = sequential;
127677 exports.serialization = serialization;
127678 exports.setBackend = setBackend$1;
127679 exports.setPlatform = setPlatform;
127680 exports.setWebGLContext = setWebGLContext;
127681 exports.setdiff1dAsync = setdiff1dAsync;
127682 exports.shared = shared;
127683 exports.sigmoid = sigmoid$2;
127684 exports.sign = sign$3;
127685 exports.signal = signal;
127686 exports.sin = sin$2;
127687 exports.sinh = sinh$2;
127688 exports.slice = slice$2;
127689 exports.slice1d = slice1d;
127690 exports.slice2d = slice2d;
127691 exports.slice3d = slice3d;
127692 exports.slice4d = slice4d;
127693 exports.slice_util = slice_util;
127694 exports.softmax = softmax$3;
127695 exports.softplus = softplus$2;
127696 exports.spaceToBatchND = spaceToBatchND$2;
127697 exports.sparse = sparse$1;
127698 exports.sparseToDense = sparseToDense$2;
127699 exports.spectral = spectral$1;
127700 exports.split = split$3;
127701 exports.sqrt = sqrt$2;
127702 exports.square = square$2;
127703 exports.squaredDifference = squaredDifference$2;
127704 exports.squeeze = squeeze;
127705 exports.stack = stack;
127706 exports.step = step$2;
127707 exports.stridedSlice = stridedSlice$2;
127708 exports.string = string$1;
127709 exports.sub = sub$2;
127710 exports.sum = sum$3;
127711 exports.sumOutType = sumOutType;
127712 exports.tan = tan$2;
127713 exports.tanh = tanh$2;
127714 exports.tensor = tensor;
127715 exports.tensor1d = tensor1d;
127716 exports.tensor2d = tensor2d;
127717 exports.tensor3d = tensor3d;
127718 exports.tensor4d = tensor4d;
127719 exports.tensor5d = tensor5d;
127720 exports.tensor6d = tensor6d;
127721 exports.tensorScatterUpdate = tensorScatterUpdate$2;
127722 exports.tensor_util = tensor_util;
127723 exports.test_util = test_util;
127724 exports.tidy = tidy;
127725 exports.tile = tile$3;
127726 exports.time = time;
127727 exports.topk = topk;
127728 exports.train = train;
127729 exports.transpose = transpose$2;
127730 exports.truncatedNormal = truncatedNormal$1;
127731 exports.unique = unique$3;
127732 exports.unregisterGradient = unregisterGradient;
127733 exports.unregisterKernel = unregisterKernel;
127734 exports.unsortedSegmentSum = unsortedSegmentSum$2;
127735 exports.unstack = unstack;
127736 exports.upcastType = upcastType;
127737 exports.upperBound = upperBound$1;
127738 exports.util = util;
127739 exports.valueAndGrad = valueAndGrad;
127740 exports.valueAndGrads = valueAndGrads;
127741 exports.variable = variable$1;
127742 exports.variableGrads = variableGrads;
127743 exports.version = version;
127744 exports.version_converter = version$5;
127745 exports.version_core = version$7;
127746 exports.version_cpu = version$3;
127747 exports.version_layers = version$6;
127748 exports.version_webgl = version$2;
127749 exports.webgl = webgl;
127750 exports.webgl_util = webgl_util;
127751 exports.where = where;
127752 exports.whereAsync = whereAsync;
127753 exports.zeros = zeros$2;
127754 exports.zerosLike = zerosLike$3;
127755
127756}));
127757//# sourceMappingURL=tf.js.map