minhvtt commited on
Commit
fba01f9
·
verified ·
1 Parent(s): b85b8b1

Upload 12 files

Browse files
Files changed (2) hide show
  1. advanced_rag.py +65 -64
  2. main.py +6 -14
advanced_rag.py CHANGED
@@ -150,11 +150,22 @@ Alternative queries (one per line):"""
150
  for result in results:
151
  doc_id = result["id"]
152
  if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence:
 
 
 
 
 
 
 
 
 
 
 
153
  all_results[doc_id] = RetrievedDocument(
154
  id=doc_id,
155
- text=result["metadata"].get("text", ""),
156
  confidence=result["confidence"],
157
- metadata=result["metadata"]
158
  )
159
 
160
  # Sort by confidence and return top_k
@@ -170,12 +181,12 @@ Alternative queries (one per line):"""
170
  """
171
  Rerank documents using Cross-Encoder (Best Case 2025)
172
  Cross-Encoder provides superior relevance scoring compared to bi-encoders
173
-
174
  Args:
175
  query: Original user query
176
  documents: Retrieved documents to rerank
177
  top_k: Number of top documents to return
178
-
179
  Returns:
180
  Reranked documents
181
  """
@@ -184,29 +195,38 @@ Alternative queries (one per line):"""
184
 
185
  # Prepare query-document pairs for Cross-Encoder
186
  pairs = [[query, doc.text] for doc in documents]
187
-
188
- # Get Cross-Encoder scores
189
  ce_scores = self.cross_encoder.predict(pairs)
190
-
191
- # Normalize CE scores using sigmoid (convert logits to 0-1 range)
192
- import math
193
- def sigmoid(x):
194
- return 1 / (1 + math.exp(-x))
195
-
196
- ce_scores_normalized = [sigmoid(float(score)) for score in ce_scores]
197
-
198
- # Create reranked documents with normalized scores
 
 
 
 
 
 
 
 
 
199
  reranked = []
200
- for doc, ce_score_norm in zip(documents, ce_scores_normalized):
201
- # Use ONLY Cross-Encoder score (it's more accurate than cosine similarity)
202
  reranked.append(RetrievedDocument(
203
  id=doc.id,
204
  text=doc.text,
205
- confidence=float(ce_score_norm),
206
  metadata=doc.metadata
207
  ))
208
-
209
- # Sort by Cross-Encoder score
210
  reranked.sort(key=lambda x: x.confidence, reverse=True)
211
  return reranked[:top_k]
212
 
@@ -217,47 +237,32 @@ Alternative queries (one per line):"""
217
  max_tokens: int = 500
218
  ) -> List[RetrievedDocument]:
219
  """
220
- Compress context to most relevant parts
221
- Remove redundant information and keep only relevant sentences
222
  """
223
  compressed_docs = []
224
 
225
  for doc in documents:
226
- # Split into sentences
227
- sentences = self._split_sentences(doc.text)
228
-
229
- # Score each sentence based on relevance to query
230
- scored_sentences = []
231
- query_words = set(query.lower().split())
232
-
233
- for sent in sentences:
234
- sent_words = set(sent.lower().split())
235
- # Simple relevance: word overlap
236
- overlap = len(query_words & sent_words)
237
- if overlap > 0:
238
- scored_sentences.append((sent, overlap))
239
-
240
- # Sort by relevance and take top sentences
241
- scored_sentences.sort(key=lambda x: x[1], reverse=True)
242
-
243
- # Reconstruct compressed text (up to max_tokens)
244
- compressed_text = ""
245
- word_count = 0
246
- for sent, score in scored_sentences:
247
- sent_words = len(sent.split())
248
- if word_count + sent_words <= max_tokens:
249
- compressed_text += sent + " "
250
- word_count += sent_words
251
- else:
252
- break
253
-
254
- # If nothing selected, take original first part
255
- if not compressed_text.strip():
256
- compressed_text = doc.text[:max_tokens * 5] # Rough estimate
257
 
258
  compressed_docs.append(RetrievedDocument(
259
  id=doc.id,
260
- text=compressed_text.strip(),
261
  confidence=doc.confidence,
262
  metadata=doc.metadata
263
  ))
@@ -386,22 +391,18 @@ Alternative queries (one per line):"""
386
  system_message: str = "You are a helpful AI assistant."
387
  ) -> str:
388
  """
389
- Build optimized RAG prompt for LLM
390
- Uses best practices for prompt engineering
391
  """
392
  prompt_template = f"""{system_message}
393
 
394
  {context}
395
 
396
- INSTRUCTIONS:
397
  1. Dựa trên CONTEXT phía trên, hãy trả lời câu hỏi của người dùng
398
- 2. Context đã được lọc với độ tương đồng cao - LUÔN SỬ DỤNG nếu liên quan đến câu hỏi
399
- 3. Trích dẫn thông tin cụ thể từ context khi trả lời
400
  4. CHỈ nói "Tôi không tìm thấy thông tin liên quan" nếu context HOÀN TOÀN KHÔNG đề cập đến chủ đề được hỏi
401
- 5. Trả lời bằng tiếng Việt nếu câu hỏi tiếng Việt
402
-
403
- USER QUESTION: {query}
404
-
405
- YOUR ANSWER:"""
406
 
407
  return prompt_template
 
150
  for result in results:
151
  doc_id = result["id"]
152
  if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence:
153
+ # Lấy text từ metadata - hỗ trợ cả "text" (string) và "texts" (array)
154
+ metadata = result["metadata"]
155
+ doc_text = metadata.get("text", "")
156
+ if not doc_text and "texts" in metadata:
157
+ # Nếu là array, join thành string
158
+ texts_arr = metadata.get("texts", [])
159
+ if isinstance(texts_arr, list):
160
+ doc_text = "\n".join(texts_arr)
161
+ else:
162
+ doc_text = str(texts_arr)
163
+
164
  all_results[doc_id] = RetrievedDocument(
165
  id=doc_id,
166
+ text=doc_text,
167
  confidence=result["confidence"],
168
+ metadata=metadata
169
  )
170
 
171
  # Sort by confidence and return top_k
 
181
  """
182
  Rerank documents using Cross-Encoder (Best Case 2025)
183
  Cross-Encoder provides superior relevance scoring compared to bi-encoders
184
+
185
  Args:
186
  query: Original user query
187
  documents: Retrieved documents to rerank
188
  top_k: Number of top documents to return
189
+
190
  Returns:
191
  Reranked documents
192
  """
 
195
 
196
  # Prepare query-document pairs for Cross-Encoder
197
  pairs = [[query, doc.text] for doc in documents]
198
+
199
+ # Get Cross-Encoder scores (raw logits)
200
  ce_scores = self.cross_encoder.predict(pairs)
201
+ ce_scores = [float(s) for s in ce_scores]
202
+
203
+ # Min-Max normalization để scale về 0-1
204
+ # Thay vì sigmoid (cho điểm rất thấp với logits âm)
205
+ min_score = min(ce_scores)
206
+ max_score = max(ce_scores)
207
+
208
+ if max_score - min_score > 0.001: # Có sự khác biệt giữa các scores
209
+ ce_scores_normalized = [
210
+ (score - min_score) / (max_score - min_score)
211
+ for score in ce_scores
212
+ ]
213
+ else:
214
+ # Tất cả scores gần như bằng nhau -> giữ original confidence
215
+ ce_scores_normalized = [doc.confidence for doc in documents]
216
+
217
+ # Combine: 70% Cross-Encoder ranking + 30% original cosine similarity
218
+ # Để giữ lại một phần semantic similarity từ embedding
219
  reranked = []
220
+ for doc, ce_norm in zip(documents, ce_scores_normalized):
221
+ combined_score = 0.7 * ce_norm + 0.3 * doc.confidence
222
  reranked.append(RetrievedDocument(
223
  id=doc.id,
224
  text=doc.text,
225
+ confidence=float(combined_score),
226
  metadata=doc.metadata
227
  ))
228
+
229
+ # Sort by combined score
230
  reranked.sort(key=lambda x: x.confidence, reverse=True)
231
  return reranked[:top_k]
232
 
 
237
  max_tokens: int = 500
238
  ) -> List[RetrievedDocument]:
239
  """
240
+ Compress context - giữ nguyên nội dung quan trọng, chỉ truncate nếu quá dài
241
+ KHÔNG dùng word overlap loại bỏ sai thông tin quan trọng
242
  """
243
  compressed_docs = []
244
 
245
  for doc in documents:
246
+ text = doc.text.strip()
247
+
248
+ # Chỉ truncate nếu text quá dài (ước tính ~4 chars/token)
249
+ max_chars = max_tokens * 4
250
+ if len(text) > max_chars:
251
+ # Cắt thông minh tại câu gần nhất
252
+ truncated = text[:max_chars]
253
+ last_period = max(
254
+ truncated.rfind('.'),
255
+ truncated.rfind('!'),
256
+ truncated.rfind('?'),
257
+ truncated.rfind('\n')
258
+ )
259
+ if last_period > max_chars * 0.5: # Nếu tìm thấy dấu câu ở nửa sau
260
+ truncated = truncated[:last_period + 1]
261
+ text = truncated.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  compressed_docs.append(RetrievedDocument(
264
  id=doc.id,
265
+ text=text,
266
  confidence=doc.confidence,
267
  metadata=doc.metadata
268
  ))
 
391
  system_message: str = "You are a helpful AI assistant."
392
  ) -> str:
393
  """
394
+ Build optimized RAG system prompt for LLM
395
+ Query sẽ được gửi riêng trong user message
396
  """
397
  prompt_template = f"""{system_message}
398
 
399
  {context}
400
 
401
+ HƯỚNG DẪN TRẢ LỜI:
402
  1. Dựa trên CONTEXT phía trên, hãy trả lời câu hỏi của người dùng
403
+ 2. Context đã được hệ thống tìm kiếm lọc - HÃY SỬ DỤNG thông tin này để trả lời
404
+ 3. Trích dẫn thông tin cụ thể từ context khi trả lời (tên sự kiện, địa điểm, thời gian, v.v.)
405
  4. CHỈ nói "Tôi không tìm thấy thông tin liên quan" nếu context HOÀN TOÀN KHÔNG đề cập đến chủ đề được hỏi
406
+ 5. Trả lời bằng tiếng Việt, ngắn gọn đầy đủ thông tin"""
 
 
 
 
407
 
408
  return prompt_template
main.py CHANGED
@@ -778,20 +778,12 @@ Example:
778
  model="openai/gpt-oss-20b"
779
  )
780
 
781
- # Build messages
782
- if request.use_advanced_rag and context_used:
783
- # Advanced RAG prompt already contains query + instructions
784
- # Just send it as system message, user message is empty
785
- messages = [
786
- {"role": "system", "content": "You are a helpful assistant."},
787
- {"role": "user", "content": system_message}
788
- ]
789
- else:
790
- # Basic RAG or no RAG - normal message flow
791
- messages = [
792
- {"role": "system", "content": system_message},
793
- {"role": "user", "content": request.message}
794
- ]
795
 
796
  # Generate response
797
  response = ""
 
778
  model="openai/gpt-oss-20b"
779
  )
780
 
781
+ # Build messages - luôn dùng cấu trúc chuẩn
782
+ # System = instructions + context, User = query
783
+ messages = [
784
+ {"role": "system", "content": system_message},
785
+ {"role": "user", "content": request.message}
786
+ ]
 
 
 
 
 
 
 
 
787
 
788
  # Generate response
789
  response = ""