""" Professional Edge Detection & Refinement Module ============================================== This module provides advanced edge detection, refinement, and processing specifically optimized for hair segmentation in video processing pipelines. Features: - Multi-scale edge detection - Hair-specific edge refinement - Temporal edge consistency - Sub-pixel edge accuracy - GPU-accelerated processing Author: BackgroundFX Pro License: MIT """ import os import cv2 import numpy as np import logging from typing import Dict, List, Tuple, Optional, Union from dataclasses import dataclass from enum import Enum import time try: import torch import torch.nn.functional as F TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False logging.warning("PyTorch not available - using CPU-only edge detection") # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EdgeDetectionMethod(Enum): """Available edge detection methods""" CANNY = "canny" SOBEL = "sobel" LAPLACIAN = "laplacian" SCHARR = "scharr" PREWITT = "prewitt" ROBERTS = "roberts" MULTISCALE = "multiscale" HAIR_OPTIMIZED = "hair_optimized" @dataclass class EdgeDetectionResult: """Result container for edge detection""" edges: np.ndarray confidence_map: np.ndarray edge_strength: float processing_time: float method_used: str quality_score: float class EdgeQualityMetrics: """Calculate edge quality metrics""" @staticmethod def calculate_edge_strength(edges: np.ndarray) -> float: """Calculate overall edge strength""" return np.mean(edges[edges > 0]) if np.any(edges > 0) else 0.0 @staticmethod def calculate_edge_density(edges: np.ndarray) -> float: """Calculate edge density (ratio of edge pixels)""" return np.sum(edges > 0) / edges.size @staticmethod def calculate_edge_continuity(edges: np.ndarray) -> float: """Calculate edge continuity score""" # Use morphological operations to measure continuity kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) dilated = cv2.dilate(edges, kernel, iterations=1) eroded = cv2.erode(dilated, kernel, iterations=1) # Continuity is measured by how much structure is preserved original_pixels = np.sum(edges > 0) preserved_pixels = np.sum(eroded > 0) return preserved_pixels / max(original_pixels, 1) @staticmethod def calculate_edge_thickness_variation(edges: np.ndarray) -> float: """Calculate variation in edge thickness""" # Use distance transform to measure edge thickness dist_transform = cv2.distanceTransform( (edges > 0).astype(np.uint8), cv2.DIST_L2, 5 ) edge_pixels = edges > 0 if not np.any(edge_pixels): return 0.0 thicknesses = dist_transform[edge_pixels] return np.std(thicknesses) / (np.mean(thicknesses) + 1e-6) @staticmethod def calculate_overall_quality(edges: np.ndarray) -> float: """Calculate overall edge quality score""" strength = EdgeQualityMetrics.calculate_edge_strength(edges) density = EdgeQualityMetrics.calculate_edge_density(edges) continuity = EdgeQualityMetrics.calculate_edge_continuity(edges) thickness_var = EdgeQualityMetrics.calculate_edge_thickness_variation(edges) # Combine metrics (lower thickness variation is better) quality = ( strength * 0.3 + density * 0.2 + continuity * 0.4 + (1.0 - min(thickness_var, 1.0)) * 0.1 ) return min(quality, 1.0) class BaseEdgeDetector: """Base class for edge detectors""" def __init__(self, name: str): self.name = name def detect(self, image: np.ndarray, **kwargs) -> np.ndarray: """Detect edges in image""" raise NotImplementedError def get_default_params(self) -> Dict: """Get default parameters""" return {} class CannyEdgeDetector(BaseEdgeDetector): """Canny edge detector with adaptive thresholds""" def __init__(self): super().__init__("Canny") def detect(self, image: np.ndarray, **kwargs) -> np.ndarray: """Detect edges using Canny""" # Convert to grayscale if needed if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = image # Adaptive threshold calculation low_threshold = kwargs.get('low_threshold', None) high_threshold = kwargs.get('high_threshold', None) if low_threshold is None or high_threshold is None: # Calculate adaptive thresholds using Otsu's method _, otsu_thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) low_threshold = 0.5 * otsu_thresh high_threshold = otsu_thresh # Apply Gaussian blur blur_kernel = kwargs.get('blur_kernel', 5) if blur_kernel > 0: gray = cv2.GaussianBlur(gray, (blur_kernel, blur_kernel), 0) # Detect edges edges = cv2.Canny( gray, int(low_threshold), int(high_threshold), apertureSize=kwargs.get('aperture_size', 3), L2gradient=kwargs.get('l2_gradient', False) ) return edges.astype(np.float32) / 255.0 def get_default_params(self) -> Dict: return { 'low_threshold': None, 'high_threshold': None, 'blur_kernel': 5, 'aperture_size': 3, 'l2_gradient': False } class HairOptimizedEdgeDetector(BaseEdgeDetector): """Hair-specific edge detection optimized for fine details""" def __init__(self): super().__init__("HairOptimized") def detect(self, image: np.ndarray, **kwargs) -> np.ndarray: """Detect hair edges using multi-scale approach""" if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = image # Multi-scale edge detection scales = kwargs.get('scales', [1.0, 0.7, 1.4]) edge_maps = [] for scale in scales: # Resize image if scale != 1.0: h, w = gray.shape new_h, new_w = int(h * scale), int(w * scale) scaled_gray = cv2.resize(gray, (new_w, new_h)) else: scaled_gray = gray # Detect edges at this scale scale_edges = self._detect_single_scale(scaled_gray, **kwargs) # Resize back to original size if scale != 1.0: scale_edges = cv2.resize(scale_edges, (gray.shape[1], gray.shape[0])) edge_maps.append(scale_edges) # Combine edge maps combined_edges = self._combine_edge_maps(edge_maps) # Hair-specific post-processing refined_edges = self._hair_specific_refinement(combined_edges, gray) return refined_edges def _detect_single_scale(self, gray: np.ndarray, **kwargs) -> np.ndarray: """Detect edges at single scale""" # Use multiple gradient operators sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) sobel_magnitude = np.sqrt(sobel_x**2 + sobel_y**2) # Scharr operator for better fine detail detection scharr_x = cv2.Scharr(gray, cv2.CV_64F, 1, 0) scharr_y = cv2.Scharr(gray, cv2.CV_64F, 0, 1) scharr_magnitude = np.sqrt(scharr_x**2 + scharr_y**2) # Combine operators combined = 0.6 * sobel_magnitude + 0.4 * scharr_magnitude # Normalize combined = combined / (np.max(combined) + 1e-6) return combined.astype(np.float32) def _combine_edge_maps(self, edge_maps: List[np.ndarray]) -> np.ndarray: """Combine multiple edge maps""" # Weighted combination - give more weight to original scale weights = [0.5, 0.25, 0.25] # Adjust based on scales combined = np.zeros_like(edge_maps[0]) for edge_map, weight in zip(edge_maps, weights): combined += edge_map * weight return combined def _hair_specific_refinement(self, edges: np.ndarray, original: np.ndarray) -> np.ndarray: """Apply hair-specific refinements""" # Enhance thin structures (hair strands) kernel_thin = np.array([[-1, -1, -1], [ 2, 2, 2], [-1, -1, -1]]) / 3.0 thin_enhanced = cv2.filter2D(edges, -1, kernel_thin) # Combine with original edges refined = 0.7 * edges + 0.3 * np.abs(thin_enhanced) # Apply non-maximum suppression for thin edges refined = self._thin_edge_nms(refined) return refined def _thin_edge_nms(self, edges: np.ndarray) -> np.ndarray: """Non-maximum suppression optimized for thin edges""" # Simple 3x3 NMS kernel = np.ones((3, 3), np.uint8) dilated = cv2.dilate(edges, kernel, iterations=1) # Keep only local maxima nms_edges = np.where(edges == dilated, edges, 0) return nms_edges class MultiScaleEdgeDetector(BaseEdgeDetector): """Multi-scale edge detection with scale fusion""" def __init__(self): super().__init__("MultiScale") def detect(self, image: np.ndarray, **kwargs) -> np.ndarray: """Multi-scale edge detection""" if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = image scales = kwargs.get('scales', [0.5, 1.0, 1.5, 2.0]) sigma_base = kwargs.get('sigma_base', 1.0) edge_pyramid = [] for scale in scales: # Calculate sigma for this scale sigma = sigma_base * scale # Apply Gaussian blur blurred = cv2.GaussianBlur(gray, (0, 0), sigma) # Detect edges edges = cv2.Canny( blurred, int(50 / scale), # Adaptive thresholds int(150 / scale), apertureSize=3 ) edge_pyramid.append(edges.astype(np.float32) / 255.0) # Combine scales with weighted fusion weights = np.array([0.1, 0.4, 0.3, 0.2]) # Favor middle scales combined_edges = np.zeros_like(edge_pyramid[0]) for edges, weight in zip(edge_pyramid, weights): combined_edges += edges * weight return combined_edges class GPUEdgeDetector(BaseEdgeDetector): """GPU-accelerated edge detection using PyTorch""" def __init__(self): super().__init__("GPU") self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if not TORCH_AVAILABLE: logger.warning("PyTorch not available - GPU edge detection disabled") def detect(self, image: np.ndarray, **kwargs) -> np.ndarray: """GPU-accelerated edge detection""" if not TORCH_AVAILABLE: # Fallback to CPU Canny detector = CannyEdgeDetector() return detector.detect(image, **kwargs) # Convert to tensor if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = image tensor = torch.from_numpy(gray).float().unsqueeze(0).unsqueeze(0).to(self.device) tensor = tensor / 255.0 # Sobel operators sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3).to(self.device) sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3).to(self.device) # Apply convolutions grad_x = F.conv2d(tensor, sobel_x, padding=1) grad_y = F.conv2d(tensor, sobel_y, padding=1) # Calculate magnitude magnitude = torch.sqrt(grad_x**2 + grad_y**2) # Apply threshold threshold = kwargs.get('threshold', 0.1) edges = (magnitude > threshold).float() # Convert back to numpy result = edges.squeeze().cpu().numpy() return result class TemporalEdgeConsistency: """Ensure temporal consistency in edge detection across frames""" def __init__(self, memory_frames: int = 3, consistency_threshold: float = 0.1): self.memory_frames = memory_frames self.consistency_threshold = consistency_threshold self.frame_buffer = [] def apply_temporal_consistency(self, current_edges: np.ndarray) -> np.ndarray: """Apply temporal consistency to current frame edges""" if len(self.frame_buffer) == 0: # First frame - just store and return self.frame_buffer.append(current_edges.copy()) return current_edges # Calculate consistency with previous frames consistent_edges = self._calculate_consistent_edges(current_edges) # Update buffer self.frame_buffer.append(current_edges.copy()) if len(self.frame_buffer) > self.memory_frames: self.frame_buffer.pop(0) return consistent_edges def _calculate_consistent_edges(self, current_edges: np.ndarray) -> np.ndarray: """Calculate temporally consistent edges""" # Weight recent frames more heavily weights = np.linspace(0.1, 0.9, len(self.frame_buffer)) weights = weights / np.sum(weights) # Create weighted average of previous frames avg_previous = np.zeros_like(current_edges) for frame, weight in zip(self.frame_buffer, weights): avg_previous += frame * weight # Blend current with historical average consistency_factor = 0.3 # How much to blend with history blended_edges = (1 - consistency_factor) * current_edges + consistency_factor * avg_previous return blended_edges class EdgeRefinementProcessor: """Post-process edges for better quality""" @staticmethod def remove_noise(edges: np.ndarray, min_area: int = 10) -> np.ndarray: """Remove small noise components""" # Find connected components edges_uint8 = (edges * 255).astype(np.uint8) num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(edges_uint8, connectivity=8) # Filter by area filtered_edges = np.zeros_like(edges) for i in range(1, num_labels): # Skip background (label 0) area = stats[i, cv2.CC_STAT_AREA] if area >= min_area: filtered_edges[labels == i] = edges[labels == i] return filtered_edges @staticmethod def smooth_edges(edges: np.ndarray, iterations: int = 1) -> np.ndarray: """Smooth edges while preserving structure""" smoothed = edges.copy() for _ in range(iterations): # Apply gentle Gaussian smoothing smoothed = cv2.GaussianBlur(smoothed, (3, 3), 0.5) return smoothed @staticmethod def enhance_hair_edges(edges: np.ndarray, original_image: np.ndarray) -> np.ndarray: """Enhance edges specifically for hair""" # Convert original to grayscale if needed if len(original_image.shape) == 3: gray = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY) else: gray = original_image # Use structure tensor to find hair-like structures # Calculate gradients grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) # Structure tensor components J11 = cv2.GaussianBlur(grad_x * grad_x, (5, 5), 1.0) J22 = cv2.GaussianBlur(grad_y * grad_y, (5, 5), 1.0) J12 = cv2.GaussianBlur(grad_x * grad_y, (5, 5), 1.0) # Calculate coherence (measure of linear structure) trace = J11 + J22 det = J11 * J22 - J12 * J12 # Avoid division by zero coherence = np.divide( (trace - 2 * np.sqrt(det + 1e-6))**2, (trace + 1e-6)**2, out=np.zeros_like(trace), where=(trace + 1e-6) != 0 ) # Normalize coherence coherence = coherence / (np.max(coherence) + 1e-6) # Enhance edges where coherence is high (linear structures like hair) enhanced_edges = edges * (1.0 + coherence * 0.5) return np.clip(enhanced_edges, 0, 1) class EdgeDetectionPipeline: """Main edge detection pipeline with multiple methods and post-processing""" def __init__(self, config: Optional[Dict] = None): self.config = config or {} self.detectors = {} self.temporal_processor = TemporalEdgeConsistency( memory_frames=self.config.get('temporal_memory', 3), consistency_threshold=self.config.get('consistency_threshold', 0.1) ) self.refinement_processor = EdgeRefinementProcessor() # Initialize detectors self._initialize_detectors() def _initialize_detectors(self): """Initialize available edge detectors""" self.detectors[EdgeDetectionMethod.CANNY] = CannyEdgeDetector() self.detectors[EdgeDetectionMethod.HAIR_OPTIMIZED] = HairOptimizedEdgeDetector() self.detectors[EdgeDetectionMethod.MULTISCALE] = MultiScaleEdgeDetector() if TORCH_AVAILABLE: self.detectors[EdgeDetectionMethod.GPU] = GPUEdgeDetector() def detect_edges(self, image: np.ndarray, method: EdgeDetectionMethod = EdgeDetectionMethod.HAIR_OPTIMIZED, apply_temporal_consistency: bool = True, apply_refinement: bool = True, **kwargs) -> EdgeDetectionResult: """Detect edges with specified method and post-processing""" start_time = time.time() # Select detector if method not in self.detectors: logger.warning(f"Method {method} not available, using Canny") method = EdgeDetectionMethod.CANNY detector = self.detectors[method] # Detect edges try: edges = detector.detect(image, **kwargs) except Exception as e: logger.error(f"Edge detection failed with {method.value}: {e}") # Fallback to Canny edges = self.detectors[EdgeDetectionMethod.CANNY].detect(image, **kwargs) method = EdgeDetectionMethod.CANNY # Apply temporal consistency if apply_temporal_consistency: edges = self.temporal_processor.apply_temporal_consistency(edges) # Apply refinement if apply_refinement: # Remove noise edges = self.refinement_processor.remove_noise( edges, min_area=self.config.get('min_edge_area', 10) ) # Smooth edges edges = self.refinement_processor.smooth_edges( edges, iterations=self.config.get('smoothing_iterations', 1) ) # Enhance hair edges edges = self.refinement_processor.enhance_hair_edges(edges, image) # Calculate metrics processing_time = time.time() - start_time quality_score = EdgeQualityMetrics.calculate_overall_quality(edges) edge_strength = EdgeQualityMetrics.calculate_edge_strength(edges) # Create confidence map (edges as confidence) confidence_map = edges.copy() return EdgeDetectionResult( edges=edges, confidence_map=confidence_map, edge_strength=edge_strength, processing_time=processing_time, method_used=method.value, quality_score=quality_score ) def get_best_method_for_image(self, image: np.ndarray) -> EdgeDetectionMethod: """Automatically select best edge detection method for image""" # Analyze image characteristics if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = image # Calculate image statistics contrast = np.std(gray) brightness = np.mean(gray) # High contrast images work well with Canny if contrast > 50: return EdgeDetectionMethod.CANNY # Low contrast or complex textures benefit from hair-optimized if contrast < 20 or brightness < 50: return EdgeDetectionMethod.HAIR_OPTIMIZED # Default to multiscale for balanced cases return EdgeDetectionMethod.MULTISCALE # Convenience functions def detect_hair_edges(image: np.ndarray, config: Optional[Dict] = None) -> EdgeDetectionResult: """Convenience function to detect hair edges with optimal settings""" pipeline = EdgeDetectionPipeline(config) return pipeline.detect_edges( image, method=EdgeDetectionMethod.HAIR_OPTIMIZED, apply_temporal_consistency=False, apply_refinement=True ) def detect_video_edges(frames: List[np.ndarray], config: Optional[Dict] = None) -> List[EdgeDetectionResult]: """Detect edges in video frames with temporal consistency""" pipeline = EdgeDetectionPipeline(config) results = [] for frame in frames: result = pipeline.detect_edges( frame, method=EdgeDetectionMethod.HAIR_OPTIMIZED, apply_temporal_consistency=True, apply_refinement=True ) results.append(result) return results # Example usage and testing if __name__ == "__main__": # Test with synthetic image test_image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) # Create pipeline config = { 'temporal_memory': 3, 'consistency_threshold': 0.1, 'min_edge_area': 10, 'smoothing_iterations': 1 } pipeline = EdgeDetectionPipeline(config) # Test different methods methods = [ EdgeDetectionMethod.CANNY, EdgeDetectionMethod.HAIR_OPTIMIZED, EdgeDetectionMethod.MULTISCALE ] for method in methods: if method in pipeline.detectors: result = pipeline.detect_edges(test_image, method=method) print(f"\n{method.value} Results:") print(f" Edge strength: {result.edge_strength:.3f}") print(f" Quality score: {result.quality_score:.3f}") print(f" Processing time: {result.processing_time:.3f}s") # Test automatic method selection best_method = pipeline.get_best_method_for_image(test_image) print(f"\nBest method for this image: {best_method.value}")