Image_Upscaler / app.py
JS6969's picture
Update app.py
8c30d17 verified
raw
history blame
12.4 kB
import gradio as gr
import cv2
import numpy
import os
import random
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
# ────────────────────────────────────────────────────────
# Globals
# ────────────────────────────────────────────────────────
last_file = None
img_mode = "RGBA"
# ────────────────────────────────────────────────────────
# Utilities
# ────────────────────────────────────────────────────────
def rnd_string(x: int) -> str:
"""Returns a string of 'x' random characters."""
characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
result = "".join((random.choice(characters)) for _ in range(x))
return result
def reset():
"""Resets the Image components and deletes the last processed image."""
global last_file
if last_file:
try:
print(f"Deleting {last_file} ...")
os.remove(last_file)
except Exception as e:
print("Delete error:", e)
last_file = None
return gr.update(value=None), gr.update(value=None)
def has_transparency(img):
"""
Check for transparency in a PIL image.
https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
"""
if img.info.get("transparency", None) is not None:
return True
if img.mode == "P":
transparent = img.info.get("transparency", -1)
for _, index in img.getcolors():
if index == transparent:
return True
elif img.mode == "RGBA":
extrema = img.getextrema()
if extrema[3][0] < 255:
return True
return False
def image_properties(img):
"""Return resolution & color mode of the input image; set global img_mode."""
global img_mode
if img:
if has_transparency(img):
img_mode = "RGBA"
else:
img_mode = "RGB"
properties = f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
return properties
def model_tip_text(model_name: str) -> str:
"""Return human-friendly guidance for the chosen model."""
tips = {
"RealESRGAN_x4plus": (
"**RealESRGAN_x4plus (4Γ—)** β€” Best for photoreal images (portraits, landscapes). "
"Balanced detail recovery. Good default for Flux realism."
),
"RealESRNet_x4plus": (
"**RealESRNet_x4plus (4Γ—)** β€” Softer but great on noisy/compressed sources "
"(old JPEGs, screenshots)."
),
"RealESRGAN_x4plus_anime_6B": (
"**RealESRGAN_x4plus_anime_6B (4Γ—)** β€” For anime/illustrations/line art only. "
"Not recommended for real-life photos."
),
"RealESRGAN_x2plus": (
"**RealESRGAN_x2plus (2Γ—)** β€” Faster, lighter 2Γ— cleanup when you don't need 4Γ—."
),
"realesr-general-x4v3": (
"**realesr-general-x4v3 (4Γ—)** β€” Versatile mixed-content model with adjustable denoise. "
"**Denoise Strength** slider only affects this model (blends with the WDN variant). "
"Try 0.3–0.5 for slightly cleaner, sharper results."
),
}
return tips.get(model_name, "")
# ────────────────────────────────────────────────────────
# Core upscaling
# ────────────────────────────────────────────────────────
def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
"""Real-ESRGAN function to restore (and upscale) images with robust defaults."""
if img is None:
return
# ----- Select backbone + weights -----
if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
# We'll ensure BOTH base and WDN weights exist; order matters for DNI.
file_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
]
else:
raise ValueError(f"Unknown model: {model_name}")
# ----- Ensure weights are on disk -----
# For the general-x4v3 case we download both; for others single file is fine.
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
weights_dir = os.path.join(ROOT_DIR, 'weights')
os.makedirs(weights_dir, exist_ok=True)
# Track model paths
local_paths = []
for url in file_url:
fname = os.path.basename(url)
local_path = os.path.join(weights_dir, fname)
if not os.path.isfile(local_path):
local_path = load_file_from_url(url=url, model_dir=weights_dir, progress=True)
local_paths.append(local_path)
# Default path(s)
if model_name == 'realesr-general-x4v3':
# Order: [base, wdn] then set DNI weights accordingly
base_path = os.path.join(weights_dir, 'realesr-general-x4v3.pth')
wdn_path = os.path.join(weights_dir, 'realesr-general-wdn-x4v3.pth')
model_path = [base_path, wdn_path]
denoise_strength = float(denoise_strength)
# Weight for WDN equals denoise_strength (cleaner); base gets the remainder
dni_weight = [1.0 - denoise_strength, denoise_strength]
else:
model_path = os.path.join(weights_dir, f"{model_name}.pth")
dni_weight = None
# ----- CUDA / precision / tiling -----
# Be defensive: cv2.cuda may not exist in CPU-only builds.
use_cuda = False
try:
use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0
except Exception:
use_cuda = False
gpu_id = 0 if use_cuda else None
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=256, # Safe VRAM default; increase if you have headroom
tile_pad=10,
pre_pad=10,
half=bool(use_cuda), # FP16 on GPU
gpu_id=gpu_id
)
# ----- Optional face enhancement -----
face_enhancer = None
if face_enhance:
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler
)
# ----- Convert PIL -> cv2 (handle RGB/RGBA) -----
cv_img = numpy.array(img)
if cv_img.ndim == 3 and cv_img.shape[2] == 4:
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
else:
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR)
# ----- Enhance -----
try:
if face_enhancer:
_, _, output = face_enhancer.enhance(cv_img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(cv_img, outscale=int(outscale))
except RuntimeError as error:
print('Error', error)
print('Tip: If you hit CUDA OOM, try a smaller tile size (e.g., 128).')
return None
# ----- cv2 -> RGBA/RGB for Gradio, also save -----
if output.ndim == 3 and output.shape[2] == 4:
display_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
extension = 'png'
else:
display_img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
extension = 'jpg'
out_filename = f"output_{rnd_string(8)}.{extension}"
try:
cv2.imwrite(out_filename, output)
global last_file
last_file = out_filename
except Exception as e:
print("Save error:", e)
return display_img # ndarray so Gradio displays immediately
# ────────────────────────────────────────────────────────
# UI
# ────────────────────────────────────────────────────────
def main():
with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
gr.Markdown("## Image Upscaler")
with gr.Accordion("Upscaling options", open=True):
with gr.Row():
model_name = gr.Dropdown(
label="Upscaler model",
choices=[
"RealESRGAN_x4plus",
"RealESRNet_x4plus",
"RealESRGAN_x4plus_anime_6B",
"RealESRGAN_x2plus",
"realesr-general-x4v3",
],
value="RealESRGAN_x4plus", # photoreal default
show_label=True
)
denoise_strength = gr.Slider(
label="Denoise Strength (only for realesr-general-x4v3)",
minimum=0, maximum=1, step=0.1, value=0.5
)
outscale = gr.Slider(
label="Resolution upscale",
minimum=1, maximum=6, step=1, value=4, show_label=True
)
face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)", value=False)
# Model tips panel (auto-updates)
model_tips = gr.Markdown(model_tip_text("RealESRGAN_x4plus"))
with gr.Row():
with gr.Group():
input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA")
input_image_properties = gr.Textbox(label="Image Properties", max_lines=1)
output_image = gr.Image(label="Output Image", image_mode="RGBA")
with gr.Row():
reset_btn = gr.Button("Remove images")
restore_btn = gr.Button("Upscale")
# Event listeners:
input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
model_name.change(fn=model_tip_text, inputs=model_name, outputs=model_tips)
restore_btn.click(
fn=realesrgan,
inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
outputs=output_image
)
reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
gr.Markdown("") # spacer
demo.launch()
if __name__ == "__main__":
main()