lllindsey0615 commited on
Commit
3d43232
·
1 Parent(s): c8b06c0

handling mp3 format

Browse files
Files changed (1) hide show
  1. app.py +16 -12
app.py CHANGED
@@ -38,19 +38,17 @@ def get_cached_model(model_name: str):
38
 
39
  # Separation Logic (all stems)
40
  def separate_all_stems(audio_file_path: str, model_name: str):
41
- model = get_cached_model(model_name)
42
- #waveform, sr = torchaudio.load(audio_file_path)
43
- #is_mono = waveform.shape[0] == 1
44
- #if is_mono:
45
- #waveform = waveform.repeat(2, 1)
46
-
47
  signal = AudioSignal(audio_file_path)
48
  signal = signal.resample(44100) # expects 44.1kHz
49
- if signal.num_channels == 1:
 
 
50
  signal = signal.convert_to(stereo=True)
51
 
52
- waveform = torch.from_numpy(signal.audio_data).float() # [channels, samples]
53
- waveform = waveform.unsqueeze(0) # [1, channels, samples]
 
 
54
 
55
  with torch.no_grad():
56
  stems_batch = apply_model(
@@ -80,11 +78,17 @@ def separate_all_stems(audio_file_path: str, model_name: str):
80
  def process_fn(audio_file_path, model_name):
81
  output_signals = separate_all_stems(audio_file_path, model_name)
82
 
 
 
 
83
  outputs = []
84
  for stem_name, signal in zip(STEM_NAMES, output_signals):
85
- filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.wav"
86
- output_audio_path = save_audio(signal, filename)
87
- outputs.append(output_audio_path)
 
 
 
88
 
89
  return tuple(outputs)
90
 
 
38
 
39
  # Separation Logic (all stems)
40
  def separate_all_stems(audio_file_path: str, model_name: str):
 
 
 
 
 
 
41
  signal = AudioSignal(audio_file_path)
42
  signal = signal.resample(44100) # expects 44.1kHz
43
+
44
+ is_mono = signal.num_channels == 1
45
+ if is_mono:
46
  signal = signal.convert_to(stereo=True)
47
 
48
+ sr = signal.sample_rate
49
+ waveform = signal.audio_data.float() # [channels, samples]
50
+ waveform = waveform.unsqueeze(0) # [1, channels, samples]
51
+
52
 
53
  with torch.no_grad():
54
  stems_batch = apply_model(
 
78
  def process_fn(audio_file_path, model_name):
79
  output_signals = separate_all_stems(audio_file_path, model_name)
80
 
81
+ is_mp3 = Path(audio_file_path).suffix.lower() == ".mp3"
82
+ extension = "mp3" if is_mp3 else "wav"
83
+
84
  outputs = []
85
  for stem_name, signal in zip(STEM_NAMES, output_signals):
86
+ filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.{extension}"
87
+ output_path = Path(filename)
88
+
89
+ # Use .export() to control output format
90
+ signal.export(output_path, format=extension)
91
+ outputs.append(str(output_path))
92
 
93
  return tuple(outputs)
94