| import importlib | |
| import PIL | |
| import pytorch_lightning as pl | |
| import torch.utils.data | |
| import wandb | |
| from typing import Union | |
| from torchvision import transforms | |
| from utils_.loss import VGGPerceptualLoss | |
| from utils_.visualization import * | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| class Model(pl.LightningModule): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.encoder = importlib.import_module('models.' + self.hparams.encoder).Encoder(self.hparams) | |
| self.decoder = importlib.import_module('models.' + self.hparams.decoder).Decoder(self.hparams) | |
| self.batch_size = self.hparams.batch_size | |
| self.vgg_loss = VGGPerceptualLoss() | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(0.5, 0.5) | |
| ]) | |
| def forward(self, x: PIL.Image.Image) -> PIL.Image.Image: | |
| """ | |
| :param x: a PIL image | |
| :return: an edge map of the same size as x with values in [0, 1] (normalized by max) | |
| """ | |
| w, h = x.size | |
| x = self.transform(x).unsqueeze(0) | |
| x = x.to(self.device) | |
| kp = self.encoder({'img': x})['keypoints'] | |
| edge_map = self.decoder.rasterize(kp, output_size=64) | |
| bs = edge_map.shape[0] | |
| edge_map = edge_map / (1e-8 + edge_map.reshape(bs, 1, -1).max(dim=2, keepdim=True)[0].reshape(bs, 1, 1, 1)) | |
| edge_map = torch.cat([edge_map] * 3, dim=1) | |
| edge_map = F.interpolate(edge_map, size=(h, w), mode='bilinear', align_corners=False) | |
| x = torch.clamp(edge_map + (x * 0.5 + 0.5)*0.5, min=0, max=1) | |
| x = transforms.ToPILImage()(x[0].detach().cpu()) | |
| fig = plt.figure(figsize=(1, h/w), dpi=w) | |
| fig.tight_layout(pad=0) | |
| plt.axis('off') | |
| plt.imshow(x) | |
| kp = kp[0].detach().cpu() * 0.5 + 0.5 | |
| kp[:, 1] *= w | |
| kp[:, 0] *= h | |
| plt.scatter(kp[:, 1], kp[:, 0], s=min(w/h, min(1, h/w)), marker='o') | |
| ncols, nrows = fig.canvas.get_width_height() | |
| fig.canvas.draw() | |
| plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3) | |
| plt.close(fig) | |
| return plot | |