Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright © 2025, Adobe Inc. and its licensors. All rights reserved. | |
| # | |
| # This file is licensed under the Adobe Research License. You may obtain a copy | |
| # of the license at https://raw.githubusercontent.com/adobe-research/FaceLift/main/LICENSE.md | |
| """ | |
| Transformer utilities for GSLRM. | |
| This module contains the core transformer components used by the GSLRM model, | |
| including self-attention, MLP layers, and transformer blocks. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| try: | |
| import xformers.ops as xops | |
| except ImportError as e: | |
| print("Please install xformers to use flashatt v2") | |
| raise e | |
| def _init_weights(module): | |
| """ | |
| Initialize weights for transformer modules. | |
| Reference: https://github.com/karpathy/nanoGPT/blob/eba36e84649f3c6d840a93092cb779a260544d08/model.py#L162-L168 | |
| Args: | |
| module: Neural network module to initialize | |
| """ | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| class MLP(nn.Module): | |
| """ | |
| Multi-layer perceptron with GELU activation. | |
| Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L49-L65 | |
| """ | |
| def __init__( | |
| self, | |
| d, | |
| mlp_ratio=4, | |
| mlp_bias=False, | |
| mlp_dropout=0.0, | |
| mlp_dim=None, | |
| ): | |
| """ | |
| Initialize MLP layer. | |
| Args: | |
| d: Input/output dimension | |
| mlp_ratio: Hidden dimension ratio (hidden_dim = d * mlp_ratio) | |
| mlp_bias: Whether to use bias in linear layers | |
| mlp_dropout: Dropout probability | |
| mlp_dim: Explicit hidden dimension (overrides mlp_ratio if provided) | |
| """ | |
| super().__init__() | |
| if mlp_dim is None: | |
| mlp_dim = d * mlp_ratio | |
| self.mlp = nn.Sequential( | |
| nn.Linear(d, mlp_dim, bias=mlp_bias), | |
| nn.GELU(), | |
| nn.Linear(mlp_dim, d, bias=mlp_bias), | |
| nn.Dropout(mlp_dropout), | |
| ) | |
| def forward(self, x): | |
| """ | |
| Forward pass through MLP. | |
| Args: | |
| x: Input tensor of shape (batch, seq_len, d) | |
| Returns: | |
| Output tensor of shape (batch, seq_len, d) | |
| """ | |
| return self.mlp(x) | |
| class SelfAttention(nn.Module): | |
| """ | |
| Multi-head self-attention with flash attention support. | |
| Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L68-L92 | |
| """ | |
| def __init__( | |
| self, | |
| d, | |
| d_head, | |
| attn_qkv_bias=False, | |
| attn_dropout=0.0, | |
| attn_fc_bias=False, | |
| attn_fc_dropout=0.0, | |
| use_flashatt_v2=True, | |
| ): | |
| """ | |
| Initialize self-attention layer. | |
| Args: | |
| d: Token dimension | |
| d_head: Head dimension | |
| attn_qkv_bias: Whether to use bias in QKV projection | |
| attn_dropout: Attention dropout probability | |
| attn_fc_bias: Whether to use bias in output projection | |
| attn_fc_dropout: Output projection dropout probability | |
| use_flashatt_v2: Whether to use flash attention v2 | |
| """ | |
| super().__init__() | |
| assert d % d_head == 0, f"Token dimension {d} should be divisible by head dimension {d_head}" | |
| self.d = d | |
| self.d_head = d_head | |
| self.attn_dropout = attn_dropout | |
| self.use_flashatt_v2 = use_flashatt_v2 | |
| # QKV projection (projects to 3*d for Q, K, V) | |
| self.to_qkv = nn.Linear(d, 3 * d, bias=attn_qkv_bias) | |
| # Output projection | |
| self.fc = nn.Linear(d, d, bias=attn_fc_bias) | |
| self.attn_fc_dropout = nn.Dropout(attn_fc_dropout) | |
| def forward(self, x, subset_attention_size=None): | |
| """ | |
| Forward pass through self-attention. | |
| Args: | |
| x: Input tensor of shape (batch, seq_len, d) | |
| subset_attention_size: Optional size for subset attention | |
| Returns: | |
| Output tensor of shape (batch, seq_len, d) | |
| """ | |
| # Generate Q, K, V | |
| q, k, v = self.to_qkv(x).split(self.d, dim=2) | |
| if self.use_flashatt_v2: | |
| # Use xformers flash attention | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.d_head), | |
| (q, k, v), | |
| ) | |
| if subset_attention_size is not None and subset_attention_size < q.shape[1]: | |
| # Handle subset attention for memory efficiency | |
| x_subset = xops.memory_efficient_attention( | |
| q[:, :subset_attention_size, :, :].contiguous(), | |
| k[:, :subset_attention_size, :, :].contiguous(), | |
| v[:, :subset_attention_size, :, :].contiguous(), | |
| attn_bias=None, | |
| op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), | |
| ) | |
| x_rest = xops.memory_efficient_attention( | |
| q[:, subset_attention_size:, :, :].contiguous(), | |
| k, | |
| v, | |
| attn_bias=None, | |
| op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), | |
| ) | |
| x = torch.cat([x_subset, x_rest], dim=1) | |
| else: | |
| # Standard flash attention | |
| x = xops.memory_efficient_attention( | |
| q, k, v, | |
| attn_bias=None, | |
| op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), | |
| ) | |
| x = rearrange(x, "b l nh dh -> b l (nh dh)") | |
| else: | |
| # Use PyTorch scaled dot product attention | |
| q, k, v = ( | |
| rearrange(q, "b l (nh dh) -> b nh l dh", dh=self.d_head), | |
| rearrange(k, "b l (nh dh) -> b nh l dh", dh=self.d_head), | |
| rearrange(v, "b l (nh dh) -> b nh l dh", dh=self.d_head), | |
| ) | |
| dropout_p = self.attn_dropout if self.training else 0.0 | |
| if subset_attention_size is not None and subset_attention_size < q.shape[2]: | |
| # Handle subset attention | |
| x_subset = F.scaled_dot_product_attention( | |
| q[:, :, :subset_attention_size, :].contiguous(), | |
| k[:, :, :subset_attention_size, :].contiguous(), | |
| v[:, :, :subset_attention_size, :].contiguous(), | |
| dropout_p=dropout_p, | |
| ) | |
| x_rest = F.scaled_dot_product_attention( | |
| q[:, :, subset_attention_size:, :].contiguous(), | |
| k, v, | |
| dropout_p=dropout_p, | |
| ) | |
| x = torch.cat([x_subset, x_rest], dim=2) | |
| else: | |
| # Standard attention | |
| x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) | |
| x = rearrange(x, "b nh l dh -> b l (nh dh)") | |
| # Apply output projection and dropout | |
| return self.attn_fc_dropout(self.fc(x)) | |
| class TransformerBlock(nn.Module): | |
| """ | |
| Standard transformer block with pre-normalization. | |
| Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L95-L113 | |
| """ | |
| def __init__( | |
| self, | |
| d, | |
| d_head, | |
| ln_bias=False, | |
| attn_qkv_bias=False, | |
| attn_dropout=0.0, | |
| attn_fc_bias=False, | |
| attn_fc_dropout=0.0, | |
| mlp_ratio=4, | |
| mlp_bias=False, | |
| mlp_dropout=0.0, | |
| ): | |
| """ | |
| Initialize transformer block. | |
| Args: | |
| d: Token dimension | |
| d_head: Attention head dimension | |
| ln_bias: Whether to use bias in layer norm | |
| attn_qkv_bias: Whether to use bias in attention QKV projection | |
| attn_dropout: Attention dropout probability | |
| attn_fc_bias: Whether to use bias in attention output projection | |
| attn_fc_dropout: Attention output dropout probability | |
| mlp_ratio: MLP hidden dimension ratio | |
| mlp_bias: Whether to use bias in MLP layers | |
| mlp_dropout: MLP dropout probability | |
| """ | |
| super().__init__() | |
| # Layer normalization | |
| self.norm1 = nn.LayerNorm(d, bias=ln_bias) | |
| self.norm2 = nn.LayerNorm(d, bias=ln_bias) | |
| # Self-attention | |
| self.attn = SelfAttention( | |
| d=d, | |
| d_head=d_head, | |
| attn_qkv_bias=attn_qkv_bias, | |
| attn_dropout=attn_dropout, | |
| attn_fc_bias=attn_fc_bias, | |
| attn_fc_dropout=attn_fc_dropout, | |
| ) | |
| # MLP | |
| self.mlp = MLP( | |
| d=d, | |
| mlp_ratio=mlp_ratio, | |
| mlp_bias=mlp_bias, | |
| mlp_dropout=mlp_dropout, | |
| ) | |
| def forward(self, x, subset_attention_size=None): | |
| """ | |
| Forward pass through transformer block. | |
| Args: | |
| x: Input tensor of shape (batch, seq_len, d) | |
| subset_attention_size: Optional size for subset attention | |
| Returns: | |
| Output tensor of shape (batch, seq_len, d) | |
| """ | |
| # Pre-norm attention with residual connection | |
| x = x + self.attn(self.norm1(x), subset_attention_size=subset_attention_size) | |
| # Pre-norm MLP with residual connection | |
| x = x + self.mlp(self.norm2(x)) | |
| return x |