UNPKG

26.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20import { PARALLELIZE_THRESHOLD } from './reduce_util';
21describeWithFlags('unsortedSegmentSum', ALL_ENVS, () => {
22 it('tensor1D', async () => {
23 const t = tf.tensor1d([1, 2, 3, 4]);
24 const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');
25 const numSegments = 3;
26 const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);
27 expect(res.shape).toEqual([numSegments]);
28 expectArraysClose(await res.data(), [4, 4, 2]);
29 });
30 it('tensor2D', async () => {
31 const t = tf.tensor2d([1, 2, 3, 4], [2, 2]);
32 const segmentIds = tf.tensor1d([0, 0], 'int32');
33 const numSegments = 2;
34 const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);
35 expect(res.shape).toEqual([numSegments, 2]);
36 expectArraysClose(await res.data(), [4, 6, 0, 0]);
37 });
38 it('tensor3D', async () => {
39 const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]);
40 const segmentIds = tf.tensor1d([2, 1, 2], 'int32');
41 const numSegments = 3;
42 const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);
43 expect(res.shape).toEqual([numSegments, 2, 2]);
44 expectArraysClose(await res.data(), [0, 0, 0, 0, 5, 6, 7, 8, 10, 12, 14, 16]);
45 });
46 it('N > than parallelization threshold, tensor1D', async () => {
47 const n = PARALLELIZE_THRESHOLD * 2;
48 const values = new Float32Array(n);
49 const numSegments = 5;
50 const segmentIdValues = new Float32Array(n);
51 const vals = new Float32Array(numSegments);
52 for (let i = 0; i < n; i++) {
53 values[i] = i;
54 segmentIdValues[i] = i % numSegments;
55 vals[i % numSegments] += i;
56 }
57 const t = tf.tensor1d(values);
58 const segmentIds = tf.tensor1d(segmentIdValues, 'int32');
59 const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);
60 expect(res.shape).toEqual([numSegments]);
61 expectArraysClose(await res.data(), vals);
62 });
63 it('ignores negative segmentIds', async () => {
64 const t = tf.tensor1d([1, 2, 3, 4]);
65 const segmentIds = tf.tensor1d([0, 2, -1, 1], 'int32');
66 const numSegments = 3;
67 const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);
68 expect(res.shape).toEqual([numSegments]);
69 expectArraysClose(await res.data(), [1, 4, 2]);
70 });
71 it('gradient ignores negative segmentIds', async () => {
72 const t = tf.tensor1d([1, 2, 3, 4]);
73 const segmentIds = tf.tensor1d([0, 2, -1, 1], 'int32');
74 const numSegments = 3;
75 const dy = tf.tensor1d([11, 2, 7]);
76 const gradient = tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);
77 expect(gradient.shape).toEqual(t.shape);
78 expectArraysClose(await gradient.data(), [11, 7, 0, 2]);
79 });
80 it('tensor1D gradient', async () => {
81 const t = tf.tensor1d([1, 2, 3, 4]);
82 const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');
83 const numSegments = 3;
84 const dy = tf.tensor1d([11, 2, 7]);
85 const gradient = tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);
86 expect(gradient.shape).toEqual(t.shape);
87 expectArraysClose(await gradient.data(), [11, 7, 11, 2]);
88 });
89 it('gradient with clones', async () => {
90 const t = tf.tensor1d([1, 2, 3, 4]);
91 const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');
92 const numSegments = 3;
93 const dy = tf.tensor1d([11, 2, 7]);
94 const gradient = tf.grad(a => tf.unsortedSegmentSum(a.clone(), segmentIds.clone(), numSegments)
95 .clone())(t, dy);
96 expect(gradient.shape).toEqual(t.shape);
97 expectArraysClose(await gradient.data(), [11, 7, 11, 2]);
98 });
99 it('tensor2D gradient', async () => {
100 const t = tf.tensor2d([1, 2, 3, 4], [2, 2]);
101 const segmentIds = tf.tensor1d([0, 0], 'int32');
102 const numSegments = 2;
103 const dy = tf.tensor2d([11, 2, 4, 5], [2, 2]);
104 const gradient = tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);
105 expect(gradient.shape).toEqual(t.shape);
106 expectArraysClose(await gradient.data(), [11, 2, 11, 2]);
107 });
108 it('tensor3D gradient', async () => {
109 const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]);
110 const segmentIds = tf.tensor1d([2, 1, 2], 'int32');
111 const numSegments = 3;
112 const dy = tf.tensor3d([11, 2, 4, 5, 17, 31, 1, 0, -1, 14, 3, 28], [3, 2, 2]);
113 const gradient = tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);
114 expect(gradient.shape).toEqual(t.shape);
115 expectArraysClose(await gradient.data(), [-1, 14, 3, 28, 17, 31, 1, 0, -1, 14, 3, 28]);
116 });
117 it('accepts a tensor-like object', async () => {
118 const x = [1, 2, 3, 4];
119 const segmentIds = [0, 2, 0, 1];
120 const numSegments = 3;
121 const res = tf.unsortedSegmentSum(x, segmentIds, numSegments);
122 expect(res.shape).toEqual([3]);
123 expectArraysClose(await res.data(), [4, 4, 2]);
124 });
125 it('accepts a tensor-like object chained', async () => {
126 const x = tf.tensor1d([1, 2, 3, 4]);
127 const segmentIds = [0, 2, 0, 1];
128 const numSegments = 3;
129 const res = x.unsortedSegmentSum(segmentIds, numSegments);
130 expect(res.shape).toEqual([3]);
131 expectArraysClose(await res.data(), [4, 4, 2]);
132 });
133});
134//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"unsorted_segment_sum_test.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/unsorted_segment_sum_test.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,EAAE,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAC,MAAM,iBAAiB,CAAC;AAC5D,OAAO,EAAC,iBAAiB,EAAC,MAAM,cAAc,CAAC;AAC/C,OAAO,EAAC,qBAAqB,EAAC,MAAM,eAAe,CAAC;AAEpD,iBAAiB,CAAC,oBAAoB,EAAE,QAAQ,EAAE,GAAG,EAAE;IACrD,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACtD,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAE9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC;QACzC,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5C,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAChD,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAE9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5C,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC1E,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACnD,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAE9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC/C,iBAAiB,CACb,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;IAClE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,8CAA8C,EAAE,KAAK,IAAI,EAAE;QAC5D,MAAM,CAAC,GAAG,qBAAqB,GAAG,CAAC,CAAC;QACpC,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACnC,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,eAAe,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QAC5C,MAAM,IAAI,GAAG,IAAI,YAAY,CAAC,WAAW,CAAC,CAAC;QAC3C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE;YAC1B,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;YACd,eAAe,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,WAAW,CAAC;YACrC,IAAI,CAAC,CAAC,GAAG,WAAW,CAAC,IAAI,CAAC,CAAC;SAC5B;QACD,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC;QAC9B,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,eAAe,EAAE,OAAO,CAAC,CAAC;QACzD,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAE9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC;QACzC,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,IAAI,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,6BAA6B,EAAE,KAAK,IAAI,EAAE;QAC3C,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACvD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAE9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC;QACzC,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sCAAsC,EAAE,KAAK,IAAI,EAAE;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACvD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnC,MAAM,QAAQ,GACV,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE3E,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACxC,iBAAiB,CAAC,MAAM,QAAQ,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC1D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACtD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnC,MAAM,QAAQ,GACV,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE3E,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACxC,iBAAiB,CAAC,MAAM,QAAQ,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACtD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnC,MAAM,QAAQ,GAAG,EAAE,CAAC,IAAI,CACpB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,kBAAkB,CAAC,CAAC,CAAC,KAAK,EAAE,EAAE,UAAU,CAAC,KAAK,EAAE,EAAE,WAAW,CAAC;aAC5D,KAAK,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE9B,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACxC,iBAAiB,CAAC,MAAM,QAAQ,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5C,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAChD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC9C,MAAM,QAAQ,GACV,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE3E,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACxC,iBAAiB,CAAC,MAAM,QAAQ,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC1E,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACnD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,EAAE,GACJ,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACvE,MAAM,QAAQ,GACV,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE3E,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACxC,iBAAiB,CACb,MAAM,QAAQ,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;IAC3E,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,8BAA8B,EAAE,KAAK,IAAI,EAAE;QAC5C,MAAM,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QACvB,MAAM,UAAU,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QAChC,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAC9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/B,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sCAAsC,EAAE,KAAK,IAAI,EAAE;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QAChC,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,CAAC,CAAC,kBAAkB,CAAC,UAAU,EAAE,WAAW,CAAC,CAAC;QAE1D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/B,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tf from '../index';\nimport {ALL_ENVS, describeWithFlags} from '../jasmine_util';\nimport {expectArraysClose} from '../test_util';\nimport {PARALLELIZE_THRESHOLD} from './reduce_util';\n\ndescribeWithFlags('unsortedSegmentSum', ALL_ENVS, () => {\n  it('tensor1D', async () => {\n    const t = tf.tensor1d([1, 2, 3, 4]);\n    const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');\n    const numSegments = 3;\n    const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);\n\n    expect(res.shape).toEqual([numSegments]);\n    expectArraysClose(await res.data(), [4, 4, 2]);\n  });\n\n  it('tensor2D', async () => {\n    const t = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n    const segmentIds = tf.tensor1d([0, 0], 'int32');\n    const numSegments = 2;\n    const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);\n\n    expect(res.shape).toEqual([numSegments, 2]);\n    expectArraysClose(await res.data(), [4, 6, 0, 0]);\n  });\n\n  it('tensor3D', async () => {\n    const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]);\n    const segmentIds = tf.tensor1d([2, 1, 2], 'int32');\n    const numSegments = 3;\n    const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);\n\n    expect(res.shape).toEqual([numSegments, 2, 2]);\n    expectArraysClose(\n        await res.data(), [0, 0, 0, 0, 5, 6, 7, 8, 10, 12, 14, 16]);\n  });\n\n  it('N > than parallelization threshold, tensor1D', async () => {\n    const n = PARALLELIZE_THRESHOLD * 2;\n    const values = new Float32Array(n);\n    const numSegments = 5;\n    const segmentIdValues = new Float32Array(n);\n    const vals = new Float32Array(numSegments);\n    for (let i = 0; i < n; i++) {\n      values[i] = i;\n      segmentIdValues[i] = i % numSegments;\n      vals[i % numSegments] += i;\n    }\n    const t = tf.tensor1d(values);\n    const segmentIds = tf.tensor1d(segmentIdValues, 'int32');\n    const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);\n\n    expect(res.shape).toEqual([numSegments]);\n    expectArraysClose(await res.data(), vals);\n  });\n\n  it('ignores negative segmentIds', async () => {\n    const t = tf.tensor1d([1, 2, 3, 4]);\n    const segmentIds = tf.tensor1d([0, 2, -1, 1], 'int32');\n    const numSegments = 3;\n\n    const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);\n\n    expect(res.shape).toEqual([numSegments]);\n    expectArraysClose(await res.data(), [1, 4, 2]);\n  });\n\n  it('gradient ignores negative segmentIds', async () => {\n    const t = tf.tensor1d([1, 2, 3, 4]);\n    const segmentIds = tf.tensor1d([0, 2, -1, 1], 'int32');\n    const numSegments = 3;\n\n    const dy = tf.tensor1d([11, 2, 7]);\n    const gradient =\n        tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);\n\n    expect(gradient.shape).toEqual(t.shape);\n    expectArraysClose(await gradient.data(), [11, 7, 0, 2]);\n  });\n\n  it('tensor1D gradient', async () => {\n    const t = tf.tensor1d([1, 2, 3, 4]);\n    const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');\n    const numSegments = 3;\n\n    const dy = tf.tensor1d([11, 2, 7]);\n    const gradient =\n        tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);\n\n    expect(gradient.shape).toEqual(t.shape);\n    expectArraysClose(await gradient.data(), [11, 7, 11, 2]);\n  });\n\n  it('gradient with clones', async () => {\n    const t = tf.tensor1d([1, 2, 3, 4]);\n    const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');\n    const numSegments = 3;\n\n    const dy = tf.tensor1d([11, 2, 7]);\n    const gradient = tf.grad(\n        a => tf.unsortedSegmentSum(a.clone(), segmentIds.clone(), numSegments)\n                 .clone())(t, dy);\n\n    expect(gradient.shape).toEqual(t.shape);\n    expectArraysClose(await gradient.data(), [11, 7, 11, 2]);\n  });\n\n  it('tensor2D gradient', async () => {\n    const t = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n    const segmentIds = tf.tensor1d([0, 0], 'int32');\n    const numSegments = 2;\n\n    const dy = tf.tensor2d([11, 2, 4, 5], [2, 2]);\n    const gradient =\n        tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);\n\n    expect(gradient.shape).toEqual(t.shape);\n    expectArraysClose(await gradient.data(), [11, 2, 11, 2]);\n  });\n\n  it('tensor3D gradient', async () => {\n    const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]);\n    const segmentIds = tf.tensor1d([2, 1, 2], 'int32');\n    const numSegments = 3;\n\n    const dy =\n        tf.tensor3d([11, 2, 4, 5, 17, 31, 1, 0, -1, 14, 3, 28], [3, 2, 2]);\n    const gradient =\n        tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);\n\n    expect(gradient.shape).toEqual(t.shape);\n    expectArraysClose(\n        await gradient.data(), [-1, 14, 3, 28, 17, 31, 1, 0, -1, 14, 3, 28]);\n  });\n\n  it('accepts a tensor-like object', async () => {\n    const x = [1, 2, 3, 4];\n    const segmentIds = [0, 2, 0, 1];\n    const numSegments = 3;\n    const res = tf.unsortedSegmentSum(x, segmentIds, numSegments);\n    expect(res.shape).toEqual([3]);\n    expectArraysClose(await res.data(), [4, 4, 2]);\n  });\n\n  it('accepts a tensor-like object chained', async () => {\n    const x = tf.tensor1d([1, 2, 3, 4]);\n    const segmentIds = [0, 2, 0, 1];\n    const numSegments = 3;\n    const res = x.unsortedSegmentSum(segmentIds, numSegments);\n\n    expect(res.shape).toEqual([3]);\n    expectArraysClose(await res.data(), [4, 4, 2]);\n  });\n});\n"]}
\No newline at end of file