asr-inference / whisper_cs_fase_1.py
Sarah Solito
Fase_1 and Fase_2 releases, code cleaned
d6fb6a2
raw
history blame
3.13 kB
from faster_whisper import WhisperModel
from transformers import pipeline
import os
import time
from settings import MODEL_PATH_V2_FAST, MODEL_PATH_V1, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, BATCH_SIZE, TASK
from audio_utils import debug_print, get_settings, split_input_stereo_channels, format_audio, process_waveforms, post_process_transcripts, post_process_transcription, post_merge_consecutive_segments_from_text, cleanup_temp_files
hf_token = os.getenv("HF_TOKEN")
ASR_MODEL_V2 = None
ASR_MODEL_V1 = None
def get_asr_model_v2(DEVICE, COMPUTE_TYPE):
global ASR_MODEL_V2
if ASR_MODEL_V2 is None:
debug_print("[MODEL LOADING] Loading ASR v2_fast model...")
ASR_MODEL_V2 = WhisperModel(
MODEL_PATH_V2_FAST,
device=DEVICE,
compute_type=COMPUTE_TYPE
)
debug_print("[MODEL LOADING]v2_fast model loaded")
return ASR_MODEL_V2
def get_asr_model_v1(DEVICE):
global ASR_MODEL_V1
if ASR_MODEL_V1 is None:
debug_print("[MODEL LOADING]Loading ASR v1 pipeline model...")
ASR_MODEL_V1 = pipeline(
task="automatic-speech-recognition",
model=MODEL_PATH_V1,
chunk_length_s=30,
device=0 if DEVICE == "cuda" else -1,
token=hf_token
)
debug_print("[MODEL LOADING]ASR v1 model loaded")
return ASR_MODEL_V1
def transcribe_asr(audio, model):
text = model(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": TASK}, return_timestamps=True)["text"]
return text
def transcribe_faster_asr(left_waveform, right_waveform, model):
left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe")
right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe")
return list(left_result), list(right_result)
def generate_fase_1(audio_path, model_version, civil_channel):
DEVICE, COMPUTE_TYPE = get_settings()
debug_print(f"[Fase1] Starting inference with model version: {model_version}")
if model_version == "v2_fast":
asr_model = get_asr_model_v2(DEVICE, COMPUTE_TYPE)
actual_compute_type = asr_model.model.compute_type
debug_print(f"[SETTINGS] Device: {DEVICE}, Compute type: {actual_compute_type}")
split_input_stereo_channels(audio_path)
left_waveform, right_waveform = process_waveforms(DEVICE, actual_compute_type)
debug_print(f"[SETTINGS] Civil channel: {civil_channel}")
left_result, right_result = transcribe_faster_asr(left_waveform, right_waveform, asr_model)
text, _ = post_process_transcripts(left_result, right_result, civil_channel)
cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH)
else:
actual_compute_type = "float32" # HF pipeline safe default
debug_print(f"[SETTINGS] Device: {DEVICE}, Compute type: {actual_compute_type}")
asr_model = get_asr_model_v1(DEVICE)
audio = format_audio(audio_path, actual_compute_type, DEVICE)
result = transcribe_asr(audio, asr_model)
text = post_process_transcription(result)
return text