Spaces:
Sleeping
Sleeping
File size: 4,496 Bytes
b7a3e32 6cbca40 c0827a3 6cbca40 4909ef6 b9b869f b7a3e32 6cbca40 b9b869f 6cbca40 b7a3e32 4909ef6 b7a3e32 6cbca40 b7a3e32 b9b869f b7a3e32 c0827a3 b7a3e32 c0827a3 b7a3e32 c0827a3 b7a3e32 c0827a3 b7a3e32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from fastapi import (
APIRouter,
status,
HTTPException,
File,
UploadFile,
Form,
)
from fastapi.responses import JSONResponse
from src.utils.logger import logger
from src.agents.role_play.flow import role_play_agent
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional
from src.agents.role_play.scenarios import get_scenarios
import json
import base64
router = APIRouter(prefix="/ai", tags=["AI"])
class RoleplayRequest(BaseModel):
query: str = Field(..., description="User's query for the AI agent")
session_id: str = Field(
..., description="Session ID for tracking user interactions"
)
scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay")
@router.get("/scenarios", status_code=status.HTTP_200_OK)
async def list_scenarios():
"""Get all available scenarios"""
return JSONResponse(content=get_scenarios())
@router.post("/roleplay", status_code=status.HTTP_200_OK)
async def roleplay(
session_id: str = Form(
..., description="Session ID for tracking user interactions"
),
scenario: str = Form(
..., description="The scenario for the roleplay as JSON string"
),
text_message: Optional[str] = Form(None, description="Text message from user"),
audio_file: Optional[UploadFile] = File(None, description="Audio file from user"),
):
"""Send a message (text or audio) to the roleplay agent"""
logger.info(f"Received roleplay request: {session_id}")
# Validate that at least one input is provided
if not text_message and not audio_file:
raise HTTPException(
status_code=400, detail="Either text_message or audio_file must be provided"
)
# Parse scenario from JSON string
try:
scenario_dict = json.loads(scenario)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid scenario JSON format")
if not scenario_dict:
raise HTTPException(status_code=400, detail="Scenario not provided")
# Prepare message content
message_content = []
# Handle text input
if text_message:
message_content.append({"type": "text", "text": text_message})
# Handle audio input
if audio_file:
try:
# Read audio file content
audio_data = await audio_file.read()
# Convert to base64
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
# Determine mime type based on file extension
file_extension = (
audio_file.filename.split(".")[-1].lower()
if audio_file.filename
else "wav"
)
mime_type_map = {
"wav": "audio/wav",
"mp3": "audio/mpeg",
"ogg": "audio/ogg",
"webm": "audio/webm",
"m4a": "audio/mp4",
}
mime_type = mime_type_map.get(file_extension, "audio/wav")
message_content.append(
{
"type": "audio",
"source_type": "base64",
"data": audio_base64,
"mime_type": mime_type,
}
)
except Exception as e:
logger.error(f"Error processing audio file: {str(e)}")
raise HTTPException(
status_code=400, detail=f"Error processing audio file: {str(e)}"
)
# Create message in the required format
message = {"role": "user", "content": message_content}
try:
response = await role_play_agent().ainvoke(
{
"messages": [message],
"scenario_title": scenario_dict["scenario_title"],
"scenario_description": scenario_dict["scenario_description"],
"scenario_context": scenario_dict["scenario_context"],
"your_role": scenario_dict["your_role"],
"key_vocabulary": scenario_dict["key_vocabulary"],
},
{"configurable": {"thread_id": session_id}},
)
# Extract AI response content
ai_response = response["messages"][-1].content
logger.info(f"AI response: {ai_response}")
return JSONResponse(content={"response": ai_response})
except Exception as e:
logger.error(f"Error in roleplay: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|