π [Fix] BoxMatcher for filter outsided bbox
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
|
@@ -212,19 +212,20 @@ class BoxMatcher:
|
|
| 212 |
topk_masks = topk_targets > 0
|
| 213 |
return topk_targets, topk_masks
|
| 214 |
|
| 215 |
-
def filter_duplicates(self,
|
| 216 |
"""
|
| 217 |
Filter the maximum suitability target index of each anchor.
|
| 218 |
|
| 219 |
Args:
|
| 220 |
-
|
| 221 |
|
| 222 |
Returns:
|
| 223 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
| 224 |
"""
|
| 225 |
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
|
| 226 |
-
max_idx = F.one_hot(
|
| 227 |
topk_mask = torch.where(duplicates, max_idx, topk_mask)
|
|
|
|
| 228 |
unique_indices = topk_mask.argmax(dim=1)
|
| 229 |
return unique_indices[..., None], topk_mask.sum(1), topk_mask
|
| 230 |
|
|
@@ -278,7 +279,7 @@ class BoxMatcher:
|
|
| 278 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
| 279 |
|
| 280 |
# delete one anchor pred assign to mutliple gts
|
| 281 |
-
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
|
| 282 |
|
| 283 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
| 284 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
|
|
|
| 212 |
topk_masks = topk_targets > 0
|
| 213 |
return topk_targets, topk_masks
|
| 214 |
|
| 215 |
+
def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor, grid_mask: Tensor):
|
| 216 |
"""
|
| 217 |
Filter the maximum suitability target index of each anchor.
|
| 218 |
|
| 219 |
Args:
|
| 220 |
+
iou_mat [batch x targets x anchors]: The suitability for each targets-anchors
|
| 221 |
|
| 222 |
Returns:
|
| 223 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
| 224 |
"""
|
| 225 |
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
|
| 226 |
+
max_idx = F.one_hot(iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
|
| 227 |
topk_mask = torch.where(duplicates, max_idx, topk_mask)
|
| 228 |
+
topk_mask &= grid_mask
|
| 229 |
unique_indices = topk_mask.argmax(dim=1)
|
| 230 |
return unique_indices[..., None], topk_mask.sum(1), topk_mask
|
| 231 |
|
|
|
|
| 279 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
| 280 |
|
| 281 |
# delete one anchor pred assign to mutliple gts
|
| 282 |
+
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask, grid_mask)
|
| 283 |
|
| 284 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
| 285 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|