Spaces:
Sleeping
Sleeping
fix: update import statement for BaseModel in agent.py and add timing logs in speaking_controller.py
Browse files
src/agents/evaluation/agent.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from langchain_core.prompts import ChatPromptTemplate
|
| 2 |
-
from
|
| 3 |
from src.config.llm import model
|
| 4 |
from src.utils.logger import logger
|
| 5 |
from .prompt import evaluation_prompt
|
|
|
|
| 1 |
from langchain_core.prompts import ChatPromptTemplate
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
from src.config.llm import model
|
| 4 |
from src.utils.logger import logger
|
| 5 |
from .prompt import evaluation_prompt
|
src/apis/controllers/speaking_controller.py
CHANGED
|
@@ -17,8 +17,8 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
|
| 17 |
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
| 18 |
from loguru import logger
|
| 19 |
import onnxruntime
|
|
|
|
| 20 |
|
| 21 |
-
warnings.filterwarnings("ignore")
|
| 22 |
|
| 23 |
# Download required NLTK data
|
| 24 |
try:
|
|
@@ -66,7 +66,8 @@ class WhisperASR:
|
|
| 66 |
Returns transcript and confidence score
|
| 67 |
"""
|
| 68 |
try:
|
| 69 |
-
|
|
|
|
| 70 |
audio, sr = librosa.load(audio_path, sr=self.sample_rate)
|
| 71 |
|
| 72 |
# Process audio
|
|
@@ -95,7 +96,7 @@ class WhisperASR:
|
|
| 95 |
# Convert to phoneme representation for comparison
|
| 96 |
g2p = SimpleG2P()
|
| 97 |
phoneme_representation = g2p.get_reference_phoneme_string(transcript)
|
| 98 |
-
|
| 99 |
return {
|
| 100 |
"character_transcript": transcript,
|
| 101 |
"phoneme_representation": phoneme_representation,
|
|
@@ -179,49 +180,7 @@ class Wav2Vec2CharacterASRONNX:
|
|
| 179 |
|
| 180 |
except ImportError as e:
|
| 181 |
print(f"Error importing Wav2Vec2ONNXConverter: {e}")
|
| 182 |
-
|
| 183 |
-
self._fallback_create_onnx_model(onnx_model_path, processor_name)
|
| 184 |
-
|
| 185 |
-
except Exception as e:
|
| 186 |
-
print(f"Error creating ONNX model: {e}")
|
| 187 |
-
# Try fallback method
|
| 188 |
-
self._fallback_create_onnx_model(onnx_model_path, processor_name)
|
| 189 |
-
|
| 190 |
-
def _fallback_create_onnx_model(self, onnx_model_path: str, processor_name: str):
|
| 191 |
-
"""Fallback method to create ONNX model using basic torch.onnx.export"""
|
| 192 |
-
try:
|
| 193 |
-
print("Using fallback method to create ONNX model...")
|
| 194 |
-
|
| 195 |
-
# Load PyTorch model
|
| 196 |
-
model = Wav2Vec2ForCTC.from_pretrained(processor_name)
|
| 197 |
-
model.eval()
|
| 198 |
-
|
| 199 |
-
# Create dummy input
|
| 200 |
-
dummy_input = torch.randn(1, 160000, dtype=torch.float32)
|
| 201 |
-
|
| 202 |
-
# Export to ONNX
|
| 203 |
-
with torch.no_grad():
|
| 204 |
-
torch.onnx.export(
|
| 205 |
-
model,
|
| 206 |
-
dummy_input,
|
| 207 |
-
onnx_model_path,
|
| 208 |
-
input_names=["input_values"],
|
| 209 |
-
output_names=["logits"],
|
| 210 |
-
dynamic_axes={
|
| 211 |
-
"input_values": {0: "batch_size", 1: "sequence_length"},
|
| 212 |
-
"logits": {0: "batch_size", 1: "sequence_length"},
|
| 213 |
-
},
|
| 214 |
-
opset_version=14,
|
| 215 |
-
do_constant_folding=True,
|
| 216 |
-
verbose=False,
|
| 217 |
-
export_params=True,
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
print(f"✓ Fallback ONNX model created at: {onnx_model_path}")
|
| 221 |
-
|
| 222 |
-
except Exception as e:
|
| 223 |
-
print(f"Fallback method also failed: {e}")
|
| 224 |
-
raise Exception(f"Could not create ONNX model: {e}")
|
| 225 |
|
| 226 |
def transcribe_to_characters(self, audio_path: str) -> Dict:
|
| 227 |
"""
|
|
@@ -230,6 +189,7 @@ class Wav2Vec2CharacterASRONNX:
|
|
| 230 |
"""
|
| 231 |
try:
|
| 232 |
# Load audio
|
|
|
|
| 233 |
speech, sr = librosa.load(audio_path, sr=self.sample_rate)
|
| 234 |
|
| 235 |
# Prepare input for ONNX
|
|
@@ -261,6 +221,9 @@ class Wav2Vec2CharacterASRONNX:
|
|
| 261 |
|
| 262 |
# Calculate confidence scores
|
| 263 |
confidence_scores = self._calculate_confidence_scores(logits)
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
return {
|
| 266 |
"character_transcript": character_transcript,
|
|
@@ -934,6 +897,7 @@ class SimplePronunciationAssessor:
|
|
| 934 |
print("Step 1: Using Whisper transcription...")
|
| 935 |
asr_result = self.whisper_asr.transcribe_to_text(audio_path)
|
| 936 |
model_info = f"Whisper ({self.whisper_asr.model_name})"
|
|
|
|
| 937 |
|
| 938 |
character_transcript = asr_result["character_transcript"]
|
| 939 |
phoneme_representation = asr_result["phoneme_representation"]
|
|
|
|
| 17 |
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
| 18 |
from loguru import logger
|
| 19 |
import onnxruntime
|
| 20 |
+
import time
|
| 21 |
|
|
|
|
| 22 |
|
| 23 |
# Download required NLTK data
|
| 24 |
try:
|
|
|
|
| 66 |
Returns transcript and confidence score
|
| 67 |
"""
|
| 68 |
try:
|
| 69 |
+
|
| 70 |
+
start_time = time.time()
|
| 71 |
audio, sr = librosa.load(audio_path, sr=self.sample_rate)
|
| 72 |
|
| 73 |
# Process audio
|
|
|
|
| 96 |
# Convert to phoneme representation for comparison
|
| 97 |
g2p = SimpleG2P()
|
| 98 |
phoneme_representation = g2p.get_reference_phoneme_string(transcript)
|
| 99 |
+
logger.info(f"Whisper transcription time: {time.time() - start_time:.2f}s")
|
| 100 |
return {
|
| 101 |
"character_transcript": transcript,
|
| 102 |
"phoneme_representation": phoneme_representation,
|
|
|
|
| 180 |
|
| 181 |
except ImportError as e:
|
| 182 |
print(f"Error importing Wav2Vec2ONNXConverter: {e}")
|
| 183 |
+
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
def transcribe_to_characters(self, audio_path: str) -> Dict:
|
| 186 |
"""
|
|
|
|
| 189 |
"""
|
| 190 |
try:
|
| 191 |
# Load audio
|
| 192 |
+
start_time = time.time()
|
| 193 |
speech, sr = librosa.load(audio_path, sr=self.sample_rate)
|
| 194 |
|
| 195 |
# Prepare input for ONNX
|
|
|
|
| 221 |
|
| 222 |
# Calculate confidence scores
|
| 223 |
confidence_scores = self._calculate_confidence_scores(logits)
|
| 224 |
+
logger.info(
|
| 225 |
+
f"Wav2Vec2 ONNX transcription time: {time.time() - start_time:.2f}s"
|
| 226 |
+
)
|
| 227 |
|
| 228 |
return {
|
| 229 |
"character_transcript": character_transcript,
|
|
|
|
| 897 |
print("Step 1: Using Whisper transcription...")
|
| 898 |
asr_result = self.whisper_asr.transcribe_to_text(audio_path)
|
| 899 |
model_info = f"Whisper ({self.whisper_asr.model_name})"
|
| 900 |
+
print(f"Whisper ASR result: {asr_result}")
|
| 901 |
|
| 902 |
character_transcript = asr_result["character_transcript"]
|
| 903 |
phoneme_representation = asr_result["phoneme_representation"]
|