Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from gradio.themes.ocean import Ocean | |
| import torch | |
| import numpy as np | |
| import supervision as sv | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| Qwen3VLForConditionalGeneration, | |
| Qwen3VLProcessor, | |
| ) | |
| import json | |
| import ast | |
| import re | |
| from PIL import Image | |
| from spaces import GPU | |
| # --- Constants and Configuration --- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = "auto" | |
| CATEGORIES = ["Query", "Caption", "Point", "Detect"] | |
| PLACEHOLDERS = { | |
| "Query": "What's in this image?", | |
| "Caption": "Enter caption length: short, normal, or long", | |
| "Point": "Select an object from suggestions or enter manually", | |
| "Detect": "Select an object from suggestions or enter manually", | |
| } | |
| # --- Model Loading --- | |
| # Load Moondream | |
| moondream = AutoModelForCausalLM.from_pretrained( | |
| "moondream/moondream3-preview", | |
| trust_remote_code=True, | |
| dtype=DTYPE, | |
| device_map=DEVICE, | |
| revision="main", | |
| ).eval() | |
| # Load Qwen3-VL | |
| qwen_model = Qwen3VLForConditionalGeneration.from_pretrained( | |
| "Qwen/Qwen3-VL-4B-Instruct", | |
| dtype=DTYPE, | |
| device_map=DEVICE, | |
| ).eval() | |
| qwen_processor = Qwen3VLProcessor.from_pretrained( | |
| "Qwen/Qwen3-VL-4B-Instruct", | |
| ) | |
| # --- Utility Functions --- | |
| def safe_parse_json(text: str): | |
| text = text.strip() | |
| text = re.sub(r"^```(json)?", "", text) | |
| text = re.sub(r"```$", "", text) | |
| text = text.strip() | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| try: | |
| return ast.literal_eval(text) | |
| except Exception: | |
| return {} | |
| def get_suggested_objects(image: Image.Image): | |
| """Get suggested objects in the image using Moondream""" | |
| if image is None: | |
| return [] | |
| try: | |
| result = moondream.query( | |
| image=image, | |
| question="What objects are in the image, provide the list.", | |
| reasoning=False, | |
| ) | |
| suggested_objects = ast.literal_eval(result["answer"]) | |
| if isinstance(suggested_objects, list): | |
| if len(suggested_objects) > 3: # send not more than 3 suggestions | |
| return suggested_objects[:3] | |
| else: | |
| suggested_objects | |
| return [] | |
| except Exception as e: | |
| print(f"Error getting suggestions: {e}") | |
| return [] | |
| def annotate_image(image: Image.Image, result: dict): | |
| if not isinstance(image, Image.Image): | |
| return image # Return original if not a valid image | |
| if not isinstance(result, dict): | |
| return image # Return original if result is not a dict | |
| original_width, original_height = image.size | |
| # Handle Point annotations | |
| if "points" in result and result["points"]: | |
| points_list = [] | |
| for point in result.get("points", []): | |
| x = int(point["x"] * original_width) | |
| y = int(point["y"] * original_height) | |
| points_list.append([x, y]) | |
| if not points_list: | |
| return image | |
| points_array = np.array(points_list).reshape(1, -1, 2) | |
| key_points = sv.KeyPoints(xy=points_array) | |
| vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED) | |
| annotated_image = vertex_annotator.annotate( | |
| scene=image.copy(), key_points=key_points | |
| ) | |
| return annotated_image | |
| # Handle Detection annotations | |
| if "objects" in result and result["objects"]: | |
| detections = sv.Detections.from_vlm( | |
| sv.VLM.MOONDREAM, | |
| result, | |
| resolution_wh=image.size, | |
| ) | |
| if len(detections) == 0: | |
| return image | |
| box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=5) | |
| annotated_scene = box_annotator.annotate( | |
| scene=image.copy(), detections=detections | |
| ) | |
| return annotated_scene | |
| return image | |
| # --- Inference Functions --- | |
| def run_qwen_inference(image: Image.Image, prompt: str): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| inputs = qwen_processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(DEVICE) | |
| with torch.inference_mode(): | |
| generated_ids = qwen_model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = qwen_processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| )[0] | |
| return output_text | |
| def process_qwen(image: Image.Image, category: str, prompt: str): | |
| if category == "Query": | |
| return run_qwen_inference(image, prompt), {} | |
| elif category == "Caption": | |
| full_prompt = f"Provide a {prompt} length caption for the image." | |
| return run_qwen_inference(image, full_prompt), {} | |
| elif category == "Point": | |
| full_prompt = ( | |
| f"Provide 2d point coordinates for {prompt}. Report in JSON format." | |
| ) | |
| output_text = run_qwen_inference(image, full_prompt) | |
| parsed_json = safe_parse_json(output_text) | |
| points_result = {"points": []} | |
| if isinstance(parsed_json, list): | |
| for item in parsed_json: | |
| if "point_2d" in item and len(item["point_2d"]) == 2: | |
| x, y = item["point_2d"] | |
| points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0}) | |
| return json.dumps(points_result, indent=2), points_result | |
| elif category == "Detect": | |
| full_prompt = ( | |
| f"Provide bounding box coordinates for {prompt}. Report in JSON format." | |
| ) | |
| output_text = run_qwen_inference(image, full_prompt) | |
| parsed_json = safe_parse_json(output_text) | |
| objects_result = {"objects": []} | |
| if isinstance(parsed_json, list): | |
| for item in parsed_json: | |
| if "bbox_2d" in item and len(item["bbox_2d"]) == 4: | |
| xmin, ymin, xmax, ymax = item["bbox_2d"] | |
| objects_result["objects"].append( | |
| { | |
| "x_min": xmin / 1000.0, | |
| "y_min": ymin / 1000.0, | |
| "x_max": xmax / 1000.0, | |
| "y_max": ymax / 1000.0, | |
| } | |
| ) | |
| return json.dumps(objects_result, indent=2), objects_result | |
| return "Invalid category", {} | |
| def process_moondream(image: Image.Image, category: str, prompt: str): | |
| if category == "Query": | |
| result = moondream.query(image=image, question=prompt) | |
| return result["answer"], {} | |
| elif category == "Caption": | |
| result = moondream.caption(image, length=prompt) | |
| return result["caption"], {} | |
| elif category == "Point": | |
| result = moondream.point(image, prompt) | |
| return json.dumps(result, indent=2), result | |
| elif category == "Detect": | |
| result = moondream.detect(image, prompt) | |
| return json.dumps(result, indent=2), result | |
| return "Invalid category", {} | |
| # --- Gradio Interface Logic --- | |
| def on_category_and_image_change(image, category): | |
| """Generate suggestions when category changes to Point or Detect""" | |
| text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True) | |
| if image is None or category not in ["Point", "Detect", "Caption"]: | |
| return gr.Radio(choices=[], visible=False), text_box | |
| if category == "Caption": | |
| return gr.Radio(choices=["short", "normal", "long"], visible=True), text_box | |
| suggestions = get_suggested_objects(image) | |
| if suggestions: | |
| return gr.Radio(choices=suggestions, visible=True, interactive=True), text_box | |
| else: | |
| return gr.Radio(choices=["no choice possible"], visible=True, interactive=True), text_box | |
| def update_prompt_from_radio(selected_object): | |
| """Update prompt textbox when a radio option is selected""" | |
| if selected_object: | |
| return gr.Textbox(value=selected_object) | |
| return gr.Textbox(value="") | |
| def process_inputs(image, category, prompt): | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| if not prompt: | |
| raise gr.Error("Please provide a prompt.") | |
| # Process with Qwen | |
| qwen_text, qwen_data = process_qwen(image, category, prompt) | |
| qwen_annotated_image = annotate_image(image, qwen_data) | |
| # Process with Moondream | |
| moondream_text, moondream_data = process_moondream(image, category, prompt) | |
| moondream_annotated_image = annotate_image(image, moondream_data) | |
| return qwen_annotated_image, qwen_text, moondream_annotated_image, moondream_text | |
| css_hide_share = """ | |
| button#gradio-share-link-button-0 { | |
| display: none !important; | |
| } | |
| """ | |
| # --- Gradio UI Layout --- | |
| with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo: | |
| gr.Markdown("# 👓 Object Understanding with Vision Language Models") | |
| gr.Markdown( | |
| "### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts." | |
| ) | |
| gr.Markdown(""" | |
| *Powered by [Qwen3-VL 4B](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct) and [Moondream 3 Preview](https://huggingface.co/moondream/moondream3-preview). Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.* | |
| *Moondream 3 uses the [moondream-preview](https://huggingface.co/vikhyatk/moondream2/blob/main/moondream.py), selecting `detect` for categories with "Object Detection" `point` for the ones with "Keypoint Detection", and reasoning-based querying for all others.* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Input Image") | |
| category_select = gr.Radio( | |
| choices=CATEGORIES, | |
| value=CATEGORIES[0], | |
| label="Select Task Category", | |
| interactive=True, | |
| ) | |
| # Suggested objects radio (hidden by default) | |
| suggestions_radio = gr.Radio( | |
| choices=[], | |
| label="Suggestions", | |
| visible=False, | |
| interactive=True, | |
| ) | |
| prompt_input = gr.Textbox( | |
| placeholder=PLACEHOLDERS[CATEGORIES[0]], | |
| label="Prompt", | |
| lines=2, | |
| ) | |
| submit_btn = gr.Button("Compare Models", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Qwen/Qwen3-VL-4B-Instruct") | |
| qwen_img_output = gr.Image(label="Annotated Image") | |
| qwen_text_output = gr.Textbox( | |
| label="Text Output", lines=8, interactive=False | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### moondream/moondream3-preview") | |
| moon_img_output = gr.Image(label="Annotated Image") | |
| moon_text_output = gr.Textbox( | |
| label="Text Output", lines=8, interactive=False | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/example_1.jpg", "Query", "How many cars are in the image?"], | |
| ["examples/example_1.jpg", "Caption", ""], | |
| ["examples/example_2.JPG", "Point", ""], | |
| ["examples/example_2.JPG", "Detect", ""], | |
| ], | |
| inputs=[image_input, category_select, prompt_input], | |
| ) | |
| # --- Event Listeners --- | |
| category_select.change( | |
| fn=on_category_and_image_change, | |
| inputs=[image_input, category_select], | |
| outputs=[suggestions_radio, prompt_input], | |
| ) | |
| suggestions_radio.change( | |
| fn=update_prompt_from_radio, | |
| inputs=[suggestions_radio], | |
| outputs=[prompt_input], | |
| ) | |
| submit_btn.click( | |
| fn=process_inputs, | |
| inputs=[image_input, category_select, prompt_input], | |
| outputs=[qwen_img_output, qwen_text_output, moon_img_output, moon_text_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |