DvD / train_settings /dvd /train_TDiff.py
hanquansanren's picture
Add application file
05fb4ab
import torch
import torchvision.transforms as transforms
from termcolor import colored
from datasets.load_pre_made_dataset import \
Doc_Dataset, Aug_Doc_Dataset, Doc3d_Dataset, Mix_Dataset
from datasets.batch_processing import GLUNetBatchPreprocessing
from utils_data.image_transforms import ArrayToTensor
from utils_data.loaders import Loader
from train_settings.models.geotr.geotr_core import GeoTr, GeoTr_Seg, GeoTr_Seg_womask, GeoTr_Seg_Inf,\
reload_segmodel, reload_model, Seg
from ..models.geotr.unet_model import UNet
from .improved_diffusion import dist_util, logger
from .improved_diffusion.resample import create_named_schedule_sampler
from .improved_diffusion.script_util import (args_to_dict,
create_model_and_diffusion,
model_and_diffusion_defaults)
from .improved_diffusion.train_util import TrainLoop
def run(settings):
settings.description = 'train settings for dvd'
dist_util.setup_dist()
torch.cuda.set_device(dist_util.dev())
logger.configure(dir=f"{settings.env.train_mode}_{settings.env.dataset_name}")
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
device=dist_util.dev(),
train_mode=settings.env.train_mode,
tv=settings.env.time_variant,
**args_to_dict(settings, model_and_diffusion_defaults().keys())
)
# print(model)
if settings.env.resume_checkpoint:
state_dict = dist_util.load_state_dict(settings.env.resume_checkpoint, map_location='cpu')
# # 删除部分参数
# exclude_params = ['input_blocks.0.0.weight', 'input_blocks.0.0.bias'] # 替换为你想忽略的参数名
# for param in exclude_params:
# if param in state_dict:
# del state_dict[param]
model.load_state_dict(state_dict, strict=False)
settings.device = dist_util.dev()
print(f"Setting device to {settings.device}")
model = model.to(dist_util.dev())
schedule_sampler = create_named_schedule_sampler(settings.env.schedule_sampler, diffusion)
# if settings.env.use_gt_mask == True:
# pretrained_dewarp_model = GeoTr_Seg_womask()
# elif settings.env.use_gt_mask == False:
pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
pretrained_seg_model = Seg()
# line_model_ckpt = torch.load(settings.env.line_seg_model_path, map_location='cpu')
# print(checkpoint)
# print(pretrained_line_seg_model)
# new_state_dict = {k: v for k, v in checkpoint.items() if k.startswith('module.unet')}
# torch.save({'model': new_state_dict}, './checkpoints/backup/line_model.pth')
# new_state_dict = {}
# for key, value in line_model_ckpt.items():
# # 如果key以 'module.unet.' 开头,去掉前缀
# if key.startswith('module.seg.'):
# new_key = key[len('module.seg.'):]
# new_state_dict[new_key] = value
# else:
# pass
# # new_state_dict[key] = value
# # 保存修改后的模型权重
# torch.save({'model': new_state_dict}, './checkpoints/backup/seg_model.pth')
line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
pretrained_line_seg_model.to(dist_util.dev())
pretrained_line_seg_model.eval()
seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
pretrained_seg_model.to(dist_util.dev())
pretrained_seg_model.eval()
# pretrained_dewarp_model = GeoTr_Seg_Inf()
# reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
# reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
# pretrained_dewarp_model.to(dist_util.dev())
# pretrained_dewarp_model.eval()
logger.log("creating data loader...")
# 1. Define training and validation datasets
# datasets, pre-processing of the images is done within the network function !
if settings.env.dataset_name == 'doc_debug':
img_transforms = transforms.Compose([ArrayToTensor(get_float=False)])
flow_transform = transforms.Compose([ArrayToTensor()]) # just put channels first and put it to float
train_dataset, _ = Doc_Dataset(root=settings.env.doc_debug,
source_image_transform=img_transforms,
target_image_transform=None,
flow_transform=flow_transform,
split=1,
get_mapping=False)
train_loader = Loader('train', train_dataset, batch_size=settings.env.batch_size, shuffle=True,
drop_last=False, training=True, num_workers=settings.env.n_threads)
elif settings.env.dataset_name == 'aug_doc':
img_transforms = transforms.Compose([ArrayToTensor(get_float=False)])
flow_transform = transforms.Compose([ArrayToTensor()]) # just put channels first and put it to float
train_dataset, _ = Aug_Doc_Dataset(root=settings.env.doc_debug,
source_image_transform=img_transforms,
target_image_transform=None,
flow_transform=flow_transform,
split=1,
get_mapping=False)
elif settings.env.dataset_name == 'doc3d':
img_transforms = transforms.Compose([ArrayToTensor(get_float=False)])
flow_transform = transforms.Compose([ArrayToTensor()]) # just put channels first and put it to float
train_dataset, _ = Doc3d_Dataset(root=settings.env.doc_debug,
source_image_transform=img_transforms,
target_image_transform=None,
flow_transform=flow_transform,
split=1,
get_mapping=False)
train_loader = Loader('train', train_dataset, batch_size=settings.env.batch_size, shuffle=True,
drop_last=False, training=True, num_workers=settings.env.n_threads)
# Setting dataset name into diffusion because of the semantic setting.
setattr(diffusion, 'dataset', settings.env.dataset_name)
# but better results are obtained with using simple bilinear interpolation instead of deconvolutions.
print(colored('==> ', 'blue') + 'model created.')
logger.log("training...")
batch_preprocessing = GLUNetBatchPreprocessing(settings, apply_mask=False, apply_mask_zero_borders=False,
sparse_ground_truth=False)
# 4. Define loss module
TrainLoop(
model=model,
pretrained_dewarp_model = pretrained_seg_model,
pretrained_line_seg_model = pretrained_line_seg_model,
diffusion=diffusion,
settings=settings,
batch_preprocessing=batch_preprocessing,
data=train_loader,
batch_size=settings.env.batch_size,
microbatch=settings.env.microbatch,
lr=settings.env.lr,
ema_rate=settings.env.ema_rate,
log_interval=settings.env.log_interval,
save_interval=settings.env.save_interval,
resume_checkpoint=settings.env.resume_checkpoint,
use_fp16=settings.env.use_fp16,
fp16_scale_growth=settings.env.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=settings.env.weight_decay,
lr_anneal_steps=settings.env.lr_anneal_steps,
resume_step=settings.env.resume_step,
use_gt_mask = settings.env.use_gt_mask,
use_init_flow = settings.env.use_init_flow,
train_mode = settings.env.train_mode,
use_line_mask = settings.env.use_line_mask
).run_loop_dewarping()