Spaces:
Build error
Build error
| """ | |
| @Date: 2021/09/19 | |
| @description: | |
| """ | |
| import json | |
| import os | |
| import argparse | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import glob | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from config.defaults import merge_from_file, get_config | |
| from dataset.mp3d_dataset import MP3DDataset | |
| from dataset.zind_dataset import ZindDataset | |
| from models.build import build_model | |
| from loss import GradLoss | |
| from postprocessing.post_process import post_process | |
| from preprocessing.pano_lsd_align import panoEdgeDetection, rotatePanorama | |
| from utils.boundary import corners2boundaries, layout2depth | |
| from utils.conversion import depth2xyz | |
| from utils.logger import get_logger | |
| from utils.misc import tensor2np_d, tensor2np | |
| from evaluation.accuracy import show_grad | |
| from models.lgt_net import LGT_Net | |
| from utils.writer import xyz2json | |
| from visualization.boundary import draw_boundaries | |
| from visualization.floorplan import draw_floorplan, draw_iou_floorplan | |
| from visualization.obj3d import create_3d_obj | |
| def parse_option(): | |
| parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script') | |
| parser.add_argument('--img_glob', | |
| type=str, | |
| required=True, | |
| help='image glob path') | |
| parser.add_argument('--cfg', | |
| type=str, | |
| required=True, | |
| metavar='FILE', | |
| help='path of config file') | |
| parser.add_argument('--post_processing', | |
| type=str, | |
| default='manhattan', | |
| choices=['manhattan', 'atalanta', 'original'], | |
| help='post-processing type') | |
| parser.add_argument('--output_dir', | |
| type=str, | |
| default='src/output', | |
| help='path of output') | |
| parser.add_argument('--visualize_3d', action='store_true', | |
| help='visualize_3d') | |
| parser.add_argument('--output_3d', action='store_true', | |
| help='output_3d') | |
| parser.add_argument('--device', | |
| type=str, | |
| default='cuda', | |
| help='device') | |
| args = parser.parse_args() | |
| args.mode = 'test' | |
| print("arguments:") | |
| for arg in vars(args): | |
| print(arg, ":", getattr(args, arg)) | |
| print("-" * 50) | |
| return args | |
| def visualize_2d(img, dt, show_depth=True, show_floorplan=True, show=False, save_path=None): | |
| dt_np = tensor2np_d(dt) | |
| dt_depth = dt_np['depth'][0] | |
| dt_xyz = depth2xyz(np.abs(dt_depth)) | |
| dt_ratio = dt_np['ratio'][0][0] | |
| dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt_xyz, step=None, visible=False, length=img.shape[1]) | |
| vis_img = draw_boundaries(img, boundary_list=dt_boundaries, boundary_color=[0, 1, 0]) | |
| if 'processed_xyz' in dt: | |
| dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt['processed_xyz'][0], step=None, visible=False, | |
| length=img.shape[1]) | |
| vis_img = draw_boundaries(vis_img, boundary_list=dt_boundaries, boundary_color=[1, 0, 0]) | |
| if show_depth: | |
| dt_grad_img = show_depth_normal_grad(dt) | |
| grad_h = dt_grad_img.shape[0] | |
| vis_merge = [ | |
| vis_img[0:-grad_h, :, :], | |
| dt_grad_img, | |
| ] | |
| vis_img = np.concatenate(vis_merge, axis=0) | |
| # vis_img = dt_grad_img.transpose(1, 2, 0)[100:] | |
| if show_floorplan: | |
| if 'processed_xyz' in dt: | |
| floorplan = draw_iou_floorplan(dt['processed_xyz'][0][..., ::2], dt_xyz[..., ::2], | |
| dt_board_color=[1, 0, 0, 1], gt_board_color=[0, 1, 0, 1]) | |
| else: | |
| floorplan = show_alpha_floorplan(dt_xyz, border_color=[0, 1, 0, 1]) | |
| vis_img = np.concatenate([vis_img, floorplan[:, 60:-60, :]], axis=1) | |
| if show: | |
| plt.imshow(vis_img) | |
| plt.show() | |
| if save_path: | |
| result = Image.fromarray((vis_img * 255).astype(np.uint8)) | |
| result.save(save_path) | |
| return vis_img | |
| def preprocess(img_ori, q_error=0.7, refine_iter=3, vp_cache_path=None): | |
| # Align images with VP | |
| if os.path.exists(vp_cache_path): | |
| with open(vp_cache_path) as f: | |
| vp = [[float(v) for v in line.rstrip().split(' ')] for line in f.readlines()] | |
| vp = np.array(vp) | |
| else: | |
| # VP detection and line segment extraction | |
| _, vp, _, _, _, _, _ = panoEdgeDetection(img_ori, | |
| qError=q_error, | |
| refineIter=refine_iter) | |
| i_img = rotatePanorama(img_ori, vp[2::-1]) | |
| if vp_cache_path is not None: | |
| with open(vp_cache_path, 'w') as f: | |
| for i in range(3): | |
| f.write('%.6f %.6f %.6f\n' % (vp[i, 0], vp[i, 1], vp[i, 2])) | |
| return i_img, vp | |
| def show_depth_normal_grad(dt): | |
| grad_conv = GradLoss().to(dt['depth'].device).grad_conv | |
| dt_grad_img = show_grad(dt['depth'][0], grad_conv, 50) | |
| dt_grad_img = cv2.resize(dt_grad_img, (1024, 60), interpolation=cv2.INTER_NEAREST) | |
| return dt_grad_img | |
| def show_alpha_floorplan(dt_xyz, side_l=512, border_color=None): | |
| if border_color is None: | |
| border_color = [1, 0, 0, 1] | |
| fill_color = [0.2, 0.2, 0.2, 0.2] | |
| dt_floorplan = draw_floorplan(xz=dt_xyz[..., ::2], fill_color=fill_color, | |
| border_color=border_color, side_l=side_l, show=False, center_color=[1, 0, 0, 1]) | |
| dt_floorplan = Image.fromarray((dt_floorplan * 255).astype(np.uint8), mode='RGBA') | |
| back = np.zeros([side_l, side_l, len(fill_color)], dtype=np.float) | |
| back[..., :] = [0.8, 0.8, 0.8, 1] | |
| back = Image.fromarray((back * 255).astype(np.uint8), mode='RGBA') | |
| iou_floorplan = Image.alpha_composite(back, dt_floorplan).convert("RGB") | |
| dt_floorplan = np.array(iou_floorplan) / 255.0 | |
| return dt_floorplan | |
| def save_pred_json(xyz, ration, save_path): | |
| # xyz[..., -1] = -xyz[..., -1] | |
| json_data = xyz2json(xyz, ration) | |
| with open(save_path, 'w') as f: | |
| f.write(json.dumps(json_data, indent=4) + '\n') | |
| return json_data | |
| def inference(): | |
| if len(img_paths) == 0: | |
| logger.error('No images found') | |
| return | |
| bar = tqdm(img_paths, ncols=100) | |
| for img_path in bar: | |
| if not os.path.isfile(img_path): | |
| logger.error(f'The {img_path} not is file') | |
| continue | |
| name = os.path.basename(img_path).split('.')[0] | |
| bar.set_description(name) | |
| img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3] | |
| if args.post_processing is not None and 'manhattan' in args.post_processing: | |
| bar.set_description("Preprocessing") | |
| img, vp = preprocess(img, vp_cache_path=os.path.join(args.output_dir, f"{name}_vp.txt")) | |
| img = (img / 255.0).astype(np.float32) | |
| run_one_inference(img, model, args, name) | |
| def inference_dataset(dataset): | |
| bar = tqdm(dataset, ncols=100) | |
| for data in bar: | |
| bar.set_description(data['id']) | |
| run_one_inference(data['image'].transpose(1, 2, 0), model, args, name=data['id'], logger=logger) | |
| def run_one_inference(img, model, args, name, logger, show=True, show_depth=True, | |
| show_floorplan=True, mesh_format='.gltf', mesh_resolution=512): | |
| model.eval() | |
| logger.info("model inference...") | |
| dt = model(torch.from_numpy(img.transpose(2, 0, 1)[None]).to(args.device)) | |
| if args.post_processing != 'original': | |
| logger.info(f"post-processing, type:{args.post_processing}...") | |
| dt['processed_xyz'] = post_process(tensor2np(dt['depth']), type_name=args.post_processing) | |
| visualize_2d(img, dt, | |
| show_depth=show_depth, | |
| show_floorplan=show_floorplan, | |
| show=show, | |
| save_path=os.path.join(args.output_dir, f"{name}_pred.png")) | |
| output_xyz = dt['processed_xyz'][0] if 'processed_xyz' in dt else depth2xyz(tensor2np(dt['depth'][0])) | |
| logger.info(f"saving predicted layout json...") | |
| json_data = save_pred_json(output_xyz, tensor2np(dt['ratio'][0])[0], | |
| save_path=os.path.join(args.output_dir, f"{name}_pred.json")) | |
| # if args.visualize_3d: | |
| # from visualization.visualizer.visualizer import visualize_3d | |
| # visualize_3d(json_data, (img * 255).astype(np.uint8)) | |
| if args.visualize_3d or args.output_3d: | |
| dt_boundaries = corners2boundaries(tensor2np(dt['ratio'][0])[0], corners_xyz=output_xyz, step=None, | |
| length=mesh_resolution if 'processed_xyz' in dt else None, | |
| visible=True if 'processed_xyz' in dt else False) | |
| dt_layout_depth = layout2depth(dt_boundaries, show=False) | |
| logger.info(f"creating 3d mesh ...") | |
| create_3d_obj(cv2.resize(img, dt_layout_depth.shape[::-1]), dt_layout_depth, | |
| save_path=os.path.join(args.output_dir, f"{name}_3d{mesh_format}") if args.output_3d else None, | |
| mesh=True, show=args.visualize_3d) | |
| if __name__ == '__main__': | |
| logger = get_logger() | |
| args = parse_option() | |
| config = get_config(args) | |
| if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available(): | |
| logger.info(f'The {args.device} is not available, will use cpu ...') | |
| config.defrost() | |
| args.device = "cpu" | |
| config.TRAIN.DEVICE = "cpu" | |
| config.freeze() | |
| model, _, _, _ = build_model(config, logger) | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| img_paths = sorted(glob.glob(args.img_glob)) | |
| inference() | |
| # dataset = MP3DDataset(root_dir='./src/dataset/mp3d', mode='test', split_list=[ | |
| # ['7y3sRwLe3Va', '155fac2d50764bf09feb6c8f33e8fb76'], | |
| # ['e9zR4mvMWw7', 'c904c55a5d0e420bbd6e4e030b9fe5b4'], | |
| # ]) | |
| # dataset = ZindDataset(root_dir='./src/dataset/zind', mode='test', split_list=[ | |
| # '1169_pano_21', | |
| # '0583_pano_59', | |
| # ], vp_align=True) | |
| # inference_dataset(dataset) | |