| from haystack.document_stores import FAISSDocumentStore | |
| from haystack.utils import convert_files_to_docs, fetch_archive_from_http, clean_wiki_text | |
| from haystack.nodes import DensePassageRetriever | |
| from haystack.utils import print_documents, print_answers | |
| from haystack.pipelines import DocumentSearchPipeline | |
| from haystack.nodes import Seq2SeqGenerator | |
| from haystack.pipelines import GenerativeQAPipeline | |
| # %% Save/Load FAISS and embeddings | |
| # Try out this script. Make sure you have deleted any old saves of the document store, including the file called faiss_document_store.db that is saved and loaded by default. | |
| # # Convert files to dicts | |
| # dicts = convert_files_to_dicts(dir_path=doc_dir, clean_func=clean_wiki_text, split_paragraphs=True)[:10] | |
| # document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", vector_dim=128) | |
| # # document_store = FAISSDocumentStore(sql_url= "sqlite:///faiss_document_store.db") | |
| # retriever = EmbeddingRetriever(document_store=document_store, | |
| # embedding_model="yjernite/retribert-base-uncased", | |
| # model_format="retribert", | |
| # use_gpu=False) | |
| # # Now, let's write the dicts containing documents to our DB. | |
| # document_store.write_documents(dicts) | |
| # document_store.update_embeddings(retriever) | |
| # document_store.save("my_faiss_index.faiss") | |
| # new_document_store= FAISSDocumentStore.load("my_faiss_index.faiss") | |
| # # new_document_store = FAISSDocumentStore.load(faiss_file_path="testfile_path", sql_url= "sqlite:///faiss_document_store.db") | |
| # %% ------------------------------------------------------------------------------------------------------------ | |
| def prepare(): | |
| # %% Document Store | |
| document_store= FAISSDocumentStore.load("faiss_index.faiss") | |
| # %% Initialize Retriever and Reader/Generator | |
| # Retriever (DPR) | |
| retriever = DensePassageRetriever( | |
| document_store=document_store, | |
| query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki", | |
| passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki", | |
| use_gpu=False | |
| ) | |
| # # Test DPR | |
| # p_retrieval = DocumentSearchPipeline(retriever) | |
| # res = p_retrieval.run(query="Tell me something about Arya Stark?", params={"Retriever": {"top_k": 5}}) | |
| # print_documents(res, max_text_len=512) | |
| # Reader/Generator | |
| # Here we use a Seq2SeqGenerator with the vblagoje/bart_lfqa model (https://huggingface.co/vblagoje/bart_lfqa) | |
| generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa", | |
| use_gpu=False) | |
| # %% Pipeline | |
| pipe = GenerativeQAPipeline(generator, retriever) | |
| return pipe | |
| def answer(pipe, question, k_retriever=3): | |
| res = pipe.run(question, params={"Retriever": {"top_k": k_retriever}}) | |
| # # Question | |
| # pipe.run( | |
| # query="How did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 3}} | |
| # ) | |
| # # Answer | |
| # res = pipe.run(query="Why is Arya Stark an unusual character?", params={"Retriever": {"top_k": 3}}) | |
| return res | |
| if __name__ == '__main__': | |
| question = 'Tell me something about Arya Stark?' | |
| pipe = prepare() | |
| res = answer(pipe, question) | |
| print_answers(res, details="minimum") | |