hash-map commited on
Commit
3c1111f
·
verified ·
1 Parent(s): d554370

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +77 -66
rag.py CHANGED
@@ -1,66 +1,77 @@
1
- from langchain.document_loaders import DirectoryLoader
2
- from langchain.text_splitter import RecursiveCharacterTextSplitter
3
- from langchain.embeddings import HuggingFaceEmbeddings
4
- from langchain.vectorstores import FAISS
5
- from langchain.llms import Ollama
6
- db = FAISS.load_local(
7
- folder_path="got_embeddings",
8
- embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2"),
9
- allow_dangerous_deserialization=True
10
- )
11
-
12
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
13
-
14
- # Vector Store Retriever
15
- vector_retriever = db.as_retriever(search_kwargs={"k": 3})
16
-
17
- # Keyword Retriever (BM25)
18
- bm25_retriever = BM25Retriever.from_documents(texts)
19
- bm25_retriever.k = 2
20
-
21
- # Combine both
22
- ensemble_retriever = EnsembleRetriever(
23
- retrievers=[vector_retriever, bm25_retriever],
24
- weights=[0.6, 0.4] # Tune based on your tests
25
- )
26
-
27
- # Use in ask_question()
28
-
29
-
30
- from langchain_community.llms import Ollama
31
- from langchain_core.prompts import ChatPromptTemplate
32
-
33
- def ask_question(question,temparature=0.5,num_ctx=4096,top_k=40):
34
- # 1. Retrieve relevant context from your vector DB
35
- docs = ensemble_retriever.get_relevant_documents(question)
36
- context = "\n\n".join([doc.page_content for doc in docs])
37
-
38
- # 2. Create optimized prompt template
39
- prompt_template = ChatPromptTemplate.from_messages([
40
- ("system", "You are a Game of Thrones expert. Answer strictly based on the context."),
41
- ("human", """Context: {context}
42
-
43
- Question: {question}
44
-
45
- Rules:
46
- - If answer isn't in context, say "I don't know"
47
- - Keep answers under 5 sentences
48
- - Include book/season references when possible""")
49
- ])
50
-
51
-
52
- # 3. Configure Ollama with your specific model parameters
53
- llm = Ollama(
54
- model="llama3:8b-instruct-q4_0",
55
- temperature=temparature,
56
- num_ctx=num_ctx,
57
- top_k=top_k,
58
- repeat_penalty=1.1,
59
- stop=["<|eot_id|>"],
60
-
61
- # Disable GPU entirely
62
- )
63
-
64
- # 4. Generate response
65
- chain = prompt_template | llm
66
- return chain.invoke({"context": context, "question": question})
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.document_loaders import DirectoryLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain.embeddings import HuggingFaceEmbeddings
4
+ from langchain.vectorstores import FAISS
5
+ from langchain.llms import Ollama
6
+ db = FAISS.load_local(
7
+ folder_path="got_embeddings",
8
+ embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2"),
9
+ allow_dangerous_deserialization=True
10
+ )
11
+
12
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
13
+
14
+ # Vector Store Retriever
15
+ vector_retriever = db.as_retriever(search_kwargs={"k": 3})
16
+
17
+ # Keyword Retriever (BM25)
18
+ bm25_retriever = BM25Retriever.from_documents(texts)
19
+ bm25_retriever.k = 2
20
+
21
+ # Combine both
22
+ ensemble_retriever = EnsembleRetriever(
23
+ retrievers=[vector_retriever, bm25_retriever],
24
+ weights=[0.6, 0.4] # Tune based on your tests
25
+ )
26
+
27
+ # Use in ask_question()
28
+
29
+
30
+ from langchain_community.llms import Ollama
31
+ from langchain_core.prompts import ChatPromptTemplate
32
+
33
+ from langchain_community.llms import Ollama
34
+ from langchain_core.prompts import ChatPromptTemplate
35
+
36
+ def respond_rag_ollama(
37
+ message: str,
38
+ history: list[tuple[str, str]],
39
+ system_message: str,
40
+ num_ctx: int = 2048,
41
+ num_predict: int = 128,
42
+ temperature: float = 0.8,
43
+ top_k: int = 40,
44
+ repeat_penalty: float = 1.1,
45
+ stop: list[str] | None = None,
46
+ ):
47
+ # 1. Retrieve relevant context from your vector DB
48
+ docs = ensemble_retriever.get_relevant_documents(message)
49
+ context = "\n\n".join(doc.page_content for doc in docs)
50
+
51
+ # 2. Build a conversational prompt
52
+ prompt_template = ChatPromptTemplate.from_messages([
53
+ ("system", system_message),
54
+ ("human", f"""Context: {{context}}
55
+
56
+ Question: {{question}}
57
+
58
+ Rules:
59
+ - If the answer isn't in the context, respond with "I don't know"
60
+ - Keep answers under 5 sentences
61
+ - Include book/season references when possible""")
62
+ ])
63
+
64
+ # 3. Configure the Ollama LLM with adjustable parameters
65
+ llm = Ollama(
66
+ model="llama3:8b-instruct-q4_0",
67
+ temperature=temperature,
68
+ num_ctx=num_ctx,
69
+ num_predict=num_predict,
70
+ top_k=top_k,
71
+ repeat_penalty=repeat_penalty,
72
+ stop= ["<|eot_id|>"],
73
+ )
74
+
75
+
76
+ chain = prompt_template | llm
77
+ yield from chain.stream_invoke({"context": context, "question": message})