Spaces:
Running
on
Zero
Running
on
Zero
| from math import atan2, cos, exp, floor, sin, sqrt | |
| import numpy as np | |
| import torch | |
| def fract(x): | |
| """Get fractional part of a number""" | |
| if isinstance(x, torch.Tensor): | |
| return x - torch.floor(x) | |
| return x - floor(x) | |
| class GSEffects: | |
| """Convert GLSL GS render effects to PyTorch - vectorized for batch processing""" | |
| def __init__(self, start_time=0.0, end_time=10.0): | |
| """ | |
| Initialize effects with time range | |
| Args: | |
| start_time: Animation start time | |
| end_time: Animation end time | |
| """ | |
| self.start_time = start_time | |
| self.end_time = end_time | |
| def smoothstep(edge0, edge1, x): | |
| """GLSL smoothstep function (vectorized)""" | |
| if isinstance(x, torch.Tensor): | |
| result = torch.zeros_like(x, dtype=x.dtype) | |
| mask_low = x < edge0 | |
| mask_high = x > edge1 | |
| mask_mid = ~(mask_low | mask_high) | |
| t = (x[mask_mid] - edge0) / (edge1 - edge0) | |
| result[mask_mid] = t * t * (3.0 - 2.0 * t) | |
| result[mask_low] = 0.0 | |
| result[mask_high] = 1.0 | |
| return result | |
| else: | |
| if x < edge0: | |
| return 0.0 | |
| if x > edge1: | |
| return 1.0 | |
| t = (x - edge0) / (edge1 - edge0) | |
| return t * t * (3.0 - 2.0 * t) | |
| def step(edge, x): | |
| """GLSL step function (vectorized)""" | |
| if isinstance(x, torch.Tensor): | |
| return (x >= edge).to(x.dtype) | |
| if isinstance(edge, torch.Tensor): | |
| return (x >= edge).to(edge.dtype) | |
| return 1.0 if x >= edge else 0.0 | |
| def mix(x, y, a): | |
| """GLSL mix function (linear interpolation, vectorized)""" | |
| return x * (1.0 - a) + y * a | |
| def clamp(x, min_val, max_val): | |
| """Clamp value between min and max (vectorized)""" | |
| if isinstance(x, torch.Tensor): | |
| return torch.clamp(x, min_val, max_val) | |
| return max(min_val, min(max_val, x)) | |
| def length_xz(pos): | |
| """Calculate length of XZ components (vectorized)""" | |
| if pos.dim() == 1: | |
| return torch.sqrt(pos[0]**2 + pos[2]**2) | |
| return torch.sqrt(pos[:, 0]**2 + pos[:, 2]**2) | |
| def length_vec(v): | |
| """Calculate vector length (vectorized)""" | |
| if v.dim() == 1: | |
| return torch.sqrt(torch.sum(v**2)) | |
| return torch.sqrt(torch.sum(v**2, dim=1)) | |
| def hash(p): | |
| """Pseudo-random hash function (vectorized)""" | |
| p = fract(p * 0.3183099 + 0.1) | |
| p = p * 17.0 | |
| return torch.stack([ | |
| fract(p[:, 0] * p[:, 1] * p[:, 2]), | |
| fract(p[:, 0] + p[:, 1] * p[:, 2]), | |
| fract(p[:, 0] * p[:, 1] + p[:, 2]) | |
| ], dim=1) | |
| def noise(p): | |
| """3D Perlin-style noise function (vectorized)""" | |
| i = torch.floor(p).to(torch.long) | |
| f = fract(p) | |
| f = f * f * (3.0 - 2.0 * f) | |
| def get_hash_offset(offset): | |
| return GSEffects.hash(i.to(p.dtype) + offset) | |
| n000 = get_hash_offset(torch.tensor([0, 0, 0], dtype=p.dtype, device=p.device)) | |
| n100 = get_hash_offset(torch.tensor([1, 0, 0], dtype=p.dtype, device=p.device)) | |
| n010 = get_hash_offset(torch.tensor([0, 1, 0], dtype=p.dtype, device=p.device)) | |
| n110 = get_hash_offset(torch.tensor([1, 1, 0], dtype=p.dtype, device=p.device)) | |
| n001 = get_hash_offset(torch.tensor([0, 0, 1], dtype=p.dtype, device=p.device)) | |
| n101 = get_hash_offset(torch.tensor([1, 0, 1], dtype=p.dtype, device=p.device)) | |
| n011 = get_hash_offset(torch.tensor([0, 1, 1], dtype=p.dtype, device=p.device)) | |
| n111 = get_hash_offset(torch.tensor([1, 1, 1], dtype=p.dtype, device=p.device)) | |
| x0 = GSEffects.mix(n000, n100, f[:, 0:1]) | |
| x1 = GSEffects.mix(n010, n110, f[:, 0:1]) | |
| x2 = GSEffects.mix(n001, n101, f[:, 0:1]) | |
| x3 = GSEffects.mix(n011, n111, f[:, 0:1]) | |
| y0 = GSEffects.mix(x0, x1, f[:, 1:2]) | |
| y1 = GSEffects.mix(x2, x3, f[:, 1:2]) | |
| return GSEffects.mix(y0, y1, f[:, 2:3]) | |
| def rot_2d(angle): | |
| """2D rotation (vectorized)""" | |
| if isinstance(angle, torch.Tensor): | |
| s = torch.sin(angle) | |
| c = torch.cos(angle) | |
| rot = torch.stack([torch.stack([c, -s], dim=-1), | |
| torch.stack([s, c], dim=-1)], dim=-2).squeeze() | |
| else: | |
| s = np.sin(angle) | |
| c = np.cos(angle) | |
| rot = torch.tensor([[c, -s], | |
| [s, c]]).cuda().float() | |
| return rot | |
| def twister(self, pos, scale, t): | |
| h = self.hash(pos)[:, 0:1] + 0.1 | |
| pos_xz_len = self.length_xz(pos) | |
| s = self.smoothstep(0.0, 8.0, t * t * 0.1 - pos_xz_len * 2.0 + 2.0)[:, None] | |
| mask = (torch.linalg.norm(scale, dim=-1, keepdim=True) < 0.05) | |
| pos_y = torch.where(mask, (-10. + pos[:, 1:2]) * (s ** (2 * h)), pos[:, 1:2]) | |
| pos_xz = pos[:, [0, 2]] * torch.exp(-1 * torch.linalg.norm(pos[:, [0, 2]], dim=-1, keepdim=True)) | |
| pos_xz = torch.einsum("n i, n i j -> n j", pos_xz, self.rot_2d(t * 0.2 + pos[:, 1:2] * 20. * (1 - s))) | |
| pos_new = torch.cat([pos_xz[:, 0:1], pos_y, pos_xz[:, 1:2]], dim=-1) | |
| return pos_new, s ** 4 | |
| def rain(self, pos, scale, t): | |
| h = self.hash(pos) | |
| pos_xz_len = self.length_xz(pos) | |
| s = self.smoothstep(0.0, 5.0, t * t * 0.1 - pos_xz_len * 2.0 + 1.0) ** (0.5 + h[:, 0]) | |
| y = pos[:, 1:2] | |
| pos_y = torch.minimum(-10. + s[:, None] * 15., pos[:, 1:2]) | |
| pos_x = pos[:, 0:1] + pos_y * 0.2 | |
| pos_xz = torch.cat([pos_x, pos[:, 2:3]], dim=-1) | |
| pos_xz = pos_xz * torch.matmul(self.rot_2d(t * 0.3), torch.ones_like(pos_xz).unsqueeze(-1)).squeeze(-1) | |
| pos_new = torch.cat([pos_xz[:, 0:1], pos_y, pos_xz[:, 1:2]], dim=-1) | |
| a = self.smoothstep(-10.0, y.squeeze(), pos_y.squeeze())[:, None] | |
| return pos_new, a | |
| def apply_effect(self, gsplat, t, effect_type, ignore_scale=False): | |
| """ | |
| Apply the effect shader logic (vectorized for batch processing) | |
| Args: | |
| gsplat: Dictionary with: | |
| 'means': (n, 3) tensor | |
| 'scales': (n, 3) tensor | |
| 'colors': (n, 3) tensor | |
| 'quats': (n, 4) tensor | |
| 'opacities': (n,) tensor | |
| t: Current time (normalized based on start_time and end_time) | |
| effect_type: 2=Spread | |
| Returns: | |
| Modified gsplat dictionary | |
| """ | |
| # Normalize time to animation range | |
| normalized_t = t - self.start_time | |
| device = gsplat['means'].device | |
| dtype = gsplat['means'].dtype | |
| output = { | |
| 'means': gsplat['means'].clone(), | |
| 'quats': gsplat['quats'].clone(), | |
| 'scales': gsplat['scales'].clone(), | |
| 'opacities': gsplat['opacities'].clone(), | |
| 'colors': gsplat['colors'].clone() | |
| } | |
| s = self.smoothstep(0.0, 10.0, normalized_t - 3.2) * 10.0 | |
| scales = output['scales'] | |
| local_pos = output['means'].clone() | |
| l = self.length_xz(local_pos) | |
| smoothstep_val = None | |
| if effect_type == 2: # Spread Effect | |
| border = torch.abs(s - l - 0.5) | |
| decay = 1.0 - 0.2 * torch.exp(-20.0 * border) | |
| # decay = 1.0 - 0.7 * torch.exp(-10.0 * border) | |
| local_pos = local_pos * decay[:, None] | |
| smoothstep_val = self.smoothstep(s - 0.5, s, l + 0.5) | |
| # final_scales = self.mix(scales, 0.002, smoothstep_val[:, None]) | |
| if not ignore_scale: | |
| final_scales = self.mix(scales, 1e-9, smoothstep_val[:, None]) | |
| else: | |
| final_scales = scales | |
| noise_input = torch.stack([ | |
| local_pos[:, 0] * 2.0 + normalized_t * 0.5, | |
| local_pos[:, 1] * 2.0 + normalized_t * 0.5, | |
| local_pos[:, 2] * 2.0 + normalized_t * 0.5 | |
| ], dim=1) | |
| noise_val = self.noise(noise_input) | |
| output['means'] = local_pos + 0.0 * noise_val * smoothstep_val[:, None] | |
| output['scales'] = final_scales | |
| at = torch.atan2(local_pos[:, 0], local_pos[:, 2]) / 3.1416 | |
| output['colors'] *= self.step(at, normalized_t - 3.1416)[:, None] | |
| output['colors'] += (torch.exp(-20.0 * border) + | |
| torch.exp(-50.0 * torch.abs(normalized_t - at - 3.1416)) * 0.5)[:, None] | |
| output['opacities'] *= self.step(at, normalized_t - 3.1416) | |
| output['opacities'] += (torch.exp(-20.0 * border) + | |
| torch.exp(-50.0 * torch.abs(normalized_t - at - 3.1416)) * 0.5) | |
| # ===== New feature: Randomly mask points based on smoothstep_val ===== | |
| # Higher smoothstep_val means higher probability of masking | |
| mask_prob = smoothstep_val.squeeze() if smoothstep_val.dim() > 1 else smoothstep_val | |
| if not hasattr(self, "random_vals"): | |
| self.random_vals = torch.rand(mask_prob.shape, device=device, dtype=dtype) | |
| mask = self.random_vals < mask_prob*0.8 # True indicates the point is masked | |
| # Apply mask to various attributes | |
| if not ignore_scale: | |
| output['means'][mask] *= 0 # Or can be set to other values | |
| output['scales'][mask] *= 0 # Set scales to 0 to make points invisible | |
| output['opacities'][mask] *= 0 # Set opacity to 0 to make points transparent | |
| return output, smoothstep_val | |
| # Usage example | |
| if __name__ == "__main__": | |
| # Create effects processor with time range from 0 to 10 seconds | |
| effects = GSEffects(start_time=0.0, end_time=10.0) | |
| # Sample gsplat data (batch) | |
| n_points = 100 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| sample_gsplat = { | |
| 'means': torch.randn(n_points, 3, dtype=torch.float32, device=device), | |
| 'quats': torch.randn(n_points, 4, dtype=torch.float32, device=device), | |
| 'scales': torch.rand(n_points, 3, dtype=torch.float32, device=device), | |
| 'opacities': torch.rand(n_points, dtype=torch.float32, device=device), | |
| 'colors': torch.rand(n_points, 3, dtype=torch.float32, device=device) | |
| } | |
| # Apply Magic effect at different time points | |
| for t in [0.0, 2.5, 5.0, 7.5, 10.0]: | |
| result = effects.apply_effect(sample_gsplat, t, effect_type=2) | |
| print(f"\nTime: {t}s") | |
| print(f"Center shape: {result['means'].shape}") | |
| print(f"Center[0]: {result['means'][0]}") | |
| print(f"Scales shape: {result['scales'].shape}") | |
| print(f"Scales[0]: {result['scales'][0]}") | |
| print(f"RGB shape: {result['colors'].shape}") | |
| print(f"RGB[0]: {result['colors'][0]}") | |
| print(f"Opacity shape: {result['opacities'].shape}") | |
| print(f"Opacity[0]: {result['opacities'][0]}") |