Spaces:
Build error
Build error
| import os | |
| from typing import Dict | |
| from typing import Optional | |
| from typing import Tuple | |
| import kornia | |
| import lpips | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from loguru import logger | |
| from arcface_torch.backbones.iresnet import iresnet100 | |
| from configs.train_config import TrainConfig | |
| from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel | |
| from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper | |
| from HRNet.hrnet import HighResolutionNet | |
| from models.discriminator import Discriminator | |
| from models.gan_loss import GANLoss | |
| from models.generator import Generator | |
| from models.init_weight import init_net | |
| class HifiFace: | |
| def __init__( | |
| self, | |
| identity_extractor_config, | |
| is_training=True, | |
| device="cpu", | |
| load_checkpoint: Optional[Tuple[str, int]] = None, | |
| ): | |
| super(HifiFace, self).__init__() | |
| self.generator = Generator(identity_extractor_config) | |
| self.is_training = is_training | |
| if self.is_training: | |
| self.lr = TrainConfig().lr | |
| self.use_ddp = TrainConfig().use_ddp | |
| self.grad_clip = TrainConfig().grad_clip if TrainConfig().grad_clip is not None else 100.0 | |
| self.discriminator = init_net(Discriminator(3)) | |
| self.l1_loss = nn.L1Loss() | |
| if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: | |
| self.mse_loss = nn.MSELoss() | |
| self.loss_fn_vgg = lpips.LPIPS(net="vgg") | |
| self.adv_loss = GANLoss() | |
| # 3D人脸重建模型 | |
| self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False) | |
| self.f_3d.load_state_dict( | |
| torch.load(identity_extractor_config["f_3d_checkpoint_path"], map_location="cpu")["net_recon"] | |
| ) | |
| self.f_3d.eval() | |
| self.face_model = ParametricFaceModel(bfm_folder=identity_extractor_config["bfm_folder"]) | |
| self.face_model.to("cpu") | |
| # 人脸识别模型 | |
| self.f_id = iresnet100(pretrained=False, fp16=False) | |
| self.f_id.load_state_dict(torch.load(identity_extractor_config["f_id_checkpoint_path"], map_location="cpu")) | |
| self.f_id.eval() | |
| # mouth heatmap model | |
| if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: | |
| self.model_mouth = HighResolutionNet() | |
| checkpoint = torch.load(identity_extractor_config["hrnet_path"], map_location="cpu") | |
| self.model_mouth.load_state_dict(checkpoint) | |
| self.model_mouth.eval() | |
| self.lambda_adv = 1 | |
| self.lambda_seg = 100 | |
| self.lambda_rec = 20 | |
| self.lambda_cyc = 1 | |
| self.lambda_lpips = 5 | |
| self.lambda_shape = 0.5 | |
| self.lambda_id = 5 | |
| self.lambda_eye_hm = 10000.0 | |
| self.lambda_mouth_hm = 10000.0 | |
| self.dilation_kernel = torch.ones(5, 5) | |
| if load_checkpoint is not None: | |
| self.load(load_checkpoint[0], load_checkpoint[1]) | |
| self.setup(device) | |
| def save(self, path, idx=None): | |
| os.makedirs(path, exist_ok=True) | |
| if idx is None: | |
| g_path = os.path.join(path, "generator.pth") | |
| d_path = os.path.join(path, "discriminator.pth") | |
| else: | |
| g_path = os.path.join(path, f"generator_{idx}.pth") | |
| d_path = os.path.join(path, f"discriminator_{idx}.pth") | |
| if self.use_ddp: | |
| torch.save(self.generator.module.state_dict(), g_path) | |
| torch.save(self.discriminator.module.state_dict(), d_path) | |
| else: | |
| torch.save(self.generator.state_dict(), g_path) | |
| torch.save(self.discriminator.state_dict(), d_path) | |
| def load(self, path, idx=None): | |
| if idx is None: | |
| g_path = os.path.join(path, "generator.pth") | |
| d_path = os.path.join(path, "discriminator.pth") | |
| else: | |
| g_path = os.path.join(path, f"generator_{idx}.pth") | |
| d_path = os.path.join(path, f"discriminator_{idx}.pth") | |
| logger.info(f"Loading generator from {g_path}") | |
| self.generator.load_state_dict(torch.load(g_path, map_location="cpu")) | |
| if self.is_training: | |
| logger.info(f"Loading discriminator from {d_path}") | |
| self.discriminator.load_state_dict(torch.load(d_path, map_location="cpu")) | |
| def setup(self, device): | |
| self.generator.to(device) | |
| if self.is_training: | |
| self.discriminator.to(device) | |
| self.l1_loss.to(device) | |
| if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: | |
| self.mse_loss.to(device) | |
| self.f_3d.to(device) | |
| self.f_id.to(device) | |
| self.loss_fn_vgg.to(device) | |
| self.face_model.to(device) | |
| self.adv_loss.to(device) | |
| if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: | |
| self.model_mouth.to(device) | |
| self.f_3d.requires_grad_(False) | |
| self.f_id.requires_grad_(False) | |
| self.loss_fn_vgg.requires_grad_(False) | |
| if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: | |
| self.model_mouth.requires_grad_(False) | |
| self.dilation_kernel = self.dilation_kernel.to(device) | |
| if self.use_ddp: | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import torch.distributed as dist | |
| self.generator = DDP(self.generator, device_ids=[device]) | |
| self.discriminator = DDP(self.discriminator, device_ids=[device]) | |
| if dist.get_rank() == 0: | |
| torch.save(self.generator.state_dict(), "/tmp/generator.pth") | |
| torch.save(self.discriminator.state_dict(), "/tmp/discriminator.pth") | |
| dist.barrier() | |
| self.generator.load_state_dict(torch.load("/tmp/generator.pth", map_location=device)) | |
| self.discriminator.load_state_dict(torch.load("/tmp/discriminator.pth", map_location=device)) | |
| self.g_optimizer = torch.optim.AdamW(self.generator.parameters(), lr=self.lr, betas=[0, 0.999]) | |
| self.d_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr=self.lr, betas=[0, 0.999]) | |
| def train(self): | |
| self.generator.train() | |
| self.discriminator.train() | |
| # 整个id extractor是不训练的模块 | |
| if self.use_ddp: | |
| self.generator.module.id_extractor.eval() | |
| else: | |
| self.generator.id_extractor.eval() | |
| def eval(self): | |
| self.generator.eval() | |
| if self.is_training: | |
| self.discriminator.eval() | |
| def train_forward_generator(self, source_img, target_img, target_mask, same_id_mask): | |
| """ | |
| 训练时候 Generator的loss计算 | |
| Parameters: | |
| ----------- | |
| source_img: torch.Tensor | |
| target_img: torch.Tensor | |
| target_mask: torch.Tensor, [B, 1, H, W] | |
| same_id_mask: torch.Tensor, [B, 1] | |
| Returns: | |
| -------- | |
| source_img: torch.Tensor | |
| target_img: torch.Tensor | |
| i_cycle: torch.Tensor, cycle image | |
| i_r: torch.Tensor | |
| m_r: torch.Tensor | |
| loss: Dict[torch.Tensor], contain pairs of loss name and loss values | |
| """ | |
| same = same_id_mask.unsqueeze(-1).unsqueeze(-1) | |
| i_r, i_low, m_r, m_low = self.generator(source_img, target_img, need_id_grad=False) | |
| i_cycle, _, _, _ = self.generator(target_img, i_r, need_id_grad=True) | |
| d_r = self.discriminator(i_r) | |
| # SID Loss: shape loss + id loss | |
| with torch.no_grad(): | |
| c_s = self.f_3d(F.interpolate(source_img, size=224, mode="bilinear")) | |
| c_t = self.f_3d(F.interpolate(target_img, size=224, mode="bilinear")) | |
| c_r = self.f_3d(F.interpolate(i_r, size=224, mode="bilinear")) | |
| c_low = self.f_3d(F.interpolate(i_low, size=224, mode="bilinear")) | |
| with torch.no_grad(): | |
| c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1) | |
| _, _, _, q_fuse = self.face_model.compute_for_render(c_fuse) | |
| _, _, _, q_r = self.face_model.compute_for_render(c_r) | |
| _, _, _, q_low = self.face_model.compute_for_render(c_low) | |
| with torch.no_grad(): | |
| v_id_i_s = F.normalize( | |
| self.f_id(F.interpolate((source_img - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2 | |
| ) | |
| v_id_i_r = F.normalize(self.f_id(F.interpolate((i_r - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) | |
| v_id_i_low = F.normalize(self.f_id(F.interpolate((i_low - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) | |
| loss_shape = self.l1_loss(q_fuse, q_r) + self.l1_loss(q_fuse, q_low) | |
| loss_shape = torch.clamp(loss_shape, min=0.0, max=10.0) | |
| inner_product_r = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_r.unsqueeze(2)).squeeze() | |
| inner_product_low = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_low.unsqueeze(2)).squeeze() | |
| loss_id = self.l1_loss(torch.ones_like(inner_product_r), inner_product_r) + self.l1_loss( | |
| torch.ones_like(inner_product_low), inner_product_low | |
| ) | |
| loss_sid = self.lambda_shape * loss_shape + self.lambda_id * loss_id | |
| # Realism Loss: segmentation loss + reconstruction loss + cycle loss + perceptual loss + adversarial loss | |
| loss_cycle = self.l1_loss(target_img, i_cycle) | |
| # dilate target mask | |
| target_mask = kornia.morphology.dilation(target_mask, self.dilation_kernel) | |
| loss_segmentation = self.l1_loss( | |
| F.interpolate(target_mask, scale_factor=0.25, mode="bilinear"), m_low | |
| ) + self.l1_loss(target_mask, m_r) | |
| loss_reconstruction = self.l1_loss(i_r * same, target_img * same) + self.l1_loss( | |
| i_low * same, F.interpolate(target_img, scale_factor=0.25, mode="bilinear") * same | |
| ) | |
| loss_perceptual = self.loss_fn_vgg(target_img * same, i_r * same).mean() | |
| loss_adversarial = self.adv_loss(d_r, True, for_discriminator=False) | |
| loss_realism = ( | |
| self.lambda_adv * loss_adversarial | |
| + self.lambda_seg * loss_segmentation | |
| + self.lambda_rec * loss_reconstruction | |
| + self.lambda_cyc * loss_cycle | |
| + self.lambda_lpips * loss_perceptual | |
| ) | |
| # eye hm loss | |
| loss_eye_hm = 0 | |
| # mouth hm loss | |
| loss_mouth_hm = 0 | |
| if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: | |
| target_hm = self.model_mouth(target_img) | |
| r_hm = self.model_mouth(i_r) | |
| if TrainConfig().eye_hm_loss: | |
| target_eye_hm = target_hm[:, 96:98, :, :] | |
| r_eye_hm = r_hm[:, 96:98, :, :] | |
| loss_eye_hm = self.mse_loss(r_eye_hm, target_eye_hm) | |
| loss_realism = loss_realism + self.lambda_eye_hm * loss_eye_hm | |
| if TrainConfig().mouth_hm_loss: | |
| target_mouth_hm = target_hm[:, 76:96, :, :] | |
| r_mouth_hm = r_hm[:, 76:96, :, :] | |
| loss_mouth_hm = self.mse_loss(r_mouth_hm, target_mouth_hm) | |
| loss_realism = loss_realism + self.lambda_mouth_hm * loss_mouth_hm | |
| loss_generator = loss_sid + loss_realism | |
| loss_dict = { | |
| "loss_shape": loss_shape, | |
| "loss_id": loss_id, | |
| "loss_sid": loss_sid, | |
| "loss_cycle": loss_cycle, | |
| "loss_segmentation": loss_segmentation, | |
| "loss_reconstruction": loss_reconstruction, | |
| "loss_perceptual": loss_perceptual, | |
| "loss_adversarial": loss_adversarial, | |
| "loss_realism": loss_realism, | |
| "loss_generator": loss_generator, | |
| } | |
| if TrainConfig().eye_hm_loss: | |
| loss_dict.update({"loss_eye_hm": loss_eye_hm}) | |
| if TrainConfig().mouth_hm_loss: | |
| loss_dict.update({"loss_mouth_hm": loss_mouth_hm}) | |
| return ( | |
| source_img, | |
| target_img, | |
| i_cycle.detach(), | |
| i_r.detach(), | |
| m_r.detach(), | |
| loss_dict, | |
| ) | |
| def train_forward_discriminator(self, target_img, i_r): | |
| """ | |
| 训练时候 Discriminator的loss计算 | |
| Parameters: | |
| ----------- | |
| target_img: torch.Tensor, 目标脸图片 | |
| i_r: torch.Tensor, 换脸结果 | |
| Returns: | |
| -------- | |
| Dict[str]: contains pair of loss name and loss values | |
| """ | |
| d_gt = self.discriminator(target_img) | |
| d_fake = self.discriminator(i_r.detach()) | |
| loss_real = self.adv_loss(d_gt, True) | |
| loss_fake = self.adv_loss(d_fake, False) | |
| # alpha = torch.rand(target_img.shape[0], 1, 1, 1).to(target_img.device) | |
| # x_hat = (alpha * target_img.data + (1 - alpha) * i_r.data).requires_grad_(True) | |
| # out = self.discriminator(x_hat) | |
| # loss_gp = gradient_penalty(out, x_hat) | |
| loss_discriminator = loss_real + loss_fake # + 10 * loss_gp | |
| return { | |
| "loss_real": loss_real, | |
| "loss_fake": loss_fake, | |
| # "loss_gp": loss_gp, | |
| "loss_discriminator": loss_discriminator, | |
| } | |
| def forward( | |
| self, source_img: torch.Tensor, target_img: torch.Tensor, shape_rate=None, id_rate=None | |
| ) -> torch.Tensor: | |
| """ | |
| Parameters: | |
| ----------- | |
| source_img: torch.Tensor, source face 图像 | |
| target_img: torch.Tensor, target face 图像 | |
| *_rate: 插值系数 | |
| Returns: | |
| -------- | |
| i_r: torch.Tensor, swapped result | |
| """ | |
| if shape_rate is None and id_rate is None: | |
| i_r, _, m_r, _ = self.generator(source_img, target_img) | |
| else: | |
| if shape_rate is None: | |
| shape_rate = 1.0 | |
| if id_rate is None: | |
| id_rate = 1.0 | |
| i_r, _, m_r, _ = self.generator.interp(source_img, target_img, shape_rate, id_rate) | |
| return i_r, m_r | |
| def optimize( | |
| self, | |
| source_img: torch.Tensor, | |
| target_img: torch.Tensor, | |
| target_mask: torch.Tensor, | |
| same_id_mask: torch.Tensor, | |
| ) -> Tuple[Dict, Dict[str, torch.Tensor]]: | |
| """ | |
| 模型的optimize | |
| 训练模式下执行一次训练,并返回loss信息和结果 | |
| Parameters: | |
| ----------- | |
| source_img: torch.Tensor, source face 图像 | |
| target_img: torch.Tensor, target face 图像 | |
| target_mask: torch.Tensor, target face mask | |
| same_id_mask: torch.Tensor, same id mask, 标识source 和 target是否是同个人 | |
| Returns: | |
| -------- | |
| Tuple[Dict, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| loss_dict, source_img, target_img, m_r(预测的mask), i_r(换脸结果) | |
| """ | |
| src_img, tgt_img, i_cycle, i_r, m_r, loss_G_dict = self.train_forward_generator( | |
| source_img, target_img, target_mask, same_id_mask | |
| ) | |
| loss_G = loss_G_dict["loss_generator"] | |
| self.g_optimizer.zero_grad() | |
| loss_G.backward() | |
| global_norm_G = torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.grad_clip) | |
| self.g_optimizer.step() | |
| loss_D_dict = self.train_forward_discriminator(tgt_img, i_r) | |
| loss_D = loss_D_dict["loss_discriminator"] | |
| self.d_optimizer.zero_grad() | |
| loss_D.backward() | |
| global_norm_D = torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_clip) | |
| self.d_optimizer.step() | |
| total_loss_dict = {"global_norm_G": global_norm_G, "global_norm_D": global_norm_D} | |
| total_loss_dict.update(loss_G_dict) | |
| total_loss_dict.update(loss_D_dict) | |
| return total_loss_dict, { | |
| "source face": src_img, | |
| "target face": tgt_img, | |
| "swapped face": torch.clamp(i_r, min=0.0, max=1.0), | |
| "pred face mask": m_r, | |
| "cycle face": i_cycle, | |
| } | |
| if __name__ == "__main__": | |
| import torch | |
| import cv2 | |
| from configs.train_config import TrainConfig | |
| identity_extractor_config = TrainConfig().identity_extractor_config | |
| model = HifiFace(identity_extractor_config, is_training=True) | |
| # src = cv2.imread("/home/xuehongyang/data/test1.jpg") | |
| # tgt = cv2.imread("/home/xuehongyang/data/test2.jpg") | |
| # src = cv2.cvtColor(src, cv2.COLOR_BGR2RGB) | |
| # tgt = cv2.cvtColor(tgt, cv2.COLOR_BGR2RGB) | |
| # src = cv2.resize(src, (256, 256)) | |
| # tgt = cv2.resize(tgt, (256, 256)) | |
| # src = src.transpose(2, 0, 1)[None, ...] | |
| # tgt = tgt.transpose(2, 0, 1)[None, ...] | |
| # source_img = torch.from_numpy(src).float() / 255.0 | |
| # target_img = torch.from_numpy(tgt).float() / 255.0 | |
| # same_id_mask = torch.Tensor([1]).unsqueeze(0) | |
| # tgt_mask = target_img[:, 0, :, :].unsqueeze(1) | |
| # if torch.cuda.is_available(): | |
| # model.to("cuda:3") | |
| # source_img = source_img.to("cuda:3") | |
| # target_img = target_img.to("cuda:3") | |
| # tgt_mask = tgt_mask.to("cuda:3") | |
| # same_id_mask = same_id_mask.to("cuda:3") | |
| # source_img = source_img.repeat(16, 1, 1, 1) | |
| # target_img = target_img.repeat(16, 1, 1, 1) | |
| # tgt_mask = tgt_mask.repeat(16, 1, 1, 1) | |
| # same_id_mask = same_id_mask.repeat(16, 1) | |
| # while True: | |
| # x = model.optimize(source_img, target_img, tgt_mask, same_id_mask) | |
| # print(x[0]["loss_generator"]) | |