Spaces:
Running
on
Zero
Running
on
Zero
| # This file is copied from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py | |
| # such that users do not need to install detectron2 just for these two functions | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| from typing import List | |
| import torch | |
| from torch.nn import functional as F | |
| def cat(tensors: List[torch.Tensor], dim: int = 0): | |
| """ | |
| Efficient version of torch.cat that avoids a copy if there is only a single element in a list | |
| """ | |
| assert isinstance(tensors, (list, tuple)) | |
| if len(tensors) == 1: | |
| return tensors[0] | |
| return torch.cat(tensors, dim) | |
| def calculate_uncertainty(sem_seg_logits): | |
| """ | |
| For each location of the prediction `sem_seg_logits` we estimate uncerainty as the | |
| difference between top first and top second predicted logits. | |
| Args: | |
| mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and | |
| C is the number of foreground classes. The values are logits. | |
| Returns: | |
| scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with | |
| the most uncertain locations having the highest uncertainty score. | |
| """ | |
| if sem_seg_logits.shape[1] == 2: | |
| # binary segmentation | |
| return -(torch.abs(sem_seg_logits[:, 1:2])) | |
| top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0] | |
| return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) | |
| def point_sample(input, point_coords, **kwargs): | |
| """ | |
| A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. | |
| Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside | |
| [0, 1] x [0, 1] square. | |
| Args: | |
| input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. | |
| point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains | |
| [0, 1] x [0, 1] normalized point coordinates. | |
| Returns: | |
| output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains | |
| features for points in `point_coords`. The features are obtained via bilinear | |
| interpolation from `input` the same way as :function:`torch.nn.functional.grid_sample`. | |
| """ | |
| add_dim = False | |
| if point_coords.dim() == 3: | |
| add_dim = True | |
| point_coords = point_coords.unsqueeze(2) | |
| output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) | |
| if add_dim: | |
| output = output.squeeze(3) | |
| return output | |
| def get_uncertain_point_coords_with_randomness(coarse_logits, uncertainty_func, num_points, | |
| oversample_ratio, importance_sample_ratio): | |
| """ | |
| Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties | |
| are calculated for each point using 'uncertainty_func' function that takes point's logit | |
| prediction as input. | |
| See PointRend paper for details. | |
| Args: | |
| coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for | |
| class-specific or class-agnostic prediction. | |
| uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that | |
| contains logit predictions for P points and returns their uncertainties as a Tensor of | |
| shape (N, 1, P). | |
| num_points (int): The number of points P to sample. | |
| oversample_ratio (int): Oversampling parameter. | |
| importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. | |
| Returns: | |
| point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P | |
| sampled points. | |
| """ | |
| assert oversample_ratio >= 1 | |
| assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 | |
| num_boxes = coarse_logits.shape[0] | |
| num_sampled = int(num_points * oversample_ratio) | |
| point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) | |
| point_logits = point_sample(coarse_logits, point_coords, align_corners=False) | |
| # It is crucial to calculate uncertainty based on the sampled prediction value for the points. | |
| # Calculating uncertainties of the coarse predictions first and sampling them for points leads | |
| # to incorrect results. | |
| # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between | |
| # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. | |
| # However, if we calculate uncertainties for the coarse predictions first, | |
| # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. | |
| point_uncertainties = uncertainty_func(point_logits) | |
| num_uncertain_points = int(importance_sample_ratio * num_points) | |
| num_random_points = num_points - num_uncertain_points | |
| idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] | |
| shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) | |
| idx += shift[:, None] | |
| point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, | |
| 2) | |
| if num_random_points > 0: | |
| point_coords = cat( | |
| [ | |
| point_coords, | |
| torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), | |
| ], | |
| dim=1, | |
| ) | |
| return point_coords |