diff --git "a/main.py" "b/main.py" new file mode 100644--- /dev/null +++ "b/main.py" @@ -0,0 +1,2700 @@ +import streamlit as st +import pandas as pd +import numpy as np +import json +import time +from datetime import datetime +import requests +from typing import Dict, List, Any +import plotly.express as px +import plotly.graph_objects as go +from plotly.subplots import make_subplots +import os +from io import StringIO +import uuid +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Import custom modules +from eda_analyzer import EDAAnalyzer +from database_manager import DatabaseManager + +# Page configuration +st.set_page_config( + page_title="⚡ Neural Data Analyst Premium", + page_icon="⚡", + layout="wide", + initial_sidebar_state="expanded" +) + +# Custom CSS +st.markdown(""" + +""", unsafe_allow_html=True) + +class NeuralDataAnalyst: + def __init__(self): + # Initialize with error handling + try: + self.db_manager = DatabaseManager() + except: + # Create a basic fallback if DatabaseManager fails + self.db_manager = None + + self.eda_analyzer = EDAAnalyzer() + self.initialize_session_state() + + def initialize_session_state(self): + """Initialize session state variables""" + # Get API key from multiple sources + api_key = None + + # Try different sources in order of preference + try: + # 1. Try Streamlit secrets first (for deployment) + if hasattr(st, 'secrets') and 'GROQ_API_KEY' in st.secrets: + api_key = st.secrets['GROQ_API_KEY'] + print("API key loaded from Streamlit secrets") + # 2. Try environment variable (most common) + elif 'GROQ_API_KEY' in os.environ: + api_key = os.environ['GROQ_API_KEY'] + print("API key loaded from environment variable") + # 3. Try loading from .env file directly + else: + from dotenv import load_dotenv + load_dotenv(override=True) # Force reload + if 'GROQ_API_KEY' in os.environ: + api_key = os.environ['GROQ_API_KEY'] + print("API key loaded from .env file") + except Exception as e: + print(f"Error loading API key: {e}") + + # Debug: Check if we found the API key + if api_key: + print(f"✅ API key found: {api_key[:10]}...{api_key[-5:] if len(api_key) > 15 else api_key}") + else: + print("❌ No API key found") + + if 'api_key' not in st.session_state: + st.session_state.api_key = api_key or "" + if 'api_connected' not in st.session_state: + st.session_state.api_connected = bool(api_key) + if 'uploaded_data' not in st.session_state: + st.session_state.uploaded_data = None + if 'data_schema' not in st.session_state: + st.session_state.data_schema = "" + if 'analysis_history' not in st.session_state: + st.session_state.analysis_history = [] + if 'session_id' not in st.session_state: + st.session_state.session_id = str(uuid.uuid4()) + if 'selected_model' not in st.session_state: + st.session_state.selected_model = "llama-3.3-70b-versatile" + if 'example_query' not in st.session_state: + st.session_state.example_query = "" + if 'recent_queries' not in st.session_state: + st.session_state.recent_queries = [] + if 'show_eda_results' not in st.session_state: + st.session_state.show_eda_results = False + if 'show_ai_insights' not in st.session_state: + st.session_state.show_ai_insights = False + if 'show_advanced_analytics' not in st.session_state: + st.session_state.show_advanced_analytics = False + if 'eda_results' not in st.session_state: + st.session_state.eda_results = None + if 'ai_insights_text' not in st.session_state: + st.session_state.ai_insights_text = None + if 'show_model_selection' not in st.session_state: + st.session_state.show_model_selection = False + if 'current_query' not in st.session_state: + st.session_state.current_query = "" + + # Force test connection if we have an API key but haven't connected + if api_key and not st.session_state.api_connected: + print("Testing API connection...") + self.test_api_connection_silent(api_key, st.session_state.selected_model) + + def render_header(self): + """Render the main header""" + st.markdown('

⚡ NEURAL DATA ANALYST

', unsafe_allow_html=True) + st.markdown('

Premium AI-Powered Business Intelligence Suite

', unsafe_allow_html=True) + + def render_api_config(self): + """Render API configuration section""" + with st.sidebar: + st.markdown("## 🔐 Neural Configuration") + + # Debug section + with st.expander("🔧 Debug Info", expanded=False): + st.write(f"API Key in session: {'Yes' if st.session_state.api_key else 'No'}") + if st.session_state.api_key: + st.write(f"API Key (masked): {st.session_state.api_key[:10]}...{st.session_state.api_key[-5:]}") + st.write(f"API Connected: {st.session_state.api_connected}") + st.write(f"Environment GROQ_API_KEY: {'Set' if os.environ.get('GROQ_API_KEY') else 'Not set'}") + + # Test button + if st.button("🔄 Reload API Key", key="reload_api"): + from dotenv import load_dotenv + load_dotenv(override=True) + api_key = os.environ.get('GROQ_API_KEY') + if api_key: + st.session_state.api_key = api_key + # Test the API key immediately + test_success = self.test_api_connection_silent(api_key, st.session_state.selected_model) + if test_success: + st.session_state.api_connected = True + st.success("✅ API key reloaded and tested successfully!") + else: + st.error("❌ API key loaded but connection test failed") + st.rerun() + else: + st.error("No API key found in .env file") + + if st.button("🧪 Test API Connection", key="test_api"): + if st.session_state.api_key: + with st.spinner("Testing API connection..."): + success = self.test_api_connection_silent(st.session_state.api_key, st.session_state.selected_model) + if success: + st.session_state.api_connected = True + st.success("✅ API connection successful!") + else: + st.error("❌ API connection failed") + else: + st.error("No API key to test") + + # Check if API key is configured + has_api_key = bool(st.session_state.api_key) + + if has_api_key: + st.success("✅ API Key loaded from environment") + + # Model selection + model = st.selectbox( + "AI Model", + [ + "llama-3.3-70b-versatile", + "llama3-70b-8192", + "mixtral-8x7b-32768", + "gemma2-9b-it", + "qwen-qwq-32b", + "deepseek-r1-distill-llama-70b" + ], + index=0, + key="model_selector" + ) + st.session_state.selected_model = model + + # Connection status + if st.session_state.api_connected: + st.markdown('
⚡ Neural Link: Active
', unsafe_allow_html=True) + else: + st.markdown('
⚡ Neural Link: Connecting...
', unsafe_allow_html=True) + + else: + st.error("❌ No API key configured") + st.markdown(""" + **Setup Required:** + + **For local development:** + Create `.env` file: + ``` + GROQ_API_KEY=your_api_key_here + ``` + + **For Streamlit Cloud:** + Add to app secrets: + ```toml + GROQ_API_KEY = "your_api_key_here" + ``` + + **Get API key:** [Groq Console](https://console.groq.com/keys) + """) + + # History section + st.markdown("---") + st.markdown("## 📊 Analysis History") + + if st.button("🗂️ View History", key="view_history"): + self.show_history() + + if st.button("🗑️ Clear History", key="clear_history"): + if self.db_manager: + self.db_manager.clear_history(st.session_state.session_id) + st.success("History cleared!") + else: + st.session_state.analysis_history = [] + st.success("History cleared!") + + def generate_database_schema(self, df: pd.DataFrame) -> Dict[str, str]: + """Generate database schema from DataFrame""" + schema_parts = [] + table_name = "uploaded_data" + + # Start table definition + schema_parts.append(f"CREATE TABLE {table_name} (") + + column_definitions = [] + for col in df.columns: + # Clean column name (remove spaces, special chars) + clean_col = col.replace(' ', '_').replace('-', '_').replace('.', '_') + clean_col = ''.join(c for c in clean_col if c.isalnum() or c == '_') + + # Determine SQL data type + dtype = df[col].dtype + if pd.api.types.is_integer_dtype(dtype): + sql_type = "INTEGER" + elif pd.api.types.is_float_dtype(dtype): + sql_type = "DECIMAL(10,2)" + elif pd.api.types.is_datetime64_any_dtype(dtype): + sql_type = "DATETIME" + elif pd.api.types.is_bool_dtype(dtype): + sql_type = "BOOLEAN" + else: + # Check if it's a short text field + max_length = df[col].astype(str).str.len().max() if not df[col].empty else 50 + if max_length <= 50: + sql_type = f"VARCHAR({max(50, int(max_length))})" + else: + sql_type = "TEXT" + + column_definitions.append(f" {clean_col} {sql_type}") + + schema_parts.append(",\n".join(column_definitions)) + schema_parts.append(");") + + schema = "\n".join(schema_parts) + + # Add simple format for AI queries + simple_schema = f"{table_name}(" + ", ".join([ + f"{col.replace(' ', '_').replace('-', '_').replace('.', '_')}" + for col in df.columns + ]) + ")" + + return { + "sql_schema": schema, + "simple_schema": simple_schema + } + + def test_api_connection_silent(self, api_key: str, model: str) -> bool: + """Test API connection silently and return success status""" + try: + response = requests.post( + "https://api.groq.com/openai/v1/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + }, + json={ + "model": model, + "messages": [{"role": "user", "content": "Say 'OK' in one word."}], + "temperature": 0.1, + "max_tokens": 10 + }, + timeout=10 + ) + success = response.status_code == 200 + if success: + st.session_state.api_connected = True + print("✅ API connection test successful") + else: + print(f"❌ API connection test failed: {response.status_code}") + return success + except Exception as e: + print(f"❌ API connection test error: {e}") + return False + + def render_data_upload(self): + """Render data upload section""" + st.markdown("## 📊 Data Upload & Analysis") + + uploaded_file = st.file_uploader( + "Choose a CSV or JSON file", + type=['csv', 'json'], + help="Upload your data file for comprehensive analysis" + ) + + if uploaded_file is not None: + try: + if uploaded_file.name.endswith('.csv'): + df = pd.read_csv(uploaded_file) + elif uploaded_file.name.endswith('.json'): + df = pd.read_json(uploaded_file) + + st.session_state.uploaded_data = df + + # Generate and store database schema + schema_info = self.generate_database_schema(df) + st.session_state.data_schema = schema_info["simple_schema"] + + # Success message with file info + st.success(f"✅ {uploaded_file.name} loaded successfully!") + + # Key metrics + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("📊 Rows", f"{len(df):,}") + with col2: + st.metric("📋 Columns", len(df.columns)) + with col3: + st.metric("💾 Size", f"{uploaded_file.size / 1024:.1f} KB") + with col4: + st.metric("❓ Missing", f"{df.isnull().sum().sum():,}") + + # Database Schema Section + with st.expander("🗄️ Database Schema", expanded=True): + st.markdown("**Generated Schema for AI Queries:**") + st.code(st.session_state.data_schema, language="sql") + + st.markdown("**Full SQL Schema:**") + st.code(schema_info["sql_schema"], language="sql") + + # Column details + col1, col2 = st.columns(2) + with col1: + st.markdown("**📊 Numeric Columns:**") + numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() + if numeric_cols: + for col in numeric_cols: + st.write(f"• {col} ({df[col].dtype})") + else: + st.write("None found") + + with col2: + st.markdown("**📝 Text Columns:**") + text_cols = df.select_dtypes(include=['object']).columns.tolist() + if text_cols: + for col in text_cols: + st.write(f"• {col} (text)") + else: + st.write("None found") + + # Multiple Visualizations + self.create_multiple_visualizations(df) + + # Action buttons + st.markdown("### 🚀 Analysis Actions") + col1, col2, col3 = st.columns(3) + + with col1: + if st.button("🔬 Complete EDA", key="eda_button", help="Comprehensive Exploratory Data Analysis"): + with st.spinner("Performing comprehensive EDA analysis..."): + st.session_state.eda_results = self.eda_analyzer.perform_complete_eda(df) + st.session_state.show_eda_results = True + st.session_state.show_ai_insights = False + st.session_state.show_advanced_analytics = False + + with col2: + if st.button("🤖 AI Insights", key="ai_insights", help="Generate AI-powered insights"): + if st.session_state.api_key and len(st.session_state.api_key) > 10: # Better API key check + with st.spinner("🤖 Generating AI insights..."): + try: + summary = f"Dataset: {len(df)} rows, {len(df.columns)} columns, Missing: {df.isnull().sum().sum()}" + prompt = f"Analyze this dataset and provide 3 key business insights: {summary}" + st.session_state.ai_insights_text = self.make_api_call(st.session_state.selected_model, prompt) + st.session_state.show_ai_insights = True + st.session_state.show_eda_results = False + st.session_state.show_advanced_analytics = False + except Exception as e: + st.error(f"AI insights failed: {str(e)}") + st.session_state.ai_insights_text = f"Error generating insights: {str(e)}" + st.session_state.show_ai_insights = True + else: + st.error("❌ API key not found or invalid. Check the sidebar debug info.") + st.session_state.show_ai_insights = False + + with col3: + if st.button("📊 Advanced Analytics", key="advanced_analytics", help="Advanced statistical analysis"): + st.session_state.show_advanced_analytics = True + st.session_state.show_eda_results = False + st.session_state.show_ai_insights = False + + # Display results based on session state + if st.session_state.show_eda_results and st.session_state.eda_results: + st.markdown("---") + st.markdown("## 🔬 EDA Results") + + # Simple EDA display + overview = st.session_state.eda_results.get('overview', {}) + if overview: + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Rows", f"{overview.get('total_rows', 0):,}") + with col2: + st.metric("Columns", overview.get('total_columns', 0)) + with col3: + st.metric("Missing", f"{overview.get('missing_values_total', 0):,}") + with col4: + st.metric("Duplicates", f"{overview.get('duplicate_rows', 0):,}") + + # Show insights + insights = st.session_state.eda_results.get('insights', []) + if insights: + st.markdown("### 💡 Key Insights") + for insight in insights[:3]: + st.markdown(f"**{insight.get('title', 'Insight')}:** {insight.get('description', 'No description')}") + + elif st.session_state.show_ai_insights and st.session_state.ai_insights_text: + st.markdown("---") + st.markdown("## 🤖 AI Insights") + st.markdown(st.session_state.ai_insights_text) + + elif st.session_state.show_advanced_analytics: + st.markdown("---") + st.markdown("## 📊 Advanced Analytics") + + # Simple analytics + numeric_cols = df.select_dtypes(include=[np.number]).columns + if len(numeric_cols) > 0: + st.markdown("### 📈 Statistics") + st.dataframe(df[numeric_cols].describe()) + + if len(numeric_cols) >= 2: + st.markdown("### 🔗 Correlations") + corr_matrix = df[numeric_cols].corr() + fig = px.imshow(corr_matrix, text_auto=True, title="Correlation Matrix") + st.plotly_chart(fig, use_container_width=True) + + # Data preview + with st.expander("👀 Data Preview", expanded=False): + st.dataframe(df.head(100), use_container_width=True) + + except Exception as e: + st.error(f"Error loading file: {str(e)}") + else: + # Show sample data option when no file is uploaded + st.info("👆 Upload a CSV or JSON file to get started") + + if st.button("📋 Load Sample Data", help="Load sample sales data for testing"): + sample_data = self.create_sample_data() + st.session_state.uploaded_data = sample_data + schema_info = self.generate_database_schema(sample_data) + st.session_state.data_schema = schema_info["simple_schema"] + st.success("✅ Sample data loaded!") + st.rerun() + + def create_sample_data(self) -> pd.DataFrame: + """Create sample sales data for demonstration""" + np.random.seed(42) + n_rows = 1000 + + # Generate sample sales data + data = { + 'customer_id': range(1, n_rows + 1), + 'customer_name': [f"Customer_{i}" for i in range(1, n_rows + 1)], + 'product': np.random.choice(['Widget A', 'Widget B', 'Widget C', 'Gadget X', 'Gadget Y'], n_rows), + 'sales_amount': np.random.normal(2000, 500, n_rows).round(2), + 'order_date': pd.date_range('2023-01-01', periods=n_rows, freq='D'), + 'region': np.random.choice(['North', 'South', 'East', 'West'], n_rows), + 'sales_rep': np.random.choice(['John Smith', 'Jane Doe', 'Bob Johnson', 'Alice Brown'], n_rows), + 'customer_age': np.random.randint(25, 70, n_rows), + 'customer_segment': np.random.choice(['Premium', 'Standard', 'Basic'], n_rows), + 'discount_percent': np.random.uniform(0, 20, n_rows).round(1) + } + + return pd.DataFrame(data) + + def generate_ai_insights(self, df: pd.DataFrame): + """Generate AI-powered insights about the data""" + with st.spinner("🤖 Generating AI insights..."): + try: + # Prepare data summary for AI + summary = f""" + Dataset Analysis: + - Rows: {len(df):,} + - Columns: {len(df.columns)} + - Schema: {st.session_state.data_schema} + - Missing values: {df.isnull().sum().sum():,} + + Column types: + {df.dtypes.to_string()} + + Sample data: + {df.head(3).to_string()} + """ + + prompt = f"""Analyze this dataset and provide 5 key business insights: + + {summary} + + Format as: + 1. **Insight Title**: Description + 2. **Insight Title**: Description + (etc.) + + Focus on business value, patterns, and actionable recommendations.""" + + insights = self.make_api_call(st.session_state.selected_model, prompt) + + st.markdown("### 🤖 AI-Generated Insights") + st.markdown(insights) + + # Save to history + analysis_record = { + "timestamp": datetime.now().isoformat(), + "type": "AI Insights", + "data_shape": df.shape, + "insights": insights, + "session_id": st.session_state.session_id + } + + if self.db_manager: + self.db_manager.save_analysis(analysis_record) + else: + st.session_state.analysis_history.append(analysis_record) + + except Exception as e: + st.error(f"Failed to generate AI insights: {str(e)}") + + def show_advanced_analytics(self, df: pd.DataFrame): + """Show advanced analytics options""" + st.markdown("### 📊 Advanced Analytics") + + analytics_tabs = st.tabs(["📈 Statistical Summary", "🔍 Outlier Detection", "📊 Correlation Analysis"]) + + with analytics_tabs[0]: + numeric_cols = df.select_dtypes(include=[np.number]).columns + if len(numeric_cols) > 0: + st.markdown("#### 📈 Statistical Summary") + summary_stats = df[numeric_cols].describe() + st.dataframe(summary_stats, use_container_width=True) + + # Skewness and kurtosis + st.markdown("#### 📏 Distribution Metrics") + dist_metrics = pd.DataFrame({ + 'Column': numeric_cols, + 'Skewness': [df[col].skew() for col in numeric_cols], + 'Kurtosis': [df[col].kurtosis() for col in numeric_cols] + }) + st.dataframe(dist_metrics, use_container_width=True, hide_index=True) + else: + st.info("No numeric columns found for statistical analysis") + + with analytics_tabs[1]: + self.perform_outlier_analysis(df) + + with analytics_tabs[2]: + self.perform_correlation_analysis(df) + + def perform_outlier_analysis(self, df: pd.DataFrame): + """Perform outlier analysis""" + numeric_cols = df.select_dtypes(include=[np.number]).columns + + if len(numeric_cols) == 0: + st.info("No numeric columns found for outlier analysis") + return + + st.markdown("#### 🔍 Outlier Analysis") + + selected_col = st.selectbox("Select column for outlier analysis", numeric_cols) + + if selected_col: + data = df[selected_col].dropna() + + # IQR method + Q1 = data.quantile(0.25) + Q3 = data.quantile(0.75) + IQR = Q3 - Q1 + lower_bound = Q1 - 1.5 * IQR + upper_bound = Q3 + 1.5 * IQR + + outliers = df[(df[selected_col] < lower_bound) | (df[selected_col] > upper_bound)] + + col1, col2 = st.columns(2) + with col1: + st.metric("Total Outliers", len(outliers)) + with col2: + st.metric("Outlier Percentage", f"{len(outliers)/len(df)*100:.1f}%") + + # Visualization + fig = go.Figure() + + # Normal points + normal_data = df[~df.index.isin(outliers.index)] + fig.add_trace(go.Scatter( + x=normal_data.index, + y=normal_data[selected_col], + mode='markers', + name='Normal', + marker=dict(color='blue', size=4) + )) + + # Outliers + if len(outliers) > 0: + fig.add_trace(go.Scatter( + x=outliers.index, + y=outliers[selected_col], + mode='markers', + name='Outliers', + marker=dict(color='red', size=8, symbol='x') + )) + + fig.update_layout( + title=f'Outlier Detection: {selected_col}', + xaxis_title='Index', + yaxis_title=selected_col, + height=400 + ) + + st.plotly_chart(fig, use_container_width=True) + + def perform_correlation_analysis(self, df: pd.DataFrame): + """Perform correlation analysis""" + numeric_cols = df.select_dtypes(include=[np.number]).columns + + if len(numeric_cols) < 2: + st.info("Need at least 2 numeric columns for correlation analysis") + return + + st.markdown("#### 🔗 Correlation Analysis") + + # Correlation matrix + corr_matrix = df[numeric_cols].corr() + + # Heatmap + fig = px.imshow(corr_matrix, + text_auto=True, + aspect="auto", + title="Correlation Matrix", + color_continuous_scale="RdBu_r") + fig.update_layout(height=500) + st.plotly_chart(fig, use_container_width=True) + + # Top correlations + st.markdown("#### 🔝 Strongest Correlations") + + # Get correlation pairs + corr_pairs = [] + for i in range(len(corr_matrix.columns)): + for j in range(i+1, len(corr_matrix.columns)): + corr_pairs.append({ + 'Variable 1': corr_matrix.columns[i], + 'Variable 2': corr_matrix.columns[j], + 'Correlation': corr_matrix.iloc[i, j] + }) + + if corr_pairs: + corr_df = pd.DataFrame(corr_pairs) + corr_df['Abs_Correlation'] = abs(corr_df['Correlation']) + corr_df = corr_df.sort_values('Abs_Correlation', ascending=False) + + # Display top 10 correlations + top_corr = corr_df.head(10)[['Variable 1', 'Variable 2', 'Correlation']] + st.dataframe(top_corr, use_container_width=True, hide_index=True) + + def create_multiple_visualizations(self, df: pd.DataFrame): + """Create multiple visualizations for the uploaded data""" + st.markdown("### 📊 Data Visualizations") + + numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() + categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist() + + # Create tabs for different visualization types + viz_tabs = st.tabs(["📈 Overview", "📊 Distributions", "🔗 Relationships", "📋 Summary"]) + + with viz_tabs[0]: # Overview + col1, col2 = st.columns(2) + + with col1: + if len(numeric_cols) >= 2: + # Correlation heatmap + corr_matrix = df[numeric_cols].corr() + fig = px.imshow(corr_matrix, + text_auto=True, + aspect="auto", + title="Correlation Heatmap", + color_continuous_scale="RdBu_r") + fig.update_layout(height=400) + st.plotly_chart(fig, use_container_width=True) + + with col2: + if len(categorical_cols) > 0: + # Categorical distribution + col = categorical_cols[0] + value_counts = df[col].value_counts().head(10) + fig = px.pie(values=value_counts.values, + names=value_counts.index, + title=f"Distribution of {col}") + fig.update_layout(height=400) + st.plotly_chart(fig, use_container_width=True) + + with viz_tabs[1]: # Distributions + if len(numeric_cols) > 0: + col1, col2 = st.columns(2) + + with col1: + # Histogram + selected_num_col = st.selectbox("Select numeric column for histogram", numeric_cols) + fig = px.histogram(df, x=selected_num_col, title=f"Distribution of {selected_num_col}") + st.plotly_chart(fig, use_container_width=True) + + with col2: + # Box plot + fig = px.box(df, y=selected_num_col, title=f"Box Plot of {selected_num_col}") + st.plotly_chart(fig, use_container_width=True) + + # Multiple histograms + if len(numeric_cols) > 1: + st.markdown("#### All Numeric Distributions") + fig = make_subplots( + rows=(len(numeric_cols) + 1) // 2, + cols=2, + subplot_titles=numeric_cols[:6] # Limit to 6 for performance + ) + + for i, col in enumerate(numeric_cols[:6]): + row = (i // 2) + 1 + col_pos = (i % 2) + 1 + fig.add_trace( + go.Histogram(x=df[col], name=col, showlegend=False), + row=row, col=col_pos + ) + + fig.update_layout(height=300 * ((len(numeric_cols[:6]) + 1) // 2)) + st.plotly_chart(fig, use_container_width=True) + + with viz_tabs[2]: # Relationships + if len(numeric_cols) >= 2: + col1, col2 = st.columns(2) + + with col1: + x_col = st.selectbox("X-axis", numeric_cols, key="x_scatter") + with col2: + y_col = st.selectbox("Y-axis", numeric_cols, key="y_scatter", + index=1 if len(numeric_cols) > 1 else 0) + + # Color by categorical if available + color_col = None + if categorical_cols: + color_col = st.selectbox("Color by (optional)", ["None"] + categorical_cols) + color_col = color_col if color_col != "None" else None + + # Scatter plot + fig = px.scatter(df, x=x_col, y=y_col, color=color_col, + title=f"{x_col} vs {y_col}") + st.plotly_chart(fig, use_container_width=True) + + # Scatter matrix for first 4 numeric columns + if len(numeric_cols) >= 3: + st.markdown("#### Scatter Matrix") + cols_for_matrix = numeric_cols[:4] + fig = px.scatter_matrix(df[cols_for_matrix], + title="Scatter Matrix (First 4 Numeric Columns)") + fig.update_layout(height=600) + st.plotly_chart(fig, use_container_width=True) + + with viz_tabs[3]: # Summary + # Data summary table + st.markdown("#### 📋 Data Summary") + summary_data = { + "Metric": ["Total Rows", "Total Columns", "Numeric Columns", "Categorical Columns", "Missing Values", "Memory Usage"], + "Value": [ + f"{len(df):,}", + f"{len(df.columns)}", + f"{len(numeric_cols)}", + f"{len(categorical_cols)}", + f"{df.isnull().sum().sum():,}", + f"{df.memory_usage(deep=True).sum() / 1024**2:.2f} MB" + ] + } + summary_df = pd.DataFrame(summary_data) + st.dataframe(summary_df, use_container_width=True, hide_index=True) + + # Column information + st.markdown("#### 📝 Column Information") + col_info = [] + for col in df.columns: + col_info.append({ + "Column": col, + "Type": str(df[col].dtype), + "Non-Null": f"{df[col].count():,}", + "Null": f"{df[col].isnull().sum():,}", + "Unique": f"{df[col].nunique():,}", + "Sample Values": str(df[col].dropna().head(3).tolist())[:50] + "..." + }) + + col_info_df = pd.DataFrame(col_info) + st.dataframe(col_info_df, use_container_width=True, hide_index=True) + + def test_api_connection(self, api_key: str, model: str) -> bool: + """Test API connection""" + try: + response = requests.post( + "https://api.groq.com/openai/v1/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + }, + json={ + "model": model, + "messages": [{"role": "user", "content": "Say 'Connection successful!' in exactly 3 words."}], + "temperature": 0.1, + "max_tokens": 50 + }, + timeout=10 + ) + return response.status_code == 200 + except Exception as e: + st.error(f"Connection error: {str(e)}") + return False + + def perform_eda(self, df: pd.DataFrame): + """Perform comprehensive EDA""" + st.markdown("## 🔬 Comprehensive EDA Results") + + with st.spinner("Performing comprehensive analysis..."): + eda_results = self.eda_analyzer.perform_complete_eda(df) + + # Save EDA to history + eda_record = { + "timestamp": datetime.now().isoformat(), + "type": "EDA", + "data_shape": df.shape, + "results": "EDA analysis completed", + "session_id": st.session_state.session_id + } + + if self.db_manager: + self.db_manager.save_analysis(eda_record) + else: + st.session_state.analysis_history.append(eda_record) + + # Display EDA results + self.display_eda_results(eda_results) + + def display_eda_results(self, results: Dict): + """Display EDA results""" + tabs = st.tabs([ + "📊 Overview", + "📈 Distributions", + "🔗 Correlations", + "🎯 Insights", + "📋 Data Quality" + ]) + + with tabs[0]: # Overview + self.display_overview(results['overview']) + + with tabs[1]: # Distributions + self.display_distributions(results['distributions']) + + with tabs[2]: # Correlations + self.display_correlations(results['correlations']) + + with tabs[3]: # Insights + self.display_insights(results['insights']) + + with tabs[4]: # Data Quality + self.display_data_quality(results['data_quality']) + + def display_overview(self, overview: Dict): + """Display overview section""" + st.markdown("### 📊 Dataset Overview") + + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Total Rows", f"{overview['total_rows']:,}") + with col2: + st.metric("Total Columns", overview['total_columns']) + with col3: + st.metric("Numeric Columns", overview['numeric_columns']) + with col4: + st.metric("Categorical Columns", overview['categorical_columns']) + + if 'summary_stats' in overview: + st.markdown("### 📈 Summary Statistics") + st.dataframe(overview['summary_stats'], use_container_width=True) + + def display_distributions(self, distributions: Dict): + """Display distribution plots""" + st.markdown("### 📈 Data Distributions") + + for chart_name, chart_data in distributions.items(): + st.plotly_chart(chart_data, use_container_width=True) + + def display_correlations(self, correlations: Dict): + """Display correlation analysis""" + st.markdown("### 🔗 Correlation Analysis") + + if 'heatmap' in correlations: + st.plotly_chart(correlations['heatmap'], use_container_width=True) + + if 'top_correlations' in correlations: + st.markdown("#### 🔝 Top Correlations") + st.dataframe(correlations['top_correlations'], use_container_width=True) + + def display_insights(self, insights: List[Dict]): + """Display AI-generated insights""" + st.markdown("### 🎯 AI-Generated Insights") + + for i, insight in enumerate(insights): + with st.container(): + st.markdown(f""" +
+

💡 {insight['title']}

+

{insight['description']}

+
+ """, unsafe_allow_html=True) + + def display_data_quality(self, quality: Dict): + """Display data quality metrics""" + st.markdown("### 📋 Data Quality Assessment") + + col1, col2 = st.columns(2) + + with col1: + st.markdown("#### Missing Values") + if 'missing_values' in quality: + st.plotly_chart(quality['missing_values'], use_container_width=True) + + with col2: + st.markdown("#### Data Types") + if 'data_types' in quality: + st.plotly_chart(quality['data_types'], use_container_width=True) + + def render_query_interface(self): + """Render natural language query interface""" + st.markdown("## 🚀 AI Query Interface") + + # Show current schema + if st.session_state.data_schema: + with st.expander("🗄️ Current Data Schema", expanded=False): + st.code(st.session_state.data_schema, language="sql") + + # Natural Language Query input + query_input = st.text_area( + "Natural Language Query", + value=st.session_state.example_query, + placeholder="Example: Show me the top 10 customers by total sales amount", + height=100, + help="Describe what you want to analyze in plain English" + ) + + # Clear the example query after it's been used + if st.session_state.example_query and query_input == st.session_state.example_query: + st.session_state.example_query = "" + + # Analysis buttons + col1, col2 = st.columns(2) + + with col1: + # Check API key more thoroughly + api_available = bool(st.session_state.api_key and len(st.session_state.api_key) > 10) + analyze_disabled = not api_available or not query_input.strip() + + if st.button("🧠 Analyze Query", + key="analyze_single", + disabled=analyze_disabled, + help="Generate SQL and insights for your query"): + if not api_available: + st.error("❌ API key required. Check sidebar debug info.") + elif query_input.strip(): + self.analyze_single_query(query_input.strip(), st.session_state.data_schema) + else: + st.error("Please enter a query") + + with col2: + compare_disabled = not api_available or not query_input.strip() + if st.button("⚔️ Model Battle", + key="model_comparison", + disabled=compare_disabled, + help="Compare multiple AI models on your query"): + if not api_available: + st.error("❌ API key required. Check sidebar debug info.") + elif query_input.strip(): + # Store query and show model selection + st.session_state.current_query = query_input.strip() + st.session_state.show_model_selection = True + st.rerun() + else: + st.error("Please enter a query") + + # Show connection status + if not api_available: + st.warning("⚠️ **AI Features Disabled**: API key not detected. Use the '🔄 Reload API Key' button in the sidebar.") + else: + st.success("✅ **AI Features Active**: Ready for natural language queries and model battles!") + + # Recent queries section (simplified) + if hasattr(st.session_state, 'recent_queries') and st.session_state.recent_queries: + with st.expander("📝 Recent Queries", expanded=False): + for i, recent_query in enumerate(st.session_state.recent_queries[-5:]): # Show last 5 + if st.button(f"🔄 {recent_query[:60]}...", key=f"recent_{i}"): + st.session_state.example_query = recent_query + st.rerun() + + # Show model selection interface if activated + if st.session_state.show_model_selection: + self.show_model_selection_interface() + + def analyze_single_query(self, query: str, schema: str = ""): + """Analyze query with single model""" + # Add to recent queries + if query not in st.session_state.recent_queries: + st.session_state.recent_queries.append(query) + # Keep only last 10 queries + st.session_state.recent_queries = st.session_state.recent_queries[-10:] + + with st.spinner(f"🧠 Analyzing with {st.session_state.get('selected_model', 'AI model')}..."): + try: + # Generate SQL + sql_result = self.generate_sql(query, schema) + + # Generate insights + insights_result = self.generate_insights(query, schema) + + # Save to history + analysis_record = { + "timestamp": datetime.now().isoformat(), + "type": "Single Query Analysis", + "query": query, + "schema": schema, + "sql_result": sql_result, + "insights": insights_result, + "model": st.session_state.get('selected_model'), + "session_id": st.session_state.session_id + } + + if self.db_manager: + self.db_manager.save_analysis(analysis_record) + else: + st.session_state.analysis_history.append(analysis_record) + + # Display results + self.display_query_results(sql_result, insights_result) + + except Exception as e: + st.error(f"Analysis failed: {str(e)}") + + def generate_sql(self, query: str, schema: str = "") -> str: + """Generate SQL from natural language""" + schema_context = schema if schema else st.session_state.data_schema + + prompt = f"""Convert this natural language query to SQL: + +Database Schema: {schema_context} + +Natural Language Query: {query} + +Instructions: +- Use the exact column names from the schema +- Generate clean, optimized SQL +- Include appropriate WHERE, GROUP BY, ORDER BY clauses +- Use proper SQL syntax +- Return only the SQL query without explanations + +SQL Query:""" + + return self.make_api_call(st.session_state.selected_model, prompt) + + def generate_insights(self, query: str, schema: str = "") -> str: + """Generate business insights""" + schema_context = schema if schema else st.session_state.data_schema + + prompt = f"""Provide detailed business insights for this data analysis query: + +Database Schema: {schema_context} + +Query: {query} + +Generate 4-5 key business insights in this format: +**Insight Title 1**: Detailed explanation of what this analysis reveals about the business +**Insight Title 2**: Another important finding or recommendation +(continue for 4-5 insights) + +Focus on: +- Business implications +- Actionable recommendations +- Data patterns and trends +- Strategic insights +- Potential opportunities or risks + +Business Insights:""" + + return self.make_api_call(st.session_state.selected_model, prompt) + + def make_api_call(self, model: str, prompt: str) -> str: + """Make API call to Groq""" + try: + response = requests.post( + "https://api.groq.com/openai/v1/chat/completions", + headers={ + "Authorization": f"Bearer {st.session_state.api_key}", + "Content-Type": "application/json" + }, + json={ + "model": model, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.1, + "max_tokens": 1000 + }, + timeout=30 + ) + + if response.status_code == 200: + return response.json()['choices'][0]['message']['content'].strip() + else: + raise Exception(f"API error: {response.status_code}") + + except Exception as e: + raise Exception(f"API call failed: {str(e)}") + + def display_query_results(self, sql_result: str, insights_result: str): + """Display query analysis results""" + st.markdown("## 🎯 Analysis Results") + + tabs = st.tabs(["🔍 SQL Query", "💡 AI Insights", "🔄 Execute Query"]) + + with tabs[0]: + st.markdown("### 🔍 Generated SQL Query") + st.code(sql_result, language='sql') + + # Copy button simulation + if st.button("📋 Copy SQL", key="copy_sql"): + st.success("SQL copied to clipboard! (Use Ctrl+C to copy from the code block above)") + + with tabs[1]: + st.markdown("### 💡 AI-Powered Business Insights") + # Parse and display insights + insights = self.parse_insights_improved(insights_result) + + for i, insight in enumerate(insights): + with st.container(): + st.markdown(f""" +
+

💡 {insight['title']}

+

{insight['text']}

+
+ """, unsafe_allow_html=True) + + with tabs[2]: + st.markdown("### 🔄 Execute Query on Your Data") + + if st.session_state.uploaded_data is not None: + if st.button("▶️ Run SQL on Uploaded Data", key="execute_sql"): + self.execute_sql_on_data(sql_result, st.session_state.uploaded_data) + else: + st.info("Upload data first to execute SQL queries") + + def execute_sql_on_data(self, sql_query: str, df: pd.DataFrame): + """Execute SQL query on the uploaded DataFrame""" + try: + import sqlite3 + import tempfile + + # Create temporary SQLite database + with tempfile.NamedTemporaryFile(delete=False, suffix='.db') as tmp_file: + conn = sqlite3.connect(tmp_file.name) + + # Write DataFrame to SQLite + df.to_sql('uploaded_data', conn, if_exists='replace', index=False) + + # Clean and execute SQL + clean_sql = sql_query.strip() + if clean_sql.lower().startswith('sql:'): + clean_sql = clean_sql[4:].strip() + + # Execute query + result_df = pd.read_sql_query(clean_sql, conn) + conn.close() + + # Display results + st.success("✅ Query executed successfully!") + + col1, col2 = st.columns(2) + with col1: + st.metric("Rows Returned", len(result_df)) + with col2: + st.metric("Columns", len(result_df.columns)) + + # Show results + st.markdown("#### 📊 Query Results") + st.dataframe(result_df, use_container_width=True) + + # Visualization if possible + if len(result_df) > 0: + self.auto_visualize_results(result_df) + + except Exception as e: + st.error(f"Error executing SQL: {str(e)}") + st.info("💡 Tip: The AI-generated SQL might need adjustment for your specific data structure") + + def auto_visualize_results(self, result_df: pd.DataFrame): + """Automatically create visualizations for query results""" + if len(result_df) == 0: + return + + st.markdown("#### 📈 Auto-Generated Visualization") + + numeric_cols = result_df.select_dtypes(include=[np.number]).columns.tolist() + + # Simple bar chart if we have one numeric column + if len(numeric_cols) == 1 and len(result_df) <= 50: + if len(result_df.columns) >= 2: + # Use first non-numeric column as x-axis + text_cols = result_df.select_dtypes(include=['object']).columns.tolist() + if text_cols: + fig = px.bar(result_df, + x=text_cols[0], + y=numeric_cols[0], + title=f"{numeric_cols[0]} by {text_cols[0]}") + st.plotly_chart(fig, use_container_width=True) + + # Line chart for time series-like data + elif len(numeric_cols) >= 1 and len(result_df) > 10: + fig = px.line(result_df, + y=numeric_cols[0], + title=f"Trend: {numeric_cols[0]}") + st.plotly_chart(fig, use_container_width=True) + + def parse_insights_improved(self, raw_insights: str) -> List[Dict]: + """Parse insights from API response with improved formatting""" + insights = [] + + # Split by lines and look for patterns + lines = raw_insights.strip().split('\n') + + current_insight = None + for line in lines: + line = line.strip() + if not line: + continue + + # Look for bold titles or numbered items + if line.startswith('**') and line.endswith('**:'): + if current_insight: + insights.append(current_insight) + + title = line.replace('**', '').replace(':', '').strip() + current_insight = {'title': title, 'text': ''} + + elif line.startswith('**') and '**:' in line: + if current_insight: + insights.append(current_insight) + + parts = line.split('**:', 1) + title = parts[0].replace('**', '').strip() + text = parts[1].strip() if len(parts) > 1 else '' + current_insight = {'title': title, 'text': text} + + elif current_insight and line: + # Continue building the current insight + if current_insight['text']: + current_insight['text'] += ' ' + line + else: + current_insight['text'] = line + + elif line.startswith(('1.', '2.', '3.', '4.', '5.', '-', '•')): + # Handle numbered or bulleted insights + if current_insight: + insights.append(current_insight) + + # Extract title and text + clean_line = line.lstrip('1234567890.-• ').strip() + if ':' in clean_line: + parts = clean_line.split(':', 1) + title = parts[0].strip() + text = parts[1].strip() + else: + title = f"Insight {len(insights) + 1}" + text = clean_line + + current_insight = {'title': title, 'text': text} + + # Add the last insight + if current_insight: + insights.append(current_insight) + + # If no insights were parsed, create one from the raw text + if not insights and raw_insights.strip(): + insights.append({ + 'title': 'AI Analysis', + 'text': raw_insights.strip() + }) + + return insights[:5] # Limit to 5 insights + + def run_model_comparison(self, query: str, schema: str = ""): + """Run model comparison analysis""" + # Double-check API key + if not st.session_state.api_key or len(st.session_state.api_key) < 10: + st.error("🔑 Valid API key required for model comparison feature") + return + + models = [ + "llama-3.3-70b-versatile", + "llama3-70b-8192", + "mixtral-8x7b-32768", + "gemma2-9b-it" + ] + + st.markdown("## ⚔️ Model Comparison Arena") + st.markdown("*Comparing multiple AI models on your query...*") + + # Create progress bar + progress_bar = st.progress(0) + status_text = st.empty() + + results = [] + + for i, model in enumerate(models): + status_text.text(f"Testing {model}...") + progress_bar.progress((i + 1) / len(models)) + + try: + start_time = time.time() + + # Create a more focused prompt for model comparison + comparison_prompt = f"""Analyze this query and provide SQL + business insight: + +Schema: {schema if schema else st.session_state.data_schema} +Query: {query} + +Respond in this exact format: +SQL: [your SQL query here] +INSIGHT: [your business insight here] + +Keep response concise and focused.""" + + response = self.make_api_call(model, comparison_prompt) + response_time = time.time() - start_time + + # Score the response + score = self.score_response(response, response_time) + + results.append({ + 'model': model, + 'response': response, + 'response_time': response_time * 1000, # Convert to ms + 'score': score, + 'success': True + }) + + except Exception as e: + results.append({ + 'model': model, + 'response': f"Error: {str(e)}", + 'response_time': 0, + 'score': 0, + 'success': False + }) + + # Add small delay between requests to avoid rate limiting + time.sleep(1) + + # Clear progress indicators + progress_bar.empty() + status_text.empty() + + # Save comparison to history + comparison_record = { + "timestamp": datetime.now().isoformat(), + "type": "Model Comparison", + "query": query, + "schema": schema, + "results": [ + { + 'model': r['model'], + 'score': r['score'], + 'success': r['success'], + 'response_time': r['response_time'] + } for r in results # Store simplified version for history + ], + "session_id": st.session_state.session_id + } + + if self.db_manager: + self.db_manager.save_analysis(comparison_record) + else: + st.session_state.analysis_history.append(comparison_record) + + # Display comparison results + self.display_comparison_results(results) + + def score_response(self, response: str, response_time: float) -> int: + """Score the model response based on quality and speed""" + try: + # Quality scoring based on content + response_lower = response.lower() + + # Check for SQL presence (40 points max) + has_sql = any(keyword in response_lower for keyword in ['select', 'from', 'where', 'group by', 'order by']) + sql_score = 40 if has_sql else 0 + + # Check for insight presence (30 points max) + has_insight = any(keyword in response_lower for keyword in ['insight', 'analysis', 'recommendation', 'business', 'trend']) + insight_score = 30 if has_insight else 0 + + # Response length and completeness (20 points max) + length_score = min(len(response) / 20, 20) # 1 point per 20 characters, max 20 + + # Speed scoring (10 points max) - faster is better + if response_time > 0: + speed_score = max(0, 10 - (response_time * 2)) # Penalty for slow responses + else: + speed_score = 0 + + # Calculate total score + total_score = sql_score + insight_score + length_score + speed_score + + # Ensure score is between 0 and 100 + return max(0, min(100, round(total_score))) + + except Exception: + return 0 + + def display_comparison_results(self, results: List[Dict]): + """Display model comparison results with enhanced formatting""" + if not results: + st.error("No results to display") + return + + # Sort by score for ranking + sorted_results = sorted([r for r in results if r['success']], key=lambda x: x['score'], reverse=True) + failed_results = [r for r in results if not r['success']] + + # Winner summary + if sorted_results: + winner = sorted_results[0] + fastest = min([r for r in results if r['success']], key=lambda x: x['response_time'], default=winner) + + # Create winner announcement + col1, col2, col3 = st.columns(3) + + with col1: + st.markdown(f""" +
+

🏆 WINNER

+

{winner['model']}

+

Score: {winner['score']}/100

+
+ """, unsafe_allow_html=True) + + with col2: + st.markdown(f""" +
+

⚡ FASTEST

+

{fastest['model']}

+

{fastest['response_time']:.0f}ms

+
+ """, unsafe_allow_html=True) + + with col3: + success_rate = len(sorted_results) / len(results) * 100 + st.markdown(f""" +
+

📊 SUCCESS

+

{len(sorted_results)}/{len(results)}

+

{success_rate:.0f}% Success Rate

+
+ """, unsafe_allow_html=True) + + st.markdown("---") + + # Detailed results in tabs + if len(sorted_results) > 0: + # Create tabs for each successful model + tab_names = [f"{'🥇' if i == 0 else '🥈' if i == 1 else '🥉' if i == 2 else '📊'} {result['model']}" + for i, result in enumerate(sorted_results)] + + if len(tab_names) > 0: + tabs = st.tabs(tab_names) + + for i, (tab, result) in enumerate(zip(tabs, sorted_results)): + with tab: + # Performance metrics + metric_col1, metric_col2, metric_col3, metric_col4 = st.columns(4) + + with metric_col1: + st.metric("🏆 Rank", f"#{i+1}") + with metric_col2: + st.metric("📊 Score", f"{result['score']}/100") + with metric_col3: + st.metric("⚡ Speed", f"{result['response_time']:.0f}ms") + with metric_col4: + quality = "Excellent" if result['score'] >= 80 else "Good" if result['score'] >= 60 else "Fair" if result['score'] >= 40 else "Poor" + st.metric("🎯 Quality", quality) + + st.markdown("---") + + # Model response + st.markdown("### 📝 Response") + + # Try to parse SQL and Insight separately + response = result['response'] + + if "SQL:" in response and "INSIGHT:" in response: + parts = response.split("INSIGHT:") + sql_part = parts[0].replace("SQL:", "").strip() + insight_part = parts[1].strip() + + st.markdown("**🔍 Generated SQL:**") + st.code(sql_part, language='sql') + + st.markdown("**💡 Business Insight:**") + st.markdown(insight_part) + else: + # Show full response if parsing fails + st.markdown("**📄 Full Response:**") + st.text_area("", response, height=200, key=f"response_full_{i}") + + # Show failed models if any + if failed_results: + st.markdown("---") + st.markdown("### ❌ Failed Models") + + for result in failed_results: + with st.expander(f"❌ {result['model']} - Failed"): + st.error(f"Error: {result['response']}") + + # Add comparison summary + if len(sorted_results) > 1: + st.markdown("---") + st.markdown("### 📈 Performance Comparison") + + # Create comparison chart + import plotly.graph_objects as go + + models = [r['model'].replace('-', '
') for r in sorted_results] + scores = [r['score'] for r in sorted_results] + times = [r['response_time'] for r in sorted_results] + + fig = go.Figure() + + # Add score bars + fig.add_trace(go.Bar( + name='Score', + x=models, + y=scores, + yaxis='y', + marker_color='#FFD700' + )) + + # Add response time line + fig.add_trace(go.Scatter( + name='Response Time (ms)', + x=models, + y=times, + yaxis='y2', + mode='lines+markers', + line=dict(color='#FF6B6B', width=3), + marker=dict(size=8) + )) + + # Update layout for dual y-axis + fig.update_layout( + title='Model Performance Comparison', + xaxis=dict(title='Models'), + yaxis=dict(title='Score (0-100)', side='left'), + yaxis2=dict(title='Response Time (ms)', side='right', overlaying='y'), + height=400, + plot_bgcolor='rgba(0,0,0,0)', + paper_bgcolor='rgba(0,0,0,0)', + font=dict(color='white') + ) + + st.plotly_chart(fig, use_container_width=True) + + def show_history(self): + """Show analysis history""" + st.markdown("## 📊 Analysis History") + + if self.db_manager: + history = self.db_manager.get_history(st.session_state.session_id) + else: + history = st.session_state.analysis_history + + if not history: + st.info("No analysis history found.") + return + + for i, record in enumerate(reversed(history[-10:])): # Show last 10 records + with st.expander(f"{record['type']} - {record['timestamp'][:19]}"): + if record['type'] == 'Single Query Analysis': + st.markdown(f"**Query:** {record['query']}") + if record.get('schema'): + st.markdown(f"**Schema:** {record['schema']}") + st.markdown("**SQL Result:**") + st.code(record['sql_result'], language='sql') + + elif record['type'] == 'Model Comparison': + st.markdown(f"**Query:** {record['query']}") + st.markdown("**Results:**") + for result in record['results']: + st.markdown(f"- {result['model']}: {result['score']}/100") + + elif record['type'] == 'EDA': + st.markdown(f"**Data Shape:** {record['data_shape']}") + st.markdown("**Analysis completed successfully**") + + def show_statistical_summary(self, df: pd.DataFrame): + """Show statistical summary without API""" + st.markdown("### 📊 Statistical Summary") + + numeric_cols = df.select_dtypes(include=[np.number]).columns + if len(numeric_cols) > 0: + st.dataframe(df[numeric_cols].describe(), use_container_width=True) + + # Additional stats + st.markdown("#### 📈 Additional Statistics") + additional_stats = pd.DataFrame({ + 'Column': numeric_cols, + 'Skewness': [df[col].skew() for col in numeric_cols], + 'Kurtosis': [df[col].kurtosis() for col in numeric_cols], + 'Missing %': [df[col].isnull().sum() / len(df) * 100 for col in numeric_cols] + }) + st.dataframe(additional_stats, use_container_width=True, hide_index=True) + else: + st.info("No numeric columns found for statistical analysis") + + def show_outlier_detection(self, df: pd.DataFrame): + """Show outlier detection without API""" + st.markdown("### 🔍 Outlier Detection") + + numeric_cols = df.select_dtypes(include=[np.number]).columns + if len(numeric_cols) > 0: + selected_col = st.selectbox("Select column for outlier analysis", numeric_cols, key="outlier_col") + + if selected_col: + data = df[selected_col].dropna() + + # IQR method + Q1 = data.quantile(0.25) + Q3 = data.quantile(0.75) + IQR = Q3 - Q1 + lower_bound = Q1 - 1.5 * IQR + upper_bound = Q3 + 1.5 * IQR + + outliers = df[(df[selected_col] < lower_bound) | (df[selected_col] > upper_bound)] + + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total Outliers", len(outliers)) + with col2: + st.metric("Outlier %", f"{len(outliers)/len(df)*100:.1f}%") + with col3: + st.metric("IQR", f"{IQR:.2f}") + + # Show outlier values + if len(outliers) > 0: + st.markdown("#### 🎯 Outlier Values") + st.dataframe(outliers[[selected_col]].head(20), use_container_width=True) + else: + st.info("No numeric columns found for outlier analysis") + + def show_correlation_matrix(self, df: pd.DataFrame): + """Show correlation matrix without API""" + st.markdown("### 🔗 Correlation Analysis") + + numeric_cols = df.select_dtypes(include=[np.number]).columns + if len(numeric_cols) >= 2: + corr_matrix = df[numeric_cols].corr() + + # Heatmap + fig = px.imshow(corr_matrix, + text_auto=True, + aspect="auto", + title="Correlation Matrix", + color_continuous_scale="RdBu_r") + fig.update_layout(height=500) + st.plotly_chart(fig, use_container_width=True) + + # Strong correlations + st.markdown("#### 🔝 Strongest Correlations") + corr_pairs = [] + for i in range(len(corr_matrix.columns)): + for j in range(i+1, len(corr_matrix.columns)): + corr_pairs.append({ + 'Variable 1': corr_matrix.columns[i], + 'Variable 2': corr_matrix.columns[j], + 'Correlation': corr_matrix.iloc[i, j] + }) + + if corr_pairs: + corr_df = pd.DataFrame(corr_pairs) + corr_df['Abs_Correlation'] = abs(corr_df['Correlation']) + corr_df = corr_df.sort_values('Abs_Correlation', ascending=False) + st.dataframe(corr_df.head(10)[['Variable 1', 'Variable 2', 'Correlation']], + use_container_width=True, hide_index=True) + else: + st.info("Need at least 2 numeric columns for correlation analysis") + + def show_model_selection_interface(self): + """Show persistent model selection interface""" + st.markdown("---") + st.markdown("## ⚔️ Model Battle Setup") + st.markdown(f"**Query:** {st.session_state.current_query}") + + # Available models + available_models = { + "llama-3.3-70b-versatile": "Llama 3.3 70B (Most Advanced)", + "llama3-70b-8192": "Llama 3 70B (Reliable)", + "mixtral-8x7b-32768": "Mixtral 8x7B (Fast & Efficient)", + "gemma2-9b-it": "Gemma 2 9B (Lightweight)", + "qwen-qwq-32b": "Qwen QwQ 32B (Reasoning)", + "deepseek-r1-distill-llama-70b": "DeepSeek R1 70B (Specialized)" + } + + st.markdown("### 🎯 Select Models for Battle") + + # Model selection with checkboxes - using session state keys + col1, col2 = st.columns(2) + + with col1: + st.markdown("**🚀 High-Performance Models:**") + llama33_selected = st.checkbox("Llama 3.3 70B (Most Advanced)", key="battle_llama33", value=True) + llama3_selected = st.checkbox("Llama 3 70B (Reliable)", key="battle_llama3", value=True) + deepseek_selected = st.checkbox("DeepSeek R1 70B (Specialized)", key="battle_deepseek", value=False) + + with col2: + st.markdown("**⚡ Fast & Efficient Models:**") + mixtral_selected = st.checkbox("Mixtral 8x7B (Fast & Efficient)", key="battle_mixtral", value=True) + gemma_selected = st.checkbox("Gemma 2 9B (Lightweight)", key="battle_gemma", value=False) + qwen_selected = st.checkbox("Qwen QwQ 32B (Reasoning)", key="battle_qwen", value=False) + + # Build selected models list + selected_models = [] + if llama33_selected: + selected_models.append("llama-3.3-70b-versatile") + if llama3_selected: + selected_models.append("llama3-70b-8192") + if deepseek_selected: + selected_models.append("deepseek-r1-distill-llama-70b") + if mixtral_selected: + selected_models.append("mixtral-8x7b-32768") + if gemma_selected: + selected_models.append("gemma2-9b-it") + if qwen_selected: + selected_models.append("qwen-qwq-32b") + + # Show selection summary + if selected_models: + st.success(f"✅ **Selected Models:** {len(selected_models)} models ready for battle") + + # Battle configuration + col1, col2, col3 = st.columns(3) + with col1: + test_rounds = st.selectbox("Test Rounds", [1, 2, 3], index=0, help="Number of times to test each model") + with col2: + timeout_seconds = st.selectbox("Timeout (seconds)", [10, 20, 30], index=1, help="Max time to wait for each response") + with col3: + # Cancel button + if st.button("❌ Cancel", key="cancel_battle"): + st.session_state.show_model_selection = False + st.rerun() + + # Start battle button + col1, col2 = st.columns([3, 1]) + with col1: + if st.button("🚀 Start Model Battle", key="start_battle_persist", type="primary"): + st.session_state.show_model_selection = False # Hide selection interface + self.run_model_comparison_with_selection( + st.session_state.current_query, + st.session_state.data_schema, + selected_models, + test_rounds, + timeout_seconds + ) + else: + st.warning("⚠️ Please select at least one model for the battle") + + # Cancel button for when no models selected + if st.button("❌ Cancel", key="cancel_battle_no_models"): + st.session_state.show_model_selection = False + st.rerun() + + def show_model_selection_and_run(self, query: str, schema: str = ""): + """Legacy method - now redirects to persistent interface""" + st.session_state.current_query = query + st.session_state.show_model_selection = True + + def run_model_comparison_with_selection(self, query: str, schema: str, selected_models: list, rounds: int, timeout: int): + """Run model comparison with selected models""" + if not selected_models: + st.error("No models selected") + return + + # Double-check API key + if not st.session_state.api_key or len(st.session_state.api_key) < 10: + st.error("🔑 Valid API key required for model comparison feature") + return + + st.markdown("## ⚔️ Model Battle Arena") + st.markdown(f"*Testing {len(selected_models)} models with {rounds} round(s) each...*") + + # Create progress bar + total_tests = len(selected_models) * rounds + progress_bar = st.progress(0) + status_text = st.empty() + + results = [] + test_count = 0 + + for model in selected_models: + model_results = [] + + for round_num in range(rounds): + test_count += 1 + status_text.text(f"Testing {model} (Round {round_num + 1}/{rounds})...") + progress_bar.progress(test_count / total_tests) + + try: + start_time = time.time() + + # Create a focused prompt + comparison_prompt = f"""Analyze this query and provide SQL + business insight: + +Schema: {schema if schema else st.session_state.data_schema} +Query: {query} + +Respond in this exact format: +SQL: [your SQL query here] +INSIGHT: [your business insight here] + +Keep response concise and focused.""" + + response = self.make_api_call_with_timeout(model, comparison_prompt, timeout) + response_time = time.time() - start_time + + # Score the response + score = self.score_response(response, response_time) + + model_results.append({ + 'response': response, + 'response_time': response_time * 1000, + 'score': score, + 'success': True, + 'round': round_num + 1 + }) + + except Exception as e: + model_results.append({ + 'response': f"Error: {str(e)}", + 'response_time': 0, + 'score': 0, + 'success': False, + 'round': round_num + 1 + }) + + # Small delay between requests + time.sleep(0.5) + + # Calculate average performance for this model + successful_results = [r for r in model_results if r['success']] + if successful_results: + avg_score = sum(r['score'] for r in successful_results) / len(successful_results) + avg_time = sum(r['response_time'] for r in successful_results) / len(successful_results) + best_response = max(successful_results, key=lambda x: x['score'])['response'] + else: + avg_score = 0 + avg_time = 0 + best_response = "All attempts failed" + + results.append({ + 'model': model, + 'avg_score': avg_score, + 'avg_response_time': avg_time, + 'success_rate': len(successful_results) / len(model_results) * 100, + 'best_response': best_response, + 'all_results': model_results, + 'success': len(successful_results) > 0 + }) + + # Clear progress indicators + progress_bar.empty() + status_text.empty() + + # Display enhanced results + self.display_enhanced_comparison_results(results, query) + + def make_api_call_with_timeout(self, model: str, prompt: str, timeout: int) -> str: + """Make API call with custom timeout""" + try: + response = requests.post( + "https://api.groq.com/openai/v1/chat/completions", + headers={ + "Authorization": f"Bearer {st.session_state.api_key}", + "Content-Type": "application/json" + }, + json={ + "model": model, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.1, + "max_tokens": 1000 + }, + timeout=timeout + ) + + if response.status_code == 200: + return response.json()['choices'][0]['message']['content'].strip() + else: + raise Exception(f"API error: {response.status_code}") + + except Exception as e: + raise Exception(f"API call failed: {str(e)}") + + def display_enhanced_comparison_results(self, results: List[Dict], query: str): + """Display enhanced comparison results with multiple rounds""" + if not results: + st.error("No results to display") + return + + # Sort by average score + sorted_results = sorted([r for r in results if r['success']], key=lambda x: x['avg_score'], reverse=True) + failed_results = [r for r in results if not r['success']] + + # Enhanced winner announcement + if sorted_results: + winner = sorted_results[0] + fastest = min(sorted_results, key=lambda x: x['avg_response_time']) + most_reliable = max(sorted_results, key=lambda x: x['success_rate']) + + st.markdown("### 🏆 Battle Results") + + col1, col2, col3 = st.columns(3) + + with col1: + st.markdown(f""" +
+

🏆 HIGHEST SCORE

+

{winner['model'].replace('-', ' ').title()}

+

Avg Score: {winner['avg_score']:.1f}/100

+
+ """, unsafe_allow_html=True) + + with col2: + st.markdown(f""" +
+

⚡ FASTEST

+

{fastest['model'].replace('-', ' ').title()}

+

Avg: {fastest['avg_response_time']:.0f}ms

+
+ """, unsafe_allow_html=True) + + with col3: + st.markdown(f""" +
+

🎯 MOST RELIABLE

+

{most_reliable['model'].replace('-', ' ').title()}

+

{most_reliable['success_rate']:.0f}% Success

+
+ """, unsafe_allow_html=True) + + st.markdown("---") + + # Performance comparison chart + st.markdown("### 📊 Performance Comparison") + + models = [r['model'].replace('-', ' ').replace('versatile', '').replace('8192', '').title() for r in sorted_results] + scores = [r['avg_score'] for r in sorted_results] + times = [r['avg_response_time'] for r in sorted_results] + + fig = go.Figure() + + fig.add_trace(go.Bar( + name='Average Score', + x=models, + y=scores, + yaxis='y', + marker_color='#FFD700', + text=[f"{s:.1f}" for s in scores], + textposition='auto' + )) + + fig.add_trace(go.Scatter( + name='Response Time (ms)', + x=models, + y=times, + yaxis='y2', + mode='lines+markers', + line=dict(color='#FF6B6B', width=3), + marker=dict(size=10) + )) + + fig.update_layout( + title=f'Model Performance: "{query[:50]}..."', + xaxis=dict(title='Models'), + yaxis=dict(title='Score (0-100)', side='left'), + yaxis2=dict(title='Response Time (ms)', side='right', overlaying='y'), + height=400, + plot_bgcolor='rgba(0,0,0,0)', + paper_bgcolor='rgba(0,0,0,0)', + font=dict(color='white') + ) + + st.plotly_chart(fig, use_container_width=True) + + # Detailed results + st.markdown("### 📋 Detailed Results") + for i, result in enumerate(sorted_results): + with st.expander(f"{'🥇' if i == 0 else '🥈' if i == 1 else '🥉' if i == 2 else '📊'} {result['model']} - Avg Score: {result['avg_score']:.1f}/100"): + + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Avg Score", f"{result['avg_score']:.1f}/100") + with col2: + st.metric("Avg Speed", f"{result['avg_response_time']:.0f}ms") + with col3: + st.metric("Success Rate", f"{result['success_rate']:.0f}%") + with col4: + st.metric("Total Tests", len(result['all_results'])) + + st.markdown("**Best Response:**") + best_response = result['best_response'] + + if "SQL:" in best_response and "INSIGHT:" in best_response: + parts = best_response.split("INSIGHT:") + sql_part = parts[0].replace("SQL:", "").strip() + insight_part = parts[1].strip() + + st.markdown("**SQL:**") + st.code(sql_part, language='sql') + st.markdown("**Insight:**") + st.markdown(insight_part) + else: + st.text_area("", best_response, height=150, key=f"best_response_{i}") + + # Show failed models + if failed_results: + st.markdown("### ❌ Failed Models") + for result in failed_results: + st.error(f"**{result['model']}**: All attempts failed") + + def perform_eda_inline(self, df: pd.DataFrame): + """Perform comprehensive EDA and display results inline""" + st.markdown("---") + st.markdown("## 🔬 Comprehensive EDA Results") + + try: + eda_results = self.eda_analyzer.perform_complete_eda(df) + + # Save EDA to history + eda_record = { + "timestamp": datetime.now().isoformat(), + "type": "EDA", + "data_shape": df.shape, + "results": "EDA analysis completed", + "session_id": st.session_state.session_id + } + + if self.db_manager: + self.db_manager.save_analysis(eda_record) + else: + st.session_state.analysis_history.append(eda_record) + + # Display EDA results inline + self.display_eda_results_inline(eda_results) + + except Exception as e: + st.error(f"EDA analysis failed: {str(e)}") + + def display_eda_results_inline(self, results: Dict): + """Display EDA results inline without navigation""" + # Create tabs for different sections + eda_tabs = st.tabs([ + "📊 Overview", + "📈 Distributions", + "🔗 Correlations", + "🎯 Insights", + "📋 Data Quality" + ]) + + with eda_tabs[0]: # Overview + st.markdown("### 📊 Dataset Overview") + overview = results.get('overview', {}) + + if overview: + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Total Rows", f"{overview.get('total_rows', 0):,}") + with col2: + st.metric("Total Columns", overview.get('total_columns', 0)) + with col3: + st.metric("Numeric Columns", overview.get('numeric_columns', 0)) + with col4: + st.metric("Categorical Columns", overview.get('categorical_columns', 0)) + + # Additional metrics + col5, col6, col7, col8 = st.columns(4) + with col5: + st.metric("Missing Values", f"{overview.get('missing_values_total', 0):,}") + with col6: + st.metric("Duplicate Rows", f"{overview.get('duplicate_rows', 0):,}") + with col7: + st.metric("Memory Usage", overview.get('memory_usage', '0 MB')) + with col8: + st.metric("DateTime Columns", overview.get('datetime_columns', 0)) + + if 'summary_stats' in overview: + st.markdown("### 📈 Summary Statistics") + st.dataframe(overview['summary_stats'], use_container_width=True) + + with eda_tabs[1]: # Distributions + st.markdown("### 📈 Data Distributions") + distributions = results.get('distributions', {}) + + if distributions: + for chart_name, chart_data in distributions.items(): + if hasattr(chart_data, 'update_layout'): # Check if it's a plotly figure + st.plotly_chart(chart_data, use_container_width=True) + else: + st.write(f"Chart: {chart_name}") + else: + st.info("No distribution charts available") + + with eda_tabs[2]: # Correlations + st.markdown("### 🔗 Correlation Analysis") + correlations = results.get('correlations', {}) + + if 'heatmap' in correlations: + st.plotly_chart(correlations['heatmap'], use_container_width=True) + + if 'top_correlations' in correlations: + st.markdown("#### 🔝 Top Correlations") + st.dataframe(correlations['top_correlations'], use_container_width=True) + + if 'scatter_matrix' in correlations: + st.plotly_chart(correlations['scatter_matrix'], use_container_width=True) + + with eda_tabs[3]: # Insights + st.markdown("### 🎯 Generated Insights") + insights = results.get('insights', []) + + if insights: + for i, insight in enumerate(insights): + with st.container(): + st.markdown(f""" +
+

💡 {insight.get('title', f'Insight {i+1}')}

+

{insight.get('description', 'No description available')}

+
+ """, unsafe_allow_html=True) + else: + st.info("No insights generated") + + with eda_tabs[4]: # Data Quality + st.markdown("### 📋 Data Quality Assessment") + quality = results.get('data_quality', {}) + + col1, col2 = st.columns(2) + + with col1: + st.markdown("#### Missing Values") + if 'missing_values' in quality: + st.plotly_chart(quality['missing_values'], use_container_width=True) + else: + st.info("No missing values chart available") + + with col2: + st.markdown("#### Data Types") + if 'data_types' in quality: + st.plotly_chart(quality['data_types'], use_container_width=True) + else: + st.info("No data types chart available") + + # Show duplicates info if available + if 'duplicates' in quality: + st.markdown("#### Duplicate Analysis") + dup_info = quality['duplicates'] + col1, col2 = st.columns(2) + with col1: + st.metric("Duplicate Count", dup_info.get('count', 0)) + with col2: + st.metric("Duplicate %", f"{dup_info.get('percentage', 0):.1f}%") + + def generate_ai_insights_inline(self, df: pd.DataFrame): + """Generate AI-powered insights and display inline""" + st.markdown("---") + st.markdown("## 🤖 AI-Generated Insights") + + try: + # Prepare data summary for AI + summary = f""" + Dataset Analysis: + - Rows: {len(df):,} + - Columns: {len(df.columns)} + - Schema: {st.session_state.data_schema} + - Missing values: {df.isnull().sum().sum():,} + + Column types: + {df.dtypes.to_string()} + + Sample data: + {df.head(3).to_string()} + """ + + prompt = f"""Analyze this dataset and provide 5 key business insights: + + {summary} + + Format as: + 1. **Insight Title**: Description + 2. **Insight Title**: Description + (etc.) + + Focus on business value, patterns, and actionable recommendations.""" + + insights = self.make_api_call(st.session_state.selected_model, prompt) + + # Display insights in a nice format + st.markdown("### 💡 Business Intelligence Report") + + # Parse insights and display + parsed_insights = self.parse_insights_improved(insights) + + for i, insight in enumerate(parsed_insights): + st.markdown(f""" +
+

💡 {insight['title']}

+

{insight['text']}

+
+ """, unsafe_allow_html=True) + + # Save to history + analysis_record = { + "timestamp": datetime.now().isoformat(), + "type": "AI Insights", + "data_shape": df.shape, + "insights": insights, + "session_id": st.session_state.session_id + } + + if self.db_manager: + self.db_manager.save_analysis(analysis_record) + else: + st.session_state.analysis_history.append(analysis_record) + + except Exception as e: + st.error(f"Failed to generate AI insights: {str(e)}") + + def show_advanced_analytics_inline(self, df: pd.DataFrame): + """Show advanced analytics inline""" + st.markdown("---") + st.markdown("## 📊 Advanced Analytics") + + analytics_tabs = st.tabs(["📈 Statistical Summary", "🔍 Outlier Detection", "📊 Correlation Analysis", "🎯 Distribution Analysis"]) + + with analytics_tabs[0]: + self.show_statistical_summary_advanced(df) + + with analytics_tabs[1]: + self.show_outlier_analysis_advanced(df) + + with analytics_tabs[2]: + self.show_correlation_analysis_advanced(df) + + with analytics_tabs[3]: + self.show_distribution_analysis_advanced(df) + + def show_statistical_summary_advanced(self, df: pd.DataFrame): + """Advanced statistical summary""" + st.markdown("### 📈 Advanced Statistical Summary") + + numeric_cols = df.select_dtypes(include=[np.number]).columns + if len(numeric_cols) > 0: + # Basic statistics + st.markdown("#### 📊 Descriptive Statistics") + st.dataframe(df[numeric_cols].describe(), use_container_width=True) + + # Advanced statistics + st.markdown("#### 📏 Advanced Distribution Metrics") + advanced_stats = pd.DataFrame({ + 'Column': numeric_cols, + 'Skewness': [df[col].skew() for col in numeric_cols], + 'Kurtosis': [df[col].kurtosis() for col in numeric_cols], + 'Missing %': [df[col].isnull().sum() / len(df) * 100 for col in numeric_cols], + 'Zeros %': [(df[col] == 0).sum() / len(df) * 100 for col in numeric_cols], + 'Unique Values': [df[col].nunique() for col in numeric_cols] + }) + st.dataframe(advanced_stats, use_container_width=True, hide_index=True) + + # Statistical interpretation + st.markdown("#### 🎯 Statistical Interpretation") + for col in numeric_cols[:3]: # Limit to first 3 columns + skewness = df[col].skew() + kurtosis = df[col].kurtosis() + + skew_interpretation = ( + "Normal" if abs(skewness) < 0.5 else + "Slightly Skewed" if abs(skewness) < 1 else + "Moderately Skewed" if abs(skewness) < 2 else + "Highly Skewed" + ) + + kurt_interpretation = ( + "Normal" if abs(kurtosis) < 1 else + "Heavy-tailed" if kurtosis > 1 else + "Light-tailed" + ) + + st.markdown(f"**{col}:** {skew_interpretation} distribution, {kurt_interpretation} shape") + else: + st.info("No numeric columns found for statistical analysis") + + def show_outlier_analysis_advanced(self, df: pd.DataFrame): + """Advanced outlier analysis""" + st.markdown("### 🔍 Advanced Outlier Detection") + + numeric_cols = df.select_dtypes(include=[np.number]).columns + if len(numeric_cols) > 0: + # Outlier summary for all columns + st.markdown("#### 📊 Outlier Summary (All Columns)") + + outlier_summary = [] + for col in numeric_cols: + data = df[col].dropna() + if len(data) > 0: + Q1 = data.quantile(0.25) + Q3 = data.quantile(0.75) + IQR = Q3 - Q1 + lower_bound = Q1 - 1.5 * IQR + upper_bound = Q3 + 1.5 * IQR + + outliers = df[(df[col] < lower_bound) | (df[col] > upper_bound)] + outlier_summary.append({ + 'Column': col, + 'Outlier Count': len(outliers), + 'Outlier %': len(outliers) / len(df) * 100, + 'Lower Bound': lower_bound, + 'Upper Bound': upper_bound + }) + + if outlier_summary: + outlier_df = pd.DataFrame(outlier_summary) + st.dataframe(outlier_df, use_container_width=True, hide_index=True) + + # Detailed analysis for selected column + st.markdown("#### 🎯 Detailed Outlier Analysis") + selected_col = st.selectbox("Select column for detailed analysis", numeric_cols, key="detailed_outlier") + + if selected_col: + data = df[selected_col].dropna() + Q1 = data.quantile(0.25) + Q3 = data.quantile(0.75) + IQR = Q3 - Q1 + lower_bound = Q1 - 1.5 * IQR + upper_bound = Q3 + 1.5 * IQR + + outliers = df[(df[selected_col] < lower_bound) | (df[selected_col] > upper_bound)] + + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Total Outliers", len(outliers)) + with col2: + st.metric("Outlier %", f"{len(outliers)/len(df)*100:.1f}%") + with col3: + st.metric("IQR", f"{IQR:.2f}") + with col4: + st.metric("Range", f"{upper_bound - lower_bound:.2f}") + + # Visualization + fig = go.Figure() + + # Box plot + fig.add_trace(go.Box( + y=data, + name=selected_col, + boxpoints='outliers' + )) + + fig.update_layout( + title=f'Box Plot with Outliers: {selected_col}', + height=400 + ) + + st.plotly_chart(fig, use_container_width=True) + + # Show actual outlier values + if len(outliers) > 0: + st.markdown("#### 📋 Outlier Values") + st.dataframe(outliers[[selected_col]].head(20), use_container_width=True) + else: + st.info("No numeric columns found for outlier analysis") + + def show_correlation_analysis_advanced(self, df: pd.DataFrame): + """Advanced correlation analysis""" + st.markdown("### 📊 Advanced Correlation Analysis") + + numeric_cols = df.select_dtypes(include=[np.number]).columns + if len(numeric_cols) >= 2: + # Correlation matrix + corr_matrix = df[numeric_cols].corr() + + # Enhanced heatmap + fig = px.imshow(corr_matrix, + text_auto=True, + aspect="auto", + title="Enhanced Correlation Matrix", + color_continuous_scale="RdBu_r") + fig.update_layout(height=500) + st.plotly_chart(fig, use_container_width=True) + + # Correlation strength analysis + st.markdown("#### 🎯 Correlation Strength Analysis") + + corr_pairs = [] + for i in range(len(corr_matrix.columns)): + for j in range(i+1, len(corr_matrix.columns)): + corr_value = corr_matrix.iloc[i, j] + strength = ( + "Very Strong" if abs(corr_value) >= 0.8 else + "Strong" if abs(corr_value) >= 0.6 else + "Moderate" if abs(corr_value) >= 0.4 else + "Weak" if abs(corr_value) >= 0.2 else + "Very Weak" + ) + + corr_pairs.append({ + 'Variable 1': corr_matrix.columns[i], + 'Variable 2': corr_matrix.columns[j], + 'Correlation': corr_value, + 'Abs Correlation': abs(corr_value), + 'Strength': strength + }) + + if corr_pairs: + corr_df = pd.DataFrame(corr_pairs) + corr_df = corr_df.sort_values('Abs Correlation', ascending=False) + + st.dataframe(corr_df[['Variable 1', 'Variable 2', 'Correlation', 'Strength']].head(15), + use_container_width=True, hide_index=True) + + # Strong correlations highlight + strong_corr = corr_df[corr_df['Abs Correlation'] >= 0.6] + if len(strong_corr) > 0: + st.markdown("#### ⚠️ Strong Correlations (>0.6)") + st.dataframe(strong_corr[['Variable 1', 'Variable 2', 'Correlation']], + use_container_width=True, hide_index=True) + st.warning("Strong correlations may indicate multicollinearity issues in modeling.") + else: + st.info("Need at least 2 numeric columns for correlation analysis") + + def show_distribution_analysis_advanced(self, df: pd.DataFrame): + """Advanced distribution analysis""" + st.markdown("### 🎯 Advanced Distribution Analysis") + + numeric_cols = df.select_dtypes(include=[np.number]).columns + categorical_cols = df.select_dtypes(include=['object', 'category']).columns + + if len(numeric_cols) > 0: + st.markdown("#### 📊 Numeric Distribution Analysis") + + selected_num_col = st.selectbox("Select numeric column", numeric_cols, key="dist_numeric") + + col1, col2 = st.columns(2) + + with col1: + # Histogram with normal curve overlay + data = df[selected_num_col].dropna() + + fig = go.Figure() + + # Histogram + fig.add_trace(go.Histogram( + x=data, + name='Distribution', + opacity=0.7, + nbinsx=30 + )) + + fig.update_layout( + title=f'Distribution of {selected_num_col}', + xaxis_title=selected_num_col, + yaxis_title='Frequency', + height=400 + ) + + st.plotly_chart(fig, use_container_width=True) + + with col2: + # QQ plot approximation using percentiles + fig = go.Figure() + + # Box plot + fig.add_trace(go.Box( + y=data, + name=selected_num_col, + boxpoints='all', + jitter=0.3, + pointpos=-1.8 + )) + + fig.update_layout( + title=f'Box Plot of {selected_num_col}', + height=400 + ) + + st.plotly_chart(fig, use_container_width=True) + + # Distribution statistics + st.markdown("#### 📈 Distribution Statistics") + col1, col2, col3, col4 = st.columns(4) + + with col1: + st.metric("Mean", f"{data.mean():.2f}") + with col2: + st.metric("Median", f"{data.median():.2f}") + with col3: + st.metric("Std Dev", f"{data.std():.2f}") + with col4: + st.metric("Range", f"{data.max() - data.min():.2f}") + + if len(categorical_cols) > 0: + st.markdown("#### 🏷️ Categorical Distribution Analysis") + + selected_cat_col = st.selectbox("Select categorical column", categorical_cols, key="dist_categorical") + + value_counts = df[selected_cat_col].value_counts() + + col1, col2 = st.columns(2) + + with col1: + # Bar chart + fig = px.bar( + x=value_counts.index[:15], # Top 15 + y=value_counts.values[:15], + title=f'Top Categories in {selected_cat_col}' + ) + st.plotly_chart(fig, use_container_width=True) + + with col2: + # Pie chart + fig = px.pie( + values=value_counts.values[:10], # Top 10 + names=value_counts.index[:10], + title=f'Distribution of {selected_cat_col}' + ) + st.plotly_chart(fig, use_container_width=True) + + # Category statistics + st.markdown("#### 📊 Category Statistics") + col1, col2, col3, col4 = st.columns(4) + + with col1: + st.metric("Unique Categories", df[selected_cat_col].nunique()) + with col2: + st.metric("Most Frequent", value_counts.index[0]) + with col3: + st.metric("Frequency", value_counts.iloc[0]) + with col4: + st.metric("Missing Values", df[selected_cat_col].isnull().sum()) + + def run(self): + """Main application runner""" + self.render_header() + + # Render sidebar configuration + self.render_api_config() + + # Show API key warning if not configured, but don't block the app + if not st.session_state.api_key: + with st.expander("🔑 API Key Configuration (Optional for Basic Features)", expanded=False): + st.warning("**AI-powered features require a Groq API key**") + st.markdown(""" + **For local development:** + 1. Create a `.env` file in your project directory + 2. Add: `GROQ_API_KEY=your_api_key_here` + 3. Restart the application + + **For Streamlit Cloud:** + 1. Go to your app settings + 2. Add to secrets: `GROQ_API_KEY = "your_api_key_here"` + 3. Redeploy the app + + **Get your API key:** [Groq Console](https://console.groq.com/keys) + + **Available without API key:** + - Data upload and preview + - Statistical analysis + - Data visualizations + - Correlation analysis + - Outlier detection + - Complete EDA reports + """) + + # Main content area - always show regardless of API key status + if st.session_state.uploaded_data is not None: + # Show query interface when data is loaded + st.markdown("---") + self.render_query_interface() + else: + # Show data upload when no data is loaded + self.render_data_upload() + +def main(): + """Main function""" + app = NeuralDataAnalyst() + app.run() + +if __name__ == "__main__": + main() \ No newline at end of file