ChatbotRAG / cag_service.py
minhvtt's picture
Upload 12 files
0f17bfd verified
"""
CAG Service (Cache-Augmented Generation)
Semantic caching layer for RAG system using Qdrant
This module implements intelligent caching to reduce latency and LLM costs
by serving semantically similar queries from cache.
"""
from typing import Optional, Dict, Any, Tuple
from datetime import datetime, timedelta
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance, VectorParams, PointStruct,
SearchParams, Filter, FieldCondition, MatchValue, Range
)
import uuid
import os
class CAGService:
"""
Cache-Augmented Generation Service
Features:
- Semantic similarity-based cache lookup (cosine similarity)
- TTL (Time-To-Live) for automatic cache expiration
- Configurable similarity threshold
"""
def __init__(
self,
embedding_service,
qdrant_url: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
cache_collection: str = "semantic_cache",
vector_size: int = 1024,
similarity_threshold: float = 0.9,
ttl_hours: int = 24
):
"""
Initialize CAG Service
Args:
embedding_service: Embedding service for query encoding
qdrant_url: Qdrant Cloud URL
qdrant_api_key: Qdrant API key
cache_collection: Collection name for cache
vector_size: Embedding dimension
similarity_threshold: Min similarity for cache hit (0-1)
ttl_hours: Cache entry lifetime in hours
"""
self.embedding_service = embedding_service
self.cache_collection = cache_collection
self.similarity_threshold = similarity_threshold
self.ttl_hours = ttl_hours
# Initialize Qdrant client
url = qdrant_url or os.getenv("QDRANT_URL")
api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
if not url or not api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY required for CAG")
self.client = QdrantClient(url=url, api_key=api_key)
self.vector_size = vector_size
# Ensure cache collection exists
self._ensure_cache_collection()
print(f"✓ CAG Service initialized (cache: {cache_collection}, threshold: {similarity_threshold})")
def _ensure_cache_collection(self):
"""Create cache collection if it doesn't exist"""
collections = self.client.get_collections().collections
exists = any(c.name == self.cache_collection for c in collections)
if not exists:
print(f"Creating semantic cache collection: {self.cache_collection}")
self.client.create_collection(
collection_name=self.cache_collection,
vectors_config=VectorParams(
size=self.vector_size,
distance=Distance.COSINE
)
)
print("✓ Semantic cache collection created")
def check_cache(
self,
query: str
) -> Optional[Dict[str, Any]]:
"""
Check if query has a cached response
Args:
query: User query string
Returns:
Cached data if found (with response, context, metadata), None otherwise
"""
# Generate query embedding
query_embedding = self.embedding_service.encode_text(query)
if len(query_embedding.shape) > 1:
query_embedding = query_embedding.flatten()
# Search for similar queries in cache
search_result = self.client.query_points(
collection_name=self.cache_collection,
query=query_embedding.tolist(),
limit=1,
score_threshold=self.similarity_threshold,
with_payload=True
).points
if not search_result:
return None
hit = search_result[0]
# Check TTL
cached_at = datetime.fromisoformat(hit.payload.get("cached_at"))
expires_at = cached_at + timedelta(hours=self.ttl_hours)
if datetime.utcnow() > expires_at:
# Cache expired, delete it
self.client.delete(
collection_name=self.cache_collection,
points_selector=[hit.id]
)
return None
# Cache hit!
return {
"response": hit.payload.get("response"),
"context_used": hit.payload.get("context_used", []),
"rag_stats": hit.payload.get("rag_stats"),
"cached_query": hit.payload.get("original_query"),
"similarity_score": float(hit.score),
"cached_at": cached_at.isoformat(),
"cache_hit": True
}
def save_to_cache(
self,
query: str,
response: str,
context_used: list,
rag_stats: Optional[Dict] = None
) -> str:
"""
Save query-response pair to cache
Args:
query: Original user query
response: Generated response
context_used: Retrieved context documents
rag_stats: RAG pipeline statistics
Returns:
Cache entry ID
"""
# Generate query embedding
query_embedding = self.embedding_service.encode_text(query)
if len(query_embedding.shape) > 1:
query_embedding = query_embedding.flatten()
# Create cache entry
cache_id = str(uuid.uuid4())
point = PointStruct(
id=cache_id,
vector=query_embedding.tolist(),
payload={
"original_query": query,
"response": response,
"context_used": context_used,
"rag_stats": rag_stats or {},
"cached_at": datetime.utcnow().isoformat(),
"cache_type": "semantic"
}
)
# Save to Qdrant
self.client.upsert(
collection_name=self.cache_collection,
points=[point]
)
return cache_id
def clear_cache(self) -> bool:
"""
Clear all cache entries
Returns:
Success status
"""
try:
# Delete and recreate collection
self.client.delete_collection(collection_name=self.cache_collection)
self._ensure_cache_collection()
print("✓ Semantic cache cleared")
return True
except Exception as e:
print(f"Error clearing cache: {e}")
return False
def get_cache_stats(self) -> Dict[str, Any]:
"""
Get cache statistics
Returns:
Cache statistics (size, hit rate, etc.)
"""
try:
info = self.client.get_collection(collection_name=self.cache_collection)
return {
"total_entries": info.points_count,
"vectors_count": info.vectors_count,
"status": info.status,
"ttl_hours": self.ttl_hours,
"similarity_threshold": self.similarity_threshold
}
except Exception as e:
print(f"Error getting cache stats: {e}")
return {}