File size: 4,793 Bytes
48dbef5 f489d22 48dbef5 25f3a6c 48dbef5 0883a2b 0466bed 7984676 0883a2b 48dbef5 6926b0c 44ccdb7 48dbef5 f489d22 48dbef5 f489d22 3e600ab 3d43232 1a3b1bf 3e600ab 3d43232 f489d22 29ead22 3d43232 48dbef5 c8b06c0 48dbef5 f489d22 48dbef5 4c06cf0 665d1ca 48dbef5 6926b0c 48dbef5 4c06cf0 3d43232 48dbef5 3d43232 182934c 3d43232 4c06cf0 199495b f489d22 48dbef5 0883a2b 6926b0c 0883a2b 48dbef5 44ccdb7 48dbef5 4c06cf0 48dbef5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import spaces
from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import save_audio
from pyharp.labels import LabelList
from audiotools import AudioSignal
from demucs import pretrained
from demucs.apply import apply_model
import gradio as gr
import torch
from pathlib import Path
import numpy as np
# ModelCard
model_card = ModelCard(
name="Demucs Stem Separator (All Stems)",
description=(
"Separates a full music track into four individual audio stems: Drums, Bass, Vocals, Other-instrumental using the Demucs model. Available models are - \n"
"htdemucs: Best overall fidelity and natural sound, "
"mdx_extra: Balanced trade-off, faster processing with a small drop in quality, "
"mdx_extra_q: Fastest & lightest, noticeable quality loss, ideal for quick previews. \n"
),
author="Alexandre Défossez, et al.",
tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"]
)
DEMUX_MODELS = ["htdemucs", "mdx_extra", "mdx_extra_q"]
STEM_NAMES = ["Drums", "Bass", "Vocals", "Other"]
# Global model cache
LOADED_MODELS = {}
def get_cached_model(model_name: str):
if model_name not in LOADED_MODELS:
model = pretrained.get_model(model_name)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
LOADED_MODELS[model_name] = model
return LOADED_MODELS[model_name]
# Separation Logic
def separate_all_stems(audio_file_path: str, model_name: str):
model = get_cached_model(model_name)
signal = AudioSignal(audio_file_path)
signal = signal.resample(44100) # expects 44.1kHz
is_mono = signal.num_channels == 1
if is_mono:
signal.audio_data = signal.audio_data.repeat(2, 1)
signal._num_channels = 2
sr = signal.sample_rate
# Ensure audio_data is a torch.Tensor
audio = signal.audio_data
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
audio = audio.float() # [channels, samples] or [channels, samples, ?]
# Remove extra trailing dimensions
if audio.ndim > 2:
audio = audio.squeeze()
# Final shape: [1, channels, samples]
waveform = audio.unsqueeze(0)
with torch.no_grad():
stems_batch = apply_model(
model,
waveform.to(next(model.parameters()).device),
overlap=0.2,
shifts=1,
split=True,
)
stems = stems_batch[0]
output_signals = []
for stem in stems:
if is_mono:
stem = stem.mean(dim=0, keepdim=True)
signal = AudioSignal(stem.cpu().numpy().astype("float32"), sample_rate=sr)
output_signals.append(signal)
# Combine drums + bass + other = instrumental
drums, bass, other, vocals = output_signals
instrumental = drums + bass + other
return [drums, bass, vocals, other]
# Process Function
def process_fn(audio_file_path, model_name):
if model_name not in DEMUX_MODELS:
raise ValueError(f"Unsupported model selected: {model_name}")
output_signals = separate_all_stems(audio_file_path, model_name)
is_mp3 = Path(audio_file_path).suffix.lower() == ".mp3"
extension = "mp3" if is_mp3 else "wav"
outputs = []
for stem_name, signal in zip(STEM_NAMES, output_signals):
filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.{extension}"
output_path = Path(filename)
save_audio(signal, output_path)
outputs.append(str(output_path))
return tuple(outputs)
# Gradio App
with gr.Blocks() as demo:
input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True)
model_dropdown = gr.Dropdown(
choices=DEMUX_MODELS,
label="Demucs Model",
value="htdemucs",
info=(
"Choose which Demucs variant to use for separation:\n"
"1. htdemucs: High-quality Hybrid Transformer Demucs trained on MusDB + 800 songs.\n"
"2. mdx_extra: Trained with extended data. Balanced model offering a clear trade-off between quality and speed.\n"
"3. mdx_extra_q: Quantized lightweight version of mdx_extra.\n"
)
)
output_drums = gr.Audio(type="filepath", label="Drums")
output_bass = gr.Audio(type="filepath", label="Bass")
output_vocals = gr.Audio(type="filepath", label="Vocals")
output_instrumental = gr.Audio(type="filepath", label="Other")
app = build_endpoint(
model_card=model_card,
input_components=[input_audio, model_dropdown],
output_components=[
output_drums,
output_bass,
output_vocals,
output_instrumental
],
process_fn=process_fn
)
demo.queue().launch(share=True, show_error=True)
|