Spaces:
Sleeping
Sleeping
| import base64 | |
| from typing import List, TypedDict, Annotated, Optional | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage | |
| from langgraph.graph.message import add_messages | |
| from langgraph.graph import START, StateGraph, MessagesState, END | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from dotenv import load_dotenv | |
| from prompts import ORCHESTRATOR_SYSTEM_PROMPT, RETRIEVER_SYSTEM_PROMPT, RESEARCH_SYSTEM_PROMPT, MATH_SYSTEM_PROMPT | |
| from tools import DATABASE_TOOLS, FILE_TOOLS, RESEARCH_TOOLS, MATH_TOOLS, ALL_TOOLS | |
| import gradio as gr | |
| import os | |
| import requests | |
| import pandas as pd | |
| import json | |
| import time | |
| import sys | |
| import traceback | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Fix tokenizer parallelism warning | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # TODO: check if any tools is missing on tools folder (arxiv, youtube, wikipedia, etc.) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # AGENT & GRAPH SETUP | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Initialize the LLM | |
| llm = ChatOpenAI(model="gpt-4o", temperature=0) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SIMPLE AGENT SETUP (following course pattern) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Build simple agent graph - no complex routing needed | |
| builder = StateGraph(MessagesState) | |
| # Single agent node that handles everything | |
| def gaia_agent(state: MessagesState): | |
| """ | |
| Single agent that handles all GAIA questions with access to all tools. | |
| Lets the LLM naturally decide which tools to use. | |
| """ | |
| messages = state["messages"] | |
| # Create agent with all tools available | |
| agent_llm = llm.bind_tools(ALL_TOOLS) | |
| # Add system message optimized for GAIA | |
| system_message = SystemMessage(content=""" | |
| You are a precise QA agent specialized in answering GAIA benchmark questions. | |
| CRITICAL RESPONSE RULES: | |
| - Answer with ONLY the exact answer, no explanations or conversational text | |
| - NO XML tags, NO "FINAL ANSWER:", NO introductory phrases | |
| - For lists: comma-separated, alphabetized if requested, no trailing punctuation | |
| - For numbers: use exact format requested (USD as 12.34, codes bare, etc.) | |
| - For yes/no: respond only "Yes" or "No" | |
| AVAILABLE TOOLS: | |
| - Database search tools: Use to find similar questions in the knowledge base | |
| - File processing tools: Use for Excel, CSV, audio, video, image analysis | |
| - Research tools: Use for web search and current information | |
| - Math tools: Use for calculations and numerical analysis | |
| WORKFLOW: | |
| 1. First try database search tools to find similar questions | |
| 2. If database returns "NO_EXACT_MATCH", continue with other appropriate tools | |
| 3. Use research tools for web search if needed | |
| 4. Use math tools for calculations if needed | |
| 5. Always provide the exact final answer, never return internal tool messages | |
| IMPORTANT: Never return tool result messages like "NO_EXACT_MATCH" as your final answer. | |
| Always process the question and provide the actual answer. | |
| Your goal is to provide exact answers that match GAIA ground truth precisely. | |
| """.strip()) | |
| messages_with_system = [system_message] + messages | |
| # Process the message | |
| response = agent_llm.invoke(messages_with_system) | |
| return {"messages": [response]} | |
| # Simple routing: tools or end | |
| def should_continue(state: MessagesState): | |
| """Simple routing: use tools if requested, otherwise end.""" | |
| last_message = state["messages"][-1] | |
| # If agent wants to use tools, go to tools | |
| if hasattr(last_message, 'tool_calls') and last_message.tool_calls: | |
| return "tools" | |
| # Otherwise, we're done | |
| return END | |
| # Add nodes | |
| builder.add_node("agent", gaia_agent) | |
| builder.add_node("tools", ToolNode(ALL_TOOLS)) | |
| # Add edges | |
| builder.add_edge(START, "agent") | |
| builder.add_conditional_edges("agent", should_continue) | |
| builder.add_edge("tools", "agent") # Return to agent after using tools | |
| # Add | |
| graph = builder.compile() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GAIA API INTERACTION FUNCTIONS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_gaia_questions(): | |
| """Fetch questions from the GAIA API.""" | |
| try: | |
| response = requests.get("https://agents-course-unit4-scoring.hf.space/questions") | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as e: | |
| print(f"Error fetching GAIA questions: {e}") | |
| return [] | |
| def get_random_gaia_question(): | |
| """Fetch a single random question from the GAIA API.""" | |
| try: | |
| response = requests.get("https://agents-course-unit4-scoring.hf.space/random-question") | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as e: | |
| print(f"Error fetching random GAIA question: {e}") | |
| return None | |
| def answer_gaia_question(question_text: str, debug: bool = False) -> str: | |
| """Answer a single GAIA question using the simple agent.""" | |
| try: | |
| # Create the initial state | |
| initial_state = { | |
| "messages": [HumanMessage(content=question_text)] | |
| } | |
| if debug: | |
| print(f"π Processing question: {question_text}") | |
| # Invoke the graph - much simpler now! | |
| result = graph.invoke(initial_state) | |
| if debug: | |
| print(f"π Total messages in conversation: {len(result.get('messages', []))}") | |
| for i, msg in enumerate(result.get('messages', [])): | |
| print(f" Message {i+1}: {type(msg).__name__} - {str(msg.content)[:100]}...") | |
| if result and "messages" in result and result["messages"]: | |
| final_answer = result["messages"][-1].content.strip() | |
| if debug: | |
| print(f"π― Final answer: {final_answer}") | |
| return final_answer | |
| else: | |
| return "No answer generated" | |
| except Exception as e: | |
| if debug: | |
| print(f"β Error details: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print(f"Error answering question: {e}") | |
| return f"Error: {str(e)}" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TESTING AND VALIDATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| print("π Enhanced GAIA Agent Graph Structure:") | |
| try: | |
| print(graph.get_graph().draw_mermaid()) | |
| except: | |
| print("Could not generate mermaid diagram") | |
| print("\nπ§ͺ Testing with GAIA-style questions...") | |
| # Test questions that cover different GAIA capabilities | |
| test_questions = [ | |
| "What is 2 + 2?", | |
| "What is the capital of France?", | |
| "List the vegetables from this list: broccoli, apple, carrot. Alphabetize and use comma separation.", | |
| "Given the Excel file at test_sales.xlsx, what were total sales for food? Express in USD with two decimals.", | |
| "Examine the audio file at ./test.wav. What is its transcript?", | |
| ] | |
| # Add YouTube test if we have a valid URL | |
| if os.path.exists("test.wav"): | |
| test_questions.append("What does the speaker say in the audio file test.wav?") | |
| for i, question in enumerate(test_questions, 1): | |
| print(f"\nπ Test {i}: {question}") | |
| try: | |
| answer = answer_gaia_question(question) | |
| print(f"β Answer: {answer!r}") | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| print("-" * 80) | |
| # Test with a real GAIA question if API is available | |
| print("\nπ Testing with real GAIA question...") | |
| try: | |
| random_q = get_random_gaia_question() | |
| if random_q: | |
| print(f"π GAIA Question: {random_q.get('question', 'N/A')}") | |
| answer = answer_gaia_question(random_q.get('question', '')) | |
| print(f"π― Agent Answer: {answer!r}") | |
| print(f"π‘ Task ID: {random_q.get('task_id', 'N/A')}") | |
| except Exception as e: | |
| print(f"Could not test with real GAIA question: {e}") |