serviceadvisor / car_core /models_zeroshot.py
viswanani's picture
Upload 16 files
94332c9 verified
raw
history blame contribute delete
873 Bytes
from typing import List, Tuple
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
class ModelIDZeroShot:
def __init__(self, label_space: List[str]):
self.labels = label_space
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
@torch.no_grad()
def top_label(self, image: Image.Image) -> str:
prompts = [f"A photo of a {x}" for x in self.labels]
inp = self.processor(text=prompts, images=image.convert("RGB"), return_tensors="pt", padding=True).to(self.device)
out = self.model(**inp)
idx = int(out.logits_per_image[0].argmax().item())
return self.labels[idx]