Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------ | |
| # Deformable DETR | |
| # Copyright (c) 2020 SenseTime. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------ | |
| # Modified from DETR (https://github.com/facebookresearch/detr) | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| # ------------------------------------------------------------------------ | |
| """ | |
| This file provides the definition of the convolutional heads used to predict masks, as well as the losses | |
| """ | |
| import io | |
| from collections import defaultdict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from .util import box_ops | |
| from .util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list | |
| try: | |
| from panopticapi.utils import id2rgb, rgb2id | |
| except ImportError: | |
| pass | |
| class DETRsegm(nn.Module): | |
| def __init__(self, detr, freeze_detr=False): | |
| super().__init__() | |
| self.detr = detr | |
| if freeze_detr: | |
| for p in self.parameters(): | |
| p.requires_grad_(False) | |
| hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead | |
| self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0) | |
| self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) | |
| def forward(self, samples: NestedTensor): | |
| if not isinstance(samples, NestedTensor): | |
| samples = nested_tensor_from_tensor_list(samples) | |
| features, pos = self.detr.backbone(samples) | |
| bs = features[-1].tensors.shape[0] | |
| src, mask = features[-1].decompose() | |
| src_proj = self.detr.input_proj(src) | |
| hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) | |
| outputs_class = self.detr.class_embed(hs) | |
| outputs_coord = self.detr.bbox_embed(hs).sigmoid() | |
| out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} | |
| if self.detr.aux_loss: | |
| out["aux_outputs"] = [ | |
| {"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) | |
| ] | |
| # FIXME h_boxes takes the last one computed, keep this in mind | |
| bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) | |
| seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) | |
| outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) | |
| out["pred_masks"] = outputs_seg_masks | |
| return out | |
| class MaskHeadSmallConv(nn.Module): | |
| """ | |
| Simple convolutional head, using group norm. | |
| Upsampling is done using a FPN approach | |
| """ | |
| def __init__(self, dim, fpn_dims, context_dim): | |
| super().__init__() | |
| inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] | |
| self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) | |
| self.gn1 = torch.nn.GroupNorm(8, dim) | |
| self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) | |
| self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) | |
| self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) | |
| self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) | |
| self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) | |
| self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) | |
| self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) | |
| self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) | |
| self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) | |
| self.dim = dim | |
| self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) | |
| self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) | |
| self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_uniform_(m.weight, a=1) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x, bbox_mask, fpns): | |
| def expand(tensor, length): | |
| return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) | |
| x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) | |
| x = self.lay1(x) | |
| x = self.gn1(x) | |
| x = F.relu(x) | |
| x = self.lay2(x) | |
| x = self.gn2(x) | |
| x = F.relu(x) | |
| cur_fpn = self.adapter1(fpns[0]) | |
| if cur_fpn.size(0) != x.size(0): | |
| cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) | |
| x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") | |
| x = self.lay3(x) | |
| x = self.gn3(x) | |
| x = F.relu(x) | |
| cur_fpn = self.adapter2(fpns[1]) | |
| if cur_fpn.size(0) != x.size(0): | |
| cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) | |
| x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") | |
| x = self.lay4(x) | |
| x = self.gn4(x) | |
| x = F.relu(x) | |
| cur_fpn = self.adapter3(fpns[2]) | |
| if cur_fpn.size(0) != x.size(0): | |
| cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) | |
| x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") | |
| x = self.lay5(x) | |
| x = self.gn5(x) | |
| x = F.relu(x) | |
| x = self.out_lay(x) | |
| return x | |
| class MHAttentionMap(nn.Module): | |
| """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" | |
| def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.hidden_dim = hidden_dim | |
| self.dropout = nn.Dropout(dropout) | |
| self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) | |
| self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) | |
| nn.init.zeros_(self.k_linear.bias) | |
| nn.init.zeros_(self.q_linear.bias) | |
| nn.init.xavier_uniform_(self.k_linear.weight) | |
| nn.init.xavier_uniform_(self.q_linear.weight) | |
| self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 | |
| def forward(self, q, k, mask=None): | |
| q = self.q_linear(q) | |
| k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) | |
| qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) | |
| kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) | |
| weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) | |
| if mask is not None: | |
| weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) | |
| weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights) | |
| weights = self.dropout(weights) | |
| return weights | |
| def dice_loss(inputs, targets, num_boxes): | |
| """ | |
| Compute the DICE loss, similar to generalized IOU for masks | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| """ | |
| inputs = inputs.sigmoid() | |
| inputs = inputs.flatten(1) | |
| numerator = 2 * (inputs * targets).sum(1) | |
| denominator = inputs.sum(-1) + targets.sum(-1) | |
| loss = 1 - (numerator + 1) / (denominator + 1) | |
| return loss.sum() / num_boxes | |
| def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): | |
| """ | |
| Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| alpha: (optional) Weighting factor in range (0,1) to balance | |
| positive vs negative examples. Default = -1 (no weighting). | |
| gamma: Exponent of the modulating factor (1 - p_t) to | |
| balance easy vs hard examples. | |
| Returns: | |
| Loss tensor | |
| """ | |
| prob = inputs.sigmoid() | |
| ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
| p_t = prob * targets + (1 - prob) * (1 - targets) | |
| loss = ce_loss * ((1 - p_t) ** gamma) | |
| if alpha >= 0: | |
| alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
| loss = alpha_t * loss | |
| return loss.mean(1) | |
| class PostProcessSegm(nn.Module): | |
| def __init__(self, threshold=0.5): | |
| super().__init__() | |
| self.threshold = threshold | |
| def forward(self, results, outputs, orig_target_sizes, max_target_sizes): | |
| assert len(orig_target_sizes) == len(max_target_sizes) | |
| max_h, max_w = max_target_sizes.max(0)[0].tolist() | |
| outputs_masks = outputs["pred_masks"].squeeze(2) | |
| outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) | |
| outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() | |
| for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): | |
| img_h, img_w = t[0], t[1] | |
| results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) | |
| results[i]["masks"] = F.interpolate( | |
| results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" | |
| ).byte() | |
| return results | |
| class PostProcessPanoptic(nn.Module): | |
| """This class converts the output of the model to the final panoptic result, in the format expected by the | |
| coco panoptic API """ | |
| def __init__(self, is_thing_map, threshold=0.85): | |
| """ | |
| Parameters: | |
| is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether | |
| the class is a thing (True) or a stuff (False) class | |
| threshold: confidence threshold: segments with confidence lower than this will be deleted | |
| """ | |
| super().__init__() | |
| self.threshold = threshold | |
| self.is_thing_map = is_thing_map | |
| def forward(self, outputs, processed_sizes, target_sizes=None): | |
| """ This function computes the panoptic prediction from the model's predictions. | |
| Parameters: | |
| outputs: This is a dict coming directly from the model. See the model doc for the content. | |
| processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the | |
| model, ie the size after data augmentation but before batching. | |
| target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size | |
| of each prediction. If left to None, it will default to the processed_sizes | |
| """ | |
| if target_sizes is None: | |
| target_sizes = processed_sizes | |
| assert len(processed_sizes) == len(target_sizes) | |
| out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] | |
| assert len(out_logits) == len(raw_masks) == len(target_sizes) | |
| preds = [] | |
| def to_tuple(tup): | |
| if isinstance(tup, tuple): | |
| return tup | |
| return tuple(tup.cpu().tolist()) | |
| for cur_logits, cur_masks, cur_boxes, size, target_size in zip( | |
| out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes | |
| ): | |
| # we filter empty queries and detection below threshold | |
| scores, labels = cur_logits.softmax(-1).max(-1) | |
| keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) | |
| cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) | |
| cur_scores = cur_scores[keep] | |
| cur_classes = cur_classes[keep] | |
| cur_masks = cur_masks[keep] | |
| cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0) | |
| cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) | |
| h, w = cur_masks.shape[-2:] | |
| assert len(cur_boxes) == len(cur_classes) | |
| # It may be that we have several predicted masks for the same stuff class. | |
| # In the following, we track the list of masks ids for each stuff class (they are merged later on) | |
| cur_masks = cur_masks.flatten(1) | |
| stuff_equiv_classes = defaultdict(lambda: []) | |
| for k, label in enumerate(cur_classes): | |
| if not self.is_thing_map[label.item()]: | |
| stuff_equiv_classes[label.item()].append(k) | |
| def get_ids_area(masks, scores, dedup=False): | |
| # This helper function creates the final panoptic segmentation image | |
| # It also returns the area of the masks that appears on the image | |
| m_id = masks.transpose(0, 1).softmax(-1) | |
| if m_id.shape[-1] == 0: | |
| # We didn't detect any mask :( | |
| m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) | |
| else: | |
| m_id = m_id.argmax(-1).view(h, w) | |
| if dedup: | |
| # Merge the masks corresponding to the same stuff class | |
| for equiv in stuff_equiv_classes.values(): | |
| if len(equiv) > 1: | |
| for eq_id in equiv: | |
| m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) | |
| final_h, final_w = to_tuple(target_size) | |
| seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) | |
| seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) | |
| np_seg_img = ( | |
| torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy() | |
| ) | |
| m_id = torch.from_numpy(rgb2id(np_seg_img)) | |
| area = [] | |
| for i in range(len(scores)): | |
| area.append(m_id.eq(i).sum().item()) | |
| return area, seg_img | |
| area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) | |
| if cur_classes.numel() > 0: | |
| # We know filter empty masks as long as we find some | |
| while True: | |
| filtered_small = torch.as_tensor( | |
| [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device | |
| ) | |
| if filtered_small.any().item(): | |
| cur_scores = cur_scores[~filtered_small] | |
| cur_classes = cur_classes[~filtered_small] | |
| cur_masks = cur_masks[~filtered_small] | |
| area, seg_img = get_ids_area(cur_masks, cur_scores) | |
| else: | |
| break | |
| else: | |
| cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) | |
| segments_info = [] | |
| for i, a in enumerate(area): | |
| cat = cur_classes[i].item() | |
| segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a}) | |
| del cur_classes | |
| with io.BytesIO() as out: | |
| seg_img.save(out, format="PNG") | |
| predictions = {"png_string": out.getvalue(), "segments_info": segments_info} | |
| preds.append(predictions) | |
| return preds | |