backend_chatbot / app /services /reranker_service.py
helal94hb1's picture
fix: new embeddings and rerankh5
c046733
import logging
import torch
from typing import List, Dict
from transformers import PreTrainedTokenizer
import math
import os
import requests
# --- Import your custom model architecture ---
from app.models.expert_judge_model import ExpertJudgeCrossEncoder, get_tokenizer
from app.core import state
from app.core.config import settings
logger = logging.getLogger(__name__)
def load_reranker_model():
"""
Loads the custom-trained ExpertJudgeCrossEncoder model. If running on a
new server, it first downloads the model from S3.
"""
if state.reranker_model_loaded:
logger.info("Re-ranker model already loaded in state.")
return True
# --- ADDED: Download from S3 if file doesn't exist ---
model_path = settings.RERANKER_MODEL_PATH
if not os.path.exists(model_path) and settings.S3_RERANKER_URL:
logger.info(f"Re-ranker model not found at {model_path}. Downloading from S3...")
try:
# Create the 'data' directory if it doesn't exist
os.makedirs(os.path.dirname(model_path), exist_ok=True)
with requests.get(settings.S3_RERANKER_URL, stream=True) as r:
r.raise_for_status()
with open(model_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Successfully downloaded re-ranker model from S3.")
except Exception as e:
logger.exception(f"FATAL: Failed to download re-ranker model from S3: {e}")
return False
# --- END OF ADDITION ---
base_model_name = settings.RERANKER_MODEL_NAME
logger.info(f"Loading custom ExpertJudgeCrossEncoder from: {model_path}")
logger.info(f"Using base model architecture: {base_model_name}")
try:
# 1. Instantiate the model structure
model = ExpertJudgeCrossEncoder(model_name=base_model_name)
# 2. Load the saved weights (the state_dict) into the model structure
model.load_state_dict(torch.load(model_path, map_location=state.device))
# 3. Set up the model for inference
model.to(state.device)
model.eval()
# 4. Load the corresponding tokenizer
tokenizer = get_tokenizer(model_name=base_model_name)
# 5. Store both in the state
state.reranker_model = model
state.reranker_tokenizer = tokenizer
state.reranker_model_loaded = True
logger.info("Custom ExpertJudgeCrossEncoder model and tokenizer loaded successfully.")
return True
except Exception as e:
logger.exception(f"Failed to load custom Cross-Encoder model: {e}")
return False
def rerank_chunks(query: str, chunks: List[Dict], metadata_map: Dict) -> List[Dict]:
"""
Re-ranks chunks using the custom ExpertJudgeCrossEncoder with the concise
v4 input format.
"""
if not state.reranker_model_loaded or not chunks:
logger.warning("Re-ranker not loaded or no chunks provided. Returning original order.")
return chunks
logger.info(f"Re-ranking {len(chunks)} chunks with custom Expert Judge (v4 format)...")
model = state.reranker_model
tokenizer: PreTrainedTokenizer = state.reranker_tokenizer
scores = []
# --- EDIT: Create pairs of [query, passage] for the tokenizer ---
# The BGE-Reranker model was trained on this simple format.
query_passage_pairs = [[query, chunk.get('text', '')] for chunk in chunks]
with torch.no_grad():
# --- EDIT: Tokenize the pairs directly ---
# The tokenizer will handle adding [CLS], [SEP] tokens correctly for pairs.
encoded = tokenizer(
query_passage_pairs,
padding=True,
truncation=True,
return_tensors='pt',
max_length=512
)
encoded = {k: v.to(state.device) for k, v in encoded.items()}
logits = model(input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask'])
# scores = torch.sigmoid(logits).squeeze().cpu().numpy()
scores = logits.squeeze().cpu().numpy()
# Add the new, more accurate score to each chunk dictionary
if len(chunks) == 1:
chunks[0]['rerank_score'] = float(scores)
else:
for i, chunk in enumerate(chunks):
chunk['rerank_score'] = scores[i]
# Sort the chunks by the new score in descending order
sorted_chunks = sorted(chunks, key=lambda x: x.get('rerank_score', 0.0), reverse=True)
logger.info("--- Top 5 Custom Re-ranked Chunks (v4) ---")
for i, chunk in enumerate(sorted_chunks[:100]):
logger.info(f" {i+1}. Chunk ID: {chunk.get('id', 'N/A')} | New Score: {chunk.get('rerank_score', 0.0):.4f}")
logger.info("--------------------------------")
return sorted_chunks