1 | /**
|
2 | * @license
|
3 | * Copyright 2017 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import * as util from '../util';
|
18 | /**
|
19 | * Returns true if the axis specifies the inner most dimensions of the
|
20 | * array.
|
21 | */
|
22 | export function axesAreInnerMostDims(axes, rank) {
|
23 | for (let i = 0; i < axes.length; ++i) {
|
24 | if (axes[axes.length - i - 1] !== rank - 1 - i) {
|
25 | return false;
|
26 | }
|
27 | }
|
28 | return true;
|
29 | }
|
30 | export function combineLocations(outputLoc, reduceLoc, axes) {
|
31 | const rank = outputLoc.length + reduceLoc.length;
|
32 | const loc = [];
|
33 | let outIdx = 0;
|
34 | let reduceIdx = 0;
|
35 | for (let dim = 0; dim < rank; dim++) {
|
36 | if (axes.indexOf(dim) === -1) {
|
37 | loc.push(outputLoc[outIdx++]);
|
38 | }
|
39 | else {
|
40 | loc.push(reduceLoc[reduceIdx++]);
|
41 | }
|
42 | }
|
43 | return loc;
|
44 | }
|
45 | export function computeOutAndReduceShapes(aShape, axes) {
|
46 | const outShape = [];
|
47 | const rank = aShape.length;
|
48 | for (let dim = 0; dim < rank; dim++) {
|
49 | if (axes.indexOf(dim) === -1) {
|
50 | outShape.push(aShape[dim]);
|
51 | }
|
52 | }
|
53 | const reduceShape = axes.map(dim => aShape[dim]);
|
54 | return [outShape, reduceShape];
|
55 | }
|
56 | export function expandShapeToKeepDim(shape, axes) {
|
57 | const reduceSubShape = axes.map(x => 1);
|
58 | return combineLocations(shape, reduceSubShape, axes);
|
59 | }
|
60 | export function assertAxesAreInnerMostDims(msg, axes, rank) {
|
61 | util.assert(axesAreInnerMostDims(axes, rank), () => `${msg} supports only inner-most axes for now. ` +
|
62 | `Got axes ${axes} and rank-${rank} input.`);
|
63 | }
|
64 | /**
|
65 | * Returns the axes permutation to be used with `tf.transpose`, if such
|
66 | * permutation is necessary. Otherwise it returns null. This method is used by
|
67 | * operations that operate only on inner-most axes.
|
68 | */
|
69 | export function getAxesPermutation(axes, rank) {
|
70 | if (axesAreInnerMostDims(axes, rank)) {
|
71 | return null;
|
72 | }
|
73 | const result = [];
|
74 | for (let i = 0; i < rank; ++i) {
|
75 | if (axes.indexOf(i) === -1) {
|
76 | result.push(i);
|
77 | }
|
78 | }
|
79 | axes.forEach(axis => result.push(axis));
|
80 | return result;
|
81 | }
|
82 | /** Returns the axes permutation that undoes the original permutation. */
|
83 | export function getUndoAxesPermutation(axes) {
|
84 | return axes.map((axis, i) => [i, axis])
|
85 | .sort((a, b) => a[1] - b[1])
|
86 | .map(x => x[0]);
|
87 | }
|
88 | export function getInnerMostAxes(numAxes, rank) {
|
89 | const res = [];
|
90 | for (let i = rank - numAxes; i < rank; ++i) {
|
91 | res.push(i);
|
92 | }
|
93 | return res;
|
94 | }
|
95 | //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYXhpc191dGlsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvYXhpc191dGlsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sS0FBSyxJQUFJLE1BQU0sU0FBUyxDQUFDO0FBRWhDOzs7R0FHRztBQUNILE1BQU0sVUFBVSxvQkFBb0IsQ0FBQyxJQUFjLEVBQUUsSUFBWTtJQUMvRCxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsSUFBSSxDQUFDLE1BQU0sRUFBRSxFQUFFLENBQUMsRUFBRTtRQUNwQyxJQUFJLElBQUksQ0FBQyxJQUFJLENBQUMsTUFBTSxHQUFHLENBQUMsR0FBRyxDQUFDLENBQUMsS0FBSyxJQUFJLEdBQUcsQ0FBQyxHQUFHLENBQUMsRUFBRTtZQUM5QyxPQUFPLEtBQUssQ0FBQztTQUNkO0tBQ0Y7SUFDRCxPQUFPLElBQUksQ0FBQztBQUNkLENBQUM7QUFFRCxNQUFNLFVBQVUsZ0JBQWdCLENBQzVCLFNBQW1CLEVBQUUsU0FBbUIsRUFBRSxJQUFjO0lBQzFELE1BQU0sSUFBSSxHQUFHLFNBQVMsQ0FBQyxNQUFNLEdBQUcsU0FBUyxDQUFDLE1BQU0sQ0FBQztJQUNqRCxNQUFNLEdBQUcsR0FBRyxFQUFFLENBQUM7SUFDZixJQUFJLE1BQU0sR0FBRyxDQUFDLENBQUM7SUFDZixJQUFJLFNBQVMsR0FBRyxDQUFDLENBQUM7SUFDaEIsS0FBSyxJQUFJLEdBQUcsR0FBRyxDQUFDLEVBQUUsR0FBRyxHQUFHLElBQUksRUFBRSxHQUFHLEVBQUUsRUFBRTtRQUNyQyxJQUFJLElBQUksQ0FBQyxPQUFPLENBQUMsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLEVBQUU7WUFDNUIsR0FBRyxDQUFDLElBQUksQ0FBQyxTQUFTLENBQUMsTUFBTSxFQUFFLENBQUMsQ0FBQyxDQUFDO1NBQy9CO2FBQU07WUFDTCxHQUFHLENBQUMsSUFBSSxDQUFDLFNBQVMsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxDQUFDLENBQUM7U0FDbEM7S0FDRjtJQUNELE9BQU8sR0FBRyxDQUFDO0FBQ2IsQ0FBQztBQUVELE1BQU0sVUFBVSx5QkFBeUIsQ0FDckMsTUFBZ0IsRUFBRSxJQUFjO0lBQ2xDLE1BQU0sUUFBUSxHQUFHLEVBQUUsQ0FBQztJQUNwQixNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsTUFBTSxDQUFDO0lBQzNCLEtBQUssSUFBSSxHQUFHLEdBQUcsQ0FBQyxFQUFFLEdBQUcsR0FBRyxJQUFJLEVBQUUsR0FBRyxFQUFFLEVBQUU7UUFDbkMsSUFBSSxJQUFJLENBQUMsT0FBTyxDQUFDLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxFQUFFO1lBQzVCLFFBQVEsQ0FBQyxJQUFJLENBQUMsTUFBTSxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7U0FDNUI7S0FDRjtJQUNELE1BQU0sV0FBVyxHQUFHLElBQUksQ0FBQyxHQUFHLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxNQUFNLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztJQUNqRCxPQUFPLENBQUMsUUFBUSxFQUFFLFdBQVcsQ0FBQyxDQUFDO0FBQ2pDLENBQUM7QUFFRCxNQUFNLFVBQVUsb0JBQW9CLENBQ2hDLEtBQWUsRUFBRSxJQUFjO0lBQ2pDLE1BQU0sY0FBYyxHQUFHLElBQUksQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUN4QyxPQUFPLGdCQUFnQixDQUFDLEtBQUssRUFBRSxjQUFjLEVBQUUsSUFBSSxDQUFDLENBQUM7QUFDdkQsQ0FBQztBQUVELE1BQU0sVUFBVSwwQkFBMEIsQ0FDdEMsR0FBVyxFQUFFLElBQWMsRUFBRSxJQUFZO0lBQzNDLElBQUksQ0FBQyxNQUFNLENBQ1Asb0JBQW9CLENBQUMsSUFBSSxFQUFFLElBQUksQ0FBQyxFQUNoQyxHQUFHLEVBQUUsQ0FBQyxHQUFHLEdBQUcsMENBQTBDO1FBQ2xELFlBQVksSUFBSSxhQUFhLElBQUksU0FBUyxDQUFDLENBQUM7QUFDdEQsQ0FBQztBQUVEOzs7O0dBSUc7QUFDSCxNQUFNLFVBQVUsa0JBQWtCLENBQUMsSUFBYyxFQUFFLElBQVk7SUFFN0QsSUFBSSxvQkFBb0IsQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLEVBQUU7UUFDcEMsT0FBTyxJQUFJLENBQUM7S0FDYjtJQUNELE1BQU0sTUFBTSxHQUFhLEVBQUUsQ0FBQztJQUM1QixLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsSUFBSSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQzdCLElBQUksSUFBSSxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsRUFBRTtZQUMxQixNQUFNLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO1NBQ2hCO0tBQ0Y7SUFDRCxJQUFJLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQyxFQUFFLENBQUMsTUFBTSxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDO0lBQ3hDLE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUM7QUFFRCx5RUFBeUU7QUFDekUsTUFBTSxVQUFVLHNCQUFzQixDQUFDLElBQWM7SUFDbkQsT0FBTyxJQUFJLENBQUMsR0FBRyxDQUFDLENBQUMsSUFBSSxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsSUFBSSxDQUFDLENBQUM7U0FDbEMsSUFBSSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztTQUMzQixHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztBQUN0QixDQUFDO0FBRUQsTUFBTSxVQUFVLGdCQUFnQixDQUFDLE9BQWUsRUFBRSxJQUFZO0lBQzVELE1BQU0sR0FBRyxHQUFhLEVBQUUsQ0FBQztJQUN6QixLQUFLLElBQUksQ0FBQyxHQUFHLElBQUksR0FBRyxPQUFPLEVBQUUsQ0FBQyxHQUFHLElBQUksRUFBRSxFQUFFLENBQUMsRUFBRTtRQUMxQyxHQUFHLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQ2I7SUFDRCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxNyBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCAqIGFzIHV0aWwgZnJvbSAnLi4vdXRpbCc7XG5cbi8qKlxuICogUmV0dXJucyB0cnVlIGlmIHRoZSBheGlzIHNwZWNpZmllcyB0aGUgaW5uZXIgbW9zdCBkaW1lbnNpb25zIG9mIHRoZVxuICogYXJyYXkuXG4gKi9cbmV4cG9ydCBmdW5jdGlvbiBheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzOiBudW1iZXJbXSwgcmFuazogbnVtYmVyKTogYm9vbGVhbiB7XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgYXhlcy5sZW5ndGg7ICsraSkge1xuICAgIGlmIChheGVzW2F4ZXMubGVuZ3RoIC0gaSAtIDFdICE9PSByYW5rIC0gMSAtIGkpIHtcbiAgICAgIHJldHVybiBmYWxzZTtcbiAgICB9XG4gIH1cbiAgcmV0dXJuIHRydWU7XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBjb21iaW5lTG9jYXRpb25zKFxuICAgIG91dHB1dExvYzogbnVtYmVyW10sIHJlZHVjZUxvYzogbnVtYmVyW10sIGF4ZXM6IG51bWJlcltdKTogbnVtYmVyW10ge1xuICBjb25zdCByYW5rID0gb3V0cHV0TG9jLmxlbmd0aCArIHJlZHVjZUxvYy5sZW5ndGg7XG4gIGNvbnN0IGxvYyA9IFtdO1xuICBsZXQgb3V0SWR4ID0gMDtcbiAgbGV0IHJlZHVjZUlkeCA9IDA7XG4gIMKgIGZvciAobGV0IGRpbSA9IDA7IGRpbSA8IHJhbms7IGRpbSsrKSB7XG4gICAgaWYgKGF4ZXMuaW5kZXhPZihkaW0pID09PSAtMSkge1xuICAgICAgbG9jLnB1c2gob3V0cHV0TG9jW291dElkeCsrXSk7XG4gICAgfSBlbHNlIHtcbiAgICAgIGxvYy5wdXNoKHJlZHVjZUxvY1tyZWR1Y2VJZHgrK10pO1xuICAgIH1cbiAgfVxuICByZXR1cm4gbG9jO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gY29tcHV0ZU91dEFuZFJlZHVjZVNoYXBlcyhcbiAgICBhU2hhcGU6IG51bWJlcltdLCBheGVzOiBudW1iZXJbXSk6IFtudW1iZXJbXSwgbnVtYmVyW11dIHtcbiAgY29uc3Qgb3V0U2hhcGUgPSBbXTtcbiAgY29uc3QgcmFuayA9IGFTaGFwZS5sZW5ndGg7XG4gIGZvciAobGV0IGRpbSA9IDA7IGRpbSA8IHJhbms7IGRpbSsrKSB7XG4gICAgaWYgKGF4ZXMuaW5kZXhPZihkaW0pID09PSAtMSkge1xuICAgICAgb3V0U2hhcGUucHVzaChhU2hhcGVbZGltXSk7XG4gICAgfVxuICB9XG4gIGNvbnN0IHJlZHVjZVNoYXBlID0gYXhlcy5tYXAoZGltID0+IGFTaGFwZVtkaW1dKTtcbiAgcmV0dXJuIFtvdXRTaGFwZSwgcmVkdWNlU2hhcGVdO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gZXhwYW5kU2hhcGVUb0tlZXBEaW0oXG4gICAgc2hhcGU6IG51bWJlcltdLCBheGVzOiBudW1iZXJbXSk6IG51bWJlcltdIHtcbiAgY29uc3QgcmVkdWNlU3ViU2hhcGUgPSBheGVzLm1hcCh4ID0+IDEpO1xuICByZXR1cm4gY29tYmluZUxvY2F0aW9ucyhzaGFwZSwgcmVkdWNlU3ViU2hhcGUsIGF4ZXMpO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gYXNzZXJ0QXhlc0FyZUlubmVyTW9zdERpbXMoXG4gICAgbXNnOiBzdHJpbmcsIGF4ZXM6IG51bWJlcltdLCByYW5rOiBudW1iZXIpOiB2b2lkIHtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICBheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzLCByYW5rKSxcbiAgICAgICgpID0+IGAke21zZ30gc3VwcG9ydHMgb25seSBpbm5lci1tb3N0IGF4ZXMgZm9yIG5vdy4gYCArXG4gICAgICAgICAgYEdvdCBheGVzICR7YXhlc30gYW5kIHJhbmstJHtyYW5rfSBpbnB1dC5gKTtcbn1cblxuLyoqXG4gKiBSZXR1cm5zIHRoZSBheGVzIHBlcm11dGF0aW9uIHRvIGJlIHVzZWQgd2l0aCBgdGYudHJhbnNwb3NlYCwgaWYgc3VjaFxuICogcGVybXV0YXRpb24gaXMgbmVjZXNzYXJ5LiBPdGhlcndpc2UgaXQgcmV0dXJucyBudWxsLiBUaGlzIG1ldGhvZCBpcyB1c2VkIGJ5XG4gKiBvcGVyYXRpb25zIHRoYXQgb3BlcmF0ZSBvbmx5IG9uIGlubmVyLW1vc3QgYXhlcy5cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGdldEF4ZXNQZXJtdXRhdGlvbihheGVzOiBudW1iZXJbXSwgcmFuazogbnVtYmVyKTogbnVtYmVyW118XG4gICAgbnVsbCB7XG4gIGlmIChheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzLCByYW5rKSkge1xuICAgIHJldHVybiBudWxsO1xuICB9XG4gIGNvbnN0IHJlc3VsdDogbnVtYmVyW10gPSBbXTtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCByYW5rOyArK2kpIHtcbiAgICBpZiAoYXhlcy5pbmRleE9mKGkpID09PSAtMSkge1xuICAgICAgcmVzdWx0LnB1c2goaSk7XG4gICAgfVxuICB9XG4gIGF4ZXMuZm9yRWFjaChheGlzID0+IHJlc3VsdC5wdXNoKGF4aXMpKTtcbiAgcmV0dXJuIHJlc3VsdDtcbn1cblxuLyoqIFJldHVybnMgdGhlIGF4ZXMgcGVybXV0YXRpb24gdGhhdCB1bmRvZXMgdGhlIG9yaWdpbmFsIHBlcm11dGF0aW9uLiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGdldFVuZG9BeGVzUGVybXV0YXRpb24oYXhlczogbnVtYmVyW10pOiBudW1iZXJbXSB7XG4gIHJldHVybiBheGVzLm1hcCgoYXhpcywgaSkgPT4gW2ksIGF4aXNdKVxuICAgICAgLnNvcnQoKGEsIGIpID0+IGFbMV0gLSBiWzFdKVxuICAgICAgLm1hcCh4ID0+IHhbMF0pO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gZ2V0SW5uZXJNb3N0QXhlcyhudW1BeGVzOiBudW1iZXIsIHJhbms6IG51bWJlcik6IG51bWJlcltdIHtcbiAgY29uc3QgcmVzOiBudW1iZXJbXSA9IFtdO1xuICBmb3IgKGxldCBpID0gcmFuayAtIG51bUF4ZXM7IGkgPCByYW5rOyArK2kpIHtcbiAgICByZXMucHVzaChpKTtcbiAgfVxuICByZXR1cm4gcmVzO1xufVxuIl19 |
\ | No newline at end of file |