|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <math.h> |
|
|
#include <torch/extension.h> |
|
|
#include <cstdio> |
|
|
#include <sstream> |
|
|
#include <iostream> |
|
|
#include <tuple> |
|
|
#include <stdio.h> |
|
|
#include <cuda_runtime_api.h> |
|
|
#include <memory> |
|
|
#include "cuda_rasterizer/config.h" |
|
|
#include "cuda_rasterizer/rasterizer.h" |
|
|
#include "cuda_rasterizer/rasterizer_impl.h" |
|
|
#include <fstream> |
|
|
#include <string> |
|
|
#include <functional> |
|
|
|
|
|
std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) { |
|
|
auto lambda = [&t](size_t N) { |
|
|
t.resize_({(long long)N}); |
|
|
return reinterpret_cast<char*>(t.contiguous().data_ptr()); |
|
|
}; |
|
|
return lambda; |
|
|
} |
|
|
|
|
|
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> |
|
|
RasterizeGaussiansCUDA( |
|
|
const torch::Tensor& background, |
|
|
const torch::Tensor& means3D, |
|
|
const torch::Tensor& colors, |
|
|
const torch::Tensor& flows, |
|
|
const torch::Tensor& opacity, |
|
|
const torch::Tensor& ts, |
|
|
const torch::Tensor& scales, |
|
|
const torch::Tensor& scales_t, |
|
|
const torch::Tensor& rotations, |
|
|
const torch::Tensor& rotations_r, |
|
|
const float scale_modifier, |
|
|
const torch::Tensor& cov3D_precomp, |
|
|
const torch::Tensor& viewmatrix, |
|
|
const torch::Tensor& projmatrix, |
|
|
const float tan_fovx, |
|
|
const float tan_fovy, |
|
|
const int image_height, |
|
|
const int image_width, |
|
|
const torch::Tensor& sh, |
|
|
const int degree, |
|
|
const int degree_t, |
|
|
const torch::Tensor& campos, |
|
|
const float timestamp, |
|
|
const float time_duration, |
|
|
const bool rot_4d, |
|
|
const int gaussian_dim, |
|
|
const bool force_sh_3d, |
|
|
const bool prefiltered, |
|
|
const bool debug) |
|
|
{ |
|
|
if (means3D.ndimension() != 2 || means3D.size(1) != 3) { |
|
|
AT_ERROR("means3D must have dimensions (num_points, 3)"); |
|
|
} |
|
|
|
|
|
const int P = means3D.size(0); |
|
|
const int H = image_height; |
|
|
const int W = image_width; |
|
|
|
|
|
auto int_opts = means3D.options().dtype(torch::kInt32); |
|
|
auto float_opts = means3D.options().dtype(torch::kFloat32); |
|
|
|
|
|
torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); |
|
|
torch::Tensor out_flow = torch::full({2, H, W}, 0.0, float_opts); |
|
|
torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts); |
|
|
torch::Tensor out_T = torch::full({1, H, W}, 0.0, float_opts); |
|
|
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); |
|
|
torch::Tensor out_means3D = means3D.clone(); |
|
|
|
|
|
|
|
|
torch::Tensor accum_weights_ptr = torch::full({P}, 0, float_opts); |
|
|
torch::Tensor accum_weights_count = torch::full({P}, 0, int_opts); |
|
|
torch::Tensor accum_max_count = torch::full({P}, 0, float_opts); |
|
|
|
|
|
|
|
|
torch::Device device(torch::kCUDA); |
|
|
torch::TensorOptions options(torch::kByte); |
|
|
torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); |
|
|
torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); |
|
|
torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); |
|
|
std::function<char*(size_t)> geomFunc = resizeFunctional(geomBuffer); |
|
|
std::function<char*(size_t)> binningFunc = resizeFunctional(binningBuffer); |
|
|
std::function<char*(size_t)> imgFunc = resizeFunctional(imgBuffer); |
|
|
|
|
|
int rendered = 0; |
|
|
if(P != 0) |
|
|
{ |
|
|
int M = 0; |
|
|
if(sh.size(0) != 0) |
|
|
{ |
|
|
M = sh.size(1); |
|
|
} |
|
|
|
|
|
rendered = CudaRasterizer::Rasterizer::forward( |
|
|
geomFunc, |
|
|
binningFunc, |
|
|
imgFunc, |
|
|
P, degree, degree_t, M, |
|
|
background.contiguous().data<float>(), |
|
|
W, H, |
|
|
means3D.contiguous().data<float>(), |
|
|
out_means3D.contiguous().data<float>(), |
|
|
sh.contiguous().data_ptr<float>(), |
|
|
colors.contiguous().data<float>(), |
|
|
flows.contiguous().data<float>(), |
|
|
opacity.contiguous().data<float>(), |
|
|
ts.contiguous().data_ptr<float>(), |
|
|
scales.contiguous().data_ptr<float>(), |
|
|
scales_t.contiguous().data_ptr<float>(), |
|
|
scale_modifier, |
|
|
rotations.contiguous().data_ptr<float>(), |
|
|
rotations_r.contiguous().data_ptr<float>(), |
|
|
cov3D_precomp.contiguous().data<float>(), |
|
|
viewmatrix.contiguous().data<float>(), |
|
|
projmatrix.contiguous().data<float>(), |
|
|
campos.contiguous().data<float>(), |
|
|
timestamp, |
|
|
time_duration, |
|
|
rot_4d, |
|
|
gaussian_dim, |
|
|
force_sh_3d, |
|
|
tan_fovx, |
|
|
tan_fovy, |
|
|
prefiltered, |
|
|
out_color.contiguous().data<float>(), |
|
|
out_flow.contiguous().data<float>(), |
|
|
out_depth.contiguous().data<float>(), |
|
|
out_T.contiguous().data<float>(), |
|
|
|
|
|
|
|
|
|
|
|
accum_weights_ptr.contiguous().data<float>(), |
|
|
accum_weights_count.contiguous().data<int>(), |
|
|
accum_max_count.contiguous().data<float>(), |
|
|
|
|
|
|
|
|
|
|
|
radii.contiguous().data<int>(), |
|
|
debug); |
|
|
} |
|
|
char* geo_ptr = reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()); |
|
|
CudaRasterizer::GeometryState geoState = CudaRasterizer::GeometryState::fromChunk(geo_ptr, P); |
|
|
|
|
|
torch::Tensor covs3D_com = torch::from_blob(geoState.cov3D, {P, 6}, float_opts); |
|
|
|
|
|
return std::make_tuple(rendered, out_color, out_flow, out_depth, out_T, accum_weights_ptr, accum_weights_count, accum_max_count, radii, geomBuffer, binningBuffer, imgBuffer, covs3D_com, out_means3D); |
|
|
} |
|
|
|
|
|
|
|
|
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> |
|
|
RasterizeGaussiansBackwardCUDA( |
|
|
const torch::Tensor& background, |
|
|
const torch::Tensor& means3D, |
|
|
const torch::Tensor& out_means3D, |
|
|
const torch::Tensor& radii, |
|
|
const torch::Tensor& colors, |
|
|
const torch::Tensor& flows_2d, |
|
|
const torch::Tensor& opacities, |
|
|
const torch::Tensor& ts, |
|
|
const torch::Tensor& scales, |
|
|
const torch::Tensor& scales_t, |
|
|
const torch::Tensor& rotations, |
|
|
const torch::Tensor& rotations_r, |
|
|
const float scale_modifier, |
|
|
const torch::Tensor& cov3D_precomp, |
|
|
const torch::Tensor& viewmatrix, |
|
|
const torch::Tensor& projmatrix, |
|
|
const float tan_fovx, |
|
|
const float tan_fovy, |
|
|
const torch::Tensor& dL_dout_color, |
|
|
const torch::Tensor& dL_dout_depth, |
|
|
const torch::Tensor& dL_dout_mask, |
|
|
const torch::Tensor& dL_dout_flow, |
|
|
const torch::Tensor& sh, |
|
|
const int degree, |
|
|
const int degree_t, |
|
|
const torch::Tensor& campos, |
|
|
const float timestamp, |
|
|
const float time_duration, |
|
|
const bool rot_4d, |
|
|
const int gaussian_dim, |
|
|
const bool force_sh_3d, |
|
|
const torch::Tensor& geomBuffer, |
|
|
const int R, |
|
|
const torch::Tensor& binningBuffer, |
|
|
const torch::Tensor& imageBuffer, |
|
|
const bool debug) |
|
|
{ |
|
|
const int P = means3D.size(0); |
|
|
const int H = dL_dout_color.size(1); |
|
|
const int W = dL_dout_color.size(2); |
|
|
|
|
|
int M = 0; |
|
|
if(sh.size(0) != 0) |
|
|
{ |
|
|
M = sh.size(1); |
|
|
} |
|
|
|
|
|
torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options()); |
|
|
torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options()); |
|
|
torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options()); |
|
|
torch::Tensor dL_dflows = torch::zeros({P, 2}, means3D.options()); |
|
|
torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options()); |
|
|
torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options()); |
|
|
torch::Tensor dL_dts = torch::zeros({P, 1}, means3D.options()); |
|
|
torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options()); |
|
|
torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options()); |
|
|
torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options()); |
|
|
torch::Tensor dL_dscales_t = torch::zeros({P, 1}, means3D.options()); |
|
|
torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options()); |
|
|
torch::Tensor dL_drotations_r = torch::zeros({P, 4}, means3D.options()); |
|
|
|
|
|
if(P != 0) |
|
|
{ |
|
|
CudaRasterizer::Rasterizer::backward(P, degree, degree_t, M, R, |
|
|
background.contiguous().data<float>(), |
|
|
W, H, |
|
|
|
|
|
out_means3D.contiguous().data<float>(), |
|
|
sh.contiguous().data<float>(), |
|
|
colors.contiguous().data<float>(), |
|
|
flows_2d.contiguous().data<float>(), |
|
|
opacities.contiguous().data<float>(), |
|
|
ts.contiguous().data<float>(), |
|
|
scales.data_ptr<float>(), |
|
|
scales_t.data_ptr<float>(), |
|
|
scale_modifier, |
|
|
rotations.data_ptr<float>(), |
|
|
rotations_r.data_ptr<float>(), |
|
|
cov3D_precomp.contiguous().data<float>(), |
|
|
viewmatrix.contiguous().data<float>(), |
|
|
projmatrix.contiguous().data<float>(), |
|
|
campos.contiguous().data<float>(), |
|
|
timestamp, |
|
|
time_duration, |
|
|
rot_4d, |
|
|
gaussian_dim, |
|
|
force_sh_3d, |
|
|
tan_fovx, |
|
|
tan_fovy, |
|
|
radii.contiguous().data<int>(), |
|
|
reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()), |
|
|
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()), |
|
|
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()), |
|
|
dL_dout_color.contiguous().data<float>(), |
|
|
dL_dout_depth.contiguous().data<float>(), |
|
|
dL_dout_mask.contiguous().data<float>(), |
|
|
dL_dout_flow.contiguous().data<float>(), |
|
|
dL_dmeans2D.contiguous().data<float>(), |
|
|
dL_dconic.contiguous().data<float>(), |
|
|
dL_dopacity.contiguous().data<float>(), |
|
|
dL_dcolors.contiguous().data<float>(), |
|
|
dL_dmeans3D.contiguous().data<float>(), |
|
|
dL_dcov3D.contiguous().data<float>(), |
|
|
dL_dsh.contiguous().data<float>(), |
|
|
dL_dflows.contiguous().data<float>(), |
|
|
dL_dts.contiguous().data<float>(), |
|
|
dL_dscales.contiguous().data<float>(), |
|
|
dL_dscales_t.contiguous().data<float>(), |
|
|
dL_drotations.contiguous().data<float>(), |
|
|
dL_drotations_r.contiguous().data<float>(), |
|
|
debug); |
|
|
} |
|
|
|
|
|
return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, |
|
|
dL_dsh, dL_dflows, dL_dts, dL_dscales, dL_dscales_t, dL_drotations, dL_drotations_r); |
|
|
} |
|
|
|
|
|
torch::Tensor markVisible( |
|
|
torch::Tensor& means3D, |
|
|
torch::Tensor& viewmatrix, |
|
|
torch::Tensor& projmatrix) |
|
|
{ |
|
|
const int P = means3D.size(0); |
|
|
|
|
|
torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool)); |
|
|
|
|
|
if(P != 0) |
|
|
{ |
|
|
CudaRasterizer::Rasterizer::markVisible(P, |
|
|
means3D.contiguous().data<float>(), |
|
|
viewmatrix.contiguous().data<float>(), |
|
|
projmatrix.contiguous().data<float>(), |
|
|
present.contiguous().data<bool>()); |
|
|
} |
|
|
|
|
|
return present; |
|
|
} |