MogensR commited on
Commit
a0ffb03
Β·
1 Parent(s): 6fdc616

Create utils/refinement.py

Browse files
Files changed (1) hide show
  1. utils/refinement.py +167 -0
utils/refinement.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ utils.refinement
4
+ ─────────────────────────────────────────────────────────────────────────────
5
+ Single-frame mask refinement for BackgroundFX Pro.
6
+
7
+ Public API
8
+ ----------
9
+ refine_mask_hq(image, mask, matanyone_processor, fallback_enabled=True) -> np.ndarray
10
+ """
11
+
12
+ from __future__ import annotations
13
+ from typing import Any, Tuple, Optional
14
+ import logging, cv2, torch, numpy as np
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+ # Quality thresholds (same as before)
19
+ MIN_AREA_RATIO = 0.015
20
+ MAX_AREA_RATIO = 0.97
21
+
22
+ # ────────────────────────────────────────────────────────────────────────────
23
+ # Public
24
+ # ────────────────────────────────────────────────────────────────────────────
25
+ __all__ = ["refine_mask_hq"]
26
+
27
+ def refine_mask_hq(
28
+ image: np.ndarray,
29
+ mask: np.ndarray,
30
+ matanyone_processor: Any,
31
+ fallback_enabled: bool = True,
32
+ ) -> np.ndarray:
33
+ """
34
+ 1) Try MatAnyOne high-quality refinement.
35
+ 2) Otherwise OpenCV β€œenhanced” filter.
36
+ 3) GrabCut and saliency fallbacks.
37
+ Always returns uint8 mask (0/255).
38
+ """
39
+ mask = _process_mask(mask)
40
+
41
+ # 1 β€” MatAnyOne
42
+ if matanyone_processor is not None:
43
+ try:
44
+ refined = _matanyone_refine(image, mask, matanyone_processor)
45
+ if refined is not None and _validate_mask_quality(refined, image.shape[:2]):
46
+ return refined
47
+ log.warning("MatAnyOne produced poor mask; fallback")
48
+ except Exception as e:
49
+ log.warning(f"MatAnyOne error: {e}")
50
+
51
+ # 2 β€” OpenCV β€œenhanced” bilateral+guided+MORPH
52
+ try:
53
+ refined = _opencv_enhance(image, mask)
54
+ if _validate_mask_quality(refined, image.shape[:2]):
55
+ return refined
56
+ except Exception as e:
57
+ log.debug(f"OpenCV enhance error: {e}")
58
+
59
+ # 3 β€” GrabCut + saliency double-fallback
60
+ try:
61
+ gc = _refine_with_grabcut(image, mask)
62
+ if _validate_mask_quality(gc, image.shape[:2]):
63
+ return gc
64
+ sal = _refine_with_saliency(image, mask)
65
+ if _validate_mask_quality(sal, image.shape[:2]):
66
+ return sal
67
+ except Exception as e:
68
+ log.debug(f"GrabCut/saliency fallback error: {e}")
69
+
70
+ # last resort
71
+ return mask if fallback_enabled else _opencv_enhance(image, mask)
72
+
73
+ # ────────────────────────────────────────────────────────────────────────────
74
+ # MatAnyOne wrapper (safe)
75
+ # ────────────────────────────────────────────────────────────────────────────
76
+ def _matanyone_refine(img, mask, proc) -> Optional[np.ndarray]:
77
+ if not (hasattr(proc, "step") and hasattr(proc, "output_prob_to_mask")):
78
+ return None
79
+ # image tensor (C,H,W) float32 0-1
80
+ anp = img.astype(np.float32)
81
+ if anp.max() > 1: anp /= 255.0
82
+ anp = np.transpose(anp, (2,0,1))
83
+ img_t = torch.from_numpy(anp).unsqueeze(0).to(proc.device if hasattr(proc,"device") else "cpu")
84
+ mask_f = mask.astype(np.float32)/255.0
85
+ mask_t = torch.from_numpy(mask_f).unsqueeze(0).to(img_t.device)
86
+
87
+ with torch.no_grad():
88
+ prob = proc.step(img_t, mask_t, objects=[1])
89
+ m = proc.output_prob_to_mask(prob).squeeze().cpu().numpy()
90
+ if m.max() <= 1: m *= 255
91
+ return m.astype(np.uint8)
92
+
93
+ # ────────────────────────────────────────────────────────────────────────────
94
+ # OpenCV enhanced filter chain
95
+ # ────────────────────────────────────────────────────────────────────────────
96
+ def _opencv_enhance(img, mask):
97
+ if mask.ndim == 3: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
98
+ if mask.max()<=1: mask = (mask*255).astype(np.uint8)
99
+ m = cv2.bilateralFilter(mask, 9, 75, 75)
100
+ m = _guided_filter(img, m, r=8, eps=0.2)
101
+ m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5)))
102
+ m = cv2.morphologyEx(m, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))
103
+ m = cv2.GaussianBlur(m,(3,3),0.8)
104
+ _,m = cv2.threshold(m,127,255,cv2.THRESH_BINARY)
105
+ return m
106
+
107
+ def _guided_filter(guide, mask, r=8, eps=0.2):
108
+ g = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY).astype(np.float32)/255.0
109
+ m = mask.astype(np.float32)/255.0
110
+ k = 2*r+1
111
+ mean_g = cv2.boxFilter(g, -1, (k,k))
112
+ mean_m = cv2.boxFilter(m, -1, (k,k))
113
+ corr_gm = cv2.boxFilter(g*m, -1, (k,k))
114
+ cov = corr_gm - mean_g*mean_m
115
+ var_g = cv2.boxFilter(g*g, -1, (k,k)) - mean_g*mean_g
116
+ a = cov/(var_g+eps)
117
+ b = mean_m - a*mean_g
118
+ mean_a = cv2.boxFilter(a, -1, (k,k))
119
+ mean_b = cv2.boxFilter(b, -1, (k,k))
120
+ out = (mean_a*g+mean_b)*255
121
+ return out.astype(np.uint8)
122
+
123
+ # ────────────────────────────────────────────────────────────────────────────
124
+ # GrabCut & saliency fallbacks
125
+ # ────────────────────────────────────────────────────────────────────────────
126
+ def _refine_with_grabcut(img, seed):
127
+ h,w = img.shape[:2]
128
+ gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8)
129
+ gc[seed>200] = cv2.GC_FGD
130
+ rect = (w//4, h//6, w//2, int(h*0.7))
131
+ bgd,fgd = np.zeros((1,65),np.float64), np.zeros((1,65),np.float64)
132
+ cv2.grabCut(img, gc, rect, bgd, fgd, 3, cv2.GC_INIT_WITH_MASK)
133
+ return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD),255,0).astype(np.uint8)
134
+
135
+ def _refine_with_saliency(img, seed):
136
+ sal = _compute_saliency(img)
137
+ if sal is None: return seed
138
+ high = (sal>0.6).astype(np.uint8)*255
139
+ cy,cx = img.shape[0]//2, img.shape[1]//2
140
+ if np.any(seed>127):
141
+ ys,xs = np.where(seed>127); cy,cx=int(np.mean(ys)),int(np.mean(xs))
142
+ ff = high.copy(); cv2.floodFill(ff,None,(cx,cy),255,loDiff=5,upDiff=5)
143
+ return ff
144
+
145
+ def _compute_saliency(img):
146
+ try:
147
+ if hasattr(cv2,"saliency"):
148
+ s=cv2.saliency.StaticSaliencySpectralResidual_create()
149
+ ok,sm=s.computeSaliency(img)
150
+ if ok: return (sm-sm.min())/max(1e-6,sm.max()-sm.min())
151
+ except Exception: pass
152
+ return None
153
+
154
+ # ────────────────────────────────────────────────────────────────────────────
155
+ # Helpers
156
+ # ────────────────────────────────────────────────────────────────────────────
157
+ def _process_mask(mask):
158
+ if mask.ndim==3: mask=cv2.cvtColor(mask,cv2.COLOR_BGR2GRAY)
159
+ if mask.dtype!=np.uint8:
160
+ mask = (mask*255).astype(np.uint8) if mask.max()<=1 else mask.astype(np.uint8)
161
+ _,mask=cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
162
+ return mask
163
+
164
+ def _validate_mask_quality(mask, shape: Tuple[int,int]) -> bool:
165
+ h,w = shape
166
+ ratio = np.sum(mask>127)/(h*w)
167
+ return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO