from fastapi import ( APIRouter, status, Depends, BackgroundTasks, HTTPException, File, UploadFile, Form, ) from fastapi.responses import JSONResponse, StreamingResponse from src.utils.logger import logger from src.services.tts_service import tts_service from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional from src.agents.lesson_practice.flow import lesson_practice_agent from src.apis.models.lesson_models import Lesson, LessonResponse, LessonDetailResponse import json import os import uuid from datetime import datetime import base64 import asyncio router = APIRouter(prefix="/lesson", tags=["AI"]) class LessonPracticeRequest(BaseModel): unit: str = Field(..., description="Unit of the lesson") vocabulary: list = Field(..., description="Vocabulary for the lesson") key_structures: list = Field(..., description="Key structures for the lesson") practice_questions: list = Field( ..., description="Practice questions for the lesson" ) student_level: str = Field("beginner", description="Student's level of English") query: str = Field(..., description="User query for the lesson") session_id: str = Field(..., description="Session ID for the lesson") # Helper function to load lessons from JSON file def load_lessons_from_file() -> List[Lesson]: """Load lessons from the JSON file""" try: lessons_file_path = os.path.join( os.path.dirname(__file__), "..", "..", "data", "lessons.json" ) if not os.path.exists(lessons_file_path): logger.warning(f"Lessons file not found at {lessons_file_path}") return [] with open(lessons_file_path, "r", encoding="utf-8") as file: lessons_data = json.load(file) # Convert to Lesson objects lessons = [] for lesson_data in lessons_data: try: lesson = Lesson(**lesson_data) lessons.append(lesson) except Exception as e: logger.error( f"Error parsing lesson {lesson_data.get('id', 'unknown')}: {str(e)}" ) continue return lessons except Exception as e: logger.error(f"Error loading lessons: {str(e)}") return [] @router.get("/all", response_model=LessonResponse) async def get_all_lessons(): """ Get all available lessons Returns: LessonResponse: Contains list of all lessons and total count """ try: lessons = load_lessons_from_file() return LessonResponse(lessons=lessons, total=len(lessons)) except Exception as e: logger.error(f"Error retrieving lessons: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve lessons", ) @router.get("/{lesson_id}", response_model=LessonDetailResponse) async def get_lesson_by_id(lesson_id: str): """ Get a specific lesson by ID Args: lesson_id (str): The unique identifier of the lesson Returns: LessonDetailResponse: Contains the lesson details """ try: lessons = load_lessons_from_file() # Find the lesson with the specified ID lesson = next((l for l in lessons if l.id == lesson_id), None) if not lesson: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Lesson with ID '{lesson_id}' not found", ) return LessonDetailResponse(lesson=lesson) except HTTPException: raise except Exception as e: logger.error(f"Error retrieving lesson {lesson_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve lesson", ) @router.get("/search/unit/{unit_name}") async def search_lessons_by_unit(unit_name: str): """ Search lessons by unit name (case-insensitive partial match) Args: unit_name (str): Part of the unit name to search for Returns: LessonResponse: Contains list of matching lessons """ try: lessons = load_lessons_from_file() # Filter lessons by unit name (case-insensitive partial match) matching_lessons = [ lesson for lesson in lessons if unit_name.lower() in lesson.unit.lower() ] return LessonResponse(lessons=matching_lessons, total=len(matching_lessons)) except Exception as e: logger.error(f"Error searching lessons by unit '{unit_name}': {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to search lessons", ) @router.post("/chat") async def chat( session_id: str = Form( ..., description="Session ID for tracking user interactions" ), lesson_data: str = Form(..., description="The lesson data as JSON string"), text_message: Optional[str] = Form(None, description="Text message from user"), audio_file: Optional[UploadFile] = File(None, description="Audio file from user"), ): """Send a message (text or audio) to the lesson practice v2 agent with Practice and Teaching agents""" # Validate that at least one input is provided if not text_message and not audio_file: raise HTTPException( status_code=400, detail="Either text_message or audio_file must be provided" ) # Parse lesson data from JSON string try: lesson_dict = json.loads(lesson_data) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid lesson_data JSON format") if not lesson_dict: raise HTTPException(status_code=400, detail="Lesson data not provided") # Prepare message content message_content = [] # Handle text input if text_message: message_content.append({"type": "text", "text": text_message}) # Handle audio input if audio_file: try: # Read audio file content audio_data = await audio_file.read() # Convert to base64 audio_base64 = base64.b64encode(audio_data).decode("utf-8") # Determine mime type based on file extension file_extension = ( audio_file.filename.split(".")[-1].lower() if audio_file.filename else "wav" ) mime_type_map = { "wav": "audio/wav", "mp3": "audio/mpeg", "ogg": "audio/ogg", "webm": "audio/webm", "m4a": "audio/mp4", } mime_type = mime_type_map.get(file_extension, "audio/wav") message_content.append( { "type": "audio", "source_type": "base64", "data": audio_base64, "mime_type": mime_type, } ) except Exception as e: logger.error(f"Error processing audio file: {str(e)}") raise HTTPException( status_code=400, detail=f"Error processing audio file: {str(e)}" ) # Create message in the required format message = {"role": "user", "content": message_content} try: response = await lesson_practice_agent().ainvoke( { "messages": [message], "unit": lesson_dict.get("unit", ""), "vocabulary": lesson_dict.get("vocabulary", []), "key_structures": lesson_dict.get("key_structures", []), "practice_questions": lesson_dict.get("practice_questions", []), "student_level": lesson_dict.get("student_level", "beginner"), }, {"configurable": {"thread_id": session_id}}, ) # Extract AI response content ai_response = response["messages"][-1].content logger.info(f"AI response (v2): {ai_response}") return JSONResponse(content={"response": ai_response}) except Exception as e: logger.error(f"Error in lesson practice v2: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @router.post("/chat/stream", status_code=status.HTTP_200_OK) async def chat_stream( session_id: str = Form( ..., description="Session ID for tracking user interactions" ), lesson_data: str = Form(..., description="The lesson data as JSON string"), text_message: Optional[str] = Form(None, description="Text message from user"), audio_file: Optional[UploadFile] = File(None, description="Audio file from user"), audio: bool = Form(False, description="Whether to return TTS audio response"), ): """Send a message (text or audio) to the lesson practice v2 agent with streaming response""" logger.info(f"Received streaming lesson practice v2 request: {session_id}") # Validate that at least one input is provided if not text_message and not audio_file: raise HTTPException( status_code=400, detail="Either text_message or audio_file must be provided" ) # Parse lesson data from JSON string try: lesson_dict = json.loads(lesson_data) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid lesson_data JSON format") if not lesson_dict: raise HTTPException(status_code=400, detail="Lesson data not provided") # Prepare message content message_content = [] # Handle text input if text_message: message_content.append({"type": "text", "text": text_message}) # Handle audio input if audio_file: try: # Read audio file content audio_data = await audio_file.read() # Convert to base64 audio_base64 = base64.b64encode(audio_data).decode("utf-8") # Determine mime type based on file extension file_extension = ( audio_file.filename.split(".")[-1].lower() if audio_file.filename else "wav" ) mime_type_map = { "wav": "audio/wav", "mp3": "audio/mpeg", "ogg": "audio/ogg", "webm": "audio/webm", "m4a": "audio/mp4", } mime_type = mime_type_map.get(file_extension, "audio/wav") message_content.append( { "type": "audio", "source_type": "base64", "data": audio_base64, "mime_type": mime_type, } ) except Exception as e: logger.error(f"Error processing audio file: {str(e)}") raise HTTPException( status_code=400, detail=f"Error processing audio file: {str(e)}" ) # Create message in the required format message = {"role": "user", "content": message_content} async def generate_stream(): """Generator function for streaming responses""" accumulated_content = "" try: input_graph = { "messages": [message], "unit": lesson_dict.get("unit", ""), "vocabulary": lesson_dict.get("vocabulary", []), "key_structures": lesson_dict.get("key_structures", []), "practice_questions": lesson_dict.get("practice_questions", []), "student_level": lesson_dict.get("student_level", "beginner"), } config = {"configurable": {"thread_id": session_id}} async for event in lesson_practice_agent().astream( input=input_graph, stream_mode=["messages"], config=config, subgraphs=True, ): _, event_type, message_chunk = event if event_type == "messages": # message_chunk is a tuple, get the first element which is the actual AIMessageChunk if isinstance(message_chunk, tuple) and len(message_chunk) > 0: actual_message = message_chunk[0] content = getattr(actual_message, "content", "") else: actual_message = message_chunk content = getattr(message_chunk, "content", "") if content: # Accumulate content for TTS accumulated_content += content # Create SSE-formatted response response_data = { "type": "message_chunk", "content": content, "metadata": { "agent": getattr(actual_message, "name", "unknown"), "id": getattr(actual_message, "id", ""), "usage_metadata": getattr( actual_message, "usage_metadata", {} ), }, } yield f"data: {json.dumps(response_data)}\n\n" # Small delay to prevent overwhelming the client await asyncio.sleep(0.01) # Generate TTS audio if requested audio_data = None if audio and accumulated_content.strip(): try: logger.info( f"Generating TTS for lesson v2 content: {len(accumulated_content)} chars" ) audio_result = await tts_service.text_to_speech(accumulated_content) if audio_result: audio_data = { "audio_data": audio_result["audio_data"], "mime_type": audio_result["mime_type"], "format": audio_result["format"], } logger.info("Lesson v2 TTS audio generated successfully") else: logger.warning("Lesson v2 TTS generation failed") except Exception as tts_error: logger.error(f"Lesson v2 TTS generation error: {str(tts_error)}") # Send completion signal with optional audio completion_data = {"type": "completion", "content": "", "audio": audio_data} yield f"data: {json.dumps(completion_data)}\n\n" except Exception as e: logger.error(f"Error in streaming lesson practice v2: {str(e)}") error_data = {"type": "error", "content": str(e)} yield f"data: {json.dumps(error_data)}\n\n" return StreamingResponse( generate_stream(), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/event-stream", }, )