Update VisionSdpaAttention to support memory efficient backend. (#27)
Browse files- Update VisionSdpaAttention to support memory efficient backend. (fc8b0b11b92c381639e616506cca574f1b05af09)
Co-authored-by: warren wang <warrenwjk@users.noreply.huggingface.co>
- modeling_dots_vision.py +14 -5
    	
        modeling_dots_vision.py
    CHANGED
    
    | @@ -274,12 +274,21 @@ class VisionSdpaAttention(nn.Module): | |
| 274 | 
             
                    for i in range(1, len(cu_seqlens)):
         | 
| 275 | 
             
                        attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
         | 
| 276 |  | 
| 277 | 
            -
                    q  | 
| 278 | 
            -
                     | 
| 279 | 
            -
                     | 
|  | |
| 280 |  | 
| 281 | 
            -
                     | 
| 282 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 283 | 
             
                    attn_output = attn_output.reshape(seq_length, -1)
         | 
| 284 |  | 
| 285 | 
             
                    attn_output = self.proj(attn_output)
         | 
|  | |
| 274 | 
             
                    for i in range(1, len(cu_seqlens)):
         | 
| 275 | 
             
                        attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
         | 
| 276 |  | 
| 277 | 
            +
                    # Convert q, k, v to 4D to enable : (1, num_heads, seq_length, head_dim)
         | 
| 278 | 
            +
                    q = q.transpose(0, 1).unsqueeze(0)   # (1, num_heads, seq_length, head_dim)
         | 
| 279 | 
            +
                    k = k.transpose(0, 1).unsqueeze(0)
         | 
| 280 | 
            +
                    v = v.transpose(0, 1).unsqueeze(0)
         | 
| 281 |  | 
| 282 | 
            +
                    # See: https://github.com/pytorch/pytorch/issues/127523
         | 
| 283 | 
            +
                    if attention_mask.stride(-1) != 1:
         | 
| 284 | 
            +
                        attention_mask = torch.empty_like(attention_mask, memory_format=torch.contiguous_format).copy_(attention_mask)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    # use memory efficient backend
         | 
| 287 | 
            +
                    from torch.nn.attention import SDPBackend, sdpa_kernel
         | 
| 288 | 
            +
                    with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
         | 
| 289 | 
            +
                        attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    attn_output = attn_output.squeeze(0).transpose(0, 1)  # (seq_length, num_heads, head_dim)
         | 
| 292 | 
             
                    attn_output = attn_output.reshape(seq_length, -1)
         | 
| 293 |  | 
| 294 | 
             
                    attn_output = self.proj(attn_output)
         | 
