MogensR's picture
Create processing/matting.py
e5e6fe5
raw
history blame
14.6 kB
"""
Advanced matting algorithms for BackgroundFX Pro.
Implements multiple matting techniques with automatic fallback.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from typing import Dict, Tuple, Optional, List
from dataclasses import dataclass
import logging
from ..utils.logger import setup_logger
from ..utils.device import DeviceManager
from ..utils.config import ConfigManager
from ..core.models import ModelFactory, ModelType
from ..core.quality import QualityAnalyzer
from ..core.edge import EdgeRefinement
logger = setup_logger(__name__)
@dataclass
class MattingConfig:
"""Configuration for matting operations."""
alpha_threshold: float = 0.5
erode_iterations: int = 2
dilate_iterations: int = 2
blur_radius: int = 3
trimap_size: int = 30
confidence_threshold: float = 0.7
use_guided_filter: bool = True
guided_filter_radius: int = 8
guided_filter_eps: float = 1e-6
use_temporal_smoothing: bool = False
temporal_window: int = 5
class AlphaMatting:
"""Advanced alpha matting using multiple techniques."""
def __init__(self, config: Optional[MattingConfig] = None):
self.config = config or MattingConfig()
self.device_manager = DeviceManager()
self.quality_analyzer = QualityAnalyzer()
self.edge_refinement = EdgeRefinement()
def create_trimap(self, mask: np.ndarray,
dilation_size: int = None) -> np.ndarray:
"""
Create trimap from binary mask.
Args:
mask: Binary mask (H, W)
dilation_size: Size of uncertain region
Returns:
Trimap with 0 (background), 128 (unknown), 255 (foreground)
"""
dilation_size = dilation_size or self.config.trimap_size
# Ensure binary mask
if mask.dtype != np.uint8:
mask = (mask * 255).astype(np.uint8)
# Create trimap
trimap = np.copy(mask)
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE,
(dilation_size, dilation_size)
)
# Dilate and erode to create unknown region
dilated = cv2.dilate(mask, kernel, iterations=1)
eroded = cv2.erode(mask, kernel, iterations=1)
# Set unknown region
trimap[dilated == 255] = 128
trimap[eroded == 255] = 255
return trimap
def guided_filter(self, image: np.ndarray,
guide: np.ndarray,
radius: int = None,
eps: float = None) -> np.ndarray:
"""
Apply guided filter for edge-preserving smoothing.
Args:
image: Input image to filter
guide: Guide image (usually RGB image)
radius: Filter radius
eps: Regularization parameter
Returns:
Filtered image
"""
radius = radius or self.config.guided_filter_radius
eps = eps or self.config.guided_filter_eps
if len(guide.shape) == 3:
guide = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY)
# Convert to float32
guide = guide.astype(np.float32) / 255.0
image = image.astype(np.float32) / 255.0
# Box filter helper
def box_filter(img, r):
return cv2.boxFilter(img, -1, (r, r))
# Guided filter implementation
mean_I = box_filter(guide, radius)
mean_p = box_filter(image, radius)
mean_Ip = box_filter(guide * image, radius)
cov_Ip = mean_Ip - mean_I * mean_p
mean_II = box_filter(guide * guide, radius)
var_I = mean_II - mean_I * mean_I
a = cov_Ip / (var_I + eps)
b = mean_p - a * mean_I
mean_a = box_filter(a, radius)
mean_b = box_filter(b, radius)
output = mean_a * guide + mean_b
return np.clip(output * 255, 0, 255).astype(np.uint8)
def closed_form_matting(self, image: np.ndarray,
trimap: np.ndarray) -> np.ndarray:
"""
Closed-form matting using Laplacian matrix.
Simplified version for real-time processing.
Args:
image: RGB image
trimap: Trimap with known regions
Returns:
Alpha matte
"""
h, w = trimap.shape
# Initialize alpha with trimap
alpha = np.copy(trimap).astype(np.float32) / 255.0
# Known regions
is_fg = trimap == 255
is_bg = trimap == 0
is_unknown = trimap == 128
if not np.any(is_unknown):
return alpha
# Simple propagation from known to unknown regions
# Using distance transform for efficiency
dist_fg = cv2.distanceTransform(
is_fg.astype(np.uint8),
cv2.DIST_L2, 5
)
dist_bg = cv2.distanceTransform(
is_bg.astype(np.uint8),
cv2.DIST_L2, 5
)
# Normalize distances
total_dist = dist_fg + dist_bg + 1e-10
alpha_unknown = dist_fg / total_dist
# Apply only to unknown regions
alpha[is_unknown] = alpha_unknown[is_unknown]
# Apply guided filter for smoothing
if self.config.use_guided_filter:
alpha = self.guided_filter(
(alpha * 255).astype(np.uint8),
image
) / 255.0
return np.clip(alpha, 0, 1)
def deep_matting(self, image: np.ndarray,
mask: np.ndarray,
model: Optional[nn.Module] = None) -> np.ndarray:
"""
Apply deep learning-based matting refinement.
Args:
image: RGB image
mask: Initial mask
model: Optional pre-trained model
Returns:
Refined alpha matte
"""
device = self.device_manager.get_device()
# Prepare input
h, w = image.shape[:2]
# Resize for model input
input_size = (512, 512)
image_resized = cv2.resize(image, input_size)
mask_resized = cv2.resize(mask, input_size)
# Convert to tensor
image_tensor = torch.from_numpy(
image_resized.transpose(2, 0, 1)
).float().unsqueeze(0) / 255.0
mask_tensor = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
# Move to device
image_tensor = image_tensor.to(device)
mask_tensor = mask_tensor.to(device)
# If no model provided, use simple refinement
if model is None:
# Simple CNN-based refinement
with torch.no_grad():
# Concatenate image and mask
x = torch.cat([image_tensor, mask_tensor], dim=1)
# Simple refinement network simulation
refined = self._simple_refine_network(x)
# Convert back to numpy
alpha = refined.squeeze().cpu().numpy()
else:
with torch.no_grad():
alpha = model(image_tensor, mask_tensor)
alpha = alpha.squeeze().cpu().numpy()
# Resize back to original size
alpha = cv2.resize(alpha, (w, h))
return np.clip(alpha, 0, 1)
def _simple_refine_network(self, x: torch.Tensor) -> torch.Tensor:
"""Simple refinement network for demonstration."""
# Extract mask channel
mask = x[:, 3:4, :, :]
# Apply series of filters
refined = mask
# Edge-aware smoothing
for _ in range(3):
refined = F.avg_pool2d(refined, 3, stride=1, padding=1)
refined = torch.sigmoid((refined - 0.5) * 10)
return refined
def morphological_refinement(self, alpha: np.ndarray) -> np.ndarray:
"""
Apply morphological operations for refinement.
Args:
alpha: Alpha matte
Returns:
Refined alpha matte
"""
# Convert to uint8
alpha_uint8 = (alpha * 255).astype(np.uint8)
# Morphological operations
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
# Remove small holes
alpha_uint8 = cv2.morphologyEx(
alpha_uint8, cv2.MORPH_CLOSE, kernel,
iterations=self.config.erode_iterations
)
# Remove small components
alpha_uint8 = cv2.morphologyEx(
alpha_uint8, cv2.MORPH_OPEN, kernel,
iterations=self.config.dilate_iterations
)
# Smooth boundaries
if self.config.blur_radius > 0:
alpha_uint8 = cv2.GaussianBlur(
alpha_uint8,
(self.config.blur_radius * 2 + 1, self.config.blur_radius * 2 + 1),
0
)
return alpha_uint8.astype(np.float32) / 255.0
def process(self, image: np.ndarray,
mask: np.ndarray,
method: str = 'auto') -> Dict[str, np.ndarray]:
"""
Process image with selected matting method.
Args:
image: RGB image
mask: Initial segmentation mask
method: Matting method ('auto', 'trimap', 'deep', 'guided')
Returns:
Dictionary with alpha matte and confidence
"""
try:
# Analyze quality
quality_metrics = self.quality_analyzer.analyze_frame(image)
# Select method based on quality
if method == 'auto':
if quality_metrics['blur_score'] > 50:
method = 'guided'
elif quality_metrics['edge_clarity'] > 0.7:
method = 'trimap'
else:
method = 'deep'
logger.info(f"Using matting method: {method}")
# Apply selected method
if method == 'trimap':
trimap = self.create_trimap(mask)
alpha = self.closed_form_matting(image, trimap)
elif method == 'deep':
alpha = self.deep_matting(image, mask)
elif method == 'guided':
alpha = mask.astype(np.float32) / 255.0
if self.config.use_guided_filter:
alpha = self.guided_filter(
(alpha * 255).astype(np.uint8),
image
) / 255.0
else:
# Default to simple refinement
alpha = mask.astype(np.float32) / 255.0
# Apply morphological refinement
alpha = self.morphological_refinement(alpha)
# Edge refinement
alpha = self.edge_refinement.refine_edges(
image,
(alpha * 255).astype(np.uint8)
) / 255.0
# Calculate confidence
confidence = self._calculate_confidence(alpha, quality_metrics)
return {
'alpha': alpha,
'confidence': confidence,
'method_used': method,
'quality_metrics': quality_metrics
}
except Exception as e:
logger.error(f"Matting processing failed: {e}")
# Return original mask as fallback
return {
'alpha': mask.astype(np.float32) / 255.0,
'confidence': 0.0,
'method_used': 'fallback',
'error': str(e)
}
def _calculate_confidence(self, alpha: np.ndarray,
quality_metrics: Dict) -> float:
"""Calculate confidence score for the matting result."""
# Base confidence from quality metrics
confidence = quality_metrics.get('overall_quality', 0.5)
# Adjust based on alpha distribution
alpha_mean = np.mean(alpha)
alpha_std = np.std(alpha)
# Good matting should have clear separation
if 0.3 < alpha_mean < 0.7 and alpha_std > 0.3:
confidence *= 1.2
# Check for edge clarity
edges = cv2.Canny((alpha * 255).astype(np.uint8), 50, 150)
edge_ratio = np.sum(edges > 0) / edges.size
if edge_ratio < 0.1: # Clear boundaries
confidence *= 1.1
return np.clip(confidence, 0.0, 1.0)
class CompositingEngine:
"""Handle alpha compositing and blending."""
def __init__(self):
self.logger = setup_logger(f"{__name__}.CompositingEngine")
def composite(self, foreground: np.ndarray,
background: np.ndarray,
alpha: np.ndarray) -> np.ndarray:
"""
Composite foreground over background using alpha.
Args:
foreground: Foreground image (H, W, 3)
background: Background image (H, W, 3)
alpha: Alpha matte (H, W) or (H, W, 1)
Returns:
Composited image
"""
# Ensure alpha is 3-channel
if len(alpha.shape) == 2:
alpha = np.expand_dims(alpha, axis=2)
if alpha.shape[2] == 1:
alpha = np.repeat(alpha, 3, axis=2)
# Ensure float32
fg = foreground.astype(np.float32) / 255.0
bg = background.astype(np.float32) / 255.0
a = alpha.astype(np.float32)
if a.max() > 1.0:
a = a / 255.0
# Alpha blending
result = fg * a + bg * (1 - a)
# Convert back to uint8
result = np.clip(result * 255, 0, 255).astype(np.uint8)
return result
def premultiply_alpha(self, image: np.ndarray,
alpha: np.ndarray) -> np.ndarray:
"""Premultiply image by alpha channel."""
if len(alpha.shape) == 2:
alpha = np.expand_dims(alpha, axis=2)
result = image.astype(np.float32) * alpha.astype(np.float32)
if alpha.max() > 1.0:
result = result / 255.0
return np.clip(result, 0, 255).astype(np.uint8)