Spaces:
Running
Running
| import os | |
| import cv2 | |
| import sys | |
| import torch | |
| import os.path as osp | |
| from gfpgan import GFPGANer | |
| from basicsr.utils.download_util import load_file_from_url | |
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) | |
| sys.path.append(root_path) | |
| from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo | |
| class GFPGAN: | |
| def __init__( | |
| self, | |
| upscale=2, | |
| bg_upsampler_name="realesrgan", | |
| prefered_net_in_upsampler="RRDBNet", | |
| ): | |
| upscale = int(upscale) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ------------------------ set up background upsampler ------------------------ | |
| upsampler_zoo = RealEsrUpsamplerZoo( | |
| upscale=upscale, | |
| bg_upsampler_name=bg_upsampler_name, | |
| prefered_net_in_upsampler=prefered_net_in_upsampler, | |
| ) | |
| bg_upsampler = upsampler_zoo.bg_upsampler | |
| # ------------------------ load model ------------------------ | |
| gfpgan_weights_path = os.path.join( | |
| ROOT_DIR, "SR_Inference", "gfpgan", "weights" | |
| ) | |
| gfpgan_model_path = os.path.join(gfpgan_weights_path, "GFPGANv1.3.pth") | |
| if not os.path.isfile(gfpgan_model_path): | |
| url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth" | |
| gfpgan_model_path = load_file_from_url( | |
| url=url, | |
| model_dir=gfpgan_weights_path, | |
| progress=True, | |
| file_name="GFPGANv1.3.pth", | |
| ) | |
| self.sr_model = GFPGANer( | |
| upscale=upscale, | |
| bg_upsampler=bg_upsampler, | |
| model_path=gfpgan_model_path, | |
| device=device, | |
| ) | |
| def __call__(self, img): | |
| # ------------------------ restore/enhance image using GFPGAN model ------------------------ | |
| cropped_faces, sr_faces, sr_img = self.sr_model.enhance(img) | |
| return sr_img | |
| if __name__ == "__main__": | |
| gfpgan = GFPGAN( | |
| upscale=2, bg_upsampler_name="realesrgan", prefered_net_in_upsampler="RRDBNet" | |
| ) | |
| img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png") | |
| sr_img = gfpgan(img=img) | |
| saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs" | |
| os.makedirs(saving_dir, exist_ok=True) | |
| cv2.imwrite(f"{saving_dir}/sr_img_gfpgan.png", sr_img) | |