UNPKG

26.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20import { PARALLELIZE_THRESHOLD } from './reduce_util';
21describeWithFlags('unsortedSegmentSum', ALL_ENVS, () => {
22 it('tensor1D', async () => {
23 const t = tf.tensor1d([1, 2, 3, 4]);
24 const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');
25 const numSegments = 3;
26 const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);
27 expect(res.shape).toEqual([numSegments]);
28 expectArraysClose(await res.data(), [4, 4, 2]);
29 });
30 it('tensor2D', async () => {
31 const t = tf.tensor2d([1, 2, 3, 4], [2, 2]);
32 const segmentIds = tf.tensor1d([0, 0], 'int32');
33 const numSegments = 2;
34 const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);
35 expect(res.shape).toEqual([numSegments, 2]);
36 expectArraysClose(await res.data(), [4, 6, 0, 0]);
37 });
38 it('tensor3D', async () => {
39 const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]);
40 const segmentIds = tf.tensor1d([2, 1, 2], 'int32');
41 const numSegments = 3;
42 const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);
43 expect(res.shape).toEqual([numSegments, 2, 2]);
44 expectArraysClose(await res.data(), [0, 0, 0, 0, 5, 6, 7, 8, 10, 12, 14, 16]);
45 });
46 it('N > than parallelization threshold, tensor1D', async () => {
47 const n = PARALLELIZE_THRESHOLD * 2;
48 const values = new Float32Array(n);
49 const numSegments = 5;
50 const segmentIdValues = new Float32Array(n);
51 const vals = new Float32Array(numSegments);
52 for (let i = 0; i < n; i++) {
53 values[i] = i;
54 segmentIdValues[i] = i % numSegments;
55 vals[i % numSegments] += i;
56 }
57 const t = tf.tensor1d(values);
58 const segmentIds = tf.tensor1d(segmentIdValues, 'int32');
59 const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);
60 expect(res.shape).toEqual([numSegments]);
61 expectArraysClose(await res.data(), vals);
62 });
63 it('ignores negative segmentIds', async () => {
64 const t = tf.tensor1d([1, 2, 3, 4]);
65 const segmentIds = tf.tensor1d([0, 2, -1, 1], 'int32');
66 const numSegments = 3;
67 const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);
68 expect(res.shape).toEqual([numSegments]);
69 expectArraysClose(await res.data(), [1, 4, 2]);
70 });
71 it('gradient ignores negative segmentIds', async () => {
72 const t = tf.tensor1d([1, 2, 3, 4]);
73 const segmentIds = tf.tensor1d([0, 2, -1, 1], 'int32');
74 const numSegments = 3;
75 const dy = tf.tensor1d([11, 2, 7]);
76 const gradient = tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);
77 expect(gradient.shape).toEqual(t.shape);
78 expectArraysClose(await gradient.data(), [11, 7, 0, 2]);
79 });
80 it('tensor1D gradient', async () => {
81 const t = tf.tensor1d([1, 2, 3, 4]);
82 const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');
83 const numSegments = 3;
84 const dy = tf.tensor1d([11, 2, 7]);
85 const gradient = tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);
86 expect(gradient.shape).toEqual(t.shape);
87 expectArraysClose(await gradient.data(), [11, 7, 11, 2]);
88 });
89 it('gradient with clones', async () => {
90 const t = tf.tensor1d([1, 2, 3, 4]);
91 const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');
92 const numSegments = 3;
93 const dy = tf.tensor1d([11, 2, 7]);
94 const gradient = tf.grad(a => tf.unsortedSegmentSum(a.clone(), segmentIds.clone(), numSegments)
95 .clone())(t, dy);
96 expect(gradient.shape).toEqual(t.shape);
97 expectArraysClose(await gradient.data(), [11, 7, 11, 2]);
98 });
99 it('tensor2D gradient', async () => {
100 const t = tf.tensor2d([1, 2, 3, 4], [2, 2]);
101 const segmentIds = tf.tensor1d([0, 0], 'int32');
102 const numSegments = 2;
103 const dy = tf.tensor2d([11, 2, 4, 5], [2, 2]);
104 const gradient = tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);
105 expect(gradient.shape).toEqual(t.shape);
106 expectArraysClose(await gradient.data(), [11, 2, 11, 2]);
107 });
108 it('tensor3D gradient', async () => {
109 const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]);
110 const segmentIds = tf.tensor1d([2, 1, 2], 'int32');
111 const numSegments = 3;
112 const dy = tf.tensor3d([11, 2, 4, 5, 17, 31, 1, 0, -1, 14, 3, 28], [3, 2, 2]);
113 const gradient = tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);
114 expect(gradient.shape).toEqual(t.shape);
115 expectArraysClose(await gradient.data(), [-1, 14, 3, 28, 17, 31, 1, 0, -1, 14, 3, 28]);
116 });
117 it('accepts a tensor-like object', async () => {
118 const x = [1, 2, 3, 4];
119 const segmentIds = [0, 2, 0, 1];
120 const numSegments = 3;
121 const res = tf.unsortedSegmentSum(x, segmentIds, numSegments);
122 expect(res.shape).toEqual([3]);
123 expectArraysClose(await res.data(), [4, 4, 2]);
124 });
125 it('accepts a tensor-like object chained', async () => {
126 const x = tf.tensor1d([1, 2, 3, 4]);
127 const segmentIds = [0, 2, 0, 1];
128 const numSegments = 3;
129 const res = x.unsortedSegmentSum(segmentIds, numSegments);
130 expect(res.shape).toEqual([3]);
131 expectArraysClose(await res.data(), [4, 4, 2]);
132 });
133});
134//# sourceMappingURL=data:application/json;base64,
\No newline at end of file