minhvtt commited on
Commit
70413d7
·
verified ·
1 Parent(s): 885208c

Update intent_classifier.py

Browse files
Files changed (1) hide show
  1. intent_classifier.py +201 -188
intent_classifier.py CHANGED
@@ -1,188 +1,201 @@
1
- """
2
- Intent Classifier for Hybrid RAG + FSM Chatbot
3
- Detects user intent to route between scenario flows and RAG queries
4
- """
5
- from typing import Dict, Optional, List
6
- import re
7
-
8
-
9
- class IntentClassifier:
10
- """
11
- Classify user intent using keyword matching
12
- Routes to either:
13
- - Scenario flows (scripted conversations)
14
- - RAG queries (knowledge retrieval)
15
- """
16
-
17
- def __init__(self, scenarios_dir: str = "scenarios"):
18
- """
19
- Initialize with auto-loading triggers from scenario JSON files
20
-
21
- Args:
22
- scenarios_dir: Directory containing scenario JSON files
23
- """
24
- # Auto-load scenario patterns from JSON files
25
- self.scenario_patterns = self._load_scenario_patterns(scenarios_dir)
26
-
27
- # General question patterns (RAG)
28
- self.general_patterns = [
29
- # Location
30
- "ở đâu", "địa điểm", "location", "where",
31
- "chỗ nào", "tổ chức tại",
32
-
33
- # Time
34
- "mấy giờ", "khi nào", "when", "time",
35
- "bao giờ", "thời gian", "ngày nào",
36
-
37
- # Info
38
- "thông tin", "info", "information",
39
- "chi tiết", "details", "về",
40
-
41
- # Parking
42
- "đậu xe", "parking", "gửi xe",
43
-
44
- # Contact
45
- "liên hệ", "contact", "số điện thoại",
46
-
47
- # Events/content - NEW (Bug fix #3)
48
- "sự kiện", "event", "đâu", "show nào",
49
- "line-up", "lineup", "performer"
50
- ]
51
-
52
- def _load_scenario_patterns(self, scenarios_dir: str) -> dict:
53
- """
54
- Auto-load triggers from all scenario JSON files
55
-
56
- Returns:
57
- {"scenario_id": ["trigger1", "trigger2", ...]}
58
- """
59
- import json
60
- import os
61
-
62
- patterns = {}
63
-
64
- if not os.path.exists(scenarios_dir):
65
- print(f"⚠ Scenarios directory not found: {scenarios_dir}")
66
- return patterns
67
-
68
- for filename in os.listdir(scenarios_dir):
69
- if filename.endswith('.json'):
70
- filepath = os.path.join(scenarios_dir, filename)
71
- try:
72
- with open(filepath, 'r', encoding='utf-8') as f:
73
- scenario = json.load(f)
74
- scenario_id = scenario.get('scenario_id')
75
- triggers = scenario.get('triggers', [])
76
-
77
- if scenario_id and triggers:
78
- patterns[scenario_id] = triggers
79
- print(f"✓ Loaded triggers for: {scenario_id} ({len(triggers)} patterns)")
80
- except Exception as e:
81
- print(f"⚠ Error loading {filename}: {e}")
82
-
83
- return patterns
84
-
85
- def classify(
86
- self,
87
- message: str,
88
- conversation_state: Optional[Dict] = None
89
- ) -> str:
90
- """
91
- Classify user intent with IMPROVED mid-scenario detection (Bug fix #3)
92
-
93
- Returns:
94
- - "scenario:{scenario_id}" - Trigger new scenario
95
- - "scenario:continue" - Continue active scenario
96
- - "rag:general" - General RAG query (no active scenario)
97
- - "rag:with_resume" - RAG query mid-scenario (then resume)
98
- """
99
- message_lower = message.lower().strip()
100
-
101
- # Check if user is in active scenario
102
- active_scenario = conversation_state.get('active_scenario') if conversation_state else None
103
-
104
- if active_scenario:
105
- # User is in a scenario - check if this is off-topic or continuation
106
-
107
- # IMPROVED: Detect off-topic questions better
108
- # Check for question words + patterns
109
- question_indicators = ["?", "đâu", "gì", "sao", "where", "what", "how", "when"]
110
- has_question = any(q in message_lower for q in question_indicators)
111
-
112
- # Check if matches general patterns
113
- matches_general = self._matches_any_pattern(message_lower, self.general_patterns)
114
-
115
- # Short messages with questions are likely off-topic
116
- word_count = len(message_lower.split())
117
- is_short_question = word_count <= 4 and has_question
118
-
119
- # Decision logic
120
- if matches_general or is_short_question:
121
- # User asking off-topic question RAG with resume
122
- print(f"🔀 Off-topic detected: '{message}' → rag:with_resume")
123
- return "rag:with_resume"
124
- else:
125
- # Normal scenario continuation
126
- return "scenario:continue"
127
-
128
- # Not in scenario - check for scenario triggers
129
- for scenario_id, patterns in self.scenario_patterns.items():
130
- for pattern in patterns:
131
- if pattern.lower() in message_lower:
132
- return f"scenario:{scenario_id}"
133
-
134
- # No scenario match - general RAG query
135
- return "rag:general"
136
-
137
- def _matches_any_pattern(self, message: str, patterns: List[str]) -> bool:
138
- """
139
- Check if message matches any pattern in list
140
- """
141
- for pattern in patterns:
142
- # Simple substring match
143
- if pattern in message:
144
- return True
145
-
146
- # Word boundary check
147
- if re.search(rf'\b{re.escape(pattern)}\b', message, re.IGNORECASE):
148
- return True
149
-
150
- return False
151
-
152
- def get_scenario_type(self, intent: str) -> Optional[str]:
153
- """
154
- Extract scenario type from intent string
155
-
156
- Args:
157
- intent: "scenario:price_inquiry" or "scenario:continue"
158
-
159
- Returns:
160
- "price_inquiry" or None
161
- """
162
- if not intent.startswith("scenario:"):
163
- return None
164
-
165
- parts = intent.split(":", 1)
166
- if len(parts) < 2:
167
- return None
168
-
169
- scenario_type = parts[1]
170
- if scenario_type == "continue":
171
- return None
172
-
173
- return scenario_type
174
-
175
- def add_scenario_pattern(self, scenario_id: str, patterns: List[str]):
176
- """
177
- Dynamically add new scenario patterns
178
- """
179
- if scenario_id in self.scenario_patterns:
180
- self.scenario_patterns[scenario_id].extend(patterns)
181
- else:
182
- self.scenario_patterns[scenario_id] = patterns
183
-
184
- def add_general_pattern(self, patterns: List[str]):
185
- """
186
- Dynamically add new general question patterns
187
- """
188
- self.general_patterns.extend(patterns)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Intent Classifier for Hybrid RAG + FSM Chatbot
3
+ Detects user intent to route between scenario flows and RAG queries
4
+ """
5
+ from typing import Dict, Optional, List
6
+ import re
7
+
8
+
9
+ class IntentClassifier:
10
+ """
11
+ Classify user intent using keyword matching
12
+ Routes to either:
13
+ - Scenario flows (scripted conversations)
14
+ - RAG queries (knowledge retrieval)
15
+ """
16
+
17
+ def __init__(self, scenarios_dir: str = "scenarios"):
18
+ """
19
+ Initialize with auto-loading triggers from scenario JSON files
20
+
21
+ Args:
22
+ scenarios_dir: Directory containing scenario JSON files
23
+ """
24
+ # Auto-load scenario patterns from JSON files
25
+ self.scenario_patterns = self._load_scenario_patterns(scenarios_dir)
26
+
27
+ # General question patterns (RAG)
28
+ self.general_patterns = [
29
+ # Location
30
+ "ở đâu", "địa điểm", "location", "where",
31
+ "chỗ nào", "tổ chức tại",
32
+
33
+ # Time
34
+ "mấy giờ", "khi nào", "when", "time",
35
+ "bao giờ", "thời gian", "ngày nào",
36
+
37
+ # Info
38
+ "thông tin", "info", "information",
39
+ "chi tiết", "details", "về",
40
+
41
+ # Parking
42
+ "đậu xe", "parking", "gửi xe",
43
+
44
+ # Contact
45
+ "liên hệ", "contact", "số điện thoại",
46
+
47
+ # Events/content
48
+ "sự kiện", "event", "đâu", "show nào",
49
+ "line-up", "lineup", "performer"
50
+ ]
51
+
52
+ def _load_scenario_patterns(self, scenarios_dir: str) -> dict:
53
+ """
54
+ Auto-load triggers from all scenario JSON files
55
+
56
+ Returns:
57
+ {"scenario_id": ["trigger1", "trigger2", ...]}
58
+ """
59
+ import json
60
+ import os
61
+
62
+ patterns = {}
63
+
64
+ if not os.path.exists(scenarios_dir):
65
+ print(f"⚠ Scenarios directory not found: {scenarios_dir}")
66
+ return patterns
67
+
68
+ for filename in os.listdir(scenarios_dir):
69
+ if filename.endswith('.json'):
70
+ filepath = os.path.join(scenarios_dir, filename)
71
+ try:
72
+ with open(filepath, 'r', encoding='utf-8') as f:
73
+ scenario = json.load(f)
74
+ scenario_id = scenario.get('scenario_id')
75
+ triggers = scenario.get('triggers', [])
76
+
77
+ if scenario_id and triggers:
78
+ patterns[scenario_id] = triggers
79
+ print(f"✓ Loaded triggers for: {scenario_id} ({len(triggers)} patterns)")
80
+ except Exception as e:
81
+ print(f"⚠ Error loading {filename}: {e}")
82
+
83
+ return patterns
84
+
85
+ def classify(
86
+ self,
87
+ message: str,
88
+ conversation_state: Optional[Dict] = None
89
+ ) -> str:
90
+ """
91
+ Classify user intent with improved mid-scenario detection
92
+
93
+ Returns:
94
+ - "scenario:{scenario_id}" - Trigger new scenario
95
+ - "scenario:continue" - Continue active scenario
96
+ - "rag:general" - General RAG query (no active scenario)
97
+ - "rag:with_resume" - RAG query mid-scenario (then resume)
98
+ """
99
+ message_lower = message.lower().strip()
100
+
101
+ # Check if user is in active scenario
102
+ active_scenario = conversation_state.get('active_scenario') if conversation_state else None
103
+
104
+ if active_scenario:
105
+ # User is in a scenario - check if this is off-topic or continuation
106
+
107
+ # Valid choice keywords (answers to scenario questions)
108
+ choice_keywords = [
109
+ # Event recommendation choices
110
+ 'giá', 'price', 'vé', 'ticket',
111
+ 'lineup', 'line-up', 'nghệ sĩ', 'artist',
112
+ 'địa điểm', 'location', 'chỗ',
113
+ 'thời gian', 'time', 'lịch',
114
+ # General answers
115
+ 'có', 'yes', 'ok', 'được', 'không', 'no',
116
+ 'chill', 'sôi động', 'hài', 'workshop',
117
+ '1', '2', '3', '4', '5' # Ratings or choices
118
+ ]
119
+
120
+ # Check if message matches valid answer
121
+ is_valid_answer = any(keyword in message_lower for keyword in choice_keywords)
122
+
123
+ # Check if this is a question (off-topic)
124
+ has_question_mark = "?" in message
125
+ question_words = ["gì", "sao", "thế nào", "bao nhiêu", "mấy giờ", "ai", "how", "what", "why"]
126
+ has_question_word = any(qw in message_lower for qw in question_words)
127
+
128
+ # Classify as off-topic ONLY if:
129
+ # 1. Has question mark OR question words
130
+ # 2. AND does NOT match valid answer keywords
131
+ # 3. AND is asking about new information
132
+ is_off_topic = (has_question_mark or has_question_word) and not is_valid_answer
133
+
134
+ if is_off_topic:
135
+ print(f"🔀 Off-topic question detected: '{message}' → rag:with_resume")
136
+ return "rag:with_resume"
137
+ else:
138
+ # Normal scenario continuation
139
+ return "scenario:continue"
140
+
141
+ # Not in scenario - check for scenario triggers
142
+ for scenario_id, patterns in self.scenario_patterns.items():
143
+ for pattern in patterns:
144
+ if pattern.lower() in message_lower:
145
+ return f"scenario:{scenario_id}"
146
+
147
+ # No scenario match - general RAG query
148
+ return "rag:general"
149
+
150
+ def _matches_any_pattern(self, message: str, patterns: List[str]) -> bool:
151
+ """
152
+ Check if message matches any pattern in list
153
+ """
154
+ for pattern in patterns:
155
+ # Simple substring match
156
+ if pattern in message:
157
+ return True
158
+
159
+ # Word boundary check
160
+ if re.search(rf'\b{re.escape(pattern)}\b', message, re.IGNORECASE):
161
+ return True
162
+
163
+ return False
164
+
165
+ def get_scenario_type(self, intent: str) -> Optional[str]:
166
+ """
167
+ Extract scenario type from intent string
168
+
169
+ Args:
170
+ intent: "scenario:price_inquiry" or "scenario:continue"
171
+
172
+ Returns:
173
+ "price_inquiry" or None
174
+ """
175
+ if not intent.startswith("scenario:"):
176
+ return None
177
+
178
+ parts = intent.split(":", 1)
179
+ if len(parts) < 2:
180
+ return None
181
+
182
+ scenario_type = parts[1]
183
+ if scenario_type == "continue":
184
+ return None
185
+
186
+ return scenario_type
187
+
188
+ def add_scenario_pattern(self, scenario_id: str, patterns: List[str]):
189
+ """
190
+ Dynamically add new scenario patterns
191
+ """
192
+ if scenario_id in self.scenario_patterns:
193
+ self.scenario_patterns[scenario_id].extend(patterns)
194
+ else:
195
+ self.scenario_patterns[scenario_id] = patterns
196
+
197
+ def add_general_pattern(self, patterns: List[str]):
198
+ """
199
+ Dynamically add new general question patterns
200
+ """
201
+ self.general_patterns.extend(patterns)