yuhangzang commited on
Commit
babd02b
·
1 Parent(s): d173683

Gallery: pair image+prompt examples; load prompt on selection; add bottom citation bib section

Browse files
Files changed (2) hide show
  1. app.py +76 -18
  2. requirements.txt +3 -0
app.py CHANGED
@@ -27,7 +27,8 @@ def _load_model_and_processor():
27
  try:
28
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
29
  MODEL_ID,
30
- torch_dtype=DTYPE,
 
31
  attn_implementation=attn_impl,
32
  device_map="auto",
33
  )
@@ -36,7 +37,8 @@ def _load_model_and_processor():
36
  # Fallback for environments without flash-attn
37
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
38
  MODEL_ID,
39
- torch_dtype=DTYPE,
 
40
  attn_implementation="eager",
41
  device_map="auto",
42
  )
@@ -142,17 +144,32 @@ def generate(image, prompt, max_new_tokens, temperature, top_p, top_k):
142
 
143
  def build_ui():
144
  with gr.Blocks() as demo:
145
- gr.Markdown("# Spark-VL ZeroGPU Demo\nUpload an image or choose from the example gallery, then enter a prompt.")
146
-
147
- # Build an image gallery from ./examples
148
- def _gather_examples() -> List[str]:
 
 
 
 
149
  imgs = []
150
  for ext in ("jpg", "jpeg", "png", "webp"):
151
  imgs.extend(glob.glob(os.path.join("examples", f"*.{ext}")))
152
  # Deduplicate while keeping order
153
- return list(dict.fromkeys(sorted(imgs)))
154
-
155
- example_images = _gather_examples()
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # Load default image if exists
158
  default_path = os.path.join("examples", "example_0.png")
@@ -161,26 +178,42 @@ def build_ui():
161
  with gr.Row():
162
  with gr.Column(scale=1):
163
  image = gr.Image(type="pil", label="Image", value=default_image)
 
 
 
 
 
 
 
 
 
 
 
 
164
  gallery = gr.Gallery(
165
- value=example_images,
166
- label="Example Gallery",
167
  show_label=True,
168
  columns=4,
169
- height=240,
170
  allow_preview=True,
171
  )
172
 
173
  # When a thumbnail is clicked, load it into the image input
174
- def _on_gallery_select(evt: gr.SelectData):
 
175
  idx = evt.index
176
- if 0 <= idx < len(example_images):
 
177
  try:
178
- return Image.open(example_images[idx])
179
  except Exception:
180
- return None
181
- return None
 
 
182
 
183
- gallery.select(fn=_on_gallery_select, outputs=image)
184
 
185
  with gr.Column(scale=1):
186
  prompt = gr.Textbox(
@@ -198,6 +231,14 @@ def build_ui():
198
  top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
199
  run = gr.Button("Generate")
200
 
 
 
 
 
 
 
 
 
201
  output = gr.Textbox(label="Model Output", lines=8)
202
 
203
  run.click(
@@ -207,6 +248,23 @@ def build_ui():
207
  show_progress=True,
208
  )
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  demo.queue(max_size=10).launch()
211
  return demo
212
 
 
27
  try:
28
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
29
  MODEL_ID,
30
+ # `torch_dtype` was deprecated in Transformers; use `dtype` instead.
31
+ dtype=DTYPE,
32
  attn_implementation=attn_impl,
33
  device_map="auto",
34
  )
 
37
  # Fallback for environments without flash-attn
38
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
39
  MODEL_ID,
40
+ # Use the new `dtype` kwarg for consistency with deprecations
41
+ dtype=DTYPE,
42
  attn_implementation="eager",
43
  device_map="auto",
44
  )
 
144
 
145
  def build_ui():
146
  with gr.Blocks() as demo:
147
+ gr.Markdown("# Spark-VL ZeroGPU Demo\nUpload an image or choose from the example gallery (image + prompt), then enter a prompt.")
148
+
149
+ # Build an image+prompt gallery from ./examples
150
+ # Each example is an image file with an optional sidecar .txt containing the prompt.
151
+ # If a .txt is present (same basename), we will display a caption and load the
152
+ # prompt alongside the image when the thumbnail is selected.
153
+ def _gather_examples() -> List[tuple]:
154
+ pairs = [] # (image_path, prompt_text)
155
  imgs = []
156
  for ext in ("jpg", "jpeg", "png", "webp"):
157
  imgs.extend(glob.glob(os.path.join("examples", f"*.{ext}")))
158
  # Deduplicate while keeping order
159
+ for img_path in list(dict.fromkeys(sorted(imgs))):
160
+ stem, _ = os.path.splitext(img_path)
161
+ prompt_path = stem + ".txt"
162
+ prompt_text = None
163
+ if os.path.exists(prompt_path):
164
+ try:
165
+ with open(prompt_path, "r", encoding="utf-8") as fh:
166
+ prompt_text = fh.read().strip()
167
+ except Exception:
168
+ prompt_text = None
169
+ pairs.append((img_path, prompt_text))
170
+ return pairs
171
+
172
+ example_pairs = _gather_examples()
173
 
174
  # Load default image if exists
175
  default_path = os.path.join("examples", "example_0.png")
 
178
  with gr.Row():
179
  with gr.Column(scale=1):
180
  image = gr.Image(type="pil", label="Image", value=default_image)
181
+ # Prepare gallery items as (image, caption) so users can see
182
+ # that a prompt is associated with each example.
183
+ def _gallery_items():
184
+ items = []
185
+ for img_path, prompt_text in example_pairs:
186
+ caption = (prompt_text or "").strip()
187
+ # Keep captions compact to avoid tall tiles
188
+ if len(caption) > 120:
189
+ caption = caption[:117] + "..."
190
+ items.append((img_path, caption))
191
+ return items
192
+
193
  gallery = gr.Gallery(
194
+ value=_gallery_items(),
195
+ label="Examples (Image + Prompt)",
196
  show_label=True,
197
  columns=4,
198
+ height=260,
199
  allow_preview=True,
200
  )
201
 
202
  # When a thumbnail is clicked, load it into the image input
203
+ def _on_gallery_select(evt: gr.SelectData, cur_prompt: str = ""):
204
+ # Load both the example image and its paired prompt
205
  idx = evt.index
206
+ if 0 <= idx < len(example_pairs):
207
+ img_path, prompt_text = example_pairs[idx]
208
  try:
209
+ img_val = Image.open(img_path)
210
  except Exception:
211
+ img_val = None
212
+ # If no prompt sidecar, preserve the user's current prompt
213
+ return img_val, (prompt_text if prompt_text is not None else cur_prompt)
214
+ return None, cur_prompt
215
 
216
+ # Defer wiring the select handler until after the prompt component is created
217
 
218
  with gr.Column(scale=1):
219
  prompt = gr.Textbox(
 
231
  top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
232
  run = gr.Button("Generate")
233
 
234
+ # Now that both components exist, wire the gallery->(image,prompt) binding
235
+ try:
236
+ gallery.select(fn=_on_gallery_select, inputs=[prompt], outputs=[image, prompt])
237
+ except Exception:
238
+ # If the event cannot be bound (e.g., running in a limited environment),
239
+ # just skip wiring without breaking the app.
240
+ pass
241
+
242
  output = gr.Textbox(label="Model Output", lines=8)
243
 
244
  run.click(
 
248
  show_progress=True,
249
  )
250
 
251
+ # Citation section at the bottom
252
+ gr.Markdown(
253
+ """
254
+ ---
255
+ If you find this project useful, please kindly cite:
256
+
257
+ ```bibtex
258
+ @article{liu2025spark,
259
+ title={SPARK: Synergistic Policy And Reward Co-Evolving Framework},
260
+ author={Liu, Ziyu and Zang, Yuhang and Ding, Shengyuan and Cao, Yuhang and Dong, Xiaoyi and Duan, Haodong and Lin, Dahua and Wang, Jiaqi},
261
+ journal={arXiv preprint arXiv:2509.22624},
262
+ year={2025}
263
+ }
264
+ ```
265
+ """
266
+ )
267
+
268
  demo.queue(max_size=10).launch()
269
  return demo
270
 
requirements.txt CHANGED
@@ -5,3 +5,6 @@ gradio>=5.49.1
5
  spaces>=0.24.0
6
  pillow
7
  torchvision
 
 
 
 
5
  spaces>=0.24.0
6
  pillow
7
  torchvision
8
+ \n+# Optional: FlashAttention v2 for faster attention on compatible Linux CUDA GPUs.
9
+ # This installs only on 64-bit Linux. It will be skipped on macOS/Windows/ARM.
10
+ flash-attn; platform_system == "Linux" and platform_machine == "x86_64"