Spaces:
Sleeping
Sleeping
| # PRONUNCIATION ASSESSMENT USING WAV2VEC2PHONEME | |
| # Input: Audio + Reference Text → Output: Word highlights + Phoneme diff + Wrong words | |
| # Uses Wav2Vec2Phoneme for accurate phoneme-level transcription without language model correction | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, APIRouter | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Dict, Optional | |
| import tempfile | |
| import os | |
| import numpy as np | |
| import librosa | |
| import nltk | |
| import eng_to_ipa as ipa | |
| import torch | |
| import re | |
| from collections import defaultdict | |
| import warnings | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2PhonemeCTCTokenizer | |
| warnings.filterwarnings("ignore") | |
| # Download required NLTK data | |
| try: | |
| nltk.download("cmudict", quiet=True) | |
| from nltk.corpus import cmudict | |
| except: | |
| print("Warning: NLTK data not available") | |
| # ============================================================================= | |
| # MODELS | |
| # ============================================================================= | |
| router = APIRouter(prefix="/pronunciation", tags=["Pronunciation"]) | |
| class PronunciationAssessmentResult(BaseModel): | |
| transcript: str # What the user actually said (character transcript) | |
| transcript_phonemes: str # User's phonemes | |
| user_phonemes: str # Alias for transcript_phonemes for UI clarity | |
| character_transcript: str | |
| overall_score: float | |
| word_highlights: List[Dict] | |
| phoneme_differences: List[Dict] | |
| wrong_words: List[Dict] | |
| feedback: List[str] | |
| processing_info: Dict | |
| # ============================================================================= | |
| # WAV2VEC2 PHONEME ASR | |
| # ============================================================================= | |
| class Wav2Vec2CharacterASR: | |
| """Wav2Vec2 character-level ASR without language model correction""" | |
| def __init__(self, model_name: str = "facebook/wav2vec2-base-960h"): | |
| """ | |
| Initialize Wav2Vec2 character-level model | |
| Available models: | |
| - facebook/wav2vec2-large-960h-lv60-self (character-level, no LM) | |
| - facebook/wav2vec2-base-960h (character-level, no LM) | |
| - facebook/wav2vec2-large-960h (character-level, no LM) | |
| """ | |
| print(f"Loading Wav2Vec2 character model: {model_name}") | |
| try: | |
| self.processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| self.model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
| self.model.eval() | |
| print("Wav2Vec2 character model loaded successfully") | |
| self.model_name = model_name | |
| except Exception as e: | |
| print(f"Error loading model {model_name}: {e}") | |
| # Fallback to base model | |
| fallback_model = "facebook/wav2vec2-base-960h" | |
| print(f"Trying fallback model: {fallback_model}") | |
| try: | |
| self.processor = Wav2Vec2Processor.from_pretrained(fallback_model) | |
| self.model = Wav2Vec2ForCTC.from_pretrained(fallback_model) | |
| self.model.eval() | |
| self.model_name = fallback_model | |
| print("Fallback model loaded successfully") | |
| except Exception as e2: | |
| raise Exception(f"Failed to load both models. Original error: {e}, Fallback error: {e2}") | |
| self.sample_rate = 16000 | |
| def transcribe_to_characters(self, audio_path: str) -> Dict: | |
| """ | |
| Transcribe audio directly to characters (no language model correction) | |
| Returns raw character sequence as produced by the model | |
| """ | |
| try: | |
| # Load audio | |
| speech, sr = librosa.load(audio_path, sr=self.sample_rate) | |
| # Prepare input | |
| input_values = self.processor( | |
| speech, | |
| sampling_rate=self.sample_rate, | |
| return_tensors="pt" | |
| ).input_values | |
| # Get model predictions (no language model involved) | |
| with torch.no_grad(): | |
| logits = self.model(input_values).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| # Decode to characters directly | |
| character_transcript = self.processor.batch_decode(predicted_ids)[0] | |
| # Clean up character transcript | |
| character_transcript = self._clean_character_transcript(character_transcript) | |
| # Convert characters to phoneme-like representation | |
| phoneme_like_transcript = self._characters_to_phoneme_representation(character_transcript) | |
| return { | |
| "character_transcript": character_transcript, | |
| "phoneme_representation": phoneme_like_transcript, | |
| "raw_predicted_ids": predicted_ids[0].tolist(), | |
| "confidence_scores": torch.softmax(logits, dim=-1).max(dim=-1)[0][0].tolist()[:100] # Limit for JSON | |
| } | |
| except Exception as e: | |
| print(f"Transcription error: {e}") | |
| return { | |
| "character_transcript": "", | |
| "phoneme_representation": "", | |
| "raw_predicted_ids": [], | |
| "confidence_scores": [] | |
| } | |
| def _clean_character_transcript(self, transcript: str) -> str: | |
| """Clean and standardize character transcript""" | |
| # Remove extra spaces and special tokens | |
| cleaned = re.sub(r'\s+', ' ', transcript) | |
| cleaned = cleaned.strip().lower() | |
| return cleaned | |
| def _characters_to_phoneme_representation(self, text: str) -> str: | |
| """Convert character-based transcript to phoneme-like representation for comparison""" | |
| # This is a simple character-to-phoneme mapping for pronunciation comparison | |
| # The idea is to convert the raw character output to something comparable with reference phonemes | |
| if not text: | |
| return "" | |
| words = text.split() | |
| phoneme_words = [] | |
| # Use our G2P to convert transcript words to phonemes | |
| g2p = SimpleG2P() | |
| for word in words: | |
| try: | |
| word_data = g2p.text_to_phonemes(word)[0] | |
| phoneme_words.extend(word_data["phonemes"]) | |
| except: | |
| # Fallback: simple letter-to-sound mapping | |
| phoneme_words.extend(self._simple_letter_to_phoneme(word)) | |
| return " ".join(phoneme_words) | |
| def _simple_letter_to_phoneme(self, word: str) -> List[str]: | |
| """Simple fallback letter-to-phoneme conversion""" | |
| letter_to_phoneme = { | |
| 'a': 'æ', 'b': 'b', 'c': 'k', 'd': 'd', 'e': 'ɛ', | |
| 'f': 'f', 'g': 'ɡ', 'h': 'h', 'i': 'ɪ', 'j': 'dʒ', | |
| 'k': 'k', 'l': 'l', 'm': 'm', 'n': 'n', 'o': 'ʌ', | |
| 'p': 'p', 'q': 'k', 'r': 'r', 's': 's', 't': 't', | |
| 'u': 'ʌ', 'v': 'v', 'w': 'w', 'x': 'ks', 'y': 'j', 'z': 'z' | |
| } | |
| phonemes = [] | |
| for letter in word.lower(): | |
| if letter in letter_to_phoneme: | |
| phonemes.append(letter_to_phoneme[letter]) | |
| return phonemes | |
| # ============================================================================= | |
| # SIMPLE G2P FOR REFERENCE | |
| # ============================================================================= | |
| class SimpleG2P: | |
| """Simple Grapheme-to-Phoneme converter for reference text""" | |
| def __init__(self): | |
| try: | |
| self.cmu_dict = cmudict.dict() | |
| except: | |
| self.cmu_dict = {} | |
| print("Warning: CMU dictionary not available") | |
| def text_to_phonemes(self, text: str) -> List[Dict]: | |
| """Convert text to phoneme sequence""" | |
| words = self._clean_text(text).split() | |
| phoneme_sequence = [] | |
| for word in words: | |
| word_phonemes = self._get_word_phonemes(word) | |
| phoneme_sequence.append({ | |
| "word": word, | |
| "phonemes": word_phonemes, | |
| "ipa": self._get_ipa(word), | |
| "phoneme_string": " ".join(word_phonemes) | |
| }) | |
| return phoneme_sequence | |
| def get_reference_phoneme_string(self, text: str) -> str: | |
| """Get reference phoneme string for comparison""" | |
| phoneme_sequence = self.text_to_phonemes(text) | |
| all_phonemes = [] | |
| for word_data in phoneme_sequence: | |
| all_phonemes.extend(word_data["phonemes"]) | |
| return " ".join(all_phonemes) | |
| def _clean_text(self, text: str) -> str: | |
| """Clean text for processing""" | |
| text = re.sub(r"[^\w\s\']", " ", text) | |
| text = re.sub(r"\s+", " ", text) | |
| return text.lower().strip() | |
| def _get_word_phonemes(self, word: str) -> List[str]: | |
| """Get phonemes for a word""" | |
| word_lower = word.lower() | |
| if word_lower in self.cmu_dict: | |
| # Remove stress markers and convert to Wav2Vec2 phoneme format | |
| phonemes = self.cmu_dict[word_lower][0] | |
| clean_phonemes = [re.sub(r"[0-9]", "", p) for p in phonemes] | |
| return self._convert_to_wav2vec_format(clean_phonemes) | |
| else: | |
| return self._estimate_phonemes(word) | |
| def _convert_to_wav2vec_format(self, cmu_phonemes: List[str]) -> List[str]: | |
| """Convert CMU phonemes to Wav2Vec2 format""" | |
| # Mapping from CMU to Wav2Vec2/eSpeak phonemes | |
| cmu_to_espeak = { | |
| "AA": "ɑ", "AE": "æ", "AH": "ʌ", "AO": "ɔ", "AW": "aʊ", | |
| "AY": "aɪ", "EH": "ɛ", "ER": "ɝ", "EY": "eɪ", "IH": "ɪ", | |
| "IY": "i", "OW": "oʊ", "OY": "ɔɪ", "UH": "ʊ", "UW": "u", | |
| "B": "b", "CH": "tʃ", "D": "d", "DH": "ð", "F": "f", | |
| "G": "ɡ", "HH": "h", "JH": "dʒ", "K": "k", "L": "l", | |
| "M": "m", "N": "n", "NG": "ŋ", "P": "p", "R": "r", | |
| "S": "s", "SH": "ʃ", "T": "t", "TH": "θ", "V": "v", | |
| "W": "w", "Y": "j", "Z": "z", "ZH": "ʒ" | |
| } | |
| converted = [] | |
| for phoneme in cmu_phonemes: | |
| converted_phoneme = cmu_to_espeak.get(phoneme, phoneme.lower()) | |
| converted.append(converted_phoneme) | |
| return converted | |
| def _get_ipa(self, word: str) -> str: | |
| """Get IPA transcription""" | |
| try: | |
| return ipa.convert(word) | |
| except: | |
| return f"/{word}/" | |
| def _estimate_phonemes(self, word: str) -> List[str]: | |
| """Estimate phonemes for unknown words""" | |
| # Basic phoneme estimation with eSpeak-style output | |
| phoneme_map = { | |
| "ch": ["tʃ"], "sh": ["ʃ"], "th": ["θ"], "ph": ["f"], | |
| "ck": ["k"], "ng": ["ŋ"], "qu": ["k", "w"], | |
| "a": ["æ"], "e": ["ɛ"], "i": ["ɪ"], "o": ["ʌ"], "u": ["ʌ"], | |
| "b": ["b"], "c": ["k"], "d": ["d"], "f": ["f"], "g": ["ɡ"], | |
| "h": ["h"], "j": ["dʒ"], "k": ["k"], "l": ["l"], "m": ["m"], | |
| "n": ["n"], "p": ["p"], "r": ["r"], "s": ["s"], "t": ["t"], | |
| "v": ["v"], "w": ["w"], "x": ["k", "s"], "y": ["j"], "z": ["z"] | |
| } | |
| word = word.lower() | |
| phonemes = [] | |
| i = 0 | |
| while i < len(word): | |
| # Check 2-letter combinations first | |
| if i <= len(word) - 2: | |
| two_char = word[i:i+2] | |
| if two_char in phoneme_map: | |
| phonemes.extend(phoneme_map[two_char]) | |
| i += 2 | |
| continue | |
| # Single character | |
| char = word[i] | |
| if char in phoneme_map: | |
| phonemes.extend(phoneme_map[char]) | |
| i += 1 | |
| return phonemes | |
| # ============================================================================= | |
| # PHONEME COMPARATOR | |
| # ============================================================================= | |
| class PhonemeComparator: | |
| """Compare reference and learner phoneme sequences""" | |
| def __init__(self): | |
| # Vietnamese speakers' common phoneme substitutions | |
| self.substitution_patterns = { | |
| "θ": ["f", "s", "t"], # TH → F, S, T | |
| "ð": ["d", "z", "v"], # DH → D, Z, V | |
| "v": ["w", "f"], # V → W, F | |
| "r": ["l"], # R → L | |
| "l": ["r"], # L → R | |
| "z": ["s"], # Z → S | |
| "ʒ": ["ʃ", "z"], # ZH → SH, Z | |
| "ŋ": ["n"], # NG → N | |
| } | |
| # Difficulty levels for Vietnamese speakers | |
| self.difficulty_map = { | |
| "θ": 0.9, # th (think) | |
| "ð": 0.9, # th (this) | |
| "v": 0.8, # v | |
| "z": 0.8, # z | |
| "ʒ": 0.9, # zh (measure) | |
| "r": 0.7, # r | |
| "l": 0.6, # l | |
| "w": 0.5, # w | |
| "f": 0.4, # f | |
| "s": 0.3, # s | |
| "ʃ": 0.5, # sh | |
| "tʃ": 0.4, # ch | |
| "dʒ": 0.5, # j | |
| "ŋ": 0.3, # ng | |
| } | |
| def compare_phoneme_sequences(self, reference_phonemes: str, | |
| learner_phonemes: str) -> List[Dict]: | |
| """Compare reference and learner phoneme sequences""" | |
| # Split phoneme strings | |
| ref_phones = reference_phonemes.split() | |
| learner_phones = learner_phonemes.split() | |
| print(f"Reference phonemes: {ref_phones}") | |
| print(f"Learner phonemes: {learner_phones}") | |
| # Simple alignment comparison | |
| comparisons = [] | |
| max_len = max(len(ref_phones), len(learner_phones)) | |
| for i in range(max_len): | |
| ref_phoneme = ref_phones[i] if i < len(ref_phones) else "" | |
| learner_phoneme = learner_phones[i] if i < len(learner_phones) else "" | |
| if ref_phoneme and learner_phoneme: | |
| # Both present - check accuracy | |
| if ref_phoneme == learner_phoneme: | |
| status = "correct" | |
| score = 1.0 | |
| elif self._is_acceptable_substitution(ref_phoneme, learner_phoneme): | |
| status = "acceptable" | |
| score = 0.7 | |
| else: | |
| status = "wrong" | |
| score = 0.2 | |
| elif ref_phoneme and not learner_phoneme: | |
| # Missing phoneme | |
| status = "missing" | |
| score = 0.0 | |
| elif learner_phoneme and not ref_phoneme: | |
| # Extra phoneme | |
| status = "extra" | |
| score = 0.0 | |
| else: | |
| continue | |
| comparison = { | |
| "position": i, | |
| "reference_phoneme": ref_phoneme, | |
| "learner_phoneme": learner_phoneme, | |
| "status": status, | |
| "score": score, | |
| "difficulty": self.difficulty_map.get(ref_phoneme, 0.3) | |
| } | |
| comparisons.append(comparison) | |
| return comparisons | |
| def _is_acceptable_substitution(self, reference: str, learner: str) -> bool: | |
| """Check if learner phoneme is acceptable substitution for Vietnamese speakers""" | |
| acceptable = self.substitution_patterns.get(reference, []) | |
| return learner in acceptable | |
| # ============================================================================= | |
| # WORD ANALYZER | |
| # ============================================================================= | |
| class WordAnalyzer: | |
| """Analyze word-level pronunciation accuracy using character-based ASR""" | |
| def __init__(self): | |
| self.g2p = SimpleG2P() | |
| self.comparator = PhonemeComparator() | |
| def analyze_words(self, reference_text: str, learner_phonemes: str) -> Dict: | |
| """Analyze word-level pronunciation using phoneme representation from character ASR""" | |
| # Get reference phonemes by word | |
| reference_words = self.g2p.text_to_phonemes(reference_text) | |
| # Get overall phoneme comparison | |
| reference_phoneme_string = self.g2p.get_reference_phoneme_string(reference_text) | |
| phoneme_comparisons = self.comparator.compare_phoneme_sequences( | |
| reference_phoneme_string, learner_phonemes | |
| ) | |
| # Map phonemes back to words | |
| word_highlights = self._create_word_highlights(reference_words, phoneme_comparisons) | |
| # Identify wrong words | |
| wrong_words = self._identify_wrong_words(word_highlights, phoneme_comparisons) | |
| return { | |
| "word_highlights": word_highlights, | |
| "phoneme_differences": phoneme_comparisons, | |
| "wrong_words": wrong_words | |
| } | |
| def _create_word_highlights(self, reference_words: List[Dict], | |
| phoneme_comparisons: List[Dict]) -> List[Dict]: | |
| """Create word highlighting data""" | |
| word_highlights = [] | |
| phoneme_index = 0 | |
| for word_data in reference_words: | |
| word = word_data["word"] | |
| word_phonemes = word_data["phonemes"] | |
| num_phonemes = len(word_phonemes) | |
| # Get phoneme scores for this word | |
| word_phoneme_scores = [] | |
| for j in range(num_phonemes): | |
| if phoneme_index + j < len(phoneme_comparisons): | |
| comparison = phoneme_comparisons[phoneme_index + j] | |
| word_phoneme_scores.append(comparison["score"]) | |
| # Calculate word score | |
| word_score = np.mean(word_phoneme_scores) if word_phoneme_scores else 0.0 | |
| # Create word highlight | |
| highlight = { | |
| "word": word, | |
| "score": float(word_score), | |
| "status": self._get_word_status(word_score), | |
| "color": self._get_word_color(word_score), | |
| "phonemes": word_phonemes, | |
| "ipa": word_data["ipa"], | |
| "phoneme_scores": word_phoneme_scores, | |
| "phoneme_start_index": phoneme_index, | |
| "phoneme_end_index": phoneme_index + num_phonemes - 1 | |
| } | |
| word_highlights.append(highlight) | |
| phoneme_index += num_phonemes | |
| return word_highlights | |
| def _identify_wrong_words(self, word_highlights: List[Dict], | |
| phoneme_comparisons: List[Dict]) -> List[Dict]: | |
| """Identify words that were pronounced incorrectly""" | |
| wrong_words = [] | |
| for word_highlight in word_highlights: | |
| if word_highlight["score"] < 0.6: # Threshold for wrong pronunciation | |
| # Find specific phoneme errors for this word | |
| start_idx = word_highlight["phoneme_start_index"] | |
| end_idx = word_highlight["phoneme_end_index"] | |
| wrong_phonemes = [] | |
| missing_phonemes = [] | |
| for i in range(start_idx, min(end_idx + 1, len(phoneme_comparisons))): | |
| comparison = phoneme_comparisons[i] | |
| if comparison["status"] == "wrong": | |
| wrong_phonemes.append({ | |
| "expected": comparison["reference_phoneme"], | |
| "actual": comparison["learner_phoneme"], | |
| "difficulty": comparison["difficulty"] | |
| }) | |
| elif comparison["status"] == "missing": | |
| missing_phonemes.append({ | |
| "phoneme": comparison["reference_phoneme"], | |
| "difficulty": comparison["difficulty"] | |
| }) | |
| wrong_word = { | |
| "word": word_highlight["word"], | |
| "score": word_highlight["score"], | |
| "expected_phonemes": word_highlight["phonemes"], | |
| "ipa": word_highlight["ipa"], | |
| "wrong_phonemes": wrong_phonemes, | |
| "missing_phonemes": missing_phonemes, | |
| "tips": self._get_vietnamese_tips(wrong_phonemes, missing_phonemes) | |
| } | |
| wrong_words.append(wrong_word) | |
| return wrong_words | |
| def _get_word_status(self, score: float) -> str: | |
| """Get word status from score""" | |
| if score >= 0.8: | |
| return "excellent" | |
| elif score >= 0.6: | |
| return "good" | |
| elif score >= 0.4: | |
| return "needs_practice" | |
| else: | |
| return "poor" | |
| def _get_word_color(self, score: float) -> str: | |
| """Get color for word highlighting""" | |
| if score >= 0.8: | |
| return "#22c55e" # Green | |
| elif score >= 0.6: | |
| return "#84cc16" # Light green | |
| elif score >= 0.4: | |
| return "#eab308" # Yellow | |
| else: | |
| return "#ef4444" # Red | |
| def _get_vietnamese_tips(self, wrong_phonemes: List[Dict], | |
| missing_phonemes: List[Dict]) -> List[str]: | |
| """Get Vietnamese-specific pronunciation tips""" | |
| tips = [] | |
| # Tips for specific Vietnamese pronunciation challenges | |
| vietnamese_tips = { | |
| "θ": "Đặt lưỡi giữa răng trên và dưới, thổi nhẹ (think, three)", | |
| "ð": "Giống θ nhưng rung dây thanh âm (this, that)", | |
| "v": "Chạm môi dưới vào răng trên, không dùng cả hai môi như tiếng Việt", | |
| "r": "Cuộn lưỡi nhưng không chạm vào vòm miệng, không lăn lưỡi", | |
| "l": "Đầu lưỡi chạm vào vòm miệng sau răng", | |
| "z": "Giống âm 's' nhưng có rung dây thanh âm", | |
| "ʒ": "Giống âm 'ʃ' (sh) nhưng có rung dây thanh âm", | |
| "w": "Tròn môi như âm 'u', không dùng răng như âm 'v'" | |
| } | |
| # Add tips for wrong phonemes | |
| for wrong in wrong_phonemes: | |
| expected = wrong["expected"] | |
| actual = wrong["actual"] | |
| if expected in vietnamese_tips: | |
| tips.append(f"Âm '{expected}': {vietnamese_tips[expected]}") | |
| else: | |
| tips.append(f"Luyện âm '{expected}' thay vì '{actual}'") | |
| # Add tips for missing phonemes | |
| for missing in missing_phonemes: | |
| phoneme = missing["phoneme"] | |
| if phoneme in vietnamese_tips: | |
| tips.append(f"Thiếu âm '{phoneme}': {vietnamese_tips[phoneme]}") | |
| return tips | |
| # ============================================================================= | |
| # FEEDBACK GENERATOR | |
| # ============================================================================= | |
| class SimpleFeedbackGenerator: | |
| """Generate simple, actionable feedback in Vietnamese""" | |
| def generate_feedback(self, overall_score: float, wrong_words: List[Dict], | |
| phoneme_comparisons: List[Dict]) -> List[str]: | |
| """Generate Vietnamese feedback""" | |
| feedback = [] | |
| # Overall feedback in Vietnamese | |
| if overall_score >= 0.8: | |
| feedback.append("Phát âm rất tốt! Bạn đã làm xuất sắc.") | |
| elif overall_score >= 0.6: | |
| feedback.append("Phát âm khá tốt, còn một vài điểm cần cải thiện.") | |
| elif overall_score >= 0.4: | |
| feedback.append("Cần luyện tập thêm. Tập trung vào những từ được đánh dấu đỏ.") | |
| else: | |
| feedback.append("Hãy luyện tập chậm và rõ ràng hơn.") | |
| # Wrong words feedback | |
| if wrong_words: | |
| if len(wrong_words) <= 3: | |
| word_names = [w["word"] for w in wrong_words] | |
| feedback.append(f"Các từ cần luyện tập: {', '.join(word_names)}") | |
| else: | |
| feedback.append(f"Có {len(wrong_words)} từ cần luyện tập. Tập trung vào từng từ một.") | |
| # Most problematic phonemes | |
| problem_phonemes = defaultdict(int) | |
| for comparison in phoneme_comparisons: | |
| if comparison["status"] in ["wrong", "missing"]: | |
| phoneme = comparison["reference_phoneme"] | |
| problem_phonemes[phoneme] += 1 | |
| if problem_phonemes: | |
| most_difficult = sorted(problem_phonemes.items(), key=lambda x: x[1], reverse=True) | |
| top_problem = most_difficult[0][0] | |
| phoneme_tips = { | |
| "θ": "Lưỡi giữa răng, thổi nhẹ", | |
| "ð": "Lưỡi giữa răng, rung dây thanh", | |
| "v": "Môi dưới chạm răng trên", | |
| "r": "Cuộn lưỡi, không chạm vòm miệng", | |
| "l": "Lưỡi chạm vòm miệng", | |
| "z": "Như 's' nhưng rung dây thanh" | |
| } | |
| if top_problem in phoneme_tips: | |
| feedback.append(f"Âm khó nhất '{top_problem}': {phoneme_tips[top_problem]}") | |
| return feedback | |
| # ============================================================================= | |
| # MAIN PRONUNCIATION ASSESSOR | |
| # ============================================================================= | |
| class SimplePronunciationAssessor: | |
| """Main pronunciation assessor using Wav2Vec2 character-level model""" | |
| def __init__(self): | |
| print("Initializing Simple Pronunciation Assessor...") | |
| self.asr = Wav2Vec2CharacterASR() # Updated to use character-based ASR | |
| self.word_analyzer = WordAnalyzer() | |
| self.feedback_generator = SimpleFeedbackGenerator() | |
| print("Initialization completed") | |
| def assess_pronunciation(self, audio_path: str, reference_text: str) -> Dict: | |
| """ | |
| Main assessment function | |
| Input: Audio path + Reference text | |
| Output: Word highlights + Phoneme differences + Wrong words | |
| """ | |
| print("Starting pronunciation assessment...") | |
| # Step 1: Wav2Vec2 character transcription (no language model) | |
| print("Step 1: Transcribing to characters...") | |
| asr_result = self.asr.transcribe_to_characters(audio_path) | |
| character_transcript = asr_result["character_transcript"] | |
| phoneme_representation = asr_result["phoneme_representation"] | |
| print(f"Character transcript: {character_transcript}") | |
| print(f"Phoneme representation: {phoneme_representation}") | |
| # Step 2: Word analysis using phoneme representation | |
| print("Step 2: Analyzing words...") | |
| analysis_result = self.word_analyzer.analyze_words(reference_text, phoneme_representation) | |
| # Step 3: Calculate overall score | |
| phoneme_comparisons = analysis_result["phoneme_differences"] | |
| overall_score = self._calculate_overall_score(phoneme_comparisons) | |
| # Step 4: Generate feedback | |
| print("Step 3: Generating feedback...") | |
| feedback = self.feedback_generator.generate_feedback( | |
| overall_score, analysis_result["wrong_words"], phoneme_comparisons | |
| ) | |
| result = { | |
| "transcript": character_transcript, # What user actually said | |
| "transcript_phonemes": phoneme_representation, | |
| "user_phonemes": phoneme_representation, # Alias for UI clarity | |
| "character_transcript": character_transcript, | |
| "overall_score": overall_score, | |
| "word_highlights": analysis_result["word_highlights"], | |
| "phoneme_differences": phoneme_comparisons, | |
| "wrong_words": analysis_result["wrong_words"], | |
| "feedback": feedback, | |
| "processing_info": { | |
| "model_used": f"Wav2Vec2-Character ({self.asr.model_name})", | |
| "character_based": True, | |
| "language_model_correction": False, | |
| "raw_output": True | |
| } | |
| } | |
| print("Assessment completed successfully") | |
| return result | |
| def _calculate_overall_score(self, phoneme_comparisons: List[Dict]) -> float: | |
| """Calculate overall pronunciation score""" | |
| if not phoneme_comparisons: | |
| return 0.0 | |
| total_score = sum(comparison["score"] for comparison in phoneme_comparisons) | |
| return total_score / len(phoneme_comparisons) | |
| # ============================================================================= | |
| # API ENDPOINT | |
| # ============================================================================= | |
| # Initialize assessor | |
| assessor = SimplePronunciationAssessor() | |
| def convert_numpy_types(obj): | |
| """Convert numpy types to Python native types""" | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| elif isinstance(obj, np.floating): | |
| return float(obj) | |
| elif isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| elif isinstance(obj, dict): | |
| return {key: convert_numpy_types(value) for key, value in obj.items()} | |
| elif isinstance(obj, list): | |
| return [convert_numpy_types(item) for item in obj] | |
| else: | |
| return obj | |
| async def assess_pronunciation( | |
| audio: UploadFile = File(..., description="Audio file (.wav, .mp3, .m4a)"), | |
| reference_text: str = Form(..., description="Reference text to pronounce") | |
| ): | |
| """ | |
| Pronunciation Assessment API using Wav2Vec2 Character-level Model | |
| Key Features: | |
| - Uses facebook/wav2vec2-large-960h-lv60-self for character transcription | |
| - NO language model correction (shows actual pronunciation errors) | |
| - Character-level accuracy converted to phoneme representation | |
| - Vietnamese-optimized feedback and tips | |
| Input: Audio file + Reference text | |
| Output: Word highlights + Phoneme differences + Wrong words | |
| """ | |
| import time | |
| start_time = time.time() | |
| # Validate inputs | |
| if not reference_text.strip(): | |
| raise HTTPException(status_code=400, detail="Reference text cannot be empty") | |
| if len(reference_text) > 500: | |
| raise HTTPException(status_code=400, detail="Reference text too long (max 500 characters)") | |
| # Check for valid English characters | |
| if not re.match(r"^[a-zA-Z\s\'\-\.!?,;:]+$", reference_text): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Text must contain only English letters, spaces, and basic punctuation" | |
| ) | |
| try: | |
| # Save uploaded file temporarily | |
| file_extension = ".wav" | |
| if audio.filename and "." in audio.filename: | |
| file_extension = f".{audio.filename.split('.')[-1]}" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: | |
| content = await audio.read() | |
| tmp_file.write(content) | |
| tmp_file.flush() | |
| print(f"Processing audio file: {tmp_file.name}") | |
| # Run assessment using Wav2Vec2 Character model | |
| result = assessor.assess_pronunciation(tmp_file.name, reference_text) | |
| # Add processing time | |
| processing_time = time.time() - start_time | |
| result["processing_info"]["processing_time"] = processing_time | |
| # Convert numpy types for JSON serialization | |
| final_result = convert_numpy_types(result) | |
| print(f"Assessment completed in {processing_time:.2f} seconds") | |
| return PronunciationAssessmentResult(**final_result) | |
| except Exception as e: | |
| print(f"Assessment error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Assessment failed: {str(e)}") | |
| # ============================================================================= | |
| # UTILITY ENDPOINTS | |
| # ============================================================================= | |
| async def get_word_phonemes(word: str): | |
| """Get phoneme breakdown for a specific word""" | |
| try: | |
| g2p = SimpleG2P() | |
| phoneme_data = g2p.text_to_phonemes(word)[0] | |
| # Add difficulty analysis for Vietnamese speakers | |
| difficulty_scores = [] | |
| comparator = PhonemeComparator() | |
| for phoneme in phoneme_data["phonemes"]: | |
| difficulty = comparator.difficulty_map.get(phoneme, 0.3) | |
| difficulty_scores.append(difficulty) | |
| avg_difficulty = float(np.mean(difficulty_scores)) if difficulty_scores else 0.3 | |
| return { | |
| "word": word, | |
| "phonemes": phoneme_data["phonemes"], | |
| "phoneme_string": phoneme_data["phoneme_string"], | |
| "ipa": phoneme_data["ipa"], | |
| "difficulty_score": avg_difficulty, | |
| "difficulty_level": "hard" if avg_difficulty > 0.6 else "medium" if avg_difficulty > 0.4 else "easy", | |
| "challenging_phonemes": [ | |
| { | |
| "phoneme": p, | |
| "difficulty": comparator.difficulty_map.get(p, 0.3), | |
| "vietnamese_tip": get_vietnamese_tip(p) | |
| } | |
| for p in phoneme_data["phonemes"] | |
| if comparator.difficulty_map.get(p, 0.3) > 0.6 | |
| ] | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Word analysis error: {str(e)}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| try: | |
| model_info = { | |
| "status": "healthy", | |
| "model": assessor.asr.model_name, | |
| "character_based": True, | |
| "language_model_correction": False, | |
| "vietnamese_optimized": True | |
| } | |
| return model_info | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "error": str(e) | |
| } | |
| async def test_model(): | |
| """Test if Wav2Vec2 model is working""" | |
| try: | |
| # Test model info | |
| test_result = { | |
| "model_loaded": True, | |
| "model_name": assessor.asr.model_name, | |
| "processor_ready": True, | |
| "sample_rate": assessor.asr.sample_rate, | |
| "sample_characters": "this is a test", | |
| "sample_phonemes": "ðɪs ɪz ə tɛst" | |
| } | |
| return test_result | |
| except Exception as e: | |
| return { | |
| "model_loaded": False, | |
| "error": str(e) | |
| } | |
| # ============================================================================= | |
| # HELPER FUNCTIONS | |
| # ============================================================================= | |
| def get_vietnamese_tip(phoneme: str) -> str: | |
| """Get Vietnamese pronunciation tip for a phoneme""" | |
| tips = { | |
| "θ": "Đặt lưỡi giữa răng, thổi nhẹ", | |
| "ð": "Giống θ nhưng rung dây thanh âm", | |
| "v": "Môi dưới chạm răng trên", | |
| "r": "Cuộn lưỡi, không chạm vòm miệng", | |
| "l": "Lưỡi chạm vòm miệng sau răng", | |
| "z": "Như 's' nhưng rung dây thanh", | |
| "ʒ": "Như 'ʃ' nhưng rung dây thanh", | |
| "w": "Tròn môi như 'u'" | |
| } | |
| return tips.get(phoneme, f"Luyện âm {phoneme}") | |