File size: 1,570 Bytes
c0a7f25
 
 
 
 
a4cb278
c0a7f25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4cb278
 
 
 
 
 
 
 
 
 
 
 
 
c0a7f25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import (
    Annotated,
    Sequence,
    TypedDict,
)
from langchain_core.messages import ToolMessage, AnyMessage, RemoveMessage
from langgraph.graph.message import add_messages
import json
from .prompt import conversation_prompt
from src.config.llm import model


class State(TypedDict):
    """The state of the agent."""

    unit: str
    vocabulary: list
    key_structures: list
    practice_questions: list
    student_level: list
    messages: Annotated[Sequence[AnyMessage], add_messages]


tools = []

tools_by_name = {tool.name: tool for tool in tools}


def trim_history(state: State):
    if not state.get("active_agent"):
        state["active_agent"] = "Roleplay Agent"
    history = state.get("messages", [])
    if len(history) > 25:
        num_to_remove = len(history) - 5
        remove_messages = [
            RemoveMessage(id=history[i].id) for i in range(num_to_remove)
        ]
        state["messages"] = remove_messages
    return state


# Define our tool node
def tool_node(state: State):
    outputs = []
    for tool_call in state["messages"][-1].tool_calls:
        tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"])
        outputs.append(
            ToolMessage(
                content=json.dumps(tool_result),
                name=tool_call["name"],
                tool_call_id=tool_call["id"],
            )
        )
    return {"messages": outputs}


async def agent(state: State):
    llm = conversation_prompt | model
    response = await llm.ainvoke(state)
    return {"messages": response}