MogensR commited on
Commit
baea23e
·
1 Parent(s): 9da6723

Update models/loaders/model_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/model_loader.py +120 -263
models/loaders/model_loader.py CHANGED
@@ -1,74 +1,51 @@
 
1
  """
2
  Model Loading Module
3
- Handles loading and validation of SAM2 and MatAnyone AI models
 
4
  """
5
 
6
- # ============================================================================ #
7
  # IMPORTS AND DEPENDENCIES
8
- # ============================================================================ #
9
-
10
  import os
11
  import gc
12
  import sys
13
  import time
14
- import shutil
15
  import logging
16
- import tempfile
17
  import traceback
18
- from typing import Optional, Dict, Any, Tuple, Union
19
  from pathlib import Path
 
20
 
21
  import torch
22
- from omegaconf import DictConfig, OmegaConf
23
 
24
- # Import modular components - Updated paths for BackgroundFX Pro structure
25
- from core.exceptions import ModelLoadingError # Updated import
26
- from utils.hardware.device_manager import DeviceManager # Updated import
27
- from utils.system.memory_manager import MemoryManager # Updated import
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
- # ============================================================================ #
32
- # DATA CONTAINER CLASSES
33
- # ============================================================================ #
34
-
35
- class LoadedModel:
36
- """Container for a loaded model with metadata"""
37
- def __init__(self, model=None, model_id: str = "", load_time: float = 0.0):
38
- self.model = model
39
- self.model_id = model_id
40
- self.load_time = load_time
41
- self.device = None
42
- self.framework = None
43
-
44
- def __repr__(self):
45
- return f"LoadedModel(id={self.model_id}, loaded={self.model is not None})"
46
-
47
- # ============================================================================ #
48
- # MODEL LOADER CLASS - MAIN INTERFACE
49
- # ============================================================================ #
50
-
51
  class ModelLoader:
52
  """
53
- Simplified model loading for SAM2 and MatAnyone
54
- Uses only the working loading strategies without redundant attempts
55
  """
56
-
57
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
58
  self.device_manager = device_mgr
59
  self.memory_manager = memory_mgr
60
  self.device = self.device_manager.get_optimal_device()
61
-
62
- # Model storage
63
  self.sam2_predictor = None
64
- self.matanyone_model = None
65
- self.matanyone_core = None
66
-
67
- # Configuration paths
68
  self.checkpoints_dir = "./checkpoints"
69
  os.makedirs(self.checkpoints_dir, exist_ok=True)
70
-
71
- # Model loading statistics
72
  self.loading_stats = {
73
  'sam2_load_time': 0.0,
74
  'matanyone_load_time': 0.0,
@@ -76,111 +53,90 @@ def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
76
  'models_loaded': False,
77
  'loading_attempts': 0
78
  }
79
-
80
- logger.info(f"ModelLoader initialized for device: {self.device}")
81
 
82
- # ============================================================================ #
83
- # MAIN MODEL LOADING ORCHESTRATION
84
- # ============================================================================ #
85
 
86
- def load_all_models(self, progress_callback: Optional[callable] = None, cancel_event=None) -> Tuple[Any, Any]:
 
 
 
87
  """
88
- Load both SAM2 and MatAnyone models with comprehensive error handling
89
-
90
- Args:
91
- progress_callback: Progress update callback
92
- cancel_event: Event to check for cancellation
93
-
94
- Returns:
95
- Tuple of (sam2_predictor, matanyone_model)
96
  """
97
  start_time = time.time()
98
  self.loading_stats['loading_attempts'] += 1
99
-
100
  try:
101
  logger.info("Starting model loading process...")
102
  if progress_callback:
103
  progress_callback(0.0, "Initializing model loading...")
104
-
105
- # Clear any existing models
106
  self._cleanup_models()
107
-
108
  # Load SAM2 first
109
  logger.info("Loading SAM2 predictor...")
110
  if progress_callback:
111
  progress_callback(0.1, "Loading SAM2 predictor...")
112
-
113
  self.sam2_predictor = self._load_sam2_predictor(progress_callback)
114
-
115
  if self.sam2_predictor is None:
116
  logger.warning("SAM2 loading failed - will use fallback segmentation")
117
  else:
118
  sam2_time = time.time() - start_time
119
  self.loading_stats['sam2_load_time'] = sam2_time
120
  logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
121
-
122
- # Load MatAnyone
123
- logger.info("Loading MatAnyone model...")
124
  if progress_callback:
125
- progress_callback(0.6, "Loading MatAnyone model...")
126
-
127
  matanyone_start = time.time()
128
-
129
- self.matanyone_model, self.matanyone_core = self._load_matanyone_model(progress_callback)
130
-
131
  if self.matanyone_model is None:
132
- logger.warning("MatAnyone loading failed - will use OpenCV refinement")
133
  else:
134
  matanyone_time = time.time() - matanyone_start
135
  self.loading_stats['matanyone_load_time'] = matanyone_time
136
- logger.info(f"MatAnyone loaded in {matanyone_time:.1f}s")
137
-
138
- # Final setup
139
  total_time = time.time() - start_time
140
  self.loading_stats['total_load_time'] = total_time
141
  self.loading_stats['models_loaded'] = True
142
-
143
  if progress_callback:
144
  if self.sam2_predictor or self.matanyone_model:
145
  progress_callback(1.0, "Models loaded (with fallbacks available)")
146
  else:
147
  progress_callback(1.0, "Using fallback methods (models failed to load)")
148
-
149
  logger.info(f"Model loading completed in {total_time:.2f}s")
150
-
151
  return self.sam2_predictor, self.matanyone_model
152
-
153
  except Exception as e:
154
  error_msg = f"Model loading failed: {str(e)}"
155
  logger.error(f"{error_msg}\n{traceback.format_exc()}")
156
-
157
- # Cleanup on failure
158
  self._cleanup_models()
159
  self.loading_stats['models_loaded'] = False
160
-
161
  if progress_callback:
162
  progress_callback(1.0, f"Error: {error_msg}")
163
-
164
  return None, None
165
 
166
- # ============================================================================ #
167
- # SAM2 MODEL LOADING - DIRECT OFFICIAL APPROACH ONLY
168
- # ============================================================================ #
169
-
170
- def _load_sam2_predictor(self, progress_callback: Optional[callable] = None):
171
  """
172
- Load SAM2 using only the official from_pretrained method that works
173
-
174
- Args:
175
- progress_callback: Progress update callback
176
-
177
- Returns:
178
- SAM2 predictor or None
179
  """
180
- # Determine model size based on device memory
181
- model_size = "large" # default
182
- if hasattr(self.device_manager, 'get_device_memory_gb'):
183
- try:
184
  memory_gb = self.device_manager.get_device_memory_gb()
185
  if memory_gb < 4:
186
  model_size = "tiny"
@@ -189,139 +145,98 @@ def _load_sam2_predictor(self, progress_callback: Optional[callable] = None):
189
  elif memory_gb < 12:
190
  model_size = "base"
191
  logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB memory")
192
- except Exception as e:
193
- logger.warning(f"Could not determine device memory: {e}")
194
-
195
  model_map = {
196
  "tiny": "facebook/sam2.1-hiera-tiny",
197
- "small": "facebook/sam2.1-hiera-small",
198
  "base": "facebook/sam2.1-hiera-base-plus",
199
  "large": "facebook/sam2.1-hiera-large"
200
  }
201
-
202
  model_id = model_map.get(model_size, model_map["large"])
203
-
204
  if progress_callback:
205
  progress_callback(0.3, f"Loading SAM2 {model_size} model...")
206
-
207
- # Use ONLY the official SAM2 from_pretrained method that works
208
  try:
209
- logger.info(f"Loading SAM2 from {model_id}...")
210
  from sam2.sam2_image_predictor import SAM2ImagePredictor
211
-
212
- # This is the method that successfully downloads and loads the model
213
  predictor = SAM2ImagePredictor.from_pretrained(model_id)
214
-
215
- # Move to correct device if needed
216
  if hasattr(predictor, 'model'):
217
  predictor.model = predictor.model.to(self.device)
218
-
219
  logger.info("SAM2 loaded successfully via official from_pretrained")
220
  return predictor
221
-
222
- except ImportError as e:
223
- logger.error(f"SAM2 module not found. Install with: pip install sam2")
224
  return None
225
-
226
  except Exception as e:
227
  logger.error(f"SAM2 loading failed: {e}")
228
- # Try downloading checkpoint manually as fallback
229
- try:
230
- logger.info("Attempting manual checkpoint download...")
231
- import urllib.request
232
-
233
- checkpoint_url = f"https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2.1_hiera_{model_size}.pt"
234
- checkpoint_path = os.path.join(self.checkpoints_dir, f"sam2.1_hiera_{model_size}.pt")
235
-
236
- if not os.path.exists(checkpoint_path):
237
- logger.info(f"Downloading checkpoint from {checkpoint_url}")
238
- urllib.request.urlretrieve(checkpoint_url, checkpoint_path)
239
-
240
- # Try loading with downloaded checkpoint
241
- predictor = SAM2ImagePredictor.from_pretrained(model_id, checkpoint=checkpoint_path)
242
- logger.info("SAM2 loaded successfully with manual checkpoint")
243
- return predictor
244
-
245
- except Exception as fallback_error:
246
- logger.error(f"Manual checkpoint fallback also failed: {fallback_error}")
247
- return None
248
-
249
- # ============================================================================ #
250
- # MATANYONE MODEL LOADING
251
- # ============================================================================ #
252
-
253
- def _load_matanyone_model(self, progress_callback: Optional[callable] = None):
254
  """
255
- Load MatAnyone model - try official method only
256
-
257
- Args:
258
- progress_callback: Progress update callback
259
-
260
- Returns:
261
- Tuple[model, core] or (None, None)
262
  """
263
  try:
264
- logger.info("Loading MatAnyone from HuggingFace...")
265
  if progress_callback:
266
- progress_callback(0.7, "Loading MatAnyone model...")
267
-
 
268
  from matanyone import InferenceCore
269
-
270
- # Initialize with the official model repo
271
- processor = InferenceCore("PeiqingYang/MatAnyone")
272
-
273
- logger.info("MatAnyone loaded successfully")
274
- return processor, processor
275
-
 
 
 
 
 
 
 
 
276
  except ImportError:
277
- logger.error("MatAnyone module not found. Install with: pip install matanyone")
278
- return None, None
279
-
280
  except Exception as e:
281
- logger.error(f"MatAnyone loading failed: {e}")
282
- return None, None
283
-
284
- # ============================================================================ #
285
- # MODEL MANAGEMENT AND CLEANUP
286
- # ============================================================================ #
287
 
 
 
 
288
  def _cleanup_models(self):
289
- """Clean up loaded models and free memory"""
290
  if self.sam2_predictor is not None:
291
  del self.sam2_predictor
292
  self.sam2_predictor = None
293
-
294
  if self.matanyone_model is not None:
295
  del self.matanyone_model
296
  self.matanyone_model = None
297
-
298
- if self.matanyone_core is not None:
299
- del self.matanyone_core
300
- self.matanyone_core = None
301
-
302
- # Clear GPU cache
303
  if torch.cuda.is_available():
304
  torch.cuda.empty_cache()
305
  gc.collect()
306
-
307
  logger.debug("Model cleanup completed")
308
-
309
  def cleanup(self):
310
- """Clean up all resources"""
311
  self._cleanup_models()
312
  logger.info("ModelLoader cleanup completed")
313
 
314
- # ============================================================================ #
315
- # MODEL INFORMATION AND STATUS
316
- # ============================================================================ #
317
-
318
  def get_model_info(self) -> Dict[str, Any]:
319
- """
320
- Get information about loaded models
321
-
322
- Returns:
323
- Dict with model information and statistics
324
- """
325
  info = {
326
  'models_loaded': self.loading_stats['models_loaded'],
327
  'sam2_loaded': self.sam2_predictor is not None,
@@ -329,118 +244,60 @@ def get_model_info(self) -> Dict[str, Any]:
329
  'device': str(self.device),
330
  'loading_stats': self.loading_stats.copy()
331
  }
332
-
333
  if self.sam2_predictor is not None:
334
- try:
335
- info['sam2_model_type'] = type(self.sam2_predictor).__name__
336
- if hasattr(self.sam2_predictor, 'model'):
337
- info['sam2_has_model'] = True
338
- if hasattr(self.sam2_predictor, 'predictor'):
339
- info['sam2_has_predictor'] = True
340
- except:
341
- info['sam2_model_type'] = "Unknown"
342
-
343
  if self.matanyone_model is not None:
344
- try:
345
- info['matanyone_model_type'] = type(self.matanyone_model).__name__
346
- except:
347
- info['matanyone_model_type'] = "Unknown"
348
-
349
  return info
350
-
351
- def get_status(self) -> Dict[str, Any]:
352
- """Get model loader status for backward compatibility"""
353
- return self.get_model_info()
354
-
355
  def get_load_summary(self) -> str:
356
- """Get a human-readable summary of model loading"""
357
  if not self.loading_stats['models_loaded']:
358
  return "Models not loaded"
359
-
360
  sam2_time = self.loading_stats['sam2_load_time']
361
  matanyone_time = self.loading_stats['matanyone_load_time']
362
  total_time = self.loading_stats['total_load_time']
363
-
364
  summary = f"Models loaded in {total_time:.1f}s\n"
365
-
366
  if self.sam2_predictor:
367
  summary += f"✓ SAM2: {sam2_time:.1f}s\n"
368
  else:
369
  summary += f"✗ SAM2: Failed (using fallback)\n"
370
-
371
  if self.matanyone_model:
372
- summary += f"✓ MatAnyone: {matanyone_time:.1f}s\n"
373
  else:
374
- summary += f"✗ MatAnyone: Failed (using OpenCV)\n"
375
-
376
  summary += f"Device: {self.device}"
377
-
378
  return summary
379
-
380
  def get_matanyone(self):
381
- """Get MatAnyone model for backward compatibility"""
382
  return self.matanyone_model
383
-
384
  def get_sam2(self):
385
- """Get SAM2 predictor for backward compatibility"""
386
  return self.sam2_predictor
387
 
388
- # ============================================================================ #
389
- # MODEL VALIDATION AND TESTING
390
- # ============================================================================ #
391
-
392
  def validate_models(self) -> bool:
393
- """
394
- Validate that models are properly loaded and functional
395
-
396
- Returns:
397
- bool: True if at least one model is valid
398
- """
399
  try:
400
  has_valid_model = False
401
-
402
- # Check SAM2
403
  if self.sam2_predictor is not None:
404
- # Check for required methods/attributes
405
  if hasattr(self.sam2_predictor, 'set_image') or hasattr(self.sam2_predictor, 'predict'):
406
  has_valid_model = True
407
- logger.info("SAM2 validation passed")
408
- elif hasattr(self.sam2_predictor, 'model'):
409
- has_valid_model = True
410
- logger.info("SAM2 model found")
411
-
412
- # Check MatAnyone
413
  if self.matanyone_model is not None:
414
  has_valid_model = True
415
- logger.info("MatAnyone validation passed")
416
-
417
  return has_valid_model
418
-
419
  except Exception as e:
420
  logger.error(f"Model validation failed: {e}")
421
  return False
422
 
423
- # ============================================================================ #
424
- # UTILITY METHODS
425
- # ============================================================================ #
426
-
427
- def reload_models(self, progress_callback: Optional[callable] = None) -> Tuple[Any, Any]:
428
- """
429
- Reload all models (useful for error recovery)
430
-
431
- Args:
432
- progress_callback: Progress update callback
433
-
434
- Returns:
435
- Tuple of (sam2_predictor, matanyone_model)
436
- """
437
  logger.info("Reloading models...")
438
  self._cleanup_models()
439
  self.loading_stats['models_loaded'] = False
440
-
441
  return self.load_all_models(progress_callback)
442
-
443
  @property
444
  def models_ready(self) -> bool:
445
- """Check if at least one model is loaded and ready"""
446
- return self.sam2_predictor is not None or self.matanyone_model is not None
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
  """
3
  Model Loading Module
4
+ Handles loading and validation of SAM2 and MatAnyOne AI models
5
+ (Modern version for BackgroundFX Pro – only edit this file for model loading logic)
6
  """
7
 
8
+ # ============================================================================
9
  # IMPORTS AND DEPENDENCIES
10
+ # ============================================================================
 
11
  import os
12
  import gc
13
  import sys
14
  import time
 
15
  import logging
 
16
  import traceback
 
17
  from pathlib import Path
18
+ from typing import Optional, Dict, Any, Tuple, Callable
19
 
20
  import torch
 
21
 
22
+ # Modular dependencies (adapt as your structure changes)
23
+ from core.exceptions import ModelLoadingError
24
+ from utils.hardware.device_manager import DeviceManager
25
+ from utils.system.memory_manager import MemoryManager
26
 
27
  logger = logging.getLogger(__name__)
28
 
29
+ # ============================================================================
30
+ # MODEL LOADER CLASS
31
+ # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class ModelLoader:
33
  """
34
+ Loads and manages SAM2 and MatAnyOne models.
35
+ Tune all model-specific logic/settings here.
36
  """
37
+
38
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
39
  self.device_manager = device_mgr
40
  self.memory_manager = memory_mgr
41
  self.device = self.device_manager.get_optimal_device()
42
+
 
43
  self.sam2_predictor = None
44
+ self.matanyone_model = None # This is usually InferenceCore
45
+
 
 
46
  self.checkpoints_dir = "./checkpoints"
47
  os.makedirs(self.checkpoints_dir, exist_ok=True)
48
+
 
49
  self.loading_stats = {
50
  'sam2_load_time': 0.0,
51
  'matanyone_load_time': 0.0,
 
53
  'models_loaded': False,
54
  'loading_attempts': 0
55
  }
 
 
56
 
57
+ logger.info(f"ModelLoader initialized for device: {self.device}")
 
 
58
 
59
+ # ============================================================================
60
+ # MAIN LOADING FUNCTION (ORCHESTRATION)
61
+ # ============================================================================
62
+ def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_event=None) -> Tuple[Any, Any]:
63
  """
64
+ Loads both SAM2 and MatAnyOne models with error handling.
65
+ Returns: (sam2_predictor, matanyone_model)
 
 
 
 
 
 
66
  """
67
  start_time = time.time()
68
  self.loading_stats['loading_attempts'] += 1
69
+
70
  try:
71
  logger.info("Starting model loading process...")
72
  if progress_callback:
73
  progress_callback(0.0, "Initializing model loading...")
74
+
 
75
  self._cleanup_models()
76
+
77
  # Load SAM2 first
78
  logger.info("Loading SAM2 predictor...")
79
  if progress_callback:
80
  progress_callback(0.1, "Loading SAM2 predictor...")
 
81
  self.sam2_predictor = self._load_sam2_predictor(progress_callback)
82
+
83
  if self.sam2_predictor is None:
84
  logger.warning("SAM2 loading failed - will use fallback segmentation")
85
  else:
86
  sam2_time = time.time() - start_time
87
  self.loading_stats['sam2_load_time'] = sam2_time
88
  logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
89
+
90
+ # Load MatAnyOne
91
+ logger.info("Loading MatAnyOne model...")
92
  if progress_callback:
93
+ progress_callback(0.6, "Loading MatAnyOne model...")
 
94
  matanyone_start = time.time()
95
+
96
+ self.matanyone_model = self._load_matanyone_model(progress_callback)
97
+
98
  if self.matanyone_model is None:
99
+ logger.warning("MatAnyOne loading failed - will use OpenCV refinement")
100
  else:
101
  matanyone_time = time.time() - matanyone_start
102
  self.loading_stats['matanyone_load_time'] = matanyone_time
103
+ logger.info(f"MatAnyOne loaded in {matanyone_time:.1f}s")
104
+
105
+ # Final status
106
  total_time = time.time() - start_time
107
  self.loading_stats['total_load_time'] = total_time
108
  self.loading_stats['models_loaded'] = True
109
+
110
  if progress_callback:
111
  if self.sam2_predictor or self.matanyone_model:
112
  progress_callback(1.0, "Models loaded (with fallbacks available)")
113
  else:
114
  progress_callback(1.0, "Using fallback methods (models failed to load)")
115
+
116
  logger.info(f"Model loading completed in {total_time:.2f}s")
117
+
118
  return self.sam2_predictor, self.matanyone_model
119
+
120
  except Exception as e:
121
  error_msg = f"Model loading failed: {str(e)}"
122
  logger.error(f"{error_msg}\n{traceback.format_exc()}")
 
 
123
  self._cleanup_models()
124
  self.loading_stats['models_loaded'] = False
 
125
  if progress_callback:
126
  progress_callback(1.0, f"Error: {error_msg}")
 
127
  return None, None
128
 
129
+ # ============================================================================
130
+ # SAM2 LOADING (OFFICIAL FROM_PRETRAINED)
131
+ # ============================================================================
132
+ def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
 
133
  """
134
+ Loads SAM2 using the official Hugging Face interface.
135
+ Returns: SAM2 predictor object or None
 
 
 
 
 
136
  """
137
+ model_size = "large"
138
+ try:
139
+ if hasattr(self.device_manager, 'get_device_memory_gb'):
 
140
  memory_gb = self.device_manager.get_device_memory_gb()
141
  if memory_gb < 4:
142
  model_size = "tiny"
 
145
  elif memory_gb < 12:
146
  model_size = "base"
147
  logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB memory")
148
+ except Exception as e:
149
+ logger.warning(f"Could not determine device memory: {e}")
150
+
151
  model_map = {
152
  "tiny": "facebook/sam2.1-hiera-tiny",
153
+ "small": "facebook/sam2.1-hiera-small",
154
  "base": "facebook/sam2.1-hiera-base-plus",
155
  "large": "facebook/sam2.1-hiera-large"
156
  }
 
157
  model_id = model_map.get(model_size, model_map["large"])
158
+
159
  if progress_callback:
160
  progress_callback(0.3, f"Loading SAM2 {model_size} model...")
161
+
 
162
  try:
 
163
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
 
164
  predictor = SAM2ImagePredictor.from_pretrained(model_id)
 
 
165
  if hasattr(predictor, 'model'):
166
  predictor.model = predictor.model.to(self.device)
 
167
  logger.info("SAM2 loaded successfully via official from_pretrained")
168
  return predictor
169
+ except ImportError:
170
+ logger.error("SAM2 module not found. Install with: pip install sam2")
 
171
  return None
 
172
  except Exception as e:
173
  logger.error(f"SAM2 loading failed: {e}")
174
+ return None
175
+
176
+ # ============================================================================
177
+ # MATANYONE LOADING (OFFICIAL INFERENCECORE)
178
+ # ============================================================================
179
+ def _load_matanyone_model(self, progress_callback: Optional[Callable] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  """
181
+ Loads MatAnyOne using Hugging Face official 'matanyone' package.
182
+ Returns: InferenceCore object or None
183
+
184
+ ---------- MATANYONE TUNING SECTION ----------
185
+ To adjust MatAnyOne settings, change arguments to InferenceCore below!
186
+ (e.g., for precision, model variant, device, chunk size, etc.)
187
+ ---------------------------------------------
188
  """
189
  try:
 
190
  if progress_callback:
191
+ progress_callback(0.7, "Loading MatAnyOne model...")
192
+
193
+ # --- HIGHLIGHT: SET ANY MatAnyOne SETTINGS HERE ---
194
  from matanyone import InferenceCore
195
+
196
+ # Example: To set chunk size or custom model repo, add kwargs here.
197
+ # See: https://huggingface.co/PeiqingYang/MatAnyone for config options
198
+
199
+ matanyone_kwargs = dict(
200
+ repo_id="PeiqingYang/MatAnyone", # You can change to any compatible Hugging Face repo
201
+ device=self.device, # Device to load on ("cuda" or "cpu")
202
+ dtype=torch.float32, # Change to torch.float16 for faster inference on good GPUs
203
+ # chunk_size=512, # Optional: for memory tuning on large videos
204
+ )
205
+
206
+ processor = InferenceCore(**matanyone_kwargs)
207
+ logger.info("MatAnyOne loaded successfully (InferenceCore)")
208
+ return processor
209
+
210
  except ImportError:
211
+ logger.error("MatAnyOne module not found. Install with: pip install matanyone")
212
+ return None
 
213
  except Exception as e:
214
+ logger.error(f"MatAnyOne loading failed: {e}")
215
+ return None
 
 
 
 
216
 
217
+ # ============================================================================
218
+ # MODEL MANAGEMENT AND CLEANUP
219
+ # ============================================================================
220
  def _cleanup_models(self):
 
221
  if self.sam2_predictor is not None:
222
  del self.sam2_predictor
223
  self.sam2_predictor = None
 
224
  if self.matanyone_model is not None:
225
  del self.matanyone_model
226
  self.matanyone_model = None
 
 
 
 
 
 
227
  if torch.cuda.is_available():
228
  torch.cuda.empty_cache()
229
  gc.collect()
 
230
  logger.debug("Model cleanup completed")
231
+
232
  def cleanup(self):
 
233
  self._cleanup_models()
234
  logger.info("ModelLoader cleanup completed")
235
 
236
+ # ============================================================================
237
+ # MODEL INFO AND VALIDATION
238
+ # ============================================================================
 
239
  def get_model_info(self) -> Dict[str, Any]:
 
 
 
 
 
 
240
  info = {
241
  'models_loaded': self.loading_stats['models_loaded'],
242
  'sam2_loaded': self.sam2_predictor is not None,
 
244
  'device': str(self.device),
245
  'loading_stats': self.loading_stats.copy()
246
  }
 
247
  if self.sam2_predictor is not None:
248
+ info['sam2_model_type'] = type(self.sam2_predictor).__name__
 
 
 
 
 
 
 
 
249
  if self.matanyone_model is not None:
250
+ info['matanyone_model_type'] = type(self.matanyone_model).__name__
 
 
 
 
251
  return info
252
+
 
 
 
 
253
  def get_load_summary(self) -> str:
 
254
  if not self.loading_stats['models_loaded']:
255
  return "Models not loaded"
 
256
  sam2_time = self.loading_stats['sam2_load_time']
257
  matanyone_time = self.loading_stats['matanyone_load_time']
258
  total_time = self.loading_stats['total_load_time']
 
259
  summary = f"Models loaded in {total_time:.1f}s\n"
 
260
  if self.sam2_predictor:
261
  summary += f"✓ SAM2: {sam2_time:.1f}s\n"
262
  else:
263
  summary += f"✗ SAM2: Failed (using fallback)\n"
 
264
  if self.matanyone_model:
265
+ summary += f"✓ MatAnyOne: {matanyone_time:.1f}s\n"
266
  else:
267
+ summary += f"✗ MatAnyOne: Failed (using OpenCV)\n"
 
268
  summary += f"Device: {self.device}"
 
269
  return summary
270
+
271
  def get_matanyone(self):
 
272
  return self.matanyone_model
273
+
274
  def get_sam2(self):
 
275
  return self.sam2_predictor
276
 
 
 
 
 
277
  def validate_models(self) -> bool:
 
 
 
 
 
 
278
  try:
279
  has_valid_model = False
 
 
280
  if self.sam2_predictor is not None:
 
281
  if hasattr(self.sam2_predictor, 'set_image') or hasattr(self.sam2_predictor, 'predict'):
282
  has_valid_model = True
 
 
 
 
 
 
283
  if self.matanyone_model is not None:
284
  has_valid_model = True
 
 
285
  return has_valid_model
 
286
  except Exception as e:
287
  logger.error(f"Model validation failed: {e}")
288
  return False
289
 
290
+ def reload_models(self, progress_callback: Optional[Callable] = None) -> Tuple[Any, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  logger.info("Reloading models...")
292
  self._cleanup_models()
293
  self.loading_stats['models_loaded'] = False
 
294
  return self.load_all_models(progress_callback)
295
+
296
  @property
297
  def models_ready(self) -> bool:
298
+ return self.sam2_predictor is not None or self.matanyone_model is not None
299
+
300
+ # ============================================================================
301
+ # END MODEL LOADER
302
+ # ============================================================================
303
+