MogensR's picture
Update utils/refinement.py
e94d263
#!/usr/bin/env python3
"""
utils.refinement
High-quality mask refinement for BackgroundFX Pro.
"""
from __future__ import annotations
from typing import Any, Optional, Tuple, List
import logging
import cv2
import numpy as np
import torch
log = logging.getLogger(__name__)
# ============================================================================
# CUSTOM EXCEPTION
# ============================================================================
class MaskRefinementError(Exception):
"""Custom exception for mask refinement errors"""
pass
# ============================================================================
# EXPORTS
# ============================================================================
__all__ = [
"refine_mask_hq",
"refine_masks_batch",
"MaskRefinementError",
]
# ============================================================================
# MAIN API - SINGLE FRAME
# ============================================================================
def refine_mask_hq(
image: np.ndarray,
mask: np.ndarray,
matanyone_model: Optional[Any] = None,
fallback_enabled: bool = True
) -> np.ndarray:
"""
High-quality mask refinement with multiple strategies.
Args:
image: Original BGR image
mask: Initial binary mask (0/255)
matanyone_model: Optional MatAnyone model for AI refinement
fallback_enabled: Whether to use fallback methods if AI fails
Returns:
Refined binary mask (0/255)
"""
if image is None or mask is None:
raise MaskRefinementError("Invalid input image or mask")
if image.shape[:2] != mask.shape[:2]:
raise MaskRefinementError(f"Image shape {image.shape[:2]} doesn't match mask shape {mask.shape[:2]}")
# Try AI-based refinement first if model available
if matanyone_model is not None:
try:
refined = _refine_with_matanyone(image, mask, matanyone_model)
if _validate_refined_mask(refined, mask):
return refined
log.warning("MatAnyone refinement failed validation")
except Exception as e:
log.warning(f"MatAnyone refinement failed: {e}")
# Fallback to classical refinement methods
if fallback_enabled:
try:
return _classical_refinement(image, mask)
except Exception as e:
log.warning(f"Classical refinement failed: {e}")
return mask # Return original if all fails
return mask
# ============================================================================
# BATCH PROCESSING FOR TEMPORAL CONSISTENCY
# ============================================================================
def refine_masks_batch(
frames: List[np.ndarray],
masks: List[np.ndarray],
matanyone_model: Optional[Any] = None,
fallback_enabled: bool = True
) -> List[np.ndarray]:
"""
Refine multiple masks using MatAnyone's temporal consistency.
Args:
frames: List of BGR images
masks: List of initial binary masks
matanyone_model: MatAnyone InferenceCore model
fallback_enabled: Whether to use fallback methods
Returns:
List of refined binary masks
"""
if not frames or not masks:
return masks
if len(frames) != len(masks):
raise MaskRefinementError(f"Frame count {len(frames)} doesn't match mask count {len(masks)}")
if matanyone_model is not None:
try:
refined = _refine_batch_with_matanyone(frames, masks, matanyone_model)
# Validate all masks
if all(_validate_refined_mask(r, m) for r, m in zip(refined, masks)):
return refined
log.warning("Batch MatAnyone refinement failed validation")
except Exception as e:
log.warning(f"Batch MatAnyone refinement failed: {e}")
# Fallback to frame-by-frame classical refinement
if fallback_enabled:
return [_classical_refinement(f, m) for f, m in zip(frames, masks)]
return masks
# ============================================================================
# AI-BASED REFINEMENT - SINGLE FRAME
# ============================================================================
def _refine_with_matanyone(
image: np.ndarray,
mask: np.ndarray,
model: Any
) -> np.ndarray:
"""Use MatAnyone model for mask refinement."""
try:
# Set device to GPU (Tesla T4 on cuda:0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Convert BGR to RGB and normalize
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w = image_rgb.shape[:2]
# Convert to torch tensor format (C, H, W) and normalize to [0, 1]
image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0
image_tensor = image_tensor.unsqueeze(0).to(device) # Add batch dimension and move to GPU
# CRITICAL: Ensure mask is 2D before processing
if mask.ndim == 3:
# Convert multi-channel to single channel
if mask.shape[2] == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
else:
mask = mask[:, :, 0]
# Ensure mask is binary uint8
if mask.dtype != np.uint8:
mask = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8)
# Final verification that mask is 2D
assert mask.ndim == 2, f"Mask must be 2D after conversion, got shape {mask.shape}"
assert mask.shape == (h, w), f"Mask shape {mask.shape} doesn't match image shape ({h}, {w})"
# Convert mask to tensor and move to GPU
mask_tensor = torch.from_numpy(mask).float() / 255.0
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, H, W) on GPU
# Verify tensor dimensions
assert mask_tensor.shape == (1, 1, h, w), f"Mask tensor wrong shape: {mask_tensor.shape}, expected (1, 1, {h}, {w})"
# Try different methods on InferenceCore
result = None
# Log available methods for debugging
methods = [m for m in dir(model) if not m.startswith('_')]
log.debug(f"MatAnyone InferenceCore methods: {methods}")
with torch.no_grad():
if hasattr(model, 'step'):
# Step method for iterative processing
result = model.step(image_tensor, mask_tensor)
elif hasattr(model, 'process_frame'):
result = model.process_frame(image_tensor, mask_tensor)
elif hasattr(model, 'forward'):
result = model.forward(image_tensor, mask_tensor)
elif hasattr(model, '__call__'):
result = model(image_tensor, mask_tensor)
else:
raise MaskRefinementError(f"No recognized method. Available: {methods}")
if result is None:
raise MaskRefinementError("MatAnyone returned None")
# Extract alpha matte from result
alpha = _extract_alpha_from_result(result)
# Convert back to numpy and resize if needed
if isinstance(alpha, torch.Tensor):
alpha = alpha.squeeze().cpu().numpy()
if alpha.ndim == 3:
alpha = alpha[0] if alpha.shape[0] == 1 else alpha.mean(axis=0)
if alpha.dtype != np.uint8:
alpha = (alpha * 255).clip(0, 255).astype(np.uint8)
if alpha.shape != (h, w):
alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LINEAR)
return _process_mask(alpha)
except Exception as e:
log.error(f"MatAnyone processing error: {str(e)}")
raise MaskRefinementError(f"MatAnyone processing failed: {str(e)}")
# ============================================================================
# AI-BASED REFINEMENT - BATCH
# ============================================================================
def _refine_batch_with_matanyone(
frames: List[np.ndarray],
masks: List[np.ndarray],
model: Any
) -> List[np.ndarray]:
"""Process batch of frames through MatAnyone for temporal consistency."""
try:
# Set device to GPU (Tesla T4 on cuda:0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = len(frames)
h, w = frames[0].shape[:2]
# Convert frames to tensor batch and move to GPU
frame_tensors = []
for frame in frames:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
frame_tensors.append(tensor)
# Stack into batch (N, C, H, W) and move to GPU
batch_tensor = torch.stack(frame_tensors).to(device)
# Prepare first mask for initialization
first_mask = masks[0]
# CRITICAL: Ensure first mask is 2D
if first_mask.ndim == 3:
if first_mask.shape[2] == 3:
first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY)
else:
first_mask = first_mask[:, :, 0]
if first_mask.dtype != np.uint8:
first_mask = (first_mask * 255).astype(np.uint8) if first_mask.max() <= 1 else first_mask.astype(np.uint8)
assert first_mask.ndim == 2, f"First mask must be 2D, got shape {first_mask.shape}"
# Convert first mask to tensor and move to GPU
first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
first_mask_tensor = first_mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
refined_masks = []
with torch.no_grad():
# Check for batch processing methods
if hasattr(model, 'process_batch'):
# Direct batch processing
results = model.process_batch(batch_tensor, first_mask_tensor)
for result in results:
alpha = _extract_alpha_from_result(result)
refined_masks.append(_tensor_to_mask(alpha, h, w))
elif hasattr(model, 'step'):
# Process frames sequentially with memory
for i, frame_tensor in enumerate(frame_tensors):
frame_on_device = frame_tensor.unsqueeze(0).to(device)
if i == 0:
# First frame with mask
result = model.step(frame_on_device, first_mask_tensor)
else:
# Subsequent frames use memory from previous
result = model.step(frame_on_device, None)
alpha = _extract_alpha_from_result(result)
refined_masks.append(_tensor_to_mask(alpha, h, w))
else:
# Fallback to processing each frame with its mask
log.warning("MatAnyone batch processing not available, using frame-by-frame")
for frame_tensor, mask in zip(frame_tensors, masks):
# Ensure each mask is 2D
if mask.ndim == 3:
if mask.shape[2] == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
else:
mask = mask[:, :, 0]
mask_tensor = torch.from_numpy(mask).float() / 255.0
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
frame_on_device = frame_tensor.unsqueeze(0).to(device)
result = model(frame_on_device, mask_tensor)
alpha = _extract_alpha_from_result(result)
refined_masks.append(_tensor_to_mask(alpha, h, w))
return refined_masks
except Exception as e:
log.error(f"Batch MatAnyone processing error: {str(e)}")
raise MaskRefinementError(f"Batch processing failed: {str(e)}")
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def _extract_alpha_from_result(result):
"""Extract alpha matte from various result formats."""
if isinstance(result, (tuple, list)):
return result[0] if len(result) > 0 else None
elif isinstance(result, dict):
return result.get('alpha', result.get('matte', result.get('mask', None)))
else:
return result
def _tensor_to_mask(tensor, target_h, target_w):
"""Convert tensor to numpy mask with proper sizing."""
if isinstance(tensor, torch.Tensor):
mask = tensor.squeeze().cpu().numpy()
else:
mask = tensor
if mask.ndim == 3:
mask = mask[0] if mask.shape[0] == 1 else mask.mean(axis=0)
if mask.dtype != np.uint8:
mask = (mask * 255).clip(0, 255).astype(np.uint8)
if mask.shape != (target_h, target_w):
mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
return mask
def _validate_refined_mask(refined: np.ndarray, original: np.ndarray) -> bool:
"""Check if refined mask is reasonable."""
if refined is None or refined.size == 0:
return False
refined_area = np.sum(refined > 127)
original_area = np.sum(original > 127)
if refined_area == 0:
return False
ratio = refined_area / max(original_area, 1)
return 0.5 <= ratio <= 2.0
def _process_mask(mask: np.ndarray) -> np.ndarray:
"""Convert any mask format to binary 0/255."""
if mask.dtype == np.float32 or mask.dtype == np.float64:
if mask.max() <= 1.0:
mask = (mask * 255).astype(np.uint8)
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.ndim == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
return binary
# ============================================================================
# CLASSICAL REFINEMENT
# ============================================================================
def _classical_refinement(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Apply classical CV techniques for mask refinement."""
refined = mask.copy()
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel)
refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel)
refined = _edge_aware_smooth(image, refined)
refined = _feather_edges(refined, radius=3)
refined = _remove_small_components(refined, min_area_ratio=0.005)
return refined
def _edge_aware_smooth(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Apply edge-aware smoothing using guided filter."""
mask_float = mask.astype(np.float32) / 255.0
radius = 5
eps = 0.01
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
mean_I = cv2.boxFilter(gray, -1, (radius, radius))
mean_p = cv2.boxFilter(mask_float, -1, (radius, radius))
mean_Ip = cv2.boxFilter(gray * mask_float, -1, (radius, radius))
cov_Ip = mean_Ip - mean_I * mean_p
mean_II = cv2.boxFilter(gray * gray, -1, (radius, radius))
var_I = mean_II - mean_I * mean_I
a = cov_Ip / (var_I + eps)
b = mean_p - a * mean_I
mean_a = cv2.boxFilter(a, -1, (radius, radius))
mean_b = cv2.boxFilter(b, -1, (radius, radius))
refined = mean_a * gray + mean_b
return (refined * 255).clip(0, 255).astype(np.uint8)
def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray:
"""Slightly blur edges for smoother transitions."""
if radius <= 0:
return mask
blurred = cv2.GaussianBlur(mask, (radius*2+1, radius*2+1), radius/2)
_, binary = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY)
return binary
def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) -> np.ndarray:
"""Remove small disconnected components."""
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
if num_labels <= 1:
return mask
total_area = mask.shape[0] * mask.shape[1]
min_area = int(total_area * min_area_ratio)
areas = stats[1:, cv2.CC_STAT_AREA]
if len(areas) == 0:
return mask
max_label = np.argmax(areas) + 1
cleaned = np.zeros_like(mask)
for label in range(1, num_labels):
if stats[label, cv2.CC_STAT_AREA] >= min_area or label == max_label:
cleaned[labels == label] = 255
return cleaned