1 | /**
|
2 | * @license
|
3 | * Copyright 2017 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import * as util from '../util';
|
18 | /**
|
19 | * Returns true if the axis specifies the inner most dimensions of the
|
20 | * array.
|
21 | */
|
22 | export function axesAreInnerMostDims(axes, rank) {
|
23 | for (let i = 0; i < axes.length; ++i) {
|
24 | if (axes[axes.length - i - 1] !== rank - 1 - i) {
|
25 | return false;
|
26 | }
|
27 | }
|
28 | return true;
|
29 | }
|
30 | export function combineLocations(outputLoc, reduceLoc, axes) {
|
31 | const rank = outputLoc.length + reduceLoc.length;
|
32 | const loc = [];
|
33 | let outIdx = 0;
|
34 | let reduceIdx = 0;
|
35 | for (let dim = 0; dim < rank; dim++) {
|
36 | if (axes.indexOf(dim) === -1) {
|
37 | loc.push(outputLoc[outIdx++]);
|
38 | }
|
39 | else {
|
40 | loc.push(reduceLoc[reduceIdx++]);
|
41 | }
|
42 | }
|
43 | return loc;
|
44 | }
|
45 | export function computeOutAndReduceShapes(aShape, axes) {
|
46 | const outShape = [];
|
47 | const rank = aShape.length;
|
48 | for (let dim = 0; dim < rank; dim++) {
|
49 | if (axes.indexOf(dim) === -1) {
|
50 | outShape.push(aShape[dim]);
|
51 | }
|
52 | }
|
53 | const reduceShape = axes.map(dim => aShape[dim]);
|
54 | return [outShape, reduceShape];
|
55 | }
|
56 | export function expandShapeToKeepDim(shape, axes) {
|
57 | const reduceSubShape = axes.map(x => 1);
|
58 | return combineLocations(shape, reduceSubShape, axes);
|
59 | }
|
60 | export function assertAxesAreInnerMostDims(msg, axes, rank) {
|
61 | util.assert(axesAreInnerMostDims(axes, rank), () => `${msg} supports only inner-most axes for now. ` +
|
62 | `Got axes ${axes} and rank-${rank} input.`);
|
63 | }
|
64 | /**
|
65 | * Returns the axes permutation to be used with `tf.transpose`, if such
|
66 | * permutation is necessary. Otherwise it returns null. This method is used by
|
67 | * operations that operate only on inner-most axes.
|
68 | */
|
69 | export function getAxesPermutation(axes, rank) {
|
70 | if (axesAreInnerMostDims(axes, rank)) {
|
71 | return null;
|
72 | }
|
73 | const result = [];
|
74 | for (let i = 0; i < rank; ++i) {
|
75 | if (axes.indexOf(i) === -1) {
|
76 | result.push(i);
|
77 | }
|
78 | }
|
79 | axes.forEach(axis => result.push(axis));
|
80 | return result;
|
81 | }
|
82 | /** Returns the axes permutation that undoes the original permutation. */
|
83 | export function getUndoAxesPermutation(axes) {
|
84 | return axes.map((axis, i) => [i, axis])
|
85 | .sort((a, b) => a[1] - b[1])
|
86 | .map(x => x[0]);
|
87 | }
|
88 | export function getInnerMostAxes(numAxes, rank) {
|
89 | const res = [];
|
90 | for (let i = rank - numAxes; i < rank; ++i) {
|
91 | res.push(i);
|
92 | }
|
93 | return res;
|
94 | }
|
95 | //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"axis_util.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/axis_util.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,IAAI,MAAM,SAAS,CAAC;AAEhC;;;GAGG;AACH,MAAM,UAAU,oBAAoB,CAAC,IAAc,EAAE,IAAY;IAC/D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACpC,IAAI,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,CAAC,GAAG,CAAC,CAAC,KAAK,IAAI,GAAG,CAAC,GAAG,CAAC,EAAE;YAC9C,OAAO,KAAK,CAAC;SACd;KACF;IACD,OAAO,IAAI,CAAC;AACd,CAAC;AAED,MAAM,UAAU,gBAAgB,CAC5B,SAAmB,EAAE,SAAmB,EAAE,IAAc;IAC1D,MAAM,IAAI,GAAG,SAAS,CAAC,MAAM,GAAG,SAAS,CAAC,MAAM,CAAC;IACjD,MAAM,GAAG,GAAG,EAAE,CAAC;IACf,IAAI,MAAM,GAAG,CAAC,CAAC;IACf,IAAI,SAAS,GAAG,CAAC,CAAC;IAChB,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,IAAI,EAAE,GAAG,EAAE,EAAE;QACrC,IAAI,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,EAAE;YAC5B,GAAG,CAAC,IAAI,CAAC,SAAS,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;SAC/B;aAAM;YACL,GAAG,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC;SAClC;KACF;IACD,OAAO,GAAG,CAAC;AACb,CAAC;AAED,MAAM,UAAU,yBAAyB,CACrC,MAAgB,EAAE,IAAc;IAClC,MAAM,QAAQ,GAAG,EAAE,CAAC;IACpB,MAAM,IAAI,GAAG,MAAM,CAAC,MAAM,CAAC;IAC3B,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,IAAI,EAAE,GAAG,EAAE,EAAE;QACnC,IAAI,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,EAAE;YAC5B,QAAQ,CAAC,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC;SAC5B;KACF;IACD,MAAM,WAAW,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC;IACjD,OAAO,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAC;AACjC,CAAC;AAED,MAAM,UAAU,oBAAoB,CAChC,KAAe,EAAE,IAAc;IACjC,MAAM,cAAc,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACxC,OAAO,gBAAgB,CAAC,KAAK,EAAE,cAAc,EAAE,IAAI,CAAC,CAAC;AACvD,CAAC;AAED,MAAM,UAAU,0BAA0B,CACtC,GAAW,EAAE,IAAc,EAAE,IAAY;IAC3C,IAAI,CAAC,MAAM,CACP,oBAAoB,CAAC,IAAI,EAAE,IAAI,CAAC,EAChC,GAAG,EAAE,CAAC,GAAG,GAAG,0CAA0C;QAClD,YAAY,IAAI,aAAa,IAAI,SAAS,CAAC,CAAC;AACtD,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,kBAAkB,CAAC,IAAc,EAAE,IAAY;IAE7D,IAAI,oBAAoB,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE;QACpC,OAAO,IAAI,CAAC;KACb;IACD,MAAM,MAAM,GAAa,EAAE,CAAC;IAC5B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,EAAE,CAAC,EAAE;QAC7B,IAAI,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE;YAC1B,MAAM,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SAChB;KACF;IACD,IAAI,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC;IACxC,OAAO,MAAM,CAAC;AAChB,CAAC;AAED,yEAAyE;AACzE,MAAM,UAAU,sBAAsB,CAAC,IAAc;IACnD,OAAO,IAAI,CAAC,GAAG,CAAC,CAAC,IAAI,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;SAClC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC;SAC3B,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;AACtB,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,OAAe,EAAE,IAAY;IAC5D,MAAM,GAAG,GAAa,EAAE,CAAC;IACzB,KAAK,IAAI,CAAC,GAAG,IAAI,GAAG,OAAO,EAAE,CAAC,GAAG,IAAI,EAAE,EAAE,CAAC,EAAE;QAC1C,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;KACb;IACD,OAAO,GAAG,CAAC;AACb,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2017 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as util from '../util';\n\n/**\n * Returns true if the axis specifies the inner most dimensions of the\n * array.\n */\nexport function axesAreInnerMostDims(axes: number[], rank: number): boolean {\n  for (let i = 0; i < axes.length; ++i) {\n    if (axes[axes.length - i - 1] !== rank - 1 - i) {\n      return false;\n    }\n  }\n  return true;\n}\n\nexport function combineLocations(\n    outputLoc: number[], reduceLoc: number[], axes: number[]): number[] {\n  const rank = outputLoc.length + reduceLoc.length;\n  const loc = [];\n  let outIdx = 0;\n  let reduceIdx = 0;\n    for (let dim = 0; dim < rank; dim++) {\n    if (axes.indexOf(dim) === -1) {\n      loc.push(outputLoc[outIdx++]);\n    } else {\n      loc.push(reduceLoc[reduceIdx++]);\n    }\n  }\n  return loc;\n}\n\nexport function computeOutAndReduceShapes(\n    aShape: number[], axes: number[]): [number[], number[]] {\n  const outShape = [];\n  const rank = aShape.length;\n  for (let dim = 0; dim < rank; dim++) {\n    if (axes.indexOf(dim) === -1) {\n      outShape.push(aShape[dim]);\n    }\n  }\n  const reduceShape = axes.map(dim => aShape[dim]);\n  return [outShape, reduceShape];\n}\n\nexport function expandShapeToKeepDim(\n    shape: number[], axes: number[]): number[] {\n  const reduceSubShape = axes.map(x => 1);\n  return combineLocations(shape, reduceSubShape, axes);\n}\n\nexport function assertAxesAreInnerMostDims(\n    msg: string, axes: number[], rank: number): void {\n  util.assert(\n      axesAreInnerMostDims(axes, rank),\n      () => `${msg} supports only inner-most axes for now. ` +\n          `Got axes ${axes} and rank-${rank} input.`);\n}\n\n/**\n * Returns the axes permutation to be used with `tf.transpose`, if such\n * permutation is necessary. Otherwise it returns null. This method is used by\n * operations that operate only on inner-most axes.\n */\nexport function getAxesPermutation(axes: number[], rank: number): number[]|\n    null {\n  if (axesAreInnerMostDims(axes, rank)) {\n    return null;\n  }\n  const result: number[] = [];\n  for (let i = 0; i < rank; ++i) {\n    if (axes.indexOf(i) === -1) {\n      result.push(i);\n    }\n  }\n  axes.forEach(axis => result.push(axis));\n  return result;\n}\n\n/** Returns the axes permutation that undoes the original permutation. */\nexport function getUndoAxesPermutation(axes: number[]): number[] {\n  return axes.map((axis, i) => [i, axis])\n      .sort((a, b) => a[1] - b[1])\n      .map(x => x[0]);\n}\n\nexport function getInnerMostAxes(numAxes: number, rank: number): number[] {\n  const res: number[] = [];\n  for (let i = rank - numAxes; i < rank; ++i) {\n    res.push(i);\n  }\n  return res;\n}\n"]} |
\ | No newline at end of file |