#!/usr/bin/env python3 """ utils.refinement High-quality mask refinement for BackgroundFX Pro. """ from __future__ import annotations from typing import Any, Optional, Tuple, List import logging import cv2 import numpy as np import torch log = logging.getLogger(__name__) # ============================================================================ # CUSTOM EXCEPTION # ============================================================================ class MaskRefinementError(Exception): """Custom exception for mask refinement errors""" pass # ============================================================================ # EXPORTS # ============================================================================ __all__ = [ "refine_mask_hq", "refine_masks_batch", "MaskRefinementError", ] # ============================================================================ # MAIN API - SINGLE FRAME # ============================================================================ def refine_mask_hq( image: np.ndarray, mask: np.ndarray, matanyone_model: Optional[Any] = None, fallback_enabled: bool = True ) -> np.ndarray: """ High-quality mask refinement with multiple strategies. Args: image: Original BGR image mask: Initial binary mask (0/255) matanyone_model: Optional MatAnyone model for AI refinement fallback_enabled: Whether to use fallback methods if AI fails Returns: Refined binary mask (0/255) """ if image is None or mask is None: raise MaskRefinementError("Invalid input image or mask") if image.shape[:2] != mask.shape[:2]: raise MaskRefinementError(f"Image shape {image.shape[:2]} doesn't match mask shape {mask.shape[:2]}") # Try AI-based refinement first if model available if matanyone_model is not None: try: refined = _refine_with_matanyone(image, mask, matanyone_model) if _validate_refined_mask(refined, mask): return refined log.warning("MatAnyone refinement failed validation") except Exception as e: log.warning(f"MatAnyone refinement failed: {e}") # Fallback to classical refinement methods if fallback_enabled: try: return _classical_refinement(image, mask) except Exception as e: log.warning(f"Classical refinement failed: {e}") return mask # Return original if all fails return mask # ============================================================================ # BATCH PROCESSING FOR TEMPORAL CONSISTENCY # ============================================================================ def refine_masks_batch( frames: List[np.ndarray], masks: List[np.ndarray], matanyone_model: Optional[Any] = None, fallback_enabled: bool = True ) -> List[np.ndarray]: """ Refine multiple masks using MatAnyone's temporal consistency. Args: frames: List of BGR images masks: List of initial binary masks matanyone_model: MatAnyone InferenceCore model fallback_enabled: Whether to use fallback methods Returns: List of refined binary masks """ if not frames or not masks: return masks if len(frames) != len(masks): raise MaskRefinementError(f"Frame count {len(frames)} doesn't match mask count {len(masks)}") if matanyone_model is not None: try: refined = _refine_batch_with_matanyone(frames, masks, matanyone_model) # Validate all masks if all(_validate_refined_mask(r, m) for r, m in zip(refined, masks)): return refined log.warning("Batch MatAnyone refinement failed validation") except Exception as e: log.warning(f"Batch MatAnyone refinement failed: {e}") # Fallback to frame-by-frame classical refinement if fallback_enabled: return [_classical_refinement(f, m) for f, m in zip(frames, masks)] return masks # ============================================================================ # AI-BASED REFINEMENT - SINGLE FRAME # ============================================================================ def _refine_with_matanyone( image: np.ndarray, mask: np.ndarray, model: Any ) -> np.ndarray: """Use MatAnyone model for mask refinement.""" try: # Set device to GPU (Tesla T4 on cuda:0) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Convert BGR to RGB and normalize image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) h, w = image_rgb.shape[:2] # Convert to torch tensor format (C, H, W) and normalize to [0, 1] image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0 image_tensor = image_tensor.unsqueeze(0).to(device) # Add batch dimension and move to GPU # CRITICAL: Ensure mask is 2D before processing if mask.ndim == 3: # Convert multi-channel to single channel if mask.shape[2] == 3: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) else: mask = mask[:, :, 0] # Ensure mask is binary uint8 if mask.dtype != np.uint8: mask = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8) # Final verification that mask is 2D assert mask.ndim == 2, f"Mask must be 2D after conversion, got shape {mask.shape}" assert mask.shape == (h, w), f"Mask shape {mask.shape} doesn't match image shape ({h}, {w})" # Convert mask to tensor and move to GPU mask_tensor = torch.from_numpy(mask).float() / 255.0 mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, H, W) on GPU # Verify tensor dimensions assert mask_tensor.shape == (1, 1, h, w), f"Mask tensor wrong shape: {mask_tensor.shape}, expected (1, 1, {h}, {w})" # Try different methods on InferenceCore result = None # Log available methods for debugging methods = [m for m in dir(model) if not m.startswith('_')] log.debug(f"MatAnyone InferenceCore methods: {methods}") with torch.no_grad(): if hasattr(model, 'step'): # Step method for iterative processing result = model.step(image_tensor, mask_tensor) elif hasattr(model, 'process_frame'): result = model.process_frame(image_tensor, mask_tensor) elif hasattr(model, 'forward'): result = model.forward(image_tensor, mask_tensor) elif hasattr(model, '__call__'): result = model(image_tensor, mask_tensor) else: raise MaskRefinementError(f"No recognized method. Available: {methods}") if result is None: raise MaskRefinementError("MatAnyone returned None") # Extract alpha matte from result alpha = _extract_alpha_from_result(result) # Convert back to numpy and resize if needed if isinstance(alpha, torch.Tensor): alpha = alpha.squeeze().cpu().numpy() if alpha.ndim == 3: alpha = alpha[0] if alpha.shape[0] == 1 else alpha.mean(axis=0) if alpha.dtype != np.uint8: alpha = (alpha * 255).clip(0, 255).astype(np.uint8) if alpha.shape != (h, w): alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LINEAR) return _process_mask(alpha) except Exception as e: log.error(f"MatAnyone processing error: {str(e)}") raise MaskRefinementError(f"MatAnyone processing failed: {str(e)}") # ============================================================================ # AI-BASED REFINEMENT - BATCH # ============================================================================ def _refine_batch_with_matanyone( frames: List[np.ndarray], masks: List[np.ndarray], model: Any ) -> List[np.ndarray]: """Process batch of frames through MatAnyone for temporal consistency.""" try: # Set device to GPU (Tesla T4 on cuda:0) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') batch_size = len(frames) h, w = frames[0].shape[:2] # Convert frames to tensor batch and move to GPU frame_tensors = [] for frame in frames: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0 frame_tensors.append(tensor) # Stack into batch (N, C, H, W) and move to GPU batch_tensor = torch.stack(frame_tensors).to(device) # Prepare first mask for initialization first_mask = masks[0] # CRITICAL: Ensure first mask is 2D if first_mask.ndim == 3: if first_mask.shape[2] == 3: first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY) else: first_mask = first_mask[:, :, 0] if first_mask.dtype != np.uint8: first_mask = (first_mask * 255).astype(np.uint8) if first_mask.max() <= 1 else first_mask.astype(np.uint8) assert first_mask.ndim == 2, f"First mask must be 2D, got shape {first_mask.shape}" # Convert first mask to tensor and move to GPU first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0 first_mask_tensor = first_mask_tensor.unsqueeze(0).unsqueeze(0).to(device) refined_masks = [] with torch.no_grad(): # Check for batch processing methods if hasattr(model, 'process_batch'): # Direct batch processing results = model.process_batch(batch_tensor, first_mask_tensor) for result in results: alpha = _extract_alpha_from_result(result) refined_masks.append(_tensor_to_mask(alpha, h, w)) elif hasattr(model, 'step'): # Process frames sequentially with memory for i, frame_tensor in enumerate(frame_tensors): frame_on_device = frame_tensor.unsqueeze(0).to(device) if i == 0: # First frame with mask result = model.step(frame_on_device, first_mask_tensor) else: # Subsequent frames use memory from previous result = model.step(frame_on_device, None) alpha = _extract_alpha_from_result(result) refined_masks.append(_tensor_to_mask(alpha, h, w)) else: # Fallback to processing each frame with its mask log.warning("MatAnyone batch processing not available, using frame-by-frame") for frame_tensor, mask in zip(frame_tensors, masks): # Ensure each mask is 2D if mask.ndim == 3: if mask.shape[2] == 3: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) else: mask = mask[:, :, 0] mask_tensor = torch.from_numpy(mask).float() / 255.0 mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) frame_on_device = frame_tensor.unsqueeze(0).to(device) result = model(frame_on_device, mask_tensor) alpha = _extract_alpha_from_result(result) refined_masks.append(_tensor_to_mask(alpha, h, w)) return refined_masks except Exception as e: log.error(f"Batch MatAnyone processing error: {str(e)}") raise MaskRefinementError(f"Batch processing failed: {str(e)}") # ============================================================================ # HELPER FUNCTIONS # ============================================================================ def _extract_alpha_from_result(result): """Extract alpha matte from various result formats.""" if isinstance(result, (tuple, list)): return result[0] if len(result) > 0 else None elif isinstance(result, dict): return result.get('alpha', result.get('matte', result.get('mask', None))) else: return result def _tensor_to_mask(tensor, target_h, target_w): """Convert tensor to numpy mask with proper sizing.""" if isinstance(tensor, torch.Tensor): mask = tensor.squeeze().cpu().numpy() else: mask = tensor if mask.ndim == 3: mask = mask[0] if mask.shape[0] == 1 else mask.mean(axis=0) if mask.dtype != np.uint8: mask = (mask * 255).clip(0, 255).astype(np.uint8) if mask.shape != (target_h, target_w): mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_LINEAR) return mask def _validate_refined_mask(refined: np.ndarray, original: np.ndarray) -> bool: """Check if refined mask is reasonable.""" if refined is None or refined.size == 0: return False refined_area = np.sum(refined > 127) original_area = np.sum(original > 127) if refined_area == 0: return False ratio = refined_area / max(original_area, 1) return 0.5 <= ratio <= 2.0 def _process_mask(mask: np.ndarray) -> np.ndarray: """Convert any mask format to binary 0/255.""" if mask.dtype == np.float32 or mask.dtype == np.float64: if mask.max() <= 1.0: mask = (mask * 255).astype(np.uint8) if mask.dtype != np.uint8: mask = mask.astype(np.uint8) if mask.ndim == 3: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) _, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) return binary # ============================================================================ # CLASSICAL REFINEMENT # ============================================================================ def _classical_refinement(image: np.ndarray, mask: np.ndarray) -> np.ndarray: """Apply classical CV techniques for mask refinement.""" refined = mask.copy() kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel) refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel) refined = _edge_aware_smooth(image, refined) refined = _feather_edges(refined, radius=3) refined = _remove_small_components(refined, min_area_ratio=0.005) return refined def _edge_aware_smooth(image: np.ndarray, mask: np.ndarray) -> np.ndarray: """Apply edge-aware smoothing using guided filter.""" mask_float = mask.astype(np.float32) / 255.0 radius = 5 eps = 0.01 gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0 mean_I = cv2.boxFilter(gray, -1, (radius, radius)) mean_p = cv2.boxFilter(mask_float, -1, (radius, radius)) mean_Ip = cv2.boxFilter(gray * mask_float, -1, (radius, radius)) cov_Ip = mean_Ip - mean_I * mean_p mean_II = cv2.boxFilter(gray * gray, -1, (radius, radius)) var_I = mean_II - mean_I * mean_I a = cov_Ip / (var_I + eps) b = mean_p - a * mean_I mean_a = cv2.boxFilter(a, -1, (radius, radius)) mean_b = cv2.boxFilter(b, -1, (radius, radius)) refined = mean_a * gray + mean_b return (refined * 255).clip(0, 255).astype(np.uint8) def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray: """Slightly blur edges for smoother transitions.""" if radius <= 0: return mask blurred = cv2.GaussianBlur(mask, (radius*2+1, radius*2+1), radius/2) _, binary = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY) return binary def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) -> np.ndarray: """Remove small disconnected components.""" num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) if num_labels <= 1: return mask total_area = mask.shape[0] * mask.shape[1] min_area = int(total_area * min_area_ratio) areas = stats[1:, cv2.CC_STAT_AREA] if len(areas) == 0: return mask max_label = np.argmax(areas) + 1 cleaned = np.zeros_like(mask) for label in range(1, num_labels): if stats[label, cv2.CC_STAT_AREA] >= min_area or label == max_label: cleaned[labels == label] = 255 return cleaned