File size: 5,423 Bytes
8f1f240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4851757
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# -*- coding: utf-8 -*-
"""heai.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1CPgKNfxzP9sPf9nsHmsct1wlUuZL3XpL
"""

import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load model and tokenizer
model_name = "ibm-granite/granite-3.2-2b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Function to generate LLM response
def generate_response(prompt, max_length=1024):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    if torch.cuda.is_available():
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.replace(prompt, "").strip()
    return response

# Function for disease prediction
def disease_prediction(symptoms):
    prompt = f"""Based on the following symptoms, provide possible medical conditions and general medication suggestions.
Always emphasize the importance of consulting a doctor for proper diagnosis.

Symptoms: {symptoms}

Possible conditions and recommendations:

**IMPORTANT: This is for informational purposes only. Please consult a healthcare professional for proper diagnosis and treatment.**

Analysis:"""
    return generate_response(prompt, max_length=1200)

# Function for treatment plan
def treatment_plan(condition, age, gender, medical_history):
    prompt = f"""Generate personalized treatment suggestions for the following patient information. Include home remedies and general medication guidelines.

Medical Condition: {condition}
Age: {age}
Gender: {gender}
Medical History: {medical_history}

Personalized treatment plan including home remedies and medication guidelines:

**IMPORTANT: This is for informational purposes only. Please consult a healthcare professional for proper treatment.**

Treatment Plan:"""
    return generate_response(prompt, max_length=1200)

# Function for chat with patient
def patient_chat(chat_history, user_input):
    conversation = chat_history + f"\nPatient: {user_input}\nAI:"
    response = generate_response(conversation, max_length=800)
    chat_history += f"\nPatient: {user_input}\nAI: {response}"
    return chat_history, chat_history

# Build Gradio app
with gr.Blocks() as app:
    gr.Markdown("# Medical AI Assistant")
    gr.Markdown("**Disclaimer: This is for informational purposes only. Always consult healthcare professionals for medical advice.**")

    with gr.Tabs():
        with gr.TabItem("Patient Chat"):
            chat_history = gr.Textbox(label="Conversation", lines=15, value="", interactive=False)
            user_input = gr.Textbox(label="Your Message", placeholder="Describe your symptoms or ask questions...", lines=2)
            send_btn = gr.Button("Send")
            send_btn.click(patient_chat, inputs=[chat_history, user_input], outputs=[chat_history, chat_history])

        with gr.TabItem("Disease Prediction"):
            with gr.Row():
                with gr.Column():
                    symptoms_input = gr.Textbox(
                        label="Enter Symptoms",
                        placeholder="e.g., fever, headache, cough, fatigue...",
                        lines=4
                    )
                    predict_btn = gr.Button("Analyze Symptoms")
                with gr.Column():
                    prediction_output = gr.Textbox(label="Possible Conditions & Recommendations", lines=20)
            predict_btn.click(disease_prediction, inputs=symptoms_input, outputs=prediction_output)

        with gr.TabItem("Treatment Plans"):
            with gr.Row():
                with gr.Column():
                    condition_input = gr.Textbox(
                        label="Medical Condition",
                        placeholder="e.g., diabetes, hypertension, migraine...",
                        lines=2
                    )
                    age_input = gr.Number(label="Age", value=30)
                    gender_input = gr.Dropdown(
                        choices=["Male", "Female", "Other"],
                        label="Gender",
                        value="Male"
                    )
                    history_input = gr.Textbox(
                        label="Medical History",
                        placeholder="Previous conditions, allergies, medications or None",
                        lines=3
                    )
                    plan_btn = gr.Button("Generate Treatment Plan")
                with gr.Column():
                    plan_output = gr.Textbox(label="Personalized Treatment Plan", lines=20)
            plan_btn.click(
                treatment_plan,
                inputs=[condition_input, age_input, gender_input, history_input],
                outputs=plan_output
            )

    app.launch(server_name="0.0.0.0", server_port=8080, share=True)