|
|
""" |
|
|
Advanced matting algorithms for BackgroundFX Pro. |
|
|
Implements multiple matting techniques with automatic fallback. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from typing import Dict, Tuple, Optional, List |
|
|
from dataclasses import dataclass |
|
|
import logging |
|
|
|
|
|
from ..utils.logger import setup_logger |
|
|
from ..utils.device import DeviceManager |
|
|
from ..utils.config import ConfigManager |
|
|
from ..core.models import ModelFactory, ModelType |
|
|
from ..core.quality import QualityAnalyzer |
|
|
from ..core.edge import EdgeRefinement |
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MattingConfig: |
|
|
"""Configuration for matting operations.""" |
|
|
alpha_threshold: float = 0.5 |
|
|
erode_iterations: int = 2 |
|
|
dilate_iterations: int = 2 |
|
|
blur_radius: int = 3 |
|
|
trimap_size: int = 30 |
|
|
confidence_threshold: float = 0.7 |
|
|
use_guided_filter: bool = True |
|
|
guided_filter_radius: int = 8 |
|
|
guided_filter_eps: float = 1e-6 |
|
|
use_temporal_smoothing: bool = False |
|
|
temporal_window: int = 5 |
|
|
|
|
|
|
|
|
class AlphaMatting: |
|
|
"""Advanced alpha matting using multiple techniques.""" |
|
|
|
|
|
def __init__(self, config: Optional[MattingConfig] = None): |
|
|
self.config = config or MattingConfig() |
|
|
self.device_manager = DeviceManager() |
|
|
self.quality_analyzer = QualityAnalyzer() |
|
|
self.edge_refinement = EdgeRefinement() |
|
|
|
|
|
def create_trimap(self, mask: np.ndarray, |
|
|
dilation_size: int = None) -> np.ndarray: |
|
|
""" |
|
|
Create trimap from binary mask. |
|
|
|
|
|
Args: |
|
|
mask: Binary mask (H, W) |
|
|
dilation_size: Size of uncertain region |
|
|
|
|
|
Returns: |
|
|
Trimap with 0 (background), 128 (unknown), 255 (foreground) |
|
|
""" |
|
|
dilation_size = dilation_size or self.config.trimap_size |
|
|
|
|
|
|
|
|
if mask.dtype != np.uint8: |
|
|
mask = (mask * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
trimap = np.copy(mask) |
|
|
kernel = cv2.getStructuringElement( |
|
|
cv2.MORPH_ELLIPSE, |
|
|
(dilation_size, dilation_size) |
|
|
) |
|
|
|
|
|
|
|
|
dilated = cv2.dilate(mask, kernel, iterations=1) |
|
|
eroded = cv2.erode(mask, kernel, iterations=1) |
|
|
|
|
|
|
|
|
trimap[dilated == 255] = 128 |
|
|
trimap[eroded == 255] = 255 |
|
|
|
|
|
return trimap |
|
|
|
|
|
def guided_filter(self, image: np.ndarray, |
|
|
guide: np.ndarray, |
|
|
radius: int = None, |
|
|
eps: float = None) -> np.ndarray: |
|
|
""" |
|
|
Apply guided filter for edge-preserving smoothing. |
|
|
|
|
|
Args: |
|
|
image: Input image to filter |
|
|
guide: Guide image (usually RGB image) |
|
|
radius: Filter radius |
|
|
eps: Regularization parameter |
|
|
|
|
|
Returns: |
|
|
Filtered image |
|
|
""" |
|
|
radius = radius or self.config.guided_filter_radius |
|
|
eps = eps or self.config.guided_filter_eps |
|
|
|
|
|
if len(guide.shape) == 3: |
|
|
guide = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
|
|
|
guide = guide.astype(np.float32) / 255.0 |
|
|
image = image.astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
def box_filter(img, r): |
|
|
return cv2.boxFilter(img, -1, (r, r)) |
|
|
|
|
|
|
|
|
mean_I = box_filter(guide, radius) |
|
|
mean_p = box_filter(image, radius) |
|
|
mean_Ip = box_filter(guide * image, radius) |
|
|
cov_Ip = mean_Ip - mean_I * mean_p |
|
|
|
|
|
mean_II = box_filter(guide * guide, radius) |
|
|
var_I = mean_II - mean_I * mean_I |
|
|
|
|
|
a = cov_Ip / (var_I + eps) |
|
|
b = mean_p - a * mean_I |
|
|
|
|
|
mean_a = box_filter(a, radius) |
|
|
mean_b = box_filter(b, radius) |
|
|
|
|
|
output = mean_a * guide + mean_b |
|
|
return np.clip(output * 255, 0, 255).astype(np.uint8) |
|
|
|
|
|
def closed_form_matting(self, image: np.ndarray, |
|
|
trimap: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Closed-form matting using Laplacian matrix. |
|
|
Simplified version for real-time processing. |
|
|
|
|
|
Args: |
|
|
image: RGB image |
|
|
trimap: Trimap with known regions |
|
|
|
|
|
Returns: |
|
|
Alpha matte |
|
|
""" |
|
|
h, w = trimap.shape |
|
|
|
|
|
|
|
|
alpha = np.copy(trimap).astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
is_fg = trimap == 255 |
|
|
is_bg = trimap == 0 |
|
|
is_unknown = trimap == 128 |
|
|
|
|
|
if not np.any(is_unknown): |
|
|
return alpha |
|
|
|
|
|
|
|
|
|
|
|
dist_fg = cv2.distanceTransform( |
|
|
is_fg.astype(np.uint8), |
|
|
cv2.DIST_L2, 5 |
|
|
) |
|
|
dist_bg = cv2.distanceTransform( |
|
|
is_bg.astype(np.uint8), |
|
|
cv2.DIST_L2, 5 |
|
|
) |
|
|
|
|
|
|
|
|
total_dist = dist_fg + dist_bg + 1e-10 |
|
|
alpha_unknown = dist_fg / total_dist |
|
|
|
|
|
|
|
|
alpha[is_unknown] = alpha_unknown[is_unknown] |
|
|
|
|
|
|
|
|
if self.config.use_guided_filter: |
|
|
alpha = self.guided_filter( |
|
|
(alpha * 255).astype(np.uint8), |
|
|
image |
|
|
) / 255.0 |
|
|
|
|
|
return np.clip(alpha, 0, 1) |
|
|
|
|
|
def deep_matting(self, image: np.ndarray, |
|
|
mask: np.ndarray, |
|
|
model: Optional[nn.Module] = None) -> np.ndarray: |
|
|
""" |
|
|
Apply deep learning-based matting refinement. |
|
|
|
|
|
Args: |
|
|
image: RGB image |
|
|
mask: Initial mask |
|
|
model: Optional pre-trained model |
|
|
|
|
|
Returns: |
|
|
Refined alpha matte |
|
|
""" |
|
|
device = self.device_manager.get_device() |
|
|
|
|
|
|
|
|
h, w = image.shape[:2] |
|
|
|
|
|
|
|
|
input_size = (512, 512) |
|
|
image_resized = cv2.resize(image, input_size) |
|
|
mask_resized = cv2.resize(mask, input_size) |
|
|
|
|
|
|
|
|
image_tensor = torch.from_numpy( |
|
|
image_resized.transpose(2, 0, 1) |
|
|
).float().unsqueeze(0) / 255.0 |
|
|
|
|
|
mask_tensor = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0 |
|
|
|
|
|
|
|
|
image_tensor = image_tensor.to(device) |
|
|
mask_tensor = mask_tensor.to(device) |
|
|
|
|
|
|
|
|
if model is None: |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
x = torch.cat([image_tensor, mask_tensor], dim=1) |
|
|
|
|
|
|
|
|
refined = self._simple_refine_network(x) |
|
|
|
|
|
|
|
|
alpha = refined.squeeze().cpu().numpy() |
|
|
else: |
|
|
with torch.no_grad(): |
|
|
alpha = model(image_tensor, mask_tensor) |
|
|
alpha = alpha.squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
alpha = cv2.resize(alpha, (w, h)) |
|
|
|
|
|
return np.clip(alpha, 0, 1) |
|
|
|
|
|
def _simple_refine_network(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Simple refinement network for demonstration.""" |
|
|
|
|
|
mask = x[:, 3:4, :, :] |
|
|
|
|
|
|
|
|
refined = mask |
|
|
|
|
|
|
|
|
for _ in range(3): |
|
|
refined = F.avg_pool2d(refined, 3, stride=1, padding=1) |
|
|
refined = torch.sigmoid((refined - 0.5) * 10) |
|
|
|
|
|
return refined |
|
|
|
|
|
def morphological_refinement(self, alpha: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Apply morphological operations for refinement. |
|
|
|
|
|
Args: |
|
|
alpha: Alpha matte |
|
|
|
|
|
Returns: |
|
|
Refined alpha matte |
|
|
""" |
|
|
|
|
|
alpha_uint8 = (alpha * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
|
|
|
|
|
|
alpha_uint8 = cv2.morphologyEx( |
|
|
alpha_uint8, cv2.MORPH_CLOSE, kernel, |
|
|
iterations=self.config.erode_iterations |
|
|
) |
|
|
|
|
|
|
|
|
alpha_uint8 = cv2.morphologyEx( |
|
|
alpha_uint8, cv2.MORPH_OPEN, kernel, |
|
|
iterations=self.config.dilate_iterations |
|
|
) |
|
|
|
|
|
|
|
|
if self.config.blur_radius > 0: |
|
|
alpha_uint8 = cv2.GaussianBlur( |
|
|
alpha_uint8, |
|
|
(self.config.blur_radius * 2 + 1, self.config.blur_radius * 2 + 1), |
|
|
0 |
|
|
) |
|
|
|
|
|
return alpha_uint8.astype(np.float32) / 255.0 |
|
|
|
|
|
def process(self, image: np.ndarray, |
|
|
mask: np.ndarray, |
|
|
method: str = 'auto') -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Process image with selected matting method. |
|
|
|
|
|
Args: |
|
|
image: RGB image |
|
|
mask: Initial segmentation mask |
|
|
method: Matting method ('auto', 'trimap', 'deep', 'guided') |
|
|
|
|
|
Returns: |
|
|
Dictionary with alpha matte and confidence |
|
|
""" |
|
|
try: |
|
|
|
|
|
quality_metrics = self.quality_analyzer.analyze_frame(image) |
|
|
|
|
|
|
|
|
if method == 'auto': |
|
|
if quality_metrics['blur_score'] > 50: |
|
|
method = 'guided' |
|
|
elif quality_metrics['edge_clarity'] > 0.7: |
|
|
method = 'trimap' |
|
|
else: |
|
|
method = 'deep' |
|
|
|
|
|
logger.info(f"Using matting method: {method}") |
|
|
|
|
|
|
|
|
if method == 'trimap': |
|
|
trimap = self.create_trimap(mask) |
|
|
alpha = self.closed_form_matting(image, trimap) |
|
|
|
|
|
elif method == 'deep': |
|
|
alpha = self.deep_matting(image, mask) |
|
|
|
|
|
elif method == 'guided': |
|
|
alpha = mask.astype(np.float32) / 255.0 |
|
|
if self.config.use_guided_filter: |
|
|
alpha = self.guided_filter( |
|
|
(alpha * 255).astype(np.uint8), |
|
|
image |
|
|
) / 255.0 |
|
|
else: |
|
|
|
|
|
alpha = mask.astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
alpha = self.morphological_refinement(alpha) |
|
|
|
|
|
|
|
|
alpha = self.edge_refinement.refine_edges( |
|
|
image, |
|
|
(alpha * 255).astype(np.uint8) |
|
|
) / 255.0 |
|
|
|
|
|
|
|
|
confidence = self._calculate_confidence(alpha, quality_metrics) |
|
|
|
|
|
return { |
|
|
'alpha': alpha, |
|
|
'confidence': confidence, |
|
|
'method_used': method, |
|
|
'quality_metrics': quality_metrics |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Matting processing failed: {e}") |
|
|
|
|
|
return { |
|
|
'alpha': mask.astype(np.float32) / 255.0, |
|
|
'confidence': 0.0, |
|
|
'method_used': 'fallback', |
|
|
'error': str(e) |
|
|
} |
|
|
|
|
|
def _calculate_confidence(self, alpha: np.ndarray, |
|
|
quality_metrics: Dict) -> float: |
|
|
"""Calculate confidence score for the matting result.""" |
|
|
|
|
|
confidence = quality_metrics.get('overall_quality', 0.5) |
|
|
|
|
|
|
|
|
alpha_mean = np.mean(alpha) |
|
|
alpha_std = np.std(alpha) |
|
|
|
|
|
|
|
|
if 0.3 < alpha_mean < 0.7 and alpha_std > 0.3: |
|
|
confidence *= 1.2 |
|
|
|
|
|
|
|
|
edges = cv2.Canny((alpha * 255).astype(np.uint8), 50, 150) |
|
|
edge_ratio = np.sum(edges > 0) / edges.size |
|
|
|
|
|
if edge_ratio < 0.1: |
|
|
confidence *= 1.1 |
|
|
|
|
|
return np.clip(confidence, 0.0, 1.0) |
|
|
|
|
|
|
|
|
class CompositingEngine: |
|
|
"""Handle alpha compositing and blending.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.logger = setup_logger(f"{__name__}.CompositingEngine") |
|
|
|
|
|
def composite(self, foreground: np.ndarray, |
|
|
background: np.ndarray, |
|
|
alpha: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Composite foreground over background using alpha. |
|
|
|
|
|
Args: |
|
|
foreground: Foreground image (H, W, 3) |
|
|
background: Background image (H, W, 3) |
|
|
alpha: Alpha matte (H, W) or (H, W, 1) |
|
|
|
|
|
Returns: |
|
|
Composited image |
|
|
""" |
|
|
|
|
|
if len(alpha.shape) == 2: |
|
|
alpha = np.expand_dims(alpha, axis=2) |
|
|
if alpha.shape[2] == 1: |
|
|
alpha = np.repeat(alpha, 3, axis=2) |
|
|
|
|
|
|
|
|
fg = foreground.astype(np.float32) / 255.0 |
|
|
bg = background.astype(np.float32) / 255.0 |
|
|
a = alpha.astype(np.float32) |
|
|
|
|
|
if a.max() > 1.0: |
|
|
a = a / 255.0 |
|
|
|
|
|
|
|
|
result = fg * a + bg * (1 - a) |
|
|
|
|
|
|
|
|
result = np.clip(result * 255, 0, 255).astype(np.uint8) |
|
|
|
|
|
return result |
|
|
|
|
|
def premultiply_alpha(self, image: np.ndarray, |
|
|
alpha: np.ndarray) -> np.ndarray: |
|
|
"""Premultiply image by alpha channel.""" |
|
|
if len(alpha.shape) == 2: |
|
|
alpha = np.expand_dims(alpha, axis=2) |
|
|
|
|
|
result = image.astype(np.float32) * alpha.astype(np.float32) |
|
|
|
|
|
if alpha.max() > 1.0: |
|
|
result = result / 255.0 |
|
|
|
|
|
return np.clip(result, 0, 255).astype(np.uint8) |