Spaces:
Sleeping
Sleeping
| from typing import List, Tuple | |
| import torch | |
| from PIL import Image | |
| from transformers import CLIPModel, CLIPProcessor | |
| class ModelIDZeroShot: | |
| def __init__(self, label_space: List[str]): | |
| self.labels = label_space | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.device = device | |
| self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| def top_label(self, image: Image.Image) -> str: | |
| prompts = [f"A photo of a {x}" for x in self.labels] | |
| inp = self.processor(text=prompts, images=image.convert("RGB"), return_tensors="pt", padding=True).to(self.device) | |
| out = self.model(**inp) | |
| idx = int(out.logits_per_image[0].argmax().item()) | |
| return self.labels[idx] | |