import torch from torch import nn from transformers import (AutoModel, GenerationConfig, Qwen2_5_VLForConditionalGeneration, Qwen2ForCausalLM) from transformers.modeling_utils import PreTrainedModel from .configuration_sa2va_chat import Sa2VAChatConfigQwen from .sam2 import SAM2 import numpy as np from torchvision.transforms.functional import to_pil_image import torch.nn.functional as F from qwen_vl_utils import process_vision_info class DirectResize: def __init__(self, target_length: int) -> None: self.target_length = target_length def apply_image(self, image: np.ndarray) -> np.ndarray: """ Expects a numpy array with shape HxWxC in uint8 format. """ img = to_pil_image(image, mode='RGB') return np.array(img.resize((self.target_length, self.target_length))) class Sa2VAChatModelQwen(PreTrainedModel): config_class = Sa2VAChatConfigQwen main_input_name = 'pixel_values' base_model_prefix = 'language_model' _no_split_modules = ['Qwen2_5_VisionTransformerPretrainedModel', 'Qwen2_5_VLDecoderLayer', 'SAM2'] _supports_flash_attn_2 = True supports_gradient_checkpointing = True def __init__(self, config: Sa2VAChatConfigQwen, model=None, use_flash_attn=True): super().__init__(config) self.extra_image_processor = DirectResize(target_length=1024, ) self.min_pixels = 512 * 28 * 28 self.max_pixels = 2048 * 28 * 28 self.torch_dtype = torch.bfloat16 if model is not None: self.model=model else: self.model = Qwen2_5_VLForConditionalGeneration(config) self.model._tied_weights_keys = None llm_hidden_size = config.text_config.hidden_size self.grounding_encoder = SAM2() out_dim = self.grounding_encoder.hidden_dim in_dim = llm_hidden_size self.text_hidden_fcs = nn.Sequential( nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim), nn.Dropout(0.0) ) @property def lm_head(self): return self.model.lm_head def get_input_embeddings(self): return self.model.get_input_embeddings() def get_output_embeddings(self): return self.model.get_output_embeddings() def predict_forward( self, image=None, video=None, text=None, past_text='', mask_prompts=None, tokenizer=None, processor=None, ): assert processor is not None self.processor = processor self.seg_token_idx = self.processor.tokenizer.convert_tokens_to_ids('[SEG]') text = text.replace('', "") if image is None and video is None and '' not in past_text: messages = [ { "role": "user", "content": [ {"type": "text", "text": past_text + text}, ], } ] # Preparation for inference processsed_text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) mm_inputs = self.processor( text=[processsed_text], images=None, videos=None, padding=True, return_tensors="pt", ) mm_inputs = mm_inputs.to(self.device) ret_masks = [] else: input_dict = {} if video is not None: pixel_values = [] extra_pixel_values = [] images = [] content = [] ori_image_size = video[0].size for frame_idx, frame_image in enumerate(video): # assert ori_image_size == frame_image.size g_image = np.array(frame_image) # for grounding g_image = self.extra_image_processor.apply_image(g_image) g_image = torch.from_numpy(g_image).permute(2, 0, 1).contiguous() extra_pixel_values.append(g_image) if frame_idx < 5: content.append({"type": "image", "image": frame_image},) content.append({"type": "text", "text": text}) messages = [ { "role": "user", "content": content, } ] # Preparation for inference processsed_text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) mm_inputs = self.processor( text=[processsed_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", min_pixels=self.min_pixels, max_pixels=self.max_pixels ) mm_inputs = mm_inputs.to(self.device) g_pixel_values = torch.stack([ self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values ]).to(self.torch_dtype) num_frames = min(5, len(video)) else: ori_image_size = image.size # prepare grounding images g_image = np.array(image) # for grounding g_image = self.extra_image_processor.apply_image(g_image) g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous().to(self.torch_dtype) extra_pixel_values = [g_pixel_values] g_pixel_values = torch.stack([ self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values ]).to(self.torch_dtype) messages = [ { "role": "user", "content": [ { "type": "image", "image": image, }, {"type": "text", "text": text}, ], } ] # Preparation for inference processsed_text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) mm_inputs = self.processor( text=[processsed_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", min_pixels=self.min_pixels, max_pixels=self.max_pixels ) mm_inputs = mm_inputs.to(self.device) num_frames = 1 input_dict['g_pixel_values'] = g_pixel_values ret_masks = [] generate_output = self.model.generate( **mm_inputs, max_new_tokens=2048, do_sample=False, output_hidden_states=True, return_dict_in_generate=True ) generate_output_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(mm_inputs.input_ids, generate_output.sequences) ] predict = self.processor.batch_decode(generate_output_trimmed, skip_special_tokens=False)[0].strip() if image is None and video is None and '' not in past_text: return {'prediction': predict, 'prediction_masks': ret_masks, } # if have seg result, find the seg hidden states hidden_states = generate_output.hidden_states last_hidden_states = [item[-1][0] for item in hidden_states] last_hidden_states = torch.cat(last_hidden_states, dim=0) seg_hidden_states = get_seg_hidden_states( last_hidden_states, generate_output.sequences[0][:-1], seg_id=self.seg_token_idx ) all_seg_hidden_states = self.text_hidden_fcs(seg_hidden_states) for seg_hidden_states in all_seg_hidden_states: seg_hidden_states = seg_hidden_states.unsqueeze(0) g_pixel_values = input_dict['g_pixel_values'] sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values) pred_masks = self.grounding_encoder.language_embd_inference(sam_states, [seg_hidden_states] * num_frames) w, h = ori_image_size masks = F.interpolate(pred_masks, size=(h, w), mode='bilinear', align_corners=False) masks = masks[:, 0] masks = masks.sigmoid() > 0.5 masks = masks.cpu().numpy() ret_masks.append(masks) return {'prediction': predict, 'prediction_masks': ret_masks,} def get_seg_hidden_states(hidden_states, output_ids, seg_id): seg_mask = output_ids == seg_id n_out = len(seg_mask) if n_out == 0: return hidden_states[0:0] return hidden_states[-n_out:][seg_mask]