import torch, json import torchvision from torchvision import transforms, models from PIL import Image def build_model(arch, dropout, width, freeze_backbone, num_classes=2): import torch.nn as nn if arch == "smallcnn": class SmallCNN(nn.Module): def __init__(self, num_classes=2, dropout=0.2, width=32): super().__init__() c = width self.features = nn.Sequential( nn.Conv2d(3, c, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(c, 2*c, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(2*c, 4*c, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), ) self.head = nn.Sequential(nn.Flatten(), nn.Dropout(dropout), nn.Linear(4*c, num_classes)) def forward(self, x): return self.head(self.features(x)) return SmallCNN(num_classes=num_classes, dropout=dropout, width=width) elif arch == "resnet18": m = models.resnet18(weights=None) # weights not needed for inference after loading state_dict in_features = m.fc.in_features import torch.nn as nn m.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_features, num_classes)) return m elif arch == "mobilenet_v3_small": m = models.mobilenet_v3_small(weights=None) in_features = m.classifier[-1].in_features import torch.nn as nn m.classifier[-1] = nn.Linear(in_features, num_classes) return m else: raise ValueError("Unknown arch") def load_model(model_path="model_state.pt", config_path="config.json", device="cpu"): with open(config_path) as f: cfg = json.load(f) model = build_model(cfg["arch"], cfg["dropout"], cfg["width"], cfg["freeze_backbone"], cfg["num_classes"]) state = torch.load(model_path, map_location=device) model.load_state_dict(state, strict=True) model.to(device).eval() tfm = transforms.Compose([ transforms.Resize(int(cfg["img_size"]*1.14)), transforms.CenterCrop(cfg["img_size"]), transforms.ToTensor(), transforms.Normalize(mean=cfg["mean"], std=cfg["std"]), ]) return model, tfm, cfg def predict_image(image_path, model, tfm, device="cpu"): img = Image.open(image_path).convert("RGB") x = tfm(img).unsqueeze(0).to(device) with torch.no_grad(): logits = model(x) probs = torch.softmax(logits, dim=1).cpu().numpy().ravel().tolist() pred = int(logits.argmax(dim=1).item()) return pred, probs