SonicMaster / app.py
ambujm22's picture
Update app.py
e22a58e verified
raw
history blame
11.2 kB
# ---------- MUST BE FIRST: Gradio CDN + ZeroGPU probe ----------
import os
os.environ.setdefault("GRADIO_USE_CDN", "true")
import spaces
@spaces.GPU(duration=10)
def _gpu_probe(a: int = 1, b: int = 1) -> int:
# Never called; exists so ZeroGPU startup check passes.
return a + b
# ---------- Standard imports ----------
from pathlib import Path
from typing import Optional, Tuple, List
import subprocess
import sys
import traceback
import gradio as gr
import numpy as np
import soundfile as sf
from huggingface_hub import hf_hub_download
# ---------- Config ----------
SPACE_ROOT = Path(__file__).parent.resolve()
REPO_DIR = SPACE_ROOT / "SonicMasterRepo"
REPO_URL = "https://github.com/AMAAI-Lab/SonicMaster"
WEIGHTS_REPO = "amaai-lab/SonicMaster"
WEIGHTS_FILE = "model.safetensors"
CACHE_DIR = SPACE_ROOT / "weights"
CACHE_DIR.mkdir(parents=True, exist_ok=True)
# ZeroGPU detection (heuristic)
USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
# ---------- Lazy resources ----------
_weights_path: Optional[Path] = None
_repo_ready: bool = False
def get_weights_path(progress: Optional[gr.Progress] = None) -> Path:
"""Fetch model weights lazily and cache the resolved path."""
global _weights_path
if _weights_path is None:
if progress:
progress(0.10, desc="Downloading model weights (first run)")
wp = hf_hub_download(
repo_id=WEIGHTS_REPO,
filename=WEIGHTS_FILE,
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
force_download=False,
resume_download=True,
)
_weights_path = Path(wp)
return _weights_path
def ensure_repo(progress: Optional[gr.Progress] = None) -> Path:
"""Clone the inference repo lazily and put it on sys.path."""
global _repo_ready
if not _repo_ready:
if not REPO_DIR.exists():
if progress:
progress(0.18, desc="Cloning SonicMaster repo (first run)")
subprocess.run(
["git", "clone", "--depth", "1", REPO_URL, REPO_DIR.as_posix()],
check=True,
)
if REPO_DIR.as_posix() not in sys.path:
sys.path.append(REPO_DIR.as_posix())
_repo_ready = True
return REPO_DIR
# ---------- Audio helpers ----------
def save_temp_wav(wav: np.ndarray, sr: int, path: Path):
# Ensure (N, C) shape for soundfile
if wav.ndim == 1:
data = wav
else:
# (channels, samples) -> (samples, channels)
data = wav.T if wav.shape[0] < wav.shape[1] else wav
if data.dtype == np.float64:
data = data.astype(np.float32)
sf.write(path.as_posix(), data, sr)
def read_audio(path: str) -> Tuple[np.ndarray, int]:
wav, sr = sf.read(path, always_2d=False)
if wav.dtype == np.float64:
wav = wav.astype(np.float32)
return wav, sr
# ---------- CLI runner ----------
def _candidate_commands(py: str, script: Path, ckpt: Path, inp: Path, prompt: str, out: Path) -> List[List[str]]:
"""Try multiple arg styles commonly found in repos."""
combos = [
# infer_single.py (common)
[py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--prompt", prompt, "--output", out.as_posix()],
[py, script.as_posix(), "--weights", ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--out", out.as_posix()],
# other possible entrypoints
[py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--output", out.as_posix()],
]
return combos
def run_sonicmaster_cli(
input_wav_path: Path,
prompt: str,
out_path: Path,
progress: Optional[gr.Progress] = None,
) -> Tuple[bool, str]:
"""
Returns (ok, message). Captures stderr/stdout and returns first non-empty output file.
"""
if progress:
progress(0.14, desc="Preparing inference")
ckpt = get_weights_path(progress=progress)
repo = ensure_repo(progress=progress)
# Candidate scripts to try
script_candidates = [
repo / "infer_single.py",
repo / "inference_fullsong.py",
repo / "inference_ptload_batch.py",
]
scripts = [s for s in script_candidates if s.exists()]
if not scripts:
return False, "No inference script found in the repo (expected infer_single.py or similar)."
py = sys.executable or "python3"
env = os.environ.copy() # keep CUDA_VISIBLE_DEVICES etc.
last_err = ""
for idx, script in enumerate(scripts, start=1):
for jdx, cmd in enumerate(_candidate_commands(py, script, ckpt, input_wav_path, prompt, out_path), start=1):
try:
if progress:
progress(min(0.20 + 0.08 * (idx + jdx), 0.70), desc=f"Running {script.name} (try {idx}.{jdx})")
res = subprocess.run(cmd, capture_output=True, text=True, check=True, env=env)
if out_path.exists() and out_path.stat().st_size > 0:
if progress:
progress(0.88, desc="Post-processing output")
# Return any informative stdout as message
msg = (res.stdout or "").strip()
return True, msg if msg else "Inference completed."
else:
last_err = f"{script.name} produced no output file."
except subprocess.CalledProcessError as e:
# Collect stderr/stdout for the user
snippet = "\n".join(filter(None, [e.stdout or "", e.stderr or ""])).strip()
last_err = snippet if snippet else f"{script.name} failed with return code {e.returncode}."
except Exception as e:
last_err = f"Unexpected error: {e}\n{traceback.format_exc()}"
return False, last_err or "All candidate commands failed without an error message."
# ---------- REAL GPU function (called only if using ZeroGPU / GPU available) ----------
@spaces.GPU(duration=180)
def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> Tuple[bool, str]:
try:
# Initialize CUDA inside the GPU context
import torch # noqa: F401
except Exception:
pass
from pathlib import Path as _P
return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), progress=None)
def _has_cuda() -> bool:
try:
import torch
return torch.cuda.is_available()
except Exception:
return False
# ---------- UI callback ----------
def enhance_audio_ui(
audio_path: str,
prompt: str,
progress=gr.Progress(track_tqdm=True),
) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
"""
Returns (audio, message). On failure, audio=None and message=error text.
"""
try:
if not prompt:
raise gr.Error("Please provide a text prompt.")
if not audio_path:
raise gr.Error("Please upload or select an input audio file.")
wav, sr = read_audio(audio_path)
tmp_in = SPACE_ROOT / "tmp_in.wav"
tmp_out = SPACE_ROOT / "tmp_out.wav"
if tmp_out.exists():
try:
tmp_out.unlink()
except Exception:
pass
if progress:
progress(0.06, desc="Preparing audio")
save_temp_wav(wav, sr, tmp_in)
# Choose execution path: prefer real GPU if available, else CPU
use_gpu_call = USE_ZEROGPU or _has_cuda()
if progress:
progress(0.12, desc="Starting inference")
if use_gpu_call:
ok, msg = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix())
else:
ok, msg = run_sonicmaster_cli(tmp_in, prompt, tmp_out, progress=progress)
if ok and tmp_out.exists() and tmp_out.stat().st_size > 0:
# Return output audio by filepath (lighter than big arrays)
# Gradio Audio accepts a (sr, np.ndarray) OR a file path; giving file path is fine.
return (None, f"Saved output: {tmp_out.name}\n{msg or ''}") if False else (read_audio(tmp_out.as_posix()), msg or "Done.")
else:
# On failure: DON'T echo input audio β€” return None and the error message
if not msg:
msg = "Inference failed without a specific error message."
return (None, msg.strip())
except gr.Error as e:
return (None, str(e))
except Exception as e:
return (None, f"Unexpected error: {e}\n{traceback.format_exc()}")
# ---------- Gradio UI ----------
PROMPT_EXAMPLES = [
["Increase the clarity of this song by emphasizing treble frequencies."],
["Make this song sound more boomy by amplifying the low end bass frequencies."],
["Make the audio smoother and less distorted."],
["Improve the balance in this song."],
["Reduce roominess/echo (dereverb)."],
["Raise the level of the vocals."],
["Give the song a wider stereo image."],
]
with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill_height=True) as demo:
gr.Markdown("## 🎧 SonicMaster\nUpload audio, enter a prompt, then click **Enhance**.\n"
"- Progress appears below during the first run (weights/repo download).\n"
"- If something fails, you'll see the **error message** instead of the input audio.")
with gr.Row():
with gr.Column(scale=1):
in_audio = gr.Audio(label="Input Audio", type="filepath")
prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., Reduce reverb and brighten the vocals.")
run_btn = gr.Button("πŸš€ Enhance", variant="primary")
gr.Examples(
examples=PROMPT_EXAMPLES,
inputs=[prompt], # prompt-only examples to avoid heavy file ops at startup
label="Prompt Examples",
)
with gr.Column(scale=1):
out_audio = gr.Audio(label="Enhanced Audio (output)")
status = gr.Textbox(label="Status / Messages", interactive=False, lines=6)
# On click, return audio + message
run_btn.click(
fn=enhance_audio_ui,
inputs=[in_audio, prompt],
outputs=[out_audio, status],
concurrency_limit=1,
)
# Queue BEFORE mounting so the mounted app is ready immediately
demo = demo.queue(concurrency_count=1, max_size=16)
# ---------- FastAPI mount & health ----------
from fastapi import FastAPI, Request
from starlette.responses import PlainTextResponse
try:
from starlette.exceptions import ClientDisconnect # Starlette β‰₯0.27
except Exception:
from starlette.requests import ClientDisconnect # fallback for older versions
app = FastAPI()
@app.get("/health")
def _health():
return {"ok": True}
@app.exception_handler(ClientDisconnect)
async def client_disconnect_handler(request: Request, exc: ClientDisconnect):
return PlainTextResponse("Client disconnected", status_code=499)
# Mount Gradio at root (Spaces looks here)
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)