File size: 3,487 Bytes
91fc62a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torchvision
from matplotlib import colors
def get_part_color(n_parts):
colormap = ('red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen',
'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid',
'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson',
'honeydew', 'thistle',
'red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen',
'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid',
'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson',
'honeydew', 'thistle')[:n_parts]
part_color = []
for i in range(n_parts):
part_color.append(colors.to_rgb(colormap[i]))
part_color = np.array(part_color)
return part_color
def denormalize(img):
mean = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1)
std = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1)
img = img * std + mean
img = torch.clamp(img, min=0, max=1)
return img
def draw_matrix(mat):
fig = plt.figure()
sns.heatmap(mat, annot=True, fmt='.2f', cmap="YlGnBu")
ncols, nrows = fig.canvas.get_width_height()
fig.canvas.draw()
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
plt.close(fig)
return plot
def draw_kp_grid(img, kp):
kp_color = get_part_color(kp.shape[1])
img = img[:64].permute(0, 2, 3, 1).detach().cpu()
kp = kp.detach().cpu()[:64]
fig = plt.figure(figsize=(8, 8))
gs = gridspec.GridSpec(8, 8)
gs.update(wspace=0, hspace=0)
for i, sample in enumerate(img):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.imshow(sample, vmin=0, vmax=1)
ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+')
ncols, nrows = fig.canvas.get_width_height()
fig.canvas.draw()
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
plt.close(fig)
return plot
def draw_kp_grid_unnorm(img, kp):
kp_color = get_part_color(kp.shape[1])
img = img[:64].permute(0, 2, 3, 1).detach().cpu()
kp = kp.detach().cpu()[:64]
fig = plt.figure(figsize=(8, 8))
gs = gridspec.GridSpec(8, 8)
gs.update(wspace=0, hspace=0)
for i, sample in enumerate(img):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.imshow(sample)
ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+')
ncols, nrows = fig.canvas.get_width_height()
fig.canvas.draw()
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
plt.close(fig)
return plot
def draw_img_grid(img):
img = img[:64].detach().cpu()
nrow = min(8, img.shape[0])
img = torchvision.utils.make_grid(img[:64], nrow=nrow).permute(1, 2, 0)
return torch.clamp(img * 255, min=0, max=255).numpy().astype(np.uint8)
|