File size: 6,491 Bytes
c5ca6dc
1f79c2f
 
 
 
 
 
c5ca6dc
 
 
 
 
 
 
1ea1cfa
c5ca6dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import torch
from transformers import (
    AutoModelForCTC,
    AutoProcessor,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
)
import onnxruntime as rt
import numpy as np
import librosa


class Wave2Vec2Inference:
    def __init__(self, model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True):
        self.device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"
        if use_lm_if_possible:
            self.processor = AutoProcessor.from_pretrained(model_name)
        else:
            self.processor = Wav2Vec2Processor.from_pretrained(model_name)
        self.model = AutoModelForCTC.from_pretrained(model_name)
        self.model.to(self.device)
        self.hotwords = hotwords
        self.use_lm_if_possible = use_lm_if_possible

    def buffer_to_text(self, audio_buffer):
        if len(audio_buffer) == 0:
            return ""

        inputs = self.processor(
            torch.tensor(audio_buffer),
            sampling_rate=16_000,
            return_tensors="pt",
            padding=True,
        )

        with torch.no_grad():
            logits = self.model(
                inputs.input_values.to(self.device),
                attention_mask=inputs.attention_mask.to(self.device),
            ).logits

        if hasattr(self.processor, "decoder") and self.use_lm_if_possible:
            transcription = self.processor.decode(
                logits[0].cpu().numpy(),
                hotwords=self.hotwords,
                # hotword_weight=self.hotword_weight,
                output_word_offsets=True,
            )
            confidence = transcription.lm_score / len(transcription.text.split(" "))
            transcription: str = transcription.text
        else:
            predicted_ids = torch.argmax(logits, dim=-1)
            transcription: str = self.processor.batch_decode(predicted_ids)[0]
            # confidence = self.confidence_score(logits, predicted_ids)
        return transcription.lower()

    def confidence_score(self, logits, predicted_ids):
        scores = torch.nn.functional.softmax(logits, dim=-1)
        pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
        mask = torch.logical_and(
            predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
            predicted_ids.not_equal(self.processor.tokenizer.pad_token_id),
        )

        character_scores = pred_scores.masked_select(mask)
        total_average = torch.sum(character_scores) / len(character_scores)
        return total_average

    def file_to_text(self, filename):
        import librosa

        audio_input, samplerate = librosa.load(filename, sr=16000)
        return self.buffer_to_text(audio_input)


class Wave2Vec2ONNXInference:
    def __init__(self, model_name, onnx_path):
        self.processor = Wav2Vec2Processor.from_pretrained(model_name)
        # self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
        options = rt.SessionOptions()
        options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
        self.model = rt.InferenceSession(onnx_path, options)

    def buffer_to_text(self, audio_buffer):
        if len(audio_buffer) == 0:
            return ""

        inputs = self.processor(
            torch.tensor(audio_buffer),
            sampling_rate=16_000,
            return_tensors="np",
            padding=True,
        )

        input_values = inputs.input_values
        onnx_outputs = self.model.run(
            None, {self.model.get_inputs()[0].name: input_values}
        )[0]
        prediction = np.argmax(onnx_outputs, axis=-1)

        transcription = self.processor.decode(prediction.squeeze().tolist())
        return transcription.lower()

    def file_to_text(self, filename):
        audio_input, samplerate = librosa.load(filename, sr=16000)
        return self.buffer_to_text(audio_input)


# took that script from: https://github.com/ccoreilly/wav2vec2-service/blob/master/convert_torch_to_onnx.py


def convert_to_onnx(model_id_or_path, onnx_model_name):
    print(f"Converting {model_id_or_path} to onnx")
    model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
    audio_len = 250000

    x = torch.randn(1, audio_len, requires_grad=True)

    torch.onnx.export(
        model,  # model being run
        x,  # model input (or a tuple for multiple inputs)
        onnx_model_name,  # where to save the model (can be a file or file-like object)
        export_params=True,  # store the trained parameter weights inside the model file
        opset_version=14,  # the ONNX version to export the model to
        do_constant_folding=True,  # whether to execute constant folding for optimization
        input_names=["input"],  # the model's input names
        output_names=["output"],  # the model's output names
        dynamic_axes={
            "input": {1: "audio_len"},  # variable length axes
            "output": {1: "audio_len"},
        },
    )


def quantize_onnx_model(onnx_model_path, quantized_model_path):
    print("Starting quantization...")
    from onnxruntime.quantization import quantize_dynamic, QuantType

    quantize_dynamic(
        onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8
    )

    print(f"Quantized model saved to: {quantized_model_path}")


def export_to_onnx(
    model: str = "facebook/wav2vec2-large-960h-lv60-self", quantize: bool = False
):
    onnx_model_name = model.split("/")[-1] + ".onnx"
    convert_to_onnx(model, onnx_model_name)
    if quantize:
        quantized_model_name = model.split("/")[-1] + ".quant.onnx"
        quantize_onnx_model(onnx_model_name, quantized_model_name)


if __name__ == "__main__":
    from loguru import logger
    import time

    asr = Wave2Vec2Inference("facebook/wav2vec2-large-960h-lv60-self")

    # Warm up runs
    print("Warming up...")
    for i in range(2):
        asr.file_to_text("test.wav")
        print(f"Warm up {i+1} completed")

    # Test runs
    print("Running tests...")
    times = []
    for i in range(10):
        start_time = time.time()
        text = asr.file_to_text("test.wav")
        end_time = time.time()
        execution_time = end_time - start_time
        times.append(execution_time)
        print(f"Test {i+1}: {execution_time:.3f}s - {text}")

    # Calculate average time
    average_time = sum(times) / len(times)
    print(f"\nAverage execution time: {average_time:.3f}s")
    print(f"Min time: {min(times):.3f}s")
    print(f"Max time: {max(times):.3f}s")