ambujm22 commited on
Commit
5bce909
Β·
verified Β·
1 Parent(s): 30c0cfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -21
app.py CHANGED
@@ -1,33 +1,246 @@
1
- # ---------- Gradio Space entrypoint (no FastAPI/Uvicorn) ----------
2
  import os
3
  os.environ.setdefault("GRADIO_USE_CDN", "true")
4
 
5
- import gradio as gr
6
-
7
- # Optional: harmless on CPU; useful once you switch to ZeroGPU hardware
8
  try:
9
  import spaces
10
- @spaces.GPU(duration=10)
11
- def gpu_probe(a: int = 1, b: int = 1):
12
- return a + b
13
  except Exception:
14
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Build a tiny UI
17
- def echo(s: str) -> str:
18
- return f"echo: {s}"
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- with gr.Blocks(title="Hello Space") as _demo:
21
- gr.Markdown("### βœ… App is alive\nType to echo.")
22
- inp = gr.Textbox(label="Input", value="hello")
23
- out = gr.Textbox(label="Output")
24
- inp.submit(echo, inp, out)
 
25
 
26
- # Expose all common names that the Space supervisor might look for
27
- demo = _demo.queue(max_size=8) # primary export
28
- iface = demo # alias
29
- app = demo # another alias
30
 
31
- # For local debugging only (ignored on Spaces)
32
  if __name__ == "__main__":
33
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # ========== MUST BE FIRST: Gradio SDK entry + ZeroGPU probes ==========
2
  import os
3
  os.environ.setdefault("GRADIO_USE_CDN", "true")
4
 
5
+ # Optional: 'spaces' is present on Spaces; harmless to try locally.
 
 
6
  try:
7
  import spaces
 
 
 
8
  except Exception:
9
+ class _DummySpaces:
10
+ def GPU(self, *_, **__):
11
+ def deco(fn): return fn
12
+ return deco
13
+ spaces = _DummySpaces()
14
+
15
+ # PUBLIC names so ZeroGPU supervisor can detect them
16
+ @spaces.GPU(duration=10)
17
+ def gpu_probe(a: int = 1, b: int = 1):
18
+ return a + b
19
+
20
+ @spaces.GPU(duration=10)
21
+ def gpu_echo(x: str = "ok"):
22
+ return x
23
+
24
+ # ========== Standard imports ==========
25
+ import sys
26
+ import subprocess
27
+ from pathlib import Path
28
+ from typing import Tuple, Optional, List
29
+
30
+ import gradio as gr
31
+ import numpy as np
32
+ import soundfile as sf
33
+ from huggingface_hub import hf_hub_download
34
+
35
+ # ZeroGPU runtime hint (still safe on CPU)
36
+ USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
37
+
38
+ SPACE_ROOT = Path(__file__).parent.resolve()
39
+ REPO_DIR = SPACE_ROOT / "SonicMasterRepo"
40
+ REPO_URL = "https://github.com/AMAAI-Lab/SonicMaster"
41
+ WEIGHTS_REPO = "amaai-lab/SonicMaster"
42
+ WEIGHTS_FILE = "model.safetensors"
43
+ CACHE_DIR = SPACE_ROOT / "weights"
44
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
45
+
46
+ # ========== Lazy resources (no heavy work at import) ==========
47
+ _weights_path: Optional[Path] = None
48
+ _repo_ready: bool = False
49
+
50
+ def get_weights_path(progress: Optional[gr.Progress] = None) -> Path:
51
+ """Download/resolve weights lazily."""
52
+ global _weights_path
53
+ if _weights_path is None:
54
+ if progress: progress(0.10, desc="Downloading model weights (first run)")
55
+ wp = hf_hub_download(
56
+ repo_id=WEIGHTS_REPO,
57
+ filename=WEIGHTS_FILE,
58
+ local_dir=str(CACHE_DIR),
59
+ local_dir_use_symlinks=False,
60
+ force_download=False,
61
+ resume_download=True,
62
+ )
63
+ _weights_path = Path(wp)
64
+ return _weights_path
65
+
66
+ def ensure_repo(progress: Optional[gr.Progress] = None) -> Path:
67
+ """Clone the repo lazily and add to sys.path."""
68
+ global _repo_ready
69
+ if not _repo_ready:
70
+ if not REPO_DIR.exists():
71
+ if progress: progress(0.18, desc="Cloning SonicMaster repo (first run)")
72
+ subprocess.run(
73
+ ["git", "clone", "--depth", "1", REPO_URL, REPO_DIR.as_posix()],
74
+ check=True,
75
+ )
76
+ if REPO_DIR.as_posix() not in sys.path:
77
+ sys.path.append(REPO_DIR.as_posix())
78
+ _repo_ready = True
79
+ return REPO_DIR
80
+
81
+ # ========== Helpers ==========
82
+ def save_temp_wav(wav: np.ndarray, sr: int, path: Path):
83
+ # Ensure shape (samples, channels)
84
+ if wav.ndim == 2 and wav.shape[0] < wav.shape[1]:
85
+ wav = wav.T
86
+ if wav.dtype == np.float64:
87
+ wav = wav.astype(np.float32)
88
+ sf.write(path.as_posix(), wav, sr)
89
+
90
+ def read_audio(path: str) -> Tuple[np.ndarray, int]:
91
+ wav, sr = sf.read(path, always_2d=False)
92
+ if wav.dtype == np.float64:
93
+ wav = wav.astype(np.float32)
94
+ return wav, sr
95
+
96
+ def _candidate_commands(py: str, script: Path, ckpt: Path, inp: Path, prompt: str, out: Path) -> List[List[str]]:
97
+ # Try common flag layouts
98
+ return [
99
+ [py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--prompt", prompt, "--output", out.as_posix()],
100
+ [py, script.as_posix(), "--weights",ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--out", out.as_posix()],
101
+ [py, script.as_posix(), "--ckpt", ckpt.as_posix(), "--input", inp.as_posix(), "--text", prompt, "--output", out.as_posix()],
102
+ ]
103
+
104
+ def run_sonicmaster_cli(
105
+ input_wav_path: Path,
106
+ prompt: str,
107
+ out_path: Path,
108
+ progress: Optional[gr.Progress] = None,
109
+ ) -> Tuple[bool, str]:
110
+ """Run inference scripts via subprocess; return (ok, message)."""
111
+ if progress: progress(0.14, desc="Preparing inference")
112
+ ckpt = get_weights_path(progress=progress)
113
+ repo = ensure_repo(progress=progress)
114
+
115
+ candidates = [repo / "infer_single.py", repo / "inference_fullsong.py", repo / "inference_ptload_batch.py"]
116
+ scripts = [s for s in candidates if s.exists()]
117
+ if not scripts:
118
+ return False, "No inference script found in the repo (expected infer_single.py or similar)."
119
+
120
+ py = sys.executable or "python3"
121
+ env = os.environ.copy()
122
+
123
+ last_err = ""
124
+ for sidx, script in enumerate(scripts, 1):
125
+ for cidx, cmd in enumerate(_candidate_commands(py, script, ckpt, input_wav_path, prompt, out_path), 1):
126
+ try:
127
+ if progress:
128
+ progress(min(0.20 + 0.08 * (sidx + cidx), 0.70), desc=f"Running {script.name} (try {sidx}.{cidx})")
129
+ res = subprocess.run(cmd, capture_output=True, text=True, check=True, env=env)
130
+ if out_path.exists() and out_path.stat().st_size > 0:
131
+ if progress: progress(0.88, desc="Post-processing output")
132
+ return True, (res.stdout or "Inference completed.").strip()
133
+ last_err = f"{script.name} produced no output file."
134
+ except subprocess.CalledProcessError as e:
135
+ snippet = "\n".join(filter(None, [e.stdout or "", e.stderr or ""])).strip()
136
+ last_err = snippet if snippet else f"{script.name} failed with return code {e.returncode}."
137
+ except Exception as e:
138
+ import traceback
139
+ last_err = f"Unexpected error: {e}\n{traceback.format_exc()}"
140
+ return False, last_err or "All candidate commands failed."
141
+
142
+ # ========== GPU path (called only if ZeroGPU/GPU available) ==========
143
+ @spaces.GPU(duration=180)
144
+ def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> Tuple[bool, str]:
145
+ try:
146
+ import torch # noqa: F401
147
+ except Exception:
148
+ pass
149
+ from pathlib import Path as _P
150
+ return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), progress=None)
151
+
152
+ def _has_cuda() -> bool:
153
+ try:
154
+ import torch
155
+ return torch.cuda.is_available()
156
+ except Exception:
157
+ return False
158
+
159
+ # ========== Gradio callback ==========
160
+ def enhance_audio_ui(
161
+ audio_path: str,
162
+ prompt: str,
163
+ progress=gr.Progress(track_tqdm=True),
164
+ ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
165
+ """
166
+ Returns (audio, message). On failure, audio=None and message=error text.
167
+ """
168
+ try:
169
+ if not prompt:
170
+ raise gr.Error("Please provide a text prompt.")
171
+ if not audio_path:
172
+ raise gr.Error("Please upload or select an input audio file.")
173
+
174
+ wav, sr = read_audio(audio_path)
175
+ tmp_in = SPACE_ROOT / "tmp_in.wav"
176
+ tmp_out = SPACE_ROOT / "tmp_out.wav"
177
+ if tmp_out.exists():
178
+ try: tmp_out.unlink()
179
+ except Exception: pass
180
+
181
+ if progress: progress(0.06, desc="Preparing audio")
182
+ save_temp_wav(wav, sr, tmp_in)
183
+
184
+ use_gpu_call = USE_ZEROGPU or _has_cuda()
185
+ if progress: progress(0.12, desc="Starting inference")
186
+
187
+ if use_gpu_call:
188
+ ok, msg = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix())
189
+ else:
190
+ ok, msg = run_sonicmaster_cli(tmp_in, prompt, tmp_out, progress=progress)
191
+
192
+ if ok and tmp_out.exists() and tmp_out.stat().st_size > 0:
193
+ out_wav, out_sr = read_audio(tmp_out.as_posix())
194
+ return (out_sr, out_wav), (msg or "Done.")
195
+ else:
196
+ return None, (msg or "Inference failed without a specific error message.")
197
+
198
+ except gr.Error as e:
199
+ return None, str(e)
200
+ except Exception as e:
201
+ import traceback
202
+ return None, f"Unexpected error: {e}\n{traceback.format_exc()}"
203
+
204
+ # ========== Gradio UI ==========
205
+ PROMPT_EXAMPLES = [
206
+ ["Increase the clarity of this song by emphasizing treble frequencies."],
207
+ ["Make this song sound more boomy by amplifying the low end bass frequencies."],
208
+ ["Make the audio smoother and less distorted."],
209
+ ["Improve the balance in this song."],
210
+ ["Reduce roominess/echo (dereverb)."],
211
+ ["Raise the level of the vocals."],
212
+ ["Give the song a wider stereo image."],
213
+ ]
214
 
215
+ with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill_height=True) as _demo:
216
+ gr.Markdown(
217
+ "## 🎧 SonicMaster\n"
218
+ "Upload or choose an example prompt, write your own instruction, then click **Enhance**.\n"
219
+ "- First run downloads model weights & repo (progress will show).\n"
220
+ "- On failure, the **Status** box shows the exact error (we won't echo the input audio)."
221
+ )
222
+ with gr.Row():
223
+ with gr.Column():
224
+ in_audio = gr.Audio(label="Input Audio", type="filepath")
225
+ prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., Reduce reverb and brighten vocals.")
226
+ run_btn = gr.Button("πŸš€ Enhance", variant="primary")
227
+ gr.Examples(examples=PROMPT_EXAMPLES, inputs=[prompt], label="Prompt Examples")
228
+ with gr.Column():
229
+ out_audio = gr.Audio(label="Enhanced Audio (output)")
230
+ status = gr.Textbox(label="Status / Messages", interactive=False, lines=8)
231
 
232
+ run_btn.click(
233
+ fn=enhance_audio_ui,
234
+ inputs=[in_audio, prompt],
235
+ outputs=[out_audio, status],
236
+ concurrency_limit=1,
237
+ )
238
 
239
+ # Expose all common names the supervisor might look for
240
+ demo = _demo.queue(max_size=16)
241
+ iface = demo
242
+ app = demo
243
 
244
+ # Local debugging only
245
  if __name__ == "__main__":
246
  demo.launch(server_name="0.0.0.0", server_port=7860)