File size: 3,433 Bytes
361de4d
 
 
 
ae43f82
 
 
 
 
 
 
 
 
361de4d
ae43f82
361de4d
ae43f82
361de4d
 
ae43f82
361de4d
ae43f82
 
 
361de4d
ae43f82
 
361de4d
ae43f82
361de4d
ae43f82
 
361de4d
 
ae43f82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361de4d
 
ae43f82
361de4d
ae43f82
361de4d
 
ae43f82
361de4d
ae43f82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import faiss
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.metrics import ndcg_score

# ----------------------------
# Load dataset (MS MARCO v1.1)
# ----------------------------
dataset = load_dataset("ms_marco", "v1.1", split="train[:10000]")
passages = [item["passage"] for item in dataset]
print(f"Loaded {len(passages)} passages")

# ----------------------------
# Load SBERT model
# ----------------------------
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")

# ----------------------------
# Build FAISS index
# ----------------------------
embeddings = model.encode(passages, convert_to_numpy=True, show_progress_bar=True)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
print("FAISS index built with", index.ntotal, "passages")

# ----------------------------
# Search function
# ----------------------------
def search(query, k=10):
    query_vec = model.encode([query], convert_to_numpy=True)
    distances, indices = index.search(query_vec, k)
    results = [(passages[i], float(dist)) for i, dist in zip(indices[0], distances[0])]
    return results

# ----------------------------
# Evaluation metrics
# ----------------------------
def evaluate(query, relevant_passages, k=10):
    """Compute IR metrics for a query given a list of relevant passages (ground truth)."""
    results = search(query, k)
    retrieved = [res[0] for res in results]
    
    # Binary relevance vector
    y_true = [1 if p in relevant_passages else 0 for p in retrieved]
    y_true_full = np.array([[1 if passages[i] in relevant_passages else 0 for i in range(len(passages))]])
    y_scores_full = np.zeros((1, len(passages)))
    for idx, (res, dist) in enumerate(results):
        pos = passages.index(res)
        y_scores_full[0, pos] = 1.0 - dist  # higher score = more relevant
    
    # Metrics
    precision = sum(y_true) / k
    recall = sum(y_true) / len(relevant_passages) if relevant_passages else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
    mrr = 1.0 / (y_true.index(1)+1) if 1 in y_true else 0
    ndcg = ndcg_score(y_true_full, y_scores_full, k=k)
    
    return {
        "Precision@10": round(precision, 3),
        "Recall@10": round(recall, 3),
        "F1": round(f1, 3),
        "MRR": round(mrr, 3),
        "nDCG@10": round(ndcg, 3)
    }

# ----------------------------
# Gradio interface
# ----------------------------
def gradio_interface(query, relevant_texts):
    results = search(query, k=10)
    metrics = {}
    if relevant_texts.strip():
        relevant_passages = [t.strip() for t in relevant_texts.split("\n") if t.strip()]
        metrics = evaluate(query, relevant_passages, k=10)
    return results, metrics

demo = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Textbox(label="Enter your query"),
        gr.Textbox(label="Enter relevant passages (ground truth, one per line)", placeholder="Optional")
    ],
    outputs=[
        gr.Dataframe(headers=["Passage", "Distance"], label="Top-10 Results"),
        gr.Label(label="Evaluation Metrics")
    ],
    title="SBERT + FAISS Semantic Search",
    description="Enter a query to search MS MARCO passages. Optionally provide ground truth passages to compute IR metrics."
)

if __name__ == "__main__":
    demo.launch()