Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| from datasets import load_dataset | |
| from sklearn.metrics import precision_score, recall_score | |
| import pandas as pd | |
| import os | |
| # Load SBERT model | |
| model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | |
| # Load MS MARCO dataset (10,000 passages for demo) | |
| dataset = load_dataset("ms_marco", "passage", split="train[:10000]") | |
| passages = dataset["passage"] | |
| passage_embeddings = model.encode(passages, convert_to_numpy=True, show_progress_bar=True) | |
| # Build FAISS index | |
| dimension = passage_embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(passage_embeddings) | |
| # CSV log file | |
| log_file = "results_log.csv" | |
| if not os.path.exists(log_file): | |
| pd.DataFrame(columns=["Query", "Relevant Passage", "Precision@10", "Recall@10", "F1@10", "MRR", "nDCG@10"]).to_csv(log_file, index=False) | |
| # Search function | |
| def semantic_search(query, k=10): | |
| query_vec = model.encode([query], convert_to_numpy=True) | |
| distances, indices = index.search(query_vec, k) | |
| results = [passages[i] for i in indices[0]] | |
| return "\n\n".join(results) | |
| # Helper functions for metrics | |
| def mean_reciprocal_rank(y_true): | |
| for rank, rel in enumerate(y_true, start=1): | |
| if rel == 1: | |
| return 1 / rank | |
| return 0 | |
| def ndcg_at_k(y_true, k=10): | |
| y_true = np.array(y_true)[:k] | |
| if y_true.sum() == 0: | |
| return 0.0 | |
| dcg = np.sum((2**y_true - 1) / np.log2(np.arange(2, len(y_true) + 2))) | |
| ideal = np.sort(y_true)[::-1] | |
| idcg = np.sum((2**ideal - 1) / np.log2(np.arange(2, len(ideal) + 2))) | |
| return dcg / idcg if idcg > 0 else 0.0 | |
| # Evaluation function | |
| def evaluate(query, relevant_passage, k=10): | |
| results_text = semantic_search(query, k) | |
| results_list = results_text.split("\n\n") | |
| y_true = [1 if relevant_passage in r else 0 for r in results_list] | |
| y_pred = [1] * len(results_list) | |
| precision = precision_score(y_true, y_pred, zero_division=0) | |
| recall = recall_score(y_true, y_pred, zero_division=0) | |
| f1 = 2 * (precision * recall) / (precision + recall + 1e-9) | |
| mrr = mean_reciprocal_rank(y_true) | |
| ndcg = ndcg_at_k(y_true, k) | |
| # Log results to CSV | |
| df = pd.read_csv(log_file) | |
| new_row = { | |
| "Query": query, | |
| "Relevant Passage": relevant_passage, | |
| "Precision@10": round(precision, 3), | |
| "Recall@10": round(recall, 3), | |
| "F1@10": round(f1, 3), | |
| "MRR": round(mrr, 3), | |
| "nDCG@10": round(ndcg, 3), | |
| } | |
| df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True) | |
| df.to_csv(log_file, index=False) | |
| return new_row, f"π {len(df)} evaluations logged so far." | |
| # Function to download CSV log | |
| def download_log(): | |
| return log_file | |
| # Function to check current log count | |
| def check_log_count(): | |
| df = pd.read_csv(log_file) | |
| return f"π {len(df)} evaluations logged so far." | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π Semantic Search with SBERT (MS MARCO Subset)") | |
| with gr.Row(): | |
| query_input = gr.Textbox(label="Enter your search query") | |
| k_input = gr.Slider(1, 20, value=10, step=1, label="Top-K Results") | |
| results_output = gr.Textbox(label="Search Results", lines=10) | |
| run_btn = gr.Button("Search") | |
| run_btn.click(fn=semantic_search, inputs=[query_input, k_input], outputs=results_output) | |
| gr.Markdown("### π Evaluation") | |
| with gr.Row(): | |
| eval_query = gr.Textbox(label="Evaluation Query") | |
| relevant_passage = gr.Textbox(label="Known Relevant Passage") | |
| eval_btn = gr.Button("Run Evaluation") | |
| eval_output = gr.JSON(label="Evaluation Metrics") | |
| eval_counter = gr.Label(label="Evaluation Log Count") | |
| eval_btn.click(fn=evaluate, inputs=[eval_query, relevant_passage, k_input], outputs=[eval_output, eval_counter]) | |
| gr.Markdown("### π Download Logged Results") | |
| download_btn = gr.Button("Download CSV") | |
| file_output = gr.File() | |
| count_btn = gr.Button("Check Log Count") | |
| count_output = gr.Label(label="Evaluation Log Count") | |
| download_btn.click(fn=download_log, outputs=file_output) | |
| count_btn.click(fn=check_log_count, outputs=count_output) | |
| demo.launch() | |