MogensR's picture
Update matanyone_fixed/utils/get_default_model.py
3e4b062 verified
"""
Fixed MatAnyone Model Interface
Simplified and reliable model loading
"""
import torch
import torch.nn as nn
from typing import Union, Optional
from pathlib import Path
class SimpleMatteModel(nn.Module):
"""
Simplified matting model that ensures proper tensor handling
"""
def __init__(self, backbone_channels: int = 3):
super().__init__()
# Simple encoder-decoder architecture
self.encoder = nn.Sequential(
# Initial conv
nn.Conv2d(backbone_channels, 64, 7, padding=3),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(inplace=True),
# Downsampling blocks
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(inplace=True),
# Bottleneck
nn.Conv2d(256, 512, 3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(inplace=True),
)
self.decoder = nn.Sequential(
# Upsampling blocks
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(inplace=True),
# Final prediction
nn.Conv2d(64, 1, 3, padding=1),
nn.Sigmoid()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass ensuring tensor operations
Args:
x: Input tensor (B, C, H, W)
Returns:
torch.Tensor: Alpha matte (B, 1, H, W)
"""
if not isinstance(x, torch.Tensor):
raise TypeError(f"Input must be torch.Tensor, got {type(x)}")
# Encode
features = self.encoder(x)
# Decode
alpha = self.decoder(features)
return alpha
def forward_with_prob(self, image: torch.Tensor, prob: torch.Tensor) -> torch.Tensor:
"""
Forward pass with probability guidance
Args:
image: Input image (B, 3, H, W)
prob: Probability mask (B, 1, H, W)
Returns:
torch.Tensor: Alpha matte (B, 1, H, W)
"""
if not isinstance(image, torch.Tensor) or not isinstance(prob, torch.Tensor):
raise TypeError("Both inputs must be torch.Tensor")
# Concatenate image and probability as input
x = torch.cat([image, prob], dim=1) # (B, 4, H, W)
# Forward pass
return self.forward(x)
def load_pretrained_weights(model: nn.Module, checkpoint_path: Union[str, Path]) -> nn.Module:
"""
Load pretrained weights with error handling
Args:
model: Model to load weights into
checkpoint_path: Path to checkpoint file
Returns:
nn.Module: Model with loaded weights
"""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
print(f"Warning: Checkpoint not found at {checkpoint_path}")
print("Using randomly initialized weights")
return model
try:
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Extract state dict
if isinstance(checkpoint, dict):
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
else:
state_dict = checkpoint
# Load weights with flexible key matching
model_dict = model.state_dict()
matched_dict = {}
for key, value in state_dict.items():
# Remove module prefix if present
clean_key = key.replace('module.', '')
if clean_key in model_dict:
if model_dict[clean_key].shape == value.shape:
matched_dict[clean_key] = value
else:
print(f"Shape mismatch for {clean_key}: model {model_dict[clean_key].shape} vs checkpoint {value.shape}")
else:
print(f"Key not found in model: {clean_key}")
# Load matched weights
model_dict.update(matched_dict)
model.load_state_dict(model_dict)
print(f"Loaded {len(matched_dict)} weights from {checkpoint_path}")
except Exception as e:
print(f"Error loading checkpoint: {e}")
print("Using randomly initialized weights")
return model
def get_matanyone_model(checkpoint_path: Union[str, Path],
device: Union[str, torch.device] = 'cpu',
backbone_channels: int = 3) -> nn.Module:
"""
FIXED MODEL LOADING: Create and load MatAnyone model
Args:
checkpoint_path: Path to model checkpoint
device: Device to load model on
backbone_channels: Number of input channels (3 for RGB, 4 for RGB + prob)
Returns:
nn.Module: Loaded model
"""
# Determine input channels based on usage
# If we're using probability guidance, we need 4 channels (RGB + prob)
# Otherwise, 3 channels (RGB only)
input_channels = 3 # Support both RGB and RGB+prob inputs
# Create model
model = SimpleMatteModel(backbone_channels=input_channels)
# Load pretrained weights if available
model = load_pretrained_weights(model, checkpoint_path)
# Move to device
device = torch.device(device)
model = model.to(device)
model.eval()
print(f"MatAnyone model loaded on {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
return model
# Fallback for compatibility with original MatAnyone interface
def build_model(*args, **kwargs):
"""Compatibility function for original MatAnyone interface"""
return get_matanyone_model(*args, **kwargs)
class ModelWrapper:
"""
Wrapper to match original MatAnyone model interface
"""
def __init__(self, model: nn.Module):
self.model = model
self.device = next(model.parameters()).device
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
def eval(self):
return self.model.eval()
def train(self, mode=True):
return self.model.train(mode)
def to(self, device):
return ModelWrapper(self.model.to(device))
def parameters(self):
return self.model.parameters()
def state_dict(self):
return self.model.state_dict()
def load_state_dict(self, state_dict):
return self.model.load_state_dict(state_dict)