Spaces:
Sleeping
Sleeping
| import logging | |
| import torch | |
| from typing import List, Dict | |
| from transformers import PreTrainedTokenizer | |
| import math | |
| import os | |
| import requests | |
| # --- Import your custom model architecture --- | |
| from app.models.expert_judge_model import ExpertJudgeCrossEncoder, get_tokenizer | |
| from app.core import state | |
| from app.core.config import settings | |
| logger = logging.getLogger(__name__) | |
| def load_reranker_model(): | |
| """ | |
| Loads the custom-trained ExpertJudgeCrossEncoder model. If running on a | |
| new server, it first downloads the model from S3. | |
| """ | |
| if state.reranker_model_loaded: | |
| logger.info("Re-ranker model already loaded in state.") | |
| return True | |
| # --- ADDED: Download from S3 if file doesn't exist --- | |
| model_path = settings.RERANKER_MODEL_PATH | |
| if not os.path.exists(model_path) and settings.S3_RERANKER_URL: | |
| logger.info(f"Re-ranker model not found at {model_path}. Downloading from S3...") | |
| try: | |
| # Create the 'data' directory if it doesn't exist | |
| os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
| with requests.get(settings.S3_RERANKER_URL, stream=True) as r: | |
| r.raise_for_status() | |
| with open(model_path, 'wb') as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info("Successfully downloaded re-ranker model from S3.") | |
| except Exception as e: | |
| logger.exception(f"FATAL: Failed to download re-ranker model from S3: {e}") | |
| return False | |
| # --- END OF ADDITION --- | |
| base_model_name = settings.RERANKER_MODEL_NAME | |
| logger.info(f"Loading custom ExpertJudgeCrossEncoder from: {model_path}") | |
| logger.info(f"Using base model architecture: {base_model_name}") | |
| try: | |
| # 1. Instantiate the model structure | |
| model = ExpertJudgeCrossEncoder(model_name=base_model_name) | |
| # 2. Load the saved weights (the state_dict) into the model structure | |
| model.load_state_dict(torch.load(model_path, map_location=state.device)) | |
| # 3. Set up the model for inference | |
| model.to(state.device) | |
| model.eval() | |
| # 4. Load the corresponding tokenizer | |
| tokenizer = get_tokenizer(model_name=base_model_name) | |
| # 5. Store both in the state | |
| state.reranker_model = model | |
| state.reranker_tokenizer = tokenizer | |
| state.reranker_model_loaded = True | |
| logger.info("Custom ExpertJudgeCrossEncoder model and tokenizer loaded successfully.") | |
| return True | |
| except Exception as e: | |
| logger.exception(f"Failed to load custom Cross-Encoder model: {e}") | |
| return False | |
| def rerank_chunks(query: str, chunks: List[Dict], metadata_map: Dict) -> List[Dict]: | |
| """ | |
| Re-ranks chunks using the custom ExpertJudgeCrossEncoder with the concise | |
| v4 input format. | |
| """ | |
| if not state.reranker_model_loaded or not chunks: | |
| logger.warning("Re-ranker not loaded or no chunks provided. Returning original order.") | |
| return chunks | |
| logger.info(f"Re-ranking {len(chunks)} chunks with custom Expert Judge (v4 format)...") | |
| model = state.reranker_model | |
| tokenizer: PreTrainedTokenizer = state.reranker_tokenizer | |
| scores = [] | |
| # --- EDIT: Create pairs of [query, passage] for the tokenizer --- | |
| # The BGE-Reranker model was trained on this simple format. | |
| query_passage_pairs = [[query, chunk.get('text', '')] for chunk in chunks] | |
| with torch.no_grad(): | |
| # --- EDIT: Tokenize the pairs directly --- | |
| # The tokenizer will handle adding [CLS], [SEP] tokens correctly for pairs. | |
| encoded = tokenizer( | |
| query_passage_pairs, | |
| padding=True, | |
| truncation=True, | |
| return_tensors='pt', | |
| max_length=512 | |
| ) | |
| encoded = {k: v.to(state.device) for k, v in encoded.items()} | |
| logits = model(input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask']) | |
| # scores = torch.sigmoid(logits).squeeze().cpu().numpy() | |
| scores = logits.squeeze().cpu().numpy() | |
| # Add the new, more accurate score to each chunk dictionary | |
| if len(chunks) == 1: | |
| chunks[0]['rerank_score'] = float(scores) | |
| else: | |
| for i, chunk in enumerate(chunks): | |
| chunk['rerank_score'] = scores[i] | |
| # Sort the chunks by the new score in descending order | |
| sorted_chunks = sorted(chunks, key=lambda x: x.get('rerank_score', 0.0), reverse=True) | |
| logger.info("--- Top 5 Custom Re-ranked Chunks (v4) ---") | |
| for i, chunk in enumerate(sorted_chunks[:100]): | |
| logger.info(f" {i+1}. Chunk ID: {chunk.get('id', 'N/A')} | New Score: {chunk.get('rerank_score', 0.0):.4f}") | |
| logger.info("--------------------------------") | |
| return sorted_chunks | |