Spaces:
Sleeping
Sleeping
| from fastapi import ( | |
| APIRouter, | |
| status, | |
| HTTPException, | |
| File, | |
| UploadFile, | |
| Form, | |
| ) | |
| from fastapi.responses import JSONResponse | |
| from src.utils.logger import logger | |
| from src.agents.role_play.flow import role_play_agent | |
| from pydantic import BaseModel, Field | |
| from typing import Dict, Any, Optional | |
| from src.agents.role_play.scenarios import get_scenarios | |
| import json | |
| import base64 | |
| router = APIRouter(prefix="/ai", tags=["AI"]) | |
| class RoleplayRequest(BaseModel): | |
| query: str = Field(..., description="User's query for the AI agent") | |
| session_id: str = Field( | |
| ..., description="Session ID for tracking user interactions" | |
| ) | |
| scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay") | |
| async def list_scenarios(): | |
| """Get all available scenarios""" | |
| return JSONResponse(content=get_scenarios()) | |
| async def roleplay( | |
| session_id: str = Form( | |
| ..., description="Session ID for tracking user interactions" | |
| ), | |
| scenario: str = Form( | |
| ..., description="The scenario for the roleplay as JSON string" | |
| ), | |
| text_message: Optional[str] = Form(None, description="Text message from user"), | |
| audio_file: Optional[UploadFile] = File(None, description="Audio file from user"), | |
| ): | |
| """Send a message (text or audio) to the roleplay agent""" | |
| logger.info(f"Received roleplay request: {session_id}") | |
| # Validate that at least one input is provided | |
| if not text_message and not audio_file: | |
| raise HTTPException( | |
| status_code=400, detail="Either text_message or audio_file must be provided" | |
| ) | |
| # Parse scenario from JSON string | |
| try: | |
| scenario_dict = json.loads(scenario) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid scenario JSON format") | |
| if not scenario_dict: | |
| raise HTTPException(status_code=400, detail="Scenario not provided") | |
| # Prepare message content | |
| message_content = [] | |
| # Handle text input | |
| if text_message: | |
| message_content.append({"type": "text", "text": text_message}) | |
| # Handle audio input | |
| if audio_file: | |
| try: | |
| # Read audio file content | |
| audio_data = await audio_file.read() | |
| # Convert to base64 | |
| audio_base64 = base64.b64encode(audio_data).decode("utf-8") | |
| # Determine mime type based on file extension | |
| file_extension = ( | |
| audio_file.filename.split(".")[-1].lower() | |
| if audio_file.filename | |
| else "wav" | |
| ) | |
| mime_type_map = { | |
| "wav": "audio/wav", | |
| "mp3": "audio/mpeg", | |
| "ogg": "audio/ogg", | |
| "webm": "audio/webm", | |
| "m4a": "audio/mp4", | |
| } | |
| mime_type = mime_type_map.get(file_extension, "audio/wav") | |
| message_content.append( | |
| { | |
| "type": "audio", | |
| "source_type": "base64", | |
| "data": audio_base64, | |
| "mime_type": mime_type, | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing audio file: {str(e)}") | |
| raise HTTPException( | |
| status_code=400, detail=f"Error processing audio file: {str(e)}" | |
| ) | |
| # Create message in the required format | |
| message = {"role": "user", "content": message_content} | |
| try: | |
| response = await role_play_agent().ainvoke( | |
| { | |
| "messages": [message], | |
| "scenario_title": scenario_dict["scenario_title"], | |
| "scenario_description": scenario_dict["scenario_description"], | |
| "scenario_context": scenario_dict["scenario_context"], | |
| "your_role": scenario_dict["your_role"], | |
| "key_vocabulary": scenario_dict["key_vocabulary"], | |
| }, | |
| {"configurable": {"thread_id": session_id}}, | |
| ) | |
| # Extract AI response content | |
| ai_response = response["messages"][-1].content | |
| logger.info(f"AI response: {ai_response}") | |
| return JSONResponse(content={"response": ai_response}) | |
| except Exception as e: | |
| logger.error(f"Error in roleplay: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |