Spaces:
Sleeping
Sleeping
Add configuration, graph, runner, and tools modules to enhance agent functionality. Introduce a Configuration class for managing parameters, implement an AgentRunner for executing the agent graph, and create tools for general search and mathematical calculations. Update test_agent.py to reflect new import paths and improve overall code organization.
13388e5
unverified
| """Define the agent graph and its components.""" | |
| import logging | |
| import os | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, TypedDict, Union | |
| import yaml | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
| from langchain_core.runnables import RunnableConfig | |
| from langgraph.graph import END, StateGraph | |
| from langgraph.types import interrupt | |
| from smolagents import CodeAgent, LiteLLMModel | |
| from configuration import Configuration | |
| from tools import tools | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Enable LiteLLM debug logging only if environment variable is set | |
| import litellm | |
| if os.getenv("LITELLM_DEBUG", "false").lower() == "true": | |
| litellm.set_verbose = True | |
| logger.setLevel(logging.DEBUG) | |
| else: | |
| litellm.set_verbose = False | |
| logger.setLevel(logging.INFO) | |
| # Configure LiteLLM to drop unsupported parameters | |
| litellm.drop_params = True | |
| # Load default prompt templates from local file | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| prompts_dir = os.path.join(current_dir, "prompts") | |
| yaml_path = os.path.join(prompts_dir, "code_agent.yaml") | |
| with open(yaml_path, "r") as f: | |
| prompt_templates = yaml.safe_load(f) | |
| # Initialize the model and agent using configuration | |
| config = Configuration() | |
| model = LiteLLMModel( | |
| api_base=config.api_base, | |
| api_key=config.api_key, | |
| model_id=config.model_id, | |
| ) | |
| agent = CodeAgent( | |
| add_base_tools=True, | |
| max_steps=1, # Execute one step at a time | |
| model=model, | |
| prompt_templates=prompt_templates, | |
| tools=tools, | |
| verbosity_level=logging.DEBUG, | |
| ) | |
| class AgentState(TypedDict): | |
| """State for the agent graph.""" | |
| messages: List[Union[HumanMessage, AIMessage, SystemMessage]] | |
| question: str | |
| answer: Optional[str] | |
| step_logs: List[Dict] | |
| is_complete: bool | |
| step_count: int | |
| # Add memory-related fields | |
| context: Dict[str, any] # For storing contextual information | |
| memory_buffer: List[Dict] # For storing important information across steps | |
| last_action: Optional[str] # Track the last action taken | |
| action_history: List[Dict] # History of actions taken | |
| error_count: int # Track error frequency | |
| success_count: int # Track successful operations | |
| class AgentNode: | |
| """Node that runs the agent.""" | |
| def __init__(self, agent: CodeAgent): | |
| """Initialize the agent node with an agent.""" | |
| self.agent = agent | |
| def __call__( | |
| self, state: AgentState, config: Optional[RunnableConfig] = None | |
| ) -> AgentState: | |
| """Run the agent on the current state.""" | |
| # Log current state | |
| logger.info("Current state before processing:") | |
| logger.info(f"Messages: {state['messages']}") | |
| logger.info(f"Question: {state['question']}") | |
| logger.info(f"Answer: {state['answer']}") | |
| # Get configuration | |
| cfg = Configuration.from_runnable_config(config) | |
| logger.info(f"Using configuration: {cfg}") | |
| # Log execution start | |
| logger.info("Starting agent execution") | |
| try: | |
| # Run the agent | |
| result = self.agent.run(state["question"]) | |
| # Update memory-related fields | |
| new_state = state.copy() | |
| new_state["messages"].append(AIMessage(content=result)) | |
| new_state["answer"] = result | |
| new_state["step_count"] += 1 | |
| new_state["last_action"] = "agent_response" | |
| new_state["action_history"].append( | |
| { | |
| "step": state["step_count"], | |
| "action": "agent_response", | |
| "result": result, | |
| } | |
| ) | |
| new_state["success_count"] += 1 | |
| # Store important information in memory buffer | |
| if result: | |
| new_state["memory_buffer"].append( | |
| { | |
| "step": state["step_count"], | |
| "content": result, | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error during agent execution: {str(e)}") | |
| new_state = state.copy() | |
| new_state["error_count"] += 1 | |
| new_state["action_history"].append( | |
| {"step": state["step_count"], "action": "error", "error": str(e)} | |
| ) | |
| raise | |
| # Log updated state | |
| logger.info("Updated state after processing:") | |
| logger.info(f"Messages: {new_state['messages']}") | |
| logger.info(f"Question: {new_state['question']}") | |
| logger.info(f"Answer: {new_state['answer']}") | |
| return new_state | |
| class StepCallbackNode: | |
| """Node that handles step callbacks and user interaction.""" | |
| def __init__(self, name: str): | |
| self.name = name | |
| def __call__(self, state: dict) -> dict: | |
| """Process the state and handle user interaction.""" | |
| print(f"\nCurrent step: {state.get('step_count', 0)}") | |
| print(f"Question: {state.get('question', 'No question')}") | |
| print(f"Current answer: {state.get('answer', 'No answer yet')}\n") | |
| while True: | |
| choice = input( | |
| "Enter 'c' to continue, 'q' to quit, 'i' for more info, or 'r' to reject answer: " | |
| ).lower() | |
| if choice == "c": | |
| # Mark as complete to continue | |
| state["is_complete"] = True | |
| return state | |
| elif choice == "q": | |
| # Mark as complete and set answer to None to quit | |
| state["is_complete"] = True | |
| state["answer"] = None | |
| return state | |
| elif choice == "i": | |
| # Show more information but don't mark as complete | |
| print("\nAdditional Information:") | |
| print(f"Messages: {state.get('messages', [])}") | |
| print(f"Step Logs: {state.get('step_logs', [])}") | |
| print(f"Context: {state.get('context', {})}") | |
| print(f"Memory Buffer: {state.get('memory_buffer', [])}") | |
| print(f"Last Action: {state.get('last_action', None)}") | |
| print(f"Action History: {state.get('action_history', [])}") | |
| print(f"Error Count: {state.get('error_count', 0)}") | |
| print(f"Success Count: {state.get('success_count', 0)}\n") | |
| elif choice == "r": | |
| # Reject the current answer and continue execution | |
| print("\nRejecting current answer and continuing execution...") | |
| # Clear the message history to prevent confusion | |
| state["messages"] = [] | |
| state["answer"] = None | |
| state["is_complete"] = False | |
| return state | |
| else: | |
| print("Invalid choice. Please enter 'c', 'q', 'i', or 'r'.") | |
| def build_agent_graph(agent: AgentNode) -> StateGraph: | |
| """Build the agent graph.""" | |
| # Initialize the graph | |
| workflow = StateGraph(AgentState) | |
| # Add nodes | |
| workflow.add_node("agent", agent) | |
| workflow.add_node("callback", StepCallbackNode("callback")) | |
| # Add edges | |
| workflow.add_edge("agent", "callback") | |
| # Add conditional edges for callback | |
| def should_continue(state: AgentState) -> str: | |
| """Determine the next node based on state.""" | |
| # If we have no answer, continue to agent | |
| if not state["answer"]: | |
| logger.info("No answer found, continuing to agent") | |
| return "agent" | |
| # If we have an answer and it's complete, we're done | |
| if state["is_complete"]: | |
| logger.info(f"Found complete answer: {state['answer']}") | |
| return END | |
| # Otherwise, go to callback for user input | |
| logger.info(f"Waiting for user input for answer: {state['answer']}") | |
| return "callback" | |
| workflow.add_conditional_edges( | |
| "callback", | |
| should_continue, | |
| {END: END, "agent": "agent", "callback": "callback"}, | |
| ) | |
| # Set entry point | |
| workflow.set_entry_point("agent") | |
| return workflow.compile() | |
| # Initialize the agent graph | |
| agent_graph = build_agent_graph(AgentNode(agent)) | |