ABAO77 commited on
Commit
a4cb278
·
1 Parent(s): c0827a3

Add trim_history function and update lesson chat route to handle text and audio inputs

Browse files
src/agents/lesson_practice/flow.py CHANGED
@@ -1,5 +1,5 @@
1
  from langgraph.graph import StateGraph, START, END
2
- from .func import State, agent, tool_node
3
  from langgraph.graph.state import CompiledStateGraph
4
  from langgraph.checkpoint.memory import InMemorySaver
5
 
@@ -18,12 +18,14 @@ class LessonPracticeAgent:
18
  return "continue"
19
 
20
  def node(self, graph: StateGraph):
 
21
  graph.add_node("agent", agent)
22
  graph.add_node("tools", tool_node)
23
  return graph
24
 
25
  def edge(self, graph: StateGraph):
26
- graph.add_edge(START, "agent")
 
27
  graph.add_conditional_edges(
28
  "agent", self.should_continue, {"end": END, "continue": "tools"}
29
  )
 
1
  from langgraph.graph import StateGraph, START, END
2
+ from .func import State, trim_history, agent, tool_node
3
  from langgraph.graph.state import CompiledStateGraph
4
  from langgraph.checkpoint.memory import InMemorySaver
5
 
 
18
  return "continue"
19
 
20
  def node(self, graph: StateGraph):
21
+ graph.add_node("trim_history", trim_history)
22
  graph.add_node("agent", agent)
23
  graph.add_node("tools", tool_node)
24
  return graph
25
 
26
  def edge(self, graph: StateGraph):
27
+ graph.add_edge(START, "trim_history")
28
+ graph.add_edge("trim_history", "agent")
29
  graph.add_conditional_edges(
30
  "agent", self.should_continue, {"end": END, "continue": "tools"}
31
  )
src/agents/lesson_practice/func.py CHANGED
@@ -3,7 +3,7 @@ from typing import (
3
  Sequence,
4
  TypedDict,
5
  )
6
- from langchain_core.messages import ToolMessage, AnyMessage
7
  from langgraph.graph.message import add_messages
8
  import json
9
  from .prompt import conversation_prompt
@@ -26,6 +26,19 @@ tools = []
26
  tools_by_name = {tool.name: tool for tool in tools}
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Define our tool node
30
  def tool_node(state: State):
31
  outputs = []
 
3
  Sequence,
4
  TypedDict,
5
  )
6
+ from langchain_core.messages import ToolMessage, AnyMessage, RemoveMessage
7
  from langgraph.graph.message import add_messages
8
  import json
9
  from .prompt import conversation_prompt
 
26
  tools_by_name = {tool.name: tool for tool in tools}
27
 
28
 
29
+ def trim_history(state: State):
30
+ if not state.get("active_agent"):
31
+ state["active_agent"] = "Roleplay Agent"
32
+ history = state.get("messages", [])
33
+ if len(history) > 25:
34
+ num_to_remove = len(history) - 5
35
+ remove_messages = [
36
+ RemoveMessage(id=history[i].id) for i in range(num_to_remove)
37
+ ]
38
+ state["messages"] = remove_messages
39
+ return state
40
+
41
+
42
  # Define our tool node
43
  def tool_node(state: State):
44
  outputs = []
src/apis/routes/lesson_route.py CHANGED
@@ -1,4 +1,13 @@
1
- from fastapi import APIRouter, status, Depends, BackgroundTasks, HTTPException
 
 
 
 
 
 
 
 
 
2
  from fastapi.responses import JSONResponse
3
  from src.utils.logger import logger
4
  from pydantic import BaseModel, Field
@@ -9,6 +18,7 @@ import json
9
  import os
10
  import uuid
11
  from datetime import datetime
 
12
 
13
  router = APIRouter(prefix="/lesson", tags=["AI"])
14
 
@@ -144,18 +154,101 @@ async def search_lessons_by_unit(unit_name: str):
144
 
145
 
146
  @router.post("/chat")
147
- async def chat(request: LessonPracticeRequest):
148
- response = await lesson_practice_agent().ainvoke(
149
- {
150
- "unit": request.unit,
151
- "vocabulary": request.vocabulary,
152
- "key_structures": request.key_structures,
153
- "practice_questions": request.practice_questions,
154
- "student_level": request.student_level,
155
- "messages": [request.query],
156
- },
157
- {"configurable": {"thread_id": request.session_id}},
158
- )
159
- return JSONResponse(
160
- content=response["messages"][-1].content, status_code=status.HTTP_200_OK
161
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import (
2
+ APIRouter,
3
+ status,
4
+ Depends,
5
+ BackgroundTasks,
6
+ HTTPException,
7
+ File,
8
+ UploadFile,
9
+ Form,
10
+ )
11
  from fastapi.responses import JSONResponse
12
  from src.utils.logger import logger
13
  from pydantic import BaseModel, Field
 
18
  import os
19
  import uuid
20
  from datetime import datetime
21
+ import base64
22
 
23
  router = APIRouter(prefix="/lesson", tags=["AI"])
24
 
 
154
 
155
 
156
  @router.post("/chat")
157
+ async def chat(
158
+ session_id: str = Form(
159
+ ..., description="Session ID for tracking user interactions"
160
+ ),
161
+ lesson_data: str = Form(
162
+ ..., description="The lesson data as JSON string"
163
+ ),
164
+ text_message: Optional[str] = Form(None, description="Text message from user"),
165
+ audio_file: Optional[UploadFile] = File(None, description="Audio file from user"),
166
+ ):
167
+ """Send a message (text or audio) to the lesson practice agent"""
168
+
169
+ # Validate that at least one input is provided
170
+ if not text_message and not audio_file:
171
+ raise HTTPException(
172
+ status_code=400, detail="Either text_message or audio_file must be provided"
173
+ )
174
+
175
+ # Parse lesson data from JSON string
176
+ try:
177
+ lesson_dict = json.loads(lesson_data)
178
+ except json.JSONDecodeError:
179
+ raise HTTPException(status_code=400, detail="Invalid lesson_data JSON format")
180
+
181
+ if not lesson_dict:
182
+ raise HTTPException(status_code=400, detail="Lesson data not provided")
183
+
184
+ # Prepare message content
185
+ message_content = []
186
+
187
+ # Handle text input
188
+ if text_message:
189
+ message_content.append({"type": "text", "text": text_message})
190
+
191
+ # Handle audio input
192
+ if audio_file:
193
+ try:
194
+ # Read audio file content
195
+ audio_data = await audio_file.read()
196
+
197
+ # Convert to base64
198
+ audio_base64 = base64.b64encode(audio_data).decode("utf-8")
199
+
200
+ # Determine mime type based on file extension
201
+ file_extension = (
202
+ audio_file.filename.split(".")[-1].lower()
203
+ if audio_file.filename
204
+ else "wav"
205
+ )
206
+ mime_type_map = {
207
+ "wav": "audio/wav",
208
+ "mp3": "audio/mpeg",
209
+ "ogg": "audio/ogg",
210
+ "webm": "audio/webm",
211
+ "m4a": "audio/mp4",
212
+ }
213
+ mime_type = mime_type_map.get(file_extension, "audio/wav")
214
+
215
+ message_content.append(
216
+ {
217
+ "type": "audio",
218
+ "source_type": "base64",
219
+ "data": audio_base64,
220
+ "mime_type": mime_type,
221
+ }
222
+ )
223
+
224
+ except Exception as e:
225
+ logger.error(f"Error processing audio file: {str(e)}")
226
+ raise HTTPException(
227
+ status_code=400, detail=f"Error processing audio file: {str(e)}"
228
+ )
229
+
230
+ # Create message in the required format
231
+ message = {"role": "user", "content": message_content}
232
+
233
+ try:
234
+ response = await lesson_practice_agent().ainvoke(
235
+ {
236
+ "messages": [message],
237
+ "unit": lesson_dict.get("unit", ""),
238
+ "vocabulary": lesson_dict.get("vocabulary", []),
239
+ "key_structures": lesson_dict.get("key_structures", []),
240
+ "practice_questions": lesson_dict.get("practice_questions", []),
241
+ "student_level": lesson_dict.get("student_level", "beginner"),
242
+ },
243
+ {"configurable": {"thread_id": session_id}},
244
+ )
245
+
246
+ # Extract AI response content
247
+ ai_response = response["messages"][-1].content
248
+ logger.info(f"AI response: {ai_response}")
249
+
250
+ return JSONResponse(content={"response": ai_response})
251
+
252
+ except Exception as e:
253
+ logger.error(f"Error in lesson practice: {str(e)}")
254
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")