MogensR commited on
Commit
31653b7
·
1 Parent(s): cc63301

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +265 -61
model_loader.py CHANGED
@@ -9,7 +9,9 @@
9
 
10
  import os
11
  import gc
 
12
  import time
 
13
  import logging
14
  import tempfile
15
  import traceback
@@ -27,6 +29,171 @@
27
 
28
  logger = logging.getLogger(__name__)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # ============================================================================ #
31
  # MODEL LOADER CLASS - MAIN INTERFACE
32
  # ============================================================================ #
@@ -169,86 +336,123 @@ def load_all_models(self, progress_callback: Optional[callable] = None, cancel_e
169
  return None, None
170
 
171
  # ============================================================================ #
172
- # SAM2 MODEL LOADING - AUTOMATIC CONFIG DETECTION
173
  # ============================================================================ #
174
 
175
  def _load_sam2_predictor(self, progress_callback: Optional[callable] = None):
176
  """
177
- Load SAM2 predictor with automatic config detection - no manual config files needed
178
- Uses build_sam2_video_predictor for automatic configuration based on checkpoint filename
179
 
180
  Args:
181
  progress_callback: Progress update callback
182
 
183
  Returns:
184
- SAM2VideoPredictor or None
185
  """
186
- def try_load_sam2_auto(repo_id: str, filename: str, model_name: str):
187
- """Attempt to load SAM2 with automatic config detection"""
188
- try:
189
- checkpoint_path = os.path.join(self.checkpoints_dir, filename)
190
- logger.info(f"Attempting SAM2 checkpoint: {checkpoint_path}")
191
-
192
- # Download checkpoint if needed
193
- if not os.path.exists(checkpoint_path):
194
- logger.info(f"Downloading {filename} from Hugging Face Hub...")
195
- if progress_callback:
196
- progress_callback(0.2, f"Downloading {filename}...")
197
-
198
- try:
199
- from huggingface_hub import hf_hub_download
200
- checkpoint_path = hf_hub_download(
201
- repo_id=repo_id,
202
- filename=filename,
203
- cache_dir=self.checkpoints_dir,
204
- local_dir_use_symlinks=False
205
- )
206
- logger.info(f"Download complete: {checkpoint_path}")
207
- except Exception as download_error:
208
- logger.warning(f"Failed to download {filename}: {download_error}")
209
- return None
210
-
211
- if progress_callback:
212
- progress_callback(0.4, f"Building SAM2 {model_name}...")
213
-
214
- # Use automatic config detection - NO manual config needed!
215
- from sam2.build_sam import build_sam2_video_predictor
216
-
217
- predictor = build_sam2_video_predictor(checkpoint_path, device=self.device)
218
-
219
- logger.info(f"SAM2 {model_name} loaded successfully on {self.device}")
220
- return predictor
221
-
222
- except Exception as e:
223
- error_msg = f"Failed to load SAM2 {model_name}: {e}"
224
- logger.warning(error_msg)
225
- return None
226
-
227
- # Try different SAM2 models with automatic config detection
228
- model_attempts = [
229
- ("facebook/sam2-hiera-large", "sam2_hiera_large.pt", "hiera_large"),
230
- ("facebook/sam2-hiera-base-plus", "sam2_hiera_base_plus.pt", "hiera_base_plus"),
231
- ("facebook/sam2-hiera-small", "sam2_hiera_small.pt", "hiera_small"),
232
- ("facebook/sam2-hiera-tiny", "sam2_hiera_tiny.pt", "hiera_tiny")
233
- ]
234
 
235
- # Prioritize model size based on device memory
 
 
 
236
  if hasattr(self.device_manager, 'get_device_memory_gb'):
237
  try:
238
  memory_gb = self.device_manager.get_device_memory_gb()
239
  if memory_gb < 4:
240
- model_attempts = model_attempts[2:] # Only tiny and small
241
  elif memory_gb < 8:
242
- model_attempts = model_attempts[1:] # Skip large
 
243
  except Exception as e:
244
  logger.warning(f"Could not determine device memory: {e}")
245
 
246
- for repo_id, filename, model_name in model_attempts:
247
- predictor = try_load_sam2_auto(repo_id, filename, model_name)
248
- if predictor is not None:
249
- return predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- logger.error("All SAM2 model loading attempts failed")
252
  return None
253
 
254
  # ============================================================================ #
@@ -411,7 +615,7 @@ def get_model_info(self) -> Dict[str, Any]:
411
 
412
  if self.sam2_predictor is not None:
413
  try:
414
- info['sam2_model_type'] = type(self.sam2_predictor.model).__name__
415
  except:
416
  info['sam2_model_type'] = "Unknown"
417
 
 
9
 
10
  import os
11
  import gc
12
+ import sys
13
  import time
14
+ import shutil
15
  import logging
16
  import tempfile
17
  import traceback
 
29
 
30
  logger = logging.getLogger(__name__)
31
 
32
+ # ============================================================================ #
33
+ # HARD CACHE CLEANER
34
+ # ============================================================================ #
35
+
36
+ class HardCacheCleaner:
37
+ """
38
+ Comprehensive cache cleaning system to resolve SAM2 loading issues
39
+ Clears Python module cache, HuggingFace cache, and temp files
40
+ """
41
+
42
+ @staticmethod
43
+ def clean_all_caches(verbose: bool = True):
44
+ """Clean all caches that might interfere with SAM2 loading"""
45
+
46
+ if verbose:
47
+ logger.info("Starting comprehensive cache cleanup...")
48
+
49
+ # 1. Clean Python module cache
50
+ HardCacheCleaner._clean_python_cache(verbose)
51
+
52
+ # 2. Clean HuggingFace cache
53
+ HardCacheCleaner._clean_huggingface_cache(verbose)
54
+
55
+ # 3. Clean PyTorch cache
56
+ HardCacheCleaner._clean_pytorch_cache(verbose)
57
+
58
+ # 4. Clean temp directories
59
+ HardCacheCleaner._clean_temp_directories(verbose)
60
+
61
+ # 5. Clear import cache
62
+ HardCacheCleaner._clear_import_cache(verbose)
63
+
64
+ # 6. Force garbage collection
65
+ HardCacheCleaner._force_gc_cleanup(verbose)
66
+
67
+ if verbose:
68
+ logger.info("Cache cleanup completed")
69
+
70
+ @staticmethod
71
+ def _clean_python_cache(verbose: bool = True):
72
+ """Clean Python bytecode cache"""
73
+ try:
74
+ # Clear sys.modules cache for SAM2 related modules
75
+ sam2_modules = [key for key in sys.modules.keys() if 'sam2' in key.lower()]
76
+ for module in sam2_modules:
77
+ if verbose:
78
+ logger.info(f"Removing cached module: {module}")
79
+ del sys.modules[module]
80
+
81
+ # Clear __pycache__ directories
82
+ for root, dirs, files in os.walk("."):
83
+ for dir_name in dirs[:]: # Use slice to modify list during iteration
84
+ if dir_name == "__pycache__":
85
+ cache_path = os.path.join(root, dir_name)
86
+ if verbose:
87
+ logger.info(f"Removing __pycache__: {cache_path}")
88
+ shutil.rmtree(cache_path, ignore_errors=True)
89
+ dirs.remove(dir_name)
90
+
91
+ except Exception as e:
92
+ logger.warning(f"Python cache cleanup failed: {e}")
93
+
94
+ @staticmethod
95
+ def _clean_huggingface_cache(verbose: bool = True):
96
+ """Clean HuggingFace model cache"""
97
+ try:
98
+ cache_paths = [
99
+ os.path.expanduser("~/.cache/huggingface/"),
100
+ os.path.expanduser("~/.cache/torch/"),
101
+ "./checkpoints/",
102
+ "./.cache/",
103
+ ]
104
+
105
+ for cache_path in cache_paths:
106
+ if os.path.exists(cache_path):
107
+ if verbose:
108
+ logger.info(f"Cleaning cache directory: {cache_path}")
109
+
110
+ # Remove SAM2 specific files
111
+ for root, dirs, files in os.walk(cache_path):
112
+ for file in files:
113
+ if any(pattern in file.lower() for pattern in ['sam2', 'segment-anything-2']):
114
+ file_path = os.path.join(root, file)
115
+ try:
116
+ os.remove(file_path)
117
+ if verbose:
118
+ logger.info(f"Removed cached file: {file_path}")
119
+ except:
120
+ pass
121
+
122
+ for dir_name in dirs[:]:
123
+ if any(pattern in dir_name.lower() for pattern in ['sam2', 'segment-anything-2']):
124
+ dir_path = os.path.join(root, dir_name)
125
+ try:
126
+ shutil.rmtree(dir_path, ignore_errors=True)
127
+ if verbose:
128
+ logger.info(f"Removed cached directory: {dir_path}")
129
+ dirs.remove(dir_name)
130
+ except:
131
+ pass
132
+
133
+ except Exception as e:
134
+ logger.warning(f"HuggingFace cache cleanup failed: {e}")
135
+
136
+ @staticmethod
137
+ def _clean_pytorch_cache(verbose: bool = True):
138
+ """Clean PyTorch cache"""
139
+ try:
140
+ import torch
141
+ if torch.cuda.is_available():
142
+ torch.cuda.empty_cache()
143
+ if verbose:
144
+ logger.info("Cleared PyTorch CUDA cache")
145
+ except Exception as e:
146
+ logger.warning(f"PyTorch cache cleanup failed: {e}")
147
+
148
+ @staticmethod
149
+ def _clean_temp_directories(verbose: bool = True):
150
+ """Clean temporary directories"""
151
+ try:
152
+ temp_dirs = [tempfile.gettempdir(), "/tmp", "./tmp", "./temp"]
153
+
154
+ for temp_dir in temp_dirs:
155
+ if os.path.exists(temp_dir):
156
+ for item in os.listdir(temp_dir):
157
+ if 'sam2' in item.lower() or 'segment' in item.lower():
158
+ item_path = os.path.join(temp_dir, item)
159
+ try:
160
+ if os.path.isfile(item_path):
161
+ os.remove(item_path)
162
+ elif os.path.isdir(item_path):
163
+ shutil.rmtree(item_path, ignore_errors=True)
164
+ if verbose:
165
+ logger.info(f"Removed temp item: {item_path}")
166
+ except:
167
+ pass
168
+
169
+ except Exception as e:
170
+ logger.warning(f"Temp directory cleanup failed: {e}")
171
+
172
+ @staticmethod
173
+ def _clear_import_cache(verbose: bool = True):
174
+ """Clear Python import cache"""
175
+ try:
176
+ import importlib
177
+
178
+ # Invalidate import caches
179
+ importlib.invalidate_caches()
180
+
181
+ if verbose:
182
+ logger.info("Cleared Python import cache")
183
+
184
+ except Exception as e:
185
+ logger.warning(f"Import cache cleanup failed: {e}")
186
+
187
+ @staticmethod
188
+ def _force_gc_cleanup(verbose: bool = True):
189
+ """Force garbage collection"""
190
+ try:
191
+ collected = gc.collect()
192
+ if verbose:
193
+ logger.info(f"Garbage collection freed {collected} objects")
194
+ except Exception as e:
195
+ logger.warning(f"Garbage collection failed: {e}")
196
+
197
  # ============================================================================ #
198
  # MODEL LOADER CLASS - MAIN INTERFACE
199
  # ============================================================================ #
 
336
  return None, None
337
 
338
  # ============================================================================ #
339
+ # SAM2 MODEL LOADING - HUGGINGFACE TRANSFORMERS APPROACH
340
  # ============================================================================ #
341
 
342
  def _load_sam2_predictor(self, progress_callback: Optional[callable] = None):
343
  """
344
+ Load SAM2 using HuggingFace Transformers integration with cache cleanup
345
+ This method works reliably on HuggingFace Spaces without config file issues
346
 
347
  Args:
348
  progress_callback: Progress update callback
349
 
350
  Returns:
351
+ SAM2 model or None
352
  """
353
+ logger.info("=== USING NEW HF TRANSFORMERS SAM2 LOADER ===")
354
+
355
+ # Step 1: Clean caches before loading
356
+ if progress_callback:
357
+ progress_callback(0.15, "Cleaning caches...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
+ HardCacheCleaner.clean_all_caches(verbose=True)
360
+
361
+ # Step 2: Determine model size based on device memory
362
+ model_size = "large" # default
363
  if hasattr(self.device_manager, 'get_device_memory_gb'):
364
  try:
365
  memory_gb = self.device_manager.get_device_memory_gb()
366
  if memory_gb < 4:
367
+ model_size = "tiny"
368
  elif memory_gb < 8:
369
+ model_size = "base"
370
+ logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB memory")
371
  except Exception as e:
372
  logger.warning(f"Could not determine device memory: {e}")
373
 
374
+ # Step 3: Try multiple HuggingFace approaches
375
+ model_map = {
376
+ "tiny": "facebook/sam2.1-hiera-tiny",
377
+ "small": "facebook/sam2.1-hiera-small",
378
+ "base": "facebook/sam2.1-hiera-base-plus",
379
+ "large": "facebook/sam2.1-hiera-large"
380
+ }
381
+
382
+ model_id = model_map.get(model_size, model_map["large"])
383
+
384
+ if progress_callback:
385
+ progress_callback(0.3, f"Loading SAM2 {model_size}...")
386
+
387
+ # Method 1: HuggingFace Transformers Pipeline (most reliable)
388
+ try:
389
+ logger.info("Trying Transformers pipeline approach...")
390
+ from transformers import pipeline
391
+
392
+ sam2_pipeline = pipeline(
393
+ "mask-generation",
394
+ model=model_id,
395
+ device=0 if str(self.device) == "cuda" else -1
396
+ )
397
+
398
+ logger.info("SAM2 loaded successfully via Transformers pipeline")
399
+ return sam2_pipeline
400
+
401
+ except Exception as e:
402
+ logger.warning(f"Pipeline approach failed: {e}")
403
+
404
+ # Method 2: Direct Transformers classes
405
+ try:
406
+ logger.info("Trying direct Transformers classes...")
407
+ from transformers import Sam2Processor, Sam2Model
408
+
409
+ processor = Sam2Processor.from_pretrained(model_id)
410
+ model = Sam2Model.from_pretrained(model_id).to(self.device)
411
+
412
+ logger.info("SAM2 loaded successfully via Transformers classes")
413
+ return {"model": model, "processor": processor}
414
+
415
+ except Exception as e:
416
+ logger.warning(f"Direct class approach failed: {e}")
417
+
418
+ # Method 3: Official SAM2 with from_pretrained
419
+ try:
420
+ logger.info("Trying official SAM2 from_pretrained...")
421
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
422
+
423
+ predictor = SAM2ImagePredictor.from_pretrained(model_id)
424
+
425
+ logger.info("SAM2 loaded successfully via official from_pretrained")
426
+ return predictor
427
+
428
+ except Exception as e:
429
+ logger.warning(f"Official from_pretrained approach failed: {e}")
430
+
431
+ # Method 4: Fallback to direct checkpoint download
432
+ try:
433
+ logger.info("Trying fallback checkpoint approach...")
434
+ from huggingface_hub import hf_hub_download
435
+ from transformers import Sam2Model
436
+
437
+ # Download checkpoint directly
438
+ checkpoint_path = hf_hub_download(
439
+ repo_id=model_id,
440
+ filename="model.safetensors" if "sam2.1" in model_id else "pytorch_model.bin"
441
+ )
442
+
443
+ logger.info(f"Downloaded checkpoint to: {checkpoint_path}")
444
+
445
+ # Load with minimal approach
446
+ model = Sam2Model.from_pretrained(model_id)
447
+ model = model.to(self.device)
448
+
449
+ logger.info("SAM2 loaded successfully via fallback approach")
450
+ return model
451
+
452
+ except Exception as e:
453
+ logger.warning(f"Fallback approach failed: {e}")
454
 
455
+ logger.error("All SAM2 loading methods failed")
456
  return None
457
 
458
  # ============================================================================ #
 
615
 
616
  if self.sam2_predictor is not None:
617
  try:
618
+ info['sam2_model_type'] = type(self.sam2_predictor).__name__
619
  except:
620
  info['sam2_model_type'] = "Unknown"
621