vikhyatk commited on
Commit
df5f888
·
verified ·
1 Parent(s): dcacf66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -6
app.py CHANGED
@@ -26,9 +26,11 @@ except ImportError:
26
  IN_SPACES = False
27
 
28
  import torch
29
- from queue import Queue
30
  import os
31
  import gradio as gr
 
 
 
32
  from threading import Thread
33
  from transformers import (
34
  TextIteratorStreamer,
@@ -48,7 +50,7 @@ if IN_SPACES:
48
  )
49
 
50
  auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
51
- tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2")
52
  moondream = AutoModelForCausalLM.from_pretrained(
53
  "vikhyatk/moondream-next",
54
  trust_remote_code=True,
@@ -57,9 +59,79 @@ moondream = AutoModelForCausalLM.from_pretrained(
57
  attn_implementation="flash_attention_2",
58
  token=auth_token if IN_SPACES else None,
59
  )
 
 
 
 
 
 
 
 
 
 
 
 
60
  moondream.eval()
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  @spaces.GPU(duration=10)
64
  def answer_question(img, prompt):
65
  if img is None:
@@ -85,10 +157,12 @@ def answer_question(img, prompt):
85
  buffer = ""
86
  for new_text in streamer:
87
  buffer += new_text
88
- yield buffer.strip(), "Thinking..."
89
 
90
  answer = queue.get()
91
- yield answer["answer"], answer["thought"]
 
 
92
 
93
 
94
  @spaces.GPU(duration=10)
@@ -135,7 +209,9 @@ def detect(img, object):
135
  width=3,
136
  )
137
 
138
- yield f"{len(objs)} detected", gr.update(visible=True, value=img)
 
 
139
 
140
 
141
  js = """
@@ -266,6 +342,12 @@ css = """
266
  .chain-of-thought {
267
  opacity: 0.7 !important;
268
  }
 
 
 
 
 
 
269
 
270
  #life-canvas {
271
  position: fixed;
@@ -294,6 +376,9 @@ with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo:
294
  show_label=False,
295
  value=lambda: "Caption",
296
  )
 
 
 
297
  with gr.Row():
298
  with gr.Column():
299
 
@@ -312,6 +397,7 @@ with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo:
312
  submit.click(answer_question, [img, prompt], [output, thought])
313
  prompt.submit(answer_question, [img, prompt], [output, thought])
314
  img.change(answer_question, [img, prompt], [output, thought])
 
315
  elif mode == "Caption":
316
  with gr.Group():
317
  with gr.Row():
@@ -342,10 +428,42 @@ with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo:
342
  gr.Markdown("Coming soon!")
343
 
344
  with gr.Column():
345
- thought = gr.Markdown(elem_classes=["chain-of-thought"], line_breaks=True)
 
 
 
 
346
  output = gr.Markdown(label="Response", elem_classes=["output-text"])
347
  ann = gr.Image(visible=False)
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  mode_radio.change(
350
  lambda: ("", "", gr.update(visible=False, value=None)),
351
  [],
 
26
  IN_SPACES = False
27
 
28
  import torch
 
29
  import os
30
  import gradio as gr
31
+ import json
32
+
33
+ from queue import Queue
34
  from threading import Thread
35
  from transformers import (
36
  TextIteratorStreamer,
 
50
  )
51
 
52
  auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
53
+ tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream-next")
54
  moondream = AutoModelForCausalLM.from_pretrained(
55
  "vikhyatk/moondream-next",
56
  trust_remote_code=True,
 
59
  attn_implementation="flash_attention_2",
60
  token=auth_token if IN_SPACES else None,
61
  )
62
+
63
+ # CKPT_DIRS = ["/tmp/md-ckpt/ckpt/ft/song-moon-4c-s15/s72001/"]
64
+ # def get_ckpt(filename):
65
+ # ckpts = [
66
+ # torch.load(os.path.join(dir, filename), map_location="cpu") for dir in CKPT_DIRS
67
+ # ]
68
+ # avg_ckpt = {
69
+ # key.replace("._orig_mod", ""): sum(ckpt[key] for ckpt in ckpts) / len(ckpts)
70
+ # for key in ckpts[0]
71
+ # }
72
+ # return avg_ckpt
73
+ # moondream.load_state_dict(get_ckpt("model.pt"))
74
  moondream.eval()
75
 
76
 
77
+ def convert_to_entities(text, coords):
78
+ """
79
+ Converts a string with special markers into an entity representation.
80
+ Markers:
81
+ - <|coord|> pairs indicate coordinate markers
82
+ - <|start_ground|> indicates the start of a ground term
83
+ - <|end_ground|> indicates the end of a ground term
84
+
85
+ Returns:
86
+ - Dictionary with cleaned text and entities with their character positions
87
+ """
88
+ # Initialize variables
89
+ cleaned_text = ""
90
+ entities = []
91
+ entity = []
92
+
93
+ # Track current position in cleaned text
94
+ current_pos = 0
95
+ # Track if we're currently processing an entity
96
+ in_entity = False
97
+ entity_start = 0
98
+
99
+ i = 0
100
+ while i < len(text):
101
+ # Check for markers
102
+ if text[i : i + 9] == "<|coord|>":
103
+ i += 9
104
+ entity.append(coords.pop(0))
105
+ continue
106
+
107
+ elif text[i : i + 16] == "<|start_ground|>":
108
+ in_entity = True
109
+ entity_start = current_pos
110
+ i += 16
111
+ continue
112
+
113
+ elif text[i : i + 14] == "<|end_ground|>":
114
+ # Store entity position
115
+ entities.append(
116
+ {
117
+ "entity": json.dumps(entity),
118
+ "start": entity_start,
119
+ "end": current_pos,
120
+ }
121
+ )
122
+ entity = []
123
+ in_entity = False
124
+ i += 14
125
+ continue
126
+
127
+ # Add character to cleaned text
128
+ cleaned_text += text[i]
129
+ current_pos += 1
130
+ i += 1
131
+
132
+ return {"text": cleaned_text, "entities": entities}
133
+
134
+
135
  @spaces.GPU(duration=10)
136
  def answer_question(img, prompt):
137
  if img is None:
 
157
  buffer = ""
158
  for new_text in streamer:
159
  buffer += new_text
160
+ yield buffer.strip(), {"text": "Thinking...", "entities": []}
161
 
162
  answer = queue.get()
163
+ thought = convert_to_entities(answer["thought"], answer["coords"])
164
+
165
+ yield answer["answer"], thought
166
 
167
 
168
  @spaces.GPU(duration=10)
 
209
  width=3,
210
  )
211
 
212
+ yield {"text": f"{len(objs)} detected", "entities": []}, gr.update(
213
+ visible=True, value=img
214
+ )
215
 
216
 
217
  js = """
 
342
  .chain-of-thought {
343
  opacity: 0.7 !important;
344
  }
345
+ .chain-of-thought span.label {
346
+ display: none;
347
+ }
348
+ .chain-of-thought span.textspan {
349
+ padding-right: 0;
350
+ }
351
 
352
  #life-canvas {
353
  position: fixed;
 
376
  show_label=False,
377
  value=lambda: "Caption",
378
  )
379
+
380
+ input_image = gr.State(None)
381
+
382
  with gr.Row():
383
  with gr.Column():
384
 
 
397
  submit.click(answer_question, [img, prompt], [output, thought])
398
  prompt.submit(answer_question, [img, prompt], [output, thought])
399
  img.change(answer_question, [img, prompt], [output, thought])
400
+ img.change(lambda img: img, [img], [input_image])
401
  elif mode == "Caption":
402
  with gr.Group():
403
  with gr.Row():
 
428
  gr.Markdown("Coming soon!")
429
 
430
  with gr.Column():
431
+ thought = gr.HighlightedText(
432
+ elem_classes=["chain-of-thought"],
433
+ label="Thinking tokens",
434
+ interactive=False,
435
+ )
436
  output = gr.Markdown(label="Response", elem_classes=["output-text"])
437
  ann = gr.Image(visible=False)
438
 
439
+ def on_select(img, evt: gr.SelectData):
440
+ if img is None or evt.value[1] is None:
441
+ return gr.update(visible=False, value=None)
442
+
443
+ w, h = img.size
444
+ if w > 768 or h > 768:
445
+ img = Resize(768)(img)
446
+ w, h = img.size
447
+
448
+ coords = json.loads(evt.value[1])
449
+ if len(coords) != 2:
450
+ raise ValueError("Only points supported right now.")
451
+ coords[0] = int(coords[0] * w)
452
+ coords[1] = int(coords[1] * h)
453
+
454
+ img_clone = img.copy()
455
+ draw = ImageDraw.Draw(img_clone)
456
+ draw.ellipse(
457
+ (coords[0] - 3, coords[1] - 3, coords[0] + 3, coords[1] + 3),
458
+ fill="red",
459
+ outline="red",
460
+ )
461
+
462
+ return gr.update(visible=True, value=img_clone)
463
+
464
+ thought.select(on_select, [input_image], [ann])
465
+ input_image.change(lambda: gr.update(visible=False), [], [ann])
466
+
467
  mode_radio.change(
468
  lambda: ("", "", gr.update(visible=False, value=None)),
469
  [],