ChatbotRAG / chat_endpoint.py
minhvtt's picture
Upload 26 files
75033ed verified
"""
Chat endpoint với Multi-turn Conversation + Function Calling
"""
from fastapi import HTTPException
from datetime import datetime
from huggingface_hub import InferenceClient
from typing import Dict, List
import json
async def chat_endpoint(
request, # ChatRequest
conversation_service,
tools_service,
advanced_rag,
embedding_service,
qdrant_service,
chat_history_collection,
hf_token
):
"""
Multi-turn conversational chatbot với RAG + Function Calling
Flow:
1. Session management - create hoặc load existing session
2. RAG search - retrieve context nếu enabled
3. Build messages với conversation history + tools prompt
4. LLM generation - có thể trigger tool calls
5. Execute tools nếu cần
6. Final LLM response với tool results
7. Save to conversation history
"""
try:
# ===== 1. SESSION MANAGEMENT =====
session_id = request.session_id
if not session_id:
# Create new session (server-side)
session_id = conversation_service.create_session(
metadata={"user_agent": "api", "created_via": "chat_endpoint"},
user_id=request.user_id # NEW: Pass user_id from request
)
print(f"Created new session: {session_id} for user: {request.user_id or 'anonymous'}")
else:
# Validate existing session
if not conversation_service.session_exists(session_id):
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found. It may have expired."
)
# Load conversation history
conversation_history = conversation_service.get_conversation_history(session_id)
# ===== 2. RAG SEARCH =====
context_used = []
rag_stats = None
context_text = ""
if request.use_rag:
if request.use_advanced_rag:
# Use Advanced RAG Pipeline
hf_client = None
if request.hf_token or hf_token:
hf_client = InferenceClient(token=request.hf_token or hf_token)
documents, stats = advanced_rag.hybrid_rag_pipeline(
query=request.message,
top_k=request.top_k,
score_threshold=request.score_threshold,
use_reranking=request.use_reranking,
use_compression=request.use_compression,
use_query_expansion=request.use_query_expansion,
max_context_tokens=500,
hf_client=hf_client
)
# Convert to dict format
context_used = [
{
"id": doc.id,
"confidence": doc.confidence,
"metadata": doc.metadata
}
for doc in documents
]
rag_stats = stats
# Format context
context_text = advanced_rag.format_context_for_llm(documents)
else:
# Basic RAG
query_embedding = embedding_service.encode_text(request.message)
results = qdrant_service.search(
query_embedding=query_embedding,
limit=request.top_k,
score_threshold=request.score_threshold
)
context_used = results
context_text = "\n\nRelevant Context:\n"
for i, doc in enumerate(context_used, 1):
doc_text = doc["metadata"].get("text", "")
if not doc_text:
doc_text = " ".join(doc["metadata"].get("texts", []))
confidence = doc["confidence"]
context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n"
# ===== 3. BUILD MESSAGES với TOOLS PROMPT =====
messages = []
# System message với RAG context + Tools instruction
if request.use_rag and context_used:
if request.use_advanced_rag:
base_prompt = advanced_rag.build_rag_prompt(
query="", # Query sẽ đi trong user message
context=context_text,
system_message=request.system_message
)
else:
base_prompt = f"""{request.system_message}
{context_text}
HƯỚNG DẪN:
- Sử dụng thông tin từ context trên để trả lời câu hỏi.
- Trả lời tự nhiên, thân thiện, không copy nguyên văn.
- Nếu tìm thấy sự kiện, hãy tóm tắt các thông tin quan trọng nhất.
"""
else:
base_prompt = request.system_message
# Add tools instruction nếu enabled
if request.enable_tools:
tools_prompt = tools_service.get_tools_prompt()
system_message_with_tools = f"{base_prompt}\n\n{tools_prompt}"
else:
system_message_with_tools = base_prompt
# Bắt đầu messages với system
messages.append({"role": "system", "content": system_message_with_tools})
# Add conversation history (past turns)
messages.extend(conversation_history)
# Add current user message
messages.append({"role": "user", "content": request.message})
# ===== 4. LLM GENERATION =====
token = request.hf_token or hf_token
tool_calls_made = []
if not token:
response = f"""[LLM Response Placeholder]
Context retrieved: {len(context_used)} documents
User question: {request.message}
Session: {session_id}
To enable actual LLM generation:
1. Set HUGGINGFACE_TOKEN environment variable, OR
2. Pass hf_token in request body
"""
else:
try:
client = InferenceClient(
token=token,
model="openai/gpt-oss-20b" # Hoặc model khác
)
# First LLM call
first_response = ""
try:
for msg in client.chat_completion(
messages,
max_tokens=request.max_tokens,
stream=True,
temperature=request.temperature,
top_p=request.top_p,
):
choices = msg.choices
if len(choices) and choices[0].delta.content:
first_response += choices[0].delta.content
except Exception as e:
# HF API throws error when LLM returns JSON (tool call)
# Extract the "failed_generation" from error
error_str = str(e)
if "tool_use_failed" in error_str and "failed_generation" in error_str:
# Parse error dict to get the actual JSON response
import ast
try:
error_dict = ast.literal_eval(error_str)
first_response = error_dict.get("failed_generation", "")
except:
# Fallback: extract JSON from string
import re
match = re.search(r"'failed_generation': '({.*?})'", error_str)
if match:
first_response = match.group(1)
else:
raise e
else:
raise e
# ===== 5. PARSE & EXECUTE TOOLS =====
if request.enable_tools:
tool_result = await tools_service.parse_and_execute(first_response)
if tool_result:
# Tool was called!
tool_calls_made.append(tool_result)
# Add tool result to messages
messages.append({"role": "assistant", "content": first_response})
messages.append({
"role": "user",
"content": f"TOOL RESULT:\n{json.dumps(tool_result['result'], ensure_ascii=False, indent=2)}\n\nHãy dùng thông tin này để trả lời câu hỏi của user."
})
# Second LLM call với tool results
final_response = ""
for msg in client.chat_completion(
messages,
max_tokens=request.max_tokens,
stream=True,
temperature=request.temperature,
top_p=request.top_p,
):
choices = msg.choices
if len(choices) and choices[0].delta.content:
final_response += choices[0].delta.content
response = final_response
else:
# No tool call, use first response
response = first_response
else:
response = first_response
except Exception as e:
response = f"Error generating response with LLM: {str(e)}\n\nContext was retrieved successfully, but LLM generation failed."
# ===== 6. SAVE TO CONVERSATION HISTORY =====
conversation_service.add_message(
session_id,
"user",
request.message
)
conversation_service.add_message(
session_id,
"assistant",
response,
metadata={
"rag_stats": rag_stats,
"tool_calls": tool_calls_made,
"context_count": len(context_used)
}
)
# Also save to legacy chat_history collection
chat_data = {
"session_id": session_id,
"user_message": request.message,
"assistant_response": response,
"context_used": context_used,
"tool_calls": tool_calls_made,
"timestamp": datetime.utcnow()
}
chat_history_collection.insert_one(chat_data)
# ===== 7. RETURN RESPONSE =====
return {
"response": response,
"context_used": context_used,
"timestamp": datetime.utcnow().isoformat(),
"rag_stats": rag_stats,
"session_id": session_id,
"tool_calls": tool_calls_made if tool_calls_made else None
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")