File size: 3,068 Bytes
20c3a0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce33d09
20c3a0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce33d09
20c3a0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce33d09
20c3a0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import TypedDict
from src.config.llm import model
from langgraph.prebuilt import create_react_agent
from langgraph_swarm import create_handoff_tool
from langchain_core.messages import RemoveMessage
from .prompt import practice_agent_prompt, teaching_agent_prompt
from typing_extensions import TypedDict, Annotated
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from loguru import logger


class State(TypedDict):
    active_agent: str | None
    messages: Annotated[list[AnyMessage], add_messages]
    unit: str
    vocabulary: list
    key_structures: list
    practice_questions: list
    student_level: str


def trim_history(state: State):
    if not state.get("active_agent"):
        state["active_agent"] = "Teaching Agent"
    history = state.get("messages", [])
    if len(history) > 25:
        num_to_remove = len(history) - 5
        remove_messages = [
            RemoveMessage(id=history[i].id) for i in range(num_to_remove)
        ]
        state["messages"] = remove_messages
    return state


async def call_practice_agent(state: State):
    logger.info("Calling practice agent...")
    practice_agent = create_react_agent(
        model,
        [
            create_handoff_tool(
                agent_name="Teaching Agent",
                description="Hand off to Teaching Agent when user asks for grammar explanations, Vietnamese help, makes repeated fundamental errors, or needs more structured learning support",
            ),
        ],
        prompt=practice_agent_prompt.format(
            unit=state["unit"],
            vocabulary=state["vocabulary"],
            key_structures=state["key_structures"],
            practice_questions=state["practice_questions"],
            student_level=state["student_level"],
        ),
        name="Practice Agent",
    )
    response = await practice_agent.ainvoke({"messages": state["messages"]})

    return {"messages": response["messages"]}


async def call_teaching_agent(state: State):
    logger.info("Calling teaching agent...")
    teaching_agent = create_react_agent(
        model,
        [
            create_handoff_tool(
                agent_name="Practice Agent",
                description="Hand off to Practice Agent when user demonstrates understanding, confidence, and is ready for natural English conversation practice",
            ),
        ],
        prompt=teaching_agent_prompt.format(
            unit=state["unit"],
            vocabulary=state["vocabulary"],
            key_structures=state["key_structures"],
            practice_questions=state["practice_questions"],
            student_level=state["student_level"],
        ),
        name="Teaching Agent",
    )
    response = await teaching_agent.ainvoke({"messages": state["messages"]})
    return {"messages": response["messages"]}


def route_to_active_agent(state: State) -> str:
    if state["active_agent"] == "Practice Agent":
        return "Practice Agent"
    elif state["active_agent"] == "Teaching Agent":
        return "Teaching Agent"