File size: 1,300 Bytes
c0a7f25
a4cb278
c0a7f25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4cb278
c0a7f25
 
 
 
 
a4cb278
 
c0a7f25
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from langgraph.graph import StateGraph, START, END
from .func import State, trim_history, agent, tool_node
from langgraph.graph.state import CompiledStateGraph
from langgraph.checkpoint.memory import InMemorySaver


class LessonPracticeAgent:
    def __init__(self):
        pass

    @staticmethod
    def should_continue(state: State):
        messages = state["messages"]
        last_message = messages[-1]
        if not last_message.tool_calls:
            return "end"
        else:
            return "continue"

    def node(self, graph: StateGraph):
        graph.add_node("trim_history", trim_history)
        graph.add_node("agent", agent)
        graph.add_node("tools", tool_node)
        return graph

    def edge(self, graph: StateGraph):
        graph.add_edge(START, "trim_history")
        graph.add_edge("trim_history", "agent")
        graph.add_conditional_edges(
            "agent", self.should_continue, {"end": END, "continue": "tools"}
        )
        return graph

    def __call__(self, checkpointer=InMemorySaver()) -> CompiledStateGraph:
        graph = StateGraph(State)
        graph: StateGraph = self.node(graph)
        graph: StateGraph = self.edge(graph)
        return graph.compile(checkpointer=checkpointer)


lesson_practice_agent = LessonPracticeAgent()