agents-course-v2 / agent.py
D3MI4N's picture
clean up project repo
b36ff59
import base64
from typing import List, TypedDict, Annotated, Optional
from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
from langgraph.graph.message import add_messages
from langgraph.graph import START, StateGraph, MessagesState, END
from langgraph.prebuilt import ToolNode, tools_condition
from dotenv import load_dotenv
from prompts import ORCHESTRATOR_SYSTEM_PROMPT, RETRIEVER_SYSTEM_PROMPT, RESEARCH_SYSTEM_PROMPT, MATH_SYSTEM_PROMPT
from tools import DATABASE_TOOLS, FILE_TOOLS, RESEARCH_TOOLS, MATH_TOOLS, ALL_TOOLS
import gradio as gr
import os
import requests
import pandas as pd
import json
import time
import sys
import traceback
# Load environment variables from .env file
load_dotenv()
# Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# TODO: check if any tools is missing on tools folder (arxiv, youtube, wikipedia, etc.)
# ─────────────────────────────────────────────────────────────────────────────
# AGENT & GRAPH SETUP
# ─────────────────────────────────────────────────────────────────────────────
# Initialize the LLM
llm = ChatOpenAI(model="gpt-4o", temperature=0)
# ─────────────────────────────────────────────────────────────────────────────
# SIMPLE AGENT SETUP (following course pattern)
# ─────────────────────────────────────────────────────────────────────────────
# Build simple agent graph - no complex routing needed
builder = StateGraph(MessagesState)
# Single agent node that handles everything
def gaia_agent(state: MessagesState):
"""
Single agent that handles all GAIA questions with access to all tools.
Lets the LLM naturally decide which tools to use.
"""
messages = state["messages"]
# Create agent with all tools available
agent_llm = llm.bind_tools(ALL_TOOLS)
# Add system message optimized for GAIA
system_message = SystemMessage(content="""
You are a precise QA agent specialized in answering GAIA benchmark questions.
CRITICAL RESPONSE RULES:
- Answer with ONLY the exact answer, no explanations or conversational text
- NO XML tags, NO "FINAL ANSWER:", NO introductory phrases
- For lists: comma-separated, alphabetized if requested, no trailing punctuation
- For numbers: use exact format requested (USD as 12.34, codes bare, etc.)
- For yes/no: respond only "Yes" or "No"
AVAILABLE TOOLS:
- Database search tools: Use to find similar questions in the knowledge base
- File processing tools: Use for Excel, CSV, audio, video, image analysis
- Research tools: Use for web search and current information
- Math tools: Use for calculations and numerical analysis
WORKFLOW:
1. First try database search tools to find similar questions
2. If database returns "NO_EXACT_MATCH", continue with other appropriate tools
3. Use research tools for web search if needed
4. Use math tools for calculations if needed
5. Always provide the exact final answer, never return internal tool messages
IMPORTANT: Never return tool result messages like "NO_EXACT_MATCH" as your final answer.
Always process the question and provide the actual answer.
Your goal is to provide exact answers that match GAIA ground truth precisely.
""".strip())
messages_with_system = [system_message] + messages
# Process the message
response = agent_llm.invoke(messages_with_system)
return {"messages": [response]}
# Simple routing: tools or end
def should_continue(state: MessagesState):
"""Simple routing: use tools if requested, otherwise end."""
last_message = state["messages"][-1]
# If agent wants to use tools, go to tools
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
return "tools"
# Otherwise, we're done
return END
# Add nodes
builder.add_node("agent", gaia_agent)
builder.add_node("tools", ToolNode(ALL_TOOLS))
# Add edges
builder.add_edge(START, "agent")
builder.add_conditional_edges("agent", should_continue)
builder.add_edge("tools", "agent") # Return to agent after using tools
# Add
graph = builder.compile()
# ─────────────────────────────────────────────────────────────────────────────
# GAIA API INTERACTION FUNCTIONS
# ─────────────────────────────────────────────────────────────────────────────
def get_gaia_questions():
"""Fetch questions from the GAIA API."""
try:
response = requests.get("https://agents-course-unit4-scoring.hf.space/questions")
response.raise_for_status()
return response.json()
except Exception as e:
print(f"Error fetching GAIA questions: {e}")
return []
def get_random_gaia_question():
"""Fetch a single random question from the GAIA API."""
try:
response = requests.get("https://agents-course-unit4-scoring.hf.space/random-question")
response.raise_for_status()
return response.json()
except Exception as e:
print(f"Error fetching random GAIA question: {e}")
return None
def answer_gaia_question(question_text: str, debug: bool = False) -> str:
"""Answer a single GAIA question using the simple agent."""
try:
# Create the initial state
initial_state = {
"messages": [HumanMessage(content=question_text)]
}
if debug:
print(f"πŸ” Processing question: {question_text}")
# Invoke the graph - much simpler now!
result = graph.invoke(initial_state)
if debug:
print(f"πŸ“Š Total messages in conversation: {len(result.get('messages', []))}")
for i, msg in enumerate(result.get('messages', [])):
print(f" Message {i+1}: {type(msg).__name__} - {str(msg.content)[:100]}...")
if result and "messages" in result and result["messages"]:
final_answer = result["messages"][-1].content.strip()
if debug:
print(f"🎯 Final answer: {final_answer}")
return final_answer
else:
return "No answer generated"
except Exception as e:
if debug:
print(f"❌ Error details: {e}")
import traceback
traceback.print_exc()
print(f"Error answering question: {e}")
return f"Error: {str(e)}"
# ─────────────────────────────────────────────────────────────────────────────
# TESTING AND VALIDATION
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
print("πŸ” Enhanced GAIA Agent Graph Structure:")
try:
print(graph.get_graph().draw_mermaid())
except:
print("Could not generate mermaid diagram")
print("\nπŸ§ͺ Testing with GAIA-style questions...")
# Test questions that cover different GAIA capabilities
test_questions = [
"What is 2 + 2?",
"What is the capital of France?",
"List the vegetables from this list: broccoli, apple, carrot. Alphabetize and use comma separation.",
"Given the Excel file at test_sales.xlsx, what were total sales for food? Express in USD with two decimals.",
"Examine the audio file at ./test.wav. What is its transcript?",
]
# Add YouTube test if we have a valid URL
if os.path.exists("test.wav"):
test_questions.append("What does the speaker say in the audio file test.wav?")
for i, question in enumerate(test_questions, 1):
print(f"\nπŸ“ Test {i}: {question}")
try:
answer = answer_gaia_question(question)
print(f"βœ… Answer: {answer!r}")
except Exception as e:
print(f"❌ Error: {e}")
print("-" * 80)
# Test with a real GAIA question if API is available
print("\n🌍 Testing with real GAIA question...")
try:
random_q = get_random_gaia_question()
if random_q:
print(f"πŸ“‹ GAIA Question: {random_q.get('question', 'N/A')}")
answer = answer_gaia_question(random_q.get('question', ''))
print(f"🎯 Agent Answer: {answer!r}")
print(f"πŸ’‘ Task ID: {random_q.get('task_id', 'N/A')}")
except Exception as e:
print(f"Could not test with real GAIA question: {e}")