spark / app.py
yuhangzang
update
67b36a4
raw
history blame
7.22 kB
import os
import time
import glob
from typing import List
import spaces
import gradio as gr
import torch
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
MODEL_ID = os.environ.get("SPARK_MODEL_ID", "internlm/Spark-VL-7B")
DTYPE = torch.bfloat16
_model = None
_processor = None
_attn_impl = None
def _load_model_and_processor():
global _model, _processor, _attn_impl
if _model is not None and _processor is not None:
return _model, _processor
# Prefer flash-attn if available, otherwise fall back to eager.
attn_impl = os.environ.get("ATTN_IMPL", "flash_attention_2")
try:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
attn_implementation=attn_impl,
device_map="auto",
)
_attn_impl = attn_impl
except Exception:
# Fallback for environments without flash-attn
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
attn_implementation="eager",
device_map="auto",
)
_attn_impl = "eager"
processor = AutoProcessor.from_pretrained(MODEL_ID)
_model = model
_processor = processor
return _model, _processor
def _prepare_inputs(image, prompt):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
chat_text = _processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = _processor(
text=[chat_text],
# Pass the single image directly; template contains <image> placeholder
images=[image] if image is not None else None,
return_tensors="pt",
)
return inputs
def _decode(generated_ids, input_ids):
# Trim the prompt part before decoding
trimmed = generated_ids[:, input_ids.shape[1] :]
out = _processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return out[0] if out else ""
@spaces.GPU(duration=120)
def generate(image, prompt, max_new_tokens, temperature, top_p, top_k):
if image is None:
return "Please upload an image."
prompt = (prompt or "").strip()
if not prompt:
return "Please enter a prompt."
start = time.time()
model, _ = _load_model_and_processor()
try:
# Ensure model resides on GPU during the call
p = next(model.parameters())
if p.device.type != "cuda":
model.to("cuda")
except StopIteration:
pass
try:
inputs = _prepare_inputs(image, prompt)
dev = next(model.parameters()).device
inputs = {k: v.to(dev) if hasattr(v, "to") else v for k, v in inputs.items()}
gen_kwargs = {
"max_new_tokens": int(max_new_tokens),
"do_sample": True,
"temperature": float(temperature),
"top_p": float(top_p),
"top_k": int(top_k),
"use_cache": True,
}
with torch.inference_mode():
out_ids = model.generate(**inputs, **gen_kwargs)
text = _decode(out_ids, inputs["input_ids"])
took = time.time() - start
return f"{text}\n\n[attn={_attn_impl}, time={took:.1f}s]"
except Exception as e:
return f"Inference failed: {type(e).__name__}: {e}"
finally:
# Release GPU quickly on ZeroGPU by moving weights off CUDA.
try:
if hasattr(model, "to"):
model.to("cpu")
torch.cuda.empty_cache()
except Exception:
pass
def build_ui():
with gr.Blocks() as demo:
gr.Markdown("# Spark-VL ZeroGPU Demo\nUpload an image or choose from the example gallery, then enter a prompt.")
# Build an image gallery from ./examples
def _gather_examples() -> List[str]:
exts = ("*.jpg", "*.jpeg", "*.png", "*.webp")
imgs: List[str] = []
for ptn in exts:
imgs.extend(sorted(glob.glob(os.path.join("examples", ptn))))
# Deduplicate while keeping order
seen = set()
uniq = []
for p in imgs:
if p not in seen:
uniq.append(p)
seen.add(p)
return uniq
example_images = _gather_examples()
default_candidates = [
os.path.join("examples", "example_0.png"),
]
default_image_path = next((p for p in default_candidates if os.path.exists(p)), None)
default_image = Image.open(default_image_path) if default_image_path else None
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(type="pil", label="Image", value=default_image)
gallery = gr.Gallery(
value=example_images,
label="Example Gallery",
show_label=True,
columns=4,
height=240,
allow_preview=True,
)
# When a thumbnail is clicked, load it into the image input
def _on_gallery_select(evt):
try:
idx = int(evt.index)
except Exception:
return None
if idx is None or idx < 0 or idx >= len(example_images):
return None
# Return PIL image so upstream expects a PIL image
try:
return Image.open(example_images[idx])
except Exception:
return example_images[idx]
gallery.select(fn=_on_gallery_select, inputs=None, outputs=image)
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
value=(
"As seen in the diagram, three darts are thrown at nine fixed balloons. "
"If a balloon is hit it will burst and the dart continues in the same direction "
"it had beforehand. How many balloons will not be hit by a dart?"
),
lines=4,
)
max_new_tokens = gr.Slider(16, 512, value=128, step=8, label="max_new_tokens")
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p")
top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
run = gr.Button("Generate")
output = gr.Textbox(label="Model Output", lines=8)
run.click(
fn=generate,
inputs=[image, prompt, max_new_tokens, temperature, top_p, top_k],
outputs=output,
show_progress=True,
)
demo.queue(concurrency_count=1, max_size=10).launch()
return demo
if __name__ == "__main__":
build_ui()