lllindsey0615's picture
add instrumental stem
4c06cf0
raw
history blame
3.46 kB
import spaces
from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import load_audio, save_audio
from pyharp.labels import LabelList
from audiotools import AudioSignal
from demucs import pretrained
from demucs.apply import apply_model
import gradio as gr
import torchaudio
import torch
from pathlib import Path
# ModelCard
model_card = ModelCard(
name="Demucs Stem Separator (All Stems)",
description="Separates a music mixture into all individual stems using a Demucs model.",
author="Alexandre Défossez, et al.",
tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"]
)
DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
STEM_NAMES = ["Drums", "Bass", "Vocals", "Instrumental (No Vocals)"]
# Global model cache
LOADED_MODELS = {}
def get_cached_model(model_name: str):
if model_name not in LOADED_MODELS:
model = pretrained.get_model(model_name)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
LOADED_MODELS[model_name] = model
return LOADED_MODELS[model_name]
# Separation Logic (all stems)
def separate_all_stems(audio_file_path: str, model_name: str):
model = get_cached_model(model_name)
waveform, sr = torchaudio.load(audio_file_path)
is_mono = waveform.shape[0] == 1
if is_mono:
waveform = waveform.repeat(2, 1)
with torch.no_grad():
stems_batch = apply_model(
model,
waveform.unsqueeze(0).to(next(model.parameters()).device),
overlap=0.2,
shifts=1,
split=True,
)
stems = stems_batch[0]
output_signals = []
for stem in stems:
if is_mono:
stem = stem.mean(dim=0, keepdim=True)
signal = AudioSignal(stem.cpu().numpy().astype("float32"), sample_rate=sr)
output_signals.append(signal)
# Combine drums + bass + other = instrumental
drums, bass, other, vocals = output_signals
instrumental = drums + bass + other
return [drums, bass, vocals, instrumental]
# Process Function
def process_fn(audio_file_path, model_name):
output_signals = separate_all_stems(audio_file_path, model_name)
outputs = []
for stem_name, signal in zip(STEM_NAMES, output_signals):
filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.wav"
output_audio_path = save_audio(signal, filename)
outputs.append(output_audio_path)
return tuple(outputs)
# Gradio App
with gr.Blocks() as demo:
input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True)
model_dropdown = gr.Dropdown(
choices=DEMUX_MODELS,
label="Demucs Model",
value="htdemucs"
)
# Outputs: Multiple stems
output_drums = gr.Audio(type="filepath", label="Drums")
output_bass = gr.Audio(type="filepath", label="Bass")
output_vocals = gr.Audio(type="filepath", label="Vocals")
output_instrumental = gr.Audio(type="filepath", label="Instrumental (No Vocals)")
#output_labels = gr.JSON(label="Labels")
app = build_endpoint(
model_card=model_card,
input_components=[input_audio, model_dropdown],
output_components=[
output_drums,
output_bass,
output_vocals,
output_instrumental
],
process_fn=process_fn
)
demo.queue().launch(share=True, show_error=True)