Spaces:
Runtime error
Runtime error
File size: 5,423 Bytes
8f1f240 4851757 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# -*- coding: utf-8 -*-
"""heai.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1CPgKNfxzP9sPf9nsHmsct1wlUuZL3XpL
"""
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load model and tokenizer
model_name = "ibm-granite/granite-3.2-2b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Function to generate LLM response
def generate_response(prompt, max_length=1024):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
if torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.replace(prompt, "").strip()
return response
# Function for disease prediction
def disease_prediction(symptoms):
prompt = f"""Based on the following symptoms, provide possible medical conditions and general medication suggestions.
Always emphasize the importance of consulting a doctor for proper diagnosis.
Symptoms: {symptoms}
Possible conditions and recommendations:
**IMPORTANT: This is for informational purposes only. Please consult a healthcare professional for proper diagnosis and treatment.**
Analysis:"""
return generate_response(prompt, max_length=1200)
# Function for treatment plan
def treatment_plan(condition, age, gender, medical_history):
prompt = f"""Generate personalized treatment suggestions for the following patient information. Include home remedies and general medication guidelines.
Medical Condition: {condition}
Age: {age}
Gender: {gender}
Medical History: {medical_history}
Personalized treatment plan including home remedies and medication guidelines:
**IMPORTANT: This is for informational purposes only. Please consult a healthcare professional for proper treatment.**
Treatment Plan:"""
return generate_response(prompt, max_length=1200)
# Function for chat with patient
def patient_chat(chat_history, user_input):
conversation = chat_history + f"\nPatient: {user_input}\nAI:"
response = generate_response(conversation, max_length=800)
chat_history += f"\nPatient: {user_input}\nAI: {response}"
return chat_history, chat_history
# Build Gradio app
with gr.Blocks() as app:
gr.Markdown("# Medical AI Assistant")
gr.Markdown("**Disclaimer: This is for informational purposes only. Always consult healthcare professionals for medical advice.**")
with gr.Tabs():
with gr.TabItem("Patient Chat"):
chat_history = gr.Textbox(label="Conversation", lines=15, value="", interactive=False)
user_input = gr.Textbox(label="Your Message", placeholder="Describe your symptoms or ask questions...", lines=2)
send_btn = gr.Button("Send")
send_btn.click(patient_chat, inputs=[chat_history, user_input], outputs=[chat_history, chat_history])
with gr.TabItem("Disease Prediction"):
with gr.Row():
with gr.Column():
symptoms_input = gr.Textbox(
label="Enter Symptoms",
placeholder="e.g., fever, headache, cough, fatigue...",
lines=4
)
predict_btn = gr.Button("Analyze Symptoms")
with gr.Column():
prediction_output = gr.Textbox(label="Possible Conditions & Recommendations", lines=20)
predict_btn.click(disease_prediction, inputs=symptoms_input, outputs=prediction_output)
with gr.TabItem("Treatment Plans"):
with gr.Row():
with gr.Column():
condition_input = gr.Textbox(
label="Medical Condition",
placeholder="e.g., diabetes, hypertension, migraine...",
lines=2
)
age_input = gr.Number(label="Age", value=30)
gender_input = gr.Dropdown(
choices=["Male", "Female", "Other"],
label="Gender",
value="Male"
)
history_input = gr.Textbox(
label="Medical History",
placeholder="Previous conditions, allergies, medications or None",
lines=3
)
plan_btn = gr.Button("Generate Treatment Plan")
with gr.Column():
plan_output = gr.Textbox(label="Personalized Treatment Plan", lines=20)
plan_btn.click(
treatment_plan,
inputs=[condition_input, age_input, gender_input, history_input],
outputs=plan_output
)
app.launch(server_name="0.0.0.0", server_port=8080, share=True)
|