Spaces:
Runtime error
Runtime error
| import torch | |
| from torchvision.utils import make_grid | |
| from torchvision import transforms | |
| import torchvision.transforms.functional as TF | |
| from torch import nn, optim | |
| from torch.optim.lr_scheduler import CosineAnnealingLR | |
| from torch.utils.data import DataLoader, Dataset | |
| from huggingface_hub import hf_hub_download | |
| import requests | |
| import gradio as gr | |
| import numpy as np | |
| class Upsample(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, dropout=True): | |
| super(Upsample, self).__init__() | |
| self.dropout = dropout | |
| self.block = nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=nn.InstanceNorm2d), | |
| nn.InstanceNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.dropout_layer = nn.Dropout2d(0.5) | |
| def forward(self, x, shortcut=None): | |
| x = self.block(x) | |
| if self.dropout: | |
| x = self.dropout_layer(x) | |
| if shortcut is not None: | |
| x = torch.cat([x, shortcut], dim=1) | |
| return x | |
| class Downsample(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, apply_instancenorm=True): | |
| super(Downsample, self).__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=nn.InstanceNorm2d) | |
| self.norm = nn.InstanceNorm2d(out_channels) | |
| self.relu = nn.LeakyReLU(0.2, inplace=True) | |
| self.apply_norm = apply_instancenorm | |
| def forward(self, x): | |
| x = self.conv(x) | |
| if self.apply_norm: | |
| x = self.norm(x) | |
| x = self.relu(x) | |
| return x | |
| class CycleGAN_Unet_Generator(nn.Module): | |
| def __init__(self, filter=64): | |
| super(CycleGAN_Unet_Generator, self).__init__() | |
| self.downsamples = nn.ModuleList([ | |
| Downsample(3, filter, kernel_size=4, apply_instancenorm=False), # (b, filter, 128, 128) | |
| Downsample(filter, filter * 2), # (b, filter * 2, 64, 64) | |
| Downsample(filter * 2, filter * 4), # (b, filter * 4, 32, 32) | |
| Downsample(filter * 4, filter * 8), # (b, filter * 8, 16, 16) | |
| Downsample(filter * 8, filter * 8), # (b, filter * 8, 8, 8) | |
| Downsample(filter * 8, filter * 8), # (b, filter * 8, 4, 4) | |
| Downsample(filter * 8, filter * 8), # (b, filter * 8, 2, 2) | |
| ]) | |
| self.upsamples = nn.ModuleList([ | |
| Upsample(filter * 8, filter * 8), | |
| Upsample(filter * 16, filter * 8), | |
| Upsample(filter * 16, filter * 8), | |
| Upsample(filter * 16, filter * 4, dropout=False), | |
| Upsample(filter * 8, filter * 2, dropout=False), | |
| Upsample(filter * 4, filter, dropout=False) | |
| ]) | |
| self.last = nn.Sequential( | |
| nn.ConvTranspose2d(filter * 2, 3, kernel_size=4, stride=2, padding=1), | |
| nn.Tanh() | |
| ) | |
| def forward(self, x): | |
| skips = [] | |
| for l in self.downsamples: | |
| x = l(x) | |
| skips.append(x) | |
| skips = reversed(skips[:-1]) | |
| for l, s in zip(self.upsamples, skips): | |
| x = l(x, s) | |
| out = self.last(x) | |
| return out | |
| class ImageTransform: | |
| def __init__(self, img_size=256): | |
| self.transform = { | |
| 'train': transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomVerticalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5], std=[0.5]) | |
| ]), | |
| 'test': transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5], std=[0.5]) | |
| ])} | |
| def __call__(self, img, phase='train'): | |
| img = self.transform[phase](img) | |
| return img | |
| path = hf_hub_download('huggan/NeonGAN', 'model.bin') | |
| model_gen_n = torch.load(path, map_location=torch.device('cpu')) | |
| transform = ImageTransform(img_size=256) | |
| inputs = [ | |
| gr.inputs.Image(type="pil", label="Original Image") | |
| ] | |
| outputs = [ | |
| gr.outputs.Image(type="pil", label="Neon Image") | |
| ] | |
| def get_output_image(img): | |
| img = transform(img, phase='test') | |
| gen_img = model_gen_n(img.unsqueeze(0))[0] | |
| # Reverse Normalization | |
| gen_img = gen_img * 0.5 + 0.5 | |
| gen_img = gen_img * 255 | |
| gen_img = gen_img.detach().cpu().numpy().astype(np.uint8) | |
| gen_img = np.transpose(gen_img, [1,2,0]) | |
| gen_img = Image.fromarray(gen_img) | |
| print(gen_img) | |
| return gen_img | |
| gr.Interface( | |
| get_output_image, | |
| inputs, | |
| outputs, | |
| theme="huggingface", | |
| ).launch(enable_queue=True) | |