|
|
""" |
|
|
Fixed MatAnyone Model Interface |
|
|
Simplified and reliable model loading |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Union, Optional |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
class SimpleMatteModel(nn.Module): |
|
|
""" |
|
|
Simplified matting model that ensures proper tensor handling |
|
|
""" |
|
|
|
|
|
def __init__(self, backbone_channels: int = 3): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
|
|
|
nn.Conv2d(backbone_channels, 64, 7, padding=3), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(64, 64, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
|
|
|
nn.Conv2d(64, 128, 3, stride=2, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(128, 128, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.Conv2d(128, 256, 3, stride=2, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(256, 256, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
|
|
|
nn.Conv2d(256, 512, 3, stride=2, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(512, 512, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
self.decoder = nn.Sequential( |
|
|
|
|
|
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(256, 256, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(128, 128, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(64, 64, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
|
|
|
nn.Conv2d(64, 1, 3, padding=1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass ensuring tensor operations |
|
|
|
|
|
Args: |
|
|
x: Input tensor (B, C, H, W) |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Alpha matte (B, 1, H, W) |
|
|
""" |
|
|
if not isinstance(x, torch.Tensor): |
|
|
raise TypeError(f"Input must be torch.Tensor, got {type(x)}") |
|
|
|
|
|
|
|
|
features = self.encoder(x) |
|
|
|
|
|
|
|
|
alpha = self.decoder(features) |
|
|
|
|
|
return alpha |
|
|
|
|
|
def forward_with_prob(self, image: torch.Tensor, prob: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass with probability guidance |
|
|
|
|
|
Args: |
|
|
image: Input image (B, 3, H, W) |
|
|
prob: Probability mask (B, 1, H, W) |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Alpha matte (B, 1, H, W) |
|
|
""" |
|
|
if not isinstance(image, torch.Tensor) or not isinstance(prob, torch.Tensor): |
|
|
raise TypeError("Both inputs must be torch.Tensor") |
|
|
|
|
|
|
|
|
x = torch.cat([image, prob], dim=1) |
|
|
|
|
|
|
|
|
return self.forward(x) |
|
|
|
|
|
|
|
|
def load_pretrained_weights(model: nn.Module, checkpoint_path: Union[str, Path]) -> nn.Module: |
|
|
""" |
|
|
Load pretrained weights with error handling |
|
|
|
|
|
Args: |
|
|
model: Model to load weights into |
|
|
checkpoint_path: Path to checkpoint file |
|
|
|
|
|
Returns: |
|
|
nn.Module: Model with loaded weights |
|
|
""" |
|
|
checkpoint_path = Path(checkpoint_path) |
|
|
|
|
|
if not checkpoint_path.exists(): |
|
|
print(f"Warning: Checkpoint not found at {checkpoint_path}") |
|
|
print("Using randomly initialized weights") |
|
|
return model |
|
|
|
|
|
try: |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict): |
|
|
if 'state_dict' in checkpoint: |
|
|
state_dict = checkpoint['state_dict'] |
|
|
elif 'model' in checkpoint: |
|
|
state_dict = checkpoint['model'] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
model_dict = model.state_dict() |
|
|
matched_dict = {} |
|
|
|
|
|
for key, value in state_dict.items(): |
|
|
|
|
|
clean_key = key.replace('module.', '') |
|
|
|
|
|
if clean_key in model_dict: |
|
|
if model_dict[clean_key].shape == value.shape: |
|
|
matched_dict[clean_key] = value |
|
|
else: |
|
|
print(f"Shape mismatch for {clean_key}: model {model_dict[clean_key].shape} vs checkpoint {value.shape}") |
|
|
else: |
|
|
print(f"Key not found in model: {clean_key}") |
|
|
|
|
|
|
|
|
model_dict.update(matched_dict) |
|
|
model.load_state_dict(model_dict) |
|
|
|
|
|
print(f"Loaded {len(matched_dict)} weights from {checkpoint_path}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading checkpoint: {e}") |
|
|
print("Using randomly initialized weights") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def get_matanyone_model(checkpoint_path: Union[str, Path], |
|
|
device: Union[str, torch.device] = 'cpu', |
|
|
backbone_channels: int = 3) -> nn.Module: |
|
|
""" |
|
|
FIXED MODEL LOADING: Create and load MatAnyone model |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to model checkpoint |
|
|
device: Device to load model on |
|
|
backbone_channels: Number of input channels (3 for RGB, 4 for RGB + prob) |
|
|
|
|
|
Returns: |
|
|
nn.Module: Loaded model |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
input_channels = 3 |
|
|
|
|
|
|
|
|
model = SimpleMatteModel(backbone_channels=input_channels) |
|
|
|
|
|
|
|
|
model = load_pretrained_weights(model, checkpoint_path) |
|
|
|
|
|
|
|
|
device = torch.device(device) |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print(f"MatAnyone model loaded on {device}") |
|
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def build_model(*args, **kwargs): |
|
|
"""Compatibility function for original MatAnyone interface""" |
|
|
return get_matanyone_model(*args, **kwargs) |
|
|
|
|
|
|
|
|
class ModelWrapper: |
|
|
""" |
|
|
Wrapper to match original MatAnyone model interface |
|
|
""" |
|
|
|
|
|
def __init__(self, model: nn.Module): |
|
|
self.model = model |
|
|
self.device = next(model.parameters()).device |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
return self.model(*args, **kwargs) |
|
|
|
|
|
def eval(self): |
|
|
return self.model.eval() |
|
|
|
|
|
def train(self, mode=True): |
|
|
return self.model.train(mode) |
|
|
|
|
|
def to(self, device): |
|
|
return ModelWrapper(self.model.to(device)) |
|
|
|
|
|
def parameters(self): |
|
|
return self.model.parameters() |
|
|
|
|
|
def state_dict(self): |
|
|
return self.model.state_dict() |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
return self.model.load_state_dict(state_dict) |