Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- app.py +90 -0
- helper.py +66 -0
- requirements.txt +7 -0
app.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio
|
| 3 |
+
import torchaudio.functional as F
|
| 4 |
+
from torchaudio.utils import download_asset
|
| 5 |
+
|
| 6 |
+
from pesq import pesq
|
| 7 |
+
from pystoi import stoi
|
| 8 |
+
import mir_eval
|
| 9 |
+
from pydub import AudioSegment
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
import streamlit as st
|
| 13 |
+
from helper import plot_spectrogram,plot_mask,si_snr,generate_mixture,evaluate,get_irms
|
| 14 |
+
|
| 15 |
+
target_snr=3
|
| 16 |
+
|
| 17 |
+
#parameters for STFT
|
| 18 |
+
N_FFT = 1024
|
| 19 |
+
N_HOP = 256
|
| 20 |
+
stft = torchaudio.transforms.Spectrogram(
|
| 21 |
+
n_fft=N_FFT,
|
| 22 |
+
hop_length=N_HOP,
|
| 23 |
+
power=None,
|
| 24 |
+
)
|
| 25 |
+
istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)
|
| 26 |
+
#defining a psd transform
|
| 27 |
+
psd_transform = torchaudio.transforms.PSD()
|
| 28 |
+
mvdr_transform = torchaudio.transforms.SoudenMVDR()
|
| 29 |
+
|
| 30 |
+
#defining the reference microphone
|
| 31 |
+
REFERENCE_CHANNEL = 0
|
| 32 |
+
|
| 33 |
+
#creating a random noise for better calculations
|
| 34 |
+
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
|
| 35 |
+
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
|
| 36 |
+
waveform_noise = waveform_noise.to(torch.double)
|
| 37 |
+
stft_noise = stft(waveform_noise)
|
| 38 |
+
|
| 39 |
+
def ui():
|
| 40 |
+
st.title("Speech Enhancer")
|
| 41 |
+
st.markdown("Made by Vageesh")
|
| 42 |
+
#making an audio developer uploader:
|
| 43 |
+
audio_file = st.file_uploader("Upload an audio file in wav format", type=[ "wav"])
|
| 44 |
+
|
| 45 |
+
if audio_file is not None:
|
| 46 |
+
waveform_clean,sr=torchaudio.load(audio_file)
|
| 47 |
+
waveform_clean = waveform_mix.to(torch.double)
|
| 48 |
+
stft_clean = stft(waveform_clean)
|
| 49 |
+
st.text("Your uploaded audio")
|
| 50 |
+
st.audio(waveform_clean)
|
| 51 |
+
#creating a mixture of our audio file and the noise file
|
| 52 |
+
waveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr)
|
| 53 |
+
#making the files into torch double format
|
| 54 |
+
waveform_mix = waveform_mix.to(torch.double)
|
| 55 |
+
#computing STFT
|
| 56 |
+
stft_mix = stft(waveform_mix)
|
| 57 |
+
#plotting the spectogram
|
| 58 |
+
spec_img=plot_spectrogram(stft_mix)
|
| 59 |
+
st.image(spec_img,captions='Spectrogram of Mixture Speech (dB)')
|
| 60 |
+
#showing mixed audio in streamlit
|
| 61 |
+
st.audio(waveform_mix)
|
| 62 |
+
#getting the irms
|
| 63 |
+
irm_speech, irm_noise = get_irms(stft_clean, stft_noise)
|
| 64 |
+
#getting the psd speech
|
| 65 |
+
psd_speech = psd_transform(stft_mix, irm_speech)
|
| 66 |
+
psd_noise = psd_transform(stft_mix, irm_noise)
|
| 67 |
+
stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
|
| 68 |
+
waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])
|
| 69 |
+
#plotting the cleaned audio and hearing it
|
| 70 |
+
spec_clean_img=plot_spectrogram(stft_souden)
|
| 71 |
+
waveform_souden = waveform_souden.reshape(1, -1)
|
| 72 |
+
st.image(spec_clean_img,captions='Spectrogram of Mixture Speech (dB)')
|
| 73 |
+
st.audio(waveform_souden)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
helper.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
|
| 2 |
+
magnitude = stft.abs()
|
| 3 |
+
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
|
| 4 |
+
# figure, axis = plt.subplots(1, 1)
|
| 5 |
+
# img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
|
| 6 |
+
# figure.suptitle(title)
|
| 7 |
+
# plt.colorbar(img, ax=axis)
|
| 8 |
+
# plt.show()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def plot_mask(mask, title="Mask", xlim=None):
|
| 12 |
+
mask = mask.numpy()
|
| 13 |
+
figure, axis = plt.subplots(1, 1)
|
| 14 |
+
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
|
| 15 |
+
figure.suptitle(title)
|
| 16 |
+
plt.colorbar(img, ax=axis)
|
| 17 |
+
plt.show()
|
| 18 |
+
|
| 19 |
+
def si_snr(estimate, reference, epsilon=1e-8):
|
| 20 |
+
estimate = estimate - estimate.mean()
|
| 21 |
+
reference = reference - reference.mean()
|
| 22 |
+
reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
|
| 23 |
+
mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
|
| 24 |
+
scale = mix_pow / (reference_pow + epsilon)
|
| 25 |
+
|
| 26 |
+
reference = scale * reference
|
| 27 |
+
error = estimate - reference
|
| 28 |
+
|
| 29 |
+
reference_pow = reference.pow(2)
|
| 30 |
+
error_pow = error.pow(2)
|
| 31 |
+
|
| 32 |
+
reference_pow = reference_pow.mean(axis=1)
|
| 33 |
+
error_pow = error_pow.mean(axis=1)
|
| 34 |
+
|
| 35 |
+
si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
|
| 36 |
+
return si_snr.item()
|
| 37 |
+
|
| 38 |
+
def generate_mixture(waveform_clean, waveform_noise, target_snr):
|
| 39 |
+
power_clean_signal = waveform_clean.pow(2).mean()
|
| 40 |
+
power_noise_signal = waveform_noise.pow(2).mean()
|
| 41 |
+
current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
|
| 42 |
+
waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
|
| 43 |
+
return waveform_clean + waveform_noise
|
| 44 |
+
|
| 45 |
+
def evaluate(estimate, reference):
|
| 46 |
+
si_snr_score = si_snr(estimate, reference)
|
| 47 |
+
(
|
| 48 |
+
sdr,
|
| 49 |
+
_,
|
| 50 |
+
_,
|
| 51 |
+
_,
|
| 52 |
+
) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False)
|
| 53 |
+
pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), "wb")
|
| 54 |
+
stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False)
|
| 55 |
+
print(f"SDR score: {sdr[0]}")
|
| 56 |
+
print(f"Si-SNR score: {si_snr_score}")
|
| 57 |
+
print(f"PESQ score: {pesq_mix}")
|
| 58 |
+
print(f"STOI score: {stoi_mix}")
|
| 59 |
+
|
| 60 |
+
def get_irms(stft_clean, stft_noise):
|
| 61 |
+
mag_clean = stft_clean.abs() ** 2
|
| 62 |
+
mag_noise = stft_noise.abs() ** 2
|
| 63 |
+
irm_speech = mag_clean / (mag_clean + mag_noise)
|
| 64 |
+
irm_noise = mag_noise / (mag_clean + mag_noise)
|
| 65 |
+
return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]
|
| 66 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
pesq
|
| 4 |
+
pystoi
|
| 5 |
+
mir_eval
|
| 6 |
+
streamlit
|
| 7 |
+
matplotlib
|