|
|
""" |
|
|
Model Loading Module |
|
|
Handles loading and validation of SAM2 and MatAnyone AI models |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import gc |
|
|
import sys |
|
|
import time |
|
|
import shutil |
|
|
import logging |
|
|
import tempfile |
|
|
import traceback |
|
|
from typing import Optional, Dict, Any, Tuple, Union |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import gradio as gr |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
|
|
|
import exceptions |
|
|
import device_manager |
|
|
import memory_manager |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HardCacheCleaner: |
|
|
""" |
|
|
Comprehensive cache cleaning system to resolve SAM2 loading issues |
|
|
Clears Python module cache, HuggingFace cache, and temp files |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def clean_all_caches(verbose: bool = True): |
|
|
"""Clean all caches that might interfere with SAM2 loading""" |
|
|
|
|
|
if verbose: |
|
|
logger.info("Starting comprehensive cache cleanup...") |
|
|
|
|
|
|
|
|
HardCacheCleaner._clean_python_cache(verbose) |
|
|
|
|
|
|
|
|
HardCacheCleaner._clean_huggingface_cache(verbose) |
|
|
|
|
|
|
|
|
HardCacheCleaner._clean_pytorch_cache(verbose) |
|
|
|
|
|
|
|
|
HardCacheCleaner._clean_temp_directories(verbose) |
|
|
|
|
|
|
|
|
HardCacheCleaner._clear_import_cache(verbose) |
|
|
|
|
|
|
|
|
HardCacheCleaner._force_gc_cleanup(verbose) |
|
|
|
|
|
if verbose: |
|
|
logger.info("Cache cleanup completed") |
|
|
|
|
|
@staticmethod |
|
|
def _clean_python_cache(verbose: bool = True): |
|
|
"""Clean Python bytecode cache""" |
|
|
try: |
|
|
|
|
|
sam2_modules = [key for key in sys.modules.keys() if 'sam2' in key.lower()] |
|
|
for module in sam2_modules: |
|
|
if verbose: |
|
|
logger.info(f"Removing cached module: {module}") |
|
|
del sys.modules[module] |
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk("."): |
|
|
for dir_name in dirs[:]: |
|
|
if dir_name == "__pycache__": |
|
|
cache_path = os.path.join(root, dir_name) |
|
|
if verbose: |
|
|
logger.info(f"Removing __pycache__: {cache_path}") |
|
|
shutil.rmtree(cache_path, ignore_errors=True) |
|
|
dirs.remove(dir_name) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Python cache cleanup failed: {e}") |
|
|
|
|
|
@staticmethod |
|
|
def _clean_huggingface_cache(verbose: bool = True): |
|
|
"""Clean HuggingFace model cache""" |
|
|
try: |
|
|
cache_paths = [ |
|
|
os.path.expanduser("~/.cache/huggingface/"), |
|
|
os.path.expanduser("~/.cache/torch/"), |
|
|
"./checkpoints/", |
|
|
"./.cache/", |
|
|
] |
|
|
|
|
|
for cache_path in cache_paths: |
|
|
if os.path.exists(cache_path): |
|
|
if verbose: |
|
|
logger.info(f"Cleaning cache directory: {cache_path}") |
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk(cache_path): |
|
|
for file in files: |
|
|
if any(pattern in file.lower() for pattern in ['sam2', 'segment-anything-2']): |
|
|
file_path = os.path.join(root, file) |
|
|
try: |
|
|
os.remove(file_path) |
|
|
if verbose: |
|
|
logger.info(f"Removed cached file: {file_path}") |
|
|
except: |
|
|
pass |
|
|
|
|
|
for dir_name in dirs[:]: |
|
|
if any(pattern in dir_name.lower() for pattern in ['sam2', 'segment-anything-2']): |
|
|
dir_path = os.path.join(root, dir_name) |
|
|
try: |
|
|
shutil.rmtree(dir_path, ignore_errors=True) |
|
|
if verbose: |
|
|
logger.info(f"Removed cached directory: {dir_path}") |
|
|
dirs.remove(dir_name) |
|
|
except: |
|
|
pass |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"HuggingFace cache cleanup failed: {e}") |
|
|
|
|
|
@staticmethod |
|
|
def _clean_pytorch_cache(verbose: bool = True): |
|
|
"""Clean PyTorch cache""" |
|
|
try: |
|
|
import torch |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
if verbose: |
|
|
logger.info("Cleared PyTorch CUDA cache") |
|
|
except Exception as e: |
|
|
logger.warning(f"PyTorch cache cleanup failed: {e}") |
|
|
|
|
|
@staticmethod |
|
|
def _clean_temp_directories(verbose: bool = True): |
|
|
"""Clean temporary directories""" |
|
|
try: |
|
|
temp_dirs = [tempfile.gettempdir(), "/tmp", "./tmp", "./temp"] |
|
|
|
|
|
for temp_dir in temp_dirs: |
|
|
if os.path.exists(temp_dir): |
|
|
for item in os.listdir(temp_dir): |
|
|
if 'sam2' in item.lower() or 'segment' in item.lower(): |
|
|
item_path = os.path.join(temp_dir, item) |
|
|
try: |
|
|
if os.path.isfile(item_path): |
|
|
os.remove(item_path) |
|
|
elif os.path.isdir(item_path): |
|
|
shutil.rmtree(item_path, ignore_errors=True) |
|
|
if verbose: |
|
|
logger.info(f"Removed temp item: {item_path}") |
|
|
except: |
|
|
pass |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Temp directory cleanup failed: {e}") |
|
|
|
|
|
@staticmethod |
|
|
def _clear_import_cache(verbose: bool = True): |
|
|
"""Clear Python import cache""" |
|
|
try: |
|
|
import importlib |
|
|
|
|
|
|
|
|
importlib.invalidate_caches() |
|
|
|
|
|
if verbose: |
|
|
logger.info("Cleared Python import cache") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Import cache cleanup failed: {e}") |
|
|
|
|
|
@staticmethod |
|
|
def _force_gc_cleanup(verbose: bool = True): |
|
|
"""Force garbage collection""" |
|
|
try: |
|
|
collected = gc.collect() |
|
|
if verbose: |
|
|
logger.info(f"Garbage collection freed {collected} objects") |
|
|
except Exception as e: |
|
|
logger.warning(f"Garbage collection failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelLoader: |
|
|
""" |
|
|
Comprehensive model loading and management for SAM2 and MatAnyone |
|
|
Handles automatic config detection, multiple fallback strategies, and memory management |
|
|
""" |
|
|
|
|
|
def __init__(self, device_mgr: device_manager.DeviceManager, memory_mgr: memory_manager.MemoryManager): |
|
|
self.device_manager = device_mgr |
|
|
self.memory_manager = memory_mgr |
|
|
self.device = self.device_manager.get_optimal_device() |
|
|
|
|
|
|
|
|
self.sam2_predictor = None |
|
|
self.matanyone_model = None |
|
|
self.matanyone_core = None |
|
|
|
|
|
|
|
|
self.checkpoints_dir = "./checkpoints" |
|
|
os.makedirs(self.checkpoints_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.loading_stats = { |
|
|
'sam2_load_time': 0.0, |
|
|
'matanyone_load_time': 0.0, |
|
|
'total_load_time': 0.0, |
|
|
'models_loaded': False, |
|
|
'loading_attempts': 0 |
|
|
} |
|
|
|
|
|
logger.info(f"ModelLoader initialized for device: {self.device}") |
|
|
self._apply_gradio_patch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _apply_gradio_patch(self): |
|
|
"""Apply Gradio schema monkey patch to prevent validation errors""" |
|
|
try: |
|
|
import gradio.components.base |
|
|
original_get_config = gradio.components.base.Component.get_config |
|
|
|
|
|
def patched_get_config(self): |
|
|
config = original_get_config(self) |
|
|
|
|
|
config.pop("show_progress_bar", None) |
|
|
config.pop("min_width", None) |
|
|
config.pop("scale", None) |
|
|
return config |
|
|
|
|
|
gradio.components.base.Component.get_config = patched_get_config |
|
|
logger.debug("Applied Gradio schema monkey patch") |
|
|
|
|
|
except (ImportError, AttributeError) as e: |
|
|
logger.warning(f"Could not apply Gradio monkey patch: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_all_models(self, progress_callback: Optional[callable] = None, cancel_event=None) -> Tuple[Any, Any]: |
|
|
""" |
|
|
Load both SAM2 and MatAnyone models with comprehensive error handling |
|
|
|
|
|
Args: |
|
|
progress_callback: Progress update callback |
|
|
cancel_event: Event to check for cancellation |
|
|
|
|
|
Returns: |
|
|
Tuple of (sam2_predictor, matanyone_model) |
|
|
""" |
|
|
start_time = time.time() |
|
|
self.loading_stats['loading_attempts'] += 1 |
|
|
|
|
|
try: |
|
|
logger.info("Starting model loading process...") |
|
|
if progress_callback: |
|
|
progress_callback(0.0, "Initializing model loading...") |
|
|
|
|
|
|
|
|
self._cleanup_models() |
|
|
|
|
|
|
|
|
logger.info("Loading SAM2 predictor...") |
|
|
if progress_callback: |
|
|
progress_callback(0.1, "Loading SAM2 predictor...") |
|
|
|
|
|
self.sam2_predictor = self._load_sam2_predictor(progress_callback) |
|
|
|
|
|
if self.sam2_predictor is None: |
|
|
raise exceptions.ModelLoadingError("SAM2", "Failed to load SAM2 predictor") |
|
|
|
|
|
sam2_time = time.time() - start_time |
|
|
self.loading_stats['sam2_load_time'] = sam2_time |
|
|
logger.info(f"SAM2 loaded in {sam2_time:.2f}s") |
|
|
|
|
|
|
|
|
logger.info("Loading MatAnyone model...") |
|
|
if progress_callback: |
|
|
progress_callback(0.6, "Loading MatAnyone model...") |
|
|
|
|
|
matanyone_start = time.time() |
|
|
|
|
|
self.matanyone_model, self.matanyone_core = self._load_matanyone_model(progress_callback) |
|
|
|
|
|
if self.matanyone_model is None: |
|
|
raise exceptions.ModelLoadingError("MatAnyone", "Failed to load MatAnyone model") |
|
|
|
|
|
matanyone_time = time.time() - matanyone_start |
|
|
self.loading_stats['matanyone_load_time'] = matanyone_time |
|
|
logger.info(f"MatAnyone loaded in {matanyone_time:.1f}s") |
|
|
|
|
|
|
|
|
total_time = time.time() - start_time |
|
|
self.loading_stats['total_load_time'] = total_time |
|
|
self.loading_stats['models_loaded'] = True |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(1.0, "Models loaded successfully!") |
|
|
|
|
|
logger.info(f"All models loaded successfully in {total_time:.2f}s") |
|
|
|
|
|
return self.sam2_predictor, self.matanyone_model |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Model loading failed: {str(e)}" |
|
|
logger.error(f"{error_msg}\n{traceback.format_exc()}") |
|
|
|
|
|
|
|
|
self._cleanup_models() |
|
|
self.loading_stats['models_loaded'] = False |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(1.0, f"Error: {error_msg}") |
|
|
|
|
|
return None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_sam2_predictor(self, progress_callback: Optional[callable] = None): |
|
|
""" |
|
|
Load SAM2 using HuggingFace Transformers integration with cache cleanup |
|
|
This method works reliably on HuggingFace Spaces without config file issues |
|
|
|
|
|
Args: |
|
|
progress_callback: Progress update callback |
|
|
|
|
|
Returns: |
|
|
SAM2 model or None |
|
|
""" |
|
|
logger.info("=== USING NEW HF TRANSFORMERS SAM2 LOADER ===") |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0.15, "Cleaning caches...") |
|
|
|
|
|
HardCacheCleaner.clean_all_caches(verbose=True) |
|
|
|
|
|
|
|
|
model_size = "large" |
|
|
if hasattr(self.device_manager, 'get_device_memory_gb'): |
|
|
try: |
|
|
memory_gb = self.device_manager.get_device_memory_gb() |
|
|
if memory_gb < 4: |
|
|
model_size = "tiny" |
|
|
elif memory_gb < 8: |
|
|
model_size = "base" |
|
|
logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB memory") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not determine device memory: {e}") |
|
|
|
|
|
|
|
|
model_map = { |
|
|
"tiny": "facebook/sam2.1-hiera-tiny", |
|
|
"small": "facebook/sam2.1-hiera-small", |
|
|
"base": "facebook/sam2.1-hiera-base-plus", |
|
|
"large": "facebook/sam2.1-hiera-large" |
|
|
} |
|
|
|
|
|
model_id = model_map.get(model_size, model_map["large"]) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0.3, f"Loading SAM2 {model_size}...") |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Trying Transformers pipeline approach...") |
|
|
from transformers import pipeline |
|
|
|
|
|
sam2_pipeline = pipeline( |
|
|
"mask-generation", |
|
|
model=model_id, |
|
|
device=0 if str(self.device) == "cuda" else -1 |
|
|
) |
|
|
|
|
|
logger.info("SAM2 loaded successfully via Transformers pipeline") |
|
|
return sam2_pipeline |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Pipeline approach failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Trying direct Transformers classes...") |
|
|
from transformers import Sam2Processor, Sam2Model |
|
|
|
|
|
processor = Sam2Processor.from_pretrained(model_id) |
|
|
model = Sam2Model.from_pretrained(model_id).to(self.device) |
|
|
|
|
|
logger.info("SAM2 loaded successfully via Transformers classes") |
|
|
return {"model": model, "processor": processor} |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Direct class approach failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Trying official SAM2 from_pretrained...") |
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
|
|
predictor = SAM2ImagePredictor.from_pretrained(model_id) |
|
|
|
|
|
logger.info("SAM2 loaded successfully via official from_pretrained") |
|
|
return predictor |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Official from_pretrained approach failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Trying fallback checkpoint approach...") |
|
|
from huggingface_hub import hf_hub_download |
|
|
from transformers import Sam2Model |
|
|
|
|
|
|
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=model_id, |
|
|
filename="model.safetensors" if "sam2.1" in model_id else "pytorch_model.bin" |
|
|
) |
|
|
|
|
|
logger.info(f"Downloaded checkpoint to: {checkpoint_path}") |
|
|
|
|
|
|
|
|
model = Sam2Model.from_pretrained(model_id) |
|
|
model = model.to(self.device) |
|
|
|
|
|
logger.info("SAM2 loaded successfully via fallback approach") |
|
|
return model |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Fallback approach failed: {e}") |
|
|
|
|
|
logger.error("All SAM2 loading methods failed") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_matanyone_model(self, progress_callback: Optional[callable] = None): |
|
|
""" |
|
|
Load MatAnyone model with multiple import strategies |
|
|
|
|
|
Args: |
|
|
progress_callback: Progress update callback |
|
|
|
|
|
Returns: |
|
|
Tuple[model, core] or (None, None) |
|
|
""" |
|
|
import_strategies = [ |
|
|
self._load_matanyone_strategy_1, |
|
|
self._load_matanyone_strategy_2, |
|
|
self._load_matanyone_strategy_3, |
|
|
self._load_matanyone_strategy_4 |
|
|
] |
|
|
|
|
|
for i, strategy in enumerate(import_strategies, 1): |
|
|
try: |
|
|
logger.info(f"Trying MatAnyone loading strategy {i}...") |
|
|
if progress_callback: |
|
|
progress_callback(0.7 + (i * 0.05), f"MatAnyone strategy {i}...") |
|
|
|
|
|
model, core = strategy() |
|
|
if model is not None and core is not None: |
|
|
logger.info(f"MatAnyone loaded successfully with strategy {i}") |
|
|
return model, core |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"MatAnyone strategy {i} failed: {e}") |
|
|
continue |
|
|
|
|
|
logger.error("All MatAnyone loading strategies failed") |
|
|
return None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_matanyone_strategy_1(self): |
|
|
"""MatAnyone loading strategy 1: Official HuggingFace InferenceCore""" |
|
|
from matanyone import InferenceCore |
|
|
|
|
|
|
|
|
processor = InferenceCore("PeiqingYang/MatAnyone") |
|
|
return processor, processor |
|
|
|
|
|
def _load_matanyone_strategy_2(self): |
|
|
"""MatAnyone loading strategy 2: Alternative import paths""" |
|
|
from matanyone import MatAnyOne |
|
|
from matanyone import InferenceCore |
|
|
|
|
|
cfg = OmegaConf.create({ |
|
|
'model_name': 'matanyone', |
|
|
'device': str(self.device) |
|
|
}) |
|
|
|
|
|
model = MatAnyOne(cfg) |
|
|
core = InferenceCore(model, cfg) |
|
|
|
|
|
return model, core |
|
|
|
|
|
def _load_matanyone_strategy_3(self): |
|
|
"""MatAnyone loading strategy 3: Repository-specific imports""" |
|
|
try: |
|
|
from matanyone.models.matanyone import MatAnyOneModel |
|
|
from matanyone.core import InferenceEngine |
|
|
except ImportError: |
|
|
from matanyone.src.models import MatAnyOneModel |
|
|
from matanyone.src.core import InferenceEngine |
|
|
|
|
|
config = { |
|
|
'model_path': None, |
|
|
'device': self.device, |
|
|
'precision': 'fp16' if self.device.type == 'cuda' else 'fp32' |
|
|
} |
|
|
|
|
|
model = MatAnyOneModel.from_pretrained(config) |
|
|
engine = InferenceEngine(model) |
|
|
|
|
|
return model, engine |
|
|
|
|
|
def _load_matanyone_strategy_4(self): |
|
|
"""MatAnyone loading strategy 4: Direct model class""" |
|
|
from matanyone.model.matanyone import MatAnyone |
|
|
|
|
|
model = MatAnyone.from_pretrained("not-lain/matanyone") |
|
|
return model, model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cleanup_models(self): |
|
|
"""Clean up loaded models and free memory""" |
|
|
if self.sam2_predictor is not None: |
|
|
del self.sam2_predictor |
|
|
self.sam2_predictor = None |
|
|
|
|
|
if self.matanyone_model is not None: |
|
|
del self.matanyone_model |
|
|
self.matanyone_model = None |
|
|
|
|
|
if self.matanyone_core is not None: |
|
|
del self.matanyone_core |
|
|
self.matanyone_core = None |
|
|
|
|
|
|
|
|
self.memory_manager.cleanup_aggressive() |
|
|
gc.collect() |
|
|
|
|
|
logger.debug("Model cleanup completed") |
|
|
|
|
|
def cleanup(self): |
|
|
"""Clean up all resources""" |
|
|
self._cleanup_models() |
|
|
logger.info("ModelLoader cleanup completed") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_info(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get information about loaded models |
|
|
|
|
|
Returns: |
|
|
Dict with model information and statistics |
|
|
""" |
|
|
info = { |
|
|
'models_loaded': self.loading_stats['models_loaded'], |
|
|
'sam2_loaded': self.sam2_predictor is not None, |
|
|
'matanyone_loaded': self.matanyone_model is not None, |
|
|
'device': str(self.device), |
|
|
'loading_stats': self.loading_stats.copy() |
|
|
} |
|
|
|
|
|
if self.sam2_predictor is not None: |
|
|
try: |
|
|
info['sam2_model_type'] = type(self.sam2_predictor).__name__ |
|
|
except: |
|
|
info['sam2_model_type'] = "Unknown" |
|
|
|
|
|
if self.matanyone_model is not None: |
|
|
try: |
|
|
info['matanyone_model_type'] = type(self.matanyone_model).__name__ |
|
|
except: |
|
|
info['matanyone_model_type'] = "Unknown" |
|
|
|
|
|
return info |
|
|
|
|
|
def get_status(self) -> Dict[str, Any]: |
|
|
"""Get model loader status for backward compatibility""" |
|
|
return self.get_model_info() |
|
|
|
|
|
def get_load_summary(self) -> str: |
|
|
"""Get a human-readable summary of model loading""" |
|
|
if not self.loading_stats['models_loaded']: |
|
|
return "Models not loaded" |
|
|
|
|
|
sam2_time = self.loading_stats['sam2_load_time'] |
|
|
matanyone_time = self.loading_stats['matanyone_load_time'] |
|
|
total_time = self.loading_stats['total_load_time'] |
|
|
|
|
|
summary = f"Models loaded successfully in {total_time:.1f}s\n" |
|
|
summary += f"SAM2: {sam2_time:.1f}s\n" |
|
|
summary += f"MatAnyone: {matanyone_time:.1f}s\n" |
|
|
summary += f"Device: {self.device}" |
|
|
|
|
|
return summary |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_models(self) -> bool: |
|
|
""" |
|
|
Validate that models are properly loaded and functional |
|
|
|
|
|
Returns: |
|
|
bool: True if models are valid |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not self.loading_stats['models_loaded']: |
|
|
return False |
|
|
|
|
|
if self.sam2_predictor is None or self.matanyone_model is None: |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Model validation passed") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model validation failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reload_models(self, progress_callback: Optional[callable] = None) -> Tuple[Any, Any]: |
|
|
""" |
|
|
Reload all models (useful for error recovery) |
|
|
|
|
|
Args: |
|
|
progress_callback: Progress update callback |
|
|
|
|
|
Returns: |
|
|
Tuple of (sam2_predictor, matanyone_model) |
|
|
""" |
|
|
logger.info("Reloading models...") |
|
|
self._cleanup_models() |
|
|
self.loading_stats['models_loaded'] = False |
|
|
|
|
|
return self.load_all_models(progress_callback) |
|
|
|
|
|
@property |
|
|
def models_ready(self) -> bool: |
|
|
"""Check if all models are loaded and ready""" |
|
|
return ( |
|
|
self.loading_stats['models_loaded'] and |
|
|
self.sam2_predictor is not None and |
|
|
self.matanyone_model is not None |
|
|
) |