minhvtt commited on
Commit
0f17bfd
·
verified ·
1 Parent(s): 5838fbc

Upload 12 files

Browse files
Files changed (1) hide show
  1. cag_service.py +229 -229
cag_service.py CHANGED
@@ -1,229 +1,229 @@
1
- """
2
- CAG Service (Cache-Augmented Generation)
3
- Semantic caching layer for RAG system using Qdrant
4
-
5
- This module implements intelligent caching to reduce latency and LLM costs
6
- by serving semantically similar queries from cache.
7
- """
8
-
9
- from typing import Optional, Dict, Any, Tuple
10
- from datetime import datetime, timedelta
11
- import numpy as np
12
- from qdrant_client import QdrantClient
13
- from qdrant_client.models import (
14
- Distance, VectorParams, PointStruct,
15
- SearchParams, Filter, FieldCondition, MatchValue, Range
16
- )
17
- import uuid
18
- import os
19
-
20
-
21
- class CAGService:
22
- """
23
- Cache-Augmented Generation Service
24
-
25
- Features:
26
- - Semantic similarity-based cache lookup (cosine similarity)
27
- - TTL (Time-To-Live) for automatic cache expiration
28
- - Configurable similarity threshold
29
- """
30
-
31
- def __init__(
32
- self,
33
- embedding_service,
34
- qdrant_url: Optional[str] = None,
35
- qdrant_api_key: Optional[str] = None,
36
- cache_collection: str = "semantic_cache",
37
- vector_size: int = 1024,
38
- similarity_threshold: float = 0.9,
39
- ttl_hours: int = 24
40
- ):
41
- """
42
- Initialize CAG Service
43
-
44
- Args:
45
- embedding_service: Embedding service for query encoding
46
- qdrant_url: Qdrant Cloud URL
47
- qdrant_api_key: Qdrant API key
48
- cache_collection: Collection name for cache
49
- vector_size: Embedding dimension
50
- similarity_threshold: Min similarity for cache hit (0-1)
51
- ttl_hours: Cache entry lifetime in hours
52
- """
53
- self.embedding_service = embedding_service
54
- self.cache_collection = cache_collection
55
- self.similarity_threshold = similarity_threshold
56
- self.ttl_hours = ttl_hours
57
-
58
- # Initialize Qdrant client
59
- url = qdrant_url or os.getenv("QDRANT_URL")
60
- api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
61
-
62
- if not url or not api_key:
63
- raise ValueError("QDRANT_URL and QDRANT_API_KEY required for CAG")
64
-
65
- self.client = QdrantClient(url=url, api_key=api_key)
66
- self.vector_size = vector_size
67
-
68
- # Ensure cache collection exists
69
- self._ensure_cache_collection()
70
-
71
- print(f"✓ CAG Service initialized (cache: {cache_collection}, threshold: {similarity_threshold})")
72
-
73
- def _ensure_cache_collection(self):
74
- """Create cache collection if it doesn't exist"""
75
- collections = self.client.get_collections().collections
76
- exists = any(c.name == self.cache_collection for c in collections)
77
-
78
- if not exists:
79
- print(f"Creating semantic cache collection: {self.cache_collection}")
80
- self.client.create_collection(
81
- collection_name=self.cache_collection,
82
- vectors_config=VectorParams(
83
- size=self.vector_size,
84
- distance=Distance.COSINE
85
- )
86
- )
87
- print("✓ Semantic cache collection created")
88
-
89
- def check_cache(
90
- self,
91
- query: str
92
- ) -> Optional[Dict[str, Any]]:
93
- """
94
- Check if query has a cached response
95
-
96
- Args:
97
- query: User query string
98
-
99
- Returns:
100
- Cached data if found (with response, context, metadata), None otherwise
101
- """
102
- # Generate query embedding
103
- query_embedding = self.embedding_service.encode_text(query)
104
-
105
- if len(query_embedding.shape) > 1:
106
- query_embedding = query_embedding.flatten()
107
-
108
- # Search for similar queries in cache
109
- search_result = self.client.query_points(
110
- collection_name=self.cache_collection,
111
- query=query_embedding.tolist(),
112
- limit=1,
113
- score_threshold=self.similarity_threshold,
114
- with_payload=True
115
- ).points
116
-
117
- if not search_result:
118
- return None
119
-
120
- hit = search_result[0]
121
-
122
- # Check TTL
123
- cached_at = datetime.fromisoformat(hit.payload.get("cached_at"))
124
- expires_at = cached_at + timedelta(hours=self.ttl_hours)
125
-
126
- if datetime.utcnow() > expires_at:
127
- # Cache expired, delete it
128
- self.client.delete(
129
- collection_name=self.cache_collection,
130
- points_selector=[hit.id]
131
- )
132
- return None
133
-
134
- # Cache hit!
135
- return {
136
- "response": hit.payload.get("response"),
137
- "context_used": hit.payload.get("context_used", []),
138
- "rag_stats": hit.payload.get("rag_stats"),
139
- "cached_query": hit.payload.get("original_query"),
140
- "similarity_score": float(hit.score),
141
- "cached_at": cached_at.isoformat(),
142
- "cache_hit": True
143
- }
144
-
145
- def save_to_cache(
146
- self,
147
- query: str,
148
- response: str,
149
- context_used: list,
150
- rag_stats: Optional[Dict] = None
151
- ) -> str:
152
- """
153
- Save query-response pair to cache
154
-
155
- Args:
156
- query: Original user query
157
- response: Generated response
158
- context_used: Retrieved context documents
159
- rag_stats: RAG pipeline statistics
160
-
161
- Returns:
162
- Cache entry ID
163
- """
164
- # Generate query embedding
165
- query_embedding = self.embedding_service.encode_text(query)
166
-
167
- if len(query_embedding.shape) > 1:
168
- query_embedding = query_embedding.flatten()
169
-
170
- # Create cache entry
171
- cache_id = str(uuid.uuid4())
172
-
173
- point = PointStruct(
174
- id=cache_id,
175
- vector=query_embedding.tolist(),
176
- payload={
177
- "original_query": query,
178
- "response": response,
179
- "context_used": context_used,
180
- "rag_stats": rag_stats or {},
181
- "cached_at": datetime.utcnow().isoformat(),
182
- "cache_type": "semantic"
183
- }
184
- )
185
-
186
- # Save to Qdrant
187
- self.client.upsert(
188
- collection_name=self.cache_collection,
189
- points=[point]
190
- )
191
-
192
- return cache_id
193
-
194
- def clear_cache(self) -> bool:
195
- """
196
- Clear all cache entries
197
-
198
- Returns:
199
- Success status
200
- """
201
- try:
202
- # Delete and recreate collection
203
- self.client.delete_collection(collection_name=self.cache_collection)
204
- self._ensure_cache_collection()
205
- print("✓ Semantic cache cleared")
206
- return True
207
- except Exception as e:
208
- print(f"Error clearing cache: {e}")
209
- return False
210
-
211
- def get_cache_stats(self) -> Dict[str, Any]:
212
- """
213
- Get cache statistics
214
-
215
- Returns:
216
- Cache statistics (size, hit rate, etc.)
217
- """
218
- try:
219
- info = self.client.get_collection(collection_name=self.cache_collection)
220
- return {
221
- "total_entries": info.points_count,
222
- "vectors_count": info.vectors_count,
223
- "status": info.status,
224
- "ttl_hours": self.ttl_hours,
225
- "similarity_threshold": self.similarity_threshold
226
- }
227
- except Exception as e:
228
- print(f"Error getting cache stats: {e}")
229
- return {}
 
1
+ """
2
+ CAG Service (Cache-Augmented Generation)
3
+ Semantic caching layer for RAG system using Qdrant
4
+
5
+ This module implements intelligent caching to reduce latency and LLM costs
6
+ by serving semantically similar queries from cache.
7
+ """
8
+
9
+ from typing import Optional, Dict, Any, Tuple
10
+ from datetime import datetime, timedelta
11
+ import numpy as np
12
+ from qdrant_client import QdrantClient
13
+ from qdrant_client.models import (
14
+ Distance, VectorParams, PointStruct,
15
+ SearchParams, Filter, FieldCondition, MatchValue, Range
16
+ )
17
+ import uuid
18
+ import os
19
+
20
+
21
+ class CAGService:
22
+ """
23
+ Cache-Augmented Generation Service
24
+
25
+ Features:
26
+ - Semantic similarity-based cache lookup (cosine similarity)
27
+ - TTL (Time-To-Live) for automatic cache expiration
28
+ - Configurable similarity threshold
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ embedding_service,
34
+ qdrant_url: Optional[str] = None,
35
+ qdrant_api_key: Optional[str] = None,
36
+ cache_collection: str = "semantic_cache",
37
+ vector_size: int = 1024,
38
+ similarity_threshold: float = 0.9,
39
+ ttl_hours: int = 24
40
+ ):
41
+ """
42
+ Initialize CAG Service
43
+
44
+ Args:
45
+ embedding_service: Embedding service for query encoding
46
+ qdrant_url: Qdrant Cloud URL
47
+ qdrant_api_key: Qdrant API key
48
+ cache_collection: Collection name for cache
49
+ vector_size: Embedding dimension
50
+ similarity_threshold: Min similarity for cache hit (0-1)
51
+ ttl_hours: Cache entry lifetime in hours
52
+ """
53
+ self.embedding_service = embedding_service
54
+ self.cache_collection = cache_collection
55
+ self.similarity_threshold = similarity_threshold
56
+ self.ttl_hours = ttl_hours
57
+
58
+ # Initialize Qdrant client
59
+ url = qdrant_url or os.getenv("QDRANT_URL")
60
+ api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
61
+
62
+ if not url or not api_key:
63
+ raise ValueError("QDRANT_URL and QDRANT_API_KEY required for CAG")
64
+
65
+ self.client = QdrantClient(url=url, api_key=api_key)
66
+ self.vector_size = vector_size
67
+
68
+ # Ensure cache collection exists
69
+ self._ensure_cache_collection()
70
+
71
+ print(f"✓ CAG Service initialized (cache: {cache_collection}, threshold: {similarity_threshold})")
72
+
73
+ def _ensure_cache_collection(self):
74
+ """Create cache collection if it doesn't exist"""
75
+ collections = self.client.get_collections().collections
76
+ exists = any(c.name == self.cache_collection for c in collections)
77
+
78
+ if not exists:
79
+ print(f"Creating semantic cache collection: {self.cache_collection}")
80
+ self.client.create_collection(
81
+ collection_name=self.cache_collection,
82
+ vectors_config=VectorParams(
83
+ size=self.vector_size,
84
+ distance=Distance.COSINE
85
+ )
86
+ )
87
+ print("✓ Semantic cache collection created")
88
+
89
+ def check_cache(
90
+ self,
91
+ query: str
92
+ ) -> Optional[Dict[str, Any]]:
93
+ """
94
+ Check if query has a cached response
95
+
96
+ Args:
97
+ query: User query string
98
+
99
+ Returns:
100
+ Cached data if found (with response, context, metadata), None otherwise
101
+ """
102
+ # Generate query embedding
103
+ query_embedding = self.embedding_service.encode_text(query)
104
+
105
+ if len(query_embedding.shape) > 1:
106
+ query_embedding = query_embedding.flatten()
107
+
108
+ # Search for similar queries in cache
109
+ search_result = self.client.query_points(
110
+ collection_name=self.cache_collection,
111
+ query=query_embedding.tolist(),
112
+ limit=1,
113
+ score_threshold=self.similarity_threshold,
114
+ with_payload=True
115
+ ).points
116
+
117
+ if not search_result:
118
+ return None
119
+
120
+ hit = search_result[0]
121
+
122
+ # Check TTL
123
+ cached_at = datetime.fromisoformat(hit.payload.get("cached_at"))
124
+ expires_at = cached_at + timedelta(hours=self.ttl_hours)
125
+
126
+ if datetime.utcnow() > expires_at:
127
+ # Cache expired, delete it
128
+ self.client.delete(
129
+ collection_name=self.cache_collection,
130
+ points_selector=[hit.id]
131
+ )
132
+ return None
133
+
134
+ # Cache hit!
135
+ return {
136
+ "response": hit.payload.get("response"),
137
+ "context_used": hit.payload.get("context_used", []),
138
+ "rag_stats": hit.payload.get("rag_stats"),
139
+ "cached_query": hit.payload.get("original_query"),
140
+ "similarity_score": float(hit.score),
141
+ "cached_at": cached_at.isoformat(),
142
+ "cache_hit": True
143
+ }
144
+
145
+ def save_to_cache(
146
+ self,
147
+ query: str,
148
+ response: str,
149
+ context_used: list,
150
+ rag_stats: Optional[Dict] = None
151
+ ) -> str:
152
+ """
153
+ Save query-response pair to cache
154
+
155
+ Args:
156
+ query: Original user query
157
+ response: Generated response
158
+ context_used: Retrieved context documents
159
+ rag_stats: RAG pipeline statistics
160
+
161
+ Returns:
162
+ Cache entry ID
163
+ """
164
+ # Generate query embedding
165
+ query_embedding = self.embedding_service.encode_text(query)
166
+
167
+ if len(query_embedding.shape) > 1:
168
+ query_embedding = query_embedding.flatten()
169
+
170
+ # Create cache entry
171
+ cache_id = str(uuid.uuid4())
172
+
173
+ point = PointStruct(
174
+ id=cache_id,
175
+ vector=query_embedding.tolist(),
176
+ payload={
177
+ "original_query": query,
178
+ "response": response,
179
+ "context_used": context_used,
180
+ "rag_stats": rag_stats or {},
181
+ "cached_at": datetime.utcnow().isoformat(),
182
+ "cache_type": "semantic"
183
+ }
184
+ )
185
+
186
+ # Save to Qdrant
187
+ self.client.upsert(
188
+ collection_name=self.cache_collection,
189
+ points=[point]
190
+ )
191
+
192
+ return cache_id
193
+
194
+ def clear_cache(self) -> bool:
195
+ """
196
+ Clear all cache entries
197
+
198
+ Returns:
199
+ Success status
200
+ """
201
+ try:
202
+ # Delete and recreate collection
203
+ self.client.delete_collection(collection_name=self.cache_collection)
204
+ self._ensure_cache_collection()
205
+ print("✓ Semantic cache cleared")
206
+ return True
207
+ except Exception as e:
208
+ print(f"Error clearing cache: {e}")
209
+ return False
210
+
211
+ def get_cache_stats(self) -> Dict[str, Any]:
212
+ """
213
+ Get cache statistics
214
+
215
+ Returns:
216
+ Cache statistics (size, hit rate, etc.)
217
+ """
218
+ try:
219
+ info = self.client.get_collection(collection_name=self.cache_collection)
220
+ return {
221
+ "total_entries": info.points_count,
222
+ "vectors_count": info.vectors_count,
223
+ "status": info.status,
224
+ "ttl_hours": self.ttl_hours,
225
+ "similarity_threshold": self.similarity_threshold
226
+ }
227
+ except Exception as e:
228
+ print(f"Error getting cache stats: {e}")
229
+ return {}