Spaces:
Running
Running
| from typing import List | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn.utils import parametrize | |
| from TTS.vocoder.layers.lvc_block import LVCBlock | |
| LRELU_SLOPE = 0.1 | |
| class UnivnetGenerator(torch.nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| hidden_channels: int, | |
| cond_channels: int, | |
| upsample_factors: List[int], | |
| lvc_layers_each_block: int, | |
| lvc_kernel_size: int, | |
| kpnet_hidden_channels: int, | |
| kpnet_conv_size: int, | |
| dropout: float, | |
| use_weight_norm=True, | |
| ): | |
| """Univnet Generator network. | |
| Paper: https://arxiv.org/pdf/2106.07889.pdf | |
| Args: | |
| in_channels (int): Number of input tensor channels. | |
| out_channels (int): Number of channels of the output tensor. | |
| hidden_channels (int): Number of hidden network channels. | |
| cond_channels (int): Number of channels of the conditioning tensors. | |
| upsample_factors (List[int]): List of uplsample factors for the upsampling layers. | |
| lvc_layers_each_block (int): Number of LVC layers in each block. | |
| lvc_kernel_size (int): Kernel size of the LVC layers. | |
| kpnet_hidden_channels (int): Number of hidden channels in the key-point network. | |
| kpnet_conv_size (int): Number of convolution channels in the key-point network. | |
| dropout (float): Dropout rate. | |
| use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True. | |
| """ | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.cond_channels = cond_channels | |
| self.upsample_scale = np.prod(upsample_factors) | |
| self.lvc_block_nums = len(upsample_factors) | |
| # define first convolution | |
| self.first_conv = torch.nn.Conv1d( | |
| in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True | |
| ) | |
| # define residual blocks | |
| self.lvc_blocks = torch.nn.ModuleList() | |
| cond_hop_length = 1 | |
| for n in range(self.lvc_block_nums): | |
| cond_hop_length = cond_hop_length * upsample_factors[n] | |
| lvcb = LVCBlock( | |
| in_channels=hidden_channels, | |
| cond_channels=cond_channels, | |
| upsample_ratio=upsample_factors[n], | |
| conv_layers=lvc_layers_each_block, | |
| conv_kernel_size=lvc_kernel_size, | |
| cond_hop_length=cond_hop_length, | |
| kpnet_hidden_channels=kpnet_hidden_channels, | |
| kpnet_conv_size=kpnet_conv_size, | |
| kpnet_dropout=dropout, | |
| ) | |
| self.lvc_blocks += [lvcb] | |
| # define output layers | |
| self.last_conv_layers = torch.nn.ModuleList( | |
| [ | |
| torch.nn.Conv1d( | |
| hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True | |
| ), | |
| ] | |
| ) | |
| # apply weight norm | |
| if use_weight_norm: | |
| self.apply_weight_norm() | |
| def forward(self, c): | |
| """Calculate forward propagation. | |
| Args: | |
| c (Tensor): Local conditioning auxiliary features (B, C ,T'). | |
| Returns: | |
| Tensor: Output tensor (B, out_channels, T) | |
| """ | |
| # random noise | |
| x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) | |
| x = x.to(self.first_conv.bias.device) | |
| x = self.first_conv(x) | |
| for n in range(self.lvc_block_nums): | |
| x = self.lvc_blocks[n](x, c) | |
| # apply final layers | |
| for f in self.last_conv_layers: | |
| x = F.leaky_relu(x, LRELU_SLOPE) | |
| x = f(x) | |
| x = torch.tanh(x) | |
| return x | |
| def remove_weight_norm(self): | |
| """Remove weight normalization module from all of the layers.""" | |
| def _remove_weight_norm(m): | |
| try: | |
| # print(f"Weight norm is removed from {m}.") | |
| parametrize.remove_parametrizations(m, "weight") | |
| except ValueError: # this module didn't have weight norm | |
| return | |
| self.apply(_remove_weight_norm) | |
| def apply_weight_norm(self): | |
| """Apply weight normalization module from all of the layers.""" | |
| def _apply_weight_norm(m): | |
| if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): | |
| torch.nn.utils.parametrizations.weight_norm(m) | |
| # print(f"Weight norm is applied to {m}.") | |
| self.apply(_apply_weight_norm) | |
| def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): | |
| assert layers % stacks == 0 | |
| layers_per_cycle = layers // stacks | |
| dilations = [dilation(i % layers_per_cycle) for i in range(layers)] | |
| return (kernel_size - 1) * sum(dilations) + 1 | |
| def receptive_field_size(self): | |
| """Return receptive field size.""" | |
| return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) | |
| def inference(self, c): | |
| """Perform inference. | |
| Args: | |
| c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`. | |
| Returns: | |
| Tensor: Output tensor (T, out_channels) | |
| """ | |
| x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) | |
| x = x.to(self.first_conv.bias.device) | |
| c = c.to(next(self.parameters())) | |
| return self.forward(c) | |