Marti Umbert
commited on
Commit
·
8b05aa3
1
Parent(s):
b3bfb38
whisperlivekit/whisper_streaming_custom/backends.py: use BatchedInferencePipeline in FasterWhisperASR class, and batch_size=16 in transcribe() function, also created WhisperXASR class
Browse files
whisperlivekit/whisper_streaming_custom/backends.py
CHANGED
|
@@ -95,7 +95,7 @@ class FasterWhisperASR(ASRBase):
|
|
| 95 |
sep = ""
|
| 96 |
|
| 97 |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
| 98 |
-
from faster_whisper import WhisperModel
|
| 99 |
|
| 100 |
if model_dir is not None:
|
| 101 |
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
|
@@ -115,7 +115,9 @@ class FasterWhisperASR(ASRBase):
|
|
| 115 |
compute_type=compute_type,
|
| 116 |
download_root=cache_dir,
|
| 117 |
)
|
| 118 |
-
|
|
|
|
|
|
|
| 119 |
|
| 120 |
def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
|
| 121 |
segments, info = self.model.transcribe(
|
|
@@ -125,6 +127,7 @@ class FasterWhisperASR(ASRBase):
|
|
| 125 |
beam_size=5,
|
| 126 |
word_timestamps=True,
|
| 127 |
condition_on_previous_text=True,
|
|
|
|
| 128 |
**self.transcribe_kargs,
|
| 129 |
)
|
| 130 |
return list(segments)
|
|
@@ -148,6 +151,60 @@ class FasterWhisperASR(ASRBase):
|
|
| 148 |
def set_translate_task(self):
|
| 149 |
self.transcribe_kargs["task"] = "translate"
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
class MLXWhisper(ASRBase):
|
| 153 |
"""
|
|
|
|
| 95 |
sep = ""
|
| 96 |
|
| 97 |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
| 98 |
+
from faster_whisper import WhisperModel, BatchedInferencePipeline
|
| 99 |
|
| 100 |
if model_dir is not None:
|
| 101 |
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
|
|
|
| 115 |
compute_type=compute_type,
|
| 116 |
download_root=cache_dir,
|
| 117 |
)
|
| 118 |
+
batched_model = BatchedInferencePipeline(model=model)
|
| 119 |
+
return batched_model
|
| 120 |
+
#return model
|
| 121 |
|
| 122 |
def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
|
| 123 |
segments, info = self.model.transcribe(
|
|
|
|
| 127 |
beam_size=5,
|
| 128 |
word_timestamps=True,
|
| 129 |
condition_on_previous_text=True,
|
| 130 |
+
batch_size=16,
|
| 131 |
**self.transcribe_kargs,
|
| 132 |
)
|
| 133 |
return list(segments)
|
|
|
|
| 151 |
def set_translate_task(self):
|
| 152 |
self.transcribe_kargs["task"] = "translate"
|
| 153 |
|
| 154 |
+
class WhisperXASR(ASRBase):
|
| 155 |
+
"""Uses whisperX as the backend."""
|
| 156 |
+
sep = ""
|
| 157 |
+
|
| 158 |
+
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
| 159 |
+
import whisperx
|
| 160 |
+
|
| 161 |
+
if model_dir is not None:
|
| 162 |
+
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
| 163 |
+
f"modelsize and cache_dir parameters are not used.")
|
| 164 |
+
model_size_or_path = model_dir
|
| 165 |
+
elif modelsize is not None:
|
| 166 |
+
model_size_or_path = modelsize
|
| 167 |
+
else:
|
| 168 |
+
raise ValueError("Either modelsize or model_dir must be set")
|
| 169 |
+
device = "cuda" # Allow CTranslate2 to decide available device
|
| 170 |
+
compute_type = "int8" # Allow CTranslate2 to decide faster compute type
|
| 171 |
+
|
| 172 |
+
import torch
|
| 173 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 174 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 175 |
+
|
| 176 |
+
model = whisperx.load_model(model_size_or_path, device, compute_type=compute_type)
|
| 177 |
+
|
| 178 |
+
return model
|
| 179 |
+
|
| 180 |
+
def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
|
| 181 |
+
segments, info = self.model.transcribe(
|
| 182 |
+
audio,
|
| 183 |
+
language=self.original_language,
|
| 184 |
+
batch_size=8,
|
| 185 |
+
**self.transcribe_kargs,
|
| 186 |
+
)
|
| 187 |
+
return list(segments)
|
| 188 |
+
|
| 189 |
+
def ts_words(self, segments) -> List[ASRToken]:
|
| 190 |
+
tokens = []
|
| 191 |
+
for segment in segments:
|
| 192 |
+
if segment.no_speech_prob > 0.9:
|
| 193 |
+
continue
|
| 194 |
+
for word in segment.words:
|
| 195 |
+
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
|
| 196 |
+
tokens.append(token)
|
| 197 |
+
return tokens
|
| 198 |
+
|
| 199 |
+
def segments_end_ts(self, segments) -> List[float]:
|
| 200 |
+
return [segment.end for segment in segments]
|
| 201 |
+
|
| 202 |
+
def use_vad(self):
|
| 203 |
+
pass
|
| 204 |
+
# self.transcribe_kargs["vad_filter"] = True
|
| 205 |
+
|
| 206 |
+
def set_translate_task(self):
|
| 207 |
+
self.transcribe_kargs["task"] = "translate"
|
| 208 |
|
| 209 |
class MLXWhisper(ASRBase):
|
| 210 |
"""
|