|
|
""" |
|
|
Model Loading Module |
|
|
Handles loading and validation of SAM2 and MatAnyone AI models |
|
|
""" |
|
|
|
|
|
import os |
|
|
import gc |
|
|
import time |
|
|
import logging |
|
|
import tempfile |
|
|
import traceback |
|
|
from typing import Optional, Dict, Any, Tuple, Union |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import hydra |
|
|
import gradio as gr |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
|
|
|
import exceptions |
|
|
import device_manager |
|
|
import memory_manager |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ModelLoader: |
|
|
""" |
|
|
Comprehensive model loading and management for SAM2 and MatAnyone |
|
|
""" |
|
|
|
|
|
def __init__(self, device_mgr: device_manager.DeviceManager, memory_mgr: memory_manager.MemoryManager): |
|
|
self.device_manager = device_mgr |
|
|
self.memory_manager = memory_mgr |
|
|
self.device = self.device_manager.get_optimal_device() |
|
|
|
|
|
|
|
|
self.sam2_predictor = None |
|
|
self.matanyone_model = None |
|
|
self.matanyone_core = None |
|
|
|
|
|
|
|
|
self.configs_dir = os.path.abspath("Configs") |
|
|
self.checkpoints_dir = "./checkpoints" |
|
|
os.makedirs(self.checkpoints_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.loading_stats = { |
|
|
'sam2_load_time': 0.0, |
|
|
'matanyone_load_time': 0.0, |
|
|
'total_load_time': 0.0, |
|
|
'models_loaded': False, |
|
|
'loading_attempts': 0 |
|
|
} |
|
|
|
|
|
logger.info(f"ModelLoader initialized for device: {self.device}") |
|
|
self._apply_gradio_patch() |
|
|
|
|
|
def _apply_gradio_patch(self): |
|
|
"""Apply Gradio schema monkey patch to prevent validation errors""" |
|
|
try: |
|
|
import gradio.components.base |
|
|
original_get_config = gradio.components.base.Component.get_config |
|
|
|
|
|
def patched_get_config(self): |
|
|
config = original_get_config(self) |
|
|
|
|
|
config.pop("show_progress_bar", None) |
|
|
config.pop("min_width", None) |
|
|
config.pop("scale", None) |
|
|
return config |
|
|
|
|
|
gradio.components.base.Component.get_config = patched_get_config |
|
|
logger.debug("Applied Gradio schema monkey patch") |
|
|
|
|
|
except (ImportError, AttributeError) as e: |
|
|
logger.warning(f"Could not apply Gradio monkey patch: {e}") |
|
|
|
|
|
def load_all_models(self, progress_callback: Optional[callable] = None, cancel_event=None) -> Tuple[Any, Any]: |
|
|
""" |
|
|
Load both SAM2 and MatAnyone models with comprehensive error handling |
|
|
|
|
|
Args: |
|
|
progress_callback: Progress update callback |
|
|
cancel_event: Event to check for cancellation |
|
|
|
|
|
Returns: |
|
|
Tuple of (sam2_predictor, matanyone_model) |
|
|
""" |
|
|
start_time = time.time() |
|
|
self.loading_stats['loading_attempts'] += 1 |
|
|
|
|
|
try: |
|
|
logger.info("Starting model loading process...") |
|
|
if progress_callback: |
|
|
progress_callback(0.0, "Initializing model loading...") |
|
|
|
|
|
|
|
|
self._cleanup_models() |
|
|
|
|
|
|
|
|
logger.info("Loading SAM2 predictor...") |
|
|
if progress_callback: |
|
|
progress_callback(0.1, "Loading SAM2 predictor...") |
|
|
|
|
|
self.sam2_predictor = self._load_sam2_predictor(progress_callback) |
|
|
|
|
|
if self.sam2_predictor is None: |
|
|
raise exceptions.ModelLoadingError("Failed to load SAM2 predictor") |
|
|
|
|
|
sam2_time = time.time() - start_time |
|
|
self.loading_stats['sam2_load_time'] = sam2_time |
|
|
logger.info(f"SAM2 loaded in {sam2_time:.2f}s") |
|
|
|
|
|
|
|
|
logger.info("Loading MatAnyone model...") |
|
|
if progress_callback: |
|
|
progress_callback(0.6, "Loading MatAnyone model...") |
|
|
|
|
|
matanyone_start = time.time() |
|
|
|
|
|
self.matanyone_model, self.matanyone_core = self._load_matanyone_model(progress_callback) |
|
|
|
|
|
if self.matanyone_model is None: |
|
|
raise exceptions.ModelLoadingError("Failed to load MatAnyone model") |
|
|
|
|
|
matanyone_time = time.time() - matanyone_start |
|
|
self.loading_stats['matanyone_load_time'] = matanyone_time |
|
|
logger.info(f"MatAnyone loaded in {matanyone_time:.2f}s") |
|
|
|
|
|
|
|
|
total_time = time.time() - start_time |
|
|
self.loading_stats['total_load_time'] = total_time |
|
|
self.loading_stats['models_loaded'] = True |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(1.0, "Models loaded successfully!") |
|
|
|
|
|
logger.info(f"All models loaded successfully in {total_time:.2f}s") |
|
|
|
|
|
return self.sam2_predictor, self.matanyone_model |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Model loading failed: {str(e)}" |
|
|
logger.error(f"{error_msg}\n{traceback.format_exc()}") |
|
|
|
|
|
|
|
|
self._cleanup_models() |
|
|
self.loading_stats['models_loaded'] = False |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(1.0, f"Error: {error_msg}") |
|
|
|
|
|
return None, None |
|
|
|
|
|
def _load_sam2_predictor(self, progress_callback: Optional[callable] = None): |
|
|
""" |
|
|
Load SAM2 predictor with multiple fallback strategies |
|
|
|
|
|
Args: |
|
|
progress_callback: Progress update callback |
|
|
|
|
|
Returns: |
|
|
SAM2ImagePredictor or None |
|
|
""" |
|
|
if not os.path.isdir(self.configs_dir): |
|
|
logger.warning(f"SAM2 Configs directory not found at '{self.configs_dir}', trying fallback loading") |
|
|
|
|
|
def try_load_sam2(config_name_with_yaml: str, checkpoint_name: str): |
|
|
"""Attempt to load SAM2 with given config and checkpoint""" |
|
|
try: |
|
|
checkpoint_path = os.path.join(self.checkpoints_dir, checkpoint_name) |
|
|
logger.info(f"Attempting SAM2 checkpoint: {checkpoint_path}") |
|
|
|
|
|
|
|
|
if not os.path.exists(checkpoint_path): |
|
|
logger.info(f"Downloading {checkpoint_name} from Hugging Face Hub...") |
|
|
if progress_callback: |
|
|
progress_callback(0.2, f"Downloading {checkpoint_name}...") |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
repo = f"facebook/{config_name_with_yaml.replace('.yaml','')}" |
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=repo, |
|
|
filename=checkpoint_name, |
|
|
cache_dir=self.checkpoints_dir, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
logger.info(f"Download complete: {checkpoint_path}") |
|
|
except Exception as download_error: |
|
|
logger.warning(f"Failed to download {checkpoint_name}: {download_error}") |
|
|
return None |
|
|
|
|
|
|
|
|
if os.path.isdir(self.configs_dir): |
|
|
if hydra.core.global_hydra.GlobalHydra.instance().is_initialized(): |
|
|
hydra.core.global_hydra.GlobalHydra.instance().clear() |
|
|
|
|
|
hydra.initialize( |
|
|
version_base=None, |
|
|
config_path=os.path.relpath(self.configs_dir), |
|
|
job_name=f"sam2_load_{int(time.time())}" |
|
|
) |
|
|
|
|
|
|
|
|
config_name = config_name_with_yaml.replace(".yaml", "") |
|
|
if progress_callback: |
|
|
progress_callback(0.4, f"Building {config_name}...") |
|
|
|
|
|
from sam2.build_sam import build_sam2 |
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
|
|
sam2_model = build_sam2(config_name, checkpoint_path) |
|
|
sam2_model.to(self.device) |
|
|
predictor = SAM2ImagePredictor(sam2_model) |
|
|
|
|
|
logger.info(f"SAM2 {config_name} loaded successfully on {self.device}") |
|
|
return predictor |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Failed to load SAM2 {config_name_with_yaml}: {e}" |
|
|
logger.warning(error_msg) |
|
|
return None |
|
|
|
|
|
|
|
|
model_attempts = [ |
|
|
("sam2_hiera_large.yaml", "sam2_hiera_large.pt"), |
|
|
("sam2_hiera_base_plus.yaml", "sam2_hiera_base_plus.pt"), |
|
|
("sam2_hiera_small.yaml", "sam2_hiera_small.pt"), |
|
|
("sam2_hiera_tiny.yaml", "sam2_hiera_tiny.pt") |
|
|
] |
|
|
|
|
|
|
|
|
if hasattr(self.device_manager, 'get_device_memory_gb'): |
|
|
try: |
|
|
memory_gb = self.device_manager.get_device_memory_gb() |
|
|
if memory_gb < 4: |
|
|
model_attempts = model_attempts[2:] |
|
|
elif memory_gb < 8: |
|
|
model_attempts = model_attempts[1:] |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not determine device memory: {e}") |
|
|
|
|
|
for config_yaml, checkpoint_pt in model_attempts: |
|
|
predictor = try_load_sam2(config_yaml, checkpoint_pt) |
|
|
if predictor is not None: |
|
|
return predictor |
|
|
|
|
|
logger.error("All SAM2 model loading attempts failed") |
|
|
return None |
|
|
|
|
|
def _load_matanyone_model(self, progress_callback: Optional[callable] = None): |
|
|
""" |
|
|
Load MatAnyone model with multiple import strategies |
|
|
|
|
|
Args: |
|
|
progress_callback: Progress update callback |
|
|
|
|
|
Returns: |
|
|
Tuple[model, core] or (None, None) |
|
|
""" |
|
|
import_strategies = [ |
|
|
self._load_matanyone_strategy_1, |
|
|
self._load_matanyone_strategy_2, |
|
|
self._load_matanyone_strategy_3, |
|
|
self._load_matanyone_strategy_4 |
|
|
] |
|
|
|
|
|
for i, strategy in enumerate(import_strategies, 1): |
|
|
try: |
|
|
logger.info(f"Trying MatAnyone loading strategy {i}...") |
|
|
if progress_callback: |
|
|
progress_callback(0.7 + (i * 0.05), f"MatAnyone strategy {i}...") |
|
|
|
|
|
model, core = strategy() |
|
|
if model is not None and core is not None: |
|
|
logger.info(f"MatAnyone loaded successfully with strategy {i}") |
|
|
return model, core |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"MatAnyone strategy {i} failed: {e}") |
|
|
continue |
|
|
|
|
|
logger.error("All MatAnyone loading strategies failed") |
|
|
return None, None |
|
|
|
|
|
def _load_matanyone_strategy_1(self): |
|
|
"""MatAnyone loading strategy 1: Direct model import""" |
|
|
from matanyone.model.matanyone import MatAnyOne |
|
|
from matanyone.inference.inference_core import InferenceCore |
|
|
|
|
|
cfg = OmegaConf.create({ |
|
|
'model': {'name': 'MatAnyOne'}, |
|
|
'device': str(self.device), |
|
|
'fp16': True if self.device.type == 'cuda' else False |
|
|
}) |
|
|
|
|
|
net = MatAnyOne(cfg) |
|
|
core = InferenceCore(net, cfg) |
|
|
|
|
|
return net, core |
|
|
|
|
|
def _load_matanyone_strategy_2(self): |
|
|
"""MatAnyone loading strategy 2: Alternative import paths""" |
|
|
from matanyone import MatAnyOne |
|
|
from matanyone import InferenceCore |
|
|
|
|
|
cfg = OmegaConf.create({ |
|
|
'model_name': 'matanyone', |
|
|
'device': str(self.device) |
|
|
}) |
|
|
|
|
|
model = MatAnyOne(cfg) |
|
|
core = InferenceCore(model, cfg) |
|
|
|
|
|
return model, core |
|
|
|
|
|
def _load_matanyone_strategy_3(self): |
|
|
"""MatAnyone loading strategy 3: Repository-specific imports""" |
|
|
try: |
|
|
from matanyone.models.matanyone import MatAnyOneModel |
|
|
from matanyone.core import InferenceEngine |
|
|
except ImportError: |
|
|
from matanyone.src.models import MatAnyOneModel |
|
|
from matanyone.src.core import InferenceEngine |
|
|
|
|
|
config = { |
|
|
'model_path': None, |
|
|
'device': self.device, |
|
|
'precision': 'fp16' if self.device.type == 'cuda' else 'fp32' |
|
|
} |
|
|
|
|
|
model = MatAnyOneModel.from_pretrained(config) |
|
|
engine = InferenceEngine(model) |
|
|
|
|
|
return model, engine |
|
|
|
|
|
def _load_matanyone_strategy_4(self): |
|
|
"""MatAnyone loading strategy 4: Hugging Face Hub approach""" |
|
|
from huggingface_hub import hf_hub_download |
|
|
from matanyone import load_model_from_hub |
|
|
|
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id="PeiqingYang/MatAnyone", |
|
|
filename="pytorch_model.bin", |
|
|
cache_dir=self.checkpoints_dir |
|
|
) |
|
|
|
|
|
model = load_model_from_hub(model_path, device=self.device) |
|
|
|
|
|
return model, model |
|
|
|
|
|
def _cleanup_models(self): |
|
|
"""Clean up loaded models and free memory""" |
|
|
if self.sam2_predictor is not None: |
|
|
del self.sam2_predictor |
|
|
self.sam2_predictor = None |
|
|
|
|
|
if self.matanyone_model is not None: |
|
|
del self.matanyone_model |
|
|
self.matanyone_model = None |
|
|
|
|
|
if self.matanyone_core is not None: |
|
|
del self.matanyone_core |
|
|
self.matanyone_core = None |
|
|
|
|
|
|
|
|
self.memory_manager.cleanup_gpu_memory() |
|
|
gc.collect() |
|
|
|
|
|
logger.debug("Model cleanup completed") |
|
|
|
|
|
def get_model_info(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get information about loaded models |
|
|
|
|
|
Returns: |
|
|
Dict with model information and statistics |
|
|
""" |
|
|
info = { |
|
|
'models_loaded': self.loading_stats['models_loaded'], |
|
|
'sam2_loaded': self.sam2_predictor is not None, |
|
|
'matanyone_loaded': self.matanyone_model is not None, |
|
|
'device': str(self.device), |
|
|
'loading_stats': self.loading_stats.copy() |
|
|
} |
|
|
|
|
|
if self.sam2_predictor is not None: |
|
|
try: |
|
|
info['sam2_model_type'] = type(self.sam2_predictor.model).__name__ |
|
|
except: |
|
|
info['sam2_model_type'] = "Unknown" |
|
|
|
|
|
if self.matanyone_model is not None: |
|
|
try: |
|
|
info['matanyone_model_type'] = type(self.matanyone_model).__name__ |
|
|
except: |
|
|
info['matanyone_model_type'] = "Unknown" |
|
|
|
|
|
return info |
|
|
|
|
|
def get_status(self) -> Dict[str, Any]: |
|
|
"""Get model loader status for backward compatibility""" |
|
|
return self.get_model_info() |
|
|
|
|
|
def get_load_summary(self) -> str: |
|
|
"""Get a human-readable summary of model loading""" |
|
|
if not self.loading_stats['models_loaded']: |
|
|
return "Models not loaded" |
|
|
|
|
|
sam2_time = self.loading_stats['sam2_load_time'] |
|
|
matanyone_time = self.loading_stats['matanyone_load_time'] |
|
|
total_time = self.loading_stats['total_load_time'] |
|
|
|
|
|
summary = f"Models loaded successfully in {total_time:.1f}s\n" |
|
|
summary += f"SAM2: {sam2_time:.1f}s\n" |
|
|
summary += f"MatAnyone: {matanyone_time:.1f}s\n" |
|
|
summary += f"Device: {self.device}" |
|
|
|
|
|
return summary |
|
|
|
|
|
def validate_models(self) -> bool: |
|
|
""" |
|
|
Validate that models are properly loaded and functional |
|
|
|
|
|
Returns: |
|
|
bool: True if models are valid |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not self.loading_stats['models_loaded']: |
|
|
return False |
|
|
|
|
|
if self.sam2_predictor is None or self.matanyone_model is None: |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Model validation passed") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model validation failed: {e}") |
|
|
return False |
|
|
|
|
|
def reload_models(self, progress_callback: Optional[callable] = None) -> Tuple[Any, Any]: |
|
|
""" |
|
|
Reload all models (useful for error recovery) |
|
|
|
|
|
Args: |
|
|
progress_callback: Progress update callback |
|
|
|
|
|
Returns: |
|
|
Tuple of (sam2_predictor, matanyone_model) |
|
|
""" |
|
|
logger.info("Reloading models...") |
|
|
self._cleanup_models() |
|
|
self.loading_stats['models_loaded'] = False |
|
|
|
|
|
return self.load_all_models(progress_callback) |
|
|
|
|
|
def cleanup(self): |
|
|
"""Clean up all resources""" |
|
|
self._cleanup_models() |
|
|
logger.info("ModelLoader cleanup completed") |
|
|
|
|
|
@property |
|
|
def models_ready(self) -> bool: |
|
|
"""Check if all models are loaded and ready""" |
|
|
return ( |
|
|
self.loading_stats['models_loaded'] and |
|
|
self.sam2_predictor is not None and |
|
|
self.matanyone_model is not None |
|
|
) |