from pathlib import Path import numpy as np import torch import moviepy.editor as mpy from src.models.models.rasterization import GaussianSplatRenderer from src.models.utils.sh_utils import RGB2SH, SH2RGB from src.utils.gs_effects import GSEffects from src.utils.color_map import apply_color_map_to_image from tqdm import tqdm def rotation_matrix_to_quaternion(R): """Convert rotation matrix to quaternion""" trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] q = torch.zeros(R.shape[:-2] + (4,), device=R.device, dtype=R.dtype) # Case where trace > 0 mask1 = trace > 0 s = torch.sqrt(trace[mask1] + 1.0) * 2 # s=4*qw q[mask1, 0] = 0.25 * s # qw q[mask1, 1] = (R[mask1, 2, 1] - R[mask1, 1, 2]) / s # qx q[mask1, 2] = (R[mask1, 0, 2] - R[mask1, 2, 0]) / s # qy q[mask1, 3] = (R[mask1, 1, 0] - R[mask1, 0, 1]) / s # qz # Case where R[0,0] > R[1,1] and R[0,0] > R[2,2] mask2 = (~mask1) & (R[..., 0, 0] > R[..., 1, 1]) & (R[..., 0, 0] > R[..., 2, 2]) s = torch.sqrt(1.0 + R[mask2, 0, 0] - R[mask2, 1, 1] - R[mask2, 2, 2]) * 2 # s=4*qx q[mask2, 0] = (R[mask2, 2, 1] - R[mask2, 1, 2]) / s # qw q[mask2, 1] = 0.25 * s # qx q[mask2, 2] = (R[mask2, 0, 1] + R[mask2, 1, 0]) / s # qy q[mask2, 3] = (R[mask2, 0, 2] + R[mask2, 2, 0]) / s # qz # Case where R[1,1] > R[2,2] mask3 = (~mask1) & (~mask2) & (R[..., 1, 1] > R[..., 2, 2]) s = torch.sqrt(1.0 + R[mask3, 1, 1] - R[mask3, 0, 0] - R[mask3, 2, 2]) * 2 # s=4*qy q[mask3, 0] = (R[mask3, 0, 2] - R[mask3, 2, 0]) / s # qw q[mask3, 1] = (R[mask3, 0, 1] + R[mask3, 1, 0]) / s # qx q[mask3, 2] = 0.25 * s # qy q[mask3, 3] = (R[mask3, 1, 2] + R[mask3, 2, 1]) / s # qz # Remaining case mask4 = (~mask1) & (~mask2) & (~mask3) s = torch.sqrt(1.0 + R[mask4, 2, 2] - R[mask4, 0, 0] - R[mask4, 1, 1]) * 2 # s=4*qz q[mask4, 0] = (R[mask4, 1, 0] - R[mask4, 0, 1]) / s # qw q[mask4, 1] = (R[mask4, 0, 2] + R[mask4, 2, 0]) / s # qx q[mask4, 2] = (R[mask4, 1, 2] + R[mask4, 2, 1]) / s # qy q[mask4, 3] = 0.25 * s # qz return q def quaternion_to_rotation_matrix(q): """Convert quaternion to rotation matrix""" w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3] # Normalize quaternion norm = torch.sqrt(w*w + x*x + y*y + z*z) w, x, y, z = w/norm, x/norm, y/norm, z/norm R = torch.zeros(q.shape[:-1] + (3, 3), device=q.device, dtype=q.dtype) R[..., 0, 0] = 1 - 2*(y*y + z*z) R[..., 0, 1] = 2*(x*y - w*z) R[..., 0, 2] = 2*(x*z + w*y) R[..., 1, 0] = 2*(x*y + w*z) R[..., 1, 1] = 1 - 2*(x*x + z*z) R[..., 1, 2] = 2*(y*z - w*x) R[..., 2, 0] = 2*(x*z - w*y) R[..., 2, 1] = 2*(y*z + w*x) R[..., 2, 2] = 1 - 2*(x*x + y*y) return R def slerp_quaternions(q1, q2, t): """Spherical linear interpolation between quaternions""" # Compute dot product dot = (q1 * q2).sum(dim=-1, keepdim=True) # If dot product is negative, slerp won't take the shorter path. # Note that q and -q represent the same rotation, so we can flip one. mask = dot < 0 q2 = torch.where(mask, -q2, q2) dot = torch.where(mask, -dot, dot) # If the inputs are too close for comfort, linearly interpolate # and normalize the result. DOT_THRESHOLD = 0.9995 mask_linear = dot > DOT_THRESHOLD result = torch.zeros_like(q1) # Linear interpolation for close quaternions if mask_linear.any(): result_linear = q1 + t * (q2 - q1) norm = torch.norm(result_linear, dim=-1, keepdim=True) result_linear = result_linear / norm result = torch.where(mask_linear, result_linear, result) # Spherical interpolation for distant quaternions mask_slerp = ~mask_linear if mask_slerp.any(): theta_0 = torch.acos(torch.abs(dot)) sin_theta_0 = torch.sin(theta_0) theta = theta_0 * t sin_theta = torch.sin(theta) s0 = torch.cos(theta) - dot * sin_theta / sin_theta_0 s1 = sin_theta / sin_theta_0 result_slerp = (s0 * q1) + (s1 * q2) result = torch.where(mask_slerp, result_slerp, result) return result def render_interpolated_video(gs_renderer: GaussianSplatRenderer, splats: dict, camtoworlds: torch.Tensor, intrinsics: torch.Tensor, hw: tuple[int, int], out_path: Path, interp_per_pair: int = 20, loop_reverse: bool = True, effects: GSEffects = None, effect_type: int = 2, save_mode: str = "split") -> None: # camtoworlds: [B, S, 4, 4], intrinsics: [B, S, 3, 3] b, s, _, _ = camtoworlds.shape h, w = hw # Build interpolated trajectory def build_interpolated_traj(index, nums): exts, ints = [], [] tmp_camtoworlds = camtoworlds[:, index] tmp_intrinsics = intrinsics[:, index] for i in range(len(index)-1): exts.append(tmp_camtoworlds[:, i:i+1]) ints.append(tmp_intrinsics[:, i:i+1]) # Extract rotation and translation R0, t0 = tmp_camtoworlds[:, i, :3, :3], tmp_camtoworlds[:, i, :3, 3] R1, t1 = tmp_camtoworlds[:, i + 1, :3, :3], tmp_camtoworlds[:, i + 1, :3, 3] # Convert rotations to quaternions q0 = rotation_matrix_to_quaternion(R0) q1 = rotation_matrix_to_quaternion(R1) # Interpolate using smooth quaternion slerp for j in range(1, nums + 1): alpha = j / (nums + 1) # Linear interpolation for translation t_interp = (1 - alpha) * t0 + alpha * t1 # Spherical interpolation for rotation q_interp = slerp_quaternions(q0, q1, alpha) R_interp = quaternion_to_rotation_matrix(q_interp) # Create interpolated extrinsic matrix ext = torch.eye(4, device=R_interp.device, dtype=R_interp.dtype)[None].repeat(b, 1, 1) ext[:, :3, :3] = R_interp ext[:, :3, 3] = t_interp # Linear interpolation for intrinsics K0 = tmp_intrinsics[:, i] K1 = tmp_intrinsics[:, i + 1] K = (1 - alpha) * K0 + alpha * K1 exts.append(ext[:, None]) ints.append(K[:, None]) exts = torch.cat(exts, dim=1)[:1] ints = torch.cat(ints, dim=1)[:1] return exts, ints # Build wobble trajectory def build_wobble_traj(nums, delta): assert s==1 t = torch.linspace(0, 1, nums, dtype=torch.float32, device=camtoworlds.device) t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 tf = torch.eye(4, dtype=torch.float32, device=camtoworlds.device) radius = delta * 0.15 tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() radius = radius[..., None] radius = radius * t tf[..., 0, 3] = torch.sin(2 * torch.pi * t) * radius tf[..., 1, 3] = -torch.cos(2 * torch.pi * t) * radius exts = camtoworlds @ tf ints = intrinsics.repeat(1, exts.shape[1], 1, 1) return exts, ints if s > 1: all_ext, all_int = build_interpolated_traj([i for i in range(s)], interp_per_pair) else: all_ext, all_int = build_wobble_traj(interp_per_pair * 12, splats["means"][0].median(dim=0).values.norm(dim=-1)[None]) rendered_rgbs, rendered_depths = [], [] chunk = 40 if effects is None else 1 t = 0 t_skip = 0 if effects is not None: try: pruned_splats = gs_renderer.prune_gs(splats, gs_renderer.voxel_size) except: pruned_splats = splats # indices = [x for x in range(0, all_ext.shape[1], 2)][:4] # add_ext, add_int = build_interpolated_traj(indices, 150) # add_ext = torch.flip(add_ext, dims=[1]) # add_int = torch.flip(add_int, dims=[1]) add_ext = all_ext[:, :1, :, :].repeat(1, 320, 1, 1) add_int = all_int[:, :1, :, :].repeat(1, 320, 1, 1) shift = pruned_splats["means"][0].median(dim=0).values scale_factor = (pruned_splats["means"][0] - shift).abs().quantile(0.95, dim=0).max() all_ext[0, :, :3, -1] = (all_ext[0, :, :3, -1] - shift) / scale_factor add_ext[0, :, :3, -1] = (add_ext[0, :, :3, -1] - shift) / scale_factor flag = None try: raw_splats = gs_renderer.rasterizer.runner.splats except: pass for st in range(0, add_ext.shape[1]): ed = min(st + 1, add_ext.shape[1]) assert gs_renderer.sh_degree == 0 if flag is not None and (flag < 0.99).any(): break sample_gsplat = {"means": (pruned_splats["means"][0] - shift)/scale_factor, "quats": pruned_splats["quats"][0], "scales": pruned_splats["scales"][0]/scale_factor, "opacities": pruned_splats["opacities"][0],"colors": SH2RGB(pruned_splats["sh"][0].reshape(-1, 3))} effects_splats, flag = effects.apply_effect(sample_gsplat, t, effect_type=effect_type) t += 0.04 effects_splats["sh"] = RGB2SH(effects_splats["colors"]).reshape(-1, 1, 3) try: gs_renderer.rasterizer.runner.splats effects_splats["sh0"] = effects_splats["sh"][:, :1, :] effects_splats["shN"] = effects_splats["sh"][:, 1:, :] effects_splats["scales"] = effects_splats["scales"].log() effects_splats["opacities"] = torch.logit(torch.clamp(effects_splats["opacities"], 1e-6, 1 - 1e-6)) gs_renderer.rasterizer.runner.splats = effects_splats colors, depths, _ = gs_renderer.rasterizer.rasterize_batches( None, None, None, None, None, add_ext[:, st:ed].to(torch.float32), add_int[:, st:ed].to(torch.float32), width=w, height=h, sh_degree=gs_renderer.sh_degree, ) except: colors, depths, _ = gs_renderer.rasterizer.rasterize_batches( effects_splats["means"][None], effects_splats["quats"][None], effects_splats["scales"][None], effects_splats["opacities"][None], effects_splats["sh"][None], add_ext[:, st:ed].to(torch.float32), add_int[:, st:ed].to(torch.float32), width=w, height=h, sh_degree=gs_renderer.sh_degree if "sh" in pruned_splats else None, ) if st > add_ext.shape[1]*0.14: t_skip = t if t_skip == 0 else t_skip # break rendered_rgbs.append(colors) rendered_depths.append(depths) # if (flag == 0).all(): # break t_st = t t_ed = 0 loop_dir = 1 ignore_scale = False for st in tqdm(range(0, all_ext.shape[1], chunk)): ed = min(st + chunk, all_ext.shape[1]) if effects is not None: try: sample_gsplat = {"means": (pruned_splats["means"][0] - shift)/scale_factor, "quats": pruned_splats["quats"][0], "scales": pruned_splats["scales"][0]/scale_factor, "opacities": pruned_splats["opacities"][0],"colors": SH2RGB(pruned_splats["sh"][0].reshape(-1, 3))} except: sample_gsplat = {"means": (pruned_splats["means"][0] - shift)/scale_factor, "quats": pruned_splats["quats"][0], "scales": pruned_splats["scales"][0]/scale_factor, "opacities": pruned_splats["opacities"][0],"colors": SH2RGB(pruned_splats["sh"][0].reshape(-1, 3))} effects_splats, flag = effects.apply_effect(sample_gsplat, t, effect_type=effect_type, ignore_scale=ignore_scale) if loop_dir < 0: t -= 0.04 else: t += 0.04 if flag.mean() < 0.01 and t_ed == 0: t_ed = t effects_splats["sh"] = RGB2SH(effects_splats["colors"]).reshape(-1, 1, 3) effects_splats["sh0"] = effects_splats["sh"][:, :1, :] effects_splats["shN"] = effects_splats["sh"][:, 1:, :] try: gs_renderer.rasterizer.runner.splats effects_splats["sh0"] = effects_splats["sh"][:, :1, :] effects_splats["shN"] = effects_splats["sh"][:, 1:, :] effects_splats["scales"] = effects_splats["scales"].log() effects_splats["opacities"] = torch.logit(torch.clamp(effects_splats["opacities"], 1e-6, 1 - 1e-6)) gs_renderer.rasterizer.runner.splats = effects_splats colors, depths, _ = gs_renderer.rasterizer.rasterize_batches( None, None, None, None, None, all_ext[:, st:ed].to(torch.float32), all_int[:, st:ed].to(torch.float32), width=w, height=h, sh_degree=gs_renderer.sh_degree, ) except: colors, depths, _ = gs_renderer.rasterizer.rasterize_batches( effects_splats["means"][None], effects_splats["quats"][None], effects_splats["scales"][None], effects_splats["opacities"][None], effects_splats["sh"][None], all_ext[:, st:ed].to(torch.float32), all_int[:, st:ed].to(torch.float32), width=w, height=h, sh_degree=gs_renderer.sh_degree if "sh" in pruned_splats else None, ) if t > (all_ext.shape[1]) * 0.04 + t_st - (t_ed - t_st)*2 - 15*0.04 or t < t_st: # ignore_scale = True loop_dir *= -1 t = t_ed if loop_dir == -1 else t else: colors, depths, _ = gs_renderer.rasterizer.rasterize_batches( splats["means"][:1], splats["quats"][:1], splats["scales"][:1], splats["opacities"][:1], splats["sh"][:1] if "sh" in splats else splats["colors"][:1], all_ext[:, st:ed].to(torch.float32), all_int[:, st:ed].to(torch.float32), width=w, height=h, sh_degree=gs_renderer.sh_degree if "sh" in splats else None, ) rendered_rgbs.append(colors) rendered_depths.append(depths) rgbs = torch.cat(rendered_rgbs, dim=1)[0] # [N, H, W, 3] depths = torch.cat(rendered_depths, dim=1)[0, ..., 0] # [N, H, W] def depth_vis(d: torch.Tensor) -> torch.Tensor: valid = d > 0 if valid.any(): near = d[valid].float().quantile(0.01).log() else: near = torch.tensor(0.0, device=d.device) far = d.flatten().float().quantile(0.99).log() x = d.float().clamp(min=1e-9).log() x = 1.0 - (x - near) / (far - near + 1e-9) return apply_color_map_to_image(x, "turbo") frames = [] rgb_frames = [] depth_frames = [] for rgb, dep in zip(rgbs, depths): rgb_img = rgb.permute(2, 0, 1) # [3, H, W] depth_img = depth_vis(dep) # [3, H, W] if save_mode == 'both': combined = torch.cat([rgb_img, depth_img], dim=1) # [3, 2*H, W] frames.append(combined) elif save_mode == 'split': rgb_frames.append(rgb_img) depth_frames.append(depth_img) else: raise ValueError("save_mode must be 'both' or 'split'") def _make_video(frames, path): video = torch.stack(frames).clamp(0, 1) # [N, 3, H, W] video = video.permute(0, 2, 3, 1) # [N, H, W, 3] for moviepy video = (video * 255).to(torch.uint8).cpu().numpy() if loop_reverse and video.shape[0] > 1: video = np.concatenate([video, video[::-1][1:-1]], axis=0) clip = mpy.ImageSequenceClip(list(video), fps=30) clip.write_videofile(str(path), logger=None) # Save videos if save_mode == 'both': _make_video(frames, f"{out_path}.mp4") elif save_mode == 'split': _make_video(rgb_frames, f"{out_path}_rgb.mp4") _make_video(depth_frames, f"{out_path}_depth.mp4") print(f"Video saved to {out_path} (mode: {save_mode})") if effects is not None: try: gs_renderer.rasterizer.runner.splats = raw_splats except: pass torch.cuda.empty_cache()