Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Gradio application for MedGemma inference with ZeroGPU. | |
| This script defines a minimal Gradio interface around Google's | |
| ``medgemma‑27b‑it`` multi‑modal model. It is designed to run on | |
| Hugging Face Spaces using the **ZeroGPU** hardware option. ZeroGPU | |
| allocates an NVIDIA H200 GPU slice for the duration of each call and | |
| releases it afterwards. The interface accepts a textual **prompt** | |
| (English only), an optional image upload and an optional **system | |
| prompt** to steer the model. All responses are returned in English and | |
| include a short disclaimer reminding users to consult a medical | |
| professional. | |
| If you set an ``API_KEY`` secret in your Space, callers must supply the | |
| same value in the hidden API key field. Otherwise the endpoint will be | |
| publicly accessible. See the README for details. | |
| Note: ZeroGPU Spaces currently only work with the **Gradio** SDK and | |
| support specific versions of PyTorch and Python【916380489845432†L110-L118】. | |
| Running this script outside of a Space will work on CPU or dedicated | |
| GPU hardware, but ZeroGPU GPU allocation only takes effect when the | |
| Space hardware is set to *ZeroGPU (Dynamic resources)*. | |
| """ | |
| import os | |
| from typing import Optional | |
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModelForImageTextToText, | |
| GenerationConfig, | |
| pipeline, | |
| ) | |
| import spaces # for the @spaces.GPU decorator | |
| # ---------------------------------------------------------------------------- | |
| # Configuration | |
| # ---------------------------------------------------------------------------- | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN is None: | |
| raise RuntimeError( | |
| "HF_TOKEN environment variable must be set as a Secret in the Space." | |
| ) | |
| # Optional API key: when set, clients must provide the same value in the | |
| # hidden ``api_key`` field of the Gradio interface. If not set, no | |
| # authentication is enforced. | |
| API_KEY = os.getenv("API_KEY") | |
| MODEL_ID = "google/medgemma-27b-it" | |
| # Load the processor outside of the GPU context – this is lightweight | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True) | |
| eos_id = processor.tokenizer.eos_token_id | |
| pad_id = processor.tokenizer.pad_token_id or eos_id | |
| # Banned phrases to reduce chatty or irrelevant responses | |
| ban_list = [ | |
| "Disclaimer", | |
| "disclaimer", | |
| "As an AI Chatbot", | |
| "as an AI Chatbot", | |
| "I cannot give medical advice", | |
| "I cannot provide medical advice", | |
| "I cannot give medical advise", | |
| "user", | |
| "response", | |
| "display", | |
| "response>", | |
| "```", | |
| "label", | |
| "tool_code", | |
| ] | |
| bad_words_ids = [processor.tokenizer(b, add_special_tokens=False).input_ids for b in ban_list] | |
| # Limit the number of generated tokens to shorten inference time. A smaller | |
| # ``max_new_tokens`` helps ensure the call completes within the 180‑second | |
| # runtime. See the ZeroGPU documentation for runtime guidance【666268612876326†L200-L211】. | |
| gen_cfg = GenerationConfig( | |
| max_new_tokens=60, | |
| do_sample=False, | |
| repetition_penalty=1.12, | |
| no_repeat_ngram_size=6, | |
| length_penalty=1.0, | |
| temperature=0.0, | |
| eos_token_id=eos_id, | |
| pad_token_id=pad_id, | |
| bad_words_ids=bad_words_ids, | |
| ) | |
| # We'll load the model lazily inside run_model to ensure GPU allocation | |
| # occurs within the ZeroGPU context. Cache the model and pipeline on | |
| # first use so subsequent calls are faster. A simple attribute on the | |
| # function serves as a persistent cache. | |
| # Increase the duration to 180 seconds and enable queueing. The ZeroGPU | |
| # documentation notes that the default runtime is 60 seconds and that you can | |
| # specify a longer duration via the ``duration`` parameter【666268612876326†L200-L211】. | |
| # Enabling the queue prevents immediate failure when GPUs are busy. These | |
| # adjustments help mitigate intermittent ``GPU task aborted`` errors. | |
| def run_model(prompt: str, image: Optional[Image.Image], system_prompt: Optional[str]) -> str: | |
| """Execute the MedGemma model. | |
| This function will be run inside the ZeroGPU allocation context. It | |
| lazily loads the model and pipeline on first invocation and reuses | |
| them for subsequent calls. Inputs are combined with an optional | |
| system prompt to produce the full prompt. The model's output is | |
| returned as a plain English string. | |
| Args: | |
| prompt: The user's question (English only). | |
| image: An optional PIL Image. If provided, the model will use | |
| both text and image modalities; otherwise text-only. | |
| system_prompt: An optional system prompt to steer the model. If | |
| None or empty, a default instruction is used. | |
| Returns: | |
| The raw English output from the model (without disclaimer). | |
| """ | |
| # Lazy‑load the model and pipeline on first use | |
| if not hasattr(run_model, "model"): | |
| # Determine the appropriate dtype and device map. We'll load on | |
| # auto to split across CPU/GPU if necessary. Use bfloat16 when | |
| # CUDA is available to save memory on H200. | |
| model_kwargs: dict = { | |
| "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| "token": HF_TOKEN, | |
| } | |
| if torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" | |
| model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs) | |
| # Apply our generation configuration directly to the model. The | |
| # transformers pipeline does not accept a ``generation_config`` | |
| # keyword argument, so we assign it here. See the Spaces | |
| # documentation for more details on setting generation options. | |
| model.generation_config = gen_cfg | |
| # Create a pipeline for convenience. The pipeline will use the | |
| # model's generation_config when invoked. | |
| vlm = pipeline( | |
| task="image-text-to-text", | |
| model=model, | |
| processor=processor, | |
| ) | |
| # Store for reuse | |
| run_model.model = model | |
| run_model.vlm = vlm | |
| else: | |
| vlm = run_model.vlm | |
| # Compose the full prompt | |
| sys_prompt = ( | |
| system_prompt.strip() | |
| if system_prompt and system_prompt.strip() | |
| else "You are a concise radiology assistant. Answer the user's question based on the image and text." | |
| ) | |
| full_prompt = sys_prompt + "\n" + prompt.strip() | |
| # Run inference | |
| if image is not None: | |
| result = vlm(image, full_prompt) | |
| else: | |
| result = vlm(full_prompt) | |
| output = result[0]["generated_text"] | |
| return output | |
| def predict( | |
| prompt: str, | |
| image: Optional[Image.Image] = None, | |
| system_prompt: Optional[str] = None, | |
| api_key: Optional[str] = None, | |
| ) -> str: | |
| """Wrapper function for Gradio. | |
| Handles optional API key authentication and appends a disclaimer to | |
| the model's output. See README for details. | |
| Args: | |
| prompt: The user's question in English. | |
| image: An optional PIL image. | |
| system_prompt: Optional system prompt to steer the model. | |
| api_key: Optional API key supplied by the client. If the | |
| ``API_KEY`` secret is set and this does not match, the | |
| request is rejected. | |
| Returns: | |
| A string containing the model's answer followed by a | |
| disclaimer. If authentication fails an error message is | |
| returned instead. | |
| """ | |
| # Enforce API key if configured | |
| if API_KEY: | |
| if api_key is None or api_key != API_KEY: | |
| return "Error: Invalid or missing API key." | |
| # Validate prompt | |
| if not prompt or not prompt.strip(): | |
| return "Error: Prompt cannot be empty." | |
| try: | |
| answer = run_model(prompt, image, system_prompt) | |
| except Exception as e: | |
| return f"Error during inference: {e}" | |
| disclaimer = ( | |
| "\n\nThis response is generated by an AI model and may be incorrect. " | |
| "Always consult a licensed medical professional for health questions." | |
| ) | |
| return answer.strip() + disclaimer | |
| def build_demo() -> gr.Interface: | |
| """Construct the Gradio UI for this application.""" | |
| # Define inputs: prompt, optional image, optional system prompt, and | |
| # optional API key (hidden from the UI). When API_KEY is not | |
| # configured the api_key input is ignored. | |
| inputs = [ | |
| gr.Textbox( | |
| label="Prompt (English only)", | |
| lines=4, | |
| placeholder="Describe the medical image or ask a question." | |
| ), | |
| gr.Image( | |
| type="pil", | |
| label="Optional image" | |
| ), | |
| gr.Textbox( | |
| label="Optional system prompt", | |
| lines=2, | |
| placeholder="e.g. You are a concise radiology assistant." | |
| ), | |
| gr.Textbox( | |
| label="API key", | |
| lines=1, | |
| placeholder="Enter API key if required", | |
| type="password", | |
| visible=bool(API_KEY), | |
| ), | |
| ] | |
| outputs = gr.Textbox(label="Answer") | |
| description = ( | |
| "Ask MedGemma a question about a medical image or condition. " | |
| "Optionally provide a system prompt to guide the model's behaviour. " | |
| "All responses are in English and include a disclaimer." | |
| ) | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=inputs, | |
| outputs=outputs, | |
| title="MedGemma ZeroGPU (Gradio)", | |
| description=description, | |
| allow_flagging="never", | |
| ) | |
| return demo | |
| demo = build_demo() | |
| if __name__ == "__main__": | |
| # Launch with share=False to bind to the default port. In Spaces this | |
| # function is not executed; Spaces uses the Gradio SDK to run the app. | |
| demo.launch() |