import logging import torch from typing import List, Dict from transformers import PreTrainedTokenizer import math import os import requests # --- Import your custom model architecture --- from app.models.expert_judge_model import ExpertJudgeCrossEncoder, get_tokenizer from app.core import state from app.core.config import settings logger = logging.getLogger(__name__) def load_reranker_model(): """ Loads the custom-trained ExpertJudgeCrossEncoder model. If running on a new server, it first downloads the model from S3. """ if state.reranker_model_loaded: logger.info("Re-ranker model already loaded in state.") return True # --- ADDED: Download from S3 if file doesn't exist --- model_path = settings.RERANKER_MODEL_PATH if not os.path.exists(model_path) and settings.S3_RERANKER_URL: logger.info(f"Re-ranker model not found at {model_path}. Downloading from S3...") try: # Create the 'data' directory if it doesn't exist os.makedirs(os.path.dirname(model_path), exist_ok=True) with requests.get(settings.S3_RERANKER_URL, stream=True) as r: r.raise_for_status() with open(model_path, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) logger.info("Successfully downloaded re-ranker model from S3.") except Exception as e: logger.exception(f"FATAL: Failed to download re-ranker model from S3: {e}") return False # --- END OF ADDITION --- base_model_name = settings.RERANKER_MODEL_NAME logger.info(f"Loading custom ExpertJudgeCrossEncoder from: {model_path}") logger.info(f"Using base model architecture: {base_model_name}") try: # 1. Instantiate the model structure model = ExpertJudgeCrossEncoder(model_name=base_model_name) # 2. Load the saved weights (the state_dict) into the model structure model.load_state_dict(torch.load(model_path, map_location=state.device)) # 3. Set up the model for inference model.to(state.device) model.eval() # 4. Load the corresponding tokenizer tokenizer = get_tokenizer(model_name=base_model_name) # 5. Store both in the state state.reranker_model = model state.reranker_tokenizer = tokenizer state.reranker_model_loaded = True logger.info("Custom ExpertJudgeCrossEncoder model and tokenizer loaded successfully.") return True except Exception as e: logger.exception(f"Failed to load custom Cross-Encoder model: {e}") return False def rerank_chunks(query: str, chunks: List[Dict], metadata_map: Dict) -> List[Dict]: """ Re-ranks chunks using the custom ExpertJudgeCrossEncoder with the concise v4 input format. """ if not state.reranker_model_loaded or not chunks: logger.warning("Re-ranker not loaded or no chunks provided. Returning original order.") return chunks logger.info(f"Re-ranking {len(chunks)} chunks with custom Expert Judge (v4 format)...") model = state.reranker_model tokenizer: PreTrainedTokenizer = state.reranker_tokenizer scores = [] # --- EDIT: Create pairs of [query, passage] for the tokenizer --- # The BGE-Reranker model was trained on this simple format. query_passage_pairs = [[query, chunk.get('text', '')] for chunk in chunks] with torch.no_grad(): # --- EDIT: Tokenize the pairs directly --- # The tokenizer will handle adding [CLS], [SEP] tokens correctly for pairs. encoded = tokenizer( query_passage_pairs, padding=True, truncation=True, return_tensors='pt', max_length=512 ) encoded = {k: v.to(state.device) for k, v in encoded.items()} logits = model(input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask']) # scores = torch.sigmoid(logits).squeeze().cpu().numpy() scores = logits.squeeze().cpu().numpy() # Add the new, more accurate score to each chunk dictionary if len(chunks) == 1: chunks[0]['rerank_score'] = float(scores) else: for i, chunk in enumerate(chunks): chunk['rerank_score'] = scores[i] # Sort the chunks by the new score in descending order sorted_chunks = sorted(chunks, key=lambda x: x.get('rerank_score', 0.0), reverse=True) logger.info("--- Top 5 Custom Re-ranked Chunks (v4) ---") for i, chunk in enumerate(sorted_chunks[:100]): logger.info(f" {i+1}. Chunk ID: {chunk.get('id', 'N/A')} | New Score: {chunk.get('rerank_score', 0.0):.4f}") logger.info("--------------------------------") return sorted_chunks