Spaces:
Sleeping
Sleeping
| from typing import ( | |
| Annotated, | |
| Sequence, | |
| TypedDict, | |
| ) | |
| from langchain_core.messages import ToolMessage, AnyMessage, RemoveMessage | |
| from langgraph.graph.message import add_messages | |
| import json | |
| from .prompt import conversation_prompt | |
| from src.config.llm import model | |
| class State(TypedDict): | |
| """The state of the agent.""" | |
| unit: str | |
| vocabulary: list | |
| key_structures: list | |
| practice_questions: list | |
| student_level: list | |
| messages: Annotated[Sequence[AnyMessage], add_messages] | |
| tools = [] | |
| tools_by_name = {tool.name: tool for tool in tools} | |
| def trim_history(state: State): | |
| if not state.get("active_agent"): | |
| state["active_agent"] = "Roleplay Agent" | |
| history = state.get("messages", []) | |
| if len(history) > 25: | |
| num_to_remove = len(history) - 5 | |
| remove_messages = [ | |
| RemoveMessage(id=history[i].id) for i in range(num_to_remove) | |
| ] | |
| state["messages"] = remove_messages | |
| return state | |
| # Define our tool node | |
| def tool_node(state: State): | |
| outputs = [] | |
| for tool_call in state["messages"][-1].tool_calls: | |
| tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"]) | |
| outputs.append( | |
| ToolMessage( | |
| content=json.dumps(tool_result), | |
| name=tool_call["name"], | |
| tool_call_id=tool_call["id"], | |
| ) | |
| ) | |
| return {"messages": outputs} | |
| async def agent(state: State): | |
| llm = conversation_prompt | model | |
| response = await llm.ainvoke(state) | |
| return {"messages": response} | |