| import logging | |
| from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | |
| import torch | |
| from PIL import Image | |
| from typing import List, Optional, Tuple, Union, cast | |
| import numpy as np | |
| from tqdm import tqdm | |
| import sys | |
| import os | |
| from torch.utils.data import DataLoader | |
| from torch import nn | |
| logger = logging.getLogger(__name__) | |
| class Qwen2VLForEmbedding(Qwen2VLForConditionalGeneration): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| pixel_values: Optional[torch.Tensor] = None, | |
| pixel_values_videos: Optional[torch.FloatTensor] = None, | |
| image_grid_thw: Optional[torch.LongTensor] = None, | |
| video_grid_thw: Optional[torch.LongTensor] = None, | |
| rope_deltas: Optional[torch.LongTensor] = None, | |
| ): | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if inputs_embeds is None: | |
| inputs_embeds = self.model.embed_tokens(input_ids) | |
| if pixel_values is not None: | |
| pixel_values = pixel_values.type(self.visual.get_dtype()) | |
| image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) | |
| image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) | |
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) | |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) | |
| if pixel_values_videos is not None: | |
| pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) | |
| video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) | |
| video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) | |
| video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) | |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(inputs_embeds.device) | |
| outputs = self.model( | |
| input_ids=None, | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = outputs[0] | |
| embeddings = hidden_states[:, -1, :] | |
| embeddings = torch.nn.functional.normalize(embeddings, dim=-1) | |
| return embeddings | |
| def set_processor(self, model_name_or_path, max_len=3072, eos_token_id=151643, min_image_token=64, max_image_token=2500): | |
| self.max_len = max_len | |
| self.eos_token_id = eos_token_id | |
| self.processor = AutoProcessor.from_pretrained( | |
| model_name_or_path, | |
| min_pixels=min_image_token * 28 * 28, | |
| max_pixels=max_image_token * 28 * 28 | |
| ) | |
| assert self.processor.tokenizer.padding_side == 'left' | |
| def prepare_text_input(self, image=None, text=None, q_or_c=None, task_instruction=None): | |
| assert q_or_c in ["query", "candidate", "q", "c"] | |
| prompt_template = "<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>" | |
| if "q" in q_or_c: | |
| if task_instruction is None: | |
| system_prompt = "You are a helpful assistant." | |
| task_instruction_example_csr = "Represent the given image with the given query." | |
| print(f"""Warning: For optimal performance, UniSE-MLLM requires the task instruction to be specified in the query. For example, for the Composed Screenshot Retrieval task, you might use a specific instruction like: {task_instruction_example_csr}.""") | |
| else: | |
| system_prompt = task_instruction | |
| if image is None: | |
| user_prompt = text | |
| else: | |
| if text is not None: | |
| user_prompt = f"Query:{text}<|vision_start|><|image_pad|><|vision_end|>" | |
| else: | |
| user_prompt = "<|vision_start|><|image_pad|><|vision_end|>" | |
| text_input = prompt_template.format(system_prompt, user_prompt) | |
| else: | |
| if text is not None: | |
| system_prompt = "Represent the given text." | |
| user_prompt = f"{text}" | |
| if image is not None: | |
| system_prompt = "Represent the given text-rich image, focusing on extracting and interpreting both its rich text content and visual features." | |
| user_prompt = f"<|vision_start|><|image_pad|><|vision_end|>" | |
| text_input = prompt_template.format(system_prompt, user_prompt) | |
| return text_input | |
| def data_process(self, images=None, text=None, q_or_c=None, task_instruction=None): | |
| if images is not None: | |
| _is_list = isinstance(images, list) | |
| elif text is not None: | |
| _is_list = isinstance(text, list) | |
| else: | |
| raise ValueError("images and text cannot be both None.") | |
| assert q_or_c in ["query", "candidate", "q", "c"] | |
| if not _is_list : | |
| text_input = self.prepare_text_input(images, text, q_or_c, task_instruction) | |
| text_input = [text_input] | |
| if images is not None: | |
| images = Image.open(images).convert("RGB") | |
| images = [images] | |
| inputs = self.processor(images=images, text=text_input, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len) | |
| else: | |
| inputs = self.processor(text=text_input, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len) | |
| if inputs.input_ids.size(-1) == self.max_len: | |
| inputs.input_ids[:, -1] = self.eos_token_id | |
| assert (inputs.input_ids[:, -1] == self.eos_token_id).all() | |
| assert (inputs.attention_mask[:, -1] == 1).all() | |
| else: | |
| if text is None: | |
| text = [None] * len(images) | |
| text_input = [self.prepare_text_input(_image, _text, q_or_c, task_instruction) for _image, _text in zip(images, text)] | |
| if images is not None: | |
| images = [Image.open(_image).convert("RGB") for _image in images] | |
| inputs = self.processor(images=images, text=text_input, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len) | |
| else: | |
| inputs = self.processor(text=text_input, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len) | |
| if inputs.input_ids.size(-1) == self.max_len: | |
| inputs.input_ids[:, -1] = self.eos_token_id | |
| assert (inputs.input_ids[:, -1] == self.eos_token_id).all() | |
| assert (inputs.attention_mask[:, -1] == 1).all() | |
| inputs = inputs.to(self.device) | |
| return inputs | |