Spaces:
Build error
Build error
| """ | |
| OmniGen2 Attention Processor Module | |
| Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| """ | |
| import math | |
| from typing import Optional, Tuple, Dict, Any | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import repeat | |
| from flash_attn import flash_attn_varlen_func | |
| from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input | |
| from diffusers.models.attention_processor import Attention | |
| from .embeddings import apply_rotary_emb | |
| class OmniGen2AttnProcessorFlash2Varlen: | |
| """ | |
| Processor for implementing scaled dot-product attention with flash attention and variable length sequences. | |
| This processor is optimized for PyTorch 2.0 and implements: | |
| - Flash attention with variable length sequences | |
| - Rotary position embeddings (RoPE) | |
| - Query-Key normalization | |
| - Proportional attention scaling | |
| Args: | |
| None | |
| Raises: | |
| ImportError: If PyTorch version is less than 2.0 | |
| """ | |
| def __init__(self) -> None: | |
| """Initialize the attention processor.""" | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. " | |
| "Please upgrade PyTorch to version 2.0 or later." | |
| ) | |
| def _upad_input( | |
| self, | |
| query_layer: torch.Tensor, | |
| key_layer: torch.Tensor, | |
| value_layer: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| query_length: int, | |
| num_heads: int, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: | |
| """ | |
| Unpad the input tensors for flash attention. | |
| Args: | |
| query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) | |
| key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) | |
| value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) | |
| attention_mask: Attention mask tensor of shape (batch_size, seq_len) | |
| query_length: Length of the query sequence | |
| num_heads: Number of attention heads | |
| Returns: | |
| Tuple containing: | |
| - Unpadded query tensor | |
| - Unpadded key tensor | |
| - Unpadded value tensor | |
| - Query indices | |
| - Tuple of cumulative sequence lengths for query and key | |
| - Tuple of maximum sequence lengths for query and key | |
| """ | |
| def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
| """Helper function to get unpadding data from attention mask.""" | |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
| max_seqlen_in_batch = seqlens_in_batch.max().item() | |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | |
| return indices, cu_seqlens, max_seqlen_in_batch | |
| indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) | |
| batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | |
| # Unpad key and value layers | |
| key_layer = index_first_axis( | |
| key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), | |
| indices_k, | |
| ) | |
| value_layer = index_first_axis( | |
| value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), | |
| indices_k, | |
| ) | |
| # Handle different query length cases | |
| if query_length == kv_seq_len: | |
| query_layer = index_first_axis( | |
| query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), | |
| indices_k, | |
| ) | |
| cu_seqlens_q = cu_seqlens_k | |
| max_seqlen_in_batch_q = max_seqlen_in_batch_k | |
| indices_q = indices_k | |
| elif query_length == 1: | |
| max_seqlen_in_batch_q = 1 | |
| cu_seqlens_q = torch.arange( | |
| batch_size + 1, dtype=torch.int32, device=query_layer.device | |
| ) | |
| indices_q = cu_seqlens_q[:-1] | |
| query_layer = query_layer.squeeze(1) | |
| else: | |
| attention_mask = attention_mask[:, -query_length:] | |
| query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) | |
| return ( | |
| query_layer, | |
| key_layer, | |
| value_layer, | |
| indices_q, | |
| (cu_seqlens_q, cu_seqlens_k), | |
| (max_seqlen_in_batch_q, max_seqlen_in_batch_k), | |
| ) | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| base_sequence_length: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Process attention computation with flash attention. | |
| Args: | |
| attn: Attention module | |
| hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) | |
| encoder_hidden_states: Encoder hidden states tensor | |
| attention_mask: Optional attention mask tensor | |
| image_rotary_emb: Optional rotary embeddings for image tokens | |
| base_sequence_length: Optional base sequence length for proportional attention | |
| Returns: | |
| torch.Tensor: Processed hidden states after attention computation | |
| """ | |
| batch_size, sequence_length, _ = hidden_states.shape | |
| # Get Query-Key-Value Pair | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query_dim = query.shape[-1] | |
| inner_dim = key.shape[-1] | |
| head_dim = query_dim // attn.heads | |
| dtype = query.dtype | |
| # Get key-value heads | |
| kv_heads = inner_dim // head_dim | |
| # Reshape tensors for attention computation | |
| query = query.view(batch_size, -1, attn.heads, head_dim) | |
| key = key.view(batch_size, -1, kv_heads, head_dim) | |
| value = value.view(batch_size, -1, kv_heads, head_dim) | |
| # Apply Query-Key normalization | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # Apply Rotary Position Embeddings | |
| if image_rotary_emb is not None: | |
| query = apply_rotary_emb(query, image_rotary_emb, use_real=False) | |
| key = apply_rotary_emb(key, image_rotary_emb, use_real=False) | |
| query, key = query.to(dtype), key.to(dtype) | |
| # Calculate attention scale | |
| if base_sequence_length is not None: | |
| softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale | |
| else: | |
| softmax_scale = attn.scale | |
| # Unpad input for flash attention | |
| ( | |
| query_states, | |
| key_states, | |
| value_states, | |
| indices_q, | |
| cu_seq_lens, | |
| max_seq_lens, | |
| ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads) | |
| cu_seqlens_q, cu_seqlens_k = cu_seq_lens | |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | |
| # Handle different number of heads | |
| if kv_heads < attn.heads: | |
| key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) | |
| value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) | |
| # Apply flash attention | |
| attn_output_unpad = flash_attn_varlen_func( | |
| query_states, | |
| key_states, | |
| value_states, | |
| cu_seqlens_q=cu_seqlens_q, | |
| cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_q=max_seqlen_in_batch_q, | |
| max_seqlen_k=max_seqlen_in_batch_k, | |
| dropout_p=0.0, | |
| causal=False, | |
| softmax_scale=softmax_scale, | |
| ) | |
| # Pad output and apply final transformations | |
| hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length) | |
| hidden_states = hidden_states.flatten(-2) | |
| hidden_states = hidden_states.type_as(query) | |
| # Apply output projection | |
| hidden_states = attn.to_out[0](hidden_states) | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states |