Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| from PIL import Image | |
| import torch | |
| model_id = "google/medgemma-4b-it" | |
| # Load model and processor | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", # Requires 'accelerate' | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| def generate_report(image, clinical_info): | |
| if image is None: | |
| return "Please upload a medical image." | |
| # Create message list for chat-style input | |
| user_content = [] | |
| if clinical_info: | |
| user_content.append({"type": "text", "text": f"Patient info: {clinical_info}"}) | |
| user_content.append({"type": "text", "text": "Please describe the medical image in a radiology report style."}) | |
| user_content.append({"type": "image", "image": image}) | |
| messages = [ | |
| {"role": "system", "content": [{"type": "text", "text": "You are an expert radiologist."}]}, | |
| {"role": "user", "content": user_content} | |
| ] | |
| # Process input | |
| inputs = processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, | |
| return_dict=True, return_tensors="pt" | |
| ).to(model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| output = model.generate(**inputs, max_new_tokens=512, do_sample=True, top_p=0.9, top_k=50) | |
| generated_ids = output[0] | |
| decoded = processor.decode(generated_ids[input_len:], skip_special_tokens=True) | |
| return decoded.strip() | |
| # Gradio interface | |
| gr.Interface( | |
| fn=generate_report, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Medical Image (X-ray, etc)"), | |
| gr.Textbox(lines=2, placeholder="e.g. Prior diagnosis: pneumonia. 65-year-old male with cough...", label="Optional Clinical Info") | |
| ], | |
| outputs=gr.Textbox(label="Generated Radiology Report"), | |
| title="🧠 MedGemma Radiology Report Generator", | |
| description="Upload a medical image and optionally include clinical info (like prior findings or diagnosis). Powered by Google's MedGemma-4B model.", | |
| allow_flagging="never" | |
| ).launch() | |