Talk2DINO / src /plot.py
lorebianchi98's picture
First commit
593b176
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
def plot_qualitative(image, sim, palette, texts, alpha=0.6, legend_height=0.1):
"""
image: HxWx3 uint8 image
sim: HxW segmentation mask with integer class IDs
palette: list of [R,G,B] colors
texts: list of class names corresponding to IDs
alpha: transparency for overlay
legend_height: fraction of figure height reserved for legend
"""
qualitative_plot = np.zeros((sim.shape[0], sim.shape[1], 3), dtype=np.uint8)
for j in np.unique(sim):
qualitative_plot[sim == j] = np.array(palette[j])
# Normalize images for alpha blending
img_float = image.astype(np.float32) / 255.0
overlay_float = qualitative_plot.astype(np.float32) / 255.0
# Figure with space for legend
fig_height = img_float.shape[0] / 100
fig_width = img_float.shape[1] / 100
fig = plt.figure(figsize=(fig_width, fig_height + legend_height * fig_height), dpi=100)
# Main image axis
ax_img = fig.add_axes([0, legend_height, 1, 1 - legend_height])
ax_img.imshow(img_float)
ax_img.imshow(overlay_float, alpha=alpha)
ax_img.axis("off")
# Legend axis
ax_legend = fig.add_axes([0, 0, 1, legend_height])
ax_legend.axis("off")
# Draw legend rectangles
unique_classes = np.unique(sim)
num_classes = len(unique_classes)
for idx, cls in enumerate(unique_classes):
color = np.array(palette[cls]) / 255.0
# Rectangle: (x, y), width, height
rect_width = 1 / num_classes * 0.8
rect = Rectangle((idx / num_classes, 0.1), rect_width, 0.6, facecolor=color)
ax_legend.add_patch(rect)
# Add text label centered on rectangle
ax_legend.text(idx / num_classes + rect_width / 2, 0.8, texts[cls],
ha='center', va='bottom', fontsize=10)
# Extract as NumPy array
fig.canvas.draw()
buf = np.asarray(fig.canvas.renderer.buffer_rgba())
img_array = (buf[:, :, :3]).copy() # drop alpha
plt.close(fig)
return img_array