Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| import spaces | |
| from infer_rvc_python import BaseLoader | |
| import random | |
| import logging | |
| import time | |
| import soundfile as sf | |
| from infer_rvc_python.main import download_manager | |
| import zipfile | |
| import asyncio | |
| import librosa | |
| import traceback | |
| import numpy as np | |
| import urllib.request | |
| import shutil | |
| import threading | |
| from pedalboard import Pedalboard, Reverb, Compressor, HighpassFilter | |
| from pedalboard.io import AudioFile | |
| from pydub import AudioSegment | |
| import noisereduce as nr | |
| import edge_tts | |
| from huggingface_hub import hf_hub_download, HfApi | |
| from typing import List, Tuple, Optional, Dict, Any | |
| import json | |
| from pathlib import Path | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler("rvc_app.log"), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger("RVC_APP") | |
| # Suppress third-party logging | |
| logging.getLogger("infer_rvc_python").setLevel(logging.ERROR) | |
| # Constants | |
| PITCH_ALGO_OPT = ["pm", "harvest", "crepe", "rmvpe", "rmvpe+"] | |
| MAX_FILE_SIZE = 500 * 1024 * 1024 # 500 MB | |
| DOWNLOAD_DIR = "downloads" | |
| OUTPUT_DIR = "output" | |
| CONFIG_FILE = "rvc_config.json" | |
| SUPPORTED_AUDIO_FORMATS = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] | |
| # Create necessary directories | |
| os.makedirs(DOWNLOAD_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Theme and UI Configuration | |
| title = "<center><strong><font size='7'>🔊 RVC+</font></strong></center>" | |
| description = """ | |
| <div style="text-align: center; font-size: 1.1em; color: #aaa; margin: 10px 0;"> | |
| This demo is for educational and research purposes only.<br> | |
| Misuse of voice conversion technology is unethical. Use responsibly.<br> | |
| Authors are not liable for inappropriate usage. | |
| </div> | |
| """ | |
| # Theme definition (keeping your existing theme code) | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| # Define the new OrangeRed color palette | |
| colors.orange_red = colors.Color( | |
| name="orange_red", | |
| c50="#FFF0E5", | |
| c100="#FFE0CC", | |
| c200="#FFC299", | |
| c300="#FFA366", | |
| c400="#FF8533", | |
| c500="#FF4500", # OrangeRed base color | |
| c600="#E63E00", | |
| c700="#CC3700", | |
| c800="#B33000", | |
| c900="#992900", | |
| c950="#802200", | |
| ) | |
| class OrangeRedTheme(Soft): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.orange_red, | |
| neutral_hue: colors.Color | str = colors.slate, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| #font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| # fonts.GoogleFont("Outfit"), "Arial", "sans-serif", | |
| #), | |
| #font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| # fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", | |
| #), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| text_size=text_size, | |
| #font=font, | |
| #font_mono=font_mono, | |
| ) | |
| super().set( | |
| background_fill_primary="*primary_50", | |
| background_fill_primary_dark="*primary_900", | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| button_primary_text_color="white", | |
| button_primary_text_color_hover="white", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_secondary_text_color="black", | |
| button_secondary_text_color_hover="white", | |
| button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", | |
| button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", | |
| button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", | |
| button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", | |
| slider_color="*secondary_500", | |
| slider_color_dark="*secondary_600", | |
| block_title_text_weight="600", | |
| block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", | |
| button_primary_shadow="*shadow_drop_lg", | |
| button_large_padding="11px", | |
| color_accent_soft="*primary_100", | |
| block_label_background_fill="*primary_200", | |
| ) | |
| # Instantiate the theme | |
| orange_red_theme = OrangeRedTheme() | |
| # CSS (keeping your existing CSS) | |
| css = """ | |
| #main-title h1 { | |
| font-size: 2.3em !important; | |
| } | |
| #output-title h2 { | |
| font-size: 2.1em !important; | |
| } | |
| :root { | |
| --color-grey-50: #f9fafb; | |
| --banner-background: var(--secondary-400); | |
| --banner-text-color: var(--primary-100); | |
| --banner-background-dark: var(--secondary-800); | |
| --banner-text-color-dark: var(--primary-100); | |
| --banner-chrome-height: calc(16px + 43px); | |
| --chat-chrome-height-wide-no-banner: 320px; | |
| --chat-chrome-height-narrow-no-banner: 450px; | |
| --chat-chrome-height-wide: calc(var(--chat-chrome-height-wide-no-banner) + var(--banner-chrome-height)); | |
| --chat-chrome-height-narrow: calc(var(--chat-chrome-height-narrow-no-banner) + var(--banner-chrome-height)); | |
| } | |
| .banner-message { background-color: var(--banner-background); padding: 5px; margin: 0; border-radius: 5px; border: none; } | |
| .banner-message-text { font-size: 13px; font-weight: bolder; color: var(--banner-text-color) !important; } | |
| body.dark .banner-message { background-color: var(--banner-background-dark) !important; } | |
| body.dark .gradio-container .contain .banner-message .banner-message-text { color: var(--banner-text-color-dark) !important; } | |
| .toast-body { background-color: var(--color-grey-50); } | |
| .html-container:has(.css-styles) { padding: 0; margin: 0; } | |
| .css-styles { height: 0; } | |
| .model-message { text-align: end; } | |
| .model-dropdown-container { display: flex; align-items: center; gap: 10px; padding: 0; } | |
| .user-input-container .multimodal-textbox{ border: none !important; } | |
| .control-button { height: 51px; } | |
| button.cancel { border: var(--button-border-width) solid var(--button-cancel-border-color); background: var(--button-cancel-background-fill); color: var(--button-cancel-text-color); box-shadow: var(--button-cancel-shadow); } | |
| button.cancel:hover, .cancel[disabled] { background: var(--button-cancel-background-fill-hover); color: var(--button-cancel-text-color-hover); } | |
| .opt-out-message { top: 8px; } | |
| .opt-out-message .html-container, .opt-out-checkbox label { font-size: 14px !important; padding: 0 !important; margin: 0 !important; color: var(--neutral-400) !important; } | |
| div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; max-height: 900px !important; } | |
| div.no-padding { padding: 0 !important; } | |
| @media (max-width: 1280px) { div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; } } | |
| @media (max-width: 1024px) { | |
| .responsive-row { flex-direction: column; } | |
| .model-message { text-align: start; font-size: 10px !important; } | |
| .model-dropdown-container { flex-direction: column; align-items: flex-start; } | |
| div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-narrow)) !important; } | |
| } | |
| @media (max-width: 400px) { | |
| .responsive-row { flex-direction: column; } | |
| .model-message { text-align: start; font-size: 10px !important; } | |
| .model-dropdown-container { flex-direction: column; align-items: flex-start; } | |
| div.block.chatbot { max-height: 360px !important; } | |
| } | |
| @media (max-height: 932px) { .chatbot { max-height: 500px !important; } } | |
| @media (max-height: 1280px) { div.block.chatbot { max-height: 800px !important; } } | |
| """ | |
| # Model Management Class | |
| class ModelManager: | |
| """Manages model loading, downloading, and caching.""" | |
| def __init__(self): | |
| self.converter = BaseLoader(only_cpu=False, hubert_path=None, rmvpe_path=None) | |
| self.loaded_models = {} # Cache for loaded models | |
| self.config = self._load_config() | |
| def _load_config(self) -> Dict[str, Any]: | |
| """Load configuration from file if exists.""" | |
| if os.path.exists(CONFIG_FILE): | |
| try: | |
| with open(CONFIG_FILE, 'r') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| logger.error(f"Failed to load config: {e}") | |
| return {"recent_models": [], "default_settings": {}} | |
| def save_config(self): | |
| """Save current configuration to file.""" | |
| try: | |
| with open(CONFIG_FILE, 'w') as f: | |
| json.dump(self.config, f) | |
| except Exception as e: | |
| logger.error(f"Failed to save config: {e}") | |
| def add_recent_model(self, model_path: str): | |
| """Add a model to recent models list.""" | |
| if model_path not in self.config["recent_models"]: | |
| self.config["recent_models"].append(model_path) | |
| # Keep only the 5 most recent models | |
| self.config["recent_models"] = self.config["recent_models"][-5:] | |
| self.save_config() | |
| def find_files(self, directory: str, exts: Tuple[str] = (".pth", ".index", ".zip")) -> List[str]: | |
| """Find files with specific extensions in a directory.""" | |
| return [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(exts)] | |
| def unzip_in_folder(self, zip_path: str, extract_to: str): | |
| """Unzip a file to a specific folder.""" | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| for member in zip_ref.infolist(): | |
| if not member.is_dir(): | |
| # Preserve filename, avoid path traversal | |
| member.filename = os.path.basename(member.filename) | |
| zip_ref.extract(member, extract_to) | |
| def get_file_size(self, url: str) -> int: | |
| """Check file size for Hugging Face URLs.""" | |
| if "huggingface" not in url.lower(): | |
| raise ValueError("❌ Only Hugging Face links are allowed.") | |
| try: | |
| api = HfApi() | |
| # Extract repo_id and filename from the URL | |
| if "/resolve/main/" in url: | |
| parts = url.split("/resolve/main/") | |
| elif "/resolve/" in url: | |
| # Handle specific branches | |
| parts = url.split("/resolve/") | |
| parts[1] = parts[1].split("/", 1)[1] # Remove branch name | |
| else: | |
| # Assume it's a blob link or direct file link | |
| parts = url.rstrip("/").rsplit("/", 2) | |
| if len(parts) == 3: | |
| repo_parts = "/".join(parts[0].split("/")[-2:]) | |
| filename = parts[2] | |
| repo_id = f"{parts[0].split('/')[-2]}/{parts[0].split('/')[-1]}" | |
| file_info = api.repo_info(repo_id=repo_id, repo_type="model") | |
| file_entry = next((f for f in file_info.siblings if f.rfilename == filename), None) | |
| if not file_entry: | |
| raise ValueError(f"❌ File '{filename}' not found in repository '{repo_id}'.") | |
| file_size = file_entry.size | |
| if file_size > MAX_FILE_SIZE: | |
| raise ValueError(f"⚠️ File too large: {file_size / 1e6:.1f} MB (>500MB)") | |
| return file_size | |
| else: | |
| raise ValueError("❌ Unable to parse Hugging Face URL.") | |
| repo_parts = parts[0].split("/")[-2:] | |
| repo_id = f"{repo_parts[0]}/{repo_parts[1]}" | |
| filename = parts[1] | |
| file_info = api.repo_info(repo_id=repo_id, repo_type="model") | |
| file_entry = next((f for f in file_info.siblings if f.rfilename == filename), None) | |
| if not file_entry: | |
| raise ValueError(f"❌ File '{filename}' not found in repository '{repo_id}'.") | |
| file_size = file_entry.size | |
| if file_size > MAX_FILE_SIZE: | |
| raise ValueError(f"⚠️ File too large: {file_size / 1e6:.1f} MB (>500MB)") | |
| return file_size | |
| except Exception as e: | |
| raise RuntimeError(f"❌ Failed to fetch file info: {str(e)}") | |
| def clear_directory_later(self, directory: str, delay: int = 30): | |
| """Clear temp directory after delay in a background thread.""" | |
| def _clear(): | |
| time.sleep(delay) | |
| if os.path.exists(directory): | |
| shutil.rmtree(directory, ignore_errors=True) | |
| logger.info(f"🧹 Cleaned up: {directory}") | |
| threading.Thread(target=_clear, daemon=True).start() | |
| def find_model_and_index(self, directory: str) -> Tuple[Optional[str], Optional[str]]: | |
| """Find model and index files in a directory.""" | |
| files = self.find_files(directory) | |
| model = next((f for f in files if f.endswith(".pth")), None) | |
| index = next((f for f in files if f.endswith(".index")), None) | |
| return model, index | |
| def download_model(self, url_data: str) -> Tuple[str, Optional[str]]: | |
| """Download model from Hugging Face URL.""" | |
| if not url_data.strip(): | |
| raise ValueError("❌ No URL provided.") | |
| urls = [u.strip() for u in url_data.split(",") if u.strip()] | |
| if len(urls) > 2: | |
| raise ValueError("❌ Provide up to two URLs (model.pth, index.index).") | |
| # Validate size first | |
| for url in urls: | |
| self.get_file_size(url) | |
| folder_name = f"model_{random.randint(1000, 9999)}" | |
| directory = os.path.join(DOWNLOAD_DIR, folder_name) | |
| os.makedirs(directory, exist_ok=True) | |
| try: | |
| downloaded_files = [] | |
| for url in urls: | |
| # Use the robust Hugging Face Hub library for download | |
| parsed_url = urllib.parse.urlparse(url) | |
| path_parts = parsed_url.path.strip("/").split("/") | |
| if len(path_parts) < 4: | |
| raise ValueError("❌ Invalid Hugging Face URL structure.") | |
| repo_id = f"{path_parts[0]}/{path_parts[1]}" | |
| revision = "main" | |
| if "resolve" in path_parts: | |
| resolve_idx = path_parts.index("resolve") | |
| if resolve_idx + 1 < len(path_parts): | |
| revision = path_parts[resolve_idx + 1] | |
| filename = "/".join(path_parts[resolve_idx + 2:]) | |
| else: | |
| # Assume it's a blob link pointing to a file | |
| filename = path_parts[-1] | |
| # Download the file | |
| local_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| revision=revision, | |
| cache_dir=directory, | |
| local_dir=directory, | |
| local_dir_use_symlinks=False | |
| ) | |
| downloaded_files.append(local_path) | |
| # Unzip if needed | |
| for f in self.find_files(directory, (".zip",)): | |
| self.unzip_in_folder(f, directory) | |
| model, index = self.find_model_and_index(directory) | |
| if not model: | |
| raise ValueError("❌ .pth model file not found in downloaded content.") | |
| gr.Info(f"✅ Model loaded: {os.path.basename(model)}") | |
| if index: | |
| gr.Info(f"📌 Index loaded: {os.path.basename(index)}") | |
| else: | |
| gr.Warning("⚠️ Index file not found – conversion may be less accurate.") | |
| # Schedule cleanup | |
| self.clear_directory_later(directory, delay=30) | |
| # Add to recent models | |
| self.add_recent_model(os.path.abspath(model)) | |
| return os.path.abspath(model), os.path.abspath(index) if index else None | |
| except Exception as e: | |
| shutil.rmtree(directory, ignore_errors=True) | |
| logger.error(f"Download failed: {e}") | |
| raise gr.Error(f"❌ Download failed: {str(e)}") | |
| # Audio Processing Class | |
| class AudioProcessor: | |
| """Handles audio processing tasks like noise reduction and effects.""" | |
| def apply_noisereduce(audio_paths: List[str]) -> List[str]: | |
| """Apply noise reduction to audio files.""" | |
| results = [] | |
| for path in audio_paths: | |
| out_path = f"{os.path.splitext(path)[0]}_denoised.wav" | |
| try: | |
| audio = AudioSegment.from_file(path) | |
| samples = np.array(audio.get_array_of_samples()) | |
| sr = audio.frame_rate | |
| reduced = nr.reduce_noise(y=samples.astype(np.float32), sr=sr, prop_decrease=0.6) | |
| reduced_audio = AudioSegment( | |
| reduced.tobytes(), | |
| frame_rate=sr, | |
| sample_width=audio.sample_width, | |
| channels=audio.channels | |
| ) | |
| reduced_audio.export(out_path, format="wav") | |
| results.append(out_path) | |
| gr.Info("🔊 Noise reduction applied.") | |
| except Exception as e: | |
| logger.error(f"Noise reduction failed: {e}") | |
| results.append(path) | |
| return results | |
| def apply_audio_effects(audio_paths: List[str]) -> List[str]: | |
| """Apply audio effects to audio files.""" | |
| results = [] | |
| board = Pedalboard([ | |
| HighpassFilter(cutoff_frequency_hz=80), | |
| Compressor(ratio=4, threshold_db=-15), | |
| Reverb(room_size=0.15, damping=0.7, wet_level=0.15, dry_level=0.85) | |
| ]) | |
| for path in audio_paths: | |
| out_path = f"{os.path.splitext(path)[0]}_reverb.wav" | |
| try: | |
| with AudioFile(path) as f: | |
| with AudioFile(out_path, 'w', f.samplerate, f.num_channels) as o: | |
| while f.tell() < f.frames: | |
| chunk = f.read(int(f.samplerate)) | |
| effected = board(chunk, f.samplerate) | |
| o.write(effected) | |
| results.append(out_path) | |
| gr.Info("🎛️ Audio effects applied.") | |
| except Exception as e: | |
| logger.error(f"Effects failed: {e}") | |
| results.append(path) | |
| return results | |
| def validate_audio_files(file_paths: List[str]) -> List[str]: | |
| """Validate that files are supported audio formats.""" | |
| valid_files = [] | |
| for path in file_paths: | |
| if os.path.splitext(path)[1].lower() in SUPPORTED_AUDIO_FORMATS: | |
| valid_files.append(path) | |
| else: | |
| gr.Warning(f"⚠️ Skipping unsupported file: {os.path.basename(path)}") | |
| return valid_files | |
| # TTS Handler Class | |
| class TTSHandler: | |
| """Handles text-to-speech functionality.""" | |
| async def generate_tts(text: str, voice: str, output_path: str): | |
| """Generate TTS audio from text.""" | |
| communicate = edge_tts.Communicate(text, voice.split("-")[0]) | |
| await communicate.save(output_path) | |
| def infer_tts(tts_voice: str, tts_text: str, play_tts: bool) -> Tuple[List[str], Optional[str]]: | |
| """Generate TTS audio with the specified voice.""" | |
| if not tts_text.strip(): | |
| raise ValueError("❌ Text is empty.") | |
| folder = f"tts_{random.randint(10000, 99999)}" | |
| out_dir = os.path.join(OUTPUT_DIR, folder) | |
| os.makedirs(out_dir, exist_ok=True) | |
| out_path = os.path.join(out_dir, "tts_output.mp3") | |
| try: | |
| asyncio.run(TTSHandler.generate_tts(tts_text, tts_voice, out_path)) | |
| if play_tts: | |
| return [out_path], out_path | |
| return [out_path], None | |
| except Exception as e: | |
| logger.error(f"TTS generation failed: {e}") | |
| raise gr.Error(f"TTS generation failed: {str(e)}") | |
| def get_voice_list() -> List[str]: | |
| """Get list of available TTS voices.""" | |
| try: | |
| return sorted( | |
| [f"{v['ShortName']}-{v['Gender']}" for v in asyncio.run(edge_tts.list_voices())] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to get voice list: {e}") | |
| return ["en-US-JennyNeural-Female"] # Fallback | |
| # Main Conversion Function | |
| def run_conversion( | |
| audio_files: List[str], | |
| model_path: str, | |
| pitch_algo: str, | |
| pitch_level: int, | |
| index_path: Optional[str], | |
| index_rate: float, | |
| filter_radius: int, | |
| rms_mix_rate: float, | |
| protect: float, | |
| denoise: bool, | |
| effects: bool, | |
| model_manager: ModelManager | |
| ) -> List[str]: | |
| """Run voice conversion on the provided audio files.""" | |
| if not audio_files: | |
| raise ValueError("❌ Please upload at least one audio file.") | |
| # Validate audio files | |
| audio_files = AudioProcessor.validate_audio_files(audio_files) | |
| if not audio_files: | |
| raise ValueError("❌ No valid audio files provided.") | |
| random_tag = f"USER_{random.randint(10000000, 99999999)}" | |
| # Configure converter | |
| model_manager.converter.apply_conf( | |
| tag=random_tag, | |
| file_model=model_path, | |
| pitch_algo=pitch_algo, | |
| pitch_lvl=pitch_level, | |
| file_index=index_path, | |
| index_influence=index_rate, | |
| respiration_median_filtering=int(filter_radius), | |
| envelope_ratio=rms_mix_rate, | |
| consonant_breath_protection=protect, | |
| resample_sr=44100 if any(f.endswith(".mp3") for f in audio_files) else 0, | |
| ) | |
| # Run conversion | |
| try: | |
| results = model_manager.converter(audio_files, random_tag, overwrite=False, parallel_workers=8) | |
| except Exception as e: | |
| logger.error(f"Conversion failed: {e}") | |
| raise gr.Error(f"❌ Conversion failed: {str(e)}") | |
| # Post-processing | |
| if denoise: | |
| results = AudioProcessor.apply_noisereduce(results) | |
| if effects: | |
| results = AudioProcessor.apply_audio_effects(results) | |
| return results | |
| # Gradio UI Builder | |
| def create_ui(): | |
| """Create and configure the Gradio UI.""" | |
| # Initialize model manager | |
| model_manager = ModelManager() | |
| with gr.Blocks(theme=orange_red_theme, title="RVC+", fill_width=True, delete_cache=(3200, 3200), css=css) as app: | |
| gr.HTML(title) | |
| gr.HTML(description) | |
| with gr.Tabs(): | |
| # ============= TAB 1: Voice Conversion ============= | |
| with gr.Tab("🎤 Voice Conversion", id=0): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🔊 Upload Audio") | |
| audio_input = gr.File( | |
| label="Audio Files (WAV, MP3, OGG, FLAC, M4A)", | |
| file_count="multiple", | |
| type="filepath" | |
| ) | |
| gr.Markdown("### 📥 Load Model") | |
| model_file = gr.File(label="Upload .pth Model", type="filepath") | |
| index_file = gr.File(label="Upload .index File (Optional)", type="filepath") | |
| # Recent models dropdown | |
| recent_models = gr.Dropdown( | |
| label="Recent Models", | |
| choices=model_manager.config["recent_models"], | |
| value=None, | |
| interactive=True | |
| ) | |
| recent_models.change( | |
| lambda x: x if x else None, | |
| inputs=[recent_models], | |
| outputs=[model_file] | |
| ) | |
| use_url = gr.Checkbox(label="🌐 Download from Hugging Face URL", value=False) | |
| with gr.Group(visible=False) as url_group: | |
| gr.Markdown( | |
| "🔗 Paste Hugging Face link(s):<br>" | |
| "• Direct ZIP: `https://hf.co/user/repo/resolve/main/model.zip`<br>" | |
| "• Separate files: `https://hf.co/user/repo/resolve/main/model.pth, https://hf.co/user/repo/resolve/main/model.index`" | |
| ) | |
| model_url = gr.Textbox( | |
| placeholder="https://huggingface.co/user/repo/resolve/main/file.pth", | |
| label="Model URL(s)", | |
| lines=2 | |
| ) | |
| download_btn = gr.Button("⬇️ Download Model", variant="secondary") | |
| use_url.change( | |
| lambda x: gr.update(visible=x), | |
| inputs=[use_url], | |
| outputs=[url_group] | |
| ) | |
| download_btn.click( | |
| model_manager.download_model, | |
| inputs=[model_url], | |
| outputs=[model_file, index_file] | |
| ).then( | |
| lambda: gr.update(visible=False), # Hide URL group after download | |
| outputs=[url_group] | |
| ).then( | |
| lambda: gr.update(choices=model_manager.config["recent_models"]), | |
| outputs=[recent_models] | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ Conversion Settings") | |
| with gr.Group(): | |
| pitch_algo = gr.Dropdown(PITCH_ALGO_OPT, value="rmvpe+", label="Pitch Algorithm") | |
| pitch_level = gr.Slider(-24, 24, value=0, step=1, label="Pitch Level") | |
| index_rate = gr.Slider(0, 1, value=0.75, label="Index Influence") | |
| filter_radius = gr.Slider(0, 7, value=3, step=1, label="Median Filter") | |
| rms_mix_rate = gr.Slider(0, 1, value=0.25, label="Volume Envelope") | |
| protect = gr.Slider(0, 0.5, value=0.5, label="Consonant Protection") | |
| denoise = gr.Checkbox(False, label="🔇 Denoise Output") | |
| reverb = gr.Checkbox(False, label="🎛️ Add Reverb") | |
| # Save settings button | |
| save_settings_btn = gr.Button("💾 Save as Default", size="sm") | |
| save_settings_btn.click( | |
| lambda *args: model_manager.config.update({"default_settings": { | |
| "pitch_algo": args[0], "pitch_level": args[1], "index_rate": args[2], | |
| "filter_radius": args[3], "rms_mix_rate": args[4], "protect": args[5], | |
| "denoise": args[6], "reverb": args[7] | |
| }}) or model_manager.save_config(), | |
| inputs=[pitch_algo, pitch_level, index_rate, filter_radius, | |
| rms_mix_rate, protect, denoise, reverb] | |
| ) | |
| convert_btn = gr.Button("🚀 Convert Voice", variant="primary", size="lg") | |
| output_files = gr.File(label="✅ Converted Audio", file_count="multiple") | |
| # Progress indicator | |
| progress = gr.Progress() | |
| convert_btn.click( | |
| run_conversion, | |
| inputs=[ | |
| audio_input, | |
| model_file, | |
| pitch_algo, | |
| pitch_level, | |
| index_file, | |
| index_rate, | |
| filter_radius, | |
| rms_mix_rate, | |
| protect, | |
| denoise, | |
| reverb, | |
| gr.State(model_manager) # Pass model manager as state | |
| ], | |
| outputs=output_files, | |
| ) | |
| # ============= TAB 2: Text-to-Speech ============= | |
| with gr.Tab("🗣️ Text-to-Speech", id=1): | |
| gr.Markdown("### Convert text to speech using Edge TTS.") | |
| # Get voice list | |
| tts_voice_list = TTSHandler.get_voice_list() | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| tts_text = gr.Textbox( | |
| placeholder="Enter your text here...", | |
| label="Text Input", | |
| lines=5 | |
| ) | |
| tts_voice = gr.Dropdown( | |
| tts_voice_list, | |
| value=tts_voice_list[0] if tts_voice_list else None, | |
| label="Voice" | |
| ) | |
| tts_play = gr.Checkbox(False, label="🎧 Auto-play audio") | |
| tts_btn = gr.Button("🔊 Generate Speech", variant="secondary") | |
| with gr.Column(scale=1): | |
| tts_output_audio = gr.File(label="Download Audio", type="filepath") | |
| tts_preview = gr.Audio(label="Preview", visible=False, autoplay=True) | |
| tts_btn.click( | |
| TTSHandler.infer_tts, | |
| inputs=[tts_voice, tts_text, tts_play], | |
| outputs=[tts_output_audio, tts_preview], | |
| ).then( | |
| lambda x: gr.update(visible=bool(x)), | |
| inputs=[tts_preview], | |
| outputs=[tts_preview] | |
| ) | |
| # ============= TAB 3: Settings ============= | |
| with gr.Tab("⚙️ Settings", id=2): | |
| gr.Markdown("### Application Settings") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Model Management") | |
| clear_cache_btn = gr.Button("🗑️ Clear Model Cache", variant="secondary") | |
| clear_cache_btn.click( | |
| lambda: shutil.rmtree(DOWNLOAD_DIR, ignore_errors=True) or gr.Info("Cache cleared"), | |
| outputs=[] | |
| ) | |
| gr.Markdown("#### Recent Models") | |
| recent_models_list = gr.DataFrame( | |
| value=[[model] for model in model_manager.config["recent_models"]], | |
| headers=["Model Path"], | |
| datatype=["str"], | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### System Information") | |
| system_info = gr.HTML( | |
| f""" | |
| <div> | |
| <p><strong>Python Version:</strong> {os.sys.version}</p> | |
| <p><strong>Platform:</strong> {os.sys.platform}</p> | |
| <p><strong>Download Directory:</strong> {os.path.abspath(DOWNLOAD_DIR)}</p> | |
| <p><strong>Output Directory:</strong> {os.path.abspath(OUTPUT_DIR)}</p> | |
| </div> | |
| """ | |
| ) | |
| # Examples | |
| gr.Markdown("### 📚 Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["./test.ogg", "./model.pth", "rmvpe+", 0, "./model.index", 0.75, 3, 0.25, 0.5, False, False], | |
| ["./example3/test3.wav", "./example3/zip_link.txt", "rmvpe+", 0, None, 0.75, 3, 0.25, 0.5, True, True], | |
| ], | |
| inputs=[ | |
| audio_input, model_file, pitch_algo, pitch_level, index_file, | |
| index_rate, filter_radius, rms_mix_rate, protect, denoise, reverb | |
| ], | |
| outputs=output_files, | |
| fn=lambda *args: run_conversion(*args, model_manager), | |
| cache_examples=False, | |
| ) | |
| return app | |
| # Launch App | |
| if __name__ == "__main__": | |
| app = create_ui() | |
| app.queue(default_concurrency_limit=10) | |
| app.launch( | |
| share=True, | |
| debug=False, | |
| show_api=True, | |
| max_threads=40, | |
| allowed_paths=[DOWNLOAD_DIR, OUTPUT_DIR], | |
| ) |