# app/services/retrieval.py import logging import time import numpy as np import torch import torch.nn.functional as F from typing import List, Tuple from sentence_transformers import SentenceTransformer import os import requests # Import settings and state from app.core.config import settings from app.core import state logger = logging.getLogger(__name__) def load_retrieval_artifacts(): """ Loads all necessary artifacts for retrieval. If running on a new server, it first downloads the artifacts from S3. """ if state.artifacts_loaded: logger.info("Retrieval artifacts already loaded in state.") return True # --- ADDED: Download from S3 if file doesn't exist --- artifacts_path = settings.RETRIEVAL_ARTIFACTS_PATH if not os.path.exists(artifacts_path) and settings.S3_ARTIFACTS_URL: logger.info(f"Artifacts file not found at {artifacts_path}. Downloading from S3...") try: # Create the 'data' directory if it doesn't exist os.makedirs(os.path.dirname(artifacts_path), exist_ok=True) with requests.get(settings.S3_ARTIFACTS_URL, stream=True) as r: r.raise_for_status() with open(artifacts_path, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) logger.info("Successfully downloaded artifacts from S3.") except Exception as e: logger.exception(f"FATAL: Failed to download artifacts from S3: {e}") return False # --- END OF ADDITION --- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device for retrieval: {device}") state.device = device # 1. Load the pre-computed artifacts file logger.info(f"Loading retrieval artifacts from {artifacts_path}...") try: if not os.path.exists(artifacts_path): logger.error(f"FATAL: Artifacts file not found at {artifacts_path}") return False artifacts = np.load(artifacts_path, allow_pickle=True) # Load into state state.transformed_chunk_embeddings = artifacts['transformed_chunk_embeddings'] state.chunk_ids_in_order = artifacts['chunk_ids'] state.temperature = artifacts['temperature'][0] # Extract scalar from array logger.info(f"Successfully loaded {len(state.chunk_ids_in_order)} transformed embeddings.") logger.info(f"Loaded temperature value: {state.temperature:.4f}") except Exception as e: logger.exception(f"Failed to load and process retrieval artifacts: {e}") return False # 2. Load the Sentence Transformer model for encoding queries logger.info(f"Loading Sentence Transformer model: {settings.QUERY_ENCODER_MODEL_NAME}...") try: # query_encoder = SentenceTransformer(settings.QUERY_ENCODER_MODEL_NAME, device=device) cache_dir = "data/cache/" logger.info(f"Using cache directory for models: {cache_dir}") query_encoder = SentenceTransformer(settings.QUERY_ENCODER_MODEL_NAME, device=device,cache_folder=cache_dir) query_encoder.eval() state.query_encoder_model = query_encoder logger.info("Query encoder model loaded successfully.") except Exception as e: logger.exception(f"Failed to load Sentence Transformer model: {e}") return False state.artifacts_loaded = True return True # In app/services/retrieval.py def find_top_gnn_chunks(query_text: str, top_n: int = 200) -> List[Tuple[str, float]]: """ Performs a similarity search that is mathematically identical to the trained model, but without loading the GNN itself. It uses pre-transformed embeddings. """ if not state.artifacts_loaded: logger.error("Service not ready. Retrieval artifacts not loaded.") return [] start_time = time.time() try: with torch.no_grad(): # 1. Encode query text into an embedding vector query_embedding = state.query_encoder_model.encode( query_text, convert_to_tensor=True, device=state.device ) # 2. Apply query normalization to the query embedding q_trans_normalized = F.normalize(query_embedding.unsqueeze(0), p=2, dim=-1) # 3. Convert to numpy for fast similarity calculation query_vec_np = q_trans_normalized.cpu().numpy() # 4. Perform dot product. similarities = (query_vec_np @ state.transformed_chunk_embeddings.T)[0] # 5. Apply the learned temperature scaling scaled_similarities = similarities * np.exp(state.temperature) # 6. Combine with IDs, sort, and return top N results = list(zip(state.chunk_ids_in_order, scaled_similarities)) results.sort(key=lambda item: item[1], reverse=True) except Exception as e: logger.exception(f"Error during model-based similarity search: {e}") return [] duration = time.time() - start_time logger.info(f"Similarity search completed in {duration:.4f} seconds.") # --- MODIFIED: Detailed logging to include original_file --- top_results_to_log = results[:200] logger.info("--- Top 30 Retrieved Chunks (from retrieval service) ---") for i, (chunk_id, score) in enumerate(top_results_to_log): # Look up metadata for the current chunk_id chunk_id_str = str(chunk_id) # Ensure the key is a string for lookup metadata = state.chunk_metadata_map.get(chunk_id_str, {}) original_file = metadata.get('original_file', 'File not found') # Updated log message to include the original file logger.info(f" {i+1}. Chunk ID: {chunk_id_str} | File: {original_file} | Score: {score:.4f}") logger.info("----------------------------------------------------------------") # --- END OF MODIFICATION --- return results[:top_n]