MogensR commited on
Commit
d42af6c
·
1 Parent(s): 39ebfbd

Create models/loader.py

Browse files
Files changed (1) hide show
  1. models/models/loader.py +515 -0
models/models/loader.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model loader for BackgroundFX Pro.
3
+ Handles loading, initialization, and management of ML models.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import onnxruntime as ort
9
+ import numpy as np
10
+ from pathlib import Path
11
+ from typing import Dict, Optional, Any, Union, List, Tuple
12
+ from dataclasses import dataclass
13
+ import logging
14
+ import gc
15
+ import psutil
16
+ from functools import lru_cache
17
+
18
+ from .registry import ModelInfo, ModelFramework, ModelTask, ModelRegistry
19
+ from .downloader import ModelDownloader
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class LoadedModel:
26
+ """Container for loaded model."""
27
+ model_id: str
28
+ model: Any # Actual model object
29
+ framework: ModelFramework
30
+ device: str
31
+ memory_usage: int # In bytes
32
+ load_time: float # In seconds
33
+ metadata: Dict[str, Any]
34
+
35
+
36
+ class ModelLoader:
37
+ """
38
+ Load and manage ML models with automatic memory management.
39
+ """
40
+
41
+ def __init__(self,
42
+ registry: ModelRegistry,
43
+ device: Optional[str] = None,
44
+ max_memory_gb: float = 4.0,
45
+ enable_cache: bool = True):
46
+ """
47
+ Initialize model loader.
48
+
49
+ Args:
50
+ registry: Model registry instance
51
+ device: Device to load models on ('cuda', 'cpu', 'auto')
52
+ max_memory_gb: Maximum memory usage in GB
53
+ enable_cache: Enable model caching
54
+ """
55
+ self.registry = registry
56
+ self.downloader = ModelDownloader(registry)
57
+ self.max_memory_bytes = int(max_memory_gb * 1024 * 1024 * 1024)
58
+ self.enable_cache = enable_cache
59
+
60
+ # Device management
61
+ self.device = self._setup_device(device)
62
+ self.providers = self._setup_providers()
63
+
64
+ # Model cache
65
+ self.loaded_models: Dict[str, LoadedModel] = {}
66
+ self.current_memory_usage = 0
67
+
68
+ logger.info(f"ModelLoader initialized with device: {self.device}")
69
+
70
+ def _setup_device(self, device: Optional[str]) -> str:
71
+ """Setup computation device."""
72
+ if device == 'auto' or device is None:
73
+ if torch.cuda.is_available():
74
+ return 'cuda'
75
+ elif torch.backends.mps.is_available():
76
+ return 'mps'
77
+ else:
78
+ return 'cpu'
79
+ return device
80
+
81
+ def _setup_providers(self) -> List[str]:
82
+ """Setup ONNX Runtime providers."""
83
+ providers = []
84
+
85
+ if self.device == 'cuda':
86
+ providers.extend([
87
+ 'CUDAExecutionProvider',
88
+ 'TensorrtExecutionProvider'
89
+ ])
90
+ elif self.device == 'mps':
91
+ providers.append('CoreMLExecutionProvider')
92
+
93
+ providers.append('CPUExecutionProvider')
94
+
95
+ return providers
96
+
97
+ def load_model(self,
98
+ model_id: str,
99
+ force_reload: bool = False,
100
+ device_override: Optional[str] = None) -> Optional[LoadedModel]:
101
+ """
102
+ Load a model by ID.
103
+
104
+ Args:
105
+ model_id: Model ID to load
106
+ force_reload: Force reload even if cached
107
+ device_override: Override default device
108
+
109
+ Returns:
110
+ Loaded model or None if failed
111
+ """
112
+ # Check cache
113
+ if not force_reload and model_id in self.loaded_models:
114
+ logger.info(f"Using cached model: {model_id}")
115
+ self.registry.update_model_usage(model_id)
116
+ return self.loaded_models[model_id]
117
+
118
+ # Get model info
119
+ model_info = self.registry.get_model(model_id)
120
+ if not model_info:
121
+ logger.error(f"Model not found: {model_id}")
122
+ return None
123
+
124
+ # Download if needed
125
+ if model_info.status != "available":
126
+ logger.info(f"Downloading model: {model_id}")
127
+ if not self.downloader.download_model(model_id):
128
+ logger.error(f"Failed to download model: {model_id}")
129
+ return None
130
+
131
+ # Check memory
132
+ if not self._check_memory_available(model_info):
133
+ logger.warning(f"Insufficient memory for model: {model_id}")
134
+ self._free_memory(model_info.memory_mb * 1024 * 1024 if model_info.memory_mb else 0)
135
+
136
+ # Load model
137
+ device = device_override or self.device
138
+ loaded = self._load_model_impl(model_info, device)
139
+
140
+ if loaded:
141
+ # Cache model
142
+ if self.enable_cache:
143
+ self.loaded_models[model_id] = loaded
144
+ self.current_memory_usage += loaded.memory_usage
145
+
146
+ # Update registry
147
+ self.registry.update_model_usage(model_id)
148
+
149
+ logger.info(f"Successfully loaded model: {model_id}")
150
+ return loaded
151
+
152
+ return None
153
+
154
+ def _load_model_impl(self, model_info: ModelInfo, device: str) -> Optional[LoadedModel]:
155
+ """
156
+ Implementation of model loading based on framework.
157
+
158
+ Args:
159
+ model_info: Model information
160
+ device: Device to load on
161
+
162
+ Returns:
163
+ Loaded model or None
164
+ """
165
+ import time
166
+ start_time = time.time()
167
+
168
+ try:
169
+ if model_info.framework == ModelFramework.PYTORCH:
170
+ model = self._load_pytorch_model(model_info, device)
171
+ elif model_info.framework == ModelFramework.ONNX:
172
+ model = self._load_onnx_model(model_info)
173
+ elif model_info.framework == ModelFramework.TFLITE:
174
+ model = self._load_tflite_model(model_info)
175
+ elif model_info.framework == ModelFramework.TENSORRT:
176
+ model = self._load_tensorrt_model(model_info)
177
+ else:
178
+ logger.error(f"Unsupported framework: {model_info.framework}")
179
+ return None
180
+
181
+ if model is None:
182
+ return None
183
+
184
+ # Estimate memory usage
185
+ memory_usage = self._estimate_model_memory(model, model_info)
186
+
187
+ loaded = LoadedModel(
188
+ model_id=model_info.model_id,
189
+ model=model,
190
+ framework=model_info.framework,
191
+ device=device,
192
+ memory_usage=memory_usage,
193
+ load_time=time.time() - start_time,
194
+ metadata=model_info.config
195
+ )
196
+
197
+ return loaded
198
+
199
+ except Exception as e:
200
+ logger.error(f"Failed to load model {model_info.model_id}: {e}")
201
+ return None
202
+
203
+ def _load_pytorch_model(self, model_info: ModelInfo, device: str) -> Optional[Any]:
204
+ """Load PyTorch model."""
205
+ try:
206
+ model_path = Path(model_info.local_path)
207
+
208
+ # Load model
209
+ if model_path.suffix == '.pth':
210
+ # Load state dict
211
+ state_dict = torch.load(model_path, map_location=device)
212
+
213
+ # Create model architecture (model-specific)
214
+ model = self._create_model_architecture(model_info)
215
+ if model:
216
+ model.load_state_dict(state_dict)
217
+ else:
218
+ # Try loading as complete model
219
+ model = torch.load(model_path, map_location=device)
220
+ else:
221
+ # Load complete model
222
+ model = torch.load(model_path, map_location=device)
223
+
224
+ # Move to device
225
+ if isinstance(model, nn.Module):
226
+ model = model.to(device)
227
+ model.eval()
228
+
229
+ return model
230
+
231
+ except Exception as e:
232
+ logger.error(f"PyTorch model loading failed: {e}")
233
+ return None
234
+
235
+ def _load_onnx_model(self, model_info: ModelInfo) -> Optional[Any]:
236
+ """Load ONNX model."""
237
+ try:
238
+ model_path = str(model_info.local_path)
239
+
240
+ # Create session options
241
+ sess_options = ort.SessionOptions()
242
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
243
+
244
+ # Add providers based on device
245
+ providers = self.providers
246
+
247
+ # Create inference session
248
+ session = ort.InferenceSession(
249
+ model_path,
250
+ sess_options=sess_options,
251
+ providers=providers
252
+ )
253
+
254
+ return session
255
+
256
+ except Exception as e:
257
+ logger.error(f"ONNX model loading failed: {e}")
258
+ return None
259
+
260
+ def _load_tflite_model(self, model_info: ModelInfo) -> Optional[Any]:
261
+ """Load TFLite model."""
262
+ try:
263
+ import tensorflow as tf
264
+
265
+ model_path = str(model_info.local_path)
266
+
267
+ # Load TFLite model
268
+ interpreter = tf.lite.Interpreter(model_path=model_path)
269
+ interpreter.allocate_tensors()
270
+
271
+ return interpreter
272
+
273
+ except Exception as e:
274
+ logger.error(f"TFLite model loading failed: {e}")
275
+ return None
276
+
277
+ def _load_tensorrt_model(self, model_info: ModelInfo) -> Optional[Any]:
278
+ """Load TensorRT model."""
279
+ try:
280
+ import tensorrt as trt
281
+ import pycuda.driver as cuda
282
+ import pycuda.autoinit
283
+
284
+ model_path = str(model_info.local_path)
285
+
286
+ # Load TensorRT engine
287
+ with open(model_path, 'rb') as f:
288
+ engine_data = f.read()
289
+
290
+ runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
291
+ engine = runtime.deserialize_cuda_engine(engine_data)
292
+ context = engine.create_execution_context()
293
+
294
+ return {'engine': engine, 'context': context}
295
+
296
+ except Exception as e:
297
+ logger.error(f"TensorRT model loading failed: {e}")
298
+ return None
299
+
300
+ def _create_model_architecture(self, model_info: ModelInfo) -> Optional[nn.Module]:
301
+ """
302
+ Create model architecture for specific models.
303
+ This would need to be implemented for each model type.
304
+ """
305
+ # Model-specific architecture creation
306
+ # This is where you'd define the architecture for models
307
+ # that are loaded as state_dicts
308
+
309
+ if model_info.model_id == "u2net":
310
+ # Example: Create U2Net architecture
311
+ try:
312
+ from ..core.models import U2NET
313
+ return U2NET()
314
+ except:
315
+ pass
316
+
317
+ return None
318
+
319
+ def _estimate_model_memory(self, model: Any, model_info: ModelInfo) -> int:
320
+ """Estimate model memory usage in bytes."""
321
+ if model_info.memory_mb:
322
+ return model_info.memory_mb * 1024 * 1024
323
+
324
+ # Estimate based on model type
325
+ if isinstance(model, nn.Module):
326
+ # PyTorch model
327
+ param_size = sum(p.numel() * p.element_size() for p in model.parameters())
328
+ buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
329
+ return param_size + buffer_size
330
+
331
+ elif hasattr(model, 'get_inputs'):
332
+ # ONNX model
333
+ # Rough estimate based on file size
334
+ file_size = Path(model_info.local_path).stat().st_size
335
+ return int(file_size * 2) # Account for runtime overhead
336
+
337
+ else:
338
+ # Default estimate
339
+ return 500 * 1024 * 1024 # 500MB default
340
+
341
+ def _check_memory_available(self, model_info: ModelInfo) -> bool:
342
+ """Check if enough memory is available."""
343
+ required = model_info.memory_mb * 1024 * 1024 if model_info.memory_mb else 500 * 1024 * 1024
344
+
345
+ if self.device == 'cuda':
346
+ # Check GPU memory
347
+ try:
348
+ import torch
349
+ free_memory = torch.cuda.mem_get_info()[0]
350
+ return free_memory > required
351
+ except:
352
+ pass
353
+
354
+ # Check system memory
355
+ available = psutil.virtual_memory().available
356
+ return available > required
357
+
358
+ def _free_memory(self, required_bytes: int):
359
+ """Free memory by unloading models."""
360
+ if not self.enable_cache:
361
+ return
362
+
363
+ # Sort models by last used time
364
+ models_by_usage = sorted(
365
+ self.loaded_models.items(),
366
+ key=lambda x: self.registry.models[x[0]].last_used or 0
367
+ )
368
+
369
+ freed = 0
370
+ for model_id, loaded_model in models_by_usage:
371
+ if freed >= required_bytes:
372
+ break
373
+
374
+ # Unload model
375
+ self.unload_model(model_id)
376
+ freed += loaded_model.memory_usage
377
+
378
+ logger.info(f"Freed memory by unloading: {model_id}")
379
+
380
+ def unload_model(self, model_id: str) -> bool:
381
+ """
382
+ Unload a model from memory.
383
+
384
+ Args:
385
+ model_id: Model ID to unload
386
+
387
+ Returns:
388
+ True if unloaded
389
+ """
390
+ if model_id in self.loaded_models:
391
+ loaded = self.loaded_models[model_id]
392
+
393
+ # Clean up model
394
+ if isinstance(loaded.model, nn.Module):
395
+ del loaded.model
396
+ if self.device == 'cuda':
397
+ torch.cuda.empty_cache()
398
+ else:
399
+ del loaded.model
400
+
401
+ # Update tracking
402
+ self.current_memory_usage -= loaded.memory_usage
403
+ del self.loaded_models[model_id]
404
+
405
+ # Force garbage collection
406
+ gc.collect()
407
+
408
+ logger.info(f"Unloaded model: {model_id}")
409
+ return True
410
+
411
+ return False
412
+
413
+ def unload_all(self):
414
+ """Unload all models."""
415
+ model_ids = list(self.loaded_models.keys())
416
+ for model_id in model_ids:
417
+ self.unload_model(model_id)
418
+
419
+ def get_loaded_models(self) -> List[str]:
420
+ """Get list of loaded model IDs."""
421
+ return list(self.loaded_models.keys())
422
+
423
+ def get_memory_usage(self) -> Dict[str, Any]:
424
+ """Get memory usage statistics."""
425
+ return {
426
+ 'current_usage_mb': self.current_memory_usage / (1024 * 1024),
427
+ 'max_usage_mb': self.max_memory_bytes / (1024 * 1024),
428
+ 'loaded_models': len(self.loaded_models),
429
+ 'models': {
430
+ model_id: loaded.memory_usage / (1024 * 1024)
431
+ for model_id, loaded in self.loaded_models.items()
432
+ }
433
+ }
434
+
435
+ def predict(self,
436
+ model_id: str,
437
+ input_data: Union[np.ndarray, torch.Tensor],
438
+ **kwargs) -> Optional[Any]:
439
+ """
440
+ Run prediction with a model.
441
+
442
+ Args:
443
+ model_id: Model ID
444
+ input_data: Input data
445
+ **kwargs: Additional arguments
446
+
447
+ Returns:
448
+ Prediction result
449
+ """
450
+ # Load model if needed
451
+ loaded = self.load_model(model_id)
452
+ if not loaded:
453
+ return None
454
+
455
+ try:
456
+ if loaded.framework == ModelFramework.PYTORCH:
457
+ return self._predict_pytorch(loaded.model, input_data, **kwargs)
458
+ elif loaded.framework == ModelFramework.ONNX:
459
+ return self._predict_onnx(loaded.model, input_data, **kwargs)
460
+ elif loaded.framework == ModelFramework.TFLITE:
461
+ return self._predict_tflite(loaded.model, input_data, **kwargs)
462
+ else:
463
+ logger.error(f"Prediction not implemented for: {loaded.framework}")
464
+ return None
465
+
466
+ except Exception as e:
467
+ logger.error(f"Prediction failed: {e}")
468
+ return None
469
+
470
+ def _predict_pytorch(self, model: nn.Module, input_data: Any, **kwargs) -> Any:
471
+ """Run PyTorch prediction."""
472
+ with torch.no_grad():
473
+ if not isinstance(input_data, torch.Tensor):
474
+ input_data = torch.from_numpy(input_data)
475
+
476
+ input_data = input_data.to(self.device)
477
+ output = model(input_data)
478
+
479
+ if isinstance(output, torch.Tensor):
480
+ output = output.cpu().numpy()
481
+
482
+ return output
483
+
484
+ def _predict_onnx(self, session: ort.InferenceSession, input_data: Any, **kwargs) -> Any:
485
+ """Run ONNX prediction."""
486
+ if isinstance(input_data, torch.Tensor):
487
+ input_data = input_data.numpy()
488
+
489
+ # Get input name
490
+ input_name = session.get_inputs()[0].name
491
+
492
+ # Run inference
493
+ outputs = session.run(None, {input_name: input_data})
494
+
495
+ return outputs[0] if len(outputs) == 1 else outputs
496
+
497
+ def _predict_tflite(self, interpreter: Any, input_data: Any, **kwargs) -> Any:
498
+ """Run TFLite prediction."""
499
+ if isinstance(input_data, torch.Tensor):
500
+ input_data = input_data.numpy()
501
+
502
+ # Get input/output details
503
+ input_details = interpreter.get_input_details()
504
+ output_details = interpreter.get_output_details()
505
+
506
+ # Set input
507
+ interpreter.set_tensor(input_details[0]['index'], input_data)
508
+
509
+ # Run inference
510
+ interpreter.invoke()
511
+
512
+ # Get output
513
+ output = interpreter.get_tensor(output_details[0]['index'])
514
+
515
+ return output