Spaces:
Running
Running
| """ | |
| Context Enrichment Module for Medical RAG | |
| This module enriches retrieved documents with surrounding context (adjacent pages) | |
| to provide comprehensive information for expert medical professionals. | |
| """ | |
| from typing import List, Dict, Set, Optional | |
| from langchain.schema import Document | |
| from pathlib import Path | |
| from .config import logger | |
| class ContextEnricher: | |
| """ | |
| Enriches retrieved documents with surrounding pages for richer context. | |
| """ | |
| def __init__(self, cache_size: int = 100): | |
| """ | |
| Initialize context enricher with document cache. | |
| Args: | |
| cache_size: Maximum number of source documents to cache | |
| """ | |
| self._document_cache: Dict[str, List[Document]] = {} | |
| self._cache_size = cache_size | |
| self._all_chunks_cache: Optional[List[Document]] = None # Cache all chunks to avoid reloading | |
| def enrich_documents( | |
| self, | |
| retrieved_docs: List[Document], | |
| pages_before: int = 1, | |
| pages_after: int = 1, | |
| max_enriched_docs: int = 5 | |
| ) -> List[Document]: | |
| """ | |
| Enrich retrieved documents by adding separate context pages. | |
| Args: | |
| retrieved_docs: List of retrieved documents | |
| pages_before: Number of pages to include before each document | |
| pages_after: Number of pages to include after each document | |
| max_enriched_docs: Maximum number of documents to enrich (top results) | |
| Returns: | |
| List with original documents + separate context page documents | |
| """ | |
| if not retrieved_docs: | |
| return [] | |
| result_docs = [] | |
| processed_sources = set() | |
| enriched_count = 0 | |
| # Only enrich top documents to avoid overwhelming context | |
| docs_to_enrich = retrieved_docs[:max_enriched_docs] | |
| for doc in docs_to_enrich: | |
| try: | |
| # Get source information | |
| source = doc.metadata.get('source', 'unknown') | |
| page_num = doc.metadata.get('page_number', 1) | |
| # Skip if already processed this source-page combination | |
| source_page_key = f"{source}_{page_num}" | |
| if source_page_key in processed_sources: | |
| continue | |
| processed_sources.add(source_page_key) | |
| # Get surrounding pages | |
| surrounding_docs = self._get_surrounding_pages( | |
| doc, | |
| pages_before, | |
| pages_after | |
| ) | |
| if surrounding_docs: | |
| # Add separate documents for each page | |
| page_docs = self._create_separate_page_documents( | |
| doc, | |
| surrounding_docs, | |
| pages_before, | |
| pages_after | |
| ) | |
| result_docs.extend(page_docs) | |
| enriched_count += 1 | |
| # Log enrichment details | |
| page_numbers = [int(d.metadata.get('page_number', 0)) for d in page_docs] | |
| logger.debug(f"Enriched {source} page {page_num} with pages: {page_numbers}") | |
| else: | |
| # No surrounding pages found, add original with empty enrichment metadata | |
| original_with_metadata = self._add_empty_enrichment_metadata(doc) | |
| result_docs.append(original_with_metadata) | |
| except Exception as e: | |
| logger.warning(f"Could not enrich document from {doc.metadata.get('source')}: {e}") | |
| original_with_metadata = self._add_empty_enrichment_metadata(doc) | |
| result_docs.append(original_with_metadata) | |
| # Add remaining documents without enrichment | |
| for doc in retrieved_docs[max_enriched_docs:]: | |
| original_with_metadata = self._add_empty_enrichment_metadata(doc) | |
| result_docs.append(original_with_metadata) | |
| logger.info(f"Enriched {enriched_count} documents with surrounding context pages") | |
| return result_docs | |
| def _get_surrounding_pages( | |
| self, | |
| doc: Document, | |
| pages_before: int, | |
| pages_after: int | |
| ) -> List[Document]: | |
| """ | |
| Get surrounding pages for a document. | |
| Args: | |
| doc: Original document | |
| pages_before: Number of pages before | |
| pages_after: Number of pages after | |
| Returns: | |
| List of surrounding documents (including original), deduplicated by page number | |
| """ | |
| source = doc.metadata.get('source', 'unknown') | |
| page_num = doc.metadata.get('page_number', 1) | |
| provider = doc.metadata.get('provider', 'unknown') | |
| disease = doc.metadata.get('disease', 'unknown') | |
| # Try to get full document from cache or load it | |
| full_doc_pages = self._get_full_document(source, provider, disease) | |
| if not full_doc_pages: | |
| return [] | |
| # Find the target page and surrounding pages | |
| target_page = int(page_num) if isinstance(page_num, (int, str)) else 1 | |
| # Use a dict to deduplicate by page number (keep first occurrence) | |
| pages_dict = {} | |
| for page_doc in full_doc_pages: | |
| doc_page_num = page_doc.metadata.get('page_number', 0) | |
| if isinstance(doc_page_num, str): | |
| try: | |
| doc_page_num = int(doc_page_num) | |
| except: | |
| continue | |
| # Include pages within range | |
| if target_page - pages_before <= doc_page_num <= target_page + pages_after: | |
| # Only add if not already present (deduplication) | |
| if doc_page_num not in pages_dict: | |
| pages_dict[doc_page_num] = page_doc | |
| # Return sorted by page number | |
| surrounding = [pages_dict[pn] for pn in sorted(pages_dict.keys())] | |
| return surrounding | |
| def _get_full_document( | |
| self, | |
| source: str, | |
| provider: str, | |
| disease: str | |
| ) -> Optional[List[Document]]: | |
| """ | |
| Get full document pages from chunks cache. | |
| Args: | |
| source: Source filename | |
| provider: Provider name | |
| disease: Disease name | |
| Returns: | |
| List of all pages in the document, or None if not found | |
| """ | |
| cache_key = f"{provider}_{disease}_{source}" | |
| # Check cache | |
| if cache_key in self._document_cache: | |
| return self._document_cache[cache_key] | |
| # Load from chunks cache instead of trying to reload PDFs | |
| try: | |
| from . import utils | |
| # Load all chunks (use cached version to avoid redundant loading) | |
| if self._all_chunks_cache is None: | |
| self._all_chunks_cache = utils.load_chunks() | |
| if self._all_chunks_cache: | |
| logger.debug(f"Loaded {len(self._all_chunks_cache)} chunks into enricher cache") | |
| all_chunks = self._all_chunks_cache | |
| if not all_chunks: | |
| logger.debug(f"No chunks available for enrichment") | |
| return None | |
| # Filter chunks for this specific document | |
| doc_pages = [] | |
| for chunk in all_chunks: | |
| chunk_source = chunk.metadata.get('source', '') | |
| chunk_provider = chunk.metadata.get('provider', '') | |
| chunk_disease = chunk.metadata.get('disease', '') | |
| # Match by source, provider, and disease | |
| if (chunk_source == source and | |
| chunk_provider == provider and | |
| chunk_disease == disease): | |
| doc_pages.append(chunk) | |
| if not doc_pages: | |
| logger.debug(f"Could not find chunks for document: {source} (Provider: {provider}, Disease: {disease})") | |
| return None | |
| # Sort by page number | |
| doc_pages.sort(key=lambda d: int(d.metadata.get('page_number', 0))) | |
| # Cache it (with size limit) | |
| if len(self._document_cache) >= self._cache_size: | |
| # Remove oldest entry | |
| self._document_cache.pop(next(iter(self._document_cache))) | |
| self._document_cache[cache_key] = doc_pages | |
| logger.debug(f"Loaded {len(doc_pages)} pages for {source} from chunks cache") | |
| return doc_pages | |
| except Exception as e: | |
| logger.warning(f"Error loading document from chunks cache {source}: {e}") | |
| return None | |
| def _create_separate_page_documents( | |
| self, | |
| original_doc: Document, | |
| surrounding_docs: List[Document], | |
| pages_before: int, | |
| pages_after: int | |
| ) -> List[Document]: | |
| """ | |
| Create separate document objects for original page and context pages. | |
| Args: | |
| original_doc: Original retrieved document | |
| surrounding_docs: List of surrounding documents | |
| pages_before: Number of pages before | |
| pages_after: Number of pages after | |
| Returns: | |
| List of separate documents (context pages + original page + context pages) | |
| """ | |
| # Sort by page number | |
| sorted_docs = sorted( | |
| surrounding_docs, | |
| key=lambda d: int(d.metadata.get('page_number', 0)) | |
| ) | |
| original_page = int(original_doc.metadata.get('page_number', 1)) | |
| result_docs = [] | |
| for doc in sorted_docs: | |
| page_num = int(doc.metadata.get('page_number', 0)) | |
| # Determine if this is a context page or the original page | |
| is_context_page = (page_num != original_page) | |
| # Create document with appropriate metadata | |
| page_doc = Document( | |
| page_content=doc.page_content, | |
| metadata={ | |
| **doc.metadata, | |
| 'context_enrichment': is_context_page, | |
| 'enriched': False, | |
| 'pages_included': [], | |
| 'primary_page': None, | |
| 'context_pages_before': None, | |
| 'context_pages_after': None, | |
| } | |
| ) | |
| result_docs.append(page_doc) | |
| return result_docs | |
| def _add_empty_enrichment_metadata(self, doc: Document) -> Document: | |
| """ | |
| Add empty enrichment metadata fields to a document. | |
| Args: | |
| doc: Original document | |
| Returns: | |
| Document with enrichment metadata fields set to default values | |
| """ | |
| return Document( | |
| page_content=doc.page_content, | |
| metadata={ | |
| **doc.metadata, | |
| 'enriched': False, | |
| 'pages_included': [], | |
| 'primary_page': None, | |
| 'context_pages_before': None, | |
| 'context_pages_after': None, | |
| } | |
| ) | |
| # Global enricher instance | |
| _context_enricher = ContextEnricher(cache_size=100) | |
| def enrich_retrieved_documents( | |
| documents: List[Document], | |
| pages_before: int = 1, | |
| pages_after: int = 1, | |
| max_enriched: int = 5 | |
| ) -> List[Document]: | |
| """ | |
| Convenience function to enrich retrieved documents. | |
| Args: | |
| documents: Retrieved documents | |
| pages_before: Number of pages to include before each document | |
| pages_after: Number of pages to include after each document | |
| max_enriched: Maximum number of documents to enrich | |
| Returns: | |
| Enriched documents with surrounding context | |
| """ | |
| return _context_enricher.enrich_documents( | |
| documents, | |
| pages_before=pages_before, | |
| pages_after=pages_after, | |
| max_enriched_docs=max_enriched | |
| ) | |
| def get_context_enricher() -> ContextEnricher: | |
| """Get the global context enricher instance.""" | |
| return _context_enricher | |