Spaces:
Sleeping
Sleeping
| from typing import Any, Dict, List, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from loguru import logger | |
| from torch import Tensor, nn | |
| from torch.nn import BCEWithLogitsLoss | |
| from yolo.config.config import Config | |
| from yolo.utils.bounding_box_utils import ( | |
| AnchorBoxConverter, | |
| BoxMatcher, | |
| calculate_iou, | |
| generate_anchors, | |
| ) | |
| from yolo.utils.module_utils import divide_into_chunks | |
| class BCELoss(nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.bce = BCEWithLogitsLoss(pos_weight=torch.tensor([1.0], device=device), reduction="none") | |
| def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any: | |
| return self.bce(predicts_cls, targets_cls).sum() / cls_norm | |
| class BoxLoss(nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def forward( | |
| self, predicts_bbox: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor | |
| ) -> Any: | |
| valid_bbox = valid_masks[..., None].expand(-1, -1, 4) | |
| picked_predict = predicts_bbox[valid_bbox].view(-1, 4) | |
| picked_targets = targets_bbox[valid_bbox].view(-1, 4) | |
| iou = calculate_iou(picked_predict, picked_targets, "ciou").diag() | |
| loss_iou = 1.0 - iou | |
| loss_iou = (loss_iou * box_norm).sum() / cls_norm | |
| return loss_iou | |
| class DFLoss(nn.Module): | |
| def __init__(self, anchors: Tensor, scaler: Tensor, reg_max: int) -> None: | |
| super().__init__() | |
| self.anchors = anchors | |
| self.scaler = scaler | |
| self.reg_max = reg_max | |
| def forward( | |
| self, predicts_anc: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor | |
| ) -> Any: | |
| valid_bbox = valid_masks[..., None].expand(-1, -1, 4) | |
| bbox_lt, bbox_rb = targets_bbox.chunk(2, -1) | |
| anchors_norm = (self.anchors / self.scaler[:, None])[None] | |
| targets_dist = torch.cat(((anchors_norm - bbox_lt), (bbox_rb - anchors_norm)), -1).clamp(0, self.reg_max - 1.01) | |
| picked_targets = targets_dist[valid_bbox].view(-1) | |
| picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max) | |
| label_left, label_right = picked_targets.floor(), picked_targets.floor() + 1 | |
| weight_left, weight_right = label_right - picked_targets, picked_targets - label_left | |
| loss_left = F.cross_entropy(picked_predict, label_left.to(torch.long), reduction="none") | |
| loss_right = F.cross_entropy(picked_predict, label_right.to(torch.long), reduction="none") | |
| loss_dfl = loss_left * weight_left + loss_right * weight_right | |
| loss_dfl = loss_dfl.view(-1, 4).mean(-1) | |
| loss_dfl = (loss_dfl * box_norm).sum() / cls_norm | |
| return loss_dfl | |
| class YOLOLoss: | |
| def __init__(self, cfg: Config) -> None: | |
| self.reg_max = cfg.model.anchor.reg_max | |
| self.class_num = cfg.class_num | |
| self.image_size = list(cfg.image_size) | |
| self.strides = cfg.model.anchor.strides | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device) | |
| self.scale_up = torch.tensor(self.image_size * 2, device=device) | |
| self.anchors, self.scaler = generate_anchors(self.image_size, self.strides, device) | |
| self.cls = BCELoss() | |
| self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max) | |
| self.iou = BoxLoss() | |
| self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors) | |
| self.box_converter = AnchorBoxConverter(cfg, device) | |
| def separate_anchor(self, anchors): | |
| """ | |
| separate anchor and bbouding box | |
| """ | |
| anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1) | |
| anchors_box = anchors_box / self.scaler[None, :, None] | |
| return anchors_cls, anchors_box | |
| def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]: | |
| # Batch_Size x (Anchor + Class) x H x W | |
| # TODO: check datatype, why targets has a little bit error with origin version | |
| predicts, predicts_anc = self.box_converter(predicts) | |
| # For each predicted targets, assign a best suitable ground truth box. | |
| align_targets, valid_masks = self.matcher(targets, predicts) | |
| targets_cls, targets_bbox = self.separate_anchor(align_targets) | |
| predicts_cls, predicts_bbox = self.separate_anchor(predicts) | |
| cls_norm = targets_cls.sum() | |
| box_norm = targets_cls.sum(-1)[valid_masks] | |
| ## -- CLS -- ## | |
| loss_cls = self.cls(predicts_cls, targets_cls, cls_norm) | |
| ## -- IOU -- ## | |
| loss_iou = self.iou(predicts_bbox, targets_bbox, valid_masks, box_norm, cls_norm) | |
| ## -- DFL -- ## | |
| loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm) | |
| return loss_iou, loss_dfl, loss_cls | |
| class DualLoss: | |
| def __init__(self, cfg: Config) -> None: | |
| self.loss = YOLOLoss(cfg) | |
| self.aux_rate = cfg.task.loss.aux | |
| self.iou_rate = cfg.task.loss.objective["BoxLoss"] | |
| self.dfl_rate = cfg.task.loss.objective["DFLoss"] | |
| self.cls_rate = cfg.task.loss.objective["BCELoss"] | |
| def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: | |
| targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up | |
| # TODO: Need Refactor this region, make it flexible! | |
| predicts = divide_into_chunks(predicts[0], 2) | |
| aux_iou, aux_dfl, aux_cls = self.loss(predicts[0], targets) | |
| main_iou, main_dfl, main_cls = self.loss(predicts[1], targets) | |
| loss_dict = { | |
| "BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou), | |
| "DFLoss": self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl), | |
| "BCELoss": self.cls_rate * (aux_cls * self.aux_rate + main_cls), | |
| } | |
| loss_sum = sum(list(loss_dict.values())) / len(loss_dict) | |
| return loss_sum, loss_dict | |
| def get_loss_function(cfg: Config) -> YOLOLoss: | |
| loss_function = DualLoss(cfg) | |
| logger.info("✅ Success load loss function") | |
| return loss_function | |