hash-map's picture
Update rag.py
9c820f3 verified
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
from langchain_community.llms import Ollama
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from langchain_community.llms import HuggingFacePipeline
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.document_loaders import DirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
import google.generativeai as genai
import os
import google.generativeai as genai
# Initialize Gemini
genai.configure(api_key=os.environ.get("GEMINI_API_KEY")) # Replace this with actual key or environment-safe config
model = genai.GenerativeModel("gemini-2.5-flash")
# Define your RAG response function
loader = DirectoryLoader('.', glob="all_dialogues.txt")
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=300, chunk_overlap=100
)
texts = text_splitter.split_documents(docs)
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
db = FAISS.load_local(
folder_path="./",
embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2"),
allow_dangerous_deserialization=True
)
# Vector Store Retriever
vector_retriever = db.as_retriever(search_kwargs={"k": 10})
# Keyword Retriever (BM25)
bm25_retriever = BM25Retriever.from_documents(texts)
bm25_retriever.k = 5
# Combine both
def ensemble_retriever(query):
vector_docs = vector_retriever.invoke(query)
bm25_docs = bm25_retriever.invoke(query)
combined_docs = vector_docs + bm25_docs
return combined_docs
# Use in ask_question()
from langchain_community.llms import HuggingFaceHub
from langchain_core.prompts import ChatPromptTemplate
def respond_rag_huggingface(message: str):
docs = ensemble_retriever(message)
context = "\n\n".join(doc.page_content for doc in docs)
system_message = os.environ.get("SYSTEM_MESSAGE",
"You are a Game of Thrones maester and Harry Potter's Dumbledore. " +
"Answer the given question based on your knowledge, providing accurate details without mentioning any specific sources or context used. " +
"State how much you know about the topic, and do not provide faulty answers. " +
"If the answer is unclear, clarify what you mean rather than saying 'I do not know.'")
prompt = f"""{system_message}
Question:
{message}
Rules:
- Do not mention the context or where the information comes from
- State how much you know about the topic (e.g., 'I have detailed knowledge,' 'I have some knowledge,' or 'My knowledge is limited')
- Keep answers under 5 sentences
- Include book/season references when possible
- Answer based on relevant knowledge from Game of Thrones and Harry Potter
"""
response = model.generate_content(prompt)
return response.text
__all__ = ["respond_rag_huggingface"]
# def respond_rag_ollama(
# message: str,
# system_message: str = "you are game of thrones measter answer the given question strictly based on the context provived.if u donot know the answer reply i dont know donot give gibberish answers ",
# num_ctx: int = 2048,
# num_predict: int = 128,
# temperature: float = 0.8,
# top_k: int = 40,
# repeat_penalty: float = 1.1,
# stop: list[str] | None = None,
# ):
# partial_response=""
# # 1. Retrieve relevant context from your vector DB
# docs = ensemble_retriever.get_relevant_documents(message)
# context = "\n\n".join(doc.page_content for doc in docs)
# # 2. Build a conversational prompt
# prompt_template = ChatPromptTemplate.from_messages([
# ("system", system_message),
# ("human", f"""Context: {{context}}
# Question: {{question}}
# Rules:
# - If the answer isn't in the context, respond with "I don't know"
# - Keep answers under 5 sentences
# - Include book/season references when possible""")
# ])
# # 3. Configure the Ollama LLM with adjustable parameters
# llm = Ollama(
# model="llama3:8b-instruct-q4_0",
# temperature=temperature,
# num_ctx=num_ctx,
# num_predict=num_predict,
# top_k=top_k,
# repeat_penalty=repeat_penalty,
# stop= ["<|eot_id|>"],
# )
# chain = prompt_template | llm
# response = chain.invoke({"context": context, "question": message})
# return response.content