Spaces:
Sleeping
Sleeping
File size: 4,822 Bytes
58de15f 6868c13 a9465d3 58de15f 757abfc 58de15f 757abfc 58de15f 757abfc 58de15f 757abfc 58de15f 757abfc 58de15f a9465d3 58de15f a9465d3 58de15f 8595342 58de15f 8595342 58de15f 8595342 58de15f bea2de8 58de15f c046733 58de15f a9465d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import logging
import torch
from typing import List, Dict
from transformers import PreTrainedTokenizer
import math
import os
import requests
# --- Import your custom model architecture ---
from app.models.expert_judge_model import ExpertJudgeCrossEncoder, get_tokenizer
from app.core import state
from app.core.config import settings
logger = logging.getLogger(__name__)
def load_reranker_model():
"""
Loads the custom-trained ExpertJudgeCrossEncoder model. If running on a
new server, it first downloads the model from S3.
"""
if state.reranker_model_loaded:
logger.info("Re-ranker model already loaded in state.")
return True
# --- ADDED: Download from S3 if file doesn't exist ---
model_path = settings.RERANKER_MODEL_PATH
if not os.path.exists(model_path) and settings.S3_RERANKER_URL:
logger.info(f"Re-ranker model not found at {model_path}. Downloading from S3...")
try:
# Create the 'data' directory if it doesn't exist
os.makedirs(os.path.dirname(model_path), exist_ok=True)
with requests.get(settings.S3_RERANKER_URL, stream=True) as r:
r.raise_for_status()
with open(model_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Successfully downloaded re-ranker model from S3.")
except Exception as e:
logger.exception(f"FATAL: Failed to download re-ranker model from S3: {e}")
return False
# --- END OF ADDITION ---
base_model_name = settings.RERANKER_MODEL_NAME
logger.info(f"Loading custom ExpertJudgeCrossEncoder from: {model_path}")
logger.info(f"Using base model architecture: {base_model_name}")
try:
# 1. Instantiate the model structure
model = ExpertJudgeCrossEncoder(model_name=base_model_name)
# 2. Load the saved weights (the state_dict) into the model structure
model.load_state_dict(torch.load(model_path, map_location=state.device))
# 3. Set up the model for inference
model.to(state.device)
model.eval()
# 4. Load the corresponding tokenizer
tokenizer = get_tokenizer(model_name=base_model_name)
# 5. Store both in the state
state.reranker_model = model
state.reranker_tokenizer = tokenizer
state.reranker_model_loaded = True
logger.info("Custom ExpertJudgeCrossEncoder model and tokenizer loaded successfully.")
return True
except Exception as e:
logger.exception(f"Failed to load custom Cross-Encoder model: {e}")
return False
def rerank_chunks(query: str, chunks: List[Dict], metadata_map: Dict) -> List[Dict]:
"""
Re-ranks chunks using the custom ExpertJudgeCrossEncoder with the concise
v4 input format.
"""
if not state.reranker_model_loaded or not chunks:
logger.warning("Re-ranker not loaded or no chunks provided. Returning original order.")
return chunks
logger.info(f"Re-ranking {len(chunks)} chunks with custom Expert Judge (v4 format)...")
model = state.reranker_model
tokenizer: PreTrainedTokenizer = state.reranker_tokenizer
scores = []
# --- EDIT: Create pairs of [query, passage] for the tokenizer ---
# The BGE-Reranker model was trained on this simple format.
query_passage_pairs = [[query, chunk.get('text', '')] for chunk in chunks]
with torch.no_grad():
# --- EDIT: Tokenize the pairs directly ---
# The tokenizer will handle adding [CLS], [SEP] tokens correctly for pairs.
encoded = tokenizer(
query_passage_pairs,
padding=True,
truncation=True,
return_tensors='pt',
max_length=512
)
encoded = {k: v.to(state.device) for k, v in encoded.items()}
logits = model(input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask'])
# scores = torch.sigmoid(logits).squeeze().cpu().numpy()
scores = logits.squeeze().cpu().numpy()
# Add the new, more accurate score to each chunk dictionary
if len(chunks) == 1:
chunks[0]['rerank_score'] = float(scores)
else:
for i, chunk in enumerate(chunks):
chunk['rerank_score'] = scores[i]
# Sort the chunks by the new score in descending order
sorted_chunks = sorted(chunks, key=lambda x: x.get('rerank_score', 0.0), reverse=True)
logger.info("--- Top 5 Custom Re-ranked Chunks (v4) ---")
for i, chunk in enumerate(sorted_chunks[:100]):
logger.info(f" {i+1}. Chunk ID: {chunk.get('id', 'N/A')} | New Score: {chunk.get('rerank_score', 0.0):.4f}")
logger.info("--------------------------------")
return sorted_chunks
|