|
|
import gradio as gr |
|
|
from huggingface_hub import InferenceClient |
|
|
from datasets import load_dataset |
|
|
import random |
|
|
import re |
|
|
|
|
|
|
|
|
math_samples = None |
|
|
|
|
|
def load_sample_problems(): |
|
|
"""Load sample problems from ALL datasets""" |
|
|
global math_samples |
|
|
if math_samples is not None: |
|
|
return math_samples |
|
|
|
|
|
samples = [] |
|
|
try: |
|
|
|
|
|
gsm8k = load_dataset("openai/gsm8k", "main", streaming=True) |
|
|
for i, item in enumerate(gsm8k["train"]): |
|
|
samples.append(item["question"]) |
|
|
if i >= 50: |
|
|
break |
|
|
|
|
|
|
|
|
fw = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True) |
|
|
fw_count = 0 |
|
|
for item in fw: |
|
|
|
|
|
if any(word in item['text'].lower() for word in ['math', 'calculate', 'solve', 'derivative', 'integral', 'triangle', 'equation']): |
|
|
samples.append(item['text'][:200] + " (Solve this math problem.)") |
|
|
fw_count += 1 |
|
|
if fw_count >= 20: |
|
|
break |
|
|
|
|
|
|
|
|
ds = load_dataset("HuggingFaceH4/ultrachat_200k", streaming=True) |
|
|
ds_count = 0 |
|
|
for item in ds: |
|
|
if 'math' in item['messages'][0]['content'].lower() or 'calculate' in item['messages'][0]['content'].lower(): |
|
|
user_msg = item['messages'][0]['content'] |
|
|
samples.append(user_msg) |
|
|
ds_count += 1 |
|
|
if ds_count >= 20: |
|
|
break |
|
|
|
|
|
print(f"โ
Loaded {len(samples)} samples: GSM8K ({50}), Fineweb-edu ({fw_count}), Ultrachat ({ds_count})") |
|
|
math_samples = samples |
|
|
return samples |
|
|
|
|
|
except Exception as e: |
|
|
print(f"โ ๏ธ Dataset error: {e}, using fallback") |
|
|
math_samples = [ |
|
|
"What is the derivative of f(x) = 3xยฒ + 2x - 1?", |
|
|
"A triangle has sides of length 5, 12, and 13. What is its area?", |
|
|
"If logโ(x) + logโ(x+6) = 4, find the value of x.", |
|
|
"Find the limit: lim(x->0) (sin(x)/x)", |
|
|
"Solve the system: x + 2y = 7, 3x - y = 4", |
|
|
"Calculate the integral of sin(x) from 0 to pi.", |
|
|
"What is the probability of rolling a 6 on a die 3 times in a row?" |
|
|
] |
|
|
return math_samples |
|
|
|
|
|
def create_math_system_message(): |
|
|
"""Specialized system prompt for mathematics with LaTeX""" |
|
|
return """You are Mathetics AI, an advanced mathematics tutor and problem solver. |
|
|
|
|
|
๐งฎ **Your Expertise:** |
|
|
- Step-by-step problem solving with clear explanations |
|
|
- Multiple solution approaches when applicable |
|
|
- Proper mathematical notation and terminology using LaTeX |
|
|
- Verification of answers through different methods |
|
|
|
|
|
๐ **Problem Domains:** |
|
|
- Arithmetic, Algebra, and Number Theory |
|
|
- Geometry, Trigonometry, and Coordinate Geometry |
|
|
- Calculus (Limits, Derivatives, Integrals) |
|
|
- Statistics, Probability, and Data Analysis |
|
|
- Competition Mathematics (AMC, AIME level) |
|
|
|
|
|
๐ก **Teaching Style:** |
|
|
1. **Understand the Problem** - Identify what's being asked |
|
|
2. **Plan the Solution** - Choose the appropriate method |
|
|
3. **Execute Step-by-Step** - Show all work clearly with LaTeX formatting |
|
|
4. **Verify the Answer** - Check if the result makes sense |
|
|
5. **Alternative Methods** - Mention other possible approaches |
|
|
|
|
|
**LaTeX Guidelines:** |
|
|
- Use $...$ for inline math: $x^2 + y^2 = z^2$ |
|
|
- Use $$...$$ for display math |
|
|
- Box final answers: \boxed{answer} |
|
|
- Fractions: \frac{numerator}{denominator} |
|
|
- Limits: \lim_{x \to 0} |
|
|
- Derivatives: \frac{d}{dx} or f'(x) |
|
|
|
|
|
Always be precise, educational, and encourage mathematical thinking.""" |
|
|
|
|
|
def render_latex(text): |
|
|
"""Enhanced LaTeX rendering - fixes raw code output""" |
|
|
if not text or len(text) < 5: |
|
|
return text |
|
|
|
|
|
try: |
|
|
|
|
|
text = re.sub(r'(?<!\\)\$([^\$]+)\$(?!\$)', r'$\1$', text) |
|
|
text = re.sub(r'\$\$([^\$]+)\$\$', r'$$\1$$', text) |
|
|
text = re.sub(r'\\\[([^\\]+)\\\]', r'$$\1$$', text) |
|
|
text = re.sub(r'\\\(([^\\]+)\\\)', r'$\1$', text) |
|
|
|
|
|
|
|
|
text = re.sub(r'\\(lim|frac|sqrt|int|sum|prod|partial|nabla|infty|to|le|ge|neq|approx|cdot|times|div|deg|prime|log|ln|sin|cos|tan|cot|sec|csc|arcsin|arccos|arctan|sinh|cosh)', r'\1', text) |
|
|
|
|
|
|
|
|
text = re.sub(r'\\boxed\{([^}]+)\}', r'$$\boxed{\1}$$', text) |
|
|
text = re.sub(r'\\frac\{([^}]+)\}\{([^}]+)\}', r'$\frac{\1}{\2}$', text) |
|
|
|
|
|
|
|
|
text = re.sub(r'\s*([\$\\])\s*', r'\1', text) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"โ ๏ธ LaTeX formatting error: {e}") |
|
|
|
|
|
return text |
|
|
|
|
|
def respond(message, history, system_message, max_tokens, temperature, top_p): |
|
|
"""Enhanced response with proper LaTeX streaming""" |
|
|
yield "๐ค Thinking step-by-step..." |
|
|
|
|
|
client = InferenceClient(model="Qwen/Qwen2.5-Math-7B-Instruct") |
|
|
|
|
|
messages = [] |
|
|
if system_message: |
|
|
messages.append({"role": "system", "content": system_message}) |
|
|
|
|
|
|
|
|
for user_msg, assistant_msg in history: |
|
|
if user_msg: |
|
|
messages.append({"role": "user", "content": user_msg}) |
|
|
if assistant_msg: |
|
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
response = "" |
|
|
max_tokens = max(max_tokens, 1536) |
|
|
|
|
|
try: |
|
|
for message_chunk in client.chat_completion( |
|
|
messages, |
|
|
max_tokens=max_tokens, |
|
|
stream=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
timeout=60 |
|
|
): |
|
|
choices = message_chunk.choices |
|
|
if len(choices) and choices[0].delta.content: |
|
|
token = choices[0].delta.content |
|
|
response += token |
|
|
|
|
|
if len(response) % 50 == 0 or token.strip() in ['.', '!', '?', '\n']: |
|
|
formatted = render_latex(response) |
|
|
yield formatted |
|
|
else: |
|
|
yield response |
|
|
|
|
|
final_formatted = render_latex(response.strip()) |
|
|
yield final_formatted |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"โ **Error**: {str(e)[:100]}...\n\n๐ก Try a simpler problem or wait a moment." |
|
|
yield error_msg |
|
|
|
|
|
def get_random_sample(): |
|
|
"""Get a random sample problem - loads datasets if needed""" |
|
|
global math_samples |
|
|
if math_samples is None: |
|
|
math_samples = load_datasets_lazy() |
|
|
if math_samples: |
|
|
return random.choice(math_samples) |
|
|
return "Solve for x: 2xยฒ + 5x - 3 = 0" |
|
|
|
|
|
def insert_sample_to_chat(difficulty): |
|
|
"""Insert random sample into chat input""" |
|
|
sample = get_random_sample() |
|
|
return sample |
|
|
|
|
|
def show_help(): |
|
|
return """**๐งฎ Math Help Tips:** |
|
|
|
|
|
1. **Be Specific**: "Find the derivative of xยฒ + 3x" instead of "help with calculus" |
|
|
2. **Request Steps**: "Show me step-by-step how to solve..." |
|
|
3. **Ask for Verification**: "Check if my answer x=5 is correct" |
|
|
4. **Alternative Methods**: "What's another way to solve this integral?" |
|
|
5. **Use Clear Notation**: "lim(x->0)" for limits |
|
|
|
|
|
**Pro Tip**: Crank tokens to 1500+ for competition problems!""" |
|
|
|
|
|
|
|
|
def chat_response(message, history): |
|
|
"""Main chat function - compatible with all Gradio versions""" |
|
|
bot_response = "" |
|
|
for response in respond( |
|
|
message, |
|
|
history, |
|
|
create_math_system_message(), |
|
|
1024, |
|
|
0.3, |
|
|
0.85 |
|
|
): |
|
|
bot_response = response |
|
|
history.append([message, bot_response]) |
|
|
yield history, "" |
|
|
|
|
|
return history, "" |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="๐งฎ Mathetics AI", |
|
|
theme=gr.themes.Soft(), |
|
|
css=""" |
|
|
/* Enhanced math rendering */ |
|
|
.markdown-body { font-family: 'Times New Roman', Georgia, serif; line-height: 1.6; } |
|
|
.katex { font-size: 1.1em !important; color: #2c3e50; } |
|
|
.katex-display { font-size: 1.3em !important; text-align: center; margin: 1em 0; padding: 10px; background: #f8f9fa; border-radius: 8px; } |
|
|
|
|
|
/* Chat styling */ |
|
|
.message { margin: 10px 0; padding: 12px; border-radius: 8px; } |
|
|
.user { background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%); border-left: 4px solid #2196f3; } |
|
|
.assistant { background: linear-gradient(135deg, #f5f5f5 0%, #eeeeee 100%); border-left: 4px solid #4caf50; } |
|
|
|
|
|
/* Sidebar */ |
|
|
.difficulty-selector { background: linear-gradient(135deg, #fff3e0 0%, #ffe0b2 100%); padding: 15px; border-radius: 10px; margin: 10px 0; border: 1px solid #ffcc80; } |
|
|
|
|
|
/* Responsive */ |
|
|
@media (max-width: 768px) { .katex { font-size: 1em !important; } .katex-display { font-size: 1.1em !important; } } |
|
|
""" |
|
|
) as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# ๐งฎ **Mathetics AI** - Advanced Mathematics Solver |
|
|
|
|
|
**Your Personal AI Math Tutor** | Step-by-step solutions with beautiful LaTeX rendering |
|
|
|
|
|
--- |
|
|
""") |
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot( |
|
|
height=500, |
|
|
show_label=False, |
|
|
avatar_images=("๐งโ๐", "๐ค"), |
|
|
bubble_full_width=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
msg = gr.Textbox( |
|
|
placeholder="Ask: 'Find the derivative of 3xยฒ + 2x - 1'", |
|
|
scale=4, |
|
|
show_label=False, |
|
|
lines=2 |
|
|
) |
|
|
submit_btn = gr.Button("๐ Solve", variant="primary", scale=1) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
token_slider = gr.Slider(256, 2048, value=1024, step=128, label="๐ Max Tokens") |
|
|
temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="๐ฏ Temperature") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
difficulty_preset = gr.Dropdown( |
|
|
choices=["Elementary", "High School", "College", "Competition"], |
|
|
value="High School", |
|
|
label="๐ฏ Difficulty", |
|
|
elem_classes=["difficulty-selector"] |
|
|
) |
|
|
sample_btn = gr.Button("๐ฒ Random Problem", variant="secondary") |
|
|
help_btn = gr.Button("โ Help", variant="secondary") |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Find the derivative of f(x) = 3xยฒ + 2x - 1"], |
|
|
["A triangle has sides 5, 12, and 13. What is its area?"], |
|
|
["Solve: lim(x->0) sin(x)/x"], |
|
|
["What is โซ(2xยณ - 5x + 3) dx?"], |
|
|
["Solve the system: x + 2y = 7, 3x - y = 4"] |
|
|
], |
|
|
inputs=msg, |
|
|
label="๐ก Quick Examples" |
|
|
) |
|
|
|
|
|
|
|
|
def submit_message(message, history): |
|
|
return chat_response(message, history) |
|
|
|
|
|
def clear_chat(): |
|
|
return [], "" |
|
|
|
|
|
msg.submit(submit_message, [msg, chatbot], [msg, chatbot]) |
|
|
submit_btn.click(submit_message, [msg, chatbot], [msg, chatbot]) |
|
|
|
|
|
sample_btn.click( |
|
|
insert_sample_to_chat, |
|
|
inputs=[difficulty_preset], |
|
|
outputs=msg |
|
|
) |
|
|
|
|
|
help_btn.click( |
|
|
show_help, |
|
|
outputs=gr.Markdown(visible=True, label="Help") |
|
|
) |
|
|
|
|
|
|
|
|
gr.Button("๐๏ธ Clear Chat", variant="secondary").click( |
|
|
clear_chat, |
|
|
outputs=[chatbot, msg] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
**๐ง Tech:** Qwen2.5-Math-7B โข LaTeX rendering โข Streaming responses |
|
|
**๐ก Tip:** Use "lim(x->0)" for limits, crank tokens for complex problems |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |