Spaces:
Build error
Build error
| from transformers import AutoTokenizer, AutoModel, AutoConfig | |
| import torch | |
| from tqdm import tqdm | |
| import gan_cls_768 | |
| from torch.autograd import Variable | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def clean(txt): | |
| txt = txt.lower() | |
| txt = txt.strip() | |
| txt = txt.strip('.') | |
| return txt | |
| max_len = 76 | |
| def tokenize(tokenizer, txt): | |
| return tokenizer( | |
| txt, | |
| max_length=max_len, | |
| padding='max_length', | |
| truncation=True, | |
| return_offsets_mapping=False | |
| ) | |
| def encode(model_name, model, tokenizer, txt): | |
| txt = clean(txt) | |
| txt_tokenized = tokenize(tokenizer, txt) | |
| for k, v in txt_tokenized.items(): | |
| txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None] | |
| model.eval() | |
| with torch.no_grad(): | |
| encoded = model(**txt_tokenized) | |
| return encoded.last_hidden_state.squeeze()[0].cpu().numpy() | |
| model_name = 'roberta-base' | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device) | |
| def generate_image(text, n): | |
| embed = encode(model_name, model, tokenizer, text) | |
| generator = torch.nn.DataParallel(gan_cls_768.generator().to(device)) | |
| generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu'))) | |
| generator.eval() | |
| embed2 = torch.FloatTensor(embed) | |
| embed2 = embed2.unsqueeze(0) | |
| right_embed = Variable(embed2.float()).to(device) | |
| l = [] | |
| for i in tqdm(range(n)): | |
| noise = Variable(torch.randn(1, 100)).to(device) | |
| noise = noise.view(noise.size(0), 100, 1, 1) | |
| fake_images = generator(right_embed, noise) | |
| for idx, image in enumerate(fake_images): | |
| im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy()) | |
| l.append(im) | |
| return l | |
| if __name__ == '__main__': | |
| n = 10 | |
| imgs = generate_image('Red images', n) | |
| fig, ax = plt.subplots(nrows=5, ncols=2) | |
| ax = ax.flatten() | |
| for idx, ax in enumerate(ax): | |
| ax.imshow(imgs[idx]) | |
| ax.axis('off') | |
| fig.tight_layout() | |
| plt.show() | |
| # while True: | |
| # print('Type Caption: ') | |
| # txt = input() | |
| # print('Generating images...') | |
| # generate_image(txt) | |
| # print('Completed') | |