|
|
from langchain.document_loaders import DirectoryLoader |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.vectorstores import FAISS |
|
|
from langchain.llms import Ollama |
|
|
|
|
|
loader = DirectoryLoader('.', glob="all_dialogues.txt") |
|
|
docs = loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=1000, chunk_overlap=200 |
|
|
) |
|
|
texts = text_splitter.split_documents(docs) |
|
|
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") |
|
|
|
|
|
|
|
|
db = FAISS.load_local( |
|
|
folder_path="./", |
|
|
embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2"), |
|
|
allow_dangerous_deserialization=True |
|
|
) |
|
|
|
|
|
from langchain.retrievers import BM25Retriever, EnsembleRetriever |
|
|
|
|
|
|
|
|
vector_retriever = db.as_retriever(search_kwargs={"k": 3}) |
|
|
|
|
|
|
|
|
bm25_retriever = BM25Retriever.from_documents(texts) |
|
|
bm25_retriever.k = 2 |
|
|
|
|
|
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
|
retrievers=[vector_retriever, bm25_retriever], |
|
|
weights=[0.6, 0.4] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain_community.llms import Ollama |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
|
from langchain_community.llms import Ollama |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
|
def respond_rag_ollama( |
|
|
message: str, |
|
|
history: list[tuple[str, str]], |
|
|
system_message: str, |
|
|
num_ctx: int = 2048, |
|
|
num_predict: int = 128, |
|
|
temperature: float = 0.8, |
|
|
top_k: int = 40, |
|
|
repeat_penalty: float = 1.1, |
|
|
stop: list[str] | None = None, |
|
|
): |
|
|
|
|
|
docs = ensemble_retriever.get_relevant_documents(message) |
|
|
context = "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
|
|
|
|
prompt_template = ChatPromptTemplate.from_messages([ |
|
|
("system", system_message), |
|
|
("human", f"""Context: {{context}} |
|
|
|
|
|
Question: {{question}} |
|
|
|
|
|
Rules: |
|
|
- If the answer isn't in the context, respond with "I don't know" |
|
|
- Keep answers under 5 sentences |
|
|
- Include book/season references when possible""") |
|
|
]) |
|
|
|
|
|
|
|
|
llm = Ollama( |
|
|
model="llama3:8b-instruct-q4_0", |
|
|
temperature=temperature, |
|
|
num_ctx=num_ctx, |
|
|
num_predict=num_predict, |
|
|
top_k=top_k, |
|
|
repeat_penalty=repeat_penalty, |
|
|
stop= ["<|eot_id|>"], |
|
|
) |
|
|
|
|
|
|
|
|
chain = prompt_template | llm |
|
|
yield from chain.stream_invoke({"context": context, "question": message}) |
|
|
|