Spaces:
Paused
Paused
| import os | |
| import asyncio | |
| from typing import Optional | |
| from .retriever import SearchAPIRetriever, SectionRetriever | |
| from langchain.retrievers import ( | |
| ContextualCompressionRetriever, | |
| ) | |
| from langchain.retrievers.document_compressors import ( | |
| DocumentCompressorPipeline, | |
| EmbeddingsFilter, | |
| ) | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from ..vector_store import VectorStoreWrapper | |
| from ..utils.costs import estimate_embedding_cost | |
| from ..memory.embeddings import OPENAI_EMBEDDING_MODEL | |
| class VectorstoreCompressor: | |
| def __init__(self, vector_store: VectorStoreWrapper, max_results:int = 7, filter: Optional[dict] = None, **kwargs): | |
| self.vector_store = vector_store | |
| self.max_results = max_results | |
| self.filter = filter | |
| self.kwargs = kwargs | |
| def __pretty_print_docs(self, docs): | |
| return f"\n".join(f"Source: {d.metadata.get('source')}\n" | |
| f"Title: {d.metadata.get('title')}\n" | |
| f"Content: {d.page_content}\n" | |
| for d in docs) | |
| async def async_get_context(self, query, max_results=5): | |
| """Get relevant context from vector store""" | |
| results = await self.vector_store.asimilarity_search(query=query, k=max_results, filter=self.filter) | |
| return self.__pretty_print_docs(results) | |
| class ContextCompressor: | |
| def __init__(self, documents, embeddings, max_results=5, **kwargs): | |
| self.max_results = max_results | |
| self.documents = documents | |
| self.kwargs = kwargs | |
| self.embeddings = embeddings | |
| self.similarity_threshold = os.environ.get("SIMILARITY_THRESHOLD", 0.35) | |
| def __get_contextual_retriever(self): | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| relevance_filter = EmbeddingsFilter(embeddings=self.embeddings, | |
| similarity_threshold=self.similarity_threshold) | |
| pipeline_compressor = DocumentCompressorPipeline( | |
| transformers=[splitter, relevance_filter] | |
| ) | |
| base_retriever = SearchAPIRetriever( | |
| pages=self.documents | |
| ) | |
| contextual_retriever = ContextualCompressionRetriever( | |
| base_compressor=pipeline_compressor, base_retriever=base_retriever | |
| ) | |
| return contextual_retriever | |
| def __pretty_print_docs(self, docs, top_n): | |
| return f"\n".join(f"Source: {d.metadata.get('source')}\n" | |
| f"Title: {d.metadata.get('title')}\n" | |
| f"Content: {d.page_content}\n" | |
| for i, d in enumerate(docs) if i < top_n) | |
| async def async_get_context(self, query, max_results=5, cost_callback=None): | |
| compressed_docs = self.__get_contextual_retriever() | |
| if cost_callback: | |
| cost_callback(estimate_embedding_cost(model=OPENAI_EMBEDDING_MODEL, docs=self.documents)) | |
| relevant_docs = await asyncio.to_thread(compressed_docs.invoke, query) | |
| return self.__pretty_print_docs(relevant_docs, max_results) | |
| class WrittenContentCompressor: | |
| def __init__(self, documents, embeddings, similarity_threshold, **kwargs): | |
| self.documents = documents | |
| self.kwargs = kwargs | |
| self.embeddings = embeddings | |
| self.similarity_threshold = similarity_threshold | |
| def __get_contextual_retriever(self): | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| relevance_filter = EmbeddingsFilter(embeddings=self.embeddings, | |
| similarity_threshold=self.similarity_threshold) | |
| pipeline_compressor = DocumentCompressorPipeline( | |
| transformers=[splitter, relevance_filter] | |
| ) | |
| base_retriever = SectionRetriever( | |
| sections=self.documents | |
| ) | |
| contextual_retriever = ContextualCompressionRetriever( | |
| base_compressor=pipeline_compressor, base_retriever=base_retriever | |
| ) | |
| return contextual_retriever | |
| def __pretty_docs_list(self, docs, top_n): | |
| return [f"Title: {d.metadata.get('section_title')}\nContent: {d.page_content}\n" for i, d in enumerate(docs) if i < top_n] | |
| async def async_get_context(self, query, max_results=5, cost_callback=None): | |
| compressed_docs = self.__get_contextual_retriever() | |
| if cost_callback: | |
| cost_callback(estimate_embedding_cost(model=OPENAI_EMBEDDING_MODEL, docs=self.documents)) | |
| relevant_docs = await asyncio.to_thread(compressed_docs.invoke, query) | |
| return self.__pretty_docs_list(relevant_docs, max_results) | |