File size: 4,496 Bytes
b7a3e32
 
 
 
 
 
 
 
6cbca40
 
c0827a3
6cbca40
4909ef6
 
b9b869f
b7a3e32
6cbca40
 
 
 
 
 
 
 
 
 
 
 
 
 
b9b869f
6cbca40
 
 
 
b7a3e32
 
 
 
 
 
 
 
 
 
 
4909ef6
b7a3e32
 
 
 
 
 
 
 
 
 
 
 
 
6cbca40
 
b7a3e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9b869f
b7a3e32
c0827a3
b7a3e32
 
c0827a3
 
 
 
 
b7a3e32
 
 
c0827a3
b7a3e32
 
 
c0827a3
b7a3e32
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from fastapi import (
    APIRouter,
    status,
    HTTPException,
    File,
    UploadFile,
    Form,
)
from fastapi.responses import JSONResponse
from src.utils.logger import logger
from src.agents.role_play.flow import role_play_agent
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional
from src.agents.role_play.scenarios import get_scenarios
import json
import base64

router = APIRouter(prefix="/ai", tags=["AI"])


class RoleplayRequest(BaseModel):
    query: str = Field(..., description="User's query for the AI agent")
    session_id: str = Field(
        ..., description="Session ID for tracking user interactions"
    )
    scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay")


@router.get("/scenarios", status_code=status.HTTP_200_OK)
async def list_scenarios():
    """Get all available scenarios"""
    return JSONResponse(content=get_scenarios())


@router.post("/roleplay", status_code=status.HTTP_200_OK)
async def roleplay(
    session_id: str = Form(
        ..., description="Session ID for tracking user interactions"
    ),
    scenario: str = Form(
        ..., description="The scenario for the roleplay as JSON string"
    ),
    text_message: Optional[str] = Form(None, description="Text message from user"),
    audio_file: Optional[UploadFile] = File(None, description="Audio file from user"),
):
    """Send a message (text or audio) to the roleplay agent"""
    logger.info(f"Received roleplay request: {session_id}")
    # Validate that at least one input is provided
    if not text_message and not audio_file:
        raise HTTPException(
            status_code=400, detail="Either text_message or audio_file must be provided"
        )

    # Parse scenario from JSON string
    try:
        scenario_dict = json.loads(scenario)
    except json.JSONDecodeError:
        raise HTTPException(status_code=400, detail="Invalid scenario JSON format")

    if not scenario_dict:
        raise HTTPException(status_code=400, detail="Scenario not provided")

    # Prepare message content
    message_content = []

    # Handle text input
    if text_message:
        message_content.append({"type": "text", "text": text_message})

    # Handle audio input
    if audio_file:
        try:
            # Read audio file content
            audio_data = await audio_file.read()

            # Convert to base64
            audio_base64 = base64.b64encode(audio_data).decode("utf-8")

            # Determine mime type based on file extension
            file_extension = (
                audio_file.filename.split(".")[-1].lower()
                if audio_file.filename
                else "wav"
            )
            mime_type_map = {
                "wav": "audio/wav",
                "mp3": "audio/mpeg",
                "ogg": "audio/ogg",
                "webm": "audio/webm",
                "m4a": "audio/mp4",
            }
            mime_type = mime_type_map.get(file_extension, "audio/wav")

            message_content.append(
                {
                    "type": "audio",
                    "source_type": "base64",
                    "data": audio_base64,
                    "mime_type": mime_type,
                }
            )

        except Exception as e:
            logger.error(f"Error processing audio file: {str(e)}")
            raise HTTPException(
                status_code=400, detail=f"Error processing audio file: {str(e)}"
            )

    # Create message in the required format
    message = {"role": "user", "content": message_content}

    try:
        response = await role_play_agent().ainvoke(
            {
                "messages": [message],
                "scenario_title": scenario_dict["scenario_title"],
                "scenario_description": scenario_dict["scenario_description"],
                "scenario_context": scenario_dict["scenario_context"],
                "your_role": scenario_dict["your_role"],
                "key_vocabulary": scenario_dict["key_vocabulary"],
            },
            {"configurable": {"thread_id": session_id}},
        )

        # Extract AI response content
        ai_response = response["messages"][-1].content
        logger.info(f"AI response: {ai_response}")

        return JSONResponse(content={"response": ai_response})

    except Exception as e:
        logger.error(f"Error in roleplay: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")