VideoBackgroundReplacer / cache_cleaner.py
MogensR's picture
Create cache_cleaner.py
cc63301
raw
history blame
14.1 kB
# ============================================================================ #
# HARD CACHE CLEANER + WORKING SAM2 LOADER FOR HUGGINGFACE SPACES
# ============================================================================ #
import os
import gc
import sys
import shutil
import tempfile
import logging
import traceback
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
logger = logging.getLogger(__name__)
class HardCacheCleaner:
"""
Comprehensive cache cleaning system to resolve SAM2 loading issues
Clears Python module cache, HuggingFace cache, and temp files
"""
@staticmethod
def clean_all_caches(verbose: bool = True):
"""Clean all caches that might interfere with SAM2 loading"""
if verbose:
logger.info("🧹 Starting comprehensive cache cleanup...")
# 1. Clean Python module cache
HardCacheCleaner._clean_python_cache(verbose)
# 2. Clean HuggingFace cache
HardCacheCleaner._clean_huggingface_cache(verbose)
# 3. Clean PyTorch cache
HardCacheCleaner._clean_pytorch_cache(verbose)
# 4. Clean temp directories
HardCacheCleaner._clean_temp_directories(verbose)
# 5. Clear import cache
HardCacheCleaner._clear_import_cache(verbose)
# 6. Force garbage collection
HardCacheCleaner._force_gc_cleanup(verbose)
if verbose:
logger.info("βœ… Cache cleanup completed")
@staticmethod
def _clean_python_cache(verbose: bool = True):
"""Clean Python bytecode cache"""
try:
# Clear sys.modules cache for SAM2 related modules
sam2_modules = [key for key in sys.modules.keys() if 'sam2' in key.lower()]
for module in sam2_modules:
if verbose:
logger.info(f"πŸ—‘οΈ Removing cached module: {module}")
del sys.modules[module]
# Clear __pycache__ directories
for root, dirs, files in os.walk("."):
for dir_name in dirs[:]: # Use slice to modify list during iteration
if dir_name == "__pycache__":
cache_path = os.path.join(root, dir_name)
if verbose:
logger.info(f"πŸ—‘οΈ Removing __pycache__: {cache_path}")
shutil.rmtree(cache_path, ignore_errors=True)
dirs.remove(dir_name)
except Exception as e:
logger.warning(f"Python cache cleanup failed: {e}")
@staticmethod
def _clean_huggingface_cache(verbose: bool = True):
"""Clean HuggingFace model cache"""
try:
cache_paths = [
os.path.expanduser("~/.cache/huggingface/"),
os.path.expanduser("~/.cache/torch/"),
"./checkpoints/",
"./.cache/",
]
for cache_path in cache_paths:
if os.path.exists(cache_path):
if verbose:
logger.info(f"πŸ—‘οΈ Cleaning cache directory: {cache_path}")
# Remove SAM2 specific files
for root, dirs, files in os.walk(cache_path):
for file in files:
if any(pattern in file.lower() for pattern in ['sam2', 'segment-anything-2']):
file_path = os.path.join(root, file)
try:
os.remove(file_path)
if verbose:
logger.info(f"πŸ—‘οΈ Removed cached file: {file_path}")
except:
pass
for dir_name in dirs[:]:
if any(pattern in dir_name.lower() for pattern in ['sam2', 'segment-anything-2']):
dir_path = os.path.join(root, dir_name)
try:
shutil.rmtree(dir_path, ignore_errors=True)
if verbose:
logger.info(f"πŸ—‘οΈ Removed cached directory: {dir_path}")
dirs.remove(dir_name)
except:
pass
except Exception as e:
logger.warning(f"HuggingFace cache cleanup failed: {e}")
@staticmethod
def _clean_pytorch_cache(verbose: bool = True):
"""Clean PyTorch cache"""
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if verbose:
logger.info("πŸ—‘οΈ Cleared PyTorch CUDA cache")
except Exception as e:
logger.warning(f"PyTorch cache cleanup failed: {e}")
@staticmethod
def _clean_temp_directories(verbose: bool = True):
"""Clean temporary directories"""
try:
temp_dirs = [tempfile.gettempdir(), "/tmp", "./tmp", "./temp"]
for temp_dir in temp_dirs:
if os.path.exists(temp_dir):
for item in os.listdir(temp_dir):
if 'sam2' in item.lower() or 'segment' in item.lower():
item_path = os.path.join(temp_dir, item)
try:
if os.path.isfile(item_path):
os.remove(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path, ignore_errors=True)
if verbose:
logger.info(f"πŸ—‘οΈ Removed temp item: {item_path}")
except:
pass
except Exception as e:
logger.warning(f"Temp directory cleanup failed: {e}")
@staticmethod
def _clear_import_cache(verbose: bool = True):
"""Clear Python import cache"""
try:
import importlib
# Invalidate import caches
importlib.invalidate_caches()
if verbose:
logger.info("πŸ—‘οΈ Cleared Python import cache")
except Exception as e:
logger.warning(f"Import cache cleanup failed: {e}")
@staticmethod
def _force_gc_cleanup(verbose: bool = True):
"""Force garbage collection"""
try:
collected = gc.collect()
if verbose:
logger.info(f"πŸ—‘οΈ Garbage collection freed {collected} objects")
except Exception as e:
logger.warning(f"Garbage collection failed: {e}")
class WorkingSAM2Loader:
"""
SAM2 loader using HuggingFace Transformers integration - proven to work on HF Spaces
This avoids all the config file and CUDA compilation issues
"""
@staticmethod
def load_sam2_transformers_approach(device: str = "cuda", model_size: str = "large") -> Optional[Any]:
"""
Load SAM2 using HuggingFace Transformers integration
This method works reliably on HuggingFace Spaces
"""
try:
logger.info("πŸ€– Loading SAM2 via HuggingFace Transformers...")
# Model size mapping
model_map = {
"tiny": "facebook/sam2.1-hiera-tiny",
"small": "facebook/sam2.1-hiera-small",
"base": "facebook/sam2.1-hiera-base-plus",
"large": "facebook/sam2.1-hiera-large"
}
model_id = model_map.get(model_size, model_map["large"])
logger.info(f"Using model: {model_id}")
# Method 1: Using Transformers pipeline (most reliable for HF Spaces)
try:
from transformers import pipeline
sam2_pipeline = pipeline(
"mask-generation",
model=model_id,
device=0 if device == "cuda" else -1
)
logger.info("βœ… SAM2 loaded successfully via Transformers pipeline")
return sam2_pipeline
except Exception as e:
logger.warning(f"Pipeline approach failed: {e}")
# Method 2: Using SAM2 classes directly via Transformers
try:
from transformers import Sam2Processor, Sam2Model
processor = Sam2Processor.from_pretrained(model_id)
model = Sam2Model.from_pretrained(model_id).to(device)
logger.info("βœ… SAM2 loaded successfully via Transformers classes")
return {"model": model, "processor": processor}
except Exception as e:
logger.warning(f"Direct class approach failed: {e}")
# Method 3: Using official SAM2 with .from_pretrained()
try:
from sam2.sam2_image_predictor import SAM2ImagePredictor
predictor = SAM2ImagePredictor.from_pretrained(model_id)
logger.info("βœ… SAM2 loaded successfully via official from_pretrained")
return predictor
except Exception as e:
logger.warning(f"Official from_pretrained approach failed: {e}")
return None
except Exception as e:
logger.error(f"All SAM2 loading methods failed: {e}")
return None
@staticmethod
def load_sam2_fallback_approach(device: str = "cuda") -> Optional[Any]:
"""
Fallback approach using direct model loading
"""
try:
logger.info("πŸ”„ Trying fallback SAM2 loading approach...")
# Try the simplest possible approach
from huggingface_hub import hf_hub_download
import torch
# Download checkpoint directly
checkpoint_path = hf_hub_download(
repo_id="facebook/sam2.1-hiera-large",
filename="sam2_hiera_large.pt"
)
logger.info(f"Downloaded checkpoint to: {checkpoint_path}")
# Try to load with minimal dependencies
try:
# Method A: Try the working transformers integration
from transformers import Sam2Model
model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-large")
return model.to(device)
except Exception as e:
logger.warning(f"Transformers fallback failed: {e}")
return None
except Exception as e:
logger.error(f"Fallback loading failed: {e}")
return None
# ============================================================================ #
# INTEGRATED MODEL LOADER WITH CACHE CLEANING
# ============================================================================ #
def load_sam2_with_cache_cleanup(
device: str = "cuda",
model_size: str = "large",
force_cache_clean: bool = True,
verbose: bool = True
) -> Tuple[Optional[Any], str]:
"""
Load SAM2 with comprehensive cache cleanup
Returns:
Tuple of (model, status_message)
"""
status_messages = []
try:
# Step 1: Clean caches if requested
if force_cache_clean:
status_messages.append("🧹 Cleaning caches...")
HardCacheCleaner.clean_all_caches(verbose=verbose)
status_messages.append("βœ… Cache cleanup completed")
# Step 2: Try primary loading method
status_messages.append("πŸ€– Loading SAM2 (primary method)...")
model = WorkingSAM2Loader.load_sam2_transformers_approach(device, model_size)
if model is not None:
status_messages.append("βœ… SAM2 loaded successfully!")
return model, "\n".join(status_messages)
# Step 3: Try fallback method
status_messages.append("πŸ”„ Trying fallback loading method...")
model = WorkingSAM2Loader.load_sam2_fallback_approach(device)
if model is not None:
status_messages.append("βœ… SAM2 loaded successfully (fallback)!")
return model, "\n".join(status_messages)
# Step 4: All methods failed
status_messages.append("❌ All SAM2 loading methods failed")
return None, "\n".join(status_messages)
except Exception as e:
error_msg = f"❌ Critical error in SAM2 loading: {e}"
logger.error(f"{error_msg}\n{traceback.format_exc()}")
status_messages.append(error_msg)
return None, "\n".join(status_messages)
# ============================================================================ #
# USAGE EXAMPLE
# ============================================================================ #
if __name__ == "__main__":
# Clean example usage
print("Testing SAM2 loader with cache cleanup...")
# Load SAM2 with full cache cleanup
model, status = load_sam2_with_cache_cleanup(
device="cuda",
model_size="large",
force_cache_clean=True,
verbose=True
)
print("Status:", status)
if model is not None:
print("SAM2 loaded successfully!")
print("Model type:", type(model))
else:
print("SAM2 loading failed completely")