hash-map's picture
Update rag.py
c9ba4ba verified
raw
history blame
4.06 kB
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever, EnsembleRetriever
from langchain_community.llms import Ollama
loader = DirectoryLoader('.', glob="all_dialogues.txt")
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200
)
texts = text_splitter.split_documents(docs)
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
db = FAISS.load_local(
folder_path="./",
embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2"),
allow_dangerous_deserialization=True
)
# Vector Store Retriever
vector_retriever = db.as_retriever(search_kwargs={"k": 3})
# Keyword Retriever (BM25)
bm25_retriever = BM25Retriever.from_documents(texts)
bm25_retriever.k = 2
# Combine both
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=[0.6, 0.4] # Tune based on your tests
)
# Use in ask_question()
from langchain_community.llms import HuggingFaceHub
from langchain_core.prompts import ChatPromptTemplate
def respond_rag_huggingface(
message: str,
system_message: str = " you are game of thrones measter answer the given question strictly based on the context provived.if u donot know the answer reply i dont know donot give gibberish answers",
num_predict: int = 128,
temperature: float = 0.8,
):
# 1. Retrieve context
docs = ensemble_retriever.get_relevant_documents(message)
context = "\n\n".join(doc.page_content for doc in docs)
# 2. Prompt
prompt_template = ChatPromptTemplate.from_messages([
("system", system_message),
("human", """Context: {context}
Question: {question}
Rules:
- If the answer isn't in the context, respond with "I don't know"
- Keep answers under 5 sentences
- Include book/season references when possible""")
])
# 3. HuggingFace LLM (e.g., use `HuggingFaceH4/zephyr-7b-beta`)
llm = HuggingFaceHub(
repo_id="HuggingFaceH4/zephyr-7b-beta",
model_kwargs={
"temperature": temperature,
"max_new_tokens": num_predict
}
)
# 4. Run chain
chain = prompt_template | llm
response = chain.invoke({"context": context, "question": message})
return response.content
__all__ = ["respond_rag_huggingface"]
# def respond_rag_ollama(
# message: str,
# system_message: str = "you are game of thrones measter answer the given question strictly based on the context provived.if u donot know the answer reply i dont know donot give gibberish answers ",
# num_ctx: int = 2048,
# num_predict: int = 128,
# temperature: float = 0.8,
# top_k: int = 40,
# repeat_penalty: float = 1.1,
# stop: list[str] | None = None,
# ):
# partial_response=""
# # 1. Retrieve relevant context from your vector DB
# docs = ensemble_retriever.get_relevant_documents(message)
# context = "\n\n".join(doc.page_content for doc in docs)
# # 2. Build a conversational prompt
# prompt_template = ChatPromptTemplate.from_messages([
# ("system", system_message),
# ("human", f"""Context: {{context}}
# Question: {{question}}
# Rules:
# - If the answer isn't in the context, respond with "I don't know"
# - Keep answers under 5 sentences
# - Include book/season references when possible""")
# ])
# # 3. Configure the Ollama LLM with adjustable parameters
# llm = Ollama(
# model="llama3:8b-instruct-q4_0",
# temperature=temperature,
# num_ctx=num_ctx,
# num_predict=num_predict,
# top_k=top_k,
# repeat_penalty=repeat_penalty,
# stop= ["<|eot_id|>"],
# )
# chain = prompt_template | llm
# response = chain.invoke({"context": context, "question": message})
# return response.content