VideoBackgroundReplacer / utils /segmentation.py
MogensR's picture
Update utils/segmentation.py
19a2b07
#!/usr/bin/env python3
"""
utils.segmentation
─────────────────────────────────────────────────────────────────────────────
All high-quality person-segmentation code for BackgroundFX Pro.
Exports
-------
segment_person_hq(image, predictor, fallback_enabled=True) β†’ np.ndarray
segment_person_hq_original(image, predictor, fallback_enabled=True) β†’ np.ndarray
SegmentationError - Custom exception for segmentation errors
Everything else is prefixed "_" and considered private.
"""
from __future__ import annotations
from typing import Any, Tuple, Optional, Dict
import logging, os, math
import cv2
import numpy as np
import torch
log = logging.getLogger(__name__)
# ============================================================================
# CUSTOM EXCEPTION
# ============================================================================
class SegmentationError(Exception):
"""Custom exception for segmentation-related errors"""
pass
# ============================================================================
# TUNABLE CONSTANTS
# ============================================================================
USE_ENHANCED_SEGMENTATION = True
USE_INTELLIGENT_PROMPTING = True
USE_ITERATIVE_REFINEMENT = True
MIN_AREA_RATIO = 0.015
MAX_AREA_RATIO = 0.97
SALIENCY_THRESH = 0.65
GRABCUT_ITERS = 3
# ----------------------------------------------------------------------------
# Public -- main entry-points
# ----------------------------------------------------------------------------
__all__ = [
"segment_person_hq",
"segment_person_hq_original",
"SegmentationError",
]
# ============================================================================
# SAM2 TO MATANYONE MASK BRIDGE
# ============================================================================
def _sam2_to_matanyone_mask(masks: Any, scores: Any = None) -> np.ndarray:
"""
Convert SAM2 multi-mask output to single best mask for MatAnyone.
SAM2 returns (N, H, W) where N is typically 3 masks.
We need to return a single (H, W) mask.
"""
if masks is None or len(masks) == 0:
raise SegmentationError("No masks returned from SAM2")
# Handle torch tensors
if isinstance(masks, torch.Tensor):
masks = masks.cpu().numpy()
if scores is not None and isinstance(scores, torch.Tensor):
scores = scores.cpu().numpy()
# Ensure we have the right shape
if masks.ndim == 4: # (B, N, H, W)
masks = masks[0] # Take first batch
if masks.ndim != 3: # Should be (N, H, W)
raise SegmentationError(f"Unexpected mask shape: {masks.shape}")
# Select best mask
if scores is not None and len(scores) > 0:
best_idx = int(np.argmax(scores))
else:
# Fallback: pick mask with largest area
areas = [np.sum(m > 0.5) for m in masks]
best_idx = int(np.argmax(areas))
mask = masks[best_idx]
# Convert to uint8 binary mask
if mask.dtype in (np.float32, np.float64):
mask = (mask > 0.5).astype(np.uint8) * 255
elif mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
# Ensure single channel
if mask.ndim == 3:
mask = mask[:, :, 0] if mask.shape[2] > 1 else mask.squeeze()
# Binary threshold
_, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
# Verify output shape
assert mask.ndim == 2, f"Output mask must be 2D, got shape {mask.shape}"
return mask
# ============================================================================
# MAIN API
# ============================================================================
def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
"""
High-quality person segmentation. Tries SAM-2 with smart prompts first,
then a classical CV cascade, then a geometric fallback.
Returns uint8 mask (0/255). Never raises if fallback_enabled=True.
"""
if not USE_ENHANCED_SEGMENTATION:
return segment_person_hq_original(image, predictor, fallback_enabled)
if image is None or image.size == 0:
raise SegmentationError("Invalid input image")
# 1) β€” SAM-2 path -------------------------------------------------------
if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
try:
predictor.set_image(image)
mask = (
_segment_with_intelligent_prompts(image, predictor)
if USE_INTELLIGENT_PROMPTING
else _segment_with_basic_prompts(image, predictor)
)
if USE_ITERATIVE_REFINEMENT:
mask = _auto_refine_mask_iteratively(image, mask, predictor)
if _validate_mask_quality(mask, image.shape[:2]):
return mask
log.warning("SAM2 mask failed validation β†’ fallback")
except Exception as e:
log.warning(f"SAM2 path failed: {e}")
# 2) β€” Classical cascade ----------------------------------------------
try:
mask = _classical_segmentation_cascade(image)
if _validate_mask_quality(mask, image.shape[:2]):
return mask
log.warning("Classical cascade weak β†’ geometric fallback")
except Exception as e:
log.debug(f"Classical cascade error: {e}")
# 3) β€” Last-chance geometric ellipse ----------------------------------
return _geometric_person_mask(image)
def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
"""
Very first implementation kept for rollback. Fewer smarts, still robust.
"""
if image is None or image.size == 0:
raise SegmentationError("Invalid input image")
try:
if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
h, w = image.shape[:2]
predictor.set_image(image)
points = np.array([
[w//2, h//4],
[w//2, h//2],
[w//2, 3*h//4],
[w//3, h//2],
[2*w//3, h//2],
], dtype=np.float32)
labels = np.ones(len(points), np.int32)
with torch.no_grad():
masks, scores, _ = predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True,
)
# Use the bridge function to get single best mask
if masks is not None and len(masks):
mask = _sam2_to_matanyone_mask(masks, scores)
if _validate_mask_quality(mask, image.shape[:2]):
return mask
if fallback_enabled:
return _classical_segmentation_cascade(image)
raise RuntimeError("SAM2 failed and fallback disabled")
except Exception as e:
log.warning(f"segment_person_hq_original error: {e}")
return _classical_segmentation_cascade(image)
# ============================================================================
# INTELLIGENT + BASIC PROMPTING
# ============================================================================
def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
pos, neg = _generate_smart_prompts(image)
return _sam2_predict(image, predictor, pos, neg)
def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
h, w = image.shape[:2]
pos = np.array([[w//2, h//3], [w//2, h//2], [w//2, 2*h//3]], np.float32)
neg = np.array([[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]], np.float32)
return _sam2_predict(image, predictor, pos, neg)
def _sam2_predict(image: np.ndarray, predictor: Any,
pos_points: np.ndarray, neg_points: np.ndarray) -> np.ndarray:
if pos_points.size == 0:
pos_points = np.array([[image.shape[1]//2, image.shape[0]//2]], np.float32)
points = np.vstack([pos_points, neg_points])
labels = np.hstack([np.ones(len(pos_points)), np.zeros(len(neg_points))]).astype(np.int32)
with torch.no_grad():
masks, scores, _ = predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True,
)
# Use the bridge function to convert multi-mask to single mask
return _sam2_to_matanyone_mask(masks, scores)
def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Simple saliency-based heuristic to auto-place positive / negative points.
"""
h, w = image.shape[:2]
sal = _compute_saliency(image)
pos, neg = [], []
if sal is not None:
high = sal > (SALIENCY_THRESH - .1)
contours, _ = cv2.findContours((high*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for c in sorted(contours, key=cv2.contourArea, reverse=True)[:3]:
M = cv2.moments(c)
if M["m00"]:
pos.append([int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"])])
if not pos:
pos = [[w//2, h//2]]
neg = [[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]]
return np.asarray(pos, np.float32), np.asarray(neg, np.float32)
# ============================================================================
# CLASSICAL SEGMENTATION CASCADE
# ============================================================================
def _classical_segmentation_cascade(image: np.ndarray) -> np.ndarray:
"""
Edge-median background subtraction β†’ saliency flood-fill β†’ GrabCut.
"""
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
edge_px = np.concatenate([gray[0], gray[-1], gray[:, 0], gray[:, -1]])
diff = np.abs(gray.astype(float) - np.median(edge_px))
mask = (diff > 30).astype(np.uint8) * 255
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE,
cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))
if _validate_mask_quality(mask, image.shape[:2]):
return mask
# Saliency + flood-fill
mask = _refine_with_saliency(image, mask)
if _validate_mask_quality(mask, image.shape[:2]):
return mask
# GrabCut
mask = _refine_with_grabcut(image, mask)
if _validate_mask_quality(mask, image.shape[:2]):
return mask
# Geometric fallback
return _geometric_person_mask(image)
# Saliency, GrabCut helpers --------------------------------------------------
def _compute_saliency(image: np.ndarray) -> Optional[np.ndarray]:
try:
if hasattr(cv2, "saliency"):
s = cv2.saliency.StaticSaliencySpectralResidual_create()
ok, smap = s.computeSaliency(image)
if ok:
smap = (smap - smap.min()) / max(1e-6, smap.max()-smap.min())
return smap
except Exception:
pass
return None
def _auto_person_rect(image):
sal = _compute_saliency(image)
if sal is None:
return None
m = (sal > SALIENCY_THRESH).astype(np.uint8)
cnts, _ = cv2.findContours(m*255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not cnts:
return None
x,y,w,h = cv2.boundingRect(max(cnts, key=cv2.contourArea))
H,W = image.shape[:2]
pad = 0.05
x = max(0, int(x-W*pad)); y = max(0, int(y-H*pad))
w = min(W-x, int(w*(1+2*pad))); h = min(H-y, int(h*(1+2*pad)))
return x,y,w,h
def _refine_with_grabcut(image: np.ndarray, seed: np.ndarray) -> np.ndarray:
h,w = image.shape[:2]
gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8)
gc[seed>200] = cv2.GC_FGD
rect = _auto_person_rect(image) or (w//4, h//6, w//2, int(h*0.7))
bgd, fgd = np.zeros((1,65), np.float64), np.zeros((1,65), np.float64)
cv2.grabCut(image, gc, rect, bgd, fgd, GRABCUT_ITERS, cv2.GC_INIT_WITH_MASK)
return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
def _refine_with_saliency(image: np.ndarray, seed: np.ndarray) -> np.ndarray:
sal = _compute_saliency(image)
if sal is None:
return seed
high = (sal > SALIENCY_THRESH).astype(np.uint8)*255
ys,xs = np.where(seed>127)
cy,cx = int(np.mean(ys)) if len(ys) else image.shape[0]//2, int(np.mean(xs)) if len(xs) else image.shape[1]//2
ff = high.copy()
cv2.floodFill(ff, None, (cx,cy), 255, loDiff=5, upDiff=5)
return ff
# ============================================================================
# QUALITY / HELPER FUNCTIONS
# ============================================================================
def _validate_mask_quality(mask: np.ndarray, shape: Tuple[int,int]) -> bool:
h,w = shape
ratio = np.sum(mask>127)/(h*w)
return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO
def _process_mask(mask: np.ndarray) -> np.ndarray:
"""Legacy mask processor - kept for compatibility but mostly replaced by _sam2_to_matanyone_mask"""
if mask.dtype in (np.float32, np.float64):
if mask.max() <= 1.0:
mask = (mask*255).astype(np.uint8)
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.ndim == 3:
mask = mask.squeeze()
if mask.ndim == 3: # multi-channel mask β†’ collapse
mask = mask[:,:,0]
_,mask = cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
return mask
def _geometric_person_mask(image: np.ndarray) -> np.ndarray:
h,w = image.shape[:2]
mask = np.zeros((h,w), np.uint8)
cv2.ellipse(mask, (w//2,h//2), (w//3,int(h/2.5)), 0, 0,360, 255,-1)
return mask
# ============================================================================
# OPTIONAL: Iterative auto-refinement (lightweight)
# ============================================================================
def _auto_refine_mask_iteratively(image, mask, predictor, max_iterations=1):
# Simple one-pass hook (full version lives in refinement.py)
return mask