Spaces:
Sleeping
Sleeping
| # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py | |
| # Reference: https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py | |
| # Modified by Qihang Yu | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| _SOFTMAX_MASKING_CONSTANT = -99999.0 | |
| # https://www.tensorflow.org/api_docs/python/tf/math/divide_no_nan | |
| def divide_no_nan(x: torch.Tensor, y: torch.Tensor): | |
| return torch.nan_to_num(x / y, nan=0.0, posinf=0.0, neginf=0.0) | |
| # https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L393 | |
| def focal_cross_entropy_loss( | |
| pred: torch.Tensor, | |
| gt: torch.Tensor, | |
| weight: torch.Tensor, # This is for PQ-loss weighting | |
| focal_loss_alpha: float = 0.75, | |
| focal_loss_gamma: float = 0.0, | |
| background_channel_index: int = -1): | |
| """ | |
| pred: B x N x C | |
| gt: B x N | |
| weight: B x N | |
| """ | |
| pred = pred.transpose(1, 2) # B x C x N | |
| gt = F.one_hot(gt, num_classes=pred.shape[1]).transpose(1, 2).to(pred) # B x C x N | |
| loss = F.cross_entropy(pred, gt, reduction="none") # B x N | |
| if focal_loss_gamma == 0.0: | |
| focal_loss = loss | |
| else: | |
| pred = F.softmax(pred, dim=1) # B x C x N | |
| pt = (pred * gt).sum(1) # B x N | |
| focal_loss = torch.pow(1.0 - pt, focal_loss_gamma) * loss # B x N | |
| if focal_loss_alpha >= 0: | |
| alpha_weights = ( | |
| focal_loss_alpha * (1.0 - gt[:, background_channel_index]) | |
| + (1 - focal_loss_alpha) * gt[:, background_channel_index]) # B x N | |
| focal_loss = alpha_weights * focal_loss # B x N | |
| focal_loss = focal_loss * weight # B x N | |
| focal_loss = focal_loss.flatten(1) | |
| num_non_zero = (focal_loss != 0.0).to(focal_loss).sum(-1) # B | |
| num_non_zero = torch.clamp(num_non_zero, min=1.0) | |
| loss_sum_per_sample = focal_loss.sum(-1) # B | |
| return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 | |
| # https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L50 | |
| def _gumbel_topk_sample(logits: torch.Tensor, k: int): | |
| """Samples k points from the softmax distribution with Gumbel-Top-k trick.""" | |
| # Note that torch.rand is [0, 1), we need to make it (0, 1) to ensure the log is valid. | |
| gumbel_noise = torch.rand(size=logits.shape, dtype=logits.dtype, device=logits.device) | |
| gumbel_noise = -torch.log(-torch.log(gumbel_noise)) | |
| _, indices = torch.topk(logits + gumbel_noise, k) | |
| return indices | |
| # https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L576 | |
| def pixelwise_insdis_loss( | |
| pixel_feature: torch.Tensor, | |
| gt_mask: torch.Tensor, | |
| sample_temperature: float, | |
| sample_k: int, | |
| instance_discrimination_temperature: float, | |
| pixel_gt_void_mask: torch.Tensor, | |
| inverse_gt_mask_area: torch.Tensor | |
| ): | |
| # pixel_feature: B x C x H x W | |
| # gt_mask: B x N x H x W | |
| pixel_feature = pixel_feature.flatten(2) # B x C x HW | |
| gt_mask = gt_mask.flatten(2) # B x N x HW | |
| pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW | |
| inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW | |
| sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW | |
| # sample_logits.masked_fill_(pixel_gt_void_mask, float('-inf')) | |
| sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT | |
| sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K | |
| # Sample ground truth one-hot encodings and compute gt_similarity. | |
| pixel_gt_sampled_feature = torch.gather(gt_mask, dim=2, index=sample_indices.unsqueeze(1).repeat(1, gt_mask.shape[1], 1)) # B x N x K | |
| sampled_gt_similarity = torch.einsum('bnk,bnj->bkj', pixel_gt_sampled_feature, pixel_gt_sampled_feature) # B x K x K | |
| # Normalize the ground truth similarity into a distribution (sum to 1). | |
| pixel_normalizing_constant = sampled_gt_similarity.sum(dim=1, keepdim=True) # B x 1 x K | |
| sampled_gt_similarity /= torch.clamp(pixel_normalizing_constant, min=1.0) # B x K x K | |
| # Sample predicted features and compute pred_similarity. | |
| pixel_pred_sampled_feature = torch.gather(pixel_feature, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pixel_feature.shape[1], 1)) # B x C x K | |
| sampled_pred_similarity = torch.einsum('bck,bcj->bkj', pixel_pred_sampled_feature, pixel_pred_sampled_feature) # B x K x K | |
| sampled_pred_similarity /= instance_discrimination_temperature # B x K x K | |
| loss = F.cross_entropy(sampled_pred_similarity, sampled_gt_similarity, reduction="none") # B x K | |
| num_non_zero = (loss != 0.0).to(loss).sum(-1) # B | |
| num_non_zero = torch.clamp(num_non_zero, min=1.0) | |
| loss_sum_per_sample = loss.sum(-1) # B | |
| return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 | |
| def aux_semantic_loss( | |
| pred_semantic_logits: torch.Tensor, | |
| ground_truth_semantic: torch.Tensor, | |
| sample_temperature: float, | |
| sample_k: int, | |
| pixel_gt_void_mask: torch.Tensor, | |
| inverse_gt_mask_area: torch.Tensor, | |
| num_classes: int): | |
| pred_semantic_logits = pred_semantic_logits.flatten(2) # B x C x HW | |
| ground_truth_semantic = ground_truth_semantic.flatten(1) # B x HW | |
| pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW | |
| inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW | |
| sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW | |
| sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT | |
| sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K | |
| sampled_ground_truth_semantic = torch.gather(ground_truth_semantic, dim=1, index=sample_indices) # B x K | |
| sampled_pred_semantic_logits = torch.gather(pred_semantic_logits, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pred_semantic_logits.shape[1], 1)) # B x C x K | |
| # ignore the class index num_classes. | |
| keep_mask = (sampled_ground_truth_semantic != num_classes) # B x K | |
| loss = F.cross_entropy(sampled_pred_semantic_logits, sampled_ground_truth_semantic, ignore_index=num_classes, reduction='none') # B x K | |
| loss = loss * keep_mask.to(loss) | |
| num_non_zero = (loss != 0.0).to(loss).sum(-1) # B | |
| num_non_zero = torch.clamp(num_non_zero, min=1.0) | |
| loss_sum_per_sample = loss.sum(-1) # B | |
| return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 | |
| # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L56 | |
| # https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L510 | |
| def dice_loss( | |
| inputs: torch.Tensor, | |
| targets: torch.Tensor, | |
| pixel_gt_void_mask: torch.Tensor, | |
| matched_cls_prob: torch.Tensor | |
| ): | |
| """ | |
| Compute the DICE loss, similar to generalized IOU for masks | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| """ | |
| inputs = inputs.softmax(1) # B N HW | |
| # https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L111 | |
| inputs = inputs.masked_fill(pixel_gt_void_mask.unsqueeze(1), 0) # remove void pixels. | |
| smooth = 1.0 | |
| intersection = 2 * (inputs * targets).sum(-1) + smooth # B x N | |
| denominator = inputs.sum(-1) + targets.sum(-1) + smooth # B x N | |
| loss = 1.0 - divide_no_nan(intersection, denominator) | |
| loss *= matched_cls_prob | |
| # Note: kMaX-DeepLab sum over num_masks and avg over batches. But here batch and num_mask are one | |
| # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L559 | |
| # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/max_deeplab_loss.py#L402 | |
| # As the existing of modifer, it equals to multiplier by 0.75 | |
| return (loss.sum(1) * 0.75/128).mean() # sum over masks and mean over batches. | |
| def softmax_ce_loss( | |
| inputs: torch.Tensor, | |
| targets: torch.Tensor, | |
| pixel_gt_void_mask: torch.Tensor, | |
| ): | |
| """ | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| Returns: | |
| Loss tensor | |
| """ | |
| loss = F.cross_entropy(inputs, targets, reduction="none") # B x HW | |
| loss = loss.masked_fill(pixel_gt_void_mask, 0) # remove void pixels. | |
| num_non_zero = (loss != 0.0).to(loss).sum(-1) # B | |
| num_non_zero = torch.clamp(num_non_zero, min=1.0) | |
| loss_sum_per_sample = loss.sum(-1) # B | |
| return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 | |
| class SetCriterion(nn.Module): | |
| """This class computes the loss for DETR. | |
| The process happens in two steps: | |
| 1) we compute hungarian assignment between ground truth boxes and the outputs of the model | |
| 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) | |
| """ | |
| def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, share_final_matching, | |
| pixel_insdis_temperature=1.5, pixel_insdis_sample_k=4096, | |
| aux_semantic_temperature=2.0, aux_semantic_sample_k=4096): | |
| """Create the criterion. | |
| Parameters: | |
| num_classes: number of object categories, omitting the special no-object category | |
| matcher: module able to compute a matching between targets and proposals | |
| eos_coef: relative classification weight applied to the no-object category | |
| losses: list of all the losses to be applied. See get_loss for list of available losses. | |
| """ | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.matcher = matcher | |
| self.weight_dict = weight_dict | |
| self.eos_coef = eos_coef | |
| self.losses = losses | |
| self.share_final_matching = share_final_matching | |
| self.pixel_insdis_temperature = pixel_insdis_temperature | |
| self.pixel_insdis_sample_k = pixel_insdis_sample_k | |
| self.aux_semantic_temperature = aux_semantic_temperature | |
| self.aux_semantic_sample_k = aux_semantic_sample_k | |
| def loss_labels(self, outputs, targets): | |
| """Classification loss (NLL) | |
| targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] | |
| """ | |
| assert "pred_logits" in outputs | |
| src_logits = outputs["pred_logits"] # B x N x C | |
| target_classes = targets["labels"] # B x N | |
| pq_loss_class_weight = targets["pq_loss_class_weight"] | |
| losses = {"loss_ce": focal_cross_entropy_loss(src_logits, target_classes, pq_loss_class_weight)} | |
| return losses | |
| def loss_masks(self, outputs, targets): | |
| """Compute the losses related to the masks: the focal loss and the dice loss. | |
| targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] | |
| """ | |
| src_masks = outputs["pred_masks"] # B x N x H x W | |
| target_masks = targets["masks"] | |
| pq_loss_mask_weight = targets["pq_loss_mask_weight"] | |
| pixel_gt_void_mask = targets["pixel_gt_void_mask"] | |
| src_masks = src_masks.flatten(2) # B x N x HW | |
| target_masks = target_masks.flatten(2) # B x N x HW | |
| pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW | |
| losses = { | |
| "loss_mask": softmax_ce_loss(src_masks, target_masks, pixel_gt_void_mask), | |
| "loss_dice": dice_loss(src_masks, target_masks, pixel_gt_void_mask, pq_loss_mask_weight), | |
| } | |
| return losses | |
| def loss_pixels(self, outputs, targets): | |
| pixel_feature = outputs["pixel_feature"] | |
| target_masks = targets["masks"] | |
| pixel_gt_void_mask = targets["pixel_gt_void_mask"] | |
| inverse_gt_mask_area = targets["inverse_gt_mask_area"] | |
| losses = {"loss_pixel_insdis": pixelwise_insdis_loss( | |
| pixel_feature=pixel_feature, | |
| gt_mask=target_masks, | |
| sample_temperature=self.pixel_insdis_temperature, | |
| sample_k=self.pixel_insdis_sample_k, | |
| instance_discrimination_temperature=0.3, | |
| pixel_gt_void_mask=pixel_gt_void_mask, | |
| inverse_gt_mask_area=inverse_gt_mask_area | |
| )} | |
| del target_masks | |
| return losses | |
| def loss_semantic(self, outputs, targets): | |
| pred_semantic_logits = outputs["aux_semantic_pred"] | |
| ground_truth_semantic = targets["ground_truth_semantic"] | |
| pixel_gt_void_mask = targets["pixel_gt_void_mask"].flatten(1) | |
| inverse_gt_mask_area = targets["inverse_gt_mask_area"].flatten(1) | |
| losses = {"loss_aux_semantic": aux_semantic_loss( | |
| pred_semantic_logits=pred_semantic_logits, | |
| ground_truth_semantic=ground_truth_semantic, | |
| sample_temperature=self.aux_semantic_temperature, | |
| sample_k=self.aux_semantic_sample_k, | |
| pixel_gt_void_mask=pixel_gt_void_mask, | |
| inverse_gt_mask_area=inverse_gt_mask_area, | |
| num_classes=self.num_classes | |
| )} | |
| return losses | |
| def _get_src_permutation_idx(self, indices): | |
| # permute predictions following indices | |
| # torch.full_like gives a tensor full of i in shape of src.shape | |
| # at each iter, i is the index, src is the src ind in shape of (N) | |
| # so batch_idx is concat of (0,0,...), (1,1,...), with shape (N0+N1+N2+...+Nb) | |
| # so if we flatten gt/pred across bathces, this gives the batch_id of each sample | |
| batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
| # src_idx is src_ind concated to shape (N0+N1+N2+...+Nb) | |
| # it is a flattened concat of mask_id at each batch | |
| src_idx = torch.cat([src for (src, _) in indices]) | |
| return batch_idx, src_idx | |
| def get_loss(self, loss, outputs, targets): | |
| loss_map = { | |
| 'labels': self.loss_labels, | |
| 'masks': self.loss_masks, | |
| 'pixels': self.loss_pixels, | |
| 'aux_semantic': self.loss_semantic, | |
| } | |
| assert loss in loss_map, f"do you really want to compute {loss} loss?" | |
| return loss_map[loss](outputs, targets) | |
| def process_gt(self, outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=False): | |
| # Permute&Pad Pred> for loss compuation. | |
| # By controling process_gt, we can share the matching results for all preds. | |
| src_idx = self._get_src_permutation_idx(indices) | |
| src_masks = outputs["pred_masks"].detach() # B x N x H x W | |
| # Pad and permute the target_mask to B x N x H x W | |
| target_masks = torch.zeros_like(src_masks) | |
| target_masks_o = torch.cat([t["masks"][J] for t, (_, J) in zip(targets, indices)]).to(target_masks) | |
| target_masks[src_idx] = target_masks_o | |
| # Pad and permute the matched_cls_prob to B x N | |
| matched_cls_prob_o = torch.cat([cls_prob for cls_prob in matched_cls_prob]) | |
| matched_cls_prob_o = torch.clamp(matched_cls_prob_o, min=self.eos_coef) | |
| # https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L1034 | |
| # no penalty for unmatched masks. | |
| matched_cls_prob = torch.full( | |
| src_masks.shape[:2], 0, dtype=src_masks.dtype, device=src_masks.device | |
| ) # B x N | |
| matched_cls_prob[src_idx] = matched_cls_prob_o.to(matched_cls_prob) | |
| # pixel_gt_void_mask is used to indicate those pixels without labels. | |
| pixel_gt_void_mask = (target_masks.sum(1) < 1) # B x H x W | |
| # inverse_gt_mask_area is used to sample pixels. | |
| mask_gt_area = target_masks.sum(2).sum(2) # B x N | |
| pixel_gt_area = torch.einsum('bnhw,bn->bhw', target_masks, mask_gt_area) # B x H x W | |
| inverse_gt_mask_area = (pixel_gt_area.shape[1] * pixel_gt_area.shape[2]) / torch.clamp(pixel_gt_area, min=1.0) # B x H x W | |
| src_logits = outputs["pred_logits"] # B x N x C | |
| # Pad and permute the target_classes to B x N | |
| target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) | |
| # This serves as a padding. | |
| target_classes = torch.full( | |
| src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device | |
| ) | |
| # We put real GT to those corresponds to src_idx, and put void into other places. | |
| target_classes[src_idx] = target_classes_o | |
| src_masks_prob = src_masks.softmax(1) | |
| void_mask = pixel_gt_void_mask.to(src_masks_prob) # B x H x W | |
| # compute iou instead of dice for void overlapping. | |
| def computer_iou_score(x, y): | |
| # x : B x N x H x W | |
| # y : B x H x W | |
| x = x.flatten(2) # B x N x L | |
| y = y.flatten(1) # B x L | |
| intersection = torch.einsum('bnl,bl->bn', x, y) # B x N | |
| denominator = x.sum(-1) # B x N | |
| return intersection / (denominator + 1e-5) # B x N | |
| # Pad and permute the matched_dice to B x N | |
| matched_dice_o = torch.cat([dice for dice in matched_dice]) | |
| matched_dice = computer_iou_score(src_masks_prob, void_mask) # unmatched masks use their dice with void | |
| matched_dice[src_idx] = matched_dice_o.to(matched_dice) | |
| matched_dice = torch.clamp(matched_dice, min=self.eos_coef) | |
| processed_gt = {"masks": target_masks, "labels": target_classes, | |
| "pq_loss_mask_weight": matched_cls_prob, | |
| "pq_loss_class_weight": matched_dice, | |
| "pixel_gt_void_mask": pixel_gt_void_mask, | |
| "inverse_gt_mask_area": inverse_gt_mask_area,} | |
| if process_semantic: | |
| # To obtain semantic gt | |
| ground_truth_semantic = [t["semantic_masks"] for t in targets] | |
| ground_truth_semantic = torch.stack(ground_truth_semantic, dim=0) # B x H x W | |
| # self.num_classes is set to ignore label | |
| ground_truth_semantic[ground_truth_semantic==-1] = self.num_classes | |
| processed_gt.update({"ground_truth_semantic": ground_truth_semantic}) | |
| return processed_gt | |
| def forward(self, outputs, targets): | |
| """This performs the loss computation. | |
| Parameters: | |
| outputs: dict of tensors, see the output specification of the model for the format | |
| targets: list of dicts, such that len(targets) == batch_size. | |
| The expected keys in each dict depends on the losses applied, see each loss' doc | |
| """ | |
| outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} | |
| indices, matched_dice, matched_cls_prob = self.matcher(outputs_without_aux, targets) | |
| # Pad GT to the same number of prediction. | |
| processed_targets = self.process_gt(outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=True) | |
| # Compute all the requested losses | |
| losses = {} | |
| for loss in self.losses: | |
| losses.update(self.get_loss(loss, outputs, processed_targets)) | |
| if "aux_outputs" in outputs: | |
| for i, aux_outputs in enumerate(outputs["aux_outputs"]): | |
| # We share matching results across predictions. | |
| if not self.share_final_matching: | |
| indices, matched_dice, matched_cls_prob = self.matcher(aux_outputs, targets) | |
| if not self.share_final_matching: | |
| processed_targets = self.process_gt(aux_outputs, targets, indices, matched_dice, matched_cls_prob) | |
| for loss in self.losses: | |
| if loss in ['aux_semantic']: | |
| # Only for final output. | |
| continue | |
| l_dict = self.get_loss(loss, aux_outputs, processed_targets) | |
| l_dict = {k + f"_{i}": v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| return losses | |
| def __repr__(self): | |
| head = "Criterion " + self.__class__.__name__ | |
| body = [ | |
| "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)), | |
| "losses: {}".format(self.losses), | |
| "weight_dict: {}".format(self.weight_dict), | |
| "num_classes: {}".format(self.num_classes), | |
| "eos_coef: {}".format(self.eos_coef), | |
| ] | |
| _repr_indent = 4 | |
| lines = [head] + [" " * _repr_indent + line for line in body] | |
| return "\n".join(lines) |