File size: 1,662 Bytes
20c3a0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce33d09
 
 
20c3a0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a412ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from langgraph.graph import StateGraph, START, END
from .func import State, trim_history, call_practice_agent, call_teaching_agent
from langgraph.graph.state import CompiledStateGraph
from langgraph.checkpoint.memory import InMemorySaver


class LessonPractice2Agent:
    def __init__(self):
        pass

    @staticmethod
    def route_to_active_agent(state: State) -> str:
        if state["active_agent"] == "Practice Agent":
            return "Practice Agent"
        elif state["active_agent"] == "Teaching Agent":
            return "Teaching Agent"
        else:
            # Default to Teaching Agent if no active agent is set
            return "Teaching Agent"

    def node(self, graph: StateGraph):
        graph.add_node("trim_history", trim_history)
        graph.add_node("Practice Agent", call_practice_agent, destinations=("Teaching Agent",))
        graph.add_node(
            "Teaching Agent", call_teaching_agent, destinations=("Practice Agent",)
        )
        return graph

    def edge(self, graph: StateGraph):
        graph.add_edge(START, "trim_history")
        graph.add_conditional_edges(
            "trim_history",
            self.route_to_active_agent,
            {
                "Practice Agent": "Practice Agent",
                "Teaching Agent": "Teaching Agent",
            },
        )
        return graph

    def __call__(self, checkpointer=InMemorySaver()) -> CompiledStateGraph:
        graph = StateGraph(State)
        graph: StateGraph = self.node(graph)
        graph: StateGraph = self.edge(graph)
        return graph.compile(checkpointer=checkpointer)


lesson_practice_agent = LessonPractice2Agent()