import gradio as gr from transformers import AutoProcessor, AutoModelForImageTextToText from PIL import Image import torch model_id = "google/medgemma-4b-it" # Load model and processor model = AutoModelForImageTextToText.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", # Requires 'accelerate' ) processor = AutoProcessor.from_pretrained(model_id) def generate_report(image, clinical_info): if image is None: return "Please upload a medical image." # Create message list for chat-style input user_content = [] if clinical_info: user_content.append({"type": "text", "text": f"Patient info: {clinical_info}"}) user_content.append({"type": "text", "text": "Please describe the medical image in a radiology report style."}) user_content.append({"type": "image", "image": image}) messages = [ {"role": "system", "content": [{"type": "text", "text": "You are an expert radiologist."}]}, {"role": "user", "content": user_content} ] # Process input inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): output = model.generate(**inputs, max_new_tokens=512, do_sample=True, top_p=0.9, top_k=50) generated_ids = output[0] decoded = processor.decode(generated_ids[input_len:], skip_special_tokens=True) return decoded.strip() # Gradio interface gr.Interface( fn=generate_report, inputs=[ gr.Image(type="pil", label="Upload Medical Image (X-ray, etc)"), gr.Textbox(lines=2, placeholder="e.g. Prior diagnosis: pneumonia. 65-year-old male with cough...", label="Optional Clinical Info") ], outputs=gr.Textbox(label="Generated Radiology Report"), title="🧠 MedGemma Radiology Report Generator", description="Upload a medical image and optionally include clinical info (like prior findings or diagnosis). Powered by Google's MedGemma-4B model.", allow_flagging="never" ).launch()