Xlordo's picture
Update app.py
ae43f82 verified
raw
history blame
3.43 kB
import gradio as gr
import faiss
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.metrics import ndcg_score
# ----------------------------
# Load dataset (MS MARCO v1.1)
# ----------------------------
dataset = load_dataset("ms_marco", "v1.1", split="train[:10000]")
passages = [item["passage"] for item in dataset]
print(f"Loaded {len(passages)} passages")
# ----------------------------
# Load SBERT model
# ----------------------------
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
# ----------------------------
# Build FAISS index
# ----------------------------
embeddings = model.encode(passages, convert_to_numpy=True, show_progress_bar=True)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
print("FAISS index built with", index.ntotal, "passages")
# ----------------------------
# Search function
# ----------------------------
def search(query, k=10):
query_vec = model.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_vec, k)
results = [(passages[i], float(dist)) for i, dist in zip(indices[0], distances[0])]
return results
# ----------------------------
# Evaluation metrics
# ----------------------------
def evaluate(query, relevant_passages, k=10):
"""Compute IR metrics for a query given a list of relevant passages (ground truth)."""
results = search(query, k)
retrieved = [res[0] for res in results]
# Binary relevance vector
y_true = [1 if p in relevant_passages else 0 for p in retrieved]
y_true_full = np.array([[1 if passages[i] in relevant_passages else 0 for i in range(len(passages))]])
y_scores_full = np.zeros((1, len(passages)))
for idx, (res, dist) in enumerate(results):
pos = passages.index(res)
y_scores_full[0, pos] = 1.0 - dist # higher score = more relevant
# Metrics
precision = sum(y_true) / k
recall = sum(y_true) / len(relevant_passages) if relevant_passages else 0
f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
mrr = 1.0 / (y_true.index(1)+1) if 1 in y_true else 0
ndcg = ndcg_score(y_true_full, y_scores_full, k=k)
return {
"Precision@10": round(precision, 3),
"Recall@10": round(recall, 3),
"F1": round(f1, 3),
"MRR": round(mrr, 3),
"nDCG@10": round(ndcg, 3)
}
# ----------------------------
# Gradio interface
# ----------------------------
def gradio_interface(query, relevant_texts):
results = search(query, k=10)
metrics = {}
if relevant_texts.strip():
relevant_passages = [t.strip() for t in relevant_texts.split("\n") if t.strip()]
metrics = evaluate(query, relevant_passages, k=10)
return results, metrics
demo = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Enter your query"),
gr.Textbox(label="Enter relevant passages (ground truth, one per line)", placeholder="Optional")
],
outputs=[
gr.Dataframe(headers=["Passage", "Distance"], label="Top-10 Results"),
gr.Label(label="Evaluation Metrics")
],
title="SBERT + FAISS Semantic Search",
description="Enter a query to search MS MARCO passages. Optionally provide ground truth passages to compute IR metrics."
)
if __name__ == "__main__":
demo.launch()