ambujm22 commited on
Commit
e22a58e
Β·
verified Β·
1 Parent(s): 48bfead

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -140
app.py CHANGED
@@ -1,203 +1,289 @@
 
1
  # ---------- MUST BE FIRST: Gradio CDN + ZeroGPU probe ----------
2
  import os
3
  os.environ.setdefault("GRADIO_USE_CDN", "true")
4
 
5
- # A GPU-decorated function MUST exist at import time for ZeroGPU.
6
- # Import spaces unconditionally and register a tiny probe.
7
  import spaces
8
 
9
  @spaces.GPU(duration=10)
10
- def _gpu_probe() -> str:
11
- # Never called; only here so ZeroGPU startup check passes.
12
- return "ok"
13
 
14
  # ---------- Standard imports ----------
15
- import sys
16
- import subprocess
17
  from pathlib import Path
18
- from typing import Tuple, Optional
 
 
 
19
 
20
  import gradio as gr
21
  import numpy as np
22
  import soundfile as sf
23
  from huggingface_hub import hf_hub_download
24
 
25
- # Detect ZeroGPU to decide whether to CALL the GPU function.
26
- USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
27
-
28
  SPACE_ROOT = Path(__file__).parent.resolve()
29
  REPO_DIR = SPACE_ROOT / "SonicMasterRepo"
 
30
  WEIGHTS_REPO = "amaai-lab/SonicMaster"
31
  WEIGHTS_FILE = "model.safetensors"
32
  CACHE_DIR = SPACE_ROOT / "weights"
33
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
34
 
35
- # ---------- 1) Pull weights from HF Hub ----------
36
- def get_weights_path() -> Path:
37
- return Path(
38
- hf_hub_download(
 
 
 
 
 
 
 
 
 
 
39
  repo_id=WEIGHTS_REPO,
40
  filename=WEIGHTS_FILE,
41
- local_dir=CACHE_DIR.as_posix(),
42
  local_dir_use_symlinks=False,
43
  force_download=False,
44
  resume_download=True,
45
  )
46
- )
 
47
 
48
- # ---------- 2) Clone GitHub repo ----------
49
- def ensure_repo() -> Path:
50
- if not REPO_DIR.exists():
51
- subprocess.run(
52
- ["git", "clone", "--depth", "1",
53
- "https://github.com/AMAAI-Lab/SonicMaster",
54
- REPO_DIR.as_posix()],
55
- check=True,
56
- )
57
- if REPO_DIR.as_posix() not in sys.path:
58
- sys.path.append(REPO_DIR.as_posix())
 
 
 
59
  return REPO_DIR
60
 
61
- # ---------- 3) Examples ----------
62
- def build_examples():
63
- repo = ensure_repo()
64
- wav_dir = repo / "samples" / "inputs"
65
- wav_paths = sorted(p for p in wav_dir.glob("*.wav") if p.is_file())
66
- prompts = [
67
- "Increase the clarity of this song by emphasizing treble frequencies.",
68
- "Make this song sound more boomy by amplifying the low end bass frequencies.",
69
- "Can you make this sound louder, please?",
70
- "Make the audio smoother and less distorted.",
71
- "Improve the balance in this song.",
72
- "Disentangle the left and right channels to give this song a stereo feeling.",
73
- "Correct the unnatural frequency emphasis. Reduce the roominess or echo.",
74
- "Raise the level of the vocals, please.",
75
- "Increase the clarity of this song by emphasizing treble frequencies.",
76
- "Please, dereverb this audio.",
77
- ]
78
- return [[p.as_posix(), prompts[i] if i < len(prompts) else prompts[-1]]
79
- for i, p in enumerate(wav_paths[:10])]
80
-
81
- # ---------- 4) I/O helpers ----------
82
  def save_temp_wav(wav: np.ndarray, sr: int, path: Path):
83
- if wav.ndim == 2 and wav.shape[0] < wav.shape[1]:
84
- wav = wav.T
85
- sf.write(path.as_posix(), wav, sr)
 
 
 
 
 
 
86
 
87
  def read_audio(path: str) -> Tuple[np.ndarray, int]:
88
  wav, sr = sf.read(path, always_2d=False)
89
- return wav.astype(np.float32) if wav.dtype == np.float64 else wav, sr
90
-
91
- # ---------- 5) Core inference (subprocess calling your repo script) ----------
92
- def run_sonicmaster_cli(input_wav_path: Path,
93
- prompt: str,
94
- out_path: Path,
95
- _logs: list,
96
- progress: Optional[gr.Progress] = None) -> bool:
97
- if progress: progress(0.15, desc="Loading weights & repo")
98
- ckpt = get_weights_path()
99
- repo = ensure_repo()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  py = sys.executable or "python3"
102
- script_candidates = [repo / "infer_single.py"]
103
-
104
- CANDIDATE_CMDS = []
105
- for script in script_candidates:
106
- if script.exists():
107
- CANDIDATE_CMDS.append([
108
- py, script.as_posix(),
109
- "--ckpt", ckpt.as_posix(),
110
- "--input", input_wav_path.as_posix(),
111
- "--prompt", prompt,
112
- "--output", out_path.as_posix(),
113
- ])
114
- CANDIDATE_CMDS.append([
115
- py, script.as_posix(),
116
- "--weights", ckpt.as_posix(),
117
- "--input", input_wav_path.as_posix(),
118
- "--text", prompt,
119
- "--out", out_path.as_posix(),
120
- ])
121
-
122
- for idx, cmd in enumerate(CANDIDATE_CMDS, start=1):
123
- try:
124
- if progress: progress(0.35 + 0.05*idx, desc=f"Running inference (try {idx})")
125
- # inherit env so CUDA_VISIBLE_DEVICES from ZeroGPU reaches subprocess
126
- subprocess.run(cmd, capture_output=True, text=True, check=True, env=os.environ.copy())
127
- if out_path.exists() and out_path.stat().st_size > 0:
128
- if progress: progress(0.9, desc="Post-processing output")
129
- return True
130
- except Exception:
131
- continue
132
- return False
133
-
134
- # ---------- 6) REAL GPU function (always defined; only CALLED on ZeroGPU) ----------
135
  @spaces.GPU(duration=180)
136
- def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> bool:
137
- # Import torch here so CUDA initializes inside GPU context
138
  try:
 
139
  import torch # noqa: F401
140
  except Exception:
141
  pass
142
  from pathlib import Path as _P
143
- return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), _logs=[], progress=None)
144
-
145
- # ---------- 7) Gradio callback ----------
146
- def enhance_audio_ui(audio_path: str,
147
- prompt: str,
148
- progress=gr.Progress(track_tqdm=True)) -> Tuple[int, np.ndarray]:
149
- if not audio_path or not prompt:
150
- raise gr.Error("Please provide audio and a text prompt.")
151
-
152
- wav, sr = read_audio(audio_path)
153
- tmp_in, tmp_out = SPACE_ROOT / "tmp_in.wav", SPACE_ROOT / "tmp_out.wav"
154
- if tmp_out.exists():
155
- try: tmp_out.unlink()
156
- except: pass
157
- save_temp_wav(wav, sr, tmp_in)
158
-
159
- if progress: progress(0.3, desc="Starting inference")
160
- if USE_ZEROGPU:
161
- ok = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix())
162
- else:
163
- ok = run_sonicmaster_cli(tmp_in, prompt, tmp_out, _logs=[], progress=progress)
164
 
165
- if ok and tmp_out.exists() and tmp_out.stat().st_size > 0:
166
- out_wav, out_sr = read_audio(tmp_out.as_posix())
167
- return (out_sr, out_wav)
168
- else:
169
- return (sr, wav)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- # ---------- 8) Gradio UI ----------
172
  with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill_height=True) as demo:
173
- gr.Markdown("## 🎧 SonicMaster\nUpload or choose an example, write a text prompt, then click **Enhance**.")
 
 
174
  with gr.Row():
175
- with gr.Column():
176
  in_audio = gr.Audio(label="Input Audio", type="filepath")
177
- prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., reduce reverb")
178
- run_btn = gr.Button("πŸš€ Enhance", variant="primary")
179
- gr.Examples(examples=build_examples(), inputs=[in_audio, prompt])
180
- with gr.Column():
 
 
 
 
181
  out_audio = gr.Audio(label="Enhanced Audio (output)")
182
- run_btn.click(fn=enhance_audio_ui,
183
- inputs=[in_audio, prompt],
184
- outputs=[out_audio],
185
- concurrency_limit=1)
186
 
187
- # ---------- 9) FastAPI mount & disconnect handler ----------
 
 
 
 
 
 
 
 
 
 
 
188
  from fastapi import FastAPI, Request
189
  from starlette.responses import PlainTextResponse
190
- from starlette.requests import ClientDisconnect
191
-
192
- _ = get_weights_path(); _ = ensure_repo()
 
193
 
194
  app = FastAPI()
195
 
 
 
 
 
196
  @app.exception_handler(ClientDisconnect)
197
  async def client_disconnect_handler(request: Request, exc: ClientDisconnect):
198
  return PlainTextResponse("Client disconnected", status_code=499)
199
 
200
- app = gr.mount_gradio_app(app, demo.queue(max_size=16), path="/")
 
201
 
202
  if __name__ == "__main__":
203
  import uvicorn
 
1
+
2
  # ---------- MUST BE FIRST: Gradio CDN + ZeroGPU probe ----------
3
  import os
4
  os.environ.setdefault("GRADIO_USE_CDN", "true")
5
 
 
 
6
  import spaces
7
 
8
  @spaces.GPU(duration=10)
9
+ def _gpu_probe(a: int = 1, b: int = 1) -> int:
10
+ # Never called; exists so ZeroGPU startup check passes.
11
+ return a + b
12
 
13
  # ---------- Standard imports ----------
 
 
14
  from pathlib import Path
15
+ from typing import Optional, Tuple, List
16
+ import subprocess
17
+ import sys
18
+ import traceback
19
 
20
  import gradio as gr
21
  import numpy as np
22
  import soundfile as sf
23
  from huggingface_hub import hf_hub_download
24
 
25
+ # ---------- Config ----------
 
 
26
  SPACE_ROOT = Path(__file__).parent.resolve()
27
  REPO_DIR = SPACE_ROOT / "SonicMasterRepo"
28
+ REPO_URL = "https://github.com/AMAAI-Lab/SonicMaster"
29
  WEIGHTS_REPO = "amaai-lab/SonicMaster"
30
  WEIGHTS_FILE = "model.safetensors"
31
  CACHE_DIR = SPACE_ROOT / "weights"
32
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
33
 
34
+ # ZeroGPU detection (heuristic)
35
+ USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
36
+
37
+ # ---------- Lazy resources ----------
38
+ _weights_path: Optional[Path] = None
39
+ _repo_ready: bool = False
40
+
41
+ def get_weights_path(progress: Optional[gr.Progress] = None) -> Path:
42
+ """Fetch model weights lazily and cache the resolved path."""
43
+ global _weights_path
44
+ if _weights_path is None:
45
+ if progress:
46
+ progress(0.10, desc="Downloading model weights (first run)")
47
+ wp = hf_hub_download(
48
  repo_id=WEIGHTS_REPO,
49
  filename=WEIGHTS_FILE,
50
+ local_dir=str(CACHE_DIR),
51
  local_dir_use_symlinks=False,
52
  force_download=False,
53
  resume_download=True,
54
  )
55
+ _weights_path = Path(wp)
56
+ return _weights_path
57
 
58
+ def ensure_repo(progress: Optional[gr.Progress] = None) -> Path:
59
+ """Clone the inference repo lazily and put it on sys.path."""
60
+ global _repo_ready
61
+ if not _repo_ready:
62
+ if not REPO_DIR.exists():
63
+ if progress:
64
+ progress(0.18, desc="Cloning SonicMaster repo (first run)")
65
+ subprocess.run(
66
+ ["git", "clone", "--depth", "1", REPO_URL, REPO_DIR.as_posix()],
67
+ check=True,
68
+ )
69
+ if REPO_DIR.as_posix() not in sys.path:
70
+ sys.path.append(REPO_DIR.as_posix())
71
+ _repo_ready = True
72
  return REPO_DIR
73
 
74
+ # ---------- Audio helpers ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def save_temp_wav(wav: np.ndarray, sr: int, path: Path):
76
+ # Ensure (N, C) shape for soundfile
77
+ if wav.ndim == 1:
78
+ data = wav
79
+ else:
80
+ # (channels, samples) -> (samples, channels)
81
+ data = wav.T if wav.shape[0] < wav.shape[1] else wav
82
+ if data.dtype == np.float64:
83
+ data = data.astype(np.float32)
84
+ sf.write(path.as_posix(), data, sr)
85
 
86
  def read_audio(path: str) -> Tuple[np.ndarray, int]:
87
  wav, sr = sf.read(path, always_2d=False)
88
+ if wav.dtype == np.float64:
89
+ wav = wav.astype(np.float32)
90
+ return wav, sr
91
+
92
+ # ---------- CLI runner ----------
93
+ def _candidate_commands(py: str, script: Path, ckpt: Path, inp: Path, prompt: str, out: Path) -> List[List[str]]:
94
+ """Try multiple arg styles commonly found in repos."""
95
+ combos = [
96
+ # infer_single.py (common)
97
+ [py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--prompt", prompt, "--output", out.as_posix()],
98
+ [py, script.as_posix(), "--weights", ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--out", out.as_posix()],
99
+ # other possible entrypoints
100
+ [py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--output", out.as_posix()],
101
+ ]
102
+ return combos
103
+
104
+ def run_sonicmaster_cli(
105
+ input_wav_path: Path,
106
+ prompt: str,
107
+ out_path: Path,
108
+ progress: Optional[gr.Progress] = None,
109
+ ) -> Tuple[bool, str]:
110
+ """
111
+ Returns (ok, message). Captures stderr/stdout and returns first non-empty output file.
112
+ """
113
+ if progress:
114
+ progress(0.14, desc="Preparing inference")
115
+ ckpt = get_weights_path(progress=progress)
116
+ repo = ensure_repo(progress=progress)
117
+
118
+ # Candidate scripts to try
119
+ script_candidates = [
120
+ repo / "infer_single.py",
121
+ repo / "inference_fullsong.py",
122
+ repo / "inference_ptload_batch.py",
123
+ ]
124
+ scripts = [s for s in script_candidates if s.exists()]
125
+ if not scripts:
126
+ return False, "No inference script found in the repo (expected infer_single.py or similar)."
127
 
128
  py = sys.executable or "python3"
129
+ env = os.environ.copy() # keep CUDA_VISIBLE_DEVICES etc.
130
+
131
+ last_err = ""
132
+ for idx, script in enumerate(scripts, start=1):
133
+ for jdx, cmd in enumerate(_candidate_commands(py, script, ckpt, input_wav_path, prompt, out_path), start=1):
134
+ try:
135
+ if progress:
136
+ progress(min(0.20 + 0.08 * (idx + jdx), 0.70), desc=f"Running {script.name} (try {idx}.{jdx})")
137
+ res = subprocess.run(cmd, capture_output=True, text=True, check=True, env=env)
138
+ if out_path.exists() and out_path.stat().st_size > 0:
139
+ if progress:
140
+ progress(0.88, desc="Post-processing output")
141
+ # Return any informative stdout as message
142
+ msg = (res.stdout or "").strip()
143
+ return True, msg if msg else "Inference completed."
144
+ else:
145
+ last_err = f"{script.name} produced no output file."
146
+ except subprocess.CalledProcessError as e:
147
+ # Collect stderr/stdout for the user
148
+ snippet = "\n".join(filter(None, [e.stdout or "", e.stderr or ""])).strip()
149
+ last_err = snippet if snippet else f"{script.name} failed with return code {e.returncode}."
150
+ except Exception as e:
151
+ last_err = f"Unexpected error: {e}\n{traceback.format_exc()}"
152
+
153
+ return False, last_err or "All candidate commands failed without an error message."
154
+
155
+ # ---------- REAL GPU function (called only if using ZeroGPU / GPU available) ----------
 
 
 
 
 
 
156
  @spaces.GPU(duration=180)
157
+ def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> Tuple[bool, str]:
 
158
  try:
159
+ # Initialize CUDA inside the GPU context
160
  import torch # noqa: F401
161
  except Exception:
162
  pass
163
  from pathlib import Path as _P
164
+ return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), progress=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ def _has_cuda() -> bool:
167
+ try:
168
+ import torch
169
+ return torch.cuda.is_available()
170
+ except Exception:
171
+ return False
172
+
173
+ # ---------- UI callback ----------
174
+ def enhance_audio_ui(
175
+ audio_path: str,
176
+ prompt: str,
177
+ progress=gr.Progress(track_tqdm=True),
178
+ ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
179
+ """
180
+ Returns (audio, message). On failure, audio=None and message=error text.
181
+ """
182
+ try:
183
+ if not prompt:
184
+ raise gr.Error("Please provide a text prompt.")
185
+ if not audio_path:
186
+ raise gr.Error("Please upload or select an input audio file.")
187
+
188
+ wav, sr = read_audio(audio_path)
189
+
190
+ tmp_in = SPACE_ROOT / "tmp_in.wav"
191
+ tmp_out = SPACE_ROOT / "tmp_out.wav"
192
+ if tmp_out.exists():
193
+ try:
194
+ tmp_out.unlink()
195
+ except Exception:
196
+ pass
197
+
198
+ if progress:
199
+ progress(0.06, desc="Preparing audio")
200
+ save_temp_wav(wav, sr, tmp_in)
201
+
202
+ # Choose execution path: prefer real GPU if available, else CPU
203
+ use_gpu_call = USE_ZEROGPU or _has_cuda()
204
+
205
+ if progress:
206
+ progress(0.12, desc="Starting inference")
207
+ if use_gpu_call:
208
+ ok, msg = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix())
209
+ else:
210
+ ok, msg = run_sonicmaster_cli(tmp_in, prompt, tmp_out, progress=progress)
211
+
212
+ if ok and tmp_out.exists() and tmp_out.stat().st_size > 0:
213
+ # Return output audio by filepath (lighter than big arrays)
214
+ # Gradio Audio accepts a (sr, np.ndarray) OR a file path; giving file path is fine.
215
+ return (None, f"Saved output: {tmp_out.name}\n{msg or ''}") if False else (read_audio(tmp_out.as_posix()), msg or "Done.")
216
+ else:
217
+ # On failure: DON'T echo input audio β€” return None and the error message
218
+ if not msg:
219
+ msg = "Inference failed without a specific error message."
220
+ return (None, msg.strip())
221
+
222
+ except gr.Error as e:
223
+ return (None, str(e))
224
+ except Exception as e:
225
+ return (None, f"Unexpected error: {e}\n{traceback.format_exc()}")
226
+
227
+ # ---------- Gradio UI ----------
228
+ PROMPT_EXAMPLES = [
229
+ ["Increase the clarity of this song by emphasizing treble frequencies."],
230
+ ["Make this song sound more boomy by amplifying the low end bass frequencies."],
231
+ ["Make the audio smoother and less distorted."],
232
+ ["Improve the balance in this song."],
233
+ ["Reduce roominess/echo (dereverb)."],
234
+ ["Raise the level of the vocals."],
235
+ ["Give the song a wider stereo image."],
236
+ ]
237
 
 
238
  with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill_height=True) as demo:
239
+ gr.Markdown("## 🎧 SonicMaster\nUpload audio, enter a prompt, then click **Enhance**.\n"
240
+ "- Progress appears below during the first run (weights/repo download).\n"
241
+ "- If something fails, you'll see the **error message** instead of the input audio.")
242
  with gr.Row():
243
+ with gr.Column(scale=1):
244
  in_audio = gr.Audio(label="Input Audio", type="filepath")
245
+ prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., Reduce reverb and brighten the vocals.")
246
+ run_btn = gr.Button("πŸš€ Enhance", variant="primary")
247
+ gr.Examples(
248
+ examples=PROMPT_EXAMPLES,
249
+ inputs=[prompt], # prompt-only examples to avoid heavy file ops at startup
250
+ label="Prompt Examples",
251
+ )
252
+ with gr.Column(scale=1):
253
  out_audio = gr.Audio(label="Enhanced Audio (output)")
254
+ status = gr.Textbox(label="Status / Messages", interactive=False, lines=6)
 
 
 
255
 
256
+ # On click, return audio + message
257
+ run_btn.click(
258
+ fn=enhance_audio_ui,
259
+ inputs=[in_audio, prompt],
260
+ outputs=[out_audio, status],
261
+ concurrency_limit=1,
262
+ )
263
+
264
+ # Queue BEFORE mounting so the mounted app is ready immediately
265
+ demo = demo.queue(concurrency_count=1, max_size=16)
266
+
267
+ # ---------- FastAPI mount & health ----------
268
  from fastapi import FastAPI, Request
269
  from starlette.responses import PlainTextResponse
270
+ try:
271
+ from starlette.exceptions import ClientDisconnect # Starlette β‰₯0.27
272
+ except Exception:
273
+ from starlette.requests import ClientDisconnect # fallback for older versions
274
 
275
  app = FastAPI()
276
 
277
+ @app.get("/health")
278
+ def _health():
279
+ return {"ok": True}
280
+
281
  @app.exception_handler(ClientDisconnect)
282
  async def client_disconnect_handler(request: Request, exc: ClientDisconnect):
283
  return PlainTextResponse("Client disconnected", status_code=499)
284
 
285
+ # Mount Gradio at root (Spaces looks here)
286
+ app = gr.mount_gradio_app(app, demo, path="/")
287
 
288
  if __name__ == "__main__":
289
  import uvicorn