|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  | from functools import partial | 
					
						
						|  | from typing import List, Optional | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Conv2dAct(nn.Sequential): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_channels: int, | 
					
						
						|  | out_channels: int, | 
					
						
						|  | kernel_size: int, | 
					
						
						|  | padding: int = 0, | 
					
						
						|  | stride: int = 1, | 
					
						
						|  | norm_layer: str = "bn", | 
					
						
						|  | num_groups: int = 32, | 
					
						
						|  | activation: str = "ReLU", | 
					
						
						|  | inplace: bool = True, | 
					
						
						|  | ): | 
					
						
						|  | if norm_layer == "bn": | 
					
						
						|  | NormLayer = nn.BatchNorm2d | 
					
						
						|  | elif norm_layer == "gn": | 
					
						
						|  | NormLayer = partial(nn.GroupNorm, num_groups=num_groups) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception( | 
					
						
						|  | f"`norm_layer` must be one of [`bn`, `gn`], got `{norm_layer}`" | 
					
						
						|  | ) | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.conv = nn.Conv2d( | 
					
						
						|  | in_channels, | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size=kernel_size, | 
					
						
						|  | stride=stride, | 
					
						
						|  | padding=padding, | 
					
						
						|  | bias=False, | 
					
						
						|  | ) | 
					
						
						|  | self.norm = NormLayer(out_channels) | 
					
						
						|  | self.act = getattr(nn, activation)(inplace=inplace) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | return self.act(self.norm(self.conv(x))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SCSEModule(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_channels: int, | 
					
						
						|  | reduction: int = 16, | 
					
						
						|  | activation: str = "ReLU", | 
					
						
						|  | inplace: bool = False, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.cSE = nn.Sequential( | 
					
						
						|  | nn.AdaptiveAvgPool2d(1), | 
					
						
						|  | nn.Conv2d(in_channels, in_channels // reduction, 1), | 
					
						
						|  | getattr(nn, activation)(inplace=inplace), | 
					
						
						|  | nn.Conv2d(in_channels // reduction, in_channels, 1), | 
					
						
						|  | ) | 
					
						
						|  | self.sSE = nn.Conv2d(in_channels, 1, 1) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | return x * self.cSE(x).sigmoid() + x * self.sSE(x).sigmoid() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Attention(nn.Module): | 
					
						
						|  | def __init__(self, name: str, **params): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | if name is None: | 
					
						
						|  | self.attention = nn.Identity(**params) | 
					
						
						|  | elif name == "scse": | 
					
						
						|  | self.attention = SCSEModule(**params) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError("Attention {} is not implemented".format(name)) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | return self.attention(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DecoderBlock(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_channels: int, | 
					
						
						|  | skip_channels: int, | 
					
						
						|  | out_channels: int, | 
					
						
						|  | norm_layer: str = "bn", | 
					
						
						|  | activation: str = "ReLU", | 
					
						
						|  | attention_type: Optional[str] = None, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.conv1 = Conv2dAct( | 
					
						
						|  | in_channels + skip_channels, | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | padding=1, | 
					
						
						|  | norm_layer=norm_layer, | 
					
						
						|  | activation=activation, | 
					
						
						|  | ) | 
					
						
						|  | self.attention1 = Attention( | 
					
						
						|  | attention_type, in_channels=in_channels + skip_channels | 
					
						
						|  | ) | 
					
						
						|  | self.conv2 = Conv2dAct( | 
					
						
						|  | out_channels, | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | padding=1, | 
					
						
						|  | norm_layer=norm_layer, | 
					
						
						|  | activation=activation, | 
					
						
						|  | ) | 
					
						
						|  | self.attention2 = Attention(attention_type, in_channels=out_channels) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, x: torch.Tensor, skip: Optional[torch.Tensor] = None | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | if skip is not None: | 
					
						
						|  | h, w = skip.shape[2:] | 
					
						
						|  | x = F.interpolate(x, size=(h, w), mode="nearest") | 
					
						
						|  | x = torch.cat([x, skip], dim=1) | 
					
						
						|  | x = self.attention1(x) | 
					
						
						|  | else: | 
					
						
						|  | x = F.interpolate(x, scale_factor=(2, 2), mode="nearest") | 
					
						
						|  | x = self.conv1(x) | 
					
						
						|  | x = self.conv2(x) | 
					
						
						|  | x = self.attention2(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CenterBlock(nn.Sequential): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_channels: int, | 
					
						
						|  | out_channels: int, | 
					
						
						|  | norm_layer: str = "bn", | 
					
						
						|  | activation: str = "ReLU", | 
					
						
						|  | ): | 
					
						
						|  | conv1 = Conv2dAct( | 
					
						
						|  | in_channels, | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | padding=1, | 
					
						
						|  | norm_layer=norm_layer, | 
					
						
						|  | activation=activation, | 
					
						
						|  | ) | 
					
						
						|  | conv2 = Conv2dAct( | 
					
						
						|  | out_channels, | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | padding=1, | 
					
						
						|  | norm_layer=norm_layer, | 
					
						
						|  | activation=activation, | 
					
						
						|  | ) | 
					
						
						|  | super().__init__(conv1, conv2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class UnetDecoder(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | decoder_n_blocks: int, | 
					
						
						|  | decoder_channels: List[int], | 
					
						
						|  | encoder_channels: List[int], | 
					
						
						|  | decoder_center_block: bool = False, | 
					
						
						|  | decoder_norm_layer: str = "bn", | 
					
						
						|  | decoder_attention_type: Optional[str] = None, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.decoder_n_blocks = decoder_n_blocks | 
					
						
						|  | self.decoder_channels = decoder_channels | 
					
						
						|  | self.encoder_channels = encoder_channels | 
					
						
						|  | self.decoder_center_block = decoder_center_block | 
					
						
						|  | self.decoder_norm_layer = decoder_norm_layer | 
					
						
						|  | self.decoder_attention_type = decoder_attention_type | 
					
						
						|  |  | 
					
						
						|  | if self.decoder_n_blocks != len(self.decoder_channels): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( | 
					
						
						|  | self.decoder_n_blocks, len(self.decoder_channels) | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | encoder_channels = encoder_channels[::-1] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | head_channels = encoder_channels[0] | 
					
						
						|  | in_channels = [head_channels] + list(self.decoder_channels[:-1]) | 
					
						
						|  | skip_channels = list(encoder_channels[1:]) + [0] | 
					
						
						|  | out_channels = self.decoder_channels | 
					
						
						|  |  | 
					
						
						|  | if self.decoder_center_block: | 
					
						
						|  | self.center = CenterBlock( | 
					
						
						|  | head_channels, head_channels, norm_layer=self.decoder_norm_layer | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.center = nn.Identity() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | kwargs = dict( | 
					
						
						|  | norm_layer=self.decoder_norm_layer, | 
					
						
						|  | attention_type=self.decoder_attention_type, | 
					
						
						|  | ) | 
					
						
						|  | blocks = [ | 
					
						
						|  | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) | 
					
						
						|  | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) | 
					
						
						|  | ] | 
					
						
						|  | self.blocks = nn.ModuleList(blocks) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, features: List[torch.Tensor]) -> torch.Tensor: | 
					
						
						|  | features = features[::-1] | 
					
						
						|  |  | 
					
						
						|  | head = features[0] | 
					
						
						|  | skips = features[1:] | 
					
						
						|  |  | 
					
						
						|  | output = [self.center(head)] | 
					
						
						|  | for i, decoder_block in enumerate(self.blocks): | 
					
						
						|  | skip = skips[i] if i < len(skips) else None | 
					
						
						|  | output.append(decoder_block(output[-1], skip)) | 
					
						
						|  |  | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SegmentationHead(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_channels: int, | 
					
						
						|  | out_channels: int, | 
					
						
						|  | size: int, | 
					
						
						|  | kernel_size: int = 3, | 
					
						
						|  | dropout: float = 0.0, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.drop = nn.Dropout2d(p=dropout) | 
					
						
						|  | self.conv = nn.Conv2d( | 
					
						
						|  | in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 | 
					
						
						|  | ) | 
					
						
						|  | if isinstance(size, (tuple, list)): | 
					
						
						|  | self.up = nn.Upsample(size=size, mode="bilinear") | 
					
						
						|  | else: | 
					
						
						|  | self.up = nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | return self.up(self.conv(self.drop(x))) | 
					
						
						|  |  |