Spaces:
Sleeping
Sleeping
File size: 4,559 Bytes
6861e2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from typing import Dict, Any
from app.rag_chain import get_qa_chain
from app.models.model import LLM as llm
from app.summary_chain import get_summary_chain
from app.translate_chain import get_translate_chain
from app.rewrite_chain import get_rewrite_chain
# --- OUTPUT COLLECTOR ---
output_collector = {
"answer": None,
"summary": None,
"translation": None,
"formal_email": None
}
# --- TOOLS ---
def answer_with_rag_tool(input_text: str, input_lang: str, **kwargs) -> str:
chain = get_qa_chain(language=input_lang)
result = chain.invoke({"input": input_text})
output_collector["answer"] = result
return result
def summarize_tool(input_text: str, input_lang: str, **kwargs) -> str:
llm, prompt = get_summary_chain(input_lang)
result = llm.invoke(prompt.format_messages(input=input_text))
output_collector["summary"] = result.content
return result
def translate_tool(input_text: str, input_lang: str, target_lang: str) -> str:
if not target_lang:
target_lang = input_lang # Default to input language if no target language provided
llm, prompt = get_translate_chain(input_lang)
result = llm.invoke(prompt.format_messages(input=input_text, target_lang=target_lang))
output_collector["translation"] = result.content
return result
def rewrite_email_tool(input_text: str, input_lang: str, target_lang: str) -> str:
llm, prompt = get_rewrite_chain(input_lang)
result = llm.invoke(prompt.format_messages(input=input_text, target_lang=target_lang))
output_collector["formal_email"] = result.content
return result
# --- LANGGRAPH STATE & NODES ---
class AgentState(Dict[str, Any]):
user_input: str
input_lang: str
target_lang: str
answer: str = None
summary: str = None
translation: str = None
formal_email: str = None
def node_answer_by_RAG(state: AgentState) -> AgentState:
answer = answer_with_rag_tool(state["user_input"], state["input_lang"])
state["answer"] = answer
return state
def node_summarize(state: AgentState) -> AgentState:
summary = summarize_tool(state["answer"], state["input_lang"])
state["summary"] = summary
return state
def node_translate(state: AgentState) -> AgentState:
# Use summary if exists, else answer
text_to_translate = state.get("summary") or state["answer"]
translation = translate_tool(text_to_translate, state["input_lang"], state["target_lang"])
state["translation"] = translation
return state
def node_rewrite(state: AgentState) -> AgentState:
# Use translation if exists, else summary, else answer
text_to_rewrite = state.get("translation") or state["answer"]
formal_email = rewrite_email_tool(text_to_rewrite, state["input_lang"], state["target_lang"])
state["formal_email"] = formal_email
return state
# --- LANGGRAPH GRAPH DEFINITION ---
graph = StateGraph(AgentState)
graph.add_node("answer_by_RAG", node_answer_by_RAG)
graph.add_node("summarize", node_summarize)
graph.add_node("translate", node_translate)
graph.add_node("rewrite", node_rewrite)
graph.set_entry_point("answer_by_RAG")
# Remove all conditional edges and decision functions, just allow direct edges for sequential execution if needed
# (But for the new execution plan, we do not need to set edges between nodes, as we step through them manually)
# Only keep END edge for completeness
# Add END edges for all possible last nodes
for node in ["answer_by_RAG", "summarize", "translate", "rewrite"]:
graph.add_edge(node, END)
compiled_graph = graph.compile()
# --- MAIN API CALL ---
def run_agent(user_input: str, input_lang: str = "Deutsch", target_lang: str = None, do_summarize: bool = False, do_translate: bool = False, do_email: bool = False):
print(user_input, input_lang, target_lang, do_summarize, do_translate, do_email)
for k in output_collector:
output_collector[k] = None
state = {
"user_input": user_input,
"input_lang": input_lang,
"target_lang": target_lang
}
execution_plan = ["answer_by_RAG"]
if do_summarize:
execution_plan.append("summarize")
if do_translate:
execution_plan.append("translate")
if do_email:
execution_plan.append("rewrite")
for node in execution_plan:
state = globals()[f"node_{node}"](state)
# Yield the current outputs after each node
yield {k: v for k, v in output_collector.items() if v is not None}
|