ABAO77 commited on
Commit
cef1d4a
·
1 Parent(s): 1a5420f

feat: implement Wav2Vec2 character-level ASR with ONNX and Transformers support, add phoneme comparison and feedback generation

Browse files
src/apis/controllers/speaking_controller.py CHANGED
@@ -121,58 +121,91 @@ class WhisperASR:
121
  }
122
 
123
 
124
- class Wav2Vec2CharacterASRONNX:
125
- """Wav2Vec2 character-level ASR with ONNX runtime - no language model correction"""
126
 
127
  def __init__(
128
  self,
 
 
129
  onnx_model_path: str = "./wav2vec2_asr.onnx",
130
- processor_name: str = "facebook/wav2vec2-base-960h",
131
  ):
132
  """
133
- Initialize Wav2Vec2 ONNX character-level model
134
- Automatically creates ONNX model if it doesn't exist
135
 
136
  Args:
137
- onnx_model_path: Path to the ONNX model file
138
- processor_name: HuggingFace model name for the processor
 
139
  """
140
- print(f"Loading Wav2Vec2 ONNX model from: {onnx_model_path}")
141
- print(f"Loading processor: {processor_name}")
 
 
 
 
 
 
 
 
 
 
142
 
 
 
143
  # Check if ONNX model exists, if not create it
144
- if not os.path.exists(onnx_model_path):
145
- print(f"ONNX model not found at {onnx_model_path}. Creating it...")
146
- self._create_onnx_model(onnx_model_path, processor_name)
147
 
148
  try:
149
  # Load ONNX model
150
- self.session = onnxruntime.InferenceSession(onnx_model_path)
151
  self.input_name = self.session.get_inputs()[0].name
152
  self.output_name = self.session.get_outputs()[0].name
153
 
154
  # Load processor
155
- self.processor = Wav2Vec2Processor.from_pretrained(processor_name)
156
 
157
  print("ONNX Wav2Vec2 character model loaded successfully")
158
- self.model_name = processor_name
159
- self.onnx_path = onnx_model_path
160
- self.sample_rate = 16000
161
 
162
  except Exception as e:
163
  print(f"Error loading ONNX model: {e}")
164
  raise
165
 
166
- def _create_onnx_model(self, onnx_model_path: str, processor_name: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  """Create ONNX model if it doesn't exist"""
168
  try:
169
  # Import the converter from model_convert
170
  from src.model_convert.wav2vec2onnx import Wav2Vec2ONNXConverter
171
 
172
  print("Creating new ONNX model...")
173
- converter = Wav2Vec2ONNXConverter(processor_name)
174
  created_path = converter.convert_to_onnx(
175
- onnx_path=onnx_model_path,
176
  input_length=160000, # 10 seconds
177
  opset_version=14,
178
  )
@@ -184,9 +217,16 @@ class Wav2Vec2CharacterASRONNX:
184
 
185
  def transcribe_to_characters(self, audio_path: str) -> Dict:
186
  """
187
- Transcribe audio directly to characters using ONNX model (no language model correction)
188
  Returns raw character sequence as produced by the model
189
  """
 
 
 
 
 
 
 
190
  try:
191
  # Load audio
192
  start_time = time.time()
@@ -233,14 +273,56 @@ class Wav2Vec2CharacterASRONNX:
233
  }
234
 
235
  except Exception as e:
236
- print(f"Transcription error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  return {
238
- "character_transcript": "",
239
- "phoneme_representation": "",
240
- "raw_predicted_ids": [],
241
- "confidence_scores": [],
 
 
242
  }
243
 
 
 
 
 
244
  def _calculate_confidence_scores(self, logits: np.ndarray) -> List[float]:
245
  """Calculate confidence scores from logits using numpy"""
246
  # Apply softmax
@@ -257,27 +339,23 @@ class Wav2Vec2CharacterASRONNX:
257
  logger.info(f"Raw transcript before cleaning: {transcript}")
258
  cleaned = re.sub(r"\s+", " ", transcript)
259
  cleaned = cleaned.strip().lower()
260
-
261
  return cleaned
262
 
263
  def _characters_to_phoneme_representation(self, text: str) -> str:
264
  """Convert character-based transcript to phoneme-like representation for comparison"""
265
- # This is a simple character-to-phoneme mapping for pronunciation comparison
266
- # The idea is to convert the raw character output to something comparable with reference phonemes
267
-
268
  if not text:
269
  return ""
270
 
271
  words = text.split()
272
  phoneme_words = []
273
-
274
- # Use our G2P to convert transcript words to phonemes
275
  g2p = SimpleG2P()
276
-
277
  for word in words:
278
  try:
279
- word_data = g2p.text_to_phonemes(word)[0]
280
- phoneme_words.extend(word_data["phonemes"])
 
 
 
281
  except:
282
  # Fallback: simple letter-to-sound mapping
283
  phoneme_words.extend(self._simple_letter_to_phoneme(word))
@@ -322,17 +400,35 @@ class Wav2Vec2CharacterASRONNX:
322
 
323
  return phonemes
324
 
325
- def get_model_info(self) -> Dict:
326
- """Get information about the loaded ONNX model"""
327
  return {
328
- "onnx_model_path": self.onnx_path,
329
- "processor_name": self.model_name,
330
- "input_name": self.input_name,
331
- "output_name": self.output_name,
 
 
 
 
 
 
332
  "sample_rate": self.sample_rate,
333
- "session_providers": self.session.get_providers(),
334
  }
335
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  class SimpleG2P:
338
  """Simple Grapheme-to-Phoneme converter for reference text"""
@@ -866,7 +962,7 @@ class SimplePronunciationAssessor:
866
 
867
  def __init__(self):
868
  print("Initializing Simple Pronunciation Assessor...")
869
- self.wav2vec2_asr = Wav2Vec2CharacterASRONNX() # Advanced mode
870
  self.whisper_asr = WhisperASR() # Normal mode
871
  self.word_analyzer = WordAnalyzer()
872
  self.feedback_generator = SimpleFeedbackGenerator()
 
121
  }
122
 
123
 
124
+ class Wav2Vec2CharacterASR:
125
+ """Wav2Vec2 character-level ASR with support for both ONNX and Transformers inference"""
126
 
127
  def __init__(
128
  self,
129
+ model_name: str = "facebook/wav2vec2-large-960h-lv60-self",
130
+ onnx: bool = False,
131
  onnx_model_path: str = "./wav2vec2_asr.onnx",
 
132
  ):
133
  """
134
+ Initialize Wav2Vec2 character-level model
 
135
 
136
  Args:
137
+ model_name: HuggingFace model name
138
+ onnx: If True, use ONNX runtime for inference. If False, use Transformers
139
+ onnx_model_path: Path to the ONNX model file (only used if onnx=True)
140
  """
141
+ self.model_name = model_name
142
+ self.use_onnx = onnx
143
+ self.onnx_model_path = onnx_model_path
144
+ self.sample_rate = 16000
145
+
146
+ print(f"Loading Wav2Vec2 character model: {model_name}")
147
+ print(f"Using {'ONNX' if onnx else 'Transformers'} for inference")
148
+
149
+ if self.use_onnx:
150
+ self._init_onnx_model()
151
+ else:
152
+ self._init_transformers_model()
153
 
154
+ def _init_onnx_model(self):
155
+ """Initialize ONNX model and processor"""
156
  # Check if ONNX model exists, if not create it
157
+ if not os.path.exists(self.onnx_model_path):
158
+ print(f"ONNX model not found at {self.onnx_model_path}. Creating it...")
159
+ self._create_onnx_model()
160
 
161
  try:
162
  # Load ONNX model
163
+ self.session = onnxruntime.InferenceSession(self.onnx_model_path)
164
  self.input_name = self.session.get_inputs()[0].name
165
  self.output_name = self.session.get_outputs()[0].name
166
 
167
  # Load processor
168
+ self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
169
 
170
  print("ONNX Wav2Vec2 character model loaded successfully")
 
 
 
171
 
172
  except Exception as e:
173
  print(f"Error loading ONNX model: {e}")
174
  raise
175
 
176
+ def _init_transformers_model(self):
177
+ """Initialize Transformers model and processor"""
178
+ try:
179
+ self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
180
+ self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
181
+ self.model.eval()
182
+ print("Wav2Vec2 character model loaded successfully")
183
+ except Exception as e:
184
+ print(f"Error loading model {self.model_name}: {e}")
185
+ # Fallback to base model
186
+ fallback_model = "facebook/wav2vec2-base-960h"
187
+ print(f"Trying fallback model: {fallback_model}")
188
+ try:
189
+ self.processor = Wav2Vec2Processor.from_pretrained(fallback_model)
190
+ self.model = Wav2Vec2ForCTC.from_pretrained(fallback_model)
191
+ self.model.eval()
192
+ self.model_name = fallback_model
193
+ print("Fallback model loaded successfully")
194
+ except Exception as e2:
195
+ raise Exception(
196
+ f"Failed to load both models. Original error: {e}, Fallback error: {e2}"
197
+ )
198
+
199
+ def _create_onnx_model(self):
200
  """Create ONNX model if it doesn't exist"""
201
  try:
202
  # Import the converter from model_convert
203
  from src.model_convert.wav2vec2onnx import Wav2Vec2ONNXConverter
204
 
205
  print("Creating new ONNX model...")
206
+ converter = Wav2Vec2ONNXConverter(self.model_name)
207
  created_path = converter.convert_to_onnx(
208
+ onnx_path=self.onnx_model_path,
209
  input_length=160000, # 10 seconds
210
  opset_version=14,
211
  )
 
217
 
218
  def transcribe_to_characters(self, audio_path: str) -> Dict:
219
  """
220
+ Transcribe audio directly to characters (no language model correction)
221
  Returns raw character sequence as produced by the model
222
  """
223
+ if self.use_onnx:
224
+ return self._transcribe_onnx(audio_path)
225
+ else:
226
+ return self._transcribe_transformers(audio_path)
227
+
228
+ def _transcribe_onnx(self, audio_path: str) -> Dict:
229
+ """Transcribe using ONNX runtime"""
230
  try:
231
  # Load audio
232
  start_time = time.time()
 
273
  }
274
 
275
  except Exception as e:
276
+ print(f"ONNX transcription error: {e}")
277
+ return self._empty_result()
278
+
279
+ def _transcribe_transformers(self, audio_path: str) -> Dict:
280
+ """Transcribe using Transformers"""
281
+ try:
282
+ # Load audio
283
+ start_time = time.time()
284
+ speech, sr = librosa.load(audio_path, sr=self.sample_rate)
285
+
286
+ # Prepare input
287
+ input_values = self.processor(
288
+ speech, sampling_rate=self.sample_rate, return_tensors="pt"
289
+ ).input_values
290
+
291
+ # Get model predictions (no language model involved)
292
+ with torch.no_grad():
293
+ logits = self.model(input_values).logits
294
+ predicted_ids = torch.argmax(logits, dim=-1)
295
+
296
+ # Decode to characters directly
297
+ character_transcript = self.processor.batch_decode(predicted_ids)[0]
298
+
299
+ # Clean up character transcript
300
+ character_transcript = self._clean_character_transcript(
301
+ character_transcript
302
+ )
303
+
304
+ # Convert characters to phoneme-like representation
305
+ phoneme_like_transcript = self._characters_to_phoneme_representation(
306
+ character_transcript
307
+ )
308
+
309
+ logger.info(
310
+ f"Transformers transcription time: {time.time() - start_time:.2f}s"
311
+ )
312
+
313
  return {
314
+ "character_transcript": character_transcript,
315
+ "phoneme_representation": phoneme_like_transcript,
316
+ "raw_predicted_ids": predicted_ids[0].tolist(),
317
+ "confidence_scores": torch.softmax(logits, dim=-1)
318
+ .max(dim=-1)[0][0]
319
+ .tolist()[:100], # Limit for JSON
320
  }
321
 
322
+ except Exception as e:
323
+ print(f"Transformers transcription error: {e}")
324
+ return self._empty_result()
325
+
326
  def _calculate_confidence_scores(self, logits: np.ndarray) -> List[float]:
327
  """Calculate confidence scores from logits using numpy"""
328
  # Apply softmax
 
339
  logger.info(f"Raw transcript before cleaning: {transcript}")
340
  cleaned = re.sub(r"\s+", " ", transcript)
341
  cleaned = cleaned.strip().lower()
 
342
  return cleaned
343
 
344
  def _characters_to_phoneme_representation(self, text: str) -> str:
345
  """Convert character-based transcript to phoneme-like representation for comparison"""
 
 
 
346
  if not text:
347
  return ""
348
 
349
  words = text.split()
350
  phoneme_words = []
 
 
351
  g2p = SimpleG2P()
 
352
  for word in words:
353
  try:
354
+ if g2p:
355
+ word_data = g2p.text_to_phonemes(word)[0]
356
+ phoneme_words.extend(word_data["phonemes"])
357
+ else:
358
+ phoneme_words.extend(self._simple_letter_to_phoneme(word))
359
  except:
360
  # Fallback: simple letter-to-sound mapping
361
  phoneme_words.extend(self._simple_letter_to_phoneme(word))
 
400
 
401
  return phonemes
402
 
403
+ def _empty_result(self) -> Dict:
404
+ """Return empty result structure"""
405
  return {
406
+ "character_transcript": "",
407
+ "phoneme_representation": "",
408
+ "raw_predicted_ids": [],
409
+ "confidence_scores": [],
410
+ }
411
+
412
+ def get_model_info(self) -> Dict:
413
+ """Get information about the loaded model"""
414
+ info = {
415
+ "model_name": self.model_name,
416
  "sample_rate": self.sample_rate,
417
+ "inference_method": "ONNX" if self.use_onnx else "Transformers",
418
  }
419
 
420
+ if self.use_onnx:
421
+ info.update(
422
+ {
423
+ "onnx_model_path": self.onnx_model_path,
424
+ "input_name": self.input_name,
425
+ "output_name": self.output_name,
426
+ "session_providers": self.session.get_providers(),
427
+ }
428
+ )
429
+
430
+ return info
431
+
432
 
433
  class SimpleG2P:
434
  """Simple Grapheme-to-Phoneme converter for reference text"""
 
962
 
963
  def __init__(self):
964
  print("Initializing Simple Pronunciation Assessor...")
965
+ self.wav2vec2_asr = Wav2Vec2CharacterASR() # Advanced mode
966
  self.whisper_asr = WhisperASR() # Normal mode
967
  self.word_analyzer = WordAnalyzer()
968
  self.feedback_generator = SimpleFeedbackGenerator()
src/utils/speaking_utils.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ import numpy as np
3
+ import nltk
4
+ import eng_to_ipa as ipa
5
+ import re
6
+ from collections import defaultdict
7
+
8
+
9
+ try:
10
+ nltk.download("cmudict", quiet=True)
11
+ from nltk.corpus import cmudict
12
+ except:
13
+ print("Warning: NLTK data not available")
14
+
15
+
16
+ class SimpleG2P:
17
+ """Simple Grapheme-to-Phoneme converter for reference text"""
18
+
19
+ def __init__(self):
20
+ try:
21
+ self.cmu_dict = cmudict.dict()
22
+ except:
23
+ self.cmu_dict = {}
24
+ print("Warning: CMU dictionary not available")
25
+
26
+ def text_to_phonemes(self, text: str) -> List[Dict]:
27
+ """Convert text to phoneme sequence"""
28
+ words = self._clean_text(text).split()
29
+ phoneme_sequence = []
30
+
31
+ for word in words:
32
+ word_phonemes = self._get_word_phonemes(word)
33
+ phoneme_sequence.append(
34
+ {
35
+ "word": word,
36
+ "phonemes": word_phonemes,
37
+ "ipa": self._get_ipa(word),
38
+ "phoneme_string": " ".join(word_phonemes),
39
+ }
40
+ )
41
+
42
+ return phoneme_sequence
43
+
44
+ def get_reference_phoneme_string(self, text: str) -> str:
45
+ """Get reference phoneme string for comparison"""
46
+ phoneme_sequence = self.text_to_phonemes(text)
47
+ all_phonemes = []
48
+
49
+ for word_data in phoneme_sequence:
50
+ all_phonemes.extend(word_data["phonemes"])
51
+
52
+ return " ".join(all_phonemes)
53
+
54
+ def _clean_text(self, text: str) -> str:
55
+ """Clean text for processing"""
56
+ text = re.sub(r"[^\w\s\']", " ", text)
57
+ text = re.sub(r"\s+", " ", text)
58
+ return text.lower().strip()
59
+
60
+ def _get_word_phonemes(self, word: str) -> List[str]:
61
+ """Get phonemes for a word"""
62
+ word_lower = word.lower()
63
+
64
+ if word_lower in self.cmu_dict:
65
+ # Remove stress markers and convert to Wav2Vec2 phoneme format
66
+ phonemes = self.cmu_dict[word_lower][0]
67
+ clean_phonemes = [re.sub(r"[0-9]", "", p) for p in phonemes]
68
+ return self._convert_to_wav2vec_format(clean_phonemes)
69
+ else:
70
+ return self._estimate_phonemes(word)
71
+
72
+ def _convert_to_wav2vec_format(self, cmu_phonemes: List[str]) -> List[str]:
73
+ """Convert CMU phonemes to Wav2Vec2 format"""
74
+ # Mapping from CMU to Wav2Vec2/eSpeak phonemes
75
+ cmu_to_espeak = {
76
+ "AA": "ɑ",
77
+ "AE": "æ",
78
+ "AH": "ʌ",
79
+ "AO": "ɔ",
80
+ "AW": "aʊ",
81
+ "AY": "aɪ",
82
+ "EH": "ɛ",
83
+ "ER": "ɝ",
84
+ "EY": "eɪ",
85
+ "IH": "ɪ",
86
+ "IY": "i",
87
+ "OW": "oʊ",
88
+ "OY": "ɔɪ",
89
+ "UH": "ʊ",
90
+ "UW": "u",
91
+ "B": "b",
92
+ "CH": "tʃ",
93
+ "D": "d",
94
+ "DH": "ð",
95
+ "F": "f",
96
+ "G": "ɡ",
97
+ "HH": "h",
98
+ "JH": "dʒ",
99
+ "K": "k",
100
+ "L": "l",
101
+ "M": "m",
102
+ "N": "n",
103
+ "NG": "ŋ",
104
+ "P": "p",
105
+ "R": "r",
106
+ "S": "s",
107
+ "SH": "ʃ",
108
+ "T": "t",
109
+ "TH": "θ",
110
+ "V": "v",
111
+ "W": "w",
112
+ "Y": "j",
113
+ "Z": "z",
114
+ "ZH": "ʒ",
115
+ }
116
+
117
+ converted = []
118
+ for phoneme in cmu_phonemes:
119
+ converted_phoneme = cmu_to_espeak.get(phoneme, phoneme.lower())
120
+ converted.append(converted_phoneme)
121
+
122
+ return converted
123
+
124
+ def _get_ipa(self, word: str) -> str:
125
+ """Get IPA transcription"""
126
+ try:
127
+ return ipa.convert(word)
128
+ except:
129
+ return f"/{word}/"
130
+
131
+ def _estimate_phonemes(self, word: str) -> List[str]:
132
+ """Estimate phonemes for unknown words"""
133
+ # Basic phoneme estimation with eSpeak-style output
134
+ phoneme_map = {
135
+ "ch": ["tʃ"],
136
+ "sh": ["ʃ"],
137
+ "th": ["θ"],
138
+ "ph": ["f"],
139
+ "ck": ["k"],
140
+ "ng": ["ŋ"],
141
+ "qu": ["k", "w"],
142
+ "a": ["æ"],
143
+ "e": ["ɛ"],
144
+ "i": ["ɪ"],
145
+ "o": ["ʌ"],
146
+ "u": ["ʌ"],
147
+ "b": ["b"],
148
+ "c": ["k"],
149
+ "d": ["d"],
150
+ "f": ["f"],
151
+ "g": ["ɡ"],
152
+ "h": ["h"],
153
+ "j": ["dʒ"],
154
+ "k": ["k"],
155
+ "l": ["l"],
156
+ "m": ["m"],
157
+ "n": ["n"],
158
+ "p": ["p"],
159
+ "r": ["r"],
160
+ "s": ["s"],
161
+ "t": ["t"],
162
+ "v": ["v"],
163
+ "w": ["w"],
164
+ "x": ["k", "s"],
165
+ "y": ["j"],
166
+ "z": ["z"],
167
+ }
168
+
169
+ word = word.lower()
170
+ phonemes = []
171
+ i = 0
172
+
173
+ while i < len(word):
174
+ # Check 2-letter combinations first
175
+ if i <= len(word) - 2:
176
+ two_char = word[i : i + 2]
177
+ if two_char in phoneme_map:
178
+ phonemes.extend(phoneme_map[two_char])
179
+ i += 2
180
+ continue
181
+
182
+ # Single character
183
+ char = word[i]
184
+ if char in phoneme_map:
185
+ phonemes.extend(phoneme_map[char])
186
+
187
+ i += 1
188
+
189
+ return phonemes
190
+
191
+
192
+ class PhonemeComparator:
193
+ """Compare reference and learner phoneme sequences"""
194
+
195
+ def __init__(self):
196
+ # Vietnamese speakers' common phoneme substitutions
197
+ self.substitution_patterns = {
198
+ "θ": ["f", "s", "t"], # TH → F, S, T
199
+ "ð": ["d", "z", "v"], # DH → D, Z, V
200
+ "v": ["w", "f"], # V → W, F
201
+ "r": ["l"], # R → L
202
+ "l": ["r"], # L → R
203
+ "z": ["s"], # Z → S
204
+ "ʒ": ["ʃ", "z"], # ZH → SH, Z
205
+ "ŋ": ["n"], # NG → N
206
+ }
207
+
208
+ # Difficulty levels for Vietnamese speakers
209
+ self.difficulty_map = {
210
+ "θ": 0.9, # th (think)
211
+ "ð": 0.9, # th (this)
212
+ "v": 0.8, # v
213
+ "z": 0.8, # z
214
+ "ʒ": 0.9, # zh (measure)
215
+ "r": 0.7, # r
216
+ "l": 0.6, # l
217
+ "w": 0.5, # w
218
+ "f": 0.4, # f
219
+ "s": 0.3, # s
220
+ "ʃ": 0.5, # sh
221
+ "tʃ": 0.4, # ch
222
+ "dʒ": 0.5, # j
223
+ "ŋ": 0.3, # ng
224
+ }
225
+
226
+ def compare_phoneme_sequences(
227
+ self, reference_phonemes: str, learner_phonemes: str
228
+ ) -> List[Dict]:
229
+ """Compare reference and learner phoneme sequences"""
230
+
231
+ # Split phoneme strings
232
+ ref_phones = reference_phonemes.split()
233
+ learner_phones = learner_phonemes.split()
234
+
235
+ print(f"Reference phonemes: {ref_phones}")
236
+ print(f"Learner phonemes: {learner_phones}")
237
+
238
+ # Simple alignment comparison
239
+ comparisons = []
240
+ max_len = max(len(ref_phones), len(learner_phones))
241
+
242
+ for i in range(max_len):
243
+ ref_phoneme = ref_phones[i] if i < len(ref_phones) else ""
244
+ learner_phoneme = learner_phones[i] if i < len(learner_phones) else ""
245
+
246
+ if ref_phoneme and learner_phoneme:
247
+ # Both present - check accuracy
248
+ if ref_phoneme == learner_phoneme:
249
+ status = "correct"
250
+ score = 1.0
251
+ elif self._is_acceptable_substitution(ref_phoneme, learner_phoneme):
252
+ status = "acceptable"
253
+ score = 0.7
254
+ else:
255
+ status = "wrong"
256
+ score = 0.2
257
+
258
+ elif ref_phoneme and not learner_phoneme:
259
+ # Missing phoneme
260
+ status = "missing"
261
+ score = 0.0
262
+
263
+ elif learner_phoneme and not ref_phoneme:
264
+ # Extra phoneme
265
+ status = "extra"
266
+ score = 0.0
267
+ else:
268
+ continue
269
+
270
+ comparison = {
271
+ "position": i,
272
+ "reference_phoneme": ref_phoneme,
273
+ "learner_phoneme": learner_phoneme,
274
+ "status": status,
275
+ "score": score,
276
+ "difficulty": self.difficulty_map.get(ref_phoneme, 0.3),
277
+ }
278
+
279
+ comparisons.append(comparison)
280
+
281
+ return comparisons
282
+
283
+ def _is_acceptable_substitution(self, reference: str, learner: str) -> bool:
284
+ """Check if learner phoneme is acceptable substitution for Vietnamese speakers"""
285
+ acceptable = self.substitution_patterns.get(reference, [])
286
+ return learner in acceptable
287
+
288
+
289
+ # =============================================================================
290
+ # WORD ANALYZER
291
+ # =============================================================================
292
+
293
+
294
+ class WordAnalyzer:
295
+ """Analyze word-level pronunciation accuracy using character-based ASR"""
296
+
297
+ def __init__(self):
298
+ self.g2p = SimpleG2P()
299
+ self.comparator = PhonemeComparator()
300
+
301
+ def analyze_words(self, reference_text: str, learner_phonemes: str) -> Dict:
302
+ """Analyze word-level pronunciation using phoneme representation from character ASR"""
303
+
304
+ # Get reference phonemes by word
305
+ reference_words = self.g2p.text_to_phonemes(reference_text)
306
+
307
+ # Get overall phoneme comparison
308
+ reference_phoneme_string = self.g2p.get_reference_phoneme_string(reference_text)
309
+ phoneme_comparisons = self.comparator.compare_phoneme_sequences(
310
+ reference_phoneme_string, learner_phonemes
311
+ )
312
+
313
+ # Map phonemes back to words
314
+ word_highlights = self._create_word_highlights(
315
+ reference_words, phoneme_comparisons
316
+ )
317
+
318
+ # Identify wrong words
319
+ wrong_words = self._identify_wrong_words(word_highlights, phoneme_comparisons)
320
+
321
+ return {
322
+ "word_highlights": word_highlights,
323
+ "phoneme_differences": phoneme_comparisons,
324
+ "wrong_words": wrong_words,
325
+ }
326
+
327
+ def _create_word_highlights(
328
+ self, reference_words: List[Dict], phoneme_comparisons: List[Dict]
329
+ ) -> List[Dict]:
330
+ """Create word highlighting data"""
331
+
332
+ word_highlights = []
333
+ phoneme_index = 0
334
+
335
+ for word_data in reference_words:
336
+ word = word_data["word"]
337
+ word_phonemes = word_data["phonemes"]
338
+ num_phonemes = len(word_phonemes)
339
+
340
+ # Get phoneme scores for this word
341
+ word_phoneme_scores = []
342
+ for j in range(num_phonemes):
343
+ if phoneme_index + j < len(phoneme_comparisons):
344
+ comparison = phoneme_comparisons[phoneme_index + j]
345
+ word_phoneme_scores.append(comparison["score"])
346
+
347
+ # Calculate word score
348
+ word_score = np.mean(word_phoneme_scores) if word_phoneme_scores else 0.0
349
+
350
+ # Create word highlight
351
+ highlight = {
352
+ "word": word,
353
+ "score": float(word_score),
354
+ "status": self._get_word_status(word_score),
355
+ "color": self._get_word_color(word_score),
356
+ "phonemes": word_phonemes,
357
+ "ipa": word_data["ipa"],
358
+ "phoneme_scores": word_phoneme_scores,
359
+ "phoneme_start_index": phoneme_index,
360
+ "phoneme_end_index": phoneme_index + num_phonemes - 1,
361
+ }
362
+
363
+ word_highlights.append(highlight)
364
+ phoneme_index += num_phonemes
365
+
366
+ return word_highlights
367
+
368
+ def _identify_wrong_words(
369
+ self, word_highlights: List[Dict], phoneme_comparisons: List[Dict]
370
+ ) -> List[Dict]:
371
+ """Identify words that were pronounced incorrectly"""
372
+
373
+ wrong_words = []
374
+
375
+ for word_highlight in word_highlights:
376
+ if word_highlight["score"] < 0.6: # Threshold for wrong pronunciation
377
+
378
+ # Find specific phoneme errors for this word
379
+ start_idx = word_highlight["phoneme_start_index"]
380
+ end_idx = word_highlight["phoneme_end_index"]
381
+
382
+ wrong_phonemes = []
383
+ missing_phonemes = []
384
+
385
+ for i in range(start_idx, min(end_idx + 1, len(phoneme_comparisons))):
386
+ comparison = phoneme_comparisons[i]
387
+
388
+ if comparison["status"] == "wrong":
389
+ wrong_phonemes.append(
390
+ {
391
+ "expected": comparison["reference_phoneme"],
392
+ "actual": comparison["learner_phoneme"],
393
+ "difficulty": comparison["difficulty"],
394
+ }
395
+ )
396
+ elif comparison["status"] == "missing":
397
+ missing_phonemes.append(
398
+ {
399
+ "phoneme": comparison["reference_phoneme"],
400
+ "difficulty": comparison["difficulty"],
401
+ }
402
+ )
403
+
404
+ wrong_word = {
405
+ "word": word_highlight["word"],
406
+ "score": word_highlight["score"],
407
+ "expected_phonemes": word_highlight["phonemes"],
408
+ "ipa": word_highlight["ipa"],
409
+ "wrong_phonemes": wrong_phonemes,
410
+ "missing_phonemes": missing_phonemes,
411
+ "tips": self._get_vietnamese_tips(wrong_phonemes, missing_phonemes),
412
+ }
413
+
414
+ wrong_words.append(wrong_word)
415
+
416
+ return wrong_words
417
+
418
+ def _get_word_status(self, score: float) -> str:
419
+ """Get word status from score"""
420
+ if score >= 0.8:
421
+ return "excellent"
422
+ elif score >= 0.6:
423
+ return "good"
424
+ elif score >= 0.4:
425
+ return "needs_practice"
426
+ else:
427
+ return "poor"
428
+
429
+ def _get_word_color(self, score: float) -> str:
430
+ """Get color for word highlighting"""
431
+ if score >= 0.8:
432
+ return "#22c55e" # Green
433
+ elif score >= 0.6:
434
+ return "#84cc16" # Light green
435
+ elif score >= 0.4:
436
+ return "#eab308" # Yellow
437
+ else:
438
+ return "#ef4444" # Red
439
+
440
+ def _get_vietnamese_tips(
441
+ self, wrong_phonemes: List[Dict], missing_phonemes: List[Dict]
442
+ ) -> List[str]:
443
+ """Get Vietnamese-specific pronunciation tips"""
444
+
445
+ tips = []
446
+
447
+ # Tips for specific Vietnamese pronunciation challenges
448
+ vietnamese_tips = {
449
+ "θ": "Đặt lưỡi giữa răng trên và dưới, thổi nhẹ (think, three)",
450
+ "ð": "Giống θ nhưng rung dây thanh âm (this, that)",
451
+ "v": "Chạm môi dưới vào răng trên, không dùng cả hai môi như tiếng Việt",
452
+ "r": "Cuộn lưỡi nhưng không chạm vào vòm miệng, không lăn lưỡi",
453
+ "l": "Đầu lưỡi chạm vào vòm miệng sau răng",
454
+ "z": "Giống âm 's' nhưng có rung dây thanh âm",
455
+ "ʒ": "Giống âm 'ʃ' (sh) nhưng có rung dây thanh âm",
456
+ "w": "Tròn môi như âm 'u', không dùng răng như âm 'v'",
457
+ }
458
+
459
+ # Add tips for wrong phonemes
460
+ for wrong in wrong_phonemes:
461
+ expected = wrong["expected"]
462
+ actual = wrong["actual"]
463
+
464
+ if expected in vietnamese_tips:
465
+ tips.append(f"Âm '{expected}': {vietnamese_tips[expected]}")
466
+ else:
467
+ tips.append(f"Luyện âm '{expected}' thay vì '{actual}'")
468
+
469
+ # Add tips for missing phonemes
470
+ for missing in missing_phonemes:
471
+ phoneme = missing["phoneme"]
472
+ if phoneme in vietnamese_tips:
473
+ tips.append(f"Thiếu âm '{phoneme}': {vietnamese_tips[phoneme]}")
474
+
475
+ return tips
476
+
477
+
478
+ class SimpleFeedbackGenerator:
479
+ """Generate simple, actionable feedback in Vietnamese"""
480
+
481
+ def generate_feedback(
482
+ self,
483
+ overall_score: float,
484
+ wrong_words: List[Dict],
485
+ phoneme_comparisons: List[Dict],
486
+ ) -> List[str]:
487
+ """Generate Vietnamese feedback"""
488
+
489
+ feedback = []
490
+
491
+ # Overall feedback in Vietnamese
492
+ if overall_score >= 0.8:
493
+ feedback.append("Phát âm rất tốt! Bạn đã làm xuất sắc.")
494
+ elif overall_score >= 0.6:
495
+ feedback.append("Phát âm khá tốt, còn một vài điểm cần cải thiện.")
496
+ elif overall_score >= 0.4:
497
+ feedback.append(
498
+ "Cần luyện tập thêm. Tập trung vào những từ được đánh dấu đỏ."
499
+ )
500
+ else:
501
+ feedback.append("Hãy luyện tập chậm và rõ ràng hơn.")
502
+
503
+ # Wrong words feedback
504
+ if wrong_words:
505
+ if len(wrong_words) <= 3:
506
+ word_names = [w["word"] for w in wrong_words]
507
+ feedback.append(f"Các từ cần luyện tập: {', '.join(word_names)}")
508
+ else:
509
+ feedback.append(
510
+ f"Có {len(wrong_words)} từ cần luyện tập. Tập trung vào từng từ một."
511
+ )
512
+
513
+ # Most problematic phonemes
514
+ problem_phonemes = defaultdict(int)
515
+ for comparison in phoneme_comparisons:
516
+ if comparison["status"] in ["wrong", "missing"]:
517
+ phoneme = comparison["reference_phoneme"]
518
+ problem_phonemes[phoneme] += 1
519
+
520
+ if problem_phonemes:
521
+ most_difficult = sorted(
522
+ problem_phonemes.items(), key=lambda x: x[1], reverse=True
523
+ )
524
+ top_problem = most_difficult[0][0]
525
+
526
+ phoneme_tips = {
527
+ "θ": "Lưỡi giữa răng, thổi nhẹ",
528
+ "ð": "Lưỡi giữa răng, rung dây thanh",
529
+ "v": "Môi dưới chạm răng trên",
530
+ "r": "Cuộn lưỡi, không chạm vòm miệng",
531
+ "l": "Lưỡi chạm vòm miệng",
532
+ "z": "Như 's' nhưng rung dây thanh",
533
+ }
534
+
535
+ if top_problem in phoneme_tips:
536
+ feedback.append(
537
+ f"Âm khó nhất '{top_problem}': {phoneme_tips[top_problem]}"
538
+ )
539
+
540
+ return feedback
541
+
542
+
543
+ def convert_numpy_types(obj):
544
+ """Convert numpy types to Python native types"""
545
+ if isinstance(obj, np.integer):
546
+ return int(obj)
547
+ elif isinstance(obj, np.floating):
548
+ return float(obj)
549
+ elif isinstance(obj, np.ndarray):
550
+ return obj.tolist()
551
+ elif isinstance(obj, dict):
552
+ return {key: convert_numpy_types(value) for key, value in obj.items()}
553
+ elif isinstance(obj, list):
554
+ return [convert_numpy_types(item) for item in obj]
555
+ else:
556
+ return obj