Southisuk commited on
Commit
ceda26c
·
verified ·
1 Parent(s): c13df45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +531 -68
app.py CHANGED
@@ -1,75 +1,538 @@
1
- import json
2
- import gradio as gr
3
- from langchain_community.vectorstores import FAISS
4
- from langchain_community.embeddings import HuggingFaceEmbeddings
5
- from langchain.chains import RetrievalQA
6
- from langchain_community.llms import HuggingFacePipeline
7
- from transformers import pipeline
8
-
9
- # -----------------------------
10
- # 1) Load dataset + build vector DB
11
- # -----------------------------
12
- with open("nbb_merged_full.json", "r", encoding="utf-8") as f:
13
- data = json.load(f)
14
-
15
- texts = []
16
- if isinstance(data, list):
17
- for item in data:
18
- if isinstance(item, dict) and "text" in item:
19
- texts.append(item["text"])
20
- elif isinstance(item, str):
21
- texts.append(item)
22
- elif isinstance(data, dict):
23
- if "text" in data:
24
- texts.append(data["text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  else:
26
- texts.extend([str(v) for v in data.values()])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- print(f" Loaded {len(texts)} documents from dataset")
29
 
30
- # Embedding model
31
- embeddings = HuggingFaceEmbeddings(
32
- model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
 
 
33
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Build FAISS DB
36
- db = FAISS.from_texts(texts, embeddings)
37
-
38
- # -----------------------------
39
- # 2) Load LLM (lightweight)
40
- # -----------------------------
41
- # flan-t5-base = เบา / multilingual MiniLM = รองรับหลายภาษา
42
- model_name = "google/flan-t5-base"
43
- pipe = pipeline("text2text-generation", model=model_name, device=-1, max_new_tokens=256)
44
- llm = HuggingFacePipeline(pipeline=pipe)
45
-
46
- # -----------------------------
47
- # 3) QA Chain (RAG)
48
- # -----------------------------
49
- retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
50
- qa = RetrievalQA.from_chain_type(
51
- llm=llm,
52
- retriever=retriever,
53
- chain_type="stuff"
54
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # -----------------------------
57
- # 4) Gradio UI
58
- # -----------------------------
59
- def chatbot(message, history):
60
- if not message.strip():
61
- return "⚠️ ກະລຸນາພິມຄຳຖາມ"
62
- result = qa.run(message)
63
- return result
64
-
65
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
66
- gr.Markdown("<h1 style='text-align: center; color: green;'>🌾 Lao Chatbot (RAG)</h1>")
67
- chatbot_ui = gr.Chatbot(height=500)
68
- msg = gr.Textbox(placeholder="ພິມຄຳຖາມທີ່ນີ້...", label="Input")
69
- clear_btn = gr.Button("🧹 Clear Chat")
70
-
71
- msg.submit(fn=chatbot, inputs=[msg, chatbot_ui], outputs=chatbot_ui)
72
- clear_btn.click(lambda: None, None, chatbot_ui, queue=False)
73
-
74
- if __name__ == "__main__":
75
- demo.launch()
 
1
+ import json, os, re
2
+ import numpy as np
3
+ from scipy.sparse import hstack, csr_matrix
4
+ from sklearn.feature_extraction.text import TfidfVectorizer
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ from google.colab import files
7
+
8
+ # === Cell 1: Load & Normalize new merged dataset (supports content + csv_row) ===
9
+ import json, os
10
+ from google.colab import files
11
+
12
+ # คุณสามารถชี้ไปยังไฟล์รวมใหม่ได้เลย เช่น:
13
+ PREFERRED_PATHS = [
14
+ "/content/nbb_merged_full.json", # ถ้ารันใน Colab หลังอัปโหลดไฟล์นี้
15
+ "/content/intents_dataset_v1_lo.json", # เผื่อยังใช้ไฟล์เก่า
16
+ "/mnt/data/nbb_merged_full.json", # เผื่อรันบนเครื่อง/สภาพแวดล้อมอื่น
17
+ ]
18
+
19
+ DATASET_PATH = None
20
+ for p in PREFERRED_PATHS:
21
+ if os.path.exists(p):
22
+ DATASET_PATH = p
23
+ break
24
+
25
+ if DATASET_PATH is None:
26
+ print("กรุณาอัปโหลดไฟล์ dataset (.json) ที่รวมของใหม่ (เช่น nbb_merged_full.json)")
27
+ uploaded = files.upload()
28
+ assert uploaded, "ไม่ได้อัปโหลดไฟล์"
29
+ fname = list(uploaded.keys())[0]
30
+ DATASET_PATH = f"/content/{fname}"
31
+ os.rename(fname, DATASET_PATH)
32
+
33
+ with open(DATASET_PATH, "r", encoding="utf-8") as f:
34
+ RAW_DATA = json.load(f)
35
+
36
+ def _get(d, *chain, default=""):
37
+ x = d
38
+ for k in chain:
39
+ if not isinstance(x, dict) or k not in x:
40
+ return default
41
+ x = x[k]
42
+ return x
43
+
44
+ def normalize_record(d: dict) -> dict:
45
+ """
46
+ คืนค่า document สกีมาเดียว:
47
+ {
48
+ id, section, title,
49
+ content: { lo: "...", en_summary: "..." },
50
+ keywords: [...],
51
+ score_boost: float
52
+ }
53
+ """
54
+ base_id = d.get("id", None) or f"auto_{hash(json.dumps(d, ensure_ascii=False)) & 0xffffffff}"
55
+ section = d.get("section", d.get("source_type", ""))
56
+ title = d.get("title") or _get(d, "data", "title") or _get(d, "data", "topic") or base_id
57
+
58
+ # กรณีเดิม: มี content.lo อยู่แล้ว
59
+ lo_text = _get(d, "content", "lo", default="").strip()
60
+ en_sum = _get(d, "content", "en_summary", default="").strip()
61
+
62
+ # กรณี CSV แถวใหม่: ใช้ answer เป็น content.lo, และดัน question เข้าไปใน keywords ช่วยค้น
63
+ if not lo_text and ("data" in d or d.get("source_type") == "csv_row"):
64
+ data = d.get("data", {})
65
+ ans = str(data.get("answer", "") or "").strip()
66
+ que = str(data.get("question", "") or "").strip()
67
+ top = str(data.get("topic", "") or "").strip()
68
+ lo_text = ans
69
+ # เติมสรุปอังกฤษถ้ามีอยู่เดิม
70
+ if not en_sum and isinstance(_get(d, "content"), dict):
71
+ en_sum = _get(d, "content", "en_summary", default="")
72
+ # รวบรวม keywords จาก topic/section/question สั้น ๆ
73
+ kws = []
74
+ if top: kws.append(top)
75
+ if section: kws.append(section)
76
+ if que: kws.append(que[:120])
77
+ keywords = list(dict.fromkeys((d.get("keywords") or []) + kws))
78
  else:
79
+ keywords = d.get("keywords") or []
80
+
81
+ score_boost = float(d.get("score_boost", 1.0))
82
+
83
+ return {
84
+ "id": base_id,
85
+ "section": section,
86
+ "title": title,
87
+ "content": {
88
+ "lo": lo_text,
89
+ "en_summary": en_sum
90
+ },
91
+ "keywords": keywords,
92
+ "score_boost": score_boost,
93
+ "_raw": d # เก็บต้นฉบับไว้ตรวจสอบ/อ้างอิง
94
+ }
95
+
96
+ # รวมทุกเรคคอร์ด แล้วคัดเฉพาะที่มีเนื้อหาให้สร้างดัชนีได้
97
+ DOCS = []
98
+ for d in RAW_DATA:
99
+ try:
100
+ nd = normalize_record(d)
101
+ if (nd.get("content", {}) or {}).get("lo", "").strip():
102
+ DOCS.append(nd)
103
+ except Exception as e:
104
+ # ข้ามเรคคอร์ดที่เสีย
105
+ pass
106
+
107
+ assert DOCS, "ไม่พบเอกสารที่มี content.lo หลัง normalize — กรุณาตรวจไฟล์ dataset"
108
+ print(f"[OK] Loaded & normalized {len(DOCS)} docs from: {DATASET_PATH}")
109
+
110
+ # === Cell 2: Build index text from unified schema (content + csv_row) ===
111
+ import re
112
+
113
+ ZWSP = "\u200b"
114
+
115
+ def normalize_lo(text: str) -> str:
116
+ if not text: return ""
117
+ t = text.replace(ZWSP, " ")
118
+ t = re.sub(r"\s+", " ", t).strip()
119
+ return t
120
+
121
+ def build_index_text(doc: dict) -> str:
122
+ title = normalize_lo(doc.get("title", ""))
123
+ lo = normalize_lo(doc.get("content", {}).get("lo", ""))
124
+ en = normalize_lo(doc.get("content", {}).get("en_summary", ""))
125
+ kws = ", ".join(doc.get("keywords", []) or [])
126
+ sec = normalize_lo(doc.get("section", ""))
127
+
128
+ # เพิ่ม section และ keywords เพื่อช่วยค้น
129
+ # NOTE: ถ้ามีคำถามจาก CSV เราได้ยัดไว้ใน keywords ไปแล้วบางส่วน
130
+ return "\n".join([t for t in [title, lo, en, sec, kws] if t]).strip()
131
+
132
+ CORPUS = [build_index_text(d) for d in DOCS]
133
+ IDS = [d["id"] for d in DOCS]
134
+ SECTIONS = [d.get("section", "") for d in DOCS]
135
+ BOOSTS = [float(d.get("score_boost", 1.0)) for d in DOCS]
136
+ ID2DOC = {d["id"]: d for d in DOCS}
137
 
138
+ print(f"[OK] Built corpus of {len(CORPUS)} items.")
139
 
140
+ word_vec = TfidfVectorizer(
141
+ analyzer="word",
142
+ ngram_range=(1,2), # 1-2 คำ พอ ไม่หนักไป
143
+ min_df=1, max_df=0.95,
144
+ sublinear_tf=True
145
  )
146
+ char_vec = TfidfVectorizer(
147
+ analyzer="char_wb", # สร้าง n-gram ในกรอบคำ (กันสัญลักษณ์รบกวน)
148
+ ngram_range=(3,5),
149
+ min_df=1, max_df=0.98,
150
+ sublinear_tf=True
151
+ )
152
+
153
+ Xw = word_vec.fit_transform(CORPUS)
154
+ Xc = char_vec.fit_transform(CORPUS)
155
+ X = hstack([Xw, Xc]).tocsr()
156
+
157
+ TOP_K = 20
158
+ FINAL_TOP_N = 3
159
+ MIN_CONF = 0.12 # TF-IDF scale จะเล็กกว่า embedding; ตั้ง 0.1-0.2 เป็นเกตเริ่มต้น
160
+
161
+ # Placeholder function for keyword_intent_hint
162
+ def keyword_intent_hint(q: str) -> list:
163
+ """
164
+ Placeholder function for keyword_intent_hint.
165
+ Replace with actual implementation if needed.
166
+ """
167
+ return []
168
+
169
+ SECTION_WEIGHTS = {} # Add a placeholder for SECTION_WEIGHTS if it's not defined elsewhere
170
+
171
+
172
+ def vectorize_query(q: str) -> csr_matrix:
173
+ qn = normalize_lo(q)
174
+ qw = word_vec.transform([qn])
175
+ qc = char_vec.transform([qn])
176
+ return hstack([qw, qc]).tocsr()
177
+
178
+ def search(q: str, k: int = TOP_K):
179
+ qv = vectorize_query(q)
180
+ sims = cosine_similarity(qv, X)[0] # shape = (N,)
181
+ # จัดอันดับ
182
+ idxs = np.argsort(-sims)[:k]
183
+ hits = []
184
+ hints = keyword_intent_hint(q)
185
+ for ix in idxs:
186
+ base = float(sims[ix])
187
+ sec = SECTIONS[ix]
188
+ boost = BOOSTS[ix]
189
+ # section weights
190
+ if sec in SECTION_WEIGHTS:
191
+ boost *= SECTION_WEIGHTS[sec]
192
+ # keyword hints
193
+ if sec in hints:
194
+ boost *= 1.10
195
+ final = base * boost
196
+ hits.append({
197
+ "id": IDS[ix],
198
+ "score": base,
199
+ "final_score": final,
200
+ "section": sec
201
+ })
202
+ # เรียงตาม final_score
203
+ hits.sort(key=lambda h: h["final_score"], reverse=True)
204
+ return hits
205
+
206
+ def answer_template_only(q: str):
207
+ hits = search(q, k=TOP_K)
208
+ if not hits or hits[0]["score"] < MIN_CONF:
209
+ return "ຂໍອະໄພ ບໍ່ພົບຂໍ້ມູນໃນຖານຄວາມຮູ້.", []
210
+ chunks, cits = [], []
211
+ for h in hits[:FINAL_TOP_N]:
212
+ d = ID2DOC[h["id"]]
213
+ title = d.get("title", d["id"])
214
+ lo = d.get("content",{}).get("lo","")
215
+ chunks.append(f"• {title}\n{lo}")
216
+ cits.append(h["id"])
217
+ return "\n\n".join(chunks), cits
218
+
219
+ # ============================================================
220
+ # LLM-Guarded RAG (เสริมพลังจากโหมดที่ 1) — ติดตั้งเพิ่มน้อยสุด
221
+ # ใช้ llama-cpp-python + โมเดล GGUF เบาๆ (Qwen2.5-3B หรือ Llama 3.2 3B, 4-bit)
222
+ # ============================================================
223
+
224
+ import os, json, re, time
225
+ from google.colab import files
226
+
227
+ # -------- 1) ติดตั้ง llama-cpp-python (ตัวเดียวพอ) --------
228
+ try:
229
+ import llama_cpp
230
+ except Exception:
231
+ # ติดตั้งเฉพาะเมื่อยังไม่มี (เวอร์ชันเสถียรกับ Py311/Colab)
232
+ !pip -q install llama-cpp-python==0.2.90
233
+ import llama_cpp
234
+
235
+ from llama_cpp import Llama
236
+
237
+ # -------- 2) เตรียมโมเดล GGUF --------
238
+ # เลือกอย่างใดอย่างหนึ่ง:
239
+ # (A) ให้ระบบพยายามดาวน์โหลดจาก Hugging Face (ต้องมีเน็ต)
240
+ # (B) ถ้าไม่อยากดาวน์โหลด: อัปโหลดไฟล์ .gguf เอง แล้วตั้งชื่อ local-llm.gguf
241
+
242
+ MODEL_PATH = "/content/local-llm.gguf"
243
+
244
+ def ensure_model():
245
+ if os.path.exists(MODEL_PATH):
246
+ return True
247
+ print("ยังไม่มีโมเดล .gguf → เลือกวิธีใดวิธีหนึ่ง:")
248
+ print(" 1) อัปโหลดไฟล์ .gguf ด้วยตนเอง (แนะนำ Q4_K_M ~ 2GB) แล้วตั้งชื่อ local-llm.gguf")
249
+ print(" 2) หรือให้ช่วยดาวน์โหลด (ต้องใช้เน็ต): Qwen2.5-3B-Instruct Q4_K_M")
250
+ choice = input("พิมพ์ 1 (upload) / 2 (download): ").strip()
251
+ if choice == "1":
252
+ uploaded = files.upload()
253
+ assert uploaded, "ไม่ได้อัปโหลดไฟล์"
254
+ fname = list(uploaded.keys())[0]
255
+ os.rename(fname, MODEL_PATH)
256
+ print("อัปโหลดแล้ว:", MODEL_PATH)
257
+ return True
258
+ else:
259
+ try:
260
+ from huggingface_hub import hf_hub_download
261
+ except Exception:
262
+ # ติดตั้งเฉพาะเมื่อจำเป็น
263
+ !pip -q install huggingface_hub==0.25.2
264
+ from huggingface_hub import hf_hub_download
265
+ REPO_ID = "Qwen/Qwen2.5-3B-Instruct-GGUF"
266
+ FNAME = "qwen2.5-3b-instruct-q4_k_m.gguf"
267
+ try:
268
+ p = hf_hub_download(repo_id=REPO_ID, filename=FNAME, local_dir="/content", local_dir_use_symlinks=False)
269
+ os.rename(p, MODEL_PATH)
270
+ print("ดาวน์โหลดสำเร็จ:", MODEL_PATH)
271
+ return True
272
+ except Exception as e:
273
+ print("ดาวน์โหลดไม่สำเร็จ:", e)
274
+ print("โปรดอัปโหลดไฟล์ .gguf เอง แล้วตั้งชื่อ local-llm.gguf")
275
+ return False
276
+
277
+ ok = ensure_model()
278
+ assert ok and os.path.exists(MODEL_PATH), "ยังไม่มีโมเดล .gguf ให้ใช้งาน"
279
+
280
+ # -------- 3) โหลดโมเดลด้วย llama.cpp + อุ่นเครื่อง --------
281
+ from llama_cpp import Llama
282
 
283
+ # ปรับค่าตามเครื่อง:
284
+ # - ถ้า Colab (T4): n_gpu_layers=128, n_batch=512
285
+ # - ถ้า GTX1650 4GB: n_gpu_layers=24~32, n_batch=256 (ถ้า OOM ให้ลดลง หรือตั้ง 0 = CPU)
286
+ LLM = Llama(
287
+ model_path=MODEL_PATH,
288
+ n_ctx=2048, # พอสำหรับบริบท 1 ชิ้น + คำตอบ
289
+ n_threads=8,
290
+ n_gpu_layers=128, # <-- GTX1650 ให้ใช้ 24~32 แทน
291
+ n_batch=512, # <-- GTX1650 ใช้ 256
292
+ logits_all=False,
293
+ verbose=False
 
 
 
 
 
 
 
 
294
  )
295
+ print("✅ LLM loaded:", MODEL_PATH)
296
+
297
+ # อุ่นเครื่องรอบแรก ลดดีเลย์ในการตอบครั้งถัดไป
298
+ try:
299
+ _ = LLM("Warmup", max_tokens=1)
300
+ print("🔥 Warmup done")
301
+ except Exception as e:
302
+ print("⚠️ Warmup skipped:", e)
303
+
304
+ # =========================
305
+ # BEST: Guarded RAG + Auto-Judge + Router + Logging (รองรับ llama.cpp ของคุณ)
306
+ # ต้องมีตัวแปรก่อนหน้า: LLM, search(query,k) -> hits, ID2DOC (dict), และ (ถ้ามี) answer_template_only()
307
+ # =========================
308
+
309
+ import os, re, json, time
310
+ from datetime import datetime
311
+
312
+ # ---------- CONFIG ----------
313
+ #TOP_K = 5
314
+ #CHUNK_LIMIT = 250
315
+ #MAX_TOKENS = 64
316
+
317
+ TOP_K = 10 # ค้นเอกสารเบื้องต้น
318
+ FINAL_TOP_N = 1 # ส่งเข้า LLM แค่ 1 ชิ้น (เร็วและนิ่ง)
319
+ MIN_CONF = 0.14 # เกณฑ์ความเชื่อมั่นของ retrieval (TF-IDF)
320
+ CHUNK_LIMIT = 360 # ตัดความยาว context/ชิ้น
321
+ MAX_TOKENS = 96 # จำกัดความย���วคำตอบ
322
+ TEMP = 0.2
323
+ QUALITY_LOG = "/content/quality_feedback.jsonl"
324
+
325
+ # ---------- SYSTEM RULES ----------
326
+ SYSTEM_RULES = """
327
+ You are a Lao banking assistant for NAYOBY BANK (NBB).
328
+
329
+ HARD RULES (do not break):
330
+ 1) Answer ONLY from the provided Context. Do NOT use outside knowledge or make assumptions.
331
+ 2) If the answer is not clearly in the Context, reply in Lao: "ຂໍອະໄພ ຂ້ອຍບໍ່ພົບຂໍ້ມູນໃນຖານຄວາມຮູ້."
332
+ 3) Cite the evidence ids at the end in square brackets (1–3 ids).
333
+ 4) Default reply in Lao; if the whole user question is Thai/English, reply with that language; keep product terms exactly as in Context.
334
+ 5) Never invent numbers, dates, fees, branches, or contacts beyond the Context.
335
+
336
+ STYLE:
337
+ - Concise (≤ 100 Lao words). Direct answer first, bullets if needed.
338
+ - Keep terminology exactly as in Context.
339
+
340
+ FORMAT:
341
+ - End the last line with citations like: [id_a, id_b]
342
+ """
343
+
344
+ # ---------- PROMPT BUILDER ----------
345
+ def _build_context(hits, n=FINAL_TOP_N, limit=CHUNK_LIMIT):
346
+ parts, used = [], []
347
+ for h in hits[:n]:
348
+ d = ID2DOC[h["id"]]
349
+ title = d.get("title", h["id"])
350
+ lo = (d.get("content", {}).get("lo", "") or "")[:limit]
351
+ parts.append(f"[{h['id']}] {title}\n{lo}")
352
+ used.append(h["id"])
353
+ return "\n\n".join(parts), used
354
+
355
+ def _build_prompt(query, hits):
356
+ ctx, _ = _build_context(hits)
357
+ return (
358
+ f"{SYSTEM_RULES}\n\n"
359
+ f"### Context:\n{ctx}\n\n"
360
+ f"### Question:\n{query}\n\n"
361
+ "### Answer:\n"
362
+ )
363
+
364
+ # ---------- LLM ANSWER (Guarded) ----------
365
+ def llm_guarded_answer_best(query: str):
366
+ hits = search(query, k=TOP_K)
367
+ if not hits or hits[0]["score"] < MIN_CONF:
368
+ return "ຂໍອະໄພ ບໍ່ພົບຂໍ້ມູນໃນຖານຄວາມຮູ້.", [], hits
369
+
370
+ prompt = _build_prompt(query, hits)
371
+ # warmup ลดดีเลย์ครั้งแรก
372
+ try: _ = LLM("Warmup", max_tokens=1)
373
+ except: pass
374
+
375
+ out = LLM(
376
+ prompt,
377
+ max_tokens=MAX_TOKENS,
378
+ temperature=TEMP,
379
+ top_p=0.9,
380
+ repeat_penalty=1.1,
381
+ stop=["</s>", "### Question:", "### Context:"]
382
+ )
383
+ text = out["choices"][0]["text"].strip()
384
+ cites = [h["id"] for h in hits[:FINAL_TOP_N]]
385
+ return text, cites, hits
386
+
387
+ # ---------- TEMPLATE FALLBACK (ถ้าไม่มีให้ใช้เวอร์ชันย่อ) ----------
388
+ def _template_only_from_hits(hits):
389
+ if not hits:
390
+ return "ຂໍອະໄພ ບໍ່ພົບຂໍ້ມູນໃນຖານຄວາມຮູ້.", []
391
+ d = ID2DOC[hits[0]["id"]]
392
+ lo = d.get("content", {}).get("lo", "") or ""
393
+ return lo, [d["id"]]
394
+
395
+ # ---------- HEURISTICS (กันพลาดเร็ว) ----------
396
+ def _tok(s): return re.findall(r"[\w\-\.%]+", s.lower(), flags=re.U)
397
+ def _numbers(s): return re.findall(r"\d+(?:[.,]\d+)?", s)
398
+
399
+ def heuristic_label(query, answer, ctx_text, hits, citations):
400
+ verdict, reasons = None, []
401
+
402
+ max_sim = hits[0]["score"] if hits else 0.0
403
+ avg_top3 = sum([h["score"] for h in hits[:3]])/max(1,len(hits[:3]))
404
+ if not citations:
405
+ return "INCORRECT", "no citations"
406
+ if max_sim < MIN_CONF:
407
+ return "INCORRECT", f"low sim {max_sim:.2f}"
408
+
409
+ # overlap ของคำในคำตอบที่อยู่ใน context
410
+ ans_t = set(_tok(answer)); ctx_t = set(_tok(ctx_text))
411
+ overlap = len(ans_t & ctx_t) / max(1, len(ans_t))
412
+ if overlap < 0.25:
413
+ verdict, reasons = "ALMOST", [f"low overlap {overlap:.2f}"]
414
+
415
+ # ตัวเลขที่โผล่ในคำตอบแต่ไม่มีใน context
416
+ ans_nums = set(_numbers(answer))
417
+ ctx_nums = set(_numbers(ctx_text))
418
+ invented = ans_nums - ctx_nums
419
+ if invented:
420
+ # ถ้าตัวเลขเยอะและไม่อยู่ใน context ให้ลดเป็น INCORRECT
421
+ return "INCORRECT", f"invented numbers: {sorted(invented)}"
422
+
423
+ if verdict is None:
424
+ verdict, reasons = "CORRECT", [f"sim {max_sim:.2f}, overlap {overlap:.2f}"]
425
+ return verdict, "; ".join(reasons)
426
+
427
+ # ---------- LLM-AS-A-JUDGE (ใช้โมเดลของคุณ) ----------
428
+ JUDGE_PROMPT = """
429
+ You are a strict evaluator for a Lao banking RAG system.
430
+ Decide if the Answer is CORRECT, ALMOST, or INCORRECT based ONLY on the Context and the Question.
431
+ - CORRECT: fully supported by Context; no invented facts; answers the question.
432
+ - ALMOST: mostly supported but missing a key detail or minor phrasing errors.
433
+ - INCORRECT: unsupported/contradicted/invented/wrong numbers/off-topic.
434
+ Return pure JSON: {"verdict":"CORRECT|ALMOST|INCORRECT","reason":"<=25 Lao words"}
435
+ """
436
+
437
+ def judge_with_llm_same_model(question, ctx_text, answer):
438
+ prompt = (
439
+ f"{JUDGE_PROMPT}\n\n"
440
+ f"### Context:\n{ctx_text}\n\n"
441
+ f"### Question:\n{question}\n\n"
442
+ f"### Answer:\n{answer}\n\n"
443
+ "### Your JSON:\n"
444
+ )
445
+ res = LLM(prompt, max_tokens=96, temperature=0.0, stop=["</s>", "###"])
446
+ raw = res["choices"][0]["text"].strip()
447
+ # ดึง JSON ออกมาแบบกันพลาด
448
+ m = re.search(r"\{.*\}", raw, re.S)
449
+ try:
450
+ return json.loads(m.group(0) if m else raw)
451
+ except Exception:
452
+ return {"verdict":"INCORRECT","reason":"judge parsing failed"}
453
+
454
+ # ---------- ROUTER + LOGGING ----------
455
+ def _build_ctx_text(hits):
456
+ ctx, used = _build_context(hits, n=FINAL_TOP_N, limit=CHUNK_LIMIT)
457
+ return ctx, used
458
+
459
+ def log_quality(record: dict):
460
+ with open(QUALITY_LOG, "a", encoding="utf-8") as f:
461
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
462
+
463
+ def smart_answer(query: str, use_judge=True, allow_template_fallback=True):
464
+ # 1) ตอบด้วย LLM-Guarded
465
+ ans, cites, hits = llm_guarded_answer_best(query)
466
+
467
+ # 2) แปลง context ที่ใช้จริง
468
+ ctx_text, used_ids = _build_ctx_text(hits)
469
+
470
+ # 3) Heuristics
471
+ h_verdict, h_reason = heuristic_label(query, ans, ctx_text, hits, cites)
472
+
473
+ # 4) LLM Judge (สั้นและเร็ว)
474
+ j_verdict, j_reason = None, None
475
+ if use_judge:
476
+ j = judge_with_llm_same_model(query, ctx_text, ans)
477
+ j_verdict = (j.get("verdict") or "").upper()
478
+ j_reason = j.get("reason","").strip()
479
+
480
+ # 5) รวมคำตัดสิน (เข้มงวด = เอา “แย่กว่า”)
481
+ order = {"INCORRECT":0, "ALMOST":1, "CORRECT":2}
482
+ final_v = h_verdict
483
+ final_r = f"Heur:{h_reason}"
484
+ if j_verdict in order and order[j_verdict] < order[final_v]:
485
+ final_v = j_verdict
486
+ final_r = f"Judge:{j_reason} | Heur:{h_reason}"
487
+
488
+ # 6) ถ้าแย่ → fallback เป็น Template-only (ถ้าต้องการ)
489
+ if allow_template_fallback and final_v in ("INCORRECT","ALMOST"):
490
+ try:
491
+ t_ans, t_cites = answer_template_only(query) # ถ้ามีฟังก์ชันของคุณอยู่แล้ว
492
+ except NameError:
493
+ t_ans, t_cites = _template_only_from_hits(hits)
494
+ ans = t_ans
495
+ cites = t_cites
496
+ final_v = "CORRECT" # แหล่งอ้างอิงตรงจากฐานความรู้ (ไม่แต่ง)
497
+
498
+ # 7) บันทึกล็อกเพื่อปรับปรุงภายหลัง
499
+ rec = {
500
+ "ts": datetime.utcnow().isoformat(),
501
+ "query": query,
502
+ "answer": ans,
503
+ "citations": cites,
504
+ "final_verdict": final_v,
505
+ "final_reason": final_r,
506
+ "heur_verdict": h_verdict, "heur_reason": h_reason,
507
+ "judge_verdict": j_verdict, "judge_reason": j_reason,
508
+ "top_sim": hits[0]["score"] if hits else 0.0,
509
+ "used_ids": used_ids
510
+ }
511
+ os.makedirs(os.path.dirname(QUALITY_LOG), exist_ok=True)
512
+ log_quality(rec)
513
+
514
+ # 8) ส่งผลกลับ
515
+ return ans, cites, final_v, final_r
516
+
517
+ # ---------- ตัวอย่างเรียกใช้งาน ----------
518
+ # ans, cites, verdict, reason = smart_answer("ອັດຕາດອກເບ້ຍ ໄລຍະສັ້ນ ເທົ່າໃດ?")
519
+ # print(ans, cites, verdict, reason)
520
+
521
+ import gradio as gr
522
+
523
+ def gradio_smart(q):
524
+ try:
525
+ ans, cites, verdict, reason = smart_answer(q, use_judge=True, allow_template_fallback=True)
526
+ cite_str = ", ".join(cites) if cites else "-"
527
+ return f"{ans}\n\nອ້າງອີງ: {cite_str}\nຜົນປະເມີນ: {verdict} — {reason}"
528
+ except Exception as e:
529
+ return f"⚠️ Error: {e}"
530
+
531
+ with gr.Blocks(title="NBB RAG — Smart (Guarded + Judge + Router)") as demo:
532
+ gr.Markdown("### ພິມຄຳຖາມ → ລະບົບຈະສະຫຼຸບ Context")
533
+ q = gr.Textbox(label="ຄຳຖາມ", lines=2)
534
+ btn = gr.Button("ຖາມ")
535
+ out = gr.Textbox(label="ຄຳຕອບ", lines=18)
536
+ btn.click(fn=gradio_smart, inputs=q, outputs=out)
537
 
538
+ demo.launch()