import gradio as gr import random import torch import logging from slam_llm.utils.model_utils import get_custom_model_factory from slam_llm.utils.dataset_utils import get_preprocessed_dataset from examples.tts.utils.codec_utils import audio_decode_cosyvoice from examples.tts.speech_dataset_tts_gradio import SpeechDatasetOnline import os import logging import soundfile as sf import torchaudio import hydra from omegaconf import DictConfig, ListConfig, OmegaConf import time import pdb import json from pathlib import Path from huggingface_hub import snapshot_download as hf_snapshot_download # from modelscope import snapshot_download as ms_snapshot_download import spaces import tempfile import uuid from pathlib import Path import sys # 引入 sys 以便打印警告 # --- 1. 全局配置和缓存 --- # 在全局范围只加载配置,不加载模型 # Global model cache to avoid reloading MODEL_CACHE = {} TOKENIZER_CACHE = {} config_path = "config/" train_config = OmegaConf.load(os.path.join(config_path, "train_config.yaml")) model_config = OmegaConf.load(os.path.join(config_path, "model_config.yaml")) dataset_config = OmegaConf.load(os.path.join(config_path, "dataset_config.yaml")) decode_config = OmegaConf.load(os.path.join(config_path, "decode_config.yaml")) kwargs = OmegaConf.load(os.path.join(config_path, "kwargs.yaml")) # 在启动时下载模型文件是安全的 if not Path(kwargs.ckpt_path).exists(): print("--- Model checkpoint not found, downloading from Hugging Face Hub... ---") hf_snapshot_download(repo_id='yhaha/EmoVoice', local_dir="./ckpts/") print("--- Model download complete. ---") # --- 2. 创建一个模型加载函数 (延迟加载) --- def load_model(): global MODEL_CACHE, TOKENIZER_CACHE if "EmoVoice_1.5B" not in MODEL_CACHE: print("--- initializing model... --") # 将所有与CUDA和模型初始化相关的代码移到这里 torch.cuda.manual_seed(train_config.seed) torch.manual_seed(train_config.seed) random.seed(train_config.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) model_factory = get_custom_model_factory(model_config) model, tokenizer = model_factory(train_config, model_config, **kwargs) model.to(device) tmp_device = next(model.parameters()).device print("tmp_device:",tmp_device) model.eval() print("--- Model loaded successfully to device:", device, "---") MODEL_CACHE["EmoVoice_1.5B"] = model TOKENIZER_CACHE["EmoVoice_1.5B"] = tokenizer # --- 3. 修改Gradio的推理函数 --- @spaces.GPU @torch.inference_mode() def infer_tts(text, emotion_text_prompt, neutral_speaker_wav): """ 在线输入 text = "The kettle SCREAMED as it reached boiling point, mirroring my inner tension." emotion_text_prompt = "Parallel emotions with rising heat, an audible cry of pent emotion." neutral_speaker_wav = "/nfs/yangguanrou.ygr/data/gpt4o_third/output_gpt4o/neutral/gpt4o_23948_neutral_ash.wav" """ # 从缓存中获取已加载的模型 model = MODEL_CACHE["EmoVoice_1.5B"] tokenizer = TOKENIZER_CACHE["EmoVoice_1.5B"] codec_decoder = model.codec_decoder print("debug") device = next(model.parameters()).device # 直接从模型获取正确的设备 print(device) input_path = "input/" output_path= "output/" if not os.path.exists(output_path): os.makedirs(output_path) dataset = SpeechDatasetOnline(dataset_config, tokenizer) batch = dataset.prepare_input(text, emotion_text_prompt, neutral_speaker_wav) print(batch) task_type = decode_config.task_type code_layer = model_config.vocab_config.code_layer code_type = model_config.code_type num_latency_tokens = dataset_config.num_latency_tokens modeling_paradigm = dataset_config.modeling_paradigm interleaved_text_token_num = dataset_config.interleaved_text_token_num interleaved_audio_token_num = dataset_config.interleaved_audio_token_num output_text_only = kwargs.get('output_text_only', False) speech_sample_rate = kwargs.get('speech_sample_rate', 24000) audio_prompt_path = kwargs.get('audio_prompt_path', None) tone_dir = "neutral_prompt_speech" for key in batch.keys(): batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key] audio_prompt_path = batch["neutral_speaker_wav"][0] if modeling_paradigm == "parallel" or modeling_paradigm == "interleaved": model_outputs = model.generate(**batch, **decode_config) if modeling_paradigm == "parallel" or modeling_paradigm == "serial": text_outputs = model_outputs[code_layer] audio_outputs = model_outputs[:code_layer] if modeling_paradigm != "serial": if audio_outputs[0].shape[0] == decode_config.max_new_tokens: return "default.wav" audio_tokens = [audio_outputs[layer] for layer in range(code_layer)] if code_layer > 0 else audio_outputs print(audio_tokens) audio_hat = audio_decode_cosyvoice(audio_tokens, model_config, codec_decoder, audio_prompt_path, code_layer, num_latency_tokens, speed=1.0) print(audio_hat) if audio_hat == None: return "default.wav" unique_filename = f"{uuid.uuid4()}.wav" save_path = os.path.join(output_path, unique_filename) print(save_path) sf.write(save_path, audio_hat.squeeze().cpu().numpy(), speech_sample_rate) if device == "cuda": torch.cuda.empty_cache() return save_path # --- 4. Gradio UI (无需修改) --- text = gr.Textbox(lines=1, label="Text", placeholder="e.g., 'Wobbly tables ruin everything!'") emotion_text_prompt = gr.Textbox(lines=1, label="Emotion Text Prompt", placeholder="e.g., 'Expressing aggravated displeasure and discontent.'") neutral_speaker_wav = gr.Audio(label="Neutral Speaker Prompt WAV", type="filepath") description_text = """ ### **EmoVoice** is a emotion-controllable TTS model that exploits large language models (LLMs) to enable fine-grained freestyle natural language emotion control. ### [📖 **Arxiv**](https://arxiv.org/abs/2504.12867) | [💻 **GitHub**](https://github.com/yanghaha0908/EmoVoice) | [🤗 **Model**](https://huggingface.co/yhaha/EmoVoice) | [🚀 **Space**](https://huggingface.co/spaces/chenxie95/EmoVoice) | [🌐 **Demo Page**](https://yanghaha0908.github.io/EmoVoice/) """ gr_interface = gr.Interface( fn=infer_tts, inputs=[text, emotion_text_prompt, neutral_speaker_wav], outputs=[ gr.Audio(label="🎵 Emotional Speech", type="filepath") ], title="EmoVoice: LLM-based Emotional Text-To-Speech Model with Freestyle Text Prompting", description=description_text, # flagging_mode="never", # cache_examples="lazy", ) if __name__ == "__main__": load_model() # gr_interface.queue(15).launch() gr_interface.launch()