|
|
|
|
|
""" |
|
|
MatAnyone Model Loader |
|
|
Handles MatAnyone loading with proper device initialization |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import logging |
|
|
import traceback |
|
|
from pathlib import Path |
|
|
from typing import Optional, Dict, Any |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class MatAnyoneLoader: |
|
|
"""Dedicated loader for MatAnyone models""" |
|
|
|
|
|
def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"): |
|
|
self.device = device |
|
|
self.cache_dir = cache_dir |
|
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
|
|
|
|
self.model = None |
|
|
self.model_id = "PeiqingYang/MatAnyone" |
|
|
self.load_time = 0.0 |
|
|
|
|
|
def load(self) -> Optional[Any]: |
|
|
""" |
|
|
Load MatAnyone model |
|
|
Returns: |
|
|
Loaded model or None |
|
|
""" |
|
|
logger.info(f"Loading MatAnyone model: {self.model_id}") |
|
|
|
|
|
|
|
|
strategies = [ |
|
|
("official", self._load_official), |
|
|
("fallback", self._load_fallback) |
|
|
] |
|
|
|
|
|
for strategy_name, strategy_func in strategies: |
|
|
try: |
|
|
logger.info(f"Trying MatAnyone loading strategy: {strategy_name}") |
|
|
start_time = time.time() |
|
|
model = strategy_func() |
|
|
if model: |
|
|
self.load_time = time.time() - start_time |
|
|
self.model = model |
|
|
logger.info(f"MatAnyone loaded successfully via {strategy_name} in {self.load_time:.2f}s") |
|
|
return model |
|
|
except Exception as e: |
|
|
logger.error(f"MatAnyone {strategy_name} strategy failed: {e}") |
|
|
logger.debug(traceback.format_exc()) |
|
|
continue |
|
|
|
|
|
logger.error("All MatAnyone loading strategies failed") |
|
|
return None |
|
|
|
|
|
def _load_official(self) -> Optional[Any]: |
|
|
"""Load using official MatAnyone API""" |
|
|
from matanyone import InferenceCore |
|
|
|
|
|
|
|
|
processor = InferenceCore(self.model_id) |
|
|
|
|
|
|
|
|
if hasattr(processor, 'device'): |
|
|
processor.device = self.device |
|
|
|
|
|
|
|
|
if hasattr(processor, 'model'): |
|
|
if hasattr(processor.model, 'to'): |
|
|
processor.model = processor.model.to(self.device) |
|
|
processor.model.eval() |
|
|
|
|
|
|
|
|
self._patch_processor(processor) |
|
|
|
|
|
return processor |
|
|
|
|
|
def _patch_processor(self, processor): |
|
|
""" |
|
|
Patch the MatAnyone processor to handle device placement and tensor formats correctly |
|
|
""" |
|
|
original_step = getattr(processor, 'step', None) |
|
|
original_process = getattr(processor, 'process', None) |
|
|
|
|
|
device = self.device |
|
|
|
|
|
def safe_wrapper(*args, **kwargs): |
|
|
"""Universal wrapper that handles both step and process calls""" |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = None |
|
|
mask = None |
|
|
idx_mask = kwargs.get('idx_mask', False) |
|
|
|
|
|
|
|
|
if 'image' in kwargs and 'mask' in kwargs: |
|
|
image = kwargs['image'] |
|
|
mask = kwargs['mask'] |
|
|
elif len(args) >= 2: |
|
|
image = args[0] |
|
|
mask = args[1] |
|
|
if len(args) > 2: |
|
|
idx_mask = args[2] |
|
|
elif len(args) == 1: |
|
|
|
|
|
mask = args[0] |
|
|
|
|
|
if isinstance(mask, np.ndarray): |
|
|
h, w = mask.shape[:2] if mask.ndim >= 2 else (512, 512) |
|
|
image = np.zeros((h, w, 3), dtype=np.uint8) |
|
|
elif isinstance(mask, torch.Tensor): |
|
|
h, w = mask.shape[-2:] if mask.dim() >= 2 else (512, 512) |
|
|
image = torch.zeros((h, w, 3), dtype=torch.uint8) |
|
|
|
|
|
if image is None or mask is None: |
|
|
logger.error(f"MatAnyone called with invalid args: {len(args)} args, kwargs: {kwargs.keys()}") |
|
|
|
|
|
if mask is not None: |
|
|
return mask |
|
|
return np.ones((512, 512), dtype=np.float32) * 0.5 |
|
|
|
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
image = torch.from_numpy(image).to(device) |
|
|
elif isinstance(image, torch.Tensor): |
|
|
image = image.to(device) |
|
|
|
|
|
if isinstance(mask, np.ndarray): |
|
|
mask = torch.from_numpy(mask).to(device) |
|
|
elif isinstance(mask, torch.Tensor): |
|
|
mask = mask.to(device) |
|
|
|
|
|
|
|
|
if image.dim() == 2: |
|
|
image = image.unsqueeze(0) |
|
|
elif image.dim() == 3: |
|
|
|
|
|
if image.shape[-1] in [1, 3, 4]: |
|
|
image = image.permute(2, 0, 1) |
|
|
|
|
|
if image.shape[0] in [1, 3, 4]: |
|
|
image = image.unsqueeze(0) |
|
|
elif image.dim() == 4: |
|
|
|
|
|
if image.shape[-1] in [1, 3, 4]: |
|
|
image = image.permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
if mask.dim() == 2: |
|
|
mask = mask.unsqueeze(0) |
|
|
elif mask.dim() == 3: |
|
|
if mask.shape[0] > 4: |
|
|
mask = mask.permute(2, 0, 1) |
|
|
|
|
|
|
|
|
if image.dtype != torch.float32: |
|
|
image = image.float() |
|
|
if not idx_mask and mask.dtype != torch.float32: |
|
|
mask = mask.float() |
|
|
|
|
|
if image.max() > 1.0: |
|
|
image = image / 255.0 |
|
|
if not idx_mask and mask.max() > 1.0: |
|
|
mask = mask / 255.0 |
|
|
|
|
|
|
|
|
if original_step: |
|
|
try: |
|
|
result = original_step(image, mask, idx_mask=idx_mask) |
|
|
|
|
|
if isinstance(result, torch.Tensor): |
|
|
result = result.cpu().numpy() |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error(f"MatAnyone original step failed: {e}") |
|
|
|
|
|
|
|
|
if isinstance(mask, torch.Tensor): |
|
|
|
|
|
import torch.nn.functional as F |
|
|
mask = F.avg_pool2d(mask.unsqueeze(0), 3, stride=1, padding=1) |
|
|
mask = mask.squeeze(0).cpu().numpy() |
|
|
|
|
|
return mask |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"MatAnyone safe_wrapper failed: {e}") |
|
|
import traceback |
|
|
logger.debug(traceback.format_exc()) |
|
|
|
|
|
if 'mask' in locals() and mask is not None: |
|
|
if isinstance(mask, torch.Tensor): |
|
|
return mask.cpu().numpy() |
|
|
return mask |
|
|
return np.ones((512, 512), dtype=np.float32) * 0.5 |
|
|
|
|
|
|
|
|
if hasattr(processor, 'step'): |
|
|
processor.step = safe_wrapper |
|
|
logger.info("Patched MatAnyone step method") |
|
|
|
|
|
if hasattr(processor, 'process'): |
|
|
processor.process = safe_wrapper |
|
|
logger.info("Patched MatAnyone process method") |
|
|
|
|
|
|
|
|
processor.__call__ = safe_wrapper |
|
|
|
|
|
def _load_fallback(self) -> Optional[Any]: |
|
|
"""Create fallback processor for testing""" |
|
|
|
|
|
class FallbackMatAnyone: |
|
|
def __init__(self, device): |
|
|
self.device = device |
|
|
|
|
|
def step(self, image, mask, idx_mask=False, **kwargs): |
|
|
"""Pass through mask with minor smoothing""" |
|
|
if isinstance(mask, np.ndarray): |
|
|
|
|
|
import cv2 |
|
|
if mask.ndim == 2: |
|
|
smoothed = cv2.GaussianBlur(mask, (5, 5), 1.0) |
|
|
return smoothed |
|
|
elif mask.ndim == 3: |
|
|
smoothed = np.zeros_like(mask) |
|
|
for i in range(mask.shape[0]): |
|
|
smoothed[i] = cv2.GaussianBlur(mask[i], (5, 5), 1.0) |
|
|
return smoothed |
|
|
return mask |
|
|
|
|
|
def process(self, image, mask, **kwargs): |
|
|
"""Alias for step""" |
|
|
return self.step(image, mask, **kwargs) |
|
|
|
|
|
logger.warning("Using fallback MatAnyone (limited refinement)") |
|
|
return FallbackMatAnyone(self.device) |
|
|
|
|
|
def cleanup(self): |
|
|
"""Clean up resources""" |
|
|
if self.model: |
|
|
del self.model |
|
|
self.model = None |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def get_info(self) -> Dict[str, Any]: |
|
|
"""Get loader information""" |
|
|
return { |
|
|
"loaded": self.model is not None, |
|
|
"model_id": self.model_id, |
|
|
"device": self.device, |
|
|
"load_time": self.load_time, |
|
|
"model_type": type(self.model).__name__ if self.model else None |
|
|
} |