Spaces:
Runtime error
Runtime error
| # Adopted from: https://github.com/DAMO-NLP-SG/VideoLLaMA3. | |
| # Adopted from: https://github.com/haotian-liu/LLaVA. | |
| # Below is the original copyright: | |
| # Copyright 2023 Haotian Liu | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List, Optional, Tuple, Union, Dict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import (AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoImageProcessor, | |
| Qwen2Config, Qwen2ForCausalLM, Qwen2Model) | |
| from transformers.generation.utils import GenerateOutput | |
| # from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from dataclasses import dataclass | |
| from transformers.utils import ModelOutput | |
| from .loss import cross_entropy_loss, CrossEntropyLoss, DiceLoss | |
| from .processor import Videollama3BaseProcessor | |
| from .rynnec_arch import RynnecMetaForCausalLM, RynnecMetaModel | |
| from .videollama3_encoder import Videollama3ImageProcessor | |
| from rynnec.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN | |
| from .sam2_train import SAM2TrainRunner | |
| from .sam2 import SAM2 | |
| from .utils import genetate_video_pred_embeddings, process_video_gt_masks | |
| CHAT_TEMPLATE = """ | |
| {%- set identifier = 'im' %} | |
| {% for message in messages %} | |
| {% if message['role'] == 'stream' %} | |
| {% set identifier = 'stream' %} | |
| {% else %} | |
| {% set identifier = 'im' %} | |
| {% endif %} | |
| {% if message['role'] is not none %} | |
| {{- '<|' + identifier + '_start|>' + message['role'] + '\n' -}} | |
| {% endif %} | |
| {% if message['content'] is string %} | |
| {{- message['content'] + '<|' + identifier + '_end|>\n' -}} | |
| {% else %} | |
| {% for content in message['content'] %} | |
| {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %} | |
| {% if 'time' in content %} | |
| {{- 'Time ' + content['time'] | round(1) | string + 's: ' -}} | |
| {% endif %} | |
| {{- image_token + '\n' -}} | |
| {% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %} | |
| {% for i in range(content['num_frames']) %} | |
| {% if 'timestamps' in content %} | |
| {{- 'Time ' + content['timestamps'][i] | round(1) | string + 's:' -}} | |
| {% endif %} | |
| {% if i < content['num_frames'] - 1 %} | |
| {{- image_token + ',' -}} | |
| {% else %} | |
| {{- image_token + '\n' -}} | |
| {% endif %} | |
| {% endfor %} | |
| {% elif content['type'] == 'text' or 'text' in content %} | |
| {{- content['text'] -}} | |
| {% endif %} | |
| {% endfor %} | |
| {% if message['role'] is not none %} | |
| {{- '<|' + identifier + '_end|>\n' -}} | |
| {% endif %} | |
| {% endif %} | |
| {% endfor %} | |
| {% if add_generation_prompt %} | |
| {{- '<|im_start|>assistant\n' -}} | |
| {% if add_think_prompt %} | |
| {{- '<think>\n' -}} | |
| {% endif %} | |
| {% endif %} | |
| """ | |
| class CausalLMOutputWithPast(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| logits: torch.FloatTensor = None | |
| past_key_values: Optional[List[torch.FloatTensor]] = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| rope_deltas: Optional[torch.LongTensor] = None | |
| ce_loss: Optional[torch.FloatTensor] = None | |
| mask_bce_loss: Optional[torch.FloatTensor] = None | |
| mask_dice_loss: Optional[torch.FloatTensor] = None | |
| mask_loss: Optional[torch.FloatTensor] = None | |
| class Videollama3Qwen2Processor(Videollama3BaseProcessor): | |
| tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") | |
| chat_template = CHAT_TEMPLATE | |
| def __init__( | |
| self, | |
| image_processor=None, | |
| tokenizer=None, | |
| chat_template=None, | |
| image_merge_size: int = 1, | |
| video_merge_size: int = 2, | |
| fps=1, | |
| max_frames=180, | |
| **kwargs | |
| ): | |
| super().__init__(image_processor, tokenizer, chat_template, **kwargs) | |
| self.generation_prompt = self._infer_generation_prompt() | |
| self.generation_prompt_ids = self.tokenizer.encode(self.generation_prompt, return_tensors="pt") | |
| self.generation_prompt_length = len(self.generation_prompt_ids[0]) | |
| def _infer_generation_prompt(self): | |
| pseudo_message = [{"role": "user", "content": ""}] | |
| instruction = self.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=True) | |
| conversation = self.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=False) | |
| return instruction.replace(conversation, "") | |
| def _process_text_with_label( | |
| self, | |
| text: List[Dict], | |
| grid_sizes: torch.Tensor = None, | |
| **kwargs, | |
| ): | |
| assert kwargs.pop("return_tensors", "pt") == "pt", "Only PyTorch tensors are supported when return_labels=True." | |
| assert isinstance(text[0], dict), "When return_labels=True, text must be a list of messages." | |
| input_ids_list = [] | |
| targets_list = [] | |
| image_idx = 0 | |
| for message_idx, message in enumerate(text): | |
| # 1. set chat template and append image tokens | |
| prompt = self.apply_chat_template([message], tokenize=False, add_generation_prompt=False) | |
| prompt_chunks = prompt.split(DEFAULT_IMAGE_TOKEN) | |
| prompt = [] | |
| for chunk_idx in range(len(prompt_chunks) - 1): | |
| prompt.append(prompt_chunks[chunk_idx]) | |
| thw = grid_sizes[image_idx] | |
| prompt.append(DEFAULT_IMAGE_TOKEN * thw.prod().long()) | |
| image_idx += 1 | |
| prompt.append(prompt_chunks[-1]) | |
| prompt = "".join(prompt) | |
| input_ids = self.tokenizer.encode(prompt, return_tensors="pt")[0] | |
| input_ids_list.append(input_ids) | |
| targets = torch.full_like(input_ids, IGNORE_INDEX) | |
| if message["role"] == "assistant" or message["role"] is None: | |
| targets[self.generation_prompt_length:-1] = input_ids[self.generation_prompt_length:-1].clone() | |
| # NOTE: mask out image tokens | |
| vision_mask = input_ids == self.image_token_id | |
| targets[vision_mask] = IGNORE_INDEX | |
| vision_indices = torch.nonzero(vision_mask, as_tuple=True)[0] | |
| targets[vision_indices + 1] = IGNORE_INDEX | |
| # NOTE: mask out <think> or <think>\n | |
| think_mask = targets == self.think_start_token_id | |
| targets[think_mask] = IGNORE_INDEX | |
| think_indices = torch.nonzero(think_mask, as_tuple=True)[0] | |
| newline_mask = torch.zeros_like(think_mask) | |
| newline_mask[think_indices + 1] = targets[think_indices + 1] == self.newline_token_id | |
| targets[newline_mask] = IGNORE_INDEX | |
| targets_list.append(targets) | |
| assert len(grid_sizes) == image_idx, "Number of images does not match the number of image tokens in the text." | |
| text_inputs = { | |
| "input_ids": torch.cat(input_ids_list), | |
| "labels": torch.cat(targets_list), | |
| } | |
| return text_inputs | |
| class RynnecQwen2Config(Qwen2Config): | |
| model_type = "rynnec_qwen2" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.model_type = "rynnec_qwen2" | |
| class RynnecQwen2Model(RynnecMetaModel, Qwen2Model): | |
| config_class = RynnecQwen2Config | |
| def __init__(self, config: RynnecQwen2Config): | |
| super(RynnecQwen2Model, self).__init__(config) | |
| if hasattr(config, "mm_mask_decoder"): # inference | |
| self.build_mask_decoder(config) | |
| else: # training | |
| if 'out_dim' not in config: | |
| config.out_dim = 256 | |
| def build_mask_decoder(self, config): | |
| # Projection layer for lisa | |
| in_dim = config.hidden_size | |
| out_dim = config.out_dim | |
| text_fc = [ | |
| nn.Linear(in_dim, in_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(in_dim, out_dim), | |
| nn.Dropout(0.0), | |
| ] | |
| self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)]) | |
| self.text_hidden_fcs.train() | |
| for param in self.text_hidden_fcs.parameters(): | |
| param.requires_grad = True | |
| class RynnecQwen2ForCausalLM(Qwen2ForCausalLM, RynnecMetaForCausalLM): | |
| config_class = RynnecQwen2Config | |
| def __init__(self, config, **kwargs): | |
| super(Qwen2ForCausalLM, self).__init__(config) | |
| self.model = RynnecQwen2Model(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| if hasattr(config, "training") and config.training is True: | |
| self.grounding_encoder = SAM2TrainRunner(ckpt_path=config.mask_decoder_model) | |
| config.mm_mask_decoder = True | |
| else: | |
| self.grounding_encoder = SAM2(ckpt_path=config.mask_decoder_model) | |
| self.loss_mask = CrossEntropyLoss( | |
| use_sigmoid=True, | |
| reduction='mean', | |
| loss_weight=2.0 | |
| ) | |
| self.loss_dice = DiceLoss( | |
| use_sigmoid=True, | |
| activate=True, | |
| reduction='mean', | |
| naive_dice=True, | |
| eps=1.0, | |
| loss_weight=0.5 | |
| ) | |
| def load_sam2_weights(self, model_path): | |
| sam2_model = torch.load(model_path, map_location='cpu')['model'] | |
| prefix = "sam2_model." | |
| new_state_dict = {} | |
| for param_name in sam2_model.keys(): | |
| new_param_name = prefix + param_name | |
| new_state_dict[new_param_name] = sam2_model[param_name] | |
| self.grounding_encoder.load_state_dict(new_state_dict, strict=False) | |
| def get_model(self): | |
| return self.model | |
| # NOTE: arguments are copied from transformers==4.46.3 | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| num_logits_to_keep: int = 0, | |
| # multimodal inputs | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| grid_sizes: Optional[torch.LongTensor] = None, | |
| merge_sizes: Optional[torch.LongTensor] = None, | |
| modals: Optional[List[str]] = None, | |
| masks: Optional[List[torch.LongTensor]] = None, | |
| mask_ids = None, | |
| sam_images = None, | |
| sam_size = None, | |
| image2maskids = None, | |
| **loss_kwargs, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| torch.cuda.empty_cache() | |
| if inputs_embeds is None: | |
| input_ids_raw = input_ids.clone() | |
| ( | |
| input_ids, | |
| attention_mask, | |
| position_ids, | |
| past_key_values, | |
| inputs_embeds, | |
| labels, | |
| ) = self.prepare_inputs_labels_for_multimodal( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| labels=labels, | |
| pixel_values=pixel_values, | |
| grid_sizes=grid_sizes, | |
| merge_sizes=merge_sizes, | |
| modals=modals, | |
| masks=masks, | |
| mask_ids=mask_ids | |
| ) | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| cache_position=cache_position, | |
| ) | |
| hidden_states = outputs[0] | |
| loss, logits = None, None | |
| _valid = True | |
| seg_valid = True | |
| if labels is not None: #training | |
| ce_loss = cross_entropy_loss( | |
| hidden_states=hidden_states, | |
| lm_head=self.lm_head, | |
| position_ids=position_ids, | |
| labels=labels, | |
| reduction_scope=self.config.loss_reduction_scope, | |
| **loss_kwargs, | |
| ) | |
| if self.config.has_mask: | |
| hidden_states_sam = [] | |
| hidden_states_sam.append(self.model.text_hidden_fcs[0](hidden_states)) | |
| hidden_states_sam = torch.stack(hidden_states_sam, dim=-1).sum(dim=-1) | |
| bs = input_ids_raw.shape[0] | |
| gt_masks_list = [] | |
| pred_masks_list = [] | |
| mask_bce_loss = 0 | |
| mask_dice_loss = 0 | |
| num_masks = 0 | |
| for i in range(bs): | |
| pred_masks = [] | |
| pred_embeddings = [] | |
| input_id = input_ids_raw[i] | |
| seg_token_mask = input_id[1:]==self.config.seg_token_index | |
| seg_token_mask = torch.cat( | |
| [ | |
| seg_token_mask, | |
| torch.zeros((1)).bool().cuda(), | |
| ], | |
| dim=0, | |
| ) | |
| pred_embedding = hidden_states_sam[i][seg_token_mask] | |
| if len(pred_embedding)>0: | |
| pred_embeddings.append(pred_embedding) | |
| else: | |
| pred_embeddings.append(hidden_states_sam[i, :1]) | |
| gt_masks_video = [] # FIXME: Only support one segmentation now | |
| gt_mask = masks[i] | |
| mask_valid = True | |
| if len(image2maskids[i])==0: | |
| sam_images[i] = sam_images[i][:1] | |
| gt_masks_video.append(torch.zeros((len(sam_images[i]), 224, 224)).to(sam_images[0].device)) | |
| mask_valid = False | |
| else: | |
| for mids in image2maskids[i]: | |
| for mid in mids: | |
| if mid is None: | |
| gt_masks_video.append(torch.zeros((224, 224)).unsqueeze(0).to(gt_mask[0].device)) | |
| else: | |
| gt_masks_video.append(gt_mask[mid].unsqueeze(0)) | |
| frames_per_batch = [len(sam_images[i])] | |
| try: | |
| pred_embeddings_list_video = genetate_video_pred_embeddings(pred_embeddings, frames_per_batch) | |
| # pred_embeddings_list_video, gt_masks_video = check_obj_number(pred_embeddings_list_video, gt_masks_video) | |
| g_pixel_values = sam_images[i] | |
| num_objs = len(pred_embeddings_list_video[0]) | |
| # with torch.no_grad(): | |
| sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values, expand_size=num_objs) | |
| language_embeddings = torch.cat(pred_embeddings_list_video, dim=0)[:, None]#.contiguous() | |
| num_frames = len(pred_embeddings_list_video) | |
| gt_masks_video = process_video_gt_masks(gt_masks_video, num_frames, num_objs) | |
| pred_masks = self.grounding_encoder.inject_language_embd(sam_states, language_embeddings, nf_nobj=(num_frames, num_objs)) | |
| gt_masks = [F.interpolate(gt_mask.unsqueeze(0), size=pred_masks[0].shape[-2:], mode='nearest').squeeze(0) for gt_mask in gt_masks_video] | |
| gt_masks = torch.cat(gt_masks, dim=0) | |
| pred_masks = pred_masks.flatten(0, 1) | |
| if not mask_valid: | |
| pred_masks = pred_masks*0.0 | |
| if len(pred_masks) != len(gt_masks): | |
| # drop this data | |
| print(f"Pred mask shape {pred_masks.shape} is not equal to gt_mask shape {gt_masks.shape} !!!") | |
| min_num = min(len(pred_masks), len(gt_masks)) | |
| pred_masks = pred_masks[:min_num] | |
| gt_masks = gt_masks[:min_num] | |
| seg_valid = False | |
| if not seg_valid or not mask_valid: | |
| _scale = 0.0 | |
| else: | |
| _scale = 1.0 | |
| mask_bce_loss_ = self.loss_mask(pred_masks, gt_masks) * len(pred_masks) * _scale | |
| mask_dice_loss_ = self.loss_dice(pred_masks, gt_masks) * len(gt_masks) * _scale | |
| mask_bce_loss += mask_bce_loss_ | |
| mask_dice_loss += mask_dice_loss_ | |
| num_masks += len(pred_masks) | |
| except Exception as exp: | |
| print(exp) | |
| _valid = False | |
| if num_masks>0: | |
| mask_bce_loss = mask_bce_loss / num_masks | |
| mask_dice_loss = mask_dice_loss / num_masks | |
| mask_bce_loss = self.config.bce_loss_weight * mask_bce_loss | |
| mask_dice_loss = self.config.dice_loss_weight * mask_dice_loss | |
| if _valid==False: | |
| mask_bce_loss = mask_bce_loss * 0.0 | |
| mask_dice_loss = mask_dice_loss* 0.0 | |
| mask_loss = mask_bce_loss + mask_dice_loss | |
| loss = mask_loss + ce_loss | |
| else: | |
| loss = ce_loss | |
| else: | |
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| if loss is not None: | |
| if self.config.has_mask: | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| ce_loss=ce_loss.detach(), | |
| mask_bce_loss=mask_bce_loss.detach(), | |
| mask_dice_loss=mask_dice_loss.detach(), | |
| mask_loss=mask_loss.detach(), | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| else: | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| else: #infer | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def inference( | |
| self, | |
| # multimodal inputs | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| grid_sizes: Optional[torch.LongTensor] = None, | |
| merge_sizes: Optional[torch.LongTensor] = None, | |
| modals: Optional[List[str]] = None, | |
| masks: Optional[List[torch.LongTensor]] = None, | |
| mask_ids = None, | |
| sam_images = None, | |
| sam_size = None, | |
| image2maskids = None, | |
| seg_start_idx = 0, | |
| **kwargs, | |
| ): | |
| outputs = self.generate( | |
| pixel_values=pixel_values, | |
| grid_sizes=grid_sizes, | |
| merge_sizes=merge_sizes, | |
| modals=modals, | |
| masks=masks, | |
| mask_ids=mask_ids, | |
| output_hidden_states=True, | |
| return_dict_in_generate=True, | |
| **kwargs | |
| ) | |
| input_ids = kwargs.pop('input_ids') | |
| last_hidden_state = [] | |
| for hs in outputs.hidden_states: # round | |
| last_hidden_state.append(hs[-1]) | |
| last_hidden_state = torch.cat(last_hidden_state, dim=1) | |
| output_ids = outputs.sequences | |
| concat_ids = torch.cat((input_ids, output_ids), dim=1) | |
| seg_token_mask = concat_ids[:, 1:] == self.config.seg_token_index | |
| last_hidden_state_sam = self.model.text_hidden_fcs[0](last_hidden_state) | |
| pred_embeddings = last_hidden_state_sam[seg_token_mask] | |
| seg_token_counts = seg_token_mask.int().sum() | |
| if seg_token_counts>0: | |
| g_pixel_values = torch.cat(sam_images, dim=0).contiguous() | |
| num_objs = 1 #FIXME: Only support one segmentation now | |
| if seg_start_idx>0: | |
| # before start idx | |
| g_pixel_values_beg = g_pixel_values[:seg_start_idx+1].flip(0) | |
| num_frames = len(g_pixel_values_beg) | |
| sam_states_beg = self.grounding_encoder.get_sam2_embeddings(g_pixel_values_beg) | |
| pred_masks_beg = self.grounding_encoder.language_embd_inference(sam_states_beg, [pred_embeddings]*num_frames) | |
| else: | |
| pred_masks_beg = torch.zeros((1, 1, 1024, 1024)).to(pixel_values.device) | |
| if seg_start_idx<=len(g_pixel_values)-1: | |
| g_pixel_values_end = g_pixel_values[seg_start_idx:] | |
| num_frames = len(g_pixel_values_end) | |
| sam_states_end = self.grounding_encoder.get_sam2_embeddings(g_pixel_values_end) | |
| pred_masks_end = self.grounding_encoder.language_embd_inference(sam_states_end, [pred_embeddings]*num_frames) | |
| else: | |
| pred_masks_end = torch.zeros((0, 1, 1024, 1024)).to(pixel_values.device) | |
| pred_masks = torch.cat([pred_masks_beg[1:].flip(0), pred_masks_end], dim=0) | |
| return output_ids, pred_masks | |
| def generate( | |
| self, | |
| # multimodal inputs | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| grid_sizes: Optional[torch.LongTensor] = None, | |
| merge_sizes: Optional[torch.LongTensor] = None, | |
| modals: Optional[List[str]] = None, | |
| masks: Optional[List[torch.LongTensor]] = None, | |
| mask_ids = None, | |
| **kwargs, | |
| ) -> Union[GenerateOutput, torch.LongTensor]: | |
| input_ids = kwargs.pop("input_ids", None) | |
| attention_mask = kwargs.pop("attention_mask", None) | |
| position_ids = kwargs.pop("position_ids", None) | |
| past_key_values = kwargs.pop("past_key_values", None) | |
| if "inputs_embeds" in kwargs: | |
| raise NotImplementedError("`inputs_embeds` is not supported") | |
| if pixel_values is not None: | |
| ( | |
| input_ids, | |
| attention_mask, | |
| position_ids, | |
| past_key_values, | |
| inputs_embeds, | |
| labels, | |
| ) = self.prepare_inputs_labels_for_multimodal( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| labels=None, | |
| pixel_values=pixel_values, | |
| grid_sizes=grid_sizes, | |
| merge_sizes=merge_sizes, | |
| modals=modals, | |
| masks=masks, | |
| mask_ids=mask_ids | |
| ) | |
| else: | |
| inputs_embeds = self.get_model().embed_tokens(input_ids) | |
| return super().generate( | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| **kwargs | |
| ) | |
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): | |
| images = kwargs.pop("images", None) | |
| _inputs = super().prepare_inputs_for_generation( | |
| input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs | |
| ) | |
| if images is not None: | |
| _inputs['images'] = images | |
| return _inputs | |
| AutoConfig.register("rynnec_qwen2", RynnecQwen2Config) | |
| AutoModelForCausalLM.register(RynnecQwen2Config, RynnecQwen2ForCausalLM) | |
| AutoProcessor.register(RynnecQwen2Config, Videollama3Qwen2Processor) | |
| AutoImageProcessor.register(RynnecQwen2Config, Videollama3ImageProcessor) | |