MogensR commited on
Commit
ee38ee4
·
1 Parent(s): d4f305e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +534 -756
app.py CHANGED
@@ -1,38 +1,34 @@
1
  #!/usr/bin/env python3
2
  """
3
- Final Fixed Video Background Replacement
4
- Uses proper functions from utilities.py to avoid transparency issues
5
- NEW: Added automatic device detection for Hugging Face Spaces compatibility,
6
- improved error handling, and better resource management
7
- FIXED: All issues identified by Grok4 - robust error handling, variable scope, codec fallbacks
8
- FIXED: Added SSR mode disable for Gradio compatibility
9
- FIXED: Audio preservation - no more missing audio in processed videos
10
- UPDATE: Enhanced logging for initialization errors, isolated matanyone imports to avoid GUI crashes
11
  """
 
 
12
  import cv2
13
  import numpy as np
14
- from pathlib import Path
15
  import torch
16
- import traceback
17
  import time
18
- import shutil
19
- import gc
20
  import threading
21
  import subprocess
22
- from typing import Optional, Tuple, Dict, Any
23
- import logging
24
- from huggingface_hub import hf_hub_download
25
- import os
26
 
27
- # ============================================================================ #
28
- # CRITICAL: GRADIO SCHEMA VALIDATION FIX - MUST BE FIRST
29
- # ============================================================================ #
 
 
 
 
 
30
  try:
31
  import gradio_client.utils as gc_utils
32
  original_get_type = gc_utils.get_type
33
 
34
  def patched_get_type(schema):
35
- """Fixed get_type function that handles boolean schemas properly"""
36
  if not isinstance(schema, dict):
37
  if isinstance(schema, bool):
38
  return "boolean"
@@ -40,843 +36,625 @@ def patched_get_type(schema):
40
  return "string"
41
  if isinstance(schema, (int, float)):
42
  return "number"
43
- return "string" # fallback
44
-
45
- # If it's a dict, use original function
46
  return original_get_type(schema)
47
 
48
  gc_utils.get_type = patched_get_type
49
- print("✅ CRITICAL: Gradio schema patch applied successfully!")
50
-
51
- except (ImportError, AttributeError) as e:
52
- print(f"❌ CRITICAL: Gradio patch failed: {e}")
53
  logger.error(f"Gradio patch failed: {e}")
54
 
55
- # Import utilities - CRITICAL: Use these functions, don't duplicate!
56
  from utilities import (
57
  segment_person_hq,
58
  refine_mask_hq,
59
- enhance_mask_opencv,
60
  replace_background_hq,
61
  create_professional_background,
62
  PROFESSIONAL_BACKGROUNDS,
63
  validate_video_file
64
  )
65
 
66
- # Import two-stage processor if available
67
  try:
68
  from two_stage_processor import TwoStageProcessor, CHROMA_PRESETS
69
  TWO_STAGE_AVAILABLE = True
70
- logger.info("Two-stage processor available")
71
- except ImportError as e:
72
  TWO_STAGE_AVAILABLE = False
73
- logger.warning(f"Two-stage processor not available: {e}")
74
-
75
- logging.basicConfig(level=logging.INFO)
76
- logger = logging.getLogger(__name__)
77
-
78
- # ============================================================================ #
79
- # OPTIMIZATION SETTINGS
80
- # ============================================================================ #
81
- KEYFRAME_INTERVAL = 5 # Process MatAnyone every 5th frame
82
- FRAME_SKIP = 1 # Process every frame (set to 2 for every other frame)
83
- MEMORY_CLEANUP_INTERVAL = 30 # Clean memory every 30 frames
84
-
85
- # ============================================================================ #
86
- # MODEL CACHING SYSTEM
87
- # ============================================================================ #
88
- CACHE_DIR = Path("/tmp/model_cache")
89
- CACHE_DIR.mkdir(exist_ok=True, parents=True)
90
-
91
- # ============================================================================ #
92
- # GLOBAL MODEL STATE
93
- # ============================================================================ #
94
- sam2_predictor = None
95
- matanyone_model = None
96
- models_loaded = False
97
- loading_lock = threading.Lock()
98
- two_stage_processor = None
99
- PROCESS_CANCELLED = threading.Event()
100
-
101
- # ============================================================================ #
102
- # DEVICE DETECTION FOR HUGGING FACE SPACES - ROBUST
103
- # ============================================================================ #
104
- def get_device():
105
- """Automatically detect the best available device (CPU or GPU) with robust error handling"""
106
- try:
107
  if torch.cuda.is_available():
108
  try:
109
- device_name = torch.cuda.get_device_name(0)
110
- logger.info(f"Using GPU: {device_name}")
111
- except Exception as e:
112
- logger.warning(f"Could not get GPU name: {e}, but CUDA is available")
113
- device_name = "CUDA GPU"
114
-
115
- try:
116
  test_tensor = torch.tensor([1.0], device='cuda')
117
  del test_tensor
118
  torch.cuda.empty_cache()
119
- return torch.device("cuda")
 
 
120
  except Exception as e:
121
- logger.error(f"CUDA test failed: {e}, falling back to CPU")
122
- return torch.device("cpu")
123
- else:
124
- logger.info("Using CPU (no GPU available)")
125
- return torch.device("cpu")
126
- except Exception as e:
127
- logger.error(f"Device detection failed: {e}, defaulting to CPU")
128
  return torch.device("cpu")
129
 
130
- # Set the device globally
131
- DEVICE = get_device()
132
-
133
- # ============================================================================ #
134
- # ROBUST FFMPEG OPERATIONS
135
- # ============================================================================ #
136
- def run_ffmpeg_command(command_args: list, description: str = "FFmpeg operation") -> bool:
137
- """Run ffmpeg command with proper error handling"""
138
- try:
139
- logger.info(f"Running {description}: {' '.join(command_args)}")
140
- result = subprocess.run(
141
- command_args,
142
- check=True,
143
- capture_output=True,
144
- text=True,
145
- timeout=300 # 5 minute timeout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
- logger.info(f"{description} completed successfully")
148
- return True
149
- except subprocess.CalledProcessError as e:
150
- logger.error(f"{description} failed with exit code {e.returncode}")
151
- logger.error(f"STDERR: {e.stderr}")
152
- return False
153
- except subprocess.TimeoutExpired:
154
- logger.error(f"{description} timed out")
155
- return False
156
- except Exception as e:
157
- logger.error(f"{description} failed: {e}")
158
- return False
 
 
 
 
 
 
 
159
 
160
- # ============================================================================ #
161
- # ROBUST VIDEO WRITER WITH CODEC FALLBACK
162
- # ============================================================================ #
163
- def create_video_writer(output_path: str, fps: float, width: int, height: int) -> Tuple[Optional[cv2.VideoWriter], Optional[str]]:
164
- """Create video writer with codec fallback"""
165
- codecs_to_try = [
166
- ('mp4v', '.mp4'), # Most compatible
167
- ('avc1', '.mp4'), # H.264 if available
168
- ('XVID', '.avi'), # Fallback
169
- ]
170
 
171
- for fourcc_str, ext in codecs_to_try:
172
- try:
173
- fourcc = cv2.VideoWriter_fourcc(*fourcc_str)
174
- if not output_path.endswith(ext):
175
- base = os.path.splitext(output_path)[0]
176
- test_path = base + ext
177
- else:
178
- test_path = output_path
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- writer = cv2.VideoWriter(test_path, fourcc, fps, (width, height))
181
- if writer.isOpened():
182
- logger.info(f"Successfully created video writer with {fourcc_str} codec")
183
- return writer, test_path
184
- else:
185
- writer.release()
186
- except Exception as e:
187
- logger.warning(f"Failed to create writer with {fourcc_str}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- logger.error("All video codecs failed")
190
- return None, None
191
-
192
- # ============================================================================ #
193
- # SAM2 LOADER WITH VALIDATION - ROBUST
194
- # ============================================================================ #
195
- def load_sam2_predictor_fixed(device: torch.device = DEVICE, progress_callback: Optional[callable] = None) -> Any:
196
- """Load SAM2 with proper error handling and validation"""
197
- def _prog(pct: float, desc: str):
198
  if progress_callback:
199
- progress_callback(pct, desc)
200
-
201
- if "Frame" in desc and "|" in desc:
202
- parts = desc.split("|")
203
- frame_info = parts[0].strip() if len(parts) > 0 else ""
204
- time_info = parts[1].strip() if len(parts) > 1 else ""
205
- fps_info = parts[2].strip() if len(parts) > 2 else ""
206
- eta_info = parts[3].strip() if len(parts) > 3 else ""
207
- display_text = f"""📊 PROCESSING STATUS
208
- ━━━━━━━━━━━━━━━━━━━━━━━━━━
209
- 🎬 {frame_info}
210
- ⏱️ Elapsed: {time_info}
211
- ⚡ Speed: {fps_info}
212
- 🎯 {eta_info}
213
- ━━━━━━━━━━━━━━━━━━━━━━━━━━
214
- 📈 Progress: {pct*100:.1f}%"""
215
- try:
216
- with open("/tmp/processing_info.txt", 'w') as f:
217
- f.write(display_text)
218
- except Exception as e:
219
- logger.warning(f"Error writing processing info: {e}")
220
-
221
- try:
222
- _prog(0.1, "Initializing SAM2...")
223
-
224
- hf_token = os.getenv('HF_TOKEN')
225
- if not hf_token:
226
- logger.warning("No HF_TOKEN found, downloads may be rate limited")
227
-
228
  try:
 
 
 
 
 
229
  checkpoint_path = hf_hub_download(
230
  repo_id="facebook/sam2-hiera-large",
231
  filename="sam2_hiera_large.pt",
232
- cache_dir=str(CACHE_DIR / "sam2_checkpoint"),
233
- force_download=False,
234
- token=hf_token
235
  )
236
- logger.info(f"SAM2 checkpoint downloaded to {checkpoint_path}")
237
- except Exception as e:
238
- logger.error(f"Failed to download SAM2 checkpoint: {e}")
239
- raise Exception(f"SAM2 checkpoint download failed: {e}")
240
-
241
- try:
242
- from sam2.build_sam import build_sam2
243
- from sam2.sam2_image_predictor import SAM2ImagePredictor
244
- logger.info("SAM2 modules imported successfully")
245
- except ImportError as e:
246
- logger.error(f"SAM2 import failed: {e}")
247
- raise Exception(f"SAM2 import failed: {e}. Make sure SAM2 is properly installed.")
248
-
249
- try:
250
  sam2_model = build_sam2("sam2_hiera_l.yaml", checkpoint_path)
251
- sam2_model.to(device)
252
  sam2_model.eval()
253
  predictor = SAM2ImagePredictor(sam2_model)
254
- logger.info(f"SAM2 model built and moved to {device}")
255
- except Exception as e:
256
- logger.error(f"SAM2 model creation failed: {e}")
257
- raise Exception(f"SAM2 model creation failed: {e}")
258
-
259
- _prog(0.8, "Testing SAM2 functionality...")
260
- test_image = np.zeros((256, 256, 3), dtype=np.uint8)
261
- predictor.set_image(test_image)
262
-
263
- test_points = np.array([[128.0, 128.0]], dtype=np.float32)
264
- test_labels = np.array([1], dtype=np.int32)
265
-
266
- try:
267
  with torch.no_grad():
268
  masks, scores, _ = predictor.predict(
269
  point_coords=test_points,
270
  point_labels=test_labels,
271
  multimask_output=False
272
  )
 
 
 
 
 
 
 
 
 
273
  except Exception as e:
274
- logger.error(f"SAM2 prediction test failed: {e}")
275
- raise Exception(f"SAM2 prediction test failed: {e}")
276
-
277
- if masks is None or len(masks) == 0:
278
- logger.error("SAM2 predictor test failed - no masks generated")
279
- raise Exception("SAM2 predictor test failed - no masks generated")
280
-
281
- _prog(1.0, "SAM2 loaded and validated successfully!")
282
- logger.info(f"SAM2 predictor loaded and tested successfully on {device}")
283
- return predictor
284
-
285
- except Exception as e:
286
- logger.error(f"SAM2 loading failed: {str(e)}")
287
- logger.error(f"Full traceback: {traceback.format_exc()}")
288
- raise Exception(f"SAM2 loading failed: {str(e)}")
289
-
290
- # ============================================================================ #
291
- # MATANYONE LOADER WITH VALIDATION - ROBUST
292
- # ============================================================================ #
293
- def load_matanyone_fixed(progress_callback: Optional[callable] = None) -> Any:
294
- """Load MatAnyone with proper error handling and validation"""
295
- def _prog(pct: float, desc: str):
296
  if progress_callback:
297
- progress_callback(pct, desc)
298
-
299
- try:
300
- _prog(0.2, "Loading MatAnyone...")
301
-
302
  try:
303
  from matanyone import InferenceCore
304
- logger.info("Successfully imported MatAnyone InferenceCore")
305
- except ImportError as e:
306
- logger.error(f"MatAnyone import failed: {e}")
307
- raise Exception(f"MatAnyone import failed: {e}. Make sure MatAnyone is properly installed.")
308
-
309
- try:
310
  processor = InferenceCore("PeiqingYang/MatAnyone")
311
- logger.info("MatAnyone InferenceCore initialized")
312
- except Exception as e:
313
- logger.error(f"MatAnyone model loading failed: {e}")
314
- raise Exception(f"MatAnyone model loading failed: {e}")
315
-
316
- _prog(0.8, "Testing MatAnyone functionality...")
317
- test_image = np.zeros((256, 256, 3), dtype=np.uint8)
318
- test_mask = np.zeros((256, 256), dtype=np.uint8)
319
- test_mask[64:192, 64:192] = 255
320
-
321
- try:
322
- if hasattr(processor, 'process') or hasattr(processor, '__call__'):
323
- logger.info("MatAnyone processor interface detected")
324
- else:
325
- logger.warning("MatAnyone interface unclear, will use fallback refinement")
326
- except Exception as test_e:
327
- logger.warning(f"MatAnyone test failed: {test_e}, will use enhanced OpenCV")
328
-
329
- _prog(1.0, "MatAnyone loaded successfully!")
330
- logger.info(f"MatAnyone processor loaded successfully on {DEVICE}")
331
- return processor
332
-
333
- except Exception as e:
334
- logger.error(f"MatAnyone loading failed: {str(e)}")
335
- logger.error(f"Full traceback: {traceback.format_exc()}")
336
- raise Exception(f"MatAnyone loading failed: {str(e)}")
337
-
338
- # ============================================================================ #
339
- # MODEL MANAGEMENT FUNCTIONS
340
- # ============================================================================ #
341
- def get_model_status() -> Dict[str, str]:
342
- """Return current model status for UI"""
343
- global sam2_predictor, matanyone_model, models_loaded
344
- return {
345
- 'sam2': 'Ready' if sam2_predictor is not None else 'Not loaded',
346
- 'matanyone': 'Ready' if matanyone_model is not None else 'Not loaded',
347
- 'validated': models_loaded,
348
- 'device': str(DEVICE)
349
- }
350
-
351
- def get_cache_status() -> Dict[str, Any]:
352
- """Get current cache status"""
353
- return {
354
- "sam2_loaded": sam2_predictor is not None,
355
- "matanyone_loaded": matanyone_model is not None,
356
- "models_validated": models_loaded,
357
- "two_stage_available": TWO_STAGE_AVAILABLE,
358
- "device": str(DEVICE)
359
- }
360
-
361
- def load_models_with_validation(progress_callback: Optional[callable] = None) -> str:
362
- """Load models with comprehensive validation"""
363
- global sam2_predictor, matanyone_model, models_loaded, two_stage_processor, PROCESS_CANCELLED
364
-
365
- with loading_lock:
366
- if models_loaded and not PROCESS_CANCELLED.is_set():
367
- logger.info("Models already loaded and validated")
368
- return "Models already loaded and validated"
369
-
370
- try:
371
- PROCESS_CANCELLED.clear()
372
- start_time = time.time()
373
- logger.info(f"Starting model loading on {DEVICE}")
374
-
375
  if progress_callback:
376
- progress_callback(0.0, f"Starting model loading on {DEVICE}...")
377
-
378
- sam2_predictor = load_sam2_predictor_fixed(device=DEVICE, progress_callback=progress_callback)
379
-
380
- if PROCESS_CANCELLED.is_set():
381
- logger.info("Model loading cancelled by user")
382
- return "Model loading cancelled by user"
383
-
384
- matanyone_model = load_matanyone_fixed(progress_callback=progress_callback)
385
-
386
- if PROCESS_CANCELLED.is_set():
387
- logger.info("Model loading cancelled by user")
388
- return "Model loading cancelled by user"
389
-
390
- models_loaded = True
391
-
392
- if TWO_STAGE_AVAILABLE:
393
- try:
394
- two_stage_processor = TwoStageProcessor(sam2_predictor, matanyone_model)
395
- logger.info("Two-stage processor initialized")
396
- except Exception as e:
397
- logger.warning(f"Two-stage processor initialization failed: {e}")
398
- TWO_STAGE_AVAILABLE = False
399
-
400
- load_time = time.time() - start_time
401
- message = f"SUCCESS: SAM2 + MatAnyone loaded and validated in {load_time:.1f}s on {DEVICE}"
402
- if TWO_STAGE_AVAILABLE:
403
- message += " (Two-stage mode available)"
404
- logger.info(message)
405
- return message
406
-
407
  except Exception as e:
408
- models_loaded = False
409
- error_msg = f"Model loading failed: {str(e)}"
410
- logger.error(error_msg)
411
- return error_msg
412
-
413
- # ============================================================================ #
414
- # MAIN VIDEO PROCESSING - USING UTILITIES FUNCTIONS - ROBUST
415
- # ============================================================================ #
416
- def process_video_fixed(
417
- video_path: str,
418
- background_choice: str,
419
- custom_background_path: Optional[str],
420
- progress_callback: Optional[callable] = None,
421
- use_two_stage: bool = False,
422
- chroma_preset: str = "standard",
423
- preview_mask: bool = False,
424
- preview_greenscreen: bool = False
425
- ) -> Tuple[Optional[str], str]:
426
- """Optimized video processing using proper functions from utilities - ROBUST VERSION"""
427
- global PROCESS_CANCELLED
428
-
429
- if PROCESS_CANCELLED.is_set():
430
- logger.info("Processing cancelled by user")
431
- return None, "Processing cancelled by user"
432
-
433
- if not models_loaded:
434
- logger.error("Models not loaded")
435
- return None, "Models not loaded. Call load_models_with_validation() first."
436
-
437
- if not video_path or not os.path.exists(video_path):
438
- logger.error(f"Video file not found: {video_path}")
439
- return None, f"Video file not found: {video_path}"
440
-
441
- is_valid, validation_msg = validate_video_file(video_path)
442
- if not is_valid:
443
- logger.error(f"Invalid video: {validation_msg}")
444
- return None, f"Invalid video: {validation_msg}"
445
-
446
- def _prog(pct: float, desc: str):
447
- if PROCESS_CANCELLED.is_set():
448
- raise Exception("Processing cancelled by user")
449
-
450
- if progress_callback:
451
- progress_callback(pct, desc)
452
-
453
- if "Frame" in desc and "|" in desc:
454
- parts = desc.split("|")
455
- frame_info = parts[0].strip() if len(parts) > 0 else ""
456
- time_info = parts[1].strip() if len(parts) > 1 else ""
457
- fps_info = parts[2].strip() if len(parts) > 2 else ""
458
- eta_info = parts[3].strip() if len(parts) > 3 else ""
459
-
460
- display_text = f"""📊 PROCESSING STATUS
461
- ━━━━━━━━━━━━━━━━━━━━━━━━━━
462
- 🎬 {frame_info}
463
- ⏱️ Elapsed: {time_info}
464
- ⚡ Speed: {fps_info}
465
- 🎯 {eta_info}
466
- ━━━━━━━━━━━━━━━━━━━━━━━━━━
467
- 📈 Progress: {pct*100:.1f}%"""
468
- try:
469
- with open("/tmp/processing_info.txt", 'w') as f:
470
- f.write(display_text)
471
- except Exception as e:
472
- logger.warning(f"Error writing processing info: {e}")
473
-
474
- try:
475
- _prog(0.0, f"Starting {'TWO-STAGE' if use_two_stage else 'SINGLE-STAGE'} processing on {DEVICE}...")
476
-
477
- if use_two_stage:
478
- if not TWO_STAGE_AVAILABLE:
479
- logger.error("Two-stage mode not available")
480
- return None, "Two-stage mode not available. Please add two_stage_processor.py file."
481
-
482
- if two_stage_processor is None:
483
- logger.error("Two-stage processor not initialized")
484
- return None, "Two-stage processor not initialized. Please reload models."
485
-
486
- _prog(0.05, "Starting TWO-STAGE green screen processing...")
487
-
488
- cap = cv2.VideoCapture(video_path)
489
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
490
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
491
- cap.release()
492
-
493
- if background_choice == "custom" and custom_background_path:
494
- if not os.path.exists(custom_background_path):
495
- logger.error(f"Custom background not found: {custom_background_path}")
496
- return None, f"Custom background not found: {custom_background_path}"
497
-
498
- background = cv2.imread(custom_background_path)
499
- if background is None:
500
- logger.error("Could not read custom background image")
501
- return None, "Could not read custom background image."
502
- background_name = "Custom Image"
503
  else:
504
- if background_choice in PROFESSIONAL_BACKGROUNDS:
505
- bg_config = PROFESSIONAL_BACKGROUNDS[background_choice]
506
- background = create_professional_background(bg_config, frame_width, frame_height)
507
- background_name = bg_config["name"]
508
- else:
509
- logger.error(f"Invalid background selection: {background_choice}")
510
- return None, f"Invalid background selection: {background_choice}"
511
-
512
- chroma_settings = CHROMA_PRESETS.get(chroma_preset, CHROMA_PRESETS['standard'])
513
-
514
- timestamp = int(time.time())
515
- final_output = f"/tmp/twostage_final_{timestamp}.mp4"
516
-
517
- result, message = two_stage_processor.process_full_pipeline(
518
- video_path,
519
- background,
520
- final_output,
521
- chroma_settings=chroma_settings,
522
- progress_callback=_prog
523
- )
524
-
525
- if PROCESS_CANCELLED.is_set():
526
- logger.info("Processing cancelled by user")
527
- return None, "Processing cancelled by user"
528
-
529
- if result is None:
530
- logger.error(f"Two-stage processing failed: {message}")
531
- return None, message
532
-
533
- _prog(0.9, "Adding audio...")
534
- final_with_audio = f"/tmp/twostage_audio_{timestamp}.mp4"
535
-
536
- audio_check_success = run_ffmpeg_command([
537
- 'ffprobe', '-v', 'quiet', '-select_streams', 'a:0',
538
- '-show_entries', 'stream=codec_name', '-of', 'csv=p=0', video_path
539
- ], "Checking for audio stream")
540
-
541
- if audio_check_success:
542
- audio_success = run_ffmpeg_command([
543
- 'ffmpeg', '-y', '-i', final_output, '-i', video_path,
544
- '-c:v', 'copy',
545
- '-c:a', 'aac', '-b:a', '192k', '-ac', '2', '-ar', '48000',
546
- '-map', '0:v:0', '-map', '1:a:0', '-shortest', final_with_audio
547
- ], "Two-stage audio processing with original audio")
548
 
549
- if not audio_success or not os.path.exists(final_with_audio):
550
- logger.warning("Failed with original audio, trying fallback method...")
551
- audio_success = run_ffmpeg_command([
552
- 'ffmpeg', '-y', '-i', video_path, '-i', final_output,
553
- '-c:v', 'libx264', '-crf', '18', '-preset', 'fast',
554
- '-c:a', 'copy',
555
- '-map', '1:v:0', '-map', '0:a:0', '-shortest', final_with_audio
556
- ], "Fallback two-stage audio processing")
557
- else:
558
- logger.info("Input video has no audio stream")
559
- try:
560
- shutil.copy2(final_output, final_with_audio)
561
- audio_success = True
562
- except Exception as e:
563
- logger.error(f"Failed to copy video: {e}")
564
- audio_success = False
565
- final_with_audio = final_output
566
-
567
- if audio_success and os.path.exists(final_with_audio):
568
- try:
569
- os.remove(final_output)
570
- except:
571
- pass
572
- final_output = final_with_audio
573
- else:
574
- logger.warning("Audio processing failed, using video without audio")
575
-
576
- _prog(1.0, "TWO-STAGE processing complete!")
577
-
578
- success_message = (
579
- f"TWO-STAGE Success!\n"
580
- f"Background: {background_name}\n"
581
- f"Method: Green Screen Chroma Key\n"
582
- f"Preset: {chroma_preset}\n"
583
- f"Quality: Professional cinema-grade\n"
584
- f"Device: {DEVICE}"
585
- )
586
-
587
- return final_output, success_message
588
-
589
- _prog(0.05, f"Starting SINGLE-STAGE processing on {DEVICE}...")
590
-
591
  cap = cv2.VideoCapture(video_path)
592
  if not cap.isOpened():
593
- logger.error("Could not open video file")
594
- return None, "Could not open video file."
595
-
596
  fps = cap.get(cv2.CAP_PROP_FPS)
597
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
598
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
599
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
600
-
601
- if total_frames == 0:
602
- cap.release()
603
- logger.error("Video appears to be empty")
604
- return None, "Video appears to be empty."
605
-
606
- logger.info(f"Video info: {frame_width}x{frame_height}, {fps}fps, {total_frames} frames, processing on {DEVICE}")
607
-
608
- background = None
609
- background_name = ""
610
-
611
- if background_choice == "custom" and custom_background_path:
612
- if not os.path.exists(custom_background_path):
613
- cap.release()
614
- logger.error(f"Custom background not found: {custom_background_path}")
615
- return None, f"Custom background not found: {custom_background_path}"
616
-
617
- background = cv2.imread(custom_background_path)
618
- if background is None:
619
- cap.release()
620
- logger.error("Could not read custom background image")
621
- return None, "Could not read custom background image."
622
- background_name = "Custom Image"
623
- else:
624
- if background_choice in PROFESSIONAL_BACKGROUNDS:
625
- bg_config = PROFESSIONAL_BACKGROUNDS[background_choice]
626
- background = create_professional_background(bg_config, frame_width, frame_height)
627
- background_name = bg_config["name"]
628
- else:
629
- cap.release()
630
- logger.error(f"Invalid background selection: {background_choice}")
631
- return None, f"Invalid background selection: {background_choice}"
632
-
633
  if background is None:
634
  cap.release()
635
- logger.error("Failed to create background")
636
- return None, "Failed to create background."
637
-
638
  timestamp = int(time.time())
639
-
640
- _prog(0.1, f"Processing {total_frames} frames with {'TWO-STAGE' if use_two_stage else 'SINGLE-STAGE'} processing on {DEVICE}...")
641
-
642
- if preview_mask or preview_greenscreen:
643
- output_path = f"/tmp/preview_{timestamp}.mp4"
644
- else:
645
- output_path = f"/tmp/output_{timestamp}.mp4"
646
-
647
- final_writer, actual_output_path = create_video_writer(output_path, fps, frame_width, frame_height)
648
- if final_writer is None:
649
  cap.release()
650
- logger.error("Could not create output video file with any codec")
651
- return None, "Could not create output video file with any codec."
652
-
653
- output_path = actual_output_path
654
-
655
  frame_count = 0
656
  successful_frames = 0
657
  last_refined_mask = None
658
-
659
- start_time = time.time()
660
-
661
  try:
662
  while True:
663
- if PROCESS_CANCELLED.is_set():
664
  break
665
-
666
  ret, frame = cap.read()
667
  if not ret:
668
  break
669
-
670
- if frame_count % FRAME_SKIP != 0:
671
- frame_count += 1
672
- continue
673
-
674
  try:
675
- elapsed_time = time.time() - start_time
676
- current_fps = frame_count / elapsed_time if elapsed_time > 0 else 0
677
- remaining_frames = total_frames - frame_count
678
- eta_seconds = remaining_frames / current_fps if current_fps > 0 else 0
679
- eta_display = f"{int(eta_seconds//60)}m {int(eta_seconds%60)}s" if eta_seconds > 60 else f"{int(eta_seconds)}s"
680
-
681
- progress_msg = f"Frame {frame_count + 1}/{total_frames} | {elapsed_time:.1f}s | {current_fps:.1f} fps | ETA: {eta_display} | Device: {DEVICE}"
682
-
683
- logger.info(progress_msg)
684
- pct = min(1.0, 0.1 + (frame_count / max(1, total_frames)) * 0.8)
685
- _prog(pct, progress_msg)
686
-
687
- mask = segment_person_hq(frame, sam2_predictor)
688
-
689
- if preview_mask:
690
- mask_vis = np.zeros_like(frame)
691
- mask_vis[..., 1] = mask
692
- final_writer.write(mask_vis.astype(np.uint8))
693
- frame_count += 1
694
- continue
695
-
696
- if (frame_count % KEYFRAME_INTERVAL == 0) or (last_refined_mask is None):
697
- refined_mask = refine_mask_hq(frame, mask, matanyone_model)
698
  last_refined_mask = refined_mask.copy()
699
- logger.info(f"Keyframe refinement at frame {frame_count} on {DEVICE}")
700
  else:
 
701
  alpha = 0.7
702
  refined_mask = cv2.addWeighted(mask, alpha, last_refined_mask, 1-alpha, 0)
703
-
704
- if preview_greenscreen:
705
- green_bg = np.zeros_like(frame)
706
- green_bg[:, :] = [0, 255, 0]
707
- preview_frame = frame.copy()
708
- mask_3ch = cv2.cvtColor(refined_mask, cv2.COLOR_GRAY2BGR)
709
- mask_norm = mask_3ch.astype(float) / 255
710
- preview_frame = preview_frame * mask_norm + green_bg * (1 - mask_norm)
711
- final_writer.write(preview_frame.astype(np.uint8))
712
- frame_count += 1
713
- continue
714
-
715
- result_frame = replace_background_hq(frame, refined_mask, background)
716
- final_writer.write(result_frame.astype(np.uint8))
717
  successful_frames += 1
718
-
719
  except Exception as frame_error:
720
- logger.warning(f"Error processing frame {frame_count}: {frame_error}")
721
- final_writer.write(frame)
722
-
723
  frame_count += 1
724
-
725
- if frame_count % MEMORY_CLEANUP_INTERVAL == 0:
726
- gc.collect()
727
- if DEVICE.type == 'cuda':
728
- torch.cuda.empty_cache()
729
- elapsed = time.time() - start_time
730
- fps_actual = frame_count / elapsed
731
- eta = (total_frames - frame_count) / fps_actual if fps_actual > 0 else 0
732
- logger.info(f"Progress: {frame_count}/{total_frames}, FPS: {fps_actual:.1f}, ETA: {eta:.0f}s, Device: {DEVICE}")
733
-
734
  finally:
735
  cap.release()
736
- final_writer.release()
737
-
738
- if PROCESS_CANCELLED.is_set():
739
- _prog(0.95, "Cleaning up cancelled process...")
740
  try:
741
- if os.path.exists(output_path):
742
- os.remove(output_path)
743
  except:
744
  pass
745
- logger.info("Processing cancelled by user")
746
- return None, "Processing cancelled by user"
747
-
748
  if successful_frames == 0:
749
- logger.error("No frames were processed successfully with AI")
750
- return None, "No frames were processed successfully with AI."
751
-
752
- total_time = time.time() - start_time
753
- avg_fps = frame_count / total_time if total_time > 0 else 0
754
-
755
- _prog(0.9, "Finalizing output...")
756
-
757
- if preview_mask or preview_greenscreen:
758
  final_output = output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
  else:
760
- _prog(0.9, "Adding audio...")
761
- final_output = f"/tmp/final_{timestamp}.mp4"
762
-
763
- audio_check_success = run_ffmpeg_command([
764
- 'ffprobe', '-v', 'quiet', '-select_streams', 'a:0',
765
- '-show_entries', 'stream=codec_name', '-of', 'csv=p=0', video_path
766
- ], "Checking for audio stream")
767
-
768
- if audio_check_success:
769
- audio_success = run_ffmpeg_command([
770
- 'ffmpeg', '-y', '-i', output_path, '-i', video_path,
771
- '-c:v', 'copy',
772
- '-c:a', 'aac', '-b:a', '192k', '-ac', '2', '-ar', '48000',
773
- '-map', '0:v:0', '-map', '1:a:0', '-shortest', final_output
774
- ], "Audio processing with original audio")
775
-
776
- if not audio_success or not os.path.exists(final_output):
777
- logger.warning("Failed with original audio, trying fallback method...")
778
- audio_success = run_ffmpeg_command([
779
- 'ffmpeg', '-y', '-i', video_path, '-i', output_path,
780
- '-c:v', 'libx264', '-crf', '18', '-preset', 'fast',
781
- '-c:a', 'copy',
782
- '-map', '1:v:0', '-map', '0:a:0', '-shortest', final_output
783
- ], "Fallback audio processing")
784
- else:
785
- logger.info("Input video has no audio stream")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
  try:
787
- shutil.copy2(output_path, final_output)
788
- audio_success = True
789
- except Exception as e:
790
- logger.error(f"Failed to copy video: {e}")
791
- audio_success = False
792
- final_output = output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
793
 
794
- if not audio_success or not os.path.exists(final_output):
795
- logger.warning("All audio processing failed, using video without audio")
796
- try:
797
- shutil.copy2(output_path, final_output)
798
- except Exception as e:
799
- logger.error(f"Failed to copy video: {e}")
800
- final_output = output_path
801
 
802
- try:
803
- if os.path.exists(output_path) and output_path != final_output:
804
- os.remove(output_path)
805
- except Exception as e:
806
- logger.warning(f"Cleanup error: {e}")
807
 
808
- _prog(1.0, "Processing complete!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
809
 
810
- success_message = (
811
- f"Success!\n"
812
- f"Background: {background_name}\n"
813
- f"Resolution: {frame_width}x{frame_height}\n"
814
- f"Total frames: {frame_count}\n"
815
- f"Successfully processed: {successful_frames}\n"
816
- f"Processing time: {total_time:.1f}s\n"
817
- f"Average FPS: {avg_fps:.1f}\n"
818
- f"Keyframe interval: {KEYFRAME_INTERVAL}\n"
819
- f"Mode: {'TWO-STAGE' if use_two_stage else 'SINGLE-STAGE'}\n"
820
- f"Device: {DEVICE}"
821
- )
822
 
823
- return final_output, success_message
 
824
 
825
- except Exception as e:
826
- logger.error(f"Processing error: {traceback.format_exc()}")
827
- return None, f"Processing Error: {str(e)}"
828
 
829
- # ============================================================================ #
830
- # MAIN - IMPORT UI COMPONENTS
831
- # ============================================================================ #
832
  def main():
 
833
  try:
834
- print("===== FINAL FIXED VIDEO BACKGROUND REPLACEMENT =====")
835
- print(f"Keyframe interval: {KEYFRAME_INTERVAL} frames")
836
- print(f"Frame skip: {FRAME_SKIP} (1=all frames, 2=every other)")
837
- print(f"Two-stage mode: {'AVAILABLE' if TWO_STAGE_AVAILABLE else 'NOT AVAILABLE'}")
838
- print(f"Device: {DEVICE}")
839
- print("Loading UI components...")
840
-
841
- try:
842
- from ui_components import create_interface
843
- logger.info("Successfully imported ui_components")
844
- except ImportError as e:
845
- logger.error(f"Failed to import ui_components: {e}")
846
- logger.error(f"Full traceback: {traceback.format_exc()}")
847
- raise Exception(f"UI components import failed: {e}")
848
-
849
- os.makedirs("/tmp/MyAvatar/My_Videos/", exist_ok=True)
850
- CACHE_DIR.mkdir(exist_ok=True, parents=True)
851
-
852
- print("Creating interface...")
853
- try:
854
- demo = create_interface()
855
- logger.info("Gradio interface created successfully")
856
- except Exception as e:
857
- logger.error(f"Failed to create Gradio interface: {e}")
858
- logger.error(f"Full traceback: {traceback.format_exc()}")
859
- raise Exception(f"Gradio interface creation failed: {e}")
860
-
861
- print("Launching...")
862
- try:
863
- demo.queue().launch(
864
- server_name="0.0.0.0",
865
- server_port=7860,
866
- share=True,
867
- show_error=True,
868
- debug=True
869
- )
870
- logger.info("Gradio server launched successfully")
871
- except Exception as e:
872
- logger.error(f"Gradio launch failed: {e}")
873
- logger.error(f"Full traceback: {traceback.format_exc()}")
874
- raise Exception(f"Gradio launch failed: {e}")
875
-
876
  except Exception as e:
877
- logger.error(f"Startup failed: {e}")
878
- logger.error(f"Full traceback: {traceback.format_exc()}")
879
- print(f"Startup failed: {e}")
880
  raise
881
 
882
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
  """
3
+ Video Background Replacement - Main Application
4
+ Refactored version with improved error handling, memory management, and configuration
 
 
 
 
 
 
5
  """
6
+
7
+ import os
8
  import cv2
9
  import numpy as np
 
10
  import torch
 
11
  import time
12
+ import logging
 
13
  import threading
14
  import subprocess
15
+ from pathlib import Path
16
+ from typing import Optional, Tuple, Dict, Any, Callable
17
+ from dataclasses import dataclass
 
18
 
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Apply Gradio schema patch early
27
  try:
28
  import gradio_client.utils as gc_utils
29
  original_get_type = gc_utils.get_type
30
 
31
  def patched_get_type(schema):
 
32
  if not isinstance(schema, dict):
33
  if isinstance(schema, bool):
34
  return "boolean"
 
36
  return "string"
37
  if isinstance(schema, (int, float)):
38
  return "number"
39
+ return "string"
 
 
40
  return original_get_type(schema)
41
 
42
  gc_utils.get_type = patched_get_type
43
+ logger.info("Gradio schema patch applied successfully")
44
+ except Exception as e:
 
 
45
  logger.error(f"Gradio patch failed: {e}")
46
 
47
+ # Import core modules
48
  from utilities import (
49
  segment_person_hq,
50
  refine_mask_hq,
 
51
  replace_background_hq,
52
  create_professional_background,
53
  PROFESSIONAL_BACKGROUNDS,
54
  validate_video_file
55
  )
56
 
 
57
  try:
58
  from two_stage_processor import TwoStageProcessor, CHROMA_PRESETS
59
  TWO_STAGE_AVAILABLE = True
60
+ except ImportError:
 
61
  TWO_STAGE_AVAILABLE = False
62
+ CHROMA_PRESETS = {'standard': {}}
63
+
64
+ # Configuration
65
+ @dataclass
66
+ class ProcessingConfig:
67
+ keyframe_interval: int = int(os.getenv('KEYFRAME_INTERVAL', '5'))
68
+ frame_skip: int = int(os.getenv('FRAME_SKIP', '1'))
69
+ memory_cleanup_interval: int = int(os.getenv('MEMORY_CLEANUP_INTERVAL', '30'))
70
+ max_video_length: int = int(os.getenv('MAX_VIDEO_LENGTH', '300')) # seconds
71
+ quality_preset: str = os.getenv('QUALITY_PRESET', 'balanced')
72
+
73
+ class DeviceManager:
74
+ """Manage device detection and switching"""
75
+
76
+ @staticmethod
77
+ def get_optimal_device():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if torch.cuda.is_available():
79
  try:
80
+ # Test CUDA functionality
 
 
 
 
 
 
81
  test_tensor = torch.tensor([1.0], device='cuda')
82
  del test_tensor
83
  torch.cuda.empty_cache()
84
+ device = torch.device("cuda")
85
+ logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
86
+ return device
87
  except Exception as e:
88
+ logger.warning(f"CUDA test failed: {e}, falling back to CPU")
89
+
90
+ logger.info("Using CPU device")
 
 
 
 
91
  return torch.device("cpu")
92
 
93
+ class MemoryManager:
94
+ """Enhanced memory management"""
95
+
96
+ def __init__(self, device):
97
+ self.device = device
98
+ self.gpu_available = device.type == 'cuda'
99
+
100
+ def cleanup_aggressive(self):
101
+ import gc
102
+ gc.collect()
103
+ if self.gpu_available:
104
+ torch.cuda.empty_cache()
105
+ torch.cuda.synchronize()
106
+
107
+ def get_memory_usage(self):
108
+ usage = {}
109
+ if self.gpu_available:
110
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory
111
+ gpu_allocated = torch.cuda.memory_allocated(0)
112
+ usage['gpu_percent'] = (gpu_allocated / gpu_memory) * 100
113
+ usage['gpu_allocated_gb'] = gpu_allocated / (1024**3)
114
+ return usage
115
+
116
+ class ProgressTracker:
117
+ """Enhanced progress tracking with detailed statistics"""
118
+
119
+ def __init__(self, total_frames: int, callback: Optional[Callable] = None):
120
+ self.total_frames = total_frames
121
+ self.callback = callback
122
+ self.start_time = time.time()
123
+ self.processed_frames = 0
124
+ self.frame_times = []
125
+
126
+ def update(self, frame_number: int, stage: str = ""):
127
+ current_time = time.time()
128
+ self.processed_frames = frame_number
129
+
130
+ elapsed_time = current_time - self.start_time
131
+ current_fps = self.processed_frames / elapsed_time if elapsed_time > 0 else 0
132
+
133
+ remaining_frames = self.total_frames - self.processed_frames
134
+ eta_seconds = remaining_frames / current_fps if current_fps > 0 else 0
135
+
136
+ progress_pct = self.processed_frames / self.total_frames if self.total_frames > 0 else 0
137
+
138
+ message = (
139
+ f"Frame {self.processed_frames}/{self.total_frames} | "
140
+ f"Elapsed: {self._format_time(elapsed_time)} | "
141
+ f"Speed: {current_fps:.1f} fps | "
142
+ f"ETA: {self._format_time(eta_seconds)}"
143
  )
144
+
145
+ if stage:
146
+ message = f"{stage} | {message}"
147
+
148
+ if self.callback:
149
+ try:
150
+ self.callback(progress_pct, message)
151
+ except Exception as e:
152
+ logger.warning(f"Progress callback failed: {e}")
153
+
154
+ def _format_time(self, seconds: float) -> str:
155
+ if seconds < 60:
156
+ return f"{int(seconds)}s"
157
+ elif seconds < 3600:
158
+ return f"{int(seconds//60)}m {int(seconds%60)}s"
159
+ else:
160
+ hours = int(seconds // 3600)
161
+ minutes = int((seconds % 3600) // 60)
162
+ return f"{hours}h {minutes}m"
163
 
164
+ class VideoProcessor:
165
+ """Main video processing class with error recovery"""
 
 
 
 
 
 
 
 
166
 
167
+ def __init__(self):
168
+ self.device = DeviceManager.get_optimal_device()
169
+ self.memory_manager = MemoryManager(self.device)
170
+ self.config = ProcessingConfig()
171
+ self.sam2_predictor = None
172
+ self.matanyone_model = None
173
+ self.two_stage_processor = None
174
+ self.models_loaded = False
175
+ self.loading_lock = threading.Lock()
176
+ self.cancel_event = threading.Event()
177
+
178
+ def load_models(self, progress_callback: Optional[Callable] = None) -> str:
179
+ """Load AI models with comprehensive validation"""
180
+ with self.loading_lock:
181
+ if self.models_loaded:
182
+ return "Models already loaded and validated"
183
+
184
+ try:
185
+ self.cancel_event.clear()
186
+ start_time = time.time()
187
 
188
+ if progress_callback:
189
+ progress_callback(0.0, f"Starting model loading on {self.device}")
190
+
191
+ # Load SAM2
192
+ self.sam2_predictor = self._load_sam2(progress_callback)
193
+ if self.cancel_event.is_set():
194
+ return "Model loading cancelled"
195
+
196
+ # Load MatAnyone
197
+ self.matanyone_model = self._load_matanyone(progress_callback)
198
+ if self.cancel_event.is_set():
199
+ return "Model loading cancelled"
200
+
201
+ # Initialize two-stage processor if available
202
+ if TWO_STAGE_AVAILABLE:
203
+ try:
204
+ self.two_stage_processor = TwoStageProcessor(
205
+ self.sam2_predictor, self.matanyone_model
206
+ )
207
+ logger.info("Two-stage processor initialized")
208
+ except Exception as e:
209
+ logger.warning(f"Two-stage processor init failed: {e}")
210
+
211
+ self.models_loaded = True
212
+ load_time = time.time() - start_time
213
+
214
+ message = f"Models loaded successfully in {load_time:.1f}s on {self.device}"
215
+ if TWO_STAGE_AVAILABLE:
216
+ message += " (Two-stage mode available)"
217
+
218
+ logger.info(message)
219
+ return message
220
+
221
+ except Exception as e:
222
+ self.models_loaded = False
223
+ error_msg = f"Model loading failed: {str(e)}"
224
+ logger.error(error_msg)
225
+ return error_msg
226
 
227
+ def _load_sam2(self, progress_callback: Optional[Callable]) -> Any:
228
+ """Load SAM2 predictor with validation"""
 
 
 
 
 
 
 
229
  if progress_callback:
230
+ progress_callback(0.1, "Loading SAM2...")
231
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  try:
233
+ from huggingface_hub import hf_hub_download
234
+ from sam2.build_sam import build_sam2
235
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
236
+
237
+ # Download checkpoint
238
  checkpoint_path = hf_hub_download(
239
  repo_id="facebook/sam2-hiera-large",
240
  filename="sam2_hiera_large.pt",
241
+ cache_dir=str(Path("/tmp/model_cache/sam2_checkpoint")),
242
+ force_download=False
 
243
  )
244
+
245
+ # Build model
 
 
 
 
 
 
 
 
 
 
 
 
246
  sam2_model = build_sam2("sam2_hiera_l.yaml", checkpoint_path)
247
+ sam2_model.to(self.device)
248
  sam2_model.eval()
249
  predictor = SAM2ImagePredictor(sam2_model)
250
+
251
+ # Validate with test
252
+ test_image = np.zeros((256, 256, 3), dtype=np.uint8)
253
+ predictor.set_image(test_image)
254
+ test_points = np.array([[128.0, 128.0]], dtype=np.float32)
255
+ test_labels = np.array([1], dtype=np.int32)
256
+
 
 
 
 
 
 
257
  with torch.no_grad():
258
  masks, scores, _ = predictor.predict(
259
  point_coords=test_points,
260
  point_labels=test_labels,
261
  multimask_output=False
262
  )
263
+
264
+ if masks is None or len(masks) == 0:
265
+ raise Exception("SAM2 validation failed")
266
+
267
+ if progress_callback:
268
+ progress_callback(0.5, "SAM2 loaded and validated")
269
+
270
+ return predictor
271
+
272
  except Exception as e:
273
+ logger.error(f"SAM2 loading failed: {e}")
274
+ raise
275
+
276
+ def _load_matanyone(self, progress_callback: Optional[Callable]) -> Any:
277
+ """Load MatAnyone processor with validation"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  if progress_callback:
279
+ progress_callback(0.6, "Loading MatAnyone...")
280
+
 
 
 
281
  try:
282
  from matanyone import InferenceCore
 
 
 
 
 
 
283
  processor = InferenceCore("PeiqingYang/MatAnyone")
284
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  if progress_callback:
286
+ progress_callback(0.9, "MatAnyone loaded successfully")
287
+
288
+ return processor
289
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  except Exception as e:
291
+ logger.warning(f"MatAnyone loading failed: {e}")
292
+ # Return None to use fallback refinement
293
+ return None
294
+
295
+ def process_video(
296
+ self,
297
+ video_path: str,
298
+ background_choice: str,
299
+ custom_background_path: Optional[str] = None,
300
+ progress_callback: Optional[Callable] = None,
301
+ use_two_stage: bool = False,
302
+ chroma_preset: str = "standard",
303
+ preview_mask: bool = False,
304
+ preview_greenscreen: bool = False
305
+ ) -> Tuple[Optional[str], str]:
306
+ """Process video with comprehensive error handling"""
307
+
308
+ if not self.models_loaded:
309
+ return None, "Models not loaded. Please load models first."
310
+
311
+ if self.cancel_event.is_set():
312
+ return None, "Processing cancelled"
313
+
314
+ # Validate input
315
+ is_valid, validation_msg = validate_video_file(video_path)
316
+ if not is_valid:
317
+ return None, f"Invalid video: {validation_msg}"
318
+
319
+ try:
320
+ if use_two_stage and TWO_STAGE_AVAILABLE and self.two_stage_processor:
321
+ return self._process_two_stage(
322
+ video_path, background_choice, custom_background_path,
323
+ progress_callback, chroma_preset
324
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  else:
326
+ return self._process_single_stage(
327
+ video_path, background_choice, custom_background_path,
328
+ progress_callback, preview_mask, preview_greenscreen
329
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
+ except Exception as e:
332
+ logger.error(f"Video processing failed: {e}")
333
+ return None, f"Processing failed: {str(e)}"
334
+
335
+ def _process_single_stage(
336
+ self,
337
+ video_path: str,
338
+ background_choice: str,
339
+ custom_background_path: Optional[str],
340
+ progress_callback: Optional[Callable],
341
+ preview_mask: bool,
342
+ preview_greenscreen: bool
343
+ ) -> Tuple[Optional[str], str]:
344
+ """Single-stage video processing"""
345
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  cap = cv2.VideoCapture(video_path)
347
  if not cap.isOpened():
348
+ return None, "Could not open video file"
349
+
 
350
  fps = cap.get(cv2.CAP_PROP_FPS)
351
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
352
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
353
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
354
+
355
+ # Prepare background
356
+ background = self._prepare_background(
357
+ background_choice, custom_background_path, frame_width, frame_height
358
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  if background is None:
360
  cap.release()
361
+ return None, "Failed to prepare background"
362
+
363
+ # Setup output
364
  timestamp = int(time.time())
365
+ output_path = f"/tmp/output_{timestamp}.mp4"
366
+
367
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
368
+ out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
369
+
370
+ if not out.isOpened():
 
 
 
 
371
  cap.release()
372
+ return None, "Could not create output video"
373
+
374
+ # Process frames
375
+ progress_tracker = ProgressTracker(total_frames, progress_callback)
 
376
  frame_count = 0
377
  successful_frames = 0
378
  last_refined_mask = None
379
+
 
 
380
  try:
381
  while True:
382
+ if self.cancel_event.is_set():
383
  break
384
+
385
  ret, frame = cap.read()
386
  if not ret:
387
  break
388
+
 
 
 
 
389
  try:
390
+ progress_tracker.update(frame_count, "Processing")
391
+
392
+ # Segmentation
393
+ mask = segment_person_hq(frame, self.sam2_predictor)
394
+
395
+ # Mask refinement (keyframe-based)
396
+ if (frame_count % self.config.keyframe_interval == 0) or (last_refined_mask is None):
397
+ refined_mask = refine_mask_hq(frame, mask, self.matanyone_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  last_refined_mask = refined_mask.copy()
 
399
  else:
400
+ # Blend with previous refined mask for temporal consistency
401
  alpha = 0.7
402
  refined_mask = cv2.addWeighted(mask, alpha, last_refined_mask, 1-alpha, 0)
403
+
404
+ # Generate output based on mode
405
+ if preview_mask:
406
+ result_frame = self._create_mask_preview(frame, refined_mask)
407
+ elif preview_greenscreen:
408
+ result_frame = self._create_greenscreen_preview(frame, refined_mask)
409
+ else:
410
+ result_frame = replace_background_hq(frame, refined_mask, background)
411
+
412
+ out.write(result_frame)
 
 
 
 
413
  successful_frames += 1
414
+
415
  except Exception as frame_error:
416
+ logger.warning(f"Frame {frame_count} processing failed: {frame_error}")
417
+ out.write(frame) # Write original frame as fallback
418
+
419
  frame_count += 1
420
+
421
+ # Memory cleanup
422
+ if frame_count % self.config.memory_cleanup_interval == 0:
423
+ self.memory_manager.cleanup_aggressive()
424
+
 
 
 
 
 
425
  finally:
426
  cap.release()
427
+ out.release()
428
+
429
+ if self.cancel_event.is_set():
 
430
  try:
431
+ os.remove(output_path)
 
432
  except:
433
  pass
434
+ return None, "Processing cancelled"
435
+
 
436
  if successful_frames == 0:
437
+ return None, "No frames processed successfully"
438
+
439
+ # Add audio if not preview mode
440
+ if not (preview_mask or preview_greenscreen):
441
+ final_output = self._add_audio(video_path, output_path)
442
+ else:
 
 
 
443
  final_output = output_path
444
+
445
+ success_msg = (
446
+ f"Success! Processed {successful_frames}/{frame_count} frames\n"
447
+ f"Background: {background_choice}\n"
448
+ f"Mode: Single-stage\n"
449
+ f"Device: {self.device}"
450
+ )
451
+
452
+ return final_output, success_msg
453
+
454
+ def _process_two_stage(
455
+ self,
456
+ video_path: str,
457
+ background_choice: str,
458
+ custom_background_path: Optional[str],
459
+ progress_callback: Optional[Callable],
460
+ chroma_preset: str
461
+ ) -> Tuple[Optional[str], str]:
462
+ """Two-stage processing using green screen intermediate"""
463
+
464
+ cap = cv2.VideoCapture(video_path)
465
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
466
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
467
+ cap.release()
468
+
469
+ # Prepare background
470
+ background = self._prepare_background(
471
+ background_choice, custom_background_path, frame_width, frame_height
472
+ )
473
+ if background is None:
474
+ return None, "Failed to prepare background"
475
+
476
+ # Process with two-stage pipeline
477
+ timestamp = int(time.time())
478
+ final_output = f"/tmp/twostage_final_{timestamp}.mp4"
479
+
480
+ chroma_settings = CHROMA_PRESETS.get(chroma_preset, CHROMA_PRESETS['standard'])
481
+
482
+ result, message = self.two_stage_processor.process_full_pipeline(
483
+ video_path,
484
+ background,
485
+ final_output,
486
+ chroma_settings=chroma_settings,
487
+ progress_callback=progress_callback
488
+ )
489
+
490
+ if result is None:
491
+ return None, message
492
+
493
+ success_msg = (
494
+ f"Two-stage success!\n"
495
+ f"Background: {background_choice}\n"
496
+ f"Preset: {chroma_preset}\n"
497
+ f"Quality: Cinema-grade\n"
498
+ f"Device: {self.device}"
499
+ )
500
+
501
+ return result, success_msg
502
+
503
+ def _prepare_background(
504
+ self,
505
+ background_choice: str,
506
+ custom_background_path: Optional[str],
507
+ width: int,
508
+ height: int
509
+ ) -> Optional[np.ndarray]:
510
+ """Prepare background image"""
511
+
512
+ if background_choice == "custom" and custom_background_path:
513
+ if not os.path.exists(custom_background_path):
514
+ logger.error(f"Custom background not found: {custom_background_path}")
515
+ return None
516
+
517
+ background = cv2.imread(custom_background_path)
518
+ if background is None:
519
+ logger.error("Could not read custom background")
520
+ return None
521
  else:
522
+ if background_choice not in PROFESSIONAL_BACKGROUNDS:
523
+ logger.error(f"Unknown background: {background_choice}")
524
+ return None
525
+
526
+ bg_config = PROFESSIONAL_BACKGROUNDS[background_choice]
527
+ background = create_professional_background(bg_config, width, height)
528
+
529
+ return cv2.resize(background, (width, height))
530
+
531
+ def _create_mask_preview(self, frame: np.ndarray, mask: np.ndarray) -> np.ndarray:
532
+ """Create mask preview visualization"""
533
+ mask_vis = np.zeros_like(frame)
534
+ mask_vis[..., 1] = mask # Green channel
535
+ return mask_vis
536
+
537
+ def _create_greenscreen_preview(self, frame: np.ndarray, mask: np.ndarray) -> np.ndarray:
538
+ """Create green screen preview"""
539
+ green_bg = np.zeros_like(frame)
540
+ green_bg[:, :] = [0, 255, 0] # Pure green
541
+
542
+ mask_3ch = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
543
+ mask_norm = mask_3ch.astype(float) / 255
544
+
545
+ return (frame * mask_norm + green_bg * (1 - mask_norm)).astype(np.uint8)
546
+
547
+ def _add_audio(self, input_video: str, processed_video: str) -> str:
548
+ """Add audio from original video to processed video"""
549
+ timestamp = int(time.time())
550
+ final_output = f"/tmp/final_with_audio_{timestamp}.mp4"
551
+
552
+ try:
553
+ # Check if input has audio
554
+ result = subprocess.run([
555
+ 'ffprobe', '-v', 'quiet', '-select_streams', 'a:0',
556
+ '-show_entries', 'stream=codec_name', '-of', 'csv=p=0', input_video
557
+ ], capture_output=True, text=True, timeout=30)
558
+
559
+ if result.returncode != 0:
560
+ logger.info("Input video has no audio")
561
+ return processed_video
562
+
563
+ # Add audio
564
+ result = subprocess.run([
565
+ 'ffmpeg', '-y', '-i', processed_video, '-i', input_video,
566
+ '-c:v', 'copy', '-c:a', 'aac', '-b:a', '192k',
567
+ '-map', '0:v:0', '-map', '1:a:0', '-shortest', final_output
568
+ ], capture_output=True, text=True, timeout=300)
569
+
570
+ if result.returncode == 0 and os.path.exists(final_output):
571
  try:
572
+ os.remove(processed_video)
573
+ except:
574
+ pass
575
+ return final_output
576
+ else:
577
+ logger.warning("Audio processing failed, using video without audio")
578
+ return processed_video
579
+
580
+ except Exception as e:
581
+ logger.warning(f"Audio processing error: {e}")
582
+ return processed_video
583
+
584
+ def get_status(self) -> Dict[str, Any]:
585
+ """Get current processor status"""
586
+ return {
587
+ 'models_loaded': self.models_loaded,
588
+ 'sam2_available': self.sam2_predictor is not None,
589
+ 'matanyone_available': self.matanyone_model is not None,
590
+ 'two_stage_available': TWO_STAGE_AVAILABLE and self.two_stage_processor is not None,
591
+ 'device': str(self.device),
592
+ 'memory_usage': self.memory_manager.get_memory_usage(),
593
+ 'config': {
594
+ 'keyframe_interval': self.config.keyframe_interval,
595
+ 'quality_preset': self.config.quality_preset
596
+ }
597
+ }
598
+
599
+ def cancel_processing(self):
600
+ """Cancel current processing"""
601
+ self.cancel_event.set()
602
+ logger.info("Processing cancellation requested")
603
 
604
+ # Global processor instance
605
+ processor = VideoProcessor()
 
 
 
 
 
606
 
607
+ # Compatibility functions for existing UI
608
+ def load_models_with_validation(progress_callback: Optional[Callable] = None) -> str:
609
+ return processor.load_models(progress_callback)
 
 
610
 
611
+ def process_video_fixed(
612
+ video_path: str,
613
+ background_choice: str,
614
+ custom_background_path: Optional[str],
615
+ progress_callback: Optional[Callable] = None,
616
+ use_two_stage: bool = False,
617
+ chroma_preset: str = "standard",
618
+ preview_mask: bool = False,
619
+ preview_greenscreen: bool = False
620
+ ) -> Tuple[Optional[str], str]:
621
+ return processor.process_video(
622
+ video_path, background_choice, custom_background_path,
623
+ progress_callback, use_two_stage, chroma_preset,
624
+ preview_mask, preview_greenscreen
625
+ )
626
 
627
+ def get_model_status() -> Dict[str, Any]:
628
+ return processor.get_status()
 
 
 
 
 
 
 
 
 
 
629
 
630
+ def get_cache_status() -> Dict[str, Any]:
631
+ return processor.get_status()
632
 
633
+ # For backward compatibility
634
+ PROCESS_CANCELLED = processor.cancel_event
 
635
 
 
 
 
636
  def main():
637
+ """Main application entry point"""
638
  try:
639
+ logger.info("Starting Video Background Replacement application")
640
+ logger.info(f"Device: {processor.device}")
641
+ logger.info(f"Two-stage available: {TWO_STAGE_AVAILABLE}")
642
+
643
+ # Import and create UI
644
+ from ui_components import create_interface
645
+ demo = create_interface()
646
+
647
+ # Launch application
648
+ demo.queue().launch(
649
+ server_name="0.0.0.0",
650
+ server_port=7860,
651
+ share=True,
652
+ show_error=True,
653
+ debug=False
654
+ )
655
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  except Exception as e:
657
+ logger.error(f"Application startup failed: {e}")
 
 
658
  raise
659
 
660
  if __name__ == "__main__":