🩹 [Fix] BoxMatcher, change eps and filter function
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
|
@@ -14,7 +14,7 @@ from yolo.utils.logger import logger
|
|
| 14 |
|
| 15 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
| 16 |
metrics = metrics.lower()
|
| 17 |
-
EPS = 1e-
|
| 18 |
dtype = bbox1.dtype
|
| 19 |
bbox1 = bbox1.to(torch.float32)
|
| 20 |
bbox2 = bbox2.to(torch.float32)
|
|
@@ -210,7 +210,7 @@ class BoxMatcher:
|
|
| 210 |
topk_masks = topk_targets > 0
|
| 211 |
return topk_targets, topk_masks
|
| 212 |
|
| 213 |
-
def filter_duplicates(self, target_matrix: Tensor):
|
| 214 |
"""
|
| 215 |
Filter the maximum suitability target index of each anchor.
|
| 216 |
|
|
@@ -220,9 +220,11 @@ class BoxMatcher:
|
|
| 220 |
Returns:
|
| 221 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
| 222 |
"""
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
| 226 |
|
| 227 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
| 228 |
"""
|
|
@@ -249,16 +251,15 @@ class BoxMatcher:
|
|
| 249 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
| 250 |
|
| 251 |
# delete one anchor pred assign to mutliple gts
|
| 252 |
-
unique_indices = self.filter_duplicates(
|
| 253 |
-
|
| 254 |
-
# TODO: do we need grid_mask? Filter the valid groud truth
|
| 255 |
-
valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool()
|
| 256 |
|
| 257 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
| 258 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
| 259 |
align_cls = F.one_hot(align_cls, self.class_num)
|
| 260 |
|
| 261 |
# normalize class ditribution
|
|
|
|
|
|
|
| 262 |
max_target = target_matrix.amax(dim=-1, keepdim=True)
|
| 263 |
max_iou = iou_mat.amax(dim=-1, keepdim=True)
|
| 264 |
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|
|
|
|
| 14 |
|
| 15 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
| 16 |
metrics = metrics.lower()
|
| 17 |
+
EPS = 1e-7
|
| 18 |
dtype = bbox1.dtype
|
| 19 |
bbox1 = bbox1.to(torch.float32)
|
| 20 |
bbox2 = bbox2.to(torch.float32)
|
|
|
|
| 210 |
topk_masks = topk_targets > 0
|
| 211 |
return topk_targets, topk_masks
|
| 212 |
|
| 213 |
+
def filter_duplicates(self, target_matrix: Tensor, topk_mask: Tensor):
|
| 214 |
"""
|
| 215 |
Filter the maximum suitability target index of each anchor.
|
| 216 |
|
|
|
|
| 220 |
Returns:
|
| 221 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
| 222 |
"""
|
| 223 |
+
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
|
| 224 |
+
max_idx = F.one_hot(target_matrix.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
|
| 225 |
+
topk_mask = torch.where(duplicates, max_idx, topk_mask)
|
| 226 |
+
unique_indices = topk_mask.argmax(dim=1)
|
| 227 |
+
return unique_indices[..., None], topk_mask.sum(1), topk_mask
|
| 228 |
|
| 229 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
| 230 |
"""
|
|
|
|
| 251 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
| 252 |
|
| 253 |
# delete one anchor pred assign to mutliple gts
|
| 254 |
+
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
| 257 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
| 258 |
align_cls = F.one_hot(align_cls, self.class_num)
|
| 259 |
|
| 260 |
# normalize class ditribution
|
| 261 |
+
iou_mat *= topk_mask
|
| 262 |
+
target_matrix *= topk_mask
|
| 263 |
max_target = target_matrix.amax(dim=-1, keepdim=True)
|
| 264 |
max_iou = iou_mat.amax(dim=-1, keepdim=True)
|
| 265 |
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|