Commit
·
209ae73
1
Parent(s):
2c968b4
Remove hardcoded .cuda() calls to support single forward pass on CPU and ensure DeepSeekOCR model compatibility with transformers==4.52.4
Browse files- modeling_deepseekocr.py +1 -1
- modeling_deepseekv2.py +8 -6
modeling_deepseekocr.py
CHANGED
|
@@ -502,7 +502,7 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 502 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 503 |
# exit()
|
| 504 |
|
| 505 |
-
inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1)
|
| 506 |
|
| 507 |
idx += 1
|
| 508 |
|
|
|
|
| 502 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 503 |
# exit()
|
| 504 |
|
| 505 |
+
inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1), images_in_this_batch)
|
| 506 |
|
| 507 |
idx += 1
|
| 508 |
|
modeling_deepseekv2.py
CHANGED
|
@@ -36,7 +36,6 @@ from transformers.cache_utils import Cache, DynamicCache
|
|
| 36 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 37 |
from transformers.models.llama.modeling_llama import (
|
| 38 |
LlamaAttention,
|
| 39 |
-
LlamaFlashAttention2
|
| 40 |
)
|
| 41 |
from transformers.modeling_outputs import (
|
| 42 |
BaseModelOutputWithPast,
|
|
@@ -60,6 +59,8 @@ from transformers.utils.import_utils import is_torch_fx_available
|
|
| 60 |
|
| 61 |
from .configuration_deepseek_v2 import DeepseekV2Config
|
| 62 |
|
|
|
|
|
|
|
| 63 |
if is_flash_attn_2_available():
|
| 64 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 65 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
@@ -1235,7 +1236,6 @@ ATTENTION_CLASSES = {
|
|
| 1235 |
"mla_flash_attention_2": DeepseekV2FlashAttention2,
|
| 1236 |
|
| 1237 |
"mha_eager": LlamaAttention,
|
| 1238 |
-
"mha_flash_attention_2": LlamaFlashAttention2
|
| 1239 |
}
|
| 1240 |
|
| 1241 |
|
|
@@ -1269,6 +1269,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
| 1269 |
self.post_attention_layernorm = DeepseekV2RMSNorm(
|
| 1270 |
config.hidden_size, eps=config.rms_norm_eps
|
| 1271 |
)
|
|
|
|
|
|
|
| 1272 |
|
| 1273 |
def forward(
|
| 1274 |
self,
|
|
@@ -1303,15 +1305,18 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
| 1303 |
residual = hidden_states
|
| 1304 |
|
| 1305 |
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
|
|
|
| 1306 |
|
| 1307 |
# Self Attention
|
| 1308 |
-
hidden_states, self_attn_weights
|
| 1309 |
hidden_states=hidden_states,
|
| 1310 |
attention_mask=attention_mask,
|
| 1311 |
position_ids=position_ids,
|
| 1312 |
past_key_value=past_key_value,
|
| 1313 |
output_attentions=output_attentions,
|
| 1314 |
use_cache=use_cache,
|
|
|
|
| 1315 |
**kwargs,
|
| 1316 |
)
|
| 1317 |
hidden_states = residual + hidden_states
|
|
@@ -1327,9 +1332,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
| 1327 |
if output_attentions:
|
| 1328 |
outputs += (self_attn_weights,)
|
| 1329 |
|
| 1330 |
-
if use_cache:
|
| 1331 |
-
outputs += (present_key_value,)
|
| 1332 |
-
|
| 1333 |
return outputs
|
| 1334 |
|
| 1335 |
|
|
|
|
| 36 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 37 |
from transformers.models.llama.modeling_llama import (
|
| 38 |
LlamaAttention,
|
|
|
|
| 39 |
)
|
| 40 |
from transformers.modeling_outputs import (
|
| 41 |
BaseModelOutputWithPast,
|
|
|
|
| 59 |
|
| 60 |
from .configuration_deepseek_v2 import DeepseekV2Config
|
| 61 |
|
| 62 |
+
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
| 63 |
+
|
| 64 |
if is_flash_attn_2_available():
|
| 65 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 66 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
| 1236 |
"mla_flash_attention_2": DeepseekV2FlashAttention2,
|
| 1237 |
|
| 1238 |
"mha_eager": LlamaAttention,
|
|
|
|
| 1239 |
}
|
| 1240 |
|
| 1241 |
|
|
|
|
| 1269 |
self.post_attention_layernorm = DeepseekV2RMSNorm(
|
| 1270 |
config.hidden_size, eps=config.rms_norm_eps
|
| 1271 |
)
|
| 1272 |
+
# Compute position_embeddings
|
| 1273 |
+
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
| 1274 |
|
| 1275 |
def forward(
|
| 1276 |
self,
|
|
|
|
| 1305 |
residual = hidden_states
|
| 1306 |
|
| 1307 |
hidden_states = self.input_layernorm(hidden_states)
|
| 1308 |
+
|
| 1309 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1310 |
|
| 1311 |
# Self Attention
|
| 1312 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 1313 |
hidden_states=hidden_states,
|
| 1314 |
attention_mask=attention_mask,
|
| 1315 |
position_ids=position_ids,
|
| 1316 |
past_key_value=past_key_value,
|
| 1317 |
output_attentions=output_attentions,
|
| 1318 |
use_cache=use_cache,
|
| 1319 |
+
position_embeddings=position_embeddings,
|
| 1320 |
**kwargs,
|
| 1321 |
)
|
| 1322 |
hidden_states = residual + hidden_states
|
|
|
|
| 1332 |
if output_attentions:
|
| 1333 |
outputs += (self_attn_weights,)
|
| 1334 |
|
|
|
|
|
|
|
|
|
|
| 1335 |
return outputs
|
| 1336 |
|
| 1337 |
|