from concurrent.futures import ThreadPoolExecutor from typing import List, Optional from . import utils from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever from langchain.schema import Document from .config import logger from .tracing import traceable # Global configuration for retrieval parameters # Increased for more comprehensive context and complete answers DEFAULT_K_VECTOR = 1 # Number of documents to retrieve from vector search DEFAULT_K_BM25 = 1 # Number of documents to retrieve from BM25 search # Global variables for lazy loading _vector_store = None _chunks = None _vector_retriever = None _bm25_retriever = None _hybrid_retriever = None _initialized = False def _ensure_initialized(): """Initialize retrievers on first use (lazy loading for faster startup)""" global _vector_store, _chunks, _vector_retriever, _bm25_retriever, _hybrid_retriever, _initialized if _initialized: return logger.info("🔄 Initializing retrievers (first time use)...") # Process any new data and update vector store and chunks cache try: logger.info("🔄 Processing new data and updating vector store if needed...") _vector_store = utils.process_new_data_and_update_vector_store() if _vector_store is None: # Fall back to load existing if processing found no new files _vector_store = utils.load_vector_store() if _vector_store is None: # As a last resort, create from whatever is already in cache (if any) logger.info("â„šī¸ No vector store found; attempting creation from cached chunks...") cached_chunks = utils.load_chunks() or [] if cached_chunks: _vector_store = utils.create_vector_store(cached_chunks) logger.info("✅ Vector store created from cached chunks") else: logger.warning("âš ī¸ No data available to build a vector store. Retrievers may not function until data is provided.") except Exception as e: logger.error(f"Error preparing vector store: {str(e)}") raise # Load merged chunks for BM25 (includes previous + new) try: logger.info("đŸ“Ļ Loading chunks cache for BM25 retriever...") _chunks = utils.load_chunks() or [] if not _chunks: logger.warning("âš ī¸ No chunks available for BM25 retriever. BM25 will be empty until data is processed.") except Exception as e: logger.error(f"Error loading chunks: {str(e)}") raise # Create vector retriever logger.info("🔍 Creating vector retriever...") _vector_retriever = _vector_store.as_retriever(search_kwargs={"k": 5}) if _vector_store else None # Create BM25 retriever logger.info("📝 Creating BM25 retriever...") _bm25_retriever = BM25Retriever.from_documents(_chunks) if _chunks else None if _bm25_retriever: _bm25_retriever.k = 5 # Create hybrid retriever logger.info("🔄 Creating hybrid retriever...") if _vector_retriever and _bm25_retriever: _hybrid_retriever = EnsembleRetriever( retrievers=[_bm25_retriever, _vector_retriever], weights=[0.2, 0.8] ) elif _vector_retriever: logger.warning("â„šī¸ BM25 retriever unavailable; using vector retriever only.") _hybrid_retriever = _vector_retriever elif _bm25_retriever: _hybrid_retriever = _bm25_retriever else: raise RuntimeError("Neither vector or BM25 retrievers could be initialized. Provide data under data/new_data and retry.") _initialized = True logger.info("✅ Retrievers initialized successfully.") def initialize_eagerly(): """Force initialization of retrievers for background loading""" _ensure_initialized() def is_initialized() -> bool: """Check if retrievers are already initialized""" return _initialized # ----------------------------------------------- # Provider-aware retrieval helper functions # ----------------------------------------------- _retrieval_pool = ThreadPoolExecutor(max_workers=4) def _get_doc_id(doc: Document) -> str: """Generate unique identifier for a document.""" source = doc.metadata.get('source', 'unknown') page = doc.metadata.get('page_number', 'unknown') content_hash = hash(doc.page_content[:200]) # Hash first 200 chars return f"{source}_{page}_{content_hash}" def _match_provider(doc, provider: str) -> bool: if not provider: return True prov = str(doc.metadata.get("provider", "")).strip().lower() return prov == provider.strip().lower() @traceable(name="VectorRetriever") def vector_search(query: str, provider: str | None = None, k: int = None): """Search FAISS vector store with optional provider metadata filter.""" _ensure_initialized() if not _vector_store: return [] # Use global default if k is not specified if k is None: k = DEFAULT_K_VECTOR try: # Standard search if provider: docs = _vector_store.similarity_search(query, k=k, filter={"provider": provider}) else: docs = _vector_store.similarity_search(query, k=k) # Ensure provider post-filter in case backend filter is lenient if provider: docs = [d for d in docs if _match_provider(d, provider)] return docs except Exception as e: logger.error(f"Vector search failed: {e}") return [] @traceable(name="BM25Retriever") def bm25_search(query: str, provider: str | None = None, k: int = None): """Search BM25 using the global retriever with optional provider filter.""" _ensure_initialized() # Use global default if k is not specified if k is None: k = DEFAULT_K_BM25 try: if not _bm25_retriever: return [] # Standard search _bm25_retriever.k = max(1, k) docs = _bm25_retriever.invoke(query) or [] if provider: docs = [d for d in docs if _match_provider(d, provider)] return docs[:k] except Exception as e: logger.error(f"BM25 search failed: {e}") return [] def hybrid_search(query: str, provider: str | None = None, k_vector: int = None, k_bm25: int = None): """Combine vector and BM25 results (provider-filtered if provided).""" _ensure_initialized() # Ensure retrievers are initialized before parallel execution # Use global defaults if not specified if k_vector is None: k_vector = DEFAULT_K_VECTOR if k_bm25 is None: k_bm25 = DEFAULT_K_BM25 f_vector = _retrieval_pool.submit(vector_search, query, provider, k_vector) f_bm25 = _retrieval_pool.submit(bm25_search, query, provider, k_bm25) v_docs = f_vector.result() b_docs = f_bm25.result() # Merge uniquely by document ID seen = set() merged = [] for d in v_docs + b_docs: doc_id = _get_doc_id(d) if doc_id not in seen: seen.add(doc_id) merged.append(d) logger.info(f"Hybrid search returned {len(merged)} unique documents") return merged