Image_Upscaler / app.py
JS6969's picture
Update app.py
356195b verified
raw
history blame
24.3 kB
# ────────────────────────────────────────────────────────
# TorchVision compat shim (MUST be before importing basicsr)
# Fixes: ModuleNotFoundError: torchvision.transforms.functional_tensor
# ────────────────────────────────────────────────────────
import sys, types
try:
import torchvision.transforms.functional_tensor as _ft # noqa: F401
except Exception:
from torchvision.transforms import functional as _F
_mod = types.ModuleType("torchvision.transforms.functional_tensor")
_mod.rgb_to_grayscale = _F.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = _mod
# ────────────────────────────────────────────────────────
# Spaces ZeroGPU decorator (safe no-op locally)
# ────────────────────────────────────────────────────────
try:
import spaces
GPU = spaces.GPU
except Exception:
def GPU(*args, **kwargs):
def _wrap(f): return f
return _wrap
# ────────────────────────────────────────────────────────
# Standard imports
# ────────────────────────────────────────────────────────
import gradio as gr
import cv2
import numpy
import os
import random
import inspect
from pathlib import Path
import zipfile
import tempfile
from basicsr.archs.rrdbnet_arch import RRDBNet as _RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
# ────────────────────────────────────────────────────────
# Globals
# ────────────────────────────────────────────────────────
last_file = None
img_mode = "RGBA"
# ────────────────────────────────────────────────────────
# Utilities
# ────────────────────────────────────────────────────────
def rnd_string(x: int) -> str:
characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
return "".join((random.choice(characters)) for _ in range(x))
def reset():
global last_file
if last_file:
try:
print(f"Deleting {last_file} ...")
os.remove(last_file)
except Exception as e:
print("Delete error:", e)
last_file = None
return gr.update(value=None), gr.update(value=None)
def has_transparency(img):
if img.info.get("transparency", None) is not None:
return True
if img.mode == "P":
transparent = img.info.get("transparency", -1)
for _, index in img.getcolors():
if index == transparent:
return True
elif img.mode == "RGBA":
extrema = img.getextrema()
if extrema[3][0] < 255:
return True
return False
def image_properties(img):
global img_mode
if img:
img_mode = "RGBA" if has_transparency(img) else "RGB"
return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
def model_tip_text(model_name: str) -> str:
tips = {
"RealESRGAN_x4plus": (
"**RealESRGAN_x4plus (4Γ—)** β€” Best for photoreal images (portraits, landscapes). "
"Balanced detail recovery. Good default for Flux realism."
),
"RealESRNet_x4plus": (
"**RealESRNet_x4plus (4Γ—)** β€” Softer but great on noisy/compressed sources "
"(old JPEGs, screenshots)."
),
"RealESRGAN_x4plus_anime_6B": (
"**RealESRGAN_x4plus_anime_6B (4Γ—)** β€” For anime/illustrations/line art only. "
"Not recommended for real-life photos."
),
"RealESRGAN_x2plus": (
"**RealESRGAN_x2plus (2Γ—)** β€” Faster, lighter 2Γ— cleanup when you don't need 4Γ—."
),
"realesr-general-x4v3": (
"**realesr-general-x4v3 (4Γ—)** β€” Versatile mixed-content model with adjustable denoise. "
"**Denoise Strength** slider only affects this model (blends with the WDN variant). "
"Try 0.3–0.5 for slightly cleaner, sharper results."
),
}
return tips.get(model_name, "")
# ────────────────────────────────────────────────────────
# RRDBNet builder that tolerates different Basicsr signatures
# ────────────────────────────────────────────────────────
def build_rrdb(scale: int, num_block: int):
"""
Creates an RRDBNet across several possible constructor signatures used by basicsr/realesrgan.
Tries, in order:
1) keyword style (num_in_ch/num_out_ch/num_feat/num_block/num_grow_ch/scale)
2) alt keyword style (in_nc/out_nc/nf/nb/gc/sf)
3) positional with gc before scale
4) positional with scale before gc
"""
# Try keyword: "num_*" + "scale"
try:
return _RRDBNet(
num_in_ch=3, num_out_ch=3,
num_feat=64, num_block=num_block,
num_grow_ch=32, scale=scale
)
except TypeError:
pass
# Try keyword: "in_nc/out_nc" + "sf"
try:
return _RRDBNet(
in_nc=3, out_nc=3,
nf=64, nb=num_block,
gc=32, sf=scale
)
except TypeError:
pass
# Inspect parameters to guess positional order
params = list(inspect.signature(_RRDBNet).parameters.keys())
# Common positional (gc, scale) order
try:
return _RRDBNet(3, 3, 64, num_block, 32, scale)
except TypeError:
pass
# Alternate positional (scale, gc) order
try:
return _RRDBNet(3, 3, 64, num_block, scale, 32)
except TypeError as e:
raise TypeError(f"RRDBNet signature not recognized: {e}")
#Factor an upsampler builder
def get_upsampler(model_name: str, outscale: int, tile: int = 256):
# Build the same backbone/weights as in realesrgan(), but return a ready RealESRGANer
if model_name == 'RealESRGAN_x4plus':
model = build_rrdb(scale=4, num_block=23); netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
elif model_name == 'RealESRNet_x4plus':
model = build_rrdb(scale=4, num_block=23); netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
elif model_name == 'RealESRGAN_x4plus_anime_6B':
model = build_rrdb(scale=4, num_block=6); netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif model_name == 'RealESRGAN_x2plus':
model = build_rrdb(scale=2, num_block=23); netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
elif model_name == 'realesr-general-x4v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'); netscale = 4
file_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
]
else:
raise ValueError(f"Unknown model: {model_name}")
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
weights_dir = os.path.join(ROOT_DIR, 'weights')
os.makedirs(weights_dir, exist_ok=True)
for url in file_url:
fname = os.path.basename(url)
local_path = os.path.join(weights_dir, fname)
if not os.path.isfile(local_path):
load_file_from_url(url=url, model_dir=weights_dir, progress=True)
if model_name == 'realesr-general-x4v3':
model_path = [
os.path.join(weights_dir, 'realesr-general-x4v3.pth'),
os.path.join(weights_dir, 'realesr-general-wdn-x4v3.pth'),
]
dni_weight = None # supplied at call site if using denoise blend
else:
model_path = os.path.join(weights_dir, f"{model_name}.pth")
dni_weight = None
use_cuda = False
try:
use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0
except Exception:
use_cuda = False
gpu_id = 0 if use_cuda else None
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=tile or 256,
tile_pad=10,
pre_pad=10,
half=bool(use_cuda),
gpu_id=gpu_id
)
return upsampler, netscale, use_cuda, model_path
# ────────────────────────────────────────────────────────
# Core upscaling
# Decorated for Hugging Face Spaces ZeroGPU
# ────────────────────────────────────────────────────────
@GPU() # lets Spaces know this function uses GPU; safe no-op locally
def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
if img is None:
return
# ----- Select backbone + weights -----
if model_name == 'RealESRGAN_x4plus':
model = build_rrdb(scale=4, num_block=23); netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
elif model_name == 'RealESRNet_x4plus':
model = build_rrdb(scale=4, num_block=23); netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
elif model_name == 'RealESRGAN_x4plus_anime_6B':
model = build_rrdb(scale=4, num_block=6); netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif model_name == 'RealESRGAN_x2plus':
model = build_rrdb(scale=2, num_block=23); netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
elif model_name == 'realesr-general-x4v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'); netscale = 4
file_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
]
else:
raise ValueError(f"Unknown model: {model_name}")
# ----- Ensure weights on disk -----
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
weights_dir = os.path.join(ROOT_DIR, 'weights')
os.makedirs(weights_dir, exist_ok=True)
for url in file_url:
fname = os.path.basename(url)
local_path = os.path.join(weights_dir, fname)
if not os.path.isfile(local_path):
load_file_from_url(url=url, model_dir=weights_dir, progress=True)
if model_name == 'realesr-general-x4v3':
base_path = os.path.join(weights_dir, 'realesr-general-x4v3.pth')
wdn_path = os.path.join(weights_dir, 'realesr-general-wdn-x4v3.pth')
model_path = [base_path, wdn_path]
denoise_strength = float(denoise_strength)
dni_weight = [1.0 - denoise_strength, denoise_strength] # base, WDN
else:
model_path = os.path.join(weights_dir, f"{model_name}.pth")
dni_weight = None
# ----- CUDA / precision / tiling -----
use_cuda = False
try:
use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0
except Exception:
use_cuda = False
gpu_id = 0 if use_cuda else None
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=256, # VRAM-safe default; lower to 128 if OOM
tile_pad=10,
pre_pad=10,
half=bool(use_cuda),
gpu_id=gpu_id
)
# ----- Optional face enhancement -----
face_enhancer = None
if face_enhance:
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler
)
# ----- PIL -> cv2 -----
cv_img = numpy.array(img)
if cv_img.ndim == 3 and cv_img.shape[2] == 4:
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
else:
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR)
# ----- Enhance -----
try:
if face_enhancer:
_, _, output = face_enhancer.enhance(cv_img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(cv_img, outscale=int(outscale))
except RuntimeError as error:
print('Error', error)
print('Tip: If you hit CUDA OOM, try a smaller tile size (e.g., 128).')
return None
# ----- cv2 -> display ndarray, also save -----
if output.ndim == 3 and output.shape[2] == 4:
display_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
extension = 'png'
else:
display_img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
extension = 'jpg'
out_filename = f"output_{rnd_string(8)}.{extension}"
try:
cv2.imwrite(out_filename, output)
global last_file
last_file = out_filename
except Exception as e:
print("Save error:", e)
return display_img
#Add a batch upscaler that preserves filenames
def render_progress(pct: float, text: str = "") -> str:
pct = max(0.0, min(100.0, float(pct)))
bar = f"<div style='width:100%;border:1px solid #ddd;border-radius:6px;overflow:hidden;height:12px;'><div style='height:100%;width:{pct:.1f}%;background:#3b82f6;'></div></div>"
label = f"<div style='font-size:12px;opacity:.8;margin-top:4px;'>{text} {pct:.1f}%</div>"
return bar + label
def batch_realesrgan(
files: list, # from gr.Files (type='filepath')
model_name: str,
denoise_strength: float,
face_enhance: bool,
outscale: int,
tile: int,
batch_size: int = 16,
):
"""
Processes multiple images in batches, preserves original file names for outputs,
and returns (gallery, zip_file, details, progress_html) with streamed progress.
"""
# Validate
if not files or len(files) == 0:
yield None, None, "No files uploaded.", render_progress(0, "Idle")
return
# Build upsampler once (much faster than per-image)
upsampler, netscale, use_cuda, model_path = get_upsampler(model_name, outscale, tile=tile)
# Optional: face enhancer (same as your single-image path)
face_enhancer = None
if face_enhance:
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler
)
# Prepare work/output dirs
work = Path(tempfile.mkdtemp(prefix="batch_up_"))
out_dir = work / "upscaled"
out_dir.mkdir(parents=True, exist_ok=True)
# Normalize list of input paths
src_paths = [Path(f.name if hasattr(f, "name") else f) for f in files]
total = len(src_paths)
done = 0
out_paths = []
# If realesr-general-x4v3: support blending base + WDN via dni (optional)
dni_weight = None
if model_name == "realesr-general-x4v3":
# Blend [base, WDN] with user's slider
denoise_strength = float(denoise_strength)
dni_weight = [1.0 - denoise_strength, denoise_strength]
# RealESRGANer.enhance accepts dni_weight override via attribute on the instance
try:
upsampler.dni_weight = dni_weight
except Exception:
pass
# Process in batches (I/O and PIL open are still per-file)
for i in range(0, total, int(max(1, batch_size))):
batch = src_paths[i:i + int(max(1, batch_size))]
for src in batch:
try:
# Load as RGB consistently
from PIL import Image
with Image.open(src) as im:
img = im.convert("RGB")
arr = numpy.array(img)
arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
if face_enhancer:
_, _, output = face_enhancer.enhance(arr, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(arr, outscale=int(outscale))
# Preserve original file name & (reasonable) extension
orig_ext = src.suffix.lower()
ext = orig_ext if orig_ext in (".png", ".jpg", ".jpeg") else ".png"
out_path = out_dir / (src.stem + ext)
# Save (keep alpha if produced, else RGB)
if output.ndim == 3 and output.shape[2] == 4:
cv2.imwrite(str(out_path.with_suffix(".png")), output) # 4ch β†’ PNG
out_path = out_path.with_suffix(".png")
else:
if ext in (".jpg", ".jpeg"):
cv2.imwrite(str(out_path), output, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
else:
cv2.imwrite(str(out_path), output) # PNG default
out_paths.append(out_path)
except Exception as e:
# Continue on errors
print(f"[batch] Error on {src}: {e}")
finally:
done += 1
pct = (done / total) * 100.0 if total else 0.0
remaining = max(0, total - done)
msg = f"Upscaling… {done}/{total} done Β· {remaining} remaining (batch {(i//batch_size)+1}/{(total+batch_size-1)//batch_size})"
yield None, None, msg, render_progress(pct, msg)
if not out_paths:
yield None, None, "No outputs produced.", render_progress(100, "Finished")
return
# Small even-sampled gallery for preview
def _sample_even(seq, n=30):
if not seq: return []
if len(seq) <= n: return [str(p) for p in seq]
step = (len(seq)-1) / (n-1)
idxs = [round(i*step) for i in range(n)]
seen, out = set(), []
for i in idxs:
if i not in seen:
out.append(str(seq[int(i)])); seen.add(int(i))
return out
out_paths = sorted(out_paths) # stable
gallery = _sample_even(out_paths, 30)
# Zip with same file names
zip_path = work / "upscaled.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for p in out_paths:
zf.write(p, arcname=p.name)
details = f"Upscaled {len(out_paths)} images β†’ {out_dir}"
yield gallery, str(zip_path), details, render_progress(100, "Complete")
# ────────────────────────────────────────────────────────
# UI
# ────────────────────────────────────────────────────────
def main():
with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
gr.Markdown("## Image Upscaler")
with gr.Accordion("Upscaling options", open=True):
with gr.Row():
model_name = gr.Dropdown(
label="Upscaler model",
choices=[
"RealESRGAN_x4plus",
"RealESRNet_x4plus",
"RealESRGAN_x4plus_anime_6B",
"RealESRGAN_x2plus",
"realesr-general-x4v3",
],
value="RealESRGAN_x4plus",
show_label=True
)
denoise_strength = gr.Slider(
label="Denoise Strength (only for realesr-general-x4v3)",
minimum=0, maximum=1, step=0.1, value=0.5
)
outscale = gr.Slider(
label="Resolution upscale",
minimum=1, maximum=6, step=1, value=4, show_label=True
)
face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)", value=False)
model_tips = gr.Markdown(model_tip_text("RealESRGAN_x4plus"))
with gr.Row():
with gr.Group():
input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA")
input_image_properties = gr.Textbox(label="Image Properties", max_lines=1)
output_image = gr.Image(label="Output Image", image_mode="RGBA")
with gr.Row():
reset_btn = gr.Button("Remove images")
restore_btn = gr.Button("Upscale")
input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
model_name.change(fn=model_tip_text, inputs=model_name, outputs=model_tips)
restore_btn.click(
fn=realesrgan,
inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
outputs=output_image
)
reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
# --- Batch Upscale (multi-image) ---
gr.Markdown("### Batch Upscale")
with gr.Accordion("Batch options", open=True):
with gr.Row():
batch_files = gr.Files(
label="Upload multiple images (PNG/JPG/JPEG)",
type="filepath",
file_types=[".png", ".jpg", ".jpeg"],
)
with gr.Row():
batch_tile = gr.Number(label="Tile size (0/auto β†’ 256)", value=256, precision=0)
batch_size = gr.Number(label="Batch size (images per batch)", value=16, precision=0)
with gr.Row():
batch_btn = gr.Button("Upscale Batch", variant="primary")
batch_prog = gr.HTML(render_progress(0.0, "Idle"))
batch_gallery = gr.Gallery(label="Preview (sampled 30)", columns=6, height=420)
batch_zip = gr.File(label="Download upscaled.zip")
batch_details = gr.Markdown("")
# Wire it up (generator β†’ streaming)
batch_btn.click(
fn=batch_realesrgan,
inputs=[batch_files, model_name, denoise_strength, face_enhance, outscale, batch_tile, batch_size],
outputs=[batch_gallery, batch_zip, batch_details, batch_prog],
)
gr.Markdown("") # spacer
# Disable SSR (ZeroGPU + Gradio logs suggested turning this off)
demo.launch(ssr_mode=False) # set share=True for a public link
def main():
with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
# ... your current UI (plus batch section) ...
return demo
if __name__ == "__main__":
demo = main()
demo.queue().launch(ssr_mode=False)