Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import tempfile | |
| import gradio as gr | |
| import cv2 | |
| try: | |
| from mmengine.visualization import Visualizer | |
| except ImportError: | |
| Visualizer = None | |
| print("Warning: mmengine is not installed, visualization is disabled.") | |
| # Load the model and tokenizer | |
| model_path = "ByteDance/Sa2VA-4B" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ).eval().cuda() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code = True, | |
| ) | |
| def visualize(pred_mask, image_path, work_dir): | |
| visualizer = Visualizer() | |
| img = cv2.imread(image_path) | |
| visualizer.set_image(img) | |
| visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4) | |
| visual_result = visualizer.get_image() | |
| output_path = os.path.join(work_dir, os.path.basename(image_path)) | |
| cv2.imwrite(output_path, visual_result) | |
| return output_path | |
| def image_vision(image_input_path, prompt): | |
| image_path = image_input_path | |
| text_prompts = f"<image>{prompt}" | |
| image = Image.open(image_path).convert('RGB') | |
| input_dict = { | |
| 'image': image, | |
| 'text': text_prompts, | |
| 'past_text': '', | |
| 'mask_prompts': None, | |
| 'tokenizer': tokenizer, | |
| } | |
| return_dict = model.predict_forward(**input_dict) | |
| print(return_dict) | |
| answer = return_dict["prediction"] # the text format answer | |
| seg_image = return_dict["prediction_masks"] | |
| return answer, seg_image | |
| def main_infer(image_input_path, prompt): | |
| answer, seg_image = image_vision(image_input_path, prompt) | |
| if '[SEG]' in answer and Visualizer is not None: | |
| pred_masks = seg_image[0] | |
| temp_dir = tempfile.mkdtemp() | |
| pred_mask = pred_masks | |
| os.makedirs(temp_dir, exist_ok=True) | |
| seg_result = visualize(pred_mask, image_input_path, temp_dir) | |
| return answer, seg_result | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Image IN", type="filepath") | |
| with gr.Row(): | |
| instruction = gr.Textbox(label="Instruction", scale=4) | |
| submit_btn = gr.Button("Submit", scale=1) | |
| with gr.Column(): | |
| output_res = gr.Textbox(label="Response") | |
| output_image = gr.Image(label="Segmentation", type="numpy") | |
| submit_btn.click( | |
| fn = main_infer, | |
| inputs = [image_input, instruction], | |
| outputs = [output_res, output_image] | |
| ) | |
| demo.queue().launch(show_api=False, show_error=True) |