medgemmank / app.py
dflehel's picture
Upload app.py
d2612d8 verified
"""
Gradio application for MedGemma inference with ZeroGPU.
This script defines a minimal Gradio interface around Google's
``medgemma‑27b‑it`` multi‑modal model. It is designed to run on
Hugging Face Spaces using the **ZeroGPU** hardware option. ZeroGPU
allocates an NVIDIA H200 GPU slice for the duration of each call and
releases it afterwards. The interface accepts a textual **prompt**
(English only), an optional image upload and an optional **system
prompt** to steer the model. All responses are returned in English and
include a short disclaimer reminding users to consult a medical
professional.
If you set an ``API_KEY`` secret in your Space, callers must supply the
same value in the hidden API key field. Otherwise the endpoint will be
publicly accessible. See the README for details.
Note: ZeroGPU Spaces currently only work with the **Gradio** SDK and
support specific versions of PyTorch and Python【916380489845432†L110-L118】.
Running this script outside of a Space will work on CPU or dedicated
GPU hardware, but ZeroGPU GPU allocation only takes effect when the
Space hardware is set to *ZeroGPU (Dynamic resources)*.
"""
import os
from typing import Optional
import gradio as gr
from PIL import Image
import torch
from transformers import (
AutoProcessor,
AutoModelForImageTextToText,
GenerationConfig,
pipeline,
)
import spaces # for the @spaces.GPU decorator
# ----------------------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise RuntimeError(
"HF_TOKEN environment variable must be set as a Secret in the Space."
)
# Optional API key: when set, clients must provide the same value in the
# hidden ``api_key`` field of the Gradio interface. If not set, no
# authentication is enforced.
API_KEY = os.getenv("API_KEY")
MODEL_ID = "google/medgemma-27b-it"
# Load the processor outside of the GPU context – this is lightweight
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
eos_id = processor.tokenizer.eos_token_id
pad_id = processor.tokenizer.pad_token_id or eos_id
# Banned phrases to reduce chatty or irrelevant responses
ban_list = [
"Disclaimer",
"disclaimer",
"As an AI Chatbot",
"as an AI Chatbot",
"I cannot give medical advice",
"I cannot provide medical advice",
"I cannot give medical advise",
"user",
"response",
"display",
"response>",
"```",
"label",
"tool_code",
]
bad_words_ids = [processor.tokenizer(b, add_special_tokens=False).input_ids for b in ban_list]
# Limit the number of generated tokens to shorten inference time. A smaller
# ``max_new_tokens`` helps ensure the call completes within the 180‑second
# runtime. See the ZeroGPU documentation for runtime guidance【666268612876326†L200-L211】.
gen_cfg = GenerationConfig(
max_new_tokens=60,
do_sample=False,
repetition_penalty=1.12,
no_repeat_ngram_size=6,
length_penalty=1.0,
temperature=0.0,
eos_token_id=eos_id,
pad_token_id=pad_id,
bad_words_ids=bad_words_ids,
)
# We'll load the model lazily inside run_model to ensure GPU allocation
# occurs within the ZeroGPU context. Cache the model and pipeline on
# first use so subsequent calls are faster. A simple attribute on the
# function serves as a persistent cache.
# Increase the duration to 180 seconds and enable queueing. The ZeroGPU
# documentation notes that the default runtime is 60 seconds and that you can
# specify a longer duration via the ``duration`` parameter【666268612876326†L200-L211】.
# Enabling the queue prevents immediate failure when GPUs are busy. These
# adjustments help mitigate intermittent ``GPU task aborted`` errors.
@spaces.GPU(duration=180, enable_queue=True)
def run_model(prompt: str, image: Optional[Image.Image], system_prompt: Optional[str]) -> str:
"""Execute the MedGemma model.
This function will be run inside the ZeroGPU allocation context. It
lazily loads the model and pipeline on first invocation and reuses
them for subsequent calls. Inputs are combined with an optional
system prompt to produce the full prompt. The model's output is
returned as a plain English string.
Args:
prompt: The user's question (English only).
image: An optional PIL Image. If provided, the model will use
both text and image modalities; otherwise text-only.
system_prompt: An optional system prompt to steer the model. If
None or empty, a default instruction is used.
Returns:
The raw English output from the model (without disclaimer).
"""
# Lazy‑load the model and pipeline on first use
if not hasattr(run_model, "model"):
# Determine the appropriate dtype and device map. We'll load on
# auto to split across CPU/GPU if necessary. Use bfloat16 when
# CUDA is available to save memory on H200.
model_kwargs: dict = {
"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
"token": HF_TOKEN,
}
if torch.cuda.is_available():
model_kwargs["device_map"] = "auto"
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)
# Apply our generation configuration directly to the model. The
# transformers pipeline does not accept a ``generation_config``
# keyword argument, so we assign it here. See the Spaces
# documentation for more details on setting generation options.
model.generation_config = gen_cfg
# Create a pipeline for convenience. The pipeline will use the
# model's generation_config when invoked.
vlm = pipeline(
task="image-text-to-text",
model=model,
processor=processor,
)
# Store for reuse
run_model.model = model
run_model.vlm = vlm
else:
vlm = run_model.vlm
# Compose the full prompt
sys_prompt = (
system_prompt.strip()
if system_prompt and system_prompt.strip()
else "You are a concise radiology assistant. Answer the user's question based on the image and text."
)
full_prompt = sys_prompt + "\n" + prompt.strip()
# Run inference
if image is not None:
result = vlm(image, full_prompt)
else:
result = vlm(full_prompt)
output = result[0]["generated_text"]
return output
def predict(
prompt: str,
image: Optional[Image.Image] = None,
system_prompt: Optional[str] = None,
api_key: Optional[str] = None,
) -> str:
"""Wrapper function for Gradio.
Handles optional API key authentication and appends a disclaimer to
the model's output. See README for details.
Args:
prompt: The user's question in English.
image: An optional PIL image.
system_prompt: Optional system prompt to steer the model.
api_key: Optional API key supplied by the client. If the
``API_KEY`` secret is set and this does not match, the
request is rejected.
Returns:
A string containing the model's answer followed by a
disclaimer. If authentication fails an error message is
returned instead.
"""
# Enforce API key if configured
if API_KEY:
if api_key is None or api_key != API_KEY:
return "Error: Invalid or missing API key."
# Validate prompt
if not prompt or not prompt.strip():
return "Error: Prompt cannot be empty."
try:
answer = run_model(prompt, image, system_prompt)
except Exception as e:
return f"Error during inference: {e}"
disclaimer = (
"\n\nThis response is generated by an AI model and may be incorrect. "
"Always consult a licensed medical professional for health questions."
)
return answer.strip() + disclaimer
def build_demo() -> gr.Interface:
"""Construct the Gradio UI for this application."""
# Define inputs: prompt, optional image, optional system prompt, and
# optional API key (hidden from the UI). When API_KEY is not
# configured the api_key input is ignored.
inputs = [
gr.Textbox(
label="Prompt (English only)",
lines=4,
placeholder="Describe the medical image or ask a question."
),
gr.Image(
type="pil",
label="Optional image"
),
gr.Textbox(
label="Optional system prompt",
lines=2,
placeholder="e.g. You are a concise radiology assistant."
),
gr.Textbox(
label="API key",
lines=1,
placeholder="Enter API key if required",
type="password",
visible=bool(API_KEY),
),
]
outputs = gr.Textbox(label="Answer")
description = (
"Ask MedGemma a question about a medical image or condition. "
"Optionally provide a system prompt to guide the model's behaviour. "
"All responses are in English and include a disclaimer."
)
demo = gr.Interface(
fn=predict,
inputs=inputs,
outputs=outputs,
title="MedGemma ZeroGPU (Gradio)",
description=description,
allow_flagging="never",
)
return demo
demo = build_demo()
if __name__ == "__main__":
# Launch with share=False to bind to the default port. In Spaces this
# function is not executed; Spaces uses the Gradio SDK to run the app.
demo.launch()