Spaces:
Sleeping
Sleeping
File size: 4,143 Bytes
91c6bea 986437f 91c6bea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage
from typing import Literal
from pydantic import BaseModel, Field
from langchain_core.prompts import PromptTemplate
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode, tools_condition
from .state import AgentState
from src.llm.llm_interface import llm_groq
class GradeDocs(BaseModel):
binary_score: str = Field(description="Relevance score 'yes' or 'no'")
class RAGWorkflow:
def __init__(self, retriever_tool):
self.workflow = StateGraph(AgentState)
self.tools = [retriever_tool]
self.retrieve = ToolNode([retriever_tool])
self._setup_nodes()
self._setup_edges()
def _setup_nodes(self):
self.workflow.add_node("agent", self._agent_node)
self.workflow.add_node("retrieve", self.retrieve)
self.workflow.add_node("generate", self._generator_node)
self.workflow.add_node("rewrite", self._rewrite_node)
def _setup_edges(self):
self.workflow.add_edge(START, "agent")
self.workflow.add_conditional_edges(
"agent",
tools_condition,
{
"tools": "retrieve",
END: END
}
)
self.workflow.add_conditional_edges(
"retrieve",
self._grade_docs,
)
self.workflow.add_edge("generate", END)
self.workflow.add_edge("rewrite", "agent")
def compile(self):
return self.workflow.compile()
def _agent_node(self, state):
print("---CALL AGENT---")
messages = state["messages"]
model = llm_groq.bind_tools(self.tools)
response = model.invoke(messages[0].content)
return {"messages": [response]}
def _generator_node(self, state):
print("---GENERATE---")
messages = state["messages"]
question = messages[0].content
docs = messages[-1].content
prompt = hub.pull("rlm/rag-prompt")
rag_chain = prompt | llm_groq | StrOutputParser()
response = rag_chain.invoke({"context": docs, "question": question})
return {"messages": [response]}
def _rewrite_node(self, state):
print("---REWRITE---")
messages = state["messages"]
question = messages[0].content
msg = [
HumanMessage(
content=f""" \n
Look at the input and try to reason about the underlying semantic intent / meaning. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate an improved question: """,
)
]
response = llm_groq.invoke(msg)
return {"messages": [response]}
def _grade_docs(self, state):
print("---CHECK RELEVANCE---")
llm_with_tool = llm_groq.with_structured_output(GradeDocs)
prompt = PromptTemplate(
template="""You are a grader assessing relevance of a retrieved document to a user question. \n
Here is the retrieved document: \n\n {context} \n\n
Here is the user question: {question} \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
input_variables=["context", "question"],
)
chain = prompt | llm_with_tool
messages = state["messages"]
question = messages[0].content
docs = messages[-1].content
scored_result = chain.invoke({"question": question, "context": docs})
if scored_result.binary_score == "yes":
print("---DECISION: DOCS RELEVANT---")
return "generate"
print("---DECISION: DOCS NOT RELEVANT---")
return "rewrite"
|