Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import transformers | |
| from torch.utils.data import Dataset | |
| from transformers import ViTFeatureExtractor | |
| from io import BytesIO | |
| from base64 import b64decode | |
| from PIL import Image | |
| from accelerate import Accelerator | |
| import base64 | |
| from config import get_config | |
| from pathlib import Path | |
| from tokenizers import Tokenizer | |
| from tokenizers.models import WordLevel | |
| from tokenizers.trainers import WordLevelTrainer | |
| from tokenizers.pre_tokenizers import Whitespace | |
| from model import build_transformer | |
| import torch.nn.functional as F | |
| from transformers import GPT2TokenizerFast | |
| import streamlit as st | |
| def process(model,image, tokenizer, device): | |
| image = get_image(image) | |
| model.eval() | |
| with torch.no_grad(): | |
| encoder_input = image.unsqueeze(0).to(device) # (b, seq_len) | |
| model_out = greedy_decode(model, encoder_input, None, tokenizer, 196,device) | |
| model_text = tokenizer.decode(model_out.detach().cpu().numpy()) | |
| return model_text | |
| # get image prompt | |
| def get_image(image): | |
| # import model | |
| model_id = 'google/vit-base-patch16-224-in21k' | |
| feature_extractor = ViTFeatureExtractor.from_pretrained( | |
| model_id | |
| ) | |
| image = Image.open(image) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| enc_input = feature_extractor( | |
| image, | |
| return_tensors='pt' | |
| ) | |
| return enc_input['pixel_values'].squeeze(0).squeeze(0).squeeze(0).squeeze(0).squeeze(0) | |
| #get tokenizer | |
| def get_or_build_tokenizer(config): | |
| tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2", unk_token ='[UNK]', bos_token = '[SOS]', eos_token = '[EOS]' , pad_token = '[PAD]') | |
| return tokenizer | |
| def causal_mask(size): | |
| mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int) | |
| return mask == 0 | |
| # get model | |
| def get_model(config, vocab_tgt_len): | |
| model = build_transformer(vocab_tgt_len, config['seq_len'], d_model=config['d_model']) | |
| return model | |
| # greedy decode | |
| def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device): | |
| sos_idx = tokenizer_tgt.convert_tokens_to_ids('[SOS]') | |
| eos_idx = tokenizer_tgt.convert_tokens_to_ids('[EOS]') | |
| # Precompute the encoder output and reuse it for every step | |
| encoder_output = model.encode(source, None) | |
| # Initialize the decoder input with the sos token | |
| decoder_input = torch.empty(1, 1).fill_(sos_idx).long().to(device) | |
| while True: | |
| if decoder_input.size(1) == max_len: | |
| break | |
| # build mask for target | |
| decoder_mask = causal_mask(decoder_input.size(1)).long().to(device) | |
| # calculate output | |
| out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask) | |
| # print(f'out: {out.shape}') | |
| # Get next token probabilities with temperature applied | |
| logits = model.project(out[:, -1]) | |
| probabilities = F.softmax(logits, dim=-1) | |
| # Greedily select the next word | |
| next_word = torch.argmax(probabilities, dim=1) | |
| # Append next word | |
| decoder_input = torch.cat([decoder_input, next_word.unsqueeze(0)], dim=1) | |
| if next_word.item() == eos_idx: | |
| break | |
| return decoder_input.squeeze(0) | |
| def image_base64(image): | |
| base64_bytes = base64.b64encode(image_file.read()) | |
| base64_string = base64_bytes.decode() | |
| return base64_string | |
| def main(): | |
| st.title("Image Captioning with Vision Transformer") | |
| image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) | |
| if image is not None: | |
| st.image(image, use_column_width=True) | |
| # image_bytes = uploaded_file.getvalue() | |
| # image = image_base64(image_bytes) | |
| # image = get_image(uploaded_file) | |
| with st.empty(): | |
| st.write("Processing the image... Please wait.") | |
| accelerator = Accelerator() | |
| device = accelerator.device | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| config = get_config() | |
| tokenizer = get_or_build_tokenizer(config) | |
| model = get_model(config, len(tokenizer)) | |
| model = accelerator.prepare(model) | |
| accelerator.load_state('models/') | |
| text_output = process(model, image, tokenizer, device) | |
| st.write(text_output) | |
| if __name__ == "__main__": | |
| main() | |