Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import cv2 | |
| import os | |
| from loss import batch_episym | |
| from tqdm import tqdm | |
| import sys | |
| ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| sys.path.insert(0, ROOT_DIR) | |
| from utils import evaluation_utils, train_utils | |
| def valid(valid_loader, model, match_loss, config, model_config): | |
| model.eval() | |
| loader_iter = iter(valid_loader) | |
| num_pair = 0 | |
| total_loss, total_acc_corr, total_acc_incorr = 0, 0, 0 | |
| total_precision, total_recall = torch.zeros( | |
| model_config.layer_num, device="cuda" | |
| ), torch.zeros(model_config.layer_num, device="cuda") | |
| total_acc_mid = torch.zeros(len(model_config.seedlayer) - 1, device="cuda") | |
| with torch.no_grad(): | |
| if config.local_rank == 0: | |
| loader_iter = tqdm(loader_iter) | |
| print("validating...") | |
| for test_data in loader_iter: | |
| num_pair += 1 | |
| test_data = train_utils.tocuda(test_data) | |
| res = model(test_data) | |
| loss_res = match_loss.run(test_data, res) | |
| total_acc_corr += loss_res["acc_corr"] | |
| total_acc_incorr += loss_res["acc_incorr"] | |
| total_loss += loss_res["total_loss"] | |
| if config.model_name == "SGM": | |
| total_acc_mid += loss_res["mid_acc_corr"] | |
| total_precision, total_recall = ( | |
| total_precision + loss_res["pre_seed_conf"], | |
| total_recall + loss_res["recall_seed_conf"], | |
| ) | |
| total_acc_corr /= num_pair | |
| total_acc_incorr /= num_pair | |
| total_precision /= num_pair | |
| total_recall /= num_pair | |
| total_acc_mid /= num_pair | |
| # apply tensor reduction | |
| ( | |
| total_loss, | |
| total_acc_corr, | |
| total_acc_incorr, | |
| total_precision, | |
| total_recall, | |
| total_acc_mid, | |
| ) = ( | |
| train_utils.reduce_tensor(total_loss, "sum"), | |
| train_utils.reduce_tensor(total_acc_corr, "mean"), | |
| train_utils.reduce_tensor(total_acc_incorr, "mean"), | |
| train_utils.reduce_tensor(total_precision, "mean"), | |
| train_utils.reduce_tensor(total_recall, "mean"), | |
| train_utils.reduce_tensor(total_acc_mid, "mean"), | |
| ) | |
| model.train() | |
| return ( | |
| total_loss, | |
| total_acc_corr, | |
| total_acc_incorr, | |
| total_precision, | |
| total_recall, | |
| total_acc_mid, | |
| ) | |
| def dump_train_vis(res, data, step, config): | |
| # batch matching | |
| p = res["p"][:, :-1, :-1] | |
| score, index1 = torch.max(p, dim=-1) | |
| _, index2 = torch.max(p, dim=-2) | |
| mask_th = score > 0.2 | |
| mask_mc = index2.gather(index=index1, dim=1) == torch.arange(len(p[0])).cuda()[None] | |
| mask_p = mask_th & mask_mc # B*N | |
| corr1, corr2 = data["x1"], data["x2"].gather( | |
| index=index1[:, :, None].expand(-1, -1, 2), dim=1 | |
| ) | |
| corr1_kpt, corr2_kpt = data["kpt1"], data["kpt2"].gather( | |
| index=index1[:, :, None].expand(-1, -1, 2), dim=1 | |
| ) | |
| epi_dis = batch_episym(corr1, corr2, data["e_gt"]) | |
| mask_inlier = epi_dis < config.inlier_th # B*N | |
| # dump vis | |
| for cur_mask_p, cur_mask_inlier, cur_corr1, cur_corr2, img_path1, img_path2 in zip( | |
| mask_p, mask_inlier, corr1_kpt, corr2_kpt, data["img_path1"], data["img_path2"] | |
| ): | |
| img1, img2 = cv2.imread(img_path1), cv2.imread(img_path2) | |
| dis_play = evaluation_utils.draw_match( | |
| img1, | |
| img2, | |
| cur_corr1[cur_mask_p].cpu().numpy(), | |
| cur_corr2[cur_mask_p].cpu().numpy(), | |
| inlier=cur_mask_inlier, | |
| ) | |
| base_name_seq = os.path.join( | |
| img_path1.split("/")[-1] | |
| + "_" | |
| + img_path2.split("/")[-1] | |
| + "_" | |
| + img_path1.split("/")[-2] | |
| ) | |
| save_path = os.path.join( | |
| config.train_vis_folder, | |
| "train_vis", | |
| config.log_base, | |
| str(step), | |
| base_name_seq + ".png", | |
| ) | |
| cv2.imwrite(save_path, dis_play) | |