Spaces:
Running
Running
| import os | |
| from typing import Union | |
| import torch | |
| import torchaudio | |
| from modules.Denoiser.AudioDenoiser import AudioDenoiser | |
| from modules.utils.constants import MODELS_DIR | |
| from modules.devices import devices | |
| import soundfile as sf | |
| ad: Union[AudioDenoiser, None] = None | |
| class TTSAudioDenoiser: | |
| def load_ad(self): | |
| global ad | |
| if ad is None: | |
| ad = AudioDenoiser( | |
| os.path.join( | |
| MODELS_DIR, | |
| "Denoise", | |
| "audio-denoiser-512-32-v1", | |
| ), | |
| device=devices.device, | |
| ) | |
| ad.model.to(devices.device) | |
| return ad | |
| def denoise(self, audio_data, sample_rate, auto_scale=False): | |
| ad = self.load_ad() | |
| sr = ad.model_sample_rate | |
| return sr, ad.process_waveform(audio_data, sample_rate, auto_scale) | |
| if __name__ == "__main__": | |
| tts_deno = TTSAudioDenoiser() | |
| data, sr = sf.read("test.wav") | |
| audio_tensor = torch.from_numpy(data).unsqueeze(0).float() | |
| print(audio_tensor) | |
| # data, sr = torchaudio.load("test.wav") | |
| # print(data) | |
| # data = data.to(devices.device) | |
| sr, denoised = tts_deno.denoise(audio_data=audio_tensor, sample_rate=sr) | |
| denoised = denoised.cpu() | |
| torchaudio.save("denoised.wav", denoised, sample_rate=sr) | |