import gradio as gr from datasets import load_dataset from sentence_transformers import SentenceTransformer, util import numpy as np # ---------- Load model ---------- model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") # ---------- Load MS MARCO dataset ---------- # 10k sample passages dataset = load_dataset("sentence-transformers/msmarco", "v1.1", split="train[:10000]") passages = dataset["passage"] # Precompute embeddings passage_embeddings = model.encode(passages, convert_to_tensor=True) # Map index -> passage id_to_passage = {i: passages[i] for i in range(len(passages))} # ---------- Load queries and qrels ---------- queries_dataset = load_dataset("sentence-transformers/msmarco", "v1.1", split="validation[:500]") # small sample qrels_dataset = load_dataset("ms_marco", "v1.1", split="validation[:500]") # contains relevant passage ids query_id_to_text = {i: q["query"] for i, q in enumerate(queries_dataset)} query_id_to_relevant = {i: set(q["positive_passages"]) for i, q in enumerate(qrels_dataset)} # ---------- Evaluation metrics ---------- def precision_at_k(relevant, retrieved, k): return len(set(relevant) & set(retrieved[:k])) / k def recall_at_k(relevant, retrieved, k): return len(set(relevant) & set(retrieved[:k])) / len(relevant) if relevant else 0 def f1_at_k(relevant, retrieved, k): p = precision_at_k(relevant, retrieved, k) r = recall_at_k(relevant, retrieved, k) return 2*p*r / (p+r) if (p+r) > 0 else 0 def mrr(relevant, retrieved): for i, r in enumerate(retrieved): if r in relevant: return 1 / (i+1) return 0 def ndcg_at_k(relevant, retrieved, k): dcg = 0 for i, r in enumerate(retrieved[:k]): if r in relevant: dcg += 1 / np.log2(i+2) ideal_dcg = sum(1 / np.log2(i+2) for i in range(min(len(relevant), k))) return dcg / ideal_dcg if ideal_dcg > 0 else 0 # ---------- Search ---------- def semantic_search(query, top_k=10): query_embedding = model.encode(query, convert_to_tensor=True) scores = util.cos_sim(query_embedding, passage_embeddings)[0] top_results = scores.topk(k=top_k) retrieved_indices = [int(idx) for idx in top_results[1]] results = [(id_to_passage[idx], float(scores[idx])) for idx in retrieved_indices] return results, retrieved_indices # ---------- Gradio interface ---------- def search_and_evaluate(query): results, retrieved_indices = semantic_search(query, top_k=10) # Match against actual relevant passages if available relevant_indices = set() for i, q in query_id_to_text.items(): if q.strip().lower() == query.strip().lower(): relevant_indices = query_id_to_relevant[i] break metrics = { "Precision@10": precision_at_k(relevant_indices, retrieved_indices, 10), "Recall@10": recall_at_k(relevant_indices, retrieved_indices, 10), "F1@10": f1_at_k(relevant_indices, retrieved_indices, 10), "MRR": mrr(relevant_indices, retrieved_indices), "nDCG@10": ndcg_at_k(relevant_indices, retrieved_indices, 10) } output_text = "### Search Results:\n" for i, (text, score) in enumerate(results, 1): output_text += f"{i}. {text} (score: {score:.4f})\n\n" output_text += "\n### Evaluation Metrics:\n" for k, v in metrics.items(): output_text += f"{k}: {v:.4f}\n" return output_text iface = gr.Interface( fn=search_and_evaluate, inputs=gr.Textbox(label="Enter your query"), outputs=gr.Textbox(label="Results + Metrics"), title="SBERT Semantic Search + Evaluation Metrics", description="Semantic search on MS MARCO (10,000 sample passages) using all-mpnet-base-v2 with true evaluation metrics." ) if __name__ == "__main__": iface.launch()