Spaces:
Sleeping
Sleeping
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain.retrievers import ContextualCompressionRetriever | |
| from langchain.retrievers.document_compressors import FlashrankRerank | |
| from langchain.tools.retriever import create_retriever_tool | |
| class RetrieverManager: | |
| def __init__(self, vector_store): | |
| self.vector_store = vector_store | |
| def create_base_retriever(self, search_type: str ="similarity", k: int = 3): | |
| """Create basic vector store retriever""" | |
| return self.vector_store.as_retriever( | |
| search_type=search_type, | |
| search_kwargs={"k": k} | |
| ) | |
| def create_ensemble_retriever( | |
| self, | |
| texts, | |
| k: int = 3, | |
| vector_weight: float = 0.5, | |
| keyword_weight: float =0.5 | |
| ): | |
| """Create ensemble retriever combining vector and keyword search""" | |
| vector_retriever = self.create_base_retriever(k=k) | |
| keyword_retriever = BM25Retriever.from_documents(texts) | |
| keyword_retriever.k = k | |
| return EnsembleRetriever( | |
| retrievers=[vector_retriever, keyword_retriever], | |
| weights=[vector_weight, keyword_weight] | |
| ) | |
| def create_compression_retriever(self, base_retriever, top_n: int): | |
| """Create compression retriever with reranking""" | |
| compressor = FlashrankRerank(top_n=top_n) | |
| return ContextualCompressionRetriever( | |
| base_compressor=compressor, | |
| base_retriever=base_retriever | |
| ) | |
| def create_retriever(self, documents, top_n: int, k: int = 3, ): | |
| base_retriever = self.create_ensemble_retriever(texts=documents, k=k) | |
| compression_retriever = self.create_compression_retriever(base_retriever=base_retriever, top_n=top_n) | |
| retriever_tool = create_retriever_tool( | |
| compression_retriever, | |
| "retrieve_docs", | |
| "use tools for search through the user's provided documents and return relevant information about user query.", | |
| ) | |
| return retriever_tool |