1 | /**
|
2 | * @license
|
3 | * Copyright 2020 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import { nonMaxSuppressionV4Impl } from '../../backends/non_max_suppression_impl';
|
18 | import { convertToTensor } from '../../tensor_util_env';
|
19 | import { nonMaxSuppSanityCheck } from '../nonmax_util';
|
20 | import { scalar } from '../scalar';
|
21 | import { tensor1d } from '../tensor1d';
|
22 | /**
|
23 | * Asynchronously performs non maximum suppression of bounding boxes based on
|
24 | * iou (intersection over union), with an option to pad results.
|
25 | *
|
26 | * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
|
27 | * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
|
28 | * the bounding box.
|
29 | * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
|
30 | * @param maxOutputSize The maximum number of boxes to be selected.
|
31 | * @param iouThreshold A float representing the threshold for deciding whether
|
32 | * boxes overlap too much with respect to IOU. Must be between [0, 1].
|
33 | * Defaults to 0.5 (50% box overlap).
|
34 | * @param scoreThreshold A threshold for deciding when to remove boxes based
|
35 | * on score. Defaults to -inf, which means any score is accepted.
|
36 | * @param padToMaxOutputSize Defalts to false. If true, size of output
|
37 | * `selectedIndices` is padded to maxOutputSize.
|
38 | * @return A map with the following properties:
|
39 | * - selectedIndices: A 1D tensor with the selected box indices.
|
40 | * - validOutputs: A scalar denoting how many elements in `selectedIndices`
|
41 | * are valid. Valid elements occur first, then padding.
|
42 | *
|
43 | * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
|
44 | */
|
45 | async function nonMaxSuppressionPaddedAsync_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, padToMaxOutputSize = false) {
|
46 | const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
|
47 | const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
|
48 | const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null /* softNmsSigma */);
|
49 | const $maxOutputSize = params.maxOutputSize;
|
50 | const $iouThreshold = params.iouThreshold;
|
51 | const $scoreThreshold = params.scoreThreshold;
|
52 | const [boxesVals, scoresVals] = await Promise.all([$boxes.data(), $scores.data()]);
|
53 | // We call a cpu based impl directly with the typedarray data here rather
|
54 | // than a kernel because all kernels are synchronous (and thus cannot await
|
55 | // .data()).
|
56 | const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl(boxesVals, scoresVals, $maxOutputSize, $iouThreshold, $scoreThreshold, padToMaxOutputSize);
|
57 | if ($boxes !== boxes) {
|
58 | $boxes.dispose();
|
59 | }
|
60 | if ($scores !== scores) {
|
61 | $scores.dispose();
|
62 | }
|
63 | return {
|
64 | selectedIndices: tensor1d(selectedIndices, 'int32'),
|
65 | validOutputs: scalar(validOutputs, 'int32')
|
66 | };
|
67 | }
|
68 | export const nonMaxSuppressionPaddedAsync = nonMaxSuppressionPaddedAsync_;
|
69 | //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibm9uX21heF9zdXBwcmVzc2lvbl9wYWRkZWRfYXN5bmMuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wcy9pbWFnZS9ub25fbWF4X3N1cHByZXNzaW9uX3BhZGRlZF9hc3luYy50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFDSCxPQUFPLEVBQUMsdUJBQXVCLEVBQUMsTUFBTSx5Q0FBeUMsQ0FBQztBQUdoRixPQUFPLEVBQUMsZUFBZSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFFdEQsT0FBTyxFQUFDLHFCQUFxQixFQUFDLE1BQU0sZ0JBQWdCLENBQUM7QUFDckQsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNqQyxPQUFPLEVBQUMsUUFBUSxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBRXJDOzs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7O0dBc0JHO0FBQ0gsS0FBSyxVQUFVLDZCQUE2QixDQUN4QyxLQUEwQixFQUFFLE1BQTJCLEVBQ3ZELGFBQXFCLEVBQUUsWUFBWSxHQUFHLEdBQUcsRUFDekMsY0FBYyxHQUFHLE1BQU0sQ0FBQyxpQkFBaUIsRUFDekMsa0JBQWtCLEdBQUcsS0FBSztJQUM1QixNQUFNLE1BQU0sR0FBRyxlQUFlLENBQUMsS0FBSyxFQUFFLE9BQU8sRUFBRSx3QkFBd0IsQ0FBQyxDQUFDO0lBQ3pFLE1BQU0sT0FBTyxHQUFHLGVBQWUsQ0FBQyxNQUFNLEVBQUUsUUFBUSxFQUFFLHdCQUF3QixDQUFDLENBQUM7SUFFNUUsTUFBTSxNQUFNLEdBQUcscUJBQXFCLENBQ2hDLE1BQU0sRUFBRSxPQUFPLEVBQUUsYUFBYSxFQUFFLFlBQVksRUFBRSxjQUFjLEVBQzVELElBQUksQ0FBQyxrQkFBa0IsQ0FBQyxDQUFDO0lBQzdCLE1BQU0sY0FBYyxHQUFHLE1BQU0sQ0FBQyxhQUFhLENBQUM7SUFDNUMsTUFBTSxhQUFhLEdBQUcsTUFBTSxDQUFDLFlBQVksQ0FBQztJQUMxQyxNQUFNLGVBQWUsR0FBRyxNQUFNLENBQUMsY0FBYyxDQUFDO0lBRTlDLE1BQU0sQ0FBQyxTQUFTLEVBQUUsVUFBVSxDQUFDLEdBQ3pCLE1BQU0sT0FBTyxDQUFDLEdBQUcsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxJQUFJLEVBQUUsRUFBRSxPQUFPLENBQUMsSUFBSSxFQUFFLENBQUMsQ0FBQyxDQUFDO0lBRXZELHlFQUF5RTtJQUN6RSwyRUFBMkU7SUFDM0UsWUFBWTtJQUNaLE1BQU0sRUFBQyxlQUFlLEVBQUUsWUFBWSxFQUFDLEdBQUcsdUJBQXVCLENBQzNELFNBQVMsRUFBRSxVQUFVLEVBQUUsY0FBYyxFQUFFLGFBQWEsRUFBRSxlQUFlLEVBQ3JFLGtCQUFrQixDQUFDLENBQUM7SUFFeEIsSUFBSSxNQUFNLEtBQUssS0FBSyxFQUFFO1FBQ3BCLE1BQU0sQ0FBQyxPQUFPLEVBQUUsQ0FBQztLQUNsQjtJQUNELElBQUksT0FBTyxLQUFLLE1BQU0sRUFBRTtRQUN0QixPQUFPLENBQUMsT0FBTyxFQUFFLENBQUM7S0FDbkI7SUFFRCxPQUFPO1FBQ0wsZUFBZSxFQUFFLFFBQVEsQ0FBQyxlQUFlLEVBQUUsT0FBTyxDQUFDO1FBQ25ELFlBQVksRUFBRSxNQUFNLENBQUMsWUFBWSxFQUFFLE9BQU8sQ0FBQztLQUM1QyxDQUFDO0FBQ0osQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLDRCQUE0QixHQUFHLDZCQUE2QixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuaW1wb3J0IHtub25NYXhTdXBwcmVzc2lvblY0SW1wbH0gZnJvbSAnLi4vLi4vYmFja2VuZHMvbm9uX21heF9zdXBwcmVzc2lvbl9pbXBsJztcbmltcG9ydCB7VGVuc29yMUQsIFRlbnNvcjJEfSBmcm9tICcuLi8uLi90ZW5zb3InO1xuaW1wb3J0IHtOYW1lZFRlbnNvck1hcH0gZnJvbSAnLi4vLi4vdGVuc29yX3R5cGVzJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi8uLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtUZW5zb3JMaWtlfSBmcm9tICcuLi8uLi90eXBlcyc7XG5pbXBvcnQge25vbk1heFN1cHBTYW5pdHlDaGVja30gZnJvbSAnLi4vbm9ubWF4X3V0aWwnO1xuaW1wb3J0IHtzY2FsYXJ9IGZyb20gJy4uL3NjYWxhcic7XG5pbXBvcnQge3RlbnNvcjFkfSBmcm9tICcuLi90ZW5zb3IxZCc7XG5cbi8qKlxuICogQXN5bmNocm9ub3VzbHkgcGVyZm9ybXMgbm9uIG1heGltdW0gc3VwcHJlc3Npb24gb2YgYm91bmRpbmcgYm94ZXMgYmFzZWQgb25cbiAqIGlvdSAoaW50ZXJzZWN0aW9uIG92ZXIgdW5pb24pLCB3aXRoIGFuIG9wdGlvbiB0byBwYWQgcmVzdWx0cy5cbiAqXG4gKiBAcGFyYW0gYm94ZXMgYSAyZCB0ZW5zb3Igb2Ygc2hhcGUgYFtudW1Cb3hlcywgNF1gLiBFYWNoIGVudHJ5IGlzXG4gKiAgICAgYFt5MSwgeDEsIHkyLCB4Ml1gLCB3aGVyZSBgKHkxLCB4MSlgIGFuZCBgKHkyLCB4MilgIGFyZSB0aGUgY29ybmVycyBvZlxuICogICAgIHRoZSBib3VuZGluZyBib3guXG4gKiBAcGFyYW0gc2NvcmVzIGEgMWQgdGVuc29yIHByb3ZpZGluZyB0aGUgYm94IHNjb3JlcyBvZiBzaGFwZSBgW251bUJveGVzXWAuXG4gKiBAcGFyYW0gbWF4T3V0cHV0U2l6ZSBUaGUgbWF4aW11bSBudW1iZXIgb2YgYm94ZXMgdG8gYmUgc2VsZWN0ZWQuXG4gKiBAcGFyYW0gaW91VGhyZXNob2xkIEEgZmxvYXQgcmVwcmVzZW50aW5nIHRoZSB0aHJlc2hvbGQgZm9yIGRlY2lkaW5nIHdoZXRoZXJcbiAqICAgICBib3hlcyBvdmVybGFwIHRvbyBtdWNoIHdpdGggcmVzcGVjdCB0byBJT1UuIE11c3QgYmUgYmV0d2VlbiBbMCwgMV0uXG4gKiAgICAgRGVmYXVsdHMgdG8gMC41ICg1MCUgYm94IG92ZXJsYXApLlxuICogQHBhcmFtIHNjb3JlVGhyZXNob2xkIEEgdGhyZXNob2xkIGZvciBkZWNpZGluZyB3aGVuIHRvIHJlbW92ZSBib3hlcyBiYXNlZFxuICogICAgIG9uIHNjb3JlLiBEZWZhdWx0cyB0byAtaW5mLCB3aGljaCBtZWFucyBhbnkgc2NvcmUgaXMgYWNjZXB0ZWQuXG4gKiBAcGFyYW0gcGFkVG9NYXhPdXRwdXRTaXplIERlZmFsdHMgdG8gZmFsc2UuIElmIHRydWUsIHNpemUgb2Ygb3V0cHV0XG4gKiAgICAgYHNlbGVjdGVkSW5kaWNlc2AgaXMgcGFkZGVkIHRvIG1heE91dHB1dFNpemUuXG4gKiBAcmV0dXJuIEEgbWFwIHdpdGggdGhlIGZvbGxvd2luZyBwcm9wZXJ0aWVzOlxuICogICAgIC0gc2VsZWN0ZWRJbmRpY2VzOiBBIDFEIHRlbnNvciB3aXRoIHRoZSBzZWxlY3RlZCBib3ggaW5kaWNlcy5cbiAqICAgICAtIHZhbGlkT3V0cHV0czogQSBzY2FsYXIgZGVub3RpbmcgaG93IG1hbnkgZWxlbWVudHMgaW4gYHNlbGVjdGVkSW5kaWNlc2BcbiAqICAgICAgIGFyZSB2YWxpZC4gVmFsaWQgZWxlbWVudHMgb2NjdXIgZmlyc3QsIHRoZW4gcGFkZGluZy5cbiAqXG4gKiBAZG9jIHtoZWFkaW5nOiAnT3BlcmF0aW9ucycsIHN1YmhlYWRpbmc6ICdJbWFnZXMnLCBuYW1lc3BhY2U6ICdpbWFnZSd9XG4gKi9cbmFzeW5jIGZ1bmN0aW9uIG5vbk1heFN1cHByZXNzaW9uUGFkZGVkQXN5bmNfKFxuICAgIGJveGVzOiBUZW5zb3IyRHxUZW5zb3JMaWtlLCBzY29yZXM6IFRlbnNvcjFEfFRlbnNvckxpa2UsXG4gICAgbWF4T3V0cHV0U2l6ZTogbnVtYmVyLCBpb3VUaHJlc2hvbGQgPSAwLjUsXG4gICAgc2NvcmVUaHJlc2hvbGQgPSBOdW1iZXIuTkVHQVRJVkVfSU5GSU5JVFksXG4gICAgcGFkVG9NYXhPdXRwdXRTaXplID0gZmFsc2UpOiBQcm9taXNlPE5hbWVkVGVuc29yTWFwPiB7XG4gIGNvbnN0ICRib3hlcyA9IGNvbnZlcnRUb1RlbnNvcihib3hlcywgJ2JveGVzJywgJ25vbk1heFN1cHByZXNzaW9uQXN5bmMnKTtcbiAgY29uc3QgJHNjb3JlcyA9IGNvbnZlcnRUb1RlbnNvcihzY29yZXMsICdzY29yZXMnLCAnbm9uTWF4U3VwcHJlc3Npb25Bc3luYycpO1xuXG4gIGNvbnN0IHBhcmFtcyA9IG5vbk1heFN1cHBTYW5pdHlDaGVjayhcbiAgICAgICRib3hlcywgJHNjb3JlcywgbWF4T3V0cHV0U2l6ZSwgaW91VGhyZXNob2xkLCBzY29yZVRocmVzaG9sZCxcbiAgICAgIG51bGwgLyogc29mdE5tc1NpZ21hICovKTtcbiAgY29uc3QgJG1heE91dHB1dFNpemUgPSBwYXJhbXMubWF4T3V0cHV0U2l6ZTtcbiAgY29uc3QgJGlvdVRocmVzaG9sZCA9IHBhcmFtcy5pb3VUaHJlc2hvbGQ7XG4gIGNvbnN0ICRzY29yZVRocmVzaG9sZCA9IHBhcmFtcy5zY29yZVRocmVzaG9sZDtcblxuICBjb25zdCBbYm94ZXNWYWxzLCBzY29yZXNWYWxzXSA9XG4gICAgICBhd2FpdCBQcm9taXNlLmFsbChbJGJveGVzLmRhdGEoKSwgJHNjb3Jlcy5kYXRhKCldKTtcblxuICAvLyBXZSBjYWxsIGEgY3B1IGJhc2VkIGltcGwgZGlyZWN0bHkgd2l0aCB0aGUgdHlwZWRhcnJheSBkYXRhIGhlcmUgcmF0aGVyXG4gIC8vIHRoYW4gYSBrZXJuZWwgYmVjYXVzZSBhbGwga2VybmVscyBhcmUgc3luY2hyb25vdXMgKGFuZCB0aHVzIGNhbm5vdCBhd2FpdFxuICAvLyAuZGF0YSgpKS5cbiAgY29uc3Qge3NlbGVjdGVkSW5kaWNlcywgdmFsaWRPdXRwdXRzfSA9IG5vbk1heFN1cHByZXNzaW9uVjRJbXBsKFxuICAgICAgYm94ZXNWYWxzLCBzY29yZXNWYWxzLCAkbWF4T3V0cHV0U2l6ZSwgJGlvdVRocmVzaG9sZCwgJHNjb3JlVGhyZXNob2xkLFxuICAgICAgcGFkVG9NYXhPdXRwdXRTaXplKTtcblxuICBpZiAoJGJveGVzICE9PSBib3hlcykge1xuICAgICRib3hlcy5kaXNwb3NlKCk7XG4gIH1cbiAgaWYgKCRzY29yZXMgIT09IHNjb3Jlcykge1xuICAgICRzY29yZXMuZGlzcG9zZSgpO1xuICB9XG5cbiAgcmV0dXJuIHtcbiAgICBzZWxlY3RlZEluZGljZXM6IHRlbnNvcjFkKHNlbGVjdGVkSW5kaWNlcywgJ2ludDMyJyksXG4gICAgdmFsaWRPdXRwdXRzOiBzY2FsYXIodmFsaWRPdXRwdXRzLCAnaW50MzInKVxuICB9O1xufVxuXG5leHBvcnQgY29uc3Qgbm9uTWF4U3VwcHJlc3Npb25QYWRkZWRBc3luYyA9IG5vbk1heFN1cHByZXNzaW9uUGFkZGVkQXN5bmNfO1xuIl19 |
\ | No newline at end of file |