File size: 1,109 Bytes
c6eb9ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

from weave_prompt import ImageSimilarityMetric
from PIL import Image
import lpips
import torch
import numpy as np

class LPIPSImageSimilarityMetric(ImageSimilarityMetric):
    """Image similarity metric using LPIPS perceptual similarity."""
    def __init__(self, net: str = 'alex', device: str = 'cpu'):
        self.lpips_model = lpips.LPIPS(net=net).to(device)
        self.device = device

    def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float:
        def img_to_tensor(img):
            img = img.convert('RGB')  # Ensure image has 3 channels for handling PNG
            arr = np.array(img.resize((256, 256))).astype(np.float32) / 255.0
            arr = arr.transpose(2, 0, 1)  # HWC to CHW
            tensor = torch.tensor(arr).unsqueeze(0)
            return tensor * 2 - 1  # LPIPS expects [-1, 1]
        gen_tensor = img_to_tensor(generated_img).to(self.device)
        tgt_tensor = img_to_tensor(target_img).to(self.device)
        distance = self.lpips_model(gen_tensor, tgt_tensor).item()
        similarity = max(0.0, 1.0 - distance)
        return similarity