1 | /**
|
2 | * @license
|
3 | * Copyright 2020 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import { ENGINE } from '../../engine';
|
18 | import { NonMaxSuppressionV5 } from '../../kernel_names';
|
19 | import { convertToTensor } from '../../tensor_util_env';
|
20 | import { nonMaxSuppSanityCheck } from '../nonmax_util';
|
21 | import { op } from '../operation';
|
22 | /**
|
23 | * Performs non maximum suppression of bounding boxes based on
|
24 | * iou (intersection over union).
|
25 | *
|
26 | * This op also supports a Soft-NMS mode (c.f.
|
27 | * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
|
28 | * of other overlapping boxes, therefore favoring different regions of the image
|
29 | * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
|
30 | * parameter to be larger than 0.
|
31 | *
|
32 | * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
|
33 | * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
|
34 | * the bounding box.
|
35 | * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
|
36 | * @param maxOutputSize The maximum number of boxes to be selected.
|
37 | * @param iouThreshold A float representing the threshold for deciding whether
|
38 | * boxes overlap too much with respect to IOU. Must be between [0, 1].
|
39 | * Defaults to 0.5 (50% box overlap).
|
40 | * @param scoreThreshold A threshold for deciding when to remove boxes based
|
41 | * on score. Defaults to -inf, which means any score is accepted.
|
42 | * @param softNmsSigma A float representing the sigma parameter for Soft NMS.
|
43 | * When sigma is 0, it falls back to nonMaxSuppression.
|
44 | * @return A map with the following properties:
|
45 | * - selectedIndices: A 1D tensor with the selected box indices.
|
46 | * - selectedScores: A 1D tensor with the corresponding scores for each
|
47 | * selected box.
|
48 | *
|
49 | * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
|
50 | */
|
51 | function nonMaxSuppressionWithScore_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, softNmsSigma = 0.0) {
|
52 | const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
|
53 | const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
|
54 | const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
|
55 | maxOutputSize = params.maxOutputSize;
|
56 | iouThreshold = params.iouThreshold;
|
57 | scoreThreshold = params.scoreThreshold;
|
58 | softNmsSigma = params.softNmsSigma;
|
59 | const inputs = { boxes: $boxes, scores: $scores };
|
60 | const attrs = { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma };
|
61 | // tslint:disable-next-line: no-unnecessary-type-assertion
|
62 | const result = ENGINE.runKernel(NonMaxSuppressionV5, inputs, attrs);
|
63 | return { selectedIndices: result[0], selectedScores: result[1] };
|
64 | }
|
65 | export const nonMaxSuppressionWithScore = op({ nonMaxSuppressionWithScore_ });
|
66 | //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibm9uX21heF9zdXBwcmVzc2lvbl93aXRoX3Njb3JlLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvaW1hZ2Uvbm9uX21heF9zdXBwcmVzc2lvbl93aXRoX3Njb3JlLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxjQUFjLENBQUM7QUFDcEMsT0FBTyxFQUFDLG1CQUFtQixFQUFzRCxNQUFNLG9CQUFvQixDQUFDO0FBSTVHLE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUd0RCxPQUFPLEVBQUMscUJBQXFCLEVBQUMsTUFBTSxnQkFBZ0IsQ0FBQztBQUNyRCxPQUFPLEVBQUMsRUFBRSxFQUFDLE1BQU0sY0FBYyxDQUFDO0FBRWhDOzs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7O0dBNEJHO0FBQ0gsU0FBUywyQkFBMkIsQ0FDaEMsS0FBMEIsRUFBRSxNQUEyQixFQUN2RCxhQUFxQixFQUFFLFlBQVksR0FBRyxHQUFHLEVBQ3pDLGNBQWMsR0FBRyxNQUFNLENBQUMsaUJBQWlCLEVBQ3pDLFlBQVksR0FBRyxHQUFHO0lBQ3BCLE1BQU0sTUFBTSxHQUFHLGVBQWUsQ0FBQyxLQUFLLEVBQUUsT0FBTyxFQUFFLG1CQUFtQixDQUFDLENBQUM7SUFDcEUsTUFBTSxPQUFPLEdBQUcsZUFBZSxDQUFDLE1BQU0sRUFBRSxRQUFRLEVBQUUsbUJBQW1CLENBQUMsQ0FBQztJQUV2RSxNQUFNLE1BQU0sR0FBRyxxQkFBcUIsQ0FDaEMsTUFBTSxFQUFFLE9BQU8sRUFBRSxhQUFhLEVBQUUsWUFBWSxFQUFFLGNBQWMsRUFDNUQsWUFBWSxDQUFDLENBQUM7SUFDbEIsYUFBYSxHQUFHLE1BQU0sQ0FBQyxhQUFhLENBQUM7SUFDckMsWUFBWSxHQUFHLE1BQU0sQ0FBQyxZQUFZLENBQUM7SUFDbkMsY0FBYyxHQUFHLE1BQU0sQ0FBQyxjQUFjLENBQUM7SUFDdkMsWUFBWSxHQUFHLE1BQU0sQ0FBQyxZQUFZLENBQUM7SUFFbkMsTUFBTSxNQUFNLEdBQThCLEVBQUMsS0FBSyxFQUFFLE1BQU0sRUFBRSxNQUFNLEVBQUUsT0FBTyxFQUFDLENBQUM7SUFDM0UsTUFBTSxLQUFLLEdBQ1AsRUFBQyxhQUFhLEVBQUUsWUFBWSxFQUFFLGNBQWMsRUFBRSxZQUFZLEVBQUMsQ0FBQztJQUVoRSwwREFBMEQ7SUFDMUQsTUFBTSxNQUFNLEdBQUcsTUFBTSxDQUFDLFNBQVMsQ0FDWixtQkFBbUIsRUFBRSxNQUE4QixFQUNuRCxLQUEyQixDQUFhLENBQUM7SUFFNUQsT0FBTyxFQUFDLGVBQWUsRUFBRSxNQUFNLENBQUMsQ0FBQyxDQUFDLEVBQUUsY0FBYyxFQUFFLE1BQU0sQ0FBQyxDQUFDLENBQUMsRUFBQyxDQUFDO0FBQ2pFLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSwwQkFBMEIsR0FBRyxFQUFFLENBQUMsRUFBQywyQkFBMkIsRUFBQyxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7RU5HSU5FfSBmcm9tICcuLi8uLi9lbmdpbmUnO1xuaW1wb3J0IHtOb25NYXhTdXBwcmVzc2lvblY1LCBOb25NYXhTdXBwcmVzc2lvblY1QXR0cnMsIE5vbk1heFN1cHByZXNzaW9uVjVJbnB1dHN9IGZyb20gJy4uLy4uL2tlcm5lbF9uYW1lcyc7XG5pbXBvcnQge05hbWVkQXR0ck1hcH0gZnJvbSAnLi4vLi4va2VybmVsX3JlZ2lzdHJ5JztcbmltcG9ydCB7VGVuc29yLCBUZW5zb3IxRCwgVGVuc29yMkR9IGZyb20gJy4uLy4uL3RlbnNvcic7XG5pbXBvcnQge05hbWVkVGVuc29yTWFwfSBmcm9tICcuLi8uLi90ZW5zb3JfdHlwZXMnO1xuaW1wb3J0IHtjb252ZXJ0VG9UZW5zb3J9IGZyb20gJy4uLy4uL3RlbnNvcl91dGlsX2Vudic7XG5pbXBvcnQge1RlbnNvckxpa2V9IGZyb20gJy4uLy4uL3R5cGVzJztcblxuaW1wb3J0IHtub25NYXhTdXBwU2FuaXR5Q2hlY2t9IGZyb20gJy4uL25vbm1heF91dGlsJztcbmltcG9ydCB7b3B9IGZyb20gJy4uL29wZXJhdGlvbic7XG5cbi8qKlxuICogUGVyZm9ybXMgbm9uIG1heGltdW0gc3VwcHJlc3Npb24gb2YgYm91bmRpbmcgYm94ZXMgYmFzZWQgb25cbiAqIGlvdSAoaW50ZXJzZWN0aW9uIG92ZXIgdW5pb24pLlxuICpcbiAqIFRoaXMgb3AgYWxzbyBzdXBwb3J0cyBhIFNvZnQtTk1TIG1vZGUgKGMuZi5cbiAqIEJvZGxhIGV0IGFsLCBodHRwczovL2FyeGl2Lm9yZy9hYnMvMTcwNC4wNDUwMykgd2hlcmUgYm94ZXMgcmVkdWNlIHRoZSBzY29yZVxuICogb2Ygb3RoZXIgb3ZlcmxhcHBpbmcgYm94ZXMsIHRoZXJlZm9yZSBmYXZvcmluZyBkaWZmZXJlbnQgcmVnaW9ucyBvZiB0aGUgaW1hZ2VcbiAqIHdpdGggaGlnaCBzY29yZXMuIFRvIGVuYWJsZSB0aGlzIFNvZnQtTk1TIG1vZGUsIHNldCB0aGUgYHNvZnRObXNTaWdtYWBcbiAqIHBhcmFtZXRlciB0byBiZSBsYXJnZXIgdGhhbiAwLlxuICpcbiAqIEBwYXJhbSBib3hlcyBhIDJkIHRlbnNvciBvZiBzaGFwZSBgW251bUJveGVzLCA0XWAuIEVhY2ggZW50cnkgaXNcbiAqICAgICBgW3kxLCB4MSwgeTIsIHgyXWAsIHdoZXJlIGAoeTEsIHgxKWAgYW5kIGAoeTIsIHgyKWAgYXJlIHRoZSBjb3JuZXJzIG9mXG4gKiAgICAgdGhlIGJvdW5kaW5nIGJveC5cbiAqIEBwYXJhbSBzY29yZXMgYSAxZCB0ZW5zb3IgcHJvdmlkaW5nIHRoZSBib3ggc2NvcmVzIG9mIHNoYXBlIGBbbnVtQm94ZXNdYC5cbiAqIEBwYXJhbSBtYXhPdXRwdXRTaXplIFRoZSBtYXhpbXVtIG51bWJlciBvZiBib3hlcyB0byBiZSBzZWxlY3RlZC5cbiAqIEBwYXJhbSBpb3VUaHJlc2hvbGQgQSBmbG9hdCByZXByZXNlbnRpbmcgdGhlIHRocmVzaG9sZCBmb3IgZGVjaWRpbmcgd2hldGhlclxuICogICAgIGJveGVzIG92ZXJsYXAgdG9vIG11Y2ggd2l0aCByZXNwZWN0IHRvIElPVS4gTXVzdCBiZSBiZXR3ZWVuIFswLCAxXS5cbiAqICAgICBEZWZhdWx0cyB0byAwLjUgKDUwJSBib3ggb3ZlcmxhcCkuXG4gKiBAcGFyYW0gc2NvcmVUaHJlc2hvbGQgQSB0aHJlc2hvbGQgZm9yIGRlY2lkaW5nIHdoZW4gdG8gcmVtb3ZlIGJveGVzIGJhc2VkXG4gKiAgICAgb24gc2NvcmUuIERlZmF1bHRzIHRvIC1pbmYsIHdoaWNoIG1lYW5zIGFueSBzY29yZSBpcyBhY2NlcHRlZC5cbiAqIEBwYXJhbSBzb2Z0Tm1zU2lnbWEgQSBmbG9hdCByZXByZXNlbnRpbmcgdGhlIHNpZ21hIHBhcmFtZXRlciBmb3IgU29mdCBOTVMuXG4gKiAgICAgV2hlbiBzaWdtYSBpcyAwLCBpdCBmYWxscyBiYWNrIHRvIG5vbk1heFN1cHByZXNzaW9uLlxuICogQHJldHVybiBBIG1hcCB3aXRoIHRoZSBmb2xsb3dpbmcgcHJvcGVydGllczpcbiAqICAgICAtIHNlbGVjdGVkSW5kaWNlczogQSAxRCB0ZW5zb3Igd2l0aCB0aGUgc2VsZWN0ZWQgYm94IGluZGljZXMuXG4gKiAgICAgLSBzZWxlY3RlZFNjb3JlczogQSAxRCB0ZW5zb3Igd2l0aCB0aGUgY29ycmVzcG9uZGluZyBzY29yZXMgZm9yIGVhY2hcbiAqICAgICAgIHNlbGVjdGVkIGJveC5cbiAqXG4gKiBAZG9jIHtoZWFkaW5nOiAnT3BlcmF0aW9ucycsIHN1YmhlYWRpbmc6ICdJbWFnZXMnLCBuYW1lc3BhY2U6ICdpbWFnZSd9XG4gKi9cbmZ1bmN0aW9uIG5vbk1heFN1cHByZXNzaW9uV2l0aFNjb3JlXyhcbiAgICBib3hlczogVGVuc29yMkR8VGVuc29yTGlrZSwgc2NvcmVzOiBUZW5zb3IxRHxUZW5zb3JMaWtlLFxuICAgIG1heE91dHB1dFNpemU6IG51bWJlciwgaW91VGhyZXNob2xkID0gMC41LFxuICAgIHNjb3JlVGhyZXNob2xkID0gTnVtYmVyLk5FR0FUSVZFX0lORklOSVRZLFxuICAgIHNvZnRObXNTaWdtYSA9IDAuMCk6IE5hbWVkVGVuc29yTWFwIHtcbiAgY29uc3QgJGJveGVzID0gY29udmVydFRvVGVuc29yKGJveGVzLCAnYm94ZXMnLCAnbm9uTWF4U3VwcHJlc3Npb24nKTtcbiAgY29uc3QgJHNjb3JlcyA9IGNvbnZlcnRUb1RlbnNvcihzY29yZXMsICdzY29yZXMnLCAnbm9uTWF4U3VwcHJlc3Npb24nKTtcblxuICBjb25zdCBwYXJhbXMgPSBub25NYXhTdXBwU2FuaXR5Q2hlY2soXG4gICAgICAkYm94ZXMsICRzY29yZXMsIG1heE91dHB1dFNpemUsIGlvdVRocmVzaG9sZCwgc2NvcmVUaHJlc2hvbGQsXG4gICAgICBzb2Z0Tm1zU2lnbWEpO1xuICBtYXhPdXRwdXRTaXplID0gcGFyYW1zLm1heE91dHB1dFNpemU7XG4gIGlvdVRocmVzaG9sZCA9IHBhcmFtcy5pb3VUaHJlc2hvbGQ7XG4gIHNjb3JlVGhyZXNob2xkID0gcGFyYW1zLnNjb3JlVGhyZXNob2xkO1xuICBzb2Z0Tm1zU2lnbWEgPSBwYXJhbXMuc29mdE5tc1NpZ21hO1xuXG4gIGNvbnN0IGlucHV0czogTm9uTWF4U3VwcHJlc3Npb25WNUlucHV0cyA9IHtib3hlczogJGJveGVzLCBzY29yZXM6ICRzY29yZXN9O1xuICBjb25zdCBhdHRyczogTm9uTWF4U3VwcHJlc3Npb25WNUF0dHJzID1cbiAgICAgIHttYXhPdXRwdXRTaXplLCBpb3VUaHJlc2hvbGQsIHNjb3JlVGhyZXNob2xkLCBzb2Z0Tm1zU2lnbWF9O1xuXG4gIC8vIHRzbGludDpkaXNhYmxlLW5leHQtbGluZTogbm8tdW5uZWNlc3NhcnktdHlwZS1hc3NlcnRpb25cbiAgY29uc3QgcmVzdWx0ID0gRU5HSU5FLnJ1bktlcm5lbChcbiAgICAgICAgICAgICAgICAgICAgIE5vbk1heFN1cHByZXNzaW9uVjUsIGlucHV0cyBhcyB7fSBhcyBOYW1lZFRlbnNvck1hcCxcbiAgICAgICAgICAgICAgICAgICAgIGF0dHJzIGFzIHt9IGFzIE5hbWVkQXR0ck1hcCkgYXMgVGVuc29yW107XG5cbiAgcmV0dXJuIHtzZWxlY3RlZEluZGljZXM6IHJlc3VsdFswXSwgc2VsZWN0ZWRTY29yZXM6IHJlc3VsdFsxXX07XG59XG5cbmV4cG9ydCBjb25zdCBub25NYXhTdXBwcmVzc2lvbldpdGhTY29yZSA9IG9wKHtub25NYXhTdXBwcmVzc2lvbldpdGhTY29yZV99KTtcbiJdfQ== |
\ | No newline at end of file |