Spaces:
Running
on
T4
Running
on
T4
| import torch | |
| import torch.nn as nn | |
| import warnings | |
| from torch.autograd import Function | |
| from typing import * | |
| try: | |
| import pointnet2_ops._ext as _ext | |
| except ImportError: | |
| from torch.utils.cpp_extension import load | |
| import glob | |
| import os.path as osp | |
| import os | |
| warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.") | |
| _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src") | |
| _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( | |
| osp.join(_ext_src_root, "src", "*.cu") | |
| ) | |
| _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) | |
| os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" | |
| _ext = load( | |
| "_ext", | |
| sources=_ext_sources, | |
| extra_include_paths=[osp.join(_ext_src_root, "include")], | |
| extra_cflags=["-O3"], | |
| extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"], | |
| with_cuda=True, | |
| ) | |
| class FurthestPointSampling(Function): | |
| def forward(ctx, xyz, npoint): | |
| # type: (Any, torch.Tensor, int) -> torch.Tensor | |
| r""" | |
| Uses iterative furthest point sampling to select a set of npoint features that have the largest | |
| minimum distance | |
| Parameters | |
| ---------- | |
| xyz : torch.Tensor | |
| (B, N, 3) tensor where N > npoint | |
| npoint : int32 | |
| number of features in the sampled set | |
| Returns | |
| ------- | |
| torch.Tensor | |
| (B, npoint) tensor containing the set | |
| """ | |
| out = _ext.furthest_point_sampling(xyz, npoint) | |
| ctx.mark_non_differentiable(out) | |
| return out | |
| def backward(ctx, grad_out): | |
| return () | |
| furthest_point_sample = FurthestPointSampling.apply | |
| class GatherOperation(Function): | |
| def forward(ctx, features, idx): | |
| # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor | |
| r""" | |
| Parameters | |
| ---------- | |
| features : torch.Tensor | |
| (B, C, N) tensor | |
| idx : torch.Tensor | |
| (B, npoint) tensor of the features to gather | |
| Returns | |
| ------- | |
| torch.Tensor | |
| (B, C, npoint) tensor | |
| """ | |
| ctx.save_for_backward(idx, features) | |
| return _ext.gather_points(features, idx) | |
| def backward(ctx, grad_out): | |
| idx, features = ctx.saved_tensors | |
| N = features.size(2) | |
| grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) | |
| return grad_features, None | |
| gather_operation = GatherOperation.apply | |
| class ThreeNN(Function): | |
| def forward(ctx, unknown, known): | |
| # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] | |
| r""" | |
| Find the three nearest neighbors of unknown in known | |
| Parameters | |
| ---------- | |
| unknown : torch.Tensor | |
| (B, n, 3) tensor of known features | |
| known : torch.Tensor | |
| (B, m, 3) tensor of unknown features | |
| Returns | |
| ------- | |
| dist : torch.Tensor | |
| (B, n, 3) l2 distance to the three nearest neighbors | |
| idx : torch.Tensor | |
| (B, n, 3) index of 3 nearest neighbors | |
| """ | |
| dist2, idx = _ext.three_nn(unknown, known) | |
| dist = torch.sqrt(dist2) | |
| ctx.mark_non_differentiable(dist, idx) | |
| return dist, idx | |
| def backward(ctx, grad_dist, grad_idx): | |
| return () | |
| three_nn = ThreeNN.apply | |
| class ThreeInterpolate(Function): | |
| def forward(ctx, features, idx, weight): | |
| # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor | |
| r""" | |
| Performs weight linear interpolation on 3 features | |
| Parameters | |
| ---------- | |
| features : torch.Tensor | |
| (B, c, m) Features descriptors to be interpolated from | |
| idx : torch.Tensor | |
| (B, n, 3) three nearest neighbors of the target features in features | |
| weight : torch.Tensor | |
| (B, n, 3) weights | |
| Returns | |
| ------- | |
| torch.Tensor | |
| (B, c, n) tensor of the interpolated features | |
| """ | |
| ctx.save_for_backward(idx, weight, features) | |
| return _ext.three_interpolate(features, idx, weight) | |
| def backward(ctx, grad_out): | |
| # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] | |
| r""" | |
| Parameters | |
| ---------- | |
| grad_out : torch.Tensor | |
| (B, c, n) tensor with gradients of ouputs | |
| Returns | |
| ------- | |
| grad_features : torch.Tensor | |
| (B, c, m) tensor with gradients of features | |
| None | |
| None | |
| """ | |
| idx, weight, features = ctx.saved_tensors | |
| m = features.size(2) | |
| grad_features = _ext.three_interpolate_grad( | |
| grad_out.contiguous(), idx, weight, m | |
| ) | |
| return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) | |
| three_interpolate = ThreeInterpolate.apply | |
| class GroupingOperation(Function): | |
| def forward(ctx, features, idx): | |
| # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor | |
| r""" | |
| Parameters | |
| ---------- | |
| features : torch.Tensor | |
| (B, C, N) tensor of features to group | |
| idx : torch.Tensor | |
| (B, npoint, nsample) tensor containing the indicies of features to group with | |
| Returns | |
| ------- | |
| torch.Tensor | |
| (B, C, npoint, nsample) tensor | |
| """ | |
| ctx.save_for_backward(idx, features) | |
| return _ext.group_points(features, idx) | |
| def backward(ctx, grad_out): | |
| # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] | |
| r""" | |
| Parameters | |
| ---------- | |
| grad_out : torch.Tensor | |
| (B, C, npoint, nsample) tensor of the gradients of the output from forward | |
| Returns | |
| ------- | |
| torch.Tensor | |
| (B, C, N) gradient of the features | |
| None | |
| """ | |
| idx, features = ctx.saved_tensors | |
| N = features.size(2) | |
| grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) | |
| return grad_features, torch.zeros_like(idx) | |
| grouping_operation = GroupingOperation.apply | |
| class BallQuery(Function): | |
| def forward(ctx, radius, nsample, xyz, new_xyz): | |
| # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor | |
| r""" | |
| Parameters | |
| ---------- | |
| radius : float | |
| radius of the balls | |
| nsample : int | |
| maximum number of features in the balls | |
| xyz : torch.Tensor | |
| (B, N, 3) xyz coordinates of the features | |
| new_xyz : torch.Tensor | |
| (B, npoint, 3) centers of the ball query | |
| Returns | |
| ------- | |
| torch.Tensor | |
| (B, npoint, nsample) tensor with the indicies of the features that form the query balls | |
| """ | |
| output = _ext.ball_query(new_xyz, xyz, radius, nsample) | |
| ctx.mark_non_differentiable(output) | |
| return output | |
| def backward(ctx, grad_out): | |
| return () | |
| ball_query = BallQuery.apply | |
| class QueryAndGroup(nn.Module): | |
| r""" | |
| Groups with a ball query of radius | |
| Parameters | |
| --------- | |
| radius : float32 | |
| Radius of ball | |
| nsample : int32 | |
| Maximum number of features to gather in the ball | |
| """ | |
| def __init__(self, radius, nsample, use_xyz=True): | |
| # type: (QueryAndGroup, float, int, bool) -> None | |
| super(QueryAndGroup, self).__init__() | |
| self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz | |
| def forward(self, xyz, new_xyz, features=None): | |
| # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] | |
| r""" | |
| Parameters | |
| ---------- | |
| xyz : torch.Tensor | |
| xyz coordinates of the features (B, N, 3) | |
| new_xyz : torch.Tensor | |
| centriods (B, npoint, 3) | |
| features : torch.Tensor | |
| Descriptors of the features (B, C, N) | |
| Returns | |
| ------- | |
| new_features : torch.Tensor | |
| (B, 3 + C, npoint, nsample) tensor | |
| """ | |
| idx = ball_query(self.radius, self.nsample, xyz, new_xyz) | |
| xyz_trans = xyz.transpose(1, 2).contiguous() | |
| grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) | |
| grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) | |
| if features is not None: | |
| grouped_features = grouping_operation(features, idx) | |
| if self.use_xyz: | |
| new_features = torch.cat( | |
| [grouped_xyz, grouped_features], dim=1 | |
| ) # (B, C + 3, npoint, nsample) | |
| else: | |
| new_features = grouped_features | |
| else: | |
| assert ( | |
| self.use_xyz | |
| ), "Cannot have not features and not use xyz as a feature!" | |
| new_features = grouped_xyz | |
| return new_features | |
| class GroupAll(nn.Module): | |
| r""" | |
| Groups all features | |
| Parameters | |
| --------- | |
| """ | |
| def __init__(self, use_xyz=True): | |
| # type: (GroupAll, bool) -> None | |
| super(GroupAll, self).__init__() | |
| self.use_xyz = use_xyz | |
| def forward(self, xyz, new_xyz, features=None): | |
| # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] | |
| r""" | |
| Parameters | |
| ---------- | |
| xyz : torch.Tensor | |
| xyz coordinates of the features (B, N, 3) | |
| new_xyz : torch.Tensor | |
| Ignored | |
| features : torch.Tensor | |
| Descriptors of the features (B, C, N) | |
| Returns | |
| ------- | |
| new_features : torch.Tensor | |
| (B, C + 3, 1, N) tensor | |
| """ | |
| grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) | |
| if features is not None: | |
| grouped_features = features.unsqueeze(2) | |
| if self.use_xyz: | |
| new_features = torch.cat( | |
| [grouped_xyz, grouped_features], dim=1 | |
| ) # (B, 3 + C, 1, N) | |
| else: | |
| new_features = grouped_features | |
| else: | |
| new_features = grouped_xyz | |
| return new_features | |