Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 SparkAudio | |
| # 2025 Xinsheng Wang (w.xinshawn@gmail.com) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import torch | |
| import soundfile as sf | |
| import logging | |
| import argparse | |
| import gradio as gr | |
| from datetime import datetime | |
| from cli.SparkTTS import SparkTTS | |
| from sparktts.utils.token_parser import LEVELS_MAP_UI | |
| from huggingface_hub import snapshot_download | |
| import spaces | |
| MODEL = None | |
| def initialize_model(model_dir=None, device="cpu"): | |
| """Load the model once at the beginning.""" | |
| if model_dir is None: | |
| logging.info(f"Downloading model to: {model_dir}") | |
| model_dir = snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B") | |
| logging.info(f"Loading model from: {model_dir}") | |
| device = torch.device(device) | |
| model = SparkTTS(model_dir, device) | |
| return model | |
| def generate(text, | |
| prompt_speech, | |
| prompt_text, | |
| gender, | |
| pitch, | |
| speed, | |
| ): | |
| """Generate audio from text.""" | |
| global MODEL | |
| # Initialize model if not already done | |
| if MODEL is None: | |
| MODEL = initialize_model(device="cuda" if torch.cuda.is_available() else "cpu") | |
| model = MODEL | |
| # if gpu available, move model to gpu | |
| if torch.cuda.is_available(): | |
| print("Moving model to GPU") | |
| model.to("cuda") | |
| with torch.no_grad(): | |
| wav = model.inference( | |
| text, | |
| prompt_speech, | |
| prompt_text, | |
| gender, | |
| pitch, | |
| speed, | |
| ) | |
| return wav | |
| def run_tts( | |
| text, | |
| prompt_text=None, | |
| prompt_speech=None, | |
| gender=None, | |
| pitch=None, | |
| speed=None, | |
| save_dir="example/results", | |
| ): | |
| """Perform TTS inference and save the generated audio.""" | |
| logging.info(f"Saving audio to: {save_dir}") | |
| if prompt_text is not None: | |
| prompt_text = None if len(prompt_text) <= 1 else prompt_text | |
| # Ensure the save directory exists | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Generate unique filename using timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
| save_path = os.path.join(save_dir, f"{timestamp}.wav") | |
| logging.info("Starting inference...") | |
| # Perform inference and save the output audio | |
| wav = generate(text, | |
| prompt_speech, | |
| prompt_text, | |
| gender, | |
| pitch, | |
| speed,) | |
| sf.write(save_path, wav, samplerate=16000) | |
| logging.info(f"Audio saved at: {save_path}") | |
| return save_path | |
| def build_ui(model_dir, device=0): | |
| global MODEL | |
| # Initialize model with proper device handling | |
| device = "cuda" if torch.cuda.is_available() and device != "cpu" else "cpu" | |
| if MODEL is None: | |
| MODEL = initialize_model(model_dir, device=device) | |
| if device == "cuda": | |
| MODEL = MODEL.to(device) | |
| # Define callback function for voice cloning | |
| def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record): | |
| """ | |
| Gradio callback to clone voice using text and optional prompt speech. | |
| - text: The input text to be synthesised. | |
| - prompt_text: Additional textual info for the prompt (optional). | |
| - prompt_wav_upload/prompt_wav_record: Audio files used as reference. | |
| """ | |
| prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record | |
| prompt_text_clean = None if len(prompt_text) < 2 else prompt_text | |
| audio_output_path = run_tts( | |
| text, | |
| prompt_text=prompt_text_clean, | |
| prompt_speech=prompt_speech | |
| ) | |
| return audio_output_path | |
| # Define callback function for creating new voices | |
| def voice_creation(text, gender, pitch, speed): | |
| """ | |
| Gradio callback to create a synthetic voice with adjustable parameters. | |
| - text: The input text for synthesis. | |
| - gender: 'male' or 'female'. | |
| - pitch/speed: Ranges mapped by LEVELS_MAP_UI. | |
| """ | |
| pitch_val = LEVELS_MAP_UI[int(pitch)] | |
| speed_val = LEVELS_MAP_UI[int(speed)] | |
| audio_output_path = run_tts( | |
| text, | |
| gender=gender, | |
| pitch=pitch_val, | |
| speed=speed_val | |
| ) | |
| return audio_output_path | |
| with gr.Blocks() as demo: | |
| # Use HTML for centered title | |
| gr.HTML('<h1 style="text-align: center;">(Official) Spark-TTS by SparkAudio</h1>') | |
| with gr.Row(): | |
| image_3s = gr.Image(type="filepath", value="moyin_3s.png", interactive=False, height=300, width=300) | |
| with gr.Tabs(): | |
| # Voice Clone Tab | |
| with gr.TabItem("Voice Clone"): | |
| gr.Markdown( | |
| "### Upload reference audio or recording (上传参考音频或者录音)" | |
| ) | |
| with gr.Row(): | |
| prompt_wav_upload = gr.Audio( | |
| sources="upload", | |
| type="filepath", | |
| label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.", | |
| ) | |
| prompt_wav_record = gr.Audio( | |
| sources="microphone", | |
| type="filepath", | |
| label="Record the prompt audio file.", | |
| ) | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Text", lines=3, placeholder="Enter text here" | |
| ) | |
| prompt_text_input = gr.Textbox( | |
| label="Text of prompt speech (Optional; recommended for cloning in the same language.)", | |
| lines=3, | |
| placeholder="Enter text of the prompt speech.", | |
| ) | |
| audio_output = gr.Audio( | |
| label="Generated Audio", autoplay=True, streaming=True | |
| ) | |
| generate_buttom_clone = gr.Button("Generate") | |
| generate_buttom_clone.click( | |
| voice_clone, | |
| inputs=[ | |
| text_input, | |
| prompt_text_input, | |
| prompt_wav_upload, | |
| prompt_wav_record, | |
| ], | |
| outputs=[audio_output], | |
| ) | |
| # Voice Creation Tab | |
| with gr.TabItem("Voice Creation"): | |
| gr.Markdown( | |
| "### Create your own voice based on the following parameters" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gender = gr.Radio( | |
| choices=["male", "female"], value="male", label="Gender" | |
| ) | |
| pitch = gr.Slider( | |
| minimum=1, maximum=5, step=1, value=3, label="Pitch" | |
| ) | |
| speed = gr.Slider( | |
| minimum=1, maximum=5, step=1, value=3, label="Speed" | |
| ) | |
| with gr.Column(): | |
| text_input_creation = gr.Textbox( | |
| label="Input Text", | |
| lines=3, | |
| placeholder="Enter text here", | |
| value="You can generate a customized voice by adjusting parameters such as pitch and speed.", | |
| ) | |
| create_button = gr.Button("Create Voice") | |
| audio_output = gr.Audio( | |
| label="Generated Audio", autoplay=True, streaming=True | |
| ) | |
| create_button.click( | |
| voice_creation, | |
| inputs=[text_input_creation, gender, pitch, speed], | |
| outputs=[audio_output], | |
| ) | |
| return demo | |
| def parse_arguments(): | |
| """ | |
| Parse command-line arguments such as model directory and device ID. | |
| """ | |
| parser = argparse.ArgumentParser(description="Spark TTS Gradio server.") | |
| parser.add_argument( | |
| "--model_dir", | |
| type=str, | |
| default=None, | |
| help="Path to the model directory." | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default="cpu", | |
| help="Device to use (e.g., 'cpu' or 'cuda:0')." | |
| ) | |
| parser.add_argument( | |
| "--server_name", | |
| type=str, | |
| default=None, | |
| help="Server host/IP for Gradio app." | |
| ) | |
| parser.add_argument( | |
| "--server_port", | |
| type=int, | |
| default=None, | |
| help="Server port for Gradio app." | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| # Parse command-line arguments | |
| args = parse_arguments() | |
| # Build the Gradio demo by specifying the model directory and GPU device | |
| demo = build_ui( | |
| model_dir=args.model_dir, | |
| device=args.device | |
| ) | |
| # Launch Gradio with the specified server name and port | |
| demo.launch( | |
| server_name=args.server_name, | |
| server_port=args.server_port | |
| ) |