Spaces:
Running
on
Zero
Running
on
Zero
| from functools import reduce | |
| import math | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.backends.cuda import sdp_kernel | |
| from packaging import version | |
| from dac.nn.layers import Snake1d | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, main, skip=None): | |
| super().__init__() | |
| self.main = nn.Sequential(*main) | |
| self.skip = skip if skip else nn.Identity() | |
| def forward(self, input): | |
| return self.main(input) + self.skip(input) | |
| class ResConvBlock(ResidualBlock): | |
| def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): | |
| skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) | |
| super().__init__([ | |
| nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), | |
| nn.GroupNorm(1, c_mid), | |
| Snake1d(c_mid) if use_snake else nn.GELU(), | |
| nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), | |
| nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), | |
| (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), | |
| ], skip) | |
| class SelfAttention1d(nn.Module): | |
| def __init__(self, c_in, n_head=1, dropout_rate=0.): | |
| super().__init__() | |
| assert c_in % n_head == 0 | |
| self.norm = nn.GroupNorm(1, c_in) | |
| self.n_head = n_head | |
| self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) | |
| self.out_proj = nn.Conv1d(c_in, c_in, 1) | |
| self.dropout = nn.Dropout(dropout_rate, inplace=True) | |
| self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') | |
| if not self.use_flash: | |
| return | |
| device_properties = torch.cuda.get_device_properties(torch.device('cuda')) | |
| if device_properties.major == 8 and device_properties.minor == 0: | |
| # Use flash attention for A100 GPUs | |
| self.sdp_kernel_config = (True, False, False) | |
| else: | |
| # Don't use flash attention for other GPUs | |
| self.sdp_kernel_config = (False, True, True) | |
| def forward(self, input): | |
| n, c, s = input.shape | |
| qkv = self.qkv_proj(self.norm(input)) | |
| qkv = qkv.view( | |
| [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) | |
| q, k, v = qkv.chunk(3, dim=1) | |
| scale = k.shape[3]**-0.25 | |
| if self.use_flash: | |
| with sdp_kernel(*self.sdp_kernel_config): | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) | |
| else: | |
| att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) | |
| y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) | |
| return input + self.dropout(self.out_proj(y)) | |
| class SkipBlock(nn.Module): | |
| def __init__(self, *main): | |
| super().__init__() | |
| self.main = nn.Sequential(*main) | |
| def forward(self, input): | |
| return torch.cat([self.main(input), input], dim=1) | |
| class FourierFeatures(nn.Module): | |
| def __init__(self, in_features, out_features, std=1.): | |
| super().__init__() | |
| assert out_features % 2 == 0 | |
| self.weight = nn.Parameter(torch.randn( | |
| [out_features // 2, in_features]) * std) | |
| def forward(self, input): | |
| f = 2 * math.pi * input @ self.weight.T | |
| return torch.cat([f.cos(), f.sin()], dim=-1) | |
| def expand_to_planes(input, shape): | |
| return input[..., None].repeat([1, 1, shape[2]]) | |
| _kernels = { | |
| 'linear': | |
| [1 / 8, 3 / 8, 3 / 8, 1 / 8], | |
| 'cubic': | |
| [-0.01171875, -0.03515625, 0.11328125, 0.43359375, | |
| 0.43359375, 0.11328125, -0.03515625, -0.01171875], | |
| 'lanczos3': | |
| [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, | |
| -0.066637322306633, 0.13550527393817902, 0.44638532400131226, | |
| 0.44638532400131226, 0.13550527393817902, -0.066637322306633, | |
| -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] | |
| } | |
| class Downsample1d(nn.Module): | |
| def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): | |
| super().__init__() | |
| self.pad_mode = pad_mode | |
| kernel_1d = torch.tensor(_kernels[kernel]) | |
| self.pad = kernel_1d.shape[0] // 2 - 1 | |
| self.register_buffer('kernel', kernel_1d) | |
| self.channels_last = channels_last | |
| def forward(self, x): | |
| if self.channels_last: | |
| x = x.permute(0, 2, 1) | |
| x = F.pad(x, (self.pad,) * 2, self.pad_mode) | |
| weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) | |
| indices = torch.arange(x.shape[1], device=x.device) | |
| weight[indices, indices] = self.kernel.to(weight) | |
| x = F.conv1d(x, weight, stride=2) | |
| if self.channels_last: | |
| x = x.permute(0, 2, 1) | |
| return x | |
| class Upsample1d(nn.Module): | |
| def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): | |
| super().__init__() | |
| self.pad_mode = pad_mode | |
| kernel_1d = torch.tensor(_kernels[kernel]) * 2 | |
| self.pad = kernel_1d.shape[0] // 2 - 1 | |
| self.register_buffer('kernel', kernel_1d) | |
| self.channels_last = channels_last | |
| def forward(self, x): | |
| if self.channels_last: | |
| x = x.permute(0, 2, 1) | |
| x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) | |
| weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) | |
| indices = torch.arange(x.shape[1], device=x.device) | |
| weight[indices, indices] = self.kernel.to(weight) | |
| x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) | |
| if self.channels_last: | |
| x = x.permute(0, 2, 1) | |
| return x | |
| def Downsample1d_2( | |
| in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 | |
| ) -> nn.Module: | |
| assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" | |
| return nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=factor * kernel_multiplier + 1, | |
| stride=factor, | |
| padding=factor * (kernel_multiplier // 2), | |
| ) | |
| def Upsample1d_2( | |
| in_channels: int, out_channels: int, factor: int, use_nearest: bool = False | |
| ) -> nn.Module: | |
| if factor == 1: | |
| return nn.Conv1d( | |
| in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 | |
| ) | |
| if use_nearest: | |
| return nn.Sequential( | |
| nn.Upsample(scale_factor=factor, mode="nearest"), | |
| nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| ), | |
| ) | |
| else: | |
| return nn.ConvTranspose1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=factor * 2, | |
| stride=factor, | |
| padding=factor // 2 + factor % 2, | |
| output_padding=factor % 2, | |
| ) | |
| def zero_init(layer): | |
| nn.init.zeros_(layer.weight) | |
| if layer.bias is not None: | |
| nn.init.zeros_(layer.bias) | |
| return layer | |
| def rms_norm(x, scale, eps): | |
| dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) | |
| mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) | |
| scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) | |
| return x * scale.to(x.dtype) | |
| #rms_norm = torch.compile(rms_norm) | |
| class AdaRMSNorm(nn.Module): | |
| def __init__(self, features, cond_features, eps=1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) | |
| def extra_repr(self): | |
| return f"eps={self.eps}," | |
| def forward(self, x, cond): | |
| return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) | |
| def normalize(x, eps=1e-4): | |
| dim = list(range(1, x.ndim)) | |
| n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) | |
| alpha = np.sqrt(n.numel() / x.numel()) | |
| return x / torch.add(eps, n, alpha=alpha) | |
| class ForcedWNConv1d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=1): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) | |
| def forward(self, x): | |
| if self.training: | |
| with torch.no_grad(): | |
| self.weight.copy_(normalize(self.weight)) | |
| fan_in = self.weight[0].numel() | |
| w = normalize(self.weight) / math.sqrt(fan_in) | |
| return F.conv1d(x, w, padding='same') | |
| # Kernels | |
| use_compile = True | |
| def compile(function, *args, **kwargs): | |
| if not use_compile: | |
| return function | |
| try: | |
| return torch.compile(function, *args, **kwargs) | |
| except RuntimeError: | |
| return function | |
| def linear_geglu(x, weight, bias=None): | |
| x = x @ weight.mT | |
| if bias is not None: | |
| x = x + bias | |
| x, gate = x.chunk(2, dim=-1) | |
| return x * F.gelu(gate) | |
| def rms_norm(x, scale, eps): | |
| dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) | |
| mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) | |
| scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) | |
| return x * scale.to(x.dtype) | |
| # Layers | |
| class LinearGEGLU(nn.Linear): | |
| def __init__(self, in_features, out_features, bias=True): | |
| super().__init__(in_features, out_features * 2, bias=bias) | |
| self.out_features = out_features | |
| def forward(self, x): | |
| return linear_geglu(x, self.weight, self.bias) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, shape, fix_scale = False, eps=1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| if fix_scale: | |
| self.register_buffer("scale", torch.ones(shape)) | |
| else: | |
| self.scale = nn.Parameter(torch.ones(shape)) | |
| def extra_repr(self): | |
| return f"shape={tuple(self.scale.shape)}, eps={self.eps}" | |
| def forward(self, x): | |
| return rms_norm(x, self.scale, self.eps) | |
| def snake_beta(x, alpha, beta): | |
| return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) | |
| # try: | |
| # snake_beta = torch.compile(snake_beta) | |
| # except RuntimeError: | |
| # pass | |
| # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license | |
| # License available in LICENSES/LICENSE_NVIDIA.txt | |
| class SnakeBeta(nn.Module): | |
| def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): | |
| super(SnakeBeta, self).__init__() | |
| self.in_features = in_features | |
| # initialize alpha | |
| self.alpha_logscale = alpha_logscale | |
| if self.alpha_logscale: # log scale alphas initialized to zeros | |
| self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) | |
| self.beta = nn.Parameter(torch.zeros(in_features) * alpha) | |
| else: # linear scale alphas initialized to ones | |
| self.alpha = nn.Parameter(torch.ones(in_features) * alpha) | |
| self.beta = nn.Parameter(torch.ones(in_features) * alpha) | |
| self.alpha.requires_grad = alpha_trainable | |
| self.beta.requires_grad = alpha_trainable | |
| self.no_div_by_zero = 0.000000001 | |
| def forward(self, x): | |
| alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] | |
| beta = self.beta.unsqueeze(0).unsqueeze(-1) | |
| if self.alpha_logscale: | |
| alpha = torch.exp(alpha) | |
| beta = torch.exp(beta) | |
| x = snake_beta(x, alpha, beta) | |
| return x |