UNPKG

25.5 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('expandDims', ALL_ENVS, () => {
21 it('scalar, default axis is 0', async () => {
22 const res = tf.scalar(1).expandDims();
23 expect(res.shape).toEqual([1]);
24 expectArraysClose(await res.data(), [1]);
25 });
26 it('scalar, axis is out of bounds throws error', () => {
27 const f = () => tf.scalar(1).expandDims(1);
28 expect(f).toThrowError();
29 });
30 it('1d, axis=-3', () => {
31 expect(() => {
32 tf.tensor1d([1, 2, 3]).expandDims(-3);
33 }).toThrowError();
34 });
35 it('1d, axis=-2', async () => {
36 const res = tf.tensor1d([1, 2, 3]).expandDims(-2 /* axis */);
37 expect(res.shape).toEqual([1, 3]);
38 expectArraysClose(await res.data(), [1, 2, 3]);
39 });
40 it('1d, axis=-1', async () => {
41 const res = tf.tensor1d([1, 2, 3]).expandDims(-1 /* axis */);
42 expect(res.shape).toEqual([3, 1]);
43 expectArraysClose(await res.data(), [1, 2, 3]);
44 });
45 it('1d, axis=0', async () => {
46 const res = tf.tensor1d([1, 2, 3]).expandDims(0 /* axis */);
47 expect(res.shape).toEqual([1, 3]);
48 expectArraysClose(await res.data(), [1, 2, 3]);
49 });
50 it('1d, axis=1', async () => {
51 const res = tf.tensor1d([1, 2, 3]).expandDims(1 /* axis */);
52 expect(res.shape).toEqual([3, 1]);
53 expectArraysClose(await res.data(), [1, 2, 3]);
54 });
55 it('2d, axis=-4', () => {
56 expect(() => {
57 tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-4 /* axis */);
58 }).toThrowError();
59 });
60 it('2d, axis=-3', async () => {
61 const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-3 /* axis */);
62 expect(res.shape).toEqual([1, 3, 2]);
63 expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]);
64 });
65 it('2d, axis=-2', async () => {
66 const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-2 /* axis */);
67 expect(res.shape).toEqual([3, 1, 2]);
68 expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]);
69 });
70 it('2d, axis=-1', async () => {
71 const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-1 /* axis */);
72 expect(res.shape).toEqual([3, 2, 1]);
73 expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]);
74 });
75 it('2d, axis=0', async () => {
76 const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(0 /* axis */);
77 expect(res.shape).toEqual([1, 3, 2]);
78 expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]);
79 });
80 it('2d, axis=1', async () => {
81 const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(1 /* axis */);
82 expect(res.shape).toEqual([3, 1, 2]);
83 expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]);
84 });
85 it('2d, axis=2', async () => {
86 const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(2 /* axis */);
87 expect(res.shape).toEqual([3, 2, 1]);
88 expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]);
89 });
90 it('4d, axis=0', async () => {
91 const res = tf.tensor4d([[[[4]]]]).expandDims();
92 expect(res.shape).toEqual([1, 1, 1, 1, 1]);
93 expectArraysClose(await res.data(), [4]);
94 });
95 it('1d string tensor', async () => {
96 const t = tf.tensor(['hello', 'world']);
97 const res = t.expandDims();
98 expect(res.shape).toEqual([1, 2]);
99 expectArraysClose(await res.data(), ['hello', 'world']);
100 });
101 it('2d string tensor, axis=1', async () => {
102 const t = tf.tensor([['a', 'b'], ['c', 'd']]);
103 const res = t.expandDims(1);
104 expect(res.shape).toEqual([2, 1, 2]);
105 expectArraysClose(await res.data(), ['a', 'b', 'c', 'd']);
106 });
107 it('throws when passed a non-tensor', () => {
108 expect(() => tf.expandDims({}))
109 .toThrowError(/Argument 'x' passed to 'expandDims' must be a Tensor/);
110 });
111 it('accepts a tensor-like object', async () => {
112 const res = tf.expandDims(7);
113 expect(res.shape).toEqual([1]);
114 expectArraysClose(await res.data(), [7]);
115 });
116 it('works with 0 in shape', async () => {
117 const a = tf.tensor2d([], [0, 3]);
118 const res = a.expandDims();
119 expect(res.shape).toEqual([1, 0, 3]);
120 expectArraysClose(await res.data(), []);
121 const res2 = a.expandDims(1);
122 expect(res2.shape).toEqual([0, 1, 3]);
123 expectArraysClose(await res2.data(), []);
124 const res3 = a.expandDims(2);
125 expect(res3.shape).toEqual([0, 3, 1]);
126 expectArraysClose(await res3.data(), []);
127 });
128 it('ensure no memory leak', async () => {
129 const numTensorsBefore = tf.memory().numTensors;
130 const numDataIdBefore = tf.engine().backend.numDataIds();
131 const t = tf.scalar(1);
132 const res = t.expandDims();
133 expect(res.shape).toEqual([1]);
134 expectArraysClose(await res.data(), [1]);
135 res.dispose();
136 t.dispose();
137 const numTensorsAfter = tf.memory().numTensors;
138 const numDataIdAfter = tf.engine().backend.numDataIds();
139 expect(numTensorsAfter).toBe(numTensorsBefore);
140 expect(numDataIdAfter).toBe(numDataIdBefore);
141 });
142});
143//# sourceMappingURL=data:application/json;base64,
\No newline at end of file