DvD / utils_flow /visualization_utils.py
hanquansanren's picture
Add application file
05fb4ab
import os
import matplotlib.pyplot as plt
import numpy as np
import torch as th
from PIL import Image
from torchvision.utils import save_image
from datasets.utils import flow_viz
from datasets.utils.warping import register_model2
reg_model_bilin = register_model2((512,512), 'bilinear')
def visualize(sample, category, rate, name_dataset, i, batch_vis, source_vis, target_vis, mask):
os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/pred_flow', exist_ok=True) # pred flow
os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/gt_flow', exist_ok=True) # gt flow
os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/src_samples', exist_ok=True) # 原始source
os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/trg_samples', exist_ok=True) # 原始target
os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_pred', exist_ok=True) # pred dewarped
os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_gt', exist_ok=True) # gt dewarped
os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/mask', exist_ok=True) # target上匹配区域的mask
for j in range(len(sample)):
flow_vis = sample[j].detach().permute(1,2,0).float().cpu().numpy()
flow_vis = flow_viz.flow_to_image(flow_vis)
plt.imsave(f'vis_{category}/{rate}_1_{name_dataset+2}/pred_flow/flow_{i}_{j}.png', flow_vis / 255.0)
flow_gt_vis = batch_vis[j].detach().permute(1,2,0).float().cpu().numpy()
flow_gt_vis = flow_viz.flow_to_image(flow_gt_vis)
plt.imsave(f'vis_{category}/{rate}_1_{name_dataset+2}/gt_flow/gt_{i}_{j}.png', flow_gt_vis / 255.0)
src = source_vis[j].permute(1, 2, 0).cpu().numpy()
src = Image.fromarray((src).astype(np.uint8))
src.save(f'vis_{category}/{rate}_1_{name_dataset+2}/src_samples/src_{i}_{j}.png')
trg = target_vis[j].permute(1, 2, 0).cpu().numpy()
trg = Image.fromarray((trg).astype(np.uint8))
trg.save(f'vis_{category}/{rate}_1_{name_dataset+2}/trg_samples/trg_{i}_{j}.png')
warped_src = warp(source_vis[j:j+1].to(sample.device).float(), sample)
warped_src_masked = warped_src * mask[j:j+1].float()
warped_src = warped_src[j].permute(1, 2, 0).detach().cpu().numpy()
warped_src_masked = warped_src_masked[j].permute(1, 2, 0).detach().cpu().numpy()
warped_src = Image.fromarray((warped_src).astype(np.uint8))
warped_src_masked = Image.fromarray((warped_src_masked).astype(np.uint8))
warped_src.save(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_pred/warped_{i}_{j}.png')
warped_src_masked.save(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_pred/warped_masked_{i}_{j}.png')
warped_gt = warp(source_vis[j:j+1].to(batch_vis.device).float(), batch_vis)
warped_gt_masked = warped_gt * mask[j:j+1].float()
warped_gt = warped_gt[j].permute(1, 2, 0).detach().cpu().numpy()
warped_gt_masked = warped_gt_masked[j].permute(1, 2, 0).detach().cpu().numpy()
warped_gt = Image.fromarray((warped_gt).astype(np.uint8))
warped_gt_masked = Image.fromarray((warped_gt_masked).astype(np.uint8))
warped_gt.save(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_gt/warped_{i}_{j}.png')
warped_gt_masked.save(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_gt/warped_masked_{i}_{j}.png')
mask_vis = th.stack((mask[j], mask[j], mask[j]))
save_image(mask_vis.float(), f'vis_{category}/{rate}_1_{name_dataset+2}/mask/mask_{i}_{j}.png')
def visualize_dewarping(settings, sample, data, i, source_vis, data_path, ref_flow=None):
os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/pred_flow', exist_ok=True) # pred flow
os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped
# os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/source', exist_ok=True)
# flow_vis = sample[0].detach().permute(1,2,0).float().cpu().numpy()
# flow_vis = flow_viz.flow_to_image(flow_vis)
# plt.imsave(f"vis_hp/{settings.env.eval_dataset_name}/{settings.name}/pred_flow/flow_{data_path[0].split('/')[-1]}", flow_vis / 255.0)
# save_image(source_vis/255., f"vis_hp/{settings.env.eval_dataset_name}/{settings.name}/source/{data_path[0].split('/')[-1]}")
# warped_src = warp(source_vis.to(sample.device).float(), sample) # [1, 3, 1629, 981]
warped_src = reg_model_bilin([source_vis.to(sample.device).float(), sample])
warped_src = warped_src[0].permute(1, 2, 0).detach().cpu().numpy()#*255. # (1873, 1353, 3)
warped_src = Image.fromarray((warped_src).astype(np.uint8))
warped_src.save(f"vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred/warped_{data_path[0].split('/')[-1][:-4]}.png")
if ref_flow is not None:
os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/pred_flow_ref', exist_ok=True) # pred flow
os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred_ref', exist_ok=True) # pred dewarped
# flow_vis_ref = ref_flow[0].detach().permute(1,2,0).float().cpu().numpy()
# flow_vis_ref = flow_viz.flow_to_image(flow_vis_ref)
# plt.imsave(f"vis_hp/{settings.env.eval_dataset_name}/{settings.name}/pred_flow_ref/flow_{data_path[0].split('/')[-1]}", flow_vis_ref / 255.0)
# warped_src_ref = warp(source_vis.to(ref_flow.device).float(), ref_flow) # [1, 3, 1629, 981]
warped_src_ref = reg_model_bilin([source_vis.to(ref_flow.device).float(), ref_flow])
warped_src_ref = warped_src_ref[0].permute(1, 2, 0).detach().cpu().numpy()#*255. # (1873, 1353, 3)
warped_src_ref = Image.fromarray((warped_src_ref).astype(np.uint8))
warped_src_ref.save(f"vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred_ref/warped_{data_path[0].split('/')[-1]}")