import spaces from pyharp.core import ModelCard, build_endpoint from pyharp.media.audio import load_audio, save_audio from pyharp.labels import LabelList from audiotools import AudioSignal from demucs import pretrained from demucs.apply import apply_model import gradio as gr import torchaudio import torch from pathlib import Path # ModelCard model_card = ModelCard( name="Demucs Stem Separator (All Stems)", description="Separates a music mixture into all individual stems using a Demucs model.", author="Alexandre Défossez, et al.", tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"] ) DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"] STEM_NAMES = ["Drums", "Bass", "Other", "Vocals"] # Global model cache LOADED_MODELS = {} def get_cached_model(model_name: str): if model_name not in LOADED_MODELS: model = pretrained.get_model(model_name) model.eval() device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) LOADED_MODELS[model_name] = model return LOADED_MODELS[model_name] # Separation Logic (all stems) def separate_all_stems(audio_file_path: str, model_name: str): model = get_cached_model(model_name) waveform, sr = torchaudio.load(audio_file_path) is_mono = waveform.shape[0] == 1 if is_mono: waveform = waveform.repeat(2, 1) with torch.no_grad(): stems_batch = apply_model( model, waveform.unsqueeze(0).to(next(model.parameters()).device), overlap=0.2, shifts=1, split=True, ) stems = stems_batch[0] output_signals = [] for stem in stems: if is_mono: stem = stem.mean(dim=0, keepdim=True) signal = AudioSignal(stem.cpu().numpy().astype("float32"), sample_rate=sr) output_signals.append(signal) return output_signals # [drums, bass, other, vocals] # Process Function @spaces.GPU def process_fn(audio_file_path, model_name): output_signals = separate_all_stems(audio_file_path, model_name) outputs = [] for stem_name, signal in zip(STEM_NAMES, output_signals): filename = f"demucs_{model_name}_{stem_name.lower()}.wav" output_audio_path = save_audio(signal, filename) outputs.append(output_audio_path) return outputs, LabelList() # Gradio App with gr.Blocks() as demo: input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True) model_dropdown = gr.Dropdown( choices=DEMUX_MODELS, label="Demucs Model", value="htdemucs" ) # Outputs: Multiple stems output_drums = gr.Audio(type="filepath", label="Drums") output_bass = gr.Audio(type="filepath", label="Bass") output_other = gr.Audio(type="filepath", label="Other") output_vocals = gr.Audio(type="filepath", label="Vocals") output_labels = gr.JSON(label="Labels") app = build_endpoint( model_card=model_card, input_components=[input_audio, model_dropdown], output_components=[ output_drums, output_bass, output_other, output_vocals, output_labels ], process_fn=process_fn ) demo.queue().launch(share=True, show_error=True)