Spaces:
Sleeping
Sleeping
Add configuration, graph, runner, and tools modules to enhance agent functionality. Introduce a Configuration class for managing parameters, implement an AgentRunner for executing the agent graph, and create tools for general search and mathematical calculations. Update test_agent.py to reflect new import paths and improve overall code organization.
Browse files- services/configuration.py β configuration.py +0 -0
- services/graph.py β graph.py +52 -48
- api/runner.py β runner.py +59 -3
- test_agent.py +1 -1
- services/tools.py β tools.py +30 -0
services/configuration.py β configuration.py
RENAMED
|
File without changes
|
services/graph.py β graph.py
RENAMED
|
@@ -6,14 +6,14 @@ from datetime import datetime
|
|
| 6 |
from typing import Dict, List, Optional, TypedDict, Union
|
| 7 |
|
| 8 |
import yaml
|
| 9 |
-
from services.configuration import Configuration
|
| 10 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 11 |
from langchain_core.runnables import RunnableConfig
|
| 12 |
from langgraph.graph import END, StateGraph
|
| 13 |
from langgraph.types import interrupt
|
| 14 |
from smolagents import CodeAgent, LiteLLMModel
|
| 15 |
|
| 16 |
-
from
|
|
|
|
| 17 |
|
| 18 |
# Configure logging
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -33,7 +33,7 @@ else:
|
|
| 33 |
litellm.drop_params = True
|
| 34 |
|
| 35 |
# Load default prompt templates from local file
|
| 36 |
-
current_dir = os.path.dirname(os.path.
|
| 37 |
prompts_dir = os.path.join(current_dir, "prompts")
|
| 38 |
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
|
| 39 |
|
|
@@ -150,48 +150,50 @@ class AgentNode:
|
|
| 150 |
class StepCallbackNode:
|
| 151 |
"""Node that handles step callbacks and user interaction."""
|
| 152 |
|
| 153 |
-
def
|
| 154 |
-
self
|
| 155 |
-
) -> AgentState:
|
| 156 |
-
"""Handle step callback and user interaction."""
|
| 157 |
-
# Get configuration
|
| 158 |
-
cfg = Configuration.from_runnable_config(config)
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
"answer": state["answer"],
|
| 166 |
-
}
|
| 167 |
-
state["step_logs"].append(step_log)
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
)
|
| 174 |
-
user_input = interrupt_result[0] # Get the actual user input
|
| 175 |
|
| 176 |
-
if
|
|
|
|
| 177 |
state["is_complete"] = True
|
| 178 |
return state
|
| 179 |
-
elif
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
return state
|
| 184 |
-
elif
|
| 185 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
return state
|
| 187 |
else:
|
| 188 |
-
|
| 189 |
-
return state
|
| 190 |
-
|
| 191 |
-
except Exception as e:
|
| 192 |
-
logger.warning(f"Error during interrupt: {str(e)}")
|
| 193 |
-
# Continue without marking as complete
|
| 194 |
-
return state
|
| 195 |
|
| 196 |
|
| 197 |
def build_agent_graph(agent: AgentNode) -> StateGraph:
|
|
@@ -201,7 +203,7 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
|
|
| 201 |
|
| 202 |
# Add nodes
|
| 203 |
workflow.add_node("agent", agent)
|
| 204 |
-
workflow.add_node("callback", StepCallbackNode())
|
| 205 |
|
| 206 |
# Add edges
|
| 207 |
workflow.add_edge("agent", "callback")
|
|
@@ -209,22 +211,24 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
|
|
| 209 |
# Add conditional edges for callback
|
| 210 |
def should_continue(state: AgentState) -> str:
|
| 211 |
"""Determine the next node based on state."""
|
| 212 |
-
# If we have no answer, continue
|
| 213 |
if not state["answer"]:
|
| 214 |
-
logger.info("No answer found, continuing")
|
| 215 |
-
return "agent"
|
| 216 |
-
|
| 217 |
-
# If we have an answer but it's not complete, continue
|
| 218 |
-
if not state["is_complete"]:
|
| 219 |
-
logger.info(f"Found answer but not complete: {state['answer']}")
|
| 220 |
return "agent"
|
| 221 |
|
| 222 |
# If we have an answer and it's complete, we're done
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
workflow.add_conditional_edges(
|
| 227 |
-
"callback",
|
|
|
|
|
|
|
| 228 |
)
|
| 229 |
|
| 230 |
# Set entry point
|
|
|
|
| 6 |
from typing import Dict, List, Optional, TypedDict, Union
|
| 7 |
|
| 8 |
import yaml
|
|
|
|
| 9 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 10 |
from langchain_core.runnables import RunnableConfig
|
| 11 |
from langgraph.graph import END, StateGraph
|
| 12 |
from langgraph.types import interrupt
|
| 13 |
from smolagents import CodeAgent, LiteLLMModel
|
| 14 |
|
| 15 |
+
from configuration import Configuration
|
| 16 |
+
from tools import tools
|
| 17 |
|
| 18 |
# Configure logging
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 33 |
litellm.drop_params = True
|
| 34 |
|
| 35 |
# Load default prompt templates from local file
|
| 36 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 37 |
prompts_dir = os.path.join(current_dir, "prompts")
|
| 38 |
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
|
| 39 |
|
|
|
|
| 150 |
class StepCallbackNode:
|
| 151 |
"""Node that handles step callbacks and user interaction."""
|
| 152 |
|
| 153 |
+
def __init__(self, name: str):
|
| 154 |
+
self.name = name
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
def __call__(self, state: dict) -> dict:
|
| 157 |
+
"""Process the state and handle user interaction."""
|
| 158 |
+
print(f"\nCurrent step: {state.get('step_count', 0)}")
|
| 159 |
+
print(f"Question: {state.get('question', 'No question')}")
|
| 160 |
+
print(f"Current answer: {state.get('answer', 'No answer yet')}\n")
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
+
while True:
|
| 163 |
+
choice = input(
|
| 164 |
+
"Enter 'c' to continue, 'q' to quit, 'i' for more info, or 'r' to reject answer: "
|
| 165 |
+
).lower()
|
|
|
|
|
|
|
| 166 |
|
| 167 |
+
if choice == "c":
|
| 168 |
+
# Mark as complete to continue
|
| 169 |
state["is_complete"] = True
|
| 170 |
return state
|
| 171 |
+
elif choice == "q":
|
| 172 |
+
# Mark as complete and set answer to None to quit
|
| 173 |
+
state["is_complete"] = True
|
| 174 |
+
state["answer"] = None
|
| 175 |
return state
|
| 176 |
+
elif choice == "i":
|
| 177 |
+
# Show more information but don't mark as complete
|
| 178 |
+
print("\nAdditional Information:")
|
| 179 |
+
print(f"Messages: {state.get('messages', [])}")
|
| 180 |
+
print(f"Step Logs: {state.get('step_logs', [])}")
|
| 181 |
+
print(f"Context: {state.get('context', {})}")
|
| 182 |
+
print(f"Memory Buffer: {state.get('memory_buffer', [])}")
|
| 183 |
+
print(f"Last Action: {state.get('last_action', None)}")
|
| 184 |
+
print(f"Action History: {state.get('action_history', [])}")
|
| 185 |
+
print(f"Error Count: {state.get('error_count', 0)}")
|
| 186 |
+
print(f"Success Count: {state.get('success_count', 0)}\n")
|
| 187 |
+
elif choice == "r":
|
| 188 |
+
# Reject the current answer and continue execution
|
| 189 |
+
print("\nRejecting current answer and continuing execution...")
|
| 190 |
+
# Clear the message history to prevent confusion
|
| 191 |
+
state["messages"] = []
|
| 192 |
+
state["answer"] = None
|
| 193 |
+
state["is_complete"] = False
|
| 194 |
return state
|
| 195 |
else:
|
| 196 |
+
print("Invalid choice. Please enter 'c', 'q', 'i', or 'r'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
|
| 199 |
def build_agent_graph(agent: AgentNode) -> StateGraph:
|
|
|
|
| 203 |
|
| 204 |
# Add nodes
|
| 205 |
workflow.add_node("agent", agent)
|
| 206 |
+
workflow.add_node("callback", StepCallbackNode("callback"))
|
| 207 |
|
| 208 |
# Add edges
|
| 209 |
workflow.add_edge("agent", "callback")
|
|
|
|
| 211 |
# Add conditional edges for callback
|
| 212 |
def should_continue(state: AgentState) -> str:
|
| 213 |
"""Determine the next node based on state."""
|
| 214 |
+
# If we have no answer, continue to agent
|
| 215 |
if not state["answer"]:
|
| 216 |
+
logger.info("No answer found, continuing to agent")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
return "agent"
|
| 218 |
|
| 219 |
# If we have an answer and it's complete, we're done
|
| 220 |
+
if state["is_complete"]:
|
| 221 |
+
logger.info(f"Found complete answer: {state['answer']}")
|
| 222 |
+
return END
|
| 223 |
+
|
| 224 |
+
# Otherwise, go to callback for user input
|
| 225 |
+
logger.info(f"Waiting for user input for answer: {state['answer']}")
|
| 226 |
+
return "callback"
|
| 227 |
|
| 228 |
workflow.add_conditional_edges(
|
| 229 |
+
"callback",
|
| 230 |
+
should_continue,
|
| 231 |
+
{END: END, "agent": "agent", "callback": "callback"},
|
| 232 |
)
|
| 233 |
|
| 234 |
# Set entry point
|
api/runner.py β runner.py
RENAMED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
import logging
|
| 2 |
import os
|
|
|
|
| 3 |
import uuid
|
| 4 |
|
| 5 |
from langgraph.types import Command
|
| 6 |
|
| 7 |
-
from
|
| 8 |
|
| 9 |
# Configure logging
|
| 10 |
logging.basicConfig(level=logging.INFO) # Default to INFO level
|
|
@@ -48,6 +49,26 @@ class AgentRunner:
|
|
| 48 |
if "messages" in state and state["messages"]:
|
| 49 |
for msg in reversed(state["messages"]):
|
| 50 |
if hasattr(msg, "content") and msg.content:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
logger.info(f"Found answer in message: {msg.content}")
|
| 52 |
return msg.content
|
| 53 |
|
|
@@ -99,7 +120,9 @@ class AgentRunner:
|
|
| 99 |
answer = self._extract_answer(chunk)
|
| 100 |
if answer:
|
| 101 |
self.last_state = chunk
|
| 102 |
-
return answer
|
|
|
|
|
|
|
| 103 |
else:
|
| 104 |
logger.debug(f"Skipping chunk without answer: {chunk}")
|
| 105 |
else:
|
|
@@ -111,7 +134,9 @@ class AgentRunner:
|
|
| 111 |
answer = self._extract_answer(result)
|
| 112 |
if answer:
|
| 113 |
self.last_state = result
|
| 114 |
-
return answer
|
|
|
|
|
|
|
| 115 |
else:
|
| 116 |
logger.debug(f"Skipping result without answer: {result}")
|
| 117 |
|
|
@@ -122,3 +147,34 @@ class AgentRunner:
|
|
| 122 |
except Exception as e:
|
| 123 |
logger.error(f"Error processing input: {str(e)}")
|
| 124 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
+
import re
|
| 4 |
import uuid
|
| 5 |
|
| 6 |
from langgraph.types import Command
|
| 7 |
|
| 8 |
+
from graph import agent_graph
|
| 9 |
|
| 10 |
# Configure logging
|
| 11 |
logging.basicConfig(level=logging.INFO) # Default to INFO level
|
|
|
|
| 49 |
if "messages" in state and state["messages"]:
|
| 50 |
for msg in reversed(state["messages"]):
|
| 51 |
if hasattr(msg, "content") and msg.content:
|
| 52 |
+
# Look for code blocks that might contain the answer
|
| 53 |
+
if "```" in msg.content:
|
| 54 |
+
# Extract code between ```py and ``` or ```python and ```
|
| 55 |
+
code_match = re.search(
|
| 56 |
+
r"```(?:py|python)?\s*\n(.*?)\n```", msg.content, re.DOTALL
|
| 57 |
+
)
|
| 58 |
+
if code_match:
|
| 59 |
+
code = code_match.group(1)
|
| 60 |
+
# Look for final_answer call
|
| 61 |
+
final_answer_match = re.search(
|
| 62 |
+
r"final_answer\((.*?)\)", code
|
| 63 |
+
)
|
| 64 |
+
if final_answer_match:
|
| 65 |
+
answer = final_answer_match.group(1)
|
| 66 |
+
logger.info(
|
| 67 |
+
f"Found answer in final_answer call: {answer}"
|
| 68 |
+
)
|
| 69 |
+
return answer
|
| 70 |
+
|
| 71 |
+
# If no code block with final_answer, use the content
|
| 72 |
logger.info(f"Found answer in message: {msg.content}")
|
| 73 |
return msg.content
|
| 74 |
|
|
|
|
| 120 |
answer = self._extract_answer(chunk)
|
| 121 |
if answer:
|
| 122 |
self.last_state = chunk
|
| 123 |
+
# If the state is complete, return the answer
|
| 124 |
+
if chunk.get("is_complete", False):
|
| 125 |
+
return answer
|
| 126 |
else:
|
| 127 |
logger.debug(f"Skipping chunk without answer: {chunk}")
|
| 128 |
else:
|
|
|
|
| 134 |
answer = self._extract_answer(result)
|
| 135 |
if answer:
|
| 136 |
self.last_state = result
|
| 137 |
+
# If the state is complete, return the answer
|
| 138 |
+
if result.get("is_complete", False):
|
| 139 |
+
return answer
|
| 140 |
else:
|
| 141 |
logger.debug(f"Skipping result without answer: {result}")
|
| 142 |
|
|
|
|
| 147 |
except Exception as e:
|
| 148 |
logger.error(f"Error processing input: {str(e)}")
|
| 149 |
raise
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
import argparse
|
| 154 |
+
|
| 155 |
+
from langgraph.types import Command
|
| 156 |
+
|
| 157 |
+
# Set up argument parser
|
| 158 |
+
parser = argparse.ArgumentParser(description="Run the agent with a question")
|
| 159 |
+
parser.add_argument("question", type=str, help="The question to ask the agent")
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--resume",
|
| 162 |
+
type=str,
|
| 163 |
+
help="Value to resume with after an interrupt",
|
| 164 |
+
default=None,
|
| 165 |
+
)
|
| 166 |
+
args = parser.parse_args()
|
| 167 |
+
|
| 168 |
+
# Create agent runner
|
| 169 |
+
runner = AgentRunner()
|
| 170 |
+
|
| 171 |
+
if args.resume:
|
| 172 |
+
# Resume from interrupt with provided value
|
| 173 |
+
print(f"\nResuming with value: {args.resume}")
|
| 174 |
+
response = runner(Command(resume=args.resume))
|
| 175 |
+
else:
|
| 176 |
+
# Initial run with question
|
| 177 |
+
print(f"\nAsking question: {args.question}")
|
| 178 |
+
response = runner(args.question)
|
| 179 |
+
|
| 180 |
+
print(f"\nFinal response: {response}")
|
test_agent.py
CHANGED
|
@@ -2,7 +2,7 @@ import logging
|
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
-
from
|
| 6 |
|
| 7 |
# Configure test logger
|
| 8 |
test_logger = logging.getLogger("test_agent")
|
|
|
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
+
from runner import AgentRunner
|
| 6 |
|
| 7 |
# Configure test logger
|
| 8 |
test_logger = logging.getLogger("test_agent")
|
services/tools.py β tools.py
RENAMED
|
@@ -47,9 +47,39 @@ class GeneralSearchTool(Tool):
|
|
| 47 |
return "\n\n---\n\n".join(output)
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# Export all tools
|
| 51 |
tools = [
|
| 52 |
# DuckDuckGoSearchTool(),
|
| 53 |
GeneralSearchTool(),
|
|
|
|
| 54 |
# WikipediaSearchTool(),
|
| 55 |
]
|
|
|
|
| 47 |
return "\n\n---\n\n".join(output)
|
| 48 |
|
| 49 |
|
| 50 |
+
class MathTool(Tool):
|
| 51 |
+
name = "math"
|
| 52 |
+
description = """Performs mathematical calculations and returns the result."""
|
| 53 |
+
inputs = {
|
| 54 |
+
"expression": {
|
| 55 |
+
"type": "string",
|
| 56 |
+
"description": "The mathematical expression to evaluate.",
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
output_type = "string"
|
| 60 |
+
|
| 61 |
+
def forward(self, expression: str) -> str:
|
| 62 |
+
try:
|
| 63 |
+
# Use eval with a restricted set of builtins for safety
|
| 64 |
+
safe_dict = {
|
| 65 |
+
"__builtins__": {
|
| 66 |
+
"abs": abs,
|
| 67 |
+
"round": round,
|
| 68 |
+
"min": min,
|
| 69 |
+
"max": max,
|
| 70 |
+
"sum": sum,
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
result = eval(expression, safe_dict)
|
| 74 |
+
return str(result)
|
| 75 |
+
except Exception as e:
|
| 76 |
+
raise Exception(f"Error evaluating expression: {str(e)}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
# Export all tools
|
| 80 |
tools = [
|
| 81 |
# DuckDuckGoSearchTool(),
|
| 82 |
GeneralSearchTool(),
|
| 83 |
+
MathTool(),
|
| 84 |
# WikipediaSearchTool(),
|
| 85 |
]
|