Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| class InceptionEncoder(nn.Module): | |
| def __init__(self, embed_size, train_CNN=False): | |
| super(InceptionEncoder, self).__init__() | |
| self.train_CNN = train_CNN | |
| self.inception = models.inception_v3(pretrained=True, aux_logits=False) | |
| self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size) | |
| self.relu = nn.ReLU() | |
| self.bn = nn.BatchNorm1d(embed_size, momentum = 0.01) | |
| self.dropout = nn.Dropout(0.5) | |
| def forward(self, images): | |
| features = self.inception(images) | |
| norm_features = self.bn(features) | |
| return self.dropout(self.relu(norm_features)) | |
| class LstmDecoder(nn.Module): | |
| def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'): | |
| super(LstmDecoder, self).__init__() | |
| self.num_layers = num_layers | |
| self.hidden_size = hidden_size | |
| self.device = device | |
| self.embed = nn.Embedding(vocab_size, embed_size) | |
| self.lstm = nn.LSTM(embed_size, hidden_size, num_layers = self.num_layers) | |
| self.linear = nn.Linear(hidden_size, vocab_size) | |
| self.dropout = nn.Dropout(0.5) | |
| def forward(self, encoder_out, captions): | |
| h0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_() | |
| c0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_() | |
| embeddings = self.dropout(self.embed(captions)) | |
| embeddings = torch.cat((encoder_out.unsqueeze(0), embeddings), dim=0) | |
| hiddens, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach())) | |
| outputs = self.linear(hiddens) | |
| return outputs | |
| class SeqToSeq(nn.Module): | |
| def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'): | |
| super(SeqToSeq, self).__init__() | |
| self.encoder = InceptionEncoder(embed_size) | |
| self.decoder = LstmDecoder(embed_size, hidden_size, vocab_size, num_layers, device) | |
| def forward(self, images, captions): | |
| features = self.encoder(images) | |
| outputs = self.decoder(features, captions) | |
| return outputs | |
| def caption_image(self, image, vocabulary, max_length = 50): | |
| result_caption = [] | |
| with torch.no_grad(): | |
| x = self.encoder(image).unsqueeze(0) | |
| states = None | |
| for _ in range(max_length): | |
| hiddens, states = self.decoder.lstm(x, states) | |
| output = self.decoder.linear(hiddens.squeeze(0)) | |
| predicted = output.argmax(1) | |
| result_caption.append(predicted.item()) | |
| x = self.decoder.embed(predicted).unsqueeze(0) | |
| if vocabulary[str(predicted.item())] == "<EOS>": | |
| break | |
| return [vocabulary[str(idx)] for idx in result_caption] | |