Puffin / scripts /camera /visualization /visualize_batch.py
KangLiao's picture
init
ace9173
"""Visualization of predicted and ground truth for a single batch."""
"""Adapted from https://github.com/cvg/GeoCalib"""
from typing import Any, Dict
import numpy as np
import torch
from scripts.camera.geometry.perspective_fields import get_latitude_field
from scripts.camera.utils.conversions import rad2deg
from scripts.camera.utils.tensor import batch_to_device
from scripts.camera.visualization.viz2d import (
plot_confidences,
plot_heatmaps,
plot_image_grid,
plot_latitudes,
plot_vector_fields,
)
def make_up_figure(
pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
) -> Dict[str, Any]:
"""Get predicted and ground truth up fields and errors.
Args:
pred (Dict[str, torch.Tensor]): Predicted up field.
data (Dict[str, torch.Tensor]): Ground truth up field.
n_pairs (int): Number of pairs to visualize.
Returns:
Dict[str, Any]: Dictionary with figure.
"""
pred = batch_to_device(pred, "cpu", detach=True)
data = batch_to_device(data, "cpu", detach=True)
n_pairs = min(n_pairs, len(data["image"]))
if "up_field" not in pred.keys():
return {}
up_fields = []
for i in range(n_pairs):
row = [data["up_field"][i]]
titles = ["Up GT"]
if "up_confidence" in pred.keys():
row += [pred["up_confidence"][i]]
titles += ["Up Confidence"]
row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row]
up_fields.append(row)
# create figure
N, M = len(up_fields), len(up_fields[0]) + 1
imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
fig, ax = plot_image_grid(imgs, return_fig=True, set_lim=True)
ax = np.array(ax)
for i in range(n_pairs):
plot_vector_fields([up_fields[i][0]], axes=ax[i, [1]])
#plot_heatmaps([up_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]])
if "up_confidence" in pred.keys():
plot_confidences([up_fields[i][3]], axes=ax[i, [4]])
return {"up": fig}
def make_latitude_figure(
pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
) -> Dict[str, Any]:
"""Get predicted and ground truth latitude fields and errors.
Args:
pred (Dict[str, torch.Tensor]): Predicted latitude field.
data (Dict[str, torch.Tensor]): Ground truth latitude field.
n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
Returns:
Dict[str, Any]: Dictionary with figure.
"""
pred = batch_to_device(pred, "cpu", detach=True)
data = batch_to_device(data, "cpu", detach=True)
n_pairs = min(n_pairs, len(data["image"]))
latitude_fields = []
if "latitude_field" not in pred.keys():
return {}
for i in range(n_pairs):
row = [
rad2deg(data["latitude_field"][i][0]),
#rad2deg(pred["latitude_field"][i][0]),
#errors[i],
]
titles = ["Latitude GT"]
if "latitude_confidence" in pred.keys():
row += [pred["latitude_confidence"][i]]
titles += ["Latitude Confidence"]
row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row]
latitude_fields.append(row)
# create figure
N, M = len(latitude_fields), len(latitude_fields[0]) + 1
imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
fig, ax = plot_image_grid(imgs, return_fig=True, set_lim=True)
ax = np.array(ax)
for i in range(n_pairs):
plot_latitudes([latitude_fields[i][0]], is_radians=False, axes=ax[i, [1]])
#plot_heatmaps([latitude_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]])
if "latitude_confidence" in pred.keys():
plot_confidences([latitude_fields[i][3]], axes=ax[i, [4]])
return {"latitude": fig}
def make_camera_figure(
pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
) -> Dict[str, Any]:
"""Get predicted and ground truth camera parameters.
Args:
pred (Dict[str, torch.Tensor]): Predicted camera parameters.
data (Dict[str, torch.Tensor]): Ground truth camera parameters.
n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
Returns:
Dict[str, Any]: Dictionary with figure.
"""
pred = batch_to_device(pred, "cpu", detach=True)
data = batch_to_device(data, "cpu", detach=True)
n_pairs = min(n_pairs, len(data["image"]))
if "camera" not in pred.keys():
return {}
latitudes = []
for i in range(n_pairs):
titles = ["Cameras GT"]
row = [get_latitude_field(data["camera"][i], data["gravity"][i])]
if "camera" in pred.keys() and "gravity" in pred.keys():
row += [get_latitude_field(pred["camera"][i], pred["gravity"][i])]
titles += ["Cameras Pred"]
row = [rad2deg(r).squeeze(-1).float().numpy()[0] for r in row]
latitudes.append(row)
# create figure
N, M = len(latitudes), len(latitudes[0]) + 1
imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True)
ax = np.array(ax)
for i in range(n_pairs):
plot_latitudes(latitudes[i], is_radians=False, axes=ax[i, 1:])
return {"camera": fig}
def make_perspective_figures(
pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
) -> Dict[str, Any]:
"""Get predicted and ground truth perspective fields.
Args:
pred (Dict[str, torch.Tensor]): Predicted perspective fields.
data (Dict[str, torch.Tensor]): Ground truth perspective fields.
n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
Returns:
Dict[str, Any]: Dictionary with figure.
"""
n_pairs = min(n_pairs, len(data["image"]))
figures = make_up_figure(pred, data, n_pairs)
figures |= make_latitude_figure(pred, data, n_pairs)
#figures |= make_camera_figure(pred, data, n_pairs)
{f.tight_layout() for f in figures.values()}
return figures