Spaces:
Sleeping
Sleeping
| import torch | |
| import clip | |
| import PIL.Image | |
| from PIL import Image | |
| import skimage.io as io | |
| import streamlit as st | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup | |
| from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel | |
| from model import generate2,ClipCaptionModel | |
| from engine import inference | |
| model_trained = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| model_trained.load_state_dict(torch.load('model_trained.pth',map_location=torch.device('cpu')),strict=False) | |
| image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| def show_n_generate(img, model, greedy = True): | |
| image = Image.open(img) | |
| pixel_values = image_processor(image, return_tensors ="pt").pixel_values | |
| if greedy: | |
| generated_ids = model.generate(pixel_values, max_new_tokens = 30) | |
| else: | |
| generated_ids = model.generate( | |
| pixel_values, | |
| do_sample=True, | |
| max_new_tokens = 30, | |
| top_k=5) | |
| generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return generated_text | |
| device = "cpu" | |
| clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| prefix_length = 10 | |
| model = ClipCaptionModel(prefix_length) | |
| model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu')),strict=False) | |
| model = model.eval() | |
| coco_model = ClipCaptionModel(prefix_length) | |
| coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu')),strict=False) | |
| model = model.eval() | |
| def ui(): | |
| st.markdown("# Image Captioning") | |
| # st.markdown("## Done By- Vageesh and Rushil") | |
| uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg']) | |
| if uploaded_file is not None: | |
| image = io.imread(uploaded_file) | |
| pil_image = PIL.Image.fromarray(image) | |
| image = preprocess(pil_image).unsqueeze(0).to(device) | |
| option = st.selectbox('Please select the Model',('Clip Captioning','Attention Decoder','VIT+GPT2')) | |
| if option=='Clip Captioning': | |
| with torch.no_grad(): | |
| prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) | |
| prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) | |
| generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed) | |
| st.image(uploaded_file, width = 500, channels = 'RGB') | |
| st.markdown("**PREDICTION:** " + generated_text_prefix) | |
| elif option=='Attention Decoder': | |
| out = inference(uploaded_file) | |
| st.image(uploaded_file, width = 500, channels = 'RGB') | |
| st.markdown("**PREDICTION:** " + out) | |
| # elif option=='VIT+GPT2': | |
| # out=show_n_generate(uploaded_file, greedy = False, model = model_trained) | |
| # st.image(uploaded_file, width = 500, channels = 'RGB') | |
| # st.markdown("**PREDICTION:** " + out) | |
| if __name__ == '__main__': | |
| ui() | |