Xlordo's picture
Create app.py
361de4d verified
raw
history blame
4.22 kB
import gradio as gr
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from datasets import load_dataset
from sklearn.metrics import precision_score, recall_score
import pandas as pd
import os
# Load SBERT model
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
# Load MS MARCO dataset (10,000 passages for demo)
dataset = load_dataset("ms_marco", "passage", split="train[:10000]")
passages = dataset["passage"]
passage_embeddings = model.encode(passages, convert_to_numpy=True, show_progress_bar=True)
# Build FAISS index
dimension = passage_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(passage_embeddings)
# CSV log file
log_file = "results_log.csv"
if not os.path.exists(log_file):
pd.DataFrame(columns=["Query", "Relevant Passage", "Precision@10", "Recall@10", "F1@10", "MRR", "nDCG@10"]).to_csv(log_file, index=False)
# Search function
def semantic_search(query, k=10):
query_vec = model.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_vec, k)
results = [passages[i] for i in indices[0]]
return "\n\n".join(results)
# Helper functions for metrics
def mean_reciprocal_rank(y_true):
for rank, rel in enumerate(y_true, start=1):
if rel == 1:
return 1 / rank
return 0
def ndcg_at_k(y_true, k=10):
y_true = np.array(y_true)[:k]
if y_true.sum() == 0:
return 0.0
dcg = np.sum((2**y_true - 1) / np.log2(np.arange(2, len(y_true) + 2)))
ideal = np.sort(y_true)[::-1]
idcg = np.sum((2**ideal - 1) / np.log2(np.arange(2, len(ideal) + 2)))
return dcg / idcg if idcg > 0 else 0.0
# Evaluation function
def evaluate(query, relevant_passage, k=10):
results_text = semantic_search(query, k)
results_list = results_text.split("\n\n")
y_true = [1 if relevant_passage in r else 0 for r in results_list]
y_pred = [1] * len(results_list)
precision = precision_score(y_true, y_pred, zero_division=0)
recall = recall_score(y_true, y_pred, zero_division=0)
f1 = 2 * (precision * recall) / (precision + recall + 1e-9)
mrr = mean_reciprocal_rank(y_true)
ndcg = ndcg_at_k(y_true, k)
# Log results to CSV
df = pd.read_csv(log_file)
new_row = {
"Query": query,
"Relevant Passage": relevant_passage,
"Precision@10": round(precision, 3),
"Recall@10": round(recall, 3),
"F1@10": round(f1, 3),
"MRR": round(mrr, 3),
"nDCG@10": round(ndcg, 3),
}
df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
df.to_csv(log_file, index=False)
return new_row, f"πŸ“Š {len(df)} evaluations logged so far."
# Function to download CSV log
def download_log():
return log_file
# Function to check current log count
def check_log_count():
df = pd.read_csv(log_file)
return f"πŸ“Š {len(df)} evaluations logged so far."
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## πŸ”Ž Semantic Search with SBERT (MS MARCO Subset)")
with gr.Row():
query_input = gr.Textbox(label="Enter your search query")
k_input = gr.Slider(1, 20, value=10, step=1, label="Top-K Results")
results_output = gr.Textbox(label="Search Results", lines=10)
run_btn = gr.Button("Search")
run_btn.click(fn=semantic_search, inputs=[query_input, k_input], outputs=results_output)
gr.Markdown("### πŸ“Š Evaluation")
with gr.Row():
eval_query = gr.Textbox(label="Evaluation Query")
relevant_passage = gr.Textbox(label="Known Relevant Passage")
eval_btn = gr.Button("Run Evaluation")
eval_output = gr.JSON(label="Evaluation Metrics")
eval_counter = gr.Label(label="Evaluation Log Count")
eval_btn.click(fn=evaluate, inputs=[eval_query, relevant_passage, k_input], outputs=[eval_output, eval_counter])
gr.Markdown("### πŸ“‚ Download Logged Results")
download_btn = gr.Button("Download CSV")
file_output = gr.File()
count_btn = gr.Button("Check Log Count")
count_output = gr.Label(label="Evaluation Log Count")
download_btn.click(fn=download_log, outputs=file_output)
count_btn.click(fn=check_log_count, outputs=count_output)
demo.launch()