Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import json | |
| from neuralnet.model import SeqToSeq | |
| import wget | |
| url = "https://github.com/Koushik0901/Image-Captioning/releases/download/v1.0/flickr30k.pt" | |
| # os.system("curl -L https://github.com/Koushik0901/Image-Captioning/releases/download/v1.0/flickr30k.pt") | |
| filename = wget.download(url) | |
| def inference(img_path): | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((299, 299)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ] | |
| ) | |
| vocabulary = json.load(open('./vocab.json')) | |
| model_params = {"embed_size":256, "hidden_size":512, "vocab_size": 7666, "num_layers": 3, "device":"cpu"} | |
| model = SeqToSeq(**model_params) | |
| checkpoint = torch.load('./flickr30k.pt', map_location = 'cpu') | |
| model.load_state_dict(checkpoint['state_dict']) | |
| img = transform(Image.open(img_path).convert("RGB")).unsqueeze(0) | |
| result_caption = [] | |
| model.eval() | |
| x = model.encoder(img).unsqueeze(0) | |
| states = None | |
| out_captions = model.caption_image(img, vocabulary['itos'], 50) | |
| return " ".join(out_captions[1:-1]) | |
| if __name__ == '__main__': | |
| print(inference('./test_examples/dog.png')) | |