Spaces:
Running
on
T4
Running
on
T4
| void gather_points_kernel_wrapper(int b, int c, int n, int npoints, | |
| const float *points, const int *idx, | |
| float *out); | |
| void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, | |
| const float *grad_out, const int *idx, | |
| float *grad_points); | |
| void furthest_point_sampling_kernel_wrapper(int b, int n, int m, | |
| const float *dataset, float *temp, | |
| int *idxs); | |
| at::Tensor gather_points(at::Tensor points, at::Tensor idx) { | |
| CHECK_CONTIGUOUS(points); | |
| CHECK_CONTIGUOUS(idx); | |
| CHECK_IS_FLOAT(points); | |
| CHECK_IS_INT(idx); | |
| if (points.is_cuda()) { | |
| CHECK_CUDA(idx); | |
| } | |
| at::Tensor output = | |
| torch::zeros({points.size(0), points.size(1), idx.size(1)}, | |
| at::device(points.device()).dtype(at::ScalarType::Float)); | |
| if (points.is_cuda()) { | |
| gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), | |
| idx.size(1), points.data_ptr<float>(), | |
| idx.data_ptr<int>(), output.data_ptr<float>()); | |
| } else { | |
| AT_ASSERT(false, "CPU not supported"); | |
| } | |
| return output; | |
| } | |
| at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, | |
| const int n) { | |
| CHECK_CONTIGUOUS(grad_out); | |
| CHECK_CONTIGUOUS(idx); | |
| CHECK_IS_FLOAT(grad_out); | |
| CHECK_IS_INT(idx); | |
| if (grad_out.is_cuda()) { | |
| CHECK_CUDA(idx); | |
| } | |
| at::Tensor output = | |
| torch::zeros({grad_out.size(0), grad_out.size(1), n}, | |
| at::device(grad_out.device()).dtype(at::ScalarType::Float)); | |
| if (grad_out.is_cuda()) { | |
| gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, | |
| idx.size(1), grad_out.data_ptr<float>(), | |
| idx.data_ptr<int>(), | |
| output.data_ptr<float>()); | |
| } else { | |
| AT_ASSERT(false, "CPU not supported"); | |
| } | |
| return output; | |
| } | |
| at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { | |
| CHECK_CONTIGUOUS(points); | |
| CHECK_IS_FLOAT(points); | |
| at::Tensor output = | |
| torch::zeros({points.size(0), nsamples}, | |
| at::device(points.device()).dtype(at::ScalarType::Int)); | |
| at::Tensor tmp = | |
| torch::full({points.size(0), points.size(1)}, 1e10, | |
| at::device(points.device()).dtype(at::ScalarType::Float)); | |
| if (points.is_cuda()) { | |
| furthest_point_sampling_kernel_wrapper( | |
| points.size(0), points.size(1), nsamples, points.data_ptr<float>(), | |
| tmp.data_ptr<float>(), output.data_ptr<int>()); | |
| } else { | |
| AT_ASSERT(false, "CPU not supported"); | |
| } | |
| return output; | |
| } | |