HiFiFace-inference-demo / models /semantic_face_fusion_model.py
xuehongyang
ser
83d8d3c
import torch.nn as nn
import torch.nn.functional as F
from models.model_blocks import AdaInResBlock
from models.model_blocks import ResBlock
from models.model_blocks import UpSamplingBlock
class SemanticFaceFusionModule(nn.Module):
def __init__(self):
"""
Semantic Face Fusion Module
to preserve lighting and background
"""
super(SemanticFaceFusionModule, self).__init__()
self.sigma = ResBlock(256, 256)
self.low_mask_predict = nn.Sequential(nn.Conv2d(256, 1, 3, 1, 1), nn.Sigmoid())
self.z_fuse_block_1 = AdaInResBlock(256, 256)
self.z_fuse_block_2 = AdaInResBlock(256, 256)
self.i_low_block = nn.Sequential(nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 3, 3, 1, 1))
self.f_up = UpSamplingBlock()
def forward(self, target_image, z_enc, z_dec, v_sid):
"""
Parameters:
----------
target_image: 目标脸图片
z_enc: 1/4原图大小的low-level encoder feature map
z_dec: 1/4原图大小的low-level decoder feature map
v_sid: the 3D shape aware identity vector
Returns:
--------
i_r: re-target image
i_low: 1/4 size retarget image
m_r: face mask
m_low: 1/4 size face mask
"""
z_enc = self.sigma(z_enc)
# 估算z_dec对应的人脸 low-level feature mask
m_low = self.low_mask_predict(z_dec)
# 计算融合的low-level feature map
# mask区域使用decoder的low-level特征 + 非mask区域使用encoder的low-level特征
z_fuse = m_low * z_dec + (1 - m_low) * z_enc
z_fuse = self.z_fuse_block_1(z_fuse, v_sid)
z_fuse = self.z_fuse_block_2(z_fuse, v_sid)
i_low = self.i_low_block(z_fuse)
i_low = m_low * i_low + (1 - m_low) * F.interpolate(target_image, scale_factor=0.25)
i_r, m_r = self.f_up(z_fuse)
i_r = m_r * i_r + (1 - m_r) * target_image
return i_r, i_low, m_r, m_low
if __name__ == "__main__":
import torch
timg = torch.randn(1, 3, 256, 256)
z_enc = torch.randn(1, 256, 64, 64)
z_dec = torch.randn(1, 256, 64, 64)
v_sid = torch.randn(1, 769)
model = SemanticFaceFusionModule()
i_r, i_low, m_r, m_low = model(timg, z_enc, z_dec, v_sid)
print(i_r.shape, i_low.shape, m_r.shape, m_low.shape)