MogensR commited on
Commit
bcb51b3
·
1 Parent(s): 57e7830

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +107 -68
model_loader.py CHANGED
@@ -17,9 +17,10 @@
17
  import gradio as gr
18
  from omegaconf import DictConfig, OmegaConf
19
 
20
- from exceptions import ModelLoadingError, ConfigurationError
21
- from device_manager import DeviceManager
22
- from memory_manager import MemoryManager
 
23
 
24
  logger = logging.getLogger(__name__)
25
 
@@ -28,10 +29,10 @@ class ModelLoader:
28
  Comprehensive model loading and management for SAM2 and MatAnyone
29
  """
30
 
31
- def __init__(self, device_manager: DeviceManager, memory_manager: MemoryManager):
32
- self.device_manager = device_manager
33
- self.memory_manager = memory_manager
34
- self.device = device
35
 
36
  # Model storage
37
  self.sam2_predictor = None
@@ -75,33 +76,37 @@ def patched_get_config(self):
75
  except (ImportError, AttributeError) as e:
76
  logger.warning(f"Could not apply Gradio monkey patch: {e}")
77
 
78
- def load_all_models(self, progress: Optional[gr.Progress] = None) -> bool:
79
  """
80
  Load both SAM2 and MatAnyone models with comprehensive error handling
81
 
82
  Args:
83
- progress: Gradio progress callback
 
84
 
85
  Returns:
86
- bool: True if all models loaded successfully
87
  """
88
  start_time = time.time()
89
  self.loading_stats['loading_attempts'] += 1
90
 
91
  try:
92
  logger.info("Starting model loading process...")
93
- self._maybe_progress(progress, 0.0, "Initializing model loading...")
 
94
 
95
  # Clear any existing models
96
  self._cleanup_models()
97
 
98
  # Load SAM2 first (typically faster)
99
  logger.info("Loading SAM2 predictor...")
100
- self._maybe_progress(progress, 0.1, "Loading SAM2 predictor...")
101
- self.sam2_predictor = self._load_sam2_predictor(progress)
 
 
102
 
103
  if self.sam2_predictor is None:
104
- raise ModelLoadingError("Failed to load SAM2 predictor")
105
 
106
  sam2_time = time.time() - start_time
107
  self.loading_stats['sam2_load_time'] = sam2_time
@@ -109,13 +114,15 @@ def load_all_models(self, progress: Optional[gr.Progress] = None) -> bool:
109
 
110
  # Load MatAnyone
111
  logger.info("Loading MatAnyone model...")
112
- self._maybe_progress(progress, 0.6, "Loading MatAnyone model...")
 
 
113
  matanyone_start = time.time()
114
 
115
- self.matanyone_model, self.matanyone_core = self._load_matanyone_model(progress)
116
 
117
  if self.matanyone_model is None:
118
- raise ModelLoadingError("Failed to load MatAnyone model")
119
 
120
  matanyone_time = time.time() - matanyone_start
121
  self.loading_stats['matanyone_load_time'] = matanyone_time
@@ -126,10 +133,12 @@ def load_all_models(self, progress: Optional[gr.Progress] = None) -> bool:
126
  self.loading_stats['total_load_time'] = total_time
127
  self.loading_stats['models_loaded'] = True
128
 
129
- self._maybe_progress(progress, 1.0, "Models loaded successfully!")
 
 
130
  logger.info(f"All models loaded successfully in {total_time:.2f}s")
131
 
132
- return True
133
 
134
  except Exception as e:
135
  error_msg = f"Model loading failed: {str(e)}"
@@ -139,23 +148,23 @@ def load_all_models(self, progress: Optional[gr.Progress] = None) -> bool:
139
  self._cleanup_models()
140
  self.loading_stats['models_loaded'] = False
141
 
142
- if progress:
143
- progress(1.0, desc=f"Error: {error_msg}")
144
 
145
- raise ModelLoadingError(error_msg) from e
146
 
147
- def _load_sam2_predictor(self, progress: Optional[gr.Progress] = None):
148
  """
149
  Load SAM2 predictor with multiple fallback strategies
150
 
151
  Args:
152
- progress: Gradio progress callback
153
 
154
  Returns:
155
  SAM2ImagePredictor or None
156
  """
157
  if not os.path.isdir(self.configs_dir):
158
- raise ModelLoadingError(f"SAM2 Configs directory not found at '{self.configs_dir}'")
159
 
160
  def try_load_sam2(config_name_with_yaml: str, checkpoint_name: str):
161
  """Attempt to load SAM2 with given config and checkpoint"""
@@ -166,31 +175,38 @@ def try_load_sam2(config_name_with_yaml: str, checkpoint_name: str):
166
  # Download checkpoint if needed
167
  if not os.path.exists(checkpoint_path):
168
  logger.info(f"Downloading {checkpoint_name} from Hugging Face Hub...")
169
- self._maybe_progress(progress, 0.2, f"Downloading {checkpoint_name}...")
 
170
 
171
- from huggingface_hub import hf_hub_download
172
- repo = f"facebook/{config_name_with_yaml.replace('.yaml','')}"
173
- checkpoint_path = hf_hub_download(
174
- repo_id=repo,
175
- filename=checkpoint_name,
176
- cache_dir=self.checkpoints_dir,
177
- local_dir_use_symlinks=False
178
- )
179
- logger.info(f"Download complete: {checkpoint_path}")
 
 
 
 
180
 
181
- # Reset and initialize Hydra
182
- if hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
183
- hydra.core.global_hydra.GlobalHydra.instance().clear()
184
-
185
- hydra.initialize(
186
- version_base=None,
187
- config_path=os.path.relpath(self.configs_dir),
188
- job_name=f"sam2_load_{int(time.time())}"
189
- )
 
190
 
191
  # Build SAM2 model
192
  config_name = config_name_with_yaml.replace(".yaml", "")
193
- self._maybe_progress(progress, 0.4, f"Building {config_name}...")
 
194
 
195
  from sam2.build_sam import build_sam2
196
  from sam2.sam2_image_predictor import SAM2ImagePredictor
@@ -217,25 +233,29 @@ def try_load_sam2(config_name_with_yaml: str, checkpoint_name: str):
217
 
218
  # Prioritize model size based on device memory
219
  if hasattr(self.device_manager, 'get_device_memory_gb'):
220
- memory_gb = self.device_manager.get_device_memory_gb()
221
- if memory_gb < 4:
222
- model_attempts = model_attempts[2:] # Only tiny and small
223
- elif memory_gb < 8:
224
- model_attempts = model_attempts[1:] # Skip large
 
 
 
225
 
226
  for config_yaml, checkpoint_pt in model_attempts:
227
  predictor = try_load_sam2(config_yaml, checkpoint_pt)
228
  if predictor is not None:
229
  return predictor
230
 
231
- raise ModelLoadingError("All SAM2 model loading attempts failed")
 
232
 
233
- def _load_matanyone_model(self, progress: Optional[gr.Progress] = None):
234
  """
235
  Load MatAnyone model with multiple import strategies
236
 
237
  Args:
238
- progress: Gradio progress callback
239
 
240
  Returns:
241
  Tuple[model, core] or (None, None)
@@ -250,7 +270,8 @@ def _load_matanyone_model(self, progress: Optional[gr.Progress] = None):
250
  for i, strategy in enumerate(import_strategies, 1):
251
  try:
252
  logger.info(f"Trying MatAnyone loading strategy {i}...")
253
- self._maybe_progress(progress, 0.7 + (i * 0.05), f"MatAnyone strategy {i}...")
 
254
 
255
  model, core = strategy()
256
  if model is not None and core is not None:
@@ -261,7 +282,8 @@ def _load_matanyone_model(self, progress: Optional[gr.Progress] = None):
261
  logger.warning(f"MatAnyone strategy {i} failed: {e}")
262
  continue
263
 
264
- raise ModelLoadingError("All MatAnyone loading strategies failed")
 
265
 
266
  def _load_matanyone_strategy_1(self):
267
  """MatAnyone loading strategy 1: Direct model import"""
@@ -350,14 +372,6 @@ def _cleanup_models(self):
350
 
351
  logger.debug("Model cleanup completed")
352
 
353
- def _maybe_progress(self, progress: Optional[gr.Progress], value: float, desc: str):
354
- """Update progress if callback is available"""
355
- if progress is not None:
356
- try:
357
- progress(value, desc=desc)
358
- except Exception as e:
359
- logger.debug(f"Progress update failed: {e}")
360
-
361
  def get_model_info(self) -> Dict[str, Any]:
362
  """
363
  Get information about loaded models
@@ -387,6 +401,26 @@ def get_model_info(self) -> Dict[str, Any]:
387
 
388
  return info
389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  def validate_models(self) -> bool:
391
  """
392
  Validate that models are properly loaded and functional
@@ -411,21 +445,26 @@ def validate_models(self) -> bool:
411
  logger.error(f"Model validation failed: {e}")
412
  return False
413
 
414
- def reload_models(self, progress: Optional[gr.Progress] = None) -> bool:
415
  """
416
  Reload all models (useful for error recovery)
417
 
418
  Args:
419
- progress: Gradio progress callback
420
 
421
  Returns:
422
- bool: True if reload successful
423
  """
424
  logger.info("Reloading models...")
425
  self._cleanup_models()
426
  self.loading_stats['models_loaded'] = False
427
 
428
- return self.load_all_models(progress)
 
 
 
 
 
429
 
430
  @property
431
  def models_ready(self) -> bool:
@@ -434,4 +473,4 @@ def models_ready(self) -> bool:
434
  self.loading_stats['models_loaded'] and
435
  self.sam2_predictor is not None and
436
  self.matanyone_model is not None
437
- )
 
17
  import gradio as gr
18
  from omegaconf import DictConfig, OmegaConf
19
 
20
+ # Import modular components
21
+ import exceptions
22
+ import device_manager
23
+ import memory_manager
24
 
25
  logger = logging.getLogger(__name__)
26
 
 
29
  Comprehensive model loading and management for SAM2 and MatAnyone
30
  """
31
 
32
+ def __init__(self, device_mgr: device_manager.DeviceManager, memory_mgr: memory_manager.MemoryManager):
33
+ self.device_manager = device_mgr
34
+ self.memory_manager = memory_mgr
35
+ self.device = self.device_manager.get_optimal_device()
36
 
37
  # Model storage
38
  self.sam2_predictor = None
 
76
  except (ImportError, AttributeError) as e:
77
  logger.warning(f"Could not apply Gradio monkey patch: {e}")
78
 
79
+ def load_all_models(self, progress_callback: Optional[callable] = None, cancel_event=None) -> Tuple[Any, Any]:
80
  """
81
  Load both SAM2 and MatAnyone models with comprehensive error handling
82
 
83
  Args:
84
+ progress_callback: Progress update callback
85
+ cancel_event: Event to check for cancellation
86
 
87
  Returns:
88
+ Tuple of (sam2_predictor, matanyone_model)
89
  """
90
  start_time = time.time()
91
  self.loading_stats['loading_attempts'] += 1
92
 
93
  try:
94
  logger.info("Starting model loading process...")
95
+ if progress_callback:
96
+ progress_callback(0.0, "Initializing model loading...")
97
 
98
  # Clear any existing models
99
  self._cleanup_models()
100
 
101
  # Load SAM2 first (typically faster)
102
  logger.info("Loading SAM2 predictor...")
103
+ if progress_callback:
104
+ progress_callback(0.1, "Loading SAM2 predictor...")
105
+
106
+ self.sam2_predictor = self._load_sam2_predictor(progress_callback)
107
 
108
  if self.sam2_predictor is None:
109
+ raise exceptions.ModelLoadingError("Failed to load SAM2 predictor")
110
 
111
  sam2_time = time.time() - start_time
112
  self.loading_stats['sam2_load_time'] = sam2_time
 
114
 
115
  # Load MatAnyone
116
  logger.info("Loading MatAnyone model...")
117
+ if progress_callback:
118
+ progress_callback(0.6, "Loading MatAnyone model...")
119
+
120
  matanyone_start = time.time()
121
 
122
+ self.matanyone_model, self.matanyone_core = self._load_matanyone_model(progress_callback)
123
 
124
  if self.matanyone_model is None:
125
+ raise exceptions.ModelLoadingError("Failed to load MatAnyone model")
126
 
127
  matanyone_time = time.time() - matanyone_start
128
  self.loading_stats['matanyone_load_time'] = matanyone_time
 
133
  self.loading_stats['total_load_time'] = total_time
134
  self.loading_stats['models_loaded'] = True
135
 
136
+ if progress_callback:
137
+ progress_callback(1.0, "Models loaded successfully!")
138
+
139
  logger.info(f"All models loaded successfully in {total_time:.2f}s")
140
 
141
+ return self.sam2_predictor, self.matanyone_model
142
 
143
  except Exception as e:
144
  error_msg = f"Model loading failed: {str(e)}"
 
148
  self._cleanup_models()
149
  self.loading_stats['models_loaded'] = False
150
 
151
+ if progress_callback:
152
+ progress_callback(1.0, f"Error: {error_msg}")
153
 
154
+ return None, None
155
 
156
+ def _load_sam2_predictor(self, progress_callback: Optional[callable] = None):
157
  """
158
  Load SAM2 predictor with multiple fallback strategies
159
 
160
  Args:
161
+ progress_callback: Progress update callback
162
 
163
  Returns:
164
  SAM2ImagePredictor or None
165
  """
166
  if not os.path.isdir(self.configs_dir):
167
+ logger.warning(f"SAM2 Configs directory not found at '{self.configs_dir}', trying fallback loading")
168
 
169
  def try_load_sam2(config_name_with_yaml: str, checkpoint_name: str):
170
  """Attempt to load SAM2 with given config and checkpoint"""
 
175
  # Download checkpoint if needed
176
  if not os.path.exists(checkpoint_path):
177
  logger.info(f"Downloading {checkpoint_name} from Hugging Face Hub...")
178
+ if progress_callback:
179
+ progress_callback(0.2, f"Downloading {checkpoint_name}...")
180
 
181
+ try:
182
+ from huggingface_hub import hf_hub_download
183
+ repo = f"facebook/{config_name_with_yaml.replace('.yaml','')}"
184
+ checkpoint_path = hf_hub_download(
185
+ repo_id=repo,
186
+ filename=checkpoint_name,
187
+ cache_dir=self.checkpoints_dir,
188
+ local_dir_use_symlinks=False
189
+ )
190
+ logger.info(f"Download complete: {checkpoint_path}")
191
+ except Exception as download_error:
192
+ logger.warning(f"Failed to download {checkpoint_name}: {download_error}")
193
+ return None
194
 
195
+ # Reset and initialize Hydra if configs directory exists
196
+ if os.path.isdir(self.configs_dir):
197
+ if hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
198
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
199
+
200
+ hydra.initialize(
201
+ version_base=None,
202
+ config_path=os.path.relpath(self.configs_dir),
203
+ job_name=f"sam2_load_{int(time.time())}"
204
+ )
205
 
206
  # Build SAM2 model
207
  config_name = config_name_with_yaml.replace(".yaml", "")
208
+ if progress_callback:
209
+ progress_callback(0.4, f"Building {config_name}...")
210
 
211
  from sam2.build_sam import build_sam2
212
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
233
 
234
  # Prioritize model size based on device memory
235
  if hasattr(self.device_manager, 'get_device_memory_gb'):
236
+ try:
237
+ memory_gb = self.device_manager.get_device_memory_gb()
238
+ if memory_gb < 4:
239
+ model_attempts = model_attempts[2:] # Only tiny and small
240
+ elif memory_gb < 8:
241
+ model_attempts = model_attempts[1:] # Skip large
242
+ except Exception as e:
243
+ logger.warning(f"Could not determine device memory: {e}")
244
 
245
  for config_yaml, checkpoint_pt in model_attempts:
246
  predictor = try_load_sam2(config_yaml, checkpoint_pt)
247
  if predictor is not None:
248
  return predictor
249
 
250
+ logger.error("All SAM2 model loading attempts failed")
251
+ return None
252
 
253
+ def _load_matanyone_model(self, progress_callback: Optional[callable] = None):
254
  """
255
  Load MatAnyone model with multiple import strategies
256
 
257
  Args:
258
+ progress_callback: Progress update callback
259
 
260
  Returns:
261
  Tuple[model, core] or (None, None)
 
270
  for i, strategy in enumerate(import_strategies, 1):
271
  try:
272
  logger.info(f"Trying MatAnyone loading strategy {i}...")
273
+ if progress_callback:
274
+ progress_callback(0.7 + (i * 0.05), f"MatAnyone strategy {i}...")
275
 
276
  model, core = strategy()
277
  if model is not None and core is not None:
 
282
  logger.warning(f"MatAnyone strategy {i} failed: {e}")
283
  continue
284
 
285
+ logger.error("All MatAnyone loading strategies failed")
286
+ return None, None
287
 
288
  def _load_matanyone_strategy_1(self):
289
  """MatAnyone loading strategy 1: Direct model import"""
 
372
 
373
  logger.debug("Model cleanup completed")
374
 
 
 
 
 
 
 
 
 
375
  def get_model_info(self) -> Dict[str, Any]:
376
  """
377
  Get information about loaded models
 
401
 
402
  return info
403
 
404
+ def get_status(self) -> Dict[str, Any]:
405
+ """Get model loader status for backward compatibility"""
406
+ return self.get_model_info()
407
+
408
+ def get_load_summary(self) -> str:
409
+ """Get a human-readable summary of model loading"""
410
+ if not self.loading_stats['models_loaded']:
411
+ return "Models not loaded"
412
+
413
+ sam2_time = self.loading_stats['sam2_load_time']
414
+ matanyone_time = self.loading_stats['matanyone_load_time']
415
+ total_time = self.loading_stats['total_load_time']
416
+
417
+ summary = f"Models loaded successfully in {total_time:.1f}s\n"
418
+ summary += f"SAM2: {sam2_time:.1f}s\n"
419
+ summary += f"MatAnyone: {matanyone_time:.1f}s\n"
420
+ summary += f"Device: {self.device}"
421
+
422
+ return summary
423
+
424
  def validate_models(self) -> bool:
425
  """
426
  Validate that models are properly loaded and functional
 
445
  logger.error(f"Model validation failed: {e}")
446
  return False
447
 
448
+ def reload_models(self, progress_callback: Optional[callable] = None) -> Tuple[Any, Any]:
449
  """
450
  Reload all models (useful for error recovery)
451
 
452
  Args:
453
+ progress_callback: Progress update callback
454
 
455
  Returns:
456
+ Tuple of (sam2_predictor, matanyone_model)
457
  """
458
  logger.info("Reloading models...")
459
  self._cleanup_models()
460
  self.loading_stats['models_loaded'] = False
461
 
462
+ return self.load_all_models(progress_callback)
463
+
464
+ def cleanup(self):
465
+ """Clean up all resources"""
466
+ self._cleanup_models()
467
+ logger.info("ModelLoader cleanup completed")
468
 
469
  @property
470
  def models_ready(self) -> bool:
 
473
  self.loading_stats['models_loaded'] and
474
  self.sam2_predictor is not None and
475
  self.matanyone_model is not None
476
+ )