#!/usr/bin/env python3 """ Advanced matting algorithms for BackgroundFX Pro. Implements multiple matting techniques with automatic fallback. """ from dataclasses import dataclass from typing import Dict, Optional import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from utils.logger import get_logger from utils.hardware.device_manager import DeviceManager from utils.config import ConfigManager # kept for forward compatibility / config hook from core.models import ModelFactory, ModelType # not used directly here but kept for API consistency from core.quality import QualityAnalyzer from core.edge import EdgeRefinement logger = get_logger(__name__) @dataclass class MattingConfig: """Configuration for matting operations.""" alpha_threshold: float = 0.5 erode_iterations: int = 2 dilate_iterations: int = 2 blur_radius: int = 3 trimap_size: int = 30 confidence_threshold: float = 0.7 use_guided_filter: bool = True guided_filter_radius: int = 8 guided_filter_eps: float = 1e-6 use_temporal_smoothing: bool = False temporal_window: int = 5 class AlphaMatting: """Advanced alpha matting using multiple techniques.""" def __init__(self, config: Optional[MattingConfig] = None): self.config = config or MattingConfig() self.device_manager = DeviceManager() self.quality_analyzer = QualityAnalyzer() self.edge_refinement = EdgeRefinement() def create_trimap(self, mask: np.ndarray, dilation_size: Optional[int] = None) -> np.ndarray: """ Create trimap from a binary mask. Args: mask: Binary mask (H, W) in {0, 255} or [0,1] dilation_size: Size of uncertain region (pixels) Returns: Trimap with values 0 (background), 128 (unknown), 255 (foreground) """ dilation_size = dilation_size or self.config.trimap_size # Ensure uint8 binary if mask.dtype != np.uint8: mask = (mask * 255).astype(np.uint8) mask = (mask > 127).astype(np.uint8) * 255 trimap = np.copy(mask) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_size, dilation_size)) # Dilate/erode once to form unknown band dilated = cv2.dilate(mask, kernel, iterations=1) eroded = cv2.erode(mask, kernel, iterations=1) # Unknown where dilation has expanded FG beyond eroded FG band trimap[:] = 0 trimap[eroded == 255] = 255 unknown = (dilated == 255) & (eroded == 0) trimap[unknown] = 128 return trimap def guided_filter( self, image: np.ndarray, guide: np.ndarray, radius: Optional[int] = None, eps: Optional[float] = None, ) -> np.ndarray: """ Apply guided filter for edge-preserving smoothing. Args: image: Input image to filter (H, W) uint8 guide: Guide image (H, W, 3) or (H, W) radius: Filter radius eps: Regularization parameter Returns: Filtered image (H, W) uint8 """ radius = radius or self.config.guided_filter_radius eps = eps or self.config.guided_filter_eps if guide.ndim == 3: guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) else: guide_gray = guide # Convert to float32 in [0,1] I = guide_gray.astype(np.float32) / 255.0 p = image.astype(np.float32) / 255.0 # Box filter helper def box_filter(img, r): return cv2.boxFilter(img, -1, (r, r)) mean_I = box_filter(I, radius) mean_p = box_filter(p, radius) mean_Ip = box_filter(I * p, radius) cov_Ip = mean_Ip - mean_I * mean_p mean_II = box_filter(I * I, radius) var_I = mean_II - mean_I * mean_I a = cov_Ip / (var_I + eps) b = mean_p - a * mean_I mean_a = box_filter(a, radius) mean_b = box_filter(b, radius) q = mean_a * I + mean_b return np.clip(q * 255.0, 0, 255).astype(np.uint8) def closed_form_matting(self, image: np.ndarray, trimap: np.ndarray) -> np.ndarray: """ Closed-form-inspired fast matting using distance transforms + optional guided filtering. Args: image: RGB image (H, W, 3) uint8 trimap: Trimap with values {0, 128, 255} Returns: Alpha matte in [0,1] float32 """ h, w = trimap.shape[:2] alpha = (trimap.astype(np.float32) / 255.0) is_fg = trimap == 255 is_bg = trimap == 0 is_unknown = trimap == 128 if not np.any(is_unknown): return np.clip(alpha, 0.0, 1.0) dist_fg = cv2.distanceTransform(is_fg.astype(np.uint8), cv2.DIST_L2, 5) dist_bg = cv2.distanceTransform(is_bg.astype(np.uint8), cv2.DIST_L2, 5) total = dist_fg + dist_bg + 1e-10 alpha_unknown = dist_fg / total alpha[is_unknown] = alpha_unknown[is_unknown] if self.config.use_guided_filter: alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8) alpha_u8 = self.guided_filter(alpha_u8, image) alpha = alpha_u8.astype(np.float32) / 255.0 return np.clip(alpha, 0.0, 1.0) def deep_matting( self, image: np.ndarray, mask: np.ndarray, model: Optional[nn.Module] = None, ) -> np.ndarray: """ Apply deep learning-based matting refinement. Args: image: RGB image (H, W, 3) uint8 mask: Initial mask (H, W) {0..255} or [0,1] model: Optional pre-trained model taking (img, mask) → alpha Returns: Refined alpha matte in [0,1] float32 """ device = self.device_manager.get_device() h, w = image.shape[:2] input_size = (512, 512) img_rs = cv2.resize(image, input_size) msk_rs = cv2.resize(mask, input_size) img_t = torch.from_numpy(img_rs.transpose(2, 0, 1)).float().unsqueeze(0) / 255.0 msk_t = torch.from_numpy(msk_rs).float().unsqueeze(0).unsqueeze(0) if msk_t.max() > 1.0: msk_t = msk_t / 255.0 img_t = img_t.to(device) msk_t = msk_t.to(device) with torch.no_grad(): if model is None: x = torch.cat([img_t, msk_t], dim=1) refined = self._simple_refine_network(x) else: refined = model(img_t, msk_t) alpha = refined.squeeze().float().cpu().numpy() alpha = cv2.resize(alpha, (w, h)) return np.clip(alpha, 0.0, 1.0) def _simple_refine_network(self, x: torch.Tensor) -> torch.Tensor: """Tiny non-learned refinement block (demo-quality).""" # x: [B, 4, H, W] (RGB + mask) mask = x[:, 3:4, :, :] refined = mask for _ in range(3): refined = F.avg_pool2d(refined, 3, stride=1, padding=1) refined = torch.sigmoid((refined - 0.5) * 10.0) return refined def morphological_refinement(self, alpha: np.ndarray) -> np.ndarray: """ Apply morphological operations and boundary smoothing. Args: alpha: Alpha matte in [0,1] float32 Returns: Refined alpha in [0,1] float32 """ alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) # Close small holes in FG alpha_u8 = cv2.morphologyEx( alpha_u8, cv2.MORPH_CLOSE, kernel, iterations=self.config.erode_iterations ) # Remove small specks alpha_u8 = cv2.morphologyEx( alpha_u8, cv2.MORPH_OPEN, kernel, iterations=self.config.dilate_iterations ) if self.config.blur_radius > 0: r = self.config.blur_radius * 2 + 1 alpha_u8 = cv2.GaussianBlur(alpha_u8, (r, r), 0) return alpha_u8.astype(np.float32) / 255.0 def process(self, image: np.ndarray, mask: np.ndarray, method: str = "auto") -> Dict[str, np.ndarray]: """ Process image with selected matting method. Args: image: RGB image (H, W, 3) uint8 mask: Initial segmentation mask (H, W) method: 'auto' | 'trimap' | 'deep' | 'guided' Returns: dict(alpha, confidence, method_used, quality_metrics[, error]) """ try: quality_metrics = self.quality_analyzer.analyze_frame(image) chosen = method if method == "auto": # Heuristic selection blur_score = quality_metrics.get("blur_score", 0.0) edge_clarity = quality_metrics.get("edge_clarity", 0.0) if blur_score > 50: chosen = "guided" elif edge_clarity > 0.7: chosen = "trimap" else: chosen = "deep" logger.info(f"Using matting method: {chosen}") if chosen == "trimap": trimap = self.create_trimap(mask) alpha = self.closed_form_matting(image, trimap) elif chosen == "deep": alpha = self.deep_matting(image, mask) elif chosen == "guided": alpha = mask.astype(np.float32) if alpha.max() > 1.0: alpha = alpha / 255.0 if self.config.use_guided_filter: alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8) alpha = self.guided_filter(alpha_u8, image).astype(np.float32) / 255.0 else: alpha = mask.astype(np.float32) if alpha.max() > 1.0: alpha = alpha / 255.0 # Morphological + edge refinement alpha = self.morphological_refinement(alpha) alpha = self.edge_refinement.refine_edges( image, np.clip(alpha * 255.0, 0, 255).astype(np.uint8) ).astype(np.float32) / 255.0 confidence = self._calculate_confidence(alpha, quality_metrics) return { "alpha": np.clip(alpha, 0.0, 1.0), "confidence": float(np.clip(confidence, 0.0, 1.0)), "method_used": chosen, "quality_metrics": quality_metrics, } except Exception as e: logger.error(f"Matting processing failed: {e}") fallback = mask.astype(np.float32) if fallback.max() > 1.0: fallback = fallback / 255.0 return { "alpha": np.clip(fallback, 0.0, 1.0), "confidence": 0.0, "method_used": "fallback", "error": str(e), } def _calculate_confidence(self, alpha: np.ndarray, quality_metrics: Dict) -> float: """Calculate confidence score for the matting result.""" confidence = float(quality_metrics.get("overall_quality", 0.5)) alpha_mean = float(np.mean(alpha)) alpha_std = float(np.std(alpha)) # Prefer clear separation if 0.3 < alpha_mean < 0.7 and alpha_std > 0.3: confidence *= 1.2 edges = cv2.Canny(np.clip(alpha * 255.0, 0, 255).astype(np.uint8), 50, 150) edge_ratio = float(np.sum(edges > 0) / edges.size) if edge_ratio < 0.1: confidence *= 1.1 return float(np.clip(confidence, 0.0, 1.0)) class CompositingEngine: """Handle alpha compositing and blending.""" def __init__(self): self.logger = get_logger(f"{__name__}.CompositingEngine") def composite(self, foreground: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray: """ Composite foreground over background using alpha. Args: foreground: Foreground image (H, W, 3) uint8 background: Background image (H, W, 3) uint8 alpha: Alpha matte (H, W) or (H, W, 1) in [0..255] or [0..1] Returns: Composited image (H, W, 3) uint8 """ # Ensure alpha is 3-channel if alpha.ndim == 2: alpha = np.expand_dims(alpha, axis=2) if alpha.shape[2] == 1: alpha = np.repeat(alpha, 3, axis=2) # Normalize alpha to [0,1] a = alpha.astype(np.float32) if a.max() > 1.0: a = a / 255.0 fg = foreground.astype(np.float32) / 255.0 bg = background.astype(np.float32) / 255.0 result = fg * a + bg * (1.0 - a) return np.clip(result * 255.0, 0, 255).astype(np.uint8) def premultiply_alpha(self, image: np.ndarray, alpha: np.ndarray) -> np.ndarray: """ Premultiply RGB image by alpha channel. Args: image: (H, W, 3) uint8 alpha: (H, W) or (H, W, 1) in [0..255] or [0..1] Returns: Premultiplied (H, W, 3) uint8 """ if alpha.ndim == 2: alpha = np.expand_dims(alpha, axis=2) if alpha.shape[2] == 1: alpha = np.repeat(alpha, 3, axis=2) a = alpha.astype(np.float32) if a.max() > 1.0: a = a / 255.0 img_f = image.astype(np.float32) premul = img_f * a return np.clip(premul, 0.0, 255.0).astype(np.uint8)