Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class Conv2d(nn.Module): | |
| def __init__(self, cin, cout, kernel_size, stride, padding, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout)) | |
| self.act = nn.ReLU() | |
| def forward(self, x): | |
| out = self.conv_block(x) | |
| return self.act(out) | |
| class Conv2d_res(nn.Module): | |
| # TensorRT does not support 'if' statement, thus we create independent Conv2d_res for residual block | |
| def __init__(self, cin, cout, kernel_size, stride, padding, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout)) | |
| self.act = nn.ReLU() | |
| def forward(self, x): | |
| out = self.conv_block(x) | |
| out += x | |
| return self.act(out) | |
| class Conv2dTranspose(nn.Module): | |
| def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.conv_block = nn.Sequential( | |
| nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), | |
| nn.BatchNorm2d(cout), | |
| ) | |
| self.act = nn.ReLU() | |
| def forward(self, x): | |
| out = self.conv_block(x) | |
| return self.act(out) | |
| class FETE_model(nn.Module): | |
| def __init__(self): | |
| super(FETE_model, self).__init__() | |
| self.face_encoder_blocks = nn.ModuleList( | |
| [ | |
| nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=2, padding=3)), # 256,256 -> 128,128 | |
| nn.Sequential( | |
| Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 64,64 | |
| Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), | |
| ), | |
| nn.Sequential( | |
| Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 32,32 | |
| Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), | |
| ), | |
| nn.Sequential( | |
| Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 16,16 | |
| Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), | |
| ), | |
| nn.Sequential( | |
| Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 8,8 | |
| Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), | |
| ), | |
| nn.Sequential( | |
| Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 4,4 | |
| Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1), | |
| ), | |
| nn.Sequential( | |
| Conv2d(512, 512, kernel_size=3, stride=2, padding=0), # 1, 1 | |
| Conv2d(512, 512, kernel_size=1, stride=1, padding=0), | |
| ), | |
| ] | |
| ) | |
| self.audio_encoder = nn.Sequential( | |
| Conv2d(1, 32, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), | |
| Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), | |
| Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), | |
| Conv2d(64, 128, kernel_size=3, stride=3, padding=1), | |
| Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), | |
| Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), | |
| Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), | |
| Conv2d(256, 512, kernel_size=3, stride=1, padding=0), | |
| Conv2d(512, 512, kernel_size=1, stride=1, padding=0), | |
| ) | |
| self.pose_encoder = nn.Sequential( | |
| Conv2d(1, 32, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), | |
| Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1), | |
| Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), | |
| Conv2d(64, 128, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), | |
| Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1), | |
| Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), | |
| Conv2d(256, 512, kernel_size=3, stride=2, padding=0), | |
| Conv2d(512, 512, kernel_size=1, stride=1, padding=0), | |
| ) | |
| self.emotion_encoder = nn.Sequential( | |
| Conv2d(1, 32, kernel_size=7, stride=1, padding=1), | |
| Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), | |
| Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1), | |
| Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), | |
| Conv2d(64, 128, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), | |
| Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1), | |
| Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), | |
| Conv2d(256, 512, kernel_size=3, stride=2, padding=0), | |
| Conv2d(512, 512, kernel_size=1, stride=1, padding=0), | |
| ) | |
| self.blink_encoder = nn.Sequential( | |
| Conv2d(1, 32, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), | |
| Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1), | |
| Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), | |
| Conv2d(64, 128, kernel_size=3, stride=(1, 2), padding=1), | |
| Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), | |
| Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1), | |
| Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), | |
| Conv2d(256, 512, kernel_size=1, stride=(1, 2), padding=0), | |
| Conv2d(512, 512, kernel_size=1, stride=1, padding=0), | |
| ) | |
| self.face_decoder_blocks = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| Conv2d(2048, 512, kernel_size=1, stride=1, padding=0), | |
| ), | |
| nn.Sequential( | |
| Conv2dTranspose(1024, 512, kernel_size=4, stride=1, padding=0), # 4,4 | |
| Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1), | |
| ), | |
| nn.Sequential( | |
| Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), | |
| Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1), # 8,8 | |
| Self_Attention(512, 512), | |
| ), | |
| nn.Sequential( | |
| Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), | |
| Conv2d_res(384, 384, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(384, 384, kernel_size=3, stride=1, padding=1), # 16, 16 | |
| Self_Attention(384, 384), | |
| ), | |
| nn.Sequential( | |
| Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), | |
| Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), # 32, 32 | |
| Self_Attention(256, 256), | |
| ), | |
| nn.Sequential( | |
| Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), | |
| Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), | |
| ), # 64, 64 | |
| nn.Sequential( | |
| Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), | |
| Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), | |
| Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), | |
| ), | |
| ] | |
| ) # 128,128 | |
| # self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), | |
| # nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), | |
| # nn.Sigmoid()) | |
| self.output_block = nn.Sequential( | |
| Conv2dTranspose(80, 32, kernel_size=3, stride=2, padding=1, output_padding=1), | |
| nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), | |
| nn.Sigmoid(), | |
| ) | |
| def forward( | |
| self, | |
| face_sequences, | |
| audio_sequences, | |
| pose_sequences, | |
| emotion_sequences, | |
| blink_sequences, | |
| ): | |
| # audio_sequences = (B, T, 1, 80, 16) | |
| B = audio_sequences.size(0) | |
| # disabled for inference | |
| # input_dim_size = len(face_sequences.size()) | |
| # if input_dim_size > 4: | |
| # audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) | |
| # pose_sequences = torch.cat([pose_sequences[:, i] for i in range(pose_sequences.size(1))], dim=0) | |
| # emotion_sequences = torch.cat([emotion_sequences[:, i] for i in range(emotion_sequences.size(1))], dim=0) | |
| # blink_sequences = torch.cat([blink_sequences[:, i] for i in range(blink_sequences.size(1))], dim=0) | |
| # face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) | |
| # print(audio_sequences.size(), face_sequences.size(), pose_sequences.size(), emotion_sequences.size()) | |
| audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 | |
| pose_embedding = self.pose_encoder(pose_sequences) # B, 512, 1, 1 | |
| emotion_embedding = self.emotion_encoder(emotion_sequences) # B, 512, 1, 1 | |
| blink_embedding = self.blink_encoder(blink_sequences) # B, 512, 1, 1 | |
| inputs_embedding = torch.cat((audio_embedding, pose_embedding, emotion_embedding, blink_embedding), dim=1) # B, 1536, 1, 1 | |
| # print(audio_embedding.size(), pose_embedding.size(), emotion_embedding.size(), inputs_embedding.size()) | |
| feats = [] | |
| x = face_sequences | |
| for f in self.face_encoder_blocks: | |
| x = f(x) | |
| # print(x.shape) | |
| feats.append(x) | |
| x = inputs_embedding | |
| for f in self.face_decoder_blocks: | |
| x = f(x) | |
| # print(x.shape) | |
| # try: | |
| x = torch.cat((x, feats[-1]), dim=1) | |
| # except Exception as e: | |
| # print(x.size()) | |
| # print(feats[-1].size()) | |
| # raise e | |
| feats.pop() | |
| x = self.output_block(x) | |
| # if input_dim_size > 4: | |
| # x = torch.split(x, B, dim=0) # [(B, C, H, W)] | |
| # outputs = torch.stack(x, dim=2) # (B, C, T, H, W) | |
| # else: | |
| outputs = x | |
| return outputs | |
| class Self_Attention(nn.Module): | |
| """ | |
| Source-Reference Attention Layer | |
| """ | |
| def __init__(self, in_planes_s, in_planes_r): | |
| """ | |
| Parameters | |
| ---------- | |
| in_planes_s: int | |
| Number of input source feature vector channels. | |
| in_planes_r: int | |
| Number of input reference feature vector channels. | |
| """ | |
| super(Self_Attention, self).__init__() | |
| self.query_conv = nn.Conv2d(in_channels=in_planes_s, out_channels=in_planes_s // 8, kernel_size=1) | |
| self.key_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r // 8, kernel_size=1) | |
| self.value_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r, kernel_size=1) | |
| self.gamma = nn.Parameter(torch.zeros(1)) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, source): | |
| source = source.float() if isinstance(source, torch.cuda.HalfTensor) else source | |
| reference = source | |
| """ | |
| Parameters | |
| ---------- | |
| source : torch.Tensor | |
| Source feature maps (B x Cs x Ts x Hs x Ws) | |
| reference : torch.Tensor | |
| Reference feature maps (B x Cr x Tr x Hr x Wr ) | |
| Returns : | |
| torch.Tensor | |
| Source-reference attention value added to the input source features | |
| torch.Tensor | |
| Attention map (B x Ns x Nt) (Ns=Ts*Hs*Ws, Nr=Tr*Hr*Wr) | |
| """ | |
| s_batchsize, sC, sH, sW = source.size() | |
| r_batchsize, rC, rH, rW = reference.size() | |
| proj_query = self.query_conv(source).view(s_batchsize, -1, sH * sW).permute(0, 2, 1) | |
| proj_key = self.key_conv(reference).view(r_batchsize, -1, rW * rH) | |
| energy = torch.bmm(proj_query, proj_key) | |
| attention = self.softmax(energy) | |
| proj_value = self.value_conv(reference).view(r_batchsize, -1, rH * rW) | |
| out = torch.bmm(proj_value, attention.permute(0, 2, 1)) | |
| out = out.view(s_batchsize, sC, sH, sW) | |
| out = self.gamma * out + source | |
| return out.half() if isinstance(source, torch.cuda.FloatTensor) else out | |