Spaces:
Running
Running
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| from torch import Tensor | |
| _XYWH2XYXY = torch.tensor([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], | |
| [-0.5, 0.0, 0.5, 0.0], [0.0, -0.5, 0.0, 0.5]], | |
| dtype=torch.float32) | |
| def select_nms_index(scores: Tensor, | |
| boxes: Tensor, | |
| nms_index: Tensor, | |
| batch_size: int, | |
| keep_top_k: int = -1): | |
| batch_inds, cls_inds = nms_index[:, 0], nms_index[:, 1] | |
| box_inds = nms_index[:, 2] | |
| scores = scores[batch_inds, cls_inds, box_inds].unsqueeze(1) | |
| boxes = boxes[batch_inds, box_inds, ...] | |
| dets = torch.cat([boxes, scores], dim=1) | |
| batched_dets = dets.unsqueeze(0).repeat(batch_size, 1, 1) | |
| batch_template = torch.arange( | |
| 0, batch_size, dtype=batch_inds.dtype, device=batch_inds.device) | |
| batched_dets = batched_dets.where( | |
| (batch_inds == batch_template.unsqueeze(1)).unsqueeze(-1), | |
| batched_dets.new_zeros(1)) | |
| batched_labels = cls_inds.unsqueeze(0).repeat(batch_size, 1) | |
| batched_labels = batched_labels.where( | |
| (batch_inds == batch_template.unsqueeze(1)), | |
| batched_labels.new_ones(1) * -1) | |
| N = batched_dets.shape[0] | |
| batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((N, 1, 5))), | |
| 1) | |
| batched_labels = torch.cat((batched_labels, -batched_labels.new_ones( | |
| (N, 1))), 1) | |
| _, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True) | |
| topk_batch_inds = torch.arange( | |
| batch_size, dtype=topk_inds.dtype, | |
| device=topk_inds.device).view(-1, 1) | |
| batched_dets = batched_dets[topk_batch_inds, topk_inds, ...] | |
| batched_labels = batched_labels[topk_batch_inds, topk_inds, ...] | |
| batched_dets, batched_scores = batched_dets.split([4, 1], 2) | |
| batched_scores = batched_scores.squeeze(-1) | |
| num_dets = (batched_scores > 0).sum(1, keepdim=True) | |
| return num_dets, batched_dets, batched_scores, batched_labels | |
| class ONNXNMSop(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| boxes: Tensor, | |
| scores: Tensor, | |
| max_output_boxes_per_class: Tensor = torch.tensor([100]), | |
| iou_threshold: Tensor = torch.tensor([0.5]), | |
| score_threshold: Tensor = torch.tensor([0.05]) | |
| ) -> Tensor: | |
| device = boxes.device | |
| batch = scores.shape[0] | |
| num_det = 20 | |
| batches = torch.randint(0, batch, (num_det, )).sort()[0].to(device) | |
| idxs = torch.arange(100, 100 + num_det).to(device) | |
| zeros = torch.zeros((num_det, ), dtype=torch.int64).to(device) | |
| selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], | |
| 0).T.contiguous() | |
| selected_indices = selected_indices.to(torch.int64) | |
| return selected_indices | |
| def symbolic( | |
| g, | |
| boxes: Tensor, | |
| scores: Tensor, | |
| max_output_boxes_per_class: Tensor = torch.tensor([100]), | |
| iou_threshold: Tensor = torch.tensor([0.5]), | |
| score_threshold: Tensor = torch.tensor([0.05]), | |
| ): | |
| return g.op( | |
| 'NonMaxSuppression', | |
| boxes, | |
| scores, | |
| max_output_boxes_per_class, | |
| iou_threshold, | |
| score_threshold, | |
| outputs=1) | |
| def onnx_nms( | |
| boxes: torch.Tensor, | |
| scores: torch.Tensor, | |
| max_output_boxes_per_class: int = 100, | |
| iou_threshold: float = 0.5, | |
| score_threshold: float = 0.05, | |
| pre_top_k: int = -1, | |
| keep_top_k: int = 100, | |
| box_coding: int = 0, | |
| ): | |
| max_output_boxes_per_class = torch.tensor([max_output_boxes_per_class]) | |
| iou_threshold = torch.tensor([iou_threshold]) | |
| score_threshold = torch.tensor([score_threshold]) | |
| batch_size, _, _ = scores.shape | |
| if box_coding == 1: | |
| boxes = boxes @ (_XYWH2XYXY.to(boxes.device)) | |
| scores = scores.transpose(1, 2).contiguous() | |
| selected_indices = ONNXNMSop.apply(boxes, scores, | |
| max_output_boxes_per_class, | |
| iou_threshold, score_threshold) | |
| num_dets, batched_dets, batched_scores, batched_labels = select_nms_index( | |
| scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k) | |
| return num_dets, batched_dets, batched_scores, batched_labels.to( | |
| torch.int32) | |