ABAO77 commited on
Commit
b7a3e32
Β·
1 Parent(s): 0f1954e

Enhance roleplay API to support audio messages and add comprehensive test cases

Browse files
requirements.txt CHANGED
@@ -11,4 +11,5 @@ langchain
11
  langgraph-swarm
12
  langchain-google-genai
13
  python-dotenv
14
- loguru
 
 
11
  langgraph-swarm
12
  langchain-google-genai
13
  python-dotenv
14
+ loguru
15
+ python-multipart
sessions.json CHANGED
@@ -3,8 +3,8 @@
3
  "id": "82a6779d-ad13-4edd-a046-575e563a4348",
4
  "name": "New Conversation",
5
  "created_at": "2025-08-21T11:57:23.992279",
6
- "last_message": "I would like a coffee",
7
- "message_count": 7
8
  },
9
  {
10
  "id": "4fbf6c50-6054-4f3d-ac4e-d8281c306d72",
 
3
  "id": "82a6779d-ad13-4edd-a046-575e563a4348",
4
  "name": "New Conversation",
5
  "created_at": "2025-08-21T11:57:23.992279",
6
+ "last_message": "[Audio message]",
7
+ "message_count": 37
8
  },
9
  {
10
  "id": "4fbf6c50-6054-4f3d-ac4e-d8281c306d72",
src/apis/routes/chat_route.py CHANGED
@@ -1,4 +1,13 @@
1
- from fastapi import APIRouter, status, Depends, BackgroundTasks, HTTPException
 
 
 
 
 
 
 
 
 
2
  from fastapi.responses import JSONResponse
3
  from src.utils.logger import logger
4
  from src.agents.role_play.func import create_agents
@@ -9,6 +18,7 @@ import json
9
  import os
10
  import uuid
11
  from datetime import datetime
 
12
 
13
  router = APIRouter(prefix="/ai", tags=["AI"])
14
 
@@ -113,22 +123,103 @@ async def list_scenarios():
113
 
114
 
115
  @router.post("/roleplay", status_code=status.HTTP_200_OK)
116
- async def roleplay(request: RoleplayRequest):
117
- """Send a message to the roleplay agent"""
118
- scenario = request.scenario
119
- if not scenario:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  raise HTTPException(status_code=400, detail="Scenario not provided")
121
- response = await create_agents(scenario).ainvoke(
122
- {
123
- "messages": [request.query],
124
- },
125
- {"configurable": {"thread_id": request.session_id}},
126
- )
127
 
128
- # Update session with last message
129
- update_session_last_message(request.session_id, request.query)
 
 
 
 
 
 
 
 
 
 
130
 
131
- return JSONResponse(content=response["messages"][-1].content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  @router.post("/get-messages", status_code=status.HTTP_200_OK)
 
1
+ from fastapi import (
2
+ APIRouter,
3
+ status,
4
+ Depends,
5
+ BackgroundTasks,
6
+ HTTPException,
7
+ File,
8
+ UploadFile,
9
+ Form,
10
+ )
11
  from fastapi.responses import JSONResponse
12
  from src.utils.logger import logger
13
  from src.agents.role_play.func import create_agents
 
18
  import os
19
  import uuid
20
  from datetime import datetime
21
+ import base64
22
 
23
  router = APIRouter(prefix="/ai", tags=["AI"])
24
 
 
123
 
124
 
125
  @router.post("/roleplay", status_code=status.HTTP_200_OK)
126
+ async def roleplay(
127
+ session_id: str = Form(
128
+ ..., description="Session ID for tracking user interactions"
129
+ ),
130
+ scenario: str = Form(
131
+ ..., description="The scenario for the roleplay as JSON string"
132
+ ),
133
+ text_message: Optional[str] = Form(None, description="Text message from user"),
134
+ audio_file: Optional[UploadFile] = File(None, description="Audio file from user"),
135
+ ):
136
+ """Send a message (text or audio) to the roleplay agent"""
137
+
138
+ # Validate that at least one input is provided
139
+ if not text_message and not audio_file:
140
+ raise HTTPException(
141
+ status_code=400, detail="Either text_message or audio_file must be provided"
142
+ )
143
+
144
+ # Parse scenario from JSON string
145
+ try:
146
+ scenario_dict = json.loads(scenario)
147
+ except json.JSONDecodeError:
148
+ raise HTTPException(status_code=400, detail="Invalid scenario JSON format")
149
+
150
+ if not scenario_dict:
151
  raise HTTPException(status_code=400, detail="Scenario not provided")
 
 
 
 
 
 
152
 
153
+ # Prepare message content
154
+ message_content = []
155
+
156
+ # Handle text input
157
+ if text_message:
158
+ message_content.append({"type": "text", "text": text_message})
159
+
160
+ # Handle audio input
161
+ if audio_file:
162
+ try:
163
+ # Read audio file content
164
+ audio_data = await audio_file.read()
165
 
166
+ # Convert to base64
167
+ audio_base64 = base64.b64encode(audio_data).decode("utf-8")
168
+
169
+ # Determine mime type based on file extension
170
+ file_extension = (
171
+ audio_file.filename.split(".")[-1].lower()
172
+ if audio_file.filename
173
+ else "wav"
174
+ )
175
+ mime_type_map = {
176
+ "wav": "audio/wav",
177
+ "mp3": "audio/mpeg",
178
+ "ogg": "audio/ogg",
179
+ "webm": "audio/webm",
180
+ "m4a": "audio/mp4",
181
+ }
182
+ mime_type = mime_type_map.get(file_extension, "audio/wav")
183
+
184
+ message_content.append(
185
+ {
186
+ "type": "audio",
187
+ "source_type": "base64",
188
+ "data": audio_base64,
189
+ "mime_type": mime_type,
190
+ }
191
+ )
192
+
193
+ except Exception as e:
194
+ logger.error(f"Error processing audio file: {str(e)}")
195
+ raise HTTPException(
196
+ status_code=400, detail=f"Error processing audio file: {str(e)}"
197
+ )
198
+
199
+ # Create message in the required format
200
+ message = {"role": "user", "content": message_content}
201
+
202
+ try:
203
+ response = await create_agents(scenario_dict).ainvoke(
204
+ {
205
+ "messages": [message],
206
+ },
207
+ {"configurable": {"thread_id": session_id}},
208
+ )
209
+
210
+ # Update session with last message (use text if available, otherwise indicate audio)
211
+ last_message = text_message if text_message else "[Audio message]"
212
+ update_session_last_message(session_id, last_message)
213
+
214
+ # Extract AI response content
215
+ ai_response = response["messages"][-1].content
216
+ logger.info(f"AI response: {ai_response}")
217
+
218
+ return JSONResponse(content={"response": ai_response})
219
+
220
+ except Exception as e:
221
+ logger.error(f"Error in roleplay: {str(e)}")
222
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
223
 
224
 
225
  @router.post("/get-messages", status_code=status.HTTP_200_OK)
test_audio_api.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for audio roleplay API
4
+ """
5
+
6
+ import requests
7
+ import json
8
+ import base64
9
+
10
+ # API configuration
11
+ API_BASE_URL = "http://localhost:8000"
12
+ # API_BASE_URL = "https://abao77-run-code-api.hf.space"
13
+
14
+ def test_text_message():
15
+ """Test sending text message only"""
16
+ print("Testing text message...")
17
+
18
+ scenario = {
19
+ "scenario_title": "Restaurant Order",
20
+ "scenario_description": "Order food at a restaurant",
21
+ "scenario_context": "You are at a restaurant and want to order food",
22
+ "your_role": "Customer",
23
+ "key_vocabulary": ["menu", "order", "bill", "table"]
24
+ }
25
+
26
+ data = {
27
+ "session_id": "test-session-123",
28
+ "scenario": json.dumps(scenario),
29
+ "text_message": "Hello, I'd like to see the menu please."
30
+ }
31
+
32
+ try:
33
+ response = requests.post(f"{API_BASE_URL}/api/ai/roleplay", data=data)
34
+ if response.ok:
35
+ result = response.json()
36
+ print("βœ… Text message test successful!")
37
+ print(f"Response: {result.get('response', 'No response')}")
38
+ else:
39
+ print(f"❌ Text message test failed: {response.status_code}")
40
+ print(f"Error: {response.text}")
41
+ except Exception as e:
42
+ print(f"❌ Text message test error: {e}")
43
+
44
+ def test_audio_message():
45
+ """Test sending audio file"""
46
+ print("\nTesting audio message...")
47
+
48
+ scenario = {
49
+ "scenario_title": "Restaurant Order",
50
+ "scenario_description": "Order food at a restaurant",
51
+ "scenario_context": "You are at a restaurant and want to order food",
52
+ "your_role": "Customer",
53
+ "key_vocabulary": ["menu", "order", "bill", "table"]
54
+ }
55
+
56
+ # Create a dummy audio file (in real scenario, this would be actual audio)
57
+ dummy_audio_data = b"fake_audio_data_for_testing"
58
+
59
+ data = {
60
+ "session_id": "test-session-456",
61
+ "scenario": json.dumps(scenario)
62
+ }
63
+
64
+ files = {
65
+ "audio_file": ("test_audio.wav", dummy_audio_data, "audio/wav")
66
+ }
67
+
68
+ try:
69
+ response = requests.post(f"{API_BASE_URL}/api/ai/roleplay", data=data, files=files)
70
+ if response.ok:
71
+ result = response.json()
72
+ print("βœ… Audio message test successful!")
73
+ print(f"Response: {result.get('response', 'No response')}")
74
+ else:
75
+ print(f"❌ Audio message test failed: {response.status_code}")
76
+ print(f"Error: {response.text}")
77
+ except Exception as e:
78
+ print(f"❌ Audio message test error: {e}")
79
+
80
+ def test_combined_message():
81
+ """Test sending both text and audio"""
82
+ print("\nTesting combined text + audio message...")
83
+
84
+ scenario = {
85
+ "scenario_title": "Restaurant Order",
86
+ "scenario_description": "Order food at a restaurant",
87
+ "scenario_context": "You are at a restaurant and want to order food",
88
+ "your_role": "Customer",
89
+ "key_vocabulary": ["menu", "order", "bill", "table"]
90
+ }
91
+
92
+ dummy_audio_data = b"fake_audio_data_for_testing"
93
+
94
+ data = {
95
+ "session_id": "test-session-789",
96
+ "scenario": json.dumps(scenario),
97
+ "text_message": "I have a question about the menu"
98
+ }
99
+
100
+ files = {
101
+ "audio_file": ("question.wav", dummy_audio_data, "audio/wav")
102
+ }
103
+
104
+ try:
105
+ response = requests.post(f"{API_BASE_URL}/api/ai/roleplay", data=data, files=files)
106
+ if response.ok:
107
+ result = response.json()
108
+ print("βœ… Combined message test successful!")
109
+ print(f"Response: {result.get('response', 'No response')}")
110
+ else:
111
+ print(f"❌ Combined message test failed: {response.status_code}")
112
+ print(f"Error: {response.text}")
113
+ except Exception as e:
114
+ print(f"❌ Combined message test error: {e}")
115
+
116
+ if __name__ == "__main__":
117
+ print("πŸ§ͺ Testing Audio Roleplay API")
118
+ print("=" * 50)
119
+
120
+ test_text_message()
121
+ test_audio_message()
122
+ test_combined_message()
123
+
124
+ print("\n" + "=" * 50)
125
+ print("🏁 Testing completed!")