Spaces:
Runtime error
Runtime error
| from weave_prompt import ImageSimilarityMetric | |
| from PIL import Image | |
| import lpips | |
| import torch | |
| import numpy as np | |
| class LPIPSImageSimilarityMetric(ImageSimilarityMetric): | |
| """Image similarity metric using LPIPS perceptual similarity.""" | |
| def __init__(self, net: str = 'alex', device: str = 'cpu'): | |
| self.lpips_model = lpips.LPIPS(net=net).to(device) | |
| self.device = device | |
| def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float: | |
| def img_to_tensor(img): | |
| img = img.convert('RGB') # Ensure image has 3 channels for handling PNG | |
| arr = np.array(img.resize((256, 256))).astype(np.float32) / 255.0 | |
| arr = arr.transpose(2, 0, 1) # HWC to CHW | |
| tensor = torch.tensor(arr).unsqueeze(0) | |
| return tensor * 2 - 1 # LPIPS expects [-1, 1] | |
| gen_tensor = img_to_tensor(generated_img).to(self.device) | |
| tgt_tensor = img_to_tensor(target_img).to(self.device) | |
| distance = self.lpips_model(gen_tensor, tgt_tensor).item() | |
| similarity = max(0.0, 1.0 - distance) | |
| return similarity | |