Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from omegaconf import OmegaConf | |
| from .msd import ScaleDiscriminator | |
| from .mpd import MultiPeriodDiscriminator | |
| from .mrd import MultiResolutionDiscriminator | |
| class Discriminator(nn.Module): | |
| def __init__(self, hp): | |
| super(Discriminator, self).__init__() | |
| self.MRD = MultiResolutionDiscriminator(hp) | |
| self.MPD = MultiPeriodDiscriminator(hp) | |
| self.MSD = ScaleDiscriminator() | |
| def forward(self, x): | |
| r = self.MRD(x) | |
| p = self.MPD(x) | |
| s = self.MSD(x) | |
| return r + p + s | |
| if __name__ == '__main__': | |
| hp = OmegaConf.load('../config/base.yaml') | |
| model = Discriminator(hp) | |
| x = torch.randn(3, 1, 16384) | |
| print(x.shape) | |
| output = model(x) | |
| for features, score in output: | |
| for feat in features: | |
| print(feat.shape) | |
| print(score.shape) | |
| pytorch_total_params = sum(p.numel() | |
| for p in model.parameters() if p.requires_grad) | |
| print(pytorch_total_params) | |