MogensR commited on
Commit
cc63301
Β·
1 Parent(s): 0961c3b

Create cache_cleaner.py

Browse files
Files changed (1) hide show
  1. cache_cleaner.py +364 -0
cache_cleaner.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================ #
2
+ # HARD CACHE CLEANER + WORKING SAM2 LOADER FOR HUGGINGFACE SPACES
3
+ # ============================================================================ #
4
+
5
+ import os
6
+ import gc
7
+ import sys
8
+ import shutil
9
+ import tempfile
10
+ import logging
11
+ import traceback
12
+ from pathlib import Path
13
+ from typing import Optional, Dict, Any, Tuple
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class HardCacheCleaner:
18
+ """
19
+ Comprehensive cache cleaning system to resolve SAM2 loading issues
20
+ Clears Python module cache, HuggingFace cache, and temp files
21
+ """
22
+
23
+ @staticmethod
24
+ def clean_all_caches(verbose: bool = True):
25
+ """Clean all caches that might interfere with SAM2 loading"""
26
+
27
+ if verbose:
28
+ logger.info("🧹 Starting comprehensive cache cleanup...")
29
+
30
+ # 1. Clean Python module cache
31
+ HardCacheCleaner._clean_python_cache(verbose)
32
+
33
+ # 2. Clean HuggingFace cache
34
+ HardCacheCleaner._clean_huggingface_cache(verbose)
35
+
36
+ # 3. Clean PyTorch cache
37
+ HardCacheCleaner._clean_pytorch_cache(verbose)
38
+
39
+ # 4. Clean temp directories
40
+ HardCacheCleaner._clean_temp_directories(verbose)
41
+
42
+ # 5. Clear import cache
43
+ HardCacheCleaner._clear_import_cache(verbose)
44
+
45
+ # 6. Force garbage collection
46
+ HardCacheCleaner._force_gc_cleanup(verbose)
47
+
48
+ if verbose:
49
+ logger.info("βœ… Cache cleanup completed")
50
+
51
+ @staticmethod
52
+ def _clean_python_cache(verbose: bool = True):
53
+ """Clean Python bytecode cache"""
54
+ try:
55
+ # Clear sys.modules cache for SAM2 related modules
56
+ sam2_modules = [key for key in sys.modules.keys() if 'sam2' in key.lower()]
57
+ for module in sam2_modules:
58
+ if verbose:
59
+ logger.info(f"πŸ—‘οΈ Removing cached module: {module}")
60
+ del sys.modules[module]
61
+
62
+ # Clear __pycache__ directories
63
+ for root, dirs, files in os.walk("."):
64
+ for dir_name in dirs[:]: # Use slice to modify list during iteration
65
+ if dir_name == "__pycache__":
66
+ cache_path = os.path.join(root, dir_name)
67
+ if verbose:
68
+ logger.info(f"πŸ—‘οΈ Removing __pycache__: {cache_path}")
69
+ shutil.rmtree(cache_path, ignore_errors=True)
70
+ dirs.remove(dir_name)
71
+
72
+ except Exception as e:
73
+ logger.warning(f"Python cache cleanup failed: {e}")
74
+
75
+ @staticmethod
76
+ def _clean_huggingface_cache(verbose: bool = True):
77
+ """Clean HuggingFace model cache"""
78
+ try:
79
+ cache_paths = [
80
+ os.path.expanduser("~/.cache/huggingface/"),
81
+ os.path.expanduser("~/.cache/torch/"),
82
+ "./checkpoints/",
83
+ "./.cache/",
84
+ ]
85
+
86
+ for cache_path in cache_paths:
87
+ if os.path.exists(cache_path):
88
+ if verbose:
89
+ logger.info(f"πŸ—‘οΈ Cleaning cache directory: {cache_path}")
90
+
91
+ # Remove SAM2 specific files
92
+ for root, dirs, files in os.walk(cache_path):
93
+ for file in files:
94
+ if any(pattern in file.lower() for pattern in ['sam2', 'segment-anything-2']):
95
+ file_path = os.path.join(root, file)
96
+ try:
97
+ os.remove(file_path)
98
+ if verbose:
99
+ logger.info(f"πŸ—‘οΈ Removed cached file: {file_path}")
100
+ except:
101
+ pass
102
+
103
+ for dir_name in dirs[:]:
104
+ if any(pattern in dir_name.lower() for pattern in ['sam2', 'segment-anything-2']):
105
+ dir_path = os.path.join(root, dir_name)
106
+ try:
107
+ shutil.rmtree(dir_path, ignore_errors=True)
108
+ if verbose:
109
+ logger.info(f"πŸ—‘οΈ Removed cached directory: {dir_path}")
110
+ dirs.remove(dir_name)
111
+ except:
112
+ pass
113
+
114
+ except Exception as e:
115
+ logger.warning(f"HuggingFace cache cleanup failed: {e}")
116
+
117
+ @staticmethod
118
+ def _clean_pytorch_cache(verbose: bool = True):
119
+ """Clean PyTorch cache"""
120
+ try:
121
+ import torch
122
+ if torch.cuda.is_available():
123
+ torch.cuda.empty_cache()
124
+ if verbose:
125
+ logger.info("πŸ—‘οΈ Cleared PyTorch CUDA cache")
126
+ except Exception as e:
127
+ logger.warning(f"PyTorch cache cleanup failed: {e}")
128
+
129
+ @staticmethod
130
+ def _clean_temp_directories(verbose: bool = True):
131
+ """Clean temporary directories"""
132
+ try:
133
+ temp_dirs = [tempfile.gettempdir(), "/tmp", "./tmp", "./temp"]
134
+
135
+ for temp_dir in temp_dirs:
136
+ if os.path.exists(temp_dir):
137
+ for item in os.listdir(temp_dir):
138
+ if 'sam2' in item.lower() or 'segment' in item.lower():
139
+ item_path = os.path.join(temp_dir, item)
140
+ try:
141
+ if os.path.isfile(item_path):
142
+ os.remove(item_path)
143
+ elif os.path.isdir(item_path):
144
+ shutil.rmtree(item_path, ignore_errors=True)
145
+ if verbose:
146
+ logger.info(f"πŸ—‘οΈ Removed temp item: {item_path}")
147
+ except:
148
+ pass
149
+
150
+ except Exception as e:
151
+ logger.warning(f"Temp directory cleanup failed: {e}")
152
+
153
+ @staticmethod
154
+ def _clear_import_cache(verbose: bool = True):
155
+ """Clear Python import cache"""
156
+ try:
157
+ import importlib
158
+
159
+ # Invalidate import caches
160
+ importlib.invalidate_caches()
161
+
162
+ if verbose:
163
+ logger.info("πŸ—‘οΈ Cleared Python import cache")
164
+
165
+ except Exception as e:
166
+ logger.warning(f"Import cache cleanup failed: {e}")
167
+
168
+ @staticmethod
169
+ def _force_gc_cleanup(verbose: bool = True):
170
+ """Force garbage collection"""
171
+ try:
172
+ collected = gc.collect()
173
+ if verbose:
174
+ logger.info(f"πŸ—‘οΈ Garbage collection freed {collected} objects")
175
+ except Exception as e:
176
+ logger.warning(f"Garbage collection failed: {e}")
177
+
178
+
179
+ class WorkingSAM2Loader:
180
+ """
181
+ SAM2 loader using HuggingFace Transformers integration - proven to work on HF Spaces
182
+ This avoids all the config file and CUDA compilation issues
183
+ """
184
+
185
+ @staticmethod
186
+ def load_sam2_transformers_approach(device: str = "cuda", model_size: str = "large") -> Optional[Any]:
187
+ """
188
+ Load SAM2 using HuggingFace Transformers integration
189
+ This method works reliably on HuggingFace Spaces
190
+ """
191
+ try:
192
+ logger.info("πŸ€– Loading SAM2 via HuggingFace Transformers...")
193
+
194
+ # Model size mapping
195
+ model_map = {
196
+ "tiny": "facebook/sam2.1-hiera-tiny",
197
+ "small": "facebook/sam2.1-hiera-small",
198
+ "base": "facebook/sam2.1-hiera-base-plus",
199
+ "large": "facebook/sam2.1-hiera-large"
200
+ }
201
+
202
+ model_id = model_map.get(model_size, model_map["large"])
203
+ logger.info(f"Using model: {model_id}")
204
+
205
+ # Method 1: Using Transformers pipeline (most reliable for HF Spaces)
206
+ try:
207
+ from transformers import pipeline
208
+
209
+ sam2_pipeline = pipeline(
210
+ "mask-generation",
211
+ model=model_id,
212
+ device=0 if device == "cuda" else -1
213
+ )
214
+
215
+ logger.info("βœ… SAM2 loaded successfully via Transformers pipeline")
216
+ return sam2_pipeline
217
+
218
+ except Exception as e:
219
+ logger.warning(f"Pipeline approach failed: {e}")
220
+
221
+ # Method 2: Using SAM2 classes directly via Transformers
222
+ try:
223
+ from transformers import Sam2Processor, Sam2Model
224
+
225
+ processor = Sam2Processor.from_pretrained(model_id)
226
+ model = Sam2Model.from_pretrained(model_id).to(device)
227
+
228
+ logger.info("βœ… SAM2 loaded successfully via Transformers classes")
229
+ return {"model": model, "processor": processor}
230
+
231
+ except Exception as e:
232
+ logger.warning(f"Direct class approach failed: {e}")
233
+
234
+ # Method 3: Using official SAM2 with .from_pretrained()
235
+ try:
236
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
237
+
238
+ predictor = SAM2ImagePredictor.from_pretrained(model_id)
239
+
240
+ logger.info("βœ… SAM2 loaded successfully via official from_pretrained")
241
+ return predictor
242
+
243
+ except Exception as e:
244
+ logger.warning(f"Official from_pretrained approach failed: {e}")
245
+
246
+ return None
247
+
248
+ except Exception as e:
249
+ logger.error(f"All SAM2 loading methods failed: {e}")
250
+ return None
251
+
252
+ @staticmethod
253
+ def load_sam2_fallback_approach(device: str = "cuda") -> Optional[Any]:
254
+ """
255
+ Fallback approach using direct model loading
256
+ """
257
+ try:
258
+ logger.info("πŸ”„ Trying fallback SAM2 loading approach...")
259
+
260
+ # Try the simplest possible approach
261
+ from huggingface_hub import hf_hub_download
262
+ import torch
263
+
264
+ # Download checkpoint directly
265
+ checkpoint_path = hf_hub_download(
266
+ repo_id="facebook/sam2.1-hiera-large",
267
+ filename="sam2_hiera_large.pt"
268
+ )
269
+
270
+ logger.info(f"Downloaded checkpoint to: {checkpoint_path}")
271
+
272
+ # Try to load with minimal dependencies
273
+ try:
274
+ # Method A: Try the working transformers integration
275
+ from transformers import Sam2Model
276
+ model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-large")
277
+ return model.to(device)
278
+
279
+ except Exception as e:
280
+ logger.warning(f"Transformers fallback failed: {e}")
281
+
282
+ return None
283
+
284
+ except Exception as e:
285
+ logger.error(f"Fallback loading failed: {e}")
286
+ return None
287
+
288
+
289
+ # ============================================================================ #
290
+ # INTEGRATED MODEL LOADER WITH CACHE CLEANING
291
+ # ============================================================================ #
292
+
293
+ def load_sam2_with_cache_cleanup(
294
+ device: str = "cuda",
295
+ model_size: str = "large",
296
+ force_cache_clean: bool = True,
297
+ verbose: bool = True
298
+ ) -> Tuple[Optional[Any], str]:
299
+ """
300
+ Load SAM2 with comprehensive cache cleanup
301
+
302
+ Returns:
303
+ Tuple of (model, status_message)
304
+ """
305
+
306
+ status_messages = []
307
+
308
+ try:
309
+ # Step 1: Clean caches if requested
310
+ if force_cache_clean:
311
+ status_messages.append("🧹 Cleaning caches...")
312
+ HardCacheCleaner.clean_all_caches(verbose=verbose)
313
+ status_messages.append("βœ… Cache cleanup completed")
314
+
315
+ # Step 2: Try primary loading method
316
+ status_messages.append("πŸ€– Loading SAM2 (primary method)...")
317
+ model = WorkingSAM2Loader.load_sam2_transformers_approach(device, model_size)
318
+
319
+ if model is not None:
320
+ status_messages.append("βœ… SAM2 loaded successfully!")
321
+ return model, "\n".join(status_messages)
322
+
323
+ # Step 3: Try fallback method
324
+ status_messages.append("πŸ”„ Trying fallback loading method...")
325
+ model = WorkingSAM2Loader.load_sam2_fallback_approach(device)
326
+
327
+ if model is not None:
328
+ status_messages.append("βœ… SAM2 loaded successfully (fallback)!")
329
+ return model, "\n".join(status_messages)
330
+
331
+ # Step 4: All methods failed
332
+ status_messages.append("❌ All SAM2 loading methods failed")
333
+ return None, "\n".join(status_messages)
334
+
335
+ except Exception as e:
336
+ error_msg = f"❌ Critical error in SAM2 loading: {e}"
337
+ logger.error(f"{error_msg}\n{traceback.format_exc()}")
338
+ status_messages.append(error_msg)
339
+ return None, "\n".join(status_messages)
340
+
341
+
342
+ # ============================================================================ #
343
+ # USAGE EXAMPLE
344
+ # ============================================================================ #
345
+
346
+ if __name__ == "__main__":
347
+ # Clean example usage
348
+ print("Testing SAM2 loader with cache cleanup...")
349
+
350
+ # Load SAM2 with full cache cleanup
351
+ model, status = load_sam2_with_cache_cleanup(
352
+ device="cuda",
353
+ model_size="large",
354
+ force_cache_clean=True,
355
+ verbose=True
356
+ )
357
+
358
+ print("Status:", status)
359
+
360
+ if model is not None:
361
+ print("SAM2 loaded successfully!")
362
+ print("Model type:", type(model))
363
+ else:
364
+ print("SAM2 loading failed completely")