Spaces:
Runtime error
Runtime error
| import os | |
| import copy | |
| import math | |
| import warnings | |
| import shutil | |
| from functools import partial | |
| import torch | |
| import numpy as np | |
| from .model import load_pretrained_model | |
| from .mm_utils import load_images, process_images, load_video, process_video, tokenizer_multimodal_token, get_model_name_from_path, KeywordsStoppingCriteria, DirectResize, sam_preprocess_batch | |
| from .constants import NUM_FRAMES, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, MODAL_INDEX_MAP, STREAM_START_TOKEN, STREAM_END_TOKEN | |
| from .model.rynnec_qwen2 import Videollama3Qwen2Processor | |
| def disable_torch_init(): | |
| """ | |
| Disable the redundant torch default initialization to accelerate model creation. | |
| """ | |
| import torch | |
| setattr(torch.nn.Linear, "reset_parameters", lambda self: None) | |
| setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) | |
| def model_init(model_path=None, min_visual_tokens=None, max_visual_tokens=None, **kwargs): | |
| model_path = "Alibaba-DAMO-Academy/RynnEC-2B" if model_path is None else model_path | |
| model_name = get_model_name_from_path(model_path) | |
| tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, **kwargs) | |
| if max_visual_tokens is not None: | |
| image_processor.max_tokens = max_visual_tokens | |
| if min_visual_tokens is not None: | |
| image_processor.min_tokens = min_visual_tokens | |
| if tokenizer.pad_token is None and tokenizer.unk_token is not None: | |
| tokenizer.pad_token = tokenizer.unk_token | |
| processor = Videollama3Qwen2Processor(image_processor, tokenizer) | |
| return model, processor | |
| def mm_infer(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', **kwargs): | |
| mask_ids = kwargs.pop('mask_ids', None) | |
| masks = kwargs.pop('masks', None) | |
| if modal == 'image': | |
| modal_token = DEFAULT_IMAGE_TOKEN | |
| images = images_or_videos | |
| timestamps = None | |
| elif modal == 'video': | |
| modal_token = DEFAULT_VIDEO_TOKEN | |
| images, timestamps = images_or_videos | |
| elif modal == 'text': | |
| modal_token = '' | |
| else: | |
| raise ValueError(f"Unsupported modal: {modal}") | |
| # 1. text preprocess (tag process & generate prompt). | |
| if isinstance(instruct, str): | |
| messages = [{'role': 'user', 'content': instruct}] | |
| elif isinstance(instruct, list): | |
| messages = copy.deepcopy(instruct) | |
| else: | |
| raise ValueError(f"Unsupported type of instruct: {type(instruct)}") | |
| if all(not modal_token in message["content"] for message in messages): | |
| warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!") | |
| messages[0]["content"] = modal_token + messages[0]["content"] | |
| converted_messages = [] | |
| for message in messages: | |
| chunks = message["content"].split(modal_token) | |
| converted_messages.append({ | |
| "role": "user", | |
| "content": [] | |
| }) | |
| for chunk_idx in range(1, 2 * len(chunks)): | |
| if chunk_idx % 2 == 1: | |
| chunk = chunks[chunk_idx // 2].strip() | |
| converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None | |
| else: | |
| if modal == 'image': | |
| converted_messages[-1]["content"].append({"type": "image"}) | |
| elif modal == 'video': | |
| converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps}) | |
| messages = converted_messages | |
| system_message = [] | |
| image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size) | |
| # TODO: attention mask? | |
| messages = system_message + messages | |
| data_dict = vlprocessor( | |
| images=images, | |
| text=messages, | |
| merge_size=image_downsampling, | |
| return_labels=True, | |
| return_tensors="pt", | |
| ) | |
| torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16 | |
| # images = [x.to(torch_dtype).cuda(non_blocking=True) for x in data_dict["images"]] | |
| # grid_thws = [x.cuda(non_blocking=True) for x in data_dict["grid_thws"]] | |
| # 3. generate response according to visual signals and prompts. | |
| keywords = [tokenizer.eos_token] | |
| stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0)) | |
| do_sample = kwargs.get('do_sample', False) | |
| temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0) | |
| top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0) | |
| top_k = kwargs.get('top_k', 20 if do_sample else 50) | |
| max_new_tokens = kwargs.get('max_new_tokens', 2048) | |
| data_dict["modals"] = [modal] | |
| data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()} | |
| if "pixel_values" in data_dict: | |
| data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"]) | |
| data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16) | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids=data_dict["input_ids"].unsqueeze(0).cuda(), | |
| pixel_values=data_dict["pixel_values"], | |
| grid_sizes=data_dict["grid_sizes"], | |
| merge_sizes=data_dict["merge_sizes"], | |
| modals=data_dict["modals"], | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_p=top_p, | |
| top_k=top_k, | |
| use_cache=True, | |
| stopping_criteria=[stopping_criteria], | |
| pad_token_id=tokenizer.eos_token_id, | |
| masks=[masks], | |
| mask_ids=mask_ids | |
| ) | |
| outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return outputs | |
| def mm_infer_segmentation(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', seg_start_idx=0, **kwargs): | |
| image2maskids = kwargs.get('image2maskids', []) | |
| img_size=1024 | |
| sam_transform = DirectResize(img_size) | |
| if modal == 'image': | |
| modal_token = DEFAULT_IMAGE_TOKEN | |
| images = images_or_videos | |
| timestamps = None | |
| elif modal == 'video': | |
| modal_token = DEFAULT_VIDEO_TOKEN | |
| images, timestamps = images_or_videos | |
| elif modal == 'text': | |
| modal_token = '' | |
| else: | |
| raise ValueError(f"Unsupported modal: {modal}") | |
| sam_images = [] | |
| sam_size = None | |
| if len(images)>0: | |
| for image in images: | |
| sam_image = sam_transform.apply_image(np.array(image)) | |
| sam_images.append(sam_image) | |
| if sam_size is None: | |
| sam_size = sam_image.shape[:2] | |
| sam_images = np.array(sam_images) | |
| sam_images = torch.from_numpy(sam_images).permute(0, 3, 1, 2).contiguous() | |
| sam_images = sam_preprocess_batch(sam_images) | |
| # 1. text preprocess (tag process & generate prompt). | |
| if isinstance(instruct, str): | |
| messages = [{'role': 'user', 'content': instruct}] | |
| elif isinstance(instruct, list): | |
| messages = copy.deepcopy(instruct) | |
| else: | |
| raise ValueError(f"Unsupported type of instruct: {type(instruct)}") | |
| if all(not modal_token in message["content"] for message in messages): | |
| warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!") | |
| messages[0]["content"] = modal_token + messages[0]["content"] | |
| converted_messages = [] | |
| for message in messages: | |
| chunks = message["content"].split(modal_token) | |
| converted_messages.append({ | |
| "role": "user", | |
| "content": [] | |
| }) | |
| for chunk_idx in range(1, 2 * len(chunks)): | |
| if chunk_idx % 2 == 1: | |
| chunk = chunks[chunk_idx // 2].strip() | |
| converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None | |
| else: | |
| if modal == 'image': | |
| converted_messages[-1]["content"].append({"type": "image"}) | |
| elif modal == 'video': | |
| converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps}) | |
| messages = converted_messages | |
| system_message = [] | |
| image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size) | |
| # TODO: attention mask? | |
| messages = system_message + messages | |
| data_dict = vlprocessor( | |
| images=images, | |
| text=messages, | |
| merge_size=image_downsampling, | |
| return_labels=True, | |
| return_tensors="pt", | |
| ) | |
| torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16 | |
| keywords = [tokenizer.eos_token] | |
| stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0)) | |
| do_sample = kwargs.get('do_sample', False) | |
| temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0) | |
| top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0) | |
| top_k = kwargs.get('top_k', 20 if do_sample else 50) | |
| max_new_tokens = kwargs.get('max_new_tokens', 2048) | |
| torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16 | |
| data_dict["modals"] = [modal] | |
| data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()} | |
| if "pixel_values" in data_dict: | |
| data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"]) | |
| data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16) | |
| with torch.inference_mode(): | |
| output_ids, pred_masks = model.inference( | |
| input_ids=data_dict["input_ids"].unsqueeze(0).cuda(), | |
| pixel_values=data_dict["pixel_values"], | |
| grid_sizes=data_dict["grid_sizes"], | |
| merge_sizes=data_dict["merge_sizes"], | |
| modals=data_dict["modals"], | |
| sam_images=[sam_images], | |
| sam_size=[sam_size], | |
| image2maskids=[image2maskids], | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_p=top_p, | |
| top_k=top_k, | |
| use_cache=True, | |
| stopping_criteria=[stopping_criteria], | |
| pad_token_id=tokenizer.eos_token_id, | |
| seg_start_idx=seg_start_idx | |
| ) | |
| outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() | |
| pred_masks_sigmoid = pred_masks.sigmoid()>0.5 | |
| return outputs, pred_masks_sigmoid | |