UniPixel / app.py
yeliudev's picture
Fix model loading
0ce76cd
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause license.
import re
import uuid
from functools import partial
import gradio as gr
import imageio.v3 as iio
import spaces
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as T
from PIL import Image
from unipixel.constants import MEM_TOKEN, SEG_TOKEN
from unipixel.dataset.utils import process_vision_info
from unipixel.model.builder import build_model
from unipixel.utils.io import load_image, load_video
from unipixel.utils.transforms import get_sam2_transform
from unipixel.utils.visualizer import draw_mask, sample_color
MODEL = 'PolyU-ChenLab/UniPixel-3B'
TITLE = 'UniPixel: Unified Object Referring and Segmentation for Pixel-Level Visual Reasoning'
HEADER = """
<p align="center" style="margin: 1em 0 2em;"><img width="260" src="https://raw.githubusercontent.com/PolyU-ChenLab/UniPixel/refs/heads/main/.github/logo.png"></p>
<h3 align="center">Unified Object Referring and Segmentation for Pixel-Level Visual Reasoning</h3>
<div style="display: flex; justify-content: center; gap: 5px;">
<a href="https://arxiv.org/abs/2509.18094" target="_blank"><img src="https://img.shields.io/badge/arXiv-2509.18094-red"></a>
<a href="https://polyu-chenlab.github.io/unipixel/" target="_blank"><img src="https://img.shields.io/badge/Project-Page-brightgreen"></a>
<a href="https://huggingface.co/collections/PolyU-ChenLab/unipixel-68cf7137013455e5b15962e8" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue"></a>
<a href="https://huggingface.co/datasets/PolyU-ChenLab/UniPixel-SFT-1M" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Dataset-orange"></a>
<a href="https://github.com/PolyU-ChenLab/UniPixel/blob/main/README.md" target="_blank"><img src="https://img.shields.io/badge/License-BSD--3--Clause-purple"></a>
<a href="https://github.com/PolyU-ChenLab/UniPixel" target="_blank"><img src="https://img.shields.io/github/stars/PolyU-ChenLab/UniPixel"></a>
</div>
<p style="margin-top: 1em;">UniPixel is a unified MLLM for pixel-level vision-language understanding. It flexibly supports a variety of fine-grained tasks, including image/video segmentation, regional understanding, and a novel PixelQA task that jointly requires object-centric referring, segmentation, and question-answering in videos. Please open an <a href="https://github.com/PolyU-ChenLab/UniPixel/issues/new" target="_blank">issue</a> if you meet any problems.</p>
"""
# https://github.com/gradio-app/gradio/pull/10552
JS = """
function init() {
if (window.innerWidth >= 1536) {
document.querySelector('main').style.maxWidth = '1536px'
}
document.getElementById('query_1').addEventListener('keydown', function f1(e) { if (e.key === 'Enter') { document.getElementById('submit_1').click() } })
document.getElementById('query_2').addEventListener('keydown', function f2(e) { if (e.key === 'Enter') { document.getElementById('submit_2').click() } })
document.getElementById('query_3').addEventListener('keydown', function f3(e) { if (e.key === 'Enter') { document.getElementById('submit_3').click() } })
document.getElementById('query_4').addEventListener('keydown', function f4(e) { if (e.key === 'Enter') { document.getElementById('submit_4').click() } })
}
"""
model, processor = build_model(MODEL, attn_implementation='sdpa')
sam2_transform = get_sam2_transform(model.config.sam2_image_size)
device = torch.device('cuda')
colors = sample_color()
color_map = {f'Target {i + 1}': f'#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}' for i, c in enumerate(colors * 255)}
color_map_light = {
f'Target {i + 1}': f'#{int(c[0] * 127.5 + 127.5):02x}{int(c[1] * 127.5 + 127.5):02x}{int(c[2] * 127.5 + 127.5):02x}'
for i, c in enumerate(colors)
}
def enable_btns():
return (gr.Button(interactive=True), ) * 4
def disable_btns():
return (gr.Button(interactive=False), ) * 4
def reset_seg():
return 16, gr.Button(interactive=False)
def reset_reg():
return 1, gr.Button(interactive=False)
def update_region(blob):
if blob['background'] is None or not blob['layers'][0].any():
return
region = blob['background'].copy()
region[blob['layers'][0][:, :, -1] == 0] = [0, 0, 0, 0]
return region
def update_video(video, prompt_idx):
if video is None:
return gr.ImageEditor(value=None, interactive=False)
_, images = load_video(video, sample_frames=16)
component = gr.ImageEditor(value=images[prompt_idx - 1], interactive=True)
return component
@spaces.GPU
def infer_seg(media, query, sample_frames=16, media_type=None):
global model
if not media:
gr.Warning('Please upload an image or a video.')
return None, None, None
if not query:
gr.Warning('Please provide a text prompt.')
return None, None, None
if any(media.endswith(k) for k in ('jpg', 'png')):
frames, images = load_image(media), [media]
else:
frames, images = load_video(media, sample_frames=sample_frames)
messages = [{
'role':
'user',
'content': [{
'type': 'video',
'video': images,
'min_pixels': 128 * 28 * 28,
'max_pixels': 256 * 28 * 28 * int(sample_frames / len(images))
}, {
'type': 'text',
'text': query
}]
}]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
images, videos, kwargs = process_vision_info(messages, return_video_kwargs=True)
data = processor(text=[text], images=images, videos=videos, return_tensors='pt', **kwargs)
data['frames'] = [sam2_transform(frames).to(model.sam2.dtype)]
data['frame_size'] = [frames.shape[1:3]]
model = model.to(device)
output_ids = model.generate(
**data.to(device),
do_sample=False,
temperature=None,
top_k=None,
top_p=None,
repetition_penalty=None,
max_new_tokens=512)
assert data.input_ids.size(0) == output_ids.size(0) == 1
output_ids = output_ids[0, data.input_ids.size(1):]
if output_ids[-1] == processor.tokenizer.eos_token_id:
output_ids = output_ids[:-1]
response = processor.decode(output_ids, clean_up_tokenization_spaces=False)
response = response.replace(f' {SEG_TOKEN}', SEG_TOKEN).replace(f'{SEG_TOKEN} ', SEG_TOKEN)
entities = []
for i, m in enumerate(re.finditer(re.escape(SEG_TOKEN), response)):
entities.append(dict(entity=f'Target {i + 1}', start=m.start(), end=m.end()))
answer = dict(text=response, entities=entities)
imgs = draw_mask(frames, model.seg, colors=colors)
path = f"/tmp/{uuid.uuid4().hex}.{'gif' if len(imgs) > 1 else 'png'}"
iio.imwrite(path, imgs, duration=100, loop=0)
if media_type == 'image':
if len(model.seg) >= 1:
masks = media, [(m[0, 0].numpy(), f'Target {i + 1}') for i, m in enumerate(model.seg)]
else:
masks = None
else:
masks = path
return answer, masks, path
infer_seg_image = partial(infer_seg, media_type='image')
infer_seg_video = partial(infer_seg, media_type='video')
@spaces.GPU
def infer_reg(blob, query, prompt_idx=1, video=None):
global model
if blob['background'] is None:
gr.Warning('Please upload an image or a video.')
return
if not blob['layers'][0].any():
gr.Warning('Please provide a mask prompt.')
return
if not query:
gr.Warning('Please provide a text prompt.')
return
if video is None:
frames = torch.from_numpy(blob['background'][:, :, :3]).unsqueeze(0)
images = [Image.fromarray(blob['background'], mode='RGBA')]
else:
frames, images = load_video(video, sample_frames=16)
frame_size = frames.shape[1:3]
mask = torch.from_numpy(blob['layers'][0][:, :, -1]).unsqueeze(0) > 0
refer_mask = torch.zeros(frames.size(0), 1, *frame_size)
refer_mask[prompt_idx - 1] = mask
if refer_mask.size(0) % 2 != 0:
refer_mask = torch.cat((refer_mask, refer_mask[-1, None]))
refer_mask = refer_mask.flatten(1)
refer_mask = F.max_pool1d(refer_mask.transpose(-1, -2), kernel_size=2, stride=2).transpose(-1, -2)
refer_mask = refer_mask.view(-1, 1, *frame_size)
if video is None:
prefix = f'Here is an image with the following highlighted regions:\n[0]: <{prompt_idx}> {MEM_TOKEN}\n'
else:
prefix = f'Here is a video with {len(images)} frames denoted as <1> to <{len(images)}>. The highlighted regions are as follows:\n[0]: <{prompt_idx}>-<{prompt_idx + 1}> {MEM_TOKEN}\n'
messages = [{
'role':
'user',
'content': [{
'type': 'video',
'video': images,
'min_pixels': 128 * 28 * 28,
'max_pixels': 256 * 28 * 28 * int(16 / len(images))
}, {
'type': 'text',
'text': prefix + query
}]
}]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
images, videos, kwargs = process_vision_info(messages, return_video_kwargs=True)
data = processor(text=[text], images=images, videos=videos, return_tensors='pt', **kwargs)
refer_mask = T.resize(refer_mask, (data['video_grid_thw'][0][1] * 14, data['video_grid_thw'][0][2] * 14))
refer_mask = F.max_pool2d(refer_mask, kernel_size=28, stride=28)
refer_mask = refer_mask > 0
data['frames'] = [sam2_transform(frames).to(model.sam2.dtype)]
data['frame_size'] = [frames.shape[1:3]]
data['refer_mask'] = [refer_mask]
model = model.to(device)
output_ids = model.generate(
**data.to(device),
do_sample=False,
temperature=None,
top_k=None,
top_p=None,
repetition_penalty=None,
max_new_tokens=512)
assert data.input_ids.size(0) == output_ids.size(0) == 1
output_ids = output_ids[0, data.input_ids.size(1):]
if output_ids[-1] == processor.tokenizer.eos_token_id:
output_ids = output_ids[:-1]
response = processor.decode(output_ids, clean_up_tokenization_spaces=False)
response = response.replace(' [0]', '[0]').replace('[0] ', '[0]').replace('[0]', '<REGION>')
entities = []
for m in re.finditer(re.escape('<REGION>'), response):
entities.append(dict(entity='region', start=m.start(), end=m.end(), color="#f85050"))
answer = dict(text=response, entities=entities)
return answer
def build_demo():
with gr.Blocks(title=TITLE, js=JS, theme=gr.themes.Soft()) as demo:
gr.HTML(HEADER)
with gr.Tab('Image Segmentation'):
download_btn_1 = gr.DownloadButton(label='๐Ÿ“ฆ Download', interactive=False, render=False)
msk_1 = gr.AnnotatedImage(label='Segmentation Results', color_map=color_map, render=False)
ans_1 = gr.HighlightedText(
label='Model Response', color_map=color_map_light, show_inline_category=False, render=False)
with gr.Row():
with gr.Column():
media_1 = gr.Image(type='filepath')
sample_frames_1 = gr.Slider(1, 32, value=16, step=1, visible=False)
query_1 = gr.Textbox(label='Text Prompt', placeholder='Please segment the...', elem_id='query_1')
with gr.Row():
random_btn_1 = gr.Button(value='๐Ÿ”ฎ Random', visible=False)
reset_btn_1 = gr.ClearButton([media_1, query_1, msk_1, ans_1], value='๐Ÿ—‘๏ธ Reset')
reset_btn_1.click(reset_seg, None, [sample_frames_1, download_btn_1])
download_btn_1.render()
submit_btn_1 = gr.Button(value='๐Ÿš€ Submit', variant='primary', elem_id='submit_1')
with gr.Column():
msk_1.render()
ans_1.render()
ctx_1 = submit_btn_1.click(disable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
ctx_1 = ctx_1.then(infer_seg_image, [media_1, query_1, sample_frames_1], [ans_1, msk_1, download_btn_1])
ctx_1.then(enable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
with gr.Tab('Video Segmentation'):
download_btn_2 = gr.DownloadButton(label='๐Ÿ“ฆ Download', interactive=False, render=False)
msk_2 = gr.Image(label='Segmentation Results', render=False)
ans_2 = gr.HighlightedText(
label='Model Response', color_map=color_map_light, show_inline_category=False, render=False)
with gr.Row():
with gr.Column():
media_2 = gr.Video()
with gr.Accordion(label='Hyperparameters', open=False):
sample_frames_2 = gr.Slider(
1,
32,
value=16,
step=1,
interactive=True,
label='Sample Frames',
info='The number of frames to sample from a video (Default: 16)')
query_2 = gr.Textbox(label='Text Prompt', placeholder='Please segment the...', elem_id='query_2')
with gr.Row():
random_btn_2 = gr.Button(value='๐Ÿ”ฎ Random', visible=False)
reset_btn_2 = gr.ClearButton([media_2, query_2, msk_2, ans_2], value='๐Ÿ—‘๏ธ Reset')
reset_btn_2.click(reset_seg, None, [sample_frames_2, download_btn_2])
download_btn_2.render()
submit_btn_2 = gr.Button(value='๐Ÿš€ Submit', variant='primary', elem_id='submit_2')
with gr.Column():
msk_2.render()
ans_2.render()
ctx_2 = submit_btn_2.click(disable_btns, None, [random_btn_2, reset_btn_2, download_btn_2, submit_btn_2])
ctx_2 = ctx_2.then(infer_seg_video, [media_2, query_2, sample_frames_2], [ans_2, msk_2, download_btn_2])
ctx_2.then(enable_btns, None, [random_btn_2, reset_btn_2, download_btn_2, submit_btn_2])
with gr.Tab('Image Regional Understanding'):
download_btn_3 = gr.DownloadButton(visible=False)
msk_3 = gr.Image(label='Highlighted Region', render=False)
ans_3 = gr.HighlightedText(label='Model Response', show_inline_category=False, render=False)
with gr.Row():
with gr.Column():
media_3 = gr.ImageEditor(
label='Image & Mask Prompt',
brush=gr.Brush(colors=["#ff000080"], color_mode='fixed'),
transforms=None,
layers=False)
media_3.change(update_region, media_3, msk_3)
prompt_frame_index_3 = gr.Slider(1, 16, value=1, step=1, visible=False)
query_3 = gr.Textbox(
label='Text Prompt', placeholder='Please describe the highlighted region...', elem_id='query_3')
with gr.Row():
random_btn_3 = gr.Button(value='๐Ÿ”ฎ Random', visible=False)
reset_btn_3 = gr.ClearButton([media_3, query_3, msk_3, ans_3], value='๐Ÿ—‘๏ธ Reset')
reset_btn_3.click(reset_reg, None, [prompt_frame_index_3, download_btn_3])
submit_btn_3 = gr.Button(value='๐Ÿš€ Submit', variant='primary', elem_id='submit_3')
with gr.Column():
msk_3.render()
ans_3.render()
ctx_3 = submit_btn_3.click(disable_btns, None, [random_btn_3, reset_btn_3, download_btn_3, submit_btn_3])
ctx_3 = ctx_3.then(infer_reg, [media_3, query_3, prompt_frame_index_3], ans_3)
ctx_3.then(enable_btns, None, [random_btn_3, reset_btn_3, download_btn_3, submit_btn_3])
with gr.Tab('Video Regional Understanding'):
download_btn_4 = gr.DownloadButton(visible=False)
prompt_frame_index_4 = gr.Slider(
1,
16,
value=1,
step=1,
interactive=True,
label='Prompt Frame Index',
info='The index of the frame to apply mask prompts (Default: 1)',
render=False)
msk_4 = gr.ImageEditor(
label='Mask Prompt',
brush=gr.Brush(colors=['#ff000080'], color_mode='fixed'),
transforms=None,
layers=False,
interactive=False,
render=False)
ans_4 = gr.HighlightedText(label='Model Response', show_inline_category=False, render=False)
with gr.Row():
with gr.Column():
media_4 = gr.Video()
media_4.change(update_video, [media_4, prompt_frame_index_4], msk_4)
with gr.Accordion(label='Hyperparameters', open=False):
prompt_frame_index_4.render()
prompt_frame_index_4.change(update_video, [media_4, prompt_frame_index_4], msk_4)
query_4 = gr.Textbox(
label='Text Prompt', placeholder='Please describe the highlighted region...', elem_id='query_4')
with gr.Row():
random_btn_4 = gr.Button(value='๐Ÿ”ฎ Random', visible=False)
reset_btn_4 = gr.ClearButton([media_4, query_4, msk_4, ans_4], value='๐Ÿ—‘๏ธ Reset')
reset_btn_4.click(reset_reg, None, [prompt_frame_index_4, download_btn_4])
submit_btn_4 = gr.Button(value='๐Ÿš€ Submit', variant='primary', elem_id='submit_4')
with gr.Column():
msk_4.render()
ans_4.render()
ctx_4 = submit_btn_4.click(disable_btns, None, [random_btn_4, reset_btn_4, download_btn_4, submit_btn_4])
ctx_4 = ctx_4.then(infer_reg, [msk_4, query_4, prompt_frame_index_4, media_4], ans_4)
ctx_4.then(enable_btns, None, [random_btn_4, reset_btn_4, download_btn_4, submit_btn_4])
return demo
if __name__ == '__main__':
demo = build_demo()
demo.queue()
demo.launch(server_name='0.0.0.0')