File size: 22,055 Bytes
454acc0
 
 
 
 
 
 
19dba77
 
 
454acc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0fc94b
 
 
 
 
 
9b78dcf
c1e6e2a
 
 
 
 
 
 
 
 
9b78dcf
 
 
c1e6e2a
9b78dcf
c1e6e2a
 
9b78dcf
c1e6e2a
 
9b78dcf
c1e6e2a
 
 
 
 
 
 
 
19dba77
c1e6e2a
 
19dba77
c1e6e2a
 
 
 
 
19dba77
 
c1e6e2a
 
 
 
 
 
 
 
 
 
 
19dba77
 
c1e6e2a
 
 
 
 
 
 
19dba77
9b78dcf
c1e6e2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19dba77
c1e6e2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19dba77
c1e6e2a
 
 
 
 
 
 
 
 
 
 
19dba77
c1e6e2a
 
 
19dba77
c1e6e2a
19dba77
c1e6e2a
 
 
 
 
 
 
 
 
9b78dcf
c1e6e2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19dba77
c1e6e2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19dba77
c1e6e2a
 
19dba77
c1e6e2a
 
 
 
 
19dba77
c1e6e2a
 
 
 
 
 
 
 
 
 
 
 
 
19dba77
c1e6e2a
 
 
 
 
 
 
19dba77
 
c1e6e2a
 
 
 
 
 
 
19dba77
c1e6e2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b78dcf
 
 
 
c1e6e2a
9b78dcf
 
454acc0
 
 
 
 
 
 
 
 
 
 
9b78dcf
 
454acc0
 
 
c1e6e2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454acc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1e6e2a
454acc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1e6e2a
454acc0
 
 
 
 
 
 
 
c1e6e2a
454acc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1e6e2a
 
 
 
 
 
 
 
 
 
 
 
 
454acc0
 
 
 
 
c1e6e2a
454acc0
 
 
 
 
c1e6e2a
454acc0
9b78dcf
c0fc94b
 
9b78dcf
c0fc94b
454acc0
 
 
9b78dcf
 
19dba77
9b78dcf
 
 
c1e6e2a
 
 
 
 
454acc0
 
c1e6e2a
454acc0
 
 
 
 
19dba77
454acc0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
import streamlit as st
import google.generativeai as genai
import os
import json
import base64
from dotenv import load_dotenv
from streamlit_local_storage import LocalStorage
import plotly.graph_objects as go
import numpy as np
import re

# --- PAGE CONFIGURATION ---
st.set_page_config(
    page_title="Math Jegna - Your AI Math Tutor",
    page_icon="🧠",
    layout="wide"
)

# Create an instance of the LocalStorage class
localS = LocalStorage()

# --- HELPER FUNCTIONS ---
def format_chat_for_download(chat_history):
    """Formats the chat history into a human-readable string for download."""
    formatted_text = f"# Math Mentor Chat\n\n"
    for message in chat_history:
        role = "You" if message["role"] == "user" else "Math Mentor"
        formatted_text += f"**{role}:**\n{message['content']}\n\n---\n\n"
    return formatted_text

def convert_role_for_gemini(role):
    """Convert Streamlit chat roles to Gemini API roles"""
    if role == "assistant":
        return "model"
    return role  # "user" stays the same

def should_generate_visual(user_prompt, ai_response):
    """Determine if a visual aid would be helpful for K-12 concepts"""
    elementary_keywords = [
        'add', 'subtract', 'multiply', 'divide', 'times', 'plus', 'minus',
        'count', 'counting', 'number', 'numbers', 'fraction', 'fractions',
        'shape', 'shapes', 'circle', 'square', 'triangle', 'rectangle',
        'money', 'coins', 'dollars', 'cents', 'time', 'clock',
        'pattern', 'patterns', 'grouping', 'groups', 'tens', 'ones',
        'place value', 'hundred', 'thousand', 'comparison', 'greater', 'less',
        'equal', 'equals', 'measurement', 'length', 'height', 'weight'
    ]
    
    combined_text = (user_prompt + " " + ai_response).lower()
    return any(keyword in combined_text for keyword in elementary_keywords)

def create_counting_visual(numbers):
    """Create visual counting aids using dots/circles"""
    try:
        if not isinstance(numbers, list):
            numbers = [numbers]
        
        fig = go.Figure()
        colors = ['red', 'blue', 'green', 'orange', 'purple', 'yellow']
        
        for i, num in enumerate(numbers[:6]):  # Limit to 6 different numbers
            if num <= 20:  # Only for reasonable counting numbers
                # Create dots arranged in rows
                dots_per_row = min(5, num)
                rows = (num - 1) // dots_per_row + 1
                
                x_positions = []
                y_positions = []
                
                for dot in range(num):
                    row = dot // dots_per_row
                    col = dot % dots_per_row
                    x_positions.append(col + i * 7)  # Separate groups
                    y_positions.append(-row + rows - 1)
                
                fig.add_trace(go.Scatter(
                    x=x_positions,
                    y=y_positions,
                    mode='markers',
                    marker=dict(
                        size=20,
                        color=colors[i],
                        symbol='circle',
                        line=dict(width=2, color='black')
                    ),
                    name=f'{num} items',
                    showlegend=True
                ))
                
                # Add number label
                fig.add_annotation(
                    x=2 + i * 7,
                    y=-rows - 0.5,
                    text=str(num),
                    font=dict(size=24, color=colors[i]),
                    showlegend=False
                )
        
        fig.update_layout(
            title="Counting Visualization",
            showlegend=True,
            xaxis=dict(showgrid=False, showticklabels=False, zeroline=False),
            yaxis=dict(showgrid=False, showticklabels=False, zeroline=False),
            height=400,
            template="simple_white"
        )
        
        return fig
    except:
        return None

def create_addition_visual(num1, num2):
    """Create visual addition using manipulatives"""
    try:
        if num1 > 10 or num2 > 10:  # Keep it simple for young learners
            return None
            
        fig = go.Figure()
        
        # First number (red circles)
        x1 = list(range(num1))
        y1 = [1] * num1
        fig.add_trace(go.Scatter(
            x=x1, y=y1,
            mode='markers',
            marker=dict(size=25, color='red', symbol='circle', line=dict(width=2, color='black')),
            name=f'First group: {num1}',
            showlegend=True
        ))
        
        # Second number (blue circles)
        x2 = list(range(num1 + 1, num1 + num2 + 1))
        y2 = [1] * num2
        fig.add_trace(go.Scatter(
            x=x2, y=y2,
            mode='markers',
            marker=dict(size=25, color='blue', symbol='circle', line=dict(width=2, color='black')),
            name=f'Second group: {num2}',
            showlegend=True
        ))
        
        # Plus sign
        fig.add_annotation(
            x=num1 - 0.5,
            y=1.5,
            text="+",
            font=dict(size=30, color='black'),
            showlegend=False
        )
        
        # Equals sign and result
        fig.add_annotation(
            x=num1 + num2 + 0.5,
            y=1.5,
            text="=",
            font=dict(size=30, color='black'),
            showlegend=False
        )
        
        fig.add_annotation(
            x=num1 + num2 + 1.5,
            y=1.5,
            text=str(num1 + num2),
            font=dict(size=30, color='green'),
            showlegend=False
        )
        
        fig.update_layout(
            title=f"Addition: {num1} + {num2} = {num1 + num2}",
            showlegend=True,
            xaxis=dict(showgrid=False, showticklabels=False, zeroline=False, range=[-1, num1 + num2 + 3]),
            yaxis=dict(showgrid=False, showticklabels=False, zeroline=False, range=[0, 2.5]),
            height=300,
            template="simple_white"
        )
        
        return fig
    except:
        return None

def create_fraction_visual(numerator, denominator):
    """Create visual fraction using pie charts or bars"""
    try:
        if denominator > 12 or numerator > denominator:  # Keep it simple
            return None
            
        fig = go.Figure()
        
        # Create a circle divided into parts
        angles = np.linspace(0, 2*np.pi, denominator + 1)
        
        for i in range(denominator):
            # Create each slice
            theta = np.linspace(angles[i], angles[i+1], 50)
            r = np.ones_like(theta)
            x = r * np.cos(theta)
            y = r * np.sin(theta)
            
            # Add center point
            x = np.concatenate([[0], x, [0]])
            y = np.concatenate([[0], y, [0]])
            
            color = 'lightblue' if i < numerator else 'lightgray'
            
            fig.add_trace(go.Scatter(
                x=x, y=y,
                fill='toself',
                mode='lines',
                line=dict(color='black', width=2),
                fillcolor=color,
                name=f'Slice {i+1}' if i < numerator else '',
                showlegend=False
            ))
        
        fig.update_layout(
            title=f"Fraction: {numerator}/{denominator}",
            xaxis=dict(showgrid=False, showticklabels=False, zeroline=False, scaleanchor="y", scaleratio=1),
            yaxis=dict(showgrid=False, showticklabels=False, zeroline=False),
            height=400,
            template="simple_white"
        )
        
        return fig
    except:
        return None

def create_place_value_visual(number):
    """Create place value visualization"""
    try:
        if number > 9999:  # Keep it reasonable for elementary
            return None
            
        str_num = str(number).zfill(4)  # Pad with zeros
        ones = int(str_num[-1])
        tens = int(str_num[-2])
        hundreds = int(str_num[-3])
        thousands = int(str_num[-4])
        
        fig = go.Figure()
        
        # Create visual blocks for each place value
        positions = [0, 2, 4, 6]  # x positions for thousands, hundreds, tens, ones
        values = [thousands, hundreds, tens, ones]
        labels = ['Thousands', 'Hundreds', 'Tens', 'Ones']
        colors = ['red', 'blue', 'green', 'orange']
        
        for i, (pos, val, label, color) in enumerate(zip(positions, values, labels, colors)):
            if val > 0:
                # Create blocks to represent the value
                blocks_per_row = min(5, val)
                rows = (val - 1) // blocks_per_row + 1
                
                x_coords = []
                y_coords = []
                
                for block in range(val):
                    row = block // blocks_per_row
                    col = block % blocks_per_row
                    x_coords.append(pos + col * 0.3)
                    y_coords.append(row * 0.3)
                
                fig.add_trace(go.Scatter(
                    x=x_coords,
                    y=y_coords,
                    mode='markers',
                    marker=dict(
                        size=15,
                        color=color,
                        symbol='square',
                        line=dict(width=1, color='black')
                    ),
                    name=f'{label}: {val}',
                    showlegend=True
                ))
            
            # Add place value label
            fig.add_annotation(
                x=pos + 0.6,
                y=-0.5,
                text=label,
                font=dict(size=12),
                showlegend=False
            )
            
            # Add digit
            fig.add_annotation(
                x=pos + 0.6,
                y=-0.8,
                text=str(val),
                font=dict(size=16, color=colors[i]),
                showlegend=False
            )
        
        fig.update_layout(
            title=f"Place Value: {number}",
            showlegend=True,
            xaxis=dict(showgrid=False, showticklabels=False, zeroline=False, range=[-0.5, 7]),
            yaxis=dict(showgrid=False, showticklabels=False, zeroline=False),
            height=400,
            template="simple_white"
        )
        
        return fig
    except:
        return None

def generate_k12_visual(user_prompt, ai_response):
    """Generate age-appropriate visualizations for K-12 students"""
    try:
        user_lower = user_prompt.lower()
        
        # COUNTING NUMBERS
        count_match = re.search(r'count.*?(\d+)', user_lower)
        if count_match or 'counting' in user_lower:
            number = int(count_match.group(1)) if count_match else 5
            return create_counting_visual(number)
        
        # SIMPLE ADDITION
        add_match = re.search(r'(\d+)\s*\+\s*(\d+)', user_prompt)
        if add_match and 'add' in user_lower or '+' in user_prompt:
            num1, num2 = int(add_match.group(1)), int(add_match.group(2))
            if num1 <= 10 and num2 <= 10:  # Keep it simple
                return create_addition_visual(num1, num2)
        
        # FRACTIONS
        fraction_match = re.search(r'(\d+)/(\d+)', user_prompt)
        if fraction_match or 'fraction' in user_lower:
            if fraction_match:
                num, den = int(fraction_match.group(1)), int(fraction_match.group(2))
            else:
                num, den = 1, 2  # Default to 1/2
            return create_fraction_visual(num, den)
        
        # PLACE VALUE
        if 'place value' in user_lower or 'place' in user_lower:
            place_match = re.search(r'\b(\d{1,4})\b', user_prompt)
            if place_match:
                number = int(place_match.group(1))
                return create_place_value_visual(number)
        
        # NUMBERS (general counting)
        number_match = re.search(r'\b(\d+)\b', user_prompt)
        if number_match and any(word in user_lower for word in ['show', 'count', 'number']):
            number = int(number_match.group(1))
            if 1 <= number <= 20:
                return create_counting_visual(number)
        
        return None
        
    except Exception as e:
        st.error(f"Could not generate K-12 visual: {e}")
        return None

# --- API KEY & MODEL CONFIGURATION ---
load_dotenv()
api_key = None

try:
    api_key = st.secrets["GOOGLE_API_KEY"]
except (KeyError, FileNotFoundError):
    api_key = os.getenv("GOOGLE_API_KEY")

if api_key:
    genai.configure(api_key=api_key)
    
    # Main text model
    model = genai.GenerativeModel(
        model_name="gemini-2.5-flash-lite",
        system_instruction="""
        You are "Math Jegna", an AI math tutor specializing in K-12 mathematics using the Professor B methodology.
        
        FOCUS ON ELEMENTARY CONCEPTS:
        - Basic counting (1-100)
        - Simple addition and subtraction (single digits to start)
        - Beginning multiplication (times tables)
        - Basic fractions (halves, thirds, quarters)
        - Place value (ones, tens, hundreds)
        - Shape recognition
        - Simple word problems
        - Money and time concepts
        
        PROFESSOR B METHODOLOGY - ESSENTIAL PRINCIPLES:
        1. Present math as a STORY that connects ideas
        2. Use MENTAL GYMNASTICS - fun games and finger counting
        3. Build from concrete to abstract naturally
        4. NO ROTE MEMORIZATION - focus on understanding patterns and connections
        5. Eliminate math anxiety through simple, truthful explanations
        6. Use manipulatives and visual aids
        7. Allow accelerated learning when student shows mastery
        
        TEACHING STYLE:
        - Start with what the child already knows
        - Build new concepts as natural extensions of previous learning
        - Use simple, clear language appropriate for the age
        - Make math enjoyable and reduce tension
        - Connect everything to real-world experiences
        - Celebrate understanding, not just correct answers
        
        VISUAL AIDS: Mention when visual aids will help, using phrases like:
        - "Let me show you this with counting dots..."
        - "I'll create a picture to help you see this..."
        - "A visual will make this clearer..."
        
        Remember: You're helping young minds discover the beauty and logic of mathematics through stories and connections, not through drilling and memorization.
        
        You are strictly forbidden from answering non-mathematical questions. If asked non-math questions, respond only with: "I can only answer mathematical questions. Please ask me a question about counting, adding, shapes, or another math topic."
        """
    )
else:
    st.error("🚨 Google API Key not found! Please add it to your secrets or a local .env file.")
    st.stop()

# --- SESSION STATE & LOCAL STORAGE INITIALIZATION ---
if "chats" not in st.session_state:
    try:
        shared_chat_b64 = st.query_params.get("shared_chat")
        if shared_chat_b64:
            decoded_chat_json = base64.urlsafe_b64decode(shared_chat_b64).decode()
            st.session_state.chats = {"Shared Chat": json.loads(decoded_chat_json)}
            st.session_state.active_chat_key = "Shared Chat"
            st.query_params.clear()
        else:
            raise ValueError("No shared chat")
    except (TypeError, ValueError, Exception):
        saved_data_json = localS.getItem("math_mentor_chats")
        if saved_data_json:
            saved_data = json.loads(saved_data_json)
            st.session_state.chats = saved_data.get("chats", {})
            st.session_state.active_chat_key = saved_data.get("active_chat_key", "New Chat")
        else:
            st.session_state.chats = {
                "New Chat": [
                    {"role": "assistant", "content": "Hello! I'm Math Jegna, your friendly math helper! 🧠✨\n\nI love helping kids learn math through fun stories and pictures. Try asking me about:\n- Counting numbers\n- Adding or subtracting\n- Fractions like 1/2\n- Shapes and patterns\n- Or any math question!\n\nWhat would you like to learn about today?"}
                ]
            }
            st.session_state.active_chat_key = "New Chat"

# --- RENAME DIALOG ---
@st.dialog("Rename Chat")
def rename_chat(chat_key):
    st.write(f"Enter a new name for '{chat_key}':")
    new_name = st.text_input("New Name", key=f"rename_input_{chat_key}")
    if st.button("Save", key=f"save_rename_{chat_key}"):
        if new_name and new_name not in st.session_state.chats:
            st.session_state.chats[new_name] = st.session_state.chats.pop(chat_key)
            st.session_state.active_chat_key = new_name
            st.rerun()
        elif not new_name:
            st.error("Name cannot be empty.")
        else:
            st.error("A chat with this name already exists.")

# --- DELETE CONFIRMATION DIALOG ---
@st.dialog("Delete Chat")
def delete_chat(chat_key):
    st.warning(f"Are you sure you want to delete '{chat_key}'? This cannot be undone.")
    if st.button("Yes, Delete", type="primary", key=f"confirm_delete_{chat_key}"):
        st.session_state.chats.pop(chat_key)
        if st.session_state.active_chat_key == chat_key:
            st.session_state.active_chat_key = next(iter(st.session_state.chats))
        st.rerun()

# --- SIDEBAR CHAT MANAGEMENT ---
st.sidebar.title("πŸ“š My Math Chats")
st.sidebar.divider()

if st.sidebar.button("βž• New Chat", use_container_width=True):
    i = 1
    while f"New Chat {i}" in st.session_state.chats:
        i += 1
    new_chat_key = f"New Chat {i}"
    st.session_state.chats[new_chat_key] = [
        {"role": "assistant", "content": "Hi there! Ready for some fun math learning? Ask me about counting, adding, shapes, or anything else! πŸš€πŸ”’"}
    ]
    st.session_state.active_chat_key = new_chat_key
    st.rerun()

st.sidebar.divider()

for chat_key in list(st.session_state.chats.keys()):
    is_active = (chat_key == st.session_state.active_chat_key)
    expander_label = f"**{chat_key} (Active)**" if is_active else chat_key
    
    with st.sidebar.expander(expander_label):
        if st.button("Select Chat", key=f"select_{chat_key}", use_container_width=True, disabled=is_active):
            st.session_state.active_chat_key = chat_key
            st.rerun()
        
        if st.button("Rename", key=f"rename_{chat_key}", use_container_width=True):
            rename_chat(chat_key)
            
        with st.popover("Share", use_container_width=True):
            st.markdown("**Download Conversation**")
            st.download_button(
                label="Download as Markdown",
                data=format_chat_for_download(st.session_state.chats[chat_key]),
                file_name=f"{chat_key.replace(' ', '_')}.md",
                mime="text/markdown"
            )
            st.markdown("**Share via Link**")
            st.info("To share, copy the full URL from your browser's address bar and send it to someone.")
        
        if st.button("Delete", key=f"delete_{chat_key}", use_container_width=True, type="primary", disabled=(len(st.session_state.chats) <= 1)):
            delete_chat(chat_key)

# --- MAIN CHAT INTERFACE ---
active_chat = st.session_state.chats[st.session_state.active_chat_key]

st.title(f"Math Helper: {st.session_state.active_chat_key} 🧠")
st.write("🎯 Perfect for young learners! Ask about counting, adding, shapes, fractions, and more!")

# Add some example prompts for young learners
with st.expander("πŸ’‘ Try asking me about..."):
    st.write("""
    - **Counting**: "Show me how to count to 10"
    - **Addition**: "What is 3 + 4?"
    - **Fractions**: "What is 1/2?"
    - **Place Value**: "What is the place value of 325?"
    - **Shapes**: "Tell me about triangles"
    - **Time**: "How do I read a clock?"
    """)

for message in active_chat:
    with st.chat_message(name=message["role"], avatar="πŸ§‘β€πŸ’»" if message["role"] == "user" else "🧠"):
        st.markdown(message["content"])

if user_prompt := st.chat_input("Ask me a math question!"):
    active_chat.append({"role": "user", "content": user_prompt})
    with st.chat_message("user", avatar="πŸ§‘β€πŸ’»"):
        st.markdown(user_prompt)

    with st.chat_message("assistant", avatar="🧠"):
        with st.spinner("Math Jegna is thinking... πŸ€”"):
            try:
                # Generate text response first
                chat_session = model.start_chat(history=[
                    {'role': convert_role_for_gemini(msg['role']), 'parts': [msg['content']]} 
                    for msg in active_chat[:-1] if 'content' in msg
                ])
                response = chat_session.send_message(user_prompt)
                ai_response_text = response.text
                st.markdown(ai_response_text)
                
                # Store the text response
                active_chat.append({"role": "assistant", "content": ai_response_text})
                
                # Check if we should generate a visual aid
                if should_generate_visual(user_prompt, ai_response_text):
                    with st.spinner("Creating a helpful picture... 🎨"):
                        k12_fig = generate_k12_visual(user_prompt, ai_response_text)
                        if k12_fig:
                            st.plotly_chart(k12_fig, use_container_width=True)
                            st.success("✨ Here's a picture to help you understand!")

            except Exception as e:
                error_message = f"Oops! Something went wrong. Let me try again! πŸ€–\n\n**Error:** {e}"
                st.error(error_message)
                active_chat.append({"role": "assistant", "content": error_message})

# --- SAVE DATA TO LOCAL STORAGE ---
data_to_save = {
    "chats": st.session_state.chats,
    "active_chat_key": st.session_state.active_chat_key
}
localS.setItem("math_mentor_chats", json.dumps(data_to_save))