from weave_prompt import ImageSimilarityMetric from PIL import Image import lpips import torch import numpy as np class LPIPSImageSimilarityMetric(ImageSimilarityMetric): """Image similarity metric using LPIPS perceptual similarity.""" def __init__(self, net: str = 'alex', device: str = 'cpu'): self.lpips_model = lpips.LPIPS(net=net).to(device) self.device = device def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float: def img_to_tensor(img): img = img.convert('RGB') # Ensure image has 3 channels for handling PNG arr = np.array(img.resize((256, 256))).astype(np.float32) / 255.0 arr = arr.transpose(2, 0, 1) # HWC to CHW tensor = torch.tensor(arr).unsqueeze(0) return tensor * 2 - 1 # LPIPS expects [-1, 1] gen_tensor = img_to_tensor(generated_img).to(self.device) tgt_tensor = img_to_tensor(target_img).to(self.device) distance = self.lpips_model(gen_tensor, tgt_tensor).item() similarity = max(0.0, 1.0 - distance) return similarity