Spaces:
Running
Running
| """ | |
| Chat endpoint với Multi-turn Conversation + Function Calling | |
| """ | |
| from fastapi import HTTPException | |
| from datetime import datetime | |
| from huggingface_hub import InferenceClient | |
| from typing import Dict, List | |
| import json | |
| async def chat_endpoint( | |
| request, # ChatRequest | |
| conversation_service, | |
| tools_service, | |
| advanced_rag, | |
| embedding_service, | |
| qdrant_service, | |
| chat_history_collection, | |
| hf_token | |
| ): | |
| """ | |
| Multi-turn conversational chatbot với RAG + Function Calling | |
| Flow: | |
| 1. Session management - create hoặc load existing session | |
| 2. RAG search - retrieve context nếu enabled | |
| 3. Build messages với conversation history + tools prompt | |
| 4. LLM generation - có thể trigger tool calls | |
| 5. Execute tools nếu cần | |
| 6. Final LLM response với tool results | |
| 7. Save to conversation history | |
| """ | |
| try: | |
| # ===== 1. SESSION MANAGEMENT ===== | |
| session_id = request.session_id | |
| if not session_id: | |
| # Create new session (server-side) | |
| session_id = conversation_service.create_session( | |
| metadata={"user_agent": "api", "created_via": "chat_endpoint"}, | |
| user_id=request.user_id # NEW: Pass user_id from request | |
| ) | |
| print(f"Created new session: {session_id} for user: {request.user_id or 'anonymous'}") | |
| else: | |
| # Validate existing session | |
| if not conversation_service.session_exists(session_id): | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Session {session_id} not found. It may have expired." | |
| ) | |
| # Load conversation history | |
| conversation_history = conversation_service.get_conversation_history(session_id) | |
| # ===== 2. RAG SEARCH ===== | |
| context_used = [] | |
| rag_stats = None | |
| context_text = "" | |
| if request.use_rag: | |
| if request.use_advanced_rag: | |
| # Use Advanced RAG Pipeline | |
| hf_client = None | |
| if request.hf_token or hf_token: | |
| hf_client = InferenceClient(token=request.hf_token or hf_token) | |
| documents, stats = advanced_rag.hybrid_rag_pipeline( | |
| query=request.message, | |
| top_k=request.top_k, | |
| score_threshold=request.score_threshold, | |
| use_reranking=request.use_reranking, | |
| use_compression=request.use_compression, | |
| use_query_expansion=request.use_query_expansion, | |
| max_context_tokens=500, | |
| hf_client=hf_client | |
| ) | |
| # Convert to dict format | |
| context_used = [ | |
| { | |
| "id": doc.id, | |
| "confidence": doc.confidence, | |
| "metadata": doc.metadata | |
| } | |
| for doc in documents | |
| ] | |
| rag_stats = stats | |
| # Format context | |
| context_text = advanced_rag.format_context_for_llm(documents) | |
| else: | |
| # Basic RAG | |
| query_embedding = embedding_service.encode_text(request.message) | |
| results = qdrant_service.search( | |
| query_embedding=query_embedding, | |
| limit=request.top_k, | |
| score_threshold=request.score_threshold | |
| ) | |
| context_used = results | |
| context_text = "\n\nRelevant Context:\n" | |
| for i, doc in enumerate(context_used, 1): | |
| doc_text = doc["metadata"].get("text", "") | |
| if not doc_text: | |
| doc_text = " ".join(doc["metadata"].get("texts", [])) | |
| confidence = doc["confidence"] | |
| context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n" | |
| # ===== 3. BUILD MESSAGES với TOOLS PROMPT ===== | |
| messages = [] | |
| # System message với RAG context + Tools instruction | |
| if request.use_rag and context_used: | |
| if request.use_advanced_rag: | |
| base_prompt = advanced_rag.build_rag_prompt( | |
| query="", # Query sẽ đi trong user message | |
| context=context_text, | |
| system_message=request.system_message | |
| ) | |
| else: | |
| base_prompt = f"""{request.system_message} | |
| {context_text} | |
| HƯỚNG DẪN: | |
| - Sử dụng thông tin từ context trên để trả lời câu hỏi. | |
| - Trả lời tự nhiên, thân thiện, không copy nguyên văn. | |
| - Nếu tìm thấy sự kiện, hãy tóm tắt các thông tin quan trọng nhất. | |
| """ | |
| else: | |
| base_prompt = request.system_message | |
| # Add tools instruction nếu enabled | |
| if request.enable_tools: | |
| tools_prompt = tools_service.get_tools_prompt() | |
| system_message_with_tools = f"{base_prompt}\n\n{tools_prompt}" | |
| else: | |
| system_message_with_tools = base_prompt | |
| # Bắt đầu messages với system | |
| messages.append({"role": "system", "content": system_message_with_tools}) | |
| # Add conversation history (past turns) | |
| messages.extend(conversation_history) | |
| # Add current user message | |
| messages.append({"role": "user", "content": request.message}) | |
| # ===== 4. LLM GENERATION ===== | |
| token = request.hf_token or hf_token | |
| tool_calls_made = [] | |
| if not token: | |
| response = f"""[LLM Response Placeholder] | |
| Context retrieved: {len(context_used)} documents | |
| User question: {request.message} | |
| Session: {session_id} | |
| To enable actual LLM generation: | |
| 1. Set HUGGINGFACE_TOKEN environment variable, OR | |
| 2. Pass hf_token in request body | |
| """ | |
| else: | |
| try: | |
| client = InferenceClient( | |
| token=token, | |
| model="openai/gpt-oss-20b" # Hoặc model khác | |
| ) | |
| # First LLM call | |
| first_response = "" | |
| try: | |
| for msg in client.chat_completion( | |
| messages, | |
| max_tokens=request.max_tokens, | |
| stream=True, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| ): | |
| choices = msg.choices | |
| if len(choices) and choices[0].delta.content: | |
| first_response += choices[0].delta.content | |
| except Exception as e: | |
| # HF API throws error when LLM returns JSON (tool call) | |
| # Extract the "failed_generation" from error | |
| error_str = str(e) | |
| if "tool_use_failed" in error_str and "failed_generation" in error_str: | |
| # Parse error dict to get the actual JSON response | |
| import ast | |
| try: | |
| error_dict = ast.literal_eval(error_str) | |
| first_response = error_dict.get("failed_generation", "") | |
| except: | |
| # Fallback: extract JSON from string | |
| import re | |
| match = re.search(r"'failed_generation': '({.*?})'", error_str) | |
| if match: | |
| first_response = match.group(1) | |
| else: | |
| raise e | |
| else: | |
| raise e | |
| # ===== 5. PARSE & EXECUTE TOOLS ===== | |
| if request.enable_tools: | |
| tool_result = await tools_service.parse_and_execute(first_response) | |
| if tool_result: | |
| # Tool was called! | |
| tool_calls_made.append(tool_result) | |
| # Add tool result to messages | |
| messages.append({"role": "assistant", "content": first_response}) | |
| messages.append({ | |
| "role": "user", | |
| "content": f"TOOL RESULT:\n{json.dumps(tool_result['result'], ensure_ascii=False, indent=2)}\n\nHãy dùng thông tin này để trả lời câu hỏi của user." | |
| }) | |
| # Second LLM call với tool results | |
| final_response = "" | |
| for msg in client.chat_completion( | |
| messages, | |
| max_tokens=request.max_tokens, | |
| stream=True, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| ): | |
| choices = msg.choices | |
| if len(choices) and choices[0].delta.content: | |
| final_response += choices[0].delta.content | |
| response = final_response | |
| else: | |
| # No tool call, use first response | |
| response = first_response | |
| else: | |
| response = first_response | |
| except Exception as e: | |
| response = f"Error generating response with LLM: {str(e)}\n\nContext was retrieved successfully, but LLM generation failed." | |
| # ===== 6. SAVE TO CONVERSATION HISTORY ===== | |
| conversation_service.add_message( | |
| session_id, | |
| "user", | |
| request.message | |
| ) | |
| conversation_service.add_message( | |
| session_id, | |
| "assistant", | |
| response, | |
| metadata={ | |
| "rag_stats": rag_stats, | |
| "tool_calls": tool_calls_made, | |
| "context_count": len(context_used) | |
| } | |
| ) | |
| # Also save to legacy chat_history collection | |
| chat_data = { | |
| "session_id": session_id, | |
| "user_message": request.message, | |
| "assistant_response": response, | |
| "context_used": context_used, | |
| "tool_calls": tool_calls_made, | |
| "timestamp": datetime.utcnow() | |
| } | |
| chat_history_collection.insert_one(chat_data) | |
| # ===== 7. RETURN RESPONSE ===== | |
| return { | |
| "response": response, | |
| "context_used": context_used, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "rag_stats": rag_stats, | |
| "session_id": session_id, | |
| "tool_calls": tool_calls_made if tool_calls_made else None | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error: {str(e)}") | |