import os from typing import Dict from typing import Optional from typing import Tuple import kornia import lpips import torch import torch.nn as nn import torch.nn.functional as F from loguru import logger from arcface_torch.backbones.iresnet import iresnet100 from configs.train_config import TrainConfig from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper from HRNet.hrnet import HighResolutionNet from models.discriminator import Discriminator from models.gan_loss import GANLoss from models.generator import Generator from models.init_weight import init_net class HifiFace: def __init__( self, identity_extractor_config, is_training=True, device="cpu", load_checkpoint: Optional[Tuple[str, int]] = None, ): super(HifiFace, self).__init__() self.generator = Generator(identity_extractor_config) self.is_training = is_training if self.is_training: self.lr = TrainConfig().lr self.use_ddp = TrainConfig().use_ddp self.grad_clip = TrainConfig().grad_clip if TrainConfig().grad_clip is not None else 100.0 self.discriminator = init_net(Discriminator(3)) self.l1_loss = nn.L1Loss() if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: self.mse_loss = nn.MSELoss() self.loss_fn_vgg = lpips.LPIPS(net="vgg") self.adv_loss = GANLoss() # 3D人脸重建模型 self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False) self.f_3d.load_state_dict( torch.load(identity_extractor_config["f_3d_checkpoint_path"], map_location="cpu")["net_recon"] ) self.f_3d.eval() self.face_model = ParametricFaceModel(bfm_folder=identity_extractor_config["bfm_folder"]) self.face_model.to("cpu") # 人脸识别模型 self.f_id = iresnet100(pretrained=False, fp16=False) self.f_id.load_state_dict(torch.load(identity_extractor_config["f_id_checkpoint_path"], map_location="cpu")) self.f_id.eval() # mouth heatmap model if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: self.model_mouth = HighResolutionNet() checkpoint = torch.load(identity_extractor_config["hrnet_path"], map_location="cpu") self.model_mouth.load_state_dict(checkpoint) self.model_mouth.eval() self.lambda_adv = 1 self.lambda_seg = 100 self.lambda_rec = 20 self.lambda_cyc = 1 self.lambda_lpips = 5 self.lambda_shape = 0.5 self.lambda_id = 5 self.lambda_eye_hm = 10000.0 self.lambda_mouth_hm = 10000.0 self.dilation_kernel = torch.ones(5, 5) if load_checkpoint is not None: self.load(load_checkpoint[0], load_checkpoint[1]) self.setup(device) def save(self, path, idx=None): os.makedirs(path, exist_ok=True) if idx is None: g_path = os.path.join(path, "generator.pth") d_path = os.path.join(path, "discriminator.pth") else: g_path = os.path.join(path, f"generator_{idx}.pth") d_path = os.path.join(path, f"discriminator_{idx}.pth") if self.use_ddp: torch.save(self.generator.module.state_dict(), g_path) torch.save(self.discriminator.module.state_dict(), d_path) else: torch.save(self.generator.state_dict(), g_path) torch.save(self.discriminator.state_dict(), d_path) def load(self, path, idx=None): if idx is None: g_path = os.path.join(path, "generator.pth") d_path = os.path.join(path, "discriminator.pth") else: g_path = os.path.join(path, f"generator_{idx}.pth") d_path = os.path.join(path, f"discriminator_{idx}.pth") logger.info(f"Loading generator from {g_path}") self.generator.load_state_dict(torch.load(g_path, map_location="cpu")) if self.is_training: logger.info(f"Loading discriminator from {d_path}") self.discriminator.load_state_dict(torch.load(d_path, map_location="cpu")) def setup(self, device): self.generator.to(device) if self.is_training: self.discriminator.to(device) self.l1_loss.to(device) if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: self.mse_loss.to(device) self.f_3d.to(device) self.f_id.to(device) self.loss_fn_vgg.to(device) self.face_model.to(device) self.adv_loss.to(device) if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: self.model_mouth.to(device) self.f_3d.requires_grad_(False) self.f_id.requires_grad_(False) self.loss_fn_vgg.requires_grad_(False) if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: self.model_mouth.requires_grad_(False) self.dilation_kernel = self.dilation_kernel.to(device) if self.use_ddp: from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist self.generator = DDP(self.generator, device_ids=[device]) self.discriminator = DDP(self.discriminator, device_ids=[device]) if dist.get_rank() == 0: torch.save(self.generator.state_dict(), "/tmp/generator.pth") torch.save(self.discriminator.state_dict(), "/tmp/discriminator.pth") dist.barrier() self.generator.load_state_dict(torch.load("/tmp/generator.pth", map_location=device)) self.discriminator.load_state_dict(torch.load("/tmp/discriminator.pth", map_location=device)) self.g_optimizer = torch.optim.AdamW(self.generator.parameters(), lr=self.lr, betas=[0, 0.999]) self.d_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr=self.lr, betas=[0, 0.999]) def train(self): self.generator.train() self.discriminator.train() # 整个id extractor是不训练的模块 if self.use_ddp: self.generator.module.id_extractor.eval() else: self.generator.id_extractor.eval() def eval(self): self.generator.eval() if self.is_training: self.discriminator.eval() def train_forward_generator(self, source_img, target_img, target_mask, same_id_mask): """ 训练时候 Generator的loss计算 Parameters: ----------- source_img: torch.Tensor target_img: torch.Tensor target_mask: torch.Tensor, [B, 1, H, W] same_id_mask: torch.Tensor, [B, 1] Returns: -------- source_img: torch.Tensor target_img: torch.Tensor i_cycle: torch.Tensor, cycle image i_r: torch.Tensor m_r: torch.Tensor loss: Dict[torch.Tensor], contain pairs of loss name and loss values """ same = same_id_mask.unsqueeze(-1).unsqueeze(-1) i_r, i_low, m_r, m_low = self.generator(source_img, target_img, need_id_grad=False) i_cycle, _, _, _ = self.generator(target_img, i_r, need_id_grad=True) d_r = self.discriminator(i_r) # SID Loss: shape loss + id loss with torch.no_grad(): c_s = self.f_3d(F.interpolate(source_img, size=224, mode="bilinear")) c_t = self.f_3d(F.interpolate(target_img, size=224, mode="bilinear")) c_r = self.f_3d(F.interpolate(i_r, size=224, mode="bilinear")) c_low = self.f_3d(F.interpolate(i_low, size=224, mode="bilinear")) with torch.no_grad(): c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1) _, _, _, q_fuse = self.face_model.compute_for_render(c_fuse) _, _, _, q_r = self.face_model.compute_for_render(c_r) _, _, _, q_low = self.face_model.compute_for_render(c_low) with torch.no_grad(): v_id_i_s = F.normalize( self.f_id(F.interpolate((source_img - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2 ) v_id_i_r = F.normalize(self.f_id(F.interpolate((i_r - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) v_id_i_low = F.normalize(self.f_id(F.interpolate((i_low - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) loss_shape = self.l1_loss(q_fuse, q_r) + self.l1_loss(q_fuse, q_low) loss_shape = torch.clamp(loss_shape, min=0.0, max=10.0) inner_product_r = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_r.unsqueeze(2)).squeeze() inner_product_low = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_low.unsqueeze(2)).squeeze() loss_id = self.l1_loss(torch.ones_like(inner_product_r), inner_product_r) + self.l1_loss( torch.ones_like(inner_product_low), inner_product_low ) loss_sid = self.lambda_shape * loss_shape + self.lambda_id * loss_id # Realism Loss: segmentation loss + reconstruction loss + cycle loss + perceptual loss + adversarial loss loss_cycle = self.l1_loss(target_img, i_cycle) # dilate target mask target_mask = kornia.morphology.dilation(target_mask, self.dilation_kernel) loss_segmentation = self.l1_loss( F.interpolate(target_mask, scale_factor=0.25, mode="bilinear"), m_low ) + self.l1_loss(target_mask, m_r) loss_reconstruction = self.l1_loss(i_r * same, target_img * same) + self.l1_loss( i_low * same, F.interpolate(target_img, scale_factor=0.25, mode="bilinear") * same ) loss_perceptual = self.loss_fn_vgg(target_img * same, i_r * same).mean() loss_adversarial = self.adv_loss(d_r, True, for_discriminator=False) loss_realism = ( self.lambda_adv * loss_adversarial + self.lambda_seg * loss_segmentation + self.lambda_rec * loss_reconstruction + self.lambda_cyc * loss_cycle + self.lambda_lpips * loss_perceptual ) # eye hm loss loss_eye_hm = 0 # mouth hm loss loss_mouth_hm = 0 if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: target_hm = self.model_mouth(target_img) r_hm = self.model_mouth(i_r) if TrainConfig().eye_hm_loss: target_eye_hm = target_hm[:, 96:98, :, :] r_eye_hm = r_hm[:, 96:98, :, :] loss_eye_hm = self.mse_loss(r_eye_hm, target_eye_hm) loss_realism = loss_realism + self.lambda_eye_hm * loss_eye_hm if TrainConfig().mouth_hm_loss: target_mouth_hm = target_hm[:, 76:96, :, :] r_mouth_hm = r_hm[:, 76:96, :, :] loss_mouth_hm = self.mse_loss(r_mouth_hm, target_mouth_hm) loss_realism = loss_realism + self.lambda_mouth_hm * loss_mouth_hm loss_generator = loss_sid + loss_realism loss_dict = { "loss_shape": loss_shape, "loss_id": loss_id, "loss_sid": loss_sid, "loss_cycle": loss_cycle, "loss_segmentation": loss_segmentation, "loss_reconstruction": loss_reconstruction, "loss_perceptual": loss_perceptual, "loss_adversarial": loss_adversarial, "loss_realism": loss_realism, "loss_generator": loss_generator, } if TrainConfig().eye_hm_loss: loss_dict.update({"loss_eye_hm": loss_eye_hm}) if TrainConfig().mouth_hm_loss: loss_dict.update({"loss_mouth_hm": loss_mouth_hm}) return ( source_img, target_img, i_cycle.detach(), i_r.detach(), m_r.detach(), loss_dict, ) def train_forward_discriminator(self, target_img, i_r): """ 训练时候 Discriminator的loss计算 Parameters: ----------- target_img: torch.Tensor, 目标脸图片 i_r: torch.Tensor, 换脸结果 Returns: -------- Dict[str]: contains pair of loss name and loss values """ d_gt = self.discriminator(target_img) d_fake = self.discriminator(i_r.detach()) loss_real = self.adv_loss(d_gt, True) loss_fake = self.adv_loss(d_fake, False) # alpha = torch.rand(target_img.shape[0], 1, 1, 1).to(target_img.device) # x_hat = (alpha * target_img.data + (1 - alpha) * i_r.data).requires_grad_(True) # out = self.discriminator(x_hat) # loss_gp = gradient_penalty(out, x_hat) loss_discriminator = loss_real + loss_fake # + 10 * loss_gp return { "loss_real": loss_real, "loss_fake": loss_fake, # "loss_gp": loss_gp, "loss_discriminator": loss_discriminator, } def forward( self, source_img: torch.Tensor, target_img: torch.Tensor, shape_rate=None, id_rate=None ) -> torch.Tensor: """ Parameters: ----------- source_img: torch.Tensor, source face 图像 target_img: torch.Tensor, target face 图像 *_rate: 插值系数 Returns: -------- i_r: torch.Tensor, swapped result """ if shape_rate is None and id_rate is None: i_r, _, m_r, _ = self.generator(source_img, target_img) else: if shape_rate is None: shape_rate = 1.0 if id_rate is None: id_rate = 1.0 i_r, _, m_r, _ = self.generator.interp(source_img, target_img, shape_rate, id_rate) return i_r, m_r def optimize( self, source_img: torch.Tensor, target_img: torch.Tensor, target_mask: torch.Tensor, same_id_mask: torch.Tensor, ) -> Tuple[Dict, Dict[str, torch.Tensor]]: """ 模型的optimize 训练模式下执行一次训练,并返回loss信息和结果 Parameters: ----------- source_img: torch.Tensor, source face 图像 target_img: torch.Tensor, target face 图像 target_mask: torch.Tensor, target face mask same_id_mask: torch.Tensor, same id mask, 标识source 和 target是否是同个人 Returns: -------- Tuple[Dict, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: loss_dict, source_img, target_img, m_r(预测的mask), i_r(换脸结果) """ src_img, tgt_img, i_cycle, i_r, m_r, loss_G_dict = self.train_forward_generator( source_img, target_img, target_mask, same_id_mask ) loss_G = loss_G_dict["loss_generator"] self.g_optimizer.zero_grad() loss_G.backward() global_norm_G = torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.grad_clip) self.g_optimizer.step() loss_D_dict = self.train_forward_discriminator(tgt_img, i_r) loss_D = loss_D_dict["loss_discriminator"] self.d_optimizer.zero_grad() loss_D.backward() global_norm_D = torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_clip) self.d_optimizer.step() total_loss_dict = {"global_norm_G": global_norm_G, "global_norm_D": global_norm_D} total_loss_dict.update(loss_G_dict) total_loss_dict.update(loss_D_dict) return total_loss_dict, { "source face": src_img, "target face": tgt_img, "swapped face": torch.clamp(i_r, min=0.0, max=1.0), "pred face mask": m_r, "cycle face": i_cycle, } if __name__ == "__main__": import torch import cv2 from configs.train_config import TrainConfig identity_extractor_config = TrainConfig().identity_extractor_config model = HifiFace(identity_extractor_config, is_training=True) # src = cv2.imread("/home/xuehongyang/data/test1.jpg") # tgt = cv2.imread("/home/xuehongyang/data/test2.jpg") # src = cv2.cvtColor(src, cv2.COLOR_BGR2RGB) # tgt = cv2.cvtColor(tgt, cv2.COLOR_BGR2RGB) # src = cv2.resize(src, (256, 256)) # tgt = cv2.resize(tgt, (256, 256)) # src = src.transpose(2, 0, 1)[None, ...] # tgt = tgt.transpose(2, 0, 1)[None, ...] # source_img = torch.from_numpy(src).float() / 255.0 # target_img = torch.from_numpy(tgt).float() / 255.0 # same_id_mask = torch.Tensor([1]).unsqueeze(0) # tgt_mask = target_img[:, 0, :, :].unsqueeze(1) # if torch.cuda.is_available(): # model.to("cuda:3") # source_img = source_img.to("cuda:3") # target_img = target_img.to("cuda:3") # tgt_mask = tgt_mask.to("cuda:3") # same_id_mask = same_id_mask.to("cuda:3") # source_img = source_img.repeat(16, 1, 1, 1) # target_img = target_img.repeat(16, 1, 1, 1) # tgt_mask = tgt_mask.repeat(16, 1, 1, 1) # same_id_mask = same_id_mask.repeat(16, 1) # while True: # x = model.optimize(source_img, target_img, tgt_mask, same_id_mask) # print(x[0]["loss_generator"])