ambujm22 commited on
Commit
3c1f2a9
Β·
verified Β·
1 Parent(s): 08f8861

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -14
app.py CHANGED
@@ -1,26 +1,204 @@
1
- # --- ZeroGPU must see a GPU-decorated function at import time ---
 
 
 
 
 
2
  import spaces
3
 
4
  @spaces.GPU(duration=10)
5
- def _gpu_probe():
 
6
  return "ok"
7
 
8
- # --- Gradio app kept trivial to prove boot path ---
 
 
 
 
 
9
  import gradio as gr
10
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def echo(x):
13
- return f"ok: {x}"
14
 
15
- with gr.Blocks(title="Hello") as demo:
16
- inp = gr.Textbox(label="Say something")
17
- out = gr.Textbox(label="Reply")
18
- inp.submit(echo, inp, out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Expose *either* 'demo' or a FastAPI 'app'. We'll use FastAPI + mount:
21
  app = FastAPI()
22
 
23
- # Mount Gradio at root so Spaces healthcheck to "/" gets 200
24
- app = gr.mount_gradio_app(app, demo.queue(), path="/")
 
 
 
25
 
26
- # DO NOT run uvicorn here β€” Spaces runs the server.
 
 
 
1
+ # ---------- MUST BE FIRST: Gradio CDN + ZeroGPU probe ----------
2
+ import os
3
+ os.environ.setdefault("GRADIO_USE_CDN", "true")
4
+
5
+ # A GPU-decorated function MUST exist at import time for ZeroGPU.
6
+ # Import spaces unconditionally and register a tiny probe.
7
  import spaces
8
 
9
  @spaces.GPU(duration=10)
10
+ def _gpu_probe() -> str:
11
+ # Never called; only here so ZeroGPU startup check passes.
12
  return "ok"
13
 
14
+ # ---------- Standard imports ----------
15
+ import sys
16
+ import subprocess
17
+ from pathlib import Path
18
+ from typing import Tuple, Optional
19
+
20
  import gradio as gr
21
+ import numpy as np
22
+ import soundfile as sf
23
+ from huggingface_hub import hf_hub_download
24
+
25
+ # Detect ZeroGPU to decide whether to CALL the GPU function.
26
+ USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
27
+
28
+ SPACE_ROOT = Path(__file__).parent.resolve()
29
+ REPO_DIR = SPACE_ROOT / "SonicMasterRepo"
30
+ WEIGHTS_REPO = "amaai-lab/SonicMaster"
31
+ WEIGHTS_FILE = "model.safetensors"
32
+ CACHE_DIR = SPACE_ROOT / "weights"
33
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
34
+
35
+ # ---------- 1) Pull weights from HF Hub ----------
36
+ def get_weights_path() -> Path:
37
+ return Path(
38
+ hf_hub_download(
39
+ repo_id=WEIGHTS_REPO,
40
+ filename=WEIGHTS_FILE,
41
+ local_dir=CACHE_DIR.as_posix(),
42
+ local_dir_use_symlinks=False,
43
+ force_download=False,
44
+ resume_download=True,
45
+ )
46
+ )
47
+
48
+ # ---------- 2) Clone GitHub repo ----------
49
+ def ensure_repo() -> Path:
50
+ if not REPO_DIR.exists():
51
+ subprocess.run(
52
+ ["git", "clone", "--depth", "1",
53
+ "https://github.com/AMAAI-Lab/SonicMaster",
54
+ REPO_DIR.as_posix()],
55
+ check=True,
56
+ )
57
+ if REPO_DIR.as_posix() not in sys.path:
58
+ sys.path.append(REPO_DIR.as_posix())
59
+ return REPO_DIR
60
+
61
+ # ---------- 3) Examples ----------
62
+ def build_examples():
63
+ repo = ensure_repo()
64
+ wav_dir = repo / "samples" / "inputs"
65
+ wav_paths = sorted(p for p in wav_dir.glob("*.wav") if p.is_file())
66
+ prompts = [
67
+ "Increase the clarity of this song by emphasizing treble frequencies.",
68
+ "Make this song sound more boomy by amplifying the low end bass frequencies.",
69
+ "Can you make this sound louder, please?",
70
+ "Make the audio smoother and less distorted.",
71
+ "Improve the balance in this song.",
72
+ "Disentangle the left and right channels to give this song a stereo feeling.",
73
+ "Correct the unnatural frequency emphasis. Reduce the roominess or echo.",
74
+ "Raise the level of the vocals, please.",
75
+ "Increase the clarity of this song by emphasizing treble frequencies.",
76
+ "Please, dereverb this audio.",
77
+ ]
78
+ return [[p.as_posix(), prompts[i] if i < len(prompts) else prompts[-1]]
79
+ for i, p in enumerate(wav_paths[:10])]
80
+
81
+ # ---------- 4) I/O helpers ----------
82
+ def save_temp_wav(wav: np.ndarray, sr: int, path: Path):
83
+ if wav.ndim == 2 and wav.shape[0] < wav.shape[1]:
84
+ wav = wav.T
85
+ sf.write(path.as_posix(), wav, sr)
86
+
87
+ def read_audio(path: str) -> Tuple[np.ndarray, int]:
88
+ wav, sr = sf.read(path, always_2d=False)
89
+ return wav.astype(np.float32) if wav.dtype == np.float64 else wav, sr
90
+
91
+ # ---------- 5) Core inference (subprocess calling your repo script) ----------
92
+ def run_sonicmaster_cli(input_wav_path: Path,
93
+ prompt: str,
94
+ out_path: Path,
95
+ _logs: list,
96
+ progress: Optional[gr.Progress] = None) -> bool:
97
+ if progress: progress(0.15, desc="Loading weights & repo")
98
+ ckpt = get_weights_path()
99
+ repo = ensure_repo()
100
 
101
+ py = sys.executable or "python3"
102
+ script_candidates = [repo / "infer_single.py"]
103
 
104
+ CANDIDATE_CMDS = []
105
+ for script in script_candidates:
106
+ if script.exists():
107
+ CANDIDATE_CMDS.append([
108
+ py, script.as_posix(),
109
+ "--ckpt", ckpt.as_posix(),
110
+ "--input", input_wav_path.as_posix(),
111
+ "--prompt", prompt,
112
+ "--output", out_path.as_posix(),
113
+ ])
114
+ CANDIDATE_CMDS.append([
115
+ py, script.as_posix(),
116
+ "--weights", ckpt.as_posix(),
117
+ "--input", input_wav_path.as_posix(),
118
+ "--text", prompt,
119
+ "--out", out_path.as_posix(),
120
+ ])
121
+
122
+ for idx, cmd in enumerate(CANDIDATE_CMDS, start=1):
123
+ try:
124
+ if progress: progress(0.35 + 0.05*idx, desc=f"Running inference (try {idx})")
125
+ # inherit env so CUDA_VISIBLE_DEVICES from ZeroGPU reaches subprocess
126
+ subprocess.run(cmd, capture_output=True, text=True, check=True, env=os.environ.copy())
127
+ if out_path.exists() and out_path.stat().st_size > 0:
128
+ if progress: progress(0.9, desc="Post-processing output")
129
+ return True
130
+ except Exception:
131
+ continue
132
+ return False
133
+
134
+ # ---------- 6) REAL GPU function (always defined; only CALLED on ZeroGPU) ----------
135
+ @spaces.GPU(duration=180)
136
+ def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> bool:
137
+ # Import torch here so CUDA initializes inside GPU context
138
+ try:
139
+ import torch # noqa: F401
140
+ except Exception:
141
+ pass
142
+ from pathlib import Path as _P
143
+ return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), _logs=[], progress=None)
144
+
145
+ # ---------- 7) Gradio callback ----------
146
+ def enhance_audio_ui(audio_path: str,
147
+ prompt: str,
148
+ progress=gr.Progress(track_tqdm=True)) -> Tuple[int, np.ndarray]:
149
+ if not audio_path or not prompt:
150
+ raise gr.Error("Please provide audio and a text prompt.")
151
+
152
+ wav, sr = read_audio(audio_path)
153
+ tmp_in, tmp_out = SPACE_ROOT / "tmp_in.wav", SPACE_ROOT / "tmp_out.wav"
154
+ if tmp_out.exists():
155
+ try: tmp_out.unlink()
156
+ except: pass
157
+ save_temp_wav(wav, sr, tmp_in)
158
+
159
+ if progress: progress(0.3, desc="Starting inference")
160
+ if USE_ZEROGPU:
161
+ ok = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix())
162
+ else:
163
+ ok = run_sonicmaster_cli(tmp_in, prompt, tmp_out, _logs=[], progress=progress)
164
+
165
+ if ok and tmp_out.exists() and tmp_out.stat().st_size > 0:
166
+ out_wav, out_sr = read_audio(tmp_out.as_posix())
167
+ return (out_sr, out_wav)
168
+ else:
169
+ return (sr, wav)
170
+
171
+ # ---------- 8) Gradio UI ----------
172
+ with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill_height=True) as demo:
173
+ gr.Markdown("## 🎧 SonicMaster\nUpload or choose an example, write a text prompt, then click **Enhance**.")
174
+ with gr.Row():
175
+ with gr.Column():
176
+ in_audio = gr.Audio(label="Input Audio", type="filepath")
177
+ prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., reduce reverb")
178
+ run_btn = gr.Button("πŸš€ Enhance", variant="primary")
179
+ gr.Examples(examples=build_examples(), inputs=[in_audio, prompt])
180
+ with gr.Column():
181
+ out_audio = gr.Audio(label="Enhanced Audio (output)")
182
+ run_btn.click(fn=enhance_audio_ui,
183
+ inputs=[in_audio, prompt],
184
+ outputs=[out_audio],
185
+ concurrency_limit=1)
186
+
187
+ # ---------- 9) FastAPI mount & disconnect handler ----------
188
+ from fastapi import FastAPI, Request
189
+ from starlette.responses import PlainTextResponse
190
+ from starlette.requests import ClientDisconnect
191
+
192
+ _ = get_weights_path(); _ = ensure_repo()
193
 
 
194
  app = FastAPI()
195
 
196
+ @app.exception_handler(ClientDisconnect)
197
+ async def client_disconnect_handler(request: Request, exc: ClientDisconnect):
198
+ return PlainTextResponse("Client disconnected", status_code=499)
199
+
200
+ app = gr.mount_gradio_app(app, demo.queue(max_size=16), path="/")
201
 
202
+ if __name__ == "__main__":
203
+ import uvicorn
204
+ uvicorn.run(app, host="0.0.0.0", port=7860)