Spaces:
Build error
Build error
| """ | |
| @Date: 2021/07/17 | |
| @description: | |
| """ | |
| import sys | |
| import os | |
| import shutil | |
| import argparse | |
| import numpy as np | |
| import json | |
| import torch | |
| import torch.nn.parallel | |
| import torch.optim | |
| import torch.multiprocessing as mp | |
| import torch.utils.data | |
| import torch.utils.data.distributed | |
| import torch.cuda | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from torch.utils.tensorboard import SummaryWriter | |
| from config.defaults import get_config, get_rank_config | |
| from models.other.criterion import calc_criterion | |
| from models.build import build_model | |
| from models.other.init_env import init_env | |
| from utils.logger import build_logger | |
| from utils.misc import tensor2np_d, tensor2np | |
| from dataset.build import build_loader | |
| from evaluation.accuracy import calc_accuracy, show_heat_map, calc_ce, calc_pe, calc_rmse_delta_1, \ | |
| show_depth_normal_grad, calc_f1_score | |
| from postprocessing.post_process import post_process | |
| try: | |
| from apex import amp | |
| except ImportError: | |
| amp = None | |
| def parse_option(): | |
| debug = True if sys.gettrace() else False | |
| parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script') | |
| parser.add_argument('--cfg', | |
| type=str, | |
| metavar='FILE', | |
| help='path to config file') | |
| parser.add_argument('--mode', | |
| type=str, | |
| default='train', | |
| choices=['train', 'val', 'test'], | |
| help='train/val/test mode') | |
| parser.add_argument('--val_name', | |
| type=str, | |
| choices=['val', 'test'], | |
| help='val name') | |
| parser.add_argument('--bs', type=int, | |
| help='batch size') | |
| parser.add_argument('--save_eval', action='store_true', | |
| help='save eval result') | |
| parser.add_argument('--post_processing', type=str, | |
| choices=['manhattan', 'atalanta', 'manhattan_old'], | |
| help='type of postprocessing ') | |
| parser.add_argument('--need_cpe', action='store_true', | |
| help='need to evaluate corner error and pixel error') | |
| parser.add_argument('--need_f1', action='store_true', | |
| help='need to evaluate f1-score of corners') | |
| parser.add_argument('--need_rmse', action='store_true', | |
| help='need to evaluate root mean squared error and delta error') | |
| parser.add_argument('--force_cube', action='store_true', | |
| help='force cube shape when eval') | |
| parser.add_argument('--wall_num', type=int, | |
| help='wall number') | |
| args = parser.parse_args() | |
| args.debug = debug | |
| print("arguments:") | |
| for arg in vars(args): | |
| print(arg, ":", getattr(args, arg)) | |
| print("-" * 50) | |
| return args | |
| def main(): | |
| args = parse_option() | |
| config = get_config(args) | |
| if config.TRAIN.SCRATCH and os.path.exists(config.CKPT.DIR) and config.MODE == 'train': | |
| print(f"Train from scratch, delete checkpoint dir: {config.CKPT.DIR}") | |
| f = [int(f.split('_')[-1].split('.')[0]) for f in os.listdir(config.CKPT.DIR) if 'pkl' in f] | |
| if len(f) > 0: | |
| last_epoch = np.array(f).max() | |
| if last_epoch > 10: | |
| c = input(f"delete it (last_epoch: {last_epoch})?(Y/N)\n") | |
| if c != 'y' and c != 'Y': | |
| exit(0) | |
| shutil.rmtree(config.CKPT.DIR, ignore_errors=True) | |
| os.makedirs(config.CKPT.DIR, exist_ok=True) | |
| os.makedirs(config.CKPT.RESULT_DIR, exist_ok=True) | |
| os.makedirs(config.LOGGER.DIR, exist_ok=True) | |
| if ':' in config.TRAIN.DEVICE: | |
| nprocs = len(config.TRAIN.DEVICE.split(':')[-1].split(',')) | |
| if 'cuda' in config.TRAIN.DEVICE: | |
| if not torch.cuda.is_available(): | |
| print(f"Cuda is not available(config is: {config.TRAIN.DEVICE}), will use cpu ...") | |
| config.defrost() | |
| config.TRAIN.DEVICE = "cpu" | |
| config.freeze() | |
| nprocs = 1 | |
| if config.MODE == 'train': | |
| with open(os.path.join(config.CKPT.DIR, "config.yaml"), "w") as f: | |
| f.write(config.dump(allow_unicode=True)) | |
| if config.TRAIN.DEVICE == 'cpu' or nprocs < 2: | |
| print(f"Use single process, device:{config.TRAIN.DEVICE}") | |
| main_worker(0, config, 1) | |
| else: | |
| print(f"Use {nprocs} processes ...") | |
| mp.spawn(main_worker, nprocs=nprocs, args=(config, nprocs), join=True) | |
| def main_worker(local_rank, cfg, world_size): | |
| config = get_rank_config(cfg, local_rank, world_size) | |
| logger = build_logger(config) | |
| writer = SummaryWriter(config.CKPT.DIR) | |
| logger.info(f"Comment: {config.COMMENT}") | |
| cur_pid = os.getpid() | |
| logger.info(f"Current process id: {cur_pid}") | |
| torch.hub._hub_dir = config.CKPT.PYTORCH | |
| logger.info(f"Pytorch hub dir: {torch.hub._hub_dir}") | |
| init_env(config.SEED, config.TRAIN.DETERMINISTIC, config.DATA.NUM_WORKERS) | |
| model, optimizer, criterion, scheduler = build_model(config, logger) | |
| train_data_loader, val_data_loader = build_loader(config, logger) | |
| if 'cuda' in config.TRAIN.DEVICE: | |
| torch.cuda.set_device(config.TRAIN.DEVICE) | |
| if config.MODE == 'train': | |
| train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler) | |
| else: | |
| iou_results, other_results = val_an_epoch(model, val_data_loader, | |
| criterion, config, logger, writer=None, | |
| epoch=config.TRAIN.START_EPOCH) | |
| results = dict(iou_results, **other_results) | |
| if config.SAVE_EVAL: | |
| save_path = os.path.join(config.CKPT.RESULT_DIR, f"result.json") | |
| with open(save_path, 'w+') as f: | |
| json.dump(results, f, indent=4) | |
| def save(model, optimizer, epoch, iou_d, logger, writer, config): | |
| model.save(optimizer, epoch, accuracy=iou_d['full_3d'], logger=logger, acc_d=iou_d, config=config) | |
| for k in model.acc_d: | |
| writer.add_scalar(f"BestACC/{k}", model.acc_d[k]['acc'], epoch) | |
| def train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler): | |
| for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): | |
| logger.info("=" * 200) | |
| train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch) | |
| epoch_iou_d, _ = val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch) | |
| if config.LOCAL_RANK == 0: | |
| ddp = config.WORLD_SIZE > 1 | |
| save(model.module if ddp else model, optimizer, epoch, epoch_iou_d, logger, writer, config) | |
| if scheduler is not None: | |
| if scheduler.min_lr is not None and optimizer.param_groups[0]['lr'] <= scheduler.min_lr: | |
| continue | |
| scheduler.step() | |
| writer.close() | |
| def train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch=0): | |
| logger.info(f'Start Train Epoch {epoch}/{config.TRAIN.EPOCHS - 1}') | |
| model.train() | |
| if len(config.MODEL.FINE_TUNE) > 0: | |
| model.feature_extractor.eval() | |
| optimizer.zero_grad() | |
| data_len = len(train_data_loader) | |
| start_i = data_len * epoch * config.WORLD_SIZE | |
| bar = enumerate(train_data_loader) | |
| if config.LOCAL_RANK == 0 and config.SHOW_BAR: | |
| bar = tqdm(bar, total=data_len, ncols=200) | |
| device = config.TRAIN.DEVICE | |
| epoch_loss_d = {} | |
| for i, gt in bar: | |
| imgs = gt['image'].to(device, non_blocking=True) | |
| gt['depth'] = gt['depth'].to(device, non_blocking=True) | |
| gt['ratio'] = gt['ratio'].to(device, non_blocking=True) | |
| if 'corner_heat_map' in gt: | |
| gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True) | |
| if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device: | |
| imgs = imgs.type(torch.float16) | |
| gt['depth'] = gt['depth'].type(torch.float16) | |
| gt['ratio'] = gt['ratio'].type(torch.float16) | |
| dt = model(imgs) | |
| loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d) | |
| if config.LOCAL_RANK == 0 and config.SHOW_BAR: | |
| bar.set_postfix(batch_loss_d) | |
| optimizer.zero_grad() | |
| if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device: | |
| with amp.scale_loss(loss, optimizer) as scaled_loss: | |
| scaled_loss.backward() | |
| else: | |
| loss.backward() | |
| optimizer.step() | |
| global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK | |
| for key, val in batch_loss_d.items(): | |
| writer.add_scalar(f'TrainBatchLoss/{key}', val, global_step) | |
| if config.LOCAL_RANK != 0: | |
| return | |
| epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()])) | |
| s = 'TrainEpochLoss: ' | |
| for key, val in epoch_loss_d.items(): | |
| writer.add_scalar(f'TrainEpochLoss/{key}', val, epoch) | |
| s += f" {key}={val}" | |
| logger.info(s) | |
| writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], epoch) | |
| logger.info(f"LearningRate: {optimizer.param_groups[0]['lr']}") | |
| def val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch=0): | |
| model.eval() | |
| logger.info(f'Start Validate Epoch {epoch}/{config.TRAIN.EPOCHS - 1}') | |
| data_len = len(val_data_loader) | |
| start_i = data_len * epoch * config.WORLD_SIZE | |
| bar = enumerate(val_data_loader) | |
| if config.LOCAL_RANK == 0 and config.SHOW_BAR: | |
| bar = tqdm(bar, total=data_len, ncols=200) | |
| device = config.TRAIN.DEVICE | |
| epoch_loss_d = {} | |
| epoch_iou_d = { | |
| 'visible_2d': [], | |
| 'visible_3d': [], | |
| 'full_2d': [], | |
| 'full_3d': [], | |
| 'height': [] | |
| } | |
| epoch_other_d = { | |
| 'ce': [], | |
| 'pe': [], | |
| 'f1': [], | |
| 'precision': [], | |
| 'recall': [], | |
| 'rmse': [], | |
| 'delta_1': [] | |
| } | |
| show_index = np.random.randint(0, data_len) | |
| for i, gt in bar: | |
| imgs = gt['image'].to(device, non_blocking=True) | |
| gt['depth'] = gt['depth'].to(device, non_blocking=True) | |
| gt['ratio'] = gt['ratio'].to(device, non_blocking=True) | |
| if 'corner_heat_map' in gt: | |
| gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True) | |
| dt = model(imgs) | |
| vis_w = config.TRAIN.VIS_WEIGHT | |
| visualization = False # (config.LOCAL_RANK == 0 and i == show_index) or config.SAVE_EVAL | |
| loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d) | |
| if config.EVAL.POST_PROCESSING is not None: | |
| depth = tensor2np(dt['depth']) | |
| dt['processed_xyz'] = post_process(depth, type_name=config.EVAL.POST_PROCESSING, | |
| need_cube=config.EVAL.FORCE_CUBE) | |
| if config.EVAL.FORCE_CUBE and config.EVAL.NEED_CPE: | |
| ce = calc_ce(tensor2np_d(dt), tensor2np_d(gt)) | |
| pe = calc_pe(tensor2np_d(dt), tensor2np_d(gt)) | |
| epoch_other_d['ce'].append(ce) | |
| epoch_other_d['pe'].append(pe) | |
| if config.EVAL.NEED_F1: | |
| f1, precision, recall = calc_f1_score(tensor2np_d(dt), tensor2np_d(gt)) | |
| epoch_other_d['f1'].append(f1) | |
| epoch_other_d['precision'].append(precision) | |
| epoch_other_d['recall'].append(recall) | |
| if config.EVAL.NEED_RMSE: | |
| rmse, delta_1 = calc_rmse_delta_1(tensor2np_d(dt), tensor2np_d(gt)) | |
| epoch_other_d['rmse'].append(rmse) | |
| epoch_other_d['delta_1'].append(delta_1) | |
| visb_iou, full_iou, iou_height, pano_bds, full_iou_2ds = calc_accuracy(tensor2np_d(dt), tensor2np_d(gt), | |
| visualization, h=vis_w // 2) | |
| epoch_iou_d['visible_2d'].append(visb_iou[0]) | |
| epoch_iou_d['visible_3d'].append(visb_iou[1]) | |
| epoch_iou_d['full_2d'].append(full_iou[0]) | |
| epoch_iou_d['full_3d'].append(full_iou[1]) | |
| epoch_iou_d['height'].append(iou_height) | |
| if config.LOCAL_RANK == 0 and config.SHOW_BAR: | |
| bar.set_postfix(batch_loss_d) | |
| global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK | |
| if writer: | |
| for key, val in batch_loss_d.items(): | |
| writer.add_scalar(f'ValBatchLoss/{key}', val, global_step) | |
| if not visualization: | |
| continue | |
| gt_grad_imgs, dt_grad_imgs = show_depth_normal_grad(dt, gt, device, vis_w) | |
| dt_heat_map_imgs = None | |
| gt_heat_map_imgs = None | |
| if 'corner_heat_map' in gt: | |
| dt_heat_map_imgs, gt_heat_map_imgs = show_heat_map(dt, gt, vis_w) | |
| if config.TRAIN.VIS_MERGE or config.SAVE_EVAL: | |
| imgs = [] | |
| for j in range(len(pano_bds)): | |
| # floorplan = np.concatenate([visb_iou[2][j], full_iou[2][j]], axis=-1) | |
| floorplan = full_iou[2][j] | |
| margin_w = int(floorplan.shape[-1] * (60/512)) | |
| floorplan = floorplan[:, :, margin_w:-margin_w] | |
| grad_h = dt_grad_imgs[0].shape[1] | |
| vis_merge = [ | |
| gt_grad_imgs[j], | |
| pano_bds[j][:, grad_h:-grad_h], | |
| dt_grad_imgs[j] | |
| ] | |
| if 'corner_heat_map' in gt: | |
| vis_merge = [dt_heat_map_imgs[j], gt_heat_map_imgs[j]] + vis_merge | |
| img = np.concatenate(vis_merge, axis=-2) | |
| img = np.concatenate([img, ], axis=-1) | |
| # img = gt_grad_imgs[j] | |
| imgs.append(img) | |
| if writer: | |
| writer.add_images('VIS/Merge', np.array(imgs), global_step) | |
| if config.SAVE_EVAL: | |
| for k in range(len(imgs)): | |
| img = imgs[k] * 255.0 | |
| save_path = os.path.join(config.CKPT.RESULT_DIR, f"{gt['id'][k]}_{full_iou_2ds[k]:.5f}.png") | |
| Image.fromarray(img.transpose(1, 2, 0).astype(np.uint8)).save(save_path) | |
| elif writer: | |
| writer.add_images('IoU/Visible_Floorplan', visb_iou[2], global_step) | |
| writer.add_images('IoU/Full_Floorplan', full_iou[2], global_step) | |
| writer.add_images('IoU/Boundary', pano_bds, global_step) | |
| writer.add_images('Grad/gt', gt_grad_imgs, global_step) | |
| writer.add_images('Grad/dt', dt_grad_imgs, global_step) | |
| if config.LOCAL_RANK != 0: | |
| return | |
| epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()])) | |
| s = 'ValEpochLoss: ' | |
| for key, val in epoch_loss_d.items(): | |
| if writer: | |
| writer.add_scalar(f'ValEpochLoss/{key}', val, epoch) | |
| s += f" {key}={val}" | |
| logger.info(s) | |
| epoch_iou_d = dict(zip(epoch_iou_d.keys(), [np.array(epoch_iou_d[k]).mean() for k in epoch_iou_d.keys()])) | |
| s = 'ValEpochIoU: ' | |
| for key, val in epoch_iou_d.items(): | |
| if writer: | |
| writer.add_scalar(f'ValEpochIoU/{key}', val, epoch) | |
| s += f" {key}={val}" | |
| logger.info(s) | |
| epoch_other_d = dict(zip(epoch_other_d.keys(), | |
| [np.array(epoch_other_d[k]).mean() if len(epoch_other_d[k]) > 0 else 0 for k in | |
| epoch_other_d.keys()])) | |
| logger.info(f'other acc: {epoch_other_d}') | |
| return epoch_iou_d, epoch_other_d | |
| if __name__ == '__main__': | |
| main() | |