grounded-vqa / app.py
vikhyatk's picture
Update app.py
363481b verified
REVISION = "bce9358ca7928fc17c0c82d5fa2253aa681a4624"
try:
import spaces
IN_SPACES = True
except ImportError:
from functools import wraps
import inspect
class spaces:
@staticmethod
def GPU(duration):
def decorator(func):
@wraps(func) # Preserves the original function's metadata
def wrapper(*args, **kwargs):
if inspect.isgeneratorfunction(func):
# If the decorated function is a generator, yield from it
yield from func(*args, **kwargs)
else:
# For regular functions, just return the result
return func(*args, **kwargs)
return wrapper
return decorator
IN_SPACES = False
import torch
import os
import gradio as gr
import json
from queue import Queue
from threading import Thread
from transformers import AutoModelForCausalLM
from PIL import ImageDraw
from torchvision.transforms.v2 import Resize
os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True
moondream = AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream-next",
trust_remote_code=True,
dtype=torch.bfloat16,
device_map={"": "cuda"},
revision=REVISION
)
moondream.eval()
@spaces.GPU(duration=10)
def localized_query(img, x, y, question):
if img is None:
yield "", gr.update(visible=False, value=None)
return
answer = moondream.query(img, question, spatial_refs=[(x, y)])["answer"]
w, h = img.size
x, y = x * w, y * h
img_clone = img.copy()
draw = ImageDraw.Draw(img_clone)
draw.ellipse(
(x - 5, y - 5, x + 5, y + 5),
fill="red",
outline="blue",
)
yield answer, gr.update(visible=True, value=img_clone)
js = ""
css = """
.output-text span p {
font-size: 1.4rem !important;
}
.chain-of-thought {
opacity: 0.7 !important;
}
.chain-of-thought span.label {
display: none;
}
.chain-of-thought span.textspan {
padding-right: 0;
}
"""
with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo:
if IN_SPACES:
# gr.HTML("<style>body, body gradio-app { background: none !important; }</style>")
pass
gr.Markdown(
"""
# 🌔 grounded visual question answering
upload an image, then click on it to ask a question about that region of the image.
"""
)
input_image = gr.State(None)
with gr.Row():
with gr.Column():
@gr.render()
def show_inputs():
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Input",
value="What is this?",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
x_slider = gr.Slider(label="x", minimum=0, maximum=1, randomize=True)
y_slider = gr.Slider(label="y", minimum=0, maximum=1, randomize=True)
submit.click(localized_query, [img, x_slider, y_slider, prompt], [output, ann])
prompt.submit(localized_query, [img, x_slider, y_slider, prompt], [output, ann])
x_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, ann])
y_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, ann])
img.change(localized_query, [img, x_slider, y_slider, prompt], [output, ann])
def select_handler(image, evt: gr.SelectData):
w, h = image.size
return [evt.index[0] / w, evt.index[1] / h]
img.select(select_handler, img, [x_slider, y_slider])
with gr.Column():
output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True)
ann = gr.Image(visible=False)
demo.queue().launch()