File size: 5,463 Bytes
6226a7a c4223ac 6226a7a c4223ac 6226a7a c4223ac 6226a7a c4223ac 6226a7a c4223ac 6226a7a c4223ac 6226a7a c4223ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""
PyTorch Inference implementation for LSNet models
"""
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.models import create_model
# Import to register lsnet models
try:
import lsnet.lsnet_artist # noqa: F401
except ImportError as e:
print(f"Error importing lsnet: {e}")
raise
class PyTorchInference:
def __init__(self, checkpoint_path, model_arch="lsnet_xl_artist", num_classes=None, feature_dim=None, device="cpu"):
"""
Initialize PyTorch inference
Args:
checkpoint_path: Path to PyTorch checkpoint file
model_arch: Model architecture name
num_classes: Number of output classes (inferred from checkpoint if None)
feature_dim: Feature dimension (inferred from checkpoint if None)
device: Device to run inference on
"""
self.checkpoint_path = checkpoint_path
self.model_arch = model_arch
self.device = device
# Hardcoded input size mapping - based on actual model definitions
self.input_size = self._get_input_size(model_arch)
print(f"Using input size: {self.input_size} for model {model_arch}")
# Load checkpoint
state_dict = self.load_checkpoint_state(checkpoint_path)
state_dict = self.normalize_state_dict_keys(state_dict)
# Resolve model parameters
if num_classes is None:
num_classes = self.resolve_num_classes(state_dict)
if feature_dim is None:
feature_dim = self.resolve_feature_dim(state_dict)
# Create model - don't pass img_size, let the model use its default
self.model = create_model(
model_arch,
pretrained=False,
num_classes=num_classes,
feature_dim=feature_dim,
)
# Load weights
self.model.load_state_dict(state_dict, strict=False)
self.model.to(device)
self.model.eval()
# Get transform - override with our correct input size
# We manually set the input_size instead of relying on the model's config
config = resolve_data_config({'input_size': (3, self.input_size, self.input_size)}, model=self.model)
self.transform = create_transform(**config)
print(f"Created transform with input size: {self.input_size}")
def _get_input_size(self, model_arch):
"""Get input size based on model architecture - hardcoded to match actual model definitions"""
if model_arch == 'lsnet_xl_artist_448':
return 448
else:
# All other artist models use 224
return 224
@staticmethod
def load_checkpoint_state(checkpoint_path: str):
"""Load checkpoint and return model weights"""
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
if isinstance(checkpoint, dict):
if "model" in checkpoint:
return checkpoint["model"]
if "model_ema" in checkpoint:
return checkpoint["model_ema"]
return checkpoint
@staticmethod
def normalize_state_dict_keys(state_dict):
"""Remove distributed training prefixes"""
normalized = {}
for key, value in state_dict.items():
if key.startswith("module."):
new_key = key[len("module.") :]
else:
new_key = key
normalized[new_key] = value
return normalized
@staticmethod
def resolve_num_classes(state_dict) -> int:
"""Infer number of classes from checkpoint"""
for key, value in state_dict.items():
if key.endswith("head.l.weight") or key.endswith("head.weight"):
return value.shape[0]
# Default
return 1000
@staticmethod
def resolve_feature_dim(state_dict) -> int:
"""Infer feature dimension from checkpoint"""
for key, value in state_dict.items():
if key.endswith("head.bn.weight"):
return value.shape[0]
for key, value in state_dict.items():
if "head" in key and "weight" in key and len(value.shape) >= 2:
return value.shape[1] if len(value.shape) > 1 else value.shape[0]
# Default
return 768
def preprocess(self, image):
"""
Preprocess PIL image for PyTorch inference
Args:
image: PIL Image
Returns:
torch tensor ready for inference
"""
image = image.convert("RGB")
tensor = self.transform(image)
print(f"Preprocessed image to tensor shape: {tensor.shape}")
return tensor.unsqueeze(0)
def predict(self, image, top_k=5, threshold=0.0):
"""
Run inference on image
Args:
image: PIL Image
top_k: Number of top predictions to return
threshold: Minimum confidence threshold
Returns:
logits: Raw model output as numpy array
"""
input_tensor = self.preprocess(image)
with torch.no_grad():
input_tensor = input_tensor.to(self.device)
print(f"Running inference on tensor shape: {input_tensor.shape}")
# Use return_features=False to get classification logits
logits = self.model(input_tensor, return_features=False)
return logits.cpu().numpy()[0] |