File size: 7,254 Bytes
73c6377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional

from . import utils
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.schema import Document
from .config import logger
from .tracing import traceable

# Global configuration for retrieval parameters
# Increased for more comprehensive context and complete answers
DEFAULT_K_VECTOR = 1  # Number of documents to retrieve from vector search
DEFAULT_K_BM25 = 1     # Number of documents to retrieve from BM25 search

# Global variables for lazy loading
_vector_store = None
_chunks = None
_vector_retriever = None
_bm25_retriever = None
_hybrid_retriever = None
_initialized = False

def _ensure_initialized():
    """Initialize retrievers on first use (lazy loading for faster startup)"""
    global _vector_store, _chunks, _vector_retriever, _bm25_retriever, _hybrid_retriever, _initialized
    
    if _initialized:
        return
        
    logger.info("πŸ”„ Initializing retrievers (first time use)...")
    
    # Process any new data and update vector store and chunks cache
    try:
        logger.info("πŸ”„ Processing new data and updating vector store if needed...")
        _vector_store = utils.process_new_data_and_update_vector_store()
        if _vector_store is None:
            # Fall back to load existing if processing found no new files
            _vector_store = utils.load_vector_store()
        if _vector_store is None:
            # As a last resort, create from whatever is already in cache (if any)
            logger.info("ℹ️ No vector store found; attempting creation from cached chunks...")
            cached_chunks = utils.load_chunks() or []
            if cached_chunks:
                _vector_store = utils.create_vector_store(cached_chunks)
                logger.info("βœ… Vector store created from cached chunks")
            else:
                logger.warning("⚠️ No data available to build a vector store. Retrievers may not function until data is provided.")
    except Exception as e:
        logger.error(f"Error preparing vector store: {str(e)}")
        raise

    # Load merged chunks for BM25 (includes previous + new)
    try:
        logger.info("πŸ“¦ Loading chunks cache for BM25 retriever...")
        _chunks = utils.load_chunks() or []
        if not _chunks:
            logger.warning("⚠️ No chunks available for BM25 retriever. BM25 will be empty until data is processed.")
    except Exception as e:
        logger.error(f"Error loading chunks: {str(e)}")
        raise

    # Create vector retriever
    logger.info("πŸ” Creating vector retriever...")
    _vector_retriever = _vector_store.as_retriever(search_kwargs={"k": 5}) if _vector_store else None

    # Create BM25 retriever
    logger.info("πŸ“ Creating BM25 retriever...")
    _bm25_retriever = BM25Retriever.from_documents(_chunks) if _chunks else None
    if _bm25_retriever:
        _bm25_retriever.k = 5

    # Create hybrid retriever
    logger.info("πŸ”„ Creating hybrid retriever...")
    if _vector_retriever and _bm25_retriever:
        _hybrid_retriever = EnsembleRetriever(
            retrievers=[_bm25_retriever, _vector_retriever],
            weights=[0.2, 0.8]
        )
    elif _vector_retriever:
        logger.warning("ℹ️ BM25 retriever unavailable; using vector retriever only.")
        _hybrid_retriever = _vector_retriever
    elif _bm25_retriever:
        _hybrid_retriever = _bm25_retriever
    else:
        raise RuntimeError("Neither vector or BM25 retrievers could be initialized. Provide data under data/new_data and retry.")

    _initialized = True
    logger.info("βœ… Retrievers initialized successfully.")


def initialize_eagerly():
    """Force initialization of retrievers for background loading"""
    _ensure_initialized()


def is_initialized() -> bool:
    """Check if retrievers are already initialized"""
    return _initialized


# -----------------------------------------------
# Provider-aware retrieval helper functions
# -----------------------------------------------
_retrieval_pool = ThreadPoolExecutor(max_workers=4)


def _get_doc_id(doc: Document) -> str:
    """Generate unique identifier for a document."""
    source = doc.metadata.get('source', 'unknown')
    page = doc.metadata.get('page_number', 'unknown')
    content_hash = hash(doc.page_content[:200])  # Hash first 200 chars
    return f"{source}_{page}_{content_hash}"


def _match_provider(doc, provider: str) -> bool:
    if not provider:
        return True
    prov = str(doc.metadata.get("provider", "")).strip().lower()
    return prov == provider.strip().lower()


@traceable(name="VectorRetriever")
def vector_search(query: str, provider: str | None = None, k: int = None):
    """Search FAISS vector store with optional provider metadata filter."""
    _ensure_initialized()
    if not _vector_store:
        return []
    
    # Use global default if k is not specified
    if k is None:
        k = DEFAULT_K_VECTOR
    
    try:
        # Standard search
        if provider:
            docs = _vector_store.similarity_search(query, k=k, filter={"provider": provider})
        else:
            docs = _vector_store.similarity_search(query, k=k)
        
        # Ensure provider post-filter in case backend filter is lenient
        if provider:
            docs = [d for d in docs if _match_provider(d, provider)]
        return docs
    except Exception as e:
        logger.error(f"Vector search failed: {e}")
        return []


@traceable(name="BM25Retriever")
def bm25_search(query: str, provider: str | None = None, k: int = None):
    """Search BM25 using the global retriever with optional provider filter."""
    _ensure_initialized()
    
    # Use global default if k is not specified
    if k is None:
        k = DEFAULT_K_BM25
    
    try:
        if not _bm25_retriever:
            return []
        
        # Standard search
        _bm25_retriever.k = max(1, k)
        docs = _bm25_retriever.invoke(query) or []
        
        if provider:
            docs = [d for d in docs if _match_provider(d, provider)]
        return docs[:k]
    except Exception as e:
        logger.error(f"BM25 search failed: {e}")
        return []


def hybrid_search(query: str, provider: str | None = None, k_vector: int = None, k_bm25: int = None):
    """Combine vector and BM25 results (provider-filtered if provided)."""
    _ensure_initialized()  # Ensure retrievers are initialized before parallel execution
    
    # Use global defaults if not specified
    if k_vector is None:
        k_vector = DEFAULT_K_VECTOR
    if k_bm25 is None:
        k_bm25 = DEFAULT_K_BM25
    
    f_vector = _retrieval_pool.submit(vector_search, query, provider, k_vector)
    f_bm25 = _retrieval_pool.submit(bm25_search, query, provider, k_bm25)

    v_docs = f_vector.result()
    b_docs = f_bm25.result()
    # Merge uniquely by document ID
    seen = set()
    merged = []
    for d in v_docs + b_docs:
        doc_id = _get_doc_id(d)
        if doc_id not in seen:
            seen.add(doc_id)
            merged.append(d)
    
    logger.info(f"Hybrid search returned {len(merged)} unique documents")
    return merged