nanfangwuyu21's picture
Update the chatbot to an LangGraph agent, handling four tasks including RAG, summary, translate and email. Update prompts. Update Gradio UI.
6861e2a
from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from typing import Dict, Any
from app.rag_chain import get_qa_chain
from app.models.model import LLM as llm
from app.summary_chain import get_summary_chain
from app.translate_chain import get_translate_chain
from app.rewrite_chain import get_rewrite_chain
# --- OUTPUT COLLECTOR ---
output_collector = {
"answer": None,
"summary": None,
"translation": None,
"formal_email": None
}
# --- TOOLS ---
def answer_with_rag_tool(input_text: str, input_lang: str, **kwargs) -> str:
chain = get_qa_chain(language=input_lang)
result = chain.invoke({"input": input_text})
output_collector["answer"] = result
return result
def summarize_tool(input_text: str, input_lang: str, **kwargs) -> str:
llm, prompt = get_summary_chain(input_lang)
result = llm.invoke(prompt.format_messages(input=input_text))
output_collector["summary"] = result.content
return result
def translate_tool(input_text: str, input_lang: str, target_lang: str) -> str:
if not target_lang:
target_lang = input_lang # Default to input language if no target language provided
llm, prompt = get_translate_chain(input_lang)
result = llm.invoke(prompt.format_messages(input=input_text, target_lang=target_lang))
output_collector["translation"] = result.content
return result
def rewrite_email_tool(input_text: str, input_lang: str, target_lang: str) -> str:
llm, prompt = get_rewrite_chain(input_lang)
result = llm.invoke(prompt.format_messages(input=input_text, target_lang=target_lang))
output_collector["formal_email"] = result.content
return result
# --- LANGGRAPH STATE & NODES ---
class AgentState(Dict[str, Any]):
user_input: str
input_lang: str
target_lang: str
answer: str = None
summary: str = None
translation: str = None
formal_email: str = None
def node_answer_by_RAG(state: AgentState) -> AgentState:
answer = answer_with_rag_tool(state["user_input"], state["input_lang"])
state["answer"] = answer
return state
def node_summarize(state: AgentState) -> AgentState:
summary = summarize_tool(state["answer"], state["input_lang"])
state["summary"] = summary
return state
def node_translate(state: AgentState) -> AgentState:
# Use summary if exists, else answer
text_to_translate = state.get("summary") or state["answer"]
translation = translate_tool(text_to_translate, state["input_lang"], state["target_lang"])
state["translation"] = translation
return state
def node_rewrite(state: AgentState) -> AgentState:
# Use translation if exists, else summary, else answer
text_to_rewrite = state.get("translation") or state["answer"]
formal_email = rewrite_email_tool(text_to_rewrite, state["input_lang"], state["target_lang"])
state["formal_email"] = formal_email
return state
# --- LANGGRAPH GRAPH DEFINITION ---
graph = StateGraph(AgentState)
graph.add_node("answer_by_RAG", node_answer_by_RAG)
graph.add_node("summarize", node_summarize)
graph.add_node("translate", node_translate)
graph.add_node("rewrite", node_rewrite)
graph.set_entry_point("answer_by_RAG")
# Remove all conditional edges and decision functions, just allow direct edges for sequential execution if needed
# (But for the new execution plan, we do not need to set edges between nodes, as we step through them manually)
# Only keep END edge for completeness
# Add END edges for all possible last nodes
for node in ["answer_by_RAG", "summarize", "translate", "rewrite"]:
graph.add_edge(node, END)
compiled_graph = graph.compile()
# --- MAIN API CALL ---
def run_agent(user_input: str, input_lang: str = "Deutsch", target_lang: str = None, do_summarize: bool = False, do_translate: bool = False, do_email: bool = False):
print(user_input, input_lang, target_lang, do_summarize, do_translate, do_email)
for k in output_collector:
output_collector[k] = None
state = {
"user_input": user_input,
"input_lang": input_lang,
"target_lang": target_lang
}
execution_plan = ["answer_by_RAG"]
if do_summarize:
execution_plan.append("summarize")
if do_translate:
execution_plan.append("translate")
if do_email:
execution_plan.append("rewrite")
for node in execution_plan:
state = globals()[f"node_{node}"](state)
# Yield the current outputs after each node
yield {k: v for k, v in output_collector.items() if v is not None}