File size: 4,155 Bytes
0d29a98
688a87d
e05052e
 
d430aa2
e05052e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e53af6
bdb8c95
0e53af6
df5f888
 
 
0e53af6
3e074ca
0e53af6
 
 
97f30ad
0e53af6
abc934b
 
8d0b547
abc934b
688a87d
0e53af6
 
 
 
c6c8487
89703a3
3c806f9
8d0b547
3c806f9
 
ed4ba53
06436a8
89703a3
 
 
 
 
4e84543
89703a3
4e84543
89703a3
 
0c5c558
c6c8487
 
77b0c52
9a9a80e
 
 
 
 
 
038e09d
e05052e
 
df5f888
 
 
 
 
 
9a9a80e
 
 
038e09d
31151c4
0ff1ff1
038e09d
c57ffa7
 
0c5c558
0eb629f
 
c57ffa7
 
df5f888
 
 
c57ffa7
 
fddb275
 
 
 
 
 
 
 
 
 
 
 
7e87351
 
fddb275
 
363481b
 
fddb275
 
7e87351
fddb275
 
abc934b
 
c2ea813
f37c69f
e05052e
0e53af6
01fb1e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
REVISION = "bce9358ca7928fc17c0c82d5fa2253aa681a4624"

try:
    import spaces
    
    IN_SPACES = True
except ImportError:
    from functools import wraps
    import inspect

    class spaces:
        @staticmethod
        def GPU(duration):
            def decorator(func):
                @wraps(func)  # Preserves the original function's metadata
                def wrapper(*args, **kwargs):
                    if inspect.isgeneratorfunction(func):
                        # If the decorated function is a generator, yield from it
                        yield from func(*args, **kwargs)
                    else:
                        # For regular functions, just return the result
                        return func(*args, **kwargs)

                return wrapper

            return decorator

    IN_SPACES = False

import torch
import os
import gradio as gr
import json

from queue import Queue
from threading import Thread
from transformers import AutoModelForCausalLM
from PIL import ImageDraw
from torchvision.transforms.v2 import Resize

os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True
moondream = AutoModelForCausalLM.from_pretrained(
    "vikhyatk/moondream-next",
    trust_remote_code=True,
    dtype=torch.bfloat16,
    device_map={"": "cuda"},
    revision=REVISION
)
moondream.eval()


@spaces.GPU(duration=10)
def localized_query(img, x, y, question):
    if img is None:
        yield "", gr.update(visible=False, value=None)
        return

    answer = moondream.query(img, question, spatial_refs=[(x, y)])["answer"]
    
    w, h = img.size
    x, y = x * w, y * h
    img_clone = img.copy()
    draw = ImageDraw.Draw(img_clone)
    draw.ellipse(
        (x - 5, y - 5, x + 5, y + 5),
        fill="red",
        outline="blue",
    )
    
    yield answer, gr.update(visible=True, value=img_clone)


js = ""

css = """
    .output-text span p {
        font-size: 1.4rem !important;
    }

    .chain-of-thought {
        opacity: 0.7 !important;
    }
    .chain-of-thought span.label {
        display: none;
    }
    .chain-of-thought span.textspan {
        padding-right: 0;
    }
"""

with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo:
    if IN_SPACES:
        # gr.HTML("<style>body, body gradio-app { background: none !important; }</style>")
        pass

    gr.Markdown(
        """
        # 🌔 grounded visual question answering

        upload an image, then click on it to ask a question about that region of the image.
        """
    )

    input_image = gr.State(None)

    with gr.Row():
        with gr.Column():
            @gr.render()
            def show_inputs():
    
                with gr.Group():
                    with gr.Row():
                        prompt = gr.Textbox(
                            label="Input",
                            value="What is this?",
                            scale=4,
                        )
                        submit = gr.Button("Submit")
                    img = gr.Image(type="pil", label="Upload an Image")
                    x_slider = gr.Slider(label="x", minimum=0, maximum=1, randomize=True)
                    y_slider = gr.Slider(label="y", minimum=0, maximum=1, randomize=True)
                submit.click(localized_query, [img, x_slider, y_slider, prompt], [output, ann])
                prompt.submit(localized_query, [img, x_slider, y_slider, prompt], [output, ann])
                x_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, ann])
                y_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, ann])
                img.change(localized_query, [img, x_slider, y_slider, prompt], [output, ann])
                def select_handler(image, evt: gr.SelectData):
                    w, h = image.size
                    return [evt.index[0] / w, evt.index[1] / h]
                img.select(select_handler, img, [x_slider, y_slider])

        with gr.Column():
            output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True)
            ann = gr.Image(visible=False)


demo.queue().launch()