File size: 5,067 Bytes
05fb4ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import torch
import torch.distributed as dist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from collections import OrderedDict

import datasets
from utils_data.image_transforms import ArrayToTensor

from .evaluation import run_evaluation_docunet
from .improved_diffusion import dist_util, logger
from .improved_diffusion.script_util import (args_to_dict,
                                             create_model_and_diffusion,
                                             model_and_diffusion_defaults)
from train_settings.models.geotr.geotr_core import GeoTr, GeoTr_Seg, GeoTr_Seg_Inf,\
                                                    reload_segmodel, reload_model, Seg
                                                      
from datasets.doc_dataset.doc_benchmark import Doc_dewarping_Data1
from train_settings.dvd.evaluation import validate


from ..models.geotr.unet_model import UNet

class WrappedDiffusionModel(torch.nn.Module):
    def __init__(self, model, t, model_kwargs):
        super().__init__()
        self.model = model
        self.t = t
        self.model_kwargs = model_kwargs

    def forward(self, x):
        return self.model(x, self.t, **self.model_kwargs)

def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

def run(settings):
    dist_util.setup_dist()
    logger.configure(dir=f"SAMPLING_{settings.env.eval_dataset}_{settings.name}")
    logger.log(f"Corruption Disabled. Evaluating on Original {settings.env.eval_dataset}")
    logger.log("Loading model and diffusion...")
    
    model, diffusion = create_model_and_diffusion(
        device=dist_util.dev(),
        train_mode=settings.env.train_mode, # stage 1
        tv=settings.env.time_variant,
        **args_to_dict(settings, model_and_diffusion_defaults().keys()),
    )
    setattr(diffusion, "settings", settings)
    
    
    
    
    # pretrained_dewarp_model = GeoTr(num_attn_layers=6, num_token=(288//8)**2)
    pretrained_dewarp_model = GeoTr_Seg_Inf()
    reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
    # reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
    pretrained_dewarp_model.to(dist_util.dev())
    pretrained_dewarp_model.eval()

    if settings.env.use_line_mask:
        pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
        pretrained_seg_model = Seg()
        line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
        pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)    
        pretrained_line_seg_model.to(dist_util.dev())
        pretrained_line_seg_model.eval()

        seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
        pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)    
        pretrained_seg_model.to(dist_util.dev())
        pretrained_seg_model.eval()




    model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
    logger.log(f"Model loaded with {settings.env.model_path}")


    model.to(dist_util.dev())
    print(get_parameter_number(model))
    model.eval()
    
    logger.log("Creating data loader...")
    logger.info('\n:============== Logging Configs ==============')
    for key, value in settings.env.__dict__.items():
        if key in ['model_path', 'timestep_respacing', 'eval_dataset']:
            logger.info(f"\t{key}:\t{value}") 
    logger.info(':===============================================\n')



    if settings.env.eval_dataset_name == "docunet" or settings.env.eval_dataset_name == "dir300" or settings.env.eval_dataset_name == "anyphoto" or settings.env.eval_dataset_name == "docreal": 
        # 1. Define training and validation datasets
        input_transform = transforms.Compose([ArrayToTensor(get_float=True)])  # only put channel first
        test_set = datasets.Doc_benchmark(
            settings.env.eval_dataset, 
            input_transform,
        )
        
        test_loader = DataLoader(test_set, batch_size=1, shuffle=True,
                                drop_last=False, num_workers=8)
        logger.info(f"Starting sampling")
        run_evaluation_docunet(
            settings, logger, test_loader, diffusion, model, pretrained_dewarp_model,pretrained_line_seg_model,pretrained_seg_model)
    elif settings.env.eval_dataset_name == "doc_val": 
        val_set = Doc_dewarping_Data1(root_path= settings.env.eval_dataset, transforms=input_transform, resolution=288, model_setting = "doctr")
        val_loader = DataLoader(val_set, batch_size=1, num_workers=4,
                                drop_last=False, pin_memory=True,shuffle=False)
        prec1 = validate(val_loader, pretrained_dewarp_model)

    dist.barrier()
    logger.log("sampling complete")