Spaces:
Sleeping
Sleeping
File size: 1,570 Bytes
c0a7f25 a4cb278 c0a7f25 a4cb278 c0a7f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from typing import (
Annotated,
Sequence,
TypedDict,
)
from langchain_core.messages import ToolMessage, AnyMessage, RemoveMessage
from langgraph.graph.message import add_messages
import json
from .prompt import conversation_prompt
from src.config.llm import model
class State(TypedDict):
"""The state of the agent."""
unit: str
vocabulary: list
key_structures: list
practice_questions: list
student_level: list
messages: Annotated[Sequence[AnyMessage], add_messages]
tools = []
tools_by_name = {tool.name: tool for tool in tools}
def trim_history(state: State):
if not state.get("active_agent"):
state["active_agent"] = "Roleplay Agent"
history = state.get("messages", [])
if len(history) > 25:
num_to_remove = len(history) - 5
remove_messages = [
RemoveMessage(id=history[i].id) for i in range(num_to_remove)
]
state["messages"] = remove_messages
return state
# Define our tool node
def tool_node(state: State):
outputs = []
for tool_call in state["messages"][-1].tool_calls:
tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"])
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
async def agent(state: State):
llm = conversation_prompt | model
response = await llm.ainvoke(state)
return {"messages": response}
|