from typing import List, Dict from PIL import Image import torch, os import torch.nn as nn import torchvision.transforms as T from .config import DEFAULT_LABELS from .utils import softmax class SimpleVisionModel(nn.Module): """ Wrapper around a lightweight classifier. For training, use training/train_vision.py. At inference, if checkpoint absent or downloads fail, we return rule-based scores. """ def __init__(self, num_classes: int): super().__init__() try: import timm self.net = timm.create_model("mobilenetv3_small_100", pretrained=True, num_classes=num_classes) except Exception: self.net = nn.Sequential( nn.AdaptiveAvgPool2d((8,8)), nn.Flatten(), nn.Linear(8*8*3, 128), nn.ReLU(), nn.Linear(128, num_classes) ) def forward(self, x): return self.net(x) class VisionInference: def __init__(self, labels: List[str] = None, ckpt_path: str = "checkpoints/vision/best.pt"): self.labels = labels or DEFAULT_LABELS self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = SimpleVisionModel(num_classes=len(self.labels)).to(self.device) self.transform = T.Compose([T.Resize((224,224)), T.ToTensor()]) self.ready = False if os.path.exists(ckpt_path): try: state = torch.load(ckpt_path, map_location=self.device) self.model.load_state_dict(state["model"] if "model" in state else state) self.ready = True except Exception: self.ready = False @torch.no_grad() def predict(self, image: Image.Image) -> Dict[str, float]: if image is None: return {l: 0.0 for l in self.labels} try: x = self.transform(image.convert("RGB")).unsqueeze(0).to(self.device) logits = self.model(x)[0].detach().cpu().tolist() probs = softmax(logits) return {lbl: float(p) for lbl, p in zip(self.labels, probs)} except Exception: import numpy as np img = image.convert("RGB").resize((64,64)) arr = np.array(img).astype("float32")/255.0 gray = arr.mean(axis=2) contrast = float(gray.std()) red_mean = float(arr[:,:,0].mean()) green_mean = float(arr[:,:,1].mean()) blue_mean = float(arr[:,:,2].mean()) scores = {l: 0.01 for l in self.labels} if contrast > 0.22: scores["scratch_dent"] += 0.2 scores["paint_damage"] += 0.15 scores["bumper_damage"] += 0.1 if blue_mean < 0.35 and green_mean < 0.35: scores["rust"] += 0.2 if red_mean > 0.55: scores["engine_leak"] += 0.15 s = sum(scores.values()) return {k: v/s for k,v in scores.items()}