File size: 2,361 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch.nn as nn
import torch.nn.functional as F

from models.model_blocks import AdaInResBlock
from models.model_blocks import ResBlock
from models.model_blocks import UpSamplingBlock


class SemanticFaceFusionModule(nn.Module):
    def __init__(self):
        """
        Semantic Face Fusion Module
        to preserve lighting and background
        """
        super(SemanticFaceFusionModule, self).__init__()

        self.sigma = ResBlock(256, 256)
        self.low_mask_predict = nn.Sequential(nn.Conv2d(256, 1, 3, 1, 1), nn.Sigmoid())
        self.z_fuse_block_1 = AdaInResBlock(256, 256)
        self.z_fuse_block_2 = AdaInResBlock(256, 256)

        self.i_low_block = nn.Sequential(nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 3, 3, 1, 1))

        self.f_up = UpSamplingBlock()

    def forward(self, target_image, z_enc, z_dec, v_sid):
        """
        Parameters:
        ----------
        target_image: 目标脸图片
        z_enc: 1/4原图大小的low-level encoder feature map
        z_dec: 1/4原图大小的low-level decoder feature map
        v_sid: the 3D shape aware identity vector

        Returns:
        --------
        i_r: re-target image
        i_low: 1/4 size retarget image
        m_r: face mask
        m_low: 1/4 size face mask
        """
        z_enc = self.sigma(z_enc)

        # 估算z_dec对应的人脸 low-level feature mask
        m_low = self.low_mask_predict(z_dec)

        # 计算融合的low-level feature map
        # mask区域使用decoder的low-level特征 + 非mask区域使用encoder的low-level特征
        z_fuse = m_low * z_dec + (1 - m_low) * z_enc

        z_fuse = self.z_fuse_block_1(z_fuse, v_sid)
        z_fuse = self.z_fuse_block_2(z_fuse, v_sid)

        i_low = self.i_low_block(z_fuse)

        i_low = m_low * i_low + (1 - m_low) * F.interpolate(target_image, scale_factor=0.25)

        i_r, m_r = self.f_up(z_fuse)
        i_r = m_r * i_r + (1 - m_r) * target_image

        return i_r, i_low, m_r, m_low


if __name__ == "__main__":
    import torch

    timg = torch.randn(1, 3, 256, 256)
    z_enc = torch.randn(1, 256, 64, 64)
    z_dec = torch.randn(1, 256, 64, 64)
    v_sid = torch.randn(1, 769)
    model = SemanticFaceFusionModule()
    i_r, i_low, m_r, m_low = model(timg, z_enc, z_dec, v_sid)
    print(i_r.shape, i_low.shape, m_r.shape, m_low.shape)