UNPKG

9.34 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('transpose', ALL_ENVS, () => {
21 it('of scalar is no-op', async () => {
22 const a = tf.scalar(3);
23 expectArraysClose(await tf.transpose(a).data(), [3]);
24 });
25 it('of 1D is no-op', async () => {
26 const a = tf.tensor1d([1, 2, 3]);
27 expectArraysClose(await tf.transpose(a).data(), [1, 2, 3]);
28 });
29 it('of scalar with perm of incorrect rank throws error', () => {
30 const a = tf.scalar(3);
31 const perm = [0]; // Should be empty array.
32 expect(() => tf.transpose(a, perm)).toThrowError();
33 });
34 it('of 1d with perm out of bounds throws error', () => {
35 const a = tf.tensor1d([1, 2, 3]);
36 const perm = [1];
37 expect(() => tf.transpose(a, perm)).toThrowError();
38 });
39 it('of 1d with perm incorrect rank throws error', () => {
40 const a = tf.tensor1d([1, 2, 3]);
41 const perm = [0, 0]; // Should be of length 1.
42 expect(() => tf.transpose(a, perm)).toThrowError();
43 });
44 it('2D (no change)', async () => {
45 const t = tf.tensor2d([1, 11, 2, 22, 3, 33, 4, 44], [2, 4]);
46 const t2 = tf.transpose(t, [0, 1]);
47 expect(t2.shape).toEqual(t.shape);
48 expectArraysClose(await t2.array(), await t.array());
49 });
50 it('2D (transpose)', async () => {
51 const t = tf.tensor2d([1, 11, 2, 22, 3, 33, 4, 44], [2, 4]);
52 const t2 = tf.transpose(t, [1, 0]);
53 expect(t2.shape).toEqual([4, 2]);
54 expectArraysClose(await t2.data(), [1, 3, 11, 33, 2, 4, 22, 44]);
55 });
56 it('2D, shape has ones', async () => {
57 const t = tf.tensor2d([1, 2, 3, 4], [1, 4]);
58 const t2 = tf.transpose(t, [1, 0]);
59 expect(t2.shape).toEqual([4, 1]);
60 expectArraysClose(await t2.data(), [1, 2, 3, 4]);
61 });
62 it('3D [r, c, d] => [d, r, c]', async () => {
63 const t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
64 const t2 = tf.transpose(t, [2, 0, 1]);
65 expect(t2.shape).toEqual([2, 2, 2]);
66 expectArraysClose(await t2.data(), [1, 2, 3, 4, 11, 22, 33, 44]);
67 });
68 it('3D [r, c, d] => [d, c, r]', async () => {
69 const t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
70 const t2 = tf.transpose(t, [2, 1, 0]);
71 expect(t2.shape).toEqual([2, 2, 2]);
72 expectArraysClose(await t2.data(), [1, 3, 2, 4, 11, 33, 22, 44]);
73 });
74 it('3D [r, c, d] => [d, r, c], shape has ones', async () => {
75 const perm = [2, 0, 1];
76 const t = tf.tensor3d([1, 2, 3, 4], [2, 1, 2]);
77 const tt = tf.transpose(t, perm);
78 expect(tt.shape).toEqual([2, 2, 1]);
79 expectArraysClose(await tt.data(), [1, 3, 2, 4]);
80 const t2 = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
81 const tt2 = tf.transpose(t2, perm);
82 expect(tt2.shape).toEqual([1, 2, 2]);
83 expectArraysClose(await tt2.data(), [1, 2, 3, 4]);
84 const t3 = tf.tensor3d([1, 2, 3, 4], [1, 2, 2]);
85 const tt3 = tf.transpose(t3, perm);
86 expect(tt3.shape).toEqual([2, 1, 2]);
87 expectArraysClose(await tt3.data(), [1, 3, 2, 4]);
88 });
89 it('3D [r, c, d] => [r, d, c]', async () => {
90 const perm = [0, 2, 1];
91 const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
92 const tt = tf.transpose(t, perm);
93 expect(tt.shape).toEqual([2, 2, 2]);
94 expectArraysClose(await tt.data(), [1, 3, 2, 4, 5, 7, 6, 8]);
95 });
96 it('5D [r, c, d, e, f] => [r, c, d, f, e]', async () => {
97 const t = tf.tensor5d(new Array(32).fill(0).map((x, i) => i + 1), [2, 2, 2, 2, 2]);
98 const t2 = tf.transpose(t, [0, 1, 2, 4, 3]);
99 expect(t2.shape).toEqual([2, 2, 2, 2, 2]);
100 expectArraysClose(await t2.data(), [
101 1, 3, 2, 4, 5, 7, 6, 8, 9, 11, 10, 12, 13, 15, 14, 16,
102 17, 19, 18, 20, 21, 23, 22, 24, 25, 27, 26, 28, 29, 31, 30, 32
103 ]);
104 });
105 it('4D [r, c, d, e] => [c, r, d, e]', async () => {
106 const t = tf.tensor4d(new Array(16).fill(0).map((x, i) => i + 1), [2, 2, 2, 2]);
107 const t2 = tf.transpose(t, [1, 0, 2, 3]);
108 expect(t2.shape).toEqual([2, 2, 2, 2]);
109 expectArraysClose(await t2.data(), [1, 2, 3, 4, 9, 10, 11, 12, 5, 6, 7, 8, 13, 14, 15, 16]);
110 });
111 it('4D [r, c, d, e] => [c, r, e, d]', async () => {
112 const t = tf.tensor4d(new Array(16).fill(0).map((x, i) => i + 1), [2, 2, 2, 2]);
113 const t2 = tf.transpose(t, [1, 0, 3, 2]);
114 expect(t2.shape).toEqual([2, 2, 2, 2]);
115 expectArraysClose(await t2.data(), [1, 3, 2, 4, 9, 11, 10, 12, 5, 7, 6, 8, 13, 15, 14, 16]);
116 });
117 it('4D [r, c, d, e] => [e, r, c, d]', async () => {
118 const t = tf.tensor4d(new Array(16).fill(0).map((x, i) => i + 1), [2, 2, 2, 2]);
119 const t2 = tf.transpose(t, [3, 0, 1, 2]);
120 expect(t2.shape).toEqual([2, 2, 2, 2]);
121 expectArraysClose(await t2.data(), [1, 3, 5, 7, 9, 11, 13, 15, 2, 4, 6, 8, 10, 12, 14, 16]);
122 });
123 it('4D [r, c, d, e] => [d, c, e, r]', async () => {
124 const t = tf.tensor4d(new Array(16).fill(0).map((x, i) => i + 1), [2, 2, 2, 2]);
125 const t2 = tf.transpose(t, [2, 1, 3, 0]);
126 expect(t2.shape).toEqual([2, 2, 2, 2]);
127 expectArraysClose(await t2.data(), [1, 9, 2, 10, 5, 13, 6, 14, 3, 11, 4, 12, 7, 15, 8, 16]);
128 });
129 it('5D [r, c, d, e, f] => [c, r, d, e, f]', async () => {
130 const t = tf.tensor5d(new Array(32).fill(0).map((x, i) => i + 1), [2, 2, 2, 2, 2]);
131 const t2 = tf.transpose(t, [1, 0, 2, 3, 4]);
132 expect(t2.shape).toEqual([2, 2, 2, 2, 2]);
133 expectArraysClose(await t2.data(), [
134 1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24,
135 9, 10, 11, 12, 13, 14, 15, 16, 25, 26, 27, 28, 29, 30, 31, 32
136 ]);
137 });
138 it('6D [r, c, d, e, f] => [r, c, d, f, e]', async () => {
139 const t = tf.tensor6d(new Array(64).fill(0).map((x, i) => i + 1), [2, 2, 2, 2, 2, 2]);
140 const t2 = tf.transpose(t, [0, 1, 2, 3, 5, 4]);
141 expect(t2.shape).toEqual([2, 2, 2, 2, 2, 2]);
142 expectArraysClose(await t2.data(), [
143 1, 3, 2, 4, 5, 7, 6, 8, 9, 11, 10, 12, 13, 15, 14, 16,
144 17, 19, 18, 20, 21, 23, 22, 24, 25, 27, 26, 28, 29, 31, 30, 32,
145 33, 35, 34, 36, 37, 39, 38, 40, 41, 43, 42, 44, 45, 47, 46, 48,
146 49, 51, 50, 52, 53, 55, 54, 56, 57, 59, 58, 60, 61, 63, 62, 64
147 ]);
148 });
149 it('6D [r, c, d, e, f, g] => [c, r, d, e, f, g]', async () => {
150 const t = tf.tensor6d(new Array(64).fill(0).map((x, i) => i + 1), [2, 2, 2, 2, 2, 2]);
151 const t2 = tf.transpose(t, [1, 0, 2, 3, 4, 5]);
152 expect(t2.shape).toEqual([2, 2, 2, 2, 2, 2]);
153 expectArraysClose(await t2.data(), [
154 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
155 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
156 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
157 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64
158 ]);
159 });
160 it('gradient 3D [r, c, d] => [d, c, r]', async () => {
161 const t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
162 const perm = [2, 1, 0];
163 const dy = tf.tensor3d([111, 211, 121, 221, 112, 212, 122, 222], [2, 2, 2]);
164 const dt = tf.grad(t => t.transpose(perm))(t, dy);
165 expect(dt.shape).toEqual(t.shape);
166 expect(dt.dtype).toEqual('float32');
167 expectArraysClose(await dt.data(), [111, 112, 121, 122, 211, 212, 221, 222]);
168 });
169 it('gradient with clones', async () => {
170 const t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
171 const perm = [2, 1, 0];
172 const dy = tf.tensor3d([111, 211, 121, 221, 112, 212, 122, 222], [2, 2, 2]);
173 const dt = tf.grad(t => t.clone().transpose(perm).clone())(t, dy);
174 expect(dt.shape).toEqual(t.shape);
175 expect(dt.dtype).toEqual('float32');
176 expectArraysClose(await dt.data(), [111, 112, 121, 122, 211, 212, 221, 222]);
177 });
178 it('throws when passed a non-tensor', () => {
179 expect(() => tf.transpose({}))
180 .toThrowError(/Argument 'x' passed to 'transpose' must be a Tensor/);
181 });
182 it('accepts a tensor-like object', async () => {
183 const t = [[1, 11, 2, 22], [3, 33, 4, 44]];
184 const res = tf.transpose(t, [1, 0]);
185 expect(res.shape).toEqual([4, 2]);
186 expectArraysClose(await res.data(), [1, 3, 11, 33, 2, 4, 22, 44]);
187 });
188});
189//# sourceMappingURL=transpose_test.js.map
\No newline at end of file