MogensR commited on
Commit
a8a12b2
·
1 Parent(s): 7b9f1c5

Update utils/refinement.py

Browse files
Files changed (1) hide show
  1. utils/refinement.py +21 -13
utils/refinement.py CHANGED
@@ -130,13 +130,16 @@ def _refine_with_matanyone(
130
  ) -> np.ndarray:
131
  """Use MatAnyone model for mask refinement."""
132
  try:
 
 
 
133
  # Convert BGR to RGB and normalize
134
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
135
  h, w = image_rgb.shape[:2]
136
 
137
  # Convert to torch tensor format (C, H, W) and normalize to [0, 1]
138
  image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0
139
- image_tensor = image_tensor.unsqueeze(0) # Add batch dimension (1, C, H, W)
140
 
141
  # Ensure mask is binary uint8
142
  if mask.dtype != np.uint8:
@@ -144,9 +147,9 @@ def _refine_with_matanyone(
144
  if mask.ndim == 3:
145
  mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
146
 
147
- # Convert mask to tensor
148
  mask_tensor = torch.from_numpy(mask).float() / 255.0
149
- mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
150
 
151
  # Try different methods on InferenceCore
152
  result = None
@@ -157,7 +160,7 @@ def _refine_with_matanyone(
157
 
158
  with torch.no_grad():
159
  if hasattr(model, 'step'):
160
- # Step method for iterative processing (don't call reset)
161
  result = model.step(image_tensor, mask_tensor)
162
  elif hasattr(model, 'process_frame'):
163
  result = model.process_frame(image_tensor, mask_tensor)
@@ -203,18 +206,21 @@ def _refine_batch_with_matanyone(
203
  ) -> List[np.ndarray]:
204
  """Process batch of frames through MatAnyone for temporal consistency."""
205
  try:
 
 
 
206
  batch_size = len(frames)
207
  h, w = frames[0].shape[:2]
208
 
209
- # Convert frames to tensor batch
210
  frame_tensors = []
211
  for frame in frames:
212
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
213
  tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
214
  frame_tensors.append(tensor)
215
 
216
- # Stack into batch (N, C, H, W)
217
- batch_tensor = torch.stack(frame_tensors)
218
 
219
  # Prepare first mask for initialization
220
  first_mask = masks[0]
@@ -223,9 +229,9 @@ def _refine_batch_with_matanyone(
223
  if first_mask.ndim == 3:
224
  first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY)
225
 
226
- # Convert first mask to tensor
227
  first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
228
- first_mask_tensor = first_mask_tensor.unsqueeze(0).unsqueeze(0)
229
 
230
  refined_masks = []
231
 
@@ -241,12 +247,13 @@ def _refine_batch_with_matanyone(
241
  elif hasattr(model, 'step'):
242
  # Process frames sequentially with memory
243
  for i, frame_tensor in enumerate(frame_tensors):
 
244
  if i == 0:
245
  # First frame with mask
246
- result = model.step(frame_tensor.unsqueeze(0), first_mask_tensor)
247
  else:
248
  # Subsequent frames use memory from previous
249
- result = model.step(frame_tensor.unsqueeze(0), None)
250
 
251
  alpha = _extract_alpha_from_result(result)
252
  refined_masks.append(_tensor_to_mask(alpha, h, w))
@@ -256,9 +263,10 @@ def _refine_batch_with_matanyone(
256
  log.warning("MatAnyone batch processing not available, using frame-by-frame")
257
  for frame_tensor, mask in zip(frame_tensors, masks):
258
  mask_tensor = torch.from_numpy(mask).float() / 255.0
259
- mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0)
 
260
 
261
- result = model(frame_tensor.unsqueeze(0), mask_tensor)
262
  alpha = _extract_alpha_from_result(result)
263
  refined_masks.append(_tensor_to_mask(alpha, h, w))
264
 
 
130
  ) -> np.ndarray:
131
  """Use MatAnyone model for mask refinement."""
132
  try:
133
+ # Set device to GPU (Tesla T4 on cuda:0)
134
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
135
+
136
  # Convert BGR to RGB and normalize
137
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
138
  h, w = image_rgb.shape[:2]
139
 
140
  # Convert to torch tensor format (C, H, W) and normalize to [0, 1]
141
  image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0
142
+ image_tensor = image_tensor.unsqueeze(0).to(device) # Add batch dimension and move to GPU
143
 
144
  # Ensure mask is binary uint8
145
  if mask.dtype != np.uint8:
 
147
  if mask.ndim == 3:
148
  mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
149
 
150
+ # Convert mask to tensor and move to GPU
151
  mask_tensor = torch.from_numpy(mask).float() / 255.0
152
+ mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, H, W) on GPU
153
 
154
  # Try different methods on InferenceCore
155
  result = None
 
160
 
161
  with torch.no_grad():
162
  if hasattr(model, 'step'):
163
+ # Step method for iterative processing
164
  result = model.step(image_tensor, mask_tensor)
165
  elif hasattr(model, 'process_frame'):
166
  result = model.process_frame(image_tensor, mask_tensor)
 
206
  ) -> List[np.ndarray]:
207
  """Process batch of frames through MatAnyone for temporal consistency."""
208
  try:
209
+ # Set device to GPU (Tesla T4 on cuda:0)
210
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
211
+
212
  batch_size = len(frames)
213
  h, w = frames[0].shape[:2]
214
 
215
+ # Convert frames to tensor batch and move to GPU
216
  frame_tensors = []
217
  for frame in frames:
218
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
219
  tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
220
  frame_tensors.append(tensor)
221
 
222
+ # Stack into batch (N, C, H, W) and move to GPU
223
+ batch_tensor = torch.stack(frame_tensors).to(device)
224
 
225
  # Prepare first mask for initialization
226
  first_mask = masks[0]
 
229
  if first_mask.ndim == 3:
230
  first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY)
231
 
232
+ # Convert first mask to tensor and move to GPU
233
  first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
234
+ first_mask_tensor = first_mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
235
 
236
  refined_masks = []
237
 
 
247
  elif hasattr(model, 'step'):
248
  # Process frames sequentially with memory
249
  for i, frame_tensor in enumerate(frame_tensors):
250
+ frame_on_device = frame_tensor.unsqueeze(0).to(device)
251
  if i == 0:
252
  # First frame with mask
253
+ result = model.step(frame_on_device, first_mask_tensor)
254
  else:
255
  # Subsequent frames use memory from previous
256
+ result = model.step(frame_on_device, None)
257
 
258
  alpha = _extract_alpha_from_result(result)
259
  refined_masks.append(_tensor_to_mask(alpha, h, w))
 
263
  log.warning("MatAnyone batch processing not available, using frame-by-frame")
264
  for frame_tensor, mask in zip(frame_tensors, masks):
265
  mask_tensor = torch.from_numpy(mask).float() / 255.0
266
+ mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
267
+ frame_on_device = frame_tensor.unsqueeze(0).to(device)
268
 
269
+ result = model(frame_on_device, mask_tensor)
270
  alpha = _extract_alpha_from_result(result)
271
  refined_masks.append(_tensor_to_mask(alpha, h, w))
272