Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import ( | |
| AutoModelForCTC, | |
| AutoProcessor, | |
| Wav2Vec2Processor, | |
| Wav2Vec2ForCTC, | |
| ) | |
| import onnxruntime as rt | |
| import numpy as np | |
| import librosa | |
| class Wave2Vec2Inference: | |
| def __init__(self, model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True): | |
| self.device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu" | |
| if use_lm_if_possible: | |
| self.processor = AutoProcessor.from_pretrained(model_name) | |
| else: | |
| self.processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| self.model = AutoModelForCTC.from_pretrained(model_name) | |
| self.model.to(self.device) | |
| self.hotwords = hotwords | |
| self.use_lm_if_possible = use_lm_if_possible | |
| def buffer_to_text(self, audio_buffer): | |
| if len(audio_buffer) == 0: | |
| return "" | |
| inputs = self.processor( | |
| torch.tensor(audio_buffer), | |
| sampling_rate=16_000, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| logits = self.model( | |
| inputs.input_values.to(self.device), | |
| attention_mask=inputs.attention_mask.to(self.device), | |
| ).logits | |
| if hasattr(self.processor, "decoder") and self.use_lm_if_possible: | |
| transcription = self.processor.decode( | |
| logits[0].cpu().numpy(), | |
| hotwords=self.hotwords, | |
| # hotword_weight=self.hotword_weight, | |
| output_word_offsets=True, | |
| ) | |
| confidence = transcription.lm_score / len(transcription.text.split(" ")) | |
| transcription: str = transcription.text | |
| else: | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription: str = self.processor.batch_decode(predicted_ids)[0] | |
| # confidence = self.confidence_score(logits, predicted_ids) | |
| return transcription.lower() | |
| def confidence_score(self, logits, predicted_ids): | |
| scores = torch.nn.functional.softmax(logits, dim=-1) | |
| pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0] | |
| mask = torch.logical_and( | |
| predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id), | |
| predicted_ids.not_equal(self.processor.tokenizer.pad_token_id), | |
| ) | |
| character_scores = pred_scores.masked_select(mask) | |
| total_average = torch.sum(character_scores) / len(character_scores) | |
| return total_average | |
| def file_to_text(self, filename): | |
| import librosa | |
| audio_input, samplerate = librosa.load(filename, sr=16000) | |
| return self.buffer_to_text(audio_input) | |
| class Wave2Vec2ONNXInference: | |
| def __init__(self, model_name, onnx_path): | |
| self.processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| # self.model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
| options = rt.SessionOptions() | |
| options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| self.model = rt.InferenceSession(onnx_path, options) | |
| def buffer_to_text(self, audio_buffer): | |
| if len(audio_buffer) == 0: | |
| return "" | |
| inputs = self.processor( | |
| torch.tensor(audio_buffer), | |
| sampling_rate=16_000, | |
| return_tensors="np", | |
| padding=True, | |
| ) | |
| input_values = inputs.input_values | |
| onnx_outputs = self.model.run( | |
| None, {self.model.get_inputs()[0].name: input_values} | |
| )[0] | |
| prediction = np.argmax(onnx_outputs, axis=-1) | |
| transcription = self.processor.decode(prediction.squeeze().tolist()) | |
| return transcription.lower() | |
| def file_to_text(self, filename): | |
| audio_input, samplerate = librosa.load(filename, sr=16000) | |
| return self.buffer_to_text(audio_input) | |
| # took that script from: https://github.com/ccoreilly/wav2vec2-service/blob/master/convert_torch_to_onnx.py | |
| def convert_to_onnx(model_id_or_path, onnx_model_name): | |
| print(f"Converting {model_id_or_path} to onnx") | |
| model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path) | |
| audio_len = 250000 | |
| x = torch.randn(1, audio_len, requires_grad=True) | |
| torch.onnx.export( | |
| model, # model being run | |
| x, # model input (or a tuple for multiple inputs) | |
| onnx_model_name, # where to save the model (can be a file or file-like object) | |
| export_params=True, # store the trained parameter weights inside the model file | |
| opset_version=14, # the ONNX version to export the model to | |
| do_constant_folding=True, # whether to execute constant folding for optimization | |
| input_names=["input"], # the model's input names | |
| output_names=["output"], # the model's output names | |
| dynamic_axes={ | |
| "input": {1: "audio_len"}, # variable length axes | |
| "output": {1: "audio_len"}, | |
| }, | |
| ) | |
| def quantize_onnx_model(onnx_model_path, quantized_model_path): | |
| print("Starting quantization...") | |
| from onnxruntime.quantization import quantize_dynamic, QuantType | |
| quantize_dynamic( | |
| onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8 | |
| ) | |
| print(f"Quantized model saved to: {quantized_model_path}") | |
| def export_to_onnx( | |
| model: str = "facebook/wav2vec2-large-960h-lv60-self", quantize: bool = False | |
| ): | |
| onnx_model_name = model.split("/")[-1] + ".onnx" | |
| convert_to_onnx(model, onnx_model_name) | |
| if quantize: | |
| quantized_model_name = model.split("/")[-1] + ".quant.onnx" | |
| quantize_onnx_model(onnx_model_name, quantized_model_name) | |
| if __name__ == "__main__": | |
| from loguru import logger | |
| import time | |
| asr = Wave2Vec2Inference("facebook/wav2vec2-large-960h-lv60-self") | |
| # Warm up runs | |
| print("Warming up...") | |
| for i in range(2): | |
| asr.file_to_text("test.wav") | |
| print(f"Warm up {i+1} completed") | |
| # Test runs | |
| print("Running tests...") | |
| times = [] | |
| for i in range(10): | |
| start_time = time.time() | |
| text = asr.file_to_text("test.wav") | |
| end_time = time.time() | |
| execution_time = end_time - start_time | |
| times.append(execution_time) | |
| print(f"Test {i+1}: {execution_time:.3f}s - {text}") | |
| # Calculate average time | |
| average_time = sum(times) / len(times) | |
| print(f"\nAverage execution time: {average_time:.3f}s") | |
| print(f"Min time: {min(times):.3f}s") | |
| print(f"Max time: {max(times):.3f}s") | |