Spaces:
Build error
Build error
| ''' | |
| @author: Zhigang Jiang | |
| @time: 2022/05/23 | |
| @description: | |
| ''' | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import torch | |
| os.system('pip install --upgrade --no-cache-dir gdown') | |
| from PIL import Image | |
| from utils.logger import get_logger | |
| from config.defaults import get_config | |
| from inference import preprocess, run_one_inference | |
| from models.build import build_model | |
| from argparse import Namespace | |
| import gdown | |
| def down_ckpt(model_cfg, ckpt_dir): | |
| model_ids = [ | |
| ['src/config/mp3d.yaml', '1o97oAmd-yEP5bQrM0eAWFPLq27FjUDbh'], | |
| ['src/config/zind.yaml', '1PzBj-dfDfH_vevgSkRe5kczW0GVl_43I'], | |
| ['src/config/pano.yaml', '1JoeqcPbm_XBPOi6O9GjjWi3_rtyPZS8m'], | |
| ['src/config/s2d3d.yaml', '1PfJzcxzUsbwwMal7yTkBClIFgn8IdEzI'], | |
| ['src/config/ablation_study/full.yaml', '1U16TxUkvZlRwJNaJnq9nAUap-BhCVIha'] | |
| ] | |
| for model_id in model_ids: | |
| if model_id[0] != model_cfg: | |
| continue | |
| path = os.path.join(ckpt_dir, 'best.pkl') | |
| if not os.path.exists(path): | |
| logger.info(f"Downloading {model_id}") | |
| os.makedirs(ckpt_dir, exist_ok=True) | |
| gdown.download(f"https://drive.google.com/uc?id={model_id[1]}", path, False) | |
| def greet(img_path, pre_processing, weight_name, post_processing, visualization, mesh_format, mesh_resolution): | |
| args.pre_processing = pre_processing | |
| args.post_processing = post_processing | |
| if weight_name == 'mp3d': | |
| model = mp3d_model | |
| elif weight_name == 'zind': | |
| model = zind_model | |
| else: | |
| logger.error("unknown pre-trained weight name") | |
| raise NotImplementedError | |
| img_name = os.path.basename(img_path).split('.')[0] | |
| img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3] | |
| vp_cache_path = 'src/demo/default_vp.txt' | |
| if args.pre_processing: | |
| vp_cache_path = os.path.join('src/output', f'{img_name}_vp.txt') | |
| logger.info("pre-processing ...") | |
| img, vp = preprocess(img, vp_cache_path=vp_cache_path) | |
| img = (img / 255.0).astype(np.float32) | |
| run_one_inference(img, model, args, img_name, | |
| logger=logger, show=False, | |
| show_depth='depth-normal-gradient' in visualization, | |
| show_floorplan='2d-floorplan' in visualization, | |
| mesh_format=mesh_format, mesh_resolution=int(mesh_resolution)) | |
| return [os.path.join(args.output_dir, f"{img_name}_pred.png"), | |
| os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"), | |
| os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"), | |
| vp_cache_path, | |
| os.path.join(args.output_dir, f"{img_name}_pred.json")] | |
| def get_model(args): | |
| config = get_config(args) | |
| down_ckpt(args.cfg, config.CKPT.DIR) | |
| if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available(): | |
| logger.info(f'The {args.device} is not available, will use cpu ...') | |
| config.defrost() | |
| args.device = "cpu" | |
| config.TRAIN.DEVICE = "cpu" | |
| config.freeze() | |
| model, _, _, _ = build_model(config, logger) | |
| return model | |
| if __name__ == '__main__': | |
| logger = get_logger() | |
| args = Namespace(device='cuda', output_dir='src/output', visualize_3d=False, output_3d=True) | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| args.cfg = 'src/config/mp3d.yaml' | |
| mp3d_model = get_model(args) | |
| args.cfg = 'src/config/zind.yaml' | |
| zind_model = get_model(args) | |
| description = "This demo of the project " \ | |
| "<a href='https://github.com/zhigangjiang/LGT-Net' target='_blank'>LGT-Net</a>. " \ | |
| "It uses the Geometry-Aware Transformer Network to predict the 3d room layout of an rgb panorama." | |
| demo = gr.Interface(fn=greet, | |
| inputs=[gr.Image(type='filepath', label='input rgb panorama', value='src/demo/pano_demo1.png'), | |
| gr.Checkbox(label='pre-processing', value=True), | |
| gr.Radio(['mp3d', 'zind'], | |
| label='pre-trained weight', | |
| value='mp3d'), | |
| gr.Radio(['manhattan', 'atalanta', 'original'], | |
| label='post-processing method', | |
| value='manhattan'), | |
| gr.CheckboxGroup(['depth-normal-gradient', '2d-floorplan'], | |
| label='2d-visualization', | |
| value=['depth-normal-gradient', '2d-floorplan']), | |
| gr.Radio(['.gltf', '.obj', '.glb'], | |
| label='output format of 3d mesh', | |
| value='.gltf'), | |
| gr.Radio(['128', '256', '512', '1024'], | |
| label='output resolution of 3d mesh', | |
| value='256'), | |
| ], | |
| outputs=[gr.Image(label='predicted result 2d-visualization', type='filepath'), | |
| gr.Model3D(label='3d mesh reconstruction', clear_color=[1.0, 1.0, 1.0, 1.0]), | |
| gr.File(label='3d mesh file'), | |
| gr.File(label='vanishing point information'), | |
| gr.File(label='layout json')], | |
| examples=[ | |
| ['src/demo/pano_demo1.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
| ['src/demo/mp3d_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
| ['src/demo/mp3d_demo2.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
| ['src/demo/mp3d_demo3.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
| ['src/demo/zind_demo1.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
| ['src/demo/zind_demo2.png', False, 'zind', 'atalanta', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
| ['src/demo/zind_demo3.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
| ['src/demo/other_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
| ['src/demo/other_demo2.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
| ], title='LGT-Net', allow_flagging="never", cache_examples=False, description=description) | |
| demo.launch(debug=True, enable_queue=False) | |