Spaces:
Running
Running
File size: 7,254 Bytes
73c6377 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
from . import utils
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.schema import Document
from .config import logger
from .tracing import traceable
# Global configuration for retrieval parameters
# Increased for more comprehensive context and complete answers
DEFAULT_K_VECTOR = 1 # Number of documents to retrieve from vector search
DEFAULT_K_BM25 = 1 # Number of documents to retrieve from BM25 search
# Global variables for lazy loading
_vector_store = None
_chunks = None
_vector_retriever = None
_bm25_retriever = None
_hybrid_retriever = None
_initialized = False
def _ensure_initialized():
"""Initialize retrievers on first use (lazy loading for faster startup)"""
global _vector_store, _chunks, _vector_retriever, _bm25_retriever, _hybrid_retriever, _initialized
if _initialized:
return
logger.info("π Initializing retrievers (first time use)...")
# Process any new data and update vector store and chunks cache
try:
logger.info("π Processing new data and updating vector store if needed...")
_vector_store = utils.process_new_data_and_update_vector_store()
if _vector_store is None:
# Fall back to load existing if processing found no new files
_vector_store = utils.load_vector_store()
if _vector_store is None:
# As a last resort, create from whatever is already in cache (if any)
logger.info("βΉοΈ No vector store found; attempting creation from cached chunks...")
cached_chunks = utils.load_chunks() or []
if cached_chunks:
_vector_store = utils.create_vector_store(cached_chunks)
logger.info("β
Vector store created from cached chunks")
else:
logger.warning("β οΈ No data available to build a vector store. Retrievers may not function until data is provided.")
except Exception as e:
logger.error(f"Error preparing vector store: {str(e)}")
raise
# Load merged chunks for BM25 (includes previous + new)
try:
logger.info("π¦ Loading chunks cache for BM25 retriever...")
_chunks = utils.load_chunks() or []
if not _chunks:
logger.warning("β οΈ No chunks available for BM25 retriever. BM25 will be empty until data is processed.")
except Exception as e:
logger.error(f"Error loading chunks: {str(e)}")
raise
# Create vector retriever
logger.info("π Creating vector retriever...")
_vector_retriever = _vector_store.as_retriever(search_kwargs={"k": 5}) if _vector_store else None
# Create BM25 retriever
logger.info("π Creating BM25 retriever...")
_bm25_retriever = BM25Retriever.from_documents(_chunks) if _chunks else None
if _bm25_retriever:
_bm25_retriever.k = 5
# Create hybrid retriever
logger.info("π Creating hybrid retriever...")
if _vector_retriever and _bm25_retriever:
_hybrid_retriever = EnsembleRetriever(
retrievers=[_bm25_retriever, _vector_retriever],
weights=[0.2, 0.8]
)
elif _vector_retriever:
logger.warning("βΉοΈ BM25 retriever unavailable; using vector retriever only.")
_hybrid_retriever = _vector_retriever
elif _bm25_retriever:
_hybrid_retriever = _bm25_retriever
else:
raise RuntimeError("Neither vector or BM25 retrievers could be initialized. Provide data under data/new_data and retry.")
_initialized = True
logger.info("β
Retrievers initialized successfully.")
def initialize_eagerly():
"""Force initialization of retrievers for background loading"""
_ensure_initialized()
def is_initialized() -> bool:
"""Check if retrievers are already initialized"""
return _initialized
# -----------------------------------------------
# Provider-aware retrieval helper functions
# -----------------------------------------------
_retrieval_pool = ThreadPoolExecutor(max_workers=4)
def _get_doc_id(doc: Document) -> str:
"""Generate unique identifier for a document."""
source = doc.metadata.get('source', 'unknown')
page = doc.metadata.get('page_number', 'unknown')
content_hash = hash(doc.page_content[:200]) # Hash first 200 chars
return f"{source}_{page}_{content_hash}"
def _match_provider(doc, provider: str) -> bool:
if not provider:
return True
prov = str(doc.metadata.get("provider", "")).strip().lower()
return prov == provider.strip().lower()
@traceable(name="VectorRetriever")
def vector_search(query: str, provider: str | None = None, k: int = None):
"""Search FAISS vector store with optional provider metadata filter."""
_ensure_initialized()
if not _vector_store:
return []
# Use global default if k is not specified
if k is None:
k = DEFAULT_K_VECTOR
try:
# Standard search
if provider:
docs = _vector_store.similarity_search(query, k=k, filter={"provider": provider})
else:
docs = _vector_store.similarity_search(query, k=k)
# Ensure provider post-filter in case backend filter is lenient
if provider:
docs = [d for d in docs if _match_provider(d, provider)]
return docs
except Exception as e:
logger.error(f"Vector search failed: {e}")
return []
@traceable(name="BM25Retriever")
def bm25_search(query: str, provider: str | None = None, k: int = None):
"""Search BM25 using the global retriever with optional provider filter."""
_ensure_initialized()
# Use global default if k is not specified
if k is None:
k = DEFAULT_K_BM25
try:
if not _bm25_retriever:
return []
# Standard search
_bm25_retriever.k = max(1, k)
docs = _bm25_retriever.invoke(query) or []
if provider:
docs = [d for d in docs if _match_provider(d, provider)]
return docs[:k]
except Exception as e:
logger.error(f"BM25 search failed: {e}")
return []
def hybrid_search(query: str, provider: str | None = None, k_vector: int = None, k_bm25: int = None):
"""Combine vector and BM25 results (provider-filtered if provided)."""
_ensure_initialized() # Ensure retrievers are initialized before parallel execution
# Use global defaults if not specified
if k_vector is None:
k_vector = DEFAULT_K_VECTOR
if k_bm25 is None:
k_bm25 = DEFAULT_K_BM25
f_vector = _retrieval_pool.submit(vector_search, query, provider, k_vector)
f_bm25 = _retrieval_pool.submit(bm25_search, query, provider, k_bm25)
v_docs = f_vector.result()
b_docs = f_bm25.result()
# Merge uniquely by document ID
seen = set()
merged = []
for d in v_docs + b_docs:
doc_id = _get_doc_id(d)
if doc_id not in seen:
seen.add(doc_id)
merged.append(d)
logger.info(f"Hybrid search returned {len(merged)} unique documents")
return merged
|