|
|
"""This recipe to train CLAP. |
|
|
It supports distillation using tinyCLAP (https://arxiv.org/abs/2311.14517). |
|
|
|
|
|
Authors |
|
|
* Francesco Paissan 2024 |
|
|
""" |
|
|
|
|
|
import sys |
|
|
|
|
|
import gradio as gr |
|
|
import speechbrain as sb |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchaudio |
|
|
import torchaudio.transforms as T |
|
|
from hyperpyyaml import load_hyperpyyaml |
|
|
from speechbrain.utils.distributed import run_on_main |
|
|
from speechbrain.utils.metric_stats import MetricStats |
|
|
|
|
|
torch.backends.cudnn.enabled = False |
|
|
|
|
|
eps = 1e-10 |
|
|
|
|
|
|
|
|
class CLAPBrain(sb.Brain): |
|
|
def preprocess(self, wavs): |
|
|
"""Pre-process wavs.""" |
|
|
x = self.hparams.spectrogram_extractor(wavs) |
|
|
x = self.hparams.logmel_extractor(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def prepare_txt_features(self, text): |
|
|
"""Prepares text features to input in CLAP text encoder.""" |
|
|
txt_inp = self.hparams.txt_tokenizer( |
|
|
text, |
|
|
max_length=self.hparams.text_max_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
).to(self.device) |
|
|
|
|
|
return txt_inp |
|
|
|
|
|
def compute_sim(self, audio_embed, caption_embed): |
|
|
"""Computes CLAP similarity metric.""" |
|
|
similarity = audio_embed @ caption_embed.t() |
|
|
|
|
|
return similarity |
|
|
|
|
|
def compute_forward(self, batch, stage): |
|
|
if len(batch) == 2: |
|
|
wavs, caption = batch |
|
|
else: |
|
|
wavs, caption, _, _ = batch |
|
|
|
|
|
wavs = wavs.to(self.device).squeeze(1) |
|
|
|
|
|
x_sb = self.preprocess(wavs) |
|
|
|
|
|
text_inp = self.prepare_txt_features(caption) |
|
|
|
|
|
txt_shared, aud_shared = self.hparams.clap( |
|
|
x_sb, |
|
|
text_inp.input_ids.data, |
|
|
text_inp.token_type_ids.data, |
|
|
text_inp.attention_mask.data, |
|
|
) |
|
|
|
|
|
if not hasattr(self.modules, "clap"): |
|
|
aud_shared_student, _, _ = self.modules.clap_student(x_sb) |
|
|
aud_shared_student = aud_shared_student / aud_shared_student.norm( |
|
|
dim=1, keepdim=True |
|
|
) |
|
|
|
|
|
return txt_shared, aud_shared, aud_shared_student |
|
|
|
|
|
|
|
|
def audio_preprocess(x, sample_rate): |
|
|
tmp, sr = torchaudio.load(x) |
|
|
resample = T.Resample(sr, sample_rate) |
|
|
|
|
|
tmp = resample(tmp) |
|
|
tmp = tmp.sum(0, keepdims=True) |
|
|
|
|
|
return tmp |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def inference_wrapper(clap_brain): |
|
|
def f(wav_path, prompt): |
|
|
clap_brain.modules.eval() |
|
|
tmp = audio_preprocess(wav_path, clap_brain.hparams.sample_rate) |
|
|
|
|
|
ret = clap_brain.compute_forward([tmp, prompt], stage=sb.Stage.TEST) |
|
|
sim = clap_brain.compute_sim(ret[2], ret[0]) |
|
|
|
|
|
return f"tinyCLAP similarity is: {round(sim.item(), 2)}" |
|
|
|
|
|
return f |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
hparams_file = "hparams/inference.yaml" |
|
|
|
|
|
|
|
|
with open(hparams_file) as fin: |
|
|
hparams = load_hyperpyyaml(fin, {}) |
|
|
|
|
|
|
|
|
if hparams["use_tensorboard"]: |
|
|
from speechbrain.utils.train_logger import TensorboardLogger |
|
|
|
|
|
hparams["tensorboard_train_logger"] = TensorboardLogger( |
|
|
hparams["tensorboard_logs_folder"] |
|
|
) |
|
|
|
|
|
hparams["clap"].to(hparams["device"]) |
|
|
hparams["clap"].requires_grad_(False) |
|
|
hparams["clap"].eval() |
|
|
|
|
|
if hparams["zs_eval"]: |
|
|
hparams["class_list"] = datasets["train"].dataset.classes |
|
|
|
|
|
if hparams["audioenc_name_student"] is not None: |
|
|
if hparams["projection_only"]: |
|
|
print("Freezing Base AudioEncoder. Updating only the projection layers.") |
|
|
hparams["student_model"].base.requires_grad_(False) |
|
|
|
|
|
hparams["spectrogram_extractor"].to(hparams["device"]) |
|
|
hparams["logmel_extractor"].to(hparams["device"]) |
|
|
|
|
|
clap_brain = CLAPBrain( |
|
|
modules=hparams["modules"], |
|
|
hparams=hparams, |
|
|
) |
|
|
|
|
|
if hparams["pretrained_CLAP"] is not None: |
|
|
print("Loading CLAP model...") |
|
|
run_on_main(hparams["load_CLAP"].collect_files) |
|
|
hparams["load_CLAP"].load_collected() |
|
|
|
|
|
inference_api = inference_wrapper(clap_brain) |
|
|
|
|
|
examples_list = [ |
|
|
["./tunztunz_music.wav", "this is the sound of house music"], |
|
|
["./siren.wav", "this is the sound of sirens wailing"], |
|
|
[ |
|
|
"./whistling_and_chirping.wav", |
|
|
"someone is whistling while birds are chirping", |
|
|
], |
|
|
] |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=inference_api, |
|
|
inputs=[gr.Audio(type="filepath"), gr.Textbox()], |
|
|
outputs=["text"], |
|
|
examples=examples_list, |
|
|
) |
|
|
demo.launch() |
|
|
|