File size: 8,062 Bytes
a0ffb03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
#!/usr/bin/env python3
"""
utils.refinement
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Single-frame mask refinement for BackgroundFX Pro.
Public API
----------
refine_mask_hq(image, mask, matanyone_processor, fallback_enabled=True) -> np.ndarray
"""
from __future__ import annotations
from typing import Any, Tuple, Optional
import logging, cv2, torch, numpy as np
log = logging.getLogger(__name__)
# Quality thresholds (same as before)
MIN_AREA_RATIO = 0.015
MAX_AREA_RATIO = 0.97
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Public
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
__all__ = ["refine_mask_hq"]
def refine_mask_hq(
image: np.ndarray,
mask: np.ndarray,
matanyone_processor: Any,
fallback_enabled: bool = True,
) -> np.ndarray:
"""
1) Try MatAnyOne high-quality refinement.
2) Otherwise OpenCV βenhancedβ filter.
3) GrabCut and saliency fallbacks.
Always returns uint8 mask (0/255).
"""
mask = _process_mask(mask)
# 1 β MatAnyOne
if matanyone_processor is not None:
try:
refined = _matanyone_refine(image, mask, matanyone_processor)
if refined is not None and _validate_mask_quality(refined, image.shape[:2]):
return refined
log.warning("MatAnyOne produced poor mask; fallback")
except Exception as e:
log.warning(f"MatAnyOne error: {e}")
# 2 β OpenCV βenhancedβ bilateral+guided+MORPH
try:
refined = _opencv_enhance(image, mask)
if _validate_mask_quality(refined, image.shape[:2]):
return refined
except Exception as e:
log.debug(f"OpenCV enhance error: {e}")
# 3 β GrabCut + saliency double-fallback
try:
gc = _refine_with_grabcut(image, mask)
if _validate_mask_quality(gc, image.shape[:2]):
return gc
sal = _refine_with_saliency(image, mask)
if _validate_mask_quality(sal, image.shape[:2]):
return sal
except Exception as e:
log.debug(f"GrabCut/saliency fallback error: {e}")
# last resort
return mask if fallback_enabled else _opencv_enhance(image, mask)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# MatAnyOne wrapper (safe)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _matanyone_refine(img, mask, proc) -> Optional[np.ndarray]:
if not (hasattr(proc, "step") and hasattr(proc, "output_prob_to_mask")):
return None
# image tensor (C,H,W) float32 0-1
anp = img.astype(np.float32)
if anp.max() > 1: anp /= 255.0
anp = np.transpose(anp, (2,0,1))
img_t = torch.from_numpy(anp).unsqueeze(0).to(proc.device if hasattr(proc,"device") else "cpu")
mask_f = mask.astype(np.float32)/255.0
mask_t = torch.from_numpy(mask_f).unsqueeze(0).to(img_t.device)
with torch.no_grad():
prob = proc.step(img_t, mask_t, objects=[1])
m = proc.output_prob_to_mask(prob).squeeze().cpu().numpy()
if m.max() <= 1: m *= 255
return m.astype(np.uint8)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# OpenCV enhanced filter chain
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _opencv_enhance(img, mask):
if mask.ndim == 3: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
if mask.max()<=1: mask = (mask*255).astype(np.uint8)
m = cv2.bilateralFilter(mask, 9, 75, 75)
m = _guided_filter(img, m, r=8, eps=0.2)
m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5)))
m = cv2.morphologyEx(m, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))
m = cv2.GaussianBlur(m,(3,3),0.8)
_,m = cv2.threshold(m,127,255,cv2.THRESH_BINARY)
return m
def _guided_filter(guide, mask, r=8, eps=0.2):
g = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY).astype(np.float32)/255.0
m = mask.astype(np.float32)/255.0
k = 2*r+1
mean_g = cv2.boxFilter(g, -1, (k,k))
mean_m = cv2.boxFilter(m, -1, (k,k))
corr_gm = cv2.boxFilter(g*m, -1, (k,k))
cov = corr_gm - mean_g*mean_m
var_g = cv2.boxFilter(g*g, -1, (k,k)) - mean_g*mean_g
a = cov/(var_g+eps)
b = mean_m - a*mean_g
mean_a = cv2.boxFilter(a, -1, (k,k))
mean_b = cv2.boxFilter(b, -1, (k,k))
out = (mean_a*g+mean_b)*255
return out.astype(np.uint8)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# GrabCut & saliency fallbacks
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _refine_with_grabcut(img, seed):
h,w = img.shape[:2]
gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8)
gc[seed>200] = cv2.GC_FGD
rect = (w//4, h//6, w//2, int(h*0.7))
bgd,fgd = np.zeros((1,65),np.float64), np.zeros((1,65),np.float64)
cv2.grabCut(img, gc, rect, bgd, fgd, 3, cv2.GC_INIT_WITH_MASK)
return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD),255,0).astype(np.uint8)
def _refine_with_saliency(img, seed):
sal = _compute_saliency(img)
if sal is None: return seed
high = (sal>0.6).astype(np.uint8)*255
cy,cx = img.shape[0]//2, img.shape[1]//2
if np.any(seed>127):
ys,xs = np.where(seed>127); cy,cx=int(np.mean(ys)),int(np.mean(xs))
ff = high.copy(); cv2.floodFill(ff,None,(cx,cy),255,loDiff=5,upDiff=5)
return ff
def _compute_saliency(img):
try:
if hasattr(cv2,"saliency"):
s=cv2.saliency.StaticSaliencySpectralResidual_create()
ok,sm=s.computeSaliency(img)
if ok: return (sm-sm.min())/max(1e-6,sm.max()-sm.min())
except Exception: pass
return None
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Helpers
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _process_mask(mask):
if mask.ndim==3: mask=cv2.cvtColor(mask,cv2.COLOR_BGR2GRAY)
if mask.dtype!=np.uint8:
mask = (mask*255).astype(np.uint8) if mask.max()<=1 else mask.astype(np.uint8)
_,mask=cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
return mask
def _validate_mask_quality(mask, shape: Tuple[int,int]) -> bool:
h,w = shape
ratio = np.sum(mask>127)/(h*w)
return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO
|