from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import FlashrankRerank from langchain.tools.retriever import create_retriever_tool class RetrieverManager: def __init__(self, vector_store): self.vector_store = vector_store def create_base_retriever(self, search_type: str ="similarity", k: int = 3): """Create basic vector store retriever""" return self.vector_store.as_retriever( search_type=search_type, search_kwargs={"k": k} ) def create_ensemble_retriever( self, texts, k: int = 3, vector_weight: float = 0.5, keyword_weight: float =0.5 ): """Create ensemble retriever combining vector and keyword search""" vector_retriever = self.create_base_retriever(k=k) keyword_retriever = BM25Retriever.from_documents(texts) keyword_retriever.k = k return EnsembleRetriever( retrievers=[vector_retriever, keyword_retriever], weights=[vector_weight, keyword_weight] ) def create_compression_retriever(self, base_retriever, top_n: int): """Create compression retriever with reranking""" compressor = FlashrankRerank(top_n=top_n) return ContextualCompressionRetriever( base_compressor=compressor, base_retriever=base_retriever ) def create_retriever(self, documents, top_n: int, k: int = 3, ): base_retriever = self.create_ensemble_retriever(texts=documents, k=k) compression_retriever = self.create_compression_retriever(base_retriever=base_retriever, top_n=top_n) retriever_tool = create_retriever_tool( compression_retriever, "retrieve_docs", "use tools for search through the user's provided documents and return relevant information about user query.", ) return retriever_tool