File size: 3,974 Bytes
48dbef5 f489d22 48dbef5 25f3a6c 48dbef5 44ccdb7 48dbef5 f489d22 48dbef5 f489d22 3e600ab 3d43232 1a3b1bf 3e600ab 3d43232 f489d22 29ead22 3d43232 48dbef5 c8b06c0 48dbef5 f489d22 48dbef5 4c06cf0 48dbef5 4c06cf0 3d43232 48dbef5 3d43232 182934c 3d43232 4c06cf0 199495b f489d22 48dbef5 44ccdb7 48dbef5 4c06cf0 48dbef5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import spaces
from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import save_audio
from pyharp.labels import LabelList
from audiotools import AudioSignal
from demucs import pretrained
from demucs.apply import apply_model
import gradio as gr
import torch
from pathlib import Path
import numpy as np
# ModelCard
model_card = ModelCard(
name="Demucs Stem Separator (All Stems)",
description="Separates a music mixture into all individual stems using a Demucs model.",
author="Alexandre Défossez, et al.",
tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"]
)
DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
STEM_NAMES = ["Drums", "Bass", "Vocals", "Other"]
# Global model cache
LOADED_MODELS = {}
def get_cached_model(model_name: str):
if model_name not in LOADED_MODELS:
model = pretrained.get_model(model_name)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
LOADED_MODELS[model_name] = model
return LOADED_MODELS[model_name]
# Separation Logic
def separate_all_stems(audio_file_path: str, model_name: str):
model = get_cached_model(model_name)
signal = AudioSignal(audio_file_path)
signal = signal.resample(44100) # expects 44.1kHz
is_mono = signal.num_channels == 1
if is_mono:
signal.audio_data = signal.audio_data.repeat(2, 1)
signal._num_channels = 2
sr = signal.sample_rate
# Ensure audio_data is a torch.Tensor
audio = signal.audio_data
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
audio = audio.float() # [channels, samples] or [channels, samples, ?]
# Remove extra trailing dimensions
if audio.ndim > 2:
audio = audio.squeeze()
# Final shape: [1, channels, samples]
waveform = audio.unsqueeze(0)
with torch.no_grad():
stems_batch = apply_model(
model,
waveform.to(next(model.parameters()).device),
overlap=0.2,
shifts=1,
split=True,
)
stems = stems_batch[0]
output_signals = []
for stem in stems:
if is_mono:
stem = stem.mean(dim=0, keepdim=True)
signal = AudioSignal(stem.cpu().numpy().astype("float32"), sample_rate=sr)
output_signals.append(signal)
# Combine drums + bass + other = instrumental
drums, bass, other, vocals = output_signals
instrumental = drums + bass + other
return [drums, bass, vocals, instrumental]
# Process Function
def process_fn(audio_file_path, model_name):
output_signals = separate_all_stems(audio_file_path, model_name)
is_mp3 = Path(audio_file_path).suffix.lower() == ".mp3"
extension = "mp3" if is_mp3 else "wav"
outputs = []
for stem_name, signal in zip(STEM_NAMES, output_signals):
filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.{extension}"
output_path = Path(filename)
save_audio(signal, output_path)
outputs.append(str(output_path))
return tuple(outputs)
# Gradio App
with gr.Blocks() as demo:
input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True)
model_dropdown = gr.Dropdown(
choices=DEMUX_MODELS,
label="Demucs Model",
value="htdemucs"
)
output_drums = gr.Audio(type="filepath", label="Drums")
output_bass = gr.Audio(type="filepath", label="Bass")
output_vocals = gr.Audio(type="filepath", label="Vocals")
output_instrumental = gr.Audio(type="filepath", label="Other")
app = build_endpoint(
model_card=model_card,
input_components=[input_audio, model_dropdown],
output_components=[
output_drums,
output_bass,
output_vocals,
output_instrumental
],
process_fn=process_fn
)
demo.queue().launch(share=True, show_error=True)
|