File size: 2,081 Bytes
593b176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle

def plot_qualitative(image, sim, palette, texts, alpha=0.6, legend_height=0.1):
    """
    image: HxWx3 uint8 image
    sim: HxW segmentation mask with integer class IDs
    palette: list of [R,G,B] colors
    texts: list of class names corresponding to IDs
    alpha: transparency for overlay
    legend_height: fraction of figure height reserved for legend
    """

    qualitative_plot = np.zeros((sim.shape[0], sim.shape[1], 3), dtype=np.uint8)
    for j in np.unique(sim):
        qualitative_plot[sim == j] = np.array(palette[j])

    # Normalize images for alpha blending
    img_float = image.astype(np.float32) / 255.0
    overlay_float = qualitative_plot.astype(np.float32) / 255.0

    # Figure with space for legend
    fig_height = img_float.shape[0] / 100
    fig_width = img_float.shape[1] / 100
    fig = plt.figure(figsize=(fig_width, fig_height + legend_height * fig_height), dpi=100)

    # Main image axis
    ax_img = fig.add_axes([0, legend_height, 1, 1 - legend_height])
    ax_img.imshow(img_float)
    ax_img.imshow(overlay_float, alpha=alpha)
    ax_img.axis("off")

    # Legend axis
    ax_legend = fig.add_axes([0, 0, 1, legend_height])
    ax_legend.axis("off")

    # Draw legend rectangles
    unique_classes = np.unique(sim)
    num_classes = len(unique_classes)
    for idx, cls in enumerate(unique_classes):
        color = np.array(palette[cls]) / 255.0
        # Rectangle: (x, y), width, height
        rect_width = 1 / num_classes * 0.8
        rect = Rectangle((idx / num_classes, 0.1), rect_width, 0.6, facecolor=color)
        ax_legend.add_patch(rect)
        # Add text label centered on rectangle
        ax_legend.text(idx / num_classes + rect_width / 2, 0.8, texts[cls],
                       ha='center', va='bottom', fontsize=10)

    # Extract as NumPy array
    fig.canvas.draw()
    buf = np.asarray(fig.canvas.renderer.buffer_rgba())
    img_array = (buf[:, :, :3]).copy()  # drop alpha

    plt.close(fig)
    return img_array