Spaces:
Runtime error
Runtime error
Update edubot.py
Browse files
edubot.py
CHANGED
|
@@ -4,7 +4,7 @@ from langchain.vectorstores import FAISS
|
|
| 4 |
from langchain.llms import CTransformers
|
| 5 |
from langchain.chains import RetrievalQA
|
| 6 |
from config import *
|
| 7 |
-
|
| 8 |
class EduBotCreator:
|
| 9 |
|
| 10 |
def __init__(self):
|
|
@@ -23,7 +23,7 @@ class EduBotCreator:
|
|
| 23 |
custom_prompt_temp = PromptTemplate(template=self.prompt_temp,
|
| 24 |
input_variables=self.input_variables)
|
| 25 |
return custom_prompt_temp
|
| 26 |
-
|
| 27 |
def load_llm(self):
|
| 28 |
llm = CTransformers(
|
| 29 |
model = self.model_ckpt,
|
|
@@ -32,7 +32,7 @@ class EduBotCreator:
|
|
| 32 |
temperature = self.temperature
|
| 33 |
)
|
| 34 |
return llm
|
| 35 |
-
|
| 36 |
def load_vectordb(self):
|
| 37 |
hfembeddings = HuggingFaceEmbeddings(
|
| 38 |
model_name=self.embedder,
|
|
@@ -42,7 +42,6 @@ class EduBotCreator:
|
|
| 42 |
vector_db = FAISS.load_local(self.vector_db_path, hfembeddings)
|
| 43 |
return vector_db
|
| 44 |
|
| 45 |
-
|
| 46 |
def create_bot(self, custom_prompt, vectordb, llm):
|
| 47 |
retrieval_qa_chain = RetrievalQA.from_chain_type(
|
| 48 |
llm=llm,
|
|
|
|
| 4 |
from langchain.llms import CTransformers
|
| 5 |
from langchain.chains import RetrievalQA
|
| 6 |
from config import *
|
| 7 |
+
|
| 8 |
class EduBotCreator:
|
| 9 |
|
| 10 |
def __init__(self):
|
|
|
|
| 23 |
custom_prompt_temp = PromptTemplate(template=self.prompt_temp,
|
| 24 |
input_variables=self.input_variables)
|
| 25 |
return custom_prompt_temp
|
| 26 |
+
|
| 27 |
def load_llm(self):
|
| 28 |
llm = CTransformers(
|
| 29 |
model = self.model_ckpt,
|
|
|
|
| 32 |
temperature = self.temperature
|
| 33 |
)
|
| 34 |
return llm
|
| 35 |
+
|
| 36 |
def load_vectordb(self):
|
| 37 |
hfembeddings = HuggingFaceEmbeddings(
|
| 38 |
model_name=self.embedder,
|
|
|
|
| 42 |
vector_db = FAISS.load_local(self.vector_db_path, hfembeddings)
|
| 43 |
return vector_db
|
| 44 |
|
|
|
|
| 45 |
def create_bot(self, custom_prompt, vectordb, llm):
|
| 46 |
retrieval_qa_chain = RetrievalQA.from_chain_type(
|
| 47 |
llm=llm,
|