Spaces:
Sleeping
Sleeping
| from fastapi import ( | |
| APIRouter, | |
| status, | |
| HTTPException, | |
| File, | |
| UploadFile, | |
| Form, | |
| ) | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from src.utils.logger import logger | |
| from src.agents.role_play.flow import role_play_agent | |
| from src.services.tts_service import tts_service | |
| from pydantic import BaseModel, Field | |
| from typing import Dict, Any, Optional | |
| from src.agents.role_play.scenarios import get_scenarios | |
| import json | |
| import base64 | |
| import asyncio | |
| router = APIRouter(prefix="/ai", tags=["AI"]) | |
| class RoleplayRequest(BaseModel): | |
| query: str = Field(..., description="User's query for the AI agent") | |
| session_id: str = Field( | |
| ..., description="Session ID for tracking user interactions" | |
| ) | |
| scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay") | |
| async def list_scenarios(): | |
| """Get all available scenarios""" | |
| return JSONResponse(content=get_scenarios()) | |
| async def roleplay( | |
| session_id: str = Form( | |
| ..., description="Session ID for tracking user interactions" | |
| ), | |
| scenario: str = Form( | |
| ..., description="The scenario for the roleplay as JSON string" | |
| ), | |
| text_message: Optional[str] = Form(None, description="Text message from user"), | |
| audio_file: Optional[UploadFile] = File(None, description="Audio file from user"), | |
| ): | |
| """Send a message (text or audio) to the roleplay agent""" | |
| logger.info(f"Received roleplay request: {session_id}") | |
| # Validate that at least one input is provided | |
| if not text_message and not audio_file: | |
| raise HTTPException( | |
| status_code=400, detail="Either text_message or audio_file must be provided" | |
| ) | |
| # Parse scenario from JSON string | |
| try: | |
| scenario_dict = json.loads(scenario) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid scenario JSON format") | |
| if not scenario_dict: | |
| raise HTTPException(status_code=400, detail="Scenario not provided") | |
| # Prepare message content | |
| message_content = [] | |
| # Handle text input | |
| if text_message: | |
| message_content.append({"type": "text", "text": text_message}) | |
| # Handle audio input | |
| if audio_file: | |
| try: | |
| # Read audio file content | |
| audio_data = await audio_file.read() | |
| # Convert to base64 | |
| audio_base64 = base64.b64encode(audio_data).decode("utf-8") | |
| # Determine mime type based on file extension | |
| file_extension = ( | |
| audio_file.filename.split(".")[-1].lower() | |
| if audio_file.filename | |
| else "wav" | |
| ) | |
| mime_type_map = { | |
| "wav": "audio/wav", | |
| "mp3": "audio/mpeg", | |
| "ogg": "audio/ogg", | |
| "webm": "audio/webm", | |
| "m4a": "audio/mp4", | |
| } | |
| mime_type = mime_type_map.get(file_extension, "audio/wav") | |
| message_content.append( | |
| { | |
| "type": "audio", | |
| "source_type": "base64", | |
| "data": audio_base64, | |
| "mime_type": mime_type, | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing audio file: {str(e)}") | |
| raise HTTPException( | |
| status_code=400, detail=f"Error processing audio file: {str(e)}" | |
| ) | |
| # Create message in the required format | |
| message = {"role": "user", "content": message_content} | |
| try: | |
| response = await role_play_agent().ainvoke( | |
| { | |
| "messages": [message], | |
| "scenario_title": scenario_dict["scenario_title"], | |
| "scenario_description": scenario_dict["scenario_description"], | |
| "scenario_context": scenario_dict["scenario_context"], | |
| "your_role": scenario_dict["your_role"], | |
| "key_vocabulary": scenario_dict["key_vocabulary"], | |
| }, | |
| {"configurable": {"thread_id": session_id}}, | |
| ) | |
| # Extract AI response content | |
| ai_response = response["messages"][-1].content | |
| logger.info(f"AI response: {ai_response}") | |
| return JSONResponse(content={"response": ai_response}) | |
| except Exception as e: | |
| logger.error(f"Error in roleplay: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def roleplay_stream( | |
| session_id: str = Form( | |
| ..., description="Session ID for tracking user interactions" | |
| ), | |
| scenario: str = Form( | |
| ..., description="The scenario for the roleplay as JSON string" | |
| ), | |
| text_message: Optional[str] = Form(None, description="Text message from user"), | |
| audio_file: Optional[UploadFile] = File(None, description="Audio file from user"), | |
| audio: bool = Form(False, description="Whether to return TTS audio response"), | |
| ): | |
| """Send a message (text or audio) to the roleplay agent with streaming response""" | |
| logger.info(f"Received streaming roleplay request: {session_id}") | |
| # Validate that at least one input is provided | |
| if not text_message and not audio_file: | |
| raise HTTPException( | |
| status_code=400, detail="Either text_message or audio_file must be provided" | |
| ) | |
| # Parse scenario from JSON string | |
| try: | |
| scenario_dict = json.loads(scenario) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid scenario JSON format") | |
| if not scenario_dict: | |
| raise HTTPException(status_code=400, detail="Scenario not provided") | |
| # Prepare message content | |
| message_content = [] | |
| # Handle text input | |
| if text_message: | |
| message_content.append({"type": "text", "text": text_message}) | |
| # Handle audio input | |
| if audio_file: | |
| try: | |
| # Read audio file content | |
| audio_data = await audio_file.read() | |
| # Convert to base64 | |
| audio_base64 = base64.b64encode(audio_data).decode("utf-8") | |
| # Determine mime type based on file extension | |
| file_extension = ( | |
| audio_file.filename.split(".")[-1].lower() | |
| if audio_file.filename | |
| else "wav" | |
| ) | |
| mime_type_map = { | |
| "wav": "audio/wav", | |
| "mp3": "audio/mpeg", | |
| "ogg": "audio/ogg", | |
| "webm": "audio/webm", | |
| "m4a": "audio/mp4", | |
| } | |
| mime_type = mime_type_map.get(file_extension, "audio/wav") | |
| message_content.append( | |
| { | |
| "type": "audio", | |
| "source_type": "base64", | |
| "data": audio_base64, | |
| "mime_type": mime_type, | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing audio file: {str(e)}") | |
| raise HTTPException( | |
| status_code=400, detail=f"Error processing audio file: {str(e)}" | |
| ) | |
| # Create message in the required format | |
| message = {"role": "user", "content": message_content} | |
| async def generate_stream(): | |
| """Generator function for streaming responses""" | |
| accumulated_content = "" | |
| conversation_ended = False | |
| try: | |
| input_graph = { | |
| "messages": [message], | |
| "scenario_title": scenario_dict["scenario_title"], | |
| "scenario_description": scenario_dict["scenario_description"], | |
| "scenario_context": scenario_dict["scenario_context"], | |
| "your_role": scenario_dict["your_role"], | |
| "key_vocabulary": scenario_dict["key_vocabulary"], | |
| } | |
| config = {"configurable": {"thread_id": session_id}} | |
| async for event in role_play_agent().astream( | |
| input=input_graph, | |
| stream_mode=["messages", "values"], | |
| config=config, | |
| subgraphs=True, | |
| ): | |
| _, event_type, message_chunk = event | |
| if event_type == "messages": | |
| # message_chunk is a tuple, get the first element which is the actual AIMessageChunk | |
| if isinstance(message_chunk, tuple) and len(message_chunk) > 0: | |
| actual_message = message_chunk[0] | |
| content = getattr(actual_message, "content", "") | |
| else: | |
| actual_message = message_chunk | |
| content = getattr(message_chunk, "content", "") | |
| # Check if this is a tool call message and if it's an end conversation tool call | |
| if ( | |
| hasattr(actual_message, "tool_calls") | |
| and actual_message.tool_calls | |
| ): | |
| # Check if any tool call is for ending the conversation | |
| for tool_call in actual_message.tool_calls: | |
| if ( | |
| isinstance(tool_call, dict) | |
| and tool_call.get("name") == "end_conversation" | |
| ): | |
| # Send a special termination message to the client | |
| termination_data = { | |
| "type": "termination", | |
| "content": "Conversation ended", | |
| "reason": tool_call.get("args", {}).get("reason", "Unknown reason") | |
| } | |
| yield f"data: {json.dumps(termination_data)}\n\n" | |
| conversation_ended = True | |
| break | |
| if content and not conversation_ended: | |
| # Accumulate content for TTS | |
| accumulated_content += content | |
| # Create SSE-formatted response | |
| response_data = { | |
| "type": "message_chunk", | |
| "content": content, | |
| "metadata": { | |
| "agent": getattr(actual_message, "name", "unknown"), | |
| "id": getattr(actual_message, "id", ""), | |
| "usage_metadata": getattr( | |
| actual_message, "usage_metadata", {} | |
| ), | |
| }, | |
| } | |
| yield f"data: {json.dumps(response_data)}\n\n" | |
| # Small delay to prevent overwhelming the client | |
| await asyncio.sleep(0.01) | |
| # Only send completion signal if conversation wasn't ended by tool call | |
| if not conversation_ended: | |
| # Generate TTS audio if requested | |
| audio_data = None | |
| if audio and accumulated_content.strip(): | |
| try: | |
| logger.info( | |
| f"Generating TTS for accumulated content: {len(accumulated_content)} chars" | |
| ) | |
| audio_result = await tts_service.text_to_speech(accumulated_content) | |
| if audio_result: | |
| audio_data = { | |
| "audio_data": audio_result["audio_data"], | |
| "mime_type": audio_result["mime_type"], | |
| "format": audio_result["format"], | |
| } | |
| logger.info("TTS audio generated successfully") | |
| else: | |
| logger.warning("TTS generation failed") | |
| except Exception as tts_error: | |
| logger.error(f"TTS generation error: {str(tts_error)}") | |
| # Send completion signal with optional audio | |
| completion_data = {"type": "completion", "content": "", "audio": audio_data} | |
| yield f"data: {json.dumps(completion_data)}\n\n" | |
| except Exception as e: | |
| logger.error(f"Error in streaming roleplay: {str(e)}") | |
| error_data = {"type": "error", "content": str(e)} | |
| yield f"data: {json.dumps(error_data)}\n\n" | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="text/plain", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Content-Type": "text/event-stream", | |
| }, | |
| ) | |