SelfReVision / app.py
jrfish's picture
Update app.py
89d1322 verified
import gradio as gr
import random
from PIL import Image
import os
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
# Load your images
image_dir = "images"
images = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".jpg")]
STATIC_IMAGE_PATH = "images/Places365_val_00000009.jpg"
# Load the model and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Dummy inference function for now
def run_vlm(image, prompt):
inputs = processor(image, prompt, return_tensors="pt").to(device)
out = model.generate(**inputs)
return processor.decode(out[0], skip_special_tokens=True)
# return f"Model output for: '{prompt}' and selected image."
def get_random_image(event=None):
jpgs = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]
if not jpgs:
return gr.update(value=STATIC_IMAGE_PATH), gr.update(value=STATIC_IMAGE_PATH)
selected = os.path.join(image_dir, random.choice(jpgs))
print(selected)
return gr.update(value=selected), gr.update(value=selected)
def process(image_path, user_prompt):
image = Image.open(image_path)
return run_vlm(image, user_prompt)
with gr.Blocks() as demo:
with gr.Row():
image_display = gr.Image(value=STATIC_IMAGE_PATH, type="filepath", label="Selected Image")
image_path = gr.Textbox(value=STATIC_IMAGE_PATH, visible=False)
user_prompt = gr.Textbox(label="User Prompt")
with gr.Row():
random_button = gr.Button("Random Photo")
run_button = gr.Button("Run Model")
output = gr.Textbox(label="Model Output")
random_button.click(fn=get_random_image, outputs=[image_display, image_path])
run_button.click(fn=process, inputs=[image_path, user_prompt], outputs=output)
# with gr.Blocks() as demo:
# with gr.Row():
# image_display = gr.Image(type="filepath", label="Selected Image")
# random_button = gr.Button("Randomize Photo")
# image_path = gr.Textbox(visible=False)
# user_prompt = gr.Textbox(label="User Prompt")
# run_button = gr.Button("Run Model")
# output = gr.Textbox(label="Model Output")
# random_button.click(fn=random_image, outputs=[image_display, image_path])
# run_button.click(fn=process, inputs=[image_path, user_prompt], outputs=output)
demo.launch()