Spaces:
Running
Running
| import os | |
| import cv2 | |
| import sys | |
| import torch | |
| import os.path as osp | |
| from basicsr.utils import img2tensor, tensor2img | |
| from torchvision.transforms.functional import normalize | |
| from facexlib.utils.face_restoration_helper import FaceRestoreHelper | |
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) | |
| sys.path.append(root_path) | |
| from SR_Inference.codeformer.codeformer_arch import CodeFormerArch | |
| from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo | |
| class CodeFormer: | |
| def __init__( | |
| self, | |
| upscale=2, | |
| bg_upsampler_name="realesrgan", | |
| prefered_net_in_upsampler="RRDBNet", | |
| fidelity_weight=0.8, | |
| ): | |
| self.upscale = int(upscale) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.fidelity_weight = fidelity_weight | |
| # ------------------------ set up background upsampler ------------------------ | |
| upsampler_zoo = RealEsrUpsamplerZoo( | |
| upscale=self.upscale, | |
| bg_upsampler_name=bg_upsampler_name, | |
| prefered_net_in_upsampler=prefered_net_in_upsampler, | |
| ) | |
| self.bg_upsampler = upsampler_zoo.bg_upsampler | |
| # ------------------ set up FaceRestoreHelper ------------------- | |
| gfpgan_weights_path = os.path.join( | |
| ROOT_DIR, "SR_Inference", "gfpgan", "weights" | |
| ) | |
| self.face_restorer_helper = FaceRestoreHelper( | |
| upscale_factor=self.upscale, | |
| face_size=512, | |
| crop_ratio=(1, 1), | |
| det_model="retinaface_resnet50", | |
| save_ext="png", | |
| use_parse=True, | |
| device=self.device, | |
| # model_rootpath="gfpgan/weights", | |
| model_rootpath=gfpgan_weights_path, | |
| ) | |
| # ------------------ load model ------------------- | |
| self.sr_model = CodeFormerArch().to(self.device) | |
| ckpt_path = os.path.join( | |
| ROOT_DIR, "SR_Inference", "codeformer", "weights", "codeformer_v0.1.0.pth" | |
| ) | |
| loadnet = torch.load(ckpt_path, map_location=self.device) | |
| if "params_ema" in loadnet: | |
| keyname = "params_ema" | |
| else: | |
| keyname = "params" | |
| self.sr_model.load_state_dict(loadnet[keyname]) | |
| self.sr_model.eval() | |
| def __call__(self, img): | |
| bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0] | |
| self.face_restorer_helper.clean_all() | |
| self.face_restorer_helper.read_image(img) | |
| self.face_restorer_helper.get_face_landmarks_5( | |
| only_keep_largest=True, only_center_face=False, eye_dist_threshold=5 | |
| ) | |
| self.face_restorer_helper.align_warp_face() | |
| if len(self.face_restorer_helper.cropped_faces) > 0: | |
| cropped_face = self.face_restorer_helper.cropped_faces[0] | |
| cropped_face_t = img2tensor( | |
| imgs=cropped_face / 255.0, bgr2rgb=True, float32=True | |
| ) | |
| normalize( | |
| tensor=cropped_face_t, | |
| mean=(0.5, 0.5, 0.5), | |
| std=(0.5, 0.5, 0.5), | |
| inplace=True, | |
| ) | |
| cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) | |
| # ------------------- restore/enhance image using CodeFormerArch model ------------------- | |
| output = self.sr_model(cropped_face_t, w=self.fidelity_weight, adain=True)[ | |
| 0 | |
| ] | |
| restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) | |
| restored_face = restored_face.astype("uint8") | |
| self.face_restorer_helper.add_restored_face(restored_face) | |
| self.face_restorer_helper.get_inverse_affine(None) | |
| sr_img = self.face_restorer_helper.paste_faces_to_input_image( | |
| upsample_img=bg_img | |
| ) | |
| else: | |
| sr_img = bg_img | |
| return sr_img | |
| if __name__ == "__main__": | |
| codeformer = CodeFormer(upscale=2, fidelity_weight=1.0) | |
| img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png") | |
| sr_img = codeformer(img=img) | |
| saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs" | |
| os.makedirs(saving_dir, exist_ok=True) | |
| cv2.imwrite(f"{saving_dir}/sr_img.png", sr_img) | |