Kaloscope-artist-style-classifier / inference_pytorch.py
DraconicDragon's picture
Update inference_pytorch.py
c4223ac verified
"""
PyTorch Inference implementation for LSNet models
"""
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.models import create_model
# Import to register lsnet models
try:
import lsnet.lsnet_artist # noqa: F401
except ImportError as e:
print(f"Error importing lsnet: {e}")
raise
class PyTorchInference:
def __init__(self, checkpoint_path, model_arch="lsnet_xl_artist", num_classes=None, feature_dim=None, device="cpu"):
"""
Initialize PyTorch inference
Args:
checkpoint_path: Path to PyTorch checkpoint file
model_arch: Model architecture name
num_classes: Number of output classes (inferred from checkpoint if None)
feature_dim: Feature dimension (inferred from checkpoint if None)
device: Device to run inference on
"""
self.checkpoint_path = checkpoint_path
self.model_arch = model_arch
self.device = device
# Hardcoded input size mapping - based on actual model definitions
self.input_size = self._get_input_size(model_arch)
print(f"Using input size: {self.input_size} for model {model_arch}")
# Load checkpoint
state_dict = self.load_checkpoint_state(checkpoint_path)
state_dict = self.normalize_state_dict_keys(state_dict)
# Resolve model parameters
if num_classes is None:
num_classes = self.resolve_num_classes(state_dict)
if feature_dim is None:
feature_dim = self.resolve_feature_dim(state_dict)
# Create model - don't pass img_size, let the model use its default
self.model = create_model(
model_arch,
pretrained=False,
num_classes=num_classes,
feature_dim=feature_dim,
)
# Load weights
self.model.load_state_dict(state_dict, strict=False)
self.model.to(device)
self.model.eval()
# Get transform - override with our correct input size
# We manually set the input_size instead of relying on the model's config
config = resolve_data_config({'input_size': (3, self.input_size, self.input_size)}, model=self.model)
self.transform = create_transform(**config)
print(f"Created transform with input size: {self.input_size}")
def _get_input_size(self, model_arch):
"""Get input size based on model architecture - hardcoded to match actual model definitions"""
if model_arch == 'lsnet_xl_artist_448':
return 448
else:
# All other artist models use 224
return 224
@staticmethod
def load_checkpoint_state(checkpoint_path: str):
"""Load checkpoint and return model weights"""
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
if isinstance(checkpoint, dict):
if "model" in checkpoint:
return checkpoint["model"]
if "model_ema" in checkpoint:
return checkpoint["model_ema"]
return checkpoint
@staticmethod
def normalize_state_dict_keys(state_dict):
"""Remove distributed training prefixes"""
normalized = {}
for key, value in state_dict.items():
if key.startswith("module."):
new_key = key[len("module.") :]
else:
new_key = key
normalized[new_key] = value
return normalized
@staticmethod
def resolve_num_classes(state_dict) -> int:
"""Infer number of classes from checkpoint"""
for key, value in state_dict.items():
if key.endswith("head.l.weight") or key.endswith("head.weight"):
return value.shape[0]
# Default
return 1000
@staticmethod
def resolve_feature_dim(state_dict) -> int:
"""Infer feature dimension from checkpoint"""
for key, value in state_dict.items():
if key.endswith("head.bn.weight"):
return value.shape[0]
for key, value in state_dict.items():
if "head" in key and "weight" in key and len(value.shape) >= 2:
return value.shape[1] if len(value.shape) > 1 else value.shape[0]
# Default
return 768
def preprocess(self, image):
"""
Preprocess PIL image for PyTorch inference
Args:
image: PIL Image
Returns:
torch tensor ready for inference
"""
image = image.convert("RGB")
tensor = self.transform(image)
print(f"Preprocessed image to tensor shape: {tensor.shape}")
return tensor.unsqueeze(0)
def predict(self, image, top_k=5, threshold=0.0):
"""
Run inference on image
Args:
image: PIL Image
top_k: Number of top predictions to return
threshold: Minimum confidence threshold
Returns:
logits: Raw model output as numpy array
"""
input_tensor = self.preprocess(image)
with torch.no_grad():
input_tensor = input_tensor.to(self.device)
print(f"Running inference on tensor shape: {input_tensor.shape}")
# Use return_features=False to get classification logits
logits = self.model(input_tensor, return_features=False)
return logits.cpu().numpy()[0]