warrenwjk commited on
Commit
fc8b0b1
·
verified ·
1 Parent(s): 325ed02

Update VisionSdpaAttention to support memory efficient backend.

Browse files

Because of a bug (https://github.com/pytorch/pytorch/issues/127523), the memory-efficient backend for scaled_dot_product_attention currently only supports 4D data.

Typically, users only switch to VisionSdpaAttention when their hardware does not support FlashAttention2 (such as Turing architecture GPUs, e.g., the 2080 Ti, and earlier models). However, memory usage increases dramatically with input size. This implementation helps reduce memory consumption, which is the bottleneck in 99% of cases.

Files changed (1) hide show
  1. modeling_dots_vision.py +14 -5
modeling_dots_vision.py CHANGED
@@ -274,12 +274,21 @@ class VisionSdpaAttention(nn.Module):
274
  for i in range(1, len(cu_seqlens)):
275
  attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
276
 
277
- q = q.transpose(0, 1)
278
- k = k.transpose(0, 1)
279
- v = v.transpose(0, 1)
 
280
 
281
- attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
282
- attn_output = attn_output.transpose(0, 1)
 
 
 
 
 
 
 
 
283
  attn_output = attn_output.reshape(seq_length, -1)
284
 
285
  attn_output = self.proj(attn_output)
 
274
  for i in range(1, len(cu_seqlens)):
275
  attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
276
 
277
+ # Convert q, k, v to 4D to enable : (1, num_heads, seq_length, head_dim)
278
+ q = q.transpose(0, 1).unsqueeze(0) # (1, num_heads, seq_length, head_dim)
279
+ k = k.transpose(0, 1).unsqueeze(0)
280
+ v = v.transpose(0, 1).unsqueeze(0)
281
 
282
+ # See: https://github.com/pytorch/pytorch/issues/127523
283
+ if attention_mask.stride(-1) != 1:
284
+ attention_mask = torch.empty_like(attention_mask, memory_format=torch.contiguous_format).copy_(attention_mask)
285
+
286
+ # use memory efficient backend
287
+ from torch.nn.attention import SDPBackend, sdpa_kernel
288
+ with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
289
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
290
+
291
+ attn_output = attn_output.squeeze(0).transpose(0, 1) # (seq_length, num_heads, head_dim)
292
  attn_output = attn_output.reshape(seq_length, -1)
293
 
294
  attn_output = self.proj(attn_output)