Spaces:
Running
Running
| """ | |
| Advanced RAG techniques for improved retrieval and generation (Best Case 2025) | |
| Includes: LLM-Based Query Expansion, Cross-Encoder Reranking, Contextual Compression, Hybrid Search | |
| """ | |
| from typing import List, Dict, Optional, Tuple | |
| import numpy as np | |
| from dataclasses import dataclass | |
| import re | |
| from sentence_transformers import CrossEncoder | |
| class RetrievedDocument: | |
| """Document retrieved from vector database""" | |
| id: str | |
| text: str | |
| confidence: float | |
| metadata: Dict | |
| class AdvancedRAG: | |
| """Advanced RAG system with 2025 best practices""" | |
| def __init__(self, embedding_service, qdrant_service): | |
| self.embedding_service = embedding_service | |
| self.qdrant_service = qdrant_service | |
| # Initialize Cross-Encoder for reranking (multilingual for Vietnamese support) | |
| print("Loading Cross-Encoder model for reranking...") | |
| # Use multilingual model instead of English-only ms-marco | |
| self.cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1') | |
| print("✓ Cross-Encoder loaded (multilingual)") | |
| def expand_query_llm( | |
| self, | |
| query: str, | |
| hf_client=None | |
| ) -> List[str]: | |
| """ | |
| Expand query using LLM (Best Case 2025) | |
| Generates query variations, sub-queries, and hypothetical answers | |
| Args: | |
| query: Original user query | |
| hf_client: HuggingFace InferenceClient (optional) | |
| Returns: | |
| List of expanded queries | |
| """ | |
| queries = [query] | |
| # Fallback to rule-based if no LLM client | |
| if not hf_client: | |
| return self._expand_query_rule_based(query) | |
| try: | |
| # LLM-based expansion prompt | |
| expansion_prompt = f"""Given this user question, generate 2-3 alternative phrasings or sub-questions that would help retrieve relevant information. | |
| User Question: {query} | |
| Alternative queries (one per line):""" | |
| # Generate expansions | |
| response = "" | |
| for msg in hf_client.chat_completion( | |
| messages=[{"role": "user", "content": expansion_prompt}], | |
| max_tokens=256, | |
| stream=True, | |
| temperature=0.7, | |
| model="openai/gpt-oss-20b" | |
| ): | |
| if msg.choices and msg.choices[0].delta.content: | |
| response += msg.choices[0].delta.content | |
| # Parse expansions | |
| lines = [line.strip() for line in response.split('\n') if line.strip()] | |
| # Filter out numbered lists, dashes, etc. | |
| clean_lines = [] | |
| for line in lines: | |
| # Remove common list markers | |
| cleaned = re.sub(r'^[\d\-\*\•]+[\.\)]\s*', '', line) | |
| if cleaned and len(cleaned) > 5: | |
| clean_lines.append(cleaned) | |
| queries.extend(clean_lines[:3]) # Add top 3 expansions | |
| except Exception as e: | |
| print(f"LLM expansion failed, using rule-based: {e}") | |
| return self._expand_query_rule_based(query) | |
| return queries[:4] # Original + 3 expansions | |
| def _expand_query_rule_based(self, query: str) -> List[str]: | |
| """ | |
| Fallback rule-based query expansion | |
| Simple but effective Vietnamese-aware expansion | |
| """ | |
| queries = [query] | |
| # Vietnamese question words | |
| question_words = ['ai', 'gì', 'nào', 'đâu', 'khi nào', 'như thế nào', | |
| 'sao', 'tại sao', 'có', 'là', 'được', 'không', 'làm sao'] | |
| query_lower = query.lower() | |
| for qw in question_words: | |
| if qw in query_lower: | |
| variant = query_lower.replace(qw, '').strip() | |
| if variant and variant != query_lower: | |
| queries.append(variant) | |
| break # One variation is enough | |
| # Extract key phrases | |
| words = query.split() | |
| if len(words) > 3: | |
| key_phrases = ' '.join(words[1:]) if words[0].lower() in question_words else ' '.join(words[:3]) | |
| if key_phrases not in queries: | |
| queries.append(key_phrases) | |
| return queries[:3] | |
| def multi_query_retrieval( | |
| self, | |
| query: str, | |
| top_k: int = 5, | |
| score_threshold: float = 0.5, | |
| expanded_queries: Optional[List[str]] = None | |
| ) -> List[RetrievedDocument]: | |
| """ | |
| Retrieve documents using multiple query variations | |
| Combines results from all query variations with deduplication | |
| """ | |
| if expanded_queries is None: | |
| expanded_queries = [query] | |
| all_results = {} # Deduplicate by doc_id | |
| for q in expanded_queries: | |
| # Generate embedding for each query variant | |
| query_embedding = self.embedding_service.encode_text(q) | |
| # Search in Qdrant | |
| results = self.qdrant_service.search( | |
| query_embedding=query_embedding, | |
| limit=top_k, | |
| score_threshold=score_threshold | |
| ) | |
| # Add to results (keep highest score for duplicates) | |
| for result in results: | |
| doc_id = result["id"] | |
| if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence: | |
| # Lấy text từ metadata - hỗ trợ cả "text" (string) và "texts" (array) | |
| metadata = result["metadata"] | |
| doc_text = metadata.get("text", "") | |
| if not doc_text and "texts" in metadata: | |
| # Nếu là array, join thành string | |
| texts_arr = metadata.get("texts", []) | |
| if isinstance(texts_arr, list): | |
| doc_text = "\n".join(texts_arr) | |
| else: | |
| doc_text = str(texts_arr) | |
| all_results[doc_id] = RetrievedDocument( | |
| id=doc_id, | |
| text=doc_text, | |
| confidence=result["confidence"], | |
| metadata=metadata | |
| ) | |
| # Sort by confidence and return top_k | |
| sorted_results = sorted(all_results.values(), key=lambda x: x.confidence, reverse=True) | |
| return sorted_results[:top_k * 2] # Return more for reranking | |
| def rerank_documents_cross_encoder( | |
| self, | |
| query: str, | |
| documents: List[RetrievedDocument], | |
| top_k: int = 5 | |
| ) -> List[RetrievedDocument]: | |
| """ | |
| Rerank documents using Cross-Encoder (Best Case 2025) | |
| Cross-Encoder provides superior relevance scoring compared to bi-encoders | |
| Args: | |
| query: Original user query | |
| documents: Retrieved documents to rerank | |
| top_k: Number of top documents to return | |
| Returns: | |
| Reranked documents | |
| """ | |
| if not documents: | |
| return documents | |
| # Prepare query-document pairs for Cross-Encoder | |
| pairs = [[query, doc.text] for doc in documents] | |
| # Get Cross-Encoder scores (raw logits) | |
| ce_scores = self.cross_encoder.predict(pairs) | |
| ce_scores = [float(s) for s in ce_scores] | |
| # Min-Max normalization để scale về 0-1 | |
| # Thay vì sigmoid (cho điểm rất thấp với logits âm) | |
| min_score = min(ce_scores) | |
| max_score = max(ce_scores) | |
| if max_score - min_score > 0.001: # Có sự khác biệt giữa các scores | |
| ce_scores_normalized = [ | |
| (score - min_score) / (max_score - min_score) | |
| for score in ce_scores | |
| ] | |
| else: | |
| # Tất cả scores gần như bằng nhau -> giữ original confidence | |
| ce_scores_normalized = [doc.confidence for doc in documents] | |
| # Combine: 70% Cross-Encoder ranking + 30% original cosine similarity | |
| # Để giữ lại một phần semantic similarity từ embedding | |
| reranked = [] | |
| for doc, ce_norm in zip(documents, ce_scores_normalized): | |
| combined_score = 0.7 * ce_norm + 0.3 * doc.confidence | |
| reranked.append(RetrievedDocument( | |
| id=doc.id, | |
| text=doc.text, | |
| confidence=float(combined_score), | |
| metadata=doc.metadata | |
| )) | |
| # Sort by combined score | |
| reranked.sort(key=lambda x: x.confidence, reverse=True) | |
| return reranked[:top_k] | |
| def compress_context( | |
| self, | |
| query: str, | |
| documents: List[RetrievedDocument], | |
| max_tokens: int = 500 | |
| ) -> List[RetrievedDocument]: | |
| """ | |
| Compress context - giữ nguyên nội dung quan trọng, chỉ truncate nếu quá dài | |
| KHÔNG dùng word overlap vì nó loại bỏ sai thông tin quan trọng | |
| """ | |
| compressed_docs = [] | |
| for doc in documents: | |
| text = doc.text.strip() | |
| # Chỉ truncate nếu text quá dài (ước tính ~4 chars/token) | |
| max_chars = max_tokens * 4 | |
| if len(text) > max_chars: | |
| # Cắt thông minh tại câu gần nhất | |
| truncated = text[:max_chars] | |
| last_period = max( | |
| truncated.rfind('.'), | |
| truncated.rfind('!'), | |
| truncated.rfind('?'), | |
| truncated.rfind('\n') | |
| ) | |
| if last_period > max_chars * 0.5: # Nếu tìm thấy dấu câu ở nửa sau | |
| truncated = truncated[:last_period + 1] | |
| text = truncated.strip() | |
| compressed_docs.append(RetrievedDocument( | |
| id=doc.id, | |
| text=text, | |
| confidence=doc.confidence, | |
| metadata=doc.metadata | |
| )) | |
| return compressed_docs | |
| def _split_sentences(self, text: str) -> List[str]: | |
| """Split text into sentences (Vietnamese-aware)""" | |
| sentences = re.split(r'[.!?]+', text) | |
| return [s.strip() for s in sentences if s.strip()] | |
| def hybrid_rag_pipeline( | |
| self, | |
| query: str, | |
| top_k: int = 5, | |
| score_threshold: float = 0.5, | |
| use_reranking: bool = True, | |
| use_compression: bool = True, | |
| use_query_expansion: bool = True, | |
| max_context_tokens: int = 500, | |
| hf_client=None | |
| ) -> Tuple[List[RetrievedDocument], Dict]: | |
| """ | |
| Complete advanced RAG pipeline (Best Case 2025) | |
| 1. LLM-based query expansion | |
| 2. Multi-query retrieval | |
| 3. Cross-Encoder reranking | |
| 4. Contextual compression | |
| Args: | |
| query: User query | |
| top_k: Number of documents to return | |
| score_threshold: Minimum relevance score | |
| use_reranking: Enable Cross-Encoder reranking | |
| use_compression: Enable context compression | |
| use_query_expansion: Enable LLM-based query expansion | |
| max_context_tokens: Max tokens for compression | |
| hf_client: HuggingFace InferenceClient for expansion | |
| Returns: | |
| (documents, stats) | |
| """ | |
| stats = { | |
| "original_query": query, | |
| "expanded_queries": [], | |
| "initial_results": 0, | |
| "after_rerank": 0, | |
| "after_compression": 0, | |
| "used_cross_encoder": use_reranking, | |
| "used_llm_expansion": use_query_expansion and hf_client is not None | |
| } | |
| # Step 1: Query Expansion (LLM-based or rule-based) | |
| if use_query_expansion: | |
| expanded_queries = self.expand_query_llm(query, hf_client) | |
| else: | |
| expanded_queries = [query] | |
| stats["expanded_queries"] = expanded_queries | |
| # Step 2: Multi-query retrieval | |
| documents = self.multi_query_retrieval( | |
| query=query, | |
| top_k=top_k * 2, # Get more candidates for reranking | |
| score_threshold=score_threshold, | |
| expanded_queries=expanded_queries | |
| ) | |
| stats["initial_results"] = len(documents) | |
| # Step 3: Cross-Encoder Reranking (Best Case 2025) | |
| if use_reranking and documents: | |
| documents = self.rerank_documents_cross_encoder( | |
| query=query, | |
| documents=documents, | |
| top_k=top_k | |
| ) | |
| else: | |
| documents = documents[:top_k] | |
| stats["after_rerank"] = len(documents) | |
| # Step 4: Contextual compression (optional) | |
| if use_compression and documents: | |
| documents = self.compress_context( | |
| query=query, | |
| documents=documents, | |
| max_tokens=max_context_tokens | |
| ) | |
| stats["after_compression"] = len(documents) | |
| return documents, stats | |
| def format_context_for_llm( | |
| self, | |
| documents: List[RetrievedDocument], | |
| include_metadata: bool = True | |
| ) -> str: | |
| """ | |
| Format retrieved documents into context string for LLM | |
| Uses better structure for improved LLM understanding | |
| """ | |
| if not documents: | |
| return "" | |
| context_parts = ["RELEVANT CONTEXT:\n"] | |
| for i, doc in enumerate(documents, 1): | |
| context_parts.append(f"\n--- Document {i} (Relevance: {doc.confidence:.2%}) ---") | |
| context_parts.append(doc.text) | |
| if include_metadata and doc.metadata: | |
| # Add useful metadata | |
| meta_str = [] | |
| for key, value in doc.metadata.items(): | |
| if key not in ['text', 'texts'] and value: | |
| meta_str.append(f"{key}: {value}") | |
| if meta_str: | |
| context_parts.append(f"[Metadata: {', '.join(meta_str)}]") | |
| context_parts.append("\n--- End of Context ---\n") | |
| return "\n".join(context_parts) | |
| def build_rag_prompt( | |
| self, | |
| query: str, | |
| context: str, | |
| system_message: str = "You are a helpful AI assistant." | |
| ) -> str: | |
| """ | |
| Build optimized RAG system prompt for LLM | |
| Query sẽ được gửi riêng trong user message | |
| """ | |
| prompt_template = f"""{system_message} | |
| {context} | |
| HƯỚNG DẪN TRẢ LỜI: | |
| 1. Đóng vai trò là một trợ lý ảo thân thiện, trả lời tự nhiên bằng tiếng Việt. | |
| 2. Dựa vào CONTEXT được cung cấp để trả lời câu hỏi. | |
| 3. KHÔNG copy nguyên văn text từ context. Hãy tổng hợp lại thông tin một cách mạch lạc. | |
| 4. Bắt đầu câu trả lời bằng các cụm từ tự nhiên như: "Dựa trên dữ liệu tôi tìm thấy...", "Tôi có thông tin về các sự kiện sau...", "Có vẻ như đây là những gì bạn đang tìm...". | |
| 5. Nếu có nhiều kết quả, hãy liệt kê ngắn gọn các điểm chính (Tên, Thời gian, Địa điểm). | |
| 6. Nếu context không liên quan, hãy lịch sự nói rằng bạn chưa tìm thấy thông tin phù hợp trong hệ thống.""" | |
| return prompt_template | |