VideoBackgroundReplacer / utils /cv_processing.py
MogensR's picture
Update utils/cv_processing.py
26841b5
raw
history blame
44.4 kB
"""
Computer Vision Processing Module for BackgroundFX Pro
Contains segmentation, mask refinement, background replacement, and helper functions
"""
# Set OMP_NUM_THREADS at the very beginning to prevent libgomp errors
import os
if 'OMP_NUM_THREADS' not in os.environ:
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['MKL_NUM_THREADS'] = '4'
import logging
from typing import Optional, Tuple, Dict, Any
import numpy as np
import cv2
import torch
logger = logging.getLogger(__name__)
# ============================================================================
# CONFIGURATION AND CONSTANTS
# ============================================================================
# Version control flags for CV functions
USE_ENHANCED_SEGMENTATION = True
USE_AUTO_TEMPORAL_CONSISTENCY = True
USE_INTELLIGENT_PROMPTING = True
USE_ITERATIVE_REFINEMENT = True
# Professional background templates
PROFESSIONAL_BACKGROUNDS = {
"office_modern": {
"name": "Modern Office",
"type": "gradient",
"colors": ["#f8f9fa", "#e9ecef", "#dee2e6"],
"direction": "diagonal",
"description": "Clean, contemporary office environment",
"brightness": 0.95,
"contrast": 1.1
},
"studio_blue": {
"name": "Professional Blue",
"type": "gradient",
"colors": ["#1e3c72", "#2a5298", "#3498db"],
"direction": "radial",
"description": "Broadcast-quality blue studio",
"brightness": 0.9,
"contrast": 1.2
},
"studio_green": {
"name": "Broadcast Green",
"type": "color",
"colors": ["#00b894"],
"chroma_key": True,
"description": "Professional green screen replacement",
"brightness": 1.0,
"contrast": 1.0
},
"minimalist": {
"name": "Minimalist White",
"type": "gradient",
"colors": ["#ffffff", "#f1f2f6", "#ddd"],
"direction": "soft_radial",
"description": "Clean, minimal background",
"brightness": 0.98,
"contrast": 0.9
},
"warm_gradient": {
"name": "Warm Sunset",
"type": "gradient",
"colors": ["#ff7675", "#fd79a8", "#fdcb6e"],
"direction": "diagonal",
"description": "Warm, inviting atmosphere",
"brightness": 0.85,
"contrast": 1.15
},
"tech_dark": {
"name": "Tech Dark",
"type": "gradient",
"colors": ["#0c0c0c", "#2d3748", "#4a5568"],
"direction": "vertical",
"description": "Modern tech/gaming setup",
"brightness": 0.7,
"contrast": 1.3
}
}
# ============================================================================
# CUSTOM EXCEPTIONS
# ============================================================================
class SegmentationError(Exception):
"""Custom exception for segmentation failures"""
pass
class MaskRefinementError(Exception):
"""Custom exception for mask refinement failures"""
pass
class BackgroundReplacementError(Exception):
"""Custom exception for background replacement failures"""
pass
# ============================================================================
# MAIN SEGMENTATION FUNCTIONS
# ============================================================================
def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
"""High-quality person segmentation with intelligent automation"""
if not USE_ENHANCED_SEGMENTATION:
return segment_person_hq_original(image, predictor, fallback_enabled)
logger.debug("Using ENHANCED segmentation with intelligent automation")
if image is None or image.size == 0:
raise SegmentationError("Invalid input image")
try:
if predictor is None:
if fallback_enabled:
logger.warning("SAM2 predictor not available, using fallback")
return _fallback_segmentation(image)
else:
raise SegmentationError("SAM2 predictor not available")
try:
predictor.set_image(image)
except Exception as e:
logger.error(f"Failed to set image in predictor: {e}")
if fallback_enabled:
return _fallback_segmentation(image)
else:
raise SegmentationError(f"Predictor setup failed: {e}")
if USE_INTELLIGENT_PROMPTING:
mask = _segment_with_intelligent_prompts(image, predictor)
else:
mask = _segment_with_basic_prompts(image, predictor)
if USE_ITERATIVE_REFINEMENT and mask is not None:
mask = _auto_refine_mask_iteratively(image, mask, predictor)
if not _validate_mask_quality(mask, image.shape[:2]):
logger.warning("Mask quality validation failed")
if fallback_enabled:
return _fallback_segmentation(image)
else:
raise SegmentationError("Poor mask quality")
logger.debug(f"Enhanced segmentation successful - mask range: {mask.min()}-{mask.max()}")
return mask
except SegmentationError:
raise
except Exception as e:
logger.error(f"Unexpected segmentation error: {e}")
if fallback_enabled:
return _fallback_segmentation(image)
else:
raise SegmentationError(f"Unexpected error: {e}")
def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
"""Original version of person segmentation for rollback"""
if image is None or image.size == 0:
raise SegmentationError("Invalid input image")
try:
if predictor is None:
if fallback_enabled:
logger.warning("SAM2 predictor not available, using fallback")
return _fallback_segmentation(image)
else:
raise SegmentationError("SAM2 predictor not available")
try:
predictor.set_image(image)
except Exception as e:
logger.error(f"Failed to set image in predictor: {e}")
if fallback_enabled:
return _fallback_segmentation(image)
else:
raise SegmentationError(f"Predictor setup failed: {e}")
h, w = image.shape[:2]
points = np.array([
[w//2, h//4],
[w//2, h//2],
[w//2, 3*h//4],
[w//3, h//2],
[2*w//3, h//2],
[w//2, h//6],
[w//4, 2*h//3],
[3*w//4, 2*h//3],
], dtype=np.float32)
labels = np.ones(len(points), dtype=np.int32)
try:
with torch.no_grad():
masks, scores, _ = predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True
)
except Exception as e:
logger.error(f"SAM2 prediction failed: {e}")
if fallback_enabled:
return _fallback_segmentation(image)
else:
raise SegmentationError(f"Prediction failed: {e}")
if masks is None or len(masks) == 0:
logger.warning("SAM2 returned no masks")
if fallback_enabled:
return _fallback_segmentation(image)
else:
raise SegmentationError("No masks generated")
if scores is None or len(scores) == 0:
logger.warning("SAM2 returned no scores")
best_mask = masks[0]
else:
best_idx = np.argmax(scores)
best_mask = masks[best_idx]
logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}")
mask = _process_mask(best_mask)
if not _validate_mask_quality(mask, image.shape[:2]):
logger.warning("Mask quality validation failed")
if fallback_enabled:
return _fallback_segmentation(image)
else:
raise SegmentationError("Poor mask quality")
logger.debug(f"Segmentation successful - mask range: {mask.min()}-{mask.max()}")
return mask
except SegmentationError:
raise
except Exception as e:
logger.error(f"Unexpected segmentation error: {e}")
if fallback_enabled:
return _fallback_segmentation(image)
else:
raise SegmentationError(f"Unexpected error: {e}")
# ============================================================================
# MASK REFINEMENT FUNCTIONS
# ============================================================================
def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any,
fallback_enabled: bool = True) -> np.ndarray:
"""Enhanced mask refinement with MatAnyone and robust fallbacks"""
if image is None or mask is None:
raise MaskRefinementError("Invalid input image or mask")
try:
mask = _process_mask(mask)
if matanyone_processor is not None:
try:
logger.debug("Attempting MatAnyone refinement")
refined_mask = _matanyone_refine(image, mask, matanyone_processor)
if refined_mask is not None and _validate_mask_quality(refined_mask, image.shape[:2]):
logger.debug("MatAnyone refinement successful")
return refined_mask
else:
logger.warning("MatAnyone produced poor quality mask")
except Exception as e:
logger.warning(f"MatAnyone refinement failed: {e}")
if fallback_enabled:
logger.debug("Using enhanced OpenCV refinement")
return enhance_mask_opencv_advanced(image, mask)
else:
raise MaskRefinementError("MatAnyone failed and fallback disabled")
except MaskRefinementError:
raise
except Exception as e:
logger.error(f"Unexpected mask refinement error: {e}")
if fallback_enabled:
return enhance_mask_opencv_advanced(image, mask)
else:
raise MaskRefinementError(f"Unexpected error: {e}")
def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Advanced OpenCV-based mask enhancement with multiple techniques"""
try:
if len(mask.shape) == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
if mask.max() <= 1.0:
mask = (mask * 255).astype(np.uint8)
refined_mask = cv2.bilateralFilter(mask, 9, 75, 75)
refined_mask = _guided_filter_approx(image, refined_mask, radius=8, eps=0.2)
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_CLOSE, kernel_close)
kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_OPEN, kernel_open)
refined_mask = cv2.GaussianBlur(refined_mask, (3, 3), 0.8)
_, refined_mask = cv2.threshold(refined_mask, 127, 255, cv2.THRESH_BINARY)
return refined_mask
except Exception as e:
logger.warning(f"Enhanced OpenCV refinement failed: {e}")
return cv2.GaussianBlur(mask, (5, 5), 1.0)
# ============================================================================
# BACKGROUND REPLACEMENT FUNCTIONS
# ============================================================================
def replace_background_hq(frame: np.ndarray, mask: np.ndarray, background: np.ndarray,
fallback_enabled: bool = True) -> np.ndarray:
"""Enhanced background replacement with comprehensive error handling"""
if frame is None or mask is None or background is None:
raise BackgroundReplacementError("Invalid input frame, mask, or background")
try:
background = cv2.resize(background, (frame.shape[1], frame.shape[0]),
interpolation=cv2.INTER_LANCZOS4)
if len(mask.shape) == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.max() <= 1.0:
logger.debug("Converting normalized mask to 0-255 range")
mask = (mask * 255).astype(np.uint8)
try:
result = _advanced_compositing(frame, mask, background)
logger.debug("Advanced compositing successful")
return result
except Exception as e:
logger.warning(f"Advanced compositing failed: {e}")
if fallback_enabled:
return _simple_compositing(frame, mask, background)
else:
raise BackgroundReplacementError(f"Advanced compositing failed: {e}")
except BackgroundReplacementError:
raise
except Exception as e:
logger.error(f"Unexpected background replacement error: {e}")
if fallback_enabled:
return _simple_compositing(frame, mask, background)
else:
raise BackgroundReplacementError(f"Unexpected error: {e}")
def create_professional_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
"""Enhanced professional background creation with quality improvements"""
try:
if bg_config["type"] == "color":
background = _create_solid_background(bg_config, width, height)
elif bg_config["type"] == "gradient":
background = _create_gradient_background_enhanced(bg_config, width, height)
else:
background = np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
background = _apply_background_adjustments(background, bg_config)
return background
except Exception as e:
logger.error(f"Background creation error: {e}")
return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
# ============================================================================
# VALIDATION FUNCTION
# ============================================================================
def validate_video_file(video_path: str) -> Tuple[bool, str]:
"""Enhanced video file validation with detailed checks"""
if not video_path or not os.path.exists(video_path):
return False, "Video file not found"
try:
file_size = os.path.getsize(video_path)
if file_size == 0:
return False, "Video file is empty"
if file_size > 2 * 1024 * 1024 * 1024:
return False, "Video file too large (>2GB)"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return False, "Cannot open video file"
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
if frame_count == 0:
return False, "Video appears to be empty (0 frames)"
if fps <= 0 or fps > 120:
return False, f"Invalid frame rate: {fps}"
if width <= 0 or height <= 0:
return False, f"Invalid resolution: {width}x{height}"
if width > 4096 or height > 4096:
return False, f"Resolution too high: {width}x{height} (max 4096x4096)"
duration = frame_count / fps
if duration > 300:
return False, f"Video too long: {duration:.1f}s (max 300s)"
return True, f"Valid video: {width}x{height}, {fps:.1f}fps, {duration:.1f}s"
except Exception as e:
return False, f"Error validating video: {str(e)}"
# ============================================================================
# HELPER FUNCTIONS - SEGMENTATION
# ============================================================================
def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
"""Intelligent automatic prompt generation for segmentation"""
try:
h, w = image.shape[:2]
pos_points, neg_points = _generate_smart_prompts(image)
if len(pos_points) == 0:
pos_points = np.array([[w//2, h//2]], dtype=np.float32)
points = np.vstack([pos_points, neg_points])
labels = np.hstack([
np.ones(len(pos_points), dtype=np.int32),
np.zeros(len(neg_points), dtype=np.int32)
])
logger.debug(f"Using {len(pos_points)} positive, {len(neg_points)} negative points")
with torch.no_grad():
masks, scores, _ = predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True
)
if masks is None or len(masks) == 0:
raise SegmentationError("No masks generated")
if scores is not None and len(scores) > 0:
best_idx = np.argmax(scores)
best_mask = masks[best_idx]
logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}")
else:
best_mask = masks[0]
return _process_mask(best_mask)
except Exception as e:
logger.error(f"Intelligent prompting failed: {e}")
raise
def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
"""Basic prompting method for segmentation"""
h, w = image.shape[:2]
positive_points = np.array([
[w//2, h//3],
[w//2, h//2],
[w//2, 2*h//3],
], dtype=np.float32)
negative_points = np.array([
[w//10, h//10],
[9*w//10, h//10],
[w//10, 9*h//10],
[9*w//10, 9*h//10],
], dtype=np.float32)
points = np.vstack([positive_points, negative_points])
labels = np.array([1, 1, 1, 0, 0, 0, 0], dtype=np.int32)
with torch.no_grad():
masks, scores, _ = predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True
)
if masks is None or len(masks) == 0:
raise SegmentationError("No masks generated")
best_idx = np.argmax(scores) if scores is not None and len(scores) > 0 else 0
best_mask = masks[best_idx]
return _process_mask(best_mask)
def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Generate optimal positive/negative points automatically"""
try:
h, w = image.shape[:2]
try:
saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
success, saliency_map = saliency.computeSaliency(image)
if success:
saliency_thresh = cv2.threshold(saliency_map, 0.7, 1, cv2.THRESH_BINARY)[1]
contours, _ = cv2.findContours((saliency_thresh * 255).astype(np.uint8),
cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
positive_points = []
if contours:
for contour in sorted(contours, key=cv2.contourArea, reverse=True)[:3]:
M = cv2.moments(contour)
if M["m00"] != 0:
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
if 0 < cx < w and 0 < cy < h:
positive_points.append([cx, cy])
if positive_points:
logger.debug(f"Generated {len(positive_points)} saliency-based points")
positive_points = np.array(positive_points, dtype=np.float32)
else:
raise Exception("No valid saliency points found")
except Exception as e:
logger.debug(f"Saliency method failed: {e}, using fallback")
positive_points = np.array([
[w//2, h//3],
[w//2, h//2],
[w//2, 2*h//3],
], dtype=np.float32)
negative_points = np.array([
[10, 10],
[w-10, 10],
[10, h-10],
[w-10, h-10],
[w//2, 5],
[w//2, h-5],
], dtype=np.float32)
return positive_points, negative_points
except Exception as e:
logger.warning(f"Smart prompt generation failed: {e}")
h, w = image.shape[:2]
positive_points = np.array([[w//2, h//2]], dtype=np.float32)
negative_points = np.array([[10, 10], [w-10, 10]], dtype=np.float32)
return positive_points, negative_points
# ============================================================================
# HELPER FUNCTIONS - REFINEMENT
# ============================================================================
def _auto_refine_mask_iteratively(image: np.ndarray, initial_mask: np.ndarray,
predictor: Any, max_iterations: int = 2) -> np.ndarray:
"""Automatically refine mask based on quality assessment"""
try:
current_mask = initial_mask.copy()
for iteration in range(max_iterations):
quality_score = _assess_mask_quality(current_mask, image)
logger.debug(f"Iteration {iteration}: quality score = {quality_score:.3f}")
if quality_score > 0.85:
logger.debug(f"Quality sufficient after {iteration} iterations")
break
problem_areas = _find_mask_errors(current_mask, image)
if np.any(problem_areas):
corrective_points, corrective_labels = _generate_corrective_prompts(
image, current_mask, problem_areas
)
if len(corrective_points) > 0:
try:
with torch.no_grad():
masks, scores, _ = predictor.predict(
point_coords=corrective_points,
point_labels=corrective_labels,
mask_input=current_mask[None, :, :],
multimask_output=False
)
if masks is not None and len(masks) > 0:
refined_mask = _process_mask(masks[0])
if _assess_mask_quality(refined_mask, image) > quality_score:
current_mask = refined_mask
logger.debug(f"Improved mask in iteration {iteration}")
else:
logger.debug(f"Refinement didn't improve quality in iteration {iteration}")
break
except Exception as e:
logger.debug(f"Refinement iteration {iteration} failed: {e}")
break
else:
logger.debug("No problem areas detected")
break
return current_mask
except Exception as e:
logger.warning(f"Iterative refinement failed: {e}")
return initial_mask
def _assess_mask_quality(mask: np.ndarray, image: np.ndarray) -> float:
"""Assess mask quality automatically"""
try:
h, w = image.shape[:2]
scores = []
mask_area = np.sum(mask > 127)
total_area = h * w
area_ratio = mask_area / total_area
if 0.05 <= area_ratio <= 0.8:
area_score = 1.0
elif area_ratio < 0.05:
area_score = area_ratio / 0.05
else:
area_score = max(0, 1.0 - (area_ratio - 0.8) / 0.2)
scores.append(area_score)
mask_binary = mask > 127
if np.any(mask_binary):
mask_center_y, mask_center_x = np.where(mask_binary)
center_y = np.mean(mask_center_y) / h
center_x = np.mean(mask_center_x) / w
center_score = 1.0 - min(abs(center_x - 0.5), abs(center_y - 0.5))
scores.append(center_score)
else:
scores.append(0.0)
edges = cv2.Canny(mask, 50, 150)
edge_density = np.sum(edges > 0) / total_area
smoothness_score = max(0, 1.0 - edge_density * 10)
scores.append(smoothness_score)
num_labels, _ = cv2.connectedComponents(mask)
connectivity_score = max(0, 1.0 - (num_labels - 2) * 0.2)
scores.append(connectivity_score)
weights = [0.3, 0.2, 0.3, 0.2]
overall_score = np.average(scores, weights=weights)
return overall_score
except Exception as e:
logger.warning(f"Quality assessment failed: {e}")
return 0.5
def _find_mask_errors(mask: np.ndarray, image: np.ndarray) -> np.ndarray:
"""Identify problematic areas in mask"""
try:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150)
mask_edges = cv2.Canny(mask, 50, 150)
edge_discrepancy = cv2.bitwise_xor(edges, mask_edges)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
error_regions = cv2.dilate(edge_discrepancy, kernel, iterations=1)
return error_regions > 0
except Exception as e:
logger.warning(f"Error detection failed: {e}")
return np.zeros_like(mask, dtype=bool)
def _generate_corrective_prompts(image: np.ndarray, mask: np.ndarray,
problem_areas: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Generate corrective prompts based on problem areas"""
try:
contours, _ = cv2.findContours(problem_areas.astype(np.uint8),
cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
corrective_points = []
corrective_labels = []
for contour in contours:
if cv2.contourArea(contour) > 100:
M = cv2.moments(contour)
if M["m00"] != 0:
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
current_mask_value = mask[cy, cx]
if current_mask_value < 127:
corrective_points.append([cx, cy])
corrective_labels.append(1)
else:
corrective_points.append([cx, cy])
corrective_labels.append(0)
return (np.array(corrective_points, dtype=np.float32) if corrective_points else np.array([]).reshape(0, 2),
np.array(corrective_labels, dtype=np.int32) if corrective_labels else np.array([], dtype=np.int32))
except Exception as e:
logger.warning(f"Corrective prompt generation failed: {e}")
return np.array([]).reshape(0, 2), np.array([], dtype=np.int32)
# ============================================================================
# HELPER FUNCTIONS - PROCESSING
# ============================================================================
def _process_mask(mask: np.ndarray) -> np.ndarray:
"""Process raw mask to ensure correct format and range"""
try:
if len(mask.shape) > 2:
mask = mask.squeeze()
if len(mask.shape) > 2:
mask = mask[:, :, 0] if mask.shape[2] > 0 else mask.sum(axis=2)
if mask.dtype == bool:
mask = mask.astype(np.uint8) * 255
elif mask.dtype == np.float32 or mask.dtype == np.float64:
if mask.max() <= 1.0:
mask = (mask * 255).astype(np.uint8)
else:
mask = np.clip(mask, 0, 255).astype(np.uint8)
else:
mask = mask.astype(np.uint8)
kernel = np.ones((3, 3), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
_, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
return mask
except Exception as e:
logger.error(f"Mask processing failed: {e}")
h, w = mask.shape[:2] if len(mask.shape) >= 2 else (256, 256)
fallback = np.zeros((h, w), dtype=np.uint8)
fallback[h//4:3*h//4, w//4:3*w//4] = 255
return fallback
def _validate_mask_quality(mask: np.ndarray, image_shape: Tuple[int, int]) -> bool:
"""Validate that the mask meets quality criteria"""
try:
h, w = image_shape
mask_area = np.sum(mask > 127)
total_area = h * w
area_ratio = mask_area / total_area
if area_ratio < 0.05 or area_ratio > 0.8:
logger.warning(f"Suspicious mask area ratio: {area_ratio:.3f}")
return False
mask_binary = mask > 127
mask_center_y, mask_center_x = np.where(mask_binary)
if len(mask_center_y) == 0:
logger.warning("Empty mask")
return False
center_y = np.mean(mask_center_y)
center_x = np.mean(mask_center_x)
if center_y < h * 0.2 or center_y > h * 0.9:
logger.warning(f"Mask center too far from expected person location: y={center_y/h:.2f}")
return False
return True
except Exception as e:
logger.warning(f"Mask validation error: {e}")
return True
def _fallback_segmentation(image: np.ndarray) -> np.ndarray:
"""Fallback segmentation when AI models fail"""
try:
logger.info("Using fallback segmentation strategy")
h, w = image.shape[:2]
try:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
edge_pixels = np.concatenate([
gray[0, :], gray[-1, :], gray[:, 0], gray[:, -1]
])
bg_color = np.median(edge_pixels)
diff = np.abs(gray.astype(float) - bg_color)
mask = (diff > 30).astype(np.uint8) * 255
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
if _validate_mask_quality(mask, image.shape[:2]):
logger.info("Background subtraction fallback successful")
return mask
except Exception as e:
logger.warning(f"Background subtraction fallback failed: {e}")
mask = np.zeros((h, w), dtype=np.uint8)
center_x, center_y = w // 2, h // 2
radius_x, radius_y = w // 3, h // 2.5
y, x = np.ogrid[:h, :w]
mask_ellipse = ((x - center_x) / radius_x) ** 2 + ((y - center_y) / radius_y) ** 2 <= 1
mask[mask_ellipse] = 255
logger.info("Using geometric fallback mask")
return mask
except Exception as e:
logger.error(f"All fallback strategies failed: {e}")
h, w = image.shape[:2]
mask = np.zeros((h, w), dtype=np.uint8)
mask[h//6:5*h//6, w//4:3*w//4] = 255
return mask
def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Optional[np.ndarray]:
"""Attempt MatAnyone mask refinement"""
try:
if hasattr(processor, 'infer'):
refined_mask = processor.infer(image, mask)
elif hasattr(processor, 'process'):
refined_mask = processor.process(image, mask)
elif callable(processor):
refined_mask = processor(image, mask)
else:
logger.warning("Unknown MatAnyone interface")
return None
if refined_mask is None:
return None
refined_mask = _process_mask(refined_mask)
logger.debug("MatAnyone refinement successful")
return refined_mask
except Exception as e:
logger.warning(f"MatAnyone processing error: {e}")
return None
def _guided_filter_approx(guide: np.ndarray, mask: np.ndarray, radius: int = 8, eps: float = 0.2) -> np.ndarray:
"""Approximation of guided filter for edge-aware smoothing"""
try:
guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) if len(guide.shape) == 3 else guide
guide_gray = guide_gray.astype(np.float32) / 255.0
mask_float = mask.astype(np.float32) / 255.0
kernel_size = 2 * radius + 1
mean_guide = cv2.boxFilter(guide_gray, -1, (kernel_size, kernel_size))
mean_mask = cv2.boxFilter(mask_float, -1, (kernel_size, kernel_size))
corr_guide_mask = cv2.boxFilter(guide_gray * mask_float, -1, (kernel_size, kernel_size))
cov_guide_mask = corr_guide_mask - mean_guide * mean_mask
mean_guide_sq = cv2.boxFilter(guide_gray * guide_gray, -1, (kernel_size, kernel_size))
var_guide = mean_guide_sq - mean_guide * mean_guide
a = cov_guide_mask / (var_guide + eps)
b = mean_mask - a * mean_guide
mean_a = cv2.boxFilter(a, -1, (kernel_size, kernel_size))
mean_b = cv2.boxFilter(b, -1, (kernel_size, kernel_size))
output = mean_a * guide_gray + mean_b
output = np.clip(output * 255, 0, 255).astype(np.uint8)
return output
except Exception as e:
logger.warning(f"Guided filter approximation failed: {e}")
return mask
# ============================================================================
# HELPER FUNCTIONS - COMPOSITING
# ============================================================================
def _advanced_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray:
"""Advanced compositing with edge feathering and color correction"""
try:
threshold = 100
_, mask_binary = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_CLOSE, kernel)
mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_OPEN, kernel)
mask_smooth = cv2.GaussianBlur(mask_binary.astype(np.float32), (5, 5), 1.0)
mask_smooth = mask_smooth / 255.0
mask_smooth = np.power(mask_smooth, 0.8)
mask_smooth = np.where(mask_smooth > 0.5,
np.minimum(mask_smooth * 1.1, 1.0),
mask_smooth * 0.9)
frame_adjusted = _color_match_edges(frame, background, mask_smooth)
alpha_3ch = np.stack([mask_smooth] * 3, axis=2)
frame_float = frame_adjusted.astype(np.float32)
background_float = background.astype(np.float32)
result = frame_float * alpha_3ch + background_float * (1 - alpha_3ch)
result = np.clip(result, 0, 255).astype(np.uint8)
return result
except Exception as e:
logger.error(f"Advanced compositing error: {e}")
raise
def _color_match_edges(frame: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray:
"""Subtle color matching at edges to reduce halos"""
try:
edge_mask = cv2.Sobel(alpha, cv2.CV_64F, 1, 1, ksize=3)
edge_mask = np.abs(edge_mask)
edge_mask = (edge_mask > 0.1).astype(np.float32)
edge_areas = edge_mask > 0
if not np.any(edge_areas):
return frame
frame_adjusted = frame.copy().astype(np.float32)
background_float = background.astype(np.float32)
adjustment_strength = 0.1
for c in range(3):
frame_adjusted[:, :, c] = np.where(
edge_areas,
frame_adjusted[:, :, c] * (1 - adjustment_strength) +
background_float[:, :, c] * adjustment_strength,
frame_adjusted[:, :, c]
)
return np.clip(frame_adjusted, 0, 255).astype(np.uint8)
except Exception as e:
logger.warning(f"Color matching failed: {e}")
return frame
def _simple_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray:
"""Simple fallback compositing method"""
try:
logger.info("Using simple compositing fallback")
background = cv2.resize(background, (frame.shape[1], frame.shape[0]))
if len(mask.shape) == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
if mask.max() <= 1.0:
mask = (mask * 255).astype(np.uint8)
_, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
mask_norm = mask_binary.astype(np.float32) / 255.0
mask_3ch = np.stack([mask_norm] * 3, axis=2)
result = frame * mask_3ch + background * (1 - mask_3ch)
return result.astype(np.uint8)
except Exception as e:
logger.error(f"Simple compositing failed: {e}")
return frame
# ============================================================================
# HELPER FUNCTIONS - BACKGROUND CREATION
# ============================================================================
def _create_solid_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
"""Create solid color background"""
color_hex = bg_config["colors"][0].lstrip('#')
color_rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
color_bgr = color_rgb[::-1]
return np.full((height, width, 3), color_bgr, dtype=np.uint8)
def _create_gradient_background_enhanced(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
"""Create enhanced gradient background with better quality"""
try:
colors = bg_config["colors"]
direction = bg_config.get("direction", "vertical")
rgb_colors = []
for color_hex in colors:
color_hex = color_hex.lstrip('#')
rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
rgb_colors.append(rgb)
if not rgb_colors:
rgb_colors = [(128, 128, 128)]
if direction == "vertical":
background = _create_vertical_gradient(rgb_colors, width, height)
elif direction == "horizontal":
background = _create_horizontal_gradient(rgb_colors, width, height)
elif direction == "diagonal":
background = _create_diagonal_gradient(rgb_colors, width, height)
elif direction in ["radial", "soft_radial"]:
background = _create_radial_gradient(rgb_colors, width, height, direction == "soft_radial")
else:
background = _create_vertical_gradient(rgb_colors, width, height)
return cv2.cvtColor(background, cv2.COLOR_RGB2BGR)
except Exception as e:
logger.error(f"Gradient creation error: {e}")
return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
def _create_vertical_gradient(colors: list, width: int, height: int) -> np.ndarray:
"""Create vertical gradient using NumPy for performance"""
gradient = np.zeros((height, width, 3), dtype=np.uint8)
for y in range(height):
progress = y / height if height > 0 else 0
color = _interpolate_color(colors, progress)
gradient[y, :] = color
return gradient
def _create_horizontal_gradient(colors: list, width: int, height: int) -> np.ndarray:
"""Create horizontal gradient using NumPy for performance"""
gradient = np.zeros((height, width, 3), dtype=np.uint8)
for x in range(width):
progress = x / width if width > 0 else 0
color = _interpolate_color(colors, progress)
gradient[:, x] = color
return gradient
def _create_diagonal_gradient(colors: list, width: int, height: int) -> np.ndarray:
"""Create diagonal gradient using vectorized operations"""
y_coords, x_coords = np.mgrid[0:height, 0:width]
max_distance = width + height
progress = (x_coords + y_coords) / max_distance
progress = np.clip(progress, 0, 1)
gradient = np.zeros((height, width, 3), dtype=np.uint8)
for c in range(3):
gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c)
return gradient
def _create_radial_gradient(colors: list, width: int, height: int, soft: bool = False) -> np.ndarray:
"""Create radial gradient using vectorized operations"""
center_x, center_y = width // 2, height // 2
max_distance = np.sqrt(center_x**2 + center_y**2)
y_coords, x_coords = np.mgrid[0:height, 0:width]
distances = np.sqrt((x_coords - center_x)**2 + (y_coords - center_y)**2)
progress = distances / max_distance
progress = np.clip(progress, 0, 1)
if soft:
progress = np.power(progress, 0.7)
gradient = np.zeros((height, width, 3), dtype=np.uint8)
for c in range(3):
gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c)
return gradient
def _vectorized_color_interpolation(colors: list, progress: np.ndarray, channel: int) -> np.ndarray:
"""Vectorized color interpolation for performance"""
if len(colors) == 1:
return np.full_like(progress, colors[0][channel], dtype=np.uint8)
num_segments = len(colors) - 1
segment_progress = progress * num_segments
segment_indices = np.floor(segment_progress).astype(int)
segment_indices = np.clip(segment_indices, 0, num_segments - 1)
local_progress = segment_progress - segment_indices
start_colors = np.array([colors[i][channel] for i in range(len(colors))])
end_colors = np.array([colors[min(i + 1, len(colors) - 1)][channel] for i in range(len(colors))])
start_vals = start_colors[segment_indices]
end_vals = end_colors[segment_indices]
result = start_vals + (end_vals - start_vals) * local_progress
return np.clip(result, 0, 255).astype(np.uint8)
def _interpolate_color(colors: list, progress: float) -> tuple:
"""Interpolate between multiple colors"""
if len(colors) == 1:
return colors[0]
elif len(colors) == 2:
r = int(colors[0][0] + (colors[1][0] - colors[0][0]) * progress)
g = int(colors[0][1] + (colors[1][1] - colors[0][1]) * progress)
b = int(colors[0][2] + (colors[1][2] - colors[0][2]) * progress)
return (r, g, b)
else:
segment = progress * (len(colors) - 1)
idx = int(segment)
local_progress = segment - idx
if idx >= len(colors) - 1:
return colors[-1]
c1, c2 = colors[idx], colors[idx + 1]
r = int(c1[0] + (c2[0] - c1[0]) * local_progress)
g = int(c1[1] + (c2[1] - c1[1]) * local_progress)
b = int(c1[2] + (c2[2] - c1[2]) * local_progress)
return (r, g, b)
def _apply_background_adjustments(background: np.ndarray, bg_config: Dict[str, Any]) -> np.ndarray:
"""Apply brightness and contrast adjustments to background"""
try:
brightness = bg_config.get("brightness", 1.0)
contrast = bg_config.get("contrast", 1.0)
if brightness != 1.0 or contrast != 1.0:
background = background.astype(np.float32)
background = background * contrast * brightness
background = np.clip(background, 0, 255).astype(np.uint8)
return background
except Exception as e:
logger.warning(f"Background adjustment failed: {e}")
return background