File size: 4,793 Bytes
48dbef5
 
f489d22
48dbef5
 
 
 
 
 
 
 
 
25f3a6c
48dbef5
 
 
 
0883a2b
0466bed
7984676
 
 
0883a2b
48dbef5
 
 
 
6926b0c
44ccdb7
48dbef5
 
 
 
 
 
 
 
 
 
 
 
 
f489d22
48dbef5
f489d22
 
3e600ab
 
3d43232
 
 
1a3b1bf
 
3e600ab
3d43232
f489d22
29ead22
 
 
 
 
 
 
 
 
 
 
 
 
 
3d43232
48dbef5
 
 
c8b06c0
48dbef5
 
 
 
f489d22
48dbef5
 
 
 
 
 
 
 
 
4c06cf0
 
 
 
665d1ca
48dbef5
 
 
6926b0c
 
 
48dbef5
4c06cf0
3d43232
 
 
48dbef5
 
3d43232
 
 
182934c
3d43232
4c06cf0
199495b
 
f489d22
48dbef5
 
 
 
 
 
0883a2b
 
 
6926b0c
 
 
0883a2b
48dbef5
 
 
 
 
44ccdb7
48dbef5
 
 
 
 
 
 
4c06cf0
 
48dbef5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import spaces
from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import save_audio
from pyharp.labels import LabelList

from audiotools import AudioSignal
from demucs import pretrained
from demucs.apply import apply_model

import gradio as gr
import torch
from pathlib import Path
import numpy as np

# ModelCard
model_card = ModelCard(
    name="Demucs Stem Separator (All Stems)",
    description=(
    "Separates a full music track into four individual audio stems: Drums, Bass, Vocals, Other-instrumental using the Demucs model. Available models are -  \n"
    "htdemucs: Best overall fidelity and natural sound, "
     "mdx_extra: Balanced trade-off, faster processing with a small drop in quality, "
     "mdx_extra_q: Fastest & lightest, noticeable quality loss, ideal for quick previews. \n"
    ),
    author="Alexandre Défossez, et al.",
    tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"]
)

DEMUX_MODELS = ["htdemucs", "mdx_extra", "mdx_extra_q"]
STEM_NAMES = ["Drums", "Bass", "Vocals", "Other"]

# Global model cache
LOADED_MODELS = {}

def get_cached_model(model_name: str):
    if model_name not in LOADED_MODELS:
        model = pretrained.get_model(model_name)
        model.eval()
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        LOADED_MODELS[model_name] = model
    return LOADED_MODELS[model_name]

# Separation Logic
def separate_all_stems(audio_file_path: str, model_name: str):
    model = get_cached_model(model_name)

    signal = AudioSignal(audio_file_path)
    signal = signal.resample(44100)  # expects 44.1kHz

    is_mono = signal.num_channels == 1
    if is_mono:
        signal.audio_data = signal.audio_data.repeat(2, 1)
        signal._num_channels = 2 

    sr = signal.sample_rate  

    # Ensure audio_data is a torch.Tensor
    audio = signal.audio_data
    if isinstance(audio, np.ndarray):
        audio = torch.from_numpy(audio)

    audio = audio.float()  # [channels, samples] or [channels, samples, ?]

    # Remove extra trailing dimensions
    if audio.ndim > 2:
        audio = audio.squeeze()

    # Final shape: [1, channels, samples]
    waveform = audio.unsqueeze(0)


    with torch.no_grad():
        stems_batch = apply_model(
            model,
            waveform.to(next(model.parameters()).device),
            overlap=0.2,
            shifts=1,
            split=True,
        )

    stems = stems_batch[0]

    output_signals = []
    for stem in stems:
        if is_mono:
            stem = stem.mean(dim=0, keepdim=True)
        signal = AudioSignal(stem.cpu().numpy().astype("float32"), sample_rate=sr)
        output_signals.append(signal)

    # Combine drums + bass + other = instrumental
    drums, bass, other, vocals = output_signals
    instrumental = drums + bass + other

    return [drums, bass, vocals, other]

# Process Function
def process_fn(audio_file_path, model_name):
    if model_name not in DEMUX_MODELS:
        raise ValueError(f"Unsupported model selected: {model_name}")

    output_signals = separate_all_stems(audio_file_path, model_name)

    is_mp3 = Path(audio_file_path).suffix.lower() == ".mp3"
    extension = "mp3" if is_mp3 else "wav"

    outputs = []
    for stem_name, signal in zip(STEM_NAMES, output_signals):
        filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.{extension}"
        output_path = Path(filename)

        save_audio(signal, output_path)
        outputs.append(str(output_path))

    return tuple(outputs)

# Gradio App
with gr.Blocks() as demo:
    input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True)

    model_dropdown = gr.Dropdown(
        choices=DEMUX_MODELS,
        label="Demucs Model",
        value="htdemucs",
        info=(
            "Choose which Demucs variant to use for separation:\n"
            "1. htdemucs: High-quality Hybrid Transformer Demucs trained on MusDB + 800 songs.\n"
            "2. mdx_extra: Trained with extended data. Balanced model offering a clear trade-off between quality and speed.\n"
            "3. mdx_extra_q: Quantized lightweight version of mdx_extra.\n"
        )
    )

    output_drums = gr.Audio(type="filepath", label="Drums")
    output_bass = gr.Audio(type="filepath", label="Bass")
    output_vocals = gr.Audio(type="filepath", label="Vocals")
    output_instrumental = gr.Audio(type="filepath", label="Other")

    app = build_endpoint(
        model_card=model_card,
        input_components=[input_audio, model_dropdown],
        output_components=[
            output_drums,
            output_bass,
            output_vocals,
            output_instrumental
        ],
        process_fn=process_fn
    )

demo.queue().launch(share=True, show_error=True)