""" Model registry for BackgroundFX Pro. Manages available models, versions, and metadata. """ import json import hashlib from pathlib import Path from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass, field, asdict from enum import Enum from datetime import datetime import requests import yaml import logging logger = logging.getLogger(__name__) class ModelStatus(Enum): """Model availability status.""" AVAILABLE = "available" DOWNLOADING = "downloading" NOT_DOWNLOADED = "not_downloaded" CORRUPTED = "corrupted" DEPRECATED = "deprecated" class ModelTask(Enum): """Model task types.""" SEGMENTATION = "segmentation" MATTING = "matting" ENHANCEMENT = "enhancement" DETECTION = "detection" BACKGROUND_GEN = "background_generation" class ModelFramework(Enum): """Supported frameworks.""" PYTORCH = "pytorch" ONNX = "onnx" TENSORRT = "tensorrt" COREML = "coreml" TFLITE = "tflite" @dataclass class ModelInfo: """Model information and metadata.""" # Basic info model_id: str name: str version: str task: ModelTask framework: ModelFramework # Files and URLs url: str mirror_urls: List[str] = field(default_factory=list) filename: str = "" file_size: int = 0 sha256: Optional[str] = None # Model details description: str = "" author: str = "" license: str = "" paper_url: Optional[str] = None github_url: Optional[str] = None # Performance metrics accuracy: Optional[float] = None speed_fps: Optional[float] = None memory_mb: Optional[int] = None # Requirements min_gpu_memory_gb: float = 0 min_ram_gb: float = 2 requires_gpu: bool = False supported_platforms: List[str] = field(default_factory=lambda: ["windows", "linux", "macos"]) # Configuration input_size: Optional[Tuple[int, int]] = None batch_size: int = 1 config: Dict[str, Any] = field(default_factory=dict) # Status status: ModelStatus = ModelStatus.NOT_DOWNLOADED local_path: Optional[str] = None download_date: Optional[datetime] = None last_used: Optional[datetime] = None use_count: int = 0 def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" data = asdict(self) # Convert enums to strings data['task'] = self.task.value data['framework'] = self.framework.value data['status'] = self.status.value # Convert datetime to ISO format if self.download_date: data['download_date'] = self.download_date.isoformat() if self.last_used: data['last_used'] = self.last_used.isoformat() return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo': """Create from dictionary.""" # Convert string enums if 'task' in data: data['task'] = ModelTask(data['task']) if 'framework' in data: data['framework'] = ModelFramework(data['framework']) if 'status' in data: data['status'] = ModelStatus(data['status']) # Convert ISO strings to datetime if 'download_date' in data and data['download_date']: data['download_date'] = datetime.fromisoformat(data['download_date']) if 'last_used' in data and data['last_used']: data['last_used'] = datetime.fromisoformat(data['last_used']) return cls(**data) class ModelRegistry: """Central registry for all available models.""" # Default model definitions DEFAULT_MODELS = { "rmbg-1.4": ModelInfo( model_id="rmbg-1.4", name="RMBG v1.4", version="1.4", task=ModelTask.SEGMENTATION, framework=ModelFramework.ONNX, url="https://huggingface.co/briaai/RMBG-1.4/resolve/main/model.onnx", filename="rmbg_v1.4.onnx", file_size=176_000_000, # ~176MB sha256="d0c3e8c7d98e32b9c30e0c8f228e3c6d1a5e5c8e9f0a1b2c3d4e5f6a7b8c9d0e1", description="State-of-the-art background removal model", author="BRIA AI", license="BRIA RMBG-1.4 Community License", github_url="https://github.com/bria-ai/RMBG-1.4", accuracy=0.98, speed_fps=30, memory_mb=500, requires_gpu=False, input_size=(1024, 1024) ), "u2net": ModelInfo( model_id="u2net", name="U2-Net", version="1.0", task=ModelTask.SEGMENTATION, framework=ModelFramework.PYTORCH, url="https://github.com/xuebinqin/U-2-Net/releases/download/v1.0/u2net.pth", filename="u2net.pth", file_size=176_000_000, description="Salient object detection for background removal", author="Xuebin Qin et al.", license="Apache 2.0", paper_url="https://arxiv.org/abs/2005.09007", accuracy=0.95, speed_fps=20, memory_mb=800, requires_gpu=True, input_size=(320, 320) ), "u2netp": ModelInfo( model_id="u2netp", name="U2-Net Lite", version="1.0", task=ModelTask.SEGMENTATION, framework=ModelFramework.PYTORCH, url="https://github.com/xuebinqin/U-2-Net/releases/download/v1.0/u2netp.pth", filename="u2netp.pth", file_size=4_700_000, # ~4.7MB description="Lightweight version of U2-Net", author="Xuebin Qin et al.", license="Apache 2.0", accuracy=0.92, speed_fps=40, memory_mb=200, requires_gpu=False, input_size=(320, 320) ), "isnet": ModelInfo( model_id="isnet", name="IS-Net", version="1.0", task=ModelTask.SEGMENTATION, framework=ModelFramework.PYTORCH, url="https://github.com/xuebinqin/DIS/releases/download/v1.0/isnet.pth", filename="isnet.pth", file_size=450_000_000, description="Highly accurate salient object detection", author="Xuebin Qin et al.", license="Apache 2.0", paper_url="https://arxiv.org/abs/2203.03041", accuracy=0.97, speed_fps=15, memory_mb=1200, requires_gpu=True, min_gpu_memory_gb=4, input_size=(1024, 1024) ), "modnet": ModelInfo( model_id="modnet", name="MODNet", version="1.0", task=ModelTask.MATTING, framework=ModelFramework.PYTORCH, url="https://github.com/ZHKKKe/MODNet/releases/download/v1.0/modnet_photographic_portrait_matting.ckpt", filename="modnet.ckpt", file_size=25_000_000, description="Trimap-free portrait matting", author="Zhanghan Ke et al.", license="CC BY-NC 4.0", paper_url="https://arxiv.org/abs/2011.11961", github_url="https://github.com/ZHKKKe/MODNet", accuracy=0.94, speed_fps=25, memory_mb=400, requires_gpu=False, input_size=(512, 512) ), "robust_video_matting": ModelInfo( model_id="robust_video_matting", name="Robust Video Matting", version="1.0", task=ModelTask.MATTING, framework=ModelFramework.ONNX, url="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3.onnx", filename="rvm_mobilenetv3.onnx", file_size=14_000_000, description="Temporal coherent video matting", author="Shanchuan Lin et al.", license="GPL-3.0", paper_url="https://arxiv.org/abs/2108.11515", github_url="https://github.com/PeterL1n/RobustVideoMatting", accuracy=0.93, speed_fps=30, memory_mb=300, requires_gpu=False, config={"temporal": True, "recurrent": True} ), "selfie_segmentation": ModelInfo( model_id="selfie_segmentation", name="MediaPipe Selfie Segmentation", version="1.0", task=ModelTask.SEGMENTATION, framework=ModelFramework.TFLITE, url="https://storage.googleapis.com/mediapipe-models/selfie_segmentation/selfie_segmentation.tflite", filename="selfie_segmentation.tflite", file_size=260_000, # ~260KB description="Ultra-lightweight real-time segmentation", author="Google MediaPipe", license="Apache 2.0", accuracy=0.88, speed_fps=60, memory_mb=50, requires_gpu=False, input_size=(256, 256) ) } def __init__(self, models_dir: Optional[Path] = None, config_file: Optional[Path] = None): """ Initialize model registry. Args: models_dir: Directory to store downloaded models config_file: Optional config file with custom models """ self.models_dir = models_dir or Path.home() / ".backgroundfx" / "models" self.models_dir.mkdir(parents=True, exist_ok=True) self.registry_file = self.models_dir / "registry.json" self.models: Dict[str, ModelInfo] = {} # Load registry self._load_registry() # Load custom config if provided if config_file: self._load_custom_config(config_file) # Update model status self._update_model_status() def _load_registry(self): """Load model registry from file or create default.""" if self.registry_file.exists(): try: with open(self.registry_file, 'r') as f: data = json.load(f) for model_id, model_data in data.items(): self.models[model_id] = ModelInfo.from_dict(model_data) logger.info(f"Loaded {len(self.models)} models from registry") except Exception as e: logger.error(f"Failed to load registry: {e}") self._initialize_default_registry() else: self._initialize_default_registry() def _initialize_default_registry(self): """Initialize with default models.""" self.models = self.DEFAULT_MODELS.copy() self._save_registry() logger.info("Initialized registry with default models") def _save_registry(self): """Save registry to file.""" try: data = { model_id: model.to_dict() for model_id, model in self.models.items() } with open(self.registry_file, 'w') as f: json.dump(data, f, indent=2) except Exception as e: logger.error(f"Failed to save registry: {e}") def _load_custom_config(self, config_file: Path): """Load custom model configurations.""" try: with open(config_file, 'r') as f: if config_file.suffix == '.yaml': config = yaml.safe_load(f) else: config = json.load(f) for model_data in config.get('models', []): model = ModelInfo.from_dict(model_data) self.models[model.model_id] = model logger.info(f"Added custom model: {model.name}") self._save_registry() except Exception as e: logger.error(f"Failed to load custom config: {e}") def _update_model_status(self): """Update status of all models based on local files.""" for model_id, model in self.models.items(): model_path = self.models_dir / model.filename if model_path.exists(): # Verify file integrity if self._verify_model_file(model_path, model): model.status = ModelStatus.AVAILABLE model.local_path = str(model_path) else: model.status = ModelStatus.CORRUPTED logger.warning(f"Model {model_id} file is corrupted") else: model.status = ModelStatus.NOT_DOWNLOADED model.local_path = None def _verify_model_file(self, file_path: Path, model: ModelInfo) -> bool: """Verify model file integrity.""" # Check file size if model.file_size > 0: actual_size = file_path.stat().st_size if abs(actual_size - model.file_size) > 1000: # Allow 1KB difference logger.warning(f"Size mismatch for {model.model_id}: " f"expected {model.file_size}, got {actual_size}") return False # Check SHA256 if available if model.sha256: try: sha256 = self._calculate_sha256(file_path) if sha256 != model.sha256: logger.warning(f"SHA256 mismatch for {model.model_id}") return False except Exception as e: logger.error(f"Failed to verify SHA256: {e}") return False return True def _calculate_sha256(self, file_path: Path) -> str: """Calculate SHA256 hash of file.""" sha256_hash = hashlib.sha256() with open(file_path, "rb") as f: for byte_block in iter(lambda: f.read(4096), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() def register_model(self, model: ModelInfo) -> bool: """ Register a new model. Args: model: Model information Returns: True if registered successfully """ try: self.models[model.model_id] = model self._save_registry() logger.info(f"Registered model: {model.name}") return True except Exception as e: logger.error(f"Failed to register model: {e}") return False def get_model(self, model_id: str) -> Optional[ModelInfo]: """Get model information by ID.""" return self.models.get(model_id) def list_models(self, task: Optional[ModelTask] = None, framework: Optional[ModelFramework] = None, status: Optional[ModelStatus] = None) -> List[ModelInfo]: """ List models with optional filtering. Args: task: Filter by task type framework: Filter by framework status: Filter by status Returns: List of matching models """ models = list(self.models.values()) if task: models = [m for m in models if m.task == task] if framework: models = [m for m in models if m.framework == framework] if status: models = [m for m in models if m.status == status] return models def get_best_model(self, task: ModelTask, prefer_speed: bool = False, require_gpu: Optional[bool] = None) -> Optional[ModelInfo]: """ Get best model for a task. Args: task: Task type prefer_speed: Prefer speed over accuracy require_gpu: GPU requirement Returns: Best matching model """ candidates = self.list_models(task=task, status=ModelStatus.AVAILABLE) if require_gpu is not None: candidates = [m for m in candidates if m.requires_gpu == require_gpu] if not candidates: return None # Sort by preference if prefer_speed: candidates.sort(key=lambda m: m.speed_fps or 0, reverse=True) else: candidates.sort(key=lambda m: m.accuracy or 0, reverse=True) return candidates[0] if candidates else None def update_model_usage(self, model_id: str): """Update model usage statistics.""" if model_id in self.models: model = self.models[model_id] model.use_count += 1 model.last_used = datetime.now() self._save_registry() def get_total_size(self, status: Optional[ModelStatus] = None) -> int: """Get total size of models in bytes.""" models = self.list_models(status=status) return sum(m.file_size for m in models) def cleanup_unused_models(self, days: int = 30) -> List[str]: """ Remove models not used in specified days. Args: days: Days threshold Returns: List of removed model IDs """ removed = [] cutoff = datetime.now().timestamp() - (days * 86400) for model_id, model in self.models.items(): if (model.status == ModelStatus.AVAILABLE and model.last_used and model.last_used.timestamp() < cutoff): # Delete file if model.local_path: try: Path(model.local_path).unlink() model.status = ModelStatus.NOT_DOWNLOADED model.local_path = None removed.append(model_id) logger.info(f"Removed unused model: {model_id}") except Exception as e: logger.error(f"Failed to remove model {model_id}: {e}") if removed: self._save_registry() return removed def export_registry(self, output_file: Path): """Export registry to file.""" data = { 'version': '1.0', 'models': [model.to_dict() for model in self.models.values()] } with open(output_file, 'w') as f: if output_file.suffix == '.yaml': yaml.dump(data, f, default_flow_style=False) else: json.dump(data, f, indent=2) def get_statistics(self) -> Dict[str, Any]: """Get registry statistics.""" total_models = len(self.models) downloaded = len([m for m in self.models.values() if m.status == ModelStatus.AVAILABLE]) task_counts = {} for task in ModelTask: count = len([m for m in self.models.values() if m.task == task]) if count > 0: task_counts[task.value] = count return { 'total_models': total_models, 'downloaded_models': downloaded, 'total_size_mb': self.get_total_size() / (1024 * 1024), 'downloaded_size_mb': self.get_total_size(ModelStatus.AVAILABLE) / (1024 * 1024), 'models_by_task': task_counts, 'most_used': sorted( [(m.model_id, m.use_count) for m in self.models.values()], key=lambda x: x[1], reverse=True )[:5] }