Spaces:
Build error
Build error
| # original https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py | |
| import itertools | |
| import logging | |
| import os | |
| import zlib | |
| from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union | |
| import ctranslate2 | |
| import numpy as np | |
| import tokenizers | |
| from faster_whisper.audio import decode_audio | |
| from faster_whisper.feature_extractor import FeatureExtractor | |
| from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer | |
| from faster_whisper.utils import download_model, format_timestamp, get_logger | |
| from faster_whisper.vad import ( | |
| SpeechTimestampsMap, | |
| VadOptions, | |
| collect_chunks, | |
| get_speech_timestamps, | |
| ) | |
| class Word(NamedTuple): | |
| start: float | |
| end: float | |
| word: str | |
| probability: float | |
| class Segment(NamedTuple): | |
| id: int | |
| seek: int | |
| start: float | |
| end: float | |
| text: str | |
| tokens: List[int] | |
| temperature: float | |
| avg_logprob: float | |
| compression_ratio: float | |
| no_speech_prob: float | |
| words: Optional[List[Word]] | |
| class TranscriptionOptions(NamedTuple): | |
| beam_size: int | |
| best_of: int | |
| patience: float | |
| length_penalty: float | |
| repetition_penalty: float | |
| no_repeat_ngram_size: int | |
| log_prob_threshold: Optional[float] | |
| no_speech_threshold: Optional[float] | |
| compression_ratio_threshold: Optional[float] | |
| condition_on_previous_text: bool | |
| prompt_reset_on_temperature: float | |
| temperatures: List[float] | |
| initial_prompt: Optional[Union[str, Iterable[int]]] | |
| prefix: Optional[str] | |
| suppress_blank: bool | |
| suppress_tokens: Optional[List[int]] | |
| without_timestamps: bool | |
| max_initial_timestamp: float | |
| word_timestamps: bool | |
| prepend_punctuations: str | |
| append_punctuations: str | |
| class TranscriptionInfo(NamedTuple): | |
| language: str | |
| language_probability: float | |
| duration: float | |
| duration_after_vad: float | |
| all_language_probs: Optional[List[Tuple[str, float]]] | |
| transcription_options: TranscriptionOptions | |
| vad_options: VadOptions | |
| class WhisperModel: | |
| def __init__( | |
| self, | |
| model_size_or_path: str, | |
| device: str = "auto", | |
| device_index: Union[int, List[int]] = 0, | |
| compute_type: str = "default", | |
| cpu_threads: int = 0, | |
| num_workers: int = 1, | |
| download_root: Optional[str] = None, | |
| local_files_only: bool = False, | |
| ): | |
| """Initializes the Whisper model. | |
| Args: | |
| model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en, | |
| small, small.en, medium, medium.en, large-v1, large-v2, or large), a path to a converted | |
| model directory, or a CTranslate2-converted Whisper model ID from the Hugging Face Hub. | |
| When a size or a model ID is configured, the converted model is downloaded | |
| from the Hugging Face Hub. | |
| device: Device to use for computation ("cpu", "cuda", "auto"). | |
| device_index: Device ID to use. | |
| The model can also be loaded on multiple GPUs by passing a list of IDs | |
| (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel | |
| when transcribe() is called from multiple Python threads (see also num_workers). | |
| compute_type: Type to use for computation. | |
| See https://opennmt.net/CTranslate2/quantization.html. | |
| cpu_threads: Number of threads to use when running on CPU (4 by default). | |
| A non zero value overrides the OMP_NUM_THREADS environment variable. | |
| num_workers: When transcribe() is called from multiple Python threads, | |
| having multiple workers enables true parallelism when running the model | |
| (concurrent calls to self.model.generate() will run in parallel). | |
| This can improve the global throughput at the cost of increased memory usage. | |
| download_root: Directory where the models should be saved. If not set, the models | |
| are saved in the standard Hugging Face cache directory. | |
| local_files_only: If True, avoid downloading the file and return the path to the | |
| local cached file if it exists. | |
| """ | |
| self.logger = get_logger() | |
| if os.path.isdir(model_size_or_path): | |
| model_path = model_size_or_path | |
| else: | |
| model_path = download_model( | |
| model_size_or_path, | |
| local_files_only=local_files_only, | |
| cache_dir=download_root, | |
| ) | |
| self.model = ctranslate2.models.Whisper( | |
| model_path, | |
| device=device, | |
| device_index=device_index, | |
| compute_type=compute_type, | |
| intra_threads=cpu_threads, | |
| inter_threads=num_workers, | |
| ) | |
| tokenizer_file = os.path.join(model_path, "tokenizer.json") | |
| if os.path.isfile(tokenizer_file): | |
| self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) | |
| else: | |
| self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained( | |
| "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") | |
| ) | |
| self.feature_extractor = FeatureExtractor() | |
| self.num_samples_per_token = self.feature_extractor.hop_length * 2 | |
| self.frames_per_second = ( | |
| self.feature_extractor.sampling_rate // self.feature_extractor.hop_length | |
| ) | |
| self.tokens_per_second = ( | |
| self.feature_extractor.sampling_rate // self.num_samples_per_token | |
| ) | |
| self.input_stride = 2 | |
| self.time_precision = 0.02 | |
| self.max_length = 448 | |
| def supported_languages(self) -> List[str]: | |
| """The languages supported by the model.""" | |
| return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"] | |
| def transcribe( | |
| self, | |
| audio: Union[str, BinaryIO, np.ndarray], | |
| language: Optional[str] = None, | |
| task: str = "transcribe", | |
| beam_size: int = 5, | |
| best_of: int = 5, | |
| patience: float = 1, | |
| length_penalty: float = 1, | |
| repetition_penalty: float = 1, | |
| no_repeat_ngram_size: int = 0, | |
| temperature: Union[float, List[float], Tuple[float, ...]] = [ | |
| 0.0, | |
| 0.2, | |
| 0.4, | |
| 0.6, | |
| 0.8, | |
| 1.0, | |
| ], | |
| compression_ratio_threshold: Optional[float] = 2.4, | |
| log_prob_threshold: Optional[float] = -1.0, | |
| no_speech_threshold: Optional[float] = 0.6, | |
| condition_on_previous_text: bool = True, | |
| prompt_reset_on_temperature: float = 0.5, | |
| initial_prompt: Optional[Union[str, Iterable[int]]] = None, | |
| prefix: Optional[str] = None, | |
| suppress_blank: bool = True, | |
| suppress_tokens: Optional[List[int]] = [-1], | |
| without_timestamps: bool = False, | |
| max_initial_timestamp: float = 1.0, | |
| word_timestamps: bool = False, | |
| prepend_punctuations: str = "\"'“¿([{-", | |
| append_punctuations: str = "\"'.。,,!!??::”)]}、", | |
| vad_filter: bool = False, | |
| vad_parameters: Optional[Union[dict, VadOptions]] = None, | |
| ) -> Tuple[Iterable[Segment], TranscriptionInfo]: | |
| """Transcribes an input file. | |
| Arguments: | |
| audio: Path to the input file (or a file-like object), or the audio waveform. | |
| language: The language spoken in the audio. It should be a language code such | |
| as "en" or "fr". If not set, the language will be detected in the first 30 seconds | |
| of audio. | |
| task: Task to execute (transcribe or translate). | |
| beam_size: Beam size to use for decoding. | |
| best_of: Number of candidates when sampling with non-zero temperature. | |
| patience: Beam search patience factor. | |
| length_penalty: Exponential length penalty constant. | |
| repetition_penalty: Penalty applied to the score of previously generated tokens | |
| (set > 1 to penalize). | |
| no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). | |
| temperature: Temperature for sampling. It can be a tuple of temperatures, | |
| which will be successively used upon failures according to either | |
| `compression_ratio_threshold` or `log_prob_threshold`. | |
| compression_ratio_threshold: If the gzip compression ratio is above this value, | |
| treat as failed. | |
| log_prob_threshold: If the average log probability over sampled tokens is | |
| below this value, treat as failed. | |
| no_speech_threshold: If the no_speech probability is higher than this value AND | |
| the average log probability over sampled tokens is below `log_prob_threshold`, | |
| consider the segment as silent. | |
| condition_on_previous_text: If True, the previous output of the model is provided | |
| as a prompt for the next window; disabling may make the text inconsistent across | |
| windows, but the model becomes less prone to getting stuck in a failure loop, | |
| such as repetition looping or timestamps going out of sync. | |
| prompt_reset_on_temperature: Resets prompt if temperature is above this value. | |
| Arg has effect only if condition_on_previous_text is True. | |
| initial_prompt: Optional text string or iterable of token ids to provide as a | |
| prompt for the first window. | |
| prefix: Optional text to provide as a prefix for the first window. | |
| suppress_blank: Suppress blank outputs at the beginning of the sampling. | |
| suppress_tokens: List of token IDs to suppress. -1 will suppress a default set | |
| of symbols as defined in the model config.json file. | |
| without_timestamps: Only sample text tokens. | |
| max_initial_timestamp: The initial timestamp cannot be later than this. | |
| word_timestamps: Extract word-level timestamps using the cross-attention pattern | |
| and dynamic time warping, and include the timestamps for each word in each segment. | |
| prepend_punctuations: If word_timestamps is True, merge these punctuation symbols | |
| with the next word | |
| append_punctuations: If word_timestamps is True, merge these punctuation symbols | |
| with the previous word | |
| vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio | |
| without speech. This step is using the Silero VAD model | |
| https://github.com/snakers4/silero-vad. | |
| vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available | |
| parameters and default values in the class `VadOptions`). | |
| Returns: | |
| A tuple with: | |
| - a generator over transcribed segments | |
| - an instance of TranscriptionInfo | |
| """ | |
| sampling_rate = self.feature_extractor.sampling_rate | |
| if not isinstance(audio, np.ndarray): | |
| audio = decode_audio(audio, sampling_rate=sampling_rate) | |
| duration = audio.shape[0] / sampling_rate | |
| duration_after_vad = duration | |
| self.logger.info( | |
| "Processing audio with duration %s", format_timestamp(duration) | |
| ) | |
| if vad_filter: | |
| if vad_parameters is None: | |
| vad_parameters = VadOptions() | |
| elif isinstance(vad_parameters, dict): | |
| vad_parameters = VadOptions(**vad_parameters) | |
| speech_chunks = get_speech_timestamps(audio, vad_parameters) | |
| audio = collect_chunks(audio, speech_chunks) | |
| duration_after_vad = audio.shape[0] / sampling_rate | |
| self.logger.info( | |
| "VAD filter removed %s of audio", | |
| format_timestamp(duration - duration_after_vad), | |
| ) | |
| if self.logger.isEnabledFor(logging.DEBUG): | |
| self.logger.debug( | |
| "VAD filter kept the following audio segments: %s", | |
| ", ".join( | |
| "[%s -> %s]" | |
| % ( | |
| format_timestamp(chunk["start"] / sampling_rate), | |
| format_timestamp(chunk["end"] / sampling_rate), | |
| ) | |
| for chunk in speech_chunks | |
| ), | |
| ) | |
| else: | |
| speech_chunks = None | |
| features = self.feature_extractor(audio) | |
| encoder_output = None | |
| all_language_probs = None | |
| if language is None: | |
| if not self.model.is_multilingual: | |
| language = "en" | |
| language_probability = 1 | |
| else: | |
| segment = features[:, : self.feature_extractor.nb_max_frames] | |
| encoder_output = self.encode(segment) | |
| # results is a list of tuple[str, float] with language names and | |
| # probabilities. | |
| results = self.model.detect_language(encoder_output)[0] | |
| # Parse language names to strip out markers | |
| all_language_probs = [(token[2:-2], prob) for (token, prob) in results] | |
| # Get top language token and probability | |
| language, language_probability = all_language_probs[0] | |
| self.logger.info( | |
| "Detected language '%s' with probability %.2f", | |
| language, | |
| language_probability, | |
| ) | |
| else: | |
| if not self.model.is_multilingual and language != "en": | |
| self.logger.warning( | |
| "The current model is English-only but the language parameter is set to '%s'; " | |
| "using 'en' instead." % language | |
| ) | |
| language = "en" | |
| language_probability = 1 | |
| tokenizer = Tokenizer( | |
| self.hf_tokenizer, | |
| self.model.is_multilingual, | |
| task=task, | |
| language=language, | |
| ) | |
| options = TranscriptionOptions( | |
| beam_size=beam_size, | |
| best_of=best_of, | |
| patience=patience, | |
| length_penalty=length_penalty, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| log_prob_threshold=log_prob_threshold, | |
| no_speech_threshold=no_speech_threshold, | |
| compression_ratio_threshold=compression_ratio_threshold, | |
| condition_on_previous_text=condition_on_previous_text, | |
| prompt_reset_on_temperature=prompt_reset_on_temperature, | |
| temperatures=( | |
| temperature if isinstance(temperature, (list, tuple)) else [temperature] | |
| ), | |
| initial_prompt=initial_prompt, | |
| prefix=prefix, | |
| suppress_blank=suppress_blank, | |
| suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens), | |
| without_timestamps=without_timestamps, | |
| max_initial_timestamp=max_initial_timestamp, | |
| word_timestamps=word_timestamps, | |
| prepend_punctuations=prepend_punctuations, | |
| append_punctuations=append_punctuations, | |
| ) | |
| segments = self.generate_segments(features, tokenizer, options, encoder_output) | |
| if speech_chunks: | |
| segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate) | |
| info = TranscriptionInfo( | |
| language=language, | |
| language_probability=language_probability, | |
| duration=duration, | |
| duration_after_vad=duration_after_vad, | |
| transcription_options=options, | |
| vad_options=vad_parameters, | |
| all_language_probs=all_language_probs, | |
| ) | |
| return segments, info | |
| def generate_segments( | |
| self, | |
| features: np.ndarray, | |
| tokenizer: Tokenizer, | |
| options: TranscriptionOptions, | |
| encoder_output: Optional[ctranslate2.StorageView] = None, | |
| ) -> Iterable[Segment]: | |
| content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames | |
| idx = 0 | |
| seek = 0 | |
| all_tokens = [] | |
| prompt_reset_since = 0 | |
| if options.initial_prompt is not None: | |
| if isinstance(options.initial_prompt, str): | |
| initial_prompt = " " + options.initial_prompt.strip() | |
| initial_prompt_tokens = tokenizer.encode(initial_prompt) | |
| all_tokens.extend(initial_prompt_tokens) | |
| else: | |
| all_tokens.extend(options.initial_prompt) | |
| last_speech_timestamp = 0.0 | |
| all_segments = [] | |
| while seek < content_frames: | |
| time_offset = seek * self.feature_extractor.time_per_frame | |
| segment = features[:, seek : seek + self.feature_extractor.nb_max_frames] | |
| segment_size = min( | |
| self.feature_extractor.nb_max_frames, content_frames - seek | |
| ) | |
| segment_duration = segment_size * self.feature_extractor.time_per_frame | |
| if self.logger.isEnabledFor(logging.DEBUG): | |
| self.logger.debug( | |
| "Processing segment at %s", format_timestamp(time_offset) | |
| ) | |
| previous_tokens = all_tokens[prompt_reset_since:] | |
| prompt = self.get_prompt( | |
| tokenizer, | |
| previous_tokens, | |
| without_timestamps=options.without_timestamps, | |
| prefix=options.prefix if seek == 0 else None, | |
| ) | |
| if seek > 0 or encoder_output is None: | |
| encoder_output = self.encode(segment) | |
| ( | |
| result, | |
| avg_logprob, | |
| temperature, | |
| compression_ratio, | |
| ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options) | |
| if options.no_speech_threshold is not None: | |
| # no voice activity check | |
| should_skip = result.no_speech_prob > options.no_speech_threshold | |
| if ( | |
| options.log_prob_threshold is not None | |
| and avg_logprob > options.log_prob_threshold | |
| ): | |
| # don't skip if the logprob is high enough, despite the no_speech_prob | |
| should_skip = False | |
| if should_skip: | |
| self.logger.debug( | |
| "No speech threshold is met (%f > %f)", | |
| result.no_speech_prob, | |
| options.no_speech_threshold, | |
| ) | |
| # fast-forward to the next segment boundary | |
| seek += segment_size | |
| continue | |
| tokens = result.sequences_ids[0] | |
| previous_seek = seek | |
| current_segments = [] | |
| single_timestamp_ending = ( | |
| len(tokens) >= 2 | |
| and tokens[-2] < tokenizer.timestamp_begin | |
| and tokens[-1] >= tokenizer.timestamp_begin | |
| ) | |
| consecutive_timestamps = [ | |
| i | |
| for i in range(len(tokens)) | |
| if i > 0 | |
| and tokens[i] >= tokenizer.timestamp_begin | |
| and tokens[i - 1] >= tokenizer.timestamp_begin | |
| ] | |
| if len(consecutive_timestamps) > 0: | |
| slices = list(consecutive_timestamps) | |
| if single_timestamp_ending: | |
| slices.append(len(tokens)) | |
| last_slice = 0 | |
| for current_slice in slices: | |
| sliced_tokens = tokens[last_slice:current_slice] | |
| start_timestamp_position = ( | |
| sliced_tokens[0] - tokenizer.timestamp_begin | |
| ) | |
| end_timestamp_position = ( | |
| sliced_tokens[-1] - tokenizer.timestamp_begin | |
| ) | |
| start_time = ( | |
| time_offset + start_timestamp_position * self.time_precision | |
| ) | |
| end_time = ( | |
| time_offset + end_timestamp_position * self.time_precision | |
| ) | |
| current_segments.append( | |
| dict( | |
| seek=seek, | |
| start=start_time, | |
| end=end_time, | |
| tokens=sliced_tokens, | |
| ) | |
| ) | |
| last_slice = current_slice | |
| if single_timestamp_ending: | |
| # single timestamp at the end means no speech after the last timestamp. | |
| seek += segment_size | |
| else: | |
| # otherwise, ignore the unfinished segment and seek to the last timestamp | |
| last_timestamp_position = ( | |
| tokens[last_slice - 1] - tokenizer.timestamp_begin | |
| ) | |
| seek += last_timestamp_position * self.input_stride | |
| else: | |
| duration = segment_duration | |
| timestamps = [ | |
| token for token in tokens if token >= tokenizer.timestamp_begin | |
| ] | |
| if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: | |
| last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin | |
| duration = last_timestamp_position * self.time_precision | |
| current_segments.append( | |
| dict( | |
| seek=seek, | |
| start=time_offset, | |
| end=time_offset + duration, | |
| tokens=tokens, | |
| ) | |
| ) | |
| seek += segment_size | |
| if options.word_timestamps: | |
| self.add_word_timestamps( | |
| current_segments, | |
| tokenizer, | |
| encoder_output, | |
| segment_size, | |
| options.prepend_punctuations, | |
| options.append_punctuations, | |
| last_speech_timestamp=last_speech_timestamp, | |
| ) | |
| word_end_timestamps = [ | |
| w["end"] for s in current_segments for w in s["words"] | |
| ] | |
| if len(word_end_timestamps) > 0: | |
| last_speech_timestamp = word_end_timestamps[-1] | |
| if not single_timestamp_ending and len(word_end_timestamps) > 0: | |
| seek_shift = round( | |
| (word_end_timestamps[-1] - time_offset) * self.frames_per_second | |
| ) | |
| if seek_shift > 0: | |
| seek = previous_seek + seek_shift | |
| for segment in current_segments: | |
| tokens = segment["tokens"] | |
| text = tokenizer.decode(tokens) | |
| if segment["start"] == segment["end"] or not text.strip(): | |
| continue | |
| all_tokens.extend(tokens) | |
| idx += 1 | |
| all_segments.append(Segment( | |
| id=idx, | |
| seek=seek, | |
| start=segment["start"], | |
| end=segment["end"], | |
| text=text, | |
| tokens=tokens, | |
| temperature=temperature, | |
| avg_logprob=avg_logprob, | |
| compression_ratio=compression_ratio, | |
| no_speech_prob=result.no_speech_prob, | |
| words=( | |
| [Word(**word) for word in segment["words"]] | |
| if options.word_timestamps | |
| else None | |
| ), | |
| )) | |
| if ( | |
| not options.condition_on_previous_text | |
| or temperature > options.prompt_reset_on_temperature | |
| ): | |
| if options.condition_on_previous_text: | |
| self.logger.debug( | |
| "Reset prompt. prompt_reset_on_temperature threshold is met %f > %f", | |
| temperature, | |
| options.prompt_reset_on_temperature, | |
| ) | |
| prompt_reset_since = len(all_tokens) | |
| return all_segments | |
| def encode(self, features: np.ndarray) -> ctranslate2.StorageView: | |
| # When the model is running on multiple GPUs, the encoder output should be moved | |
| # to the CPU since we don't know which GPU will handle the next job. | |
| to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 | |
| features = np.expand_dims(features, 0) | |
| features = get_ctranslate2_storage(features) | |
| return self.model.encode(features, to_cpu=to_cpu) | |
| def generate_with_fallback( | |
| self, | |
| encoder_output: ctranslate2.StorageView, | |
| prompt: List[int], | |
| tokenizer: Tokenizer, | |
| options: TranscriptionOptions, | |
| ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]: | |
| decode_result = None | |
| all_results = [] | |
| below_cr_threshold_results = [] | |
| max_initial_timestamp_index = int( | |
| round(options.max_initial_timestamp / self.time_precision) | |
| ) | |
| for temperature in options.temperatures: | |
| if temperature > 0: | |
| kwargs = { | |
| "beam_size": 1, | |
| "num_hypotheses": options.best_of, | |
| "sampling_topk": 0, | |
| "sampling_temperature": temperature, | |
| } | |
| else: | |
| kwargs = { | |
| "beam_size": options.beam_size, | |
| "patience": options.patience, | |
| } | |
| result = self.model.generate( | |
| encoder_output, | |
| [prompt], | |
| length_penalty=options.length_penalty, | |
| repetition_penalty=options.repetition_penalty, | |
| no_repeat_ngram_size=options.no_repeat_ngram_size, | |
| max_length=self.max_length, | |
| return_scores=True, | |
| return_no_speech_prob=True, | |
| suppress_blank=options.suppress_blank, | |
| suppress_tokens=options.suppress_tokens, | |
| max_initial_timestamp_index=max_initial_timestamp_index, | |
| **kwargs, | |
| )[0] | |
| tokens = result.sequences_ids[0] | |
| # Recover the average log prob from the returned score. | |
| seq_len = len(tokens) | |
| cum_logprob = result.scores[0] * (seq_len**options.length_penalty) | |
| avg_logprob = cum_logprob / (seq_len + 1) | |
| text = tokenizer.decode(tokens).strip() | |
| compression_ratio = get_compression_ratio(text) | |
| decode_result = ( | |
| result, | |
| avg_logprob, | |
| temperature, | |
| compression_ratio, | |
| ) | |
| all_results.append(decode_result) | |
| needs_fallback = False | |
| if options.compression_ratio_threshold is not None: | |
| if compression_ratio > options.compression_ratio_threshold: | |
| needs_fallback = True # too repetitive | |
| self.logger.debug( | |
| "Compression ratio threshold is not met with temperature %.1f (%f > %f)", | |
| temperature, | |
| compression_ratio, | |
| options.compression_ratio_threshold, | |
| ) | |
| else: | |
| below_cr_threshold_results.append(decode_result) | |
| if ( | |
| options.log_prob_threshold is not None | |
| and avg_logprob < options.log_prob_threshold | |
| ): | |
| needs_fallback = True # average log probability is too low | |
| self.logger.debug( | |
| "Log probability threshold is not met with temperature %.1f (%f < %f)", | |
| temperature, | |
| avg_logprob, | |
| options.log_prob_threshold, | |
| ) | |
| if ( | |
| options.no_speech_threshold is not None | |
| and result.no_speech_prob > options.no_speech_threshold | |
| ): | |
| needs_fallback = False # silence | |
| if not needs_fallback: | |
| break | |
| else: | |
| # all failed, select the result with the highest average log probability | |
| decode_result = max( | |
| below_cr_threshold_results or all_results, key=lambda x: x[1] | |
| ) | |
| return decode_result | |
| def get_prompt( | |
| self, | |
| tokenizer: Tokenizer, | |
| previous_tokens: List[int], | |
| without_timestamps: bool = False, | |
| prefix: Optional[str] = None, | |
| ) -> List[int]: | |
| prompt = [] | |
| if previous_tokens: | |
| prompt.append(tokenizer.sot_prev) | |
| prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) | |
| prompt.extend(tokenizer.sot_sequence) | |
| if without_timestamps: | |
| prompt.append(tokenizer.no_timestamps) | |
| if prefix: | |
| prefix_tokens = tokenizer.encode(" " + prefix.strip()) | |
| if len(prefix_tokens) >= self.max_length // 2: | |
| prefix_tokens = prefix_tokens[: self.max_length // 2 - 1] | |
| if not without_timestamps: | |
| prompt.append(tokenizer.timestamp_begin) | |
| prompt.extend(prefix_tokens) | |
| return prompt | |
| def add_word_timestamps( | |
| self, | |
| segments: List[dict], | |
| tokenizer: Tokenizer, | |
| encoder_output: ctranslate2.StorageView, | |
| num_frames: int, | |
| prepend_punctuations: str, | |
| append_punctuations: str, | |
| last_speech_timestamp: float, | |
| ) -> None: | |
| if len(segments) == 0: | |
| return | |
| text_tokens_per_segment = [ | |
| [token for token in segment["tokens"] if token < tokenizer.eot] | |
| for segment in segments | |
| ] | |
| text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) | |
| alignment = self.find_alignment( | |
| tokenizer, text_tokens, encoder_output, num_frames | |
| ) | |
| word_durations = np.array([word["end"] - word["start"] for word in alignment]) | |
| word_durations = word_durations[word_durations.nonzero()] | |
| median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 | |
| max_duration = median_duration * 2 | |
| # hack: truncate long words at sentence boundaries. | |
| # a better segmentation algorithm based on VAD should be able to replace this. | |
| if len(word_durations) > 0: | |
| sentence_end_marks = ".。!!??" | |
| # ensure words at sentence boundaries | |
| # are not longer than twice the median word duration. | |
| for i in range(1, len(alignment)): | |
| if alignment[i]["end"] - alignment[i]["start"] > max_duration: | |
| if alignment[i]["word"] in sentence_end_marks: | |
| alignment[i]["end"] = alignment[i]["start"] + max_duration | |
| elif alignment[i - 1]["word"] in sentence_end_marks: | |
| alignment[i]["start"] = alignment[i]["end"] - max_duration | |
| merge_punctuations(alignment, prepend_punctuations, append_punctuations) | |
| time_offset = ( | |
| segments[0]["seek"] | |
| * self.feature_extractor.hop_length | |
| / self.feature_extractor.sampling_rate | |
| ) | |
| word_index = 0 | |
| for segment, text_tokens in zip(segments, text_tokens_per_segment): | |
| saved_tokens = 0 | |
| words = [] | |
| while word_index < len(alignment) and saved_tokens < len(text_tokens): | |
| timing = alignment[word_index] | |
| if timing["word"]: | |
| words.append( | |
| dict( | |
| word=timing["word"], | |
| start=round(time_offset + timing["start"], 2), | |
| end=round(time_offset + timing["end"], 2), | |
| probability=timing["probability"], | |
| ) | |
| ) | |
| saved_tokens += len(timing["tokens"]) | |
| word_index += 1 | |
| # hack: truncate long words at segment boundaries. | |
| # a better segmentation algorithm based on VAD should be able to replace this. | |
| if len(words) > 0: | |
| # ensure the first and second word after a pause is not longer than | |
| # twice the median word duration. | |
| if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( | |
| words[0]["end"] - words[0]["start"] > max_duration | |
| or ( | |
| len(words) > 1 | |
| and words[1]["end"] - words[0]["start"] > max_duration * 2 | |
| ) | |
| ): | |
| if ( | |
| len(words) > 1 | |
| and words[1]["end"] - words[1]["start"] > max_duration | |
| ): | |
| boundary = max( | |
| words[1]["end"] / 2, words[1]["end"] - max_duration | |
| ) | |
| words[0]["end"] = words[1]["start"] = boundary | |
| words[0]["start"] = max(0, words[0]["end"] - max_duration) | |
| # prefer the segment-level start timestamp if the first word is too long. | |
| if ( | |
| segment["start"] < words[0]["end"] | |
| and segment["start"] - 0.5 > words[0]["start"] | |
| ): | |
| words[0]["start"] = max( | |
| 0, min(words[0]["end"] - median_duration, segment["start"]) | |
| ) | |
| else: | |
| segment["start"] = words[0]["start"] | |
| # prefer the segment-level end timestamp if the last word is too long. | |
| if ( | |
| segment["end"] > words[-1]["start"] | |
| and segment["end"] + 0.5 < words[-1]["end"] | |
| ): | |
| words[-1]["end"] = max( | |
| words[-1]["start"] + median_duration, segment["end"] | |
| ) | |
| else: | |
| segment["end"] = words[-1]["end"] | |
| last_speech_timestamp = segment["end"] | |
| segment["words"] = words | |
| def find_alignment( | |
| self, | |
| tokenizer: Tokenizer, | |
| text_tokens: List[int], | |
| encoder_output: ctranslate2.StorageView, | |
| num_frames: int, | |
| median_filter_width: int = 7, | |
| ) -> List[dict]: | |
| if len(text_tokens) == 0: | |
| return [] | |
| result = self.model.align( | |
| encoder_output, | |
| tokenizer.sot_sequence, | |
| [text_tokens], | |
| num_frames, | |
| median_filter_width=median_filter_width, | |
| )[0] | |
| text_token_probs = result.text_token_probs | |
| alignments = result.alignments | |
| text_indices = np.array([pair[0] for pair in alignments]) | |
| time_indices = np.array([pair[1] for pair in alignments]) | |
| words, word_tokens = tokenizer.split_to_word_tokens( | |
| text_tokens + [tokenizer.eot] | |
| ) | |
| word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) | |
| if len(word_boundaries) <= 1: | |
| return [] | |
| jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) | |
| jump_times = time_indices[jumps] / self.tokens_per_second | |
| start_times = jump_times[word_boundaries[:-1]] | |
| end_times = jump_times[word_boundaries[1:]] | |
| word_probabilities = [ | |
| np.mean(text_token_probs[i:j]) | |
| for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) | |
| ] | |
| return [ | |
| dict( | |
| word=word, tokens=tokens, start=start, end=end, probability=probability | |
| ) | |
| for word, tokens, start, end, probability in zip( | |
| words, word_tokens, start_times, end_times, word_probabilities | |
| ) | |
| ] | |
| def destroy(self): | |
| del self.model | |
| def restore_speech_timestamps( | |
| segments: Iterable[Segment], | |
| speech_chunks: List[dict], | |
| sampling_rate: int, | |
| ) -> Iterable[Segment]: | |
| ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate) | |
| for segment in segments: | |
| if segment.words: | |
| words = [] | |
| for word in segment.words: | |
| # Ensure the word start and end times are resolved to the same chunk. | |
| middle = (word.start + word.end) / 2 | |
| chunk_index = ts_map.get_chunk_index(middle) | |
| word = word._replace( | |
| start=ts_map.get_original_time(word.start, chunk_index), | |
| end=ts_map.get_original_time(word.end, chunk_index), | |
| ) | |
| words.append(word) | |
| segment = segment._replace( | |
| start=words[0].start, | |
| end=words[-1].end, | |
| words=words, | |
| ) | |
| else: | |
| segment = segment._replace( | |
| start=ts_map.get_original_time(segment.start), | |
| end=ts_map.get_original_time(segment.end), | |
| ) | |
| return segments | |
| def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: | |
| segment = np.ascontiguousarray(segment) | |
| segment = ctranslate2.StorageView.from_array(segment) | |
| return segment | |
| def get_compression_ratio(text: str) -> float: | |
| text_bytes = text.encode("utf-8") | |
| return len(text_bytes) / len(zlib.compress(text_bytes)) | |
| def get_suppressed_tokens( | |
| tokenizer: Tokenizer, | |
| suppress_tokens: Optional[List[int]], | |
| ) -> Optional[List[int]]: | |
| if not suppress_tokens or -1 in suppress_tokens: | |
| return suppress_tokens | |
| suppress_tokens = list(suppress_tokens) | |
| # Ensure the following special tokens are suppressed when the user does | |
| # not use the default set (-1). | |
| suppress_tokens.extend( | |
| [ | |
| tokenizer.transcribe, | |
| tokenizer.translate, | |
| tokenizer.sot, | |
| tokenizer.sot_prev, | |
| tokenizer.sot_lm, | |
| ] | |
| ) | |
| return sorted(set(suppress_tokens)) | |
| def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None: | |
| # merge prepended punctuations | |
| i = len(alignment) - 2 | |
| j = len(alignment) - 1 | |
| while i >= 0: | |
| previous = alignment[i] | |
| following = alignment[j] | |
| if previous["word"].startswith(" ") and previous["word"].strip() in prepended: | |
| # prepend it to the following word | |
| following["word"] = previous["word"] + following["word"] | |
| following["tokens"] = previous["tokens"] + following["tokens"] | |
| previous["word"] = "" | |
| previous["tokens"] = [] | |
| else: | |
| j = i | |
| i -= 1 | |
| # merge appended punctuations | |
| i = 0 | |
| j = 1 | |
| while j < len(alignment): | |
| previous = alignment[i] | |
| following = alignment[j] | |
| if not previous["word"].endswith(" ") and following["word"] in appended: | |
| # append it to the previous word | |
| previous["word"] = previous["word"] + following["word"] | |
| previous["tokens"] = previous["tokens"] + following["tokens"] | |
| following["word"] = "" | |
| following["tokens"] = [] | |
| else: | |
| i = j | |
| j += 1 | |