leonelhs's picture
Upload model.py
125f259 verified
raw
history blame
16.2 kB
import os
from abc import abstractmethod
from typing import Dict
from typing import Tuple
import kornia
import lpips
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
from HRNet.hrnet import HighResolutionNet
from arcface_torch.backbones.iresnet import iresnet100
from models.discriminator import Discriminator
from models.gan_loss import GANLoss
from models.generator import Generator
from models.init_weight import init_net
class HifiFace:
def __init__(
self,
identity_extractor_config,
generator_path,
is_training=False,
device="cpu"
):
super(HifiFace, self).__init__()
self.d_optimizer = None
self.g_optimizer = None
self.generator = Generator(identity_extractor_config)
self.is_training = is_training
self.device = device
self.generator_path = generator_path
if self.is_training:
self.lr = TrainConfig().lr
self.use_ddp = TrainConfig().use_ddp
self.grad_clip = TrainConfig().grad_clip if TrainConfig().grad_clip is not None else 100.0
self.discriminator = init_net(Discriminator(3))
self.l1_loss = nn.L1Loss()
if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss:
self.mse_loss = nn.MSELoss()
self.loss_fn_vgg = lpips.LPIPS(net="vgg")
self.adv_loss = GANLoss()
# 3D人脸重建模型
self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False)
self.f_3d.load_state_dict(
torch.load(identity_extractor_config["f_3d_checkpoint_path"], map_location="cpu")["net_recon"]
)
self.f_3d.eval()
self.face_model = ParametricFaceModel(bfm_folder=identity_extractor_config["bfm_folder"])
self.face_model.to("cpu")
# 人脸识别模型
self.f_id = iresnet100(pretrained=False, fp16=False)
self.f_id.load_state_dict(torch.load(identity_extractor_config["f_id_checkpoint_path"], map_location="cpu"))
self.f_id.eval()
# mouth heatmap model
if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss:
self.model_mouth = HighResolutionNet()
checkpoint = torch.load(identity_extractor_config["hrnet_path"], map_location="cpu")
self.model_mouth.load_state_dict(checkpoint)
self.model_mouth.eval()
self.lambda_adv = 1
self.lambda_seg = 100
self.lambda_rec = 20
self.lambda_cyc = 1
self.lambda_lpips = 5
self.lambda_shape = 0.5
self.lambda_id = 5
self.lambda_eye_hm = 10000.0
self.lambda_mouth_hm = 10000.0
self.dilation_kernel = torch.ones(5, 5)
self.load_checkpoint()
self.setup(self.device)
def save(self, path, idx=None):
os.makedirs(path, exist_ok=True)
if idx is None:
g_path = os.path.join(path, "generator.pth")
d_path = os.path.join(path, "discriminator.pth")
else:
g_path = os.path.join(path, f"generator_{idx}.pth")
d_path = os.path.join(path, f"discriminator_{idx}.pth")
if self.use_ddp:
torch.save(self.generator.module.state_dict(), g_path)
torch.save(self.discriminator.module.state_dict(), d_path)
else:
torch.save(self.generator.state_dict(), g_path)
torch.save(self.discriminator.state_dict(), d_path)
@abstractmethod
def load_checkpoint(self):
pass
def setup(self, device):
self.generator.to(device)
if self.is_training:
self.discriminator.to(device)
self.l1_loss.to(device)
if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss:
self.mse_loss.to(device)
self.f_3d.to(device)
self.f_id.to(device)
self.loss_fn_vgg.to(device)
self.face_model.to(device)
self.adv_loss.to(device)
if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss:
self.model_mouth.to(device)
self.f_3d.requires_grad_(False)
self.f_id.requires_grad_(False)
self.loss_fn_vgg.requires_grad_(False)
if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss:
self.model_mouth.requires_grad_(False)
self.dilation_kernel = self.dilation_kernel.to(device)
if self.use_ddp:
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
self.generator = DDP(self.generator, device_ids=[device])
self.discriminator = DDP(self.discriminator, device_ids=[device])
if dist.get_rank() == 0:
torch.save(self.generator.state_dict(), "/tmp/generator.pth")
torch.save(self.discriminator.state_dict(), "/tmp/discriminator.pth")
dist.barrier()
self.generator.load_state_dict(torch.load("/tmp/generator.pth", map_location=device))
self.discriminator.load_state_dict(torch.load("/tmp/discriminator.pth", map_location=device))
self.g_optimizer = torch.optim.AdamW(self.generator.parameters(), lr=self.lr, betas=[0, 0.999])
self.d_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr=self.lr, betas=[0, 0.999])
def train(self):
self.generator.train()
self.discriminator.train()
# 整个id extractor是不训练的模块
if self.use_ddp:
self.generator.module.id_extractor.eval()
else:
self.generator.id_extractor.eval()
def eval(self):
self.generator.eval()
if self.is_training:
self.discriminator.eval()
def train_forward_generator(self, source_img, target_img, target_mask, same_id_mask):
"""
训练时候 Generator的loss计算
Parameters:
-----------
source_img: torch.Tensor
target_img: torch.Tensor
target_mask: torch.Tensor, [B, 1, H, W]
same_id_mask: torch.Tensor, [B, 1]
Returns:
--------
source_img: torch.Tensor
target_img: torch.Tensor
i_cycle: torch.Tensor, cycle image
i_r: torch.Tensor
m_r: torch.Tensor
loss: Dict[torch.Tensor], contain pairs of loss name and loss values
"""
same = same_id_mask.unsqueeze(-1).unsqueeze(-1)
i_r, i_low, m_r, m_low = self.generator(source_img, target_img, need_id_grad=False)
i_cycle, _, _, _ = self.generator(target_img, i_r, need_id_grad=True)
d_r = self.discriminator(i_r)
# SID Loss: shape loss + id loss
with torch.no_grad():
c_s = self.f_3d(F.interpolate(source_img, size=224, mode="bilinear"))
c_t = self.f_3d(F.interpolate(target_img, size=224, mode="bilinear"))
c_r = self.f_3d(F.interpolate(i_r, size=224, mode="bilinear"))
c_low = self.f_3d(F.interpolate(i_low, size=224, mode="bilinear"))
with torch.no_grad():
c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1)
_, _, _, q_fuse = self.face_model.compute_for_render(c_fuse)
_, _, _, q_r = self.face_model.compute_for_render(c_r)
_, _, _, q_low = self.face_model.compute_for_render(c_low)
with torch.no_grad():
v_id_i_s = F.normalize(
self.f_id(F.interpolate((source_img - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2
)
v_id_i_r = F.normalize(self.f_id(F.interpolate((i_r - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2)
v_id_i_low = F.normalize(self.f_id(F.interpolate((i_low - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2)
loss_shape = self.l1_loss(q_fuse, q_r) + self.l1_loss(q_fuse, q_low)
loss_shape = torch.clamp(loss_shape, min=0.0, max=10.0)
inner_product_r = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_r.unsqueeze(2)).squeeze()
inner_product_low = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_low.unsqueeze(2)).squeeze()
loss_id = self.l1_loss(torch.ones_like(inner_product_r), inner_product_r) + self.l1_loss(
torch.ones_like(inner_product_low), inner_product_low
)
loss_sid = self.lambda_shape * loss_shape + self.lambda_id * loss_id
# Realism Loss: segmentation loss + reconstruction loss + cycle loss + perceptual loss + adversarial loss
loss_cycle = self.l1_loss(target_img, i_cycle)
# dilate target mask
target_mask = kornia.morphology.dilation(target_mask, self.dilation_kernel)
loss_segmentation = self.l1_loss(
F.interpolate(target_mask, scale_factor=0.25, mode="bilinear"), m_low
) + self.l1_loss(target_mask, m_r)
loss_reconstruction = self.l1_loss(i_r * same, target_img * same) + self.l1_loss(
i_low * same, F.interpolate(target_img, scale_factor=0.25, mode="bilinear") * same
)
loss_perceptual = self.loss_fn_vgg(target_img * same, i_r * same).mean()
loss_adversarial = self.adv_loss(d_r, True, for_discriminator=False)
loss_realism = (
self.lambda_adv * loss_adversarial
+ self.lambda_seg * loss_segmentation
+ self.lambda_rec * loss_reconstruction
+ self.lambda_cyc * loss_cycle
+ self.lambda_lpips * loss_perceptual
)
# eye hm loss
loss_eye_hm = 0
# mouth hm loss
loss_mouth_hm = 0
if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss:
target_hm = self.model_mouth(target_img)
r_hm = self.model_mouth(i_r)
if TrainConfig().eye_hm_loss:
target_eye_hm = target_hm[:, 96:98, :, :]
r_eye_hm = r_hm[:, 96:98, :, :]
loss_eye_hm = self.mse_loss(r_eye_hm, target_eye_hm)
loss_realism = loss_realism + self.lambda_eye_hm * loss_eye_hm
if TrainConfig().mouth_hm_loss:
target_mouth_hm = target_hm[:, 76:96, :, :]
r_mouth_hm = r_hm[:, 76:96, :, :]
loss_mouth_hm = self.mse_loss(r_mouth_hm, target_mouth_hm)
loss_realism = loss_realism + self.lambda_mouth_hm * loss_mouth_hm
loss_generator = loss_sid + loss_realism
loss_dict = {
"loss_shape": loss_shape,
"loss_id": loss_id,
"loss_sid": loss_sid,
"loss_cycle": loss_cycle,
"loss_segmentation": loss_segmentation,
"loss_reconstruction": loss_reconstruction,
"loss_perceptual": loss_perceptual,
"loss_adversarial": loss_adversarial,
"loss_realism": loss_realism,
"loss_generator": loss_generator,
}
if TrainConfig().eye_hm_loss:
loss_dict.update({"loss_eye_hm": loss_eye_hm})
if TrainConfig().mouth_hm_loss:
loss_dict.update({"loss_mouth_hm": loss_mouth_hm})
return (
source_img,
target_img,
i_cycle.detach(),
i_r.detach(),
m_r.detach(),
loss_dict,
)
def train_forward_discriminator(self, target_img, i_r):
"""
训练时候 Discriminator的loss计算
Parameters:
-----------
target_img: torch.Tensor, 目标脸图片
i_r: torch.Tensor, 换脸结果
Returns:
--------
Dict[str]: contains pair of loss name and loss values
"""
d_gt = self.discriminator(target_img)
d_fake = self.discriminator(i_r.detach())
loss_real = self.adv_loss(d_gt, True)
loss_fake = self.adv_loss(d_fake, False)
# alpha = torch.rand(target_img.shape[0], 1, 1, 1).to(target_img.device)
# x_hat = (alpha * target_img.data + (1 - alpha) * i_r.data).requires_grad_(True)
# out = self.discriminator(x_hat)
# loss_gp = gradient_penalty(out, x_hat)
loss_discriminator = loss_real + loss_fake # + 10 * loss_gp
return {
"loss_real": loss_real,
"loss_fake": loss_fake,
# "loss_gp": loss_gp,
"loss_discriminator": loss_discriminator,
}
def forward(
self, source_img: torch.Tensor, target_img: torch.Tensor, shape_rate=None, id_rate=None
) -> torch.Tensor:
"""
Parameters:
-----------
source_img: torch.Tensor, source face 图像
target_img: torch.Tensor, target face 图像
*_rate: 插值系数
Returns:
--------
i_r: torch.Tensor, swapped result
"""
if shape_rate is None and id_rate is None:
i_r, _, m_r, _ = self.generator(source_img, target_img)
else:
if shape_rate is None:
shape_rate = 1.0
if id_rate is None:
id_rate = 1.0
i_r, _, m_r, _ = self.generator.interp(source_img, target_img, shape_rate, id_rate)
return i_r, m_r
def optimize(
self,
source_img: torch.Tensor,
target_img: torch.Tensor,
target_mask: torch.Tensor,
same_id_mask: torch.Tensor,
) -> Tuple[Dict, Dict[str, torch.Tensor]]:
"""
模型的optimize
训练模式下执行一次训练,并返回loss信息和结果
Parameters:
-----------
source_img: torch.Tensor, source face 图像
target_img: torch.Tensor, target face 图像
target_mask: torch.Tensor, target face mask
same_id_mask: torch.Tensor, same id mask, 标识source 和 target是否是同个人
Returns:
--------
Tuple[Dict, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
loss_dict, source_img, target_img, m_r(预测的mask), i_r(换脸结果)
"""
src_img, tgt_img, i_cycle, i_r, m_r, loss_G_dict = self.train_forward_generator(
source_img, target_img, target_mask, same_id_mask
)
loss_G = loss_G_dict["loss_generator"]
self.g_optimizer.zero_grad()
loss_G.backward()
global_norm_G = torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.grad_clip)
self.g_optimizer.step()
loss_D_dict = self.train_forward_discriminator(tgt_img, i_r)
loss_D = loss_D_dict["loss_discriminator"]
self.d_optimizer.zero_grad()
loss_D.backward()
global_norm_D = torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_clip)
self.d_optimizer.step()
total_loss_dict = {"global_norm_G": global_norm_G, "global_norm_D": global_norm_D}
total_loss_dict.update(loss_G_dict)
total_loss_dict.update(loss_D_dict)
return total_loss_dict, {
"source face": src_img,
"target face": tgt_img,
"swapped face": torch.clamp(i_r, min=0.0, max=1.0),
"pred face mask": m_r,
"cycle face": i_cycle,
}
class HifiFaceST(HifiFace):
def __init__(self, identity_extractor_config, device, generator_path):
super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
def load_checkpoint(self):
self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
logger.info(f"Loading generator from {self.generator_path}")
class HifiFaceWGM(HifiFace):
def __init__(self, identity_extractor_config, device, generator_path):
super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
def load_checkpoint(self):
self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
logger.info(f"Loading generator from {self.generator_path}")