Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import time | |
| import glob | |
| from typing import List | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| MODEL_ID = os.environ.get("SPARK_MODEL_ID", "internlm/Spark-VL-7B") | |
| DTYPE = torch.bfloat16 | |
| _model = None | |
| _processor = None | |
| _attn_impl = None | |
| def _load_model_and_processor(): | |
| global _model, _processor, _attn_impl | |
| if _model is not None and _processor is not None: | |
| return _model, _processor | |
| # Prefer flash-attn if available, otherwise fall back to eager. | |
| attn_impl = os.environ.get("ATTN_IMPL", "flash_attention_2") | |
| try: | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| attn_implementation=attn_impl, | |
| device_map="auto", | |
| ) | |
| _attn_impl = attn_impl | |
| except Exception: | |
| # Fallback for environments without flash-attn | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| attn_implementation="eager", | |
| device_map="auto", | |
| ) | |
| _attn_impl = "eager" | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| _model = model | |
| _processor = processor | |
| return _model, _processor | |
| def _prepare_inputs(image, prompt): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| chat_text = _processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = _processor( | |
| text=[chat_text], | |
| # Pass the single image directly; template contains <image> placeholder | |
| images=[image] if image is not None else None, | |
| return_tensors="pt", | |
| ) | |
| return inputs | |
| def _decode(generated_ids, input_ids): | |
| # Trim the prompt part before decoding | |
| trimmed = generated_ids[:, input_ids.shape[1] :] | |
| out = _processor.batch_decode( | |
| trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| return out[0] if out else "" | |
| def generate(image, prompt, max_new_tokens, temperature, top_p, top_k): | |
| if image is None: | |
| return "Please upload an image." | |
| prompt = (prompt or "").strip() | |
| if not prompt: | |
| return "Please enter a prompt." | |
| start = time.time() | |
| model, _ = _load_model_and_processor() | |
| try: | |
| # Ensure model resides on GPU during the call | |
| p = next(model.parameters()) | |
| if p.device.type != "cuda": | |
| model.to("cuda") | |
| except StopIteration: | |
| pass | |
| try: | |
| inputs = _prepare_inputs(image, prompt) | |
| dev = next(model.parameters()).device | |
| inputs = {k: v.to(dev) if hasattr(v, "to") else v for k, v in inputs.items()} | |
| gen_kwargs = { | |
| "max_new_tokens": int(max_new_tokens), | |
| "do_sample": True, | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "top_k": int(top_k), | |
| "use_cache": True, | |
| } | |
| with torch.inference_mode(): | |
| out_ids = model.generate(**inputs, **gen_kwargs) | |
| text = _decode(out_ids, inputs["input_ids"]) | |
| took = time.time() - start | |
| return f"{text}\n\n[attn={_attn_impl}, time={took:.1f}s]" | |
| except Exception as e: | |
| return f"Inference failed: {type(e).__name__}: {e}" | |
| finally: | |
| # Release GPU quickly on ZeroGPU by moving weights off CUDA. | |
| try: | |
| if hasattr(model, "to"): | |
| model.to("cpu") | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| def build_ui(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Spark-VL ZeroGPU Demo\nUpload an image or choose from the example gallery, then enter a prompt.") | |
| # Build an image gallery from ./examples | |
| def _gather_examples() -> List[str]: | |
| exts = ("*.jpg", "*.jpeg", "*.png", "*.webp") | |
| imgs: List[str] = [] | |
| for ptn in exts: | |
| imgs.extend(sorted(glob.glob(os.path.join("examples", ptn)))) | |
| # Deduplicate while keeping order | |
| seen = set() | |
| uniq = [] | |
| for p in imgs: | |
| if p not in seen: | |
| uniq.append(p) | |
| seen.add(p) | |
| return uniq | |
| example_images = _gather_examples() | |
| default_candidates = [ | |
| os.path.join("examples", "example_0.png"), | |
| ] | |
| default_image_path = next((p for p in default_candidates if os.path.exists(p)), None) | |
| default_image = Image.open(default_image_path) if default_image_path else None | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image = gr.Image(type="pil", label="Image", value=default_image) | |
| gallery = gr.Gallery( | |
| value=example_images, | |
| label="Example Gallery", | |
| show_label=True, | |
| columns=4, | |
| height=240, | |
| allow_preview=True, | |
| ) | |
| # When a thumbnail is clicked, load it into the image input | |
| def _on_gallery_select(evt): | |
| try: | |
| idx = int(evt.index) | |
| except Exception: | |
| return None | |
| if idx is None or idx < 0 or idx >= len(example_images): | |
| return None | |
| # Return PIL image so upstream expects a PIL image | |
| try: | |
| return Image.open(example_images[idx]) | |
| except Exception: | |
| return example_images[idx] | |
| gallery.select(fn=_on_gallery_select, inputs=None, outputs=image) | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value=( | |
| "As seen in the diagram, three darts are thrown at nine fixed balloons. " | |
| "If a balloon is hit it will burst and the dart continues in the same direction " | |
| "it had beforehand. How many balloons will not be hit by a dart?" | |
| ), | |
| lines=4, | |
| ) | |
| max_new_tokens = gr.Slider(16, 512, value=128, step=8, label="max_new_tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p") | |
| top_k = gr.Slider(1, 200, value=50, step=1, label="top_k") | |
| run = gr.Button("Generate") | |
| output = gr.Textbox(label="Model Output", lines=8) | |
| run.click( | |
| fn=generate, | |
| inputs=[image, prompt, max_new_tokens, temperature, top_p, top_k], | |
| outputs=output, | |
| show_progress=True, | |
| ) | |
| demo.queue(concurrency_count=1, max_size=10).launch() | |
| return demo | |
| if __name__ == "__main__": | |
| build_ui() | |