Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import google.generativeai as genai | |
| import os | |
| import json | |
| import base64 | |
| from dotenv import load_dotenv | |
| from streamlit_local_storage import LocalStorage | |
| import plotly.graph_objects as go | |
| import numpy as np | |
| import re | |
| # --- PAGE CONFIGURATION --- | |
| st.set_page_config( | |
| page_title="Math Jegna - Your AI Math Tutor", | |
| page_icon="π§ ", | |
| layout="wide" | |
| ) | |
| # Create an instance of the LocalStorage class | |
| localS = LocalStorage() | |
| # --- HELPER FUNCTIONS --- | |
| def format_chat_for_download(chat_history): | |
| """Formats the chat history into a human-readable string for download.""" | |
| formatted_text = f"# Math Mentor Chat\n\n" | |
| for message in chat_history: | |
| role = "You" if message["role"] == "user" else "Math Mentor" | |
| formatted_text += f"**{role}:**\n{message['content']}\n\n---\n\n" | |
| return formatted_text | |
| def convert_role_for_gemini(role): | |
| """Convert Streamlit chat roles to Gemini API roles""" | |
| if role == "assistant": | |
| return "model" | |
| return role # "user" stays the same | |
| def should_generate_visual(user_prompt, ai_response): | |
| """Determine if a visual aid would be helpful for K-12 concepts""" | |
| elementary_keywords = [ | |
| 'add', 'subtract', 'multiply', 'divide', 'times', 'plus', 'minus', | |
| 'count', 'counting', 'number', 'numbers', 'fraction', 'fractions', | |
| 'shape', 'shapes', 'circle', 'square', 'triangle', 'rectangle', | |
| 'money', 'coins', 'dollars', 'cents', 'time', 'clock', | |
| 'pattern', 'patterns', 'grouping', 'groups', 'tens', 'ones', | |
| 'place value', 'hundred', 'thousand', 'comparison', 'greater', 'less', | |
| 'equal', 'equals', 'measurement', 'length', 'height', 'weight' | |
| ] | |
| combined_text = (user_prompt + " " + ai_response).lower() | |
| return any(keyword in combined_text for keyword in elementary_keywords) | |
| def create_counting_visual(numbers): | |
| """Create visual counting aids using dots/circles""" | |
| try: | |
| if not isinstance(numbers, list): | |
| numbers = [numbers] | |
| fig = go.Figure() | |
| colors = ['red', 'blue', 'green', 'orange', 'purple', 'yellow'] | |
| for i, num in enumerate(numbers[:6]): # Limit to 6 different numbers | |
| if num <= 20: # Only for reasonable counting numbers | |
| # Create dots arranged in rows | |
| dots_per_row = min(5, num) | |
| rows = (num - 1) // dots_per_row + 1 | |
| x_positions = [] | |
| y_positions = [] | |
| for dot in range(num): | |
| row = dot // dots_per_row | |
| col = dot % dots_per_row | |
| x_positions.append(col + i * 7) # Separate groups | |
| y_positions.append(-row + rows - 1) | |
| fig.add_trace(go.Scatter( | |
| x=x_positions, | |
| y=y_positions, | |
| mode='markers', | |
| marker=dict( | |
| size=20, | |
| color=colors[i], | |
| symbol='circle', | |
| line=dict(width=2, color='black') | |
| ), | |
| name=f'{num} items', | |
| showlegend=True | |
| )) | |
| # Add number label | |
| fig.add_annotation( | |
| x=2 + i * 7, | |
| y=-rows - 0.5, | |
| text=str(num), | |
| font=dict(size=24, color=colors[i]), | |
| showlegend=False | |
| ) | |
| fig.update_layout( | |
| title="Counting Visualization", | |
| showlegend=True, | |
| xaxis=dict(showgrid=False, showticklabels=False, zeroline=False), | |
| yaxis=dict(showgrid=False, showticklabels=False, zeroline=False), | |
| height=400, | |
| template="simple_white" | |
| ) | |
| return fig | |
| except: | |
| return None | |
| def create_addition_visual(num1, num2): | |
| """Create visual addition using manipulatives""" | |
| try: | |
| if num1 > 10 or num2 > 10: # Keep it simple for young learners | |
| return None | |
| fig = go.Figure() | |
| # First number (red circles) | |
| x1 = list(range(num1)) | |
| y1 = [1] * num1 | |
| fig.add_trace(go.Scatter( | |
| x=x1, y=y1, | |
| mode='markers', | |
| marker=dict(size=25, color='red', symbol='circle', line=dict(width=2, color='black')), | |
| name=f'First group: {num1}', | |
| showlegend=True | |
| )) | |
| # Second number (blue circles) | |
| x2 = list(range(num1 + 1, num1 + num2 + 1)) | |
| y2 = [1] * num2 | |
| fig.add_trace(go.Scatter( | |
| x=x2, y=y2, | |
| mode='markers', | |
| marker=dict(size=25, color='blue', symbol='circle', line=dict(width=2, color='black')), | |
| name=f'Second group: {num2}', | |
| showlegend=True | |
| )) | |
| # Plus sign | |
| fig.add_annotation( | |
| x=num1 - 0.5, | |
| y=1.5, | |
| text="+", | |
| font=dict(size=30, color='black'), | |
| showlegend=False | |
| ) | |
| # Equals sign and result | |
| fig.add_annotation( | |
| x=num1 + num2 + 0.5, | |
| y=1.5, | |
| text="=", | |
| font=dict(size=30, color='black'), | |
| showlegend=False | |
| ) | |
| fig.add_annotation( | |
| x=num1 + num2 + 1.5, | |
| y=1.5, | |
| text=str(num1 + num2), | |
| font=dict(size=30, color='green'), | |
| showlegend=False | |
| ) | |
| fig.update_layout( | |
| title=f"Addition: {num1} + {num2} = {num1 + num2}", | |
| showlegend=True, | |
| xaxis=dict(showgrid=False, showticklabels=False, zeroline=False, range=[-1, num1 + num2 + 3]), | |
| yaxis=dict(showgrid=False, showticklabels=False, zeroline=False, range=[0, 2.5]), | |
| height=300, | |
| template="simple_white" | |
| ) | |
| return fig | |
| except: | |
| return None | |
| def create_fraction_visual(numerator, denominator): | |
| """Create visual fraction using pie charts or bars""" | |
| try: | |
| if denominator > 12 or numerator > denominator: # Keep it simple | |
| return None | |
| fig = go.Figure() | |
| # Create a circle divided into parts | |
| angles = np.linspace(0, 2*np.pi, denominator + 1) | |
| for i in range(denominator): | |
| # Create each slice | |
| theta = np.linspace(angles[i], angles[i+1], 50) | |
| r = np.ones_like(theta) | |
| x = r * np.cos(theta) | |
| y = r * np.sin(theta) | |
| # Add center point | |
| x = np.concatenate([[0], x, [0]]) | |
| y = np.concatenate([[0], y, [0]]) | |
| color = 'lightblue' if i < numerator else 'lightgray' | |
| fig.add_trace(go.Scatter( | |
| x=x, y=y, | |
| fill='toself', | |
| mode='lines', | |
| line=dict(color='black', width=2), | |
| fillcolor=color, | |
| name=f'Slice {i+1}' if i < numerator else '', | |
| showlegend=False | |
| )) | |
| fig.update_layout( | |
| title=f"Fraction: {numerator}/{denominator}", | |
| xaxis=dict(showgrid=False, showticklabels=False, zeroline=False, scaleanchor="y", scaleratio=1), | |
| yaxis=dict(showgrid=False, showticklabels=False, zeroline=False), | |
| height=400, | |
| template="simple_white" | |
| ) | |
| return fig | |
| except: | |
| return None | |
| def create_place_value_visual(number): | |
| """Create place value visualization""" | |
| try: | |
| if number > 9999: # Keep it reasonable for elementary | |
| return None | |
| str_num = str(number).zfill(4) # Pad with zeros | |
| ones = int(str_num[-1]) | |
| tens = int(str_num[-2]) | |
| hundreds = int(str_num[-3]) | |
| thousands = int(str_num[-4]) | |
| fig = go.Figure() | |
| # Create visual blocks for each place value | |
| positions = [0, 2, 4, 6] # x positions for thousands, hundreds, tens, ones | |
| values = [thousands, hundreds, tens, ones] | |
| labels = ['Thousands', 'Hundreds', 'Tens', 'Ones'] | |
| colors = ['red', 'blue', 'green', 'orange'] | |
| for i, (pos, val, label, color) in enumerate(zip(positions, values, labels, colors)): | |
| if val > 0: | |
| # Create blocks to represent the value | |
| blocks_per_row = min(5, val) | |
| rows = (val - 1) // blocks_per_row + 1 | |
| x_coords = [] | |
| y_coords = [] | |
| for block in range(val): | |
| row = block // blocks_per_row | |
| col = block % blocks_per_row | |
| x_coords.append(pos + col * 0.3) | |
| y_coords.append(row * 0.3) | |
| fig.add_trace(go.Scatter( | |
| x=x_coords, | |
| y=y_coords, | |
| mode='markers', | |
| marker=dict( | |
| size=15, | |
| color=color, | |
| symbol='square', | |
| line=dict(width=1, color='black') | |
| ), | |
| name=f'{label}: {val}', | |
| showlegend=True | |
| )) | |
| # Add place value label | |
| fig.add_annotation( | |
| x=pos + 0.6, | |
| y=-0.5, | |
| text=label, | |
| font=dict(size=12), | |
| showlegend=False | |
| ) | |
| # Add digit | |
| fig.add_annotation( | |
| x=pos + 0.6, | |
| y=-0.8, | |
| text=str(val), | |
| font=dict(size=16, color=colors[i]), | |
| showlegend=False | |
| ) | |
| fig.update_layout( | |
| title=f"Place Value: {number}", | |
| showlegend=True, | |
| xaxis=dict(showgrid=False, showticklabels=False, zeroline=False, range=[-0.5, 7]), | |
| yaxis=dict(showgrid=False, showticklabels=False, zeroline=False), | |
| height=400, | |
| template="simple_white" | |
| ) | |
| return fig | |
| except: | |
| return None | |
| def generate_k12_visual(user_prompt, ai_response): | |
| """Generate age-appropriate visualizations for K-12 students""" | |
| try: | |
| user_lower = user_prompt.lower() | |
| # COUNTING NUMBERS | |
| count_match = re.search(r'count.*?(\d+)', user_lower) | |
| if count_match or 'counting' in user_lower: | |
| number = int(count_match.group(1)) if count_match else 5 | |
| return create_counting_visual(number) | |
| # SIMPLE ADDITION | |
| add_match = re.search(r'(\d+)\s*\+\s*(\d+)', user_prompt) | |
| if add_match and 'add' in user_lower or '+' in user_prompt: | |
| num1, num2 = int(add_match.group(1)), int(add_match.group(2)) | |
| if num1 <= 10 and num2 <= 10: # Keep it simple | |
| return create_addition_visual(num1, num2) | |
| # FRACTIONS | |
| fraction_match = re.search(r'(\d+)/(\d+)', user_prompt) | |
| if fraction_match or 'fraction' in user_lower: | |
| if fraction_match: | |
| num, den = int(fraction_match.group(1)), int(fraction_match.group(2)) | |
| else: | |
| num, den = 1, 2 # Default to 1/2 | |
| return create_fraction_visual(num, den) | |
| # PLACE VALUE | |
| if 'place value' in user_lower or 'place' in user_lower: | |
| place_match = re.search(r'\b(\d{1,4})\b', user_prompt) | |
| if place_match: | |
| number = int(place_match.group(1)) | |
| return create_place_value_visual(number) | |
| # NUMBERS (general counting) | |
| number_match = re.search(r'\b(\d+)\b', user_prompt) | |
| if number_match and any(word in user_lower for word in ['show', 'count', 'number']): | |
| number = int(number_match.group(1)) | |
| if 1 <= number <= 20: | |
| return create_counting_visual(number) | |
| return None | |
| except Exception as e: | |
| st.error(f"Could not generate K-12 visual: {e}") | |
| return None | |
| # --- API KEY & MODEL CONFIGURATION --- | |
| load_dotenv() | |
| api_key = None | |
| try: | |
| api_key = st.secrets["GOOGLE_API_KEY"] | |
| except (KeyError, FileNotFoundError): | |
| api_key = os.getenv("GOOGLE_API_KEY") | |
| if api_key: | |
| genai.configure(api_key=api_key) | |
| # Main text model | |
| model = genai.GenerativeModel( | |
| model_name="gemini-2.5-flash-lite", | |
| system_instruction=""" | |
| You are "Math Jegna", an AI math tutor specializing in K-12 mathematics using the Professor B methodology. | |
| FOCUS ON ELEMENTARY CONCEPTS: | |
| - Basic counting (1-100) | |
| - Simple addition and subtraction (single digits to start) | |
| - Beginning multiplication (times tables) | |
| - Basic fractions (halves, thirds, quarters) | |
| - Place value (ones, tens, hundreds) | |
| - Shape recognition | |
| - Simple word problems | |
| - Money and time concepts | |
| PROFESSOR B METHODOLOGY - ESSENTIAL PRINCIPLES: | |
| 1. Present math as a STORY that connects ideas | |
| 2. Use MENTAL GYMNASTICS - fun games and finger counting | |
| 3. Build from concrete to abstract naturally | |
| 4. NO ROTE MEMORIZATION - focus on understanding patterns and connections | |
| 5. Eliminate math anxiety through simple, truthful explanations | |
| 6. Use manipulatives and visual aids | |
| 7. Allow accelerated learning when student shows mastery | |
| TEACHING STYLE: | |
| - Start with what the child already knows | |
| - Build new concepts as natural extensions of previous learning | |
| - Use simple, clear language appropriate for the age | |
| - Make math enjoyable and reduce tension | |
| - Connect everything to real-world experiences | |
| - Celebrate understanding, not just correct answers | |
| VISUAL AIDS: Mention when visual aids will help, using phrases like: | |
| - "Let me show you this with counting dots..." | |
| - "I'll create a picture to help you see this..." | |
| - "A visual will make this clearer..." | |
| Remember: You're helping young minds discover the beauty and logic of mathematics through stories and connections, not through drilling and memorization. | |
| You are strictly forbidden from answering non-mathematical questions. If asked non-math questions, respond only with: "I can only answer mathematical questions. Please ask me a question about counting, adding, shapes, or another math topic." | |
| """ | |
| ) | |
| else: | |
| st.error("π¨ Google API Key not found! Please add it to your secrets or a local .env file.") | |
| st.stop() | |
| # --- SESSION STATE & LOCAL STORAGE INITIALIZATION --- | |
| if "chats" not in st.session_state: | |
| try: | |
| shared_chat_b64 = st.query_params.get("shared_chat") | |
| if shared_chat_b64: | |
| decoded_chat_json = base64.urlsafe_b64decode(shared_chat_b64).decode() | |
| st.session_state.chats = {"Shared Chat": json.loads(decoded_chat_json)} | |
| st.session_state.active_chat_key = "Shared Chat" | |
| st.query_params.clear() | |
| else: | |
| raise ValueError("No shared chat") | |
| except (TypeError, ValueError, Exception): | |
| saved_data_json = localS.getItem("math_mentor_chats") | |
| if saved_data_json: | |
| saved_data = json.loads(saved_data_json) | |
| st.session_state.chats = saved_data.get("chats", {}) | |
| st.session_state.active_chat_key = saved_data.get("active_chat_key", "New Chat") | |
| else: | |
| st.session_state.chats = { | |
| "New Chat": [ | |
| {"role": "assistant", "content": "Hello! I'm Math Jegna, your friendly math helper! π§ β¨\n\nI love helping kids learn math through fun stories and pictures. Try asking me about:\n- Counting numbers\n- Adding or subtracting\n- Fractions like 1/2\n- Shapes and patterns\n- Or any math question!\n\nWhat would you like to learn about today?"} | |
| ] | |
| } | |
| st.session_state.active_chat_key = "New Chat" | |
| # --- RENAME DIALOG --- | |
| def rename_chat(chat_key): | |
| st.write(f"Enter a new name for '{chat_key}':") | |
| new_name = st.text_input("New Name", key=f"rename_input_{chat_key}") | |
| if st.button("Save", key=f"save_rename_{chat_key}"): | |
| if new_name and new_name not in st.session_state.chats: | |
| st.session_state.chats[new_name] = st.session_state.chats.pop(chat_key) | |
| st.session_state.active_chat_key = new_name | |
| st.rerun() | |
| elif not new_name: | |
| st.error("Name cannot be empty.") | |
| else: | |
| st.error("A chat with this name already exists.") | |
| # --- DELETE CONFIRMATION DIALOG --- | |
| def delete_chat(chat_key): | |
| st.warning(f"Are you sure you want to delete '{chat_key}'? This cannot be undone.") | |
| if st.button("Yes, Delete", type="primary", key=f"confirm_delete_{chat_key}"): | |
| st.session_state.chats.pop(chat_key) | |
| if st.session_state.active_chat_key == chat_key: | |
| st.session_state.active_chat_key = next(iter(st.session_state.chats)) | |
| st.rerun() | |
| # --- SIDEBAR CHAT MANAGEMENT --- | |
| st.sidebar.title("π My Math Chats") | |
| st.sidebar.divider() | |
| if st.sidebar.button("β New Chat", use_container_width=True): | |
| i = 1 | |
| while f"New Chat {i}" in st.session_state.chats: | |
| i += 1 | |
| new_chat_key = f"New Chat {i}" | |
| st.session_state.chats[new_chat_key] = [ | |
| {"role": "assistant", "content": "Hi there! Ready for some fun math learning? Ask me about counting, adding, shapes, or anything else! ππ’"} | |
| ] | |
| st.session_state.active_chat_key = new_chat_key | |
| st.rerun() | |
| st.sidebar.divider() | |
| for chat_key in list(st.session_state.chats.keys()): | |
| is_active = (chat_key == st.session_state.active_chat_key) | |
| expander_label = f"**{chat_key} (Active)**" if is_active else chat_key | |
| with st.sidebar.expander(expander_label): | |
| if st.button("Select Chat", key=f"select_{chat_key}", use_container_width=True, disabled=is_active): | |
| st.session_state.active_chat_key = chat_key | |
| st.rerun() | |
| if st.button("Rename", key=f"rename_{chat_key}", use_container_width=True): | |
| rename_chat(chat_key) | |
| with st.popover("Share", use_container_width=True): | |
| st.markdown("**Download Conversation**") | |
| st.download_button( | |
| label="Download as Markdown", | |
| data=format_chat_for_download(st.session_state.chats[chat_key]), | |
| file_name=f"{chat_key.replace(' ', '_')}.md", | |
| mime="text/markdown" | |
| ) | |
| st.markdown("**Share via Link**") | |
| st.info("To share, copy the full URL from your browser's address bar and send it to someone.") | |
| if st.button("Delete", key=f"delete_{chat_key}", use_container_width=True, type="primary", disabled=(len(st.session_state.chats) <= 1)): | |
| delete_chat(chat_key) | |
| # --- MAIN CHAT INTERFACE --- | |
| active_chat = st.session_state.chats[st.session_state.active_chat_key] | |
| st.title(f"Math Helper: {st.session_state.active_chat_key} π§ ") | |
| st.write("π― Perfect for young learners! Ask about counting, adding, shapes, fractions, and more!") | |
| # Add some example prompts for young learners | |
| with st.expander("π‘ Try asking me about..."): | |
| st.write(""" | |
| - **Counting**: "Show me how to count to 10" | |
| - **Addition**: "What is 3 + 4?" | |
| - **Fractions**: "What is 1/2?" | |
| - **Place Value**: "What is the place value of 325?" | |
| - **Shapes**: "Tell me about triangles" | |
| - **Time**: "How do I read a clock?" | |
| """) | |
| for message in active_chat: | |
| with st.chat_message(name=message["role"], avatar="π§βπ»" if message["role"] == "user" else "π§ "): | |
| st.markdown(message["content"]) | |
| if user_prompt := st.chat_input("Ask me a math question!"): | |
| active_chat.append({"role": "user", "content": user_prompt}) | |
| with st.chat_message("user", avatar="π§βπ»"): | |
| st.markdown(user_prompt) | |
| with st.chat_message("assistant", avatar="π§ "): | |
| with st.spinner("Math Jegna is thinking... π€"): | |
| try: | |
| # Generate text response first | |
| chat_session = model.start_chat(history=[ | |
| {'role': convert_role_for_gemini(msg['role']), 'parts': [msg['content']]} | |
| for msg in active_chat[:-1] if 'content' in msg | |
| ]) | |
| response = chat_session.send_message(user_prompt) | |
| ai_response_text = response.text | |
| st.markdown(ai_response_text) | |
| # Store the text response | |
| active_chat.append({"role": "assistant", "content": ai_response_text}) | |
| # Check if we should generate a visual aid | |
| if should_generate_visual(user_prompt, ai_response_text): | |
| with st.spinner("Creating a helpful picture... π¨"): | |
| k12_fig = generate_k12_visual(user_prompt, ai_response_text) | |
| if k12_fig: | |
| st.plotly_chart(k12_fig, use_container_width=True) | |
| st.success("β¨ Here's a picture to help you understand!") | |
| except Exception as e: | |
| error_message = f"Oops! Something went wrong. Let me try again! π€\n\n**Error:** {e}" | |
| st.error(error_message) | |
| active_chat.append({"role": "assistant", "content": error_message}) | |
| # --- SAVE DATA TO LOCAL STORAGE --- | |
| data_to_save = { | |
| "chats": st.session_state.chats, | |
| "active_chat_key": st.session_state.active_chat_key | |
| } | |
| localS.setItem("math_mentor_chats", json.dumps(data_to_save)) |