asr-inference / whisper_cs_fase_2.py
AbirMessaoudi's picture
fase_1, fase_2 releases (#46)
1619dcb verified
raw
history blame
3.63 kB
from faster_whisper import WhisperModel
from transformers import pipeline
import os
from settings import MODEL_PATH_AGE_GENDER, MODEL_PATH_METEO, MODEL_PATH_V2_FAST, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH
from audio_utils import debug_print, get_settings, split_input_stereo_channels, process_waveforms, post_process_transcripts, post_merge_consecutive_segments_from_text, cleanup_temp_files
from shout_detector import shout
from silence_detector import silence
from meteo_detector import classify_meteo_event
from age_gender_detector import age_gender, WavLMWrapper
hf_token = os.getenv("HF_AUTH_TOKEN")
ASR_MODEL = None
AGE_GENDER_MODEL = None
METEO_MODEL = None
def get_asr_model(DEVICE, COMPUTE_TYPE):
global ASR_MODEL
if ASR_MODEL is None:
debug_print("[MODEL LOADING]Loading ASR model...")
ASR_MODEL = WhisperModel(
MODEL_PATH_V2_FAST,
device=DEVICE,
compute_type=COMPUTE_TYPE
)
debug_print("[MODEL LOADING]ASR model loaded")
return ASR_MODEL
def get_age_gender_model(DEVICE):
global AGE_GENDER_MODEL
if AGE_GENDER_MODEL is None:
debug_print("[MODEL LOADING]Loading Age/Gender model...")
AGE_GENDER_MODEL = WavLMWrapper.from_pretrained(MODEL_PATH_AGE_GENDER).to(DEVICE)
AGE_GENDER_MODEL.eval()
debug_print("[MODEL LOADING]Age/Gender model loaded")
return AGE_GENDER_MODEL
def get_meteo_model(DEVICE):
global METEO_MODEL
if METEO_MODEL is None:
debug_print("[MODEL LOADING]Loading Meteo model...")
METEO_MODEL = pipeline(
task="text-classification",
model=MODEL_PATH_METEO,
tokenizer=MODEL_PATH_METEO,
top_k=None,
device=0 if DEVICE == "cuda" else -1,
token=hf_token
)
debug_print("[MODEL LOADING]Meteo model loaded")
return METEO_MODEL
def transcribe_faster_asr(left_waveform, right_waveform, model):
left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe")
right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe")
return list(left_result), list(right_result)
def generate_fase_2(audio_path, model_version, civil_channel):
DEVICE, COMPUTE_TYPE = get_settings()
asr_model = get_asr_model(DEVICE, COMPUTE_TYPE)
age_gender_model = get_age_gender_model(DEVICE)
meteo_model = get_meteo_model(DEVICE)
actual_compute_type = asr_model.model.compute_type
debug_print(f"[SETTINGS] Device: {DEVICE}, Compute type: {actual_compute_type}")
split_input_stereo_channels(audio_path)
left_waveform, right_waveform = process_waveforms(DEVICE, actual_compute_type)
debug_print(f"[SETTINGS] Civil channel: {civil_channel}")
left_result, right_result = transcribe_faster_asr(left_waveform, right_waveform, asr_model)
silence_event = silence(audio_path)
civil_waveform = left_waveform if civil_channel == "Left" else right_waveform
civil_path = LEFT_CHANNEL_TEMP_PATH if civil_channel == "Left" else RIGHT_CHANNEL_TEMP_PATH
shout_event = shout(civil_path)
age, sex, age_group = age_gender(civil_waveform, age_gender_model, DEVICE)
age = f"{age_group} (aprox. {age} años)"
clean_output_asr, clean_output_meteo = post_process_transcripts(left_result, right_result, civil_channel)
text = '\n' + clean_output_asr
meteo_event = classify_meteo_event(clean_output_meteo, meteo_model, threshold=0.0)
cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH)
return text, sex, age, silence_event, shout_event, meteo_event