Spaces:
Sleeping
Sleeping
File size: 6,024 Bytes
58de15f 6868c13 a9465d3 58de15f 757abfc 58de15f 757abfc 58de15f 757abfc 58de15f 757abfc 58de15f 757abfc 58de15f a9465d3 280c5bf 58de15f c53ff60 58de15f c53ff60 58de15f c53ff60 58de15f bea2de8 58de15f bea2de8 58de15f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# app/services/retrieval.py
import logging
import time
import numpy as np
import torch
import torch.nn.functional as F
from typing import List, Tuple
from sentence_transformers import SentenceTransformer
import os
import requests
# Import settings and state
from app.core.config import settings
from app.core import state
logger = logging.getLogger(__name__)
def load_retrieval_artifacts():
"""
Loads all necessary artifacts for retrieval. If running on a new server,
it first downloads the artifacts from S3.
"""
if state.artifacts_loaded:
logger.info("Retrieval artifacts already loaded in state.")
return True
# --- ADDED: Download from S3 if file doesn't exist ---
artifacts_path = settings.RETRIEVAL_ARTIFACTS_PATH
if not os.path.exists(artifacts_path) and settings.S3_ARTIFACTS_URL:
logger.info(f"Artifacts file not found at {artifacts_path}. Downloading from S3...")
try:
# Create the 'data' directory if it doesn't exist
os.makedirs(os.path.dirname(artifacts_path), exist_ok=True)
with requests.get(settings.S3_ARTIFACTS_URL, stream=True) as r:
r.raise_for_status()
with open(artifacts_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Successfully downloaded artifacts from S3.")
except Exception as e:
logger.exception(f"FATAL: Failed to download artifacts from S3: {e}")
return False
# --- END OF ADDITION ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device for retrieval: {device}")
state.device = device
# 1. Load the pre-computed artifacts file
logger.info(f"Loading retrieval artifacts from {artifacts_path}...")
try:
if not os.path.exists(artifacts_path):
logger.error(f"FATAL: Artifacts file not found at {artifacts_path}")
return False
artifacts = np.load(artifacts_path, allow_pickle=True)
# Load into state
state.transformed_chunk_embeddings = artifacts['transformed_chunk_embeddings']
state.chunk_ids_in_order = artifacts['chunk_ids']
state.temperature = artifacts['temperature'][0] # Extract scalar from array
logger.info(f"Successfully loaded {len(state.chunk_ids_in_order)} transformed embeddings.")
logger.info(f"Loaded temperature value: {state.temperature:.4f}")
except Exception as e:
logger.exception(f"Failed to load and process retrieval artifacts: {e}")
return False
# 2. Load the Sentence Transformer model for encoding queries
logger.info(f"Loading Sentence Transformer model: {settings.QUERY_ENCODER_MODEL_NAME}...")
try:
# query_encoder = SentenceTransformer(settings.QUERY_ENCODER_MODEL_NAME, device=device)
cache_dir = "data/cache/"
logger.info(f"Using cache directory for models: {cache_dir}")
query_encoder = SentenceTransformer(settings.QUERY_ENCODER_MODEL_NAME, device=device,cache_folder=cache_dir)
query_encoder.eval()
state.query_encoder_model = query_encoder
logger.info("Query encoder model loaded successfully.")
except Exception as e:
logger.exception(f"Failed to load Sentence Transformer model: {e}")
return False
state.artifacts_loaded = True
return True
# In app/services/retrieval.py
def find_top_gnn_chunks(query_text: str, top_n: int = 200) -> List[Tuple[str, float]]:
"""
Performs a similarity search that is mathematically identical to the trained model,
but without loading the GNN itself. It uses pre-transformed embeddings.
"""
if not state.artifacts_loaded:
logger.error("Service not ready. Retrieval artifacts not loaded.")
return []
start_time = time.time()
try:
with torch.no_grad():
# 1. Encode query text into an embedding vector
query_embedding = state.query_encoder_model.encode(
query_text, convert_to_tensor=True, device=state.device
)
# 2. Apply query normalization to the query embedding
q_trans_normalized = F.normalize(query_embedding.unsqueeze(0), p=2, dim=-1)
# 3. Convert to numpy for fast similarity calculation
query_vec_np = q_trans_normalized.cpu().numpy()
# 4. Perform dot product.
similarities = (query_vec_np @ state.transformed_chunk_embeddings.T)[0]
# 5. Apply the learned temperature scaling
scaled_similarities = similarities * np.exp(state.temperature)
# 6. Combine with IDs, sort, and return top N
results = list(zip(state.chunk_ids_in_order, scaled_similarities))
results.sort(key=lambda item: item[1], reverse=True)
except Exception as e:
logger.exception(f"Error during model-based similarity search: {e}")
return []
duration = time.time() - start_time
logger.info(f"Similarity search completed in {duration:.4f} seconds.")
# --- MODIFIED: Detailed logging to include original_file ---
top_results_to_log = results[:200]
logger.info("--- Top 30 Retrieved Chunks (from retrieval service) ---")
for i, (chunk_id, score) in enumerate(top_results_to_log):
# Look up metadata for the current chunk_id
chunk_id_str = str(chunk_id) # Ensure the key is a string for lookup
metadata = state.chunk_metadata_map.get(chunk_id_str, {})
original_file = metadata.get('original_file', 'File not found')
# Updated log message to include the original file
logger.info(f" {i+1}. Chunk ID: {chunk_id_str} | File: {original_file} | Score: {score:.4f}")
logger.info("----------------------------------------------------------------")
# --- END OF MODIFICATION ---
return results[:top_n]
|