File size: 10,369 Bytes
baea23e
a00a1ac
23796fb
 
a00a1ac
 
9685fa7
 
a00a1ac
 
 
 
baea23e
a00a1ac
 
 
baea23e
 
 
a00a1ac
23796fb
 
 
 
a00a1ac
 
9685fa7
7d67503
23796fb
 
 
7d67503
 
 
 
 
 
9685fa7
7d67503
 
 
 
 
9685fa7
7d67503
 
9685fa7
a00a1ac
23796fb
 
644d7d5
bcb51b3
 
23796fb
 
 
 
 
 
 
9685fa7
 
23796fb
 
a00a1ac
9685fa7
 
 
 
 
a00a1ac
23796fb
baea23e
462ff09
9685fa7
23796fb
 
 
9685fa7
 
23796fb
 
 
 
 
 
 
 
9685fa7
a00a1ac
9685fa7
23796fb
a00a1ac
 
bcb51b3
 
23796fb
 
a00a1ac
23796fb
 
bcb51b3
23796fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260d38d
23796fb
 
 
 
9685fa7
 
 
23796fb
 
bcb51b3
23796fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260d38d
23796fb
 
 
a00a1ac
9685fa7
 
23796fb
 
bcb51b3
9685fa7
23796fb
260d38d
23796fb
 
260d38d
9685fa7
23796fb
a00a1ac
 
23796fb
a00a1ac
9685fa7
23796fb
bcb51b3
 
23796fb
bcb51b3
462ff09
23796fb
 
 
 
 
9685fa7
 
 
 
 
 
 
23796fb
9685fa7
 
 
23796fb
 
9685fa7
 
23796fb
 
9685fa7
 
23796fb
9685fa7
23796fb
 
 
9685fa7
23796fb
 
 
 
 
 
 
 
 
 
 
 
9685fa7
 
 
 
 
23796fb
9685fa7
 
 
 
 
19adb9d
23796fb
 
 
 
 
19adb9d
9685fa7
 
 
23796fb
9685fa7
23796fb
 
 
 
 
9685fa7
23796fb
 
9685fa7
23796fb
 
9685fa7
23796fb
 
9685fa7
23796fb
 
 
 
 
9685fa7
 
23796fb
9685fa7
 
 
a00a1ac
23796fb
 
 
 
 
a00a1ac
 
23796fb
 
 
 
 
a00a1ac
 
23796fb
 
260d38d
 
23796fb
 
a00a1ac
23796fb
19adb9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
#!/usr/bin/env python3
"""
Unified Model Loader
Coordinates separate SAM2 and MatAnyone loaders for cleaner architecture
"""

from __future__ import annotations

import os
import gc
import time
import logging
from typing import Optional, Dict, Any, Tuple, Callable

import torch

from core.exceptions import ModelLoadingError
from utils.hardware.device_manager import DeviceManager
from utils.system.memory_manager import MemoryManager

# Import the specialized loaders
from models.loaders.sam2_loader import SAM2Loader
from models.loaders.matanyone_loader import MatAnyoneLoader

logger = logging.getLogger(__name__)


class LoadedModel:
    """Container for loaded model information"""
    def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, 
                 device: str = "", framework: str = ""):
        self.model = model
        self.model_id = model_id
        self.load_time = load_time
        self.device = device
        self.framework = framework

    def to_dict(self) -> Dict[str, Any]:
        return {
            "model_id": self.model_id,
            "framework": self.framework,
            "device": self.device,
            "load_time": self.load_time,
            "loaded": self.model is not None,
        }


class ModelLoader:
    """Main model loader that coordinates SAM2 and MatAnyone loaders"""
    
    def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
        self.device_manager = device_mgr
        self.memory_manager = memory_mgr
        self.device = self.device_manager.get_optimal_device()
        
        # Initialize specialized loaders
        self.sam2_loader = SAM2Loader(device=str(self.device))
        self.matanyone_loader = MatAnyoneLoader(device=str(self.device))
        
        # Model storage
        self.sam2_predictor: Optional[LoadedModel] = None
        self.matanyone_model: Optional[LoadedModel] = None
        
        # Statistics
        self.loading_stats = {
            "sam2_load_time": 0.0,
            "matanyone_load_time": 0.0,
            "total_load_time": 0.0,
            "models_loaded": False,
            "loading_attempts": 0,
        }
        
        logger.info(f"ModelLoader initialized for device: {self.device}")

    def load_all_models(
        self,
        progress_callback: Optional[Callable[[float, str], None]] = None,
        cancel_event=None
    ) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
        """
        Load all models using specialized loaders
        
        Args:
            progress_callback: Optional callback for progress updates
            cancel_event: Optional threading.Event for cancellation
            
        Returns:
            Tuple of (sam2_model, matanyone_model)
        """
        start_time = time.time()
        self.loading_stats["loading_attempts"] += 1
        
        try:
            logger.info("Starting model loading process...")
            if progress_callback:
                progress_callback(0.0, "Initializing model loading...")
                
            # Clean up any existing models
            self._cleanup_models()
            
            # Load SAM2
            if progress_callback:
                progress_callback(0.1, "Loading SAM2 model...")
            
            sam2_start = time.time()
            sam2_model = self.sam2_loader.load()
            sam2_time = time.time() - sam2_start
            
            if sam2_model:
                self.sam2_predictor = LoadedModel(
                    model=sam2_model,
                    model_id=self.sam2_loader.model_id,
                    load_time=sam2_time,
                    device=str(self.device),
                    framework="sam2"
                )
                self.loading_stats["sam2_load_time"] = sam2_time
                logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
            else:
                logger.warning("SAM2 loading failed")
                
            # Check for cancellation
            if cancel_event and cancel_event.is_set():
                if progress_callback:
                    progress_callback(1.0, "Model loading cancelled")
                return self.sam2_predictor, None
                
            # Load MatAnyone
            if progress_callback:
                progress_callback(0.6, "Loading MatAnyone model...")
                
            matanyone_start = time.time()
            matanyone_model = self.matanyone_loader.load()
            matanyone_time = time.time() - matanyone_start
            
            if matanyone_model:
                self.matanyone_model = LoadedModel(
                    model=matanyone_model,
                    model_id=self.matanyone_loader.model_id,
                    load_time=matanyone_time,
                    device=str(self.device),
                    framework="matanyone"
                )
                self.loading_stats["matanyone_load_time"] = matanyone_time
                logger.info(f"MatAnyone loaded in {matanyone_time:.2f}s")
            else:
                logger.warning("MatAnyone loading failed")
                
            # Update statistics
            total_time = time.time() - start_time
            self.loading_stats["total_load_time"] = total_time
            self.loading_stats["models_loaded"] = bool(self.sam2_predictor or self.matanyone_model)
            
            # Final progress update
            if progress_callback:
                if self.loading_stats["models_loaded"]:
                    progress_callback(1.0, "Models loaded successfully")
                else:
                    progress_callback(1.0, "Model loading completed with failures")
                    
            logger.info(f"Model loading completed in {total_time:.2f}s")
            return self.sam2_predictor, self.matanyone_model
            
        except Exception as e:
            error_msg = f"Model loading failed: {str(e)}"
            logger.error(error_msg)
            self._cleanup_models()
            self.loading_stats["models_loaded"] = False
            
            if progress_callback:
                progress_callback(1.0, f"Error: {error_msg}")
                
            return None, None

    def reload_models(
        self,
        progress_callback: Optional[Callable[[float, str], None]] = None
    ) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
        """Reload all models from scratch"""
        logger.info("Reloading models...")
        self._cleanup_models()
        self.loading_stats["models_loaded"] = False
        return self.load_all_models(progress_callback)

    @property
    def models_ready(self) -> bool:
        """Check if any models are loaded and ready"""
        return self.sam2_predictor is not None or self.matanyone_model is not None

    def get_sam2(self):
        """Get SAM2 predictor model"""
        return self.sam2_predictor.model if self.sam2_predictor else None

    def get_matanyone(self):
        """Get MatAnyone processor model"""
        return self.matanyone_model.model if self.matanyone_model else None

    def validate_models(self) -> bool:
        """Validate that loaded models have expected interfaces"""
        try:
            valid = False
            
            if self.sam2_predictor:
                model = self.sam2_predictor.model
                if hasattr(model, "set_image") and hasattr(model, "predict"):
                    valid = True
                    logger.info("SAM2 model validated")
                    
            if self.matanyone_model:
                model = self.matanyone_model.model
                if hasattr(model, "step") or hasattr(model, "process"):
                    valid = True
                    logger.info("MatAnyone model validated")
                    
            return valid
            
        except Exception as e:
            logger.error(f"Model validation failed: {e}")
            return False

    def get_model_info(self) -> Dict[str, Any]:
        """Get detailed information about loaded models"""
        info = {
            "models_loaded": self.loading_stats["models_loaded"],
            "device": str(self.device),
            "loading_stats": self.loading_stats.copy(),
        }
        
        # Add SAM2 info
        info["sam2"] = self.sam2_loader.get_info() if self.sam2_loader else {}
        
        # Add MatAnyone info
        info["matanyone"] = self.matanyone_loader.get_info() if self.matanyone_loader else {}
        
        return info

    def get_load_summary(self) -> str:
        """Get human-readable loading summary"""
        if not self.loading_stats["models_loaded"]:
            return "No models loaded"
            
        lines = []
        lines.append(f"Models loaded in {self.loading_stats['total_load_time']:.1f}s")
        
        if self.sam2_predictor:
            lines.append(f"βœ“ SAM2: {self.loading_stats['sam2_load_time']:.1f}s")
            lines.append(f"  Model: {self.sam2_predictor.model_id}")
        else:
            lines.append("βœ— SAM2: Failed to load")
            
        if self.matanyone_model:
            lines.append(f"βœ“ MatAnyone: {self.loading_stats['matanyone_load_time']:.1f}s")
            lines.append(f"  Model: {self.matanyone_model.model_id}")
        else:
            lines.append("βœ— MatAnyone: Failed to load")
            
        lines.append(f"Device: {self.device}")
        
        return "\n".join(lines)

    def cleanup(self):
        """Clean up all resources"""
        self._cleanup_models()
        logger.info("ModelLoader cleanup completed")

    def _cleanup_models(self):
        """Internal cleanup of loaded models"""
        # Clean up SAM2
        if self.sam2_loader:
            self.sam2_loader.cleanup()
        if self.sam2_predictor:
            del self.sam2_predictor
            self.sam2_predictor = None
            
        # Clean up MatAnyone
        if self.matanyone_loader:
            self.matanyone_loader.cleanup()
        if self.matanyone_model:
            del self.matanyone_model
            self.matanyone_model = None
            
        # Clear CUDA cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        # Garbage collection
        gc.collect()
        
        logger.debug("Model cleanup completed")