File size: 5,463 Bytes
6226a7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4223ac
 
 
 
6226a7a
 
 
 
 
 
 
 
 
 
c4223ac
6226a7a
 
 
 
 
 
 
 
 
 
 
 
c4223ac
 
 
6226a7a
c4223ac
 
 
 
 
 
 
 
 
6226a7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4223ac
6226a7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4223ac
6226a7a
 
 
c4223ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
PyTorch Inference implementation for LSNet models
"""

import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.models import create_model

# Import to register lsnet models
try:
    import lsnet.lsnet_artist  # noqa: F401
except ImportError as e:
    print(f"Error importing lsnet: {e}")
    raise


class PyTorchInference:
    def __init__(self, checkpoint_path, model_arch="lsnet_xl_artist", num_classes=None, feature_dim=None, device="cpu"):
        """
        Initialize PyTorch inference

        Args:
            checkpoint_path: Path to PyTorch checkpoint file
            model_arch: Model architecture name
            num_classes: Number of output classes (inferred from checkpoint if None)
            feature_dim: Feature dimension (inferred from checkpoint if None)
            device: Device to run inference on
        """
        self.checkpoint_path = checkpoint_path
        self.model_arch = model_arch
        self.device = device

        # Hardcoded input size mapping - based on actual model definitions
        self.input_size = self._get_input_size(model_arch)
        print(f"Using input size: {self.input_size} for model {model_arch}")

        # Load checkpoint
        state_dict = self.load_checkpoint_state(checkpoint_path)
        state_dict = self.normalize_state_dict_keys(state_dict)

        # Resolve model parameters
        if num_classes is None:
            num_classes = self.resolve_num_classes(state_dict)
        if feature_dim is None:
            feature_dim = self.resolve_feature_dim(state_dict)

        # Create model - don't pass img_size, let the model use its default
        self.model = create_model(
            model_arch,
            pretrained=False,
            num_classes=num_classes,
            feature_dim=feature_dim,
        )

        # Load weights
        self.model.load_state_dict(state_dict, strict=False)
        self.model.to(device)
        self.model.eval()

        # Get transform - override with our correct input size
        # We manually set the input_size instead of relying on the model's config
        config = resolve_data_config({'input_size': (3, self.input_size, self.input_size)}, model=self.model)
        self.transform = create_transform(**config)
        print(f"Created transform with input size: {self.input_size}")

    def _get_input_size(self, model_arch):
        """Get input size based on model architecture - hardcoded to match actual model definitions"""
        if model_arch == 'lsnet_xl_artist_448':
            return 448
        else:
            # All other artist models use 224
            return 224

    @staticmethod
    def load_checkpoint_state(checkpoint_path: str):
        """Load checkpoint and return model weights"""
        checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
        if isinstance(checkpoint, dict):
            if "model" in checkpoint:
                return checkpoint["model"]
            if "model_ema" in checkpoint:
                return checkpoint["model_ema"]
        return checkpoint

    @staticmethod
    def normalize_state_dict_keys(state_dict):
        """Remove distributed training prefixes"""
        normalized = {}
        for key, value in state_dict.items():
            if key.startswith("module."):
                new_key = key[len("module.") :]
            else:
                new_key = key
            normalized[new_key] = value
        return normalized

    @staticmethod
    def resolve_num_classes(state_dict) -> int:
        """Infer number of classes from checkpoint"""
        for key, value in state_dict.items():
            if key.endswith("head.l.weight") or key.endswith("head.weight"):
                return value.shape[0]
        # Default
        return 1000

    @staticmethod
    def resolve_feature_dim(state_dict) -> int:
        """Infer feature dimension from checkpoint"""
        for key, value in state_dict.items():
            if key.endswith("head.bn.weight"):
                return value.shape[0]

        for key, value in state_dict.items():
            if "head" in key and "weight" in key and len(value.shape) >= 2:
                return value.shape[1] if len(value.shape) > 1 else value.shape[0]

        # Default
        return 768

    def preprocess(self, image):
        """
        Preprocess PIL image for PyTorch inference

        Args:
            image: PIL Image

        Returns:
            torch tensor ready for inference
        """
        image = image.convert("RGB")
        tensor = self.transform(image)
        print(f"Preprocessed image to tensor shape: {tensor.shape}")
        return tensor.unsqueeze(0)

    def predict(self, image, top_k=5, threshold=0.0):
        """
        Run inference on image

        Args:
            image: PIL Image
            top_k: Number of top predictions to return
            threshold: Minimum confidence threshold

        Returns:
            logits: Raw model output as numpy array
        """
        input_tensor = self.preprocess(image)

        with torch.no_grad():
            input_tensor = input_tensor.to(self.device)
            print(f"Running inference on tensor shape: {input_tensor.shape}")
            # Use return_features=False to get classification logits
            logits = self.model(input_tensor, return_features=False)

        return logits.cpu().numpy()[0]