MogensR commited on
Commit
23796fb
·
1 Parent(s): 1aea709

Update models/loaders/model_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/model_loader.py +152 -505
models/loaders/model_loader.py CHANGED
@@ -1,11 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
- Model Loader for Hugging Face Spaces
4
- - Robust SAM2 loader with multiple strategies
5
- - Correct MatAnyOne loader via official InferenceCore (no transformers)
6
- - Clean progress reporting, cleanup, and diagnostics
7
- - NEW: Global MatAnyOne step/process shape guard to prevent 5D tensors
8
- - UPDATED: Enhanced MatAnyone wrapper support for component masks
9
  """
10
 
11
  from __future__ import annotations
@@ -14,8 +10,6 @@
14
  import gc
15
  import time
16
  import logging
17
- import traceback
18
- from pathlib import Path
19
  from typing import Optional, Dict, Any, Tuple, Callable
20
 
21
  import torch
@@ -24,14 +18,17 @@
24
  from utils.hardware.device_manager import DeviceManager
25
  from utils.system.memory_manager import MemoryManager
26
 
 
 
 
 
27
  logger = logging.getLogger(__name__)
28
 
29
 
30
- # ------------------------------
31
- # Data wrapper
32
- # ------------------------------
33
  class LoadedModel:
34
- def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, device: str = "", framework: str = ""):
 
 
35
  self.model = model
36
  self.model_id = model_id
37
  self.load_time = load_time
@@ -48,22 +45,23 @@ def to_dict(self) -> Dict[str, Any]:
48
  }
49
 
50
 
51
- # ------------------------------
52
- # Loader
53
- # ------------------------------
54
  class ModelLoader:
 
 
55
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
56
  self.device_manager = device_mgr
57
  self.memory_manager = memory_mgr
58
- self.device = self.device_manager.get_optimal_device() # e.g., cuda:0 or cpu
59
-
 
 
 
 
 
60
  self.sam2_predictor: Optional[LoadedModel] = None
61
  self.matanyone_model: Optional[LoadedModel] = None
62
- self._matanyone_wrapper = None # Cache for enhanced wrapper
63
-
64
- self.checkpoints_dir = "./checkpoints"
65
- os.makedirs(self.checkpoints_dir, exist_ok=True)
66
-
67
  self.loading_stats = {
68
  "sam2_load_time": 0.0,
69
  "matanyone_load_time": 0.0,
@@ -71,85 +69,114 @@ def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
71
  "models_loaded": False,
72
  "loading_attempts": 0,
73
  }
74
-
75
  logger.info(f"ModelLoader initialized for device: {self.device}")
76
 
77
- # ---------- Public API ----------
78
-
79
  def load_all_models(
80
- self, progress_callback: Optional[Callable[[float, str], None]] = None, cancel_event=None
 
 
81
  ) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
82
  """
83
- Loads SAM2 + MatAnyOne. Returns (LoadedModel|None, LoadedModel|None).
 
 
 
 
 
 
 
84
  """
85
  start_time = time.time()
86
  self.loading_stats["loading_attempts"] += 1
87
-
88
  try:
89
  logger.info("Starting model loading process...")
90
  if progress_callback:
91
  progress_callback(0.0, "Initializing model loading...")
92
-
 
93
  self._cleanup_models()
94
-
95
- # ---- SAM2 ----
96
- logger.info("Loading SAM2 predictor...")
97
  if progress_callback:
98
- progress_callback(0.1, "Loading SAM2 predictor...")
99
- sam2_loaded = self._load_sam2_predictor(progress_callback)
100
-
101
- if sam2_loaded is None:
102
- logger.warning("SAM2 loading failed - a limited fallback will be used at runtime if needed.")
 
 
 
 
 
 
 
 
 
 
 
103
  else:
104
- self.sam2_predictor = sam2_loaded
105
- self.loading_stats["sam2_load_time"] = self.sam2_predictor.load_time
106
- logger.info(f"SAM2 loaded in {self.loading_stats['sam2_load_time']:.2f}s")
107
-
108
- # Early exit if cancelled
109
- if cancel_event is not None and getattr(cancel_event, "is_set", lambda: False)():
110
  if progress_callback:
111
  progress_callback(1.0, "Model loading cancelled")
112
  return self.sam2_predictor, None
113
-
114
- # ---- MatAnyOne ----
115
- logger.info("Loading MatAnyOne model...")
116
  if progress_callback:
117
- progress_callback(0.6, "Loading MatAnyOne model...")
118
- matanyone_loaded = self._load_matanyone(progress_callback)
119
-
120
- if matanyone_loaded is None:
121
- logger.warning("MatAnyOne loading failed - will use simple refinement fallbacks.")
 
 
 
 
 
 
 
 
 
 
 
122
  else:
123
- self.matanyone_model = matanyone_loaded
124
- self.loading_stats["matanyone_load_time"] = self.matanyone_model.load_time
125
- logger.info(f"MatAnyOne loaded in {self.loading_stats['matanyone_load_time']:.2f}s")
126
-
127
- # ---- Final status ----
128
  total_time = time.time() - start_time
129
  self.loading_stats["total_load_time"] = total_time
130
  self.loading_stats["models_loaded"] = bool(self.sam2_predictor or self.matanyone_model)
131
-
 
132
  if progress_callback:
133
  if self.loading_stats["models_loaded"]:
134
- progress_callback(1.0, "Models loaded (fallbacks available if any model failed)")
135
  else:
136
- progress_callback(1.0, "Using fallback methods (models failed to load)")
137
-
138
  logger.info(f"Model loading completed in {total_time:.2f}s")
139
  return self.sam2_predictor, self.matanyone_model
140
-
141
  except Exception as e:
142
  error_msg = f"Model loading failed: {str(e)}"
143
- logger.error(f"{error_msg}\n{traceback.format_exc()}")
144
  self._cleanup_models()
145
  self.loading_stats["models_loaded"] = False
 
146
  if progress_callback:
147
  progress_callback(1.0, f"Error: {error_msg}")
 
148
  return None, None
149
 
150
- def reload_models(self, progress_callback: Optional[Callable[[float, str], None]] = None) -> Tuple[
151
- Optional[LoadedModel], Optional[LoadedModel]
152
- ]:
 
 
153
  logger.info("Reloading models...")
154
  self._cleanup_models()
155
  self.loading_stats["models_loaded"] = False
@@ -157,486 +184,106 @@ def reload_models(self, progress_callback: Optional[Callable[[float, str], None]
157
 
158
  @property
159
  def models_ready(self) -> bool:
 
160
  return self.sam2_predictor is not None or self.matanyone_model is not None
161
 
162
  def get_sam2(self):
163
- return self.sam2_predictor.model if self.sam2_predictor is not None else None
 
164
 
165
  def get_matanyone(self):
166
- """Get MatAnyone processor, optionally wrapped with enhanced features."""
167
- if self.matanyone_model is None:
168
- return None
169
-
170
- # Check if we should use the enhanced wrapper
171
- try:
172
- from app_config import get_config
173
- config = get_config()
174
-
175
- if config.matanyone_enabled and (config.use_component_masks or
176
- config.matanyone_edge_enhancement or
177
- config.matanyone_hair_refinement):
178
- # Use enhanced wrapper for advanced features
179
- try:
180
- from models.wrappers.matanyone_wrapper import MatAnyOneWrapper
181
-
182
- if self._matanyone_wrapper is None:
183
- self._matanyone_wrapper = MatAnyOneWrapper(
184
- self.matanyone_model.model,
185
- device=self.device,
186
- config=config.get_matanyone_config()
187
- )
188
- logger.info("Using enhanced MatAnyone wrapper with component support")
189
- return self._matanyone_wrapper
190
- except ImportError as e:
191
- logger.warning(f"Enhanced MatAnyone wrapper not available: {e}")
192
- except Exception as e:
193
- logger.error(f"Failed to initialize enhanced MatAnyone wrapper: {e}")
194
-
195
- except Exception as e:
196
- logger.debug(f"Could not check for enhanced wrapper configuration: {e}")
197
-
198
- # Return raw model for basic usage
199
- return self.matanyone_model.model if self.matanyone_model is not None else None
200
 
201
  def validate_models(self) -> bool:
 
202
  try:
203
- ok = False
204
- if self.sam2_predictor is not None:
 
205
  model = self.sam2_predictor.model
206
- if hasattr(model, "set_image") or hasattr(model, "predict"):
207
- ok = True
208
- if self.matanyone_model is not None:
209
- ok = True
210
- return ok
 
 
 
 
 
 
 
211
  except Exception as e:
212
  logger.error(f"Model validation failed: {e}")
213
  return False
214
 
215
  def get_model_info(self) -> Dict[str, Any]:
 
216
  info = {
217
  "models_loaded": self.loading_stats["models_loaded"],
218
- "sam2_loaded": self.sam2_predictor is not None,
219
- "matanyone_loaded": self.matanyone_model is not None,
220
  "device": str(self.device),
221
  "loading_stats": self.loading_stats.copy(),
222
  }
223
- if self.sam2_predictor is not None:
224
- info["sam2_model_type"] = type(self.sam2_predictor.model).__name__
225
- info["sam2_metadata"] = self.sam2_predictor.to_dict()
226
- if self.matanyone_model is not None:
227
- info["matanyone_model_type"] = type(self.matanyone_model.model).__name__
228
- info["matanyone_metadata"] = self.matanyone_model.to_dict()
229
 
230
- # Add wrapper status
231
- info["matanyone_wrapper_active"] = self._matanyone_wrapper is not None
 
 
 
232
 
233
  return info
234
 
235
  def get_load_summary(self) -> str:
 
236
  if not self.loading_stats["models_loaded"]:
237
- return "Models not loaded"
238
- sam2_time = self.loading_stats["sam2_load_time"]
239
- matanyone_time = self.loading_stats["matanyone_load_time"]
240
- total_time = self.loading_stats["total_load_time"]
241
- summary = f"Models loaded in {total_time:.1f}s\n"
242
  if self.sam2_predictor:
243
- summary += f"✓ SAM2: {sam2_time:.1f}s (ID: {self.sam2_predictor.model_id})\n"
 
244
  else:
245
- summary += "✗ SAM2: Failed (using fallback)\n"
 
246
  if self.matanyone_model:
247
- summary += f"✓ MatAnyOne: {matanyone_time:.1f}s (ID: {self.matanyone_model.model_id})\n"
248
- if self._matanyone_wrapper:
249
- summary += " └─ Enhanced wrapper active\n"
250
  else:
251
- summary += "✗ MatAnyOne: Failed (using simple refinement)\n"
252
- summary += f"Device: {self.device}"
253
- return summary
 
 
254
 
255
  def cleanup(self):
 
256
  self._cleanup_models()
257
  logger.info("ModelLoader cleanup completed")
258
 
259
- # ---------- Internal: SAM2 ----------
260
-
261
- def _load_sam2_predictor(self, progress_callback: Optional[Callable[[float, str], None]] = None) -> Optional[LoadedModel]:
262
- """
263
- Try multiple SAM2 loading strategies: official -> transformers -> dummy fallback.
264
- """
265
- # Choose model size heuristically
266
- model_size = "large"
267
- try:
268
- if hasattr(self.device_manager, "get_device_memory_gb"):
269
- memory_gb = self.device_manager.get_device_memory_gb()
270
- if memory_gb < 4:
271
- model_size = "tiny"
272
- elif memory_gb < 8:
273
- model_size = "small"
274
- elif memory_gb < 12:
275
- model_size = "base"
276
- logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB VRAM")
277
- except Exception as e:
278
- logger.warning(f"Could not determine device memory: {e}")
279
- model_size = "tiny"
280
-
281
- model_map = {
282
- "tiny": "facebook/sam2.1-hiera-tiny",
283
- "small": "facebook/sam2.1-hiera-small",
284
- "base": "facebook/sam2.1-hiera-base-plus",
285
- "large": "facebook/sam2.1-hiera-large",
286
- }
287
- model_id = model_map.get(model_size, model_map["tiny"])
288
-
289
- if progress_callback:
290
- progress_callback(0.3, f"Loading SAM2 ({model_size})...")
291
-
292
- methods = [
293
- ("official", self._try_load_sam2_official, model_id),
294
- ("direct", self._try_load_sam2_direct, model_id),
295
- ("manual", self._try_load_sam2_manual, model_id),
296
- ]
297
-
298
- for name, fn, mid in methods:
299
- try:
300
- logger.info(f"Attempting SAM2 load via {name} method ({mid})...")
301
- result = fn(mid)
302
- if result is not None:
303
- logger.info(f"SAM2 loaded successfully via {name} method")
304
- return result
305
- except Exception as e:
306
- logger.error(f"SAM2 {name} method failed: {e}")
307
- logger.debug(traceback.format_exc())
308
- continue
309
-
310
- logger.error("All SAM2 loading methods failed")
311
- return None
312
-
313
- def _try_load_sam2_official(self, model_id: str) -> Optional[LoadedModel]:
314
- """
315
- Official predictor path (Meta's SAM2ImagePredictor).
316
- """
317
- from sam2.sam2_image_predictor import SAM2ImagePredictor
318
-
319
- # Space-specific hub flags
320
- os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1"
321
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
322
-
323
- cache_dir = os.path.join(self.checkpoints_dir, "sam2_cache")
324
- os.makedirs(cache_dir, exist_ok=True)
325
-
326
- t0 = time.time()
327
- predictor = SAM2ImagePredictor.from_pretrained(
328
- model_id,
329
- cache_dir=cache_dir,
330
- local_files_only=False,
331
- trust_remote_code=True,
332
- )
333
- if hasattr(predictor, "model"):
334
- predictor.model = predictor.model.to(self.device)
335
- t1 = time.time()
336
-
337
- return LoadedModel(
338
- model=predictor, model_id=model_id, load_time=t1 - t0, device=str(self.device), framework="sam2"
339
- )
340
-
341
- def _try_load_sam2_direct(self, model_id: str) -> Optional[LoadedModel]:
342
- """
343
- Transformers AutoModel path (best-effort; API may vary).
344
- """
345
- from transformers import AutoModel, AutoProcessor
346
-
347
- t0 = time.time()
348
- model = AutoModel.from_pretrained(
349
- model_id,
350
- trust_remote_code=True,
351
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
352
- ).to(self.device)
353
-
354
- try:
355
- processor = AutoProcessor.from_pretrained(model_id)
356
- except Exception:
357
- processor = None
358
-
359
- t1 = time.time()
360
-
361
- class SAM2Wrapper:
362
- def __init__(self, model, processor=None):
363
- self.model = model
364
- self.processor = processor
365
-
366
- def set_image(self, image):
367
- self.current_image = image
368
-
369
- def predict(self, *args, **kwargs):
370
- return self.model(*args, **kwargs)
371
-
372
- wrapped = SAM2Wrapper(model, processor)
373
-
374
- return LoadedModel(
375
- model=wrapped,
376
- model_id=model_id,
377
- load_time=t1 - t0,
378
- device=str(self.device),
379
- framework="sam2-transformers",
380
- )
381
-
382
- def _try_load_sam2_manual(self, model_id: str) -> Optional[LoadedModel]:
383
- """
384
- Dummy fallback that won't crash the app.
385
- """
386
- class DummySAM2:
387
- def __init__(self, device):
388
- self.device = device
389
- self.model = None
390
-
391
- def set_image(self, image):
392
- self.current_image = image
393
-
394
- def predict(self, point_coords=None, point_labels=None, box=None, **kwargs):
395
- import numpy as np
396
- if hasattr(self, "current_image"):
397
- h, w = self.current_image.shape[:2]
398
- else:
399
- h, w = 512, 512
400
- return {
401
- "masks": np.ones((1, h, w), dtype=np.float32),
402
- "scores": np.array([0.5]),
403
- "logits": np.ones((1, h, w), dtype=np.float32),
404
- }
405
-
406
- t0 = time.time()
407
- dummy = DummySAM2(self.device)
408
- t1 = time.time()
409
-
410
- logger.warning("Using manual SAM2 fallback (limited functionality)")
411
- return LoadedModel(
412
- model=dummy, model_id=f"{model_id}-fallback", load_time=t1 - t0, device=str(self.device), framework="sam2-fallback"
413
- )
414
-
415
- # ---------- Internal: MatAnyOne ----------
416
-
417
- def _load_matanyone(self, progress_callback: Optional[Callable[[float, str], None]] = None) -> Optional[LoadedModel]:
418
- """
419
- Correct MatAnyOne loader using official package API.
420
- """
421
- if progress_callback:
422
- progress_callback(0.7, "Loading MatAnyOne (InferenceCore)...")
423
- try:
424
- return self._try_load_matanyone_official()
425
- except Exception as e:
426
- logger.error(f"MatAnyOne official loader failed: {e}")
427
- logger.debug(traceback.format_exc())
428
- logger.warning("Falling back to simple MatAnyOne placeholder.")
429
- return self._try_load_matanyone_fallback()
430
-
431
- def _try_load_matanyone_official(self) -> Optional[LoadedModel]:
432
- """
433
- Official MatAnyOne via package's InferenceCore.
434
- IMPORTANT: pass model id POSITIONALLY; do NOT use repo_id= or transformers.
435
- Also: install a shape guard so every call is safe (no 5D tensors).
436
- """
437
- from matanyone import InferenceCore
438
-
439
- t0 = time.time()
440
- processor = InferenceCore("PeiqingYang/MatAnyone")
441
-
442
- # ------------------- BEGIN: GLOBAL SHAPE GUARD PATCH -------------------
443
- try:
444
- # Lazy import coercers; provide minimal fallbacks if missing.
445
- try:
446
- from utils.interop import (
447
- ensure_image_nchw,
448
- ensure_mask_for_matanyone,
449
- log_shape,
450
- )
451
- except Exception as imp_err:
452
- logger.warning(f"utils.interop not available ({imp_err}); using minimal inline coercers")
453
-
454
- def log_shape(tag: str, t: torch.Tensor) -> None:
455
- try:
456
- mn = float(t.min()) if t.numel() else float("nan")
457
- mx = float(t.max()) if t.numel() else float("nan")
458
- print(f"[MatAny.guard] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
459
- f"range=[{mn:.4f},{mx:.4f}]")
460
- except Exception:
461
- pass
462
-
463
- def _to_float01(x: torch.Tensor) -> torch.Tensor:
464
- x = x.to(torch.float32)
465
- if x.max() > 1.0:
466
- x = x / 255.0
467
- return x.clamp_(0.0, 1.0)
468
-
469
- def _squeeze_bt(x: torch.Tensor) -> torch.Tensor:
470
- if x.ndim == 5:
471
- # (B,T,C,H,W) → drop T if 1
472
- if x.shape[1] == 1:
473
- x = x.squeeze(1)
474
- if x.ndim == 5 and x.shape[0] == 1:
475
- x = x.squeeze(0)
476
- if x.ndim == 4 and x.shape[0] == 1 and x.shape[1] == 1 and x.shape[-3] == 3:
477
- x = x.squeeze(1)
478
- return x
479
-
480
- def ensure_image_nchw(img: torch.Tensor, device=self.device, want_batched: bool = True) -> torch.Tensor:
481
- img = img.to(device)
482
- img = _squeeze_bt(img)
483
- if img.ndim == 3:
484
- # CHW or HWC
485
- if img.shape[0] in (1, 3):
486
- chw = img
487
- else:
488
- chw = img.permute(2, 0, 1)
489
- chw = _to_float01(chw.contiguous())
490
- return chw.unsqueeze(0) if want_batched else chw
491
- if img.ndim == 4:
492
- N, A, B, C = img.shape
493
- if A == 3:
494
- nchw = img
495
- elif C == 3:
496
- nchw = img.permute(0, 3, 1, 2)
497
- else:
498
- raise AssertionError(f"Cannot infer channels in image: {tuple(img.shape)}")
499
- nchw = _to_float01(nchw.contiguous())
500
- return nchw if want_batched else nchw[0]
501
- raise AssertionError(f"Bad image dims: {tuple(img.shape)}")
502
-
503
- def ensure_mask_for_matanyone(mask: torch.Tensor, *, idx_mask: bool = False,
504
- threshold: float = 0.5, keep_soft: bool = False,
505
- device=self.device) -> torch.Tensor:
506
- mask = mask.to(device)
507
- mask = _squeeze_bt(mask)
508
- if idx_mask:
509
- if mask.ndim == 3:
510
- if mask.shape[0] == 1:
511
- idx = (mask[0] >= threshold).to(torch.long)
512
- else:
513
- idx = torch.argmax(mask, dim=0).to(torch.long)
514
- idx = (idx > 0).to(torch.long)
515
- elif mask.ndim == 2:
516
- idx = (mask >= threshold).to(torch.long)
517
- else:
518
- raise AssertionError(f"idx mask must be 2D or 3D; got {tuple(mask.shape)}")
519
- return idx
520
- # channel mask
521
- if mask.ndim == 2:
522
- out = mask.unsqueeze(0)
523
- elif mask.ndim == 3:
524
- if mask.shape[0] == 1:
525
- out = mask
526
- else:
527
- areas = mask.sum(dim=(-2, -1))
528
- out = mask[areas.argmax():areas.argmax()+1]
529
- else:
530
- raise AssertionError(f"mask must be 2D/3D; got {tuple(mask.shape)}")
531
- out = out.to(torch.float32)
532
- if not keep_soft:
533
- out = (out >= threshold).to(torch.float32)
534
- return out.clamp_(0.0, 1.0).contiguous()
535
-
536
- def _guarded_factory(core_obj, method_name: str):
537
- core_step = getattr(core_obj, method_name)
538
-
539
- def wrapped_step(*args, **kwargs):
540
- # Extract image/mask/idx_mask whether passed positionally or by name
541
- image = kwargs.get("image", None)
542
- mask = kwargs.get("mask", None)
543
- idx_mask = kwargs.get("idx_mask", kwargs.get("index_mask", False))
544
-
545
- # Positional fallback guess: (image, mask, ...)
546
- if image is None and len(args) >= 1:
547
- image = args[0]
548
- if mask is None and len(args) >= 2:
549
- mask = args[1]
550
-
551
- # Coerce shapes
552
- img_nchw = ensure_image_nchw(image, device=self.device, want_batched=True)
553
- log_shape("image_nchw", img_nchw)
554
-
555
- if idx_mask:
556
- m_fixed = ensure_mask_for_matanyone(mask, idx_mask=True, device=img_nchw.device)
557
- log_shape("idx_hw", m_fixed)
558
- else:
559
- m_fixed = ensure_mask_for_matanyone(mask, idx_mask=False, threshold=0.5, keep_soft=False, device=img_nchw.device)
560
- log_shape("mask_c_hw", m_fixed)
561
-
562
- # Rebuild kwargs without duplicates
563
- new_kwargs = dict(kwargs)
564
- new_kwargs["idx_mask"] = bool(idx_mask)
565
- new_kwargs["image"] = img_nchw[0] # common: CHW image
566
-
567
- if idx_mask:
568
- new_kwargs["mask"] = m_fixed # (H,W) long
569
- else:
570
- new_kwargs["mask"] = m_fixed # (1,H,W) float
571
-
572
- # Try unbatched first, then batched fallback if needed
573
- try:
574
- return core_step(**new_kwargs)
575
- except Exception as e1:
576
- logger.debug(f"MatAnyOne step (CHW) failed, retrying batched NCHW: {e1}")
577
- new_kwargs["image"] = img_nchw # (1,3,H,W)
578
- try:
579
- return core_step(**new_kwargs)
580
- except Exception as e2:
581
- logger.error(f"MatAnyOne guarded call failed (both modes). Last error: {e2}")
582
- raise
583
-
584
- return wrapped_step
585
-
586
- if hasattr(processor, "step"):
587
- processor.step = _guarded_factory(processor, "step")
588
- logger.info("Patched MatAnyOne InferenceCore.step with shape guard")
589
- if hasattr(processor, "process"):
590
- processor.process = _guarded_factory(processor, "process")
591
- logger.info("Patched MatAnyOne InferenceCore.process with shape guard")
592
- except Exception as guard_err:
593
- logger.warning(f"Could not install MatAnyOne guard: {guard_err}")
594
- # -------------------- END: GLOBAL SHAPE GUARD PATCH --------------------
595
-
596
- t1 = time.time()
597
-
598
- return LoadedModel(
599
- model=processor,
600
- model_id="PeiqingYang/MatAnyone",
601
- load_time=t1 - t0,
602
- device=str(self.device),
603
- framework="matanyone",
604
- )
605
-
606
- def _try_load_matanyone_fallback(self) -> Optional[LoadedModel]:
607
- """
608
- Minimal placeholder that safely passes masks through.
609
- """
610
- class FallbackMatAnyone:
611
- def __init__(self, device):
612
- self.device = device
613
-
614
- def process(self, image, mask, **kwargs):
615
- # Identity pass-through (keeps pipeline alive)
616
- return mask
617
-
618
- t0 = time.time()
619
- model = FallbackMatAnyone(self.device)
620
- t1 = time.time()
621
-
622
- logger.warning("Using MatAnyOne fallback (limited functionality)")
623
- return LoadedModel(
624
- model=model, model_id="MatAnyone-fallback", load_time=t1 - t0, device=str(self.device), framework="matanyone-fallback"
625
- )
626
-
627
- # ---------- Internal: cleanup ----------
628
-
629
  def _cleanup_models(self):
630
- if self.sam2_predictor is not None:
 
 
 
 
631
  del self.sam2_predictor
632
  self.sam2_predictor = None
633
- if self.matanyone_model is not None:
 
 
 
 
634
  del self.matanyone_model
635
  self.matanyone_model = None
636
- if self._matanyone_wrapper is not None:
637
- del self._matanyone_wrapper
638
- self._matanyone_wrapper = None
639
  if torch.cuda.is_available():
640
  torch.cuda.empty_cache()
 
 
641
  gc.collect()
 
642
  logger.debug("Model cleanup completed")
 
1
  #!/usr/bin/env python3
2
  """
3
+ Unified Model Loader
4
+ Coordinates separate SAM2 and MatAnyone loaders for cleaner architecture
 
 
 
 
5
  """
6
 
7
  from __future__ import annotations
 
10
  import gc
11
  import time
12
  import logging
 
 
13
  from typing import Optional, Dict, Any, Tuple, Callable
14
 
15
  import torch
 
18
  from utils.hardware.device_manager import DeviceManager
19
  from utils.system.memory_manager import MemoryManager
20
 
21
+ # Import the specialized loaders
22
+ from models.loaders.sam2_loader import SAM2Loader
23
+ from models.loaders.matanyone_loader import MatAnyoneLoader
24
+
25
  logger = logging.getLogger(__name__)
26
 
27
 
 
 
 
28
  class LoadedModel:
29
+ """Container for loaded model information"""
30
+ def __init__(self, model=None, model_id: str = "", load_time: float = 0.0,
31
+ device: str = "", framework: str = ""):
32
  self.model = model
33
  self.model_id = model_id
34
  self.load_time = load_time
 
45
  }
46
 
47
 
 
 
 
48
  class ModelLoader:
49
+ """Main model loader that coordinates SAM2 and MatAnyone loaders"""
50
+
51
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
52
  self.device_manager = device_mgr
53
  self.memory_manager = memory_mgr
54
+ self.device = self.device_manager.get_optimal_device()
55
+
56
+ # Initialize specialized loaders
57
+ self.sam2_loader = SAM2Loader(device=str(self.device))
58
+ self.matanyone_loader = MatAnyoneLoader(device=str(self.device))
59
+
60
+ # Model storage
61
  self.sam2_predictor: Optional[LoadedModel] = None
62
  self.matanyone_model: Optional[LoadedModel] = None
63
+
64
+ # Statistics
 
 
 
65
  self.loading_stats = {
66
  "sam2_load_time": 0.0,
67
  "matanyone_load_time": 0.0,
 
69
  "models_loaded": False,
70
  "loading_attempts": 0,
71
  }
72
+
73
  logger.info(f"ModelLoader initialized for device: {self.device}")
74
 
 
 
75
  def load_all_models(
76
+ self,
77
+ progress_callback: Optional[Callable[[float, str], None]] = None,
78
+ cancel_event=None
79
  ) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
80
  """
81
+ Load all models using specialized loaders
82
+
83
+ Args:
84
+ progress_callback: Optional callback for progress updates
85
+ cancel_event: Optional threading.Event for cancellation
86
+
87
+ Returns:
88
+ Tuple of (sam2_model, matanyone_model)
89
  """
90
  start_time = time.time()
91
  self.loading_stats["loading_attempts"] += 1
92
+
93
  try:
94
  logger.info("Starting model loading process...")
95
  if progress_callback:
96
  progress_callback(0.0, "Initializing model loading...")
97
+
98
+ # Clean up any existing models
99
  self._cleanup_models()
100
+
101
+ # Load SAM2
 
102
  if progress_callback:
103
+ progress_callback(0.1, "Loading SAM2 model...")
104
+
105
+ sam2_start = time.time()
106
+ sam2_model = self.sam2_loader.load()
107
+ sam2_time = time.time() - sam2_start
108
+
109
+ if sam2_model:
110
+ self.sam2_predictor = LoadedModel(
111
+ model=sam2_model,
112
+ model_id=self.sam2_loader.model_id,
113
+ load_time=sam2_time,
114
+ device=str(self.device),
115
+ framework="sam2"
116
+ )
117
+ self.loading_stats["sam2_load_time"] = sam2_time
118
+ logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
119
  else:
120
+ logger.warning("SAM2 loading failed")
121
+
122
+ # Check for cancellation
123
+ if cancel_event and cancel_event.is_set():
 
 
124
  if progress_callback:
125
  progress_callback(1.0, "Model loading cancelled")
126
  return self.sam2_predictor, None
127
+
128
+ # Load MatAnyone
 
129
  if progress_callback:
130
+ progress_callback(0.6, "Loading MatAnyone model...")
131
+
132
+ matanyone_start = time.time()
133
+ matanyone_model = self.matanyone_loader.load()
134
+ matanyone_time = time.time() - matanyone_start
135
+
136
+ if matanyone_model:
137
+ self.matanyone_model = LoadedModel(
138
+ model=matanyone_model,
139
+ model_id=self.matanyone_loader.model_id,
140
+ load_time=matanyone_time,
141
+ device=str(self.device),
142
+ framework="matanyone"
143
+ )
144
+ self.loading_stats["matanyone_load_time"] = matanyone_time
145
+ logger.info(f"MatAnyone loaded in {matanyone_time:.2f}s")
146
  else:
147
+ logger.warning("MatAnyone loading failed")
148
+
149
+ # Update statistics
 
 
150
  total_time = time.time() - start_time
151
  self.loading_stats["total_load_time"] = total_time
152
  self.loading_stats["models_loaded"] = bool(self.sam2_predictor or self.matanyone_model)
153
+
154
+ # Final progress update
155
  if progress_callback:
156
  if self.loading_stats["models_loaded"]:
157
+ progress_callback(1.0, "Models loaded successfully")
158
  else:
159
+ progress_callback(1.0, "Model loading completed with failures")
160
+
161
  logger.info(f"Model loading completed in {total_time:.2f}s")
162
  return self.sam2_predictor, self.matanyone_model
163
+
164
  except Exception as e:
165
  error_msg = f"Model loading failed: {str(e)}"
166
+ logger.error(error_msg)
167
  self._cleanup_models()
168
  self.loading_stats["models_loaded"] = False
169
+
170
  if progress_callback:
171
  progress_callback(1.0, f"Error: {error_msg}")
172
+
173
  return None, None
174
 
175
+ def reload_models(
176
+ self,
177
+ progress_callback: Optional[Callable[[float, str], None]] = None
178
+ ) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
179
+ """Reload all models from scratch"""
180
  logger.info("Reloading models...")
181
  self._cleanup_models()
182
  self.loading_stats["models_loaded"] = False
 
184
 
185
  @property
186
  def models_ready(self) -> bool:
187
+ """Check if any models are loaded and ready"""
188
  return self.sam2_predictor is not None or self.matanyone_model is not None
189
 
190
  def get_sam2(self):
191
+ """Get SAM2 predictor model"""
192
+ return self.sam2_predictor.model if self.sam2_predictor else None
193
 
194
  def get_matanyone(self):
195
+ """Get MatAnyone processor model"""
196
+ return self.matanyone_model.model if self.matanyone_model else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def validate_models(self) -> bool:
199
+ """Validate that loaded models have expected interfaces"""
200
  try:
201
+ valid = False
202
+
203
+ if self.sam2_predictor:
204
  model = self.sam2_predictor.model
205
+ if hasattr(model, "set_image") and hasattr(model, "predict"):
206
+ valid = True
207
+ logger.info("SAM2 model validated")
208
+
209
+ if self.matanyone_model:
210
+ model = self.matanyone_model.model
211
+ if hasattr(model, "step") or hasattr(model, "process"):
212
+ valid = True
213
+ logger.info("MatAnyone model validated")
214
+
215
+ return valid
216
+
217
  except Exception as e:
218
  logger.error(f"Model validation failed: {e}")
219
  return False
220
 
221
  def get_model_info(self) -> Dict[str, Any]:
222
+ """Get detailed information about loaded models"""
223
  info = {
224
  "models_loaded": self.loading_stats["models_loaded"],
 
 
225
  "device": str(self.device),
226
  "loading_stats": self.loading_stats.copy(),
227
  }
 
 
 
 
 
 
228
 
229
+ # Add SAM2 info
230
+ info["sam2"] = self.sam2_loader.get_info() if self.sam2_loader else {}
231
+
232
+ # Add MatAnyone info
233
+ info["matanyone"] = self.matanyone_loader.get_info() if self.matanyone_loader else {}
234
 
235
  return info
236
 
237
  def get_load_summary(self) -> str:
238
+ """Get human-readable loading summary"""
239
  if not self.loading_stats["models_loaded"]:
240
+ return "No models loaded"
241
+
242
+ lines = []
243
+ lines.append(f"Models loaded in {self.loading_stats['total_load_time']:.1f}s")
244
+
245
  if self.sam2_predictor:
246
+ lines.append(f"✓ SAM2: {self.loading_stats['sam2_load_time']:.1f}s")
247
+ lines.append(f" Model: {self.sam2_predictor.model_id}")
248
  else:
249
+ lines.append("✗ SAM2: Failed to load")
250
+
251
  if self.matanyone_model:
252
+ lines.append(f"✓ MatAnyone: {self.loading_stats['matanyone_load_time']:.1f}s")
253
+ lines.append(f" Model: {self.matanyone_model.model_id}")
 
254
  else:
255
+ lines.append("✗ MatAnyone: Failed to load")
256
+
257
+ lines.append(f"Device: {self.device}")
258
+
259
+ return "\n".join(lines)
260
 
261
  def cleanup(self):
262
+ """Clean up all resources"""
263
  self._cleanup_models()
264
  logger.info("ModelLoader cleanup completed")
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  def _cleanup_models(self):
267
+ """Internal cleanup of loaded models"""
268
+ # Clean up SAM2
269
+ if self.sam2_loader:
270
+ self.sam2_loader.cleanup()
271
+ if self.sam2_predictor:
272
  del self.sam2_predictor
273
  self.sam2_predictor = None
274
+
275
+ # Clean up MatAnyone
276
+ if self.matanyone_loader:
277
+ self.matanyone_loader.cleanup()
278
+ if self.matanyone_model:
279
  del self.matanyone_model
280
  self.matanyone_model = None
281
+
282
+ # Clear CUDA cache
 
283
  if torch.cuda.is_available():
284
  torch.cuda.empty_cache()
285
+
286
+ # Garbage collection
287
  gc.collect()
288
+
289
  logger.debug("Model cleanup completed")