ROSE / rose /models /diff_mask_predictor.py
weiyuchoumou526's picture
Initial commit
080c0c2
raw
history blame
1.7 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
class DiffMaskPredictor(nn.Module):
def __init__(self, input_dim=4608, patch_grid=(10, 15, 189), output_grid=(81, 480, 720), hidden_dim=256):
"""
Args:
input_dim (int): concatenated feature dimension, e.g. 1536 * num_selected_layers
patch_grid (tuple): (F_p, H_p, W_p) - patch token grid shape (e.g., from transformer block)
output_grid (tuple): (F, H, W) - final full resolution shape for mask
hidden_dim (int): intermediate conv/linear hidden dim
"""
super().__init__()
self.F_p, self.H_p, self.W_p = patch_grid
self.F, self.H, self.W = output_grid
self.project = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
"""
Args:
x (Tensor): shape [B, L, D_total], L = F_p H_p W_p
Returns:
Tensor: predicted diff mask, shape [B, 1, F, H, W]
"""
B, L, D = x.shape
assert L == self.F_p * self.H_p * self.W_p, \
f"Input token length {L} doesn't match patch grid ({self.F_p}, {self.H_p}, {self.W_p})"
x = self.project(x) # [B, L, 1]
x = x.view(B, 1, self.F_p, self.H_p, self.W_p) # [B, 1, F_p, H_p, W_p]
x = F.interpolate(
x, size=(self.F, self.H, self.W),
mode="trilinear", align_corners=False # upsample to match ground truth resolution
)
return x # [B, 1, F, H, W]