MogensR commited on
Commit
9f687f8
·
1 Parent(s): 2450c76

Update utilities.py

Browse files
Files changed (1) hide show
  1. utilities.py +110 -1
utilities.py CHANGED
@@ -400,7 +400,7 @@ def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any
400
  raise MaskRefinementError(f"Unexpected error: {e}")
401
 
402
  def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Optional[np.ndarray]:
403
- """Attempt MatAnyone mask refinement"""
404
  try:
405
  # Different possible MatAnyone interfaces
406
  if hasattr(processor, 'infer'):
@@ -419,12 +419,121 @@ def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Op
419
  # Process the refined mask
420
  refined_mask = _process_mask(refined_mask)
421
 
 
422
  return refined_mask
423
 
424
  except Exception as e:
425
  logger.warning(f"MatAnyone processing error: {e}")
426
  return None
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
429
  """
430
  Advanced OpenCV-based mask enhancement with multiple techniques
 
400
  raise MaskRefinementError(f"Unexpected error: {e}")
401
 
402
  def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Optional[np.ndarray]:
403
+ """Attempt MatAnyone mask refinement - Python 3.10 compatible"""
404
  try:
405
  # Different possible MatAnyone interfaces
406
  if hasattr(processor, 'infer'):
 
419
  # Process the refined mask
420
  refined_mask = _process_mask(refined_mask)
421
 
422
+ logger.debug("MatAnyone refinement successful")
423
  return refined_mask
424
 
425
  except Exception as e:
426
  logger.warning(f"MatAnyone processing error: {e}")
427
  return None
428
 
429
+ def _background_matting_v2_refine(image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
430
+ """Use BackgroundMattingV2 for mask refinement"""
431
+ try:
432
+ # Import BackgroundMattingV2 if available
433
+ from inference_images import inference_img
434
+ import torch
435
+
436
+ # Convert inputs to proper format
437
+ image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
438
+ mask_tensor = torch.from_numpy(mask).float() / 255.0
439
+
440
+ # Create trimap from mask
441
+ trimap = _create_trimap_from_mask(mask)
442
+ trimap_tensor = torch.from_numpy(trimap).float()
443
+
444
+ # Run inference
445
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
446
+
447
+ with torch.no_grad():
448
+ alpha = inference_img(
449
+ image_tensor.unsqueeze(0).to(device),
450
+ trimap_tensor.unsqueeze(0).unsqueeze(0).to(device)
451
+ )
452
+
453
+ # Convert back to numpy
454
+ refined_mask = alpha.cpu().squeeze().numpy()
455
+ refined_mask = (refined_mask * 255).astype(np.uint8)
456
+
457
+ logger.info("BackgroundMattingV2 refinement successful")
458
+ return refined_mask
459
+
460
+ except ImportError:
461
+ logger.warning("BackgroundMattingV2 not available")
462
+ return None
463
+ except Exception as e:
464
+ logger.warning(f"BackgroundMattingV2 error: {e}")
465
+ return None
466
+
467
+ def _rembg_refine(image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
468
+ """Use rembg for mask refinement"""
469
+ try:
470
+ from rembg import remove, new_session
471
+
472
+ # Use rembg to get a high-quality mask
473
+ session = new_session('u2net')
474
+
475
+ # Convert image to PIL
476
+ from PIL import Image
477
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
478
+
479
+ # Remove background
480
+ output = remove(pil_image, session=session)
481
+
482
+ # Extract alpha channel as mask
483
+ if output.mode == 'RGBA':
484
+ alpha = np.array(output)[:, :, 3]
485
+ else:
486
+ # Fallback: convert to grayscale
487
+ alpha = np.array(output.convert('L'))
488
+
489
+ # Combine with original mask using weighted average
490
+ original_mask_norm = mask.astype(np.float32) / 255.0
491
+ rembg_mask_norm = alpha.astype(np.float32) / 255.0
492
+
493
+ # Weighted combination: 70% rembg, 30% original
494
+ combined = 0.7 * rembg_mask_norm + 0.3 * original_mask_norm
495
+ combined = np.clip(combined * 255, 0, 255).astype(np.uint8)
496
+
497
+ logger.info("Rembg refinement successful")
498
+ return combined
499
+
500
+ except ImportError:
501
+ logger.warning("Rembg not available")
502
+ return None
503
+ except Exception as e:
504
+ logger.warning(f"Rembg error: {e}")
505
+ return None
506
+
507
+ def _create_trimap_from_mask(mask: np.ndarray, erode_size: int = 10, dilate_size: int = 20) -> np.ndarray:
508
+ """Create trimap from binary mask for BackgroundMattingV2"""
509
+ try:
510
+ # Ensure mask is binary
511
+ _, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
512
+
513
+ # Create trimap: 0 = background, 128 = unknown, 255 = foreground
514
+ trimap = np.zeros_like(mask, dtype=np.uint8)
515
+
516
+ # Erode mask to get sure foreground
517
+ kernel_erode = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_size, erode_size))
518
+ sure_fg = cv2.erode(binary_mask, kernel_erode, iterations=1)
519
+
520
+ # Dilate mask to get unknown region
521
+ kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_size, dilate_size))
522
+ unknown_region = cv2.dilate(binary_mask, kernel_dilate, iterations=1)
523
+
524
+ # Set trimap values
525
+ trimap[sure_fg == 255] = 255 # Sure foreground
526
+ trimap[(unknown_region == 255) & (sure_fg == 0)] = 128 # Unknown
527
+ # Background remains 0
528
+
529
+ return trimap
530
+
531
+ except Exception as e:
532
+ logger.warning(f"Trimap creation failed: {e}")
533
+ # Return simple trimap based on original mask
534
+ trimap = np.where(mask > 127, 255, 0).astype(np.uint8)
535
+ return trimap
536
+
537
  def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
538
  """
539
  Advanced OpenCV-based mask enhancement with multiple techniques