ZhenweiWang's picture
Upload folder using huggingface_hub
0ca05b5 verified
# inspired by https://github.com/DepthAnything/Depth-Anything-V2
from typing import List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models.utils.grid import create_uv_grid, position_grid_to_embed
class DPTHead(nn.Module):
"""
# DPT Head for dense prediction tasks.
# This module implements the DPT (Dense Prediction Transformer) head as proposed in
# "Vision Transformers for Dense Prediction" (https://arxiv.org/abs/2103.13413).
# It takes features from a vision transformer backbone and generates dense (per-pixel) predictions
# by fusing multi-scale features through a series of projection, upsampling, and refinement blocks.
# Args:
# dim_in (int): Number of input feature channels.
# patch_size (int, optional): Patch size used by the backbone, default is 14.
# output_dim (int, optional): Number of output channels, default is 4.
# activation (str, optional): Activation function type for the output head, default is "inv_log".
# conf_activation (str, optional): Activation function type for the confidence/output uncertainty head, default is "expp1".
# features (int, optional): Number of channels used in intermediate feature representations, default is 256.
# out_channels (List[int], optional): Number of channels for each intermediate multi-scale feature.
# intermediate_layer_idx (List[int], optional): Indices specifying which backbone layers to use for multi-scale fusion.
# pos_embed (bool, optional): Whether to add positional encoding to the features, default is True.
# feature_only (bool, optional): If True, only return intermediate features (skip final prediction and activations).
# down_ratio (int, optional): Downsampling ratio of the output predictions, default is 1 (no downsampling).
"""
def __init__(
self,
dim_in: int,
patch_size: int = 14,
output_dim: int = 4,
activation: str = "inv_log+expp1",
features: int = 256,
out_channels: List[int] = [256, 512, 1024, 1024],
pos_embed: bool = True,
down_ratio: int = 1,
is_gsdpt: bool = False
) -> None:
super(DPTHead, self).__init__()
self.patch_size = patch_size
self.activation = activation
self.pos_embed = pos_embed
self.down_ratio = down_ratio
self.is_gsdpt = is_gsdpt
self.norm = nn.LayerNorm(dim_in)
# Projection layers for each output channel from tokens.
self.projects = nn.ModuleList([nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels])
# Resize layers for upsampling feature maps.
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
),
]
)
self.scratch = _make_scratch(out_channels, features, expand=False)
# Attach additional modules to scratch.
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features)
self.scratch.refinenet2 = _make_fusion_block(features)
self.scratch.refinenet3 = _make_fusion_block(features)
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
head_features_1 = features
head_features_2 = 32
if self.is_gsdpt:
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
conv2_in_channels = head_features_1 // 2
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
)
self.input_merger = nn.Sequential(
nn.Conv2d(3, conv2_in_channels, 7, 1, 3),
nn.ReLU()
)
else:
self.scratch.output_conv1 = nn.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
)
conv2_in_channels = head_features_1 // 2
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
)
def forward(
self,
token_list: List[torch.Tensor],
images: torch.Tensor,
patch_start_idx: int,
frames_chunk_size: int = 8,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward pass with optional frame chunking for memory efficiency.
Args:
token_list: List of token tensors from transformer, each [B, N, C]
images: Input images [B, S, 3, H, W], range [0, 1]
patch_start_idx: Starting index of patch tokens
frames_chunk_size: Number of frames per chunk. If None or >= S, process all at once
gradient_checkpoint: Whether to use gradient checkpointing
Returns:
For is_gsdpt: predictions [B, S, ...]
Otherwise: (predictions, confidence), [B, S, X, H, W] and [B, S, 1, H, W]
"""
B, S, _, H, W = images.shape
# Process all frames together if chunk size not specified or large enough
if frames_chunk_size is None or frames_chunk_size >= S:
return self._forward_impl(token_list, images, patch_start_idx)
assert frames_chunk_size > 0
# Process frames in chunks
preds_chunks = []
conf_chunks = []
gs_chunks = []
for frame_start in range(0, S, frames_chunk_size):
frame_end = min(frame_start + frames_chunk_size, S)
if self.is_gsdpt:
gs, preds, conf = self._forward_impl(
token_list, images, patch_start_idx, frame_start, frame_end
)
gs_chunks.append(gs)
preds_chunks.append(preds)
conf_chunks.append(conf)
else:
preds, conf = self._forward_impl(
token_list, images, patch_start_idx, frame_start, frame_end
)
preds_chunks.append(preds)
conf_chunks.append(conf)
# Concatenate chunks along frame dimension
if self.is_gsdpt:
return torch.cat(gs_chunks, dim=1), torch.cat(preds_chunks, dim=1), torch.cat(conf_chunks, dim=1),
else:
return torch.cat(preds_chunks, dim=1), torch.cat(conf_chunks, dim=1)
def _forward_impl(
self,
token_list: List[torch.Tensor],
images: torch.Tensor,
patch_start_idx: int,
frame_start: int = None,
frame_end: int = None,
) -> torch.Tensor:
"""
Core forward implementation for DPT head.
Args:
token_list: List of transformer tokens from each layer, [B, S, N, C]
images: Input images [B, S, 3, H, W]
patch_start_idx: Starting index of patch tokens
frame_start: Start index for frame chunking (optional)
frame_end: End index for frame chunking (optional)
Returns:
If is_gsdpt: (features, preds, conf)
Else: (preds, conf)
"""
# Slice frames if chunking
if frame_start is not None and frame_end is not None:
images = images[:, frame_start:frame_end].contiguous()
B, S, _, H, W = images.shape
ph = H // self.patch_size # patch height
pw = W // self.patch_size # patch width
# Extract and project multi-level features
feats = []
for proj, resize, tokens in zip(self.projects, self.resize_layers, token_list):
# Extract patch tokens
patch_tokens = tokens[:, :, patch_start_idx:]
if frame_start is not None and frame_end is not None:
patch_tokens = patch_tokens[:, frame_start:frame_end]
# Reshape to [B*S, N_patches, C]
patch_tokens = patch_tokens.reshape(B * S, -1, patch_tokens.shape[-1])
patch_tokens = self.norm(patch_tokens)
# Convert to 2D feature map [B*S, C, ph, pw]
feat = patch_tokens.permute(0, 2, 1).reshape(B * S, patch_tokens.shape[-1], ph, pw)
feat = proj(feat)
if self.pos_embed:
feat = self._apply_pos_embed(feat, W, H)
feat = resize(feat)
feats.append(feat)
# Fuse multi-level features
fused = self.scratch_forward(feats)
fused = custom_interpolate(
fused,
size=(
int(ph * self.patch_size / self.down_ratio),
int(pw * self.patch_size / self.down_ratio)
),
mode="bilinear",
align_corners=True,
)
# Apply positional embedding after upsampling
if self.pos_embed:
fused = self._apply_pos_embed(fused, W, H)
# Generate predictions and confidence
if self.is_gsdpt:
# GSDPT: output features, predictions, and confidence
out = self.scratch.output_conv2(fused)
preds, conf = self.activate_head(out, activation=self.activation)
preds = preds.reshape(B, S, *preds.shape[1:])
conf = conf.reshape(B, S, *conf.shape[1:])
# Merge direct image features
img_flat = images.reshape(B * S, -1, H, W)
img_feat = self.input_merger(img_flat)
fused = fused + img_feat
fused = fused.reshape(B, S, *fused.shape[1:])
return fused, preds, conf
else:
# Standard: output predictions and confidence
out = self.scratch.output_conv2(fused)
preds, conf = self.activate_head(out, activation=self.activation)
preds = preds.reshape(B, S, *preds.shape[1:])
conf = conf.reshape(B, S, *conf.shape[1:])
return preds, conf
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
"""
Apply positional embedding to tensor x.
"""
patch_w = x.shape[-1]
patch_h = x.shape[-2]
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
pos_embed = pos_embed * ratio
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
return x + pos_embed
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
"""
Forward pass through the fusion blocks.
Args:
features (List[Tensor]): List of feature maps from different layers.
Returns:
Tensor: Fused feature map.
"""
layer_1, layer_2, layer_3, layer_4 = features
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
del layer_4_rn, layer_4
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
del layer_3_rn, layer_3
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
del layer_2_rn, layer_2
out = self.scratch.refinenet1(out, layer_1_rn)
del layer_1_rn, layer_1
out = self.scratch.output_conv1(out)
return out
def activate_head(self, out_head: torch.Tensor, activation: str = "inv_log+expp1") -> Tuple[torch.Tensor, torch.Tensor]:
"""
Process network output to extract attribute (e.g. points, depth, etc.) and confidence values.
Args:
out_head: Network output tensor (B, C, H, W)
activation: Activation type for processing (e.g., "inv_log+expp1")
Returns:
Tuple of (attribute tensor, confidence tensor)
"""
# Parse activation string
act_attr, act_conf = (activation.split("+") if "+" in activation else (activation, "expp1"))
# (B,C,H,W) -> (B,H,W,C)
feat = out_head.permute(0, 2, 3, 1)
attr, conf = feat[..., :-1], feat[..., -1]
# Map point activations to lambdas for clarity and conciseness
attr_activations = {
"norm_exp": lambda x: (x / x.norm(dim=-1, keepdim=True).clamp(min=1e-8)) * torch.expm1(x.norm(dim=-1, keepdim=True)),
"norm": lambda x: x / x.norm(dim=-1, keepdim=True),
"exp": torch.exp,
"relu": F.relu,
"inv_log": self._apply_inverse_log_transform,
"xy_inv_log": lambda x: torch.cat([
x[..., :2] * self._apply_inverse_log_transform(x[..., 2:]),
self._apply_inverse_log_transform(x[..., 2:])
], dim=-1),
"sigmoid": torch.sigmoid,
"linear": lambda x: x
}
if act_attr not in attr_activations:
raise ValueError(f"Unknown attribute activation: {act_attr}")
attr_out = attr_activations[act_attr](attr)
# Confidence activation mapping
conf_activations = {
"expp1": lambda c: 1 + c.exp(),
"expp0": torch.exp,
"sigmoid": torch.sigmoid
}
if act_conf not in conf_activations:
raise ValueError(f"Unknown confidence activation: {act_conf}")
conf_out = conf_activations[act_conf](conf)
return attr_out, conf_out
def _apply_inverse_log_transform(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
Apply inverse logarithm transform: sign(y) * (exp(|y|) - 1)
Args:
input_tensor: Input tensor
Returns:
Transformed tensor
"""
return torch.sign(input_tensor) * (torch.expm1(torch.abs(input_tensor)))
################################################################################
# DPT Modules
################################################################################
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
return FeatureFusionBlock(
features,
nn.ReLU(inplace=True),
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=size,
has_residual=has_residual,
groups=groups,
)
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
class ResidualConvUnit(nn.Module):
"""Residual convolution module with skip connection."""
def __init__(self, features, activation, bn, groups=1):
"""Initialize ResidualConvUnit.
Args:
features (int): Number of input/output feature channels
activation: Activation function to use
bn (bool): Whether to use batch normalization (currently unused)
groups (int): Number of groups for grouped convolution
"""
super().__init__()
self.bn = bn
self.groups = groups
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.norm1 = None
self.norm2 = None
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass with residual connection.
Args:
x (tensor): Input tensor of shape (B, C, H, W)
Returns:
tensor: Output tensor of shape (B, C, H, W) with residual added
"""
out = self.activation(x)
out = self.conv1(out)
if self.norm1 is not None:
out = self.norm1(out)
out = self.activation(out)
out = self.conv2(out)
if self.norm2 is not None:
out = self.norm2(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=None,
has_residual=True,
groups=1,
):
"""Initialize FeatureFusionBlock.
Args:
features (int): Number of input/output feature channels
activation: Activation function to use
deconv (bool): Whether to use deconvolution
bn (bool): Whether to use batch normalization
expand (bool): Whether to expand features (halve output channels)
align_corners (bool): Align corners for interpolation
size: Target size for upsampling
has_residual (bool): Whether to include residual connection
groups (int): Number of groups for grouped convolution
"""
super(FeatureFusionBlock, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = groups
self.expand = expand
out_features = features
if self.expand == True:
out_features = features // 2
self.out_conv = nn.Conv2d(
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
)
if has_residual:
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
self.has_residual = has_residual
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
self.skip_add = nn.quantized.FloatFunctional()
self.size = size
def forward(self, *xs, size=None):
"""Forward pass through the feature fusion block.
Args:
*xs: Variable number of input tensors. First tensor is the main input,
second tensor (if present) is used for residual connection.
size: Optional target size for upsampling. If None, uses self.size or scale_factor=2.
Returns:
torch.Tensor: Fused and upsampled output tensor.
"""
output = xs[0]
if self.has_residual:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output
def custom_interpolate(
x: torch.Tensor,
size: Tuple[int, int] = None,
scale_factor: float = None,
mode: str = "bilinear",
align_corners: bool = True,
) -> torch.Tensor:
"""
Custom interpolation function to handle large tensors by chunking.
Avoids INT_MAX overflow issues in nn.functional.interpolate when dealing with
very large input tensors by splitting them into smaller chunks.
Args:
x: Input tensor to interpolate
size: Target output size (H, W)
scale_factor: Scaling factor if size is not provided
mode: Interpolation mode (default: "bilinear")
align_corners: Whether to align corners in interpolation
Returns:
Interpolated tensor
"""
if size is None:
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
INT_MAX = 1610612736
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
if input_elements > INT_MAX:
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
interpolated_chunks = [
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
]
x = torch.cat(interpolated_chunks, dim=0)
return x.contiguous()
else:
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)