Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| from hydra import compose | |
| from hydra.utils import instantiate | |
| from nncore.nn import constant_init_, xavier_init_ | |
| from transformers import (AutoConfig, AutoModel, AutoProcessor, Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, | |
| Qwen2_5_VLModel, Qwen2_5_VLProcessor, Qwen2_5_VLTextModel) | |
| from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES | |
| from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel, Qwen2RMSNorm | |
| from sam2.loss_fns import MultiStepMultiMasksAndIous | |
| from sam2.modeling.position_encoding import PositionEmbedding1DRandom | |
| from sam2.modeling.sam.prompt_encoder import PromptEncoder | |
| from sam2.sam2_train import BatchedVideoDatapoint | |
| def cache_state_hook(module, inputs, outputs=None): | |
| module.state = inputs[0] if isinstance(inputs, tuple) else inputs | |
| class PatchedQwen2_5_VLProcessor(Qwen2_5_VLProcessor): | |
| def _check_special_mm_tokens(self, text, *args, **kwargs): | |
| self.cache_text = text | |
| return super()._check_special_mm_tokens(text, *args, **kwargs) | |
| class PixelQwen2_5_VLConfig(Qwen2_5_VLConfig): | |
| model_type = 'pixel_qwen2_5_vl' | |
| class PixelQwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrainedModel): | |
| def __init__(self, config, *args, **kwargs): | |
| super().__init__(config, *args, **kwargs) | |
| self.merger.mlp.register_forward_pre_hook(cache_state_hook) | |
| class PixelQwen2_5_VLModel(Qwen2_5_VLModel): | |
| config_class = PixelQwen2_5_VLConfig | |
| def __init__(self, config): | |
| super(Qwen2_5_VLModel, self).__init__(config) | |
| self.visual = PixelQwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) | |
| self.language_model = Qwen2_5_VLTextModel._from_config(config.text_config) | |
| self.rope_deltas = None | |
| self.post_init() | |
| self.language_model.norm.register_forward_pre_hook(cache_state_hook) | |
| class PixelQwen2_5_VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): | |
| config_class = PixelQwen2_5_VLConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = PixelQwen2_5_VLModel(config) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| if self.config.sam2_config is not None: | |
| overrides = [f'++model.image_size={self.config.sam2_image_size}'] | |
| if self.config.sam2_inference_mode: | |
| overrides.append('++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor') | |
| cfg = compose(config_name=self.config.sam2_config, overrides=overrides) | |
| self.sam2 = instantiate(cfg.model) | |
| sam_dim, llm_dim = self.sam2.hidden_dim, self.config.hidden_size | |
| self.seg_head = nn.Sequential( | |
| Qwen2RMSNorm(llm_dim), nn.Linear(llm_dim, llm_dim), nn.GELU(), | |
| nn.Linear(llm_dim, sam_dim * self.config.sam2_hidden_tokens)) | |
| self.ref_encoder = PromptEncoder( | |
| embed_dim=sam_dim, | |
| image_embedding_size=(self.sam2.sam_image_embedding_size, self.sam2.sam_image_embedding_size), | |
| input_image_size=(self.config.sam2_image_size, self.config.sam2_image_size), | |
| mask_in_chans=16) | |
| self.ref_proj_single = nn.Linear(sam_dim * 2, sam_dim * 3) | |
| self.ref_proj_double = nn.Linear(sam_dim * 3, sam_dim * 3) | |
| self.ref_proj = nn.Sequential(nn.GELU(), nn.Linear(sam_dim * 6, llm_dim)) | |
| self.tem_pe = PositionEmbedding1DRandom(sam_dim // 2) | |
| self.tem_emb = nn.Embedding(1, sam_dim) | |
| self.tem_proj = nn.Linear(sam_dim, sam_dim * 3) | |
| self.msk_proj = nn.Sequential( | |
| nn.Linear(self.visual.merger.hidden_size, self.visual.merger.hidden_size), nn.GELU(), | |
| nn.Linear(self.visual.merger.hidden_size, llm_dim)) | |
| self.loss_seg = MultiStepMultiMasksAndIous( | |
| dict(loss_mask=100, loss_dice=5, loss_iou=5, loss_class=5), | |
| supervise_all_iou=True, | |
| iou_use_l1_loss=True, | |
| pred_obj_scores=True, | |
| focal_alpha=0.25, | |
| focal_gamma=2.0, | |
| focal_alpha_obj_score=-1.0, | |
| focal_gamma_obj_score=0.0) | |
| self.post_init() | |
| def init_parameters(self): | |
| # initialize ref_encoder with weights from sam2.sam_prompt_encoder | |
| for p0, p1 in zip(self.ref_encoder.parameters(), self.sam2.sam_prompt_encoder.parameters()): | |
| p0.copy_(p1) | |
| # initialize msk_proj with weights from visual.merger.mlp | |
| for p0, p1 in zip(self.msk_proj.parameters(), self.visual.merger.mlp.parameters()): | |
| p0.copy_(p1) | |
| # reset extra parameters | |
| for s in ('seg_head', 'ref_proj_single', 'ref_proj_double', 'ref_proj', 'tem_proj'): | |
| b = getattr(self, s, None) | |
| if b is None: | |
| continue | |
| for n, m in b.named_modules(): | |
| if isinstance(m, nn.Linear): | |
| print(f'Reset parameters of {b.__class__.__name__} {n} ({m.__class__.__name__})') | |
| xavier_init_(m, distribution='uniform') | |
| elif isinstance(m, nn.LayerNorm): | |
| print(f'Reset parameters of {b.__class__.__name__} {n} ({m.__class__.__name__})') | |
| constant_init_(m) | |
| def load_sam2_weights(self): | |
| state_dict = torch.load(self.config.sam2_checkpoint, map_location=self.sam2.device, weights_only=True)['model'] | |
| state_dict['memory_encoder.fuser.layers.0.weight'] = state_dict.pop('memory_encoder.fuser.layers.0.gamma') | |
| state_dict['memory_encoder.fuser.layers.1.weight'] = state_dict.pop('memory_encoder.fuser.layers.1.gamma') | |
| self.sam2.load_state_dict(state_dict) | |
| def forward(self, | |
| input_ids=None, | |
| attention_mask=None, | |
| position_ids=None, | |
| past_key_values=None, | |
| inputs_embeds=None, | |
| labels=None, | |
| use_cache=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| pixel_values=None, | |
| pixel_values_videos=None, | |
| image_grid_thw=None, | |
| video_grid_thw=None, | |
| rope_deltas=None, | |
| cache_position=None, | |
| second_per_grid_ts=None, | |
| frames=None, | |
| frame_size=None, | |
| point_coords=None, | |
| point_labels=None, | |
| point_frames=None, | |
| refer_mask=None, | |
| label_obj_to_frame_idx=None, | |
| label_mask=None): | |
| if caching := not self.training and (past_key_values is None or len(past_key_values) == 0): | |
| self.seg = [] | |
| # move input_ids to the correct device (in case of auto device map) | |
| input_ids = input_ids.to(self.model.language_model.embed_tokens.weight.device) | |
| if inputs_embeds is None: | |
| inputs_embeds = self.get_input_embeddings()(input_ids) | |
| device, dtype = inputs_embeds.device, inputs_embeds.dtype | |
| if pixel_values is not None: | |
| image_embeds = self.get_image_features(pixel_values, image_grid_thw) | |
| image_embeds = torch.cat(image_embeds) | |
| n_image_tokens = (input_ids == self.config.image_token_id).sum() | |
| n_image_features = image_embeds.shape[0] | |
| assert n_image_tokens == n_image_features | |
| mask = input_ids == self.config.image_token_id | |
| mask_unsqueezed = mask.unsqueeze(-1) | |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) | |
| image_mask = mask_expanded.to(device) | |
| image_embeds = image_embeds.to(device, dtype) | |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) | |
| if pixel_values_videos is not None: | |
| video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) | |
| video_embeds = torch.cat(video_embeds) | |
| n_video_tokens = (input_ids == self.config.video_token_id).sum() | |
| n_video_features = video_embeds.shape[0] | |
| assert n_video_tokens == n_video_features | |
| mask = input_ids == self.config.video_token_id | |
| mask_unsqueezed = mask.unsqueeze(-1) | |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) | |
| video_mask = mask_expanded.to(device) | |
| video_embeds = video_embeds.to(device, dtype) | |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) | |
| if any(k is not None for k in (point_coords, point_labels, point_frames)): | |
| assert all(k is not None for k in (point_coords, point_labels, point_frames)) | |
| ref = [] | |
| for batch_idx in range(video_grid_thw.size(0)): | |
| for obj_point_coords, obj_point_labels in zip(point_coords[batch_idx], point_labels[batch_idx]): | |
| obj_ref, _ = self.ref_encoder((obj_point_coords, obj_point_labels), None, None, None) | |
| assert obj_ref.size(1) in (2, 3), obj_ref.size() | |
| if obj_ref.size(1) == 2: | |
| obj_ref = self.ref_proj_single(obj_ref.flatten(1)) | |
| else: | |
| obj_ref = self.ref_proj_double(obj_ref.flatten(1)) | |
| ref.append(obj_ref) | |
| ref = torch.cat(ref) | |
| tem = [] | |
| for batch_idx in range(video_grid_thw.size(0)): | |
| # temporal merge size set to 2 | |
| size = video_grid_thw[batch_idx][0].item() * 2 | |
| for obj_point_frames in point_frames[batch_idx]: | |
| obj_tem = obj_point_frames.unsqueeze(0).float() | |
| obj_tem = self.tem_pe.forward_with_coords(obj_tem, size) | |
| assert obj_tem.size(0) == 1, obj_tem.size() | |
| tem.append(obj_tem[0]) | |
| tem = torch.cat(tem) | |
| tem = tem + self.tem_emb(torch.LongTensor([0]).to(device)) | |
| tem = self.tem_proj(tem) | |
| ref_emb = self.ref_proj(torch.cat((ref, tem), dim=1)).to(device, dtype) | |
| ref_mask = input_ids == self.config.ref_token_id | |
| # replace only the <ref> tokens in the instruction | |
| # ref_mask = ref_mask * (labels == IGNORE_INDEX) if self.training else ref_mask | |
| ref_mask = ref_mask.unsqueeze(-1).expand_as(inputs_embeds).to(device) | |
| inputs_embeds = inputs_embeds.masked_scatter(ref_mask, ref_emb) | |
| if refer_mask is not None: | |
| mem, base_idx = [], 0 | |
| for batch_idx in range(video_grid_thw.size(0)): | |
| size = video_grid_thw[batch_idx].prod().item() // 4 | |
| step = video_grid_thw[batch_idx][1] * video_grid_thw[batch_idx][2] // 4 | |
| # emb = self.model.visual.merger.ln_q.state[base_idx:base_idx + size] | |
| # map grouped order back to raster scan order | |
| # dim = emb.size(1) | |
| # emb = emb.permute(1, 0).reshape(dim, -1, 2, 2).permute(0, 2, 1, 3).reshape(dim, -1).permute(1, 0) | |
| emb = self.model.visual.merger.mlp.state[base_idx:base_idx + size] | |
| batch_refer_mask = refer_mask[batch_idx] | |
| for obj_idx in range(batch_refer_mask.size(1)): | |
| mask = batch_refer_mask[:, obj_idx].flatten() | |
| assert mask.size(0) == emb.size(0) == size | |
| obj_emb = [] | |
| for i in range(0, size, step): | |
| frame_mask = mask[i:i + step] | |
| if frame_mask.any(): | |
| obj_emb.append(emb[i:i + step][frame_mask].mean(dim=0)) | |
| if len(obj_emb) > 0: | |
| obj_emb = torch.stack(obj_emb) | |
| mem.append(obj_emb) | |
| base_idx = base_idx + size | |
| mem_mask = input_ids == self.config.mem_token_id | |
| if len(mem) > 0: | |
| mem_emb = self.msk_proj(torch.cat(mem)) | |
| mem_mask = mem_mask.unsqueeze(-1).expand_as(inputs_embeds).to(device) | |
| assert mem_emb.size(0) == mem_mask.all(dim=-1).sum(), (mem_emb.size(), mem_mask.all(dim=-1).sum()) | |
| inputs_embeds = inputs_embeds.masked_scatter(mem_mask, mem_emb) | |
| else: | |
| assert not mem_mask.any() | |
| # ensure gradient tracking (in case that embed_tokens has been frozen) | |
| if self.training and not inputs_embeds.requires_grad: | |
| inputs_embeds.requires_grad = True | |
| outputs = super().forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| labels=labels, | |
| use_cache=not self.training, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=True, | |
| pixel_values=pixel_values, | |
| pixel_values_videos=pixel_values_videos, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| rope_deltas=rope_deltas, | |
| cache_position=cache_position, | |
| second_per_grid_ts=second_per_grid_ts) | |
| if self.config.sam2_config is not None and self.config.sam2_enable_decoder and frames is not None: | |
| # decoder block -> -2 -> decoder block -> state -> norm -> -1 | |
| seg_tokens_all = self.seg_head(self.model.language_model.norm.state) | |
| seg_tokens_all = seg_tokens_all.reshape(*seg_tokens_all.shape[:2], self.config.sam2_hidden_tokens, -1) | |
| if self.training and label_obj_to_frame_idx is not None and label_mask is not None: | |
| loss_seg_all, avg_factor = 0, 0 | |
| shift_inputs = input_ids[..., 1:].contiguous() | |
| for batch_idx, (obj_to_frame_idx, mask) in enumerate(zip(label_obj_to_frame_idx, label_mask)): | |
| # supervise all <seg> tokens (including those in inputs) | |
| inds = torch.where(shift_inputs[batch_idx] == self.config.seg_token_id)[0] | |
| assert inds.size(0) == mask.size(1) | |
| if self.config.sample_objects > 0 and inds.size(0) > self.config.sample_objects: | |
| sample_inds = random.sample(list(range(inds.size(0))), self.config.sample_objects) | |
| obj_to_frame_idx = obj_to_frame_idx[:, sample_inds] | |
| inds = inds[sample_inds] | |
| mask = mask[:, sample_inds] | |
| if self.config.sam2_batch_mode: | |
| seg_tokens = seg_tokens_all[batch_idx][inds].repeat(mask.size(0), 1, 1) # (t * o) * 2 * c | |
| img_batch = frames[batch_idx].unsqueeze(0) # 1 * t * c * h * w | |
| masks = mask.view(1, -1, mask.size(2), mask.size(3)) # 1 * (t * o) * h * w | |
| else: | |
| seg_tokens = seg_tokens_all[batch_idx][inds] # o * 2 * c | |
| img_batch = frames[batch_idx].unsqueeze(1) # t * 1 * c * h * w | |
| masks = mask # t * o * h * w | |
| data = BatchedVideoDatapoint(img_batch=img_batch, obj_to_frame_idx=obj_to_frame_idx, masks=masks) | |
| pred = self.sam2(data, seg_tokens) | |
| loss_seg = self.loss_seg(pred, masks) | |
| loss_seg = loss_seg['core_loss'] / masks.size(0) | |
| loss_seg_all += loss_seg | |
| avg_factor += 1 | |
| assert avg_factor > 0 | |
| outputs.loss = outputs.loss + loss_seg_all / avg_factor | |
| else: | |
| assert len(frames) == len(frame_size) == 1 | |
| seg_tokens = [] | |
| if caching: | |
| # case 1: input contains <seg> | |
| shift_inputs = input_ids[..., 1:].contiguous() | |
| inds = torch.where(shift_inputs[0] == self.config.seg_token_id)[0].to(seg_tokens_all.device) | |
| seg_tokens += [t for t in seg_tokens_all[0][inds].unsqueeze(1)] | |
| if outputs.logits[0, -1].argmax() == self.config.seg_token_id: | |
| # case 2: output contains <seg> | |
| seg_tokens.append(seg_tokens_all[0, -1, None]) | |
| for seg_token in seg_tokens: | |
| if self.config.sam2_batch_mode: | |
| pred_mask = [] | |
| for idx in range(frames[0].size(0)): | |
| state = self.sam2.init_state(frames[0][idx, None], frame_size[0]) | |
| self.sam2.add_new_hidden_state(state, 0, 0, seg_token) | |
| pred_mask += [o[2] for o in self.sam2.propagate_in_video(state, verbose=False)] | |
| pred_mask = torch.cat(pred_mask, dim=1) | |
| else: | |
| state = self.sam2.init_state(frames[0], frame_size[0]) | |
| self.sam2.add_new_hidden_state(state, 0, 0, seg_token) | |
| pred_mask = torch.cat([o[2] for o in self.sam2.propagate_in_video(state, verbose=False)], dim=1) | |
| assert pred_mask.size(1) == frames[0].size(0) | |
| self.seg.append((pred_mask > 0).cpu()) | |
| return outputs | |
| def prepare_inputs_for_generation(self, | |
| *args, | |
| cache_position=None, | |
| frames=None, | |
| frame_size=None, | |
| point_coords=None, | |
| point_labels=None, | |
| point_frames=None, | |
| refer_mask=None, | |
| **kwargs): | |
| model_inputs = super().prepare_inputs_for_generation(*args, cache_position=cache_position, **kwargs) | |
| model_inputs.update({ | |
| 'frames': frames, | |
| 'frame_size': frame_size, | |
| 'point_coords': point_coords if cache_position[0] == 0 else None, | |
| 'point_labels': point_labels if cache_position[0] == 0 else None, | |
| 'point_frames': point_frames if cache_position[0] == 0 else None, | |
| 'refer_mask': refer_mask if cache_position[0] == 0 else None | |
| }) | |
| return model_inputs | |
| # set the patched model to a vision model | |
| MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES[PixelQwen2_5_VLConfig.model_type] = 'PixelQwen2_5_VLForConditionalGeneration' | |
| AutoConfig.register(PixelQwen2_5_VLConfig.model_type, PixelQwen2_5_VLConfig) | |
| AutoModel.register(PixelQwen2_5_VLConfig, PixelQwen2_5_VLForConditionalGeneration) | |
| AutoProcessor.register(PixelQwen2_5_VLConfig, PatchedQwen2_5_VLProcessor) | |