DvD / train_settings /dvd /evaluation.py
hanquansanren's picture
Add application file
05fb4ab
import os
import numpy as np
import torch as th
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from train_settings.dvd.feature_backbones.VGG_features import VGGPyramid
from utils_flow.visualization_utils import visualize, visualize_dewarping
from .eval_utils import extract_raw_features_single, extract_raw_features_single2
from .improved_diffusion import dist_util
from .improved_diffusion.gaussian_diffusion import GaussianDiffusion
import torch
from torchvision.utils import save_image as tv_save_image
def prepare_data(settings, batch_preprocessing, SIZE, data):
if 'source_image_ori' in data:
source_vis = data['source_image_ori'] # B, C, 512, 512 torch.uint8 cpu
else:
source_vis = data['source_image']
if 'target_image' in data:
target_vis = data['target_image']
else:
target_vis = None
_, _, H_ori, W_ori = source_vis.shape
# data = batch_preprocessing(data)
source = data['source_image'].to(dist_util.dev()) # [1, 3, 914, 1380] torch.float32
if 'source_image_0' in data:
source_0 = data['source_image_0'].to(dist_util.dev())
else:
source_0 = None
if 'target_image' in data:
target = data['target_image'] # [1, 3, 914, 1380] torch.float32
else:
target = None
if 'flow_map' in data:
batch_ori = data['flow_map'] # [1, 2, 914, 1380] torch.float32
else:
batch_ori = None
if 'flow_map_inter' in data:
batch_ori_inter = data['flow_map_inter'] # [1, 2, 914, 1380] torch.float32
else:
batch_ori_inter = None
if target is not None:
target = F.interpolate(target, size=512, mode='bilinear', align_corners=False) # [1, 3, 512, 512]
target_256 = data['target_image_256'].to(dist_util.dev()) # [1, 3, 256, 256]
else:
target = None
target_256 = None
# source = F.interpolate(source, size=512, mode='bilinear', align_corners=False) #[1, 3, 512, 512]
# source_256 = data['source_image_256'].to(dist_util.dev()) # [1, 3, 256, 256]
if settings.env.eval_dataset == 'hp-240':# false
source_256 = source
target_256 = target
else: # true
data['source_image_256'] = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area')
source_256 = data['source_image_256'].to(dist_util.dev())
if 'target_image_256' in data:
target_256 = data['target_image_256']
else:
target_256 = None
if 'correspondence_mask' in data:
mask = data['correspondence_mask'] # torch.bool [1, 914, 1380]
else:
mask = torch.ones((1, 512, 512), dtype=torch.bool).to(dist_util.dev()) # None
return data, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
def run_sample_lr_dewarping(
settings, logger, diffusion, model, radius, source, feature_size,
raw_corr, init_flow, c20, source_64, pyramid, doc_mask,
seg_map_all=None, textline_map=None, init_feat=None
):
# init_flow = init_flow * feature_size
# coords = initialize_flow(init_flow.shape[0], feature_size, feature_size, dist_util.dev())
# coords_warped = coords + init_flow
# local_corr = local_Corr(
# raw_corr.view(1, 1, feature_size, feature_size, feature_size, feature_size).to(dist_util.dev()),
# coords_warped.to(dist_util.dev()),
# radius,
# )
# local_corr = F.interpolate(
# local_corr.view(1, (2 * radius + 1) ** 2, feature_size, feature_size),
# size=feature_size,
# mode='bilinear',
# align_corners=True,
# )
# init_flow = F.interpolate(init_flow, size=feature_size, mode='bilinear', align_corners=True)
# init_flow /= feature_size
model_kwsettings = {'init_flow': init_flow, 'src_feat': c20, 'src_64':None,
'y512':source, 'tmode':settings.env.train_mode,
'mask_cat': doc_mask,
'init_feat': init_feat,
'iter': settings.env.iter} # 'trg_feat': trg_feat
# [1, 81, 64, 64] [1, 2, 64, 64] [1, 64, 64, 64]
if settings.env.use_gt_mask == False:
model_kwsettings['mask_y512'] = seg_map_all # [b, 384, 64, 64]
if settings.env.use_line_mask == True:
model_kwsettings['line_msk'] = textline_map #
image_size_h, image_size_w = feature_size, feature_size
# tv_save_image(source,"vis_hp/debug_vis/source.png")
# tv_save_image(doc_mask,"vis_hp/debug_vis/mask512_8877.png")
logger.info(f"\nStarting sampling")
sample, _ = diffusion.ddim_sample_loop(
model,
(1, 2, image_size_h, image_size_w), # 1,2,64,64
noise=None,
clip_denoised=settings.env.clip_denoised, # false
model_kwargs=model_kwsettings,
eta=0.0,
progress=True,
denoised_fn=None,
sampling_kwargs={'src_img': source}, # 'trg_img': target
logger=logger,
n_batch=settings.env.n_batch,
time_variant = settings.env.time_variant,
pyramid=pyramid
)
sample = th.clamp(sample, min=-1, max=1)
return sample
def run_evaluation_docunet(
settings, logger, val_loader, diffusion: GaussianDiffusion, model,
pretrained_dewarp_model,pretrained_line_seg_model=None,pretrained_seg_model=None
):
os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True)
# batch_preprocessing = DocBatchPreprocessing(
# settings, apply_mask=False, apply_mask_zero_borders=False, sparse_ground_truth=False
# )
batch_preprocessing = None
pbar = tqdm(enumerate(val_loader), total=len(val_loader))
pyramid = VGGPyramid(train=False).to(dist_util.dev())
SIZE = None
trian_t = []
for i, data in pbar:
radius = 4
raw_corr = None
image_size = 64
data_path = data['path']
# ref test
# source_288 = F.interpolate(data['source_image']/255., size=(288), mode='bilinear', align_corners=True).to(dist_util.dev())
source_288 = F.interpolate(data['source_image'], size=(288), mode='bilinear', align_corners=True).to(dist_util.dev())
# tv_save_image(data['source_image']/255., "vis_hp/msk5/in{}".format(data['path'][0].split('/')[-1]))
if settings.env.time_variant == True:
init_feat = torch.zeros((data['source_image'].shape[0], 256, image_size, image_size), dtype=torch.float32).to(dist_util.dev())
else:
init_feat = None
with torch.inference_mode():
ref_bm, mask_x = pretrained_dewarp_model(source_288) # [1,2,288,288] 0~288 0~1
# base = coords_grid_tensor((288,288)).to(ref_bm.device) # [1, 2, 288, 288]
# ref_flow = ref_bm - base
ref_flow = ref_bm/287.0 # [-1, 1] # [1,2,288,288]
if settings.env.use_init_flow:
init_flow = F.interpolate(ref_flow, size=(image_size), mode='bilinear', align_corners=True) # [24, 2, 64, 64]
else:
init_flow = torch.zeros((data['source_image'].shape[0], 2, image_size, image_size), dtype=torch.float32).to(dist_util.dev())
# mask_x = F.interpolate(mask_x, size=(512), mode='bilinear', align_corners=True) # 0-1
# data['source_image'] = mask_x*data['source_image'].to(dist_util.dev()) # 0-255
# mask_x_vis = mask_x*data['source_image'].to(dist_util.dev()) # 不存在最优mask阈值策略
# tv_save_image(mask_x_vis, "vis_hp/msk_wore/{}".format(data['path'][0].split('/')[-1])) # 0~1 (288,288)
(
data,
H_ori, # 512
W_ori, # 512
source, # [1, 3, 512, 512] 0-1
target, # None
batch_ori, # None
batch_ori_inter, # None
source_256,# [1, 3, 256, 256] 0-1
target_256, # None
source_vis, # [1, 3, H, W] cpu仅用于可视化
target_vis, # None
mask, # [1, 512, 512] 全白
source_0
) = prepare_data(settings, batch_preprocessing, SIZE, data)
with torch.no_grad():
if settings.env.use_gt_mask == False:
# ref_bm, mask_x = self.pretrained_dewarp_model(source_288) # [1,2,288,288] bm 0~288 mskx0-256
mskx, d0, hx6, hx5d, hx4d, hx3d, hx2d, hx1d = pretrained_seg_model(source_288)
hx6 = F.interpolate(hx6, size=image_size, mode='bilinear', align_corners=False)
hx5d = F.interpolate(hx5d, size=image_size, mode='bilinear', align_corners=False)
hx4d = F.interpolate(hx4d, size=image_size, mode='bilinear', align_corners=False)
hx3d = F.interpolate(hx3d, size=image_size, mode='bilinear', align_corners=False)
hx2d = F.interpolate(hx2d, size=image_size, mode='bilinear', align_corners=False)
hx1d = F.interpolate(hx1d, size=image_size, mode='bilinear', align_corners=False)
seg_map_all = torch.cat((hx6, hx5d, hx4d, hx3d, hx2d, hx1d), dim=1) # [b, 384, 64, 64]
# tv_save_image(mskx,"vis_hp/debug_vis/mskx.png")
if settings.env.use_line_mask:
textline_map, textline_mask = pretrained_line_seg_model(mskx) # [3, 64, 256, 256]
textline_map = F.interpolate(textline_map, size=image_size, mode='bilinear', align_corners=False) # [3, 64, 64, 64]
else:
seg_map_all = None
textline_map = None
if settings.env.train_VGG:
c20 = None
feature_size = image_size
else:
feature_size = image_size
if settings.env.train_mode == 'stage_1_dit_cat' or settings.env.train_mode =='stage_1_dit_cross':
with th.no_grad():
c20 = extract_raw_features_single2(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
# 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
else:
with th.no_grad():
c20 = extract_raw_features_single(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
# 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
source_64 = None # F.interpolate(source, size=(feature_size), mode='bilinear', align_corners=True)
logger.info(f"Starting sampling with VGG Features")
# init_flow = correlation_to_flow_w_argmax(
# raw_corr.view(1, 1, feature_size, feature_size, feature_size, feature_size),
# output_shape=(feature_size, feature_size),
# ) # B, 2, 64, 64 初始偏移场
import time
begin_train = time.time()
sample = run_sample_lr_dewarping(
settings,
logger,
diffusion,
model,
radius, # 4
source, # [B, 3, 512, 512] 0~1
feature_size, # 64
raw_corr, # None
init_flow, # [B, 2, 64, 64] -1~1
c20, # # [B, 64, 64, 64]
source_64, # None
pyramid,
mask_x, #mask_x, # F.interpolate(mskx, size=(512), mode='bilinear', align_corners=True)[:,:1,:,:] , # mask_x
seg_map_all,
textline_map,
init_feat
) # sample: [1, 2, 64, 64] 偏移量 [-1,1]范围 五步DDIM的结果
trian_t.append(time.time()-begin_train) # 从这里宣布结束训练当前epoch
# if settings.env.use_sr_net == True: # false
# logger.info('Running super resolution')
# sample_sr = None
# for j in range(1):
# batch_ori, sample_sr, init_flow_sr = run_sample_sr(
# settings, logger, diffusion_sr, model_sr, pyramid, data, sample, sample_sr
# )
# sample_ = F.interpolate(sample_sr, size=(H_ori, W_ori), mode='bilinear', align_corners=True)
# sample_[:, 0, :, :] = sample_[:, 0, :, :] * W_ori
# sample_[:, 1, :, :] = sample_[:, 1, :, :] * H_ori
# sample_ = sample_.permute(0, 2, 3, 1)[mask]
# batch_ori_ = batch_ori.permute(0, 2, 3, 1)[mask]
# epe = th.sum((sample_ - batch_ori_.to(sample_.device)) ** 2, dim=1).sqrt()
# logger.info(f'sr iter: {i}, epe: {epe.mean()}')
# sample = F.interpolate(sample_sr, size=(H_ori, W_ori), mode='bilinear', align_corners=True)
# sample[:, 0, :, :] = sample[:, 0, :, :] * W_ori
# sample[:, 1, :, :] = sample[:, 1, :, :] * H_ori
# init_flow = F.interpolate(init_flow_sr, size=(H_ori, W_ori), mode='bilinear', align_corners=True)
# # init_flow[:, 0, :, :] = init_flow[:, 0, :, :] * W_ori
# # init_flow[:, 1, :, :] = init_flow[:, 1, :, :] * H_ori
# sample = th.mean(sample[0], dim=0, keepdim=True)
if settings.env.use_sr_net == False:
sample = F.interpolate(sample, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
# sample[:, 0, :, :] = sample[:, 0, :, :] * W_ori
# sample[:, 1, :, :] = sample[:, 1, :, :] * H_ori
base = F.interpolate(coords_grid_tensor((512,512))/511., size=(H_ori, W_ori), mode='bilinear', align_corners=True)
# sample = ( ((sample + base.to(sample.device)) )*2 - 1 )
sample = ( ((sample + base.to(sample.device))*1 )*2 - 1 )*0.987 # (2 * (bm / 286.8) - 1) * 0.99
ref_flow = None
if ref_flow is not None:
ref_flow = F.interpolate(ref_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
# ref_flow[:, 0, :, :] = ref_flow[:, 0, :, :] * W_ori
# ref_flow[:, 1, :, :] = ref_flow[:, 1, :, :] * H_ori
ref_flow = (ref_flow + base.to(ref_flow.device))*2 -1
# init_flow = F.interpolate(init_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True)
else:
raise ValueError("Invalid value")
if settings.env.visualize:
visualize_dewarping(settings, sample, data, i, source_vis, data_path, ref_flow)
# sample = sample.permute(0, 2, 3, 1)[mask]
# init_flow[:, 0, :, :] = init_flow[:, 0, :, :] * W_ori
# init_flow[:, 1, :, :] = init_flow[:, 1, :, :] * H_ori
# init_flow = init_flow.permute(0, 2, 3, 1)[mask]
# print("Elapsed time:{:.2f} minutes ".format(trian_t/60))
print(len(trian_t))
print("Elapsed time:{:.2f} avg_second ".format(sum(trian_t) / len(trian_t)))
def coords_grid_tensor(perturbed_img_shape):
im_x, im_y = np.mgrid[0:perturbed_img_shape[0]-1:complex(perturbed_img_shape[0]),
0:perturbed_img_shape[1]-1:complex(perturbed_img_shape[1])]
coords = np.stack((im_y,im_x), axis=2) # 先x后y,行序优先
coords = th.from_numpy(coords).float().permute(2,0,1).to(dist_util.dev()) # (2, 512, 512)
return coords.unsqueeze(0) # [2, 512, 512]
def validate(local_rank, args, val_loader, model, criterion):
for i, sample in enumerate(val_loader):
input1, label = sample # [2, 3, 288, 288],[2, 2, 288, 288]
input1 = input1.to(local_rank,non_blocking=True)
label = label.to(local_rank,non_blocking=True)
# label = (label/288.0-0.5)*2
with torch.no_grad():
output = model(input1) # [3b, 2, 288, 288]
# loss = F.l1_loss(output, label) # 合成图像强监督
# test point
# bm_test=(output/288.0-0.5)*2
bm_test = (output/992.0-0.5)*2
label = (label/992.0-0.5)*2
# bm_test = output
bm_test = F.interpolate(bm_test, size=(1000,1000), mode='bilinear', align_corners=True)
label = F.interpolate(label, size=(1000,1000), mode='bilinear', align_corners=True)
input1 = F.interpolate(input1, size=(1000,1000), mode='bilinear', align_corners=True)
regis_image1 = F.grid_sample(input=input1, grid=bm_test.permute(0,2,3,1), align_corners=True)
regis_image2 = F.grid_sample(input=input1, grid=label.permute(0,2,3,1), align_corners=True)
# regis_image2 = F.grid_sample(input=a_sample[None], grid=bm_test[None].permute(0,2,3,1), align_corners=True)
tv_save_image(input1[0], "backup/test/ori.png")
tv_save_image(regis_image1[0], "backup/test/aaa.png")
tv_save_image(regis_image2[0], "backup/test/gt.png")
# warped_src = warp(source_vis.to(sample.device).float(), sample) # [1, 3, 1629, 981]
# warped_src = warped_src[0].permute(1, 2, 0).detach().cpu().numpy() # (1873, 1353, 3)
# warped_src = Image.fromarray((warped_src).astype(np.uint8))
# warped_src.save(f"vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred/warped_{data_path[0].split('/')[-1]}")
return None