my-gradio-app / rag /query_engine.py
Nguyen Trong Lap
Recreate history without binary blobs
eeb0f9c
import time
import json
from typing import Dict, Any, List
from langchain_chroma import Chroma
from langchain.chains import RetrievalQA
from langchain_community.callbacks.manager import get_openai_callback
from config.settings import MODEL, CHROMA_PATH, EMBEDDING_MODEL, OPENAI_API_KEY, OPENAI_BASE_URL
from langchain_openai import ChatOpenAI
from langchain_huggingface import HuggingFaceEmbeddings
import traceback
# =================================================================
# 1. Khởi tạo Mô hình và RAG Chain (Chỉ load 1 lần)
# =================================================================
try:
# Khởi tạo LLM
print("CHROMA_PATH:", CHROMA_PATH)
llm = ChatOpenAI(model=MODEL, temperature=0.1, api_key= OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
# Khởi tạo Embeddings
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
# Khởi tạo Vector Store và Retriever
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
# Cấu hình top_k = 2 (faster search)
retriever = db.as_retriever(search_kwargs={"k": 2})
# Khởi tạo RetrievalQA Chain
qa_chain = RetrievalQA.from_chain_type(
llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
except Exception as e:
print(f"⚠️ Lỗi khởi tạo RAG Chain: {e}")
traceback.print_exc()
print(f"⚠️ Lỗi khởi tạo RAG Chain: {e}. Vui lòng chạy ingest.py và kiểm tra API Key.")
qa_chain = None
# =================================================================
# 2. Hàm Truy vấn với Context và Log (Thành viên B)
# =================================================================
def query_with_context(query: str, system_prompt: str) -> Dict[str, Any]:
"""
Truy vấn RAG Chain, log thời gian/token, và trả về kết quả JSON.
Args:
query: Câu hỏi của người dùng.
system_prompt: Ngữ cảnh hệ thống cho LLM.
Returns:
Dict[str, Any]: {answer, source_docs, metadata}
"""
if not qa_chain:
return {
"answer": "Hệ thống RAG chưa được khởi tạo. Vui lòng kiểm tra API Key và chạy ingest.py.",
"source_docs": [],
"metadata": {"time_s": 0.0, "tokens": 0, "status": "ERROR"}
}
start_time = time.time()
# Tích hợp System Prompt vào LLM
qa_chain.combine_documents_chain.llm_chain.prompt.messages[0].prompt.template = system_prompt
# Sử dụng get_openai_callback để theo dõi số lượng token và chi phí
with get_openai_callback() as cb:
response = qa_chain.invoke(query)
total_tokens = getattr(cb, "total_tokens", 0)
# Lấy thông tin token
total_tokens = cb.total_tokens
end_time = time.time()
query_time = end_time - start_time
# Lấy kết quả và nguồn tài liệu
answer = response['result']
source_documents = response['source_documents']
# Chuẩn hóa nguồn tài liệu để có thể serialize thành JSON
formatted_sources: List[Dict[str, Any]] = []
for doc in source_documents:
# Get clean preview without truncating mid-sentence
content = doc.page_content.strip()
preview = content[:300] if len(content) > 300 else content
# Get source info from metadata
metadata = doc.metadata
source_name = metadata.get('source', 'Unknown')
formatted_sources.append({
"content_preview": preview,
"metadata": metadata,
"source_name": source_name
})
# Log thời gian và token (Console Log)
print("--- RAG Query Log ---")
print(f"Query: {query}")
print(f"Thời gian truy vấn: {query_time:.2f} giây")
print(f"Tổng số token sử dụng: {total_tokens}")
print("----------------------")
# Xuất kết quả ra JSON
return {
"answer": answer,
"source_docs": formatted_sources,
"metadata": {
"time_s": round(query_time, 2),
"tokens": total_tokens,
"status": "SUCCESS"
}
}