|
|
""" |
|
|
PyTorch Inference implementation for LSNet models |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from timm.data import resolve_data_config |
|
|
from timm.data.transforms_factory import create_transform |
|
|
from timm.models import create_model |
|
|
|
|
|
|
|
|
try: |
|
|
import lsnet.lsnet_artist |
|
|
except ImportError as e: |
|
|
print(f"Error importing lsnet: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
class PyTorchInference: |
|
|
def __init__(self, checkpoint_path, model_arch="lsnet_xl_artist", num_classes=None, feature_dim=None, device="cpu"): |
|
|
""" |
|
|
Initialize PyTorch inference |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to PyTorch checkpoint file |
|
|
model_arch: Model architecture name |
|
|
num_classes: Number of output classes (inferred from checkpoint if None) |
|
|
feature_dim: Feature dimension (inferred from checkpoint if None) |
|
|
device: Device to run inference on |
|
|
""" |
|
|
self.checkpoint_path = checkpoint_path |
|
|
self.model_arch = model_arch |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.input_size = self._get_input_size(model_arch) |
|
|
print(f"Using input size: {self.input_size} for model {model_arch}") |
|
|
|
|
|
|
|
|
state_dict = self.load_checkpoint_state(checkpoint_path) |
|
|
state_dict = self.normalize_state_dict_keys(state_dict) |
|
|
|
|
|
|
|
|
if num_classes is None: |
|
|
num_classes = self.resolve_num_classes(state_dict) |
|
|
if feature_dim is None: |
|
|
feature_dim = self.resolve_feature_dim(state_dict) |
|
|
|
|
|
|
|
|
self.model = create_model( |
|
|
model_arch, |
|
|
pretrained=False, |
|
|
num_classes=num_classes, |
|
|
feature_dim=feature_dim, |
|
|
) |
|
|
|
|
|
|
|
|
self.model.load_state_dict(state_dict, strict=False) |
|
|
self.model.to(device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
|
|
|
config = resolve_data_config({'input_size': (3, self.input_size, self.input_size)}, model=self.model) |
|
|
self.transform = create_transform(**config) |
|
|
print(f"Created transform with input size: {self.input_size}") |
|
|
|
|
|
def _get_input_size(self, model_arch): |
|
|
"""Get input size based on model architecture - hardcoded to match actual model definitions""" |
|
|
if model_arch == 'lsnet_xl_artist_448': |
|
|
return 448 |
|
|
else: |
|
|
|
|
|
return 224 |
|
|
|
|
|
@staticmethod |
|
|
def load_checkpoint_state(checkpoint_path: str): |
|
|
"""Load checkpoint and return model weights""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
|
|
if isinstance(checkpoint, dict): |
|
|
if "model" in checkpoint: |
|
|
return checkpoint["model"] |
|
|
if "model_ema" in checkpoint: |
|
|
return checkpoint["model_ema"] |
|
|
return checkpoint |
|
|
|
|
|
@staticmethod |
|
|
def normalize_state_dict_keys(state_dict): |
|
|
"""Remove distributed training prefixes""" |
|
|
normalized = {} |
|
|
for key, value in state_dict.items(): |
|
|
if key.startswith("module."): |
|
|
new_key = key[len("module.") :] |
|
|
else: |
|
|
new_key = key |
|
|
normalized[new_key] = value |
|
|
return normalized |
|
|
|
|
|
@staticmethod |
|
|
def resolve_num_classes(state_dict) -> int: |
|
|
"""Infer number of classes from checkpoint""" |
|
|
for key, value in state_dict.items(): |
|
|
if key.endswith("head.l.weight") or key.endswith("head.weight"): |
|
|
return value.shape[0] |
|
|
|
|
|
return 1000 |
|
|
|
|
|
@staticmethod |
|
|
def resolve_feature_dim(state_dict) -> int: |
|
|
"""Infer feature dimension from checkpoint""" |
|
|
for key, value in state_dict.items(): |
|
|
if key.endswith("head.bn.weight"): |
|
|
return value.shape[0] |
|
|
|
|
|
for key, value in state_dict.items(): |
|
|
if "head" in key and "weight" in key and len(value.shape) >= 2: |
|
|
return value.shape[1] if len(value.shape) > 1 else value.shape[0] |
|
|
|
|
|
|
|
|
return 768 |
|
|
|
|
|
def preprocess(self, image): |
|
|
""" |
|
|
Preprocess PIL image for PyTorch inference |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
|
|
|
Returns: |
|
|
torch tensor ready for inference |
|
|
""" |
|
|
image = image.convert("RGB") |
|
|
tensor = self.transform(image) |
|
|
print(f"Preprocessed image to tensor shape: {tensor.shape}") |
|
|
return tensor.unsqueeze(0) |
|
|
|
|
|
def predict(self, image, top_k=5, threshold=0.0): |
|
|
""" |
|
|
Run inference on image |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
top_k: Number of top predictions to return |
|
|
threshold: Minimum confidence threshold |
|
|
|
|
|
Returns: |
|
|
logits: Raw model output as numpy array |
|
|
""" |
|
|
input_tensor = self.preprocess(image) |
|
|
|
|
|
with torch.no_grad(): |
|
|
input_tensor = input_tensor.to(self.device) |
|
|
print(f"Running inference on tensor shape: {input_tensor.shape}") |
|
|
|
|
|
logits = self.model(input_tensor, return_features=False) |
|
|
|
|
|
return logits.cpu().numpy()[0] |