Spaces:
Runtime error
Runtime error
| from torch import nn | |
| from torch.nn import init | |
| import torch | |
| import torch.nn.functional as F | |
| class conv_block(nn.Module): | |
| def __init__(self, ch_in, ch_out): | |
| super(conv_block, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), | |
| nn.BatchNorm2d(ch_out), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), | |
| nn.BatchNorm2d(ch_out), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| return x | |
| class up_conv(nn.Module): | |
| def __init__(self, ch_in, ch_out): | |
| super(up_conv, self).__init__() | |
| self.up = nn.Sequential( | |
| nn.Upsample(scale_factor=2), | |
| nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), | |
| nn.BatchNorm2d(ch_out), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| x = self.up(x) | |
| return x | |
| class U_Net(nn.Module): | |
| def __init__(self, img_ch=3, output_ch=1, multi_stage=False): | |
| super(U_Net, self).__init__() | |
| self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.Conv1 = conv_block(ch_in=img_ch, ch_out=64) | |
| self.Conv2 = conv_block(ch_in=64, ch_out=128) | |
| self.Conv3 = conv_block(ch_in=128, ch_out=256) | |
| self.Conv4 = conv_block(ch_in=256, ch_out=512) | |
| self.Conv5 = conv_block(ch_in=512, ch_out=1024) | |
| self.Up5 = up_conv(ch_in=1024, ch_out=512) | |
| self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) | |
| self.Up4 = up_conv(ch_in=512, ch_out=256) | |
| self.Up_conv4 = conv_block(ch_in=512, ch_out=256) | |
| self.Up3 = up_conv(ch_in=256, ch_out=128) | |
| self.Up_conv3 = conv_block(ch_in=256, ch_out=128) | |
| self.Up2 = up_conv(ch_in=128, ch_out=64) | |
| self.Up_conv2 = conv_block(ch_in=128, ch_out=64) | |
| self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) | |
| self.activation = nn.Sequential(nn.Sigmoid()) | |
| # init_weights(self) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| init_type = "normal" | |
| gain = 0.02 | |
| classname = m.__class__.__name__ | |
| if hasattr(m, "weight") and ( | |
| classname.find("Conv") != -1 or classname.find("Linear") != -1 | |
| ): | |
| if init_type == "normal": | |
| init.normal_(m.weight.data, 0.0, gain) | |
| elif init_type == "xavier": | |
| init.xavier_normal_(m.weight.data, gain=gain) | |
| elif init_type == "kaiming": | |
| init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") | |
| elif init_type == "orthogonal": | |
| init.orthogonal_(m.weight.data, gain=gain) | |
| else: | |
| raise NotImplementedError( | |
| "initialization method [%s] is not implemented" % init_type | |
| ) | |
| if hasattr(m, "bias") and m.bias is not None: | |
| init.constant_(m.bias.data, 0.0) | |
| elif classname.find("BatchNorm2d") != -1: | |
| init.normal_(m.weight.data, 1.0, gain) | |
| init.constant_(m.bias.data, 0.0) | |
| def forward(self, x): | |
| # encoding path | |
| x1 = self.Conv1(x) | |
| x2 = self.Maxpool(x1) | |
| x2 = self.Conv2(x2) | |
| x3 = self.Maxpool(x2) | |
| x3 = self.Conv3(x3) | |
| x4 = self.Maxpool(x3) | |
| x4 = self.Conv4(x4) | |
| x5 = self.Maxpool(x4) | |
| x5 = self.Conv5(x5) | |
| # decoding + concat path | |
| d5 = self.Up5(x5) | |
| d5 = torch.cat((x4, d5), dim=1) | |
| d5 = self.Up_conv5(d5) | |
| d4 = self.Up4(d5) | |
| d4 = torch.cat((x3, d4), dim=1) | |
| d4 = self.Up_conv4(d4) | |
| d3 = self.Up3(d4) | |
| d3 = torch.cat((x2, d3), dim=1) | |
| d3 = self.Up_conv3(d3) | |
| d2 = self.Up2(d3) | |
| d2 = torch.cat((x1, d2), dim=1) | |
| d2 = self.Up_conv2(d2) | |
| d1 = self.Conv_1x1(d2) | |
| d1 = self.activation(d1) | |
| return d1 | |