Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Model Persistence Manager for LightDiffusion | |
| Keeps models loaded in VRAM for instant reuse between generations | |
| """ | |
| from typing import Dict, Optional, Any, Tuple, List | |
| import logging | |
| from modules.Device import Device | |
| class ModelCache: | |
| """Global model cache to keep models loaded in VRAM""" | |
| def __init__(self): | |
| self._cached_models: Dict[str, Any] = {} | |
| self._cached_clip: Optional[Any] = None | |
| self._cached_vae: Optional[Any] = None | |
| self._cached_model_patcher: Optional[Any] = None | |
| self._cached_conditions: Dict[str, Any] = {} | |
| self._last_checkpoint_path: Optional[str] = None | |
| self._keep_models_loaded: bool = True | |
| self._loaded_models_list: List[Any] = [] | |
| def set_keep_models_loaded(self, keep_loaded: bool) -> None: | |
| """Enable or disable keeping models loaded in VRAM""" | |
| self._keep_models_loaded = keep_loaded | |
| if not keep_loaded: | |
| self.clear_cache() | |
| def get_keep_models_loaded(self) -> bool: | |
| """Check if models should be kept loaded""" | |
| return self._keep_models_loaded | |
| def cache_checkpoint( | |
| self, checkpoint_path: str, model_patcher: Any, clip: Any, vae: Any | |
| ) -> None: | |
| """Cache a loaded checkpoint""" | |
| if not self._keep_models_loaded: | |
| return | |
| self._last_checkpoint_path = checkpoint_path | |
| self._cached_model_patcher = model_patcher | |
| self._cached_clip = clip | |
| self._cached_vae = vae | |
| logging.info(f"Cached checkpoint: {checkpoint_path}") | |
| def get_cached_checkpoint( | |
| self, checkpoint_path: str | |
| ) -> Optional[Tuple[Any, Any, Any]]: | |
| """Get cached checkpoint if available""" | |
| if not self._keep_models_loaded: | |
| return None | |
| if ( | |
| self._last_checkpoint_path == checkpoint_path | |
| and self._cached_model_patcher is not None | |
| and self._cached_clip is not None | |
| and self._cached_vae is not None | |
| ): | |
| logging.info(f"Using cached checkpoint: {checkpoint_path}") | |
| return self._cached_model_patcher, self._cached_clip, self._cached_vae | |
| return None | |
| def cache_sampling_models(self, models: List[Any]) -> None: | |
| """Cache models used during sampling""" | |
| if not self._keep_models_loaded: | |
| return | |
| self._loaded_models_list = models.copy() | |
| def get_cached_sampling_models(self) -> List[Any]: | |
| """Get cached sampling models""" | |
| if not self._keep_models_loaded: | |
| return [] | |
| return self._loaded_models_list | |
| def prevent_model_cleanup(self, conds: Dict[str, Any], models: List[Any]) -> None: | |
| """Prevent models from being cleaned up if caching is enabled""" | |
| if not self._keep_models_loaded: | |
| # Original cleanup behavior | |
| from modules.cond import cond_util | |
| cond_util.cleanup_additional_models(models) | |
| control_cleanup = [] | |
| for k in conds: | |
| control_cleanup += cond_util.get_models_from_cond(conds[k], "control") | |
| cond_util.cleanup_additional_models(set(control_cleanup)) | |
| else: | |
| # Keep models loaded - only cleanup control models that aren't main models | |
| control_cleanup = [] | |
| for k in conds: | |
| from modules.cond import cond_util | |
| control_cleanup += cond_util.get_models_from_cond(conds[k], "control") | |
| # Only cleanup control models, not the main models | |
| from modules.cond import cond_util | |
| cond_util.cleanup_additional_models(set(control_cleanup)) | |
| logging.info("Kept main models loaded in VRAM for reuse") | |
| def clear_cache(self) -> None: | |
| """Clear all cached models""" | |
| if self._cached_model_patcher is not None: | |
| try: | |
| # Properly unload the cached models | |
| if hasattr(self._cached_model_patcher, "model_unload"): | |
| self._cached_model_patcher.model_unload() | |
| except Exception as e: | |
| logging.warning(f"Error unloading cached model: {e}") | |
| self._cached_models.clear() | |
| self._cached_clip = None | |
| self._cached_vae = None | |
| self._cached_model_patcher = None | |
| self._cached_conditions.clear() | |
| self._last_checkpoint_path = None | |
| self._loaded_models_list.clear() | |
| # Force cleanup | |
| Device.cleanup_models(keep_clone_weights_loaded=False) | |
| Device.soft_empty_cache(force=True) | |
| logging.info("Cleared model cache and freed VRAM") | |
| def get_memory_info(self) -> Dict[str, Any]: | |
| """Get memory usage information""" | |
| device = Device.get_torch_device() | |
| total_mem = Device.get_total_memory(device) | |
| free_mem = Device.get_free_memory(device) | |
| used_mem = total_mem - free_mem | |
| return { | |
| "total_vram": total_mem / (1024 * 1024 * 1024), # GB | |
| "used_vram": used_mem / (1024 * 1024 * 1024), # GB | |
| "free_vram": free_mem / (1024 * 1024 * 1024), # GB | |
| "cached_models": len(self._cached_models), | |
| "keep_loaded": self._keep_models_loaded, | |
| "has_cached_checkpoint": self._cached_model_patcher is not None, | |
| } | |
| # Global model cache instance | |
| model_cache = ModelCache() | |
| def get_model_cache() -> ModelCache: | |
| """Get the global model cache instance""" | |
| return model_cache | |
| def set_keep_models_loaded(keep_loaded: bool) -> None: | |
| """Global function to enable/disable model persistence""" | |
| model_cache.set_keep_models_loaded(keep_loaded) | |
| def get_keep_models_loaded() -> bool: | |
| """Global function to check if models should be kept loaded""" | |
| return model_cache.get_keep_models_loaded() | |
| def clear_model_cache() -> None: | |
| """Global function to clear model cache""" | |
| model_cache.clear_cache() | |
| def get_memory_info() -> Dict[str, Any]: | |
| """Global function to get memory info""" | |
| return model_cache.get_memory_info() | |