| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # Resnet Blocks | |
| class ResnetBlockFC(nn.Module): | |
| ''' Fully connected ResNet Block class. | |
| Args: | |
| size_in (int): input dimension | |
| size_out (int): output dimension | |
| size_h (int): hidden dimension | |
| ''' | |
| def __init__(self, size_in, size_out=None, size_h=None): | |
| super().__init__() | |
| # Attributes | |
| if size_out is None: | |
| size_out = size_in | |
| if size_h is None: | |
| size_h = min(size_in, size_out) | |
| self.size_in = size_in | |
| self.size_h = size_h | |
| self.size_out = size_out | |
| # Submodules | |
| self.fc_0 = nn.Linear(size_in, size_h) | |
| self.fc_1 = nn.Linear(size_h, size_out) | |
| self.actvn = nn.ReLU() | |
| if size_in == size_out: | |
| self.shortcut = None | |
| else: | |
| self.shortcut = nn.Linear(size_in, size_out, bias=False) | |
| # Initialization | |
| nn.init.zeros_(self.fc_1.weight) | |
| def forward(self, x): | |
| net = self.fc_0(self.actvn(x)) | |
| dx = self.fc_1(self.actvn(net)) | |
| if self.shortcut is not None: | |
| x_s = self.shortcut(x) | |
| else: | |
| x_s = x | |
| return x_s + dx |