File size: 6,329 Bytes
0ca05b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# inspired by https://github.com/facebookresearch/vggt/blob/main/src/models/heads/camera_head.py
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.models.layers import Mlp
from src.models.layers.block import Block


class CameraHead(nn.Module):
    """
    Camera head module: predicts camera parameters from token representations using iterative refinement
    
    Processes dedicated camera tokens through a series of transformer blocks
    """
    def __init__(
        self,
        dim_in: int = 2048,
        trunk_depth: int = 4,
        num_heads: int = 16,
        mlp_ratio: int = 4,
        init_values: float = 0.01,
        trans_act: str = "linear",
        quat_act: str = "linear",
        fl_act: str = "relu",
    ):
        super().__init__()

        self.out_dim = 9
        self.trans_act = trans_act
        self.quat_act = quat_act
        self.fl_act = fl_act
        self.depth = trunk_depth

        # Build refinement network using transformer block sequence
        self.refine_net = nn.Sequential(
            *[
                Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
                for _ in range(trunk_depth)
            ]
        )

        # Normalization for camera tokens and network output
        self.token_norm = nn.LayerNorm(dim_in)
        self.out_norm = nn.LayerNorm(dim_in)

        # Learnable initial camera parameter token
        self.init_token = nn.Parameter(torch.zeros(1, 1, self.out_dim))
        self.param_embed = nn.Linear(self.out_dim, dim_in)

        # Generate adaptive normalization parameters: shift, scale, and gate
        self.adapt_norm_gen = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))

        # Adaptive layer normalization (no learnable parameters)
        self.adapt_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
        self.param_predictor = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.out_dim, drop=0)

    def forward(self, feat_seq: list, steps: int = 4) -> list:
        """
        Forward pass to predict camera parameters
        
        Args:
            feat_seq: List of token tensors from network, last one used for prediction
            steps: Number of iterative refinement steps, default 4
            
        Returns:
            List of predicted camera encodings (post-activation) from each iteration
        """
        # Use tokens from last block for camera prediction
        latest_feat = feat_seq[-1]

        # Extract camera tokens
        cam_tokens = latest_feat[:, :, 0]
        cam_tokens = self.token_norm(cam_tokens)

        # Iteratively refine camera pose predictions
        b, seq_len, feat_dim = cam_tokens.shape  # seq_len expected to be 1
        curr_pred = None
        pred_seq = []

        for step in range(steps):
            # Use learned initial token for first iteration
            if curr_pred is None:
                net_input = self.param_embed(self.init_token.expand(b, seq_len, -1))
            else:
                curr_pred = curr_pred.detach()
                net_input = self.param_embed(curr_pred)
            norm_shift, norm_scale, norm_gate = self.adapt_norm_gen(net_input).chunk(3, dim=-1)
            mod_cam_feat = norm_gate * self.apply_adaptive_modulation(self.adapt_norm(cam_tokens), norm_shift, norm_scale)
            mod_cam_feat = mod_cam_feat + cam_tokens

            proc_feat = self.refine_net(mod_cam_feat)
            param_delta = self.param_predictor(self.out_norm(proc_feat))

            if curr_pred is None:
                curr_pred = param_delta
            else:
                curr_pred = curr_pred + param_delta

            # Apply final activation functions for translation, quaternion, and field-of-view
            activated_params = self.apply_camera_parameter_activation(curr_pred)
            pred_seq.append(activated_params)

        return pred_seq

    def apply_camera_parameter_activation(self, params: torch.Tensor) -> torch.Tensor:
        """
        Apply activation functions to camera parameter components
        
        Args:
            params: Tensor containing camera parameters [translation, quaternion, focal_length]
            
        Returns:
            Activated camera parameters tensor
        """
        trans_vec = params[..., :3]
        quat_vec = params[..., 3:7]
        fl_vec = params[..., 7:]  # or field of view

        trans_vec = self.apply_parameter_activation(trans_vec, self.trans_act)
        quat_vec = self.apply_parameter_activation(quat_vec, self.quat_act)
        fl_vec = self.apply_parameter_activation(fl_vec, self.fl_act)

        activated_params = torch.cat([trans_vec, quat_vec, fl_vec], dim=-1)
        return activated_params

    def apply_parameter_activation(self, tensor: torch.Tensor, act_type: str) -> torch.Tensor:
        """
        Apply specified activation function to parameter tensor
        
        Args:
            tensor: Tensor containing parameter values
            act_type: Activation type ("linear", "inv_log", "exp", "relu")
            
        Returns:
            Activated parameter tensor
        """
        if act_type == "linear":
            return tensor
        elif act_type == "inv_log":
            return self.apply_inverse_logarithm_transform(tensor)
        elif act_type == "exp":
            return torch.exp(tensor)
        elif act_type == "relu":
            return F.relu(tensor)
        else:
            raise ValueError(f"Unknown activation_type: {act_type}")

    def apply_inverse_logarithm_transform(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply inverse logarithm transform: sign(y) * (exp(|y|) - 1)
        
        Args:
            x: Input tensor
            
        Returns:
            Transformed tensor
        """
        return torch.sign(x) * (torch.expm1(torch.abs(x)))

    def apply_adaptive_modulation(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
        """
        Apply adaptive modulation to input tensor using scaling and shifting parameters
        """
        # Modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
        return x * (1 + scale) + shift