genau-demo / GenAU /scripts /generate_and_eval.py
Moayed's picture
add demo files
cef9e84
"""For a given parent directory, it consideres all of their subdirectories as different experiments. For each experiment, it finds all subdirectories that start with "val" and compute the metrics on this subdirectory.
The subdirectory should contain wav files with the same name as the test dataset directory.
"""
import os
import torch
import shutil
from audioldm_eval import EvaluationHelper
from pytorch_lightning import Trainer, seed_everything
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
from pytorch_lightning import seed_everything
from src.utilities.data.videoaudio_dataset import VideoAudioDataset
from src.tools.training_utils import (
get_restore_step,
copy_test_subset_data,
)
from src.utilities.model.model_util import instantiate_from_config
from src.tools.configuration import Configuration
SAMPLE_RATE = 16000
devices = torch.cuda.device_count()
evaluator = EvaluationHelper(SAMPLE_RATE, torch.device(f"cuda:{0}"))
def locate_yaml_file(path):
for file in os.listdir(path):
if ".yaml" in file:
return os.path.join(path, file)
return None
def is_evaluated(path):
candidates = []
for file in os.listdir(
os.path.dirname(path)
): # all the file inside a experiment folder
if ".json" in file:
candidates.append(file)
folder_name = os.path.basename(path)
for candidate in candidates:
if folder_name in candidate:
return True
return False
def locate_validation_output(path):
folders = []
for file in os.listdir(path):
dirname = os.path.join(path, file)
if "val_" in file and os.path.isdir(dirname):
if not is_evaluated(dirname):
folders.append(dirname)
return folders
def evaluate_exp_performance(exp_name, evaluation_dataset=None):
abs_path_exp = os.path.join(latent_diffusion_model_log_path, exp_name)
config_yaml_path = locate_yaml_file(abs_path_exp)
if config_yaml_path is None:
print("%s does not contain a yaml configuration file" % exp_name)
return
folders_todo = locate_validation_output(abs_path_exp)
for folder in folders_todo:
if evaluation_dataset is not None:
test_dataset = evaluation_dataset
if len(os.listdir(folder)) == 964:
test_dataset = "audiocaps"
elif len(os.listdir(folder)) > 5000:
test_dataset = "musiccaps"
else:
print(f"[Warning, generate_and_eval.py] cannot identiy test dataset name at folder {folder}")
print(f"[INFO] Skipping folder {folder}")
continue
test_audio_data_folder = os.path.join(test_audio_path, test_dataset)
evaluator.main(folder, test_audio_data_folder)
@torch.no_grad()
def generate_test_audio(configs, config_yaml_path, exp_group_name,
exp_name, use_wav_cond=False, strategy='wo_ema',
batch_size=244, n_candidates_per_samples=1, ckpt=None,
evaluation_dataset=None):
if "seed" in configs.keys():
seed_everything(configs["seed"])
else:
print("SEED EVERYTHING TO 0")
seed_everything(0)
if "precision" in configs['training'].keys():
torch.set_float32_matmul_precision(
configs['training']["precision"]
) # highest, high, medium
log_path = configs['logging']["log_directory"]
if "dataloader_add_ons" in configs["data"].keys():
dataloader_add_ons = configs["data"]["dataloader_add_ons"]
else:
dataloader_add_ons = []
batch_size = configs["model"]["params"]["batchsize"]
# set up evaluation parameters
configs['model']['params']['evaluation_params']['n_candidates_per_samples'] = n_candidates_per_samples
if ckpt is not None:
configs["reload_from_ckpt"] = ckpt
val_dataset = VideoAudioDataset(configs, split="test", add_ons=dataloader_add_ons, load_video=False, load_audio=True, sample_single_caption=True)
val_loader = DataLoader(
val_dataset,
num_workers=12, # configs['data'].get('num_workers', 12),
batch_size=max(1, batch_size // configs['model']['params']['evaluation_params']['n_candidates_per_samples']),
)
config_reload_from_ckpt = configs.get("reload_from_ckpt", None)
checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints")
wandb_path = os.path.join(log_path, exp_group_name, exp_name)
os.makedirs(checkpoint_path, exist_ok=True)
shutil.copy(config_yaml_path, wandb_path)
if config_reload_from_ckpt is not None:
resume_from_checkpoint = config_reload_from_ckpt
try:
n_step = int(resume_from_checkpoint.split(".ckpt")[0].split("step=")[1])
except:
print("[Warning] cannot extract model step from the checkpoint filename, using UNK")
n_step = "UNK"
print("Reload given checkpoint %s" % resume_from_checkpoint)
elif len(os.listdir(checkpoint_path)) > 0:
print("Load checkpoint from path: %s" % checkpoint_path)
restore_step, n_step = get_restore_step(checkpoint_path)
resume_from_checkpoint = os.path.join(checkpoint_path, restore_step)
print("Resume from checkpoint", resume_from_checkpoint)
else:
raise "Please specify a pre-trained checkpoint"
guidance_scale = configs["model"]["params"]["evaluation_params"][
"unconditional_guidance_scale"
]
ddim_sampling_steps = configs["model"]["params"]["evaluation_params"][
"ddim_sampling_steps"
]
n_candidates_per_samples = configs["model"]["params"]["evaluation_params"][
"n_candidates_per_samples"
]
configs['model']['params']['ckpt_path'] = resume_from_checkpoint
# change log directory
configs['logging']["log_directory"] = configs['logging']["log_directory"].replace('train', 'evaluation')
latent_diffusion = instantiate_from_config(configs["model"])
latent_diffusion.set_log_dir(configs['logging']["log_directory"], exp_group_name, exp_name)
latent_diffusion.eval()
latent_diffusion = latent_diffusion.cuda()
if use_wav_cond:
latent_diffusion.random_clap_condition(text_prop=0.0)
name = latent_diffusion.get_validation_folder_name(guidance_scale, ddim_sampling_steps, n_candidates_per_samples, step=n_step, tag=strategy)
if strategy == 'wo_ema':
print("[INFO] Using No EMA strategy")
latent_diffusion.use_ema = False
latent_diffusion.name = name
latent_diffusion.unconditional_guidance_scale = guidance_scale
latent_diffusion.ddim_sampling_steps = ddim_sampling_steps
latent_diffusion.n_gen = n_candidates_per_samples
latent_diffusion.generate_sample(
val_loader,
name=name,
unconditional_guidance_scale=guidance_scale,
ddim_steps=ddim_sampling_steps,
n_gen=n_candidates_per_samples,
use_ema=(strategy != 'wo_ema')
)
# copy test data if it does not exists
if evaluation_dataset is not None:
assert evaluation_dataset==val_dataset.dataset_name, f"[ERROR, generate_and_eval.py] the given evaluation dataset {evaluation_dataset} and the specified dataset_name of the test dataset {val_dataset.dataset_name} do not match."
test_data_subset_folder = os.path.join(
os.path.dirname(configs['logging']["log_directory"]),
"testset_data",
val_dataset.dataset_name,
)
os.makedirs(test_data_subset_folder, exist_ok=True)
copy_test_subset_data(val_dataset, test_data_subset_folder)
def eval(exps, evaluation_dataset=None):
for exp in exps:
try:
evaluate_exp_performance(exp, evaluation_dataset=evaluation_dataset)
except Exception as e:
print(exp, e)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="AudioLDM model evaluation")
parser.add_argument(
"-c",
"--config_yaml",
type=str,
required=True,
help="path to config .yaml file",
)
parser.add_argument(
"--evaluation_dataset",
type=str,
default=None,
required=False,
help="target dataset to run the evaluation on",
)
parser.add_argument(
"-s",
"--strategy",
type=str,
required=False,
default='wo_ema',
help="The strategy of combining weights from different checkpoint: wo_ema, avg_ckpt, or ema",
)
parser.add_argument(
"-b",
"--batch_size",
type=int,
required=False,
default=64
)
parser.add_argument(
"-nc",
"--n_candidates_per_samples",
type=int,
required=False,
default=1,
help="Normally set it to 1, "
)
parser.add_argument(
"-ckpt",
type=str,
default=None
)
args = parser.parse_args()
assert args.strategy in ['wo_ema', 'avg_ckpt', 'ema']
config_yaml = args.config_yaml
exp_name = os.path.basename(config_yaml.split(".")[0])
exp_group_name = os.path.basename(os.path.dirname(config_yaml))
config_yaml_path = os.path.join(config_yaml)
configuration = Configuration(config_yaml_path)
configs = configuration.get_config()
# generate audio
generate_test_audio(configs, config_yaml_path, exp_group_name, exp_name,
strategy=args.strategy, batch_size=args.batch_size,
n_candidates_per_samples=args.n_candidates_per_samples, ckpt=args.ckpt,
evaluation_dataset=args.evaluation_dataset)
test_audio_path = os.path.join(
os.path.dirname(configs['logging']["log_directory"]),
"testset_data"
)
latent_diffusion_model_log_path = os.path.join(configs['logging']["log_directory"], exp_group_name)
# copy config path
shutil.copy(config_yaml_path, os.path.join(configs['logging']["log_directory"], exp_group_name, exp_name))
eval([exp_name], evaluation_dataset=args.evaluation_dataset)