Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer, util | |
| import numpy as np | |
| # ---------- Load model ---------- | |
| model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | |
| # ---------- Load MS MARCO dataset ---------- | |
| # 10k sample passages | |
| dataset = load_dataset("sentence-transformers/msmarco", "v1.1", split="train[:10000]") | |
| passages = dataset["passage"] | |
| # Precompute embeddings | |
| passage_embeddings = model.encode(passages, convert_to_tensor=True) | |
| # Map index -> passage | |
| id_to_passage = {i: passages[i] for i in range(len(passages))} | |
| # ---------- Load queries and qrels ---------- | |
| queries_dataset = load_dataset("sentence-transformers/msmarco", "v1.1", split="validation[:500]") # small sample | |
| qrels_dataset = load_dataset("ms_marco", "v1.1", split="validation[:500]") # contains relevant passage ids | |
| query_id_to_text = {i: q["query"] for i, q in enumerate(queries_dataset)} | |
| query_id_to_relevant = {i: set(q["positive_passages"]) for i, q in enumerate(qrels_dataset)} | |
| # ---------- Evaluation metrics ---------- | |
| def precision_at_k(relevant, retrieved, k): | |
| return len(set(relevant) & set(retrieved[:k])) / k | |
| def recall_at_k(relevant, retrieved, k): | |
| return len(set(relevant) & set(retrieved[:k])) / len(relevant) if relevant else 0 | |
| def f1_at_k(relevant, retrieved, k): | |
| p = precision_at_k(relevant, retrieved, k) | |
| r = recall_at_k(relevant, retrieved, k) | |
| return 2*p*r / (p+r) if (p+r) > 0 else 0 | |
| def mrr(relevant, retrieved): | |
| for i, r in enumerate(retrieved): | |
| if r in relevant: | |
| return 1 / (i+1) | |
| return 0 | |
| def ndcg_at_k(relevant, retrieved, k): | |
| dcg = 0 | |
| for i, r in enumerate(retrieved[:k]): | |
| if r in relevant: | |
| dcg += 1 / np.log2(i+2) | |
| ideal_dcg = sum(1 / np.log2(i+2) for i in range(min(len(relevant), k))) | |
| return dcg / ideal_dcg if ideal_dcg > 0 else 0 | |
| # ---------- Search ---------- | |
| def semantic_search(query, top_k=10): | |
| query_embedding = model.encode(query, convert_to_tensor=True) | |
| scores = util.cos_sim(query_embedding, passage_embeddings)[0] | |
| top_results = scores.topk(k=top_k) | |
| retrieved_indices = [int(idx) for idx in top_results[1]] | |
| results = [(id_to_passage[idx], float(scores[idx])) for idx in retrieved_indices] | |
| return results, retrieved_indices | |
| # ---------- Gradio interface ---------- | |
| def search_and_evaluate(query): | |
| results, retrieved_indices = semantic_search(query, top_k=10) | |
| # Match against actual relevant passages if available | |
| relevant_indices = set() | |
| for i, q in query_id_to_text.items(): | |
| if q.strip().lower() == query.strip().lower(): | |
| relevant_indices = query_id_to_relevant[i] | |
| break | |
| metrics = { | |
| "Precision@10": precision_at_k(relevant_indices, retrieved_indices, 10), | |
| "Recall@10": recall_at_k(relevant_indices, retrieved_indices, 10), | |
| "F1@10": f1_at_k(relevant_indices, retrieved_indices, 10), | |
| "MRR": mrr(relevant_indices, retrieved_indices), | |
| "nDCG@10": ndcg_at_k(relevant_indices, retrieved_indices, 10) | |
| } | |
| output_text = "### Search Results:\n" | |
| for i, (text, score) in enumerate(results, 1): | |
| output_text += f"{i}. {text} (score: {score:.4f})\n\n" | |
| output_text += "\n### Evaluation Metrics:\n" | |
| for k, v in metrics.items(): | |
| output_text += f"{k}: {v:.4f}\n" | |
| return output_text | |
| iface = gr.Interface( | |
| fn=search_and_evaluate, | |
| inputs=gr.Textbox(label="Enter your query"), | |
| outputs=gr.Textbox(label="Results + Metrics"), | |
| title="SBERT Semantic Search + Evaluation Metrics", | |
| description="Semantic search on MS MARCO (10,000 sample passages) using all-mpnet-base-v2 with true evaluation metrics." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |