Spaces:
Paused
Paused
File size: 7,026 Bytes
2f28b62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import List, Optional
from pydantic import BaseModel, Field, model_validator
from app.llm import LLM
from app.logger import logger
from app.sandbox.client import SANDBOX_CLIENT
from app.schema import ROLE_TYPE, AgentState, Memory, Message
class BaseAgent(BaseModel, ABC):
"""Abstract base class for managing agent state and execution.
Provides foundational functionality for state transitions, memory management,
and a step-based execution loop. Subclasses must implement the `step` method.
"""
# Core attributes
name: str = Field(..., description="Unique name of the agent")
description: Optional[str] = Field(None, description="Optional agent description")
# Prompts
system_prompt: Optional[str] = Field(
None, description="System-level instruction prompt"
)
next_step_prompt: Optional[str] = Field(
None, description="Prompt for determining next action"
)
# Dependencies
llm: LLM = Field(default_factory=LLM, description="Language model instance")
memory: Memory = Field(default_factory=Memory, description="Agent's memory store")
state: AgentState = Field(
default=AgentState.IDLE, description="Current agent state"
)
# Execution control
max_steps: int = Field(default=10, description="Maximum steps before termination")
current_step: int = Field(default=0, description="Current step in execution")
duplicate_threshold: int = 2
class Config:
arbitrary_types_allowed = True
extra = "allow" # Allow extra fields for flexibility in subclasses
@model_validator(mode="after")
def initialize_agent(self) -> "BaseAgent":
"""Initialize agent with default settings if not provided."""
if self.llm is None or not isinstance(self.llm, LLM):
self.llm = LLM(config_name=self.name.lower())
if not isinstance(self.memory, Memory):
self.memory = Memory()
return self
@asynccontextmanager
async def state_context(self, new_state: AgentState):
"""Context manager for safe agent state transitions.
Args:
new_state: The state to transition to during the context.
Yields:
None: Allows execution within the new state.
Raises:
ValueError: If the new_state is invalid.
"""
if not isinstance(new_state, AgentState):
raise ValueError(f"Invalid state: {new_state}")
previous_state = self.state
self.state = new_state
try:
yield
except Exception as e:
self.state = AgentState.ERROR # Transition to ERROR on failure
raise e
finally:
self.state = previous_state # Revert to previous state
def update_memory(
self,
role: ROLE_TYPE, # type: ignore
content: str,
base64_image: Optional[str] = None,
**kwargs,
) -> None:
"""Add a message to the agent's memory.
Args:
role: The role of the message sender (user, system, assistant, tool).
content: The message content.
base64_image: Optional base64 encoded image.
**kwargs: Additional arguments (e.g., tool_call_id for tool messages).
Raises:
ValueError: If the role is unsupported.
"""
message_map = {
"user": Message.user_message,
"system": Message.system_message,
"assistant": Message.assistant_message,
"tool": lambda content, **kw: Message.tool_message(content, **kw),
}
if role not in message_map:
raise ValueError(f"Unsupported message role: {role}")
# Create message with appropriate parameters based on role
kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})}
self.memory.add_message(message_map[role](content, **kwargs))
async def run(self, request: Optional[str] = None) -> str:
"""Execute the agent's main loop asynchronously.
Args:
request: Optional initial user request to process.
Returns:
A string summarizing the execution results.
Raises:
RuntimeError: If the agent is not in IDLE state at start.
"""
if self.state != AgentState.IDLE:
raise RuntimeError(f"Cannot run agent from state: {self.state}")
if request:
self.update_memory("user", request)
results: List[str] = []
async with self.state_context(AgentState.RUNNING):
while (
self.current_step < self.max_steps and self.state != AgentState.FINISHED
):
self.current_step += 1
logger.info(f"Executing step {self.current_step}/{self.max_steps}")
step_result = await self.step()
# Check for stuck state
if self.is_stuck():
self.handle_stuck_state()
results.append(f"Step {self.current_step}: {step_result}")
if self.current_step >= self.max_steps:
self.current_step = 0
self.state = AgentState.IDLE
results.append(f"Terminated: Reached max steps ({self.max_steps})")
await SANDBOX_CLIENT.cleanup()
return "\n".join(results) if results else "No steps executed"
@abstractmethod
async def step(self) -> str:
"""Execute a single step in the agent's workflow.
Must be implemented by subclasses to define specific behavior.
"""
def handle_stuck_state(self):
"""Handle stuck state by adding a prompt to change strategy"""
stuck_prompt = "\
Observed duplicate responses. Consider new strategies and avoid repeating ineffective paths already attempted."
self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}"
logger.warning(f"Agent detected stuck state. Added prompt: {stuck_prompt}")
def is_stuck(self) -> bool:
"""Check if the agent is stuck in a loop by detecting duplicate content"""
if len(self.memory.messages) < 2:
return False
last_message = self.memory.messages[-1]
if not last_message.content:
return False
# Count identical content occurrences
duplicate_count = sum(
1
for msg in reversed(self.memory.messages[:-1])
if msg.role == "assistant" and msg.content == last_message.content
)
return duplicate_count >= self.duplicate_threshold
@property
def messages(self) -> List[Message]:
"""Retrieve a list of messages from the agent's memory."""
return self.memory.messages
@messages.setter
def messages(self, value: List[Message]):
"""Set the list of messages in the agent's memory."""
self.memory.messages = value
|