from typing import ( Annotated, Sequence, TypedDict, ) from langchain_core.messages import ToolMessage, AnyMessage, RemoveMessage from langgraph.graph.message import add_messages import json from .prompt import conversation_prompt from src.config.llm import model class State(TypedDict): """The state of the agent.""" unit: str vocabulary: list key_structures: list practice_questions: list student_level: list messages: Annotated[Sequence[AnyMessage], add_messages] tools = [] tools_by_name = {tool.name: tool for tool in tools} def trim_history(state: State): if not state.get("active_agent"): state["active_agent"] = "Roleplay Agent" history = state.get("messages", []) if len(history) > 25: num_to_remove = len(history) - 5 remove_messages = [ RemoveMessage(id=history[i].id) for i in range(num_to_remove) ] state["messages"] = remove_messages return state # Define our tool node def tool_node(state: State): outputs = [] for tool_call in state["messages"][-1].tool_calls: tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"]) outputs.append( ToolMessage( content=json.dumps(tool_result), name=tool_call["name"], tool_call_id=tool_call["id"], ) ) return {"messages": outputs} async def agent(state: State): llm = conversation_prompt | model response = await llm.ainvoke(state) return {"messages": response}