Spaces:
Sleeping
Sleeping
Commit
·
bea2de8
1
Parent(s):
8b1b13a
fix: new embeddings and reranker
Browse files- app/api/v2_endpoints.py +77 -8
- app/core/config.py +5 -2
- app/services/llm_service.py +2 -2
- app/services/reranker_service.py +2 -1
- app/services/retrieval.py +4 -4
app/api/v2_endpoints.py
CHANGED
|
@@ -29,13 +29,53 @@ from app.services import query_expansion_service
|
|
| 29 |
from app.core import state
|
| 30 |
|
| 31 |
logger = logging.getLogger(__name__)
|
| 32 |
-
logger.setLevel(logging.DEBUG)
|
| 33 |
router = APIRouter()
|
| 34 |
|
| 35 |
# --- Constants ---
|
| 36 |
CONTEXT_CHUNK_COUNT = 100
|
| 37 |
# --- MODIFIED: TOTAL_RETRIEVAL_COUNT is now the number of candidates for the re-ranker ---
|
| 38 |
-
RERANK_CANDIDATE_COUNT =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# --- Startup Event (Loads data into state) ---
|
| 41 |
@router.on_event("startup")
|
|
@@ -172,10 +212,23 @@ async def handle_v2_query(
|
|
| 172 |
retrieved_scores_float = [float(score) for chunk_id, score in search_results]
|
| 173 |
|
| 174 |
candidate_chunks = []
|
|
|
|
| 175 |
for chunk_id, initial_score in search_results:
|
| 176 |
chunk_text = state.chunk_content_map.get(str(chunk_id))
|
| 177 |
if chunk_text:
|
| 178 |
candidate_chunks.append({"id": str(chunk_id), "text": chunk_text})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
# --- MODIFIED: Offload the blocking re-ranker function to a threadpool ---
|
| 181 |
reranked_chunks = await run_in_threadpool(
|
|
@@ -186,12 +239,18 @@ async def handle_v2_query(
|
|
| 186 |
)
|
| 187 |
|
| 188 |
if reranked_chunks:
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
-
if not filtered_chunks:
|
| 193 |
-
|
| 194 |
-
|
| 195 |
|
| 196 |
# --- MODIFIED: Offload the blocking sequence organization to a threadpool ---
|
| 197 |
organized_chunks = await run_in_threadpool(
|
|
@@ -213,7 +272,17 @@ async def handle_v2_query(
|
|
| 213 |
else:
|
| 214 |
llm_answer = "I found relevant documents, but could not construct an answer."
|
| 215 |
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
else:
|
| 219 |
llm_answer = "Could not re-rank the search results."
|
|
|
|
| 29 |
from app.core import state
|
| 30 |
|
| 31 |
logger = logging.getLogger(__name__)
|
|
|
|
| 32 |
router = APIRouter()
|
| 33 |
|
| 34 |
# --- Constants ---
|
| 35 |
CONTEXT_CHUNK_COUNT = 100
|
| 36 |
# --- MODIFIED: TOTAL_RETRIEVAL_COUNT is now the number of candidates for the re-ranker ---
|
| 37 |
+
RERANK_CANDIDATE_COUNT = 100
|
| 38 |
+
|
| 39 |
+
def dynamic_top_k_selection(
|
| 40 |
+
reranked_docs: List[Dict[str, Any]],
|
| 41 |
+
k_min: int = 3,
|
| 42 |
+
k_max: int = 15,
|
| 43 |
+
fall_off_threshold: float = 1.0 # Start with a threshold of 1.0 logit score drop
|
| 44 |
+
) -> List[Dict[str, Any]]:
|
| 45 |
+
"""
|
| 46 |
+
Selects a dynamic number of documents based on score fall-off.
|
| 47 |
+
"""
|
| 48 |
+
if not reranked_docs:
|
| 49 |
+
return []
|
| 50 |
+
|
| 51 |
+
if len(reranked_docs) <= k_min:
|
| 52 |
+
return reranked_docs
|
| 53 |
+
|
| 54 |
+
scores = np.array([doc.get('rerank_score', -float('inf')) for doc in reranked_docs])
|
| 55 |
+
score_diffs = np.diff(scores) * -1 # Make differences positive as scores are descending
|
| 56 |
+
|
| 57 |
+
elbow_index = -1
|
| 58 |
+
# Start searching for a large fall-off after the k_min-th document
|
| 59 |
+
for i in range(k_min - 1, len(score_diffs)):
|
| 60 |
+
if score_diffs[i] > fall_off_threshold:
|
| 61 |
+
# The drop is after this document, so we take up to and including this one.
|
| 62 |
+
elbow_index = i + 1
|
| 63 |
+
break
|
| 64 |
+
|
| 65 |
+
if elbow_index != -1:
|
| 66 |
+
# We found a significant drop
|
| 67 |
+
final_k = elbow_index
|
| 68 |
+
else:
|
| 69 |
+
# No significant drop found, take the max allowed
|
| 70 |
+
final_k = k_max
|
| 71 |
+
|
| 72 |
+
# Ensure final_k is within the [k_min, k_max] bounds and also within list size
|
| 73 |
+
final_k = min(max(final_k, k_min), k_max, len(reranked_docs))
|
| 74 |
+
|
| 75 |
+
logger.info(f"Dynamic K selection: Found elbow at index {elbow_index}. "
|
| 76 |
+
f"Selected final K of {final_k} from {len(reranked_docs)} candidates.")
|
| 77 |
+
|
| 78 |
+
return reranked_docs[:final_k]
|
| 79 |
|
| 80 |
# --- Startup Event (Loads data into state) ---
|
| 81 |
@router.on_event("startup")
|
|
|
|
| 212 |
retrieved_scores_float = [float(score) for chunk_id, score in search_results]
|
| 213 |
|
| 214 |
candidate_chunks = []
|
| 215 |
+
missing_chunk_count = 0
|
| 216 |
for chunk_id, initial_score in search_results:
|
| 217 |
chunk_text = state.chunk_content_map.get(str(chunk_id))
|
| 218 |
if chunk_text:
|
| 219 |
candidate_chunks.append({"id": str(chunk_id), "text": chunk_text})
|
| 220 |
+
else:
|
| 221 |
+
missing_chunk_count += 1
|
| 222 |
+
logger.warning(
|
| 223 |
+
f"Data consistency warning: Retrieved chunk_id '{chunk_id}' "
|
| 224 |
+
f"not found in the in-memory chunk_content_map."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# --- LOGGING POINT 2: After building the candidate list ---
|
| 228 |
+
logger.debug(
|
| 229 |
+
f"Successfully built {len(candidate_chunks)} candidate chunks for re-ranking. "
|
| 230 |
+
f"{missing_chunk_count} chunks were dropped due to missing text content."
|
| 231 |
+
)
|
| 232 |
|
| 233 |
# --- MODIFIED: Offload the blocking re-ranker function to a threadpool ---
|
| 234 |
reranked_chunks = await run_in_threadpool(
|
|
|
|
| 239 |
)
|
| 240 |
|
| 241 |
if reranked_chunks:
|
| 242 |
+
filtered_chunks = dynamic_top_k_selection(
|
| 243 |
+
reranked_docs=reranked_chunks,
|
| 244 |
+
k_min=settings.RERANKER_K_MIN, # e.g., 3
|
| 245 |
+
k_max=settings.RERANKER_K_MAX, # e.g., 100
|
| 246 |
+
fall_off_threshold=settings.RERANKER_FALLOFF_THRESHOLD # e.g., 1.0
|
| 247 |
+
)
|
| 248 |
+
# score_threshold = settings.RERANKER_SCORE_THRESHOLD
|
| 249 |
+
# filtered_chunks = [c for c in reranked_chunks if c['rerank_score'] > score_threshold]
|
| 250 |
|
| 251 |
+
# if not filtered_chunks:
|
| 252 |
+
# logger.warning(f"No chunks met the score threshold of {score_threshold}. Using only the top-ranked chunk.")
|
| 253 |
+
# filtered_chunks = reranked_chunks[:1]
|
| 254 |
|
| 255 |
# --- MODIFIED: Offload the blocking sequence organization to a threadpool ---
|
| 256 |
organized_chunks = await run_in_threadpool(
|
|
|
|
| 272 |
else:
|
| 273 |
llm_answer = "I found relevant documents, but could not construct an answer."
|
| 274 |
|
| 275 |
+
top_result_preview = None
|
| 276 |
+
if reranked_chunks:
|
| 277 |
+
top_chunk = reranked_chunks[0]
|
| 278 |
+
top_metadata = state.chunk_metadata_map.get(top_chunk['id'], {})
|
| 279 |
+
top_result_preview = schemas.TopResultPreview(
|
| 280 |
+
id=top_chunk['id'],
|
| 281 |
+
score=float(top_chunk['rerank_score']),
|
| 282 |
+
content_preview=top_chunk['text'][:150],
|
| 283 |
+
original_file=top_metadata.get('original_file')
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
|
| 287 |
else:
|
| 288 |
llm_answer = "Could not re-rank the search results."
|
app/core/config.py
CHANGED
|
@@ -26,8 +26,11 @@ class Settings(BaseSettings):
|
|
| 26 |
S3_RERANKER_URL: Optional[str] = None
|
| 27 |
|
| 28 |
RERANKER_MODEL_PATH: str = "data/best_expert_judge_cross_encoder.pt" # Or the exact name of your saved .pt file
|
| 29 |
-
RERANKER_MODEL_NAME: str = "
|
| 30 |
-
RERANKER_SCORE_THRESHOLD: float = 0.
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
SEQUENCE_EXPANSION_THRESHOLD: float =0.68
|
| 33 |
|
|
|
|
| 26 |
S3_RERANKER_URL: Optional[str] = None
|
| 27 |
|
| 28 |
RERANKER_MODEL_PATH: str = "data/best_expert_judge_cross_encoder.pt" # Or the exact name of your saved .pt file
|
| 29 |
+
RERANKER_MODEL_NAME: str = "mixedbread-ai/mxbai-rerank-base-v2" # The base model used in your training script
|
| 30 |
+
RERANKER_SCORE_THRESHOLD: float = 0.0
|
| 31 |
+
RERANKER_K_MIN: int = 5
|
| 32 |
+
RERANKER_FALLOFF_THRESHOLD: int = 1
|
| 33 |
+
RERANKER_K_MAX: int = 100
|
| 34 |
|
| 35 |
SEQUENCE_EXPANSION_THRESHOLD: float =0.68
|
| 36 |
|
app/services/llm_service.py
CHANGED
|
@@ -116,7 +116,7 @@ async def generate_answer(query: str, context_used: str) -> str:
|
|
| 116 |
answer = "⚠️ LLM returned an empty response."
|
| 117 |
else:
|
| 118 |
logger.info("LLM query-based response generated.")
|
| 119 |
-
logger.
|
| 120 |
-
logger.
|
| 121 |
|
| 122 |
return answer
|
|
|
|
| 116 |
answer = "⚠️ LLM returned an empty response."
|
| 117 |
else:
|
| 118 |
logger.info("LLM query-based response generated.")
|
| 119 |
+
logger.info(f"LLM Request preview: {messages}")
|
| 120 |
+
logger.info(f"LLM Response preview: {answer}...")
|
| 121 |
|
| 122 |
return answer
|
app/services/reranker_service.py
CHANGED
|
@@ -97,7 +97,8 @@ def rerank_chunks(query: str, chunks: List[Dict], metadata_map: Dict) -> List[Di
|
|
| 97 |
encoded = {k: v.to(state.device) for k, v in encoded.items()}
|
| 98 |
|
| 99 |
logits = model(input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask'])
|
| 100 |
-
scores = torch.sigmoid(logits).squeeze().cpu().numpy()
|
|
|
|
| 101 |
|
| 102 |
# Add the new, more accurate score to each chunk dictionary
|
| 103 |
if len(chunks) == 1:
|
|
|
|
| 97 |
encoded = {k: v.to(state.device) for k, v in encoded.items()}
|
| 98 |
|
| 99 |
logits = model(input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask'])
|
| 100 |
+
# scores = torch.sigmoid(logits).squeeze().cpu().numpy()
|
| 101 |
+
scores = logits.squeeze().cpu().numpy()
|
| 102 |
|
| 103 |
# Add the new, more accurate score to each chunk dictionary
|
| 104 |
if len(chunks) == 1:
|
app/services/retrieval.py
CHANGED
|
@@ -129,8 +129,8 @@ def find_top_gnn_chunks(query_text: str, top_n: int = 100) -> List[Tuple[str, fl
|
|
| 129 |
logger.info(f"Similarity search completed in {duration:.4f} seconds.")
|
| 130 |
|
| 131 |
# --- MODIFIED: Detailed logging to include original_file ---
|
| 132 |
-
top_results_to_log = results[:
|
| 133 |
-
logger.
|
| 134 |
for i, (chunk_id, score) in enumerate(top_results_to_log):
|
| 135 |
# Look up metadata for the current chunk_id
|
| 136 |
chunk_id_str = str(chunk_id) # Ensure the key is a string for lookup
|
|
@@ -138,8 +138,8 @@ def find_top_gnn_chunks(query_text: str, top_n: int = 100) -> List[Tuple[str, fl
|
|
| 138 |
original_file = metadata.get('original_file', 'File not found')
|
| 139 |
|
| 140 |
# Updated log message to include the original file
|
| 141 |
-
logger.
|
| 142 |
-
logger.
|
| 143 |
# --- END OF MODIFICATION ---
|
| 144 |
|
| 145 |
return results[:top_n]
|
|
|
|
| 129 |
logger.info(f"Similarity search completed in {duration:.4f} seconds.")
|
| 130 |
|
| 131 |
# --- MODIFIED: Detailed logging to include original_file ---
|
| 132 |
+
top_results_to_log = results[:200]
|
| 133 |
+
logger.info("--- Top 30 Retrieved Chunks (from retrieval service) ---")
|
| 134 |
for i, (chunk_id, score) in enumerate(top_results_to_log):
|
| 135 |
# Look up metadata for the current chunk_id
|
| 136 |
chunk_id_str = str(chunk_id) # Ensure the key is a string for lookup
|
|
|
|
| 138 |
original_file = metadata.get('original_file', 'File not found')
|
| 139 |
|
| 140 |
# Updated log message to include the original file
|
| 141 |
+
logger.info(f" {i+1}. Chunk ID: {chunk_id_str} | File: {original_file} | Score: {score:.4f}")
|
| 142 |
+
logger.info("----------------------------------------------------------------")
|
| 143 |
# --- END OF MODIFICATION ---
|
| 144 |
|
| 145 |
return results[:top_n]
|