1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose } from '../test_util';
|
20 | describeWithFlags('transpose', ALL_ENVS, () => {
|
21 | it('of scalar is no-op', async () => {
|
22 | const a = tf.scalar(3);
|
23 | expectArraysClose(await tf.transpose(a).data(), [3]);
|
24 | });
|
25 | it('of 1D is no-op', async () => {
|
26 | const a = tf.tensor1d([1, 2, 3]);
|
27 | expectArraysClose(await tf.transpose(a).data(), [1, 2, 3]);
|
28 | });
|
29 | it('of scalar with perm of incorrect rank throws error', () => {
|
30 | const a = tf.scalar(3);
|
31 | const perm = [0];
|
32 | expect(() => tf.transpose(a, perm)).toThrowError();
|
33 | });
|
34 | it('of 1d with perm out of bounds throws error', () => {
|
35 | const a = tf.tensor1d([1, 2, 3]);
|
36 | const perm = [1];
|
37 | expect(() => tf.transpose(a, perm)).toThrowError();
|
38 | });
|
39 | it('of 1d with perm incorrect rank throws error', () => {
|
40 | const a = tf.tensor1d([1, 2, 3]);
|
41 | const perm = [0, 0];
|
42 | expect(() => tf.transpose(a, perm)).toThrowError();
|
43 | });
|
44 | it('2D (no change)', async () => {
|
45 | const t = tf.tensor2d([1, 11, 2, 22, 3, 33, 4, 44], [2, 4]);
|
46 | const t2 = tf.transpose(t, [0, 1]);
|
47 | expect(t2.shape).toEqual(t.shape);
|
48 | expectArraysClose(await t2.array(), await t.array());
|
49 | });
|
50 | it('2D (transpose)', async () => {
|
51 | const t = tf.tensor2d([1, 11, 2, 22, 3, 33, 4, 44], [2, 4]);
|
52 | const t2 = tf.transpose(t, [1, 0]);
|
53 | expect(t2.shape).toEqual([4, 2]);
|
54 | expectArraysClose(await t2.data(), [1, 3, 11, 33, 2, 4, 22, 44]);
|
55 | });
|
56 | it('2D, shape has ones', async () => {
|
57 | const t = tf.tensor2d([1, 2, 3, 4], [1, 4]);
|
58 | const t2 = tf.transpose(t, [1, 0]);
|
59 | expect(t2.shape).toEqual([4, 1]);
|
60 | expectArraysClose(await t2.data(), [1, 2, 3, 4]);
|
61 | });
|
62 | it('3D [r, c, d] => [d, r, c]', async () => {
|
63 | const t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
|
64 | const t2 = tf.transpose(t, [2, 0, 1]);
|
65 | expect(t2.shape).toEqual([2, 2, 2]);
|
66 | expectArraysClose(await t2.data(), [1, 2, 3, 4, 11, 22, 33, 44]);
|
67 | });
|
68 | it('3D [r, c, d] => [d, c, r]', async () => {
|
69 | const t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
|
70 | const t2 = tf.transpose(t, [2, 1, 0]);
|
71 | expect(t2.shape).toEqual([2, 2, 2]);
|
72 | expectArraysClose(await t2.data(), [1, 3, 2, 4, 11, 33, 22, 44]);
|
73 | });
|
74 | it('3D [r, c, d] => [d, r, c], shape has ones', async () => {
|
75 | const perm = [2, 0, 1];
|
76 | const t = tf.tensor3d([1, 2, 3, 4], [2, 1, 2]);
|
77 | const tt = tf.transpose(t, perm);
|
78 | expect(tt.shape).toEqual([2, 2, 1]);
|
79 | expectArraysClose(await tt.data(), [1, 3, 2, 4]);
|
80 | const t2 = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
|
81 | const tt2 = tf.transpose(t2, perm);
|
82 | expect(tt2.shape).toEqual([1, 2, 2]);
|
83 | expectArraysClose(await tt2.data(), [1, 2, 3, 4]);
|
84 | const t3 = tf.tensor3d([1, 2, 3, 4], [1, 2, 2]);
|
85 | const tt3 = tf.transpose(t3, perm);
|
86 | expect(tt3.shape).toEqual([2, 1, 2]);
|
87 | expectArraysClose(await tt3.data(), [1, 3, 2, 4]);
|
88 | });
|
89 | it('3D [r, c, d] => [r, d, c]', async () => {
|
90 | const perm = [0, 2, 1];
|
91 | const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
|
92 | const tt = tf.transpose(t, perm);
|
93 | expect(tt.shape).toEqual([2, 2, 2]);
|
94 | expectArraysClose(await tt.data(), [1, 3, 2, 4, 5, 7, 6, 8]);
|
95 | });
|
96 | it('5D [r, c, d, e, f] => [r, c, d, f, e]', async () => {
|
97 | const t = tf.tensor5d(new Array(32).fill(0).map((x, i) => i + 1), [2, 2, 2, 2, 2]);
|
98 | const t2 = tf.transpose(t, [0, 1, 2, 4, 3]);
|
99 | expect(t2.shape).toEqual([2, 2, 2, 2, 2]);
|
100 | expectArraysClose(await t2.data(), [
|
101 | 1, 3, 2, 4, 5, 7, 6, 8, 9, 11, 10, 12, 13, 15, 14, 16,
|
102 | 17, 19, 18, 20, 21, 23, 22, 24, 25, 27, 26, 28, 29, 31, 30, 32
|
103 | ]);
|
104 | });
|
105 | it('4D [r, c, d, e] => [c, r, d, e]', async () => {
|
106 | const t = tf.tensor4d(new Array(16).fill(0).map((x, i) => i + 1), [2, 2, 2, 2]);
|
107 | const t2 = tf.transpose(t, [1, 0, 2, 3]);
|
108 | expect(t2.shape).toEqual([2, 2, 2, 2]);
|
109 | expectArraysClose(await t2.data(), [1, 2, 3, 4, 9, 10, 11, 12, 5, 6, 7, 8, 13, 14, 15, 16]);
|
110 | });
|
111 | it('4D [r, c, d, e] => [c, r, e, d]', async () => {
|
112 | const t = tf.tensor4d(new Array(16).fill(0).map((x, i) => i + 1), [2, 2, 2, 2]);
|
113 | const t2 = tf.transpose(t, [1, 0, 3, 2]);
|
114 | expect(t2.shape).toEqual([2, 2, 2, 2]);
|
115 | expectArraysClose(await t2.data(), [1, 3, 2, 4, 9, 11, 10, 12, 5, 7, 6, 8, 13, 15, 14, 16]);
|
116 | });
|
117 | it('4D [r, c, d, e] => [e, r, c, d]', async () => {
|
118 | const t = tf.tensor4d(new Array(16).fill(0).map((x, i) => i + 1), [2, 2, 2, 2]);
|
119 | const t2 = tf.transpose(t, [3, 0, 1, 2]);
|
120 | expect(t2.shape).toEqual([2, 2, 2, 2]);
|
121 | expectArraysClose(await t2.data(), [1, 3, 5, 7, 9, 11, 13, 15, 2, 4, 6, 8, 10, 12, 14, 16]);
|
122 | });
|
123 | it('4D [r, c, d, e] => [d, c, e, r]', async () => {
|
124 | const t = tf.tensor4d(new Array(16).fill(0).map((x, i) => i + 1), [2, 2, 2, 2]);
|
125 | const t2 = tf.transpose(t, [2, 1, 3, 0]);
|
126 | expect(t2.shape).toEqual([2, 2, 2, 2]);
|
127 | expectArraysClose(await t2.data(), [1, 9, 2, 10, 5, 13, 6, 14, 3, 11, 4, 12, 7, 15, 8, 16]);
|
128 | });
|
129 | it('5D [r, c, d, e, f] => [c, r, d, e, f]', async () => {
|
130 | const t = tf.tensor5d(new Array(32).fill(0).map((x, i) => i + 1), [2, 2, 2, 2, 2]);
|
131 | const t2 = tf.transpose(t, [1, 0, 2, 3, 4]);
|
132 | expect(t2.shape).toEqual([2, 2, 2, 2, 2]);
|
133 | expectArraysClose(await t2.data(), [
|
134 | 1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24,
|
135 | 9, 10, 11, 12, 13, 14, 15, 16, 25, 26, 27, 28, 29, 30, 31, 32
|
136 | ]);
|
137 | });
|
138 | it('6D [r, c, d, e, f] => [r, c, d, f, e]', async () => {
|
139 | const t = tf.tensor6d(new Array(64).fill(0).map((x, i) => i + 1), [2, 2, 2, 2, 2, 2]);
|
140 | const t2 = tf.transpose(t, [0, 1, 2, 3, 5, 4]);
|
141 | expect(t2.shape).toEqual([2, 2, 2, 2, 2, 2]);
|
142 | expectArraysClose(await t2.data(), [
|
143 | 1, 3, 2, 4, 5, 7, 6, 8, 9, 11, 10, 12, 13, 15, 14, 16,
|
144 | 17, 19, 18, 20, 21, 23, 22, 24, 25, 27, 26, 28, 29, 31, 30, 32,
|
145 | 33, 35, 34, 36, 37, 39, 38, 40, 41, 43, 42, 44, 45, 47, 46, 48,
|
146 | 49, 51, 50, 52, 53, 55, 54, 56, 57, 59, 58, 60, 61, 63, 62, 64
|
147 | ]);
|
148 | });
|
149 | it('6D [r, c, d, e, f, g] => [c, r, d, e, f, g]', async () => {
|
150 | const t = tf.tensor6d(new Array(64).fill(0).map((x, i) => i + 1), [2, 2, 2, 2, 2, 2]);
|
151 | const t2 = tf.transpose(t, [1, 0, 2, 3, 4, 5]);
|
152 | expect(t2.shape).toEqual([2, 2, 2, 2, 2, 2]);
|
153 | expectArraysClose(await t2.data(), [
|
154 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
155 | 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
|
156 | 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
|
157 | 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64
|
158 | ]);
|
159 | });
|
160 | it('gradient 3D [r, c, d] => [d, c, r]', async () => {
|
161 | const t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
|
162 | const perm = [2, 1, 0];
|
163 | const dy = tf.tensor3d([111, 211, 121, 221, 112, 212, 122, 222], [2, 2, 2]);
|
164 | const dt = tf.grad(t => t.transpose(perm))(t, dy);
|
165 | expect(dt.shape).toEqual(t.shape);
|
166 | expect(dt.dtype).toEqual('float32');
|
167 | expectArraysClose(await dt.data(), [111, 112, 121, 122, 211, 212, 221, 222]);
|
168 | });
|
169 | it('gradient with clones', async () => {
|
170 | const t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
|
171 | const perm = [2, 1, 0];
|
172 | const dy = tf.tensor3d([111, 211, 121, 221, 112, 212, 122, 222], [2, 2, 2]);
|
173 | const dt = tf.grad(t => t.clone().transpose(perm).clone())(t, dy);
|
174 | expect(dt.shape).toEqual(t.shape);
|
175 | expect(dt.dtype).toEqual('float32');
|
176 | expectArraysClose(await dt.data(), [111, 112, 121, 122, 211, 212, 221, 222]);
|
177 | });
|
178 | it('throws when passed a non-tensor', () => {
|
179 | expect(() => tf.transpose({}))
|
180 | .toThrowError(/Argument 'x' passed to 'transpose' must be a Tensor/);
|
181 | });
|
182 | it('accepts a tensor-like object', async () => {
|
183 | const t = [[1, 11, 2, 22], [3, 33, 4, 44]];
|
184 | const res = tf.transpose(t, [1, 0]);
|
185 | expect(res.shape).toEqual([4, 2]);
|
186 | expectArraysClose(await res.data(), [1, 3, 11, 33, 2, 4, 22, 44]);
|
187 | });
|
188 | });
|
189 |
|
\ | No newline at end of file |