saumya-pailwan's picture
instrumental return change
665d1ca verified
import spaces
from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import save_audio
from pyharp.labels import LabelList
from audiotools import AudioSignal
from demucs import pretrained
from demucs.apply import apply_model
import gradio as gr
import torch
from pathlib import Path
import numpy as np
# ModelCard
model_card = ModelCard(
name="Demucs Stem Separator (All Stems)",
description=(
"Separates a full music track into four individual audio stems: Drums, Bass, Vocals, Other-instrumental using the Demucs model. Available models are - \n"
"htdemucs: Best overall fidelity and natural sound, "
"mdx_extra: Balanced trade-off, faster processing with a small drop in quality, "
"mdx_extra_q: Fastest & lightest, noticeable quality loss, ideal for quick previews. \n"
),
author="Alexandre Défossez, et al.",
tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"]
)
DEMUX_MODELS = ["htdemucs", "mdx_extra", "mdx_extra_q"]
STEM_NAMES = ["Drums", "Bass", "Vocals", "Other"]
# Global model cache
LOADED_MODELS = {}
def get_cached_model(model_name: str):
if model_name not in LOADED_MODELS:
model = pretrained.get_model(model_name)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
LOADED_MODELS[model_name] = model
return LOADED_MODELS[model_name]
# Separation Logic
def separate_all_stems(audio_file_path: str, model_name: str):
model = get_cached_model(model_name)
signal = AudioSignal(audio_file_path)
signal = signal.resample(44100) # expects 44.1kHz
is_mono = signal.num_channels == 1
if is_mono:
signal.audio_data = signal.audio_data.repeat(2, 1)
signal._num_channels = 2
sr = signal.sample_rate
# Ensure audio_data is a torch.Tensor
audio = signal.audio_data
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
audio = audio.float() # [channels, samples] or [channels, samples, ?]
# Remove extra trailing dimensions
if audio.ndim > 2:
audio = audio.squeeze()
# Final shape: [1, channels, samples]
waveform = audio.unsqueeze(0)
with torch.no_grad():
stems_batch = apply_model(
model,
waveform.to(next(model.parameters()).device),
overlap=0.2,
shifts=1,
split=True,
)
stems = stems_batch[0]
output_signals = []
for stem in stems:
if is_mono:
stem = stem.mean(dim=0, keepdim=True)
signal = AudioSignal(stem.cpu().numpy().astype("float32"), sample_rate=sr)
output_signals.append(signal)
# Combine drums + bass + other = instrumental
drums, bass, other, vocals = output_signals
instrumental = drums + bass + other
return [drums, bass, vocals, other]
# Process Function
def process_fn(audio_file_path, model_name):
if model_name not in DEMUX_MODELS:
raise ValueError(f"Unsupported model selected: {model_name}")
output_signals = separate_all_stems(audio_file_path, model_name)
is_mp3 = Path(audio_file_path).suffix.lower() == ".mp3"
extension = "mp3" if is_mp3 else "wav"
outputs = []
for stem_name, signal in zip(STEM_NAMES, output_signals):
filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.{extension}"
output_path = Path(filename)
save_audio(signal, output_path)
outputs.append(str(output_path))
return tuple(outputs)
# Gradio App
with gr.Blocks() as demo:
input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True)
model_dropdown = gr.Dropdown(
choices=DEMUX_MODELS,
label="Demucs Model",
value="htdemucs",
info=(
"Choose which Demucs variant to use for separation:\n"
"1. htdemucs: High-quality Hybrid Transformer Demucs trained on MusDB + 800 songs.\n"
"2. mdx_extra: Trained with extended data. Balanced model offering a clear trade-off between quality and speed.\n"
"3. mdx_extra_q: Quantized lightweight version of mdx_extra.\n"
)
)
output_drums = gr.Audio(type="filepath", label="Drums")
output_bass = gr.Audio(type="filepath", label="Bass")
output_vocals = gr.Audio(type="filepath", label="Vocals")
output_instrumental = gr.Audio(type="filepath", label="Other")
app = build_endpoint(
model_card=model_card,
input_components=[input_audio, model_dropdown],
output_components=[
output_drums,
output_bass,
output_vocals,
output_instrumental
],
process_fn=process_fn
)
demo.queue().launch(share=True, show_error=True)