Spaces:
Runtime error
Runtime error
| """ Binary Cross Entropy w/ a few extras | |
| Hacked together by / Copyright 2021 Ross Wightman | |
| """ | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class BinaryCrossEntropy(nn.Module): | |
| """ BCE with optional one-hot from dense targets, label smoothing, thresholding | |
| NOTE for experiments comparing CE to BCE /w label smoothing, may remove | |
| """ | |
| def __init__( | |
| self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None, | |
| reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None): | |
| super(BinaryCrossEntropy, self).__init__() | |
| assert 0. <= smoothing < 1.0 | |
| self.smoothing = smoothing | |
| self.target_threshold = target_threshold | |
| self.reduction = reduction | |
| self.register_buffer('weight', weight) | |
| self.register_buffer('pos_weight', pos_weight) | |
| def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| assert x.shape[0] == target.shape[0] | |
| if target.shape != x.shape: | |
| # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse | |
| num_classes = x.shape[-1] | |
| # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ | |
| off_value = self.smoothing / num_classes | |
| on_value = 1. - self.smoothing + off_value | |
| target = target.long().view(-1, 1) | |
| target = torch.full( | |
| (target.size()[0], num_classes), | |
| off_value, | |
| device=x.device, dtype=x.dtype).scatter_(1, target, on_value) | |
| if self.target_threshold is not None: | |
| # Make target 0, or 1 if threshold set | |
| target = target.gt(self.target_threshold).to(dtype=target.dtype) | |
| return F.binary_cross_entropy_with_logits( | |
| x, target, | |
| self.weight, | |
| pos_weight=self.pos_weight, | |
| reduction=self.reduction) | |