|
|
from langchain_community.document_loaders import DirectoryLoader |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_community.retrievers import BM25Retriever |
|
|
from langchain_community.llms import Ollama |
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
from langchain_community.llms import HuggingFacePipeline |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_community.document_loaders import DirectoryLoader |
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_community.retrievers import BM25Retriever |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") |
|
|
model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta") |
|
|
|
|
|
|
|
|
hf_pipeline = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_new_tokens=128, |
|
|
temperature=0.8, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=hf_pipeline) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loader = DirectoryLoader('.', glob="all_dialogues.txt") |
|
|
docs = loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=1000, chunk_overlap=200 |
|
|
) |
|
|
texts = text_splitter.split_documents(docs) |
|
|
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") |
|
|
|
|
|
|
|
|
db = FAISS.load_local( |
|
|
folder_path="./", |
|
|
embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2"), |
|
|
allow_dangerous_deserialization=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vector_retriever = db.as_retriever(search_kwargs={"k": 5}) |
|
|
|
|
|
|
|
|
bm25_retriever = BM25Retriever.from_documents(texts) |
|
|
bm25_retriever.k = 5 |
|
|
|
|
|
|
|
|
def ensemble_retriever(query): |
|
|
vector_docs = vector_retriever.get_relevant_documents(query) |
|
|
bm25_docs = bm25_retriever.get_relevant_documents(query) |
|
|
|
|
|
|
|
|
combined_docs = vector_docs+ bm25_docs |
|
|
return combined_docs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain_community.llms import HuggingFaceHub |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
|
|
|
|
def respond_rag_huggingface(message: str): |
|
|
docs = ensemble_retriever(message) |
|
|
context = "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
|
prompt_template = ChatPromptTemplate.from_messages([ |
|
|
("system", "you are game of thrones measter answer the given question strictly based on the context provived.if u donot know the answer reply i dont know donot give gibberish answers"), |
|
|
("human", """Context: {context} |
|
|
|
|
|
Question: {question} |
|
|
|
|
|
Rules: |
|
|
- If the answer isn't in the context, respond with "I don't know" |
|
|
- Keep answers under 5 sentences |
|
|
- Include book/season references when possible""") |
|
|
]) |
|
|
|
|
|
chain = prompt_template | llm |
|
|
response = chain.invoke({"context": context, "question": message}) |
|
|
return response.content |
|
|
|
|
|
|
|
|
__all__ = ["respond_rag_huggingface"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|