Yongyi Zang
commited on
Commit
Β·
3617711
1
Parent(s):
7acb2e5
Change Files
Browse files- __pycache__/model.cpython-313.pyc +0 -0
- app.py +6 -4
__pycache__/model.cpython-313.pyc
ADDED
|
Binary file (22 kB). View file
|
|
|
app.py
CHANGED
|
@@ -24,7 +24,7 @@ def _get_model(ckpt_name: str):
|
|
| 24 |
raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}")
|
| 25 |
if ckpt_name in _model_cache:
|
| 26 |
return _model_cache[ckpt_name]
|
| 27 |
-
ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.
|
| 28 |
model = UFormer(config).to(DEVICE).eval()
|
| 29 |
state = torch.load(ckpt_path, map_location=DEVICE)
|
| 30 |
model.load_state_dict(state)
|
|
@@ -43,6 +43,7 @@ def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=
|
|
| 43 |
out = np.zeros_like(x_pad)
|
| 44 |
norm = np.zeros((1, x_pad.shape[1]))
|
| 45 |
n_chunks = 1 + (x_pad.shape[1] - chunk) // hop
|
|
|
|
| 46 |
|
| 47 |
for i in range(n_chunks):
|
| 48 |
s = i * hop
|
|
@@ -52,7 +53,8 @@ def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=
|
|
| 52 |
out[:, s:s+chunk] += y * win
|
| 53 |
norm[:, s:s+chunk] += win
|
| 54 |
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
# ββββββββββββββββββββββ
|
| 58 |
# 3) Restore function for Gradio
|
|
@@ -81,12 +83,12 @@ def restore_fn(audio_path, checkpoint):
|
|
| 81 |
demo = gr.Interface(
|
| 82 |
fn=restore_fn,
|
| 83 |
inputs=[
|
| 84 |
-
gr.Audio(
|
| 85 |
gr.Dropdown(VALID_CKPTS, label="Checkpoint")
|
| 86 |
],
|
| 87 |
outputs=gr.Audio(type="filepath", label="Restored Output"),
|
| 88 |
title="π΅ Music Source Restoration",
|
| 89 |
-
description="Upload
|
| 90 |
allow_flagging="never"
|
| 91 |
)
|
| 92 |
|
|
|
|
| 24 |
raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}")
|
| 25 |
if ckpt_name in _model_cache:
|
| 26 |
return _model_cache[ckpt_name]
|
| 27 |
+
ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.pth")
|
| 28 |
model = UFormer(config).to(DEVICE).eval()
|
| 29 |
state = torch.load(ckpt_path, map_location=DEVICE)
|
| 30 |
model.load_state_dict(state)
|
|
|
|
| 43 |
out = np.zeros_like(x_pad)
|
| 44 |
norm = np.zeros((1, x_pad.shape[1]))
|
| 45 |
n_chunks = 1 + (x_pad.shape[1] - chunk) // hop
|
| 46 |
+
print(f"Processing {n_chunks} chunks of size {chunk} with hop {hop}...")
|
| 47 |
|
| 48 |
for i in range(n_chunks):
|
| 49 |
s = i * hop
|
|
|
|
| 53 |
out[:, s:s+chunk] += y * win
|
| 54 |
norm[:, s:s+chunk] += win
|
| 55 |
|
| 56 |
+
eps = 1e-8
|
| 57 |
+
return (out / (norm + eps))[:, :T]
|
| 58 |
|
| 59 |
# ββββββββββββββββββββββ
|
| 60 |
# 3) Restore function for Gradio
|
|
|
|
| 83 |
demo = gr.Interface(
|
| 84 |
fn=restore_fn,
|
| 85 |
inputs=[
|
| 86 |
+
gr.Audio(sources="upload", type="filepath", label="Your Input"),
|
| 87 |
gr.Dropdown(VALID_CKPTS, label="Checkpoint")
|
| 88 |
],
|
| 89 |
outputs=gr.Audio(type="filepath", label="Restored Output"),
|
| 90 |
title="π΅ Music Source Restoration",
|
| 91 |
+
description="Upload an (stereo) audio file and choose an instrument/group checkpoint to restore. Please note that these are baseline models for demonstration purposes only, and most of them don't perform really well...",
|
| 92 |
allow_flagging="never"
|
| 93 |
)
|
| 94 |
|