1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import { ENGINE } from '../engine';
|
18 | import { Transpose } from '../kernel_names';
|
19 | import { convertToTensor } from '../tensor_util_env';
|
20 | import * as util from '../util';
|
21 | import { op } from './operation';
|
22 | /**
|
23 | * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`.
|
24 | *
|
25 | * The returned `tf.Tensor`'s dimension `i` will correspond to the input
|
26 | * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`,
|
27 | * where `n` is the rank of the input `tf.Tensor`. Hence by default, this
|
28 | * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s.
|
29 | *
|
30 | * ```js
|
31 | * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
32 | *
|
33 | * a.transpose().print(); // or tf.transpose(a)
|
34 | * ```
|
35 | *
|
36 | * @param x The tensor to transpose.
|
37 | * @param perm The permutation of the dimensions of a.
|
38 | *
|
39 | * @doc {heading: 'Operations', subheading: 'Matrices'}
|
40 | */
|
41 | function transpose_(x, perm) {
|
42 | const $x = convertToTensor(x, 'x', 'transpose');
|
43 | if (perm == null) {
|
44 | perm = $x.shape.map((s, i) => i).reverse();
|
45 | }
|
46 | util.assert($x.rank === perm.length, () => `Error in transpose: rank of input ${$x.rank} ` +
|
47 | `must match length of perm ${perm}.`);
|
48 | perm.forEach(axis => {
|
49 | util.assert(axis >= 0 && axis < $x.rank, () => `All entries in 'perm' must be between 0 and ${$x.rank - 1}` +
|
50 | ` but got ${perm}`);
|
51 | });
|
52 | if ($x.rank <= 1) {
|
53 | return $x.clone();
|
54 | }
|
55 | const inputs = { x: $x };
|
56 | const attrs = { perm };
|
57 | return ENGINE.runKernel(Transpose, inputs, attrs);
|
58 | }
|
59 | export const transpose = op({ transpose_ });
|
60 | //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoidHJhbnNwb3NlLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvdHJhbnNwb3NlLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDakMsT0FBTyxFQUFDLFNBQVMsRUFBa0MsTUFBTSxpQkFBaUIsQ0FBQztBQUkzRSxPQUFPLEVBQUMsZUFBZSxFQUFDLE1BQU0sb0JBQW9CLENBQUM7QUFFbkQsT0FBTyxLQUFLLElBQUksTUFBTSxTQUFTLENBQUM7QUFFaEMsT0FBTyxFQUFDLEVBQUUsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUUvQjs7Ozs7Ozs7Ozs7Ozs7Ozs7O0dBa0JHO0FBQ0gsU0FBUyxVQUFVLENBQW1CLENBQWUsRUFBRSxJQUFlO0lBQ3BFLE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLFdBQVcsQ0FBQyxDQUFDO0lBRWhELElBQUksSUFBSSxJQUFJLElBQUksRUFBRTtRQUNoQixJQUFJLEdBQUcsRUFBRSxDQUFDLEtBQUssQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxPQUFPLEVBQUUsQ0FBQztLQUM1QztJQUNELElBQUksQ0FBQyxNQUFNLENBQ1AsRUFBRSxDQUFDLElBQUksS0FBSyxJQUFJLENBQUMsTUFBTSxFQUN2QixHQUFHLEVBQUUsQ0FBQyxxQ0FBcUMsRUFBRSxDQUFDLElBQUksR0FBRztRQUNqRCw2QkFBNkIsSUFBSSxHQUFHLENBQUMsQ0FBQztJQUM5QyxJQUFJLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQyxFQUFFO1FBQ2xCLElBQUksQ0FBQyxNQUFNLENBQ1AsSUFBSSxJQUFJLENBQUMsSUFBSSxJQUFJLEdBQUcsRUFBRSxDQUFDLElBQUksRUFDM0IsR0FBRyxFQUFFLENBQUMsK0NBQStDLEVBQUUsQ0FBQyxJQUFJLEdBQUcsQ0FBQyxFQUFFO1lBQzlELFlBQVksSUFBSSxFQUFFLENBQUMsQ0FBQztJQUM5QixDQUFDLENBQUMsQ0FBQztJQUVILElBQUksRUFBRSxDQUFDLElBQUksSUFBSSxDQUFDLEVBQUU7UUFDaEIsT0FBTyxFQUFFLENBQUMsS0FBSyxFQUFFLENBQUM7S0FDbkI7SUFFRCxNQUFNLE1BQU0sR0FBb0IsRUFBQyxDQUFDLEVBQUUsRUFBRSxFQUFDLENBQUM7SUFDeEMsTUFBTSxLQUFLLEdBQW1CLEVBQUMsSUFBSSxFQUFDLENBQUM7SUFFckMsT0FBTyxNQUFNLENBQUMsU0FBUyxDQUNuQixTQUFTLEVBQUUsTUFBOEIsRUFBRSxLQUEyQixDQUFDLENBQUM7QUFDOUUsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLFNBQVMsR0FBRyxFQUFFLENBQUMsRUFBQyxVQUFVLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTggR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge0VOR0lORX0gZnJvbSAnLi4vZW5naW5lJztcbmltcG9ydCB7VHJhbnNwb3NlLCBUcmFuc3Bvc2VBdHRycywgVHJhbnNwb3NlSW5wdXRzfSBmcm9tICcuLi9rZXJuZWxfbmFtZXMnO1xuaW1wb3J0IHtOYW1lZEF0dHJNYXB9IGZyb20gJy4uL2tlcm5lbF9yZWdpc3RyeSc7XG5pbXBvcnQge1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7TmFtZWRUZW5zb3JNYXB9IGZyb20gJy4uL3RlbnNvcl90eXBlcyc7XG5pbXBvcnQge2NvbnZlcnRUb1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yX3V0aWxfZW52JztcbmltcG9ydCB7VGVuc29yTGlrZX0gZnJvbSAnLi4vdHlwZXMnO1xuaW1wb3J0ICogYXMgdXRpbCBmcm9tICcuLi91dGlsJztcblxuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuXG4vKipcbiAqIFRyYW5zcG9zZXMgdGhlIGB0Zi5UZW5zb3JgLiBQZXJtdXRlcyB0aGUgZGltZW5zaW9ucyBhY2NvcmRpbmcgdG8gYHBlcm1gLlxuICpcbiAqIFRoZSByZXR1cm5lZCBgdGYuVGVuc29yYCdzIGRpbWVuc2lvbiBgaWAgd2lsbCBjb3JyZXNwb25kIHRvIHRoZSBpbnB1dFxuICogZGltZW5zaW9uIGBwZXJtW2ldYC4gSWYgYHBlcm1gIGlzIG5vdCBnaXZlbiwgaXQgaXMgc2V0IHRvIGBbbi0xLi4uMF1gLFxuICogd2hlcmUgYG5gIGlzIHRoZSByYW5rIG9mIHRoZSBpbnB1dCBgdGYuVGVuc29yYC4gSGVuY2UgYnkgZGVmYXVsdCwgdGhpc1xuICogb3BlcmF0aW9uIHBlcmZvcm1zIGEgcmVndWxhciBtYXRyaXggdHJhbnNwb3NlIG9uIDItRCBpbnB1dCBgdGYuVGVuc29yYHMuXG4gKlxuICogYGBganNcbiAqIGNvbnN0IGEgPSB0Zi50ZW5zb3IyZChbMSwgMiwgMywgNCwgNSwgNl0sIFsyLCAzXSk7XG4gKlxuICogYS50cmFuc3Bvc2UoKS5wcmludCgpOyAgLy8gb3IgdGYudHJhbnNwb3NlKGEpXG4gKiBgYGBcbiAqXG4gKiBAcGFyYW0geCBUaGUgdGVuc29yIHRvIHRyYW5zcG9zZS5cbiAqIEBwYXJhbSBwZXJtIFRoZSBwZXJtdXRhdGlvbiBvZiB0aGUgZGltZW5zaW9ucyBvZiBhLlxuICpcbiAqIEBkb2Mge2hlYWRpbmc6ICdPcGVyYXRpb25zJywgc3ViaGVhZGluZzogJ01hdHJpY2VzJ31cbiAqL1xuZnVuY3Rpb24gdHJhbnNwb3NlXzxUIGV4dGVuZHMgVGVuc29yPih4OiBUfFRlbnNvckxpa2UsIHBlcm0/OiBudW1iZXJbXSk6IFQge1xuICBjb25zdCAkeCA9IGNvbnZlcnRUb1RlbnNvcih4LCAneCcsICd0cmFuc3Bvc2UnKTtcblxuICBpZiAocGVybSA9PSBudWxsKSB7XG4gICAgcGVybSA9ICR4LnNoYXBlLm1hcCgocywgaSkgPT4gaSkucmV2ZXJzZSgpO1xuICB9XG4gIHV0aWwuYXNzZXJ0KFxuICAgICAgJHgucmFuayA9PT0gcGVybS5sZW5ndGgsXG4gICAgICAoKSA9PiBgRXJyb3IgaW4gdHJhbnNwb3NlOiByYW5rIG9mIGlucHV0ICR7JHgucmFua30gYCArXG4gICAgICAgICAgYG11c3QgbWF0Y2ggbGVuZ3RoIG9mIHBlcm0gJHtwZXJtfS5gKTtcbiAgcGVybS5mb3JFYWNoKGF4aXMgPT4ge1xuICAgIHV0aWwuYXNzZXJ0KFxuICAgICAgICBheGlzID49IDAgJiYgYXhpcyA8ICR4LnJhbmssXG4gICAgICAgICgpID0+IGBBbGwgZW50cmllcyBpbiAncGVybScgbXVzdCBiZSBiZXR3ZWVuIDAgYW5kICR7JHgucmFuayAtIDF9YCArXG4gICAgICAgICAgICBgIGJ1dCBnb3QgJHtwZXJtfWApO1xuICB9KTtcblxuICBpZiAoJHgucmFuayA8PSAxKSB7XG4gICAgcmV0dXJuICR4LmNsb25lKCk7XG4gIH1cblxuICBjb25zdCBpbnB1dHM6IFRyYW5zcG9zZUlucHV0cyA9IHt4OiAkeH07XG4gIGNvbnN0IGF0dHJzOiBUcmFuc3Bvc2VBdHRycyA9IHtwZXJtfTtcblxuICByZXR1cm4gRU5HSU5FLnJ1bktlcm5lbChcbiAgICAgIFRyYW5zcG9zZSwgaW5wdXRzIGFzIHt9IGFzIE5hbWVkVGVuc29yTWFwLCBhdHRycyBhcyB7fSBhcyBOYW1lZEF0dHJNYXApO1xufVxuXG5leHBvcnQgY29uc3QgdHJhbnNwb3NlID0gb3Aoe3RyYW5zcG9zZV99KTtcbiJdfQ== |
\ | No newline at end of file |