MogensR commited on
Commit
0f94e43
·
1 Parent(s): 345218c

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +96 -44
models/loaders/matanyone_loader.py CHANGED
@@ -87,20 +87,52 @@ def _patch_processor(self, processor):
87
  """
88
  Patch the MatAnyone processor to handle device placement and tensor formats correctly
89
  """
90
- original_step = None
91
- original_process = None
92
-
93
- if hasattr(processor, 'step'):
94
- original_step = processor.step
95
- if hasattr(processor, 'process'):
96
- original_process = processor.process
97
 
98
  device = self.device
99
 
100
- def safe_step(image, mask, idx_mask=False, **kwargs):
101
- """Wrapped step function with proper device handling"""
102
  try:
103
- # Ensure inputs are tensors on the correct device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  if isinstance(image, np.ndarray):
105
  image = torch.from_numpy(image).to(device)
106
  elif isinstance(image, torch.Tensor):
@@ -111,61 +143,81 @@ def safe_step(image, mask, idx_mask=False, **kwargs):
111
  elif isinstance(mask, torch.Tensor):
112
  mask = mask.to(device)
113
 
114
- # Handle image format (ensure CHW or NCHW)
115
- if image.dim() == 3:
116
- # HWC to CHW if needed
117
- if image.shape[-1] in [1, 3, 4]:
118
- image = image.permute(2, 0, 1)
119
- # Add batch dimension if needed
120
- if image.dim() == 3:
121
- image = image.unsqueeze(0)
 
 
 
 
 
 
122
 
123
- # Handle mask format
124
  if mask.dim() == 2:
125
- mask = mask.unsqueeze(0) # Add channel dimension
 
 
 
126
 
127
- # Ensure float tensors
128
  if image.dtype != torch.float32:
129
  image = image.float()
130
  if not idx_mask and mask.dtype != torch.float32:
131
  mask = mask.float()
132
 
133
- # Normalize if needed
134
  if image.max() > 1.0:
135
  image = image / 255.0
136
  if not idx_mask and mask.max() > 1.0:
137
  mask = mask / 255.0
138
 
139
- # Call original method
140
  if original_step:
141
- return original_step(image, mask, idx_mask=idx_mask, **kwargs)
142
- else:
143
- # Fallback if no original method
144
- return mask
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- except Exception as e:
147
- logger.error(f"MatAnyone step failed: {e}")
148
- logger.debug(traceback.format_exc())
149
- # Return input mask as fallback
150
  return mask
151
-
152
- def safe_process(image, mask, **kwargs):
153
- """Wrapped process function with proper device handling"""
154
- try:
155
- # Use safe_step for processing
156
- return safe_step(image, mask, idx_mask=False, **kwargs)
157
  except Exception as e:
158
- logger.error(f"MatAnyone process failed: {e}")
159
- return mask
 
 
 
 
 
 
 
160
 
161
- # Apply patches
162
  if hasattr(processor, 'step'):
163
- processor.step = safe_step
164
- logger.info("Patched MatAnyone step method for device safety")
165
 
166
  if hasattr(processor, 'process'):
167
- processor.process = safe_process
168
- logger.info("Patched MatAnyone process method for device safety")
 
 
 
169
 
170
  def _load_fallback(self) -> Optional[Any]:
171
  """Create fallback processor for testing"""
 
87
  """
88
  Patch the MatAnyone processor to handle device placement and tensor formats correctly
89
  """
90
+ original_step = getattr(processor, 'step', None)
91
+ original_process = getattr(processor, 'process', None)
 
 
 
 
 
92
 
93
  device = self.device
94
 
95
+ def safe_wrapper(*args, **kwargs):
96
+ """Universal wrapper that handles both step and process calls"""
97
  try:
98
+ # Handle different calling patterns
99
+ # Pattern 1: step(image, mask, idx_mask=False)
100
+ # Pattern 2: process(image, mask)
101
+ # Pattern 3: Called with just args
102
+ # Pattern 4: Called with kwargs
103
+
104
+ image = None
105
+ mask = None
106
+ idx_mask = kwargs.get('idx_mask', False)
107
+
108
+ # Extract image and mask
109
+ if 'image' in kwargs and 'mask' in kwargs:
110
+ image = kwargs['image']
111
+ mask = kwargs['mask']
112
+ elif len(args) >= 2:
113
+ image = args[0]
114
+ mask = args[1]
115
+ if len(args) > 2:
116
+ idx_mask = args[2]
117
+ elif len(args) == 1:
118
+ # Might be called with just mask for refinement
119
+ mask = args[0]
120
+ # Create dummy image if needed
121
+ if isinstance(mask, np.ndarray):
122
+ h, w = mask.shape[:2] if mask.ndim >= 2 else (512, 512)
123
+ image = np.zeros((h, w, 3), dtype=np.uint8)
124
+ elif isinstance(mask, torch.Tensor):
125
+ h, w = mask.shape[-2:] if mask.dim() >= 2 else (512, 512)
126
+ image = torch.zeros((h, w, 3), dtype=torch.uint8)
127
+
128
+ if image is None or mask is None:
129
+ logger.error(f"MatAnyone called with invalid args: {len(args)} args, kwargs: {kwargs.keys()}")
130
+ # Return something safe
131
+ if mask is not None:
132
+ return mask
133
+ return np.ones((512, 512), dtype=np.float32) * 0.5
134
+
135
+ # Convert to tensors on correct device
136
  if isinstance(image, np.ndarray):
137
  image = torch.from_numpy(image).to(device)
138
  elif isinstance(image, torch.Tensor):
 
143
  elif isinstance(mask, torch.Tensor):
144
  mask = mask.to(device)
145
 
146
+ # Fix image format (ensure CHW or NCHW)
147
+ if image.dim() == 2: # Grayscale HW
148
+ image = image.unsqueeze(0) # CHW
149
+ elif image.dim() == 3:
150
+ # Check if HWC or CHW
151
+ if image.shape[-1] in [1, 3, 4]: # HWC
152
+ image = image.permute(2, 0, 1) # CHW
153
+ # Add batch if needed
154
+ if image.shape[0] in [1, 3, 4]: # CHW
155
+ image = image.unsqueeze(0) # NCHW
156
+ elif image.dim() == 4:
157
+ # Already NCHW, ensure correct channel position
158
+ if image.shape[-1] in [1, 3, 4]: # NHWC
159
+ image = image.permute(0, 3, 1, 2) # NCHW
160
 
161
+ # Fix mask format
162
  if mask.dim() == 2:
163
+ mask = mask.unsqueeze(0) # Add channel: CHW
164
+ elif mask.dim() == 3:
165
+ if mask.shape[0] > 4: # Likely HWC
166
+ mask = mask.permute(2, 0, 1) # CHW
167
 
168
+ # Ensure float and normalized
169
  if image.dtype != torch.float32:
170
  image = image.float()
171
  if not idx_mask and mask.dtype != torch.float32:
172
  mask = mask.float()
173
 
 
174
  if image.max() > 1.0:
175
  image = image / 255.0
176
  if not idx_mask and mask.max() > 1.0:
177
  mask = mask / 255.0
178
 
179
+ # Call original method if it exists
180
  if original_step:
181
+ try:
182
+ result = original_step(image, mask, idx_mask=idx_mask)
183
+ # Convert result back to numpy if needed
184
+ if isinstance(result, torch.Tensor):
185
+ result = result.cpu().numpy()
186
+ return result
187
+ except Exception as e:
188
+ logger.error(f"MatAnyone original step failed: {e}")
189
+
190
+ # Fallback: return slightly processed mask
191
+ if isinstance(mask, torch.Tensor):
192
+ # Apply slight smoothing
193
+ import torch.nn.functional as F
194
+ mask = F.avg_pool2d(mask.unsqueeze(0), 3, stride=1, padding=1)
195
+ mask = mask.squeeze(0).cpu().numpy()
196
 
 
 
 
 
197
  return mask
198
+
 
 
 
 
 
199
  except Exception as e:
200
+ logger.error(f"MatAnyone safe_wrapper failed: {e}")
201
+ import traceback
202
+ logger.debug(traceback.format_exc())
203
+ # Return safe fallback
204
+ if 'mask' in locals() and mask is not None:
205
+ if isinstance(mask, torch.Tensor):
206
+ return mask.cpu().numpy()
207
+ return mask
208
+ return np.ones((512, 512), dtype=np.float32) * 0.5
209
 
210
+ # Apply patches to both methods
211
  if hasattr(processor, 'step'):
212
+ processor.step = safe_wrapper
213
+ logger.info("Patched MatAnyone step method")
214
 
215
  if hasattr(processor, 'process'):
216
+ processor.process = safe_wrapper
217
+ logger.info("Patched MatAnyone process method")
218
+
219
+ # Also add a direct call method
220
+ processor.__call__ = safe_wrapper
221
 
222
  def _load_fallback(self) -> Optional[Any]:
223
  """Create fallback processor for testing"""