File size: 6,024 Bytes
58de15f
 
 
 
 
 
 
 
 
 
6868c13
a9465d3
58de15f
 
 
 
 
 
 
 
757abfc
 
58de15f
 
 
 
757abfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58de15f
 
 
757abfc
58de15f
 
 
 
 
 
 
 
 
 
 
 
 
757abfc
58de15f
 
 
757abfc
58de15f
 
 
a9465d3
280c5bf
 
 
58de15f
 
 
 
 
 
 
 
 
 
 
 
c53ff60
58de15f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c53ff60
 
58de15f
 
 
 
 
 
 
 
c53ff60
58de15f
 
 
 
 
 
 
 
 
 
 
 
 
bea2de8
 
58de15f
 
 
 
 
 
 
bea2de8
 
58de15f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# app/services/retrieval.py

import logging
import time
import numpy as np
import torch
import torch.nn.functional as F
from typing import List, Tuple
from sentence_transformers import SentenceTransformer
import os
import requests

# Import settings and state
from app.core.config import settings
from app.core import state

logger = logging.getLogger(__name__)

def load_retrieval_artifacts():
    """
    Loads all necessary artifacts for retrieval. If running on a new server,
    it first downloads the artifacts from S3.
    """
    if state.artifacts_loaded:
        logger.info("Retrieval artifacts already loaded in state.")
        return True
 
    # --- ADDED: Download from S3 if file doesn't exist ---
    artifacts_path = settings.RETRIEVAL_ARTIFACTS_PATH
    if not os.path.exists(artifacts_path) and settings.S3_ARTIFACTS_URL:
        logger.info(f"Artifacts file not found at {artifacts_path}. Downloading from S3...")
        try:
            # Create the 'data' directory if it doesn't exist
            os.makedirs(os.path.dirname(artifacts_path), exist_ok=True)
            with requests.get(settings.S3_ARTIFACTS_URL, stream=True) as r:
                r.raise_for_status()
                with open(artifacts_path, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            logger.info("Successfully downloaded artifacts from S3.")
        except Exception as e:
            logger.exception(f"FATAL: Failed to download artifacts from S3: {e}")
            return False
    # --- END OF ADDITION ---
 
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device for retrieval: {device}")
    state.device = device
 
    # 1. Load the pre-computed artifacts file
    logger.info(f"Loading retrieval artifacts from {artifacts_path}...")
    try:
        if not os.path.exists(artifacts_path):
            logger.error(f"FATAL: Artifacts file not found at {artifacts_path}")
            return False
        artifacts = np.load(artifacts_path, allow_pickle=True)
        # Load into state
        state.transformed_chunk_embeddings = artifacts['transformed_chunk_embeddings']
        state.chunk_ids_in_order = artifacts['chunk_ids']
        state.temperature = artifacts['temperature'][0] # Extract scalar from array
        logger.info(f"Successfully loaded {len(state.chunk_ids_in_order)} transformed embeddings.")
        logger.info(f"Loaded temperature value: {state.temperature:.4f}")
 
    except Exception as e:
        logger.exception(f"Failed to load and process retrieval artifacts: {e}")
        return False
 
    # 2. Load the Sentence Transformer model for encoding queries
    logger.info(f"Loading Sentence Transformer model: {settings.QUERY_ENCODER_MODEL_NAME}...")
    try:
        # query_encoder = SentenceTransformer(settings.QUERY_ENCODER_MODEL_NAME, device=device)
        cache_dir = "data/cache/"
        logger.info(f"Using cache directory for models: {cache_dir}")
        query_encoder = SentenceTransformer(settings.QUERY_ENCODER_MODEL_NAME, device=device,cache_folder=cache_dir)
        query_encoder.eval()
        state.query_encoder_model = query_encoder
        logger.info("Query encoder model loaded successfully.")
    except Exception as e:
        logger.exception(f"Failed to load Sentence Transformer model: {e}")
        return False
    state.artifacts_loaded = True
    return True


# In app/services/retrieval.py

def find_top_gnn_chunks(query_text: str, top_n: int = 200) -> List[Tuple[str, float]]:
    """
    Performs a similarity search that is mathematically identical to the trained model,
    but without loading the GNN itself. It uses pre-transformed embeddings.
    """
    if not state.artifacts_loaded:
        logger.error("Service not ready. Retrieval artifacts not loaded.")
        return []

    start_time = time.time()
    
    try:
        with torch.no_grad():
            # 1. Encode query text into an embedding vector
            query_embedding = state.query_encoder_model.encode(
                query_text, convert_to_tensor=True, device=state.device
            )
            
            # 2. Apply query normalization to the query embedding
            q_trans_normalized = F.normalize(query_embedding.unsqueeze(0), p=2, dim=-1)
            
            # 3. Convert to numpy for fast similarity calculation
            query_vec_np = q_trans_normalized.cpu().numpy()
            
            # 4. Perform dot product.
            similarities = (query_vec_np @ state.transformed_chunk_embeddings.T)[0]
            
            # 5. Apply the learned temperature scaling
            scaled_similarities = similarities * np.exp(state.temperature)

        # 6. Combine with IDs, sort, and return top N
        results = list(zip(state.chunk_ids_in_order, scaled_similarities))
        results.sort(key=lambda item: item[1], reverse=True)

    except Exception as e:
        logger.exception(f"Error during model-based similarity search: {e}")
        return []

    duration = time.time() - start_time
    logger.info(f"Similarity search completed in {duration:.4f} seconds.")

    # --- MODIFIED: Detailed logging to include original_file ---
    top_results_to_log = results[:200]
    logger.info("--- Top 30 Retrieved Chunks (from retrieval service) ---")
    for i, (chunk_id, score) in enumerate(top_results_to_log):
        # Look up metadata for the current chunk_id
        chunk_id_str = str(chunk_id) # Ensure the key is a string for lookup
        metadata = state.chunk_metadata_map.get(chunk_id_str, {})
        original_file = metadata.get('original_file', 'File not found')
        
        # Updated log message to include the original file
        logger.info(f"  {i+1}. Chunk ID: {chunk_id_str} | File: {original_file} | Score: {score:.4f}")
    logger.info("----------------------------------------------------------------")
    # --- END OF MODIFICATION ---
    
    return results[:top_n]