Spaces:
Running
Running
| """ | |
| Advanced RAG techniques for improved retrieval and generation (Best Case 2025) | |
| Includes: LLM-Based Query Expansion, Cross-Encoder Reranking, Contextual Compression, Hybrid Search | |
| """ | |
| from typing import List, Dict, Optional, Tuple | |
| import numpy as np | |
| from dataclasses import dataclass | |
| import re | |
| from sentence_transformers import CrossEncoder | |
| class RetrievedDocument: | |
| """Document retrieved from vector database""" | |
| id: str | |
| text: str | |
| confidence: float | |
| metadata: Dict | |
| class AdvancedRAG: | |
| """Advanced RAG system with 2025 best practices""" | |
| def __init__(self, embedding_service, qdrant_service): | |
| self.embedding_service = embedding_service | |
| self.qdrant_service = qdrant_service | |
| # Initialize Cross-Encoder for reranking (state-of-the-art) | |
| print("Loading Cross-Encoder model for reranking...") | |
| self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
| print("✓ Cross-Encoder loaded") | |
| def expand_query_llm( | |
| self, | |
| query: str, | |
| hf_client=None | |
| ) -> List[str]: | |
| """ | |
| Expand query using LLM (Best Case 2025) | |
| Generates query variations, sub-queries, and hypothetical answers | |
| Args: | |
| query: Original user query | |
| hf_client: HuggingFace InferenceClient (optional) | |
| Returns: | |
| List of expanded queries | |
| """ | |
| queries = [query] | |
| # Fallback to rule-based if no LLM client | |
| if not hf_client: | |
| return self._expand_query_rule_based(query) | |
| try: | |
| # LLM-based expansion prompt | |
| expansion_prompt = f"""Given this user question, generate 2-3 alternative phrasings or sub-questions that would help retrieve relevant information. | |
| User Question: {query} | |
| Alternative queries (one per line):""" | |
| # Generate expansions | |
| response = "" | |
| for msg in hf_client.chat_completion( | |
| messages=[{"role": "user", "content": expansion_prompt}], | |
| max_tokens=150, | |
| stream=True, | |
| temperature=0.7 | |
| ): | |
| if msg.choices and msg.choices[0].delta.content: | |
| response += msg.choices[0].delta.content | |
| # Parse expansions | |
| lines = [line.strip() for line in response.split('\n') if line.strip()] | |
| # Filter out numbered lists, dashes, etc. | |
| clean_lines = [] | |
| for line in lines: | |
| # Remove common list markers | |
| cleaned = re.sub(r'^[\d\-\*\•]+[\.\)]\s*', '', line) | |
| if cleaned and len(cleaned) > 5: | |
| clean_lines.append(cleaned) | |
| queries.extend(clean_lines[:3]) # Add top 3 expansions | |
| except Exception as e: | |
| print(f"LLM expansion failed, using rule-based: {e}") | |
| return self._expand_query_rule_based(query) | |
| return queries[:4] # Original + 3 expansions | |
| def _expand_query_rule_based(self, query: str) -> List[str]: | |
| """ | |
| Fallback rule-based query expansion | |
| Simple but effective Vietnamese-aware expansion | |
| """ | |
| queries = [query] | |
| # Vietnamese question words | |
| question_words = ['ai', 'gì', 'nào', 'đâu', 'khi nào', 'như thế nào', | |
| 'sao', 'tại sao', 'có', 'là', 'được', 'không', 'làm sao'] | |
| query_lower = query.lower() | |
| for qw in question_words: | |
| if qw in query_lower: | |
| variant = query_lower.replace(qw, '').strip() | |
| if variant and variant != query_lower: | |
| queries.append(variant) | |
| break # One variation is enough | |
| # Extract key phrases | |
| words = query.split() | |
| if len(words) > 3: | |
| key_phrases = ' '.join(words[1:]) if words[0].lower() in question_words else ' '.join(words[:3]) | |
| if key_phrases not in queries: | |
| queries.append(key_phrases) | |
| return queries[:3] | |
| def multi_query_retrieval( | |
| self, | |
| query: str, | |
| top_k: int = 5, | |
| score_threshold: float = 0.5, | |
| expanded_queries: Optional[List[str]] = None | |
| ) -> List[RetrievedDocument]: | |
| """ | |
| Retrieve documents using multiple query variations | |
| Combines results from all query variations with deduplication | |
| """ | |
| if expanded_queries is None: | |
| expanded_queries = [query] | |
| all_results = {} # Deduplicate by doc_id | |
| for q in expanded_queries: | |
| # Generate embedding for each query variant | |
| query_embedding = self.embedding_service.encode_text(q) | |
| # Search in Qdrant | |
| results = self.qdrant_service.search( | |
| query_embedding=query_embedding, | |
| limit=top_k, | |
| score_threshold=score_threshold | |
| ) | |
| # Add to results (keep highest score for duplicates) | |
| for result in results: | |
| doc_id = result["id"] | |
| if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence: | |
| all_results[doc_id] = RetrievedDocument( | |
| id=doc_id, | |
| text=result["metadata"].get("text", ""), | |
| confidence=result["confidence"], | |
| metadata=result["metadata"] | |
| ) | |
| # Sort by confidence and return top_k | |
| sorted_results = sorted(all_results.values(), key=lambda x: x.confidence, reverse=True) | |
| return sorted_results[:top_k * 2] # Return more for reranking | |
| def rerank_documents_cross_encoder( | |
| self, | |
| query: str, | |
| documents: List[RetrievedDocument], | |
| top_k: int = 5 | |
| ) -> List[RetrievedDocument]: | |
| """ | |
| Rerank documents using Cross-Encoder (Best Case 2025) | |
| Cross-Encoder provides superior relevance scoring compared to bi-encoders | |
| Args: | |
| query: Original user query | |
| documents: Retrieved documents to rerank | |
| top_k: Number of top documents to return | |
| Returns: | |
| Reranked documents | |
| """ | |
| if not documents: | |
| return documents | |
| # Prepare query-document pairs for Cross-Encoder | |
| pairs = [[query, doc.text] for doc in documents] | |
| # Get Cross-Encoder scores | |
| ce_scores = self.cross_encoder.predict(pairs) | |
| # Create reranked documents with new scores | |
| reranked = [] | |
| for doc, ce_score in zip(documents, ce_scores): | |
| # Combine CE score with original confidence (weighted) | |
| combined_score = 0.7 * float(ce_score) + 0.3 * doc.confidence | |
| reranked.append(RetrievedDocument( | |
| id=doc.id, | |
| text=doc.text, | |
| confidence=float(combined_score), | |
| metadata=doc.metadata | |
| )) | |
| # Sort by new combined score | |
| reranked.sort(key=lambda x: x.confidence, reverse=True) | |
| return reranked[:top_k] | |
| def compress_context( | |
| self, | |
| query: str, | |
| documents: List[RetrievedDocument], | |
| max_tokens: int = 500 | |
| ) -> List[RetrievedDocument]: | |
| """ | |
| Compress context to most relevant parts | |
| Remove redundant information and keep only relevant sentences | |
| """ | |
| compressed_docs = [] | |
| for doc in documents: | |
| # Split into sentences | |
| sentences = self._split_sentences(doc.text) | |
| # Score each sentence based on relevance to query | |
| scored_sentences = [] | |
| query_words = set(query.lower().split()) | |
| for sent in sentences: | |
| sent_words = set(sent.lower().split()) | |
| # Simple relevance: word overlap | |
| overlap = len(query_words & sent_words) | |
| if overlap > 0: | |
| scored_sentences.append((sent, overlap)) | |
| # Sort by relevance and take top sentences | |
| scored_sentences.sort(key=lambda x: x[1], reverse=True) | |
| # Reconstruct compressed text (up to max_tokens) | |
| compressed_text = "" | |
| word_count = 0 | |
| for sent, score in scored_sentences: | |
| sent_words = len(sent.split()) | |
| if word_count + sent_words <= max_tokens: | |
| compressed_text += sent + " " | |
| word_count += sent_words | |
| else: | |
| break | |
| # If nothing selected, take original first part | |
| if not compressed_text.strip(): | |
| compressed_text = doc.text[:max_tokens * 5] # Rough estimate | |
| compressed_docs.append(RetrievedDocument( | |
| id=doc.id, | |
| text=compressed_text.strip(), | |
| confidence=doc.confidence, | |
| metadata=doc.metadata | |
| )) | |
| return compressed_docs | |
| def _split_sentences(self, text: str) -> List[str]: | |
| """Split text into sentences (Vietnamese-aware)""" | |
| sentences = re.split(r'[.!?]+', text) | |
| return [s.strip() for s in sentences if s.strip()] | |
| def hybrid_rag_pipeline( | |
| self, | |
| query: str, | |
| top_k: int = 5, | |
| score_threshold: float = 0.5, | |
| use_reranking: bool = True, | |
| use_compression: bool = True, | |
| use_query_expansion: bool = True, | |
| max_context_tokens: int = 500, | |
| hf_client=None | |
| ) -> Tuple[List[RetrievedDocument], Dict]: | |
| """ | |
| Complete advanced RAG pipeline (Best Case 2025) | |
| 1. LLM-based query expansion | |
| 2. Multi-query retrieval | |
| 3. Cross-Encoder reranking | |
| 4. Contextual compression | |
| Args: | |
| query: User query | |
| top_k: Number of documents to return | |
| score_threshold: Minimum relevance score | |
| use_reranking: Enable Cross-Encoder reranking | |
| use_compression: Enable context compression | |
| use_query_expansion: Enable LLM-based query expansion | |
| max_context_tokens: Max tokens for compression | |
| hf_client: HuggingFace InferenceClient for expansion | |
| Returns: | |
| (documents, stats) | |
| """ | |
| stats = { | |
| "original_query": query, | |
| "expanded_queries": [], | |
| "initial_results": 0, | |
| "after_rerank": 0, | |
| "after_compression": 0, | |
| "used_cross_encoder": use_reranking, | |
| "used_llm_expansion": use_query_expansion and hf_client is not None | |
| } | |
| # Step 1: Query Expansion (LLM-based or rule-based) | |
| if use_query_expansion: | |
| expanded_queries = self.expand_query_llm(query, hf_client) | |
| else: | |
| expanded_queries = [query] | |
| stats["expanded_queries"] = expanded_queries | |
| # Step 2: Multi-query retrieval | |
| documents = self.multi_query_retrieval( | |
| query=query, | |
| top_k=top_k * 2, # Get more candidates for reranking | |
| score_threshold=score_threshold, | |
| expanded_queries=expanded_queries | |
| ) | |
| stats["initial_results"] = len(documents) | |
| # Step 3: Cross-Encoder Reranking (Best Case 2025) | |
| if use_reranking and documents: | |
| documents = self.rerank_documents_cross_encoder( | |
| query=query, | |
| documents=documents, | |
| top_k=top_k | |
| ) | |
| else: | |
| documents = documents[:top_k] | |
| stats["after_rerank"] = len(documents) | |
| # Step 4: Contextual compression (optional) | |
| if use_compression and documents: | |
| documents = self.compress_context( | |
| query=query, | |
| documents=documents, | |
| max_tokens=max_context_tokens | |
| ) | |
| stats["after_compression"] = len(documents) | |
| return documents, stats | |
| def format_context_for_llm( | |
| self, | |
| documents: List[RetrievedDocument], | |
| include_metadata: bool = True | |
| ) -> str: | |
| """ | |
| Format retrieved documents into context string for LLM | |
| Uses better structure for improved LLM understanding | |
| """ | |
| if not documents: | |
| return "" | |
| context_parts = ["RELEVANT CONTEXT:\n"] | |
| for i, doc in enumerate(documents, 1): | |
| context_parts.append(f"\n--- Document {i} (Relevance: {doc.confidence:.2%}) ---") | |
| context_parts.append(doc.text) | |
| if include_metadata and doc.metadata: | |
| # Add useful metadata | |
| meta_str = [] | |
| for key, value in doc.metadata.items(): | |
| if key not in ['text', 'texts'] and value: | |
| meta_str.append(f"{key}: {value}") | |
| if meta_str: | |
| context_parts.append(f"[Metadata: {', '.join(meta_str)}]") | |
| context_parts.append("\n--- End of Context ---\n") | |
| return "\n".join(context_parts) | |
| def build_rag_prompt( | |
| self, | |
| query: str, | |
| context: str, | |
| system_message: str = "You are a helpful AI assistant." | |
| ) -> str: | |
| """ | |
| Build optimized RAG prompt for LLM | |
| Uses best practices for prompt engineering | |
| """ | |
| prompt_template = f"""{system_message} | |
| {context} | |
| INSTRUCTIONS: | |
| 1. Answer the user's question using ONLY the information provided in the context above | |
| 2. If the context doesn't contain relevant information, say "Tôi không tìm thấy thông tin liên quan trong dữ liệu." | |
| 3. Cite relevant parts of the context when answering | |
| 4. Be concise and accurate | |
| 5. Answer in Vietnamese if the question is in Vietnamese | |
| USER QUESTION: {query} | |
| YOUR ANSWER:""" | |
| return prompt_template | |