Spaces:
Sleeping
Sleeping
| import torch | |
| import clip | |
| import PIL.Image | |
| import skimage.io as io | |
| import streamlit as st | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup | |
| from model import preprocess,clip_model,generate2,ClipCaptionModel | |
| #model loading code | |
| device = "cpu" | |
| clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| prefix_length = 10 | |
| model = ClipCaptionModel(prefix_length) | |
| model.load_state_dict(torch.load('C:\Deep learning lab\DLops Project\Cl+gpt2\model.h5',map_location=torch.device('cpu'))) | |
| model = model.eval() | |
| coco_model = ClipCaptionModel(prefix_length) | |
| coco_model.load_state_dict(torch.load('C:\Deep learning lab\DLops Project\Cl+gpt2\COCO_model.h5',map_location=torch.device('cpu'))) | |
| model = model.eval() | |
| def ui(): | |
| st.markdown("# Image Captioning") | |
| uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg']) | |
| if uploaded_file is not None: | |
| image = io.imread(uploaded_file) | |
| pil_image = PIL.Image.fromarray(image) | |
| image = preprocess(pil_image).unsqueeze(0).to(device) | |
| option = st.selectbox('Please select the Model',('Model', 'COCO Model')) | |
| if option=='Model': | |
| with torch.no_grad(): | |
| prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) | |
| prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) | |
| generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed) | |
| st.image(uploaded_file, width = 500, channels = 'RGB') | |
| st.markdown("**PREDICTION:** " + generated_text_prefix) | |
| elif option=='COCO Model': | |
| with torch.no_grad(): | |
| prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) | |
| prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) | |
| generated_text_prefix = generate2(coco_model, tokenizer, embed=prefix_embed) | |
| st.image(uploaded_file, width = 500, channels = 'RGB') | |
| st.markdown("**PREDICTION:** " + generated_text_prefix) | |
| if __name__ == '__main__': | |
| ui() | |