Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import numpy | |
| import os | |
| import random | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from basicsr.utils.download_util import load_file_from_url | |
| from realesrgan import RealESRGANer | |
| from realesrgan.archs.srvgg_arch import SRVGGNetCompact | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Globals | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| last_file = None | |
| img_mode = "RGBA" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Utilities | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def rnd_string(x: int) -> str: | |
| """Returns a string of 'x' random characters.""" | |
| characters = "abcdefghijklmnopqrstuvwxyz_0123456789" | |
| result = "".join((random.choice(characters)) for _ in range(x)) | |
| return result | |
| def reset(): | |
| """Resets the Image components and deletes the last processed image.""" | |
| global last_file | |
| if last_file: | |
| try: | |
| print(f"Deleting {last_file} ...") | |
| os.remove(last_file) | |
| except Exception as e: | |
| print("Delete error:", e) | |
| last_file = None | |
| return gr.update(value=None), gr.update(value=None) | |
| def has_transparency(img): | |
| """ | |
| Check for transparency in a PIL image. | |
| https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent | |
| """ | |
| if img.info.get("transparency", None) is not None: | |
| return True | |
| if img.mode == "P": | |
| transparent = img.info.get("transparency", -1) | |
| for _, index in img.getcolors(): | |
| if index == transparent: | |
| return True | |
| elif img.mode == "RGBA": | |
| extrema = img.getextrema() | |
| if extrema[3][0] < 255: | |
| return True | |
| return False | |
| def image_properties(img): | |
| """Return resolution & color mode of the input image; set global img_mode.""" | |
| global img_mode | |
| if img: | |
| if has_transparency(img): | |
| img_mode = "RGBA" | |
| else: | |
| img_mode = "RGB" | |
| properties = f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}" | |
| return properties | |
| def model_tip_text(model_name: str) -> str: | |
| """Return human-friendly guidance for the chosen model.""" | |
| tips = { | |
| "RealESRGAN_x4plus": ( | |
| "**RealESRGAN_x4plus (4Γ)** β Best for photoreal images (portraits, landscapes). " | |
| "Balanced detail recovery. Good default for Flux realism." | |
| ), | |
| "RealESRNet_x4plus": ( | |
| "**RealESRNet_x4plus (4Γ)** β Softer but great on noisy/compressed sources " | |
| "(old JPEGs, screenshots)." | |
| ), | |
| "RealESRGAN_x4plus_anime_6B": ( | |
| "**RealESRGAN_x4plus_anime_6B (4Γ)** β For anime/illustrations/line art only. " | |
| "Not recommended for real-life photos." | |
| ), | |
| "RealESRGAN_x2plus": ( | |
| "**RealESRGAN_x2plus (2Γ)** β Faster, lighter 2Γ cleanup when you don't need 4Γ." | |
| ), | |
| "realesr-general-x4v3": ( | |
| "**realesr-general-x4v3 (4Γ)** β Versatile mixed-content model with adjustable denoise. " | |
| "**Denoise Strength** slider only affects this model (blends with the WDN variant). " | |
| "Try 0.3β0.5 for slightly cleaner, sharper results." | |
| ), | |
| } | |
| return tips.get(model_name, "") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Core upscaling | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def realesrgan(img, model_name, denoise_strength, face_enhance, outscale): | |
| """Real-ESRGAN function to restore (and upscale) images with robust defaults.""" | |
| if img is None: | |
| return | |
| # ----- Select backbone + weights ----- | |
| if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] | |
| elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'] | |
| elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'] | |
| elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
| netscale = 2 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'] | |
| elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size) | |
| model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
| netscale = 4 | |
| # We'll ensure BOTH base and WDN weights exist; order matters for DNI. | |
| file_url = [ | |
| 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', | |
| 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth' | |
| ] | |
| else: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| # ----- Ensure weights are on disk ----- | |
| # For the general-x4v3 case we download both; for others single file is fine. | |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| weights_dir = os.path.join(ROOT_DIR, 'weights') | |
| os.makedirs(weights_dir, exist_ok=True) | |
| # Track model paths | |
| local_paths = [] | |
| for url in file_url: | |
| fname = os.path.basename(url) | |
| local_path = os.path.join(weights_dir, fname) | |
| if not os.path.isfile(local_path): | |
| local_path = load_file_from_url(url=url, model_dir=weights_dir, progress=True) | |
| local_paths.append(local_path) | |
| # Default path(s) | |
| if model_name == 'realesr-general-x4v3': | |
| # Order: [base, wdn] then set DNI weights accordingly | |
| base_path = os.path.join(weights_dir, 'realesr-general-x4v3.pth') | |
| wdn_path = os.path.join(weights_dir, 'realesr-general-wdn-x4v3.pth') | |
| model_path = [base_path, wdn_path] | |
| denoise_strength = float(denoise_strength) | |
| # Weight for WDN equals denoise_strength (cleaner); base gets the remainder | |
| dni_weight = [1.0 - denoise_strength, denoise_strength] | |
| else: | |
| model_path = os.path.join(weights_dir, f"{model_name}.pth") | |
| dni_weight = None | |
| # ----- CUDA / precision / tiling ----- | |
| # Be defensive: cv2.cuda may not exist in CPU-only builds. | |
| use_cuda = False | |
| try: | |
| use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0 | |
| except Exception: | |
| use_cuda = False | |
| gpu_id = 0 if use_cuda else None | |
| upsampler = RealESRGANer( | |
| scale=netscale, | |
| model_path=model_path, | |
| dni_weight=dni_weight, | |
| model=model, | |
| tile=256, # Safe VRAM default; increase if you have headroom | |
| tile_pad=10, | |
| pre_pad=10, | |
| half=bool(use_cuda), # FP16 on GPU | |
| gpu_id=gpu_id | |
| ) | |
| # ----- Optional face enhancement ----- | |
| face_enhancer = None | |
| if face_enhance: | |
| from gfpgan import GFPGANer | |
| face_enhancer = GFPGANer( | |
| model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', | |
| upscale=outscale, | |
| arch='clean', | |
| channel_multiplier=2, | |
| bg_upsampler=upsampler | |
| ) | |
| # ----- Convert PIL -> cv2 (handle RGB/RGBA) ----- | |
| cv_img = numpy.array(img) | |
| if cv_img.ndim == 3 and cv_img.shape[2] == 4: | |
| cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA) | |
| else: | |
| cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR) | |
| # ----- Enhance ----- | |
| try: | |
| if face_enhancer: | |
| _, _, output = face_enhancer.enhance(cv_img, has_aligned=False, only_center_face=False, paste_back=True) | |
| else: | |
| output, _ = upsampler.enhance(cv_img, outscale=int(outscale)) | |
| except RuntimeError as error: | |
| print('Error', error) | |
| print('Tip: If you hit CUDA OOM, try a smaller tile size (e.g., 128).') | |
| return None | |
| # ----- cv2 -> RGBA/RGB for Gradio, also save ----- | |
| if output.ndim == 3 and output.shape[2] == 4: | |
| display_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) | |
| extension = 'png' | |
| else: | |
| display_img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) | |
| extension = 'jpg' | |
| out_filename = f"output_{rnd_string(8)}.{extension}" | |
| try: | |
| cv2.imwrite(out_filename, output) | |
| global last_file | |
| last_file = out_filename | |
| except Exception as e: | |
| print("Save error:", e) | |
| return display_img # ndarray so Gradio displays immediately | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UI | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo: | |
| gr.Markdown("## Image Upscaler") | |
| with gr.Accordion("Upscaling options", open=True): | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| label="Upscaler model", | |
| choices=[ | |
| "RealESRGAN_x4plus", | |
| "RealESRNet_x4plus", | |
| "RealESRGAN_x4plus_anime_6B", | |
| "RealESRGAN_x2plus", | |
| "realesr-general-x4v3", | |
| ], | |
| value="RealESRGAN_x4plus", # photoreal default | |
| show_label=True | |
| ) | |
| denoise_strength = gr.Slider( | |
| label="Denoise Strength (only for realesr-general-x4v3)", | |
| minimum=0, maximum=1, step=0.1, value=0.5 | |
| ) | |
| outscale = gr.Slider( | |
| label="Resolution upscale", | |
| minimum=1, maximum=6, step=1, value=4, show_label=True | |
| ) | |
| face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)", value=False) | |
| # Model tips panel (auto-updates) | |
| model_tips = gr.Markdown(model_tip_text("RealESRGAN_x4plus")) | |
| with gr.Row(): | |
| with gr.Group(): | |
| input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA") | |
| input_image_properties = gr.Textbox(label="Image Properties", max_lines=1) | |
| output_image = gr.Image(label="Output Image", image_mode="RGBA") | |
| with gr.Row(): | |
| reset_btn = gr.Button("Remove images") | |
| restore_btn = gr.Button("Upscale") | |
| # Event listeners: | |
| input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties) | |
| model_name.change(fn=model_tip_text, inputs=model_name, outputs=model_tips) | |
| restore_btn.click( | |
| fn=realesrgan, | |
| inputs=[input_image, model_name, denoise_strength, face_enhance, outscale], | |
| outputs=output_image | |
| ) | |
| reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image]) | |
| gr.Markdown("") # spacer | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |