Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """heai.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1CPgKNfxzP9sPf9nsHmsct1wlUuZL3XpL | |
| """ | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Load model and tokenizer | |
| model_name = "ibm-granite/granite-3.2-2b-instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Function to generate LLM response | |
| def generate_response(prompt, max_length=1024): | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response.replace(prompt, "").strip() | |
| return response | |
| # Function for disease prediction | |
| def disease_prediction(symptoms): | |
| prompt = f"""Based on the following symptoms, provide possible medical conditions and general medication suggestions. | |
| Always emphasize the importance of consulting a doctor for proper diagnosis. | |
| Symptoms: {symptoms} | |
| Possible conditions and recommendations: | |
| **IMPORTANT: This is for informational purposes only. Please consult a healthcare professional for proper diagnosis and treatment.** | |
| Analysis:""" | |
| return generate_response(prompt, max_length=1200) | |
| # Function for treatment plan | |
| def treatment_plan(condition, age, gender, medical_history): | |
| prompt = f"""Generate personalized treatment suggestions for the following patient information. Include home remedies and general medication guidelines. | |
| Medical Condition: {condition} | |
| Age: {age} | |
| Gender: {gender} | |
| Medical History: {medical_history} | |
| Personalized treatment plan including home remedies and medication guidelines: | |
| **IMPORTANT: This is for informational purposes only. Please consult a healthcare professional for proper treatment.** | |
| Treatment Plan:""" | |
| return generate_response(prompt, max_length=1200) | |
| # Function for chat with patient | |
| def patient_chat(chat_history, user_input): | |
| conversation = chat_history + f"\nPatient: {user_input}\nAI:" | |
| response = generate_response(conversation, max_length=800) | |
| chat_history += f"\nPatient: {user_input}\nAI: {response}" | |
| return chat_history, chat_history | |
| # Build Gradio app | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Medical AI Assistant") | |
| gr.Markdown("**Disclaimer: This is for informational purposes only. Always consult healthcare professionals for medical advice.**") | |
| with gr.Tabs(): | |
| with gr.TabItem("Patient Chat"): | |
| chat_history = gr.Textbox(label="Conversation", lines=15, value="", interactive=False) | |
| user_input = gr.Textbox(label="Your Message", placeholder="Describe your symptoms or ask questions...", lines=2) | |
| send_btn = gr.Button("Send") | |
| send_btn.click(patient_chat, inputs=[chat_history, user_input], outputs=[chat_history, chat_history]) | |
| with gr.TabItem("Disease Prediction"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| symptoms_input = gr.Textbox( | |
| label="Enter Symptoms", | |
| placeholder="e.g., fever, headache, cough, fatigue...", | |
| lines=4 | |
| ) | |
| predict_btn = gr.Button("Analyze Symptoms") | |
| with gr.Column(): | |
| prediction_output = gr.Textbox(label="Possible Conditions & Recommendations", lines=20) | |
| predict_btn.click(disease_prediction, inputs=symptoms_input, outputs=prediction_output) | |
| with gr.TabItem("Treatment Plans"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| condition_input = gr.Textbox( | |
| label="Medical Condition", | |
| placeholder="e.g., diabetes, hypertension, migraine...", | |
| lines=2 | |
| ) | |
| age_input = gr.Number(label="Age", value=30) | |
| gender_input = gr.Dropdown( | |
| choices=["Male", "Female", "Other"], | |
| label="Gender", | |
| value="Male" | |
| ) | |
| history_input = gr.Textbox( | |
| label="Medical History", | |
| placeholder="Previous conditions, allergies, medications or None", | |
| lines=3 | |
| ) | |
| plan_btn = gr.Button("Generate Treatment Plan") | |
| with gr.Column(): | |
| plan_output = gr.Textbox(label="Personalized Treatment Plan", lines=20) | |
| plan_btn.click( | |
| treatment_plan, | |
| inputs=[condition_input, age_input, gender_input, history_input], | |
| outputs=plan_output | |
| ) | |
| app.launch(server_name="0.0.0.0", server_port=8080, share=True) | |