Spaces:
Build error
Build error
| from langchain.vectorstores import FAISS | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.document_loaders import ( | |
| PyPDFLoader, | |
| DataFrameLoader, | |
| ) | |
| from langchain.document_loaders.csv_loader import CSVLoader | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| from langchain.chains.retrieval_qa.base import RetrievalQA | |
| from langchain.chat_models import ChatOpenAI | |
| from bot.utils.show_log import logger | |
| import pandas as pd | |
| import threading | |
| import glob | |
| import os | |
| import queue | |
| class SearchableIndex: | |
| def __init__(self, path): | |
| self.path = path | |
| def get_text_splits(self): | |
| with open(self.path, 'r') as txt: | |
| data = txt.read() | |
| text_split = RecursiveCharacterTextSplitter(chunk_size=1000, | |
| chunk_overlap=0, | |
| length_function=len) | |
| doc_list = text_split.split_text(data) | |
| return doc_list | |
| def get_pdf_splits(self): | |
| loader = PyPDFLoader(self.path) | |
| pages = loader.load_and_split() | |
| text_split = RecursiveCharacterTextSplitter(chunk_size=1000, | |
| chunk_overlap=0, | |
| length_function=len) | |
| doc_list = [] | |
| for pg in pages: | |
| pg_splits = text_split.split_text(pg.page_content) | |
| doc_list.extend(pg_splits) | |
| return doc_list | |
| def get_xml_splits(self, target_col, sheet_name): | |
| df = pd.read_excel(io=self.path, | |
| engine='openpyxl', | |
| sheet_name=sheet_name) | |
| df_loader = DataFrameLoader(df, | |
| page_content_column=target_col) | |
| excel_docs = df_loader.load() | |
| return excel_docs | |
| def get_csv_splits(self): | |
| csv_loader = CSVLoader(self.path) | |
| csv_docs = csv_loader.load() | |
| return csv_docs | |
| def merge_or_create_index(cls, index_store, faiss_db, embeddings, logger): | |
| if os.path.exists(index_store): | |
| local_db = FAISS.load_local(index_store, embeddings) | |
| local_db.merge_from(faiss_db) | |
| logger.info("Merge index completed") | |
| local_db.save_local(index_store) | |
| return local_db | |
| else: | |
| faiss_db.save_local(folder_path=index_store) | |
| logger.info("New store created and loaded...") | |
| local_db = FAISS.load_local(index_store, embeddings) | |
| return local_db | |
| def check_and_load_index(cls, index_files, embeddings, logger, path, result_queue): | |
| if index_files: | |
| local_db = FAISS.load_local(index_files[0], embeddings) | |
| file_to_remove = os.path.join(path, 'combined_content.txt') | |
| if os.path.exists(file_to_remove): | |
| os.remove(file_to_remove) | |
| else: | |
| raise logger.warning("Index store does not exist") | |
| result_queue.put(local_db) # Put the result in the queue | |
| def embed_index(cls, url, path, target_col=None, sheet_name=None): | |
| embeddings = OpenAIEmbeddings() | |
| def process_docs(queues, extension): | |
| nonlocal doc_list | |
| instance = cls(path) | |
| if extension == ".txt": | |
| doc_list = instance.get_text_splits() | |
| elif extension == ".pdf": | |
| doc_list = instance.get_pdf_splits() | |
| elif extension == ".xml": | |
| doc_list = instance.get_xml_splits(target_col, sheet_name) | |
| elif extension == ".csv": | |
| doc_list = instance.get_csv_splits() | |
| else: | |
| doc_list = None | |
| queues.put(doc_list) | |
| if url != 'NO_URL' and path: | |
| file_extension = os.path.splitext(path)[1].lower() | |
| data_queue = queue.Queue() | |
| thread = threading.Thread(target=process_docs, args=(data_queue, file_extension)) | |
| thread.start() | |
| doc_list = data_queue.get() | |
| if not doc_list: | |
| raise ValueError("Unsupported file format") | |
| faiss_db = FAISS.from_texts(doc_list, embeddings) | |
| index_store = os.path.splitext(path)[0] + "_index" | |
| local_db = cls.merge_or_create_index(index_store, faiss_db, embeddings, logger) | |
| return local_db, index_store | |
| elif url == 'NO_URL' and path: | |
| index_files = glob.glob(os.path.join(path, '*_index')) | |
| result_queue = queue.Queue() # Create a queue to store the result | |
| thread = threading.Thread(target=cls.check_and_load_index, | |
| args=(index_files, embeddings, logger, path, result_queue)) | |
| thread.start() | |
| local_db = result_queue.get() # Retrieve the result from the queue | |
| return local_db | |
| def query(cls, question: str, llm, index): | |
| """Query the vectorstore.""" | |
| llm = llm or ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0) | |
| chain = RetrievalQA.from_chain_type( | |
| llm, retriever=index.as_retriever() | |
| ) | |
| return chain.run(question) | |
| if __name__ == '__main__': | |
| pass | |
| # Examples for search query | |
| # index = SearchableIndex.embed_index( | |
| # path="/Users/macbook/Downloads/AI_test_exam/ChatBot/learning_documents/combined_content.txt") | |
| # prompt = 'show more detail about types of data collected' | |
| # llm = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0) | |
| # result = SearchableIndex.query(prompt, llm=llm, index=index) | |
| # print(result) | |