|
|
|
|
|
""" |
|
|
Advanced matting algorithms for BackgroundFX Pro. |
|
|
Implements multiple matting techniques with automatic fallback. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, Optional |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from utils.logger import get_logger |
|
|
from utils.hardware.device_manager import DeviceManager |
|
|
from utils.config import ConfigManager |
|
|
from core.models import ModelFactory, ModelType |
|
|
from core.quality import QualityAnalyzer |
|
|
from core.edge import EdgeRefinement |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MattingConfig: |
|
|
"""Configuration for matting operations.""" |
|
|
alpha_threshold: float = 0.5 |
|
|
erode_iterations: int = 2 |
|
|
dilate_iterations: int = 2 |
|
|
blur_radius: int = 3 |
|
|
trimap_size: int = 30 |
|
|
confidence_threshold: float = 0.7 |
|
|
use_guided_filter: bool = True |
|
|
guided_filter_radius: int = 8 |
|
|
guided_filter_eps: float = 1e-6 |
|
|
use_temporal_smoothing: bool = False |
|
|
temporal_window: int = 5 |
|
|
|
|
|
|
|
|
class AlphaMatting: |
|
|
"""Advanced alpha matting using multiple techniques.""" |
|
|
|
|
|
def __init__(self, config: Optional[MattingConfig] = None): |
|
|
self.config = config or MattingConfig() |
|
|
self.device_manager = DeviceManager() |
|
|
self.quality_analyzer = QualityAnalyzer() |
|
|
self.edge_refinement = EdgeRefinement() |
|
|
|
|
|
def create_trimap(self, mask: np.ndarray, dilation_size: Optional[int] = None) -> np.ndarray: |
|
|
""" |
|
|
Create trimap from a binary mask. |
|
|
|
|
|
Args: |
|
|
mask: Binary mask (H, W) in {0, 255} or [0,1] |
|
|
dilation_size: Size of uncertain region (pixels) |
|
|
|
|
|
Returns: |
|
|
Trimap with values 0 (background), 128 (unknown), 255 (foreground) |
|
|
""" |
|
|
dilation_size = dilation_size or self.config.trimap_size |
|
|
|
|
|
|
|
|
if mask.dtype != np.uint8: |
|
|
mask = (mask * 255).astype(np.uint8) |
|
|
mask = (mask > 127).astype(np.uint8) * 255 |
|
|
|
|
|
trimap = np.copy(mask) |
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_size, dilation_size)) |
|
|
|
|
|
|
|
|
dilated = cv2.dilate(mask, kernel, iterations=1) |
|
|
eroded = cv2.erode(mask, kernel, iterations=1) |
|
|
|
|
|
|
|
|
trimap[:] = 0 |
|
|
trimap[eroded == 255] = 255 |
|
|
unknown = (dilated == 255) & (eroded == 0) |
|
|
trimap[unknown] = 128 |
|
|
|
|
|
return trimap |
|
|
|
|
|
def guided_filter( |
|
|
self, |
|
|
image: np.ndarray, |
|
|
guide: np.ndarray, |
|
|
radius: Optional[int] = None, |
|
|
eps: Optional[float] = None, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Apply guided filter for edge-preserving smoothing. |
|
|
|
|
|
Args: |
|
|
image: Input image to filter (H, W) uint8 |
|
|
guide: Guide image (H, W, 3) or (H, W) |
|
|
radius: Filter radius |
|
|
eps: Regularization parameter |
|
|
|
|
|
Returns: |
|
|
Filtered image (H, W) uint8 |
|
|
""" |
|
|
radius = radius or self.config.guided_filter_radius |
|
|
eps = eps or self.config.guided_filter_eps |
|
|
|
|
|
if guide.ndim == 3: |
|
|
guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) |
|
|
else: |
|
|
guide_gray = guide |
|
|
|
|
|
|
|
|
I = guide_gray.astype(np.float32) / 255.0 |
|
|
p = image.astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
def box_filter(img, r): |
|
|
return cv2.boxFilter(img, -1, (r, r)) |
|
|
|
|
|
mean_I = box_filter(I, radius) |
|
|
mean_p = box_filter(p, radius) |
|
|
mean_Ip = box_filter(I * p, radius) |
|
|
cov_Ip = mean_Ip - mean_I * mean_p |
|
|
|
|
|
mean_II = box_filter(I * I, radius) |
|
|
var_I = mean_II - mean_I * mean_I |
|
|
|
|
|
a = cov_Ip / (var_I + eps) |
|
|
b = mean_p - a * mean_I |
|
|
|
|
|
mean_a = box_filter(a, radius) |
|
|
mean_b = box_filter(b, radius) |
|
|
|
|
|
q = mean_a * I + mean_b |
|
|
return np.clip(q * 255.0, 0, 255).astype(np.uint8) |
|
|
|
|
|
def closed_form_matting(self, image: np.ndarray, trimap: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Closed-form-inspired fast matting using distance transforms + optional guided filtering. |
|
|
|
|
|
Args: |
|
|
image: RGB image (H, W, 3) uint8 |
|
|
trimap: Trimap with values {0, 128, 255} |
|
|
|
|
|
Returns: |
|
|
Alpha matte in [0,1] float32 |
|
|
""" |
|
|
h, w = trimap.shape[:2] |
|
|
alpha = (trimap.astype(np.float32) / 255.0) |
|
|
|
|
|
is_fg = trimap == 255 |
|
|
is_bg = trimap == 0 |
|
|
is_unknown = trimap == 128 |
|
|
|
|
|
if not np.any(is_unknown): |
|
|
return np.clip(alpha, 0.0, 1.0) |
|
|
|
|
|
dist_fg = cv2.distanceTransform(is_fg.astype(np.uint8), cv2.DIST_L2, 5) |
|
|
dist_bg = cv2.distanceTransform(is_bg.astype(np.uint8), cv2.DIST_L2, 5) |
|
|
|
|
|
total = dist_fg + dist_bg + 1e-10 |
|
|
alpha_unknown = dist_fg / total |
|
|
alpha[is_unknown] = alpha_unknown[is_unknown] |
|
|
|
|
|
if self.config.use_guided_filter: |
|
|
alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8) |
|
|
alpha_u8 = self.guided_filter(alpha_u8, image) |
|
|
alpha = alpha_u8.astype(np.float32) / 255.0 |
|
|
|
|
|
return np.clip(alpha, 0.0, 1.0) |
|
|
|
|
|
def deep_matting( |
|
|
self, |
|
|
image: np.ndarray, |
|
|
mask: np.ndarray, |
|
|
model: Optional[nn.Module] = None, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Apply deep learning-based matting refinement. |
|
|
|
|
|
Args: |
|
|
image: RGB image (H, W, 3) uint8 |
|
|
mask: Initial mask (H, W) {0..255} or [0,1] |
|
|
model: Optional pre-trained model taking (img, mask) → alpha |
|
|
|
|
|
Returns: |
|
|
Refined alpha matte in [0,1] float32 |
|
|
""" |
|
|
device = self.device_manager.get_device() |
|
|
|
|
|
h, w = image.shape[:2] |
|
|
input_size = (512, 512) |
|
|
|
|
|
img_rs = cv2.resize(image, input_size) |
|
|
msk_rs = cv2.resize(mask, input_size) |
|
|
|
|
|
img_t = torch.from_numpy(img_rs.transpose(2, 0, 1)).float().unsqueeze(0) / 255.0 |
|
|
msk_t = torch.from_numpy(msk_rs).float().unsqueeze(0).unsqueeze(0) |
|
|
if msk_t.max() > 1.0: |
|
|
msk_t = msk_t / 255.0 |
|
|
|
|
|
img_t = img_t.to(device) |
|
|
msk_t = msk_t.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
if model is None: |
|
|
x = torch.cat([img_t, msk_t], dim=1) |
|
|
refined = self._simple_refine_network(x) |
|
|
else: |
|
|
refined = model(img_t, msk_t) |
|
|
alpha = refined.squeeze().float().cpu().numpy() |
|
|
|
|
|
alpha = cv2.resize(alpha, (w, h)) |
|
|
return np.clip(alpha, 0.0, 1.0) |
|
|
|
|
|
def _simple_refine_network(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Tiny non-learned refinement block (demo-quality).""" |
|
|
|
|
|
mask = x[:, 3:4, :, :] |
|
|
|
|
|
refined = mask |
|
|
for _ in range(3): |
|
|
refined = F.avg_pool2d(refined, 3, stride=1, padding=1) |
|
|
refined = torch.sigmoid((refined - 0.5) * 10.0) |
|
|
|
|
|
return refined |
|
|
|
|
|
def morphological_refinement(self, alpha: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Apply morphological operations and boundary smoothing. |
|
|
|
|
|
Args: |
|
|
alpha: Alpha matte in [0,1] float32 |
|
|
|
|
|
Returns: |
|
|
Refined alpha in [0,1] float32 |
|
|
""" |
|
|
alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8) |
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
|
|
|
|
|
|
alpha_u8 = cv2.morphologyEx( |
|
|
alpha_u8, cv2.MORPH_CLOSE, kernel, iterations=self.config.erode_iterations |
|
|
) |
|
|
|
|
|
alpha_u8 = cv2.morphologyEx( |
|
|
alpha_u8, cv2.MORPH_OPEN, kernel, iterations=self.config.dilate_iterations |
|
|
) |
|
|
|
|
|
if self.config.blur_radius > 0: |
|
|
r = self.config.blur_radius * 2 + 1 |
|
|
alpha_u8 = cv2.GaussianBlur(alpha_u8, (r, r), 0) |
|
|
|
|
|
return alpha_u8.astype(np.float32) / 255.0 |
|
|
|
|
|
def process(self, image: np.ndarray, mask: np.ndarray, method: str = "auto") -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Process image with selected matting method. |
|
|
|
|
|
Args: |
|
|
image: RGB image (H, W, 3) uint8 |
|
|
mask: Initial segmentation mask (H, W) |
|
|
method: 'auto' | 'trimap' | 'deep' | 'guided' |
|
|
|
|
|
Returns: |
|
|
dict(alpha, confidence, method_used, quality_metrics[, error]) |
|
|
""" |
|
|
try: |
|
|
quality_metrics = self.quality_analyzer.analyze_frame(image) |
|
|
|
|
|
chosen = method |
|
|
if method == "auto": |
|
|
|
|
|
blur_score = quality_metrics.get("blur_score", 0.0) |
|
|
edge_clarity = quality_metrics.get("edge_clarity", 0.0) |
|
|
if blur_score > 50: |
|
|
chosen = "guided" |
|
|
elif edge_clarity > 0.7: |
|
|
chosen = "trimap" |
|
|
else: |
|
|
chosen = "deep" |
|
|
|
|
|
logger.info(f"Using matting method: {chosen}") |
|
|
|
|
|
if chosen == "trimap": |
|
|
trimap = self.create_trimap(mask) |
|
|
alpha = self.closed_form_matting(image, trimap) |
|
|
elif chosen == "deep": |
|
|
alpha = self.deep_matting(image, mask) |
|
|
elif chosen == "guided": |
|
|
alpha = mask.astype(np.float32) |
|
|
if alpha.max() > 1.0: |
|
|
alpha = alpha / 255.0 |
|
|
if self.config.use_guided_filter: |
|
|
alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8) |
|
|
alpha = self.guided_filter(alpha_u8, image).astype(np.float32) / 255.0 |
|
|
else: |
|
|
alpha = mask.astype(np.float32) |
|
|
if alpha.max() > 1.0: |
|
|
alpha = alpha / 255.0 |
|
|
|
|
|
|
|
|
alpha = self.morphological_refinement(alpha) |
|
|
alpha = self.edge_refinement.refine_edges( |
|
|
image, np.clip(alpha * 255.0, 0, 255).astype(np.uint8) |
|
|
).astype(np.float32) / 255.0 |
|
|
|
|
|
confidence = self._calculate_confidence(alpha, quality_metrics) |
|
|
|
|
|
return { |
|
|
"alpha": np.clip(alpha, 0.0, 1.0), |
|
|
"confidence": float(np.clip(confidence, 0.0, 1.0)), |
|
|
"method_used": chosen, |
|
|
"quality_metrics": quality_metrics, |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Matting processing failed: {e}") |
|
|
fallback = mask.astype(np.float32) |
|
|
if fallback.max() > 1.0: |
|
|
fallback = fallback / 255.0 |
|
|
return { |
|
|
"alpha": np.clip(fallback, 0.0, 1.0), |
|
|
"confidence": 0.0, |
|
|
"method_used": "fallback", |
|
|
"error": str(e), |
|
|
} |
|
|
|
|
|
def _calculate_confidence(self, alpha: np.ndarray, quality_metrics: Dict) -> float: |
|
|
"""Calculate confidence score for the matting result.""" |
|
|
confidence = float(quality_metrics.get("overall_quality", 0.5)) |
|
|
|
|
|
alpha_mean = float(np.mean(alpha)) |
|
|
alpha_std = float(np.std(alpha)) |
|
|
|
|
|
|
|
|
if 0.3 < alpha_mean < 0.7 and alpha_std > 0.3: |
|
|
confidence *= 1.2 |
|
|
|
|
|
edges = cv2.Canny(np.clip(alpha * 255.0, 0, 255).astype(np.uint8), 50, 150) |
|
|
edge_ratio = float(np.sum(edges > 0) / edges.size) |
|
|
if edge_ratio < 0.1: |
|
|
confidence *= 1.1 |
|
|
|
|
|
return float(np.clip(confidence, 0.0, 1.0)) |
|
|
|
|
|
|
|
|
class CompositingEngine: |
|
|
"""Handle alpha compositing and blending.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.logger = get_logger(f"{__name__}.CompositingEngine") |
|
|
|
|
|
def composite(self, foreground: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Composite foreground over background using alpha. |
|
|
|
|
|
Args: |
|
|
foreground: Foreground image (H, W, 3) uint8 |
|
|
background: Background image (H, W, 3) uint8 |
|
|
alpha: Alpha matte (H, W) or (H, W, 1) in [0..255] or [0..1] |
|
|
|
|
|
Returns: |
|
|
Composited image (H, W, 3) uint8 |
|
|
""" |
|
|
|
|
|
if alpha.ndim == 2: |
|
|
alpha = np.expand_dims(alpha, axis=2) |
|
|
if alpha.shape[2] == 1: |
|
|
alpha = np.repeat(alpha, 3, axis=2) |
|
|
|
|
|
|
|
|
a = alpha.astype(np.float32) |
|
|
if a.max() > 1.0: |
|
|
a = a / 255.0 |
|
|
|
|
|
fg = foreground.astype(np.float32) / 255.0 |
|
|
bg = background.astype(np.float32) / 255.0 |
|
|
|
|
|
result = fg * a + bg * (1.0 - a) |
|
|
return np.clip(result * 255.0, 0, 255).astype(np.uint8) |
|
|
|
|
|
def premultiply_alpha(self, image: np.ndarray, alpha: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Premultiply RGB image by alpha channel. |
|
|
|
|
|
Args: |
|
|
image: (H, W, 3) uint8 |
|
|
alpha: (H, W) or (H, W, 1) in [0..255] or [0..1] |
|
|
|
|
|
Returns: |
|
|
Premultiplied (H, W, 3) uint8 |
|
|
""" |
|
|
if alpha.ndim == 2: |
|
|
alpha = np.expand_dims(alpha, axis=2) |
|
|
if alpha.shape[2] == 1: |
|
|
alpha = np.repeat(alpha, 3, axis=2) |
|
|
|
|
|
a = alpha.astype(np.float32) |
|
|
if a.max() > 1.0: |
|
|
a = a / 255.0 |
|
|
|
|
|
img_f = image.astype(np.float32) |
|
|
premul = img_f * a |
|
|
return np.clip(premul, 0.0, 255.0).astype(np.uint8) |
|
|
|