Spaces:
Runtime error
Runtime error
| import base64 | |
| import json | |
| import time | |
| from types import SimpleNamespace | |
| # import spaces | |
| import gradio as gr | |
| import os | |
| import sys | |
| import yaml | |
| sys.path.insert(0, './') | |
| # from wenet.utils.init_tokenizer import init_tokenizer | |
| # from wenet.utils.init_model import init_model | |
| import logging | |
| # import librosa | |
| # import torch | |
| # import torchaudio | |
| import numpy as np | |
| def makedir_for_file(filepath): | |
| dirpath = os.path.dirname(filepath) | |
| if not os.path.exists(dirpath): | |
| os.makedirs(dirpath) | |
| def load_dict_from_yaml(file_path: str): | |
| with open(file_path, 'rt', encoding='utf-8') as f: | |
| dict_1 = yaml.load(f, Loader=yaml.FullLoader) | |
| return dict_1 | |
| # 获取当前脚本文件的绝对路径 | |
| abs_path = os.path.abspath(__file__) | |
| # 将图片转换为 Base64 | |
| with open(os.path.join(os.path.dirname(abs_path), "lab.png"), "rb") as image_file: | |
| encoded_string = base64.b64encode(image_file.read()).decode("utf-8") | |
| # with open("./cat.jpg", "rb") as image_file: | |
| # encoded_string = base64.b64encode(image_file.read()).decode("utf-8") | |
| # 自定义CSS样式 | |
| custom_css = """ | |
| /* 自定义CSS样式 */ | |
| """ | |
| # 任务提示映射 | |
| TASK_PROMPT_MAPPING = { | |
| "ASR (Automatic Speech Recognition)": "执行语音识别任务,将音频转换为文字。", | |
| "SRWT (Speech Recognition with Timestamps)": "请转录音频内容,并为每个英文词汇及其对应的中文翻译标注出精确到0.1秒的起止时间,时间范围用<>括起来。", | |
| "VED (Vocal Event Detection)(Categories:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other)": "请将音频转录为文字记录,并在记录末尾标注<音频事件>标签,音频事件共8种:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other。", | |
| "SER (Speech Emotion Recognition)(Categories:sad,anger,neutral,happy,surprise,fear,disgust,和other)": "请将音频内容转录成文字记录,并在记录末尾标注<情感>标签,情感共8种:sad,anger,neutral,happy,surprise,fear,disgust,和other。", | |
| "SSR (Speaking Style Recognition)(Categories:新闻科普,恐怖故事,童话故事,客服,诗歌散文,有声书,日常口语,其他)": "请将音频内容进行文字转录,并在最后添加<风格>标签,标签共8种:新闻科普、恐怖故事、童话故事、客服、诗歌散文、有声书、日常口语、其他。", | |
| "SGC (Speaker Gender Classification)(Categories:female,male)": "请将音频转录为文本,并在文本结尾处标注<性别>标签,性别为female或male。", | |
| "SAP (Speaker Age Prediction)(Categories:child、adult和old)": "请将音频转录为文本,并在文本结尾处标注<年龄>标签,年龄划分为child、adult和old三种。", | |
| "STTC (Speech to Text Chat)": "首先将语音转录为文字,然后对语音内容进行回复,转录和文字之间使用<开始回答>分割。" | |
| } | |
| def init_model_my(): | |
| logging.basicConfig(level=logging.DEBUG, | |
| format='%(asctime)s %(levelname)s %(message)s') | |
| config_path = "train.yaml" | |
| from huggingface_hub import hf_hub_download | |
| # 从Hugging Face下载.pt文件 | |
| pt_file_path = hf_hub_download(repo_id="ASLP-lab/OSUM", filename="infer.pt") | |
| args = SimpleNamespace(**{ | |
| "checkpoint": pt_file_path, | |
| }) | |
| configs = load_dict_from_yaml(config_path) | |
| model, configs = init_model(args, configs) | |
| model = model.cuda() | |
| tokenizer = init_tokenizer(configs) | |
| print(model) | |
| return model, tokenizer | |
| # global_model, tokenizer = init_model_my() | |
| print("model init success") | |
| def do_resample(input_wav_path, output_wav_path): | |
| """""" | |
| print(f'input_wav_path: {input_wav_path}, output_wav_path: {output_wav_path}') | |
| waveform, sample_rate = torchaudio.load(input_wav_path) | |
| # 检查音频的维度 | |
| num_channels = waveform.shape[0] | |
| # 如果音频是多通道的,则进行通道平均 | |
| if num_channels > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| waveform = torchaudio.transforms.Resample( | |
| orig_freq=sample_rate, new_freq=16000)(waveform) | |
| makedir_for_file(output_wav_path) | |
| torchaudio.save(output_wav_path, waveform, 16000) | |
| # @spaces.GPU | |
| def true_decode_fuc(input_wav_path, input_prompt): | |
| # input_prompt = TASK_PROMPT_MAPPING.get(input_prompt, "未知任务类型") | |
| print(f"wav_path: {input_wav_path}, prompt:{input_prompt}") | |
| timestamp_ms = int(time.time() * 1000) | |
| now_file_tmp_path_resample = f'/home/xlgeng/.cache/.temp/{timestamp_ms}_resample.wav' | |
| do_resample(input_wav_path, now_file_tmp_path_resample) | |
| input_wav_path = now_file_tmp_path_resample | |
| waveform, sample_rate = torchaudio.load(input_wav_path) | |
| waveform = waveform.squeeze(0) # (channel=1, sample) -> (sample,) | |
| print(f'wavform shape: {waveform.shape}, sample_rate: {sample_rate}') | |
| window = torch.hann_window(400) | |
| stft = torch.stft(waveform, | |
| 400, | |
| 160, | |
| window=window, | |
| return_complex=True) | |
| magnitudes = stft[..., :-1].abs() ** 2 | |
| filters = torch.from_numpy( | |
| librosa.filters.mel(sr=sample_rate, | |
| n_fft=400, | |
| n_mels=80)) | |
| mel_spec = filters @ magnitudes | |
| # NOTE(xcsong): https://github.com/openai/whisper/discussions/269 | |
| log_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
| log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | |
| log_spec = (log_spec + 4.0) / 4.0 | |
| feat = log_spec.transpose(0, 1) | |
| feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).cuda() | |
| feat = feat.unsqueeze(0).cuda() | |
| # feat = feat.half() | |
| # feat_lens = feat_lens.half() | |
| model = global_model.cuda() | |
| model.eval() | |
| res_text = model.generate(wavs=feat, wavs_len=feat_lens, prompt=input_prompt)[0] | |
| print("耿雪龙哈哈:", res_text) | |
| return res_text | |
| def do_decode(input_wav_path, input_prompt): | |
| print(f'input_wav_path= {input_wav_path}, input_prompt= {input_prompt}') | |
| # 省略处理逻辑 | |
| # output_res= true_decode_fuc(input_wav_path, input_prompt) | |
| output_res = f"耿雪龙哈哈:测试结果, input_wav_path= {input_wav_path}, input_prompt= {input_prompt}" | |
| return output_res | |
| def save_to_jsonl(if_correct, wav, prompt, res): | |
| data = { | |
| "if_correct": if_correct, | |
| "wav": wav, | |
| "task": prompt, | |
| "res": res | |
| } | |
| with open("results.jsonl", "a", encoding="utf-8") as f: | |
| f.write(json.dumps(data, ensure_ascii=False) + "\n") | |
| def handle_submit(input_wav_path, input_prompt): | |
| output_res = do_decode(input_wav_path, input_prompt) | |
| return output_res | |
| def download_audio(input_wav_path): | |
| if input_wav_path: | |
| # 返回文件路径供下载 | |
| return input_wav_path | |
| else: | |
| return None | |
| # 自定义 CSS 样式 | |
| CSS = """ | |
| .custom-footer { | |
| position: fixed; | |
| bottom: 20px; /* 距离页面底部的距离 */ | |
| left: 50%; | |
| transform: translateX(-50%); | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| gap: 20px; | |
| text-align: center; | |
| font-weight: bold; | |
| padding-bottom: 20px; /* 在底部添加额外的间距 */ | |
| } | |
| .custom-footer p { | |
| margin: 0; | |
| } | |
| .custom-footer img { | |
| height: 80px; | |
| width: auto; | |
| } | |
| """ | |
| # 创建 Gradio 界面 | |
| with gr.Blocks(css=CSS) as demo: | |
| # 添加标题 | |
| gr.Markdown( | |
| """ | |
| <div style="display: flex; align-items: center; justify-content: center; text-align: center;"> | |
| <h1 style="font-family: 'Arial', sans-serif; color: #014377; font-size: 32px; margin-bottom: 0; display: inline-block; vertical-align: middle;"> | |
| OSUM Speech Understanding Model Test | |
| </h1> | |
| </div> | |
| """ | |
| ) | |
| # 添加音频输入和任务选择 | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio(label="Record", type="filepath") | |
| with gr.Column(scale=1, min_width=300): | |
| output_text = gr.Textbox(label="Output", lines=8, placeholder="The generated result will be displayed here...", interactive=False) | |
| # 添加任务选择和自定义输入框 | |
| with gr.Row(): | |
| task_dropdown = gr.Dropdown( | |
| label="Task", | |
| choices=list(TASK_PROMPT_MAPPING.keys()) + ["Custom Task Prompt"], | |
| value="ASR (Automatic Speech Recognition)" | |
| ) | |
| custom_prompt_input = gr.Textbox(label="Custom Task Prompt", placeholder="Please enter a custom task prompt...", visible=False) | |
| # 添加按钮(下载按钮在左边,开始处理按钮在右边) | |
| with gr.Row(): | |
| download_button = gr.DownloadButton("Download Recording", variant="secondary", elem_classes=["button-height", "download-button"]) | |
| submit_button = gr.Button("Start to Process", variant="primary", elem_classes=["button-height", "submit-button"]) | |
| # 添加确认组件 | |
| with gr.Row(visible=False) as confirmation_row: | |
| gr.Markdown("Please determine whether the result is correct:") | |
| confirmation_buttons = gr.Radio( | |
| choices=["Correct", "Incorrect"], | |
| label="", | |
| interactive=True, | |
| container=False, | |
| elem_classes="confirmation-buttons" | |
| ) | |
| save_button = gr.Button("Submit Feedback", variant="secondary") | |
| # 添加底部内容 | |
| gr.HTML( | |
| f""" | |
| <div class="custom-footer"> | |
| <div> | |
| <p> | |
| <a href="http://www.nwpu-aslp.org/" target="_blank">Audio, Speech and Language Processing Group (ASLP@NPU)</a>, | |
| </p> | |
| <p> | |
| Northwestern Polytechnical University | |
| </p> | |
| <p> | |
| <a href="https://github.com/ASLP-lab/OSUM" target="_blank">GitHub</a> | |
| </p> | |
| </div> | |
| <img src="data:image/png;base64,{encoded_string}" alt="OSUM Logo"> | |
| </div> | |
| """ | |
| ) | |
| # 绑定事件 | |
| def show_confirmation(output_res, input_wav_path, input_prompt): | |
| return gr.update(visible=True), output_res, input_wav_path, input_prompt | |
| def save_result(if_correct, wav, prompt, res): | |
| save_to_jsonl(if_correct, wav, prompt, res) | |
| return gr.update(visible=False) | |
| def handle_submit(input_wav_path, task_choice, custom_prompt): | |
| try: | |
| if task_choice == "Custom Task Prompt": | |
| input_prompt = custom_prompt | |
| else: | |
| input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型") | |
| output_res = do_decode(input_wav_path, input_prompt) | |
| return output_res | |
| except Exception as e: | |
| print(f"Error in handle_submit: {e}") | |
| return "Error occurred. Please check the input." | |
| # 当任务选择框的值发生变化时,更新自定义输入框的可见性 | |
| task_dropdown.change( | |
| fn=lambda choice: gr.update(visible=choice == "Custom Task Prompt"), | |
| inputs=task_dropdown, | |
| outputs=custom_prompt_input | |
| ) | |
| submit_button.click( | |
| fn=handle_submit, | |
| inputs=[audio_input, task_dropdown, custom_prompt_input], | |
| outputs=output_text | |
| ).then( | |
| fn=show_confirmation, | |
| inputs=[output_text, audio_input, task_dropdown], | |
| outputs=[confirmation_row, output_text, audio_input, task_dropdown] | |
| ) | |
| download_button.click( | |
| fn=download_audio, | |
| inputs=[audio_input], | |
| outputs=[download_button] | |
| ) | |
| save_button.click( | |
| fn=save_result, | |
| inputs=[confirmation_buttons, audio_input, task_dropdown, output_text], | |
| outputs=confirmation_row | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |