sergiopaniego's picture
better suggestions and lower sized images for snappier inference (#6)
d3b5e19 verified
import gradio as gr
from gradio.themes.ocean import Ocean
import torch
import numpy as np
import supervision as sv
from transformers import (
AutoModelForCausalLM,
Qwen3VLForConditionalGeneration,
Qwen3VLProcessor,
)
import json
import ast
import re
from PIL import Image
from spaces import GPU
# --- Constants and Configuration ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = "auto"
CATEGORIES = ["Query", "Caption", "Point", "Detect"]
PLACEHOLDERS = {
"Query": "What's in this image?",
"Caption": "Enter caption length: short, normal, or long",
"Point": "Select an object from suggestions or enter manually",
"Detect": "Select an object from suggestions or enter manually",
}
# --- Model Loading ---
# Load Moondream
moondream = AutoModelForCausalLM.from_pretrained(
"moondream/moondream3-preview",
trust_remote_code=True,
dtype=DTYPE,
device_map=DEVICE,
revision="main",
).eval()
# Load Qwen3-VL
qwen_model = Qwen3VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-VL-4B-Instruct",
dtype=DTYPE,
device_map=DEVICE,
).eval()
qwen_processor = Qwen3VLProcessor.from_pretrained(
"Qwen/Qwen3-VL-4B-Instruct",
)
# --- Utility Functions ---
def safe_parse_json(text: str):
text = text.strip()
text = re.sub(r"^```(json)?", "", text)
text = re.sub(r"```$", "", text)
text = text.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass
try:
return ast.literal_eval(text)
except Exception:
return {}
@GPU
def get_suggested_objects(image: Image.Image):
"""Get suggested objects in the image using Moondream"""
if image is None:
return []
try:
result = moondream.query(
image=image,
question="List the objects in the image in python list format.",
reasoning=False,
)
suggested_objects = ast.literal_eval(result["answer"])
if isinstance(suggested_objects, list):
if len(suggested_objects) > 3: # send not more than 3 suggestions
return suggested_objects[:3]
else:
suggested_objects
return []
except Exception as e:
print(f"Error getting suggestions: {e}")
return []
def annotate_image(image: Image.Image, result: dict):
if not isinstance(image, Image.Image):
return image # Return original if not a valid image
if not isinstance(result, dict):
return image # Return original if result is not a dict
original_width, original_height = image.size
# Handle Point annotations
if "points" in result and result["points"]:
points_list = []
for point in result.get("points", []):
x = int(point["x"] * original_width)
y = int(point["y"] * original_height)
points_list.append([x, y])
if not points_list:
return image
points_array = np.array(points_list).reshape(1, -1, 2)
key_points = sv.KeyPoints(xy=points_array)
vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
annotated_image = vertex_annotator.annotate(
scene=image.copy(), key_points=key_points
)
return annotated_image
# Handle Detection annotations
if "objects" in result and result["objects"]:
detections = sv.Detections.from_vlm(
sv.VLM.MOONDREAM,
result,
resolution_wh=image.size,
)
if len(detections) == 0:
return image
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=5)
annotated_scene = box_annotator.annotate(
scene=image.copy(), detections=detections
)
return annotated_scene
return image
# --- Inference Functions ---
def run_qwen_inference(image: Image.Image, prompt: str):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
inputs = qwen_processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(DEVICE)
with torch.inference_mode():
generated_ids = qwen_model.generate(
**inputs,
max_new_tokens=512,
)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = qwen_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return output_text
@GPU
def process_qwen(image: Image.Image, category: str, prompt: str):
if category == "Query":
return run_qwen_inference(image, prompt), {}
elif category == "Caption":
full_prompt = f"Provide a {prompt} length caption for the image."
return run_qwen_inference(image, full_prompt), {}
elif category == "Point":
full_prompt = (
f"Provide 2d point coordinates for {prompt}. Report in JSON format."
)
output_text = run_qwen_inference(image, full_prompt)
parsed_json = safe_parse_json(output_text)
points_result = {"points": []}
if isinstance(parsed_json, list):
for item in parsed_json:
if "point_2d" in item and len(item["point_2d"]) == 2:
x, y = item["point_2d"]
points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0})
return json.dumps(points_result, indent=2), points_result
elif category == "Detect":
full_prompt = (
f"Provide bounding box coordinates for {prompt}. Report in JSON format."
)
output_text = run_qwen_inference(image, full_prompt)
parsed_json = safe_parse_json(output_text)
objects_result = {"objects": []}
if isinstance(parsed_json, list):
for item in parsed_json:
if "bbox_2d" in item and len(item["bbox_2d"]) == 4:
xmin, ymin, xmax, ymax = item["bbox_2d"]
objects_result["objects"].append(
{
"x_min": xmin / 1000.0,
"y_min": ymin / 1000.0,
"x_max": xmax / 1000.0,
"y_max": ymax / 1000.0,
}
)
return json.dumps(objects_result, indent=2), objects_result
return "Invalid category", {}
@GPU
def process_moondream(image: Image.Image, category: str, prompt: str):
if category == "Query":
result = moondream.query(image=image, question=prompt)
return result["answer"], {}
elif category == "Caption":
result = moondream.caption(image, length=prompt)
return result["caption"], {}
elif category == "Point":
result = moondream.point(image, prompt)
return json.dumps(result, indent=2), result
elif category == "Detect":
result = moondream.detect(image, prompt)
return json.dumps(result, indent=2), result
return "Invalid category", {}
# --- Gradio Interface Logic ---
def on_category_and_image_change(image, category):
"""Generate suggestions when category changes to Point or Detect"""
text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True)
if image is None or category not in ["Point", "Detect", "Caption"]:
return gr.Radio(choices=[], visible=False), text_box
if category == "Caption":
return gr.Radio(choices=["short", "normal", "long"], visible=True), text_box
suggestions = get_suggested_objects(image)
if suggestions:
return gr.Radio(choices=suggestions, visible=True, interactive=True), text_box
else:
return gr.Radio(choices=["no choice possible"], visible=True, interactive=True), text_box
def update_prompt_from_radio(selected_object):
"""Update prompt textbox when a radio option is selected"""
if selected_object:
return gr.Textbox(value=selected_object)
return gr.Textbox(value="")
def process_inputs(image, category, prompt):
if image is None:
raise gr.Error("Please upload an image.")
if not prompt:
raise gr.Error("Please provide a prompt.")
# Resize the image to make inference quicker
image.thumbnail((512, 512))
# Process with Qwen
qwen_text, qwen_data = process_qwen(image, category, prompt)
qwen_annotated_image = annotate_image(image, qwen_data)
# Process with Moondream
moondream_text, moondream_data = process_moondream(image, category, prompt)
moondream_annotated_image = annotate_image(image, moondream_data)
return qwen_annotated_image, qwen_text, moondream_annotated_image, moondream_text
css_hide_share = """
button#gradio-share-link-button-0 {
display: none !important;
}
"""
# --- Gradio UI Layout ---
with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
gr.Markdown("# 👓 Object Understanding with Vision Language Models")
gr.Markdown(
"### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts."
)
gr.Markdown("""
*Powered by [Qwen3-VL 4B](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct) and [Moondream 3 Preview](https://huggingface.co/moondream/moondream3-preview). Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
*Moondream 3 uses the [moondream-preview](https://huggingface.co/vikhyatk/moondream2/blob/main/moondream.py), selecting `detect` for categories with "Object Detection" `point` for the ones with "Keypoint Detection", and reasoning-based querying for all others.*
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Input Image")
category_select = gr.Radio(
choices=CATEGORIES,
value=CATEGORIES[0],
label="Select Task Category",
interactive=True,
)
# Suggested objects radio (hidden by default)
suggestions_radio = gr.Radio(
choices=[],
label="Suggestions",
visible=False,
interactive=True,
)
prompt_input = gr.Textbox(
placeholder=PLACEHOLDERS[CATEGORIES[0]],
label="Prompt",
lines=2,
)
submit_btn = gr.Button("Compare Models", variant="primary")
with gr.Column(scale=2):
with gr.Row():
with gr.Column():
gr.Markdown("### Qwen/Qwen3-VL-4B-Instruct")
qwen_img_output = gr.Image(label="Annotated Image")
qwen_text_output = gr.Textbox(
label="Text Output", lines=8, interactive=False
)
with gr.Column():
gr.Markdown("### moondream/moondream3-preview")
moon_img_output = gr.Image(label="Annotated Image")
moon_text_output = gr.Textbox(
label="Text Output", lines=8, interactive=False
)
gr.Examples(
examples=[
["examples/example_1.jpg", "Query", "How many cars are in the image?"],
["examples/example_1.jpg", "Caption", ""],
["examples/example_2.JPG", "Point", ""],
["examples/example_2.JPG", "Detect", ""],
],
inputs=[image_input, category_select, prompt_input],
)
# --- Event Listeners ---
category_select.change(
fn=on_category_and_image_change,
inputs=[image_input, category_select],
outputs=[suggestions_radio, prompt_input],
)
suggestions_radio.change(
fn=update_prompt_from_radio,
inputs=[suggestions_radio],
outputs=[prompt_input],
)
submit_btn.click(
fn=process_inputs,
inputs=[image_input, category_select, prompt_input],
outputs=[qwen_img_output, qwen_text_output, moon_img_output, moon_text_output],
)
if __name__ == "__main__":
demo.launch()