File size: 2,168 Bytes
86f9b85
5b67aa1
86f9b85
 
 
0ca6255
86f9b85
1bdc4ae
5b67aa1
 
4480deb
1bdc4ae
0ca6255
1bdc4ae
5b67aa1
4480deb
0ca6255
1bdc4ae
0ca6255
1bdc4ae
 
 
 
 
 
4480deb
34b2d99
4480deb
1bdc4ae
34b2d99
 
1bdc4ae
 
 
 
 
4480deb
1bdc4ae
34b2d99
1bdc4ae
 
34b2d99
1bdc4ae
 
34b2d99
1bdc4ae
34b2d99
1bdc4ae
0ca6255
 
4480deb
1bdc4ae
 
4480deb
 
0ca6255
1bdc4ae
0ca6255
86f9b85
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import torch

model_id = "google/medgemma-4b-it"

# Load model and processor
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",  # Requires 'accelerate'
)
processor = AutoProcessor.from_pretrained(model_id)

def generate_report(image, clinical_info):
    if image is None:
        return "Please upload a medical image."

    # Create message list for chat-style input
    user_content = []
    if clinical_info:
        user_content.append({"type": "text", "text": f"Patient info: {clinical_info}"})
    user_content.append({"type": "text", "text": "Please describe the medical image in a radiology report style."})
    user_content.append({"type": "image", "image": image})

    messages = [
        {"role": "system", "content": [{"type": "text", "text": "You are an expert radiologist."}]},
        {"role": "user", "content": user_content}
    ]

    # Process input
    inputs = processor.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True,
        return_dict=True, return_tensors="pt"
    ).to(model.device, dtype=torch.bfloat16)

    input_len = inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        output = model.generate(**inputs, max_new_tokens=512, do_sample=True, top_p=0.9, top_k=50)

    generated_ids = output[0]
    decoded = processor.decode(generated_ids[input_len:], skip_special_tokens=True)

    return decoded.strip()

# Gradio interface
gr.Interface(
    fn=generate_report,
    inputs=[
        gr.Image(type="pil", label="Upload Medical Image (X-ray, etc)"),
        gr.Textbox(lines=2, placeholder="e.g. Prior diagnosis: pneumonia. 65-year-old male with cough...", label="Optional Clinical Info")
    ],
    outputs=gr.Textbox(label="Generated Radiology Report"),
    title="🧠 MedGemma Radiology Report Generator",
    description="Upload a medical image and optionally include clinical info (like prior findings or diagnosis). Powered by Google's MedGemma-4B model.",
    allow_flagging="never"
).launch()