Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import re | |
| from copy import deepcopy | |
| from typing import Dict, Tuple | |
| from lagent.schema import AgentMessage, AgentStatusCode, ModelStatusCode | |
| from lagent.utils import GeneratorWithReturn | |
| from .graph import ExecutionAction, WebSearchGraph | |
| from .streaming import AsyncStreamingAgentForInternLM, StreamingAgentForInternLM | |
| def _update_ref(ref: str, ref2url: Dict[str, str], ptr: int) -> str: | |
| numbers = list({int(n) for n in re.findall(r"\[\[(\d+)\]\]", ref)}) | |
| numbers = {n: idx + 1 for idx, n in enumerate(numbers)} | |
| updated_ref = re.sub( | |
| r"\[\[(\d+)\]\]", | |
| lambda match: f"[[{numbers[int(match.group(1))] + ptr}]]", | |
| ref, | |
| ) | |
| updated_ref2url = {} | |
| if numbers: | |
| try: | |
| assert all(elem in ref2url for elem in numbers) | |
| except Exception as exc: | |
| logging.info(f"Illegal reference id: {str(exc)}") | |
| if ref2url: | |
| updated_ref2url = { | |
| numbers[idx] + ptr: ref2url[idx] for idx in numbers if idx in ref2url | |
| } | |
| return updated_ref, updated_ref2url, len(numbers) + 1 | |
| def _generate_references_from_graph(graph: Dict[str, dict]) -> Tuple[str, Dict[int, dict]]: | |
| ptr, references, references_url = 0, [], {} | |
| for name, data_item in graph.items(): | |
| if name in ["root", "response"]: | |
| continue | |
| # only search once at each node, thus the result offset is 2 | |
| assert data_item["memory"]["agent.memory"][2]["sender"].endswith("ActionExecutor") | |
| ref2url = { | |
| int(k): v | |
| for k, v in json.loads(data_item["memory"]["agent.memory"][2]["content"]).items() | |
| } | |
| updata_ref, ref2url, added_ptr = _update_ref( | |
| data_item["response"]["content"], ref2url, ptr | |
| ) | |
| ptr += added_ptr | |
| references.append(f'## {data_item["content"]}\n\n{updata_ref}') | |
| references_url.update(ref2url) | |
| return "\n\n".join(references), references_url | |
| class MindSearchAgent(StreamingAgentForInternLM): | |
| def __init__( | |
| self, | |
| searcher_cfg: dict, | |
| summary_prompt: str, | |
| finish_condition=lambda m: "add_response_node" in m.content, | |
| max_turn: int = 10, | |
| **kwargs, | |
| ): | |
| WebSearchGraph.SEARCHER_CONFIG = searcher_cfg | |
| super().__init__(finish_condition=finish_condition, max_turn=max_turn, **kwargs) | |
| self.summary_prompt = summary_prompt | |
| self.action = ExecutionAction() | |
| def forward(self, message: AgentMessage, session_id=0, **kwargs): | |
| if isinstance(message, str): | |
| message = AgentMessage(sender="user", content=message) | |
| _graph_state = dict(node={}, adjacency_list={}, ref2url={}) | |
| local_dict, global_dict = {}, globals() | |
| for _ in range(self.max_turn): | |
| last_agent_state = AgentStatusCode.SESSION_READY | |
| for message in self.agent(message, session_id=session_id, **kwargs): | |
| if isinstance(message.formatted, dict) and message.formatted.get("tool_type"): | |
| if message.stream_state == ModelStatusCode.END: | |
| message.stream_state = last_agent_state + int( | |
| last_agent_state | |
| in [ | |
| AgentStatusCode.CODING, | |
| AgentStatusCode.PLUGIN_START, | |
| ] | |
| ) | |
| else: | |
| message.stream_state = ( | |
| AgentStatusCode.PLUGIN_START | |
| if message.formatted["tool_type"] == "plugin" | |
| else AgentStatusCode.CODING | |
| ) | |
| else: | |
| message.stream_state = AgentStatusCode.STREAM_ING | |
| message.formatted.update(deepcopy(_graph_state)) | |
| yield message | |
| last_agent_state = message.stream_state | |
| if not message.formatted["tool_type"]: | |
| message.stream_state = AgentStatusCode.END | |
| yield message | |
| return | |
| gen = GeneratorWithReturn( | |
| self.action.run(message.content, local_dict, global_dict, True) | |
| ) | |
| for graph_exec in gen: | |
| graph_exec.formatted["ref2url"] = deepcopy(_graph_state["ref2url"]) | |
| yield graph_exec | |
| reference, references_url = _generate_references_from_graph(gen.ret[1]) | |
| _graph_state.update(node=gen.ret[1], adjacency_list=gen.ret[2], ref2url=references_url) | |
| if self.finish_condition(message): | |
| message = AgentMessage( | |
| sender="ActionExecutor", | |
| content=self.summary_prompt, | |
| formatted=deepcopy(_graph_state), | |
| stream_state=message.stream_state + 1, # plugin or code return | |
| ) | |
| yield message | |
| # summarize the references to generate the final answer | |
| for message in self.agent(message, session_id=session_id, **kwargs): | |
| message.formatted.update(deepcopy(_graph_state)) | |
| yield message | |
| return | |
| message = AgentMessage( | |
| sender="ActionExecutor", | |
| content=reference, | |
| formatted=deepcopy(_graph_state), | |
| stream_state=message.stream_state + 1, # plugin or code return | |
| ) | |
| yield message | |
| class AsyncMindSearchAgent(AsyncStreamingAgentForInternLM): | |
| def __init__( | |
| self, | |
| searcher_cfg: dict, | |
| summary_prompt: str, | |
| finish_condition=lambda m: "add_response_node" in m.content, | |
| max_turn: int = 10, | |
| **kwargs, | |
| ): | |
| WebSearchGraph.SEARCHER_CONFIG = searcher_cfg | |
| WebSearchGraph.is_async = True | |
| WebSearchGraph.start_loop() | |
| super().__init__(finish_condition=finish_condition, max_turn=max_turn, **kwargs) | |
| self.summary_prompt = summary_prompt | |
| self.action = ExecutionAction() | |
| async def forward(self, message: AgentMessage, session_id=0, **kwargs): | |
| if isinstance(message, str): | |
| message = AgentMessage(sender="user", content=message) | |
| _graph_state = dict(node={}, adjacency_list={}, ref2url={}) | |
| local_dict, global_dict = {}, globals() | |
| for _ in range(self.max_turn): | |
| last_agent_state = AgentStatusCode.SESSION_READY | |
| async for message in self.agent(message, session_id=session_id, **kwargs): | |
| if isinstance(message.formatted, dict) and message.formatted.get("tool_type"): | |
| if message.stream_state == ModelStatusCode.END: | |
| message.stream_state = last_agent_state + int( | |
| last_agent_state | |
| in [ | |
| AgentStatusCode.CODING, | |
| AgentStatusCode.PLUGIN_START, | |
| ] | |
| ) | |
| else: | |
| message.stream_state = ( | |
| AgentStatusCode.PLUGIN_START | |
| if message.formatted["tool_type"] == "plugin" | |
| else AgentStatusCode.CODING | |
| ) | |
| else: | |
| message.stream_state = AgentStatusCode.STREAM_ING | |
| message.formatted.update(deepcopy(_graph_state)) | |
| yield message | |
| last_agent_state = message.stream_state | |
| if not message.formatted["tool_type"]: | |
| message.stream_state = AgentStatusCode.END | |
| yield message | |
| return | |
| gen = GeneratorWithReturn( | |
| self.action.run(message.content, local_dict, global_dict, True) | |
| ) | |
| for graph_exec in gen: | |
| graph_exec.formatted["ref2url"] = deepcopy(_graph_state["ref2url"]) | |
| yield graph_exec | |
| reference, references_url = _generate_references_from_graph(gen.ret[1]) | |
| _graph_state.update(node=gen.ret[1], adjacency_list=gen.ret[2], ref2url=references_url) | |
| if self.finish_condition(message): | |
| message = AgentMessage( | |
| sender="ActionExecutor", | |
| content=self.summary_prompt, | |
| formatted=deepcopy(_graph_state), | |
| stream_state=message.stream_state + 1, # plugin or code return | |
| ) | |
| yield message | |
| # summarize the references to generate the final answer | |
| async for message in self.agent(message, session_id=session_id, **kwargs): | |
| message.formatted.update(deepcopy(_graph_state)) | |
| yield message | |
| return | |
| message = AgentMessage( | |
| sender="ActionExecutor", | |
| content=reference, | |
| formatted=deepcopy(_graph_state), | |
| stream_state=message.stream_state + 1, # plugin or code return | |
| ) | |
| yield message | |