Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| from detectron2.layers import batched_nms | |
| from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads | |
| from detectron2.modeling.roi_heads.roi_heads import Res5ROIHeads | |
| from detectron2.structures import Instances | |
| def merge_branch_instances(instances, num_branch, nms_thresh, topk_per_image): | |
| """ | |
| Merge detection results from different branches of TridentNet. | |
| Return detection results by applying non-maximum suppression (NMS) on bounding boxes | |
| and keep the unsuppressed boxes and other instances (e.g mask) if any. | |
| Args: | |
| instances (list[Instances]): A list of N * num_branch instances that store detection | |
| results. Contain N images and each image has num_branch instances. | |
| num_branch (int): Number of branches used for merging detection results for each image. | |
| nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. | |
| topk_per_image (int): The number of top scoring detections to return. Set < 0 to return | |
| all detections. | |
| Returns: | |
| results: (list[Instances]): A list of N instances, one for each image in the batch, | |
| that stores the topk most confidence detections after merging results from multiple | |
| branches. | |
| """ | |
| if num_branch == 1: | |
| return instances | |
| batch_size = len(instances) // num_branch | |
| results = [] | |
| for i in range(batch_size): | |
| instance = Instances.cat([instances[i + batch_size * j] for j in range(num_branch)]) | |
| # Apply per-class NMS | |
| keep = batched_nms( | |
| instance.pred_boxes.tensor, instance.scores, instance.pred_classes, nms_thresh | |
| ) | |
| keep = keep[:topk_per_image] | |
| result = instance[keep] | |
| results.append(result) | |
| return results | |
| class TridentRes5ROIHeads(Res5ROIHeads): | |
| """ | |
| The TridentNet ROIHeads in a typical "C4" R-CNN model. | |
| See :class:`Res5ROIHeads`. | |
| """ | |
| def __init__(self, cfg, input_shape): | |
| super().__init__(cfg, input_shape) | |
| self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH | |
| self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1 | |
| def forward(self, images, features, proposals, targets=None): | |
| """ | |
| See :class:`Res5ROIHeads.forward`. | |
| """ | |
| num_branch = self.num_branch if self.training or not self.trident_fast else 1 | |
| all_targets = targets * num_branch if targets is not None else None | |
| pred_instances, losses = super().forward(images, features, proposals, all_targets) | |
| del images, all_targets, targets | |
| if self.training: | |
| return pred_instances, losses | |
| else: | |
| pred_instances = merge_branch_instances( | |
| pred_instances, | |
| num_branch, | |
| self.box_predictor.test_nms_thresh, | |
| self.box_predictor.test_topk_per_image, | |
| ) | |
| return pred_instances, {} | |
| class TridentStandardROIHeads(StandardROIHeads): | |
| """ | |
| The `StandardROIHeads` for TridentNet. | |
| See :class:`StandardROIHeads`. | |
| """ | |
| def __init__(self, cfg, input_shape): | |
| super(TridentStandardROIHeads, self).__init__(cfg, input_shape) | |
| self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH | |
| self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1 | |
| def forward(self, images, features, proposals, targets=None): | |
| """ | |
| See :class:`Res5ROIHeads.forward`. | |
| """ | |
| # Use 1 branch if using trident_fast during inference. | |
| num_branch = self.num_branch if self.training or not self.trident_fast else 1 | |
| # Duplicate targets for all branches in TridentNet. | |
| all_targets = targets * num_branch if targets is not None else None | |
| pred_instances, losses = super().forward(images, features, proposals, all_targets) | |
| del images, all_targets, targets | |
| if self.training: | |
| return pred_instances, losses | |
| else: | |
| pred_instances = merge_branch_instances( | |
| pred_instances, | |
| num_branch, | |
| self.box_predictor.test_nms_thresh, | |
| self.box_predictor.test_topk_per_image, | |
| ) | |
| return pred_instances, {} | |