Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class PreEmphasis(torch.nn.Module): | |
| def __init__(self, coef: float = 0.97) -> None: | |
| super().__init__() | |
| self.coef = coef | |
| # make kernel | |
| # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped. | |
| self.register_buffer( | |
| "flipped_filter", | |
| torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0), | |
| ) | |
| def forward(self, input: torch.tensor) -> torch.tensor: | |
| assert ( | |
| len(input.size()) == 2 | |
| ), "The number of dimensions of input tensor must be 2!" | |
| # reflect padding to match lengths of in/out | |
| input = input.unsqueeze(1) | |
| input = F.pad(input, (1, 0), "reflect") | |
| return F.conv1d(input, self.flipped_filter) | |
| class AFMS(nn.Module): | |
| """ | |
| Alpha-Feature map scaling, added to the output of each residual block[1,2]. | |
| Reference: | |
| [1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf | |
| [2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page | |
| """ | |
| def __init__(self, nb_dim: int) -> None: | |
| super().__init__() | |
| self.alpha = nn.Parameter(torch.ones((nb_dim, 1))) | |
| self.fc = nn.Linear(nb_dim, nb_dim) | |
| self.sig = nn.Sigmoid() | |
| def forward(self, x): | |
| y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1) | |
| y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1) | |
| x = x + self.alpha | |
| x = x * y | |
| return x | |
| class Bottle2neck(nn.Module): | |
| def __init__( | |
| self, | |
| inplanes, | |
| planes, | |
| kernel_size=None, | |
| dilation=None, | |
| scale=4, | |
| pool=False, | |
| ): | |
| super().__init__() | |
| width = int(math.floor(planes / scale)) | |
| self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1) | |
| self.bn1 = nn.BatchNorm1d(width * scale) | |
| self.nums = scale - 1 | |
| convs = [] | |
| bns = [] | |
| num_pad = math.floor(kernel_size / 2) * dilation | |
| for i in range(self.nums): | |
| convs.append( | |
| nn.Conv1d( | |
| width, | |
| width, | |
| kernel_size=kernel_size, | |
| dilation=dilation, | |
| padding=num_pad, | |
| ) | |
| ) | |
| bns.append(nn.BatchNorm1d(width)) | |
| self.convs = nn.ModuleList(convs) | |
| self.bns = nn.ModuleList(bns) | |
| self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1) | |
| self.bn3 = nn.BatchNorm1d(planes) | |
| self.relu = nn.ReLU() | |
| self.width = width | |
| self.mp = nn.MaxPool1d(pool) if pool else False | |
| self.afms = AFMS(planes) | |
| if inplanes != planes: # if change in number of filters | |
| self.residual = nn.Sequential( | |
| nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False) | |
| ) | |
| else: | |
| self.residual = nn.Identity() | |
| def forward(self, x): | |
| residual = self.residual(x) | |
| out = self.conv1(x) | |
| out = self.relu(out) | |
| out = self.bn1(out) | |
| spx = torch.split(out, self.width, 1) | |
| for i in range(self.nums): | |
| if i == 0: | |
| sp = spx[i] | |
| else: | |
| sp = sp + spx[i] | |
| sp = self.convs[i](sp) | |
| sp = self.relu(sp) | |
| sp = self.bns[i](sp) | |
| if i == 0: | |
| out = sp | |
| else: | |
| out = torch.cat((out, sp), 1) | |
| out = torch.cat((out, spx[self.nums]), 1) | |
| out = self.conv3(out) | |
| out = self.relu(out) | |
| out = self.bn3(out) | |
| out += residual | |
| if self.mp: | |
| out = self.mp(out) | |
| out = self.afms(out) | |
| return out | |