File size: 8,301 Bytes
05fb4ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch
import torchvision.transforms as transforms
from termcolor import colored

from datasets.load_pre_made_dataset import \
    Doc_Dataset, Aug_Doc_Dataset, Doc3d_Dataset, Mix_Dataset

from datasets.batch_processing import GLUNetBatchPreprocessing
from utils_data.image_transforms import ArrayToTensor
from utils_data.loaders import Loader
from train_settings.models.geotr.geotr_core import GeoTr, GeoTr_Seg, GeoTr_Seg_womask, GeoTr_Seg_Inf,\
                                                     reload_segmodel, reload_model, Seg

from ..models.geotr.unet_model import UNet

from .improved_diffusion import dist_util, logger
from .improved_diffusion.resample import create_named_schedule_sampler
from .improved_diffusion.script_util import (args_to_dict,
                                             create_model_and_diffusion,
                                             model_and_diffusion_defaults)
from .improved_diffusion.train_util import TrainLoop


def run(settings):
    settings.description = 'train settings for dvd'

    dist_util.setup_dist()
    torch.cuda.set_device(dist_util.dev())
    logger.configure(dir=f"{settings.env.train_mode}_{settings.env.dataset_name}")
                     
    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        device=dist_util.dev(),
        train_mode=settings.env.train_mode,
        tv=settings.env.time_variant,
        **args_to_dict(settings, model_and_diffusion_defaults().keys())
    )
    # print(model)
    if settings.env.resume_checkpoint:
        state_dict = dist_util.load_state_dict(settings.env.resume_checkpoint, map_location='cpu')
        # # 删除部分参数
        # exclude_params = ['input_blocks.0.0.weight', 'input_blocks.0.0.bias']  # 替换为你想忽略的参数名
        # for param in exclude_params:
        #     if param in state_dict:
        #         del state_dict[param]
        model.load_state_dict(state_dict, strict=False)

    settings.device = dist_util.dev()
    print(f"Setting device to {settings.device}")
    model = model.to(dist_util.dev())
    schedule_sampler = create_named_schedule_sampler(settings.env.schedule_sampler, diffusion)
    # if settings.env.use_gt_mask == True:
    #     pretrained_dewarp_model = GeoTr_Seg_womask()
    # elif settings.env.use_gt_mask == False:
    pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
    pretrained_seg_model = Seg()
    # line_model_ckpt = torch.load(settings.env.line_seg_model_path, map_location='cpu')
    # print(checkpoint)
    # print(pretrained_line_seg_model)
    # new_state_dict = {k: v for k, v in checkpoint.items() if k.startswith('module.unet')}
    # torch.save({'model': new_state_dict}, './checkpoints/backup/line_model.pth')
    
    # new_state_dict = {}
    # for key, value in line_model_ckpt.items():
    #     # 如果key以 'module.unet.' 开头,去掉前缀
    #     if key.startswith('module.seg.'):
    #         new_key = key[len('module.seg.'):] 
    #         new_state_dict[new_key] = value
    #     else:
    #         pass
    #         # new_state_dict[key] = value

    # # 保存修改后的模型权重
    # torch.save({'model': new_state_dict}, './checkpoints/backup/seg_model.pth')
    
    line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
    pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)    
    pretrained_line_seg_model.to(dist_util.dev())
    pretrained_line_seg_model.eval()

    seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
    pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)    
    pretrained_seg_model.to(dist_util.dev())
    pretrained_seg_model.eval()
    

    # pretrained_dewarp_model = GeoTr_Seg_Inf()
    # reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
    # reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
    # pretrained_dewarp_model.to(dist_util.dev())
    # pretrained_dewarp_model.eval()
    
    
    
    logger.log("creating data loader...")    
    
    # 1. Define training and validation datasets
    # datasets, pre-processing of the images is done within the network function !
    if settings.env.dataset_name == 'doc_debug':
        img_transforms = transforms.Compose([ArrayToTensor(get_float=False)])
        flow_transform = transforms.Compose([ArrayToTensor()])  # just put channels first and put it to float
        train_dataset, _ = Doc_Dataset(root=settings.env.doc_debug,
                                        source_image_transform=img_transforms,
                                        target_image_transform=None,
                                        flow_transform=flow_transform,
                                        split=1,
                                        get_mapping=False)
        train_loader = Loader('train', train_dataset, batch_size=settings.env.batch_size, shuffle=True,
                            drop_last=False, training=True, num_workers=settings.env.n_threads)  
    elif settings.env.dataset_name == 'aug_doc':
        img_transforms = transforms.Compose([ArrayToTensor(get_float=False)])
        flow_transform = transforms.Compose([ArrayToTensor()])  # just put channels first and put it to float
        train_dataset, _ = Aug_Doc_Dataset(root=settings.env.doc_debug,
                                        source_image_transform=img_transforms,
                                        target_image_transform=None,
                                        flow_transform=flow_transform,
                                        split=1,
                                        get_mapping=False)
    elif settings.env.dataset_name == 'doc3d':
        img_transforms = transforms.Compose([ArrayToTensor(get_float=False)])
        flow_transform = transforms.Compose([ArrayToTensor()])  # just put channels first and put it to float
        train_dataset, _ = Doc3d_Dataset(root=settings.env.doc_debug,
                                        source_image_transform=img_transforms,
                                        target_image_transform=None,
                                        flow_transform=flow_transform,
                                        split=1,
                                        get_mapping=False)

        
    train_loader = Loader('train', train_dataset, batch_size=settings.env.batch_size, shuffle=True,
                        drop_last=False, training=True, num_workers=settings.env.n_threads)    
    
    # Setting dataset name into diffusion because of the semantic setting.
    setattr(diffusion, 'dataset', settings.env.dataset_name) 

    # but better results are obtained with using simple bilinear interpolation instead of deconvolutions.
    print(colored('==> ', 'blue') + 'model created.')

    logger.log("training...")
    batch_preprocessing = GLUNetBatchPreprocessing(settings, apply_mask=False, apply_mask_zero_borders=False,
                                                sparse_ground_truth=False)
    
    # 4. Define loss module    
    TrainLoop(
        model=model, 
        pretrained_dewarp_model = pretrained_seg_model,
        pretrained_line_seg_model = pretrained_line_seg_model,
        diffusion=diffusion,
        settings=settings,
        batch_preprocessing=batch_preprocessing,
        data=train_loader,
        batch_size=settings.env.batch_size, 
        microbatch=settings.env.microbatch, 
        lr=settings.env.lr, 
        ema_rate=settings.env.ema_rate, 
        log_interval=settings.env.log_interval, 
        save_interval=settings.env.save_interval, 
        resume_checkpoint=settings.env.resume_checkpoint,
        use_fp16=settings.env.use_fp16, 
        fp16_scale_growth=settings.env.fp16_scale_growth, 
        schedule_sampler=schedule_sampler,
        weight_decay=settings.env.weight_decay, 
        lr_anneal_steps=settings.env.lr_anneal_steps, 
        resume_step=settings.env.resume_step,
        use_gt_mask = settings.env.use_gt_mask, 
        use_init_flow = settings.env.use_init_flow,
        train_mode = settings.env.train_mode,
        use_line_mask = settings.env.use_line_mask
    ).run_loop_dewarping()