Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from maskrcnn_benchmark.modeling import registry | |
| from maskrcnn_benchmark.layers import Scale, DFConv2d | |
| from .loss import make_fcos_loss_evaluator | |
| from .anchor_generator import make_center_anchor_generator | |
| from .inference import make_fcos_postprocessor | |
| class FCOSHead(torch.nn.Module): | |
| def __init__(self, cfg): | |
| super(FCOSHead, self).__init__() | |
| # TODO: Implement the sigmoid version first. | |
| num_classes = cfg.MODEL.FCOS.NUM_CLASSES - 1 | |
| in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS | |
| use_gn = cfg.MODEL.FCOS.USE_GN | |
| use_bn = cfg.MODEL.FCOS.USE_BN | |
| use_dcn_in_tower = cfg.MODEL.FCOS.USE_DFCONV | |
| self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES | |
| self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS | |
| self.centerness_on_reg = cfg.MODEL.FCOS.CENTERNESS_ON_REG | |
| cls_tower = [] | |
| bbox_tower = [] | |
| for i in range(cfg.MODEL.FCOS.NUM_CONVS): | |
| if use_dcn_in_tower and \ | |
| i == cfg.MODEL.FCOS.NUM_CONVS - 1: | |
| conv_func = DFConv2d | |
| else: | |
| conv_func = nn.Conv2d | |
| cls_tower.append( | |
| conv_func( | |
| in_channels, | |
| in_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=True | |
| ) | |
| ) | |
| if use_gn: | |
| cls_tower.append(nn.GroupNorm(32, in_channels)) | |
| if use_bn: | |
| cls_tower.append(nn.BatchNorm2d(in_channels)) | |
| cls_tower.append(nn.ReLU()) | |
| bbox_tower.append( | |
| conv_func( | |
| in_channels, | |
| in_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=True | |
| ) | |
| ) | |
| if use_gn: | |
| bbox_tower.append(nn.GroupNorm(32, in_channels)) | |
| if use_bn: | |
| bbox_tower.append(nn.BatchNorm2d(in_channels)) | |
| bbox_tower.append(nn.ReLU()) | |
| self.add_module('cls_tower', nn.Sequential(*cls_tower)) | |
| self.add_module('bbox_tower', nn.Sequential(*bbox_tower)) | |
| self.cls_logits = nn.Conv2d( | |
| in_channels, num_classes, kernel_size=3, stride=1, | |
| padding=1 | |
| ) | |
| self.bbox_pred = nn.Conv2d( | |
| in_channels, 4, kernel_size=3, stride=1, | |
| padding=1 | |
| ) | |
| self.centerness = nn.Conv2d( | |
| in_channels, 1, kernel_size=3, stride=1, | |
| padding=1 | |
| ) | |
| # initialization | |
| for modules in [self.cls_tower, self.bbox_tower, | |
| self.cls_logits, self.bbox_pred, | |
| self.centerness]: | |
| for l in modules.modules(): | |
| if isinstance(l, nn.Conv2d): | |
| torch.nn.init.normal_(l.weight, std=0.01) | |
| torch.nn.init.constant_(l.bias, 0) | |
| # initialize the bias for focal loss | |
| prior_prob = cfg.MODEL.FCOS.PRIOR_PROB | |
| bias_value = -math.log((1 - prior_prob) / prior_prob) | |
| torch.nn.init.constant_(self.cls_logits.bias, bias_value) | |
| self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)]) | |
| def forward(self, x): | |
| logits = [] | |
| bbox_reg = [] | |
| centerness = [] | |
| for l, feature in enumerate(x): | |
| cls_tower = self.cls_tower(feature) | |
| box_tower = self.bbox_tower(feature) | |
| logits.append(self.cls_logits(cls_tower)) | |
| if self.centerness_on_reg: | |
| centerness.append(self.centerness(box_tower)) | |
| else: | |
| centerness.append(self.centerness(cls_tower)) | |
| bbox_pred = self.scales[l](self.bbox_pred(box_tower)) | |
| if self.norm_reg_targets: | |
| bbox_pred = F.relu(bbox_pred) | |
| if self.training: | |
| bbox_reg.append(bbox_pred) | |
| else: | |
| bbox_reg.append(bbox_pred * self.fpn_strides[l]) | |
| else: | |
| bbox_reg.append(torch.exp(bbox_pred)) | |
| return logits, bbox_reg, centerness | |
| class FCOSModule(torch.nn.Module): | |
| """ | |
| Module for FCOS computation. Takes feature maps from the backbone and | |
| FCOS outputs and losses. Only Test on FPN now. | |
| """ | |
| def __init__(self, cfg): | |
| super(FCOSModule, self).__init__() | |
| head = FCOSHead(cfg) | |
| box_selector_train = make_fcos_postprocessor(cfg, is_train=True) | |
| box_selector_test = make_fcos_postprocessor(cfg, is_train=False) | |
| loss_evaluator = make_fcos_loss_evaluator(cfg) | |
| self.cfg = cfg | |
| self.head = head | |
| self.box_selector_train = box_selector_train | |
| self.box_selector_test = box_selector_test | |
| self.loss_evaluator = loss_evaluator | |
| self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES | |
| if not cfg.MODEL.RPN_ONLY: | |
| self.anchor_generator = make_center_anchor_generator(cfg) | |
| def forward(self, images, features, targets=None): | |
| """ | |
| Arguments: | |
| images (ImageList): images for which we want to compute the predictions | |
| features (list[Tensor]): features computed from the images that are | |
| used for computing the predictions. Each tensor in the list | |
| correspond to different feature levels | |
| targets (list[BoxList): ground-truth boxes present in the image (optional) | |
| Returns: | |
| boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per | |
| image. | |
| losses (dict[Tensor]): the losses for the model during training. During | |
| testing, it is an empty dict. | |
| """ | |
| box_cls, box_regression, centerness = self.head(features) | |
| locations = self.compute_locations(features) | |
| if self.training and targets is not None: | |
| return self._forward_train( | |
| locations, box_cls, box_regression, | |
| centerness, targets, images.image_sizes | |
| ) | |
| else: | |
| return self._forward_test( | |
| locations, box_cls, box_regression, | |
| centerness, images.image_sizes | |
| ) | |
| def _forward_train(self, locations, box_cls, box_regression, centerness, targets, image_sizes=None): | |
| loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator( | |
| locations, box_cls, box_regression, centerness, targets | |
| ) | |
| losses = { | |
| "loss_cls": loss_box_cls, | |
| "loss_reg": loss_box_reg, | |
| "loss_centerness": loss_centerness | |
| } | |
| if self.cfg.MODEL.RPN_ONLY: | |
| return None, losses | |
| else: | |
| boxes = self.box_selector_train( | |
| locations, box_cls, box_regression, | |
| centerness, image_sizes | |
| ) | |
| proposals = self.anchor_generator(boxes, image_sizes, centerness) | |
| return proposals, losses | |
| def _forward_test(self, locations, box_cls, box_regression, centerness, image_sizes): | |
| boxes = self.box_selector_test( | |
| locations, box_cls, box_regression, | |
| centerness, image_sizes | |
| ) | |
| if not self.cfg.MODEL.RPN_ONLY: | |
| boxes = self.anchor_generator(boxes, image_sizes, centerness) | |
| return boxes, {} | |
| def compute_locations(self, features): | |
| locations = [] | |
| for level, feature in enumerate(features): | |
| h, w = feature.size()[-2:] | |
| locations_per_level = self.compute_locations_per_level( | |
| h, w, self.fpn_strides[level], | |
| feature.device | |
| ) | |
| locations.append(locations_per_level) | |
| return locations | |
| def compute_locations_per_level(self, h, w, stride, device): | |
| shifts_x = torch.arange( | |
| 0, w * stride, step=stride, | |
| dtype=torch.float32, device=device | |
| ) | |
| shifts_y = torch.arange( | |
| 0, h * stride, step=stride, | |
| dtype=torch.float32, device=device | |
| ) | |
| shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) | |
| shift_x = shift_x.reshape(-1) | |
| shift_y = shift_y.reshape(-1) | |
| locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 | |
| return locations | |