Spaces:
Running
Running
File size: 2,168 Bytes
86f9b85 5b67aa1 86f9b85 0ca6255 86f9b85 1bdc4ae 5b67aa1 4480deb 1bdc4ae 0ca6255 1bdc4ae 5b67aa1 4480deb 0ca6255 1bdc4ae 0ca6255 1bdc4ae 4480deb 34b2d99 4480deb 1bdc4ae 34b2d99 1bdc4ae 4480deb 1bdc4ae 34b2d99 1bdc4ae 34b2d99 1bdc4ae 34b2d99 1bdc4ae 34b2d99 1bdc4ae 0ca6255 4480deb 1bdc4ae 4480deb 0ca6255 1bdc4ae 0ca6255 86f9b85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import torch
model_id = "google/medgemma-4b-it"
# Load model and processor
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto", # Requires 'accelerate'
)
processor = AutoProcessor.from_pretrained(model_id)
def generate_report(image, clinical_info):
if image is None:
return "Please upload a medical image."
# Create message list for chat-style input
user_content = []
if clinical_info:
user_content.append({"type": "text", "text": f"Patient info: {clinical_info}"})
user_content.append({"type": "text", "text": "Please describe the medical image in a radiology report style."})
user_content.append({"type": "image", "image": image})
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are an expert radiologist."}]},
{"role": "user", "content": user_content}
]
# Process input
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
output = model.generate(**inputs, max_new_tokens=512, do_sample=True, top_p=0.9, top_k=50)
generated_ids = output[0]
decoded = processor.decode(generated_ids[input_len:], skip_special_tokens=True)
return decoded.strip()
# Gradio interface
gr.Interface(
fn=generate_report,
inputs=[
gr.Image(type="pil", label="Upload Medical Image (X-ray, etc)"),
gr.Textbox(lines=2, placeholder="e.g. Prior diagnosis: pneumonia. 65-year-old male with cough...", label="Optional Clinical Info")
],
outputs=gr.Textbox(label="Generated Radiology Report"),
title="🧠 MedGemma Radiology Report Generator",
description="Upload a medical image and optionally include clinical info (like prior findings or diagnosis). Powered by Google's MedGemma-4B model.",
allow_flagging="never"
).launch()
|