Update utilities.py
Browse files- utilities.py +110 -1
utilities.py
CHANGED
|
@@ -400,7 +400,7 @@ def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any
|
|
| 400 |
raise MaskRefinementError(f"Unexpected error: {e}")
|
| 401 |
|
| 402 |
def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Optional[np.ndarray]:
|
| 403 |
-
"""Attempt MatAnyone mask refinement"""
|
| 404 |
try:
|
| 405 |
# Different possible MatAnyone interfaces
|
| 406 |
if hasattr(processor, 'infer'):
|
|
@@ -419,12 +419,121 @@ def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Op
|
|
| 419 |
# Process the refined mask
|
| 420 |
refined_mask = _process_mask(refined_mask)
|
| 421 |
|
|
|
|
| 422 |
return refined_mask
|
| 423 |
|
| 424 |
except Exception as e:
|
| 425 |
logger.warning(f"MatAnyone processing error: {e}")
|
| 426 |
return None
|
| 427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 429 |
"""
|
| 430 |
Advanced OpenCV-based mask enhancement with multiple techniques
|
|
|
|
| 400 |
raise MaskRefinementError(f"Unexpected error: {e}")
|
| 401 |
|
| 402 |
def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Optional[np.ndarray]:
|
| 403 |
+
"""Attempt MatAnyone mask refinement - Python 3.10 compatible"""
|
| 404 |
try:
|
| 405 |
# Different possible MatAnyone interfaces
|
| 406 |
if hasattr(processor, 'infer'):
|
|
|
|
| 419 |
# Process the refined mask
|
| 420 |
refined_mask = _process_mask(refined_mask)
|
| 421 |
|
| 422 |
+
logger.debug("MatAnyone refinement successful")
|
| 423 |
return refined_mask
|
| 424 |
|
| 425 |
except Exception as e:
|
| 426 |
logger.warning(f"MatAnyone processing error: {e}")
|
| 427 |
return None
|
| 428 |
|
| 429 |
+
def _background_matting_v2_refine(image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
|
| 430 |
+
"""Use BackgroundMattingV2 for mask refinement"""
|
| 431 |
+
try:
|
| 432 |
+
# Import BackgroundMattingV2 if available
|
| 433 |
+
from inference_images import inference_img
|
| 434 |
+
import torch
|
| 435 |
+
|
| 436 |
+
# Convert inputs to proper format
|
| 437 |
+
image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
|
| 438 |
+
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 439 |
+
|
| 440 |
+
# Create trimap from mask
|
| 441 |
+
trimap = _create_trimap_from_mask(mask)
|
| 442 |
+
trimap_tensor = torch.from_numpy(trimap).float()
|
| 443 |
+
|
| 444 |
+
# Run inference
|
| 445 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 446 |
+
|
| 447 |
+
with torch.no_grad():
|
| 448 |
+
alpha = inference_img(
|
| 449 |
+
image_tensor.unsqueeze(0).to(device),
|
| 450 |
+
trimap_tensor.unsqueeze(0).unsqueeze(0).to(device)
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Convert back to numpy
|
| 454 |
+
refined_mask = alpha.cpu().squeeze().numpy()
|
| 455 |
+
refined_mask = (refined_mask * 255).astype(np.uint8)
|
| 456 |
+
|
| 457 |
+
logger.info("BackgroundMattingV2 refinement successful")
|
| 458 |
+
return refined_mask
|
| 459 |
+
|
| 460 |
+
except ImportError:
|
| 461 |
+
logger.warning("BackgroundMattingV2 not available")
|
| 462 |
+
return None
|
| 463 |
+
except Exception as e:
|
| 464 |
+
logger.warning(f"BackgroundMattingV2 error: {e}")
|
| 465 |
+
return None
|
| 466 |
+
|
| 467 |
+
def _rembg_refine(image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
|
| 468 |
+
"""Use rembg for mask refinement"""
|
| 469 |
+
try:
|
| 470 |
+
from rembg import remove, new_session
|
| 471 |
+
|
| 472 |
+
# Use rembg to get a high-quality mask
|
| 473 |
+
session = new_session('u2net')
|
| 474 |
+
|
| 475 |
+
# Convert image to PIL
|
| 476 |
+
from PIL import Image
|
| 477 |
+
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
| 478 |
+
|
| 479 |
+
# Remove background
|
| 480 |
+
output = remove(pil_image, session=session)
|
| 481 |
+
|
| 482 |
+
# Extract alpha channel as mask
|
| 483 |
+
if output.mode == 'RGBA':
|
| 484 |
+
alpha = np.array(output)[:, :, 3]
|
| 485 |
+
else:
|
| 486 |
+
# Fallback: convert to grayscale
|
| 487 |
+
alpha = np.array(output.convert('L'))
|
| 488 |
+
|
| 489 |
+
# Combine with original mask using weighted average
|
| 490 |
+
original_mask_norm = mask.astype(np.float32) / 255.0
|
| 491 |
+
rembg_mask_norm = alpha.astype(np.float32) / 255.0
|
| 492 |
+
|
| 493 |
+
# Weighted combination: 70% rembg, 30% original
|
| 494 |
+
combined = 0.7 * rembg_mask_norm + 0.3 * original_mask_norm
|
| 495 |
+
combined = np.clip(combined * 255, 0, 255).astype(np.uint8)
|
| 496 |
+
|
| 497 |
+
logger.info("Rembg refinement successful")
|
| 498 |
+
return combined
|
| 499 |
+
|
| 500 |
+
except ImportError:
|
| 501 |
+
logger.warning("Rembg not available")
|
| 502 |
+
return None
|
| 503 |
+
except Exception as e:
|
| 504 |
+
logger.warning(f"Rembg error: {e}")
|
| 505 |
+
return None
|
| 506 |
+
|
| 507 |
+
def _create_trimap_from_mask(mask: np.ndarray, erode_size: int = 10, dilate_size: int = 20) -> np.ndarray:
|
| 508 |
+
"""Create trimap from binary mask for BackgroundMattingV2"""
|
| 509 |
+
try:
|
| 510 |
+
# Ensure mask is binary
|
| 511 |
+
_, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
|
| 512 |
+
|
| 513 |
+
# Create trimap: 0 = background, 128 = unknown, 255 = foreground
|
| 514 |
+
trimap = np.zeros_like(mask, dtype=np.uint8)
|
| 515 |
+
|
| 516 |
+
# Erode mask to get sure foreground
|
| 517 |
+
kernel_erode = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_size, erode_size))
|
| 518 |
+
sure_fg = cv2.erode(binary_mask, kernel_erode, iterations=1)
|
| 519 |
+
|
| 520 |
+
# Dilate mask to get unknown region
|
| 521 |
+
kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_size, dilate_size))
|
| 522 |
+
unknown_region = cv2.dilate(binary_mask, kernel_dilate, iterations=1)
|
| 523 |
+
|
| 524 |
+
# Set trimap values
|
| 525 |
+
trimap[sure_fg == 255] = 255 # Sure foreground
|
| 526 |
+
trimap[(unknown_region == 255) & (sure_fg == 0)] = 128 # Unknown
|
| 527 |
+
# Background remains 0
|
| 528 |
+
|
| 529 |
+
return trimap
|
| 530 |
+
|
| 531 |
+
except Exception as e:
|
| 532 |
+
logger.warning(f"Trimap creation failed: {e}")
|
| 533 |
+
# Return simple trimap based on original mask
|
| 534 |
+
trimap = np.where(mask > 127, 255, 0).astype(np.uint8)
|
| 535 |
+
return trimap
|
| 536 |
+
|
| 537 |
def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 538 |
"""
|
| 539 |
Advanced OpenCV-based mask enhancement with multiple techniques
|