Spaces:
Runtime error
Runtime error
| import torch | |
| from ..utils.tensor import batch_to_device | |
| from .viz2d import cm_RdGn, plot_heatmaps, plot_image_grid, plot_keypoints, plot_matches | |
| def make_match_figures(pred_, data_, n_pairs=2): | |
| # print first n pairs in batch | |
| if "0to1" in pred_.keys(): | |
| pred_ = pred_["0to1"] | |
| images, kpts, matches, mcolors = [], [], [], [] | |
| heatmaps = [] | |
| pred = batch_to_device(pred_, "cpu", non_blocking=False) | |
| data = batch_to_device(data_, "cpu", non_blocking=False) | |
| view0, view1 = data["view0"], data["view1"] | |
| n_pairs = min(n_pairs, view0["image"].shape[0]) | |
| assert view0["image"].shape[0] >= n_pairs | |
| kp0, kp1 = pred["keypoints0"], pred["keypoints1"] | |
| m0 = pred["matches0"] | |
| gtm0 = pred["gt_matches0"] | |
| for i in range(n_pairs): | |
| valid = (m0[i] > -1) & (gtm0[i] >= -1) | |
| kpm0, kpm1 = kp0[i][valid].numpy(), kp1[i][m0[i][valid]].numpy() | |
| images.append( | |
| [view0["image"][i].permute(1, 2, 0), view1["image"][i].permute(1, 2, 0)] | |
| ) | |
| kpts.append([kp0[i], kp1[i]]) | |
| matches.append((kpm0, kpm1)) | |
| correct = gtm0[i][valid] == m0[i][valid] | |
| if "heatmap0" in pred.keys(): | |
| heatmaps.append( | |
| [ | |
| torch.sigmoid(pred["heatmap0"][i, 0]), | |
| torch.sigmoid(pred["heatmap1"][i, 0]), | |
| ] | |
| ) | |
| elif "depth" in view0.keys() and view0["depth"] is not None: | |
| heatmaps.append([view0["depth"][i], view1["depth"][i]]) | |
| mcolors.append(cm_RdGn(correct).tolist()) | |
| fig, axes = plot_image_grid(images, return_fig=True, set_lim=True) | |
| if len(heatmaps) > 0: | |
| [plot_heatmaps(heatmaps[i], axes=axes[i], a=1.0) for i in range(n_pairs)] | |
| [plot_keypoints(kpts[i], axes=axes[i], colors="royalblue") for i in range(n_pairs)] | |
| [ | |
| plot_matches(*matches[i], color=mcolors[i], axes=axes[i], a=0.5, lw=1.0, ps=0.0) | |
| for i in range(n_pairs) | |
| ] | |
| return {"matching": fig} | |