File size: 1,277 Bytes
78efc3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import List, Dict, Any
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.schema import Document
from .utils import getconfig

config = getconfig("params.cfg")

# load ranker settings from your existing config

RANKER_MODEL = config.get("reranker", "MODEL")
RANKER_TOP_K  = int(config.get("reranker", "TOP_K"))

def rerank_context(
    query: str,
    contexts: List[Dict[str, Any]],
    top_n: int = None
) -> List[Dict[str, Any]]:
    """
    Re-ranks a list of context dicts (each with 'page_content' & 'metadata')
    using a cross-encoder and returns the top_n sorted results.
    """
    # wrap into LangChain Documents
    docs = [
        Document(page_content=c["page_content"], metadata=c.get("metadata", {}))
        for c in contexts
    ]

    # instantiate reranker
    n = top_n or RANKER_TOP_K
    model    = HuggingFaceCrossEncoder(model_name=RANKER_MODEL)
    reranker = CrossEncoderReranker(model=model, top_n=n)

    # perform reranking
    reranked: List[Document] = reranker.rerank(query, docs)

    # return as plain dicts
    return [
        {"page_content": d.page_content, "metadata": d.metadata}
        for d in reranked
    ]