MogensR commited on
Commit
45a250f
·
1 Parent(s): b1298b9

Create models/__init__.py

Browse files
Files changed (1) hide show
  1. models/__init__.py +326 -0
models/__init__.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BackgroundFX Pro Models Module.
3
+ Comprehensive model management, optimization, and deployment.
4
+ """
5
+
6
+ from .registry import (
7
+ ModelRegistry,
8
+ ModelInfo,
9
+ ModelStatus,
10
+ ModelTask,
11
+ ModelFramework
12
+ )
13
+
14
+ from .downloader import (
15
+ ModelDownloader,
16
+ DownloadStatus,
17
+ DownloadProgress
18
+ )
19
+
20
+ from .loader import (
21
+ ModelLoader,
22
+ LoadedModel
23
+ )
24
+
25
+ from .optimizer import (
26
+ ModelOptimizer,
27
+ OptimizationResult
28
+ )
29
+
30
+ __all__ = [
31
+ # Registry
32
+ 'ModelRegistry',
33
+ 'ModelInfo',
34
+ 'ModelStatus',
35
+ 'ModelTask',
36
+ 'ModelFramework',
37
+
38
+ # Downloader
39
+ 'ModelDownloader',
40
+ 'DownloadStatus',
41
+ 'DownloadProgress',
42
+
43
+ # Loader
44
+ 'ModelLoader',
45
+ 'LoadedModel',
46
+
47
+ # Optimizer
48
+ 'ModelOptimizer',
49
+ 'OptimizationResult',
50
+
51
+ # High-level functions
52
+ 'create_model_manager',
53
+ 'download_all_models',
54
+ 'optimize_for_deployment',
55
+ 'benchmark_models'
56
+ ]
57
+
58
+ # Version
59
+ __version__ = '1.0.0'
60
+
61
+
62
+ class ModelManager:
63
+ """
64
+ High-level model management interface.
65
+ Combines registry, downloading, loading, and optimization.
66
+ """
67
+
68
+ def __init__(self, models_dir: str = None, device: str = 'auto'):
69
+ """
70
+ Initialize model manager.
71
+
72
+ Args:
73
+ models_dir: Directory for model storage
74
+ device: Device for model loading
75
+ """
76
+ from pathlib import Path
77
+
78
+ self.models_dir = Path(models_dir) if models_dir else Path.home() / ".backgroundfx" / "models"
79
+ self.device = device
80
+
81
+ # Initialize components
82
+ self.registry = ModelRegistry(self.models_dir)
83
+ self.downloader = ModelDownloader(self.registry)
84
+ self.loader = ModelLoader(self.registry, device=device)
85
+ self.optimizer = ModelOptimizer(self.loader)
86
+
87
+ def setup(self, task: str = None, download: bool = True) -> bool:
88
+ """
89
+ Setup models for a specific task.
90
+
91
+ Args:
92
+ task: Task type (segmentation, matting, etc.)
93
+ download: Download missing models
94
+
95
+ Returns:
96
+ True if setup successful
97
+ """
98
+ if download:
99
+ return self.downloader.download_required_models(task)
100
+ return True
101
+
102
+ def get_model(self, model_id: str = None, task: str = None) -> LoadedModel:
103
+ """
104
+ Get a loaded model by ID or task.
105
+
106
+ Args:
107
+ model_id: Specific model ID
108
+ task: Task type to find best model
109
+
110
+ Returns:
111
+ Loaded model
112
+ """
113
+ if model_id:
114
+ return self.loader.load_model(model_id)
115
+ elif task:
116
+ from .registry import ModelTask
117
+ task_enum = ModelTask(task)
118
+ best_model = self.registry.get_best_model(task_enum)
119
+ if best_model:
120
+ return self.loader.load_model(best_model.model_id)
121
+ return None
122
+
123
+ def predict(self, input_data, model_id: str = None, task: str = None, **kwargs):
124
+ """
125
+ Run prediction with a model.
126
+
127
+ Args:
128
+ input_data: Input data
129
+ model_id: Model ID
130
+ task: Task type
131
+ **kwargs: Additional arguments
132
+
133
+ Returns:
134
+ Prediction result
135
+ """
136
+ if not model_id and task:
137
+ from .registry import ModelTask
138
+ task_enum = ModelTask(task)
139
+ best_model = self.registry.get_best_model(task_enum)
140
+ if best_model:
141
+ model_id = best_model.model_id
142
+
143
+ if model_id:
144
+ return self.loader.predict(model_id, input_data, **kwargs)
145
+ return None
146
+
147
+ def optimize(self, model_id: str, optimization_type: str = 'quantization', **kwargs):
148
+ """
149
+ Optimize a model.
150
+
151
+ Args:
152
+ model_id: Model to optimize
153
+ optimization_type: Type of optimization
154
+ **kwargs: Optimization parameters
155
+
156
+ Returns:
157
+ Optimization result
158
+ """
159
+ return self.optimizer.optimize_model(model_id, optimization_type, **kwargs)
160
+
161
+ def benchmark(self, task: str = None) -> dict:
162
+ """
163
+ Benchmark available models.
164
+
165
+ Args:
166
+ task: Optional task filter
167
+
168
+ Returns:
169
+ Benchmark results
170
+ """
171
+ results = {}
172
+
173
+ models = self.registry.list_models()
174
+ if task:
175
+ from .registry import ModelTask
176
+ task_enum = ModelTask(task)
177
+ models = [m for m in models if m.task == task_enum]
178
+
179
+ for model_info in models:
180
+ if model_info.status == ModelStatus.AVAILABLE:
181
+ loaded = self.loader.load_model(model_info.model_id)
182
+ if loaded:
183
+ results[model_info.model_id] = {
184
+ 'name': model_info.name,
185
+ 'framework': model_info.framework.value,
186
+ 'size_mb': model_info.file_size / (1024 * 1024),
187
+ 'speed_fps': model_info.speed_fps,
188
+ 'accuracy': model_info.accuracy,
189
+ 'memory_mb': model_info.memory_mb,
190
+ 'load_time': loaded.load_time
191
+ }
192
+
193
+ return results
194
+
195
+ def cleanup(self, days: int = 30):
196
+ """
197
+ Clean up unused models.
198
+
199
+ Args:
200
+ days: Days threshold for unused models
201
+
202
+ Returns:
203
+ List of removed models
204
+ """
205
+ return self.registry.cleanup_unused_models(days)
206
+
207
+ def get_stats(self) -> dict:
208
+ """Get model management statistics."""
209
+ return {
210
+ 'registry': self.registry.get_statistics(),
211
+ 'loader': self.loader.get_memory_usage(),
212
+ 'downloads': {
213
+ model_id: progress.progress
214
+ for model_id, progress in self.downloader.get_all_progress().items()
215
+ }
216
+ }
217
+
218
+
219
+ # Convenience functions
220
+
221
+ def create_model_manager(models_dir: str = None, device: str = 'auto') -> ModelManager:
222
+ """
223
+ Create a model manager instance.
224
+
225
+ Args:
226
+ models_dir: Directory for models
227
+ device: Device for loading
228
+
229
+ Returns:
230
+ Model manager
231
+ """
232
+ return ModelManager(models_dir, device)
233
+
234
+
235
+ def download_all_models(manager: ModelManager = None, force: bool = False) -> bool:
236
+ """
237
+ Download all available models.
238
+
239
+ Args:
240
+ manager: Model manager instance
241
+ force: Force re-download
242
+
243
+ Returns:
244
+ True if all downloads successful
245
+ """
246
+ if not manager:
247
+ manager = create_model_manager()
248
+
249
+ models = manager.registry.list_models()
250
+ model_ids = [m.model_id for m in models]
251
+
252
+ futures = manager.downloader.download_models_async(model_ids, force=force)
253
+
254
+ success = True
255
+ for model_id, future in futures.items():
256
+ try:
257
+ if not future.result():
258
+ success = False
259
+ except:
260
+ success = False
261
+
262
+ return success
263
+
264
+
265
+ def optimize_for_deployment(manager: ModelManager = None,
266
+ target: str = 'edge',
267
+ models: list = None) -> dict:
268
+ """
269
+ Optimize models for deployment.
270
+
271
+ Args:
272
+ manager: Model manager
273
+ target: Deployment target (edge, cloud, mobile)
274
+ models: Specific models to optimize
275
+
276
+ Returns:
277
+ Optimization results
278
+ """
279
+ if not manager:
280
+ manager = create_model_manager()
281
+
282
+ results = {}
283
+
284
+ # Determine optimization strategy
285
+ if target == 'edge':
286
+ optimization = 'quantization'
287
+ kwargs = {'quantization_type': 'dynamic'}
288
+ elif target == 'mobile':
289
+ optimization = 'coreml' if manager.device == 'mps' else 'tflite'
290
+ kwargs = {}
291
+ elif target == 'cloud':
292
+ optimization = 'tensorrt' if manager.device == 'cuda' else 'onnx'
293
+ kwargs = {'fp16': True}
294
+ else:
295
+ optimization = 'onnx'
296
+ kwargs = {}
297
+
298
+ # Get models to optimize
299
+ if not models:
300
+ available = manager.registry.list_models(status=ModelStatus.AVAILABLE)
301
+ models = [m.model_id for m in available]
302
+
303
+ # Optimize each model
304
+ for model_id in models:
305
+ result = manager.optimize(model_id, optimization, **kwargs)
306
+ if result:
307
+ results[model_id] = result
308
+
309
+ return results
310
+
311
+
312
+ def benchmark_models(manager: ModelManager = None, task: str = None) -> dict:
313
+ """
314
+ Benchmark model performance.
315
+
316
+ Args:
317
+ manager: Model manager
318
+ task: Optional task filter
319
+
320
+ Returns:
321
+ Benchmark results
322
+ """
323
+ if not manager:
324
+ manager = create_model_manager()
325
+
326
+ return manager.benchmark(task)