File size: 13,299 Bytes
f77b2f9
 
 
 
cc63301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
 
f77b2f9
 
 
 
cc63301
 
 
f77b2f9
cc63301
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
f77b2f9
 
 
 
 
 
 
 
 
 
cc63301
 
 
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f77b2f9
cc63301
f77b2f9
cc63301
 
f77b2f9
cc63301
 
 
f77b2f9
cc63301
 
 
f77b2f9
cc63301
 
 
f77b2f9
cc63301
 
 
f77b2f9
cc63301
 
 
f77b2f9
cc63301
 
f77b2f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
"""
Cache Management and SAM2 Loading Utilities
Comprehensive cache cleaning system to resolve model loading issues on HF Spaces
"""

import os
import gc
import sys
import shutil
import tempfile
import logging
import traceback
from pathlib import Path
from typing import Optional, Dict, Any, Tuple

logger = logging.getLogger(__name__)

class HardCacheCleaner:
    """
    Comprehensive cache cleaning system to resolve SAM2 loading issues
    Clears Python module cache, HuggingFace cache, and temp files
    """
    
    @staticmethod
    def clean_all_caches(verbose: bool = True):
        """Clean all caches that might interfere with SAM2 loading"""
        
        if verbose:
            logger.info("Starting comprehensive cache cleanup...")
        
        # 1. Clean Python module cache
        HardCacheCleaner._clean_python_cache(verbose)
        
        # 2. Clean HuggingFace cache
        HardCacheCleaner._clean_huggingface_cache(verbose)
        
        # 3. Clean PyTorch cache
        HardCacheCleaner._clean_pytorch_cache(verbose)
        
        # 4. Clean temp directories
        HardCacheCleaner._clean_temp_directories(verbose)
        
        # 5. Clear import cache
        HardCacheCleaner._clear_import_cache(verbose)
        
        # 6. Force garbage collection
        HardCacheCleaner._force_gc_cleanup(verbose)
        
        if verbose:
            logger.info("Cache cleanup completed")
    
    @staticmethod
    def _clean_python_cache(verbose: bool = True):
        """Clean Python bytecode cache"""
        try:
            # Clear sys.modules cache for SAM2 related modules
            sam2_modules = [key for key in sys.modules.keys() if 'sam2' in key.lower()]
            for module in sam2_modules:
                if verbose:
                    logger.info(f"Removing cached module: {module}")
                del sys.modules[module]
            
            # Clear __pycache__ directories
            for root, dirs, files in os.walk("."):
                for dir_name in dirs[:]:  # Use slice to modify list during iteration
                    if dir_name == "__pycache__":
                        cache_path = os.path.join(root, dir_name)
                        if verbose:
                            logger.info(f"Removing __pycache__: {cache_path}")
                        shutil.rmtree(cache_path, ignore_errors=True)
                        dirs.remove(dir_name)
            
        except Exception as e:
            logger.warning(f"Python cache cleanup failed: {e}")
    
    @staticmethod
    def _clean_huggingface_cache(verbose: bool = True):
        """Clean HuggingFace model cache"""
        try:
            # Get config for cache directories
            from config.app_config import get_config
            config = get_config()
            
            cache_paths = [
                os.path.expanduser("~/.cache/huggingface/"),
                os.path.expanduser("~/.cache/torch/"),
                config.model_cache_dir,
                "./checkpoints/",
                "./.cache/",
            ]
            
            for cache_path in cache_paths:
                if os.path.exists(cache_path):
                    if verbose:
                        logger.info(f"Cleaning cache directory: {cache_path}")
                    
                    # Remove SAM2 specific files
                    for root, dirs, files in os.walk(cache_path):
                        for file in files:
                            if any(pattern in file.lower() for pattern in ['sam2', 'segment-anything-2']):
                                file_path = os.path.join(root, file)
                                try:
                                    os.remove(file_path)
                                    if verbose:
                                        logger.info(f"Removed cached file: {file_path}")
                                except:
                                    pass
                        
                        for dir_name in dirs[:]:
                            if any(pattern in dir_name.lower() for pattern in ['sam2', 'segment-anything-2']):
                                dir_path = os.path.join(root, dir_name)
                                try:
                                    shutil.rmtree(dir_path, ignore_errors=True)
                                    if verbose:
                                        logger.info(f"Removed cached directory: {dir_path}")
                                    dirs.remove(dir_name)
                                except:
                                    pass
                                    
        except Exception as e:
            logger.warning(f"HuggingFace cache cleanup failed: {e}")
    
    @staticmethod
    def _clean_pytorch_cache(verbose: bool = True):
        """Clean PyTorch cache"""
        try:
            import torch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                if verbose:
                    logger.info("Cleared PyTorch CUDA cache")
        except Exception as e:
            logger.warning(f"PyTorch cache cleanup failed: {e}")
    
    @staticmethod
    def _clean_temp_directories(verbose: bool = True):
        """Clean temporary directories"""
        try:
            from config.app_config import get_config
            config = get_config()
            
            temp_dirs = [
                config.temp_dir,
                tempfile.gettempdir(), 
                "/tmp", 
                "./tmp", 
                "./temp"
            ]
            
            for temp_dir in temp_dirs:
                if os.path.exists(temp_dir):
                    for item in os.listdir(temp_dir):
                        if 'sam2' in item.lower() or 'segment' in item.lower():
                            item_path = os.path.join(temp_dir, item)
                            try:
                                if os.path.isfile(item_path):
                                    os.remove(item_path)
                                elif os.path.isdir(item_path):
                                    shutil.rmtree(item_path, ignore_errors=True)
                                if verbose:
                                    logger.info(f"Removed temp item: {item_path}")
                            except:
                                pass
                                
        except Exception as e:
            logger.warning(f"Temp directory cleanup failed: {e}")
    
    @staticmethod
    def _clear_import_cache(verbose: bool = True):
        """Clear Python import cache"""
        try:
            import importlib
            
            # Invalidate import caches
            importlib.invalidate_caches()
            
            if verbose:
                logger.info("Cleared Python import cache")
                
        except Exception as e:
            logger.warning(f"Import cache cleanup failed: {e}")
    
    @staticmethod
    def _force_gc_cleanup(verbose: bool = True):
        """Force garbage collection"""
        try:
            collected = gc.collect()
            if verbose:
                logger.info(f"Garbage collection freed {collected} objects")
        except Exception as e:
            logger.warning(f"Garbage collection failed: {e}")


class WorkingSAM2Loader:
    """
    SAM2 loader using HuggingFace Transformers integration - proven to work on HF Spaces
    This avoids all the config file and CUDA compilation issues
    """
    
    @staticmethod
    def load_sam2_transformers_approach(device: str = "cuda", model_size: str = "large") -> Optional[Any]:
        """
        Load SAM2 using HuggingFace Transformers integration
        This method works reliably on HuggingFace Spaces
        """
        try:
            logger.info("Loading SAM2 via HuggingFace Transformers...")
            
            # Model size mapping
            model_map = {
                "tiny": "facebook/sam2.1-hiera-tiny",
                "small": "facebook/sam2.1-hiera-small", 
                "base": "facebook/sam2.1-hiera-base-plus",
                "large": "facebook/sam2.1-hiera-large"
            }
            
            model_id = model_map.get(model_size, model_map["large"])
            logger.info(f"Using model: {model_id}")
            
            # Method 1: Using Transformers pipeline (most reliable for HF Spaces)
            try:
                from transformers import pipeline
                
                sam2_pipeline = pipeline(
                    "mask-generation",
                    model=model_id,
                    device=0 if device == "cuda" else -1
                )
                
                logger.info("SAM2 loaded successfully via Transformers pipeline")
                return sam2_pipeline
                
            except Exception as e:
                logger.warning(f"Pipeline approach failed: {e}")
            
            # Method 2: Using SAM2 classes directly via Transformers
            try:
                from transformers import Sam2Processor, Sam2Model
                
                processor = Sam2Processor.from_pretrained(model_id)
                model = Sam2Model.from_pretrained(model_id).to(device)
                
                logger.info("SAM2 loaded successfully via Transformers classes")
                return {"model": model, "processor": processor}
                
            except Exception as e:
                logger.warning(f"Direct class approach failed: {e}")
            
            # Method 3: Using official SAM2 with .from_pretrained()
            try:
                from sam2.sam2_image_predictor import SAM2ImagePredictor
                
                predictor = SAM2ImagePredictor.from_pretrained(model_id)
                
                logger.info("SAM2 loaded successfully via official from_pretrained")
                return predictor
                
            except Exception as e:
                logger.warning(f"Official from_pretrained approach failed: {e}")
            
            return None
            
        except Exception as e:
            logger.error(f"All SAM2 loading methods failed: {e}")
            return None
    
    @staticmethod
    def load_sam2_fallback_approach(device: str = "cuda") -> Optional[Any]:
        """
        Fallback approach using direct model loading
        """
        try:
            logger.info("Trying fallback SAM2 loading approach...")
            
            # Try the simplest possible approach
            from huggingface_hub import hf_hub_download
            import torch
            
            # Download checkpoint directly
            checkpoint_path = hf_hub_download(
                repo_id="facebook/sam2.1-hiera-large",
                filename="sam2_hiera_large.pt"
            )
            
            logger.info(f"Downloaded checkpoint to: {checkpoint_path}")
            
            # Try to load with minimal dependencies
            try:
                # Method A: Try the working transformers integration
                from transformers import Sam2Model
                model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-large")
                return model.to(device)
                
            except Exception as e:
                logger.warning(f"Transformers fallback failed: {e}")
            
            return None
            
        except Exception as e:
            logger.error(f"Fallback loading failed: {e}")
            return None


def load_sam2_with_cache_cleanup(
    device: str = "cuda", 
    model_size: str = "large",
    force_cache_clean: bool = True,
    verbose: bool = True
) -> Tuple[Optional[Any], str]:
    """
    Load SAM2 with comprehensive cache cleanup
    
    Returns:
        Tuple of (model, status_message)
    """
    
    status_messages = []
    
    try:
        # Step 1: Clean caches if requested
        if force_cache_clean:
            status_messages.append("Cleaning caches...")
            HardCacheCleaner.clean_all_caches(verbose=verbose)
            status_messages.append("Cache cleanup completed")
        
        # Step 2: Try primary loading method
        status_messages.append("Loading SAM2 (primary method)...")
        model = WorkingSAM2Loader.load_sam2_transformers_approach(device, model_size)
        
        if model is not None:
            status_messages.append("SAM2 loaded successfully!")
            return model, "\n".join(status_messages)
        
        # Step 3: Try fallback method
        status_messages.append("Trying fallback loading method...")
        model = WorkingSAM2Loader.load_sam2_fallback_approach(device)
        
        if model is not None:
            status_messages.append("SAM2 loaded successfully (fallback)!")
            return model, "\n".join(status_messages)
        
        # Step 4: All methods failed
        status_messages.append("All SAM2 loading methods failed")
        return None, "\n".join(status_messages)
        
    except Exception as e:
        error_msg = f"Critical error in SAM2 loading: {e}"
        logger.error(f"{error_msg}\n{traceback.format_exc()}")
        status_messages.append(error_msg)
        return None, "\n".join(status_messages)