Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- encoding: utf-8 -*- | |
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class IOUloss(nn.Module): | |
| def __init__(self, reduction="none", loss_type="iou"): | |
| super(IOUloss, self).__init__() | |
| self.reduction = reduction | |
| self.loss_type = loss_type | |
| def forward(self, pred, target): | |
| assert pred.shape[0] == target.shape[0] | |
| pred = pred.view(-1, 4) | |
| target = target.view(-1, 4) | |
| tl = torch.max( | |
| (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) | |
| ) | |
| br = torch.min( | |
| (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) | |
| ) | |
| area_p = torch.prod(pred[:, 2:], 1) | |
| area_g = torch.prod(target[:, 2:], 1) | |
| en = (tl < br).type(tl.type()).prod(dim=1) | |
| area_i = torch.prod(br - tl, 1) * en | |
| iou = (area_i) / (area_p + area_g - area_i + 1e-16) | |
| if self.loss_type == "iou": | |
| loss = 1 - iou ** 2 | |
| elif self.loss_type == "giou": | |
| c_tl = torch.min( | |
| (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) | |
| ) | |
| c_br = torch.max( | |
| (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) | |
| ) | |
| area_c = torch.prod(c_br - c_tl, 1) | |
| giou = iou - (area_c - area_i) / area_c.clamp(1e-16) | |
| loss = 1 - giou.clamp(min=-1.0, max=1.0) | |
| if self.reduction == "mean": | |
| loss = loss.mean() | |
| elif self.reduction == "sum": | |
| loss = loss.sum() | |
| return loss | |
| def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): | |
| """ | |
| Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| alpha: (optional) Weighting factor in range (0,1) to balance | |
| positive vs negative examples. Default = -1 (no weighting). | |
| gamma: Exponent of the modulating factor (1 - p_t) to | |
| balance easy vs hard examples. | |
| Returns: | |
| Loss tensor | |
| """ | |
| prob = inputs.sigmoid() | |
| ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
| p_t = prob * targets + (1 - prob) * (1 - targets) | |
| loss = ce_loss * ((1 - p_t) ** gamma) | |
| if alpha >= 0: | |
| alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
| loss = alpha_t * loss | |
| #return loss.mean(0).sum() / num_boxes | |
| return loss.sum() / num_boxes |