Spaces:
Running
Running
File size: 7,626 Bytes
0f17bfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
"""
CAG Service (Cache-Augmented Generation)
Semantic caching layer for RAG system using Qdrant
This module implements intelligent caching to reduce latency and LLM costs
by serving semantically similar queries from cache.
"""
from typing import Optional, Dict, Any, Tuple
from datetime import datetime, timedelta
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance, VectorParams, PointStruct,
SearchParams, Filter, FieldCondition, MatchValue, Range
)
import uuid
import os
class CAGService:
"""
Cache-Augmented Generation Service
Features:
- Semantic similarity-based cache lookup (cosine similarity)
- TTL (Time-To-Live) for automatic cache expiration
- Configurable similarity threshold
"""
def __init__(
self,
embedding_service,
qdrant_url: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
cache_collection: str = "semantic_cache",
vector_size: int = 1024,
similarity_threshold: float = 0.9,
ttl_hours: int = 24
):
"""
Initialize CAG Service
Args:
embedding_service: Embedding service for query encoding
qdrant_url: Qdrant Cloud URL
qdrant_api_key: Qdrant API key
cache_collection: Collection name for cache
vector_size: Embedding dimension
similarity_threshold: Min similarity for cache hit (0-1)
ttl_hours: Cache entry lifetime in hours
"""
self.embedding_service = embedding_service
self.cache_collection = cache_collection
self.similarity_threshold = similarity_threshold
self.ttl_hours = ttl_hours
# Initialize Qdrant client
url = qdrant_url or os.getenv("QDRANT_URL")
api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
if not url or not api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY required for CAG")
self.client = QdrantClient(url=url, api_key=api_key)
self.vector_size = vector_size
# Ensure cache collection exists
self._ensure_cache_collection()
print(f"✓ CAG Service initialized (cache: {cache_collection}, threshold: {similarity_threshold})")
def _ensure_cache_collection(self):
"""Create cache collection if it doesn't exist"""
collections = self.client.get_collections().collections
exists = any(c.name == self.cache_collection for c in collections)
if not exists:
print(f"Creating semantic cache collection: {self.cache_collection}")
self.client.create_collection(
collection_name=self.cache_collection,
vectors_config=VectorParams(
size=self.vector_size,
distance=Distance.COSINE
)
)
print("✓ Semantic cache collection created")
def check_cache(
self,
query: str
) -> Optional[Dict[str, Any]]:
"""
Check if query has a cached response
Args:
query: User query string
Returns:
Cached data if found (with response, context, metadata), None otherwise
"""
# Generate query embedding
query_embedding = self.embedding_service.encode_text(query)
if len(query_embedding.shape) > 1:
query_embedding = query_embedding.flatten()
# Search for similar queries in cache
search_result = self.client.query_points(
collection_name=self.cache_collection,
query=query_embedding.tolist(),
limit=1,
score_threshold=self.similarity_threshold,
with_payload=True
).points
if not search_result:
return None
hit = search_result[0]
# Check TTL
cached_at = datetime.fromisoformat(hit.payload.get("cached_at"))
expires_at = cached_at + timedelta(hours=self.ttl_hours)
if datetime.utcnow() > expires_at:
# Cache expired, delete it
self.client.delete(
collection_name=self.cache_collection,
points_selector=[hit.id]
)
return None
# Cache hit!
return {
"response": hit.payload.get("response"),
"context_used": hit.payload.get("context_used", []),
"rag_stats": hit.payload.get("rag_stats"),
"cached_query": hit.payload.get("original_query"),
"similarity_score": float(hit.score),
"cached_at": cached_at.isoformat(),
"cache_hit": True
}
def save_to_cache(
self,
query: str,
response: str,
context_used: list,
rag_stats: Optional[Dict] = None
) -> str:
"""
Save query-response pair to cache
Args:
query: Original user query
response: Generated response
context_used: Retrieved context documents
rag_stats: RAG pipeline statistics
Returns:
Cache entry ID
"""
# Generate query embedding
query_embedding = self.embedding_service.encode_text(query)
if len(query_embedding.shape) > 1:
query_embedding = query_embedding.flatten()
# Create cache entry
cache_id = str(uuid.uuid4())
point = PointStruct(
id=cache_id,
vector=query_embedding.tolist(),
payload={
"original_query": query,
"response": response,
"context_used": context_used,
"rag_stats": rag_stats or {},
"cached_at": datetime.utcnow().isoformat(),
"cache_type": "semantic"
}
)
# Save to Qdrant
self.client.upsert(
collection_name=self.cache_collection,
points=[point]
)
return cache_id
def clear_cache(self) -> bool:
"""
Clear all cache entries
Returns:
Success status
"""
try:
# Delete and recreate collection
self.client.delete_collection(collection_name=self.cache_collection)
self._ensure_cache_collection()
print("✓ Semantic cache cleared")
return True
except Exception as e:
print(f"Error clearing cache: {e}")
return False
def get_cache_stats(self) -> Dict[str, Any]:
"""
Get cache statistics
Returns:
Cache statistics (size, hit rate, etc.)
"""
try:
info = self.client.get_collection(collection_name=self.cache_collection)
return {
"total_entries": info.points_count,
"vectors_count": info.vectors_count,
"status": info.status,
"ttl_hours": self.ttl_hours,
"similarity_threshold": self.similarity_threshold
}
except Exception as e:
print(f"Error getting cache stats: {e}")
return {}
|