Spaces:
Sleeping
Sleeping
| import logging | |
| import time | |
| from functools import partial | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn.utils.parametrize import remove_parametrizations | |
| from torchaudio.functional import resample | |
| from torchaudio.transforms import MelSpectrogram | |
| from modules import config | |
| from modules.devices import devices | |
| from .hparams import HParams | |
| logger = logging.getLogger(__name__) | |
| def inference_chunk( | |
| model, | |
| dwav: torch.Tensor, | |
| sr: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| npad=441, | |
| ) -> torch.Tensor: | |
| assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz" | |
| del sr | |
| length = dwav.shape[-1] | |
| abs_max = dwav.abs().max().clamp(min=1e-7) | |
| assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D" | |
| dwav = dwav.to(device=device, dtype=dtype) | |
| dwav = dwav / abs_max # Normalize | |
| dwav = F.pad(dwav, (0, npad)) | |
| hwav: torch.Tensor = model(dwav[None])[0].cpu() # (T,) | |
| hwav = hwav[:length] # Trim padding | |
| hwav = hwav * abs_max # Unnormalize | |
| return hwav | |
| def compute_corr(x, y): | |
| return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs() | |
| def compute_offset(chunk1, chunk2, sr=44100): | |
| """ | |
| Args: | |
| chunk1: (T,) | |
| chunk2: (T,) | |
| Returns: | |
| offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset) | |
| """ | |
| hop_length = sr // 200 # 5 ms resolution | |
| win_length = hop_length * 4 | |
| n_fft = 2 ** (win_length - 1).bit_length() | |
| mel_fn = MelSpectrogram( | |
| sample_rate=sr, | |
| n_fft=n_fft, | |
| win_length=win_length, | |
| hop_length=hop_length, | |
| n_mels=80, | |
| f_min=0.0, | |
| f_max=sr // 2, | |
| ) | |
| chunk1 = chunk1.float() | |
| chunk2 = chunk2.float() | |
| spec1 = mel_fn(chunk1).log1p() | |
| spec2 = mel_fn(chunk2).log1p() | |
| corr = compute_corr(spec1, spec2) # (F, T) | |
| corr = corr.mean(dim=0) # (T,) | |
| argmax = corr.argmax().item() | |
| if argmax > len(corr) // 2: | |
| argmax -= len(corr) | |
| offset = -argmax * hop_length | |
| return offset | |
| def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None): | |
| signal_length = (len(chunks) - 1) * hop_length + chunk_length | |
| overlap_length = chunk_length - hop_length | |
| signal = torch.zeros(signal_length, device=chunks[0].device) | |
| fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device) | |
| fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)]) | |
| fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device) | |
| fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout]) | |
| for i, chunk in enumerate(chunks): | |
| start = i * hop_length | |
| end = start + chunk_length | |
| if len(chunk) < chunk_length: | |
| chunk = F.pad(chunk, (0, chunk_length - len(chunk))) | |
| if i > 0: | |
| pre_region = chunks[i - 1][-overlap_length:] | |
| cur_region = chunk[:overlap_length] | |
| offset = compute_offset(pre_region, cur_region, sr=sr) | |
| start -= offset | |
| end -= offset | |
| if i == 0: | |
| chunk = chunk * fadeout | |
| elif i == len(chunks) - 1: | |
| chunk = chunk * fadein | |
| else: | |
| chunk = chunk * fadein * fadeout | |
| signal[start:end] += chunk[: len(signal[start:end])] | |
| signal = signal[:length] | |
| return signal | |
| def remove_weight_norm_recursively(module): | |
| for _, module in module.named_modules(): | |
| try: | |
| remove_parametrizations(module, "weight") | |
| except Exception: | |
| pass | |
| def inference( | |
| model, | |
| dwav, | |
| sr, | |
| device, | |
| dtype, | |
| chunk_seconds: float = 30.0, | |
| overlap_seconds: float = 1.0, | |
| ): | |
| from tqdm import trange | |
| if config.runtime_env_vars.off_tqdm: | |
| trange = partial(trange, disable=True) | |
| remove_weight_norm_recursively(model) | |
| hp: HParams = model.hp | |
| dwav = resample( | |
| dwav, | |
| orig_freq=sr, | |
| new_freq=hp.wav_rate, | |
| lowpass_filter_width=64, | |
| rolloff=0.9475937167399596, | |
| resampling_method="sinc_interp_kaiser", | |
| beta=14.769656459379492, | |
| ) | |
| del sr # Everything is in hp.wav_rate now | |
| sr = hp.wav_rate | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| start_time = time.perf_counter() | |
| chunk_length = int(sr * chunk_seconds) | |
| overlap_length = int(sr * overlap_seconds) | |
| hop_length = chunk_length - overlap_length | |
| chunks = [] | |
| for start in trange(0, dwav.shape[-1], hop_length): | |
| chunk_dwav = inference_chunk( | |
| model, dwav[start : start + chunk_length], sr, device, dtype | |
| ) | |
| chunks.append(chunk_dwav.cpu()) | |
| devices.torch_gc() | |
| hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1]) | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| elapsed_time = time.perf_counter() - start_time | |
| logger.debug( | |
| f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz" | |
| ) | |
| devices.torch_gc() | |
| return hwav, sr | |