Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,081 Bytes
593b176 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
def plot_qualitative(image, sim, palette, texts, alpha=0.6, legend_height=0.1):
"""
image: HxWx3 uint8 image
sim: HxW segmentation mask with integer class IDs
palette: list of [R,G,B] colors
texts: list of class names corresponding to IDs
alpha: transparency for overlay
legend_height: fraction of figure height reserved for legend
"""
qualitative_plot = np.zeros((sim.shape[0], sim.shape[1], 3), dtype=np.uint8)
for j in np.unique(sim):
qualitative_plot[sim == j] = np.array(palette[j])
# Normalize images for alpha blending
img_float = image.astype(np.float32) / 255.0
overlay_float = qualitative_plot.astype(np.float32) / 255.0
# Figure with space for legend
fig_height = img_float.shape[0] / 100
fig_width = img_float.shape[1] / 100
fig = plt.figure(figsize=(fig_width, fig_height + legend_height * fig_height), dpi=100)
# Main image axis
ax_img = fig.add_axes([0, legend_height, 1, 1 - legend_height])
ax_img.imshow(img_float)
ax_img.imshow(overlay_float, alpha=alpha)
ax_img.axis("off")
# Legend axis
ax_legend = fig.add_axes([0, 0, 1, legend_height])
ax_legend.axis("off")
# Draw legend rectangles
unique_classes = np.unique(sim)
num_classes = len(unique_classes)
for idx, cls in enumerate(unique_classes):
color = np.array(palette[cls]) / 255.0
# Rectangle: (x, y), width, height
rect_width = 1 / num_classes * 0.8
rect = Rectangle((idx / num_classes, 0.1), rect_width, 0.6, facecolor=color)
ax_legend.add_patch(rect)
# Add text label centered on rectangle
ax_legend.text(idx / num_classes + rect_width / 2, 0.8, texts[cls],
ha='center', va='bottom', fontsize=10)
# Extract as NumPy array
fig.canvas.draw()
buf = np.asarray(fig.canvas.renderer.buffer_rgba())
img_array = (buf[:, :, :3]).copy() # drop alpha
plt.close(fig)
return img_array
|