UNPKG

28.6 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20import * as axis_util from './axis_util';
21describe('axis_util combineLocations', () => {
22 it('rank 4, reduce last 2 dims', () => {
23 const loc = axis_util.combineLocations([4, 1], [3, 7], [2, 3]);
24 expect(loc).toEqual([4, 1, 3, 7]);
25 });
26 it('rank 4, reduce first two dims', () => {
27 const loc = axis_util.combineLocations([4, 1], [3, 7], [0, 1]);
28 expect(loc).toEqual([3, 7, 4, 1]);
29 });
30 it('rank 4, reduce 1st and 3rd dims', () => {
31 const loc = axis_util.combineLocations([4, 1], [3, 7], [0, 2]);
32 expect(loc).toEqual([3, 4, 7, 1]);
33 });
34 it('rank 4, reduce 1st and 4th dims', () => {
35 const loc = axis_util.combineLocations([4, 1], [3, 7], [0, 3]);
36 expect(loc).toEqual([3, 4, 1, 7]);
37 });
38 it('rank 3, reduce all dims', () => {
39 const loc = axis_util.combineLocations([], [3, 7, 1], [0, 1, 2]);
40 expect(loc).toEqual([3, 7, 1]);
41 });
42 it('rank 2, reduce last dim', () => {
43 const loc = axis_util.combineLocations([3], [5], [1]);
44 expect(loc).toEqual([3, 5]);
45 });
46 it('rank 2, reduce first dim', () => {
47 const loc = axis_util.combineLocations([3], [5], [0]);
48 expect(loc).toEqual([5, 3]);
49 });
50});
51describe('axis_util computeOutAndReduceShapes', () => {
52 it('rank 4, reduce all dims', () => {
53 const [out, red] = axis_util.computeOutAndReduceShapes([3, 7, 2, 4], [0, 1, 2, 3]);
54 expect(out).toEqual([]);
55 expect(red).toEqual([3, 7, 2, 4]);
56 });
57 it('rank 4, reduce last 2 dims', () => {
58 const [out, red] = axis_util.computeOutAndReduceShapes([3, 7, 2, 4], [2, 3]);
59 expect(out).toEqual([3, 7]);
60 expect(red).toEqual([2, 4]);
61 });
62 it('rank 4, reduce first 2 dims', () => {
63 const [out, red] = axis_util.computeOutAndReduceShapes([3, 7, 2, 4], [0, 1]);
64 expect(out).toEqual([2, 4]);
65 expect(red).toEqual([3, 7]);
66 });
67 it('rank 4, reduce last 3 dims', () => {
68 const [out, red] = axis_util.computeOutAndReduceShapes([3, 7, 2, 4], [1, 2, 3]);
69 expect(out).toEqual([3]);
70 expect(red).toEqual([7, 2, 4]);
71 });
72 it('rank 4, reduce 1st and 3rd dims', () => {
73 const [out, red] = axis_util.computeOutAndReduceShapes([3, 7, 2, 4], [0, 2]);
74 expect(out).toEqual([7, 4]);
75 expect(red).toEqual([3, 2]);
76 });
77 it('rank 3, reduce all dims', () => {
78 const [out, red] = axis_util.computeOutAndReduceShapes([3, 7, 2], [0, 1, 2]);
79 expect(out).toEqual([]);
80 expect(red).toEqual([3, 7, 2]);
81 });
82});
83describe('axis_util axesAreInnerMostDims', () => {
84 it('rank 4, reduce last dim', () => {
85 const res = axis_util.axesAreInnerMostDims([3], 4);
86 expect(res).toBe(true);
87 });
88 it('rank 4, reduce last 2 dims', () => {
89 const res = axis_util.axesAreInnerMostDims([2, 3], 4);
90 expect(res).toBe(true);
91 });
92 it('rank 4, reduce last 3 dims', () => {
93 const res = axis_util.axesAreInnerMostDims([1, 2, 3], 4);
94 expect(res).toBe(true);
95 });
96 it('rank 4, reduce all dims', () => {
97 const res = axis_util.axesAreInnerMostDims([0, 1, 2, 3], 4);
98 expect(res).toBe(true);
99 });
100 it('rank 4, reduce all but 2nd', () => {
101 const res = axis_util.axesAreInnerMostDims([0, 2, 3], 4);
102 expect(res).toBe(false);
103 });
104 it('rank 4, reduce all but 3rd', () => {
105 const res = axis_util.axesAreInnerMostDims([0, 1, 3], 4);
106 expect(res).toBe(false);
107 });
108 it('rank 4, reduce all but last', () => {
109 const res = axis_util.axesAreInnerMostDims([0, 1, 2], 4);
110 expect(res).toBe(false);
111 });
112});
113describe('axis_util expandShapeToKeepDim', () => {
114 it('2d -> 1d axis=0', () => {
115 const shape = axis_util.expandShapeToKeepDim([2], [0]);
116 expect(shape).toEqual([1, 2]);
117 });
118 it('2d -> 1d axis=1', () => {
119 const shape = axis_util.expandShapeToKeepDim([4], [1]);
120 expect(shape).toEqual([4, 1]);
121 });
122 it('3d -> 1d axis=1,2', () => {
123 const shape = axis_util.expandShapeToKeepDim([7], [1, 2]);
124 expect(shape).toEqual([7, 1, 1]);
125 });
126 it('3d -> 2d axis=1', () => {
127 const shape = axis_util.expandShapeToKeepDim([7, 3], [1]);
128 expect(shape).toEqual([7, 1, 3]);
129 });
130});
131describe('axis_util getPermAxes', () => {
132 it('all axes, no perm is needed', () => {
133 const perm = axis_util.getAxesPermutation([0, 1, 2], 3);
134 expect(perm).toBeNull();
135 });
136 it('no axes, no perm is needed', () => {
137 const perm = axis_util.getAxesPermutation([], 3);
138 expect(perm).toBeNull();
139 });
140 it('inner most 2 axes, no perm is needed', () => {
141 const perm = axis_util.getAxesPermutation([2, 3], 4);
142 expect(perm).toBeNull();
143 });
144 it('outer most axis, perm is needed', () => {
145 const perm = axis_util.getAxesPermutation([0], 4);
146 expect(perm).toEqual([1, 2, 3, 0]);
147 });
148 it('2 outer most axes, perm is needed', () => {
149 const perm = axis_util.getAxesPermutation([0, 1], 4);
150 expect(perm).toEqual([2, 3, 0, 1]);
151 });
152});
153describeWithFlags('axis_util getUndoAxesPermutation', ALL_ENVS, () => {
154 it('4d axes', () => {
155 const axes = [2, 0, 1, 3];
156 expect(axis_util.getUndoAxesPermutation(axes)).toEqual([1, 2, 0, 3]);
157 });
158 it('3d axes, no perm', () => {
159 const axes = [0, 1, 2];
160 expect(axis_util.getUndoAxesPermutation(axes)).toEqual([0, 1, 2]);
161 });
162 it('3d axes, complete flip', () => {
163 const axes = [2, 1, 0];
164 expect(axis_util.getUndoAxesPermutation(axes)).toEqual([2, 1, 0]);
165 });
166 it('4d array with values', async () => {
167 const axes = [2, 0, 1, 3];
168 const undoPermutation = axis_util.getUndoAxesPermutation(axes);
169 const a = tf.randomNormal([2, 3, 4, 5]);
170 const aT = tf.transpose(a, axes);
171 const aTT = tf.transpose(aT, undoPermutation);
172 expectArraysClose(await a.data(), await aTT.data());
173 });
174});
175//# sourceMappingURL=data:application/json;base64,
\No newline at end of file