final_agent_course / utils /state_manager.py
tuan3335's picture
structure code
92d2175
"""
State Manager - Quản lý trạng thái và context của agent
"""
import json
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
@dataclass
class ToolResult:
"""Kết quả từ một tool"""
tool_name: str
success: bool
result: Any
error_message: Optional[str] = None
execution_time: Optional[float] = None
timestamp: str = ""
def __post_init__(self):
if not self.timestamp:
self.timestamp = datetime.now().isoformat()
@dataclass
class TaskContext:
"""Context của một task"""
task_id: str
question: str
original_question: str # Câu hỏi gốc (trước khi reverse)
question_type: str = "unknown" # youtube, image, audio, wiki, text, file, math
has_file_attachment: bool = False
detected_urls: List[str] = None
processed_question: str = "" # Câu hỏi sau khi xử lý
def __post_init__(self):
if self.detected_urls is None:
self.detected_urls = []
class AgentState:
"""Quản lý trạng thái của Agent"""
def __init__(self):
self.current_task: Optional[TaskContext] = None
self.tool_results: List[ToolResult] = []
self.conversation_history: List[Dict[str, Any]] = []
self.cached_data: Dict[str, Any] = {}
self.session_id: str = datetime.now().strftime("%Y%m%d_%H%M%S")
def start_new_task(self, task_id: str, question: str) -> TaskContext:
"""Bắt đầu task mới"""
self.current_task = TaskContext(
task_id=task_id,
question=question,
original_question=question
)
self.tool_results = [] # Reset tool results cho task mới
return self.current_task
def update_task_context(self, **kwargs) -> None:
"""Cập nhật context của task hiện tại"""
if self.current_task:
for key, value in kwargs.items():
if hasattr(self.current_task, key):
setattr(self.current_task, key, value)
def add_tool_result(self, tool_result: ToolResult) -> None:
"""Thêm kết quả tool"""
self.tool_results.append(tool_result)
# Cache một số kết quả quan trọng
if tool_result.success:
if tool_result.tool_name == "youtube_tool":
self.cached_data["youtube_content"] = tool_result.result
elif tool_result.tool_name == "wiki_search":
self.cached_data["wiki_content"] = tool_result.result
elif tool_result.tool_name == "image_ocr":
self.cached_data["image_text"] = tool_result.result
elif tool_result.tool_name == "audio_transcript":
self.cached_data["audio_text"] = tool_result.result
def get_tool_results(self, tool_name: Optional[str] = None) -> List[ToolResult]:
"""Lấy kết quả tools"""
if tool_name:
return [r for r in self.tool_results if r.tool_name == tool_name]
return self.tool_results
def has_successful_tool(self, tool_name: str) -> bool:
"""Kiểm tra xem tool đã chạy thành công chưa"""
return any(r.tool_name == tool_name and r.success for r in self.tool_results)
def get_cached_data(self, key: str) -> Any:
"""Lấy cached data"""
return self.cached_data.get(key)
def add_conversation_turn(self, role: str, content: str, metadata: Dict[str, Any] = None):
"""Thêm lượt hội thoại"""
turn = {
"role": role,
"content": content,
"timestamp": datetime.now().isoformat(),
"metadata": metadata or {}
}
self.conversation_history.append(turn)
def get_conversation_context(self, max_turns: int = 5) -> List[Dict[str, Any]]:
"""Lấy context hội thoại gần nhất"""
return self.conversation_history[-max_turns:] if self.conversation_history else []
def generate_context_summary(self) -> str:
"""Tạo tóm tắt context cho AI"""
if not self.current_task:
return "No active task"
summary_parts = []
# Task info
summary_parts.append(f"Current Task: {self.current_task.task_id}")
summary_parts.append(f"Question Type: {self.current_task.question_type}")
summary_parts.append(f"Original Question: {self.current_task.original_question}")
if self.current_task.processed_question != self.current_task.original_question:
summary_parts.append(f"Processed Question: {self.current_task.processed_question}")
# Tools used
successful_tools = [r.tool_name for r in self.tool_results if r.success]
if successful_tools:
summary_parts.append(f"Successful Tools: {', '.join(set(successful_tools))}")
failed_tools = [r.tool_name for r in self.tool_results if not r.success]
if failed_tools:
summary_parts.append(f"Failed Tools: {', '.join(set(failed_tools))}")
# Available cached data
if self.cached_data:
summary_parts.append(f"Available Data: {', '.join(self.cached_data.keys())}")
return "\n".join(summary_parts)
def export_state(self) -> Dict[str, Any]:
"""Export state để debug hoặc logging"""
return {
"session_id": self.session_id,
"current_task": asdict(self.current_task) if self.current_task else None,
"tool_results": [asdict(r) for r in self.tool_results],
"conversation_history": self.conversation_history,
"cached_data_keys": list(self.cached_data.keys())
}
def clear_state(self):
"""Xóa state (cho task mới)"""
self.current_task = None
self.tool_results = []
self.cached_data = {}
# Giữ conversation_history để maintain context
# Singleton instance
_agent_state = AgentState()
def get_agent_state() -> AgentState:
"""Lấy global agent state"""
return _agent_state
def reset_agent_state():
"""Reset global agent state"""
global _agent_state
_agent_state = AgentState()
# Utility functions
def analyze_question_type(question: str) -> str:
"""Phân tích loại câu hỏi"""
# Type check and convert if needed
if not isinstance(question, str):
if hasattr(question, 'get'): # Is dict-like
question = str(question.get('question', question.get('content', str(question))))
else:
question = str(question)
question_lower = question.lower()
# Check for URLs/file attachments
if "youtube.com" in question or "youtu.be" in question:
return "youtube"
elif "http" in question and any(ext in question for ext in ['.jpg', '.png', '.pdf', '.doc']):
return "file_url"
elif any(word in question_lower for word in ['image', 'picture', 'photo', 'diagram']):
return "image"
elif any(word in question_lower for word in ['audio', 'sound', 'voice', 'music']):
return "audio"
elif any(word in question_lower for word in ['who is', 'what is', 'when was', 'where is']):
return "wiki"
elif any(word in question_lower for word in ['calculate', 'solve', 'math', 'equation']):
return "math"
elif "file:" in question_lower or "attachment" in question_lower:
return "file"
else:
return "text"
def detect_urls_in_question(question: str) -> List[str]:
"""Detect URLs trong câu hỏi"""
# Type check and convert if needed
if not isinstance(question, str):
if hasattr(question, 'get'): # Is dict-like
question = str(question.get('question', question.get('content', str(question))))
else:
question = str(question)
import re
url_pattern = r'https?://[^\s<>"]+|www\.[^\s<>"]+'
return re.findall(url_pattern, question)
# Test functions
if __name__ == "__main__":
# Test state manager
state = get_agent_state()
# Test task
task = state.start_new_task("test_001", "Who is Marie Curie?")
print("Task created:", task.task_id)
# Test tool result
result = ToolResult(
tool_name="wiki_search",
success=True,
result={"title": "Marie Curie", "summary": "Polish physicist..."}
)
state.add_tool_result(result)
print("State summary:")
print(state.generate_context_summary())