hash-map's picture
Update rag.py
2c9a851 verified
raw
history blame
4.47 kB
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
from langchain_community.llms import Ollama
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from langchain_community.llms import HuggingFacePipeline
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.document_loaders import DirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
import google.generativeai as genai
import os
import google.generativeai as genai
# Initialize Gemini
genai.configure(api_key=os.environ.get("GEMINI_API_KEY")) # Replace this with actual key or environment-safe config
model = genai.GenerativeModel("gemini-1.5-flash")
# Define your RAG response function
loader = DirectoryLoader('.', glob="all_dialogues.txt")
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200
)
texts = text_splitter.split_documents(docs)
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
db = FAISS.load_local(
folder_path="./",
embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2"),
allow_dangerous_deserialization=True
)
# Vector Store Retriever
vector_retriever = db.as_retriever(search_kwargs={"k": 5})
# Keyword Retriever (BM25)
bm25_retriever = BM25Retriever.from_documents(texts)
bm25_retriever.k = 5
# Combine both
def ensemble_retriever(query):
vector_docs = vector_retriever.get_relevant_documents(query)
bm25_docs = bm25_retriever.get_relevant_documents(query)
# Optionally weight them
combined_docs = vector_docs+ bm25_docs # or use custom weighting
return combined_docs
# Use in ask_question()
from langchain_community.llms import HuggingFaceHub
from langchain_core.prompts import ChatPromptTemplate
def respond_rag_huggingface(message: str):
docs = ensemble_retriever(message)
context = "\n\n".join(doc.page_content for doc in docs)
system_message = os.environ.get("SYSTEM_MESSAGE",
"You are a Game of Thrones measter. Answer the given question strictly based on the provided context. If you don't know, say 'I don't know'. Do not guess.")
prompt = f"""{system_message}
Context:
{context}
Question:
{message}
Rules:
- If the answer isn't in the context, respond with "I don't know"
- Keep answers under 5 sentences
- Include book/season references when possible"""
response = model.generate_content(prompt)
return response.text
__all__ = ["respond_rag_huggingface"]
# def respond_rag_ollama(
# message: str,
# system_message: str = "you are game of thrones measter answer the given question strictly based on the context provived.if u donot know the answer reply i dont know donot give gibberish answers ",
# num_ctx: int = 2048,
# num_predict: int = 128,
# temperature: float = 0.8,
# top_k: int = 40,
# repeat_penalty: float = 1.1,
# stop: list[str] | None = None,
# ):
# partial_response=""
# # 1. Retrieve relevant context from your vector DB
# docs = ensemble_retriever.get_relevant_documents(message)
# context = "\n\n".join(doc.page_content for doc in docs)
# # 2. Build a conversational prompt
# prompt_template = ChatPromptTemplate.from_messages([
# ("system", system_message),
# ("human", f"""Context: {{context}}
# Question: {{question}}
# Rules:
# - If the answer isn't in the context, respond with "I don't know"
# - Keep answers under 5 sentences
# - Include book/season references when possible""")
# ])
# # 3. Configure the Ollama LLM with adjustable parameters
# llm = Ollama(
# model="llama3:8b-instruct-q4_0",
# temperature=temperature,
# num_ctx=num_ctx,
# num_predict=num_predict,
# top_k=top_k,
# repeat_penalty=repeat_penalty,
# stop= ["<|eot_id|>"],
# )
# chain = prompt_template | llm
# response = chain.invoke({"context": context, "question": message})
# return response.content