Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from torchinfo import summary | |
| import torchvision | |
| resnet = torchvision.models.resnet.resnet50(pretrained=True) | |
| class ConvBlock(nn.Module): | |
| """ | |
| Helper module that consists of a Conv -> BN -> ReLU | |
| """ | |
| def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.ReLU() | |
| self.with_nonlinearity = with_nonlinearity | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.bn(x) | |
| if self.with_nonlinearity: | |
| x = self.relu(x) | |
| return x | |
| class Bridge(nn.Module): | |
| """ | |
| This is the middle layer of the UNet which just consists of some | |
| """ | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.bridge = nn.Sequential( | |
| ConvBlock(in_channels, out_channels), | |
| ConvBlock(out_channels, out_channels) | |
| ) | |
| def forward(self, x): | |
| return self.bridge(x) | |
| class UpBlockForUNetWithResNet50(nn.Module): | |
| """ | |
| Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock | |
| """ | |
| def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None, | |
| upsampling_method="conv_transpose"): | |
| super().__init__() | |
| if up_conv_in_channels == None: | |
| up_conv_in_channels = in_channels | |
| if up_conv_out_channels == None: | |
| up_conv_out_channels = out_channels | |
| if upsampling_method == "conv_transpose": | |
| self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2) | |
| elif upsampling_method == "bilinear": | |
| self.upsample = nn.Sequential( | |
| nn.Upsample(mode='bilinear', scale_factor=2), | |
| nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) | |
| ) | |
| self.conv_block_1 = ConvBlock(in_channels, out_channels) | |
| self.conv_block_2 = ConvBlock(out_channels, out_channels) | |
| def forward(self, up_x, down_x): | |
| """ | |
| :param up_x: this is the output from the previous up block | |
| :param down_x: this is the output from the down block | |
| :return: upsampled feature map | |
| """ | |
| x = self.upsample(up_x) | |
| print(x.shape) | |
| print(down_x.shape) | |
| x = torch.cat([x, down_x], 1) | |
| x = self.conv_block_1(x) | |
| x = self.conv_block_2(x) | |
| return x | |
| class UNetWithResnet50Encoder(nn.Module): | |
| DEPTH = 6 | |
| def __init__(self, max_depth, n_classes=1): | |
| super().__init__() | |
| resnet = torchvision.models.resnet.resnet50(pretrained=True) | |
| down_blocks = [] | |
| up_blocks = [] | |
| self.input_block = nn.Sequential(*list(resnet.children()))[:3] | |
| self.input_pool = list(resnet.children())[3] | |
| for bottleneck in list(resnet.children()): | |
| if isinstance(bottleneck, nn.Sequential): | |
| down_blocks.append(bottleneck) | |
| self.down_blocks = nn.ModuleList(down_blocks) | |
| self.bridge = Bridge(2048, 2048) | |
| up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024)) | |
| up_blocks.append(UpBlockForUNetWithResNet50(1024, 512)) | |
| up_blocks.append(UpBlockForUNetWithResNet50(512, 256)) | |
| up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128, | |
| up_conv_in_channels=256, up_conv_out_channels=128)) | |
| up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64, | |
| up_conv_in_channels=128, up_conv_out_channels=64)) | |
| self.up_blocks = nn.ModuleList(up_blocks) | |
| self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1) | |
| self.max_depth = max_depth | |
| def forward(self, x, with_output_feature_map=False): | |
| pre_pools = dict() | |
| pre_pools[f"layer_0"] = x | |
| x = self.input_block(x) | |
| pre_pools[f"layer_1"] = x | |
| x = self.input_pool(x) | |
| for i, block in enumerate(self.down_blocks, 2): | |
| x = block(x) | |
| if i == (UNetWithResnet50Encoder.DEPTH - 1): | |
| continue | |
| pre_pools[f"layer_{i}"] = x | |
| x = self.bridge(x) | |
| for i, block in enumerate(self.up_blocks, 1): | |
| key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}" | |
| x = block(x, pre_pools[key]) | |
| output_feature_map = x | |
| x = self.out(x) | |
| del pre_pools | |
| # if with_output_feature_map: | |
| # return x, output_feature_map | |
| # else: | |
| # return x | |
| out_depth = torch.sigmoid(x) * self.max_depth | |
| return {'pred_d': out_depth} | |
| # model = UNetWithResnet50Encoder().cuda() | |
| # inp = torch.rand((2, 3, 512, 512)).cuda() | |
| # out = model(inp) | |
| if __name__ == "__main__": | |
| model = UNetWithResnet50Encoder(max_depth=10).cuda() | |
| # print(model) | |
| summary(model, input_size=(1,3,256,256)) |