|
|
""" |
|
|
Edge processing and symmetry correction for BackgroundFX Pro. |
|
|
Fixes hair segmentation asymmetry and improves edge quality. |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import cv2 |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
from dataclasses import dataclass |
|
|
from scipy import ndimage, signal |
|
|
from scipy.spatial import distance |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EdgeConfig: |
|
|
"""Configuration for edge processing.""" |
|
|
edge_thickness: int = 3 |
|
|
smoothing_iterations: int = 2 |
|
|
symmetry_threshold: float = 0.3 |
|
|
hair_detection_sensitivity: float = 0.7 |
|
|
refinement_radius: int = 5 |
|
|
use_guided_filter: bool = True |
|
|
bilateral_d: int = 9 |
|
|
bilateral_sigma_color: float = 75 |
|
|
bilateral_sigma_space: float = 75 |
|
|
morphology_kernel_size: int = 5 |
|
|
edge_preservation_weight: float = 0.8 |
|
|
|
|
|
|
|
|
class EdgeProcessor: |
|
|
"""Main edge processing and refinement system.""" |
|
|
|
|
|
def __init__(self, config: Optional[EdgeConfig] = None): |
|
|
self.config = config or EdgeConfig() |
|
|
self.hair_segmentation = HairSegmentation(config) |
|
|
self.edge_refinement = EdgeRefinement(config) |
|
|
self.symmetry_corrector = SymmetryCorrector(config) |
|
|
|
|
|
def process(self, image: np.ndarray, mask: np.ndarray, |
|
|
detect_hair: bool = True) -> np.ndarray: |
|
|
"""Process edges with full pipeline.""" |
|
|
|
|
|
edges = self._detect_edges(mask) |
|
|
|
|
|
|
|
|
if detect_hair: |
|
|
hair_mask = self.hair_segmentation.segment(image, mask) |
|
|
mask = self._blend_hair_mask(mask, hair_mask) |
|
|
|
|
|
|
|
|
mask = self.symmetry_corrector.correct(mask, image) |
|
|
|
|
|
|
|
|
mask = self.edge_refinement.refine(image, mask, edges) |
|
|
|
|
|
|
|
|
mask = self._final_smoothing(mask) |
|
|
|
|
|
return mask |
|
|
|
|
|
def _detect_edges(self, mask: np.ndarray) -> np.ndarray: |
|
|
"""Detect edges in mask.""" |
|
|
|
|
|
mask_uint8 = (mask * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
edges1 = cv2.Canny(mask_uint8, 50, 150) |
|
|
edges2 = cv2.Canny(mask_uint8, 30, 100) |
|
|
edges3 = cv2.Canny(mask_uint8, 70, 200) |
|
|
|
|
|
|
|
|
edges = np.maximum(edges1, np.maximum(edges2, edges3)) |
|
|
|
|
|
return edges / 255.0 |
|
|
|
|
|
def _blend_hair_mask(self, original_mask: np.ndarray, |
|
|
hair_mask: np.ndarray) -> np.ndarray: |
|
|
"""Blend hair mask with original mask.""" |
|
|
|
|
|
hair_regions = hair_mask > 0.5 |
|
|
|
|
|
|
|
|
alpha = 0.7 |
|
|
blended = original_mask.copy() |
|
|
blended[hair_regions] = ( |
|
|
alpha * hair_mask[hair_regions] + |
|
|
(1 - alpha) * original_mask[hair_regions] |
|
|
) |
|
|
|
|
|
return blended |
|
|
|
|
|
def _final_smoothing(self, mask: np.ndarray) -> np.ndarray: |
|
|
"""Apply final smoothing pass.""" |
|
|
|
|
|
if self.config.use_guided_filter: |
|
|
mask = self._guided_filter(mask, mask) |
|
|
|
|
|
|
|
|
kernel = cv2.getStructuringElement( |
|
|
cv2.MORPH_ELLIPSE, |
|
|
(self.config.morphology_kernel_size, self.config.morphology_kernel_size) |
|
|
) |
|
|
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
|
|
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
|
|
|
|
|
return mask |
|
|
|
|
|
def _guided_filter(self, input_img: np.ndarray, |
|
|
guidance: np.ndarray, |
|
|
radius: int = 4, |
|
|
epsilon: float = 0.2**2) -> np.ndarray: |
|
|
"""Apply guided filter for edge-preserving smoothing.""" |
|
|
|
|
|
mean_I = cv2.boxFilter(guidance, cv2.CV_64F, (radius, radius)) |
|
|
mean_p = cv2.boxFilter(input_img, cv2.CV_64F, (radius, radius)) |
|
|
mean_Ip = cv2.boxFilter(guidance * input_img, cv2.CV_64F, (radius, radius)) |
|
|
cov_Ip = mean_Ip - mean_I * mean_p |
|
|
|
|
|
mean_II = cv2.boxFilter(guidance * guidance, cv2.CV_64F, (radius, radius)) |
|
|
var_I = mean_II - mean_I * mean_I |
|
|
|
|
|
a = cov_Ip / (var_I + epsilon) |
|
|
b = mean_p - a * mean_I |
|
|
|
|
|
mean_a = cv2.boxFilter(a, cv2.CV_64F, (radius, radius)) |
|
|
mean_b = cv2.boxFilter(b, cv2.CV_64F, (radius, radius)) |
|
|
|
|
|
q = mean_a * guidance + mean_b |
|
|
|
|
|
return q |
|
|
|
|
|
|
|
|
class HairSegmentation: |
|
|
"""Specialized hair segmentation module.""" |
|
|
|
|
|
def __init__(self, config: EdgeConfig): |
|
|
self.config = config |
|
|
self.hair_detector = HairDetector() |
|
|
|
|
|
def segment(self, image: np.ndarray, initial_mask: np.ndarray) -> np.ndarray: |
|
|
"""Segment hair regions with improved accuracy.""" |
|
|
|
|
|
hair_probability = self.hair_detector.detect(image) |
|
|
|
|
|
|
|
|
hair_mask = self._refine_with_mask(hair_probability, initial_mask) |
|
|
|
|
|
|
|
|
hair_mask = self._fix_hair_asymmetry(hair_mask, image) |
|
|
|
|
|
|
|
|
hair_mask = self._enhance_hair_strands(hair_mask, image) |
|
|
|
|
|
return hair_mask |
|
|
|
|
|
def _refine_with_mask(self, hair_prob: np.ndarray, |
|
|
initial_mask: np.ndarray) -> np.ndarray: |
|
|
"""Refine hair probability with initial mask.""" |
|
|
|
|
|
kernel = np.ones((15, 15), np.uint8) |
|
|
dilated_mask = cv2.dilate(initial_mask, kernel, iterations=2) |
|
|
|
|
|
|
|
|
refined = hair_prob * dilated_mask |
|
|
|
|
|
|
|
|
threshold = self.config.hair_detection_sensitivity |
|
|
hair_mask = (refined > threshold).astype(np.float32) |
|
|
|
|
|
|
|
|
hair_mask = cv2.GaussianBlur(hair_mask, (5, 5), 1.0) |
|
|
|
|
|
return hair_mask |
|
|
|
|
|
def _fix_hair_asymmetry(self, mask: np.ndarray, |
|
|
image: np.ndarray) -> np.ndarray: |
|
|
"""Fix asymmetry in hair segmentation.""" |
|
|
h, w = mask.shape[:2] |
|
|
center_x = w // 2 |
|
|
|
|
|
|
|
|
left_mask = mask[:, :center_x] |
|
|
right_mask = mask[:, center_x:] |
|
|
|
|
|
|
|
|
right_flipped = np.fliplr(right_mask) |
|
|
|
|
|
|
|
|
if left_mask.shape[1] == right_flipped.shape[1]: |
|
|
diff = np.abs(left_mask - right_flipped) |
|
|
asymmetry_score = np.mean(diff) |
|
|
|
|
|
if asymmetry_score > self.config.symmetry_threshold: |
|
|
logger.info(f"Detected hair asymmetry: {asymmetry_score:.3f}") |
|
|
|
|
|
|
|
|
balanced_left = 0.5 * left_mask + 0.5 * right_flipped |
|
|
balanced_right = np.fliplr(0.5 * right_mask + 0.5 * np.fliplr(left_mask)) |
|
|
|
|
|
|
|
|
mask[:, :center_x] = balanced_left |
|
|
mask[:, center_x:center_x + balanced_right.shape[1]] = balanced_right |
|
|
|
|
|
return mask |
|
|
|
|
|
def _enhance_hair_strands(self, mask: np.ndarray, |
|
|
image: np.ndarray) -> np.ndarray: |
|
|
"""Enhance fine hair strands.""" |
|
|
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
|
|
|
|
|
|
|
|
enhanced_mask = mask.copy() |
|
|
|
|
|
|
|
|
orientations = [0, 45, 90, 135] |
|
|
gabor_responses = [] |
|
|
|
|
|
for angle in orientations: |
|
|
theta = np.deg2rad(angle) |
|
|
kernel = cv2.getGaborKernel( |
|
|
(21, 21), 4.0, theta, 10.0, 0.5, 0, ktype=cv2.CV_32F |
|
|
) |
|
|
filtered = cv2.filter2D(gray, cv2.CV_32F, kernel) |
|
|
gabor_responses.append(np.abs(filtered)) |
|
|
|
|
|
|
|
|
gabor_max = np.max(gabor_responses, axis=0) |
|
|
gabor_normalized = gabor_max / (np.max(gabor_max) + 1e-6) |
|
|
|
|
|
|
|
|
hair_enhancement = gabor_normalized * (1 - mask) |
|
|
enhanced_mask = np.clip(mask + 0.3 * hair_enhancement, 0, 1) |
|
|
|
|
|
return enhanced_mask |
|
|
|
|
|
|
|
|
class HairDetector: |
|
|
"""Detects hair regions in images.""" |
|
|
|
|
|
def detect(self, image: np.ndarray) -> np.ndarray: |
|
|
"""Detect hair probability map.""" |
|
|
|
|
|
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) |
|
|
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) |
|
|
|
|
|
|
|
|
hair_colors = [ |
|
|
|
|
|
((0, 0, 0), (180, 255, 30)), |
|
|
|
|
|
((10, 20, 20), (20, 255, 100)), |
|
|
|
|
|
((15, 30, 50), (25, 255, 200)), |
|
|
|
|
|
((0, 50, 50), (10, 255, 150)), |
|
|
] |
|
|
|
|
|
hair_masks = [] |
|
|
for (lower, upper) in hair_colors: |
|
|
mask = cv2.inRange(hsv, np.array(lower), np.array(upper)) |
|
|
hair_masks.append(mask) |
|
|
|
|
|
|
|
|
color_mask = np.max(hair_masks, axis=0) / 255.0 |
|
|
|
|
|
|
|
|
texture_mask = self._detect_hair_texture(image) |
|
|
|
|
|
|
|
|
hair_probability = 0.6 * color_mask + 0.4 * texture_mask |
|
|
|
|
|
|
|
|
hair_probability = cv2.GaussianBlur(hair_probability, (7, 7), 2.0) |
|
|
|
|
|
return hair_probability |
|
|
|
|
|
def _detect_hair_texture(self, image: np.ndarray) -> np.ndarray: |
|
|
"""Detect hair-like texture patterns.""" |
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
|
|
|
|
|
|
|
|
texture_score = np.zeros_like(gray, dtype=np.float32) |
|
|
|
|
|
|
|
|
dx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3) |
|
|
dy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3) |
|
|
|
|
|
|
|
|
magnitude = np.sqrt(dx**2 + dy**2) |
|
|
orientation = np.arctan2(dy, dx) |
|
|
|
|
|
|
|
|
|
|
|
window_size = 9 |
|
|
kernel = np.ones((window_size, window_size)) / (window_size**2) |
|
|
|
|
|
|
|
|
orient_mean = cv2.filter2D(orientation, -1, kernel) |
|
|
orient_sq_mean = cv2.filter2D(orientation**2, -1, kernel) |
|
|
orient_var = orient_sq_mean - orient_mean**2 |
|
|
|
|
|
|
|
|
texture_score = magnitude * np.exp(-orient_var) |
|
|
|
|
|
|
|
|
texture_score = texture_score / (np.max(texture_score) + 1e-6) |
|
|
|
|
|
return texture_score |
|
|
|
|
|
|
|
|
class EdgeRefinement: |
|
|
"""Refines edges for better quality.""" |
|
|
|
|
|
def __init__(self, config: EdgeConfig): |
|
|
self.config = config |
|
|
|
|
|
def refine(self, image: np.ndarray, mask: np.ndarray, |
|
|
edges: np.ndarray) -> np.ndarray: |
|
|
"""Refine mask edges.""" |
|
|
|
|
|
refined = self._bilateral_smooth(mask, image) |
|
|
|
|
|
|
|
|
refined = self._snap_to_edges(refined, image, edges) |
|
|
|
|
|
|
|
|
refined = self._subpixel_refinement(refined, image) |
|
|
|
|
|
|
|
|
refined = self._apply_feathering(refined) |
|
|
|
|
|
return refined |
|
|
|
|
|
def _bilateral_smooth(self, mask: np.ndarray, |
|
|
image: np.ndarray) -> np.ndarray: |
|
|
"""Apply bilateral filtering for edge-aware smoothing.""" |
|
|
|
|
|
mask_uint8 = (mask * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
smoothed = cv2.bilateralFilter( |
|
|
mask_uint8, |
|
|
self.config.bilateral_d, |
|
|
self.config.bilateral_sigma_color, |
|
|
self.config.bilateral_sigma_space |
|
|
) |
|
|
|
|
|
return smoothed / 255.0 |
|
|
|
|
|
def _snap_to_edges(self, mask: np.ndarray, image: np.ndarray, |
|
|
detected_edges: np.ndarray) -> np.ndarray: |
|
|
"""Snap mask boundaries to image edges.""" |
|
|
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
|
|
image_edges = cv2.Canny(gray, 50, 150) / 255.0 |
|
|
|
|
|
|
|
|
mask_edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150) / 255.0 |
|
|
|
|
|
|
|
|
dist_transform = cv2.distanceTransform( |
|
|
(1 - image_edges).astype(np.uint8), |
|
|
cv2.DIST_L2, 5 |
|
|
) |
|
|
|
|
|
|
|
|
snap_radius = self.config.refinement_radius |
|
|
refined = mask.copy() |
|
|
|
|
|
|
|
|
edge_region = cv2.dilate(mask_edges, np.ones((5, 5))) > 0 |
|
|
|
|
|
|
|
|
close_to_image_edge = (dist_transform < snap_radius) & edge_region |
|
|
refined[close_to_image_edge] = np.where( |
|
|
mask[close_to_image_edge] > 0.5, 1.0, 0.0 |
|
|
) |
|
|
|
|
|
return refined |
|
|
|
|
|
def _subpixel_refinement(self, mask: np.ndarray, |
|
|
image: np.ndarray) -> np.ndarray: |
|
|
"""Apply subpixel refinement to edges.""" |
|
|
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
|
|
|
|
|
|
|
|
grad_x = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3) |
|
|
grad_y = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3) |
|
|
grad_mag = np.sqrt(grad_x**2 + grad_y**2) |
|
|
|
|
|
|
|
|
grad_mag = grad_mag / (np.max(grad_mag) + 1e-6) |
|
|
|
|
|
|
|
|
|
|
|
refined = mask.copy() |
|
|
strong_gradient = grad_mag > 0.3 |
|
|
|
|
|
refined[strong_gradient] = np.where( |
|
|
mask[strong_gradient] > 0.5, |
|
|
np.minimum(mask[strong_gradient] + 0.1, 1.0), |
|
|
np.maximum(mask[strong_gradient] - 0.1, 0.0) |
|
|
) |
|
|
|
|
|
return refined |
|
|
|
|
|
def _apply_feathering(self, mask: np.ndarray, |
|
|
radius: int = 3) -> np.ndarray: |
|
|
"""Apply feathering to edges.""" |
|
|
|
|
|
mask_binary = (mask > 0.5).astype(np.uint8) |
|
|
|
|
|
|
|
|
dist_outside = cv2.distanceTransform( |
|
|
mask_binary, cv2.DIST_L2, 5 |
|
|
) |
|
|
|
|
|
|
|
|
dist_inside = cv2.distanceTransform( |
|
|
1 - mask_binary, cv2.DIST_L2, 5 |
|
|
) |
|
|
|
|
|
|
|
|
feather_region = (dist_outside <= radius) | (dist_inside <= radius) |
|
|
|
|
|
if np.any(feather_region): |
|
|
|
|
|
alpha = np.zeros_like(mask) |
|
|
alpha[dist_outside > radius] = 1.0 |
|
|
alpha[feather_region] = dist_outside[feather_region] / radius |
|
|
|
|
|
|
|
|
mask = mask * (1 - feather_region) + alpha * feather_region |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
class SymmetryCorrector: |
|
|
"""Corrects asymmetry in masks.""" |
|
|
|
|
|
def __init__(self, config: EdgeConfig): |
|
|
self.config = config |
|
|
|
|
|
def correct(self, mask: np.ndarray, image: np.ndarray) -> np.ndarray: |
|
|
"""Correct asymmetry in mask.""" |
|
|
|
|
|
center = self._find_center(mask) |
|
|
|
|
|
|
|
|
asymmetry_score = self._compute_asymmetry(mask, center) |
|
|
|
|
|
if asymmetry_score > self.config.symmetry_threshold: |
|
|
logger.info(f"Correcting asymmetry: {asymmetry_score:.3f}") |
|
|
mask = self._balance_mask(mask, center) |
|
|
|
|
|
return mask |
|
|
|
|
|
def _find_center(self, mask: np.ndarray) -> int: |
|
|
"""Find vertical center of object.""" |
|
|
|
|
|
mask_binary = (mask > 0.5).astype(np.uint8) |
|
|
|
|
|
moments = cv2.moments(mask_binary) |
|
|
if moments['m00'] > 0: |
|
|
cx = int(moments['m10'] / moments['m00']) |
|
|
return cx |
|
|
else: |
|
|
return mask.shape[1] // 2 |
|
|
|
|
|
def _compute_asymmetry(self, mask: np.ndarray, center: int) -> float: |
|
|
"""Compute asymmetry score.""" |
|
|
h, w = mask.shape[:2] |
|
|
|
|
|
|
|
|
left_width = center |
|
|
right_width = w - center |
|
|
min_width = min(left_width, right_width) |
|
|
|
|
|
if min_width <= 0: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
left = mask[:, center-min_width:center] |
|
|
right = mask[:, center:center+min_width] |
|
|
|
|
|
|
|
|
right_flipped = np.fliplr(right) |
|
|
|
|
|
|
|
|
diff = np.abs(left - right_flipped) |
|
|
asymmetry = np.mean(diff) |
|
|
|
|
|
return asymmetry |
|
|
|
|
|
def _balance_mask(self, mask: np.ndarray, center: int) -> np.ndarray: |
|
|
"""Balance mask to reduce asymmetry.""" |
|
|
h, w = mask.shape[:2] |
|
|
balanced = mask.copy() |
|
|
|
|
|
|
|
|
left_width = center |
|
|
right_width = w - center |
|
|
min_width = min(left_width, right_width) |
|
|
|
|
|
if min_width <= 0: |
|
|
return mask |
|
|
|
|
|
|
|
|
left = mask[:, center-min_width:center] |
|
|
right = mask[:, center:center+min_width] |
|
|
|
|
|
|
|
|
left_confidence = np.mean(np.abs(left - 0.5)) |
|
|
right_confidence = np.mean(np.abs(right - 0.5)) |
|
|
|
|
|
|
|
|
total_conf = left_confidence + right_confidence + 1e-6 |
|
|
left_weight = left_confidence / total_conf |
|
|
right_weight = right_confidence / total_conf |
|
|
|
|
|
|
|
|
balanced_left = left_weight * left + right_weight * np.fliplr(right) |
|
|
balanced_right = right_weight * right + left_weight * np.fliplr(left) |
|
|
|
|
|
|
|
|
balanced[:, center-min_width:center] = balanced_left |
|
|
balanced[:, center:center+min_width] = balanced_right |
|
|
|
|
|
|
|
|
seam_width = 5 |
|
|
seam_start = max(0, center - seam_width) |
|
|
seam_end = min(w, center + seam_width) |
|
|
balanced[:, seam_start:seam_end] = cv2.GaussianBlur( |
|
|
balanced[:, seam_start:seam_end], (5, 1), 1.0 |
|
|
) |
|
|
|
|
|
return balanced |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'EdgeProcessor', |
|
|
'EdgeConfig', |
|
|
'HairSegmentation', |
|
|
'EdgeRefinement', |
|
|
'SymmetryCorrector', |
|
|
'HairDetector' |
|
|
] |