FaceLift / gslrm /model /utils_transformer.py
weijielyu's picture
demo
cb4cbb0
# Copyright © 2025, Adobe Inc. and its licensors. All rights reserved.
#
# This file is licensed under the Adobe Research License. You may obtain a copy
# of the license at https://raw.githubusercontent.com/adobe-research/FaceLift/main/LICENSE.md
"""
Transformer utilities for GSLRM.
This module contains the core transformer components used by the GSLRM model,
including self-attention, MLP layers, and transformer blocks.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
try:
import xformers.ops as xops
except ImportError as e:
print("Please install xformers to use flashatt v2")
raise e
def _init_weights(module):
"""
Initialize weights for transformer modules.
Reference: https://github.com/karpathy/nanoGPT/blob/eba36e84649f3c6d840a93092cb779a260544d08/model.py#L162-L168
Args:
module: Neural network module to initialize
"""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
class MLP(nn.Module):
"""
Multi-layer perceptron with GELU activation.
Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L49-L65
"""
def __init__(
self,
d,
mlp_ratio=4,
mlp_bias=False,
mlp_dropout=0.0,
mlp_dim=None,
):
"""
Initialize MLP layer.
Args:
d: Input/output dimension
mlp_ratio: Hidden dimension ratio (hidden_dim = d * mlp_ratio)
mlp_bias: Whether to use bias in linear layers
mlp_dropout: Dropout probability
mlp_dim: Explicit hidden dimension (overrides mlp_ratio if provided)
"""
super().__init__()
if mlp_dim is None:
mlp_dim = d * mlp_ratio
self.mlp = nn.Sequential(
nn.Linear(d, mlp_dim, bias=mlp_bias),
nn.GELU(),
nn.Linear(mlp_dim, d, bias=mlp_bias),
nn.Dropout(mlp_dropout),
)
def forward(self, x):
"""
Forward pass through MLP.
Args:
x: Input tensor of shape (batch, seq_len, d)
Returns:
Output tensor of shape (batch, seq_len, d)
"""
return self.mlp(x)
class SelfAttention(nn.Module):
"""
Multi-head self-attention with flash attention support.
Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L68-L92
"""
def __init__(
self,
d,
d_head,
attn_qkv_bias=False,
attn_dropout=0.0,
attn_fc_bias=False,
attn_fc_dropout=0.0,
use_flashatt_v2=True,
):
"""
Initialize self-attention layer.
Args:
d: Token dimension
d_head: Head dimension
attn_qkv_bias: Whether to use bias in QKV projection
attn_dropout: Attention dropout probability
attn_fc_bias: Whether to use bias in output projection
attn_fc_dropout: Output projection dropout probability
use_flashatt_v2: Whether to use flash attention v2
"""
super().__init__()
assert d % d_head == 0, f"Token dimension {d} should be divisible by head dimension {d_head}"
self.d = d
self.d_head = d_head
self.attn_dropout = attn_dropout
self.use_flashatt_v2 = use_flashatt_v2
# QKV projection (projects to 3*d for Q, K, V)
self.to_qkv = nn.Linear(d, 3 * d, bias=attn_qkv_bias)
# Output projection
self.fc = nn.Linear(d, d, bias=attn_fc_bias)
self.attn_fc_dropout = nn.Dropout(attn_fc_dropout)
def forward(self, x, subset_attention_size=None):
"""
Forward pass through self-attention.
Args:
x: Input tensor of shape (batch, seq_len, d)
subset_attention_size: Optional size for subset attention
Returns:
Output tensor of shape (batch, seq_len, d)
"""
# Generate Q, K, V
q, k, v = self.to_qkv(x).split(self.d, dim=2)
if self.use_flashatt_v2:
# Use xformers flash attention
q, k, v = map(
lambda t: rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.d_head),
(q, k, v),
)
if subset_attention_size is not None and subset_attention_size < q.shape[1]:
# Handle subset attention for memory efficiency
x_subset = xops.memory_efficient_attention(
q[:, :subset_attention_size, :, :].contiguous(),
k[:, :subset_attention_size, :, :].contiguous(),
v[:, :subset_attention_size, :, :].contiguous(),
attn_bias=None,
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
)
x_rest = xops.memory_efficient_attention(
q[:, subset_attention_size:, :, :].contiguous(),
k,
v,
attn_bias=None,
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
)
x = torch.cat([x_subset, x_rest], dim=1)
else:
# Standard flash attention
x = xops.memory_efficient_attention(
q, k, v,
attn_bias=None,
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
)
x = rearrange(x, "b l nh dh -> b l (nh dh)")
else:
# Use PyTorch scaled dot product attention
q, k, v = (
rearrange(q, "b l (nh dh) -> b nh l dh", dh=self.d_head),
rearrange(k, "b l (nh dh) -> b nh l dh", dh=self.d_head),
rearrange(v, "b l (nh dh) -> b nh l dh", dh=self.d_head),
)
dropout_p = self.attn_dropout if self.training else 0.0
if subset_attention_size is not None and subset_attention_size < q.shape[2]:
# Handle subset attention
x_subset = F.scaled_dot_product_attention(
q[:, :, :subset_attention_size, :].contiguous(),
k[:, :, :subset_attention_size, :].contiguous(),
v[:, :, :subset_attention_size, :].contiguous(),
dropout_p=dropout_p,
)
x_rest = F.scaled_dot_product_attention(
q[:, :, subset_attention_size:, :].contiguous(),
k, v,
dropout_p=dropout_p,
)
x = torch.cat([x_subset, x_rest], dim=2)
else:
# Standard attention
x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
x = rearrange(x, "b nh l dh -> b l (nh dh)")
# Apply output projection and dropout
return self.attn_fc_dropout(self.fc(x))
class TransformerBlock(nn.Module):
"""
Standard transformer block with pre-normalization.
Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L95-L113
"""
def __init__(
self,
d,
d_head,
ln_bias=False,
attn_qkv_bias=False,
attn_dropout=0.0,
attn_fc_bias=False,
attn_fc_dropout=0.0,
mlp_ratio=4,
mlp_bias=False,
mlp_dropout=0.0,
):
"""
Initialize transformer block.
Args:
d: Token dimension
d_head: Attention head dimension
ln_bias: Whether to use bias in layer norm
attn_qkv_bias: Whether to use bias in attention QKV projection
attn_dropout: Attention dropout probability
attn_fc_bias: Whether to use bias in attention output projection
attn_fc_dropout: Attention output dropout probability
mlp_ratio: MLP hidden dimension ratio
mlp_bias: Whether to use bias in MLP layers
mlp_dropout: MLP dropout probability
"""
super().__init__()
# Layer normalization
self.norm1 = nn.LayerNorm(d, bias=ln_bias)
self.norm2 = nn.LayerNorm(d, bias=ln_bias)
# Self-attention
self.attn = SelfAttention(
d=d,
d_head=d_head,
attn_qkv_bias=attn_qkv_bias,
attn_dropout=attn_dropout,
attn_fc_bias=attn_fc_bias,
attn_fc_dropout=attn_fc_dropout,
)
# MLP
self.mlp = MLP(
d=d,
mlp_ratio=mlp_ratio,
mlp_bias=mlp_bias,
mlp_dropout=mlp_dropout,
)
def forward(self, x, subset_attention_size=None):
"""
Forward pass through transformer block.
Args:
x: Input tensor of shape (batch, seq_len, d)
subset_attention_size: Optional size for subset attention
Returns:
Output tensor of shape (batch, seq_len, d)
"""
# Pre-norm attention with residual connection
x = x + self.attn(self.norm1(x), subset_attention_size=subset_attention_size)
# Pre-norm MLP with residual connection
x = x + self.mlp(self.norm2(x))
return x