# Copyright © 2025, Adobe Inc. and its licensors. All rights reserved. # # This file is licensed under the Adobe Research License. You may obtain a copy # of the license at https://raw.githubusercontent.com/adobe-research/FaceLift/main/LICENSE.md import math import os import cv2 import matplotlib import numpy as np import torch from diff_gaussian_rasterization import ( GaussianRasterizationSettings, GaussianRasterizer, ) from einops import rearrange from plyfile import PlyData, PlyElement from torch import nn from collections import OrderedDict import videoio @torch.no_grad() def get_turntable_cameras( hfov=50, num_views=8, w=384, h=384, radius=2.7, elevation=20, up_vector=np.array([0, 0, 1]), ): fx = w / (2 * np.tan(np.deg2rad(hfov) / 2.0)) fy = fx cx, cy = w / 2.0, h / 2.0 fxfycxcy = ( np.array([fx, fy, cx, cy]).reshape(1, 4).repeat(num_views, axis=0) ) # [num_views, 4] # azimuths = np.linspace(0, 360, num_views, endpoint=False) azimuths = np.linspace(270, 630, num_views, endpoint=False) elevations = np.ones_like(azimuths) * elevation c2ws = [] for elev, azim in zip(elevations, azimuths): elev, azim = np.deg2rad(elev), np.deg2rad(azim) z = radius * np.sin(elev) base = radius * np.cos(elev) x = base * np.cos(azim) y = base * np.sin(azim) cam_pos = np.array([x, y, z]) forward = -cam_pos / np.linalg.norm(cam_pos) right = np.cross(forward, up_vector) right = right / np.linalg.norm(right) up = np.cross(right, forward) up = up / np.linalg.norm(up) R = np.stack((right, -up, forward), axis=1) c2w = np.eye(4) c2w[:3, :4] = np.concatenate((R, cam_pos[:, None]), axis=1) c2ws.append(c2w) c2ws = np.stack(c2ws, axis=0) # [num_views, 4, 4] return w, h, num_views, fxfycxcy, c2ws def imageseq2video(images, filename, fps=24): # if images is uint8, convert to float32 if images.dtype == np.uint8: images = images.astype(np.float32) / 255.0 videoio.videosave(filename, images, lossless=True, preset="veryfast", fps=fps) # copied from: utils.general_utils def strip_lowerdiag(L): uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=L.device) uncertainty[:, 0] = L[:, 0, 0] uncertainty[:, 1] = L[:, 0, 1] uncertainty[:, 2] = L[:, 0, 2] uncertainty[:, 3] = L[:, 1, 1] uncertainty[:, 4] = L[:, 1, 2] uncertainty[:, 5] = L[:, 2, 2] return uncertainty def strip_symmetric(sym): return strip_lowerdiag(sym) def build_rotation(r): norm = torch.sqrt( r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3] ) q = r / norm[:, None] R = torch.zeros((q.size(0), 3, 3), device=r.device) r = q[:, 0] x = q[:, 1] y = q[:, 2] z = q[:, 3] R[:, 0, 0] = 1 - 2 * (y * y + z * z) R[:, 0, 1] = 2 * (x * y - r * z) R[:, 0, 2] = 2 * (x * z + r * y) R[:, 1, 0] = 2 * (x * y + r * z) R[:, 1, 1] = 1 - 2 * (x * x + z * z) R[:, 1, 2] = 2 * (y * z - r * x) R[:, 2, 0] = 2 * (x * z - r * y) R[:, 2, 1] = 2 * (y * z + r * x) R[:, 2, 2] = 1 - 2 * (x * x + y * y) return R def build_scaling_rotation(s, r): L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=s.device) R = build_rotation(r) L[:, 0, 0] = s[:, 0] L[:, 1, 1] = s[:, 1] L[:, 2, 2] = s[:, 2] L = R @ L return L def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): L = build_scaling_rotation(scaling_modifier * scaling, rotation) actual_covariance = L @ L.transpose(1, 2) symm = strip_symmetric(actual_covariance) return symm # copied from: utils.sh_utils C0 = 0.28209479177387814 C1 = 0.4886025119029199 C2 = [ 1.0925484305920792, -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, 0.5462742152960396, ] C3 = [ -0.5900435899266435, 2.890611442640554, -0.4570457994644658, 0.3731763325901154, -0.4570457994644658, 1.445305721320277, -0.5900435899266435, ] C4 = [ 2.5033429417967046, -1.7701307697799304, 0.9461746957575601, -0.6690465435572892, 0.10578554691520431, -0.6690465435572892, 0.47308734787878004, -1.7701307697799304, 0.6258357354491761, ] def eval_sh(deg, sh, dirs): """ Evaluate spherical harmonics at unit directions using hardcoded SH polynomials. Works with torch/np/jnp. ... Can be 0 or more batch dimensions. Args: deg: int SH deg. Currently, 0-3 supported sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] dirs: jnp.ndarray unit directions [..., 3] Returns: [..., C] """ assert deg <= 4 and deg >= 0 coeff = (deg + 1) ** 2 assert sh.shape[-1] >= coeff result = C0 * sh[..., 0] if deg > 0: x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] result = ( result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] ) if deg > 1: xx, yy, zz = x * x, y * y, z * z xy, yz, xz = x * y, y * z, x * z result = ( result + C2[0] * xy * sh[..., 4] + C2[1] * yz * sh[..., 5] + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + C2[3] * xz * sh[..., 7] + C2[4] * (xx - yy) * sh[..., 8] ) if deg > 2: result = ( result + C3[0] * y * (3 * xx - yy) * sh[..., 9] + C3[1] * xy * z * sh[..., 10] + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + C3[5] * z * (xx - yy) * sh[..., 14] + C3[6] * x * (xx - 3 * yy) * sh[..., 15] ) if deg > 3: result = ( result + C4[0] * xy * (xx - yy) * sh[..., 16] + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24] ) return result def RGB2SH(rgb): return (rgb - 0.5) / C0 def SH2RGB(sh): return sh * C0 + 0.5 def create_video(image_folder, output_video_file, framerate=30): # Get all image file paths to a list. images = [img for img in os.listdir(image_folder) if img.endswith(".png")] images.sort() # Read the first image to know the height and width frame = cv2.imread(os.path.join(image_folder, images[0])) height, width, layers = frame.shape video = cv2.VideoWriter( output_video_file, cv2.VideoWriter_fourcc(*"mp4v"), framerate, (width, height) ) # iterate over each image and add it to the video sequence for image in images: video.write(cv2.imread(os.path.join(image_folder, image))) cv2.destroyAllWindows() video.release() class Camera(nn.Module): def __init__(self, C2W, fxfycxcy, h, w): """ C2W: 4x4 camera-to-world matrix; opencv convention fxfycxcy: 4 """ super().__init__() self.C2W = C2W.clone().float() self.W2C = self.C2W.inverse() self.h = h self.w = w self.znear = 0.01 self.zfar = 100.0 fx, fy, cx, cy = fxfycxcy[0], fxfycxcy[1], fxfycxcy[2], fxfycxcy[3] self.tanfovX = w / (2 * fx) self.tanfovY = h / (2 * fy) def getProjectionMatrix(W, H, fx, fy, cx, cy, znear, zfar): P = torch.zeros(4, 4, device=fx.device) P[0, 0] = 2 * fx / W P[1, 1] = 2 * fy / H P[0, 2] = 2 * (cx / W) - 1 P[1, 2] = 2 * (cy / H) - 1 P[2, 2] = -(zfar + znear) / (zfar - znear) P[3, 2] = 1.0 P[2, 3] = -(2 * zfar * znear) / (zfar - znear) return P self.world_view_transform = self.W2C.transpose(0, 1) self.projection_matrix = getProjectionMatrix( self.w, self.h, fx, fy, cx, cy, self.znear, self.zfar ).transpose(0, 1) self.full_proj_transform = ( self.world_view_transform.unsqueeze(0).bmm( self.projection_matrix.unsqueeze(0) ) ).squeeze(0) self.camera_center = self.C2W[:3, 3] # modified from scene/gaussian_model.py class GaussianModel: def setup_functions(self): self.scaling_activation = torch.exp self.inv_scaling_activation = torch.log self.rotation_activation = torch.nn.functional.normalize self.opacity_activation = torch.sigmoid self.covariance_activation = build_covariance_from_scaling_rotation def __init__(self, sh_degree: int, scaling_modifier=None): self.sh_degree = sh_degree self._xyz = torch.empty(0) self._features_dc = torch.empty(0) if self.sh_degree > 0: self._features_rest = torch.empty(0) else: self._features_rest = None self._scaling = torch.empty(0) self._rotation = torch.empty(0) self._opacity = torch.empty(0) self.setup_functions() self.scaling_modifier = scaling_modifier def empty(self): self.__init__(self.sh_degree, self.scaling_modifier) def set_data(self, xyz, features, scaling, rotation, opacity): """ xyz : torch.tensor of shape (N, 3) features : torch.tensor of shape (N, (self.sh_degree + 1) ** 2, 3) scaling : torch.tensor of shape (N, 3) rotation : torch.tensor of shape (N, 4) opacity : torch.tensor of shape (N, 1) """ self._xyz = xyz self._features_dc = features[:, :1, :].contiguous() if self.sh_degree > 0: self._features_rest = features[:, 1:, :].contiguous() else: self._features_rest = None self._scaling = scaling self._rotation = rotation self._opacity = opacity return self def to(self, device): self._xyz = self._xyz.to(device) self._features_dc = self._features_dc.to(device) if self.sh_degree > 0: self._features_rest = self._features_rest.to(device) self._scaling = self._scaling.to(device) self._rotation = self._rotation.to(device) self._opacity = self._opacity.to(device) return self def filter(self, valid_mask): self._xyz = self._xyz[valid_mask] self._features_dc = self._features_dc[valid_mask] if self.sh_degree > 0: self._features_rest = self._features_rest[valid_mask] self._scaling = self._scaling[valid_mask] self._rotation = self._rotation[valid_mask] self._opacity = self._opacity[valid_mask] return self def crop(self, crop_bbx=[-1, 1, -1, 1, -1, 1]): x_min, x_max, y_min, y_max, z_min, z_max = crop_bbx xyz = self._xyz invalid_mask = ( (xyz[:, 0] < x_min) | (xyz[:, 0] > x_max) | (xyz[:, 1] < y_min) | (xyz[:, 1] > y_max) | (xyz[:, 2] < z_min) | (xyz[:, 2] > z_max) ) valid_mask = ~invalid_mask return self.filter(valid_mask) def crop_by_xyz(self, floater_thres=0.75): xyz = self._xyz invalid_mask = ( (((xyz[:, 0] < -floater_thres) & (xyz[:, 1] < -floater_thres)) | ((xyz[:, 0] < -floater_thres) & (xyz[:, 1] > floater_thres)) | ((xyz[:, 0] > floater_thres) & (xyz[:, 1] < -floater_thres)) | ((xyz[:, 0] > floater_thres) & (xyz[:, 1] > floater_thres))) & (xyz[:, 2] < -floater_thres) ) valid_mask = ~invalid_mask return self.filter(valid_mask) def prune(self, opacity_thres=0.05): opacity = self.get_opacity.squeeze(1) valid_mask = opacity > opacity_thres return self.filter(valid_mask) def prune_by_scaling(self, scaling_thres=0.1): scaling = self.get_scaling valid_mask = scaling.max(dim=1).values < scaling_thres position_mask = self._xyz[:, 2] > 0 valid_mask = valid_mask | position_mask return self.filter(valid_mask) def prune_by_nearfar(self, cam_origins, nearfar_percent=(0.01, 0.99)): # cam_origins: [num_cams, 3] # nearfar_percent: [near, far] assert len(nearfar_percent) == 2 assert nearfar_percent[0] < nearfar_percent[1] assert nearfar_percent[0] >= 0 and nearfar_percent[1] <= 1 device = self._xyz.device # compute distance of all points to all cameras # [num_points, num_cams] dists = torch.cdist(self._xyz[None], cam_origins[None].to(device))[0] # [2, num_cams] dists_percentile = torch.quantile( dists, torch.tensor(nearfar_percent).to(device), dim=0 ) # prune all points that are outside the percentile range # [num_points, num_cams] # goal: prune points that are too close or too far from any camera reject_mask = (dists < dists_percentile[0:1, :]) | ( dists > dists_percentile[1:2, :] ) reject_mask = reject_mask.any(dim=1) valid_mask = ~reject_mask return self.filter(valid_mask) def apply_all_filters( self, opacity_thres=0.05, scaling_thres=None, floater_thres=None, crop_bbx=[-1, 1, -1, 1, -1, 1], cam_origins=None, nearfar_percent=(0.005, 1.0), ): self.prune(opacity_thres) if scaling_thres is not None: self.prune_by_scaling(scaling_thres) if floater_thres is not None: self.crop_by_xyz(floater_thres) if crop_bbx is not None: self.crop(crop_bbx) if cam_origins is not None: self.prune_by_nearfar(cam_origins, nearfar_percent) return self def shrink_bbx(self, drop_ratio=0.05): xyz = self._xyz xyz_min, xyz_max = torch.quantile( xyz, torch.tensor([drop_ratio, 1 - drop_ratio]).float().to(xyz.device), dim=0, ) # [2, N] xyz_min = xyz_min.detach().cpu().numpy() xyz_max = xyz_max.detach().cpu().numpy() crop_bbx = [ xyz_min[0], xyz_max[0], xyz_min[1], xyz_max[1], xyz_min[2], xyz_max[2], ] print(f"Shrinking bbx to {crop_bbx}") return self.crop(crop_bbx) def report_stats(self): print( f"xyz: {self._xyz.shape}, {self._xyz.min().item()}, {self._xyz.max().item()}" ) print( f"features_dc: {self._features_dc.shape}, {self._features_dc.min().item()}, {self._features_dc.max().item()}" ) if self.sh_degree > 0: print( f"features_rest: {self._features_rest.shape}, {self._features_rest.min().item()}, {self._features_rest.max().item()}" ) print( f"scaling: {self._scaling.shape}, {self._scaling.min().item()}, {self._scaling.max().item()}" ) print( f"rotation: {self._rotation.shape}, {self._rotation.min().item()}, {self._rotation.max().item()}" ) print( f"opacity: {self._opacity.shape}, {self._opacity.min().item()}, {self._opacity.max().item()}" ) print( f"after activation, xyz: {self.get_xyz.shape}, {self.get_xyz.min().item()}, {self.get_xyz.max().item()}" ) print( f"after activation, features: {self.get_features.shape}, {self.get_features.min().item()}, {self.get_features.max().item()}" ) print( f"after activation, scaling: {self.get_scaling.shape}, {self.get_scaling.min().item()}, {self.get_scaling.max().item()}" ) print( f"after activation, rotation: {self.get_rotation.shape}, {self.get_rotation.min().item()}, {self.get_rotation.max().item()}" ) print( f"after activation, opacity: {self.get_opacity.shape}, {self.get_opacity.min().item()}, {self.get_opacity.max().item()}" ) print( f"after activation, covariance: {self.get_covariance().shape}, {self.get_covariance().min().item()}, {self.get_covariance().max().item()}" ) @property def get_scaling(self): if self.scaling_modifier is not None: return self.scaling_activation(self._scaling) * self.scaling_modifier else: return self.scaling_activation(self._scaling) @property def get_rotation(self): return self.rotation_activation(self._rotation) @property def get_xyz(self): return self._xyz @property def get_features(self): if self.sh_degree > 0: features_dc = self._features_dc features_rest = self._features_rest return torch.cat((features_dc, features_rest), dim=1) else: return self._features_dc @property def get_opacity(self): return self.opacity_activation(self._opacity) def get_covariance(self, scaling_modifier=1): return self.covariance_activation( self.get_scaling, scaling_modifier, self._rotation ) def construct_dtypes(self, use_fp16=False, enable_gs_viewer=True): if not use_fp16: l = [ ("x", "f4"), ("y", "f4"), ("z", "f4"), ("red", "u1"), ("green", "u1"), ("blue", "u1"), ] # All channels except the 3 DC for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): l.append((f"f_dc_{i}", "f4")) if enable_gs_viewer: assert self.sh_degree <= 3, "GS viewer only supports SH up to degree 3" sh_degree = 3 for i in range(((sh_degree + 1) ** 2 - 1) * 3): l.append((f"f_rest_{i}", "f4")) else: if self.sh_degree > 0: for i in range( self._features_rest.shape[1] * self._features_rest.shape[2] ): l.append((f"f_rest_{i}", "f4")) l.append(("opacity", "f4")) for i in range(self._scaling.shape[1]): l.append((f"scale_{i}", "f4")) for i in range(self._rotation.shape[1]): l.append((f"rot_{i}", "f4")) else: l = [ ("x", "f2"), ("y", "f2"), ("z", "f2"), ("red", "u1"), ("green", "u1"), ("blue", "u1"), ] # All channels except the 3 DC for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): l.append((f"f_dc_{i}", "f2")) if self.sh_degree > 0: for i in range( self._features_rest.shape[1] * self._features_rest.shape[2] ): l.append((f"f_rest_{i}", "f2")) l.append(("opacity", "f2")) for i in range(self._scaling.shape[1]): l.append((f"scale_{i}", "f2")) for i in range(self._rotation.shape[1]): l.append((f"rot_{i}", "f2")) return l def save_ply( self, path, use_fp16=False, enable_gs_viewer=True, color_code=False, filter_mask=None, ): os.makedirs(os.path.dirname(path), exist_ok=True) xyz = self._xyz.detach().cpu().numpy() f_dc = ( self._features_dc.detach() .transpose(1, 2) .flatten(start_dim=1) .contiguous() .cpu() .numpy() ) if not color_code: rgb = (SH2RGB(f_dc) * 255.0).clip(0.0, 255.0).astype(np.uint8) else: # use an color map to color code the index of points index = np.linspace(0, 1, xyz.shape[0]) rgb = matplotlib.colormaps["viridis"](index)[..., :3] rgb = (rgb * 255.0).clip(0.0, 255.0).astype(np.uint8) opacities = self._opacity.detach().cpu().numpy() if self.scaling_modifier is not None: scale = self.inv_scaling_activation(self.get_scaling).detach().cpu().numpy() else: scale = self._scaling.detach().cpu().numpy() rotation = self._rotation.detach().cpu().numpy() dtype_full = self.construct_dtypes(use_fp16, enable_gs_viewer) elements = np.empty(xyz.shape[0], dtype=dtype_full) f_rest = None if self.sh_degree > 0: f_rest = ( self._features_rest.detach() .transpose(1, 2) .flatten(start_dim=1) .contiguous() .cpu() .numpy() ) # (3, (self.sh_degree + 1) ** 2 - 1) if enable_gs_viewer: sh_degree = 3 if f_rest is None: f_rest = np.zeros( (xyz.shape[0], 3 * ((sh_degree + 1) ** 2 - 1)), dtype=np.float32 ) elif f_rest.shape[1] < 3 * ((sh_degree + 1) ** 2 - 1): f_rest_pad = np.zeros( (xyz.shape[0], 3 * ((sh_degree + 1) ** 2 - 1)), dtype=np.float32 ) f_rest_pad[:, : f_rest.shape[1]] = f_rest f_rest = f_rest_pad if f_rest is not None: attributes = np.concatenate( (xyz, rgb, f_dc, f_rest, opacities, scale, rotation), axis=1 ) else: attributes = np.concatenate( (xyz, rgb, f_dc, opacities, scale, rotation), axis=1 ) if filter_mask is not None: attributes = attributes[filter_mask] elements = elements[filter_mask] elements[:] = list(map(tuple, attributes)) el = PlyElement.describe(elements, "vertex") PlyData([el]).write(path) def load_ply(self, path): plydata = PlyData.read(path) xyz = np.stack( ( np.asarray(plydata.elements[0]["x"]), np.asarray(plydata.elements[0]["y"]), np.asarray(plydata.elements[0]["z"]), ), axis=1, ) opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] features_dc = np.zeros((xyz.shape[0], 3, 1)) features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) if self.sh_degree > 0: extra_f_names = [ p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_") ] extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1])) assert len(extra_f_names) == 3 * (self.sh_degree + 1) ** 2 - 3 features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) for idx, attr_name in enumerate(extra_f_names): features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) features_extra = features_extra.reshape( (features_extra.shape[0], 3, (self.sh_degree + 1) ** 2 - 1) ) scale_names = [ p.name for p in plydata.elements[0].properties if p.name.startswith("scale_") ] scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) scales = np.zeros((xyz.shape[0], len(scale_names))) for idx, attr_name in enumerate(scale_names): scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) rot_names = [ p.name for p in plydata.elements[0].properties if p.name.startswith("rot") ] rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) rots = np.zeros((xyz.shape[0], len(rot_names))) for idx, attr_name in enumerate(rot_names): rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) self._xyz = torch.from_numpy(xyz.astype(np.float32)) self._features_dc = ( torch.from_numpy(features_dc.astype(np.float32)) .transpose(1, 2) .contiguous() ) if self.sh_degree > 0: self._features_rest = ( torch.from_numpy(features_extra.astype(np.float32)) .transpose(1, 2) .contiguous() ) self._opacity = torch.from_numpy( np.copy(opacities).astype(np.float32) ).contiguous() self._scaling = torch.from_numpy(scales.astype(np.float32)).contiguous() self._rotation = torch.from_numpy(rots.astype(np.float32)).contiguous() def render_opencv_cam( pc: GaussianModel, height: int, width: int, C2W: torch.Tensor, fxfycxcy: torch.Tensor, bg_color=(1.0, 1.0, 1.0), scaling_modifier=1.0, ): """ Render the scene. Background tensor (bg_color) must be on GPU! """ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means screenspace_points = torch.empty_like( pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda" ) # try: # screenspace_points.retain_grad() # except: # pass viewpoint_camera = Camera(C2W=C2W, fxfycxcy=fxfycxcy, h=height, w=width) bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=C2W.device) # Set up rasterization configuration raster_settings = GaussianRasterizationSettings( image_height=int(viewpoint_camera.h), image_width=int(viewpoint_camera.w), tanfovx=viewpoint_camera.tanfovX, tanfovy=viewpoint_camera.tanfovY, bg=bg_color, scale_modifier=scaling_modifier, viewmatrix=viewpoint_camera.world_view_transform, projmatrix=viewpoint_camera.full_proj_transform, sh_degree=pc.sh_degree, campos=viewpoint_camera.camera_center, prefiltered=False, debug=False, ) rasterizer = GaussianRasterizer(raster_settings=raster_settings) means3D = pc.get_xyz means2D = screenspace_points opacity = pc.get_opacity scales = pc.get_scaling rotations = pc.get_rotation shs = pc.get_features # Rasterize visible Gaussians to image, obtain their radii (on screen). rendered_image, radii = rasterizer( means3D=means3D, means2D=means2D, shs=shs, colors_precomp=None, opacities=opacity, scales=scales, rotations=rotations, cov3D_precomp=None, ) # Those Gaussians that were frustum culled or had a radius of 0 were not visible. # They will be excluded from value updates used in the splitting criteria. return { "render": rendered_image, "viewspace_points": screenspace_points, "visibility_filter": radii > 0, "radii": radii, } class DeferredGaussianRender(torch.autograd.Function): @staticmethod def forward( ctx, xyz, features, scaling, rotation, opacity, height, width, C2W, fxfycxcy, scaling_modifier=None, ): """ xyz: [b, n_gaussians, 3] features: [b, n_gaussians, (sh_degree+1)^2, 3] scaling: [b, n_gaussians, 3] rotation: [b, n_gaussians, 4] opacity: [b, n_gaussians, 1] height: int width: int C2W: [b, v, 4, 4] fxfycxcy: [b, v, 4] output: [b, v, 3, height, width] """ ctx.scaling_modifier = scaling_modifier # Infer sh_degree from features sh_degree = int(math.sqrt(features.shape[-2])) - 1 # Create a temp class to hold the data and for rendering gaussians_model = GaussianModel(sh_degree, scaling_modifier) with torch.no_grad(): b, v = C2W.size(0), C2W.size(1) renders = [] for i in range(b): pc = gaussians_model.set_data( xyz[i], features[i], scaling[i], rotation[i], opacity[i] ) for j in range(v): renders.append( render_opencv_cam(pc, height, width, C2W[i, j], fxfycxcy[i, j])[ "render" ] ) renders = torch.stack(renders, dim=0) renders = renders.reshape(b, v, 3, height, width) renders = renders.requires_grad_() # Save_for_backward only supports tensors ctx.save_for_backward(xyz, features, scaling, rotation, opacity, C2W, fxfycxcy) ctx.rendering_size = (height, width) ctx.sh_degree = sh_degree # Release the temp class; do not save it. del gaussians_model return renders @staticmethod def backward(ctx, grad_output): # Restore params xyz, features, scaling, rotation, opacity, C2W, fxfycxcy = ctx.saved_tensors height, width = ctx.rendering_size sh_degree = ctx.sh_degree # **The order of this dict should not be changed** input_dict = OrderedDict( [ ("xyz", xyz), ("features", features), ("scaling", scaling), ("rotation", rotation), ("opacity", opacity), ] ) input_dict = {k: v.detach().requires_grad_() for k, v in input_dict.items()} # Create a temp class to hold the data and for rendering gaussians_model = GaussianModel(sh_degree, ctx.scaling_modifier) with torch.enable_grad(): b, v = C2W.size(0), C2W.size(1) for i in range(b): for j in range(v): # The backward will remove the diff graph, thus each time we need a copy pc = gaussians_model.set_data( **{k: v[i] for k, v in input_dict.items()} ) # Forward render = render_opencv_cam( pc, height, width, C2W[i, j], fxfycxcy[i, j] )["render"] # Backward, suppose that only values in input_dict will get gradients. render.backward(grad_output[i, j]) del gaussians_model return *[var.grad for var in input_dict.values()], None, None, None, None, None # Function for the class deferred_gaussian_render = DeferredGaussianRender.apply @torch.no_grad() @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) def render_turntable(pc: GaussianModel, rendering_resolution=384, num_views=8): w, h, v, fxfycxcy, c2w = get_turntable_cameras( h=rendering_resolution, w=rendering_resolution, num_views=num_views, elevation=0, # For MAX SNEAK ) device = pc._xyz.device fxfycxcy = torch.from_numpy(fxfycxcy).float().to(device) # [v, 4] c2w = torch.from_numpy(c2w).float().to(device) # [v, 4, 4] renderings = torch.zeros(v, 3, h, w, dtype=torch.float32, device=device) for j in range(v): renderings[j] = render_opencv_cam(pc, h, w, c2w[j], fxfycxcy[j])["render"] torch.cuda.empty_cache() # free up memory on GPU renderings = renderings.detach().cpu().numpy() renderings = (renderings * 255).clip(0, 255).astype(np.uint8) renderings = rearrange(renderings, "v c h w -> h (v w) c") return renderings if __name__ == "__main__": import json from PIL import Image from tqdm import tqdm out_dir = "/mnt/localssd/debug-3dgs" os.makedirs(out_dir, exist_ok=True) os.system( f"wget https://phidias.s3.us-west-2.amazonaws.com/kaiz/neural-capture/eval-3dgs-lowres/AWS_test_set/results/1.fashion_boots_rubber_boots__short__Feb_21__2023_at_5_19_25_PM_yf/point_cloud/iteration_30000_fg/point_cloud.ply -O {out_dir}/point_cloud.ply" ) os.system( f"wget https://neural-capture.s3.us-west-2.amazonaws.com/data/AWS_test_set/preprocessed/1.fashion_boots_rubber_boots__short__Feb_21__2023_at_5_19_25_PM_yf/opencv_cameras_traj_norm.json -O {out_dir}/opencv_cameras_traj_norm.json" ) device = "cuda:0" pc = GaussianModel(sh_degree=3) pc.load_ply(f"{out_dir}/point_cloud.ply") pc = pc.to(device) # pc.save_ply(f"{out_dir}/point_cloud_shrink.ply") # pc.load_ply(f"{out_dir}/point_cloud_shrink.ply") # pc = pc.to(device) # pc.prune(opacity_thres=0.05) # pc.save_ply(f"{out_dir}/point_cloud_shrink_prune.ply") # pc = pc.to(device) # pc.shrink_bbx(drop_ratio=0.01) # pc.save_ply(f"{out_dir}/point_cloud_shrink_prune.ply") # pc = pc.to(device) pc.report_stats() with open(f"{out_dir}/opencv_cameras_traj_norm.json", "r") as f: cam_traj = json.load(f) for i, cam in tqdm(enumerate(cam_traj["frames"]), desc="Rendering progress"): w2c = np.array(cam["w2c"]) c2w = np.linalg.inv(w2c) c2w = torch.from_numpy(c2w.astype(np.float32)).to(device) fx = cam["fx"] fy = cam["fy"] cx = cam["cx"] cy = cam["cy"] cx = cx - 5 cy = cy + 4 fxfycxcy = torch.tensor([fx, fy, cx, cy], dtype=torch.float32, device=device) h = cam["h"] w = cam["w"] im = render_opencv_cam(pc, h, w, c2w, fxfycxcy, bg_color=[0.0, 0.0, 0.0])[ "render" ] im = im.detach().cpu().numpy().transpose(1, 2, 0) im = (im * 255).astype(np.uint8) Image.fromarray(im).save(f"{out_dir}/render_{i:08d}.png") create_video(out_dir, f"{out_dir}/render.mp4", framerate=30) print(f"Saved {out_dir}/render.mp4")