OMG4 / diff-gaussian-rasterization /rasterize_points.cu
MinShirley
{update code}
11f5b0a
/*
* Copyright (C) 2023, Inria
* GRAPHDECO research group, https://team.inria.fr/graphdeco
* All rights reserved.
*
* This software is free for non-commercial, research and evaluation use
* under the terms of the LICENSE.md file.
*
* For inquiries contact george.drettakis@inria.fr
*/
#include <math.h>
#include <torch/extension.h>
#include <cstdio>
#include <sstream>
#include <iostream>
#include <tuple>
#include <stdio.h>
#include <cuda_runtime_api.h>
#include <memory>
#include "cuda_rasterizer/config.h"
#include "cuda_rasterizer/rasterizer.h"
#include "cuda_rasterizer/rasterizer_impl.h"
#include <fstream>
#include <string>
#include <functional>
std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) {
auto lambda = [&t](size_t N) {
t.resize_({(long long)N});
return reinterpret_cast<char*>(t.contiguous().data_ptr());
};
return lambda;
}
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& colors,
const torch::Tensor& flows,
const torch::Tensor& opacity,
const torch::Tensor& ts,
const torch::Tensor& scales,
const torch::Tensor& scales_t,
const torch::Tensor& rotations,
const torch::Tensor& rotations_r,
const float scale_modifier,
const torch::Tensor& cov3D_precomp,
const torch::Tensor& viewmatrix,
const torch::Tensor& projmatrix,
const float tan_fovx,
const float tan_fovy,
const int image_height,
const int image_width,
const torch::Tensor& sh,
const int degree,
const int degree_t,
const torch::Tensor& campos,
const float timestamp,
const float time_duration,
const bool rot_4d,
const int gaussian_dim,
const bool force_sh_3d,
const bool prefiltered,
const bool debug)
{
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
AT_ERROR("means3D must have dimensions (num_points, 3)");
}
const int P = means3D.size(0);
const int H = image_height;
const int W = image_width;
auto int_opts = means3D.options().dtype(torch::kInt32);
auto float_opts = means3D.options().dtype(torch::kFloat32);
torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
torch::Tensor out_flow = torch::full({2, H, W}, 0.0, float_opts);
torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts);
torch::Tensor out_T = torch::full({1, H, W}, 0.0, float_opts);
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
torch::Tensor out_means3D = means3D.clone();
// OMG
torch::Tensor accum_weights_ptr = torch::full({P}, 0, float_opts);
torch::Tensor accum_weights_count = torch::full({P}, 0, int_opts);
torch::Tensor accum_max_count = torch::full({P}, 0, float_opts);
//
torch::Device device(torch::kCUDA);
torch::TensorOptions options(torch::kByte);
torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
std::function<char*(size_t)> geomFunc = resizeFunctional(geomBuffer);
std::function<char*(size_t)> binningFunc = resizeFunctional(binningBuffer);
std::function<char*(size_t)> imgFunc = resizeFunctional(imgBuffer);
int rendered = 0;
if(P != 0)
{
int M = 0;
if(sh.size(0) != 0)
{
M = sh.size(1);
}
rendered = CudaRasterizer::Rasterizer::forward(
geomFunc,
binningFunc,
imgFunc,
P, degree, degree_t, M,
background.contiguous().data<float>(),
W, H,
means3D.contiguous().data<float>(),
out_means3D.contiguous().data<float>(),
sh.contiguous().data_ptr<float>(),
colors.contiguous().data<float>(),
flows.contiguous().data<float>(),
opacity.contiguous().data<float>(),
ts.contiguous().data_ptr<float>(),
scales.contiguous().data_ptr<float>(),
scales_t.contiguous().data_ptr<float>(),
scale_modifier,
rotations.contiguous().data_ptr<float>(),
rotations_r.contiguous().data_ptr<float>(),
cov3D_precomp.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(),
campos.contiguous().data<float>(),
timestamp,
time_duration,
rot_4d,
gaussian_dim,
force_sh_3d,
tan_fovx,
tan_fovy,
prefiltered,
out_color.contiguous().data<float>(),
out_flow.contiguous().data<float>(),
out_depth.contiguous().data<float>(),
out_T.contiguous().data<float>(),
// OMG accum_weights 추가
accum_weights_ptr.contiguous().data<float>(),
accum_weights_count.contiguous().data<int>(),
accum_max_count.contiguous().data<float>(),
radii.contiguous().data<int>(),
debug);
}
char* geo_ptr = reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr());
CudaRasterizer::GeometryState geoState = CudaRasterizer::GeometryState::fromChunk(geo_ptr, P);
torch::Tensor covs3D_com = torch::from_blob(geoState.cov3D, {P, 6}, float_opts);
// return 값 수정 필요-완
return std::make_tuple(rendered, out_color, out_flow, out_depth, out_T, accum_weights_ptr, accum_weights_count, accum_max_count, radii, geomBuffer, binningBuffer, imgBuffer, covs3D_com, out_means3D);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA(
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& out_means3D,
const torch::Tensor& radii,
const torch::Tensor& colors,
const torch::Tensor& flows_2d,
const torch::Tensor& opacities,
const torch::Tensor& ts,
const torch::Tensor& scales,
const torch::Tensor& scales_t,
const torch::Tensor& rotations,
const torch::Tensor& rotations_r,
const float scale_modifier,
const torch::Tensor& cov3D_precomp,
const torch::Tensor& viewmatrix,
const torch::Tensor& projmatrix,
const float tan_fovx,
const float tan_fovy,
const torch::Tensor& dL_dout_color,
const torch::Tensor& dL_dout_depth,
const torch::Tensor& dL_dout_mask,
const torch::Tensor& dL_dout_flow,
const torch::Tensor& sh,
const int degree,
const int degree_t,
const torch::Tensor& campos,
const float timestamp,
const float time_duration,
const bool rot_4d,
const int gaussian_dim,
const bool force_sh_3d,
const torch::Tensor& geomBuffer,
const int R,
const torch::Tensor& binningBuffer,
const torch::Tensor& imageBuffer,
const bool debug)
{
const int P = means3D.size(0);
const int H = dL_dout_color.size(1);
const int W = dL_dout_color.size(2);
int M = 0;
if(sh.size(0) != 0)
{
M = sh.size(1);
}
torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options());
torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options());
torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options());
torch::Tensor dL_dflows = torch::zeros({P, 2}, means3D.options());
torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options());
torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options());
torch::Tensor dL_dts = torch::zeros({P, 1}, means3D.options());
torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options());
torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options());
torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options());
torch::Tensor dL_dscales_t = torch::zeros({P, 1}, means3D.options());
torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options());
torch::Tensor dL_drotations_r = torch::zeros({P, 4}, means3D.options());
if(P != 0)
{
CudaRasterizer::Rasterizer::backward(P, degree, degree_t, M, R,
background.contiguous().data<float>(),
W, H,
// means3D.contiguous().data<float>(),
out_means3D.contiguous().data<float>(),
sh.contiguous().data<float>(),
colors.contiguous().data<float>(),
flows_2d.contiguous().data<float>(),
opacities.contiguous().data<float>(),
ts.contiguous().data<float>(),
scales.data_ptr<float>(),
scales_t.data_ptr<float>(),
scale_modifier,
rotations.data_ptr<float>(),
rotations_r.data_ptr<float>(),
cov3D_precomp.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(),
campos.contiguous().data<float>(),
timestamp,
time_duration,
rot_4d,
gaussian_dim,
force_sh_3d,
tan_fovx,
tan_fovy,
radii.contiguous().data<int>(),
reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
dL_dout_color.contiguous().data<float>(),
dL_dout_depth.contiguous().data<float>(),
dL_dout_mask.contiguous().data<float>(),
dL_dout_flow.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
dL_dopacity.contiguous().data<float>(),
dL_dcolors.contiguous().data<float>(),
dL_dmeans3D.contiguous().data<float>(),
dL_dcov3D.contiguous().data<float>(),
dL_dsh.contiguous().data<float>(),
dL_dflows.contiguous().data<float>(),
dL_dts.contiguous().data<float>(),
dL_dscales.contiguous().data<float>(),
dL_dscales_t.contiguous().data<float>(),
dL_drotations.contiguous().data<float>(),
dL_drotations_r.contiguous().data<float>(),
debug);
}
return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D,
dL_dsh, dL_dflows, dL_dts, dL_dscales, dL_dscales_t, dL_drotations, dL_drotations_r);
}
torch::Tensor markVisible(
torch::Tensor& means3D,
torch::Tensor& viewmatrix,
torch::Tensor& projmatrix)
{
const int P = means3D.size(0);
torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
if(P != 0)
{
CudaRasterizer::Rasterizer::markVisible(P,
means3D.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(),
present.contiguous().data<bool>());
}
return present;
}