HBV_AI_Assistant / core /retrievers.py
moazx's picture
Initial commit with all files including LFS
73c6377
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
from . import utils
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.schema import Document
from .config import logger
from .tracing import traceable
# Global configuration for retrieval parameters
# Increased for more comprehensive context and complete answers
DEFAULT_K_VECTOR = 1 # Number of documents to retrieve from vector search
DEFAULT_K_BM25 = 1 # Number of documents to retrieve from BM25 search
# Global variables for lazy loading
_vector_store = None
_chunks = None
_vector_retriever = None
_bm25_retriever = None
_hybrid_retriever = None
_initialized = False
def _ensure_initialized():
"""Initialize retrievers on first use (lazy loading for faster startup)"""
global _vector_store, _chunks, _vector_retriever, _bm25_retriever, _hybrid_retriever, _initialized
if _initialized:
return
logger.info("πŸ”„ Initializing retrievers (first time use)...")
# Process any new data and update vector store and chunks cache
try:
logger.info("πŸ”„ Processing new data and updating vector store if needed...")
_vector_store = utils.process_new_data_and_update_vector_store()
if _vector_store is None:
# Fall back to load existing if processing found no new files
_vector_store = utils.load_vector_store()
if _vector_store is None:
# As a last resort, create from whatever is already in cache (if any)
logger.info("ℹ️ No vector store found; attempting creation from cached chunks...")
cached_chunks = utils.load_chunks() or []
if cached_chunks:
_vector_store = utils.create_vector_store(cached_chunks)
logger.info("βœ… Vector store created from cached chunks")
else:
logger.warning("⚠️ No data available to build a vector store. Retrievers may not function until data is provided.")
except Exception as e:
logger.error(f"Error preparing vector store: {str(e)}")
raise
# Load merged chunks for BM25 (includes previous + new)
try:
logger.info("πŸ“¦ Loading chunks cache for BM25 retriever...")
_chunks = utils.load_chunks() or []
if not _chunks:
logger.warning("⚠️ No chunks available for BM25 retriever. BM25 will be empty until data is processed.")
except Exception as e:
logger.error(f"Error loading chunks: {str(e)}")
raise
# Create vector retriever
logger.info("πŸ” Creating vector retriever...")
_vector_retriever = _vector_store.as_retriever(search_kwargs={"k": 5}) if _vector_store else None
# Create BM25 retriever
logger.info("πŸ“ Creating BM25 retriever...")
_bm25_retriever = BM25Retriever.from_documents(_chunks) if _chunks else None
if _bm25_retriever:
_bm25_retriever.k = 5
# Create hybrid retriever
logger.info("πŸ”„ Creating hybrid retriever...")
if _vector_retriever and _bm25_retriever:
_hybrid_retriever = EnsembleRetriever(
retrievers=[_bm25_retriever, _vector_retriever],
weights=[0.2, 0.8]
)
elif _vector_retriever:
logger.warning("ℹ️ BM25 retriever unavailable; using vector retriever only.")
_hybrid_retriever = _vector_retriever
elif _bm25_retriever:
_hybrid_retriever = _bm25_retriever
else:
raise RuntimeError("Neither vector or BM25 retrievers could be initialized. Provide data under data/new_data and retry.")
_initialized = True
logger.info("βœ… Retrievers initialized successfully.")
def initialize_eagerly():
"""Force initialization of retrievers for background loading"""
_ensure_initialized()
def is_initialized() -> bool:
"""Check if retrievers are already initialized"""
return _initialized
# -----------------------------------------------
# Provider-aware retrieval helper functions
# -----------------------------------------------
_retrieval_pool = ThreadPoolExecutor(max_workers=4)
def _get_doc_id(doc: Document) -> str:
"""Generate unique identifier for a document."""
source = doc.metadata.get('source', 'unknown')
page = doc.metadata.get('page_number', 'unknown')
content_hash = hash(doc.page_content[:200]) # Hash first 200 chars
return f"{source}_{page}_{content_hash}"
def _match_provider(doc, provider: str) -> bool:
if not provider:
return True
prov = str(doc.metadata.get("provider", "")).strip().lower()
return prov == provider.strip().lower()
@traceable(name="VectorRetriever")
def vector_search(query: str, provider: str | None = None, k: int = None):
"""Search FAISS vector store with optional provider metadata filter."""
_ensure_initialized()
if not _vector_store:
return []
# Use global default if k is not specified
if k is None:
k = DEFAULT_K_VECTOR
try:
# Standard search
if provider:
docs = _vector_store.similarity_search(query, k=k, filter={"provider": provider})
else:
docs = _vector_store.similarity_search(query, k=k)
# Ensure provider post-filter in case backend filter is lenient
if provider:
docs = [d for d in docs if _match_provider(d, provider)]
return docs
except Exception as e:
logger.error(f"Vector search failed: {e}")
return []
@traceable(name="BM25Retriever")
def bm25_search(query: str, provider: str | None = None, k: int = None):
"""Search BM25 using the global retriever with optional provider filter."""
_ensure_initialized()
# Use global default if k is not specified
if k is None:
k = DEFAULT_K_BM25
try:
if not _bm25_retriever:
return []
# Standard search
_bm25_retriever.k = max(1, k)
docs = _bm25_retriever.invoke(query) or []
if provider:
docs = [d for d in docs if _match_provider(d, provider)]
return docs[:k]
except Exception as e:
logger.error(f"BM25 search failed: {e}")
return []
def hybrid_search(query: str, provider: str | None = None, k_vector: int = None, k_bm25: int = None):
"""Combine vector and BM25 results (provider-filtered if provided)."""
_ensure_initialized() # Ensure retrievers are initialized before parallel execution
# Use global defaults if not specified
if k_vector is None:
k_vector = DEFAULT_K_VECTOR
if k_bm25 is None:
k_bm25 = DEFAULT_K_BM25
f_vector = _retrieval_pool.submit(vector_search, query, provider, k_vector)
f_bm25 = _retrieval_pool.submit(bm25_search, query, provider, k_bm25)
v_docs = f_vector.result()
b_docs = f_bm25.result()
# Merge uniquely by document ID
seen = set()
merged = []
for d in v_docs + b_docs:
doc_id = _get_doc_id(d)
if doc_id not in seen:
seen.add(doc_id)
merged.append(d)
logger.info(f"Hybrid search returned {len(merged)} unique documents")
return merged