MogensR commited on
Commit
1aea709
·
1 Parent(s): a099dfd

Create models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +215 -0
models/loaders/matanyone_loader.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MatAnyone Model Loader
4
+ Handles MatAnyone loading with proper device initialization
5
+ """
6
+
7
+ import os
8
+ import time
9
+ import logging
10
+ import traceback
11
+ from pathlib import Path
12
+ from typing import Optional, Dict, Any
13
+
14
+ import torch
15
+ import numpy as np
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class MatAnyoneLoader:
21
+ """Dedicated loader for MatAnyone models"""
22
+
23
+ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
24
+ self.device = device
25
+ self.cache_dir = cache_dir
26
+ os.makedirs(self.cache_dir, exist_ok=True)
27
+
28
+ self.model = None
29
+ self.model_id = "PeiqingYang/MatAnyone"
30
+ self.load_time = 0.0
31
+
32
+ def load(self) -> Optional[Any]:
33
+ """
34
+ Load MatAnyone model
35
+ Returns:
36
+ Loaded model or None
37
+ """
38
+ logger.info(f"Loading MatAnyone model: {self.model_id}")
39
+
40
+ # Try loading strategies in order
41
+ strategies = [
42
+ ("official", self._load_official),
43
+ ("fallback", self._load_fallback)
44
+ ]
45
+
46
+ for strategy_name, strategy_func in strategies:
47
+ try:
48
+ logger.info(f"Trying MatAnyone loading strategy: {strategy_name}")
49
+ start_time = time.time()
50
+ model = strategy_func()
51
+ if model:
52
+ self.load_time = time.time() - start_time
53
+ self.model = model
54
+ logger.info(f"MatAnyone loaded successfully via {strategy_name} in {self.load_time:.2f}s")
55
+ return model
56
+ except Exception as e:
57
+ logger.error(f"MatAnyone {strategy_name} strategy failed: {e}")
58
+ logger.debug(traceback.format_exc())
59
+ continue
60
+
61
+ logger.error("All MatAnyone loading strategies failed")
62
+ return None
63
+
64
+ def _load_official(self) -> Optional[Any]:
65
+ """Load using official MatAnyone API"""
66
+ from matanyone import InferenceCore
67
+
68
+ # Create processor - pass model ID as positional argument
69
+ processor = InferenceCore(self.model_id)
70
+
71
+ # Ensure processor is properly initialized for the device
72
+ if hasattr(processor, 'device'):
73
+ processor.device = self.device
74
+
75
+ # Move model components to device if they exist
76
+ if hasattr(processor, 'model'):
77
+ if hasattr(processor.model, 'to'):
78
+ processor.model = processor.model.to(self.device)
79
+ processor.model.eval()
80
+
81
+ # Patch the processor to handle inputs properly
82
+ self._patch_processor(processor)
83
+
84
+ return processor
85
+
86
+ def _patch_processor(self, processor):
87
+ """
88
+ Patch the MatAnyone processor to handle device placement and tensor formats correctly
89
+ """
90
+ original_step = None
91
+ original_process = None
92
+
93
+ if hasattr(processor, 'step'):
94
+ original_step = processor.step
95
+ if hasattr(processor, 'process'):
96
+ original_process = processor.process
97
+
98
+ device = self.device
99
+
100
+ def safe_step(image, mask, idx_mask=False, **kwargs):
101
+ """Wrapped step function with proper device handling"""
102
+ try:
103
+ # Ensure inputs are tensors on the correct device
104
+ if isinstance(image, np.ndarray):
105
+ image = torch.from_numpy(image).to(device)
106
+ elif isinstance(image, torch.Tensor):
107
+ image = image.to(device)
108
+
109
+ if isinstance(mask, np.ndarray):
110
+ mask = torch.from_numpy(mask).to(device)
111
+ elif isinstance(mask, torch.Tensor):
112
+ mask = mask.to(device)
113
+
114
+ # Handle image format (ensure CHW or NCHW)
115
+ if image.dim() == 3:
116
+ # HWC to CHW if needed
117
+ if image.shape[-1] in [1, 3, 4]:
118
+ image = image.permute(2, 0, 1)
119
+ # Add batch dimension if needed
120
+ if image.dim() == 3:
121
+ image = image.unsqueeze(0)
122
+
123
+ # Handle mask format
124
+ if mask.dim() == 2:
125
+ mask = mask.unsqueeze(0) # Add channel dimension
126
+
127
+ # Ensure float tensors
128
+ if image.dtype != torch.float32:
129
+ image = image.float()
130
+ if not idx_mask and mask.dtype != torch.float32:
131
+ mask = mask.float()
132
+
133
+ # Normalize if needed
134
+ if image.max() > 1.0:
135
+ image = image / 255.0
136
+ if not idx_mask and mask.max() > 1.0:
137
+ mask = mask / 255.0
138
+
139
+ # Call original method
140
+ if original_step:
141
+ return original_step(image, mask, idx_mask=idx_mask, **kwargs)
142
+ else:
143
+ # Fallback if no original method
144
+ return mask
145
+
146
+ except Exception as e:
147
+ logger.error(f"MatAnyone step failed: {e}")
148
+ logger.debug(traceback.format_exc())
149
+ # Return input mask as fallback
150
+ return mask
151
+
152
+ def safe_process(image, mask, **kwargs):
153
+ """Wrapped process function with proper device handling"""
154
+ try:
155
+ # Use safe_step for processing
156
+ return safe_step(image, mask, idx_mask=False, **kwargs)
157
+ except Exception as e:
158
+ logger.error(f"MatAnyone process failed: {e}")
159
+ return mask
160
+
161
+ # Apply patches
162
+ if hasattr(processor, 'step'):
163
+ processor.step = safe_step
164
+ logger.info("Patched MatAnyone step method for device safety")
165
+
166
+ if hasattr(processor, 'process'):
167
+ processor.process = safe_process
168
+ logger.info("Patched MatAnyone process method for device safety")
169
+
170
+ def _load_fallback(self) -> Optional[Any]:
171
+ """Create fallback processor for testing"""
172
+
173
+ class FallbackMatAnyone:
174
+ def __init__(self, device):
175
+ self.device = device
176
+
177
+ def step(self, image, mask, idx_mask=False, **kwargs):
178
+ """Pass through mask with minor smoothing"""
179
+ if isinstance(mask, np.ndarray):
180
+ # Apply slight Gaussian blur for edge smoothing
181
+ import cv2
182
+ if mask.ndim == 2:
183
+ smoothed = cv2.GaussianBlur(mask, (5, 5), 1.0)
184
+ return smoothed
185
+ elif mask.ndim == 3:
186
+ smoothed = np.zeros_like(mask)
187
+ for i in range(mask.shape[0]):
188
+ smoothed[i] = cv2.GaussianBlur(mask[i], (5, 5), 1.0)
189
+ return smoothed
190
+ return mask
191
+
192
+ def process(self, image, mask, **kwargs):
193
+ """Alias for step"""
194
+ return self.step(image, mask, **kwargs)
195
+
196
+ logger.warning("Using fallback MatAnyone (limited refinement)")
197
+ return FallbackMatAnyone(self.device)
198
+
199
+ def cleanup(self):
200
+ """Clean up resources"""
201
+ if self.model:
202
+ del self.model
203
+ self.model = None
204
+ if torch.cuda.is_available():
205
+ torch.cuda.empty_cache()
206
+
207
+ def get_info(self) -> Dict[str, Any]:
208
+ """Get loader information"""
209
+ return {
210
+ "loaded": self.model is not None,
211
+ "model_id": self.model_id,
212
+ "device": self.device,
213
+ "load_time": self.load_time,
214
+ "model_type": type(self.model).__name__ if self.model else None
215
+ }