Spaces:
Paused
Paused
| from abc import ABC, abstractmethod | |
| from contextlib import asynccontextmanager | |
| from typing import List, Optional | |
| from pydantic import BaseModel, Field, model_validator | |
| from app.llm import LLM | |
| from app.logger import logger | |
| from app.sandbox.client import SANDBOX_CLIENT | |
| from app.schema import ROLE_TYPE, AgentState, Memory, Message | |
| class BaseAgent(BaseModel, ABC): | |
| """Abstract base class for managing agent state and execution. | |
| Provides foundational functionality for state transitions, memory management, | |
| and a step-based execution loop. Subclasses must implement the `step` method. | |
| """ | |
| # Core attributes | |
| name: str = Field(..., description="Unique name of the agent") | |
| description: Optional[str] = Field(None, description="Optional agent description") | |
| # Prompts | |
| system_prompt: Optional[str] = Field( | |
| None, description="System-level instruction prompt" | |
| ) | |
| next_step_prompt: Optional[str] = Field( | |
| None, description="Prompt for determining next action" | |
| ) | |
| # Dependencies | |
| llm: LLM = Field(default_factory=LLM, description="Language model instance") | |
| memory: Memory = Field(default_factory=Memory, description="Agent's memory store") | |
| state: AgentState = Field( | |
| default=AgentState.IDLE, description="Current agent state" | |
| ) | |
| # Execution control | |
| max_steps: int = Field(default=10, description="Maximum steps before termination") | |
| current_step: int = Field(default=0, description="Current step in execution") | |
| duplicate_threshold: int = 2 | |
| class Config: | |
| arbitrary_types_allowed = True | |
| extra = "allow" # Allow extra fields for flexibility in subclasses | |
| def initialize_agent(self) -> "BaseAgent": | |
| """Initialize agent with default settings if not provided.""" | |
| if self.llm is None or not isinstance(self.llm, LLM): | |
| self.llm = LLM(config_name=self.name.lower()) | |
| if not isinstance(self.memory, Memory): | |
| self.memory = Memory() | |
| return self | |
| async def state_context(self, new_state: AgentState): | |
| """Context manager for safe agent state transitions. | |
| Args: | |
| new_state: The state to transition to during the context. | |
| Yields: | |
| None: Allows execution within the new state. | |
| Raises: | |
| ValueError: If the new_state is invalid. | |
| """ | |
| if not isinstance(new_state, AgentState): | |
| raise ValueError(f"Invalid state: {new_state}") | |
| previous_state = self.state | |
| self.state = new_state | |
| try: | |
| yield | |
| except Exception as e: | |
| self.state = AgentState.ERROR # Transition to ERROR on failure | |
| raise e | |
| finally: | |
| self.state = previous_state # Revert to previous state | |
| def update_memory( | |
| self, | |
| role: ROLE_TYPE, # type: ignore | |
| content: str, | |
| base64_image: Optional[str] = None, | |
| **kwargs, | |
| ) -> None: | |
| """Add a message to the agent's memory. | |
| Args: | |
| role: The role of the message sender (user, system, assistant, tool). | |
| content: The message content. | |
| base64_image: Optional base64 encoded image. | |
| **kwargs: Additional arguments (e.g., tool_call_id for tool messages). | |
| Raises: | |
| ValueError: If the role is unsupported. | |
| """ | |
| message_map = { | |
| "user": Message.user_message, | |
| "system": Message.system_message, | |
| "assistant": Message.assistant_message, | |
| "tool": lambda content, **kw: Message.tool_message(content, **kw), | |
| } | |
| if role not in message_map: | |
| raise ValueError(f"Unsupported message role: {role}") | |
| # Create message with appropriate parameters based on role | |
| kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})} | |
| self.memory.add_message(message_map[role](content, **kwargs)) | |
| async def run(self, request: Optional[str] = None) -> str: | |
| """Execute the agent's main loop asynchronously. | |
| Args: | |
| request: Optional initial user request to process. | |
| Returns: | |
| A string summarizing the execution results. | |
| Raises: | |
| RuntimeError: If the agent is not in IDLE state at start. | |
| """ | |
| if self.state != AgentState.IDLE: | |
| raise RuntimeError(f"Cannot run agent from state: {self.state}") | |
| if request: | |
| self.update_memory("user", request) | |
| results: List[str] = [] | |
| async with self.state_context(AgentState.RUNNING): | |
| while ( | |
| self.current_step < self.max_steps and self.state != AgentState.FINISHED | |
| ): | |
| self.current_step += 1 | |
| logger.info(f"Executing step {self.current_step}/{self.max_steps}") | |
| step_result = await self.step() | |
| # Check for stuck state | |
| if self.is_stuck(): | |
| self.handle_stuck_state() | |
| results.append(f"Step {self.current_step}: {step_result}") | |
| if self.current_step >= self.max_steps: | |
| self.current_step = 0 | |
| self.state = AgentState.IDLE | |
| results.append(f"Terminated: Reached max steps ({self.max_steps})") | |
| await SANDBOX_CLIENT.cleanup() | |
| return "\n".join(results) if results else "No steps executed" | |
| async def step(self) -> str: | |
| """Execute a single step in the agent's workflow. | |
| Must be implemented by subclasses to define specific behavior. | |
| """ | |
| def handle_stuck_state(self): | |
| """Handle stuck state by adding a prompt to change strategy""" | |
| stuck_prompt = "\ | |
| Observed duplicate responses. Consider new strategies and avoid repeating ineffective paths already attempted." | |
| self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}" | |
| logger.warning(f"Agent detected stuck state. Added prompt: {stuck_prompt}") | |
| def is_stuck(self) -> bool: | |
| """Check if the agent is stuck in a loop by detecting duplicate content""" | |
| if len(self.memory.messages) < 2: | |
| return False | |
| last_message = self.memory.messages[-1] | |
| if not last_message.content: | |
| return False | |
| # Count identical content occurrences | |
| duplicate_count = sum( | |
| 1 | |
| for msg in reversed(self.memory.messages[:-1]) | |
| if msg.role == "assistant" and msg.content == last_message.content | |
| ) | |
| return duplicate_count >= self.duplicate_threshold | |
| def messages(self) -> List[Message]: | |
| """Retrieve a list of messages from the agent's memory.""" | |
| return self.memory.messages | |
| def messages(self, value: List[Message]): | |
| """Set the list of messages in the agent's memory.""" | |
| self.memory.messages = value | |