ChatbotRAG / scenario_engine.py
minhvtt's picture
Upload 20 files
2ecdea6 verified
raw
history blame
12.7 kB
"""
Scenario Engine for FSM-based Conversations
Executes multi-turn scripted conversations from JSON definitions
"""
import json
import os
import re
from typing import Dict, Optional, List, Any
from datetime import datetime
class ScenarioEngine:
"""
Execute scenario-based conversations
Load scenarios from JSON and manage step-by-step flow
"""
def __init__(self, scenarios_dir: str = "scenarios"):
self.scenarios_dir = scenarios_dir
self.scenarios = self._load_scenarios()
def _load_scenarios(self) -> Dict[str, Dict]:
"""Load all scenario JSON files"""
scenarios = {}
if not os.path.exists(self.scenarios_dir):
print(f"⚠ Scenarios directory not found: {self.scenarios_dir}")
return scenarios
for filename in os.listdir(self.scenarios_dir):
if filename.endswith('.json'):
filepath = os.path.join(self.scenarios_dir, filename)
with open(filepath, 'r', encoding='utf-8') as f:
scenario = json.load(f)
scenario_id = scenario.get('scenario_id')
if scenario_id:
scenarios[scenario_id] = scenario
print(f"✓ Loaded scenario: {scenario_id}")
return scenarios
def start_scenario(self, scenario_id: str, initial_data: Dict = None) -> Dict[str, Any]:
"""
Start a new scenario with optional initial data
Args:
scenario_id: Scenario to start
initial_data: External data to inject (event_name, mood, etc.)
Returns:
{
"message": str,
"new_state": {...},
"end_scenario": bool
}
"""
if scenario_id not in self.scenarios:
return {
"message": "Xin lỗi, tính năng này đang được cập nhật.",
"new_state": {},
"end_scenario": True
}
scenario = self.scenarios[scenario_id]
first_step = scenario['steps'][0]
# Initialize with external data
scenario_data = initial_data.copy() if initial_data else {}
# Build first message with initial data
message = self._build_message(first_step, scenario_data, None)
return {
"message": message,
"new_state": {
"active_scenario": scenario_id,
"scenario_step": 1,
"scenario_data": scenario_data,
"last_activity": datetime.utcnow().isoformat()
},
"end_scenario": False
}
def next_step(
self,
scenario_id: str,
current_step: int,
user_input: str,
scenario_data: Dict,
rag_service: Optional[Any] = None
) -> Dict[str, Any]:
"""
Process user input and move to next step
Args:
scenario_id: Active scenario ID
current_step: Current step number
user_input: User's message
scenario_data: Data collected so far
rag_service: Optional RAG service for hybrid queries
Returns:
{
"message": str,
"new_state": {...} | None,
"end_scenario": bool,
"action": str | None
}
"""
if scenario_id not in self.scenarios:
return {"message": "Error: Scenario not found", "end_scenario": True}
scenario = self.scenarios[scenario_id]
current_step_config = self._get_step(scenario, current_step)
if not current_step_config:
return {"message": "Error: Step not found", "end_scenario": True}
# Validate input if needed
expected_type = current_step_config.get('expected_input_type')
if expected_type:
validation_error = self._validate_input(user_input, expected_type)
if validation_error:
return {
"message": validation_error,
"new_state": None, # Don't change state
"end_scenario": False
}
# Handle branching
if 'branches' in current_step_config:
branch_result = self._handle_branches(
current_step_config['branches'],
user_input,
scenario_data
)
next_step_id = branch_result['next_step']
scenario_data.update(branch_result.get('save_data', {}))
else:
next_step_id = current_step_config.get('next_step')
# Save user input
input_field = current_step_config.get('save_as', f'step_{current_step}_input')
scenario_data[input_field] = user_input
# Get next step config
next_step_config = self._get_step(scenario, next_step_id)
if not next_step_config:
return {"message": "Cảm ơn bạn!", "end_scenario": True}
# Check if scenario ends
if next_step_config.get('end_scenario'):
return {
"message": next_step_config['bot_message'],
"new_state": None,
"end_scenario": True,
"action": next_step_config.get('action')
}
# Build next message
message = self._build_message(
next_step_config,
scenario_data,
rag_service
)
return {
"message": message,
"new_state": {
"active_scenario": scenario_id,
"scenario_step": next_step_id,
"scenario_data": scenario_data,
"last_activity": datetime.utcnow().isoformat()
},
"end_scenario": False,
"action": next_step_config.get('action')
}
def _get_step(self, scenario: Dict, step_id: int) -> Optional[Dict]:
"""Get step config by ID"""
for step in scenario['steps']:
if step['id'] == step_id:
return step
return None
def _validate_input(self, user_input: str, expected_type: str) -> Optional[str]:
"""
Validate user input
Returns error message or None if valid
"""
if expected_type == 'email':
if not re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', user_input):
return "Email không hợp lệ. Vui lòng nhập lại (vd: ten@email.com)"
elif expected_type == 'phone':
# Simple Vietnamese phone validation
clean = re.sub(r'[^\d]', '', user_input)
if len(clean) < 9 or len(clean) > 11:
return "Số điện thoại không hợp lệ. Vui lòng nhập lại (10-11 số)"
return None
def _handle_branches(
self,
branches: Dict,
user_input: str,
scenario_data: Dict
) -> Dict:
"""
Handle branch logic
Returns:
{"next_step": int, "save_data": {...}}
"""
user_input_lower = user_input.lower().strip()
for branch_name, branch_config in branches.items():
if branch_name == 'default':
continue
patterns = branch_config.get('patterns', [])
for pattern in patterns:
if pattern.lower() in user_input_lower:
return {
"next_step": branch_config['next_step'],
"save_data": branch_config.get('save_data', {})
}
# Default branch
default_name = branches.get('default_branch', list(branches.keys())[0])
default_branch = branches.get(default_name, list(branches.values())[0])
return {
"next_step": default_branch['next_step'],
"save_data": default_branch.get('save_data', {})
}
def _build_message(
self,
step_config: Dict,
scenario_data: Dict,
rag_service: Optional[Any]
) -> str:
"""
Build bot message with 3-layer data resolution:
1. scenario_data (initial + user inputs)
2. RAG results (if rag_query_template exists)
3. Merged template vars
"""
# Layer 1: Base data (initial + user inputs)
# Map common template vars from scenario_data
template_data = {
'event_name': scenario_data.get('event_name', scenario_data.get('step_1_input', 'sự kiện này')),
'mood': scenario_data.get('mood', scenario_data.get('step_1_input', '')),
'interest': scenario_data.get('interest', scenario_data.get('step_1_input', '')),
'interest_tag': scenario_data.get('interest_tag', scenario_data.get('step_1_input', '')),
**scenario_data # Include all scenario data
}
# Layer 2: RAG query (if specified)
if 'rag_query_template' in step_config:
try:
# Build query from template
query = step_config['rag_query_template'].format(**template_data)
if rag_service:
# Execute RAG search
results = self._execute_rag_query(query, rag_service)
template_data['rag_results'] = results
else:
# Fallback if no RAG service
template_data['rag_results'] = "(Đang tải thông tin...)"
except Exception as e:
print(f"⚠ RAG query error: {e}")
template_data['rag_results'] = ""
# Layer 3: Build final message
if 'bot_message_template' in step_config:
try:
return step_config['bot_message_template'].format(**template_data)
except KeyError as e:
print(f"⚠ Template var missing: {e}")
print(f"📋 Available vars: {list(template_data.keys())}")
# Fallback: replace missing vars with placeholder
import re
message = step_config['bot_message_template']
# Find all {var} patterns
missing_vars = re.findall(r'\{(\w+)\}', message)
for var in missing_vars:
if var not in template_data:
template_data[var] = f"[{var}]"
print(f"⚠ Adding placeholder for: {var}")
return message.format(**template_data)
return step_config.get('bot_message', '')
def _execute_rag_query(self, query: str, rag_service: Any) -> str:
"""
Execute RAG query and format results
Returns formatted string of top results
"""
try:
# Simple search (we'll integrate with actual RAG later)
# For now, return placeholder
return f"[Kết quả tìm kiếm cho: {query}]\n1. Sự kiện A\n2. Sự kiện B"
except Exception as e:
print(f"⚠ RAG execution error: {e}")
return ""
# Test
if __name__ == "__main__":
engine = ScenarioEngine()
print("\nTest: Start price_inquiry scenario")
result = engine.start_scenario("price_inquiry")
print(f"Bot: {result['message']}")
print(f"State: {result['new_state']}")
print("\nTest: User answers 'Show A'")
state = result['new_state']
result = engine.next_step(
scenario_id=state['active_scenario'],
current_step=state['scenario_step'],
user_input="Show A",
scenario_data=state['scenario_data']
)
print(f"Bot: {result['message']}")
print("\nTest: User answers 'nhóm'")
state = result['new_state']
result = engine.next_step(
scenario_id=state['active_scenario'],
current_step=state['scenario_step'],
user_input="nhóm 5 người",
scenario_data=state['scenario_data']
)
print(f"Bot: {result['message']}")
print(f"Data collected: {result['new_state']['scenario_data']}")