Spaces:
Running
Running
| """ | |
| CAG Service (Cache-Augmented Generation) | |
| Semantic caching layer for RAG system using Qdrant | |
| This module implements intelligent caching to reduce latency and LLM costs | |
| by serving semantically similar queries from cache. | |
| """ | |
| from typing import Optional, Dict, Any, Tuple | |
| from datetime import datetime, timedelta | |
| import numpy as np | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import ( | |
| Distance, VectorParams, PointStruct, | |
| SearchParams, Filter, FieldCondition, MatchValue, Range | |
| ) | |
| import uuid | |
| import os | |
| class CAGService: | |
| """ | |
| Cache-Augmented Generation Service | |
| Features: | |
| - Semantic similarity-based cache lookup (cosine similarity) | |
| - TTL (Time-To-Live) for automatic cache expiration | |
| - Configurable similarity threshold | |
| """ | |
| def __init__( | |
| self, | |
| embedding_service, | |
| qdrant_url: Optional[str] = None, | |
| qdrant_api_key: Optional[str] = None, | |
| cache_collection: str = "semantic_cache", | |
| vector_size: int = 1024, | |
| similarity_threshold: float = 0.9, | |
| ttl_hours: int = 24 | |
| ): | |
| """ | |
| Initialize CAG Service | |
| Args: | |
| embedding_service: Embedding service for query encoding | |
| qdrant_url: Qdrant Cloud URL | |
| qdrant_api_key: Qdrant API key | |
| cache_collection: Collection name for cache | |
| vector_size: Embedding dimension | |
| similarity_threshold: Min similarity for cache hit (0-1) | |
| ttl_hours: Cache entry lifetime in hours | |
| """ | |
| self.embedding_service = embedding_service | |
| self.cache_collection = cache_collection | |
| self.similarity_threshold = similarity_threshold | |
| self.ttl_hours = ttl_hours | |
| # Initialize Qdrant client | |
| url = qdrant_url or os.getenv("QDRANT_URL") | |
| api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY") | |
| if not url or not api_key: | |
| raise ValueError("QDRANT_URL and QDRANT_API_KEY required for CAG") | |
| self.client = QdrantClient(url=url, api_key=api_key) | |
| self.vector_size = vector_size | |
| # Ensure cache collection exists | |
| self._ensure_cache_collection() | |
| print(f"✓ CAG Service initialized (cache: {cache_collection}, threshold: {similarity_threshold})") | |
| def _ensure_cache_collection(self): | |
| """Create cache collection if it doesn't exist""" | |
| collections = self.client.get_collections().collections | |
| exists = any(c.name == self.cache_collection for c in collections) | |
| if not exists: | |
| print(f"Creating semantic cache collection: {self.cache_collection}") | |
| self.client.create_collection( | |
| collection_name=self.cache_collection, | |
| vectors_config=VectorParams( | |
| size=self.vector_size, | |
| distance=Distance.COSINE | |
| ) | |
| ) | |
| print("✓ Semantic cache collection created") | |
| def check_cache( | |
| self, | |
| query: str | |
| ) -> Optional[Dict[str, Any]]: | |
| """ | |
| Check if query has a cached response | |
| Args: | |
| query: User query string | |
| Returns: | |
| Cached data if found (with response, context, metadata), None otherwise | |
| """ | |
| # Generate query embedding | |
| query_embedding = self.embedding_service.encode_text(query) | |
| if len(query_embedding.shape) > 1: | |
| query_embedding = query_embedding.flatten() | |
| # Search for similar queries in cache | |
| search_result = self.client.query_points( | |
| collection_name=self.cache_collection, | |
| query=query_embedding.tolist(), | |
| limit=1, | |
| score_threshold=self.similarity_threshold, | |
| with_payload=True | |
| ).points | |
| if not search_result: | |
| return None | |
| hit = search_result[0] | |
| # Check TTL | |
| cached_at = datetime.fromisoformat(hit.payload.get("cached_at")) | |
| expires_at = cached_at + timedelta(hours=self.ttl_hours) | |
| if datetime.utcnow() > expires_at: | |
| # Cache expired, delete it | |
| self.client.delete( | |
| collection_name=self.cache_collection, | |
| points_selector=[hit.id] | |
| ) | |
| return None | |
| # Cache hit! | |
| return { | |
| "response": hit.payload.get("response"), | |
| "context_used": hit.payload.get("context_used", []), | |
| "rag_stats": hit.payload.get("rag_stats"), | |
| "cached_query": hit.payload.get("original_query"), | |
| "similarity_score": float(hit.score), | |
| "cached_at": cached_at.isoformat(), | |
| "cache_hit": True | |
| } | |
| def save_to_cache( | |
| self, | |
| query: str, | |
| response: str, | |
| context_used: list, | |
| rag_stats: Optional[Dict] = None | |
| ) -> str: | |
| """ | |
| Save query-response pair to cache | |
| Args: | |
| query: Original user query | |
| response: Generated response | |
| context_used: Retrieved context documents | |
| rag_stats: RAG pipeline statistics | |
| Returns: | |
| Cache entry ID | |
| """ | |
| # Generate query embedding | |
| query_embedding = self.embedding_service.encode_text(query) | |
| if len(query_embedding.shape) > 1: | |
| query_embedding = query_embedding.flatten() | |
| # Create cache entry | |
| cache_id = str(uuid.uuid4()) | |
| point = PointStruct( | |
| id=cache_id, | |
| vector=query_embedding.tolist(), | |
| payload={ | |
| "original_query": query, | |
| "response": response, | |
| "context_used": context_used, | |
| "rag_stats": rag_stats or {}, | |
| "cached_at": datetime.utcnow().isoformat(), | |
| "cache_type": "semantic" | |
| } | |
| ) | |
| # Save to Qdrant | |
| self.client.upsert( | |
| collection_name=self.cache_collection, | |
| points=[point] | |
| ) | |
| return cache_id | |
| def clear_cache(self) -> bool: | |
| """ | |
| Clear all cache entries | |
| Returns: | |
| Success status | |
| """ | |
| try: | |
| # Delete and recreate collection | |
| self.client.delete_collection(collection_name=self.cache_collection) | |
| self._ensure_cache_collection() | |
| print("✓ Semantic cache cleared") | |
| return True | |
| except Exception as e: | |
| print(f"Error clearing cache: {e}") | |
| return False | |
| def get_cache_stats(self) -> Dict[str, Any]: | |
| """ | |
| Get cache statistics | |
| Returns: | |
| Cache statistics (size, hit rate, etc.) | |
| """ | |
| try: | |
| info = self.client.get_collection(collection_name=self.cache_collection) | |
| return { | |
| "total_entries": info.points_count, | |
| "vectors_count": info.vectors_count, | |
| "status": info.status, | |
| "ttl_hours": self.ttl_hours, | |
| "similarity_threshold": self.similarity_threshold | |
| } | |
| except Exception as e: | |
| print(f"Error getting cache stats: {e}") | |
| return {} | |