1 | /**
|
2 | * @license
|
3 | * Copyright 2021 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import { ENGINE } from '../../engine';
|
18 | import { SparseSegmentMean } from '../../kernel_names';
|
19 | import { convertToTensor } from '../../tensor_util_env';
|
20 | import { op } from '../operation';
|
21 | /**
|
22 | * Computes the mean along sparse segments of a tensor.
|
23 | *
|
24 | * ```js
|
25 | * const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [6,7,8,9]]);
|
26 | * // Select two rows, one segment.
|
27 | * const result1 = tf.sparse.sparseSegmentMean(c,
|
28 | * tf.tensor1d([0, 1], 'int32'),
|
29 | * tf.tensor1d([0, 0], 'int32'));
|
30 | * result1.print(); // [[0, 0, 0, 0]]
|
31 | *
|
32 | * // Select two rows, two segments.
|
33 | * const result2 = tf.sparse.sparseSegmentMean(c,
|
34 | * tf.tensor1d([0, 1], 'int32'),
|
35 | * tf.tensor1d([0, 1], 'int32'));
|
36 | * result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
|
37 | *
|
38 | * // Select all rows, two segments.
|
39 | * const result3 = tf.sparse.sparseSegmentMean(c,
|
40 | * tf.tensor1d([0, 1, 2], 'int32'),
|
41 | * tf.tensor1d([0, 1, 1], 'int32'));
|
42 | * result3.print(); // [[1.0, 2.0, 3.0, 4.0], [2.5, 2.5, 2.5, 2.5]]
|
43 | * ```
|
44 | * @param data: A Tensor of at least one dimension with data that will be
|
45 | * assembled in the output.
|
46 | * @param indices: A 1-D Tensor with indices into data. Has same rank as
|
47 | * segmentIds.
|
48 | * @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
|
49 | * should be sorted and can be repeated.
|
50 | * @return Has same shape as data, except for dimension 0 which has equal to
|
51 | * the number of segments.
|
52 | *
|
53 | * @doc {heading: 'Operations', subheading: 'Sparse'}
|
54 | */
|
55 | function sparseSegmentMean_(data, indices, segmentIds) {
|
56 | const $data = convertToTensor(data, 'data', 'sparseSegmentMean');
|
57 | const $indices = convertToTensor(indices, 'indices', 'sparseSegmentMean', 'int32');
|
58 | const $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentMean', 'int32');
|
59 | if ($data.rank < 1) {
|
60 | throw new Error(`Data should be at least 1 dimensional but received scalar`);
|
61 | }
|
62 | if ($indices.rank !== 1) {
|
63 | throw new Error(`Indices should be Tensor1D but received shape
|
64 | ${$indices.shape}`);
|
65 | }
|
66 | if ($segmentIds.rank !== 1) {
|
67 | throw new Error(`Segment ids should be Tensor1D but received shape
|
68 | ${$segmentIds.shape}`);
|
69 | }
|
70 | const inputs = {
|
71 | data: $data,
|
72 | indices: $indices,
|
73 | segmentIds: $segmentIds
|
74 | };
|
75 | return ENGINE.runKernel(SparseSegmentMean, inputs);
|
76 | }
|
77 | export const sparseSegmentMean = op({ sparseSegmentMean_ });
|
78 | //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoic3BhcnNlX3NlZ21lbnRfbWVhbi5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvb3BzL3NwYXJzZS9zcGFyc2Vfc2VnbWVudF9tZWFuLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxjQUFjLENBQUM7QUFDcEMsT0FBTyxFQUFDLGlCQUFpQixFQUEwQixNQUFNLG9CQUFvQixDQUFDO0FBRTlFLE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUV0RCxPQUFPLEVBQUMsRUFBRSxFQUFDLE1BQU0sY0FBYyxDQUFDO0FBRWhDOzs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7R0FpQ0c7QUFDSCxTQUFTLGtCQUFrQixDQUN2QixJQUF1QixFQUFFLE9BQTRCLEVBQ3JELFVBQStCO0lBQ2pDLE1BQU0sS0FBSyxHQUFHLGVBQWUsQ0FBQyxJQUFJLEVBQUUsTUFBTSxFQUFFLG1CQUFtQixDQUFDLENBQUM7SUFDakUsTUFBTSxRQUFRLEdBQ1YsZUFBZSxDQUFDLE9BQU8sRUFBRSxTQUFTLEVBQUUsbUJBQW1CLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFDdEUsTUFBTSxXQUFXLEdBQ2IsZUFBZSxDQUFDLFVBQVUsRUFBRSxZQUFZLEVBQUUsbUJBQW1CLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFFNUUsSUFBSSxLQUFLLENBQUMsSUFBSSxHQUFHLENBQUMsRUFBRTtRQUNsQixNQUFNLElBQUksS0FBSyxDQUNYLDJEQUEyRCxDQUFDLENBQUM7S0FDbEU7SUFDRCxJQUFJLFFBQVEsQ0FBQyxJQUFJLEtBQUssQ0FBQyxFQUFFO1FBQ3ZCLE1BQU0sSUFBSSxLQUFLLENBQUM7WUFDUixRQUFRLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQztLQUMzQjtJQUNELElBQUksV0FBVyxDQUFDLElBQUksS0FBSyxDQUFDLEVBQUU7UUFDMUIsTUFBTSxJQUFJLEtBQUssQ0FBQztZQUNSLFdBQVcsQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDO0tBQzlCO0lBRUQsTUFBTSxNQUFNLEdBQTRCO1FBQ3RDLElBQUksRUFBRSxLQUFLO1FBQ1gsT0FBTyxFQUFFLFFBQVE7UUFDakIsVUFBVSxFQUFFLFdBQVc7S0FDeEIsQ0FBQztJQUVGLE9BQU8sTUFBTSxDQUFDLFNBQVMsQ0FBQyxpQkFBaUIsRUFBRSxNQUFZLENBQUMsQ0FBQztBQUMzRCxDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0saUJBQWlCLEdBQUcsRUFBRSxDQUFDLEVBQUMsa0JBQWtCLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjEgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge0VOR0lORX0gZnJvbSAnLi4vLi4vZW5naW5lJztcbmltcG9ydCB7U3BhcnNlU2VnbWVudE1lYW4sIFNwYXJzZVNlZ21lbnRNZWFuSW5wdXRzfSBmcm9tICcuLi8uLi9rZXJuZWxfbmFtZXMnO1xuaW1wb3J0IHtUZW5zb3IsIFRlbnNvcjFEfSBmcm9tICcuLi8uLi90ZW5zb3InO1xuaW1wb3J0IHtjb252ZXJ0VG9UZW5zb3J9IGZyb20gJy4uLy4uL3RlbnNvcl91dGlsX2Vudic7XG5pbXBvcnQge1RlbnNvckxpa2V9IGZyb20gJy4uLy4uL3R5cGVzJztcbmltcG9ydCB7b3B9IGZyb20gJy4uL29wZXJhdGlvbic7XG5cbi8qKlxuICogQ29tcHV0ZXMgdGhlIG1lYW4gYWxvbmcgc3BhcnNlIHNlZ21lbnRzIG9mIGEgdGVuc29yLlxuICpcbiAqIGBgYGpzXG4gKiBjb25zdCBjID0gdGYudGVuc29yMmQoW1sxLDIsMyw0XSwgWy0xLC0yLC0zLC00XSwgWzYsNyw4LDldXSk7XG4gKiAvLyBTZWxlY3QgdHdvIHJvd3MsIG9uZSBzZWdtZW50LlxuICogY29uc3QgcmVzdWx0MSA9IHRmLnNwYXJzZS5zcGFyc2VTZWdtZW50TWVhbihjLFxuICogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdGYudGVuc29yMWQoWzAsIDFdLCAnaW50MzInKSxcbiAqICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHRmLnRlbnNvcjFkKFswLCAwXSwgJ2ludDMyJykpO1xuICogcmVzdWx0MS5wcmludCgpOyAvLyBbWzAsIDAsIDAsIDBdXVxuICpcbiAqIC8vIFNlbGVjdCB0d28gcm93cywgdHdvIHNlZ21lbnRzLlxuICogY29uc3QgcmVzdWx0MiA9IHRmLnNwYXJzZS5zcGFyc2VTZWdtZW50TWVhbihjLFxuICogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0Zi50ZW5zb3IxZChbMCwgMV0sICdpbnQzMicpLFxuICogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0Zi50ZW5zb3IxZChbMCwgMV0sICdpbnQzMicpKTtcbiAqIHJlc3VsdDIucHJpbnQoKTsgLy8gW1sxLCAyLCAzLCA0XSwgWy0xLCAtMiwgLTMsIC00XV1cbiAqXG4gKiAvLyBTZWxlY3QgYWxsIHJvd3MsIHR3byBzZWdtZW50cy5cbiAqIGNvbnN0IHJlc3VsdDMgPSB0Zi5zcGFyc2Uuc3BhcnNlU2VnbWVudE1lYW4oYyxcbiAqICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdGYudGVuc29yMWQoWzAsIDEsIDJdLCAnaW50MzInKSxcbiAqICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdGYudGVuc29yMWQoWzAsIDEsIDFdLCAnaW50MzInKSk7XG4gKiByZXN1bHQzLnByaW50KCk7IC8vIFtbMS4wLCAyLjAsIDMuMCwgNC4wXSwgWzIuNSwgMi41LCAyLjUsIDIuNV1dXG4gKiBgYGBcbiAqIEBwYXJhbSBkYXRhOiBBIFRlbnNvciBvZiBhdCBsZWFzdCBvbmUgZGltZW5zaW9uIHdpdGggZGF0YSB0aGF0IHdpbGwgYmVcbiAqICAgICBhc3NlbWJsZWQgaW4gdGhlIG91dHB1dC5cbiAqIEBwYXJhbSBpbmRpY2VzOiBBIDEtRCBUZW5zb3Igd2l0aCBpbmRpY2VzIGludG8gZGF0YS4gSGFzIHNhbWUgcmFuayBhc1xuICogICAgIHNlZ21lbnRJZHMuXG4gKiBAcGFyYW0gc2VnbWVudElkczogQSAxLUQgVGVuc29yIHdpdGggaW5kaWNlcyBpbnRvIHRoZSBvdXRwdXQgVGVuc29yLiBWYWx1ZXNcbiAqICAgICBzaG91bGQgYmUgc29ydGVkIGFuZCBjYW4gYmUgcmVwZWF0ZWQuXG4gKiBAcmV0dXJuIEhhcyBzYW1lIHNoYXBlIGFzIGRhdGEsIGV4Y2VwdCBmb3IgZGltZW5zaW9uIDAgd2hpY2ggaGFzIGVxdWFsIHRvXG4gKiAgICAgICAgIHRoZSBudW1iZXIgb2Ygc2VnbWVudHMuXG4gKlxuICogQGRvYyB7aGVhZGluZzogJ09wZXJhdGlvbnMnLCBzdWJoZWFkaW5nOiAnU3BhcnNlJ31cbiAqL1xuZnVuY3Rpb24gc3BhcnNlU2VnbWVudE1lYW5fKFxuICAgIGRhdGE6IFRlbnNvcnxUZW5zb3JMaWtlLCBpbmRpY2VzOiBUZW5zb3IxRHxUZW5zb3JMaWtlLFxuICAgIHNlZ21lbnRJZHM6IFRlbnNvcjFEfFRlbnNvckxpa2UpOiBUZW5zb3Ige1xuICBjb25zdCAkZGF0YSA9IGNvbnZlcnRUb1RlbnNvcihkYXRhLCAnZGF0YScsICdzcGFyc2VTZWdtZW50TWVhbicpO1xuICBjb25zdCAkaW5kaWNlcyA9XG4gICAgICBjb252ZXJ0VG9UZW5zb3IoaW5kaWNlcywgJ2luZGljZXMnLCAnc3BhcnNlU2VnbWVudE1lYW4nLCAnaW50MzInKTtcbiAgY29uc3QgJHNlZ21lbnRJZHMgPVxuICAgICAgY29udmVydFRvVGVuc29yKHNlZ21lbnRJZHMsICdzZWdtZW50SWRzJywgJ3NwYXJzZVNlZ21lbnRNZWFuJywgJ2ludDMyJyk7XG5cbiAgaWYgKCRkYXRhLnJhbmsgPCAxKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICBgRGF0YSBzaG91bGQgYmUgYXQgbGVhc3QgMSBkaW1lbnNpb25hbCBidXQgcmVjZWl2ZWQgc2NhbGFyYCk7XG4gIH1cbiAgaWYgKCRpbmRpY2VzLnJhbmsgIT09IDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYEluZGljZXMgc2hvdWxkIGJlIFRlbnNvcjFEIGJ1dCByZWNlaXZlZCBzaGFwZVxuICAgICAgICAgICR7JGluZGljZXMuc2hhcGV9YCk7XG4gIH1cbiAgaWYgKCRzZWdtZW50SWRzLnJhbmsgIT09IDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYFNlZ21lbnQgaWRzIHNob3VsZCBiZSBUZW5zb3IxRCBidXQgcmVjZWl2ZWQgc2hhcGVcbiAgICAgICAgICAkeyRzZWdtZW50SWRzLnNoYXBlfWApO1xuICB9XG5cbiAgY29uc3QgaW5wdXRzOiBTcGFyc2VTZWdtZW50TWVhbklucHV0cyA9IHtcbiAgICBkYXRhOiAkZGF0YSxcbiAgICBpbmRpY2VzOiAkaW5kaWNlcyxcbiAgICBzZWdtZW50SWRzOiAkc2VnbWVudElkc1xuICB9O1xuXG4gIHJldHVybiBFTkdJTkUucnVuS2VybmVsKFNwYXJzZVNlZ21lbnRNZWFuLCBpbnB1dHMgYXMge30pO1xufVxuXG5leHBvcnQgY29uc3Qgc3BhcnNlU2VnbWVudE1lYW4gPSBvcCh7c3BhcnNlU2VnbWVudE1lYW5ffSk7XG4iXX0= |
\ | No newline at end of file |