Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| def print_examples(model, device, vocab): | |
| transform = transforms.Compose( | |
| [transforms.Resize((299, 299)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | |
| ) | |
| model.eval() | |
| test_img1 = transform(Image.open("./test_examples/dog.png").convert("RGB")).unsqueeze(0) | |
| print("dog.png PREDICTION: " + " ".join(model.caption_image(test_img1.to(device), vocab))) | |
| test_img2 = transform(Image.open("./test_examples/dirt_bike.png").convert("RGB")).unsqueeze(0) | |
| print("dirt_bike.png PREDICTION: " + " ".join(model.caption_image(test_img2.to(device), vocab))) | |
| test_img3 = transform(Image.open("./test_examples/surfing.png").convert("RGB")).unsqueeze(0) | |
| print("wave.png PREDICTION: " + " ".join(model.caption_image(test_img3.to(device), vocab))) | |
| test_img4 = transform(Image.open("./test_examples/horse.png").convert("RGB")).unsqueeze(0) | |
| print("horse.png PREDICTION: " + " ".join(model.caption_image(test_img4.to(device), vocab))) | |
| test_img5 = transform(Image.open("./test_examples/camera.png").convert("RGB")).unsqueeze(0) | |
| print("camera.png PREDICTION: " + " ".join(model.caption_image(test_img5.to(device), vocab))) | |
| model.train() | |
| def save_checkpoint(state, filename="/content/drive/MyDrive/checkpoints/Seq2Seq.pt"): | |
| print("=> Saving checkpoint") | |
| torch.save(state, filename) | |
| def load_checkpoint(checkpoint, model, optimizer): | |
| print("=> Loading checkpoint") | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| step = checkpoint["step"] | |
| return step | |