Spaces:
Sleeping
Sleeping
add more to can select model
Browse files
app.py
CHANGED
|
@@ -28,14 +28,26 @@ TABLE_NAME = "Feedback_search"
|
|
| 28 |
api = Api(AIRTABLE_API_KEY)
|
| 29 |
table = api.table(BASE_ID, TABLE_NAME)
|
| 30 |
|
| 31 |
-
# Load model
|
| 32 |
-
model = SentenceTransformer('e5_finetuned')
|
| 33 |
-
collection_name = "product_E5_finetune"
|
| 34 |
-
|
| 35 |
# Load whitelist
|
| 36 |
with open("keyword_whitelist.pkl", "rb") as f:
|
| 37 |
keyword_whitelist = pickle.load(f)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
# Utils
|
| 40 |
def is_non_thai(text):
|
| 41 |
return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
|
|
@@ -83,12 +95,19 @@ def correct_query_merge_phrases(query: str, whitelist, threshold=80, max_ngram=3
|
|
| 83 |
# Global state
|
| 84 |
latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
|
| 85 |
|
| 86 |
-
#
|
| 87 |
-
def search_product(query):
|
| 88 |
start_time = time.time()
|
| 89 |
latest_query_result["raw_query"] = query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
corrected_query = correct_query_merge_phrases(query, keyword_whitelist)
|
| 91 |
-
query_embed = model.encode(
|
| 92 |
|
| 93 |
try:
|
| 94 |
result = qdrant_client.query_points(
|
|
@@ -107,10 +126,10 @@ def search_product(query):
|
|
| 107 |
html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
|
| 108 |
|
| 109 |
html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
|
| 110 |
-
|
| 111 |
result_summary, found = "", False
|
|
|
|
| 112 |
for res in result:
|
| 113 |
-
if res.score
|
| 114 |
found = True
|
| 115 |
name = res.payload.get("name", "ไม่ทราบชื่อสินค้า")
|
| 116 |
score = f"{res.score:.4f}"
|
|
@@ -145,12 +164,12 @@ def search_product(query):
|
|
| 145 |
|
| 146 |
return html_output
|
| 147 |
|
| 148 |
-
# Feedback
|
| 149 |
-
def log_feedback(feedback):
|
| 150 |
try:
|
| 151 |
now = datetime.now().strftime("%Y-%m-%d")
|
| 152 |
table.create({
|
| 153 |
-
"model":
|
| 154 |
"timestamp": now,
|
| 155 |
"raw_query": latest_query_result["raw_query"],
|
| 156 |
"query": latest_query_result["query"],
|
|
@@ -166,7 +185,10 @@ def log_feedback(feedback):
|
|
| 166 |
with gr.Blocks() as demo:
|
| 167 |
gr.Markdown("## 🔎 Product Semantic Search (Vector Search + Qdrant)")
|
| 168 |
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
| 170 |
result_output = gr.HTML(label="📋 ผลลัพธ์")
|
| 171 |
|
| 172 |
with gr.Row():
|
|
@@ -175,9 +197,8 @@ with gr.Blocks() as demo:
|
|
| 175 |
|
| 176 |
feedback_status = gr.Textbox(label="📬 สถานะ Feedback")
|
| 177 |
|
| 178 |
-
query_input.submit(search_product, inputs=[query_input], outputs=result_output)
|
| 179 |
-
match_btn.click(lambda: log_feedback("match"), outputs=feedback_status)
|
| 180 |
-
not_match_btn.click(lambda: log_feedback("not_match"), outputs=feedback_status)
|
| 181 |
|
| 182 |
-
# Run
|
| 183 |
demo.launch(share=True)
|
|
|
|
| 28 |
api = Api(AIRTABLE_API_KEY)
|
| 29 |
table = api.table(BASE_ID, TABLE_NAME)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
# Load whitelist
|
| 32 |
with open("keyword_whitelist.pkl", "rb") as f:
|
| 33 |
keyword_whitelist = pickle.load(f)
|
| 34 |
|
| 35 |
+
# Preload Models
|
| 36 |
+
models = {
|
| 37 |
+
"E5 Finetuned": {
|
| 38 |
+
"model": SentenceTransformer("e5_finetuned"),
|
| 39 |
+
"collection": "product_E5_finetune",
|
| 40 |
+
"threshold": 0.8,
|
| 41 |
+
"prefix": "query: "
|
| 42 |
+
},
|
| 43 |
+
"BGE M3": {
|
| 44 |
+
"model": SentenceTransformer("BAAI/bge-m3"),
|
| 45 |
+
"collection": "product_bge-m3",
|
| 46 |
+
"threshold": 0.5,
|
| 47 |
+
"prefix": ""
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
# Utils
|
| 52 |
def is_non_thai(text):
|
| 53 |
return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
|
|
|
|
| 95 |
# Global state
|
| 96 |
latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
|
| 97 |
|
| 98 |
+
# Search Function
|
| 99 |
+
def search_product(query, model_choice):
|
| 100 |
start_time = time.time()
|
| 101 |
latest_query_result["raw_query"] = query
|
| 102 |
+
|
| 103 |
+
selected = models[model_choice]
|
| 104 |
+
model = selected["model"]
|
| 105 |
+
collection_name = selected["collection"]
|
| 106 |
+
threshold = selected["threshold"]
|
| 107 |
+
prefix = selected["prefix"]
|
| 108 |
+
|
| 109 |
corrected_query = correct_query_merge_phrases(query, keyword_whitelist)
|
| 110 |
+
query_embed = model.encode(prefix + corrected_query)
|
| 111 |
|
| 112 |
try:
|
| 113 |
result = qdrant_client.query_points(
|
|
|
|
| 126 |
html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
|
| 127 |
|
| 128 |
html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
|
|
|
|
| 129 |
result_summary, found = "", False
|
| 130 |
+
|
| 131 |
for res in result:
|
| 132 |
+
if res.score >= threshold:
|
| 133 |
found = True
|
| 134 |
name = res.payload.get("name", "ไม่ทราบชื่อสินค้า")
|
| 135 |
score = f"{res.score:.4f}"
|
|
|
|
| 164 |
|
| 165 |
return html_output
|
| 166 |
|
| 167 |
+
# Feedback Function
|
| 168 |
+
def log_feedback(feedback, model_choice):
|
| 169 |
try:
|
| 170 |
now = datetime.now().strftime("%Y-%m-%d")
|
| 171 |
table.create({
|
| 172 |
+
"model": model_choice,
|
| 173 |
"timestamp": now,
|
| 174 |
"raw_query": latest_query_result["raw_query"],
|
| 175 |
"query": latest_query_result["query"],
|
|
|
|
| 185 |
with gr.Blocks() as demo:
|
| 186 |
gr.Markdown("## 🔎 Product Semantic Search (Vector Search + Qdrant)")
|
| 187 |
|
| 188 |
+
with gr.Row():
|
| 189 |
+
model_selector = gr.Dropdown(label="🔍 เลือกโมเดล", choices=list(models.keys()), value="E5 Finetuned")
|
| 190 |
+
query_input = gr.Textbox(label="พิมพ์คำค้นหา")
|
| 191 |
+
|
| 192 |
result_output = gr.HTML(label="📋 ผลลัพธ์")
|
| 193 |
|
| 194 |
with gr.Row():
|
|
|
|
| 197 |
|
| 198 |
feedback_status = gr.Textbox(label="📬 สถานะ Feedback")
|
| 199 |
|
| 200 |
+
query_input.submit(search_product, inputs=[query_input, model_selector], outputs=result_output)
|
| 201 |
+
match_btn.click(fn=lambda model: log_feedback("match", model), inputs=model_selector, outputs=feedback_status)
|
| 202 |
+
not_match_btn.click(fn=lambda model: log_feedback("not_match", model), inputs=model_selector, outputs=feedback_status)
|
| 203 |
|
|
|
|
| 204 |
demo.launch(share=True)
|