UNPKG

26.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2021 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../../index';
18import { ALL_ENVS, describeWithFlags } from '../../jasmine_util';
19import { expectArraysClose } from '../../test_util';
20function TensorValue3x4() {
21 return tf.tensor2d([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]]);
22}
23function TensorValue10() {
24 return tf.tensor1d(Array.from(Array(10), (_, i) => i + 1));
25}
26function TensorValue10x4() {
27 return tf.tensor2d(Array.from(Array(40), (_, i) => i + 1), [10, 4]);
28}
29function TensorValue10x2x4() {
30 return tf.tensor3d(Array.from(Array(80), (_, i) => i + 1), [10, 2, 4]);
31}
32describeWithFlags('sparseSegmentSum', ALL_ENVS, () => {
33 it('two rows one segment', async () => {
34 const result = tf.sparse.sparseSegmentSum(TensorValue3x4(), [0, 1], [0, 0]);
35 expectArraysClose(await result.data(), [[0, 0, 0, 0]]);
36 });
37 it('two rows two segments', async () => {
38 const result = tf.sparse.sparseSegmentSum(TensorValue3x4(), [0, 1], [0, 1]);
39 expectArraysClose(await result.data(), [[1, 2, 3, 4], [-1, -2, -3, -4]]);
40 });
41 it('all rows one segment', async () => {
42 const result = tf.sparse.sparseSegmentSum(TensorValue3x4(), [0, 1, 2], [0, 0, 1]);
43 expectArraysClose(await result.data(), [[0, 0, 0, 0], [5, 6, 7, 8]]);
44 });
45 it('0 dimensional input invalid', async () => {
46 expect(() => tf.sparse.sparseSegmentSum(tf.scalar(1), [], []))
47 .toThrowError(/should be at least 1 dimensional/);
48 });
49 it('1 dimensional input', async () => {
50 const result = tf.sparse.sparseSegmentSum(TensorValue10(), [8, 3, 0, 9], [0, 1, 2, 2]);
51 expectArraysClose(await result.data(), [9, 4, 11]);
52 });
53 it('3 dimensional input', async () => {
54 const result = tf.sparse.sparseSegmentSum(TensorValue10x2x4(), [8, 3, 0, 9], [0, 1, 2, 2]);
55 expectArraysClose(await result.data(), [
56 [[65, 66, 67, 68], [69, 70, 71, 72]],
57 [[25, 26, 27, 28], [29, 30, 31, 32]], [[74, 76, 78, 80], [82, 84, 86, 88]]
58 ]);
59 });
60 it('segment ids hole', async () => {
61 const result = tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [0, 3, 3, 3]);
62 expectArraysClose(await result.data(), [[33, 34, 35, 36], [0, 0, 0, 0], [0, 0, 0, 0], [51, 54, 57, 60]]);
63 });
64 it('segment ids > zero', async () => {
65 const result = tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [2, 3, 3, 3]);
66 expectArraysClose(await result.data(), [[0, 0, 0, 0], [0, 0, 0, 0], [33, 34, 35, 36], [51, 54, 57, 60]]);
67 });
68 it('baseline valid', async () => {
69 // Baseline for the *invalid* tests below.
70 const result = tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [0, 1, 2, 2]);
71 expectArraysClose(await result.data(), [[33, 34, 35, 36], [13, 14, 15, 16], [38, 40, 42, 44]]);
72 });
73 it('does not have memory leak.', async () => {
74 const beforeDataIds = tf.engine().backend.numDataIds();
75 const data = TensorValue3x4();
76 const indices = tf.tensor1d([0, 1], 'int32');
77 const segmentIds = tf.tensor1d([0, 0], 'int32');
78 const result = tf.sparse.sparseSegmentSum(data, indices, segmentIds);
79 await result.data();
80 const afterResDataIds = tf.engine().backend.numDataIds();
81 expect(afterResDataIds).toEqual(beforeDataIds + 4);
82 data.dispose();
83 indices.dispose();
84 segmentIds.dispose();
85 result.dispose();
86 const afterDisposeDataIds = tf.engine().backend.numDataIds();
87 expect(afterDisposeDataIds).toEqual(beforeDataIds);
88 });
89 it('indices invalid 1', async () => {
90 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, -1, 0, 9], [
91 0, 1, 2, 2
92 ])).toThrowError(/indices\[1\] == -1 out of range \[0, 10\)/);
93 });
94 it('indices invalid 2', async () => {
95 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 10], [
96 0, 1, 2, 2
97 ])).toThrowError(/indices\[3\] == 10 out of range \[0, 10\)/);
98 });
99 it('segments invalid 2', async () => {
100 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [
101 0, 1, 0, 1
102 ])).toThrowError('segment ids are not increasing');
103 });
104 it('segments invalid 3', async () => {
105 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [0, 1, 2, 0]))
106 .toThrowError('Segment id 1 out of range [0, 1), possibly because segmentIds ' +
107 'input is not sorted.');
108 });
109 it('segments invalid 4', async () => {
110 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [-1, 0, 1, 1]))
111 .toThrowError('Segment id -1 out of range [0, 2), possibly because segmentIds ' +
112 'input is not sorted.');
113 });
114 it('segments invalid 6', async () => {
115 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [
116 0, 0, 0, -1
117 ])).toThrowError('segment ids must be >= 0');
118 });
119 it('segments invalid 7', async () => {
120 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [
121 0, 0, 0, -2
122 ])).toThrowError('segment ids must be >= 0');
123 });
124 it('indices invalid rank', async () => {
125 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [[8, 3, 0, 9]], [
126 0, 0, 0, -2
127 ])).toThrowError(/should be Tensor1D/);
128 });
129 it('segments invalid rank', async () => {
130 expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [
131 [0, 0, 0, -2]
132 ])).toThrowError(/should be Tensor1D/);
133 });
134});
135//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"sparse_segment_sum_test.js","sourceRoot":"","sources":["../../../../../../../tfjs-core/src/ops/sparse/sparse_segment_sum_test.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,EAAE,MAAM,aAAa,CAAC;AAClC,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAC,MAAM,oBAAoB,CAAC;AAC/D,OAAO,EAAC,iBAAiB,EAAC,MAAM,iBAAiB,CAAC;AAElD,SAAS,cAAc;IACrB,OAAO,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;AACrE,CAAC;AAED,SAAS,aAAa;IACpB,OAAO,EAAE,CAAC,QAAQ,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;AAC7D,CAAC;AAED,SAAS,eAAe;IACtB,OAAO,EAAE,CAAC,QAAQ,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;AACtE,CAAC;AAED,SAAS,iBAAiB;IACxB,OAAO,EAAE,CAAC,QAAQ,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;AACzE,CAAC;AAED,iBAAiB,CAAC,kBAAkB,EAAE,QAAQ,EAAE,GAAG,EAAE;IACnD,EAAE,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,cAAc,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5E,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,uBAAuB,EAAE,KAAK,IAAI,EAAE;QACrC,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,cAAc,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5E,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC3E,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,MAAM,GACR,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,cAAc,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACvE,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;IACvE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,6BAA6B,EAAE,KAAK,IAAI,EAAE;QAC3C,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC;aACzD,YAAY,CAAC,kCAAkC,CAAC,CAAC;IACxD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,MAAM,MAAM,GACR,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,aAAa,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5E,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;IACrD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,gBAAgB,CACrC,iBAAiB,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACrD,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE;YACrC,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC;YACpC,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC;SAC3E,CAAC,CAAC;IACL,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,kBAAkB,EAAE,KAAK,IAAI,EAAE;QAChC,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,gBAAgB,CACrC,eAAe,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnD,iBAAiB,CACb,MAAM,MAAM,CAAC,IAAI,EAAE,EACnB,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;IACxE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,gBAAgB,CACrC,eAAe,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnD,iBAAiB,CACb,MAAM,MAAM,CAAC,IAAI,EAAE,EACnB,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;IACxE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,gBAAgB,EAAE,KAAK,IAAI,EAAE;QAC9B,0CAA0C;QAC1C,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,gBAAgB,CACrC,eAAe,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnD,iBAAiB,CACb,MAAM,MAAM,CAAC,IAAI,EAAE,EACnB,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;IAC9D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,4BAA4B,EAAE,KAAK,IAAI,EAAE;QAC1C,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,OAAO,CAAC,UAAU,EAAE,CAAC;QAEvD,MAAM,IAAI,GAAG,cAAc,EAAE,CAAC;QAC9B,MAAM,OAAO,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAC7C,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAChD,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,IAAI,EAAE,OAAO,EAAE,UAAU,CAAC,CAAC;QAErE,MAAM,MAAM,CAAC,IAAI,EAAE,CAAC;QAEpB,MAAM,eAAe,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,OAAO,CAAC,UAAU,EAAE,CAAC;QACzD,MAAM,CAAC,eAAe,CAAC,CAAC,OAAO,CAAC,aAAa,GAAG,CAAC,CAAC,CAAC;QAEnD,IAAI,CAAC,OAAO,EAAE,CAAC;QACf,OAAO,CAAC,OAAO,EAAE,CAAC;QAClB,UAAU,CAAC,OAAO,EAAE,CAAC;QACrB,MAAM,CAAC,OAAO,EAAE,CAAC;QAEjB,MAAM,mBAAmB,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,OAAO,CAAC,UAAU,EAAE,CAAC;QAC7D,MAAM,CAAC,mBAAmB,CAAC,CAAC,OAAO,CAAC,aAAa,CAAC,CAAC;IACrD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,eAAe,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE;YACxE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC;SACX,CAAC,CAAC,CAAC,YAAY,CAAC,2CAA2C,CAAC,CAAC;IAChE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,eAAe,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE;YACxE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC;SACX,CAAC,CAAC,CAAC,YAAY,CAAC,2CAA2C,CAAC,CAAC;IAChE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,eAAe,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE;YACvE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC;SACX,CAAC,CAAC,CAAC,YAAY,CAAC,gCAAgC,CAAC,CAAC;IACrD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,MAAM,CACF,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAC5B,eAAe,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;aAClD,YAAY,CACT,gEAAgE;YAChE,sBAAsB,CAAC,CAAC;IAClC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,MAAM,CACF,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAC5B,eAAe,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;aACnD,YAAY,CACT,iEAAiE;YACjE,sBAAsB,CAAC,CAAC;IAClC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,eAAe,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE;YACvE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SACZ,CAAC,CAAC,CAAC,YAAY,CAAC,0BAA0B,CAAC,CAAC;IAC/C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,eAAe,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE;YACvE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SACZ,CAAC,CAAC,CAAC,YAAY,CAAC,0BAA0B,CAAC,CAAC;IAC/C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,eAAe,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE;YACzE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SACZ,CAAC,CAAC,CAAC,YAAY,CAAC,oBAAoB,CAAC,CAAC;IACzC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,uBAAuB,EAAE,KAAK,IAAI,EAAE;QACrC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,gBAAgB,CAAC,eAAe,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE;YACvE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;SACd,CAAC,CAAC,CAAC,YAAY,CAAC,oBAAoB,CAAC,CAAC;IACzC,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2021 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tf from '../../index';\nimport {ALL_ENVS, describeWithFlags} from '../../jasmine_util';\nimport {expectArraysClose} from '../../test_util';\n\nfunction TensorValue3x4() {\n  return tf.tensor2d([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]]);\n}\n\nfunction TensorValue10() {\n  return tf.tensor1d(Array.from(Array(10), (_, i) => i + 1));\n}\n\nfunction TensorValue10x4() {\n  return tf.tensor2d(Array.from(Array(40), (_, i) => i + 1), [10, 4]);\n}\n\nfunction TensorValue10x2x4() {\n  return tf.tensor3d(Array.from(Array(80), (_, i) => i + 1), [10, 2, 4]);\n}\n\ndescribeWithFlags('sparseSegmentSum', ALL_ENVS, () => {\n  it('two rows one segment', async () => {\n    const result = tf.sparse.sparseSegmentSum(TensorValue3x4(), [0, 1], [0, 0]);\n    expectArraysClose(await result.data(), [[0, 0, 0, 0]]);\n  });\n\n  it('two rows two segments', async () => {\n    const result = tf.sparse.sparseSegmentSum(TensorValue3x4(), [0, 1], [0, 1]);\n    expectArraysClose(await result.data(), [[1, 2, 3, 4], [-1, -2, -3, -4]]);\n  });\n\n  it('all rows one segment', async () => {\n    const result =\n        tf.sparse.sparseSegmentSum(TensorValue3x4(), [0, 1, 2], [0, 0, 1]);\n    expectArraysClose(await result.data(), [[0, 0, 0, 0], [5, 6, 7, 8]]);\n  });\n\n  it('0 dimensional input invalid', async () => {\n    expect(() => tf.sparse.sparseSegmentSum(tf.scalar(1), [], []))\n        .toThrowError(/should be at least 1 dimensional/);\n  });\n\n  it('1 dimensional input', async () => {\n    const result =\n        tf.sparse.sparseSegmentSum(TensorValue10(), [8, 3, 0, 9], [0, 1, 2, 2]);\n    expectArraysClose(await result.data(), [9, 4, 11]);\n  });\n\n  it('3 dimensional input', async () => {\n    const result = tf.sparse.sparseSegmentSum(\n        TensorValue10x2x4(), [8, 3, 0, 9], [0, 1, 2, 2]);\n    expectArraysClose(await result.data(), [\n      [[65, 66, 67, 68], [69, 70, 71, 72]],\n      [[25, 26, 27, 28], [29, 30, 31, 32]], [[74, 76, 78, 80], [82, 84, 86, 88]]\n    ]);\n  });\n\n  it('segment ids hole', async () => {\n    const result = tf.sparse.sparseSegmentSum(\n        TensorValue10x4(), [8, 3, 0, 9], [0, 3, 3, 3]);\n    expectArraysClose(\n        await result.data(),\n        [[33, 34, 35, 36], [0, 0, 0, 0], [0, 0, 0, 0], [51, 54, 57, 60]]);\n  });\n\n  it('segment ids > zero', async () => {\n    const result = tf.sparse.sparseSegmentSum(\n        TensorValue10x4(), [8, 3, 0, 9], [2, 3, 3, 3]);\n    expectArraysClose(\n        await result.data(),\n        [[0, 0, 0, 0], [0, 0, 0, 0], [33, 34, 35, 36], [51, 54, 57, 60]]);\n  });\n\n  it('baseline valid', async () => {\n    // Baseline for the *invalid* tests below.\n    const result = tf.sparse.sparseSegmentSum(\n        TensorValue10x4(), [8, 3, 0, 9], [0, 1, 2, 2]);\n    expectArraysClose(\n        await result.data(),\n        [[33, 34, 35, 36], [13, 14, 15, 16], [38, 40, 42, 44]]);\n  });\n\n  it('does not have memory leak.', async () => {\n    const beforeDataIds = tf.engine().backend.numDataIds();\n\n    const data = TensorValue3x4();\n    const indices = tf.tensor1d([0, 1], 'int32');\n    const segmentIds = tf.tensor1d([0, 0], 'int32');\n    const result = tf.sparse.sparseSegmentSum(data, indices, segmentIds);\n\n    await result.data();\n\n    const afterResDataIds = tf.engine().backend.numDataIds();\n    expect(afterResDataIds).toEqual(beforeDataIds + 4);\n\n    data.dispose();\n    indices.dispose();\n    segmentIds.dispose();\n    result.dispose();\n\n    const afterDisposeDataIds = tf.engine().backend.numDataIds();\n    expect(afterDisposeDataIds).toEqual(beforeDataIds);\n  });\n\n  it('indices invalid 1', async () => {\n    expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, -1, 0, 9], [\n      0, 1, 2, 2\n    ])).toThrowError(/indices\\[1\\] == -1 out of range \\[0, 10\\)/);\n  });\n\n  it('indices invalid 2', async () => {\n    expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 10], [\n      0, 1, 2, 2\n    ])).toThrowError(/indices\\[3\\] == 10 out of range \\[0, 10\\)/);\n  });\n\n  it('segments invalid 2', async () => {\n    expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [\n      0, 1, 0, 1\n    ])).toThrowError('segment ids are not increasing');\n  });\n\n  it('segments invalid 3', async () => {\n    expect(\n        () => tf.sparse.sparseSegmentSum(\n            TensorValue10x4(), [8, 3, 0, 9], [0, 1, 2, 0]))\n        .toThrowError(\n            'Segment id 1 out of range [0, 1), possibly because segmentIds ' +\n            'input is not sorted.');\n  });\n\n  it('segments invalid 4', async () => {\n    expect(\n        () => tf.sparse.sparseSegmentSum(\n            TensorValue10x4(), [8, 3, 0, 9], [-1, 0, 1, 1]))\n        .toThrowError(\n            'Segment id -1 out of range [0, 2), possibly because segmentIds ' +\n            'input is not sorted.');\n  });\n\n  it('segments invalid 6', async () => {\n    expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [\n      0, 0, 0, -1\n    ])).toThrowError('segment ids must be >= 0');\n  });\n\n  it('segments invalid 7', async () => {\n    expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [\n      0, 0, 0, -2\n    ])).toThrowError('segment ids must be >= 0');\n  });\n\n  it('indices invalid rank', async () => {\n    expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [[8, 3, 0, 9]], [\n      0, 0, 0, -2\n    ])).toThrowError(/should be Tensor1D/);\n  });\n\n  it('segments invalid rank', async () => {\n    expect(() => tf.sparse.sparseSegmentSum(TensorValue10x4(), [8, 3, 0, 9], [\n      [0, 0, 0, -2]\n    ])).toThrowError(/should be Tensor1D/);\n  });\n});\n"]}
\No newline at end of file