|
|
|
|
|
""" |
|
|
Fallback strategies for BackgroundFX Pro. |
|
|
Implements robust fallback mechanisms when primary processing fails. |
|
|
""" |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
import logging |
|
|
import traceback |
|
|
|
|
|
|
|
|
from utils.logger import setup_logger |
|
|
from utils.device import DeviceManager |
|
|
from utils.config import ConfigManager |
|
|
from core.quality import QualityAnalyzer |
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
class FallbackLevel(Enum): |
|
|
NONE = 0 |
|
|
QUALITY_REDUCTION = 1 |
|
|
METHOD_SWITCH = 2 |
|
|
BASIC_PROCESSING = 3 |
|
|
MINIMAL_PROCESSING = 4 |
|
|
PASSTHROUGH = 5 |
|
|
|
|
|
@dataclass |
|
|
class FallbackConfig: |
|
|
max_retries: int = 3 |
|
|
quality_reduction_factor: float = 0.75 |
|
|
min_quality: float = 0.3 |
|
|
enable_caching: bool = True |
|
|
cache_size: int = 10 |
|
|
timeout_seconds: float = 30.0 |
|
|
gpu_fallback_to_cpu: bool = True |
|
|
progressive_downscale: bool = True |
|
|
min_resolution: Tuple[int, int] = (320, 240) |
|
|
|
|
|
class FallbackStrategy: |
|
|
def __init__(self, config: Optional[FallbackConfig] = None): |
|
|
self.config = config or FallbackConfig() |
|
|
self.device_manager = DeviceManager() |
|
|
self.quality_analyzer = QualityAnalyzer() |
|
|
self.cache = {} |
|
|
self.fallback_history = [] |
|
|
self.current_level = FallbackLevel.NONE |
|
|
|
|
|
def execute_with_fallback(self, func, *args, **kwargs) -> Dict[str, Any]: |
|
|
attempt = 0 |
|
|
last_error = None |
|
|
original_args = args |
|
|
original_kwargs = kwargs.copy() |
|
|
|
|
|
while attempt < self.config.max_retries: |
|
|
try: |
|
|
logger.info(f"Attempt {attempt + 1}/{self.config.max_retries} for {func.__name__}") |
|
|
result = func(*args, **kwargs) |
|
|
self.current_level = FallbackLevel.NONE |
|
|
return { |
|
|
'success': True, |
|
|
'result': result, |
|
|
'attempts': attempt + 1, |
|
|
'fallback_level': self.current_level |
|
|
} |
|
|
except Exception as e: |
|
|
last_error = e |
|
|
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") |
|
|
fallback_result = self._apply_fallback(func, e, attempt, original_args, original_kwargs) |
|
|
if fallback_result['handled']: |
|
|
args = fallback_result.get('new_args', args) |
|
|
kwargs = fallback_result.get('new_kwargs', kwargs) |
|
|
else: |
|
|
break |
|
|
attempt += 1 |
|
|
|
|
|
logger.error(f"All attempts failed for {func.__name__}") |
|
|
return self._final_fallback(func, last_error, original_args) |
|
|
|
|
|
def _apply_fallback(self, func, error: Exception, attempt: int, original_args: tuple, original_kwargs: dict) -> Dict[str, Any]: |
|
|
error_type = type(error).__name__ |
|
|
self.fallback_history.append({ |
|
|
'function': func.__name__, |
|
|
'error': error_type, |
|
|
'attempt': attempt |
|
|
}) |
|
|
|
|
|
if 'CUDA' in str(error) or 'GPU' in str(error): |
|
|
return self._handle_gpu_error(original_kwargs) |
|
|
elif 'memory' in str(error).lower(): |
|
|
return self._handle_memory_error(original_args, original_kwargs) |
|
|
elif 'timeout' in str(error).lower(): |
|
|
return self._handle_timeout_error(original_kwargs) |
|
|
elif 'model' in str(error).lower(): |
|
|
return self._handle_model_error(original_kwargs) |
|
|
else: |
|
|
return self._handle_generic_error(attempt, original_kwargs) |
|
|
|
|
|
def _handle_gpu_error(self, kwargs: dict) -> Dict[str, Any]: |
|
|
logger.info("GPU error detected, falling back to CPU") |
|
|
if self.config.gpu_fallback_to_cpu: |
|
|
self.device_manager.device = torch.device('cpu') |
|
|
kwargs['device'] = 'cpu' |
|
|
if 'batch_size' in kwargs: |
|
|
kwargs['batch_size'] = max(1, kwargs['batch_size'] // 2) |
|
|
self.current_level = FallbackLevel.METHOD_SWITCH |
|
|
return { |
|
|
'handled': True, |
|
|
'new_kwargs': kwargs |
|
|
} |
|
|
return {'handled': False} |
|
|
|
|
|
def _handle_memory_error(self, args: tuple, kwargs: dict) -> Dict[str, Any]: |
|
|
logger.info("Memory error detected, reducing quality") |
|
|
image = None |
|
|
image_idx = -1 |
|
|
for i, arg in enumerate(args): |
|
|
if isinstance(arg, np.ndarray) and len(arg.shape) == 3: |
|
|
image = arg |
|
|
image_idx = i |
|
|
break |
|
|
if image is not None and self.config.progressive_downscale: |
|
|
h, w = image.shape[:2] |
|
|
new_h = int(h * self.config.quality_reduction_factor) |
|
|
new_w = int(w * self.config.quality_reduction_factor) |
|
|
new_h = max(new_h, self.config.min_resolution[1]) |
|
|
new_w = max(new_w, self.config.min_resolution[0]) |
|
|
if new_h < h or new_w < w: |
|
|
resized = cv2.resize(image, (new_w, new_h)) |
|
|
args = list(args) |
|
|
args[image_idx] = resized |
|
|
self.current_level = FallbackLevel.QUALITY_REDUCTION |
|
|
return { |
|
|
'handled': True, |
|
|
'new_args': tuple(args), |
|
|
'new_kwargs': kwargs |
|
|
} |
|
|
if 'quality' in kwargs: |
|
|
kwargs['quality'] = max( |
|
|
self.config.min_quality, |
|
|
kwargs['quality'] * self.config.quality_reduction_factor |
|
|
) |
|
|
return { |
|
|
'handled': True, |
|
|
'new_kwargs': kwargs |
|
|
} |
|
|
|
|
|
def _handle_timeout_error(self, kwargs: dict) -> Dict[str, Any]: |
|
|
logger.info("Timeout detected, simplifying processing") |
|
|
simplifications = { |
|
|
'use_refinement': False, |
|
|
'use_temporal': False, |
|
|
'use_guided_filter': False, |
|
|
'iterations': 1, |
|
|
'num_samples': 1 |
|
|
} |
|
|
for key, value in simplifications.items(): |
|
|
if key in kwargs: |
|
|
kwargs[key] = value |
|
|
self.current_level = FallbackLevel.BASIC_PROCESSING |
|
|
return { |
|
|
'handled': True, |
|
|
'new_kwargs': kwargs |
|
|
} |
|
|
|
|
|
def _handle_model_error(self, kwargs: dict) -> Dict[str, Any]: |
|
|
logger.info("Model error detected, using simpler model") |
|
|
if 'model_type' in kwargs: |
|
|
model_hierarchy = ['large', 'base', 'small', 'tiny'] |
|
|
current = kwargs.get('model_type', 'base') |
|
|
if current in model_hierarchy: |
|
|
idx = model_hierarchy.index(current) |
|
|
if idx < len(model_hierarchy) - 1: |
|
|
kwargs['model_type'] = model_hierarchy[idx + 1] |
|
|
self.current_level = FallbackLevel.METHOD_SWITCH |
|
|
return { |
|
|
'handled': True, |
|
|
'new_kwargs': kwargs |
|
|
} |
|
|
kwargs['use_model'] = False |
|
|
self.current_level = FallbackLevel.BASIC_PROCESSING |
|
|
return { |
|
|
'handled': True, |
|
|
'new_kwargs': kwargs |
|
|
} |
|
|
|
|
|
def _handle_generic_error(self, attempt: int, kwargs: dict) -> Dict[str, Any]: |
|
|
logger.info(f"Generic error, applying degradation level {attempt + 1}") |
|
|
if attempt == 0: |
|
|
self.current_level = FallbackLevel.QUALITY_REDUCTION |
|
|
if 'quality' in kwargs: |
|
|
kwargs['quality'] *= 0.8 |
|
|
elif attempt == 1: |
|
|
self.current_level = FallbackLevel.METHOD_SWITCH |
|
|
kwargs['method'] = 'basic' |
|
|
else: |
|
|
self.current_level = FallbackLevel.MINIMAL_PROCESSING |
|
|
kwargs['skip_refinement'] = True |
|
|
kwargs['fast_mode'] = True |
|
|
return { |
|
|
'handled': True, |
|
|
'new_kwargs': kwargs |
|
|
} |
|
|
|
|
|
def _final_fallback(self, func, error: Exception, original_args: tuple) -> Dict[str, Any]: |
|
|
logger.error(f"Final fallback for {func.__name__}: {str(error)}") |
|
|
self.current_level = FallbackLevel.PASSTHROUGH |
|
|
for arg in original_args: |
|
|
if isinstance(arg, np.ndarray): |
|
|
return { |
|
|
'success': False, |
|
|
'result': arg, |
|
|
'fallback_level': self.current_level, |
|
|
'error': str(error) |
|
|
} |
|
|
return { |
|
|
'success': False, |
|
|
'result': None, |
|
|
'fallback_level': self.current_level, |
|
|
'error': str(error) |
|
|
} |
|
|
|
|
|
class ProcessingFallback: |
|
|
def __init__(self): |
|
|
self.logger = setup_logger(f"{__name__}.ProcessingFallback") |
|
|
self.quality_analyzer = QualityAnalyzer() |
|
|
|
|
|
def basic_segmentation(self, image: np.ndarray) -> np.ndarray: |
|
|
try: |
|
|
if len(image.shape) == 3: |
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
else: |
|
|
gray = image |
|
|
mask = np.zeros(gray.shape[:2], np.uint8) |
|
|
bgd_model = np.zeros((1, 65), np.float64) |
|
|
fgd_model = np.zeros((1, 65), np.float64) |
|
|
h, w = gray.shape[:2] |
|
|
rect = (int(w * 0.1), int(h * 0.1), int(w * 0.8), int(h * 0.8)) |
|
|
cv2.grabCut(image, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT) |
|
|
mask2 = np.where((mask == 2) | (mask == 0), 0, 255).astype('uint8') |
|
|
return mask2 |
|
|
except Exception as e: |
|
|
self.logger.error(f"Basic segmentation failed: {e}") |
|
|
return self._center_blob_mask(image.shape[:2]) |
|
|
|
|
|
def _center_blob_mask(self, shape: Tuple[int, int]) -> np.ndarray: |
|
|
h, w = shape |
|
|
mask = np.zeros((h, w), dtype=np.uint8) |
|
|
center = (w // 2, h // 2) |
|
|
axes = (w // 3, h // 3) |
|
|
cv2.ellipse(mask, center, axes, 0, 0, 360, 255, -1) |
|
|
mask = cv2.GaussianBlur(mask, (21, 21), 10) |
|
|
_, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
|
|
return mask |
|
|
|
|
|
def basic_matting(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray: |
|
|
try: |
|
|
if mask.dtype != np.uint8: |
|
|
mask = (mask * 255).astype(np.uint8) |
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
|
|
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
|
|
mask = cv2.GaussianBlur(mask, (5, 5), 2) |
|
|
alpha = mask.astype(np.float32) / 255.0 |
|
|
return alpha |
|
|
except Exception as e: |
|
|
self.logger.error(f"Basic matting failed: {e}") |
|
|
return mask.astype(np.float32) / 255.0 |
|
|
|
|
|
def color_difference_keying(self, image: np.ndarray, key_color: Optional[np.ndarray] = None, threshold: float = 30) -> np.ndarray: |
|
|
try: |
|
|
if key_color is None: |
|
|
h, w = image.shape[:2] |
|
|
corners = [ |
|
|
image[0:10, 0:10], |
|
|
image[0:10, w-10:w], |
|
|
image[h-10:h, 0:10], |
|
|
image[h-10:h, w-10:w] |
|
|
] |
|
|
key_color = np.mean([np.mean(c, axis=(0, 1)) for c in corners], axis=0) |
|
|
diff = np.sqrt(np.sum((image - key_color) ** 2, axis=2)) |
|
|
mask = (diff > threshold).astype(np.float32) |
|
|
mask = cv2.GaussianBlur(mask, (5, 5), 2) |
|
|
return mask |
|
|
except Exception as e: |
|
|
self.logger.error(f"Color keying failed: {e}") |
|
|
return np.ones(image.shape[:2], dtype=np.float32) |
|
|
|
|
|
def edge_based_segmentation(self, image: np.ndarray) -> np.ndarray: |
|
|
try: |
|
|
if len(image.shape) == 3: |
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
else: |
|
|
gray = image |
|
|
edges = cv2.Canny(gray, 50, 150) |
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) |
|
|
closed = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel, iterations=2) |
|
|
contours, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
mask = np.zeros(gray.shape, dtype=np.uint8) |
|
|
if contours: |
|
|
largest = max(contours, key=cv2.contourArea) |
|
|
cv2.drawContours(mask, [largest], -1, 255, -1) |
|
|
return mask |
|
|
except Exception as e: |
|
|
self.logger.error(f"Edge segmentation failed: {e}") |
|
|
return self._center_blob_mask(image.shape[:2]) |
|
|
|
|
|
def cached_result(self, cache_key: str, fallback_func, *args, **kwargs) -> Any: |
|
|
if not hasattr(self, '_cache'): |
|
|
self._cache = {} |
|
|
if cache_key in self._cache: |
|
|
self.logger.info(f"Using cached result for {cache_key}") |
|
|
return self._cache[cache_key] |
|
|
try: |
|
|
result = fallback_func(*args, **kwargs) |
|
|
self._cache[cache_key] = result |
|
|
if len(self._cache) > 100: |
|
|
keys = list(self._cache.keys()) |
|
|
for key in keys[:20]: |
|
|
del self._cache[key] |
|
|
return result |
|
|
except Exception as e: |
|
|
self.logger.error(f"Cached computation failed: {e}") |
|
|
return None |
|
|
|