File size: 873 Bytes
94332c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from typing import List, Tuple
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor

class ModelIDZeroShot:
    def __init__(self, label_space: List[str]):
        self.labels = label_space
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    @torch.no_grad()
    def top_label(self, image: Image.Image) -> str:
        prompts = [f"A photo of a {x}" for x in self.labels]
        inp = self.processor(text=prompts, images=image.convert("RGB"), return_tensors="pt", padding=True).to(self.device)
        out = self.model(**inp)
        idx = int(out.logits_per_image[0].argmax().item())
        return self.labels[idx]