Spaces:
Running
Running
| import os | |
| import cv2 | |
| import torch | |
| from realesrgan import RealESRGANer | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from realesrgan.archs.srvgg_arch import SRVGGNetCompact | |
| from basicsr.utils.download_util import load_file_from_url | |
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| class RealEsrUpsamplerZoo: | |
| def __init__( | |
| self, | |
| upscale=2, | |
| bg_upsampler_name="realesrgan", | |
| prefered_net_in_upsampler="RRDBNet", | |
| ): | |
| self.upscale = int(upscale) | |
| # ------------------------ set up background upsampler ------------------------ | |
| weights_path = os.path.join( | |
| ROOT_DIR, "SR_Inference", f"{bg_upsampler_name}", "weights" | |
| ) | |
| if bg_upsampler_name == "realesrgan": | |
| model = self.get_prefered_net(prefered_net_in_upsampler, upscale) | |
| if self.upscale == 2: | |
| model_path = os.path.join(weights_path, "RealESRGAN_x2plus.pth") | |
| url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" | |
| elif self.upscale == 4: | |
| model_path = os.path.join(weights_path, "RealESRGAN_x4plus.pth") | |
| url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" | |
| else: | |
| raise Exception( | |
| f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}" | |
| ) | |
| elif bg_upsampler_name == "realesrnet": | |
| model = self.get_prefered_net(prefered_net_in_upsampler, upscale) | |
| if self.upscale == 4: | |
| model_path = os.path.join(weights_path, "RealESRNet_x4plus.pth") | |
| url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth" | |
| else: | |
| raise Exception( | |
| f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}" | |
| ) | |
| elif bg_upsampler_name == "anime": | |
| model = self.get_prefered_net(prefered_net_in_upsampler, upscale) | |
| if self.upscale == 4: | |
| model_path = os.path.join( | |
| weights_path, "RealESRGAN_x4plus_anime_6B.pth" | |
| ) | |
| url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" | |
| else: | |
| raise Exception( | |
| f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}" | |
| ) | |
| else: | |
| raise Exception(f"No model implemented for: {bg_upsampler_name}") | |
| # ------------------------ load background upsampler model ------------------------ | |
| if not os.path.isfile(model_path): | |
| model_path = load_file_from_url( | |
| url=url, model_dir=weights_path, progress=True, file_name=None | |
| ) | |
| self.bg_upsampler = RealESRGANer( | |
| scale=int(upscale), | |
| model_path=model_path, | |
| model=model, | |
| tile=0, | |
| tile_pad=0, | |
| pre_pad=0, | |
| half=False, | |
| ) | |
| def get_prefered_net(prefered_net_in_upsampler, upscale=2): | |
| if prefered_net_in_upsampler == "RRDBNet": | |
| model = RRDBNet( | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_block=23, | |
| num_grow_ch=32, | |
| scale=int(upscale), | |
| ) | |
| elif prefered_net_in_upsampler == "SRVGGNetCompact": | |
| model = SRVGGNetCompact( | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_conv=16, | |
| upscale=int(upscale), | |
| act_type="prelu", | |
| ) | |
| else: | |
| raise Exception(f"No net named: {prefered_net_in_upsampler} implemented!") | |
| return model | |