WeavePrompt / similarity_metrics.py
kevin1kevin1k's picture
Upload folder using huggingface_hub
fb2f0a7 verified
from weave_prompt import ImageSimilarityMetric
from PIL import Image
import lpips
import torch
import numpy as np
class LPIPSImageSimilarityMetric(ImageSimilarityMetric):
"""Image similarity metric using LPIPS perceptual similarity."""
def __init__(self, net: str = 'alex', device: str = 'cpu'):
self.lpips_model = lpips.LPIPS(net=net).to(device)
self.device = device
def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float:
def img_to_tensor(img):
img = img.convert('RGB') # Ensure image has 3 channels for handling PNG
arr = np.array(img.resize((256, 256))).astype(np.float32) / 255.0
arr = arr.transpose(2, 0, 1) # HWC to CHW
tensor = torch.tensor(arr).unsqueeze(0)
return tensor * 2 - 1 # LPIPS expects [-1, 1]
gen_tensor = img_to_tensor(generated_img).to(self.device)
tgt_tensor = img_to_tensor(target_img).to(self.device)
distance = self.lpips_model(gen_tensor, tgt_tensor).item()
similarity = max(0.0, 1.0 - distance)
return similarity