|
|
|
|
|
""" |
|
|
Model management and optimization for BackgroundFX Pro. |
|
|
Fixes MatAnyone quality issues and manages model loading. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
from functools import lru_cache |
|
|
from pathlib import Path |
|
|
from typing import Dict, Any, Optional, Tuple, List |
|
|
|
|
|
import gc |
|
|
import logging |
|
|
import warnings |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
"""Configuration for model management.""" |
|
|
sam2_checkpoint: str = "checkpoints/sam2_hiera_large.pt" |
|
|
sam2_config: str = "configs/sam2_hiera_l.yaml" |
|
|
matanyone_checkpoint: str = "checkpoints/matanyone_v2.pth" |
|
|
device: str = "cuda" |
|
|
dtype: torch.dtype = torch.float16 |
|
|
optimize_memory: bool = True |
|
|
use_amp: bool = True |
|
|
cache_size: int = 5 |
|
|
enable_quality_fixes: bool = True |
|
|
matanyone_enhancement: bool = True |
|
|
use_tensorrt: bool = False |
|
|
batch_size: int = 1 |
|
|
|
|
|
|
|
|
class ModelCache: |
|
|
"""Intelligent model caching system.""" |
|
|
|
|
|
def __init__(self, max_size: int = 5): |
|
|
self.cache: Dict[str, Any] = {} |
|
|
self.max_size = max_size |
|
|
self.access_count: Dict[str, int] = {} |
|
|
self.memory_usage: Dict[str, float] = {} |
|
|
|
|
|
def add(self, key: str, model: Any, memory_size: float): |
|
|
"""Add model to cache with memory tracking.""" |
|
|
if len(self.cache) >= self.max_size and self.access_count: |
|
|
lru_key = min(self.access_count, key=self.access_count.get) |
|
|
self.remove(lru_key) |
|
|
|
|
|
self.cache[key] = model |
|
|
self.access_count[key] = 0 |
|
|
self.memory_usage[key] = memory_size |
|
|
|
|
|
def get(self, key: str) -> Optional[Any]: |
|
|
"""Get model from cache.""" |
|
|
if key in self.cache: |
|
|
self.access_count[key] += 1 |
|
|
return self.cache[key] |
|
|
return None |
|
|
|
|
|
def remove(self, key: str): |
|
|
"""Remove model from cache and free memory.""" |
|
|
if key in self.cache: |
|
|
model = self.cache[key] |
|
|
del self.cache[key] |
|
|
self.access_count.pop(key, None) |
|
|
self.memory_usage.pop(key, None) |
|
|
|
|
|
|
|
|
try: |
|
|
del model |
|
|
except Exception: |
|
|
pass |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def clear(self): |
|
|
"""Clear entire cache.""" |
|
|
for key in list(self.cache.keys()): |
|
|
self.remove(key) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MatAnyoneModel(nn.Module): |
|
|
"""Enhanced MatAnyone model with quality fixes.""" |
|
|
|
|
|
def __init__(self, config: ModelConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.base_model: Optional[nn.Module] = None |
|
|
self.quality_enhancer = QualityEnhancer() if config.enable_quality_fixes else None |
|
|
self.loaded = False |
|
|
|
|
|
def load(self): |
|
|
"""Load MatAnyone model with optimizations.""" |
|
|
if self.loaded: |
|
|
return |
|
|
|
|
|
try: |
|
|
checkpoint_path = Path(self.config.matanyone_checkpoint) |
|
|
if not checkpoint_path.exists(): |
|
|
logger.warning(f"MatAnyone checkpoint not found at {checkpoint_path}") |
|
|
return |
|
|
|
|
|
|
|
|
state_dict = torch.load(checkpoint_path, map_location=self.config.device) |
|
|
|
|
|
|
|
|
self.base_model = self._build_matanyone_architecture() |
|
|
|
|
|
|
|
|
self._load_weights_safe(state_dict) |
|
|
|
|
|
|
|
|
if self.config.optimize_memory: |
|
|
self._optimize_model() |
|
|
|
|
|
self.loaded = True |
|
|
logger.info("MatAnyone model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load MatAnyone model: {e}") |
|
|
self.loaded = False |
|
|
|
|
|
def _build_matanyone_architecture(self) -> nn.Module: |
|
|
"""Build MatAnyone architecture (placeholder).""" |
|
|
|
|
|
class MatAnyoneBase(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.encoder = nn.Sequential( |
|
|
nn.Conv2d(4, 64, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 128, 3, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(128, 256, 3, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
) |
|
|
self.decoder = nn.Sequential( |
|
|
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 4, 3, padding=1), |
|
|
nn.Sigmoid(), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
features = self.encoder(x) |
|
|
output = self.decoder(features) |
|
|
return output |
|
|
|
|
|
model = MatAnyoneBase().to(self.config.device) |
|
|
if self.config.dtype == torch.float16 and "cuda" in str(self.config.device).lower() and torch.cuda.is_available(): |
|
|
model = model.half() |
|
|
return model |
|
|
|
|
|
def _load_weights_safe(self, state_dict: Dict): |
|
|
"""Safely load weights with compatibility handling.""" |
|
|
if self.base_model is None: |
|
|
return |
|
|
|
|
|
model_dict = self.base_model.state_dict() |
|
|
|
|
|
compatible_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
k_clean = k[7:] if k.startswith("module.") else k |
|
|
if k_clean in model_dict and model_dict[k_clean].shape == v.shape: |
|
|
compatible_dict[k_clean] = v |
|
|
else: |
|
|
logger.warning(f"Skipping incompatible weight: {k}") |
|
|
|
|
|
model_dict.update(compatible_dict) |
|
|
self.base_model.load_state_dict(model_dict, strict=False) |
|
|
logger.info(f"Loaded {len(compatible_dict)}/{len(state_dict)} weights") |
|
|
|
|
|
def _optimize_model(self): |
|
|
"""Optimize model for inference.""" |
|
|
if self.base_model is None: |
|
|
return |
|
|
|
|
|
self.base_model.eval() |
|
|
|
|
|
for p in self.base_model.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
if self.config.use_tensorrt: |
|
|
try: |
|
|
self._optimize_with_tensorrt() |
|
|
except Exception as e: |
|
|
logger.warning(f"TensorRT optimization failed: {e}") |
|
|
|
|
|
def _optimize_with_tensorrt(self): |
|
|
"""Placeholder for optional TensorRT optimization.""" |
|
|
raise NotImplementedError("TensorRT path not implemented") |
|
|
|
|
|
def forward(self, image: torch.Tensor, mask: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
"""Enhanced forward pass with quality fixes.""" |
|
|
if not self.loaded: |
|
|
self.load() |
|
|
|
|
|
if self.base_model is None: |
|
|
return {"alpha": mask.unsqueeze(1), "foreground": image, "confidence": torch.tensor([0.0], device=image.device)} |
|
|
|
|
|
|
|
|
x = torch.cat([image, mask.unsqueeze(1)], dim=1) |
|
|
|
|
|
|
|
|
if self.config.matanyone_enhancement: |
|
|
x = self._preprocess_input(x) |
|
|
|
|
|
amp_enabled = self.config.use_amp and torch.cuda.is_available() and "cuda" in str(self.config.device).lower() |
|
|
with torch.cuda.amp.autocast(enabled=amp_enabled): |
|
|
output = self.base_model(x) |
|
|
|
|
|
alpha = output[:, 3:4, :, :] |
|
|
foreground = output[:, :3, :, :] |
|
|
|
|
|
if self.quality_enhancer: |
|
|
alpha = self.quality_enhancer.enhance_alpha(alpha, mask) |
|
|
foreground = self.quality_enhancer.enhance_foreground(foreground, image) |
|
|
|
|
|
alpha = self._fix_matanyone_artifacts(alpha, mask) |
|
|
|
|
|
return { |
|
|
"alpha": alpha, |
|
|
"foreground": foreground, |
|
|
"confidence": self._compute_confidence(alpha, mask), |
|
|
} |
|
|
|
|
|
def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Preprocess input to improve MatAnyone quality.""" |
|
|
if x.shape[2] > 64: |
|
|
x = self._bilateral_filter_torch(x) |
|
|
x = torch.clamp(x, 0, 1) |
|
|
|
|
|
|
|
|
mask_channel = x[:, 3:4, :, :] |
|
|
mask_enhanced = self._enhance_mask_edges(mask_channel) |
|
|
x = torch.cat([x[:, :3, :, :], mask_enhanced], dim=1) |
|
|
return x |
|
|
|
|
|
def _fix_matanyone_artifacts(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Fix common MatAnyone artifacts.""" |
|
|
alpha = self._fix_edge_bleeding(alpha, original_mask) |
|
|
alpha = self._fix_transparency_issues(alpha) |
|
|
alpha = self._ensure_mask_consistency(alpha, original_mask) |
|
|
return alpha |
|
|
|
|
|
def _fix_edge_bleeding(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Fix edge bleeding artifacts.""" |
|
|
edges = self._detect_edges_torch(original_mask) |
|
|
edge_mask = F.max_pool2d(edges, kernel_size=5, stride=1, padding=2) |
|
|
|
|
|
alpha_refined = alpha.clone() |
|
|
edge_region = edge_mask > 0.1 |
|
|
if edge_region.any(): |
|
|
alpha_refined[edge_region] = ( |
|
|
0.7 * alpha[edge_region] + 0.3 * original_mask.unsqueeze(1).expand_as(alpha)[edge_region] |
|
|
) |
|
|
return alpha_refined |
|
|
|
|
|
def _fix_transparency_issues(self, alpha: torch.Tensor) -> torch.Tensor: |
|
|
"""Fix transparency artifacts.""" |
|
|
mid_range = (alpha > 0.2) & (alpha < 0.8) |
|
|
alpha_fixed = alpha.clone() |
|
|
alpha_fixed[mid_range] = torch.where( |
|
|
alpha[mid_range] > 0.5, |
|
|
torch.clamp(alpha[mid_range] * 1.2, max=1.0), |
|
|
torch.clamp(alpha[mid_range] * 0.8, min=0.0), |
|
|
) |
|
|
alpha_fixed = F.gaussian_blur(alpha_fixed, kernel_size=(3, 3)) |
|
|
return alpha_fixed |
|
|
|
|
|
def _ensure_mask_consistency(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Ensure consistency with original mask.""" |
|
|
if original_mask.dim() == 2: |
|
|
original_mask = original_mask.unsqueeze(0).unsqueeze(0) |
|
|
elif original_mask.dim() == 3: |
|
|
original_mask = original_mask.unsqueeze(1) |
|
|
|
|
|
alpha = torch.where(original_mask < 0.1, torch.zeros_like(alpha), alpha) |
|
|
alpha = torch.where(original_mask > 0.9, torch.ones_like(alpha) * 0.95, alpha) |
|
|
return alpha |
|
|
|
|
|
def _compute_confidence(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Compute confidence score for the output.""" |
|
|
if original_mask.dim() < alpha.dim(): |
|
|
original_mask = original_mask.unsqueeze(1).expand_as(alpha) |
|
|
diff = torch.abs(alpha - original_mask) |
|
|
confidence = 1.0 - torch.mean(diff, dim=(1, 2, 3)) |
|
|
return confidence |
|
|
|
|
|
def _bilateral_filter_torch(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Approximate bilateral filter via Gaussian blur.""" |
|
|
return F.gaussian_blur(x, kernel_size=(5, 5)) |
|
|
|
|
|
def _enhance_mask_edges(self, mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Enhance edges in mask channel.""" |
|
|
edges = self._detect_edges_torch(mask) |
|
|
enhanced = torch.clamp(mask + 0.3 * edges, 0, 1) |
|
|
return enhanced |
|
|
|
|
|
def _detect_edges_torch(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Detect edges using Sobel filters.""" |
|
|
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=x.dtype, device=x.device).view(1, 1, 3, 3) |
|
|
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=x.dtype, device=x.device).view(1, 1, 3, 3) |
|
|
edges_x = F.conv2d(x, sobel_x, padding=1) |
|
|
edges_y = F.conv2d(x, sobel_y, padding=1) |
|
|
edges = torch.sqrt(edges_x ** 2 + edges_y ** 2) |
|
|
return edges |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SAM2Model: |
|
|
"""SAM2 model wrapper with optimizations.""" |
|
|
|
|
|
def __init__(self, config: ModelConfig): |
|
|
self.config = config |
|
|
self.model = None |
|
|
self.predictor = None |
|
|
self.loaded = False |
|
|
|
|
|
def load(self): |
|
|
"""Load SAM2 model.""" |
|
|
if self.loaded: |
|
|
return |
|
|
|
|
|
try: |
|
|
from sam2.build_sam import build_sam2 |
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
|
|
self.model = build_sam2( |
|
|
config_file=self.config.sam2_config, |
|
|
ckpt_path=self.config.sam2_checkpoint, |
|
|
device=self.config.device, |
|
|
) |
|
|
self.predictor = SAM2ImagePredictor(self.model) |
|
|
|
|
|
self.loaded = True |
|
|
logger.info("SAM2 model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load SAM2 model: {e}") |
|
|
self.loaded = False |
|
|
|
|
|
def predict(self, image: np.ndarray, prompts: Optional[Dict] = None) -> np.ndarray: |
|
|
"""Generate segmentation mask.""" |
|
|
if not self.loaded: |
|
|
self.load() |
|
|
|
|
|
if self.predictor is None: |
|
|
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) |
|
|
|
|
|
self.predictor.set_image(image) |
|
|
|
|
|
if prompts: |
|
|
masks, scores, _ = self.predictor.predict( |
|
|
point_coords=prompts.get("points"), |
|
|
point_labels=prompts.get("labels"), |
|
|
box=prompts.get("box"), |
|
|
multimask_output=True, |
|
|
) |
|
|
mask = masks[int(np.argmax(scores))] |
|
|
else: |
|
|
|
|
|
try: |
|
|
masks = self.predictor.generate_auto_masks(image) |
|
|
mask = masks[0] if len(masks) > 0 else np.zeros_like(image[:, :, 0]) |
|
|
except Exception: |
|
|
|
|
|
mask = np.zeros_like(image[:, :, 0]) |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QualityEnhancer(nn.Module): |
|
|
"""Neural quality enhancement module.""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.alpha_refiner = nn.Sequential( |
|
|
nn.Conv2d(1, 16, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(16, 16, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(16, 1, 3, padding=1), |
|
|
nn.Sigmoid(), |
|
|
) |
|
|
|
|
|
self.foreground_enhancer = nn.Sequential( |
|
|
nn.Conv2d(3, 32, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(32, 32, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(32, 3, 3, padding=1), |
|
|
nn.Tanh(), |
|
|
) |
|
|
|
|
|
def enhance_alpha(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Enhance alpha channel quality.""" |
|
|
refined = self.alpha_refiner(alpha) |
|
|
enhanced = torch.clamp(0.7 * refined + 0.3 * alpha, 0, 1) |
|
|
return enhanced |
|
|
|
|
|
def enhance_foreground(self, foreground: torch.Tensor, original_image: torch.Tensor) -> torch.Tensor: |
|
|
"""Enhance foreground quality.""" |
|
|
residual = self.foreground_enhancer(foreground) |
|
|
enhanced = torch.clamp(foreground + 0.1 * residual, -1, 1) |
|
|
|
|
|
if foreground.min() >= 0.0 and foreground.max() <= 1.0: |
|
|
enhanced = torch.clamp(enhanced, 0.0, 1.0) |
|
|
return enhanced |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelManager: |
|
|
"""Central model management system.""" |
|
|
|
|
|
def __init__(self, config: Optional[ModelConfig] = None): |
|
|
self.config = config or ModelConfig() |
|
|
self.cache = ModelCache(max_size=self.config.cache_size) |
|
|
|
|
|
|
|
|
self.sam2 = SAM2Model(self.config) |
|
|
self.matanyone = MatAnyoneModel(self.config) |
|
|
|
|
|
def load_all(self): |
|
|
"""Load all models.""" |
|
|
logger.info("Loading all models...") |
|
|
self.sam2.load() |
|
|
self.matanyone.load() |
|
|
logger.info("All models loaded") |
|
|
|
|
|
def get_sam2(self) -> 'SAM2Model': |
|
|
"""Get SAM2 model (lazy-loaded).""" |
|
|
if not self.sam2.loaded: |
|
|
self.sam2.load() |
|
|
return self.sam2 |
|
|
|
|
|
def get_matanyone(self) -> 'MatAnyoneModel': |
|
|
"""Get MatAnyone model (lazy-loaded).""" |
|
|
if not self.matanyone.loaded: |
|
|
self.matanyone.load() |
|
|
return self.matanyone |
|
|
|
|
|
def process_frame(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Dict[str, Any]: |
|
|
"""Process single frame through the pipeline.""" |
|
|
image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0 |
|
|
image_tensor = image_tensor.to(self.config.device) |
|
|
|
|
|
if mask is None: |
|
|
mask = self.sam2.predict(image) |
|
|
|
|
|
mask_tensor = torch.from_numpy(mask).float().to(self.config.device) |
|
|
|
|
|
result = self.matanyone(image_tensor, mask_tensor) |
|
|
|
|
|
output = { |
|
|
"alpha": result["alpha"].squeeze().cpu().numpy(), |
|
|
"foreground": (result["foreground"].squeeze().permute(1, 2, 0).cpu().numpy() * 255.0), |
|
|
"confidence": result["confidence"].detach().cpu().numpy(), |
|
|
} |
|
|
return output |
|
|
|
|
|
def cleanup(self): |
|
|
"""Cleanup models and free memory.""" |
|
|
self.cache.clear() |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelType(Enum): |
|
|
SAM2 = "sam2" |
|
|
MATANYONE = "matanyone" |
|
|
|
|
|
|
|
|
class ModelFactory: |
|
|
""" |
|
|
Lightweight factory that returns cached model instances by type. |
|
|
Kept for backward compatibility with modules importing from core.models. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[ModelConfig] = None): |
|
|
self.config = config or ModelConfig() |
|
|
self._instances: Dict[ModelType, Any] = {} |
|
|
|
|
|
def get(self, model_type: 'ModelType | str'): |
|
|
"""Return (and cache) a model instance for the given type.""" |
|
|
if isinstance(model_type, str): |
|
|
try: |
|
|
model_type = ModelType(model_type.lower()) |
|
|
except Exception: |
|
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
|
if model_type == ModelType.SAM2: |
|
|
if model_type not in self._instances: |
|
|
self._instances[model_type] = SAM2Model(self.config) |
|
|
return self._instances[model_type] |
|
|
|
|
|
if model_type == ModelType.MATANYONE: |
|
|
if model_type not in self._instances: |
|
|
self._instances[model_type] = MatAnyoneModel(self.config) |
|
|
return self._instances[model_type] |
|
|
|
|
|
raise ValueError(f"Unsupported model type: {model_type}") |
|
|
|
|
|
|
|
|
create = get |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"ModelManager", |
|
|
"SAM2Model", |
|
|
"MatAnyoneModel", |
|
|
"ModelConfig", |
|
|
"ModelCache", |
|
|
"QualityEnhancer", |
|
|
"ModelType", |
|
|
"ModelFactory", |
|
|
] |
|
|
|