Spaces:
Running
on
Zero
Running
on
Zero
| # ---------- MUST BE FIRST: Gradio CDN + ZeroGPU probe ---------- | |
| import os | |
| os.environ.setdefault("GRADIO_USE_CDN", "true") | |
| import spaces | |
| def _gpu_probe(a: int = 1, b: int = 1) -> int: | |
| # Never called; exists so ZeroGPU startup check passes. | |
| return a + b | |
| # ---------- Standard imports ---------- | |
| from pathlib import Path | |
| from typing import Optional, Tuple, List | |
| import subprocess | |
| import sys | |
| import traceback | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| from huggingface_hub import hf_hub_download | |
| # ---------- Config ---------- | |
| SPACE_ROOT = Path(__file__).parent.resolve() | |
| REPO_DIR = SPACE_ROOT / "SonicMasterRepo" | |
| REPO_URL = "https://github.com/AMAAI-Lab/SonicMaster" | |
| WEIGHTS_REPO = "amaai-lab/SonicMaster" | |
| WEIGHTS_FILE = "model.safetensors" | |
| CACHE_DIR = SPACE_ROOT / "weights" | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| # ZeroGPU detection (heuristic) | |
| USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu" | |
| # ---------- Lazy resources ---------- | |
| _weights_path: Optional[Path] = None | |
| _repo_ready: bool = False | |
| def get_weights_path(progress: Optional[gr.Progress] = None) -> Path: | |
| """Fetch model weights lazily and cache the resolved path.""" | |
| global _weights_path | |
| if _weights_path is None: | |
| if progress: | |
| progress(0.10, desc="Downloading model weights (first run)") | |
| wp = hf_hub_download( | |
| repo_id=WEIGHTS_REPO, | |
| filename=WEIGHTS_FILE, | |
| local_dir=str(CACHE_DIR), | |
| local_dir_use_symlinks=False, | |
| force_download=False, | |
| resume_download=True, | |
| ) | |
| _weights_path = Path(wp) | |
| return _weights_path | |
| def ensure_repo(progress: Optional[gr.Progress] = None) -> Path: | |
| """Clone the inference repo lazily and put it on sys.path.""" | |
| global _repo_ready | |
| if not _repo_ready: | |
| if not REPO_DIR.exists(): | |
| if progress: | |
| progress(0.18, desc="Cloning SonicMaster repo (first run)") | |
| subprocess.run( | |
| ["git", "clone", "--depth", "1", REPO_URL, REPO_DIR.as_posix()], | |
| check=True, | |
| ) | |
| if REPO_DIR.as_posix() not in sys.path: | |
| sys.path.append(REPO_DIR.as_posix()) | |
| _repo_ready = True | |
| return REPO_DIR | |
| # ---------- Audio helpers ---------- | |
| def save_temp_wav(wav: np.ndarray, sr: int, path: Path): | |
| # Ensure (N, C) shape for soundfile | |
| if wav.ndim == 1: | |
| data = wav | |
| else: | |
| # (channels, samples) -> (samples, channels) | |
| data = wav.T if wav.shape[0] < wav.shape[1] else wav | |
| if data.dtype == np.float64: | |
| data = data.astype(np.float32) | |
| sf.write(path.as_posix(), data, sr) | |
| def read_audio(path: str) -> Tuple[np.ndarray, int]: | |
| wav, sr = sf.read(path, always_2d=False) | |
| if wav.dtype == np.float64: | |
| wav = wav.astype(np.float32) | |
| return wav, sr | |
| # ---------- CLI runner ---------- | |
| def _candidate_commands(py: str, script: Path, ckpt: Path, inp: Path, prompt: str, out: Path) -> List[List[str]]: | |
| """Try multiple arg styles commonly found in repos.""" | |
| combos = [ | |
| # infer_single.py (common) | |
| [py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--prompt", prompt, "--output", out.as_posix()], | |
| [py, script.as_posix(), "--weights", ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--out", out.as_posix()], | |
| # other possible entrypoints | |
| [py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--output", out.as_posix()], | |
| ] | |
| return combos | |
| def run_sonicmaster_cli( | |
| input_wav_path: Path, | |
| prompt: str, | |
| out_path: Path, | |
| progress: Optional[gr.Progress] = None, | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Returns (ok, message). Captures stderr/stdout and returns first non-empty output file. | |
| """ | |
| if progress: | |
| progress(0.14, desc="Preparing inference") | |
| ckpt = get_weights_path(progress=progress) | |
| repo = ensure_repo(progress=progress) | |
| # Candidate scripts to try | |
| script_candidates = [ | |
| repo / "infer_single.py", | |
| repo / "inference_fullsong.py", | |
| repo / "inference_ptload_batch.py", | |
| ] | |
| scripts = [s for s in script_candidates if s.exists()] | |
| if not scripts: | |
| return False, "No inference script found in the repo (expected infer_single.py or similar)." | |
| py = sys.executable or "python3" | |
| env = os.environ.copy() # keep CUDA_VISIBLE_DEVICES etc. | |
| last_err = "" | |
| for idx, script in enumerate(scripts, start=1): | |
| for jdx, cmd in enumerate(_candidate_commands(py, script, ckpt, input_wav_path, prompt, out_path), start=1): | |
| try: | |
| if progress: | |
| progress(min(0.20 + 0.08 * (idx + jdx), 0.70), desc=f"Running {script.name} (try {idx}.{jdx})") | |
| res = subprocess.run(cmd, capture_output=True, text=True, check=True, env=env) | |
| if out_path.exists() and out_path.stat().st_size > 0: | |
| if progress: | |
| progress(0.88, desc="Post-processing output") | |
| # Return any informative stdout as message | |
| msg = (res.stdout or "").strip() | |
| return True, msg if msg else "Inference completed." | |
| else: | |
| last_err = f"{script.name} produced no output file." | |
| except subprocess.CalledProcessError as e: | |
| # Collect stderr/stdout for the user | |
| snippet = "\n".join(filter(None, [e.stdout or "", e.stderr or ""])).strip() | |
| last_err = snippet if snippet else f"{script.name} failed with return code {e.returncode}." | |
| except Exception as e: | |
| last_err = f"Unexpected error: {e}\n{traceback.format_exc()}" | |
| return False, last_err or "All candidate commands failed without an error message." | |
| # ---------- REAL GPU function (called only if using ZeroGPU / GPU available) ---------- | |
| def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> Tuple[bool, str]: | |
| try: | |
| # Initialize CUDA inside the GPU context | |
| import torch # noqa: F401 | |
| except Exception: | |
| pass | |
| from pathlib import Path as _P | |
| return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), progress=None) | |
| def _has_cuda() -> bool: | |
| try: | |
| import torch | |
| return torch.cuda.is_available() | |
| except Exception: | |
| return False | |
| # ---------- UI callback ---------- | |
| def enhance_audio_ui( | |
| audio_path: str, | |
| prompt: str, | |
| progress=gr.Progress(track_tqdm=True), | |
| ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: | |
| """ | |
| Returns (audio, message). On failure, audio=None and message=error text. | |
| """ | |
| try: | |
| if not prompt: | |
| raise gr.Error("Please provide a text prompt.") | |
| if not audio_path: | |
| raise gr.Error("Please upload or select an input audio file.") | |
| wav, sr = read_audio(audio_path) | |
| tmp_in = SPACE_ROOT / "tmp_in.wav" | |
| tmp_out = SPACE_ROOT / "tmp_out.wav" | |
| if tmp_out.exists(): | |
| try: | |
| tmp_out.unlink() | |
| except Exception: | |
| pass | |
| if progress: | |
| progress(0.06, desc="Preparing audio") | |
| save_temp_wav(wav, sr, tmp_in) | |
| # Choose execution path: prefer real GPU if available, else CPU | |
| use_gpu_call = USE_ZEROGPU or _has_cuda() | |
| if progress: | |
| progress(0.12, desc="Starting inference") | |
| if use_gpu_call: | |
| ok, msg = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix()) | |
| else: | |
| ok, msg = run_sonicmaster_cli(tmp_in, prompt, tmp_out, progress=progress) | |
| if ok and tmp_out.exists() and tmp_out.stat().st_size > 0: | |
| # Return output audio by filepath (lighter than big arrays) | |
| # Gradio Audio accepts a (sr, np.ndarray) OR a file path; giving file path is fine. | |
| return (None, f"Saved output: {tmp_out.name}\n{msg or ''}") if False else (read_audio(tmp_out.as_posix()), msg or "Done.") | |
| else: | |
| # On failure: DON'T echo input audio β return None and the error message | |
| if not msg: | |
| msg = "Inference failed without a specific error message." | |
| return (None, msg.strip()) | |
| except gr.Error as e: | |
| return (None, str(e)) | |
| except Exception as e: | |
| return (None, f"Unexpected error: {e}\n{traceback.format_exc()}") | |
| # ---------- Gradio UI ---------- | |
| PROMPT_EXAMPLES = [ | |
| ["Increase the clarity of this song by emphasizing treble frequencies."], | |
| ["Make this song sound more boomy by amplifying the low end bass frequencies."], | |
| ["Make the audio smoother and less distorted."], | |
| ["Improve the balance in this song."], | |
| ["Reduce roominess/echo (dereverb)."], | |
| ["Raise the level of the vocals."], | |
| ["Give the song a wider stereo image."], | |
| ] | |
| with gr.Blocks(title="SonicMaster β Text-Guided Restoration & Mastering", fill_height=True) as demo: | |
| gr.Markdown("## π§ SonicMaster\nUpload audio, enter a prompt, then click **Enhance**.\n" | |
| "- Progress appears below during the first run (weights/repo download).\n" | |
| "- If something fails, you'll see the **error message** instead of the input audio.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| in_audio = gr.Audio(label="Input Audio", type="filepath") | |
| prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., Reduce reverb and brighten the vocals.") | |
| run_btn = gr.Button("π Enhance", variant="primary") | |
| gr.Examples( | |
| examples=PROMPT_EXAMPLES, | |
| inputs=[prompt], # prompt-only examples to avoid heavy file ops at startup | |
| label="Prompt Examples", | |
| ) | |
| with gr.Column(scale=1): | |
| out_audio = gr.Audio(label="Enhanced Audio (output)") | |
| status = gr.Textbox(label="Status / Messages", interactive=False, lines=6) | |
| # On click, return audio + message | |
| run_btn.click( | |
| fn=enhance_audio_ui, | |
| inputs=[in_audio, prompt], | |
| outputs=[out_audio, status], | |
| concurrency_limit=1, | |
| ) | |
| # Queue BEFORE mounting so the mounted app is ready immediately | |
| demo = demo.queue(concurrency_count=1, max_size=16) | |
| # ---------- FastAPI mount & health ---------- | |
| from fastapi import FastAPI, Request | |
| from starlette.responses import PlainTextResponse | |
| try: | |
| from starlette.exceptions import ClientDisconnect # Starlette β₯0.27 | |
| except Exception: | |
| from starlette.requests import ClientDisconnect # fallback for older versions | |
| app = FastAPI() | |
| def _health(): | |
| return {"ok": True} | |
| async def client_disconnect_handler(request: Request, exc: ClientDisconnect): | |
| return PlainTextResponse("Client disconnected", status_code=499) | |
| # Mount Gradio at root (Spaces looks here) | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |