Spaces:
Runtime error
Runtime error
| """ | |
| Calculate Frechet Audio Distance betweeen two audio directories. | |
| Frechet distance implementation adapted from: https://github.com/mseitzer/pytorch-fid | |
| VGGish adapted from: https://github.com/harritaylor/torchvggish | |
| """ | |
| import os | |
| import numpy as np | |
| from glob import glob | |
| import torch | |
| from torch import nn | |
| from scipy import linalg | |
| from tqdm import tqdm | |
| import soundfile as sf | |
| import resampy | |
| from multiprocessing.dummy import Pool as ThreadPool | |
| from src.torchvggish.torchvggish.vggish import VGGishlocal | |
| SAMPLE_RATE = 16000 # resample audio file to SAMPLE_RATE. since uses the pretrained vggish model which takes wav_data as input | |
| def load_audio_task(fname):# load wav file and resample to SAMPLE_RATE | |
| wav_data, sr = sf.read(fname, dtype='int16') | |
| assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype | |
| wav_data = wav_data / 32768.0 # Convert to [-1.0, +1.0] | |
| # Convert to mono | |
| if len(wav_data.shape) > 1: | |
| wav_data = np.mean(wav_data, axis=1) | |
| if sr != SAMPLE_RATE: | |
| wav_data = resampy.resample(wav_data, sr, SAMPLE_RATE) | |
| return wav_data, SAMPLE_RATE | |
| # use pretrained torchvggish as embedding extractor, and calculate the statistic of wav file | |
| class FrechetAudioDistance: | |
| def __init__(self, use_pca=False, use_activation=False, verbose=False, audio_load_worker=8): | |
| # self.__get_model(use_pca=use_pca, use_activation=use_activation) | |
| self.__get_local_model(local_path='src/torchvggish/docs',use_pca=use_pca, use_activation=use_activation) | |
| self.verbose = verbose | |
| self.audio_load_worker = audio_load_worker | |
| def __get_model(self, use_pca=False, use_activation=False): | |
| """ | |
| Params: | |
| -- x : Either | |
| (i) a string which is the directory of a set of audio files, or | |
| (ii) a np.ndarray of shape (num_samples, sample_length) | |
| """ | |
| self.model = torch.hub.load('harritaylor/torchvggish', 'vggish') | |
| if not use_pca: | |
| self.model.postprocess = False | |
| if not use_activation: | |
| self.model.embeddings = nn.Sequential(*list(self.model.embeddings.children())[:-1]) | |
| self.model.eval() | |
| def __get_local_model(self,local_path,use_pca=False, use_activation=False): | |
| self.model = VGGishlocal(local_path) | |
| if not use_pca: | |
| self.model.postprocess = False | |
| if not use_activation: | |
| self.model.embeddings = nn.Sequential(*list(self.model.embeddings.children())[:-1]) | |
| self.model.eval() | |
| def get_embeddings(self, x, sr=16000): | |
| """ | |
| Get embeddings using VGGish model. | |
| Params: | |
| -- x : Either | |
| (i) a string which is the directory of a set of audio files, or | |
| (ii) a list of np.ndarray audio samples | |
| -- sr : Sampling rate, if x is a list of audio samples. Default value is 16000. | |
| """ | |
| embd_lst = [] | |
| if isinstance(x, list):# np.ndarray | |
| try: | |
| for audio, sr in tqdm(x, disable=(not self.verbose)): | |
| embd = self.model.forward(audio, sr) | |
| if self.model.device == torch.device('cuda'): | |
| embd = embd.cpu() | |
| embd = embd.detach().numpy() | |
| embd_lst.append(embd) | |
| except Exception as e: | |
| print("[Frechet Audio Distance] get_embeddings throw an exception: {}".format(str(e))) | |
| elif isinstance(x, str): | |
| try: | |
| for fname in tqdm(os.listdir(x), disable=(not self.verbose)): | |
| embd = self.model.forward(os.path.join(x, fname)) | |
| if self.model.device == torch.device('cuda'): | |
| embd = embd.cpu() | |
| embd = embd.detach().numpy() | |
| embd_lst.append(embd) | |
| except Exception as e: | |
| print("[Frechet Audio Distance] get_embeddings throw an exception: {}".format(str(e))) | |
| else: | |
| raise AttributeError | |
| # print("embd_lst_len",len(embd_lst)) | |
| return np.concatenate(embd_lst, axis=0) | |
| def calculate_embd_statistics(self, embd_lst): | |
| if isinstance(embd_lst, list): | |
| embd_lst = np.array(embd_lst) | |
| mu = np.mean(embd_lst, axis=0) | |
| sigma = np.cov(embd_lst, rowvar=False) | |
| return mu, sigma | |
| def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): | |
| """ | |
| Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py | |
| Numpy implementation of the Frechet Distance. | |
| The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) | |
| and X_2 ~ N(mu_2, C_2) is | |
| d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). | |
| Stable version by Dougal J. Sutherland. | |
| Params: | |
| -- mu1 : Numpy array containing the activations of a layer of the | |
| inception net (like returned by the function 'get_predictions') | |
| for generated samples. | |
| -- mu2 : The sample mean over activations, precalculated on an | |
| representative data set. | |
| -- sigma1: The covariance matrix over activations for generated samples. | |
| -- sigma2: The covariance matrix over activations, precalculated on an | |
| representative data set. | |
| Returns: | |
| -- : The Frechet Distance. | |
| """ | |
| # print(f"mu1.shape:{mu1.shape},mu2.shape:{sigma1.shape}") | |
| mu1 = np.atleast_1d(mu1) # shape(128,) | |
| mu2 = np.atleast_1d(mu2) | |
| sigma1 = np.atleast_2d(sigma1)# shape(128,128) | |
| sigma2 = np.atleast_2d(sigma2) | |
| assert mu1.shape == mu2.shape, \ | |
| 'Training and test mean vectors have different lengths' | |
| assert sigma1.shape == sigma2.shape, \ | |
| 'Training and test covariances have different dimensions' | |
| diff = mu1 - mu2 | |
| # Product might be almost singular | |
| covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |
| if not np.isfinite(covmean).all(): | |
| msg = ('fid calculation produces singular product; ' | |
| 'adding %s to diagonal of cov estimates') % eps | |
| print(msg) | |
| offset = np.eye(sigma1.shape[0]) * eps | |
| covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |
| # Numerical error might give slight imaginary component | |
| if np.iscomplexobj(covmean): | |
| if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |
| m = np.max(np.abs(covmean.imag)) | |
| raise ValueError('Imaginary component {}'.format(m)) | |
| covmean = covmean.real | |
| tr_covmean = np.trace(covmean) | |
| print(f"diff^2:{diff.dot(diff)}, sigma1:{np.trace(sigma1)},sigma2:{np.trace(sigma2)},2 * tr_covmean{2 * tr_covmean}") | |
| return (diff.dot(diff) + np.trace(sigma1) | |
| + np.trace(sigma2) - 2 * tr_covmean) | |
| def load_audio_files(self, dir):# load_audio_task会resample | |
| task_results = [] | |
| all_wav_files = glob(os.path.join(dir,"*.wav")) | |
| pool = ThreadPool(self.audio_load_worker) | |
| pbar = tqdm(total=len(all_wav_files), disable=(not self.verbose)) | |
| def update(*a): | |
| pbar.update() | |
| if self.verbose: | |
| print("[Frechet Audio Distance] Loading audio from {}...".format(dir)) | |
| for fname in all_wav_files: | |
| res = pool.apply_async(load_audio_task, args=(fname,), callback=update)# load_audio_task会resample | |
| task_results.append(res) | |
| pool.close() | |
| pool.join() | |
| return [k.get() for k in task_results] # get return value,each is (wav_data, sample_rate) | |
| def score(self, background_dir, eval_dir, store_embds=False): | |
| try: | |
| audio_background = self.load_audio_files(background_dir) | |
| audio_eval = self.load_audio_files(eval_dir) | |
| print("audios len",len(audio_background),len(audio_eval)) | |
| embds_background = self.get_embeddings(audio_background) # (N,128) | |
| embds_eval = self.get_embeddings(audio_eval) # (M,128) | |
| # print(embds_background.shape,embds_eval.shape) | |
| if store_embds: | |
| np.save("embds_background.npy", embds_background) | |
| np.save("embds_eval.npy", embds_eval) | |
| if len(embds_background) == 0: | |
| print("[Frechet Audio Distance] background set dir is empty, exitting...") | |
| return -1 | |
| if len(embds_eval) == 0: | |
| print("[Frechet Audio Distance] eval set dir is empty, exitting...") | |
| return -1 | |
| mu_background, sigma_background = self.calculate_embd_statistics(embds_background) | |
| mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval) | |
| fad_score = self.calculate_frechet_distance( | |
| mu_background, | |
| sigma_background, | |
| mu_eval, | |
| sigma_eval | |
| ) | |
| return fad_score | |
| except Exception as e: | |
| print("[Frechet Audio Distance] exception thrown, {}".format(str(e))) | |
| return -1 | |