Spaces:
Runtime error
Runtime error
| # Adopted from https://github.com/magic-research/Sa2VA. | |
| # Below is the original copyright: | |
| # coding=utf-8 | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| linear_cross_entropy = None | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| from rynnec.constants import IGNORE_INDEX | |
| from torch import Tensor | |
| import logging | |
| from huggingface_hub import hf_hub_download | |
| import functools | |
| from typing import Callable, Optional | |
| def reduce_loss(loss: Tensor, reduction: str) -> Tensor: | |
| """Reduce loss as specified. | |
| Args: | |
| loss (Tensor): Elementwise loss tensor. | |
| reduction (str): Options are "none", "mean" and "sum". | |
| Return: | |
| Tensor: Reduced loss tensor. | |
| """ | |
| reduction_enum = F._Reduction.get_enum(reduction) | |
| # none: 0, elementwise_mean:1, sum: 2 | |
| if reduction_enum == 0: | |
| return loss | |
| elif reduction_enum == 1: | |
| return loss.mean() | |
| elif reduction_enum == 2: | |
| return loss.sum() | |
| def weight_reduce_loss(loss: Tensor, | |
| weight: Optional[Tensor] = None, | |
| reduction: str = 'mean', | |
| avg_factor: Optional[float] = None) -> Tensor: | |
| """Apply element-wise weight and reduce loss. | |
| Args: | |
| loss (Tensor): Element-wise loss. | |
| weight (Optional[Tensor], optional): Element-wise weights. | |
| Defaults to None. | |
| reduction (str, optional): Same as built-in losses of PyTorch. | |
| Defaults to 'mean'. | |
| avg_factor (Optional[float], optional): Average factor when | |
| computing the mean of losses. Defaults to None. | |
| Returns: | |
| Tensor: Processed loss values. | |
| """ | |
| # if weight is specified, apply element-wise weight | |
| if weight is not None: | |
| loss = loss * weight | |
| # if avg_factor is not specified, just reduce the loss | |
| if avg_factor is None: | |
| loss = reduce_loss(loss, reduction) | |
| else: | |
| # if reduction is mean, then average the loss by avg_factor | |
| if reduction == 'mean': | |
| # Avoid causing ZeroDivisionError when avg_factor is 0.0, | |
| # i.e., all labels of an image belong to ignore index. | |
| eps = torch.finfo(torch.float32).eps | |
| loss = loss.sum() / (avg_factor + eps) | |
| # if reduction is 'none', then do nothing, otherwise raise an error | |
| elif reduction != 'none': | |
| raise ValueError('avg_factor can not be used with reduction="sum"') | |
| return loss | |
| def dice_loss(pred, | |
| target, | |
| weight=None, | |
| eps=1e-3, | |
| reduction='mean', | |
| naive_dice=False, | |
| avg_factor=None): | |
| """Calculate dice loss, there are two forms of dice loss is supported: | |
| - the one proposed in `V-Net: Fully Convolutional Neural | |
| Networks for Volumetric Medical Image Segmentation | |
| <https://arxiv.org/abs/1606.04797>`_. | |
| - the dice loss in which the power of the number in the | |
| denominator is the first power instead of the second | |
| power. | |
| Args: | |
| pred (torch.Tensor): The prediction, has a shape (n, *) | |
| target (torch.Tensor): The learning label of the prediction, | |
| shape (n, *), same shape of pred. | |
| weight (torch.Tensor, optional): The weight of loss for each | |
| prediction, has a shape (n,). Defaults to None. | |
| eps (float): Avoid dividing by zero. Default: 1e-3. | |
| reduction (str, optional): The method used to reduce the loss into | |
| a scalar. Defaults to 'mean'. | |
| Options are "none", "mean" and "sum". | |
| naive_dice (bool, optional): If false, use the dice | |
| loss defined in the V-Net paper, otherwise, use the | |
| naive dice loss in which the power of the number in the | |
| denominator is the first power instead of the second | |
| power.Defaults to False. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| """ | |
| input = pred.flatten(1) | |
| target = target.flatten(1).float() | |
| a = torch.sum(input * target, 1) | |
| if naive_dice: | |
| b = torch.sum(input, 1) | |
| c = torch.sum(target, 1) | |
| d = (2 * a + eps) / (b + c + eps) | |
| else: | |
| b = torch.sum(input * input, 1) + eps | |
| c = torch.sum(target * target, 1) + eps | |
| d = (2 * a) / (b + c) | |
| loss = 1 - d | |
| if weight is not None: | |
| assert weight.ndim == loss.ndim | |
| assert len(weight) == len(pred) | |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) | |
| return loss | |
| class DiceLoss(nn.Module): | |
| def __init__(self, | |
| use_sigmoid=True, | |
| activate=True, | |
| reduction='mean', | |
| naive_dice=False, | |
| loss_weight=1.0, | |
| eps=1e-3): | |
| """Compute dice loss. | |
| Args: | |
| use_sigmoid (bool, optional): Whether to the prediction is | |
| used for sigmoid or softmax. Defaults to True. | |
| activate (bool): Whether to activate the predictions inside, | |
| this will disable the inside sigmoid operation. | |
| Defaults to True. | |
| reduction (str, optional): The method used | |
| to reduce the loss. Options are "none", | |
| "mean" and "sum". Defaults to 'mean'. | |
| naive_dice (bool, optional): If false, use the dice | |
| loss defined in the V-Net paper, otherwise, use the | |
| naive dice loss in which the power of the number in the | |
| denominator is the first power instead of the second | |
| power. Defaults to False. | |
| loss_weight (float, optional): Weight of loss. Defaults to 1.0. | |
| eps (float): Avoid dividing by zero. Defaults to 1e-3. | |
| """ | |
| super(DiceLoss, self).__init__() | |
| self.use_sigmoid = use_sigmoid | |
| self.reduction = reduction | |
| self.naive_dice = naive_dice | |
| self.loss_weight = loss_weight | |
| self.eps = eps | |
| self.activate = activate | |
| def forward(self, | |
| pred, | |
| target, | |
| weight=None, | |
| reduction_override=None, | |
| avg_factor=None): | |
| """Forward function. | |
| Args: | |
| pred (torch.Tensor): The prediction, has a shape (n, *). | |
| target (torch.Tensor): The label of the prediction, | |
| shape (n, *), same shape of pred. | |
| weight (torch.Tensor, optional): The weight of loss for each | |
| prediction, has a shape (n,). Defaults to None. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| reduction_override (str, optional): The reduction method used to | |
| override the original reduction method of the loss. | |
| Options are "none", "mean" and "sum". | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = ( | |
| reduction_override if reduction_override else self.reduction) | |
| if self.activate: | |
| if self.use_sigmoid: | |
| pred = pred.sigmoid() | |
| else: | |
| raise NotImplementedError | |
| loss = self.loss_weight * dice_loss( | |
| pred, | |
| target, | |
| weight, | |
| eps=self.eps, | |
| reduction=reduction, | |
| naive_dice=self.naive_dice, | |
| avg_factor=avg_factor) | |
| return loss | |
| def cross_entropy_loss( | |
| hidden_states, | |
| lm_head, | |
| position_ids, | |
| labels, | |
| reduction_scope="sequence", | |
| **loss_kwargs | |
| ): | |
| batch_size = hidden_states.size(0) | |
| shift_hidden_states = hidden_states[..., :-1, :] | |
| shift_labels = labels[..., 1:] | |
| mask = shift_labels != IGNORE_INDEX | |
| shift_hidden_states = shift_hidden_states[mask].contiguous() | |
| shift_labels = shift_labels[mask].contiguous() | |
| if mask.sum() == 0: | |
| print(f"Get labels={labels}. Found no sample to calculate loss!") | |
| pseudo_logits = lm_head(hidden_states[:, 0:1]) | |
| loss = 0.0 * pseudo_logits.mean() | |
| return loss | |
| if "num_items_in_batch" not in loss_kwargs: | |
| reduction = "mean" | |
| denominator = None | |
| elif reduction_scope == "batch": | |
| reduction = "sum" | |
| denominator = loss_kwargs["num_items_in_batch"] | |
| elif reduction_scope == "sequence": | |
| reduction = "none" | |
| if batch_size == 1: | |
| # NOTE: packed sequence | |
| start_indices = torch.nonzero(position_ids[0] == 0)[:, 0] | |
| end_indices = F.pad(start_indices[1:], (0, 1), value=position_ids.size(1)) | |
| batch_indices = torch.cat( | |
| [ | |
| torch.full((e - s,), fill_value=i, device=position_ids.device, dtype=torch.long) | |
| for i, (s, e) in enumerate(zip(start_indices, end_indices)) | |
| ], | |
| ).unsqueeze(0) | |
| else: | |
| batch_indices = torch.arange(batch_size, device=position_ids.device) | |
| batch_indices = batch_indices.unsqueeze(1).expand(-1, hidden_states.size(1)) | |
| shift_batch_indices = batch_indices[..., :-1] | |
| shift_batch_indices = shift_batch_indices[mask].contiguous() | |
| num_tokens = F.one_hot(shift_batch_indices).sum(dim=0) | |
| denominator = num_tokens[shift_batch_indices] * loss_kwargs["num_items_in_batch"] | |
| else: | |
| raise ValueError(f"Unknown reduction scope: {reduction_scope}") | |
| if linear_cross_entropy is None: | |
| shift_logits = lm_head(shift_hidden_states) | |
| loss = torch.nn.functional.cross_entropy( | |
| shift_logits, | |
| shift_labels, | |
| reduction=reduction, | |
| ) | |
| else: | |
| loss = linear_cross_entropy( | |
| shift_hidden_states, | |
| lm_head.weight, | |
| shift_labels, | |
| bias=lm_head.bias, | |
| reduction=reduction, | |
| accum_e_fp32=True, | |
| accum_c_fp32=True, | |
| ) | |
| if denominator is not None: | |
| loss = loss / denominator | |
| if loss.ndim > 0: | |
| loss = loss.sum() | |
| return loss | |
| def cross_entropy(pred, | |
| label, | |
| weight=None, | |
| reduction='mean', | |
| avg_factor=None, | |
| class_weight=None, | |
| ignore_index=-100, | |
| avg_non_ignore=False): | |
| """Calculate the CrossEntropy loss. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, C), C is the number | |
| of classes. | |
| label (torch.Tensor): The learning label of the prediction. | |
| weight (torch.Tensor, optional): Sample-wise loss weight. | |
| reduction (str, optional): The method used to reduce the loss. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| class_weight (list[float], optional): The weight for each class. | |
| ignore_index (int | None): The label index to be ignored. | |
| If None, it will be set to default value. Default: -100. | |
| avg_non_ignore (bool): The flag decides to whether the loss is | |
| only averaged over non-ignored targets. Default: False. | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| # The default value of ignore_index is the same as F.cross_entropy | |
| ignore_index = -100 if ignore_index is None else ignore_index | |
| # element-wise losses | |
| loss = F.cross_entropy( | |
| pred, | |
| label, | |
| weight=class_weight, | |
| reduction='none', | |
| ignore_index=ignore_index) | |
| # average loss over non-ignored elements | |
| # pytorch's official cross_entropy average loss over non-ignored elements | |
| # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa | |
| if (avg_factor is None) and avg_non_ignore and reduction == 'mean': | |
| avg_factor = label.numel() - (label == ignore_index).sum().item() | |
| # apply weights and do the reduction | |
| if weight is not None: | |
| weight = weight.float() | |
| loss = weight_reduce_loss( | |
| loss, weight=weight, reduction=reduction, avg_factor=avg_factor) | |
| return loss | |
| def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index): | |
| """Expand onehot labels to match the size of prediction.""" | |
| bin_labels = labels.new_full((labels.size(0), label_channels), 0) | |
| valid_mask = (labels >= 0) & (labels != ignore_index) | |
| inds = torch.nonzero( | |
| valid_mask & (labels < label_channels), as_tuple=False) | |
| if inds.numel() > 0: | |
| bin_labels[inds, labels[inds]] = 1 | |
| valid_mask = valid_mask.view(-1, 1).expand(labels.size(0), | |
| label_channels).float() | |
| if label_weights is None: | |
| bin_label_weights = valid_mask | |
| else: | |
| bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels) | |
| bin_label_weights *= valid_mask | |
| return bin_labels, bin_label_weights, valid_mask | |
| def binary_cross_entropy(pred, | |
| label, | |
| weight=None, | |
| reduction='mean', | |
| avg_factor=None, | |
| class_weight=None, | |
| ignore_index=-100, | |
| avg_non_ignore=False): | |
| """Calculate the binary CrossEntropy loss. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, 1) or (N, ). | |
| When the shape of pred is (N, 1), label will be expanded to | |
| one-hot format, and when the shape of pred is (N, ), label | |
| will not be expanded to one-hot format. | |
| label (torch.Tensor): The learning label of the prediction, | |
| with shape (N, ). | |
| weight (torch.Tensor, optional): Sample-wise loss weight. | |
| reduction (str, optional): The method used to reduce the loss. | |
| Options are "none", "mean" and "sum". | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| class_weight (list[float], optional): The weight for each class. | |
| ignore_index (int | None): The label index to be ignored. | |
| If None, it will be set to default value. Default: -100. | |
| avg_non_ignore (bool): The flag decides to whether the loss is | |
| only averaged over non-ignored targets. Default: False. | |
| Returns: | |
| torch.Tensor: The calculated loss. | |
| """ | |
| # The default value of ignore_index is the same as F.cross_entropy | |
| ignore_index = -100 if ignore_index is None else ignore_index | |
| if pred.dim() != label.dim(): | |
| label, weight, valid_mask = _expand_onehot_labels( | |
| label, weight, pred.size(-1), ignore_index) | |
| else: | |
| # should mask out the ignored elements | |
| valid_mask = ((label >= 0) & (label != ignore_index)).float() | |
| if weight is not None: | |
| # The inplace writing method will have a mismatched broadcast | |
| # shape error if the weight and valid_mask dimensions | |
| # are inconsistent such as (B,N,1) and (B,N,C). | |
| weight = weight * valid_mask | |
| else: | |
| weight = valid_mask | |
| # average loss over non-ignored elements | |
| if (avg_factor is None) and avg_non_ignore and reduction == 'mean': | |
| avg_factor = valid_mask.sum().item() | |
| # weighted element-wise losses | |
| weight = weight.float() | |
| loss = F.binary_cross_entropy_with_logits( | |
| pred, label.float(), pos_weight=class_weight, reduction='none') | |
| # do the reduction for the weighted loss | |
| loss = weight_reduce_loss( | |
| loss, weight, reduction=reduction, avg_factor=avg_factor) | |
| return loss | |
| def mask_cross_entropy(pred, | |
| target, | |
| label, | |
| reduction='mean', | |
| avg_factor=None, | |
| class_weight=None, | |
| ignore_index=None, | |
| **kwargs): | |
| """Calculate the CrossEntropy loss for masks. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, C, *), C is the | |
| number of classes. The trailing * indicates arbitrary shape. | |
| target (torch.Tensor): The learning label of the prediction. | |
| label (torch.Tensor): ``label`` indicates the class label of the mask | |
| corresponding object. This will be used to select the mask in the | |
| of the class which the object belongs to when the mask prediction | |
| if not class-agnostic. | |
| reduction (str, optional): The method used to reduce the loss. | |
| Options are "none", "mean" and "sum". | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| class_weight (list[float], optional): The weight for each class. | |
| ignore_index (None): Placeholder, to be consistent with other loss. | |
| Default: None. | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| Example: | |
| >>> N, C = 3, 11 | |
| >>> H, W = 2, 2 | |
| >>> pred = torch.randn(N, C, H, W) * 1000 | |
| >>> target = torch.rand(N, H, W) | |
| >>> label = torch.randint(0, C, size=(N,)) | |
| >>> reduction = 'mean' | |
| >>> avg_factor = None | |
| >>> class_weights = None | |
| >>> loss = mask_cross_entropy(pred, target, label, reduction, | |
| >>> avg_factor, class_weights) | |
| >>> assert loss.shape == (1,) | |
| """ | |
| assert ignore_index is None, 'BCE loss does not support ignore_index' | |
| # TODO: handle these two reserved arguments | |
| assert reduction == 'mean' and avg_factor is None | |
| num_rois = pred.size()[0] | |
| inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) | |
| pred_slice = pred[inds, label].squeeze(1) | |
| return F.binary_cross_entropy_with_logits( | |
| pred_slice, target, weight=class_weight, reduction='mean')[None] | |
| class CrossEntropyLoss(nn.Module): | |
| def __init__(self, | |
| use_sigmoid=False, | |
| use_mask=False, | |
| reduction='mean', | |
| class_weight=None, | |
| ignore_index=None, | |
| loss_weight=1.0, | |
| avg_non_ignore=False): | |
| """CrossEntropyLoss. | |
| Args: | |
| use_sigmoid (bool, optional): Whether the prediction uses sigmoid | |
| of softmax. Defaults to False. | |
| use_mask (bool, optional): Whether to use mask cross entropy loss. | |
| Defaults to False. | |
| reduction (str, optional): . Defaults to 'mean'. | |
| Options are "none", "mean" and "sum". | |
| class_weight (list[float], optional): Weight of each class. | |
| Defaults to None. | |
| ignore_index (int | None): The label index to be ignored. | |
| Defaults to None. | |
| loss_weight (float, optional): Weight of the loss. Defaults to 1.0. | |
| avg_non_ignore (bool): The flag decides to whether the loss is | |
| only averaged over non-ignored targets. Default: False. | |
| """ | |
| super(CrossEntropyLoss, self).__init__() | |
| assert (use_sigmoid is False) or (use_mask is False) | |
| self.use_sigmoid = use_sigmoid | |
| self.use_mask = use_mask | |
| self.reduction = reduction | |
| self.loss_weight = loss_weight | |
| self.class_weight = class_weight | |
| self.ignore_index = ignore_index | |
| self.avg_non_ignore = avg_non_ignore | |
| if ((ignore_index is not None) and not self.avg_non_ignore | |
| and self.reduction == 'mean'): | |
| warnings.warn( | |
| 'Default ``avg_non_ignore`` is False, if you would like to ' | |
| 'ignore the certain label and average loss over non-ignore ' | |
| 'labels, which is the same with PyTorch official ' | |
| 'cross_entropy, set ``avg_non_ignore=True``.') | |
| if self.use_sigmoid: | |
| self.cls_criterion = binary_cross_entropy | |
| elif self.use_mask: | |
| self.cls_criterion = mask_cross_entropy | |
| else: | |
| self.cls_criterion = cross_entropy | |
| def extra_repr(self): | |
| """Extra repr.""" | |
| s = f'avg_non_ignore={self.avg_non_ignore}' | |
| return s | |
| def forward(self, | |
| cls_score, | |
| label, | |
| weight=None, | |
| avg_factor=None, | |
| reduction_override=None, | |
| ignore_index=None, | |
| **kwargs): | |
| """Forward function. | |
| Args: | |
| cls_score (torch.Tensor): The prediction. | |
| label (torch.Tensor): The learning label of the prediction. | |
| weight (torch.Tensor, optional): Sample-wise loss weight. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| reduction_override (str, optional): The method used to reduce the | |
| loss. Options are "none", "mean" and "sum". | |
| ignore_index (int | None): The label index to be ignored. | |
| If not None, it will override the default value. Default: None. | |
| Returns: | |
| torch.Tensor: The calculated loss. | |
| """ | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = ( | |
| reduction_override if reduction_override else self.reduction) | |
| if ignore_index is None: | |
| ignore_index = self.ignore_index | |
| if self.class_weight is not None: | |
| class_weight = cls_score.new_tensor( | |
| self.class_weight, device=cls_score.device) | |
| else: | |
| class_weight = None | |
| loss_cls = self.loss_weight * self.cls_criterion( | |
| cls_score, | |
| label, | |
| weight, | |
| class_weight=class_weight, | |
| reduction=reduction, | |
| avg_factor=avg_factor, | |
| ignore_index=ignore_index, | |
| avg_non_ignore=self.avg_non_ignore, | |
| **kwargs) | |
| return loss_cls | |