Spaces:
Runtime error
Runtime error
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain_core.tools import BaseTool | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.graph import START, StateGraph | |
| from langgraph.graph.message import MessagesState | |
| from langgraph.graph.state import CompiledStateGraph | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from chattr import ( | |
| ASSETS_DIR, | |
| MODEL_API_KEY, | |
| MODEL_NAME, | |
| MODEL_TEMPERATURE, | |
| MODEL_URL, | |
| ) | |
| SYSTEM_MESSAGE: SystemMessage = SystemMessage( | |
| content="You are a helpful assistant that can answer questions about the time." | |
| ) | |
| async def create_graph() -> CompiledStateGraph: | |
| """ | |
| Asynchronously creates and compiles a conversational state graph for a time-answering assistant with integrated external tools. | |
| Returns: | |
| CompiledStateGraph: The compiled state graph ready for execution, with nodes for agent responses and tool invocation. | |
| """ | |
| _mcp_client = MultiServerMCPClient( | |
| { | |
| "time": { | |
| "command": "docker", | |
| "args": ["run", "-i", "--rm", "mcp/time"], | |
| "transport": "stdio", | |
| } | |
| } | |
| ) | |
| _tools: list[BaseTool] = await _mcp_client.get_tools() | |
| try: | |
| _model: ChatOpenAI = ChatOpenAI( | |
| base_url=MODEL_URL, | |
| model=MODEL_NAME, | |
| api_key=MODEL_API_KEY, | |
| temperature=MODEL_TEMPERATURE, | |
| ) | |
| _model = _model.bind_tools(_tools, parallel_tool_calls=False) | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Failed to initialize ChatOpenAI model: {e}" | |
| ) from e | |
| def call_model(state: MessagesState) -> MessagesState: | |
| """ | |
| Generate a new message state by invoking the chat model with the system message prepended to the current messages. | |
| Parameters: | |
| state (MessagesState): The current state containing a list of messages. | |
| Returns: | |
| MessagesState: A new state with the model's response appended to the messages. | |
| """ | |
| return { | |
| "messages": [_model.invoke([SYSTEM_MESSAGE] + state["messages"])] | |
| } | |
| _builder: StateGraph = StateGraph(MessagesState) | |
| _builder.add_node("agent", call_model) | |
| _builder.add_node("tools", ToolNode(_tools)) | |
| _builder.add_edge(START, "agent") | |
| _builder.add_conditional_edges("agent", tools_condition) | |
| _builder.add_edge("tools", "agent") | |
| graph: CompiledStateGraph = _builder.compile() | |
| return graph | |
| def draw_graph(graph: CompiledStateGraph) -> None: | |
| """ | |
| Render the compiled state graph as a Mermaid PNG image and save it to the assets directory. | |
| """ | |
| graph.get_graph().draw_mermaid_png( | |
| output_file_path=ASSETS_DIR / "graph.png" | |
| ) | |
| if __name__ == "__main__": | |
| import asyncio | |
| async def test_graph(): | |
| """ | |
| Asynchronously creates and tests the conversational state graph by sending a time-related query and printing the resulting messages. | |
| """ | |
| g: CompiledStateGraph = await create_graph() | |
| messages = await g.ainvoke( | |
| {"messages": [HumanMessage(content="What is the time?")]} | |
| ) | |
| for m in messages["messages"]: | |
| m.pretty_print() | |
| asyncio.run(test_graph()) | |