lllindsey0615 commited on
Commit
48dbef5
·
1 Parent(s): 9d368e9

initial commit

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +107 -0
  3. requirements.txt +18 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 👁
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.44.0
8
  app_file: app.py
9
  pinned: false
10
  short_description: Demucs stem separator wrapped in pyharp
 
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.28.0
8
  app_file: app.py
9
  pinned: false
10
  short_description: Demucs stem separator wrapped in pyharp
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from pyharp.core import ModelCard, build_endpoint
3
+ from pyharp.media.audio import load_audio, save_audio
4
+ from pyharp.labels import LabelList
5
+
6
+ from audiotools import AudioSignal
7
+ from demucs import pretrained
8
+ from demucs.apply import apply_model
9
+
10
+ import gradio as gr
11
+ import torchaudio
12
+ import torch
13
+ from pathlib import Path
14
+
15
+ # ModelCard
16
+ model_card = ModelCard(
17
+ name="Demucs Stem Separator (All Stems)",
18
+ description="Separates a music mixture into all individual stems using a Demucs model.",
19
+ author="Alexandre Défossez, et al.",
20
+ tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"]
21
+ )
22
+
23
+
24
+ DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
25
+ STEM_NAMES = ["Drums", "Bass", "Other", "Vocals"]
26
+
27
+ # Global model cache
28
+ LOADED_MODELS = {}
29
+
30
+ def get_cached_model(model_name: str):
31
+ if model_name not in LOADED_MODELS:
32
+ model = pretrained.get_model(model_name)
33
+ model.eval()
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ model.to(device)
36
+ LOADED_MODELS[model_name] = model
37
+ return LOADED_MODELS[model_name]
38
+
39
+ # Separation Logic (all stems)
40
+ def separate_all_stems(audio_file_path: str, model_name: str):
41
+ model = get_cached_model(model_name)
42
+ waveform, sr = torchaudio.load(audio_file_path)
43
+ is_mono = waveform.shape[0] == 1
44
+ if is_mono:
45
+ waveform = waveform.repeat(2, 1)
46
+
47
+ with torch.no_grad():
48
+ stems_batch = apply_model(
49
+ model,
50
+ waveform.unsqueeze(0).to(next(model.parameters()).device),
51
+ overlap=0.2,
52
+ shifts=1,
53
+ split=True,
54
+ )
55
+ stems = stems_batch[0]
56
+
57
+ output_signals = []
58
+ for stem in stems:
59
+ if is_mono:
60
+ stem = stem.mean(dim=0, keepdim=True)
61
+ signal = AudioSignal(stem.cpu().numpy().astype("float32"), sample_rate=sr)
62
+ output_signals.append(signal)
63
+
64
+ return output_signals # [drums, bass, other, vocals]
65
+
66
+ # Process Function
67
+ @spaces.GPU
68
+ def process_fn(audio_file_path, model_name):
69
+ output_signals = separate_all_stems(audio_file_path, model_name)
70
+ outputs = []
71
+ for stem_name, signal in zip(STEM_NAMES, output_signals):
72
+ filename = f"demucs_{model_name}_{stem_name.lower()}.wav"
73
+ output_audio_path = save_audio(signal, filename)
74
+ outputs.append(output_audio_path)
75
+ return outputs, LabelList()
76
+
77
+ # Gradio App
78
+ with gr.Blocks() as demo:
79
+ input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True)
80
+
81
+ model_dropdown = gr.Dropdown(
82
+ choices=DEMUX_MODELS,
83
+ label="Demucs Model",
84
+ value="htdemucs"
85
+ )
86
+
87
+ # Outputs: Multiple stems
88
+ output_drums = gr.Audio(type="filepath", label="Drums")
89
+ output_bass = gr.Audio(type="filepath", label="Bass")
90
+ output_other = gr.Audio(type="filepath", label="Other")
91
+ output_vocals = gr.Audio(type="filepath", label="Vocals")
92
+ output_labels = gr.JSON(label="Labels")
93
+
94
+ app = build_endpoint(
95
+ model_card=model_card,
96
+ input_components=[input_audio, model_dropdown],
97
+ output_components=[
98
+ output_drums,
99
+ output_bass,
100
+ output_other,
101
+ output_vocals,
102
+ output_labels
103
+ ],
104
+ process_fn=process_fn
105
+ )
106
+
107
+ demo.queue().launch(share=True, show_error=True)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyHARP from correct branch
2
+ git+https://github.com/TEAMuP-dev/pyharp.git@cb/gen-inputs
3
+
4
+ # Demucs & dependencies
5
+ demucs==4.0.0
6
+ dora-search==0.1.11
7
+ einops==0.6.1
8
+ julius>=0.2.3
9
+ lameenc>=1.2
10
+ openunmix==1.2.1
11
+
12
+ # Audio & ML
13
+ torch>=1.8.1, <2.1
14
+ torchaudio>=0.8, <2.1
15
+ ffmpeg
16
+ soundfile
17
+ scipy
18
+ numpy<2