|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import (AutoModel, GenerationConfig, Qwen3VLForConditionalGeneration, |
|
|
Qwen2ForCausalLM) |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
from .configuration_sa2va_chat import Sa2VAChatConfigQwen |
|
|
|
|
|
from .sam2 import SAM2 |
|
|
|
|
|
import numpy as np |
|
|
from torchvision.transforms.functional import to_pil_image |
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
from qwen_vl_utils import process_vision_info |
|
|
|
|
|
|
|
|
|
|
|
class DirectResize: |
|
|
def __init__(self, target_length: int) -> None: |
|
|
self.target_length = target_length |
|
|
|
|
|
def apply_image(self, image: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Expects a numpy array with shape HxWxC in uint8 format. |
|
|
""" |
|
|
img = to_pil_image(image, mode='RGB') |
|
|
return np.array(img.resize((self.target_length, self.target_length))) |
|
|
|
|
|
class Sa2VAChatModelQwen(PreTrainedModel): |
|
|
config_class = Sa2VAChatConfigQwen |
|
|
main_input_name = 'pixel_values' |
|
|
base_model_prefix = 'language_model' |
|
|
_no_split_modules = ['Qwen3VisionTransformerPretrainedModel', 'Qwen3VLDecoderLayer', 'SAM2'] |
|
|
_supports_flash_attn_2 = True |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Sa2VAChatConfigQwen, model=None, use_flash_attn=True): |
|
|
super().__init__(config) |
|
|
self.extra_image_processor = DirectResize(target_length=1024, ) |
|
|
|
|
|
self.min_pixels = 512 * 28 * 28 |
|
|
self.max_pixels = 2048 * 28 * 28 |
|
|
|
|
|
self.torch_dtype = torch.bfloat16 |
|
|
|
|
|
if model is not None: |
|
|
self.model=model |
|
|
else: |
|
|
self.model = Qwen3VLForConditionalGeneration(config) |
|
|
|
|
|
llm_hidden_size = config.text_config.hidden_size |
|
|
|
|
|
self.grounding_encoder = SAM2() |
|
|
out_dim = self.grounding_encoder.hidden_dim |
|
|
in_dim = llm_hidden_size |
|
|
self.text_hidden_fcs = nn.Sequential( |
|
|
nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), |
|
|
nn.Linear(in_dim, out_dim), nn.Dropout(0.0) |
|
|
) |
|
|
|
|
|
@property |
|
|
def lm_head(self): |
|
|
return self.model.lm_head |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.model.get_input_embeddings() |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.model.get_output_embeddings() |
|
|
|
|
|
def predict_forward( |
|
|
self, |
|
|
image=None, |
|
|
video=None, |
|
|
text=None, |
|
|
past_text='', |
|
|
mask_prompts=None, |
|
|
tokenizer=None, |
|
|
processor=None, |
|
|
): |
|
|
assert processor is not None |
|
|
self.processor = processor |
|
|
|
|
|
self.seg_token_idx = self.processor.tokenizer.convert_tokens_to_ids('[SEG]') |
|
|
|
|
|
text = text.replace('<image>', "") |
|
|
|
|
|
if image is None and video is None and '<image>' not in past_text: |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": past_text + text}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
processsed_text = self.processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
mm_inputs = self.processor( |
|
|
text=[processsed_text], |
|
|
images=None, |
|
|
videos=None, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
mm_inputs = mm_inputs.to(self.device) |
|
|
|
|
|
ret_masks = [] |
|
|
else: |
|
|
input_dict = {} |
|
|
if video is not None: |
|
|
pixel_values = [] |
|
|
extra_pixel_values = [] |
|
|
images = [] |
|
|
content = [] |
|
|
ori_image_size = video[0].size |
|
|
for frame_idx, frame_image in enumerate(video): |
|
|
|
|
|
g_image = np.array(frame_image) |
|
|
g_image = self.extra_image_processor.apply_image(g_image) |
|
|
g_image = torch.from_numpy(g_image).permute(2, 0, 1).contiguous() |
|
|
extra_pixel_values.append(g_image) |
|
|
if frame_idx < 5: |
|
|
content.append({"type": "image", "image": frame_image},) |
|
|
|
|
|
|
|
|
content.append({"type": "text", "text": text}) |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": content, |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
processsed_text = self.processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
mm_inputs = self.processor( |
|
|
text=[processsed_text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
min_pixels=self.min_pixels, |
|
|
max_pixels=self.max_pixels |
|
|
) |
|
|
mm_inputs = mm_inputs.to(self.device) |
|
|
|
|
|
g_pixel_values = torch.stack([ |
|
|
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values |
|
|
]).to(self.torch_dtype) |
|
|
|
|
|
num_frames = min(5, len(video)) |
|
|
|
|
|
else: |
|
|
ori_image_size = image.size |
|
|
|
|
|
|
|
|
g_image = np.array(image) |
|
|
g_image = self.extra_image_processor.apply_image(g_image) |
|
|
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous().to(self.torch_dtype) |
|
|
extra_pixel_values = [g_pixel_values] |
|
|
g_pixel_values = torch.stack([ |
|
|
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values |
|
|
]).to(self.torch_dtype) |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": image, |
|
|
}, |
|
|
{"type": "text", "text": text}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
processsed_text = self.processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
mm_inputs = self.processor( |
|
|
text=[processsed_text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
min_pixels=self.min_pixels, |
|
|
max_pixels=self.max_pixels |
|
|
) |
|
|
mm_inputs = mm_inputs.to(self.device) |
|
|
|
|
|
num_frames = 1 |
|
|
|
|
|
input_dict['g_pixel_values'] = g_pixel_values |
|
|
ret_masks = [] |
|
|
|
|
|
generate_output = self.model.generate( |
|
|
**mm_inputs, |
|
|
max_new_tokens=2048, |
|
|
do_sample=False, |
|
|
output_hidden_states=True, |
|
|
return_dict_in_generate=True |
|
|
) |
|
|
|
|
|
generate_output_trimmed = [ |
|
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(mm_inputs.input_ids, generate_output.sequences) |
|
|
] |
|
|
|
|
|
predict = self.processor.batch_decode(generate_output_trimmed, skip_special_tokens=False)[0].strip() |
|
|
|
|
|
if image is None and video is None and '<image>' not in past_text: |
|
|
return {'prediction': predict, 'prediction_masks': ret_masks, } |
|
|
|
|
|
|
|
|
hidden_states = generate_output.hidden_states |
|
|
last_hidden_states = [item[-1][0] for item in hidden_states] |
|
|
last_hidden_states = torch.cat(last_hidden_states, dim=0) |
|
|
seg_hidden_states = get_seg_hidden_states( |
|
|
last_hidden_states, generate_output.sequences[0][:-1], |
|
|
seg_id=self.seg_token_idx |
|
|
) |
|
|
all_seg_hidden_states = self.text_hidden_fcs(seg_hidden_states) |
|
|
|
|
|
for seg_hidden_states in all_seg_hidden_states: |
|
|
seg_hidden_states = seg_hidden_states.unsqueeze(0) |
|
|
g_pixel_values = input_dict['g_pixel_values'] |
|
|
sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values) |
|
|
pred_masks = self.grounding_encoder.language_embd_inference(sam_states, [seg_hidden_states] * num_frames) |
|
|
w, h = ori_image_size |
|
|
masks = F.interpolate(pred_masks, size=(h, w), mode='bilinear', align_corners=False) |
|
|
masks = masks[:, 0] |
|
|
masks = masks.sigmoid() > 0.5 |
|
|
masks = masks.cpu().numpy() |
|
|
ret_masks.append(masks) |
|
|
|
|
|
return {'prediction': predict, 'prediction_masks': ret_masks,} |
|
|
|
|
|
def get_seg_hidden_states(hidden_states, output_ids, seg_id): |
|
|
seg_mask = output_ids == seg_id |
|
|
n_out = len(seg_mask) |
|
|
if n_out == 0: |
|
|
return hidden_states[0:0] |
|
|
return hidden_states[-n_out:][seg_mask] |