UNPKG

11.5 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as util from '../util';
18/**
19 * Returns true if the axis specifies the inner most dimensions of the
20 * array.
21 */
22export function axesAreInnerMostDims(axes, rank) {
23 for (let i = 0; i < axes.length; ++i) {
24 if (axes[axes.length - i - 1] !== rank - 1 - i) {
25 return false;
26 }
27 }
28 return true;
29}
30export function combineLocations(outputLoc, reduceLoc, axes) {
31 const rank = outputLoc.length + reduceLoc.length;
32 const loc = [];
33 let outIdx = 0;
34 let reduceIdx = 0;
35 for (let dim = 0; dim < rank; dim++) {
36 if (axes.indexOf(dim) === -1) {
37 loc.push(outputLoc[outIdx++]);
38 }
39 else {
40 loc.push(reduceLoc[reduceIdx++]);
41 }
42 }
43 return loc;
44}
45export function computeOutAndReduceShapes(aShape, axes) {
46 const outShape = [];
47 const rank = aShape.length;
48 for (let dim = 0; dim < rank; dim++) {
49 if (axes.indexOf(dim) === -1) {
50 outShape.push(aShape[dim]);
51 }
52 }
53 const reduceShape = axes.map(dim => aShape[dim]);
54 return [outShape, reduceShape];
55}
56export function expandShapeToKeepDim(shape, axes) {
57 const reduceSubShape = axes.map(x => 1);
58 return combineLocations(shape, reduceSubShape, axes);
59}
60export function assertAxesAreInnerMostDims(msg, axes, rank) {
61 util.assert(axesAreInnerMostDims(axes, rank), () => `${msg} supports only inner-most axes for now. ` +
62 `Got axes ${axes} and rank-${rank} input.`);
63}
64/**
65 * Returns the axes permutation to be used with `tf.transpose`, if such
66 * permutation is necessary. Otherwise it returns null. This method is used by
67 * operations that operate only on inner-most axes.
68 */
69export function getAxesPermutation(axes, rank) {
70 if (axesAreInnerMostDims(axes, rank)) {
71 return null;
72 }
73 const result = [];
74 for (let i = 0; i < rank; ++i) {
75 if (axes.indexOf(i) === -1) {
76 result.push(i);
77 }
78 }
79 axes.forEach(axis => result.push(axis));
80 return result;
81}
82/** Returns the axes permutation that undoes the original permutation. */
83export function getUndoAxesPermutation(axes) {
84 return axes.map((axis, i) => [i, axis])
85 .sort((a, b) => a[1] - b[1])
86 .map(x => x[0]);
87}
88export function getInnerMostAxes(numAxes, rank) {
89 const res = [];
90 for (let i = rank - numAxes; i < rank; ++i) {
91 res.push(i);
92 }
93 return res;
94}
95//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYXhpc191dGlsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvYXhpc191dGlsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sS0FBSyxJQUFJLE1BQU0sU0FBUyxDQUFDO0FBRWhDOzs7R0FHRztBQUNILE1BQU0sVUFBVSxvQkFBb0IsQ0FBQyxJQUFjLEVBQUUsSUFBWTtJQUMvRCxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsSUFBSSxDQUFDLE1BQU0sRUFBRSxFQUFFLENBQUMsRUFBRTtRQUNwQyxJQUFJLElBQUksQ0FBQyxJQUFJLENBQUMsTUFBTSxHQUFHLENBQUMsR0FBRyxDQUFDLENBQUMsS0FBSyxJQUFJLEdBQUcsQ0FBQyxHQUFHLENBQUMsRUFBRTtZQUM5QyxPQUFPLEtBQUssQ0FBQztTQUNkO0tBQ0Y7SUFDRCxPQUFPLElBQUksQ0FBQztBQUNkLENBQUM7QUFFRCxNQUFNLFVBQVUsZ0JBQWdCLENBQzVCLFNBQW1CLEVBQUUsU0FBbUIsRUFBRSxJQUFjO0lBQzFELE1BQU0sSUFBSSxHQUFHLFNBQVMsQ0FBQyxNQUFNLEdBQUcsU0FBUyxDQUFDLE1BQU0sQ0FBQztJQUNqRCxNQUFNLEdBQUcsR0FBRyxFQUFFLENBQUM7SUFDZixJQUFJLE1BQU0sR0FBRyxDQUFDLENBQUM7SUFDZixJQUFJLFNBQVMsR0FBRyxDQUFDLENBQUM7SUFDaEIsS0FBSyxJQUFJLEdBQUcsR0FBRyxDQUFDLEVBQUUsR0FBRyxHQUFHLElBQUksRUFBRSxHQUFHLEVBQUUsRUFBRTtRQUNyQyxJQUFJLElBQUksQ0FBQyxPQUFPLENBQUMsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLEVBQUU7WUFDNUIsR0FBRyxDQUFDLElBQUksQ0FBQyxTQUFTLENBQUMsTUFBTSxFQUFFLENBQUMsQ0FBQyxDQUFDO1NBQy9CO2FBQU07WUFDTCxHQUFHLENBQUMsSUFBSSxDQUFDLFNBQVMsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxDQUFDLENBQUM7U0FDbEM7S0FDRjtJQUNELE9BQU8sR0FBRyxDQUFDO0FBQ2IsQ0FBQztBQUVELE1BQU0sVUFBVSx5QkFBeUIsQ0FDckMsTUFBZ0IsRUFBRSxJQUFjO0lBQ2xDLE1BQU0sUUFBUSxHQUFHLEVBQUUsQ0FBQztJQUNwQixNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsTUFBTSxDQUFDO0lBQzNCLEtBQUssSUFBSSxHQUFHLEdBQUcsQ0FBQyxFQUFFLEdBQUcsR0FBRyxJQUFJLEVBQUUsR0FBRyxFQUFFLEVBQUU7UUFDbkMsSUFBSSxJQUFJLENBQUMsT0FBTyxDQUFDLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxFQUFFO1lBQzVCLFFBQVEsQ0FBQyxJQUFJLENBQUMsTUFBTSxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7U0FDNUI7S0FDRjtJQUNELE1BQU0sV0FBVyxHQUFHLElBQUksQ0FBQyxHQUFHLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxNQUFNLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztJQUNqRCxPQUFPLENBQUMsUUFBUSxFQUFFLFdBQVcsQ0FBQyxDQUFDO0FBQ2pDLENBQUM7QUFFRCxNQUFNLFVBQVUsb0JBQW9CLENBQ2hDLEtBQWUsRUFBRSxJQUFjO0lBQ2pDLE1BQU0sY0FBYyxHQUFHLElBQUksQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUN4QyxPQUFPLGdCQUFnQixDQUFDLEtBQUssRUFBRSxjQUFjLEVBQUUsSUFBSSxDQUFDLENBQUM7QUFDdkQsQ0FBQztBQUVELE1BQU0sVUFBVSwwQkFBMEIsQ0FDdEMsR0FBVyxFQUFFLElBQWMsRUFBRSxJQUFZO0lBQzNDLElBQUksQ0FBQyxNQUFNLENBQ1Asb0JBQW9CLENBQUMsSUFBSSxFQUFFLElBQUksQ0FBQyxFQUNoQyxHQUFHLEVBQUUsQ0FBQyxHQUFHLEdBQUcsMENBQTBDO1FBQ2xELFlBQVksSUFBSSxhQUFhLElBQUksU0FBUyxDQUFDLENBQUM7QUFDdEQsQ0FBQztBQUVEOzs7O0dBSUc7QUFDSCxNQUFNLFVBQVUsa0JBQWtCLENBQUMsSUFBYyxFQUFFLElBQVk7SUFFN0QsSUFBSSxvQkFBb0IsQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLEVBQUU7UUFDcEMsT0FBTyxJQUFJLENBQUM7S0FDYjtJQUNELE1BQU0sTUFBTSxHQUFhLEVBQUUsQ0FBQztJQUM1QixLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsSUFBSSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQzdCLElBQUksSUFBSSxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsRUFBRTtZQUMxQixNQUFNLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO1NBQ2hCO0tBQ0Y7SUFDRCxJQUFJLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQyxFQUFFLENBQUMsTUFBTSxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDO0lBQ3hDLE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUM7QUFFRCx5RUFBeUU7QUFDekUsTUFBTSxVQUFVLHNCQUFzQixDQUFDLElBQWM7SUFDbkQsT0FBTyxJQUFJLENBQUMsR0FBRyxDQUFDLENBQUMsSUFBSSxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsSUFBSSxDQUFDLENBQUM7U0FDbEMsSUFBSSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztTQUMzQixHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztBQUN0QixDQUFDO0FBRUQsTUFBTSxVQUFVLGdCQUFnQixDQUFDLE9BQWUsRUFBRSxJQUFZO0lBQzVELE1BQU0sR0FBRyxHQUFhLEVBQUUsQ0FBQztJQUN6QixLQUFLLElBQUksQ0FBQyxHQUFHLElBQUksR0FBRyxPQUFPLEVBQUUsQ0FBQyxHQUFHLElBQUksRUFBRSxFQUFFLENBQUMsRUFBRTtRQUMxQyxHQUFHLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQ2I7SUFDRCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxNyBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCAqIGFzIHV0aWwgZnJvbSAnLi4vdXRpbCc7XG5cbi8qKlxuICogUmV0dXJucyB0cnVlIGlmIHRoZSBheGlzIHNwZWNpZmllcyB0aGUgaW5uZXIgbW9zdCBkaW1lbnNpb25zIG9mIHRoZVxuICogYXJyYXkuXG4gKi9cbmV4cG9ydCBmdW5jdGlvbiBheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzOiBudW1iZXJbXSwgcmFuazogbnVtYmVyKTogYm9vbGVhbiB7XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgYXhlcy5sZW5ndGg7ICsraSkge1xuICAgIGlmIChheGVzW2F4ZXMubGVuZ3RoIC0gaSAtIDFdICE9PSByYW5rIC0gMSAtIGkpIHtcbiAgICAgIHJldHVybiBmYWxzZTtcbiAgICB9XG4gIH1cbiAgcmV0dXJuIHRydWU7XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBjb21iaW5lTG9jYXRpb25zKFxuICAgIG91dHB1dExvYzogbnVtYmVyW10sIHJlZHVjZUxvYzogbnVtYmVyW10sIGF4ZXM6IG51bWJlcltdKTogbnVtYmVyW10ge1xuICBjb25zdCByYW5rID0gb3V0cHV0TG9jLmxlbmd0aCArIHJlZHVjZUxvYy5sZW5ndGg7XG4gIGNvbnN0IGxvYyA9IFtdO1xuICBsZXQgb3V0SWR4ID0gMDtcbiAgbGV0IHJlZHVjZUlkeCA9IDA7XG4gIMKgIGZvciAobGV0IGRpbSA9IDA7IGRpbSA8IHJhbms7IGRpbSsrKSB7XG4gICAgaWYgKGF4ZXMuaW5kZXhPZihkaW0pID09PSAtMSkge1xuICAgICAgbG9jLnB1c2gob3V0cHV0TG9jW291dElkeCsrXSk7XG4gICAgfSBlbHNlIHtcbiAgICAgIGxvYy5wdXNoKHJlZHVjZUxvY1tyZWR1Y2VJZHgrK10pO1xuICAgIH1cbiAgfVxuICByZXR1cm4gbG9jO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gY29tcHV0ZU91dEFuZFJlZHVjZVNoYXBlcyhcbiAgICBhU2hhcGU6IG51bWJlcltdLCBheGVzOiBudW1iZXJbXSk6IFtudW1iZXJbXSwgbnVtYmVyW11dIHtcbiAgY29uc3Qgb3V0U2hhcGUgPSBbXTtcbiAgY29uc3QgcmFuayA9IGFTaGFwZS5sZW5ndGg7XG4gIGZvciAobGV0IGRpbSA9IDA7IGRpbSA8IHJhbms7IGRpbSsrKSB7XG4gICAgaWYgKGF4ZXMuaW5kZXhPZihkaW0pID09PSAtMSkge1xuICAgICAgb3V0U2hhcGUucHVzaChhU2hhcGVbZGltXSk7XG4gICAgfVxuICB9XG4gIGNvbnN0IHJlZHVjZVNoYXBlID0gYXhlcy5tYXAoZGltID0+IGFTaGFwZVtkaW1dKTtcbiAgcmV0dXJuIFtvdXRTaGFwZSwgcmVkdWNlU2hhcGVdO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gZXhwYW5kU2hhcGVUb0tlZXBEaW0oXG4gICAgc2hhcGU6IG51bWJlcltdLCBheGVzOiBudW1iZXJbXSk6IG51bWJlcltdIHtcbiAgY29uc3QgcmVkdWNlU3ViU2hhcGUgPSBheGVzLm1hcCh4ID0+IDEpO1xuICByZXR1cm4gY29tYmluZUxvY2F0aW9ucyhzaGFwZSwgcmVkdWNlU3ViU2hhcGUsIGF4ZXMpO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gYXNzZXJ0QXhlc0FyZUlubmVyTW9zdERpbXMoXG4gICAgbXNnOiBzdHJpbmcsIGF4ZXM6IG51bWJlcltdLCByYW5rOiBudW1iZXIpOiB2b2lkIHtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICBheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzLCByYW5rKSxcbiAgICAgICgpID0+IGAke21zZ30gc3VwcG9ydHMgb25seSBpbm5lci1tb3N0IGF4ZXMgZm9yIG5vdy4gYCArXG4gICAgICAgICAgYEdvdCBheGVzICR7YXhlc30gYW5kIHJhbmstJHtyYW5rfSBpbnB1dC5gKTtcbn1cblxuLyoqXG4gKiBSZXR1cm5zIHRoZSBheGVzIHBlcm11dGF0aW9uIHRvIGJlIHVzZWQgd2l0aCBgdGYudHJhbnNwb3NlYCwgaWYgc3VjaFxuICogcGVybXV0YXRpb24gaXMgbmVjZXNzYXJ5LiBPdGhlcndpc2UgaXQgcmV0dXJucyBudWxsLiBUaGlzIG1ldGhvZCBpcyB1c2VkIGJ5XG4gKiBvcGVyYXRpb25zIHRoYXQgb3BlcmF0ZSBvbmx5IG9uIGlubmVyLW1vc3QgYXhlcy5cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGdldEF4ZXNQZXJtdXRhdGlvbihheGVzOiBudW1iZXJbXSwgcmFuazogbnVtYmVyKTogbnVtYmVyW118XG4gICAgbnVsbCB7XG4gIGlmIChheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzLCByYW5rKSkge1xuICAgIHJldHVybiBudWxsO1xuICB9XG4gIGNvbnN0IHJlc3VsdDogbnVtYmVyW10gPSBbXTtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCByYW5rOyArK2kpIHtcbiAgICBpZiAoYXhlcy5pbmRleE9mKGkpID09PSAtMSkge1xuICAgICAgcmVzdWx0LnB1c2goaSk7XG4gICAgfVxuICB9XG4gIGF4ZXMuZm9yRWFjaChheGlzID0+IHJlc3VsdC5wdXNoKGF4aXMpKTtcbiAgcmV0dXJuIHJlc3VsdDtcbn1cblxuLyoqIFJldHVybnMgdGhlIGF4ZXMgcGVybXV0YXRpb24gdGhhdCB1bmRvZXMgdGhlIG9yaWdpbmFsIHBlcm11dGF0aW9uLiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGdldFVuZG9BeGVzUGVybXV0YXRpb24oYXhlczogbnVtYmVyW10pOiBudW1iZXJbXSB7XG4gIHJldHVybiBheGVzLm1hcCgoYXhpcywgaSkgPT4gW2ksIGF4aXNdKVxuICAgICAgLnNvcnQoKGEsIGIpID0+IGFbMV0gLSBiWzFdKVxuICAgICAgLm1hcCh4ID0+IHhbMF0pO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gZ2V0SW5uZXJNb3N0QXhlcyhudW1BeGVzOiBudW1iZXIsIHJhbms6IG51bWJlcik6IG51bWJlcltdIHtcbiAgY29uc3QgcmVzOiBudW1iZXJbXSA9IFtdO1xuICBmb3IgKGxldCBpID0gcmFuayAtIG51bUF4ZXM7IGkgPCByYW5rOyArK2kpIHtcbiAgICByZXMucHVzaChpKTtcbiAgfVxuICByZXR1cm4gcmVzO1xufVxuIl19
\No newline at end of file