from langchain import hub from langchain_core.output_parsers import StrOutputParser from langchain_core.messages import HumanMessage from typing import Literal from pydantic import BaseModel, Field from langchain_core.prompts import PromptTemplate from langgraph.graph import END, StateGraph, START from langgraph.prebuilt import ToolNode, tools_condition from .state import AgentState from src.llm.llm_interface import llm_groq class GradeDocs(BaseModel): binary_score: str = Field(description="Relevance score 'yes' or 'no'") class RAGWorkflow: def __init__(self, retriever_tool): self.workflow = StateGraph(AgentState) self.tools = [retriever_tool] self.retrieve = ToolNode([retriever_tool]) self._setup_nodes() self._setup_edges() def _setup_nodes(self): self.workflow.add_node("agent", self._agent_node) self.workflow.add_node("retrieve", self.retrieve) self.workflow.add_node("generate", self._generator_node) self.workflow.add_node("rewrite", self._rewrite_node) def _setup_edges(self): self.workflow.add_edge(START, "agent") self.workflow.add_conditional_edges( "agent", tools_condition, { "tools": "retrieve", END: END } ) self.workflow.add_conditional_edges( "retrieve", self._grade_docs, ) self.workflow.add_edge("generate", END) self.workflow.add_edge("rewrite", "agent") def compile(self): return self.workflow.compile() def _agent_node(self, state): print("---CALL AGENT---") messages = state["messages"] model = llm_groq.bind_tools(self.tools) response = model.invoke(messages[0].content) return {"messages": [response]} def _generator_node(self, state): print("---GENERATE---") messages = state["messages"] question = messages[0].content docs = messages[-1].content prompt = hub.pull("rlm/rag-prompt") rag_chain = prompt | llm_groq | StrOutputParser() response = rag_chain.invoke({"context": docs, "question": question}) return {"messages": [response]} def _rewrite_node(self, state): print("---REWRITE---") messages = state["messages"] question = messages[0].content msg = [ HumanMessage( content=f""" \n Look at the input and try to reason about the underlying semantic intent / meaning. \n Here is the initial question: \n ------- \n {question} \n ------- \n Formulate an improved question: """, ) ] response = llm_groq.invoke(msg) return {"messages": [response]} def _grade_docs(self, state): print("---CHECK RELEVANCE---") llm_with_tool = llm_groq.with_structured_output(GradeDocs) prompt = PromptTemplate( template="""You are a grader assessing relevance of a retrieved document to a user question. \n Here is the retrieved document: \n\n {context} \n\n Here is the user question: {question} \n If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""", input_variables=["context", "question"], ) chain = prompt | llm_with_tool messages = state["messages"] question = messages[0].content docs = messages[-1].content scored_result = chain.invoke({"question": question, "context": docs}) if scored_result.binary_score == "yes": print("---DECISION: DOCS RELEVANT---") return "generate" print("---DECISION: DOCS NOT RELEVANT---") return "rewrite"