File size: 14,249 Bytes
67b36a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
babd02b
 
67b36a4
 
 
 
00111a4
67b36a4
 
 
babd02b
 
67b36a4
 
 
 
 
 
 
 
 
 
 
 
d173683
 
 
 
 
 
 
 
 
 
 
 
 
67b36a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d173683
67b36a4
 
 
 
 
 
 
 
 
 
 
 
00111a4
67b36a4
 
00111a4
 
67b36a4
00111a4
67b36a4
00111a4
67b36a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00111a4
67b36a4
 
 
 
 
 
 
 
fd68401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
babd02b
 
 
 
 
 
 
00111a4
 
 
67b36a4
babd02b
 
 
 
 
 
 
 
 
 
 
 
 
 
67b36a4
00111a4
 
 
67b36a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d173683
67b36a4
 
 
 
 
fd68401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
babd02b
67b36a4
 
 
 
 
 
 
 
 
babd02b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35fd22b
67b36a4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
import os
import time
import glob
from typing import List

import spaces
import gradio as gr
import torch
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor

MODEL_ID = os.environ.get("SPARK_MODEL_ID", "internlm/Spark-VL-7B")
DTYPE = torch.bfloat16

_model = None
_processor = None
_attn_impl = None


def _load_model_and_processor():
    global _model, _processor, _attn_impl
    if _model is not None and _processor is not None:
        return _model, _processor

    # Prefer flash-attn if available, otherwise fall back to eager.
    attn_impl = os.environ.get("ATTN_IMPL", "flash_attention_2")
    try:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            MODEL_ID,
            # `torch_dtype` was deprecated in Transformers; use `dtype` instead.
            dtype=DTYPE,
            attn_implementation=attn_impl,
            device_map="auto",
        )
        _attn_impl = attn_impl
    except (ImportError, ValueError, RuntimeError):
        # Fallback for environments without flash-attn
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            MODEL_ID,
            # Use the new `dtype` kwarg for consistency with deprecations
            dtype=DTYPE,
            attn_implementation="eager",
            device_map="auto",
        )
        _attn_impl = "eager"

    processor = AutoProcessor.from_pretrained(MODEL_ID)

    _model = model
    _processor = processor
    return _model, _processor


# Optionally preload the model at app startup so first click is fast.
# - On ZeroGPU, this will instantiate on CPU (no GPU at startup), so the
#   first generate only needs to move tensors to CUDA.
# - You can disable by setting env `PRELOAD_MODEL=0`.
if os.environ.get("PRELOAD_MODEL", "1") not in ("0", "false", "False"):
    try:
        _load_model_and_processor()
        print(f"[preload] Loaded {MODEL_ID} (attn_impl={_attn_impl})", flush=True)
    except Exception as e:
        # Don't block app if preload fails; fallback to lazy load on first call
        print(f"[preload] Skipped due to: {type(e).__name__}: {e}", flush=True)


def _prepare_inputs(image, prompt):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        }
    ]
    chat_text = _processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = _processor(
        text=[chat_text],
        # Pass the single image directly; template contains <image> placeholder
        images=[image] if image is not None else None,
        return_tensors="pt",
    )
    return inputs


def _decode(generated_ids, input_ids):
    # Trim the prompt part before decoding
    trimmed = generated_ids[:, input_ids.shape[1] :]
    out = _processor.batch_decode(
        trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return out[0].strip() if out else ""


@spaces.GPU(duration=120)
def generate(image, prompt, max_new_tokens, temperature, top_p, top_k):
    if image is None:
        return "Please upload an image."
    prompt = (prompt or "").strip()
    if not prompt:
        return "Please enter a prompt."

    start = time.time()
    model, _ = _load_model_and_processor()

    try:
        # Ensure model resides on GPU during the call
        dev = next(model.parameters()).device
        if dev.type != "cuda":
            model.to("cuda")
            dev = torch.device("cuda")
    except StopIteration:
        dev = torch.device("cuda")

    try:
        inputs = _prepare_inputs(image, prompt)
        inputs = {k: v.to(dev) if hasattr(v, "to") else v for k, v in inputs.items()}

        gen_kwargs = {
            "max_new_tokens": int(max_new_tokens),
            "do_sample": True,
            "temperature": float(temperature),
            "top_p": float(top_p),
            "top_k": int(top_k),
            "use_cache": True,
        }
        with torch.inference_mode():
            out_ids = model.generate(**inputs, **gen_kwargs)
        text = _decode(out_ids, inputs["input_ids"])
        took = time.time() - start
        return f"{text}\n\n[attn={_attn_impl}, time={took:.1f}s]"
    except Exception as e:
        return f"Inference failed: {type(e).__name__}: {e}"
    finally:
        # Release GPU memory cache for ZeroGPU
        try:
            torch.cuda.empty_cache()
        except Exception:
            pass


def build_ui():
    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # Spark: Synergistic Policy And Reward Co-Evolving Framework

            <h3 align="center">
              πŸ“–<a href="https://arxiv.org/abs/2509.22624">Paper</a> 
            | πŸ€—<a href="https://huggingface.co/internlm/Spark-VL-7B">Models</a> 
            | πŸ€—<a href="https://huggingface.co/datasets/internlm/Spark-Data">Datasets</a>
            | πŸ€—<a href="https://huggingface.co/papers/2509.22624">Daily Paper</a>
            </h3>

            **🌈 Introduction:** We propose SPARK, <strong>a unified framework that integrates policy and reward into a single model for joint and synchronous training</strong>. SPARK can automatically derive reward and reflection data from verifiable reward, enabling <strong>self-learning and self-evolution</strong>.

            **πŸ€— Models:** We release the checkpoints at [internlm/Spark-VL-7B](https://huggingface.co/internlm/Spark-VL-7B).

            **πŸ€— Datasets:** Training data is available at [internlm/Spark-Data](https://huggingface.co/datasets/internlm/Spark-Data).

            **πŸ’» Training Code:** The training code and implementation details can be found at [InternLM/Spark](https://github.com/InternLM/Spark).

            ---

            πŸ“Έ **Upload an image and enter a prompt** or πŸ–ΌοΈ **choose the input from the example gallery** (image + prompt).
            """
        )

        # Build an image+prompt gallery from ./examples
        # Each example is an image file with an optional sidecar .txt containing the prompt.
        # If a .txt is present (same basename), we will display a caption and load the
        # prompt alongside the image when the thumbnail is selected.
        def _gather_examples() -> List[tuple]:
            pairs = []  # (image_path, prompt_text)
            imgs = []
            for ext in ("jpg", "jpeg", "png", "webp"):
                imgs.extend(glob.glob(os.path.join("examples", f"*.{ext}")))
            # Deduplicate while keeping order
            for img_path in list(dict.fromkeys(sorted(imgs))):
                stem, _ = os.path.splitext(img_path)
                prompt_path = stem + ".txt"
                prompt_text = None
                if os.path.exists(prompt_path):
                    try:
                        with open(prompt_path, "r", encoding="utf-8") as fh:
                            prompt_text = fh.read().strip()
                    except Exception:
                        prompt_text = None
                pairs.append((img_path, prompt_text))
            return pairs

        example_pairs = _gather_examples()

        # Load default image if exists
        default_path = os.path.join("examples", "example_0.png")
        default_image = Image.open(default_path) if os.path.exists(default_path) else None

        with gr.Row():
            with gr.Column(scale=1):
                image = gr.Image(type="pil", label="Image", value=default_image)

            with gr.Column(scale=1):
                prompt = gr.Textbox(
                    label="Prompt",
                    value=(
                        "As seen in the diagram, three darts are thrown at nine fixed balloons. "
                        "If a balloon is hit it will burst and the dart continues in the same direction "
                        "it had beforehand. How many balloons will not be hit by a dart?"
                    ),
                    lines=4,
                )
                max_new_tokens = gr.Slider(512, 4096, value=1024, step=8, label="max_new_tokens")
                temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
                top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p")
                top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
                run = gr.Button("Generate")

        # Clear prompt when image is removed
        image.clear(fn=lambda: "", outputs=prompt)

        # Examples section: table-like layout with image and prompt columns
        gr.Markdown("## Examples")

        # Handler for clicking on example images
        def _on_example_click(img_path, prompt_text):
            try:
                img_val = Image.open(img_path)
            except Exception:
                img_val = None
            return img_val, prompt_text

        # Categorize examples by type
        math_examples = []
        reward_examples = []
        other_examples = []

        for img_path, prompt_text in example_pairs:
            basename = os.path.basename(img_path)
            if basename.startswith("example_0"):
                math_examples.append((img_path, prompt_text))
            elif basename.startswith("example_1"):
                reward_examples.append((img_path, prompt_text))
            else:
                other_examples.append((img_path, prompt_text))

        # Display math reasoning examples
        if math_examples:
            gr.Markdown("### πŸ“ Math Reasoning Examples")
            for idx, (img_path, prompt_text) in enumerate(math_examples):
                with gr.Row():
                    with gr.Column(scale=1):
                        ex_img = gr.Image(
                            value=img_path,
                            type="filepath",
                            label=f"Math Example {idx}",
                            interactive=False,
                            show_label=True,
                            height=200,
                        )
                        # Wire click event to load the example
                        ex_img.select(
                            fn=lambda ip=img_path, pt=prompt_text: _on_example_click(ip, pt),
                            outputs=[image, prompt],
                        )
                    with gr.Column(scale=3):
                        ex_text = gr.Textbox(
                            value=prompt_text or "",
                            label="Prompt",
                            lines=8,
                            max_lines=8,
                            interactive=False,
                            show_label=True,
                        )

        # Display reward model examples
        if reward_examples:
            gr.Markdown("### 🎯 Reward Model Examples")
            for idx, (img_path, prompt_text) in enumerate(reward_examples):
                with gr.Row():
                    with gr.Column(scale=1):
                        ex_img = gr.Image(
                            value=img_path,
                            type="filepath",
                            label=f"Reward Example {idx}",
                            interactive=False,
                            show_label=True,
                            height=200,
                        )
                        # Wire click event to load the example
                        ex_img.select(
                            fn=lambda ip=img_path, pt=prompt_text: _on_example_click(ip, pt),
                            outputs=[image, prompt],
                        )
                    with gr.Column(scale=3):
                        ex_text = gr.Textbox(
                            value=prompt_text or "",
                            label="Prompt",
                            lines=8,
                            max_lines=8,
                            interactive=False,
                            show_label=True,
                        )

        # Display other examples if any
        if other_examples:
            gr.Markdown("### πŸ“‹ Other Examples")
            for idx, (img_path, prompt_text) in enumerate(other_examples):
                with gr.Row():
                    with gr.Column(scale=1):
                        ex_img = gr.Image(
                            value=img_path,
                            type="filepath",
                            label=f"Example {idx}",
                            interactive=False,
                            show_label=True,
                            height=200,
                        )
                        # Wire click event to load the example
                        ex_img.select(
                            fn=lambda ip=img_path, pt=prompt_text: _on_example_click(ip, pt),
                            outputs=[image, prompt],
                        )
                    with gr.Column(scale=3):
                        ex_text = gr.Textbox(
                            value=prompt_text or "",
                            label="Prompt",
                            lines=8,
                            max_lines=8,
                            interactive=False,
                            show_label=True,
                        )

        output = gr.Textbox(label="Model Output", lines=8)

        run.click(
            fn=generate,
            inputs=[image, prompt, max_new_tokens, temperature, top_p, top_k],
            outputs=output,
            show_progress=True,
        )

        # Citation section at the bottom
        gr.Markdown(
            """
            ---
            If you find this project useful, please kindly cite:

            ```bibtex
            @article{liu2025spark,
              title={SPARK: Synergistic Policy And Reward Co-Evolving Framework},
              author={Liu, Ziyu and Zang, Yuhang and Ding, Shengyuan and Cao, Yuhang and Dong, Xiaoyi and Duan, Haodong and Lin, Dahua and Wang, Jiaqi},
              journal={arXiv preprint arXiv:2509.22624},
              year={2025}
            }
            ```
            """
        )

        demo.queue(max_size=10).launch()
    return demo


if __name__ == "__main__":
    build_ui()