File size: 4,822 Bytes
58de15f
 
 
 
 
 
6868c13
a9465d3
 
 
58de15f
 
 
 
 
 
 
757abfc
 
58de15f
 
 
 
757abfc
 
58de15f
757abfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58de15f
 
 
757abfc
58de15f
 
 
 
 
 
 
 
 
 
757abfc
58de15f
 
 
 
 
 
 
 
 
 
 
 
a9465d3
 
58de15f
 
 
 
 
a9465d3
58de15f
 
 
 
8595342
 
 
58de15f
 
8595342
 
58de15f
8595342
58de15f
 
 
 
 
 
 
 
 
bea2de8
 
58de15f
 
 
 
 
 
 
 
 
 
 
c046733
 
 
 
58de15f
a9465d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import logging
import torch
from typing import List, Dict
from transformers import PreTrainedTokenizer
import math
import os
import requests
# --- Import your custom model architecture ---
from app.models.expert_judge_model import ExpertJudgeCrossEncoder, get_tokenizer

from app.core import state
from app.core.config import settings

logger = logging.getLogger(__name__)

def load_reranker_model():
    """
    Loads the custom-trained ExpertJudgeCrossEncoder model. If running on a
    new server, it first downloads the model from S3.
    """
    if state.reranker_model_loaded:
        logger.info("Re-ranker model already loaded in state.")
        return True
 
    # --- ADDED: Download from S3 if file doesn't exist ---
    model_path = settings.RERANKER_MODEL_PATH
    if not os.path.exists(model_path) and settings.S3_RERANKER_URL:
        logger.info(f"Re-ranker model not found at {model_path}. Downloading from S3...")
        try:
            # Create the 'data' directory if it doesn't exist
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            with requests.get(settings.S3_RERANKER_URL, stream=True) as r:
                r.raise_for_status()
                with open(model_path, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            logger.info("Successfully downloaded re-ranker model from S3.")
        except Exception as e:
            logger.exception(f"FATAL: Failed to download re-ranker model from S3: {e}")
            return False
    # --- END OF ADDITION ---
 
    base_model_name = settings.RERANKER_MODEL_NAME
    logger.info(f"Loading custom ExpertJudgeCrossEncoder from: {model_path}")
    logger.info(f"Using base model architecture: {base_model_name}")
 
    try:
        # 1. Instantiate the model structure
        model = ExpertJudgeCrossEncoder(model_name=base_model_name)
        # 2. Load the saved weights (the state_dict) into the model structure
        model.load_state_dict(torch.load(model_path, map_location=state.device))
        # 3. Set up the model for inference
        model.to(state.device)
        model.eval()
        # 4. Load the corresponding tokenizer
        tokenizer = get_tokenizer(model_name=base_model_name)
 
        # 5. Store both in the state
        state.reranker_model = model
        state.reranker_tokenizer = tokenizer
        state.reranker_model_loaded = True
        logger.info("Custom ExpertJudgeCrossEncoder model and tokenizer loaded successfully.")
        return True
    except Exception as e:
        logger.exception(f"Failed to load custom Cross-Encoder model: {e}")
        return False

def rerank_chunks(query: str, chunks: List[Dict], metadata_map: Dict) -> List[Dict]:
    """
    Re-ranks chunks using the custom ExpertJudgeCrossEncoder with the concise
    v4 input format.
    """
    if not state.reranker_model_loaded or not chunks:
        logger.warning("Re-ranker not loaded or no chunks provided. Returning original order.")
        return chunks

    logger.info(f"Re-ranking {len(chunks)} chunks with custom Expert Judge (v4 format)...")

    model = state.reranker_model
    tokenizer: PreTrainedTokenizer = state.reranker_tokenizer
    scores = []
 # --- EDIT: Create pairs of [query, passage] for the tokenizer ---
    # The BGE-Reranker model was trained on this simple format.
    query_passage_pairs = [[query, chunk.get('text', '')] for chunk in chunks]

    with torch.no_grad():
        # --- EDIT: Tokenize the pairs directly ---
        # The tokenizer will handle adding [CLS], [SEP] tokens correctly for pairs.
        encoded = tokenizer(
            query_passage_pairs,
            padding=True,
            truncation=True,
            return_tensors='pt',
            max_length=512
        )
        
        encoded = {k: v.to(state.device) for k, v in encoded.items()}
        
        logits = model(input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask'])
        # scores = torch.sigmoid(logits).squeeze().cpu().numpy()
        scores = logits.squeeze().cpu().numpy()

    # Add the new, more accurate score to each chunk dictionary
    if len(chunks) == 1:
        chunks[0]['rerank_score'] = float(scores)
    else:
        for i, chunk in enumerate(chunks):
            chunk['rerank_score'] = scores[i]

    # Sort the chunks by the new score in descending order
    sorted_chunks = sorted(chunks, key=lambda x: x.get('rerank_score', 0.0), reverse=True)

    logger.info("--- Top 5 Custom Re-ranked Chunks (v4) ---")
    for i, chunk in enumerate(sorted_chunks[:100]):
        logger.info(f"   {i+1}. Chunk ID: {chunk.get('id', 'N/A')} | New Score: {chunk.get('rerank_score', 0.0):.4f}")
    logger.info("--------------------------------")

    return sorted_chunks