File size: 16,950 Bytes
125b486
 
 
 
 
 
40dc653
125b486
 
 
 
7c048f9
 
0d4891c
8d5f45b
 
16bb8a1
 
125b486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a69f28
40dc653
9ca0c11
63c81e0
 
125b486
 
 
 
 
 
b54f31e
125b486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56f1236
125b486
 
 
 
 
 
 
 
 
 
 
 
 
 
b54f31e
125b486
 
 
 
 
 
b54f31e
125b486
b54f31e
125b486
 
 
 
 
 
 
9847dc4
125b486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b54f31e
125b486
 
 
 
 
b54f31e
125b486
 
b54f31e
125b486
 
 
 
 
 
 
 
 
b54f31e
125b486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b54f31e
125b486
 
 
 
 
b54f31e
125b486
 
 
 
 
 
0450dc0
125b486
16bb8a1
125b486
 
b54f31e
125b486
 
 
 
 
16bb8a1
040d089
 
 
b54f31e
125b486
 
16bb8a1
040d089
 
 
b54f31e
125b486
 
16bb8a1
b54f31e
15c8a9e
 
125b486
 
b54f31e
125b486
 
 
 
8d9577d
 
 
 
 
 
 
 
 
 
 
 
 
9ca0c11
 
2267dd9
9ca0c11
 
c93f0a4
2267dd9
7c048f9
 
 
 
 
9ca0c11
1ab5b47
9ca0c11
 
8d9577d
a5559ee
 
8d9577d
 
 
 
 
91ff64c
8d9577d
 
 
 
 
125b486
9ca0c11
 
 
 
0450dc0
0d4891c
6c4b492
 
 
 
16bb8a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
import argparse
import random
from datetime import date
from shutil import copyfile
import cv2 as cv
import numpy as np
from spaces import GPU 
import torch
import torch.backends.cudnn
import admin.settings as ws_settings
import os
import shutil
from pathlib import Path
# os.environ["CUDA_VISIBLE_DEVICES"] = "7"
# os.environ["OPENAI_LOGDIR"] = "./logs"
# os.environ["MPI_DISABLED"] = "1"
# os.environ.getattribute("HF_TOKEN")
token = os.getenv("HF_TOKEN", None)
import torch
import torch.distributed as dist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import datasets
from utils_data.image_transforms import ArrayToTensor
from train_settings.dvd.improved_diffusion import dist_util, logger
from train_settings.dvd.improved_diffusion.script_util import args_to_dict, create_model_and_diffusion,model_and_diffusion_defaults
from train_settings.models.geotr.geotr_core import GeoTr_Seg_Inf, reload_segmodel, reload_model, Seg   
from train_settings.models.geotr.unet_model import UNet
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import torch as th
from train_settings.dvd.improved_diffusion.gaussian_diffusion import GaussianDiffusion
from train_settings.dvd.feature_backbones.VGG_features import VGGPyramid
from train_settings.dvd.eval_utils import extract_raw_features_single,extract_raw_features_single2
from datasets.utils.warping import register_model2

import gradio as gr
from huggingface_hub import hf_hub_download





reg_model_bilin = register_model2((512,512), 'bilinear')

def coords_grid_tensor(perturbed_img_shape):
    im_x, im_y = np.mgrid[0:perturbed_img_shape[0]-1:complex(perturbed_img_shape[0]), 0:perturbed_img_shape[1]-1:complex(perturbed_img_shape[1])]
    coords = np.stack((im_y,im_x), axis=2) # 先x后y,行序优先
    coords = th.from_numpy(coords).float().permute(2,0,1).to('cuda')  # (2, 512, 512)
    return coords.unsqueeze(0) # [2, 512, 512]

def run_sample_lr_dewarping(
    settings, logger, diffusion, model, radius, source, feature_size, 
    raw_corr, init_flow, c20, source_64, pyramid, doc_mask, 
    seg_map_all=None, textline_map=None, init_feat=None
):
    model_kwsettings = {'init_flow': init_flow, 'src_feat': c20, 'src_64':None, 
                        'y512':source, 'tmode':settings.env.train_mode,
                        'mask_cat': doc_mask,
                        'init_feat': init_feat,
                        'iter': settings.env.iter} # 'trg_feat': trg_feat
    # [1, 81, 64, 64] [1, 2, 64, 64] [1, 64, 64, 64]
    if settings.env.use_gt_mask == False:
        model_kwsettings['mask_y512'] = seg_map_all # [b, 384, 64, 64]
    if settings.env.use_line_mask == True:
        model_kwsettings['line_msk'] = textline_map # 
    image_size_h, image_size_w = feature_size, feature_size

    logger.info(f"\nStarting sampling")

    sample, _ = diffusion.ddim_sample_loop(
        model,
        (1, 2, image_size_h, image_size_w), # 1,2,64,64
        noise=None,
        clip_denoised=settings.env.clip_denoised, # false
        model_kwargs=model_kwsettings,
        eta=0.0,
        progress=True,
        denoised_fn=None,
        sampling_kwargs={'src_img': source}, # 'trg_img': target
        logger=logger,
        n_batch=settings.env.n_batch,
        time_variant = settings.env.time_variant,
        pyramid=pyramid
    )

    sample = th.clamp(sample, min=-1, max=1)
    return sample



def visualize_dewarping_single(settings, sample, source_vis):
    # os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped
    # warped_src = warp(source_vis.to(sample.device).float(), sample) # [1, 3, 1629, 981]
    warped_src = reg_model_bilin([source_vis.to(sample.device).float(), sample])
    warped_src = warped_src[0].permute(1, 2, 0).detach().cpu().numpy()#*255. # (1873, 1353, 3)
    warped_src = Image.fromarray((warped_src).astype(np.uint8))

    return warped_src




def prepare_data_single(input_image, input_image_ori):
    source_vis = input_image_ori
    target_vis = None
    _, _, H_ori, W_ori = source_vis.shape
    source = input_image.to('cuda')  # [1, 3, 914, 1380]  torch.float32
    source_0 = None
    target = None 
    batch_ori = None    
    batch_ori_inter = None 
    target = None
    target_256 = None
    source_256 = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area').to('cuda')
    target_256 = None
    mask = torch.ones((1, 512, 512), dtype=torch.bool).to('cuda') # None

    return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0



        
        
@GPU
def run_single_docunet(input_image_ori):
    input_image_ori = np.array(input_image_ori, dtype=np.uint8)  # [x, y, 3]

    # resize to 512x512
    input_image_resized = cv.resize(input_image_ori, (512, 512))  # [512, 512, 3]

    # transpose to [3, 512, 512]
    input_image_ori = np.transpose(input_image_ori, (2, 0, 1))  # [3, 512, 512]
    input_image = np.transpose(input_image_resized, (2, 0, 1))  # [3, 512, 512]

    input_image = input_image / 255

    input_image_ori = torch.tensor(input_image_ori).unsqueeze(0) # [1, 3, 512, 512]
    input_image = torch.tensor(input_image).unsqueeze(0).float() # [1, 3, 512, 512]

    os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True)
    batch_preprocessing = None
    pyramid = VGGPyramid(train=False).to('cuda')
    SIZE = None


    radius = 4
    raw_corr = None
    source_288 = F.interpolate(input_image, size=(288), mode='bilinear', align_corners=True).to('cuda')

    if settings.env.time_variant == True:
        init_feat = torch.zeros((input_image.shape[0], 256, 64, 64), dtype=torch.float32).to('cuda') 
    else:
        init_feat = None

    with torch.inference_mode():
        ref_bm, mask_x = pretrained_dewarp_model(source_288) # [1,2,288,288] 0~288  0~1
        ref_flow = ref_bm/287.0 # [-1, 1]  # [1,2,288,288]
    if settings.env.use_init_flow: 
        init_flow = F.interpolate(ref_flow, size=(64), mode='bilinear', align_corners=True) # [24, 2, 64, 64]                  
    else:
        init_flow = torch.zeros((input_image.shape[0], 2, 64, 64), dtype=torch.float32).to('cuda') 

    (
        data,
        H_ori, # 512
        W_ori, # 512
        source, # [1, 3, 512, 512] 0-1
        target, # None
        batch_ori, # None
        batch_ori_inter, # None
        source_256,# [1, 3, 256, 256] 0-1
        target_256, # None
        source_vis, # [1, 3, H, W] cpu仅用于可视化
        target_vis, # None
        mask, # [1, 512, 512] 全白
        source_0
    ) = prepare_data_single(input_image, input_image_ori)



    with torch.no_grad():
        if settings.env.use_gt_mask == False:
            # ref_bm, mask_x = self.pretrained_dewarp_model(source_288) # [1,2,288,288] bm 0~288 mskx0-256
            mskx, d0, hx6, hx5d, hx4d, hx3d, hx2d, hx1d = pretrained_seg_model(source_288)
            hx6 = F.interpolate(hx6, size=64, mode='bilinear', align_corners=False)
            hx5d = F.interpolate(hx5d, size=64, mode='bilinear', align_corners=False)
            hx4d = F.interpolate(hx4d, size=64, mode='bilinear', align_corners=False)
            hx3d = F.interpolate(hx3d, size=64, mode='bilinear', align_corners=False)
            hx2d = F.interpolate(hx2d, size=64, mode='bilinear', align_corners=False)
            hx1d = F.interpolate(hx1d, size=64, mode='bilinear', align_corners=False)

            seg_map_all = torch.cat((hx6, hx5d, hx4d, hx3d, hx2d, hx1d), dim=1) # [b, 384, 64, 64]
            # tv_save_image(mskx,"vis_hp/debug_vis/mskx.png")
            if settings.env.use_line_mask:
                textline_map, textline_mask = pretrained_line_seg_model(mskx) # [3, 64, 256, 256]
                textline_map = F.interpolate(textline_map, size=64, mode='bilinear', align_corners=False) #  [3, 64, 64, 64]
        else:
            seg_map_all = None
            textline_map = None

    if settings.env.train_VGG:
        c20 = None
        feature_size = 64
    else:
        feature_size = 64
        if settings.env.train_mode == 'stage_1_dit_cat' or settings.env.train_mode =='stage_1_dit_cross':
            with th.no_grad():
                c20  = extract_raw_features_single2(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
            # 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
        else:
            with th.no_grad():
                c20  = extract_raw_features_single(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
            # 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
    
    source_64 = None # F.interpolate(source, size=(feature_size), mode='bilinear', align_corners=True)
    logger.info(f"Starting sampling with VGG Features")

    sample = run_sample_lr_dewarping(
        settings,
        logger,
        diffusion,
        model,
        radius, # 4
        source, # [B, 3, 512, 512] 0~1
        feature_size, # 64
        raw_corr, # None
        init_flow, # [B, 2, 64, 64]   -1~1
        c20, # # [B, 64, 64, 64]   
        source_64, # None
        pyramid,
        mask_x, #mask_x,  # F.interpolate(mskx, size=(512), mode='bilinear', align_corners=True)[:,:1,:,:] , # mask_x 
        seg_map_all,
        textline_map,
        init_feat
    ) # sample: [1, 2, 64, 64] 偏移量 [-1,1]范围 五步DDIM的结果

    if settings.env.use_sr_net == False:
        sample = F.interpolate(sample, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
        # sample[:, 0, :, :] = sample[:, 0, :, :] * W_ori
        # sample[:, 1, :, :] = sample[:, 1, :, :] * H_ori 
        base = F.interpolate(coords_grid_tensor((512,512))/511., size=(H_ori, W_ori), mode='bilinear', align_corners=True)
        # sample = ( ((sample + base.to(sample.device)) )*2 - 1 )
        sample = ( ((sample + base.to(sample.device))*1 )*2 - 1 )*0.987 #  (2 * (bm / 286.8) - 1) * 0.99
        ref_flow = None
        if ref_flow is not None: 
            ref_flow = F.interpolate(ref_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
            # ref_flow[:, 0, :, :] = ref_flow[:, 0, :, :] * W_ori 
            # ref_flow[:, 1, :, :] = ref_flow[:, 1, :, :] * H_ori 
            ref_flow  = (ref_flow + base.to(ref_flow.device))*2 -1
        # init_flow = F.interpolate(init_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True)
    else:
        raise ValueError("Invalid value")


    output = visualize_dewarping_single(settings, sample, source_vis)

    return output










parser = argparse.ArgumentParser(description='Run a sampling scripts in train_settings.')
parser.add_argument('--train_module', type=str, default='dvd', help='Name of module in the "train_settings/" folder.')
parser.add_argument('--train_name', type=str, default='val_TDiff', help='Name of the train settings file.')
parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).')
parser.add_argument('--seed', type=int, default=1992, help='Pseudo-RNG seed')
parser.add_argument('--name', type=str, default="gradio", help='Name of the experiment')
parser.add_argument('--corruption', action='store_true') # 默认为false,触发则为true

args = parser.parse_args()

args.seed = random.randint(0, 3000000)
args.seed = torch.initial_seed() & (2 ** 32 - 1)
print('Seed is {}'.format(args.seed))
random.seed(int(args.seed))
np.random.seed(args.seed)
    
cudnn_benchmark=args.cudnn_benchmark
seed=args.seed
corruption=args.corruption
name=args.name

# This is needed to avoid strange crashes related to opencv
cv.setNumThreads(0)

torch.backends.cudnn.benchmark = cudnn_benchmark

# dd/mm/YY
today = date.today()
d1 = today.strftime("%d/%m/%Y")
print('Sampling:  {}  {}\nDate: {}'.format(args.train_module, args.train_name, d1))

settings = ws_settings.Settings()
settings.module_name = args.train_module
settings.script_name = args.train_name
settings.project_path = 'train_settings/{}/{}'.format(args.train_module, args.train_name) # 'train_settings/DiffMatch/val_DiffMatch'
settings.seed = seed
settings.name = name

save_dir = os.path.join(settings.env.workspace_dir, settings.project_path) # 'checkpoints+train_settings/DiffMatch/val_DiffMatch'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
copyfile(settings.project_path + '.py', os.path.join(save_dir, settings.script_name + '.py'))


settings.severity = 0
settings.corruption_number = 0


# dist_util.setup_dist()
logger.configure(dir=f"SAMPLING_{settings.env.eval_dataset}_{settings.name}")
logger.log(f"Corruption Disabled. Evaluating on Original {settings.env.eval_dataset}")
logger.log("Loading model and diffusion...")

model, diffusion = create_model_and_diffusion(
    device='cuda',
    train_mode=settings.env.train_mode, # stage 1
    tv=settings.env.time_variant,
    **args_to_dict(settings, model_and_diffusion_defaults().keys()),
)
setattr(diffusion, "settings", settings)


pretrained_dewarp_model = GeoTr_Seg_Inf()
settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg.pth", token=token)
reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
# reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
pretrained_dewarp_model.to('cuda')
pretrained_dewarp_model.eval()

if settings.env.use_line_mask:
    pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
    pretrained_seg_model = Seg()
    settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="line_model2.pth", token=token)
    # line_model_ckpt = pretrained_line_seg_model.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
    line_model_ckpt = torch.load(settings.env.line_seg_model_path, map_location='cpu')['model']
    pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)    
    pretrained_line_seg_model.to('cuda')
    pretrained_line_seg_model.eval()

    settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg_model.pth", token=token)
    # seg_model_ckpt = pretrained_seg_model.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
    seg_model_ckpt = torch.load(settings.env.new_seg_model_path, map_location='cpu')['model']
    pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)    
    pretrained_seg_model.to('cuda')
    pretrained_seg_model.eval()

settings.env.model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="model1852000.pt", token=token)
# model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
model_ckpt = torch.load(settings.env.model_path, map_location='cpu')
model.cpu().load_state_dict(model_ckpt, strict=False)
logger.log(f"Model loaded with {settings.env.model_path}")

model.to('cuda')
model.eval()


if __name__ == '__main__':
    # demo = gr.Interface(
    #     fn=run_single_docunet,
    #     inputs=[
    #         gr.Image(type="pil", label="Input Image"),
    #     ],
    #     outputs=[
    #         gr.Image(type="numpy", label="Output Image"),
    #     ],
    #     title="Document Image Dewarping",
    #     description="This is a demo for SIGGRAPH Asia 2025 paper 'DvD: Unleashing a Generative Paradigm for Document Dewarping via Coordinates-based Diffusion Model' ",
    #     examples=EXAMPLES
    # )

    example_img_list = []
    for name in ['3_2 copy.png', '25_1 copy.png']:
        local_path = hf_hub_download(
            repo_id="hanquansanren/DvD",   
            filename=f"examples/{name}",   
            token=token
        )
        dest_path = Path("examples") / name
        dest_path.parent.mkdir(exist_ok=True)
        shutil.copy(local_path, dest_path)
        example_img_list.append([str(dest_path)])
        # example_img_list.append(local_path)
    
    print(example_img_list)


    with gr.Blocks() as demo:
        gr.Markdown("<h2 style='text-align: center;'>Document Image Dewarping Demo</h2>")
        gr.Markdown("This is a demo for SIGGRAPH Asia 2025 paper 'DvD: Unleashing a Generative Paradigm for Document Dewarping via Coordinates-based Diffusion Model' ")
        
        with gr.Row():
            input_image = gr.Image(type="pil", label="Input Image")
            output_image = gr.Image(type="numpy", label="Output Image")
        gr.Examples(
            examples=example_img_list,
            inputs=[input_image],
            label="Click an example to load into Input Image"
        )
        run_btn = gr.Button("Run")
        run_btn.click(fn=run_single_docunet, inputs=[input_image], outputs=[output_image])





    # demo.launch(share=True, debug=True, server_name="10.7.88.77")
    demo.launch(ssr_mode=False)