helal94hb1's picture
fix: new embeddings and reranker2
c53ff60
# app/services/retrieval.py
import logging
import time
import numpy as np
import torch
import torch.nn.functional as F
from typing import List, Tuple
from sentence_transformers import SentenceTransformer
import os
import requests
# Import settings and state
from app.core.config import settings
from app.core import state
logger = logging.getLogger(__name__)
def load_retrieval_artifacts():
"""
Loads all necessary artifacts for retrieval. If running on a new server,
it first downloads the artifacts from S3.
"""
if state.artifacts_loaded:
logger.info("Retrieval artifacts already loaded in state.")
return True
# --- ADDED: Download from S3 if file doesn't exist ---
artifacts_path = settings.RETRIEVAL_ARTIFACTS_PATH
if not os.path.exists(artifacts_path) and settings.S3_ARTIFACTS_URL:
logger.info(f"Artifacts file not found at {artifacts_path}. Downloading from S3...")
try:
# Create the 'data' directory if it doesn't exist
os.makedirs(os.path.dirname(artifacts_path), exist_ok=True)
with requests.get(settings.S3_ARTIFACTS_URL, stream=True) as r:
r.raise_for_status()
with open(artifacts_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Successfully downloaded artifacts from S3.")
except Exception as e:
logger.exception(f"FATAL: Failed to download artifacts from S3: {e}")
return False
# --- END OF ADDITION ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device for retrieval: {device}")
state.device = device
# 1. Load the pre-computed artifacts file
logger.info(f"Loading retrieval artifacts from {artifacts_path}...")
try:
if not os.path.exists(artifacts_path):
logger.error(f"FATAL: Artifacts file not found at {artifacts_path}")
return False
artifacts = np.load(artifacts_path, allow_pickle=True)
# Load into state
state.transformed_chunk_embeddings = artifacts['transformed_chunk_embeddings']
state.chunk_ids_in_order = artifacts['chunk_ids']
state.temperature = artifacts['temperature'][0] # Extract scalar from array
logger.info(f"Successfully loaded {len(state.chunk_ids_in_order)} transformed embeddings.")
logger.info(f"Loaded temperature value: {state.temperature:.4f}")
except Exception as e:
logger.exception(f"Failed to load and process retrieval artifacts: {e}")
return False
# 2. Load the Sentence Transformer model for encoding queries
logger.info(f"Loading Sentence Transformer model: {settings.QUERY_ENCODER_MODEL_NAME}...")
try:
# query_encoder = SentenceTransformer(settings.QUERY_ENCODER_MODEL_NAME, device=device)
cache_dir = "data/cache/"
logger.info(f"Using cache directory for models: {cache_dir}")
query_encoder = SentenceTransformer(settings.QUERY_ENCODER_MODEL_NAME, device=device,cache_folder=cache_dir)
query_encoder.eval()
state.query_encoder_model = query_encoder
logger.info("Query encoder model loaded successfully.")
except Exception as e:
logger.exception(f"Failed to load Sentence Transformer model: {e}")
return False
state.artifacts_loaded = True
return True
# In app/services/retrieval.py
def find_top_gnn_chunks(query_text: str, top_n: int = 200) -> List[Tuple[str, float]]:
"""
Performs a similarity search that is mathematically identical to the trained model,
but without loading the GNN itself. It uses pre-transformed embeddings.
"""
if not state.artifacts_loaded:
logger.error("Service not ready. Retrieval artifacts not loaded.")
return []
start_time = time.time()
try:
with torch.no_grad():
# 1. Encode query text into an embedding vector
query_embedding = state.query_encoder_model.encode(
query_text, convert_to_tensor=True, device=state.device
)
# 2. Apply query normalization to the query embedding
q_trans_normalized = F.normalize(query_embedding.unsqueeze(0), p=2, dim=-1)
# 3. Convert to numpy for fast similarity calculation
query_vec_np = q_trans_normalized.cpu().numpy()
# 4. Perform dot product.
similarities = (query_vec_np @ state.transformed_chunk_embeddings.T)[0]
# 5. Apply the learned temperature scaling
scaled_similarities = similarities * np.exp(state.temperature)
# 6. Combine with IDs, sort, and return top N
results = list(zip(state.chunk_ids_in_order, scaled_similarities))
results.sort(key=lambda item: item[1], reverse=True)
except Exception as e:
logger.exception(f"Error during model-based similarity search: {e}")
return []
duration = time.time() - start_time
logger.info(f"Similarity search completed in {duration:.4f} seconds.")
# --- MODIFIED: Detailed logging to include original_file ---
top_results_to_log = results[:200]
logger.info("--- Top 30 Retrieved Chunks (from retrieval service) ---")
for i, (chunk_id, score) in enumerate(top_results_to_log):
# Look up metadata for the current chunk_id
chunk_id_str = str(chunk_id) # Ensure the key is a string for lookup
metadata = state.chunk_metadata_map.get(chunk_id_str, {})
original_file = metadata.get('original_file', 'File not found')
# Updated log message to include the original file
logger.info(f" {i+1}. Chunk ID: {chunk_id_str} | File: {original_file} | Score: {score:.4f}")
logger.info("----------------------------------------------------------------")
# --- END OF MODIFICATION ---
return results[:top_n]