Update utils/refinement.py
Browse files- utils/refinement.py +181 -94
utils/refinement.py
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
-
from typing import Any, Optional, Tuple
|
| 9 |
import logging
|
| 10 |
|
| 11 |
import cv2
|
|
@@ -26,11 +26,12 @@ class MaskRefinementError(Exception):
|
|
| 26 |
# ============================================================================
|
| 27 |
__all__ = [
|
| 28 |
"refine_mask_hq",
|
|
|
|
| 29 |
"MaskRefinementError",
|
| 30 |
]
|
| 31 |
|
| 32 |
# ============================================================================
|
| 33 |
-
# MAIN API
|
| 34 |
# ============================================================================
|
| 35 |
def refine_mask_hq(
|
| 36 |
image: np.ndarray,
|
|
@@ -77,7 +78,50 @@ def refine_mask_hq(
|
|
| 77 |
return mask
|
| 78 |
|
| 79 |
# ============================================================================
|
| 80 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
# ============================================================================
|
| 82 |
def _refine_with_matanyone(
|
| 83 |
image: np.ndarray,
|
|
@@ -86,7 +130,6 @@ def _refine_with_matanyone(
|
|
| 86 |
) -> np.ndarray:
|
| 87 |
"""Use MatAnyone model for mask refinement."""
|
| 88 |
try:
|
| 89 |
-
# MatAnyone's InferenceCore expects torch tensors
|
| 90 |
# Convert BGR to RGB and normalize
|
| 91 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 92 |
h, w = image_rgb.shape[:2]
|
|
@@ -105,74 +148,42 @@ def _refine_with_matanyone(
|
|
| 105 |
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 106 |
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
|
| 107 |
|
| 108 |
-
#
|
| 109 |
-
# The model should have been initialized as InferenceCore(matanyone_model)
|
| 110 |
result = None
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# Initialize memory with first frame
|
| 120 |
-
model.reset()
|
| 121 |
-
# Process frame with mask
|
| 122 |
result = model.step(image_tensor, mask_tensor)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
result = model.forward(image_tensor, mask_tensor)
|
| 127 |
-
|
| 128 |
-
# Predict method
|
| 129 |
-
with torch.no_grad():
|
| 130 |
-
result = model.predict(image_tensor, mask_tensor)
|
| 131 |
-
elif hasattr(model, '__call__'):
|
| 132 |
-
# Callable model
|
| 133 |
-
with torch.no_grad():
|
| 134 |
result = model(image_tensor, mask_tensor)
|
| 135 |
-
else:
|
| 136 |
-
# Try to find any method that might work
|
| 137 |
-
methods = [m for m in dir(model) if not m.startswith('_')]
|
| 138 |
-
processing_methods = [m for m in methods if any(keyword in m.lower()
|
| 139 |
-
for keyword in ['process', 'refine', 'matte', 'alpha', 'predict'])]
|
| 140 |
-
if processing_methods:
|
| 141 |
-
method = getattr(model, processing_methods[0])
|
| 142 |
-
with torch.no_grad():
|
| 143 |
-
result = method(image_tensor, mask_tensor)
|
| 144 |
else:
|
| 145 |
-
raise MaskRefinementError(f"
|
| 146 |
|
| 147 |
if result is None:
|
| 148 |
raise MaskRefinementError("MatAnyone returned None")
|
| 149 |
|
| 150 |
-
#
|
| 151 |
-
|
| 152 |
-
# Extract alpha matte from tuple/list result
|
| 153 |
-
alpha = result[0] if len(result) > 0 else None
|
| 154 |
-
elif isinstance(result, dict):
|
| 155 |
-
# Extract from dictionary result
|
| 156 |
-
alpha = result.get('alpha', result.get('matte', result.get('mask', None)))
|
| 157 |
-
else:
|
| 158 |
-
alpha = result
|
| 159 |
-
|
| 160 |
-
if alpha is None:
|
| 161 |
-
raise MaskRefinementError("Could not extract alpha matte from MatAnyone result")
|
| 162 |
|
| 163 |
-
# Convert back to numpy
|
| 164 |
if isinstance(alpha, torch.Tensor):
|
| 165 |
-
alpha = alpha.squeeze().cpu().numpy()
|
| 166 |
-
|
| 167 |
-
# Ensure proper shape
|
| 168 |
if alpha.ndim == 3:
|
| 169 |
alpha = alpha[0] if alpha.shape[0] == 1 else alpha.mean(axis=0)
|
| 170 |
-
|
| 171 |
-
# Convert to uint8
|
| 172 |
if alpha.dtype != np.uint8:
|
| 173 |
alpha = (alpha * 255).clip(0, 255).astype(np.uint8)
|
| 174 |
-
|
| 175 |
-
# Resize if needed
|
| 176 |
if alpha.shape != (h, w):
|
| 177 |
alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LINEAR)
|
| 178 |
|
|
@@ -183,44 +194,121 @@ def _refine_with_matanyone(
|
|
| 183 |
raise MaskRefinementError(f"MatAnyone processing failed: {str(e)}")
|
| 184 |
|
| 185 |
# ============================================================================
|
| 186 |
-
#
|
| 187 |
# ============================================================================
|
| 188 |
-
def
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
-
|
| 198 |
-
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
|
| 206 |
-
return
|
| 207 |
|
| 208 |
-
# ============================================================================
|
| 209 |
-
# HELPER FUNCTIONS
|
| 210 |
-
# ============================================================================
|
| 211 |
def _validate_refined_mask(refined: np.ndarray, original: np.ndarray) -> bool:
|
| 212 |
"""Check if refined mask is reasonable."""
|
| 213 |
if refined is None or refined.size == 0:
|
| 214 |
return False
|
| 215 |
|
| 216 |
-
# Check if mask has reasonable coverage
|
| 217 |
refined_area = np.sum(refined > 127)
|
| 218 |
original_area = np.sum(original > 127)
|
| 219 |
|
| 220 |
if refined_area == 0:
|
| 221 |
return False
|
| 222 |
|
| 223 |
-
# Allow some variation but not extreme changes
|
| 224 |
ratio = refined_area / max(original_area, 1)
|
| 225 |
return 0.5 <= ratio <= 2.0
|
| 226 |
|
|
@@ -239,41 +327,45 @@ def _process_mask(mask: np.ndarray) -> np.ndarray:
|
|
| 239 |
_, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
|
| 240 |
return binary
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
def _edge_aware_smooth(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 243 |
"""Apply edge-aware smoothing using guided filter."""
|
| 244 |
-
# Convert to float for processing
|
| 245 |
mask_float = mask.astype(np.float32) / 255.0
|
| 246 |
-
|
| 247 |
-
# Simple guided filter approximation
|
| 248 |
radius = 5
|
| 249 |
eps = 0.01
|
| 250 |
|
| 251 |
-
# Use image as guide
|
| 252 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
|
| 253 |
|
| 254 |
-
# Box filter for mean
|
| 255 |
mean_I = cv2.boxFilter(gray, -1, (radius, radius))
|
| 256 |
mean_p = cv2.boxFilter(mask_float, -1, (radius, radius))
|
| 257 |
mean_Ip = cv2.boxFilter(gray * mask_float, -1, (radius, radius))
|
| 258 |
|
| 259 |
-
# Covariance
|
| 260 |
cov_Ip = mean_Ip - mean_I * mean_p
|
| 261 |
-
|
| 262 |
-
# Variance
|
| 263 |
mean_II = cv2.boxFilter(gray * gray, -1, (radius, radius))
|
| 264 |
var_I = mean_II - mean_I * mean_I
|
| 265 |
|
| 266 |
-
# Coefficients
|
| 267 |
a = cov_Ip / (var_I + eps)
|
| 268 |
b = mean_p - a * mean_I
|
| 269 |
|
| 270 |
-
# Filter
|
| 271 |
mean_a = cv2.boxFilter(a, -1, (radius, radius))
|
| 272 |
mean_b = cv2.boxFilter(b, -1, (radius, radius))
|
| 273 |
|
| 274 |
refined = mean_a * gray + mean_b
|
| 275 |
-
|
| 276 |
-
# Convert back to binary
|
| 277 |
return (refined * 255).clip(0, 255).astype(np.uint8)
|
| 278 |
|
| 279 |
def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray:
|
|
@@ -281,10 +373,8 @@ def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray:
|
|
| 281 |
if radius <= 0:
|
| 282 |
return mask
|
| 283 |
|
| 284 |
-
# Blur then threshold to maintain binary nature
|
| 285 |
blurred = cv2.GaussianBlur(mask, (radius*2+1, radius*2+1), radius/2)
|
| 286 |
_, binary = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY)
|
| 287 |
-
|
| 288 |
return binary
|
| 289 |
|
| 290 |
def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) -> np.ndarray:
|
|
@@ -294,18 +384,15 @@ def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) ->
|
|
| 294 |
if num_labels <= 1:
|
| 295 |
return mask
|
| 296 |
|
| 297 |
-
# Calculate minimum area
|
| 298 |
total_area = mask.shape[0] * mask.shape[1]
|
| 299 |
min_area = int(total_area * min_area_ratio)
|
| 300 |
|
| 301 |
-
# Find largest component (excluding background)
|
| 302 |
areas = stats[1:, cv2.CC_STAT_AREA]
|
| 303 |
if len(areas) == 0:
|
| 304 |
return mask
|
| 305 |
|
| 306 |
max_label = np.argmax(areas) + 1
|
| 307 |
|
| 308 |
-
# Keep only components above threshold or the largest one
|
| 309 |
cleaned = np.zeros_like(mask)
|
| 310 |
for label in range(1, num_labels):
|
| 311 |
if stats[label, cv2.CC_STAT_AREA] >= min_area or label == max_label:
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
+
from typing import Any, Optional, Tuple, List
|
| 9 |
import logging
|
| 10 |
|
| 11 |
import cv2
|
|
|
|
| 26 |
# ============================================================================
|
| 27 |
__all__ = [
|
| 28 |
"refine_mask_hq",
|
| 29 |
+
"refine_masks_batch",
|
| 30 |
"MaskRefinementError",
|
| 31 |
]
|
| 32 |
|
| 33 |
# ============================================================================
|
| 34 |
+
# MAIN API - SINGLE FRAME
|
| 35 |
# ============================================================================
|
| 36 |
def refine_mask_hq(
|
| 37 |
image: np.ndarray,
|
|
|
|
| 78 |
return mask
|
| 79 |
|
| 80 |
# ============================================================================
|
| 81 |
+
# BATCH PROCESSING FOR TEMPORAL CONSISTENCY
|
| 82 |
+
# ============================================================================
|
| 83 |
+
def refine_masks_batch(
|
| 84 |
+
frames: List[np.ndarray],
|
| 85 |
+
masks: List[np.ndarray],
|
| 86 |
+
matanyone_model: Optional[Any] = None,
|
| 87 |
+
fallback_enabled: bool = True
|
| 88 |
+
) -> List[np.ndarray]:
|
| 89 |
+
"""
|
| 90 |
+
Refine multiple masks using MatAnyone's temporal consistency.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
frames: List of BGR images
|
| 94 |
+
masks: List of initial binary masks
|
| 95 |
+
matanyone_model: MatAnyone InferenceCore model
|
| 96 |
+
fallback_enabled: Whether to use fallback methods
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
List of refined binary masks
|
| 100 |
+
"""
|
| 101 |
+
if not frames or not masks:
|
| 102 |
+
return masks
|
| 103 |
+
|
| 104 |
+
if len(frames) != len(masks):
|
| 105 |
+
raise MaskRefinementError(f"Frame count {len(frames)} doesn't match mask count {len(masks)}")
|
| 106 |
+
|
| 107 |
+
if matanyone_model is not None:
|
| 108 |
+
try:
|
| 109 |
+
refined = _refine_batch_with_matanyone(frames, masks, matanyone_model)
|
| 110 |
+
# Validate all masks
|
| 111 |
+
if all(_validate_refined_mask(r, m) for r, m in zip(refined, masks)):
|
| 112 |
+
return refined
|
| 113 |
+
log.warning("Batch MatAnyone refinement failed validation")
|
| 114 |
+
except Exception as e:
|
| 115 |
+
log.warning(f"Batch MatAnyone refinement failed: {e}")
|
| 116 |
+
|
| 117 |
+
# Fallback to frame-by-frame classical refinement
|
| 118 |
+
if fallback_enabled:
|
| 119 |
+
return [_classical_refinement(f, m) for f, m in zip(frames, masks)]
|
| 120 |
+
|
| 121 |
+
return masks
|
| 122 |
+
|
| 123 |
+
# ============================================================================
|
| 124 |
+
# AI-BASED REFINEMENT - SINGLE FRAME
|
| 125 |
# ============================================================================
|
| 126 |
def _refine_with_matanyone(
|
| 127 |
image: np.ndarray,
|
|
|
|
| 130 |
) -> np.ndarray:
|
| 131 |
"""Use MatAnyone model for mask refinement."""
|
| 132 |
try:
|
|
|
|
| 133 |
# Convert BGR to RGB and normalize
|
| 134 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 135 |
h, w = image_rgb.shape[:2]
|
|
|
|
| 148 |
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 149 |
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
|
| 150 |
|
| 151 |
+
# Try different methods on InferenceCore
|
|
|
|
| 152 |
result = None
|
| 153 |
|
| 154 |
+
# Log available methods for debugging
|
| 155 |
+
methods = [m for m in dir(model) if not m.startswith('_')]
|
| 156 |
+
log.debug(f"MatAnyone InferenceCore methods: {methods}")
|
| 157 |
+
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
if hasattr(model, 'step'):
|
| 160 |
+
# Step method for iterative processing (don't call reset)
|
|
|
|
|
|
|
|
|
|
| 161 |
result = model.step(image_tensor, mask_tensor)
|
| 162 |
+
elif hasattr(model, 'process_frame'):
|
| 163 |
+
result = model.process_frame(image_tensor, mask_tensor)
|
| 164 |
+
elif hasattr(model, 'forward'):
|
| 165 |
result = model.forward(image_tensor, mask_tensor)
|
| 166 |
+
elif hasattr(model, '__call__'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
result = model(image_tensor, mask_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
else:
|
| 169 |
+
raise MaskRefinementError(f"No recognized method. Available: {methods}")
|
| 170 |
|
| 171 |
if result is None:
|
| 172 |
raise MaskRefinementError("MatAnyone returned None")
|
| 173 |
|
| 174 |
+
# Extract alpha matte from result
|
| 175 |
+
alpha = _extract_alpha_from_result(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
# Convert back to numpy and resize if needed
|
| 178 |
if isinstance(alpha, torch.Tensor):
|
| 179 |
+
alpha = alpha.squeeze().cpu().numpy()
|
| 180 |
+
|
|
|
|
| 181 |
if alpha.ndim == 3:
|
| 182 |
alpha = alpha[0] if alpha.shape[0] == 1 else alpha.mean(axis=0)
|
| 183 |
+
|
|
|
|
| 184 |
if alpha.dtype != np.uint8:
|
| 185 |
alpha = (alpha * 255).clip(0, 255).astype(np.uint8)
|
| 186 |
+
|
|
|
|
| 187 |
if alpha.shape != (h, w):
|
| 188 |
alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LINEAR)
|
| 189 |
|
|
|
|
| 194 |
raise MaskRefinementError(f"MatAnyone processing failed: {str(e)}")
|
| 195 |
|
| 196 |
# ============================================================================
|
| 197 |
+
# AI-BASED REFINEMENT - BATCH
|
| 198 |
# ============================================================================
|
| 199 |
+
def _refine_batch_with_matanyone(
|
| 200 |
+
frames: List[np.ndarray],
|
| 201 |
+
masks: List[np.ndarray],
|
| 202 |
+
model: Any
|
| 203 |
+
) -> List[np.ndarray]:
|
| 204 |
+
"""Process batch of frames through MatAnyone for temporal consistency."""
|
| 205 |
+
try:
|
| 206 |
+
batch_size = len(frames)
|
| 207 |
+
h, w = frames[0].shape[:2]
|
| 208 |
+
|
| 209 |
+
# Convert frames to tensor batch
|
| 210 |
+
frame_tensors = []
|
| 211 |
+
for frame in frames:
|
| 212 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 213 |
+
tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
|
| 214 |
+
frame_tensors.append(tensor)
|
| 215 |
+
|
| 216 |
+
# Stack into batch (N, C, H, W)
|
| 217 |
+
batch_tensor = torch.stack(frame_tensors)
|
| 218 |
+
|
| 219 |
+
# Prepare first mask for initialization
|
| 220 |
+
first_mask = masks[0]
|
| 221 |
+
if first_mask.dtype != np.uint8:
|
| 222 |
+
first_mask = (first_mask * 255).astype(np.uint8)
|
| 223 |
+
if first_mask.ndim == 3:
|
| 224 |
+
first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY)
|
| 225 |
+
|
| 226 |
+
# Convert first mask to tensor
|
| 227 |
+
first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
|
| 228 |
+
first_mask_tensor = first_mask_tensor.unsqueeze(0).unsqueeze(0)
|
| 229 |
+
|
| 230 |
+
refined_masks = []
|
| 231 |
+
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
# Check for batch processing methods
|
| 234 |
+
if hasattr(model, 'process_batch'):
|
| 235 |
+
# Direct batch processing
|
| 236 |
+
results = model.process_batch(batch_tensor, first_mask_tensor)
|
| 237 |
+
for result in results:
|
| 238 |
+
alpha = _extract_alpha_from_result(result)
|
| 239 |
+
refined_masks.append(_tensor_to_mask(alpha, h, w))
|
| 240 |
+
|
| 241 |
+
elif hasattr(model, 'step'):
|
| 242 |
+
# Process frames sequentially with memory
|
| 243 |
+
for i, frame_tensor in enumerate(frame_tensors):
|
| 244 |
+
if i == 0:
|
| 245 |
+
# First frame with mask
|
| 246 |
+
result = model.step(frame_tensor.unsqueeze(0), first_mask_tensor)
|
| 247 |
+
else:
|
| 248 |
+
# Subsequent frames use memory from previous
|
| 249 |
+
result = model.step(frame_tensor.unsqueeze(0), None)
|
| 250 |
+
|
| 251 |
+
alpha = _extract_alpha_from_result(result)
|
| 252 |
+
refined_masks.append(_tensor_to_mask(alpha, h, w))
|
| 253 |
+
|
| 254 |
+
else:
|
| 255 |
+
# Fallback to processing each frame with its mask
|
| 256 |
+
log.warning("MatAnyone batch processing not available, using frame-by-frame")
|
| 257 |
+
for frame_tensor, mask in zip(frame_tensors, masks):
|
| 258 |
+
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 259 |
+
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0)
|
| 260 |
+
|
| 261 |
+
result = model(frame_tensor.unsqueeze(0), mask_tensor)
|
| 262 |
+
alpha = _extract_alpha_from_result(result)
|
| 263 |
+
refined_masks.append(_tensor_to_mask(alpha, h, w))
|
| 264 |
+
|
| 265 |
+
return refined_masks
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
log.error(f"Batch MatAnyone processing error: {str(e)}")
|
| 269 |
+
raise MaskRefinementError(f"Batch processing failed: {str(e)}")
|
| 270 |
+
|
| 271 |
+
# ============================================================================
|
| 272 |
+
# HELPER FUNCTIONS
|
| 273 |
+
# ============================================================================
|
| 274 |
+
def _extract_alpha_from_result(result):
|
| 275 |
+
"""Extract alpha matte from various result formats."""
|
| 276 |
+
if isinstance(result, (tuple, list)):
|
| 277 |
+
return result[0] if len(result) > 0 else None
|
| 278 |
+
elif isinstance(result, dict):
|
| 279 |
+
return result.get('alpha', result.get('matte', result.get('mask', None)))
|
| 280 |
+
else:
|
| 281 |
+
return result
|
| 282 |
+
|
| 283 |
+
def _tensor_to_mask(tensor, target_h, target_w):
|
| 284 |
+
"""Convert tensor to numpy mask with proper sizing."""
|
| 285 |
+
if isinstance(tensor, torch.Tensor):
|
| 286 |
+
mask = tensor.squeeze().cpu().numpy()
|
| 287 |
+
else:
|
| 288 |
+
mask = tensor
|
| 289 |
|
| 290 |
+
if mask.ndim == 3:
|
| 291 |
+
mask = mask[0] if mask.shape[0] == 1 else mask.mean(axis=0)
|
| 292 |
|
| 293 |
+
if mask.dtype != np.uint8:
|
| 294 |
+
mask = (mask * 255).clip(0, 255).astype(np.uint8)
|
| 295 |
|
| 296 |
+
if mask.shape != (target_h, target_w):
|
| 297 |
+
mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
| 298 |
|
| 299 |
+
return mask
|
| 300 |
|
|
|
|
|
|
|
|
|
|
| 301 |
def _validate_refined_mask(refined: np.ndarray, original: np.ndarray) -> bool:
|
| 302 |
"""Check if refined mask is reasonable."""
|
| 303 |
if refined is None or refined.size == 0:
|
| 304 |
return False
|
| 305 |
|
|
|
|
| 306 |
refined_area = np.sum(refined > 127)
|
| 307 |
original_area = np.sum(original > 127)
|
| 308 |
|
| 309 |
if refined_area == 0:
|
| 310 |
return False
|
| 311 |
|
|
|
|
| 312 |
ratio = refined_area / max(original_area, 1)
|
| 313 |
return 0.5 <= ratio <= 2.0
|
| 314 |
|
|
|
|
| 327 |
_, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
|
| 328 |
return binary
|
| 329 |
|
| 330 |
+
# ============================================================================
|
| 331 |
+
# CLASSICAL REFINEMENT
|
| 332 |
+
# ============================================================================
|
| 333 |
+
def _classical_refinement(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 334 |
+
"""Apply classical CV techniques for mask refinement."""
|
| 335 |
+
refined = mask.copy()
|
| 336 |
+
|
| 337 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 338 |
+
refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel)
|
| 339 |
+
refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel)
|
| 340 |
+
refined = _edge_aware_smooth(image, refined)
|
| 341 |
+
refined = _feather_edges(refined, radius=3)
|
| 342 |
+
refined = _remove_small_components(refined, min_area_ratio=0.005)
|
| 343 |
+
|
| 344 |
+
return refined
|
| 345 |
+
|
| 346 |
def _edge_aware_smooth(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 347 |
"""Apply edge-aware smoothing using guided filter."""
|
|
|
|
| 348 |
mask_float = mask.astype(np.float32) / 255.0
|
|
|
|
|
|
|
| 349 |
radius = 5
|
| 350 |
eps = 0.01
|
| 351 |
|
|
|
|
| 352 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
|
| 353 |
|
|
|
|
| 354 |
mean_I = cv2.boxFilter(gray, -1, (radius, radius))
|
| 355 |
mean_p = cv2.boxFilter(mask_float, -1, (radius, radius))
|
| 356 |
mean_Ip = cv2.boxFilter(gray * mask_float, -1, (radius, radius))
|
| 357 |
|
|
|
|
| 358 |
cov_Ip = mean_Ip - mean_I * mean_p
|
|
|
|
|
|
|
| 359 |
mean_II = cv2.boxFilter(gray * gray, -1, (radius, radius))
|
| 360 |
var_I = mean_II - mean_I * mean_I
|
| 361 |
|
|
|
|
| 362 |
a = cov_Ip / (var_I + eps)
|
| 363 |
b = mean_p - a * mean_I
|
| 364 |
|
|
|
|
| 365 |
mean_a = cv2.boxFilter(a, -1, (radius, radius))
|
| 366 |
mean_b = cv2.boxFilter(b, -1, (radius, radius))
|
| 367 |
|
| 368 |
refined = mean_a * gray + mean_b
|
|
|
|
|
|
|
| 369 |
return (refined * 255).clip(0, 255).astype(np.uint8)
|
| 370 |
|
| 371 |
def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray:
|
|
|
|
| 373 |
if radius <= 0:
|
| 374 |
return mask
|
| 375 |
|
|
|
|
| 376 |
blurred = cv2.GaussianBlur(mask, (radius*2+1, radius*2+1), radius/2)
|
| 377 |
_, binary = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY)
|
|
|
|
| 378 |
return binary
|
| 379 |
|
| 380 |
def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) -> np.ndarray:
|
|
|
|
| 384 |
if num_labels <= 1:
|
| 385 |
return mask
|
| 386 |
|
|
|
|
| 387 |
total_area = mask.shape[0] * mask.shape[1]
|
| 388 |
min_area = int(total_area * min_area_ratio)
|
| 389 |
|
|
|
|
| 390 |
areas = stats[1:, cv2.CC_STAT_AREA]
|
| 391 |
if len(areas) == 0:
|
| 392 |
return mask
|
| 393 |
|
| 394 |
max_label = np.argmax(areas) + 1
|
| 395 |
|
|
|
|
| 396 |
cleaned = np.zeros_like(mask)
|
| 397 |
for label in range(1, num_labels):
|
| 398 |
if stats[label, cv2.CC_STAT_AREA] >= min_area or label == max_label:
|