|
|
|
|
|
""" |
|
|
utils.refinement |
|
|
High-quality mask refinement for BackgroundFX Pro. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
from typing import Any, Optional, Tuple, List |
|
|
import logging |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MaskRefinementError(Exception): |
|
|
"""Custom exception for mask refinement errors""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"refine_mask_hq", |
|
|
"refine_masks_batch", |
|
|
"MaskRefinementError", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def refine_mask_hq( |
|
|
image: np.ndarray, |
|
|
mask: np.ndarray, |
|
|
matanyone_model: Optional[Any] = None, |
|
|
fallback_enabled: bool = True |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
High-quality mask refinement with multiple strategies. |
|
|
|
|
|
Args: |
|
|
image: Original BGR image |
|
|
mask: Initial binary mask (0/255) |
|
|
matanyone_model: Optional MatAnyone model for AI refinement |
|
|
fallback_enabled: Whether to use fallback methods if AI fails |
|
|
|
|
|
Returns: |
|
|
Refined binary mask (0/255) |
|
|
""" |
|
|
if image is None or mask is None: |
|
|
raise MaskRefinementError("Invalid input image or mask") |
|
|
|
|
|
if image.shape[:2] != mask.shape[:2]: |
|
|
raise MaskRefinementError(f"Image shape {image.shape[:2]} doesn't match mask shape {mask.shape[:2]}") |
|
|
|
|
|
|
|
|
if matanyone_model is not None: |
|
|
try: |
|
|
refined = _refine_with_matanyone(image, mask, matanyone_model) |
|
|
if _validate_refined_mask(refined, mask): |
|
|
return refined |
|
|
log.warning("MatAnyone refinement failed validation") |
|
|
except Exception as e: |
|
|
log.warning(f"MatAnyone refinement failed: {e}") |
|
|
|
|
|
|
|
|
if fallback_enabled: |
|
|
try: |
|
|
return _classical_refinement(image, mask) |
|
|
except Exception as e: |
|
|
log.warning(f"Classical refinement failed: {e}") |
|
|
return mask |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def refine_masks_batch( |
|
|
frames: List[np.ndarray], |
|
|
masks: List[np.ndarray], |
|
|
matanyone_model: Optional[Any] = None, |
|
|
fallback_enabled: bool = True |
|
|
) -> List[np.ndarray]: |
|
|
""" |
|
|
Refine multiple masks using MatAnyone's temporal consistency. |
|
|
|
|
|
Args: |
|
|
frames: List of BGR images |
|
|
masks: List of initial binary masks |
|
|
matanyone_model: MatAnyone InferenceCore model |
|
|
fallback_enabled: Whether to use fallback methods |
|
|
|
|
|
Returns: |
|
|
List of refined binary masks |
|
|
""" |
|
|
if not frames or not masks: |
|
|
return masks |
|
|
|
|
|
if len(frames) != len(masks): |
|
|
raise MaskRefinementError(f"Frame count {len(frames)} doesn't match mask count {len(masks)}") |
|
|
|
|
|
if matanyone_model is not None: |
|
|
try: |
|
|
refined = _refine_batch_with_matanyone(frames, masks, matanyone_model) |
|
|
|
|
|
if all(_validate_refined_mask(r, m) for r, m in zip(refined, masks)): |
|
|
return refined |
|
|
log.warning("Batch MatAnyone refinement failed validation") |
|
|
except Exception as e: |
|
|
log.warning(f"Batch MatAnyone refinement failed: {e}") |
|
|
|
|
|
|
|
|
if fallback_enabled: |
|
|
return [_classical_refinement(f, m) for f, m in zip(frames, masks)] |
|
|
|
|
|
return masks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _refine_with_matanyone( |
|
|
image: np.ndarray, |
|
|
mask: np.ndarray, |
|
|
model: Any |
|
|
) -> np.ndarray: |
|
|
"""Use MatAnyone model for mask refinement.""" |
|
|
try: |
|
|
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
h, w = image_rgb.shape[:2] |
|
|
|
|
|
|
|
|
image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0 |
|
|
image_tensor = image_tensor.unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
if mask.ndim == 3: |
|
|
|
|
|
if mask.shape[2] == 3: |
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
|
|
else: |
|
|
mask = mask[:, :, 0] |
|
|
|
|
|
|
|
|
if mask.dtype != np.uint8: |
|
|
mask = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8) |
|
|
|
|
|
|
|
|
assert mask.ndim == 2, f"Mask must be 2D after conversion, got shape {mask.shape}" |
|
|
assert mask.shape == (h, w), f"Mask shape {mask.shape} doesn't match image shape ({h}, {w})" |
|
|
|
|
|
|
|
|
mask_tensor = torch.from_numpy(mask).float() / 255.0 |
|
|
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
assert mask_tensor.shape == (1, 1, h, w), f"Mask tensor wrong shape: {mask_tensor.shape}, expected (1, 1, {h}, {w})" |
|
|
|
|
|
|
|
|
result = None |
|
|
|
|
|
|
|
|
methods = [m for m in dir(model) if not m.startswith('_')] |
|
|
log.debug(f"MatAnyone InferenceCore methods: {methods}") |
|
|
|
|
|
with torch.no_grad(): |
|
|
if hasattr(model, 'step'): |
|
|
|
|
|
result = model.step(image_tensor, mask_tensor) |
|
|
elif hasattr(model, 'process_frame'): |
|
|
result = model.process_frame(image_tensor, mask_tensor) |
|
|
elif hasattr(model, 'forward'): |
|
|
result = model.forward(image_tensor, mask_tensor) |
|
|
elif hasattr(model, '__call__'): |
|
|
result = model(image_tensor, mask_tensor) |
|
|
else: |
|
|
raise MaskRefinementError(f"No recognized method. Available: {methods}") |
|
|
|
|
|
if result is None: |
|
|
raise MaskRefinementError("MatAnyone returned None") |
|
|
|
|
|
|
|
|
alpha = _extract_alpha_from_result(result) |
|
|
|
|
|
|
|
|
if isinstance(alpha, torch.Tensor): |
|
|
alpha = alpha.squeeze().cpu().numpy() |
|
|
|
|
|
if alpha.ndim == 3: |
|
|
alpha = alpha[0] if alpha.shape[0] == 1 else alpha.mean(axis=0) |
|
|
|
|
|
if alpha.dtype != np.uint8: |
|
|
alpha = (alpha * 255).clip(0, 255).astype(np.uint8) |
|
|
|
|
|
if alpha.shape != (h, w): |
|
|
alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
return _process_mask(alpha) |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"MatAnyone processing error: {str(e)}") |
|
|
raise MaskRefinementError(f"MatAnyone processing failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _refine_batch_with_matanyone( |
|
|
frames: List[np.ndarray], |
|
|
masks: List[np.ndarray], |
|
|
model: Any |
|
|
) -> List[np.ndarray]: |
|
|
"""Process batch of frames through MatAnyone for temporal consistency.""" |
|
|
try: |
|
|
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
batch_size = len(frames) |
|
|
h, w = frames[0].shape[:2] |
|
|
|
|
|
|
|
|
frame_tensors = [] |
|
|
for frame in frames: |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0 |
|
|
frame_tensors.append(tensor) |
|
|
|
|
|
|
|
|
batch_tensor = torch.stack(frame_tensors).to(device) |
|
|
|
|
|
|
|
|
first_mask = masks[0] |
|
|
|
|
|
|
|
|
if first_mask.ndim == 3: |
|
|
if first_mask.shape[2] == 3: |
|
|
first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY) |
|
|
else: |
|
|
first_mask = first_mask[:, :, 0] |
|
|
|
|
|
if first_mask.dtype != np.uint8: |
|
|
first_mask = (first_mask * 255).astype(np.uint8) if first_mask.max() <= 1 else first_mask.astype(np.uint8) |
|
|
|
|
|
assert first_mask.ndim == 2, f"First mask must be 2D, got shape {first_mask.shape}" |
|
|
|
|
|
|
|
|
first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0 |
|
|
first_mask_tensor = first_mask_tensor.unsqueeze(0).unsqueeze(0).to(device) |
|
|
|
|
|
refined_masks = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
if hasattr(model, 'process_batch'): |
|
|
|
|
|
results = model.process_batch(batch_tensor, first_mask_tensor) |
|
|
for result in results: |
|
|
alpha = _extract_alpha_from_result(result) |
|
|
refined_masks.append(_tensor_to_mask(alpha, h, w)) |
|
|
|
|
|
elif hasattr(model, 'step'): |
|
|
|
|
|
for i, frame_tensor in enumerate(frame_tensors): |
|
|
frame_on_device = frame_tensor.unsqueeze(0).to(device) |
|
|
if i == 0: |
|
|
|
|
|
result = model.step(frame_on_device, first_mask_tensor) |
|
|
else: |
|
|
|
|
|
result = model.step(frame_on_device, None) |
|
|
|
|
|
alpha = _extract_alpha_from_result(result) |
|
|
refined_masks.append(_tensor_to_mask(alpha, h, w)) |
|
|
|
|
|
else: |
|
|
|
|
|
log.warning("MatAnyone batch processing not available, using frame-by-frame") |
|
|
for frame_tensor, mask in zip(frame_tensors, masks): |
|
|
|
|
|
if mask.ndim == 3: |
|
|
if mask.shape[2] == 3: |
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
|
|
else: |
|
|
mask = mask[:, :, 0] |
|
|
|
|
|
mask_tensor = torch.from_numpy(mask).float() / 255.0 |
|
|
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) |
|
|
frame_on_device = frame_tensor.unsqueeze(0).to(device) |
|
|
|
|
|
result = model(frame_on_device, mask_tensor) |
|
|
alpha = _extract_alpha_from_result(result) |
|
|
refined_masks.append(_tensor_to_mask(alpha, h, w)) |
|
|
|
|
|
return refined_masks |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"Batch MatAnyone processing error: {str(e)}") |
|
|
raise MaskRefinementError(f"Batch processing failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_alpha_from_result(result): |
|
|
"""Extract alpha matte from various result formats.""" |
|
|
if isinstance(result, (tuple, list)): |
|
|
return result[0] if len(result) > 0 else None |
|
|
elif isinstance(result, dict): |
|
|
return result.get('alpha', result.get('matte', result.get('mask', None))) |
|
|
else: |
|
|
return result |
|
|
|
|
|
def _tensor_to_mask(tensor, target_h, target_w): |
|
|
"""Convert tensor to numpy mask with proper sizing.""" |
|
|
if isinstance(tensor, torch.Tensor): |
|
|
mask = tensor.squeeze().cpu().numpy() |
|
|
else: |
|
|
mask = tensor |
|
|
|
|
|
if mask.ndim == 3: |
|
|
mask = mask[0] if mask.shape[0] == 1 else mask.mean(axis=0) |
|
|
|
|
|
if mask.dtype != np.uint8: |
|
|
mask = (mask * 255).clip(0, 255).astype(np.uint8) |
|
|
|
|
|
if mask.shape != (target_h, target_w): |
|
|
mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
return mask |
|
|
|
|
|
def _validate_refined_mask(refined: np.ndarray, original: np.ndarray) -> bool: |
|
|
"""Check if refined mask is reasonable.""" |
|
|
if refined is None or refined.size == 0: |
|
|
return False |
|
|
|
|
|
refined_area = np.sum(refined > 127) |
|
|
original_area = np.sum(original > 127) |
|
|
|
|
|
if refined_area == 0: |
|
|
return False |
|
|
|
|
|
ratio = refined_area / max(original_area, 1) |
|
|
return 0.5 <= ratio <= 2.0 |
|
|
|
|
|
def _process_mask(mask: np.ndarray) -> np.ndarray: |
|
|
"""Convert any mask format to binary 0/255.""" |
|
|
if mask.dtype == np.float32 or mask.dtype == np.float64: |
|
|
if mask.max() <= 1.0: |
|
|
mask = (mask * 255).astype(np.uint8) |
|
|
|
|
|
if mask.dtype != np.uint8: |
|
|
mask = mask.astype(np.uint8) |
|
|
|
|
|
if mask.ndim == 3: |
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
_, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
|
|
return binary |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _classical_refinement(image: np.ndarray, mask: np.ndarray) -> np.ndarray: |
|
|
"""Apply classical CV techniques for mask refinement.""" |
|
|
refined = mask.copy() |
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel) |
|
|
refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel) |
|
|
refined = _edge_aware_smooth(image, refined) |
|
|
refined = _feather_edges(refined, radius=3) |
|
|
refined = _remove_small_components(refined, min_area_ratio=0.005) |
|
|
|
|
|
return refined |
|
|
|
|
|
def _edge_aware_smooth(image: np.ndarray, mask: np.ndarray) -> np.ndarray: |
|
|
"""Apply edge-aware smoothing using guided filter.""" |
|
|
mask_float = mask.astype(np.float32) / 255.0 |
|
|
radius = 5 |
|
|
eps = 0.01 |
|
|
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0 |
|
|
|
|
|
mean_I = cv2.boxFilter(gray, -1, (radius, radius)) |
|
|
mean_p = cv2.boxFilter(mask_float, -1, (radius, radius)) |
|
|
mean_Ip = cv2.boxFilter(gray * mask_float, -1, (radius, radius)) |
|
|
|
|
|
cov_Ip = mean_Ip - mean_I * mean_p |
|
|
mean_II = cv2.boxFilter(gray * gray, -1, (radius, radius)) |
|
|
var_I = mean_II - mean_I * mean_I |
|
|
|
|
|
a = cov_Ip / (var_I + eps) |
|
|
b = mean_p - a * mean_I |
|
|
|
|
|
mean_a = cv2.boxFilter(a, -1, (radius, radius)) |
|
|
mean_b = cv2.boxFilter(b, -1, (radius, radius)) |
|
|
|
|
|
refined = mean_a * gray + mean_b |
|
|
return (refined * 255).clip(0, 255).astype(np.uint8) |
|
|
|
|
|
def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray: |
|
|
"""Slightly blur edges for smoother transitions.""" |
|
|
if radius <= 0: |
|
|
return mask |
|
|
|
|
|
blurred = cv2.GaussianBlur(mask, (radius*2+1, radius*2+1), radius/2) |
|
|
_, binary = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY) |
|
|
return binary |
|
|
|
|
|
def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) -> np.ndarray: |
|
|
"""Remove small disconnected components.""" |
|
|
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) |
|
|
|
|
|
if num_labels <= 1: |
|
|
return mask |
|
|
|
|
|
total_area = mask.shape[0] * mask.shape[1] |
|
|
min_area = int(total_area * min_area_ratio) |
|
|
|
|
|
areas = stats[1:, cv2.CC_STAT_AREA] |
|
|
if len(areas) == 0: |
|
|
return mask |
|
|
|
|
|
max_label = np.argmax(areas) + 1 |
|
|
|
|
|
cleaned = np.zeros_like(mask) |
|
|
for label in range(1, num_labels): |
|
|
if stats[label, cv2.CC_STAT_AREA] >= min_area or label == max_label: |
|
|
cleaned[labels == label] = 255 |
|
|
|
|
|
return cleaned |