|
|
""" |
|
|
Computer Vision Processing Module for BackgroundFX Pro |
|
|
Contains segmentation, mask refinement, background replacement, and helper functions |
|
|
""" |
|
|
|
|
|
|
|
|
import os |
|
|
if 'OMP_NUM_THREADS' not in os.environ: |
|
|
os.environ['OMP_NUM_THREADS'] = '4' |
|
|
os.environ['MKL_NUM_THREADS'] = '4' |
|
|
|
|
|
import logging |
|
|
from typing import Optional, Tuple, Dict, Any |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
USE_ENHANCED_SEGMENTATION = True |
|
|
USE_AUTO_TEMPORAL_CONSISTENCY = True |
|
|
USE_INTELLIGENT_PROMPTING = True |
|
|
USE_ITERATIVE_REFINEMENT = True |
|
|
|
|
|
|
|
|
PROFESSIONAL_BACKGROUNDS = { |
|
|
"office_modern": { |
|
|
"name": "Modern Office", |
|
|
"type": "gradient", |
|
|
"colors": ["#f8f9fa", "#e9ecef", "#dee2e6"], |
|
|
"direction": "diagonal", |
|
|
"description": "Clean, contemporary office environment", |
|
|
"brightness": 0.95, |
|
|
"contrast": 1.1 |
|
|
}, |
|
|
"studio_blue": { |
|
|
"name": "Professional Blue", |
|
|
"type": "gradient", |
|
|
"colors": ["#1e3c72", "#2a5298", "#3498db"], |
|
|
"direction": "radial", |
|
|
"description": "Broadcast-quality blue studio", |
|
|
"brightness": 0.9, |
|
|
"contrast": 1.2 |
|
|
}, |
|
|
"studio_green": { |
|
|
"name": "Broadcast Green", |
|
|
"type": "color", |
|
|
"colors": ["#00b894"], |
|
|
"chroma_key": True, |
|
|
"description": "Professional green screen replacement", |
|
|
"brightness": 1.0, |
|
|
"contrast": 1.0 |
|
|
}, |
|
|
"minimalist": { |
|
|
"name": "Minimalist White", |
|
|
"type": "gradient", |
|
|
"colors": ["#ffffff", "#f1f2f6", "#ddd"], |
|
|
"direction": "soft_radial", |
|
|
"description": "Clean, minimal background", |
|
|
"brightness": 0.98, |
|
|
"contrast": 0.9 |
|
|
}, |
|
|
"warm_gradient": { |
|
|
"name": "Warm Sunset", |
|
|
"type": "gradient", |
|
|
"colors": ["#ff7675", "#fd79a8", "#fdcb6e"], |
|
|
"direction": "diagonal", |
|
|
"description": "Warm, inviting atmosphere", |
|
|
"brightness": 0.85, |
|
|
"contrast": 1.15 |
|
|
}, |
|
|
"tech_dark": { |
|
|
"name": "Tech Dark", |
|
|
"type": "gradient", |
|
|
"colors": ["#0c0c0c", "#2d3748", "#4a5568"], |
|
|
"direction": "vertical", |
|
|
"description": "Modern tech/gaming setup", |
|
|
"brightness": 0.7, |
|
|
"contrast": 1.3 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SegmentationError(Exception): |
|
|
"""Custom exception for segmentation failures""" |
|
|
pass |
|
|
|
|
|
class MaskRefinementError(Exception): |
|
|
"""Custom exception for mask refinement failures""" |
|
|
pass |
|
|
|
|
|
class BackgroundReplacementError(Exception): |
|
|
"""Custom exception for background replacement failures""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
|
|
"""High-quality person segmentation with intelligent automation""" |
|
|
if not USE_ENHANCED_SEGMENTATION: |
|
|
return segment_person_hq_original(image, predictor, fallback_enabled) |
|
|
|
|
|
logger.debug("Using ENHANCED segmentation with intelligent automation") |
|
|
|
|
|
if image is None or image.size == 0: |
|
|
raise SegmentationError("Invalid input image") |
|
|
|
|
|
try: |
|
|
if predictor is None: |
|
|
if fallback_enabled: |
|
|
logger.warning("SAM2 predictor not available, using fallback") |
|
|
return _fallback_segmentation(image) |
|
|
else: |
|
|
raise SegmentationError("SAM2 predictor not available") |
|
|
|
|
|
try: |
|
|
predictor.set_image(image) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to set image in predictor: {e}") |
|
|
if fallback_enabled: |
|
|
return _fallback_segmentation(image) |
|
|
else: |
|
|
raise SegmentationError(f"Predictor setup failed: {e}") |
|
|
|
|
|
if USE_INTELLIGENT_PROMPTING: |
|
|
mask = _segment_with_intelligent_prompts(image, predictor) |
|
|
else: |
|
|
mask = _segment_with_basic_prompts(image, predictor) |
|
|
|
|
|
if USE_ITERATIVE_REFINEMENT and mask is not None: |
|
|
mask = _auto_refine_mask_iteratively(image, mask, predictor) |
|
|
|
|
|
if not _validate_mask_quality(mask, image.shape[:2]): |
|
|
logger.warning("Mask quality validation failed") |
|
|
if fallback_enabled: |
|
|
return _fallback_segmentation(image) |
|
|
else: |
|
|
raise SegmentationError("Poor mask quality") |
|
|
|
|
|
logger.debug(f"Enhanced segmentation successful - mask range: {mask.min()}-{mask.max()}") |
|
|
return mask |
|
|
|
|
|
except SegmentationError: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected segmentation error: {e}") |
|
|
if fallback_enabled: |
|
|
return _fallback_segmentation(image) |
|
|
else: |
|
|
raise SegmentationError(f"Unexpected error: {e}") |
|
|
|
|
|
def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
|
|
"""Original version of person segmentation for rollback""" |
|
|
if image is None or image.size == 0: |
|
|
raise SegmentationError("Invalid input image") |
|
|
|
|
|
try: |
|
|
if predictor is None: |
|
|
if fallback_enabled: |
|
|
logger.warning("SAM2 predictor not available, using fallback") |
|
|
return _fallback_segmentation(image) |
|
|
else: |
|
|
raise SegmentationError("SAM2 predictor not available") |
|
|
|
|
|
try: |
|
|
predictor.set_image(image) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to set image in predictor: {e}") |
|
|
if fallback_enabled: |
|
|
return _fallback_segmentation(image) |
|
|
else: |
|
|
raise SegmentationError(f"Predictor setup failed: {e}") |
|
|
|
|
|
h, w = image.shape[:2] |
|
|
|
|
|
points = np.array([ |
|
|
[w//2, h//4], |
|
|
[w//2, h//2], |
|
|
[w//2, 3*h//4], |
|
|
[w//3, h//2], |
|
|
[2*w//3, h//2], |
|
|
[w//2, h//6], |
|
|
[w//4, 2*h//3], |
|
|
[3*w//4, 2*h//3], |
|
|
], dtype=np.float32) |
|
|
|
|
|
labels = np.ones(len(points), dtype=np.int32) |
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
masks, scores, _ = predictor.predict( |
|
|
point_coords=points, |
|
|
point_labels=labels, |
|
|
multimask_output=True |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"SAM2 prediction failed: {e}") |
|
|
if fallback_enabled: |
|
|
return _fallback_segmentation(image) |
|
|
else: |
|
|
raise SegmentationError(f"Prediction failed: {e}") |
|
|
|
|
|
if masks is None or len(masks) == 0: |
|
|
logger.warning("SAM2 returned no masks") |
|
|
if fallback_enabled: |
|
|
return _fallback_segmentation(image) |
|
|
else: |
|
|
raise SegmentationError("No masks generated") |
|
|
|
|
|
if scores is None or len(scores) == 0: |
|
|
logger.warning("SAM2 returned no scores") |
|
|
best_mask = masks[0] |
|
|
else: |
|
|
best_idx = np.argmax(scores) |
|
|
best_mask = masks[best_idx] |
|
|
logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}") |
|
|
|
|
|
mask = _process_mask(best_mask) |
|
|
|
|
|
if not _validate_mask_quality(mask, image.shape[:2]): |
|
|
logger.warning("Mask quality validation failed") |
|
|
if fallback_enabled: |
|
|
return _fallback_segmentation(image) |
|
|
else: |
|
|
raise SegmentationError("Poor mask quality") |
|
|
|
|
|
logger.debug(f"Segmentation successful - mask range: {mask.min()}-{mask.max()}") |
|
|
return mask |
|
|
|
|
|
except SegmentationError: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected segmentation error: {e}") |
|
|
if fallback_enabled: |
|
|
return _fallback_segmentation(image) |
|
|
else: |
|
|
raise SegmentationError(f"Unexpected error: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any, |
|
|
fallback_enabled: bool = True) -> np.ndarray: |
|
|
"""Enhanced mask refinement with MatAnyone and robust fallbacks""" |
|
|
if image is None or mask is None: |
|
|
raise MaskRefinementError("Invalid input image or mask") |
|
|
|
|
|
try: |
|
|
mask = _process_mask(mask) |
|
|
|
|
|
if matanyone_processor is not None: |
|
|
try: |
|
|
logger.debug("Attempting MatAnyone refinement") |
|
|
refined_mask = _matanyone_refine(image, mask, matanyone_processor) |
|
|
|
|
|
if refined_mask is not None and _validate_mask_quality(refined_mask, image.shape[:2]): |
|
|
logger.debug("MatAnyone refinement successful") |
|
|
return refined_mask |
|
|
else: |
|
|
logger.warning("MatAnyone produced poor quality mask") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"MatAnyone refinement failed: {e}") |
|
|
|
|
|
if fallback_enabled: |
|
|
logger.debug("Using enhanced OpenCV refinement") |
|
|
return enhance_mask_opencv_advanced(image, mask) |
|
|
else: |
|
|
raise MaskRefinementError("MatAnyone failed and fallback disabled") |
|
|
|
|
|
except MaskRefinementError: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected mask refinement error: {e}") |
|
|
if fallback_enabled: |
|
|
return enhance_mask_opencv_advanced(image, mask) |
|
|
else: |
|
|
raise MaskRefinementError(f"Unexpected error: {e}") |
|
|
|
|
|
def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray: |
|
|
"""Advanced OpenCV-based mask enhancement with multiple techniques""" |
|
|
try: |
|
|
if len(mask.shape) == 3: |
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
if mask.max() <= 1.0: |
|
|
mask = (mask * 255).astype(np.uint8) |
|
|
|
|
|
refined_mask = cv2.bilateralFilter(mask, 9, 75, 75) |
|
|
refined_mask = _guided_filter_approx(image, refined_mask, radius=8, eps=0.2) |
|
|
|
|
|
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_CLOSE, kernel_close) |
|
|
|
|
|
kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) |
|
|
refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_OPEN, kernel_open) |
|
|
|
|
|
refined_mask = cv2.GaussianBlur(refined_mask, (3, 3), 0.8) |
|
|
|
|
|
_, refined_mask = cv2.threshold(refined_mask, 127, 255, cv2.THRESH_BINARY) |
|
|
|
|
|
return refined_mask |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Enhanced OpenCV refinement failed: {e}") |
|
|
return cv2.GaussianBlur(mask, (5, 5), 1.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def replace_background_hq(frame: np.ndarray, mask: np.ndarray, background: np.ndarray, |
|
|
fallback_enabled: bool = True) -> np.ndarray: |
|
|
"""Enhanced background replacement with comprehensive error handling""" |
|
|
if frame is None or mask is None or background is None: |
|
|
raise BackgroundReplacementError("Invalid input frame, mask, or background") |
|
|
|
|
|
try: |
|
|
background = cv2.resize(background, (frame.shape[1], frame.shape[0]), |
|
|
interpolation=cv2.INTER_LANCZOS4) |
|
|
|
|
|
if len(mask.shape) == 3: |
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
if mask.dtype != np.uint8: |
|
|
mask = mask.astype(np.uint8) |
|
|
|
|
|
if mask.max() <= 1.0: |
|
|
logger.debug("Converting normalized mask to 0-255 range") |
|
|
mask = (mask * 255).astype(np.uint8) |
|
|
|
|
|
try: |
|
|
result = _advanced_compositing(frame, mask, background) |
|
|
logger.debug("Advanced compositing successful") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Advanced compositing failed: {e}") |
|
|
if fallback_enabled: |
|
|
return _simple_compositing(frame, mask, background) |
|
|
else: |
|
|
raise BackgroundReplacementError(f"Advanced compositing failed: {e}") |
|
|
|
|
|
except BackgroundReplacementError: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected background replacement error: {e}") |
|
|
if fallback_enabled: |
|
|
return _simple_compositing(frame, mask, background) |
|
|
else: |
|
|
raise BackgroundReplacementError(f"Unexpected error: {e}") |
|
|
|
|
|
def create_professional_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray: |
|
|
"""Enhanced professional background creation with quality improvements""" |
|
|
try: |
|
|
if bg_config["type"] == "color": |
|
|
background = _create_solid_background(bg_config, width, height) |
|
|
elif bg_config["type"] == "gradient": |
|
|
background = _create_gradient_background_enhanced(bg_config, width, height) |
|
|
else: |
|
|
background = np.full((height, width, 3), (128, 128, 128), dtype=np.uint8) |
|
|
|
|
|
background = _apply_background_adjustments(background, bg_config) |
|
|
|
|
|
return background |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Background creation error: {e}") |
|
|
return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_video_file(video_path: str) -> Tuple[bool, str]: |
|
|
"""Enhanced video file validation with detailed checks""" |
|
|
if not video_path or not os.path.exists(video_path): |
|
|
return False, "Video file not found" |
|
|
|
|
|
try: |
|
|
file_size = os.path.getsize(video_path) |
|
|
if file_size == 0: |
|
|
return False, "Video file is empty" |
|
|
|
|
|
if file_size > 2 * 1024 * 1024 * 1024: |
|
|
return False, "Video file too large (>2GB)" |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
return False, "Cannot open video file" |
|
|
|
|
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
|
cap.release() |
|
|
|
|
|
if frame_count == 0: |
|
|
return False, "Video appears to be empty (0 frames)" |
|
|
|
|
|
if fps <= 0 or fps > 120: |
|
|
return False, f"Invalid frame rate: {fps}" |
|
|
|
|
|
if width <= 0 or height <= 0: |
|
|
return False, f"Invalid resolution: {width}x{height}" |
|
|
|
|
|
if width > 4096 or height > 4096: |
|
|
return False, f"Resolution too high: {width}x{height} (max 4096x4096)" |
|
|
|
|
|
duration = frame_count / fps |
|
|
if duration > 300: |
|
|
return False, f"Video too long: {duration:.1f}s (max 300s)" |
|
|
|
|
|
return True, f"Valid video: {width}x{height}, {fps:.1f}fps, {duration:.1f}s" |
|
|
|
|
|
except Exception as e: |
|
|
return False, f"Error validating video: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
|
|
"""Intelligent automatic prompt generation for segmentation""" |
|
|
try: |
|
|
h, w = image.shape[:2] |
|
|
pos_points, neg_points = _generate_smart_prompts(image) |
|
|
|
|
|
if len(pos_points) == 0: |
|
|
pos_points = np.array([[w//2, h//2]], dtype=np.float32) |
|
|
|
|
|
points = np.vstack([pos_points, neg_points]) |
|
|
labels = np.hstack([ |
|
|
np.ones(len(pos_points), dtype=np.int32), |
|
|
np.zeros(len(neg_points), dtype=np.int32) |
|
|
]) |
|
|
|
|
|
logger.debug(f"Using {len(pos_points)} positive, {len(neg_points)} negative points") |
|
|
|
|
|
with torch.no_grad(): |
|
|
masks, scores, _ = predictor.predict( |
|
|
point_coords=points, |
|
|
point_labels=labels, |
|
|
multimask_output=True |
|
|
) |
|
|
|
|
|
if masks is None or len(masks) == 0: |
|
|
raise SegmentationError("No masks generated") |
|
|
|
|
|
if scores is not None and len(scores) > 0: |
|
|
best_idx = np.argmax(scores) |
|
|
best_mask = masks[best_idx] |
|
|
logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}") |
|
|
else: |
|
|
best_mask = masks[0] |
|
|
|
|
|
return _process_mask(best_mask) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Intelligent prompting failed: {e}") |
|
|
raise |
|
|
|
|
|
def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
|
|
"""Basic prompting method for segmentation""" |
|
|
h, w = image.shape[:2] |
|
|
|
|
|
positive_points = np.array([ |
|
|
[w//2, h//3], |
|
|
[w//2, h//2], |
|
|
[w//2, 2*h//3], |
|
|
], dtype=np.float32) |
|
|
|
|
|
negative_points = np.array([ |
|
|
[w//10, h//10], |
|
|
[9*w//10, h//10], |
|
|
[w//10, 9*h//10], |
|
|
[9*w//10, 9*h//10], |
|
|
], dtype=np.float32) |
|
|
|
|
|
points = np.vstack([positive_points, negative_points]) |
|
|
labels = np.array([1, 1, 1, 0, 0, 0, 0], dtype=np.int32) |
|
|
|
|
|
with torch.no_grad(): |
|
|
masks, scores, _ = predictor.predict( |
|
|
point_coords=points, |
|
|
point_labels=labels, |
|
|
multimask_output=True |
|
|
) |
|
|
|
|
|
if masks is None or len(masks) == 0: |
|
|
raise SegmentationError("No masks generated") |
|
|
|
|
|
best_idx = np.argmax(scores) if scores is not None and len(scores) > 0 else 0 |
|
|
best_mask = masks[best_idx] |
|
|
|
|
|
return _process_mask(best_mask) |
|
|
|
|
|
def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
|
|
"""Generate optimal positive/negative points automatically""" |
|
|
try: |
|
|
h, w = image.shape[:2] |
|
|
|
|
|
try: |
|
|
saliency = cv2.saliency.StaticSaliencySpectralResidual_create() |
|
|
success, saliency_map = saliency.computeSaliency(image) |
|
|
|
|
|
if success: |
|
|
saliency_thresh = cv2.threshold(saliency_map, 0.7, 1, cv2.THRESH_BINARY)[1] |
|
|
contours, _ = cv2.findContours((saliency_thresh * 255).astype(np.uint8), |
|
|
cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
positive_points = [] |
|
|
if contours: |
|
|
for contour in sorted(contours, key=cv2.contourArea, reverse=True)[:3]: |
|
|
M = cv2.moments(contour) |
|
|
if M["m00"] != 0: |
|
|
cx = int(M["m10"] / M["m00"]) |
|
|
cy = int(M["m01"] / M["m00"]) |
|
|
if 0 < cx < w and 0 < cy < h: |
|
|
positive_points.append([cx, cy]) |
|
|
|
|
|
if positive_points: |
|
|
logger.debug(f"Generated {len(positive_points)} saliency-based points") |
|
|
positive_points = np.array(positive_points, dtype=np.float32) |
|
|
else: |
|
|
raise Exception("No valid saliency points found") |
|
|
|
|
|
except Exception as e: |
|
|
logger.debug(f"Saliency method failed: {e}, using fallback") |
|
|
positive_points = np.array([ |
|
|
[w//2, h//3], |
|
|
[w//2, h//2], |
|
|
[w//2, 2*h//3], |
|
|
], dtype=np.float32) |
|
|
|
|
|
negative_points = np.array([ |
|
|
[10, 10], |
|
|
[w-10, 10], |
|
|
[10, h-10], |
|
|
[w-10, h-10], |
|
|
[w//2, 5], |
|
|
[w//2, h-5], |
|
|
], dtype=np.float32) |
|
|
|
|
|
return positive_points, negative_points |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Smart prompt generation failed: {e}") |
|
|
h, w = image.shape[:2] |
|
|
positive_points = np.array([[w//2, h//2]], dtype=np.float32) |
|
|
negative_points = np.array([[10, 10], [w-10, 10]], dtype=np.float32) |
|
|
return positive_points, negative_points |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _auto_refine_mask_iteratively(image: np.ndarray, initial_mask: np.ndarray, |
|
|
predictor: Any, max_iterations: int = 2) -> np.ndarray: |
|
|
"""Automatically refine mask based on quality assessment""" |
|
|
try: |
|
|
current_mask = initial_mask.copy() |
|
|
|
|
|
for iteration in range(max_iterations): |
|
|
quality_score = _assess_mask_quality(current_mask, image) |
|
|
logger.debug(f"Iteration {iteration}: quality score = {quality_score:.3f}") |
|
|
|
|
|
if quality_score > 0.85: |
|
|
logger.debug(f"Quality sufficient after {iteration} iterations") |
|
|
break |
|
|
|
|
|
problem_areas = _find_mask_errors(current_mask, image) |
|
|
|
|
|
if np.any(problem_areas): |
|
|
corrective_points, corrective_labels = _generate_corrective_prompts( |
|
|
image, current_mask, problem_areas |
|
|
) |
|
|
|
|
|
if len(corrective_points) > 0: |
|
|
try: |
|
|
with torch.no_grad(): |
|
|
masks, scores, _ = predictor.predict( |
|
|
point_coords=corrective_points, |
|
|
point_labels=corrective_labels, |
|
|
mask_input=current_mask[None, :, :], |
|
|
multimask_output=False |
|
|
) |
|
|
|
|
|
if masks is not None and len(masks) > 0: |
|
|
refined_mask = _process_mask(masks[0]) |
|
|
|
|
|
if _assess_mask_quality(refined_mask, image) > quality_score: |
|
|
current_mask = refined_mask |
|
|
logger.debug(f"Improved mask in iteration {iteration}") |
|
|
else: |
|
|
logger.debug(f"Refinement didn't improve quality in iteration {iteration}") |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
logger.debug(f"Refinement iteration {iteration} failed: {e}") |
|
|
break |
|
|
else: |
|
|
logger.debug("No problem areas detected") |
|
|
break |
|
|
|
|
|
return current_mask |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Iterative refinement failed: {e}") |
|
|
return initial_mask |
|
|
|
|
|
def _assess_mask_quality(mask: np.ndarray, image: np.ndarray) -> float: |
|
|
"""Assess mask quality automatically""" |
|
|
try: |
|
|
h, w = image.shape[:2] |
|
|
scores = [] |
|
|
|
|
|
mask_area = np.sum(mask > 127) |
|
|
total_area = h * w |
|
|
area_ratio = mask_area / total_area |
|
|
|
|
|
if 0.05 <= area_ratio <= 0.8: |
|
|
area_score = 1.0 |
|
|
elif area_ratio < 0.05: |
|
|
area_score = area_ratio / 0.05 |
|
|
else: |
|
|
area_score = max(0, 1.0 - (area_ratio - 0.8) / 0.2) |
|
|
scores.append(area_score) |
|
|
|
|
|
mask_binary = mask > 127 |
|
|
if np.any(mask_binary): |
|
|
mask_center_y, mask_center_x = np.where(mask_binary) |
|
|
center_y = np.mean(mask_center_y) / h |
|
|
center_x = np.mean(mask_center_x) / w |
|
|
|
|
|
center_score = 1.0 - min(abs(center_x - 0.5), abs(center_y - 0.5)) |
|
|
scores.append(center_score) |
|
|
else: |
|
|
scores.append(0.0) |
|
|
|
|
|
edges = cv2.Canny(mask, 50, 150) |
|
|
edge_density = np.sum(edges > 0) / total_area |
|
|
smoothness_score = max(0, 1.0 - edge_density * 10) |
|
|
scores.append(smoothness_score) |
|
|
|
|
|
num_labels, _ = cv2.connectedComponents(mask) |
|
|
connectivity_score = max(0, 1.0 - (num_labels - 2) * 0.2) |
|
|
scores.append(connectivity_score) |
|
|
|
|
|
weights = [0.3, 0.2, 0.3, 0.2] |
|
|
overall_score = np.average(scores, weights=weights) |
|
|
|
|
|
return overall_score |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Quality assessment failed: {e}") |
|
|
return 0.5 |
|
|
|
|
|
def _find_mask_errors(mask: np.ndarray, image: np.ndarray) -> np.ndarray: |
|
|
"""Identify problematic areas in mask""" |
|
|
try: |
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
edges = cv2.Canny(gray, 50, 150) |
|
|
mask_edges = cv2.Canny(mask, 50, 150) |
|
|
edge_discrepancy = cv2.bitwise_xor(edges, mask_edges) |
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
error_regions = cv2.dilate(edge_discrepancy, kernel, iterations=1) |
|
|
return error_regions > 0 |
|
|
except Exception as e: |
|
|
logger.warning(f"Error detection failed: {e}") |
|
|
return np.zeros_like(mask, dtype=bool) |
|
|
|
|
|
def _generate_corrective_prompts(image: np.ndarray, mask: np.ndarray, |
|
|
problem_areas: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
|
|
"""Generate corrective prompts based on problem areas""" |
|
|
try: |
|
|
contours, _ = cv2.findContours(problem_areas.astype(np.uint8), |
|
|
cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
corrective_points = [] |
|
|
corrective_labels = [] |
|
|
|
|
|
for contour in contours: |
|
|
if cv2.contourArea(contour) > 100: |
|
|
M = cv2.moments(contour) |
|
|
if M["m00"] != 0: |
|
|
cx = int(M["m10"] / M["m00"]) |
|
|
cy = int(M["m01"] / M["m00"]) |
|
|
|
|
|
current_mask_value = mask[cy, cx] |
|
|
|
|
|
if current_mask_value < 127: |
|
|
corrective_points.append([cx, cy]) |
|
|
corrective_labels.append(1) |
|
|
else: |
|
|
corrective_points.append([cx, cy]) |
|
|
corrective_labels.append(0) |
|
|
|
|
|
return (np.array(corrective_points, dtype=np.float32) if corrective_points else np.array([]).reshape(0, 2), |
|
|
np.array(corrective_labels, dtype=np.int32) if corrective_labels else np.array([], dtype=np.int32)) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Corrective prompt generation failed: {e}") |
|
|
return np.array([]).reshape(0, 2), np.array([], dtype=np.int32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _process_mask(mask: np.ndarray) -> np.ndarray: |
|
|
"""Process raw mask to ensure correct format and range""" |
|
|
try: |
|
|
if len(mask.shape) > 2: |
|
|
mask = mask.squeeze() |
|
|
|
|
|
if len(mask.shape) > 2: |
|
|
mask = mask[:, :, 0] if mask.shape[2] > 0 else mask.sum(axis=2) |
|
|
|
|
|
if mask.dtype == bool: |
|
|
mask = mask.astype(np.uint8) * 255 |
|
|
elif mask.dtype == np.float32 or mask.dtype == np.float64: |
|
|
if mask.max() <= 1.0: |
|
|
mask = (mask * 255).astype(np.uint8) |
|
|
else: |
|
|
mask = np.clip(mask, 0, 255).astype(np.uint8) |
|
|
else: |
|
|
mask = mask.astype(np.uint8) |
|
|
|
|
|
kernel = np.ones((3, 3), np.uint8) |
|
|
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
|
|
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
|
|
|
|
|
_, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
|
|
|
|
|
return mask |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Mask processing failed: {e}") |
|
|
h, w = mask.shape[:2] if len(mask.shape) >= 2 else (256, 256) |
|
|
fallback = np.zeros((h, w), dtype=np.uint8) |
|
|
fallback[h//4:3*h//4, w//4:3*w//4] = 255 |
|
|
return fallback |
|
|
|
|
|
def _validate_mask_quality(mask: np.ndarray, image_shape: Tuple[int, int]) -> bool: |
|
|
"""Validate that the mask meets quality criteria""" |
|
|
try: |
|
|
h, w = image_shape |
|
|
mask_area = np.sum(mask > 127) |
|
|
total_area = h * w |
|
|
|
|
|
area_ratio = mask_area / total_area |
|
|
if area_ratio < 0.05 or area_ratio > 0.8: |
|
|
logger.warning(f"Suspicious mask area ratio: {area_ratio:.3f}") |
|
|
return False |
|
|
|
|
|
mask_binary = mask > 127 |
|
|
mask_center_y, mask_center_x = np.where(mask_binary) |
|
|
|
|
|
if len(mask_center_y) == 0: |
|
|
logger.warning("Empty mask") |
|
|
return False |
|
|
|
|
|
center_y = np.mean(mask_center_y) |
|
|
center_x = np.mean(mask_center_x) |
|
|
|
|
|
if center_y < h * 0.2 or center_y > h * 0.9: |
|
|
logger.warning(f"Mask center too far from expected person location: y={center_y/h:.2f}") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Mask validation error: {e}") |
|
|
return True |
|
|
|
|
|
def _fallback_segmentation(image: np.ndarray) -> np.ndarray: |
|
|
"""Fallback segmentation when AI models fail""" |
|
|
try: |
|
|
logger.info("Using fallback segmentation strategy") |
|
|
h, w = image.shape[:2] |
|
|
|
|
|
try: |
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
edge_pixels = np.concatenate([ |
|
|
gray[0, :], gray[-1, :], gray[:, 0], gray[:, -1] |
|
|
]) |
|
|
bg_color = np.median(edge_pixels) |
|
|
|
|
|
diff = np.abs(gray.astype(float) - bg_color) |
|
|
mask = (diff > 30).astype(np.uint8) * 255 |
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) |
|
|
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
|
|
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
|
|
|
|
|
if _validate_mask_quality(mask, image.shape[:2]): |
|
|
logger.info("Background subtraction fallback successful") |
|
|
return mask |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Background subtraction fallback failed: {e}") |
|
|
|
|
|
mask = np.zeros((h, w), dtype=np.uint8) |
|
|
|
|
|
center_x, center_y = w // 2, h // 2 |
|
|
radius_x, radius_y = w // 3, h // 2.5 |
|
|
|
|
|
y, x = np.ogrid[:h, :w] |
|
|
mask_ellipse = ((x - center_x) / radius_x) ** 2 + ((y - center_y) / radius_y) ** 2 <= 1 |
|
|
mask[mask_ellipse] = 255 |
|
|
|
|
|
logger.info("Using geometric fallback mask") |
|
|
return mask |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"All fallback strategies failed: {e}") |
|
|
h, w = image.shape[:2] |
|
|
mask = np.zeros((h, w), dtype=np.uint8) |
|
|
mask[h//6:5*h//6, w//4:3*w//4] = 255 |
|
|
return mask |
|
|
|
|
|
def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Optional[np.ndarray]: |
|
|
"""Attempt MatAnyone mask refinement""" |
|
|
try: |
|
|
if hasattr(processor, 'infer'): |
|
|
refined_mask = processor.infer(image, mask) |
|
|
elif hasattr(processor, 'process'): |
|
|
refined_mask = processor.process(image, mask) |
|
|
elif callable(processor): |
|
|
refined_mask = processor(image, mask) |
|
|
else: |
|
|
logger.warning("Unknown MatAnyone interface") |
|
|
return None |
|
|
|
|
|
if refined_mask is None: |
|
|
return None |
|
|
|
|
|
refined_mask = _process_mask(refined_mask) |
|
|
logger.debug("MatAnyone refinement successful") |
|
|
return refined_mask |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"MatAnyone processing error: {e}") |
|
|
return None |
|
|
|
|
|
def _guided_filter_approx(guide: np.ndarray, mask: np.ndarray, radius: int = 8, eps: float = 0.2) -> np.ndarray: |
|
|
"""Approximation of guided filter for edge-aware smoothing""" |
|
|
try: |
|
|
guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) if len(guide.shape) == 3 else guide |
|
|
guide_gray = guide_gray.astype(np.float32) / 255.0 |
|
|
mask_float = mask.astype(np.float32) / 255.0 |
|
|
|
|
|
kernel_size = 2 * radius + 1 |
|
|
|
|
|
mean_guide = cv2.boxFilter(guide_gray, -1, (kernel_size, kernel_size)) |
|
|
mean_mask = cv2.boxFilter(mask_float, -1, (kernel_size, kernel_size)) |
|
|
corr_guide_mask = cv2.boxFilter(guide_gray * mask_float, -1, (kernel_size, kernel_size)) |
|
|
|
|
|
cov_guide_mask = corr_guide_mask - mean_guide * mean_mask |
|
|
mean_guide_sq = cv2.boxFilter(guide_gray * guide_gray, -1, (kernel_size, kernel_size)) |
|
|
var_guide = mean_guide_sq - mean_guide * mean_guide |
|
|
|
|
|
a = cov_guide_mask / (var_guide + eps) |
|
|
b = mean_mask - a * mean_guide |
|
|
|
|
|
mean_a = cv2.boxFilter(a, -1, (kernel_size, kernel_size)) |
|
|
mean_b = cv2.boxFilter(b, -1, (kernel_size, kernel_size)) |
|
|
|
|
|
output = mean_a * guide_gray + mean_b |
|
|
output = np.clip(output * 255, 0, 255).astype(np.uint8) |
|
|
|
|
|
return output |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Guided filter approximation failed: {e}") |
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _advanced_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray: |
|
|
"""Advanced compositing with edge feathering and color correction""" |
|
|
try: |
|
|
threshold = 100 |
|
|
_, mask_binary = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY) |
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_CLOSE, kernel) |
|
|
mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_OPEN, kernel) |
|
|
|
|
|
mask_smooth = cv2.GaussianBlur(mask_binary.astype(np.float32), (5, 5), 1.0) |
|
|
mask_smooth = mask_smooth / 255.0 |
|
|
|
|
|
mask_smooth = np.power(mask_smooth, 0.8) |
|
|
|
|
|
mask_smooth = np.where(mask_smooth > 0.5, |
|
|
np.minimum(mask_smooth * 1.1, 1.0), |
|
|
mask_smooth * 0.9) |
|
|
|
|
|
frame_adjusted = _color_match_edges(frame, background, mask_smooth) |
|
|
|
|
|
alpha_3ch = np.stack([mask_smooth] * 3, axis=2) |
|
|
|
|
|
frame_float = frame_adjusted.astype(np.float32) |
|
|
background_float = background.astype(np.float32) |
|
|
|
|
|
result = frame_float * alpha_3ch + background_float * (1 - alpha_3ch) |
|
|
result = np.clip(result, 0, 255).astype(np.uint8) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Advanced compositing error: {e}") |
|
|
raise |
|
|
|
|
|
def _color_match_edges(frame: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray: |
|
|
"""Subtle color matching at edges to reduce halos""" |
|
|
try: |
|
|
edge_mask = cv2.Sobel(alpha, cv2.CV_64F, 1, 1, ksize=3) |
|
|
edge_mask = np.abs(edge_mask) |
|
|
edge_mask = (edge_mask > 0.1).astype(np.float32) |
|
|
|
|
|
edge_areas = edge_mask > 0 |
|
|
if not np.any(edge_areas): |
|
|
return frame |
|
|
|
|
|
frame_adjusted = frame.copy().astype(np.float32) |
|
|
background_float = background.astype(np.float32) |
|
|
|
|
|
adjustment_strength = 0.1 |
|
|
for c in range(3): |
|
|
frame_adjusted[:, :, c] = np.where( |
|
|
edge_areas, |
|
|
frame_adjusted[:, :, c] * (1 - adjustment_strength) + |
|
|
background_float[:, :, c] * adjustment_strength, |
|
|
frame_adjusted[:, :, c] |
|
|
) |
|
|
|
|
|
return np.clip(frame_adjusted, 0, 255).astype(np.uint8) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Color matching failed: {e}") |
|
|
return frame |
|
|
|
|
|
def _simple_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray: |
|
|
"""Simple fallback compositing method""" |
|
|
try: |
|
|
logger.info("Using simple compositing fallback") |
|
|
|
|
|
background = cv2.resize(background, (frame.shape[1], frame.shape[0])) |
|
|
|
|
|
if len(mask.shape) == 3: |
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
|
|
if mask.max() <= 1.0: |
|
|
mask = (mask * 255).astype(np.uint8) |
|
|
|
|
|
_, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
|
|
|
|
|
mask_norm = mask_binary.astype(np.float32) / 255.0 |
|
|
mask_3ch = np.stack([mask_norm] * 3, axis=2) |
|
|
|
|
|
result = frame * mask_3ch + background * (1 - mask_3ch) |
|
|
return result.astype(np.uint8) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Simple compositing failed: {e}") |
|
|
return frame |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_solid_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray: |
|
|
"""Create solid color background""" |
|
|
color_hex = bg_config["colors"][0].lstrip('#') |
|
|
color_rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4)) |
|
|
color_bgr = color_rgb[::-1] |
|
|
return np.full((height, width, 3), color_bgr, dtype=np.uint8) |
|
|
|
|
|
def _create_gradient_background_enhanced(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray: |
|
|
"""Create enhanced gradient background with better quality""" |
|
|
try: |
|
|
colors = bg_config["colors"] |
|
|
direction = bg_config.get("direction", "vertical") |
|
|
|
|
|
rgb_colors = [] |
|
|
for color_hex in colors: |
|
|
color_hex = color_hex.lstrip('#') |
|
|
rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4)) |
|
|
rgb_colors.append(rgb) |
|
|
|
|
|
if not rgb_colors: |
|
|
rgb_colors = [(128, 128, 128)] |
|
|
|
|
|
if direction == "vertical": |
|
|
background = _create_vertical_gradient(rgb_colors, width, height) |
|
|
elif direction == "horizontal": |
|
|
background = _create_horizontal_gradient(rgb_colors, width, height) |
|
|
elif direction == "diagonal": |
|
|
background = _create_diagonal_gradient(rgb_colors, width, height) |
|
|
elif direction in ["radial", "soft_radial"]: |
|
|
background = _create_radial_gradient(rgb_colors, width, height, direction == "soft_radial") |
|
|
else: |
|
|
background = _create_vertical_gradient(rgb_colors, width, height) |
|
|
|
|
|
return cv2.cvtColor(background, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Gradient creation error: {e}") |
|
|
return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8) |
|
|
|
|
|
def _create_vertical_gradient(colors: list, width: int, height: int) -> np.ndarray: |
|
|
"""Create vertical gradient using NumPy for performance""" |
|
|
gradient = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
|
|
|
for y in range(height): |
|
|
progress = y / height if height > 0 else 0 |
|
|
color = _interpolate_color(colors, progress) |
|
|
gradient[y, :] = color |
|
|
|
|
|
return gradient |
|
|
|
|
|
def _create_horizontal_gradient(colors: list, width: int, height: int) -> np.ndarray: |
|
|
"""Create horizontal gradient using NumPy for performance""" |
|
|
gradient = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
|
|
|
for x in range(width): |
|
|
progress = x / width if width > 0 else 0 |
|
|
color = _interpolate_color(colors, progress) |
|
|
gradient[:, x] = color |
|
|
|
|
|
return gradient |
|
|
|
|
|
def _create_diagonal_gradient(colors: list, width: int, height: int) -> np.ndarray: |
|
|
"""Create diagonal gradient using vectorized operations""" |
|
|
y_coords, x_coords = np.mgrid[0:height, 0:width] |
|
|
max_distance = width + height |
|
|
progress = (x_coords + y_coords) / max_distance |
|
|
progress = np.clip(progress, 0, 1) |
|
|
|
|
|
gradient = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
for c in range(3): |
|
|
gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c) |
|
|
|
|
|
return gradient |
|
|
|
|
|
def _create_radial_gradient(colors: list, width: int, height: int, soft: bool = False) -> np.ndarray: |
|
|
"""Create radial gradient using vectorized operations""" |
|
|
center_x, center_y = width // 2, height // 2 |
|
|
max_distance = np.sqrt(center_x**2 + center_y**2) |
|
|
|
|
|
y_coords, x_coords = np.mgrid[0:height, 0:width] |
|
|
distances = np.sqrt((x_coords - center_x)**2 + (y_coords - center_y)**2) |
|
|
progress = distances / max_distance |
|
|
progress = np.clip(progress, 0, 1) |
|
|
|
|
|
if soft: |
|
|
progress = np.power(progress, 0.7) |
|
|
|
|
|
gradient = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
for c in range(3): |
|
|
gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c) |
|
|
|
|
|
return gradient |
|
|
|
|
|
def _vectorized_color_interpolation(colors: list, progress: np.ndarray, channel: int) -> np.ndarray: |
|
|
"""Vectorized color interpolation for performance""" |
|
|
if len(colors) == 1: |
|
|
return np.full_like(progress, colors[0][channel], dtype=np.uint8) |
|
|
|
|
|
num_segments = len(colors) - 1 |
|
|
segment_progress = progress * num_segments |
|
|
segment_indices = np.floor(segment_progress).astype(int) |
|
|
segment_indices = np.clip(segment_indices, 0, num_segments - 1) |
|
|
local_progress = segment_progress - segment_indices |
|
|
|
|
|
start_colors = np.array([colors[i][channel] for i in range(len(colors))]) |
|
|
end_colors = np.array([colors[min(i + 1, len(colors) - 1)][channel] for i in range(len(colors))]) |
|
|
|
|
|
start_vals = start_colors[segment_indices] |
|
|
end_vals = end_colors[segment_indices] |
|
|
|
|
|
result = start_vals + (end_vals - start_vals) * local_progress |
|
|
return np.clip(result, 0, 255).astype(np.uint8) |
|
|
|
|
|
def _interpolate_color(colors: list, progress: float) -> tuple: |
|
|
"""Interpolate between multiple colors""" |
|
|
if len(colors) == 1: |
|
|
return colors[0] |
|
|
elif len(colors) == 2: |
|
|
r = int(colors[0][0] + (colors[1][0] - colors[0][0]) * progress) |
|
|
g = int(colors[0][1] + (colors[1][1] - colors[0][1]) * progress) |
|
|
b = int(colors[0][2] + (colors[1][2] - colors[0][2]) * progress) |
|
|
return (r, g, b) |
|
|
else: |
|
|
segment = progress * (len(colors) - 1) |
|
|
idx = int(segment) |
|
|
local_progress = segment - idx |
|
|
if idx >= len(colors) - 1: |
|
|
return colors[-1] |
|
|
c1, c2 = colors[idx], colors[idx + 1] |
|
|
r = int(c1[0] + (c2[0] - c1[0]) * local_progress) |
|
|
g = int(c1[1] + (c2[1] - c1[1]) * local_progress) |
|
|
b = int(c1[2] + (c2[2] - c1[2]) * local_progress) |
|
|
return (r, g, b) |
|
|
|
|
|
def _apply_background_adjustments(background: np.ndarray, bg_config: Dict[str, Any]) -> np.ndarray: |
|
|
"""Apply brightness and contrast adjustments to background""" |
|
|
try: |
|
|
brightness = bg_config.get("brightness", 1.0) |
|
|
contrast = bg_config.get("contrast", 1.0) |
|
|
|
|
|
if brightness != 1.0 or contrast != 1.0: |
|
|
background = background.astype(np.float32) |
|
|
background = background * contrast * brightness |
|
|
background = np.clip(background, 0, 255).astype(np.uint8) |
|
|
|
|
|
return background |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Background adjustment failed: {e}") |
|
|
return background |