File size: 8,062 Bytes
a0ffb03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#!/usr/bin/env python3
"""
utils.refinement
─────────────────────────────────────────────────────────────────────────────
Single-frame mask refinement for BackgroundFX Pro.

Public API
----------
refine_mask_hq(image, mask, matanyone_processor, fallback_enabled=True) -> np.ndarray
"""

from __future__ import annotations
from typing import Any, Tuple, Optional
import logging, cv2, torch, numpy as np

log = logging.getLogger(__name__)

# Quality thresholds (same as before)
MIN_AREA_RATIO = 0.015
MAX_AREA_RATIO = 0.97

# ────────────────────────────────────────────────────────────────────────────
# Public
# ────────────────────────────────────────────────────────────────────────────
__all__ = ["refine_mask_hq"]

def refine_mask_hq(
    image: np.ndarray,
    mask:  np.ndarray,
    matanyone_processor: Any,
    fallback_enabled: bool = True,
) -> np.ndarray:
    """
    1) Try MatAnyOne high-quality refinement.
    2) Otherwise OpenCV β€œenhanced” filter.
    3) GrabCut and saliency fallbacks.
    Always returns uint8 mask (0/255).
    """
    mask = _process_mask(mask)

    # 1 β€” MatAnyOne
    if matanyone_processor is not None:
        try:
            refined = _matanyone_refine(image, mask, matanyone_processor)
            if refined is not None and _validate_mask_quality(refined, image.shape[:2]):
                return refined
            log.warning("MatAnyOne produced poor mask; fallback")
        except Exception as e:
            log.warning(f"MatAnyOne error: {e}")

    # 2 β€” OpenCV β€œenhanced” bilateral+guided+MORPH
    try:
        refined = _opencv_enhance(image, mask)
        if _validate_mask_quality(refined, image.shape[:2]):
            return refined
    except Exception as e:
        log.debug(f"OpenCV enhance error: {e}")

    # 3 β€” GrabCut + saliency double-fallback
    try:
        gc = _refine_with_grabcut(image, mask)
        if _validate_mask_quality(gc, image.shape[:2]):
            return gc
        sal = _refine_with_saliency(image, mask)
        if _validate_mask_quality(sal, image.shape[:2]):
            return sal
    except Exception as e:
        log.debug(f"GrabCut/saliency fallback error: {e}")

    # last resort
    return mask if fallback_enabled else _opencv_enhance(image, mask)

# ────────────────────────────────────────────────────────────────────────────
# MatAnyOne wrapper (safe)
# ────────────────────────────────────────────────────────────────────────────
def _matanyone_refine(img, mask, proc) -> Optional[np.ndarray]:
    if not (hasattr(proc, "step") and hasattr(proc, "output_prob_to_mask")):
        return None
    # image tensor (C,H,W) float32 0-1
    anp = img.astype(np.float32)
    if anp.max() > 1: anp /= 255.0
    anp = np.transpose(anp, (2,0,1))
    img_t  = torch.from_numpy(anp).unsqueeze(0).to(proc.device if hasattr(proc,"device") else "cpu")
    mask_f = mask.astype(np.float32)/255.0
    mask_t = torch.from_numpy(mask_f).unsqueeze(0).to(img_t.device)

    with torch.no_grad():
        prob = proc.step(img_t, mask_t, objects=[1])
        m = proc.output_prob_to_mask(prob).squeeze().cpu().numpy()
    if m.max() <= 1: m *= 255
    return m.astype(np.uint8)

# ────────────────────────────────────────────────────────────────────────────
# OpenCV enhanced filter chain
# ────────────────────────────────────────────────────────────────────────────
def _opencv_enhance(img, mask):
    if mask.ndim == 3: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    if mask.max()<=1:  mask = (mask*255).astype(np.uint8)
    m = cv2.bilateralFilter(mask, 9, 75, 75)
    m = _guided_filter(img, m, r=8, eps=0.2)
    m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5)))
    m = cv2.morphologyEx(m, cv2.MORPH_OPEN,  cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))
    m = cv2.GaussianBlur(m,(3,3),0.8)
    _,m = cv2.threshold(m,127,255,cv2.THRESH_BINARY)
    return m

def _guided_filter(guide, mask, r=8, eps=0.2):
    g = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY).astype(np.float32)/255.0
    m = mask.astype(np.float32)/255.0
    k = 2*r+1
    mean_g  = cv2.boxFilter(g, -1, (k,k))
    mean_m  = cv2.boxFilter(m, -1, (k,k))
    corr_gm = cv2.boxFilter(g*m, -1, (k,k))
    cov     = corr_gm - mean_g*mean_m
    var_g   = cv2.boxFilter(g*g, -1, (k,k)) - mean_g*mean_g
    a = cov/(var_g+eps)
    b = mean_m - a*mean_g
    mean_a = cv2.boxFilter(a, -1, (k,k))
    mean_b = cv2.boxFilter(b, -1, (k,k))
    out = (mean_a*g+mean_b)*255
    return out.astype(np.uint8)

# ────────────────────────────────────────────────────────────────────────────
# GrabCut & saliency fallbacks
# ────────────────────────────────────────────────────────────────────────────
def _refine_with_grabcut(img, seed):
    h,w = img.shape[:2]
    gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8)
    gc[seed>200] = cv2.GC_FGD
    rect = (w//4, h//6, w//2, int(h*0.7))
    bgd,fgd = np.zeros((1,65),np.float64), np.zeros((1,65),np.float64)
    cv2.grabCut(img, gc, rect, bgd, fgd, 3, cv2.GC_INIT_WITH_MASK)
    return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD),255,0).astype(np.uint8)

def _refine_with_saliency(img, seed):
    sal = _compute_saliency(img)
    if sal is None: return seed
    high = (sal>0.6).astype(np.uint8)*255
    cy,cx = img.shape[0]//2, img.shape[1]//2
    if np.any(seed>127):
        ys,xs = np.where(seed>127); cy,cx=int(np.mean(ys)),int(np.mean(xs))
    ff = high.copy(); cv2.floodFill(ff,None,(cx,cy),255,loDiff=5,upDiff=5)
    return ff

def _compute_saliency(img):
    try:
        if hasattr(cv2,"saliency"):
            s=cv2.saliency.StaticSaliencySpectralResidual_create()
            ok,sm=s.computeSaliency(img)
            if ok: return (sm-sm.min())/max(1e-6,sm.max()-sm.min())
    except Exception: pass
    return None

# ────────────────────────────────────────────────────────────────────────────
# Helpers
# ────────────────────────────────────────────────────────────────────────────
def _process_mask(mask):
    if mask.ndim==3: mask=cv2.cvtColor(mask,cv2.COLOR_BGR2GRAY)
    if mask.dtype!=np.uint8:
        mask = (mask*255).astype(np.uint8) if mask.max()<=1 else mask.astype(np.uint8)
    _,mask=cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
    return mask

def _validate_mask_quality(mask, shape: Tuple[int,int]) -> bool:
    h,w = shape
    ratio = np.sum(mask>127)/(h*w)
    return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO