healai / app.py
Divyashree1326's picture
Update app.py
4851757 verified
# -*- coding: utf-8 -*-
"""heai.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1CPgKNfxzP9sPf9nsHmsct1wlUuZL3XpL
"""
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load model and tokenizer
model_name = "ibm-granite/granite-3.2-2b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Function to generate LLM response
def generate_response(prompt, max_length=1024):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
if torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.replace(prompt, "").strip()
return response
# Function for disease prediction
def disease_prediction(symptoms):
prompt = f"""Based on the following symptoms, provide possible medical conditions and general medication suggestions.
Always emphasize the importance of consulting a doctor for proper diagnosis.
Symptoms: {symptoms}
Possible conditions and recommendations:
**IMPORTANT: This is for informational purposes only. Please consult a healthcare professional for proper diagnosis and treatment.**
Analysis:"""
return generate_response(prompt, max_length=1200)
# Function for treatment plan
def treatment_plan(condition, age, gender, medical_history):
prompt = f"""Generate personalized treatment suggestions for the following patient information. Include home remedies and general medication guidelines.
Medical Condition: {condition}
Age: {age}
Gender: {gender}
Medical History: {medical_history}
Personalized treatment plan including home remedies and medication guidelines:
**IMPORTANT: This is for informational purposes only. Please consult a healthcare professional for proper treatment.**
Treatment Plan:"""
return generate_response(prompt, max_length=1200)
# Function for chat with patient
def patient_chat(chat_history, user_input):
conversation = chat_history + f"\nPatient: {user_input}\nAI:"
response = generate_response(conversation, max_length=800)
chat_history += f"\nPatient: {user_input}\nAI: {response}"
return chat_history, chat_history
# Build Gradio app
with gr.Blocks() as app:
gr.Markdown("# Medical AI Assistant")
gr.Markdown("**Disclaimer: This is for informational purposes only. Always consult healthcare professionals for medical advice.**")
with gr.Tabs():
with gr.TabItem("Patient Chat"):
chat_history = gr.Textbox(label="Conversation", lines=15, value="", interactive=False)
user_input = gr.Textbox(label="Your Message", placeholder="Describe your symptoms or ask questions...", lines=2)
send_btn = gr.Button("Send")
send_btn.click(patient_chat, inputs=[chat_history, user_input], outputs=[chat_history, chat_history])
with gr.TabItem("Disease Prediction"):
with gr.Row():
with gr.Column():
symptoms_input = gr.Textbox(
label="Enter Symptoms",
placeholder="e.g., fever, headache, cough, fatigue...",
lines=4
)
predict_btn = gr.Button("Analyze Symptoms")
with gr.Column():
prediction_output = gr.Textbox(label="Possible Conditions & Recommendations", lines=20)
predict_btn.click(disease_prediction, inputs=symptoms_input, outputs=prediction_output)
with gr.TabItem("Treatment Plans"):
with gr.Row():
with gr.Column():
condition_input = gr.Textbox(
label="Medical Condition",
placeholder="e.g., diabetes, hypertension, migraine...",
lines=2
)
age_input = gr.Number(label="Age", value=30)
gender_input = gr.Dropdown(
choices=["Male", "Female", "Other"],
label="Gender",
value="Male"
)
history_input = gr.Textbox(
label="Medical History",
placeholder="Previous conditions, allergies, medications or None",
lines=3
)
plan_btn = gr.Button("Generate Treatment Plan")
with gr.Column():
plan_output = gr.Textbox(label="Personalized Treatment Plan", lines=20)
plan_btn.click(
treatment_plan,
inputs=[condition_input, age_input, gender_input, history_input],
outputs=plan_output
)
app.launch(server_name="0.0.0.0", server_port=8080, share=True)