minhvtt commited on
Commit
cd1985a
·
verified ·
1 Parent(s): a212f5f

Update chat_endpoint.py

Browse files
Files changed (1) hide show
  1. chat_endpoint.py +282 -261
chat_endpoint.py CHANGED
@@ -1,261 +1,282 @@
1
- """
2
- Chat endpoint với Multi-turn Conversation + Function Calling
3
- """
4
- from fastapi import HTTPException
5
- from datetime import datetime
6
- from huggingface_hub import InferenceClient
7
- from typing import Dict, List
8
- import json
9
-
10
-
11
- async def chat_endpoint(
12
- request, # ChatRequest
13
- conversation_service,
14
- tools_service,
15
- advanced_rag,
16
- embedding_service,
17
- qdrant_service,
18
- chat_history_collection,
19
- hf_token
20
- ):
21
- """
22
- Multi-turn conversational chatbot với RAG + Function Calling
23
-
24
- Flow:
25
- 1. Session management - create hoặc load existing session
26
- 2. RAG search - retrieve context nếu enabled
27
- 3. Build messages với conversation history + tools prompt
28
- 4. LLM generation - có thể trigger tool calls
29
- 5. Execute tools nếu cần
30
- 6. Final LLM response với tool results
31
- 7. Save to conversation history
32
- """
33
- try:
34
- # ===== 1. SESSION MANAGEMENT =====
35
- session_id = request.session_id
36
- if not session_id:
37
- # Create new session (server-side)
38
- session_id = conversation_service.create_session(
39
- metadata={"user_agent": "api", "created_via": "chat_endpoint"}
40
- )
41
- print(f"Created new session: {session_id}")
42
- else:
43
- # Validate existing session
44
- if not conversation_service.session_exists(session_id):
45
- raise HTTPException(
46
- status_code=404,
47
- detail=f"Session {session_id} not found. It may have expired."
48
- )
49
-
50
- # Load conversation history
51
- conversation_history = conversation_service.get_conversation_history(session_id)
52
-
53
- # ===== 2. RAG SEARCH =====
54
- context_used = []
55
- rag_stats = None
56
- context_text = ""
57
-
58
- if request.use_rag:
59
- if request.use_advanced_rag:
60
- # Use Advanced RAG Pipeline
61
- hf_client = None
62
- if request.hf_token or hf_token:
63
- hf_client = InferenceClient(token=request.hf_token or hf_token)
64
-
65
- documents, stats = advanced_rag.hybrid_rag_pipeline(
66
- query=request.message,
67
- top_k=request.top_k,
68
- score_threshold=request.score_threshold,
69
- use_reranking=request.use_reranking,
70
- use_compression=request.use_compression,
71
- use_query_expansion=request.use_query_expansion,
72
- max_context_tokens=500,
73
- hf_client=hf_client
74
- )
75
-
76
- # Convert to dict format
77
- context_used = [
78
- {
79
- "id": doc.id,
80
- "confidence": doc.confidence,
81
- "metadata": doc.metadata
82
- }
83
- for doc in documents
84
- ]
85
- rag_stats = stats
86
-
87
- # Format context
88
- context_text = advanced_rag.format_context_for_llm(documents)
89
- else:
90
- # Basic RAG
91
- query_embedding = embedding_service.encode_text(request.message)
92
- results = qdrant_service.search(
93
- query_embedding=query_embedding,
94
- limit=request.top_k,
95
- score_threshold=request.score_threshold
96
- )
97
- context_used = results
98
-
99
- context_text = "\n\nRelevant Context:\n"
100
- for i, doc in enumerate(context_used, 1):
101
- doc_text = doc["metadata"].get("text", "")
102
- if not doc_text:
103
- doc_text = " ".join(doc["metadata"].get("texts", []))
104
- confidence = doc["confidence"]
105
- context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n"
106
-
107
- # ===== 3. BUILD MESSAGES với TOOLS PROMPT =====
108
- messages = []
109
-
110
- # System message với RAG context + Tools instruction
111
- if request.use_rag and context_used:
112
- if request.use_advanced_rag:
113
- base_prompt = advanced_rag.build_rag_prompt(
114
- query="", # Query sẽ đi trong user message
115
- context=context_text,
116
- system_message=request.system_message
117
- )
118
- else:
119
- base_prompt = f"""{request.system_message}
120
-
121
- {context_text}
122
-
123
- HƯỚNG DẪN:
124
- - Sử dụng thông tin từ context trên để trả lời câu hỏi.
125
- - Trả lời tự nhiên, thân thiện, không copy nguyên văn.
126
- - Nếu tìm thấy sự kiện, hãy tóm tắt các thông tin quan trọng nhất.
127
- """
128
- else:
129
- base_prompt = request.system_message
130
-
131
- # Add tools instruction nếu enabled
132
- if request.enable_tools:
133
- tools_prompt = tools_service.get_tools_prompt()
134
- system_message_with_tools = f"{base_prompt}\n\n{tools_prompt}"
135
- else:
136
- system_message_with_tools = base_prompt
137
-
138
- # Bắt đầu messages với system
139
- messages.append({"role": "system", "content": system_message_with_tools})
140
-
141
- # Add conversation history (past turns)
142
- messages.extend(conversation_history)
143
-
144
- # Add current user message
145
- messages.append({"role": "user", "content": request.message})
146
-
147
- # ===== 4. LLM GENERATION =====
148
- token = request.hf_token or hf_token
149
- tool_calls_made = []
150
-
151
- if not token:
152
- response = f"""[LLM Response Placeholder]
153
-
154
- Context retrieved: {len(context_used)} documents
155
- User question: {request.message}
156
- Session: {session_id}
157
-
158
- To enable actual LLM generation:
159
- 1. Set HUGGINGFACE_TOKEN environment variable, OR
160
- 2. Pass hf_token in request body
161
- """
162
- else:
163
- try:
164
- client = InferenceClient(
165
- token=token,
166
- model="openai/gpt-oss-20b" # Hoặc model khác
167
- )
168
-
169
- # First LLM call
170
- first_response = ""
171
- for msg in client.chat_completion(
172
- messages,
173
- max_tokens=request.max_tokens,
174
- stream=True,
175
- temperature=request.temperature,
176
- top_p=request.top_p,
177
- ):
178
- choices = msg.choices
179
- if len(choices) and choices[0].delta.content:
180
- first_response += choices[0].delta.content
181
-
182
- # ===== 5. PARSE & EXECUTE TOOLS =====
183
- if request.enable_tools:
184
- tool_result = await tools_service.parse_and_execute(first_response)
185
-
186
- if tool_result:
187
- # Tool was called!
188
- tool_calls_made.append(tool_result)
189
-
190
- # Add tool result to messages
191
- messages.append({"role": "assistant", "content": first_response})
192
- messages.append({
193
- "role": "user",
194
- "content": f"TOOL RESULT:\n{json.dumps(tool_result['result'], ensure_ascii=False, indent=2)}\n\nHãy dùng thông tin này để trả lời câu hỏi của user."
195
- })
196
-
197
- # Second LLM call với tool results
198
- final_response = ""
199
- for msg in client.chat_completion(
200
- messages,
201
- max_tokens=request.max_tokens,
202
- stream=True,
203
- temperature=request.temperature,
204
- top_p=request.top_p,
205
- ):
206
- choices = msg.choices
207
- if len(choices) and choices[0].delta.content:
208
- final_response += choices[0].delta.content
209
-
210
- response = final_response
211
- else:
212
- # No tool call, use first response
213
- response = first_response
214
- else:
215
- response = first_response
216
-
217
- except Exception as e:
218
- response = f"Error generating response with LLM: {str(e)}\n\nContext was retrieved successfully, but LLM generation failed."
219
-
220
- # ===== 6. SAVE TO CONVERSATION HISTORY =====
221
- conversation_service.add_message(
222
- session_id,
223
- "user",
224
- request.message
225
- )
226
- conversation_service.add_message(
227
- session_id,
228
- "assistant",
229
- response,
230
- metadata={
231
- "rag_stats": rag_stats,
232
- "tool_calls": tool_calls_made,
233
- "context_count": len(context_used)
234
- }
235
- )
236
-
237
- # Also save to legacy chat_history collection
238
- chat_data = {
239
- "session_id": session_id,
240
- "user_message": request.message,
241
- "assistant_response": response,
242
- "context_used": context_used,
243
- "tool_calls": tool_calls_made,
244
- "timestamp": datetime.utcnow()
245
- }
246
- chat_history_collection.insert_one(chat_data)
247
-
248
- # ===== 7. RETURN RESPONSE =====
249
- return {
250
- "response": response,
251
- "context_used": context_used,
252
- "timestamp": datetime.utcnow().isoformat(),
253
- "rag_stats": rag_stats,
254
- "session_id": session_id,
255
- "tool_calls": tool_calls_made if tool_calls_made else None
256
- }
257
-
258
- except HTTPException:
259
- raise
260
- except Exception as e:
261
- raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat endpoint với Multi-turn Conversation + Function Calling
3
+ """
4
+ from fastapi import HTTPException
5
+ from datetime import datetime
6
+ from huggingface_hub import InferenceClient
7
+ from typing import Dict, List
8
+ import json
9
+
10
+
11
+ async def chat_endpoint(
12
+ request, # ChatRequest
13
+ conversation_service,
14
+ tools_service,
15
+ advanced_rag,
16
+ embedding_service,
17
+ qdrant_service,
18
+ chat_history_collection,
19
+ hf_token
20
+ ):
21
+ """
22
+ Multi-turn conversational chatbot với RAG + Function Calling
23
+
24
+ Flow:
25
+ 1. Session management - create hoặc load existing session
26
+ 2. RAG search - retrieve context nếu enabled
27
+ 3. Build messages với conversation history + tools prompt
28
+ 4. LLM generation - có thể trigger tool calls
29
+ 5. Execute tools nếu cần
30
+ 6. Final LLM response với tool results
31
+ 7. Save to conversation history
32
+ """
33
+ try:
34
+ # ===== 1. SESSION MANAGEMENT =====
35
+ session_id = request.session_id
36
+ if not session_id:
37
+ # Create new session (server-side)
38
+ session_id = conversation_service.create_session(
39
+ metadata={"user_agent": "api", "created_via": "chat_endpoint"}
40
+ )
41
+ print(f"Created new session: {session_id}")
42
+ else:
43
+ # Validate existing session
44
+ if not conversation_service.session_exists(session_id):
45
+ raise HTTPException(
46
+ status_code=404,
47
+ detail=f"Session {session_id} not found. It may have expired."
48
+ )
49
+
50
+ # Load conversation history
51
+ conversation_history = conversation_service.get_conversation_history(session_id)
52
+
53
+ # ===== 2. RAG SEARCH =====
54
+ context_used = []
55
+ rag_stats = None
56
+ context_text = ""
57
+
58
+ if request.use_rag:
59
+ if request.use_advanced_rag:
60
+ # Use Advanced RAG Pipeline
61
+ hf_client = None
62
+ if request.hf_token or hf_token:
63
+ hf_client = InferenceClient(token=request.hf_token or hf_token)
64
+
65
+ documents, stats = advanced_rag.hybrid_rag_pipeline(
66
+ query=request.message,
67
+ top_k=request.top_k,
68
+ score_threshold=request.score_threshold,
69
+ use_reranking=request.use_reranking,
70
+ use_compression=request.use_compression,
71
+ use_query_expansion=request.use_query_expansion,
72
+ max_context_tokens=500,
73
+ hf_client=hf_client
74
+ )
75
+
76
+ # Convert to dict format
77
+ context_used = [
78
+ {
79
+ "id": doc.id,
80
+ "confidence": doc.confidence,
81
+ "metadata": doc.metadata
82
+ }
83
+ for doc in documents
84
+ ]
85
+ rag_stats = stats
86
+
87
+ # Format context
88
+ context_text = advanced_rag.format_context_for_llm(documents)
89
+ else:
90
+ # Basic RAG
91
+ query_embedding = embedding_service.encode_text(request.message)
92
+ results = qdrant_service.search(
93
+ query_embedding=query_embedding,
94
+ limit=request.top_k,
95
+ score_threshold=request.score_threshold
96
+ )
97
+ context_used = results
98
+
99
+ context_text = "\n\nRelevant Context:\n"
100
+ for i, doc in enumerate(context_used, 1):
101
+ doc_text = doc["metadata"].get("text", "")
102
+ if not doc_text:
103
+ doc_text = " ".join(doc["metadata"].get("texts", []))
104
+ confidence = doc["confidence"]
105
+ context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n"
106
+
107
+ # ===== 3. BUILD MESSAGES với TOOLS PROMPT =====
108
+ messages = []
109
+
110
+ # System message với RAG context + Tools instruction
111
+ if request.use_rag and context_used:
112
+ if request.use_advanced_rag:
113
+ base_prompt = advanced_rag.build_rag_prompt(
114
+ query="", # Query sẽ đi trong user message
115
+ context=context_text,
116
+ system_message=request.system_message
117
+ )
118
+ else:
119
+ base_prompt = f"""{request.system_message}
120
+
121
+ {context_text}
122
+
123
+ HƯỚNG DẪN:
124
+ - Sử dụng thông tin từ context trên để trả lời câu hỏi.
125
+ - Trả lời tự nhiên, thân thiện, không copy nguyên văn.
126
+ - Nếu tìm thấy sự kiện, hãy tóm tắt các thông tin quan trọng nhất.
127
+ """
128
+ else:
129
+ base_prompt = request.system_message
130
+
131
+ # Add tools instruction nếu enabled
132
+ if request.enable_tools:
133
+ tools_prompt = tools_service.get_tools_prompt()
134
+ system_message_with_tools = f"{base_prompt}\n\n{tools_prompt}"
135
+ else:
136
+ system_message_with_tools = base_prompt
137
+
138
+ # Bắt đầu messages với system
139
+ messages.append({"role": "system", "content": system_message_with_tools})
140
+
141
+ # Add conversation history (past turns)
142
+ messages.extend(conversation_history)
143
+
144
+ # Add current user message
145
+ messages.append({"role": "user", "content": request.message})
146
+
147
+ # ===== 4. LLM GENERATION =====
148
+ token = request.hf_token or hf_token
149
+ tool_calls_made = []
150
+
151
+ if not token:
152
+ response = f"""[LLM Response Placeholder]
153
+
154
+ Context retrieved: {len(context_used)} documents
155
+ User question: {request.message}
156
+ Session: {session_id}
157
+
158
+ To enable actual LLM generation:
159
+ 1. Set HUGGINGFACE_TOKEN environment variable, OR
160
+ 2. Pass hf_token in request body
161
+ """
162
+ else:
163
+ try:
164
+ client = InferenceClient(
165
+ token=token,
166
+ model="openai/gpt-oss-20b" # Hoặc model khác
167
+ )
168
+
169
+ # First LLM call
170
+ first_response = ""
171
+ try:
172
+ for msg in client.chat_completion(
173
+ messages,
174
+ max_tokens=request.max_tokens,
175
+ stream=True,
176
+ temperature=request.temperature,
177
+ top_p=request.top_p,
178
+ ):
179
+ choices = msg.choices
180
+ if len(choices) and choices[0].delta.content:
181
+ first_response += choices[0].delta.content
182
+ except Exception as e:
183
+ # HF API throws error when LLM returns JSON (tool call)
184
+ # Extract the "failed_generation" from error
185
+ error_str = str(e)
186
+ if "tool_use_failed" in error_str and "failed_generation" in error_str:
187
+ # Parse error dict to get the actual JSON response
188
+ import ast
189
+ try:
190
+ error_dict = ast.literal_eval(error_str)
191
+ first_response = error_dict.get("failed_generation", "")
192
+ except:
193
+ # Fallback: extract JSON from string
194
+ import re
195
+ match = re.search(r"'failed_generation': '({.*?})'", error_str)
196
+ if match:
197
+ first_response = match.group(1)
198
+ else:
199
+ raise e
200
+ else:
201
+ raise e
202
+
203
+ # ===== 5. PARSE & EXECUTE TOOLS =====
204
+ if request.enable_tools:
205
+ tool_result = await tools_service.parse_and_execute(first_response)
206
+
207
+ if tool_result:
208
+ # Tool was called!
209
+ tool_calls_made.append(tool_result)
210
+
211
+ # Add tool result to messages
212
+ messages.append({"role": "assistant", "content": first_response})
213
+ messages.append({
214
+ "role": "user",
215
+ "content": f"TOOL RESULT:\n{json.dumps(tool_result['result'], ensure_ascii=False, indent=2)}\n\nHãy dùng thông tin này để trả lời câu hỏi của user."
216
+ })
217
+
218
+ # Second LLM call với tool results
219
+ final_response = ""
220
+ for msg in client.chat_completion(
221
+ messages,
222
+ max_tokens=request.max_tokens,
223
+ stream=True,
224
+ temperature=request.temperature,
225
+ top_p=request.top_p,
226
+ ):
227
+ choices = msg.choices
228
+ if len(choices) and choices[0].delta.content:
229
+ final_response += choices[0].delta.content
230
+
231
+ response = final_response
232
+ else:
233
+ # No tool call, use first response
234
+ response = first_response
235
+ else:
236
+ response = first_response
237
+
238
+ except Exception as e:
239
+ response = f"Error generating response with LLM: {str(e)}\n\nContext was retrieved successfully, but LLM generation failed."
240
+
241
+ # ===== 6. SAVE TO CONVERSATION HISTORY =====
242
+ conversation_service.add_message(
243
+ session_id,
244
+ "user",
245
+ request.message
246
+ )
247
+ conversation_service.add_message(
248
+ session_id,
249
+ "assistant",
250
+ response,
251
+ metadata={
252
+ "rag_stats": rag_stats,
253
+ "tool_calls": tool_calls_made,
254
+ "context_count": len(context_used)
255
+ }
256
+ )
257
+
258
+ # Also save to legacy chat_history collection
259
+ chat_data = {
260
+ "session_id": session_id,
261
+ "user_message": request.message,
262
+ "assistant_response": response,
263
+ "context_used": context_used,
264
+ "tool_calls": tool_calls_made,
265
+ "timestamp": datetime.utcnow()
266
+ }
267
+ chat_history_collection.insert_one(chat_data)
268
+
269
+ # ===== 7. RETURN RESPONSE =====
270
+ return {
271
+ "response": response,
272
+ "context_used": context_used,
273
+ "timestamp": datetime.utcnow().isoformat(),
274
+ "rag_stats": rag_stats,
275
+ "session_id": session_id,
276
+ "tool_calls": tool_calls_made if tool_calls_made else None
277
+ }
278
+
279
+ except HTTPException:
280
+ raise
281
+ except Exception as e:
282
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")