Marti Umbert commited on
Commit
8b05aa3
·
1 Parent(s): b3bfb38

whisperlivekit/whisper_streaming_custom/backends.py: use BatchedInferencePipeline in FasterWhisperASR class, and batch_size=16 in transcribe() function, also created WhisperXASR class

Browse files
whisperlivekit/whisper_streaming_custom/backends.py CHANGED
@@ -95,7 +95,7 @@ class FasterWhisperASR(ASRBase):
95
  sep = ""
96
 
97
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
98
- from faster_whisper import WhisperModel
99
 
100
  if model_dir is not None:
101
  logger.debug(f"Loading whisper model from model_dir {model_dir}. "
@@ -115,7 +115,9 @@ class FasterWhisperASR(ASRBase):
115
  compute_type=compute_type,
116
  download_root=cache_dir,
117
  )
118
- return model
 
 
119
 
120
  def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
121
  segments, info = self.model.transcribe(
@@ -125,6 +127,7 @@ class FasterWhisperASR(ASRBase):
125
  beam_size=5,
126
  word_timestamps=True,
127
  condition_on_previous_text=True,
 
128
  **self.transcribe_kargs,
129
  )
130
  return list(segments)
@@ -148,6 +151,60 @@ class FasterWhisperASR(ASRBase):
148
  def set_translate_task(self):
149
  self.transcribe_kargs["task"] = "translate"
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  class MLXWhisper(ASRBase):
153
  """
 
95
  sep = ""
96
 
97
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
98
+ from faster_whisper import WhisperModel, BatchedInferencePipeline
99
 
100
  if model_dir is not None:
101
  logger.debug(f"Loading whisper model from model_dir {model_dir}. "
 
115
  compute_type=compute_type,
116
  download_root=cache_dir,
117
  )
118
+ batched_model = BatchedInferencePipeline(model=model)
119
+ return batched_model
120
+ #return model
121
 
122
  def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
123
  segments, info = self.model.transcribe(
 
127
  beam_size=5,
128
  word_timestamps=True,
129
  condition_on_previous_text=True,
130
+ batch_size=16,
131
  **self.transcribe_kargs,
132
  )
133
  return list(segments)
 
151
  def set_translate_task(self):
152
  self.transcribe_kargs["task"] = "translate"
153
 
154
+ class WhisperXASR(ASRBase):
155
+ """Uses whisperX as the backend."""
156
+ sep = ""
157
+
158
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
159
+ import whisperx
160
+
161
+ if model_dir is not None:
162
+ logger.debug(f"Loading whisper model from model_dir {model_dir}. "
163
+ f"modelsize and cache_dir parameters are not used.")
164
+ model_size_or_path = model_dir
165
+ elif modelsize is not None:
166
+ model_size_or_path = modelsize
167
+ else:
168
+ raise ValueError("Either modelsize or model_dir must be set")
169
+ device = "cuda" # Allow CTranslate2 to decide available device
170
+ compute_type = "int8" # Allow CTranslate2 to decide faster compute type
171
+
172
+ import torch
173
+ torch.backends.cuda.matmul.allow_tf32 = True
174
+ torch.backends.cudnn.allow_tf32 = True
175
+
176
+ model = whisperx.load_model(model_size_or_path, device, compute_type=compute_type)
177
+
178
+ return model
179
+
180
+ def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
181
+ segments, info = self.model.transcribe(
182
+ audio,
183
+ language=self.original_language,
184
+ batch_size=8,
185
+ **self.transcribe_kargs,
186
+ )
187
+ return list(segments)
188
+
189
+ def ts_words(self, segments) -> List[ASRToken]:
190
+ tokens = []
191
+ for segment in segments:
192
+ if segment.no_speech_prob > 0.9:
193
+ continue
194
+ for word in segment.words:
195
+ token = ASRToken(word.start, word.end, word.word, probability=word.probability)
196
+ tokens.append(token)
197
+ return tokens
198
+
199
+ def segments_end_ts(self, segments) -> List[float]:
200
+ return [segment.end for segment in segments]
201
+
202
+ def use_vad(self):
203
+ pass
204
+ # self.transcribe_kargs["vad_filter"] = True
205
+
206
+ def set_translate_task(self):
207
+ self.transcribe_kargs["task"] = "translate"
208
 
209
  class MLXWhisper(ASRBase):
210
  """