MogensR's picture
Update app.py
2450c76
raw
history blame
26.1 kB
#!/usr/bin/env python3
"""
Video Background Replacement - Main Application
Refactored version with improved error handling, memory management, and configuration
"""
import os
import cv2
import numpy as np
import torch
import time
import logging
import threading
import subprocess
from pathlib import Path
from typing import Optional, Tuple, Dict, Any, Callable
from dataclasses import dataclass
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Apply Gradio schema patch early
try:
import gradio_client.utils as gc_utils
original_get_type = gc_utils.get_type
def patched_get_type(schema):
if not isinstance(schema, dict):
if isinstance(schema, bool):
return "boolean"
if isinstance(schema, str):
return "string"
if isinstance(schema, (int, float)):
return "number"
return "string"
return original_get_type(schema)
gc_utils.get_type = patched_get_type
logger.info("Gradio schema patch applied successfully")
except Exception as e:
logger.error(f"Gradio patch failed: {e}")
# Import core modules
from utilities import (
segment_person_hq,
refine_mask_hq,
replace_background_hq,
create_professional_background,
PROFESSIONAL_BACKGROUNDS,
validate_video_file
)
try:
from two_stage_processor import TwoStageProcessor, CHROMA_PRESETS
TWO_STAGE_AVAILABLE = True
except ImportError:
TWO_STAGE_AVAILABLE = False
CHROMA_PRESETS = {'standard': {}}
# Configuration
@dataclass
class ProcessingConfig:
keyframe_interval: int = int(os.getenv('KEYFRAME_INTERVAL', '5'))
frame_skip: int = int(os.getenv('FRAME_SKIP', '1'))
memory_cleanup_interval: int = int(os.getenv('MEMORY_CLEANUP_INTERVAL', '30'))
max_video_length: int = int(os.getenv('MAX_VIDEO_LENGTH', '300')) # seconds
quality_preset: str = os.getenv('QUALITY_PRESET', 'balanced')
class DeviceManager:
"""Manage device detection and switching"""
@staticmethod
def get_optimal_device():
if torch.cuda.is_available():
try:
# Test CUDA functionality
test_tensor = torch.tensor([1.0], device='cuda')
del test_tensor
torch.cuda.empty_cache()
device = torch.device("cuda")
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
return device
except Exception as e:
logger.warning(f"CUDA test failed: {e}, falling back to CPU")
logger.info("Using CPU device")
return torch.device("cpu")
class MemoryManager:
"""Enhanced memory management"""
def __init__(self, device):
self.device = device
self.gpu_available = device.type == 'cuda'
def cleanup_aggressive(self):
import gc
gc.collect()
if self.gpu_available:
torch.cuda.empty_cache()
torch.cuda.synchronize()
def get_memory_usage(self):
usage = {}
if self.gpu_available:
gpu_memory = torch.cuda.get_device_properties(0).total_memory
gpu_allocated = torch.cuda.memory_allocated(0)
usage['gpu_percent'] = (gpu_allocated / gpu_memory) * 100
usage['gpu_allocated_gb'] = gpu_allocated / (1024**3)
return usage
class ProgressTracker:
"""Enhanced progress tracking with detailed statistics"""
def __init__(self, total_frames: int, callback: Optional[Callable] = None):
self.total_frames = total_frames
self.callback = callback
self.start_time = time.time()
self.processed_frames = 0
self.frame_times = []
def update(self, frame_number: int, stage: str = ""):
current_time = time.time()
self.processed_frames = frame_number
elapsed_time = current_time - self.start_time
current_fps = self.processed_frames / elapsed_time if elapsed_time > 0 else 0
remaining_frames = self.total_frames - self.processed_frames
eta_seconds = remaining_frames / current_fps if current_fps > 0 else 0
progress_pct = self.processed_frames / self.total_frames if self.total_frames > 0 else 0
message = (
f"Frame {self.processed_frames}/{self.total_frames} | "
f"Elapsed: {self._format_time(elapsed_time)} | "
f"Speed: {current_fps:.1f} fps | "
f"ETA: {self._format_time(eta_seconds)}"
)
if stage:
message = f"{stage} | {message}"
if self.callback:
try:
self.callback(progress_pct, message)
except Exception as e:
logger.warning(f"Progress callback failed: {e}")
def _format_time(self, seconds: float) -> str:
if seconds < 60:
return f"{int(seconds)}s"
elif seconds < 3600:
return f"{int(seconds//60)}m {int(seconds%60)}s"
else:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
return f"{hours}h {minutes}m"
class VideoProcessor:
"""Main video processing class with error recovery"""
def __init__(self):
self.device = DeviceManager.get_optimal_device()
self.memory_manager = MemoryManager(self.device)
self.config = ProcessingConfig()
self.sam2_predictor = None
self.matanyone_model = None
self.two_stage_processor = None
self.models_loaded = False
self.loading_lock = threading.Lock()
self.cancel_event = threading.Event()
def load_models(self, progress_callback: Optional[Callable] = None) -> str:
"""Load AI models with comprehensive validation"""
with self.loading_lock:
if self.models_loaded:
return "Models already loaded and validated"
try:
self.cancel_event.clear()
start_time = time.time()
if progress_callback:
progress_callback(0.0, f"Starting model loading on {self.device}")
# Load SAM2
self.sam2_predictor = self._load_sam2(progress_callback)
if self.cancel_event.is_set():
return "Model loading cancelled"
# Load MatAnyone
self.matanyone_model = self._load_matanyone(progress_callback)
if self.cancel_event.is_set():
return "Model loading cancelled"
# Initialize two-stage processor if available
if TWO_STAGE_AVAILABLE:
try:
self.two_stage_processor = TwoStageProcessor(
self.sam2_predictor, self.matanyone_model
)
logger.info("Two-stage processor initialized")
except Exception as e:
logger.warning(f"Two-stage processor init failed: {e}")
self.models_loaded = True
load_time = time.time() - start_time
message = f"Models loaded successfully in {load_time:.1f}s on {self.device}"
if TWO_STAGE_AVAILABLE:
message += " (Two-stage mode available)"
logger.info(message)
return message
except Exception as e:
self.models_loaded = False
error_msg = f"Model loading failed: {str(e)}"
logger.error(error_msg)
return error_msg
def _load_sam2(self, progress_callback: Optional[Callable]) -> Any:
"""Load SAM2 predictor with validation"""
if progress_callback:
progress_callback(0.1, "Loading SAM2...")
try:
from huggingface_hub import hf_hub_download
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Download checkpoint
checkpoint_path = hf_hub_download(
repo_id="facebook/sam2-hiera-large",
filename="sam2_hiera_large.pt",
cache_dir=str(Path("/tmp/model_cache/sam2_checkpoint")),
force_download=False
)
# Build model
sam2_model = build_sam2("sam2_hiera_l.yaml", checkpoint_path)
sam2_model.to(self.device)
sam2_model.eval()
predictor = SAM2ImagePredictor(sam2_model)
# Validate with test
test_image = np.zeros((256, 256, 3), dtype=np.uint8)
predictor.set_image(test_image)
test_points = np.array([[128.0, 128.0]], dtype=np.float32)
test_labels = np.array([1], dtype=np.int32)
with torch.no_grad():
masks, scores, _ = predictor.predict(
point_coords=test_points,
point_labels=test_labels,
multimask_output=False
)
if masks is None or len(masks) == 0:
raise Exception("SAM2 validation failed")
if progress_callback:
progress_callback(0.5, "SAM2 loaded and validated")
return predictor
except Exception as e:
logger.error(f"SAM2 loading failed: {e}")
raise
def _load_matanyone(self, progress_callback: Optional[Callable]) -> Any:
"""Load MatAnyone processor for Python 3.10"""
if progress_callback:
progress_callback(0.6, "Loading MatAnyone...")
try:
# Import MatAnyone - Python 3.10 compatible
try:
from matanyone import InferenceCore
processor = InferenceCore("PeiqingYang/MatAnyone")
logger.info("MatAnyone loaded via InferenceCore")
except ImportError:
try:
# Alternative import path
import matanyone
processor = matanyone.load_model("PeiqingYang/MatAnyone")
logger.info("MatAnyone loaded via direct import")
except ImportError as e:
logger.error(f"MatAnyone import failed: {e}")
logger.error("Ensure all dependencies are installed: timm>=0.9.16, einops==0.8.0")
return None
# Test MatAnyone functionality
test_image = np.zeros((256, 256, 3), dtype=np.uint8)
test_mask = np.zeros((256, 256), dtype=np.uint8)
test_mask[64:192, 64:192] = 255
try:
if hasattr(processor, 'infer'):
test_result = processor.infer(test_image, test_mask)
elif hasattr(processor, 'process'):
test_result = processor.process(test_image, test_mask)
elif callable(processor):
test_result = processor(test_image, test_mask)
else:
logger.warning("MatAnyone processor has unknown interface")
return processor # Return anyway, utilities will handle
if test_result is not None:
logger.info("MatAnyone test successful")
else:
logger.warning("MatAnyone test returned None")
except Exception as test_error:
logger.warning(f"MatAnyone test failed: {test_error}")
# Still return processor - might work in actual use
if progress_callback:
progress_callback(0.9, "MatAnyone loaded successfully")
return processor
except Exception as e:
logger.error(f"MatAnyone loading failed: {e}")
return None
def process_video(
self,
video_path: str,
background_choice: str,
custom_background_path: Optional[str] = None,
progress_callback: Optional[Callable] = None,
use_two_stage: bool = False,
chroma_preset: str = "standard",
preview_mask: bool = False,
preview_greenscreen: bool = False
) -> Tuple[Optional[str], str]:
"""Process video with comprehensive error handling"""
if not self.models_loaded:
return None, "Models not loaded. Please load models first."
if self.cancel_event.is_set():
return None, "Processing cancelled"
# Validate input
is_valid, validation_msg = validate_video_file(video_path)
if not is_valid:
return None, f"Invalid video: {validation_msg}"
try:
if use_two_stage and TWO_STAGE_AVAILABLE and self.two_stage_processor:
return self._process_two_stage(
video_path, background_choice, custom_background_path,
progress_callback, chroma_preset
)
else:
return self._process_single_stage(
video_path, background_choice, custom_background_path,
progress_callback, preview_mask, preview_greenscreen
)
except Exception as e:
logger.error(f"Video processing failed: {e}")
return None, f"Processing failed: {str(e)}"
def _process_single_stage(
self,
video_path: str,
background_choice: str,
custom_background_path: Optional[str],
progress_callback: Optional[Callable],
preview_mask: bool,
preview_greenscreen: bool
) -> Tuple[Optional[str], str]:
"""Single-stage video processing"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, "Could not open video file"
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Prepare background
background = self._prepare_background(
background_choice, custom_background_path, frame_width, frame_height
)
if background is None:
cap.release()
return None, "Failed to prepare background"
# Setup output
timestamp = int(time.time())
output_path = f"/tmp/output_{timestamp}.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
if not out.isOpened():
cap.release()
return None, "Could not create output video"
# Process frames
progress_tracker = ProgressTracker(total_frames, progress_callback)
frame_count = 0
successful_frames = 0
last_refined_mask = None
try:
while True:
if self.cancel_event.is_set():
break
ret, frame = cap.read()
if not ret:
break
try:
progress_tracker.update(frame_count, "Processing")
# Segmentation
mask = segment_person_hq(frame, self.sam2_predictor)
# Mask refinement (keyframe-based)
if (frame_count % self.config.keyframe_interval == 0) or (last_refined_mask is None):
refined_mask = refine_mask_hq(frame, mask, self.matanyone_model)
last_refined_mask = refined_mask.copy()
else:
# Blend with previous refined mask for temporal consistency
alpha = 0.7
refined_mask = cv2.addWeighted(mask, alpha, last_refined_mask, 1-alpha, 0)
# Generate output based on mode
if preview_mask:
result_frame = self._create_mask_preview(frame, refined_mask)
elif preview_greenscreen:
result_frame = self._create_greenscreen_preview(frame, refined_mask)
else:
result_frame = replace_background_hq(frame, refined_mask, background)
out.write(result_frame)
successful_frames += 1
except Exception as frame_error:
logger.warning(f"Frame {frame_count} processing failed: {frame_error}")
out.write(frame) # Write original frame as fallback
frame_count += 1
# Memory cleanup
if frame_count % self.config.memory_cleanup_interval == 0:
self.memory_manager.cleanup_aggressive()
finally:
cap.release()
out.release()
if self.cancel_event.is_set():
try:
os.remove(output_path)
except:
pass
return None, "Processing cancelled"
if successful_frames == 0:
return None, "No frames processed successfully"
# Add audio if not preview mode
if not (preview_mask or preview_greenscreen):
final_output = self._add_audio(video_path, output_path)
else:
final_output = output_path
success_msg = (
f"Success! Processed {successful_frames}/{frame_count} frames\n"
f"Background: {background_choice}\n"
f"Mode: Single-stage\n"
f"Device: {self.device}"
)
return final_output, success_msg
def _process_two_stage(
self,
video_path: str,
background_choice: str,
custom_background_path: Optional[str],
progress_callback: Optional[Callable],
chroma_preset: str
) -> Tuple[Optional[str], str]:
"""Two-stage processing using green screen intermediate"""
cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
# Prepare background
background = self._prepare_background(
background_choice, custom_background_path, frame_width, frame_height
)
if background is None:
return None, "Failed to prepare background"
# Process with two-stage pipeline
timestamp = int(time.time())
final_output = f"/tmp/twostage_final_{timestamp}.mp4"
chroma_settings = CHROMA_PRESETS.get(chroma_preset, CHROMA_PRESETS['standard'])
result, message = self.two_stage_processor.process_full_pipeline(
video_path,
background,
final_output,
chroma_settings=chroma_settings,
progress_callback=progress_callback
)
if result is None:
return None, message
success_msg = (
f"Two-stage success!\n"
f"Background: {background_choice}\n"
f"Preset: {chroma_preset}\n"
f"Quality: Cinema-grade\n"
f"Device: {self.device}"
)
return result, success_msg
def _prepare_background(
self,
background_choice: str,
custom_background_path: Optional[str],
width: int,
height: int
) -> Optional[np.ndarray]:
"""Prepare background image"""
if background_choice == "custom" and custom_background_path:
if not os.path.exists(custom_background_path):
logger.error(f"Custom background not found: {custom_background_path}")
return None
background = cv2.imread(custom_background_path)
if background is None:
logger.error("Could not read custom background")
return None
else:
if background_choice not in PROFESSIONAL_BACKGROUNDS:
logger.error(f"Unknown background: {background_choice}")
return None
bg_config = PROFESSIONAL_BACKGROUNDS[background_choice]
background = create_professional_background(bg_config, width, height)
return cv2.resize(background, (width, height))
def _create_mask_preview(self, frame: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Create mask preview visualization"""
mask_vis = np.zeros_like(frame)
mask_vis[..., 1] = mask # Green channel
return mask_vis
def _create_greenscreen_preview(self, frame: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Create green screen preview"""
green_bg = np.zeros_like(frame)
green_bg[:, :] = [0, 255, 0] # Pure green
mask_3ch = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask_norm = mask_3ch.astype(float) / 255
return (frame * mask_norm + green_bg * (1 - mask_norm)).astype(np.uint8)
def _add_audio(self, input_video: str, processed_video: str) -> str:
"""Add audio from original video to processed video"""
timestamp = int(time.time())
final_output = f"/tmp/final_with_audio_{timestamp}.mp4"
try:
# Check if input has audio
result = subprocess.run([
'ffprobe', '-v', 'quiet', '-select_streams', 'a:0',
'-show_entries', 'stream=codec_name', '-of', 'csv=p=0', input_video
], capture_output=True, text=True, timeout=30)
if result.returncode != 0:
logger.info("Input video has no audio")
return processed_video
# Add audio
result = subprocess.run([
'ffmpeg', '-y', '-i', processed_video, '-i', input_video,
'-c:v', 'copy', '-c:a', 'aac', '-b:a', '192k',
'-map', '0:v:0', '-map', '1:a:0', '-shortest', final_output
], capture_output=True, text=True, timeout=300)
if result.returncode == 0 and os.path.exists(final_output):
try:
os.remove(processed_video)
except:
pass
return final_output
else:
logger.warning("Audio processing failed, using video without audio")
return processed_video
except Exception as e:
logger.warning(f"Audio processing error: {e}")
return processed_video
def get_status(self) -> Dict[str, Any]:
"""Get current processor status"""
return {
'models_loaded': self.models_loaded,
'sam2_available': self.sam2_predictor is not None,
'matanyone_available': self.matanyone_model is not None,
'two_stage_available': TWO_STAGE_AVAILABLE and self.two_stage_processor is not None,
'device': str(self.device),
'memory_usage': self.memory_manager.get_memory_usage(),
'config': {
'keyframe_interval': self.config.keyframe_interval,
'quality_preset': self.config.quality_preset
}
}
def cancel_processing(self):
"""Cancel current processing"""
self.cancel_event.set()
logger.info("Processing cancellation requested")
# Global processor instance
processor = VideoProcessor()
# Compatibility functions for existing UI
def load_models_with_validation(progress_callback: Optional[Callable] = None) -> str:
return processor.load_models(progress_callback)
def process_video_fixed(
video_path: str,
background_choice: str,
custom_background_path: Optional[str],
progress_callback: Optional[Callable] = None,
use_two_stage: bool = False,
chroma_preset: str = "standard",
preview_mask: bool = False,
preview_greenscreen: bool = False
) -> Tuple[Optional[str], str]:
return processor.process_video(
video_path, background_choice, custom_background_path,
progress_callback, use_two_stage, chroma_preset,
preview_mask, preview_greenscreen
)
def get_model_status() -> Dict[str, Any]:
return processor.get_status()
def get_cache_status() -> Dict[str, Any]:
return processor.get_status()
# For backward compatibility
PROCESS_CANCELLED = processor.cancel_event
def main():
"""Main application entry point"""
try:
logger.info("Starting Video Background Replacement application")
logger.info(f"Device: {processor.device}")
logger.info(f"Two-stage available: {TWO_STAGE_AVAILABLE}")
# Import and create UI
from ui_components import create_interface
demo = create_interface()
# Launch application
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True,
debug=False
)
except Exception as e:
logger.error(f"Application startup failed: {e}")
raise
if __name__ == "__main__":
main()