Spaces:
Sleeping
Sleeping
| import os | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from loguru import logger | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import yaml | |
| # ---------- Settings ---------- | |
| GPU_ID = '-1' | |
| os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID | |
| DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu' | |
| SERVER_PORT = 42208 | |
| SERVER_NAME = "0.0.0.0" | |
| SSL_DIR = './keyble_ssl' | |
| FS = 16000 | |
| resamplers = {} | |
| MIN_REQUIRED_WAV_LENGTH = 1040 | |
| # EXAMPLE_DIR = './examples' | |
| # en_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "en", '*.wav'))) | |
| # jp_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "jp", '*.wav'))) | |
| # zh_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "zh", '*.wav'))) | |
| # ---------- Logging ---------- | |
| logger.add('app.log', mode='a') | |
| logger.info('============================= App restarted =============================') | |
| # ---------- Download models ---------- | |
| logger.info('============================= Download models ===========================') | |
| model_paths = { | |
| "SSL-MOS, all training sets": { | |
| "ckpt": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/checkpoint-86000steps.pkl"), | |
| "config": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/config.yml"), | |
| } | |
| } | |
| # ---------- Model ---------- | |
| models = {} | |
| for name, path_dict in model_paths.items(): | |
| logger.info(f'============================= Setting up model for {name} =============') | |
| checkpoint_path = path_dict["ckpt"] | |
| config_path = path_dict["config"] | |
| with open(config_path) as f: | |
| config = yaml.load(f, Loader=yaml.Loader) | |
| if config["model_type"] == "SSLMOS": | |
| from models.sslmos import SSLMOS | |
| model = SSLMOS( | |
| config["model_input"], | |
| num_listeners=config.get("num_listeners", None), | |
| num_domains=config.get("num_domains", None), | |
| **config["model_params"], | |
| ).to(DEVICE) | |
| model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]) | |
| model = model.eval().to(DEVICE) | |
| logger.info(f"Loaded model parameters from {checkpoint_path}.") | |
| models[name] = model | |
| def read_wav(wav_path): | |
| # read waveform | |
| waveform, sample_rate = torchaudio.load( | |
| wav_path, channels_first=False | |
| ) # waveform: [T, 1] | |
| # resample if needed | |
| if sample_rate != FS: | |
| resampler_key = f"{sample_rate}-{FS}" | |
| if resampler_key not in resamplers: | |
| resamplers[resampler_key] = torchaudio.transforms.Resample( | |
| sample_rate, FS, dtype=waveform.dtype | |
| ) | |
| waveform = resamplers[resampler_key](waveform) | |
| waveform = waveform.squeeze(-1) | |
| # always pad to a minumum length | |
| if waveform.shape[0] < MIN_REQUIRED_WAV_LENGTH: | |
| to_pad = (MIN_REQUIRED_WAV_LENGTH - waveform.shape[0]) // 2 | |
| waveform = F.pad(waveform, (to_pad, to_pad), "constant", 0) | |
| return waveform, sample_rate | |
| def predict(model_name, wav_file): | |
| x, fs = read_wav(wav_file) | |
| logger.info('wav file loaded') | |
| # set up model input | |
| model_input = x.unsqueeze(0).to(DEVICE) | |
| model_lengths = model_input.new_tensor([model_input.size(1)]).long() | |
| inputs = { | |
| config["model_input"]: model_input, | |
| config["model_input"] + "_lengths": model_lengths, | |
| } | |
| with torch.no_grad(): | |
| # model forward | |
| if config["inference_mode"] == "mean_listener": | |
| outputs = models[model_name].mean_listener_inference(inputs) | |
| elif config["inference_mode"] == "mean_net": | |
| outputs = models[model_name].mean_net_inference(inputs) | |
| pred_mean_scores = outputs["scores"].cpu().detach().numpy()[0] | |
| return pred_mean_scores | |
| with gr.Blocks(title="S3PRL-VC: Any-to-one voice conversion demo on VCC2020") as demo: | |
| gr.Markdown( | |
| """ | |
| # Demo for SHEET: Speech Human Evaluation Estimation Toolkit | |
| ### [[Paper (arXiv)]](https://arxiv.org/abs/2411.03715) [[Code]](https://github.com/unilight/sheet) | |
| **SHEET** is a subjective speech quality assessment (SSQA) toolkit designed to conduct SSQA research. It was specifically designed to interactive with MOS-Bench, a collective of datasets to benchmark SSQA models. | |
| In this demo, you can record your own voice or upload speech files to assess the quality. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Record your speech here!") | |
| input_wav = gr.Audio(label="Input speech", type='filepath') | |
| gr.Markdown("## Select a model!") | |
| model_name = gr.Radio(label="Model", choices=list(model_paths.keys())) | |
| evaluate_btn = gr.Button(value="Evaluate!") | |
| # gr.Markdown("### You can use these examples if using a microphone is too troublesome!") | |
| # gr.Markdown("I recorded the samples using my Macbook Pro, so there might be some noises.") | |
| # gr.Examples( | |
| # examples=en_examples, | |
| # inputs=input_wav, | |
| # label="English examples" | |
| # ) | |
| # gr.Examples( | |
| # examples=jp_examples, | |
| # inputs=input_wav, | |
| # label="Japanese examples" | |
| # ) | |
| # gr.Examples( | |
| # examples=zh_examples, | |
| # inputs=input_wav, | |
| # label="Mandarin examples" | |
| # ) | |
| with gr.Column(): | |
| gr.Markdown("## The predicted scores is here:") | |
| output_score = gr.Textbox(label="Prediction", interactive=False) | |
| evaluate_btn.click(predict, [model_name, input_wav], output_score) | |
| if __name__ == '__main__': | |
| try: | |
| demo.launch(debug=True) | |
| except KeyboardInterrupt as e: | |
| print(e) | |
| finally: | |
| demo.close() |