🔨 fix: BoxMatcher.__call__ now returns all zero anchor matched targets and all False valid mask, if input target has zero annotations in it. (#88)
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
|
@@ -222,12 +222,37 @@ class BoxMatcher:
|
|
| 222 |
return unique_indices[..., None]
|
| 223 |
|
| 224 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
| 225 |
-
"""
|
| 226 |
-
1. For each anchor prediction, find the highest suitability targets
|
| 227 |
-
2.
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
"""
|
| 230 |
predict_cls, predict_bbox = predict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
|
| 232 |
target_cls = target_cls.long().clamp(0)
|
| 233 |
|
|
@@ -261,8 +286,8 @@ class BoxMatcher:
|
|
| 261 |
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|
| 262 |
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
|
| 263 |
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
| 264 |
-
|
| 265 |
-
return
|
| 266 |
|
| 267 |
|
| 268 |
class Vec2Box:
|
|
|
|
| 222 |
return unique_indices[..., None]
|
| 223 |
|
| 224 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
| 225 |
+
"""Matches each target to the most suitable anchor.
|
| 226 |
+
1. For each anchor prediction, find the highest suitability targets.
|
| 227 |
+
2. Match target to the best anchor.
|
| 228 |
+
3. Noramlize the class probilities of targets.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
target: The ground truth class and bounding box information
|
| 232 |
+
as tensor of size [batch x targets x 5].
|
| 233 |
+
predict: Tuple of predicted class and bounding box tensors.
|
| 234 |
+
Class tensor is of size [batch x anchors x class]
|
| 235 |
+
Bounding box tensor is of size [batch x anchors x 4].
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
anchor_matched_targets: Tensor of size [batch x anchors x (class + 4)].
|
| 239 |
+
A tensor assigning each target/gt to the best fitting anchor.
|
| 240 |
+
The class probabilities are normalized.
|
| 241 |
+
valid_mask: Bool tensor of shape [batch x anchors].
|
| 242 |
+
True if a anchor has a target/gt assigned to it.
|
| 243 |
"""
|
| 244 |
predict_cls, predict_bbox = predict
|
| 245 |
+
|
| 246 |
+
# return if target has no gt information.
|
| 247 |
+
n_targets = target.shape[1]
|
| 248 |
+
if n_targets == 0:
|
| 249 |
+
device = predict_bbox.device
|
| 250 |
+
align_cls = torch.zeros_like(predict_cls, device=device)
|
| 251 |
+
align_bbox = torch.zeros_like(predict_bbox, device=device)
|
| 252 |
+
valid_mask = torch.zeros(predict_cls.shape[:2], dtype=bool, device=device)
|
| 253 |
+
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
|
| 254 |
+
return anchor_matched_targets, valid_mask
|
| 255 |
+
|
| 256 |
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
|
| 257 |
target_cls = target_cls.long().clamp(0)
|
| 258 |
|
|
|
|
| 286 |
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|
| 287 |
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
|
| 288 |
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
| 289 |
+
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
|
| 290 |
+
return anchor_matched_targets, valid_mask
|
| 291 |
|
| 292 |
|
| 293 |
class Vec2Box:
|