Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import cv2 | |
| import numpy as np | |
| from typing import Optional | |
| import tempfile | |
| import os | |
| import spaces | |
| MID = "apple/FastVLM-7B" | |
| IMAGE_TOKEN_INDEX = -200 | |
| # Initialize model variables | |
| tok = None | |
| model = None | |
| def load_model(): | |
| global tok, model | |
| if tok is None or model is None: | |
| print("Loading FastVLM model...") | |
| tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MID, | |
| torch_dtype=torch.float16, | |
| device_map="cuda", | |
| trust_remote_code=True, | |
| ) | |
| print("Model loaded successfully!") | |
| return tok, model | |
| def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"): | |
| """Extract frames from video""" | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames == 0: | |
| cap.release() | |
| return [] | |
| frames = [] | |
| if sampling_method == "uniform": | |
| # Uniform sampling | |
| indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
| elif sampling_method == "first": | |
| # Take first N frames | |
| indices = list(range(min(num_frames, total_frames))) | |
| elif sampling_method == "last": | |
| # Take last N frames | |
| start = max(0, total_frames - num_frames) | |
| indices = list(range(start, total_frames)) | |
| else: # middle | |
| # Take frames from the middle | |
| start = max(0, (total_frames - num_frames) // 2) | |
| indices = list(range(start, min(start + num_frames, total_frames))) | |
| for idx in indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| # Convert BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| cap.release() | |
| return frames | |
| def caption_frame(image: Image.Image, prompt: str) -> str: | |
| """Generate caption for a single frame""" | |
| # Load model on GPU | |
| tok, model = load_model() | |
| # Build chat with custom prompt | |
| messages = [ | |
| {"role": "user", "content": f"<image>\n{prompt}"} | |
| ] | |
| rendered = tok.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=False | |
| ) | |
| pre, post = rendered.split("<image>", 1) | |
| # Tokenize the text around the image token | |
| pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids | |
| post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids | |
| # Splice in the IMAGE token id | |
| img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) | |
| input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) | |
| attention_mask = torch.ones_like(input_ids, device=model.device) | |
| # Preprocess image | |
| px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"] | |
| px = px.to(model.device, dtype=model.dtype) | |
| # Generate | |
| with torch.no_grad(): | |
| out = model.generate( | |
| inputs=input_ids, | |
| attention_mask=attention_mask, | |
| images=px, | |
| max_new_tokens=15, | |
| temperature=0.7, | |
| do_sample=True, | |
| ) | |
| caption = tok.decode(out[0], skip_special_tokens=True) | |
| # Extract only the generated part | |
| if prompt in caption: | |
| caption = caption.split(prompt)[-1].strip() | |
| return caption | |
| def process_video( | |
| video_path: str, | |
| num_frames: int, | |
| sampling_method: str, | |
| caption_mode: str, | |
| custom_prompt: str, | |
| progress=gr.Progress() | |
| ) -> tuple: | |
| """Process video and generate captions""" | |
| if not video_path: | |
| return "Please upload a video first.", None | |
| progress(0, desc="Extracting frames...") | |
| frames = extract_frames(video_path, num_frames, sampling_method) | |
| if not frames: | |
| return "Failed to extract frames from video.", None | |
| # Use brief one-sentence prompt for faster processing | |
| prompt = "Provide a brief one-sentence description of what's happening in this image." | |
| captions = [] | |
| frame_previews = [] | |
| for i, frame in enumerate(frames): | |
| progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...") | |
| caption = caption_frame(frame, prompt) | |
| captions.append(f"Frame {i + 1}: {caption}") | |
| frame_previews.append(frame) | |
| progress(1.0, desc="Generating summary...") | |
| # Combine captions into a simple narrative | |
| full_caption = "\n".join(captions) | |
| # Generate overall summary if multiple frames | |
| if len(frames) > 1: | |
| video_summary = f"Analyzed {len(frames)} frames:\n\n{full_caption}" | |
| else: | |
| video_summary = f"Video Analysis:\n\n{full_caption}" | |
| return video_summary, frame_previews | |
| # Create the Gradio interface | |
| # Create custom Apple-inspired theme | |
| class AppleTheme(gr.themes.Base): | |
| def __init__(self): | |
| super().__init__( | |
| primary_hue=gr.themes.colors.blue, | |
| secondary_hue=gr.themes.colors.gray, | |
| neutral_hue=gr.themes.colors.gray, | |
| spacing_size=gr.themes.sizes.spacing_md, | |
| radius_size=gr.themes.sizes.radius_md, | |
| text_size=gr.themes.sizes.text_md, | |
| font=[ | |
| gr.themes.GoogleFont("Inter"), | |
| "-apple-system", | |
| "BlinkMacSystemFont", | |
| "SF Pro Display", | |
| "SF Pro Text", | |
| "Helvetica Neue", | |
| "Helvetica", | |
| "Arial", | |
| "sans-serif" | |
| ], | |
| font_mono=[ | |
| gr.themes.GoogleFont("SF Mono"), | |
| "ui-monospace", | |
| "Consolas", | |
| "monospace" | |
| ] | |
| ) | |
| super().set( | |
| # Core colors | |
| body_background_fill="*neutral_50", | |
| body_background_fill_dark="*neutral_950", | |
| button_primary_background_fill="*primary_500", | |
| button_primary_background_fill_hover="*primary_600", | |
| button_primary_text_color="white", | |
| button_primary_border_color="*primary_500", | |
| # Shadows | |
| block_shadow="0 4px 12px rgba(0, 0, 0, 0.08)", | |
| # Borders | |
| block_border_width="1px", | |
| block_border_color="*neutral_200", | |
| input_border_width="1px", | |
| input_border_color="*neutral_300", | |
| input_border_color_focus="*primary_500", | |
| # Text | |
| block_title_text_weight="600", | |
| block_label_text_weight="500", | |
| block_label_text_size="13px", | |
| block_label_text_color="*neutral_600", | |
| body_text_color="*neutral_900", | |
| # Spacing | |
| layout_gap="16px", | |
| block_padding="20px", | |
| # Specific components | |
| slider_color="*primary_500", | |
| ) | |
| # Create the Gradio interface with the custom theme | |
| with gr.Blocks(theme=AppleTheme()) as demo: | |
| gr.Markdown("# π¬ FastVLM Video Captioning") | |
| with gr.Row(): | |
| # Main video display | |
| with gr.Column(scale=7): | |
| video_display = gr.Video( | |
| label="Video Input", | |
| autoplay=True, | |
| loop=True | |
| ) | |
| # Sidebar with chat interface | |
| with gr.Sidebar(width=400): | |
| gr.Markdown("## π¬ Video Analysis Chat") | |
| chatbot = gr.Chatbot( | |
| value=[["Assistant", "Upload a video and I'll analyze it for you!"]], | |
| height=400, | |
| elem_classes=["chatbot"] | |
| ) | |
| process_btn = gr.Button("π― Analyze Video", variant="primary", size="lg") | |
| with gr.Accordion("πΌοΈ Analyzed Frames", open=False): | |
| frame_gallery = gr.Gallery( | |
| label="Extracted Frames", | |
| show_label=False, | |
| columns=2, | |
| rows=4, | |
| object_fit="contain", | |
| height="auto" | |
| ) | |
| # Hidden parameters with default values | |
| num_frames = gr.State(value=8) | |
| sampling_method = gr.State(value="uniform") | |
| caption_mode = gr.State(value="Brief Summary") | |
| custom_prompt = gr.State(value="") | |
| # Upload handler | |
| def handle_upload(video, chat_history): | |
| if video: | |
| chat_history.append(["User", "Video uploaded"]) | |
| chat_history.append(["Assistant", "Video loaded! Click 'Analyze Video' to generate captions."]) | |
| return video, chat_history | |
| return None, chat_history | |
| video_display.upload( | |
| handle_upload, | |
| inputs=[video_display, chatbot], | |
| outputs=[video_display, chatbot] | |
| ) | |
| # Modified process function to update chatbot with streaming | |
| def process_video_with_chat(video_path, num_frames, sampling_method, caption_mode, custom_prompt, chat_history, progress=gr.Progress()): | |
| if not video_path: | |
| chat_history.append(["Assistant", "Please upload a video first."]) | |
| yield chat_history, None | |
| return | |
| chat_history.append(["User", "Analyzing video..."]) | |
| yield chat_history, None | |
| # Extract frames | |
| progress(0, desc="Extracting frames...") | |
| frames = extract_frames(video_path, num_frames, sampling_method) | |
| if not frames: | |
| chat_history.append(["Assistant", "Failed to extract frames from video."]) | |
| yield chat_history, None | |
| return | |
| # Start streaming response | |
| chat_history.append(["Assistant", ""]) | |
| prompt = "Provide a brief one-sentence description of what's happening in this image." | |
| captions = [] | |
| for i, frame in enumerate(frames): | |
| progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...") | |
| caption = caption_frame(frame, prompt) | |
| frame_caption = f"Frame {i + 1}: {caption}\n" | |
| captions.append(frame_caption) | |
| # Update the last message with accumulated captions | |
| current_text = "".join(captions) | |
| chat_history[-1] = ["Assistant", f"Analyzing {len(frames)} frames:\n\n{current_text}"] | |
| yield chat_history, frames[:i+1] # Also update frame gallery progressively | |
| progress(1.0, desc="Analysis complete!") | |
| # Final update with complete message | |
| full_caption = "".join(captions) | |
| final_message = f"Analyzed {len(frames)} frames:\n\n{full_caption}" | |
| chat_history[-1] = ["Assistant", final_message] | |
| yield chat_history, frames | |
| # Process button with streaming | |
| process_btn.click( | |
| process_video_with_chat, | |
| inputs=[video_display, num_frames, sampling_method, caption_mode, custom_prompt, chatbot], | |
| outputs=[chatbot, frame_gallery], | |
| show_progress=True | |
| ) | |
| demo.launch() |