Run_code_api / src /apis /routes /chat_route.py
ABAO77's picture
remove conversation management
4909ef6
raw
history blame
4.5 kB
from fastapi import (
APIRouter,
status,
HTTPException,
File,
UploadFile,
Form,
)
from fastapi.responses import JSONResponse
from src.utils.logger import logger
from src.agents.role_play.flow import role_play_agent
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional
from src.agents.role_play.scenarios import get_scenarios
import json
import base64
router = APIRouter(prefix="/ai", tags=["AI"])
class RoleplayRequest(BaseModel):
query: str = Field(..., description="User's query for the AI agent")
session_id: str = Field(
..., description="Session ID for tracking user interactions"
)
scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay")
@router.get("/scenarios", status_code=status.HTTP_200_OK)
async def list_scenarios():
"""Get all available scenarios"""
return JSONResponse(content=get_scenarios())
@router.post("/roleplay", status_code=status.HTTP_200_OK)
async def roleplay(
session_id: str = Form(
..., description="Session ID for tracking user interactions"
),
scenario: str = Form(
..., description="The scenario for the roleplay as JSON string"
),
text_message: Optional[str] = Form(None, description="Text message from user"),
audio_file: Optional[UploadFile] = File(None, description="Audio file from user"),
):
"""Send a message (text or audio) to the roleplay agent"""
logger.info(f"Received roleplay request: {session_id}")
# Validate that at least one input is provided
if not text_message and not audio_file:
raise HTTPException(
status_code=400, detail="Either text_message or audio_file must be provided"
)
# Parse scenario from JSON string
try:
scenario_dict = json.loads(scenario)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid scenario JSON format")
if not scenario_dict:
raise HTTPException(status_code=400, detail="Scenario not provided")
# Prepare message content
message_content = []
# Handle text input
if text_message:
message_content.append({"type": "text", "text": text_message})
# Handle audio input
if audio_file:
try:
# Read audio file content
audio_data = await audio_file.read()
# Convert to base64
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
# Determine mime type based on file extension
file_extension = (
audio_file.filename.split(".")[-1].lower()
if audio_file.filename
else "wav"
)
mime_type_map = {
"wav": "audio/wav",
"mp3": "audio/mpeg",
"ogg": "audio/ogg",
"webm": "audio/webm",
"m4a": "audio/mp4",
}
mime_type = mime_type_map.get(file_extension, "audio/wav")
message_content.append(
{
"type": "audio",
"source_type": "base64",
"data": audio_base64,
"mime_type": mime_type,
}
)
except Exception as e:
logger.error(f"Error processing audio file: {str(e)}")
raise HTTPException(
status_code=400, detail=f"Error processing audio file: {str(e)}"
)
# Create message in the required format
message = {"role": "user", "content": message_content}
try:
response = await role_play_agent().ainvoke(
{
"messages": [message],
"scenario_title": scenario_dict["scenario_title"],
"scenario_description": scenario_dict["scenario_description"],
"scenario_context": scenario_dict["scenario_context"],
"your_role": scenario_dict["your_role"],
"key_vocabulary": scenario_dict["key_vocabulary"],
},
{"configurable": {"thread_id": session_id}},
)
# Extract AI response content
ai_response = response["messages"][-1].content
logger.info(f"AI response: {ai_response}")
return JSONResponse(content={"response": ai_response})
except Exception as e:
logger.error(f"Error in roleplay: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")