|
|
""" |
|
|
BackgroundFX Pro Models Module. |
|
|
Comprehensive model management, optimization, and deployment. |
|
|
""" |
|
|
|
|
|
from .registry import ( |
|
|
ModelRegistry, |
|
|
ModelInfo, |
|
|
ModelStatus, |
|
|
ModelTask, |
|
|
ModelFramework |
|
|
) |
|
|
|
|
|
from .downloader import ( |
|
|
ModelDownloader, |
|
|
DownloadStatus, |
|
|
DownloadProgress |
|
|
) |
|
|
|
|
|
from .loaders.model_loader import ( |
|
|
ModelLoader, |
|
|
LoadedModel |
|
|
) |
|
|
|
|
|
from .optimizer import ( |
|
|
ModelOptimizer, |
|
|
OptimizationResult |
|
|
) |
|
|
|
|
|
__all__ = [ |
|
|
|
|
|
'ModelRegistry', |
|
|
'ModelInfo', |
|
|
'ModelStatus', |
|
|
'ModelTask', |
|
|
'ModelFramework', |
|
|
|
|
|
|
|
|
'ModelDownloader', |
|
|
'DownloadStatus', |
|
|
'DownloadProgress', |
|
|
|
|
|
|
|
|
'ModelLoader', |
|
|
'LoadedModel', |
|
|
|
|
|
|
|
|
'ModelOptimizer', |
|
|
'OptimizationResult', |
|
|
|
|
|
|
|
|
'create_model_manager', |
|
|
'download_all_models', |
|
|
'optimize_for_deployment', |
|
|
'benchmark_models' |
|
|
] |
|
|
|
|
|
|
|
|
__version__ = '1.0.0' |
|
|
|
|
|
|
|
|
class ModelManager: |
|
|
""" |
|
|
High-level model management interface. |
|
|
Combines registry, downloading, loading, and optimization. |
|
|
""" |
|
|
|
|
|
def __init__(self, models_dir: str = None, device: str = 'auto'): |
|
|
""" |
|
|
Initialize model manager. |
|
|
|
|
|
Args: |
|
|
models_dir: Directory for model storage |
|
|
device: Device for model loading |
|
|
""" |
|
|
from pathlib import Path |
|
|
|
|
|
self.models_dir = Path(models_dir) if models_dir else Path.home() / ".backgroundfx" / "models" |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.registry = ModelRegistry(self.models_dir) |
|
|
self.downloader = ModelDownloader(self.registry) |
|
|
self.loader = ModelLoader(self.registry, device=device) |
|
|
self.optimizer = ModelOptimizer(self.loader) |
|
|
|
|
|
def setup(self, task: str = None, download: bool = True) -> bool: |
|
|
""" |
|
|
Setup models for a specific task. |
|
|
|
|
|
Args: |
|
|
task: Task type (segmentation, matting, etc.) |
|
|
download: Download missing models |
|
|
|
|
|
Returns: |
|
|
True if setup successful |
|
|
""" |
|
|
if download: |
|
|
return self.downloader.download_required_models(task) |
|
|
return True |
|
|
|
|
|
def get_model(self, model_id: str = None, task: str = None) -> LoadedModel: |
|
|
""" |
|
|
Get a loaded model by ID or task. |
|
|
|
|
|
Args: |
|
|
model_id: Specific model ID |
|
|
task: Task type to find best model |
|
|
|
|
|
Returns: |
|
|
Loaded model |
|
|
""" |
|
|
if model_id: |
|
|
return self.loader.load_model(model_id) |
|
|
elif task: |
|
|
from .registry import ModelTask |
|
|
task_enum = ModelTask(task) |
|
|
best_model = self.registry.get_best_model(task_enum) |
|
|
if best_model: |
|
|
return self.loader.load_model(best_model.model_id) |
|
|
return None |
|
|
|
|
|
def predict(self, input_data, model_id: str = None, task: str = None, **kwargs): |
|
|
""" |
|
|
Run prediction with a model. |
|
|
|
|
|
Args: |
|
|
input_data: Input data |
|
|
model_id: Model ID |
|
|
task: Task type |
|
|
**kwargs: Additional arguments |
|
|
|
|
|
Returns: |
|
|
Prediction result |
|
|
""" |
|
|
if not model_id and task: |
|
|
from .registry import ModelTask |
|
|
task_enum = ModelTask(task) |
|
|
best_model = self.registry.get_best_model(task_enum) |
|
|
if best_model: |
|
|
model_id = best_model.model_id |
|
|
|
|
|
if model_id: |
|
|
return self.loader.predict(model_id, input_data, **kwargs) |
|
|
return None |
|
|
|
|
|
def optimize(self, model_id: str, optimization_type: str = 'quantization', **kwargs): |
|
|
""" |
|
|
Optimize a model. |
|
|
|
|
|
Args: |
|
|
model_id: Model to optimize |
|
|
optimization_type: Type of optimization |
|
|
**kwargs: Optimization parameters |
|
|
|
|
|
Returns: |
|
|
Optimization result |
|
|
""" |
|
|
return self.optimizer.optimize_model(model_id, optimization_type, **kwargs) |
|
|
|
|
|
def benchmark(self, task: str = None) -> dict: |
|
|
""" |
|
|
Benchmark available models. |
|
|
|
|
|
Args: |
|
|
task: Optional task filter |
|
|
|
|
|
Returns: |
|
|
Benchmark results |
|
|
""" |
|
|
results = {} |
|
|
|
|
|
models = self.registry.list_models() |
|
|
if task: |
|
|
from .registry import ModelTask |
|
|
task_enum = ModelTask(task) |
|
|
models = [m for m in models if m.task == task_enum] |
|
|
|
|
|
for model_info in models: |
|
|
if model_info.status == ModelStatus.AVAILABLE: |
|
|
loaded = self.loader.load_model(model_info.model_id) |
|
|
if loaded: |
|
|
results[model_info.model_id] = { |
|
|
'name': model_info.name, |
|
|
'framework': model_info.framework.value, |
|
|
'size_mb': model_info.file_size / (1024 * 1024), |
|
|
'speed_fps': model_info.speed_fps, |
|
|
'accuracy': model_info.accuracy, |
|
|
'memory_mb': model_info.memory_mb, |
|
|
'load_time': loaded.load_time |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
def cleanup(self, days: int = 30): |
|
|
""" |
|
|
Clean up unused models. |
|
|
|
|
|
Args: |
|
|
days: Days threshold for unused models |
|
|
|
|
|
Returns: |
|
|
List of removed models |
|
|
""" |
|
|
return self.registry.cleanup_unused_models(days) |
|
|
|
|
|
def get_stats(self) -> dict: |
|
|
"""Get model management statistics.""" |
|
|
return { |
|
|
'registry': self.registry.get_statistics(), |
|
|
'loader': self.loader.get_memory_usage(), |
|
|
'downloads': { |
|
|
model_id: progress.progress |
|
|
for model_id, progress in self.downloader.get_all_progress().items() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_model_manager(models_dir: str = None, device: str = 'auto') -> ModelManager: |
|
|
""" |
|
|
Create a model manager instance. |
|
|
|
|
|
Args: |
|
|
models_dir: Directory for models |
|
|
device: Device for loading |
|
|
|
|
|
Returns: |
|
|
Model manager |
|
|
""" |
|
|
return ModelManager(models_dir, device) |
|
|
|
|
|
|
|
|
def download_all_models(manager: ModelManager = None, force: bool = False) -> bool: |
|
|
""" |
|
|
Download all available models. |
|
|
|
|
|
Args: |
|
|
manager: Model manager instance |
|
|
force: Force re-download |
|
|
|
|
|
Returns: |
|
|
True if all downloads successful |
|
|
""" |
|
|
if not manager: |
|
|
manager = create_model_manager() |
|
|
|
|
|
models = manager.registry.list_models() |
|
|
model_ids = [m.model_id for m in models] |
|
|
|
|
|
futures = manager.downloader.download_models_async(model_ids, force=force) |
|
|
|
|
|
success = True |
|
|
for model_id, future in futures.items(): |
|
|
try: |
|
|
if not future.result(): |
|
|
success = False |
|
|
except: |
|
|
success = False |
|
|
|
|
|
return success |
|
|
|
|
|
|
|
|
def optimize_for_deployment(manager: ModelManager = None, |
|
|
target: str = 'edge', |
|
|
models: list = None) -> dict: |
|
|
""" |
|
|
Optimize models for deployment. |
|
|
|
|
|
Args: |
|
|
manager: Model manager |
|
|
target: Deployment target (edge, cloud, mobile) |
|
|
models: Specific models to optimize |
|
|
|
|
|
Returns: |
|
|
Optimization results |
|
|
""" |
|
|
if not manager: |
|
|
manager = create_model_manager() |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
if target == 'edge': |
|
|
optimization = 'quantization' |
|
|
kwargs = {'quantization_type': 'dynamic'} |
|
|
elif target == 'mobile': |
|
|
optimization = 'coreml' if manager.device == 'mps' else 'tflite' |
|
|
kwargs = {} |
|
|
elif target == 'cloud': |
|
|
optimization = 'tensorrt' if manager.device == 'cuda' else 'onnx' |
|
|
kwargs = {'fp16': True} |
|
|
else: |
|
|
optimization = 'onnx' |
|
|
kwargs = {} |
|
|
|
|
|
|
|
|
if not models: |
|
|
available = manager.registry.list_models(status=ModelStatus.AVAILABLE) |
|
|
models = [m.model_id for m in available] |
|
|
|
|
|
|
|
|
for model_id in models: |
|
|
result = manager.optimize(model_id, optimization, **kwargs) |
|
|
if result: |
|
|
results[model_id] = result |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def benchmark_models(manager: ModelManager = None, task: str = None) -> dict: |
|
|
""" |
|
|
Benchmark model performance. |
|
|
|
|
|
Args: |
|
|
manager: Model manager |
|
|
task: Optional task filter |
|
|
|
|
|
Returns: |
|
|
Benchmark results |
|
|
""" |
|
|
if not manager: |
|
|
manager = create_model_manager() |
|
|
|
|
|
return manager.benchmark(task) |