MogensR commited on
Commit
d008b27
Β·
verified Β·
1 Parent(s): ae33908

Create test_positioning.py

Browse files
Files changed (1) hide show
  1. test_positioning.py +1502 -0
test_positioning.py ADDED
@@ -0,0 +1,1502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # =============================================================================
3
+ # CHAPTER 0: INTRO & OVERVIEW
4
+ # =============================================================================
5
+ """
6
+ Enhanced Video Background Replacement (SAM2 + MatAnyone + AI Backgrounds)
7
+ - Strict tensor shapes for MatAnyone (image: 3xHxW, first-frame prob mask: 1xHxW)
8
+ - First frame uses PROB path (no idx_mask / objects) to avoid assertion
9
+ - Memory management & cleanup
10
+ - SDXL / Playground / OpenAI backgrounds
11
+ - Gradio UI with "CHAPTER" dividers
12
+ - FIXED: Enhanced positioning with debug logging and coordinate precision
13
+ """
14
+
15
+ # =============================================================================
16
+ # CHAPTER 1: IMPORTS & GLOBALS
17
+ # =============================================================================
18
+ import os
19
+ import sys
20
+ import gc
21
+ import cv2
22
+ import psutil
23
+ import time
24
+ import json
25
+ import base64
26
+ import random
27
+ import shutil
28
+ import logging
29
+ import traceback
30
+ import subprocess
31
+ import tempfile
32
+ import threading
33
+ from dataclasses import dataclass
34
+ from contextlib import contextmanager
35
+ from pathlib import Path
36
+ from typing import Optional, Tuple, List
37
+
38
+ import numpy as np
39
+ from PIL import Image
40
+ import gradio as gr
41
+ from moviepy.editor import VideoFileClip
42
+
43
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
44
+ logger = logging.getLogger("bgx")
45
+
46
+ # Environment tuning (safe defaults)
47
+ os.environ.setdefault("CUDA_MODULE_LOADING", "LAZY")
48
+ os.environ.setdefault("TORCH_CUDNN_V8_API_ENABLED", "1")
49
+ os.environ.setdefault("PYTHONUNBUFFERED", "1")
50
+ os.environ.setdefault("MKL_NUM_THREADS", "4")
51
+ os.environ.setdefault("BFX_QUALITY", "max")
52
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,roundup_power2_divisions:16")
53
+ os.environ.setdefault("HYDRA_FULL_ERROR", "1")
54
+ os.environ["OMP_NUM_THREADS"] = "2"
55
+
56
+ # Paths
57
+ BASE_DIR = Path(__file__).resolve().parent
58
+ CHECKPOINTS = BASE_DIR / "checkpoints"
59
+ TEMP_DIR = BASE_DIR / "temp"
60
+ OUT_DIR = BASE_DIR / "outputs"
61
+ BACKGROUND_DIR = OUT_DIR / "backgrounds"
62
+ for p in (CHECKPOINTS, TEMP_DIR, OUT_DIR, BACKGROUND_DIR):
63
+ p.mkdir(parents=True, exist_ok=True)
64
+
65
+ # Torch/device
66
+ try:
67
+ import torch
68
+ TORCH_AVAILABLE = True
69
+ CUDA_AVAILABLE = torch.cuda.is_available()
70
+ DEVICE = "cuda" if CUDA_AVAILABLE else "cpu"
71
+ try:
72
+ if torch.backends.cuda.is_built():
73
+ torch.backends.cuda.matmul.allow_tf32 = True
74
+ if hasattr(torch.backends, "cudnn"):
75
+ torch.backends.cudnn.benchmark = True
76
+ torch.backends.cudnn.deterministic = False
77
+ if CUDA_AVAILABLE:
78
+ torch.cuda.set_per_process_memory_fraction(0.8)
79
+ except Exception:
80
+ pass
81
+ except Exception:
82
+ TORCH_AVAILABLE = False
83
+ CUDA_AVAILABLE = False
84
+ DEVICE = "cpu"
85
+
86
+ # =============================================================================
87
+ # CHAPTER 2: UI CONSTANTS & UTILS
88
+ # =============================================================================
89
+ GRADIENT_PRESETS = {
90
+ "Blue Fade": ((128, 64, 0), (255, 128, 0)),
91
+ "Sunset": ((255, 128, 0), (255, 0, 128)),
92
+ "Green Field": ((64, 128, 64), (160, 255, 160)),
93
+ "Slate": ((40, 40, 48), (96, 96, 112)),
94
+ "Ocean": ((255, 140, 0), (255, 215, 0)),
95
+ "Forest": ((34, 139, 34), (144, 238, 144)),
96
+ "Sunset Pink": ((255, 182, 193), (255, 105, 180)),
97
+ "Cool Blue": ((173, 216, 230), (0, 191, 255)),
98
+ }
99
+
100
+ AI_PROMPT_SUGGESTIONS = [
101
+ "Custom (write your own)",
102
+ "modern minimalist office with soft lighting, clean desk, blurred background",
103
+ "elegant conference room with large windows and city view",
104
+ "contemporary workspace with plants and natural light",
105
+ "luxury hotel lobby with marble floors and warm ambient lighting",
106
+ "professional studio with clean white background and soft lighting",
107
+ "modern corporate meeting room with glass walls and city skyline",
108
+ "sophisticated home office with bookshelf and warm wood tones",
109
+ "sleek coworking space with industrial design elements",
110
+ "abstract geometric patterns in blue and gold, modern art style",
111
+ "soft watercolor texture with pastel colors, dreamy atmosphere",
112
+ ]
113
+
114
+ def _make_vertical_gradient(width: int, height: int, c1, c2) -> np.ndarray:
115
+ width = max(1, int(width))
116
+ height = max(1, int(height))
117
+ top = np.array(c1, dtype=np.float32)
118
+ bot = np.array(c2, dtype=np.float32)
119
+ rows = np.linspace(top, bot, num=height, dtype=np.float32)
120
+ grad = np.repeat(rows[:, None, :], repeats=width, axis=1)
121
+ return np.clip(grad, 0, 255).astype(np.uint8)
122
+
123
+ def run_ffmpeg(args: list, fail_ok=False) -> bool:
124
+ cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error"] + args
125
+ try:
126
+ subprocess.run(cmd, check=True, capture_output=True)
127
+ return True
128
+ except Exception as e:
129
+ if not fail_ok:
130
+ logger.error(f"ffmpeg failed: {e}")
131
+ return False
132
+
133
+ def write_video_h264(clip, path: str, fps: Optional[int] = None, crf: int = 18, preset: str = "medium"):
134
+ fps = fps or max(1, int(round(getattr(clip, "fps", None) or 24)))
135
+ clip.write_videofile(
136
+ path,
137
+ audio=False,
138
+ fps=fps,
139
+ codec="libx264",
140
+ preset=preset,
141
+ ffmpeg_params=["-crf", str(crf), "-pix_fmt", "yuv420p", "-profile:v", "high", "-movflags", "+faststart"],
142
+ logger=None,
143
+ verbose=False,
144
+ )
145
+
146
+ def download_file(url: str, dest: Path, name: str) -> bool:
147
+ if dest.exists():
148
+ logger.info(f"{name} already exists")
149
+ return True
150
+ try:
151
+ import requests
152
+ logger.info(f"Downloading {name} ...")
153
+ with requests.get(url, stream=True, timeout=300) as r:
154
+ r.raise_for_status()
155
+ with open(dest, "wb") as f:
156
+ for chunk in r.iter_content(chunk_size=8192):
157
+ if chunk:
158
+ f.write(chunk)
159
+ return True
160
+ except Exception as e:
161
+ logger.error(f"Failed to download {name}: {e}")
162
+ if dest.exists():
163
+ try: dest.unlink()
164
+ except Exception: pass
165
+ return False
166
+
167
+ def ensure_repo(repo_name: str, git_url: str) -> Optional[Path]:
168
+ repo_path = CHECKPOINTS / f"{repo_name}_repo"
169
+ if not repo_path.exists():
170
+ try:
171
+ subprocess.run(["git", "clone", "--depth", "1", git_url, str(repo_path)],
172
+ check=True, timeout=300, capture_output=True)
173
+ logger.info(f"{repo_name} cloned")
174
+ except Exception as e:
175
+ logger.error(f"Failed to clone {repo_name}: {e}")
176
+ return None
177
+ repo_str = str(repo_path)
178
+ if repo_str not in sys.path:
179
+ sys.path.insert(0, repo_str)
180
+ return repo_path
181
+
182
+ def _reset_hydra():
183
+ try:
184
+ from hydra.core.global_hydra import GlobalHydra
185
+ if GlobalHydra().is_initialized():
186
+ GlobalHydra.instance().clear()
187
+ except Exception:
188
+ pass
189
+
190
+ # =============================================================================
191
+ # CHAPTER 3: MEMORY MANAGER
192
+ # =============================================================================
193
+ @dataclass
194
+ class MemoryStats:
195
+ cpu_percent: float
196
+ cpu_memory_mb: float
197
+ gpu_memory_mb: float = 0.0
198
+ gpu_memory_reserved_mb: float = 0.0
199
+ temp_files_count: int = 0
200
+ temp_files_size_mb: float = 0.0
201
+
202
+ class MemoryManager:
203
+ def __init__(self):
204
+ self.temp_files: List[str] = []
205
+ self.cleanup_lock = threading.Lock()
206
+ self.torch_available = TORCH_AVAILABLE
207
+ self.cuda_available = CUDA_AVAILABLE
208
+
209
+ def get_memory_stats(self) -> MemoryStats:
210
+ process = psutil.Process()
211
+ cpu_percent = psutil.cpu_percent(interval=0.1)
212
+ cpu_memory_mb = process.memory_info().rss / (1024 * 1024)
213
+ gpu_memory_mb = 0.0
214
+ gpu_memory_reserved_mb = 0.0
215
+ if self.torch_available and self.cuda_available:
216
+ try:
217
+ import torch
218
+ gpu_memory_mb = torch.cuda.memory_allocated() / (1024 * 1024)
219
+ gpu_memory_reserved_mb = torch.cuda.memory_reserved() / (1024 * 1024)
220
+ except Exception:
221
+ pass
222
+
223
+ temp_count, temp_size_mb = 0, 0.0
224
+ for tf in self.temp_files:
225
+ if os.path.exists(tf):
226
+ temp_count += 1
227
+ try:
228
+ temp_size_mb += os.path.getsize(tf) / (1024 * 1024)
229
+ except Exception:
230
+ pass
231
+ return MemoryStats(cpu_percent, cpu_memory_mb, gpu_memory_mb, gpu_memory_reserved_mb, temp_count, temp_size_mb)
232
+
233
+ def register_temp_file(self, path: str):
234
+ with self.cleanup_lock:
235
+ if path not in self.temp_files:
236
+ self.temp_files.append(path)
237
+
238
+ def cleanup_temp_files(self):
239
+ with self.cleanup_lock:
240
+ cleaned = 0
241
+ for tf in self.temp_files[:]:
242
+ try:
243
+ if os.path.isdir(tf):
244
+ shutil.rmtree(tf, ignore_errors=True)
245
+ elif os.path.exists(tf):
246
+ os.unlink(tf)
247
+ cleaned += 1
248
+ except Exception as e:
249
+ logger.warning(f"Failed to cleanup {tf}: {e}")
250
+ finally:
251
+ try: self.temp_files.remove(tf)
252
+ except Exception: pass
253
+ if cleaned:
254
+ logger.info(f"Cleaned {cleaned} temp paths")
255
+
256
+ def aggressive_cleanup(self):
257
+ logger.info("Aggressive cleanup...")
258
+ gc.collect()
259
+ if self.torch_available and self.cuda_available:
260
+ try:
261
+ import torch
262
+ torch.cuda.empty_cache()
263
+ torch.cuda.synchronize()
264
+ except Exception:
265
+ pass
266
+ self.cleanup_temp_files()
267
+ gc.collect()
268
+
269
+ @contextmanager
270
+ def mem_context(self, name="op"):
271
+ stats = self.get_memory_stats()
272
+ logger.info(f"Start {name} | CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB")
273
+ try:
274
+ yield self
275
+ finally:
276
+ self.aggressive_cleanup()
277
+ stats = self.get_memory_stats()
278
+ logger.info(f"End {name} | CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB")
279
+
280
+ memory_manager = MemoryManager()
281
+
282
+ # =============================================================================
283
+ # CHAPTER 4: SYSTEM STATE
284
+ # =============================================================================
285
+ class SystemState:
286
+ def __init__(self):
287
+ self.torch_available = TORCH_AVAILABLE
288
+ self.cuda_available = CUDA_AVAILABLE
289
+ self.device = DEVICE
290
+ self.sam2_ready = False
291
+ self.matanyone_ready = False
292
+ self.sam2_error = None
293
+ self.matanyone_error = None
294
+
295
+ def status_text(self) -> str:
296
+ stats = memory_manager.get_memory_stats()
297
+ return (
298
+ "=== SYSTEM STATUS ===\n"
299
+ f"PyTorch: {'βœ…' if self.torch_available else '❌'}\n"
300
+ f"CUDA: {'βœ…' if self.cuda_available else '❌'}\n"
301
+ f"Device: {self.device}\n"
302
+ f"SAM2: {'βœ…' if self.sam2_ready else ('❌' if self.sam2_error else '⏳')}\n"
303
+ f"MatAnyone: {'βœ…' if self.matanyone_ready else ('❌' if self.matanyone_error else '⏳')}\n\n"
304
+ "=== MEMORY ===\n"
305
+ f"CPU: {stats.cpu_percent:.1f}% ({stats.cpu_memory_mb:.1f} MB)\n"
306
+ f"GPU: {stats.gpu_memory_mb:.1f} MB (Reserved {stats.gpu_memory_reserved_mb:.1f} MB)\n"
307
+ f"Temp: {stats.temp_files_count} files ({stats.temp_files_size_mb:.1f} MB)\n"
308
+ )
309
+
310
+ state = SystemState()
311
+
312
+ # =============================================================================
313
+ # CHAPTER 5: SAM2 HANDLER (CUDA-only)
314
+ # =============================================================================
315
+ class SAM2Handler:
316
+ def __init__(self):
317
+ self.predictor = None
318
+ self.initialized = False
319
+
320
+ def initialize(self) -> bool:
321
+ if not (TORCH_AVAILABLE and CUDA_AVAILABLE):
322
+ state.sam2_error = "SAM2 requires CUDA"
323
+ return False
324
+
325
+ with memory_manager.mem_context("SAM2 init"):
326
+ try:
327
+ _reset_hydra()
328
+ repo_path = ensure_repo("sam2", "https://github.com/facebookresearch/segment-anything-2.git")
329
+ if not repo_path:
330
+ state.sam2_error = "Clone failed"
331
+ return False
332
+
333
+ ckpt = CHECKPOINTS / "sam2.1_hiera_large.pt"
334
+ url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
335
+ if not download_file(url, ckpt, "SAM2 Large"):
336
+ state.sam2_error = "SAM2 ckpt download failed"
337
+ return False
338
+
339
+ from hydra.core.global_hydra import GlobalHydra
340
+ from hydra import initialize_config_dir
341
+ from sam2.build_sam import build_sam2
342
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
343
+
344
+ config_dir = (repo_path / "sam2" / "configs").as_posix()
345
+ if GlobalHydra().is_initialized():
346
+ GlobalHydra.instance().clear()
347
+ initialize_config_dir(config_dir=config_dir, version_base=None)
348
+
349
+ model = build_sam2("sam2.1/sam2.1_hiera_l.yaml", str(ckpt), device="cuda")
350
+ self.predictor = SAM2ImagePredictor(model)
351
+
352
+ # Smoke test
353
+ test = np.zeros((64, 64, 3), dtype=np.uint8)
354
+ self.predictor.set_image(test)
355
+ masks, scores, _ = self.predictor.predict(
356
+ point_coords=np.array([[32, 32]]),
357
+ point_labels=np.ones(1, dtype=np.int64),
358
+ multimask_output=True,
359
+ )
360
+ ok = masks is not None and len(masks) > 0
361
+ self.initialized = ok
362
+ state.sam2_ready = ok
363
+ if not ok:
364
+ state.sam2_error = "SAM2 verify failed"
365
+ return ok
366
+
367
+ except Exception as e:
368
+ state.sam2_error = f"SAM2 init error: {e}"
369
+ return False
370
+
371
+ def create_mask(self, image_rgb: np.ndarray) -> Optional[np.ndarray]:
372
+ if not self.initialized:
373
+ return None
374
+ with memory_manager.mem_context("SAM2 mask"):
375
+ try:
376
+ self.predictor.set_image(image_rgb)
377
+ h, w = image_rgb.shape[:2]
378
+ strategies = [
379
+ np.array([[w // 2, h // 2]]),
380
+ np.array([[w // 2, h // 3]]),
381
+ np.array([[w // 2, h // 3], [w // 2, (2 * h) // 3]]),
382
+ ]
383
+ best, best_score = None, -1.0
384
+ for pc in strategies:
385
+ masks, scores, _ = self.predictor.predict(
386
+ point_coords=pc,
387
+ point_labels=np.ones(len(pc), dtype=np.int64),
388
+ multimask_output=True,
389
+ )
390
+ if masks is not None and len(masks) > 0:
391
+ i = int(np.argmax(scores))
392
+ sc = float(scores[i])
393
+ if sc > best_score:
394
+ best_score, best = sc, masks[i]
395
+
396
+ if best is None:
397
+ return None
398
+
399
+ mask_u8 = (best * 255).astype(np.uint8)
400
+ k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
401
+ mask_clean = cv2.morphologyEx(mask_u8, cv2.MORPH_CLOSE, k)
402
+ mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_OPEN, k)
403
+ mask_clean = cv2.GaussianBlur(mask_clean, (3, 3), 1.0)
404
+ return mask_clean
405
+ except Exception as e:
406
+ logger.error(f"SAM2 mask error: {e}")
407
+ return None
408
+
409
+ # =============================================================================
410
+ # CHAPTER 6: MATANYONE HANDLER (FIXED - Uses existing matanyone_fixed files)
411
+ # =============================================================================
412
+ class MatAnyoneHandler:
413
+ """
414
+ FIXED MatAnyone handler using existing matanyone_fixed files
415
+ """
416
+ def __init__(self):
417
+ self.core = None
418
+ self.initialized = False
419
+
420
+ # ----- tensor helpers -----
421
+ def _to_chw_float(self, img01: np.ndarray) -> "torch.Tensor":
422
+ """img01: HxWx3 in [0,1] -> torch float (3,H,W) on DEVICE (no batch)."""
423
+ assert img01.ndim == 3 and img01.shape[2] == 3, f"Expected HxWx3, got {img01.shape}"
424
+ t = torch.from_numpy(img01.transpose(2, 0, 1)).contiguous().float() # (3,H,W)
425
+ return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
426
+
427
+ def _prob_hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
428
+ """mask_u8: HxW -> torch float (H,W) in [0,1] on DEVICE (no batch, no channel)."""
429
+ if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
430
+ mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
431
+ prob = (mask_u8.astype(np.float32) / 255.0) # (H,W)
432
+ t = torch.from_numpy(prob).contiguous().float()
433
+ return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
434
+
435
+ def _prob_1hw_from_mask_u8(self, mask_u8: np.ndarray, w: int, h: int) -> "torch.Tensor":
436
+ """Optional: 1xHxW (channel-first, still unbatched)."""
437
+ if mask_u8.shape[0] != h or mask_u8.shape[1] != w:
438
+ mask_u8 = cv2.resize(mask_u8, (w, h), interpolation=cv2.INTER_NEAREST)
439
+ prob = (mask_u8.astype(np.float32) / 255.0)[None, ...] # (1,H,W)
440
+ t = torch.from_numpy(prob).contiguous().float()
441
+ return t.to(DEVICE, non_blocking=CUDA_AVAILABLE)
442
+
443
+ def _alpha_to_u8_hw(self, alpha_like) -> np.ndarray:
444
+ """
445
+ Accepts torch / numpy / tuple(list) outputs.
446
+ Returns uint8 HxW (0..255). Squeezes common shapes down to HxW.
447
+ """
448
+ if isinstance(alpha_like, (list, tuple)) and len(alpha_like) > 1:
449
+ alpha_like = alpha_like[1] # (indices, probs) -> take probs
450
+
451
+ if isinstance(alpha_like, torch.Tensor):
452
+ t = alpha_like.detach()
453
+ if t.is_cuda:
454
+ t = t.cpu()
455
+ a = t.float().clamp(0, 1).numpy()
456
+ else:
457
+ a = np.asarray(alpha_like, dtype=np.float32)
458
+ a = np.clip(a, 0, 1)
459
+
460
+ a = np.squeeze(a)
461
+ if a.ndim == 3 and a.shape[0] >= 1: # (1,H,W) -> (H,W)
462
+ a = a[0]
463
+ if a.ndim != 2:
464
+ raise ValueError(f"Alpha must be HxW; got {a.shape}")
465
+
466
+ return np.clip(a * 255.0, 0, 255).astype(np.uint8)
467
+
468
+ def initialize(self) -> bool:
469
+ """
470
+ FIXED MatAnyone initialization using existing matanyone_fixed files
471
+ """
472
+ if not TORCH_AVAILABLE:
473
+ state.matanyone_error = "PyTorch required"
474
+ return False
475
+
476
+ with memory_manager.mem_context("MatAnyone init"):
477
+ try:
478
+ # Use existing matanyone_fixed directory
479
+ local_matanyone = BASE_DIR / "matanyone_fixed"
480
+
481
+ if not local_matanyone.exists():
482
+ state.matanyone_error = "matanyone_fixed directory not found"
483
+ return False
484
+
485
+ # Add the fixed matanyone path to Python path
486
+ matanyone_str = str(local_matanyone)
487
+ if matanyone_str not in sys.path:
488
+ sys.path.insert(0, matanyone_str)
489
+
490
+ # Import fixed modules
491
+ try:
492
+ from inference.inference_core import InferenceCore
493
+ from utils.get_default_model import get_matanyone_model
494
+ except Exception as e:
495
+ state.matanyone_error = f"Import error: {e}"
496
+ return False
497
+
498
+ # Download model checkpoint if needed
499
+ ckpt = CHECKPOINTS / "matanyone.pth"
500
+ if not ckpt.exists():
501
+ url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth"
502
+ if not download_file(url, ckpt, "MatAnyone"):
503
+ logger.warning("MatAnyone checkpoint download failed, using random weights")
504
+
505
+ # Load model using fixed interface
506
+ net = get_matanyone_model(str(ckpt), device=DEVICE)
507
+
508
+ if net is None:
509
+ state.matanyone_error = "Model creation failed"
510
+ return False
511
+
512
+ # Create inference core with fixed implementation
513
+ self.core = InferenceCore(net)
514
+ self.initialized = True
515
+ state.matanyone_ready = True
516
+
517
+ logger.info("Fixed MatAnyone initialized successfully")
518
+ return True
519
+
520
+ except Exception as e:
521
+ state.matanyone_error = f"MatAnyone init error: {e}"
522
+ logger.error(f"MatAnyone initialization failed: {e}")
523
+ return False
524
+
525
+ def _try_step_variants_seed(self,
526
+ img_chw_t: "torch.Tensor",
527
+ prob_hw_t: "torch.Tensor",
528
+ prob_1hw_t: "torch.Tensor"):
529
+ """
530
+ Simplified step variants using fixed MatAnyone
531
+ """
532
+ # The fixed MatAnyone handles tensor format internally
533
+ try:
534
+ return self.core.step(img_chw_t, prob_hw_t)
535
+ except Exception as e:
536
+ try:
537
+ return self.core.step(img_chw_t, prob_1hw_t)
538
+ except Exception as e2:
539
+ # Final fallback: no probability guidance
540
+ return self.core.step(img_chw_t)
541
+
542
+ def _try_step_variants_noseed(self, img_chw_t: "torch.Tensor"):
543
+ """
544
+ Simplified noseed variants using fixed MatAnyone
545
+ """
546
+ return self.core.step(img_chw_t)
547
+
548
+ # ----- video matting using first-frame PROB mask --------------------------
549
+ def process_video(self, input_path: str, mask_path: str, output_path: str) -> str:
550
+ """
551
+ Produce a single-channel alpha mp4 matching input fps & size.
552
+
553
+ First frame: pass a soft seed prob (~HW) alongside the image.
554
+ Remaining frames: call step(image) only.
555
+ """
556
+ if not self.initialized or self.core is None:
557
+ raise RuntimeError("MatAnyone not initialized")
558
+
559
+ out_dir = Path(output_path)
560
+ out_dir.mkdir(parents=True, exist_ok=True)
561
+ alpha_path = out_dir / "alpha.mp4"
562
+
563
+ cap = cv2.VideoCapture(input_path)
564
+ if not cap.isOpened():
565
+ raise RuntimeError("Could not open input video")
566
+
567
+ fps = cap.get(cv2.CAP_PROP_FPS) or 24.0
568
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
569
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
570
+
571
+ # soft seed prob - prepare tensor versions only
572
+ seed_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
573
+ if seed_mask is None:
574
+ cap.release()
575
+ raise RuntimeError("Seed mask read failed")
576
+
577
+ prob_hw_t = self._prob_hw_from_mask_u8(seed_mask, w, h) # (H,W) torch
578
+ prob_1hw_t = self._prob_1hw_from_mask_u8(seed_mask, w, h) # (1,H,W) torch
579
+
580
+ # temp frames
581
+ tmp_dir = TEMP_DIR / f"ma_{int(time.time())}_{random.randint(1000,9999)}"
582
+ tmp_dir.mkdir(parents=True, exist_ok=True)
583
+ memory_manager.register_temp_file(str(tmp_dir))
584
+
585
+ frame_idx = 0
586
+
587
+ # --- first frame (with soft prob) ---
588
+ ok, frame_bgr = cap.read()
589
+ if not ok or frame_bgr is None:
590
+ cap.release()
591
+ raise RuntimeError("Empty first frame")
592
+ frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
593
+
594
+ img_chw_t = self._to_chw_float(frame_rgb01) # (3,H,W) torch
595
+
596
+ with torch.no_grad():
597
+ out_prob = self._try_step_variants_seed(
598
+ img_chw_t, prob_hw_t, prob_1hw_t
599
+ )
600
+
601
+ alpha_u8 = self._alpha_to_u8_hw(out_prob)
602
+ cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
603
+ frame_idx += 1
604
+
605
+ # --- remaining frames (no seed) ---
606
+ while True:
607
+ ok, frame_bgr = cap.read()
608
+ if not ok or frame_bgr is None:
609
+ break
610
+
611
+ frame_rgb01 = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
612
+ img_chw_t = self._to_chw_float(frame_rgb01)
613
+
614
+ with torch.no_grad():
615
+ out_prob = self._try_step_variants_noseed(img_chw_t)
616
+
617
+ alpha_u8 = self._alpha_to_u8_hw(out_prob)
618
+ cv2.imwrite(str(tmp_dir / f"{frame_idx:06d}.png"), alpha_u8)
619
+ frame_idx += 1
620
+
621
+ cap.release()
622
+
623
+ # --- encode PNGs β†’ alpha mp4 ---
624
+ list_file = tmp_dir / "list.txt"
625
+ with open(list_file, "w") as f:
626
+ for i in range(frame_idx):
627
+ f.write(f"file '{(tmp_dir / f'{i:06d}.png').as_posix()}'\n")
628
+
629
+ cmd = [
630
+ "ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
631
+ "-f", "concat", "-safe", "0",
632
+ "-r", f"{fps:.6f}",
633
+ "-i", str(list_file),
634
+ "-vf", f"format=gray,scale={w}:{h}:flags=area",
635
+ "-pix_fmt", "yuv420p",
636
+ "-c:v", "libx264", "-preset", "medium", "-crf", "18",
637
+ str(alpha_path)
638
+ ]
639
+ subprocess.run(cmd, check=True)
640
+ return str(alpha_path)
641
+
642
+ # =============================================================================
643
+ # CHAPTER 7: AI BACKGROUNDS
644
+ # =============================================================================
645
+ def _maybe_enable_xformers(pipe):
646
+ try:
647
+ pipe.enable_xformers_memory_efficient_attention()
648
+ except Exception:
649
+ pass
650
+
651
+ def _setup_memory_efficient_pipeline(pipe, require_gpu: bool):
652
+ _maybe_enable_xformers(pipe)
653
+ if not require_gpu:
654
+ try:
655
+ if hasattr(pipe, "enable_attention_slicing"):
656
+ pipe.enable_attention_slicing("auto")
657
+ if hasattr(pipe, "enable_model_cpu_offload"):
658
+ pipe.enable_model_cpu_offload()
659
+ if hasattr(pipe, "enable_sequential_cpu_offload"):
660
+ pipe.enable_sequential_cpu_offload()
661
+ except Exception:
662
+ pass
663
+
664
+ def generate_sdxl_background(width:int, height:int, prompt:str, steps:int=30, guidance:float=7.0,
665
+ seed:Optional[int]=None, require_gpu:bool=False) -> str:
666
+ if not TORCH_AVAILABLE:
667
+ raise RuntimeError("PyTorch required for SDXL")
668
+ with memory_manager.mem_context("SDXL background"):
669
+ try:
670
+ from diffusers import StableDiffusionXLPipeline
671
+ except ImportError as e:
672
+ raise RuntimeError("Install diffusers/transformers/accelerate") from e
673
+
674
+ if require_gpu and not CUDA_AVAILABLE:
675
+ raise RuntimeError("Force GPU enabled but CUDA not available")
676
+
677
+ device = "cuda" if CUDA_AVAILABLE else "cpu"
678
+ torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32
679
+
680
+ generator = torch.Generator(device=device)
681
+ if seed is None:
682
+ seed = random.randint(0, 2**31 - 1)
683
+ generator.manual_seed(int(seed))
684
+
685
+ pipe = StableDiffusionXLPipeline.from_pretrained(
686
+ "stabilityai/stable-diffusion-xl-base-1.0",
687
+ torch_dtype=torch_dtype,
688
+ add_watermarker=False,
689
+ ).to(device)
690
+
691
+ _setup_memory_efficient_pipeline(pipe, require_gpu)
692
+
693
+ enhanced = f"{prompt}, professional studio lighting, high detail, clean composition"
694
+ img = pipe(
695
+ prompt=enhanced,
696
+ height=int(height),
697
+ width=int(width),
698
+ num_inference_steps=int(steps),
699
+ guidance_scale=float(guidance),
700
+ generator=generator
701
+ ).images[0]
702
+
703
+ out = TEMP_DIR / f"sdxl_bg_{int(time.time())}_{seed or 0:08d}.jpg"
704
+ img.save(out, quality=95, optimize=True)
705
+ memory_manager.register_temp_file(str(out))
706
+ del pipe, img
707
+ return str(out)
708
+
709
+ def generate_playground_v25_background(width:int, height:int, prompt:str, steps:int=30, guidance:float=7.0,
710
+ seed:Optional[int]=None, require_gpu:bool=False) -> str:
711
+ if not TORCH_AVAILABLE:
712
+ raise RuntimeError("PyTorch required for Playground v2.5")
713
+ with memory_manager.mem_context("Playground v2.5 background"):
714
+ try:
715
+ from diffusers import DiffusionPipeline
716
+ except ImportError as e:
717
+ raise RuntimeError("Install diffusers/transformers/accelerate") from e
718
+
719
+ if require_gpu and not CUDA_AVAILABLE:
720
+ raise RuntimeError("Force GPU enabled but CUDA not available")
721
+
722
+ device = "cuda" if CUDA_AVAILABLE else "cpu"
723
+ torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32
724
+
725
+ generator = torch.Generator(device=device)
726
+ if seed is None:
727
+ seed = random.randint(0, 2**31 - 1)
728
+ generator.manual_seed(int(seed))
729
+
730
+ repo_id = "playgroundai/playground-v2.5-1024px-aesthetic"
731
+ pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
732
+ _setup_memory_efficient_pipeline(pipe, require_gpu)
733
+
734
+ enhanced = f"{prompt}, professional quality, soft light, minimal distractions"
735
+ img = pipe(
736
+ prompt=enhanced,
737
+ height=int(height),
738
+ width=int(width),
739
+ num_inference_steps=int(steps),
740
+ guidance_scale=float(guidance),
741
+ generator=generator
742
+ ).images[0]
743
+
744
+ out = TEMP_DIR / f"pg25_bg_{int(time.time())}_{seed or 0:08d}.jpg"
745
+ img.save(out, quality=95, optimize=True)
746
+ memory_manager.register_temp_file(str(out))
747
+ del pipe, img
748
+ return str(out)
749
+
750
+ def generate_sd15_background(width:int, height:int, prompt:str, steps:int=25, guidance:float=7.5,
751
+ seed:Optional[int]=None, require_gpu:bool=False) -> str:
752
+ if not TORCH_AVAILABLE:
753
+ raise RuntimeError("PyTorch required for SD 1.5")
754
+ with memory_manager.mem_context("SD1.5 background"):
755
+ try:
756
+ from diffusers import StableDiffusionPipeline
757
+ except ImportError as e:
758
+ raise RuntimeError("Install diffusers/transformers/accelerate") from e
759
+
760
+ if require_gpu and not CUDA_AVAILABLE:
761
+ raise RuntimeError("Force GPU enabled but CUDA not available")
762
+
763
+ device = "cuda" if CUDA_AVAILABLE else "cpu"
764
+ torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32
765
+
766
+ generator = torch.Generator(device=device)
767
+ if seed is None:
768
+ seed = random.randint(0, 2**31 - 1)
769
+ generator.manual_seed(int(seed))
770
+
771
+ pipe = StableDiffusionPipeline.from_pretrained(
772
+ "runwayml/stable-diffusion-v1-5",
773
+ torch_dtype=torch_dtype,
774
+ safety_checker=None,
775
+ requires_safety_checker=False
776
+ ).to(device)
777
+
778
+ _setup_memory_efficient_pipeline(pipe, require_gpu)
779
+
780
+ enhanced = f"{prompt}, professional background, clean composition"
781
+ img = pipe(
782
+ prompt=enhanced,
783
+ height=int(height),
784
+ width=int(width),
785
+ num_inference_steps=int(steps),
786
+ guidance_scale=float(guidance),
787
+ generator=generator
788
+ ).images[0]
789
+
790
+ out = TEMP_DIR / f"sd15_bg_{int(time.time())}_{seed or 0:08d}.jpg"
791
+ img.save(out, quality=95, optimize=True)
792
+ memory_manager.register_temp_file(str(out))
793
+ del pipe, img
794
+ return str(out)
795
+
796
+ def generate_openai_background(width:int, height:int, prompt:str, api_key:str, model:str="gpt-image-1") -> str:
797
+ if not api_key or not isinstance(api_key, str) or len(api_key) < 10:
798
+ raise RuntimeError("Missing or invalid OpenAI API key")
799
+ with memory_manager.mem_context("OpenAI background"):
800
+ target = "1024x1024"
801
+ url = "https://api.openai.com/v1/images/generations"
802
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
803
+ body = {"model": model, "prompt": f"{prompt}, professional background, studio lighting, minimal distractions, high detail",
804
+ "size": target, "n": 1, "quality": "high"}
805
+ import requests
806
+ r = requests.post(url, headers=headers, data=json.dumps(body), timeout=120)
807
+ if r.status_code != 200:
808
+ raise RuntimeError(f"OpenAI API error: {r.status_code} {r.text}")
809
+ data = r.json()
810
+ b64 = data["data"][0]["b64_json"]
811
+ raw = base64.b64decode(b64)
812
+ tmp_png = TEMP_DIR / f"openai_raw_{int(time.time())}_{random.randint(1000,9999)}.png"
813
+ with open(tmp_png, "wb") as f:
814
+ f.write(raw)
815
+ img = Image.open(tmp_png).convert("RGB").resize((int(width), int(height)), Image.LANCZOS)
816
+ out = TEMP_DIR / f"openai_bg_{int(time.time())}_{random.randint(1000,9999)}.jpg"
817
+ img.save(out, quality=95, optimize=True)
818
+ try: os.unlink(tmp_png)
819
+ except Exception: pass
820
+ memory_manager.register_temp_file(str(out))
821
+ return str(out)
822
+
823
+ def generate_ai_background_router(width:int, height:int, prompt:str, model:str="SDXL",
824
+ steps:int=30, guidance:float=7.0, seed:Optional[int]=None,
825
+ openai_key:Optional[str]=None, require_gpu:bool=False) -> str:
826
+ try:
827
+ if model == "OpenAI (gpt-image-1)":
828
+ if not openai_key:
829
+ raise RuntimeError("OpenAI API key not provided")
830
+ return generate_openai_background(width, height, prompt, openai_key, model="gpt-image-1")
831
+ elif model == "Playground v2.5":
832
+ return generate_playground_v25_background(width, height, prompt, steps, guidance, seed, require_gpu)
833
+ elif model == "SDXL":
834
+ return generate_sdxl_background(width, height, prompt, steps, guidance, seed, require_gpu)
835
+ else:
836
+ return generate_sd15_background(width, height, prompt, steps, guidance, seed, require_gpu)
837
+ except Exception as e:
838
+ logger.warning(f"{model} generation failed: {e}; falling back to SD1.5/gradient")
839
+ try:
840
+ return generate_sd15_background(width, height, prompt, steps, guidance, seed, require_gpu=False)
841
+ except Exception:
842
+ grad = _make_vertical_gradient(width, height, (235, 240, 245), (210, 220, 230))
843
+ out = TEMP_DIR / f"bg_fallback_{int(time.time())}.jpg"
844
+ cv2.imwrite(str(out), grad)
845
+ memory_manager.register_temp_file(str(out))
846
+ return str(out)
847
+
848
+ # =============================================================================
849
+ # CHAPTER 8: CHUNKED PROCESSOR (optional)
850
+ # =============================================================================
851
+ class ChunkedVideoProcessor:
852
+ def __init__(self, chunk_size_frames: int = 60):
853
+ self.chunk_size = int(chunk_size_frames)
854
+
855
+ def _extract_chunk(self, video_path: str, start_frame: int, end_frame: int, fps: float) -> str:
856
+ chunk_path = str(TEMP_DIR / f"chunk_{start_frame}_{end_frame}_{random.randint(1000,9999)}.mp4")
857
+ start_time = start_frame / fps
858
+ duration = max(0.001, (end_frame - start_frame) / fps)
859
+ cmd = [
860
+ "ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
861
+ "-ss", f"{start_time:.6f}", "-i", video_path,
862
+ "-t", f"{duration:.6f}",
863
+ "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2",
864
+ "-c:v", "libx264", "-preset", "veryfast", "-crf", "20",
865
+ "-an", chunk_path
866
+ ]
867
+ subprocess.run(cmd, check=True)
868
+ return chunk_path
869
+
870
+ def _merge_chunks(self, chunk_paths: List[str], fps: float, width: int, height: int) -> str:
871
+ if not chunk_paths:
872
+ raise ValueError("No chunks to merge")
873
+ if len(chunk_paths) == 1:
874
+ return chunk_paths[0]
875
+ concat_file = TEMP_DIR / f"concat_{random.randint(1000,9999)}.txt"
876
+ with open(concat_file, "w") as f:
877
+ for c in chunk_paths:
878
+ f.write(f"file '{c}'\n")
879
+ out = TEMP_DIR / f"merged_{random.randint(1000,9999)}.mp4"
880
+ cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
881
+ "-f", "concat", "-safe", "0", "-i", str(concat_file),
882
+ "-c", "copy", str(out)]
883
+ subprocess.run(cmd, check=True)
884
+ return str(out)
885
+
886
+ def process_video_chunks(self, video_path: str, processor_func, **kwargs) -> str:
887
+ cap = cv2.VideoCapture(video_path)
888
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
889
+ fps = cap.get(cv2.CAP_PROP_FPS) or 24.0
890
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
891
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
892
+ cap.release()
893
+
894
+ processed: List[str] = []
895
+ for start in range(0, total, self.chunk_size):
896
+ end = min(start + self.chunk_size, total)
897
+ with memory_manager.mem_context(f"chunk {start}-{end}"):
898
+ ch = self._extract_chunk(video_path, start, end, fps)
899
+ memory_manager.register_temp_file(ch)
900
+ out = processor_func(ch, **kwargs)
901
+ memory_manager.register_temp_file(out)
902
+ processed.append(out)
903
+ return self._merge_chunks(processed, fps, width, height)
904
+
905
+ # =============================================================================
906
+ # CHAPTER 9: MAIN PIPELINE (SAM2 β†’ MatAnyone β†’ Composite) - FIXED VERSION
907
+ # =============================================================================
908
+ def process_video_main(
909
+ video_path: str,
910
+ background_path: Optional[str] = None,
911
+ trim_duration: Optional[float] = None,
912
+ crf: int = 18,
913
+ preserve_audio_flag: bool = True,
914
+ placement: Optional[dict] = None,
915
+ use_chunked_processing: bool = False,
916
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
917
+ ) -> Tuple[Optional[str], str]:
918
+
919
+ messages: List[str] = []
920
+ with memory_manager.mem_context("Pipeline"):
921
+ try:
922
+ progress(0, desc="Initializing models")
923
+ sam2 = SAM2Handler()
924
+ matanyone = MatAnyoneHandler()
925
+
926
+ if not sam2.initialize():
927
+ return None, f"SAM2 init failed: {state.sam2_error}"
928
+ if not matanyone.initialize():
929
+ return None, f"MatAnyone init failed: {state.matanyone_error}"
930
+ messages.append("βœ… SAM2 & MatAnyone initialized")
931
+
932
+ progress(0.1, desc="Preparing video")
933
+ input_video = video_path
934
+
935
+ # Optional trim
936
+ if trim_duration and float(trim_duration) > 0:
937
+ trimmed = TEMP_DIR / f"trimmed_{int(time.time())}_{random.randint(1000,9999)}.mp4"
938
+ memory_manager.register_temp_file(str(trimmed))
939
+ with VideoFileClip(video_path) as clip:
940
+ d = min(float(trim_duration), float(clip.duration or trim_duration))
941
+ sub = clip.subclip(0, d)
942
+ write_video_h264(sub, str(trimmed), crf=int(crf))
943
+ sub.close()
944
+ input_video = str(trimmed)
945
+ messages.append(f"βœ‚οΈ Trimmed to {d:.1f}s")
946
+ else:
947
+ with VideoFileClip(video_path) as clip:
948
+ messages.append(f"🎞️ Full video: {clip.duration:.1f}s")
949
+
950
+ progress(0.2, desc="Creating SAM2 mask")
951
+ cap = cv2.VideoCapture(input_video)
952
+ ret, first_frame = cap.read()
953
+ cap.release()
954
+ if not ret or first_frame is None:
955
+ return None, "Could not read video"
956
+ h, w = first_frame.shape[:2]
957
+ rgb0 = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
958
+ mask = sam2.create_mask(rgb0)
959
+ if mask is None:
960
+ return None, "SAM2 mask failed"
961
+
962
+ mask_path = TEMP_DIR / f"mask_{int(time.time())}_{random.randint(1000,9999)}.png"
963
+ memory_manager.register_temp_file(str(mask_path))
964
+ cv2.imwrite(str(mask_path), mask)
965
+ messages.append("βœ… Person mask created")
966
+
967
+ progress(0.35, desc="Matting video")
968
+ if use_chunked_processing:
969
+ chunker = ChunkedVideoProcessor(chunk_size_frames=60)
970
+ alpha_video = chunker.process_video_chunks(
971
+ input_video,
972
+ lambda chunk_path, **_k: matanyone.process_video(
973
+ input_path=chunk_path,
974
+ mask_path=str(mask_path),
975
+ output_path=str(TEMP_DIR / f"matanyone_chunk_{int(time.time())}_{random.randint(1000,9999)}")
976
+ )
977
+ )
978
+ memory_manager.register_temp_file(alpha_video)
979
+ else:
980
+ out_dir = TEMP_DIR / f"matanyone_out_{int(time.time())}_{random.randint(1000,9999)}"
981
+ out_dir.mkdir(parents=True, exist_ok=True)
982
+ memory_manager.register_temp_file(str(out_dir))
983
+ alpha_video = matanyone.process_video(
984
+ input_path=input_video,
985
+ mask_path=str(mask_path),
986
+ output_path=str(out_dir)
987
+ )
988
+
989
+ if not alpha_video or not os.path.exists(alpha_video):
990
+ return None, "MatAnyone did not produce alpha video"
991
+ messages.append("βœ… Alpha video generated")
992
+
993
+ progress(0.55, desc="Preparing background")
994
+ original_clip = VideoFileClip(input_video)
995
+ alpha_clip = VideoFileClip(alpha_video)
996
+
997
+ if background_path and os.path.exists(background_path):
998
+ messages.append("πŸ–ΌοΈ Using background file")
999
+ bg_bgr = cv2.imread(background_path)
1000
+ bg_bgr = cv2.resize(bg_bgr, (w, h))
1001
+ bg_rgb = cv2.cvtColor(bg_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
1002
+ else:
1003
+ messages.append("πŸ–ΌοΈ Using gradient background")
1004
+ grad = _make_vertical_gradient(w, h, (200, 205, 215), (160, 170, 190))
1005
+ bg_rgb = cv2.cvtColor(grad, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
1006
+
1007
+ # FIXED: Enhanced placement parameters with validation and debugging
1008
+ placement = placement or {}
1009
+ px = max(0.0, min(1.0, float(placement.get("x", 0.5))))
1010
+ py = max(0.0, min(1.0, float(placement.get("y", 0.75))))
1011
+ ps = max(0.3, min(2.0, float(placement.get("scale", 1.0))))
1012
+ feather_px = max(0, min(50, int(placement.get("feather", 3))))
1013
+
1014
+ # Debug logging for placement parameters
1015
+ logger.info(f"POSITIONING DEBUG: px={px:.3f}, py={py:.3f}, ps={ps:.3f}, feather={feather_px}")
1016
+ logger.info(f"VIDEO DIMENSIONS: {w}x{h}")
1017
+ logger.info(f"TARGET CENTER: ({int(px * w)}, {int(py * h)})")
1018
+
1019
+ frame_count = 0
1020
+ def composite_frame(get_frame, t):
1021
+ nonlocal frame_count
1022
+ frame_count += 1
1023
+
1024
+ # Get original frame
1025
+ frame = get_frame(t).astype(np.float32) / 255.0
1026
+ hh, ww = frame.shape[:2]
1027
+
1028
+ # FIXED: Better alpha temporal synchronization
1029
+ alpha_duration = getattr(alpha_clip, 'duration', None)
1030
+ if alpha_duration and alpha_duration > 0:
1031
+ # Ensure we don't go beyond alpha video duration
1032
+ alpha_t = min(t, alpha_duration - 0.01)
1033
+ alpha_t = max(0.0, alpha_t)
1034
+ else:
1035
+ alpha_t = 0.0
1036
+
1037
+ try:
1038
+ a = alpha_clip.get_frame(alpha_t)
1039
+ # Handle multi-channel alpha
1040
+ if a.ndim == 3:
1041
+ a = a[:, :, 0]
1042
+ a = a.astype(np.float32) / 255.0
1043
+
1044
+ # FIXED: Ensure alpha matches frame dimensions exactly
1045
+ if a.shape != (hh, ww):
1046
+ logger.warning(f"Alpha size mismatch: {a.shape} vs {(hh, ww)}, resizing...")
1047
+ a = cv2.resize(a, (ww, hh), interpolation=cv2.INTER_LINEAR)
1048
+
1049
+ except Exception as e:
1050
+ logger.error(f"Alpha frame error at t={t:.3f}: {e}")
1051
+ return (bg_rgb * 255).astype(np.uint8)
1052
+
1053
+ # FIXED: Calculate scaled dimensions with better rounding
1054
+ sw = max(1, round(ww * ps)) # Use round instead of int for better precision
1055
+ sh = max(1, round(hh * ps))
1056
+
1057
+ # FIXED: Scale both frame and alpha consistently
1058
+ try:
1059
+ fg_scaled = cv2.resize(frame, (sw, sh), interpolation=cv2.INTER_AREA if ps < 1.0 else cv2.INTER_LINEAR)
1060
+ a_scaled = cv2.resize(a, (sw, sh), interpolation=cv2.INTER_AREA if ps < 1.0 else cv2.INTER_LINEAR)
1061
+ except Exception as e:
1062
+ logger.error(f"Scaling error: {e}")
1063
+ return (bg_rgb * 255).astype(np.uint8)
1064
+
1065
+ # Create canvases
1066
+ fg_canvas = np.zeros_like(frame, dtype=np.float32)
1067
+ a_canvas = np.zeros((hh, ww), dtype=np.float32)
1068
+
1069
+ # FIXED: More precise center calculations
1070
+ cx = round(px * ww)
1071
+ cy = round(py * hh)
1072
+
1073
+ # FIXED: Use floor division for consistent centering
1074
+ x0 = cx - sw // 2
1075
+ y0 = cy - sh // 2
1076
+
1077
+ # Debug logging for first few frames
1078
+ if frame_count <= 3:
1079
+ logger.info(f"FRAME {frame_count}: scaled_size=({sw}, {sh}), center=({cx}, {cy}), top_left=({x0}, {y0})")
1080
+
1081
+ # FIXED: Robust bounds checking with edge case handling
1082
+ xs0 = max(0, x0)
1083
+ ys0 = max(0, y0)
1084
+ xs1 = min(ww, x0 + sw)
1085
+ ys1 = min(hh, y0 + sh)
1086
+
1087
+ # Check for valid placement region
1088
+ if xs1 <= xs0 or ys1 <= ys0:
1089
+ if frame_count <= 3:
1090
+ logger.warning(f"Subject outside bounds: dest=({xs0},{ys0})-({xs1},{ys1})")
1091
+ return (bg_rgb * 255).astype(np.uint8)
1092
+
1093
+ # FIXED: Calculate source region with bounds validation
1094
+ src_x0 = xs0 - x0 # Will be 0 if x0 >= 0, positive if x0 < 0
1095
+ src_y0 = ys0 - y0 # Will be 0 if y0 >= 0, positive if y0 < 0
1096
+ src_x1 = src_x0 + (xs1 - xs0)
1097
+ src_y1 = src_y0 + (ys1 - ys0)
1098
+
1099
+ # Validate source bounds
1100
+ if (src_x1 > sw or src_y1 > sh or src_x0 < 0 or src_y0 < 0 or
1101
+ src_x1 <= src_x0 or src_y1 <= src_y0):
1102
+ if frame_count <= 3:
1103
+ logger.error(f"Invalid source region: ({src_x0},{src_y0})-({src_x1},{src_y1}) for {sw}x{sh} scaled")
1104
+ return (bg_rgb * 255).astype(np.uint8)
1105
+
1106
+ # FIXED: Safe canvas placement with error handling
1107
+ try:
1108
+ fg_canvas[ys0:ys1, xs0:xs1, :] = fg_scaled[src_y0:src_y1, src_x0:src_x1, :]
1109
+ a_canvas[ys0:ys1, xs0:xs1] = a_scaled[src_y0:src_y1, src_x0:src_x1]
1110
+ except Exception as e:
1111
+ logger.error(f"Canvas placement failed: {e}")
1112
+ logger.error(f"Dest: [{ys0}:{ys1}, {xs0}:{xs1}], Src: [{src_y0}:{src_y1}, {src_x0}:{src_x1}]")
1113
+ return (bg_rgb * 255).astype(np.uint8)
1114
+
1115
+ # FIXED: Apply feathering with bounds checking
1116
+ if feather_px > 0:
1117
+ kernel_size = max(3, feather_px * 2 + 1)
1118
+ if kernel_size % 2 == 0:
1119
+ kernel_size += 1 # Ensure odd kernel size
1120
+ try:
1121
+ a_canvas = cv2.GaussianBlur(a_canvas, (kernel_size, kernel_size), feather_px / 3.0)
1122
+ except Exception as e:
1123
+ logger.warning(f"Feathering failed: {e}")
1124
+
1125
+ # FIXED: Composite with proper alpha handling
1126
+ a3 = np.expand_dims(a_canvas, axis=2) # More explicit than [:, :, None]
1127
+ comp = a3 * fg_canvas + (1.0 - a3) * bg_rgb
1128
+ result = np.clip(comp * 255, 0, 255).astype(np.uint8)
1129
+
1130
+ return result
1131
+
1132
+ progress(0.7, desc="Compositing")
1133
+ final_clip = original_clip.fl(composite_frame)
1134
+
1135
+ output_path = OUT_DIR / f"processed_{int(time.time())}_{random.randint(1000,9999)}.mp4"
1136
+ temp_video_path = TEMP_DIR / f"temp_video_{int(time.time())}_{random.randint(1000,9999)}.mp4"
1137
+ memory_manager.register_temp_file(str(temp_video_path))
1138
+
1139
+ write_video_h264(final_clip, str(temp_video_path), crf=int(crf))
1140
+ original_clip.close(); alpha_clip.close(); final_clip.close()
1141
+
1142
+ progress(0.85, desc="Merging audio")
1143
+ if preserve_audio_flag:
1144
+ success = run_ffmpeg([
1145
+ "-i", str(temp_video_path),
1146
+ "-i", video_path,
1147
+ "-map", "0:v:0",
1148
+ "-map", "1:a:0?",
1149
+ "-c:v", "copy",
1150
+ "-c:a", "aac",
1151
+ "-b:a", "192k",
1152
+ "-shortest",
1153
+ str(output_path)
1154
+ ], fail_ok=True)
1155
+ if success:
1156
+ messages.append("πŸ”Š Original audio preserved")
1157
+ else:
1158
+ shutil.copy2(str(temp_video_path), str(output_path))
1159
+ messages.append("⚠️ Audio merge failed, saved w/o audio")
1160
+ else:
1161
+ shutil.copy2(str(temp_video_path), str(output_path))
1162
+ messages.append("πŸ”‡ Saved without audio")
1163
+
1164
+ messages.append("βœ… Done")
1165
+ stats = memory_manager.get_memory_stats()
1166
+ messages.append(f"πŸ“Š CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB")
1167
+ messages.append(f"🎯 Processed {frame_count} frames with placement ({px:.2f}, {py:.2f}) @ {ps:.2f}x scale")
1168
+ progress(1.0, desc="Done")
1169
+ return str(output_path), "\n".join(messages)
1170
+
1171
+ except Exception as e:
1172
+ err = f"Processing failed: {str(e)}\n\n{traceback.format_exc()}"
1173
+ return None, err
1174
+
1175
+ # =============================================================================
1176
+ # CHAPTER 10: GRADIO UI
1177
+ # =============================================================================
1178
+ def create_interface():
1179
+ def diag():
1180
+ return state.status_text()
1181
+
1182
+ def cleanup():
1183
+ memory_manager.aggressive_cleanup()
1184
+ s = memory_manager.get_memory_stats()
1185
+ return f"🧹 Cleanup\nCPU: {s.cpu_memory_mb:.1f}MB\nGPU: {s.gpu_memory_mb:.1f}MB\nTemp: {s.temp_files_count} files"
1186
+
1187
+ def preload(ai_model, openai_key, force_gpu, progress=gr.Progress()):
1188
+ try:
1189
+ progress(0, desc="Preloading...")
1190
+ msg = ""
1191
+ if ai_model in ("SDXL", "Playground v2.5", "SD 1.5 (fallback)"):
1192
+ try:
1193
+ if ai_model == "SDXL":
1194
+ _ = generate_sdxl_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu))
1195
+ elif ai_model == "Playground v2.5":
1196
+ _ = generate_playground_v25_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu))
1197
+ else:
1198
+ _ = generate_sd15_background(64, 64, "plain", steps=2, guidance=3.5, seed=42, require_gpu=bool(force_gpu))
1199
+ msg += f"{ai_model} preloaded.\n"
1200
+ except Exception as e:
1201
+ msg += f"{ai_model} preload failed: {e}\n"
1202
+
1203
+ _reset_hydra()
1204
+ s, m = SAM2Handler(), MatAnyoneHandler()
1205
+ ok_s = s.initialize()
1206
+ _reset_hydra()
1207
+ ok_m = m.initialize()
1208
+ progress(1.0, desc="Preload complete")
1209
+ return f"βœ… Preload\n{msg}SAM2: {'ready' if ok_s else 'failed'}\nMatAnyone: {'ready' if ok_m else 'failed'}"
1210
+ except Exception as e:
1211
+ return f"❌ Preload error: {e}"
1212
+
1213
+ def generate_background_safe(video_file, ai_prompt, ai_steps, ai_guidance, ai_seed,
1214
+ ai_model, openai_key, force_gpu, progress=gr.Progress()):
1215
+ if not video_file:
1216
+ return None, "Upload a video first", gr.update(visible=False), None
1217
+ with memory_manager.mem_context("Background generation"):
1218
+ try:
1219
+ video_path = video_file.name if hasattr(video_file, 'name') else str(video_file)
1220
+ if not os.path.exists(video_path):
1221
+ return None, "Video not found", gr.update(visible=False), None
1222
+ cap = cv2.VideoCapture(video_path)
1223
+ if not cap.isOpened():
1224
+ return None, "Could not open video", gr.update(visible=False), None
1225
+ ret, frame = cap.read()
1226
+ cap.release()
1227
+ if not ret or frame is None:
1228
+ return None, "Could not read frame", gr.update(visible=False), None
1229
+ h, w = int(frame.shape[0]), int(frame.shape[1])
1230
+
1231
+ steps = max(1, min(50, int(ai_steps or 30)))
1232
+ guidance = max(1.0, min(15.0, float(ai_guidance or 7.0)))
1233
+ try:
1234
+ seed_val = int(ai_seed) if ai_seed and str(ai_seed).strip() else None
1235
+ except Exception:
1236
+ seed_val = None
1237
+
1238
+ progress(0.1, desc=f"Generating {ai_model}")
1239
+ bg_path = generate_ai_background_router(
1240
+ width=w, height=h, prompt=str(ai_prompt or "professional office background").strip(),
1241
+ model=str(ai_model or "SDXL"), steps=steps, guidance=guidance,
1242
+ seed=seed_val, openai_key=openai_key, require_gpu=bool(force_gpu)
1243
+ )
1244
+ progress(1.0, desc="Background ready")
1245
+ if bg_path and os.path.exists(bg_path):
1246
+ return bg_path, f"AI background generated with {ai_model}", gr.update(visible=True), bg_path
1247
+ else:
1248
+ return None, "No output file", gr.update(visible=False), None
1249
+ except Exception as e:
1250
+ logger.error(f"Background generation error: {e}")
1251
+ return None, f"Background generation failed: {str(e)}", gr.update(visible=False), None
1252
+
1253
+ def approve_background(bg_path):
1254
+ try:
1255
+ if not bg_path or not (isinstance(bg_path, str) and os.path.exists(bg_path)):
1256
+ return None, "Generate a background first", gr.update(visible=False)
1257
+ ext = os.path.splitext(bg_path)[1].lower() or ".jpg"
1258
+ safe_name = f"approved_{int(time.time())}_{random.randint(1000,9999)}{ext}"
1259
+ dest = BACKGROUND_DIR / safe_name
1260
+ shutil.copy2(bg_path, dest)
1261
+ return str(dest), f"βœ… Background approved β†’ {dest.name}", gr.update(visible=False)
1262
+ except Exception as e:
1263
+ return None, f"⚠️ Approve failed: {e}", gr.update(visible=False)
1264
+
1265
+ css = """
1266
+ .gradio-container { font-size: 16px !important; }
1267
+ label { font-size: 18px !important; font-weight: 600 !important; color: #2d3748 !important; }
1268
+ .process-button { font-size: 20px !important; font-weight: 700 !important; padding: 16px 28px !important; }
1269
+ .memory-info { background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px; padding: 12px; }
1270
+ """
1271
+
1272
+ with gr.Blocks(title="Enhanced Video Background Replacement", theme=gr.themes.Soft(), css=css) as interface:
1273
+ gr.Markdown("# 🎬 Enhanced Video Background Replacement")
1274
+ gr.Markdown("_SAM2 + MatAnyone + AI Backgrounds β€” with strict tensor shapes & memory management_")
1275
+
1276
+ gr.HTML(f"""
1277
+ <div class='memory-info'>
1278
+ <strong>Device:</strong> {DEVICE} &nbsp;&nbsp;
1279
+ <strong>PyTorch:</strong> {'βœ…' if TORCH_AVAILABLE else '❌'} &nbsp;&nbsp;
1280
+ <strong>CUDA:</strong> {'βœ…' if CUDA_AVAILABLE else '❌'}
1281
+ </div>
1282
+ """)
1283
+
1284
+ with gr.Row():
1285
+ with gr.Column(scale=1):
1286
+ video_input = gr.Video(label="Input Video")
1287
+
1288
+ gr.Markdown("### Background")
1289
+ bg_method = gr.Radio(choices=["Upload Image", "Gradients", "AI Generated"],
1290
+ value="AI Generated", label="Background Method")
1291
+
1292
+ # Upload group (hidden by default)
1293
+ with gr.Group(visible=False) as upload_group:
1294
+ upload_img = gr.Image(label="Background Image", type="filepath")
1295
+
1296
+ # Gradient group (hidden by default)
1297
+ with gr.Group(visible=False) as gradient_group:
1298
+ gradient_choice = gr.Dropdown(label="Gradient Style",
1299
+ choices=list(GRADIENT_PRESETS.keys()),
1300
+ value="Slate")
1301
+
1302
+ # AI group (visible by default)
1303
+ with gr.Group(visible=True) as ai_group:
1304
+ prompt_suggestions = gr.Dropdown(label="πŸ’‘ Prompt Inspiration",
1305
+ choices=AI_PROMPT_SUGGESTIONS,
1306
+ value="Custom (write your own)")
1307
+ ai_prompt = gr.Textbox(label="Background Description",
1308
+ value="professional office background", lines=3)
1309
+ ai_model = gr.Radio(["SDXL", "Playground v2.5", "SD 1.5 (fallback)", "OpenAI (gpt-image-1)"],
1310
+ value="SDXL", label="AI Model")
1311
+ with gr.Accordion("Connect services (optional)", open=False):
1312
+ openai_api_key = gr.Textbox(label="OpenAI API Key", type="password",
1313
+ placeholder="sk-... (kept only in this session)")
1314
+ with gr.Row():
1315
+ ai_steps = gr.Slider(10, 50, value=30, step=1, label="Quality (steps)")
1316
+ ai_guidance = gr.Slider(1.0, 15.0, value=7.0, step=0.1, label="Guidance")
1317
+ ai_seed = gr.Number(label="Seed (optional)", precision=0)
1318
+ force_gpu_ai = gr.Checkbox(value=True, label="Force GPU for AI background")
1319
+ preload_btn = gr.Button("πŸ“¦ Preload Models")
1320
+ preload_status = gr.Textbox(label="Preload Status", lines=4)
1321
+ generate_bg_btn = gr.Button("Generate AI Background", variant="primary")
1322
+ ai_generated_bg = gr.Image(label="Generated Background", type="filepath")
1323
+ approve_bg_btn = gr.Button("βœ… Approve Background", visible=False)
1324
+ approved_background_path = gr.State(value=None)
1325
+ last_generated_bg = gr.State(value=None)
1326
+ ai_status = gr.Textbox(label="Generation Status", lines=2)
1327
+
1328
+ gr.Markdown("### Processing")
1329
+ with gr.Row():
1330
+ trim_enabled = gr.Checkbox(label="Trim Video", value=False)
1331
+ trim_seconds = gr.Number(label="Trim Duration (seconds)", value=5, precision=1)
1332
+ with gr.Row():
1333
+ crf_value = gr.Slider(0, 30, value=18, step=1, label="Quality (CRF - lower=better)")
1334
+ audio_enabled = gr.Checkbox(label="Preserve Audio", value=True)
1335
+ with gr.Row():
1336
+ use_chunked = gr.Checkbox(label="Use Chunked Processing", value=False)
1337
+
1338
+ gr.Markdown("### Subject Placement")
1339
+ with gr.Row():
1340
+ place_x = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Horizontal")
1341
+ place_y = gr.Slider(0.0, 1.0, value=0.75, step=0.01, label="Vertical")
1342
+ with gr.Row():
1343
+ place_scale = gr.Slider(0.3, 2.0, value=1.0, step=0.01, label="Scale")
1344
+ place_feather = gr.Slider(0, 15, value=3, step=1, label="Edge feather (px)")
1345
+
1346
+ process_btn = gr.Button("πŸš€ Process Video", variant="primary", elem_classes=["process-button"])
1347
+
1348
+ gr.Markdown("### System")
1349
+ with gr.Row():
1350
+ diagnostics_btn = gr.Button("πŸ“Š System Diagnostics")
1351
+ cleanup_btn = gr.Button("🧹 Memory Cleanup")
1352
+ diagnostics_output = gr.Textbox(label="System Status", lines=10)
1353
+
1354
+ with gr.Column(scale=1):
1355
+ output_video = gr.Video(label="Processed Video")
1356
+ download_file = gr.File(label="Download Processed Video")
1357
+ status_output = gr.Textbox(label="Processing Status", lines=20)
1358
+
1359
+ # --- Wiring ---
1360
+ def update_background_visibility(method):
1361
+ return (
1362
+ gr.update(visible=(method == "Upload Image")),
1363
+ gr.update(visible=(method == "Gradients")),
1364
+ gr.update(visible=(method == "AI Generated")),
1365
+ )
1366
+
1367
+ def update_prompt_from_suggestion(suggestion):
1368
+ if suggestion == "Custom (write your own)":
1369
+ return gr.update(value="", placeholder="Describe the background you want...")
1370
+ return gr.update(value=suggestion)
1371
+
1372
+ bg_method.change(
1373
+ update_background_visibility,
1374
+ inputs=[bg_method],
1375
+ outputs=[upload_group, gradient_group, ai_group]
1376
+ )
1377
+ prompt_suggestions.change(update_prompt_from_suggestion, inputs=[prompt_suggestions], outputs=[ai_prompt])
1378
+
1379
+ preload_btn.click(preload,
1380
+ inputs=[ai_model, openai_api_key, force_gpu_ai],
1381
+ outputs=[preload_status],
1382
+ show_progress=True
1383
+ )
1384
+
1385
+ generate_bg_btn.click(
1386
+ generate_background_safe,
1387
+ inputs=[video_input, ai_prompt, ai_steps, ai_guidance, ai_seed, ai_model, openai_api_key, force_gpu_ai],
1388
+ outputs=[ai_generated_bg, ai_status, approve_bg_btn, last_generated_bg],
1389
+ show_progress=True
1390
+ )
1391
+ approve_bg_btn.click(
1392
+ approve_background,
1393
+ inputs=[ai_generated_bg],
1394
+ outputs=[approved_background_path, ai_status, approve_bg_btn]
1395
+ )
1396
+
1397
+ diagnostics_btn.click(diag, outputs=[diagnostics_output])
1398
+ cleanup_btn.click(cleanup, outputs=[diagnostics_output])
1399
+
1400
+ def process_video(
1401
+ video_file,
1402
+ bg_method,
1403
+ upload_img,
1404
+ gradient_choice,
1405
+ approved_background_path,
1406
+ last_generated_bg,
1407
+ trim_enabled, trim_seconds, crf_value, audio_enabled,
1408
+ use_chunked,
1409
+ place_x, place_y, place_scale, place_feather,
1410
+ progress=gr.Progress(track_tqdm=True),
1411
+ ):
1412
+ try:
1413
+ if not video_file:
1414
+ return None, None, "Please upload a video file"
1415
+ video_path = video_file.name if hasattr(video_file, 'name') else str(video_file)
1416
+
1417
+ # Resolve background
1418
+ bg_path = None
1419
+ try:
1420
+ if bg_method == "Upload Image" and upload_img:
1421
+ bg_path = upload_img if isinstance(upload_img, str) else getattr(upload_img, "name", None)
1422
+ elif bg_method == "Gradients":
1423
+ cap = cv2.VideoCapture(video_path)
1424
+ ret, frame = cap.read(); cap.release()
1425
+ if ret and frame is not None:
1426
+ h, w = frame.shape[:2]
1427
+ if gradient_choice in GRADIENT_PRESETS:
1428
+ grad = _make_vertical_gradient(w, h, *GRADIENT_PRESETS[gradient_choice])
1429
+ tmp_bg = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, dir=TEMP_DIR).name
1430
+ cv2.imwrite(tmp_bg, grad)
1431
+ memory_manager.register_temp_file(tmp_bg)
1432
+ bg_path = tmp_bg
1433
+ else: # AI Generated
1434
+ if approved_background_path:
1435
+ bg_path = approved_background_path
1436
+ elif last_generated_bg and isinstance(last_generated_bg, str) and os.path.exists(last_generated_bg):
1437
+ bg_path = last_generated_bg
1438
+ except Exception as e:
1439
+ logger.error(f"Background setup error: {e}")
1440
+ return None, None, f"Background setup failed: {str(e)}"
1441
+
1442
+ result_path, status = process_video_main(
1443
+ video_path=video_path,
1444
+ background_path=bg_path,
1445
+ trim_duration=float(trim_seconds) if (trim_enabled and float(trim_seconds) > 0) else None,
1446
+ crf=int(crf_value),
1447
+ preserve_audio_flag=bool(audio_enabled),
1448
+ placement=dict(x=float(place_x), y=float(place_y), scale=float(place_scale), feather=int(place_feather)),
1449
+ use_chunked_processing=bool(use_chunked),
1450
+ progress=progress,
1451
+ )
1452
+
1453
+ if result_path and os.path.exists(result_path):
1454
+ return result_path, result_path, f"βœ… Success\n\n{status}"
1455
+ else:
1456
+ return None, None, f"❌ Failed\n\n{status or 'Unknown error'}"
1457
+ except Exception as e:
1458
+ tb = traceback.format_exc()
1459
+ return None, None, f"❌ Crash: {e}\n\n{tb}"
1460
+
1461
+ process_btn.click(
1462
+ process_video,
1463
+ inputs=[
1464
+ video_input,
1465
+ bg_method,
1466
+ upload_img,
1467
+ gradient_choice,
1468
+ approved_background_path, last_generated_bg,
1469
+ trim_enabled, trim_seconds, crf_value, audio_enabled,
1470
+ use_chunked,
1471
+ place_x, place_y, place_scale, place_feather,
1472
+ ],
1473
+ outputs=[output_video, download_file, status_output],
1474
+ show_progress=True
1475
+ )
1476
+
1477
+ return interface
1478
+
1479
+ # =============================================================================
1480
+ # CHAPTER 11: MAIN
1481
+ # =============================================================================
1482
+ def main():
1483
+ logger.info("Starting Enhanced Background Replacement")
1484
+ stats = memory_manager.get_memory_stats()
1485
+ logger.info(f"Initial memory: CPU {stats.cpu_memory_mb:.1f}MB, GPU {stats.gpu_memory_mb:.1f}MB")
1486
+ interface = create_interface()
1487
+ interface.queue(max_size=3)
1488
+ try:
1489
+ interface.launch(
1490
+ server_name="0.0.0.0",
1491
+ server_port=7860,
1492
+ share=False,
1493
+ inbrowser=False,
1494
+ show_error=True
1495
+ )
1496
+ finally:
1497
+ logger.info("Shutting down - cleanup")
1498
+ memory_manager.cleanup_temp_files()
1499
+ memory_manager.aggressive_cleanup()
1500
+
1501
+ if __name__ == "__main__":
1502
+ main()