Spaces:
Configuration error
Configuration error
| #https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/network.py | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| from custom_controlnet_aux.util import custom_torch_download | |
| class UNet(nn.Module): | |
| def __init__(self): | |
| super(UNet, self).__init__() | |
| self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes | |
| mobilenet_v2 = torchvision.models.mobilenet_v2(pretrained=False) | |
| mobilenet_v2.load_state_dict(torch.load(custom_torch_download(filename="mobilenet_v2-b0353104.pth")), strict=True) | |
| mob_blocks = mobilenet_v2.features | |
| # Encoder | |
| self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16 | |
| mob_blocks[0], | |
| mob_blocks[1] | |
| ) | |
| self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24 | |
| mob_blocks[2], | |
| mob_blocks[3], | |
| ) | |
| self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32 | |
| mob_blocks[4], | |
| mob_blocks[5], | |
| mob_blocks[6], | |
| ) | |
| self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96 | |
| mob_blocks[7], | |
| mob_blocks[8], | |
| mob_blocks[9], | |
| mob_blocks[10], | |
| mob_blocks[11], | |
| mob_blocks[12], | |
| mob_blocks[13], | |
| ) | |
| self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160 | |
| mob_blocks[14], | |
| mob_blocks[15], | |
| mob_blocks[16], | |
| ) | |
| # Decoder | |
| self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96 | |
| nn.UpsamplingNearest2d(scale_factor=2), | |
| nn.Conv2d(160, 96, kernel_size=3, padding=1), | |
| nn.InstanceNorm2d(96), | |
| nn.LeakyReLU(0.1), | |
| nn.Dropout(p=0.2) | |
| ) | |
| self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32 | |
| nn.UpsamplingNearest2d(scale_factor=2), | |
| nn.Conv2d(96*2, 32, kernel_size=3, padding=1), | |
| nn.InstanceNorm2d(32), | |
| nn.LeakyReLU(0.1), | |
| nn.Dropout(p=0.2) | |
| ) | |
| self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24 | |
| nn.UpsamplingNearest2d(scale_factor=2), | |
| nn.Conv2d(32*2, 24, kernel_size=3, padding=1), | |
| nn.InstanceNorm2d(24), | |
| nn.LeakyReLU(0.1), | |
| nn.Dropout(p=0.2) | |
| ) | |
| self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16 | |
| nn.UpsamplingNearest2d(scale_factor=2), | |
| nn.Conv2d(24*2, 16, kernel_size=3, padding=1), | |
| nn.InstanceNorm2d(16), | |
| nn.LeakyReLU(0.1), | |
| nn.Dropout(p=0.2) | |
| ) | |
| self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7 | |
| nn.UpsamplingNearest2d(scale_factor=2), | |
| nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1), | |
| nn.Softmax2d() | |
| ) | |
| def forward(self, x): | |
| e0 = self.en_block0(x) | |
| e1 = self.en_block1(e0) | |
| e2 = self.en_block2(e1) | |
| e3 = self.en_block3(e2) | |
| e4 = self.en_block4(e3) | |
| d4 = self.de_block4(e4) | |
| c4 = torch.cat((d4,e3),1) | |
| d3 = self.de_block3(c4) | |
| c3 = torch.cat((d3,e2),1) | |
| d2 = self.de_block2(c3) | |
| c2 =torch.cat((d2,e1),1) | |
| d1 = self.de_block1(c2) | |
| c1 = torch.cat((d1,e0),1) | |
| y = self.de_block0(c1) | |
| return y |