Spaces:
Build error
Build error
LLM Google T5 integration
#3
by
AryanJh
- opened
app.py
CHANGED
|
@@ -16,17 +16,26 @@ class BrockEventsRAG:
|
|
| 16 |
def __init__(self):
|
| 17 |
"""Initialize the RAG system with improved caching"""
|
| 18 |
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 19 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Get current date range
|
| 22 |
self.eastern = pytz.timezone('America/New_York')
|
| 23 |
self.today = datetime.now(self.eastern).replace(hour=0, minute=0, second=0, microsecond=0)
|
| 24 |
self.date_range_end = self.today + timedelta(days=14)
|
| 25 |
-
|
| 26 |
# Cache directory setup
|
| 27 |
os.makedirs("cache", exist_ok=True)
|
| 28 |
self.cache_file = "cache/events_cache.json"
|
| 29 |
-
|
|
|
|
| 30 |
# Initialize or reset collection
|
| 31 |
try:
|
| 32 |
self.collection = self.chroma_client.create_collection(
|
|
@@ -42,69 +51,18 @@ class BrockEventsRAG:
|
|
| 42 |
|
| 43 |
# Load initial events
|
| 44 |
self.update_database()
|
| 45 |
-
|
| 46 |
-
def
|
| 47 |
-
"""
|
| 48 |
-
try:
|
| 49 |
-
# Convert datetime objects to strings for JSON serialization
|
| 50 |
-
serializable_data = {
|
| 51 |
-
'last_update': data['last_update'],
|
| 52 |
-
'events': []
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
for event in data['events']:
|
| 56 |
-
event_copy = event.copy()
|
| 57 |
-
# Convert datetime objects to strings
|
| 58 |
-
if event_copy.get('start_time'):
|
| 59 |
-
event_copy['start_time'] = event_copy['start_time'].isoformat()
|
| 60 |
-
if event_copy.get('end_time'):
|
| 61 |
-
event_copy['end_time'] = event_copy['end_time'].isoformat()
|
| 62 |
-
serializable_data['events'].append(event_copy)
|
| 63 |
-
|
| 64 |
-
with open(self.cache_file, 'w', encoding='utf-8') as f:
|
| 65 |
-
json.dump(serializable_data, f, ensure_ascii=False, indent=2)
|
| 66 |
-
print(f"Cache saved successfully to {self.cache_file}")
|
| 67 |
-
|
| 68 |
-
except Exception as e:
|
| 69 |
-
print(f"Error saving cache: {e}")
|
| 70 |
-
|
| 71 |
-
def load_cache(self) -> dict:
|
| 72 |
-
"""Load and parse cached events data"""
|
| 73 |
try:
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
# Convert string timestamps back to datetime objects
|
| 79 |
-
for event in data['events']:
|
| 80 |
-
if event.get('start_time'):
|
| 81 |
-
event['start_time'] = datetime.fromisoformat(event['start_time'])
|
| 82 |
-
if event.get('end_time'):
|
| 83 |
-
event['end_time'] = datetime.fromisoformat(event['end_time'])
|
| 84 |
-
|
| 85 |
-
return data
|
| 86 |
-
return {'last_update': None, 'events': []}
|
| 87 |
-
|
| 88 |
except Exception as e:
|
| 89 |
-
print(f"Error
|
| 90 |
-
return
|
| 91 |
-
|
| 92 |
-
def should_update_cache(self) -> bool:
|
| 93 |
-
"""Check if cache needs updating (older than 24 hours)"""
|
| 94 |
-
try:
|
| 95 |
-
cached_data = self.load_cache()
|
| 96 |
-
if not cached_data['last_update']:
|
| 97 |
-
return True
|
| 98 |
-
|
| 99 |
-
last_update = datetime.fromisoformat(cached_data['last_update'])
|
| 100 |
-
time_since_update = datetime.now() - last_update
|
| 101 |
-
|
| 102 |
-
return time_since_update.total_seconds() > 86400 # 24 hours
|
| 103 |
|
| 104 |
-
except Exception as e:
|
| 105 |
-
print(f"Error checking cache: {e}")
|
| 106 |
-
return True
|
| 107 |
-
|
| 108 |
def parse_event_datetime(self, entry) -> tuple:
|
| 109 |
"""Parse start and end times from both RSS and HTML"""
|
| 110 |
try:
|
|
@@ -294,6 +252,28 @@ class BrockEventsRAG:
|
|
| 294 |
except Exception as e:
|
| 295 |
print(f"Error during query: {e}")
|
| 296 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
def generate_response(self, question: str, history: list) -> str:
|
| 298 |
"""Generate a response based on the query and chat history"""
|
| 299 |
try:
|
|
@@ -308,7 +288,7 @@ class BrockEventsRAG:
|
|
| 308 |
is_location_query = any(word in question_lower for word in ['where', 'location', 'place', 'building', 'room'])
|
| 309 |
|
| 310 |
# Format the response
|
| 311 |
-
response =
|
| 312 |
|
| 313 |
# Add top 3 matching events
|
| 314 |
for i, (doc, metadata) in enumerate(zip(results['documents'][0][:3], results['metadatas'][0][:3]), 1):
|
|
@@ -326,7 +306,69 @@ class BrockEventsRAG:
|
|
| 326 |
except Exception as e:
|
| 327 |
print(f"Error generating response: {e}")
|
| 328 |
return "I encountered an error while searching for events. Please try asking in a different way."
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
def create_demo():
|
| 331 |
# Initialize the RAG system
|
| 332 |
rag_system = BrockEventsRAG()
|
|
|
|
| 16 |
def __init__(self):
|
| 17 |
"""Initialize the RAG system with improved caching"""
|
| 18 |
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 19 |
+
self.embeddings = HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2')
|
| 20 |
+
|
| 21 |
+
# ChromaDB client setup
|
| 22 |
+
self.chroma_client = chromadb.Client(Settings(persist_directory="chroma_db", chroma_db_impl="duckdb+parquet"))
|
| 23 |
+
|
| 24 |
+
# LLM model setup
|
| 25 |
+
self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
| 26 |
+
self.llm = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
| 27 |
+
|
| 28 |
|
| 29 |
# Get current date range
|
| 30 |
self.eastern = pytz.timezone('America/New_York')
|
| 31 |
self.today = datetime.now(self.eastern).replace(hour=0, minute=0, second=0, microsecond=0)
|
| 32 |
self.date_range_end = self.today + timedelta(days=14)
|
| 33 |
+
|
| 34 |
# Cache directory setup
|
| 35 |
os.makedirs("cache", exist_ok=True)
|
| 36 |
self.cache_file = "cache/events_cache.json"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
# Initialize or reset collection
|
| 40 |
try:
|
| 41 |
self.collection = self.chroma_client.create_collection(
|
|
|
|
| 51 |
|
| 52 |
# Load initial events
|
| 53 |
self.update_database()
|
| 54 |
+
|
| 55 |
+
def fetch_rss_feed(self, url: str) -> List[Dict]:
|
| 56 |
+
"""Fetch and parse RSS feed from the given URL"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
try:
|
| 58 |
+
feed = feedparser.parse(url)
|
| 59 |
+
entries = feed.entries
|
| 60 |
+
print(f"Fetched {len(entries)} entries from the feed.")
|
| 61 |
+
return entries
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
except Exception as e:
|
| 63 |
+
print(f"Error fetching RSS feed: {e}")
|
| 64 |
+
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
def parse_event_datetime(self, entry) -> tuple:
|
| 67 |
"""Parse start and end times from both RSS and HTML"""
|
| 68 |
try:
|
|
|
|
| 252 |
except Exception as e:
|
| 253 |
print(f"Error during query: {e}")
|
| 254 |
return None
|
| 255 |
+
|
| 256 |
+
def generate_response_with_llm(events: List[Dict]) -> str:
|
| 257 |
+
"""Use the LLM to generate a natural language response for the given events."""
|
| 258 |
+
try:
|
| 259 |
+
if not events:
|
| 260 |
+
input_text = "There are no events matching the query. How should I respond?"
|
| 261 |
+
else:
|
| 262 |
+
event_summaries = "\n".join([
|
| 263 |
+
f"Event: {event['title']}. Start: {event['start_time']}, Location: {event['location']}."
|
| 264 |
+
for event in events
|
| 265 |
+
])
|
| 266 |
+
input_text = f"Format this information into a friendly response: {event_summaries}"
|
| 267 |
+
|
| 268 |
+
inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
|
| 269 |
+
outputs = self.llm.generate(**inputs)
|
| 270 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 271 |
+
return response
|
| 272 |
+
except Exception as e:
|
| 273 |
+
print(f"Error generating response: {e}")
|
| 274 |
+
return "Sorry, I couldn't generate a response."
|
| 275 |
+
|
| 276 |
+
|
| 277 |
def generate_response(self, question: str, history: list) -> str:
|
| 278 |
"""Generate a response based on the query and chat history"""
|
| 279 |
try:
|
|
|
|
| 288 |
is_location_query = any(word in question_lower for word in ['where', 'location', 'place', 'building', 'room'])
|
| 289 |
|
| 290 |
# Format the response
|
| 291 |
+
response = generate_response_with_llm(matched_events)
|
| 292 |
|
| 293 |
# Add top 3 matching events
|
| 294 |
for i, (doc, metadata) in enumerate(zip(results['documents'][0][:3], results['metadatas'][0][:3]), 1):
|
|
|
|
| 306 |
except Exception as e:
|
| 307 |
print(f"Error generating response: {e}")
|
| 308 |
return "I encountered an error while searching for events. Please try asking in a different way."
|
| 309 |
+
def save_cache(self, data: dict):
|
| 310 |
+
"""Save events data to cache file"""
|
| 311 |
+
try:
|
| 312 |
+
# Convert datetime objects to strings for JSON serialization
|
| 313 |
+
serializable_data = {
|
| 314 |
+
'last_update': data['last_update'],
|
| 315 |
+
'events': []
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
for event in data['events']:
|
| 319 |
+
event_copy = event.copy()
|
| 320 |
+
# Convert datetime objects to strings
|
| 321 |
+
if event_copy.get('start_time'):
|
| 322 |
+
event_copy['start_time'] = event_copy['start_time'].isoformat()
|
| 323 |
+
if event_copy.get('end_time'):
|
| 324 |
+
event_copy['end_time'] = event_copy['end_time'].isoformat()
|
| 325 |
+
serializable_data['events'].append(event_copy)
|
| 326 |
+
|
| 327 |
+
with open(self.cache_file, 'w', encoding='utf-8') as f:
|
| 328 |
+
json.dump(serializable_data, f, ensure_ascii=False, indent=2)
|
| 329 |
+
print(f"Cache saved successfully to {self.cache_file}")
|
| 330 |
+
|
| 331 |
+
except Exception as e:
|
| 332 |
+
print(f"Error saving cache: {e}")
|
| 333 |
+
"""
|
| 334 |
+
def load_cache(self) -> dict:
|
| 335 |
+
#Load and parse cached events data
|
| 336 |
+
try:
|
| 337 |
+
if os.path.exists(self.cache_file):
|
| 338 |
+
with open(self.cache_file, 'r', encoding='utf-8') as f:
|
| 339 |
+
data = json.load(f)
|
| 340 |
+
|
| 341 |
+
# Convert string timestamps back to datetime objects
|
| 342 |
+
for event in data['events']:
|
| 343 |
+
if event.get('start_time'):
|
| 344 |
+
event['start_time'] = datetime.fromisoformat(event['start_time'])
|
| 345 |
+
if event.get('end_time'):
|
| 346 |
+
event['end_time'] = datetime.fromisoformat(event['end_time'])
|
| 347 |
+
|
| 348 |
+
return data
|
| 349 |
+
return {'last_update': None, 'events': []}
|
| 350 |
+
|
| 351 |
+
except Exception as e:
|
| 352 |
+
print(f"Error loading cache: {e}")
|
| 353 |
+
return {'last_update': None, 'events': []}
|
| 354 |
+
|
| 355 |
+
def should_update_cache(self) -> bool:
|
| 356 |
+
#Check if cache needs updating (older than 24 hours)
|
| 357 |
+
try:
|
| 358 |
+
cached_data = self.load_cache()
|
| 359 |
+
if not cached_data['last_update']:
|
| 360 |
+
return True
|
| 361 |
+
|
| 362 |
+
last_update = datetime.fromisoformat(cached_data['last_update'])
|
| 363 |
+
time_since_update = datetime.now() - last_update
|
| 364 |
+
|
| 365 |
+
return time_since_update.total_seconds() > 86400 # 24 hours
|
| 366 |
+
|
| 367 |
+
except Exception as e:
|
| 368 |
+
print(f"Error checking cache: {e}")
|
| 369 |
+
return True
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
def create_demo():
|
| 373 |
# Initialize the RAG system
|
| 374 |
rag_system = BrockEventsRAG()
|