Spaces:
Sleeping
Sleeping
| # app/api/v2_endpoints.py | |
| from fastapi.concurrency import run_in_threadpool | |
| from fastapi import APIRouter, Depends, HTTPException, Body, status | |
| from sqlalchemy.orm import Session | |
| import logging | |
| import time | |
| import uuid | |
| from datetime import datetime | |
| from typing import Dict, Optional, Tuple, List, Any, Set | |
| from app.core.config import settings | |
| import numpy as np | |
| # --- DB Imports --- | |
| from app.db.database import get_db | |
| from app.db import models | |
| from app.db import schemas | |
| # --- Service Imports --- | |
| from app.services import data_loader | |
| from app.services import retrieval | |
| from app.services import context_builder | |
| from app.services import llm_service | |
| # --- ADDED: Import the new reranker service --- | |
| from app.services import reranker_service | |
| from app.services import parts_combination_service | |
| from app.services import query_expansion_service | |
| # --- State Import --- | |
| from app.core import state | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| router = APIRouter() | |
| # --- Constants --- | |
| CONTEXT_CHUNK_COUNT = 100 | |
| # --- MODIFIED: TOTAL_RETRIEVAL_COUNT is now the number of candidates for the re-ranker --- | |
| RERANK_CANDIDATE_COUNT = 100 | |
| def dynamic_top_k_selection( | |
| reranked_docs: List[Dict[str, Any]], | |
| k_min: int = 3, | |
| k_max: int = 15, | |
| fall_off_threshold: float = 1.0 # Start with a threshold of 1.0 logit score drop | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Selects a dynamic number of documents based on score fall-off. | |
| """ | |
| if not reranked_docs: | |
| return [] | |
| if len(reranked_docs) <= k_min: | |
| return reranked_docs | |
| scores = np.array([doc.get('rerank_score', -float('inf')) for doc in reranked_docs]) | |
| score_diffs = np.diff(scores) * -1 # Make differences positive as scores are descending | |
| elbow_index = -1 | |
| # Start searching for a large fall-off after the k_min-th document | |
| for i in range(k_min - 1, len(score_diffs)): | |
| if score_diffs[i] > fall_off_threshold: | |
| # The drop is after this document, so we take up to and including this one. | |
| elbow_index = i + 1 | |
| break | |
| if elbow_index != -1: | |
| # We found a significant drop | |
| final_k = elbow_index | |
| else: | |
| # No significant drop found, take the max allowed | |
| final_k = k_max | |
| # Ensure final_k is within the [k_min, k_max] bounds and also within list size | |
| final_k = min(max(final_k, k_min), k_max, len(reranked_docs)) | |
| logger.info(f"Dynamic K selection: Found elbow at index {elbow_index}. " | |
| f"Selected final K of {final_k} from {len(reranked_docs)} candidates.") | |
| return reranked_docs[:final_k] | |
| # --- Startup Event (Loads data into state) --- | |
| async def v2_load_data_on_startup(): | |
| """Load data and models into the central state object on startup.""" | |
| if state.v2_data_loaded: | |
| logger.info("V2 data already loaded in state. Skipping.") | |
| return | |
| logger.info("--- Starting V2 Data Loading Sequence ---") | |
| start_time = time.time() | |
| load_success = True | |
| # Task 1: Load Retrieval Artifacts (Bi-encoder) | |
| logger.info("Startup Task 1: Loading retrieval artifacts (embeddings, Wq, temp)...") | |
| artifacts_loaded = retrieval.load_retrieval_artifacts() | |
| if not artifacts_loaded: | |
| logger.error("CRITICAL FAILURE: Failed to load retrieval artifacts.") | |
| load_success = False | |
| else: | |
| logger.info("Retrieval artifacts loaded successfully.") | |
| # Task 2: Load Content Map | |
| if load_success: | |
| logger.info("Startup Task 2: Loading Chunk Content Map...") | |
| if state.chunk_ids_in_order is not None: | |
| required_ids = set(state.chunk_ids_in_order) | |
| loaded_content, loaded_metadata = await data_loader.load_chunk_content_map( | |
| required_chunk_ids=required_ids | |
| ) | |
| if loaded_content is None or loaded_metadata is None: | |
| logger.error("CRITICAL FAILURE: Failed to load chunk content/metadata map.") | |
| load_success = False | |
| else: | |
| state.chunk_content_map = loaded_content | |
| state.chunk_metadata_map = loaded_metadata | |
| logger.info(f"Chunk Content Map loading completed for {len(loaded_content)} chunks.") | |
| else: | |
| logger.warning("Skipping content loading due to artifact load failure.") | |
| # Task 3: Initialize LLM Client | |
| if load_success: | |
| logger.info("Startup Task 3: Initializing OpenAI Client...") | |
| client_initialized = llm_service.initialize_openai_client() | |
| if not client_initialized: | |
| load_success = False | |
| logger.error("CRITICAL FAILURE: OpenAI client failed to initialize.") | |
| else: | |
| logger.info("OpenAI Client initialization completed.") | |
| # --- ADDED: Startup Task 4: Load Re-ranker Model --- | |
| if load_success: | |
| logger.info("Startup Task 4: Loading Re-ranker Model...") | |
| reranker_loaded = reranker_service.load_reranker_model() | |
| if not reranker_loaded: | |
| load_success = False | |
| logger.error("CRITICAL FAILURE: Re-ranker model failed to initialize.") | |
| else: | |
| logger.info("Re-ranker model initialization completed.") | |
| # --- END OF ADDITION --- | |
| # ... | |
| # --- ADDED: Startup Task 5: Load Chunk Type and Sequence Data --- | |
| if load_success: | |
| logger.info("Startup Task 5: Loading Chunk Type and Sequence Data...") | |
| maps_loaded = await parts_combination_service.load_chunk_type_map() | |
| if not maps_loaded: | |
| logger.warning("WARNING: Chunk Type/Sequence data failed to load.") | |
| else: | |
| logger.info("Chunk Type and Sequence Data initialization completed.") | |
| # ... | |
| # Final status update | |
| if load_success: | |
| state.v2_data_loaded = True | |
| duration = time.time() - start_time | |
| logger.info(f"--- V2 Data Loading Sequence Complete. Duration: {duration:.2f} seconds ---") | |
| else: | |
| state.v2_data_loaded = False | |
| duration = time.time() - start_time | |
| logger.error(f"--- V2 Data Loading Sequence FAILED. Duration: {duration:.2f} seconds. ---") | |
| async def handle_v2_query( | |
| request: schemas.QueryRequest = Body(...), | |
| db: Session = Depends(get_db) | |
| ): | |
| logger.info(f"Received V2 query (Attempt 1): '{request.query[:100]}...' for session: {request.session_id}") | |
| start_time = time.time() | |
| if not state.v2_data_loaded: | |
| raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Service is not ready.") | |
| # ... (session and user message creation remain the same) ... | |
| session_uuid, _ = await get_or_create_session(request.session_id, db) | |
| user_msg = models.ChatMessage(session_id=session_uuid, role=schemas.MessageRole.USER, content=request.query) | |
| db.add(user_msg) | |
| llm_answer = "" | |
| context_string = "" | |
| retrieved_chunk_ids = [] | |
| used_ids_this_attempt = [] | |
| retrieved_scores_float = [] | |
| top_result_preview = None | |
| original_file = None | |
| try: | |
| # --- STEP 1: PRE-PROCESSING (Direct ABBREVIATION Replacement) --- | |
| original_query = request.query | |
| # --- EDIT: Call the new, direct replacement function --- | |
| normalized_query = query_expansion_service.replace_abbreviations(original_query) | |
| if original_query != normalized_query: | |
| logger.info(f"Query expanded from '{original_query}' to '{normalized_query}'") | |
| # --- MODIFIED: Offload the blocking retrieval function to a threadpool --- | |
| search_results: List[Tuple[str, float]] = await run_in_threadpool( | |
| retrieval.find_top_gnn_chunks, | |
| query_text=normalized_query, | |
| top_n=RERANK_CANDIDATE_COUNT | |
| ) | |
| if not search_results: | |
| llm_answer = "Based on the available information, I could not find a specific answer to your query." | |
| else: | |
| retrieved_chunk_ids = [str(chunk_id) for chunk_id, score in search_results] | |
| retrieved_scores_float = [float(score) for chunk_id, score in search_results] | |
| candidate_chunks = [] | |
| missing_chunk_count = 0 | |
| for chunk_id, initial_score in search_results: | |
| chunk_text = state.chunk_content_map.get(str(chunk_id)) | |
| if chunk_text: | |
| candidate_chunks.append({"id": str(chunk_id), "text": chunk_text}) | |
| else: | |
| missing_chunk_count += 1 | |
| logger.warning( | |
| f"Data consistency warning: Retrieved chunk_id '{chunk_id}' " | |
| f"not found in the in-memory chunk_content_map." | |
| ) | |
| # --- LOGGING POINT 2: After building the candidate list --- | |
| logger.debug( | |
| f"Successfully built {len(candidate_chunks)} candidate chunks for re-ranking. " | |
| f"{missing_chunk_count} chunks were dropped due to missing text content." | |
| ) | |
| # --- MODIFIED: Offload the blocking re-ranker function to a threadpool --- | |
| reranked_chunks = await run_in_threadpool( | |
| reranker_service.rerank_chunks, | |
| query=normalized_query, | |
| chunks=candidate_chunks, | |
| metadata_map=state.chunk_metadata_map | |
| ) | |
| if reranked_chunks: | |
| filtered_chunks = dynamic_top_k_selection( | |
| reranked_docs=reranked_chunks, | |
| k_min=settings.RERANKER_K_MIN, # e.g., 3 | |
| k_max=settings.RERANKER_K_MAX, # e.g., 100 | |
| fall_off_threshold=settings.RERANKER_FALLOFF_THRESHOLD # e.g., 1.0 | |
| ) | |
| # score_threshold = settings.RERANKER_SCORE_THRESHOLD | |
| # filtered_chunks = [c for c in reranked_chunks if c['rerank_score'] > score_threshold] | |
| # if not filtered_chunks: | |
| # logger.warning(f"No chunks met the score threshold of {score_threshold}. Using only the top-ranked chunk.") | |
| # filtered_chunks = reranked_chunks[:1] | |
| # --- MODIFIED: Offload the blocking sequence organization to a threadpool --- | |
| organized_chunks = await run_in_threadpool( | |
| parts_combination_service.organize_chunks_by_sequence, | |
| chunks=filtered_chunks | |
| ) | |
| final_chunks_for_context = organized_chunks[:CONTEXT_CHUNK_COUNT] | |
| # --- This function is very fast, no threadpool needed --- | |
| ids_for_final_context = [chunk['id'] for chunk in final_chunks_for_context] | |
| context_string, used_ids_this_attempt = context_builder.build_context_from_ids( | |
| top_chunk_ids=ids_for_final_context | |
| ) | |
| if context_string: | |
| # --- MODIFIED: Simply 'await' the now-async llm_service function --- | |
| llm_answer = await llm_service.generate_answer(request.query, context_string) | |
| else: | |
| llm_answer = "I found relevant documents, but could not construct an answer." | |
| top_result_preview = None | |
| if reranked_chunks: | |
| top_chunk = reranked_chunks[0] | |
| top_metadata = state.chunk_metadata_map.get(top_chunk['id'], {}) | |
| top_result_preview = schemas.TopResultPreview( | |
| id=top_chunk['id'], | |
| score=float(top_chunk['rerank_score']), | |
| content_preview=top_chunk['text'][:150], | |
| original_file=top_metadata.get('original_file') | |
| ) | |
| else: | |
| llm_answer = "Could not re-rank the search results." | |
| except Exception as e: | |
| logger.exception(f"Unexpected error during query processing: {e}") | |
| llm_answer = "⚠️ An unexpected error occurred." | |
| if not llm_answer: llm_answer = "⚠️ Error: No response generated." | |
| bot_msg = models.ChatMessage( | |
| session_id=session_uuid, role=schemas.MessageRole.BOT, content=llm_answer, | |
| original_query=request.query, retrieved_context_ids=retrieved_chunk_ids, | |
| used_context_ids=used_ids_this_attempt, attempt_number=1, | |
| cumulative_used_context_ids=used_ids_this_attempt | |
| ) | |
| db.add(bot_msg) | |
| try: | |
| db.commit() | |
| db.refresh(bot_msg) | |
| bot_message_id = bot_msg.id | |
| except Exception as e: | |
| db.rollback(); logger.exception(f"DB commit failed: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to save conversation messages.") | |
| response_details = schemas.QueryResultDetail( | |
| session_id=session_uuid, message_id=bot_message_id, attempt_number=1, | |
| retrieved_ids=retrieved_chunk_ids, search_scores=retrieved_scores_float, original_file=original_file | |
| ) | |
| final_response = schemas.QueryResponse( | |
| llm_answer=llm_answer, context_used_preview=context_string[:200] + "..." if context_string else "No context.", | |
| top_result_preview=top_result_preview, details=response_details | |
| ) | |
| end_time = time.time() | |
| logger.info(f"V2 query (Attempt 1) processed in {end_time - start_time:.2f}s") | |
| return final_response | |
| async def get_or_create_session(session_id: Optional[uuid.UUID], db: Session) -> Tuple[uuid.UUID, bool]: | |
| is_new = False | |
| if session_id: | |
| session = db.query(models.ChatSession).filter(models.ChatSession.id == session_id).first() | |
| if not session: | |
| session_id = uuid.uuid4() | |
| is_new = True | |
| else: | |
| session_id = uuid.uuid4() | |
| is_new = True | |
| if is_new: | |
| new_db_session = models.ChatSession(id=session_id, name=f"Session {str(session_id)[:8]}") | |
| db.add(new_db_session) | |
| return session_id, is_new | |
| # --- Session Management Endpoints --- | |
| async def create_v2_session(session_create: schemas.SessionCreate, db: Session = Depends(get_db)): | |
| logger.info(f"Creating new V2 session with name: '{session_create.name}'") | |
| session_uuid = uuid.uuid4() | |
| db_session = models.ChatSession(id=session_uuid, name=session_create.name) | |
| try: | |
| db.add(db_session); db.commit(); db.refresh(db_session) | |
| logger.info(f"Successfully created session {db_session.id}") | |
| return db_session | |
| except Exception as e: | |
| db.rollback(); logger.exception(f"Failed to create session: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to create session") | |
| async def list_v2_sessions(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): | |
| logger.info(f"Listing V2 sessions (skip={skip}, limit={limit})") | |
| sessions = db.query(models.ChatSession).order_by(models.ChatSession.created_at.desc()).offset(skip).limit(limit).all() | |
| return sessions | |
| async def rename_v2_session(session_id: uuid.UUID, session_update: schemas.ChatSessionUpdate, db: Session = Depends(get_db)): | |
| logger.info(f"Attempting to rename session {session_id} to '{session_update.name}'") | |
| db_session = db.query(models.ChatSession).filter(models.ChatSession.id == session_id).first() | |
| if not db_session: raise HTTPException(status_code=404, detail="Session not found") | |
| db_session.name = session_update.name | |
| try: | |
| db.add(db_session); db.commit(); db.refresh(db_session) | |
| logger.info(f"Successfully renamed session {session_id}") | |
| return db_session | |
| except Exception as e: | |
| db.rollback(); logger.exception(f"Failed to rename session: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to rename session") | |
| # --- Message Retrieval Endpoint --- | |
| async def get_v2_session_messages(session_id: uuid.UUID, db: Session = Depends(get_db)): | |
| logger.info(f"Fetching messages for V2 session: {session_id}") | |
| session = db.query(models.ChatSession).filter(models.ChatSession.id == session_id).first() | |
| if not session: raise HTTPException(status_code=404, detail="Session not found") | |
| messages = db.query(models.ChatMessage).filter(models.ChatMessage.session_id == session_id).order_by(models.ChatMessage.created_at.asc()).all() | |
| logger.info(f"Found {len(messages)} messages for session {session_id}.") | |
| return messages | |
| # --- Feedback Endpoint --- | |
| async def submit_feedback( | |
| feedback_data: schemas.FeedbackCreate = Body(...), | |
| db: Session = Depends(get_db) | |
| ): | |
| logger.info(f"Received feedback submission for message_id: {feedback_data.message_id}, type: {feedback_data.feedback_type.value}") | |
| rated_message = db.query(models.ChatMessage).filter(models.ChatMessage.id == feedback_data.message_id).first() | |
| if not rated_message: raise HTTPException(status_code=404, detail="Message not found") | |
| db_feedback = db.query(models.FeedbackLog).filter(models.FeedbackLog.message_id == feedback_data.message_id).first() | |
| if db_feedback: | |
| db_feedback.feedback_type = feedback_data.feedback_type | |
| db_feedback.feedback_comment = feedback_data.feedback_comment | |
| else: | |
| db_feedback = models.FeedbackLog(**feedback_data.dict()) | |
| db.add(db_feedback) | |
| db.commit() | |
| db.refresh(db_feedback) | |
| if feedback_data.feedback_type == schemas.FeedbackTypeEnum.REJECT: | |
| # Regeneration logic would go here if needed | |
| pass | |
| return db_feedback |