Spaces:
Sleeping
Sleeping
| import torch | |
| from tqdm import tqdm | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision.transforms as transforms | |
| from torch.utils.tensorboard import SummaryWriter # For TensorBoard | |
| from utils import save_checkpoint, load_checkpoint, print_examples | |
| from dataset import get_loader | |
| from model import SeqToSeq | |
| from tabulate import tabulate # To tabulate loss and epoch | |
| import argparse | |
| import json | |
| def main(args): | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((356, 356)), | |
| transforms.RandomCrop((299, 299)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ] | |
| ) | |
| train_loader, _ = get_loader( | |
| root_folder = args.root_dir, | |
| annotation_file = args.csv_file, | |
| transform=transform, | |
| batch_size = 64, | |
| num_workers=2, | |
| ) | |
| vocab = json.load(open('vocab.json')) | |
| torch.backends.cudnn.benchmark = True | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| load_model = False | |
| save_model = True | |
| train_CNN = False | |
| # Hyperparameters | |
| embed_size = args.embed_size | |
| hidden_size = args.hidden_size | |
| vocab_size = len(vocab['stoi']) | |
| num_layers = args.num_layers | |
| learning_rate = args.lr | |
| num_epochs = args.num_epochs | |
| # for tensorboard | |
| writer = SummaryWriter(args.log_dir) | |
| step = 0 | |
| model_params = {'embed_size': embed_size, 'hidden_size': hidden_size, 'vocab_size':vocab_size, 'num_layers':num_layers} | |
| # initialize model, loss etc | |
| model = SeqToSeq(**model_params, device = device).to(device) | |
| criterion = nn.CrossEntropyLoss(ignore_index = vocab['stoi']["<PAD>"]) | |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
| # Only finetune the CNN | |
| for name, param in model.encoder.inception.named_parameters(): | |
| if "fc.weight" in name or "fc.bias" in name: | |
| param.requires_grad = True | |
| else: | |
| param.requires_grad = train_CNN | |
| #load from a save checkpoint | |
| if load_model: | |
| step = load_checkpoint(torch.load(args.save_path), model, optimizer) | |
| model.train() | |
| best_loss, best_epoch = 10, 0 | |
| for epoch in range(num_epochs): | |
| print_examples(model, device, vocab['itos']) | |
| for idx, (imgs, captions) in tqdm( | |
| enumerate(train_loader), total=len(train_loader), leave=False): | |
| imgs = imgs.to(device) | |
| captions = captions.to(device) | |
| outputs = model(imgs, captions[:-1]) | |
| loss = criterion( | |
| outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1) | |
| ) | |
| writer.add_scalar("Training loss", loss.item(), global_step=step) | |
| step += 1 | |
| optimizer.zero_grad() | |
| loss.backward(loss) | |
| optimizer.step() | |
| train_loss = loss.item() | |
| if train_loss < best_loss: | |
| best_loss = train_loss | |
| best_epoch = epoch + 1 | |
| if save_model: | |
| checkpoint = { | |
| "model_params": model_params, | |
| "state_dict": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "step": step | |
| } | |
| save_checkpoint(checkpoint, args.save_path) | |
| table = [["Loss:", train_loss], | |
| ["Step:", step], | |
| ["Epoch:", epoch + 1], | |
| ["Best Loss:", best_loss], | |
| ["Best Epoch:", best_epoch]] | |
| print(tabulate(table)) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--root_dir', type = str, default = './flickr30k/flickr30k_images', help = 'path to images folder') | |
| parser.add_argument('--csv_file', type = str, default = './flickr30k/results.csv', help = 'path to captions csv file') | |
| parser.add_argument('--log_dir', type = str, default = './drive/MyDrive/TensorBoard/', help = 'path to save tensorboard logs') | |
| parser.add_argument('--save_path', type = str, default = './drive/MyDrive/checkpoints/Seq2Seq.pt', help = 'path to save checkpoint') | |
| # Model Params | |
| parser.add_argument('--batch_size', type = int, default = 64) | |
| parser.add_argument('--num_epochs', type = int, default = 100) | |
| parser.add_argument('--embed_size', type = int, default=256) | |
| parser.add_argument('--hidden_size', type = int, default=512) | |
| parser.add_argument('--lr', type = float, default= 0.001) | |
| parser.add_argument('--num_layers', type = int, default = 3, help = 'number of lstm layers') | |
| args = parser.parse_args() | |
| main(args) |