✨ [Add] yolov9 loss function, align to origin v9
Browse files- config/config.py +17 -1
- config/hyper/default.yaml +16 -0
- config/model/v7-base.yaml +4 -0
- tools/bbox_helper.py +251 -0
- utils/loss.py +182 -0
config/config.py
CHANGED
|
@@ -2,9 +2,15 @@ from dataclasses import dataclass
|
|
| 2 |
from typing import Dict, List, Union
|
| 3 |
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
@dataclass
|
| 6 |
class Model:
|
| 7 |
-
anchor:
|
| 8 |
model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
|
| 9 |
|
| 10 |
|
|
@@ -20,6 +26,8 @@ class DataLoaderConfig:
|
|
| 20 |
shuffle: bool
|
| 21 |
num_workers: int
|
| 22 |
pin_memory: bool
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
@dataclass
|
|
@@ -52,11 +60,19 @@ class EMAConfig:
|
|
| 52 |
decay: float
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
@dataclass
|
| 56 |
class TrainConfig:
|
| 57 |
optimizer: OptimizerConfig
|
| 58 |
scheduler: SchedulerConfig
|
| 59 |
ema: EMAConfig
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
@dataclass
|
|
|
|
| 2 |
from typing import Dict, List, Union
|
| 3 |
|
| 4 |
|
| 5 |
+
@dataclass
|
| 6 |
+
class AnchorConfig:
|
| 7 |
+
reg_max: int
|
| 8 |
+
strides: List[int]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
@dataclass
|
| 12 |
class Model:
|
| 13 |
+
anchor: AnchorConfig
|
| 14 |
model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
|
| 15 |
|
| 16 |
|
|
|
|
| 26 |
shuffle: bool
|
| 27 |
num_workers: int
|
| 28 |
pin_memory: bool
|
| 29 |
+
image_size: List[int]
|
| 30 |
+
class_num: int
|
| 31 |
|
| 32 |
|
| 33 |
@dataclass
|
|
|
|
| 60 |
decay: float
|
| 61 |
|
| 62 |
|
| 63 |
+
@dataclass
|
| 64 |
+
class MatcherConfig:
|
| 65 |
+
iou: str
|
| 66 |
+
topk: int
|
| 67 |
+
factor: Dict[str, int]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
@dataclass
|
| 71 |
class TrainConfig:
|
| 72 |
optimizer: OptimizerConfig
|
| 73 |
scheduler: SchedulerConfig
|
| 74 |
ema: EMAConfig
|
| 75 |
+
matcher: MatcherConfig
|
| 76 |
|
| 77 |
|
| 78 |
@dataclass
|
config/hyper/default.yaml
CHANGED
|
@@ -3,12 +3,28 @@ data:
|
|
| 3 |
shuffle: True
|
| 4 |
num_workers: 4
|
| 5 |
pin_memory: True
|
|
|
|
|
|
|
| 6 |
train:
|
| 7 |
optimizer:
|
| 8 |
type: Adam
|
| 9 |
args:
|
| 10 |
lr: 0.001
|
| 11 |
weight_decay: 0.0001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
scheduler:
|
| 13 |
type: StepLR
|
| 14 |
args:
|
|
|
|
| 3 |
shuffle: True
|
| 4 |
num_workers: 4
|
| 5 |
pin_memory: True
|
| 6 |
+
class_num: 80
|
| 7 |
+
image_size: [640, 640]
|
| 8 |
train:
|
| 9 |
optimizer:
|
| 10 |
type: Adam
|
| 11 |
args:
|
| 12 |
lr: 0.001
|
| 13 |
weight_decay: 0.0001
|
| 14 |
+
loss:
|
| 15 |
+
BCELoss:
|
| 16 |
+
args:
|
| 17 |
+
BoxLoss:
|
| 18 |
+
args:
|
| 19 |
+
alpha: 0.1
|
| 20 |
+
DFLoss:
|
| 21 |
+
args:
|
| 22 |
+
matcher:
|
| 23 |
+
iou: CIoU
|
| 24 |
+
topk: 10
|
| 25 |
+
factor:
|
| 26 |
+
iou: 6.0
|
| 27 |
+
cls: 0.5
|
| 28 |
scheduler:
|
| 29 |
type: StepLR
|
| 30 |
args:
|
config/model/v7-base.yaml
CHANGED
|
@@ -1,5 +1,9 @@
|
|
| 1 |
nc: 80
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
model:
|
| 4 |
backbone:
|
| 5 |
- Conv:
|
|
|
|
| 1 |
nc: 80
|
| 2 |
|
| 3 |
+
anchor:
|
| 4 |
+
reg_max: 16
|
| 5 |
+
strides: [8, 16, 32]
|
| 6 |
+
|
| 7 |
model:
|
| 8 |
backbone:
|
| 9 |
- Conv:
|
tools/bbox_helper.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
from config.config import MatcherConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
| 12 |
+
metrics = metrics.lower()
|
| 13 |
+
EPS = 1e-9
|
| 14 |
+
dtype = bbox1.dtype
|
| 15 |
+
bbox1 = bbox1.to(torch.float32)
|
| 16 |
+
bbox2 = bbox2.to(torch.float32)
|
| 17 |
+
|
| 18 |
+
# Expand dimensions if necessary
|
| 19 |
+
if bbox1.ndim == 2 and bbox2.ndim == 2:
|
| 20 |
+
bbox1 = bbox1.unsqueeze(1) # (Ax4) -> (Ax1x4)
|
| 21 |
+
bbox2 = bbox2.unsqueeze(0) # (Bx4) -> (1xBx4)
|
| 22 |
+
elif bbox1.ndim == 3 and bbox2.ndim == 3:
|
| 23 |
+
bbox1 = bbox1.unsqueeze(2) # (BZxAx4) -> (BZxAx1x4)
|
| 24 |
+
bbox2 = bbox2.unsqueeze(1) # (BZxBx4) -> (BZx1xBx4)
|
| 25 |
+
|
| 26 |
+
# Calculate intersection coordinates
|
| 27 |
+
xmin_inter = torch.max(bbox1[..., 0], bbox2[..., 0])
|
| 28 |
+
ymin_inter = torch.max(bbox1[..., 1], bbox2[..., 1])
|
| 29 |
+
xmax_inter = torch.min(bbox1[..., 2], bbox2[..., 2])
|
| 30 |
+
ymax_inter = torch.min(bbox1[..., 3], bbox2[..., 3])
|
| 31 |
+
|
| 32 |
+
# Calculate intersection area
|
| 33 |
+
intersection_area = torch.clamp(xmax_inter - xmin_inter, min=0) * torch.clamp(ymax_inter - ymin_inter, min=0)
|
| 34 |
+
|
| 35 |
+
# Calculate area of each bbox
|
| 36 |
+
area_bbox1 = (bbox1[..., 2] - bbox1[..., 0]) * (bbox1[..., 3] - bbox1[..., 1])
|
| 37 |
+
area_bbox2 = (bbox2[..., 2] - bbox2[..., 0]) * (bbox2[..., 3] - bbox2[..., 1])
|
| 38 |
+
|
| 39 |
+
# Calculate union area
|
| 40 |
+
union_area = area_bbox1 + area_bbox2 - intersection_area
|
| 41 |
+
|
| 42 |
+
# Calculate IoU
|
| 43 |
+
iou = intersection_area / (union_area + EPS)
|
| 44 |
+
if metrics == "iou":
|
| 45 |
+
return iou
|
| 46 |
+
|
| 47 |
+
# Calculate centroid distance
|
| 48 |
+
cx1 = (bbox1[..., 2] + bbox1[..., 0]) / 2
|
| 49 |
+
cy1 = (bbox1[..., 3] + bbox1[..., 1]) / 2
|
| 50 |
+
cx2 = (bbox2[..., 2] + bbox2[..., 0]) / 2
|
| 51 |
+
cy2 = (bbox2[..., 3] + bbox2[..., 1]) / 2
|
| 52 |
+
cent_dis = (cx1 - cx2) ** 2 + (cy1 - cy2) ** 2
|
| 53 |
+
|
| 54 |
+
# Calculate diagonal length of the smallest enclosing box
|
| 55 |
+
c_x = torch.max(bbox1[..., 2], bbox2[..., 2]) - torch.min(bbox1[..., 0], bbox2[..., 0])
|
| 56 |
+
c_y = torch.max(bbox1[..., 3], bbox2[..., 3]) - torch.min(bbox1[..., 1], bbox2[..., 1])
|
| 57 |
+
diag_dis = c_x**2 + c_y**2 + EPS
|
| 58 |
+
|
| 59 |
+
diou = iou - (cent_dis / diag_dis)
|
| 60 |
+
if metrics == "diou":
|
| 61 |
+
return diou
|
| 62 |
+
|
| 63 |
+
# Compute aspect ratio penalty term
|
| 64 |
+
arctan = torch.atan((bbox1[..., 2] - bbox1[..., 0]) / (bbox1[..., 3] - bbox1[..., 1] + EPS)) - torch.atan(
|
| 65 |
+
(bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
|
| 66 |
+
)
|
| 67 |
+
v = (4 / (math.pi**2)) * (arctan**2)
|
| 68 |
+
alpha = v / (v - iou + 1 + EPS)
|
| 69 |
+
# Compute CIoU
|
| 70 |
+
ciou = diou - alpha * v
|
| 71 |
+
return ciou.to(dtype)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
|
| 75 |
+
data_type = bbox.dtype
|
| 76 |
+
in_type, out_type = indicator.replace(" ", "").split("->")
|
| 77 |
+
|
| 78 |
+
if in_type not in ["xyxy", "xywh", "xycwh"] or out_type not in ["xyxy", "xywh", "xycwh"]:
|
| 79 |
+
raise ValueError("Invalid input or output format")
|
| 80 |
+
|
| 81 |
+
if in_type == "xywh":
|
| 82 |
+
x_min = bbox[..., 0]
|
| 83 |
+
y_min = bbox[..., 1]
|
| 84 |
+
x_max = bbox[..., 0] + bbox[..., 2]
|
| 85 |
+
y_max = bbox[..., 1] + bbox[..., 3]
|
| 86 |
+
elif in_type == "xyxy":
|
| 87 |
+
x_min = bbox[..., 0]
|
| 88 |
+
y_min = bbox[..., 1]
|
| 89 |
+
x_max = bbox[..., 2]
|
| 90 |
+
y_max = bbox[..., 3]
|
| 91 |
+
elif in_type == "xycwh":
|
| 92 |
+
x_min = bbox[..., 0] - bbox[..., 2] / 2
|
| 93 |
+
y_min = bbox[..., 1] - bbox[..., 3] / 2
|
| 94 |
+
x_max = bbox[..., 0] + bbox[..., 2] / 2
|
| 95 |
+
y_max = bbox[..., 1] + bbox[..., 3] / 2
|
| 96 |
+
|
| 97 |
+
if out_type == "xywh":
|
| 98 |
+
bbox = torch.stack([x_min, y_min, x_max - x_min, y_max - y_min], dim=-1)
|
| 99 |
+
elif out_type == "xyxy":
|
| 100 |
+
bbox = torch.stack([x_min, y_min, x_max, y_max], dim=-1)
|
| 101 |
+
elif out_type == "xycwh":
|
| 102 |
+
bbox = torch.stack([(x_min + x_max) / 2, (y_min + y_max) / 2, x_max - x_min, y_max - y_min], dim=-1)
|
| 103 |
+
|
| 104 |
+
return bbox.to(dtype=data_type)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def make_anchor(image_size: List[int], strides: List[int], device):
|
| 108 |
+
W, H = image_size
|
| 109 |
+
anchors = []
|
| 110 |
+
scaler = []
|
| 111 |
+
for stride in strides:
|
| 112 |
+
anchor_num = W // stride * H // stride
|
| 113 |
+
scaler.append(torch.full((anchor_num,), stride, device=device))
|
| 114 |
+
shift = stride // 2
|
| 115 |
+
x = torch.arange(0, W, stride, device=device) + shift
|
| 116 |
+
y = torch.arange(0, H, stride, device=device) + shift
|
| 117 |
+
anchor_x, anchor_y = torch.meshgrid(x, y, indexing="ij")
|
| 118 |
+
anchor = torch.stack([anchor_y.flatten(), anchor_x.flatten()], dim=-1)
|
| 119 |
+
anchors.append(anchor)
|
| 120 |
+
all_anchors = torch.cat(anchors, dim=0)
|
| 121 |
+
all_scalers = torch.cat(scaler, dim=0)
|
| 122 |
+
return all_anchors, all_scalers
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class BoxMatcher:
|
| 126 |
+
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
|
| 127 |
+
self.class_num = class_num
|
| 128 |
+
self.anchors = anchors
|
| 129 |
+
for attr_name in cfg:
|
| 130 |
+
setattr(self, attr_name, cfg[attr_name])
|
| 131 |
+
|
| 132 |
+
def get_valid_matrix(self, target_bbox: Tensor):
|
| 133 |
+
"""
|
| 134 |
+
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
target_bbox [batch x targets x 4]: The bounding box of each targets.
|
| 138 |
+
Returns:
|
| 139 |
+
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps with anchors.
|
| 140 |
+
"""
|
| 141 |
+
Xmin, Ymin, Xmax, Ymax = target_bbox[:, :, None].unbind(3)
|
| 142 |
+
anchors = self.anchors[None, None] # add a axis at first, second dimension
|
| 143 |
+
anchors_x, anchors_y = anchors.unbind(dim=3)
|
| 144 |
+
target_in_x = (Xmin < anchors_x) & (anchors_x < Xmax)
|
| 145 |
+
target_in_y = (Ymin < anchors_y) & (anchors_y < Ymax)
|
| 146 |
+
target_on_anchor = target_in_x & target_in_y
|
| 147 |
+
return target_on_anchor
|
| 148 |
+
|
| 149 |
+
def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
|
| 150 |
+
"""
|
| 151 |
+
Get the (predicted class' probabilities) corresponding to the target classes across all anchors
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
predict_cls [batch x class x anchors]: The predicted probabilities for each class across each anchor.
|
| 155 |
+
target_cls [batch x targets]: The class index for each target.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
[batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
|
| 159 |
+
"""
|
| 160 |
+
target_cls = target_cls.expand(-1, -1, 8400)
|
| 161 |
+
predict_cls = predict_cls.transpose(1, 2)
|
| 162 |
+
cls_probabilities = torch.gather(predict_cls, 1, target_cls)
|
| 163 |
+
return cls_probabilities
|
| 164 |
+
|
| 165 |
+
def get_iou_matrix(self, predict_bbox, target_bbox) -> Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Get the IoU between each target bounding box and each predicted bounding box.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
predict_bbox [batch x predicts x 4]: Bounding box with [x1, y1, x2, y2].
|
| 171 |
+
target_bbox [batch x targets x 4]: Bounding box with [x1, y1, x2, y2].
|
| 172 |
+
Returns:
|
| 173 |
+
[batch x targets x predicts]: The IoU scores between each target and predicted.
|
| 174 |
+
"""
|
| 175 |
+
return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)
|
| 176 |
+
|
| 177 |
+
def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
|
| 178 |
+
"""
|
| 179 |
+
Filter the top-k suitability of targets for each anchor.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
|
| 183 |
+
topk (int, optional): Number of top scores to retain per anchor.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor
|
| 187 |
+
topk_masks [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
|
| 188 |
+
"""
|
| 189 |
+
values, indices = target_matrix.topk(topk, dim=-1)
|
| 190 |
+
topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
|
| 191 |
+
topk_targets.scatter_(dim=-1, index=indices, src=values)
|
| 192 |
+
topk_masks = topk_targets > 0
|
| 193 |
+
return topk_targets, topk_masks
|
| 194 |
+
|
| 195 |
+
def filter_duplicates(self, target_matrix: Tensor):
|
| 196 |
+
"""
|
| 197 |
+
Filter the maximum suitability target index of each anchor.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
| 204 |
+
"""
|
| 205 |
+
unique_indices = target_matrix.argmax(dim=1)
|
| 206 |
+
return unique_indices[..., None]
|
| 207 |
+
|
| 208 |
+
def __call__(self, target: Tensor, predict: Tensor) -> Tuple[Tensor, Tensor]:
|
| 209 |
+
"""
|
| 210 |
+
1. For each anchor prediction, find the highest suitability targets
|
| 211 |
+
2. Select the targets
|
| 212 |
+
2. Noramlize the class probilities of targets
|
| 213 |
+
"""
|
| 214 |
+
predict_cls, predict_bbox = predict.split(self.class_num, dim=-1) # B, HW x (C B) -> B x HW x C, B x HW x B
|
| 215 |
+
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
|
| 216 |
+
target_cls = target_cls.long()
|
| 217 |
+
|
| 218 |
+
# get valid matrix (each gt appear in which anchor grid)
|
| 219 |
+
grid_mask = self.get_valid_matrix(target_bbox)
|
| 220 |
+
|
| 221 |
+
# get iou matrix (iou with each gt bbox and each predict anchor)
|
| 222 |
+
iou_mat = self.get_iou_matrix(predict_bbox, target_bbox)
|
| 223 |
+
|
| 224 |
+
# get cls matrix (cls prob with each gt class and each predict class)
|
| 225 |
+
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
|
| 226 |
+
|
| 227 |
+
# TODO: alpha and beta should be set at hydra
|
| 228 |
+
target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
|
| 229 |
+
|
| 230 |
+
# choose topk
|
| 231 |
+
# TODO: topk should be set at hydra
|
| 232 |
+
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
| 233 |
+
|
| 234 |
+
# delete one anchor pred assign to mutliple gts
|
| 235 |
+
unique_indices = self.filter_duplicates(topk_targets)
|
| 236 |
+
|
| 237 |
+
# TODO: do we need grid_mask? Filter the valid groud truth
|
| 238 |
+
valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool()
|
| 239 |
+
|
| 240 |
+
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
| 241 |
+
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
| 242 |
+
align_cls = F.one_hot(align_cls, self.class_num)
|
| 243 |
+
|
| 244 |
+
# normalize class ditribution
|
| 245 |
+
max_target = target_matrix.amax(dim=-1, keepdim=True)
|
| 246 |
+
max_iou = iou_mat.amax(dim=-1, keepdim=True)
|
| 247 |
+
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|
| 248 |
+
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
|
| 249 |
+
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
| 250 |
+
|
| 251 |
+
return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
|
utils/loss.py
CHANGED
|
@@ -1,2 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
def get_loss_function(*args, **kwargs):
|
| 2 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import time
|
| 3 |
+
from typing import Any, List
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from hydra import main
|
| 10 |
+
from loguru import logger
|
| 11 |
+
from torch import Tensor, nn
|
| 12 |
+
from torch.nn import BCEWithLogitsLoss
|
| 13 |
+
|
| 14 |
+
sys.path.append("./")
|
| 15 |
+
from config.config import Config
|
| 16 |
+
from tools.bbox_helper import BoxMatcher, calculate_iou, make_anchor, transform_bbox
|
| 17 |
+
|
| 18 |
+
|
| 19 |
def get_loss_function(*args, **kwargs):
|
| 20 |
raise NotImplementedError
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BCELoss(nn.Module):
|
| 24 |
+
def __init__(self) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.bce = BCEWithLogitsLoss(pos_weight=torch.tensor([1.0], device=torch.device("cuda")), reduction="none")
|
| 27 |
+
|
| 28 |
+
def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any:
|
| 29 |
+
return self.bce(predicts_cls, targets_cls).sum() / cls_norm
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BoxLoss(nn.Module):
|
| 33 |
+
def __init__(self) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
def forward(
|
| 37 |
+
self, predicts_bbox: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
|
| 38 |
+
) -> Any:
|
| 39 |
+
valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
|
| 40 |
+
picked_predict = predicts_bbox[valid_bbox].view(-1, 4)
|
| 41 |
+
picked_targets = targets_bbox[valid_bbox].view(-1, 4)
|
| 42 |
+
|
| 43 |
+
iou = calculate_iou(picked_predict, picked_targets, "ciou").diag()
|
| 44 |
+
loss_iou = 1.0 - iou
|
| 45 |
+
loss_iou = (loss_iou * box_norm).sum() / cls_norm
|
| 46 |
+
return loss_iou
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DFLoss(nn.Module):
|
| 50 |
+
def __init__(self, anchors: Tensor, scaler: Tensor, reg_max: int) -> None:
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.anchors = anchors
|
| 53 |
+
self.scaler = scaler
|
| 54 |
+
self.reg_max = reg_max
|
| 55 |
+
|
| 56 |
+
def forward(
|
| 57 |
+
self, predicts_anc: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
|
| 58 |
+
) -> Any:
|
| 59 |
+
valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
|
| 60 |
+
bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
|
| 61 |
+
anchors_norm = (self.anchors / self.scaler[:, None])[None]
|
| 62 |
+
targets_dist = torch.cat(((anchors_norm - bbox_lt), (bbox_rb - anchors_norm)), -1).clamp(0, self.reg_max - 1.01)
|
| 63 |
+
picked_targets = targets_dist[valid_bbox].view(-1)
|
| 64 |
+
picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)
|
| 65 |
+
|
| 66 |
+
label_left, label_right = picked_targets.floor(), picked_targets.floor() + 1
|
| 67 |
+
weight_left, weight_right = label_right - picked_targets, picked_targets - label_left
|
| 68 |
+
|
| 69 |
+
loss_left = F.cross_entropy(picked_predict, label_left.to(torch.long), reduction="none")
|
| 70 |
+
loss_right = F.cross_entropy(picked_predict, label_right.to(torch.long), reduction="none")
|
| 71 |
+
loss_dfl = loss_left * weight_left + loss_right * weight_right
|
| 72 |
+
loss_dfl = loss_dfl.view(-1, 4).mean(-1)
|
| 73 |
+
loss_dfl = (loss_dfl * box_norm).sum() / cls_norm
|
| 74 |
+
return loss_dfl
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class YOLOLoss:
|
| 78 |
+
def __init__(self, cfg: Config) -> None:
|
| 79 |
+
self.reg_max = cfg.model.anchor.reg_max
|
| 80 |
+
self.class_num = cfg.hyper.data.class_num
|
| 81 |
+
self.image_size = list(cfg.hyper.data.image_size)
|
| 82 |
+
self.strides = cfg.model.anchor.strides
|
| 83 |
+
device = torch.device("cuda")
|
| 84 |
+
|
| 85 |
+
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float16, device=device)
|
| 86 |
+
self.scale_up = torch.tensor(self.image_size * 2, device=device)
|
| 87 |
+
|
| 88 |
+
self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
|
| 89 |
+
|
| 90 |
+
self.cls = BCELoss()
|
| 91 |
+
self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
|
| 92 |
+
self.iou = BoxLoss()
|
| 93 |
+
|
| 94 |
+
self.matcher = BoxMatcher(cfg.hyper.train.matcher, self.class_num, self.anchors)
|
| 95 |
+
|
| 96 |
+
def parse_predicts(self, predicts: List[Tensor]) -> Tensor:
|
| 97 |
+
"""
|
| 98 |
+
args:
|
| 99 |
+
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
|
| 100 |
+
return:
|
| 101 |
+
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
|
| 102 |
+
"""
|
| 103 |
+
preds = []
|
| 104 |
+
for pred in predicts:
|
| 105 |
+
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
| 106 |
+
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
| 107 |
+
|
| 108 |
+
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
|
| 109 |
+
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
| 110 |
+
|
| 111 |
+
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
| 112 |
+
|
| 113 |
+
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
| 114 |
+
pred_minXY = self.anchors - lt
|
| 115 |
+
pred_maxXY = self.anchors + rb
|
| 116 |
+
predicts = torch.cat([preds_cls, pred_minXY, pred_maxXY], dim=-1)
|
| 117 |
+
|
| 118 |
+
return predicts, preds_anc
|
| 119 |
+
|
| 120 |
+
def parse_targets(self, targets: Tensor, batch_size: int = 16) -> List[Tensor]:
|
| 121 |
+
"""
|
| 122 |
+
return List:
|
| 123 |
+
"""
|
| 124 |
+
targets[:, 2:] = transform_bbox(targets[:, 2:], "xycwh -> xyxy") * self.scale_up
|
| 125 |
+
bbox_num = targets[:, 0].int().bincount()
|
| 126 |
+
batch_targets = torch.zeros(batch_size, bbox_num.max(), 5, device=targets.device)
|
| 127 |
+
for instance_idx, bbox_num in enumerate(bbox_num):
|
| 128 |
+
instance_targets = targets[targets[:, 0] == instance_idx]
|
| 129 |
+
batch_targets[instance_idx, :bbox_num] = instance_targets[:, 1:].detach()
|
| 130 |
+
return batch_targets
|
| 131 |
+
|
| 132 |
+
def separate_anchor(self, anchors):
|
| 133 |
+
"""
|
| 134 |
+
separate anchor and bbouding box
|
| 135 |
+
"""
|
| 136 |
+
anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1)
|
| 137 |
+
anchors_box = anchors_box / self.scaler[None, :, None]
|
| 138 |
+
return anchors_cls, anchors_box
|
| 139 |
+
|
| 140 |
+
@torch.autocast("cuda")
|
| 141 |
+
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tensor:
|
| 142 |
+
# Batch_Size x (Anchor + Class) x H x W
|
| 143 |
+
tlist = [time.time()]
|
| 144 |
+
# TODO: check datatype, why targets has a little bit error with origin version
|
| 145 |
+
predicts, predicts_anc = self.parse_predicts(predicts[0])
|
| 146 |
+
targets = self.parse_targets(targets)
|
| 147 |
+
|
| 148 |
+
align_targets, valid_masks = self.matcher(targets, predicts)
|
| 149 |
+
# calculate loss between with instance and predict
|
| 150 |
+
|
| 151 |
+
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
| 152 |
+
predicts_cls, predicts_bbox = self.separate_anchor(predicts)
|
| 153 |
+
|
| 154 |
+
cls_norm = targets_cls.sum()
|
| 155 |
+
box_norm = targets_cls.sum(-1)[valid_masks]
|
| 156 |
+
|
| 157 |
+
## -- CLS -- ##
|
| 158 |
+
loss_cls = self.cls(predicts_cls, targets_cls, cls_norm)
|
| 159 |
+
## -- IOU -- ##
|
| 160 |
+
loss_iou = self.iou(predicts_bbox, targets_bbox, valid_masks, box_norm, cls_norm)
|
| 161 |
+
## -- DFL -- ##
|
| 162 |
+
loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
|
| 163 |
+
|
| 164 |
+
logger.info("Loss IoU: {:.5f}, DFL: {:.5f}, CLS: {:.5f}", loss_iou, loss_dfl, loss_cls)
|
| 165 |
+
tlist.append(time.time())
|
| 166 |
+
logger.info(f"Calculate Loss Run Time {np.diff(np.array(tlist)) * 1e3} ms")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@main(config_path="../config", config_name="config", version_base=None)
|
| 170 |
+
def main(cfg):
|
| 171 |
+
losser = YOLOLoss(cfg)
|
| 172 |
+
targets = torch.load("targets.pt")
|
| 173 |
+
predicts = torch.load("predicts.pt")
|
| 174 |
+
losser(predicts, targets)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
import sys
|
| 179 |
+
|
| 180 |
+
sys.path.append("./")
|
| 181 |
+
from tools.log_helper import custom_logger
|
| 182 |
+
|
| 183 |
+
custom_logger()
|
| 184 |
+
main()
|