update modeling_qwen.py
Browse files- assets/wechat.png +0 -0
- modeling_qwen.py +2 -1
assets/wechat.png
CHANGED
|
|
modeling_qwen.py
CHANGED
|
@@ -193,9 +193,10 @@ class FlashSelfAttention(torch.nn.Module):
|
|
| 193 |
if attention_mask is not None:
|
| 194 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
| 195 |
v = v[indices_k]
|
| 196 |
-
if
|
| 197 |
q = q[indices_k]
|
| 198 |
cu_seqlens_q = cu_seqlens_k
|
|
|
|
| 199 |
else:
|
| 200 |
cu_seqlens_k = torch.arange(
|
| 201 |
0,
|
|
|
|
| 193 |
if attention_mask is not None:
|
| 194 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
| 195 |
v = v[indices_k]
|
| 196 |
+
if self.training or q.size(0) == k.size(0):
|
| 197 |
q = q[indices_k]
|
| 198 |
cu_seqlens_q = cu_seqlens_k
|
| 199 |
+
seqlen_q = seqlen_k
|
| 200 |
else:
|
| 201 |
cu_seqlens_k = torch.arange(
|
| 202 |
0,
|