Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| from __future__ import absolute_import, division, print_function, unicode_literals | |
| import unittest | |
| import torch | |
| from detectron2.layers import batched_nms | |
| from detectron2.utils.testing import random_boxes | |
| class TestNMS(unittest.TestCase): | |
| def _create_tensors(self, N): | |
| boxes = random_boxes(N, 200) | |
| scores = torch.rand(N) | |
| return boxes, scores | |
| def test_nms_scriptability(self): | |
| N = 2000 | |
| num_classes = 50 | |
| boxes, scores = self._create_tensors(N) | |
| idxs = torch.randint(0, num_classes, (N,)) | |
| scripted_batched_nms = torch.jit.script(batched_nms) | |
| err_msg = "NMS is incompatible with jit-scripted NMS for IoU={}" | |
| for iou in [0.2, 0.5, 0.8]: | |
| keep_ref = batched_nms(boxes, scores, idxs, iou) | |
| backup = boxes.clone() | |
| scripted_keep = scripted_batched_nms(boxes, scores, idxs, iou) | |
| assert torch.allclose(boxes, backup), "boxes modified by jit-scripted batched_nms" | |
| self.assertTrue(torch.equal(keep_ref, scripted_keep), err_msg.format(iou)) | |
| if __name__ == "__main__": | |
| unittest.main() | |