Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| os.environ.setdefault("GRADIO_USE_CDN", "true") | |
| try: | |
| import spaces # HF Spaces SDK | |
| except Exception: | |
| class _DummySpaces: | |
| def GPU(self, *_, **__): | |
| def deco(fn): return fn | |
| return deco | |
| spaces = _DummySpaces() | |
| def gpu_probe(a: int = 1, b: int = 1): | |
| return a + b | |
| def gpu_echo(x: str = "ok"): | |
| return x | |
| # ================= Standard imports ================= | |
| import sys | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Tuple, Optional, List, Any | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| from huggingface_hub import hf_hub_download | |
| # Runtime hints (safe on CPU) | |
| USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu" | |
| SPACE_ROOT = Path(__file__).parent.resolve() | |
| REPO_DIR = SPACE_ROOT / "SonicMasterRepo" | |
| REPO_URL = "https://github.com/AMAAI-Lab/SonicMaster" | |
| WEIGHTS_REPO = "amaai-lab/SonicMaster" | |
| WEIGHTS_FILE = "model.safetensors" | |
| CACHE_DIR = SPACE_ROOT / "weights" | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| # ================ Repo clone AT STARTUP (so examples show immediately) ================ | |
| def ensure_repo() -> Path: | |
| if not REPO_DIR.exists(): | |
| subprocess.run( | |
| ["git", "clone", "--depth", "1", REPO_URL, REPO_DIR.as_posix()], | |
| check=True, | |
| ) | |
| if REPO_DIR.as_posix() not in sys.path: | |
| sys.path.append(REPO_DIR.as_posix()) | |
| return REPO_DIR | |
| # Clone now so examples are available immediately | |
| ensure_repo() | |
| # ================ Weights: still lazy (download at first run) ================ | |
| _weights_path: Optional[Path] = None | |
| def get_weights_path(progress: Optional[gr.Progress] = None) -> Path: | |
| """Download/resolve weights lazily (keeps startup fast).""" | |
| global _weights_path | |
| if _weights_path is None: | |
| if progress: progress(0.10, desc="Downloading model weights (first run)") | |
| wp = hf_hub_download( | |
| repo_id=WEIGHTS_REPO, | |
| filename=WEIGHTS_FILE, | |
| local_dir=str(CACHE_DIR), | |
| local_dir_use_symlinks=False, | |
| force_download=False, | |
| resume_download=True, | |
| ) | |
| _weights_path = Path(wp) | |
| return _weights_path | |
| # ================== Helpers ================== | |
| def save_temp_wav(wav: np.ndarray, sr: int, path: Path): | |
| # Ensure shape (samples, channels) | |
| if wav.ndim == 2 and wav.shape[0] < wav.shape[1]: | |
| wav = wav.T | |
| if wav.dtype == np.float64: | |
| wav = wav.astype(np.float32) | |
| sf.write(path.as_posix(), wav, sr) | |
| def read_audio(path: str) -> Tuple[np.ndarray, int]: | |
| wav, sr = sf.read(path, always_2d=False) | |
| if wav.dtype == np.float64: | |
| wav = wav.astype(np.float32) | |
| return wav, sr | |
| def _candidate_commands(py: str, script: Path, ckpt: Path, inp: Path, prompt: str, out: Path) -> List[List[str]]: | |
| """ | |
| Only support infer_single.py variants. | |
| Expected primary flags: --ckpt --input --prompt --output | |
| """ | |
| return [ | |
| [py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--prompt", prompt, "--output", out.as_posix()], | |
| ] | |
| def run_sonicmaster_cli( | |
| input_wav_path: Path, | |
| prompt: str, | |
| out_path: Path, | |
| progress: Optional[gr.Progress] = None, | |
| ) -> Tuple[bool, str]: | |
| """Run inference via subprocess; returns (ok, message). Uses ONLY infer_single.py.""" | |
| # π§ Ensure a non-empty prompt for the CLI | |
| prompt = (prompt or "").strip() or "Enhance the input audio" | |
| if progress: progress(0.14, desc="Preparing inference") | |
| ckpt = get_weights_path(progress=progress) | |
| script = REPO_DIR / "infer_single.py" | |
| if not script.exists(): | |
| return False, "infer_single.py not found in the SonicMaster repo." | |
| py = sys.executable or "python3" | |
| env = os.environ.copy() | |
| last_err = "" | |
| for cidx, cmd in enumerate(_candidate_commands(py, script, ckpt, input_wav_path, prompt, out_path), 1): | |
| try: | |
| if progress: | |
| progress(min(0.25 + 0.10 * cidx, 0.70), desc=f"Running infer_single.py (try {cidx})") | |
| res = subprocess.run(cmd, capture_output=True, text=True, check=True, env=env) | |
| if out_path.exists() and out_path.stat().st_size > 0: | |
| if progress: progress(0.88, desc="Post-processing output") | |
| return True, (res.stdout or "Inference completed.").strip() | |
| last_err = "infer_single.py finished but produced no output file." | |
| except subprocess.CalledProcessError as e: | |
| snippet = "\n".join(filter(None, [e.stdout or "", e.stderr or ""])).strip() | |
| last_err = snippet if snippet else f"infer_single.py failed with return code {e.returncode}." | |
| except Exception as e: | |
| import traceback | |
| last_err = f"Unexpected error with infer_single.py: {e}\n{traceback.format_exc()}" | |
| return False, last_err or "All candidate commands failed." | |
| # ============ GPU path (ZeroGPU) ============ | |
| # safe cap for ZeroGPU tiers | |
| def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> Tuple[bool, str]: | |
| try: | |
| import torch # noqa: F401 | |
| except Exception: | |
| pass | |
| from pathlib import Path as _P | |
| return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), progress=None) | |
| def _has_cuda() -> bool: | |
| try: | |
| import torch | |
| return torch.cuda.is_available() | |
| except Exception: | |
| return False | |
| # ================== Examples @ STARTUP ================== | |
| PROMPTS_10 = [ | |
| "Increase the clarity of this song by emphasizing treble frequencies.", | |
| "Make this song sound more boomy by amplifying the low end bass frequencies.", | |
| "Can you make this sound louder, please?", | |
| "Make the audio smoother and less distorted.", | |
| "Improve the balance in this song.", | |
| "Disentangle the left and right channels to give this song a stereo feeling.", | |
| "Correct the unnatural frequency emphasis. Reduce the roominess or echo.", | |
| "Raise the level of the vocals, please.", | |
| "Increase the clarity of this song by emphasizing treble frequencies.", | |
| "Please, dereverb this audio.", | |
| ] | |
| def build_startup_examples() -> List[List[Any]]: | |
| """Build 10 (audio_path, prompt) pairs from repo at import time.""" | |
| wav_dir = REPO_DIR / "samples" / "inputs" | |
| wav_paths = sorted(p for p in wav_dir.glob("*.wav") if p.is_file()) | |
| ex = [] | |
| for i, p in enumerate(wav_paths[:10]): | |
| pr = PROMPTS_10[i] if i < len(PROMPTS_10) else PROMPTS_10[-1] | |
| ex.append([p.as_posix(), pr]) | |
| return ex | |
| STARTUP_EXAMPLES = build_startup_examples() | |
| # ================== Main callback ================== | |
| def enhance_audio_ui( | |
| audio_path: str, | |
| prompt: str, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """ | |
| Returns (audio, message). On failure, audio=None and message=error text. | |
| """ | |
| try: | |
| # π§ normalize/fallback so --prompt is always passed | |
| prompt = (prompt or "").strip() | |
| if not prompt: | |
| prompt = "Enhance the input audio" | |
| if not audio_path: | |
| raise gr.Error("Please upload or select an input audio file.") | |
| wav, sr = read_audio(audio_path) | |
| tmp_in = SPACE_ROOT / "tmp_in.wav" | |
| tmp_out = SPACE_ROOT / "tmp_out.wav" | |
| if tmp_out.exists(): | |
| try: tmp_out.unlink() | |
| except Exception: pass | |
| if progress: progress(0.06, desc="Preparing audio") | |
| save_temp_wav(wav, sr, tmp_in) | |
| use_gpu_call = USE_ZEROGPU or _has_cuda() | |
| if progress: progress(0.12, desc="Starting inference") | |
| if use_gpu_call: | |
| ok, msg = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix()) | |
| else: | |
| ok, msg = run_sonicmaster_cli(tmp_in, prompt, tmp_out, progress=progress) | |
| if ok and tmp_out.exists() and tmp_out.stat().st_size > 0: | |
| out_wav, out_sr = read_audio(tmp_out.as_posix()) | |
| return (out_sr, out_wav), (msg or "Done.") | |
| else: | |
| return None, (msg or "Inference failed without a specific error message.") | |
| except gr.Error as e: | |
| return None, str(e) | |
| except Exception as e: | |
| import traceback | |
| return None, f"Unexpected error: {e}\n{traceback.format_exc()}" | |
| # ================== Gradio UI ================== | |
| with gr.Blocks(title="SonicMaster β Text-Guided Restoration & Mastering", fill_height=True) as _demo: | |
| gr.Markdown( | |
| "## π§ SonicMaster\n" | |
| "Upload audio or pick an example, write a prompt (or leave blank), then click **Enhance**.\n" | |
| "If left blank, we'll use a generic prompt: _Enhance the input audio_.\n" | |
| "- The enhanced audio may take a few seconds to appear after processing. Please wait until the output loads.\n" | |
| "- Please note that if it is the first run, HF will need to download model weights which takes a while.\n" | |
| "\n" | |
| "If you enjoy this model, please cite [our paper](https://huggingface.co/papers/2508.03448). " | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| in_audio = gr.Audio(label="Input Audio", type="filepath") | |
| prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., Reduce reverb and brighten vocals. (Optional)") | |
| run_btn = gr.Button("π Enhance", variant="primary") | |
| # Show 10 audio+prompt examples immediately at startup | |
| if STARTUP_EXAMPLES: | |
| gr.Examples( | |
| examples=STARTUP_EXAMPLES, | |
| inputs=[in_audio, prompt], | |
| label="Sample Inputs (10)", | |
| ) | |
| else: | |
| gr.Markdown("> β οΈ No sample .wav files found in `samples/inputs/`.") | |
| with gr.Column(scale=1): | |
| out_audio = gr.Audio(label="Enhanced Audio (output)") | |
| status = gr.Textbox(label="Status / Messages", interactive=False, lines=8) | |
| run_btn.click( | |
| fn=enhance_audio_ui, | |
| inputs=[in_audio, prompt], | |
| outputs=[out_audio, status], | |
| concurrency_limit=1, | |
| ) | |
| # Expose all common names the supervisor might look for | |
| demo = _demo.queue(max_size=16) | |
| iface = demo | |
| app = demo | |
| # Local debugging only | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |