MogensR commited on
Commit
f4f3e24
·
1 Parent(s): 3f32898

Update models/loaders/model_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/model_loader.py +103 -58
models/loaders/model_loader.py CHANGED
@@ -2,6 +2,14 @@
2
  """
3
  Unified Model Loader
4
  Coordinates separate SAM2 and MatAnyone loaders for cleaner architecture
 
 
 
 
 
 
 
 
5
  """
6
 
7
  from __future__ import annotations
@@ -27,7 +35,7 @@
27
 
28
  class LoadedModel:
29
  """Container for loaded model information"""
30
- def __init__(self, model=None, model_id: str = "", load_time: float = 0.0,
31
  device: str = "", framework: str = ""):
32
  self.model = model
33
  self.model_id = model_id
@@ -42,25 +50,26 @@ def to_dict(self) -> Dict[str, Any]:
42
  "device": self.device,
43
  "load_time": self.load_time,
44
  "loaded": self.model is not None,
 
45
  }
46
 
47
 
48
  class ModelLoader:
49
  """Main model loader that coordinates SAM2 and MatAnyone loaders"""
50
-
51
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
52
  self.device_manager = device_mgr
53
  self.memory_manager = memory_mgr
54
  self.device = self.device_manager.get_optimal_device()
55
-
56
  # Initialize specialized loaders
57
  self.sam2_loader = SAM2Loader(device=str(self.device))
58
  self.matanyone_loader = MatAnyoneLoader(device=str(self.device))
59
-
60
  # Model storage
61
  self.sam2_predictor: Optional[LoadedModel] = None
62
  self.matanyone_model: Optional[LoadedModel] = None
63
-
64
  # Statistics
65
  self.loading_stats = {
66
  "sam2_load_time": 0.0,
@@ -69,7 +78,7 @@ def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
69
  "models_loaded": False,
70
  "loading_attempts": 0,
71
  }
72
-
73
  logger.info(f"ModelLoader initialized for device: {self.device}")
74
 
75
  def load_all_models(
@@ -79,33 +88,29 @@ def load_all_models(
79
  ) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
80
  """
81
  Load all models using specialized loaders
82
-
83
- Args:
84
- progress_callback: Optional callback for progress updates
85
- cancel_event: Optional threading.Event for cancellation
86
-
87
  Returns:
88
  Tuple of (sam2_model, matanyone_model)
89
  """
90
  start_time = time.time()
91
  self.loading_stats["loading_attempts"] += 1
92
-
93
  try:
94
  logger.info("Starting model loading process...")
95
  if progress_callback:
96
  progress_callback(0.0, "Initializing model loading...")
97
-
98
  # Clean up any existing models
99
  self._cleanup_models()
100
-
101
- # Load SAM2
102
  if progress_callback:
103
  progress_callback(0.1, "Loading SAM2 model...")
104
-
105
  sam2_start = time.time()
106
  sam2_model = self.sam2_loader.load()
107
  sam2_time = time.time() - sam2_start
108
-
109
  if sam2_model:
110
  self.sam2_predictor = LoadedModel(
111
  model=sam2_model,
@@ -118,21 +123,21 @@ def load_all_models(
118
  logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
119
  else:
120
  logger.warning("SAM2 loading failed")
121
-
122
- # Check for cancellation
123
  if cancel_event and cancel_event.is_set():
124
  if progress_callback:
125
  progress_callback(1.0, "Model loading cancelled")
126
  return self.sam2_predictor, None
127
-
128
- # Load MatAnyone
129
  if progress_callback:
130
  progress_callback(0.6, "Loading MatAnyone model...")
131
-
132
  matanyone_start = time.time()
133
- matanyone_model = self.matanyone_loader.load()
134
  matanyone_time = time.time() - matanyone_start
135
-
136
  if matanyone_model:
137
  self.matanyone_model = LoadedModel(
138
  model=matanyone_model,
@@ -145,31 +150,30 @@ def load_all_models(
145
  logger.info(f"MatAnyone loaded in {matanyone_time:.2f}s")
146
  else:
147
  logger.warning("MatAnyone loading failed")
148
-
149
- # Update statistics
150
  total_time = time.time() - start_time
151
  self.loading_stats["total_load_time"] = total_time
152
  self.loading_stats["models_loaded"] = bool(self.sam2_predictor or self.matanyone_model)
153
-
154
- # Final progress update
155
  if progress_callback:
156
  if self.loading_stats["models_loaded"]:
157
  progress_callback(1.0, "Models loaded successfully")
158
  else:
159
  progress_callback(1.0, "Model loading completed with failures")
160
-
161
  logger.info(f"Model loading completed in {total_time:.2f}s")
162
  return self.sam2_predictor, self.matanyone_model
163
-
164
  except Exception as e:
165
  error_msg = f"Model loading failed: {str(e)}"
166
  logger.error(error_msg)
167
  self._cleanup_models()
168
  self.loading_stats["models_loaded"] = False
169
-
170
  if progress_callback:
171
  progress_callback(1.0, f"Error: {error_msg}")
172
-
173
  return None, None
174
 
175
  def reload_models(
@@ -192,70 +196,105 @@ def get_sam2(self):
192
  return self.sam2_predictor.model if self.sam2_predictor else None
193
 
194
  def get_matanyone(self):
195
- """Get MatAnyone processor model"""
 
 
 
 
 
 
196
  return self.matanyone_model.model if self.matanyone_model else None
197
 
198
  def validate_models(self) -> bool:
199
  """Validate that loaded models have expected interfaces"""
200
  try:
201
  valid = False
202
-
 
203
  if self.sam2_predictor:
204
  model = self.sam2_predictor.model
205
  if hasattr(model, "set_image") and hasattr(model, "predict"):
206
  valid = True
207
  logger.info("SAM2 model validated")
208
-
 
209
  if self.matanyone_model:
210
  model = self.matanyone_model.model
211
- if hasattr(model, "step") or hasattr(model, "process"):
 
 
 
 
 
 
 
212
  valid = True
213
- logger.info("MatAnyone model validated")
214
-
 
 
215
  return valid
216
-
217
  except Exception as e:
218
  logger.error(f"Model validation failed: {e}")
219
  return False
220
 
221
  def get_model_info(self) -> Dict[str, Any]:
222
  """Get detailed information about loaded models"""
223
- info = {
224
  "models_loaded": self.loading_stats["models_loaded"],
225
  "device": str(self.device),
226
  "loading_stats": self.loading_stats.copy(),
227
  }
228
-
229
  # Add SAM2 info
230
  info["sam2"] = self.sam2_loader.get_info() if self.sam2_loader else {}
231
-
232
- # Add MatAnyone info
233
- info["matanyone"] = self.matanyone_loader.get_info() if self.matanyone_loader else {}
234
-
 
 
 
 
 
 
 
 
 
235
  return info
236
 
237
  def get_load_summary(self) -> str:
238
  """Get human-readable loading summary"""
239
  if not self.loading_stats["models_loaded"]:
240
  return "No models loaded"
241
-
242
  lines = []
243
  lines.append(f"Models loaded in {self.loading_stats['total_load_time']:.1f}s")
244
-
245
  if self.sam2_predictor:
246
  lines.append(f"✓ SAM2: {self.loading_stats['sam2_load_time']:.1f}s")
247
  lines.append(f" Model: {self.sam2_predictor.model_id}")
248
  else:
249
  lines.append("✗ SAM2: Failed to load")
250
-
251
  if self.matanyone_model:
252
- lines.append(f"✓ MatAnyone: {self.loading_stats['matanyone_load_time']:.1f}s")
 
 
 
 
 
 
 
 
 
253
  lines.append(f" Model: {self.matanyone_model.model_id}")
254
  else:
255
  lines.append("✗ MatAnyone: Failed to load")
256
-
257
  lines.append(f"Device: {self.device}")
258
-
259
  return "\n".join(lines)
260
 
261
  def cleanup(self):
@@ -269,21 +308,27 @@ def _cleanup_models(self):
269
  if self.sam2_loader:
270
  self.sam2_loader.cleanup()
271
  if self.sam2_predictor:
272
- del self.sam2_predictor
 
 
 
273
  self.sam2_predictor = None
274
-
275
  # Clean up MatAnyone
276
  if self.matanyone_loader:
277
  self.matanyone_loader.cleanup()
278
  if self.matanyone_model:
279
- del self.matanyone_model
 
 
 
280
  self.matanyone_model = None
281
-
282
  # Clear CUDA cache
283
  if torch.cuda.is_available():
284
  torch.cuda.empty_cache()
285
-
286
  # Garbage collection
287
  gc.collect()
288
-
289
- logger.debug("Model cleanup completed")
 
2
  """
3
  Unified Model Loader
4
  Coordinates separate SAM2 and MatAnyone loaders for cleaner architecture
5
+
6
+ Notes:
7
+ - SAM2: exposes set_image(...) and predict(...)
8
+ - MatAnyone: our loader returns a stateful callable adapter:
9
+ - callable(adapter) -> frame0: adapter(image_rgb01, mask01), frames>0: adapter(image_rgb01)
10
+ - optional: adapter.reset() to clear per-video memory
11
+ We therefore validate MatAnyone by checking "callable(...)" and/or presence of "reset",
12
+ not only ".step/.process".
13
  """
14
 
15
  from __future__ import annotations
 
35
 
36
  class LoadedModel:
37
  """Container for loaded model information"""
38
+ def __init__(self, model=None, model_id: str = "", load_time: float = 0.0,
39
  device: str = "", framework: str = ""):
40
  self.model = model
41
  self.model_id = model_id
 
50
  "device": self.device,
51
  "load_time": self.load_time,
52
  "loaded": self.model is not None,
53
+ "model_type": type(self.model).__name__ if self.model is not None else None,
54
  }
55
 
56
 
57
  class ModelLoader:
58
  """Main model loader that coordinates SAM2 and MatAnyone loaders"""
59
+
60
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
61
  self.device_manager = device_mgr
62
  self.memory_manager = memory_mgr
63
  self.device = self.device_manager.get_optimal_device()
64
+
65
  # Initialize specialized loaders
66
  self.sam2_loader = SAM2Loader(device=str(self.device))
67
  self.matanyone_loader = MatAnyoneLoader(device=str(self.device))
68
+
69
  # Model storage
70
  self.sam2_predictor: Optional[LoadedModel] = None
71
  self.matanyone_model: Optional[LoadedModel] = None
72
+
73
  # Statistics
74
  self.loading_stats = {
75
  "sam2_load_time": 0.0,
 
78
  "models_loaded": False,
79
  "loading_attempts": 0,
80
  }
81
+
82
  logger.info(f"ModelLoader initialized for device: {self.device}")
83
 
84
  def load_all_models(
 
88
  ) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
89
  """
90
  Load all models using specialized loaders
91
+
 
 
 
 
92
  Returns:
93
  Tuple of (sam2_model, matanyone_model)
94
  """
95
  start_time = time.time()
96
  self.loading_stats["loading_attempts"] += 1
97
+
98
  try:
99
  logger.info("Starting model loading process...")
100
  if progress_callback:
101
  progress_callback(0.0, "Initializing model loading...")
102
+
103
  # Clean up any existing models
104
  self._cleanup_models()
105
+
106
+ # -------------------- Load SAM2 -------------------- #
107
  if progress_callback:
108
  progress_callback(0.1, "Loading SAM2 model...")
109
+
110
  sam2_start = time.time()
111
  sam2_model = self.sam2_loader.load()
112
  sam2_time = time.time() - sam2_start
113
+
114
  if sam2_model:
115
  self.sam2_predictor = LoadedModel(
116
  model=sam2_model,
 
123
  logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
124
  else:
125
  logger.warning("SAM2 loading failed")
126
+
127
+ # Cancellation check
128
  if cancel_event and cancel_event.is_set():
129
  if progress_callback:
130
  progress_callback(1.0, "Model loading cancelled")
131
  return self.sam2_predictor, None
132
+
133
+ # ----------------- Load MatAnyone ------------------ #
134
  if progress_callback:
135
  progress_callback(0.6, "Loading MatAnyone model...")
136
+
137
  matanyone_start = time.time()
138
+ matanyone_model = self.matanyone_loader.load() # returns stateful callable adapter or None
139
  matanyone_time = time.time() - matanyone_start
140
+
141
  if matanyone_model:
142
  self.matanyone_model = LoadedModel(
143
  model=matanyone_model,
 
150
  logger.info(f"MatAnyone loaded in {matanyone_time:.2f}s")
151
  else:
152
  logger.warning("MatAnyone loading failed")
153
+
154
+ # ----------------- Finalize stats ------------------ #
155
  total_time = time.time() - start_time
156
  self.loading_stats["total_load_time"] = total_time
157
  self.loading_stats["models_loaded"] = bool(self.sam2_predictor or self.matanyone_model)
158
+
 
159
  if progress_callback:
160
  if self.loading_stats["models_loaded"]:
161
  progress_callback(1.0, "Models loaded successfully")
162
  else:
163
  progress_callback(1.0, "Model loading completed with failures")
164
+
165
  logger.info(f"Model loading completed in {total_time:.2f}s")
166
  return self.sam2_predictor, self.matanyone_model
167
+
168
  except Exception as e:
169
  error_msg = f"Model loading failed: {str(e)}"
170
  logger.error(error_msg)
171
  self._cleanup_models()
172
  self.loading_stats["models_loaded"] = False
173
+
174
  if progress_callback:
175
  progress_callback(1.0, f"Error: {error_msg}")
176
+
177
  return None, None
178
 
179
  def reload_models(
 
196
  return self.sam2_predictor.model if self.sam2_predictor else None
197
 
198
  def get_matanyone(self):
199
+ """
200
+ Get MatAnyone processor model.
201
+
202
+ IMPORTANT: This returns the stateful callable adapter from MatAnyoneLoader:
203
+ - callable(image_rgb01[, mask01]) -> 2D alpha
204
+ - optional .reset() to clear memory per video
205
+ """
206
  return self.matanyone_model.model if self.matanyone_model else None
207
 
208
  def validate_models(self) -> bool:
209
  """Validate that loaded models have expected interfaces"""
210
  try:
211
  valid = False
212
+
213
+ # Validate SAM2
214
  if self.sam2_predictor:
215
  model = self.sam2_predictor.model
216
  if hasattr(model, "set_image") and hasattr(model, "predict"):
217
  valid = True
218
  logger.info("SAM2 model validated")
219
+
220
+ # Validate MatAnyone (stateful adapter OR raw core)
221
  if self.matanyone_model:
222
  model = self.matanyone_model.model
223
+ if callable(model):
224
+ valid = True
225
+ logger.info("MatAnyone adapter validated (callable)")
226
+ elif hasattr(model, "step") or hasattr(model, "process"):
227
+ valid = True
228
+ logger.info("MatAnyone core validated (.step/.process)")
229
+ elif hasattr(model, "reset"):
230
+ # still accept an adapter exposing reset but not callable (unlikely)
231
  valid = True
232
+ logger.info("MatAnyone object validated via reset()")
233
+ else:
234
+ logger.warning("MatAnyone present but interface not recognized")
235
+
236
  return valid
237
+
238
  except Exception as e:
239
  logger.error(f"Model validation failed: {e}")
240
  return False
241
 
242
  def get_model_info(self) -> Dict[str, Any]:
243
  """Get detailed information about loaded models"""
244
+ info: Dict[str, Any] = {
245
  "models_loaded": self.loading_stats["models_loaded"],
246
  "device": str(self.device),
247
  "loading_stats": self.loading_stats.copy(),
248
  }
249
+
250
  # Add SAM2 info
251
  info["sam2"] = self.sam2_loader.get_info() if self.sam2_loader else {}
252
+
253
+ # Add MatAnyone info (augment with interface hints)
254
+ mat_info = self.matanyone_loader.get_info() if self.matanyone_loader else {}
255
+ try:
256
+ m = self.get_matanyone()
257
+ mat_info["callable"] = bool(callable(m))
258
+ mat_info["has_reset"] = bool(hasattr(m, "reset"))
259
+ mat_info["has_step"] = bool(hasattr(m, "step"))
260
+ mat_info["has_process"] = bool(hasattr(m, "process"))
261
+ except Exception:
262
+ pass
263
+ info["matanyone"] = mat_info
264
+
265
  return info
266
 
267
  def get_load_summary(self) -> str:
268
  """Get human-readable loading summary"""
269
  if not self.loading_stats["models_loaded"]:
270
  return "No models loaded"
271
+
272
  lines = []
273
  lines.append(f"Models loaded in {self.loading_stats['total_load_time']:.1f}s")
274
+
275
  if self.sam2_predictor:
276
  lines.append(f"✓ SAM2: {self.loading_stats['sam2_load_time']:.1f}s")
277
  lines.append(f" Model: {self.sam2_predictor.model_id}")
278
  else:
279
  lines.append("✗ SAM2: Failed to load")
280
+
281
  if self.matanyone_model:
282
+ # Describe adapter/callable for clarity
283
+ iface = []
284
+ m = self.matanyone_model.model
285
+ if callable(m): iface.append("callable")
286
+ if hasattr(m, "reset"): iface.append("reset")
287
+ if hasattr(m, "step"): iface.append("step")
288
+ if hasattr(m, "process"): iface.append("process")
289
+ iface_str = f" ({', '.join(iface)})" if iface else ""
290
+
291
+ lines.append(f"✓ MatAnyone: {self.loading_stats['matanyone_load_time']:.1f}s{iface_str}")
292
  lines.append(f" Model: {self.matanyone_model.model_id}")
293
  else:
294
  lines.append("✗ MatAnyone: Failed to load")
295
+
296
  lines.append(f"Device: {self.device}")
297
+
298
  return "\n".join(lines)
299
 
300
  def cleanup(self):
 
308
  if self.sam2_loader:
309
  self.sam2_loader.cleanup()
310
  if self.sam2_predictor:
311
+ try:
312
+ del self.sam2_predictor
313
+ except Exception:
314
+ pass
315
  self.sam2_predictor = None
316
+
317
  # Clean up MatAnyone
318
  if self.matanyone_loader:
319
  self.matanyone_loader.cleanup()
320
  if self.matanyone_model:
321
+ try:
322
+ del self.matanyone_model
323
+ except Exception:
324
+ pass
325
  self.matanyone_model = None
326
+
327
  # Clear CUDA cache
328
  if torch.cuda.is_available():
329
  torch.cuda.empty_cache()
330
+
331
  # Garbage collection
332
  gc.collect()
333
+
334
+ logger.debug("Model cleanup completed")