Spaces:
Build error
Build error
| from pathlib import Path | |
| import argparse | |
| from functools import partial | |
| import gradio as gr | |
| import torch | |
| from torchaudio.functional import resample | |
| import utils.train_util as train_util | |
| def load_model(cfg, | |
| ckpt_path, | |
| device): | |
| model = train_util.init_model_from_config(cfg["model"]) | |
| ckpt = torch.load(ckpt_path, "cpu") | |
| train_util.load_pretrained_model(model, ckpt) | |
| model.eval() | |
| model = model.to(device) | |
| tokenizer = train_util.init_obj_from_dict(cfg["tokenizer"]) | |
| if not tokenizer.loaded: | |
| tokenizer.load_state_dict(ckpt["tokenizer"]) | |
| model.set_index(tokenizer.bos, tokenizer.eos, tokenizer.pad) | |
| return model, tokenizer | |
| def infer(file, runner): | |
| sr, wav = file | |
| wav = torch.as_tensor(wav) | |
| if wav.dtype == torch.short: | |
| wav = wav / 2 ** 15 | |
| elif wav.dtype == torch.int: | |
| wav = wav / 2 ** 31 | |
| if wav.ndim > 1: | |
| wav = wav.mean(1) | |
| wav = resample(wav, sr, runner.target_sr) | |
| wav_len = len(wav) | |
| wav = wav.float().unsqueeze(0).to(runner.device) | |
| input_dict = { | |
| "mode": "inference", | |
| "wav": wav, | |
| "wav_len": [wav_len], | |
| "specaug": False, | |
| "sample_method": "beam", | |
| "beam_size": 3, | |
| } | |
| with torch.no_grad(): | |
| output_dict = runner.model(input_dict) | |
| seq = output_dict["seq"].cpu().numpy() | |
| cap = runner.tokenizer.decode(seq)[0] | |
| return cap | |
| # def input_toggle(input_type): | |
| # if input_type == "file": | |
| # return gr.update(visible=True), gr.update(visible=False) | |
| # elif input_type == "mic": | |
| # return gr.update(visible=False), gr.update(visible=True) | |
| class InferRunner: | |
| def __init__(self, model_name): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| exp_dir = Path(f"./checkpoints/{model_name.lower()}") | |
| cfg = train_util.load_config(exp_dir / "config.yaml") | |
| self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device) | |
| self.target_sr = cfg["target_sr"] | |
| def change_model(self, model_name): | |
| exp_dir = Path(f"./checkpoints/{model_name.lower()}") | |
| cfg = train_util.load_config(exp_dir / "config.yaml") | |
| self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device) | |
| self.target_sr = cfg["target_sr"] | |
| def change_model(radio): | |
| global infer_runner | |
| infer_runner.change_model(radio) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Markdown("# Lightweight EfficientNetB2-Transformer Audio Captioning") | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| [](https://arxiv.org/abs/2407.14329) | |
| [](https://github.com/wsntxxn/AudioCaption?tab=readme-ov-file#lightweight-effb2-transformer-model) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| radio = gr.Radio( | |
| ["AudioCaps", "Clotho"], | |
| value="AudioCaps", | |
| label="Select model" | |
| ) | |
| infer_runner = InferRunner(radio.value) | |
| file = gr.Audio(label="Input", visible=True) | |
| radio.change(fn=change_model, inputs=[radio,],) | |
| btn = gr.Button("Run") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Output") | |
| btn.click( | |
| fn=partial(infer, | |
| runner=infer_runner), | |
| inputs=[file,], | |
| outputs=output | |
| ) | |
| demo.launch() | |