Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import faiss | |
| import numpy as np | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics import ndcg_score | |
| # ---------------------------- | |
| # Load dataset (MS MARCO v1.1) | |
| # ---------------------------- | |
| dataset = load_dataset("ms_marco", "v1.1", split="train[:10000]") | |
| passages = [item["passage"] for item in dataset] | |
| print(f"Loaded {len(passages)} passages") | |
| # ---------------------------- | |
| # Load SBERT model | |
| # ---------------------------- | |
| model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | |
| # ---------------------------- | |
| # Build FAISS index | |
| # ---------------------------- | |
| embeddings = model.encode(passages, convert_to_numpy=True, show_progress_bar=True) | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(embeddings) | |
| print("FAISS index built with", index.ntotal, "passages") | |
| # ---------------------------- | |
| # Search function | |
| # ---------------------------- | |
| def search(query, k=10): | |
| query_vec = model.encode([query], convert_to_numpy=True) | |
| distances, indices = index.search(query_vec, k) | |
| results = [(passages[i], float(dist)) for i, dist in zip(indices[0], distances[0])] | |
| return results | |
| # ---------------------------- | |
| # Evaluation metrics | |
| # ---------------------------- | |
| def evaluate(query, relevant_passages, k=10): | |
| """Compute IR metrics for a query given a list of relevant passages (ground truth).""" | |
| results = search(query, k) | |
| retrieved = [res[0] for res in results] | |
| # Binary relevance vector | |
| y_true = [1 if p in relevant_passages else 0 for p in retrieved] | |
| y_true_full = np.array([[1 if passages[i] in relevant_passages else 0 for i in range(len(passages))]]) | |
| y_scores_full = np.zeros((1, len(passages))) | |
| for idx, (res, dist) in enumerate(results): | |
| pos = passages.index(res) | |
| y_scores_full[0, pos] = 1.0 - dist # higher score = more relevant | |
| # Metrics | |
| precision = sum(y_true) / k | |
| recall = sum(y_true) / len(relevant_passages) if relevant_passages else 0 | |
| f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0 | |
| mrr = 1.0 / (y_true.index(1)+1) if 1 in y_true else 0 | |
| ndcg = ndcg_score(y_true_full, y_scores_full, k=k) | |
| return { | |
| "Precision@10": round(precision, 3), | |
| "Recall@10": round(recall, 3), | |
| "F1": round(f1, 3), | |
| "MRR": round(mrr, 3), | |
| "nDCG@10": round(ndcg, 3) | |
| } | |
| # ---------------------------- | |
| # Gradio interface | |
| # ---------------------------- | |
| def gradio_interface(query, relevant_texts): | |
| results = search(query, k=10) | |
| metrics = {} | |
| if relevant_texts.strip(): | |
| relevant_passages = [t.strip() for t in relevant_texts.split("\n") if t.strip()] | |
| metrics = evaluate(query, relevant_passages, k=10) | |
| return results, metrics | |
| demo = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Textbox(label="Enter your query"), | |
| gr.Textbox(label="Enter relevant passages (ground truth, one per line)", placeholder="Optional") | |
| ], | |
| outputs=[ | |
| gr.Dataframe(headers=["Passage", "Distance"], label="Top-10 Results"), | |
| gr.Label(label="Evaluation Metrics") | |
| ], | |
| title="SBERT + FAISS Semantic Search", | |
| description="Enter a query to search MS MARCO passages. Optionally provide ground truth passages to compute IR metrics." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |