Spaces:
Running
on
Zero
Running
on
Zero
| from faster_whisper import WhisperModel | |
| from transformers import pipeline | |
| from pydub import AudioSegment | |
| import os | |
| import torchaudio | |
| import torch | |
| import re | |
| import time | |
| import sys | |
| from pathlib import Path | |
| import glob | |
| import ctypes | |
| import numpy as np | |
| from settings import DEBUG_MODE, MODEL_PATH_V2_FAST, MODEL_PATH_V1, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, RESAMPLING_FREQ, BATCH_SIZE, TASK | |
| def load_cudnn(): | |
| if not torch.cuda.is_available(): | |
| if DEBUG_MODE: print("[INFO] CUDA is not available, skipping cuDNN setup.") | |
| return | |
| if DEBUG_MODE: print(f"[INFO] sys.platform: {sys.platform}") | |
| if sys.platform == "win32": | |
| torch_lib_dir = Path(torch.__file__).parent / "lib" | |
| if torch_lib_dir.exists(): | |
| os.add_dll_directory(str(torch_lib_dir)) | |
| if DEBUG_MODE: print(f"[INFO] Added DLL directory: {torch_lib_dir}") | |
| else: | |
| if DEBUG_MODE: print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}") | |
| elif sys.platform == "linux": | |
| site_packages = Path(torch.__file__).resolve().parents[1] | |
| cudnn_dir = site_packages / "nvidia" / "cudnn" / "lib" | |
| if not cudnn_dir.exists(): | |
| if DEBUG_MODE: print(f"[ERROR] cudnn dir not found: {cudnn_dir}") | |
| return | |
| pattern = str(cudnn_dir / "libcudnn_cnn*.so*") | |
| matching_files = sorted(glob.glob(pattern)) | |
| if not matching_files: | |
| if DEBUG_MODE: print(f"[ERROR] No libcudnn_cnn*.so* found in {cudnn_dir}") | |
| return | |
| for so_path in matching_files: | |
| try: | |
| ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) | |
| if DEBUG_MODE: print(f"[INFO] Loaded: {so_path}") | |
| except OSError as e: | |
| if DEBUG_MODE: print(f"[WARNING] Failed to load {so_path}: {e}") | |
| else: | |
| if DEBUG_MODE: print(f"[WARNING] sys.platform is not win32 or linux") | |
| def get_settings(): | |
| is_cuda_available = torch.cuda.is_available() | |
| if is_cuda_available: | |
| device = "cuda" | |
| compute_type = "default" | |
| else: | |
| device = "cpu" | |
| compute_type = "default" | |
| if DEBUG_MODE: print(f"[SETTINGS] Device: {device}") | |
| return device, compute_type | |
| def load_model(use_v2_fast, device, compute_type): | |
| if DEBUG_MODE: | |
| print(f"[MODEL LOADING] use_v2_fast: {use_v2_fast}") | |
| if use_v2_fast: | |
| model = WhisperModel( | |
| MODEL_PATH_V2_FAST, | |
| device = device, | |
| compute_type = compute_type, | |
| ) | |
| else: | |
| model = pipeline( | |
| task="automatic-speech-recognition", | |
| model=MODEL_PATH_V1, | |
| chunk_length_s=30, | |
| device=device, | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| return model | |
| def split_input_stereo_channels(audio_path): | |
| ext = os.path.splitext(audio_path)[1].lower() | |
| if ext == ".wav": | |
| audio = AudioSegment.from_wav(audio_path) | |
| elif ext == ".mp3": | |
| audio = AudioSegment.from_file(audio_path, format="mp3") | |
| else: | |
| raise ValueError(f"[FORMAT AUDIO] Unsupported file format for: {audio_path}") | |
| channels = audio.split_to_mono() | |
| if len(channels) != 2: | |
| raise ValueError(f"[FORMAT AUDIO] Audio {audio_path} has {len(channels)} channels (instead of 2).") | |
| channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right | |
| channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left | |
| def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype: | |
| compute_type = compute_type.lower() | |
| if device.startswith("cuda"): | |
| if "float16" in compute_type or "int8" in compute_type: | |
| audio_np_dtype = np.float16 | |
| else: | |
| audio_np_dtype = np.float32 | |
| else: | |
| audio_np_dtype = np.float32 | |
| return audio_np_dtype | |
| def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray: | |
| input_audio, sample_rate = torchaudio.load(audio_path) | |
| if input_audio.shape[0] == 2: | |
| input_audio = torch.mean(input_audio, dim=0, keepdim=True) | |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=RESAMPLING_FREQ) | |
| input_audio = resampler(input_audio) | |
| input_audio = input_audio.squeeze() | |
| np_dtype = compute_type_to_audio_dtype(compute_type, device) | |
| input_audio = input_audio.numpy().astype(np_dtype) | |
| if DEBUG_MODE: | |
| print(f"[FORMAT AUDIO] Audio dtype for actual_compute_type: {input_audio.dtype}") | |
| return input_audio | |
| def process_waveforms(device: str, compute_type: str): | |
| left_waveform = format_audio(LEFT_CHANNEL_TEMP_PATH, compute_type, device) | |
| right_waveform = format_audio(RIGHT_CHANNEL_TEMP_PATH, compute_type, device) | |
| return left_waveform, right_waveform | |
| def transcribe_pipeline(audio, model): | |
| text = model(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": TASK}, return_timestamps=True)["text"] | |
| return text | |
| def transcribe_channels(left_waveform, right_waveform, model): | |
| left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe") | |
| right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe") | |
| left_result = list(left_result) | |
| right_result = list(right_result) | |
| return left_result, right_result | |
| # TODO refactor and rename this function | |
| def post_process_transcription(transcription, max_repeats=2): | |
| tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription) | |
| cleaned_tokens = [] | |
| repetition_count = 0 | |
| previous_token = None | |
| for token in tokens: | |
| reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token) | |
| if reduced_token == previous_token: | |
| repetition_count += 1 | |
| if repetition_count <= max_repeats: | |
| cleaned_tokens.append(reduced_token) | |
| else: | |
| repetition_count = 1 | |
| cleaned_tokens.append(reduced_token) | |
| previous_token = reduced_token | |
| cleaned_transcription = " ".join(cleaned_tokens) | |
| cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip() | |
| return cleaned_transcription | |
| # TODO not used right now, decide to use it or not | |
| def post_merge_consecutive_segments_from_text(transcription_text: str) -> str: | |
| segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text) | |
| merged_transcription = '' | |
| current_speaker = None | |
| current_segment = [] | |
| for i in range(1, len(segments) - 1, 2): | |
| speaker_tag = segments[i] | |
| text = segments[i + 1].strip() | |
| speaker = re.search(r'\d{2}', speaker_tag).group() | |
| if speaker == current_speaker: | |
| current_segment.append(text) | |
| else: | |
| if current_speaker is not None: | |
| merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' | |
| current_speaker = speaker | |
| current_segment = [text] | |
| if current_speaker is not None: | |
| merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' | |
| return merged_transcription.strip() | |
| def get_segments(result, speaker_label): | |
| segments = result | |
| final_segments = [ | |
| (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip())) | |
| for seg in segments if seg.text | |
| ] | |
| return final_segments | |
| def post_process_transcripts(left_result, right_result): | |
| left_segs = get_segments(left_result, "Speaker 1") | |
| right_segs = get_segments(right_result, "Speaker 2") | |
| merged_transcript = sorted( | |
| left_segs + right_segs, | |
| key=lambda x: float(x[0]) if x[0] is not None else float("inf") | |
| ) | |
| clean_output = "" | |
| for start, end, speaker, text in merged_transcript: | |
| clean_output += f"[{speaker}]: {text}\n" | |
| clean_output = clean_output.strip() | |
| return clean_output | |
| def cleanup_temp_files(*file_paths): | |
| for path in file_paths: | |
| if path and os.path.exists(path): | |
| if DEBUG_MODE: print(f"Removing path: {path}") | |
| os.remove(path) | |
| def generate(audio_path, use_v2_fast): | |
| load_cudnn() | |
| device, requested_compute_type = get_settings() | |
| model = load_model(use_v2_fast, device, requested_compute_type) | |
| if use_v2_fast: | |
| actual_compute_type = model.model.compute_type | |
| else: | |
| actual_compute_type = "float32" #HF pipeline safe default | |
| if DEBUG_MODE: | |
| print(f"[SETTINGS] Requested compute_type: {requested_compute_type}") | |
| print(f"[SETTINGS] Actual compute_type: {actual_compute_type}") | |
| if use_v2_fast: | |
| split_input_stereo_channels(audio_path) | |
| left_waveform, right_waveform = process_waveforms(device, actual_compute_type) | |
| left_result, right_result = transcribe_channels(left_waveform, right_waveform, model) | |
| output = post_process_transcripts(left_result, right_result) | |
| cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH) | |
| else: | |
| audio = format_audio(audio_path, actual_compute_type, device) | |
| merged_results = transcribe_pipeline(audio, model) | |
| output = post_process_transcription(merged_results) | |
| return output |