Remove hardcoded .cuda() calls to support single forward pass on CPU and ensure DeepSeekOCR model compatibility with transformers==4.52.4
		209ae73
		
		| from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM | |
| from .configuration_deepseek_v2 import DeepseekV2Config | |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | |
| from typing import List, Optional, Tuple, Union | |
| from transformers.cache_utils import Cache | |
| import requests | |
| from PIL import Image, ImageOps, ImageDraw, ImageFont | |
| from io import BytesIO | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import CrossEntropyLoss | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import InterpolationMode | |
| import os | |
| from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector | |
| from addict import Dict | |
| from transformers import TextStreamer | |
| from .conversation import get_conv_template | |
| from abc import ABC | |
| import math | |
| import re | |
| from tqdm import tqdm | |
| import numpy as np | |
| import time | |
| def load_image(image_path): | |
| try: | |
| image = Image.open(image_path) | |
| corrected_image = ImageOps.exif_transpose(image) | |
| return corrected_image | |
| except Exception as e: | |
| print(f"error: {e}") | |
| try: | |
| return Image.open(image_path) | |
| except: | |
| return None | |
| def re_match(text): | |
| pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n' | |
| # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL) | |
| mathes_image = [] | |
| mathes_other = [] | |
| for a_match in matches: | |
| if '<|ref|>image<|/ref|>' in a_match[0]: | |
| mathes_image.append(a_match[0]) | |
| else: | |
| mathes_other.append(a_match[0]) | |
| return matches, mathes_image, mathes_other | |
| def extract_coordinates_and_label(ref_text, image_width, image_height): | |
| try: | |
| label_type = ref_text[1] | |
| cor_list = eval(ref_text[2]) | |
| except Exception as e: | |
| print(e) | |
| return None | |
| return (label_type, cor_list) | |
| def draw_bounding_boxes(image, refs, ouput_path): | |
| image_width, image_height = image.size | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) | |
| draw2 = ImageDraw.Draw(overlay) | |
| # try: | |
| # except IOError: | |
| # try: | |
| # font = ImageFont.truetype("DejaVuSans.ttf", 20) | |
| # except IOError: | |
| font = ImageFont.load_default() | |
| img_idx = 0 | |
| for i, ref in enumerate(refs): | |
| try: | |
| result = extract_coordinates_and_label(ref, image_width, image_height) | |
| if result: | |
| label_type, points_list = result | |
| color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) | |
| color_a = color + (20, ) | |
| for points in points_list: | |
| x1, y1, x2, y2 = points | |
| x1 = int(x1 / 999 * image_width) | |
| y1 = int(y1 / 999 * image_height) | |
| x2 = int(x2 / 999 * image_width) | |
| y2 = int(y2 / 999 * image_height) | |
| if label_type == 'image': | |
| try: | |
| cropped = image.crop((x1, y1, x2, y2)) | |
| cropped.save(f"{ouput_path}/images/{img_idx}.jpg") | |
| except Exception as e: | |
| print(e) | |
| pass | |
| img_idx += 1 | |
| try: | |
| if label_type == 'title': | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=4) | |
| draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) | |
| else: | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=2) | |
| draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) | |
| text_x = x1 | |
| text_y = max(0, y1 - 15) | |
| text_bbox = draw.textbbox((0, 0), label_type, font=font) | |
| text_width = text_bbox[2] - text_bbox[0] | |
| text_height = text_bbox[3] - text_bbox[1] | |
| draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], | |
| fill=(255, 255, 255, 30)) | |
| draw.text((text_x, text_y), label_type, font=font, fill=color) | |
| except: | |
| pass | |
| except: | |
| continue | |
| img_draw.paste(overlay, (0, 0), overlay) | |
| return img_draw | |
| def process_image_with_refs(image, ref_texts, output_path): | |
| result_image = draw_bounding_boxes(image, ref_texts, output_path) | |
| return result_image | |
| def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | |
| best_ratio_diff = float('inf') | |
| best_ratio = (1, 1) | |
| area = width * height | |
| for ratio in target_ratios: | |
| target_aspect_ratio = ratio[0] / ratio[1] | |
| ratio_diff = abs(aspect_ratio - target_aspect_ratio) | |
| if ratio_diff < best_ratio_diff: | |
| best_ratio_diff = ratio_diff | |
| best_ratio = ratio | |
| elif ratio_diff == best_ratio_diff: | |
| if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | |
| best_ratio = ratio | |
| # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') | |
| return best_ratio | |
| def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False): | |
| orig_width, orig_height = image.size | |
| aspect_ratio = orig_width / orig_height | |
| # calculate the existing image aspect ratio | |
| target_ratios = set( | |
| (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if | |
| i * j <= max_num and i * j >= min_num) | |
| # print(target_ratios) | |
| target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
| # find the closest aspect ratio to the target | |
| target_aspect_ratio = find_closest_aspect_ratio( | |
| aspect_ratio, target_ratios, orig_width, orig_height, image_size) | |
| # print(target_aspect_ratio) | |
| # calculate the target width and height | |
| target_width = image_size * target_aspect_ratio[0] | |
| target_height = image_size * target_aspect_ratio[1] | |
| blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | |
| # resize the image | |
| resized_img = image.resize((target_width, target_height)) | |
| processed_images = [] | |
| for i in range(blocks): | |
| box = ( | |
| (i % (target_width // image_size)) * image_size, | |
| (i // (target_width // image_size)) * image_size, | |
| ((i % (target_width // image_size)) + 1) * image_size, | |
| ((i // (target_width // image_size)) + 1) * image_size | |
| ) | |
| # split the image | |
| split_img = resized_img.crop(box) | |
| processed_images.append(split_img) | |
| assert len(processed_images) == blocks | |
| if use_thumbnail and len(processed_images) != 1: | |
| thumbnail_img = image.resize((image_size, image_size)) | |
| processed_images.append(thumbnail_img) | |
| return processed_images, target_aspect_ratio | |
| def normalize_transform(mean, std): | |
| if mean is None and std is None: | |
| transform = None | |
| elif mean is None and std is not None: | |
| mean = [0.] * len(std) | |
| transform = transforms.Normalize(mean=mean, std=std) | |
| elif mean is not None and std is None: | |
| std = [1.] * len(mean) | |
| transform = transforms.Normalize(mean=mean, std=std) | |
| else: | |
| transform = transforms.Normalize(mean=mean, std=std) | |
| return transform | |
| def format_messages( | |
| conversations: List[Dict[str, str]], | |
| sft_format: str = "deepseek", | |
| system_prompt: str = "", | |
| ): | |
| """ | |
| Applies the SFT template to conversation. | |
| Args: | |
| conversations (List[Dict]): A List of messages. | |
| sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". | |
| system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". | |
| Returns: | |
| sft_prompt (str): The formatted text. | |
| """ | |
| conv = get_conv_template(sft_format) | |
| conv.set_system_message(system_prompt) | |
| for message in conversations: | |
| conv.append_message(message["role"], message["content"].strip()) | |
| sft_prompt = conv.get_prompt().strip() | |
| return sft_prompt | |
| def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): | |
| t = tokenizer.encode(text, add_special_tokens=False) | |
| bos_id = 0 | |
| eos_id = 1 | |
| if bos: | |
| t = [bos_id] + t | |
| if eos: | |
| t = t + [eos_id] | |
| return t | |
| def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: | |
| """ | |
| Args: | |
| conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : | |
| [ | |
| { | |
| "role": "User", | |
| "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.", | |
| "images": ["./examples/table_datasets.png"] | |
| }, | |
| {"role": "Assistant", "content": ""}, | |
| ] | |
| Returns: | |
| pil_images (List[PIL.Image.Image]): the list of PIL images. | |
| """ | |
| pil_images = [] | |
| for message in conversations: | |
| if "images" not in message: | |
| continue | |
| for image_path in message["images"]: | |
| # print('----------------') | |
| # print(image_path) | |
| # print('----------------') | |
| # exit() | |
| # pil_img = Image.open(image_path) | |
| pil_img = load_image(image_path) | |
| pil_img = pil_img.convert("RGB") | |
| pil_images.append(pil_img) | |
| return pil_images | |
| class BaseTransform(ABC): | |
| def set_rng(self, *args, **kwargs): | |
| pass | |
| def __call__(self, *args, **kwargs) -> torch.Tensor: | |
| pass | |
| def default_shape(self): | |
| raise NotImplementedError | |
| class BasicImageTransform(BaseTransform): | |
| def __init__( | |
| self, | |
| mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), | |
| std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), | |
| normalize: bool = True | |
| ): | |
| self.mean = mean | |
| self.std = std | |
| transform_pipelines = [ | |
| transforms.ToTensor() | |
| ] | |
| normalize = normalize_transform(mean, std) if normalize else nn.Identity() | |
| if normalize is not None: | |
| transform_pipelines.append(normalize) | |
| self.transform = transforms.Compose(transform_pipelines) | |
| def __call__(self, x): | |
| x = self.transform(x) | |
| return x | |
| class NoEOSTextStreamer(TextStreamer): | |
| def on_finalized_text(self, text: str, stream_end: bool = False): | |
| eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) | |
| text = text.replace(eos_text, "\n") | |
| print(text, flush=True, end="") | |
| class DeepseekOCRConfig(DeepseekV2Config): | |
| model_type = "DeepseekOCR" | |
| class DeepseekOCRModel(DeepseekV2Model): | |
| config_class = DeepseekOCRConfig | |
| def __init__(self, config: DeepseekV2Config): | |
| super(DeepseekOCRModel, self).__init__(config) | |
| self.sam_model = build_sam_vit_b() | |
| self.vision_model = build_clip_l() | |
| # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2) | |
| n_embed = 1280 | |
| self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed)) | |
| embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) | |
| self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) | |
| self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| images: Optional[torch.FloatTensor] = None, | |
| images_seq_mask: Optional[torch.FloatTensor] = None, | |
| images_spatial_crop: Optional[torch.FloatTensor] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutputWithPast]: | |
| if inputs_embeds is None: | |
| # inputs_embeds = self.embed_tokens(input_ids) | |
| inputs_embeds = self.get_input_embeddings()(input_ids) | |
| sam_model = getattr(self, 'sam_model', None) | |
| # sam_model = self.sam_model | |
| vision_model = getattr(self, 'vision_model', None) | |
| if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0: | |
| idx = 0 | |
| # sam_model = torch.jit.script(sam_model) | |
| # start_time = time.time() | |
| for image, crop_shape in zip(images, images_spatial_crop): | |
| images_in_this_batch = [] | |
| patches = image[0] | |
| image_ori = image[1] | |
| with torch.no_grad(): | |
| # with torch.inference_mode(): | |
| if torch.sum(patches).item() != 0: | |
| # P, C, H, W = patches.shape | |
| crop_flag = 1 | |
| local_features_1 = sam_model(patches) | |
| local_features_2 = vision_model(patches, local_features_1) | |
| # vit_time = time.time() | |
| local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1) | |
| local_features = self.projector(local_features) | |
| global_features_1 = sam_model(image_ori) | |
| global_features_2 = vision_model(image_ori, global_features_1) | |
| global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) | |
| global_features = self.projector(global_features) | |
| print('=====================') | |
| print('BASE: ', global_features.shape) | |
| print('PATCHES: ', local_features.shape) | |
| print('=====================') | |
| _, hw, n_dim = global_features.shape | |
| h = w = int(hw ** 0.5) | |
| _2, hw2, n_dim2 = local_features.shape | |
| h2 = w2 = int(hw2 ** 0.5) | |
| width_crop_num, height_crop_num = crop_shape[0], crop_shape[1] | |
| global_features = global_features.view(h, w, n_dim) | |
| global_features = torch.cat( | |
| [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 | |
| ) | |
| global_features = global_features.view(-1, n_dim) | |
| local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2) | |
| local_features = torch.cat( | |
| [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1 | |
| ) | |
| local_features = local_features.view(-1, n_dim2) | |
| global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) | |
| # end_time = time.time() | |
| # print('sam: ', sam_time - start_time) | |
| # print('vit: ', vit_time - sam_time) | |
| # print('all: ', end_time - start_time) | |
| # exit() | |
| else: | |
| global_features_1 = sam_model(image_ori) | |
| global_features_2 = vision_model(image_ori, global_features_1) | |
| global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) | |
| global_features = self.projector(global_features) | |
| print('=====================') | |
| print('BASE: ', global_features.shape) | |
| print('NO PATCHES') | |
| print('=====================') | |
| _, hw, n_dim = global_features.shape | |
| h = w = int(hw ** 0.5) | |
| global_features = global_features.view(h, w, n_dim) | |
| global_features = torch.cat( | |
| [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 | |
| ) | |
| global_features = global_features.view(-1, n_dim) | |
| global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) | |
| images_in_this_batch.append(global_local_features) | |
| # print(inputs_embeds.shape) | |
| if images_in_this_batch: | |
| images_in_this_batch = torch.cat(images_in_this_batch, dim=0) | |
| # exit() | |
| inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1), images_in_this_batch) | |
| idx += 1 | |
| return super(DeepseekOCRModel, self).forward( | |
| input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, | |
| output_attentions=output_attentions, output_hidden_states=output_hidden_states, | |
| return_dict=return_dict | |
| ) | |
| class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| config_class = DeepseekOCRConfig | |
| # supports_gradient_checkpointing = True | |
| def __init__(self, config): | |
| super(DeepseekV2ForCausalLM, self).__init__(config) | |
| self.model = DeepseekOCRModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_model(self): | |
| return self.model | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| images: Optional[torch.FloatTensor] = None, | |
| images_seq_mask: Optional[torch.FloatTensor] = None, | |
| images_spatial_crop: Optional[torch.FloatTensor] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| past_key_values=past_key_values, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| images=images, | |
| images_seq_mask = images_seq_mask, | |
| images_spatial_crop = images_spatial_crop, | |
| return_dict=return_dict | |
| ) | |
| # print(transformer_outputs) | |
| hidden_states = outputs[0] | |
| logits = self.lm_head(hidden_states) | |
| logits = logits.float() | |
| # logits | |
| loss = None | |
| if labels is not None: | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| # Flatten the tokens | |
| loss_fct = CrossEntropyLoss() | |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| # Enable model parallelism | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| loss = loss_fct(shift_logits, shift_labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | |
| ): | |
| # Omit tokens covered by past_key_values | |
| past_length = 0 | |
| if past_key_values is not None: | |
| if isinstance(past_key_values, Cache): | |
| cache_length = past_key_values.get_seq_length() | |
| past_length = past_key_values.seen_tokens | |
| max_cache_length = past_key_values.get_max_length() | |
| else: | |
| cache_length = past_length = past_key_values[0][0].shape[2] | |
| max_cache_length = None | |
| # Keep only the unprocessed tokens: | |
| # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
| # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
| # input) | |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
| # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
| # input_ids based on the past_length. | |
| elif past_length < input_ids.shape[1]: | |
| input_ids = input_ids[:, past_length:] | |
| # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
| # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
| if ( | |
| max_cache_length is not None | |
| and attention_mask is not None | |
| and cache_length + input_ids.shape[1] > max_cache_length | |
| ): | |
| attention_mask = attention_mask[:, -max_cache_length:] | |
| position_ids = kwargs.get("position_ids", None) | |
| if attention_mask is not None and position_ids is None: | |
| # create position_ids on the fly for batch generation | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| if past_key_values: | |
| position_ids = position_ids[:, -input_ids.shape[1] :] | |
| # if self.generation_config.cache_implementation == "static": | |
| # # generation with static cache | |
| # cache_position = kwargs.get("cache_position", None) | |
| # if cache_position is None: | |
| # past_length = 0 | |
| # else: | |
| # past_length = cache_position[-1] + 1 | |
| # input_ids = input_ids[:, past_length:] | |
| # position_ids = position_ids[:, past_length:] | |
| # TODO @gante we should only keep a `cache_position` in generate, and do +=1. | |
| # same goes for position ids. Could also help with continued generation. | |
| cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| if inputs_embeds is not None and past_key_values is None: | |
| model_inputs = {"inputs_embeds": inputs_embeds} | |
| else: | |
| model_inputs = {"input_ids": input_ids} | |
| model_inputs.update( | |
| { | |
| "position_ids": position_ids, | |
| "past_key_values": past_key_values, | |
| "use_cache": kwargs.get("use_cache"), | |
| "attention_mask": attention_mask, | |
| "images": kwargs.get("images", None), | |
| "images_seq_mask": kwargs.get("images_seq_mask", None), | |
| "images_spatial_crop": kwargs.get("images_spatial_crop", None), | |
| } | |
| ) | |
| return model_inputs | |
| def disable_torch_init(self): | |
| """ | |
| Disable the redundant torch default initialization to accelerate model creation. | |
| """ | |
| import torch | |
| setattr(torch.nn.Linear, "reset_parameters", lambda self: None) | |
| setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) | |
| def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False): | |
| self.disable_torch_init() | |
| os.makedirs(output_path, exist_ok=True) | |
| os.makedirs(f'{output_path}/images', exist_ok=True) | |
| if prompt and image_file: | |
| conversation = [ | |
| { | |
| "role": "<|User|>", | |
| # "content": "<image>\n<|grounding|>Given the layout of the image. ", | |
| "content": f'{prompt}', | |
| # "content": "君不见黄河之水天上来的下一句是什么?", | |
| # "content": "<image>\nFree OCR. ", | |
| # "content": "<image>\nParse the figure. ", | |
| # "content": "<image>\nExtract the text in the image. ", | |
| "images": [f'{image_file}'], | |
| }, | |
| {"role": "<|Assistant|>", "content": ""}, | |
| ] | |
| elif prompt: | |
| conversation = [ | |
| { | |
| "role": "<|User|>", | |
| # "content": "<image>\n<|grounding|>Given the layout of the image. ", | |
| "content": f'{prompt}', | |
| # "content": "君不见黄河之水天上来的下一句是什么?", | |
| # "content": "<image>\nFree OCR. ", | |
| # "content": "<image>\nParse the figure. ", | |
| # "content": "<image>\nExtract the text in the image. ", | |
| # "images": [f'{image_file}'], | |
| }, | |
| {"role": "<|Assistant|>", "content": ""}, | |
| ] | |
| else: | |
| assert False, f'prompt is none!' | |
| prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='') | |
| patch_size = 16 | |
| downsample_ratio = 4 | |
| images = load_pil_images(conversation) | |
| valid_img_tokens = 0 | |
| ratio = 1 | |
| image_draw = images[0].copy() | |
| w,h = image_draw.size | |
| # print(w, h) | |
| ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) | |
| image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) | |
| images_seq_mask = [] | |
| image_token = '<image>' | |
| image_token_id = 128815 | |
| text_splits = prompt.split(image_token) | |
| images_list, images_crop_list, images_seq_mask = [], [], [] | |
| tokenized_str = [] | |
| images_spatial_crop = [] | |
| for text_sep, image in zip(text_splits, images): | |
| tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) | |
| tokenized_str += tokenized_sep | |
| images_seq_mask += [False] * len(tokenized_sep) | |
| if crop_mode: | |
| if image.size[0] <= 640 and image.size[1] <= 640: | |
| crop_ratio = [1, 1] | |
| else: | |
| if crop_mode: | |
| # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions) | |
| images_crop_raw, crop_ratio = dynamic_preprocess(image) | |
| else: | |
| # best_width, best_height = self.image_size, self.image_size | |
| crop_ratio = [1, 1] | |
| """process the global view""" | |
| # image = image.resize((base_size, base_size)) | |
| global_view = ImageOps.pad(image, (base_size, base_size), | |
| color=tuple(int(x * 255) for x in image_transform.mean)) | |
| if base_size == 1024: | |
| valid_img_tokens += int(256 * ratio) | |
| elif base_size == 1280: | |
| valid_img_tokens += int(400 * ratio) | |
| # elif base_size == 640: | |
| # valid_img_tokens += int(100 * ratio) | |
| images_list.append(image_transform(global_view).to(torch.bfloat16)) | |
| # global_view_tensor = image_transform(global_view).to(torch.bfloat16) | |
| width_crop_num, height_crop_num = crop_ratio | |
| images_spatial_crop.append([width_crop_num, height_crop_num]) | |
| if width_crop_num > 1 or height_crop_num > 1: | |
| """process the local views""" | |
| for i in range(len(images_crop_raw)): | |
| images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) | |
| if image_size == 640: | |
| valid_img_tokens += len(images_crop_list) * 100 | |
| num_queries = math.ceil((image_size // patch_size) / downsample_ratio) | |
| num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) | |
| """add image tokens""" | |
| tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base | |
| tokenized_image += [image_token_id] | |
| if width_crop_num > 1 or height_crop_num > 1: | |
| tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * ( | |
| num_queries * height_crop_num) | |
| tokenized_str += tokenized_image | |
| images_seq_mask += [True] * len(tokenized_image) | |
| # num_image_tokens.append(len(tokenized_image)) | |
| else: | |
| # best_width, best_height = self.image_size, self.image_size | |
| # print(image.size, (best_width, best_height)) # check the select_best_resolutions func | |
| """process the global view""" | |
| if image_size <= 640: | |
| print('directly resize') | |
| image = image.resize((image_size, image_size)) | |
| # else: | |
| global_view = ImageOps.pad(image, (image_size, image_size), | |
| color=tuple(int(x * 255) for x in image_transform.mean)) | |
| images_list.append(image_transform(global_view).to(torch.bfloat16)) | |
| if base_size == 1024: | |
| valid_img_tokens += int(256 * ratio) | |
| elif base_size == 1280: | |
| valid_img_tokens += int(400 * ratio) | |
| elif base_size == 640: | |
| valid_img_tokens += int(100 * 1) | |
| elif base_size == 512: | |
| valid_img_tokens += int(64 * 1) | |
| width_crop_num, height_crop_num = 1, 1 | |
| images_spatial_crop.append([width_crop_num, height_crop_num]) | |
| """add image tokens""" | |
| num_queries = math.ceil((image_size // patch_size) / downsample_ratio) | |
| tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries | |
| tokenized_image += [image_token_id] | |
| # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * ( | |
| # num_queries * height_crop_num) | |
| tokenized_str += tokenized_image | |
| images_seq_mask += [True] * len(tokenized_image) | |
| # num_image_tokens.append(len(tokenized_image)) | |
| """process the last text split""" | |
| tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) | |
| tokenized_str += tokenized_sep | |
| images_seq_mask += [False] * len(tokenized_sep) | |
| """add the bos tokens""" | |
| bos_id = 0 | |
| tokenized_str = [bos_id] + tokenized_str | |
| images_seq_mask = [False] + images_seq_mask | |
| input_ids = torch.LongTensor(tokenized_str) | |
| images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) | |
| if len(images_list) == 0: | |
| images_ori = torch.zeros((1, 3, image_size, image_size)) | |
| images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) | |
| images_crop = torch.zeros((1, 3, base_size, base_size)) | |
| else: | |
| images_ori = torch.stack(images_list, dim=0) | |
| images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) | |
| if images_crop_list: | |
| images_crop = torch.stack(images_crop_list, dim=0) | |
| else: | |
| images_crop = torch.zeros((1, 3, base_size, base_size)) | |
| if not eval_mode: | |
| streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| with torch.no_grad(): | |
| output_ids = self.generate( | |
| input_ids.unsqueeze(0).cuda(), | |
| images=[(images_crop.cuda(), images_ori.cuda())], | |
| images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), | |
| images_spatial_crop = images_spatial_crop, | |
| # do_sample=False, | |
| # num_beams = 1, | |
| temperature=0.0, | |
| eos_token_id=tokenizer.eos_token_id, | |
| streamer=streamer, | |
| max_new_tokens=8192, | |
| no_repeat_ngram_size = 20, | |
| use_cache = True | |
| ) | |
| else: | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| with torch.no_grad(): | |
| output_ids = self.generate( | |
| input_ids.unsqueeze(0).cuda(), | |
| images=[(images_crop.cuda(), images_ori.cuda())], | |
| images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), | |
| images_spatial_crop = images_spatial_crop, | |
| # do_sample=False, | |
| # num_beams = 1, | |
| temperature=0.0, | |
| eos_token_id=tokenizer.eos_token_id, | |
| max_new_tokens=8192, | |
| no_repeat_ngram_size = 35, | |
| use_cache = True | |
| ) | |
| if '<image>' in conversation[0]['content'] and eval_mode: | |
| outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) | |
| stop_str = '<|end▁of▁sentence|>' | |
| if outputs.endswith(stop_str): | |
| outputs = outputs[:-len(stop_str)] | |
| # re_match | |
| outputs = outputs.strip() | |
| return outputs | |
| if '<image>' in conversation[0]['content'] and test_compress: | |
| outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) | |
| pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) | |
| print('='*50) | |
| print('image size: ', (w, h)) | |
| print('valid image tokens: ', int(valid_img_tokens)) | |
| print('output texts tokens (valid): ', pure_texts_outputs_token_length) | |
| print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) | |
| print('='*50) | |
| if '<image>' in conversation[0]['content'] and save_results: | |
| outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) | |
| stop_str = '<|end▁of▁sentence|>' | |
| print('='*15 + 'save results:' + '='*15) | |
| # # # # conv.messages[-1][-1] = outputs | |
| if outputs.endswith(stop_str): | |
| outputs = outputs[:-len(stop_str)] | |
| outputs = outputs.strip() | |
| matches_ref, matches_images, mathes_other = re_match(outputs) | |
| # print(matches_ref) | |
| result = process_image_with_refs(image_draw, matches_ref, output_path) | |
| for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): | |
| outputs = outputs.replace(a_match_image, ' + '.jpg)\n') | |
| for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): | |
| outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') | |
| # if 'structural formula' in conversation[0]['content']: | |
| # outputs = '<smiles>' + outputs + '</smiles>' | |
| with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: | |
| afile.write(outputs) | |
| if 'line_type' in outputs: | |
| import matplotlib.pyplot as plt | |
| lines = eval(outputs)['Line']['line'] | |
| line_type = eval(outputs)['Line']['line_type'] | |
| # print(lines) | |
| endpoints = eval(outputs)['Line']['line_endpoint'] | |
| fig, ax = plt.subplots(figsize=(3,3), dpi=200) | |
| ax.set_xlim(-15, 15) | |
| ax.set_ylim(-15, 15) | |
| for idx, line in enumerate(lines): | |
| try: | |
| p0 = eval(line.split(' -- ')[0]) | |
| p1 = eval(line.split(' -- ')[-1]) | |
| if line_type[idx] == '--': | |
| ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') | |
| else: | |
| ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') | |
| ax.scatter(p0[0], p0[1], s=5, color = 'k') | |
| ax.scatter(p1[0], p1[1], s=5, color = 'k') | |
| except: | |
| pass | |
| for endpoint in endpoints: | |
| label = endpoint.split(': ')[0] | |
| (x, y) = eval(endpoint.split(': ')[1]) | |
| ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', | |
| fontsize=5, fontweight='light') | |
| plt.savefig(f'{output_path}/geo.jpg') | |
| plt.close() | |
| result.save(f"{output_path}/result_with_boxes.jpg") | |