MogensR commited on
Commit
7b9f1c5
·
1 Parent(s): 69083e6

Update utils/refinement.py

Browse files
Files changed (1) hide show
  1. utils/refinement.py +181 -94
utils/refinement.py CHANGED
@@ -5,7 +5,7 @@
5
  """
6
 
7
  from __future__ import annotations
8
- from typing import Any, Optional, Tuple
9
  import logging
10
 
11
  import cv2
@@ -26,11 +26,12 @@ class MaskRefinementError(Exception):
26
  # ============================================================================
27
  __all__ = [
28
  "refine_mask_hq",
 
29
  "MaskRefinementError",
30
  ]
31
 
32
  # ============================================================================
33
- # MAIN API
34
  # ============================================================================
35
  def refine_mask_hq(
36
  image: np.ndarray,
@@ -77,7 +78,50 @@ def refine_mask_hq(
77
  return mask
78
 
79
  # ============================================================================
80
- # AI-BASED REFINEMENT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # ============================================================================
82
  def _refine_with_matanyone(
83
  image: np.ndarray,
@@ -86,7 +130,6 @@ def _refine_with_matanyone(
86
  ) -> np.ndarray:
87
  """Use MatAnyone model for mask refinement."""
88
  try:
89
- # MatAnyone's InferenceCore expects torch tensors
90
  # Convert BGR to RGB and normalize
91
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
92
  h, w = image_rgb.shape[:2]
@@ -105,74 +148,42 @@ def _refine_with_matanyone(
105
  mask_tensor = torch.from_numpy(mask).float() / 255.0
106
  mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
107
 
108
- # MatAnyone InferenceCore workflow for single frame
109
- # The model should have been initialized as InferenceCore(matanyone_model)
110
  result = None
111
 
112
- if hasattr(model, 'process_frame'):
113
- # Single frame processing method
114
- with torch.no_grad():
115
- result = model.process_frame(image_tensor, mask_tensor)
116
- elif hasattr(model, 'step'):
117
- # Step method for iterative processing
118
- with torch.no_grad():
119
- # Initialize memory with first frame
120
- model.reset()
121
- # Process frame with mask
122
  result = model.step(image_tensor, mask_tensor)
123
- elif hasattr(model, 'forward'):
124
- # Direct forward pass
125
- with torch.no_grad():
126
  result = model.forward(image_tensor, mask_tensor)
127
- elif hasattr(model, 'predict'):
128
- # Predict method
129
- with torch.no_grad():
130
- result = model.predict(image_tensor, mask_tensor)
131
- elif hasattr(model, '__call__'):
132
- # Callable model
133
- with torch.no_grad():
134
  result = model(image_tensor, mask_tensor)
135
- else:
136
- # Try to find any method that might work
137
- methods = [m for m in dir(model) if not m.startswith('_')]
138
- processing_methods = [m for m in methods if any(keyword in m.lower()
139
- for keyword in ['process', 'refine', 'matte', 'alpha', 'predict'])]
140
- if processing_methods:
141
- method = getattr(model, processing_methods[0])
142
- with torch.no_grad():
143
- result = method(image_tensor, mask_tensor)
144
  else:
145
- raise MaskRefinementError(f"MatAnyone model has no recognized processing method. Available methods: {methods}")
146
 
147
  if result is None:
148
  raise MaskRefinementError("MatAnyone returned None")
149
 
150
- # Handle different return types
151
- if isinstance(result, tuple) or isinstance(result, list):
152
- # Extract alpha matte from tuple/list result
153
- alpha = result[0] if len(result) > 0 else None
154
- elif isinstance(result, dict):
155
- # Extract from dictionary result
156
- alpha = result.get('alpha', result.get('matte', result.get('mask', None)))
157
- else:
158
- alpha = result
159
-
160
- if alpha is None:
161
- raise MaskRefinementError("Could not extract alpha matte from MatAnyone result")
162
 
163
- # Convert back to numpy
164
  if isinstance(alpha, torch.Tensor):
165
- alpha = alpha.squeeze().cpu().numpy() # Remove batch dimensions
166
-
167
- # Ensure proper shape
168
  if alpha.ndim == 3:
169
  alpha = alpha[0] if alpha.shape[0] == 1 else alpha.mean(axis=0)
170
-
171
- # Convert to uint8
172
  if alpha.dtype != np.uint8:
173
  alpha = (alpha * 255).clip(0, 255).astype(np.uint8)
174
-
175
- # Resize if needed
176
  if alpha.shape != (h, w):
177
  alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LINEAR)
178
 
@@ -183,44 +194,121 @@ def _refine_with_matanyone(
183
  raise MaskRefinementError(f"MatAnyone processing failed: {str(e)}")
184
 
185
  # ============================================================================
186
- # CLASSICAL REFINEMENT
187
  # ============================================================================
188
- def _classical_refinement(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
189
- """Apply classical CV techniques for mask refinement."""
190
- refined = mask.copy()
191
-
192
- # 1. Morphological operations to clean up
193
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
194
- refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel)
195
- refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # 2. Edge-aware smoothing
198
- refined = _edge_aware_smooth(image, refined)
199
 
200
- # 3. Feather edges slightly
201
- refined = _feather_edges(refined, radius=3)
202
 
203
- # 4. Remove small disconnected components
204
- refined = _remove_small_components(refined, min_area_ratio=0.005)
205
 
206
- return refined
207
 
208
- # ============================================================================
209
- # HELPER FUNCTIONS
210
- # ============================================================================
211
  def _validate_refined_mask(refined: np.ndarray, original: np.ndarray) -> bool:
212
  """Check if refined mask is reasonable."""
213
  if refined is None or refined.size == 0:
214
  return False
215
 
216
- # Check if mask has reasonable coverage
217
  refined_area = np.sum(refined > 127)
218
  original_area = np.sum(original > 127)
219
 
220
  if refined_area == 0:
221
  return False
222
 
223
- # Allow some variation but not extreme changes
224
  ratio = refined_area / max(original_area, 1)
225
  return 0.5 <= ratio <= 2.0
226
 
@@ -239,41 +327,45 @@ def _process_mask(mask: np.ndarray) -> np.ndarray:
239
  _, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
240
  return binary
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  def _edge_aware_smooth(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
243
  """Apply edge-aware smoothing using guided filter."""
244
- # Convert to float for processing
245
  mask_float = mask.astype(np.float32) / 255.0
246
-
247
- # Simple guided filter approximation
248
  radius = 5
249
  eps = 0.01
250
 
251
- # Use image as guide
252
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
253
 
254
- # Box filter for mean
255
  mean_I = cv2.boxFilter(gray, -1, (radius, radius))
256
  mean_p = cv2.boxFilter(mask_float, -1, (radius, radius))
257
  mean_Ip = cv2.boxFilter(gray * mask_float, -1, (radius, radius))
258
 
259
- # Covariance
260
  cov_Ip = mean_Ip - mean_I * mean_p
261
-
262
- # Variance
263
  mean_II = cv2.boxFilter(gray * gray, -1, (radius, radius))
264
  var_I = mean_II - mean_I * mean_I
265
 
266
- # Coefficients
267
  a = cov_Ip / (var_I + eps)
268
  b = mean_p - a * mean_I
269
 
270
- # Filter
271
  mean_a = cv2.boxFilter(a, -1, (radius, radius))
272
  mean_b = cv2.boxFilter(b, -1, (radius, radius))
273
 
274
  refined = mean_a * gray + mean_b
275
-
276
- # Convert back to binary
277
  return (refined * 255).clip(0, 255).astype(np.uint8)
278
 
279
  def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray:
@@ -281,10 +373,8 @@ def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray:
281
  if radius <= 0:
282
  return mask
283
 
284
- # Blur then threshold to maintain binary nature
285
  blurred = cv2.GaussianBlur(mask, (radius*2+1, radius*2+1), radius/2)
286
  _, binary = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY)
287
-
288
  return binary
289
 
290
  def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) -> np.ndarray:
@@ -294,18 +384,15 @@ def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) ->
294
  if num_labels <= 1:
295
  return mask
296
 
297
- # Calculate minimum area
298
  total_area = mask.shape[0] * mask.shape[1]
299
  min_area = int(total_area * min_area_ratio)
300
 
301
- # Find largest component (excluding background)
302
  areas = stats[1:, cv2.CC_STAT_AREA]
303
  if len(areas) == 0:
304
  return mask
305
 
306
  max_label = np.argmax(areas) + 1
307
 
308
- # Keep only components above threshold or the largest one
309
  cleaned = np.zeros_like(mask)
310
  for label in range(1, num_labels):
311
  if stats[label, cv2.CC_STAT_AREA] >= min_area or label == max_label:
 
5
  """
6
 
7
  from __future__ import annotations
8
+ from typing import Any, Optional, Tuple, List
9
  import logging
10
 
11
  import cv2
 
26
  # ============================================================================
27
  __all__ = [
28
  "refine_mask_hq",
29
+ "refine_masks_batch",
30
  "MaskRefinementError",
31
  ]
32
 
33
  # ============================================================================
34
+ # MAIN API - SINGLE FRAME
35
  # ============================================================================
36
  def refine_mask_hq(
37
  image: np.ndarray,
 
78
  return mask
79
 
80
  # ============================================================================
81
+ # BATCH PROCESSING FOR TEMPORAL CONSISTENCY
82
+ # ============================================================================
83
+ def refine_masks_batch(
84
+ frames: List[np.ndarray],
85
+ masks: List[np.ndarray],
86
+ matanyone_model: Optional[Any] = None,
87
+ fallback_enabled: bool = True
88
+ ) -> List[np.ndarray]:
89
+ """
90
+ Refine multiple masks using MatAnyone's temporal consistency.
91
+
92
+ Args:
93
+ frames: List of BGR images
94
+ masks: List of initial binary masks
95
+ matanyone_model: MatAnyone InferenceCore model
96
+ fallback_enabled: Whether to use fallback methods
97
+
98
+ Returns:
99
+ List of refined binary masks
100
+ """
101
+ if not frames or not masks:
102
+ return masks
103
+
104
+ if len(frames) != len(masks):
105
+ raise MaskRefinementError(f"Frame count {len(frames)} doesn't match mask count {len(masks)}")
106
+
107
+ if matanyone_model is not None:
108
+ try:
109
+ refined = _refine_batch_with_matanyone(frames, masks, matanyone_model)
110
+ # Validate all masks
111
+ if all(_validate_refined_mask(r, m) for r, m in zip(refined, masks)):
112
+ return refined
113
+ log.warning("Batch MatAnyone refinement failed validation")
114
+ except Exception as e:
115
+ log.warning(f"Batch MatAnyone refinement failed: {e}")
116
+
117
+ # Fallback to frame-by-frame classical refinement
118
+ if fallback_enabled:
119
+ return [_classical_refinement(f, m) for f, m in zip(frames, masks)]
120
+
121
+ return masks
122
+
123
+ # ============================================================================
124
+ # AI-BASED REFINEMENT - SINGLE FRAME
125
  # ============================================================================
126
  def _refine_with_matanyone(
127
  image: np.ndarray,
 
130
  ) -> np.ndarray:
131
  """Use MatAnyone model for mask refinement."""
132
  try:
 
133
  # Convert BGR to RGB and normalize
134
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
135
  h, w = image_rgb.shape[:2]
 
148
  mask_tensor = torch.from_numpy(mask).float() / 255.0
149
  mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
150
 
151
+ # Try different methods on InferenceCore
 
152
  result = None
153
 
154
+ # Log available methods for debugging
155
+ methods = [m for m in dir(model) if not m.startswith('_')]
156
+ log.debug(f"MatAnyone InferenceCore methods: {methods}")
157
+
158
+ with torch.no_grad():
159
+ if hasattr(model, 'step'):
160
+ # Step method for iterative processing (don't call reset)
 
 
 
161
  result = model.step(image_tensor, mask_tensor)
162
+ elif hasattr(model, 'process_frame'):
163
+ result = model.process_frame(image_tensor, mask_tensor)
164
+ elif hasattr(model, 'forward'):
165
  result = model.forward(image_tensor, mask_tensor)
166
+ elif hasattr(model, '__call__'):
 
 
 
 
 
 
167
  result = model(image_tensor, mask_tensor)
 
 
 
 
 
 
 
 
 
168
  else:
169
+ raise MaskRefinementError(f"No recognized method. Available: {methods}")
170
 
171
  if result is None:
172
  raise MaskRefinementError("MatAnyone returned None")
173
 
174
+ # Extract alpha matte from result
175
+ alpha = _extract_alpha_from_result(result)
 
 
 
 
 
 
 
 
 
 
176
 
177
+ # Convert back to numpy and resize if needed
178
  if isinstance(alpha, torch.Tensor):
179
+ alpha = alpha.squeeze().cpu().numpy()
180
+
 
181
  if alpha.ndim == 3:
182
  alpha = alpha[0] if alpha.shape[0] == 1 else alpha.mean(axis=0)
183
+
 
184
  if alpha.dtype != np.uint8:
185
  alpha = (alpha * 255).clip(0, 255).astype(np.uint8)
186
+
 
187
  if alpha.shape != (h, w):
188
  alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LINEAR)
189
 
 
194
  raise MaskRefinementError(f"MatAnyone processing failed: {str(e)}")
195
 
196
  # ============================================================================
197
+ # AI-BASED REFINEMENT - BATCH
198
  # ============================================================================
199
+ def _refine_batch_with_matanyone(
200
+ frames: List[np.ndarray],
201
+ masks: List[np.ndarray],
202
+ model: Any
203
+ ) -> List[np.ndarray]:
204
+ """Process batch of frames through MatAnyone for temporal consistency."""
205
+ try:
206
+ batch_size = len(frames)
207
+ h, w = frames[0].shape[:2]
208
+
209
+ # Convert frames to tensor batch
210
+ frame_tensors = []
211
+ for frame in frames:
212
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
213
+ tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
214
+ frame_tensors.append(tensor)
215
+
216
+ # Stack into batch (N, C, H, W)
217
+ batch_tensor = torch.stack(frame_tensors)
218
+
219
+ # Prepare first mask for initialization
220
+ first_mask = masks[0]
221
+ if first_mask.dtype != np.uint8:
222
+ first_mask = (first_mask * 255).astype(np.uint8)
223
+ if first_mask.ndim == 3:
224
+ first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY)
225
+
226
+ # Convert first mask to tensor
227
+ first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
228
+ first_mask_tensor = first_mask_tensor.unsqueeze(0).unsqueeze(0)
229
+
230
+ refined_masks = []
231
+
232
+ with torch.no_grad():
233
+ # Check for batch processing methods
234
+ if hasattr(model, 'process_batch'):
235
+ # Direct batch processing
236
+ results = model.process_batch(batch_tensor, first_mask_tensor)
237
+ for result in results:
238
+ alpha = _extract_alpha_from_result(result)
239
+ refined_masks.append(_tensor_to_mask(alpha, h, w))
240
+
241
+ elif hasattr(model, 'step'):
242
+ # Process frames sequentially with memory
243
+ for i, frame_tensor in enumerate(frame_tensors):
244
+ if i == 0:
245
+ # First frame with mask
246
+ result = model.step(frame_tensor.unsqueeze(0), first_mask_tensor)
247
+ else:
248
+ # Subsequent frames use memory from previous
249
+ result = model.step(frame_tensor.unsqueeze(0), None)
250
+
251
+ alpha = _extract_alpha_from_result(result)
252
+ refined_masks.append(_tensor_to_mask(alpha, h, w))
253
+
254
+ else:
255
+ # Fallback to processing each frame with its mask
256
+ log.warning("MatAnyone batch processing not available, using frame-by-frame")
257
+ for frame_tensor, mask in zip(frame_tensors, masks):
258
+ mask_tensor = torch.from_numpy(mask).float() / 255.0
259
+ mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0)
260
+
261
+ result = model(frame_tensor.unsqueeze(0), mask_tensor)
262
+ alpha = _extract_alpha_from_result(result)
263
+ refined_masks.append(_tensor_to_mask(alpha, h, w))
264
+
265
+ return refined_masks
266
+
267
+ except Exception as e:
268
+ log.error(f"Batch MatAnyone processing error: {str(e)}")
269
+ raise MaskRefinementError(f"Batch processing failed: {str(e)}")
270
+
271
+ # ============================================================================
272
+ # HELPER FUNCTIONS
273
+ # ============================================================================
274
+ def _extract_alpha_from_result(result):
275
+ """Extract alpha matte from various result formats."""
276
+ if isinstance(result, (tuple, list)):
277
+ return result[0] if len(result) > 0 else None
278
+ elif isinstance(result, dict):
279
+ return result.get('alpha', result.get('matte', result.get('mask', None)))
280
+ else:
281
+ return result
282
+
283
+ def _tensor_to_mask(tensor, target_h, target_w):
284
+ """Convert tensor to numpy mask with proper sizing."""
285
+ if isinstance(tensor, torch.Tensor):
286
+ mask = tensor.squeeze().cpu().numpy()
287
+ else:
288
+ mask = tensor
289
 
290
+ if mask.ndim == 3:
291
+ mask = mask[0] if mask.shape[0] == 1 else mask.mean(axis=0)
292
 
293
+ if mask.dtype != np.uint8:
294
+ mask = (mask * 255).clip(0, 255).astype(np.uint8)
295
 
296
+ if mask.shape != (target_h, target_w):
297
+ mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
298
 
299
+ return mask
300
 
 
 
 
301
  def _validate_refined_mask(refined: np.ndarray, original: np.ndarray) -> bool:
302
  """Check if refined mask is reasonable."""
303
  if refined is None or refined.size == 0:
304
  return False
305
 
 
306
  refined_area = np.sum(refined > 127)
307
  original_area = np.sum(original > 127)
308
 
309
  if refined_area == 0:
310
  return False
311
 
 
312
  ratio = refined_area / max(original_area, 1)
313
  return 0.5 <= ratio <= 2.0
314
 
 
327
  _, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
328
  return binary
329
 
330
+ # ============================================================================
331
+ # CLASSICAL REFINEMENT
332
+ # ============================================================================
333
+ def _classical_refinement(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
334
+ """Apply classical CV techniques for mask refinement."""
335
+ refined = mask.copy()
336
+
337
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
338
+ refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel)
339
+ refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel)
340
+ refined = _edge_aware_smooth(image, refined)
341
+ refined = _feather_edges(refined, radius=3)
342
+ refined = _remove_small_components(refined, min_area_ratio=0.005)
343
+
344
+ return refined
345
+
346
  def _edge_aware_smooth(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
347
  """Apply edge-aware smoothing using guided filter."""
 
348
  mask_float = mask.astype(np.float32) / 255.0
 
 
349
  radius = 5
350
  eps = 0.01
351
 
 
352
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
353
 
 
354
  mean_I = cv2.boxFilter(gray, -1, (radius, radius))
355
  mean_p = cv2.boxFilter(mask_float, -1, (radius, radius))
356
  mean_Ip = cv2.boxFilter(gray * mask_float, -1, (radius, radius))
357
 
 
358
  cov_Ip = mean_Ip - mean_I * mean_p
 
 
359
  mean_II = cv2.boxFilter(gray * gray, -1, (radius, radius))
360
  var_I = mean_II - mean_I * mean_I
361
 
 
362
  a = cov_Ip / (var_I + eps)
363
  b = mean_p - a * mean_I
364
 
 
365
  mean_a = cv2.boxFilter(a, -1, (radius, radius))
366
  mean_b = cv2.boxFilter(b, -1, (radius, radius))
367
 
368
  refined = mean_a * gray + mean_b
 
 
369
  return (refined * 255).clip(0, 255).astype(np.uint8)
370
 
371
  def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray:
 
373
  if radius <= 0:
374
  return mask
375
 
 
376
  blurred = cv2.GaussianBlur(mask, (radius*2+1, radius*2+1), radius/2)
377
  _, binary = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY)
 
378
  return binary
379
 
380
  def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) -> np.ndarray:
 
384
  if num_labels <= 1:
385
  return mask
386
 
 
387
  total_area = mask.shape[0] * mask.shape[1]
388
  min_area = int(total_area * min_area_ratio)
389
 
 
390
  areas = stats[1:, cv2.CC_STAT_AREA]
391
  if len(areas) == 0:
392
  return mask
393
 
394
  max_label = np.argmax(areas) + 1
395
 
 
396
  cleaned = np.zeros_like(mask)
397
  for label in range(1, num_labels):
398
  if stats[label, cv2.CC_STAT_AREA] >= min_area or label == max_label: