""" Professional Hair Segmentation Module ===================================== This module provides high-quality hair segmentation for video processing using SAM2 + MatAnyone pipeline with comprehensive error handling and fallbacks. Author: BackgroundFX Pro License: MIT """ import os import torch import cv2 import numpy as np import logging from typing import Dict, List, Tuple, Optional, Union from pathlib import Path import warnings from dataclasses import dataclass from abc import ABC, abstractmethod # Fix threading issues immediately os.environ['OMP_NUM_THREADS'] = '4' os.environ['MKL_NUM_THREADS'] = '4' # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) @dataclass class SegmentationResult: """Result container for hair segmentation""" mask: np.ndarray confidence: float coverage_percent: float asymmetry_score: float processing_time: float fallback_used: bool quality_score: float error_message: Optional[str] = None class BaseSegmentationModel(ABC): """Abstract base class for segmentation models""" @abstractmethod def initialize(self) -> bool: """Initialize the model""" pass @abstractmethod def segment(self, frame: np.ndarray) -> np.ndarray: """Segment hair in frame""" pass @abstractmethod def get_model_name(self) -> str: """Get model name for logging""" pass class SAM2Model(BaseSegmentationModel): """SAM2 segmentation model wrapper""" def __init__(self, model_path: Optional[str] = None, device: str = 'auto'): self.model_path = model_path self.device = self._get_best_device(device) self.predictor = None self.initialized = False def _get_best_device(self, device: str) -> str: """Determine best available device""" if device == 'auto': return 'cuda' if torch.cuda.is_available() else 'cpu' return device def initialize(self) -> bool: """Initialize SAM2 model""" try: logger.info("🤖 Initializing SAM2 model...") # Import SAM2 (handle different installation methods) try: from sam2.build_sam import build_sam2_video_predictor except ImportError: logger.error("SAM2 not found. Please install SAM2.") return False # Build predictor if self.model_path and Path(self.model_path).exists(): self.predictor = build_sam2_video_predictor(self.model_path, device=self.device) else: # Use default model self.predictor = build_sam2_video_predictor("sam2_hiera_large.pt", device=self.device) self.initialized = True logger.info(f"✅ SAM2 initialized on {self.device}") return True except Exception as e: logger.error(f"❌ SAM2 initialization failed: {e}") return False def segment(self, frame: np.ndarray) -> np.ndarray: """Segment using SAM2""" if not self.initialized: raise RuntimeError("SAM2 model not initialized") try: # Convert BGR to RGB if len(frame.shape) == 3: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) else: frame_rgb = frame # Set image for SAM2 self.predictor.set_image(frame_rgb) # Auto-detect person in center (you can make this more sophisticated) height, width = frame_rgb.shape[:2] center_point = np.array([[width//2, height//2]]) # Predict mask masks, scores, _ = self.predictor.predict( point_coords=center_point, point_labels=np.array([1]) ) # Return best mask if len(masks) > 0: best_mask_idx = np.argmax(scores) return masks[best_mask_idx].astype(np.float32) else: return np.zeros((height, width), dtype=np.float32) except Exception as e: logger.error(f"SAM2 segmentation failed: {e}") raise def get_model_name(self) -> str: return "SAM2" class MatAnyoneModel(BaseSegmentationModel): """MatAnyone model wrapper with quality checking""" def __init__(self, use_hf_api: bool = True, hf_token: Optional[str] = None): self.use_hf_api = use_hf_api self.hf_token = hf_token self.client = None self.processor = None self.initialized = False self.quality_threshold = 0.3 def initialize(self) -> bool: """Initialize MatAnyone model""" try: logger.info("🎭 Initializing MatAnyone model...") if self.use_hf_api: from gradio_client import Client self.client = Client("PeiqingYang/MatAnyone", hf_token=self.hf_token) logger.info("✅ MatAnyone HF API initialized") else: # Local MatAnyone initialization would go here logger.warning("Local MatAnyone not implemented yet") return False self.initialized = True return True except Exception as e: logger.error(f"❌ MatAnyone initialization failed: {e}") return False def segment(self, frame: np.ndarray) -> np.ndarray: """MatAnyone is primarily for matting, not segmentation""" raise NotImplementedError("MatAnyone is used for matting, not direct segmentation") def matte(self, image: np.ndarray, trimap: np.ndarray) -> np.ndarray: """Apply matting using MatAnyone""" if not self.initialized: raise RuntimeError("MatAnyone model not initialized") try: # Save temporary files import tempfile with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as img_file: cv2.imwrite(img_file.name, image) img_path = img_file.name with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tri_file: cv2.imwrite(tri_file.name, trimap) tri_path = tri_file.name # Process with MatAnyone if self.use_hf_api: result = self._process_hf_api(img_path, tri_path) else: result = self._process_local(img_path, tri_path) # Cleanup temp files os.unlink(img_path) os.unlink(tri_path) return result except Exception as e: logger.error(f"MatAnyone matting failed: {e}") raise def _process_hf_api(self, image_path: str, trimap_path: str) -> np.ndarray: """Process using HuggingFace API""" try: result = self.client.predict( image=image_path, trimap=trimap_path, api_name="/predict" ) # Load result if isinstance(result, str): result_image = cv2.imread(result) return result_image else: return result except Exception as e: logger.error(f"HF API processing failed: {e}") raise def _process_local(self, image_path: str, trimap_path: str) -> np.ndarray: """Process locally - placeholder for implementation""" raise NotImplementedError("Local MatAnyone processing not implemented") def get_model_name(self) -> str: return "MatAnyone" class TraditionalCVModel(BaseSegmentationModel): """Traditional computer vision fallback""" def __init__(self): self.initialized = False def initialize(self) -> bool: """Initialize traditional CV methods""" self.initialized = True return True def segment(self, frame: np.ndarray) -> np.ndarray: """Traditional hair segmentation using color and texture""" try: # Convert to different color spaces hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) lab = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB) # Hair color detection hair_mask_hsv = self._detect_hair_hsv(hsv) hair_mask_lab = self._detect_hair_lab(lab) # Combine masks combined_mask = cv2.bitwise_or(hair_mask_hsv, hair_mask_lab) # Morphological operations (using OpenCV instead of skimage) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel) combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_OPEN, kernel) return combined_mask.astype(np.float32) / 255.0 except Exception as e: logger.error(f"Traditional CV segmentation failed: {e}") raise def _detect_hair_hsv(self, hsv: np.ndarray) -> np.ndarray: """Detect hair in HSV color space""" # Multiple hair color ranges ranges = [ # Dark hair ([0, 0, 0], [180, 255, 80]), # Brown hair ([8, 50, 20], [25, 255, 200]), # Blonde hair ([15, 30, 100], [35, 255, 255]) ] masks = [] for lower, upper in ranges: mask = cv2.inRange(hsv, np.array(lower), np.array(upper)) masks.append(mask) # Combine all color ranges final_mask = masks[0] for mask in masks[1:]: final_mask = cv2.bitwise_or(final_mask, mask) return final_mask def _detect_hair_lab(self, lab: np.ndarray) -> np.ndarray: """Detect hair in LAB color space""" l_channel = lab[:, :, 0] hair_mask = cv2.inRange(l_channel, 0, 120) return hair_mask def get_model_name(self) -> str: return "TraditionalCV" class TemporalSmoother: """Temporal smoothing for video sequences""" def __init__(self, smoothing_factor: float = 0.7, change_threshold: float = 0.05): self.smoothing_factor = smoothing_factor self.change_threshold = change_threshold self.previous_mask = None self.correction_count = 0 self.total_frames = 0 def smooth(self, current_mask: np.ndarray) -> Tuple[np.ndarray, bool]: """Apply temporal smoothing""" self.total_frames += 1 corrected = False if self.previous_mask is not None: # Calculate change diff = np.mean(np.abs(current_mask - self.previous_mask)) if diff > self.change_threshold: # Apply smoothing smoothed_mask = (self.smoothing_factor * current_mask + (1 - self.smoothing_factor) * self.previous_mask) self.correction_count += 1 corrected = True else: smoothed_mask = current_mask else: smoothed_mask = current_mask self.previous_mask = smoothed_mask.copy() return smoothed_mask, corrected def get_correction_ratio(self) -> float: """Get ratio of frames that needed correction""" return self.correction_count / max(self.total_frames, 1) class HairSegmentationPipeline: """Main hair segmentation pipeline with multiple models and fallbacks""" def __init__(self, config: Optional[Dict] = None): self.config = config or {} self.models = {} self.active_model = None self.fallback_models = [] self.temporal_smoother = TemporalSmoother() self.initialized = False # Setup models self._setup_models() def _setup_models(self): """Setup available models""" try: # Primary model: SAM2 sam2_model = SAM2Model( model_path=self.config.get('sam2_model_path'), device=self.config.get('device', 'auto') ) self.models['sam2'] = sam2_model # MatAnyone for matting matanyone_model = MatAnyoneModel( use_hf_api=self.config.get('use_hf_api', True), hf_token=self.config.get('hf_token') ) self.models['matanyone'] = matanyone_model # Fallback: Traditional CV traditional_model = TraditionalCVModel() self.models['traditional'] = traditional_model except Exception as e: logger.error(f"Model setup failed: {e}") def initialize(self, preferred_model: str = 'sam2') -> bool: """Initialize the pipeline""" logger.info("🚀 Initializing Hair Segmentation Pipeline...") # Try to initialize preferred model if preferred_model in self.models: if self.models[preferred_model].initialize(): self.active_model = preferred_model logger.info(f"✅ Primary model {preferred_model} initialized") else: logger.warning(f"⚠️ Primary model {preferred_model} failed") # Initialize fallback models for model_name, model in self.models.items(): if model_name != self.active_model: if model.initialize(): self.fallback_models.append(model_name) logger.info(f"✅ Fallback model {model_name} ready") # Check if we have at least one working model if self.active_model or self.fallback_models: self.initialized = True logger.info(f"🎯 Pipeline ready - Active: {self.active_model}, Fallbacks: {self.fallback_models}") return True else: logger.error("❌ No working models available") return False def segment_frame(self, frame: np.ndarray, apply_temporal_smoothing: bool = True) -> SegmentationResult: """Segment hair in a single frame""" if not self.initialized: raise RuntimeError("Pipeline not initialized") import time start_time = time.time() # Try active model first mask, model_used, error_msg = self._try_segment_with_model(frame, self.active_model) # If failed, try fallback models if mask is None: for fallback_model in self.fallback_models: mask, model_used, error_msg = self._try_segment_with_model(frame, fallback_model) if mask is not None: break if mask is None: # Complete failure - return empty mask h, w = frame.shape[:2] mask = np.zeros((h, w), dtype=np.float32) model_used = "none" error_msg = "All models failed" # Apply temporal smoothing corrected = False if apply_temporal_smoothing: mask, corrected = self.temporal_smoother.smooth(mask) # Calculate metrics processing_time = time.time() - start_time confidence = self._calculate_confidence(mask) coverage = self._calculate_coverage(mask) asymmetry = self._calculate_asymmetry(mask) quality = self._calculate_quality(mask) return SegmentationResult( mask=mask, confidence=confidence, coverage_percent=coverage, asymmetry_score=asymmetry, processing_time=processing_time, fallback_used=(model_used != self.active_model), quality_score=quality, error_message=error_msg ) def _try_segment_with_model(self, frame: np.ndarray, model_name: str) -> Tuple[Optional[np.ndarray], str, Optional[str]]: """Try to segment with a specific model""" if model_name not in self.models: return None, model_name, f"Model {model_name} not available" try: mask = self.models[model_name].segment(frame) return mask, model_name, None except Exception as e: error_msg = f"Model {model_name} failed: {str(e)}" logger.warning(error_msg) return None, model_name, error_msg def _calculate_confidence(self, mask: np.ndarray) -> float: """Calculate mask confidence using OpenCV instead of skimage""" # Edge sharpness edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150) edge_ratio = np.sum(edges > 0) / mask.size # Mask smoothness using OpenCV Sobel instead of skimage gradient grad_x = cv2.Sobel(mask, cv2.CV_64F, 1, 0, ksize=3) grad_y = cv2.Sobel(mask, cv2.CV_64F, 0, 1, ksize=3) gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2) smoothness = 1.0 / (1.0 + np.std(gradient_magnitude)) return min(edge_ratio * 0.3 + smoothness * 0.7, 1.0) def _calculate_coverage(self, mask: np.ndarray) -> float: """Calculate hair coverage percentage""" return (np.sum(mask > 0.5) / mask.size) * 100 def _calculate_asymmetry(self, mask: np.ndarray) -> float: """Calculate left-right asymmetry score""" h, w = mask.shape[:2] center_x = w // 2 left_half = mask[:, :center_x] right_half = np.fliplr(mask[:, center_x:]) min_width = min(left_half.shape[1], right_half.shape[1]) left_half = left_half[:, :min_width] right_half = right_half[:, :min_width] return np.mean(np.abs(left_half - right_half)) def _calculate_quality(self, mask: np.ndarray) -> float: """Calculate overall mask quality""" # Combine multiple quality metrics confidence = self._calculate_confidence(mask) coverage = self._calculate_coverage(mask) / 100.0 asymmetry_penalty = 1.0 - min(self._calculate_asymmetry(mask), 1.0) return (confidence * 0.5 + coverage * 0.3 + asymmetry_penalty * 0.2) def get_pipeline_stats(self) -> Dict: """Get pipeline performance statistics""" return { 'active_model': self.active_model, 'fallback_models': self.fallback_models, 'temporal_correction_ratio': self.temporal_smoother.get_correction_ratio(), 'total_frames_processed': self.temporal_smoother.total_frames, 'corrections_applied': self.temporal_smoother.correction_count } # Convenience functions def create_pipeline(config: Optional[Dict] = None) -> HairSegmentationPipeline: """Create and initialize hair segmentation pipeline""" pipeline = HairSegmentationPipeline(config) pipeline.initialize() return pipeline def segment_image(image_path: str, config: Optional[Dict] = None) -> SegmentationResult: """Segment hair in a single image""" pipeline = create_pipeline(config) frame = cv2.imread(image_path) return pipeline.segment_frame(frame) def segment_video_frames(video_frames: List[np.ndarray], config: Optional[Dict] = None) -> List[SegmentationResult]: """Segment hair in multiple video frames""" pipeline = create_pipeline(config) results = [] for frame in video_frames: result = pipeline.segment_frame(frame) results.append(result) return results # Example usage if __name__ == "__main__": # Example configuration config = { 'sam2_model_path': None, # Use default 'device': 'auto', 'use_hf_api': True, 'hf_token': None # Set your token if needed } # Create pipeline pipeline = create_pipeline(config) # Test with example frame (you would load your actual frame) test_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) # Segment frame result = pipeline.segment_frame(test_frame) # Print results print(f"Segmentation Results:") print(f" Coverage: {result.coverage_percent:.1f}%") print(f" Confidence: {result.confidence:.3f}") print(f" Quality: {result.quality_score:.3f}") print(f" Processing time: {result.processing_time:.2f}s") print(f" Fallback used: {result.fallback_used}") # Get pipeline stats stats = pipeline.get_pipeline_stats() print(f"\nPipeline Stats:") print(f" Active model: {stats['active_model']}") print(f" Fallbacks: {stats['fallback_models']}") print(f" Correction ratio: {stats['temporal_correction_ratio']:.3f}")