Spaces:
Sleeping
Sleeping
add rerank
Browse files
app.py
CHANGED
|
@@ -13,6 +13,7 @@ from pyairtable import Api
|
|
| 13 |
import pickle
|
| 14 |
import re
|
| 15 |
import unicodedata
|
|
|
|
| 16 |
|
| 17 |
# Setup Qdrant Client
|
| 18 |
qdrant_client = QdrantClient(
|
|
@@ -43,15 +44,23 @@ models = {
|
|
| 43 |
"BGE M3": {
|
| 44 |
"model": SentenceTransformer("BAAI/bge-m3"),
|
| 45 |
"collection": "product_bge-m3",
|
| 46 |
-
"threshold": 0.
|
| 47 |
"prefix": ""
|
| 48 |
}
|
| 49 |
}
|
| 50 |
|
|
|
|
|
|
|
| 51 |
# Utils
|
| 52 |
def is_non_thai(text):
|
| 53 |
return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def normalize(text: str) -> str:
|
| 56 |
if is_non_thai(text):
|
| 57 |
return text.strip()
|
|
@@ -90,7 +99,7 @@ def correct_query_merge_phrases(query: str, whitelist, threshold=80, max_ngram=3
|
|
| 90 |
if not matched:
|
| 91 |
corrected.append(tokens[i])
|
| 92 |
i += 1
|
| 93 |
-
return
|
| 94 |
|
| 95 |
# Global state
|
| 96 |
latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
|
|
@@ -110,6 +119,7 @@ def search_product(query, model_choice):
|
|
| 110 |
query_embed = model.encode(prefix + corrected_query)
|
| 111 |
|
| 112 |
try:
|
|
|
|
| 113 |
result = qdrant_client.query_points(
|
| 114 |
collection_name=collection_name,
|
| 115 |
query=query_embed.tolist(),
|
|
@@ -120,11 +130,25 @@ def search_product(query, model_choice):
|
|
| 120 |
except Exception as e:
|
| 121 |
return f"<p>❌ Qdrant error: {str(e)}</p>"
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
elapsed = time.time() - start_time
|
| 124 |
html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
|
| 125 |
if corrected_query != query:
|
| 126 |
html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
|
| 127 |
-
|
| 128 |
html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
|
| 129 |
result_summary, found = "", False
|
| 130 |
|
|
|
|
| 13 |
import pickle
|
| 14 |
import re
|
| 15 |
import unicodedata
|
| 16 |
+
from FlagEmbedding import FlagReranker
|
| 17 |
|
| 18 |
# Setup Qdrant Client
|
| 19 |
qdrant_client = QdrantClient(
|
|
|
|
| 44 |
"BGE M3": {
|
| 45 |
"model": SentenceTransformer("BAAI/bge-m3"),
|
| 46 |
"collection": "product_bge-m3",
|
| 47 |
+
"threshold": 0.45,
|
| 48 |
"prefix": ""
|
| 49 |
}
|
| 50 |
}
|
| 51 |
|
| 52 |
+
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
|
| 53 |
+
|
| 54 |
# Utils
|
| 55 |
def is_non_thai(text):
|
| 56 |
return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
|
| 57 |
|
| 58 |
+
def join_corrected_tokens(corrected: list) -> str:
|
| 59 |
+
if corrected and is_non_thai("".join(corrected)):
|
| 60 |
+
return " ".join([w for w in corrected if len(w) > 1 or w in keyword_whitelist])
|
| 61 |
+
else:
|
| 62 |
+
return "".join([w for w in corrected if len(w) > 1 or w in keyword_whitelist])
|
| 63 |
+
|
| 64 |
def normalize(text: str) -> str:
|
| 65 |
if is_non_thai(text):
|
| 66 |
return text.strip()
|
|
|
|
| 99 |
if not matched:
|
| 100 |
corrected.append(tokens[i])
|
| 101 |
i += 1
|
| 102 |
+
return join_corrected_tokens(corrected)
|
| 103 |
|
| 104 |
# Global state
|
| 105 |
latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
|
|
|
|
| 119 |
query_embed = model.encode(prefix + corrected_query)
|
| 120 |
|
| 121 |
try:
|
| 122 |
+
# 🔍 ดึง top-50 ก่อน rerank
|
| 123 |
result = qdrant_client.query_points(
|
| 124 |
collection_name=collection_name,
|
| 125 |
query=query_embed.tolist(),
|
|
|
|
| 130 |
except Exception as e:
|
| 131 |
return f"<p>❌ Qdrant error: {str(e)}</p>"
|
| 132 |
|
| 133 |
+
# ✅ Rerank Top 10 ด้วย Cross-Encoder (เฉพาะ BGE M3 เท่านั้น)
|
| 134 |
+
if model_choice == "BGE M3" and len(result) > 0:
|
| 135 |
+
topk = 10
|
| 136 |
+
docs = [r.payload.get("name", "") for r in result[:topk]]
|
| 137 |
+
pairs = [[corrected_query, d] for d in docs]
|
| 138 |
+
scores = reranker.compute_score(pairs, normalize=True)
|
| 139 |
+
|
| 140 |
+
# ผสมคะแนน: 0.6 จาก embedding, 0.4 จาก reranker
|
| 141 |
+
result[:topk] = sorted(
|
| 142 |
+
zip(result[:topk], scores),
|
| 143 |
+
key=lambda x: 0.6 * x[0].score + 0.4 * x[1],
|
| 144 |
+
reverse=True
|
| 145 |
+
)
|
| 146 |
+
result[:topk] = [r[0] for r in result[:topk]]
|
| 147 |
+
|
| 148 |
elapsed = time.time() - start_time
|
| 149 |
html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
|
| 150 |
if corrected_query != query:
|
| 151 |
html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
|
|
|
|
| 152 |
html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
|
| 153 |
result_summary, found = "", False
|
| 154 |
|