Update utils/refinement.py
Browse files- utils/refinement.py +21 -13
utils/refinement.py
CHANGED
|
@@ -130,13 +130,16 @@ def _refine_with_matanyone(
|
|
| 130 |
) -> np.ndarray:
|
| 131 |
"""Use MatAnyone model for mask refinement."""
|
| 132 |
try:
|
|
|
|
|
|
|
|
|
|
| 133 |
# Convert BGR to RGB and normalize
|
| 134 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 135 |
h, w = image_rgb.shape[:2]
|
| 136 |
|
| 137 |
# Convert to torch tensor format (C, H, W) and normalize to [0, 1]
|
| 138 |
image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0
|
| 139 |
-
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
|
| 140 |
|
| 141 |
# Ensure mask is binary uint8
|
| 142 |
if mask.dtype != np.uint8:
|
|
@@ -144,9 +147,9 @@ def _refine_with_matanyone(
|
|
| 144 |
if mask.ndim == 3:
|
| 145 |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
| 146 |
|
| 147 |
-
# Convert mask to tensor
|
| 148 |
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 149 |
-
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
|
| 150 |
|
| 151 |
# Try different methods on InferenceCore
|
| 152 |
result = None
|
|
@@ -157,7 +160,7 @@ def _refine_with_matanyone(
|
|
| 157 |
|
| 158 |
with torch.no_grad():
|
| 159 |
if hasattr(model, 'step'):
|
| 160 |
-
# Step method for iterative processing
|
| 161 |
result = model.step(image_tensor, mask_tensor)
|
| 162 |
elif hasattr(model, 'process_frame'):
|
| 163 |
result = model.process_frame(image_tensor, mask_tensor)
|
|
@@ -203,18 +206,21 @@ def _refine_batch_with_matanyone(
|
|
| 203 |
) -> List[np.ndarray]:
|
| 204 |
"""Process batch of frames through MatAnyone for temporal consistency."""
|
| 205 |
try:
|
|
|
|
|
|
|
|
|
|
| 206 |
batch_size = len(frames)
|
| 207 |
h, w = frames[0].shape[:2]
|
| 208 |
|
| 209 |
-
# Convert frames to tensor batch
|
| 210 |
frame_tensors = []
|
| 211 |
for frame in frames:
|
| 212 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 213 |
tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
|
| 214 |
frame_tensors.append(tensor)
|
| 215 |
|
| 216 |
-
# Stack into batch (N, C, H, W)
|
| 217 |
-
batch_tensor = torch.stack(frame_tensors)
|
| 218 |
|
| 219 |
# Prepare first mask for initialization
|
| 220 |
first_mask = masks[0]
|
|
@@ -223,9 +229,9 @@ def _refine_batch_with_matanyone(
|
|
| 223 |
if first_mask.ndim == 3:
|
| 224 |
first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY)
|
| 225 |
|
| 226 |
-
# Convert first mask to tensor
|
| 227 |
first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
|
| 228 |
-
first_mask_tensor = first_mask_tensor.unsqueeze(0).unsqueeze(0)
|
| 229 |
|
| 230 |
refined_masks = []
|
| 231 |
|
|
@@ -241,12 +247,13 @@ def _refine_batch_with_matanyone(
|
|
| 241 |
elif hasattr(model, 'step'):
|
| 242 |
# Process frames sequentially with memory
|
| 243 |
for i, frame_tensor in enumerate(frame_tensors):
|
|
|
|
| 244 |
if i == 0:
|
| 245 |
# First frame with mask
|
| 246 |
-
result = model.step(
|
| 247 |
else:
|
| 248 |
# Subsequent frames use memory from previous
|
| 249 |
-
result = model.step(
|
| 250 |
|
| 251 |
alpha = _extract_alpha_from_result(result)
|
| 252 |
refined_masks.append(_tensor_to_mask(alpha, h, w))
|
|
@@ -256,9 +263,10 @@ def _refine_batch_with_matanyone(
|
|
| 256 |
log.warning("MatAnyone batch processing not available, using frame-by-frame")
|
| 257 |
for frame_tensor, mask in zip(frame_tensors, masks):
|
| 258 |
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 259 |
-
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0)
|
|
|
|
| 260 |
|
| 261 |
-
result = model(
|
| 262 |
alpha = _extract_alpha_from_result(result)
|
| 263 |
refined_masks.append(_tensor_to_mask(alpha, h, w))
|
| 264 |
|
|
|
|
| 130 |
) -> np.ndarray:
|
| 131 |
"""Use MatAnyone model for mask refinement."""
|
| 132 |
try:
|
| 133 |
+
# Set device to GPU (Tesla T4 on cuda:0)
|
| 134 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 135 |
+
|
| 136 |
# Convert BGR to RGB and normalize
|
| 137 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 138 |
h, w = image_rgb.shape[:2]
|
| 139 |
|
| 140 |
# Convert to torch tensor format (C, H, W) and normalize to [0, 1]
|
| 141 |
image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0
|
| 142 |
+
image_tensor = image_tensor.unsqueeze(0).to(device) # Add batch dimension and move to GPU
|
| 143 |
|
| 144 |
# Ensure mask is binary uint8
|
| 145 |
if mask.dtype != np.uint8:
|
|
|
|
| 147 |
if mask.ndim == 3:
|
| 148 |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
| 149 |
|
| 150 |
+
# Convert mask to tensor and move to GPU
|
| 151 |
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 152 |
+
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, H, W) on GPU
|
| 153 |
|
| 154 |
# Try different methods on InferenceCore
|
| 155 |
result = None
|
|
|
|
| 160 |
|
| 161 |
with torch.no_grad():
|
| 162 |
if hasattr(model, 'step'):
|
| 163 |
+
# Step method for iterative processing
|
| 164 |
result = model.step(image_tensor, mask_tensor)
|
| 165 |
elif hasattr(model, 'process_frame'):
|
| 166 |
result = model.process_frame(image_tensor, mask_tensor)
|
|
|
|
| 206 |
) -> List[np.ndarray]:
|
| 207 |
"""Process batch of frames through MatAnyone for temporal consistency."""
|
| 208 |
try:
|
| 209 |
+
# Set device to GPU (Tesla T4 on cuda:0)
|
| 210 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 211 |
+
|
| 212 |
batch_size = len(frames)
|
| 213 |
h, w = frames[0].shape[:2]
|
| 214 |
|
| 215 |
+
# Convert frames to tensor batch and move to GPU
|
| 216 |
frame_tensors = []
|
| 217 |
for frame in frames:
|
| 218 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 219 |
tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
|
| 220 |
frame_tensors.append(tensor)
|
| 221 |
|
| 222 |
+
# Stack into batch (N, C, H, W) and move to GPU
|
| 223 |
+
batch_tensor = torch.stack(frame_tensors).to(device)
|
| 224 |
|
| 225 |
# Prepare first mask for initialization
|
| 226 |
first_mask = masks[0]
|
|
|
|
| 229 |
if first_mask.ndim == 3:
|
| 230 |
first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY)
|
| 231 |
|
| 232 |
+
# Convert first mask to tensor and move to GPU
|
| 233 |
first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
|
| 234 |
+
first_mask_tensor = first_mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
|
| 235 |
|
| 236 |
refined_masks = []
|
| 237 |
|
|
|
|
| 247 |
elif hasattr(model, 'step'):
|
| 248 |
# Process frames sequentially with memory
|
| 249 |
for i, frame_tensor in enumerate(frame_tensors):
|
| 250 |
+
frame_on_device = frame_tensor.unsqueeze(0).to(device)
|
| 251 |
if i == 0:
|
| 252 |
# First frame with mask
|
| 253 |
+
result = model.step(frame_on_device, first_mask_tensor)
|
| 254 |
else:
|
| 255 |
# Subsequent frames use memory from previous
|
| 256 |
+
result = model.step(frame_on_device, None)
|
| 257 |
|
| 258 |
alpha = _extract_alpha_from_result(result)
|
| 259 |
refined_masks.append(_tensor_to_mask(alpha, h, w))
|
|
|
|
| 263 |
log.warning("MatAnyone batch processing not available, using frame-by-frame")
|
| 264 |
for frame_tensor, mask in zip(frame_tensors, masks):
|
| 265 |
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 266 |
+
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
|
| 267 |
+
frame_on_device = frame_tensor.unsqueeze(0).to(device)
|
| 268 |
|
| 269 |
+
result = model(frame_on_device, mask_tensor)
|
| 270 |
alpha = _extract_alpha_from_result(result)
|
| 271 |
refined_masks.append(_tensor_to_mask(alpha, h, w))
|
| 272 |
|