Spaces:
Sleeping
Sleeping
| """ | |
| State Manager - Quản lý trạng thái và context của agent | |
| """ | |
| import json | |
| from typing import Dict, Any, List, Optional | |
| from dataclasses import dataclass, asdict | |
| from datetime import datetime | |
| class ToolResult: | |
| """Kết quả từ một tool""" | |
| tool_name: str | |
| success: bool | |
| result: Any | |
| error_message: Optional[str] = None | |
| execution_time: Optional[float] = None | |
| timestamp: str = "" | |
| def __post_init__(self): | |
| if not self.timestamp: | |
| self.timestamp = datetime.now().isoformat() | |
| class TaskContext: | |
| """Context của một task""" | |
| task_id: str | |
| question: str | |
| original_question: str # Câu hỏi gốc (trước khi reverse) | |
| question_type: str = "unknown" # youtube, image, audio, wiki, text, file, math | |
| has_file_attachment: bool = False | |
| detected_urls: List[str] = None | |
| processed_question: str = "" # Câu hỏi sau khi xử lý | |
| def __post_init__(self): | |
| if self.detected_urls is None: | |
| self.detected_urls = [] | |
| class AgentState: | |
| """Quản lý trạng thái của Agent""" | |
| def __init__(self): | |
| self.current_task: Optional[TaskContext] = None | |
| self.tool_results: List[ToolResult] = [] | |
| self.conversation_history: List[Dict[str, Any]] = [] | |
| self.cached_data: Dict[str, Any] = {} | |
| self.session_id: str = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| def start_new_task(self, task_id: str, question: str) -> TaskContext: | |
| """Bắt đầu task mới""" | |
| self.current_task = TaskContext( | |
| task_id=task_id, | |
| question=question, | |
| original_question=question | |
| ) | |
| self.tool_results = [] # Reset tool results cho task mới | |
| return self.current_task | |
| def update_task_context(self, **kwargs) -> None: | |
| """Cập nhật context của task hiện tại""" | |
| if self.current_task: | |
| for key, value in kwargs.items(): | |
| if hasattr(self.current_task, key): | |
| setattr(self.current_task, key, value) | |
| def add_tool_result(self, tool_result: ToolResult) -> None: | |
| """Thêm kết quả tool""" | |
| self.tool_results.append(tool_result) | |
| # Cache một số kết quả quan trọng | |
| if tool_result.success: | |
| if tool_result.tool_name == "youtube_tool": | |
| self.cached_data["youtube_content"] = tool_result.result | |
| elif tool_result.tool_name == "wiki_search": | |
| self.cached_data["wiki_content"] = tool_result.result | |
| elif tool_result.tool_name == "image_ocr": | |
| self.cached_data["image_text"] = tool_result.result | |
| elif tool_result.tool_name == "audio_transcript": | |
| self.cached_data["audio_text"] = tool_result.result | |
| def get_tool_results(self, tool_name: Optional[str] = None) -> List[ToolResult]: | |
| """Lấy kết quả tools""" | |
| if tool_name: | |
| return [r for r in self.tool_results if r.tool_name == tool_name] | |
| return self.tool_results | |
| def has_successful_tool(self, tool_name: str) -> bool: | |
| """Kiểm tra xem tool đã chạy thành công chưa""" | |
| return any(r.tool_name == tool_name and r.success for r in self.tool_results) | |
| def get_cached_data(self, key: str) -> Any: | |
| """Lấy cached data""" | |
| return self.cached_data.get(key) | |
| def add_conversation_turn(self, role: str, content: str, metadata: Dict[str, Any] = None): | |
| """Thêm lượt hội thoại""" | |
| turn = { | |
| "role": role, | |
| "content": content, | |
| "timestamp": datetime.now().isoformat(), | |
| "metadata": metadata or {} | |
| } | |
| self.conversation_history.append(turn) | |
| def get_conversation_context(self, max_turns: int = 5) -> List[Dict[str, Any]]: | |
| """Lấy context hội thoại gần nhất""" | |
| return self.conversation_history[-max_turns:] if self.conversation_history else [] | |
| def generate_context_summary(self) -> str: | |
| """Tạo tóm tắt context cho AI""" | |
| if not self.current_task: | |
| return "No active task" | |
| summary_parts = [] | |
| # Task info | |
| summary_parts.append(f"Current Task: {self.current_task.task_id}") | |
| summary_parts.append(f"Question Type: {self.current_task.question_type}") | |
| summary_parts.append(f"Original Question: {self.current_task.original_question}") | |
| if self.current_task.processed_question != self.current_task.original_question: | |
| summary_parts.append(f"Processed Question: {self.current_task.processed_question}") | |
| # Tools used | |
| successful_tools = [r.tool_name for r in self.tool_results if r.success] | |
| if successful_tools: | |
| summary_parts.append(f"Successful Tools: {', '.join(set(successful_tools))}") | |
| failed_tools = [r.tool_name for r in self.tool_results if not r.success] | |
| if failed_tools: | |
| summary_parts.append(f"Failed Tools: {', '.join(set(failed_tools))}") | |
| # Available cached data | |
| if self.cached_data: | |
| summary_parts.append(f"Available Data: {', '.join(self.cached_data.keys())}") | |
| return "\n".join(summary_parts) | |
| def export_state(self) -> Dict[str, Any]: | |
| """Export state để debug hoặc logging""" | |
| return { | |
| "session_id": self.session_id, | |
| "current_task": asdict(self.current_task) if self.current_task else None, | |
| "tool_results": [asdict(r) for r in self.tool_results], | |
| "conversation_history": self.conversation_history, | |
| "cached_data_keys": list(self.cached_data.keys()) | |
| } | |
| def clear_state(self): | |
| """Xóa state (cho task mới)""" | |
| self.current_task = None | |
| self.tool_results = [] | |
| self.cached_data = {} | |
| # Giữ conversation_history để maintain context | |
| # Singleton instance | |
| _agent_state = AgentState() | |
| def get_agent_state() -> AgentState: | |
| """Lấy global agent state""" | |
| return _agent_state | |
| def reset_agent_state(): | |
| """Reset global agent state""" | |
| global _agent_state | |
| _agent_state = AgentState() | |
| # Utility functions | |
| def analyze_question_type(question: str) -> str: | |
| """Phân tích loại câu hỏi""" | |
| # Type check and convert if needed | |
| if not isinstance(question, str): | |
| if hasattr(question, 'get'): # Is dict-like | |
| question = str(question.get('question', question.get('content', str(question)))) | |
| else: | |
| question = str(question) | |
| question_lower = question.lower() | |
| # Check for URLs/file attachments | |
| if "youtube.com" in question or "youtu.be" in question: | |
| return "youtube" | |
| elif "http" in question and any(ext in question for ext in ['.jpg', '.png', '.pdf', '.doc']): | |
| return "file_url" | |
| elif any(word in question_lower for word in ['image', 'picture', 'photo', 'diagram']): | |
| return "image" | |
| elif any(word in question_lower for word in ['audio', 'sound', 'voice', 'music']): | |
| return "audio" | |
| elif any(word in question_lower for word in ['who is', 'what is', 'when was', 'where is']): | |
| return "wiki" | |
| elif any(word in question_lower for word in ['calculate', 'solve', 'math', 'equation']): | |
| return "math" | |
| elif "file:" in question_lower or "attachment" in question_lower: | |
| return "file" | |
| else: | |
| return "text" | |
| def detect_urls_in_question(question: str) -> List[str]: | |
| """Detect URLs trong câu hỏi""" | |
| # Type check and convert if needed | |
| if not isinstance(question, str): | |
| if hasattr(question, 'get'): # Is dict-like | |
| question = str(question.get('question', question.get('content', str(question)))) | |
| else: | |
| question = str(question) | |
| import re | |
| url_pattern = r'https?://[^\s<>"]+|www\.[^\s<>"]+' | |
| return re.findall(url_pattern, question) | |
| # Test functions | |
| if __name__ == "__main__": | |
| # Test state manager | |
| state = get_agent_state() | |
| # Test task | |
| task = state.start_new_task("test_001", "Who is Marie Curie?") | |
| print("Task created:", task.task_id) | |
| # Test tool result | |
| result = ToolResult( | |
| tool_name="wiki_search", | |
| success=True, | |
| result={"title": "Marie Curie", "summary": "Polish physicist..."} | |
| ) | |
| state.add_tool_result(result) | |
| print("State summary:") | |
| print(state.generate_context_summary()) |