Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import os.path as osp | |
| import time | |
| import gradio as gr | |
| os.environ["GRADIO_TEMP_DIR"] = "./gradio_tmp" | |
| from models.TextEnhancement import MARCONetPlus | |
| from utils.utils_image import imread_uint, uint2tensor4, tensor2uint | |
| from networks.rrdbnet2_arch import RRDBNet as BSRGAN | |
| # Initialize device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Background restoration model (lazy loading) | |
| BGModel = None | |
| def load_bg_model(): | |
| """Load BSRGAN model for background super-resolution""" | |
| global BGModel | |
| if BGModel is None: | |
| BGModel = BSRGAN(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) | |
| model_old = torch.load('./checkpoints/bsrgan_bg.pth', map_location=device) | |
| state_dict = BGModel.state_dict() | |
| for ((key, param), (key2, _)) in zip(model_old.items(), state_dict.items()): | |
| state_dict[key2] = param | |
| BGModel.load_state_dict(state_dict, strict=True) | |
| BGModel.eval() | |
| for k, v in BGModel.named_parameters(): | |
| v.requires_grad = False | |
| BGModel = BGModel.to(device) | |
| # Text restoration model | |
| TextModel = MARCONetPlus( | |
| './checkpoints/net_w_encoder_860000.pth', | |
| './checkpoints/net_prior_860000.pth', | |
| './checkpoints/net_sr_860000.pth', | |
| './checkpoints/yolo11m_short_character.pt', | |
| device=device | |
| ) | |
| def gradio_inference(input_img, aligned=False, bg_sr=False, scale_factor=2): | |
| """Run MARCONetPlus inference with optional background SR""" | |
| if input_img is None: | |
| return None | |
| # Convert input image (PIL) to OpenCV format | |
| img_L = cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR) | |
| height_L, width_L = img_L.shape[:2] | |
| # Background super-resolution | |
| if not aligned and bg_sr: | |
| load_bg_model() | |
| img_E = cv2.resize(img_L, (int(width_L//8*8), int(height_L//8*8)), interpolation=cv2.INTER_AREA) | |
| img_E = uint2tensor4(img_E).to(device) | |
| with torch.no_grad(): | |
| try: | |
| img_E = BGModel(img_E) | |
| except: | |
| torch.cuda.empty_cache() | |
| max_size = 1536 | |
| scale = min(max_size / width_L, max_size / height_L, 1.0) | |
| new_width = int(width_L * scale) | |
| new_height = int(height_L * scale) | |
| img_E = cv2.resize(img_L, (new_width//8*8, new_height//8*8), interpolation=cv2.INTER_AREA) | |
| img_E = uint2tensor4(img_E).to(device) | |
| img_E = BGModel(img_E) | |
| img_E = tensor2uint(img_E) | |
| else: | |
| img_E = img_L | |
| # Resize background | |
| width_S = width_L * scale_factor | |
| height_S = height_L * scale_factor | |
| img_E = cv2.resize(img_E, (width_S, height_S), interpolation=cv2.INTER_AREA) | |
| # Text restoration | |
| SQ, ori_texts, en_texts, debug_texts, pred_texts = TextModel.handle_texts( | |
| img=img_L, bg=img_E, sf=scale_factor, is_aligned=aligned | |
| ) | |
| if SQ is None: | |
| return None | |
| if not aligned: | |
| SQ = cv2.resize(SQ.astype(np.float32), (width_S, height_S), interpolation=cv2.INTER_AREA) | |
| out_img = SQ[:, :, ::-1].astype(np.uint8) | |
| else: | |
| out_img = en_texts[0][:, :, ::-1].astype(np.uint8) | |
| return out_img | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# MARCONetPlus Text Image Restoration") | |
| with gr.Row(): | |
| input_img = gr.Image(type="pil", label="Input Image") | |
| output_img = gr.Image(type="numpy", label="Restored Output") | |
| with gr.Row(): | |
| aligned = gr.Checkbox(label="Aligned (cropped text regions)", value=False) | |
| bg_sr = gr.Checkbox(label="Background SR (BSRGAN)", value=False) | |
| scale_factor = gr.Slider(1, 4, value=2, step=1, label="Scale Factor") | |
| run_btn = gr.Button("Run Inference") | |
| run_btn.click( | |
| fn=gradio_inference, | |
| inputs=[input_img, aligned, bg_sr, scale_factor], | |
| outputs=[output_img] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7121) | |