Xlordo commited on
Commit
361de4d
Β·
verified Β·
1 Parent(s): 29a733e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sentence_transformers import SentenceTransformer
3
+ import faiss
4
+ import numpy as np
5
+ from datasets import load_dataset
6
+ from sklearn.metrics import precision_score, recall_score
7
+ import pandas as pd
8
+ import os
9
+
10
+ # Load SBERT model
11
+ model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
12
+
13
+ # Load MS MARCO dataset (10,000 passages for demo)
14
+ dataset = load_dataset("ms_marco", "passage", split="train[:10000]")
15
+ passages = dataset["passage"]
16
+ passage_embeddings = model.encode(passages, convert_to_numpy=True, show_progress_bar=True)
17
+
18
+ # Build FAISS index
19
+ dimension = passage_embeddings.shape[1]
20
+ index = faiss.IndexFlatL2(dimension)
21
+ index.add(passage_embeddings)
22
+
23
+ # CSV log file
24
+ log_file = "results_log.csv"
25
+ if not os.path.exists(log_file):
26
+ pd.DataFrame(columns=["Query", "Relevant Passage", "Precision@10", "Recall@10", "F1@10", "MRR", "nDCG@10"]).to_csv(log_file, index=False)
27
+
28
+ # Search function
29
+ def semantic_search(query, k=10):
30
+ query_vec = model.encode([query], convert_to_numpy=True)
31
+ distances, indices = index.search(query_vec, k)
32
+ results = [passages[i] for i in indices[0]]
33
+ return "\n\n".join(results)
34
+
35
+ # Helper functions for metrics
36
+ def mean_reciprocal_rank(y_true):
37
+ for rank, rel in enumerate(y_true, start=1):
38
+ if rel == 1:
39
+ return 1 / rank
40
+ return 0
41
+
42
+ def ndcg_at_k(y_true, k=10):
43
+ y_true = np.array(y_true)[:k]
44
+ if y_true.sum() == 0:
45
+ return 0.0
46
+ dcg = np.sum((2**y_true - 1) / np.log2(np.arange(2, len(y_true) + 2)))
47
+ ideal = np.sort(y_true)[::-1]
48
+ idcg = np.sum((2**ideal - 1) / np.log2(np.arange(2, len(ideal) + 2)))
49
+ return dcg / idcg if idcg > 0 else 0.0
50
+
51
+ # Evaluation function
52
+ def evaluate(query, relevant_passage, k=10):
53
+ results_text = semantic_search(query, k)
54
+ results_list = results_text.split("\n\n")
55
+
56
+ y_true = [1 if relevant_passage in r else 0 for r in results_list]
57
+ y_pred = [1] * len(results_list)
58
+
59
+ precision = precision_score(y_true, y_pred, zero_division=0)
60
+ recall = recall_score(y_true, y_pred, zero_division=0)
61
+ f1 = 2 * (precision * recall) / (precision + recall + 1e-9)
62
+ mrr = mean_reciprocal_rank(y_true)
63
+ ndcg = ndcg_at_k(y_true, k)
64
+
65
+ # Log results to CSV
66
+ df = pd.read_csv(log_file)
67
+ new_row = {
68
+ "Query": query,
69
+ "Relevant Passage": relevant_passage,
70
+ "Precision@10": round(precision, 3),
71
+ "Recall@10": round(recall, 3),
72
+ "F1@10": round(f1, 3),
73
+ "MRR": round(mrr, 3),
74
+ "nDCG@10": round(ndcg, 3),
75
+ }
76
+ df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
77
+ df.to_csv(log_file, index=False)
78
+
79
+ return new_row, f"πŸ“Š {len(df)} evaluations logged so far."
80
+
81
+ # Function to download CSV log
82
+ def download_log():
83
+ return log_file
84
+
85
+ # Function to check current log count
86
+ def check_log_count():
87
+ df = pd.read_csv(log_file)
88
+ return f"πŸ“Š {len(df)} evaluations logged so far."
89
+
90
+ # Gradio interface
91
+ with gr.Blocks() as demo:
92
+ gr.Markdown("## πŸ”Ž Semantic Search with SBERT (MS MARCO Subset)")
93
+
94
+ with gr.Row():
95
+ query_input = gr.Textbox(label="Enter your search query")
96
+ k_input = gr.Slider(1, 20, value=10, step=1, label="Top-K Results")
97
+
98
+ results_output = gr.Textbox(label="Search Results", lines=10)
99
+ run_btn = gr.Button("Search")
100
+
101
+ run_btn.click(fn=semantic_search, inputs=[query_input, k_input], outputs=results_output)
102
+
103
+ gr.Markdown("### πŸ“Š Evaluation")
104
+ with gr.Row():
105
+ eval_query = gr.Textbox(label="Evaluation Query")
106
+ relevant_passage = gr.Textbox(label="Known Relevant Passage")
107
+ eval_btn = gr.Button("Run Evaluation")
108
+ eval_output = gr.JSON(label="Evaluation Metrics")
109
+ eval_counter = gr.Label(label="Evaluation Log Count")
110
+
111
+ eval_btn.click(fn=evaluate, inputs=[eval_query, relevant_passage, k_input], outputs=[eval_output, eval_counter])
112
+
113
+ gr.Markdown("### πŸ“‚ Download Logged Results")
114
+ download_btn = gr.Button("Download CSV")
115
+ file_output = gr.File()
116
+ count_btn = gr.Button("Check Log Count")
117
+ count_output = gr.Label(label="Evaluation Log Count")
118
+
119
+ download_btn.click(fn=download_log, outputs=file_output)
120
+ count_btn.click(fn=check_log_count, outputs=count_output)
121
+
122
+ demo.launch()