Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, status, Depends, BackgroundTasks, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from src.utils.logger import logger | |
| from src.agents.role_play.func import create_agents | |
| from pydantic import BaseModel, Field | |
| from typing import List, Dict, Any, Optional | |
| from src.agents.role_play.scenarios import get_scenarios, get_scenario_by_id | |
| import json | |
| import os | |
| import uuid | |
| from datetime import datetime | |
| router = APIRouter(prefix="/ai", tags=["AI"]) | |
| class RoleplayRequest(BaseModel): | |
| query: str = Field(..., description="User's query for the AI agent") | |
| session_id: str = Field( | |
| ..., description="Session ID for tracking user interactions" | |
| ) | |
| scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay") | |
| class SessionRequest(BaseModel): | |
| session_id: str = Field(..., description="Session ID to perform operations on") | |
| class CreateSessionRequest(BaseModel): | |
| name: str = Field(..., description="Name for the new session") | |
| class UpdateSessionRequest(BaseModel): | |
| session_id: str = Field(..., description="Session ID to update") | |
| name: str = Field(..., description="New name for the session") | |
| # Session management helper functions | |
| SESSIONS_FILE = "sessions.json" | |
| def load_sessions() -> List[Dict[str, Any]]: | |
| """Load sessions from JSON file""" | |
| try: | |
| if os.path.exists(SESSIONS_FILE): | |
| with open(SESSIONS_FILE, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Error loading sessions: {str(e)}") | |
| return [] | |
| def save_sessions(sessions: List[Dict[str, Any]]): | |
| """Save sessions to JSON file""" | |
| try: | |
| with open(SESSIONS_FILE, "w", encoding="utf-8") as f: | |
| json.dump(sessions, f, ensure_ascii=False, indent=2, default=str) | |
| except Exception as e: | |
| logger.error(f"Error saving sessions: {str(e)}") | |
| def create_session(name: str) -> Dict[str, Any]: | |
| """Create a new session""" | |
| session_id = str(uuid.uuid4()) | |
| session = { | |
| "id": session_id, | |
| "name": name, | |
| "created_at": datetime.now().isoformat(), | |
| "last_message": None, | |
| "message_count": 0, | |
| } | |
| sessions = load_sessions() | |
| sessions.append(session) | |
| save_sessions(sessions) | |
| return session | |
| def get_session_by_id(session_id: str) -> Optional[Dict[str, Any]]: | |
| """Get session by ID""" | |
| sessions = load_sessions() | |
| return next((s for s in sessions if s["id"] == session_id), None) | |
| def update_session_last_message(session_id: str, message: str): | |
| """Update session's last message""" | |
| sessions = load_sessions() | |
| for session in sessions: | |
| if session["id"] == session_id: | |
| session["last_message"] = message | |
| session["message_count"] = session.get("message_count", 0) + 1 | |
| break | |
| save_sessions(sessions) | |
| def delete_session_by_id(session_id: str) -> bool: | |
| """Delete session by ID""" | |
| sessions = load_sessions() | |
| original_count = len(sessions) | |
| sessions = [s for s in sessions if s["id"] != session_id] | |
| if len(sessions) < original_count: | |
| save_sessions(sessions) | |
| return True | |
| return False | |
| async def list_scenarios(): | |
| """Get all available scenarios""" | |
| return JSONResponse(content=get_scenarios()) | |
| async def roleplay(request: RoleplayRequest): | |
| """Send a message to the roleplay agent""" | |
| scenario = request.scenario | |
| if not scenario: | |
| raise HTTPException(status_code=400, detail="Scenario not provided") | |
| response = await create_agents(scenario).ainvoke( | |
| { | |
| "messages": [request.query], | |
| }, | |
| {"configurable": {"thread_id": request.session_id}}, | |
| ) | |
| # Update session with last message | |
| update_session_last_message(request.session_id, request.query) | |
| return JSONResponse(content=response["messages"][-1].content) | |
| async def get_messages(request: SessionRequest): | |
| """Get all messages from a conversation session""" | |
| try: | |
| # Create agent instance | |
| agent = create_agents() | |
| # Get current state | |
| current_state = agent.get_state( | |
| {"configurable": {"thread_id": request.session_id}} | |
| ) | |
| if not current_state or not current_state.values: | |
| return JSONResponse( | |
| content={ | |
| "session_id": request.session_id, | |
| "messages": [], | |
| "total_messages": 0, | |
| } | |
| ) | |
| # Extract messages from state | |
| messages = [] | |
| if "messages" in current_state.values: | |
| raw_messages = current_state.values["messages"] | |
| for msg in raw_messages: | |
| # Convert message object to dict format | |
| if hasattr(msg, "content") and hasattr(msg, "type"): | |
| messages.append( | |
| { | |
| "role": getattr(msg, "type", "unknown"), | |
| "content": getattr(msg, "content", ""), | |
| "timestamp": getattr(msg, "timestamp", None), | |
| } | |
| ) | |
| elif hasattr(msg, "content"): | |
| # Handle different message formats | |
| role = ( | |
| "human" | |
| if hasattr(msg, "__class__") | |
| and "Human" in msg.__class__.__name__ | |
| else "ai" | |
| ) | |
| messages.append( | |
| { | |
| "role": role, | |
| "content": msg.content, | |
| "timestamp": getattr(msg, "timestamp", None), | |
| } | |
| ) | |
| else: | |
| # Fallback for unexpected message format | |
| messages.append( | |
| {"role": "unknown", "content": str(msg), "timestamp": None} | |
| ) | |
| return JSONResponse( | |
| content={ | |
| "session_id": request.session_id, | |
| "messages": messages, | |
| "total_messages": len(messages), | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error( | |
| f"Error getting messages for session {request.session_id}: {str(e)}" | |
| ) | |
| raise HTTPException(status_code=500, detail=f"Failed to get messages: {str(e)}") | |
| async def get_sessions(): | |
| """Get all sessions""" | |
| try: | |
| sessions = load_sessions() | |
| return JSONResponse(content={"sessions": sessions}) | |
| except Exception as e: | |
| logger.error(f"Error getting sessions: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to get sessions: {str(e)}") | |
| async def create_new_session(request: CreateSessionRequest): | |
| """Create a new session""" | |
| try: | |
| session = create_session(request.name) | |
| return JSONResponse(content={"session": session}) | |
| except Exception as e: | |
| logger.error(f"Error creating session: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Failed to create session: {str(e)}" | |
| ) | |
| async def get_session(session_id: str): | |
| """Get a specific session by ID""" | |
| try: | |
| session = get_session_by_id(session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return JSONResponse(content={"session": session}) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error getting session {session_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to get session: {str(e)}") | |
| async def update_session(session_id: str, request: UpdateSessionRequest): | |
| """Update a session""" | |
| try: | |
| sessions = load_sessions() | |
| session_found = False | |
| for session in sessions: | |
| if session["id"] == session_id: | |
| session["name"] = request.name | |
| session_found = True | |
| break | |
| if not session_found: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| save_sessions(sessions) | |
| updated_session = get_session_by_id(session_id) | |
| return JSONResponse(content={"session": updated_session}) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error updating session {session_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Failed to update session: {str(e)}" | |
| ) | |
| async def delete_session(session_id: str): | |
| """Delete a session""" | |
| try: | |
| success = delete_session_by_id(session_id) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return JSONResponse(content={"message": "Session deleted successfully"}) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error deleting session {session_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Failed to delete session: {str(e)}" | |
| ) | |