Spaces:
Sleeping
Sleeping
| # app/services/retrieval.py | |
| import logging | |
| import time | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from typing import List, Tuple | |
| from sentence_transformers import SentenceTransformer | |
| import os | |
| import requests | |
| # Import settings and state | |
| from app.core.config import settings | |
| from app.core import state | |
| logger = logging.getLogger(__name__) | |
| def load_retrieval_artifacts(): | |
| """ | |
| Loads all necessary artifacts for retrieval. If running on a new server, | |
| it first downloads the artifacts from S3. | |
| """ | |
| if state.artifacts_loaded: | |
| logger.info("Retrieval artifacts already loaded in state.") | |
| return True | |
| # --- ADDED: Download from S3 if file doesn't exist --- | |
| artifacts_path = settings.RETRIEVAL_ARTIFACTS_PATH | |
| if not os.path.exists(artifacts_path) and settings.S3_ARTIFACTS_URL: | |
| logger.info(f"Artifacts file not found at {artifacts_path}. Downloading from S3...") | |
| try: | |
| # Create the 'data' directory if it doesn't exist | |
| os.makedirs(os.path.dirname(artifacts_path), exist_ok=True) | |
| with requests.get(settings.S3_ARTIFACTS_URL, stream=True) as r: | |
| r.raise_for_status() | |
| with open(artifacts_path, 'wb') as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info("Successfully downloaded artifacts from S3.") | |
| except Exception as e: | |
| logger.exception(f"FATAL: Failed to download artifacts from S3: {e}") | |
| return False | |
| # --- END OF ADDITION --- | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| logger.info(f"Using device for retrieval: {device}") | |
| state.device = device | |
| # 1. Load the pre-computed artifacts file | |
| logger.info(f"Loading retrieval artifacts from {artifacts_path}...") | |
| try: | |
| if not os.path.exists(artifacts_path): | |
| logger.error(f"FATAL: Artifacts file not found at {artifacts_path}") | |
| return False | |
| artifacts = np.load(artifacts_path, allow_pickle=True) | |
| # Load into state | |
| state.transformed_chunk_embeddings = artifacts['transformed_chunk_embeddings'] | |
| state.chunk_ids_in_order = artifacts['chunk_ids'] | |
| state.temperature = artifacts['temperature'][0] # Extract scalar from array | |
| logger.info(f"Successfully loaded {len(state.chunk_ids_in_order)} transformed embeddings.") | |
| logger.info(f"Loaded temperature value: {state.temperature:.4f}") | |
| except Exception as e: | |
| logger.exception(f"Failed to load and process retrieval artifacts: {e}") | |
| return False | |
| # 2. Load the Sentence Transformer model for encoding queries | |
| logger.info(f"Loading Sentence Transformer model: {settings.QUERY_ENCODER_MODEL_NAME}...") | |
| try: | |
| # query_encoder = SentenceTransformer(settings.QUERY_ENCODER_MODEL_NAME, device=device) | |
| cache_dir = "data/cache/" | |
| logger.info(f"Using cache directory for models: {cache_dir}") | |
| query_encoder = SentenceTransformer(settings.QUERY_ENCODER_MODEL_NAME, device=device,cache_folder=cache_dir) | |
| query_encoder.eval() | |
| state.query_encoder_model = query_encoder | |
| logger.info("Query encoder model loaded successfully.") | |
| except Exception as e: | |
| logger.exception(f"Failed to load Sentence Transformer model: {e}") | |
| return False | |
| state.artifacts_loaded = True | |
| return True | |
| # In app/services/retrieval.py | |
| def find_top_gnn_chunks(query_text: str, top_n: int = 200) -> List[Tuple[str, float]]: | |
| """ | |
| Performs a similarity search that is mathematically identical to the trained model, | |
| but without loading the GNN itself. It uses pre-transformed embeddings. | |
| """ | |
| if not state.artifacts_loaded: | |
| logger.error("Service not ready. Retrieval artifacts not loaded.") | |
| return [] | |
| start_time = time.time() | |
| try: | |
| with torch.no_grad(): | |
| # 1. Encode query text into an embedding vector | |
| query_embedding = state.query_encoder_model.encode( | |
| query_text, convert_to_tensor=True, device=state.device | |
| ) | |
| # 2. Apply query normalization to the query embedding | |
| q_trans_normalized = F.normalize(query_embedding.unsqueeze(0), p=2, dim=-1) | |
| # 3. Convert to numpy for fast similarity calculation | |
| query_vec_np = q_trans_normalized.cpu().numpy() | |
| # 4. Perform dot product. | |
| similarities = (query_vec_np @ state.transformed_chunk_embeddings.T)[0] | |
| # 5. Apply the learned temperature scaling | |
| scaled_similarities = similarities * np.exp(state.temperature) | |
| # 6. Combine with IDs, sort, and return top N | |
| results = list(zip(state.chunk_ids_in_order, scaled_similarities)) | |
| results.sort(key=lambda item: item[1], reverse=True) | |
| except Exception as e: | |
| logger.exception(f"Error during model-based similarity search: {e}") | |
| return [] | |
| duration = time.time() - start_time | |
| logger.info(f"Similarity search completed in {duration:.4f} seconds.") | |
| # --- MODIFIED: Detailed logging to include original_file --- | |
| top_results_to_log = results[:200] | |
| logger.info("--- Top 30 Retrieved Chunks (from retrieval service) ---") | |
| for i, (chunk_id, score) in enumerate(top_results_to_log): | |
| # Look up metadata for the current chunk_id | |
| chunk_id_str = str(chunk_id) # Ensure the key is a string for lookup | |
| metadata = state.chunk_metadata_map.get(chunk_id_str, {}) | |
| original_file = metadata.get('original_file', 'File not found') | |
| # Updated log message to include the original file | |
| logger.info(f" {i+1}. Chunk ID: {chunk_id_str} | File: {original_file} | Score: {score:.4f}") | |
| logger.info("----------------------------------------------------------------") | |
| # --- END OF MODIFICATION --- | |
| return results[:top_n] | |