Spaces:
Running
Running
| """ | |
| Conversation Service for Multi-turn Chat | |
| Server-side session management | |
| """ | |
| from typing import List, Dict, Optional | |
| from datetime import datetime | |
| from pymongo.collection import Collection | |
| import uuid | |
| class ConversationService: | |
| """ | |
| Manages multi-turn conversation history với server-side session | |
| """ | |
| def __init__(self, mongo_collection: Collection, max_history: int = 10): | |
| """ | |
| Args: | |
| mongo_collection: MongoDB collection for storing conversations | |
| max_history: Maximum số messages giữ lại (sliding window) | |
| """ | |
| self.collection = mongo_collection | |
| self.max_history = max_history | |
| # Create indexes | |
| self._ensure_indexes() | |
| def _ensure_indexes(self): | |
| """Create necessary indexes""" | |
| try: | |
| self.collection.create_index("session_id", unique=True) | |
| self.collection.create_index("user_id") # NEW: Index for user filtering | |
| # Auto-delete sessions sau 7 ngày không dùng | |
| self.collection.create_index( | |
| "updated_at", | |
| expireAfterSeconds=604800 # 7 days | |
| ) | |
| print("✓ Conversation indexes created") | |
| except Exception as e: | |
| print(f"Conversation indexes already exist or error: {e}") | |
| def create_session(self, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str: | |
| """ | |
| Create new conversation session | |
| Args: | |
| metadata: Additional metadata | |
| user_id: User identifier (optional) | |
| Returns: | |
| session_id (UUID string) | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| self.collection.insert_one({ | |
| "session_id": session_id, | |
| "user_id": user_id, # NEW: Store user_id | |
| "messages": [], | |
| "scenario_state": None, # NEW: Scenario state | |
| "metadata": metadata or {}, | |
| "created_at": datetime.utcnow(), | |
| "updated_at": datetime.utcnow() | |
| }) | |
| return session_id | |
| def add_message( | |
| self, | |
| session_id: str, | |
| role: str, | |
| content: str, | |
| metadata: Optional[Dict] = None | |
| ): | |
| """ | |
| Add message to conversation history | |
| Args: | |
| session_id: Session identifier | |
| role: "user" or "assistant" | |
| content: Message text | |
| metadata: Additional info (rag_stats, tool_calls, etc.) | |
| """ | |
| message = { | |
| "role": role, | |
| "content": content, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "metadata": metadata or {} | |
| } | |
| # Upsert: tạo session nếu chưa tồn tại | |
| self.collection.update_one( | |
| {"session_id": session_id}, | |
| { | |
| "$push": { | |
| "messages": { | |
| "$each": [message], | |
| "$slice": -self.max_history # Keep only last N messages | |
| } | |
| }, | |
| "$set": {"updated_at": datetime.utcnow()} | |
| }, | |
| upsert=True | |
| ) | |
| def get_conversation_history( | |
| self, | |
| session_id: str, | |
| limit: Optional[int] = None, | |
| include_metadata: bool = False | |
| ) -> List[Dict]: | |
| """ | |
| Get conversation messages for LLM context | |
| Args: | |
| session_id: Session identifier | |
| limit: Override max_history với số lượng tùy chỉnh | |
| include_metadata: Include metadata trong response | |
| Returns: | |
| List of messages in format: [{"role": "user", "content": "..."}, ...] | |
| """ | |
| session = self.collection.find_one({"session_id": session_id}) | |
| if not session: | |
| return [] | |
| messages = session.get("messages", []) | |
| # Limit to recent messages | |
| if limit: | |
| messages = messages[-limit:] | |
| else: | |
| messages = messages[-self.max_history:] | |
| # Format for LLM | |
| if include_metadata: | |
| return messages | |
| else: | |
| return [ | |
| { | |
| "role": msg["role"], | |
| "content": msg["content"] | |
| } | |
| for msg in messages | |
| ] | |
| def get_session_info(self, session_id: str) -> Optional[Dict]: | |
| """ | |
| Get session metadata | |
| Returns: | |
| Session info hoặc None nếu không tồn tại | |
| """ | |
| session = self.collection.find_one( | |
| {"session_id": session_id}, | |
| {"_id": 0, "session_id": 1, "user_id": 1, "created_at": 1, "updated_at": 1, "metadata": 1} | |
| ) | |
| return session | |
| def clear_session(self, session_id: str) -> bool: | |
| """ | |
| Clear conversation history for session | |
| Returns: | |
| True nếu xóa thành công, False nếu session không tồn tại | |
| """ | |
| result = self.collection.delete_one({"session_id": session_id}) | |
| return result.deleted_count > 0 | |
| def session_exists(self, session_id: str) -> bool: | |
| """ | |
| Check if session exists | |
| """ | |
| return self.collection.count_documents({"session_id": session_id}) > 0 | |
| def get_last_user_message(self, session_id: str) -> Optional[str]: | |
| """ | |
| Get the last user message in conversation | |
| Useful for context extraction | |
| """ | |
| session = self.collection.find_one({"session_id": session_id}) | |
| if not session: | |
| return None | |
| messages = session.get("messages", []) | |
| # Tìm message cuối cùng từ user | |
| for msg in reversed(messages): | |
| if msg["role"] == "user": | |
| return msg["content"] | |
| return None | |
| def list_sessions( | |
| self, | |
| limit: int = 50, | |
| skip: int = 0, | |
| sort_by: str = "updated_at", | |
| descending: bool = True, | |
| user_id: Optional[str] = None # NEW: Filter by user | |
| ) -> List[Dict]: | |
| """ | |
| List all conversation sessions | |
| Args: | |
| limit: Maximum number of sessions to return | |
| skip: Number of sessions to skip (for pagination) | |
| sort_by: Field to sort by (created_at, updated_at) | |
| descending: Sort in descending order | |
| user_id: Filter sessions by user_id (optional) | |
| Returns: | |
| List of session summaries | |
| """ | |
| sort_order = -1 if descending else 1 | |
| # Build query filter | |
| query = {} | |
| if user_id: | |
| query["user_id"] = user_id | |
| sessions = self.collection.find( | |
| query, # Use query filter | |
| {"_id": 0, "session_id": 1, "user_id": 1, "created_at": 1, "updated_at": 1, "metadata": 1} | |
| ).sort(sort_by, sort_order).skip(skip).limit(limit) | |
| result = [] | |
| for session in sessions: | |
| # Count messages | |
| message_count = len( | |
| self.collection.find_one({"session_id": session["session_id"]}, {"messages": 1}) | |
| .get("messages", []) | |
| ) | |
| result.append({ | |
| "session_id": session["session_id"], | |
| "user_id": session.get("user_id"), # NEW: Include user_id | |
| "created_at": session["created_at"], | |
| "updated_at": session["updated_at"], | |
| "message_count": message_count, | |
| "metadata": session.get("metadata", {}) | |
| }) | |
| return result | |
| def count_sessions(self, user_id: Optional[str] = None) -> int: | |
| """ | |
| Get total number of sessions | |
| Args: | |
| user_id: Filter count by user_id (optional) | |
| """ | |
| query = {} | |
| if user_id: | |
| query["user_id"] = user_id | |
| return self.collection.count_documents(query) | |
| # ===== Scenario State Management ===== | |
| def get_scenario_state(self, session_id: str) -> Optional[Dict]: | |
| """ | |
| Get current scenario state for session | |
| Returns: | |
| { | |
| "active_scenario": "price_inquiry", | |
| "scenario_step": 3, | |
| "scenario_data": {...}, | |
| "last_activity": "..." | |
| } | |
| or None if no active scenario | |
| """ | |
| session = self.collection.find_one({"session_id": session_id}) | |
| if not session: | |
| return None | |
| return session.get("scenario_state") | |
| def set_scenario_state(self, session_id: str, state: Dict): | |
| """ | |
| Set scenario state for session | |
| Args: | |
| session_id: Session ID | |
| state: Scenario state dict | |
| """ | |
| self.collection.update_one( | |
| {"session_id": session_id}, | |
| { | |
| "$set": { | |
| "scenario_state": state, | |
| "updated_at": datetime.utcnow() | |
| } | |
| }, | |
| upsert=True | |
| ) | |
| def clear_scenario(self, session_id: str): | |
| """ | |
| Clear scenario state (end scenario) | |
| """ | |
| self.collection.update_one( | |
| {"session_id": session_id}, | |
| { | |
| "$set": { | |
| "scenario_state": None, | |
| "updated_at": datetime.utcnow() | |
| } | |
| } | |
| ) | |