Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, Wav2Vec2ForCTC, Wav2Vec2Processor | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| import random | |
| import faiss | |
| import json | |
| import logging | |
| import re | |
| import streamlit as st | |
| from datetime import datetime | |
| import os | |
| import torch | |
| import librosa | |
| from gtts import gTTS | |
| import tempfile | |
| import io | |
| import base64 | |
| import time | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================ | |
| # AUDIO PROCESSING UTILITIES | |
| # ============================ | |
| class AudioProcessor: | |
| def __init__(self): | |
| """Initialize audio processing components""" | |
| try: | |
| # Load Wav2Vec2 model for speech-to-text | |
| self.stt_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
| self.stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | |
| logger.info("β STT model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"β Error loading STT model: {e}") | |
| self.stt_processor = None | |
| self.stt_model = None | |
| def speech_to_text_from_bytes(self, audio_bytes): | |
| """Convert speech to text from audio bytes""" | |
| if not self.stt_processor or not self.stt_model: | |
| return "STT model not available" | |
| try: | |
| # Create temporary file from bytes | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| tmp_file.write(audio_bytes) | |
| tmp_file_path = tmp_file.name | |
| # Load and preprocess audio | |
| audio_input, sr = librosa.load(tmp_file_path, sr=16000) | |
| # Clean up temp file | |
| os.unlink(tmp_file_path) | |
| # Check if audio is silent | |
| if np.max(np.abs(audio_input)) < 0.01: | |
| return "No speech detected. Please speak louder." | |
| # Process audio | |
| input_values = self.stt_processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values | |
| # Perform inference | |
| with torch.no_grad(): | |
| logits = self.stt_model(input_values).logits | |
| # Decode transcription | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = self.stt_processor.batch_decode(predicted_ids)[0] | |
| return transcription.strip() if transcription.strip() else "Could not transcribe audio" | |
| except Exception as e: | |
| logger.error(f"Error in speech-to-text: {e}") | |
| return f"Error processing audio: {str(e)}" | |
| def text_to_speech(self, text, lang='en'): | |
| """Convert text to speech using gTTS""" | |
| try: | |
| # Create TTS object | |
| tts = gTTS(text=text, lang=lang, slow=False) | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: | |
| tts.save(tmp_file.name) | |
| return tmp_file.name | |
| except Exception as e: | |
| logger.error(f"Error in text-to-speech: {e}") | |
| return None | |
| # ============================ | |
| # DATA PREPARATION | |
| # ============================ | |
| def prepare_dataset(): | |
| """Load and prepare the emotion dataset with error handling""" | |
| try: | |
| print("π Loading emotion dataset...") | |
| # Load the dataset | |
| ds = load_dataset("cardiffnlp/tweet_eval", "emotion") | |
| # Define emotion labels (matching the dataset) | |
| emotion_labels = ["anger", "joy", "optimism", "sadness"] | |
| def clean_text(text): | |
| """Clean and preprocess text""" | |
| text = text.lower() | |
| text = re.sub(r"http\S+", "", text) # remove URLs | |
| text = re.sub(r"[^\w\s]", "", text) # remove special characters | |
| text = re.sub(r"\d+", "", text) # remove numbers | |
| text = re.sub(r"\s+", " ", text) # normalize whitespace | |
| return text.strip() | |
| # Sample and prepare training data | |
| train_data = ds['train'] | |
| train_sample = random.sample(list(train_data), min(1000, len(train_data))) | |
| # Convert to RAG format | |
| rag_json = [] | |
| for row in train_sample: | |
| cleaned_text = clean_text(row['text']) | |
| if len(cleaned_text) > 10: # Filter out very short texts | |
| rag_json.append({ | |
| "text": cleaned_text, | |
| "emotion": emotion_labels[row['label']], | |
| "original_text": row['text'] | |
| }) | |
| print(f"Dataset prepared with {len(rag_json)} samples") | |
| return rag_json | |
| except Exception as e: | |
| print(f"Warning: Could not load dataset: {e}") | |
| # Return minimal fallback dataset | |
| return [ | |
| {"text": "feeling happy and excited", "emotion": "joy"}, | |
| {"text": "really angry and frustrated", "emotion": "anger"}, | |
| {"text": "sad and lonely today", "emotion": "sadness"}, | |
| {"text": "optimistic about the future", "emotion": "optimism"} | |
| ] | |
| # ============================ | |
| # FIXED EMOTION DETECTION MODEL | |
| # ============================ | |
| class EmotionDetector: | |
| def __init__(self): | |
| # Try multiple emotion models in order of preference | |
| self.model_options = [ | |
| "j-hartmann/emotion-english-distilroberta-base", | |
| "cardiffnlp/twitter-roberta-base-emotion-latest", | |
| "nateraw/bert-base-uncased-emotion", | |
| "michellejieli/emotion_text_classifier" | |
| ] | |
| self.model = None | |
| self.tokenizer = None | |
| self.classifier = None | |
| # Try loading models in order | |
| for model_name in self.model_options: | |
| try: | |
| st.info(f"π Trying to load {model_name}...") | |
| # Force download and load with specific parameters | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| force_download=False, | |
| resume_download=True | |
| ) | |
| # Load model with specific device mapping to avoid meta tensor issues | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| force_download=False, | |
| resume_download=True, | |
| device_map=None, # Don't use device_map | |
| torch_dtype=torch.float32, # Specify dtype explicitly | |
| low_cpu_mem_usage=False # Disable low_cpu_mem_usage | |
| ) | |
| # Move to CPU explicitly if needed | |
| if torch.cuda.is_available(): | |
| self.model = self.model.to('cpu') | |
| self.classifier = pipeline( | |
| "text-classification", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| return_all_scores=False, | |
| device=-1 # Force CPU usage | |
| ) | |
| st.success(f"β Successfully loaded {model_name}") | |
| break | |
| except Exception as e: | |
| st.warning(f"β οΈ Failed to load {model_name}: {str(e)}") | |
| continue | |
| # Fallback to simple rule-based detection if all models fail | |
| if self.classifier is None: | |
| st.warning("β οΈ All emotion models failed. Using rule-based fallback.") | |
| self.use_fallback = True | |
| else: | |
| self.use_fallback = False | |
| def detect_emotion_fallback(self, text): | |
| """Simple rule-based emotion detection as fallback""" | |
| text_lower = text.lower() | |
| # Define keyword patterns for emotions | |
| emotion_keywords = { | |
| 'joy': ['happy', 'joy', 'excited', 'thrilled', 'wonderful', 'amazing', 'great', 'fantastic', 'love', 'awesome'], | |
| 'anger': ['angry', 'mad', 'furious', 'annoyed', 'frustrated', 'irritated', 'hate', 'terrible', 'awful'], | |
| 'sadness': ['sad', 'depressed', 'upset', 'down', 'lonely', 'miserable', 'disappointed', 'heartbroken'], | |
| 'optimism': ['hope', 'optimistic', 'positive', 'confident', 'believe', 'future', 'better', 'improve'] | |
| } | |
| # Count keyword matches | |
| emotion_scores = {} | |
| for emotion, keywords in emotion_keywords.items(): | |
| score = sum(1 for keyword in keywords if keyword in text_lower) | |
| emotion_scores[emotion] = score | |
| # Get emotion with highest score | |
| if max(emotion_scores.values()) > 0: | |
| detected_emotion = max(emotion_scores, key=emotion_scores.get) | |
| confidence = min(emotion_scores[detected_emotion] * 0.3 + 0.4, 0.9) # Scale confidence | |
| else: | |
| detected_emotion = 'optimism' # Default | |
| confidence = 0.5 | |
| return detected_emotion, confidence | |
| def detect_emotion(self, text): | |
| """Detect emotion from text with fallback""" | |
| if self.use_fallback or not text.strip(): | |
| return self.detect_emotion_fallback(text) | |
| try: | |
| result = self.classifier(text) | |
| emotion = result[0]['label'].lower() | |
| confidence = result[0]['score'] | |
| # Map model outputs to our emotion categories | |
| emotion_mapping = { | |
| 'anger': 'anger', | |
| 'disgust': 'sadness', | |
| 'neutral': 'optimism', | |
| 'joy': 'joy', | |
| 'love': 'joy', | |
| 'happiness': 'joy', | |
| 'sadness': 'sadness', | |
| 'fear': 'sadness', | |
| 'surprise': 'optimism', | |
| 'optimism': 'optimism', | |
| # Additional mappings for different model outputs | |
| 'positive': 'joy', | |
| 'negative': 'sadness', | |
| 'admiration': 'joy', | |
| 'amusement': 'joy', | |
| 'annoyance': 'anger', | |
| 'approval': 'optimism', | |
| 'caring': 'joy', | |
| 'confusion': 'sadness', | |
| 'curiosity': 'optimism', | |
| 'desire': 'optimism', | |
| 'disappointment': 'sadness', | |
| 'disapproval': 'anger', | |
| 'embarrassment': 'sadness', | |
| 'excitement': 'joy', | |
| 'gratitude': 'joy', | |
| 'grief': 'sadness', | |
| 'nervousness': 'sadness', | |
| 'pride': 'joy', | |
| 'realization': 'optimism', | |
| 'relief': 'joy', | |
| 'remorse': 'sadness' | |
| } | |
| mapped_emotion = emotion_mapping.get(emotion, 'optimism') | |
| return mapped_emotion, confidence | |
| except Exception as e: | |
| logger.error(f"Error in emotion detection: {e}") | |
| # Fall back to rule-based detection | |
| return self.detect_emotion_fallback(text) | |
| # ============================ | |
| # LIGHTWEIGHT EMOTION DETECTOR (ALTERNATIVE) | |
| # ============================ | |
| class LightweightEmotionDetector: | |
| """A simple, reliable emotion detector that doesn't rely on heavy models""" | |
| def __init__(self): | |
| # Enhanced keyword-based emotion detection | |
| self.emotion_patterns = { | |
| 'joy': { | |
| 'keywords': ['happy', 'joy', 'joyful', 'excited', 'thrilled', 'wonderful', 'amazing', 'great', 'fantastic', | |
| 'love', 'awesome', 'brilliant', 'perfect', 'delighted', 'cheerful', 'elated', 'glad', 'pleased'], | |
| 'phrases': ['feel good', 'so happy', 'really excited', 'love it', 'makes me happy', 'feeling great'] | |
| }, | |
| 'anger': { | |
| 'keywords': ['angry', 'mad', 'furious', 'annoyed', 'frustrated', 'irritated', 'hate', 'terrible', 'awful', | |
| 'disgusting', 'outraged', 'livid', 'enraged', 'pissed', 'infuriated', 'resentful'], | |
| 'phrases': ['so angry', 'really mad', 'hate it', 'makes me angry', 'fed up', 'sick of'] | |
| }, | |
| 'sadness': { | |
| 'keywords': ['sad', 'depressed', 'upset', 'down', 'lonely', 'miserable', 'disappointed', 'heartbroken', | |
| 'devastated', 'hopeless', 'melancholy', 'sorrowful', 'dejected', 'despondent', 'gloomy'], | |
| 'phrases': ['feel sad', 'so down', 'really upset', 'makes me sad', 'feeling low', 'broken hearted'] | |
| }, | |
| 'optimism': { | |
| 'keywords': ['hope', 'hopeful', 'optimistic', 'positive', 'confident', 'believe', 'future', 'better', | |
| 'improve', 'progress', 'opportunity', 'potential', 'bright', 'promising', 'encouraging'], | |
| 'phrases': ['looking forward', 'things will get better', 'positive about', 'have hope', 'bright future'] | |
| } | |
| } | |
| def detect_emotion(self, text): | |
| """Detect emotion using enhanced pattern matching""" | |
| if not text.strip(): | |
| return 'optimism', 0.5 | |
| text_lower = text.lower() | |
| emotion_scores = {emotion: 0 for emotion in self.emotion_patterns.keys()} | |
| # Score based on keywords and phrases | |
| for emotion, patterns in self.emotion_patterns.items(): | |
| # Keyword matching | |
| for keyword in patterns['keywords']: | |
| if keyword in text_lower: | |
| emotion_scores[emotion] += 1 | |
| # Phrase matching (higher weight) | |
| for phrase in patterns['phrases']: | |
| if phrase in text_lower: | |
| emotion_scores[emotion] += 2 | |
| # Intensity modifiers | |
| intensifiers = ['very', 'really', 'extremely', 'so', 'absolutely', 'totally', 'completely'] | |
| intensity_boost = sum(1 for word in intensifiers if word in text_lower) * 0.5 | |
| # Get the emotion with highest score | |
| if max(emotion_scores.values()) > 0: | |
| detected_emotion = max(emotion_scores, key=emotion_scores.get) | |
| base_confidence = min(emotion_scores[detected_emotion] * 0.2 + 0.5, 0.95) | |
| confidence = min(base_confidence + intensity_boost * 0.1, 0.98) | |
| else: | |
| detected_emotion = 'optimism' # Default to optimism | |
| confidence = 0.6 | |
| return detected_emotion, confidence | |
| # ============================ | |
| # RAG SYSTEM WITH FAISS | |
| # ============================ | |
| class RAGSystem: | |
| """ | |
| Retrieval-Augmented Generation (RAG) system for selecting text templates | |
| based on user input and detected emotion. | |
| """ | |
| def __init__(self, rag_data): | |
| self.rag_data = rag_data | |
| self.texts = [entry['text'] for entry in rag_data] | |
| if len(self.texts) == 0: | |
| st.warning("β οΈ No RAG data available. Using simple responses.") | |
| self.embed_model = None | |
| self.embeddings = None | |
| self.index = None | |
| return | |
| try: | |
| # Initialize embedding model | |
| self.embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| # Create embeddings | |
| self.embeddings = self.embed_model.encode( | |
| self.texts, | |
| convert_to_numpy=True, | |
| show_progress_bar=False | |
| ) | |
| # Create FAISS index | |
| dimension = self.embeddings.shape[1] | |
| self.index = faiss.IndexFlatL2(dimension) | |
| self.index.add(self.embeddings) | |
| except Exception as e: | |
| st.warning(f"β οΈ Could not initialize RAG system: {e}") | |
| self.embed_model = None | |
| self.embeddings = None | |
| self.index = None | |
| def retrieve_templates(self, user_input, detected_emotion, top_k=3): | |
| """Retrieve relevant templates based on emotion and similarity""" | |
| if not self.embed_model or not self.index: | |
| return [] | |
| try: | |
| # Filter by emotion first | |
| emotion_filtered_indices = [ | |
| i for i, entry in enumerate(self.rag_data) | |
| if entry['emotion'] == detected_emotion | |
| ] | |
| if not emotion_filtered_indices: | |
| emotion_filtered_indices = list(range(len(self.rag_data))) | |
| # Get filtered embeddings | |
| filtered_embeddings = self.embeddings[emotion_filtered_indices] | |
| filtered_texts = [self.texts[i] for i in emotion_filtered_indices] | |
| # Create temporary index for filtered data | |
| temp_index = faiss.IndexFlatL2(filtered_embeddings.shape[1]) | |
| temp_index.add(filtered_embeddings) | |
| # Search for similar templates | |
| user_embedding = self.embed_model.encode([user_input], convert_to_numpy=True) | |
| distances, indices = temp_index.search( | |
| user_embedding, | |
| min(top_k, len(filtered_texts)) | |
| ) | |
| # Top templates | |
| top_templates = [filtered_texts[i] for i in indices[0]] | |
| return top_templates | |
| except Exception as e: | |
| logger.error(f"Error in template retrieval: {e}") | |
| return [] | |
| # ============================ | |
| # RESPONSE GENERATOR | |
| # ============================ | |
| class ResponseGenerator: | |
| def __init__(self, emotion_detector, rag_system): | |
| self.emotion_detector = emotion_detector | |
| self.rag_system = rag_system | |
| # Empathetic response templates by emotion | |
| self.response_templates = { | |
| 'anger': [ | |
| "I can understand why you're feeling frustrated. It's completely valid to feel this way.", | |
| "Your anger is understandable. Sometimes situations can be really challenging.", | |
| "I hear that you're upset, and that's okay. These feelings are important." | |
| ], | |
| 'sadness': [ | |
| "I'm sorry you're going through a difficult time. Your feelings are valid.", | |
| "It sounds like you're dealing with something really tough right now.", | |
| "I can sense your sadness, and I want you to know that it's okay to feel this way." | |
| ], | |
| 'joy': [ | |
| "I'm so happy to hear about your positive experience! That's wonderful.", | |
| "Your joy is contagious! It's great to hear such positive news.", | |
| "I love hearing about things that make you happy. That sounds amazing!" | |
| ], | |
| 'optimism': [ | |
| "Your positive outlook is inspiring. That's a great way to look at things.", | |
| "I appreciate your hopeful perspective. That's really encouraging.", | |
| "It's wonderful to hear your optimistic thoughts. Keep that positive energy!" | |
| ], | |
| 'neutral': [ | |
| "Thanks for sharing that. I hear you.", | |
| "I understand. Let's continue exploring this topic together.", | |
| "I appreciate you telling me that. Let's keep going." | |
| ] | |
| } | |
| def generate_response(self, user_input, top_k=3): | |
| """Generate empathetic response using RAG and few-shot prompting""" | |
| try: | |
| # Step 1: Detect emotion | |
| detected_emotion, confidence = self.emotion_detector.detect_emotion(user_input) | |
| # Step 2: Retrieve relevant templates (if RAG is available) | |
| templates = [] | |
| if self.rag_system and self.rag_system.embed_model: | |
| templates = self.rag_system.retrieve_templates( | |
| user_input, | |
| detected_emotion, | |
| top_k=top_k | |
| ) | |
| # Step 3: Create response using templates and emotion | |
| base_responses = self.response_templates.get( | |
| detected_emotion, | |
| self.response_templates['optimism'] | |
| ) | |
| # Combine base response with context from templates | |
| selected_base = random.choice(base_responses) | |
| # Create contextual response | |
| if templates: | |
| context_template = random.choice(templates) | |
| # Enhanced response generation | |
| response = f"{selected_base} I can relate to what you're sharing - {context_template[:80]}. Remember that your feelings are important and valid." | |
| else: | |
| response = selected_base | |
| # Add disclaimer | |
| disclaimer = "\n\nβ οΈ This is an automated response. For serious emotional concerns, please consult a mental health professional." | |
| return response + disclaimer, detected_emotion, confidence | |
| except Exception as e: | |
| error_msg = f"I apologize, but I encountered an error: {str(e)}" | |
| disclaimer = "\n\nβ οΈ This is an automated response. Please consult a professional if needed." | |
| return error_msg + disclaimer, 'neutral', 0.0 | |
| # ============================ | |
| # SIMPLE RESPONSE GENERATOR (FALLBACK) | |
| # ============================ | |
| class SimpleResponseGenerator: | |
| """Simplified response generator that works without RAG""" | |
| def __init__(self, emotion_detector): | |
| self.emotion_detector = emotion_detector | |
| # Enhanced response templates | |
| self.response_templates = { | |
| 'anger': [ | |
| "I can understand why you're feeling frustrated. It's completely valid to feel this way. Sometimes situations can be really challenging, and it's important to acknowledge these feelings.", | |
| "Your anger is understandable. When things don't go as expected, it's natural to feel upset. Would you like to talk about what's causing these feelings?", | |
| "I hear that you're upset, and that's okay. These feelings are important and deserve attention. Take a moment to breathe if you need it." | |
| ], | |
| 'sadness': [ | |
| "I'm sorry you're going through a difficult time. Your feelings are valid, and it's okay to feel sad sometimes. Remember that this feeling will pass.", | |
| "It sounds like you're dealing with something really tough right now. I want you to know that it's perfectly normal to feel this way, and you're not alone.", | |
| "I can sense your sadness, and I want you to know that it's okay to feel this way. Sometimes life presents us with challenges that naturally make us feel down." | |
| ], | |
| 'joy': [ | |
| "I'm so happy to hear about your positive experience! That's wonderful, and your joy is really uplifting. It's great when life gives us these beautiful moments.", | |
| "Your joy is contagious! It's amazing to hear such positive news. These happy moments are precious and worth celebrating.", | |
| "I love hearing about things that make you happy. That sounds absolutely amazing! Your enthusiasm is really inspiring." | |
| ], | |
| 'optimism': [ | |
| "Your positive outlook is truly inspiring. That's such a great way to look at things, and your hopefulness is really encouraging.", | |
| "I appreciate your hopeful perspective. That kind of optimism can make such a difference, not just for you but for others around you too.", | |
| "It's wonderful to hear your optimistic thoughts. Keep that positive energy flowing - it's a powerful force for good!" | |
| ] | |
| } | |
| def generate_response(self, user_input, top_k=3): | |
| """Generate response without RAG system""" | |
| try: | |
| # Detect emotion | |
| detected_emotion, confidence = self.emotion_detector.detect_emotion(user_input) | |
| # Get appropriate response template | |
| templates = self.response_templates.get(detected_emotion, self.response_templates['optimism']) | |
| selected_response = random.choice(templates) | |
| # Add personalized touch based on input length and content | |
| if len(user_input) > 100: | |
| selected_response += " I can see you've shared quite a bit with me, and I appreciate your openness." | |
| elif any(word in user_input.lower() for word in ['help', 'advice', 'what should']): | |
| selected_response += " If you'd like to talk more about this, I'm here to listen." | |
| # Add disclaimer | |
| disclaimer = "\n\nβ οΈ This is an automated response. For serious emotional concerns, please consult a mental health professional." | |
| return selected_response + disclaimer, detected_emotion, confidence | |
| except Exception as e: | |
| error_msg = f"I apologize, but I encountered an error: {str(e)}" | |
| disclaimer = "\n\nβ οΈ This is an automated response. Please consult a professional if needed." | |
| return error_msg + disclaimer, 'optimism', 0.0 | |
| # ============================ | |
| # STREAMLIT APP | |
| # ============================ | |
| def main(): | |
| # Page config with better settings | |
| st.set_page_config( | |
| page_title="Empathetic AI Companion", | |
| page_icon="π€", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # CSS with modern design | |
| st.markdown(""" | |
| <style> | |
| /* Import Google Fonts */ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); | |
| /* Global styles */ | |
| .stApp { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| font-family: 'Inter', sans-serif; | |
| } | |
| /* Main header - more elegant */ | |
| .main-header { | |
| background: rgba(255, 255, 255, 0.15); | |
| padding: 2rem; | |
| border-radius: 20px; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| backdrop-filter: blur(20px); | |
| border: 1px solid rgba(255, 255, 255, 0.2); | |
| color: white; | |
| box-shadow: 0 8px 32px rgba(0,0,0,0.1); | |
| transition: all 0.3s ease; | |
| } | |
| .main-header:hover { | |
| transform: translateY(-5px); | |
| box-shadow: 0 12px 40px rgba(0,0,0,0.2); | |
| } | |
| .main-header h1 { | |
| font-size: 2.5rem; | |
| font-weight: 700; | |
| margin-bottom: 0.5rem; | |
| background: linear-gradient(45deg, #fff, #f0f0f0); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| } | |
| .main-header p { | |
| font-size: 1.2rem; | |
| opacity: 0.9; | |
| font-weight: 400; | |
| margin: 0; | |
| } | |
| /* Improved chat messages */ | |
| .chat-message { | |
| margin-bottom: 1.5rem; | |
| animation: fadeInUp 0.5s ease; | |
| } | |
| @keyframes fadeInUp { | |
| from { opacity: 0; transform: translateY(20px); } | |
| to { opacity: 1; transform: translateY(0); } | |
| } | |
| .user-message { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 1rem 1.5rem; | |
| border-radius: 20px 20px 5px 20px; | |
| margin-left: auto; | |
| margin-right: 0; | |
| max-width: 75%; | |
| box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3); | |
| font-weight: 500; | |
| line-height: 1.5; | |
| } | |
| .bot-message { | |
| background: linear-gradient(to top, #a18cd1 0%, #fbc2eb 100%);; | |
| color: white; | |
| padding: 1rem 1.5rem; | |
| border-radius: 20px 20px 20px 5px; | |
| margin-left: 0; | |
| margin-right: auto; | |
| max-width: 75%; | |
| box-shadow: 0 4px 15px rgba(240, 147, 251, 0.3); | |
| font-weight: 500; | |
| line-height: 1.5; | |
| } | |
| /* Message headers */ | |
| .message-header { | |
| font-size: 0.85rem; | |
| opacity: 0.9; | |
| margin-bottom: 0.5rem; | |
| font-weight: 600; | |
| } | |
| /* Emotion badges - hidden but styled */ | |
| .emotion-badge { | |
| display: inline-block; | |
| padding: 0.2rem 0.6rem; | |
| border-radius: 12px; | |
| font-size: 0.75rem; | |
| font-weight: 600; | |
| margin-left: 0.5rem; | |
| opacity: 0.8; | |
| } | |
| /* Enhanced buttons */ | |
| .stButton > button { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| border: none !important; | |
| border-radius: 50px !important; | |
| padding: 1rem 2rem !important; | |
| font-weight: 600 !important; | |
| font-size: 1rem !important; | |
| transition: all 0.3s ease !important; | |
| box-shadow: 0 6px 20px rgba(102, 126, 234, 0.3) !important; | |
| min-height: 50px !important; | |
| } | |
| .stButton > button:hover { | |
| transform: translateY(-3px) !important; | |
| box-shadow: 0 8px 25px rgba(102, 126, 234, 0.4) !important; | |
| background: linear-gradient(135deg, #7c8ff0 0%, #8a5ab8 100%) !important; | |
| } | |
| /* Play button styling */ | |
| .play-button { | |
| background: linear-gradient(135deg, #28a745 0%, #20c997 100%) !important; | |
| border-radius: 25px !important; | |
| padding: 0.5rem 1rem !important; | |
| font-size: 0.9rem !important; | |
| margin-top: 0.5rem !important; | |
| box-shadow: 0 4px 15px rgba(40, 167, 69, 0.3) !important; | |
| } | |
| /* Sidebar enhancements */ | |
| .css-1d391kg { | |
| background: rgba(255, 255, 255, 0.1) !important; | |
| backdrop-filter: blur(20px) !important; | |
| } | |
| /* Stats and metrics */ | |
| .metric-card { | |
| background: rgba(255, 255, 255, 0.9); | |
| padding: 1.5rem; | |
| border-radius: 15px; | |
| text-align: center; | |
| box-shadow: 0 4px 15px rgba(0,0,0,0.05); | |
| margin-bottom: 1rem; | |
| transition: transform 0.3s ease; | |
| } | |
| .metric-card:hover { | |
| transform: translateY(-3px); | |
| } | |
| /* Progress bars */ | |
| .stProgress > div > div > div { | |
| background: linear-gradient(90deg, #667eea, #764ba2) !important; | |
| border-radius: 10px !important; | |
| } | |
| /* Hide default Streamlit elements */ | |
| .stDeployButton {display: none;} | |
| footer {visibility: hidden;} | |
| .stApp > header {visibility: hidden;} | |
| /* Custom scrollbar */ | |
| .chat-container::-webkit-scrollbar { | |
| width: 6px; | |
| } | |
| /* π Audio recorder container fix */ | |
| .audio-recorder-container { | |
| background: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| padding: 0 !important; | |
| margin: 0 !important; | |
| } | |
| /* π€ Recorder button style */ | |
| .audio-recorder-container button { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: #fff !important; | |
| border: none !important; | |
| border-radius: 50% !important; /* Makes it a perfect circle */ | |
| width: 60px !important; | |
| height: 60px !important; | |
| font-size: 1.2rem !important; | |
| font-weight: bold !important; | |
| cursor: pointer !important; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.25) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| /* Hover effect */ | |
| .audio-recorder-container button:hover { | |
| transform: scale(1.08); | |
| box-shadow: 0 6px 18px rgba(0,0,0,0.35) !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Enhanced Header with animation | |
| st.markdown(""" | |
| <div class="main-header"> | |
| <h1>π€ Empathetic AI Companion</h1> | |
| <p>Your intelligent partner for emotional support and meaningful conversations</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Initialize session state | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if "initialized" not in st.session_state: | |
| initialize_chatbot() | |
| if "audio_processor" not in st.session_state: | |
| st.session_state.audio_processor = AudioProcessor() | |
| if "last_transcription" not in st.session_state: | |
| st.session_state.last_transcription = "" | |
| # Enhanced Sidebar | |
| with st.sidebar: | |
| st.markdown("### ποΈ Control Panel") | |
| # Voice Settings Section | |
| with st.expander("ποΈ Voice Settings", expanded=True): | |
| tts_language = st.selectbox( | |
| "Text-to-Speech ptions", | |
| options=['en', 'es', 'fr', 'de', 'it'], | |
| index=0, | |
| help="Choose your preferred TTS accent" | |
| ) | |
| st.session_state.tts_language = tts_language | |
| auto_tts = st.toggle( | |
| "Auto-play Bot Responses", | |
| value=False, | |
| help="Automatically play TTS for all bot responses" | |
| ) | |
| st.session_state.auto_tts = auto_tts | |
| st.divider() | |
| # Enhanced Statistics Section | |
| if st.session_state.chat_history: | |
| with st.expander("π Session Analytics", expanded=False): | |
| emotions = [chat['emotion'] for chat in st.session_state.chat_history if 'emotion' in chat] | |
| if emotions: | |
| emotion_counts = {} | |
| for emotion in emotions: | |
| emotion_counts[emotion] = emotion_counts.get(emotion, 0) + 1 | |
| # Display emotion distribution | |
| for emotion, count in emotion_counts.items(): | |
| percentage = (count / len(emotions)) * 100 | |
| st.metric( | |
| f"{emotion.title()}", | |
| f"{count} messages", | |
| f"{percentage:.1f}%" | |
| ) | |
| # Quick Actions | |
| with st.expander("β‘ Quick Actions", expanded=True): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("π§ͺ Test AI", use_container_width=True): | |
| test_emotion_detection() | |
| with col2: | |
| if st.button("ποΈ Clear Chat", use_container_width=True): | |
| st.session_state.chat_history = [] | |
| st.session_state.last_transcription = "" | |
| st.rerun() | |
| st.divider() | |
| # Sample Messages - More engaging | |
| with st.expander("π‘ Try These Messages", expanded=False): | |
| sample_messages = [ | |
| ("π", "I'm feeling really happy today!"), | |
| ("π€", "I'm so frustrated with everything"), | |
| ("π’", "I feel really sad and alone"), | |
| ("π", "I'm excited about my future!") | |
| ] | |
| for i, (emoji, msg) in enumerate(sample_messages): | |
| if st.button(f"{emoji} {msg[:20]}...", key=f"sample_{i}", use_container_width=True): | |
| process_message(msg) | |
| st.rerun() | |
| st.divider() | |
| # Enhanced Info Section | |
| st.markdown(""" | |
| <div style="background: rgba(255,255,255,0.1); padding: 1rem; border-radius: 10px; backdrop-filter: blur(10px);"> | |
| <h4 style="color: white; margin-bottom: 0.5rem;">β¨ Features</h4> | |
| <ul style="color: rgba(255,255,255,0.9); font-size: 0.9rem; margin: 0;"> | |
| <li>π€ Voice Recording & STT</li> | |
| <li>π Natural TTS Responses</li> | |
| <li>π Advanced Emotion AI</li> | |
| <li>π¬ Context-Aware Replies</li> | |
| <li>π Real-time Analytics</li> | |
| </ul> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Main Layout - Improved | |
| col_main, col_stats = st.columns([7, 3]) | |
| with col_main: | |
| # Enhanced Chat Display | |
| st.markdown('<div class="chat-container">', unsafe_allow_html=True) | |
| if st.session_state.chat_history: | |
| for i, chat in enumerate(st.session_state.chat_history[-15:]): # Show more messages | |
| # User message with better styling | |
| st.markdown(f""" | |
| <div class="chat-message"> | |
| <div class="user-message"> | |
| <div class="message-header">π§ You β’ {chat['timestamp']}</div> | |
| {chat['user']} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Bot response with enhanced styling | |
| emotion_class = chat.get('emotion', 'optimism') | |
| confidence = chat.get('confidence', 0.0) | |
| st.markdown(f""" | |
| <div class="chat-message"> | |
| <div class="bot-message"> | |
| <div class="message-header"> | |
| π€ AI Assistant | |
| <span class="emotion-badge {emotion_class}"> | |
| {emotion_class.title()} {confidence:.0%} | |
| </span> | |
| </div> | |
| {chat['bot'].replace('β οΈ', 'β οΈ ')} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Enhanced TTS button | |
| col_tts, col_spacer = st.columns([2, 6]) | |
| with col_tts: | |
| if st.button(f"π Play Audio", key=f"tts_{i}", help="Listen to response"): | |
| play_tts(chat['bot']) | |
| # Auto-play logic | |
| if (st.session_state.auto_tts and | |
| i == len(st.session_state.chat_history) - 1 and | |
| chat.get('should_play_tts', False)): | |
| play_tts(chat['bot']) | |
| st.session_state.chat_history[-1]['should_play_tts'] = False | |
| # Enhanced Input Section | |
| st.markdown('<div class="input-section">', unsafe_allow_html=True) | |
| # Input layout | |
| col_text = st.container() | |
| col_voice, col_send = st.columns(2) | |
| with col_text: | |
| user_input = st.text_input( | |
| "", | |
| placeholder="Share what's on your mind... How can I help you today?", | |
| label_visibility="collapsed", | |
| key="main_input" | |
| ) | |
| from audio_recorder_streamlit import audio_recorder | |
| with col_voice: | |
| audio_bytes = audio_recorder() | |
| if audio_bytes: | |
| st.audio(audio_bytes, format="audio/wav") | |
| with col_send: | |
| if st.button("π€ Send Message", type="primary", key="send_btn", use_container_width=True): | |
| if user_input.strip(): | |
| process_message(user_input.strip()) | |
| st.rerun() | |
| # Voice processing with better feedback | |
| if audio_bytes is not None: | |
| with st.spinner("π Processing your voice..."): | |
| transcription = st.session_state.audio_processor.speech_to_text_from_bytes(audio_bytes) | |
| if transcription and transcription not in ["No speech detected. Please speak louder.", "Could not transcribe audio"]: | |
| st.success(f"ποΈ **Transcribed:** \"{transcription}\"") | |
| if transcription != st.session_state.last_transcription: | |
| st.session_state.last_transcription = transcription | |
| process_message(transcription, from_voice=True) | |
| st.rerun() | |
| else: | |
| st.warning(f"β οΈ {transcription}") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Enhanced Statistics Panel | |
| with col_stats: | |
| if st.session_state.chat_history: | |
| st.markdown("### π Live Insights") | |
| # Emotion trends | |
| recent_emotions = [ | |
| chat.get('emotion', 'optimism') | |
| for chat in st.session_state.chat_history[-10:] | |
| if 'emotion' in chat | |
| ] | |
| if recent_emotions: | |
| st.markdown("**Recent Emotions:**") | |
| emotion_scores = {'anger': 0, 'sadness': 0, 'joy': 0, 'optimism': 0} | |
| for emotion in recent_emotions: | |
| emotion_scores[emotion] = emotion_scores.get(emotion, 0) + 1 | |
| total = len(recent_emotions) | |
| for emotion, count in emotion_scores.items(): | |
| if count > 0: | |
| progress = count / total | |
| st.progress(progress, text=f"{emotion.title()}: {count}/{total}") | |
| # Session metrics | |
| if len(st.session_state.chat_history) > 2: | |
| st.divider() | |
| st.markdown("**Session Overview:**") | |
| total_messages = len(st.session_state.chat_history) | |
| emotions = [chat.get('emotion', 'optimism') for chat in st.session_state.chat_history] | |
| # Metrics cards | |
| st.metric("Messages", total_messages) | |
| if emotions: | |
| most_common = max(set(emotions), key=emotions.count) | |
| st.metric("Dominant Emotion", most_common.title()) | |
| # Mood indicator | |
| positive_emotions = ['joy', 'optimism'] | |
| positive_count = sum(1 for e in emotions if e in positive_emotions) | |
| mood_score = positive_count / len(emotions) | |
| if mood_score > 0.6: | |
| st.success("π Positive Mood") | |
| elif mood_score > 0.4: | |
| st.info("π Balanced Mood") | |
| else: | |
| st.warning("π Needs Support") | |
| else: | |
| # Getting started tips | |
| st.markdown(""" | |
| ### π Getting Started | |
| **Tips for better conversations:** | |
| - Be specific about your feelings | |
| - Share context about your situation | |
| - Use voice input for natural interaction | |
| - Try the sample messages below | |
| **Privacy Note:** | |
| Your conversations are processed locally and not stored permanently. | |
| """) | |
| def initialize_chatbot(): | |
| """Initialize the chatbot components with better feedback""" | |
| with st.spinner("π Loading AI models..."): | |
| try: | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| # Load dataset | |
| status_text.text("π Loading emotion dataset...") | |
| progress_bar.progress(25) | |
| st.session_state.rag_data = prepare_dataset() | |
| # Initialize emotion detector | |
| status_text.text("π§ Loading emotion detection model...") | |
| progress_bar.progress(50) | |
| st.session_state.emotion_detector = EmotionDetector() | |
| # Initialize RAG system | |
| status_text.text("π Setting up knowledge retrieval...") | |
| progress_bar.progress(75) | |
| st.session_state.rag_system = RAGSystem(st.session_state.rag_data) | |
| # Initialize response generator | |
| status_text.text("π¬ Preparing response generation...") | |
| progress_bar.progress(100) | |
| st.session_state.response_generator = ResponseGenerator( | |
| st.session_state.emotion_detector, | |
| st.session_state.rag_system | |
| ) | |
| st.session_state.initialized = True | |
| # Clear loading elements | |
| progress_bar.empty() | |
| status_text.empty() | |
| st.success("β AI Companion ready! Start your conversation below.") | |
| except Exception as e: | |
| st.error(f"β Failed to initialize: {str(e)}") | |
| st.info("π‘ Try refreshing the page or check your internet connection.") | |
| st.stop() | |
| def process_message(user_input, from_voice=False): | |
| """Enhanced message processing with better error handling""" | |
| if not user_input.strip(): | |
| return | |
| try: | |
| # Show typing indicator | |
| with st.spinner("π€ AI is thinking..."): | |
| # Generate response | |
| bot_response, detected_emotion, confidence = st.session_state.response_generator.generate_response( | |
| user_input, | |
| top_k=3 | |
| ) | |
| # Create chat entry | |
| chat_entry = { | |
| 'user': user_input, | |
| 'bot': bot_response, | |
| 'emotion': detected_emotion, | |
| 'confidence': confidence, | |
| 'timestamp': datetime.now().strftime("%H:%M"), | |
| 'from_voice': from_voice, | |
| 'should_play_tts': st.session_state.get('auto_tts', False) | |
| } | |
| st.session_state.chat_history.append(chat_entry) | |
| # Log interaction | |
| logger.info(f"User ({'Voice' if from_voice else 'Text'}): {user_input[:50]}... | Emotion: {detected_emotion} ({confidence:.2f})") | |
| except Exception as e: | |
| st.error(f"β Something went wrong: {str(e)}") | |
| st.info("π‘ Please try again or rephrase your message.") | |
| logger.error(f"Processing error: {e}") | |
| def play_tts(text): | |
| """Enhanced TTS with better error handling""" | |
| try: | |
| # Clean text for TTS | |
| clean_text = re.sub(r'[^\w\s\.\,\!\?\']', '', text) | |
| clean_text = clean_text.replace('β οΈ', '').strip() | |
| if not clean_text: | |
| return | |
| # Generate TTS | |
| tts_lang = st.session_state.get('tts_language', 'en') | |
| with st.spinner("π Generating audio..."): | |
| audio_file = st.session_state.audio_processor.text_to_speech( | |
| clean_text[:500], # Limit length | |
| lang=tts_lang | |
| ) | |
| if audio_file: | |
| with open(audio_file, 'rb') as f: | |
| audio_bytes = f.read() | |
| st.audio(audio_bytes, format='audio/mp3', autoplay=True) | |
| os.unlink(audio_file) # Clean up | |
| except Exception as e: | |
| logger.error(f"TTS error: {e}") | |
| st.toast("β οΈ Could not generate audio", icon="π") | |
| def test_emotion_detection(): | |
| """Enhanced emotion testing with better display""" | |
| test_texts = [ | |
| "I'm absolutely thrilled about my new promotion!", | |
| "I feel completely overwhelmed and sad today", | |
| "This traffic is making me so angry and frustrated!", | |
| "I have hope that everything will work out perfectly" | |
| ] | |
| st.markdown("### π§ͺ Emotion Detection Demo") | |
| for i, text in enumerate(test_texts): | |
| with st.container(): | |
| emotion, confidence = st.session_state.emotion_detector.detect_emotion(text) | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| st.write(f"**Text:** {text}") | |
| st.write(f"**Detected:** {emotion.title()} ({confidence:.1%} confidence)") | |
| with col2: | |
| # Emotion emoji mapping | |
| emoji_map = {'anger': 'π ', 'sadness': 'π’', 'joy': 'π', 'optimism': 'π'} | |
| st.markdown(f"### {emoji_map.get(emotion, 'π€')}") | |
| if i < len(test_texts) - 1: | |
| st.divider() | |
| if __name__ == "__main__": | |
| main() |