1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose } from '../test_util';
|
20 | describeWithFlags('reverse3d', ALL_ENVS, () => {
|
21 |
|
22 |
|
23 |
|
24 |
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
|
30 |
|
31 |
|
32 |
|
33 | const shape = [2, 3, 4];
|
34 | const data = [
|
35 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
36 | 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23
|
37 | ];
|
38 | it('reverse a 3D array at axis [0]', async () => {
|
39 | const input = tf.tensor3d(data, shape);
|
40 | const result = tf.reverse3d(input, [0]);
|
41 | expect(result.shape).toEqual(input.shape);
|
42 | expectArraysClose(await result.data(), [
|
43 | 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
44 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11
|
45 | ]);
|
46 | });
|
47 | it('reverse a 3D array at axis [1]', async () => {
|
48 | const input = tf.tensor3d(data, shape);
|
49 | const result = tf.reverse3d(input, [1]);
|
50 | expect(result.shape).toEqual(input.shape);
|
51 | expectArraysClose(await result.data(), [
|
52 | 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3,
|
53 | 20, 21, 22, 23, 16, 17, 18, 19, 12, 13, 14, 15
|
54 | ]);
|
55 | });
|
56 | it('reverse a 3D array at axis [2]', async () => {
|
57 | const input = tf.tensor3d(data, shape);
|
58 | const result = tf.reverse3d(input, [2]);
|
59 | expect(result.shape).toEqual(input.shape);
|
60 | expectArraysClose(await result.data(), [
|
61 | 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
|
62 | 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20
|
63 | ]);
|
64 | });
|
65 | it('reverse a 3D array at axis [0, 1]', async () => {
|
66 | const input = tf.tensor3d(data, shape);
|
67 | const result = tf.reverse3d(input, [0, 1]);
|
68 | expect(result.shape).toEqual(input.shape);
|
69 | expectArraysClose(await result.data(), [
|
70 | 20, 21, 22, 23, 16, 17, 18, 19, 12, 13, 14, 15,
|
71 | 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3
|
72 | ]);
|
73 | });
|
74 | it('reverse a 3D array at axis [0, 2]', async () => {
|
75 | const input = tf.tensor3d(data, shape);
|
76 | const result = tf.reverse3d(input, [0, 2]);
|
77 | expect(result.shape).toEqual(input.shape);
|
78 | expectArraysClose(await result.data(), [
|
79 | 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20,
|
80 | 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8
|
81 | ]);
|
82 | });
|
83 | it('reverse a 3D array at axis [1, 2]', async () => {
|
84 | const input = tf.tensor3d(data, shape);
|
85 | const result = tf.reverse3d(input, [1, 2]);
|
86 | expect(result.shape).toEqual(input.shape);
|
87 | expectArraysClose(await result.data(), [
|
88 | 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
89 | 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12
|
90 | ]);
|
91 | });
|
92 | it('throws error with invalid input', () => {
|
93 |
|
94 | const x = tf.tensor2d([1, 20, 300, 4], [1, 4]);
|
95 | expect(() => tf.reverse3d(x, [1])).toThrowError();
|
96 | });
|
97 | it('throws error with invalid axis param', () => {
|
98 | const x = tf.tensor3d([1, 20, 300, 4], [1, 1, 4]);
|
99 | expect(() => tf.reverse3d(x, [3])).toThrowError();
|
100 | expect(() => tf.reverse3d(x, [-4])).toThrowError();
|
101 | });
|
102 | it('throws error with non integer axis param', () => {
|
103 | const x = tf.tensor3d([1, 20, 300, 4], [1, 1, 4]);
|
104 | expect(() => tf.reverse3d(x, [0.5])).toThrowError();
|
105 | });
|
106 | it('accepts a tensor-like object', async () => {
|
107 | const input = [[[1], [2], [3]], [[4], [5], [6]]];
|
108 | const result = tf.reverse3d(input, [0]);
|
109 | expect(result.shape).toEqual([2, 3, 1]);
|
110 | expectArraysClose(await result.data(), [4, 5, 6, 1, 2, 3]);
|
111 | });
|
112 | });
|
113 |
|
\ | No newline at end of file |