Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| # -*- coding: utf-8 -*- | |
| """app.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1BmTzCgYHoIX81jKTqf4ImJaKRRbxgoTS | |
| """ | |
| import os | |
| import csv | |
| import pandas as pd | |
| import plotly.express as px | |
| from datetime import datetime | |
| import torch | |
| import faiss | |
| import numpy as np | |
| import gradio as gr | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # from google.colab import drive | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| from peft import PeftModel | |
| from huggingface_hub import login | |
| from transformers import pipeline as hf_pipeline | |
| from fpdf import FPDF | |
| import uuid | |
| import textwrap | |
| from dotenv import load_dotenv | |
| import shutil | |
| try: | |
| import whisper | |
| except ImportError: | |
| os.system("pip install -U openai-whisper") | |
| import whisper | |
| # Load Whisper model here | |
| whisper_model = whisper.load_model("base") | |
| load_dotenv() | |
| hf_token = os.getenv("HF_TOKEN") | |
| login(token=hf_token) | |
| # Mount Google Drive | |
| #drive.mount('/content/drive') | |
| # ------------------------------- | |
| # π§ Configuration | |
| # ------------------------------- | |
| base_model_path = "google/gemma-2-9b-it" | |
| #peft_model_path = "Jaamie/gemma-mental-health-qlora" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| embedding_model_bge = "BAAI/bge-base-en-v1.5" | |
| #save_path_bge = "./models/bge-base-en-v1.5" | |
| faiss_index_path = "./qa_faiss_embedding.index" | |
| chunked_text_path = "./chunked_text_RAG_text.txt" | |
| READER_MODEL_NAME = "google/gemma-2-9b-it" | |
| #READER_MODEL_NAME = "google/gemma-2b-it" | |
| log_file_path = "./diagnosis_logs.csv" | |
| feedback_file_path = "./feedback_logs.csv" | |
| # ------------------------------- | |
| # π§ Logging setup | |
| # ------------------------------- | |
| if not os.path.exists(log_file_path): | |
| with open(log_file_path, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["timestamp", "user_id", "input_type", "query", "diagnosis", "confidence_score", "status"]) | |
| # ------------------------------- | |
| # π§ Feedback setup | |
| # ------------------------------- | |
| if not os.path.exists(feedback_file_path): | |
| with open(feedback_file_path, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.writer(f) | |
| writer.writerow([ | |
| "feedback_id", "timestamp", "user_id", "input_type", "query", | |
| "diagnosis", "status", "feedback" | |
| ]) | |
| # Ensure directory exists | |
| #os.makedirs(save_path_bge, exist_ok=True) | |
| # ------------------------------- | |
| # π§ Model setup | |
| # ------------------------------- | |
| # Load Sentence Transformer Model | |
| # if not os.path.exists(os.path.join(save_path_bge, "config.json")): | |
| # print("Saving model to Google Drive...") | |
| # embedding_model = SentenceTransformer(embedding_model_bge) | |
| # embedding_model.save(save_path_bge) | |
| # print("Model saved successfully!") | |
| # else: | |
| # print("Loading model from Google Drive...") | |
| # device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # embedding_model = SentenceTransformer(save_path_bge, device=device) | |
| embedding_model = SentenceTransformer(embedding_model_bge, device=device) | |
| print("β BGE Embedding model loaded from Hugging Face.") | |
| # Load FAISS Index | |
| faiss_index = faiss.read_index(faiss_index_path) | |
| print("FAISS index loaded successfully!") | |
| # Load chunked text | |
| def load_chunked_text(): | |
| with open(chunked_text_path, "r", encoding="utf-8") as f: | |
| return f.read().split("\n\n---\n\n") | |
| chunked_text = load_chunked_text() | |
| print(f"Loaded {len(chunked_text)} text chunks.") | |
| # loading model for emotion classifier | |
| emotion_result = {} | |
| emotion_classifier = hf_pipeline("text-classification", model="nateraw/bert-base-uncased-emotion") | |
| # ------------------------------- | |
| # π§ Load base model + LoRA adapter | |
| # ------------------------------- | |
| # base_model = AutoModelForCausalLM.from_pretrained( | |
| # base_model_path, | |
| # torch_dtype=torch.float16, | |
| # device_map="auto" # Use accelerate for smart placement | |
| # ) | |
| # # Load the LoRA adapter on top of the base model | |
| # diagnosis_model = PeftModel.from_pretrained( | |
| # base_model, | |
| # peft_model_path | |
| # ).to(device) | |
| # # Load tokenizer from the same fine-tuned repo | |
| # diagnosis_tokenizer = AutoTokenizer.from_pretrained(peft_model_path) | |
| # # Set model to evaluation mode | |
| # diagnosis_model.eval() | |
| # print("β Model & tokenizer loaded successfully.") | |
| # # Create text-generation pipeline WITHOUT `device` arg | |
| # READER_LLM = pipeline( | |
| # model=diagnosis_model, | |
| # tokenizer=diagnosis_tokenizer, | |
| # task="text-generation", | |
| # do_sample=True, | |
| # temperature=0.2, | |
| # repetition_penalty=1.1, | |
| # return_full_text=False, | |
| # max_new_tokens=500 | |
| # ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME) | |
| #model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME).to(device) | |
| # model_id = "mistralai/Mistral-7B-Instruct-v0.1" | |
| # #model_id = "TheBloke/Gemma-2-7B-IT-GGUF" | |
| # tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| # model = AutoModelForCausalLM.from_pretrained( | |
| # model_id, | |
| # torch_dtype=torch.float16, | |
| # device_map="auto", | |
| # ).to(device) | |
| READER_LLM = pipeline( | |
| model=model, | |
| tokenizer=tokenizer, | |
| task="text-generation", | |
| do_sample=True, | |
| temperature=0.2, | |
| repetition_penalty=1.1, | |
| return_full_text=False, | |
| max_new_tokens=500, | |
| #device=device, | |
| ) | |
| # ------------------------------- | |
| # π§ Whisper Model Setup | |
| # ------------------------------- | |
| def process_whisper_query(audio): | |
| try: | |
| audio_data = whisper.load_audio(audio) | |
| audio_data = whisper.pad_or_trim(audio_data) | |
| mel = whisper.log_mel_spectrogram(audio_data).to(whisper_model.device) | |
| result = whisper_model.decode(mel, whisper.DecodingOptions(fp16=False)) | |
| transcribed_text = result.text.strip() | |
| response, download_path = process_query(transcribed_text, input_type="voice") | |
| return response, download_path | |
| except Exception as e: | |
| return f"β οΈ Error processing audio: {str(e)}", None | |
| def extract_diagnosis(response_text: str) -> str: | |
| for line in response_text.splitlines(): | |
| if "Diagnosed Mental Disorder" in line: | |
| return line.split(":")[-1].strip() | |
| return "Unknown" | |
| # calculating the correctness of the answer - Hallucination | |
| def calculate_rag_confidence(query_embedding, top_k_docs_embeddings, generation_logprobs=None): | |
| """ | |
| Combines retriever and generation signals to compute a confidence score. | |
| Args: | |
| query_embedding (np.ndarray): Embedding vector of the user query (shape: [1, dim]). | |
| top_k_docs_embeddings (np.ndarray): Embedding matrix of top-k retrieved documents (shape: [k, dim]). | |
| generation_logprobs (list, optional): List of logprobs for generated tokens. | |
| Returns: | |
| float: Final confidence score (0 to 1). | |
| """ | |
| retriever_similarities = cosine_similarity(query_embedding, top_k_docs_embeddings) | |
| retriever_confidence = float(np.max(retriever_similarities)) | |
| if generation_logprobs: | |
| gen_confidence = float(np.exp(np.mean(generation_logprobs))) | |
| else: | |
| gen_confidence = 0.0 # fallback if unavailable | |
| alpha, beta = 0.6, 0.4 | |
| final_confidence = alpha * retriever_confidence + beta * gen_confidence | |
| return round(final_confidence, 4) | |
| # Main Process | |
| def process_query(user_query, input_type="text"): | |
| # Embed the query | |
| query_embedding = embedding_model.encode(user_query, normalize_embeddings=True) | |
| query_embedding = np.array([query_embedding], dtype=np.float32) | |
| # Search FAISS index | |
| k = 5 # Retrieve top 5 relevant docs | |
| distances, indices = faiss_index.search(query_embedding, k) | |
| retrieved_docs = [chunked_text[i] for i in indices[0]] | |
| # Construct context | |
| context = "\nExtracted documents:\n" + "".join([f"Document {i}:::\n{doc}\n" for i, doc in enumerate(retrieved_docs)]) | |
| # Detect emotion | |
| emotion_result = emotion_classifier(user_query)[0] | |
| print(f"Detected emotion: {emotion_result}") | |
| emotion = emotion_result['label'] | |
| value = round(emotion_result['score'], 2) | |
| # Define RAG prompt | |
| prompt_in_chat_format = [ | |
| {"role": "user", "content": f""" | |
| You are an AI assistant specialized in diagnosing mental disorders in humans. | |
| Using the information contained in the context, answer the question comprehensively. | |
| The **Diagnosed Mental Disorder** should be only one from the list provided. | |
| [Normal, Depression, Suicidal, Anxiety, Stress, Bi-Polar, Personality Disorder] | |
| Your response must include: | |
| 1. **Diagnosed Mental Disorder** | |
| 2. **Detected emotion** {emotion} | |
| 3. **Intensity of emotion** {value} | |
| 3. **Matching Symptoms** from the context | |
| 4. **Personalized Treatment** | |
| 5. **Helpline Numbers** | |
| 6. **Source Link** (if applicable) | |
| Make sure to provide a comprehensive and accurate diagnosis and explain the personalised treatment in detail. | |
| If a disorder cannot be determined, return **Diagnosed Mental Disorder** as "Unknown". | |
| --- | |
| Context: | |
| {context} | |
| Question: {user_query}"""}, | |
| {"role": "assistant", "content": ""}, | |
| ] | |
| RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template( | |
| prompt_in_chat_format, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Generate response | |
| #answer = READER_LLM(RAG_PROMPT_TEMPLATE)[0]["generated_text"] | |
| try: | |
| response = READER_LLM(RAG_PROMPT_TEMPLATE) | |
| # print("π Raw LLM output:", response) | |
| answer = response[0]["generated_text"] if response and "generated_text" in response[0] else "β οΈ No output generated." | |
| except Exception as e: | |
| print("β Error during generation:", e) | |
| answer = "β οΈ An error occurred while generating the response." | |
| # Get embeddings of retrieved docs | |
| retrieved_doc_embeddings = embedding_model.encode(retrieved_docs, normalize_embeddings=True) | |
| retrieved_doc_embeddings = np.array(retrieved_doc_embeddings, dtype=np.float32) | |
| # Calculate RAG-based confidence | |
| confidence_score = calculate_rag_confidence(query_embedding, retrieved_doc_embeddings) | |
| # Add to response | |
| answer += f"\n\nπ§ Accuracy & Closeness of the Answer: {confidence_score:.2f}" | |
| answer += "\n\n*Derived from semantic similarity and generation certainty." | |
| # Extracting diagnosis | |
| diagnosis = extract_diagnosis(answer) | |
| status = "fallback" if diagnosis.lower() == "unknown" else "success" | |
| # Log interaction | |
| log_query(input_type=input_type, query=user_query, diagnosis=diagnosis, confidence_score=confidence_score, status=status) | |
| download_path = create_summary_txt(answer) | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| user_id = session_data["latest"]["user_id"] | |
| # Prepend to the answer string | |
| answer_header = f"π§Ύ Session ID: {user_id}\nπ Timestamp: {timestamp}\n\n" | |
| return answer_header + answer, download_path | |
| #return answer, download_path | |
| # Dashboard Interface | |
| def diagnosis_dashboard(): | |
| try: | |
| df = pd.read_csv(log_file_path) | |
| if df.empty: | |
| return "No data logged yet." | |
| # Filter out unknown or fallback cases if needed | |
| df = df[df["diagnosis"].notna()] | |
| df = df[df["diagnosis"].str.lower() != "unknown"] | |
| # Diagnosis frequency | |
| diagnosis_counts = df["diagnosis"].value_counts().reset_index() | |
| diagnosis_counts.columns = ["Diagnosis", "Count"] | |
| # Create bar chart | |
| fig = px.bar( | |
| diagnosis_counts, | |
| x="Diagnosis", | |
| y="Count", | |
| color="Diagnosis", | |
| title="π Mental Health Diagnosis Distribution", | |
| text_auto=True | |
| ) | |
| fig.update_layout(showlegend=False) | |
| return fig | |
| except Exception as e: | |
| return f"β οΈ Error loading dashboard: {str(e)}" | |
| # For logs functionality | |
| # def log_query(input_type, query, diagnosis, confidence_score, status): | |
| # with open(log_file_path, "a", newline="", encoding="utf-8") as f: | |
| # writer = csv.writer(f, quoting=csv.QUOTE_ALL) | |
| # writer.writerow([ | |
| # datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| # input_type.replace('"', '""'), | |
| # query.replace('"', '""'), | |
| # diagnosis.replace('"', '""'), | |
| # str(confidence_score), | |
| # status | |
| # ]) | |
| session_data = {} | |
| def log_query(input_type, query, diagnosis, confidence_score, status): | |
| user_id = f"SSuser_ID_{uuid.uuid4().hex[:8]}" | |
| # Store in-memory session data for feedback use | |
| session_data["latest"] = { | |
| "user_id": user_id, | |
| "input_type": input_type, | |
| "query": query, | |
| "diagnosis": diagnosis, | |
| "confidence_score": confidence_score, | |
| "status": status | |
| } | |
| with open(log_file_path, "a", newline="", encoding="utf-8") as f: | |
| writer = csv.writer(f, quoting=csv.QUOTE_ALL) | |
| writer.writerow([ | |
| str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), | |
| str(user_id), | |
| str(input_type).replace('"', '""'), | |
| str(query).replace('"', '""'), | |
| str(diagnosis).replace('"', '""'), | |
| str(confidence_score), | |
| str(status) | |
| ]) | |
| def show_logs(): | |
| try: | |
| df = pd.read_csv(log_file_path) | |
| return df.tail(100) | |
| except Exception as e: | |
| return f"β οΈ Error: {e}" | |
| # def create_summary_pdf(text, filename_prefix="diagnosis_report"): | |
| # try: | |
| # filename = f"{filename_prefix}_{uuid.uuid4().hex[:6]}.pdf" | |
| # filepath = os.path.join(".", filename) # Save in current directory | |
| # pdf = FPDF() | |
| # pdf.add_page() | |
| # pdf.set_font("Arial", style='B', size=14) | |
| # pdf.cell(200, 10, txt="π§ Mental Health Diagnosis Report", ln=True, align='C') | |
| # pdf.set_font("Arial", size=12) | |
| # pdf.ln(10) | |
| # wrapped = textwrap.wrap(text, width=90) | |
| # for line in wrapped: | |
| # pdf.cell(200, 10, txt=line, ln=True) | |
| # pdf.output(filepath) | |
| # print(f"β PDF created at: {filepath}") | |
| # return filepath | |
| # except Exception as e: | |
| # print(f"β Error creating PDF: {e}") | |
| # return None | |
| def create_summary_txt(text, filename_prefix="diagnosis_report"): | |
| filename = f"{filename_prefix}_{uuid.uuid4().hex[:6]}.txt" | |
| with open(filename, "w", encoding="utf-8") as f: | |
| f.write(text) | |
| print(f"β TXT report created: {filename}") | |
| return filename | |
| # π₯ Feedback | |
| # feedback_data = [] | |
| # def submit_feedback(feedback, input_type, query, diagnosis, confidence_score, status): | |
| # feedback_id = str(uuid.uuid4()) | |
| # timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| # with open(feedback_file_path, "a", newline="", encoding="utf-8") as f: | |
| # writer = csv.writer(f, quoting=csv.QUOTE_ALL) | |
| # writer.writerow([ | |
| # feedback_id, | |
| # timestamp, | |
| # input_type.replace('"', '""'), | |
| # query.replace('"', '""'), | |
| # diagnosis.replace('"', '""'), | |
| # str(confidence_score), | |
| # status, | |
| # feedback.replace('"', '""') | |
| # ]) | |
| # return f"β Feedback received! Your Feedback ID: {feedback_id}" | |
| def submit_feedback(feedback): | |
| # if "latest" not in session_data: | |
| # return "β οΈ No diagnosis found for this session. Please get a diagnosis first." | |
| user_info = session_data["latest"] | |
| feedback_id = f"fb_{uuid.uuid4().hex[:8]}" | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| with open(feedback_file_path, "a", newline="", encoding="utf-8") as f: | |
| writer = csv.writer(f, quoting=csv.QUOTE_ALL) | |
| writer.writerow([ | |
| feedback_id, | |
| timestamp, | |
| user_info["user_id"], | |
| user_info["input_type"], | |
| user_info["query"], | |
| user_info["diagnosis"], | |
| user_info["status"], | |
| feedback.replace('"', '""') | |
| ]) | |
| return f"β Feedback received! Your Feedback ID: {feedback_id}" | |
| def download_feedback_log(): | |
| return feedback_file_path | |
| # def send_email_report(to_email, response): | |
| # response = resend.Emails.send({ | |
| # "from": "MentalBot <noreply@safespaceai.com>", | |
| # "to": [to_email], | |
| # "subject": "π§ Your Personalized Mental Health Report", | |
| # "text": response | |
| # }) | |
| # return "β Diagnosis report sent to your email!" if response.get("id") else "β οΈ Failed to send email." | |
| # For pdf | |
| # def unified_handler(audio, text): | |
| # if audio: | |
| # response, download_path = process_whisper_query(audio) | |
| # else: | |
| # response, download_path = process_query(text, input_type="text") | |
| # # Ensure download path is valid | |
| # # if not (download_path and os.path.exists(download_path)): | |
| # # print("β PDF not found or failed to generate.") | |
| # # return response, None | |
| # if download_path and os.path.exists(download_path): | |
| # return response, download_path | |
| # else: | |
| # print("β PDF not found or failed to generate.") | |
| # return response, None | |
| # for text doc download | |
| def unified_handler(audio, text): | |
| if audio: | |
| response, _ = process_whisper_query(audio) | |
| else: | |
| response, _ = process_query(text, input_type="text") | |
| download_path = create_summary_txt(response) | |
| return response, download_path | |
| # Gradio UI | |
| main_assistant_tab = gr.Interface( | |
| fn=unified_handler, | |
| inputs=[ | |
| gr.Audio(type="filepath", label="π Speak your concern"), | |
| gr.Textbox(lines=2, placeholder="Or type your mental health concern here...") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="π§ Personalized Diagnosis", lines=15, show_copy_button=True), | |
| gr.File(label="π₯ Download Diagnosis Report") | |
| ], | |
| title="π§ SafeSpace AI", | |
| description="π *We care for you.*\n\nSpeak or type your concern to receive AI-powered mental health insights. Get your report emailed or download it as a file." | |
| ) | |
| dashboard_tab = gr.Interface( | |
| fn=diagnosis_dashboard, | |
| inputs=[], | |
| outputs=gr.Plot(label="π Diagnosis Distribution"), | |
| title="π Usage Dashboard" | |
| ) | |
| logs_tab = gr.Interface( | |
| fn=show_logs, | |
| inputs=[], | |
| outputs=gr.Dataframe(label="π Diagnosis Logs (Latest 100 entries)"), | |
| title="π Logs" | |
| ) | |
| feedback_tab = gr.Interface( | |
| fn=submit_feedback, | |
| inputs=[gr.Textbox(label="π Share your thoughts")], | |
| outputs="text", | |
| title="π Submit Feedback" | |
| ) | |
| feedback_download_tab = gr.Interface( | |
| fn=download_feedback_log, | |
| inputs=[], | |
| outputs=gr.File(label="π₯ Download All Feedback Logs"), | |
| title="π Download Feedback CSV" | |
| ) | |
| agent_tab = gr.Interface( | |
| fn=lambda: "", | |
| inputs=[], | |
| outputs=gr.HTML( | |
| """<button onclick="window.open('https://jaamie-mental-health-agent.hf.space', '_blank')" | |
| style='padding:10px 20px; font-size:16px; background-color:#4CAF50; color:white; border:none; border-radius:5px;'> | |
| π§ Launch Agent SafeSpace 001 | |
| </button>""" | |
| ), | |
| title="π€ Agent SafeSpace 001" | |
| ) | |
| # Add to your tab list | |
| app = gr.TabbedInterface( | |
| interface_list=[ | |
| main_assistant_tab, | |
| dashboard_tab, | |
| logs_tab, | |
| feedback_tab, | |
| feedback_download_tab, | |
| agent_tab | |
| ], | |
| tab_names=[ | |
| "π§ Assistant", | |
| "π Dashboard", | |
| "π Logs", | |
| "π Feedback", | |
| "π Feedback CSV", | |
| "π€ Agent 001" | |
| ] | |
| ) | |
| #app.launch(share=True) | |
| print("π SafeSpace AI is live!") | |
| # Launch the Gradio App | |
| if __name__ == "__main__": | |
| app.launch() | |