UNPKG

14.9 kBSource Map (JSON)View Raw
1{"version":3,"file":"unsorted_segment_sum_test.js","sourceRoot":"","sources":["../../src/ops/unsorted_segment_sum_test.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,EAAE,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAC,MAAM,iBAAiB,CAAC;AAC5D,OAAO,EAAC,iBAAiB,EAAC,MAAM,cAAc,CAAC;AAC/C,OAAO,EAAC,qBAAqB,EAAC,MAAM,eAAe,CAAC;AAEpD,iBAAiB,CAAC,oBAAoB,EAAE,QAAQ,EAAE,GAAG,EAAE;IACrD,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACtD,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAE9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC;QACzC,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5C,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAChD,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAE9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5C,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC1E,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACnD,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAE9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC/C,iBAAiB,CACb,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;IAClE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,8CAA8C,EAAE,KAAK,IAAI,EAAE;QAC5D,MAAM,CAAC,GAAG,qBAAqB,GAAG,CAAC,CAAC;QACpC,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACnC,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,eAAe,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QAC5C,MAAM,IAAI,GAAG,IAAI,YAAY,CAAC,WAAW,CAAC,CAAC;QAC3C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE;YAC1B,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;YACd,eAAe,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,WAAW,CAAC;YACrC,IAAI,CAAC,CAAC,GAAG,WAAW,CAAC,IAAI,CAAC,CAAC;SAC5B;QACD,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC;QAC9B,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,eAAe,EAAE,OAAO,CAAC,CAAC;QACzD,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAE9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC;QACzC,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,IAAI,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,6BAA6B,EAAE,KAAK,IAAI,EAAE;QAC3C,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACvD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAE9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC;QACzC,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sCAAsC,EAAE,KAAK,IAAI,EAAE;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACvD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnC,MAAM,QAAQ,GACV,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE3E,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACxC,iBAAiB,CAAC,MAAM,QAAQ,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC1D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACtD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnC,MAAM,QAAQ,GACV,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE3E,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACxC,iBAAiB,CAAC,MAAM,QAAQ,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACtD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnC,MAAM,QAAQ,GAAG,EAAE,CAAC,IAAI,CACpB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,kBAAkB,CAAC,CAAC,CAAC,KAAK,EAAE,EAAE,UAAU,CAAC,KAAK,EAAE,EAAE,WAAW,CAAC;aAC5D,KAAK,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE9B,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACxC,iBAAiB,CAAC,MAAM,QAAQ,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5C,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAChD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC9C,MAAM,QAAQ,GACV,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE3E,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACxC,iBAAiB,CAAC,MAAM,QAAQ,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC1E,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACnD,MAAM,WAAW,GAAG,CAAC,CAAC;QAEtB,MAAM,EAAE,GACJ,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACvE,MAAM,QAAQ,GACV,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE3E,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACxC,iBAAiB,CACb,MAAM,QAAQ,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;IAC3E,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,8BAA8B,EAAE,KAAK,IAAI,EAAE;QAC5C,MAAM,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QACvB,MAAM,UAAU,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QAChC,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,EAAE,UAAU,EAAE,WAAW,CAAC,CAAC;QAC9D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/B,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sCAAsC,EAAE,KAAK,IAAI,EAAE;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,UAAU,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QAChC,MAAM,WAAW,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,CAAC,CAAC,kBAAkB,CAAC,UAAU,EAAE,WAAW,CAAC,CAAC;QAE1D,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/B,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tf from '../index';\nimport {ALL_ENVS, describeWithFlags} from '../jasmine_util';\nimport {expectArraysClose} from '../test_util';\nimport {PARALLELIZE_THRESHOLD} from './reduce_util';\n\ndescribeWithFlags('unsortedSegmentSum', ALL_ENVS, () => {\n it('tensor1D', async () => {\n const t = tf.tensor1d([1, 2, 3, 4]);\n const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');\n const numSegments = 3;\n const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);\n\n expect(res.shape).toEqual([numSegments]);\n expectArraysClose(await res.data(), [4, 4, 2]);\n });\n\n it('tensor2D', async () => {\n const t = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n const segmentIds = tf.tensor1d([0, 0], 'int32');\n const numSegments = 2;\n const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);\n\n expect(res.shape).toEqual([numSegments, 2]);\n expectArraysClose(await res.data(), [4, 6, 0, 0]);\n });\n\n it('tensor3D', async () => {\n const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]);\n const segmentIds = tf.tensor1d([2, 1, 2], 'int32');\n const numSegments = 3;\n const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);\n\n expect(res.shape).toEqual([numSegments, 2, 2]);\n expectArraysClose(\n await res.data(), [0, 0, 0, 0, 5, 6, 7, 8, 10, 12, 14, 16]);\n });\n\n it('N > than parallelization threshold, tensor1D', async () => {\n const n = PARALLELIZE_THRESHOLD * 2;\n const values = new Float32Array(n);\n const numSegments = 5;\n const segmentIdValues = new Float32Array(n);\n const vals = new Float32Array(numSegments);\n for (let i = 0; i < n; i++) {\n values[i] = i;\n segmentIdValues[i] = i % numSegments;\n vals[i % numSegments] += i;\n }\n const t = tf.tensor1d(values);\n const segmentIds = tf.tensor1d(segmentIdValues, 'int32');\n const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);\n\n expect(res.shape).toEqual([numSegments]);\n expectArraysClose(await res.data(), vals);\n });\n\n it('ignores negative segmentIds', async () => {\n const t = tf.tensor1d([1, 2, 3, 4]);\n const segmentIds = tf.tensor1d([0, 2, -1, 1], 'int32');\n const numSegments = 3;\n\n const res = tf.unsortedSegmentSum(t, segmentIds, numSegments);\n\n expect(res.shape).toEqual([numSegments]);\n expectArraysClose(await res.data(), [1, 4, 2]);\n });\n\n it('gradient ignores negative segmentIds', async () => {\n const t = tf.tensor1d([1, 2, 3, 4]);\n const segmentIds = tf.tensor1d([0, 2, -1, 1], 'int32');\n const numSegments = 3;\n\n const dy = tf.tensor1d([11, 2, 7]);\n const gradient =\n tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);\n\n expect(gradient.shape).toEqual(t.shape);\n expectArraysClose(await gradient.data(), [11, 7, 0, 2]);\n });\n\n it('tensor1D gradient', async () => {\n const t = tf.tensor1d([1, 2, 3, 4]);\n const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');\n const numSegments = 3;\n\n const dy = tf.tensor1d([11, 2, 7]);\n const gradient =\n tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);\n\n expect(gradient.shape).toEqual(t.shape);\n expectArraysClose(await gradient.data(), [11, 7, 11, 2]);\n });\n\n it('gradient with clones', async () => {\n const t = tf.tensor1d([1, 2, 3, 4]);\n const segmentIds = tf.tensor1d([0, 2, 0, 1], 'int32');\n const numSegments = 3;\n\n const dy = tf.tensor1d([11, 2, 7]);\n const gradient = tf.grad(\n a => tf.unsortedSegmentSum(a.clone(), segmentIds.clone(), numSegments)\n .clone())(t, dy);\n\n expect(gradient.shape).toEqual(t.shape);\n expectArraysClose(await gradient.data(), [11, 7, 11, 2]);\n });\n\n it('tensor2D gradient', async () => {\n const t = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n const segmentIds = tf.tensor1d([0, 0], 'int32');\n const numSegments = 2;\n\n const dy = tf.tensor2d([11, 2, 4, 5], [2, 2]);\n const gradient =\n tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);\n\n expect(gradient.shape).toEqual(t.shape);\n expectArraysClose(await gradient.data(), [11, 2, 11, 2]);\n });\n\n it('tensor3D gradient', async () => {\n const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]);\n const segmentIds = tf.tensor1d([2, 1, 2], 'int32');\n const numSegments = 3;\n\n const dy =\n tf.tensor3d([11, 2, 4, 5, 17, 31, 1, 0, -1, 14, 3, 28], [3, 2, 2]);\n const gradient =\n tf.grad(a => tf.unsortedSegmentSum(a, segmentIds, numSegments))(t, dy);\n\n expect(gradient.shape).toEqual(t.shape);\n expectArraysClose(\n await gradient.data(), [-1, 14, 3, 28, 17, 31, 1, 0, -1, 14, 3, 28]);\n });\n\n it('accepts a tensor-like object', async () => {\n const x = [1, 2, 3, 4];\n const segmentIds = [0, 2, 0, 1];\n const numSegments = 3;\n const res = tf.unsortedSegmentSum(x, segmentIds, numSegments);\n expect(res.shape).toEqual([3]);\n expectArraysClose(await res.data(), [4, 4, 2]);\n });\n\n it('accepts a tensor-like object chained', async () => {\n const x = tf.tensor1d([1, 2, 3, 4]);\n const segmentIds = [0, 2, 0, 1];\n const numSegments = 3;\n const res = x.unsortedSegmentSum(segmentIds, numSegments);\n\n expect(res.shape).toEqual([3]);\n expectArraysClose(await res.data(), [4, 4, 2]);\n });\n});\n"]}
\No newline at end of file