File size: 11,162 Bytes
e22a58e
3c1f2a9
 
 
 
08f8861
9c3c7f1
edb1292
e22a58e
 
 
edb1292
3c1f2a9
 
e22a58e
 
 
 
3c1f2a9
5f61a8c
3c1f2a9
 
 
 
e22a58e
3c1f2a9
 
e22a58e
3c1f2a9
 
 
 
 
e22a58e
 
 
 
 
 
 
 
 
 
 
 
 
 
3c1f2a9
 
e22a58e
3c1f2a9
 
 
 
e22a58e
 
3c1f2a9
e22a58e
 
 
 
 
 
 
 
 
 
 
 
 
 
3c1f2a9
 
e22a58e
3c1f2a9
e22a58e
 
 
 
 
 
 
 
 
3c1f2a9
 
 
e22a58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c3c7f1
3c1f2a9
e22a58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c1f2a9
e22a58e
3c1f2a9
e22a58e
3c1f2a9
 
 
 
e22a58e
3c1f2a9
e22a58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c1f2a9
 
e22a58e
 
 
3c1f2a9
e22a58e
3c1f2a9
e22a58e
 
 
 
 
 
 
 
3c1f2a9
e22a58e
3c1f2a9
e22a58e
 
 
 
 
 
 
 
 
 
 
 
3c1f2a9
 
e22a58e
 
 
 
edb1292
7108aaf
9c3c7f1
e22a58e
 
 
 
3c1f2a9
 
 
 
e22a58e
 
9c3c7f1
3c1f2a9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

# ---------- MUST BE FIRST: Gradio CDN + ZeroGPU probe ----------
import os
os.environ.setdefault("GRADIO_USE_CDN", "true")

import spaces

@spaces.GPU(duration=10)
def _gpu_probe(a: int = 1, b: int = 1) -> int:
    # Never called; exists so ZeroGPU startup check passes.
    return a + b

# ---------- Standard imports ----------
from pathlib import Path
from typing import Optional, Tuple, List
import subprocess
import sys
import traceback

import gradio as gr
import numpy as np
import soundfile as sf
from huggingface_hub import hf_hub_download

# ---------- Config ----------
SPACE_ROOT   = Path(__file__).parent.resolve()
REPO_DIR     = SPACE_ROOT / "SonicMasterRepo"
REPO_URL     = "https://github.com/AMAAI-Lab/SonicMaster"
WEIGHTS_REPO = "amaai-lab/SonicMaster"
WEIGHTS_FILE = "model.safetensors"
CACHE_DIR    = SPACE_ROOT / "weights"
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# ZeroGPU detection (heuristic)
USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"

# ---------- Lazy resources ----------
_weights_path: Optional[Path] = None
_repo_ready: bool = False

def get_weights_path(progress: Optional[gr.Progress] = None) -> Path:
    """Fetch model weights lazily and cache the resolved path."""
    global _weights_path
    if _weights_path is None:
        if progress:
            progress(0.10, desc="Downloading model weights (first run)")
        wp = hf_hub_download(
            repo_id=WEIGHTS_REPO,
            filename=WEIGHTS_FILE,
            local_dir=str(CACHE_DIR),
            local_dir_use_symlinks=False,
            force_download=False,
            resume_download=True,
        )
        _weights_path = Path(wp)
    return _weights_path

def ensure_repo(progress: Optional[gr.Progress] = None) -> Path:
    """Clone the inference repo lazily and put it on sys.path."""
    global _repo_ready
    if not _repo_ready:
        if not REPO_DIR.exists():
            if progress:
                progress(0.18, desc="Cloning SonicMaster repo (first run)")
            subprocess.run(
                ["git", "clone", "--depth", "1", REPO_URL, REPO_DIR.as_posix()],
                check=True,
            )
        if REPO_DIR.as_posix() not in sys.path:
            sys.path.append(REPO_DIR.as_posix())
        _repo_ready = True
    return REPO_DIR

# ---------- Audio helpers ----------
def save_temp_wav(wav: np.ndarray, sr: int, path: Path):
    # Ensure (N, C) shape for soundfile
    if wav.ndim == 1:
        data = wav
    else:
        # (channels, samples) -> (samples, channels)
        data = wav.T if wav.shape[0] < wav.shape[1] else wav
    if data.dtype == np.float64:
        data = data.astype(np.float32)
    sf.write(path.as_posix(), data, sr)

def read_audio(path: str) -> Tuple[np.ndarray, int]:
    wav, sr = sf.read(path, always_2d=False)
    if wav.dtype == np.float64:
        wav = wav.astype(np.float32)
    return wav, sr

# ---------- CLI runner ----------
def _candidate_commands(py: str, script: Path, ckpt: Path, inp: Path, prompt: str, out: Path) -> List[List[str]]:
    """Try multiple arg styles commonly found in repos."""
    combos = [
        # infer_single.py (common)
        [py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--prompt", prompt, "--output", out.as_posix()],
        [py, script.as_posix(), "--weights", ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--out", out.as_posix()],
        # other possible entrypoints
        [py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--output", out.as_posix()],
    ]
    return combos

def run_sonicmaster_cli(
    input_wav_path: Path,
    prompt: str,
    out_path: Path,
    progress: Optional[gr.Progress] = None,
) -> Tuple[bool, str]:
    """
    Returns (ok, message). Captures stderr/stdout and returns first non-empty output file.
    """
    if progress:
        progress(0.14, desc="Preparing inference")
    ckpt = get_weights_path(progress=progress)
    repo = ensure_repo(progress=progress)

    # Candidate scripts to try
    script_candidates = [
        repo / "infer_single.py",
        repo / "inference_fullsong.py",
        repo / "inference_ptload_batch.py",
    ]
    scripts = [s for s in script_candidates if s.exists()]
    if not scripts:
        return False, "No inference script found in the repo (expected infer_single.py or similar)."

    py = sys.executable or "python3"
    env = os.environ.copy()  # keep CUDA_VISIBLE_DEVICES etc.

    last_err = ""
    for idx, script in enumerate(scripts, start=1):
        for jdx, cmd in enumerate(_candidate_commands(py, script, ckpt, input_wav_path, prompt, out_path), start=1):
            try:
                if progress:
                    progress(min(0.20 + 0.08 * (idx + jdx), 0.70), desc=f"Running {script.name} (try {idx}.{jdx})")
                res = subprocess.run(cmd, capture_output=True, text=True, check=True, env=env)
                if out_path.exists() and out_path.stat().st_size > 0:
                    if progress:
                        progress(0.88, desc="Post-processing output")
                    # Return any informative stdout as message
                    msg = (res.stdout or "").strip()
                    return True, msg if msg else "Inference completed."
                else:
                    last_err = f"{script.name} produced no output file."
            except subprocess.CalledProcessError as e:
                # Collect stderr/stdout for the user
                snippet = "\n".join(filter(None, [e.stdout or "", e.stderr or ""])).strip()
                last_err = snippet if snippet else f"{script.name} failed with return code {e.returncode}."
            except Exception as e:
                last_err = f"Unexpected error: {e}\n{traceback.format_exc()}"

    return False, last_err or "All candidate commands failed without an error message."

# ---------- REAL GPU function (called only if using ZeroGPU / GPU available) ----------
@spaces.GPU(duration=180)
def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> Tuple[bool, str]:
    try:
        # Initialize CUDA inside the GPU context
        import torch  # noqa: F401
    except Exception:
        pass
    from pathlib import Path as _P
    return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), progress=None)

def _has_cuda() -> bool:
    try:
        import torch
        return torch.cuda.is_available()
    except Exception:
        return False

# ---------- UI callback ----------
def enhance_audio_ui(
    audio_path: str,
    prompt: str,
    progress=gr.Progress(track_tqdm=True),
) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
    """
    Returns (audio, message). On failure, audio=None and message=error text.
    """
    try:
        if not prompt:
            raise gr.Error("Please provide a text prompt.")
        if not audio_path:
            raise gr.Error("Please upload or select an input audio file.")

        wav, sr = read_audio(audio_path)

        tmp_in  = SPACE_ROOT / "tmp_in.wav"
        tmp_out = SPACE_ROOT / "tmp_out.wav"
        if tmp_out.exists():
            try:
                tmp_out.unlink()
            except Exception:
                pass

        if progress:
            progress(0.06, desc="Preparing audio")
        save_temp_wav(wav, sr, tmp_in)

        # Choose execution path: prefer real GPU if available, else CPU
        use_gpu_call = USE_ZEROGPU or _has_cuda()

        if progress:
            progress(0.12, desc="Starting inference")
        if use_gpu_call:
            ok, msg = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix())
        else:
            ok, msg = run_sonicmaster_cli(tmp_in, prompt, tmp_out, progress=progress)

        if ok and tmp_out.exists() and tmp_out.stat().st_size > 0:
            # Return output audio by filepath (lighter than big arrays)
            # Gradio Audio accepts a (sr, np.ndarray) OR a file path; giving file path is fine.
            return (None, f"Saved output: {tmp_out.name}\n{msg or ''}") if False else (read_audio(tmp_out.as_posix()), msg or "Done.")
        else:
            # On failure: DON'T echo input audio β€” return None and the error message
            if not msg:
                msg = "Inference failed without a specific error message."
            return (None, msg.strip())

    except gr.Error as e:
        return (None, str(e))
    except Exception as e:
        return (None, f"Unexpected error: {e}\n{traceback.format_exc()}")

# ---------- Gradio UI ----------
PROMPT_EXAMPLES = [
    ["Increase the clarity of this song by emphasizing treble frequencies."],
    ["Make this song sound more boomy by amplifying the low end bass frequencies."],
    ["Make the audio smoother and less distorted."],
    ["Improve the balance in this song."],
    ["Reduce roominess/echo (dereverb)."],
    ["Raise the level of the vocals."],
    ["Give the song a wider stereo image."],
]

with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill_height=True) as demo:
    gr.Markdown("## 🎧 SonicMaster\nUpload audio, enter a prompt, then click **Enhance**.\n"
                "- Progress appears below during the first run (weights/repo download).\n"
                "- If something fails, you'll see the **error message** instead of the input audio.")
    with gr.Row():
        with gr.Column(scale=1):
            in_audio = gr.Audio(label="Input Audio", type="filepath")
            prompt   = gr.Textbox(label="Text Prompt", placeholder="e.g., Reduce reverb and brighten the vocals.")
            run_btn  = gr.Button("πŸš€ Enhance", variant="primary")
            gr.Examples(
                examples=PROMPT_EXAMPLES,
                inputs=[prompt],  # prompt-only examples to avoid heavy file ops at startup
                label="Prompt Examples",
            )
        with gr.Column(scale=1):
            out_audio = gr.Audio(label="Enhanced Audio (output)")
            status    = gr.Textbox(label="Status / Messages", interactive=False, lines=6)

    # On click, return audio + message
    run_btn.click(
        fn=enhance_audio_ui,
        inputs=[in_audio, prompt],
        outputs=[out_audio, status],
        concurrency_limit=1,
    )

# Queue BEFORE mounting so the mounted app is ready immediately
demo = demo.queue(concurrency_count=1, max_size=16)

# ---------- FastAPI mount & health ----------
from fastapi import FastAPI, Request
from starlette.responses import PlainTextResponse
try:
    from starlette.exceptions import ClientDisconnect  # Starlette β‰₯0.27
except Exception:
    from starlette.requests import ClientDisconnect  # fallback for older versions

app = FastAPI()

@app.get("/health")
def _health():
    return {"ok": True}

@app.exception_handler(ClientDisconnect)
async def client_disconnect_handler(request: Request, exc: ClientDisconnect):
    return PlainTextResponse("Client disconnected", status_code=499)

# Mount Gradio at root (Spaces looks here)
app = gr.mount_gradio_app(app, demo, path="/")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)