from typing import TypedDict from src.config.llm import model from langgraph.prebuilt import create_react_agent from langgraph_swarm import create_handoff_tool from langchain_core.messages import RemoveMessage from .prompt import practice_agent_prompt, teaching_agent_prompt from typing_extensions import TypedDict, Annotated from langchain_core.messages import AnyMessage from langgraph.graph import add_messages from loguru import logger class State(TypedDict): active_agent: str | None messages: Annotated[list[AnyMessage], add_messages] unit: str vocabulary: list key_structures: list practice_questions: list student_level: str def trim_history(state: State): if not state.get("active_agent"): state["active_agent"] = "Teaching Agent" history = state.get("messages", []) if len(history) > 25: num_to_remove = len(history) - 5 remove_messages = [ RemoveMessage(id=history[i].id) for i in range(num_to_remove) ] state["messages"] = remove_messages return state async def call_practice_agent(state: State): logger.info("Calling practice agent...") practice_agent = create_react_agent( model, [ create_handoff_tool( agent_name="Teaching Agent", description="Hand off to Teaching Agent when user asks for grammar explanations, Vietnamese help, makes repeated fundamental errors, or needs more structured learning support", ), ], prompt=practice_agent_prompt.format( unit=state["unit"], vocabulary=state["vocabulary"], key_structures=state["key_structures"], practice_questions=state["practice_questions"], student_level=state["student_level"], ), name="Practice Agent", ) response = await practice_agent.ainvoke({"messages": state["messages"]}) return {"messages": response["messages"]} async def call_teaching_agent(state: State): logger.info("Calling teaching agent...") teaching_agent = create_react_agent( model, [ create_handoff_tool( agent_name="Practice Agent", description="Hand off to Practice Agent when user demonstrates understanding, confidence, and is ready for natural English conversation practice", ), ], prompt=teaching_agent_prompt.format( unit=state["unit"], vocabulary=state["vocabulary"], key_structures=state["key_structures"], practice_questions=state["practice_questions"], student_level=state["student_level"], ), name="Teaching Agent", ) response = await teaching_agent.ainvoke({"messages": state["messages"]}) return {"messages": response["messages"]} def route_to_active_agent(state: State) -> str: if state["active_agent"] == "Practice Agent": return "Practice Agent" elif state["active_agent"] == "Teaching Agent": return "Teaching Agent"