| 
							 | 
						from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM | 
					
					
						
						| 
							 | 
						from .configuration_deepseek_v2 import DeepseekV2Config | 
					
					
						
						| 
							 | 
						from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | 
					
					
						
						| 
							 | 
						from typing import List, Optional, Tuple, Union | 
					
					
						
						| 
							 | 
						from transformers.cache_utils import Cache | 
					
					
						
						| 
							 | 
						import requests | 
					
					
						
						| 
							 | 
						from PIL import Image, ImageOps, ImageDraw, ImageFont | 
					
					
						
						| 
							 | 
						from io import BytesIO | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						from torch.nn import CrossEntropyLoss | 
					
					
						
						| 
							 | 
						from torchvision import transforms | 
					
					
						
						| 
							 | 
						from torchvision.transforms.functional import InterpolationMode | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector | 
					
					
						
						| 
							 | 
						from addict import Dict | 
					
					
						
						| 
							 | 
						from transformers import TextStreamer | 
					
					
						
						| 
							 | 
						from .conversation import get_conv_template | 
					
					
						
						| 
							 | 
						from abc import ABC | 
					
					
						
						| 
							 | 
						import math | 
					
					
						
						| 
							 | 
						import re | 
					
					
						
						| 
							 | 
						from tqdm import tqdm | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import time | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_image(image_path): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        image = Image.open(image_path) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        corrected_image = ImageOps.exif_transpose(image) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return corrected_image | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        print(f"error: {e}") | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            return Image.open(image_path) | 
					
					
						
						| 
							 | 
						        except: | 
					
					
						
						| 
							 | 
						            return None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def re_match(text): | 
					
					
						
						| 
							 | 
						    pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' | 
					
					
						
						| 
							 | 
						    matches = re.findall(pattern, text, re.DOTALL) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    mathes_image = [] | 
					
					
						
						| 
							 | 
						    mathes_other = [] | 
					
					
						
						| 
							 | 
						    for a_match in matches: | 
					
					
						
						| 
							 | 
						        if '<|ref|>image<|/ref|>' in a_match[0]: | 
					
					
						
						| 
							 | 
						            mathes_image.append(a_match[0]) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            mathes_other.append(a_match[0]) | 
					
					
						
						| 
							 | 
						    return matches, mathes_image, mathes_other | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def extract_coordinates_and_label(ref_text, image_width, image_height): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        label_type = ref_text[1] | 
					
					
						
						| 
							 | 
						        cor_list = eval(ref_text[2]) | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        print(e) | 
					
					
						
						| 
							 | 
						        return None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return (label_type, cor_list) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def draw_bounding_boxes(image, refs, ouput_path): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    image_width, image_height = image.size | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    img_draw = image.copy() | 
					
					
						
						| 
							 | 
						    draw = ImageDraw.Draw(img_draw) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) | 
					
					
						
						| 
							 | 
						    draw2 = ImageDraw.Draw(overlay) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    font = ImageFont.load_default() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    img_idx = 0 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for i, ref in enumerate(refs): | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            result = extract_coordinates_and_label(ref, image_width, image_height) | 
					
					
						
						| 
							 | 
						            if result: | 
					
					
						
						| 
							 | 
						                label_type, points_list = result | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                color_a = color + (20, ) | 
					
					
						
						| 
							 | 
						                for points in points_list: | 
					
					
						
						| 
							 | 
						                    x1, y1, x2, y2 = points | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    x1 = int(x1 / 999 * image_width) | 
					
					
						
						| 
							 | 
						                    y1 = int(y1 / 999 * image_height) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    x2 = int(x2 / 999 * image_width) | 
					
					
						
						| 
							 | 
						                    y2 = int(y2 / 999 * image_height) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    if label_type == 'image': | 
					
					
						
						| 
							 | 
						                        try: | 
					
					
						
						| 
							 | 
						                            cropped = image.crop((x1, y1, x2, y2)) | 
					
					
						
						| 
							 | 
						                            cropped.save(f"{ouput_path}/images/{img_idx}.jpg") | 
					
					
						
						| 
							 | 
						                        except Exception as e: | 
					
					
						
						| 
							 | 
						                            print(e) | 
					
					
						
						| 
							 | 
						                            pass | 
					
					
						
						| 
							 | 
						                        img_idx += 1 | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                    try: | 
					
					
						
						| 
							 | 
						                        if label_type == 'title': | 
					
					
						
						| 
							 | 
						                            draw.rectangle([x1, y1, x2, y2], outline=color, width=4) | 
					
					
						
						| 
							 | 
						                            draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) | 
					
					
						
						| 
							 | 
						                        else: | 
					
					
						
						| 
							 | 
						                            draw.rectangle([x1, y1, x2, y2], outline=color, width=2) | 
					
					
						
						| 
							 | 
						                            draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) | 
					
					
						
						| 
							 | 
						                        text_x = x1 | 
					
					
						
						| 
							 | 
						                        text_y = max(0, y1 - 15) | 
					
					
						
						| 
							 | 
						                             | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        text_bbox = draw.textbbox((0, 0), label_type, font=font) | 
					
					
						
						| 
							 | 
						                        text_width = text_bbox[2] - text_bbox[0] | 
					
					
						
						| 
							 | 
						                        text_height = text_bbox[3] - text_bbox[1] | 
					
					
						
						| 
							 | 
						                        draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],  | 
					
					
						
						| 
							 | 
						                                    fill=(255, 255, 255, 30)) | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        draw.text((text_x, text_y), label_type, font=font, fill=color) | 
					
					
						
						| 
							 | 
						                    except: | 
					
					
						
						| 
							 | 
						                        pass | 
					
					
						
						| 
							 | 
						        except: | 
					
					
						
						| 
							 | 
						            continue | 
					
					
						
						| 
							 | 
						    img_draw.paste(overlay, (0, 0), overlay) | 
					
					
						
						| 
							 | 
						    return img_draw | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def process_image_with_refs(image, ref_texts, output_path): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    result_image = draw_bounding_boxes(image, ref_texts, output_path) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return result_image | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | 
					
					
						
						| 
							 | 
						    best_ratio_diff = float('inf') | 
					
					
						
						| 
							 | 
						    best_ratio = (1, 1) | 
					
					
						
						| 
							 | 
						    area = width * height | 
					
					
						
						| 
							 | 
						    for ratio in target_ratios: | 
					
					
						
						| 
							 | 
						        target_aspect_ratio = ratio[0] / ratio[1] | 
					
					
						
						| 
							 | 
						        ratio_diff = abs(aspect_ratio - target_aspect_ratio) | 
					
					
						
						| 
							 | 
						        if ratio_diff < best_ratio_diff: | 
					
					
						
						| 
							 | 
						            best_ratio_diff = ratio_diff | 
					
					
						
						| 
							 | 
						            best_ratio = ratio | 
					
					
						
						| 
							 | 
						        elif ratio_diff == best_ratio_diff: | 
					
					
						
						| 
							 | 
						            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | 
					
					
						
						| 
							 | 
						                best_ratio = ratio | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return best_ratio | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False): | 
					
					
						
						| 
							 | 
						    orig_width, orig_height = image.size | 
					
					
						
						| 
							 | 
						    aspect_ratio = orig_width / orig_height | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    target_ratios = set( | 
					
					
						
						| 
							 | 
						        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if | 
					
					
						
						| 
							 | 
						        i * j <= max_num and i * j >= min_num) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    target_aspect_ratio = find_closest_aspect_ratio( | 
					
					
						
						| 
							 | 
						        aspect_ratio, target_ratios, orig_width, orig_height, image_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    target_width = image_size * target_aspect_ratio[0] | 
					
					
						
						| 
							 | 
						    target_height = image_size * target_aspect_ratio[1] | 
					
					
						
						| 
							 | 
						    blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    resized_img = image.resize((target_width, target_height)) | 
					
					
						
						| 
							 | 
						    processed_images = [] | 
					
					
						
						| 
							 | 
						    for i in range(blocks): | 
					
					
						
						| 
							 | 
						        box = ( | 
					
					
						
						| 
							 | 
						            (i % (target_width // image_size)) * image_size, | 
					
					
						
						| 
							 | 
						            (i // (target_width // image_size)) * image_size, | 
					
					
						
						| 
							 | 
						            ((i % (target_width // image_size)) + 1) * image_size, | 
					
					
						
						| 
							 | 
						            ((i // (target_width // image_size)) + 1) * image_size | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        split_img = resized_img.crop(box) | 
					
					
						
						| 
							 | 
						        processed_images.append(split_img) | 
					
					
						
						| 
							 | 
						    assert len(processed_images) == blocks | 
					
					
						
						| 
							 | 
						    if use_thumbnail and len(processed_images) != 1: | 
					
					
						
						| 
							 | 
						        thumbnail_img = image.resize((image_size, image_size)) | 
					
					
						
						| 
							 | 
						        processed_images.append(thumbnail_img) | 
					
					
						
						| 
							 | 
						    return processed_images, target_aspect_ratio | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def normalize_transform(mean, std): | 
					
					
						
						| 
							 | 
						    if mean is None and std is None: | 
					
					
						
						| 
							 | 
						        transform = None | 
					
					
						
						| 
							 | 
						    elif mean is None and std is not None: | 
					
					
						
						| 
							 | 
						        mean = [0.] * len(std) | 
					
					
						
						| 
							 | 
						        transform = transforms.Normalize(mean=mean, std=std) | 
					
					
						
						| 
							 | 
						    elif mean is not None and std is None: | 
					
					
						
						| 
							 | 
						        std = [1.] * len(mean) | 
					
					
						
						| 
							 | 
						        transform = transforms.Normalize(mean=mean, std=std) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        transform = transforms.Normalize(mean=mean, std=std) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return transform | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def format_messages( | 
					
					
						
						| 
							 | 
						        conversations: List[Dict[str, str]], | 
					
					
						
						| 
							 | 
						        sft_format: str = "deepseek", | 
					
					
						
						| 
							 | 
						        system_prompt: str = "", | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Applies the SFT template to conversation. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        conversations (List[Dict]): A List of messages. | 
					
					
						
						| 
							 | 
						        sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". | 
					
					
						
						| 
							 | 
						        system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        sft_prompt (str): The formatted text. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    conv = get_conv_template(sft_format) | 
					
					
						
						| 
							 | 
						    conv.set_system_message(system_prompt) | 
					
					
						
						| 
							 | 
						    for message in conversations: | 
					
					
						
						| 
							 | 
						        conv.append_message(message["role"], message["content"].strip()) | 
					
					
						
						| 
							 | 
						    sft_prompt = conv.get_prompt().strip() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return sft_prompt | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): | 
					
					
						
						| 
							 | 
						    t = tokenizer.encode(text, add_special_tokens=False) | 
					
					
						
						| 
							 | 
						    bos_id = 0 | 
					
					
						
						| 
							 | 
						    eos_id = 1 | 
					
					
						
						| 
							 | 
						    if bos: | 
					
					
						
						| 
							 | 
						        t = [bos_id] + t | 
					
					
						
						| 
							 | 
						    if eos: | 
					
					
						
						| 
							 | 
						        t = t + [eos_id] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return t | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : | 
					
					
						
						| 
							 | 
						            [ | 
					
					
						
						| 
							 | 
						                { | 
					
					
						
						| 
							 | 
						                    "role": "User", | 
					
					
						
						| 
							 | 
						                    "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.", | 
					
					
						
						| 
							 | 
						                    "images": ["./examples/table_datasets.png"] | 
					
					
						
						| 
							 | 
						                }, | 
					
					
						
						| 
							 | 
						                {"role": "Assistant", "content": ""}, | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        pil_images (List[PIL.Image.Image]): the list of PIL images. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    pil_images = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    for message in conversations: | 
					
					
						
						| 
							 | 
						        if "images" not in message: | 
					
					
						
						| 
							 | 
						            continue | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for image_path in message["images"]: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            pil_img = load_image(image_path) | 
					
					
						
						| 
							 | 
						            pil_img = pil_img.convert("RGB") | 
					
					
						
						| 
							 | 
						            pil_images.append(pil_img) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return pil_images | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class BaseTransform(ABC): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def set_rng(self, *args, **kwargs): | 
					
					
						
						| 
							 | 
						        pass | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __call__(self, *args, **kwargs) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        pass | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @property | 
					
					
						
						| 
							 | 
						    def default_shape(self): | 
					
					
						
						| 
							 | 
						        raise NotImplementedError | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class BasicImageTransform(BaseTransform): | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self,  | 
					
					
						
						| 
							 | 
						        mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), | 
					
					
						
						| 
							 | 
						        std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), | 
					
					
						
						| 
							 | 
						        normalize: bool = True | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        self.mean = mean | 
					
					
						
						| 
							 | 
						        self.std = std | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						        transform_pipelines = [ | 
					
					
						
						| 
							 | 
						            transforms.ToTensor() | 
					
					
						
						| 
							 | 
						        ] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        normalize = normalize_transform(mean, std) if normalize else nn.Identity() | 
					
					
						
						| 
							 | 
						        if normalize is not None: | 
					
					
						
						| 
							 | 
						            transform_pipelines.append(normalize) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.transform = transforms.Compose(transform_pipelines) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def __call__(self, x): | 
					
					
						
						| 
							 | 
						        x = self.transform(x) | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class NoEOSTextStreamer(TextStreamer): | 
					
					
						
						| 
							 | 
						    def on_finalized_text(self, text: str, stream_end: bool = False): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) | 
					
					
						
						| 
							 | 
						        text = text.replace(eos_text, "\n") | 
					
					
						
						| 
							 | 
						        print(text, flush=True, end="") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class DeepseekOCRConfig(DeepseekV2Config): | 
					
					
						
						| 
							 | 
						    model_type = "DeepseekOCR" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class DeepseekOCRModel(DeepseekV2Model): | 
					
					
						
						| 
							 | 
						    config_class = DeepseekOCRConfig | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config: DeepseekV2Config): | 
					
					
						
						| 
							 | 
						        super(DeepseekOCRModel, self).__init__(config) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.sam_model = build_sam_vit_b() | 
					
					
						
						| 
							 | 
						        self.vision_model = build_clip_l() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        n_embed = 1280 | 
					
					
						
						| 
							 | 
						        self.projector =  MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed)) | 
					
					
						
						| 
							 | 
						        embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) | 
					
					
						
						| 
							 | 
						        self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) | 
					
					
						
						| 
							 | 
						        self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def forward( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        input_ids: torch.LongTensor = None, | 
					
					
						
						| 
							 | 
						        attention_mask: Optional[torch.Tensor] = None, | 
					
					
						
						| 
							 | 
						        position_ids: Optional[torch.LongTensor] = None, | 
					
					
						
						| 
							 | 
						        past_key_values: Optional[List[torch.FloatTensor]] = None, | 
					
					
						
						| 
							 | 
						        inputs_embeds: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        use_cache: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        output_attentions: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        output_hidden_states: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        images: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        images_seq_mask: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        images_spatial_crop: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        return_dict: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						    ) -> Union[Tuple, BaseModelOutputWithPast]: | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if inputs_embeds is None: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            inputs_embeds = self.get_input_embeddings()(input_ids) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        sam_model = getattr(self, 'sam_model', None) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        vision_model = getattr(self, 'vision_model', None) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0: | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            idx = 0 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            for image, crop_shape in zip(images, images_spatial_crop): | 
					
					
						
						| 
							 | 
						                images_in_this_batch = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                patches = image[0] | 
					
					
						
						| 
							 | 
						                image_ori = image[1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                with torch.no_grad(): | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    if torch.sum(patches).item() != 0: | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        crop_flag = 1 | 
					
					
						
						| 
							 | 
						                        local_features_1 = sam_model(patches) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        local_features_2 = vision_model(patches, local_features_1)   | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)  | 
					
					
						
						| 
							 | 
						                        local_features = self.projector(local_features) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        global_features_1 = sam_model(image_ori) | 
					
					
						
						| 
							 | 
						                        global_features_2 = vision_model(image_ori, global_features_1)  | 
					
					
						
						| 
							 | 
						                        global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)  | 
					
					
						
						| 
							 | 
						                        global_features = self.projector(global_features) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        print('=====================') | 
					
					
						
						| 
							 | 
						                        print('BASE: ', global_features.shape) | 
					
					
						
						| 
							 | 
						                        print('PATCHES: ', local_features.shape) | 
					
					
						
						| 
							 | 
						                        print('=====================') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        _, hw, n_dim = global_features.shape | 
					
					
						
						| 
							 | 
						                        h = w = int(hw ** 0.5) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        _2, hw2, n_dim2 = local_features.shape | 
					
					
						
						| 
							 | 
						                        h2 = w2 = int(hw2 ** 0.5) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        width_crop_num, height_crop_num = crop_shape[0], crop_shape[1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        global_features = global_features.view(h, w, n_dim) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        global_features = torch.cat( | 
					
					
						
						| 
							 | 
						                            [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        global_features = global_features.view(-1, n_dim) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2) | 
					
					
						
						| 
							 | 
						                        local_features = torch.cat( | 
					
					
						
						| 
							 | 
						                            [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1 | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						                        local_features = local_features.view(-1, n_dim2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                    | 
					
					
						
						| 
							 | 
						                    else: | 
					
					
						
						| 
							 | 
						                        global_features_1 = sam_model(image_ori) | 
					
					
						
						| 
							 | 
						                        global_features_2 = vision_model(image_ori, global_features_1)  | 
					
					
						
						| 
							 | 
						                        global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)  | 
					
					
						
						| 
							 | 
						                        global_features = self.projector(global_features) | 
					
					
						
						| 
							 | 
						                        print('=====================') | 
					
					
						
						| 
							 | 
						                        print('BASE: ', global_features.shape) | 
					
					
						
						| 
							 | 
						                        print('NO PATCHES') | 
					
					
						
						| 
							 | 
						                        print('=====================') | 
					
					
						
						| 
							 | 
						                        _, hw, n_dim = global_features.shape | 
					
					
						
						| 
							 | 
						                        h = w = int(hw ** 0.5) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        global_features = global_features.view(h, w, n_dim) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        global_features = torch.cat( | 
					
					
						
						| 
							 | 
						                            [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        global_features = global_features.view(-1, n_dim) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    images_in_this_batch.append(global_local_features) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                if images_in_this_batch: | 
					
					
						
						| 
							 | 
						                    images_in_this_batch = torch.cat(images_in_this_batch, dim=0) | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                idx += 1 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return super(DeepseekOCRModel, self).forward( | 
					
					
						
						| 
							 | 
						            input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, | 
					
					
						
						| 
							 | 
						            inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, | 
					
					
						
						| 
							 | 
						            output_attentions=output_attentions, output_hidden_states=output_hidden_states, | 
					
					
						
						| 
							 | 
						            return_dict=return_dict | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    config_class = DeepseekOCRConfig | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config): | 
					
					
						
						| 
							 | 
						        super(DeepseekV2ForCausalLM, self).__init__(config) | 
					
					
						
						| 
							 | 
						        self.model = DeepseekOCRModel(config) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.vocab_size = config.vocab_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.post_init() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_model(self): | 
					
					
						
						| 
							 | 
						        return self.model | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        input_ids: torch.LongTensor = None, | 
					
					
						
						| 
							 | 
						        attention_mask: Optional[torch.Tensor] = None, | 
					
					
						
						| 
							 | 
						        position_ids: Optional[torch.LongTensor] = None, | 
					
					
						
						| 
							 | 
						        past_key_values: Optional[List[torch.FloatTensor]] = None, | 
					
					
						
						| 
							 | 
						        inputs_embeds: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        labels: Optional[torch.LongTensor] = None, | 
					
					
						
						| 
							 | 
						        use_cache: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        output_attentions: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        output_hidden_states: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        images: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        images_seq_mask: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        images_spatial_crop: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        return_dict: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    ) -> Union[Tuple, CausalLMOutputWithPast]: | 
					
					
						
						| 
							 | 
						        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | 
					
					
						
						| 
							 | 
						        output_hidden_states = ( | 
					
					
						
						| 
							 | 
						            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        return_dict = return_dict if return_dict is not None else self.config.use_return_dict | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        outputs  = self.model( | 
					
					
						
						| 
							 | 
						            input_ids=input_ids, | 
					
					
						
						| 
							 | 
						            past_key_values=past_key_values, | 
					
					
						
						| 
							 | 
						            attention_mask=attention_mask, | 
					
					
						
						| 
							 | 
						            position_ids=position_ids, | 
					
					
						
						| 
							 | 
						            inputs_embeds=inputs_embeds, | 
					
					
						
						| 
							 | 
						            use_cache=use_cache, | 
					
					
						
						| 
							 | 
						            output_attentions=output_attentions, | 
					
					
						
						| 
							 | 
						            output_hidden_states=output_hidden_states, | 
					
					
						
						| 
							 | 
						            images=images, | 
					
					
						
						| 
							 | 
						            images_seq_mask = images_seq_mask, | 
					
					
						
						| 
							 | 
						            images_spatial_crop = images_spatial_crop, | 
					
					
						
						| 
							 | 
						            return_dict=return_dict | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        hidden_states = outputs[0] | 
					
					
						
						| 
							 | 
						        logits = self.lm_head(hidden_states) | 
					
					
						
						| 
							 | 
						        logits = logits.float() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        loss = None | 
					
					
						
						| 
							 | 
						        if labels is not None: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            shift_logits = logits[..., :-1, :].contiguous() | 
					
					
						
						| 
							 | 
						            shift_labels = labels[..., 1:].contiguous() | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            loss_fct = CrossEntropyLoss() | 
					
					
						
						| 
							 | 
						            shift_logits = shift_logits.view(-1, self.config.vocab_size) | 
					
					
						
						| 
							 | 
						            shift_labels = shift_labels.view(-1) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            shift_labels = shift_labels.to(shift_logits.device) | 
					
					
						
						| 
							 | 
						            loss = loss_fct(shift_logits, shift_labels) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if not return_dict: | 
					
					
						
						| 
							 | 
						            output = (logits,) + outputs[1:] | 
					
					
						
						| 
							 | 
						            return (loss,) + output if loss is not None else output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return CausalLMOutputWithPast( | 
					
					
						
						| 
							 | 
						            loss=loss, | 
					
					
						
						| 
							 | 
						            logits=logits, | 
					
					
						
						| 
							 | 
						            past_key_values=outputs.past_key_values, | 
					
					
						
						| 
							 | 
						            hidden_states=outputs.hidden_states, | 
					
					
						
						| 
							 | 
						            attentions=outputs.attentions, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def prepare_inputs_for_generation( | 
					
					
						
						| 
							 | 
						        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        past_length = 0 | 
					
					
						
						| 
							 | 
						        if past_key_values is not None: | 
					
					
						
						| 
							 | 
						            if isinstance(past_key_values, Cache): | 
					
					
						
						| 
							 | 
						                cache_length = past_key_values.get_seq_length() | 
					
					
						
						| 
							 | 
						                past_length = past_key_values.seen_tokens | 
					
					
						
						| 
							 | 
						                max_cache_length = past_key_values.get_max_length() | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                cache_length = past_length = past_key_values[0][0].shape[2] | 
					
					
						
						| 
							 | 
						                max_cache_length = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | 
					
					
						
						| 
							 | 
						                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            elif past_length < input_ids.shape[1]: | 
					
					
						
						| 
							 | 
						                input_ids = input_ids[:, past_length:] | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if ( | 
					
					
						
						| 
							 | 
						                max_cache_length is not None | 
					
					
						
						| 
							 | 
						                and attention_mask is not None | 
					
					
						
						| 
							 | 
						                and cache_length + input_ids.shape[1] > max_cache_length | 
					
					
						
						| 
							 | 
						            ): | 
					
					
						
						| 
							 | 
						                attention_mask = attention_mask[:, -max_cache_length:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        position_ids = kwargs.get("position_ids", None) | 
					
					
						
						| 
							 | 
						        if attention_mask is not None and position_ids is None: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            position_ids = attention_mask.long().cumsum(-1) - 1 | 
					
					
						
						| 
							 | 
						            position_ids.masked_fill_(attention_mask == 0, 1) | 
					
					
						
						| 
							 | 
						            if past_key_values: | 
					
					
						
						| 
							 | 
						                position_ids = position_ids[:, -input_ids.shape[1] :] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if inputs_embeds is not None and past_key_values is None: | 
					
					
						
						| 
							 | 
						            model_inputs = {"inputs_embeds": inputs_embeds} | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            model_inputs = {"input_ids": input_ids} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        model_inputs.update( | 
					
					
						
						| 
							 | 
						            { | 
					
					
						
						| 
							 | 
						                "position_ids": position_ids, | 
					
					
						
						| 
							 | 
						                "past_key_values": past_key_values, | 
					
					
						
						| 
							 | 
						                "use_cache": kwargs.get("use_cache"), | 
					
					
						
						| 
							 | 
						                "attention_mask": attention_mask, | 
					
					
						
						| 
							 | 
						                "images": kwargs.get("images", None), | 
					
					
						
						| 
							 | 
						                "images_seq_mask": kwargs.get("images_seq_mask", None), | 
					
					
						
						| 
							 | 
						                "images_spatial_crop": kwargs.get("images_spatial_crop", None), | 
					
					
						
						| 
							 | 
						            } | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        return model_inputs | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def disable_torch_init(self): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Disable the redundant torch default initialization to accelerate model creation. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        import torch | 
					
					
						
						| 
							 | 
						        setattr(torch.nn.Linear, "reset_parameters", lambda self: None) | 
					
					
						
						| 
							 | 
						        setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False): | 
					
					
						
						| 
							 | 
						        self.disable_torch_init() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        os.makedirs(output_path, exist_ok=True) | 
					
					
						
						| 
							 | 
						        os.makedirs(f'{output_path}/images', exist_ok=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if prompt and image_file: | 
					
					
						
						| 
							 | 
						            conversation = [ | 
					
					
						
						| 
							 | 
						                { | 
					
					
						
						| 
							 | 
						                    "role": "<|User|>", | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    "content": f'{prompt}', | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    "images": [f'{image_file}'], | 
					
					
						
						| 
							 | 
						                }, | 
					
					
						
						| 
							 | 
						                {"role": "<|Assistant|>", "content": ""}, | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        elif prompt: | 
					
					
						
						| 
							 | 
						            conversation = [ | 
					
					
						
						| 
							 | 
						                { | 
					
					
						
						| 
							 | 
						                    "role": "<|User|>", | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    "content": f'{prompt}', | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                }, | 
					
					
						
						| 
							 | 
						                {"role": "<|Assistant|>", "content": ""}, | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            assert False, f'prompt is none!' | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        patch_size = 16 | 
					
					
						
						| 
							 | 
						        downsample_ratio = 4 | 
					
					
						
						| 
							 | 
						        images = load_pil_images(conversation) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        valid_img_tokens = 0 | 
					
					
						
						| 
							 | 
						        ratio = 1 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_draw = images[0].copy() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        w,h = image_draw.size | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) | 
					
					
						
						| 
							 | 
						        images_seq_mask = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_token = '<image>' | 
					
					
						
						| 
							 | 
						        image_token_id = 128815 | 
					
					
						
						| 
							 | 
						        text_splits = prompt.split(image_token) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        images_list, images_crop_list, images_seq_mask = [], [], [] | 
					
					
						
						| 
							 | 
						        tokenized_str = [] | 
					
					
						
						| 
							 | 
						        images_spatial_crop = [] | 
					
					
						
						| 
							 | 
						        for text_sep, image in zip(text_splits, images): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) | 
					
					
						
						| 
							 | 
						            tokenized_str += tokenized_sep | 
					
					
						
						| 
							 | 
						            images_seq_mask += [False] * len(tokenized_sep) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if crop_mode: | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                if image.size[0] <= 640 and image.size[1] <= 640: | 
					
					
						
						| 
							 | 
						                    crop_ratio = [1, 1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    if crop_mode: | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        images_crop_raw, crop_ratio = dynamic_preprocess(image) | 
					
					
						
						| 
							 | 
						                    else: | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        crop_ratio = [1, 1] | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                """process the global view""" | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                global_view = ImageOps.pad(image, (base_size, base_size), | 
					
					
						
						| 
							 | 
						                                        color=tuple(int(x * 255) for x in image_transform.mean)) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if base_size == 1024: | 
					
					
						
						| 
							 | 
						                    valid_img_tokens += int(256 * ratio) | 
					
					
						
						| 
							 | 
						                elif base_size == 1280: | 
					
					
						
						| 
							 | 
						                    valid_img_tokens += int(400 * ratio) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                images_list.append(image_transform(global_view).to(torch.bfloat16)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                width_crop_num, height_crop_num = crop_ratio | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                images_spatial_crop.append([width_crop_num, height_crop_num]) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if width_crop_num > 1 or height_crop_num > 1: | 
					
					
						
						| 
							 | 
						                    """process the local views""" | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    for i in range(len(images_crop_raw)): | 
					
					
						
						| 
							 | 
						                        images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if image_size == 640: | 
					
					
						
						| 
							 | 
						                    valid_img_tokens += len(images_crop_list) * 100 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                num_queries = math.ceil((image_size // patch_size) / downsample_ratio) | 
					
					
						
						| 
							 | 
						                num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                """add image tokens""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base | 
					
					
						
						| 
							 | 
						                tokenized_image += [image_token_id] | 
					
					
						
						| 
							 | 
						                if width_crop_num > 1 or height_crop_num > 1: | 
					
					
						
						| 
							 | 
						                    tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * ( | 
					
					
						
						| 
							 | 
						                                num_queries * height_crop_num) | 
					
					
						
						| 
							 | 
						                tokenized_str += tokenized_image | 
					
					
						
						| 
							 | 
						                images_seq_mask += [True] * len(tokenized_image) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                """process the global view""" | 
					
					
						
						| 
							 | 
						                if image_size <= 640: | 
					
					
						
						| 
							 | 
						                    print('directly resize') | 
					
					
						
						| 
							 | 
						                    image = image.resize((image_size, image_size)) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                global_view = ImageOps.pad(image, (image_size, image_size), | 
					
					
						
						| 
							 | 
						                                        color=tuple(int(x * 255) for x in image_transform.mean)) | 
					
					
						
						| 
							 | 
						                images_list.append(image_transform(global_view).to(torch.bfloat16)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                if base_size == 1024: | 
					
					
						
						| 
							 | 
						                    valid_img_tokens += int(256 * ratio) | 
					
					
						
						| 
							 | 
						                elif base_size == 1280: | 
					
					
						
						| 
							 | 
						                    valid_img_tokens += int(400 * ratio) | 
					
					
						
						| 
							 | 
						                elif base_size == 640: | 
					
					
						
						| 
							 | 
						                    valid_img_tokens += int(100 * 1) | 
					
					
						
						| 
							 | 
						                elif base_size == 512: | 
					
					
						
						| 
							 | 
						                    valid_img_tokens += int(64 * 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                width_crop_num, height_crop_num = 1, 1 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                images_spatial_crop.append([width_crop_num, height_crop_num]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                """add image tokens""" | 
					
					
						
						| 
							 | 
						                num_queries = math.ceil((image_size // patch_size) / downsample_ratio) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries | 
					
					
						
						| 
							 | 
						                tokenized_image += [image_token_id] | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                tokenized_str += tokenized_image | 
					
					
						
						| 
							 | 
						                images_seq_mask += [True] * len(tokenized_image) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        """process the last text split""" | 
					
					
						
						| 
							 | 
						        tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) | 
					
					
						
						| 
							 | 
						        tokenized_str += tokenized_sep | 
					
					
						
						| 
							 | 
						        images_seq_mask += [False] * len(tokenized_sep) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        """add the bos tokens""" | 
					
					
						
						| 
							 | 
						        bos_id = 0 | 
					
					
						
						| 
							 | 
						        tokenized_str = [bos_id] + tokenized_str  | 
					
					
						
						| 
							 | 
						        images_seq_mask = [False] + images_seq_mask | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        input_ids = torch.LongTensor(tokenized_str) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if len(images_list) == 0: | 
					
					
						
						| 
							 | 
						            images_ori = torch.zeros((1, 3, image_size, image_size)) | 
					
					
						
						| 
							 | 
						            images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) | 
					
					
						
						| 
							 | 
						            images_crop = torch.zeros((1, 3, base_size, base_size)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            images_ori = torch.stack(images_list, dim=0) | 
					
					
						
						| 
							 | 
						            images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) | 
					
					
						
						| 
							 | 
						            if images_crop_list: | 
					
					
						
						| 
							 | 
						                images_crop = torch.stack(images_crop_list, dim=0) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                images_crop = torch.zeros((1, 3, base_size, base_size)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if not eval_mode: | 
					
					
						
						| 
							 | 
						            streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) | 
					
					
						
						| 
							 | 
						            with torch.autocast("cuda", dtype=torch.bfloat16): | 
					
					
						
						| 
							 | 
						                with torch.no_grad(): | 
					
					
						
						| 
							 | 
						                    output_ids = self.generate( | 
					
					
						
						| 
							 | 
						                        input_ids.unsqueeze(0).cuda(), | 
					
					
						
						| 
							 | 
						                        images=[(images_crop.cuda(), images_ori.cuda())], | 
					
					
						
						| 
							 | 
						                        images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), | 
					
					
						
						| 
							 | 
						                        images_spatial_crop = images_spatial_crop, | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        temperature=0.0, | 
					
					
						
						| 
							 | 
						                        eos_token_id=tokenizer.eos_token_id, | 
					
					
						
						| 
							 | 
						                        streamer=streamer, | 
					
					
						
						| 
							 | 
						                        max_new_tokens=8192, | 
					
					
						
						| 
							 | 
						                        no_repeat_ngram_size = 20, | 
					
					
						
						| 
							 | 
						                        use_cache = True | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            with torch.autocast("cuda", dtype=torch.bfloat16): | 
					
					
						
						| 
							 | 
						                with torch.no_grad(): | 
					
					
						
						| 
							 | 
						                    output_ids = self.generate( | 
					
					
						
						| 
							 | 
						                        input_ids.unsqueeze(0).cuda(), | 
					
					
						
						| 
							 | 
						                        images=[(images_crop.cuda(), images_ori.cuda())], | 
					
					
						
						| 
							 | 
						                        images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), | 
					
					
						
						| 
							 | 
						                        images_spatial_crop = images_spatial_crop, | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        temperature=0.0, | 
					
					
						
						| 
							 | 
						                        eos_token_id=tokenizer.eos_token_id, | 
					
					
						
						| 
							 | 
						                        max_new_tokens=8192, | 
					
					
						
						| 
							 | 
						                        no_repeat_ngram_size = 35, | 
					
					
						
						| 
							 | 
						                        use_cache = True | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if '<image>' in conversation[0]['content'] and eval_mode: | 
					
					
						
						| 
							 | 
						                outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) | 
					
					
						
						| 
							 | 
						                stop_str = '<|end▁of▁sentence|>' | 
					
					
						
						| 
							 | 
						                if outputs.endswith(stop_str): | 
					
					
						
						| 
							 | 
						                    outputs = outputs[:-len(stop_str)] | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                outputs = outputs.strip() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                return outputs | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if '<image>' in conversation[0]['content'] and test_compress: | 
					
					
						
						| 
							 | 
						            outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) | 
					
					
						
						| 
							 | 
						            pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) | 
					
					
						
						| 
							 | 
						            print('='*50) | 
					
					
						
						| 
							 | 
						            print('image size: ', (w, h)) | 
					
					
						
						| 
							 | 
						            print('valid image tokens: ', int(valid_img_tokens)) | 
					
					
						
						| 
							 | 
						            print('output texts tokens (valid): ', pure_texts_outputs_token_length) | 
					
					
						
						| 
							 | 
						            print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) | 
					
					
						
						| 
							 | 
						            print('='*50) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if '<image>' in conversation[0]['content'] and save_results: | 
					
					
						
						| 
							 | 
						            outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) | 
					
					
						
						| 
							 | 
						            stop_str = '<|end▁of▁sentence|>' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            print('='*15 + 'save results:' + '='*15) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if outputs.endswith(stop_str): | 
					
					
						
						| 
							 | 
						                outputs = outputs[:-len(stop_str)] | 
					
					
						
						| 
							 | 
						            outputs = outputs.strip() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            matches_ref, matches_images, mathes_other = re_match(outputs) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            result = process_image_with_refs(image_draw, matches_ref, output_path) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): | 
					
					
						
						| 
							 | 
						                outputs = outputs.replace(a_match_image, ' + '.jpg)\n') | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): | 
					
					
						
						| 
							 | 
						                outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: | 
					
					
						
						| 
							 | 
						                afile.write(outputs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if 'line_type' in outputs: | 
					
					
						
						| 
							 | 
						                import matplotlib.pyplot as plt | 
					
					
						
						| 
							 | 
						                lines = eval(outputs)['Line']['line'] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                line_type = eval(outputs)['Line']['line_type'] | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                endpoints = eval(outputs)['Line']['line_endpoint'] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                fig, ax = plt.subplots(figsize=(3,3), dpi=200) | 
					
					
						
						| 
							 | 
						                ax.set_xlim(-15, 15) | 
					
					
						
						| 
							 | 
						                ax.set_ylim(-15, 15) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                for idx, line in enumerate(lines): | 
					
					
						
						| 
							 | 
						                    try: | 
					
					
						
						| 
							 | 
						                        p0 = eval(line.split(' -- ')[0]) | 
					
					
						
						| 
							 | 
						                        p1 = eval(line.split(' -- ')[-1]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        if line_type[idx] == '--': | 
					
					
						
						| 
							 | 
						                            ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') | 
					
					
						
						| 
							 | 
						                        else: | 
					
					
						
						| 
							 | 
						                            ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        ax.scatter(p0[0], p0[1], s=5, color = 'k') | 
					
					
						
						| 
							 | 
						                        ax.scatter(p1[0], p1[1], s=5, color = 'k') | 
					
					
						
						| 
							 | 
						                    except: | 
					
					
						
						| 
							 | 
						                        pass | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                for endpoint in endpoints: | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    label = endpoint.split(': ')[0] | 
					
					
						
						| 
							 | 
						                    (x, y) = eval(endpoint.split(': ')[1]) | 
					
					
						
						| 
							 | 
						                    ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',  | 
					
					
						
						| 
							 | 
						                                fontsize=5, fontweight='light') | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                plt.savefig(f'{output_path}/geo.jpg') | 
					
					
						
						| 
							 | 
						                plt.close() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            result.save(f"{output_path}/result_with_boxes.jpg") | 
					
					
						
						| 
							 | 
						
 |