File size: 4,559 Bytes
6861e2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from typing import Dict, Any
from app.rag_chain import get_qa_chain
from app.models.model import LLM as llm
from app.summary_chain import get_summary_chain
from app.translate_chain import get_translate_chain
from app.rewrite_chain import get_rewrite_chain

# --- OUTPUT COLLECTOR ---
output_collector = {
    "answer": None,
    "summary": None,
    "translation": None,
    "formal_email": None
}

# --- TOOLS ---

def answer_with_rag_tool(input_text: str, input_lang: str, **kwargs) -> str:
    chain = get_qa_chain(language=input_lang)
    result = chain.invoke({"input": input_text})
    output_collector["answer"] = result
    return result

def summarize_tool(input_text: str, input_lang: str, **kwargs) -> str:
    llm, prompt = get_summary_chain(input_lang)
    result = llm.invoke(prompt.format_messages(input=input_text))
    output_collector["summary"] = result.content
    return result

def translate_tool(input_text: str, input_lang: str, target_lang: str) -> str:
    if not target_lang:
        target_lang = input_lang  # Default to input language if no target language provided
    llm, prompt = get_translate_chain(input_lang)
    result = llm.invoke(prompt.format_messages(input=input_text, target_lang=target_lang))
    output_collector["translation"] = result.content
    return result

def rewrite_email_tool(input_text: str, input_lang: str, target_lang: str) -> str:
    llm, prompt = get_rewrite_chain(input_lang)
    result = llm.invoke(prompt.format_messages(input=input_text, target_lang=target_lang))
    output_collector["formal_email"] = result.content
    return result

# --- LANGGRAPH STATE & NODES ---

class AgentState(Dict[str, Any]):
    user_input: str
    input_lang: str
    target_lang: str
    answer: str = None
    summary: str = None
    translation: str = None
    formal_email: str = None

def node_answer_by_RAG(state: AgentState) -> AgentState:
    answer = answer_with_rag_tool(state["user_input"], state["input_lang"])
    state["answer"] = answer
    return state

def node_summarize(state: AgentState) -> AgentState:
    summary = summarize_tool(state["answer"], state["input_lang"])
    state["summary"] = summary
    return state

def node_translate(state: AgentState) -> AgentState:
    # Use summary if exists, else answer
    text_to_translate = state.get("summary") or state["answer"]
    translation = translate_tool(text_to_translate, state["input_lang"], state["target_lang"])
    state["translation"] = translation
    return state

def node_rewrite(state: AgentState) -> AgentState:
    # Use translation if exists, else summary, else answer
    text_to_rewrite = state.get("translation") or state["answer"]
    formal_email = rewrite_email_tool(text_to_rewrite, state["input_lang"], state["target_lang"])
    state["formal_email"] = formal_email
    return state

# --- LANGGRAPH GRAPH DEFINITION ---

graph = StateGraph(AgentState)
graph.add_node("answer_by_RAG", node_answer_by_RAG)
graph.add_node("summarize", node_summarize)
graph.add_node("translate", node_translate)
graph.add_node("rewrite", node_rewrite)

graph.set_entry_point("answer_by_RAG")
# Remove all conditional edges and decision functions, just allow direct edges for sequential execution if needed
# (But for the new execution plan, we do not need to set edges between nodes, as we step through them manually)
# Only keep END edge for completeness

# Add END edges for all possible last nodes
for node in ["answer_by_RAG", "summarize", "translate", "rewrite"]:
    graph.add_edge(node, END)

compiled_graph = graph.compile()

# --- MAIN API CALL ---

def run_agent(user_input: str, input_lang: str = "Deutsch", target_lang: str = None, do_summarize: bool = False, do_translate: bool = False, do_email: bool = False):
    print(user_input, input_lang, target_lang, do_summarize, do_translate, do_email)
    for k in output_collector:
        output_collector[k] = None

    state = {
        "user_input": user_input,
        "input_lang": input_lang,
        "target_lang": target_lang
    }
    execution_plan = ["answer_by_RAG"]
    if do_summarize:
        execution_plan.append("summarize")
    if do_translate:
        execution_plan.append("translate")
    if do_email:
        execution_plan.append("rewrite")

    for node in execution_plan:
        state = globals()[f"node_{node}"](state)
        # Yield the current outputs after each node
        yield {k: v for k, v in output_collector.items() if v is not None}