Spaces:
Paused
Paused
| import asyncio | |
| from typing import List, Dict, Optional, Set | |
| from ..context.compression import ContextCompressor, WrittenContentCompressor, VectorstoreCompressor | |
| from ..actions.utils import stream_output | |
| class ContextManager: | |
| """Manages context for the researcher agent.""" | |
| def __init__(self, researcher): | |
| self.researcher = researcher | |
| async def get_similar_content_by_query(self, query, pages): | |
| if self.researcher.verbose: | |
| await stream_output( | |
| "logs", | |
| "fetching_query_content", | |
| f"π Getting relevant content based on query: {query}...", | |
| self.researcher.websocket, | |
| ) | |
| context_compressor = ContextCompressor( | |
| documents=pages, embeddings=self.researcher.memory.get_embeddings() | |
| ) | |
| return await context_compressor.async_get_context( | |
| query=query, max_results=10, cost_callback=self.researcher.add_costs | |
| ) | |
| async def get_similar_content_by_query_with_vectorstore(self, query, filter): | |
| if self.researcher.verbose: | |
| await stream_output( | |
| "logs", | |
| "fetching_query_format", | |
| f" Getting relevant content based on query: {query}...", | |
| self.researcher.websocket, | |
| ) | |
| vectorstore_compressor = VectorstoreCompressor(self.researcher.vector_store, filter) | |
| return await vectorstore_compressor.async_get_context(query=query, max_results=8) | |
| async def get_similar_written_contents_by_draft_section_titles( | |
| self, | |
| current_subtopic: str, | |
| draft_section_titles: List[str], | |
| written_contents: List[Dict], | |
| max_results: int = 10 | |
| ) -> List[str]: | |
| all_queries = [current_subtopic] + draft_section_titles | |
| async def process_query(query: str) -> Set[str]: | |
| return set(await self.__get_similar_written_contents_by_query(query, written_contents)) | |
| results = await asyncio.gather(*[process_query(query) for query in all_queries]) | |
| relevant_contents = set().union(*results) | |
| relevant_contents = list(relevant_contents)[:max_results] | |
| if relevant_contents and self.researcher.verbose: | |
| prettier_contents = "\n".join(relevant_contents) | |
| await stream_output( | |
| "logs", "relevant_contents_context", f"π {prettier_contents}", self.researcher.websocket | |
| ) | |
| return relevant_contents | |
| async def __get_similar_written_contents_by_query(self, | |
| query: str, | |
| written_contents: List[Dict], | |
| similarity_threshold: float = 0.5, | |
| max_results: int = 10 | |
| ) -> List[str]: | |
| if self.researcher.verbose: | |
| await stream_output( | |
| "logs", | |
| "fetching_relevant_written_content", | |
| f"π Getting relevant written content based on query: {query}...", | |
| self.researcher.websocket, | |
| ) | |
| written_content_compressor = WrittenContentCompressor( | |
| documents=written_contents, | |
| embeddings=self.researcher.memory.get_embeddings(), | |
| similarity_threshold=similarity_threshold | |
| ) | |
| return await written_content_compressor.async_get_context( | |
| query=query, max_results=max_results, cost_callback=self.researcher.add_costs | |
| ) | |