Spaces:
Sleeping
Sleeping
Update the chatbot to an LangGraph agent, handling four tasks including RAG, summary, translate and email. Update prompts. Update Gradio UI.
6861e2a
| from langgraph.graph import StateGraph, END | |
| from langgraph.graph.state import CompiledStateGraph | |
| from typing import Dict, Any | |
| from app.rag_chain import get_qa_chain | |
| from app.models.model import LLM as llm | |
| from app.summary_chain import get_summary_chain | |
| from app.translate_chain import get_translate_chain | |
| from app.rewrite_chain import get_rewrite_chain | |
| # --- OUTPUT COLLECTOR --- | |
| output_collector = { | |
| "answer": None, | |
| "summary": None, | |
| "translation": None, | |
| "formal_email": None | |
| } | |
| # --- TOOLS --- | |
| def answer_with_rag_tool(input_text: str, input_lang: str, **kwargs) -> str: | |
| chain = get_qa_chain(language=input_lang) | |
| result = chain.invoke({"input": input_text}) | |
| output_collector["answer"] = result | |
| return result | |
| def summarize_tool(input_text: str, input_lang: str, **kwargs) -> str: | |
| llm, prompt = get_summary_chain(input_lang) | |
| result = llm.invoke(prompt.format_messages(input=input_text)) | |
| output_collector["summary"] = result.content | |
| return result | |
| def translate_tool(input_text: str, input_lang: str, target_lang: str) -> str: | |
| if not target_lang: | |
| target_lang = input_lang # Default to input language if no target language provided | |
| llm, prompt = get_translate_chain(input_lang) | |
| result = llm.invoke(prompt.format_messages(input=input_text, target_lang=target_lang)) | |
| output_collector["translation"] = result.content | |
| return result | |
| def rewrite_email_tool(input_text: str, input_lang: str, target_lang: str) -> str: | |
| llm, prompt = get_rewrite_chain(input_lang) | |
| result = llm.invoke(prompt.format_messages(input=input_text, target_lang=target_lang)) | |
| output_collector["formal_email"] = result.content | |
| return result | |
| # --- LANGGRAPH STATE & NODES --- | |
| class AgentState(Dict[str, Any]): | |
| user_input: str | |
| input_lang: str | |
| target_lang: str | |
| answer: str = None | |
| summary: str = None | |
| translation: str = None | |
| formal_email: str = None | |
| def node_answer_by_RAG(state: AgentState) -> AgentState: | |
| answer = answer_with_rag_tool(state["user_input"], state["input_lang"]) | |
| state["answer"] = answer | |
| return state | |
| def node_summarize(state: AgentState) -> AgentState: | |
| summary = summarize_tool(state["answer"], state["input_lang"]) | |
| state["summary"] = summary | |
| return state | |
| def node_translate(state: AgentState) -> AgentState: | |
| # Use summary if exists, else answer | |
| text_to_translate = state.get("summary") or state["answer"] | |
| translation = translate_tool(text_to_translate, state["input_lang"], state["target_lang"]) | |
| state["translation"] = translation | |
| return state | |
| def node_rewrite(state: AgentState) -> AgentState: | |
| # Use translation if exists, else summary, else answer | |
| text_to_rewrite = state.get("translation") or state["answer"] | |
| formal_email = rewrite_email_tool(text_to_rewrite, state["input_lang"], state["target_lang"]) | |
| state["formal_email"] = formal_email | |
| return state | |
| # --- LANGGRAPH GRAPH DEFINITION --- | |
| graph = StateGraph(AgentState) | |
| graph.add_node("answer_by_RAG", node_answer_by_RAG) | |
| graph.add_node("summarize", node_summarize) | |
| graph.add_node("translate", node_translate) | |
| graph.add_node("rewrite", node_rewrite) | |
| graph.set_entry_point("answer_by_RAG") | |
| # Remove all conditional edges and decision functions, just allow direct edges for sequential execution if needed | |
| # (But for the new execution plan, we do not need to set edges between nodes, as we step through them manually) | |
| # Only keep END edge for completeness | |
| # Add END edges for all possible last nodes | |
| for node in ["answer_by_RAG", "summarize", "translate", "rewrite"]: | |
| graph.add_edge(node, END) | |
| compiled_graph = graph.compile() | |
| # --- MAIN API CALL --- | |
| def run_agent(user_input: str, input_lang: str = "Deutsch", target_lang: str = None, do_summarize: bool = False, do_translate: bool = False, do_email: bool = False): | |
| print(user_input, input_lang, target_lang, do_summarize, do_translate, do_email) | |
| for k in output_collector: | |
| output_collector[k] = None | |
| state = { | |
| "user_input": user_input, | |
| "input_lang": input_lang, | |
| "target_lang": target_lang | |
| } | |
| execution_plan = ["answer_by_RAG"] | |
| if do_summarize: | |
| execution_plan.append("summarize") | |
| if do_translate: | |
| execution_plan.append("translate") | |
| if do_email: | |
| execution_plan.append("rewrite") | |
| for node in execution_plan: | |
| state = globals()[f"node_{node}"](state) | |
| # Yield the current outputs after each node | |
| yield {k: v for k, v in output_collector.items() if v is not None} | |