Run_code_api / src /apis /routes /chat_route.py
ABAO77's picture
feat: evaluation when end
61e4b1e
raw
history blame
12.9 kB
from fastapi import (
APIRouter,
status,
HTTPException,
File,
UploadFile,
Form,
)
from fastapi.responses import JSONResponse, StreamingResponse
from src.utils.logger import logger
from src.agents.role_play.flow import role_play_agent
from src.services.tts_service import tts_service
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional
from src.agents.role_play.scenarios import get_scenarios
import json
import base64
import asyncio
router = APIRouter(prefix="/ai", tags=["AI"])
class RoleplayRequest(BaseModel):
query: str = Field(..., description="User's query for the AI agent")
session_id: str = Field(
..., description="Session ID for tracking user interactions"
)
scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay")
@router.get("/scenarios", status_code=status.HTTP_200_OK)
async def list_scenarios():
"""Get all available scenarios"""
return JSONResponse(content=get_scenarios())
@router.post("/roleplay", status_code=status.HTTP_200_OK)
async def roleplay(
session_id: str = Form(
..., description="Session ID for tracking user interactions"
),
scenario: str = Form(
..., description="The scenario for the roleplay as JSON string"
),
text_message: Optional[str] = Form(None, description="Text message from user"),
audio_file: Optional[UploadFile] = File(None, description="Audio file from user"),
):
"""Send a message (text or audio) to the roleplay agent"""
logger.info(f"Received roleplay request: {session_id}")
# Validate that at least one input is provided
if not text_message and not audio_file:
raise HTTPException(
status_code=400, detail="Either text_message or audio_file must be provided"
)
# Parse scenario from JSON string
try:
scenario_dict = json.loads(scenario)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid scenario JSON format")
if not scenario_dict:
raise HTTPException(status_code=400, detail="Scenario not provided")
# Prepare message content
message_content = []
# Handle text input
if text_message:
message_content.append({"type": "text", "text": text_message})
# Handle audio input
if audio_file:
try:
# Read audio file content
audio_data = await audio_file.read()
# Convert to base64
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
# Determine mime type based on file extension
file_extension = (
audio_file.filename.split(".")[-1].lower()
if audio_file.filename
else "wav"
)
mime_type_map = {
"wav": "audio/wav",
"mp3": "audio/mpeg",
"ogg": "audio/ogg",
"webm": "audio/webm",
"m4a": "audio/mp4",
}
mime_type = mime_type_map.get(file_extension, "audio/wav")
message_content.append(
{
"type": "audio",
"source_type": "base64",
"data": audio_base64,
"mime_type": mime_type,
}
)
except Exception as e:
logger.error(f"Error processing audio file: {str(e)}")
raise HTTPException(
status_code=400, detail=f"Error processing audio file: {str(e)}"
)
# Create message in the required format
message = {"role": "user", "content": message_content}
try:
response = await role_play_agent().ainvoke(
{
"messages": [message],
"scenario_title": scenario_dict["scenario_title"],
"scenario_description": scenario_dict["scenario_description"],
"scenario_context": scenario_dict["scenario_context"],
"your_role": scenario_dict["your_role"],
"key_vocabulary": scenario_dict["key_vocabulary"],
},
{"configurable": {"thread_id": session_id}},
)
# Extract AI response content
ai_response = response["messages"][-1].content
logger.info(f"AI response: {ai_response}")
return JSONResponse(content={"response": ai_response})
except Exception as e:
logger.error(f"Error in roleplay: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@router.post("/roleplay/stream", status_code=status.HTTP_200_OK)
async def roleplay_stream(
session_id: str = Form(
..., description="Session ID for tracking user interactions"
),
scenario: str = Form(
..., description="The scenario for the roleplay as JSON string"
),
text_message: Optional[str] = Form(None, description="Text message from user"),
audio_file: Optional[UploadFile] = File(None, description="Audio file from user"),
audio: bool = Form(False, description="Whether to return TTS audio response"),
):
"""Send a message (text or audio) to the roleplay agent with streaming response"""
logger.info(f"Received streaming roleplay request: {session_id}")
# Validate that at least one input is provided
if not text_message and not audio_file:
raise HTTPException(
status_code=400, detail="Either text_message or audio_file must be provided"
)
# Parse scenario from JSON string
try:
scenario_dict = json.loads(scenario)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid scenario JSON format")
if not scenario_dict:
raise HTTPException(status_code=400, detail="Scenario not provided")
# Prepare message content
message_content = []
# Handle text input
if text_message:
message_content.append({"type": "text", "text": text_message})
# Handle audio input
if audio_file:
try:
# Read audio file content
audio_data = await audio_file.read()
# Convert to base64
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
# Determine mime type based on file extension
file_extension = (
audio_file.filename.split(".")[-1].lower()
if audio_file.filename
else "wav"
)
mime_type_map = {
"wav": "audio/wav",
"mp3": "audio/mpeg",
"ogg": "audio/ogg",
"webm": "audio/webm",
"m4a": "audio/mp4",
}
mime_type = mime_type_map.get(file_extension, "audio/wav")
message_content.append(
{
"type": "audio",
"source_type": "base64",
"data": audio_base64,
"mime_type": mime_type,
}
)
except Exception as e:
logger.error(f"Error processing audio file: {str(e)}")
raise HTTPException(
status_code=400, detail=f"Error processing audio file: {str(e)}"
)
# Create message in the required format
message = {"role": "user", "content": message_content}
async def generate_stream():
"""Generator function for streaming responses"""
accumulated_content = ""
conversation_ended = False
try:
input_graph = {
"messages": [message],
"scenario_title": scenario_dict["scenario_title"],
"scenario_description": scenario_dict["scenario_description"],
"scenario_context": scenario_dict["scenario_context"],
"your_role": scenario_dict["your_role"],
"key_vocabulary": scenario_dict["key_vocabulary"],
}
config = {"configurable": {"thread_id": session_id}}
async for event in role_play_agent().astream(
input=input_graph,
stream_mode=["messages", "values"],
config=config,
subgraphs=True,
):
_, event_type, message_chunk = event
if event_type == "messages":
# message_chunk is a tuple, get the first element which is the actual AIMessageChunk
if isinstance(message_chunk, tuple) and len(message_chunk) > 0:
actual_message = message_chunk[0]
content = getattr(actual_message, "content", "")
else:
actual_message = message_chunk
content = getattr(message_chunk, "content", "")
# Check if this is a tool call message and if it's an end conversation tool call
if (
hasattr(actual_message, "tool_calls")
and actual_message.tool_calls
):
# Check if any tool call is for ending the conversation
for tool_call in actual_message.tool_calls:
if (
isinstance(tool_call, dict)
and tool_call.get("name") == "end_conversation"
):
# Send a special termination message to the client
termination_data = {
"type": "termination",
"content": "Conversation ended",
"reason": tool_call.get("args", {}).get("reason", "Unknown reason")
}
yield f"data: {json.dumps(termination_data)}\n\n"
conversation_ended = True
break
if content and not conversation_ended:
# Accumulate content for TTS
accumulated_content += content
# Create SSE-formatted response
response_data = {
"type": "message_chunk",
"content": content,
"metadata": {
"agent": getattr(actual_message, "name", "unknown"),
"id": getattr(actual_message, "id", ""),
"usage_metadata": getattr(
actual_message, "usage_metadata", {}
),
},
}
yield f"data: {json.dumps(response_data)}\n\n"
# Small delay to prevent overwhelming the client
await asyncio.sleep(0.01)
# Only send completion signal if conversation wasn't ended by tool call
if not conversation_ended:
# Generate TTS audio if requested
audio_data = None
if audio and accumulated_content.strip():
try:
logger.info(
f"Generating TTS for accumulated content: {len(accumulated_content)} chars"
)
audio_result = await tts_service.text_to_speech(accumulated_content)
if audio_result:
audio_data = {
"audio_data": audio_result["audio_data"],
"mime_type": audio_result["mime_type"],
"format": audio_result["format"],
}
logger.info("TTS audio generated successfully")
else:
logger.warning("TTS generation failed")
except Exception as tts_error:
logger.error(f"TTS generation error: {str(tts_error)}")
# Send completion signal with optional audio
completion_data = {"type": "completion", "content": "", "audio": audio_data}
yield f"data: {json.dumps(completion_data)}\n\n"
except Exception as e:
logger.error(f"Error in streaming roleplay: {str(e)}")
error_data = {"type": "error", "content": str(e)}
yield f"data: {json.dumps(error_data)}\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
},
)