Spaces:
Sleeping
Sleeping
| import json | |
| import numpy as np | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| import re | |
| # -------------------- Config -------------------- | |
| TOP_K = 5 | |
| FINAL_TOP_N = 1 | |
| MIN_CONF = 0.14 | |
| CHUNK_LIMIT = 300 | |
| MAX_TOKENS = 256 | |
| TEMP = 0.2 | |
| QUALITY_LOG = "quality_feedback.jsonl" | |
| # -------------------- Load Dataset -------------------- | |
| DATASET_PATH = "nbb_merged_full.json" | |
| with open(DATASET_PATH, "r", encoding="utf-8") as f: | |
| RAW_DATA = json.load(f) | |
| def normalize_record(d): | |
| lo_text = "" | |
| if "content" in d and isinstance(d["content"], dict): | |
| lo_text = d["content"].get("lo", "") | |
| elif "data" in d: | |
| lo_text = d["data"].get("answer", "") | |
| return { | |
| "id": d.get("id", str(hash(json.dumps(d)))), | |
| "title": d.get("title", ""), | |
| "section": d.get("section", ""), | |
| "content": {"lo": lo_text}, | |
| } | |
| DOCS = [normalize_record(x) for x in RAW_DATA if normalize_record(x)["content"]["lo"].strip()] | |
| assert DOCS, "Dataset ບໍ່ມີ content.lo" | |
| CORPUS = [d["content"]["lo"] for d in DOCS] | |
| IDS = [d["id"] for d in DOCS] | |
| ID2DOC = {d["id"]: d for d in DOCS} | |
| vectorizer = TfidfVectorizer(ngram_range=(1,2), min_df=1, max_df=0.95, sublinear_tf=True) | |
| X = vectorizer.fit_transform(CORPUS) | |
| # -------------------- Search -------------------- | |
| def search(query, k=TOP_K): | |
| qv = vectorizer.transform([query]) | |
| sims = cosine_similarity(qv, X)[0] | |
| idxs = np.argsort(-sims)[:k] | |
| return [{"id": IDS[i], "score": sims[i]} for i in idxs] | |
| # -------------------- Load LLM -------------------- | |
| MODEL_PATH = hf_hub_download( | |
| repo_id="Qwen/Qwen2.5-1.5B-Instruct-GGUF", | |
| filename="qwen2.5-1.5b-instruct-q4_k_m.gguf" | |
| ) | |
| LLM = Llama( | |
| model_path=MODEL_PATH, | |
| n_ctx=1024, | |
| n_threads=4, | |
| n_gpu_layers=0, | |
| n_batch=128, | |
| logits_all=False, | |
| verbose=False | |
| ) | |
| SYSTEM_RULES = """ | |
| You are a Lao banking assistant for NAYOBY BANK (NBB). | |
| HARD RULES (do not break): | |
| 1) Answer ONLY from the provided Context. Do NOT use outside knowledge or make assumptions. | |
| 2) If the answer is not clearly in the Context, reply in Lao: "ຂໍອະໄພ ຂ້ອຍບໍ່ພົບຂໍ້ມູນໃນຖານຄວາມຮູ້." | |
| 3) Cite the evidence ids at the end in square brackets (1–3 ids). | |
| 4) Default reply in Lao; if the whole user question is Thai/English, reply with that language; keep product terms exactly as in Context. | |
| 5) Never invent numbers, dates, fees, branches, or contacts beyond the Context. | |
| STYLE: | |
| - Concise (≤ 100 Lao words). Direct answer first, bullets if needed. | |
| - Keep terminology exactly as in Context. | |
| FORMAT: | |
| - End the last line with citations like: [id_a, id_b] | |
| """ | |
| def truncate(text, limit=CHUNK_LIMIT): | |
| return text if len(text) <= limit else text[:limit] + "..." | |
| def build_prompt(question, hits): | |
| ctx = "\n\n".join([ | |
| truncate(ID2DOC[h['id']]['content']['lo']) | |
| for h in hits[:FINAL_TOP_N] | |
| ]) | |
| return f"{SYSTEM_RULES}\n\nContext:\n{ctx}\n\nQuestion:\n{question}\n\nAnswer:" | |
| # -------------------- Helper functions -------------------- | |
| def limit_words(text, max_words=100): | |
| words = text.split() | |
| return " ".join(words[:max_words]) | |
| def clean_citations(text): | |
| # เอา citation ซ้ำ ๆ ออก และจำกัดไม่เกิน 3 id | |
| match = re.findall(r"\[(.*?)\]", text) | |
| if not match: | |
| return text | |
| ids = match[-1].split(",") # ใช้ citation ชุดสุดท้าย | |
| ids = [x.strip() for x in ids if x.strip()] | |
| ids = list(dict.fromkeys(ids))[:3] # ลบซ้ำ + จำกัด 3 | |
| text = re.sub(r"\[.*?\]$", "", text).strip() | |
| return f"{text} [{', '.join(ids)}]" | |
| # -------------------- Answer -------------------- | |
| def smart_answer(message): | |
| hits = search(message, k=TOP_K) | |
| if not hits or hits[0]["score"] < MIN_CONF: | |
| return "ຂໍອະໄພ ບໍ່ພົບຂໍ້ມູນໃນຖານຄວາມຮູ້." | |
| prompt = build_prompt(message, hits) | |
| out = LLM( | |
| prompt, | |
| max_tokens=MAX_TOKENS, | |
| temperature=TEMP, | |
| stop=["\n\nQuestion:", "Context:", "Answer:", "</s>"] | |
| ) | |
| answer = out["choices"][0]["text"].strip() | |
| answer = limit_words(answer, 100) | |
| answer = clean_citations(answer) | |
| return answer | |
| # -------------------- Gradio Chatbot -------------------- | |
| def respond(message, history): | |
| answer = smart_answer(message) | |
| history = history + [(message, answer)] | |
| return history | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## ທົດລອງ RDB Chatbot") | |
| chatbot_ui = gr.Chatbot() | |
| msg = gr.Textbox(placeholder="ພິມຄຳຖາມບ່ອນນີ້...") | |
| msg.submit(respond, [msg, chatbot_ui], chatbot_ui) | |
| if __name__ == "__main__": | |
| demo.launch() | |