|
|
|
|
|
""" |
|
|
utils.segmentation |
|
|
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
All high-quality person-segmentation code for BackgroundFX Pro. |
|
|
|
|
|
Exports |
|
|
------- |
|
|
segment_person_hq(image, predictor, fallback_enabled=True) β np.ndarray |
|
|
segment_person_hq_original(image, predictor, fallback_enabled=True) β np.ndarray |
|
|
SegmentationError - Custom exception for segmentation errors |
|
|
|
|
|
Everything else is prefixed "_" and considered private. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
from typing import Any, Tuple, Optional, Dict |
|
|
import logging, os, math |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SegmentationError(Exception): |
|
|
"""Custom exception for segmentation-related errors""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
USE_ENHANCED_SEGMENTATION = True |
|
|
USE_INTELLIGENT_PROMPTING = True |
|
|
USE_ITERATIVE_REFINEMENT = True |
|
|
|
|
|
MIN_AREA_RATIO = 0.015 |
|
|
MAX_AREA_RATIO = 0.97 |
|
|
SALIENCY_THRESH = 0.65 |
|
|
GRABCUT_ITERS = 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"segment_person_hq", |
|
|
"segment_person_hq_original", |
|
|
"SegmentationError", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sam2_to_matanyone_mask(masks: Any, scores: Any = None) -> np.ndarray: |
|
|
""" |
|
|
Convert SAM2 multi-mask output to single best mask for MatAnyone. |
|
|
SAM2 returns (N, H, W) where N is typically 3 masks. |
|
|
We need to return a single (H, W) mask. |
|
|
""" |
|
|
if masks is None or len(masks) == 0: |
|
|
raise SegmentationError("No masks returned from SAM2") |
|
|
|
|
|
|
|
|
if isinstance(masks, torch.Tensor): |
|
|
masks = masks.cpu().numpy() |
|
|
if scores is not None and isinstance(scores, torch.Tensor): |
|
|
scores = scores.cpu().numpy() |
|
|
|
|
|
|
|
|
if masks.ndim == 4: |
|
|
masks = masks[0] |
|
|
if masks.ndim != 3: |
|
|
raise SegmentationError(f"Unexpected mask shape: {masks.shape}") |
|
|
|
|
|
|
|
|
if scores is not None and len(scores) > 0: |
|
|
best_idx = int(np.argmax(scores)) |
|
|
else: |
|
|
|
|
|
areas = [np.sum(m > 0.5) for m in masks] |
|
|
best_idx = int(np.argmax(areas)) |
|
|
|
|
|
mask = masks[best_idx] |
|
|
|
|
|
|
|
|
if mask.dtype in (np.float32, np.float64): |
|
|
mask = (mask > 0.5).astype(np.uint8) * 255 |
|
|
elif mask.dtype != np.uint8: |
|
|
mask = mask.astype(np.uint8) |
|
|
|
|
|
|
|
|
if mask.ndim == 3: |
|
|
mask = mask[:, :, 0] if mask.shape[2] > 1 else mask.squeeze() |
|
|
|
|
|
|
|
|
_, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
|
|
|
|
|
|
|
|
assert mask.ndim == 2, f"Output mask must be 2D, got shape {mask.shape}" |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
|
|
""" |
|
|
High-quality person segmentation. Tries SAM-2 with smart prompts first, |
|
|
then a classical CV cascade, then a geometric fallback. |
|
|
Returns uint8 mask (0/255). Never raises if fallback_enabled=True. |
|
|
""" |
|
|
if not USE_ENHANCED_SEGMENTATION: |
|
|
return segment_person_hq_original(image, predictor, fallback_enabled) |
|
|
|
|
|
if image is None or image.size == 0: |
|
|
raise SegmentationError("Invalid input image") |
|
|
|
|
|
|
|
|
if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"): |
|
|
try: |
|
|
predictor.set_image(image) |
|
|
mask = ( |
|
|
_segment_with_intelligent_prompts(image, predictor) |
|
|
if USE_INTELLIGENT_PROMPTING |
|
|
else _segment_with_basic_prompts(image, predictor) |
|
|
) |
|
|
if USE_ITERATIVE_REFINEMENT: |
|
|
mask = _auto_refine_mask_iteratively(image, mask, predictor) |
|
|
if _validate_mask_quality(mask, image.shape[:2]): |
|
|
return mask |
|
|
log.warning("SAM2 mask failed validation β fallback") |
|
|
except Exception as e: |
|
|
log.warning(f"SAM2 path failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
mask = _classical_segmentation_cascade(image) |
|
|
if _validate_mask_quality(mask, image.shape[:2]): |
|
|
return mask |
|
|
log.warning("Classical cascade weak β geometric fallback") |
|
|
except Exception as e: |
|
|
log.debug(f"Classical cascade error: {e}") |
|
|
|
|
|
|
|
|
return _geometric_person_mask(image) |
|
|
|
|
|
|
|
|
def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
|
|
""" |
|
|
Very first implementation kept for rollback. Fewer smarts, still robust. |
|
|
""" |
|
|
if image is None or image.size == 0: |
|
|
raise SegmentationError("Invalid input image") |
|
|
|
|
|
try: |
|
|
if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"): |
|
|
h, w = image.shape[:2] |
|
|
predictor.set_image(image) |
|
|
|
|
|
points = np.array([ |
|
|
[w//2, h//4], |
|
|
[w//2, h//2], |
|
|
[w//2, 3*h//4], |
|
|
[w//3, h//2], |
|
|
[2*w//3, h//2], |
|
|
], dtype=np.float32) |
|
|
labels = np.ones(len(points), np.int32) |
|
|
|
|
|
with torch.no_grad(): |
|
|
masks, scores, _ = predictor.predict( |
|
|
point_coords=points, |
|
|
point_labels=labels, |
|
|
multimask_output=True, |
|
|
) |
|
|
|
|
|
|
|
|
if masks is not None and len(masks): |
|
|
mask = _sam2_to_matanyone_mask(masks, scores) |
|
|
if _validate_mask_quality(mask, image.shape[:2]): |
|
|
return mask |
|
|
|
|
|
if fallback_enabled: |
|
|
return _classical_segmentation_cascade(image) |
|
|
raise RuntimeError("SAM2 failed and fallback disabled") |
|
|
except Exception as e: |
|
|
log.warning(f"segment_person_hq_original error: {e}") |
|
|
return _classical_segmentation_cascade(image) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
|
|
pos, neg = _generate_smart_prompts(image) |
|
|
return _sam2_predict(image, predictor, pos, neg) |
|
|
|
|
|
|
|
|
def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
|
|
h, w = image.shape[:2] |
|
|
pos = np.array([[w//2, h//3], [w//2, h//2], [w//2, 2*h//3]], np.float32) |
|
|
neg = np.array([[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]], np.float32) |
|
|
return _sam2_predict(image, predictor, pos, neg) |
|
|
|
|
|
|
|
|
def _sam2_predict(image: np.ndarray, predictor: Any, |
|
|
pos_points: np.ndarray, neg_points: np.ndarray) -> np.ndarray: |
|
|
if pos_points.size == 0: |
|
|
pos_points = np.array([[image.shape[1]//2, image.shape[0]//2]], np.float32) |
|
|
points = np.vstack([pos_points, neg_points]) |
|
|
labels = np.hstack([np.ones(len(pos_points)), np.zeros(len(neg_points))]).astype(np.int32) |
|
|
with torch.no_grad(): |
|
|
masks, scores, _ = predictor.predict( |
|
|
point_coords=points, |
|
|
point_labels=labels, |
|
|
multimask_output=True, |
|
|
) |
|
|
|
|
|
|
|
|
return _sam2_to_matanyone_mask(masks, scores) |
|
|
|
|
|
|
|
|
def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
|
|
""" |
|
|
Simple saliency-based heuristic to auto-place positive / negative points. |
|
|
""" |
|
|
h, w = image.shape[:2] |
|
|
sal = _compute_saliency(image) |
|
|
pos, neg = [], [] |
|
|
if sal is not None: |
|
|
high = sal > (SALIENCY_THRESH - .1) |
|
|
contours, _ = cv2.findContours((high*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
for c in sorted(contours, key=cv2.contourArea, reverse=True)[:3]: |
|
|
M = cv2.moments(c) |
|
|
if M["m00"]: |
|
|
pos.append([int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"])]) |
|
|
if not pos: |
|
|
pos = [[w//2, h//2]] |
|
|
neg = [[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]] |
|
|
return np.asarray(pos, np.float32), np.asarray(neg, np.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _classical_segmentation_cascade(image: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Edge-median background subtraction β saliency flood-fill β GrabCut. |
|
|
""" |
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
edge_px = np.concatenate([gray[0], gray[-1], gray[:, 0], gray[:, -1]]) |
|
|
diff = np.abs(gray.astype(float) - np.median(edge_px)) |
|
|
mask = (diff > 30).astype(np.uint8) * 255 |
|
|
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, |
|
|
cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))) |
|
|
if _validate_mask_quality(mask, image.shape[:2]): |
|
|
return mask |
|
|
|
|
|
mask = _refine_with_saliency(image, mask) |
|
|
if _validate_mask_quality(mask, image.shape[:2]): |
|
|
return mask |
|
|
|
|
|
mask = _refine_with_grabcut(image, mask) |
|
|
if _validate_mask_quality(mask, image.shape[:2]): |
|
|
return mask |
|
|
|
|
|
return _geometric_person_mask(image) |
|
|
|
|
|
|
|
|
|
|
|
def _compute_saliency(image: np.ndarray) -> Optional[np.ndarray]: |
|
|
try: |
|
|
if hasattr(cv2, "saliency"): |
|
|
s = cv2.saliency.StaticSaliencySpectralResidual_create() |
|
|
ok, smap = s.computeSaliency(image) |
|
|
if ok: |
|
|
smap = (smap - smap.min()) / max(1e-6, smap.max()-smap.min()) |
|
|
return smap |
|
|
except Exception: |
|
|
pass |
|
|
return None |
|
|
|
|
|
def _auto_person_rect(image): |
|
|
sal = _compute_saliency(image) |
|
|
if sal is None: |
|
|
return None |
|
|
m = (sal > SALIENCY_THRESH).astype(np.uint8) |
|
|
cnts, _ = cv2.findContours(m*255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
if not cnts: |
|
|
return None |
|
|
x,y,w,h = cv2.boundingRect(max(cnts, key=cv2.contourArea)) |
|
|
H,W = image.shape[:2] |
|
|
pad = 0.05 |
|
|
x = max(0, int(x-W*pad)); y = max(0, int(y-H*pad)) |
|
|
w = min(W-x, int(w*(1+2*pad))); h = min(H-y, int(h*(1+2*pad))) |
|
|
return x,y,w,h |
|
|
|
|
|
def _refine_with_grabcut(image: np.ndarray, seed: np.ndarray) -> np.ndarray: |
|
|
h,w = image.shape[:2] |
|
|
gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8) |
|
|
gc[seed>200] = cv2.GC_FGD |
|
|
rect = _auto_person_rect(image) or (w//4, h//6, w//2, int(h*0.7)) |
|
|
bgd, fgd = np.zeros((1,65), np.float64), np.zeros((1,65), np.float64) |
|
|
cv2.grabCut(image, gc, rect, bgd, fgd, GRABCUT_ITERS, cv2.GC_INIT_WITH_MASK) |
|
|
return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD), 255, 0).astype(np.uint8) |
|
|
|
|
|
def _refine_with_saliency(image: np.ndarray, seed: np.ndarray) -> np.ndarray: |
|
|
sal = _compute_saliency(image) |
|
|
if sal is None: |
|
|
return seed |
|
|
high = (sal > SALIENCY_THRESH).astype(np.uint8)*255 |
|
|
ys,xs = np.where(seed>127) |
|
|
cy,cx = int(np.mean(ys)) if len(ys) else image.shape[0]//2, int(np.mean(xs)) if len(xs) else image.shape[1]//2 |
|
|
ff = high.copy() |
|
|
cv2.floodFill(ff, None, (cx,cy), 255, loDiff=5, upDiff=5) |
|
|
return ff |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _validate_mask_quality(mask: np.ndarray, shape: Tuple[int,int]) -> bool: |
|
|
h,w = shape |
|
|
ratio = np.sum(mask>127)/(h*w) |
|
|
return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO |
|
|
|
|
|
def _process_mask(mask: np.ndarray) -> np.ndarray: |
|
|
"""Legacy mask processor - kept for compatibility but mostly replaced by _sam2_to_matanyone_mask""" |
|
|
if mask.dtype in (np.float32, np.float64): |
|
|
if mask.max() <= 1.0: |
|
|
mask = (mask*255).astype(np.uint8) |
|
|
if mask.dtype != np.uint8: |
|
|
mask = mask.astype(np.uint8) |
|
|
if mask.ndim == 3: |
|
|
mask = mask.squeeze() |
|
|
if mask.ndim == 3: |
|
|
mask = mask[:,:,0] |
|
|
_,mask = cv2.threshold(mask,127,255,cv2.THRESH_BINARY) |
|
|
return mask |
|
|
|
|
|
def _geometric_person_mask(image: np.ndarray) -> np.ndarray: |
|
|
h,w = image.shape[:2] |
|
|
mask = np.zeros((h,w), np.uint8) |
|
|
cv2.ellipse(mask, (w//2,h//2), (w//3,int(h/2.5)), 0, 0,360, 255,-1) |
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _auto_refine_mask_iteratively(image, mask, predictor, max_iterations=1): |
|
|
|
|
|
return mask |