UNPKG

26.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2021 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../../index';
18import { ALL_ENVS, describeWithFlags } from '../../jasmine_util';
19import { expectArraysClose } from '../../test_util';
20function TensorValue3x4() {
21 return tf.tensor2d([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]]);
22}
23function TensorValue10() {
24 return tf.tensor1d(Array.from(Array(10), (_, i) => i + 1));
25}
26function TensorValue10x4() {
27 return tf.tensor2d(Array.from(Array(40), (_, i) => i + 1), [10, 4]);
28}
29function TensorValue10x2x4() {
30 return tf.tensor3d(Array.from(Array(80), (_, i) => i + 1), [10, 2, 4]);
31}
32describeWithFlags('sparseSegmentSum', ALL_ENVS, () => {
33 it('two rows one segment', async () => {
34 const result = tf.sparse.sparseSegmentSum(TensorValue3x4(), [0, 1], [0, 0]);
35 expectArraysClose(await result.data(), [[0, 0, 0, 0]]);
36 });
37 it('two rows two segments', async () => {
38 const result = tf.sparse.sparseSegmentSum(TensorValue3x4(), [0, 1], [0, 1]);
39 expectArraysClose(await result.data(), [[1, 2, 3, 4], [-1, -2, -3, -4]]);
40 });
41 it('all rows one segment', async () => {
42 const result = tf.sparse.sparseSegmentSum(TensorValue3x4(), [0, 1, 2], [0, 0, 1]);
43 expectArraysClose(await result.data(), [[0, 0, 0, 0], [5, 6, 7, 8]]);
44 });
45 it('0 dimensional input invalid', async () => {
46 expect(() => tf.sparse.sparseSegmentSum(tf.scalar(1), [], []))
47 .toThrowError(/should be at least 1 dimensional/);
48 });
49 it('1 dimensional input', async () => {
50 const result = tf.sparse.sparseSegmentSum(TensorValue10(), [8, 3, 0, 9], [0, 1, 2, 2]);
51 expectArraysClose(await result.data(), [9, 4, 11]);
52 });
53 it('3 dimensional input', async () => {
54 const result = tf.sparse.sparseSegmentSum(TensorValue10x2x4(), [8, 3, 0, 9], [0, 1, 2, 2]);
55 expectArraysClose(await result.data(), [
56 [[65, 66, 67, 68], [69, 70, 71, 72]],
57 [[25, 26, 27, 28], [29, 30, 31, 32]], [[74, 76, 78, 80], [82, 84, 86, 88]]
58 ]);
59 });
60 it('segment ids hole', async () => {
61 const result = tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [0, 3, 3, 3]);
62 expectArraysClose(await result.data(), [[33, 34, 35, 36], [0, 0, 0, 0], [0, 0, 0, 0], [51, 54, 57, 60]]);
63 });
64 it('segment ids > zero', async () => {
65 const result = tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [2, 3, 3, 3]);
66 expectArraysClose(await result.data(), [[0, 0, 0, 0], [0, 0, 0, 0], [33, 34, 35, 36], [51, 54, 57, 60]]);
67 });
68 it('baseline valid', async () => {
69 // Baseline for the *invalid* tests below.
70 const result = tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [0, 1, 2, 2]);
71 expectArraysClose(await result.data(), [[33, 34, 35, 36], [13, 14, 15, 16], [38, 40, 42, 44]]);
72 });
73 it('does not have memory leak.', async () => {
74 const beforeDataIds = tf.engine().backend.numDataIds();
75 const data = TensorValue3x4();
76 const indices = tf.tensor1d([0, 1], 'int32');
77 const segmentIds = tf.tensor1d([0, 0], 'int32');
78 const result = tf.sparse.sparseSegmentSum(data, indices, segmentIds);
79 await result.data();
80 const afterResDataIds = tf.engine().backend.numDataIds();
81 expect(afterResDataIds).toEqual(beforeDataIds + 4);
82 data.dispose();
83 indices.dispose();
84 segmentIds.dispose();
85 result.dispose();
86 const afterDisposeDataIds = tf.engine().backend.numDataIds();
87 expect(afterDisposeDataIds).toEqual(beforeDataIds);
88 });
89 it('indices invalid 1', async () => {
90 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, -1, 0, 9], [
91 0, 1, 2, 2
92 ])).toThrowError(/indices\[1\] == -1 out of range \[0, 10\)/);
93 });
94 it('indices invalid 2', async () => {
95 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 10], [
96 0, 1, 2, 2
97 ])).toThrowError(/indices\[3\] == 10 out of range \[0, 10\)/);
98 });
99 it('segments invalid 2', async () => {
100 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [
101 0, 1, 0, 1
102 ])).toThrowError('segment ids are not increasing');
103 });
104 it('segments invalid 3', async () => {
105 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [0, 1, 2, 0]))
106 .toThrowError('Segment id 1 out of range [0, 1), possibly because segmentIds ' +
107 'input is not sorted.');
108 });
109 it('segments invalid 4', async () => {
110 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [-1, 0, 1, 1]))
111 .toThrowError('Segment id -1 out of range [0, 2), possibly because segmentIds ' +
112 'input is not sorted.');
113 });
114 it('segments invalid 6', async () => {
115 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [
116 0, 0, 0, -1
117 ])).toThrowError('segment ids must be >= 0');
118 });
119 it('segments invalid 7', async () => {
120 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [
121 0, 0, 0, -2
122 ])).toThrowError('segment ids must be >= 0');
123 });
124 it('indices invalid rank', async () => {
125 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [[8, 3, 0, 9]], [
126 0, 0, 0, -2
127 ])).toThrowError(/should be Tensor1D/);
128 });
129 it('segments invalid rank', async () => {
130 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [
131 [0, 0, 0, -2]
132 ])).toThrowError(/should be Tensor1D/);
133 });
134});
135//# sourceMappingURL=data:application/json;base64,
\No newline at end of file