Ravi-9's picture
Adding train/evaluate/metrics files
86e31ba verified
raw
history blame
1.84 kB
# utils/metrics.py
import numpy as np
import librosa
def calculate_msd(pred_audio, target_audio, sr=22050):
"""
Mel Spectral Distance (MSD) between predicted and target audio.
"""
# Convert to mel-spectrogram
pred_mel = librosa.feature.melspectrogram(y=pred_audio, sr=sr)
target_mel = librosa.feature.melspectrogram(y=target_audio, sr=sr)
# Convert to dB
pred_db = librosa.power_to_db(pred_mel, ref=np.max)
target_db = librosa.power_to_db(target_mel, ref=np.max)
# Mean squared difference
return np.mean((pred_db - target_db) ** 2)
def calculate_f0_correlation(pred_audio, target_audio, sr=22050):
"""
Pitch correlation (F0 correlation) between predicted and target.
"""
f0_pred, _, _ = librosa.pyin(pred_audio, fmin=50, fmax=500, sr=sr)
f0_target, _, _ = librosa.pyin(target_audio, fmin=50, fmax=500, sr=sr)
# Remove NaNs
mask = ~np.isnan(f0_pred) & ~np.isnan(f0_target)
if np.sum(mask) == 0:
return 0.0
return np.corrcoef(f0_pred[mask], f0_target[mask])[0, 1]
def calculate_phoneme_accuracy(pred_phonemes, target_phonemes):
"""
Simple phoneme accuracy metric.
(Here, pred_phonemes and target_phonemes are lists of symbols)
"""
if len(target_phonemes) == 0:
return 0.0
correct = sum(p == t for p, t in zip(pred_phonemes, target_phonemes))
return correct / len(target_phonemes)
def calculate_spectral_convergence(pred_audio, target_audio, sr=22050):
"""
Spectral convergence: how close the predicted spectrum is to the target.
"""
pred_spec = np.abs(librosa.stft(pred_audio))
target_spec = np.abs(librosa.stft(target_audio))
numerator = np.linalg.norm(target_spec - pred_spec, 'fro')
denominator = np.linalg.norm(target_spec, 'fro')
return numerator / (denominator + 1e-8)