Spaces:
Sleeping
Sleeping
Refactor app.py and update import paths in test_agent.py to improve code organization. Introduce new files for agent configuration, graph definition, and tools, enhancing the overall structure and functionality of the agent system.
Browse files- agent.py → api/runner.py +8 -20
- app.py +0 -1
- configuration.py → services/configuration.py +0 -0
- graph.py → services/graph.py +13 -17
- tools.py → services/tools.py +0 -0
- test_agent.py +23 -3
agent.py → api/runner.py
RENAMED
|
@@ -4,7 +4,7 @@ import uuid
|
|
| 4 |
|
| 5 |
from langgraph.types import Command
|
| 6 |
|
| 7 |
-
from graph import agent_graph
|
| 8 |
|
| 9 |
# Configure logging
|
| 10 |
logging.basicConfig(level=logging.INFO) # Default to INFO level
|
|
@@ -86,32 +86,20 @@ class AgentRunner:
|
|
| 86 |
}
|
| 87 |
logger.info(f"Initial state: {initial_state}")
|
| 88 |
|
| 89 |
-
# Use stream to get
|
| 90 |
logger.info("Starting graph stream for initial question")
|
| 91 |
for chunk in self.graph.stream(initial_state, config):
|
| 92 |
logger.debug(f"Received chunk: {chunk}")
|
| 93 |
-
|
| 94 |
if isinstance(chunk, dict):
|
| 95 |
if "__interrupt__" in chunk:
|
| 96 |
logger.info("Detected interrupt in stream")
|
| 97 |
logger.info(f"Interrupt details: {chunk['__interrupt__']}")
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
logger.debug(f"Received resume result: {result}")
|
| 105 |
-
if isinstance(result, dict):
|
| 106 |
-
answer = self._extract_answer(result)
|
| 107 |
-
if answer:
|
| 108 |
-
self.last_state = result
|
| 109 |
-
return answer
|
| 110 |
-
else:
|
| 111 |
-
answer = self._extract_answer(chunk)
|
| 112 |
-
if answer:
|
| 113 |
-
self.last_state = chunk
|
| 114 |
-
return answer
|
| 115 |
else:
|
| 116 |
logger.debug(f"Skipping chunk without answer: {chunk}")
|
| 117 |
else:
|
|
|
|
| 4 |
|
| 5 |
from langgraph.types import Command
|
| 6 |
|
| 7 |
+
from services.graph import agent_graph
|
| 8 |
|
| 9 |
# Configure logging
|
| 10 |
logging.basicConfig(level=logging.INFO) # Default to INFO level
|
|
|
|
| 86 |
}
|
| 87 |
logger.info(f"Initial state: {initial_state}")
|
| 88 |
|
| 89 |
+
# Use stream to get results
|
| 90 |
logger.info("Starting graph stream for initial question")
|
| 91 |
for chunk in self.graph.stream(initial_state, config):
|
| 92 |
logger.debug(f"Received chunk: {chunk}")
|
|
|
|
| 93 |
if isinstance(chunk, dict):
|
| 94 |
if "__interrupt__" in chunk:
|
| 95 |
logger.info("Detected interrupt in stream")
|
| 96 |
logger.info(f"Interrupt details: {chunk['__interrupt__']}")
|
| 97 |
+
# Let the graph handle the interrupt naturally
|
| 98 |
+
continue
|
| 99 |
+
answer = self._extract_answer(chunk)
|
| 100 |
+
if answer:
|
| 101 |
+
self.last_state = chunk
|
| 102 |
+
return answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
else:
|
| 104 |
logger.debug(f"Skipping chunk without answer: {chunk}")
|
| 105 |
else:
|
app.py
CHANGED
|
@@ -3,7 +3,6 @@ import os
|
|
| 3 |
import gradio as gr
|
| 4 |
import pandas as pd
|
| 5 |
import requests
|
| 6 |
-
|
| 7 |
from agent import AgentRunner
|
| 8 |
|
| 9 |
# (Keep Constants as is)
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import pandas as pd
|
| 5 |
import requests
|
|
|
|
| 6 |
from agent import AgentRunner
|
| 7 |
|
| 8 |
# (Keep Constants as is)
|
configuration.py → services/configuration.py
RENAMED
|
File without changes
|
graph.py → services/graph.py
RENAMED
|
@@ -6,14 +6,14 @@ from datetime import datetime
|
|
| 6 |
from typing import Dict, List, Optional, TypedDict, Union
|
| 7 |
|
| 8 |
import yaml
|
|
|
|
| 9 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 10 |
from langchain_core.runnables import RunnableConfig
|
| 11 |
from langgraph.graph import END, StateGraph
|
| 12 |
from langgraph.types import interrupt
|
| 13 |
from smolagents import CodeAgent, LiteLLMModel
|
| 14 |
|
| 15 |
-
from
|
| 16 |
-
from tools import tools
|
| 17 |
|
| 18 |
# Configure logging
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -33,7 +33,7 @@ else:
|
|
| 33 |
litellm.drop_params = True
|
| 34 |
|
| 35 |
# Load default prompt templates from local file
|
| 36 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 37 |
prompts_dir = os.path.join(current_dir, "prompts")
|
| 38 |
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
|
| 39 |
|
|
@@ -182,9 +182,7 @@ class StepCallbackNode:
|
|
| 182 |
logger.info(f"Current answer: {state['answer']}")
|
| 183 |
return state
|
| 184 |
elif user_input.lower() == "c":
|
| 185 |
-
#
|
| 186 |
-
if state["answer"]:
|
| 187 |
-
state["is_complete"] = True
|
| 188 |
return state
|
| 189 |
else:
|
| 190 |
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
|
|
@@ -192,9 +190,7 @@ class StepCallbackNode:
|
|
| 192 |
|
| 193 |
except Exception as e:
|
| 194 |
logger.warning(f"Error during interrupt: {str(e)}")
|
| 195 |
-
#
|
| 196 |
-
if state["answer"]:
|
| 197 |
-
state["is_complete"] = True
|
| 198 |
return state
|
| 199 |
|
| 200 |
|
|
@@ -213,19 +209,19 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
|
|
| 213 |
# Add conditional edges for callback
|
| 214 |
def should_continue(state: AgentState) -> str:
|
| 215 |
"""Determine the next node based on state."""
|
| 216 |
-
# If we have
|
| 217 |
-
if state["answer"]
|
| 218 |
-
logger.info(
|
| 219 |
-
return
|
| 220 |
|
| 221 |
# If we have an answer but it's not complete, continue
|
| 222 |
-
if state["
|
| 223 |
logger.info(f"Found answer but not complete: {state['answer']}")
|
| 224 |
return "agent"
|
| 225 |
|
| 226 |
-
# If we have
|
| 227 |
-
logger.info("
|
| 228 |
-
return
|
| 229 |
|
| 230 |
workflow.add_conditional_edges(
|
| 231 |
"callback", should_continue, {END: END, "agent": "agent"}
|
|
|
|
| 6 |
from typing import Dict, List, Optional, TypedDict, Union
|
| 7 |
|
| 8 |
import yaml
|
| 9 |
+
from services.configuration import Configuration
|
| 10 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 11 |
from langchain_core.runnables import RunnableConfig
|
| 12 |
from langgraph.graph import END, StateGraph
|
| 13 |
from langgraph.types import interrupt
|
| 14 |
from smolagents import CodeAgent, LiteLLMModel
|
| 15 |
|
| 16 |
+
from services.tools import tools
|
|
|
|
| 17 |
|
| 18 |
# Configure logging
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 33 |
litellm.drop_params = True
|
| 34 |
|
| 35 |
# Load default prompt templates from local file
|
| 36 |
+
current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 37 |
prompts_dir = os.path.join(current_dir, "prompts")
|
| 38 |
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
|
| 39 |
|
|
|
|
| 182 |
logger.info(f"Current answer: {state['answer']}")
|
| 183 |
return state
|
| 184 |
elif user_input.lower() == "c":
|
| 185 |
+
# Continue without marking as complete
|
|
|
|
|
|
|
| 186 |
return state
|
| 187 |
else:
|
| 188 |
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
|
|
|
|
| 190 |
|
| 191 |
except Exception as e:
|
| 192 |
logger.warning(f"Error during interrupt: {str(e)}")
|
| 193 |
+
# Continue without marking as complete
|
|
|
|
|
|
|
| 194 |
return state
|
| 195 |
|
| 196 |
|
|
|
|
| 209 |
# Add conditional edges for callback
|
| 210 |
def should_continue(state: AgentState) -> str:
|
| 211 |
"""Determine the next node based on state."""
|
| 212 |
+
# If we have no answer, continue
|
| 213 |
+
if not state["answer"]:
|
| 214 |
+
logger.info("No answer found, continuing")
|
| 215 |
+
return "agent"
|
| 216 |
|
| 217 |
# If we have an answer but it's not complete, continue
|
| 218 |
+
if not state["is_complete"]:
|
| 219 |
logger.info(f"Found answer but not complete: {state['answer']}")
|
| 220 |
return "agent"
|
| 221 |
|
| 222 |
+
# If we have an answer and it's complete, we're done
|
| 223 |
+
logger.info(f"Found complete answer: {state['answer']}")
|
| 224 |
+
return END
|
| 225 |
|
| 226 |
workflow.add_conditional_edges(
|
| 227 |
"callback", should_continue, {END: END, "agent": "agent"}
|
tools.py → services/tools.py
RENAMED
|
File without changes
|
test_agent.py
CHANGED
|
@@ -2,7 +2,7 @@ import logging
|
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
-
from
|
| 6 |
|
| 7 |
# Configure test logger
|
| 8 |
test_logger = logging.getLogger("test_agent")
|
|
@@ -194,9 +194,29 @@ def test_simple_math_calculation_with_steps():
|
|
| 194 |
|
| 195 |
# Verify final answer
|
| 196 |
expected_result = 1302.678
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
assert (
|
| 198 |
-
|
| 199 |
-
), f"Response should contain the result {expected_result}"
|
|
|
|
|
|
|
| 200 |
assert (
|
| 201 |
"final_answer" in response.lower()
|
| 202 |
), "Response should indicate it's using final_answer"
|
|
|
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
+
from api.runner import AgentRunner
|
| 6 |
|
| 7 |
# Configure test logger
|
| 8 |
test_logger = logging.getLogger("test_agent")
|
|
|
|
| 194 |
|
| 195 |
# Verify final answer
|
| 196 |
expected_result = 1302.678
|
| 197 |
+
|
| 198 |
+
# Extract all numbers from the response
|
| 199 |
+
import re
|
| 200 |
+
|
| 201 |
+
# First check for LaTeX formatting
|
| 202 |
+
latex_match = re.search(r"\\boxed{([^}]+)}", response)
|
| 203 |
+
if latex_match:
|
| 204 |
+
# Extract number from LaTeX box
|
| 205 |
+
latex_content = latex_match.group(1)
|
| 206 |
+
numbers = re.findall(r"\d+\.?\d*", latex_content)
|
| 207 |
+
else:
|
| 208 |
+
# Extract all numbers from the response
|
| 209 |
+
numbers = re.findall(r"\d+\.?\d*", response)
|
| 210 |
+
|
| 211 |
+
assert numbers, "Response should contain at least one number"
|
| 212 |
+
|
| 213 |
+
# Check if any number matches the expected result
|
| 214 |
+
has_correct_result = any(abs(float(n) - expected_result) < 0.001 for n in numbers)
|
| 215 |
assert (
|
| 216 |
+
has_correct_result
|
| 217 |
+
), f"Response should contain the result {expected_result}, got {response}"
|
| 218 |
+
|
| 219 |
+
# Verify the response indicates it's a final answer
|
| 220 |
assert (
|
| 221 |
"final_answer" in response.lower()
|
| 222 |
), "Response should indicate it's using final_answer"
|