Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import uuid | |
| import time | |
| import asyncio | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from transformers import ( | |
| Qwen2_5_VLForConditionalGeneration, | |
| AutoProcessor, | |
| TextIteratorStreamer | |
| ) | |
| # Constants | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # Load multimodal processor and model (Callisto OCR3) | |
| MODEL_ID = "nvidia/Cosmos-Reason1-7B" | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).to(device).eval() | |
| def downsample_video(video_path: str, num_frames: int = 10): | |
| vidcap = cv2.VideoCapture(video_path) | |
| total = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = vidcap.get(cv2.CAP_PROP_FPS) | |
| idxs = np.linspace(0, total - 1, num_frames, dtype=int) | |
| frames = [] | |
| for i in idxs: | |
| vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
| ok, img = vidcap.read() | |
| if ok: | |
| rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| pil = Image.fromarray(rgb) | |
| timestamp = round(i / fps, 2) | |
| frames.append((pil, timestamp)) | |
| vidcap.release() | |
| return frames | |
| def progress_bar_html(label: str) -> str: | |
| return f'''<div style="display:flex; align-items:center;"> | |
| <span style="margin-right:10px; font-size:14px;">{label}</span> | |
| <div style="width:110px; height:5px; background:#B0E0E6; border-radius:2px; overflow:hidden;"> | |
| <div style="width:100%; height:100%; background:#00FFFF; animation:load 1.5s linear infinite;"></div> | |
| </div> | |
| </div> | |
| <style>@keyframes load{{0%{{transform:translateX(-100%)}}100%{{transform:translateX(100%)}}}}</style>''' | |
| def generate(prompt: str, files: list[str] = None): | |
| files = files or [] | |
| # Determine mode | |
| is_video = any(f.lower().endswith(('.mp4', '.avi', '.mov')) for f in files) | |
| is_image = any(f.lower().endswith(('.jpg', '.png', '.jpeg', '.bmp')) for f in files) | |
| if is_video: | |
| yield progress_bar_html("Processing video with cosmos-reason1") | |
| video = files[0] | |
| frames = downsample_video(video) | |
| # Build messages | |
| messages = [ | |
| {"role": "system", "content": [{"type":"text","text":"You are a helpful assistant."}]}, | |
| {"role": "user", "content": [{"type":"text","text": prompt}]} | |
| ] | |
| for img, ts in frames: | |
| path = f"frame_{uuid.uuid4().hex}.png" | |
| img.save(path) | |
| messages[1]["content"].extend([ | |
| {"type":"text","text": f"Frame {ts}:"}, | |
| {"type":"image","url": path} | |
| ]) | |
| inputs = processor.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, | |
| return_dict=True, return_tensors="pt", | |
| truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH | |
| ).to(device) | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
| Thread(target=model.generate, kwargs={**inputs, "streamer": streamer}).start() | |
| buffer = "" | |
| for txt in streamer: | |
| buffer += txt.replace("<|im_end|>", "") | |
| time.sleep(0.01) | |
| yield buffer | |
| return | |
| if is_image: | |
| yield progress_bar_html("Processing image with cosmos-reason1") | |
| imgs = [Image.open(f) for f in files] | |
| messages = [ | |
| {"role":"user","content":[*[{"type":"image","image":i} for i in imgs],{"type":"text","text":prompt}]}] | |
| prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = processor( | |
| text=[prompt_full], images=imgs, | |
| return_tensors="pt", padding=True, | |
| truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH | |
| ).to(device) | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
| Thread(target=model.generate, kwargs={**inputs, "streamer": streamer}).start() | |
| out = "" | |
| for txt in streamer: | |
| out += txt.replace("<|im_end|>", "") | |
| time.sleep(0.01) | |
| yield out | |
| return | |
| # No valid media | |
| yield "Please upload at least one image or a video for inference." | |
| def main(): | |
| demo = gr.ChatInterface( | |
| fn=generate, | |
| additional_inputs=[ | |
| gr.File(label="Upload Images/Videos", file_types=["image", "video"], file_count="multiple") | |
| ], | |
| description="# **cosmos-reason1 by nvidia**", | |
| textbox=gr.Textbox(label="Prompt"), | |
| cache_examples=False, | |
| type="messages", | |
| multimodal=True, | |
| stop_btn="Stop Generation" | |
| ) | |
| demo.queue(max_size=10).launch(share=True) | |
| if __name__ == "__main__": | |
| main() |