MogensR commited on
Commit
69083e6
·
1 Parent(s): 1109131

Update utils/refinement.py

Browse files
Files changed (1) hide show
  1. utils/refinement.py +97 -15
utils/refinement.py CHANGED
@@ -10,6 +10,7 @@
10
 
11
  import cv2
12
  import numpy as np
 
13
 
14
  log = logging.getLogger(__name__)
15
 
@@ -84,21 +85,102 @@ def _refine_with_matanyone(
84
  model: Any
85
  ) -> np.ndarray:
86
  """Use MatAnyone model for mask refinement."""
87
- # Check if model has expected interface
88
- if hasattr(model, 'process'):
89
- result = model.process(image, mask)
90
- elif hasattr(model, 'refine'):
91
- result = model.refine(image, mask)
92
- elif callable(model):
93
- result = model(image, mask)
94
- else:
95
- raise MaskRefinementError("MatAnyone model doesn't have expected interface")
96
-
97
- # Convert result to binary mask
98
- if result is None:
99
- raise MaskRefinementError("MatAnyone returned None")
100
-
101
- return _process_mask(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  # ============================================================================
104
  # CLASSICAL REFINEMENT
 
10
 
11
  import cv2
12
  import numpy as np
13
+ import torch
14
 
15
  log = logging.getLogger(__name__)
16
 
 
85
  model: Any
86
  ) -> np.ndarray:
87
  """Use MatAnyone model for mask refinement."""
88
+ try:
89
+ # MatAnyone's InferenceCore expects torch tensors
90
+ # Convert BGR to RGB and normalize
91
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
92
+ h, w = image_rgb.shape[:2]
93
+
94
+ # Convert to torch tensor format (C, H, W) and normalize to [0, 1]
95
+ image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0
96
+ image_tensor = image_tensor.unsqueeze(0) # Add batch dimension (1, C, H, W)
97
+
98
+ # Ensure mask is binary uint8
99
+ if mask.dtype != np.uint8:
100
+ mask = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8)
101
+ if mask.ndim == 3:
102
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
103
+
104
+ # Convert mask to tensor
105
+ mask_tensor = torch.from_numpy(mask).float() / 255.0
106
+ mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
107
+
108
+ # MatAnyone InferenceCore workflow for single frame
109
+ # The model should have been initialized as InferenceCore(matanyone_model)
110
+ result = None
111
+
112
+ if hasattr(model, 'process_frame'):
113
+ # Single frame processing method
114
+ with torch.no_grad():
115
+ result = model.process_frame(image_tensor, mask_tensor)
116
+ elif hasattr(model, 'step'):
117
+ # Step method for iterative processing
118
+ with torch.no_grad():
119
+ # Initialize memory with first frame
120
+ model.reset()
121
+ # Process frame with mask
122
+ result = model.step(image_tensor, mask_tensor)
123
+ elif hasattr(model, 'forward'):
124
+ # Direct forward pass
125
+ with torch.no_grad():
126
+ result = model.forward(image_tensor, mask_tensor)
127
+ elif hasattr(model, 'predict'):
128
+ # Predict method
129
+ with torch.no_grad():
130
+ result = model.predict(image_tensor, mask_tensor)
131
+ elif hasattr(model, '__call__'):
132
+ # Callable model
133
+ with torch.no_grad():
134
+ result = model(image_tensor, mask_tensor)
135
+ else:
136
+ # Try to find any method that might work
137
+ methods = [m for m in dir(model) if not m.startswith('_')]
138
+ processing_methods = [m for m in methods if any(keyword in m.lower()
139
+ for keyword in ['process', 'refine', 'matte', 'alpha', 'predict'])]
140
+ if processing_methods:
141
+ method = getattr(model, processing_methods[0])
142
+ with torch.no_grad():
143
+ result = method(image_tensor, mask_tensor)
144
+ else:
145
+ raise MaskRefinementError(f"MatAnyone model has no recognized processing method. Available methods: {methods}")
146
+
147
+ if result is None:
148
+ raise MaskRefinementError("MatAnyone returned None")
149
+
150
+ # Handle different return types
151
+ if isinstance(result, tuple) or isinstance(result, list):
152
+ # Extract alpha matte from tuple/list result
153
+ alpha = result[0] if len(result) > 0 else None
154
+ elif isinstance(result, dict):
155
+ # Extract from dictionary result
156
+ alpha = result.get('alpha', result.get('matte', result.get('mask', None)))
157
+ else:
158
+ alpha = result
159
+
160
+ if alpha is None:
161
+ raise MaskRefinementError("Could not extract alpha matte from MatAnyone result")
162
+
163
+ # Convert back to numpy
164
+ if isinstance(alpha, torch.Tensor):
165
+ alpha = alpha.squeeze().cpu().numpy() # Remove batch dimensions
166
+
167
+ # Ensure proper shape
168
+ if alpha.ndim == 3:
169
+ alpha = alpha[0] if alpha.shape[0] == 1 else alpha.mean(axis=0)
170
+
171
+ # Convert to uint8
172
+ if alpha.dtype != np.uint8:
173
+ alpha = (alpha * 255).clip(0, 255).astype(np.uint8)
174
+
175
+ # Resize if needed
176
+ if alpha.shape != (h, w):
177
+ alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LINEAR)
178
+
179
+ return _process_mask(alpha)
180
+
181
+ except Exception as e:
182
+ log.error(f"MatAnyone processing error: {str(e)}")
183
+ raise MaskRefinementError(f"MatAnyone processing failed: {str(e)}")
184
 
185
  # ============================================================================
186
  # CLASSICAL REFINEMENT