Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torch.nn import Conv1d | |
| from torch.nn import ConvTranspose1d | |
| from torch.nn.utils import weight_norm | |
| from torch.nn.utils import remove_weight_norm | |
| from .nsf import SourceModuleHnNSF | |
| from .bigv import init_weights, AMPBlock, SnakeAlias | |
| class Generator(torch.nn.Module): | |
| # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. | |
| def __init__(self, hp): | |
| super(Generator, self).__init__() | |
| self.hp = hp | |
| self.num_kernels = len(hp.gen.resblock_kernel_sizes) | |
| self.num_upsamples = len(hp.gen.upsample_rates) | |
| # pre conv | |
| self.conv_pre = nn.utils.weight_norm( | |
| Conv1d(hp.gen.mel_channels, hp.gen.upsample_initial_channel, 7, 1, padding=3)) | |
| # nsf | |
| self.f0_upsamp = torch.nn.Upsample( | |
| scale_factor=np.prod(hp.gen.upsample_rates)) | |
| self.m_source = SourceModuleHnNSF(sampling_rate=hp.audio.sampling_rate) | |
| self.noise_convs = nn.ModuleList() | |
| # transposed conv-based upsamplers. does not apply anti-aliasing | |
| self.ups = nn.ModuleList() | |
| for i, (u, k) in enumerate(zip(hp.gen.upsample_rates, hp.gen.upsample_kernel_sizes)): | |
| # print(f'ups: {i} {k}, {u}, {(k - u) // 2}') | |
| # base | |
| self.ups.append( | |
| weight_norm( | |
| ConvTranspose1d( | |
| hp.gen.upsample_initial_channel // (2 ** i), | |
| hp.gen.upsample_initial_channel // (2 ** (i + 1)), | |
| k, | |
| u, | |
| padding=(k - u) // 2) | |
| ) | |
| ) | |
| # nsf | |
| if i + 1 < len(hp.gen.upsample_rates): | |
| stride_f0 = np.prod(hp.gen.upsample_rates[i + 1:]) | |
| stride_f0 = int(stride_f0) | |
| self.noise_convs.append( | |
| Conv1d( | |
| 1, | |
| hp.gen.upsample_initial_channel // (2 ** (i + 1)), | |
| kernel_size=stride_f0 * 2, | |
| stride=stride_f0, | |
| padding=stride_f0 // 2, | |
| ) | |
| ) | |
| else: | |
| self.noise_convs.append( | |
| Conv1d(1, hp.gen.upsample_initial_channel // | |
| (2 ** (i + 1)), kernel_size=1) | |
| ) | |
| # residual blocks using anti-aliased multi-periodicity composition modules (AMP) | |
| self.resblocks = nn.ModuleList() | |
| for i in range(len(self.ups)): | |
| ch = hp.gen.upsample_initial_channel // (2 ** (i + 1)) | |
| for k, d in zip(hp.gen.resblock_kernel_sizes, hp.gen.resblock_dilation_sizes): | |
| self.resblocks.append(AMPBlock(ch, k, d)) | |
| # post conv | |
| self.activation_post = SnakeAlias(ch) | |
| self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) | |
| # weight initialization | |
| self.ups.apply(init_weights) | |
| def forward(self, x, f0, train=True): | |
| # nsf | |
| f0 = f0[:, None] | |
| f0 = self.f0_upsamp(f0).transpose(1, 2) | |
| har_source = self.m_source(f0) | |
| har_source = har_source.transpose(1, 2) | |
| # pre conv | |
| if train: | |
| x = x + torch.randn_like(x) * 0.1 # Perturbation | |
| x = self.conv_pre(x) | |
| x = x * torch.tanh(F.softplus(x)) | |
| for i in range(self.num_upsamples): | |
| # upsampling | |
| x = self.ups[i](x) | |
| # nsf | |
| x_source = self.noise_convs[i](har_source) | |
| x = x + x_source | |
| # AMP blocks | |
| xs = None | |
| for j in range(self.num_kernels): | |
| if xs is None: | |
| xs = self.resblocks[i * self.num_kernels + j](x) | |
| else: | |
| xs += self.resblocks[i * self.num_kernels + j](x) | |
| x = xs / self.num_kernels | |
| # post conv | |
| x = self.activation_post(x) | |
| x = self.conv_post(x) | |
| x = torch.tanh(x) | |
| return x | |
| def remove_weight_norm(self): | |
| for l in self.ups: | |
| remove_weight_norm(l) | |
| for l in self.resblocks: | |
| l.remove_weight_norm() | |
| remove_weight_norm(self.conv_pre) | |
| def eval(self, inference=False): | |
| super(Generator, self).eval() | |
| # don't remove weight norm while validation in training loop | |
| if inference: | |
| self.remove_weight_norm() | |
| def inference(self, mel, f0): | |
| MAX_WAV_VALUE = 32768.0 | |
| audio = self.forward(mel, f0, False) | |
| audio = audio.squeeze() # collapse all dimension except time axis | |
| audio = MAX_WAV_VALUE * audio | |
| audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) | |
| audio = audio.short() | |
| return audio | |
| def pitch2wav(self, f0): | |
| MAX_WAV_VALUE = 32768.0 | |
| # nsf | |
| f0 = f0[:, None] | |
| f0 = self.f0_upsamp(f0).transpose(1, 2) | |
| har_source = self.m_source(f0) | |
| audio = har_source.transpose(1, 2) | |
| audio = audio.squeeze() # collapse all dimension except time axis | |
| audio = MAX_WAV_VALUE * audio | |
| audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) | |
| audio = audio.short() | |
| return audio | |