Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from models.init_weight import init_net | |
| from models.model_blocks import AdaInResBlock | |
| from models.model_blocks import ResBlock | |
| from models.semantic_face_fusion_model import SemanticFaceFusionModule | |
| from models.shape_aware_identity_model import ShapeAwareIdentityExtractor | |
| class Encoder(nn.Module): | |
| """ | |
| Hififace encoder part | |
| """ | |
| def __init__(self): | |
| super(Encoder, self).__init__() | |
| self.conv_first = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) | |
| self.channel_list = [64, 128, 256, 512, 512, 512, 512, 512] | |
| self.down_sample = [True, True, True, True, True, False, False] | |
| self.block_list = nn.ModuleList() | |
| for i in range(7): | |
| self.block_list.append( | |
| ResBlock(self.channel_list[i], self.channel_list[i + 1], down_sample=self.down_sample[i]) | |
| ) | |
| def forward(self, x): | |
| x = self.conv_first(x) | |
| z_enc = None | |
| for i in range(7): | |
| x = self.block_list[i](x) | |
| if i == 1: | |
| z_enc = x | |
| return z_enc, x | |
| class Decoder(nn.Module): | |
| """ | |
| Hififace decoder part | |
| """ | |
| def __init__(self): | |
| super(Decoder, self).__init__() | |
| self.block_list = nn.ModuleList() | |
| self.channel_list = [512, 512, 512, 512, 512, 256] | |
| self.up_sample = [False, False, True, True, True] | |
| for i in range(5): | |
| self.block_list.append( | |
| AdaInResBlock(self.channel_list[i], self.channel_list[i + 1], up_sample=self.up_sample[i]) | |
| ) | |
| def forward(self, x, id_vector): | |
| """ | |
| Parameters: | |
| ----------- | |
| x: encoder encoded feature map | |
| id_vector: 3d shape aware identity vector | |
| Returns: | |
| -------- | |
| z_dec | |
| """ | |
| for i in range(5): | |
| x = self.block_list[i](x, id_vector) | |
| return x | |
| class Generator(nn.Module): | |
| """ | |
| Hififace Generator | |
| """ | |
| def __init__(self, identity_extractor_config): | |
| super(Generator, self).__init__() | |
| self.id_extractor = ShapeAwareIdentityExtractor(identity_extractor_config) | |
| self.id_extractor.requires_grad_(False) | |
| self.encoder = init_net(Encoder()) | |
| self.decoder = init_net(Decoder()) | |
| self.sff_module = init_net(SemanticFaceFusionModule()) | |
| def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0): | |
| shape_aware_id_vector = self.id_extractor.interp(i_source, i_target, shape_rate, id_rate) | |
| z_enc, x = self.encoder(i_target) | |
| z_dec = self.decoder(x, shape_aware_id_vector) | |
| i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector) | |
| return i_r, i_low, m_r, m_low | |
| def forward(self, i_source, i_target, need_id_grad=False): | |
| """ | |
| Parameters: | |
| ----------- | |
| i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image | |
| i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image | |
| need_id_grad: bool, whether to calculate id extractor module's gradient | |
| Returns: | |
| -------- | |
| i_r: torch.Tensor | |
| i_low: torch.Tensor | |
| m_r: torch.Tensor | |
| m_low: torch.Tensor | |
| """ | |
| if need_id_grad: | |
| shape_aware_id_vector = self.id_extractor(i_source, i_target) | |
| else: | |
| with torch.no_grad(): | |
| shape_aware_id_vector = self.id_extractor(i_source, i_target) | |
| z_enc, x = self.encoder(i_target) | |
| z_dec = self.decoder(x, shape_aware_id_vector) | |
| i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector) | |
| return i_r, i_low, m_r, m_low | |