from typing import List, Tuple import torch from PIL import Image from transformers import CLIPModel, CLIPProcessor class ModelIDZeroShot: def __init__(self, label_space: List[str]): self.labels = label_space device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") @torch.no_grad() def top_label(self, image: Image.Image) -> str: prompts = [f"A photo of a {x}" for x in self.labels] inp = self.processor(text=prompts, images=image.convert("RGB"), return_tensors="pt", padding=True).to(self.device) out = self.model(**inp) idx = int(out.logits_per_image[0].argmax().item()) return self.labels[idx]