VideoBackgroundReplacer / models /wrappers /matanyone_wrapper.py
MogensR's picture
Create wrappers/matanyone_wrapper.py
941297d
raw
history blame
12.7 kB
# models/wrappers/matanyone_wrapper.py
import torch
import torch.nn.functional as F
from typing import Optional, Dict, Any, Tuple, Union
import numpy as np
class MatAnyOneWrapper:
def __init__(self, core, device=None, config=None):
"""
Initialize MatAnyone wrapper with enhanced configuration.
Args:
core: MatAnyone InferenceCore instance
device: torch device (auto-detect if None)
config: Optional configuration dict for processing parameters
"""
self.core = core
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.config = config or {}
# Default processing parameters
self.threshold = self.config.get('threshold', 0.5)
self.edge_refinement = self.config.get('edge_refinement', True)
self.hair_refinement = self.config.get('hair_refinement', True)
# Component weights for multi-layer processing
self.component_weights = self.config.get('component_weights', {
'base': 1.0,
'hair': 1.2,
'edge': 1.5,
'detail': 1.1
})
# Initialize model
try:
self.core.model.to(self.device)
except Exception:
pass
try:
self.core.model.eval()
except Exception:
pass
@torch.inference_mode()
def step(self,
image_tensor: torch.Tensor,
mask_tensor: Optional[torch.Tensor] = None,
objects: Optional[Dict] = None,
first_frame_pred: bool = False,
components: Optional[Dict[str, torch.Tensor]] = None,
**kwargs) -> torch.Tensor:
"""
Process a single frame with optional component masks.
Args:
image_tensor: (1,3,H,W) float32 [0..1] on self.device
mask_tensor: (1,1,H,W) float32 [0..1] on self.device
objects: Optional object tracking info
first_frame_pred: Whether this is the first frame
components: Optional dict with keys like 'hair', 'edge', 'detail'
Each value is a (1,1,H,W) tensor
**kwargs: Additional arguments for InferenceCore
Returns:
(1,1,H,W) float32 probabilities in [0..1]
"""
# Ensure everything is on the correct device
image_tensor = image_tensor.to(self.device, non_blocking=True)
if mask_tensor is not None:
mask_tensor = mask_tensor.to(self.device, non_blocking=True)
# Process component masks if provided
if components:
components = {
k: v.to(self.device, non_blocking=True)
for k, v in components.items()
}
# Main inference call
try:
# Adapt to actual InferenceCore API
out = self.core.step(
image_tensor=image_tensor,
mask_tensor=mask_tensor,
first_frame_pred=first_frame_pred,
objects=objects,
**kwargs
)
except TypeError:
# Fallback for different API signatures
out = self.core.step(
frame=image_tensor,
mask=mask_tensor,
**kwargs
)
# Normalize output shape
out = self._normalize_output(out)
# Apply component-based refinement if available
if components:
out = self._refine_with_components(out, components)
# Apply edge refinement if enabled
if self.edge_refinement and mask_tensor is not None:
out = self._refine_edges(out, image_tensor, mask_tensor)
return out.clamp_(0, 1)
def _normalize_output(self, out: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
"""Normalize output to (1,1,H,W) tensor."""
if isinstance(out, torch.Tensor):
if out.ndim == 3: # (1,H,W) → (1,1,H,W)
out = out.unsqueeze(1)
elif out.ndim == 2: # (H,W) → (1,1,H,W)
out = out.unsqueeze(0).unsqueeze(0)
else:
out = torch.as_tensor(out, dtype=torch.float32, device=self.device)
if out.ndim == 2:
out = out.unsqueeze(0).unsqueeze(0)
elif out.ndim == 3:
out = out.unsqueeze(1)
return out
def _refine_with_components(self,
base_mask: torch.Tensor,
components: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Refine mask using component layers (hair, edge, etc).
Args:
base_mask: (1,1,H,W) base alpha mask
components: Dict of component masks
Returns:
Refined (1,1,H,W) mask
"""
refined = base_mask.clone()
# Apply hair refinement
if 'hair' in components and self.hair_refinement:
hair_mask = components['hair']
weight = self.component_weights.get('hair', 1.0)
# Enhance hair regions
refined = torch.where(
hair_mask > 0.1,
torch.maximum(refined, hair_mask * weight),
refined
)
# Apply edge refinement
if 'edge' in components:
edge_mask = components['edge']
weight = self.component_weights.get('edge', 1.0)
# Sharpen edges
refined = self._apply_edge_enhancement(refined, edge_mask, weight)
# Apply detail mask if available
if 'detail' in components:
detail_mask = components['detail']
weight = self.component_weights.get('detail', 1.0)
refined = refined * (1 - detail_mask) + detail_mask * weight
return refined.clamp_(0, 1)
def _refine_edges(self,
mask: torch.Tensor,
image: torch.Tensor,
reference_mask: torch.Tensor) -> torch.Tensor:
"""
Apply edge refinement using image gradients.
Args:
mask: (1,1,H,W) mask to refine
image: (1,3,H,W) source image
reference_mask: (1,1,H,W) reference mask
Returns:
Edge-refined mask
"""
# Compute image gradients for edge detection
gray = image.mean(dim=1, keepdim=True)
# Sobel filters for edge detection
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
dtype=torch.float32, device=self.device)
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
dtype=torch.float32, device=self.device)
sobel_x = sobel_x.view(1, 1, 3, 3)
sobel_y = sobel_y.view(1, 1, 3, 3)
# Apply Sobel filters
edge_x = F.conv2d(gray, sobel_x, padding=1)
edge_y = F.conv2d(gray, sobel_y, padding=1)
edges = torch.sqrt(edge_x**2 + edge_y**2)
# Normalize edges
edges = edges / (edges.max() + 1e-7)
# Apply edge-aware smoothing
kernel_size = 3
refined = F.avg_pool2d(mask, kernel_size, stride=1, padding=1)
# Blend based on edge strength
alpha = 1 - edges * 0.5
refined = mask * alpha + refined * (1 - alpha)
return refined.clamp_(0, 1)
def _apply_edge_enhancement(self,
mask: torch.Tensor,
edge_mask: torch.Tensor,
weight: float) -> torch.Tensor:
"""Apply edge enhancement using edge mask."""
# Dilate edges slightly for smoother boundaries
kernel = torch.ones(1, 1, 3, 3, device=self.device) / 9
dilated_edges = F.conv2d(edge_mask, kernel, padding=1)
# Enhance edges
enhanced = torch.where(
dilated_edges > 0.1,
torch.maximum(mask, dilated_edges * weight),
mask
)
return enhanced
def process_batch(self,
images: torch.Tensor,
masks: Optional[torch.Tensor] = None,
components_batch: Optional[Dict[str, torch.Tensor]] = None,
**kwargs) -> torch.Tensor:
"""
Process a batch of frames.
Args:
images: (B,3,H,W) batch of images
masks: Optional (B,1,H,W) batch of masks
components_batch: Optional dict of component batches
**kwargs: Additional arguments
Returns:
(B,1,H,W) batch of refined masks
"""
batch_size = images.shape[0]
results = []
for i in range(batch_size):
image = images[i:i+1]
mask = masks[i:i+1] if masks is not None else None
# Extract components for this frame
components = None
if components_batch:
components = {
k: v[i:i+1] for k, v in components_batch.items()
}
# Process frame
result = self.step(
image,
mask,
components=components,
first_frame_pred=(i == 0),
**kwargs
)
results.append(result)
return torch.cat(results, dim=0)
def output_prob_to_mask(self, prob: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
"""Convert probability map to binary mask."""
if isinstance(prob, torch.Tensor):
return (prob > self.threshold).float()
t = torch.as_tensor(prob, device=self.device)
return (t > self.threshold).float()
def apply_morphology(self,
mask: torch.Tensor,
operation: str = 'close',
kernel_size: int = 5) -> torch.Tensor:
"""
Apply morphological operations to clean up mask.
Args:
mask: Binary mask tensor
operation: 'close', 'open', 'dilate', or 'erode'
kernel_size: Size of morphological kernel
Returns:
Processed mask
"""
kernel = torch.ones(1, 1, kernel_size, kernel_size, device=self.device)
if operation in ['close', 'dilate']:
# Dilation
mask = F.conv2d(mask, kernel, padding=kernel_size//2)
mask = (mask > 0).float()
if operation in ['close', 'erode']:
# Erosion
mask = F.conv2d(mask, kernel, padding=kernel_size//2)
mask = (mask >= kernel_size**2).float()
if operation == 'open':
# Erosion followed by dilation
mask = F.conv2d(mask, kernel, padding=kernel_size//2)
mask = (mask >= kernel_size**2).float()
mask = F.conv2d(mask, kernel, padding=kernel_size//2)
mask = (mask > 0).float()
return mask
def get_alpha_matte(self,
image: torch.Tensor,
mask: torch.Tensor,
trimap: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Get alpha matte with optional trimap refinement.
Args:
image: (1,3,H,W) RGB image
mask: (1,1,H,W) initial mask
trimap: Optional (1,1,H,W) trimap (0=bg, 0.5=unknown, 1=fg)
Returns:
(1,1,H,W) refined alpha matte
"""
# Process through MatAnyone
alpha = self.step(image, mask)
# Apply trimap constraints if provided
if trimap is not None:
alpha = torch.where(trimap == 0, torch.zeros_like(alpha), alpha)
alpha = torch.where(trimap == 1, torch.ones_like(alpha), alpha)
return alpha
def composite(self,
foreground: torch.Tensor,
background: torch.Tensor,
alpha: torch.Tensor) -> torch.Tensor:
"""
Composite foreground over background using alpha.
Args:
foreground: (1,3,H,W) foreground image
background: (1,3,H,W) background image
alpha: (1,1,H,W) alpha matte
Returns:
(1,3,H,W) composited image
"""
return foreground * alpha + background * (1 - alpha)