yuhangzang commited on
Commit
67b36a4
·
1 Parent(s): 12e3e78
Files changed (5) hide show
  1. .gitattributes +4 -0
  2. README.md +23 -0
  3. app.py +216 -0
  4. examples/example_0.png +3 -0
  5. requirements.txt +6 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ *.webp filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -12,3 +12,26 @@ short_description: ' A unified framework for reasoning and reward modeling'
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ ## 使用说明(ZeroGPU)
17
+
18
+ - Space 类型选择 `Gradio`,硬件选择 `ZeroGPU`(需要 PRO 或企业组织)。
19
+ - 本仓库包含一个最小可用的 Spark-VL 演示:上传图片 + 输入文本,返回模型生成结果。
20
+ - 关键代码在 `app.py`:
21
+ - 使用 `spaces.GPU` 装饰推理函数,调用时申请 GPU,用完后释放。
22
+ - 首次调用按需加载 `internlm/Spark-VL-7B`,优先尝试 `flash_attention_2`,失败则回退到 `eager`。
23
+ - 推理结束把模型移回 CPU,快速释放 ZeroGPU 显存。
24
+
25
+ ### 本地/Space 运行
26
+
27
+ 1) 推送到 Hugging Face Space 后,在 Space 设置中选择硬件 `ZeroGPU`。
28
+
29
+ 2) 运行入口:`app.py`,界面包含:图片、提示词、采样参数(max_new_tokens/temperature/top_p/top_k)。
30
+
31
+ 3) 可选环境变量:
32
+ - `SPARK_MODEL_ID`:默认 `internlm/Spark-VL-7B`。
33
+ - `ATTN_IMPL`:默认 `flash_attention_2`,可改为 `eager`。
34
+
35
+ ### 依赖
36
+
37
+ 见 `requirements.txt`(Gradio 5.x,Transformers 4.45+,qwen-vl-utils 等)。ZeroGPU 的基础镜像已包含合适的 PyTorch 版本。
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import glob
4
+ from typing import List
5
+
6
+ import spaces
7
+ import gradio as gr
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
11
+
12
+ MODEL_ID = os.environ.get("SPARK_MODEL_ID", "internlm/Spark-VL-7B")
13
+ DTYPE = torch.bfloat16
14
+
15
+ _model = None
16
+ _processor = None
17
+ _attn_impl = None
18
+
19
+
20
+ def _load_model_and_processor():
21
+ global _model, _processor, _attn_impl
22
+ if _model is not None and _processor is not None:
23
+ return _model, _processor
24
+
25
+ # Prefer flash-attn if available, otherwise fall back to eager.
26
+ attn_impl = os.environ.get("ATTN_IMPL", "flash_attention_2")
27
+ try:
28
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
29
+ MODEL_ID,
30
+ torch_dtype=DTYPE,
31
+ attn_implementation=attn_impl,
32
+ device_map="auto",
33
+ )
34
+ _attn_impl = attn_impl
35
+ except Exception:
36
+ # Fallback for environments without flash-attn
37
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
38
+ MODEL_ID,
39
+ torch_dtype=DTYPE,
40
+ attn_implementation="eager",
41
+ device_map="auto",
42
+ )
43
+ _attn_impl = "eager"
44
+
45
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
46
+
47
+ _model = model
48
+ _processor = processor
49
+ return _model, _processor
50
+
51
+
52
+ def _prepare_inputs(image, prompt):
53
+ messages = [
54
+ {
55
+ "role": "user",
56
+ "content": [
57
+ {"type": "image", "image": image},
58
+ {"type": "text", "text": prompt},
59
+ ],
60
+ }
61
+ ]
62
+ chat_text = _processor.apply_chat_template(
63
+ messages, tokenize=False, add_generation_prompt=True
64
+ )
65
+ inputs = _processor(
66
+ text=[chat_text],
67
+ # Pass the single image directly; template contains <image> placeholder
68
+ images=[image] if image is not None else None,
69
+ return_tensors="pt",
70
+ )
71
+ return inputs
72
+
73
+
74
+ def _decode(generated_ids, input_ids):
75
+ # Trim the prompt part before decoding
76
+ trimmed = generated_ids[:, input_ids.shape[1] :]
77
+ out = _processor.batch_decode(
78
+ trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
79
+ )
80
+ return out[0] if out else ""
81
+
82
+
83
+ @spaces.GPU(duration=120)
84
+ def generate(image, prompt, max_new_tokens, temperature, top_p, top_k):
85
+ if image is None:
86
+ return "Please upload an image."
87
+ prompt = (prompt or "").strip()
88
+ if not prompt:
89
+ return "Please enter a prompt."
90
+
91
+ start = time.time()
92
+ model, _ = _load_model_and_processor()
93
+ try:
94
+ # Ensure model resides on GPU during the call
95
+ p = next(model.parameters())
96
+ if p.device.type != "cuda":
97
+ model.to("cuda")
98
+ except StopIteration:
99
+ pass
100
+
101
+ try:
102
+ inputs = _prepare_inputs(image, prompt)
103
+ dev = next(model.parameters()).device
104
+ inputs = {k: v.to(dev) if hasattr(v, "to") else v for k, v in inputs.items()}
105
+
106
+ gen_kwargs = {
107
+ "max_new_tokens": int(max_new_tokens),
108
+ "do_sample": True,
109
+ "temperature": float(temperature),
110
+ "top_p": float(top_p),
111
+ "top_k": int(top_k),
112
+ "use_cache": True,
113
+ }
114
+ with torch.inference_mode():
115
+ out_ids = model.generate(**inputs, **gen_kwargs)
116
+ text = _decode(out_ids, inputs["input_ids"])
117
+ took = time.time() - start
118
+ return f"{text}\n\n[attn={_attn_impl}, time={took:.1f}s]"
119
+ except Exception as e:
120
+ return f"Inference failed: {type(e).__name__}: {e}"
121
+ finally:
122
+ # Release GPU quickly on ZeroGPU by moving weights off CUDA.
123
+ try:
124
+ if hasattr(model, "to"):
125
+ model.to("cpu")
126
+ torch.cuda.empty_cache()
127
+ except Exception:
128
+ pass
129
+
130
+
131
+ def build_ui():
132
+ with gr.Blocks() as demo:
133
+ gr.Markdown("# Spark-VL ZeroGPU Demo\nUpload an image or choose from the example gallery, then enter a prompt.")
134
+
135
+ # Build an image gallery from ./examples
136
+ def _gather_examples() -> List[str]:
137
+ exts = ("*.jpg", "*.jpeg", "*.png", "*.webp")
138
+ imgs: List[str] = []
139
+ for ptn in exts:
140
+ imgs.extend(sorted(glob.glob(os.path.join("examples", ptn))))
141
+ # Deduplicate while keeping order
142
+ seen = set()
143
+ uniq = []
144
+ for p in imgs:
145
+ if p not in seen:
146
+ uniq.append(p)
147
+ seen.add(p)
148
+ return uniq
149
+
150
+ example_images = _gather_examples()
151
+
152
+ default_candidates = [
153
+ os.path.join("examples", "example_0.png"),
154
+ ]
155
+ default_image_path = next((p for p in default_candidates if os.path.exists(p)), None)
156
+ default_image = Image.open(default_image_path) if default_image_path else None
157
+
158
+ with gr.Row():
159
+ with gr.Column(scale=1):
160
+ image = gr.Image(type="pil", label="Image", value=default_image)
161
+ gallery = gr.Gallery(
162
+ value=example_images,
163
+ label="Example Gallery",
164
+ show_label=True,
165
+ columns=4,
166
+ height=240,
167
+ allow_preview=True,
168
+ )
169
+
170
+ # When a thumbnail is clicked, load it into the image input
171
+ def _on_gallery_select(evt):
172
+ try:
173
+ idx = int(evt.index)
174
+ except Exception:
175
+ return None
176
+ if idx is None or idx < 0 or idx >= len(example_images):
177
+ return None
178
+ # Return PIL image so upstream expects a PIL image
179
+ try:
180
+ return Image.open(example_images[idx])
181
+ except Exception:
182
+ return example_images[idx]
183
+
184
+ gallery.select(fn=_on_gallery_select, inputs=None, outputs=image)
185
+
186
+ with gr.Column(scale=1):
187
+ prompt = gr.Textbox(
188
+ label="Prompt",
189
+ value=(
190
+ "As seen in the diagram, three darts are thrown at nine fixed balloons. "
191
+ "If a balloon is hit it will burst and the dart continues in the same direction "
192
+ "it had beforehand. How many balloons will not be hit by a dart?"
193
+ ),
194
+ lines=4,
195
+ )
196
+ max_new_tokens = gr.Slider(16, 512, value=128, step=8, label="max_new_tokens")
197
+ temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
198
+ top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p")
199
+ top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
200
+ run = gr.Button("Generate")
201
+
202
+ output = gr.Textbox(label="Model Output", lines=8)
203
+
204
+ run.click(
205
+ fn=generate,
206
+ inputs=[image, prompt, max_new_tokens, temperature, top_p, top_k],
207
+ outputs=output,
208
+ show_progress=True,
209
+ )
210
+
211
+ demo.queue(concurrency_count=1, max_size=10).launch()
212
+ return demo
213
+
214
+
215
+ if __name__ == "__main__":
216
+ build_ui()
examples/example_0.png ADDED

Git LFS Details

  • SHA256: df52c4fd4574d96401d0231878e83803bdb64b8d82ba81854a028a4759b7fe55
  • Pointer size: 131 Bytes
  • Size of remote file: 144 kB
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers>=4.45.0
2
+ accelerate>=0.33.0
3
+ qwen-vl-utils>=0.0.8
4
+ gradio>=5.49.1
5
+ spaces>=0.24.0
6
+ pillow