kevinkyi commited on
Commit
0081e70
·
verified ·
1 Parent(s): 23ce1ba

Add inference.py

Browse files
Files changed (1) hide show
  1. inference.py +58 -0
inference.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, json
2
+ import torchvision
3
+ from torchvision import transforms, models
4
+ from PIL import Image
5
+
6
+ def build_model(arch, dropout, width, freeze_backbone, num_classes=2):
7
+ import torch.nn as nn
8
+ if arch == "smallcnn":
9
+ class SmallCNN(nn.Module):
10
+ def __init__(self, num_classes=2, dropout=0.2, width=32):
11
+ super().__init__()
12
+ c = width
13
+ self.features = nn.Sequential(
14
+ nn.Conv2d(3, c, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
15
+ nn.Conv2d(c, 2*c, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
16
+ nn.Conv2d(2*c, 4*c, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
17
+ )
18
+ self.head = nn.Sequential(nn.Flatten(), nn.Dropout(dropout), nn.Linear(4*c, num_classes))
19
+ def forward(self, x): return self.head(self.features(x))
20
+ return SmallCNN(num_classes=num_classes, dropout=dropout, width=width)
21
+ elif arch == "resnet18":
22
+ m = models.resnet18(weights=None) # weights not needed for inference after loading state_dict
23
+ in_features = m.fc.in_features
24
+ import torch.nn as nn
25
+ m.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_features, num_classes))
26
+ return m
27
+ elif arch == "mobilenet_v3_small":
28
+ m = models.mobilenet_v3_small(weights=None)
29
+ in_features = m.classifier[-1].in_features
30
+ import torch.nn as nn
31
+ m.classifier[-1] = nn.Linear(in_features, num_classes)
32
+ return m
33
+ else:
34
+ raise ValueError("Unknown arch")
35
+
36
+ def load_model(model_path="model_state.pt", config_path="config.json", device="cpu"):
37
+ with open(config_path) as f:
38
+ cfg = json.load(f)
39
+ model = build_model(cfg["arch"], cfg["dropout"], cfg["width"], cfg["freeze_backbone"], cfg["num_classes"])
40
+ state = torch.load(model_path, map_location=device)
41
+ model.load_state_dict(state, strict=True)
42
+ model.to(device).eval()
43
+ tfm = transforms.Compose([
44
+ transforms.Resize(int(cfg["img_size"]*1.14)),
45
+ transforms.CenterCrop(cfg["img_size"]),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(mean=cfg["mean"], std=cfg["std"]),
48
+ ])
49
+ return model, tfm, cfg
50
+
51
+ def predict_image(image_path, model, tfm, device="cpu"):
52
+ img = Image.open(image_path).convert("RGB")
53
+ x = tfm(img).unsqueeze(0).to(device)
54
+ with torch.no_grad():
55
+ logits = model(x)
56
+ probs = torch.softmax(logits, dim=1).cpu().numpy().ravel().tolist()
57
+ pred = int(logits.argmax(dim=1).item())
58
+ return pred, probs