MogensR commited on
Commit
efe9b1b
·
1 Parent(s): 3790c48

Update utils/cv_processing.py

Browse files
Files changed (1) hide show
  1. utils/cv_processing.py +90 -1349
utils/cv_processing.py CHANGED
@@ -1,1370 +1,111 @@
 
1
  """
2
- Computer Vision Processing Module for BackgroundFX Pro
3
- Contains segmentation, mask refinement, background replacement, and helper functions
 
 
 
 
 
 
 
 
 
 
4
  """
5
 
6
- # ---- Early thread env (defensive) ----
7
- import os
8
- if 'OMP_NUM_THREADS' not in os.environ:
9
- os.environ['OMP_NUM_THREADS'] = '4'
10
- os.environ['MKL_NUM_THREADS'] = '4'
11
-
12
- import logging
13
- from typing import Optional, Tuple, Dict, Any
14
- import numpy as np
15
- import cv2
16
- import torch
 
 
 
 
 
 
 
 
 
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
- # ============================================================================
21
- # CONFIGURATION AND CONSTANTS
22
- # ============================================================================
23
-
24
- # Version control flags for CV functions
25
- USE_ENHANCED_SEGMENTATION = True
26
- USE_AUTO_TEMPORAL_CONSISTENCY = True # reserved for future temporal smoothing
27
- USE_INTELLIGENT_PROMPTING = True
28
- USE_ITERATIVE_REFINEMENT = True
29
-
30
- # Validator thresholds (softened to avoid false negatives)
31
- MIN_AREA_RATIO = 0.015 # 1.5% of frame
32
- MAX_AREA_RATIO = 0.97 # 97% of frame
33
-
34
- # GrabCut / saliency config
35
- GRABCUT_ITERS = 3
36
- SALIENCY_THRESH = 0.65
37
-
38
- # Professional background templates
39
- PROFESSIONAL_BACKGROUNDS = {
40
- "office_modern": {
41
- "name": "Modern Office",
42
- "type": "gradient",
43
- "colors": ["#f8f9fa", "#e9ecef", "#dee2e6"],
44
- "direction": "diagonal",
45
- "description": "Clean, contemporary office environment",
46
- "brightness": 0.95,
47
- "contrast": 1.1
48
- },
49
- "studio_blue": {
50
- "name": "Professional Blue",
51
- "type": "gradient",
52
- "colors": ["#1e3c72", "#2a5298", "#3498db"],
53
- "direction": "radial",
54
- "description": "Broadcast-quality blue studio",
55
- "brightness": 0.9,
56
- "contrast": 1.2
57
- },
58
- "studio_green": {
59
- "name": "Broadcast Green",
60
- "type": "color",
61
- "colors": ["#00b894"],
62
- "chroma_key": True,
63
- "description": "Professional green screen replacement",
64
- "brightness": 1.0,
65
- "contrast": 1.0
66
- },
67
- "minimalist": {
68
- "name": "Minimalist White",
69
- "type": "gradient",
70
- "colors": ["#ffffff", "#f1f2f6", "#ddd"],
71
- "direction": "soft_radial",
72
- "description": "Clean, minimal background",
73
- "brightness": 0.98,
74
- "contrast": 0.9
75
- },
76
- "warm_gradient": {
77
- "name": "Warm Sunset",
78
- "type": "gradient",
79
- "colors": ["#ff7675", "#fd79a8", "#fdcb6e"],
80
- "direction": "diagonal",
81
- "description": "Warm, inviting atmosphere",
82
- "brightness": 0.85,
83
- "contrast": 1.15
84
- },
85
- "tech_dark": {
86
- "name": "Tech Dark",
87
- "type": "gradient",
88
- "colors": ["#0c0c0c", "#2d3748", "#4a5568"],
89
- "direction": "vertical",
90
- "description": "Modern tech/gaming setup",
91
- "brightness": 0.7,
92
- "contrast": 1.3
93
- }
94
- }
95
-
96
- # ============================================================================
97
- # CUSTOM EXCEPTIONS
98
- # ============================================================================
99
-
100
- class SegmentationError(Exception):
101
- """Custom exception for segmentation failures"""
102
- pass
103
-
104
- class MaskRefinementError(Exception):
105
- """Custom exception for mask refinement failures"""
106
- pass
107
-
108
- class BackgroundReplacementError(Exception):
109
- """Custom exception for background replacement failures"""
110
- pass
111
-
112
- # ============================================================================
113
- # LETTERBOX FIT (RGB in, RGB out) for custom background images
114
- # ============================================================================
115
-
116
- def _fit_image_letterbox(img_rgb: np.ndarray, dst_w: int, dst_h: int, fill=(32, 32, 32)) -> np.ndarray:
117
- h, w = img_rgb.shape[:2]
118
- if h == 0 or w == 0:
119
- return np.full((dst_h, dst_w, 3), fill, dtype=np.uint8)
120
-
121
- src_aspect = w / max(1, h)
122
- dst_aspect = dst_w / max(1, dst_h)
123
-
124
- if src_aspect > dst_aspect:
125
- new_w = dst_w
126
- new_h = int(round(dst_w / src_aspect))
127
- else:
128
- new_h = dst_h
129
- new_w = int(round(dst_h * src_aspect))
130
-
131
- resized = cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)
132
- canvas = np.full((dst_h, dst_w, 3), fill, dtype=np.uint8)
133
- y0 = (dst_h - new_h) // 2
134
- x0 = (dst_w - new_w) // 2
135
- canvas[y0:y0+new_h, x0:x0+new_w] = resized
136
- return canvas
137
-
138
- # ============================================================================
139
- # MAIN SEGMENTATION FUNCTIONS
140
- # ============================================================================
141
-
142
- def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
143
- """High-quality person segmentation with intelligent automation and robust cascade"""
144
- if not USE_ENHANCED_SEGMENTATION:
145
- return segment_person_hq_original(image, predictor, fallback_enabled)
146
-
147
- logger.debug("Using ENHANCED segmentation with intelligent automation")
148
-
149
- if image is None or image.size == 0:
150
- raise SegmentationError("Invalid input image")
151
-
152
- try:
153
- # 1) SAM2 (if available)
154
- if predictor and hasattr(predictor, 'set_image') and hasattr(predictor, 'predict'):
155
- try:
156
- predictor.set_image(image)
157
- if USE_INTELLIGENT_PROMPTING:
158
- mask = _segment_with_intelligent_prompts(image, predictor, fallback_enabled=True)
159
- else:
160
- mask = _segment_with_basic_prompts(image, predictor, fallback_enabled=True)
161
-
162
- if USE_ITERATIVE_REFINEMENT and mask is not None:
163
- mask = _auto_refine_mask_iteratively(image, mask, predictor)
164
-
165
- if _validate_mask_quality(mask, image.shape[:2]):
166
- logger.debug("SAM2 mask accepted by validator")
167
- return mask
168
- logger.warning("SAM2 mask failed validation; cascading to classical methods.")
169
- except Exception as e:
170
- logger.warning(f"SAM2 segmentation error: {e}")
171
-
172
- # 2) Classical cascade when SAM2 is absent/weak
173
- classical = _classical_segmentation_cascade(image)
174
- if _validate_mask_quality(classical, image.shape[:2]):
175
- logger.debug("Classical cascade mask accepted by validator")
176
- return classical
177
-
178
- logger.warning("Classical cascade produced weak mask; using geometric fallback.")
179
- return _geometric_person_mask(image)
180
-
181
- except Exception as e:
182
- logger.error(f"Unexpected segmentation error: {e}")
183
- if fallback_enabled:
184
- return _geometric_person_mask(image)
185
- else:
186
- raise SegmentationError(f"Unexpected error: {e}")
187
-
188
- def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
189
- """Original version of person segmentation for rollback"""
190
- if image is None or image.size == 0:
191
- raise SegmentationError("Invalid input image")
192
-
193
- try:
194
- # SAFE PREDICTOR CHECK
195
- if predictor and hasattr(predictor, 'set_image') and hasattr(predictor, 'predict'):
196
- h, w = image.shape[:2]
197
- predictor.set_image(image)
198
-
199
- points = np.array([
200
- [w//2, h//4],
201
- [w//2, h//2],
202
- [w//2, 3*h//4],
203
- [w//3, h//2],
204
- [2*w//3, h//2],
205
- [w//2, h//6],
206
- [w//4, 2*h//3],
207
- [3*w//4, 2*h//3],
208
- ], dtype=np.float32)
209
-
210
- labels = np.ones(len(points), dtype=np.int32)
211
-
212
- with torch.no_grad():
213
- masks, scores, _ = predictor.predict(
214
- point_coords=points,
215
- point_labels=labels,
216
- multimask_output=True
217
- )
218
-
219
- if masks is None or len(masks) == 0:
220
- logger.warning("SAM2 returned no masks")
221
- else:
222
- best_idx = np.argmax(scores) if (scores is not None and len(scores) > 0) else 0
223
- best_mask = masks[best_idx]
224
- mask = _process_mask(best_mask)
225
- if _validate_mask_quality(mask, image.shape[:2]):
226
- logger.debug("Original SAM2 mask accepted by validator")
227
- return mask
228
-
229
- if fallback_enabled:
230
- logger.warning("Falling back to classical segmentation")
231
- return _classical_segmentation_cascade(image)
232
- else:
233
- raise SegmentationError("SAM2 failed and fallback disabled")
234
-
235
- except Exception as e:
236
- logger.error(f"Unexpected segmentation error: {e}")
237
- if fallback_enabled:
238
- return _classical_segmentation_cascade(image)
239
- else:
240
- raise SegmentationError(f"Unexpected error: {e}")
241
-
242
- # ============================================================================
243
- # MASK REFINEMENT FUNCTIONS
244
- # ============================================================================
245
-
246
- def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any,
247
- fallback_enabled: bool = True) -> np.ndarray:
248
- """Enhanced mask refinement with MatAnyone and robust fallbacks"""
249
- if image is None or mask is None:
250
- raise MaskRefinementError("Invalid input image or mask")
251
-
252
- try:
253
- mask = _process_mask(mask)
254
-
255
- # 1) MatAnyOne (if present)
256
- if matanyone_processor is not None:
257
- try:
258
- logger.debug("Attempting MatAnyone refinement")
259
- refined_mask = _matanyone_refine(image, mask, matanyone_processor)
260
-
261
- if refined_mask is not None and _validate_mask_quality(refined_mask, image.shape[:2]):
262
- logger.debug("MatAnyone refinement successful")
263
- return refined_mask
264
- else:
265
- logger.warning("MatAnyOne produced poor quality mask")
266
-
267
- except Exception as e:
268
- logger.warning(f"MatAnyOne refinement failed: {e}")
269
-
270
- # 2) Advanced OpenCV refinement
271
- try:
272
- logger.debug("Using enhanced OpenCV refinement")
273
- opencv_mask = enhance_mask_opencv_advanced(image, mask)
274
- if _validate_mask_quality(opencv_mask, image.shape[:2]):
275
- return opencv_mask
276
- except Exception as e:
277
- logger.warning(f"OpenCV advanced refinement failed: {e}")
278
-
279
- # 3) GrabCut refinement (auto rect from saliency)
280
- try:
281
- logger.debug("Using GrabCut refinement fallback")
282
- gc_mask = _refine_with_grabcut(image, mask)
283
- if _validate_mask_quality(gc_mask, image.shape[:2]):
284
- return gc_mask
285
- except Exception as e:
286
- logger.warning(f"GrabCut refinement failed: {e}")
287
-
288
- # 4) Saliency flood-fill refinement
289
- try:
290
- logger.debug("Using saliency refinement fallback")
291
- sal_mask = _refine_with_saliency(image, mask)
292
- if _validate_mask_quality(sal_mask, image.shape[:2]):
293
- return sal_mask
294
- except Exception as e:
295
- logger.warning(f"Saliency refinement failed: {e}")
296
-
297
- if fallback_enabled:
298
- logger.debug("Returning original mask after failed refinements")
299
- return mask
300
- else:
301
- raise MaskRefinementError("All refinements failed")
302
-
303
- except MaskRefinementError:
304
- raise
305
- except Exception as e:
306
- logger.error(f"Unexpected mask refinement error: {e}")
307
- if fallback_enabled:
308
- return enhance_mask_opencv_advanced(image, mask)
309
- else:
310
- raise MaskRefinementError(f"Unexpected error: {e}")
311
-
312
- def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
313
- """Advanced OpenCV-based mask enhancement with multiple techniques"""
314
- try:
315
- if len(mask.shape) == 3:
316
- mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
317
-
318
- if mask.max() <= 1.0:
319
- mask = (mask * 255).astype(np.uint8)
320
-
321
- refined_mask = cv2.bilateralFilter(mask, 9, 75, 75)
322
- refined_mask = _guided_filter_approx(image, refined_mask, radius=8, eps=0.2)
323
-
324
- kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
325
- refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_CLOSE, kernel_close)
326
-
327
- kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
328
- refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_OPEN, kernel_open)
329
-
330
- refined_mask = cv2.GaussianBlur(refined_mask, (3, 3), 0.8)
331
-
332
- _, refined_mask = cv2.threshold(refined_mask, 127, 255, cv2.THRESH_BINARY)
333
-
334
- return refined_mask
335
-
336
- except Exception as e:
337
- logger.warning(f"Enhanced OpenCV refinement failed: {e}")
338
- return cv2.GaussianBlur(mask, (5, 5), 1.0)
339
-
340
- # ============================================================================
341
- # MATANYONE REFINEMENT (SAFE)
342
- # ============================================================================
343
-
344
- def _matanyone_refine(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any) -> Optional[np.ndarray]:
345
- """Safe MatAnyOne refinement for a single frame with correct interface."""
346
- try:
347
- # Check for correct MatAnyOne interface
348
- if not hasattr(matanyone_processor, 'step') or not hasattr(matanyone_processor, 'output_prob_to_mask'):
349
- logger.warning("MatAnyOne processor missing required methods (step, output_prob_to_mask)")
350
- return None
351
-
352
- # Preprocess image: ensure float32, RGB, (C, H, W)
353
- if isinstance(image, np.ndarray):
354
- img = image.astype(np.float32)
355
- if img.max() > 1.0:
356
- img /= 255.0
357
- if img.shape[2] == 3:
358
- img = np.transpose(img, (2, 0, 1)) # (H, W, C) → (C, H, W)
359
- img_tensor = torch.from_numpy(img)
360
- else:
361
- img_tensor = image # assume already tensor
362
-
363
- # Preprocess mask: ensure float32, (H, W)
364
- if isinstance(mask, np.ndarray):
365
- mask_tensor = mask.astype(np.float32)
366
- if mask_tensor.max() > 1.0:
367
- mask_tensor /= 255.0
368
- if mask_tensor.ndim > 2:
369
- mask_tensor = mask_tensor.squeeze()
370
- mask_tensor = torch.from_numpy(mask_tensor)
371
- else:
372
- mask_tensor = mask
373
-
374
- # Move tensors to processor's device if available
375
- device = getattr(matanyone_processor, 'device', 'cpu')
376
- img_tensor = img_tensor.to(device)
377
- mask_tensor = mask_tensor.to(device)
378
-
379
- # Step: encode mask on this frame
380
- objects = [1] # single object id
381
- with torch.no_grad():
382
- output_prob = matanyone_processor.step(img_tensor, mask_tensor, objects=objects)
383
- # MatAnyOne returns output_prob as tensor
384
-
385
- refined_mask_tensor = matanyone_processor.output_prob_to_mask(output_prob)
386
-
387
- # Convert to numpy and to uint8
388
- refined_mask = refined_mask_tensor.squeeze().detach().cpu().numpy()
389
- if refined_mask.max() <= 1.0:
390
- refined_mask = (refined_mask * 255).astype(np.uint8)
391
- else:
392
- refined_mask = np.clip(refined_mask, 0, 255).astype(np.uint8)
393
-
394
- logger.debug("MatAnyOne refinement successful")
395
- return refined_mask
396
-
397
- except Exception as e:
398
- logger.warning(f"MatAnyOne refinement error: {e}")
399
- return None
400
-
401
- # ============================================================================
402
- # BACKGROUND REPLACEMENT FUNCTIONS
403
- # ============================================================================
404
-
405
- def replace_background_hq(frame: np.ndarray, mask: np.ndarray, background: np.ndarray,
406
- fallback_enabled: bool = True) -> np.ndarray:
407
- """Enhanced background replacement with comprehensive error handling"""
408
- if frame is None or mask is None or background is None:
409
- raise BackgroundReplacementError("Invalid input frame, mask, or background")
410
-
411
- try:
412
- background = cv2.resize(background, (frame.shape[1], frame.shape[0]),
413
- interpolation=cv2.INTER_LANCZOS4)
414
-
415
- if len(mask.shape) == 3:
416
- mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
417
-
418
- if mask.dtype != np.uint8:
419
- mask = mask.astype(np.uint8)
420
-
421
- if mask.max() <= 1.0:
422
- logger.debug("Converting normalized mask to 0-255 range")
423
- mask = (mask * 255).astype(np.uint8)
424
-
425
- try:
426
- result = _advanced_compositing(frame, mask, background)
427
- logger.debug("Advanced compositing successful")
428
- return result
429
-
430
- except Exception as e:
431
- logger.warning(f"Advanced compositing failed: {e}")
432
- if fallback_enabled:
433
- return _simple_compositing(frame, mask, background)
434
- else:
435
- raise BackgroundReplacementError(f"Advanced compositing failed: {e}")
436
-
437
- except BackgroundReplacementError:
438
- raise
439
- except Exception as e:
440
- logger.error(f"Unexpected background replacement error: {e}")
441
- if fallback_enabled:
442
- return _simple_compositing(frame, mask, background)
443
- else:
444
- raise BackgroundReplacementError(f"Unexpected error: {e}")
445
-
446
- def create_professional_background(bg_config: Dict[str, Any] | str, width: int, height: int) -> np.ndarray:
447
  """
448
- Enhanced professional background creation with quality improvements.
449
- Accepts style string or dict (can include custom_path). Returns BGR (OpenCV).
450
  """
451
- try:
452
- choice = "minimalist"
453
- custom_path = None
454
-
455
- if isinstance(bg_config, dict):
456
- choice = bg_config.get("background_choice", bg_config.get("name", "minimalist"))
457
- custom_path = bg_config.get("custom_path")
458
-
459
- # Custom background path (letterboxed + BGR out)
460
- if custom_path and os.path.exists(custom_path):
461
- img_bgr = cv2.imread(custom_path, cv2.IMREAD_COLOR)
462
- if img_bgr is not None:
463
- img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
464
- fitted_rgb = _fit_image_letterbox(img_rgb, width, height, fill=(32, 32, 32))
465
- fitted_bgr = cv2.cvtColor(fitted_rgb, cv2.COLOR_RGB2BGR)
466
- return fitted_bgr
467
- else:
468
- logger.warning(f"Failed to read custom background at {custom_path}. Falling back to style.")
469
-
470
- # Direct dict colors/type form support
471
- if "type" in bg_config and "colors" in bg_config:
472
- if bg_config["type"] == "color":
473
- background = _create_solid_background(bg_config, width, height)
474
- else:
475
- background = _create_gradient_background_enhanced(bg_config, width, height)
476
- background = _apply_background_adjustments(background, bg_config)
477
- return background
478
-
479
- elif isinstance(bg_config, str):
480
- choice = bg_config
481
-
482
- choice = (choice or "minimalist").lower()
483
- if choice not in PROFESSIONAL_BACKGROUNDS:
484
- choice = "minimalist"
485
-
486
- cfg = PROFESSIONAL_BACKGROUNDS[choice]
487
-
488
- if cfg.get("type") == "color":
489
- background = _create_solid_background(cfg, width, height)
490
- else:
491
- background = _create_gradient_background_enhanced(cfg, width, height)
492
-
493
- background = _apply_background_adjustments(background, cfg)
494
- return background
495
-
496
- except Exception as e:
497
- logger.error(f"Background creation error: {e}")
498
- return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
499
-
500
- # ============================================================================
501
- # VALIDATION FUNCTION
502
- # ============================================================================
503
-
504
- def validate_video_file(video_path: str) -> Tuple[bool, str]:
505
- """Enhanced video file validation with detailed checks"""
506
- if not video_path or not os.path.exists(video_path):
507
  return False, "Video file not found"
508
 
509
  try:
510
- file_size = os.path.getsize(video_path)
511
- if file_size == 0:
512
- return False, "Video file is empty"
513
-
514
- if file_size > 2 * 1024 * 1024 * 1024:
515
- return False, "Video file too large (>2GB)"
516
 
517
  cap = cv2.VideoCapture(video_path)
518
  if not cap.isOpened():
519
- return False, "Cannot open video file"
520
-
521
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
522
- fps = cap.get(cv2.CAP_PROP_FPS)
523
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
524
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
525
 
 
 
 
 
526
  cap.release()
527
 
528
- if frame_count == 0:
529
- return False, "Video appears to be empty (0 frames)"
530
-
531
  if fps <= 0 or fps > 120:
532
- return False, f"Invalid frame rate: {fps}"
533
-
534
- if width <= 0 or height <= 0:
535
- return False, f"Invalid resolution: {width}x{height}"
536
-
537
- if width > 4096 or height > 4096:
538
- return False, f"Resolution too high: {width}x{height} (max 4096x4096)"
539
-
540
- duration = frame_count / fps
541
- if duration > 300:
542
- return False, f"Video too long: {duration:.1f}s (max 300s)"
543
-
544
- return True, f"Valid video: {width}x{height}, {fps:.1f}fps, {duration:.1f}s"
545
-
546
- except Exception as e:
547
- return False, f"Error validating video: {str(e)}"
548
-
549
- # ============================================================================
550
- # HELPER FUNCTIONS - SEGMENTATION
551
- # ============================================================================
552
-
553
- def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
554
- """Intelligent automatic prompt generation for segmentation with safe predictor access"""
555
- try:
556
- # Double-check predictor validity
557
- if predictor is None or not hasattr(predictor, 'predict'):
558
- if fallback_enabled:
559
- return _classical_segmentation_cascade(image)
560
- else:
561
- raise SegmentationError("Invalid predictor in intelligent prompts")
562
-
563
- h, w = image.shape[:2]
564
- pos_points, neg_points = _generate_smart_prompts(image)
565
-
566
- if len(pos_points) == 0:
567
- pos_points = np.array([[w//2, h//2]], dtype=np.float32)
568
-
569
- points = np.vstack([pos_points, neg_points])
570
- labels = np.hstack([
571
- np.ones(len(pos_points), dtype=np.int32),
572
- np.zeros(len(neg_points), dtype=np.int32)
573
- ])
574
-
575
- logger.debug(f"Using {len(pos_points)} positive, {len(neg_points)} negative points")
576
-
577
- with torch.no_grad():
578
- masks, scores, _ = predictor.predict(
579
- point_coords=points,
580
- point_labels=labels,
581
- multimask_output=True
582
- )
583
-
584
- if masks is None or len(masks) == 0:
585
- raise SegmentationError("No masks generated")
586
-
587
- if scores is not None and len(scores) > 0:
588
- best_idx = np.argmax(scores)
589
- best_mask = masks[best_idx]
590
- logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}")
591
- else:
592
- best_mask = masks[0]
593
-
594
- return _process_mask(best_mask)
595
-
596
- except Exception as e:
597
- logger.error(f"Intelligent prompting failed: {e}")
598
- if fallback_enabled:
599
- return _classical_segmentation_cascade(image)
600
- else:
601
- raise
602
-
603
- def _segment_with_basic_prompts(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
604
- """Basic prompting method for segmentation with safe predictor access"""
605
- try:
606
- # Double-check predictor validity
607
- if predictor is None or not hasattr(predictor, 'predict'):
608
- if fallback_enabled:
609
- return _classical_segmentation_cascade(image)
610
- else:
611
- raise SegmentationError("Invalid predictor in basic prompts")
612
-
613
- h, w = image.shape[:2]
614
-
615
- positive_points = np.array([
616
- [w//2, h//3],
617
- [w//2, h//2],
618
- [w//2, 2*h//3],
619
- ], dtype=np.float32)
620
-
621
- negative_points = np.array([
622
- [w//10, h//10],
623
- [9*w//10, h//10],
624
- [w//10, 9*h//10],
625
- [9*w//10, 9*h//10],
626
- ], dtype=np.float32)
627
-
628
- points = np.vstack([positive_points, negative_points])
629
- labels = np.array([1, 1, 1, 0, 0, 0, 0], dtype=np.int32)
630
-
631
- with torch.no_grad():
632
- masks, scores, _ = predictor.predict(
633
- point_coords=points,
634
- point_labels=labels,
635
- multimask_output=True
636
- )
637
-
638
- if masks is None or len(masks) == 0:
639
- raise SegmentationError("No masks generated")
640
-
641
- best_idx = np.argmax(scores) if scores is not None and len(scores) > 0 else 0
642
- best_mask = masks[best_idx]
643
-
644
- return _process_mask(best_mask)
645
-
646
- except Exception as e:
647
- logger.error(f"Basic prompting failed: {e}")
648
- if fallback_enabled:
649
- return _classical_segmentation_cascade(image)
650
- else:
651
- raise
652
-
653
- def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
654
- """Generate optimal positive/negative points automatically"""
655
- try:
656
- h, w = image.shape[:2]
657
-
658
- saliency = _compute_saliency(image)
659
- positive_points = []
660
- if saliency is not None:
661
- saliency_thresh = (saliency > (SALIENCY_THRESH - 0.1)).astype(np.uint8) * 255
662
- contours, _ = cv2.findContours(saliency_thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
663
-
664
- if contours:
665
- for contour in sorted(contours, key=cv2.contourArea, reverse=True)[:3]:
666
- M = cv2.moments(contour)
667
- if M["m00"] != 0:
668
- cx = int(M["m10"] / M["m00"])
669
- cy = int(M["m01"] / M["m00"])
670
- if 0 < cx < w and 0 < cy < h:
671
- positive_points.append([cx, cy])
672
-
673
- if not positive_points:
674
- positive_points = [
675
- [w//2, h//3],
676
- [w//2, h//2],
677
- [w//2, 2*h//3],
678
- ]
679
-
680
- negative_points = [
681
- [10, 10],
682
- [w-10, 10],
683
- [10, h-10],
684
- [w-10, h-10],
685
- [w//2, 5],
686
- [w//2, h-5],
687
- ]
688
-
689
- return np.array(positive_points, dtype=np.float32), np.array(negative_points, dtype=np.float32)
690
-
691
- except Exception as e:
692
- logger.warning(f"Smart prompt generation failed: {e}")
693
- h, w = image.shape[:2]
694
- positive_points = np.array([[w//2, h//2]], dtype=np.float32)
695
- negative_points = np.array([[10, 10], [w-10, 10]], dtype=np.float32)
696
- return positive_points, negative_points
697
-
698
- # ============================================================================
699
- # CLASSICAL SEGMENTATION CASCADE
700
- # ============================================================================
701
-
702
- def _classical_segmentation_cascade(image: np.ndarray) -> np.ndarray:
703
- """
704
- Robust non-AI cascade:
705
- 1) Background subtraction via edge-median
706
- 2) Saliency flood-fill
707
- 3) GrabCut from auto-rect
708
- 4) Geometric ellipse (final fallback)
709
- """
710
- # 1) Background subtraction
711
- try:
712
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
713
-
714
- edge_pixels = np.concatenate([
715
- gray[0, :], gray[-1, :], gray[:, 0], gray[:, -1]
716
- ])
717
- bg_color = np.median(edge_pixels)
718
-
719
- diff = np.abs(gray.astype(float) - bg_color)
720
- mask = (diff > 30).astype(np.uint8) * 255
721
-
722
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))
723
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)))
724
-
725
- if _validate_mask_quality(mask, image.shape[:2]):
726
- logger.info("Background subtraction fallback successful")
727
- return mask
728
-
729
- except Exception as e:
730
- logger.debug(f"Background subtraction fallback failed: {e}")
731
-
732
- # 2) Saliency flood-fill refinement
733
- try:
734
- sal_ref = _refine_with_saliency(image, mask if 'mask' in locals() else np.zeros(image.shape[:2], np.uint8))
735
- if _validate_mask_quality(sal_ref, image.shape[:2]):
736
- return sal_ref
737
- except Exception as e:
738
- logger.debug(f"Saliency cascade failed: {e}")
739
-
740
- # 3) GrabCut refinement
741
- try:
742
- gc_mask = _refine_with_grabcut(image, mask if 'mask' in locals() else np.zeros(image.shape[:2], np.uint8))
743
- if _validate_mask_quality(gc_mask, image.shape[:2]):
744
- return gc_mask
745
- except Exception as e:
746
- logger.debug(f"GrabCut cascade failed: {e}")
747
-
748
- # 4) Geometric final fallback
749
- logger.info("Using geometric fallback mask")
750
- return _geometric_person_mask(image)
751
-
752
- # ============================================================================
753
- # SALIENCY / GRABCUT HELPERS
754
- # ============================================================================
755
-
756
- def _compute_saliency(image: np.ndarray) -> Optional[np.ndarray]:
757
- try:
758
- if hasattr(cv2, "saliency"):
759
- sal = cv2.saliency.StaticSaliencySpectralResidual_create()
760
- ok, smap = sal.computeSaliency(image)
761
- if ok:
762
- smap = (smap - smap.min()) / max(1e-6, (smap.max() - smap.min()))
763
- return smap
764
- except Exception:
765
- pass
766
- # Fallback spectral-ish hint using DCT trick
767
- try:
768
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
769
- log = np.log(gray + 1e-6)
770
- dct = cv2.dct(log)
771
- dct[:5, :5] = 0
772
- recon = cv2.idct(dct)
773
- recon = (recon - recon.min()) / max(1e-6, (recon.max() - recon.min()))
774
- return recon
775
- except Exception:
776
- return None
777
-
778
- def _auto_person_rect(image: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
779
- sal = _compute_saliency(image)
780
- if sal is None:
781
- return None
782
- th = (sal > SALIENCY_THRESH).astype(np.uint8) * 255
783
- contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
784
- if not contours:
785
- return None
786
- c = max(contours, key=cv2.contourArea)
787
- x, y, w, h = cv2.boundingRect(c)
788
- # Inflate
789
- pad_x, pad_y = int(0.05*w), int(0.05*h)
790
- H, W = image.shape[:2]
791
- x = max(0, x - pad_x); y = max(0, y - pad_y)
792
- w = min(W - x, w + 2*pad_x); h = min(H - y, h + 2*pad_y)
793
- return (x, y, w, h)
794
-
795
- def _refine_with_grabcut(image: np.ndarray, seed_mask: np.ndarray) -> np.ndarray:
796
- h, w = image.shape[:2]
797
- gc_mask = np.full((h, w), cv2.GC_PR_BGD, dtype=np.uint8)
798
- sure_fg = (seed_mask > 200)
799
- gc_mask[sure_fg] = cv2.GC_FGD
800
-
801
- rect = _auto_person_rect(image)
802
- if rect is not None:
803
- x, y, rw, rh = rect
804
- else:
805
- rw, rh = int(w * 0.5), int(h * 0.7)
806
- x, y = (w - rw)//2, int(h*0.15)
807
-
808
- bgdModel = np.zeros((1, 65), np.float64)
809
- fgdModel = np.zeros((1, 65), np.float64)
810
-
811
- cv2.grabCut(image, gc_mask, (x, y, rw, rh), bgdModel, fgdModel, GRABCUT_ITERS, cv2.GC_INIT_WITH_MASK)
812
-
813
- mask_bin = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
814
- mask_bin = cv2.morphologyEx(mask_bin, cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
815
- return mask_bin
816
-
817
- def _refine_with_saliency(image: np.ndarray, seed_mask: np.ndarray) -> np.ndarray:
818
- sal = _compute_saliency(image)
819
- if sal is None:
820
- return seed_mask
821
- th = (sal > SALIENCY_THRESH).astype(np.uint8) * 255
822
-
823
- # Anchor from seed center mass or center fallback
824
- ys, xs = np.where(seed_mask > 127)
825
- if len(ys) > 0:
826
- cx, cy = int(np.mean(xs)), int(np.mean(ys))
827
- else:
828
- h, w = image.shape[:2]
829
- cx, cy = w//2, h//2
830
-
831
- ff = th.copy()
832
- h, w = th.shape
833
- mask = np.zeros((h+2, w+2), np.uint8)
834
- cv2.floodFill(ff, mask, (cx, cy), 255, loDiff=5, upDiff=5, flags=4)
835
- ff = cv2.morphologyEx(ff, cv2.MORPH_CLOSE, np.ones((5,5), np.uint8))
836
- return ff
837
-
838
- # ============================================================================
839
- # HELPER FUNCTIONS - REFINEMENT
840
- # ============================================================================
841
-
842
- def _auto_refine_mask_iteratively(image: np.ndarray, initial_mask: np.ndarray,
843
- predictor: Any, max_iterations: int = 2) -> np.ndarray:
844
- """Automatically refine mask based on quality assessment with safe predictor access"""
845
- try:
846
- if predictor is None or not hasattr(predictor, 'predict'):
847
- logger.warning("Predictor invalid for iterative refinement, returning initial mask")
848
- return initial_mask
849
-
850
- current_mask = initial_mask.copy()
851
-
852
- for iteration in range(max_iterations):
853
- quality_score = _assess_mask_quality(current_mask, image)
854
- logger.debug(f"Iteration {iteration}: quality score = {quality_score:.3f}")
855
-
856
- if quality_score > 0.85:
857
- logger.debug(f"Quality sufficient after {iteration} iterations")
858
- break
859
-
860
- problem_areas = _find_mask_errors(current_mask, image)
861
-
862
- if np.any(problem_areas):
863
- corrective_points, corrective_labels = _generate_corrective_prompts(
864
- image, current_mask, problem_areas
865
- )
866
-
867
- if len(corrective_points) > 0:
868
- try:
869
- with torch.no_grad():
870
- masks, scores, _ = predictor.predict(
871
- point_coords=corrective_points,
872
- point_labels=corrective_labels,
873
- mask_input=current_mask[None, :, :],
874
- multimask_output=False
875
- )
876
-
877
- if masks is not None and len(masks) > 0:
878
- refined_mask = _process_mask(masks[0])
879
-
880
- if _assess_mask_quality(refined_mask, image) > quality_score:
881
- current_mask = refined_mask
882
- logger.debug(f"Improved mask in iteration {iteration}")
883
- else:
884
- logger.debug(f"Refinement didn't improve quality in iteration {iteration}")
885
- break
886
-
887
- except Exception as e:
888
- logger.debug(f"Refinement iteration {iteration} failed: {e}")
889
- break
890
- else:
891
- logger.debug("No problem areas detected")
892
- break
893
-
894
- return current_mask
895
-
896
- except Exception as e:
897
- logger.warning(f"Iterative refinement failed: {e}")
898
- return initial_mask
899
-
900
- def _assess_mask_quality(mask: np.ndarray, image: np.ndarray) -> float:
901
- """Assess mask quality automatically"""
902
- try:
903
- h, w = image.shape[:2]
904
- scores = []
905
-
906
- mask_area = np.sum(mask > 127)
907
- total_area = h * w
908
- area_ratio = mask_area / total_area
909
-
910
- if 0.05 <= area_ratio <= 0.8:
911
- area_score = 1.0
912
- elif area_ratio < 0.05:
913
- area_score = area_ratio / 0.05
914
- else:
915
- area_score = max(0, 1.0 - (area_ratio - 0.8) / 0.2)
916
- scores.append(area_score)
917
-
918
- mask_binary = mask > 127
919
- if np.any(mask_binary):
920
- mask_center_y, mask_center_x = np.where(mask_binary)
921
- center_y = np.mean(mask_center_y) / h
922
- center_x = np.mean(mask_center_x) / w
923
-
924
- center_score = 1.0 - min(abs(center_x - 0.5), abs(center_y - 0.5))
925
- scores.append(center_score)
926
- else:
927
- scores.append(0.0)
928
-
929
- edges = cv2.Canny(mask, 50, 150)
930
- edge_density = np.sum(edges > 0) / total_area
931
- smoothness_score = max(0, 1.0 - edge_density * 10)
932
- scores.append(smoothness_score)
933
-
934
- num_labels, _ = cv2.connectedComponents(mask)
935
- connectivity_score = max(0, 1.0 - (num_labels - 2) * 0.2)
936
- scores.append(connectivity_score)
937
-
938
- weights = [0.3, 0.2, 0.3, 0.2]
939
- overall_score = np.average(scores, weights=weights)
940
-
941
- return overall_score
942
-
943
- except Exception as e:
944
- logger.warning(f"Quality assessment failed: {e}")
945
- return 0.5
946
-
947
- def _find_mask_errors(mask: np.ndarray, image: np.ndarray) -> np.ndarray:
948
- """Identify problematic areas in mask"""
949
- try:
950
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
951
- edges = cv2.Canny(gray, 50, 150)
952
- mask_edges = cv2.Canny(mask, 50, 150)
953
- edge_discrepancy = cv2.bitwise_xor(edges, mask_edges)
954
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
955
- error_regions = cv2.dilate(edge_discrepancy, kernel, iterations=1)
956
- return error_regions > 0
957
- except Exception as e:
958
- logger.warning(f"Error detection failed: {e}")
959
- return np.zeros_like(mask, dtype=bool)
960
-
961
- def _generate_corrective_prompts(image: np.ndarray, mask: np.ndarray,
962
- problem_areas: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
963
- """Generate corrective prompts based on problem areas"""
964
- try:
965
- contours, _ = cv2.findContours(problem_areas.astype(np.uint8),
966
- cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
967
-
968
- corrective_points = []
969
- corrective_labels = []
970
-
971
- for contour in contours:
972
- if cv2.contourArea(contour) > 100:
973
- M = cv2.moments(contour)
974
- if M["m00"] != 0:
975
- cx = int(M["m10"] / M["m00"])
976
- cy = int(M["m01"] / M["m00"])
977
-
978
- current_mask_value = mask[cy, cx]
979
-
980
- if current_mask_value < 127:
981
- corrective_points.append([cx, cy])
982
- corrective_labels.append(1)
983
- else:
984
- corrective_points.append([cx, cy])
985
- corrective_labels.append(0)
986
-
987
- return (np.array(corrective_points, dtype=np.float32) if corrective_points else np.array([]).reshape(0, 2),
988
- np.array(corrective_labels, dtype=np.int32) if corrective_labels else np.array([], dtype=np.int32))
989
-
990
- except Exception as e:
991
- logger.warning(f"Corrective prompt generation failed: {e}")
992
- return np.array([]).reshape(0, 2), np.array([], dtype=np.int32)
993
-
994
- # ============================================================================
995
- # HELPER FUNCTIONS - PROCESSING
996
- # ============================================================================
997
-
998
- def _process_mask(mask: np.ndarray) -> np.ndarray:
999
- """Process raw mask to ensure correct format and range"""
1000
- try:
1001
- if len(mask.shape) > 2:
1002
- mask = mask.squeeze()
1003
-
1004
- if len(mask.shape) > 2:
1005
- mask = mask[:, :, 0] if mask.shape[2] > 0 else mask.sum(axis=2)
1006
-
1007
- if mask.dtype == bool:
1008
- mask = mask.astype(np.uint8) * 255
1009
- elif mask.dtype == np.float32 or mask.dtype == np.float64:
1010
- if mask.max() <= 1.0:
1011
- mask = (mask * 255).astype(np.uint8)
1012
- else:
1013
- mask = np.clip(mask, 0, 255).astype(np.uint8)
1014
- else:
1015
- mask = mask.astype(np.uint8)
1016
-
1017
- kernel = np.ones((3, 3), np.uint8)
1018
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
1019
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
1020
-
1021
- _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
1022
-
1023
- return mask
1024
-
1025
- except Exception as e:
1026
- logger.error(f"Mask processing failed: {e}")
1027
- h, w = mask.shape[:2] if (mask is not None and hasattr(mask, 'shape') and len(mask.shape) >= 2) else (256, 256)
1028
- fallback = np.zeros((h, w), dtype=np.uint8)
1029
- fallback[h//4:3*h//4, w//4:3*w//4] = 255
1030
- return fallback
1031
-
1032
- def _validate_mask_quality(mask: np.ndarray, image_shape: Tuple[int, int]) -> bool:
1033
- """Validate that the mask meets quality criteria (soft reject policy)"""
1034
- try:
1035
- h, w = image_shape
1036
- mask_area = np.sum(mask > 127)
1037
- total_area = h * w
1038
-
1039
- area_ratio = mask_area / total_area
1040
- if area_ratio < MIN_AREA_RATIO or area_ratio > MAX_AREA_RATIO:
1041
- logger.warning(f"Suspicious mask area ratio: {area_ratio:.3f}")
1042
- return False
1043
-
1044
- mask_binary = mask > 127
1045
- mask_center_y, mask_center_x = np.where(mask_binary)
1046
-
1047
- if len(mask_center_y) == 0:
1048
- logger.warning("Empty mask")
1049
- return False
1050
-
1051
- center_y = np.mean(mask_center_y)
1052
- # Advisory only (we no longer hard-reject based on center)
1053
- if center_y < h * 0.08 or center_y > h * 0.98:
1054
- logger.warning(f"Mask center unusual (advisory): y={center_y/h:.2f}")
1055
-
1056
- return True
1057
-
1058
- except Exception as e:
1059
- logger.warning(f"Mask validation error: {e}")
1060
- return True
1061
-
1062
- def _fallback_segmentation(image: np.ndarray) -> np.ndarray:
1063
- """Legacy fallback segmentation; prefer _classical_segmentation_cascade"""
1064
- try:
1065
- logger.info("Using fallback segmentation strategy")
1066
- h, w = image.shape[:2]
1067
-
1068
- try:
1069
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
1070
-
1071
- edge_pixels = np.concatenate([
1072
- gray[0, :], gray[-1, :], gray[:, 0], gray[:, -1]
1073
- ])
1074
- bg_color = np.median(edge_pixels)
1075
-
1076
- diff = np.abs(gray.astype(float) - bg_color)
1077
- mask = (diff > 30).astype(np.uint8) * 255
1078
-
1079
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
1080
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
1081
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
1082
-
1083
- if _validate_mask_quality(mask, image.shape[:2]):
1084
- logger.info("Background subtraction fallback successful")
1085
- return mask
1086
-
1087
- except Exception as e:
1088
- logger.warning(f"Background subtraction fallback failed: {e}")
1089
-
1090
- # Geometric ellipse fallback
1091
- mask = _geometric_person_mask(image)
1092
- logger.info("Using geometric fallback mask")
1093
- return mask
1094
-
1095
- except Exception as e:
1096
- logger.error(f"All fallback strategies failed: {e}")
1097
- h, w = image.shape[:2]
1098
- mask = np.zeros((h, w), dtype=np.uint8)
1099
- mask[h//6:5*h//6, w//4:3*w//4] = 255
1100
- return mask
1101
-
1102
- def _guided_filter_approx(guide: np.ndarray, mask: np.ndarray, radius: int = 8, eps: float = 0.2) -> np.ndarray:
1103
- """Approximation of guided filter for edge-aware smoothing"""
1104
- try:
1105
- guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) if len(guide.shape) == 3 else guide
1106
- guide_gray = guide_gray.astype(np.float32) / 255.0
1107
- mask_float = mask.astype(np.float32) / 255.0
1108
-
1109
- kernel_size = 2 * radius + 1
1110
-
1111
- mean_guide = cv2.boxFilter(guide_gray, -1, (kernel_size, kernel_size))
1112
- mean_mask = cv2.boxFilter(mask_float, -1, (kernel_size, kernel_size))
1113
- corr_guide_mask = cv2.boxFilter(guide_gray * mask_float, -1, (kernel_size, kernel_size))
1114
-
1115
- cov_guide_mask = corr_guide_mask - mean_guide * mean_mask
1116
- mean_guide_sq = cv2.boxFilter(guide_gray * guide_gray, -1, (kernel_size, kernel_size))
1117
- var_guide = mean_guide_sq - mean_guide * mean_guide
1118
-
1119
- a = cov_guide_mask / (var_guide + eps)
1120
- b = mean_mask - a * mean_guide
1121
-
1122
- mean_a = cv2.boxFilter(a, -1, (kernel_size, kernel_size))
1123
- mean_b = cv2.boxFilter(b, -1, (kernel_size, kernel_size))
1124
-
1125
- output = mean_a * guide_gray + mean_b
1126
- output = np.clip(output * 255, 0, 255).astype(np.uint8)
1127
-
1128
- return output
1129
-
1130
- except Exception as e:
1131
- logger.warning(f"Guided filter approximation failed: {e}")
1132
- return mask
1133
-
1134
- # ============================================================================
1135
- # HELPER FUNCTIONS - COMPOSITING
1136
- # ============================================================================
1137
-
1138
- def _advanced_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray:
1139
- """Advanced compositing with edge feathering and color correction"""
1140
- try:
1141
- threshold = 100
1142
- _, mask_binary = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
1143
-
1144
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
1145
- mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_CLOSE, kernel)
1146
- mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_OPEN, kernel)
1147
-
1148
- mask_smooth = cv2.GaussianBlur(mask_binary.astype(np.float32), (5, 5), 1.0) / 255.0
1149
- mask_smooth = np.power(mask_smooth, 0.8)
1150
-
1151
- mask_smooth = np.where(mask_smooth > 0.5,
1152
- np.minimum(mask_smooth * 1.1, 1.0),
1153
- mask_smooth * 0.9)
1154
-
1155
- frame_adjusted = _color_match_edges(frame, background, mask_smooth)
1156
-
1157
- alpha_3ch = np.stack([mask_smooth] * 3, axis=2)
1158
-
1159
- frame_float = frame_adjusted.astype(np.float32)
1160
- background_float = background.astype(np.float32)
1161
-
1162
- result = frame_float * alpha_3ch + background_float * (1 - alpha_3ch)
1163
- result = np.clip(result, 0, 255).astype(np.uint8)
1164
-
1165
- return result
1166
-
1167
- except Exception as e:
1168
- logger.error(f"Advanced compositing error: {e}")
1169
- raise
1170
-
1171
- def _color_match_edges(frame: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray:
1172
- """Subtle color matching at edges to reduce halos"""
1173
- try:
1174
- edge_mask = cv2.Sobel(alpha, cv2.CV_64F, 1, 1, ksize=3)
1175
- edge_mask = np.abs(edge_mask)
1176
- edge_mask = (edge_mask > 0.1).astype(np.float32)
1177
-
1178
- edge_areas = edge_mask > 0
1179
- if not np.any(edge_areas):
1180
- return frame
1181
-
1182
- frame_adjusted = frame.copy().astype(np.float32)
1183
- background_float = background.astype(np.float32)
1184
-
1185
- adjustment_strength = 0.1
1186
- for c in range(3):
1187
- frame_adjusted[:, :, c] = np.where(
1188
- edge_areas,
1189
- frame_adjusted[:, :, c] * (1 - adjustment_strength) +
1190
- background_float[:, :, c] * adjustment_strength,
1191
- frame_adjusted[:, :, c]
1192
- )
1193
-
1194
- return np.clip(frame_adjusted, 0, 255).astype(np.uint8)
1195
-
1196
- except Exception as e:
1197
- logger.warning(f"Color matching failed: {e}")
1198
- return frame
1199
-
1200
- def _simple_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray:
1201
- """Simple fallback compositing method"""
1202
- try:
1203
- logger.info("Using simple compositing fallback")
1204
-
1205
- background = cv2.resize(background, (frame.shape[1], frame.shape[0]))
1206
-
1207
- if len(mask.shape) == 3:
1208
- mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
1209
- if mask.max() <= 1.0:
1210
- mask = (mask * 255).astype(np.uint8)
1211
-
1212
- _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
1213
-
1214
- mask_norm = mask_binary.astype(np.float32) / 255.0
1215
- mask_3ch = np.stack([mask_norm] * 3, axis=2)
1216
-
1217
- result = frame * mask_3ch + background * (1 - mask_3ch)
1218
- return result.astype(np.uint8)
1219
-
1220
- except Exception as e:
1221
- logger.error(f"Simple compositing failed: {e}")
1222
- return frame
1223
-
1224
- # ============================================================================
1225
- # HELPER FUNCTIONS - BACKGROUND CREATION
1226
- # ============================================================================
1227
-
1228
- def _create_solid_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
1229
- """Create solid color background (BGR)"""
1230
- color_hex = bg_config["colors"][0].lstrip('#')
1231
- color_rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
1232
- color_bgr = color_rgb[::-1]
1233
- return np.full((height, width, 3), color_bgr, dtype=np.uint8)
1234
-
1235
- def _create_gradient_background_enhanced(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
1236
- """Create enhanced gradient background with better quality (BGR out)"""
1237
- try:
1238
- colors = bg_config["colors"]
1239
- direction = bg_config.get("direction", "vertical")
1240
-
1241
- rgb_colors = []
1242
- for color_hex in colors:
1243
- color_hex = color_hex.lstrip('#')
1244
- rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
1245
- rgb_colors.append(rgb)
1246
-
1247
- if not rgb_colors:
1248
- rgb_colors = [(128, 128, 128)]
1249
-
1250
- if direction == "vertical":
1251
- background = _create_vertical_gradient(rgb_colors, width, height)
1252
- elif direction == "horizontal":
1253
- background = _create_horizontal_gradient(rgb_colors, width, height)
1254
- elif direction == "diagonal":
1255
- background = _create_diagonal_gradient(rgb_colors, width, height)
1256
- elif direction in ["radial", "soft_radial"]:
1257
- background = _create_radial_gradient(rgb_colors, width, height, direction == "soft_radial")
1258
- else:
1259
- background = _create_vertical_gradient(rgb_colors, width, height)
1260
-
1261
- return cv2.cvtColor(background, cv2.COLOR_RGB2BGR)
1262
-
1263
- except Exception as e:
1264
- logger.error(f"Gradient creation error: {e}")
1265
- return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
1266
-
1267
- def _create_vertical_gradient(colors: list, width: int, height: int) -> np.ndarray:
1268
- """Create vertical gradient using NumPy for performance (RGB)"""
1269
- gradient = np.zeros((height, width, 3), dtype=np.uint8)
1270
- for y in range(height):
1271
- progress = y / max(1, height)
1272
- gradient[y, :] = _interpolate_color(colors, progress)
1273
- return gradient
1274
-
1275
- def _create_horizontal_gradient(colors: list, width: int, height: int) -> np.ndarray:
1276
- """Create horizontal gradient using NumPy for performance (RGB)"""
1277
- gradient = np.zeros((height, width, 3), dtype=np.uint8)
1278
- for x in range(width):
1279
- progress = x / max(1, width)
1280
- gradient[:, x] = _interpolate_color(colors, progress)
1281
- return gradient
1282
-
1283
- def _create_diagonal_gradient(colors: list, width: int, height: int) -> np.ndarray:
1284
- """Create diagonal gradient using vectorized operations (RGB)"""
1285
- y_coords, x_coords = np.mgrid[0:height, 0:width]
1286
- max_distance = width + height
1287
- progress = (x_coords + y_coords) / max(1, max_distance)
1288
- progress = np.clip(progress, 0, 1)
1289
-
1290
- gradient = np.zeros((height, width, 3), dtype=np.uint8)
1291
- for c in range(3):
1292
- gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c)
1293
- return gradient
1294
-
1295
- def _create_radial_gradient(colors: list, width: int, height: int, soft: bool = False) -> np.ndarray:
1296
- """Create radial gradient using vectorized operations (RGB)"""
1297
- center_x, center_y = width // 2, height // 2
1298
- max_distance = np.sqrt(center_x**2 + center_y**2)
1299
-
1300
- y_coords, x_coords = np.mgrid[0:height, 0:width]
1301
- distances = np.sqrt((x_coords - center_x)**2 + (y_coords - center_y)**2)
1302
- progress = distances / max(1e-6, max_distance)
1303
- progress = np.clip(progress, 0, 1)
1304
-
1305
- if soft:
1306
- progress = np.power(progress, 0.7)
1307
-
1308
- gradient = np.zeros((height, width, 3), dtype=np.uint8)
1309
- for c in range(3):
1310
- gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c)
1311
-
1312
- return gradient
1313
-
1314
- def _vectorized_color_interpolation(colors: list, progress: np.ndarray, channel: int) -> np.ndarray:
1315
- """Vectorized color interpolation for performance"""
1316
- if len(colors) == 1:
1317
- return np.full_like(progress, colors[0][channel], dtype=np.uint8)
1318
-
1319
- num_segments = len(colors) - 1
1320
- segment_progress = progress * num_segments
1321
- segment_indices = np.floor(segment_progress).astype(int)
1322
- segment_indices = np.clip(segment_indices, 0, num_segments - 1)
1323
- local_progress = segment_progress - segment_indices
1324
-
1325
- start_colors = np.array([colors[i][channel] for i in range(len(colors))])
1326
- end_colors = np.array([colors[min(i + 1, len(colors) - 1)][channel] for i in range(len(colors))])
1327
-
1328
- start_vals = start_colors[segment_indices]
1329
- end_vals = end_colors[segment_indices]
1330
-
1331
- result = start_vals + (end_vals - start_vals) * local_progress
1332
- return np.clip(result, 0, 255).astype(np.uint8)
1333
-
1334
- def _interpolate_color(colors: list, progress: float) -> tuple:
1335
- """Interpolate between multiple colors (RGB tuple)"""
1336
- if len(colors) == 1:
1337
- return colors[0]
1338
- elif len(colors) == 2:
1339
- r = int(colors[0][0] + (colors[1][0] - colors[0][0]) * progress)
1340
- g = int(colors[0][1] + (colors[1][1] - colors[0][1]) * progress)
1341
- b = int(colors[0][2] + (colors[1][2] - colors[0][2]) * progress)
1342
- return (r, g, b)
1343
- else:
1344
- segment = progress * (len(colors) - 1)
1345
- idx = int(segment)
1346
- local_progress = max(0.0, min(1.0, segment - idx))
1347
- if idx >= len(colors) - 1:
1348
- return colors[-1]
1349
- c1, c2 = colors[idx], colors[idx + 1]
1350
- r = int(c1[0] + (c2[0] - c1[0]) * local_progress)
1351
- g = int(c1[1] + (c2[1] - c1[1]) * local_progress)
1352
- b = int(c1[2] + (c2[2] - c1[2]) * local_progress)
1353
- return (r, g, b)
1354
-
1355
- def _apply_background_adjustments(background: np.ndarray, bg_config: Dict[str, Any]) -> np.ndarray:
1356
- """Apply brightness and contrast adjustments to background"""
1357
- try:
1358
- brightness = bg_config.get("brightness", 1.0)
1359
- contrast = bg_config.get("contrast", 1.0)
1360
-
1361
- if brightness != 1.0 or contrast != 1.0:
1362
- background = background.astype(np.float32)
1363
- background = background * contrast * brightness
1364
- background = np.clip(background, 0, 255).astype(np.uint8)
1365
 
1366
- return background
1367
 
1368
  except Exception as e:
1369
- logger.warning(f"Background adjustment failed: {e}")
1370
- return background
 
1
+ #!/usr/bin/env python3
2
  """
3
+ cv_processing.py · slim orchestrator layer
4
+ ──────────────────────────────────────────────────────────────────────────────
5
+ Keeps the public API (segment_person_hq, refine_mask_hq, replace_background_hq,
6
+ create_professional_background, validate_video_file) exactly the same so that
7
+ existing callers do **not** need to change their imports.
8
+
9
+ All heavy-lifting implementations live in:
10
+ utils.segmentation
11
+ utils.refinement
12
+ utils.compositing
13
+ utils.background_factory
14
+ utils.background_presets
15
  """
16
 
17
+ from __future__ import annotations
18
+
19
+ # ── std / 3rd-party ────────────────────────────────────────────────────────
20
+ import os, logging, cv2, numpy as np
21
+ from pathlib import Path
22
+ from typing import Tuple, Dict, Any, Optional
23
+
24
+ # ── project helpers (new modules we split out) ─────────────────────────────
25
+ from utils.segmentation import (
26
+ segment_person_hq,
27
+ segment_person_hq_original,
28
+ SegmentationError,
29
+ )
30
+ from utils.refinement import (
31
+ refine_mask_hq, MaskRefinementError,
32
+ )
33
+ from utils.compositing import (
34
+ replace_background_hq, BackgroundReplacementError,
35
+ )
36
+ from utils.background_factory import create_professional_background
37
+ from utils.background_presets import PROFESSIONAL_BACKGROUNDS # still used in the UI
38
 
39
  logger = logging.getLogger(__name__)
40
 
41
+ # ----------------------------------------------------------------------------
42
+ # LIGHT CONFIG – only what the UI still needs
43
+ # ----------------------------------------------------------------------------
44
+ USE_AUTO_TEMPORAL_CONSISTENCY = True # placeholder for future smoothing
45
+
46
+ # Validator soft-limits (kept here because validate_video_file still lives here)
47
+ MIN_AREA_RATIO = 0.015
48
+ MAX_AREA_RATIO = 0.97
49
+
50
+ # ----------------------------------------------------------------------------
51
+ # PUBLIC 1-LINERS to keep old call-sites working
52
+ # ----------------------------------------------------------------------------
53
+ # (They’re just re-exports from their new homes.)
54
+
55
+ __all__ = [
56
+ "segment_person_hq",
57
+ "segment_person_hq_original",
58
+ "refine_mask_hq",
59
+ "replace_background_hq",
60
+ "create_professional_background",
61
+ "validate_video_file",
62
+ "SegmentationError",
63
+ "MaskRefinementError",
64
+ "BackgroundReplacementError",
65
+ "PROFESSIONAL_BACKGROUNDS",
66
+ ]
67
+
68
+ # ----------------------------------------------------------------------------
69
+ # VIDEO VALIDATION (unchanged)
70
+ # ----------------------------------------------------------------------------
71
+ def validate_video_file(video_path: str) -> Tuple[bool, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  """
73
+ Quick sanity-check before passing a file to OpenCV / FFmpeg.
74
+ Returns (ok, human_readable_reason)
75
  """
76
+ if not video_path or not Path(video_path).exists():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return False, "Video file not found"
78
 
79
  try:
80
+ size = Path(video_path).stat().st_size
81
+ if size == 0:
82
+ return False, "File is empty"
83
+ if size > 2 * 1024 * 1024 * 1024:
84
+ return False, "File > 2 GB too large for the Space quota"
 
85
 
86
  cap = cv2.VideoCapture(video_path)
87
  if not cap.isOpened():
88
+ return False, "OpenCV cannot read the file"
 
 
 
 
 
89
 
90
+ n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
91
+ fps = cap.get(cv2.CAP_PROP_FPS)
92
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
93
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
94
  cap.release()
95
 
96
+ if n_frames == 0:
97
+ return False, "No frames detected"
 
98
  if fps <= 0 or fps > 120:
99
+ return False, f"Suspicious FPS: {fps}"
100
+ if w <= 0 or h <= 0:
101
+ return False, "Zero resolution"
102
+ if w > 4096 or h > 4096:
103
+ return False, f"Resolution {w}×{h} too high (max 4 096²)"
104
+ if (n_frames / fps) > 300:
105
+ return False, "Video longer than 5 minutes"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ return True, f"OK → {w}×{h}, {fps:.1f} fps, {n_frames/fps:.1f} s"
108
 
109
  except Exception as e:
110
+ logger.error(f"validate_video_file: {e}")
111
+ return False, f"Validation error: {e}"