File size: 3,460 Bytes
48dbef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c06cf0
48dbef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c06cf0
 
 
 
 
48dbef5
 
 
 
4c06cf0
48dbef5
 
4c06cf0
48dbef5
 
4c06cf0
199495b
 
48dbef5
 
 
 
 
 
 
 
 
 
 
 
 
 
4c06cf0
 
48dbef5
 
 
 
 
 
 
4c06cf0
 
48dbef5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import spaces
from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import load_audio, save_audio
from pyharp.labels import LabelList

from audiotools import AudioSignal
from demucs import pretrained
from demucs.apply import apply_model

import gradio as gr
import torchaudio
import torch
from pathlib import Path

# ModelCard
model_card = ModelCard(
    name="Demucs Stem Separator (All Stems)",
    description="Separates a music mixture into all individual stems using a Demucs model.",
    author="Alexandre Défossez, et al.",
    tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"]
)


DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
STEM_NAMES = ["Drums", "Bass", "Vocals", "Instrumental (No Vocals)"]

# Global model cache
LOADED_MODELS = {}

def get_cached_model(model_name: str):
    if model_name not in LOADED_MODELS:
        model = pretrained.get_model(model_name)
        model.eval()
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        LOADED_MODELS[model_name] = model
    return LOADED_MODELS[model_name]

# Separation Logic (all stems)
def separate_all_stems(audio_file_path: str, model_name: str):
    model = get_cached_model(model_name)
    waveform, sr = torchaudio.load(audio_file_path)
    is_mono = waveform.shape[0] == 1
    if is_mono:
        waveform = waveform.repeat(2, 1)

    with torch.no_grad():
        stems_batch = apply_model(
            model,
            waveform.unsqueeze(0).to(next(model.parameters()).device),
            overlap=0.2,
            shifts=1,
            split=True,
        )
    stems = stems_batch[0]

    output_signals = []
    for stem in stems:
        if is_mono:
            stem = stem.mean(dim=0, keepdim=True)
        signal = AudioSignal(stem.cpu().numpy().astype("float32"), sample_rate=sr)
        output_signals.append(signal)

    # Combine drums + bass + other = instrumental
    drums, bass, other, vocals = output_signals
    instrumental = drums + bass + other

    return [drums, bass, vocals, instrumental]

# Process Function
def process_fn(audio_file_path, model_name):
    output_signals = separate_all_stems(audio_file_path, model_name)

    outputs = []
    for stem_name, signal in zip(STEM_NAMES, output_signals):
        filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.wav"
        output_audio_path = save_audio(signal, filename)
        outputs.append(output_audio_path)

    return tuple(outputs)

# Gradio App 
with gr.Blocks() as demo:
    input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True)

    model_dropdown = gr.Dropdown(
        choices=DEMUX_MODELS,
        label="Demucs Model",
        value="htdemucs"
    )

    # Outputs: Multiple stems
    output_drums = gr.Audio(type="filepath", label="Drums")
    output_bass = gr.Audio(type="filepath", label="Bass")
    output_vocals = gr.Audio(type="filepath", label="Vocals")
    output_instrumental = gr.Audio(type="filepath", label="Instrumental (No Vocals)")
    #output_labels = gr.JSON(label="Labels")

    app = build_endpoint(
        model_card=model_card,
        input_components=[input_audio, model_dropdown],
        output_components=[
            output_drums,
            output_bass,
            output_vocals,
            output_instrumental
        ],
        process_fn=process_fn
    )

demo.queue().launch(share=True, show_error=True)