Commit
·
3d43232
1
Parent(s):
c8b06c0
handling mp3 format
Browse files
app.py
CHANGED
|
@@ -38,19 +38,17 @@ def get_cached_model(model_name: str):
|
|
| 38 |
|
| 39 |
# Separation Logic (all stems)
|
| 40 |
def separate_all_stems(audio_file_path: str, model_name: str):
|
| 41 |
-
model = get_cached_model(model_name)
|
| 42 |
-
#waveform, sr = torchaudio.load(audio_file_path)
|
| 43 |
-
#is_mono = waveform.shape[0] == 1
|
| 44 |
-
#if is_mono:
|
| 45 |
-
#waveform = waveform.repeat(2, 1)
|
| 46 |
-
|
| 47 |
signal = AudioSignal(audio_file_path)
|
| 48 |
signal = signal.resample(44100) # expects 44.1kHz
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
signal = signal.convert_to(stereo=True)
|
| 51 |
|
| 52 |
-
|
| 53 |
-
waveform =
|
|
|
|
|
|
|
| 54 |
|
| 55 |
with torch.no_grad():
|
| 56 |
stems_batch = apply_model(
|
|
@@ -80,11 +78,17 @@ def separate_all_stems(audio_file_path: str, model_name: str):
|
|
| 80 |
def process_fn(audio_file_path, model_name):
|
| 81 |
output_signals = separate_all_stems(audio_file_path, model_name)
|
| 82 |
|
|
|
|
|
|
|
|
|
|
| 83 |
outputs = []
|
| 84 |
for stem_name, signal in zip(STEM_NAMES, output_signals):
|
| 85 |
-
filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
return tuple(outputs)
|
| 90 |
|
|
|
|
| 38 |
|
| 39 |
# Separation Logic (all stems)
|
| 40 |
def separate_all_stems(audio_file_path: str, model_name: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
signal = AudioSignal(audio_file_path)
|
| 42 |
signal = signal.resample(44100) # expects 44.1kHz
|
| 43 |
+
|
| 44 |
+
is_mono = signal.num_channels == 1
|
| 45 |
+
if is_mono:
|
| 46 |
signal = signal.convert_to(stereo=True)
|
| 47 |
|
| 48 |
+
sr = signal.sample_rate
|
| 49 |
+
waveform = signal.audio_data.float() # [channels, samples]
|
| 50 |
+
waveform = waveform.unsqueeze(0) # [1, channels, samples]
|
| 51 |
+
|
| 52 |
|
| 53 |
with torch.no_grad():
|
| 54 |
stems_batch = apply_model(
|
|
|
|
| 78 |
def process_fn(audio_file_path, model_name):
|
| 79 |
output_signals = separate_all_stems(audio_file_path, model_name)
|
| 80 |
|
| 81 |
+
is_mp3 = Path(audio_file_path).suffix.lower() == ".mp3"
|
| 82 |
+
extension = "mp3" if is_mp3 else "wav"
|
| 83 |
+
|
| 84 |
outputs = []
|
| 85 |
for stem_name, signal in zip(STEM_NAMES, output_signals):
|
| 86 |
+
filename = f"demucs_{model_name}_{stem_name.lower().replace(' ', '_')}.{extension}"
|
| 87 |
+
output_path = Path(filename)
|
| 88 |
+
|
| 89 |
+
# Use .export() to control output format
|
| 90 |
+
signal.export(output_path, format=extension)
|
| 91 |
+
outputs.append(str(output_path))
|
| 92 |
|
| 93 |
return tuple(outputs)
|
| 94 |
|