Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -56,9 +56,9 @@ def ui():
|
|
| 56 |
pil_image = PIL.Image.fromarray(image)
|
| 57 |
image = preprocess(pil_image).unsqueeze(0).to(device)
|
| 58 |
|
| 59 |
-
option = st.selectbox('Please select the Model',('
|
| 60 |
|
| 61 |
-
if option=='
|
| 62 |
with torch.no_grad():
|
| 63 |
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
|
| 64 |
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
|
@@ -66,12 +66,12 @@ def ui():
|
|
| 66 |
|
| 67 |
st.image(uploaded_file, width = 500, channels = 'RGB')
|
| 68 |
st.markdown("**PREDICTION:** " + generated_text_prefix)
|
| 69 |
-
elif option=='
|
| 70 |
out = inference(uploaded_file)
|
| 71 |
st.image(uploaded_file, width = 500, channels = 'RGB')
|
| 72 |
st.markdown("**PREDICTION:** " + out)
|
| 73 |
|
| 74 |
-
elif option=='
|
| 75 |
out=show_n_generate(uploaded_file, greedy = False, model = model_trained)
|
| 76 |
st.image(uploaded_file, width = 500, channels = 'RGB')
|
| 77 |
st.markdown("**PREDICTION:** " + out)
|
|
|
|
| 56 |
pil_image = PIL.Image.fromarray(image)
|
| 57 |
image = preprocess(pil_image).unsqueeze(0).to(device)
|
| 58 |
|
| 59 |
+
option = st.selectbox('Please select the Model',('Clip Captioning','Attention Decoder','VIT+GPT2'))
|
| 60 |
|
| 61 |
+
if option=='Clip Captioning':
|
| 62 |
with torch.no_grad():
|
| 63 |
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
|
| 64 |
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
|
|
|
| 66 |
|
| 67 |
st.image(uploaded_file, width = 500, channels = 'RGB')
|
| 68 |
st.markdown("**PREDICTION:** " + generated_text_prefix)
|
| 69 |
+
elif option=='Attention Decoder':
|
| 70 |
out = inference(uploaded_file)
|
| 71 |
st.image(uploaded_file, width = 500, channels = 'RGB')
|
| 72 |
st.markdown("**PREDICTION:** " + out)
|
| 73 |
|
| 74 |
+
elif option=='VIT+GPT2':
|
| 75 |
out=show_n_generate(uploaded_file, greedy = False, model = model_trained)
|
| 76 |
st.image(uploaded_file, width = 500, channels = 'RGB')
|
| 77 |
st.markdown("**PREDICTION:** " + out)
|