π [Move] parse_predict to Anchor2Box class
Browse files- yolo/tools/bbox_helper.py +42 -1
- yolo/utils/loss.py +4 -44
yolo/tools/bbox_helper.py
CHANGED
|
@@ -3,9 +3,10 @@ from typing import List, Tuple
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn.functional as F
|
|
|
|
| 6 |
from torch import Tensor
|
| 7 |
|
| 8 |
-
from yolo.config.config import MatcherConfig
|
| 9 |
|
| 10 |
|
| 11 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
|
@@ -122,6 +123,46 @@ def make_anchor(image_size: List[int], strides: List[int], device):
|
|
| 122 |
return all_anchors, all_scalers
|
| 123 |
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
class BoxMatcher:
|
| 126 |
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
|
| 127 |
self.class_num = class_num
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn.functional as F
|
| 6 |
+
from einops import rearrange
|
| 7 |
from torch import Tensor
|
| 8 |
|
| 9 |
+
from yolo.config.config import Config, MatcherConfig
|
| 10 |
|
| 11 |
|
| 12 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
|
|
|
| 123 |
return all_anchors, all_scalers
|
| 124 |
|
| 125 |
|
| 126 |
+
class Anchor2Box:
|
| 127 |
+
def __init__(self, cfg: Config, device: torch.device) -> None:
|
| 128 |
+
self.reg_max = cfg.model.anchor.reg_max
|
| 129 |
+
self.class_num = cfg.hyper.data.class_num
|
| 130 |
+
self.image_size = list(cfg.hyper.data.image_size)
|
| 131 |
+
self.strides = cfg.model.anchor.strides
|
| 132 |
+
|
| 133 |
+
self.scale_up = torch.tensor(self.image_size * 2, device=device)
|
| 134 |
+
self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
|
| 135 |
+
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
|
| 136 |
+
|
| 137 |
+
def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor:
|
| 138 |
+
"""
|
| 139 |
+
args:
|
| 140 |
+
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
|
| 141 |
+
return:
|
| 142 |
+
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
|
| 143 |
+
"""
|
| 144 |
+
preds = []
|
| 145 |
+
for pred in predicts:
|
| 146 |
+
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
| 147 |
+
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
| 148 |
+
|
| 149 |
+
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
|
| 150 |
+
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
| 151 |
+
if with_logits:
|
| 152 |
+
preds_cls = preds_cls.sigmoid()
|
| 153 |
+
|
| 154 |
+
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
| 155 |
+
|
| 156 |
+
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
| 157 |
+
pred_minXY = self.anchors - lt
|
| 158 |
+
pred_maxXY = self.anchors + rb
|
| 159 |
+
preds_box = torch.cat([pred_minXY, pred_maxXY], dim=-1)
|
| 160 |
+
|
| 161 |
+
predicts = torch.cat([preds_cls, preds_box], dim=-1)
|
| 162 |
+
|
| 163 |
+
return predicts, preds_anc
|
| 164 |
+
|
| 165 |
+
|
| 166 |
class BoxMatcher:
|
| 167 |
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
|
| 168 |
self.class_num = class_num
|
yolo/utils/loss.py
CHANGED
|
@@ -8,12 +8,7 @@ from torch import Tensor, nn
|
|
| 8 |
from torch.nn import BCEWithLogitsLoss
|
| 9 |
|
| 10 |
from yolo.config.config import Config
|
| 11 |
-
from yolo.tools.bbox_helper import
|
| 12 |
-
BoxMatcher,
|
| 13 |
-
calculate_iou,
|
| 14 |
-
make_anchor,
|
| 15 |
-
transform_bbox,
|
| 16 |
-
)
|
| 17 |
from yolo.tools.module_helper import make_chunk
|
| 18 |
|
| 19 |
|
|
@@ -90,42 +85,7 @@ class YOLOLoss:
|
|
| 90 |
self.iou = BoxLoss()
|
| 91 |
|
| 92 |
self.matcher = BoxMatcher(cfg.hyper.train.loss.matcher, self.class_num, self.anchors)
|
| 93 |
-
|
| 94 |
-
def parse_predicts(self, predicts: List[Tensor]) -> Tensor:
|
| 95 |
-
"""
|
| 96 |
-
args:
|
| 97 |
-
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
|
| 98 |
-
return:
|
| 99 |
-
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
|
| 100 |
-
"""
|
| 101 |
-
preds = []
|
| 102 |
-
for pred in predicts:
|
| 103 |
-
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
| 104 |
-
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
| 105 |
-
|
| 106 |
-
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
|
| 107 |
-
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
| 108 |
-
|
| 109 |
-
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
| 110 |
-
|
| 111 |
-
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
| 112 |
-
pred_minXY = self.anchors - lt
|
| 113 |
-
pred_maxXY = self.anchors + rb
|
| 114 |
-
predicts = torch.cat([preds_cls, pred_minXY, pred_maxXY], dim=-1)
|
| 115 |
-
|
| 116 |
-
return predicts, preds_anc
|
| 117 |
-
|
| 118 |
-
def parse_targets(self, targets: Tensor, batch_size: int = 16) -> List[Tensor]:
|
| 119 |
-
"""
|
| 120 |
-
return List:
|
| 121 |
-
"""
|
| 122 |
-
targets[:, 2:] = transform_bbox(targets[:, 2:], "xycwh -> xyxy") * self.scale_up
|
| 123 |
-
bbox_num = targets[:, 0].int().bincount()
|
| 124 |
-
batch_targets = torch.zeros(batch_size, bbox_num.max(), 5, device=targets.device)
|
| 125 |
-
for instance_idx, bbox_num in enumerate(bbox_num):
|
| 126 |
-
instance_targets = targets[targets[:, 0] == instance_idx]
|
| 127 |
-
batch_targets[instance_idx, :bbox_num] = instance_targets[:, 1:].detach()
|
| 128 |
-
return batch_targets
|
| 129 |
|
| 130 |
def separate_anchor(self, anchors):
|
| 131 |
"""
|
|
@@ -138,10 +98,10 @@ class YOLOLoss:
|
|
| 138 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
| 139 |
# Batch_Size x (Anchor + Class) x H x W
|
| 140 |
# TODO: check datatype, why targets has a little bit error with origin version
|
| 141 |
-
predicts, predicts_anc = self.
|
| 142 |
|
|
|
|
| 143 |
align_targets, valid_masks = self.matcher(targets, predicts)
|
| 144 |
-
# calculate loss between with instance and predict
|
| 145 |
|
| 146 |
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
| 147 |
predicts_cls, predicts_bbox = self.separate_anchor(predicts)
|
|
|
|
| 8 |
from torch.nn import BCEWithLogitsLoss
|
| 9 |
|
| 10 |
from yolo.config.config import Config
|
| 11 |
+
from yolo.tools.bbox_helper import Anchor2Box, BoxMatcher, calculate_iou, make_anchor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from yolo.tools.module_helper import make_chunk
|
| 13 |
|
| 14 |
|
|
|
|
| 85 |
self.iou = BoxLoss()
|
| 86 |
|
| 87 |
self.matcher = BoxMatcher(cfg.hyper.train.loss.matcher, self.class_num, self.anchors)
|
| 88 |
+
self.box_converter = Anchor2Box(cfg, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
def separate_anchor(self, anchors):
|
| 91 |
"""
|
|
|
|
| 98 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
| 99 |
# Batch_Size x (Anchor + Class) x H x W
|
| 100 |
# TODO: check datatype, why targets has a little bit error with origin version
|
| 101 |
+
predicts, predicts_anc = self.box_converter(predicts)
|
| 102 |
|
| 103 |
+
# For each predicted targets, assign a best suitable ground truth box.
|
| 104 |
align_targets, valid_masks = self.matcher(targets, predicts)
|
|
|
|
| 105 |
|
| 106 |
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
| 107 |
predicts_cls, predicts_bbox = self.separate_anchor(predicts)
|