UNPKG

29.8 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2021 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../../index';
18import { ALL_ENVS, describeWithFlags } from '../../jasmine_util';
19import { expectArraysClose } from '../../test_util';
20function sparseTensorValue5x6() {
21 const ind = tf.tensor2d([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]], [6, 2], 'int32');
22 const val = [0, 10, 13, 14, 32, 33];
23 const shape = [5, 6];
24 return { ind, val, shape };
25}
26function sparseTensorValue2x3x4() {
27 const ind = tf.tensor2d([
28 [0, 0, 1], [0, 1, 0], [0, 1, 2], [1, 0, 3], [1, 1, 1], [1, 1, 3],
29 [1, 2, 2]
30 ], [7, 3], 'int32');
31 const val = [1, 10, 12, 103, 111, 113, 122];
32 const shape = [2, 3, 4];
33 return { ind, val, shape };
34}
35describeWithFlags('sparseReshape', ALL_ENVS, () => {
36 it('preserve static shape info', async () => {
37 const sparseTensor = sparseTensorValue5x6();
38 const result = tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [1, 5, 2, 3]);
39 expectArraysClose(await result.outputShape.data(), [1, 5, 2, 3]);
40 });
41 it('preserve shape info with inferred dim', async () => {
42 const sparseTensor = sparseTensorValue2x3x4();
43 const result = tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [2, -1]);
44 expectArraysClose(await result.outputShape.data(), [2, 3 * 4]);
45 });
46 it('does not have memory leak.', async () => {
47 const beforeDataIds = tf.engine().backend.numDataIds();
48 const sparseTensor = sparseTensorValue5x6();
49 const indices = sparseTensor.ind;
50 const shape = tf.tensor1d(sparseTensor.shape, 'int32');
51 const newShape = tf.tensor1d([1, 5, 2, 3], 'int32');
52 const result = tf.sparse.sparseReshape(indices, shape, newShape);
53 await result.outputIndices.data();
54 await result.outputShape.data();
55 const afterResDataIds = tf.engine().backend.numDataIds();
56 expect(afterResDataIds).toEqual(beforeDataIds + 5);
57 indices.dispose();
58 shape.dispose();
59 newShape.dispose();
60 result.outputIndices.dispose();
61 result.outputShape.dispose();
62 const afterDisposeDataIds = tf.engine().backend.numDataIds();
63 expect(afterDisposeDataIds).toEqual(beforeDataIds);
64 });
65 it('throw error if more than one inferred dim', async () => {
66 const sparseTensor = sparseTensorValue2x3x4();
67 expect(() => tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [
68 -1, 2, -1
69 ])).toThrowError(/only one output dimension may be -1/);
70 });
71 it('throw error if impossible new shape', async () => {
72 const sparseTensor = sparseTensorValue2x3x4();
73 expect(() => tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [
74 -1, 7
75 ])).toThrowError(/multiple of 7/);
76 });
77 it('throw error if negative output dim', async () => {
78 const sparseTensor = sparseTensorValue2x3x4();
79 expect(() => tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [
80 1, -7
81 ])).toThrowError('size 1 must be non-negative, not -7');
82 });
83 it('throw error if negative output dim', async () => {
84 const sparseTensor = sparseTensorValue2x3x4();
85 expect(() => tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [
86 -1, 0
87 ])).toThrowError(/unless all specified input sizes are non-zero/);
88 });
89 it('same shape', async () => {
90 const sparseTensor = sparseTensorValue5x6();
91 const result = tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [5, 6]);
92 expectArraysClose(await result.outputIndices.data(), await sparseTensor.ind.data());
93 expectArraysClose(await result.outputShape.data(), sparseTensor.shape);
94 });
95 it('same shape with inferred dim', async () => {
96 const sparseTensor = sparseTensorValue5x6();
97 const result = tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [-1, 6]);
98 expectArraysClose(await result.outputIndices.data(), await sparseTensor.ind.data());
99 expectArraysClose(await result.outputShape.data(), sparseTensor.shape);
100 });
101 it('new shape with same rank', async () => {
102 const sparseTensor = sparseTensorValue5x6();
103 const result = tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [3, 10]);
104 expectArraysClose(await result.outputIndices.data(), [[0, 0], [0, 6], [0, 9], [1, 0], [2, 0], [2, 1]]);
105 expectArraysClose(await result.outputShape.data(), [3, 10]);
106 });
107 it('new shape with same rank with inferred dim', async () => {
108 const sparseTensor = sparseTensorValue5x6();
109 const result = tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [3, -1]);
110 expectArraysClose(await result.outputIndices.data(), [[0, 0], [0, 6], [0, 9], [1, 0], [2, 0], [2, 1]]);
111 expectArraysClose(await result.outputShape.data(), [3, 10]);
112 });
113 it('up rank', async () => {
114 const sparseTensor = sparseTensorValue5x6();
115 const result = tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [2, 3, 5]);
116 expectArraysClose(await result.outputIndices.data(), [[0, 0, 0], [0, 1, 1], [0, 1, 4], [0, 2, 0], [1, 1, 0], [1, 1, 1]]);
117 expectArraysClose(await result.outputShape.data(), [2, 3, 5]);
118 });
119 it('up rank with inferred dim', async () => {
120 const sparseTensor = sparseTensorValue5x6();
121 const result = tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [2, -1, 5]);
122 expectArraysClose(await result.outputIndices.data(), [[0, 0, 0], [0, 1, 1], [0, 1, 4], [0, 2, 0], [1, 1, 0], [1, 1, 1]]);
123 expectArraysClose(await result.outputShape.data(), [2, 3, 5]);
124 });
125 it('down rank', async () => {
126 const sparseTensor = sparseTensorValue2x3x4();
127 const result = tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [6, 4]);
128 expectArraysClose(await result.outputIndices.data(), [[0, 1], [1, 0], [1, 2], [3, 3], [4, 1], [4, 3], [5, 2]]);
129 expectArraysClose(await result.outputShape.data(), [6, 4]);
130 });
131 it('down rank with inferred dim', async () => {
132 const sparseTensor = sparseTensorValue2x3x4();
133 const result = tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [6, -1]);
134 expectArraysClose(await result.outputIndices.data(), [[0, 1], [1, 0], [1, 2], [3, 3], [4, 1], [4, 3], [5, 2]]);
135 expectArraysClose(await result.outputShape.data(), [6, 4]);
136 });
137 it('throw error if mismatch size', async () => {
138 const sparseTensor = sparseTensorValue5x6();
139 expect(() => tf.sparse.sparseReshape(sparseTensor.ind, sparseTensor.shape, [
140 4, 7
141 ])).toThrowError(/Input to reshape is a tensor with 30 dense values/);
142 });
143});
144//# sourceMappingURL=data:application/json;base64,
\No newline at end of file