Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| from hifigan.models import Generator as HiFiGAN | |
| from hifigan.utils import AttrDict | |
| from torch import Tensor | |
| from torchaudio.sox_effects import apply_effects_tensor | |
| from wavlm.WavLM import WavLM | |
| from knnvc_utils import generate_matrix_from_index | |
| SPEAKER_INFORMATION_LAYER = 6 | |
| SPEAKER_INFORMATION_WEIGHTS = generate_matrix_from_index(SPEAKER_INFORMATION_LAYER) | |
| def fast_cosine_dist(source_feats: Tensor, matching_pool: Tensor, device: str = 'cpu') -> Tensor: | |
| """ Like torch.cdist, but fixed dim=-1 and for cosine distance.""" | |
| source_norms = torch.norm(source_feats, p=2, dim=-1).to(device) | |
| matching_norms = torch.norm(matching_pool, p=2, dim=-1) | |
| dotprod = -torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2 | |
| dotprod /= 2 | |
| dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) ) | |
| return dists | |
| class KNeighborsVC(nn.Module): | |
| def __init__(self, | |
| wavlm: WavLM, | |
| hifigan: HiFiGAN, | |
| hifigan_cfg: AttrDict, | |
| device='cuda' | |
| ) -> None: | |
| """ kNN-VC matcher. | |
| Arguments: | |
| - `wavlm` : trained WavLM model | |
| - `hifigan`: trained hifigan model | |
| - `hifigan_cfg`: hifigan config to use for vocoding. | |
| """ | |
| super().__init__() | |
| # set which features to extract from wavlm | |
| self.weighting = torch.tensor(SPEAKER_INFORMATION_WEIGHTS, device=device)[:, None] | |
| # load hifigan | |
| self.hifigan = hifigan.eval() | |
| self.h = hifigan_cfg | |
| # store wavlm | |
| self.wavlm = wavlm.eval() | |
| self.device = torch.device(device) | |
| self.sr = self.h.sampling_rate | |
| self.hop_length = 320 | |
| def get_matching_set(self, wavs: list[Path] | list[Tensor], weights=None, vad_trigger_level=7) -> Tensor: | |
| """ Get concatenated wavlm features for the matching set using all waveforms in `wavs`, | |
| specified as either a list of paths or list of loaded waveform tensors of | |
| shape (channels, T), assumed to be of 16kHz sample rate. | |
| Optionally specify custom WavLM feature weighting with `weights`. | |
| """ | |
| feats = [] | |
| for p in wavs: | |
| feats.append(self.get_features(p, weights=self.weighting if weights is None else weights, vad_trigger_level=vad_trigger_level)) | |
| feats = torch.concat(feats, dim=0).cpu() | |
| return feats | |
| def vocode(self, c: Tensor) -> Tensor: | |
| """ Vocode features with hifigan. `c` is of shape (bs, seq_len, c_dim) """ | |
| y_g_hat = self.hifigan(c) | |
| y_g_hat = y_g_hat.squeeze(1) | |
| return y_g_hat | |
| def get_features(self, path, weights=None, vad_trigger_level=0): | |
| """Returns features of `path` waveform as a tensor of shape (seq_len, dim), optionally perform VAD trimming | |
| on start/end with `vad_trigger_level`. | |
| """ | |
| # load audio | |
| if weights == None: weights = self.weighting | |
| if type(path) in [str, Path]: | |
| x, sr = torchaudio.load(path, normalize=True) | |
| else: | |
| x: Tensor = path | |
| sr = self.sr | |
| if x.dim() == 1: x = x[None] | |
| if not sr == self.sr : | |
| print(f"resample {sr} to {self.sr} in {path}") | |
| x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=self.sr) | |
| sr = self.sr | |
| # trim silence from front and back | |
| if vad_trigger_level > 1e-3: | |
| transform = T.Vad(sample_rate=sr, trigger_level=vad_trigger_level) | |
| x_front_trim = transform(x) | |
| # original way, disabled because it lacks windows support | |
| #waveform_reversed, sr = apply_effects_tensor(x_front_trim, sr, [["reverse"]]) | |
| waveform_reversed = torch.flip(x_front_trim, (-1,)) | |
| waveform_reversed_front_trim = transform(waveform_reversed) | |
| waveform_end_trim = torch.flip(waveform_reversed_front_trim, (-1,)) | |
| #waveform_end_trim, sr = apply_effects_tensor( | |
| # waveform_reversed_front_trim, sr, [["reverse"]] | |
| #) | |
| x = waveform_end_trim | |
| # extract the representation of each layer | |
| wav_input_16khz = x.to(self.device) | |
| if torch.allclose(weights, self.weighting): | |
| # use fastpath | |
| features = self.wavlm.extract_features(wav_input_16khz, output_layer=SPEAKER_INFORMATION_LAYER, ret_layer_results=False)[0] | |
| features = features.squeeze(0) | |
| else: | |
| # use slower weighted | |
| rep, layer_results = self.wavlm.extract_features(wav_input_16khz, output_layer=self.wavlm.cfg.encoder_layers, ret_layer_results=True)[0] | |
| features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim) | |
| # save full sequence | |
| features = ( features*weights[:, None] ).sum(dim=0) # (seq_len, dim) | |
| return features | |
| def match(self, query_seq: Tensor, matching_set: Tensor, synth_set: Tensor = None, | |
| topk: int = 4, tgt_loudness_db: float | None = -16, | |
| target_duration: float | None = None, device: str | None = None) -> Tensor: | |
| """ Given `query_seq`, `matching_set`, and `synth_set` tensors of shape (N, dim), perform kNN regression matching | |
| with k=`topk`. Inputs: | |
| - `query_seq`: Tensor (N1, dim) of the input/source query features. | |
| - `matching_set`: Tensor (N2, dim) of the matching set used as the 'training set' for the kNN algorithm. | |
| - `synth_set`: optional Tensor (N2, dim) corresponding to the matching set. We use the matching set to assign each query | |
| vector to a vector in the matching set, and then use the corresponding vector from the synth set during HiFiGAN synthesis. | |
| By default, and for best performance, this should be identical to the matching set. | |
| - `topk`: k in the kNN -- the number of nearest neighbors to average over. | |
| - `tgt_loudness_db`: float db used to normalize the output volume. Set to None to disable. | |
| - `target_duration`: if set to a float, interpolate resulting waveform duration to be equal to this value in seconds. | |
| - `device`: if None, uses default device at initialization. Otherwise uses specified device | |
| Returns: | |
| - converted waveform of shape (T,) | |
| """ | |
| device = torch.device(device) if device is not None else self.device | |
| if synth_set is None: synth_set = matching_set.to(device) | |
| else: synth_set = synth_set.to(device) | |
| matching_set = matching_set.to(device) | |
| query_seq = query_seq.to(device) | |
| if target_duration is not None: | |
| target_samples = int(target_duration*self.sr) | |
| scale_factor = (target_samples/self.hop_length) / query_seq.shape[0] # n_targ_feats / n_input_feats | |
| query_seq = F.interpolate(query_seq.T[None], scale_factor=scale_factor, mode='linear')[0].T | |
| dists = fast_cosine_dist(query_seq, matching_set, device=device) | |
| best = dists.topk(k=topk, largest=False, dim=-1) | |
| out_feats = synth_set[best.indices].mean(dim=1) | |
| prediction = self.vocode(out_feats[None].to(device)).cpu().squeeze() | |
| # normalization | |
| if tgt_loudness_db is not None: | |
| src_loudness = torchaudio.functional.loudness(prediction[None], self.h.sampling_rate) | |
| tgt_loudness = tgt_loudness_db | |
| pred_wav = torchaudio.functional.gain(prediction, tgt_loudness - src_loudness) | |
| else: pred_wav = prediction | |
| return pred_wav | |