Spaces:
Running
Running
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import List, Optional | |
| from . import utils | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain.schema import Document | |
| from .config import logger | |
| from .tracing import traceable | |
| # Global configuration for retrieval parameters | |
| # Increased for more comprehensive context and complete answers | |
| DEFAULT_K_VECTOR = 1 # Number of documents to retrieve from vector search | |
| DEFAULT_K_BM25 = 1 # Number of documents to retrieve from BM25 search | |
| # Global variables for lazy loading | |
| _vector_store = None | |
| _chunks = None | |
| _vector_retriever = None | |
| _bm25_retriever = None | |
| _hybrid_retriever = None | |
| _initialized = False | |
| def _ensure_initialized(): | |
| """Initialize retrievers on first use (lazy loading for faster startup)""" | |
| global _vector_store, _chunks, _vector_retriever, _bm25_retriever, _hybrid_retriever, _initialized | |
| if _initialized: | |
| return | |
| logger.info("π Initializing retrievers (first time use)...") | |
| # Process any new data and update vector store and chunks cache | |
| try: | |
| logger.info("π Processing new data and updating vector store if needed...") | |
| _vector_store = utils.process_new_data_and_update_vector_store() | |
| if _vector_store is None: | |
| # Fall back to load existing if processing found no new files | |
| _vector_store = utils.load_vector_store() | |
| if _vector_store is None: | |
| # As a last resort, create from whatever is already in cache (if any) | |
| logger.info("βΉοΈ No vector store found; attempting creation from cached chunks...") | |
| cached_chunks = utils.load_chunks() or [] | |
| if cached_chunks: | |
| _vector_store = utils.create_vector_store(cached_chunks) | |
| logger.info("β Vector store created from cached chunks") | |
| else: | |
| logger.warning("β οΈ No data available to build a vector store. Retrievers may not function until data is provided.") | |
| except Exception as e: | |
| logger.error(f"Error preparing vector store: {str(e)}") | |
| raise | |
| # Load merged chunks for BM25 (includes previous + new) | |
| try: | |
| logger.info("π¦ Loading chunks cache for BM25 retriever...") | |
| _chunks = utils.load_chunks() or [] | |
| if not _chunks: | |
| logger.warning("β οΈ No chunks available for BM25 retriever. BM25 will be empty until data is processed.") | |
| except Exception as e: | |
| logger.error(f"Error loading chunks: {str(e)}") | |
| raise | |
| # Create vector retriever | |
| logger.info("π Creating vector retriever...") | |
| _vector_retriever = _vector_store.as_retriever(search_kwargs={"k": 5}) if _vector_store else None | |
| # Create BM25 retriever | |
| logger.info("π Creating BM25 retriever...") | |
| _bm25_retriever = BM25Retriever.from_documents(_chunks) if _chunks else None | |
| if _bm25_retriever: | |
| _bm25_retriever.k = 5 | |
| # Create hybrid retriever | |
| logger.info("π Creating hybrid retriever...") | |
| if _vector_retriever and _bm25_retriever: | |
| _hybrid_retriever = EnsembleRetriever( | |
| retrievers=[_bm25_retriever, _vector_retriever], | |
| weights=[0.2, 0.8] | |
| ) | |
| elif _vector_retriever: | |
| logger.warning("βΉοΈ BM25 retriever unavailable; using vector retriever only.") | |
| _hybrid_retriever = _vector_retriever | |
| elif _bm25_retriever: | |
| _hybrid_retriever = _bm25_retriever | |
| else: | |
| raise RuntimeError("Neither vector or BM25 retrievers could be initialized. Provide data under data/new_data and retry.") | |
| _initialized = True | |
| logger.info("β Retrievers initialized successfully.") | |
| def initialize_eagerly(): | |
| """Force initialization of retrievers for background loading""" | |
| _ensure_initialized() | |
| def is_initialized() -> bool: | |
| """Check if retrievers are already initialized""" | |
| return _initialized | |
| # ----------------------------------------------- | |
| # Provider-aware retrieval helper functions | |
| # ----------------------------------------------- | |
| _retrieval_pool = ThreadPoolExecutor(max_workers=4) | |
| def _get_doc_id(doc: Document) -> str: | |
| """Generate unique identifier for a document.""" | |
| source = doc.metadata.get('source', 'unknown') | |
| page = doc.metadata.get('page_number', 'unknown') | |
| content_hash = hash(doc.page_content[:200]) # Hash first 200 chars | |
| return f"{source}_{page}_{content_hash}" | |
| def _match_provider(doc, provider: str) -> bool: | |
| if not provider: | |
| return True | |
| prov = str(doc.metadata.get("provider", "")).strip().lower() | |
| return prov == provider.strip().lower() | |
| def vector_search(query: str, provider: str | None = None, k: int = None): | |
| """Search FAISS vector store with optional provider metadata filter.""" | |
| _ensure_initialized() | |
| if not _vector_store: | |
| return [] | |
| # Use global default if k is not specified | |
| if k is None: | |
| k = DEFAULT_K_VECTOR | |
| try: | |
| # Standard search | |
| if provider: | |
| docs = _vector_store.similarity_search(query, k=k, filter={"provider": provider}) | |
| else: | |
| docs = _vector_store.similarity_search(query, k=k) | |
| # Ensure provider post-filter in case backend filter is lenient | |
| if provider: | |
| docs = [d for d in docs if _match_provider(d, provider)] | |
| return docs | |
| except Exception as e: | |
| logger.error(f"Vector search failed: {e}") | |
| return [] | |
| def bm25_search(query: str, provider: str | None = None, k: int = None): | |
| """Search BM25 using the global retriever with optional provider filter.""" | |
| _ensure_initialized() | |
| # Use global default if k is not specified | |
| if k is None: | |
| k = DEFAULT_K_BM25 | |
| try: | |
| if not _bm25_retriever: | |
| return [] | |
| # Standard search | |
| _bm25_retriever.k = max(1, k) | |
| docs = _bm25_retriever.invoke(query) or [] | |
| if provider: | |
| docs = [d for d in docs if _match_provider(d, provider)] | |
| return docs[:k] | |
| except Exception as e: | |
| logger.error(f"BM25 search failed: {e}") | |
| return [] | |
| def hybrid_search(query: str, provider: str | None = None, k_vector: int = None, k_bm25: int = None): | |
| """Combine vector and BM25 results (provider-filtered if provided).""" | |
| _ensure_initialized() # Ensure retrievers are initialized before parallel execution | |
| # Use global defaults if not specified | |
| if k_vector is None: | |
| k_vector = DEFAULT_K_VECTOR | |
| if k_bm25 is None: | |
| k_bm25 = DEFAULT_K_BM25 | |
| f_vector = _retrieval_pool.submit(vector_search, query, provider, k_vector) | |
| f_bm25 = _retrieval_pool.submit(bm25_search, query, provider, k_bm25) | |
| v_docs = f_vector.result() | |
| b_docs = f_bm25.result() | |
| # Merge uniquely by document ID | |
| seen = set() | |
| merged = [] | |
| for d in v_docs + b_docs: | |
| doc_id = _get_doc_id(d) | |
| if doc_id not in seen: | |
| seen.add(doc_id) | |
| merged.append(d) | |
| logger.info(f"Hybrid search returned {len(merged)} unique documents") | |
| return merged | |