|
|
|
|
|
""" |
|
|
Model Loader for Hugging Face Spaces |
|
|
- Robust SAM2 loader with multiple strategies |
|
|
- Correct MatAnyOne loader via official InferenceCore (no transformers) |
|
|
- Clean progress reporting, cleanup, and diagnostics |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import gc |
|
|
import time |
|
|
import logging |
|
|
import traceback |
|
|
from pathlib import Path |
|
|
from typing import Optional, Dict, Any, Tuple, Callable |
|
|
|
|
|
import torch |
|
|
|
|
|
from core.exceptions import ModelLoadingError |
|
|
from utils.hardware.device_manager import DeviceManager |
|
|
from utils.system.memory_manager import MemoryManager |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoadedModel: |
|
|
def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, device: str = "", framework: str = ""): |
|
|
self.model = model |
|
|
self.model_id = model_id |
|
|
self.load_time = load_time |
|
|
self.device = device |
|
|
self.framework = framework |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"model_id": self.model_id, |
|
|
"framework": self.framework, |
|
|
"device": self.device, |
|
|
"load_time": self.load_time, |
|
|
"loaded": self.model is not None, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelLoader: |
|
|
def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager): |
|
|
self.device_manager = device_mgr |
|
|
self.memory_manager = memory_mgr |
|
|
self.device = self.device_manager.get_optimal_device() |
|
|
|
|
|
self.sam2_predictor: Optional[LoadedModel] = None |
|
|
self.matanyone_model: Optional[LoadedModel] = None |
|
|
|
|
|
self.checkpoints_dir = "./checkpoints" |
|
|
os.makedirs(self.checkpoints_dir, exist_ok=True) |
|
|
|
|
|
self.loading_stats = { |
|
|
"sam2_load_time": 0.0, |
|
|
"matanyone_load_time": 0.0, |
|
|
"total_load_time": 0.0, |
|
|
"models_loaded": False, |
|
|
"loading_attempts": 0, |
|
|
} |
|
|
|
|
|
logger.info(f"ModelLoader initialized for device: {self.device}") |
|
|
|
|
|
|
|
|
|
|
|
def load_all_models( |
|
|
self, progress_callback: Optional[Callable[[float, str], None]] = None, cancel_event=None |
|
|
) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]: |
|
|
""" |
|
|
Loads SAM2 + MatAnyOne. Returns (LoadedModel|None, LoadedModel|None). |
|
|
""" |
|
|
start_time = time.time() |
|
|
self.loading_stats["loading_attempts"] += 1 |
|
|
|
|
|
try: |
|
|
logger.info("Starting model loading process...") |
|
|
if progress_callback: |
|
|
progress_callback(0.0, "Initializing model loading...") |
|
|
|
|
|
self._cleanup_models() |
|
|
|
|
|
|
|
|
logger.info("Loading SAM2 predictor...") |
|
|
if progress_callback: |
|
|
progress_callback(0.1, "Loading SAM2 predictor...") |
|
|
sam2_loaded = self._load_sam2_predictor(progress_callback) |
|
|
|
|
|
if sam2_loaded is None: |
|
|
logger.warning("SAM2 loading failed - a limited fallback will be used at runtime if needed.") |
|
|
else: |
|
|
self.sam2_predictor = sam2_loaded |
|
|
self.loading_stats["sam2_load_time"] = self.sam2_predictor.load_time |
|
|
logger.info(f"SAM2 loaded in {self.loading_stats['sam2_load_time']:.2f}s") |
|
|
|
|
|
|
|
|
if cancel_event is not None and getattr(cancel_event, "is_set", lambda: False)(): |
|
|
if progress_callback: |
|
|
progress_callback(1.0, "Model loading cancelled") |
|
|
return self.sam2_predictor, None |
|
|
|
|
|
|
|
|
logger.info("Loading MatAnyOne model...") |
|
|
if progress_callback: |
|
|
progress_callback(0.6, "Loading MatAnyOne model...") |
|
|
matanyone_loaded = self._load_matanyone(progress_callback) |
|
|
|
|
|
if matanyone_loaded is None: |
|
|
logger.warning("MatAnyOne loading failed - will use simple refinement fallbacks.") |
|
|
else: |
|
|
self.matanyone_model = matanyone_loaded |
|
|
self.loading_stats["matanyone_load_time"] = self.matanyone_model.load_time |
|
|
logger.info(f"MatAnyOne loaded in {self.loading_stats['matanyone_load_time']:.2f}s") |
|
|
|
|
|
|
|
|
total_time = time.time() - start_time |
|
|
self.loading_stats["total_load_time"] = total_time |
|
|
self.loading_stats["models_loaded"] = bool(self.sam2_predictor or self.matanyone_model) |
|
|
|
|
|
if progress_callback: |
|
|
if self.loading_stats["models_loaded"]: |
|
|
progress_callback(1.0, "Models loaded (fallbacks available if any model failed)") |
|
|
else: |
|
|
progress_callback(1.0, "Using fallback methods (models failed to load)") |
|
|
|
|
|
logger.info(f"Model loading completed in {total_time:.2f}s") |
|
|
return self.sam2_predictor, self.matanyone_model |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Model loading failed: {str(e)}" |
|
|
logger.error(f"{error_msg}\n{traceback.format_exc()}") |
|
|
self._cleanup_models() |
|
|
self.loading_stats["models_loaded"] = False |
|
|
if progress_callback: |
|
|
progress_callback(1.0, f"Error: {error_msg}") |
|
|
return None, None |
|
|
|
|
|
def reload_models(self, progress_callback: Optional[Callable[[float, str], None]] = None) -> Tuple[ |
|
|
Optional[LoadedModel], Optional[LoadedModel] |
|
|
]: |
|
|
logger.info("Reloading models...") |
|
|
self._cleanup_models() |
|
|
self.loading_stats["models_loaded"] = False |
|
|
return self.load_all_models(progress_callback) |
|
|
|
|
|
@property |
|
|
def models_ready(self) -> bool: |
|
|
return self.sam2_predictor is not None or self.matanyone_model is not None |
|
|
|
|
|
def get_sam2(self): |
|
|
return self.sam2_predictor.model if self.sam2_predictor is not None else None |
|
|
|
|
|
def get_matanyone(self): |
|
|
return self.matanyone_model.model if self.matanyone_model is not None else None |
|
|
|
|
|
def validate_models(self) -> bool: |
|
|
try: |
|
|
ok = False |
|
|
if self.sam2_predictor is not None: |
|
|
model = self.sam2_predictor.model |
|
|
if hasattr(model, "set_image") or hasattr(model, "predict"): |
|
|
ok = True |
|
|
if self.matanyone_model is not None: |
|
|
ok = True |
|
|
return ok |
|
|
except Exception as e: |
|
|
logger.error(f"Model validation failed: {e}") |
|
|
return False |
|
|
|
|
|
def get_model_info(self) -> Dict[str, Any]: |
|
|
info = { |
|
|
"models_loaded": self.loading_stats["models_loaded"], |
|
|
"sam2_loaded": self.sam2_predictor is not None, |
|
|
"matanyone_loaded": self.matanyone_model is not None, |
|
|
"device": str(self.device), |
|
|
"loading_stats": self.loading_stats.copy(), |
|
|
} |
|
|
if self.sam2_predictor is not None: |
|
|
info["sam2_model_type"] = type(self.sam2_predictor.model).__name__ |
|
|
info["sam2_metadata"] = self.sam2_predictor.to_dict() |
|
|
if self.matanyone_model is not None: |
|
|
info["matanyone_model_type"] = type(self.matanyone_model.model).__name__ |
|
|
info["matanyone_metadata"] = self.matanyone_model.to_dict() |
|
|
return info |
|
|
|
|
|
def get_load_summary(self) -> str: |
|
|
if not self.loading_stats["models_loaded"]: |
|
|
return "Models not loaded" |
|
|
sam2_time = self.loading_stats["sam2_load_time"] |
|
|
matanyone_time = self.loading_stats["matanyone_load_time"] |
|
|
total_time = self.loading_stats["total_load_time"] |
|
|
summary = f"Models loaded in {total_time:.1f}s\n" |
|
|
if self.sam2_predictor: |
|
|
summary += f"✓ SAM2: {sam2_time:.1f}s (ID: {self.sam2_predictor.model_id})\n" |
|
|
else: |
|
|
summary += "✗ SAM2: Failed (using fallback)\n" |
|
|
if self.matanyone_model: |
|
|
summary += f"✓ MatAnyOne: {matanyone_time:.1f}s (ID: {self.matanyone_model.model_id})\n" |
|
|
else: |
|
|
summary += "✗ MatAnyOne: Failed (using simple refinement)\n" |
|
|
summary += f"Device: {self.device}" |
|
|
return summary |
|
|
|
|
|
def cleanup(self): |
|
|
self._cleanup_models() |
|
|
logger.info("ModelLoader cleanup completed") |
|
|
|
|
|
|
|
|
|
|
|
def _load_sam2_predictor(self, progress_callback: Optional[Callable[[float, str], None]] = None) -> Optional[LoadedModel]: |
|
|
""" |
|
|
Try multiple SAM2 loading strategies: official -> transformers -> dummy fallback. |
|
|
""" |
|
|
|
|
|
model_size = "large" |
|
|
try: |
|
|
if hasattr(self.device_manager, "get_device_memory_gb"): |
|
|
memory_gb = self.device_manager.get_device_memory_gb() |
|
|
if memory_gb < 4: |
|
|
model_size = "tiny" |
|
|
elif memory_gb < 8: |
|
|
model_size = "small" |
|
|
elif memory_gb < 12: |
|
|
model_size = "base" |
|
|
logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB VRAM") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not determine device memory: {e}") |
|
|
model_size = "tiny" |
|
|
|
|
|
model_map = { |
|
|
"tiny": "facebook/sam2.1-hiera-tiny", |
|
|
"small": "facebook/sam2.1-hiera-small", |
|
|
"base": "facebook/sam2.1-hiera-base-plus", |
|
|
"large": "facebook/sam2.1-hiera-large", |
|
|
} |
|
|
model_id = model_map.get(model_size, model_map["tiny"]) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0.3, f"Loading SAM2 ({model_size})...") |
|
|
|
|
|
methods = [ |
|
|
("official", self._try_load_sam2_official, model_id), |
|
|
("direct", self._try_load_sam2_direct, model_id), |
|
|
("manual", self._try_load_sam2_manual, model_id), |
|
|
] |
|
|
|
|
|
for name, fn, mid in methods: |
|
|
try: |
|
|
logger.info(f"Attempting SAM2 load via {name} method ({mid})...") |
|
|
result = fn(mid) |
|
|
if result is not None: |
|
|
logger.info(f"SAM2 loaded successfully via {name} method") |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error(f"SAM2 {name} method failed: {e}") |
|
|
logger.debug(traceback.format_exc()) |
|
|
continue |
|
|
|
|
|
logger.error("All SAM2 loading methods failed") |
|
|
return None |
|
|
|
|
|
def _try_load_sam2_official(self, model_id: str) -> Optional[LoadedModel]: |
|
|
""" |
|
|
Official predictor path (Meta's SAM2ImagePredictor). |
|
|
""" |
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
|
|
|
|
|
os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1" |
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" |
|
|
|
|
|
cache_dir = os.path.join(self.checkpoints_dir, "sam2_cache") |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
t0 = time.time() |
|
|
predictor = SAM2ImagePredictor.from_pretrained( |
|
|
model_id, |
|
|
cache_dir=cache_dir, |
|
|
local_files_only=False, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
if hasattr(predictor, "model"): |
|
|
predictor.model = predictor.model.to(self.device) |
|
|
t1 = time.time() |
|
|
|
|
|
return LoadedModel( |
|
|
model=predictor, model_id=model_id, load_time=t1 - t0, device=str(self.device), framework="sam2" |
|
|
) |
|
|
|
|
|
def _try_load_sam2_direct(self, model_id: str) -> Optional[LoadedModel]: |
|
|
""" |
|
|
Transformers AutoModel path (best-effort; API may vary). |
|
|
""" |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
|
|
|
t0 = time.time() |
|
|
model = AutoModel.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
).to(self.device) |
|
|
|
|
|
try: |
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
except Exception: |
|
|
processor = None |
|
|
|
|
|
t1 = time.time() |
|
|
|
|
|
class SAM2Wrapper: |
|
|
def __init__(self, model, processor=None): |
|
|
self.model = model |
|
|
self.processor = processor |
|
|
|
|
|
def set_image(self, image): |
|
|
self.current_image = image |
|
|
|
|
|
def predict(self, *args, **kwargs): |
|
|
return self.model(*args, **kwargs) |
|
|
|
|
|
wrapped = SAM2Wrapper(model, processor) |
|
|
|
|
|
return LoadedModel( |
|
|
model=wrapped, |
|
|
model_id=model_id, |
|
|
load_time=t1 - t0, |
|
|
device=str(self.device), |
|
|
framework="sam2-transformers", |
|
|
) |
|
|
|
|
|
def _try_load_sam2_manual(self, model_id: str) -> Optional[LoadedModel]: |
|
|
""" |
|
|
Dummy fallback that won't crash the app. |
|
|
""" |
|
|
class DummySAM2: |
|
|
def __init__(self, device): |
|
|
self.device = device |
|
|
self.model = None |
|
|
|
|
|
def set_image(self, image): |
|
|
self.current_image = image |
|
|
|
|
|
def predict(self, point_coords=None, point_labels=None, box=None, **kwargs): |
|
|
import numpy as np |
|
|
if hasattr(self, "current_image"): |
|
|
h, w = self.current_image.shape[:2] |
|
|
else: |
|
|
h, w = 512, 512 |
|
|
return { |
|
|
"masks": np.ones((1, h, w), dtype=np.float32), |
|
|
"scores": np.array([0.5]), |
|
|
"logits": np.ones((1, h, w), dtype=np.float32), |
|
|
} |
|
|
|
|
|
t0 = time.time() |
|
|
dummy = DummySAM2(self.device) |
|
|
t1 = time.time() |
|
|
|
|
|
logger.warning("Using manual SAM2 fallback (limited functionality)") |
|
|
return LoadedModel( |
|
|
model=dummy, model_id=f"{model_id}-fallback", load_time=t1 - t0, device=str(self.device), framework="sam2-fallback" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _load_matanyone(self, progress_callback: Optional[Callable[[float, str], None]] = None) -> Optional[LoadedModel]: |
|
|
""" |
|
|
Correct MatAnyOne loader using official package API. |
|
|
""" |
|
|
if progress_callback: |
|
|
progress_callback(0.7, "Loading MatAnyOne (InferenceCore)...") |
|
|
try: |
|
|
return self._try_load_matanyone_official() |
|
|
except Exception as e: |
|
|
logger.error(f"MatAnyOne official loader failed: {e}") |
|
|
logger.debug(traceback.format_exc()) |
|
|
logger.warning("Falling back to simple MatAnyOne placeholder.") |
|
|
return self._try_load_matanyone_fallback() |
|
|
|
|
|
def _try_load_matanyone_official(self) -> Optional[LoadedModel]: |
|
|
""" |
|
|
Official MatAnyOne via package's InferenceCore. |
|
|
IMPORTANT: pass model id POSITIONALLY; do NOT use repo_id= or transformers. |
|
|
""" |
|
|
from matanyone import InferenceCore |
|
|
|
|
|
t0 = time.time() |
|
|
processor = InferenceCore("PeiqingYang/MatAnyone") |
|
|
t1 = time.time() |
|
|
|
|
|
return LoadedModel( |
|
|
model=processor, |
|
|
model_id="PeiqingYang/MatAnyone", |
|
|
load_time=t1 - t0, |
|
|
device=str(self.device), |
|
|
framework="matanyone", |
|
|
) |
|
|
|
|
|
def _try_load_matanyone_fallback(self) -> Optional[LoadedModel]: |
|
|
""" |
|
|
Minimal placeholder that safely passes masks through. |
|
|
""" |
|
|
class FallbackMatAnyone: |
|
|
def __init__(self, device): |
|
|
self.device = device |
|
|
|
|
|
def process(self, image, mask): |
|
|
|
|
|
return mask |
|
|
|
|
|
t0 = time.time() |
|
|
model = FallbackMatAnyone(self.device) |
|
|
t1 = time.time() |
|
|
|
|
|
logger.warning("Using MatAnyOne fallback (limited functionality)") |
|
|
return LoadedModel( |
|
|
model=model, model_id="MatAnyone-fallback", load_time=t1 - t0, device=str(self.device), framework="matanyone-fallback" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _cleanup_models(self): |
|
|
if self.sam2_predictor is not None: |
|
|
del self.sam2_predictor |
|
|
self.sam2_predictor = None |
|
|
if self.matanyone_model is not None: |
|
|
del self.matanyone_model |
|
|
self.matanyone_model = None |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
logger.debug("Model cleanup completed") |
|
|
|