genau-demo / GenAU /scripts /evaluate.py
Moayed's picture
add demo files
cef9e84
"""For a given parent directory, it consideres all of their subdirectories as different experiments. For each experiment, it finds all subdirectories that start with "val" and compute the metrics on this subdirectory.
The subdirectory should contain wav files with the same name as the test dataset directory.
"""
import os
import torch
from audioldm_eval import EvaluationHelper
SAMPLE_RATE = 16000
device = torch.device(f"cuda:{0}")
evaluator = EvaluationHelper(SAMPLE_RATE, device)
def locate_yaml_file(path):
for file in os.listdir(path):
if ".yaml" in file:
return os.path.join(path, file)
return None
def is_evaluated(path):
candidates = []
for file in os.listdir(
os.path.dirname(path)
): # all the file inside a experiment folder
if ".json" in file:
candidates.append(file)
folder_name = os.path.basename(path)
for candidate in candidates:
if folder_name in candidate:
return True
return False
def locate_validation_output(path):
folders = []
for file in os.listdir(path):
dirname = os.path.join(path, file)
if "val_" in file and os.path.isdir(dirname):
if not is_evaluated(dirname):
folders.append(dirname)
return folders
def evaluate_exp_performance(exp_name):
abs_path_exp = os.path.join(latent_diffusion_model_log_path, exp_name)
config_yaml_path = locate_yaml_file(abs_path_exp)
if config_yaml_path is None:
print("[INFO] %s does not contain a yaml configuration file" % exp_name)
return
folders_todo = locate_validation_output(abs_path_exp)
for folder in folders_todo:
if len(os.listdir(folder)) > 800 and len(os.listdir(folder)) < 5000:
test_dataset = "audiocaps"
elif len(os.listdir(folder)) > 5000:
test_dataset = "musiccaps"
else:
print("[WARNING] skipping experiment", folder, " as it contains only", len(os.listdir(folder)), "files")
continue
test_audio_data_folder = os.path.join(test_audio_path, test_dataset)
evaluator.main(folder, test_audio_data_folder)
def eval(exps):
for exp in exps:
try:
evaluate_exp_performance(exp)
except Exception as e:
print(exp, e)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="AudioLDM model evaluation")
parser.add_argument(
"-l", "--log_path", type=str, help="the log path", required=True
)
parser.add_argument(
"-e",
"--exp_name",
type=str,
help="the experiment name",
required=False,
default=None,
)
args = parser.parse_args()
test_audio_path = "run_logs/genau/testset_data"
latent_diffusion_model_log_path = args.log_path
if latent_diffusion_model_log_path != "all":
exp_name = args.exp_name
if exp_name is None:
exps = os.listdir(latent_diffusion_model_log_path)
eval(exps)
else:
eval([exp_name])
else:
todo_list = [os.path.abspath("run_logs/genau")]
for todo in todo_list:
for latent_diffusion_model_log_path in os.listdir(todo):
latent_diffusion_model_log_path = os.path.join(
todo, latent_diffusion_model_log_path
)
if not os.path.isdir(latent_diffusion_model_log_path):
continue
print("[INFO] Evaluationg experiment:", latent_diffusion_model_log_path)
exps = os.listdir(latent_diffusion_model_log_path)
eval(exps)