1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 |
|
18 |
|
19 |
|
20 |
|
21 |
|
22 |
|
23 |
|
24 |
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
|
30 |
|
31 | export function shuffle(array) {
|
32 | let counter = array.length;
|
33 | let temp = 0;
|
34 | let index = 0;
|
35 |
|
36 | while (counter > 0) {
|
37 |
|
38 | index = (Math.random() * counter) | 0;
|
39 |
|
40 | counter--;
|
41 |
|
42 | temp = array[counter];
|
43 | array[counter] = array[index];
|
44 | array[index] = temp;
|
45 | }
|
46 | }
|
47 |
|
48 |
|
49 |
|
50 |
|
51 |
|
52 |
|
53 |
|
54 |
|
55 |
|
56 |
|
57 |
|
58 |
|
59 |
|
60 |
|
61 |
|
62 |
|
63 | export function shuffleCombo(
|
64 | // tslint:disable-next-line:no-any
|
65 | array,
|
66 | // tslint:disable-next-line:no-any
|
67 | array2) {
|
68 | if (array.length !== array2.length) {
|
69 | throw new Error(`Array sizes must match to be shuffled together ` +
|
70 | `First array length was ${array.length}` +
|
71 | `Second array length was ${array2.length}`);
|
72 | }
|
73 | let counter = array.length;
|
74 | let temp, temp2;
|
75 | let index = 0;
|
76 |
|
77 | while (counter > 0) {
|
78 |
|
79 | index = (Math.random() * counter) | 0;
|
80 |
|
81 | counter--;
|
82 |
|
83 | temp = array[counter];
|
84 | temp2 = array2[counter];
|
85 | array[counter] = array[index];
|
86 | array2[counter] = array2[index];
|
87 | array[index] = temp;
|
88 | array2[index] = temp2;
|
89 | }
|
90 | }
|
91 |
|
92 | export function clamp(min, x, max) {
|
93 | return Math.max(min, Math.min(x, max));
|
94 | }
|
95 | export function nearestLargerEven(val) {
|
96 | return val % 2 === 0 ? val : val + 1;
|
97 | }
|
98 | export function sum(arr) {
|
99 | let sum = 0;
|
100 | for (let i = 0; i < arr.length; i++) {
|
101 | sum += arr[i];
|
102 | }
|
103 | return sum;
|
104 | }
|
105 |
|
106 |
|
107 |
|
108 |
|
109 |
|
110 |
|
111 |
|
112 | export function randUniform(a, b) {
|
113 | const r = Math.random();
|
114 | return (b * r) + (1 - r) * a;
|
115 | }
|
116 |
|
117 | export function distSquared(a, b) {
|
118 | let result = 0;
|
119 | for (let i = 0; i < a.length; i++) {
|
120 | const diff = Number(a[i]) - Number(b[i]);
|
121 | result += diff * diff;
|
122 | }
|
123 | return result;
|
124 | }
|
125 |
|
126 |
|
127 |
|
128 |
|
129 |
|
130 |
|
131 |
|
132 |
|
133 |
|
134 |
|
135 |
|
136 |
|
137 |
|
138 |
|
139 |
|
140 | export function assert(expr, msg) {
|
141 | if (!expr) {
|
142 | throw new Error(typeof msg === 'string' ? msg : msg());
|
143 | }
|
144 | }
|
145 | export function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') {
|
146 | assert(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
|
147 | }
|
148 | export function assertNonNull(a) {
|
149 | assert(a != null, () => `The input to the tensor constructor must be a non-null value.`);
|
150 | }
|
151 |
|
152 |
|
153 |
|
154 |
|
155 |
|
156 |
|
157 |
|
158 |
|
159 |
|
160 |
|
161 |
|
162 |
|
163 |
|
164 |
|
165 |
|
166 |
|
167 |
|
168 |
|
169 |
|
170 | export function flatten(arr, result = [], skipTypedArray = false) {
|
171 | if (result == null) {
|
172 | result = [];
|
173 | }
|
174 | if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) {
|
175 | for (let i = 0; i < arr.length; ++i) {
|
176 | flatten(arr[i], result, skipTypedArray);
|
177 | }
|
178 | }
|
179 | else {
|
180 | result.push(arr);
|
181 | }
|
182 | return result;
|
183 | }
|
184 |
|
185 |
|
186 |
|
187 |
|
188 |
|
189 |
|
190 |
|
191 |
|
192 |
|
193 |
|
194 |
|
195 | export function sizeFromShape(shape) {
|
196 | if (shape.length === 0) {
|
197 |
|
198 | return 1;
|
199 | }
|
200 | let size = shape[0];
|
201 | for (let i = 1; i < shape.length; i++) {
|
202 | size *= shape[i];
|
203 | }
|
204 | return size;
|
205 | }
|
206 | export function isScalarShape(shape) {
|
207 | return shape.length === 0;
|
208 | }
|
209 | export function arraysEqual(n1, n2) {
|
210 | if (n1 === n2) {
|
211 | return true;
|
212 | }
|
213 | if (n1 == null || n2 == null) {
|
214 | return false;
|
215 | }
|
216 | if (n1.length !== n2.length) {
|
217 | return false;
|
218 | }
|
219 | for (let i = 0; i < n1.length; i++) {
|
220 | if (n1[i] !== n2[i]) {
|
221 | return false;
|
222 | }
|
223 | }
|
224 | return true;
|
225 | }
|
226 | export function isInt(a) {
|
227 | return a % 1 === 0;
|
228 | }
|
229 | export function tanh(x) {
|
230 |
|
231 | if (Math.tanh != null) {
|
232 |
|
233 | return Math.tanh(x);
|
234 | }
|
235 | if (x === Infinity) {
|
236 | return 1;
|
237 | }
|
238 | else if (x === -Infinity) {
|
239 | return -1;
|
240 | }
|
241 | else {
|
242 | const e2x = Math.exp(2 * x);
|
243 | return (e2x - 1) / (e2x + 1);
|
244 | }
|
245 | }
|
246 | export function sizeToSquarishShape(size) {
|
247 | const width = Math.ceil(Math.sqrt(size));
|
248 | return [width, Math.ceil(size / width)];
|
249 | }
|
250 |
|
251 |
|
252 |
|
253 |
|
254 |
|
255 |
|
256 |
|
257 |
|
258 |
|
259 |
|
260 |
|
261 |
|
262 | export function createShuffledIndices(n) {
|
263 | const shuffledIndices = new Uint32Array(n);
|
264 | for (let i = 0; i < n; ++i) {
|
265 | shuffledIndices[i] = i;
|
266 | }
|
267 | shuffle(shuffledIndices);
|
268 | return shuffledIndices;
|
269 | }
|
270 | export function rightPad(a, size) {
|
271 | if (size <= a.length) {
|
272 | return a;
|
273 | }
|
274 | return a + ' '.repeat(size - a.length);
|
275 | }
|
276 | export function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter) {
|
277 | return new Promise((resolve, reject) => {
|
278 | let tryCount = 0;
|
279 | const tryFn = () => {
|
280 | if (checkFn()) {
|
281 | resolve();
|
282 | return;
|
283 | }
|
284 | tryCount++;
|
285 | const nextBackoff = delayFn(tryCount);
|
286 | if (maxCounter != null && tryCount >= maxCounter) {
|
287 | reject();
|
288 | return;
|
289 | }
|
290 | setTimeout(tryFn, nextBackoff);
|
291 | };
|
292 | tryFn();
|
293 | });
|
294 | }
|
295 |
|
296 |
|
297 |
|
298 |
|
299 |
|
300 |
|
301 |
|
302 |
|
303 |
|
304 | export function inferFromImplicitShape(shape, size) {
|
305 | let shapeProd = 1;
|
306 | let implicitIdx = -1;
|
307 | for (let i = 0; i < shape.length; ++i) {
|
308 | if (shape[i] >= 0) {
|
309 | shapeProd *= shape[i];
|
310 | }
|
311 | else if (shape[i] === -1) {
|
312 | if (implicitIdx !== -1) {
|
313 | throw Error(`Shapes can only have 1 implicit size. ` +
|
314 | `Found -1 at dim ${implicitIdx} and dim ${i}`);
|
315 | }
|
316 | implicitIdx = i;
|
317 | }
|
318 | else if (shape[i] < 0) {
|
319 | throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`);
|
320 | }
|
321 | }
|
322 | if (implicitIdx === -1) {
|
323 | if (size > 0 && size !== shapeProd) {
|
324 | throw Error(`Size(${size}) must match the product of shape ${shape}`);
|
325 | }
|
326 | return shape;
|
327 | }
|
328 | if (shapeProd === 0) {
|
329 | throw Error(`Cannot infer the missing size in [${shape}] when ` +
|
330 | `there are 0 elements`);
|
331 | }
|
332 | if (size % shapeProd !== 0) {
|
333 | throw Error(`The implicit shape can't be a fractional number. ` +
|
334 | `Got ${size} / ${shapeProd}`);
|
335 | }
|
336 | const newShape = shape.slice();
|
337 | newShape[implicitIdx] = size / shapeProd;
|
338 | return newShape;
|
339 | }
|
340 | export function parseAxisParam(axis, shape) {
|
341 | const rank = shape.length;
|
342 |
|
343 | axis = axis == null ? shape.map((s, i) => i) : [].concat(axis);
|
344 |
|
345 | assert(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` +
|
346 | `got axis ${axis}`);
|
347 |
|
348 | assert(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` +
|
349 | `got axis ${axis}`);
|
350 |
|
351 | return axis.map(a => a < 0 ? rank + a : a);
|
352 | }
|
353 |
|
354 | export function squeezeShape(shape, axis) {
|
355 | const newShape = [];
|
356 | const keptDims = [];
|
357 | const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
|
358 | const axes = (axis == null || isEmptyArray) ?
|
359 | null :
|
360 | parseAxisParam(axis, shape).sort();
|
361 | let j = 0;
|
362 | for (let i = 0; i < shape.length; ++i) {
|
363 | if (axes != null) {
|
364 | if (axes[j] === i && shape[i] !== 1) {
|
365 | throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
|
366 | }
|
367 | if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
|
368 | newShape.push(shape[i]);
|
369 | keptDims.push(i);
|
370 | }
|
371 | if (axes[j] <= i) {
|
372 | j++;
|
373 | }
|
374 | }
|
375 | if (shape[i] !== 1) {
|
376 | newShape.push(shape[i]);
|
377 | keptDims.push(i);
|
378 | }
|
379 | }
|
380 | return { newShape, keptDims };
|
381 | }
|
382 | export function getTypedArrayFromDType(dtype, size) {
|
383 | let values = null;
|
384 | if (dtype == null || dtype === 'float32') {
|
385 | values = new Float32Array(size);
|
386 | }
|
387 | else if (dtype === 'int32') {
|
388 | values = new Int32Array(size);
|
389 | }
|
390 | else if (dtype === 'bool') {
|
391 | values = new Uint8Array(size);
|
392 | }
|
393 | else {
|
394 | throw new Error(`Unknown data type ${dtype}`);
|
395 | }
|
396 | return values;
|
397 | }
|
398 | export function getArrayFromDType(dtype, size) {
|
399 | let values = null;
|
400 | if (dtype == null || dtype === 'float32') {
|
401 | values = new Float32Array(size);
|
402 | }
|
403 | else if (dtype === 'int32') {
|
404 | values = new Int32Array(size);
|
405 | }
|
406 | else if (dtype === 'bool') {
|
407 | values = new Uint8Array(size);
|
408 | }
|
409 | else if (dtype === 'string') {
|
410 | values = new Array(size);
|
411 | }
|
412 | else {
|
413 | throw new Error(`Unknown data type ${dtype}`);
|
414 | }
|
415 | return values;
|
416 | }
|
417 | export function checkConversionForErrors(vals, dtype) {
|
418 | for (let i = 0; i < vals.length; i++) {
|
419 | const num = vals[i];
|
420 | if (isNaN(num) || !isFinite(num)) {
|
421 | throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`);
|
422 | }
|
423 | }
|
424 | }
|
425 |
|
426 | export function isValidDtype(dtype) {
|
427 | return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
|
428 | dtype === 'int32' || dtype === 'string';
|
429 | }
|
430 |
|
431 |
|
432 |
|
433 |
|
434 | export function hasEncodingLoss(oldType, newType) {
|
435 | if (newType === 'complex64') {
|
436 | return false;
|
437 | }
|
438 | if (newType === 'float32' && oldType !== 'complex64') {
|
439 | return false;
|
440 | }
|
441 | if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
|
442 | return false;
|
443 | }
|
444 | if (newType === 'bool' && oldType === 'bool') {
|
445 | return false;
|
446 | }
|
447 | return true;
|
448 | }
|
449 | export function isTypedArray(a) {
|
450 | return a instanceof Float32Array || a instanceof Int32Array ||
|
451 | a instanceof Uint8Array;
|
452 | }
|
453 | export function bytesPerElement(dtype) {
|
454 | if (dtype === 'float32' || dtype === 'int32') {
|
455 | return 4;
|
456 | }
|
457 | else if (dtype === 'complex64') {
|
458 | return 8;
|
459 | }
|
460 | else if (dtype === 'bool') {
|
461 | return 1;
|
462 | }
|
463 | else {
|
464 | throw new Error(`Unknown dtype ${dtype}`);
|
465 | }
|
466 | }
|
467 |
|
468 |
|
469 |
|
470 |
|
471 |
|
472 |
|
473 | export function bytesFromStringArray(arr) {
|
474 | if (arr == null) {
|
475 | return 0;
|
476 | }
|
477 | let bytes = 0;
|
478 | arr.forEach(x => bytes += x.length);
|
479 | return bytes;
|
480 | }
|
481 |
|
482 | export function isString(value) {
|
483 | return typeof value === 'string' || value instanceof String;
|
484 | }
|
485 | export function isBoolean(value) {
|
486 | return typeof value === 'boolean';
|
487 | }
|
488 | export function isNumber(value) {
|
489 | return typeof value === 'number';
|
490 | }
|
491 | export function inferDtype(values) {
|
492 | if (Array.isArray(values)) {
|
493 | return inferDtype(values[0]);
|
494 | }
|
495 | if (values instanceof Float32Array) {
|
496 | return 'float32';
|
497 | }
|
498 | else if (values instanceof Int32Array || values instanceof Uint8Array) {
|
499 | return 'int32';
|
500 | }
|
501 | else if (isNumber(values)) {
|
502 | return 'float32';
|
503 | }
|
504 | else if (isString(values)) {
|
505 | return 'string';
|
506 | }
|
507 | else if (isBoolean(values)) {
|
508 | return 'bool';
|
509 | }
|
510 | return 'float32';
|
511 | }
|
512 | export function isFunction(f) {
|
513 | return !!(f && f.constructor && f.call && f.apply);
|
514 | }
|
515 | export function nearestDivisor(size, start) {
|
516 | for (let i = start; i < size; ++i) {
|
517 | if (size % i === 0) {
|
518 | return i;
|
519 | }
|
520 | }
|
521 | return size;
|
522 | }
|
523 | export function computeStrides(shape) {
|
524 | const rank = shape.length;
|
525 | if (rank < 2) {
|
526 | return [];
|
527 | }
|
528 |
|
529 |
|
530 | const strides = new Array(rank - 1);
|
531 | strides[rank - 2] = shape[rank - 1];
|
532 | for (let i = rank - 3; i >= 0; --i) {
|
533 | strides[i] = strides[i + 1] * shape[i + 1];
|
534 | }
|
535 | return strides;
|
536 | }
|
537 | function createNestedArray(offset, shape, a, isComplex = false) {
|
538 | const ret = new Array();
|
539 | if (shape.length === 1) {
|
540 | const d = shape[0] * (isComplex ? 2 : 1);
|
541 | for (let i = 0; i < d; i++) {
|
542 | ret[i] = a[offset + i];
|
543 | }
|
544 | }
|
545 | else {
|
546 | const d = shape[0];
|
547 | const rest = shape.slice(1);
|
548 | const len = rest.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
|
549 | for (let i = 0; i < d; i++) {
|
550 | ret[i] = createNestedArray(offset + i * len, rest, a, isComplex);
|
551 | }
|
552 | }
|
553 | return ret;
|
554 | }
|
555 |
|
556 | export function toNestedArray(shape, a, isComplex = false) {
|
557 | if (shape.length === 0) {
|
558 |
|
559 | return a[0];
|
560 | }
|
561 | const size = shape.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
|
562 | if (size === 0) {
|
563 |
|
564 | return [];
|
565 | }
|
566 | if (size !== a.length) {
|
567 | throw new Error(`[${shape}] does not match the input size ${a.length}${isComplex ? ' for a complex tensor' : ''}.`);
|
568 | }
|
569 | return createNestedArray(0, shape, a, isComplex);
|
570 | }
|
571 | export function makeOnesTypedArray(size, dtype) {
|
572 | const array = makeZerosTypedArray(size, dtype);
|
573 | for (let i = 0; i < array.length; i++) {
|
574 | array[i] = 1;
|
575 | }
|
576 | return array;
|
577 | }
|
578 | export function makeZerosTypedArray(size, dtype) {
|
579 | if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
|
580 | return new Float32Array(size);
|
581 | }
|
582 | else if (dtype === 'int32') {
|
583 | return new Int32Array(size);
|
584 | }
|
585 | else if (dtype === 'bool') {
|
586 | return new Uint8Array(size);
|
587 | }
|
588 | else {
|
589 | throw new Error(`Unknown data type ${dtype}`);
|
590 | }
|
591 | }
|
592 |
|
593 |
|
594 |
|
595 |
|
596 |
|
597 | export function makeZerosNestedTypedArray(shape, dtype) {
|
598 | const size = shape.reduce((prev, curr) => prev * curr, 1);
|
599 | if (dtype == null || dtype === 'float32') {
|
600 | return toNestedArray(shape, new Float32Array(size));
|
601 | }
|
602 | else if (dtype === 'int32') {
|
603 | return toNestedArray(shape, new Int32Array(size));
|
604 | }
|
605 | else if (dtype === 'bool') {
|
606 | return toNestedArray(shape, new Uint8Array(size));
|
607 | }
|
608 | else {
|
609 | throw new Error(`Unknown data type ${dtype}`);
|
610 | }
|
611 | }
|
612 | export function assertNonNegativeIntegerDimensions(shape) {
|
613 | shape.forEach(dimSize => {
|
614 | assert(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` +
|
615 | `shape [${shape}].`);
|
616 | });
|
617 | }
|
618 |
|
619 |
|
620 |
|
621 |
|
622 |
|
623 |
|
624 |
|
625 |
|
626 | export function locToIndex(locs, rank, strides) {
|
627 | if (rank === 0) {
|
628 | return 0;
|
629 | }
|
630 | else if (rank === 1) {
|
631 | return locs[0];
|
632 | }
|
633 | let index = locs[locs.length - 1];
|
634 | for (let i = 0; i < locs.length - 1; ++i) {
|
635 | index += strides[i] * locs[i];
|
636 | }
|
637 | return index;
|
638 | }
|
639 |
|
640 |
|
641 |
|
642 |
|
643 |
|
644 |
|
645 |
|
646 |
|
647 | export function indexToLoc(index, rank, strides) {
|
648 | if (rank === 0) {
|
649 | return [];
|
650 | }
|
651 | else if (rank === 1) {
|
652 | return [index];
|
653 | }
|
654 | const locs = new Array(rank);
|
655 | for (let i = 0; i < locs.length - 1; ++i) {
|
656 | locs[i] = Math.floor(index / strides[i]);
|
657 | index -= locs[i] * strides[i];
|
658 | }
|
659 | locs[locs.length - 1] = index;
|
660 | return locs;
|
661 | }
|
662 |
|
663 |
|
664 |
|
665 |
|
666 |
|
667 | export function isPromise(object) {
|
668 |
|
669 |
|
670 |
|
671 |
|
672 |
|
673 |
|
674 | return object && object.then && typeof object.then === 'function';
|
675 | }
|
676 |
|
\ | No newline at end of file |