Spaces:
Runtime error
Runtime error
| import torch | |
| from torchvision import transforms | |
| class DINOv2Processor: | |
| def __init__(self, model_name="dinov2_vitb14", device="cpu", image_size=518): | |
| self.model_name = model_name | |
| self.device = device | |
| self.image_size = image_size | |
| self.model = self._load_model() | |
| def _load_model(self): | |
| model = torch.hub.load('facebookresearch/dinov2', self.model_name) | |
| model.eval() | |
| model.to(self.device) | |
| return model | |
| def _preprocess_image(self, image): | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.CenterCrop(self.image_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ), | |
| ]) | |
| return preprocess(image) | |
| def compute_similarity(self, pil_image1, pil_image2): | |
| img1_t = self._preprocess_image(pil_image1).unsqueeze(0).to(self.device) | |
| img2_t = self._preprocess_image(pil_image2).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| feat1 = self.model(img1_t) | |
| feat2 = self.model(img2_t) | |
| feat1 = feat1 / feat1.norm(dim=1, keepdim=True) | |
| feat2 = feat2 / feat2.norm(dim=1, keepdim=True) | |
| similarity = (feat1 * feat2).sum(dim=1) | |
| return similarity.item() | |