|
|
import spaces |
|
|
from pyharp.core import ModelCard, build_endpoint |
|
|
from pyharp.media.audio import save_audio |
|
|
from pyharp.labels import LabelList |
|
|
|
|
|
from audiotools import AudioSignal |
|
|
from demucs import pretrained |
|
|
from demucs.apply import apply_model |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
model_card = ModelCard( |
|
|
name="Demucs Stem Separator (All Stems)", |
|
|
description=( |
|
|
"Separates a full music track into four individual audio stems: Drums, Bass, Vocals, Other-instrumental using the Demucs model. Available models are - \n" |
|
|
"htdemucs: Best overall fidelity and natural sound, " |
|
|
"mdx_extra: Balanced trade-off, faster processing with a small drop in quality, " |
|
|
"mdx_extra_q: Fastest & lightest, noticeable quality loss, ideal for quick previews. \n" |
|
|
), |
|
|
author="Alexandre Défossez, et al.", |
|
|
tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"] |
|
|
) |
|
|
|
|
|
DEMUX_MODELS = ["htdemucs", "mdx_extra", "mdx_extra_q"] |
|
|
STEM_NAMES = ["Drums", "Bass", "Vocals", "Other"] |
|
|
|
|
|
|
|
|
LOADED_MODELS = {} |
|
|
|
|
|
def get_cached_model(model_name: str): |
|
|
if model_name not in LOADED_MODELS: |
|
|
model = pretrained.get_model(model_name) |
|
|
model.eval() |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model.to(device) |
|
|
LOADED_MODELS[model_name] = model |
|
|
return LOADED_MODELS[model_name] |
|
|
|
|
|
|
|
|
def separate_all_stems(audio_file_path: str, model_name: str): |
|
|
model = get_cached_model(model_name) |
|
|
|
|
|
signal = AudioSignal(audio_file_path) |
|
|
signal = signal.resample(44100) |
|
|
|
|
|
is_mono = signal.num_channels == 1 |
|
|
if is_mono: |
|
|
signal.audio_data = signal.audio_data.repeat(2, 1) |
|
|
signal._num_channels = 2 |
|
|
|
|
|
sr = signal.sample_rate |
|
|
|
|
|
|
|
|
audio = signal.audio_data |
|
|
if isinstance(audio, np.ndarray): |
|
|
audio = torch.from_numpy(audio) |
|
|
|
|
|
audio = audio.float() |
|
|
|
|
|
|
|
|
if audio.ndim > 2: |
|
|
audio = audio.squeeze() |
|
|
|
|
|
|
|
|
waveform = audio.unsqueeze(0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
stems_batch = apply_model( |
|
|
model, |
|
|
waveform.to(next(model.parameters()).device), |
|
|
overlap=0.2, |
|
|
shifts=1, |
|
|
split=True, |
|
|
) |
|
|
|
|
|
stems = stems_batch[0] |
|
|
|
|
|
output_signals = [] |
|
|
for stem in stems: |
|
|
if is_mono: |
|
|
stem = stem.mean(dim=0, keepdim=True) |
|
|
signal = AudioSignal(stem.cpu().numpy().astype("float32"), sample_rate=sr) |
|
|
output_signals.append(signal) |
|
|
|
|
|
|
|
|
drums, bass, other, vocals = output_signals |
|
|
instrumental = drums + bass + other |
|
|
|
|
|
return [drums, bass, vocals, other] |
|
|
|
|
|
|
|
|
def process_fn(audio_file_path, model_name): |
|
|
if model_name not in DEMUX_MODELS: |
|
|
raise ValueError(f"Unsupported model selected: {model_name}") |
|
|
|
|
|
output_signals = separate_all_stems(audio_file_path, model_name) |
|
|
|
|
|
is_mp3 = Path(audio_file_path).suffix.lower() == ".mp3" |
|
|
extension = "mp3" if is_mp3 else "wav" |
|
|
|
|
|
outputs = [] |
|
|
for stem_name, signal in zip(STEM_NAMES, output_signals): |
|
|
filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.{extension}" |
|
|
output_path = Path(filename) |
|
|
|
|
|
save_audio(signal, output_path) |
|
|
outputs.append(str(output_path)) |
|
|
|
|
|
return tuple(outputs) |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True) |
|
|
|
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=DEMUX_MODELS, |
|
|
label="Demucs Model", |
|
|
value="htdemucs", |
|
|
info=( |
|
|
"Choose which Demucs variant to use for separation:\n" |
|
|
"1. htdemucs: High-quality Hybrid Transformer Demucs trained on MusDB + 800 songs.\n" |
|
|
"2. mdx_extra: Trained with extended data. Balanced model offering a clear trade-off between quality and speed.\n" |
|
|
"3. mdx_extra_q: Quantized lightweight version of mdx_extra.\n" |
|
|
) |
|
|
) |
|
|
|
|
|
output_drums = gr.Audio(type="filepath", label="Drums") |
|
|
output_bass = gr.Audio(type="filepath", label="Bass") |
|
|
output_vocals = gr.Audio(type="filepath", label="Vocals") |
|
|
output_instrumental = gr.Audio(type="filepath", label="Other") |
|
|
|
|
|
app = build_endpoint( |
|
|
model_card=model_card, |
|
|
input_components=[input_audio, model_dropdown], |
|
|
output_components=[ |
|
|
output_drums, |
|
|
output_bass, |
|
|
output_vocals, |
|
|
output_instrumental |
|
|
], |
|
|
process_fn=process_fn |
|
|
) |
|
|
|
|
|
demo.queue().launch(share=True, show_error=True) |
|
|
|