|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, Dict, Any, Tuple, Union |
|
|
import numpy as np |
|
|
|
|
|
class MatAnyOneWrapper: |
|
|
def __init__(self, core, device=None, config=None): |
|
|
""" |
|
|
Initialize MatAnyone wrapper with enhanced configuration. |
|
|
|
|
|
Args: |
|
|
core: MatAnyone InferenceCore instance |
|
|
device: torch device (auto-detect if None) |
|
|
config: Optional configuration dict for processing parameters |
|
|
""" |
|
|
self.core = core |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.config = config or {} |
|
|
|
|
|
|
|
|
self.threshold = self.config.get('threshold', 0.5) |
|
|
self.edge_refinement = self.config.get('edge_refinement', True) |
|
|
self.hair_refinement = self.config.get('hair_refinement', True) |
|
|
|
|
|
|
|
|
self.component_weights = self.config.get('component_weights', { |
|
|
'base': 1.0, |
|
|
'hair': 1.2, |
|
|
'edge': 1.5, |
|
|
'detail': 1.1 |
|
|
}) |
|
|
|
|
|
|
|
|
try: |
|
|
self.core.model.to(self.device) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
try: |
|
|
self.core.model.eval() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
@torch.inference_mode() |
|
|
def step(self, |
|
|
image_tensor: torch.Tensor, |
|
|
mask_tensor: Optional[torch.Tensor] = None, |
|
|
objects: Optional[Dict] = None, |
|
|
first_frame_pred: bool = False, |
|
|
components: Optional[Dict[str, torch.Tensor]] = None, |
|
|
**kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Process a single frame with optional component masks. |
|
|
|
|
|
Args: |
|
|
image_tensor: (1,3,H,W) float32 [0..1] on self.device |
|
|
mask_tensor: (1,1,H,W) float32 [0..1] on self.device |
|
|
objects: Optional object tracking info |
|
|
first_frame_pred: Whether this is the first frame |
|
|
components: Optional dict with keys like 'hair', 'edge', 'detail' |
|
|
Each value is a (1,1,H,W) tensor |
|
|
**kwargs: Additional arguments for InferenceCore |
|
|
|
|
|
Returns: |
|
|
(1,1,H,W) float32 probabilities in [0..1] |
|
|
""" |
|
|
|
|
|
image_tensor = image_tensor.to(self.device, non_blocking=True) |
|
|
if mask_tensor is not None: |
|
|
mask_tensor = mask_tensor.to(self.device, non_blocking=True) |
|
|
|
|
|
|
|
|
if components: |
|
|
components = { |
|
|
k: v.to(self.device, non_blocking=True) |
|
|
for k, v in components.items() |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
out = self.core.step( |
|
|
image_tensor=image_tensor, |
|
|
mask_tensor=mask_tensor, |
|
|
first_frame_pred=first_frame_pred, |
|
|
objects=objects, |
|
|
**kwargs |
|
|
) |
|
|
except TypeError: |
|
|
|
|
|
out = self.core.step( |
|
|
frame=image_tensor, |
|
|
mask=mask_tensor, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
out = self._normalize_output(out) |
|
|
|
|
|
|
|
|
if components: |
|
|
out = self._refine_with_components(out, components) |
|
|
|
|
|
|
|
|
if self.edge_refinement and mask_tensor is not None: |
|
|
out = self._refine_edges(out, image_tensor, mask_tensor) |
|
|
|
|
|
return out.clamp_(0, 1) |
|
|
|
|
|
def _normalize_output(self, out: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: |
|
|
"""Normalize output to (1,1,H,W) tensor.""" |
|
|
if isinstance(out, torch.Tensor): |
|
|
if out.ndim == 3: |
|
|
out = out.unsqueeze(1) |
|
|
elif out.ndim == 2: |
|
|
out = out.unsqueeze(0).unsqueeze(0) |
|
|
else: |
|
|
out = torch.as_tensor(out, dtype=torch.float32, device=self.device) |
|
|
if out.ndim == 2: |
|
|
out = out.unsqueeze(0).unsqueeze(0) |
|
|
elif out.ndim == 3: |
|
|
out = out.unsqueeze(1) |
|
|
return out |
|
|
|
|
|
def _refine_with_components(self, |
|
|
base_mask: torch.Tensor, |
|
|
components: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Refine mask using component layers (hair, edge, etc). |
|
|
|
|
|
Args: |
|
|
base_mask: (1,1,H,W) base alpha mask |
|
|
components: Dict of component masks |
|
|
|
|
|
Returns: |
|
|
Refined (1,1,H,W) mask |
|
|
""" |
|
|
refined = base_mask.clone() |
|
|
|
|
|
|
|
|
if 'hair' in components and self.hair_refinement: |
|
|
hair_mask = components['hair'] |
|
|
weight = self.component_weights.get('hair', 1.0) |
|
|
|
|
|
refined = torch.where( |
|
|
hair_mask > 0.1, |
|
|
torch.maximum(refined, hair_mask * weight), |
|
|
refined |
|
|
) |
|
|
|
|
|
|
|
|
if 'edge' in components: |
|
|
edge_mask = components['edge'] |
|
|
weight = self.component_weights.get('edge', 1.0) |
|
|
|
|
|
refined = self._apply_edge_enhancement(refined, edge_mask, weight) |
|
|
|
|
|
|
|
|
if 'detail' in components: |
|
|
detail_mask = components['detail'] |
|
|
weight = self.component_weights.get('detail', 1.0) |
|
|
refined = refined * (1 - detail_mask) + detail_mask * weight |
|
|
|
|
|
return refined.clamp_(0, 1) |
|
|
|
|
|
def _refine_edges(self, |
|
|
mask: torch.Tensor, |
|
|
image: torch.Tensor, |
|
|
reference_mask: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Apply edge refinement using image gradients. |
|
|
|
|
|
Args: |
|
|
mask: (1,1,H,W) mask to refine |
|
|
image: (1,3,H,W) source image |
|
|
reference_mask: (1,1,H,W) reference mask |
|
|
|
|
|
Returns: |
|
|
Edge-refined mask |
|
|
""" |
|
|
|
|
|
gray = image.mean(dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], |
|
|
dtype=torch.float32, device=self.device) |
|
|
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], |
|
|
dtype=torch.float32, device=self.device) |
|
|
|
|
|
sobel_x = sobel_x.view(1, 1, 3, 3) |
|
|
sobel_y = sobel_y.view(1, 1, 3, 3) |
|
|
|
|
|
|
|
|
edge_x = F.conv2d(gray, sobel_x, padding=1) |
|
|
edge_y = F.conv2d(gray, sobel_y, padding=1) |
|
|
edges = torch.sqrt(edge_x**2 + edge_y**2) |
|
|
|
|
|
|
|
|
edges = edges / (edges.max() + 1e-7) |
|
|
|
|
|
|
|
|
kernel_size = 3 |
|
|
refined = F.avg_pool2d(mask, kernel_size, stride=1, padding=1) |
|
|
|
|
|
|
|
|
alpha = 1 - edges * 0.5 |
|
|
refined = mask * alpha + refined * (1 - alpha) |
|
|
|
|
|
return refined.clamp_(0, 1) |
|
|
|
|
|
def _apply_edge_enhancement(self, |
|
|
mask: torch.Tensor, |
|
|
edge_mask: torch.Tensor, |
|
|
weight: float) -> torch.Tensor: |
|
|
"""Apply edge enhancement using edge mask.""" |
|
|
|
|
|
kernel = torch.ones(1, 1, 3, 3, device=self.device) / 9 |
|
|
dilated_edges = F.conv2d(edge_mask, kernel, padding=1) |
|
|
|
|
|
|
|
|
enhanced = torch.where( |
|
|
dilated_edges > 0.1, |
|
|
torch.maximum(mask, dilated_edges * weight), |
|
|
mask |
|
|
) |
|
|
|
|
|
return enhanced |
|
|
|
|
|
def process_batch(self, |
|
|
images: torch.Tensor, |
|
|
masks: Optional[torch.Tensor] = None, |
|
|
components_batch: Optional[Dict[str, torch.Tensor]] = None, |
|
|
**kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Process a batch of frames. |
|
|
|
|
|
Args: |
|
|
images: (B,3,H,W) batch of images |
|
|
masks: Optional (B,1,H,W) batch of masks |
|
|
components_batch: Optional dict of component batches |
|
|
**kwargs: Additional arguments |
|
|
|
|
|
Returns: |
|
|
(B,1,H,W) batch of refined masks |
|
|
""" |
|
|
batch_size = images.shape[0] |
|
|
results = [] |
|
|
|
|
|
for i in range(batch_size): |
|
|
image = images[i:i+1] |
|
|
mask = masks[i:i+1] if masks is not None else None |
|
|
|
|
|
|
|
|
components = None |
|
|
if components_batch: |
|
|
components = { |
|
|
k: v[i:i+1] for k, v in components_batch.items() |
|
|
} |
|
|
|
|
|
|
|
|
result = self.step( |
|
|
image, |
|
|
mask, |
|
|
components=components, |
|
|
first_frame_pred=(i == 0), |
|
|
**kwargs |
|
|
) |
|
|
results.append(result) |
|
|
|
|
|
return torch.cat(results, dim=0) |
|
|
|
|
|
def output_prob_to_mask(self, prob: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: |
|
|
"""Convert probability map to binary mask.""" |
|
|
if isinstance(prob, torch.Tensor): |
|
|
return (prob > self.threshold).float() |
|
|
t = torch.as_tensor(prob, device=self.device) |
|
|
return (t > self.threshold).float() |
|
|
|
|
|
def apply_morphology(self, |
|
|
mask: torch.Tensor, |
|
|
operation: str = 'close', |
|
|
kernel_size: int = 5) -> torch.Tensor: |
|
|
""" |
|
|
Apply morphological operations to clean up mask. |
|
|
|
|
|
Args: |
|
|
mask: Binary mask tensor |
|
|
operation: 'close', 'open', 'dilate', or 'erode' |
|
|
kernel_size: Size of morphological kernel |
|
|
|
|
|
Returns: |
|
|
Processed mask |
|
|
""" |
|
|
kernel = torch.ones(1, 1, kernel_size, kernel_size, device=self.device) |
|
|
|
|
|
if operation in ['close', 'dilate']: |
|
|
|
|
|
mask = F.conv2d(mask, kernel, padding=kernel_size//2) |
|
|
mask = (mask > 0).float() |
|
|
|
|
|
if operation in ['close', 'erode']: |
|
|
|
|
|
mask = F.conv2d(mask, kernel, padding=kernel_size//2) |
|
|
mask = (mask >= kernel_size**2).float() |
|
|
|
|
|
if operation == 'open': |
|
|
|
|
|
mask = F.conv2d(mask, kernel, padding=kernel_size//2) |
|
|
mask = (mask >= kernel_size**2).float() |
|
|
mask = F.conv2d(mask, kernel, padding=kernel_size//2) |
|
|
mask = (mask > 0).float() |
|
|
|
|
|
return mask |
|
|
|
|
|
def get_alpha_matte(self, |
|
|
image: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
trimap: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
""" |
|
|
Get alpha matte with optional trimap refinement. |
|
|
|
|
|
Args: |
|
|
image: (1,3,H,W) RGB image |
|
|
mask: (1,1,H,W) initial mask |
|
|
trimap: Optional (1,1,H,W) trimap (0=bg, 0.5=unknown, 1=fg) |
|
|
|
|
|
Returns: |
|
|
(1,1,H,W) refined alpha matte |
|
|
""" |
|
|
|
|
|
alpha = self.step(image, mask) |
|
|
|
|
|
|
|
|
if trimap is not None: |
|
|
alpha = torch.where(trimap == 0, torch.zeros_like(alpha), alpha) |
|
|
alpha = torch.where(trimap == 1, torch.ones_like(alpha), alpha) |
|
|
|
|
|
return alpha |
|
|
|
|
|
def composite(self, |
|
|
foreground: torch.Tensor, |
|
|
background: torch.Tensor, |
|
|
alpha: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Composite foreground over background using alpha. |
|
|
|
|
|
Args: |
|
|
foreground: (1,3,H,W) foreground image |
|
|
background: (1,3,H,W) background image |
|
|
alpha: (1,1,H,W) alpha matte |
|
|
|
|
|
Returns: |
|
|
(1,3,H,W) composited image |
|
|
""" |
|
|
return foreground * alpha + background * (1 - alpha) |