Spaces:
Running
on
Zero
Running
on
Zero
| # ========== MUST BE FIRST: Gradio SDK entry + ZeroGPU probes ========== | |
| import os | |
| os.environ.setdefault("GRADIO_USE_CDN", "true") | |
| # Optional: 'spaces' present on Spaces; harmless to try locally. | |
| try: | |
| import spaces | |
| except Exception: | |
| class _DummySpaces: | |
| def GPU(self, *_, **__): | |
| def deco(fn): return fn | |
| return deco | |
| spaces = _DummySpaces() | |
| # PUBLIC names so ZeroGPU supervisor can detect them | |
| def gpu_probe(a: int = 1, b: int = 1): | |
| return a + b | |
| def gpu_echo(x: str = "ok"): | |
| return x | |
| # ========== Standard imports ========== | |
| import sys | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Tuple, Optional, List, Dict, Any | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| from huggingface_hub import hf_hub_download | |
| # ZeroGPU runtime hint (safe on CPU) | |
| USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu" | |
| SPACE_ROOT = Path(__file__).parent.resolve() | |
| REPO_DIR = SPACE_ROOT / "SonicMasterRepo" | |
| REPO_URL = "https://github.com/AMAAI-Lab/SonicMaster" | |
| WEIGHTS_REPO = "amaai-lab/SonicMaster" | |
| WEIGHTS_FILE = "model.safetensors" | |
| CACHE_DIR = SPACE_ROOT / "weights" | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| # ========== Lazy resources (no heavy work at import) ========== | |
| _weights_path: Optional[Path] = None | |
| _repo_ready: bool = False | |
| def get_weights_path(progress: Optional[gr.Progress] = None) -> Path: | |
| """Download/resolve weights lazily.""" | |
| global _weights_path | |
| if _weights_path is None: | |
| if progress: progress(0.10, desc="Downloading model weights (first run)") | |
| wp = hf_hub_download( | |
| repo_id=WEIGHTS_REPO, | |
| filename=WEIGHTS_FILE, | |
| local_dir=str(CACHE_DIR), | |
| local_dir_use_symlinks=False, | |
| force_download=False, | |
| resume_download=True, | |
| ) | |
| _weights_path = Path(wp) | |
| return _weights_path | |
| def ensure_repo(progress: Optional[gr.Progress] = None) -> Path: | |
| """Clone the repo lazily and add to sys.path.""" | |
| global _repo_ready | |
| if not _repo_ready: | |
| if not REPO_DIR.exists(): | |
| if progress: progress(0.18, desc="Cloning SonicMaster repo (first run)") | |
| subprocess.run( | |
| ["git", "clone", "--depth", "1", REPO_URL, REPO_DIR.as_posix()], | |
| check=True, | |
| ) | |
| if REPO_DIR.as_posix() not in sys.path: | |
| sys.path.append(REPO_DIR.as_posix()) | |
| _repo_ready = True | |
| return REPO_DIR | |
| # ========== Helpers ========== | |
| def save_temp_wav(wav: np.ndarray, sr: int, path: Path): | |
| # Ensure shape (samples, channels) | |
| if wav.ndim == 2 and wav.shape[0] < wav.shape[1]: | |
| wav = wav.T | |
| if wav.dtype == np.float64: | |
| wav = wav.astype(np.float32) | |
| sf.write(path.as_posix(), wav, sr) | |
| def read_audio(path: str) -> Tuple[np.ndarray, int]: | |
| wav, sr = sf.read(path, always_2d=False) | |
| if wav.dtype == np.float64: | |
| wav = wav.astype(np.float32) | |
| return wav, sr | |
| def _candidate_commands(py: str, script: Path, ckpt: Path, inp: Path, prompt: str, out: Path) -> List[List[str]]: | |
| # Try common flag layouts | |
| return [ | |
| [py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--prompt", prompt, "--output", out.as_posix()], | |
| [py, script.as_posix(), "--weights",ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--out", out.as_posix()], | |
| [py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--output", out.as_posix()], | |
| ] | |
| def run_sonicmaster_cli( | |
| input_wav_path: Path, | |
| prompt: str, | |
| out_path: Path, | |
| progress: Optional[gr.Progress] = None, | |
| ) -> Tuple[bool, str]: | |
| """Run inference scripts via subprocess; return (ok, message).""" | |
| if progress: progress(0.14, desc="Preparing inference") | |
| ckpt = get_weights_path(progress=progress) | |
| repo = ensure_repo(progress=progress) | |
| candidates = [repo / "infer_single.py", repo / "inference_fullsong.py", repo / "inference_ptload_batch.py"] | |
| scripts = [s for s in candidates if s.exists()] | |
| if not scripts: | |
| return False, "No inference script found in the repo (expected infer_single.py or similar)." | |
| py = sys.executable or "python3" | |
| env = os.environ.copy() | |
| last_err = "" | |
| for sidx, script in enumerate(scripts, 1): | |
| for cidx, cmd in enumerate(_candidate_commands(py, script, ckpt, input_wav_path, prompt, out_path), 1): | |
| try: | |
| if progress: | |
| progress(min(0.20 + 0.08 * (sidx + cidx), 0.70), desc=f"Running {script.name} (try {sidx}.{cidx})") | |
| res = subprocess.run(cmd, capture_output=True, text=True, check=True, env=env) | |
| if out_path.exists() and out_path.stat().st_size > 0: | |
| if progress: progress(0.88, desc="Post-processing output") | |
| return True, (res.stdout or "Inference completed.").strip() | |
| last_err = f"{script.name} produced no output file." | |
| except subprocess.CalledProcessError as e: | |
| snippet = "\n".join(filter(None, [e.stdout or "", e.stderr or ""])).strip() | |
| last_err = snippet if snippet else f"{script.name} failed with return code {e.returncode}." | |
| except Exception as e: | |
| import traceback | |
| last_err = f"Unexpected error: {e}\n{traceback.format_exc()}" | |
| return False, last_err or "All candidate commands failed." | |
| # ========== GPU path (called only if ZeroGPU/GPU available) ========== | |
| def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> Tuple[bool, str]: | |
| try: | |
| import torch # noqa: F401 | |
| except Exception: | |
| pass | |
| from pathlib import Path as _P | |
| return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), progress=None) | |
| def _has_cuda() -> bool: | |
| try: | |
| import torch | |
| return torch.cuda.is_available() | |
| except Exception: | |
| return False | |
| # ========== Examples (lazy) ========== | |
| PROMPTS_10 = [ | |
| "Increase the clarity of this song by emphasizing treble frequencies.", | |
| "Make this song sound more boomy by amplifying the low end bass frequencies.", | |
| "Can you make this sound louder, please?", | |
| "Make the audio smoother and less distorted.", | |
| "Improve the balance in this song.", | |
| "Disentangle the left and right channels to give this song a stereo feeling.", | |
| "Correct the unnatural frequency emphasis. Reduce the roominess or echo.", | |
| "Raise the level of the vocals, please.", | |
| "Increase the clarity of this song by emphasizing treble frequencies.", | |
| "Please, dereverb this audio.", | |
| ] | |
| def list_example_files(progress: Optional[gr.Progress] = None) -> List[str]: | |
| """Return up to 10 .wav paths inside repo/samples/inputs (lazy clone).""" | |
| repo = ensure_repo(progress=progress) | |
| wav_dir = repo / "samples" / "inputs" | |
| files = sorted(p for p in wav_dir.glob("*.wav") if p.is_file()) | |
| return [p.as_posix() for p in files[:10]] | |
| def load_examples(_: Any = None, progress=gr.Progress()) -> Dict[str, Any]: | |
| """Button/auto-load handler: populate dropdown choices and status text.""" | |
| paths = list_example_files(progress=progress) | |
| if not paths: | |
| return { | |
| "choices": [], | |
| "status": "No sample .wav files found in repo/samples/inputs.", | |
| } | |
| labels = [f"{i+1:02d} β {Path(p).name}" for i, p in enumerate(paths)] | |
| return { | |
| "choices": labels, | |
| "paths": paths, | |
| "status": f"Loaded {len(paths)} sample audios." | |
| } | |
| def set_example_selection(idx_label: str, paths: List[str]) -> Tuple[str, str]: | |
| """When user picks an example, set the audio path + a suggested prompt.""" | |
| if not idx_label or not paths: | |
| return "", "" | |
| try: | |
| # label "01 β file.wav" -> index 0 | |
| idx = int(idx_label.split()[0]) - 1 | |
| except Exception: | |
| idx = 0 | |
| idx = max(0, min(idx, len(paths)-1)) | |
| audio_path = paths[idx] | |
| prompt = PROMPTS_10[idx] if idx < len(PROMPTS_10) else PROMPTS_10[-1] | |
| return audio_path, prompt | |
| # ========== Gradio callback ========== | |
| def enhance_audio_ui( | |
| audio_path: str, | |
| prompt: str, | |
| progress=gr.Progress(track_tqdm=True), | |
| ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: | |
| """ | |
| Returns (audio, message). On failure, audio=None and message=error text. | |
| """ | |
| try: | |
| if not prompt: | |
| raise gr.Error("Please provide a text prompt.") | |
| if not audio_path: | |
| raise gr.Error("Please upload or select an input audio file.") | |
| wav, sr = read_audio(audio_path) | |
| tmp_in = SPACE_ROOT / "tmp_in.wav" | |
| tmp_out = SPACE_ROOT / "tmp_out.wav" | |
| if tmp_out.exists(): | |
| try: tmp_out.unlink() | |
| except Exception: pass | |
| if progress: progress(0.06, desc="Preparing audio") | |
| save_temp_wav(wav, sr, tmp_in) | |
| use_gpu_call = USE_ZEROGPU or _has_cuda() | |
| if progress: progress(0.12, desc="Starting inference") | |
| if use_gpu_call: | |
| ok, msg = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix()) | |
| else: | |
| ok, msg = run_sonicmaster_cli(tmp_in, prompt, tmp_out, progress=progress) | |
| if ok and tmp_out.exists() and tmp_out.stat().st_size > 0: | |
| out_wav, out_sr = read_audio(tmp_out.as_posix()) | |
| return (out_sr, out_wav), (msg or "Done.") | |
| else: | |
| return None, (msg or "Inference failed without a specific error message.") | |
| except gr.Error as e: | |
| return None, str(e) | |
| except Exception as e: | |
| import traceback | |
| return None, f"Unexpected error: {e}\n{traceback.format_exc()}" | |
| # ========== Gradio UI ========== | |
| with gr.Blocks(title="SonicMaster β Text-Guided Restoration & Mastering", fill_height=True) as _demo: | |
| gr.Markdown( | |
| "## π§ SonicMaster\n" | |
| "Upload audio or **load sample audios**, write a prompt, then click **Enhance**.\n" | |
| "- On failure, the **Status** box shows the exact error " | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Sample loader (lazy) | |
| with gr.Accordion("Sample audios (10)", open=False): | |
| load_btn = gr.Button("π₯ Load 10 sample audios") | |
| samples_dropdown = gr.Dropdown(choices=[], label="Pick a sample", interactive=True) | |
| samples_state = gr.State([]) # holds absolute paths | |
| in_audio = gr.Audio(label="Input Audio", type="filepath") | |
| prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., Reduce reverb and brighten vocals.") | |
| run_btn = gr.Button("π Enhance", variant="primary") | |
| # Optional quick prompt examples (text-only) | |
| gr.Examples( | |
| examples=[[p] for p in [ | |
| "Reduce roominess/echo (dereverb).", | |
| "Raise the level of the vocals.", | |
| "Give the song a wider stereo image.", | |
| ]], | |
| inputs=[prompt], | |
| label="Prompt Examples", | |
| ) | |
| with gr.Column(scale=1): | |
| out_audio = gr.Audio(label="Enhanced Audio (output)") | |
| status = gr.Textbox(label="Status / Messages", interactive=False, lines=8) | |
| # --- Wire up the sample loader --- | |
| # 1) Load samples on button click (lazy clone) | |
| load_result = load_btn.click( | |
| fn=load_examples, | |
| inputs=None, | |
| outputs=None | |
| ) | |
| # Manually map the dict result to components via .then (Gradio v5 API) | |
| load_result.then(lambda d: d.get("choices", []), None, samples_dropdown) | |
| load_result.then(lambda d: d.get("paths", []), None, samples_state) | |
| load_result.then(lambda d: d.get("status", ""), None, status) | |
| # 2) When a sample is chosen, set audio path + suggested prompt | |
| samples_dropdown.change( | |
| fn=set_example_selection, | |
| inputs=[samples_dropdown, samples_state], | |
| outputs=[in_audio, prompt], | |
| ) | |
| # --- Enhance button --- | |
| run_btn.click( | |
| fn=enhance_audio_ui, | |
| inputs=[in_audio, prompt], | |
| outputs=[out_audio, status], | |
| concurrency_limit=1, | |
| ) | |
| # Expose all common names the supervisor might look for | |
| demo = _demo.queue(max_size=16) | |
| iface = demo | |
| app = demo | |
| # Local debugging only | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |