ChatbotRAG / intent_classifier.py
minhvtt's picture
Update intent_classifier.py
70413d7 verified
raw
history blame
7.32 kB
"""
Intent Classifier for Hybrid RAG + FSM Chatbot
Detects user intent to route between scenario flows and RAG queries
"""
from typing import Dict, Optional, List
import re
class IntentClassifier:
"""
Classify user intent using keyword matching
Routes to either:
- Scenario flows (scripted conversations)
- RAG queries (knowledge retrieval)
"""
def __init__(self, scenarios_dir: str = "scenarios"):
"""
Initialize with auto-loading triggers from scenario JSON files
Args:
scenarios_dir: Directory containing scenario JSON files
"""
# Auto-load scenario patterns from JSON files
self.scenario_patterns = self._load_scenario_patterns(scenarios_dir)
# General question patterns (RAG)
self.general_patterns = [
# Location
"ở đâu", "địa điểm", "location", "where",
"chỗ nào", "tổ chức tại",
# Time
"mấy giờ", "khi nào", "when", "time",
"bao giờ", "thời gian", "ngày nào",
# Info
"thông tin", "info", "information",
"chi tiết", "details", "về",
# Parking
"đậu xe", "parking", "gửi xe",
# Contact
"liên hệ", "contact", "số điện thoại",
# Events/content
"sự kiện", "event", "đâu", "show nào",
"line-up", "lineup", "performer"
]
def _load_scenario_patterns(self, scenarios_dir: str) -> dict:
"""
Auto-load triggers from all scenario JSON files
Returns:
{"scenario_id": ["trigger1", "trigger2", ...]}
"""
import json
import os
patterns = {}
if not os.path.exists(scenarios_dir):
print(f"⚠ Scenarios directory not found: {scenarios_dir}")
return patterns
for filename in os.listdir(scenarios_dir):
if filename.endswith('.json'):
filepath = os.path.join(scenarios_dir, filename)
try:
with open(filepath, 'r', encoding='utf-8') as f:
scenario = json.load(f)
scenario_id = scenario.get('scenario_id')
triggers = scenario.get('triggers', [])
if scenario_id and triggers:
patterns[scenario_id] = triggers
print(f"✓ Loaded triggers for: {scenario_id} ({len(triggers)} patterns)")
except Exception as e:
print(f"⚠ Error loading {filename}: {e}")
return patterns
def classify(
self,
message: str,
conversation_state: Optional[Dict] = None
) -> str:
"""
Classify user intent with improved mid-scenario detection
Returns:
- "scenario:{scenario_id}" - Trigger new scenario
- "scenario:continue" - Continue active scenario
- "rag:general" - General RAG query (no active scenario)
- "rag:with_resume" - RAG query mid-scenario (then resume)
"""
message_lower = message.lower().strip()
# Check if user is in active scenario
active_scenario = conversation_state.get('active_scenario') if conversation_state else None
if active_scenario:
# User is in a scenario - check if this is off-topic or continuation
# Valid choice keywords (answers to scenario questions)
choice_keywords = [
# Event recommendation choices
'giá', 'price', 'vé', 'ticket',
'lineup', 'line-up', 'nghệ sĩ', 'artist',
'địa điểm', 'location', 'chỗ',
'thời gian', 'time', 'lịch',
# General answers
'có', 'yes', 'ok', 'được', 'không', 'no',
'chill', 'sôi động', 'hài', 'workshop',
'1', '2', '3', '4', '5' # Ratings or choices
]
# Check if message matches valid answer
is_valid_answer = any(keyword in message_lower for keyword in choice_keywords)
# Check if this is a question (off-topic)
has_question_mark = "?" in message
question_words = ["gì", "sao", "thế nào", "bao nhiêu", "mấy giờ", "ai", "how", "what", "why"]
has_question_word = any(qw in message_lower for qw in question_words)
# Classify as off-topic ONLY if:
# 1. Has question mark OR question words
# 2. AND does NOT match valid answer keywords
# 3. AND is asking about new information
is_off_topic = (has_question_mark or has_question_word) and not is_valid_answer
if is_off_topic:
print(f"🔀 Off-topic question detected: '{message}' → rag:with_resume")
return "rag:with_resume"
else:
# Normal scenario continuation
return "scenario:continue"
# Not in scenario - check for scenario triggers
for scenario_id, patterns in self.scenario_patterns.items():
for pattern in patterns:
if pattern.lower() in message_lower:
return f"scenario:{scenario_id}"
# No scenario match - general RAG query
return "rag:general"
def _matches_any_pattern(self, message: str, patterns: List[str]) -> bool:
"""
Check if message matches any pattern in list
"""
for pattern in patterns:
# Simple substring match
if pattern in message:
return True
# Word boundary check
if re.search(rf'\b{re.escape(pattern)}\b', message, re.IGNORECASE):
return True
return False
def get_scenario_type(self, intent: str) -> Optional[str]:
"""
Extract scenario type from intent string
Args:
intent: "scenario:price_inquiry" or "scenario:continue"
Returns:
"price_inquiry" or None
"""
if not intent.startswith("scenario:"):
return None
parts = intent.split(":", 1)
if len(parts) < 2:
return None
scenario_type = parts[1]
if scenario_type == "continue":
return None
return scenario_type
def add_scenario_pattern(self, scenario_id: str, patterns: List[str]):
"""
Dynamically add new scenario patterns
"""
if scenario_id in self.scenario_patterns:
self.scenario_patterns[scenario_id].extend(patterns)
else:
self.scenario_patterns[scenario_id] = patterns
def add_general_pattern(self, patterns: List[str]):
"""
Dynamically add new general question patterns
"""
self.general_patterns.extend(patterns)