| import matplotlib.gridspec as gridspec | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import seaborn as sns | |
| import torch | |
| import torchvision | |
| from matplotlib import colors | |
| def get_part_color(n_parts): | |
| colormap = ('red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen', | |
| 'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid', | |
| 'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson', | |
| 'honeydew', 'thistle', | |
| 'red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen', | |
| 'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid', | |
| 'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson', | |
| 'honeydew', 'thistle')[:n_parts] | |
| part_color = [] | |
| for i in range(n_parts): | |
| part_color.append(colors.to_rgb(colormap[i])) | |
| part_color = np.array(part_color) | |
| return part_color | |
| def denormalize(img): | |
| mean = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1) | |
| std = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1) | |
| img = img * std + mean | |
| img = torch.clamp(img, min=0, max=1) | |
| return img | |
| def draw_matrix(mat): | |
| fig = plt.figure() | |
| sns.heatmap(mat, annot=True, fmt='.2f', cmap="YlGnBu") | |
| ncols, nrows = fig.canvas.get_width_height() | |
| fig.canvas.draw() | |
| plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3) | |
| plt.close(fig) | |
| return plot | |
| def draw_kp_grid(img, kp): | |
| kp_color = get_part_color(kp.shape[1]) | |
| img = img[:64].permute(0, 2, 3, 1).detach().cpu() | |
| kp = kp.detach().cpu()[:64] | |
| fig = plt.figure(figsize=(8, 8)) | |
| gs = gridspec.GridSpec(8, 8) | |
| gs.update(wspace=0, hspace=0) | |
| for i, sample in enumerate(img): | |
| ax = plt.subplot(gs[i]) | |
| plt.axis('off') | |
| ax.set_xticklabels([]) | |
| ax.set_yticklabels([]) | |
| ax.imshow(sample, vmin=0, vmax=1) | |
| ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+') | |
| ncols, nrows = fig.canvas.get_width_height() | |
| fig.canvas.draw() | |
| plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3) | |
| plt.close(fig) | |
| return plot | |
| def draw_kp_grid_unnorm(img, kp): | |
| kp_color = get_part_color(kp.shape[1]) | |
| img = img[:64].permute(0, 2, 3, 1).detach().cpu() | |
| kp = kp.detach().cpu()[:64] | |
| fig = plt.figure(figsize=(8, 8)) | |
| gs = gridspec.GridSpec(8, 8) | |
| gs.update(wspace=0, hspace=0) | |
| for i, sample in enumerate(img): | |
| ax = plt.subplot(gs[i]) | |
| plt.axis('off') | |
| ax.set_xticklabels([]) | |
| ax.set_yticklabels([]) | |
| ax.imshow(sample) | |
| ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+') | |
| ncols, nrows = fig.canvas.get_width_height() | |
| fig.canvas.draw() | |
| plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3) | |
| plt.close(fig) | |
| return plot | |
| def draw_img_grid(img): | |
| img = img[:64].detach().cpu() | |
| nrow = min(8, img.shape[0]) | |
| img = torchvision.utils.make_grid(img[:64], nrow=nrow).permute(1, 2, 0) | |
| return torch.clamp(img * 255, min=0, max=255).numpy().astype(np.uint8) | |