ABAO77 commited on
Commit
c5ca6dc
·
1 Parent(s): cef1d4a

feat: refactor Wav2Vec2 character ASR to support quantization and improve model loading

Browse files
src/AI_Models/wave2vec_inference.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor
3
+ import onnxruntime as rt
4
+ import numpy as np
5
+ import librosa
6
+
7
+
8
+ class Wave2Vec2Inference:
9
+ def __init__(self, model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True):
10
+ self.device = "cpu"
11
+ if use_lm_if_possible:
12
+ self.processor = AutoProcessor.from_pretrained(model_name)
13
+ else:
14
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
15
+ self.model = AutoModelForCTC.from_pretrained(model_name)
16
+ self.model.to(self.device)
17
+ self.hotwords = hotwords
18
+ self.use_lm_if_possible = use_lm_if_possible
19
+
20
+ def buffer_to_text(self, audio_buffer):
21
+ if len(audio_buffer) == 0:
22
+ return ""
23
+
24
+ inputs = self.processor(
25
+ torch.tensor(audio_buffer),
26
+ sampling_rate=16_000,
27
+ return_tensors="pt",
28
+ padding=True,
29
+ )
30
+
31
+ with torch.no_grad():
32
+ logits = self.model(
33
+ inputs.input_values.to(self.device),
34
+ attention_mask=inputs.attention_mask.to(self.device),
35
+ ).logits
36
+
37
+ if hasattr(self.processor, "decoder") and self.use_lm_if_possible:
38
+ transcription = self.processor.decode(
39
+ logits[0].cpu().numpy(),
40
+ hotwords=self.hotwords,
41
+ # hotword_weight=self.hotword_weight,
42
+ output_word_offsets=True,
43
+ )
44
+ confidence = transcription.lm_score / len(transcription.text.split(" "))
45
+ transcription: str = transcription.text
46
+ else:
47
+ predicted_ids = torch.argmax(logits, dim=-1)
48
+ transcription: str = self.processor.batch_decode(predicted_ids)[0]
49
+ # confidence = self.confidence_score(logits, predicted_ids)
50
+ return transcription.lower()
51
+
52
+ def confidence_score(self, logits, predicted_ids):
53
+ scores = torch.nn.functional.softmax(logits, dim=-1)
54
+ pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
55
+ mask = torch.logical_and(
56
+ predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
57
+ predicted_ids.not_equal(self.processor.tokenizer.pad_token_id),
58
+ )
59
+
60
+ character_scores = pred_scores.masked_select(mask)
61
+ total_average = torch.sum(character_scores) / len(character_scores)
62
+ return total_average
63
+
64
+ def file_to_text(self, filename):
65
+ import librosa
66
+
67
+ audio_input, samplerate = librosa.load(filename, sr=16000)
68
+ return self.buffer_to_text(audio_input)
69
+
70
+
71
+ class Wave2Vec2ONNXInference:
72
+ def __init__(self, model_name, onnx_path):
73
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
74
+ # self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
75
+ options = rt.SessionOptions()
76
+ options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
77
+ self.model = rt.InferenceSession(onnx_path, options)
78
+
79
+ def buffer_to_text(self, audio_buffer):
80
+ if len(audio_buffer) == 0:
81
+ return ""
82
+
83
+ inputs = self.processor(
84
+ torch.tensor(audio_buffer),
85
+ sampling_rate=16_000,
86
+ return_tensors="np",
87
+ padding=True,
88
+ )
89
+
90
+ input_values = inputs.input_values
91
+ onnx_outputs = self.model.run(
92
+ None, {self.model.get_inputs()[0].name: input_values}
93
+ )[0]
94
+ prediction = np.argmax(onnx_outputs, axis=-1)
95
+
96
+ transcription = self.processor.decode(prediction.squeeze().tolist())
97
+ return transcription.lower()
98
+
99
+ def file_to_text(self, filename):
100
+ audio_input, samplerate = librosa.load(filename, sr=16000)
101
+ return self.buffer_to_text(audio_input)
102
+
103
+
104
+ from onnxruntime.quantization.quantize import quantize
105
+ from transformers import Wav2Vec2ForCTC
106
+ import torch
107
+ import argparse
108
+
109
+ # took that script from: https://github.com/ccoreilly/wav2vec2-service/blob/master/convert_torch_to_onnx.py
110
+
111
+
112
+ def convert_to_onnx(model_id_or_path, onnx_model_name):
113
+ print(f"Converting {model_id_or_path} to onnx")
114
+ model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
115
+ audio_len = 250000
116
+
117
+ x = torch.randn(1, audio_len, requires_grad=True)
118
+
119
+ torch.onnx.export(
120
+ model, # model being run
121
+ x, # model input (or a tuple for multiple inputs)
122
+ onnx_model_name, # where to save the model (can be a file or file-like object)
123
+ export_params=True, # store the trained parameter weights inside the model file
124
+ opset_version=14, # the ONNX version to export the model to
125
+ do_constant_folding=True, # whether to execute constant folding for optimization
126
+ input_names=["input"], # the model's input names
127
+ output_names=["output"], # the model's output names
128
+ dynamic_axes={
129
+ "input": {1: "audio_len"}, # variable length axes
130
+ "output": {1: "audio_len"},
131
+ },
132
+ )
133
+
134
+
135
+ def quantize_onnx_model(onnx_model_path, quantized_model_path):
136
+ print("Starting quantization...")
137
+ from onnxruntime.quantization import quantize_dynamic, QuantType
138
+
139
+ quantize_dynamic(
140
+ onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8
141
+ )
142
+
143
+ print(f"Quantized model saved to: {quantized_model_path}")
144
+
145
+
146
+ def export_to_onnx(
147
+ model: str = "facebook/wav2vec2-large-960h-lv60-self", quantize: bool = False
148
+ ):
149
+ onnx_model_name = model.split("/")[-1] + ".onnx"
150
+ convert_to_onnx(model, onnx_model_name)
151
+ if quantize:
152
+ quantized_model_name = model.split("/")[-1] + ".quant.onnx"
153
+ quantize_onnx_model(onnx_model_name, quantized_model_name)
154
+
155
+
156
+ if __name__ == "__main__":
157
+ from loguru import logger
158
+ import time
159
+
160
+ asr = Wave2Vec2Inference("facebook/wav2vec2-large-960h-lv60-self")
161
+
162
+ # Warm up runs
163
+ print("Warming up...")
164
+ for i in range(2):
165
+ asr.file_to_text("test.wav")
166
+ print(f"Warm up {i+1} completed")
167
+
168
+ # Test runs
169
+ print("Running tests...")
170
+ times = []
171
+ for i in range(10):
172
+ start_time = time.time()
173
+ text = asr.file_to_text("test.wav")
174
+ end_time = time.time()
175
+ execution_time = end_time - start_time
176
+ times.append(execution_time)
177
+ print(f"Test {i+1}: {execution_time:.3f}s - {text}")
178
+
179
+ # Calculate average time
180
+ average_time = sum(times) / len(times)
181
+ print(f"\nAverage execution time: {average_time:.3f}s")
182
+ print(f"Min time: {min(times):.3f}s")
183
+ print(f"Max time: {max(times):.3f}s")
src/apis/controllers/speaking_controller.py CHANGED
@@ -1,9 +1,4 @@
1
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, APIRouter
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- from typing import List, Dict, Optional
5
- import tempfile
6
- import os
7
  import numpy as np
8
  import librosa
9
  import nltk
@@ -11,14 +6,15 @@ import eng_to_ipa as ipa
11
  import torch
12
  import re
13
  from collections import defaultdict
14
- import warnings
15
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
16
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
17
  from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
18
  from loguru import logger
19
- import onnxruntime
20
  import time
21
-
 
 
 
 
22
 
23
  # Download required NLTK data
24
  try:
@@ -128,7 +124,7 @@ class Wav2Vec2CharacterASR:
128
  self,
129
  model_name: str = "facebook/wav2vec2-large-960h-lv60-self",
130
  onnx: bool = False,
131
- onnx_model_path: str = "./wav2vec2_asr.onnx",
132
  ):
133
  """
134
  Initialize Wav2Vec2 character-level model
@@ -138,185 +134,48 @@ class Wav2Vec2CharacterASR:
138
  onnx: If True, use ONNX runtime for inference. If False, use Transformers
139
  onnx_model_path: Path to the ONNX model file (only used if onnx=True)
140
  """
141
- self.model_name = model_name
142
  self.use_onnx = onnx
143
- self.onnx_model_path = onnx_model_path
144
  self.sample_rate = 16000
145
-
146
- print(f"Loading Wav2Vec2 character model: {model_name}")
147
- print(f"Using {'ONNX' if onnx else 'Transformers'} for inference")
148
-
149
- if self.use_onnx:
150
- self._init_onnx_model()
151
- else:
152
- self._init_transformers_model()
153
-
154
- def _init_onnx_model(self):
155
- """Initialize ONNX model and processor"""
156
- # Check if ONNX model exists, if not create it
157
- if not os.path.exists(self.onnx_model_path):
158
- print(f"ONNX model not found at {self.onnx_model_path}. Creating it...")
159
- self._create_onnx_model()
160
-
161
- try:
162
- # Load ONNX model
163
- self.session = onnxruntime.InferenceSession(self.onnx_model_path)
164
- self.input_name = self.session.get_inputs()[0].name
165
- self.output_name = self.session.get_outputs()[0].name
166
-
167
- # Load processor
168
- self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
169
-
170
- print("ONNX Wav2Vec2 character model loaded successfully")
171
-
172
- except Exception as e:
173
- print(f"Error loading ONNX model: {e}")
174
- raise
175
-
176
- def _init_transformers_model(self):
177
- """Initialize Transformers model and processor"""
178
- try:
179
- self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
180
- self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
181
- self.model.eval()
182
- print("Wav2Vec2 character model loaded successfully")
183
- except Exception as e:
184
- print(f"Error loading model {self.model_name}: {e}")
185
- # Fallback to base model
186
- fallback_model = "facebook/wav2vec2-base-960h"
187
- print(f"Trying fallback model: {fallback_model}")
188
- try:
189
- self.processor = Wav2Vec2Processor.from_pretrained(fallback_model)
190
- self.model = Wav2Vec2ForCTC.from_pretrained(fallback_model)
191
- self.model.eval()
192
- self.model_name = fallback_model
193
- print("Fallback model loaded successfully")
194
- except Exception as e2:
195
- raise Exception(
196
- f"Failed to load both models. Original error: {e}, Fallback error: {e2}"
197
- )
198
-
199
- def _create_onnx_model(self):
200
- """Create ONNX model if it doesn't exist"""
201
- try:
202
- # Import the converter from model_convert
203
- from src.model_convert.wav2vec2onnx import Wav2Vec2ONNXConverter
204
-
205
- print("Creating new ONNX model...")
206
- converter = Wav2Vec2ONNXConverter(self.model_name)
207
- created_path = converter.convert_to_onnx(
208
- onnx_path=self.onnx_model_path,
209
- input_length=160000, # 10 seconds
210
- opset_version=14,
211
  )
212
- print(f"✓ ONNX model created successfully at: {created_path}")
213
-
214
- except ImportError as e:
215
- print(f"Error importing Wav2Vec2ONNXConverter: {e}")
216
- raise e
217
 
218
  def transcribe_to_characters(self, audio_path: str) -> Dict:
219
- """
220
- Transcribe audio directly to characters (no language model correction)
221
- Returns raw character sequence as produced by the model
222
- """
223
- if self.use_onnx:
224
- return self._transcribe_onnx(audio_path)
225
- else:
226
- return self._transcribe_transformers(audio_path)
227
-
228
- def _transcribe_onnx(self, audio_path: str) -> Dict:
229
- """Transcribe using ONNX runtime"""
230
- try:
231
- # Load audio
232
- start_time = time.time()
233
- speech, sr = librosa.load(audio_path, sr=self.sample_rate)
234
-
235
- # Prepare input for ONNX
236
- input_values = self.processor(
237
- speech, sampling_rate=self.sample_rate, return_tensors="np"
238
- ).input_values
239
-
240
- # Run ONNX inference
241
- ort_inputs = {self.input_name: input_values}
242
- ort_outputs = self.session.run([self.output_name], ort_inputs)
243
- logits = ort_outputs[0]
244
-
245
- # Get predictions
246
- predicted_ids = np.argmax(logits, axis=-1)
247
-
248
- # Decode to characters directly
249
- character_transcript = self.processor.batch_decode(predicted_ids)[0]
250
- logger.info(f"character_transcript {character_transcript}")
251
-
252
- # Clean up character transcript
253
- character_transcript = self._clean_character_transcript(
254
- character_transcript
255
- )
256
-
257
- # Convert characters to phoneme-like representation
258
- phoneme_like_transcript = self._characters_to_phoneme_representation(
259
- character_transcript
260
- )
261
-
262
- # Calculate confidence scores
263
- confidence_scores = self._calculate_confidence_scores(logits)
264
- logger.info(
265
- f"Wav2Vec2 ONNX transcription time: {time.time() - start_time:.2f}s"
266
- )
267
-
268
- return {
269
- "character_transcript": character_transcript,
270
- "phoneme_representation": phoneme_like_transcript,
271
- "raw_predicted_ids": predicted_ids[0].tolist(),
272
- "confidence_scores": confidence_scores[:100], # Limit for JSON
273
- }
274
-
275
- except Exception as e:
276
- print(f"ONNX transcription error: {e}")
277
- return self._empty_result()
278
-
279
- def _transcribe_transformers(self, audio_path: str) -> Dict:
280
- """Transcribe using Transformers"""
281
  try:
282
- # Load audio
283
  start_time = time.time()
284
- speech, sr = librosa.load(audio_path, sr=self.sample_rate)
285
-
286
- # Prepare input
287
- input_values = self.processor(
288
- speech, sampling_rate=self.sample_rate, return_tensors="pt"
289
- ).input_values
290
-
291
- # Get model predictions (no language model involved)
292
- with torch.no_grad():
293
- logits = self.model(input_values).logits
294
- predicted_ids = torch.argmax(logits, dim=-1)
295
-
296
- # Decode to characters directly
297
- character_transcript = self.processor.batch_decode(predicted_ids)[0]
298
-
299
- # Clean up character transcript
300
  character_transcript = self._clean_character_transcript(
301
  character_transcript
302
  )
303
 
304
- # Convert characters to phoneme-like representation
305
  phoneme_like_transcript = self._characters_to_phoneme_representation(
306
  character_transcript
307
  )
308
 
309
- logger.info(
310
- f"Transformers transcription time: {time.time() - start_time:.2f}s"
311
- )
312
 
313
  return {
314
  "character_transcript": character_transcript,
315
  "phoneme_representation": phoneme_like_transcript,
316
- "raw_predicted_ids": predicted_ids[0].tolist(),
317
- "confidence_scores": torch.softmax(logits, dim=-1)
318
- .max(dim=-1)[0][0]
319
- .tolist()[:100], # Limit for JSON
320
  }
321
 
322
  except Exception as e:
@@ -988,7 +847,7 @@ class SimplePronunciationAssessor:
988
  if mode == "advanced":
989
  print("Step 1: Using Wav2Vec2 character transcription...")
990
  asr_result = self.wav2vec2_asr.transcribe_to_characters(audio_path)
991
- model_info = f"Wav2Vec2-Character ({self.wav2vec2_asr.model_name})"
992
  else: # normal mode
993
  print("Step 1: Using Whisper transcription...")
994
  asr_result = self.whisper_asr.transcribe_to_text(audio_path)
@@ -1046,19 +905,3 @@ class SimplePronunciationAssessor:
1046
 
1047
  total_score = sum(comparison["score"] for comparison in phoneme_comparisons)
1048
  return total_score / len(phoneme_comparisons)
1049
-
1050
-
1051
- def convert_numpy_types(obj):
1052
- """Convert numpy types to Python native types"""
1053
- if isinstance(obj, np.integer):
1054
- return int(obj)
1055
- elif isinstance(obj, np.floating):
1056
- return float(obj)
1057
- elif isinstance(obj, np.ndarray):
1058
- return obj.tolist()
1059
- elif isinstance(obj, dict):
1060
- return {key: convert_numpy_types(value) for key, value in obj.items()}
1061
- elif isinstance(obj, list):
1062
- return [convert_numpy_types(item) for item in obj]
1063
- else:
1064
- return obj
 
1
+ from typing import List, Dict
 
 
 
 
 
2
  import numpy as np
3
  import librosa
4
  import nltk
 
6
  import torch
7
  import re
8
  from collections import defaultdict
 
 
9
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
10
  from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
11
  from loguru import logger
 
12
  import time
13
+ from src.AI_Models.wave2vec_inference import (
14
+ Wave2Vec2Inference,
15
+ Wave2Vec2ONNXInference,
16
+ export_to_onnx,
17
+ )
18
 
19
  # Download required NLTK data
20
  try:
 
124
  self,
125
  model_name: str = "facebook/wav2vec2-large-960h-lv60-self",
126
  onnx: bool = False,
127
+ quantized: bool = False,
128
  ):
129
  """
130
  Initialize Wav2Vec2 character-level model
 
134
  onnx: If True, use ONNX runtime for inference. If False, use Transformers
135
  onnx_model_path: Path to the ONNX model file (only used if onnx=True)
136
  """
 
137
  self.use_onnx = onnx
 
138
  self.sample_rate = 16000
139
+ self.model_name = model_name
140
+ # Check thử path của onnx model có tồn tại hay không
141
+ if onnx:
142
+ import os
143
+
144
+ if not os.path.exists(
145
+ "wav2vec2-large-960h-lv60-self"
146
+ + (".quant" if quantized else "")
147
+ + ".onnx"
148
+ ):
149
+
150
+ export_to_onnx(model_name, quantize=quantized)
151
+ self.model = (
152
+ Wave2Vec2Inference(model_name)
153
+ if not onnx
154
+ else Wave2Vec2ONNXInference(
155
+ model_name,
156
+ "wav2vec2-large-960h-lv60-self"
157
+ + (".quant" if quantized else "")
158
+ + ".onnx",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  )
160
+ )
 
 
 
 
161
 
162
  def transcribe_to_characters(self, audio_path: str) -> Dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  try:
 
164
  start_time = time.time()
165
+ character_transcript = self.model.file_to_text(audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  character_transcript = self._clean_character_transcript(
167
  character_transcript
168
  )
169
 
 
170
  phoneme_like_transcript = self._characters_to_phoneme_representation(
171
  character_transcript
172
  )
173
 
174
+ logger.info(f"Transcription time: {time.time() - start_time:.2f}s")
 
 
175
 
176
  return {
177
  "character_transcript": character_transcript,
178
  "phoneme_representation": phoneme_like_transcript,
 
 
 
 
179
  }
180
 
181
  except Exception as e:
 
847
  if mode == "advanced":
848
  print("Step 1: Using Wav2Vec2 character transcription...")
849
  asr_result = self.wav2vec2_asr.transcribe_to_characters(audio_path)
850
+ model_info = f"Wav2Vec2-Character ({self.wav2vec2_asr.model})"
851
  else: # normal mode
852
  print("Step 1: Using Whisper transcription...")
853
  asr_result = self.whisper_asr.transcribe_to_text(audio_path)
 
905
 
906
  total_score = sum(comparison["score"] for comparison in phoneme_comparisons)
907
  return total_score / len(phoneme_comparisons)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/apis/routes/speaking_route.py CHANGED
@@ -10,8 +10,8 @@ from src.apis.controllers.speaking_controller import (
10
  SimpleG2P,
11
  PhonemeComparator,
12
  SimplePronunciationAssessor,
13
- convert_numpy_types,
14
  )
 
15
 
16
  warnings.filterwarnings("ignore")
17
 
 
10
  SimpleG2P,
11
  PhonemeComparator,
12
  SimplePronunciationAssessor,
 
13
  )
14
+ from src.utils.speaking_utils import convert_numpy_types
15
 
16
  warnings.filterwarnings("ignore")
17