ZhenweiWang's picture
Upload folder using huggingface_hub
0ca05b5 verified
# inspired by https://github.com/facebookresearch/vggt/blob/main/src/models/heads/camera_head.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models.layers import Mlp
from src.models.layers.block import Block
class CameraHead(nn.Module):
"""
Camera head module: predicts camera parameters from token representations using iterative refinement
Processes dedicated camera tokens through a series of transformer blocks
"""
def __init__(
self,
dim_in: int = 2048,
trunk_depth: int = 4,
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
trans_act: str = "linear",
quat_act: str = "linear",
fl_act: str = "relu",
):
super().__init__()
self.out_dim = 9
self.trans_act = trans_act
self.quat_act = quat_act
self.fl_act = fl_act
self.depth = trunk_depth
# Build refinement network using transformer block sequence
self.refine_net = nn.Sequential(
*[
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
for _ in range(trunk_depth)
]
)
# Normalization for camera tokens and network output
self.token_norm = nn.LayerNorm(dim_in)
self.out_norm = nn.LayerNorm(dim_in)
# Learnable initial camera parameter token
self.init_token = nn.Parameter(torch.zeros(1, 1, self.out_dim))
self.param_embed = nn.Linear(self.out_dim, dim_in)
# Generate adaptive normalization parameters: shift, scale, and gate
self.adapt_norm_gen = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
# Adaptive layer normalization (no learnable parameters)
self.adapt_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
self.param_predictor = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.out_dim, drop=0)
def forward(self, feat_seq: list, steps: int = 4) -> list:
"""
Forward pass to predict camera parameters
Args:
feat_seq: List of token tensors from network, last one used for prediction
steps: Number of iterative refinement steps, default 4
Returns:
List of predicted camera encodings (post-activation) from each iteration
"""
# Use tokens from last block for camera prediction
latest_feat = feat_seq[-1]
# Extract camera tokens
cam_tokens = latest_feat[:, :, 0]
cam_tokens = self.token_norm(cam_tokens)
# Iteratively refine camera pose predictions
b, seq_len, feat_dim = cam_tokens.shape # seq_len expected to be 1
curr_pred = None
pred_seq = []
for step in range(steps):
# Use learned initial token for first iteration
if curr_pred is None:
net_input = self.param_embed(self.init_token.expand(b, seq_len, -1))
else:
curr_pred = curr_pred.detach()
net_input = self.param_embed(curr_pred)
norm_shift, norm_scale, norm_gate = self.adapt_norm_gen(net_input).chunk(3, dim=-1)
mod_cam_feat = norm_gate * self.apply_adaptive_modulation(self.adapt_norm(cam_tokens), norm_shift, norm_scale)
mod_cam_feat = mod_cam_feat + cam_tokens
proc_feat = self.refine_net(mod_cam_feat)
param_delta = self.param_predictor(self.out_norm(proc_feat))
if curr_pred is None:
curr_pred = param_delta
else:
curr_pred = curr_pred + param_delta
# Apply final activation functions for translation, quaternion, and field-of-view
activated_params = self.apply_camera_parameter_activation(curr_pred)
pred_seq.append(activated_params)
return pred_seq
def apply_camera_parameter_activation(self, params: torch.Tensor) -> torch.Tensor:
"""
Apply activation functions to camera parameter components
Args:
params: Tensor containing camera parameters [translation, quaternion, focal_length]
Returns:
Activated camera parameters tensor
"""
trans_vec = params[..., :3]
quat_vec = params[..., 3:7]
fl_vec = params[..., 7:] # or field of view
trans_vec = self.apply_parameter_activation(trans_vec, self.trans_act)
quat_vec = self.apply_parameter_activation(quat_vec, self.quat_act)
fl_vec = self.apply_parameter_activation(fl_vec, self.fl_act)
activated_params = torch.cat([trans_vec, quat_vec, fl_vec], dim=-1)
return activated_params
def apply_parameter_activation(self, tensor: torch.Tensor, act_type: str) -> torch.Tensor:
"""
Apply specified activation function to parameter tensor
Args:
tensor: Tensor containing parameter values
act_type: Activation type ("linear", "inv_log", "exp", "relu")
Returns:
Activated parameter tensor
"""
if act_type == "linear":
return tensor
elif act_type == "inv_log":
return self.apply_inverse_logarithm_transform(tensor)
elif act_type == "exp":
return torch.exp(tensor)
elif act_type == "relu":
return F.relu(tensor)
else:
raise ValueError(f"Unknown activation_type: {act_type}")
def apply_inverse_logarithm_transform(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply inverse logarithm transform: sign(y) * (exp(|y|) - 1)
Args:
x: Input tensor
Returns:
Transformed tensor
"""
return torch.sign(x) * (torch.expm1(torch.abs(x)))
def apply_adaptive_modulation(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Apply adaptive modulation to input tensor using scaling and shifting parameters
"""
# Modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
return x * (1 + scale) + shift