ABAO77's picture
Add trim_history function and update lesson chat route to handle text and audio inputs
a4cb278
raw
history blame
1.57 kB
from typing import (
Annotated,
Sequence,
TypedDict,
)
from langchain_core.messages import ToolMessage, AnyMessage, RemoveMessage
from langgraph.graph.message import add_messages
import json
from .prompt import conversation_prompt
from src.config.llm import model
class State(TypedDict):
"""The state of the agent."""
unit: str
vocabulary: list
key_structures: list
practice_questions: list
student_level: list
messages: Annotated[Sequence[AnyMessage], add_messages]
tools = []
tools_by_name = {tool.name: tool for tool in tools}
def trim_history(state: State):
if not state.get("active_agent"):
state["active_agent"] = "Roleplay Agent"
history = state.get("messages", [])
if len(history) > 25:
num_to_remove = len(history) - 5
remove_messages = [
RemoveMessage(id=history[i].id) for i in range(num_to_remove)
]
state["messages"] = remove_messages
return state
# Define our tool node
def tool_node(state: State):
outputs = []
for tool_call in state["messages"][-1].tool_calls:
tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"])
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
async def agent(state: State):
llm = conversation_prompt | model
response = await llm.ainvoke(state)
return {"messages": response}