import spaces from pyharp.core import ModelCard, build_endpoint from pyharp.media.audio import save_audio from pyharp.labels import LabelList from audiotools import AudioSignal from demucs import pretrained from demucs.apply import apply_model import gradio as gr import torch from pathlib import Path import numpy as np # ModelCard model_card = ModelCard( name="Demucs Stem Separator (All Stems)", description=( "Separates a full music track into four individual audio stems: Drums, Bass, Vocals, Other-instrumental using the Demucs model. Available models are - \n" "htdemucs: Best overall fidelity and natural sound, " "mdx_extra: Balanced trade-off, faster processing with a small drop in quality, " "mdx_extra_q: Fastest & lightest, noticeable quality loss, ideal for quick previews. \n" ), author="Alexandre Défossez, et al.", tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"] ) DEMUX_MODELS = ["htdemucs", "mdx_extra", "mdx_extra_q"] STEM_NAMES = ["Drums", "Bass", "Vocals", "Other"] # Global model cache LOADED_MODELS = {} def get_cached_model(model_name: str): if model_name not in LOADED_MODELS: model = pretrained.get_model(model_name) model.eval() device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) LOADED_MODELS[model_name] = model return LOADED_MODELS[model_name] # Separation Logic def separate_all_stems(audio_file_path: str, model_name: str): model = get_cached_model(model_name) signal = AudioSignal(audio_file_path) signal = signal.resample(44100) # expects 44.1kHz is_mono = signal.num_channels == 1 if is_mono: signal.audio_data = signal.audio_data.repeat(2, 1) signal._num_channels = 2 sr = signal.sample_rate # Ensure audio_data is a torch.Tensor audio = signal.audio_data if isinstance(audio, np.ndarray): audio = torch.from_numpy(audio) audio = audio.float() # [channels, samples] or [channels, samples, ?] # Remove extra trailing dimensions if audio.ndim > 2: audio = audio.squeeze() # Final shape: [1, channels, samples] waveform = audio.unsqueeze(0) with torch.no_grad(): stems_batch = apply_model( model, waveform.to(next(model.parameters()).device), overlap=0.2, shifts=1, split=True, ) stems = stems_batch[0] output_signals = [] for stem in stems: if is_mono: stem = stem.mean(dim=0, keepdim=True) signal = AudioSignal(stem.cpu().numpy().astype("float32"), sample_rate=sr) output_signals.append(signal) # Combine drums + bass + other = instrumental drums, bass, other, vocals = output_signals instrumental = drums + bass + other return [drums, bass, vocals, other] # Process Function def process_fn(audio_file_path, model_name): if model_name not in DEMUX_MODELS: raise ValueError(f"Unsupported model selected: {model_name}") output_signals = separate_all_stems(audio_file_path, model_name) is_mp3 = Path(audio_file_path).suffix.lower() == ".mp3" extension = "mp3" if is_mp3 else "wav" outputs = [] for stem_name, signal in zip(STEM_NAMES, output_signals): filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.{extension}" output_path = Path(filename) save_audio(signal, output_path) outputs.append(str(output_path)) return tuple(outputs) # Gradio App with gr.Blocks() as demo: input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True) model_dropdown = gr.Dropdown( choices=DEMUX_MODELS, label="Demucs Model", value="htdemucs", info=( "Choose which Demucs variant to use for separation:\n" "1. htdemucs: High-quality Hybrid Transformer Demucs trained on MusDB + 800 songs.\n" "2. mdx_extra: Trained with extended data. Balanced model offering a clear trade-off between quality and speed.\n" "3. mdx_extra_q: Quantized lightweight version of mdx_extra.\n" ) ) output_drums = gr.Audio(type="filepath", label="Drums") output_bass = gr.Audio(type="filepath", label="Bass") output_vocals = gr.Audio(type="filepath", label="Vocals") output_instrumental = gr.Audio(type="filepath", label="Other") app = build_endpoint( model_card=model_card, input_components=[input_audio, model_dropdown], output_components=[ output_drums, output_bass, output_vocals, output_instrumental ], process_fn=process_fn ) demo.queue().launch(share=True, show_error=True)