Spaces:
Sleeping
Sleeping
| from fastapi import UploadFile, File, Form, HTTPException, APIRouter | |
| from pydantic import BaseModel | |
| from typing import List, Dict | |
| import tempfile | |
| import numpy as np | |
| import re | |
| import warnings | |
| from loguru import logger | |
| from src.apis.controllers.speaking_controller import ( | |
| SimpleG2P, | |
| PhonemeComparator, | |
| SimplePronunciationAssessor, | |
| ) | |
| from src.utils.speaking_utils import convert_numpy_types | |
| warnings.filterwarnings("ignore") | |
| router = APIRouter(prefix="/pronunciation", tags=["Pronunciation"]) | |
| class PronunciationAssessmentResult(BaseModel): | |
| transcript: str # What the user actually said (character transcript) | |
| transcript_phonemes: str # User's phonemes | |
| user_phonemes: str # Alias for transcript_phonemes for UI clarity | |
| character_transcript: str | |
| overall_score: float | |
| word_highlights: List[Dict] | |
| phoneme_differences: List[Dict] | |
| wrong_words: List[Dict] | |
| feedback: List[str] | |
| processing_info: Dict | |
| assessor = SimplePronunciationAssessor() | |
| async def assess_pronunciation( | |
| audio: UploadFile = File(..., description="Audio file (.wav, .mp3, .m4a)"), | |
| reference_text: str = Form(..., description="Reference text to pronounce"), | |
| mode: str = Form( | |
| "normal", | |
| description="Assessment mode: 'normal' (Whisper) or 'advanced' (Wav2Vec2)", | |
| ), | |
| ): | |
| """ | |
| Pronunciation Assessment API with mode selection | |
| Key Features: | |
| - Normal mode: Uses Whisper for more accurate transcription with language model | |
| - Advanced mode: Uses facebook/wav2vec2-large-960h-lv60-self for character transcription | |
| - NO language model correction in advanced mode (shows actual pronunciation errors) | |
| - Character-level accuracy converted to phoneme representation | |
| - Vietnamese-optimized feedback and tips | |
| Input: Audio file + Reference text + Mode | |
| Output: Word highlights + Phoneme differences + Wrong words | |
| """ | |
| import time | |
| start_time = time.time() | |
| # Validate mode | |
| if mode not in ["normal", "advanced"]: | |
| raise HTTPException( | |
| status_code=400, detail="Mode must be 'normal' or 'advanced'" | |
| ) | |
| # Validate inputs | |
| if not reference_text.strip(): | |
| raise HTTPException(status_code=400, detail="Reference text cannot be empty") | |
| if len(reference_text) > 500: | |
| raise HTTPException( | |
| status_code=400, detail="Reference text too long (max 500 characters)" | |
| ) | |
| # Check for valid English characters | |
| if not re.match(r"^[a-zA-Z\s\'\-\.!?,;:]+$", reference_text): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Text must contain only English letters, spaces, and basic punctuation", | |
| ) | |
| try: | |
| # Save uploaded file temporarily | |
| file_extension = ".wav" | |
| if audio.filename and "." in audio.filename: | |
| file_extension = f".{audio.filename.split('.')[-1]}" | |
| with tempfile.NamedTemporaryFile( | |
| delete=False, suffix=file_extension | |
| ) as tmp_file: | |
| content = await audio.read() | |
| tmp_file.write(content) | |
| tmp_file.flush() | |
| logger.info(f"Processing audio file: {tmp_file.name} with mode: {mode}") | |
| # Run assessment using selected mode | |
| result = assessor.assess_pronunciation(tmp_file.name, reference_text, mode) | |
| # Add processing time | |
| processing_time = time.time() - start_time | |
| result["processing_info"]["processing_time"] = processing_time | |
| # Convert numpy types for JSON serialization | |
| final_result = convert_numpy_types(result) | |
| logger.info( | |
| f"Assessment completed in {processing_time:.2f} seconds using {mode} mode" | |
| ) | |
| return PronunciationAssessmentResult(**final_result) | |
| except Exception as e: | |
| logger.error(f"Assessment error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Assessment failed: {str(e)}") | |
| # ============================================================================= | |
| # UTILITY ENDPOINTS | |
| # ============================================================================= | |
| async def get_word_phonemes(word: str): | |
| """Get phoneme breakdown for a specific word""" | |
| try: | |
| g2p = SimpleG2P() | |
| phoneme_data = g2p.text_to_phonemes(word)[0] | |
| # Add difficulty analysis for Vietnamese speakers | |
| difficulty_scores = [] | |
| comparator = PhonemeComparator() | |
| for phoneme in phoneme_data["phonemes"]: | |
| difficulty = comparator.difficulty_map.get(phoneme, 0.3) | |
| difficulty_scores.append(difficulty) | |
| avg_difficulty = float(np.mean(difficulty_scores)) if difficulty_scores else 0.3 | |
| return { | |
| "word": word, | |
| "phonemes": phoneme_data["phonemes"], | |
| "phoneme_string": phoneme_data["phoneme_string"], | |
| "ipa": phoneme_data["ipa"], | |
| "difficulty_score": avg_difficulty, | |
| "difficulty_level": ( | |
| "hard" | |
| if avg_difficulty > 0.6 | |
| else "medium" if avg_difficulty > 0.4 else "easy" | |
| ), | |
| "challenging_phonemes": [ | |
| { | |
| "phoneme": p, | |
| "difficulty": comparator.difficulty_map.get(p, 0.3), | |
| "vietnamese_tip": get_vietnamese_tip(p), | |
| } | |
| for p in phoneme_data["phonemes"] | |
| if comparator.difficulty_map.get(p, 0.3) > 0.6 | |
| ], | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Word analysis error: {str(e)}") | |
| def get_vietnamese_tip(phoneme: str) -> str: | |
| """Get Vietnamese pronunciation tip for a phoneme""" | |
| tips = { | |
| "θ": "Đặt lưỡi giữa răng, thổi nhẹ", | |
| "ð": "Giống θ nhưng rung dây thanh âm", | |
| "v": "Môi dưới chạm răng trên", | |
| "r": "Cuộn lưỡi, không chạm vòm miệng", | |
| "l": "Lưỡi chạm vòm miệng sau răng", | |
| "z": "Như 's' nhưng rung dây thanh", | |
| "ʒ": "Như 'ʃ' nhưng rung dây thanh", | |
| "w": "Tròn môi như 'u'", | |
| } | |
| return tips.get(phoneme, f"Luyện âm {phoneme}") | |