Spaces:
Runtime error
Runtime error
File size: 1,698 Bytes
080c0c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
class DiffMaskPredictor(nn.Module):
def __init__(self, input_dim=4608, patch_grid=(10, 15, 189), output_grid=(81, 480, 720), hidden_dim=256):
"""
Args:
input_dim (int): concatenated feature dimension, e.g. 1536 * num_selected_layers
patch_grid (tuple): (F_p, H_p, W_p) - patch token grid shape (e.g., from transformer block)
output_grid (tuple): (F, H, W) - final full resolution shape for mask
hidden_dim (int): intermediate conv/linear hidden dim
"""
super().__init__()
self.F_p, self.H_p, self.W_p = patch_grid
self.F, self.H, self.W = output_grid
self.project = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
"""
Args:
x (Tensor): shape [B, L, D_total], L = F_p H_p W_p
Returns:
Tensor: predicted diff mask, shape [B, 1, F, H, W]
"""
B, L, D = x.shape
assert L == self.F_p * self.H_p * self.W_p, \
f"Input token length {L} doesn't match patch grid ({self.F_p}, {self.H_p}, {self.W_p})"
x = self.project(x) # [B, L, 1]
x = x.view(B, 1, self.F_p, self.H_p, self.W_p) # [B, 1, F_p, H_p, W_p]
x = F.interpolate(
x, size=(self.F, self.H, self.W),
mode="trilinear", align_corners=False # upsample to match ground truth resolution
)
return x # [B, 1, F, H, W]
|