Xlordo commited on
Commit
ae43f82
Β·
verified Β·
1 Parent(s): 829cfaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -102
app.py CHANGED
@@ -1,122 +1,95 @@
1
  import gradio as gr
2
- from sentence_transformers import SentenceTransformer
3
  import faiss
4
  import numpy as np
5
  from datasets import load_dataset
6
- from sklearn.metrics import precision_score, recall_score
7
- import pandas as pd
8
- import os
 
 
 
 
 
 
9
 
 
10
  # Load SBERT model
 
11
  model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
12
 
13
- # Load MS MARCO dataset (10,000 passages for demo)
14
- dataset = load_dataset("ms_marco", "passage", split="train[:10000]")
15
- passages = dataset["passage"]
16
- passage_embeddings = model.encode(passages, convert_to_numpy=True, show_progress_bar=True)
17
-
18
  # Build FAISS index
19
- dimension = passage_embeddings.shape[1]
 
 
20
  index = faiss.IndexFlatL2(dimension)
21
- index.add(passage_embeddings)
22
-
23
- # CSV log file
24
- log_file = "results_log.csv"
25
- if not os.path.exists(log_file):
26
- pd.DataFrame(columns=["Query", "Relevant Passage", "Precision@10", "Recall@10", "F1@10", "MRR", "nDCG@10"]).to_csv(log_file, index=False)
27
 
 
28
  # Search function
29
- def semantic_search(query, k=10):
 
30
  query_vec = model.encode([query], convert_to_numpy=True)
31
  distances, indices = index.search(query_vec, k)
32
- results = [passages[i] for i in indices[0]]
33
- return "\n\n".join(results)
34
-
35
- # Helper functions for metrics
36
- def mean_reciprocal_rank(y_true):
37
- for rank, rel in enumerate(y_true, start=1):
38
- if rel == 1:
39
- return 1 / rank
40
- return 0
41
-
42
- def ndcg_at_k(y_true, k=10):
43
- y_true = np.array(y_true)[:k]
44
- if y_true.sum() == 0:
45
- return 0.0
46
- dcg = np.sum((2**y_true - 1) / np.log2(np.arange(2, len(y_true) + 2)))
47
- ideal = np.sort(y_true)[::-1]
48
- idcg = np.sum((2**ideal - 1) / np.log2(np.arange(2, len(ideal) + 2)))
49
- return dcg / idcg if idcg > 0 else 0.0
50
-
51
- # Evaluation function
52
- def evaluate(query, relevant_passage, k=10):
53
- results_text = semantic_search(query, k)
54
- results_list = results_text.split("\n\n")
55
-
56
- y_true = [1 if relevant_passage in r else 0 for r in results_list]
57
- y_pred = [1] * len(results_list)
58
-
59
- precision = precision_score(y_true, y_pred, zero_division=0)
60
- recall = recall_score(y_true, y_pred, zero_division=0)
61
- f1 = 2 * (precision * recall) / (precision + recall + 1e-9)
62
- mrr = mean_reciprocal_rank(y_true)
63
- ndcg = ndcg_at_k(y_true, k)
64
-
65
- # Log results to CSV
66
- df = pd.read_csv(log_file)
67
- new_row = {
68
- "Query": query,
69
- "Relevant Passage": relevant_passage,
70
  "Precision@10": round(precision, 3),
71
  "Recall@10": round(recall, 3),
72
- "F1@10": round(f1, 3),
73
  "MRR": round(mrr, 3),
74
- "nDCG@10": round(ndcg, 3),
75
  }
76
- df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
77
- df.to_csv(log_file, index=False)
78
-
79
- return new_row, f"πŸ“Š {len(df)} evaluations logged so far."
80
-
81
- # Function to download CSV log
82
- def download_log():
83
- return log_file
84
-
85
- # Function to check current log count
86
- def check_log_count():
87
- df = pd.read_csv(log_file)
88
- return f"πŸ“Š {len(df)} evaluations logged so far."
89
 
 
90
  # Gradio interface
91
- with gr.Blocks() as demo:
92
- gr.Markdown("## πŸ”Ž Semantic Search with SBERT (MS MARCO Subset)")
93
-
94
- with gr.Row():
95
- query_input = gr.Textbox(label="Enter your search query")
96
- k_input = gr.Slider(1, 20, value=10, step=1, label="Top-K Results")
97
-
98
- results_output = gr.Textbox(label="Search Results", lines=10)
99
- run_btn = gr.Button("Search")
100
-
101
- run_btn.click(fn=semantic_search, inputs=[query_input, k_input], outputs=results_output)
102
-
103
- gr.Markdown("### πŸ“Š Evaluation")
104
- with gr.Row():
105
- eval_query = gr.Textbox(label="Evaluation Query")
106
- relevant_passage = gr.Textbox(label="Known Relevant Passage")
107
- eval_btn = gr.Button("Run Evaluation")
108
- eval_output = gr.JSON(label="Evaluation Metrics")
109
- eval_counter = gr.Label(label="Evaluation Log Count")
110
-
111
- eval_btn.click(fn=evaluate, inputs=[eval_query, relevant_passage, k_input], outputs=[eval_output, eval_counter])
112
-
113
- gr.Markdown("### πŸ“‚ Download Logged Results")
114
- download_btn = gr.Button("Download CSV")
115
- file_output = gr.File()
116
- count_btn = gr.Button("Check Log Count")
117
- count_output = gr.Label(label="Evaluation Log Count")
118
-
119
- download_btn.click(fn=download_log, outputs=file_output)
120
- count_btn.click(fn=check_log_count, outputs=count_output)
121
-
122
- demo.launch()
 
1
  import gradio as gr
 
2
  import faiss
3
  import numpy as np
4
  from datasets import load_dataset
5
+ from sentence_transformers import SentenceTransformer
6
+ from sklearn.metrics import ndcg_score
7
+
8
+ # ----------------------------
9
+ # Load dataset (MS MARCO v1.1)
10
+ # ----------------------------
11
+ dataset = load_dataset("ms_marco", "v1.1", split="train[:10000]")
12
+ passages = [item["passage"] for item in dataset]
13
+ print(f"Loaded {len(passages)} passages")
14
 
15
+ # ----------------------------
16
  # Load SBERT model
17
+ # ----------------------------
18
  model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
19
 
20
+ # ----------------------------
 
 
 
 
21
  # Build FAISS index
22
+ # ----------------------------
23
+ embeddings = model.encode(passages, convert_to_numpy=True, show_progress_bar=True)
24
+ dimension = embeddings.shape[1]
25
  index = faiss.IndexFlatL2(dimension)
26
+ index.add(embeddings)
27
+ print("FAISS index built with", index.ntotal, "passages")
 
 
 
 
28
 
29
+ # ----------------------------
30
  # Search function
31
+ # ----------------------------
32
+ def search(query, k=10):
33
  query_vec = model.encode([query], convert_to_numpy=True)
34
  distances, indices = index.search(query_vec, k)
35
+ results = [(passages[i], float(dist)) for i, dist in zip(indices[0], distances[0])]
36
+ return results
37
+
38
+ # ----------------------------
39
+ # Evaluation metrics
40
+ # ----------------------------
41
+ def evaluate(query, relevant_passages, k=10):
42
+ """Compute IR metrics for a query given a list of relevant passages (ground truth)."""
43
+ results = search(query, k)
44
+ retrieved = [res[0] for res in results]
45
+
46
+ # Binary relevance vector
47
+ y_true = [1 if p in relevant_passages else 0 for p in retrieved]
48
+ y_true_full = np.array([[1 if passages[i] in relevant_passages else 0 for i in range(len(passages))]])
49
+ y_scores_full = np.zeros((1, len(passages)))
50
+ for idx, (res, dist) in enumerate(results):
51
+ pos = passages.index(res)
52
+ y_scores_full[0, pos] = 1.0 - dist # higher score = more relevant
53
+
54
+ # Metrics
55
+ precision = sum(y_true) / k
56
+ recall = sum(y_true) / len(relevant_passages) if relevant_passages else 0
57
+ f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
58
+ mrr = 1.0 / (y_true.index(1)+1) if 1 in y_true else 0
59
+ ndcg = ndcg_score(y_true_full, y_scores_full, k=k)
60
+
61
+ return {
 
 
 
 
 
 
 
 
 
 
 
62
  "Precision@10": round(precision, 3),
63
  "Recall@10": round(recall, 3),
64
+ "F1": round(f1, 3),
65
  "MRR": round(mrr, 3),
66
+ "nDCG@10": round(ndcg, 3)
67
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # ----------------------------
70
  # Gradio interface
71
+ # ----------------------------
72
+ def gradio_interface(query, relevant_texts):
73
+ results = search(query, k=10)
74
+ metrics = {}
75
+ if relevant_texts.strip():
76
+ relevant_passages = [t.strip() for t in relevant_texts.split("\n") if t.strip()]
77
+ metrics = evaluate(query, relevant_passages, k=10)
78
+ return results, metrics
79
+
80
+ demo = gr.Interface(
81
+ fn=gradio_interface,
82
+ inputs=[
83
+ gr.Textbox(label="Enter your query"),
84
+ gr.Textbox(label="Enter relevant passages (ground truth, one per line)", placeholder="Optional")
85
+ ],
86
+ outputs=[
87
+ gr.Dataframe(headers=["Passage", "Distance"], label="Top-10 Results"),
88
+ gr.Label(label="Evaluation Metrics")
89
+ ],
90
+ title="SBERT + FAISS Semantic Search",
91
+ description="Enter a query to search MS MARCO passages. Optionally provide ground truth passages to compute IR metrics."
92
+ )
93
+
94
+ if __name__ == "__main__":
95
+ demo.launch()