# inspired by https://github.com/facebookresearch/vggt/blob/main/src/models/heads/camera_head.py import torch import torch.nn as nn import torch.nn.functional as F from src.models.layers import Mlp from src.models.layers.block import Block class CameraHead(nn.Module): """ Camera head module: predicts camera parameters from token representations using iterative refinement Processes dedicated camera tokens through a series of transformer blocks """ def __init__( self, dim_in: int = 2048, trunk_depth: int = 4, num_heads: int = 16, mlp_ratio: int = 4, init_values: float = 0.01, trans_act: str = "linear", quat_act: str = "linear", fl_act: str = "relu", ): super().__init__() self.out_dim = 9 self.trans_act = trans_act self.quat_act = quat_act self.fl_act = fl_act self.depth = trunk_depth # Build refinement network using transformer block sequence self.refine_net = nn.Sequential( *[ Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values) for _ in range(trunk_depth) ] ) # Normalization for camera tokens and network output self.token_norm = nn.LayerNorm(dim_in) self.out_norm = nn.LayerNorm(dim_in) # Learnable initial camera parameter token self.init_token = nn.Parameter(torch.zeros(1, 1, self.out_dim)) self.param_embed = nn.Linear(self.out_dim, dim_in) # Generate adaptive normalization parameters: shift, scale, and gate self.adapt_norm_gen = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) # Adaptive layer normalization (no learnable parameters) self.adapt_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) self.param_predictor = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.out_dim, drop=0) def forward(self, feat_seq: list, steps: int = 4) -> list: """ Forward pass to predict camera parameters Args: feat_seq: List of token tensors from network, last one used for prediction steps: Number of iterative refinement steps, default 4 Returns: List of predicted camera encodings (post-activation) from each iteration """ # Use tokens from last block for camera prediction latest_feat = feat_seq[-1] # Extract camera tokens cam_tokens = latest_feat[:, :, 0] cam_tokens = self.token_norm(cam_tokens) # Iteratively refine camera pose predictions b, seq_len, feat_dim = cam_tokens.shape # seq_len expected to be 1 curr_pred = None pred_seq = [] for step in range(steps): # Use learned initial token for first iteration if curr_pred is None: net_input = self.param_embed(self.init_token.expand(b, seq_len, -1)) else: curr_pred = curr_pred.detach() net_input = self.param_embed(curr_pred) norm_shift, norm_scale, norm_gate = self.adapt_norm_gen(net_input).chunk(3, dim=-1) mod_cam_feat = norm_gate * self.apply_adaptive_modulation(self.adapt_norm(cam_tokens), norm_shift, norm_scale) mod_cam_feat = mod_cam_feat + cam_tokens proc_feat = self.refine_net(mod_cam_feat) param_delta = self.param_predictor(self.out_norm(proc_feat)) if curr_pred is None: curr_pred = param_delta else: curr_pred = curr_pred + param_delta # Apply final activation functions for translation, quaternion, and field-of-view activated_params = self.apply_camera_parameter_activation(curr_pred) pred_seq.append(activated_params) return pred_seq def apply_camera_parameter_activation(self, params: torch.Tensor) -> torch.Tensor: """ Apply activation functions to camera parameter components Args: params: Tensor containing camera parameters [translation, quaternion, focal_length] Returns: Activated camera parameters tensor """ trans_vec = params[..., :3] quat_vec = params[..., 3:7] fl_vec = params[..., 7:] # or field of view trans_vec = self.apply_parameter_activation(trans_vec, self.trans_act) quat_vec = self.apply_parameter_activation(quat_vec, self.quat_act) fl_vec = self.apply_parameter_activation(fl_vec, self.fl_act) activated_params = torch.cat([trans_vec, quat_vec, fl_vec], dim=-1) return activated_params def apply_parameter_activation(self, tensor: torch.Tensor, act_type: str) -> torch.Tensor: """ Apply specified activation function to parameter tensor Args: tensor: Tensor containing parameter values act_type: Activation type ("linear", "inv_log", "exp", "relu") Returns: Activated parameter tensor """ if act_type == "linear": return tensor elif act_type == "inv_log": return self.apply_inverse_logarithm_transform(tensor) elif act_type == "exp": return torch.exp(tensor) elif act_type == "relu": return F.relu(tensor) else: raise ValueError(f"Unknown activation_type: {act_type}") def apply_inverse_logarithm_transform(self, x: torch.Tensor) -> torch.Tensor: """ Apply inverse logarithm transform: sign(y) * (exp(|y|) - 1) Args: x: Input tensor Returns: Transformed tensor """ return torch.sign(x) * (torch.expm1(torch.abs(x))) def apply_adaptive_modulation(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """ Apply adaptive modulation to input tensor using scaling and shifting parameters """ # Modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 return x * (1 + scale) + shift