ZhenweiWang's picture
Upload folder using huggingface_hub
0ca05b5 verified
from math import atan2, cos, exp, floor, sin, sqrt
import numpy as np
import torch
def fract(x):
"""Get fractional part of a number"""
if isinstance(x, torch.Tensor):
return x - torch.floor(x)
return x - floor(x)
class GSEffects:
"""Convert GLSL GS render effects to PyTorch - vectorized for batch processing"""
def __init__(self, start_time=0.0, end_time=10.0):
"""
Initialize effects with time range
Args:
start_time: Animation start time
end_time: Animation end time
"""
self.start_time = start_time
self.end_time = end_time
@staticmethod
def smoothstep(edge0, edge1, x):
"""GLSL smoothstep function (vectorized)"""
if isinstance(x, torch.Tensor):
result = torch.zeros_like(x, dtype=x.dtype)
mask_low = x < edge0
mask_high = x > edge1
mask_mid = ~(mask_low | mask_high)
t = (x[mask_mid] - edge0) / (edge1 - edge0)
result[mask_mid] = t * t * (3.0 - 2.0 * t)
result[mask_low] = 0.0
result[mask_high] = 1.0
return result
else:
if x < edge0:
return 0.0
if x > edge1:
return 1.0
t = (x - edge0) / (edge1 - edge0)
return t * t * (3.0 - 2.0 * t)
@staticmethod
def step(edge, x):
"""GLSL step function (vectorized)"""
if isinstance(x, torch.Tensor):
return (x >= edge).to(x.dtype)
if isinstance(edge, torch.Tensor):
return (x >= edge).to(edge.dtype)
return 1.0 if x >= edge else 0.0
@staticmethod
def mix(x, y, a):
"""GLSL mix function (linear interpolation, vectorized)"""
return x * (1.0 - a) + y * a
@staticmethod
def clamp(x, min_val, max_val):
"""Clamp value between min and max (vectorized)"""
if isinstance(x, torch.Tensor):
return torch.clamp(x, min_val, max_val)
return max(min_val, min(max_val, x))
@staticmethod
def length_xz(pos):
"""Calculate length of XZ components (vectorized)"""
if pos.dim() == 1:
return torch.sqrt(pos[0]**2 + pos[2]**2)
return torch.sqrt(pos[:, 0]**2 + pos[:, 2]**2)
@staticmethod
def length_vec(v):
"""Calculate vector length (vectorized)"""
if v.dim() == 1:
return torch.sqrt(torch.sum(v**2))
return torch.sqrt(torch.sum(v**2, dim=1))
@staticmethod
def hash(p):
"""Pseudo-random hash function (vectorized)"""
p = fract(p * 0.3183099 + 0.1)
p = p * 17.0
return torch.stack([
fract(p[:, 0] * p[:, 1] * p[:, 2]),
fract(p[:, 0] + p[:, 1] * p[:, 2]),
fract(p[:, 0] * p[:, 1] + p[:, 2])
], dim=1)
@staticmethod
def noise(p):
"""3D Perlin-style noise function (vectorized)"""
i = torch.floor(p).to(torch.long)
f = fract(p)
f = f * f * (3.0 - 2.0 * f)
def get_hash_offset(offset):
return GSEffects.hash(i.to(p.dtype) + offset)
n000 = get_hash_offset(torch.tensor([0, 0, 0], dtype=p.dtype, device=p.device))
n100 = get_hash_offset(torch.tensor([1, 0, 0], dtype=p.dtype, device=p.device))
n010 = get_hash_offset(torch.tensor([0, 1, 0], dtype=p.dtype, device=p.device))
n110 = get_hash_offset(torch.tensor([1, 1, 0], dtype=p.dtype, device=p.device))
n001 = get_hash_offset(torch.tensor([0, 0, 1], dtype=p.dtype, device=p.device))
n101 = get_hash_offset(torch.tensor([1, 0, 1], dtype=p.dtype, device=p.device))
n011 = get_hash_offset(torch.tensor([0, 1, 1], dtype=p.dtype, device=p.device))
n111 = get_hash_offset(torch.tensor([1, 1, 1], dtype=p.dtype, device=p.device))
x0 = GSEffects.mix(n000, n100, f[:, 0:1])
x1 = GSEffects.mix(n010, n110, f[:, 0:1])
x2 = GSEffects.mix(n001, n101, f[:, 0:1])
x3 = GSEffects.mix(n011, n111, f[:, 0:1])
y0 = GSEffects.mix(x0, x1, f[:, 1:2])
y1 = GSEffects.mix(x2, x3, f[:, 1:2])
return GSEffects.mix(y0, y1, f[:, 2:3])
@staticmethod
def rot_2d(angle):
"""2D rotation (vectorized)"""
if isinstance(angle, torch.Tensor):
s = torch.sin(angle)
c = torch.cos(angle)
rot = torch.stack([torch.stack([c, -s], dim=-1),
torch.stack([s, c], dim=-1)], dim=-2).squeeze()
else:
s = np.sin(angle)
c = np.cos(angle)
rot = torch.tensor([[c, -s],
[s, c]]).cuda().float()
return rot
def twister(self, pos, scale, t):
h = self.hash(pos)[:, 0:1] + 0.1
pos_xz_len = self.length_xz(pos)
s = self.smoothstep(0.0, 8.0, t * t * 0.1 - pos_xz_len * 2.0 + 2.0)[:, None]
mask = (torch.linalg.norm(scale, dim=-1, keepdim=True) < 0.05)
pos_y = torch.where(mask, (-10. + pos[:, 1:2]) * (s ** (2 * h)), pos[:, 1:2])
pos_xz = pos[:, [0, 2]] * torch.exp(-1 * torch.linalg.norm(pos[:, [0, 2]], dim=-1, keepdim=True))
pos_xz = torch.einsum("n i, n i j -> n j", pos_xz, self.rot_2d(t * 0.2 + pos[:, 1:2] * 20. * (1 - s)))
pos_new = torch.cat([pos_xz[:, 0:1], pos_y, pos_xz[:, 1:2]], dim=-1)
return pos_new, s ** 4
def rain(self, pos, scale, t):
h = self.hash(pos)
pos_xz_len = self.length_xz(pos)
s = self.smoothstep(0.0, 5.0, t * t * 0.1 - pos_xz_len * 2.0 + 1.0) ** (0.5 + h[:, 0])
y = pos[:, 1:2]
pos_y = torch.minimum(-10. + s[:, None] * 15., pos[:, 1:2])
pos_x = pos[:, 0:1] + pos_y * 0.2
pos_xz = torch.cat([pos_x, pos[:, 2:3]], dim=-1)
pos_xz = pos_xz * torch.matmul(self.rot_2d(t * 0.3), torch.ones_like(pos_xz).unsqueeze(-1)).squeeze(-1)
pos_new = torch.cat([pos_xz[:, 0:1], pos_y, pos_xz[:, 1:2]], dim=-1)
a = self.smoothstep(-10.0, y.squeeze(), pos_y.squeeze())[:, None]
return pos_new, a
def apply_effect(self, gsplat, t, effect_type, ignore_scale=False):
"""
Apply the effect shader logic (vectorized for batch processing)
Args:
gsplat: Dictionary with:
'means': (n, 3) tensor
'scales': (n, 3) tensor
'colors': (n, 3) tensor
'quats': (n, 4) tensor
'opacities': (n,) tensor
t: Current time (normalized based on start_time and end_time)
effect_type: 2=Spread
Returns:
Modified gsplat dictionary
"""
# Normalize time to animation range
normalized_t = t - self.start_time
device = gsplat['means'].device
dtype = gsplat['means'].dtype
output = {
'means': gsplat['means'].clone(),
'quats': gsplat['quats'].clone(),
'scales': gsplat['scales'].clone(),
'opacities': gsplat['opacities'].clone(),
'colors': gsplat['colors'].clone()
}
s = self.smoothstep(0.0, 10.0, normalized_t - 3.2) * 10.0
scales = output['scales']
local_pos = output['means'].clone()
l = self.length_xz(local_pos)
smoothstep_val = None
if effect_type == 2: # Spread Effect
border = torch.abs(s - l - 0.5)
decay = 1.0 - 0.2 * torch.exp(-20.0 * border)
# decay = 1.0 - 0.7 * torch.exp(-10.0 * border)
local_pos = local_pos * decay[:, None]
smoothstep_val = self.smoothstep(s - 0.5, s, l + 0.5)
# final_scales = self.mix(scales, 0.002, smoothstep_val[:, None])
if not ignore_scale:
final_scales = self.mix(scales, 1e-9, smoothstep_val[:, None])
else:
final_scales = scales
noise_input = torch.stack([
local_pos[:, 0] * 2.0 + normalized_t * 0.5,
local_pos[:, 1] * 2.0 + normalized_t * 0.5,
local_pos[:, 2] * 2.0 + normalized_t * 0.5
], dim=1)
noise_val = self.noise(noise_input)
output['means'] = local_pos + 0.0 * noise_val * smoothstep_val[:, None]
output['scales'] = final_scales
at = torch.atan2(local_pos[:, 0], local_pos[:, 2]) / 3.1416
output['colors'] *= self.step(at, normalized_t - 3.1416)[:, None]
output['colors'] += (torch.exp(-20.0 * border) +
torch.exp(-50.0 * torch.abs(normalized_t - at - 3.1416)) * 0.5)[:, None]
output['opacities'] *= self.step(at, normalized_t - 3.1416)
output['opacities'] += (torch.exp(-20.0 * border) +
torch.exp(-50.0 * torch.abs(normalized_t - at - 3.1416)) * 0.5)
# ===== New feature: Randomly mask points based on smoothstep_val =====
# Higher smoothstep_val means higher probability of masking
mask_prob = smoothstep_val.squeeze() if smoothstep_val.dim() > 1 else smoothstep_val
if not hasattr(self, "random_vals"):
self.random_vals = torch.rand(mask_prob.shape, device=device, dtype=dtype)
mask = self.random_vals < mask_prob*0.8 # True indicates the point is masked
# Apply mask to various attributes
if not ignore_scale:
output['means'][mask] *= 0 # Or can be set to other values
output['scales'][mask] *= 0 # Set scales to 0 to make points invisible
output['opacities'][mask] *= 0 # Set opacity to 0 to make points transparent
return output, smoothstep_val
# Usage example
if __name__ == "__main__":
# Create effects processor with time range from 0 to 10 seconds
effects = GSEffects(start_time=0.0, end_time=10.0)
# Sample gsplat data (batch)
n_points = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sample_gsplat = {
'means': torch.randn(n_points, 3, dtype=torch.float32, device=device),
'quats': torch.randn(n_points, 4, dtype=torch.float32, device=device),
'scales': torch.rand(n_points, 3, dtype=torch.float32, device=device),
'opacities': torch.rand(n_points, dtype=torch.float32, device=device),
'colors': torch.rand(n_points, 3, dtype=torch.float32, device=device)
}
# Apply Magic effect at different time points
for t in [0.0, 2.5, 5.0, 7.5, 10.0]:
result = effects.apply_effect(sample_gsplat, t, effect_type=2)
print(f"\nTime: {t}s")
print(f"Center shape: {result['means'].shape}")
print(f"Center[0]: {result['means'][0]}")
print(f"Scales shape: {result['scales'].shape}")
print(f"Scales[0]: {result['scales'][0]}")
print(f"RGB shape: {result['colors'].shape}")
print(f"RGB[0]: {result['colors'][0]}")
print(f"Opacity shape: {result['opacities'].shape}")
print(f"Opacity[0]: {result['opacities'][0]}")