MogensR commited on
Commit
a00a1ac
·
1 Parent(s): 1c5f026

Create model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +437 -0
model_loader.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Loading Module
3
+ Handles loading and validation of SAM2 and MatAnyone AI models
4
+ """
5
+
6
+ import os
7
+ import gc
8
+ import time
9
+ import logging
10
+ import tempfile
11
+ import traceback
12
+ from typing import Optional, Dict, Any, Tuple, Union
13
+ from pathlib import Path
14
+
15
+ import torch
16
+ import hydra
17
+ import gradio as gr
18
+ from omegaconf import DictConfig, OmegaConf
19
+
20
+ from exceptions import ModelLoadingError, ConfigurationError
21
+ from device_manager import DeviceManager
22
+ from memory_manager import MemoryManager
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ class ModelLoader:
27
+ """
28
+ Comprehensive model loading and management for SAM2 and MatAnyone
29
+ """
30
+
31
+ def __init__(self, device_manager: DeviceManager, memory_manager: MemoryManager):
32
+ self.device_manager = device_manager
33
+ self.memory_manager = memory_manager
34
+ self.device = device_manager.get_optimal_device()
35
+
36
+ # Model storage
37
+ self.sam2_predictor = None
38
+ self.matanyone_model = None
39
+ self.matanyone_core = None
40
+
41
+ # Configuration paths
42
+ self.configs_dir = os.path.abspath("Configs")
43
+ self.checkpoints_dir = "./checkpoints"
44
+ os.makedirs(self.checkpoints_dir, exist_ok=True)
45
+
46
+ # Model loading statistics
47
+ self.loading_stats = {
48
+ 'sam2_load_time': 0.0,
49
+ 'matanyone_load_time': 0.0,
50
+ 'total_load_time': 0.0,
51
+ 'models_loaded': False,
52
+ 'loading_attempts': 0
53
+ }
54
+
55
+ logger.info(f"ModelLoader initialized for device: {self.device}")
56
+ self._apply_gradio_patch()
57
+
58
+ def _apply_gradio_patch(self):
59
+ """Apply Gradio schema monkey patch to prevent validation errors"""
60
+ try:
61
+ import gradio.components.base
62
+ original_get_config = gradio.components.base.Component.get_config
63
+
64
+ def patched_get_config(self):
65
+ config = original_get_config(self)
66
+ # Remove problematic keys that cause validation errors
67
+ config.pop("show_progress_bar", None)
68
+ config.pop("min_width", None)
69
+ config.pop("scale", None)
70
+ return config
71
+
72
+ gradio.components.base.Component.get_config = patched_get_config
73
+ logger.debug("Applied Gradio schema monkey patch")
74
+
75
+ except (ImportError, AttributeError) as e:
76
+ logger.warning(f"Could not apply Gradio monkey patch: {e}")
77
+
78
+ def load_all_models(self, progress: Optional[gr.Progress] = None) -> bool:
79
+ """
80
+ Load both SAM2 and MatAnyone models with comprehensive error handling
81
+
82
+ Args:
83
+ progress: Gradio progress callback
84
+
85
+ Returns:
86
+ bool: True if all models loaded successfully
87
+ """
88
+ start_time = time.time()
89
+ self.loading_stats['loading_attempts'] += 1
90
+
91
+ try:
92
+ logger.info("Starting model loading process...")
93
+ self._maybe_progress(progress, 0.0, "Initializing model loading...")
94
+
95
+ # Clear any existing models
96
+ self._cleanup_models()
97
+
98
+ # Load SAM2 first (typically faster)
99
+ logger.info("Loading SAM2 predictor...")
100
+ self._maybe_progress(progress, 0.1, "Loading SAM2 predictor...")
101
+ self.sam2_predictor = self._load_sam2_predictor(progress)
102
+
103
+ if self.sam2_predictor is None:
104
+ raise ModelLoadingError("Failed to load SAM2 predictor")
105
+
106
+ sam2_time = time.time() - start_time
107
+ self.loading_stats['sam2_load_time'] = sam2_time
108
+ logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
109
+
110
+ # Load MatAnyone
111
+ logger.info("Loading MatAnyone model...")
112
+ self._maybe_progress(progress, 0.6, "Loading MatAnyone model...")
113
+ matanyone_start = time.time()
114
+
115
+ self.matanyone_model, self.matanyone_core = self._load_matanyone_model(progress)
116
+
117
+ if self.matanyone_model is None:
118
+ raise ModelLoadingError("Failed to load MatAnyone model")
119
+
120
+ matanyone_time = time.time() - matanyone_start
121
+ self.loading_stats['matanyone_load_time'] = matanyone_time
122
+ logger.info(f"MatAnyone loaded in {matanyone_time:.2f}s")
123
+
124
+ # Final setup
125
+ total_time = time.time() - start_time
126
+ self.loading_stats['total_load_time'] = total_time
127
+ self.loading_stats['models_loaded'] = True
128
+
129
+ self._maybe_progress(progress, 1.0, "Models loaded successfully!")
130
+ logger.info(f"All models loaded successfully in {total_time:.2f}s")
131
+
132
+ return True
133
+
134
+ except Exception as e:
135
+ error_msg = f"Model loading failed: {str(e)}"
136
+ logger.error(f"{error_msg}\n{traceback.format_exc()}")
137
+
138
+ # Cleanup on failure
139
+ self._cleanup_models()
140
+ self.loading_stats['models_loaded'] = False
141
+
142
+ if progress:
143
+ progress(1.0, desc=f"Error: {error_msg}")
144
+
145
+ raise ModelLoadingError(error_msg) from e
146
+
147
+ def _load_sam2_predictor(self, progress: Optional[gr.Progress] = None):
148
+ """
149
+ Load SAM2 predictor with multiple fallback strategies
150
+
151
+ Args:
152
+ progress: Gradio progress callback
153
+
154
+ Returns:
155
+ SAM2ImagePredictor or None
156
+ """
157
+ if not os.path.isdir(self.configs_dir):
158
+ raise ModelLoadingError(f"SAM2 Configs directory not found at '{self.configs_dir}'")
159
+
160
+ def try_load_sam2(config_name_with_yaml: str, checkpoint_name: str):
161
+ """Attempt to load SAM2 with given config and checkpoint"""
162
+ try:
163
+ checkpoint_path = os.path.join(self.checkpoints_dir, checkpoint_name)
164
+ logger.info(f"Attempting SAM2 checkpoint: {checkpoint_path}")
165
+
166
+ # Download checkpoint if needed
167
+ if not os.path.exists(checkpoint_path):
168
+ logger.info(f"Downloading {checkpoint_name} from Hugging Face Hub...")
169
+ self._maybe_progress(progress, 0.2, f"Downloading {checkpoint_name}...")
170
+
171
+ from huggingface_hub import hf_hub_download
172
+ repo = f"facebook/{config_name_with_yaml.replace('.yaml','')}"
173
+ checkpoint_path = hf_hub_download(
174
+ repo_id=repo,
175
+ filename=checkpoint_name,
176
+ cache_dir=self.checkpoints_dir,
177
+ local_dir_use_symlinks=False
178
+ )
179
+ logger.info(f"Download complete: {checkpoint_path}")
180
+
181
+ # Reset and initialize Hydra
182
+ if hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
183
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
184
+
185
+ hydra.initialize(
186
+ version_base=None,
187
+ config_path=os.path.relpath(self.configs_dir),
188
+ job_name=f"sam2_load_{int(time.time())}"
189
+ )
190
+
191
+ # Build SAM2 model
192
+ config_name = config_name_with_yaml.replace(".yaml", "")
193
+ self._maybe_progress(progress, 0.4, f"Building {config_name}...")
194
+
195
+ from sam2.build_sam import build_sam2
196
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
197
+
198
+ sam2_model = build_sam2(config_name, checkpoint_path)
199
+ sam2_model.to(self.device)
200
+ predictor = SAM2ImagePredictor(sam2_model)
201
+
202
+ logger.info(f"SAM2 {config_name} loaded successfully on {self.device}")
203
+ return predictor
204
+
205
+ except Exception as e:
206
+ error_msg = f"Failed to load SAM2 {config_name_with_yaml}: {e}"
207
+ logger.warning(error_msg)
208
+ return None
209
+
210
+ # Try different SAM2 model sizes based on device capabilities
211
+ model_attempts = [
212
+ ("sam2_hiera_large.yaml", "sam2_hiera_large.pt"),
213
+ ("sam2_hiera_base_plus.yaml", "sam2_hiera_base_plus.pt"),
214
+ ("sam2_hiera_small.yaml", "sam2_hiera_small.pt"),
215
+ ("sam2_hiera_tiny.yaml", "sam2_hiera_tiny.pt")
216
+ ]
217
+
218
+ # Prioritize model size based on device memory
219
+ if hasattr(self.device_manager, 'get_device_memory_gb'):
220
+ memory_gb = self.device_manager.get_device_memory_gb()
221
+ if memory_gb < 4:
222
+ model_attempts = model_attempts[2:] # Only tiny and small
223
+ elif memory_gb < 8:
224
+ model_attempts = model_attempts[1:] # Skip large
225
+
226
+ for config_yaml, checkpoint_pt in model_attempts:
227
+ predictor = try_load_sam2(config_yaml, checkpoint_pt)
228
+ if predictor is not None:
229
+ return predictor
230
+
231
+ raise ModelLoadingError("All SAM2 model loading attempts failed")
232
+
233
+ def _load_matanyone_model(self, progress: Optional[gr.Progress] = None):
234
+ """
235
+ Load MatAnyone model with multiple import strategies
236
+
237
+ Args:
238
+ progress: Gradio progress callback
239
+
240
+ Returns:
241
+ Tuple[model, core] or (None, None)
242
+ """
243
+ import_strategies = [
244
+ self._load_matanyone_strategy_1,
245
+ self._load_matanyone_strategy_2,
246
+ self._load_matanyone_strategy_3,
247
+ self._load_matanyone_strategy_4
248
+ ]
249
+
250
+ for i, strategy in enumerate(import_strategies, 1):
251
+ try:
252
+ logger.info(f"Trying MatAnyone loading strategy {i}...")
253
+ self._maybe_progress(progress, 0.7 + (i * 0.05), f"MatAnyone strategy {i}...")
254
+
255
+ model, core = strategy()
256
+ if model is not None and core is not None:
257
+ logger.info(f"MatAnyone loaded successfully with strategy {i}")
258
+ return model, core
259
+
260
+ except Exception as e:
261
+ logger.warning(f"MatAnyone strategy {i} failed: {e}")
262
+ continue
263
+
264
+ raise ModelLoadingError("All MatAnyone loading strategies failed")
265
+
266
+ def _load_matanyone_strategy_1(self):
267
+ """MatAnyone loading strategy 1: Direct model import"""
268
+ from matanyone.model.matanyone import MatAnyOne
269
+ from matanyone.inference.inference_core import InferenceCore
270
+
271
+ cfg = OmegaConf.create({
272
+ 'model': {'name': 'MatAnyOne'},
273
+ 'device': str(self.device),
274
+ 'fp16': True if self.device.type == 'cuda' else False
275
+ })
276
+
277
+ net = MatAnyOne(cfg)
278
+ core = InferenceCore(net, cfg)
279
+
280
+ return net, core
281
+
282
+ def _load_matanyone_strategy_2(self):
283
+ """MatAnyone loading strategy 2: Alternative import paths"""
284
+ from matanyone import MatAnyOne
285
+ from matanyone import InferenceCore
286
+
287
+ cfg = OmegaConf.create({
288
+ 'model_name': 'matanyone',
289
+ 'device': str(self.device)
290
+ })
291
+
292
+ model = MatAnyOne(cfg)
293
+ core = InferenceCore(model, cfg)
294
+
295
+ return model, core
296
+
297
+ def _load_matanyone_strategy_3(self):
298
+ """MatAnyone loading strategy 3: Repository-specific imports"""
299
+ try:
300
+ from matanyone.models.matanyone import MatAnyOneModel
301
+ from matanyone.core import InferenceEngine
302
+ except ImportError:
303
+ from matanyone.src.models import MatAnyOneModel
304
+ from matanyone.src.core import InferenceEngine
305
+
306
+ config = {
307
+ 'model_path': None, # Will use default
308
+ 'device': self.device,
309
+ 'precision': 'fp16' if self.device.type == 'cuda' else 'fp32'
310
+ }
311
+
312
+ model = MatAnyOneModel.from_pretrained(config)
313
+ engine = InferenceEngine(model)
314
+
315
+ return model, engine
316
+
317
+ def _load_matanyone_strategy_4(self):
318
+ """MatAnyone loading strategy 4: Hugging Face Hub approach"""
319
+ from huggingface_hub import hf_hub_download
320
+ from matanyone import load_model_from_hub
321
+
322
+ # Try to load from Hugging Face
323
+ model_path = hf_hub_download(
324
+ repo_id="PeiqingYang/MatAnyone",
325
+ filename="pytorch_model.bin",
326
+ cache_dir=self.checkpoints_dir
327
+ )
328
+
329
+ model = load_model_from_hub(model_path, device=self.device)
330
+
331
+ return model, model # Return same object for both
332
+
333
+ def _cleanup_models(self):
334
+ """Clean up loaded models and free memory"""
335
+ if self.sam2_predictor is not None:
336
+ del self.sam2_predictor
337
+ self.sam2_predictor = None
338
+
339
+ if self.matanyone_model is not None:
340
+ del self.matanyone_model
341
+ self.matanyone_model = None
342
+
343
+ if self.matanyone_core is not None:
344
+ del self.matanyone_core
345
+ self.matanyone_core = None
346
+
347
+ # Clear GPU cache
348
+ self.memory_manager.cleanup_gpu_memory()
349
+ gc.collect()
350
+
351
+ logger.debug("Model cleanup completed")
352
+
353
+ def _maybe_progress(self, progress: Optional[gr.Progress], value: float, desc: str):
354
+ """Update progress if callback is available"""
355
+ if progress is not None:
356
+ try:
357
+ progress(value, desc=desc)
358
+ except Exception as e:
359
+ logger.debug(f"Progress update failed: {e}")
360
+
361
+ def get_model_info(self) -> Dict[str, Any]:
362
+ """
363
+ Get information about loaded models
364
+
365
+ Returns:
366
+ Dict with model information and statistics
367
+ """
368
+ info = {
369
+ 'models_loaded': self.loading_stats['models_loaded'],
370
+ 'sam2_loaded': self.sam2_predictor is not None,
371
+ 'matanyone_loaded': self.matanyone_model is not None,
372
+ 'device': str(self.device),
373
+ 'loading_stats': self.loading_stats.copy()
374
+ }
375
+
376
+ if self.sam2_predictor is not None:
377
+ try:
378
+ info['sam2_model_type'] = type(self.sam2_predictor.model).__name__
379
+ except:
380
+ info['sam2_model_type'] = "Unknown"
381
+
382
+ if self.matanyone_model is not None:
383
+ try:
384
+ info['matanyone_model_type'] = type(self.matanyone_model).__name__
385
+ except:
386
+ info['matanyone_model_type'] = "Unknown"
387
+
388
+ return info
389
+
390
+ def validate_models(self) -> bool:
391
+ """
392
+ Validate that models are properly loaded and functional
393
+
394
+ Returns:
395
+ bool: True if models are valid
396
+ """
397
+ try:
398
+ # Basic validation
399
+ if not self.loading_stats['models_loaded']:
400
+ return False
401
+
402
+ if self.sam2_predictor is None or self.matanyone_model is None:
403
+ return False
404
+
405
+ # Try basic model operations
406
+ # This could include running a small test inference
407
+ logger.info("Model validation passed")
408
+ return True
409
+
410
+ except Exception as e:
411
+ logger.error(f"Model validation failed: {e}")
412
+ return False
413
+
414
+ def reload_models(self, progress: Optional[gr.Progress] = None) -> bool:
415
+ """
416
+ Reload all models (useful for error recovery)
417
+
418
+ Args:
419
+ progress: Gradio progress callback
420
+
421
+ Returns:
422
+ bool: True if reload successful
423
+ """
424
+ logger.info("Reloading models...")
425
+ self._cleanup_models()
426
+ self.loading_stats['models_loaded'] = False
427
+
428
+ return self.load_all_models(progress)
429
+
430
+ @property
431
+ def models_ready(self) -> bool:
432
+ """Check if all models are loaded and ready"""
433
+ return (
434
+ self.loading_stats['models_loaded'] and
435
+ self.sam2_predictor is not None and
436
+ self.matanyone_model is not None
437
+ )