""" Fallback strategies for BackgroundFX Pro. Implements robust fallback mechanisms when primary processing fails. """ import cv2 import numpy as np import torch from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass from enum import Enum import logging import traceback from ..utils.logger import setup_logger from ..utils.device import DeviceManager from ..utils.config import ConfigManager from ..core.quality import QualityAnalyzer logger = setup_logger(__name__) class FallbackLevel(Enum): """Fallback hierarchy levels.""" NONE = 0 QUALITY_REDUCTION = 1 METHOD_SWITCH = 2 BASIC_PROCESSING = 3 MINIMAL_PROCESSING = 4 PASSTHROUGH = 5 @dataclass class FallbackConfig: """Configuration for fallback strategies.""" max_retries: int = 3 quality_reduction_factor: float = 0.75 min_quality: float = 0.3 enable_caching: bool = True cache_size: int = 10 timeout_seconds: float = 30.0 gpu_fallback_to_cpu: bool = True progressive_downscale: bool = True min_resolution: Tuple[int, int] = (320, 240) class FallbackStrategy: """Intelligent fallback strategy manager.""" def __init__(self, config: Optional[FallbackConfig] = None): self.config = config or FallbackConfig() self.device_manager = DeviceManager() self.quality_analyzer = QualityAnalyzer() self.cache = {} self.fallback_history = [] self.current_level = FallbackLevel.NONE def execute_with_fallback(self, func, *args, **kwargs) -> Dict[str, Any]: """ Execute function with automatic fallback on failure. Args: func: Function to execute *args: Function arguments **kwargs: Function keyword arguments Returns: Result dictionary with status and output """ attempt = 0 last_error = None original_args = args original_kwargs = kwargs.copy() while attempt < self.config.max_retries: try: # Log attempt logger.info(f"Attempt {attempt + 1}/{self.config.max_retries} for {func.__name__}") # Try execution result = func(*args, **kwargs) # Success - reset fallback level self.current_level = FallbackLevel.NONE return { 'success': True, 'result': result, 'attempts': attempt + 1, 'fallback_level': self.current_level } except Exception as e: last_error = e logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") # Apply fallback strategy fallback_result = self._apply_fallback( func, e, attempt, original_args, original_kwargs ) if fallback_result['handled']: args = fallback_result.get('new_args', args) kwargs = fallback_result.get('new_kwargs', kwargs) else: break attempt += 1 # All attempts failed - apply final fallback logger.error(f"All attempts failed for {func.__name__}") return self._final_fallback(func, last_error, original_args) def _apply_fallback(self, func, error: Exception, attempt: int, original_args: tuple, original_kwargs: dict) -> Dict[str, Any]: """Apply appropriate fallback strategy based on error type.""" error_type = type(error).__name__ self.fallback_history.append({ 'function': func.__name__, 'error': error_type, 'attempt': attempt }) # GPU memory error - switch to CPU if 'CUDA' in str(error) or 'GPU' in str(error): return self._handle_gpu_error(original_kwargs) # Memory error - reduce quality elif 'memory' in str(error).lower(): return self._handle_memory_error(original_args, original_kwargs) # Timeout error - simplify processing elif 'timeout' in str(error).lower(): return self._handle_timeout_error(original_kwargs) # Model loading error - use simpler model elif 'model' in str(error).lower(): return self._handle_model_error(original_kwargs) # Generic error - progressive degradation else: return self._handle_generic_error(attempt, original_kwargs) def _handle_gpu_error(self, kwargs: dict) -> Dict[str, Any]: """Handle GPU-related errors.""" logger.info("GPU error detected, falling back to CPU") if self.config.gpu_fallback_to_cpu: # Switch to CPU self.device_manager.device = torch.device('cpu') kwargs['device'] = 'cpu' # Reduce batch size if present if 'batch_size' in kwargs: kwargs['batch_size'] = max(1, kwargs['batch_size'] // 2) self.current_level = FallbackLevel.METHOD_SWITCH return { 'handled': True, 'new_kwargs': kwargs } return {'handled': False} def _handle_memory_error(self, args: tuple, kwargs: dict) -> Dict[str, Any]: """Handle memory-related errors.""" logger.info("Memory error detected, reducing quality") # Try to find image in args image = None image_idx = -1 for i, arg in enumerate(args): if isinstance(arg, np.ndarray) and len(arg.shape) == 3: image = arg image_idx = i break if image is not None and self.config.progressive_downscale: # Reduce image size h, w = image.shape[:2] new_h = int(h * self.config.quality_reduction_factor) new_w = int(w * self.config.quality_reduction_factor) # Ensure minimum resolution new_h = max(new_h, self.config.min_resolution[1]) new_w = max(new_w, self.config.min_resolution[0]) if new_h < h or new_w < w: resized = cv2.resize(image, (new_w, new_h)) args = list(args) args[image_idx] = resized self.current_level = FallbackLevel.QUALITY_REDUCTION return { 'handled': True, 'new_args': tuple(args), 'new_kwargs': kwargs } # Reduce other memory-intensive parameters if 'quality' in kwargs: kwargs['quality'] = max( self.config.min_quality, kwargs['quality'] * self.config.quality_reduction_factor ) return { 'handled': True, 'new_kwargs': kwargs } def _handle_timeout_error(self, kwargs: dict) -> Dict[str, Any]: """Handle timeout errors by simplifying processing.""" logger.info("Timeout detected, simplifying processing") # Disable expensive operations simplifications = { 'use_refinement': False, 'use_temporal': False, 'use_guided_filter': False, 'iterations': 1, 'num_samples': 1 } for key, value in simplifications.items(): if key in kwargs: kwargs[key] = value self.current_level = FallbackLevel.BASIC_PROCESSING return { 'handled': True, 'new_kwargs': kwargs } def _handle_model_error(self, kwargs: dict) -> Dict[str, Any]: """Handle model loading errors.""" logger.info("Model error detected, using simpler model") # Switch to simpler model if 'model_type' in kwargs: model_hierarchy = ['large', 'base', 'small', 'tiny'] current = kwargs.get('model_type', 'base') if current in model_hierarchy: idx = model_hierarchy.index(current) if idx < len(model_hierarchy) - 1: kwargs['model_type'] = model_hierarchy[idx + 1] self.current_level = FallbackLevel.METHOD_SWITCH return { 'handled': True, 'new_kwargs': kwargs } # Disable model-based processing kwargs['use_model'] = False self.current_level = FallbackLevel.BASIC_PROCESSING return { 'handled': True, 'new_kwargs': kwargs } def _handle_generic_error(self, attempt: int, kwargs: dict) -> Dict[str, Any]: """Handle generic errors with progressive degradation.""" logger.info(f"Generic error, applying degradation level {attempt + 1}") # Progressive degradation based on attempt if attempt == 0: # First attempt - minor quality reduction self.current_level = FallbackLevel.QUALITY_REDUCTION if 'quality' in kwargs: kwargs['quality'] *= 0.8 elif attempt == 1: # Second attempt - switch methods self.current_level = FallbackLevel.METHOD_SWITCH kwargs['method'] = 'basic' else: # Final attempt - minimal processing self.current_level = FallbackLevel.MINIMAL_PROCESSING kwargs['skip_refinement'] = True kwargs['fast_mode'] = True return { 'handled': True, 'new_kwargs': kwargs } def _final_fallback(self, func, error: Exception, original_args: tuple) -> Dict[str, Any]: """Apply final fallback when all attempts fail.""" logger.error(f"Final fallback for {func.__name__}: {str(error)}") self.current_level = FallbackLevel.PASSTHROUGH # Try to return something useful for arg in original_args: if isinstance(arg, np.ndarray): # Return original image/mask return { 'success': False, 'result': arg, 'fallback_level': self.current_level, 'error': str(error) } # Return empty result return { 'success': False, 'result': None, 'fallback_level': self.current_level, 'error': str(error) } class ProcessingFallback: """Specific fallback implementations for processing operations.""" def __init__(self): self.logger = setup_logger(f"{__name__}.ProcessingFallback") self.quality_analyzer = QualityAnalyzer() def basic_segmentation(self, image: np.ndarray) -> np.ndarray: """ Basic segmentation using traditional CV methods. Used as fallback when ML models fail. Args: image: Input image Returns: Binary mask """ try: # Convert to grayscale if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = image # Apply GrabCut for basic foreground extraction mask = np.zeros(gray.shape[:2], np.uint8) bgd_model = np.zeros((1, 65), np.float64) fgd_model = np.zeros((1, 65), np.float64) # Initialize rectangle (center 80% of image) h, w = gray.shape[:2] rect = (int(w * 0.1), int(h * 0.1), int(w * 0.8), int(h * 0.8)) # Apply GrabCut cv2.grabCut(image, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT) # Extract foreground mask2 = np.where((mask == 2) | (mask == 0), 0, 255).astype('uint8') return mask2 except Exception as e: self.logger.error(f"Basic segmentation failed: {e}") # Return center blob as last resort return self._center_blob_mask(image.shape[:2]) def _center_blob_mask(self, shape: Tuple[int, int]) -> np.ndarray: """Create a center ellipse mask as ultimate fallback.""" h, w = shape mask = np.zeros((h, w), dtype=np.uint8) # Create center ellipse center = (w // 2, h // 2) axes = (w // 3, h // 3) cv2.ellipse(mask, center, axes, 0, 0, 360, 255, -1) # Smooth edges mask = cv2.GaussianBlur(mask, (21, 21), 10) _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) return mask def basic_matting(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray: """ Basic matting using morphological operations. Args: image: Input image mask: Binary mask Returns: Alpha matte """ try: # Ensure uint8 if mask.dtype != np.uint8: mask = (mask * 255).astype(np.uint8) # Morphological smoothing kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Edge softening mask = cv2.GaussianBlur(mask, (5, 5), 2) # Normalize to [0, 1] alpha = mask.astype(np.float32) / 255.0 return alpha except Exception as e: self.logger.error(f"Basic matting failed: {e}") return mask.astype(np.float32) / 255.0 def color_difference_keying(self, image: np.ndarray, key_color: Optional[np.ndarray] = None, threshold: float = 30) -> np.ndarray: """ Simple color difference keying for solid backgrounds. Args: image: Input image key_color: Background color to remove threshold: Color difference threshold Returns: Alpha matte """ try: if key_color is None: # Estimate background color from corners h, w = image.shape[:2] corners = [ image[0:10, 0:10], image[0:10, w-10:w], image[h-10:h, 0:10], image[h-10:h, w-10:w] ] key_color = np.mean([np.mean(c, axis=(0, 1)) for c in corners], axis=0) # Calculate color difference diff = np.sqrt(np.sum((image - key_color) ** 2, axis=2)) # Create mask mask = (diff > threshold).astype(np.float32) # Smooth edges mask = cv2.GaussianBlur(mask, (5, 5), 2) return mask except Exception as e: self.logger.error(f"Color keying failed: {e}") return np.ones(image.shape[:2], dtype=np.float32) def edge_based_segmentation(self, image: np.ndarray) -> np.ndarray: """ Edge-based segmentation as fallback. Args: image: Input image Returns: Binary mask """ try: # Convert to grayscale if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = image # Edge detection edges = cv2.Canny(gray, 50, 150) # Close contours kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) closed = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel, iterations=2) # Find contours contours, _ = cv2.findContours( closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) # Create mask from largest contour mask = np.zeros(gray.shape, dtype=np.uint8) if contours: largest = max(contours, key=cv2.contourArea) cv2.drawContours(mask, [largest], -1, 255, -1) return mask except Exception as e: self.logger.error(f"Edge segmentation failed: {e}") return self._center_blob_mask(image.shape[:2]) def cached_result(self, cache_key: str, fallback_func, *args, **kwargs) -> Any: """ Try to retrieve cached result or compute with fallback. Args: cache_key: Cache identifier fallback_func: Function to call if not cached *args, **kwargs: Function arguments Returns: Cached or computed result """ # Simple in-memory cache implementation if not hasattr(self, '_cache'): self._cache = {} if cache_key in self._cache: self.logger.info(f"Using cached result for {cache_key}") return self._cache[cache_key] try: result = fallback_func(*args, **kwargs) self._cache[cache_key] = result # Limit cache size if len(self._cache) > 100: # Remove oldest entries keys = list(self._cache.keys()) for key in keys[:20]: del self._cache[key] return result except Exception as e: self.logger.error(f"Cached computation failed: {e}") return None