hash-map commited on
Commit
5fd30d3
·
verified ·
1 Parent(s): f008f3f

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +32 -20
rag.py CHANGED
@@ -5,6 +5,34 @@ from langchain_community.retrievers import BM25Retriever
5
  from langchain_community.llms import Ollama
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  loader = DirectoryLoader('.', glob="all_dialogues.txt")
10
  docs = loader.load()
@@ -47,19 +75,13 @@ def ensemble_retriever(query):
47
  from langchain_community.llms import HuggingFaceHub
48
  from langchain_core.prompts import ChatPromptTemplate
49
 
50
- def respond_rag_huggingface(
51
- message: str,
52
- system_message: str = " you are game of thrones measter answer the given question strictly based on the context provived.if u donot know the answer reply i dont know donot give gibberish answers",
53
- num_predict: int = 128,
54
- temperature: float = 0.8,
55
- ):
56
- # 1. Retrieve context
57
  docs = ensemble_retriever(message)
58
  context = "\n\n".join(doc.page_content for doc in docs)
59
 
60
- # 2. Prompt
61
  prompt_template = ChatPromptTemplate.from_messages([
62
- ("system", system_message),
63
  ("human", """Context: {context}
64
 
65
  Question: {question}
@@ -70,21 +92,11 @@ def respond_rag_huggingface(
70
  - Include book/season references when possible""")
71
  ])
72
 
73
- # 3. HuggingFace LLM (e.g., use `HuggingFaceH4/zephyr-7b-beta`)
74
- llm = HuggingFaceHub(
75
- repo_id="mistralai/Mistral-7B-Instruct-v0.1",
76
- model_kwargs={
77
- "temperature": temperature,
78
- "max_new_tokens": num_predict
79
- }
80
- )
81
-
82
- # 4. Run chain
83
  chain = prompt_template | llm
84
  response = chain.invoke({"context": context, "question": message})
85
-
86
  return response.content
87
 
 
88
  __all__ = ["respond_rag_huggingface"]
89
  # def respond_rag_ollama(
90
  # message: str,
 
5
  from langchain_community.llms import Ollama
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
 
8
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
9
+ from langchain_community.llms import HuggingFacePipeline
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain_community.document_loaders import DirectoryLoader
12
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
13
+ from langchain_community.embeddings import HuggingFaceEmbeddings
14
+ from langchain_community.vectorstores import FAISS
15
+ from langchain_community.retrievers import BM25Retriever
16
+
17
+ # Load Zephyr model
18
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
19
+ model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
20
+
21
+ # Create HF pipeline
22
+ hf_pipeline = pipeline(
23
+ "text-generation",
24
+ model=model,
25
+ tokenizer=tokenizer,
26
+ max_new_tokens=128,
27
+ temperature=0.8,
28
+ pad_token_id=tokenizer.eos_token_id,
29
+ )
30
+
31
+ # Wrap in LangChain LLM
32
+ llm = HuggingFacePipeline(pipeline=hf_pipeline)
33
+
34
+ # Define your RAG response function
35
+
36
 
37
  loader = DirectoryLoader('.', glob="all_dialogues.txt")
38
  docs = loader.load()
 
75
  from langchain_community.llms import HuggingFaceHub
76
  from langchain_core.prompts import ChatPromptTemplate
77
 
78
+
79
+ def respond_rag_huggingface(message: str):
 
 
 
 
 
80
  docs = ensemble_retriever(message)
81
  context = "\n\n".join(doc.page_content for doc in docs)
82
 
 
83
  prompt_template = ChatPromptTemplate.from_messages([
84
+ ("system", "you are game of thrones measter answer the given question strictly based on the context provived.if u donot know the answer reply i dont know donot give gibberish answers"),
85
  ("human", """Context: {context}
86
 
87
  Question: {question}
 
92
  - Include book/season references when possible""")
93
  ])
94
 
 
 
 
 
 
 
 
 
 
 
95
  chain = prompt_template | llm
96
  response = chain.invoke({"context": context, "question": message})
 
97
  return response.content
98
 
99
+
100
  __all__ = ["respond_rag_huggingface"]
101
  # def respond_rag_ollama(
102
  # message: str,