helal94hb1 commited on
Commit
c53ff60
·
1 Parent(s): bea2de8

fix: new embeddings and reranker2

Browse files
Files changed (1) hide show
  1. app/services/retrieval.py +4 -7
app/services/retrieval.py CHANGED
@@ -57,10 +57,8 @@ def load_retrieval_artifacts():
57
  # Load into state
58
  state.transformed_chunk_embeddings = artifacts['transformed_chunk_embeddings']
59
  state.chunk_ids_in_order = artifacts['chunk_ids']
60
- state.wq_weights = torch.from_numpy(artifacts['wq_weights']).to(device)
61
  state.temperature = artifacts['temperature'][0] # Extract scalar from array
62
  logger.info(f"Successfully loaded {len(state.chunk_ids_in_order)} transformed embeddings.")
63
- logger.info(f"Loaded Wq matrix of shape: {state.wq_weights.shape}")
64
  logger.info(f"Loaded temperature value: {state.temperature:.4f}")
65
 
66
  except Exception as e:
@@ -86,7 +84,7 @@ def load_retrieval_artifacts():
86
 
87
  # In app/services/retrieval.py
88
 
89
- def find_top_gnn_chunks(query_text: str, top_n: int = 100) -> List[Tuple[str, float]]:
90
  """
91
  Performs a similarity search that is mathematically identical to the trained model,
92
  but without loading the GNN itself. It uses pre-transformed embeddings.
@@ -104,9 +102,8 @@ def find_top_gnn_chunks(query_text: str, top_n: int = 100) -> List[Tuple[str, fl
104
  query_text, convert_to_tensor=True, device=state.device
105
  )
106
 
107
- # 2. Apply the learned 'Wq' transformation to the query embedding
108
- q_trans = F.linear(query_embedding.unsqueeze(0), state.wq_weights)
109
- q_trans_normalized = F.normalize(q_trans, p=2, dim=-1)
110
 
111
  # 3. Convert to numpy for fast similarity calculation
112
  query_vec_np = q_trans_normalized.cpu().numpy()
@@ -115,7 +112,7 @@ def find_top_gnn_chunks(query_text: str, top_n: int = 100) -> List[Tuple[str, fl
115
  similarities = (query_vec_np @ state.transformed_chunk_embeddings.T)[0]
116
 
117
  # 5. Apply the learned temperature scaling
118
- scaled_similarities = similarities * state.temperature
119
 
120
  # 6. Combine with IDs, sort, and return top N
121
  results = list(zip(state.chunk_ids_in_order, scaled_similarities))
 
57
  # Load into state
58
  state.transformed_chunk_embeddings = artifacts['transformed_chunk_embeddings']
59
  state.chunk_ids_in_order = artifacts['chunk_ids']
 
60
  state.temperature = artifacts['temperature'][0] # Extract scalar from array
61
  logger.info(f"Successfully loaded {len(state.chunk_ids_in_order)} transformed embeddings.")
 
62
  logger.info(f"Loaded temperature value: {state.temperature:.4f}")
63
 
64
  except Exception as e:
 
84
 
85
  # In app/services/retrieval.py
86
 
87
+ def find_top_gnn_chunks(query_text: str, top_n: int = 200) -> List[Tuple[str, float]]:
88
  """
89
  Performs a similarity search that is mathematically identical to the trained model,
90
  but without loading the GNN itself. It uses pre-transformed embeddings.
 
102
  query_text, convert_to_tensor=True, device=state.device
103
  )
104
 
105
+ # 2. Apply query normalization to the query embedding
106
+ q_trans_normalized = F.normalize(query_embedding.unsqueeze(0), p=2, dim=-1)
 
107
 
108
  # 3. Convert to numpy for fast similarity calculation
109
  query_vec_np = q_trans_normalized.cpu().numpy()
 
112
  similarities = (query_vec_np @ state.transformed_chunk_embeddings.T)[0]
113
 
114
  # 5. Apply the learned temperature scaling
115
+ scaled_similarities = similarities * np.exp(state.temperature)
116
 
117
  # 6. Combine with IDs, sort, and return top N
118
  results = list(zip(state.chunk_ids_in_order, scaled_similarities))