Spaces:
Sleeping
Sleeping
| import copy | |
| from typing import List, Union | |
| from lagent.agents import Agent, AgentForInternLM, AsyncAgent, AsyncAgentForInternLM | |
| from lagent.schema import AgentMessage, AgentStatusCode, ModelStatusCode | |
| class StreamingAgentMixin: | |
| """Make agent calling output a streaming response.""" | |
| def __call__(self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs): | |
| for hook in self._hooks.values(): | |
| message = copy.deepcopy(message) | |
| result = hook.before_agent(self, message, session_id) | |
| if result: | |
| message = result | |
| self.update_memory(message, session_id=session_id) | |
| response_message = AgentMessage(sender=self.name, content="") | |
| for response_message in self.forward(*message, session_id=session_id, **kwargs): | |
| if not isinstance(response_message, AgentMessage): | |
| model_state, response = response_message | |
| response_message = AgentMessage( | |
| sender=self.name, | |
| content=response, | |
| stream_state=model_state, | |
| ) | |
| yield response_message.model_copy() | |
| self.update_memory(response_message, session_id=session_id) | |
| for hook in self._hooks.values(): | |
| response_message = response_message.model_copy(deep=True) | |
| result = hook.after_agent(self, response_message, session_id) | |
| if result: | |
| response_message = result | |
| yield response_message | |
| class AsyncStreamingAgentMixin: | |
| """Make asynchronous agent calling output a streaming response.""" | |
| async def __call__( | |
| self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs | |
| ): | |
| for hook in self._hooks.values(): | |
| message = copy.deepcopy(message) | |
| result = hook.before_agent(self, message, session_id) | |
| if result: | |
| message = result | |
| self.update_memory(message, session_id=session_id) | |
| response_message = AgentMessage(sender=self.name, content="") | |
| async for response_message in self.forward(*message, session_id=session_id, **kwargs): | |
| if not isinstance(response_message, AgentMessage): | |
| model_state, response = response_message | |
| response_message = AgentMessage( | |
| sender=self.name, | |
| content=response, | |
| stream_state=model_state, | |
| ) | |
| yield response_message.model_copy() | |
| self.update_memory(response_message, session_id=session_id) | |
| for hook in self._hooks.values(): | |
| response_message = response_message.model_copy(deep=True) | |
| result = hook.after_agent(self, response_message, session_id) | |
| if result: | |
| response_message = result | |
| yield response_message | |
| class StreamingAgent(StreamingAgentMixin, Agent): | |
| """Base streaming agent class""" | |
| def forward(self, *message: AgentMessage, session_id=0, **kwargs): | |
| formatted_messages = self.aggregator.aggregate( | |
| self.memory.get(session_id), | |
| self.name, | |
| self.output_format, | |
| self.template, | |
| ) | |
| for model_state, response, _ in self.llm.stream_chat( | |
| formatted_messages, session_id=session_id, **kwargs | |
| ): | |
| yield AgentMessage( | |
| sender=self.name, | |
| content=response, | |
| formatted=self.output_format.parse_response(response), | |
| stream_state=model_state, | |
| ) if self.output_format else (model_state, response) | |
| class AsyncStreamingAgent(AsyncStreamingAgentMixin, AsyncAgent): | |
| """Base asynchronous streaming agent class""" | |
| async def forward(self, *message: AgentMessage, session_id=0, **kwargs): | |
| formatted_messages = self.aggregator.aggregate( | |
| self.memory.get(session_id), | |
| self.name, | |
| self.output_format, | |
| self.template, | |
| ) | |
| async for model_state, response, _ in self.llm.stream_chat( | |
| formatted_messages, session_id=session_id, **kwargs | |
| ): | |
| yield AgentMessage( | |
| sender=self.name, | |
| content=response, | |
| formatted=self.output_format.parse_response(response), | |
| stream_state=model_state, | |
| ) if self.output_format else (model_state, response) | |
| class StreamingAgentForInternLM(StreamingAgentMixin, AgentForInternLM): | |
| """Streaming implementation of `lagent.agents.AgentForInternLM`""" | |
| _INTERNAL_AGENT_CLS = StreamingAgent | |
| def forward(self, message: AgentMessage, session_id=0, **kwargs): | |
| if isinstance(message, str): | |
| message = AgentMessage(sender="user", content=message) | |
| for _ in range(self.max_turn): | |
| last_agent_state = AgentStatusCode.SESSION_READY | |
| for message in self.agent(message, session_id=session_id, **kwargs): | |
| if isinstance(message.formatted, dict) and message.formatted.get("tool_type"): | |
| if message.stream_state == ModelStatusCode.END: | |
| message.stream_state = last_agent_state + int( | |
| last_agent_state | |
| in [ | |
| AgentStatusCode.CODING, | |
| AgentStatusCode.PLUGIN_START, | |
| ] | |
| ) | |
| else: | |
| message.stream_state = ( | |
| AgentStatusCode.PLUGIN_START | |
| if message.formatted["tool_type"] == "plugin" | |
| else AgentStatusCode.CODING | |
| ) | |
| else: | |
| message.stream_state = AgentStatusCode.STREAM_ING | |
| yield message | |
| last_agent_state = message.stream_state | |
| if self.finish_condition(message): | |
| message.stream_state = AgentStatusCode.END | |
| yield message | |
| return | |
| if message.formatted["tool_type"]: | |
| tool_type = message.formatted["tool_type"] | |
| executor = getattr(self, f"{tool_type}_executor", None) | |
| if not executor: | |
| raise RuntimeError(f"No available {tool_type} executor") | |
| tool_return = executor(message, session_id=session_id) | |
| tool_return.stream_state = message.stream_state + 1 | |
| message = tool_return | |
| yield message | |
| else: | |
| message.stream_state = AgentStatusCode.STREAM_ING | |
| yield message | |
| class AsyncStreamingAgentForInternLM(AsyncStreamingAgentMixin, AsyncAgentForInternLM): | |
| """Streaming implementation of `lagent.agents.AsyncAgentForInternLM`""" | |
| _INTERNAL_AGENT_CLS = AsyncStreamingAgent | |
| async def forward(self, message: AgentMessage, session_id=0, **kwargs): | |
| if isinstance(message, str): | |
| message = AgentMessage(sender="user", content=message) | |
| for _ in range(self.max_turn): | |
| last_agent_state = AgentStatusCode.SESSION_READY | |
| async for message in self.agent(message, session_id=session_id, **kwargs): | |
| if isinstance(message.formatted, dict) and message.formatted.get("tool_type"): | |
| if message.stream_state == ModelStatusCode.END: | |
| message.stream_state = last_agent_state + int( | |
| last_agent_state | |
| in [ | |
| AgentStatusCode.CODING, | |
| AgentStatusCode.PLUGIN_START, | |
| ] | |
| ) | |
| else: | |
| message.stream_state = ( | |
| AgentStatusCode.PLUGIN_START | |
| if message.formatted["tool_type"] == "plugin" | |
| else AgentStatusCode.CODING | |
| ) | |
| else: | |
| message.stream_state = AgentStatusCode.STREAM_ING | |
| yield message | |
| last_agent_state = message.stream_state | |
| if self.finish_condition(message): | |
| message.stream_state = AgentStatusCode.END | |
| yield message | |
| return | |
| if message.formatted["tool_type"]: | |
| tool_type = message.formatted["tool_type"] | |
| executor = getattr(self, f"{tool_type}_executor", None) | |
| if not executor: | |
| raise RuntimeError(f"No available {tool_type} executor") | |
| tool_return = await executor(message, session_id=session_id) | |
| tool_return.stream_state = message.stream_state + 1 | |
| message = tool_return | |
| yield message | |
| else: | |
| message.stream_state = AgentStatusCode.STREAM_ING | |
| yield message | |