Update VisionSdpaAttention to support memory efficient backend.
Browse filesBecause of a bug (https://github.com/pytorch/pytorch/issues/127523), the memory-efficient backend for scaled_dot_product_attention currently only supports 4D data.
Typically, users only switch to VisionSdpaAttention when their hardware does not support FlashAttention2 (such as Turing architecture GPUs, e.g., the 2080 Ti, and earlier models). However, memory usage increases dramatically with input size. This implementation helps reduce memory consumption, which is the bottleneck in 99% of cases.
- modeling_dots_vision.py +14 -5
    	
        modeling_dots_vision.py
    CHANGED
    
    | @@ -274,12 +274,21 @@ class VisionSdpaAttention(nn.Module): | |
| 274 | 
             
                    for i in range(1, len(cu_seqlens)):
         | 
| 275 | 
             
                        attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
         | 
| 276 |  | 
| 277 | 
            -
                    q  | 
| 278 | 
            -
                     | 
| 279 | 
            -
                     | 
|  | |
| 280 |  | 
| 281 | 
            -
                     | 
| 282 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 283 | 
             
                    attn_output = attn_output.reshape(seq_length, -1)
         | 
| 284 |  | 
| 285 | 
             
                    attn_output = self.proj(attn_output)
         | 
|  | |
| 274 | 
             
                    for i in range(1, len(cu_seqlens)):
         | 
| 275 | 
             
                        attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
         | 
| 276 |  | 
| 277 | 
            +
                    # Convert q, k, v to 4D to enable : (1, num_heads, seq_length, head_dim)
         | 
| 278 | 
            +
                    q = q.transpose(0, 1).unsqueeze(0)   # (1, num_heads, seq_length, head_dim)
         | 
| 279 | 
            +
                    k = k.transpose(0, 1).unsqueeze(0)
         | 
| 280 | 
            +
                    v = v.transpose(0, 1).unsqueeze(0)
         | 
| 281 |  | 
| 282 | 
            +
                    # See: https://github.com/pytorch/pytorch/issues/127523
         | 
| 283 | 
            +
                    if attention_mask.stride(-1) != 1:
         | 
| 284 | 
            +
                        attention_mask = torch.empty_like(attention_mask, memory_format=torch.contiguous_format).copy_(attention_mask)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    # use memory efficient backend
         | 
| 287 | 
            +
                    from torch.nn.attention import SDPBackend, sdpa_kernel
         | 
| 288 | 
            +
                    with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
         | 
| 289 | 
            +
                        attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    attn_output = attn_output.squeeze(0).transpose(0, 1)  # (seq_length, num_heads, head_dim)
         | 
| 292 | 
             
                    attn_output = attn_output.reshape(seq_length, -1)
         | 
| 293 |  | 
| 294 | 
             
                    attn_output = self.proj(attn_output)
         | 
