Spaces:
Running
Running
| import os | |
| import cv2 | |
| import sys | |
| import torch | |
| import numpy as np | |
| import os.path as osp | |
| from PIL import Image | |
| from basicsr.utils import img2tensor | |
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) | |
| sys.path.append(root_path) | |
| from SR_Inference.hat.hat_arch import HATArch | |
| class HAT: | |
| def __init__( | |
| self, | |
| upscale=2, | |
| in_chans=3, | |
| img_size=(480, 640), | |
| window_size=16, | |
| compress_ratio=3, | |
| squeeze_factor=30, | |
| conv_scale=0.01, | |
| overlap_ratio=0.5, | |
| img_range=1.0, | |
| depths=[6, 6, 6, 6, 6, 6], | |
| embed_dim=180, | |
| num_heads=[6, 6, 6, 6, 6, 6], | |
| mlp_ratio=2, | |
| upsampler="pixelshuffle", | |
| resi_connection="1conv", | |
| ): | |
| upscale = int(upscale) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ------------------ load model for img enhancement ------------------- | |
| self.sr_model = HATArch( | |
| img_size=img_size, | |
| upscale=upscale, | |
| in_chans=in_chans, | |
| window_size=window_size, | |
| compress_ratio=compress_ratio, | |
| squeeze_factor=squeeze_factor, | |
| conv_scale=conv_scale, | |
| overlap_ratio=overlap_ratio, | |
| img_range=img_range, | |
| depths=depths, | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| upsampler=upsampler, | |
| resi_connection=resi_connection, | |
| ).to(self.device) | |
| ckpt_path = os.path.join( | |
| ROOT_DIR, | |
| "SR_Inference", | |
| "hat", | |
| "weights", | |
| f"HAT_SRx{str(upscale)}_ImageNet-pretrain.pth", | |
| ) | |
| loadnet = torch.load(ckpt_path, map_location=self.device) | |
| if "params_ema" in loadnet: | |
| keyname = "params_ema" | |
| else: | |
| keyname = "params" | |
| self.sr_model.load_state_dict(loadnet[keyname]) | |
| self.sr_model.eval() | |
| def __call__(self, img): | |
| img_tensor = ( | |
| img2tensor(imgs=img / 255.0, bgr2rgb=True, float32=True) | |
| .unsqueeze(0) | |
| .to(self.device) | |
| ) | |
| restored_img = self.sr_model(img_tensor)[0] | |
| restored_img = restored_img.permute(1, 2, 0).cpu().numpy() | |
| restored_img = (restored_img - restored_img.min()) / ( | |
| restored_img.max() - restored_img.min() | |
| ) | |
| restored_img = (restored_img * 255).astype(np.uint8) | |
| restored_img = Image.fromarray(restored_img) | |
| restored_img = np.array(restored_img) | |
| sr_img = cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR) | |
| return sr_img | |
| if __name__ == "__main__": | |
| hat = HAT(upscale=2) | |
| img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png") | |
| sr_img = hat(img=img) | |
| saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs" | |
| os.makedirs(saving_dir, exist_ok=True) | |
| cv2.imwrite(f"{saving_dir}/sr_img_hat.png", sr_img) | |