MogensR's picture
Update models/__init__.py
8205af3
raw
history blame
8.82 kB
"""
BackgroundFX Pro Models Module.
Comprehensive model management, optimization, and deployment.
"""
from .registry import (
ModelRegistry,
ModelInfo,
ModelStatus,
ModelTask,
ModelFramework
)
from .downloader import (
ModelDownloader,
DownloadStatus,
DownloadProgress
)
from .loaders.model_loader import (
ModelLoader,
LoadedModel
)
from .optimizer import (
ModelOptimizer,
OptimizationResult
)
__all__ = [
# Registry
'ModelRegistry',
'ModelInfo',
'ModelStatus',
'ModelTask',
'ModelFramework',
# Downloader
'ModelDownloader',
'DownloadStatus',
'DownloadProgress',
# Loader
'ModelLoader',
'LoadedModel',
# Optimizer
'ModelOptimizer',
'OptimizationResult',
# High-level functions
'create_model_manager',
'download_all_models',
'optimize_for_deployment',
'benchmark_models'
]
# Version
__version__ = '1.0.0'
class ModelManager:
"""
High-level model management interface.
Combines registry, downloading, loading, and optimization.
"""
def __init__(self, models_dir: str = None, device: str = 'auto'):
"""
Initialize model manager.
Args:
models_dir: Directory for model storage
device: Device for model loading
"""
from pathlib import Path
self.models_dir = Path(models_dir) if models_dir else Path.home() / ".backgroundfx" / "models"
self.device = device
# Initialize components
self.registry = ModelRegistry(self.models_dir)
self.downloader = ModelDownloader(self.registry)
self.loader = ModelLoader(self.registry, device=device)
self.optimizer = ModelOptimizer(self.loader)
def setup(self, task: str = None, download: bool = True) -> bool:
"""
Setup models for a specific task.
Args:
task: Task type (segmentation, matting, etc.)
download: Download missing models
Returns:
True if setup successful
"""
if download:
return self.downloader.download_required_models(task)
return True
def get_model(self, model_id: str = None, task: str = None) -> LoadedModel:
"""
Get a loaded model by ID or task.
Args:
model_id: Specific model ID
task: Task type to find best model
Returns:
Loaded model
"""
if model_id:
return self.loader.load_model(model_id)
elif task:
from .registry import ModelTask
task_enum = ModelTask(task)
best_model = self.registry.get_best_model(task_enum)
if best_model:
return self.loader.load_model(best_model.model_id)
return None
def predict(self, input_data, model_id: str = None, task: str = None, **kwargs):
"""
Run prediction with a model.
Args:
input_data: Input data
model_id: Model ID
task: Task type
**kwargs: Additional arguments
Returns:
Prediction result
"""
if not model_id and task:
from .registry import ModelTask
task_enum = ModelTask(task)
best_model = self.registry.get_best_model(task_enum)
if best_model:
model_id = best_model.model_id
if model_id:
return self.loader.predict(model_id, input_data, **kwargs)
return None
def optimize(self, model_id: str, optimization_type: str = 'quantization', **kwargs):
"""
Optimize a model.
Args:
model_id: Model to optimize
optimization_type: Type of optimization
**kwargs: Optimization parameters
Returns:
Optimization result
"""
return self.optimizer.optimize_model(model_id, optimization_type, **kwargs)
def benchmark(self, task: str = None) -> dict:
"""
Benchmark available models.
Args:
task: Optional task filter
Returns:
Benchmark results
"""
results = {}
models = self.registry.list_models()
if task:
from .registry import ModelTask
task_enum = ModelTask(task)
models = [m for m in models if m.task == task_enum]
for model_info in models:
if model_info.status == ModelStatus.AVAILABLE:
loaded = self.loader.load_model(model_info.model_id)
if loaded:
results[model_info.model_id] = {
'name': model_info.name,
'framework': model_info.framework.value,
'size_mb': model_info.file_size / (1024 * 1024),
'speed_fps': model_info.speed_fps,
'accuracy': model_info.accuracy,
'memory_mb': model_info.memory_mb,
'load_time': loaded.load_time
}
return results
def cleanup(self, days: int = 30):
"""
Clean up unused models.
Args:
days: Days threshold for unused models
Returns:
List of removed models
"""
return self.registry.cleanup_unused_models(days)
def get_stats(self) -> dict:
"""Get model management statistics."""
return {
'registry': self.registry.get_statistics(),
'loader': self.loader.get_memory_usage(),
'downloads': {
model_id: progress.progress
for model_id, progress in self.downloader.get_all_progress().items()
}
}
# Convenience functions
def create_model_manager(models_dir: str = None, device: str = 'auto') -> ModelManager:
"""
Create a model manager instance.
Args:
models_dir: Directory for models
device: Device for loading
Returns:
Model manager
"""
return ModelManager(models_dir, device)
def download_all_models(manager: ModelManager = None, force: bool = False) -> bool:
"""
Download all available models.
Args:
manager: Model manager instance
force: Force re-download
Returns:
True if all downloads successful
"""
if not manager:
manager = create_model_manager()
models = manager.registry.list_models()
model_ids = [m.model_id for m in models]
futures = manager.downloader.download_models_async(model_ids, force=force)
success = True
for model_id, future in futures.items():
try:
if not future.result():
success = False
except:
success = False
return success
def optimize_for_deployment(manager: ModelManager = None,
target: str = 'edge',
models: list = None) -> dict:
"""
Optimize models for deployment.
Args:
manager: Model manager
target: Deployment target (edge, cloud, mobile)
models: Specific models to optimize
Returns:
Optimization results
"""
if not manager:
manager = create_model_manager()
results = {}
# Determine optimization strategy
if target == 'edge':
optimization = 'quantization'
kwargs = {'quantization_type': 'dynamic'}
elif target == 'mobile':
optimization = 'coreml' if manager.device == 'mps' else 'tflite'
kwargs = {}
elif target == 'cloud':
optimization = 'tensorrt' if manager.device == 'cuda' else 'onnx'
kwargs = {'fp16': True}
else:
optimization = 'onnx'
kwargs = {}
# Get models to optimize
if not models:
available = manager.registry.list_models(status=ModelStatus.AVAILABLE)
models = [m.model_id for m in available]
# Optimize each model
for model_id in models:
result = manager.optimize(model_id, optimization, **kwargs)
if result:
results[model_id] = result
return results
def benchmark_models(manager: ModelManager = None, task: str = None) -> dict:
"""
Benchmark model performance.
Args:
manager: Model manager
task: Optional task filter
Returns:
Benchmark results
"""
if not manager:
manager = create_model_manager()
return manager.benchmark(task)