MogensR commited on
Commit
9da6723
·
1 Parent(s): 7277e52

Delete models/models

Browse files
Files changed (1) hide show
  1. models/models/loader.py +0 -491
models/models/loader.py DELETED
@@ -1,491 +0,0 @@
1
- """
2
- Model Loading Module
3
- Handles loading and validation of SAM2 and MatAnyone AI models
4
- """
5
-
6
- # ============================================================================
7
- # IMPORTS AND DEPENDENCIES
8
- # ============================================================================
9
-
10
- import os
11
- import sys
12
- import gc
13
- import traceback
14
- from pathlib import Path
15
- from typing import Optional, Dict, Any, Tuple
16
- import logging
17
-
18
- import torch
19
- import numpy as np
20
- from huggingface_hub import hf_hub_download
21
-
22
- # ============================================================================
23
- # SYSTEM PATH CONFIGURATION
24
- # ============================================================================
25
-
26
- def setup_paths():
27
- """Configure system paths for model imports"""
28
- current_dir = Path(__file__).parent.absolute()
29
- project_root = current_dir.parent.parent
30
-
31
- paths_to_add = [
32
- str(project_root),
33
- str(project_root / "models"),
34
- str(project_root / "models" / "sam2"),
35
- str(project_root / "models" / "matting"),
36
- ]
37
-
38
- for path in paths_to_add:
39
- if path not in sys.path:
40
- sys.path.insert(0, path)
41
-
42
- return project_root
43
-
44
- PROJECT_ROOT = setup_paths()
45
-
46
- # ============================================================================
47
- # MODEL IMPORTS (After Path Setup)
48
- # ============================================================================
49
-
50
- try:
51
- from sam2.build_sam import build_sam2_video_predictor
52
- from models.matting.model_initialization import ModelInitializer
53
- except ImportError as e:
54
- logging.error(f"Failed to import models: {e}")
55
- logging.error(f"Current sys.path: {sys.path}")
56
- raise
57
-
58
- # ============================================================================
59
- # CONFIGURATION
60
- # ============================================================================
61
-
62
- class ModelConfig:
63
- """Model configuration and paths"""
64
-
65
- # Model identifiers
66
- SAM2_MODEL_ID = "facebook/sam2-hiera-large"
67
- MATANYONE_REPO = "bytedance/Matting-Anything"
68
-
69
- # Model filenames
70
- SAM2_CHECKPOINT = "sam2_hiera_large.pt"
71
- SAM2_CONFIG = "sam2_hiera_l.yaml"
72
- MATANYONE_CHECKPOINT = "model_any_mat_vitl.pth"
73
-
74
- # Default paths
75
- DEFAULT_CACHE_DIR = Path.home() / ".cache" / "huggingface" / "hub"
76
-
77
- # Device configuration
78
- CUDA_AVAILABLE = torch.cuda.is_available()
79
- DEFAULT_DEVICE = "cuda" if CUDA_AVAILABLE else "cpu"
80
-
81
- # Memory thresholds (in GB)
82
- MIN_GPU_MEMORY = 8.0
83
- MIN_RAM = 16.0
84
- RECOMMENDED_GPU_MEMORY = 12.0
85
- RECOMMENDED_RAM = 32.0
86
-
87
- # ============================================================================
88
- # UTILITY FUNCTIONS
89
- # ============================================================================
90
-
91
- def get_memory_info() -> Dict[str, float]:
92
- """Get system memory information"""
93
- import psutil
94
-
95
- memory_info = {
96
- 'ram_total_gb': psutil.virtual_memory().total / (1024**3),
97
- 'ram_available_gb': psutil.virtual_memory().available / (1024**3),
98
- 'ram_used_percent': psutil.virtual_memory().percent
99
- }
100
-
101
- if torch.cuda.is_available():
102
- try:
103
- gpu_props = torch.cuda.get_device_properties(0)
104
- memory_info.update({
105
- 'gpu_name': gpu_props.name,
106
- 'gpu_total_gb': gpu_props.total_memory / (1024**3),
107
- 'gpu_allocated_gb': torch.cuda.memory_allocated(0) / (1024**3),
108
- 'gpu_reserved_gb': torch.cuda.memory_reserved(0) / (1024**3)
109
- })
110
- except Exception as e:
111
- logging.warning(f"Could not get GPU memory info: {e}")
112
-
113
- return memory_info
114
-
115
- def clean_memory():
116
- """Clean up GPU and system memory"""
117
- if torch.cuda.is_available():
118
- torch.cuda.empty_cache()
119
- torch.cuda.synchronize()
120
- gc.collect()
121
-
122
- def format_memory_status(memory_info: Dict[str, float]) -> str:
123
- """Format memory information for display"""
124
- lines = [
125
- "=== System Memory Status ===",
126
- f"RAM: {memory_info['ram_available_gb']:.1f}GB / {memory_info['ram_total_gb']:.1f}GB available ({memory_info['ram_used_percent']:.1f}% used)"
127
- ]
128
-
129
- if 'gpu_name' in memory_info:
130
- lines.extend([
131
- f"GPU: {memory_info['gpu_name']}",
132
- f" Total: {memory_info['gpu_total_gb']:.1f}GB",
133
- f" Allocated: {memory_info['gpu_allocated_gb']:.2f}GB",
134
- f" Reserved: {memory_info['gpu_reserved_gb']:.2f}GB"
135
- ])
136
- else:
137
- lines.append("GPU: Not available")
138
-
139
- return "\n".join(lines)
140
-
141
- # ============================================================================
142
- # MODEL LOADER CLASS
143
- # ============================================================================
144
-
145
- class ModelLoader:
146
- """Manages loading and initialization of AI models"""
147
-
148
- def __init__(self, cache_dir: Optional[Path] = None, device: Optional[str] = None):
149
- """
150
- Initialize model loader
151
-
152
- Args:
153
- cache_dir: Directory for caching models
154
- device: Device to load models on ('cuda' or 'cpu')
155
- """
156
- self.cache_dir = Path(cache_dir) if cache_dir else ModelConfig.DEFAULT_CACHE_DIR
157
- self.device = device or ModelConfig.DEFAULT_DEVICE
158
-
159
- # Model instances
160
- self.sam2_predictor = None
161
- self.matanyone_model = None
162
-
163
- # Status tracking
164
- self.models_loaded = False
165
- self.load_errors = []
166
-
167
- # Setup logging
168
- self.logger = logging.getLogger(__name__)
169
-
170
- # Validate system requirements
171
- self._check_system_requirements()
172
-
173
- def _check_system_requirements(self) -> bool:
174
- """Check if system meets minimum requirements"""
175
- memory_info = get_memory_info()
176
- warnings = []
177
-
178
- # Check RAM
179
- if memory_info['ram_total_gb'] < ModelConfig.MIN_RAM:
180
- warnings.append(f"⚠️ Low RAM: {memory_info['ram_total_gb']:.1f}GB (minimum {ModelConfig.MIN_RAM}GB recommended)")
181
-
182
- # Check GPU
183
- if self.device == "cuda":
184
- if 'gpu_total_gb' in memory_info:
185
- if memory_info['gpu_total_gb'] < ModelConfig.MIN_GPU_MEMORY:
186
- warnings.append(f"⚠️ Low GPU memory: {memory_info['gpu_total_gb']:.1f}GB (minimum {ModelConfig.MIN_GPU_MEMORY}GB recommended)")
187
- else:
188
- warnings.append("⚠️ Could not detect GPU memory")
189
-
190
- if warnings:
191
- self.logger.warning("\n".join(warnings))
192
-
193
- return len(warnings) == 0
194
-
195
- def _download_model(self, repo_id: str, filename: str, repo_type: str = "model") -> Path:
196
- """
197
- Download model from Hugging Face Hub
198
-
199
- Args:
200
- repo_id: Repository ID on Hugging Face
201
- filename: Name of the file to download
202
- repo_type: Type of repository
203
-
204
- Returns:
205
- Path to downloaded file
206
- """
207
- try:
208
- self.logger.info(f"Downloading {filename} from {repo_id}...")
209
-
210
- local_path = hf_hub_download(
211
- repo_id=repo_id,
212
- filename=filename,
213
- repo_type=repo_type,
214
- cache_dir=str(self.cache_dir),
215
- resume_download=True
216
- )
217
-
218
- self.logger.info(f"✓ Downloaded to: {local_path}")
219
- return Path(local_path)
220
-
221
- except Exception as e:
222
- error_msg = f"Failed to download {filename}: {str(e)}"
223
- self.logger.error(error_msg)
224
- self.load_errors.append(error_msg)
225
- raise
226
-
227
- def _load_sam2(self) -> bool:
228
- """
229
- Load SAM2 model
230
-
231
- Returns:
232
- Success status
233
- """
234
- try:
235
- self.logger.info("Loading SAM2 model...")
236
-
237
- # Download checkpoint and config
238
- checkpoint_path = self._download_model(
239
- ModelConfig.SAM2_MODEL_ID,
240
- ModelConfig.SAM2_CHECKPOINT
241
- )
242
-
243
- config_path = self._download_model(
244
- ModelConfig.SAM2_MODEL_ID,
245
- ModelConfig.SAM2_CONFIG
246
- )
247
-
248
- # Verify files exist
249
- if not checkpoint_path.exists():
250
- raise FileNotFoundError(f"SAM2 checkpoint not found: {checkpoint_path}")
251
- if not config_path.exists():
252
- raise FileNotFoundError(f"SAM2 config not found: {config_path}")
253
-
254
- # Build predictor
255
- self.sam2_predictor = build_sam2_video_predictor(
256
- str(config_path),
257
- str(checkpoint_path),
258
- device=self.device
259
- )
260
-
261
- if self.sam2_predictor is None:
262
- raise RuntimeError("SAM2 predictor initialization returned None")
263
-
264
- self.logger.info("✓ SAM2 model loaded successfully")
265
- return True
266
-
267
- except Exception as e:
268
- error_msg = f"Failed to load SAM2: {str(e)}"
269
- self.logger.error(error_msg)
270
- self.logger.debug(traceback.format_exc())
271
- self.load_errors.append(error_msg)
272
- return False
273
-
274
- def _load_matanyone(self) -> bool:
275
- """
276
- Load MatAnyone model
277
-
278
- Returns:
279
- Success status
280
- """
281
- try:
282
- self.logger.info("Loading MatAnyone model...")
283
-
284
- # Download checkpoint
285
- checkpoint_path = self._download_model(
286
- ModelConfig.MATANYONE_REPO,
287
- ModelConfig.MATANYONE_CHECKPOINT
288
- )
289
-
290
- if not checkpoint_path.exists():
291
- raise FileNotFoundError(f"MatAnyone checkpoint not found: {checkpoint_path}")
292
-
293
- # Initialize model
294
- model_init = ModelInitializer(device=self.device)
295
- self.matanyone_model = model_init.setup_models(str(checkpoint_path))
296
-
297
- if self.matanyone_model is None:
298
- raise RuntimeError("MatAnyone model initialization returned None")
299
-
300
- self.logger.info("✓ MatAnyone model loaded successfully")
301
- return True
302
-
303
- except Exception as e:
304
- error_msg = f"Failed to load MatAnyone: {str(e)}"
305
- self.logger.error(error_msg)
306
- self.logger.debug(traceback.format_exc())
307
- self.load_errors.append(error_msg)
308
- return False
309
-
310
- def load_models(self, force_reload: bool = False) -> bool:
311
- """
312
- Load all models
313
-
314
- Args:
315
- force_reload: Force reload even if already loaded
316
-
317
- Returns:
318
- Success status
319
- """
320
- if self.models_loaded and not force_reload:
321
- self.logger.info("Models already loaded")
322
- return True
323
-
324
- self.logger.info("Starting model loading process...")
325
- self.logger.info(format_memory_status(get_memory_info()))
326
-
327
- # Clean memory before loading
328
- clean_memory()
329
-
330
- # Reset status
331
- self.models_loaded = False
332
- self.load_errors = []
333
-
334
- # Load models
335
- sam2_success = self._load_sam2()
336
- matanyone_success = self._load_matanyone()
337
-
338
- # Update status
339
- self.models_loaded = sam2_success and matanyone_success
340
-
341
- # Report results
342
- if self.models_loaded:
343
- self.logger.info("✅ All models loaded successfully")
344
- self.logger.info(format_memory_status(get_memory_info()))
345
- else:
346
- self.logger.error("❌ Some models failed to load:")
347
- for error in self.load_errors:
348
- self.logger.error(f" - {error}")
349
-
350
- return self.models_loaded
351
-
352
- def validate_models(self) -> Tuple[bool, Dict[str, Any]]:
353
- """
354
- Validate loaded models
355
-
356
- Returns:
357
- Tuple of (success status, validation results)
358
- """
359
- results = {
360
- 'sam2': {'loaded': False, 'functional': False, 'error': None},
361
- 'matanyone': {'loaded': False, 'functional': False, 'error': None},
362
- 'memory': get_memory_info()
363
- }
364
-
365
- # Check SAM2
366
- if self.sam2_predictor is not None:
367
- results['sam2']['loaded'] = True
368
- try:
369
- # Simple functionality test
370
- test_frames = torch.randn(1, 3, 256, 256).to(self.device)
371
- with torch.no_grad():
372
- # Just check if the model responds without error
373
- _ = self.sam2_predictor.model.image_encoder(test_frames)
374
- results['sam2']['functional'] = True
375
- except Exception as e:
376
- results['sam2']['error'] = str(e)
377
- self.logger.error(f"SAM2 validation failed: {e}")
378
-
379
- # Check MatAnyone
380
- if self.matanyone_model is not None:
381
- results['matanyone']['loaded'] = True
382
- try:
383
- # Simple functionality test
384
- test_input = torch.randn(1, 4, 256, 256).to(self.device)
385
- with torch.no_grad():
386
- # Just check if the model responds without error
387
- _ = self.matanyone_model(test_input)
388
- results['matanyone']['functional'] = True
389
- except Exception as e:
390
- results['matanyone']['error'] = str(e)
391
- self.logger.error(f"MatAnyone validation failed: {e}")
392
-
393
- # Overall success
394
- success = (results['sam2']['functional'] and
395
- results['matanyone']['functional'])
396
-
397
- return success, results
398
-
399
- def cleanup(self):
400
- """Clean up models and free memory"""
401
- self.logger.info("Cleaning up models...")
402
-
403
- # Delete model instances
404
- if self.sam2_predictor is not None:
405
- del self.sam2_predictor
406
- self.sam2_predictor = None
407
-
408
- if self.matanyone_model is not None:
409
- del self.matanyone_model
410
- self.matanyone_model = None
411
-
412
- # Clear memory
413
- clean_memory()
414
-
415
- self.models_loaded = False
416
- self.logger.info("✓ Cleanup complete")
417
-
418
- def get_load_summary(self) -> str:
419
- """Get summary of loaded models"""
420
- lines = ["=== Model Load Summary ==="]
421
-
422
- if self.models_loaded:
423
- lines.append("✅ Models loaded successfully")
424
- lines.append(f" - SAM2: {'Loaded' if self.sam2_predictor else 'Not loaded'}")
425
- lines.append(f" - MatAnyone: {'Loaded' if self.matanyone_model else 'Not loaded'}")
426
- else:
427
- lines.append("❌ Models not fully loaded")
428
- if self.load_errors:
429
- lines.append("Errors:")
430
- for error in self.load_errors:
431
- lines.append(f" - {error}")
432
-
433
- lines.append("")
434
- lines.append(format_memory_status(get_memory_info()))
435
-
436
- return "\n".join(lines)
437
-
438
- def get_matanyone(self):
439
- """Get MatAnyone model for backward compatibility"""
440
- return self.matanyone_model
441
-
442
- def get_sam2(self):
443
- """Get SAM2 predictor for backward compatibility"""
444
- return self.sam2_predictor
445
-
446
- # ============================================================================
447
- # MAIN EXECUTION
448
- # ============================================================================
449
-
450
- def main():
451
- """Test model loading"""
452
- logging.basicConfig(
453
- level=logging.INFO,
454
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
455
- )
456
-
457
- print("Starting model loader test...")
458
- print(f"Project root: {PROJECT_ROOT}")
459
- print(f"Python path includes: {sys.path[:3]}")
460
-
461
- # Create loader
462
- loader = ModelLoader()
463
-
464
- # Load models
465
- success = loader.load_models()
466
-
467
- if success:
468
- print("\n✅ Models loaded successfully!")
469
-
470
- # Validate models
471
- valid, results = loader.validate_models()
472
-
473
- print("\nValidation Results:")
474
- print(f" SAM2: {results['sam2']}")
475
- print(f" MatAnyone: {results['matanyone']}")
476
-
477
- if valid:
478
- print("\n✅ All models validated successfully!")
479
- else:
480
- print("\n⚠️ Some models failed validation")
481
- else:
482
- print("\n❌ Failed to load models")
483
- print(loader.get_load_summary())
484
-
485
- # Cleanup
486
- input("\nPress Enter to cleanup and exit...")
487
- loader.cleanup()
488
- print("Done!")
489
-
490
- if __name__ == "__main__":
491
- main()