""" Fixed MatAnyone Model Interface Simplified and reliable model loading """ import torch import torch.nn as nn from typing import Union, Optional from pathlib import Path class SimpleMatteModel(nn.Module): """ Simplified matting model that ensures proper tensor handling """ def __init__(self, backbone_channels: int = 3): super().__init__() # Simple encoder-decoder architecture self.encoder = nn.Sequential( # Initial conv nn.Conv2d(backbone_channels, 64, 7, padding=3), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(inplace=True), # Downsampling blocks nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(inplace=True), # Bottleneck nn.Conv2d(256, 512, 3, stride=2, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3, padding=1), nn.ReLU(inplace=True), ) self.decoder = nn.Sequential( # Upsampling blocks nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(inplace=True), # Final prediction nn.Conv2d(64, 1, 3, padding=1), nn.Sigmoid() ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass ensuring tensor operations Args: x: Input tensor (B, C, H, W) Returns: torch.Tensor: Alpha matte (B, 1, H, W) """ if not isinstance(x, torch.Tensor): raise TypeError(f"Input must be torch.Tensor, got {type(x)}") # Encode features = self.encoder(x) # Decode alpha = self.decoder(features) return alpha def forward_with_prob(self, image: torch.Tensor, prob: torch.Tensor) -> torch.Tensor: """ Forward pass with probability guidance Args: image: Input image (B, 3, H, W) prob: Probability mask (B, 1, H, W) Returns: torch.Tensor: Alpha matte (B, 1, H, W) """ if not isinstance(image, torch.Tensor) or not isinstance(prob, torch.Tensor): raise TypeError("Both inputs must be torch.Tensor") # Concatenate image and probability as input x = torch.cat([image, prob], dim=1) # (B, 4, H, W) # Forward pass return self.forward(x) def load_pretrained_weights(model: nn.Module, checkpoint_path: Union[str, Path]) -> nn.Module: """ Load pretrained weights with error handling Args: model: Model to load weights into checkpoint_path: Path to checkpoint file Returns: nn.Module: Model with loaded weights """ checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): print(f"Warning: Checkpoint not found at {checkpoint_path}") print("Using randomly initialized weights") return model try: # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location='cpu') # Extract state dict if isinstance(checkpoint, dict): if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint else: state_dict = checkpoint # Load weights with flexible key matching model_dict = model.state_dict() matched_dict = {} for key, value in state_dict.items(): # Remove module prefix if present clean_key = key.replace('module.', '') if clean_key in model_dict: if model_dict[clean_key].shape == value.shape: matched_dict[clean_key] = value else: print(f"Shape mismatch for {clean_key}: model {model_dict[clean_key].shape} vs checkpoint {value.shape}") else: print(f"Key not found in model: {clean_key}") # Load matched weights model_dict.update(matched_dict) model.load_state_dict(model_dict) print(f"Loaded {len(matched_dict)} weights from {checkpoint_path}") except Exception as e: print(f"Error loading checkpoint: {e}") print("Using randomly initialized weights") return model def get_matanyone_model(checkpoint_path: Union[str, Path], device: Union[str, torch.device] = 'cpu', backbone_channels: int = 3) -> nn.Module: """ FIXED MODEL LOADING: Create and load MatAnyone model Args: checkpoint_path: Path to model checkpoint device: Device to load model on backbone_channels: Number of input channels (3 for RGB, 4 for RGB + prob) Returns: nn.Module: Loaded model """ # Determine input channels based on usage # If we're using probability guidance, we need 4 channels (RGB + prob) # Otherwise, 3 channels (RGB only) input_channels = 3 # Support both RGB and RGB+prob inputs # Create model model = SimpleMatteModel(backbone_channels=input_channels) # Load pretrained weights if available model = load_pretrained_weights(model, checkpoint_path) # Move to device device = torch.device(device) model = model.to(device) model.eval() print(f"MatAnyone model loaded on {device}") print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") return model # Fallback for compatibility with original MatAnyone interface def build_model(*args, **kwargs): """Compatibility function for original MatAnyone interface""" return get_matanyone_model(*args, **kwargs) class ModelWrapper: """ Wrapper to match original MatAnyone model interface """ def __init__(self, model: nn.Module): self.model = model self.device = next(model.parameters()).device def __call__(self, *args, **kwargs): return self.model(*args, **kwargs) def eval(self): return self.model.eval() def train(self, mode=True): return self.model.train(mode) def to(self, device): return ModelWrapper(self.model.to(device)) def parameters(self): return self.model.parameters() def state_dict(self): return self.model.state_dict() def load_state_dict(self, state_dict): return self.model.load_state_dict(state_dict)