| 
							 | 
						from typing import Any, Callable, Dict, List, Optional | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from air_benchmark import AIRBench, Retriever | 
					
					
						
						| 
							 | 
						from llama_index.core import VectorStoreIndex | 
					
					
						
						| 
							 | 
						from llama_index.core.node_parser import SentenceSplitter | 
					
					
						
						| 
							 | 
						from llama_index.embeddings.openai import OpenAIEmbedding | 
					
					
						
						| 
							 | 
						from llama_index.llms.openai import OpenAI | 
					
					
						
						| 
							 | 
						from llama_index.retrievers.bm25 import BM25Retriever | 
					
					
						
						| 
							 | 
						from llama_index.core.retrievers import QueryFusionRetriever | 
					
					
						
						| 
							 | 
						from llama_index.core.schema import Document, NodeWithScore | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_retriever_fn(documents: List[Document], top_k: int) -> Callable[[str], List[NodeWithScore]]: | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    nodes = SentenceSplitter(chunk_size=1024, chunk_overlap=128)(documents) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    vector_index = VectorStoreIndex( | 
					
					
						
						| 
							 | 
						        nodes=nodes,  | 
					
					
						
						| 
							 | 
						        embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002") | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    vector_retriever = vector_index.as_retriever(similarity_top_k=top_k) | 
					
					
						
						| 
							 | 
						    bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    retriever = QueryFusionRetriever( | 
					
					
						
						| 
							 | 
						        [vector_retriever, bm25_retriever],  | 
					
					
						
						| 
							 | 
						        similarity_top_k=top_k,  | 
					
					
						
						| 
							 | 
						        num_queries=3,  | 
					
					
						
						| 
							 | 
						        mode="dist_based_score", | 
					
					
						
						| 
							 | 
						        llm=OpenAI(model="gpt-3.5-turbo", temperature=0.1) | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _retriever(query: str) -> List[NodeWithScore]: | 
					
					
						
						| 
							 | 
						        return retriever.retrieve(query) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return _retriever | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class LlamaRetriever(Retriever): | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self,  | 
					
					
						
						| 
							 | 
						        name: str,  | 
					
					
						
						| 
							 | 
						        create_retriever_fn: Callable[[List[Document], int], Callable[[str], List[NodeWithScore]]],  | 
					
					
						
						| 
							 | 
						        search_top_k: int = 1000, | 
					
					
						
						| 
							 | 
						    ) -> None: | 
					
					
						
						| 
							 | 
						        self.name = name | 
					
					
						
						| 
							 | 
						        self.search_top_k | 
					
					
						
						| 
							 | 
						        self.create_retriever_fn = create_retriever_fn | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __str__(self): | 
					
					
						
						| 
							 | 
						        return self.name | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __call__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        corpus: Dict[str, Dict[str, Any]], | 
					
					
						
						| 
							 | 
						        queries: Dict[str, str], | 
					
					
						
						| 
							 | 
						        **kwargs, | 
					
					
						
						| 
							 | 
						    ) -> Dict[str, Dict[str, float]]: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Retrieve relevant documents for each query | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        documents = [] | 
					
					
						
						| 
							 | 
						        for doc_id, doc in corpus.items(): | 
					
					
						
						| 
							 | 
						            text = doc.pop("text") | 
					
					
						
						| 
							 | 
						            assert text is not None | 
					
					
						
						| 
							 | 
						            documents.append(Document(id_=doc_id, text=text, metadata={**doc})) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        retriever = self.create_retriever_fn(documents) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        query_ids = list(queries.keys()) | 
					
					
						
						| 
							 | 
						        results = {qid: {} for qid in query_ids} | 
					
					
						
						| 
							 | 
						        for qid in query_ids: | 
					
					
						
						| 
							 | 
						            query = queries[qid] | 
					
					
						
						| 
							 | 
						            if isinstance(query, list): | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                query = "; ".join(query) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            nodes = retriever(query) | 
					
					
						
						| 
							 | 
						            for node in nodes: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                results[qid][node.node.ref_doc_id] = node.score | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return results | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						retriever = LlamaRetriever("vector_bm25_fusion", create_retriever_fn) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						evaluation = AIRBench( | 
					
					
						
						| 
							 | 
						    benchmark_version="AIR-Bench_24.04", | 
					
					
						
						| 
							 | 
						    task_types=["long-doc"],       | 
					
					
						
						| 
							 | 
						    domains=["arxiv"],             | 
					
					
						
						| 
							 | 
						    languages=["en"],              | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						evaluation.run( | 
					
					
						
						| 
							 | 
						    retriever, | 
					
					
						
						| 
							 | 
						    output_dir="./llama_results",    | 
					
					
						
						| 
							 | 
						    overwrite=True              | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 |