Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import spaces | |
| from transformers import AutoModelForImageSegmentation | |
| from torchvision import transforms | |
| import moviepy.editor as mp | |
| from PIL import Image | |
| import numpy as np | |
| import tempfile | |
| import time | |
| import os | |
| import shutil | |
| import ffmpeg | |
| from concurrent.futures import ThreadPoolExecutor | |
| from gradio.themes.base import Base | |
| from gradio.themes.utils import colors, fonts | |
| # Custom Theme Definition | |
| class WhiteTheme(Base): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.orange, | |
| font: fonts.Font | str | tuple[fonts.Font | str, ...] = ( | |
| fonts.GoogleFont("Inter"), | |
| "ui-sans-serif", | |
| "system-ui", | |
| "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | tuple[fonts.Font | str, ...] = ( | |
| fonts.GoogleFont("Inter"), | |
| "ui-monospace", | |
| "system-ui", | |
| "monospace", | |
| ) | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| self.set( | |
| # Light mode specific colors | |
| background_fill_primary="*primary_50", | |
| background_fill_secondary="white", | |
| border_color_primary="*primary_300", | |
| # General colors that should stay constant | |
| body_background_fill="white", | |
| body_background_fill_dark="white", | |
| block_background_fill="white", | |
| block_background_fill_dark="white", | |
| panel_background_fill="white", | |
| panel_background_fill_dark="white", | |
| body_text_color="black", | |
| body_text_color_dark="black", | |
| block_label_text_color="black", | |
| block_label_text_color_dark="black", | |
| block_border_color="white", | |
| panel_border_color="white", | |
| input_border_color="lightgray", | |
| input_background_fill="white", | |
| input_background_fill_dark="white", | |
| shadow_drop="none" | |
| ) | |
| # Set precision and device | |
| torch.set_float32_matmul_precision("medium") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load models | |
| print("Loading models...") | |
| birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True) | |
| birefnet.to(device) | |
| birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True) | |
| birefnet_lite.to(device) | |
| print("Models loaded successfully!") | |
| # Image transformation | |
| transform_image = transforms.Compose([ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]), | |
| ]) | |
| def process_frame(frame, fast_mode=True): | |
| """ | |
| Process a single frame through the BiRefNet model. | |
| Maintains original resolution throughout processing. | |
| Returns a PIL Image with alpha channel. | |
| """ | |
| try: | |
| # Preserve original resolution for final output | |
| image_ori = Image.fromarray(frame).convert('RGB') | |
| original_size = image_ori.size | |
| # Transform for model input while maintaining aspect ratio | |
| input_images = transform_image(image_ori).unsqueeze(0).to(device) | |
| # Select model based on mode | |
| model = birefnet_lite if fast_mode else birefnet | |
| with torch.no_grad(): | |
| preds = model(input_images)[-1].sigmoid().cpu() | |
| pred = preds[0].squeeze() | |
| # Resize mask back to original resolution | |
| pred_pil = transforms.ToPILImage()(pred) | |
| pred_pil = pred_pil.resize(original_size, Image.BICUBIC) | |
| # Create foreground with transparency | |
| foreground = image_ori.copy() | |
| foreground.putalpha(pred_pil) | |
| return foreground | |
| except Exception as e: | |
| print(f"Error processing frame: {e}") | |
| return None | |
| # 5-minute duration for processing | |
| def process_video(video_path, fps=0, fast_mode=True, max_workers=6): | |
| """ | |
| Process video to create transparent MOV file using ProRes 4444. | |
| Maintains original resolution and framerate if fps=0. | |
| """ | |
| temp_dir = None | |
| try: | |
| start_time = time.time() | |
| video = mp.VideoFileClip(video_path) | |
| # Use original video FPS if not specified | |
| if fps == 0: | |
| fps = video.fps | |
| frames = list(video.iter_frames(fps=fps)) | |
| total_frames = len(frames) | |
| print(f"Processing {total_frames} frames at {fps} FPS...") | |
| # Create temporary directory for PNG sequence | |
| temp_dir = tempfile.mkdtemp() | |
| png_dir = os.path.join(temp_dir, "frames") | |
| os.makedirs(png_dir, exist_ok=True) | |
| # Prepare to collect processed frames for live preview | |
| processed_frames = [] | |
| # Process frames with parallel execution | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| futures = [executor.submit(process_frame, frame, fast_mode) for frame in frames] | |
| for i, future in enumerate(futures): | |
| try: | |
| result = future.result() | |
| if result: | |
| # Save frame as PNG with transparency | |
| frame_path = os.path.join(png_dir, f"frame_{i:06d}.png") | |
| result.save(frame_path, "PNG") | |
| # Collect processed frames for live preview | |
| processed_frames.append(np.array(result)) | |
| # Update live preview | |
| elapsed_time = time.time() - start_time | |
| yield processed_frames[-1], None, None, None, f"Processing frame {i+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds" | |
| if (i + 1) % 10 == 0: | |
| print(f"Processed {i+1}/{total_frames} frames") | |
| except Exception as e: | |
| print(f"Error processing frame {i+1}: {e}") | |
| print("Creating output files...") | |
| # Create permanent output directory | |
| output_dir = os.path.join(os.path.dirname(video_path), "output") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Create ZIP file of PNG sequence | |
| zip_filename = f"frames_{int(time.time())}.zip" | |
| zip_path = os.path.join(output_dir, zip_filename) | |
| shutil.make_archive(zip_path[:-4], 'zip', png_dir) | |
| # Create MOV file with ProRes 4444 | |
| print("Creating ProRes 4444 MOV...") | |
| mov_filename = f"video_{int(time.time())}.mov" | |
| mov_path = os.path.join(output_dir, mov_filename) | |
| try: | |
| # FFmpeg settings for high-quality ProRes 4444 | |
| stream = ffmpeg.input( | |
| os.path.join(png_dir, 'frame_%06d.png'), | |
| pattern_type='sequence', | |
| framerate=fps | |
| ) | |
| # ProRes 4444 settings for maximum quality with alpha | |
| stream = ffmpeg.output( | |
| stream, | |
| mov_path, | |
| vcodec='prores_ks', # ProRes codec | |
| pix_fmt='yuva444p10le', # 10-bit 4:4:4:4 pixel format with alpha | |
| profile='4444', # ProRes 4444 profile for alpha support | |
| alpha_bits=16, # Maximum alpha bit depth | |
| qscale=1, # Highest quality setting | |
| vendor='ap10', # Standard ProRes vendor tag | |
| bits_per_mb=8000, # High bitrate for quality | |
| threads=max_workers # Parallel processing | |
| ) | |
| # Run FFmpeg command | |
| ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True) | |
| print("MOV video created successfully!") | |
| except ffmpeg.Error as e: | |
| print(f"Error creating MOV video: {e.stderr.decode() if e.stderr else str(e)}") | |
| mov_path = None | |
| print("Processing complete!") | |
| # Yield the final outputs | |
| yield None, zip_path, mov_path, None, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds" | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| yield None, None, None, None, f"Error processing video: {e}" | |
| finally: | |
| # Clean up temporary directory | |
| if temp_dir and os.path.exists(temp_dir): | |
| try: | |
| shutil.rmtree(temp_dir) | |
| except Exception as e: | |
| print(f"Error cleaning up temp directory: {e}") | |
| # Match process_video duration | |
| def process_wrapper(video, fps=0, fast_mode=True, max_workers=6): | |
| if video is None: | |
| raise gr.Error("Please upload a video.") | |
| try: | |
| for outputs in process_video(video, fps, fast_mode, max_workers): | |
| yield outputs | |
| except Exception as e: | |
| raise gr.Error(f"Error processing video: {str(e)}") | |
| # Custom CSS for styling | |
| custom_css = """ | |
| .title-container { | |
| text-align: center; | |
| padding: 10px 0; | |
| } | |
| #title { | |
| font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; | |
| font-size: 36px; | |
| font-weight: bold; | |
| color: #000000; | |
| padding: 10px; | |
| border-radius: 10px; | |
| display: inline-block; | |
| background: linear-gradient( | |
| 135deg, | |
| #e0f7fa, #e8f5e9, #fff9c4, #ffebee, | |
| #f3e5f5, #e1f5fe, #fff3e0, #e8eaf6 | |
| ); | |
| background-size: 400% 400%; | |
| animation: gradient-animation 15s ease infinite; | |
| } | |
| @keyframes gradient-animation { | |
| 0% { background-position: 0% 50%; } | |
| 50% { background-position: 100% 50%; } | |
| 100% { background-position: 0% 50%; } | |
| } | |
| #submit-button { | |
| background: linear-gradient( | |
| 135deg, | |
| #e0f7fa, #e8f5e9, #fff9c4, #ffebee, | |
| #f3e5f5, #e1f5fe, #fff3e0, #e8eaf6 | |
| ); | |
| background-size: 400% 400%; | |
| animation: gradient-animation 15s ease infinite; | |
| border-radius: 12px; | |
| color: black; | |
| } | |
| /* Force light mode styles */ | |
| :root, :root[data-theme='light'], :root[data-theme='dark'] { | |
| --body-background-fill: white !important; | |
| --background-fill-primary: white !important; | |
| --background-fill-secondary: white !important; | |
| --block-background-fill: white !important; | |
| --panel-background-fill: white !important; | |
| --body-text-color: black !important; | |
| --block-label-text-color: black !important; | |
| } | |
| /* Additional overrides for dark mode */ | |
| @media (prefers-color-scheme: dark) { | |
| :root { | |
| color-scheme: light; | |
| } | |
| } | |
| """ | |
| # Gradio Interface | |
| with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo: | |
| gr.HTML(''' | |
| <div class="title-container"> | |
| <div id="title"> | |
| <span>{.</span><span id="typed-text"></span><span>}</span> | |
| </div> | |
| </div> | |
| <script> | |
| (function() { | |
| const text = "video"; | |
| const typedTextSpan = document.getElementById("typed-text"); | |
| let charIndex = 0; | |
| function type() { | |
| if (charIndex < text.length) { | |
| typedTextSpan.textContent += text[charIndex]; | |
| charIndex++; | |
| setTimeout(type, 150); | |
| } | |
| } | |
| setTimeout(type, 150); | |
| })(); | |
| </script> | |
| ''') | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video( | |
| label="Upload Video", | |
| interactive=True, | |
| show_label=True, | |
| height=360, | |
| width=640 | |
| ) | |
| with gr.Row(): | |
| fps_slider = gr.Slider( | |
| minimum=0, | |
| maximum=60, | |
| step=1, | |
| value=0, | |
| label="Output FPS (0 will inherit the original fps value)", | |
| ) | |
| fast_mode_checkbox = gr.Checkbox( | |
| label="Fast Mode (Use BiRefNet_lite)", | |
| value=True | |
| ) | |
| max_workers_slider = gr.Slider( | |
| minimum=1, | |
| maximum=32, | |
| step=1, | |
| value=6, | |
| label="Max Workers", | |
| info="Determines how many frames to process in parallel" | |
| ) | |
| btn = gr.Button("Process Video", elem_id="submit-button") | |
| with gr.Column(): | |
| preview_image = gr.Image(label="Live Preview", show_label=True) | |
| output_foreground_zip = gr.File(label="Download PNG Sequence (ZIP)") | |
| output_foreground_video = gr.File(label="Download Video (ProRes 4444 MOV with transparency)") | |
| output_background = gr.Video(label="Background (Coming Soon)") | |
| time_textbox = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown(""" | |
| ### Output Information | |
| - MOV file uses ProRes 4444 codec for professional-grade alpha channel | |
| - Original resolution and framerate are maintained | |
| - PNG sequence provided for maximum compatibility | |
| """) | |
| btn.click( | |
| fn=process_wrapper, | |
| inputs=[video_input, fps_slider, fast_mode_checkbox, max_workers_slider], | |
| outputs=[preview_image, output_foreground_zip, output_foreground_video, | |
| output_background, time_textbox] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |