""" PyTorch Inference implementation for LSNet models """ import torch from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform from timm.models import create_model # Import to register lsnet models try: import lsnet.lsnet_artist # noqa: F401 except ImportError as e: print(f"Error importing lsnet: {e}") raise class PyTorchInference: def __init__(self, checkpoint_path, model_arch="lsnet_xl_artist", num_classes=None, feature_dim=None, device="cpu"): """ Initialize PyTorch inference Args: checkpoint_path: Path to PyTorch checkpoint file model_arch: Model architecture name num_classes: Number of output classes (inferred from checkpoint if None) feature_dim: Feature dimension (inferred from checkpoint if None) device: Device to run inference on """ self.checkpoint_path = checkpoint_path self.model_arch = model_arch self.device = device # Hardcoded input size mapping - based on actual model definitions self.input_size = self._get_input_size(model_arch) print(f"Using input size: {self.input_size} for model {model_arch}") # Load checkpoint state_dict = self.load_checkpoint_state(checkpoint_path) state_dict = self.normalize_state_dict_keys(state_dict) # Resolve model parameters if num_classes is None: num_classes = self.resolve_num_classes(state_dict) if feature_dim is None: feature_dim = self.resolve_feature_dim(state_dict) # Create model - don't pass img_size, let the model use its default self.model = create_model( model_arch, pretrained=False, num_classes=num_classes, feature_dim=feature_dim, ) # Load weights self.model.load_state_dict(state_dict, strict=False) self.model.to(device) self.model.eval() # Get transform - override with our correct input size # We manually set the input_size instead of relying on the model's config config = resolve_data_config({'input_size': (3, self.input_size, self.input_size)}, model=self.model) self.transform = create_transform(**config) print(f"Created transform with input size: {self.input_size}") def _get_input_size(self, model_arch): """Get input size based on model architecture - hardcoded to match actual model definitions""" if model_arch == 'lsnet_xl_artist_448': return 448 else: # All other artist models use 224 return 224 @staticmethod def load_checkpoint_state(checkpoint_path: str): """Load checkpoint and return model weights""" checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) if isinstance(checkpoint, dict): if "model" in checkpoint: return checkpoint["model"] if "model_ema" in checkpoint: return checkpoint["model_ema"] return checkpoint @staticmethod def normalize_state_dict_keys(state_dict): """Remove distributed training prefixes""" normalized = {} for key, value in state_dict.items(): if key.startswith("module."): new_key = key[len("module.") :] else: new_key = key normalized[new_key] = value return normalized @staticmethod def resolve_num_classes(state_dict) -> int: """Infer number of classes from checkpoint""" for key, value in state_dict.items(): if key.endswith("head.l.weight") or key.endswith("head.weight"): return value.shape[0] # Default return 1000 @staticmethod def resolve_feature_dim(state_dict) -> int: """Infer feature dimension from checkpoint""" for key, value in state_dict.items(): if key.endswith("head.bn.weight"): return value.shape[0] for key, value in state_dict.items(): if "head" in key and "weight" in key and len(value.shape) >= 2: return value.shape[1] if len(value.shape) > 1 else value.shape[0] # Default return 768 def preprocess(self, image): """ Preprocess PIL image for PyTorch inference Args: image: PIL Image Returns: torch tensor ready for inference """ image = image.convert("RGB") tensor = self.transform(image) print(f"Preprocessed image to tensor shape: {tensor.shape}") return tensor.unsqueeze(0) def predict(self, image, top_k=5, threshold=0.0): """ Run inference on image Args: image: PIL Image top_k: Number of top predictions to return threshold: Minimum confidence threshold Returns: logits: Raw model output as numpy array """ input_tensor = self.preprocess(image) with torch.no_grad(): input_tensor = input_tensor.to(self.device) print(f"Running inference on tensor shape: {input_tensor.shape}") # Use return_features=False to get classification logits logits = self.model(input_tensor, return_features=False) return logits.cpu().numpy()[0]