Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import random | |
| from datetime import date | |
| from shutil import copyfile | |
| import cv2 as cv | |
| import numpy as np | |
| from spaces import GPU | |
| import torch | |
| import torch.backends.cudnn | |
| import admin.settings as ws_settings | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = "7" | |
| # os.environ["OPENAI_LOGDIR"] = "./logs" | |
| # os.environ["MPI_DISABLED"] = "1" | |
| # os.environ.getattribute("HF_TOKEN") | |
| token = os.getenv("HF_TOKEN", None) | |
| import torch | |
| import torch.distributed as dist | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import DataLoader | |
| import datasets | |
| from utils_data.image_transforms import ArrayToTensor | |
| from train_settings.dvd.improved_diffusion import dist_util, logger | |
| from train_settings.dvd.improved_diffusion.script_util import args_to_dict, create_model_and_diffusion,model_and_diffusion_defaults | |
| from train_settings.models.geotr.geotr_core import GeoTr_Seg_Inf, reload_segmodel, reload_model, Seg | |
| from train_settings.models.geotr.unet_model import UNet | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import torch.nn.functional as F | |
| import torch as th | |
| from train_settings.dvd.improved_diffusion.gaussian_diffusion import GaussianDiffusion | |
| from train_settings.dvd.feature_backbones.VGG_features import VGGPyramid | |
| from train_settings.dvd.eval_utils import extract_raw_features_single,extract_raw_features_single2 | |
| from datasets.utils.warping import register_model2 | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| reg_model_bilin = register_model2((512,512), 'bilinear') | |
| def coords_grid_tensor(perturbed_img_shape): | |
| im_x, im_y = np.mgrid[0:perturbed_img_shape[0]-1:complex(perturbed_img_shape[0]), 0:perturbed_img_shape[1]-1:complex(perturbed_img_shape[1])] | |
| coords = np.stack((im_y,im_x), axis=2) # 先x后y,行序优先 | |
| coords = th.from_numpy(coords).float().permute(2,0,1).to('cuda') # (2, 512, 512) | |
| return coords.unsqueeze(0) # [2, 512, 512] | |
| def run_sample_lr_dewarping( | |
| settings, logger, diffusion, model, radius, source, feature_size, | |
| raw_corr, init_flow, c20, source_64, pyramid, doc_mask, | |
| seg_map_all=None, textline_map=None, init_feat=None | |
| ): | |
| model_kwsettings = {'init_flow': init_flow, 'src_feat': c20, 'src_64':None, | |
| 'y512':source, 'tmode':settings.env.train_mode, | |
| 'mask_cat': doc_mask, | |
| 'init_feat': init_feat, | |
| 'iter': settings.env.iter} # 'trg_feat': trg_feat | |
| # [1, 81, 64, 64] [1, 2, 64, 64] [1, 64, 64, 64] | |
| if settings.env.use_gt_mask == False: | |
| model_kwsettings['mask_y512'] = seg_map_all # [b, 384, 64, 64] | |
| if settings.env.use_line_mask == True: | |
| model_kwsettings['line_msk'] = textline_map # | |
| image_size_h, image_size_w = feature_size, feature_size | |
| logger.info(f"\nStarting sampling") | |
| sample, _ = diffusion.ddim_sample_loop( | |
| model, | |
| (1, 2, image_size_h, image_size_w), # 1,2,64,64 | |
| noise=None, | |
| clip_denoised=settings.env.clip_denoised, # false | |
| model_kwargs=model_kwsettings, | |
| eta=0.0, | |
| progress=True, | |
| denoised_fn=None, | |
| sampling_kwargs={'src_img': source}, # 'trg_img': target | |
| logger=logger, | |
| n_batch=settings.env.n_batch, | |
| time_variant = settings.env.time_variant, | |
| pyramid=pyramid | |
| ) | |
| sample = th.clamp(sample, min=-1, max=1) | |
| return sample | |
| def visualize_dewarping_single(settings, sample, source_vis): | |
| # os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped | |
| # warped_src = warp(source_vis.to(sample.device).float(), sample) # [1, 3, 1629, 981] | |
| warped_src = reg_model_bilin([source_vis.to(sample.device).float(), sample]) | |
| warped_src = warped_src[0].permute(1, 2, 0).detach().cpu().numpy()#*255. # (1873, 1353, 3) | |
| warped_src = Image.fromarray((warped_src).astype(np.uint8)) | |
| return warped_src | |
| def prepare_data_single(input_image, input_image_ori): | |
| source_vis = input_image_ori | |
| target_vis = None | |
| _, _, H_ori, W_ori = source_vis.shape | |
| source = input_image.to('cuda') # [1, 3, 914, 1380] torch.float32 | |
| source_0 = None | |
| target = None | |
| batch_ori = None | |
| batch_ori_inter = None | |
| target = None | |
| target_256 = None | |
| source_256 = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area').to('cuda') | |
| target_256 = None | |
| mask = torch.ones((1, 512, 512), dtype=torch.bool).to('cuda') # None | |
| return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0 | |
| def run_single_docunet(input_image_ori): | |
| input_image_ori = np.array(input_image_ori, dtype=np.uint8) # [x, y, 3] | |
| # resize to 512x512 | |
| input_image_resized = cv.resize(input_image_ori, (512, 512)) # [512, 512, 3] | |
| # transpose to [3, 512, 512] | |
| input_image_ori = np.transpose(input_image_ori, (2, 0, 1)) # [3, 512, 512] | |
| input_image = np.transpose(input_image_resized, (2, 0, 1)) # [3, 512, 512] | |
| input_image = input_image / 255 | |
| input_image_ori = torch.tensor(input_image_ori).unsqueeze(0) # [1, 3, 512, 512] | |
| input_image = torch.tensor(input_image).unsqueeze(0).float() # [1, 3, 512, 512] | |
| os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True) | |
| batch_preprocessing = None | |
| pyramid = VGGPyramid(train=False).to('cuda') | |
| SIZE = None | |
| radius = 4 | |
| raw_corr = None | |
| source_288 = F.interpolate(input_image, size=(288), mode='bilinear', align_corners=True).to('cuda') | |
| if settings.env.time_variant == True: | |
| init_feat = torch.zeros((input_image.shape[0], 256, 64, 64), dtype=torch.float32).to('cuda') | |
| else: | |
| init_feat = None | |
| with torch.inference_mode(): | |
| ref_bm, mask_x = pretrained_dewarp_model(source_288) # [1,2,288,288] 0~288 0~1 | |
| ref_flow = ref_bm/287.0 # [-1, 1] # [1,2,288,288] | |
| if settings.env.use_init_flow: | |
| init_flow = F.interpolate(ref_flow, size=(64), mode='bilinear', align_corners=True) # [24, 2, 64, 64] | |
| else: | |
| init_flow = torch.zeros((input_image.shape[0], 2, 64, 64), dtype=torch.float32).to('cuda') | |
| ( | |
| data, | |
| H_ori, # 512 | |
| W_ori, # 512 | |
| source, # [1, 3, 512, 512] 0-1 | |
| target, # None | |
| batch_ori, # None | |
| batch_ori_inter, # None | |
| source_256,# [1, 3, 256, 256] 0-1 | |
| target_256, # None | |
| source_vis, # [1, 3, H, W] cpu仅用于可视化 | |
| target_vis, # None | |
| mask, # [1, 512, 512] 全白 | |
| source_0 | |
| ) = prepare_data_single(input_image, input_image_ori) | |
| with torch.no_grad(): | |
| if settings.env.use_gt_mask == False: | |
| # ref_bm, mask_x = self.pretrained_dewarp_model(source_288) # [1,2,288,288] bm 0~288 mskx0-256 | |
| mskx, d0, hx6, hx5d, hx4d, hx3d, hx2d, hx1d = pretrained_seg_model(source_288) | |
| hx6 = F.interpolate(hx6, size=64, mode='bilinear', align_corners=False) | |
| hx5d = F.interpolate(hx5d, size=64, mode='bilinear', align_corners=False) | |
| hx4d = F.interpolate(hx4d, size=64, mode='bilinear', align_corners=False) | |
| hx3d = F.interpolate(hx3d, size=64, mode='bilinear', align_corners=False) | |
| hx2d = F.interpolate(hx2d, size=64, mode='bilinear', align_corners=False) | |
| hx1d = F.interpolate(hx1d, size=64, mode='bilinear', align_corners=False) | |
| seg_map_all = torch.cat((hx6, hx5d, hx4d, hx3d, hx2d, hx1d), dim=1) # [b, 384, 64, 64] | |
| # tv_save_image(mskx,"vis_hp/debug_vis/mskx.png") | |
| if settings.env.use_line_mask: | |
| textline_map, textline_mask = pretrained_line_seg_model(mskx) # [3, 64, 256, 256] | |
| textline_map = F.interpolate(textline_map, size=64, mode='bilinear', align_corners=False) # [3, 64, 64, 64] | |
| else: | |
| seg_map_all = None | |
| textline_map = None | |
| if settings.env.train_VGG: | |
| c20 = None | |
| feature_size = 64 | |
| else: | |
| feature_size = 64 | |
| if settings.env.train_mode == 'stage_1_dit_cat' or settings.env.train_mode =='stage_1_dit_cross': | |
| with th.no_grad(): | |
| c20 = extract_raw_features_single2(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64] | |
| # 平均互相关,VGG最浅层特征的下采样(512*512->64*64) | |
| else: | |
| with th.no_grad(): | |
| c20 = extract_raw_features_single(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64] | |
| # 平均互相关,VGG最浅层特征的下采样(512*512->64*64) | |
| source_64 = None # F.interpolate(source, size=(feature_size), mode='bilinear', align_corners=True) | |
| logger.info(f"Starting sampling with VGG Features") | |
| sample = run_sample_lr_dewarping( | |
| settings, | |
| logger, | |
| diffusion, | |
| model, | |
| radius, # 4 | |
| source, # [B, 3, 512, 512] 0~1 | |
| feature_size, # 64 | |
| raw_corr, # None | |
| init_flow, # [B, 2, 64, 64] -1~1 | |
| c20, # # [B, 64, 64, 64] | |
| source_64, # None | |
| pyramid, | |
| mask_x, #mask_x, # F.interpolate(mskx, size=(512), mode='bilinear', align_corners=True)[:,:1,:,:] , # mask_x | |
| seg_map_all, | |
| textline_map, | |
| init_feat | |
| ) # sample: [1, 2, 64, 64] 偏移量 [-1,1]范围 五步DDIM的结果 | |
| if settings.env.use_sr_net == False: | |
| sample = F.interpolate(sample, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场 | |
| # sample[:, 0, :, :] = sample[:, 0, :, :] * W_ori | |
| # sample[:, 1, :, :] = sample[:, 1, :, :] * H_ori | |
| base = F.interpolate(coords_grid_tensor((512,512))/511., size=(H_ori, W_ori), mode='bilinear', align_corners=True) | |
| # sample = ( ((sample + base.to(sample.device)) )*2 - 1 ) | |
| sample = ( ((sample + base.to(sample.device))*1 )*2 - 1 )*0.987 # (2 * (bm / 286.8) - 1) * 0.99 | |
| ref_flow = None | |
| if ref_flow is not None: | |
| ref_flow = F.interpolate(ref_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场 | |
| # ref_flow[:, 0, :, :] = ref_flow[:, 0, :, :] * W_ori | |
| # ref_flow[:, 1, :, :] = ref_flow[:, 1, :, :] * H_ori | |
| ref_flow = (ref_flow + base.to(ref_flow.device))*2 -1 | |
| # init_flow = F.interpolate(init_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True) | |
| else: | |
| raise ValueError("Invalid value") | |
| output = visualize_dewarping_single(settings, sample, source_vis) | |
| return output | |
| parser = argparse.ArgumentParser(description='Run a sampling scripts in train_settings.') | |
| parser.add_argument('--train_module', type=str, default='dvd', help='Name of module in the "train_settings/" folder.') | |
| parser.add_argument('--train_name', type=str, default='val_TDiff', help='Name of the train settings file.') | |
| parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).') | |
| parser.add_argument('--seed', type=int, default=1992, help='Pseudo-RNG seed') | |
| parser.add_argument('--name', type=str, default="gradio", help='Name of the experiment') | |
| parser.add_argument('--corruption', action='store_true') # 默认为false,触发则为true | |
| args = parser.parse_args() | |
| args.seed = random.randint(0, 3000000) | |
| args.seed = torch.initial_seed() & (2 ** 32 - 1) | |
| print('Seed is {}'.format(args.seed)) | |
| random.seed(int(args.seed)) | |
| np.random.seed(args.seed) | |
| cudnn_benchmark=args.cudnn_benchmark | |
| seed=args.seed | |
| corruption=args.corruption | |
| name=args.name | |
| # This is needed to avoid strange crashes related to opencv | |
| cv.setNumThreads(0) | |
| torch.backends.cudnn.benchmark = cudnn_benchmark | |
| # dd/mm/YY | |
| today = date.today() | |
| d1 = today.strftime("%d/%m/%Y") | |
| print('Sampling: {} {}\nDate: {}'.format(args.train_module, args.train_name, d1)) | |
| settings = ws_settings.Settings() | |
| settings.module_name = args.train_module | |
| settings.script_name = args.train_name | |
| settings.project_path = 'train_settings/{}/{}'.format(args.train_module, args.train_name) # 'train_settings/DiffMatch/val_DiffMatch' | |
| settings.seed = seed | |
| settings.name = name | |
| save_dir = os.path.join(settings.env.workspace_dir, settings.project_path) # 'checkpoints+train_settings/DiffMatch/val_DiffMatch' | |
| if not os.path.exists(save_dir): | |
| os.makedirs(save_dir) | |
| copyfile(settings.project_path + '.py', os.path.join(save_dir, settings.script_name + '.py')) | |
| settings.severity = 0 | |
| settings.corruption_number = 0 | |
| # dist_util.setup_dist() | |
| logger.configure(dir=f"SAMPLING_{settings.env.eval_dataset}_{settings.name}") | |
| logger.log(f"Corruption Disabled. Evaluating on Original {settings.env.eval_dataset}") | |
| logger.log("Loading model and diffusion...") | |
| model, diffusion = create_model_and_diffusion( | |
| device='cuda', | |
| train_mode=settings.env.train_mode, # stage 1 | |
| tv=settings.env.time_variant, | |
| **args_to_dict(settings, model_and_diffusion_defaults().keys()), | |
| ) | |
| setattr(diffusion, "settings", settings) | |
| pretrained_dewarp_model = GeoTr_Seg_Inf() | |
| settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg.pth", token=token) | |
| reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path) | |
| # reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path) | |
| pretrained_dewarp_model.to('cuda') | |
| pretrained_dewarp_model.eval() | |
| if settings.env.use_line_mask: | |
| pretrained_line_seg_model = UNet(n_channels=3, n_classes=1) | |
| pretrained_seg_model = Seg() | |
| settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="line_model2.pth", token=token) | |
| # line_model_ckpt = pretrained_line_seg_model.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model'] | |
| line_model_ckpt = torch.load(settings.env.line_seg_model_path, map_location='cpu')['model'] | |
| pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True) | |
| pretrained_line_seg_model.to('cuda') | |
| pretrained_line_seg_model.eval() | |
| settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg_model.pth", token=token) | |
| # seg_model_ckpt = pretrained_seg_model.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model'] | |
| seg_model_ckpt = torch.load(settings.env.new_seg_model_path, map_location='cpu')['model'] | |
| pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True) | |
| pretrained_seg_model.to('cuda') | |
| pretrained_seg_model.eval() | |
| settings.env.model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="model1852000.pt", token=token) | |
| # model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False) | |
| model_ckpt = torch.load(settings.env.model_path, map_location='cpu') | |
| model.cpu().load_state_dict(model_ckpt, strict=False) | |
| logger.log(f"Model loaded with {settings.env.model_path}") | |
| model.to('cuda') | |
| model.eval() | |
| if __name__ == '__main__': | |
| # demo = gr.Interface( | |
| # fn=run_single_docunet, | |
| # inputs=[ | |
| # gr.Image(type="pil", label="Input Image"), | |
| # ], | |
| # outputs=[ | |
| # gr.Image(type="numpy", label="Output Image"), | |
| # ], | |
| # title="Document Image Dewarping", | |
| # description="This is a demo for SIGGRAPH Asia 2025 paper 'DvD: Unleashing a Generative Paradigm for Document Dewarping via Coordinates-based Diffusion Model' ", | |
| # examples=EXAMPLES | |
| # ) | |
| example_img_list = [] | |
| for name in ['3_2 copy.png', '25_1 copy.png']: | |
| local_path = hf_hub_download( | |
| repo_id="hanquansanren/DvD", | |
| filename=f"examples/{name}", | |
| token=token | |
| ) | |
| dest_path = Path("examples") / name | |
| dest_path.parent.mkdir(exist_ok=True) | |
| shutil.copy(local_path, dest_path) | |
| example_img_list.append([str(dest_path)]) | |
| # example_img_list.append(local_path) | |
| print(example_img_list) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h2 style='text-align: center;'>Document Image Dewarping Demo</h2>") | |
| gr.Markdown("This is a demo for SIGGRAPH Asia 2025 paper 'DvD: Unleashing a Generative Paradigm for Document Dewarping via Coordinates-based Diffusion Model' ") | |
| with gr.Row(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| output_image = gr.Image(type="numpy", label="Output Image") | |
| gr.Examples( | |
| examples=example_img_list, | |
| inputs=[input_image], | |
| label="Click an example to load into Input Image" | |
| ) | |
| run_btn = gr.Button("Run") | |
| run_btn.click(fn=run_single_docunet, inputs=[input_image], outputs=[output_image]) | |
| # demo.launch(share=True, debug=True, server_name="10.7.88.77") | |
| demo.launch(ssr_mode=False) | |