Spaces:
Runtime error
Runtime error
| import importlib | |
| from types import SimpleNamespace | |
| import gradio as gr | |
| import pandas as pd | |
| import spaces | |
| import torch | |
| from utmosv2.utils import get_dataset, get_model | |
| description = ( | |
| "# π UTMOSv2 demo\n\n" | |
| "[](https://github.com/sarulab-speech/UTMOSv2)\n\n" | |
| "This is a demonstration of MOS prediction using UTMOSv2. " | |
| "This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate." | |
| ) | |
| device = torch.device("cuda") | |
| config = importlib.import_module("utmosv2.config.fusion_stage3") | |
| cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")}) | |
| cfg.reproduce = False | |
| cfg.config = "fusion_stage3" | |
| cfg.print_config = False | |
| cfg.data_config = None | |
| cfg.phase = "inference" | |
| cfg.num_workers = 1 | |
| def predict_mos(audio_path: str, domain: str, quick: bool) -> float: | |
| data = pd.DataFrame({"file_path": [audio_path]}) | |
| data["dataset"] = domain | |
| data["mos"] = 0 | |
| preds = 0.0 | |
| for fold in range(5): | |
| cfg.now_fold = fold | |
| cfg.weight = f"models/fusion_stage3/fold{fold}_s42_best_model.pth" | |
| model = get_model(cfg, device).eval() | |
| for _ in range(5): | |
| test_dataset = get_dataset(cfg, data, "test") | |
| p = model(*[torch.tensor(t,dtype=torch.float32).unsqueeze(0).to(device) for t in test_dataset[0][:-1]]) | |
| preds += p.cpu().numpy()[0][0] | |
| if quick: | |
| return preds | |
| preds /= 25.0 | |
| return preds | |
| with gr.Blocks() as demo: | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio = gr.Audio(type="filepath", label="Audio") | |
| domain = gr.Dropdown( | |
| [ | |
| "sarulab", | |
| "bvcc", | |
| "somos", | |
| "blizzard2008", | |
| "blizzard2009", | |
| "blizzard2010-EH1", | |
| "blizzard2010-EH2", | |
| "blizzard2010-ES1", | |
| "blizzard2010-ES3", | |
| "blizzard2011", | |
| ], | |
| label="Data-domain ID for the MOS prediction", | |
| value="sarulab", | |
| ) | |
| quick = gr.Checkbox( | |
| label="Quick prediction", | |
| value=True, | |
| info=( | |
| "UTMOSv2 makes predictions repeatedly for five randomly selected frames " | |
| "of the input speech waveform for all five folds. " | |
| "To make quick predictions by reducing this to a single repetition, " | |
| "check this checkbox:", | |
| ), | |
| ) | |
| submit = gr.Button(value="Submit") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Predicted MOS", type="text") | |
| submit.click(fn=predict_mos, inputs=[audio, domain, quick], outputs=[output]) | |
| demo.queue().launch() |