Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| from skimage.metrics import structural_similarity as ssim | |
| from skimage.metrics import peak_signal_noise_ratio as psnr | |
| import lpips | |
| from tqdm import tqdm | |
| import cv2 | |
| class VideoQualityEvaluator: | |
| def __init__(self, device='cuda'): | |
| """Initialize video quality evaluator with specified computation device | |
| Args: | |
| device (str): Computation device ('cuda' or 'cpu') | |
| """ | |
| self.device = device | |
| # Initialize LPIPS model (perceptual similarity metric) | |
| self.lpips_model = lpips.LPIPS(net='alex').to(device) | |
| def _preprocess_frame(self, frame): | |
| """Convert frame to standardized format for evaluation | |
| Args: | |
| frame: Input frame (numpy array or torch tensor) | |
| Returns: | |
| Processed frame in HWC format with values in [0,1] | |
| """ | |
| if isinstance(frame, torch.Tensor): | |
| frame = frame.detach().cpu().numpy() | |
| # Normalize to [0,1] if needed | |
| if frame.max() > 1: | |
| frame = frame / 255.0 | |
| # Convert CHW to HWC if needed | |
| if len(frame.shape) == 3 and frame.shape[0] == 3: | |
| frame = frame.transpose(1, 2, 0) | |
| return frame | |
| def calculate_psnr(self, vid1, vid2): | |
| """Calculate average PSNR between two videos | |
| Args: | |
| vid1: First video (list/array of frames) | |
| vid2: Second video (list/array of frames) | |
| Returns: | |
| Mean PSNR value across all frames | |
| """ | |
| psnrs = [] | |
| for f1, f2 in zip(vid1, vid2): | |
| f1 = self._preprocess_frame(f1) | |
| f2 = self._preprocess_frame(f2) | |
| # Calculate PSNR for this frame pair | |
| psnrs.append(psnr(f1, f2, data_range=1.0)) | |
| return np.mean(psnrs) | |
| def calculate_ssim(self, vid1, vid2): | |
| """Calculate average SSIM between two videos | |
| Args: | |
| vid1: First video (list/array of frames) | |
| vid2: Second video (list/array of frames) | |
| Returns: | |
| Mean SSIM value across all frames | |
| """ | |
| ssims = [] | |
| for f1, f2 in zip(vid1, vid2): | |
| f1 = self._preprocess_frame(f1) | |
| f2 = self._preprocess_frame(f2) | |
| # Calculate SSIM for this frame pair (multichannel for color images) | |
| ssims.append(ssim(f1, f2, channel_axis=2, data_range=1.0)) | |
| return np.mean(ssims) | |
| def calculate_lpips(self, vid1, vid2): | |
| """Calculate average LPIPS (perceptual similarity) between two videos | |
| Args: | |
| vid1: First video (list/array of frames) | |
| vid2: Second video (list/array of frames) | |
| Returns: | |
| Mean LPIPS value across all frames (lower is better) | |
| """ | |
| lpips_values = [] | |
| for f1, f2 in zip(vid1, vid2): | |
| # Convert to torch tensor if needed | |
| if not isinstance(f1, torch.Tensor): | |
| f1 = torch.from_numpy(f1).permute(2, 0, 1).unsqueeze(0).float() # HWC -> 1CHW | |
| f2 = torch.from_numpy(f2).permute(2, 0, 1).unsqueeze(0).float() | |
| # Normalize to [-1,1] if needed | |
| if f1.max() > 1: | |
| f1 = f1 / 127.5 - 1.0 | |
| f2 = f2 / 127.5 - 1.0 | |
| f1 = f1.to(self.device) | |
| f2 = f2.to(self.device) | |
| # Calculate LPIPS with no gradients | |
| with torch.no_grad(): | |
| lpips_values.append(self.lpips_model(f1, f2).item()) | |
| return np.mean(lpips_values) | |
| def evaluate_videos(self, generated_video, reference_video, metrics=['psnr','lpips','ssim']): | |
| """Comprehensive video quality evaluation between generated and reference videos | |
| Args: | |
| generated_video: Model-generated video [T,H,W,C] or [T,C,H,W] | |
| reference_video: Ground truth reference video [T,H,W,C] or [T,C,H,W] | |
| metrics: List of metrics to compute ('psnr', 'ssim', 'lpips') | |
| Returns: | |
| Dictionary containing computed metric values | |
| """ | |
| results = {} | |
| # Verify video lengths match | |
| assert len(generated_video) == len(reference_video), "Videos must have same number of frames" | |
| # Calculate requested metrics | |
| if 'psnr' in metrics: | |
| results['psnr'] = self.calculate_psnr(generated_video, reference_video) | |
| if 'ssim' in metrics: | |
| results['ssim'] = self.calculate_ssim(generated_video, reference_video) | |
| if 'lpips' in metrics: | |
| results['lpips'] = self.calculate_lpips(generated_video, reference_video) | |
| return results | |