lllindsey0615 commited on
Commit
f489d22
·
1 Parent(s): 3d43232

handling mp3 input

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import spaces
2
  from pyharp.core import ModelCard, build_endpoint
3
- from pyharp.media.audio import load_audio, save_audio
4
  from pyharp.labels import LabelList
5
 
6
  from audiotools import AudioSignal
@@ -8,7 +8,6 @@ from demucs import pretrained
8
  from demucs.apply import apply_model
9
 
10
  import gradio as gr
11
- import torchaudio
12
  import torch
13
  from pathlib import Path
14
 
@@ -20,7 +19,6 @@ model_card = ModelCard(
20
  tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"]
21
  )
22
 
23
-
24
  DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
25
  STEM_NAMES = ["Drums", "Bass", "Vocals", "Instrumental (No Vocals)"]
26
 
@@ -36,8 +34,10 @@ def get_cached_model(model_name: str):
36
  LOADED_MODELS[model_name] = model
37
  return LOADED_MODELS[model_name]
38
 
39
- # Separation Logic (all stems)
40
  def separate_all_stems(audio_file_path: str, model_name: str):
 
 
41
  signal = AudioSignal(audio_file_path)
42
  signal = signal.resample(44100) # expects 44.1kHz
43
 
@@ -46,19 +46,19 @@ def separate_all_stems(audio_file_path: str, model_name: str):
46
  signal = signal.convert_to(stereo=True)
47
 
48
  sr = signal.sample_rate
 
49
  waveform = signal.audio_data.float() # [channels, samples]
50
  waveform = waveform.unsqueeze(0) # [1, channels, samples]
51
 
52
-
53
  with torch.no_grad():
54
  stems_batch = apply_model(
55
  model,
56
- #waveform.unsqueeze(0).to(next(model.parameters()).device),
57
  waveform.to(next(model.parameters()).device),
58
  overlap=0.2,
59
  shifts=1,
60
  split=True,
61
  )
 
62
  stems = stems_batch[0]
63
 
64
  output_signals = []
@@ -86,13 +86,12 @@ def process_fn(audio_file_path, model_name):
86
  filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.{extension}"
87
  output_path = Path(filename)
88
 
89
- # Use .export() to control output format
90
  signal.export(output_path, format=extension)
91
  outputs.append(str(output_path))
92
 
93
  return tuple(outputs)
94
 
95
- # Gradio App
96
  with gr.Blocks() as demo:
97
  input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True)
98
 
@@ -102,12 +101,10 @@ with gr.Blocks() as demo:
102
  value="htdemucs"
103
  )
104
 
105
- # Outputs: Multiple stems
106
  output_drums = gr.Audio(type="filepath", label="Drums")
107
  output_bass = gr.Audio(type="filepath", label="Bass")
108
  output_vocals = gr.Audio(type="filepath", label="Vocals")
109
  output_instrumental = gr.Audio(type="filepath", label="Instrumental (No Vocals)")
110
- #output_labels = gr.JSON(label="Labels")
111
 
112
  app = build_endpoint(
113
  model_card=model_card,
 
1
  import spaces
2
  from pyharp.core import ModelCard, build_endpoint
3
+ from pyharp.media.audio import save_audio
4
  from pyharp.labels import LabelList
5
 
6
  from audiotools import AudioSignal
 
8
  from demucs.apply import apply_model
9
 
10
  import gradio as gr
 
11
  import torch
12
  from pathlib import Path
13
 
 
19
  tags=["demucs", "source-separation", "pyharp", "stems", "multi-output"]
20
  )
21
 
 
22
  DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
23
  STEM_NAMES = ["Drums", "Bass", "Vocals", "Instrumental (No Vocals)"]
24
 
 
34
  LOADED_MODELS[model_name] = model
35
  return LOADED_MODELS[model_name]
36
 
37
+ # Separation Logic
38
  def separate_all_stems(audio_file_path: str, model_name: str):
39
+ model = get_cached_model(model_name)
40
+
41
  signal = AudioSignal(audio_file_path)
42
  signal = signal.resample(44100) # expects 44.1kHz
43
 
 
46
  signal = signal.convert_to(stereo=True)
47
 
48
  sr = signal.sample_rate
49
+
50
  waveform = signal.audio_data.float() # [channels, samples]
51
  waveform = waveform.unsqueeze(0) # [1, channels, samples]
52
 
 
53
  with torch.no_grad():
54
  stems_batch = apply_model(
55
  model,
 
56
  waveform.to(next(model.parameters()).device),
57
  overlap=0.2,
58
  shifts=1,
59
  split=True,
60
  )
61
+
62
  stems = stems_batch[0]
63
 
64
  output_signals = []
 
86
  filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.{extension}"
87
  output_path = Path(filename)
88
 
 
89
  signal.export(output_path, format=extension)
90
  outputs.append(str(output_path))
91
 
92
  return tuple(outputs)
93
 
94
+ # Gradio App
95
  with gr.Blocks() as demo:
96
  input_audio = gr.Audio(type="filepath", label="Input Audio").harp_required(True)
97
 
 
101
  value="htdemucs"
102
  )
103
 
 
104
  output_drums = gr.Audio(type="filepath", label="Drums")
105
  output_bass = gr.Audio(type="filepath", label="Bass")
106
  output_vocals = gr.Audio(type="filepath", label="Vocals")
107
  output_instrumental = gr.Audio(type="filepath", label="Instrumental (No Vocals)")
 
108
 
109
  app = build_endpoint(
110
  model_card=model_card,