import numpy as np from matplotlib import pyplot as plt from matplotlib.patches import Rectangle def plot_qualitative(image, sim, palette, texts, alpha=0.6, legend_height=0.1): """ image: HxWx3 uint8 image sim: HxW segmentation mask with integer class IDs palette: list of [R,G,B] colors texts: list of class names corresponding to IDs alpha: transparency for overlay legend_height: fraction of figure height reserved for legend """ qualitative_plot = np.zeros((sim.shape[0], sim.shape[1], 3), dtype=np.uint8) for j in np.unique(sim): qualitative_plot[sim == j] = np.array(palette[j]) # Normalize images for alpha blending img_float = image.astype(np.float32) / 255.0 overlay_float = qualitative_plot.astype(np.float32) / 255.0 # Figure with space for legend fig_height = img_float.shape[0] / 100 fig_width = img_float.shape[1] / 100 fig = plt.figure(figsize=(fig_width, fig_height + legend_height * fig_height), dpi=100) # Main image axis ax_img = fig.add_axes([0, legend_height, 1, 1 - legend_height]) ax_img.imshow(img_float) ax_img.imshow(overlay_float, alpha=alpha) ax_img.axis("off") # Legend axis ax_legend = fig.add_axes([0, 0, 1, legend_height]) ax_legend.axis("off") # Draw legend rectangles unique_classes = np.unique(sim) num_classes = len(unique_classes) for idx, cls in enumerate(unique_classes): color = np.array(palette[cls]) / 255.0 # Rectangle: (x, y), width, height rect_width = 1 / num_classes * 0.8 rect = Rectangle((idx / num_classes, 0.1), rect_width, 0.6, facecolor=color) ax_legend.add_patch(rect) # Add text label centered on rectangle ax_legend.text(idx / num_classes + rect_width / 2, 0.8, texts[cls], ha='center', va='bottom', fontsize=10) # Extract as NumPy array fig.canvas.draw() buf = np.asarray(fig.canvas.renderer.buffer_rgba()) img_array = (buf[:, :, :3]).copy() # drop alpha plt.close(fig) return img_array