Spaces:
Running
Running
| import os | |
| import uuid | |
| import glob | |
| import shutil | |
| from pathlib import Path | |
| from multiprocessing.pool import Pool | |
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import tqdm | |
| from modules.networks.faceshifter import FSGenerator | |
| from inference.alignment import norm_crop, norm_crop_with_M, paste_back | |
| from inference.utils import save, get_5_from_98, get_detector, get_lmk | |
| from third_party.PIPNet.lib.tools import get_lmk_model, demo_image | |
| from inference.landmark_smooth import kalman_filter_landmark, savgol_filter_landmark | |
| from inference.tricks import Trick | |
| make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn)) | |
| fs_model_name = 'faceshifter' | |
| in_size = 256 | |
| mouth_net_param = { | |
| "use": True, | |
| "feature_dim": 128, | |
| "crop_param": (28, 56, 84, 112), | |
| "weight_path": make_abs_path("./weights/arcface/mouth_net_28_56_84_112.pth"), | |
| } | |
| trick = Trick() | |
| T = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize(0.5, 0.5), | |
| ] | |
| ) | |
| tensor2pil_transform = transforms.ToPILImage() | |
| def extract_generator(ckpt: str, pt: str): | |
| print(f'[extract_generator] loading ckpt...') | |
| from trainer.faceshifter.faceshifter_pl import FaceshifterPL512, FaceshifterPL | |
| import yaml | |
| with open(make_abs_path('../../trainer/faceshifter/config.yaml'), 'r') as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| config['mouth_net'] = mouth_net_param | |
| if in_size == 256: | |
| net = FaceshifterPL(n_layers=3, num_D=3, config=config) | |
| elif in_size == 512: | |
| net = FaceshifterPL512(n_layers=3, num_D=3, config=config, verbose=False) | |
| else: | |
| raise ValueError('Not supported in_size.') | |
| checkpoint = torch.load(ckpt, map_location="cpu", ) | |
| net.load_state_dict(checkpoint["state_dict"], strict=False) | |
| net.eval() | |
| G = net.generator | |
| torch.save(G.state_dict(), pt) | |
| print(f'[extract_generator] extracted from {ckpt}, pth saved to {pt}') | |
| ''' load model ''' | |
| if fs_model_name == 'faceshifter': | |
| pt_path = make_abs_path("./weights/extracted/G_mouth1_t38_post.pth") | |
| # pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_6.pth") | |
| # ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_6/epoch=3-step=128999.ckpt" | |
| # pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_4.pth") | |
| # ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_4/epoch=2-step=185999.ckpt" | |
| if not os.path.exists(pt_path) or 't512' in pt_path: | |
| extract_generator(ckpt_path, pt_path) | |
| fs_model = FSGenerator( | |
| make_abs_path("./weights/arcface/ms1mv3_arcface_r100_fp16/backbone.pth"), | |
| mouth_net_param=mouth_net_param, | |
| in_size=in_size, | |
| downup=in_size == 512, | |
| ) | |
| fs_model.load_state_dict(torch.load(pt_path, "cpu"), strict=True) | |
| fs_model.eval() | |
| def infer_batch_to_img(i_s, i_t, post: bool = False): | |
| i_r = fs_model(i_s, i_t)[0] # x, id_vector, att | |
| if post: | |
| target_hair_mask = trick.get_any_mask(i_t, par=[0, 17]) | |
| target_hair_mask = trick.smooth_mask(target_hair_mask) | |
| i_r = target_hair_mask * i_t + (target_hair_mask * (-1.) + 1.) * i_r # torch 1.12.0 | |
| i_r = trick.finetune_mouth(i_s, i_t, i_r) if in_size == 256 else i_r | |
| img_r = trick.tensor_to_arr(i_r)[0] | |
| return img_r | |
| elif fs_model_name == 'simswap_triplet' or fs_model_name == 'simswap_vanilla': | |
| from modules.networks.simswap import Generator_Adain_Upsample | |
| sw_model = Generator_Adain_Upsample( | |
| input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False, | |
| mouth_net_param=mouth_net_param | |
| ) | |
| if fs_model_name == 'simswap_triplet': | |
| pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_st5.pth") | |
| ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/" | |
| "simswap_triplet_5/epoch=12-step=782999.ckpt") | |
| elif fs_model_name == 'simswap_vanilla': | |
| pt_path = make_abs_path("../ffplus/extracted_ckpt/G_tmp_sv4_off.pth") | |
| ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/" | |
| "simswap_vanilla_4/epoch=694-step=1487999.ckpt") | |
| else: | |
| pt_path = None | |
| ckpt_path = None | |
| sw_model.load_state_dict(torch.load(pt_path, "cpu"), strict=False) | |
| sw_model.eval() | |
| fs_model = sw_model | |
| from trainer.simswap.simswap_pl import SimSwapPL | |
| import yaml | |
| with open(make_abs_path('../../trainer/simswap/config.yaml'), 'r') as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| config['mouth_net'] = mouth_net_param | |
| net = SimSwapPL(config=config, use_official_arc='off' in pt_path) | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| net.load_state_dict(checkpoint["state_dict"], strict=False) | |
| net.eval() | |
| sw_mouth_net = net.mouth_net # maybe None | |
| sw_netArc = net.netArc | |
| fs_model = fs_model.cuda() | |
| sw_mouth_net = sw_mouth_net.cuda() if sw_mouth_net is not None else sw_mouth_net | |
| sw_netArc = sw_netArc.cuda() | |
| def infer_batch_to_img(i_s, i_t, post: bool = False): | |
| i_r = fs_model(source=i_s, target=i_t, net_arc=sw_netArc, mouth_net=sw_mouth_net,) | |
| if post: | |
| target_hair_mask = trick.get_any_mask(i_t, par=[0, 17]) | |
| target_hair_mask = trick.smooth_mask(target_hair_mask) | |
| i_r = target_hair_mask * i_t + (target_hair_mask * (-1) + 1) * i_r | |
| i_r = i_r.clamp(-1, 1) | |
| i_r = trick.tensor_to_arr(i_r)[0] | |
| return i_r | |
| elif fs_model_name == 'simswap_official': | |
| from simswap.image_infer import SimSwapOfficialImageInfer | |
| fs_model = SimSwapOfficialImageInfer() | |
| pt_path = 'Simswap Official' | |
| mouth_net_param = { | |
| "use": False | |
| } | |
| def infer_batch_to_img(i_s, i_t): | |
| i_r = fs_model.image_infer(source_tensor=i_s, target_tensor=i_t) | |
| i_r = i_r.clamp(-1, 1) | |
| return i_r | |
| else: | |
| raise ValueError('Not supported fs_model_name.') | |
| print(f'[demo] model loaded from {pt_path}') | |
| def swap_image( | |
| source_image, | |
| target_path, | |
| out_path, | |
| transform, | |
| G, | |
| align_source="arcface", | |
| align_target="set1", | |
| gpu_mode=True, | |
| paste_back=True, | |
| use_post=False, | |
| use_gpen=False, | |
| in_size=256, | |
| ): | |
| name = target_path.split("/")[-1] | |
| name = "out_" + name | |
| if isinstance(G, torch.nn.Module): | |
| G.eval() | |
| if gpu_mode: | |
| G = G.cuda() | |
| device = torch.device(0) if gpu_mode else torch.device('cpu') | |
| source_img = np.array(Image.open(source_image).convert("RGB")) | |
| net, detector = get_lmk_model() | |
| lmk = get_5_from_98(demo_image(source_img, net, detector, device=device)[0]) | |
| source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0) | |
| source_img = transform(source_img).unsqueeze(0) | |
| target = np.array(Image.open(target_path).convert("RGB")) | |
| original_target = target.copy() | |
| lmk = get_5_from_98(demo_image(target, net, detector, device=device)[0]) | |
| target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0) | |
| target = transform(target).unsqueeze(0) | |
| if gpu_mode: | |
| target = target.cuda() | |
| source_img = source_img.cuda() | |
| cv2.imwrite('cropped_source.png', trick.tensor_to_arr(source_img)[0, :, :, ::-1]) | |
| cv2.imwrite('cropped_target.png', trick.tensor_to_arr(target)[0, :, :, ::-1]) | |
| # both inputs should be 512 | |
| result = infer_batch_to_img(source_img, target, post=use_post) | |
| cv2.imwrite('result.png', result[:, :, ::-1]) | |
| os.makedirs(out_path, exist_ok=True) | |
| Image.fromarray(result.astype(np.uint8)).save(os.path.join(out_path, name)) | |
| save((result, M, original_target, os.path.join(out_path, "paste_back_" + name), None), | |
| trick=trick, use_gpen=use_gpen) | |
| def process_video( | |
| source_image, | |
| target_path, | |
| out_path, | |
| transform, | |
| G, | |
| align_source="arcface", | |
| align_target="set1", | |
| gpu_mode=True, | |
| frames=9999999, | |
| use_tddfav2=False, | |
| landmark_smooth="kalman", | |
| use_gpen=False, | |
| ): | |
| if isinstance(G, torch.nn.Module): | |
| G.eval() | |
| if gpu_mode: | |
| G = G.cuda() | |
| device = torch.device(0) if gpu_mode else torch.device('cpu') | |
| ''' Target video to frames (.png) ''' | |
| fps = 25.0 | |
| if not os.path.isdir(target_path): | |
| vidcap = cv2.VideoCapture(target_path) | |
| fps = vidcap.get(cv2.CAP_PROP_FPS) | |
| try: | |
| for match in glob.glob(os.path.join("./tmp/", "*.png")): | |
| os.remove(match) | |
| for match in glob.glob(os.path.join(out_path, "*.png")): | |
| os.remove(match) | |
| except Exception as e: | |
| print(e) | |
| os.makedirs("./tmp/", exist_ok=True) | |
| os.system( | |
| f"ffmpeg -i {target_path} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 ./tmp/frame_%05d.png" | |
| ) | |
| target_path = "./tmp/" | |
| globbed_images = sorted(glob.glob(os.path.join(target_path, "*.png"))) | |
| ''' Get target landmarks ''' | |
| print('[Extracting target landmarks...]') | |
| if not use_tddfav2: | |
| align_net, align_detector = get_lmk_model() | |
| else: | |
| align_net, align_detector = get_detector(gpu_mode=gpu_mode) | |
| target_lmks = [] | |
| for frame_path in tqdm.tqdm(globbed_images): | |
| target = np.array(Image.open(frame_path).convert("RGB")) | |
| lmk = demo_image(target, align_net, align_detector, device=device) | |
| lmk = lmk[0] | |
| target_lmks.append(lmk) | |
| ''' Landmark smoothing ''' | |
| target_lmks = np.array(target_lmks, np.float32) # (#frames, 98, 2) | |
| if landmark_smooth == 'kalman': | |
| target_lmks = kalman_filter_landmark(target_lmks, | |
| process_noise=0.01, | |
| measure_noise=0.01).astype(np.int32) | |
| elif landmark_smooth == 'savgol': | |
| target_lmks = savgol_filter_landmark(target_lmks).astype(np.int32) | |
| elif landmark_smooth == 'cancel': | |
| target_lmks = target_lmks.astype(np.int32) | |
| else: | |
| raise KeyError('Not supported landmark_smooth choice') | |
| ''' Crop source image ''' | |
| source_img = np.array(Image.open(source_image).convert("RGB")) | |
| if not use_tddfav2: | |
| lmk = get_5_from_98(demo_image(source_img, align_net, align_detector, device=device)[0]) | |
| else: | |
| lmk = get_lmk(source_img, align_net, align_detector) | |
| source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0) | |
| source_img = transform(source_img).unsqueeze(0) | |
| if gpu_mode: | |
| source_img = source_img.cuda() | |
| ''' Process by frames ''' | |
| targets = [] | |
| t_facial_masks = [] | |
| Ms = [] | |
| original_frames = [] | |
| names = [] | |
| count = 0 | |
| for image in tqdm.tqdm(globbed_images): | |
| names.append(os.path.join(out_path, Path(image).name)) | |
| target = np.array(Image.open(image).convert("RGB")) | |
| original_frames.append(target) | |
| ''' Crop target frames ''' | |
| lmk = get_5_from_98(target_lmks[count]) | |
| target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0) | |
| target = transform(target).unsqueeze(0) # in [-1,1] | |
| if gpu_mode: | |
| target = target.cuda() | |
| ''' Finetune paste masks ''' | |
| target_facial_mask = trick.get_any_mask(target, | |
| par=[1, 2, 3, 4, 5, 6, 10, 11, 12, 13]).squeeze() # in [0,1] | |
| target_facial_mask = target_facial_mask.cpu().numpy().astype(np.float32) | |
| target_facial_mask = trick.finetune_mask(target_facial_mask, target_lmks) # in [0,1] | |
| t_facial_masks.append(target_facial_mask) | |
| ''' Face swapping ''' | |
| with torch.no_grad(): | |
| if 'faceshifter' in fs_model_name: | |
| output = fs_model(source_img, target)[0] | |
| target_hair_mask = trick.get_any_mask(target, par=[0, 17]) | |
| target_hair_mask = trick.smooth_mask(target_hair_mask) | |
| # print(output.shape, target.shape, target_hair_mask.shape) | |
| output = target_hair_mask * target + (target_hair_mask * (-1.) + 1.) * output | |
| output = trick.finetune_mouth(source_img, target, output) | |
| elif 'simswap' in fs_model_name and 'official' not in fs_model_name: | |
| output = fs_model(source=source_img, target=target, | |
| net_arc=sw_netArc, mouth_net=sw_mouth_net,) | |
| if 'vanilla' not in fs_model_name: | |
| target_hair_mask = trick.get_any_mask(target, par=[0, 17]) | |
| target_hair_mask = trick.smooth_mask(target_hair_mask) | |
| output = target_hair_mask * target + (target_hair_mask * (-1.) + 1.) * output | |
| output = trick.finetune_mouth(source_img, target, output) | |
| output = output.clamp(-1, 1) | |
| elif 'simswap_official' in fs_model_name: | |
| output = fs_model.image_infer(source_tensor=source_img, target_tensor=target) | |
| output = output.clamp(-1, 1) | |
| if isinstance(output, tuple): | |
| target = output[0][0] * 0.5 + 0.5 | |
| else: | |
| target = output[0] * 0.5 + 0.5 | |
| targets.append(trick.gpen(np.array(tensor2pil_transform(target)), use_gpen=use_gpen)) | |
| Ms.append(M) | |
| count += 1 | |
| if count > frames: | |
| break | |
| os.makedirs(out_path, exist_ok=True) | |
| return targets, t_facial_masks, Ms, original_frames, names, fps | |
| def swap_image_gr(img1, img2, use_post=False, use_gpen=False, ): | |
| root_dir = make_abs_path("./online_data") | |
| req_id = uuid.uuid1().hex | |
| data_dir = os.path.join(root_dir, req_id) | |
| os.makedirs(data_dir, exist_ok=True) | |
| source_path = os.path.join(data_dir, "source.png") | |
| target_path = os.path.join(data_dir, "target.png") | |
| filename = "paste_back_out_target.png" | |
| out_path = os.path.join(data_dir, filename) | |
| cv2.imwrite(source_path, img1[:, :, ::-1]) | |
| cv2.imwrite(target_path, img2[:, :, ::-1]) | |
| swap_image( | |
| source_path, | |
| target_path, | |
| data_dir, | |
| T, | |
| fs_model, | |
| gpu_mode=use_gpu, | |
| align_target='ffhq', | |
| align_source='ffhq', | |
| use_post=use_post, | |
| use_gpen=use_gpen, | |
| in_size=in_size, | |
| ) | |
| out = cv2.imread(out_path)[..., ::-1] | |
| return out | |
| def swap_video_gr(img1, target_path, use_gpen=False, frames=9999999): | |
| root_dir = make_abs_path("./online_data") | |
| req_id = uuid.uuid1().hex | |
| data_dir = os.path.join(root_dir, req_id) | |
| os.makedirs(data_dir, exist_ok=True) | |
| source_path = os.path.join(data_dir, "source.png") | |
| cv2.imwrite(source_path, img1[:, :, ::-1]) | |
| out_dir = os.path.join(data_dir, "out") | |
| out_name = "output.mp4" | |
| targets, t_facial_masks, Ms, original_frames, names, fps = process_video( | |
| source_path, | |
| target_path, | |
| out_dir, | |
| T, | |
| fs_model, | |
| gpu_mode=use_gpu, | |
| frames=frames, | |
| align_target='ffhq', | |
| align_source='ffhq', | |
| use_tddfav2=False, | |
| use_gpen=use_gpen, | |
| ) | |
| pool_process = 170 | |
| audio = True | |
| concat = False | |
| if pool_process <= 1: | |
| for target, M, original_target, name, t_facial_mask in tqdm.tqdm( | |
| zip(targets, Ms, original_frames, names, t_facial_masks) | |
| ): | |
| if M is None or target is None: | |
| Image.fromarray(original_target.astype(np.uint8)).save(name) | |
| continue | |
| Image.fromarray(paste_back(np.array(target), M, original_target, t_facial_mask)).save(name) | |
| else: | |
| with Pool(pool_process) as pool: | |
| pool.map(save, zip(targets, Ms, original_frames, names, t_facial_masks)) | |
| video_save_path = os.path.join(out_dir, out_name) | |
| if audio: | |
| print("use audio") | |
| os.system( | |
| f"ffmpeg -y -r {fps} -i {out_dir}/frame_%05d.png -i {target_path}" | |
| f" -map 0:v:0 -map 1:a:0? -c:a copy -c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p {video_save_path}" | |
| ) | |
| else: | |
| print("no audio") | |
| os.system( | |
| f"ffmpeg -y -r {fps} -i ./tmp/frame_%05d.png " | |
| f"-c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p {video_save_path}" | |
| ) | |
| # ffmpeg -i left.mp4 -i right.mp4 -filter_complex hstack output.mp4 | |
| if concat: | |
| concat_video_save_path = os.path.join(out_dir, "concat_" + out_name) | |
| os.system( | |
| f"ffmpeg -y -i {target_path} -i {video_save_path} -filter_complex hstack {concat_video_save_path}" | |
| ) | |
| # delete tmp file | |
| shutil.rmtree("./tmp/") | |
| for match in glob.glob(os.path.join(out_dir, "*.png")): | |
| os.remove(match) | |
| print(video_save_path) | |
| return video_save_path | |
| if __name__ == "__main__": | |
| use_gpu = torch.cuda.is_available() | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<div align='center'> <h2> ReliableSwap: Boosting General Face Swapping Via Reliable Supervision </span> </h2> \ | |
| <a style='font-size:18px;color: #7B68EE' href='https://arxiv.org/abs/2306.05356'>Arxiv</a> \ | |
| <a style='font-size:18px;color: #7B68EE' href='https://reliable-swap.github.io/'>Homepage</a> \ | |
| <a style='font-size:18px;color: #7B68EE' href='https://github.com/ygtxr1997/ReliableSwap'> Github </div>") | |
| gr.Markdown(""" | |
| <b>Alternatively, try our GitHub <a href=https://github.com/ygtxr1997/ReliableSwap> code </a> on your own GPU. </b> | |
| <a style='display:inline-block' href="https://github.com/ygtxr1997/ReliableSwap"> | |
| <img src="https://img.shields.io/github/stars/ygtxr1997/ReliableSwap?style=social"/></a> \ | |
| """) | |
| with gr.Tab("Image"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| image1_input = gr.Image(label='source') | |
| image2_input = gr.Image(label='target') | |
| image_use_post = gr.Checkbox(label="Post-Process") | |
| image_use_gpen = gr.Checkbox(label="Super Resolution (based on GPEN)") | |
| with gr.Column(scale=2): | |
| image_output = gr.Image() | |
| image_button = gr.Button("Run: Face Swapping", variant="primary") | |
| with gr.Tab("Video"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| image3_input = gr.Image(label='source') | |
| video_input = gr.Video(label='target') | |
| with gr.Column(scale=2): | |
| video_output = gr.Video() | |
| video_use_gpen = gr.Checkbox(label="Super Resolution (based on GPEN). " | |
| "Warning: very time-cost on CPU!") | |
| video_button = gr.Button("Run: Face Swapping", variant="primary") | |
| image_button.click( | |
| swap_image_gr, | |
| inputs=[image1_input, image2_input, image_use_post, image_use_gpen], | |
| outputs=image_output, | |
| ) | |
| video_button.click( | |
| swap_video_gr, | |
| inputs=[image3_input, video_input, video_use_gpen], | |
| outputs=video_output, | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |