Remove hardcoded .cuda() calls to support single forward pass on CPU and ensure DeepSeekOCR model compatibility with transformers==4.52.4
#54
by
						
kamalrajkannanmcw
	
							
						- opened
							
					
- modeling_deepseekocr.py +1 -1
- modeling_deepseekv2.py +8 -6
    	
        modeling_deepseekocr.py
    CHANGED
    
    | @@ -502,7 +502,7 @@ class DeepseekOCRModel(DeepseekV2Model): | |
| 502 | 
             
                                images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
         | 
| 503 | 
             
                                # exit()
         | 
| 504 |  | 
| 505 | 
            -
                                inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1) | 
| 506 |  | 
| 507 | 
             
                            idx += 1
         | 
| 508 |  | 
|  | |
| 502 | 
             
                                images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
         | 
| 503 | 
             
                                # exit()
         | 
| 504 |  | 
| 505 | 
            +
                                inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1), images_in_this_batch)
         | 
| 506 |  | 
| 507 | 
             
                            idx += 1
         | 
| 508 |  | 
    	
        modeling_deepseekv2.py
    CHANGED
    
    | @@ -36,7 +36,6 @@ from transformers.cache_utils import Cache, DynamicCache | |
| 36 | 
             
            from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
         | 
| 37 | 
             
            from transformers.models.llama.modeling_llama import (
         | 
| 38 | 
             
                LlamaAttention,
         | 
| 39 | 
            -
                LlamaFlashAttention2
         | 
| 40 | 
             
            )
         | 
| 41 | 
             
            from transformers.modeling_outputs import (
         | 
| 42 | 
             
                BaseModelOutputWithPast,
         | 
| @@ -60,6 +59,8 @@ from transformers.utils.import_utils import is_torch_fx_available | |
| 60 |  | 
| 61 | 
             
            from .configuration_deepseek_v2 import DeepseekV2Config
         | 
| 62 |  | 
|  | |
|  | |
| 63 | 
             
            if is_flash_attn_2_available():
         | 
| 64 | 
             
                from flash_attn import flash_attn_func, flash_attn_varlen_func
         | 
| 65 | 
             
                from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
         | 
| @@ -1235,7 +1236,6 @@ ATTENTION_CLASSES = { | |
| 1235 | 
             
                "mla_flash_attention_2": DeepseekV2FlashAttention2,
         | 
| 1236 |  | 
| 1237 | 
             
                "mha_eager": LlamaAttention,
         | 
| 1238 | 
            -
                "mha_flash_attention_2": LlamaFlashAttention2
         | 
| 1239 | 
             
            }
         | 
| 1240 |  | 
| 1241 |  | 
| @@ -1269,6 +1269,8 @@ class DeepseekV2DecoderLayer(nn.Module): | |
| 1269 | 
             
                    self.post_attention_layernorm = DeepseekV2RMSNorm(
         | 
| 1270 | 
             
                        config.hidden_size, eps=config.rms_norm_eps
         | 
| 1271 | 
             
                    )
         | 
|  | |
|  | |
| 1272 |  | 
| 1273 | 
             
                def forward(
         | 
| 1274 | 
             
                    self,
         | 
| @@ -1303,15 +1305,18 @@ class DeepseekV2DecoderLayer(nn.Module): | |
| 1303 | 
             
                    residual = hidden_states
         | 
| 1304 |  | 
| 1305 | 
             
                    hidden_states = self.input_layernorm(hidden_states)
         | 
|  | |
|  | |
| 1306 |  | 
| 1307 | 
             
                    # Self Attention
         | 
| 1308 | 
            -
                    hidden_states, self_attn_weights | 
| 1309 | 
             
                        hidden_states=hidden_states,
         | 
| 1310 | 
             
                        attention_mask=attention_mask,
         | 
| 1311 | 
             
                        position_ids=position_ids,
         | 
| 1312 | 
             
                        past_key_value=past_key_value,
         | 
| 1313 | 
             
                        output_attentions=output_attentions,
         | 
| 1314 | 
             
                        use_cache=use_cache,
         | 
|  | |
| 1315 | 
             
                        **kwargs,
         | 
| 1316 | 
             
                    )
         | 
| 1317 | 
             
                    hidden_states = residual + hidden_states
         | 
| @@ -1327,9 +1332,6 @@ class DeepseekV2DecoderLayer(nn.Module): | |
| 1327 | 
             
                    if output_attentions:
         | 
| 1328 | 
             
                        outputs += (self_attn_weights,)
         | 
| 1329 |  | 
| 1330 | 
            -
                    if use_cache:
         | 
| 1331 | 
            -
                        outputs += (present_key_value,)
         | 
| 1332 | 
            -
             | 
| 1333 | 
             
                    return outputs
         | 
| 1334 |  | 
| 1335 |  | 
|  | |
| 36 | 
             
            from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
         | 
| 37 | 
             
            from transformers.models.llama.modeling_llama import (
         | 
| 38 | 
             
                LlamaAttention,
         | 
|  | |
| 39 | 
             
            )
         | 
| 40 | 
             
            from transformers.modeling_outputs import (
         | 
| 41 | 
             
                BaseModelOutputWithPast,
         | 
|  | |
| 59 |  | 
| 60 | 
             
            from .configuration_deepseek_v2 import DeepseekV2Config
         | 
| 61 |  | 
| 62 | 
            +
            from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding 
         | 
| 63 | 
            +
             | 
| 64 | 
             
            if is_flash_attn_2_available():
         | 
| 65 | 
             
                from flash_attn import flash_attn_func, flash_attn_varlen_func
         | 
| 66 | 
             
                from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
         | 
|  | |
| 1236 | 
             
                "mla_flash_attention_2": DeepseekV2FlashAttention2,
         | 
| 1237 |  | 
| 1238 | 
             
                "mha_eager": LlamaAttention,
         | 
|  | |
| 1239 | 
             
            }
         | 
| 1240 |  | 
| 1241 |  | 
|  | |
| 1269 | 
             
                    self.post_attention_layernorm = DeepseekV2RMSNorm(
         | 
| 1270 | 
             
                        config.hidden_size, eps=config.rms_norm_eps
         | 
| 1271 | 
             
                    )
         | 
| 1272 | 
            +
                    # Compute position_embeddings
         | 
| 1273 | 
            +
                    self.rotary_emb = LlamaRotaryEmbedding(config=config)
         | 
| 1274 |  | 
| 1275 | 
             
                def forward(
         | 
| 1276 | 
             
                    self,
         | 
|  | |
| 1305 | 
             
                    residual = hidden_states
         | 
| 1306 |  | 
| 1307 | 
             
                    hidden_states = self.input_layernorm(hidden_states)
         | 
| 1308 | 
            +
                    
         | 
| 1309 | 
            +
                    position_embeddings = self.rotary_emb(hidden_states, position_ids)
         | 
| 1310 |  | 
| 1311 | 
             
                    # Self Attention
         | 
| 1312 | 
            +
                    hidden_states, self_attn_weights = self.self_attn(
         | 
| 1313 | 
             
                        hidden_states=hidden_states,
         | 
| 1314 | 
             
                        attention_mask=attention_mask,
         | 
| 1315 | 
             
                        position_ids=position_ids,
         | 
| 1316 | 
             
                        past_key_value=past_key_value,
         | 
| 1317 | 
             
                        output_attentions=output_attentions,
         | 
| 1318 | 
             
                        use_cache=use_cache,
         | 
| 1319 | 
            +
                        position_embeddings=position_embeddings,
         | 
| 1320 | 
             
                        **kwargs,
         | 
| 1321 | 
             
                    )
         | 
| 1322 | 
             
                    hidden_states = residual + hidden_states
         | 
|  | |
| 1332 | 
             
                    if output_attentions:
         | 
| 1333 | 
             
                        outputs += (self_attn_weights,)
         | 
| 1334 |  | 
|  | |
|  | |
|  | |
| 1335 | 
             
                    return outputs
         | 
| 1336 |  | 
| 1337 |  | 
