""" CAG Service (Cache-Augmented Generation) Semantic caching layer for RAG system using Qdrant This module implements intelligent caching to reduce latency and LLM costs by serving semantically similar queries from cache. """ from typing import Optional, Dict, Any, Tuple from datetime import datetime, timedelta import numpy as np from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, VectorParams, PointStruct, SearchParams, Filter, FieldCondition, MatchValue, Range ) import uuid import os class CAGService: """ Cache-Augmented Generation Service Features: - Semantic similarity-based cache lookup (cosine similarity) - TTL (Time-To-Live) for automatic cache expiration - Configurable similarity threshold """ def __init__( self, embedding_service, qdrant_url: Optional[str] = None, qdrant_api_key: Optional[str] = None, cache_collection: str = "semantic_cache", vector_size: int = 1024, similarity_threshold: float = 0.9, ttl_hours: int = 24 ): """ Initialize CAG Service Args: embedding_service: Embedding service for query encoding qdrant_url: Qdrant Cloud URL qdrant_api_key: Qdrant API key cache_collection: Collection name for cache vector_size: Embedding dimension similarity_threshold: Min similarity for cache hit (0-1) ttl_hours: Cache entry lifetime in hours """ self.embedding_service = embedding_service self.cache_collection = cache_collection self.similarity_threshold = similarity_threshold self.ttl_hours = ttl_hours # Initialize Qdrant client url = qdrant_url or os.getenv("QDRANT_URL") api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY") if not url or not api_key: raise ValueError("QDRANT_URL and QDRANT_API_KEY required for CAG") self.client = QdrantClient(url=url, api_key=api_key) self.vector_size = vector_size # Ensure cache collection exists self._ensure_cache_collection() print(f"✓ CAG Service initialized (cache: {cache_collection}, threshold: {similarity_threshold})") def _ensure_cache_collection(self): """Create cache collection if it doesn't exist""" collections = self.client.get_collections().collections exists = any(c.name == self.cache_collection for c in collections) if not exists: print(f"Creating semantic cache collection: {self.cache_collection}") self.client.create_collection( collection_name=self.cache_collection, vectors_config=VectorParams( size=self.vector_size, distance=Distance.COSINE ) ) print("✓ Semantic cache collection created") def check_cache( self, query: str ) -> Optional[Dict[str, Any]]: """ Check if query has a cached response Args: query: User query string Returns: Cached data if found (with response, context, metadata), None otherwise """ # Generate query embedding query_embedding = self.embedding_service.encode_text(query) if len(query_embedding.shape) > 1: query_embedding = query_embedding.flatten() # Search for similar queries in cache search_result = self.client.query_points( collection_name=self.cache_collection, query=query_embedding.tolist(), limit=1, score_threshold=self.similarity_threshold, with_payload=True ).points if not search_result: return None hit = search_result[0] # Check TTL cached_at = datetime.fromisoformat(hit.payload.get("cached_at")) expires_at = cached_at + timedelta(hours=self.ttl_hours) if datetime.utcnow() > expires_at: # Cache expired, delete it self.client.delete( collection_name=self.cache_collection, points_selector=[hit.id] ) return None # Cache hit! return { "response": hit.payload.get("response"), "context_used": hit.payload.get("context_used", []), "rag_stats": hit.payload.get("rag_stats"), "cached_query": hit.payload.get("original_query"), "similarity_score": float(hit.score), "cached_at": cached_at.isoformat(), "cache_hit": True } def save_to_cache( self, query: str, response: str, context_used: list, rag_stats: Optional[Dict] = None ) -> str: """ Save query-response pair to cache Args: query: Original user query response: Generated response context_used: Retrieved context documents rag_stats: RAG pipeline statistics Returns: Cache entry ID """ # Generate query embedding query_embedding = self.embedding_service.encode_text(query) if len(query_embedding.shape) > 1: query_embedding = query_embedding.flatten() # Create cache entry cache_id = str(uuid.uuid4()) point = PointStruct( id=cache_id, vector=query_embedding.tolist(), payload={ "original_query": query, "response": response, "context_used": context_used, "rag_stats": rag_stats or {}, "cached_at": datetime.utcnow().isoformat(), "cache_type": "semantic" } ) # Save to Qdrant self.client.upsert( collection_name=self.cache_collection, points=[point] ) return cache_id def clear_cache(self) -> bool: """ Clear all cache entries Returns: Success status """ try: # Delete and recreate collection self.client.delete_collection(collection_name=self.cache_collection) self._ensure_cache_collection() print("✓ Semantic cache cleared") return True except Exception as e: print(f"Error clearing cache: {e}") return False def get_cache_stats(self) -> Dict[str, Any]: """ Get cache statistics Returns: Cache statistics (size, hit rate, etc.) """ try: info = self.client.get_collection(collection_name=self.cache_collection) return { "total_entries": info.points_count, "vectors_count": info.vectors_count, "status": info.status, "ttl_hours": self.ttl_hours, "similarity_threshold": self.similarity_threshold } except Exception as e: print(f"Error getting cache stats: {e}") return {}