Spaces:
Running
Running
Upload 12 files
Browse files- advanced_rag.py +65 -64
- main.py +6 -14
advanced_rag.py
CHANGED
|
@@ -150,11 +150,22 @@ Alternative queries (one per line):"""
|
|
| 150 |
for result in results:
|
| 151 |
doc_id = result["id"]
|
| 152 |
if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
all_results[doc_id] = RetrievedDocument(
|
| 154 |
id=doc_id,
|
| 155 |
-
text=
|
| 156 |
confidence=result["confidence"],
|
| 157 |
-
metadata=
|
| 158 |
)
|
| 159 |
|
| 160 |
# Sort by confidence and return top_k
|
|
@@ -170,12 +181,12 @@ Alternative queries (one per line):"""
|
|
| 170 |
"""
|
| 171 |
Rerank documents using Cross-Encoder (Best Case 2025)
|
| 172 |
Cross-Encoder provides superior relevance scoring compared to bi-encoders
|
| 173 |
-
|
| 174 |
Args:
|
| 175 |
query: Original user query
|
| 176 |
documents: Retrieved documents to rerank
|
| 177 |
top_k: Number of top documents to return
|
| 178 |
-
|
| 179 |
Returns:
|
| 180 |
Reranked documents
|
| 181 |
"""
|
|
@@ -184,29 +195,38 @@ Alternative queries (one per line):"""
|
|
| 184 |
|
| 185 |
# Prepare query-document pairs for Cross-Encoder
|
| 186 |
pairs = [[query, doc.text] for doc in documents]
|
| 187 |
-
|
| 188 |
-
# Get Cross-Encoder scores
|
| 189 |
ce_scores = self.cross_encoder.predict(pairs)
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
reranked = []
|
| 200 |
-
for doc,
|
| 201 |
-
|
| 202 |
reranked.append(RetrievedDocument(
|
| 203 |
id=doc.id,
|
| 204 |
text=doc.text,
|
| 205 |
-
confidence=float(
|
| 206 |
metadata=doc.metadata
|
| 207 |
))
|
| 208 |
-
|
| 209 |
-
# Sort by
|
| 210 |
reranked.sort(key=lambda x: x.confidence, reverse=True)
|
| 211 |
return reranked[:top_k]
|
| 212 |
|
|
@@ -217,47 +237,32 @@ Alternative queries (one per line):"""
|
|
| 217 |
max_tokens: int = 500
|
| 218 |
) -> List[RetrievedDocument]:
|
| 219 |
"""
|
| 220 |
-
Compress context
|
| 221 |
-
|
| 222 |
"""
|
| 223 |
compressed_docs = []
|
| 224 |
|
| 225 |
for doc in documents:
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
# Reconstruct compressed text (up to max_tokens)
|
| 244 |
-
compressed_text = ""
|
| 245 |
-
word_count = 0
|
| 246 |
-
for sent, score in scored_sentences:
|
| 247 |
-
sent_words = len(sent.split())
|
| 248 |
-
if word_count + sent_words <= max_tokens:
|
| 249 |
-
compressed_text += sent + " "
|
| 250 |
-
word_count += sent_words
|
| 251 |
-
else:
|
| 252 |
-
break
|
| 253 |
-
|
| 254 |
-
# If nothing selected, take original first part
|
| 255 |
-
if not compressed_text.strip():
|
| 256 |
-
compressed_text = doc.text[:max_tokens * 5] # Rough estimate
|
| 257 |
|
| 258 |
compressed_docs.append(RetrievedDocument(
|
| 259 |
id=doc.id,
|
| 260 |
-
text=
|
| 261 |
confidence=doc.confidence,
|
| 262 |
metadata=doc.metadata
|
| 263 |
))
|
|
@@ -386,22 +391,18 @@ Alternative queries (one per line):"""
|
|
| 386 |
system_message: str = "You are a helpful AI assistant."
|
| 387 |
) -> str:
|
| 388 |
"""
|
| 389 |
-
Build optimized RAG prompt for LLM
|
| 390 |
-
|
| 391 |
"""
|
| 392 |
prompt_template = f"""{system_message}
|
| 393 |
|
| 394 |
{context}
|
| 395 |
|
| 396 |
-
|
| 397 |
1. Dựa trên CONTEXT phía trên, hãy trả lời câu hỏi của người dùng
|
| 398 |
-
2. Context đã được
|
| 399 |
-
3. Trích dẫn thông tin cụ thể từ context khi trả lời
|
| 400 |
4. CHỈ nói "Tôi không tìm thấy thông tin liên quan" nếu context HOÀN TOÀN KHÔNG đề cập đến chủ đề được hỏi
|
| 401 |
-
5. Trả lời bằng tiếng Việt
|
| 402 |
-
|
| 403 |
-
USER QUESTION: {query}
|
| 404 |
-
|
| 405 |
-
YOUR ANSWER:"""
|
| 406 |
|
| 407 |
return prompt_template
|
|
|
|
| 150 |
for result in results:
|
| 151 |
doc_id = result["id"]
|
| 152 |
if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence:
|
| 153 |
+
# Lấy text từ metadata - hỗ trợ cả "text" (string) và "texts" (array)
|
| 154 |
+
metadata = result["metadata"]
|
| 155 |
+
doc_text = metadata.get("text", "")
|
| 156 |
+
if not doc_text and "texts" in metadata:
|
| 157 |
+
# Nếu là array, join thành string
|
| 158 |
+
texts_arr = metadata.get("texts", [])
|
| 159 |
+
if isinstance(texts_arr, list):
|
| 160 |
+
doc_text = "\n".join(texts_arr)
|
| 161 |
+
else:
|
| 162 |
+
doc_text = str(texts_arr)
|
| 163 |
+
|
| 164 |
all_results[doc_id] = RetrievedDocument(
|
| 165 |
id=doc_id,
|
| 166 |
+
text=doc_text,
|
| 167 |
confidence=result["confidence"],
|
| 168 |
+
metadata=metadata
|
| 169 |
)
|
| 170 |
|
| 171 |
# Sort by confidence and return top_k
|
|
|
|
| 181 |
"""
|
| 182 |
Rerank documents using Cross-Encoder (Best Case 2025)
|
| 183 |
Cross-Encoder provides superior relevance scoring compared to bi-encoders
|
| 184 |
+
|
| 185 |
Args:
|
| 186 |
query: Original user query
|
| 187 |
documents: Retrieved documents to rerank
|
| 188 |
top_k: Number of top documents to return
|
| 189 |
+
|
| 190 |
Returns:
|
| 191 |
Reranked documents
|
| 192 |
"""
|
|
|
|
| 195 |
|
| 196 |
# Prepare query-document pairs for Cross-Encoder
|
| 197 |
pairs = [[query, doc.text] for doc in documents]
|
| 198 |
+
|
| 199 |
+
# Get Cross-Encoder scores (raw logits)
|
| 200 |
ce_scores = self.cross_encoder.predict(pairs)
|
| 201 |
+
ce_scores = [float(s) for s in ce_scores]
|
| 202 |
+
|
| 203 |
+
# Min-Max normalization để scale về 0-1
|
| 204 |
+
# Thay vì sigmoid (cho điểm rất thấp với logits âm)
|
| 205 |
+
min_score = min(ce_scores)
|
| 206 |
+
max_score = max(ce_scores)
|
| 207 |
+
|
| 208 |
+
if max_score - min_score > 0.001: # Có sự khác biệt giữa các scores
|
| 209 |
+
ce_scores_normalized = [
|
| 210 |
+
(score - min_score) / (max_score - min_score)
|
| 211 |
+
for score in ce_scores
|
| 212 |
+
]
|
| 213 |
+
else:
|
| 214 |
+
# Tất cả scores gần như bằng nhau -> giữ original confidence
|
| 215 |
+
ce_scores_normalized = [doc.confidence for doc in documents]
|
| 216 |
+
|
| 217 |
+
# Combine: 70% Cross-Encoder ranking + 30% original cosine similarity
|
| 218 |
+
# Để giữ lại một phần semantic similarity từ embedding
|
| 219 |
reranked = []
|
| 220 |
+
for doc, ce_norm in zip(documents, ce_scores_normalized):
|
| 221 |
+
combined_score = 0.7 * ce_norm + 0.3 * doc.confidence
|
| 222 |
reranked.append(RetrievedDocument(
|
| 223 |
id=doc.id,
|
| 224 |
text=doc.text,
|
| 225 |
+
confidence=float(combined_score),
|
| 226 |
metadata=doc.metadata
|
| 227 |
))
|
| 228 |
+
|
| 229 |
+
# Sort by combined score
|
| 230 |
reranked.sort(key=lambda x: x.confidence, reverse=True)
|
| 231 |
return reranked[:top_k]
|
| 232 |
|
|
|
|
| 237 |
max_tokens: int = 500
|
| 238 |
) -> List[RetrievedDocument]:
|
| 239 |
"""
|
| 240 |
+
Compress context - giữ nguyên nội dung quan trọng, chỉ truncate nếu quá dài
|
| 241 |
+
KHÔNG dùng word overlap vì nó loại bỏ sai thông tin quan trọng
|
| 242 |
"""
|
| 243 |
compressed_docs = []
|
| 244 |
|
| 245 |
for doc in documents:
|
| 246 |
+
text = doc.text.strip()
|
| 247 |
+
|
| 248 |
+
# Chỉ truncate nếu text quá dài (ước tính ~4 chars/token)
|
| 249 |
+
max_chars = max_tokens * 4
|
| 250 |
+
if len(text) > max_chars:
|
| 251 |
+
# Cắt thông minh tại câu gần nhất
|
| 252 |
+
truncated = text[:max_chars]
|
| 253 |
+
last_period = max(
|
| 254 |
+
truncated.rfind('.'),
|
| 255 |
+
truncated.rfind('!'),
|
| 256 |
+
truncated.rfind('?'),
|
| 257 |
+
truncated.rfind('\n')
|
| 258 |
+
)
|
| 259 |
+
if last_period > max_chars * 0.5: # Nếu tìm thấy dấu câu ở nửa sau
|
| 260 |
+
truncated = truncated[:last_period + 1]
|
| 261 |
+
text = truncated.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
compressed_docs.append(RetrievedDocument(
|
| 264 |
id=doc.id,
|
| 265 |
+
text=text,
|
| 266 |
confidence=doc.confidence,
|
| 267 |
metadata=doc.metadata
|
| 268 |
))
|
|
|
|
| 391 |
system_message: str = "You are a helpful AI assistant."
|
| 392 |
) -> str:
|
| 393 |
"""
|
| 394 |
+
Build optimized RAG system prompt for LLM
|
| 395 |
+
Query sẽ được gửi riêng trong user message
|
| 396 |
"""
|
| 397 |
prompt_template = f"""{system_message}
|
| 398 |
|
| 399 |
{context}
|
| 400 |
|
| 401 |
+
HƯỚNG DẪN TRẢ LỜI:
|
| 402 |
1. Dựa trên CONTEXT phía trên, hãy trả lời câu hỏi của người dùng
|
| 403 |
+
2. Context đã được hệ thống tìm kiếm và lọc - HÃY SỬ DỤNG thông tin này để trả lời
|
| 404 |
+
3. Trích dẫn thông tin cụ thể từ context khi trả lời (tên sự kiện, địa điểm, thời gian, v.v.)
|
| 405 |
4. CHỈ nói "Tôi không tìm thấy thông tin liên quan" nếu context HOÀN TOÀN KHÔNG đề cập đến chủ đề được hỏi
|
| 406 |
+
5. Trả lời bằng tiếng Việt, ngắn gọn và đầy đủ thông tin"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
return prompt_template
|
main.py
CHANGED
|
@@ -778,20 +778,12 @@ Example:
|
|
| 778 |
model="openai/gpt-oss-20b"
|
| 779 |
)
|
| 780 |
|
| 781 |
-
# Build messages
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
{"role": "user", "content": system_message}
|
| 788 |
-
]
|
| 789 |
-
else:
|
| 790 |
-
# Basic RAG or no RAG - normal message flow
|
| 791 |
-
messages = [
|
| 792 |
-
{"role": "system", "content": system_message},
|
| 793 |
-
{"role": "user", "content": request.message}
|
| 794 |
-
]
|
| 795 |
|
| 796 |
# Generate response
|
| 797 |
response = ""
|
|
|
|
| 778 |
model="openai/gpt-oss-20b"
|
| 779 |
)
|
| 780 |
|
| 781 |
+
# Build messages - luôn dùng cấu trúc chuẩn
|
| 782 |
+
# System = instructions + context, User = query
|
| 783 |
+
messages = [
|
| 784 |
+
{"role": "system", "content": system_message},
|
| 785 |
+
{"role": "user", "content": request.message}
|
| 786 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 787 |
|
| 788 |
# Generate response
|
| 789 |
response = ""
|