Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from torch.nn.utils import remove_weight_norm, weight_norm | |
| class Depthwise_Separable_Conv1D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride = 1, | |
| padding = 0, | |
| dilation = 1, | |
| bias = True, | |
| padding_mode = 'zeros', # TODO: refine this type | |
| device=None, | |
| dtype=None | |
| ): | |
| super().__init__() | |
| self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) | |
| self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) | |
| def forward(self, input): | |
| return self.point_conv(self.depth_conv(input)) | |
| def weight_norm(self): | |
| self.depth_conv = weight_norm(self.depth_conv, name = 'weight') | |
| self.point_conv = weight_norm(self.point_conv, name = 'weight') | |
| def remove_weight_norm(self): | |
| self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight') | |
| self.point_conv = remove_weight_norm(self.point_conv, name = 'weight') | |
| class Depthwise_Separable_TransposeConv1D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride = 1, | |
| padding = 0, | |
| output_padding = 0, | |
| bias = True, | |
| dilation = 1, | |
| padding_mode = 'zeros', # TODO: refine this type | |
| device=None, | |
| dtype=None | |
| ): | |
| super().__init__() | |
| self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) | |
| self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) | |
| def forward(self, input): | |
| return self.point_conv(self.depth_conv(input)) | |
| def weight_norm(self): | |
| self.depth_conv = weight_norm(self.depth_conv, name = 'weight') | |
| self.point_conv = weight_norm(self.point_conv, name = 'weight') | |
| def remove_weight_norm(self): | |
| remove_weight_norm(self.depth_conv, name = 'weight') | |
| remove_weight_norm(self.point_conv, name = 'weight') | |
| def weight_norm_modules(module, name = 'weight', dim = 0): | |
| if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): | |
| module.weight_norm() | |
| return module | |
| else: | |
| return weight_norm(module,name,dim) | |
| def remove_weight_norm_modules(module, name = 'weight'): | |
| if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): | |
| module.remove_weight_norm() | |
| else: | |
| remove_weight_norm(module,name) |