VideoBackgroundReplacer / core /hair_segmentation.py
MogensR's picture
Update core/hair_segmentation.py
cb75d8c
raw
history blame
29.2 kB
"""
Advanced hair segmentation pipeline for BackgroundFX Pro.
Specialized module for accurate hair detection and segmentation.
"""
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
import logging
from scipy import ndimage
# from skimage import morphology, filters
logger = logging.getLogger(__name__)
@dataclass
class HairConfig:
"""Configuration for hair segmentation."""
min_hair_confidence: float = 0.6
edge_sensitivity: float = 0.8
strand_detection: bool = True
strand_thickness: int = 2
asymmetry_correction: bool = True
max_asymmetry_ratio: float = 1.5
use_deep_features: bool = False
refinement_iterations: int = 3
alpha_matting: bool = True
preserve_details: bool = True
smoothing_sigma: float = 1.0
class HairSegmentationPipeline:
"""Complete hair segmentation pipeline."""
def __init__(self, config: Optional[HairConfig] = None):
self.config = config or HairConfig()
self.mask_refiner = HairMaskRefiner(config)
self.asymmetry_detector = AsymmetryDetector(config)
self.edge_enhancer = HairEdgeEnhancer(config)
# Optional deep learning model
self.deep_model = None
if self.config.use_deep_features:
self.deep_model = HairNet()
def segment(self, image: np.ndarray,
initial_mask: Optional[np.ndarray] = None,
prompts: Optional[Dict] = None) -> Dict[str, np.ndarray]:
"""
Perform complete hair segmentation.
Returns:
Dictionary containing:
- 'mask': Final hair mask
- 'confidence': Confidence map
- 'strands': Fine hair strands mask
- 'edges': Hair edge map
"""
h, w = image.shape[:2]
# 1. Initial hair detection
hair_regions = self._detect_hair_regions(image, initial_mask)
# 2. Deep feature extraction (if enabled)
if self.deep_model and self.config.use_deep_features:
deep_features = self._extract_deep_features(image)
hair_regions = self._enhance_with_deep_features(hair_regions, deep_features)
# 3. Detect and correct asymmetry
if self.config.asymmetry_correction:
asymmetry_info = self.asymmetry_detector.detect(hair_regions, image)
if asymmetry_info['is_asymmetric']:
logger.info(f"Correcting hair asymmetry: {asymmetry_info['score']:.3f}")
hair_regions = self.asymmetry_detector.correct(
hair_regions, asymmetry_info
)
# 4. Detect fine hair strands
strands_mask = None
if self.config.strand_detection:
strands_mask = self._detect_hair_strands(image, hair_regions)
# Integrate strands into main mask
hair_regions = self._integrate_strands(hair_regions, strands_mask)
# 5. Refine mask
refined_mask = self.mask_refiner.refine(image, hair_regions)
# 6. Edge enhancement
edges = self.edge_enhancer.enhance(refined_mask, image)
refined_mask = self._apply_edge_enhancement(refined_mask, edges)
# 7. Alpha matting (if enabled)
if self.config.alpha_matting:
refined_mask = self._apply_alpha_matting(image, refined_mask)
# 8. Final smoothing
final_mask = self._final_smoothing(refined_mask)
# 9. Compute confidence
confidence = self._compute_confidence(final_mask, initial_mask)
return {
'mask': final_mask,
'confidence': confidence,
'strands': strands_mask,
'edges': edges
}
def _detect_hair_regions(self, image: np.ndarray,
initial_mask: Optional[np.ndarray]) -> np.ndarray:
"""Detect hair regions using multiple cues."""
# Color-based detection
color_mask = self._detect_by_color(image)
# Texture-based detection
texture_mask = self._detect_by_texture(image)
# Combine cues
hair_probability = 0.6 * color_mask + 0.4 * texture_mask
# If initial mask provided, constrain to it
if initial_mask is not None:
# Dilate initial mask slightly to catch hair edges
kernel = np.ones((15, 15), np.uint8)
dilated_initial = cv2.dilate(initial_mask, kernel, iterations=2)
hair_probability *= dilated_initial
# Threshold
hair_mask = (hair_probability > self.config.min_hair_confidence).astype(np.float32)
# Clean up small regions
hair_mask = self._remove_small_regions(hair_mask)
return hair_mask
def _detect_by_color(self, image: np.ndarray) -> np.ndarray:
"""Detect hair by color characteristics."""
# Convert to multiple color spaces
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
ycrcb = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb)
masks = []
# Black hair detection
black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 30))
masks.append(black_mask)
# Brown hair detection
brown_mask = cv2.inRange(hsv, (10, 20, 20), (20, 255, 100))
masks.append(brown_mask)
# Blonde hair detection
blonde_mask = cv2.inRange(hsv, (15, 30, 50), (25, 255, 200))
masks.append(blonde_mask)
# Red/Auburn hair detection
red_mask = cv2.inRange(hsv, (0, 50, 50), (10, 255, 150))
auburn_mask = cv2.inRange(hsv, (160, 50, 50), (180, 255, 150))
masks.append(cv2.bitwise_or(red_mask, auburn_mask))
# Gray/White hair detection
gray_mask = cv2.inRange(hsv, (0, 0, 50), (180, 30, 200))
masks.append(gray_mask)
# Combine all masks
combined = np.zeros_like(masks[0], dtype=np.float32)
for mask in masks:
combined = np.maximum(combined, mask.astype(np.float32) / 255.0)
# Smooth the result
combined = cv2.GaussianBlur(combined, (7, 7), 2.0)
return combined
def _detect_by_texture(self, image: np.ndarray) -> np.ndarray:
"""Detect hair by texture characteristics."""
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
# Multi-scale texture analysis
texture_responses = []
# Gabor filters for different orientations and scales
for scale in [3, 5, 7]:
for angle in [0, 30, 60, 90, 120, 150]:
theta = np.deg2rad(angle)
kernel = cv2.getGaborKernel(
(21, 21), scale, theta, 10.0, 0.5, 0, ktype=cv2.CV_32F
)
response = cv2.filter2D(gray, cv2.CV_32F, kernel)
texture_responses.append(np.abs(response))
# Combine responses
texture_map = np.mean(texture_responses, axis=0)
# Normalize
texture_map = (texture_map - np.min(texture_map)) / (np.max(texture_map) - np.min(texture_map) + 1e-6)
# Hair tends to have consistent directional texture
# Compute local coherence
coherence = self._compute_texture_coherence(texture_responses)
# Combine texture magnitude and coherence
hair_texture = texture_map * coherence
return hair_texture
def _compute_texture_coherence(self, responses: List[np.ndarray]) -> np.ndarray:
"""Compute texture coherence (consistency of orientation)."""
if len(responses) < 2:
return np.ones_like(responses[0])
# Compute variance across orientations
response_stack = np.stack(responses, axis=0)
variance = np.var(response_stack, axis=0)
mean = np.mean(response_stack, axis=0) + 1e-6
# Low variance relative to mean = high coherence
coherence = 1.0 - np.minimum(variance / mean, 1.0)
return coherence
def _detect_hair_strands(self, image: np.ndarray,
hair_mask: np.ndarray) -> np.ndarray:
"""Detect fine hair strands."""
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
# Edge detection with low threshold for fine details
edges = cv2.Canny(gray, 10, 30)
# Line detection using Hough transform
lines = cv2.HoughLinesP(
edges, 1, np.pi/180, threshold=20,
minLineLength=10, maxLineGap=5
)
# Create strand mask
strand_mask = np.zeros_like(gray, dtype=np.float32)
if lines is not None:
for line in lines:
x1, y1, x2, y2 = line[0]
# Check if line is near hair region
mid_x, mid_y = (x1 + x2) // 2, (y1 + y2) // 2
# Dilated hair mask for proximity check
kernel = np.ones((15, 15), np.uint8)
dilated_hair = cv2.dilate(hair_mask, kernel, iterations=1)
if dilated_hair[mid_y, mid_x] > 0:
# Draw line as potential hair strand
cv2.line(strand_mask, (x1, y1), (x2, y2), 1.0, self.config.strand_thickness)
# Ridge detection for curved strands
ridges = filters.frangi(gray, sigmas=range(1, 4))
ridges = (ridges - np.min(ridges)) / (np.max(ridges) - np.min(ridges) + 1e-6)
# Combine with line detection
strand_mask = np.maximum(strand_mask, ridges * dilated_hair)
# Threshold and clean
strand_mask = (strand_mask > 0.3).astype(np.float32)
strand_mask = cv2.morphologyEx(strand_mask, cv2.MORPH_CLOSE, np.ones((3, 3)))
return strand_mask
def _integrate_strands(self, hair_mask: np.ndarray,
strands_mask: np.ndarray) -> np.ndarray:
"""Integrate detected strands into main hair mask."""
if strands_mask is None:
return hair_mask
# Add strands to hair mask
integrated = np.maximum(hair_mask, strands_mask * 0.8)
# Smooth the integration
integrated = cv2.GaussianBlur(integrated, (5, 5), 1.0)
return np.clip(integrated, 0, 1)
def _extract_deep_features(self, image: np.ndarray) -> torch.Tensor:
"""Extract deep features using neural network."""
if not self.deep_model:
return None
# Prepare input
input_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
# Extract features
with torch.no_grad():
features = self.deep_model.extract_features(input_tensor)
return features
def _enhance_with_deep_features(self, mask: np.ndarray,
features: torch.Tensor) -> np.ndarray:
"""Enhance mask using deep features."""
if features is None:
return mask
# Process features to get hair probability
hair_prob = self.deep_model.process_features(features)
hair_prob = hair_prob.squeeze().cpu().numpy()
# Resize to match mask
hair_prob = cv2.resize(hair_prob, (mask.shape[1], mask.shape[0]))
# Combine with existing mask
enhanced = 0.7 * mask + 0.3 * hair_prob
return np.clip(enhanced, 0, 1)
def _apply_alpha_matting(self, image: np.ndarray,
mask: np.ndarray) -> np.ndarray:
"""Apply alpha matting for refined transparency."""
# Simple alpha matting using guided filter
# For production, consider using more advanced methods like Deep Image Matting
# Convert image to grayscale for guidance
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
gray = gray.astype(np.float32) / 255.0
# Guided filter for alpha matting
radius = 20
epsilon = 0.01
alpha = self._guided_filter(mask, gray, radius, epsilon)
return np.clip(alpha, 0, 1)
def _guided_filter(self, p: np.ndarray, I: np.ndarray,
radius: int, epsilon: float) -> np.ndarray:
"""Guided filter implementation."""
mean_I = cv2.boxFilter(I, cv2.CV_32F, (radius, radius))
mean_p = cv2.boxFilter(p, cv2.CV_32F, (radius, radius))
mean_Ip = cv2.boxFilter(I * p, cv2.CV_32F, (radius, radius))
cov_Ip = mean_Ip - mean_I * mean_p
mean_II = cv2.boxFilter(I * I, cv2.CV_32F, (radius, radius))
var_I = mean_II - mean_I * mean_I
a = cov_Ip / (var_I + epsilon)
b = mean_p - a * mean_I
mean_a = cv2.boxFilter(a, cv2.CV_32F, (radius, radius))
mean_b = cv2.boxFilter(b, cv2.CV_32F, (radius, radius))
q = mean_a * I + mean_b
return q
def _apply_edge_enhancement(self, mask: np.ndarray,
edges: np.ndarray) -> np.ndarray:
"""Apply edge enhancement to mask."""
# Strengthen mask at detected edges
edge_weight = 0.3
enhanced = mask + edge_weight * edges
return np.clip(enhanced, 0, 1)
def _final_smoothing(self, mask: np.ndarray) -> np.ndarray:
"""Apply final smoothing while preserving details."""
if self.config.preserve_details:
# Edge-preserving smoothing
smoothed = cv2.bilateralFilter(
(mask * 255).astype(np.uint8), 9, 75, 75
) / 255.0
else:
# Simple Gaussian smoothing
smoothed = cv2.GaussianBlur(
mask, (5, 5), self.config.smoothing_sigma
)
return smoothed
def _compute_confidence(self, mask: np.ndarray,
initial_mask: Optional[np.ndarray]) -> np.ndarray:
"""Compute confidence map for the segmentation."""
# Base confidence from mask values
# Values close to 0 or 1 are more confident
distance_from_middle = np.abs(mask - 0.5) * 2
confidence = distance_from_middle
# If initial mask provided, boost confidence in agreement areas
if initial_mask is not None:
agreement = 1 - np.abs(mask - initial_mask)
confidence = 0.7 * confidence + 0.3 * agreement
return np.clip(confidence, 0, 1)
def _remove_small_regions(self, mask: np.ndarray,
min_size: int = 100) -> np.ndarray:
"""Remove small disconnected regions."""
# Convert to binary
binary = (mask > 0.5).astype(np.uint8)
# Find connected components
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary)
# Remove small components
cleaned = np.zeros_like(mask)
for i in range(1, num_labels):
if stats[i, cv2.CC_STAT_AREA] >= min_size:
cleaned[labels == i] = mask[labels == i]
return cleaned
class HairMaskRefiner:
"""Refines hair masks for better quality."""
def __init__(self, config: HairConfig):
self.config = config
def refine(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Refine hair mask through multiple iterations."""
refined = mask.copy()
for iteration in range(self.config.refinement_iterations):
# Progressive refinement
refined = self._refine_iteration(image, refined, iteration)
return refined
def _refine_iteration(self, image: np.ndarray, mask: np.ndarray,
iteration: int) -> np.ndarray:
"""Single refinement iteration."""
# Morphological operations
kernel_size = 5 - iteration # Decreasing kernel size
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)
)
# Close gaps
refined = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
# Remove noise
refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel)
# Smooth boundaries
refined = cv2.GaussianBlur(refined, (3, 3), 0.5)
return refined
class AsymmetryDetector:
"""Detects and corrects asymmetry in hair masks."""
def __init__(self, config: HairConfig):
self.config = config
def detect(self, mask: np.ndarray, image: np.ndarray) -> Dict[str, Any]:
"""Detect asymmetry in hair mask."""
h, w = mask.shape[:2]
# Find vertical center line
center_x = self._find_center_line(mask)
# Split into left and right
left_mask = mask[:, :center_x]
right_mask = mask[:, center_x:]
# Make same width for comparison
min_width = min(left_mask.shape[1], right_mask.shape[1])
left_mask = left_mask[:, -min_width:] if left_mask.shape[1] > min_width else left_mask
right_mask = right_mask[:, :min_width] if right_mask.shape[1] > min_width else right_mask
# Flip right for comparison
right_flipped = np.fliplr(right_mask)
# Compute asymmetry metrics
pixel_diff = np.mean(np.abs(left_mask - right_flipped))
# Area comparison
left_area = np.sum(left_mask > 0.5)
right_area = np.sum(right_mask > 0.5)
area_ratio = max(left_area, right_area) / (min(left_area, right_area) + 1e-6)
# Edge comparison
left_edges = cv2.Canny((left_mask * 255).astype(np.uint8), 50, 150)
right_edges = cv2.Canny((right_mask * 255).astype(np.uint8), 50, 150)
right_edges_flipped = np.fliplr(right_edges)
edge_diff = np.mean(np.abs(left_edges - right_edges_flipped)) / 255.0
# Overall asymmetry score
asymmetry_score = 0.4 * pixel_diff + 0.3 * (area_ratio - 1.0) / 2.0 + 0.3 * edge_diff
is_asymmetric = (asymmetry_score > self.config.symmetry_threshold or
area_ratio > self.config.max_asymmetry_ratio)
return {
'is_asymmetric': is_asymmetric,
'score': asymmetry_score,
'center_x': center_x,
'area_ratio': area_ratio,
'pixel_diff': pixel_diff,
'edge_diff': edge_diff
}
def correct(self, mask: np.ndarray, asymmetry_info: Dict[str, Any]) -> np.ndarray:
"""Correct detected asymmetry."""
center_x = asymmetry_info['center_x']
h, w = mask.shape[:2]
# Split mask
left_mask = mask[:, :center_x]
right_mask = mask[:, center_x:]
# Determine which side is more reliable
left_density = np.mean(left_mask > 0.5)
right_density = np.mean(right_mask > 0.5)
# Use denser side as reference (usually more complete)
if left_density > right_density:
# Mirror left to right
reference = left_mask
mirrored = np.fliplr(reference)
# Blend with original right
corrected_right = 0.7 * mirrored[:, :right_mask.shape[1]] + 0.3 * right_mask
# Reconstruct
corrected = np.zeros_like(mask)
corrected[:, :center_x] = left_mask
corrected[:, center_x:center_x + corrected_right.shape[1]] = corrected_right
else:
# Mirror right to left
reference = right_mask
mirrored = np.fliplr(reference)
# Blend with original left
corrected_left = 0.7 * mirrored[:, -left_mask.shape[1]:] + 0.3 * left_mask
# Reconstruct
corrected = np.zeros_like(mask)
corrected[:, :center_x] = corrected_left
corrected[:, center_x:] = right_mask
# Smooth the center seam
seam_width = 10
seam_start = max(0, center_x - seam_width)
seam_end = min(w, center_x + seam_width)
corrected[:, seam_start:seam_end] = cv2.GaussianBlur(
corrected[:, seam_start:seam_end], (7, 1), 2.0
)
return corrected
def _find_center_line(self, mask: np.ndarray) -> int:
"""Find the vertical center line of the object."""
# Use center of mass
mask_binary = (mask > 0.5).astype(np.uint8)
moments = cv2.moments(mask_binary)
if moments['m00'] > 0:
cx = int(moments['m10'] / moments['m00'])
else:
# Fallback to image center
cx = mask.shape[1] // 2
return cx
class HairEdgeEnhancer:
"""Enhances edges in hair masks."""
def __init__(self, config: HairConfig):
self.config = config
def enhance(self, mask: np.ndarray, image: np.ndarray) -> np.ndarray:
"""Enhance hair edges for better quality."""
# Detect edges in image
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
# Multi-scale edge detection
edges = self._multi_scale_edges(gray)
# Detect edges in mask
mask_edges = cv2.Canny((mask * 255).astype(np.uint8), 30, 100) / 255.0
# Find hair-specific edges
hair_edges = self._detect_hair_edges(gray, mask)
# Combine all edge information
combined_edges = np.maximum(edges * 0.3, np.maximum(mask_edges * 0.3, hair_edges * 0.4))
# Apply non-maximum suppression
combined_edges = self._non_max_suppression(combined_edges)
return combined_edges
def _multi_scale_edges(self, gray: np.ndarray) -> np.ndarray:
"""Detect edges at multiple scales."""
edges_list = []
for scale in [1, 2, 3]:
# Resize image
if scale > 1:
scaled = cv2.resize(gray, None, fx=1/scale, fy=1/scale)
else:
scaled = gray
# Detect edges
edges = cv2.Canny(scaled, 30 * scale, 80 * scale)
# Resize back
if scale > 1:
edges = cv2.resize(edges, (gray.shape[1], gray.shape[0]))
edges_list.append(edges / 255.0)
# Combine scales
combined = np.mean(edges_list, axis=0)
return combined
def _detect_hair_edges(self, gray: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Detect edges specific to hair texture."""
# Use Gabor filters to detect hair-like textures
hair_edges = np.zeros_like(gray, dtype=np.float32)
# Multiple orientations
for angle in range(0, 180, 30):
theta = np.deg2rad(angle)
kernel = cv2.getGaborKernel(
(11, 11), 3.0, theta, 8.0, 0.5, 0, ktype=cv2.CV_32F
)
filtered = cv2.filter2D(gray, cv2.CV_32F, kernel)
hair_edges = np.maximum(hair_edges, np.abs(filtered))
# Normalize
hair_edges = hair_edges / (np.max(hair_edges) + 1e-6)
# Mask to hair regions
hair_edges *= mask
# Threshold
hair_edges = (hair_edges > self.config.edge_sensitivity * 0.5).astype(np.float32)
return hair_edges
def _non_max_suppression(self, edges: np.ndarray) -> np.ndarray:
"""Apply non-maximum suppression to edges."""
# Compute gradients
dx = cv2.Sobel(edges, cv2.CV_32F, 1, 0, ksize=3)
dy = cv2.Sobel(edges, cv2.CV_32F, 0, 1, ksize=3)
# Gradient magnitude and direction
magnitude = np.sqrt(dx**2 + dy**2)
direction = np.arctan2(dy, dx)
# Quantize directions to 4 main orientations
direction = np.rad2deg(direction)
direction[direction < 0] += 180
# Non-maximum suppression
suppressed = np.zeros_like(magnitude)
for i in range(1, magnitude.shape[0] - 1):
for j in range(1, magnitude.shape[1] - 1):
angle = direction[i, j]
mag = magnitude[i, j]
# Determine neighbors based on gradient direction
if (0 <= angle < 22.5) or (157.5 <= angle <= 180):
# Horizontal
neighbors = [magnitude[i, j-1], magnitude[i, j+1]]
elif 22.5 <= angle < 67.5:
# Diagonal /
neighbors = [magnitude[i-1, j+1], magnitude[i+1, j-1]]
elif 67.5 <= angle < 112.5:
# Vertical
neighbors = [magnitude[i-1, j], magnitude[i+1, j]]
else:
# Diagonal \
neighbors = [magnitude[i-1, j-1], magnitude[i+1, j+1]]
# Keep only if local maximum
if mag >= max(neighbors):
suppressed[i, j] = mag
# Normalize
suppressed = suppressed / (np.max(suppressed) + 1e-6)
return suppressed
class HairNet(nn.Module):
"""Simple neural network for hair feature extraction (placeholder)."""
def __init__(self):
super().__init__()
# Simplified architecture - replace with actual model if needed
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(64, 32, 3, padding=1),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(32, 1, 3, padding=1),
nn.Sigmoid()
)
def extract_features(self, x: torch.Tensor) -> torch.Tensor:
"""Extract features from input image."""
return self.encoder(x)
def process_features(self, features: torch.Tensor) -> torch.Tensor:
"""Process features to get hair probability."""
return self.decoder(features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
features = self.extract_features(x)
output = self.process_features(features)
return output
# Utility functions
def visualize_hair_segmentation(image: np.ndarray,
results: Dict[str, np.ndarray],
save_path: Optional[str] = None) -> np.ndarray:
"""Visualize hair segmentation results."""
h, w = image.shape[:2]
# Create visualization grid
viz = np.zeros((h * 2, w * 2, 3), dtype=np.uint8)
# Original image
viz[:h, :w] = image
# Hair mask overlay
mask_colored = np.zeros_like(image)
mask_colored[:, :, 1] = (results['mask'] * 255).astype(np.uint8) # Green channel
overlay = cv2.addWeighted(image, 0.7, mask_colored, 0.3, 0)
viz[:h, w:] = overlay
# Confidence map
if 'confidence' in results:
confidence_colored = cv2.applyColorMap(
(results['confidence'] * 255).astype(np.uint8),
cv2.COLORMAP_JET
)
viz[h:, :w] = confidence_colored
# Edges and strands
if 'edges' in results and 'strands' in results:
edges_viz = np.zeros_like(image)
edges_viz[:, :, 2] = (results['edges'] * 255).astype(np.uint8) # Red channel
if results['strands'] is not None:
edges_viz[:, :, 0] = (results['strands'] * 255).astype(np.uint8) # Blue channel
viz[h:, w:] = edges_viz
# Add labels
cv2.putText(viz, "Original", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
cv2.putText(viz, "Hair Mask", (w + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
cv2.putText(viz, "Confidence", (10, h + 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
cv2.putText(viz, "Edges/Strands", (w + 10, h + 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
if save_path:
cv2.imwrite(save_path, viz)
return viz
# Export classes and functions
__all__ = [
'HairSegmentationPipeline',
'HairConfig',
'HairMaskRefiner',
'AsymmetryDetector',
'HairEdgeEnhancer',
'HairNet',
'visualize_hair_segmentation'
]