Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| import random | |
| import faiss | |
| import json | |
| import logging | |
| import re | |
| import streamlit as st | |
| from datetime import datetime | |
| import os | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================ | |
| # DATA PREPARATION | |
| # ============================ | |
| def prepare_dataset(): | |
| """Load and prepare the emotion dataset""" | |
| print("π Loading emotion dataset...") | |
| # Load the dataset | |
| ds = load_dataset("cardiffnlp/tweet_eval", "emotion") | |
| # Define emotion labels (matching the dataset) | |
| emotion_labels = ["anger", "joy", "optimism", "sadness"] | |
| def clean_text(text): | |
| """Clean and preprocess text""" | |
| text = text.lower() | |
| text = re.sub(r"http\S+", "", text) # remove URLs | |
| text = re.sub(r"[^\w\s]", "", text) # remove special characters | |
| text = re.sub(r"\d+", "", text) # remove numbers | |
| text = re.sub(r"\s+", " ", text) # normalize whitespace | |
| return text.strip() | |
| # Sample and prepare training data | |
| train_data = ds['train'] | |
| train_sample = random.sample(list(train_data), min(1000, len(train_data))) | |
| # Convert to RAG format | |
| rag_json = [] | |
| for row in train_sample: | |
| cleaned_text = clean_text(row['text']) | |
| if len(cleaned_text) > 10: # Filter out very short texts | |
| rag_json.append({ | |
| "text": cleaned_text, | |
| "emotion": emotion_labels[row['label']], | |
| "original_text": row['text'] | |
| }) | |
| print(f"Dataset prepared with {len(rag_json)} samples") | |
| return rag_json | |
| # ============================ | |
| # EMOTION DETECTION MODEL | |
| # ============================ | |
| class EmotionDetector: | |
| def __init__(self): | |
| self.model_name = "bhadresh-savani/distilbert-base-uncased-emotion" | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name) | |
| self.classifier = pipeline( | |
| "text-classification", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| return_all_scores=False | |
| ) | |
| except Exception as e: | |
| st.error(f"β Error loading emotion model: {e}") | |
| raise | |
| def detect_emotion(self, text): | |
| """Detect emotion from text""" | |
| try: | |
| result = self.classifier(text) | |
| emotion = result[0]['label'].lower() | |
| confidence = result[0]['score'] | |
| # Map model outputs to our emotion categories | |
| emotion_mapping = { | |
| 'anger': 'anger', | |
| 'joy': 'joy', | |
| 'love': 'joy', | |
| 'happiness': 'joy', | |
| 'sadness': 'sadness', | |
| 'fear': 'sadness', | |
| 'surprise': 'optimism', | |
| 'optimism': 'optimism' | |
| } | |
| mapped_emotion = emotion_mapping.get(emotion, 'optimism') | |
| return mapped_emotion, confidence | |
| except Exception as e: | |
| logger.error(f"Error in emotion detection: {e}") | |
| return 'optimism', 0.5 | |
| # ============================ | |
| # RAG SYSTEM WITH FAISS | |
| # ============================ | |
| class RAGSystem: | |
| def __init__(self, rag_data): | |
| self.rag_data = rag_data | |
| self.texts = [entry['text'] for entry in rag_data] | |
| # Initialize embedding model | |
| self.embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| # Create embeddings | |
| self.embeddings = self.embed_model.encode( | |
| self.texts, | |
| convert_to_numpy=True, | |
| show_progress_bar=False | |
| ) | |
| # Create FAISS index | |
| dimension = self.embeddings.shape[1] | |
| self.index = faiss.IndexFlatL2(dimension) | |
| self.index.add(self.embeddings) | |
| def retrieve_templates(self, user_input, detected_emotion, top_k=3): | |
| """Retrieve relevant templates based on emotion and similarity""" | |
| # Filter by emotion first | |
| emotion_filtered_indices = [ | |
| i for i, entry in enumerate(self.rag_data) | |
| if entry['emotion'] == detected_emotion | |
| ] | |
| if not emotion_filtered_indices: | |
| emotion_filtered_indices = list(range(len(self.rag_data))) | |
| # Get filtered embeddings | |
| filtered_embeddings = self.embeddings[emotion_filtered_indices] | |
| filtered_texts = [self.texts[i] for i in emotion_filtered_indices] | |
| # Create temporary index for filtered data | |
| temp_index = faiss.IndexFlatL2(filtered_embeddings.shape[1]) | |
| temp_index.add(filtered_embeddings) | |
| # Search for similar templates | |
| user_embedding = self.embed_model.encode([user_input], convert_to_numpy=True) | |
| distances, indices = temp_index.search( | |
| user_embedding, | |
| min(top_k, len(filtered_texts)) | |
| ) | |
| # Get top templates | |
| top_templates = [filtered_texts[i] for i in indices[0]] | |
| return top_templates | |
| # ============================ | |
| # RESPONSE GENERATOR | |
| # ============================ | |
| class ResponseGenerator: | |
| def __init__(self, emotion_detector, rag_system): | |
| self.emotion_detector = emotion_detector | |
| self.rag_system = rag_system | |
| # Empathetic response templates by emotion | |
| self.response_templates = { | |
| 'anger': [ | |
| "I can understand why you're feeling frustrated. It's completely valid to feel this way.", | |
| "Your anger is understandable. Sometimes situations can be really challenging.", | |
| "I hear that you're upset, and that's okay. These feelings are important." | |
| ], | |
| 'sadness': [ | |
| "I'm sorry you're going through a difficult time. Your feelings are valid.", | |
| "It sounds like you're dealing with something really tough right now.", | |
| "I can sense your sadness, and I want you to know that it's okay to feel this way." | |
| ], | |
| 'joy': [ | |
| "I'm so happy to hear about your positive experience! That's wonderful.", | |
| "Your joy is contagious! It's great to hear such positive news.", | |
| "I love hearing about things that make you happy. That sounds amazing!" | |
| ], | |
| 'optimism': [ | |
| "Your positive outlook is inspiring. That's a great way to look at things.", | |
| "I appreciate your hopeful perspective. That's really encouraging.", | |
| "It's wonderful to hear your optimistic thoughts. Keep that positive energy!" | |
| ] | |
| } | |
| def generate_response(self, user_input, top_k=3): | |
| """Generate empathetic response using RAG and few-shot prompting""" | |
| try: | |
| # Step 1: Detect emotion | |
| detected_emotion, confidence = self.emotion_detector.detect_emotion(user_input) | |
| # Step 2: Retrieve relevant templates | |
| templates = self.rag_system.retrieve_templates( | |
| user_input, detected_emotion, top_k=top_k | |
| ) | |
| # Step 3: Create response using templates and emotion | |
| base_responses = self.response_templates.get( | |
| detected_emotion, | |
| self.response_templates['optimism'] | |
| ) | |
| # Combine base response with context from templates | |
| selected_base = random.choice(base_responses) | |
| # Create contextual response | |
| if templates: | |
| context_template = random.choice(templates) | |
| # Enhanced response generation | |
| response = f"{selected_base} I can relate to what you're sharing - {context_template[:80]}. Remember that your feelings are important and valid." | |
| else: | |
| response = selected_base | |
| # Add disclaimer | |
| disclaimer = "\n\nβ οΈ This is an automated response. For serious emotional concerns, please consult a mental health professional." | |
| return response + disclaimer, detected_emotion, confidence | |
| except Exception as e: | |
| error_msg = f"I apologize, but I encountered an error: {str(e)}" | |
| disclaimer = "\n\nβ οΈ This is an automated response. Please consult a professional if needed." | |
| return error_msg + disclaimer, 'neutral', 0.0 | |
| # ============================ | |
| # STREAMLIT APP | |
| # ============================ | |
| def main(): | |
| # Page config | |
| st.set_page_config( | |
| page_title="Empathetic Chatbot", | |
| page_icon="π¬", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| .stApp { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| } | |
| .main-header { | |
| background: rgba(255,255,255,0.1); | |
| padding: 1rem; | |
| border-radius: 10px; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .chat-message { | |
| padding: 1rem; | |
| margin: 0.5rem 0; | |
| border-radius: 10px; | |
| background: rgba(255,255,255,0.9); | |
| max-width: 70%; /* limit bubble width */ | |
| } | |
| .user-message { | |
| background: rgba(100, 149, 237, 0.2); | |
| margin-left: auto; /* push to right */ | |
| margin-right: 1rem; /* spacing from edge */ | |
| text-align: left; /* keep text aligned inside bubble */ | |
| } | |
| .bot-message { | |
| background: rgba(50, 205, 50, 0.1); | |
| margin-left: 1rem; /* spacing from left edge */ | |
| margin-right: auto; /* push to left */ | |
| text-align: left; | |
| } | |
| .emotion-badge { | |
| display: inline-block; | |
| padding: 0.25rem 0.5rem; | |
| border-radius: 15px; | |
| font-size: 0.8rem; | |
| font-weight: bold; | |
| margin: 0.25rem; | |
| } | |
| .anger { background-color: #ffebee; color: #c62828; } | |
| .sadness { background-color: #e3f2fd; color: #1565c0; } | |
| .joy { background-color: #f3e5f5; color: #7b1fa2; } | |
| .optimism { background-color: #e8f5e8; color: #2e7d32; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Header | |
| st.markdown(""" | |
| <div class="main-header"> | |
| <h1>π¬ Emotion-Aware Empathetic Chatbot</h1> | |
| <p>Your AI companion for emotional support and understanding</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Initialize session state | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if "initialized" not in st.session_state: | |
| initialize_chatbot() | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("ποΈ Controls") | |
| # Statistics | |
| if st.session_state.chat_history: | |
| emotions = [chat['emotion'] for chat in st.session_state.chat_history if 'emotion' in chat] | |
| if emotions: | |
| emotion_counts = {} | |
| for emotion in emotions: | |
| emotion_counts[emotion] = emotion_counts.get(emotion, 0) + 1 | |
| st.subheader("π Emotion Statistics") | |
| for emotion, count in emotion_counts.items(): | |
| st.markdown(f'<span class="emotion-badge {emotion}">{emotion.title()}: {count}</span>', unsafe_allow_html=True) | |
| st.markdown("---") | |
| # Test buttons | |
| if st.button("π§ͺ Test Emotion Detection"): | |
| test_emotion_detection() | |
| if st.button("ποΈ Clear Chat History"): | |
| st.session_state.chat_history = [] | |
| st.rerun() | |
| st.markdown("---") | |
| # Sample messages | |
| st.subheader("π‘ Try these sample messages:") | |
| sample_messages = [ | |
| "I'm feeling really happy today!", | |
| "I'm so frustrated with everything", | |
| "I feel really sad and alone", | |
| " Iβm really surprised about what happend!" | |
| ] | |
| for msg in sample_messages: | |
| if st.button(f"π {msg[:20]}...", key=f"sample_{msg}"): | |
| process_message(msg) | |
| # Main chat interface | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| st.subheader("π Chat Interface") | |
| # Display chat history | |
| chat_container = st.container() | |
| with chat_container: | |
| if st.session_state.chat_history: | |
| for i, chat in enumerate(st.session_state.chat_history[-10:]): # Show last 10, in chronological order | |
| # User message | |
| st.markdown(f""" | |
| <div class="chat-message user-message"> | |
| <strong>π§ You ({chat['timestamp']}):</strong><br> | |
| π {chat['user']} | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Bot response with emotion | |
| emotion_class = chat.get('emotion', 'optimism') | |
| confidence = chat.get('confidence', 0.0) | |
| st.markdown(f""" | |
| <div class="chat-message bot-message"> | |
| <strong>π€ Bot:</strong> | |
| <span class="emotion-badge {emotion_class}"> | |
| {emotion_class.title()} ({confidence:.2f}) | |
| </span><br> | |
| π {chat['bot']} | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("---") | |
| # User input at the bottom | |
| user_input = st.text_input( | |
| "Your message:", | |
| placeholder="Type how you're feeling...", | |
| key="user_input", | |
| help="Share your thoughts and emotions" | |
| ) | |
| col_send = st.columns([1])[0] | |
| with col_send: | |
| if st.button("Send π€", type="primary", use_container_width=True): | |
| if user_input.strip(): | |
| process_message(user_input) | |
| st.rerun() | |
| else: | |
| st.warning("β οΈ Please enter a message.") | |
| with col2: | |
| st.subheader("βΉοΈ About") | |
| st.info(""" | |
| This chatbot uses: | |
| - **Emotion Detection**: Identifies your emotional state | |
| - **RAG System**: Retrieves relevant response templates | |
| - **Empathetic Responses**: Provides caring, supportive replies | |
| """) | |
| def initialize_chatbot(): | |
| """Initialize the chatbot systems""" | |
| try: | |
| with st.spinner("π Initializing chatbot systems..."): | |
| # Prepare dataset | |
| rag_data = prepare_dataset() | |
| # Initialize systems | |
| emotion_detector = EmotionDetector() | |
| rag_system = RAGSystem(rag_data) | |
| response_generator = ResponseGenerator(emotion_detector, rag_system) | |
| # Store in session state | |
| st.session_state.response_generator = response_generator | |
| st.session_state.initialized = True | |
| st.success("β Chatbot initialized successfully!") | |
| except Exception as e: | |
| st.error(f"β Initialization failed: {e}") | |
| st.stop() | |
| def process_message(user_input): | |
| """Process user message and generate response""" | |
| if hasattr(st.session_state, 'response_generator'): | |
| with st.spinner("π€ Generating response..."): | |
| response, emotion, confidence = st.session_state.response_generator.generate_response(user_input) | |
| # Add to chat history | |
| st.session_state.chat_history.append({ | |
| "user": user_input, | |
| "bot": response, | |
| "emotion": emotion, | |
| "confidence": confidence, | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| def test_emotion_detection(): | |
| """Test emotion detection functionality""" | |
| if hasattr(st.session_state, 'response_generator'): | |
| test_messages = [ | |
| "I'm so happy today!", | |
| "I feel terrible and sad", | |
| "This makes me really angry", | |
| " Iβm really surprised about what happend!" | |
| ] | |
| st.subheader("π§ͺ Emotion Detection Test Results") | |
| for msg in test_messages: | |
| emotion, confidence = st.session_state.response_generator.emotion_detector.detect_emotion(msg) | |
| st.markdown(f""" | |
| **Message:** "{msg}" | |
| **Emotion:** <span class="emotion-badge {emotion}">{emotion.title()} ({confidence:.3f})</span> | |
| """, unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() |