MogensR commited on
Commit
c3211a4
·
verified ·
1 Parent(s): d5540c9

Update utils/cv_processing.py

Browse files
Files changed (1) hide show
  1. utils/cv_processing.py +150 -19
utils/cv_processing.py CHANGED
@@ -1,7 +1,8 @@
1
  #!/usr/bin/env python3
2
  """
3
- cv_processing.py · FIXED VERSION with proper SAM2 handling + MatAnyone stateful integration
4
- Now with environment variable support for USE_SAM2 and USE_MATANYONE
 
5
 
6
  All public functions in this module expect RGB images (H,W,3) unless stated otherwise.
7
  CoreVideoProcessor already converts BGR→RGB before calling into this module.
@@ -32,6 +33,11 @@ def _use_matanyone_enabled() -> bool:
32
  val = os.getenv("USE_MATANYONE", "1")
33
  return val.lower() in ("1", "true", "yes", "on")
34
 
 
 
 
 
 
35
  # ----------------------------------------------------------------------------
36
  # Background presets
37
  # ----------------------------------------------------------------------------
@@ -139,6 +145,46 @@ def _vertical_gradient(top: Tuple[int,int,int], bottom: Tuple[int,int,int], widt
139
  bg[y, :] = (r, g, b)
140
  return bg
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  # ----------------------------------------------------------------------------
143
  # Background creation
144
  # ----------------------------------------------------------------------------
@@ -160,7 +206,7 @@ def create_professional_background(key_or_cfg: Any, width: int, height: int) ->
160
  return _vertical_gradient(dark, color, width, height)
161
 
162
  # ----------------------------------------------------------------------------
163
- # Improved Segmentation (expects RGB input)
164
  # ----------------------------------------------------------------------------
165
  def _simple_person_segmentation(frame_rgb: np.ndarray) -> np.ndarray:
166
  """Basic fallback segmentation using color detection on RGB frames."""
@@ -189,7 +235,14 @@ def _simple_person_segmentation(frame_rgb: np.ndarray) -> np.ndarray:
189
  person_mask = np.zeros_like(person_mask)
190
  cv2.drawContours(person_mask, [largest_contour], -1, 255, -1)
191
 
192
- return (person_mask.astype(np.float32) / 255.0)
 
 
 
 
 
 
 
193
 
194
  def segment_person_hq(
195
  frame: np.ndarray,
@@ -199,7 +252,8 @@ def segment_person_hq(
199
  **_compat_kwargs,
200
  ) -> np.ndarray:
201
  """
202
- High-quality person segmentation with proper SAM2 handling.
 
203
  Expects RGB frame (H,W,3), uint8 or float in [0,1].
204
  """
205
  # Override with environment variable if not explicitly set
@@ -215,7 +269,38 @@ def segment_person_hq(
215
 
216
  if predictor is not None:
217
  try:
218
- if hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  # Predictor adapter expects RGB uint8; convert if needed
220
  if frame_rgb.dtype != np.uint8:
221
  rgb_u8 = np.clip(frame_rgb * (255.0 if frame_rgb.dtype != np.uint8 else 1.0), 0, 255).astype(np.uint8) \
@@ -272,13 +357,20 @@ def segment_person_hq(
272
  if mask is not None:
273
  mask = _to_mask01(mask)
274
  # Add debug logging
275
- logger.info(f"SAM2 mask stats: shape={mask.shape}, min={mask.min():.3f}, max={mask.max():.3f}, mean={mask.mean():.3f}")
 
276
  if float(mask.max()) > 0.1:
 
 
 
 
277
  return np.ascontiguousarray(mask)
278
  else:
279
- logger.warning("SAM2 mask too weak, using fallback")
280
- else:
281
- logger.warning("SAM2 returned no masks")
 
 
282
 
283
  except Exception as e:
284
  logger.warning(f"SAM2 segmentation error: {e}")
@@ -292,7 +384,7 @@ def segment_person_hq(
292
  segment_person_hq_original = segment_person_hq
293
 
294
  # ----------------------------------------------------------------------------
295
- # MatAnyone Refinement (Stateful-capable)
296
  # ----------------------------------------------------------------------------
297
  def refine_mask_hq(
298
  frame: np.ndarray,
@@ -305,7 +397,7 @@ def refine_mask_hq(
305
  **_compat_kwargs,
306
  ) -> np.ndarray:
307
  """
308
- Refine mask with MatAnyone.
309
 
310
  Modes:
311
  • Stateful (preferred): provide `frame_idx`. On frame_idx==0, the session encodes with the mask.
@@ -324,6 +416,10 @@ def refine_mask_hq(
324
 
325
  if use_matanyone is False:
326
  logger.info("MatAnyone disabled by environment variable, returning unrefined mask")
 
 
 
 
327
  return mask01
328
 
329
  if matanyone is not None and callable(matanyone):
@@ -338,7 +434,8 @@ def refine_mask_hq(
338
  refined = matanyone(rgb01) # propagate without mask
339
  refined = _mask_to_2d(refined)
340
  if float(refined.max()) > 0.1:
341
- return _postprocess_mask(refined)
 
342
  logger.warning("MatAnyone stateful refinement produced empty/weak mask; falling back.")
343
 
344
  # Backward-compat (stateless) path
@@ -368,7 +465,8 @@ def refine_mask_hq(
368
  logger.debug(f"MatAnyone process failed: {e}")
369
 
370
  if refined is not None and float(refined.max()) > 0.1:
371
- return _postprocess_mask(refined)
 
372
  else:
373
  logger.warning("MatAnyone refinement failed or produced empty mask")
374
 
@@ -377,12 +475,27 @@ def refine_mask_hq(
377
 
378
  # Fallback refinement
379
  if fallback_enabled:
380
- return _fallback_refine(mask01)
381
  else:
 
 
 
 
382
  return mask01
383
 
 
 
 
 
 
 
 
 
 
 
 
384
  def _postprocess_mask(mask01: np.ndarray) -> np.ndarray:
385
- """Post-process mask to clean edges and remove artifacts"""
386
  mask_uint8 = (np.clip(mask01, 0, 1) * 255).astype(np.uint8)
387
 
388
  kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
@@ -397,6 +510,17 @@ def _postprocess_mask(mask01: np.ndarray) -> np.ndarray:
397
  out = mask_uint8.astype(np.float32) / 255.0
398
  return np.ascontiguousarray(out)
399
 
 
 
 
 
 
 
 
 
 
 
 
400
  def _fallback_refine(mask01: np.ndarray) -> np.ndarray:
401
  """Simple fallback refinement"""
402
  mask_uint8 = (np.clip(mask01, 0, 1) * 255).astype(np.uint8)
@@ -413,7 +537,7 @@ def _fallback_refine(mask01: np.ndarray) -> np.ndarray:
413
  return np.ascontiguousarray(out)
414
 
415
  # ----------------------------------------------------------------------------
416
- # Compositing (expects RGB inputs)
417
  # ----------------------------------------------------------------------------
418
  def replace_background_hq(
419
  frame: np.ndarray,
@@ -422,7 +546,7 @@ def replace_background_hq(
422
  fallback_enabled: bool = True,
423
  **_compat,
424
  ) -> np.ndarray:
425
- """High-quality background replacement with alpha blending (RGB in/out)."""
426
  try:
427
  H, W = frame.shape[:2]
428
 
@@ -431,7 +555,14 @@ def replace_background_hq(
431
 
432
  m = _mask_to_2d(_to_mask01(mask01))
433
 
434
- m = _feather(m, k=1)
 
 
 
 
 
 
 
435
 
436
  m3 = np.repeat(m[:, :, None], 3, axis=2)
437
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ cv_processing.py · MAXIMUM QUALITY VERSION with enhanced SAM2Handler integration
4
+ Updated to work with enhanced SAM2Handler that has full-body detection strategies
5
+ Now includes maximum quality mask cleaning and aggressive post-processing
6
 
7
  All public functions in this module expect RGB images (H,W,3) unless stated otherwise.
8
  CoreVideoProcessor already converts BGR→RGB before calling into this module.
 
33
  val = os.getenv("USE_MATANYONE", "1")
34
  return val.lower() in ("1", "true", "yes", "on")
35
 
36
+ def _use_max_quality_enabled() -> bool:
37
+ """Check if maximum quality processing should be used"""
38
+ val = os.getenv("BFX_QUALITY", "max")
39
+ return val.lower() == "max"
40
+
41
  # ----------------------------------------------------------------------------
42
  # Background presets
43
  # ----------------------------------------------------------------------------
 
145
  bg[y, :] = (r, g, b)
146
  return bg
147
 
148
+ # ----------------------------------------------------------------------------
149
+ # Maximum Quality Mask Cleaning (integrated from TwoStageProcessor)
150
+ # ----------------------------------------------------------------------------
151
+ def _maximum_quality_mask_cleaning(mask: np.ndarray) -> np.ndarray:
152
+ """Maximum quality mask cleaning and refinement - same as TwoStageProcessor."""
153
+ try:
154
+ # Ensure uint8 format
155
+ if mask.max() <= 1.0:
156
+ mask_uint8 = (mask * 255).astype(np.uint8)
157
+ else:
158
+ mask_uint8 = mask.astype(np.uint8)
159
+
160
+ # Step 1: Fill small holes aggressively
161
+ kernel_fill = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
162
+ mask_filled = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel_fill)
163
+
164
+ # Step 2: Connect nearby regions
165
+ kernel_connect = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
166
+ mask_connected = cv2.morphologyEx(mask_filled, cv2.MORPH_CLOSE, kernel_connect)
167
+
168
+ # Step 3: Smooth boundaries heavily
169
+ mask_smooth1 = cv2.GaussianBlur(mask_connected, (7, 7), 2.0)
170
+
171
+ # Step 4: Re-threshold to crisp edges
172
+ _, mask_thresh = cv2.threshold(mask_smooth1, 127, 255, cv2.THRESH_BINARY)
173
+
174
+ # Step 5: Final edge smoothing
175
+ mask_final = cv2.GaussianBlur(mask_thresh, (5, 5), 1.0)
176
+
177
+ # Step 6: Dilate slightly to ensure full coverage
178
+ kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
179
+ mask_dilated = cv2.dilate(mask_final, kernel_dilate, iterations=1)
180
+
181
+ logger.info("Maximum quality mask cleaning applied successfully")
182
+ return (mask_dilated.astype(np.float32) / 255.0)
183
+
184
+ except Exception as e:
185
+ logger.warning(f"Maximum quality mask cleaning failed: {e}")
186
+ return mask
187
+
188
  # ----------------------------------------------------------------------------
189
  # Background creation
190
  # ----------------------------------------------------------------------------
 
206
  return _vertical_gradient(dark, color, width, height)
207
 
208
  # ----------------------------------------------------------------------------
209
+ # Improved Segmentation (expects RGB input) - ENHANCED FOR SAM2Handler
210
  # ----------------------------------------------------------------------------
211
  def _simple_person_segmentation(frame_rgb: np.ndarray) -> np.ndarray:
212
  """Basic fallback segmentation using color detection on RGB frames."""
 
235
  person_mask = np.zeros_like(person_mask)
236
  cv2.drawContours(person_mask, [largest_contour], -1, 255, -1)
237
 
238
+ mask_result = (person_mask.astype(np.float32) / 255.0)
239
+
240
+ # Apply maximum quality cleaning if enabled
241
+ if _use_max_quality_enabled():
242
+ mask_result = _maximum_quality_mask_cleaning(mask_result)
243
+ logger.info("Applied maximum quality cleaning to fallback segmentation")
244
+
245
+ return mask_result
246
 
247
  def segment_person_hq(
248
  frame: np.ndarray,
 
252
  **_compat_kwargs,
253
  ) -> np.ndarray:
254
  """
255
+ High-quality person segmentation with ENHANCED SAM2Handler integration.
256
+ Now uses enhanced SAM2Handler.create_mask() for full-body detection.
257
  Expects RGB frame (H,W,3), uint8 or float in [0,1].
258
  """
259
  # Override with environment variable if not explicitly set
 
269
 
270
  if predictor is not None:
271
  try:
272
+ # ENHANCED: Check if this is the new SAM2Handler with create_mask method
273
+ if hasattr(predictor, 'create_mask'):
274
+ logger.info("Using ENHANCED SAM2Handler.create_mask() with full-body detection")
275
+ # SAM2Handler expects RGB uint8
276
+ if frame_rgb.dtype != np.uint8:
277
+ rgb_u8 = np.clip(frame_rgb * (255.0 if frame_rgb.dtype != np.uint8 else 1.0), 0, 255).astype(np.uint8) \
278
+ if np.issubdtype(frame_rgb.dtype, np.floating) else frame_rgb.astype(np.uint8)
279
+ else:
280
+ rgb_u8 = frame_rgb
281
+
282
+ # Use enhanced SAM2Handler with full-body detection strategies
283
+ mask = predictor.create_mask(rgb_u8)
284
+
285
+ if mask is not None:
286
+ # Convert to float format
287
+ mask_float = _to_mask01(mask)
288
+ logger.info(f"Enhanced SAM2Handler mask stats: shape={mask_float.shape}, min={mask_float.min():.3f}, max={mask_float.max():.3f}, mean={mask_float.mean():.3f}")
289
+
290
+ if float(mask_float.max()) > 0.1:
291
+ # Apply additional maximum quality cleaning if enabled
292
+ if _use_max_quality_enabled():
293
+ mask_float = _maximum_quality_mask_cleaning(mask_float)
294
+ logger.info("Applied additional maximum quality cleaning to enhanced SAM2 result")
295
+ return np.ascontiguousarray(mask_float)
296
+ else:
297
+ logger.warning("Enhanced SAM2Handler mask too weak, using fallback")
298
+ else:
299
+ logger.warning("Enhanced SAM2Handler returned None mask")
300
+
301
+ # FALLBACK: Basic SAM2 predictor handling (legacy compatibility)
302
+ elif hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
303
+ logger.info("Using legacy SAM2 predictor interface")
304
  # Predictor adapter expects RGB uint8; convert if needed
305
  if frame_rgb.dtype != np.uint8:
306
  rgb_u8 = np.clip(frame_rgb * (255.0 if frame_rgb.dtype != np.uint8 else 1.0), 0, 255).astype(np.uint8) \
 
357
  if mask is not None:
358
  mask = _to_mask01(mask)
359
  # Add debug logging
360
+ logger.info(f"Legacy SAM2 mask stats: shape={mask.shape}, min={mask.min():.3f}, max={mask.max():.3f}, mean={mask.mean():.3f}")
361
+
362
  if float(mask.max()) > 0.1:
363
+ # Apply maximum quality cleaning if enabled
364
+ if _use_max_quality_enabled():
365
+ mask = _maximum_quality_mask_cleaning(mask)
366
+ logger.info("Applied maximum quality cleaning to legacy SAM2 result")
367
  return np.ascontiguousarray(mask)
368
  else:
369
+ logger.warning("Legacy SAM2 mask too weak, using fallback")
370
+ else:
371
+ logger.warning("Legacy SAM2 returned no masks")
372
+ else:
373
+ logger.warning("Predictor doesn't have expected SAM2 interface")
374
 
375
  except Exception as e:
376
  logger.warning(f"SAM2 segmentation error: {e}")
 
384
  segment_person_hq_original = segment_person_hq
385
 
386
  # ----------------------------------------------------------------------------
387
+ # MatAnyone Refinement (Stateful-capable) - ENHANCED WITH MAX QUALITY
388
  # ----------------------------------------------------------------------------
389
  def refine_mask_hq(
390
  frame: np.ndarray,
 
397
  **_compat_kwargs,
398
  ) -> np.ndarray:
399
  """
400
+ Refine mask with MatAnyone + maximum quality post-processing.
401
 
402
  Modes:
403
  • Stateful (preferred): provide `frame_idx`. On frame_idx==0, the session encodes with the mask.
 
416
 
417
  if use_matanyone is False:
418
  logger.info("MatAnyone disabled by environment variable, returning unrefined mask")
419
+ # Still apply maximum quality cleaning if enabled
420
+ if _use_max_quality_enabled():
421
+ mask01 = _maximum_quality_mask_cleaning(mask01)
422
+ logger.info("Applied maximum quality cleaning to unrefined mask")
423
  return mask01
424
 
425
  if matanyone is not None and callable(matanyone):
 
434
  refined = matanyone(rgb01) # propagate without mask
435
  refined = _mask_to_2d(refined)
436
  if float(refined.max()) > 0.1:
437
+ result = _postprocess_mask_max_quality(refined)
438
+ return result
439
  logger.warning("MatAnyone stateful refinement produced empty/weak mask; falling back.")
440
 
441
  # Backward-compat (stateless) path
 
465
  logger.debug(f"MatAnyone process failed: {e}")
466
 
467
  if refined is not None and float(refined.max()) > 0.1:
468
+ result = _postprocess_mask_max_quality(refined)
469
+ return result
470
  else:
471
  logger.warning("MatAnyone refinement failed or produced empty mask")
472
 
 
475
 
476
  # Fallback refinement
477
  if fallback_enabled:
478
+ return _fallback_refine_max_quality(mask01)
479
  else:
480
+ # Still apply maximum quality cleaning if enabled
481
+ if _use_max_quality_enabled():
482
+ mask01 = _maximum_quality_mask_cleaning(mask01)
483
+ logger.info("Applied maximum quality cleaning to fallback mask")
484
  return mask01
485
 
486
+ def _postprocess_mask_max_quality(mask01: np.ndarray) -> np.ndarray:
487
+ """Post-process mask with maximum quality cleaning"""
488
+ if _use_max_quality_enabled():
489
+ # Use the aggressive maximum quality cleaning
490
+ result = _maximum_quality_mask_cleaning(mask01)
491
+ logger.info("Applied maximum quality post-processing to MatAnyone result")
492
+ return result
493
+ else:
494
+ # Use standard post-processing
495
+ return _postprocess_mask(mask01)
496
+
497
  def _postprocess_mask(mask01: np.ndarray) -> np.ndarray:
498
+ """Standard post-process mask to clean edges and remove artifacts"""
499
  mask_uint8 = (np.clip(mask01, 0, 1) * 255).astype(np.uint8)
500
 
501
  kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
 
510
  out = mask_uint8.astype(np.float32) / 255.0
511
  return np.ascontiguousarray(out)
512
 
513
+ def _fallback_refine_max_quality(mask01: np.ndarray) -> np.ndarray:
514
+ """Fallback refinement with maximum quality option"""
515
+ if _use_max_quality_enabled():
516
+ # Use aggressive maximum quality cleaning
517
+ result = _maximum_quality_mask_cleaning(mask01)
518
+ logger.info("Applied maximum quality cleaning to fallback refinement")
519
+ return result
520
+ else:
521
+ # Use standard fallback refinement
522
+ return _fallback_refine(mask01)
523
+
524
  def _fallback_refine(mask01: np.ndarray) -> np.ndarray:
525
  """Simple fallback refinement"""
526
  mask_uint8 = (np.clip(mask01, 0, 1) * 255).astype(np.uint8)
 
537
  return np.ascontiguousarray(out)
538
 
539
  # ----------------------------------------------------------------------------
540
+ # Compositing (expects RGB inputs) - ENHANCED WITH MAX QUALITY
541
  # ----------------------------------------------------------------------------
542
  def replace_background_hq(
543
  frame: np.ndarray,
 
546
  fallback_enabled: bool = True,
547
  **_compat,
548
  ) -> np.ndarray:
549
+ """High-quality background replacement with alpha blending (RGB in/out) - enhanced with max quality."""
550
  try:
551
  H, W = frame.shape[:2]
552
 
 
555
 
556
  m = _mask_to_2d(_to_mask01(mask01))
557
 
558
+ # Apply maximum quality cleaning to mask before compositing
559
+ if _use_max_quality_enabled():
560
+ m = _maximum_quality_mask_cleaning(m)
561
+ logger.debug("Applied maximum quality cleaning to compositing mask")
562
+
563
+ # Enhanced feathering for maximum quality
564
+ feather_strength = 3 if _use_max_quality_enabled() else 1
565
+ m = _feather(m, k=feather_strength)
566
 
567
  m3 = np.repeat(m[:, :, None], 3, axis=2)
568