ChatbotRAG / hybrid_chat_stream.py
minhvtt's picture
Update hybrid_chat_stream.py
f376a38 verified
raw
history blame
8.7 kB
"""
Hybrid Chat Streaming Endpoint
Real-time SSE streaming for scenarios + RAG
"""
from typing import AsyncGenerator
import asyncio
from datetime import datetime
from stream_utils import (
format_sse, stream_text_slowly,
EVENT_STATUS, EVENT_TOKEN, EVENT_DONE, EVENT_ERROR, EVENT_METADATA
)
# Import scenario handlers
from scenario_handlers.price_inquiry import PriceInquiryHandler
from scenario_handlers.event_recommendation import EventRecommendationHandler
from scenario_handlers.post_event_feedback import PostEventFeedbackHandler
from scenario_handlers.exit_intent_rescue import ExitIntentRescueHandler
async def hybrid_chat_stream(
request,
conversation_service,
intent_classifier,
embedding_service, # For handlers
qdrant_service, # For handlers
advanced_rag,
hf_token,
lead_storage
) -> AsyncGenerator[str, None]:
"""
Stream chat responses in real-time (SSE format)
Yields SSE events:
- status: "Đang suy nghĩ...", "Đang tìm kiếm..."
- token: Individual text chunks
- metadata: Context, session info
- done: Completion signal
- error: Error messages
"""
try:
# === SESSION MANAGEMENT ===
session_id = request.session_id
if not session_id:
session_id = conversation_service.create_session(
metadata={"user_agent": "api", "created_via": "stream"},
user_id=request.user_id
)
yield format_sse(EVENT_METADATA, {"session_id": session_id})
# === INTENT CLASSIFICATION ===
yield format_sse(EVENT_STATUS, "Đang phân tích câu hỏi...")
scenario_state = conversation_service.get_scenario_state(session_id) or {}
intent = intent_classifier.classify(request.message, scenario_state)
# === ROUTING ===
if intent.startswith("scenario:"):
# Scenario flow with simulated streaming using handlers
async for sse_event in handle_scenario_stream(
intent, request.message, session_id,
scenario_state, embedding_service, qdrant_service,
conversation_service, lead_storage
):
yield sse_event
elif intent == "rag:with_resume":
# Quick RAG answer + resume scenario
yield format_sse(EVENT_STATUS, "Đang tra cứu...")
async for sse_event in handle_rag_stream(
request, advanced_rag, embedding_service, qdrant_service
):
yield sse_event
# Resume hint
async for chunk in stream_text_slowly(
"\n\n---\nVậy nha! Quay lại câu hỏi trước nhé ^^",
chars_per_chunk=5,
delay_ms=15
):
yield chunk
else: # Pure RAG
yield format_sse(EVENT_STATUS, "Đang tìm kiếm trong tài liệu...")
async for sse_event in handle_rag_stream(
request, advanced_rag, embedding_service, qdrant_service
):
yield sse_event
# === SAVE HISTORY ===
# Note: We'll save the full response after streaming completes
# This requires buffering on the server side
# === DONE ===
yield format_sse(EVENT_DONE, {
"session_id": session_id,
"timestamp": datetime.utcnow().isoformat()
})
except Exception as e:
yield format_sse(EVENT_ERROR, str(e))
async def handle_scenario_stream(
intent, user_message, session_id,
scenario_state, embedding_service, qdrant_service,
conversation_service, lead_storage
) -> AsyncGenerator[str, None]:
"""
Handle scenario with simulated typing effect using dedicated handlers
"""
# Initialize all scenario handlers
handlers = {
'price_inquiry': PriceInquiryHandler(embedding_service, qdrant_service, lead_storage),
'event_recommendation': EventRecommendationHandler(embedding_service, qdrant_service, lead_storage),
'post_event_feedback': PostEventFeedbackHandler(embedding_service, qdrant_service, lead_storage),
'exit_intent_rescue': ExitIntentRescueHandler(embedding_service, qdrant_service, lead_storage)
}
# Get scenario response using handlers
if intent == "scenario:continue":
scenario_id = scenario_state.get("active_scenario")
if scenario_id not in handlers:
yield format_sse(EVENT_ERROR, f"Scenario '{scenario_id}' không tồn tại")
return
handler = handlers[scenario_id]
result = handler.next_step(
current_step=scenario_state.get("scenario_step", 1),
user_input=user_message,
scenario_data=scenario_state.get("scenario_data", {})
)
else:
scenario_type = intent.split(":", 1)[1]
if scenario_type not in handlers:
yield format_sse(EVENT_ERROR, f"Scenario '{scenario_type}' không tồn tại")
return
handler = handlers[scenario_type]
initial_data = scenario_state.get("scenario_data", {})
result = handler.start(initial_data=initial_data)
# Show loading message if RAG is being performed
if result.get("loading_message"):
yield format_sse(EVENT_STATUS, result["loading_message"])
# Small delay to let UI show loading
await asyncio.sleep(0.1)
# Update state
if result.get("end_scenario"):
conversation_service.clear_scenario(session_id)
elif result.get("new_state"):
conversation_service.set_scenario_state(session_id, result["new_state"])
# Execute actions
if result.get("action") and lead_storage:
action = result['action']
scenario_data = result.get('new_state', {}).get('scenario_data', {})
if action == "send_pdf_email":
lead_storage.save_lead(
event_name=scenario_data.get('step_1_input', 'Unknown'),
email=scenario_data.get('step_5_input'),
interests={"group": scenario_data.get('group_size'), "wants_pdf": True},
session_id=session_id
)
elif action == "save_lead_phone":
lead_storage.save_lead(
event_name=scenario_data.get('step_1_input', 'Unknown'),
email=scenario_data.get('step_5_input'),
phone=scenario_data.get('step_8_input'),
interests={"group": scenario_data.get('group_size'), "wants_reminder": True},
session_id=session_id
)
# Stream response with typing effect
response_text = result["message"]
async for chunk in stream_text_slowly(
response_text,
chars_per_chunk=4, # Faster for scenarios
delay_ms=15
):
yield chunk
yield format_sse(EVENT_METADATA, {
"mode": "scenario",
"scenario_active": not result.get("end_scenario")
})
async def handle_rag_stream(
request, advanced_rag, embedding_service, qdrant_service
) -> AsyncGenerator[str, None]:
"""
Handle RAG with real LLM streaming
"""
# RAG search (sync part)
context_used = []
if request.use_rag:
query_embedding = embedding_service.encode_text(request.message)
results = qdrant_service.search(
query_embedding=query_embedding,
limit=request.top_k,
score_threshold=request.score_threshold,
ef=256
)
context_used = results
# Build context
if context_used:
context_str = "\n\n".join([
f"[{i+1}] {r['metadata'].get('text', '')[:500]}"
for i, r in enumerate(context_used[:3])
])
else:
context_str = "Không tìm thấy thông tin liên quan."
# Simple response (for now - can integrate with real LLM streaming later)
if context_used:
response_text = f"Dựa trên tài liệu, {context_used[0]['metadata'].get('text', '')[:300]}..."
else:
response_text = "Xin lỗi, tôi không tìm thấy thông tin về câu hỏi này."
# Simulate streaming (will be replaced with real HF streaming)
async for chunk in stream_text_slowly(
response_text,
chars_per_chunk=3,
delay_ms=20
):
yield chunk
yield format_sse(EVENT_METADATA, {
"mode": "rag",
"context_count": len(context_used)
})
# TODO: Implement real HF InferenceClient streaming
# This requires updating advanced_rag.py to support stream=True