ChatbotRAG / agent_service.py
minhvtt's picture
Update agent_service.py
35d088a verified
raw
history blame
15.3 kB
"""
Agent Service - Central Brain for Sales & Feedback Agents
Manages LLM conversation loop with tool calling
"""
from typing import Dict, Any, List, Optional
import os
from tools_service import ToolsService
class AgentService:
"""
Manages the conversation loop between User -> LLM -> Tools -> Response
"""
def __init__(
self,
tools_service: ToolsService,
embedding_service,
qdrant_service,
advanced_rag,
hf_token: str
):
self.tools_service = tools_service
self.embedding_service = embedding_service
self.qdrant_service = qdrant_service
self.advanced_rag = advanced_rag
self.hf_token = hf_token
# Load system prompts
self.prompts = self._load_prompts()
def _load_prompts(self) -> Dict[str, str]:
"""Load system prompts from files"""
prompts = {}
prompts_dir = "prompts"
for mode in ["sales_agent", "feedback_agent"]:
filepath = os.path.join(prompts_dir, f"{mode}.txt")
try:
with open(filepath, 'r', encoding='utf-8') as f:
prompts[mode] = f.read()
print(f"✓ Loaded prompt: {mode}")
except Exception as e:
print(f"⚠️ Error loading {mode} prompt: {e}")
prompts[mode] = ""
return prompts
async def chat(
self,
user_message: str,
conversation_history: List[Dict],
mode: str = "sales", # "sales" or "feedback"
user_id: Optional[str] = None,
max_iterations: int = 3
) -> Dict[str, Any]:
"""
Main conversation loop
Args:
user_message: User's input
conversation_history: Previous messages [{"role": "user", "content": ...}, ...]
mode: "sales" or "feedback"
user_id: User ID (for feedback mode to check purchase history)
max_iterations: Maximum tool call iterations to prevent infinite loops
Returns:
{
"message": "Bot response",
"tool_calls": [...], # List of tools called (for debugging)
"mode": mode
}
"""
print(f"\n🤖 Agent Mode: {mode}")
print(f"👤 User Message: {user_message}")
# Select system prompt
system_prompt = self._get_system_prompt(mode)
# Build conversation context
messages = self._build_messages(system_prompt, conversation_history, user_message)
# Agentic loop: LLM may call tools multiple times
tool_calls_made = []
current_response = None
for iteration in range(max_iterations):
print(f"\n🔄 Iteration {iteration + 1}")
# Call LLM
llm_response = await self._call_llm(messages)
print(f"🧠 LLM Response: {llm_response[:200]}...")
# Check if LLM wants to call a tool
tool_call = self._parse_tool_call(llm_response)
if not tool_call:
# No tool call -> This is the final response
current_response = llm_response
break
# Execute tool
print(f"🔧 Tool Called: {tool_call['tool_name']}")
tool_result = await self.tools_service.execute_tool(
tool_call['tool_name'],
tool_call['arguments']
)
# Record tool call
tool_calls_made.append({
"function": tool_call['tool_name'],
"arguments": tool_call['arguments'],
"result": tool_result
})
# Add tool result to conversation
messages.append({
"role": "assistant",
"content": llm_response
})
messages.append({
"role": "system",
"content": f"Tool Result:\n{self._format_tool_result({'result': tool_result})}"
})
# If tool returns "run_rag_search", handle it specially
if isinstance(tool_result, dict) and tool_result.get("action") == "run_rag_search":
rag_results = await self._execute_rag_search(tool_result["query"])
messages[-1]["content"] = f"RAG Search Results:\n{rag_results}"
# Clean up response
final_response = current_response or llm_response
final_response = self._clean_response(final_response)
return {
"message": final_response,
"tool_calls": tool_calls_made,
"mode": mode
}
def _get_system_prompt(self, mode: str) -> str:
"""Get system prompt for selected mode with tools definition"""
prompt_key = f"{mode}_agent" if mode in ["sales", "feedback"] else "sales_agent"
base_prompt = self.prompts.get(prompt_key, "")
# Add tools definition
tools_definition = self._get_tools_definition()
return f"{base_prompt}\n\n{tools_definition}"
def _get_tools_definition(self) -> str:
"""Get tools definition in text format for prompt"""
return """
# AVAILABLE TOOLS
You can call the following tools when needed. To call a tool, output a JSON block like this:
```json
{
"tool_call": "tool_name",
"arguments": {
"arg1": "value1",
"arg2": "value2"
}
}
```
## Tools List:
### 1. search_events
Search for events matching user criteria.
Arguments:
- query (string): Search keywords
- vibe (string, optional): Mood/vibe (e.g., "chill", "sôi động")
- time (string, optional): Time period (e.g., "cuối tuần này")
Example:
```json
{"tool_call": "search_events", "arguments": {"query": "nhạc rock", "vibe": "sôi động"}}
```
### 2. get_event_details
Get detailed information about a specific event.
Arguments:
- event_id (string): Event ID from search results
Example:
```json
{"tool_call": "get_event_details", "arguments": {"event_id": "6900ae38eb03f29702c7fd1d"}}
```
### 3. get_purchased_events (Feedback mode only)
Check which events the user has attended.
Arguments:
- user_id (string): User ID
Example:
```json
{"tool_call": "get_purchased_events", "arguments": {"user_id": "user_123"}}
```
### 4. save_feedback
Save user's feedback/review for an event.
Arguments:
- event_id (string): Event ID
- rating (integer): 1-5 stars
- comment (string, optional): User's comment
Example:
```json
{"tool_call": "save_feedback", "arguments": {"event_id": "abc123", "rating": 5, "comment": "Tuyệt vời!"}}
```
### 5. save_lead
Save customer contact information.
Arguments:
- email (string, optional): Email address
- phone (string, optional): Phone number
- interest (string, optional): What they're interested in
Example:
```json
{"tool_call": "save_lead", "arguments": {"email": "user@example.com", "interest": "Rock show"}}
```
**IMPORTANT:**
- Call tools ONLY when you need real-time data
- After receiving tool results, respond naturally to the user
- Don't expose raw JSON to users - always format nicely
"""
def _build_messages(
self,
system_prompt: str,
history: List[Dict],
user_message: str
) -> List[Dict]:
"""Build messages array for LLM"""
messages = [{"role": "system", "content": system_prompt}]
# Add conversation history
messages.extend(history)
# Add current user message
messages.append({"role": "user", "content": user_message})
return messages
async def _call_llm(self, messages: List[Dict]) -> str:
"""
Call HuggingFace LLM directly using InferenceClient
"""
try:
from huggingface_hub import AsyncInferenceClient
# Build prompt from messages
prompt = self._messages_to_prompt(messages)
# Create async client
client = AsyncInferenceClient(token=self.hf_token)
# Call HF API
response_text = ""
async for token in await client.text_generation(
prompt=prompt,
model="openai/gpt-oss-20b",
max_new_tokens=512,
temperature=0.7,
stream=True
):
response_text += token
return response_text
except Exception as e:
print(f"⚠️ LLM Call Error: {e}")
return "Xin lỗi, tôi đang gặp chút vấn đề kỹ thuật. Bạn thử lại sau nhé!"
def _messages_to_prompt(self, messages: List[Dict]) -> str:
"""Convert messages array to single prompt string"""
prompt_parts = []
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
prompt_parts.append(f"[SYSTEM]\n{content}\n")
elif role == "user":
prompt_parts.append(f"[USER]\n{content}\n")
elif role == "assistant":
prompt_parts.append(f"[ASSISTANT]\n{content}\n")
return "\n".join(prompt_parts)
def _format_tool_result(self, tool_result: Dict) -> str:
"""Format tool result for feeding back to LLM"""
result = tool_result.get("result", {})
if isinstance(result, dict):
# Pretty print key info
formatted = []
for key, value in result.items():
if key not in ["success", "error"]:
formatted.append(f"{key}: {value}")
return "\n".join(formatted)
return str(result)
async def _execute_rag_search(self, query_params: Dict) -> str:
"""
Execute RAG search for event discovery
Called when LLM wants to search_events
"""
query = query_params.get("query", "")
vibe = query_params.get("vibe", "")
# Build search query
search_text = f"{query} {vibe}".strip()
print(f"🔍 RAG Search: {search_text}")
# Use embedding + qdrant
embedding = self.embedding_service.encode_text(search_text)
results = self.qdrant_service.search(
collection_name="events",
query_vector=embedding,
limit=5
)
# Format results
formatted = []
for i, result in enumerate(results, 1):
payload = result.payload or {}
texts = payload.get("texts", [])
text = texts[0] if texts else ""
event_id = payload.get("id_use", "")
formatted.append(f"{i}. {text[:100]}... (ID: {event_id})")
return "\n".join(formatted) if formatted else "Không tìm thấy sự kiện phù hợp."
def _parse_tool_call(self, llm_response: str) -> Optional[Dict]:
"""
Parse LLM response to detect tool calls using structured JSON
Returns:
{"tool_name": "...", "arguments": {...}} or None
"""
import json
import re
# Method 1: Look for JSON code block
json_match = re.search(r'```json\s*(\{.*?\})\s*```', llm_response, re.DOTALL)
if json_match:
try:
data = json.loads(json_match.group(1))
return self._extract_tool_from_json(data)
except json.JSONDecodeError:
pass
# Method 2: Look for inline JSON object
# Find all potential JSON objects
json_objects = re.findall(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', llm_response)
for json_str in json_objects:
try:
data = json.loads(json_str)
tool_call = self._extract_tool_from_json(data)
if tool_call:
return tool_call
except json.JSONDecodeError:
continue
# Method 3: Nested JSON (for complex structures)
try:
# Find outermost curly braces
if '{' in llm_response and '}' in llm_response:
start = llm_response.find('{')
# Find matching closing brace
count = 0
for i, char in enumerate(llm_response[start:], start):
if char == '{':
count += 1
elif char == '}':
count -= 1
if count == 0:
json_str = llm_response[start:i+1]
data = json.loads(json_str)
return self._extract_tool_from_json(data)
except (json.JSONDecodeError, ValueError):
pass
return None
def _extract_tool_from_json(self, data: dict) -> Optional[Dict]:
"""
Extract tool call information from parsed JSON
Supports multiple formats:
- {"tool_call": "search_events", "arguments": {...}}
- {"function": "search_events", "parameters": {...}}
- {"name": "search_events", "args": {...}}
"""
# Format 1: tool_call + arguments
if "tool_call" in data and isinstance(data["tool_call"], str):
return {
"tool_name": data["tool_call"],
"arguments": data.get("arguments", {})
}
# Format 2: function + parameters
if "function" in data:
return {
"tool_name": data["function"],
"arguments": data.get("parameters", data.get("arguments", {}))
}
# Format 3: name + args
if "name" in data:
return {
"tool_name": data["name"],
"arguments": data.get("args", data.get("arguments", {}))
}
# Format 4: Direct tool name as key
valid_tools = ["search_events", "get_event_details", "get_purchased_events", "save_feedback", "save_lead"]
for tool in valid_tools:
if tool in data:
return {
"tool_name": tool,
"arguments": data[tool] if isinstance(data[tool], dict) else {}
}
return None
def _clean_response(self, response: str) -> str:
"""Remove JSON artifacts from final response"""
# Remove JSON blocks
if "```json" in response:
response = response.split("```json")[0]
if "```" in response:
response = response.split("```")[0]
# Remove tool call markers
if "{" in response and "tool_call" in response:
# Find the last natural sentence before JSON
lines = response.split("\n")
cleaned = []
for line in lines:
if "{" in line and "tool_call" in line:
break
cleaned.append(line)
response = "\n".join(cleaned)
return response.strip()