File size: 3,999 Bytes
981b0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import cv2
import numpy as np
import os
import os.path as osp
import time
import gradio as gr

os.environ["GRADIO_TEMP_DIR"] = "./gradio_tmp"

from models.TextEnhancement import MARCONetPlus
from utils.utils_image import imread_uint, uint2tensor4, tensor2uint
from networks.rrdbnet2_arch import RRDBNet as BSRGAN

# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Background restoration model (lazy loading)
BGModel = None
def load_bg_model():
    """Load BSRGAN model for background super-resolution"""
    global BGModel
    if BGModel is None:
        BGModel = BSRGAN(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2)
        model_old = torch.load('./checkpoints/bsrgan_bg.pth', map_location=device)
        state_dict = BGModel.state_dict()
        for ((key, param), (key2, _)) in zip(model_old.items(), state_dict.items()):
            state_dict[key2] = param
        BGModel.load_state_dict(state_dict, strict=True)
        BGModel.eval()
        for k, v in BGModel.named_parameters():
            v.requires_grad = False
        BGModel = BGModel.to(device)

# Text restoration model
TextModel = MARCONetPlus(
    './checkpoints/net_w_encoder_860000.pth',
    './checkpoints/net_prior_860000.pth',
    './checkpoints/net_sr_860000.pth',
    './checkpoints/yolo11m_short_character.pt',
    device=device
)

def gradio_inference(input_img, aligned=False, bg_sr=False, scale_factor=2):
    """Run MARCONetPlus inference with optional background SR"""
    if input_img is None:
        return None

    # Convert input image (PIL) to OpenCV format
    img_L = cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)
    height_L, width_L = img_L.shape[:2]

    # Background super-resolution
    if not aligned and bg_sr:
        load_bg_model()
        img_E = cv2.resize(img_L, (int(width_L//8*8), int(height_L//8*8)), interpolation=cv2.INTER_AREA)
        img_E = uint2tensor4(img_E).to(device)
        with torch.no_grad():
            try:
                img_E = BGModel(img_E)
            except:
                torch.cuda.empty_cache()
                max_size = 1536
                scale = min(max_size / width_L, max_size / height_L, 1.0)
                new_width = int(width_L * scale)
                new_height = int(height_L * scale)
                img_E = cv2.resize(img_L, (new_width//8*8, new_height//8*8), interpolation=cv2.INTER_AREA)
                img_E = uint2tensor4(img_E).to(device)
                img_E = BGModel(img_E)
        img_E = tensor2uint(img_E)
    else:
        img_E = img_L

    # Resize background
    width_S = width_L * scale_factor
    height_S = height_L * scale_factor
    img_E = cv2.resize(img_E, (width_S, height_S), interpolation=cv2.INTER_AREA)

    # Text restoration
    SQ, ori_texts, en_texts, debug_texts, pred_texts = TextModel.handle_texts(
        img=img_L, bg=img_E, sf=scale_factor, is_aligned=aligned
    )

    if SQ is None:
        return None

    if not aligned:
        SQ = cv2.resize(SQ.astype(np.float32), (width_S, height_S), interpolation=cv2.INTER_AREA)
        out_img = SQ[:, :, ::-1].astype(np.uint8)
    else:
        out_img = en_texts[0][:, :, ::-1].astype(np.uint8)

    return out_img

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# MARCONetPlus Text Image Restoration")

    with gr.Row():
        input_img = gr.Image(type="pil", label="Input Image")
        output_img = gr.Image(type="numpy", label="Restored Output")

    with gr.Row():
        aligned = gr.Checkbox(label="Aligned (cropped text regions)", value=False)
        bg_sr = gr.Checkbox(label="Background SR (BSRGAN)", value=False)
        scale_factor = gr.Slider(1, 4, value=2, step=1, label="Scale Factor")

    run_btn = gr.Button("Run Inference")

    run_btn.click(
        fn=gradio_inference,
        inputs=[input_img, aligned, bg_sr, scale_factor],
        outputs=[output_img]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7121)