OMG4 / gaussian_renderer /diff_gaussian_rasterization.py
MinShirley
{update code}
11f5b0a
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
from typing import NamedTuple
import torch.nn as nn
import torch
# from . import _C
import os
from torch.utils.cpp_extension import load
parent_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "diff-gaussian-rasterization")
_C = load(
name='diff_gaussian_rasterization',
extra_cuda_cflags=["-I " + os.path.join(parent_dir, "third_party/glm/"), "-g"],
sources=[
os.path.join(parent_dir, "cuda_rasterizer/rasterizer_impl.cu"),
os.path.join(parent_dir, "cuda_rasterizer/forward.cu"),
os.path.join(parent_dir, "cuda_rasterizer/backward.cu"),
os.path.join(parent_dir, "rasterize_points.cu"),
os.path.join(parent_dir, "ext.cpp")],
verbose=True)
def cpu_deep_copy_tuple(input_tuple):
copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
return tuple(copied_tensors)
def rasterize_gaussians(
means3D,
means2D,
sh,
colors_precomp,
flow_2d,
opacities,
ts,
scales,
scales_t,
rotations,
rotations_r,
cov3Ds_precomp,
raster_settings,
):
return _RasterizeGaussians.apply(
means3D,
means2D,
sh,
colors_precomp,
flow_2d,
opacities,
ts,
scales,
scales_t,
rotations,
rotations_r,
cov3Ds_precomp,
raster_settings,
)
class _RasterizeGaussians(torch.autograd.Function):
@staticmethod
def forward(
ctx,
means3D,
means2D,
sh,
colors_precomp,
flow_2d,
opacities,
ts,
scales,
scales_t,
rotations,
rotations_r,
cov3Ds_precomp,
raster_settings,
):
# Restructure arguments the way that the C++ lib expects them
args = (
raster_settings.bg,
means3D,
colors_precomp,
flow_2d,
opacities,
ts,
scales,
scales_t,
rotations,
rotations_r,
raster_settings.scale_modifier,
cov3Ds_precomp,
raster_settings.viewmatrix,
raster_settings.projmatrix,
raster_settings.tanfovx,
raster_settings.tanfovy,
raster_settings.image_height,
raster_settings.image_width,
sh,
raster_settings.sh_degree,
raster_settings.sh_degree_t,
raster_settings.campos,
raster_settings.timestamp,
raster_settings.time_duration,
raster_settings.rot_4d,
raster_settings.gaussian_dim,
raster_settings.force_sh_3d,
raster_settings.prefiltered,
raster_settings.debug,
)
# import pickle
# f = open("/hdd/lms20031/pruning/tmp.pkl", "rb")
# good_args = pickle.load(f) # pickle.dump(dict, F)
# for i in range(len(good_args)):
# # try:
# # print(i, (good_args[i] == args[i]).all())
# # except:
# # print(i, (good_args[i] == args[i]))
# try:
# print(i, (good_args[i].dtype == args[i].dtype))
# except:
# print(i, (type(good_args[i]) == type(args[i])))
# Invoke C++/CUDA rasterizer
if raster_settings.debug:
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
try:
num_rendered, color, flow, depth, T, accum_weights_ptr, accum_weights_count, accum_max_count, \
radii, geomBuffer, binningBuffer, imgBuffer, covs_com, out_means3D = _C.rasterize_gaussians(*args)
except Exception as ex:
torch.save(cpu_args, "snapshot_fw.dump")
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
raise ex
else:
num_rendered, color, flow, depth, T, accum_weights_ptr, accum_weights_count, accum_max_count, \
radii, geomBuffer, binningBuffer, imgBuffer, covs_com, out_means3D = _C.rasterize_gaussians(*args)
# import pdb; pdb.set_trace()
# Keep relevant tensors for backward
ctx.raster_settings = raster_settings
ctx.num_rendered = num_rendered
ctx.save_for_backward(colors_precomp, means3D, out_means3D, scales, rotations, cov3Ds_precomp, radii, sh,
flow_2d, opacities, ts, scales_t, rotations_r,
geomBuffer, binningBuffer, imgBuffer)
return color, radii, depth, 1-T, flow, covs_com, accum_weights_ptr, accum_weights_count, accum_max_count
@staticmethod
def backward(ctx, grad_out_color, grad_radii, grad_depth, grad_alpha, grad_flow, grad_covs_com, a, b, c):
# Restore necessary values from context
num_rendered = ctx.num_rendered
raster_settings = ctx.raster_settings
(colors_precomp, means3D, out_means3D, scales, rotations, cov3Ds_precomp, radii, sh,
flow_2d, opacities, ts, scales_t, rotations_r,
geomBuffer, binningBuffer, imgBuffer) = ctx.saved_tensors
# Restructure args as C++ method expects them
args = (raster_settings.bg,
means3D,
out_means3D,
radii,
colors_precomp,
flow_2d,
opacities,
ts,
scales,
scales_t,
rotations,
rotations_r,
raster_settings.scale_modifier,
cov3Ds_precomp,
raster_settings.viewmatrix,
raster_settings.projmatrix,
raster_settings.tanfovx,
raster_settings.tanfovy,
grad_out_color,
grad_depth,
grad_alpha,
grad_flow,
sh,
raster_settings.sh_degree,
raster_settings.sh_degree_t,
raster_settings.campos,
raster_settings.timestamp,
raster_settings.time_duration,
raster_settings.rot_4d,
raster_settings.gaussian_dim,
raster_settings.force_sh_3d,
geomBuffer,
num_rendered,
binningBuffer,
imgBuffer,
raster_settings.debug)
# Compute gradients for relevant tensors by invoking backward method
if raster_settings.debug:
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
try:
(grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh,
grad_flows, grad_ts, grad_scales, grad_scales_t,
grad_rotations, grad_rotations_r) = _C.rasterize_gaussians_backward(*args)
except Exception as ex:
torch.save(cpu_args, "snapshot_bw.dump")
print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
raise ex
else:
(grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh,
grad_flows, grad_ts, grad_scales, grad_scales_t,
grad_rotations, grad_rotations_r) = _C.rasterize_gaussians_backward(*args)
#import pdb; pdb.set_trace()
grads = (
grad_means3D,
grad_means2D,
grad_sh,
grad_colors_precomp,
grad_flows,
grad_opacities,
grad_ts,
grad_scales,
grad_scales_t,
grad_rotations,
grad_rotations_r,
grad_cov3Ds_precomp,
None,
)
return grads
class GaussianRasterizationSettings(NamedTuple):
image_height: int
image_width: int
tanfovx : float
tanfovy : float
bg : torch.Tensor
scale_modifier : float
viewmatrix : torch.Tensor
projmatrix : torch.Tensor
sh_degree : int
sh_degree_t: int
campos : torch.Tensor
timestamp: float
time_duration: float
rot_4d: bool
gaussian_dim: int
force_sh_3d: bool
prefiltered : bool
debug : bool
class GaussianRasterizer(nn.Module):
def __init__(self, raster_settings):
super().__init__()
self.raster_settings = raster_settings
def markVisible(self, positions):
# Mark visible points (based on frustum culling for camera) with a boolean
with torch.no_grad():
raster_settings = self.raster_settings
visible = _C.mark_visible(
positions,
raster_settings.viewmatrix,
raster_settings.projmatrix)
return visible
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, flow_2d = None, ts=None,
scales = None, scales_t=None,
rotations = None, rotations_r=None,
cov3D_precomp = None):
raster_settings = self.raster_settings
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
raise Exception('Please provide excatly one of either SHs or precomputed colors!')
if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
if self.raster_settings.rot_4d and cov3D_precomp is None and (
rotations_r is None or scales_t is None or ts is None):
raise Exception(
'Please provide exactly rotations_r and scales_t and ts if rot_4d and cov3D_precomp is None!')
if shs is None:
shs = torch.Tensor([])
if colors_precomp is None:
colors_precomp = torch.Tensor([])
if flow_2d is None:
flow_2d = torch.Tensor([])
if ts is None:
ts = torch.Tensor([])
if scales is None:
scales = torch.Tensor([])
if scales_t is None:
scales_t = torch.Tensor([])
if rotations is None:
rotations = torch.Tensor([])
if rotations_r is None:
rotations_r = torch.Tensor([])
if cov3D_precomp is None:
cov3D_precomp = torch.Tensor([])
# Invoke C++/CUDA rasterization routine
return rasterize_gaussians(
means3D,
means2D,
shs,
colors_precomp,
flow_2d,
opacities,
ts,
scales,
scales_t,
rotations,
rotations_r,
cov3D_precomp,
raster_settings,
)