Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch as th | |
| from PIL import Image | |
| from torchvision.utils import save_image | |
| from datasets.utils import flow_viz | |
| from datasets.utils.warping import register_model2 | |
| reg_model_bilin = register_model2((512,512), 'bilinear') | |
| def visualize(sample, category, rate, name_dataset, i, batch_vis, source_vis, target_vis, mask): | |
| os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/pred_flow', exist_ok=True) # pred flow | |
| os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/gt_flow', exist_ok=True) # gt flow | |
| os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/src_samples', exist_ok=True) # 原始source | |
| os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/trg_samples', exist_ok=True) # 原始target | |
| os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_pred', exist_ok=True) # pred dewarped | |
| os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_gt', exist_ok=True) # gt dewarped | |
| os.makedirs(f'vis_{category}/{rate}_1_{name_dataset+2}/mask', exist_ok=True) # target上匹配区域的mask | |
| for j in range(len(sample)): | |
| flow_vis = sample[j].detach().permute(1,2,0).float().cpu().numpy() | |
| flow_vis = flow_viz.flow_to_image(flow_vis) | |
| plt.imsave(f'vis_{category}/{rate}_1_{name_dataset+2}/pred_flow/flow_{i}_{j}.png', flow_vis / 255.0) | |
| flow_gt_vis = batch_vis[j].detach().permute(1,2,0).float().cpu().numpy() | |
| flow_gt_vis = flow_viz.flow_to_image(flow_gt_vis) | |
| plt.imsave(f'vis_{category}/{rate}_1_{name_dataset+2}/gt_flow/gt_{i}_{j}.png', flow_gt_vis / 255.0) | |
| src = source_vis[j].permute(1, 2, 0).cpu().numpy() | |
| src = Image.fromarray((src).astype(np.uint8)) | |
| src.save(f'vis_{category}/{rate}_1_{name_dataset+2}/src_samples/src_{i}_{j}.png') | |
| trg = target_vis[j].permute(1, 2, 0).cpu().numpy() | |
| trg = Image.fromarray((trg).astype(np.uint8)) | |
| trg.save(f'vis_{category}/{rate}_1_{name_dataset+2}/trg_samples/trg_{i}_{j}.png') | |
| warped_src = warp(source_vis[j:j+1].to(sample.device).float(), sample) | |
| warped_src_masked = warped_src * mask[j:j+1].float() | |
| warped_src = warped_src[j].permute(1, 2, 0).detach().cpu().numpy() | |
| warped_src_masked = warped_src_masked[j].permute(1, 2, 0).detach().cpu().numpy() | |
| warped_src = Image.fromarray((warped_src).astype(np.uint8)) | |
| warped_src_masked = Image.fromarray((warped_src_masked).astype(np.uint8)) | |
| warped_src.save(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_pred/warped_{i}_{j}.png') | |
| warped_src_masked.save(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_pred/warped_masked_{i}_{j}.png') | |
| warped_gt = warp(source_vis[j:j+1].to(batch_vis.device).float(), batch_vis) | |
| warped_gt_masked = warped_gt * mask[j:j+1].float() | |
| warped_gt = warped_gt[j].permute(1, 2, 0).detach().cpu().numpy() | |
| warped_gt_masked = warped_gt_masked[j].permute(1, 2, 0).detach().cpu().numpy() | |
| warped_gt = Image.fromarray((warped_gt).astype(np.uint8)) | |
| warped_gt_masked = Image.fromarray((warped_gt_masked).astype(np.uint8)) | |
| warped_gt.save(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_gt/warped_{i}_{j}.png') | |
| warped_gt_masked.save(f'vis_{category}/{rate}_1_{name_dataset+2}/dewarped_gt/warped_masked_{i}_{j}.png') | |
| mask_vis = th.stack((mask[j], mask[j], mask[j])) | |
| save_image(mask_vis.float(), f'vis_{category}/{rate}_1_{name_dataset+2}/mask/mask_{i}_{j}.png') | |
| def visualize_dewarping(settings, sample, data, i, source_vis, data_path, ref_flow=None): | |
| os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/pred_flow', exist_ok=True) # pred flow | |
| os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped | |
| # os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/source', exist_ok=True) | |
| # flow_vis = sample[0].detach().permute(1,2,0).float().cpu().numpy() | |
| # flow_vis = flow_viz.flow_to_image(flow_vis) | |
| # plt.imsave(f"vis_hp/{settings.env.eval_dataset_name}/{settings.name}/pred_flow/flow_{data_path[0].split('/')[-1]}", flow_vis / 255.0) | |
| # save_image(source_vis/255., f"vis_hp/{settings.env.eval_dataset_name}/{settings.name}/source/{data_path[0].split('/')[-1]}") | |
| # warped_src = warp(source_vis.to(sample.device).float(), sample) # [1, 3, 1629, 981] | |
| warped_src = reg_model_bilin([source_vis.to(sample.device).float(), sample]) | |
| warped_src = warped_src[0].permute(1, 2, 0).detach().cpu().numpy()#*255. # (1873, 1353, 3) | |
| warped_src = Image.fromarray((warped_src).astype(np.uint8)) | |
| warped_src.save(f"vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred/warped_{data_path[0].split('/')[-1][:-4]}.png") | |
| if ref_flow is not None: | |
| os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/pred_flow_ref', exist_ok=True) # pred flow | |
| os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred_ref', exist_ok=True) # pred dewarped | |
| # flow_vis_ref = ref_flow[0].detach().permute(1,2,0).float().cpu().numpy() | |
| # flow_vis_ref = flow_viz.flow_to_image(flow_vis_ref) | |
| # plt.imsave(f"vis_hp/{settings.env.eval_dataset_name}/{settings.name}/pred_flow_ref/flow_{data_path[0].split('/')[-1]}", flow_vis_ref / 255.0) | |
| # warped_src_ref = warp(source_vis.to(ref_flow.device).float(), ref_flow) # [1, 3, 1629, 981] | |
| warped_src_ref = reg_model_bilin([source_vis.to(ref_flow.device).float(), ref_flow]) | |
| warped_src_ref = warped_src_ref[0].permute(1, 2, 0).detach().cpu().numpy()#*255. # (1873, 1353, 3) | |
| warped_src_ref = Image.fromarray((warped_src_ref).astype(np.uint8)) | |
| warped_src_ref.save(f"vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred_ref/warped_{data_path[0].split('/')[-1]}") | |