chatfed_reranker / app /reranker.py
mtyrrell's picture
port of generator
78efc3f
from typing import List, Dict, Any
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.schema import Document
from .utils import getconfig
config = getconfig("params.cfg")
# load ranker settings from your existing config
RANKER_MODEL = config.get("reranker", "MODEL")
RANKER_TOP_K = int(config.get("reranker", "TOP_K"))
def rerank_context(
query: str,
contexts: List[Dict[str, Any]],
top_n: int = None
) -> List[Dict[str, Any]]:
"""
Re-ranks a list of context dicts (each with 'page_content' & 'metadata')
using a cross-encoder and returns the top_n sorted results.
"""
# wrap into LangChain Documents
docs = [
Document(page_content=c["page_content"], metadata=c.get("metadata", {}))
for c in contexts
]
# instantiate reranker
n = top_n or RANKER_TOP_K
model = HuggingFaceCrossEncoder(model_name=RANKER_MODEL)
reranker = CrossEncoderReranker(model=model, top_n=n)
# perform reranking
reranked: List[Document] = reranker.rerank(query, docs)
# return as plain dicts
return [
{"page_content": d.page_content, "metadata": d.metadata}
for d in reranked
]