File size: 4,981 Bytes
			
			4f4e965  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132  | 
								from typing import List, Optional, Tuple, Union
import torch
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen2 import Qwen2ForCausalLM
from .configuration_dots import DotsVisionConfig, DotsOCRConfig
from .modeling_dots_vision import DotsVisionTransformer
DOTS_VLM_MAX_IMAGES = 200
class DotsOCRForCausalLM(Qwen2ForCausalLM):
    config_class = DotsOCRConfig
    def __init__(self, config: DotsOCRConfig):
        super().__init__(config)
        if isinstance(self.config.vision_config, dict):
            vision_config = DotsVisionConfig(**self.config.vision_config)
            self.config.vision_config = vision_config
        else:
            vision_config = self.config.vision_config
        self.vision_tower = DotsVisionTransformer(vision_config)
    def prepare_inputs_embeds(
        self,
        input_ids: torch.LongTensor,
        pixel_values: Optional[torch.FloatTensor] = None,
        grid_thw: Optional[torch.FloatTensor] = None,
        img_mask: Optional[torch.BoolTensor] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.get_input_embeddings()(input_ids)
        if pixel_values is not None:
            assert img_mask is not None
            if grid_thw.shape[0] > DOTS_VLM_MAX_IMAGES:
                print(
                    f"Num image exceeded: {grid_thw.shape[0]} > {DOTS_VLM_MAX_IMAGES}, which may cause FSDP hang"
                )
            vision_embeddings = self.vision_tower(pixel_values, grid_thw)
            true_indices = torch.nonzero(img_mask).squeeze()
            if len(true_indices) > vision_embeddings.size(0):
                print(
                    f"img_mask sum > VE and will be truncated, mask.sum()={len(true_indices)} {vision_embeddings.size(0)=}"
                )
                true_indices = true_indices[: vision_embeddings.size(0)]
                new_img_mask = torch.zeros_like(img_mask, device=img_mask.device)
                new_img_mask[true_indices[:, 0], true_indices[:, 1]] = True
            else:
                new_img_mask = img_mask
            assert (
                vision_embeddings.size(0) == new_img_mask.sum()
            ), f"{vision_embeddings.size(0)=}, {new_img_mask.sum()=}"
            inputs_embeds = inputs_embeds.masked_scatter(
                new_img_mask.to(inputs_embeds.device).unsqueeze(-1).expand_as(inputs_embeds),
                vision_embeddings.to(inputs_embeds.device).type(inputs_embeds.dtype),
            )
        return inputs_embeds
    def forward(
        self,
        input_ids: torch.LongTensor,
        pixel_values: Optional[torch.FloatTensor] = None,
        image_grid_thw: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        use_cache: Optional[bool] = None,
        logits_to_keep: int = 0,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        assert len(input_ids) >= 1, f"empty input_ids {input_ids.shape=} will cause gradnorm nan"
        if inputs_embeds is None:
            img_mask = input_ids == self.config.image_token_id
            inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, image_grid_thw, img_mask)
        outputs = super().forward(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            labels=labels,
            use_cache=use_cache if use_cache is not None else self.config.use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            # return_dict=return_dict,
            logits_to_keep=logits_to_keep,
            **loss_kwargs,
        )
        return outputs
    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        inputs_embeds=None,
        pixel_values=None,
        attention_mask=None,
        cache_position=None,
        num_logits_to_keep=None,
        **kwargs,
    ):
        model_inputs = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            num_logits_to_keep=num_logits_to_keep,
            **kwargs,
        )
        if cache_position[0] == 0:
            model_inputs["pixel_values"] = pixel_values
        return model_inputs
 |