UNPKG

20.3 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17/**
18 * Shuffles the array in-place using Fisher-Yates algorithm.
19 *
20 * ```js
21 * const a = [1, 2, 3, 4, 5];
22 * tf.util.shuffle(a);
23 * console.log(a);
24 * ```
25 *
26 * @param array The array to shuffle in-place.
27 *
28 * @doc {heading: 'Util', namespace: 'util'}
29 */
30// tslint:disable-next-line:no-any
31export function shuffle(array) {
32 let counter = array.length;
33 let temp = 0;
34 let index = 0;
35 // While there are elements in the array
36 while (counter > 0) {
37 // Pick a random index
38 index = (Math.random() * counter) | 0;
39 // Decrease counter by 1
40 counter--;
41 // And swap the last element with it
42 temp = array[counter];
43 array[counter] = array[index];
44 array[index] = temp;
45 }
46}
47/**
48 * Shuffles two arrays in-place the same way using Fisher-Yates algorithm.
49 *
50 * ```js
51 * const a = [1,2,3,4,5];
52 * const b = [11,22,33,44,55];
53 * tf.util.shuffleCombo(a, b);
54 * console.log(a, b);
55 * ```
56 *
57 * @param array The first array to shuffle in-place.
58 * @param array2 The second array to shuffle in-place with the same permutation
59 * as the first array.
60 *
61 * @doc {heading: 'Util', namespace: 'util'}
62 */
63export function shuffleCombo(
64// tslint:disable-next-line:no-any
65array,
66// tslint:disable-next-line:no-any
67array2) {
68 if (array.length !== array2.length) {
69 throw new Error(`Array sizes must match to be shuffled together ` +
70 `First array length was ${array.length}` +
71 `Second array length was ${array2.length}`);
72 }
73 let counter = array.length;
74 let temp, temp2;
75 let index = 0;
76 // While there are elements in the array
77 while (counter > 0) {
78 // Pick a random index
79 index = (Math.random() * counter) | 0;
80 // Decrease counter by 1
81 counter--;
82 // And swap the last element of each array with it
83 temp = array[counter];
84 temp2 = array2[counter];
85 array[counter] = array[index];
86 array2[counter] = array2[index];
87 array[index] = temp;
88 array2[index] = temp2;
89 }
90}
91/** Clamps a value to a specified range. */
92export function clamp(min, x, max) {
93 return Math.max(min, Math.min(x, max));
94}
95export function nearestLargerEven(val) {
96 return val % 2 === 0 ? val : val + 1;
97}
98export function sum(arr) {
99 let sum = 0;
100 for (let i = 0; i < arr.length; i++) {
101 sum += arr[i];
102 }
103 return sum;
104}
105/**
106 * Returns a sample from a uniform [a, b) distribution.
107 *
108 * @param a The minimum support (inclusive).
109 * @param b The maximum support (exclusive).
110 * @return A pseudorandom number on the half-open interval [a,b).
111 */
112export function randUniform(a, b) {
113 const r = Math.random();
114 return (b * r) + (1 - r) * a;
115}
116/** Returns the squared Euclidean distance between two vectors. */
117export function distSquared(a, b) {
118 let result = 0;
119 for (let i = 0; i < a.length; i++) {
120 const diff = Number(a[i]) - Number(b[i]);
121 result += diff * diff;
122 }
123 return result;
124}
125/**
126 * Asserts that the expression is true. Otherwise throws an error with the
127 * provided message.
128 *
129 * ```js
130 * const x = 2;
131 * tf.util.assert(x === 2, 'x is not 2');
132 * ```
133 *
134 * @param expr The expression to assert (as a boolean).
135 * @param msg A function that returns the message to report when throwing an
136 * error. We use a function for performance reasons.
137 *
138 * @doc {heading: 'Util', namespace: 'util'}
139 */
140export function assert(expr, msg) {
141 if (!expr) {
142 throw new Error(typeof msg === 'string' ? msg : msg());
143 }
144}
145export function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') {
146 assert(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
147}
148export function assertNonNull(a) {
149 assert(a != null, () => `The input to the tensor constructor must be a non-null value.`);
150}
151// NOTE: We explicitly type out what T extends instead of any so that
152// util.flatten on a nested array of number doesn't try to infer T as a
153// number[][], causing us to explicitly type util.flatten<number>().
154/**
155 * Flattens an arbitrarily nested array.
156 *
157 * ```js
158 * const a = [[1, 2], [3, 4], [5, [6, [7]]]];
159 * const flat = tf.util.flatten(a);
160 * console.log(flat);
161 * ```
162 *
163 * @param arr The nested array to flatten.
164 * @param result The destination array which holds the elements.
165 * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
166 * to false.
167 *
168 * @doc {heading: 'Util', namespace: 'util'}
169 */
170export function flatten(arr, result = [], skipTypedArray = false) {
171 if (result == null) {
172 result = [];
173 }
174 if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) {
175 for (let i = 0; i < arr.length; ++i) {
176 flatten(arr[i], result, skipTypedArray);
177 }
178 }
179 else {
180 result.push(arr);
181 }
182 return result;
183}
184/**
185 * Returns the size (number of elements) of the tensor given its shape.
186 *
187 * ```js
188 * const shape = [3, 4, 2];
189 * const size = tf.util.sizeFromShape(shape);
190 * console.log(size);
191 * ```
192 *
193 * @doc {heading: 'Util', namespace: 'util'}
194 */
195export function sizeFromShape(shape) {
196 if (shape.length === 0) {
197 // Scalar.
198 return 1;
199 }
200 let size = shape[0];
201 for (let i = 1; i < shape.length; i++) {
202 size *= shape[i];
203 }
204 return size;
205}
206export function isScalarShape(shape) {
207 return shape.length === 0;
208}
209export function arraysEqual(n1, n2) {
210 if (n1 === n2) {
211 return true;
212 }
213 if (n1 == null || n2 == null) {
214 return false;
215 }
216 if (n1.length !== n2.length) {
217 return false;
218 }
219 for (let i = 0; i < n1.length; i++) {
220 if (n1[i] !== n2[i]) {
221 return false;
222 }
223 }
224 return true;
225}
226export function isInt(a) {
227 return a % 1 === 0;
228}
229export function tanh(x) {
230 // tslint:disable-next-line:no-any
231 if (Math.tanh != null) {
232 // tslint:disable-next-line:no-any
233 return Math.tanh(x);
234 }
235 if (x === Infinity) {
236 return 1;
237 }
238 else if (x === -Infinity) {
239 return -1;
240 }
241 else {
242 const e2x = Math.exp(2 * x);
243 return (e2x - 1) / (e2x + 1);
244 }
245}
246export function sizeToSquarishShape(size) {
247 const width = Math.ceil(Math.sqrt(size));
248 return [width, Math.ceil(size / width)];
249}
250/**
251 * Creates a new array with randomized indicies to a given quantity.
252 *
253 * ```js
254 * const randomTen = tf.util.createShuffledIndices(10);
255 * console.log(randomTen);
256 * ```
257 *
258 * @param number Quantity of how many shuffled indicies to create.
259 *
260 * @doc {heading: 'Util', namespace: 'util'}
261 */
262export function createShuffledIndices(n) {
263 const shuffledIndices = new Uint32Array(n);
264 for (let i = 0; i < n; ++i) {
265 shuffledIndices[i] = i;
266 }
267 shuffle(shuffledIndices);
268 return shuffledIndices;
269}
270export function rightPad(a, size) {
271 if (size <= a.length) {
272 return a;
273 }
274 return a + ' '.repeat(size - a.length);
275}
276export function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter) {
277 return new Promise((resolve, reject) => {
278 let tryCount = 0;
279 const tryFn = () => {
280 if (checkFn()) {
281 resolve();
282 return;
283 }
284 tryCount++;
285 const nextBackoff = delayFn(tryCount);
286 if (maxCounter != null && tryCount >= maxCounter) {
287 reject();
288 return;
289 }
290 setTimeout(tryFn, nextBackoff);
291 };
292 tryFn();
293 });
294}
295/**
296 * Given the full size of the array and a shape that may contain -1 as the
297 * implicit dimension, returns the inferred shape where -1 is replaced.
298 * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
299 *
300 * @param shape The shape, which may contain -1 in some dimension.
301 * @param size The full size (number of elements) of the array.
302 * @return The inferred shape where -1 is replaced with the inferred size.
303 */
304export function inferFromImplicitShape(shape, size) {
305 let shapeProd = 1;
306 let implicitIdx = -1;
307 for (let i = 0; i < shape.length; ++i) {
308 if (shape[i] >= 0) {
309 shapeProd *= shape[i];
310 }
311 else if (shape[i] === -1) {
312 if (implicitIdx !== -1) {
313 throw Error(`Shapes can only have 1 implicit size. ` +
314 `Found -1 at dim ${implicitIdx} and dim ${i}`);
315 }
316 implicitIdx = i;
317 }
318 else if (shape[i] < 0) {
319 throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`);
320 }
321 }
322 if (implicitIdx === -1) {
323 if (size > 0 && size !== shapeProd) {
324 throw Error(`Size(${size}) must match the product of shape ${shape}`);
325 }
326 return shape;
327 }
328 if (shapeProd === 0) {
329 throw Error(`Cannot infer the missing size in [${shape}] when ` +
330 `there are 0 elements`);
331 }
332 if (size % shapeProd !== 0) {
333 throw Error(`The implicit shape can't be a fractional number. ` +
334 `Got ${size} / ${shapeProd}`);
335 }
336 const newShape = shape.slice();
337 newShape[implicitIdx] = size / shapeProd;
338 return newShape;
339}
340export function parseAxisParam(axis, shape) {
341 const rank = shape.length;
342 // Normalize input
343 axis = axis == null ? shape.map((s, i) => i) : [].concat(axis);
344 // Check for valid range
345 assert(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` +
346 `got axis ${axis}`);
347 // Check for only integers
348 assert(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` +
349 `got axis ${axis}`);
350 // Handle negative axis.
351 return axis.map(a => a < 0 ? rank + a : a);
352}
353/** Reduces the shape by removing all dimensions of shape 1. */
354export function squeezeShape(shape, axis) {
355 const newShape = [];
356 const keptDims = [];
357 const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
358 const axes = (axis == null || isEmptyArray) ?
359 null :
360 parseAxisParam(axis, shape).sort();
361 let j = 0;
362 for (let i = 0; i < shape.length; ++i) {
363 if (axes != null) {
364 if (axes[j] === i && shape[i] !== 1) {
365 throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
366 }
367 if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
368 newShape.push(shape[i]);
369 keptDims.push(i);
370 }
371 if (axes[j] <= i) {
372 j++;
373 }
374 }
375 if (shape[i] !== 1) {
376 newShape.push(shape[i]);
377 keptDims.push(i);
378 }
379 }
380 return { newShape, keptDims };
381}
382export function getTypedArrayFromDType(dtype, size) {
383 let values = null;
384 if (dtype == null || dtype === 'float32') {
385 values = new Float32Array(size);
386 }
387 else if (dtype === 'int32') {
388 values = new Int32Array(size);
389 }
390 else if (dtype === 'bool') {
391 values = new Uint8Array(size);
392 }
393 else {
394 throw new Error(`Unknown data type ${dtype}`);
395 }
396 return values;
397}
398export function getArrayFromDType(dtype, size) {
399 let values = null;
400 if (dtype == null || dtype === 'float32') {
401 values = new Float32Array(size);
402 }
403 else if (dtype === 'int32') {
404 values = new Int32Array(size);
405 }
406 else if (dtype === 'bool') {
407 values = new Uint8Array(size);
408 }
409 else if (dtype === 'string') {
410 values = new Array(size);
411 }
412 else {
413 throw new Error(`Unknown data type ${dtype}`);
414 }
415 return values;
416}
417export function checkConversionForErrors(vals, dtype) {
418 for (let i = 0; i < vals.length; i++) {
419 const num = vals[i];
420 if (isNaN(num) || !isFinite(num)) {
421 throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`);
422 }
423 }
424}
425/** Returns true if the dtype is valid. */
426export function isValidDtype(dtype) {
427 return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
428 dtype === 'int32' || dtype === 'string';
429}
430/**
431 * Returns true if the new type can't encode the old type without loss of
432 * precision.
433 */
434export function hasEncodingLoss(oldType, newType) {
435 if (newType === 'complex64') {
436 return false;
437 }
438 if (newType === 'float32' && oldType !== 'complex64') {
439 return false;
440 }
441 if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
442 return false;
443 }
444 if (newType === 'bool' && oldType === 'bool') {
445 return false;
446 }
447 return true;
448}
449export function isTypedArray(a) {
450 return a instanceof Float32Array || a instanceof Int32Array ||
451 a instanceof Uint8Array;
452}
453export function bytesPerElement(dtype) {
454 if (dtype === 'float32' || dtype === 'int32') {
455 return 4;
456 }
457 else if (dtype === 'complex64') {
458 return 8;
459 }
460 else if (dtype === 'bool') {
461 return 1;
462 }
463 else {
464 throw new Error(`Unknown dtype ${dtype}`);
465 }
466}
467/**
468 * Returns the approximate number of bytes allocated in the string array - 2
469 * bytes per character. Computing the exact bytes for a native string in JS is
470 * not possible since it depends on the encoding of the html page that serves
471 * the website.
472 */
473export function bytesFromStringArray(arr) {
474 if (arr == null) {
475 return 0;
476 }
477 let bytes = 0;
478 arr.forEach(x => bytes += x.length);
479 return bytes;
480}
481/** Returns true if the value is a string. */
482export function isString(value) {
483 return typeof value === 'string' || value instanceof String;
484}
485export function isBoolean(value) {
486 return typeof value === 'boolean';
487}
488export function isNumber(value) {
489 return typeof value === 'number';
490}
491export function inferDtype(values) {
492 if (Array.isArray(values)) {
493 return inferDtype(values[0]);
494 }
495 if (values instanceof Float32Array) {
496 return 'float32';
497 }
498 else if (values instanceof Int32Array || values instanceof Uint8Array) {
499 return 'int32';
500 }
501 else if (isNumber(values)) {
502 return 'float32';
503 }
504 else if (isString(values)) {
505 return 'string';
506 }
507 else if (isBoolean(values)) {
508 return 'bool';
509 }
510 return 'float32';
511}
512export function isFunction(f) {
513 return !!(f && f.constructor && f.call && f.apply);
514}
515export function nearestDivisor(size, start) {
516 for (let i = start; i < size; ++i) {
517 if (size % i === 0) {
518 return i;
519 }
520 }
521 return size;
522}
523export function computeStrides(shape) {
524 const rank = shape.length;
525 if (rank < 2) {
526 return [];
527 }
528 // Last dimension has implicit stride of 1, thus having D-1 (instead of D)
529 // strides.
530 const strides = new Array(rank - 1);
531 strides[rank - 2] = shape[rank - 1];
532 for (let i = rank - 3; i >= 0; --i) {
533 strides[i] = strides[i + 1] * shape[i + 1];
534 }
535 return strides;
536}
537function createNestedArray(offset, shape, a, isComplex = false) {
538 const ret = new Array();
539 if (shape.length === 1) {
540 const d = shape[0] * (isComplex ? 2 : 1);
541 for (let i = 0; i < d; i++) {
542 ret[i] = a[offset + i];
543 }
544 }
545 else {
546 const d = shape[0];
547 const rest = shape.slice(1);
548 const len = rest.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
549 for (let i = 0; i < d; i++) {
550 ret[i] = createNestedArray(offset + i * len, rest, a, isComplex);
551 }
552 }
553 return ret;
554}
555// Provide a nested array of TypedArray in given shape.
556export function toNestedArray(shape, a, isComplex = false) {
557 if (shape.length === 0) {
558 // Scalar type should return a single number.
559 return a[0];
560 }
561 const size = shape.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
562 if (size === 0) {
563 // A tensor with shape zero should be turned into empty list.
564 return [];
565 }
566 if (size !== a.length) {
567 throw new Error(`[${shape}] does not match the input size ${a.length}${isComplex ? ' for a complex tensor' : ''}.`);
568 }
569 return createNestedArray(0, shape, a, isComplex);
570}
571export function makeOnesTypedArray(size, dtype) {
572 const array = makeZerosTypedArray(size, dtype);
573 for (let i = 0; i < array.length; i++) {
574 array[i] = 1;
575 }
576 return array;
577}
578export function makeZerosTypedArray(size, dtype) {
579 if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
580 return new Float32Array(size);
581 }
582 else if (dtype === 'int32') {
583 return new Int32Array(size);
584 }
585 else if (dtype === 'bool') {
586 return new Uint8Array(size);
587 }
588 else {
589 throw new Error(`Unknown data type ${dtype}`);
590 }
591}
592/**
593 * Make nested `TypedArray` filled with zeros.
594 * @param shape The shape information for the nested array.
595 * @param dtype dtype of the array element.
596 */
597export function makeZerosNestedTypedArray(shape, dtype) {
598 const size = shape.reduce((prev, curr) => prev * curr, 1);
599 if (dtype == null || dtype === 'float32') {
600 return toNestedArray(shape, new Float32Array(size));
601 }
602 else if (dtype === 'int32') {
603 return toNestedArray(shape, new Int32Array(size));
604 }
605 else if (dtype === 'bool') {
606 return toNestedArray(shape, new Uint8Array(size));
607 }
608 else {
609 throw new Error(`Unknown data type ${dtype}`);
610 }
611}
612export function assertNonNegativeIntegerDimensions(shape) {
613 shape.forEach(dimSize => {
614 assert(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` +
615 `shape [${shape}].`);
616 });
617}
618/**
619 * Computes flat index for a given location (multidimentionsal index) in a
620 * Tensor/multidimensional array.
621 *
622 * @param locs Location in the tensor.
623 * @param rank Rank of the tensor.
624 * @param strides Tensor strides.
625 */
626export function locToIndex(locs, rank, strides) {
627 if (rank === 0) {
628 return 0;
629 }
630 else if (rank === 1) {
631 return locs[0];
632 }
633 let index = locs[locs.length - 1];
634 for (let i = 0; i < locs.length - 1; ++i) {
635 index += strides[i] * locs[i];
636 }
637 return index;
638}
639/**
640 * Computes the location (multidimensional index) in a tensor/multidimentional
641 * array for a given flat index.
642 *
643 * @param index Index in flat array.
644 * @param rank Rank of tensor.
645 * @param strides Strides of tensor.
646 */
647export function indexToLoc(index, rank, strides) {
648 if (rank === 0) {
649 return [];
650 }
651 else if (rank === 1) {
652 return [index];
653 }
654 const locs = new Array(rank);
655 for (let i = 0; i < locs.length - 1; ++i) {
656 locs[i] = Math.floor(index / strides[i]);
657 index -= locs[i] * strides[i];
658 }
659 locs[locs.length - 1] = index;
660 return locs;
661}
662/**
663 * This method asserts whether an object is a Promise instance.
664 * @param object
665 */
666// tslint:disable-next-line: no-any
667export function isPromise(object) {
668 // We chose to not use 'obj instanceOf Promise' for two reasons:
669 // 1. It only reliably works for es6 Promise, not other Promise
670 // implementations.
671 // 2. It doesn't work with framework that uses zone.js. zone.js monkey patch
672 // the async calls, so it is possible the obj (patched) is comparing to a
673 // pre-patched Promise.
674 return object && object.then && typeof object.then === 'function';
675}
676//# sourceMappingURL=util_base.js.map
\No newline at end of file