Run_code_api / src /AI_Models /wave2vec_inference.py
ABAO77's picture
feat: update device selection in Wave2Vec2Inference to support GPU usage and enhance performance
1ea1cfa
raw
history blame
6.49 kB
import torch
from transformers import (
AutoModelForCTC,
AutoProcessor,
Wav2Vec2Processor,
Wav2Vec2ForCTC,
)
import onnxruntime as rt
import numpy as np
import librosa
class Wave2Vec2Inference:
def __init__(self, model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True):
self.device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"
if use_lm_if_possible:
self.processor = AutoProcessor.from_pretrained(model_name)
else:
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.model = AutoModelForCTC.from_pretrained(model_name)
self.model.to(self.device)
self.hotwords = hotwords
self.use_lm_if_possible = use_lm_if_possible
def buffer_to_text(self, audio_buffer):
if len(audio_buffer) == 0:
return ""
inputs = self.processor(
torch.tensor(audio_buffer),
sampling_rate=16_000,
return_tensors="pt",
padding=True,
)
with torch.no_grad():
logits = self.model(
inputs.input_values.to(self.device),
attention_mask=inputs.attention_mask.to(self.device),
).logits
if hasattr(self.processor, "decoder") and self.use_lm_if_possible:
transcription = self.processor.decode(
logits[0].cpu().numpy(),
hotwords=self.hotwords,
# hotword_weight=self.hotword_weight,
output_word_offsets=True,
)
confidence = transcription.lm_score / len(transcription.text.split(" "))
transcription: str = transcription.text
else:
predicted_ids = torch.argmax(logits, dim=-1)
transcription: str = self.processor.batch_decode(predicted_ids)[0]
# confidence = self.confidence_score(logits, predicted_ids)
return transcription.lower()
def confidence_score(self, logits, predicted_ids):
scores = torch.nn.functional.softmax(logits, dim=-1)
pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
mask = torch.logical_and(
predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
predicted_ids.not_equal(self.processor.tokenizer.pad_token_id),
)
character_scores = pred_scores.masked_select(mask)
total_average = torch.sum(character_scores) / len(character_scores)
return total_average
def file_to_text(self, filename):
import librosa
audio_input, samplerate = librosa.load(filename, sr=16000)
return self.buffer_to_text(audio_input)
class Wave2Vec2ONNXInference:
def __init__(self, model_name, onnx_path):
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
# self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
options = rt.SessionOptions()
options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
self.model = rt.InferenceSession(onnx_path, options)
def buffer_to_text(self, audio_buffer):
if len(audio_buffer) == 0:
return ""
inputs = self.processor(
torch.tensor(audio_buffer),
sampling_rate=16_000,
return_tensors="np",
padding=True,
)
input_values = inputs.input_values
onnx_outputs = self.model.run(
None, {self.model.get_inputs()[0].name: input_values}
)[0]
prediction = np.argmax(onnx_outputs, axis=-1)
transcription = self.processor.decode(prediction.squeeze().tolist())
return transcription.lower()
def file_to_text(self, filename):
audio_input, samplerate = librosa.load(filename, sr=16000)
return self.buffer_to_text(audio_input)
# took that script from: https://github.com/ccoreilly/wav2vec2-service/blob/master/convert_torch_to_onnx.py
def convert_to_onnx(model_id_or_path, onnx_model_name):
print(f"Converting {model_id_or_path} to onnx")
model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
audio_len = 250000
x = torch.randn(1, audio_len, requires_grad=True)
torch.onnx.export(
model, # model being run
x, # model input (or a tuple for multiple inputs)
onnx_model_name, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=14, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=["input"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes={
"input": {1: "audio_len"}, # variable length axes
"output": {1: "audio_len"},
},
)
def quantize_onnx_model(onnx_model_path, quantized_model_path):
print("Starting quantization...")
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8
)
print(f"Quantized model saved to: {quantized_model_path}")
def export_to_onnx(
model: str = "facebook/wav2vec2-large-960h-lv60-self", quantize: bool = False
):
onnx_model_name = model.split("/")[-1] + ".onnx"
convert_to_onnx(model, onnx_model_name)
if quantize:
quantized_model_name = model.split("/")[-1] + ".quant.onnx"
quantize_onnx_model(onnx_model_name, quantized_model_name)
if __name__ == "__main__":
from loguru import logger
import time
asr = Wave2Vec2Inference("facebook/wav2vec2-large-960h-lv60-self")
# Warm up runs
print("Warming up...")
for i in range(2):
asr.file_to_text("test.wav")
print(f"Warm up {i+1} completed")
# Test runs
print("Running tests...")
times = []
for i in range(10):
start_time = time.time()
text = asr.file_to_text("test.wav")
end_time = time.time()
execution_time = end_time - start_time
times.append(execution_time)
print(f"Test {i+1}: {execution_time:.3f}s - {text}")
# Calculate average time
average_time = sum(times) / len(times)
print(f"\nAverage execution time: {average_time:.3f}s")
print(f"Min time: {min(times):.3f}s")
print(f"Max time: {max(times):.3f}s")