MogensR's picture
Update models/loaders/model_loader.py
23796fb
raw
history blame
10.4 kB
#!/usr/bin/env python3
"""
Unified Model Loader
Coordinates separate SAM2 and MatAnyone loaders for cleaner architecture
"""
from __future__ import annotations
import os
import gc
import time
import logging
from typing import Optional, Dict, Any, Tuple, Callable
import torch
from core.exceptions import ModelLoadingError
from utils.hardware.device_manager import DeviceManager
from utils.system.memory_manager import MemoryManager
# Import the specialized loaders
from models.loaders.sam2_loader import SAM2Loader
from models.loaders.matanyone_loader import MatAnyoneLoader
logger = logging.getLogger(__name__)
class LoadedModel:
"""Container for loaded model information"""
def __init__(self, model=None, model_id: str = "", load_time: float = 0.0,
device: str = "", framework: str = ""):
self.model = model
self.model_id = model_id
self.load_time = load_time
self.device = device
self.framework = framework
def to_dict(self) -> Dict[str, Any]:
return {
"model_id": self.model_id,
"framework": self.framework,
"device": self.device,
"load_time": self.load_time,
"loaded": self.model is not None,
}
class ModelLoader:
"""Main model loader that coordinates SAM2 and MatAnyone loaders"""
def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
self.device_manager = device_mgr
self.memory_manager = memory_mgr
self.device = self.device_manager.get_optimal_device()
# Initialize specialized loaders
self.sam2_loader = SAM2Loader(device=str(self.device))
self.matanyone_loader = MatAnyoneLoader(device=str(self.device))
# Model storage
self.sam2_predictor: Optional[LoadedModel] = None
self.matanyone_model: Optional[LoadedModel] = None
# Statistics
self.loading_stats = {
"sam2_load_time": 0.0,
"matanyone_load_time": 0.0,
"total_load_time": 0.0,
"models_loaded": False,
"loading_attempts": 0,
}
logger.info(f"ModelLoader initialized for device: {self.device}")
def load_all_models(
self,
progress_callback: Optional[Callable[[float, str], None]] = None,
cancel_event=None
) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
"""
Load all models using specialized loaders
Args:
progress_callback: Optional callback for progress updates
cancel_event: Optional threading.Event for cancellation
Returns:
Tuple of (sam2_model, matanyone_model)
"""
start_time = time.time()
self.loading_stats["loading_attempts"] += 1
try:
logger.info("Starting model loading process...")
if progress_callback:
progress_callback(0.0, "Initializing model loading...")
# Clean up any existing models
self._cleanup_models()
# Load SAM2
if progress_callback:
progress_callback(0.1, "Loading SAM2 model...")
sam2_start = time.time()
sam2_model = self.sam2_loader.load()
sam2_time = time.time() - sam2_start
if sam2_model:
self.sam2_predictor = LoadedModel(
model=sam2_model,
model_id=self.sam2_loader.model_id,
load_time=sam2_time,
device=str(self.device),
framework="sam2"
)
self.loading_stats["sam2_load_time"] = sam2_time
logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
else:
logger.warning("SAM2 loading failed")
# Check for cancellation
if cancel_event and cancel_event.is_set():
if progress_callback:
progress_callback(1.0, "Model loading cancelled")
return self.sam2_predictor, None
# Load MatAnyone
if progress_callback:
progress_callback(0.6, "Loading MatAnyone model...")
matanyone_start = time.time()
matanyone_model = self.matanyone_loader.load()
matanyone_time = time.time() - matanyone_start
if matanyone_model:
self.matanyone_model = LoadedModel(
model=matanyone_model,
model_id=self.matanyone_loader.model_id,
load_time=matanyone_time,
device=str(self.device),
framework="matanyone"
)
self.loading_stats["matanyone_load_time"] = matanyone_time
logger.info(f"MatAnyone loaded in {matanyone_time:.2f}s")
else:
logger.warning("MatAnyone loading failed")
# Update statistics
total_time = time.time() - start_time
self.loading_stats["total_load_time"] = total_time
self.loading_stats["models_loaded"] = bool(self.sam2_predictor or self.matanyone_model)
# Final progress update
if progress_callback:
if self.loading_stats["models_loaded"]:
progress_callback(1.0, "Models loaded successfully")
else:
progress_callback(1.0, "Model loading completed with failures")
logger.info(f"Model loading completed in {total_time:.2f}s")
return self.sam2_predictor, self.matanyone_model
except Exception as e:
error_msg = f"Model loading failed: {str(e)}"
logger.error(error_msg)
self._cleanup_models()
self.loading_stats["models_loaded"] = False
if progress_callback:
progress_callback(1.0, f"Error: {error_msg}")
return None, None
def reload_models(
self,
progress_callback: Optional[Callable[[float, str], None]] = None
) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
"""Reload all models from scratch"""
logger.info("Reloading models...")
self._cleanup_models()
self.loading_stats["models_loaded"] = False
return self.load_all_models(progress_callback)
@property
def models_ready(self) -> bool:
"""Check if any models are loaded and ready"""
return self.sam2_predictor is not None or self.matanyone_model is not None
def get_sam2(self):
"""Get SAM2 predictor model"""
return self.sam2_predictor.model if self.sam2_predictor else None
def get_matanyone(self):
"""Get MatAnyone processor model"""
return self.matanyone_model.model if self.matanyone_model else None
def validate_models(self) -> bool:
"""Validate that loaded models have expected interfaces"""
try:
valid = False
if self.sam2_predictor:
model = self.sam2_predictor.model
if hasattr(model, "set_image") and hasattr(model, "predict"):
valid = True
logger.info("SAM2 model validated")
if self.matanyone_model:
model = self.matanyone_model.model
if hasattr(model, "step") or hasattr(model, "process"):
valid = True
logger.info("MatAnyone model validated")
return valid
except Exception as e:
logger.error(f"Model validation failed: {e}")
return False
def get_model_info(self) -> Dict[str, Any]:
"""Get detailed information about loaded models"""
info = {
"models_loaded": self.loading_stats["models_loaded"],
"device": str(self.device),
"loading_stats": self.loading_stats.copy(),
}
# Add SAM2 info
info["sam2"] = self.sam2_loader.get_info() if self.sam2_loader else {}
# Add MatAnyone info
info["matanyone"] = self.matanyone_loader.get_info() if self.matanyone_loader else {}
return info
def get_load_summary(self) -> str:
"""Get human-readable loading summary"""
if not self.loading_stats["models_loaded"]:
return "No models loaded"
lines = []
lines.append(f"Models loaded in {self.loading_stats['total_load_time']:.1f}s")
if self.sam2_predictor:
lines.append(f"βœ“ SAM2: {self.loading_stats['sam2_load_time']:.1f}s")
lines.append(f" Model: {self.sam2_predictor.model_id}")
else:
lines.append("βœ— SAM2: Failed to load")
if self.matanyone_model:
lines.append(f"βœ“ MatAnyone: {self.loading_stats['matanyone_load_time']:.1f}s")
lines.append(f" Model: {self.matanyone_model.model_id}")
else:
lines.append("βœ— MatAnyone: Failed to load")
lines.append(f"Device: {self.device}")
return "\n".join(lines)
def cleanup(self):
"""Clean up all resources"""
self._cleanup_models()
logger.info("ModelLoader cleanup completed")
def _cleanup_models(self):
"""Internal cleanup of loaded models"""
# Clean up SAM2
if self.sam2_loader:
self.sam2_loader.cleanup()
if self.sam2_predictor:
del self.sam2_predictor
self.sam2_predictor = None
# Clean up MatAnyone
if self.matanyone_loader:
self.matanyone_loader.cleanup()
if self.matanyone_model:
del self.matanyone_model
self.matanyone_model = None
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Garbage collection
gc.collect()
logger.debug("Model cleanup completed")