helal94hb1 commited on
Commit
bea2de8
·
1 Parent(s): 8b1b13a

fix: new embeddings and reranker

Browse files
app/api/v2_endpoints.py CHANGED
@@ -29,13 +29,53 @@ from app.services import query_expansion_service
29
  from app.core import state
30
 
31
  logger = logging.getLogger(__name__)
32
- logger.setLevel(logging.DEBUG)
33
  router = APIRouter()
34
 
35
  # --- Constants ---
36
  CONTEXT_CHUNK_COUNT = 100
37
  # --- MODIFIED: TOTAL_RETRIEVAL_COUNT is now the number of candidates for the re-ranker ---
38
- RERANK_CANDIDATE_COUNT = 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # --- Startup Event (Loads data into state) ---
41
  @router.on_event("startup")
@@ -172,10 +212,23 @@ async def handle_v2_query(
172
  retrieved_scores_float = [float(score) for chunk_id, score in search_results]
173
 
174
  candidate_chunks = []
 
175
  for chunk_id, initial_score in search_results:
176
  chunk_text = state.chunk_content_map.get(str(chunk_id))
177
  if chunk_text:
178
  candidate_chunks.append({"id": str(chunk_id), "text": chunk_text})
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  # --- MODIFIED: Offload the blocking re-ranker function to a threadpool ---
181
  reranked_chunks = await run_in_threadpool(
@@ -186,12 +239,18 @@ async def handle_v2_query(
186
  )
187
 
188
  if reranked_chunks:
189
- score_threshold = settings.RERANKER_SCORE_THRESHOLD
190
- filtered_chunks = [c for c in reranked_chunks if c['rerank_score'] > score_threshold]
 
 
 
 
 
 
191
 
192
- if not filtered_chunks:
193
- logger.warning(f"No chunks met the score threshold of {score_threshold}. Using only the top-ranked chunk.")
194
- filtered_chunks = reranked_chunks[:1]
195
 
196
  # --- MODIFIED: Offload the blocking sequence organization to a threadpool ---
197
  organized_chunks = await run_in_threadpool(
@@ -213,7 +272,17 @@ async def handle_v2_query(
213
  else:
214
  llm_answer = "I found relevant documents, but could not construct an answer."
215
 
216
- # ... (rest of the logic for top_result_preview remains the same) ...
 
 
 
 
 
 
 
 
 
 
217
 
218
  else:
219
  llm_answer = "Could not re-rank the search results."
 
29
  from app.core import state
30
 
31
  logger = logging.getLogger(__name__)
 
32
  router = APIRouter()
33
 
34
  # --- Constants ---
35
  CONTEXT_CHUNK_COUNT = 100
36
  # --- MODIFIED: TOTAL_RETRIEVAL_COUNT is now the number of candidates for the re-ranker ---
37
+ RERANK_CANDIDATE_COUNT = 100
38
+
39
+ def dynamic_top_k_selection(
40
+ reranked_docs: List[Dict[str, Any]],
41
+ k_min: int = 3,
42
+ k_max: int = 15,
43
+ fall_off_threshold: float = 1.0 # Start with a threshold of 1.0 logit score drop
44
+ ) -> List[Dict[str, Any]]:
45
+ """
46
+ Selects a dynamic number of documents based on score fall-off.
47
+ """
48
+ if not reranked_docs:
49
+ return []
50
+
51
+ if len(reranked_docs) <= k_min:
52
+ return reranked_docs
53
+
54
+ scores = np.array([doc.get('rerank_score', -float('inf')) for doc in reranked_docs])
55
+ score_diffs = np.diff(scores) * -1 # Make differences positive as scores are descending
56
+
57
+ elbow_index = -1
58
+ # Start searching for a large fall-off after the k_min-th document
59
+ for i in range(k_min - 1, len(score_diffs)):
60
+ if score_diffs[i] > fall_off_threshold:
61
+ # The drop is after this document, so we take up to and including this one.
62
+ elbow_index = i + 1
63
+ break
64
+
65
+ if elbow_index != -1:
66
+ # We found a significant drop
67
+ final_k = elbow_index
68
+ else:
69
+ # No significant drop found, take the max allowed
70
+ final_k = k_max
71
+
72
+ # Ensure final_k is within the [k_min, k_max] bounds and also within list size
73
+ final_k = min(max(final_k, k_min), k_max, len(reranked_docs))
74
+
75
+ logger.info(f"Dynamic K selection: Found elbow at index {elbow_index}. "
76
+ f"Selected final K of {final_k} from {len(reranked_docs)} candidates.")
77
+
78
+ return reranked_docs[:final_k]
79
 
80
  # --- Startup Event (Loads data into state) ---
81
  @router.on_event("startup")
 
212
  retrieved_scores_float = [float(score) for chunk_id, score in search_results]
213
 
214
  candidate_chunks = []
215
+ missing_chunk_count = 0
216
  for chunk_id, initial_score in search_results:
217
  chunk_text = state.chunk_content_map.get(str(chunk_id))
218
  if chunk_text:
219
  candidate_chunks.append({"id": str(chunk_id), "text": chunk_text})
220
+ else:
221
+ missing_chunk_count += 1
222
+ logger.warning(
223
+ f"Data consistency warning: Retrieved chunk_id '{chunk_id}' "
224
+ f"not found in the in-memory chunk_content_map."
225
+ )
226
+
227
+ # --- LOGGING POINT 2: After building the candidate list ---
228
+ logger.debug(
229
+ f"Successfully built {len(candidate_chunks)} candidate chunks for re-ranking. "
230
+ f"{missing_chunk_count} chunks were dropped due to missing text content."
231
+ )
232
 
233
  # --- MODIFIED: Offload the blocking re-ranker function to a threadpool ---
234
  reranked_chunks = await run_in_threadpool(
 
239
  )
240
 
241
  if reranked_chunks:
242
+ filtered_chunks = dynamic_top_k_selection(
243
+ reranked_docs=reranked_chunks,
244
+ k_min=settings.RERANKER_K_MIN, # e.g., 3
245
+ k_max=settings.RERANKER_K_MAX, # e.g., 100
246
+ fall_off_threshold=settings.RERANKER_FALLOFF_THRESHOLD # e.g., 1.0
247
+ )
248
+ # score_threshold = settings.RERANKER_SCORE_THRESHOLD
249
+ # filtered_chunks = [c for c in reranked_chunks if c['rerank_score'] > score_threshold]
250
 
251
+ # if not filtered_chunks:
252
+ # logger.warning(f"No chunks met the score threshold of {score_threshold}. Using only the top-ranked chunk.")
253
+ # filtered_chunks = reranked_chunks[:1]
254
 
255
  # --- MODIFIED: Offload the blocking sequence organization to a threadpool ---
256
  organized_chunks = await run_in_threadpool(
 
272
  else:
273
  llm_answer = "I found relevant documents, but could not construct an answer."
274
 
275
+ top_result_preview = None
276
+ if reranked_chunks:
277
+ top_chunk = reranked_chunks[0]
278
+ top_metadata = state.chunk_metadata_map.get(top_chunk['id'], {})
279
+ top_result_preview = schemas.TopResultPreview(
280
+ id=top_chunk['id'],
281
+ score=float(top_chunk['rerank_score']),
282
+ content_preview=top_chunk['text'][:150],
283
+ original_file=top_metadata.get('original_file')
284
+ )
285
+
286
 
287
  else:
288
  llm_answer = "Could not re-rank the search results."
app/core/config.py CHANGED
@@ -26,8 +26,11 @@ class Settings(BaseSettings):
26
  S3_RERANKER_URL: Optional[str] = None
27
 
28
  RERANKER_MODEL_PATH: str = "data/best_expert_judge_cross_encoder.pt" # Or the exact name of your saved .pt file
29
- RERANKER_MODEL_NAME: str = "BAAI/bge-reranker-base" # The base model used in your training script
30
- RERANKER_SCORE_THRESHOLD: float = 0.3
 
 
 
31
 
32
  SEQUENCE_EXPANSION_THRESHOLD: float =0.68
33
 
 
26
  S3_RERANKER_URL: Optional[str] = None
27
 
28
  RERANKER_MODEL_PATH: str = "data/best_expert_judge_cross_encoder.pt" # Or the exact name of your saved .pt file
29
+ RERANKER_MODEL_NAME: str = "mixedbread-ai/mxbai-rerank-base-v2" # The base model used in your training script
30
+ RERANKER_SCORE_THRESHOLD: float = 0.0
31
+ RERANKER_K_MIN: int = 5
32
+ RERANKER_FALLOFF_THRESHOLD: int = 1
33
+ RERANKER_K_MAX: int = 100
34
 
35
  SEQUENCE_EXPANSION_THRESHOLD: float =0.68
36
 
app/services/llm_service.py CHANGED
@@ -116,7 +116,7 @@ async def generate_answer(query: str, context_used: str) -> str:
116
  answer = "⚠️ LLM returned an empty response."
117
  else:
118
  logger.info("LLM query-based response generated.")
119
- logger.debug(f"LLM Request preview: {messages}")
120
- logger.debug(f"LLM Response preview: {answer}...")
121
 
122
  return answer
 
116
  answer = "⚠️ LLM returned an empty response."
117
  else:
118
  logger.info("LLM query-based response generated.")
119
+ logger.info(f"LLM Request preview: {messages}")
120
+ logger.info(f"LLM Response preview: {answer}...")
121
 
122
  return answer
app/services/reranker_service.py CHANGED
@@ -97,7 +97,8 @@ def rerank_chunks(query: str, chunks: List[Dict], metadata_map: Dict) -> List[Di
97
  encoded = {k: v.to(state.device) for k, v in encoded.items()}
98
 
99
  logits = model(input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask'])
100
- scores = torch.sigmoid(logits).squeeze().cpu().numpy()
 
101
 
102
  # Add the new, more accurate score to each chunk dictionary
103
  if len(chunks) == 1:
 
97
  encoded = {k: v.to(state.device) for k, v in encoded.items()}
98
 
99
  logits = model(input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask'])
100
+ # scores = torch.sigmoid(logits).squeeze().cpu().numpy()
101
+ scores = logits.squeeze().cpu().numpy()
102
 
103
  # Add the new, more accurate score to each chunk dictionary
104
  if len(chunks) == 1:
app/services/retrieval.py CHANGED
@@ -129,8 +129,8 @@ def find_top_gnn_chunks(query_text: str, top_n: int = 100) -> List[Tuple[str, fl
129
  logger.info(f"Similarity search completed in {duration:.4f} seconds.")
130
 
131
  # --- MODIFIED: Detailed logging to include original_file ---
132
- top_results_to_log = results[:30]
133
- logger.debug("--- Top 30 Retrieved Chunks (from retrieval service) ---")
134
  for i, (chunk_id, score) in enumerate(top_results_to_log):
135
  # Look up metadata for the current chunk_id
136
  chunk_id_str = str(chunk_id) # Ensure the key is a string for lookup
@@ -138,8 +138,8 @@ def find_top_gnn_chunks(query_text: str, top_n: int = 100) -> List[Tuple[str, fl
138
  original_file = metadata.get('original_file', 'File not found')
139
 
140
  # Updated log message to include the original file
141
- logger.debug(f" {i+1}. Chunk ID: {chunk_id_str} | File: {original_file} | Score: {score:.4f}")
142
- logger.debug("----------------------------------------------------------------")
143
  # --- END OF MODIFICATION ---
144
 
145
  return results[:top_n]
 
129
  logger.info(f"Similarity search completed in {duration:.4f} seconds.")
130
 
131
  # --- MODIFIED: Detailed logging to include original_file ---
132
+ top_results_to_log = results[:200]
133
+ logger.info("--- Top 30 Retrieved Chunks (from retrieval service) ---")
134
  for i, (chunk_id, score) in enumerate(top_results_to_log):
135
  # Look up metadata for the current chunk_id
136
  chunk_id_str = str(chunk_id) # Ensure the key is a string for lookup
 
138
  original_file = metadata.get('original_file', 'File not found')
139
 
140
  # Updated log message to include the original file
141
+ logger.info(f" {i+1}. Chunk ID: {chunk_id_str} | File: {original_file} | Score: {score:.4f}")
142
+ logger.info("----------------------------------------------------------------")
143
  # --- END OF MODIFICATION ---
144
 
145
  return results[:top_n]