File size: 3,695 Bytes
cef9e84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""For a given parent directory, it consideres all of their subdirectories as different experiments. For each experiment, it finds all subdirectories that start with "val" and compute the metrics on this subdirectory. 
The subdirectory should contain wav files with the same name as the test dataset directory.
"""

import os
import torch
from audioldm_eval import EvaluationHelper

SAMPLE_RATE = 16000
device = torch.device(f"cuda:{0}")
evaluator = EvaluationHelper(SAMPLE_RATE, device)


def locate_yaml_file(path):
    for file in os.listdir(path):
        if ".yaml" in file:
            return os.path.join(path, file)
    return None


def is_evaluated(path):
    candidates = []
    for file in os.listdir(
        os.path.dirname(path)
    ):  # all the file inside a experiment folder
        if ".json" in file:
            candidates.append(file)
    folder_name = os.path.basename(path)
    for candidate in candidates:
        if folder_name in candidate:
            return True
    return False


def locate_validation_output(path):
    folders = []
    for file in os.listdir(path):
        dirname = os.path.join(path, file)
        if "val_" in file and os.path.isdir(dirname):
            if not is_evaluated(dirname):
                folders.append(dirname)
    return folders


def evaluate_exp_performance(exp_name):
    abs_path_exp = os.path.join(latent_diffusion_model_log_path, exp_name)
    config_yaml_path = locate_yaml_file(abs_path_exp)

    if config_yaml_path is None:
        print("[INFO] %s does not contain a yaml configuration file" % exp_name)
        return

    folders_todo = locate_validation_output(abs_path_exp)

    for folder in folders_todo:
        if len(os.listdir(folder)) > 800 and len(os.listdir(folder)) < 5000:
            test_dataset = "audiocaps"
        elif len(os.listdir(folder)) > 5000:
            test_dataset = "musiccaps"
        else:
            print("[WARNING] skipping experiment", folder, " as it contains only", len(os.listdir(folder)), "files")
            continue

        test_audio_data_folder = os.path.join(test_audio_path, test_dataset)

        evaluator.main(folder, test_audio_data_folder)


def eval(exps):
    for exp in exps:
        try:
            evaluate_exp_performance(exp)
        except Exception as e:
            print(exp, e)


if __name__ == "__main__":

    import argparse

    parser = argparse.ArgumentParser(description="AudioLDM model evaluation")

    parser.add_argument(
        "-l", "--log_path", type=str, help="the log path", required=True
    )
    parser.add_argument(
        "-e",
        "--exp_name",
        type=str,
        help="the experiment name",
        required=False,
        default=None,
    )

    args = parser.parse_args()

    test_audio_path = "run_logs/genau/testset_data"
    latent_diffusion_model_log_path = args.log_path

    if latent_diffusion_model_log_path != "all":
        exp_name = args.exp_name
        if exp_name is None:
            exps = os.listdir(latent_diffusion_model_log_path)
            eval(exps)
        else:
            eval([exp_name])
    else:
        todo_list = [os.path.abspath("run_logs/genau")]
        for todo in todo_list:
            for latent_diffusion_model_log_path in os.listdir(todo):
                latent_diffusion_model_log_path = os.path.join(
                    todo, latent_diffusion_model_log_path
                )
                if not os.path.isdir(latent_diffusion_model_log_path):
                    continue
                print("[INFO] Evaluationg experiment:", latent_diffusion_model_log_path)
                exps = os.listdir(latent_diffusion_model_log_path)
                eval(exps)