Spaces:
Paused
Paused
| import os | |
| import json | |
| import time | |
| import hashlib | |
| from pathlib import Path | |
| from typing import Any, Optional, Dict, Union | |
| import pickle | |
| import shutil | |
| from datetime import datetime, timedelta | |
| class ModelCache: | |
| """Manages caching for AI models and generated content""" | |
| def __init__(self, cache_dir: Optional[Union[str, Path]] = None): | |
| if cache_dir is None: | |
| # Use HuggingFace Spaces persistent storage if available | |
| if os.path.exists("/data"): | |
| cache_dir = "/data/cache" | |
| else: | |
| cache_dir = Path.home() / ".cache" / "digipal" | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| # Cache subdirectories | |
| self.model_cache_dir = self.cache_dir / "models" | |
| self.generation_cache_dir = self.cache_dir / "generations" | |
| self.audio_cache_dir = self.cache_dir / "audio" | |
| for dir_path in [self.model_cache_dir, self.generation_cache_dir, self.audio_cache_dir]: | |
| dir_path.mkdir(exist_ok=True) | |
| # Cache settings | |
| self.max_cache_size_gb = 10 # Maximum cache size in GB | |
| self.cache_expiry_days = 7 # Cache expiry in days | |
| self.generation_cache_enabled = True | |
| # In-memory cache for fast access | |
| self.memory_cache = {} | |
| self.cache_stats = self._load_cache_stats() | |
| def cache_model_weights(self, model_id: str, model_data: Any) -> bool: | |
| """Cache model weights to disk""" | |
| try: | |
| model_hash = self._get_hash(model_id) | |
| cache_path = self.model_cache_dir / f"{model_hash}.pkl" | |
| with open(cache_path, 'wb') as f: | |
| pickle.dump(model_data, f) | |
| # Update cache stats | |
| self._update_cache_stats('model', model_id, cache_path.stat().st_size) | |
| return True | |
| except Exception as e: | |
| print(f"Failed to cache model {model_id}: {e}") | |
| return False | |
| def get_cached_model(self, model_id: str) -> Optional[Any]: | |
| """Retrieve cached model weights""" | |
| try: | |
| model_hash = self._get_hash(model_id) | |
| cache_path = self.model_cache_dir / f"{model_hash}.pkl" | |
| if cache_path.exists(): | |
| # Check if cache is still valid | |
| if self._is_cache_valid(cache_path): | |
| with open(cache_path, 'rb') as f: | |
| return pickle.load(f) | |
| return None | |
| except Exception as e: | |
| print(f"Failed to load cached model {model_id}: {e}") | |
| return None | |
| def cache_generation(self, prompt: str, result: Dict[str, Any], | |
| generation_type: str = "monster") -> str: | |
| """Cache generation results""" | |
| if not self.generation_cache_enabled: | |
| return "" | |
| try: | |
| # Create unique key for this generation | |
| cache_key = self._get_generation_key(prompt, generation_type) | |
| cache_dir = self.generation_cache_dir / generation_type / cache_key[:2] | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| cache_file = cache_dir / f"{cache_key}.json" | |
| # Prepare cache data | |
| cache_data = { | |
| 'prompt': prompt, | |
| 'type': generation_type, | |
| 'timestamp': datetime.now().isoformat(), | |
| 'result': result | |
| } | |
| # Handle file paths in results | |
| if 'image' in result and hasattr(result['image'], 'save'): | |
| image_path = cache_dir / f"{cache_key}_image.png" | |
| result['image'].save(image_path) | |
| cache_data['result']['image'] = str(image_path) | |
| if 'model_3d' in result and isinstance(result['model_3d'], str): | |
| # Copy 3D model to cache | |
| model_ext = Path(result['model_3d']).suffix | |
| model_cache_path = cache_dir / f"{cache_key}_model{model_ext}" | |
| shutil.copy2(result['model_3d'], model_cache_path) | |
| cache_data['result']['model_3d'] = str(model_cache_path) | |
| # Save cache data | |
| with open(cache_file, 'w') as f: | |
| json.dump(cache_data, f, indent=2) | |
| # Update stats | |
| self._update_cache_stats('generation', cache_key, cache_file.stat().st_size) | |
| return cache_key | |
| except Exception as e: | |
| print(f"Failed to cache generation: {e}") | |
| return "" | |
| def get_cached_generation(self, prompt: str, generation_type: str = "monster") -> Optional[Dict[str, Any]]: | |
| """Retrieve cached generation if available""" | |
| if not self.generation_cache_enabled: | |
| return None | |
| try: | |
| cache_key = self._get_generation_key(prompt, generation_type) | |
| cache_file = self.generation_cache_dir / generation_type / cache_key[:2] / f"{cache_key}.json" | |
| if cache_file.exists() and self._is_cache_valid(cache_file): | |
| with open(cache_file, 'r') as f: | |
| cache_data = json.load(f) | |
| # Load associated files | |
| result = cache_data['result'] | |
| if 'image' in result and isinstance(result['image'], str): | |
| from PIL import Image | |
| if os.path.exists(result['image']): | |
| result['image'] = Image.open(result['image']) | |
| return result | |
| return None | |
| except Exception as e: | |
| print(f"Failed to load cached generation: {e}") | |
| return None | |
| def cache_audio_transcription(self, audio_path: str, transcription: str) -> bool: | |
| """Cache audio transcription results""" | |
| try: | |
| # Get audio file hash | |
| with open(audio_path, 'rb') as f: | |
| audio_hash = hashlib.md5(f.read()).hexdigest() | |
| cache_file = self.audio_cache_dir / f"{audio_hash}.json" | |
| cache_data = { | |
| 'audio_path': audio_path, | |
| 'transcription': transcription, | |
| 'timestamp': datetime.now().isoformat() | |
| } | |
| with open(cache_file, 'w') as f: | |
| json.dump(cache_data, f) | |
| return True | |
| except Exception as e: | |
| print(f"Failed to cache audio transcription: {e}") | |
| return False | |
| def get_cached_transcription(self, audio_path: str) -> Optional[str]: | |
| """Get cached audio transcription""" | |
| try: | |
| with open(audio_path, 'rb') as f: | |
| audio_hash = hashlib.md5(f.read()).hexdigest() | |
| cache_file = self.audio_cache_dir / f"{audio_hash}.json" | |
| if cache_file.exists() and self._is_cache_valid(cache_file): | |
| with open(cache_file, 'r') as f: | |
| cache_data = json.load(f) | |
| return cache_data['transcription'] | |
| return None | |
| except Exception as e: | |
| print(f"Failed to load cached transcription: {e}") | |
| return None | |
| def add_to_memory_cache(self, key: str, value: Any, ttl_seconds: int = 300): | |
| """Add item to in-memory cache with TTL""" | |
| expiry_time = time.time() + ttl_seconds | |
| self.memory_cache[key] = { | |
| 'value': value, | |
| 'expiry': expiry_time | |
| } | |
| def get_from_memory_cache(self, key: str) -> Optional[Any]: | |
| """Get item from in-memory cache""" | |
| if key in self.memory_cache: | |
| cache_item = self.memory_cache[key] | |
| if time.time() < cache_item['expiry']: | |
| return cache_item['value'] | |
| else: | |
| # Remove expired item | |
| del self.memory_cache[key] | |
| return None | |
| def clear_expired_cache(self): | |
| """Clear expired cache entries""" | |
| current_time = datetime.now() | |
| cleared_size = 0 | |
| # Clear file cache | |
| for cache_type in [self.model_cache_dir, self.generation_cache_dir, self.audio_cache_dir]: | |
| for file_path in cache_type.rglob('*'): | |
| if file_path.is_file(): | |
| file_age = current_time - datetime.fromtimestamp(file_path.stat().st_mtime) | |
| if file_age > timedelta(days=self.cache_expiry_days): | |
| file_size = file_path.stat().st_size | |
| file_path.unlink() | |
| cleared_size += file_size | |
| # Clear memory cache | |
| expired_keys = [ | |
| key for key, item in self.memory_cache.items() | |
| if time.time() > item['expiry'] | |
| ] | |
| for key in expired_keys: | |
| del self.memory_cache[key] | |
| print(f"Cleared {cleared_size / (1024**2):.2f} MB of expired cache") | |
| return cleared_size | |
| def get_cache_size(self) -> Dict[str, float]: | |
| """Get current cache size in MB""" | |
| sizes = { | |
| 'models': 0, | |
| 'generations': 0, | |
| 'audio': 0, | |
| 'total': 0 | |
| } | |
| # Calculate directory sizes | |
| for file_path in self.model_cache_dir.rglob('*'): | |
| if file_path.is_file(): | |
| sizes['models'] += file_path.stat().st_size | |
| for file_path in self.generation_cache_dir.rglob('*'): | |
| if file_path.is_file(): | |
| sizes['generations'] += file_path.stat().st_size | |
| for file_path in self.audio_cache_dir.rglob('*'): | |
| if file_path.is_file(): | |
| sizes['audio'] += file_path.stat().st_size | |
| # Convert to MB | |
| for key in sizes: | |
| sizes[key] = sizes[key] / (1024 ** 2) | |
| sizes['total'] = sizes['models'] + sizes['generations'] + sizes['audio'] | |
| return sizes | |
| def enforce_size_limit(self): | |
| """Enforce cache size limit by removing oldest entries""" | |
| cache_size = self.get_cache_size() | |
| if cache_size['total'] > self.max_cache_size_gb * 1024: # Convert GB to MB | |
| # Get all cache files with timestamps | |
| all_files = [] | |
| for cache_dir in [self.model_cache_dir, self.generation_cache_dir, self.audio_cache_dir]: | |
| for file_path in cache_dir.rglob('*'): | |
| if file_path.is_file(): | |
| all_files.append({ | |
| 'path': file_path, | |
| 'size': file_path.stat().st_size, | |
| 'mtime': file_path.stat().st_mtime | |
| }) | |
| # Sort by modification time (oldest first) | |
| all_files.sort(key=lambda x: x['mtime']) | |
| # Remove files until under limit | |
| current_size = cache_size['total'] * (1024 ** 2) # Convert to bytes | |
| target_size = self.max_cache_size_gb * (1024 ** 3) * 0.8 # 80% of limit | |
| for file_info in all_files: | |
| if current_size <= target_size: | |
| break | |
| file_info['path'].unlink() | |
| current_size -= file_info['size'] | |
| print(f"Removed {file_info['path'].name} to enforce cache limit") | |
| def _get_hash(self, text: str) -> str: | |
| """Get MD5 hash of text""" | |
| return hashlib.md5(text.encode()).hexdigest() | |
| def _get_generation_key(self, prompt: str, generation_type: str) -> str: | |
| """Get unique key for generation cache""" | |
| combined = f"{generation_type}:{prompt}" | |
| return self._get_hash(combined) | |
| def _is_cache_valid(self, cache_path: Path) -> bool: | |
| """Check if cache file is still valid""" | |
| if not cache_path.exists(): | |
| return False | |
| file_age = datetime.now() - datetime.fromtimestamp(cache_path.stat().st_mtime) | |
| return file_age < timedelta(days=self.cache_expiry_days) | |
| def _load_cache_stats(self) -> Dict[str, Any]: | |
| """Load cache statistics""" | |
| stats_file = self.cache_dir / "cache_stats.json" | |
| if stats_file.exists(): | |
| with open(stats_file, 'r') as f: | |
| return json.load(f) | |
| return { | |
| 'total_hits': 0, | |
| 'total_misses': 0, | |
| 'last_cleanup': datetime.now().isoformat(), | |
| 'entries': {} | |
| } | |
| def _update_cache_stats(self, cache_type: str, key: str, size: int): | |
| """Update cache statistics""" | |
| self.cache_stats['entries'][key] = { | |
| 'type': cache_type, | |
| 'size': size, | |
| 'timestamp': datetime.now().isoformat() | |
| } | |
| # Save stats | |
| stats_file = self.cache_dir / "cache_stats.json" | |
| with open(stats_file, 'w') as f: | |
| json.dump(self.cache_stats, f, indent=2) | |
| def get_cache_info(self) -> Dict[str, Any]: | |
| """Get cache information and statistics""" | |
| sizes = self.get_cache_size() | |
| return { | |
| 'sizes': sizes, | |
| 'stats': self.cache_stats, | |
| 'memory_cache_items': len(self.memory_cache), | |
| 'cache_dir': str(self.cache_dir), | |
| 'max_size_gb': self.max_cache_size_gb, | |
| 'expiry_days': self.cache_expiry_days | |
| } |